Ref. + AVX code & generic tests + benches (#85)

This commit is contained in:
Jean-Philippe Bossuat
2025-09-15 16:16:11 +02:00
committed by GitHub
parent 99b9e3e10e
commit 56dbd29c59
286 changed files with 27797 additions and 7270 deletions

View File

@@ -17,7 +17,10 @@ rand = {workspace = true}
rand_distr = {workspace = true}
rand_core = {workspace = true}
byteorder = {workspace = true}
once_cell = {workspace = true}
rand_chacha = "0.9.0"
bytemuck = "1.23.2"
[build-dependencies]
cmake = "0.1.54"

View File

@@ -1,4 +1,6 @@
use crate::layouts::{Backend, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDftToMut, VecZnxDftToRef};
use crate::layouts::{
Backend, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef,
};
/// Allocates as [crate::layouts::SvpPPol].
pub trait SvpPPolAlloc<B: Backend> {
@@ -25,8 +27,26 @@ pub trait SvpPrepare<B: Backend> {
}
/// Apply a scalar-vector product between `a[a_col]` and `b[b_col]` and stores the result on `res[res_col]`.
pub trait SvpApply<B: Backend> {
fn svp_apply<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
pub trait SvpApplyDft<B: Backend> {
fn svp_apply_dft<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
where
R: VecZnxDftToMut<B>,
A: SvpPPolToRef<B>,
C: VecZnxToRef;
}
/// Apply a scalar-vector product between `a[a_col]` and `b[b_col]` and stores the result on `res[res_col]`.
pub trait SvpApplyDftToDft<B: Backend> {
fn svp_apply_dft_to_dft<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
where
R: VecZnxDftToMut<B>,
A: SvpPPolToRef<B>,
C: VecZnxDftToRef<B>;
}
/// Apply a scalar-vector product between `a[a_col]` and `b[b_col]` and adds the result on `res[res_col]`.
pub trait SvpApplyDftToDftAdd<B: Backend> {
fn svp_apply_dft_to_dft_add<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
where
R: VecZnxDftToMut<B>,
A: SvpPPolToRef<B>,
@@ -34,8 +54,8 @@ pub trait SvpApply<B: Backend> {
}
/// Apply a scalar-vector product between `res[res_col]` and `a[a_col]` and stores the result on `res[res_col]`.
pub trait SvpApplyInplace<B: Backend> {
fn svp_apply_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
pub trait SvpApplyDftToDftInplace<B: Backend> {
fn svp_apply_dft_to_dft_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<B>,
A: SvpPPolToRef<B>;

View File

@@ -1,5 +1,3 @@
use rand_distr::Distribution;
use crate::{
layouts::{Backend, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef},
source::Source,
@@ -42,6 +40,16 @@ pub trait VecZnxAddInplace {
A: VecZnxToRef;
}
pub trait VecZnxAddScalar {
/// Adds the selected column of `a` on the selected column and limb of `b` and writes the result on the selected column of `res`.
#[allow(clippy::too_many_arguments)]
fn vec_znx_add_scalar<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize, b_limb: usize)
where
R: VecZnxToMut,
A: ScalarZnxToRef,
B: VecZnxToRef;
}
pub trait VecZnxAddScalarInplace {
/// Adds the selected column of `a` on the selected column and limb of `res`.
fn vec_znx_add_scalar_inplace<R, A>(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize)
@@ -79,6 +87,16 @@ pub trait VecZnxSubBAInplace {
A: VecZnxToRef;
}
pub trait VecZnxSubScalar {
/// Subtracts the selected column of `a` on the selected column and limb of `b` and writes the result on the selected column of `res`.
#[allow(clippy::too_many_arguments)]
fn vec_znx_sub_scalar<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize, b_limb: usize)
where
R: VecZnxToMut,
A: ScalarZnxToRef,
B: VecZnxToRef;
}
pub trait VecZnxSubScalarInplace {
/// Subtracts the selected column of `a` on the selected column and limb of `res`.
fn vec_znx_sub_scalar_inplace<R, A>(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize)
@@ -102,31 +120,61 @@ pub trait VecZnxNegateInplace {
A: VecZnxToMut;
}
pub trait VecZnxLshInplace {
pub trait VecZnxLshTmpBytes {
fn vec_znx_lsh_tmp_bytes(&self) -> usize;
}
pub trait VecZnxLsh<B: Backend> {
/// Left shift by k bits all columns of `a`.
fn vec_znx_lsh_inplace<A>(&self, basek: usize, k: usize, a: &mut A)
#[allow(clippy::too_many_arguments)]
fn vec_znx_lsh<R, A>(&self, basek: usize, k: usize, r: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
where
R: VecZnxToMut,
A: VecZnxToRef;
}
pub trait VecZnxRshTmpBytes {
fn vec_znx_rsh_tmp_bytes(&self) -> usize;
}
pub trait VecZnxRsh<B: Backend> {
/// Right shift by k bits all columns of `a`.
#[allow(clippy::too_many_arguments)]
fn vec_znx_rsh<R, A>(&self, basek: usize, k: usize, r: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
where
R: VecZnxToMut,
A: VecZnxToRef;
}
pub trait VecZnxLshInplace<B: Backend> {
/// Left shift by k bits all columns of `a`.
fn vec_znx_lsh_inplace<A>(&self, basek: usize, k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
where
A: VecZnxToMut;
}
pub trait VecZnxRshInplace {
pub trait VecZnxRshInplace<B: Backend> {
/// Right shift by k bits all columns of `a`.
fn vec_znx_rsh_inplace<A>(&self, basek: usize, k: usize, a: &mut A)
fn vec_znx_rsh_inplace<A>(&self, basek: usize, k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
where
A: VecZnxToMut;
}
pub trait VecZnxRotate {
/// Multiplies the selected column of `a` by X^k and stores the result in `res_col` of `res`.
fn vec_znx_rotate<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_rotate<R, A>(&self, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
}
pub trait VecZnxRotateInplace {
pub trait VecZnxRotateInplaceTmpBytes {
fn vec_znx_rotate_inplace_tmp_bytes(&self) -> usize;
}
pub trait VecZnxRotateInplace<B: Backend> {
/// Multiplies the selected column of `a` by X^k.
fn vec_znx_rotate_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
fn vec_znx_rotate_inplace<A>(&self, p: i64, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
where
A: VecZnxToMut;
}
@@ -139,54 +187,70 @@ pub trait VecZnxAutomorphism {
A: VecZnxToRef;
}
pub trait VecZnxAutomorphismInplace {
pub trait VecZnxAutomorphismInplaceTmpBytes {
fn vec_znx_automorphism_inplace_tmp_bytes(&self) -> usize;
}
pub trait VecZnxAutomorphismInplace<B: Backend> {
/// Applies the automorphism X^i -> X^ik on the selected column of `a`.
fn vec_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
fn vec_znx_automorphism_inplace<R>(&self, k: i64, res: &mut R, res_col: usize, scratch: &mut Scratch<B>)
where
A: VecZnxToMut;
R: VecZnxToMut;
}
pub trait VecZnxMulXpMinusOne {
fn vec_znx_mul_xp_minus_one<R, A>(&self, p: i64, r: &mut R, r_col: usize, a: &A, a_col: usize)
fn vec_znx_mul_xp_minus_one<R, A>(&self, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
}
pub trait VecZnxMulXpMinusOneInplace {
fn vec_znx_mul_xp_minus_one_inplace<R>(&self, p: i64, r: &mut R, r_col: usize)
pub trait VecZnxMulXpMinusOneInplaceTmpBytes {
fn vec_znx_mul_xp_minus_one_inplace_tmp_bytes(&self) -> usize;
}
pub trait VecZnxMulXpMinusOneInplace<B: Backend> {
fn vec_znx_mul_xp_minus_one_inplace<R>(&self, p: i64, res: &mut R, res_col: usize, scratch: &mut Scratch<B>)
where
R: VecZnxToMut;
}
pub trait VecZnxSplit<B: Backend> {
pub trait VecZnxSplitRingTmpBytes {
fn vec_znx_split_ring_tmp_bytes(&self) -> usize;
}
pub trait VecZnxSplitRing<B: Backend> {
/// Splits the selected columns of `b` into subrings and copies them them into the selected column of `res`.
///
/// # Panics
///
/// This method requires that all [crate::layouts::VecZnx] of b have the same ring degree
/// and that b.n() * b.len() <= a.n()
fn vec_znx_split<R, A>(&self, res: &mut [R], res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
fn vec_znx_split_ring<R, A>(&self, res: &mut [R], res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
where
R: VecZnxToMut,
A: VecZnxToRef;
}
pub trait VecZnxMerge {
pub trait VecZnxMergeRingsTmpBytes {
fn vec_znx_merge_rings_tmp_bytes(&self) -> usize;
}
pub trait VecZnxMergeRings<B: Backend> {
/// Merges the subrings of the selected column of `a` into the selected column of `res`.
///
/// # Panics
///
/// This method requires that all [crate::layouts::VecZnx] of a have the same ring degree
/// and that a.n() * a.len() <= b.n()
fn vec_znx_merge<R, A>(&self, res: &mut R, res_col: usize, a: &[A], a_col: usize)
fn vec_znx_merge_rings<R, A>(&self, res: &mut R, res_col: usize, a: &[A], a_col: usize, scratch: &mut Scratch<B>)
where
R: VecZnxToMut,
A: VecZnxToRef;
}
pub trait VecZnxSwithcDegree {
fn vec_znx_switch_degree<R, A>(&self, res: &mut R, res_col: usize, a: &A, col_a: usize)
pub trait VecZnxSwitchRing {
fn vec_znx_switch_ring<R, A>(&self, res: &mut R, res_col: usize, a: &A, col_a: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
@@ -201,42 +265,11 @@ pub trait VecZnxCopy {
pub trait VecZnxFillUniform {
/// Fills the first `size` size with uniform values in \[-2^{basek-1}, 2^{basek-1}\]
fn vec_znx_fill_uniform<R>(&self, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source)
fn vec_znx_fill_uniform<R>(&self, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: VecZnxToMut;
}
#[allow(clippy::too_many_arguments)]
pub trait VecZnxFillDistF64 {
fn vec_znx_fill_dist_f64<R, D: Distribution<f64>>(
&self,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) where
R: VecZnxToMut;
}
#[allow(clippy::too_many_arguments)]
pub trait VecZnxAddDistF64 {
/// Adds vector sampled according to the provided distribution, scaled by 2^{-k} and bounded to \[-bound, bound\].
fn vec_znx_add_dist_f64<R, D: Distribution<f64>>(
&self,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) where
R: VecZnxToMut;
}
#[allow(clippy::too_many_arguments)]
pub trait VecZnxFillNormal {
fn vec_znx_fill_normal<R>(

View File

@@ -1,10 +1,15 @@
use rand_distr::Distribution;
use crate::{
layouts::{Backend, Scratch, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef},
source::Source,
};
pub trait VecZnxBigFromSmall<B: Backend> {
fn vec_znx_big_from_small<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxToRef;
}
/// Allocates as [crate::layouts::VecZnxBig].
pub trait VecZnxBigAlloc<B: Backend> {
fn vec_znx_big_alloc(&self, cols: usize, size: usize) -> VecZnxBigOwned<B>;
@@ -45,48 +50,6 @@ pub trait VecZnxBigAddNormal<B: Backend> {
);
}
#[allow(clippy::too_many_arguments)]
pub trait VecZnxBigFillNormal<B: Backend> {
fn vec_znx_big_fill_normal<R: VecZnxBigToMut<B>>(
&self,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
);
}
#[allow(clippy::too_many_arguments)]
pub trait VecZnxBigFillDistF64<B: Backend> {
fn vec_znx_big_fill_dist_f64<R: VecZnxBigToMut<B>, D: Distribution<f64>>(
&self,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
);
}
#[allow(clippy::too_many_arguments)]
pub trait VecZnxBigAddDistF64<B: Backend> {
fn vec_znx_big_add_dist_f64<R: VecZnxBigToMut<B>, D: Distribution<f64>>(
&self,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
);
}
pub trait VecZnxBigAdd<B: Backend> {
/// Adds `a` to `b` and stores the result on `c`.
fn vec_znx_big_add<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
@@ -180,10 +143,17 @@ pub trait VecZnxBigSubSmallBInplace<B: Backend> {
A: VecZnxToRef;
}
pub trait VecZnxBigNegateInplace<B: Backend> {
fn vec_znx_big_negate_inplace<A>(&self, a: &mut A, a_col: usize)
pub trait VecZnxBigNegate<B: Backend> {
fn vec_znx_big_negate<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
A: VecZnxBigToMut<B>;
R: VecZnxBigToMut<B>,
A: VecZnxBigToRef<B>;
}
pub trait VecZnxBigNegateInplace<B: Backend> {
fn vec_znx_big_negate_inplace<R>(&self, res: &mut R, res_col: usize)
where
R: VecZnxBigToMut<B>;
}
pub trait VecZnxBigNormalizeTmpBytes {
@@ -204,9 +174,13 @@ pub trait VecZnxBigNormalize<B: Backend> {
A: VecZnxBigToRef<B>;
}
pub trait VecZnxBigAutomorphismInplaceTmpBytes {
fn vec_znx_big_automorphism_inplace_tmp_bytes(&self) -> usize;
}
pub trait VecZnxBigAutomorphism<B: Backend> {
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`.
fn vec_znx_big_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_big_automorphism<R, A>(&self, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxBigToRef<B>;
@@ -214,7 +188,7 @@ pub trait VecZnxBigAutomorphism<B: Backend> {
pub trait VecZnxBigAutomorphismInplace<B: Backend> {
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`.
fn vec_znx_big_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
fn vec_znx_big_automorphism_inplace<R>(&self, p: i64, res: &mut R, res_col: usize, scratch: &mut Scratch<B>)
where
A: VecZnxBigToMut<B>;
R: VecZnxBigToMut<B>;
}

View File

@@ -14,33 +14,33 @@ pub trait VecZnxDftAllocBytes {
fn vec_znx_dft_alloc_bytes(&self, cols: usize, size: usize) -> usize;
}
pub trait DFT<B: Backend> {
fn dft<R, A>(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
pub trait VecZnxDftApply<B: Backend> {
fn vec_znx_dft_apply<R, A>(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<B>,
A: VecZnxToRef;
}
pub trait VecZnxIDFTTmpBytes {
fn vec_znx_idft_tmp_bytes(&self) -> usize;
pub trait VecZnxIdftApplyTmpBytes {
fn vec_znx_idft_apply_tmp_bytes(&self) -> usize;
}
pub trait IDFT<B: Backend> {
fn idft<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
pub trait VecZnxIdftApply<B: Backend> {
fn vec_znx_idft_apply<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
where
R: VecZnxBigToMut<B>,
A: VecZnxDftToRef<B>;
}
pub trait IDFTTmpA<B: Backend> {
fn idft_tmp_a<R, A>(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
pub trait VecZnxIdftApplyTmpA<B: Backend> {
fn vec_znx_idft_apply_tmpa<R, A>(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxDftToMut<B>;
}
pub trait IDFTConsume<B: Backend> {
fn vec_znx_idft_consume<D: Data>(&self, a: VecZnxDft<D, B>) -> VecZnxBig<D, B>
pub trait VecZnxIdftApplyConsume<B: Backend> {
fn vec_znx_idft_apply_consume<D: Data>(&self, a: VecZnxDft<D, B>) -> VecZnxBig<D, B>
where
VecZnxDft<D, B>: VecZnxDftToMut<B>;
}

View File

@@ -1,4 +1,6 @@
use crate::layouts::{Backend, MatZnxToRef, Scratch, VecZnxDftToMut, VecZnxDftToRef, VmpPMatOwned, VmpPMatToMut, VmpPMatToRef};
use crate::layouts::{
Backend, MatZnxToRef, Scratch, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, VmpPMatOwned, VmpPMatToMut, VmpPMatToRef,
};
pub trait VmpPMatAlloc<B: Backend> {
fn vmp_pmat_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<B>;
@@ -17,12 +19,33 @@ pub trait VmpPrepareTmpBytes {
}
pub trait VmpPrepare<B: Backend> {
fn vmp_prepare<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<B>)
fn vmp_prepare<R, A>(&self, pmat: &mut R, mat: &A, scratch: &mut Scratch<B>)
where
R: VmpPMatToMut<B>,
A: MatZnxToRef;
}
#[allow(clippy::too_many_arguments)]
pub trait VmpApplyDftTmpBytes {
fn vmp_apply_dft_tmp_bytes(
&self,
res_size: usize,
a_size: usize,
b_rows: usize,
b_cols_in: usize,
b_cols_out: usize,
b_size: usize,
) -> usize;
}
pub trait VmpApplyDft<B: Backend> {
fn vmp_apply_dft<R, A, C>(&self, res: &mut R, a: &A, pmat: &C, scratch: &mut Scratch<B>)
where
R: VecZnxDftToMut<B>,
A: VecZnxToRef,
C: VmpPMatToRef<B>;
}
#[allow(clippy::too_many_arguments)]
pub trait VmpApplyDftToDftTmpBytes {
fn vmp_apply_dft_to_dft_tmp_bytes(
@@ -61,7 +84,7 @@ pub trait VmpApplyDftToDft<B: Backend> {
/// * `a`: the left operand [crate::layouts::VecZnxDft] of the vector matrix product.
/// * `b`: the right operand [crate::layouts::VmpPMat] of the vector matrix product.
/// * `buf`: scratch space, the size can be obtained with [VmpApplyTmpBytes::vmp_apply_tmp_bytes].
fn vmp_apply_dft_to_dft<R, A, C>(&self, res: &mut R, a: &A, b: &C, scratch: &mut Scratch<B>)
fn vmp_apply_dft_to_dft<R, A, C>(&self, res: &mut R, a: &A, pmat: &C, scratch: &mut Scratch<B>)
where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>,
@@ -82,7 +105,7 @@ pub trait VmpApplyDftToDftAddTmpBytes {
}
pub trait VmpApplyDftToDftAdd<B: Backend> {
fn vmp_apply_dft_to_dft_add<R, A, C>(&self, res: &mut R, a: &A, b: &C, scale: usize, scratch: &mut Scratch<B>)
fn vmp_apply_dft_to_dft_add<R, A, C>(&self, res: &mut R, a: &A, b: &C, limb_offset: usize, scratch: &mut Scratch<B>)
where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>,

View File

@@ -1,57 +1,29 @@
use rand_distr::Distribution;
use crate::{
layouts::{Backend, Scratch, ZnToMut},
reference::zn::zn_normalize_tmp_bytes,
source::Source,
};
pub trait ZnNormalizeTmpBytes {
fn zn_normalize_tmp_bytes(&self, n: usize) -> usize {
zn_normalize_tmp_bytes(n)
}
}
pub trait ZnNormalizeInplace<B: Backend> {
/// Normalizes the selected column of `a`.
fn zn_normalize_inplace<A>(&self, n: usize, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
fn zn_normalize_inplace<R>(&self, n: usize, basek: usize, res: &mut R, res_col: usize, scratch: &mut Scratch<B>)
where
A: ZnToMut;
R: ZnToMut;
}
pub trait ZnFillUniform {
/// Fills the first `size` size with uniform values in \[-2^{basek-1}, 2^{basek-1}\]
fn zn_fill_uniform<R>(&self, n: usize, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source)
fn zn_fill_uniform<R>(&self, n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: ZnToMut;
}
#[allow(clippy::too_many_arguments)]
pub trait ZnFillDistF64 {
fn zn_fill_dist_f64<R, D: Distribution<f64>>(
&self,
n: usize,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) where
R: ZnToMut;
}
#[allow(clippy::too_many_arguments)]
pub trait ZnAddDistF64 {
/// Adds vector sampled according to the provided distribution, scaled by 2^{-k} and bounded to \[-bound, bound\].
fn zn_add_dist_f64<R, D: Distribution<f64>>(
&self,
n: usize,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) where
R: ZnToMut;
}
#[allow(clippy::too_many_arguments)]
pub trait ZnFillNormal {
fn zn_fill_normal<R>(

View File

@@ -0,0 +1,5 @@
pub mod svp;
pub mod vec_znx;
pub mod vec_znx_big;
pub mod vec_znx_dft;
pub mod vmp;

View File

@@ -0,0 +1,237 @@
use std::hint::black_box;
use criterion::{BenchmarkId, Criterion};
use rand::RngCore;
use crate::{
api::{
ModuleNew, SvpApplyDft, SvpApplyDftToDft, SvpApplyDftToDftAdd, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPrepare,
VecZnxDftAlloc,
},
layouts::{Backend, DataViewMut, FillUniform, Module, ScalarZnx, SvpPPol, VecZnx, VecZnxDft},
source::Source,
};
pub fn bench_svp_prepare<B>(c: &mut Criterion, label: &str)
where
Module<B>: SvpPrepare<B> + SvpPPolAlloc<B> + ModuleNew<B>,
B: Backend<ScalarPrep = f64>,
{
let group_name: String = format!("svp_prepare::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B>(log_n: usize) -> impl FnMut()
where
Module<B>: SvpPrepare<B> + SvpPPolAlloc<B> + ModuleNew<B>,
B: Backend<ScalarPrep = f64>,
{
let module: Module<B> = Module::<B>::new(1 << log_n);
let cols: usize = 2;
let mut svp: SvpPPol<Vec<u8>, B> = module.svp_ppol_alloc(cols);
let mut a: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(module.n(), cols);
let mut source = Source::new([0u8; 32]);
a.fill_uniform(50, &mut source);
move || {
module.svp_prepare(&mut svp, 0, &a, 0);
black_box(());
}
}
for log_n in [10, 11, 12, 13, 14] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}", 1 << log_n));
let mut runner = runner::<B>(log_n);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_svp_apply_dft<B>(c: &mut Criterion, label: &str)
where
Module<B>: SvpApplyDft<B> + SvpPPolAlloc<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
B: Backend<ScalarPrep = f64>,
{
let group_name: String = format!("svp_apply_dft::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: SvpApplyDft<B> + SvpPPolAlloc<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
B: Backend<ScalarPrep = f64>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let mut svp: SvpPPol<Vec<u8>, B> = module.svp_ppol_alloc(cols);
let mut res: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(cols, size);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut source = Source::new([0u8; 32]);
source.fill_bytes(svp.data_mut());
source.fill_bytes(res.data_mut());
source.fill_bytes(a.data_mut());
move || {
for j in 0..cols {
module.svp_apply_dft(&mut res, j, &svp, j, &a, j);
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 7], [13, 2, 15], [14, 2, 31]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_svp_apply_dft_to_dft<B>(c: &mut Criterion, label: &str)
where
Module<B>: SvpApplyDftToDft<B> + SvpPPolAlloc<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
B: Backend<ScalarPrep = f64>,
{
let group_name: String = format!("svp_apply_dft_to_dft::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: SvpApplyDftToDft<B> + SvpPPolAlloc<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
B: Backend<ScalarPrep = f64>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let mut svp: SvpPPol<Vec<u8>, B> = module.svp_ppol_alloc(cols);
let mut res: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(cols, size);
let mut a: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(cols, size);
let mut source = Source::new([0u8; 32]);
source.fill_bytes(svp.data_mut());
source.fill_bytes(res.data_mut());
source.fill_bytes(a.data_mut());
move || {
for j in 0..cols {
module.svp_apply_dft_to_dft(&mut res, j, &svp, j, &a, j);
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 7], [13, 2, 15], [14, 2, 31]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_svp_apply_dft_to_dft_add<B>(c: &mut Criterion, label: &str)
where
Module<B>: SvpApplyDftToDftAdd<B> + SvpPPolAlloc<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
B: Backend<ScalarPrep = f64>,
{
let group_name: String = format!("svp_apply_dft_to_dft_add::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: SvpApplyDftToDftAdd<B> + SvpPPolAlloc<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
B: Backend<ScalarPrep = f64>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let mut svp: SvpPPol<Vec<u8>, B> = module.svp_ppol_alloc(cols);
let mut res: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(cols, size);
let mut a: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(cols, size);
let mut source = Source::new([0u8; 32]);
source.fill_bytes(svp.data_mut());
source.fill_bytes(res.data_mut());
source.fill_bytes(a.data_mut());
move || {
for j in 0..cols {
module.svp_apply_dft_to_dft_add(&mut res, j, &svp, j, &a, j);
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 7], [13, 2, 15], [14, 2, 31]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_svp_apply_dft_to_dft_inplace<B>(c: &mut Criterion, label: &str)
where
Module<B>: SvpApplyDftToDftInplace<B> + SvpPPolAlloc<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
B: Backend<ScalarPrep = f64>,
{
let group_name: String = format!("svp_apply_dft_to_dft_inplace::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: SvpApplyDftToDftInplace<B> + SvpPPolAlloc<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
B: Backend<ScalarPrep = f64>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let mut svp: SvpPPol<Vec<u8>, B> = module.svp_ppol_alloc(cols);
let mut res: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(cols, size);
let mut source = Source::new([0u8; 32]);
source.fill_bytes(svp.data_mut());
source.fill_bytes(res.data_mut());
move || {
for j in 0..cols {
module.svp_apply_dft_to_dft_inplace(&mut res, j, &svp, j);
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 7], [13, 2, 15], [14, 2, 31]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,641 @@
use std::hint::black_box;
use criterion::{BenchmarkId, Criterion};
use rand::RngCore;
use crate::{
api::{
ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxBigAdd, VecZnxBigAddInplace, VecZnxBigAddSmall,
VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAutomorphism, VecZnxBigAutomorphismInplace,
VecZnxBigAutomorphismInplaceTmpBytes, VecZnxBigNegate, VecZnxBigNegateInplace, VecZnxBigNormalize,
VecZnxBigNormalizeTmpBytes, VecZnxBigSub, VecZnxBigSubABInplace, VecZnxBigSubBAInplace, VecZnxBigSubSmallA,
VecZnxBigSubSmallB,
},
layouts::{Backend, DataViewMut, Module, ScratchOwned, VecZnx, VecZnxBig},
source::Source,
};
pub fn bench_vec_znx_big_add<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxBigAdd<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let group_name: String = format!("vec_znx_big_add::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxBigAdd<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let n: usize = params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
let mut b: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
let mut c: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
// Fill a with random bytes
source.fill_bytes(a.data_mut());
source.fill_bytes(b.data_mut());
source.fill_bytes(c.data_mut());
move || {
for i in 0..cols {
module.vec_znx_big_add(&mut c, i, &a, i, &b, i);
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vec_znx_big_add_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxBigAddInplace<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let group_name: String = format!("vec_znx_big_add_inplace::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxBigAddInplace<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let n: usize = params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
let mut c: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
// Fill a with random i64
source.fill_bytes(a.data_mut());
source.fill_bytes(c.data_mut());
move || {
for i in 0..cols {
module.vec_znx_big_add_inplace(&mut c, i, &a, i);
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vec_znx_big_add_small<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxBigAddSmall<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let group_name: String = format!("vec_znx_big_add_small::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxBigAddSmall<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let n: usize = params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut c: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
// Fill a with random i64
source.fill_bytes(a.data_mut());
source.fill_bytes(b.data_mut());
source.fill_bytes(c.data_mut());
move || {
for i in 0..cols {
module.vec_znx_big_add_small(&mut c, i, &a, i, &b, i);
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vec_znx_big_add_small_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxBigAddSmallInplace<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let group_name: String = format!("vec_znx_big_add_small_inplace::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxBigAddSmallInplace<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let n: usize = params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut c: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
// Fill a with random i64
source.fill_bytes(a.data_mut());
source.fill_bytes(c.data_mut());
move || {
for i in 0..cols {
module.vec_znx_big_add_small_inplace(&mut c, i, &a, i);
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vec_znx_big_automorphism<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxBigAutomorphism<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let group_name: String = format!("vec_znx_big_automorphism::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxBigAutomorphism<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let n: usize = params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
let mut res: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
// Fill a with random i64
source.fill_bytes(a.data_mut());
source.fill_bytes(res.data_mut());
move || {
for i in 0..cols {
module.vec_znx_big_automorphism(-7, &mut res, i, &a, i);
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vec_znx_automorphism_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxBigAutomorphismInplace<B> + VecZnxBigAutomorphismInplaceTmpBytes + ModuleNew<B> + VecZnxBigAlloc<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_automorphism_inplace::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxBigAutomorphismInplace<B> + ModuleNew<B> + VecZnxBigAutomorphismInplaceTmpBytes + VecZnxBigAlloc<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let n: usize = params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let mut source: Source = Source::new([0u8; 32]);
let mut res: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(module.vec_znx_big_automorphism_inplace_tmp_bytes());
// Fill a with random i64
source.fill_bytes(res.data_mut());
move || {
for i in 0..cols {
module.vec_znx_big_automorphism_inplace(-7, &mut res, i, scratch.borrow());
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vec_znx_big_negate<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxBigNegate<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let group_name: String = format!("vec_znx_big_negate::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxBigNegate<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let module: Module<B> = Module::<B>::new(1 << params[0]);
let cols: usize = params[1];
let size: usize = params[2];
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
let mut b: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
// Fill a with random i64
source.fill_bytes(a.data_mut());
source.fill_bytes(b.data_mut());
move || {
for i in 0..cols {
module.vec_znx_big_negate(&mut b, i, &a, i);
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vec_znx_big_negate_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxBigNegateInplace<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let group_name: String = format!("vec_znx_negate_big_inplace::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxBigNegateInplace<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let module: Module<B> = Module::<B>::new(1 << params[0]);
let cols: usize = params[1];
let size: usize = params[2];
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
// Fill a with random i64
source.fill_bytes(a.data_mut());
move || {
for i in 0..cols {
module.vec_znx_big_negate_inplace(&mut a, i);
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vec_znx_normalize<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxBigNormalize<B> + ModuleNew<B> + VecZnxBigNormalizeTmpBytes + VecZnxBigAlloc<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_big_normalize::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxBigNormalize<B> + ModuleNew<B> + VecZnxBigNormalizeTmpBytes + VecZnxBigAlloc<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let n: usize = params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let basek: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), cols, size);
// Fill a with random i64
source.fill_bytes(a.data_mut());
source.fill_bytes(res.data_mut());
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(module.vec_znx_big_normalize_tmp_bytes());
move || {
for i in 0..cols {
module.vec_znx_big_normalize(basek, &mut res, i, &a, i, scratch.borrow());
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vec_znx_big_sub<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxBigSub<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let group_name: String = format!("vec_znx_big_sub::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxBigSub<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let module: Module<B> = Module::<B>::new(1 << params[0]);
let cols: usize = params[1];
let size: usize = params[2];
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
let mut b: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
let mut c: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
// Fill a with random bytes
source.fill_bytes(a.data_mut());
source.fill_bytes(b.data_mut());
source.fill_bytes(c.data_mut());
move || {
for i in 0..cols {
module.vec_znx_big_sub(&mut c, i, &a, i, &b, i);
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vec_znx_big_sub_ab_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxBigSubABInplace<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let group_name: String = format!("vec_znx_big_sub_inplace::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxBigSubABInplace<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let module: Module<B> = Module::<B>::new(1 << params[0]);
let cols: usize = params[1];
let size: usize = params[2];
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
let mut c: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
// Fill a with random bytes
source.fill_bytes(a.data_mut());
source.fill_bytes(c.data_mut());
move || {
for i in 0..cols {
module.vec_znx_big_sub_ab_inplace(&mut c, i, &a, i);
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vec_znx_big_sub_ba_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxBigSubBAInplace<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let group_name: String = format!("vec_znx_big_sub_inplace::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxBigSubBAInplace<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let module: Module<B> = Module::<B>::new(1 << params[0]);
let cols: usize = params[1];
let size: usize = params[2];
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
let mut c: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
// Fill a with random bytes
source.fill_bytes(a.data_mut());
source.fill_bytes(c.data_mut());
move || {
for i in 0..cols {
module.vec_znx_big_sub_ba_inplace(&mut c, i, &a, i);
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vec_znx_big_sub_small_a<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxBigSubSmallA<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let group_name: String = format!("vec_znx_big_sub_small_a::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxBigSubSmallA<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let module: Module<B> = Module::<B>::new(1 << params[0]);
let cols: usize = params[1];
let size: usize = params[2];
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), cols, size);
let mut b: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
let mut c: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
// Fill a with random bytes
source.fill_bytes(a.data_mut());
source.fill_bytes(b.data_mut());
source.fill_bytes(c.data_mut());
move || {
for i in 0..cols {
module.vec_znx_big_sub_small_a(&mut c, i, &a, i, &b, i);
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vec_znx_big_sub_small_b<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxBigSubSmallB<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let group_name: String = format!("vec_znx_big_sub_small_b::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxBigSubSmallB<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let module: Module<B> = Module::<B>::new(1 << params[0]);
let cols: usize = params[1];
let size: usize = params[2];
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), cols, size);
let mut c: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
// Fill a with random bytes
source.fill_bytes(a.data_mut());
source.fill_bytes(b.data_mut());
source.fill_bytes(c.data_mut());
move || {
for i in 0..cols {
module.vec_znx_big_sub_small_b(&mut c, i, &a, i, &b, i);
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}

View File

@@ -0,0 +1,365 @@
use std::hint::black_box;
use criterion::{BenchmarkId, Criterion};
use rand::RngCore;
use crate::{
api::{
ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxBigAlloc, VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAlloc,
VecZnxDftApply, VecZnxDftSub, VecZnxDftSubABInplace, VecZnxDftSubBAInplace, VecZnxIdftApply, VecZnxIdftApplyTmpA,
VecZnxIdftApplyTmpBytes,
},
layouts::{Backend, DataViewMut, Module, ScratchOwned, VecZnx, VecZnxBig, VecZnxDft},
source::Source,
};
pub fn bench_vec_znx_dft_add<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxDftAdd<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
{
let group_name: String = format!("vec_znx_dft_add::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxDftAdd<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
{
let n: usize = params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(cols, size);
let mut b: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(cols, size);
let mut c: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(cols, size);
source.fill_bytes(a.data_mut());
source.fill_bytes(b.data_mut());
source.fill_bytes(c.data_mut());
move || {
for i in 0..cols {
module.vec_znx_dft_add(&mut c, i, &a, i, &b, i);
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vec_znx_dft_add_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxDftAddInplace<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
{
let group_name: String = format!("vec_znx_dft_add_inplace::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxDftAddInplace<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
{
let n: usize = params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(cols, size);
let mut c: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(cols, size);
// Fill a with random i64
source.fill_bytes(a.data_mut());
source.fill_bytes(c.data_mut());
move || {
for i in 0..cols {
module.vec_znx_dft_add_inplace(&mut c, i, &a, i);
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vec_znx_dft_apply<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxDftApply<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
{
let group_name: String = format!("vec_znx_dft_apply::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxDftApply<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
{
let n: usize = params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let mut source: Source = Source::new([0u8; 32]);
let mut res: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(cols, size);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
source.fill_bytes(res.data_mut());
source.fill_bytes(a.data_mut());
move || {
for i in 0..cols {
module.vec_znx_dft_apply(1, 0, &mut res, i, &a, i);
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vec_znx_idft_apply<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxIdftApply<B> + ModuleNew<B> + VecZnxIdftApplyTmpBytes + VecZnxDftAlloc<B> + VecZnxBigAlloc<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_idft_apply::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxIdftApply<B> + ModuleNew<B> + VecZnxIdftApplyTmpBytes + VecZnxDftAlloc<B> + VecZnxBigAlloc<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let n: usize = params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let mut source: Source = Source::new([0u8; 32]);
let mut res: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
let mut a: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(cols, size);
source.fill_bytes(res.data_mut());
source.fill_bytes(a.data_mut());
let mut scratch = ScratchOwned::alloc(module.vec_znx_idft_apply_tmp_bytes());
move || {
for i in 0..cols {
module.vec_znx_idft_apply(&mut res, i, &a, i, scratch.borrow());
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vec_znx_idft_apply_tmpa<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxIdftApplyTmpA<B> + ModuleNew<B> + VecZnxDftAlloc<B> + VecZnxBigAlloc<B>,
{
let group_name: String = format!("vec_znx_idft_apply_tmpa::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxIdftApplyTmpA<B> + ModuleNew<B> + VecZnxDftAlloc<B> + VecZnxBigAlloc<B>,
{
let module: Module<B> = Module::<B>::new(1 << params[0]);
let cols: usize = params[1];
let size: usize = params[2];
let mut source: Source = Source::new([0u8; 32]);
let mut res: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
let mut a: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(cols, size);
source.fill_bytes(res.data_mut());
source.fill_bytes(a.data_mut());
move || {
for i in 0..cols {
module.vec_znx_idft_apply_tmpa(&mut res, i, &mut a, i);
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vec_znx_dft_sub<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxDftSub<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
{
let group_name: String = format!("vec_znx_dft_sub::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxDftSub<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
{
let n: usize = params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(cols, size);
let mut b: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(cols, size);
let mut c: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(cols, size);
source.fill_bytes(a.data_mut());
source.fill_bytes(b.data_mut());
source.fill_bytes(c.data_mut());
move || {
for i in 0..cols {
module.vec_znx_dft_sub(&mut c, i, &a, i, &b, i);
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vec_znx_dft_sub_ab_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxDftSubABInplace<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
{
let group_name: String = format!("vec_znx_dft_sub_ab_inplace::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxDftSubABInplace<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
{
let n: usize = params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(cols, size);
let mut c: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(cols, size);
// Fill a with random i64
source.fill_bytes(a.data_mut());
source.fill_bytes(c.data_mut());
move || {
for i in 0..cols {
module.vec_znx_dft_sub_ab_inplace(&mut c, i, &a, i);
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vec_znx_dft_sub_ba_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxDftSubBAInplace<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
{
let group_name: String = format!("vec_znx_dft_sub_ba_inplace::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxDftSubBAInplace<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
{
let n: usize = params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(cols, size);
let mut c: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(cols, size);
// Fill a with random i64
source.fill_bytes(a.data_mut());
source.fill_bytes(c.data_mut());
move || {
for i in 0..cols {
module.vec_znx_dft_sub_ba_inplace(&mut c, i, &a, i);
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}

View File

@@ -0,0 +1,259 @@
use std::hint::black_box;
use criterion::{BenchmarkId, Criterion};
use rand::RngCore;
use crate::{
api::{
ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxDftAlloc, VmpApplyDft, VmpApplyDftTmpBytes, VmpApplyDftToDft,
VmpApplyDftToDftAdd, VmpApplyDftToDftAddTmpBytes, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, VmpPrepareTmpBytes,
},
layouts::{Backend, DataViewMut, MatZnx, Module, ScratchOwned, VecZnx, VecZnxDft, VmpPMat},
source::Source,
};
pub fn bench_vmp_prepare<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: ModuleNew<B> + VmpPMatAlloc<B> + VmpPrepare<B> + VmpPrepareTmpBytes,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vmp_prepare::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 5]) -> impl FnMut()
where
Module<B>: ModuleNew<B> + VmpPMatAlloc<B> + VmpPrepare<B> + VmpPrepareTmpBytes,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let module: Module<B> = Module::<B>::new(1 << params[0]);
let rows: usize = params[1];
let cols_in: usize = params[2];
let cols_out: usize = params[3];
let size: usize = params[4];
let mut source: Source = Source::new([0u8; 32]);
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(module.vmp_prepare_tmp_bytes(rows, cols_in, cols_out, size));
let mut mat: MatZnx<Vec<u8>> = MatZnx::alloc(module.n(), rows, cols_in, cols_out, size);
let mut pmat: VmpPMat<Vec<u8>, B> = module.vmp_pmat_alloc(rows, cols_in, cols_out, size);
source.fill_bytes(mat.data_mut());
source.fill_bytes(pmat.data_mut());
move || {
module.vmp_prepare(&mut pmat, &mat, scratch.borrow());
black_box(());
}
}
for params in [
[10, 2, 1, 2, 3],
[11, 4, 1, 2, 5],
[12, 7, 1, 2, 8],
[13, 15, 1, 2, 16],
[14, 31, 1, 2, 32],
] {
let id = BenchmarkId::from_parameter(format!(
"{}x({}x{})x({}x{})",
1 << params[0],
params[2],
params[1],
params[3],
params[4]
));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vmp_apply_dft<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: ModuleNew<B> + VmpApplyDftTmpBytes + VmpApplyDft<B> + VmpPMatAlloc<B> + VecZnxDftAlloc<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vmp_apply_dft::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 5]) -> impl FnMut()
where
Module<B>: ModuleNew<B> + VmpApplyDftTmpBytes + VmpApplyDft<B> + VmpPMatAlloc<B> + VecZnxDftAlloc<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let module: Module<B> = Module::<B>::new(1 << params[0]);
let rows: usize = params[1];
let cols_in: usize = params[2];
let cols_out: usize = params[3];
let size: usize = params[4];
let mut source: Source = Source::new([0u8; 32]);
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(1 << 20);
let mut res: VecZnxDft<Vec<u8>, _> = module.vec_znx_dft_alloc(cols_out, size);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), cols_in, size);
let mut pmat: VmpPMat<Vec<u8>, B> = module.vmp_pmat_alloc(rows, cols_in, cols_out, size);
source.fill_bytes(pmat.data_mut());
source.fill_bytes(res.data_mut());
source.fill_bytes(a.data_mut());
move || {
module.vmp_apply_dft(&mut res, &a, &pmat, scratch.borrow());
black_box(());
}
}
for params in [
[10, 2, 1, 2, 3],
[11, 4, 1, 2, 5],
[12, 7, 1, 2, 8],
[13, 15, 1, 2, 16],
[14, 31, 1, 2, 32],
] {
let id = BenchmarkId::from_parameter(format!(
"{}x({}x{})x({}x{})",
1 << params[0],
params[2],
params[1],
params[3],
params[4]
));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vmp_apply_dft_to_dft<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: ModuleNew<B> + VecZnxDftAlloc<B> + VmpPMatAlloc<B> + VmpApplyDftToDft<B> + VmpApplyDftToDftTmpBytes,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vmp_apply_dft_to_dft::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 5]) -> impl FnMut()
where
Module<B>: ModuleNew<B> + VecZnxDftAlloc<B> + VmpPMatAlloc<B> + VmpApplyDftToDft<B> + VmpApplyDftToDftTmpBytes,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let module: Module<B> = Module::<B>::new(1 << params[0]);
let rows: usize = params[1];
let cols_in: usize = params[2];
let cols_out: usize = params[3];
let size: usize = params[4];
let mut source: Source = Source::new([0u8; 32]);
let mut scratch: ScratchOwned<B> =
ScratchOwned::alloc(module.vmp_apply_dft_to_dft_tmp_bytes(size, size, rows, cols_in, cols_out, size));
let mut res: VecZnxDft<Vec<u8>, _> = module.vec_znx_dft_alloc(cols_out, size);
let mut a: VecZnxDft<Vec<u8>, _> = module.vec_znx_dft_alloc(cols_in, size);
let mut pmat: VmpPMat<Vec<u8>, B> = module.vmp_pmat_alloc(rows, cols_in, cols_out, size);
source.fill_bytes(pmat.data_mut());
source.fill_bytes(res.data_mut());
source.fill_bytes(a.data_mut());
move || {
module.vmp_apply_dft_to_dft(&mut res, &a, &pmat, scratch.borrow());
black_box(());
}
}
for params in [
[10, 2, 1, 2, 3],
[11, 4, 1, 2, 5],
[12, 7, 1, 2, 8],
[13, 15, 1, 2, 16],
[14, 31, 1, 2, 32],
] {
let id = BenchmarkId::from_parameter(format!(
"{}x({}x{})x({}x{})",
1 << params[0], // n
params[2], // cols_in
params[1], // size_in (=rows)
params[3], // cols_out
params[4] // size_out
));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vmp_apply_dft_to_dft_add<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: ModuleNew<B> + VecZnxDftAlloc<B> + VmpPMatAlloc<B> + VmpApplyDftToDftAdd<B> + VmpApplyDftToDftAddTmpBytes,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vmp_apply_dft_to_dft_add::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 5]) -> impl FnMut()
where
Module<B>: ModuleNew<B> + VecZnxDftAlloc<B> + VmpPMatAlloc<B> + VmpApplyDftToDftAdd<B> + VmpApplyDftToDftAddTmpBytes,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let module: Module<B> = Module::<B>::new(1 << params[0]);
let rows: usize = params[1];
let cols_in: usize = params[2];
let cols_out: usize = params[3];
let size: usize = params[4];
let mut source: Source = Source::new([0u8; 32]);
let mut scratch: ScratchOwned<B> =
ScratchOwned::alloc(module.vmp_apply_dft_to_dft_add_tmp_bytes(size, size, rows, cols_in, cols_out, size));
let mut res: VecZnxDft<Vec<u8>, _> = module.vec_znx_dft_alloc(cols_out, size);
let mut a: VecZnxDft<Vec<u8>, _> = module.vec_znx_dft_alloc(cols_in, size);
let mut pmat: VmpPMat<Vec<u8>, B> = module.vmp_pmat_alloc(rows, cols_in, cols_out, size);
source.fill_bytes(pmat.data_mut());
source.fill_bytes(res.data_mut());
source.fill_bytes(a.data_mut());
move || {
module.vmp_apply_dft_to_dft_add(&mut res, &a, &pmat, 1, scratch.borrow());
black_box(());
}
}
for params in [
[10, 2, 1, 2, 3],
[11, 4, 1, 2, 5],
[12, 7, 1, 2, 8],
[13, 15, 1, 2, 16],
[14, 31, 1, 2, 32],
] {
let id = BenchmarkId::from_parameter(format!(
"{}x({}x{})x({}x{})",
1 << params[0],
params[2],
params[1],
params[3],
params[4]
));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}

View File

@@ -1,7 +1,15 @@
use crate::{
api::{SvpApply, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPPolFromBytes, SvpPrepare},
layouts::{Backend, Module, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDftToMut, VecZnxDftToRef},
oep::{SvpApplyImpl, SvpApplyInplaceImpl, SvpPPolAllocBytesImpl, SvpPPolAllocImpl, SvpPPolFromBytesImpl, SvpPrepareImpl},
api::{
SvpApplyDft, SvpApplyDftToDft, SvpApplyDftToDftAdd, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes,
SvpPPolFromBytes, SvpPrepare,
},
layouts::{
Backend, Module, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef,
},
oep::{
SvpApplyDftImpl, SvpApplyDftToDftAddImpl, SvpApplyDftToDftImpl, SvpApplyDftToDftInplaceImpl, SvpPPolAllocBytesImpl,
SvpPPolAllocImpl, SvpPPolFromBytesImpl, SvpPrepareImpl,
},
};
impl<B> SvpPPolFromBytes<B> for Module<B>
@@ -44,29 +52,57 @@ where
}
}
impl<B> SvpApply<B> for Module<B>
impl<B> SvpApplyDft<B> for Module<B>
where
B: Backend + SvpApplyImpl<B>,
B: Backend + SvpApplyDftImpl<B>,
{
fn svp_apply<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
fn svp_apply_dft<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
where
R: VecZnxDftToMut<B>,
A: SvpPPolToRef<B>,
C: VecZnxToRef,
{
B::svp_apply_dft_impl(self, res, res_col, a, a_col, b, b_col);
}
}
impl<B> SvpApplyDftToDft<B> for Module<B>
where
B: Backend + SvpApplyDftToDftImpl<B>,
{
fn svp_apply_dft_to_dft<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
where
R: VecZnxDftToMut<B>,
A: SvpPPolToRef<B>,
C: VecZnxDftToRef<B>,
{
B::svp_apply_impl(self, res, res_col, a, a_col, b, b_col);
B::svp_apply_dft_to_dft_impl(self, res, res_col, a, a_col, b, b_col);
}
}
impl<B> SvpApplyInplace<B> for Module<B>
impl<B> SvpApplyDftToDftAdd<B> for Module<B>
where
B: Backend + SvpApplyInplaceImpl,
B: Backend + SvpApplyDftToDftAddImpl<B>,
{
fn svp_apply_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn svp_apply_dft_to_dft_add<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
where
R: VecZnxDftToMut<B>,
A: SvpPPolToRef<B>,
C: VecZnxDftToRef<B>,
{
B::svp_apply_dft_to_dft_add_impl(self, res, res_col, a, a_col, b, b_col);
}
}
impl<B> SvpApplyDftToDftInplace<B> for Module<B>
where
B: Backend + SvpApplyDftToDftInplaceImpl,
{
fn svp_apply_dft_to_dft_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<B>,
A: SvpPPolToRef<B>,
{
B::svp_apply_inplace_impl(self, res, res_col, a, a_col);
B::svp_apply_dft_to_dft_inplace_impl(self, res, res_col, a, a_col);
}
}

View File

@@ -1,19 +1,24 @@
use crate::{
api::{
VecZnxAdd, VecZnxAddDistF64, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism,
VecZnxAutomorphismInplace, VecZnxCopy, VecZnxFillDistF64, VecZnxFillNormal, VecZnxFillUniform, VecZnxLshInplace,
VecZnxMerge, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegate, VecZnxNegateInplace, VecZnxNormalize,
VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSplit,
VecZnxSub, VecZnxSubABInplace, VecZnxSubBAInplace, VecZnxSubScalarInplace, VecZnxSwithcDegree,
VecZnxAdd, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalar, VecZnxAddScalarInplace, VecZnxAutomorphism,
VecZnxAutomorphismInplace, VecZnxAutomorphismInplaceTmpBytes, VecZnxCopy, VecZnxFillNormal, VecZnxFillUniform, VecZnxLsh,
VecZnxLshInplace, VecZnxLshTmpBytes, VecZnxMergeRings, VecZnxMergeRingsTmpBytes, VecZnxMulXpMinusOne,
VecZnxMulXpMinusOneInplace, VecZnxMulXpMinusOneInplaceTmpBytes, VecZnxNegate, VecZnxNegateInplace, VecZnxNormalize,
VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes,
VecZnxRsh, VecZnxRshInplace, VecZnxRshTmpBytes, VecZnxSplitRing, VecZnxSplitRingTmpBytes, VecZnxSub, VecZnxSubABInplace,
VecZnxSubBAInplace, VecZnxSubScalar, VecZnxSubScalarInplace, VecZnxSwitchRing,
},
layouts::{Backend, Module, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef},
oep::{
VecZnxAddDistF64Impl, VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl, VecZnxAddScalarInplaceImpl,
VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl, VecZnxFillDistF64Impl, VecZnxFillNormalImpl,
VecZnxFillUniformImpl, VecZnxLshInplaceImpl, VecZnxMergeImpl, VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl,
VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl, VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl,
VecZnxRotateImpl, VecZnxRotateInplaceImpl, VecZnxRshInplaceImpl, VecZnxSplitImpl, VecZnxSubABInplaceImpl,
VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarInplaceImpl, VecZnxSwithcDegreeImpl,
VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl, VecZnxAddScalarImpl, VecZnxAddScalarInplaceImpl,
VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxAutomorphismInplaceTmpBytesImpl, VecZnxCopyImpl,
VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxLshImpl, VecZnxLshInplaceImpl, VecZnxLshTmpBytesImpl,
VecZnxMergeRingsImpl, VecZnxMergeRingsTmpBytesImpl, VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl,
VecZnxMulXpMinusOneInplaceTmpBytesImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl,
VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl,
VecZnxRotateInplaceTmpBytesImpl, VecZnxRshImpl, VecZnxRshInplaceImpl, VecZnxRshTmpBytesImpl, VecZnxSplitRingImpl,
VecZnxSplitRingTmpBytesImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarImpl,
VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl,
},
source::Source,
};
@@ -79,6 +84,20 @@ where
}
}
impl<B> VecZnxAddScalar for Module<B>
where
B: Backend + VecZnxAddScalarImpl<B>,
{
fn vec_znx_add_scalar<R, A, D>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize, b_limb: usize)
where
R: VecZnxToMut,
A: ScalarZnxToRef,
D: VecZnxToRef,
{
B::vec_znx_add_scalar_impl(self, res, res_col, a, a_col, b, b_col, b_limb)
}
}
impl<B> VecZnxAddScalarInplace for Module<B>
where
B: Backend + VecZnxAddScalarInplaceImpl<B>,
@@ -132,6 +151,20 @@ where
}
}
impl<B> VecZnxSubScalar for Module<B>
where
B: Backend + VecZnxSubScalarImpl<B>,
{
fn vec_znx_sub_scalar<R, A, D>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize, b_limb: usize)
where
R: VecZnxToMut,
A: ScalarZnxToRef,
D: VecZnxToRef,
{
B::vec_znx_sub_scalar_impl(self, res, res_col, a, a_col, b, b_col, b_limb)
}
}
impl<B> VecZnxSubScalarInplace for Module<B>
where
B: Backend + VecZnxSubScalarInplaceImpl<B>,
@@ -170,27 +203,87 @@ where
}
}
impl<B> VecZnxLshInplace for Module<B>
impl<B> VecZnxRshTmpBytes for Module<B>
where
B: Backend + VecZnxLshInplaceImpl<B>,
B: Backend + VecZnxRshTmpBytesImpl<B>,
{
fn vec_znx_lsh_inplace<A>(&self, basek: usize, k: usize, a: &mut A)
where
A: VecZnxToMut,
{
B::vec_znx_lsh_inplace_impl(self, basek, k, a)
fn vec_znx_rsh_tmp_bytes(&self) -> usize {
B::vec_znx_rsh_tmp_bytes_impl(self)
}
}
impl<B> VecZnxRshInplace for Module<B>
impl<B> VecZnxLshTmpBytes for Module<B>
where
B: Backend + VecZnxRshInplaceImpl<B>,
B: Backend + VecZnxLshTmpBytesImpl<B>,
{
fn vec_znx_rsh_inplace<A>(&self, basek: usize, k: usize, a: &mut A)
fn vec_znx_lsh_tmp_bytes(&self) -> usize {
B::vec_znx_lsh_tmp_bytes_impl(self)
}
}
impl<B> VecZnxLsh<B> for Module<B>
where
B: Backend + VecZnxLshImpl<B>,
{
fn vec_znx_lsh<R, A>(
&self,
basek: usize,
k: usize,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch<B>,
) where
R: VecZnxToMut,
A: VecZnxToRef,
{
B::vec_znx_lsh_inplace_impl(self, basek, k, res, res_col, a, a_col, scratch);
}
}
impl<B> VecZnxRsh<B> for Module<B>
where
B: Backend + VecZnxRshImpl<B>,
{
fn vec_znx_rsh<R, A>(
&self,
basek: usize,
k: usize,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch<B>,
) where
R: VecZnxToMut,
A: VecZnxToRef,
{
B::vec_znx_rsh_inplace_impl(self, basek, k, res, res_col, a, a_col, scratch);
}
}
impl<B> VecZnxLshInplace<B> for Module<B>
where
B: Backend + VecZnxLshInplaceImpl<B>,
{
fn vec_znx_lsh_inplace<A>(&self, basek: usize, k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
where
A: VecZnxToMut,
{
B::vec_znx_rsh_inplace_impl(self, basek, k, a)
B::vec_znx_lsh_inplace_impl(self, basek, k, a, a_col, scratch)
}
}
impl<B> VecZnxRshInplace<B> for Module<B>
where
B: Backend + VecZnxRshInplaceImpl<B>,
{
fn vec_znx_rsh_inplace<A>(&self, basek: usize, k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
where
A: VecZnxToMut,
{
B::vec_znx_rsh_inplace_impl(self, basek, k, a, a_col, scratch)
}
}
@@ -207,15 +300,24 @@ where
}
}
impl<B> VecZnxRotateInplace for Module<B>
impl<B> VecZnxRotateInplaceTmpBytes for Module<B>
where
B: Backend + VecZnxRotateInplaceTmpBytesImpl<B>,
{
fn vec_znx_rotate_inplace_tmp_bytes(&self) -> usize {
B::vec_znx_rotate_inplace_tmp_bytes_impl(self)
}
}
impl<B> VecZnxRotateInplace<B> for Module<B>
where
B: Backend + VecZnxRotateInplaceImpl<B>,
{
fn vec_znx_rotate_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
fn vec_znx_rotate_inplace<A>(&self, k: i64, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
where
A: VecZnxToMut,
{
B::vec_znx_rotate_inplace_impl(self, k, a, a_col)
B::vec_znx_rotate_inplace_impl(self, k, a, a_col, scratch)
}
}
@@ -232,15 +334,24 @@ where
}
}
impl<B> VecZnxAutomorphismInplace for Module<B>
impl<B> VecZnxAutomorphismInplaceTmpBytes for Module<B>
where
B: Backend + VecZnxAutomorphismInplaceTmpBytesImpl<B>,
{
fn vec_znx_automorphism_inplace_tmp_bytes(&self) -> usize {
B::vec_znx_automorphism_inplace_tmp_bytes_impl(self)
}
}
impl<B> VecZnxAutomorphismInplace<B> for Module<B>
where
B: Backend + VecZnxAutomorphismInplaceImpl<B>,
{
fn vec_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
fn vec_znx_automorphism_inplace<R>(&self, k: i64, res: &mut R, res_col: usize, scratch: &mut Scratch<B>)
where
A: VecZnxToMut,
R: VecZnxToMut,
{
B::vec_znx_automorphism_inplace_impl(self, k, a, a_col)
B::vec_znx_automorphism_inplace_impl(self, k, res, res_col, scratch)
}
}
@@ -257,54 +368,81 @@ where
}
}
impl<B> VecZnxMulXpMinusOneInplace for Module<B>
impl<B> VecZnxMulXpMinusOneInplaceTmpBytes for Module<B>
where
B: Backend + VecZnxMulXpMinusOneInplaceTmpBytesImpl<B>,
{
fn vec_znx_mul_xp_minus_one_inplace_tmp_bytes(&self) -> usize {
B::vec_znx_mul_xp_minus_one_inplace_tmp_bytes_impl(self)
}
}
impl<B> VecZnxMulXpMinusOneInplace<B> for Module<B>
where
B: Backend + VecZnxMulXpMinusOneInplaceImpl<B>,
{
fn vec_znx_mul_xp_minus_one_inplace<R>(&self, p: i64, res: &mut R, res_col: usize)
fn vec_znx_mul_xp_minus_one_inplace<R>(&self, p: i64, res: &mut R, res_col: usize, scratch: &mut Scratch<B>)
where
R: VecZnxToMut,
{
B::vec_znx_mul_xp_minus_one_inplace_impl(self, p, res, res_col);
B::vec_znx_mul_xp_minus_one_inplace_impl(self, p, res, res_col, scratch);
}
}
impl<B> VecZnxSplit<B> for Module<B>
impl<B> VecZnxSplitRingTmpBytes for Module<B>
where
B: Backend + VecZnxSplitImpl<B>,
B: Backend + VecZnxSplitRingTmpBytesImpl<B>,
{
fn vec_znx_split<R, A>(&self, res: &mut [R], res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
fn vec_znx_split_ring_tmp_bytes(&self) -> usize {
B::vec_znx_split_ring_tmp_bytes_impl(self)
}
}
impl<B> VecZnxSplitRing<B> for Module<B>
where
B: Backend + VecZnxSplitRingImpl<B>,
{
fn vec_znx_split_ring<R, A>(&self, res: &mut [R], res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
B::vec_znx_split_impl(self, res, res_col, a, a_col, scratch)
B::vec_znx_split_ring_impl(self, res, res_col, a, a_col, scratch)
}
}
impl<B> VecZnxMerge for Module<B>
impl<B> VecZnxMergeRingsTmpBytes for Module<B>
where
B: Backend + VecZnxMergeImpl<B>,
B: Backend + VecZnxMergeRingsTmpBytesImpl<B>,
{
fn vec_znx_merge<R, A>(&self, res: &mut R, res_col: usize, a: &[A], a_col: usize)
fn vec_znx_merge_rings_tmp_bytes(&self) -> usize {
B::vec_znx_merge_rings_tmp_bytes_impl(self)
}
}
impl<B> VecZnxMergeRings<B> for Module<B>
where
B: Backend + VecZnxMergeRingsImpl<B>,
{
fn vec_znx_merge_rings<R, A>(&self, res: &mut R, res_col: usize, a: &[A], a_col: usize, scratch: &mut Scratch<B>)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
B::vec_znx_merge_impl(self, res, res_col, a, a_col)
B::vec_znx_merge_rings_impl(self, res, res_col, a, a_col, scratch)
}
}
impl<B> VecZnxSwithcDegree for Module<B>
impl<B> VecZnxSwitchRing for Module<B>
where
B: Backend + VecZnxSwithcDegreeImpl<B>,
B: Backend + VecZnxSwitchRingImpl<B>,
{
fn vec_znx_switch_degree<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_switch_ring<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
B::vec_znx_switch_degree_impl(self, res, res_col, a, a_col)
B::vec_znx_switch_ring_impl(self, res, res_col, a, a_col)
}
}
@@ -325,51 +463,11 @@ impl<B> VecZnxFillUniform for Module<B>
where
B: Backend + VecZnxFillUniformImpl<B>,
{
fn vec_znx_fill_uniform<R>(&self, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source)
fn vec_znx_fill_uniform<R>(&self, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: VecZnxToMut,
{
B::vec_znx_fill_uniform_impl(self, basek, res, res_col, k, source);
}
}
impl<B> VecZnxFillDistF64 for Module<B>
where
B: Backend + VecZnxFillDistF64Impl<B>,
{
fn vec_znx_fill_dist_f64<R, D: rand::prelude::Distribution<f64>>(
&self,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) where
R: VecZnxToMut,
{
B::vec_znx_fill_dist_f64_impl(self, basek, res, res_col, k, source, dist, bound);
}
}
impl<B> VecZnxAddDistF64 for Module<B>
where
B: Backend + VecZnxAddDistF64Impl<B>,
{
fn vec_znx_add_dist_f64<R, D: rand::prelude::Distribution<f64>>(
&self,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) where
R: VecZnxToMut,
{
B::vec_znx_add_dist_f64_impl(self, basek, res, res_col, k, source, dist, bound);
B::vec_znx_fill_uniform_impl(self, basek, res, res_col, source);
}
}

View File

@@ -1,18 +1,16 @@
use rand_distr::Distribution;
use crate::{
api::{
VecZnxBigAdd, VecZnxBigAddDistF64, VecZnxBigAddInplace, VecZnxBigAddNormal, VecZnxBigAddSmall, VecZnxBigAddSmallInplace,
VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigAutomorphism, VecZnxBigAutomorphismInplace, VecZnxBigFillDistF64,
VecZnxBigFillNormal, VecZnxBigFromBytes, VecZnxBigNegateInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
VecZnxBigSub, VecZnxBigSubABInplace, VecZnxBigSubBAInplace, VecZnxBigSubSmallA, VecZnxBigSubSmallAInplace,
VecZnxBigSubSmallB, VecZnxBigSubSmallBInplace,
VecZnxBigAdd, VecZnxBigAddInplace, VecZnxBigAddNormal, VecZnxBigAddSmall, VecZnxBigAddSmallInplace, VecZnxBigAlloc,
VecZnxBigAllocBytes, VecZnxBigAutomorphism, VecZnxBigAutomorphismInplace, VecZnxBigAutomorphismInplaceTmpBytes,
VecZnxBigFromBytes, VecZnxBigFromSmall, VecZnxBigNegate, VecZnxBigNegateInplace, VecZnxBigNormalize,
VecZnxBigNormalizeTmpBytes, VecZnxBigSub, VecZnxBigSubABInplace, VecZnxBigSubBAInplace, VecZnxBigSubSmallA,
VecZnxBigSubSmallAInplace, VecZnxBigSubSmallB, VecZnxBigSubSmallBInplace,
},
layouts::{Backend, Module, Scratch, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef},
oep::{
VecZnxBigAddDistF64Impl, VecZnxBigAddImpl, VecZnxBigAddInplaceImpl, VecZnxBigAddNormalImpl, VecZnxBigAddSmallImpl,
VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, VecZnxBigAutomorphismImpl,
VecZnxBigAutomorphismInplaceImpl, VecZnxBigFillDistF64Impl, VecZnxBigFillNormalImpl, VecZnxBigFromBytesImpl,
VecZnxBigAddImpl, VecZnxBigAddInplaceImpl, VecZnxBigAddNormalImpl, VecZnxBigAddSmallImpl, VecZnxBigAddSmallInplaceImpl,
VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, VecZnxBigAutomorphismImpl, VecZnxBigAutomorphismInplaceImpl,
VecZnxBigAutomorphismInplaceTmpBytesImpl, VecZnxBigFromBytesImpl, VecZnxBigFromSmallImpl, VecZnxBigNegateImpl,
VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl, VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubABInplaceImpl,
VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl, VecZnxBigSubSmallAImpl, VecZnxBigSubSmallAInplaceImpl,
VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl,
@@ -20,6 +18,19 @@ use crate::{
source::Source,
};
impl<B> VecZnxBigFromSmall<B> for Module<B>
where
B: Backend + VecZnxBigFromSmallImpl<B>,
{
fn vec_znx_big_from_small<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxToRef,
{
B::vec_znx_big_from_small_impl(res, res_col, a, a_col);
}
}
impl<B> VecZnxBigAlloc<B> for Module<B>
where
B: Backend + VecZnxBigAllocImpl<B>,
@@ -47,24 +58,6 @@ where
}
}
impl<B> VecZnxBigAddDistF64<B> for Module<B>
where
B: Backend + VecZnxBigAddDistF64Impl<B>,
{
fn vec_znx_big_add_dist_f64<R: VecZnxBigToMut<B>, D: Distribution<f64>>(
&self,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) {
B::add_dist_f64_impl(self, basek, res, res_col, k, source, dist, bound);
}
}
impl<B> VecZnxBigAddNormal<B> for Module<B>
where
B: Backend + VecZnxBigAddNormalImpl<B>,
@@ -83,42 +76,6 @@ where
}
}
impl<B> VecZnxBigFillDistF64<B> for Module<B>
where
B: Backend + VecZnxBigFillDistF64Impl<B>,
{
fn vec_znx_big_fill_dist_f64<R: VecZnxBigToMut<B>, D: Distribution<f64>>(
&self,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) {
B::fill_dist_f64_impl(self, basek, res, res_col, k, source, dist, bound);
}
}
impl<B> VecZnxBigFillNormal<B> for Module<B>
where
B: Backend + VecZnxBigFillNormalImpl<B>,
{
fn vec_znx_big_fill_normal<R: VecZnxBigToMut<B>>(
&self,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) {
B::fill_normal_impl(self, basek, res, res_col, k, source, sigma, bound);
}
}
impl<B> VecZnxBigAdd<B> for Module<B>
where
B: Backend + VecZnxBigAddImpl<B>,
@@ -267,6 +224,19 @@ where
}
}
impl<B> VecZnxBigNegate<B> for Module<B>
where
B: Backend + VecZnxBigNegateImpl<B>,
{
fn vec_znx_big_negate<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxBigToRef<B>,
{
B::vec_znx_big_negate_impl(self, res, res_col, a, a_col);
}
}
impl<B> VecZnxBigNegateInplace<B> for Module<B>
where
B: Backend + VecZnxBigNegateInplaceImpl<B>,
@@ -321,14 +291,23 @@ where
}
}
impl<B> VecZnxBigAutomorphismInplaceTmpBytes for Module<B>
where
B: Backend + VecZnxBigAutomorphismInplaceTmpBytesImpl<B>,
{
fn vec_znx_big_automorphism_inplace_tmp_bytes(&self) -> usize {
B::vec_znx_big_automorphism_inplace_tmp_bytes_impl(self)
}
}
impl<B> VecZnxBigAutomorphismInplace<B> for Module<B>
where
B: Backend + VecZnxBigAutomorphismInplaceImpl<B>,
{
fn vec_znx_big_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
fn vec_znx_big_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
where
A: VecZnxBigToMut<B>,
{
B::vec_znx_big_automorphism_inplace_impl(self, k, a, a_col);
B::vec_znx_big_automorphism_inplace_impl(self, k, a, a_col, scratch);
}
}

View File

@@ -1,16 +1,17 @@
use crate::{
api::{
DFT, IDFT, IDFTConsume, IDFTTmpA, VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftCopy,
VecZnxDftFromBytes, VecZnxDftSub, VecZnxDftSubABInplace, VecZnxDftSubBAInplace, VecZnxDftZero, VecZnxIDFTTmpBytes,
VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftCopy,
VecZnxDftFromBytes, VecZnxDftSub, VecZnxDftSubABInplace, VecZnxDftSubBAInplace, VecZnxDftZero, VecZnxIdftApply,
VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxIdftApplyTmpBytes,
},
layouts::{
Backend, Data, Module, Scratch, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef,
VecZnxToRef,
},
oep::{
DFTImpl, IDFTConsumeImpl, IDFTImpl, IDFTTmpAImpl, VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl,
VecZnxDftAllocImpl, VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl,
VecZnxDftSubImpl, VecZnxDftZeroImpl, VecZnxIDFTTmpBytesImpl,
VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftApplyImpl,
VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl, VecZnxDftSubImpl,
VecZnxDftZeroImpl, VecZnxIdftApplyConsumeImpl, VecZnxIdftApplyImpl, VecZnxIdftApplyTmpAImpl, VecZnxIdftApplyTmpBytesImpl,
},
};
@@ -41,63 +42,63 @@ where
}
}
impl<B> VecZnxIDFTTmpBytes for Module<B>
impl<B> VecZnxIdftApplyTmpBytes for Module<B>
where
B: Backend + VecZnxIDFTTmpBytesImpl<B>,
B: Backend + VecZnxIdftApplyTmpBytesImpl<B>,
{
fn vec_znx_idft_tmp_bytes(&self) -> usize {
B::vec_znx_idft_tmp_bytes_impl(self)
fn vec_znx_idft_apply_tmp_bytes(&self) -> usize {
B::vec_znx_idft_apply_tmp_bytes_impl(self)
}
}
impl<B> IDFT<B> for Module<B>
impl<B> VecZnxIdftApply<B> for Module<B>
where
B: Backend + IDFTImpl<B>,
B: Backend + VecZnxIdftApplyImpl<B>,
{
fn idft<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
fn vec_znx_idft_apply<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
where
R: VecZnxBigToMut<B>,
A: VecZnxDftToRef<B>,
{
B::idft_impl(self, res, res_col, a, a_col, scratch);
B::vec_znx_idft_apply_impl(self, res, res_col, a, a_col, scratch);
}
}
impl<B> IDFTTmpA<B> for Module<B>
impl<B> VecZnxIdftApplyTmpA<B> for Module<B>
where
B: Backend + IDFTTmpAImpl<B>,
B: Backend + VecZnxIdftApplyTmpAImpl<B>,
{
fn idft_tmp_a<R, A>(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
fn vec_znx_idft_apply_tmpa<R, A>(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxDftToMut<B>,
{
B::idft_tmp_a_impl(self, res, res_col, a, a_col);
B::vec_znx_idft_apply_tmpa_impl(self, res, res_col, a, a_col);
}
}
impl<B> IDFTConsume<B> for Module<B>
impl<B> VecZnxIdftApplyConsume<B> for Module<B>
where
B: Backend + IDFTConsumeImpl<B>,
B: Backend + VecZnxIdftApplyConsumeImpl<B>,
{
fn vec_znx_idft_consume<D: Data>(&self, a: VecZnxDft<D, B>) -> VecZnxBig<D, B>
fn vec_znx_idft_apply_consume<D: Data>(&self, a: VecZnxDft<D, B>) -> VecZnxBig<D, B>
where
VecZnxDft<D, B>: VecZnxDftToMut<B>,
{
B::idft_consume_impl(self, a)
B::vec_znx_idft_apply_consume_impl(self, a)
}
}
impl<B> DFT<B> for Module<B>
impl<B> VecZnxDftApply<B> for Module<B>
where
B: Backend + DFTImpl<B>,
B: Backend + VecZnxDftApplyImpl<B>,
{
fn dft<R, A>(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_dft_apply<R, A>(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<B>,
A: VecZnxToRef,
{
B::dft_impl(self, step, offset, res, res_col, a, a_col);
B::vec_znx_dft_apply_impl(self, step, offset, res, res_col, a, a_col);
}
}

View File

@@ -1,12 +1,16 @@
use crate::{
api::{
VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftAddTmpBytes, VmpApplyDftToDftTmpBytes, VmpPMatAlloc,
VmpPMatAllocBytes, VmpPMatFromBytes, VmpPrepare, VmpPrepareTmpBytes,
VmpApplyDft, VmpApplyDftTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftAddTmpBytes,
VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPMatAllocBytes, VmpPMatFromBytes, VmpPrepare, VmpPrepareTmpBytes,
},
layouts::{
Backend, MatZnxToRef, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, VmpPMatOwned, VmpPMatToMut,
VmpPMatToRef,
},
layouts::{Backend, MatZnxToRef, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VmpPMatOwned, VmpPMatToMut, VmpPMatToRef},
oep::{
VmpApplyDftToDftAddImpl, VmpApplyDftToDftAddTmpBytesImpl, VmpApplyDftToDftImpl, VmpApplyDftToDftTmpBytesImpl,
VmpPMatAllocBytesImpl, VmpPMatAllocImpl, VmpPMatFromBytesImpl, VmpPMatPrepareImpl, VmpPrepareTmpBytesImpl,
VmpApplyDftImpl, VmpApplyDftTmpBytesImpl, VmpApplyDftToDftAddImpl, VmpApplyDftToDftAddTmpBytesImpl, VmpApplyDftToDftImpl,
VmpApplyDftToDftTmpBytesImpl, VmpPMatAllocBytesImpl, VmpPMatAllocImpl, VmpPMatFromBytesImpl, VmpPrepareImpl,
VmpPrepareTmpBytesImpl,
},
};
@@ -48,7 +52,7 @@ where
impl<B> VmpPrepare<B> for Module<B>
where
B: Backend + VmpPMatPrepareImpl<B>,
B: Backend + VmpPrepareImpl<B>,
{
fn vmp_prepare<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<B>)
where
@@ -59,6 +63,39 @@ where
}
}
impl<B> VmpApplyDftTmpBytes for Module<B>
where
B: Backend + VmpApplyDftTmpBytesImpl<B>,
{
fn vmp_apply_dft_tmp_bytes(
&self,
res_size: usize,
a_size: usize,
b_rows: usize,
b_cols_in: usize,
b_cols_out: usize,
b_size: usize,
) -> usize {
B::vmp_apply_dft_tmp_bytes_impl(
self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
)
}
}
impl<B> VmpApplyDft<B> for Module<B>
where
B: Backend + VmpApplyDftImpl<B>,
{
fn vmp_apply_dft<R, A, C>(&self, res: &mut R, a: &A, b: &C, scratch: &mut Scratch<B>)
where
R: VecZnxDftToMut<B>,
A: VecZnxToRef,
C: VmpPMatToRef<B>,
{
B::vmp_apply_dft_impl(self, res, a, b, scratch);
}
}
impl<B> VmpApplyDftToDftTmpBytes for Module<B>
where
B: Backend + VmpApplyDftToDftTmpBytesImpl<B>,

View File

@@ -1,10 +1,19 @@
use crate::{
api::{ZnAddDistF64, ZnAddNormal, ZnFillDistF64, ZnFillNormal, ZnFillUniform, ZnNormalizeInplace},
api::{ZnAddNormal, ZnFillNormal, ZnFillUniform, ZnNormalizeInplace, ZnNormalizeTmpBytes},
layouts::{Backend, Module, Scratch, ZnToMut},
oep::{ZnAddDistF64Impl, ZnAddNormalImpl, ZnFillDistF64Impl, ZnFillNormalImpl, ZnFillUniformImpl, ZnNormalizeInplaceImpl},
oep::{ZnAddNormalImpl, ZnFillNormalImpl, ZnFillUniformImpl, ZnNormalizeInplaceImpl, ZnNormalizeTmpBytesImpl},
source::Source,
};
impl<B> ZnNormalizeTmpBytes for Module<B>
where
B: Backend + ZnNormalizeTmpBytesImpl<B>,
{
fn zn_normalize_tmp_bytes(&self, n: usize) -> usize {
B::zn_normalize_tmp_bytes_impl(n)
}
}
impl<B> ZnNormalizeInplace<B> for Module<B>
where
B: Backend + ZnNormalizeInplaceImpl<B>,
@@ -21,53 +30,11 @@ impl<B> ZnFillUniform for Module<B>
where
B: Backend + ZnFillUniformImpl<B>,
{
fn zn_fill_uniform<R>(&self, n: usize, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source)
fn zn_fill_uniform<R>(&self, n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: ZnToMut,
{
B::zn_fill_uniform_impl(n, basek, res, res_col, k, source);
}
}
impl<B> ZnFillDistF64 for Module<B>
where
B: Backend + ZnFillDistF64Impl<B>,
{
fn zn_fill_dist_f64<R, D: rand::prelude::Distribution<f64>>(
&self,
n: usize,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) where
R: ZnToMut,
{
B::zn_fill_dist_f64_impl(n, basek, res, res_col, k, source, dist, bound);
}
}
impl<B> ZnAddDistF64 for Module<B>
where
B: Backend + ZnAddDistF64Impl<B>,
{
fn zn_add_dist_f64<R, D: rand::prelude::Distribution<f64>>(
&self,
n: usize,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) where
R: ZnToMut,
{
B::zn_add_dist_f64_impl(n, basek, res, res_col, k, source, dist, bound);
B::zn_fill_uniform_impl(n, basek, res, res_col, source);
}
}

View File

@@ -1,8 +1,9 @@
use itertools::izip;
use rug::{Assign, Float};
use crate::layouts::{
DataMut, DataRef, VecZnx, VecZnxToMut, VecZnxToRef, Zn, ZnToMut, ZnToRef, ZnxInfos, ZnxView, ZnxViewMut, ZnxZero,
use crate::{
layouts::{DataMut, DataRef, VecZnx, VecZnxToMut, VecZnxToRef, Zn, ZnToMut, ZnToRef, ZnxInfos, ZnxView, ZnxViewMut},
reference::znx::znx_zero_ref,
};
impl<D: DataMut> VecZnx<D> {
@@ -28,7 +29,7 @@ impl<D: DataMut> VecZnx<D> {
// Zeroes coefficients of the i-th column
(0..a.size()).for_each(|i| {
a.zero_at(col, i);
znx_zero_ref(a.at_mut(col, i));
});
// If 2^{basek} * 2^{k_rem} < 2^{63}-1, then we can simply copy
@@ -183,7 +184,7 @@ impl<D: DataRef> VecZnx<D> {
let prec: u32 = (basek * size) as u32;
// 2^{basek}
let base = Float::with_val(prec, (1 << basek) as f64);
let base = Float::with_val(prec, (1u64 << basek) as f64);
// y[i] = sum x[j][i] * 2^{-basek*j}
(0..size).for_each(|i| {

View File

@@ -1,17 +1,21 @@
use crate::{
alloc_aligned,
layouts::{
Data, DataMut, DataRef, DataView, DataViewMut, FillUniform, ReaderFrom, Reset, ToOwnedDeep, VecZnx, WriterTo, ZnxInfos,
ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, FillUniform, ReaderFrom, Reset, ToOwnedDeep, VecZnx, WriterTo,
ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
},
source::Source,
};
use std::fmt;
use std::{
fmt,
hash::{DefaultHasher, Hasher},
};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use rand::RngCore;
#[derive(PartialEq, Eq, Clone)]
#[repr(C)]
#[derive(PartialEq, Eq, Clone, Hash)]
pub struct MatZnx<D: Data> {
data: D,
n: usize,
@@ -21,6 +25,19 @@ pub struct MatZnx<D: Data> {
cols_out: usize,
}
impl<D: DataRef> DigestU64 for MatZnx<D> {
fn digest_u64(&self) -> u64 {
let mut h: DefaultHasher = DefaultHasher::new();
h.write(self.data.as_ref());
h.write_usize(self.n);
h.write_usize(self.size);
h.write_usize(self.rows);
h.write_usize(self.cols_in);
h.write_usize(self.cols_out);
h.finish()
}
}
impl<D: DataRef> ToOwnedDeep for MatZnx<D> {
type Owned = MatZnx<Vec<u8>>;
fn to_owned_deep(&self) -> Self::Owned {
@@ -57,6 +74,10 @@ impl<D: Data> ZnxInfos for MatZnx<D> {
fn size(&self) -> usize {
self.size
}
fn poly_count(&self) -> usize {
self.rows() * self.cols_in() * self.cols_out() * self.size()
}
}
impl<D: Data> ZnxSliceSize for MatZnx<D> {
@@ -175,8 +196,18 @@ impl<D: DataMut> MatZnx<D> {
}
impl<D: DataMut> FillUniform for MatZnx<D> {
fn fill_uniform(&mut self, source: &mut Source) {
source.fill_bytes(self.data.as_mut());
fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) {
match log_bound {
64 => source.fill_bytes(self.data.as_mut()),
0 => panic!("invalid log_bound, cannot be zero"),
_ => {
let mask: u64 = (1u64 << log_bound) - 1;
for x in self.raw_mut().iter_mut() {
let r = source.next_u64() & mask;
*x = ((r << (64 - log_bound)) as i64) >> (64 - log_bound);
}
}
}
}
}

View File

@@ -34,3 +34,7 @@ pub trait ToOwnedDeep {
type Owned;
fn to_owned_deep(&self) -> Self::Owned;
}
pub trait DigestU64 {
fn digest_u64(&self) -> u64;
}

View File

@@ -20,7 +20,30 @@ pub struct Module<B: Backend> {
_marker: PhantomData<B>,
}
unsafe impl<B: Backend> Sync for Module<B> {}
unsafe impl<B: Backend> Send for Module<B> {}
impl<B: Backend> Module<B> {
#[allow(clippy::missing_safety_doc)]
#[inline]
pub fn new_marker(n: u64) -> Self {
Self {
ptr: NonNull::dangling(),
n,
_marker: PhantomData,
}
}
#[allow(clippy::missing_safety_doc)]
#[inline]
pub unsafe fn from_nonnull(ptr: NonNull<B::Handle>, n: u64) -> Self {
Self {
ptr,
n,
_marker: PhantomData,
}
}
/// Construct from a raw pointer managed elsewhere.
/// SAFETY: `ptr` must be non-null and remain valid for the lifetime of this Module.
#[inline]

View File

@@ -1,3 +1,5 @@
use std::hash::{DefaultHasher, Hasher};
use rand::seq::SliceRandom;
use rand_core::RngCore;
use rand_distr::{Distribution, weighted::WeightedIndex};
@@ -5,19 +7,30 @@ use rand_distr::{Distribution, weighted::WeightedIndex};
use crate::{
alloc_aligned,
layouts::{
Data, DataMut, DataRef, DataView, DataViewMut, FillUniform, ReaderFrom, Reset, ToOwnedDeep, VecZnx, WriterTo, ZnxInfos,
ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, FillUniform, ReaderFrom, Reset, ToOwnedDeep, VecZnx, WriterTo,
ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
},
source::Source,
};
#[derive(PartialEq, Eq, Debug, Clone)]
#[repr(C)]
#[derive(PartialEq, Eq, Debug, Clone, Hash)]
pub struct ScalarZnx<D: Data> {
pub data: D,
pub n: usize,
pub cols: usize,
}
impl<D: DataRef> DigestU64 for ScalarZnx<D> {
fn digest_u64(&self) -> u64 {
let mut h: DefaultHasher = DefaultHasher::new();
h.write(self.data.as_ref());
h.write_usize(self.n);
h.write_usize(self.cols);
h.finish()
}
}
impl<D: DataRef> ToOwnedDeep for ScalarZnx<D> {
type Owned = ScalarZnx<Vec<u8>>;
fn to_owned_deep(&self) -> Self::Owned {
@@ -145,8 +158,18 @@ impl<D: DataMut> ZnxZero for ScalarZnx<D> {
}
impl<D: DataMut> FillUniform for ScalarZnx<D> {
fn fill_uniform(&mut self, source: &mut Source) {
source.fill_bytes(self.data.as_mut());
fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) {
match log_bound {
64 => source.fill_bytes(self.data.as_mut()),
0 => panic!("invalid log_bound, cannot be zero"),
_ => {
let mask: u64 = (1u64 << log_bound) - 1;
for x in self.raw_mut().iter_mut() {
let r = source.next_u64() & mask;
*x = ((r << (64 - log_bound)) as i64) >> (64 - log_bound);
}
}
}
}
}

View File

@@ -2,11 +2,13 @@ use std::marker::PhantomData;
use crate::layouts::Backend;
#[repr(C)]
pub struct ScratchOwned<B: Backend> {
pub data: Vec<u8>,
pub _phantom: PhantomData<B>,
}
#[repr(C)]
pub struct Scratch<B: Backend> {
pub _phantom: PhantomData<B>,
pub data: [u8],

View File

@@ -4,7 +4,7 @@ use rug::{
ops::{AddAssignRound, DivAssignRound, SubAssignRound},
};
use crate::layouts::{DataRef, VecZnx, ZnxInfos};
use crate::layouts::{Backend, DataRef, VecZnx, VecZnxBig, VecZnxBigToRef, ZnxInfos};
impl<D: DataRef> VecZnx<D> {
pub fn std(&self, basek: usize, col: usize) -> f64 {
@@ -27,3 +27,17 @@ impl<D: DataRef> VecZnx<D> {
std.to_f64()
}
}
impl<D: DataRef, B: Backend + Backend<ScalarBig = i64>> VecZnxBig<D, B> {
pub fn std(&self, basek: usize, col: usize) -> f64 {
let self_ref: VecZnxBig<&[u8], B> = self.to_ref();
let znx: VecZnx<&[u8]> = VecZnx {
data: self_ref.data,
n: self_ref.n,
cols: self_ref.cols,
size: self_ref.size,
max_size: self_ref.max_size,
};
znx.std(basek, col)
}
}

View File

@@ -1,12 +1,19 @@
use std::marker::PhantomData;
use std::{
fmt,
hash::{DefaultHasher, Hasher},
marker::PhantomData,
};
use crate::{
alloc_aligned,
layouts::{Backend, Data, DataMut, DataRef, DataView, DataViewMut, ReaderFrom, WriterTo, ZnxInfos, ZnxSliceSize, ZnxView},
layouts::{
Backend, Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, ReaderFrom, WriterTo, ZnxInfos, ZnxSliceSize, ZnxView,
},
oep::SvpPPolAllocBytesImpl,
};
#[derive(PartialEq, Eq)]
#[repr(C)]
#[derive(PartialEq, Eq, Hash)]
pub struct SvpPPol<D: Data, B: Backend> {
pub data: D,
pub n: usize,
@@ -14,6 +21,16 @@ pub struct SvpPPol<D: Data, B: Backend> {
pub _phantom: PhantomData<B>,
}
impl<D: DataRef, B: Backend> DigestU64 for SvpPPol<D, B> {
fn digest_u64(&self) -> u64 {
let mut h: DefaultHasher = DefaultHasher::new();
h.write(self.data.as_ref());
h.write_usize(self.n);
h.write_usize(self.cols);
h.finish()
}
}
impl<D: Data, B: Backend> ZnxSliceSize for SvpPPol<D, B> {
fn sl(&self) -> usize {
B::layout_prep_word_count() * self.n()
@@ -153,3 +170,32 @@ impl<D: DataRef, B: Backend> WriterTo for SvpPPol<D, B> {
Ok(())
}
}
impl<D: DataRef, B: Backend> fmt::Display for SvpPPol<D, B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "SvpPPol(n={}, cols={})", self.n, self.cols)?;
for col in 0..self.cols {
writeln!(f, "Column {}:", col)?;
let coeffs = self.at(col, 0);
write!(f, "[")?;
let max_show = 100;
let show_count = coeffs.len().min(max_show);
for (i, &coeff) in coeffs.iter().take(show_count).enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", coeff)?;
}
if coeffs.len() > max_show {
write!(f, ", ... ({} more)", coeffs.len() - max_show)?;
}
writeln!(f, "]")?;
}
Ok(())
}
}

View File

@@ -1,10 +1,13 @@
use std::fmt;
use std::{
fmt,
hash::{DefaultHasher, Hasher},
};
use crate::{
alloc_aligned,
layouts::{
Data, DataMut, DataRef, DataView, DataViewMut, FillUniform, ReaderFrom, Reset, ToOwnedDeep, WriterTo, ZnxInfos,
ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, FillUniform, ReaderFrom, Reset, ToOwnedDeep, WriterTo,
ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
},
source::Source,
};
@@ -12,7 +15,8 @@ use crate::{
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use rand::RngCore;
#[derive(PartialEq, Eq, Clone, Copy)]
#[repr(C)]
#[derive(PartialEq, Eq, Clone, Copy, Hash)]
pub struct VecZnx<D: Data> {
pub data: D,
pub n: usize,
@@ -21,6 +25,18 @@ pub struct VecZnx<D: Data> {
pub max_size: usize,
}
impl<D: DataRef> DigestU64 for VecZnx<D> {
fn digest_u64(&self) -> u64 {
let mut h: DefaultHasher = DefaultHasher::new();
h.write(self.data.as_ref());
h.write_usize(self.n);
h.write_usize(self.cols);
h.write_usize(self.size);
h.write_usize(self.max_size);
h.finish()
}
}
impl<D: DataRef> ToOwnedDeep for VecZnx<D> {
type Owned = VecZnx<Vec<u8>>;
fn to_owned_deep(&self) -> Self::Owned {
@@ -173,8 +189,18 @@ impl<D: DataRef> fmt::Display for VecZnx<D> {
}
impl<D: DataMut> FillUniform for VecZnx<D> {
fn fill_uniform(&mut self, source: &mut Source) {
source.fill_bytes(self.data.as_mut());
fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) {
match log_bound {
64 => source.fill_bytes(self.data.as_mut()),
0 => panic!("invalid log_bound, cannot be zero"),
_ => {
let mask: u64 = (1u64 << log_bound) - 1;
for x in self.raw_mut().iter_mut() {
let r = source.next_u64() & mask;
*x = ((r << (64 - log_bound)) as i64) >> (64 - log_bound);
}
}
}
}
}

View File

@@ -1,15 +1,21 @@
use std::marker::PhantomData;
use std::{
hash::{DefaultHasher, Hasher},
marker::PhantomData,
};
use rand_distr::num_traits::Zero;
use std::fmt;
use crate::{
alloc_aligned,
layouts::{Backend, Data, DataMut, DataRef, DataView, DataViewMut, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero},
layouts::{
Backend, Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
},
oep::VecZnxBigAllocBytesImpl,
};
#[derive(PartialEq, Eq)]
#[repr(C)]
#[derive(PartialEq, Eq, Hash)]
pub struct VecZnxBig<D: Data, B: Backend> {
pub data: D,
pub n: usize,
@@ -19,6 +25,18 @@ pub struct VecZnxBig<D: Data, B: Backend> {
pub _phantom: PhantomData<B>,
}
impl<D: DataRef, B: Backend> DigestU64 for VecZnxBig<D, B> {
fn digest_u64(&self) -> u64 {
let mut h: DefaultHasher = DefaultHasher::new();
h.write(self.data.as_ref());
h.write_usize(self.n);
h.write_usize(self.cols);
h.write_usize(self.size);
h.write_usize(self.max_size);
h.finish()
}
}
impl<D: Data, B: Backend> ZnxSliceSize for VecZnxBig<D, B> {
fn sl(&self) -> usize {
B::layout_big_word_count() * self.n() * self.cols()

View File

@@ -1,14 +1,21 @@
use std::{fmt, marker::PhantomData};
use std::{
fmt,
hash::{DefaultHasher, Hasher},
marker::PhantomData,
};
use rand_distr::num_traits::Zero;
use crate::{
alloc_aligned,
layouts::{
Backend, Data, DataMut, DataRef, DataView, DataViewMut, VecZnxBig, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
Backend, Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, VecZnxBig, ZnxInfos, ZnxSliceSize, ZnxView,
ZnxViewMut, ZnxZero,
},
oep::VecZnxBigAllocBytesImpl,
oep::VecZnxDftAllocBytesImpl,
};
#[repr(C)]
#[derive(PartialEq, Eq)]
pub struct VecZnxDft<D: Data, B: Backend> {
pub data: D,
@@ -19,6 +26,18 @@ pub struct VecZnxDft<D: Data, B: Backend> {
pub _phantom: PhantomData<B>,
}
impl<D: DataRef, B: Backend> DigestU64 for VecZnxDft<D, B> {
fn digest_u64(&self) -> u64 {
let mut h: DefaultHasher = DefaultHasher::new();
h.write(self.data.as_ref());
h.write_usize(self.n);
h.write_usize(self.cols);
h.write_usize(self.size);
h.write_usize(self.max_size);
h.finish()
}
}
impl<D: Data, B: Backend> ZnxSliceSize for VecZnxDft<D, B> {
fn sl(&self) -> usize {
B::layout_prep_word_count() * self.n() * self.cols()
@@ -94,10 +113,10 @@ where
impl<D: DataRef + From<Vec<u8>>, B: Backend> VecZnxDft<D, B>
where
B: VecZnxBigAllocBytesImpl<B>,
B: VecZnxDftAllocBytesImpl<B>,
{
pub fn alloc(n: usize, cols: usize, size: usize) -> Self {
let data: Vec<u8> = alloc_aligned::<u8>(B::vec_znx_big_alloc_bytes_impl(n, cols, size));
let data: Vec<u8> = alloc_aligned::<u8>(B::vec_znx_dft_alloc_bytes_impl(n, cols, size));
Self {
data: data.into(),
n,
@@ -110,7 +129,7 @@ where
pub fn from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
let data: Vec<u8> = bytes.into();
assert!(data.len() == B::vec_znx_big_alloc_bytes_impl(n, cols, size));
assert!(data.len() == B::vec_znx_dft_alloc_bytes_impl(n, cols, size));
Self {
data: data.into(),
n,

View File

@@ -1,12 +1,16 @@
use std::marker::PhantomData;
use std::{
hash::{DefaultHasher, Hasher},
marker::PhantomData,
};
use crate::{
alloc_aligned,
layouts::{Backend, Data, DataMut, DataRef, DataView, DataViewMut, ZnxInfos, ZnxView},
layouts::{Backend, Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, ZnxInfos, ZnxView},
oep::VmpPMatAllocBytesImpl,
};
#[derive(PartialEq, Eq)]
#[repr(C)]
#[derive(PartialEq, Eq, Hash)]
pub struct VmpPMat<D: Data, B: Backend> {
data: D,
n: usize,
@@ -17,6 +21,19 @@ pub struct VmpPMat<D: Data, B: Backend> {
_phantom: PhantomData<B>,
}
impl<D: DataRef, B: Backend> DigestU64 for VmpPMat<D, B> {
fn digest_u64(&self) -> u64 {
let mut h: DefaultHasher = DefaultHasher::new();
h.write(self.data.as_ref());
h.write_usize(self.n);
h.write_usize(self.size);
h.write_usize(self.rows);
h.write_usize(self.cols_in);
h.write_usize(self.cols_out);
h.finish()
}
}
impl<D: DataRef, B: Backend> ZnxView for VmpPMat<D, B> {
type Scalar = B::ScalarPrep;
}
@@ -37,6 +54,10 @@ impl<D: Data, B: Backend> ZnxInfos for VmpPMat<D, B> {
fn size(&self) -> usize {
self.size
}
fn poly_count(&self) -> usize {
self.rows() * self.cols_in() * self.size() * self.cols_out()
}
}
impl<D: Data, B: Backend> DataView for VmpPMat<D, B> {

View File

@@ -1,10 +1,13 @@
use std::fmt;
use std::{
fmt,
hash::{DefaultHasher, Hasher},
};
use crate::{
alloc_aligned,
layouts::{
Data, DataMut, DataRef, DataView, DataViewMut, FillUniform, ReaderFrom, Reset, ToOwnedDeep, WriterTo, ZnxInfos,
ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, FillUniform, ReaderFrom, Reset, ToOwnedDeep, WriterTo,
ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
},
source::Source,
};
@@ -12,7 +15,8 @@ use crate::{
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use rand::RngCore;
#[derive(PartialEq, Eq, Clone, Copy)]
#[repr(C)]
#[derive(PartialEq, Eq, Clone, Copy, Hash)]
pub struct Zn<D: Data> {
pub data: D,
pub n: usize,
@@ -21,6 +25,18 @@ pub struct Zn<D: Data> {
pub max_size: usize,
}
impl<D: DataRef> DigestU64 for Zn<D> {
fn digest_u64(&self) -> u64 {
let mut h: DefaultHasher = DefaultHasher::new();
h.write(self.data.as_ref());
h.write_usize(self.n);
h.write_usize(self.cols);
h.write_usize(self.size);
h.write_usize(self.max_size);
h.finish()
}
}
impl<D: DataRef> ToOwnedDeep for Zn<D> {
type Owned = Zn<Vec<u8>>;
fn to_owned_deep(&self) -> Self::Owned {
@@ -173,8 +189,18 @@ impl<D: DataRef> fmt::Display for Zn<D> {
}
impl<D: DataMut> FillUniform for Zn<D> {
fn fill_uniform(&mut self, source: &mut Source) {
source.fill_bytes(self.data.as_mut());
fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) {
match log_bound {
64 => source.fill_bytes(self.data.as_mut()),
0 => panic!("invalid log_bound, cannot be zero"),
_ => {
let mask: u64 = (1u64 << log_bound) - 1;
for x in self.raw_mut().iter_mut() {
let r = source.next_u64() & mask;
*x = ((r << (64 - log_bound)) as i64) >> (64 - log_bound);
}
}
}
}
}

View File

@@ -117,7 +117,7 @@ where
}
pub trait FillUniform {
fn fill_uniform(&mut self, source: &mut Source);
fn fill_uniform(&mut self, log_bound: usize, source: &mut Source);
}
pub trait Reset {

View File

@@ -4,11 +4,13 @@
#![feature(trait_alias)]
pub mod api;
pub mod bench_suite;
pub mod delegates;
pub mod layouts;
pub mod oep;
pub mod reference;
pub mod source;
pub mod tests;
pub mod test_suite;
pub mod doc {
#[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/docs/backend_safety_contract.md"))]
@@ -85,13 +87,20 @@ fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec<u8> {
/// Allocates a block of T aligned with [DEFAULTALIGN].
/// Size of T * size msut be a multiple of [DEFAULTALIGN].
pub fn alloc_aligned_custom<T>(size: usize, align: usize) -> Vec<T> {
assert_eq!(
(size * size_of::<T>()) % (align / size_of::<T>()),
0,
"size={} must be a multiple of align={}",
size,
assert!(
align.is_power_of_two(),
"Alignment must be a power of two but is {}",
align
);
assert_eq!(
(size * size_of::<T>()) % align,
0,
"size*size_of::<T>()={} must be a multiple of align={}",
size * size_of::<T>(),
align
);
let mut vec_u8: Vec<u8> = alloc_aligned_custom_u8(size_of::<T>() * size, align);
let ptr: *mut T = vec_u8.as_mut_ptr() as *mut T;
let len: usize = vec_u8.len() / size_of::<T>();
@@ -100,11 +109,11 @@ pub fn alloc_aligned_custom<T>(size: usize, align: usize) -> Vec<T> {
unsafe { Vec::from_raw_parts(ptr, len, cap) }
}
/// Allocates an aligned vector of size equal to the smallest multiple
/// of [DEFAULTALIGN]/`size_of::<T>`() that is equal or greater to `size`.
/// Allocates an aligned vector of the given size.
/// Padds until it is size in [u8] a multiple of [DEFAULTALIGN].
pub fn alloc_aligned<T>(size: usize) -> Vec<T> {
alloc_aligned_custom::<T>(
size + (DEFAULTALIGN - (size % (DEFAULTALIGN / size_of::<T>()))) % DEFAULTALIGN,
(size * size_of::<T>()).next_multiple_of(DEFAULTALIGN) / size_of::<T>(),
DEFAULTALIGN,
)
}

View File

@@ -1,4 +1,6 @@
use crate::layouts::{Backend, Module, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDftToMut, VecZnxDftToRef};
use crate::layouts::{
Backend, Module, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef,
};
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
@@ -39,9 +41,28 @@ pub unsafe trait SvpPrepareImpl<B: Backend> {
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait SvpApplyImpl<B: Backend> {
fn svp_apply_impl<R, A, C>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
pub unsafe trait SvpApplyDftImpl<B: Backend> {
fn svp_apply_dft_impl<R, A, C>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
where
R: VecZnxDftToMut<B>,
A: SvpPPolToRef<B>,
C: VecZnxToRef;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait SvpApplyDftToDftImpl<B: Backend> {
fn svp_apply_dft_to_dft_impl<R, A, C>(
module: &Module<B>,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
b: &C,
b_col: usize,
) where
R: VecZnxDftToMut<B>,
A: SvpPPolToRef<B>,
C: VecZnxDftToRef<B>;
@@ -51,8 +72,27 @@ pub unsafe trait SvpApplyImpl<B: Backend> {
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait SvpApplyInplaceImpl: Backend {
fn svp_apply_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
pub unsafe trait SvpApplyDftToDftAddImpl<B: Backend> {
fn svp_apply_dft_to_dft_add_impl<R, A, C>(
module: &Module<B>,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
b: &C,
b_col: usize,
) where
R: VecZnxDftToMut<B>,
A: SvpPPolToRef<B>,
C: VecZnxDftToRef<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait SvpApplyDftToDftInplaceImpl: Backend {
fn svp_apply_dft_to_dft_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<Self>,
A: SvpPPolToRef<Self>;

View File

@@ -1,5 +1,3 @@
use rand_distr::Distribution;
use crate::{
layouts::{Backend, Module, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef},
source::Source,
@@ -64,6 +62,28 @@ pub unsafe trait VecZnxAddInplaceImpl<B: Backend> {
A: VecZnxToRef;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_add_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L86) for reference code.
/// * See [crate::api::VecZnxAddScalar] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxAddScalarImpl<D: Backend> {
/// Adds the selected column of `a` on the selected column and limb of `b` and writes the result on the selected column of `res`.
#[allow(clippy::too_many_arguments)]
fn vec_znx_add_scalar_impl<R, A, B>(
module: &Module<D>,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
b: &B,
b_col: usize,
b_limb: usize,
) where
R: VecZnxToMut,
A: ScalarZnxToRef,
B: VecZnxToRef;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_add_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L86) for reference code.
/// * See [crate::api::VecZnxAddScalarInplace] for corresponding public API.
@@ -115,6 +135,28 @@ pub unsafe trait VecZnxSubBAInplaceImpl<B: Backend> {
A: VecZnxToRef;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO.
/// * See [crate::api::VecZnxAddScalar] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxSubScalarImpl<D: Backend> {
/// Adds the selected column of `a` on the selected column and limb of `b` and writes the result on the selected column of `res`.
#[allow(clippy::too_many_arguments)]
fn vec_znx_sub_scalar_impl<R, A, B>(
module: &Module<D>,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
b: &B,
b_col: usize,
b_limb: usize,
) where
R: VecZnxToMut,
A: ScalarZnxToRef,
B: VecZnxToRef;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_sub_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L125) for reference code.
/// * See [crate::api::VecZnxSubScalarInplace] for corresponding public API.
@@ -153,14 +195,76 @@ pub unsafe trait VecZnxNegateInplaceImpl<B: Backend> {
A: VecZnxToMut;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::reference::vec_znx::shift::vec_znx_rsh_tmp_bytes] for reference code.
/// * See [crate::api::VecZnxRshTmpBytes] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxRshTmpBytesImpl<B: Backend> {
fn vec_znx_rsh_tmp_bytes_impl(module: &Module<B>) -> usize;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::reference::vec_znx::shift::vec_znx_rsh_inplace] for reference code.
/// * See [crate::api::VecZnxRsh] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxRshImpl<B: Backend> {
#[allow(clippy::too_many_arguments)]
fn vec_znx_rsh_inplace_impl<R, A>(
module: &Module<B>,
basek: usize,
k: usize,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch<B>,
) where
R: VecZnxToMut,
A: VecZnxToRef;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::reference::vec_znx::shift::vec_znx_lsh_tmp_bytes] for reference code.
/// * See [crate::api::VecZnxLshTmpBytes] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxLshTmpBytesImpl<B: Backend> {
fn vec_znx_lsh_tmp_bytes_impl(module: &Module<B>) -> usize;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::reference::vec_znx::shift::vec_znx_lsh_inplace] for reference code.
/// * See [crate::api::VecZnxLsh] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxLshImpl<B: Backend> {
#[allow(clippy::too_many_arguments)]
fn vec_znx_lsh_inplace_impl<R, A>(
module: &Module<B>,
basek: usize,
k: usize,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch<B>,
) where
R: VecZnxToMut,
A: VecZnxToRef;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::cpu_spqlios::vec_znx::vec_znx_rsh_inplace_ref] for reference code.
/// * See [crate::api::VecZnxRshInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxRshInplaceImpl<B: Backend> {
fn vec_znx_rsh_inplace_impl<A>(module: &Module<B>, basek: usize, k: usize, a: &mut A)
where
A: VecZnxToMut;
fn vec_znx_rsh_inplace_impl<R>(
module: &Module<B>,
basek: usize,
k: usize,
res: &mut R,
res_col: usize,
scratch: &mut Scratch<B>,
) where
R: VecZnxToMut;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
@@ -168,9 +272,15 @@ pub unsafe trait VecZnxRshInplaceImpl<B: Backend> {
/// * See [crate::api::VecZnxLshInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxLshInplaceImpl<B: Backend> {
fn vec_znx_lsh_inplace_impl<A>(module: &Module<B>, basek: usize, k: usize, a: &mut A)
where
A: VecZnxToMut;
fn vec_znx_lsh_inplace_impl<R>(
module: &Module<B>,
basek: usize,
k: usize,
res: &mut R,
res_col: usize,
scratch: &mut Scratch<B>,
) where
R: VecZnxToMut;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
@@ -184,12 +294,20 @@ pub unsafe trait VecZnxRotateImpl<B: Backend> {
A: VecZnxToRef;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO;
/// * See [crate::api::VecZnxRotateInplaceTmpBytes] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxRotateInplaceTmpBytesImpl<B: Backend> {
fn vec_znx_rotate_inplace_tmp_bytes_impl(module: &Module<B>) -> usize;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_rotate_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L164) for reference code.
/// * See [crate::api::VecZnxRotateInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxRotateInplaceImpl<B: Backend> {
fn vec_znx_rotate_inplace_impl<A>(module: &Module<B>, k: i64, a: &mut A, a_col: usize)
fn vec_znx_rotate_inplace_impl<A>(module: &Module<B>, k: i64, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
where
A: VecZnxToMut;
}
@@ -199,20 +317,28 @@ pub unsafe trait VecZnxRotateInplaceImpl<B: Backend> {
/// * See [crate::api::VecZnxAutomorphism] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxAutomorphismImpl<B: Backend> {
fn vec_znx_automorphism_impl<R, A>(module: &Module<B>, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_automorphism_impl<R, A>(module: &Module<B>, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO;
/// * See [crate::api::VecZnxAutomorphismInplaceTmpBytes] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxAutomorphismInplaceTmpBytesImpl<B: Backend> {
fn vec_znx_automorphism_inplace_tmp_bytes_impl(module: &Module<B>) -> usize;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_automorphism_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L188) for reference code.
/// * See [crate::api::VecZnxAutomorphismInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxAutomorphismInplaceImpl<B: Backend> {
fn vec_znx_automorphism_inplace_impl<A>(module: &Module<B>, k: i64, a: &mut A, a_col: usize)
fn vec_znx_automorphism_inplace_impl<R>(module: &Module<B>, k: i64, res: &mut R, res_col: usize, scratch: &mut Scratch<B>)
where
A: VecZnxToMut;
R: VecZnxToMut;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
@@ -226,34 +352,75 @@ pub unsafe trait VecZnxMulXpMinusOneImpl<B: Backend> {
A: VecZnxToRef;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO;
/// * See [crate::api::VecZnxMulXpMinusOneInplaceTmpBytes] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxMulXpMinusOneInplaceTmpBytesImpl<B: Backend> {
fn vec_znx_mul_xp_minus_one_inplace_tmp_bytes_impl(module: &Module<B>) -> usize;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_mul_xp_minus_one_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/7160f588da49712a042931ea247b4259b95cefcc/spqlios/arithmetic/vec_znx.c#L200C13-L200C41) for reference code.
/// * See [crate::api::VecZnxMulXpMinusOneInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxMulXpMinusOneInplaceImpl<B: Backend> {
fn vec_znx_mul_xp_minus_one_inplace_impl<R>(module: &Module<B>, p: i64, res: &mut R, res_col: usize)
where
fn vec_znx_mul_xp_minus_one_inplace_impl<R>(
module: &Module<B>,
p: i64,
res: &mut R,
res_col: usize,
scratch: &mut Scratch<B>,
) where
R: VecZnxToMut;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::cpu_spqlios::vec_znx::vec_znx_split_ref] for reference code.
/// * See [crate::api::VecZnxSplit] for corresponding public API.
/// * See TODO;
/// * See [crate::api::VecZnxSplitRingTmpBytes] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxSplitImpl<B: Backend> {
fn vec_znx_split_impl<R, A>(module: &Module<B>, res: &mut [R], res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
where
pub unsafe trait VecZnxSplitRingTmpBytesImpl<B: Backend> {
fn vec_znx_split_ring_tmp_bytes_impl(module: &Module<B>) -> usize;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::cpu_spqlios::vec_znx::vec_znx_split_ref] for reference code.
/// * See [crate::api::VecZnxSplitRing] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxSplitRingImpl<B: Backend> {
fn vec_znx_split_ring_impl<R, A>(
module: &Module<B>,
res: &mut [R],
res_col: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch<B>,
) where
R: VecZnxToMut,
A: VecZnxToRef;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO;
/// * See [crate::api::VecZnxMergeRingsTmpBytes] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxMergeRingsTmpBytesImpl<B: Backend> {
fn vec_znx_merge_rings_tmp_bytes_impl(module: &Module<B>) -> usize;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::cpu_spqlios::vec_znx::vec_znx_merge_ref] for reference code.
/// * See [crate::api::VecZnxMerge] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxMergeImpl<B: Backend> {
fn vec_znx_merge_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &[A], a_col: usize)
where
pub unsafe trait VecZnxMergeRingsImpl<B: Backend> {
fn vec_znx_merge_rings_impl<R, A>(
module: &Module<B>,
res: &mut R,
res_col: usize,
a: &[A],
a_col: usize,
scratch: &mut Scratch<B>,
) where
R: VecZnxToMut,
A: VecZnxToRef;
}
@@ -262,8 +429,8 @@ pub unsafe trait VecZnxMergeImpl<B: Backend> {
/// * See [crate::cpu_spqlios::vec_znx::vec_znx_switch_degree_ref] for reference code.
/// * See [crate::api::VecZnxSwithcDegree] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxSwithcDegreeImpl<B: Backend> {
fn vec_znx_switch_degree_impl<R: VecZnxToMut, A: VecZnxToRef>(
pub unsafe trait VecZnxSwitchRingImpl<B: Backend> {
fn vec_znx_switch_ring_impl<R: VecZnxToMut, A: VecZnxToRef>(
module: &Module<B>,
res: &mut R,
res_col: usize,
@@ -287,47 +454,11 @@ pub unsafe trait VecZnxCopyImpl<B: Backend> {
/// * See [crate::api::VecZnxFillUniform] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxFillUniformImpl<B: Backend> {
fn vec_znx_fill_uniform_impl<R>(module: &Module<B>, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source)
fn vec_znx_fill_uniform_impl<R>(module: &Module<B>, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: VecZnxToMut;
}
#[allow(clippy::too_many_arguments)]
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::api::VecZnxFillDistF64] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxFillDistF64Impl<B: Backend> {
fn vec_znx_fill_dist_f64_impl<R, D: Distribution<f64>>(
module: &Module<B>,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) where
R: VecZnxToMut;
}
#[allow(clippy::too_many_arguments)]
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::api::VecZnxAddDistF64] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxAddDistF64Impl<B: Backend> {
fn vec_znx_add_dist_f64_impl<R, D: Distribution<f64>>(
module: &Module<B>,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) where
R: VecZnxToMut;
}
#[allow(clippy::too_many_arguments)]
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::api::VecZnxFillNormal] for corresponding public API.

View File

@@ -1,10 +1,19 @@
use rand_distr::Distribution;
use crate::{
layouts::{Backend, Module, Scratch, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef},
source::Source,
};
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigFromSmallImpl<B: Backend> {
fn vec_znx_big_from_small_impl<R, A>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxToRef;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
@@ -47,60 +56,6 @@ pub unsafe trait VecZnxBigAddNormalImpl<B: Backend> {
);
}
#[allow(clippy::too_many_arguments)]
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigFillNormalImpl<B: Backend> {
fn fill_normal_impl<R: VecZnxBigToMut<B>>(
module: &Module<B>,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
);
}
#[allow(clippy::too_many_arguments)]
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigFillDistF64Impl<B: Backend> {
fn fill_dist_f64_impl<R: VecZnxBigToMut<B>, D: Distribution<f64>>(
module: &Module<B>,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
);
}
#[allow(clippy::too_many_arguments)]
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigAddDistF64Impl<B: Backend> {
fn add_dist_f64_impl<R: VecZnxBigToMut<B>, D: Distribution<f64>>(
module: &Module<B>,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
);
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
@@ -248,6 +203,17 @@ pub unsafe trait VecZnxBigSubSmallBInplaceImpl<B: Backend> {
A: VecZnxToRef;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigNegateImpl<B: Backend> {
fn vec_znx_big_negate_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxBigToRef<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
@@ -295,12 +261,20 @@ pub unsafe trait VecZnxBigAutomorphismImpl<B: Backend> {
A: VecZnxBigToRef<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigAutomorphismInplaceTmpBytesImpl<B: Backend> {
fn vec_znx_big_automorphism_inplace_tmp_bytes_impl(module: &Module<B>) -> usize;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigAutomorphismInplaceImpl<B: Backend> {
fn vec_znx_big_automorphism_inplace_impl<A>(module: &Module<B>, k: i64, a: &mut A, a_col: usize)
fn vec_znx_big_automorphism_inplace_impl<A>(module: &Module<B>, k: i64, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
where
A: VecZnxBigToMut<B>;
}

View File

@@ -23,9 +23,16 @@ pub unsafe trait VecZnxDftFromBytesImpl<B: Backend> {
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait DFTImpl<B: Backend> {
fn dft_impl<R, A>(module: &Module<B>, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
pub unsafe trait VecZnxDftApplyImpl<B: Backend> {
fn vec_znx_dft_apply_impl<R, A>(
module: &Module<B>,
step: usize,
offset: usize,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
) where
R: VecZnxDftToMut<B>,
A: VecZnxToRef;
}
@@ -42,17 +49,23 @@ pub unsafe trait VecZnxDftAllocBytesImpl<B: Backend> {
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxIDFTTmpBytesImpl<B: Backend> {
fn vec_znx_idft_tmp_bytes_impl(module: &Module<B>) -> usize;
pub unsafe trait VecZnxIdftApplyTmpBytesImpl<B: Backend> {
fn vec_znx_idft_apply_tmp_bytes_impl(module: &Module<B>) -> usize;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait IDFTImpl<B: Backend> {
fn idft_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
where
pub unsafe trait VecZnxIdftApplyImpl<B: Backend> {
fn vec_znx_idft_apply_impl<R, A>(
module: &Module<B>,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch<B>,
) where
R: VecZnxBigToMut<B>,
A: VecZnxDftToRef<B>;
}
@@ -61,8 +74,8 @@ pub unsafe trait IDFTImpl<B: Backend> {
/// * See TODO for reference code.
/// * See for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait IDFTTmpAImpl<B: Backend> {
fn idft_tmp_a_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
pub unsafe trait VecZnxIdftApplyTmpAImpl<B: Backend> {
fn vec_znx_idft_apply_tmpa_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxDftToMut<B>;
@@ -72,8 +85,8 @@ pub unsafe trait IDFTTmpAImpl<B: Backend> {
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait IDFTConsumeImpl<B: Backend> {
fn idft_consume_impl<D: Data>(module: &Module<B>, a: VecZnxDft<D, B>) -> VecZnxBig<D, B>
pub unsafe trait VecZnxIdftApplyConsumeImpl<B: Backend> {
fn vec_znx_idft_apply_consume_impl<D: Data>(module: &Module<B>, a: VecZnxDft<D, B>) -> VecZnxBig<D, B>
where
VecZnxDft<D, B>: VecZnxDftToMut<B>;
}

View File

@@ -1,5 +1,5 @@
use crate::layouts::{
Backend, MatZnxToRef, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VmpPMatOwned, VmpPMatToMut, VmpPMatToRef,
Backend, MatZnxToRef, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, VmpPMatOwned, VmpPMatToMut, VmpPMatToRef,
};
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
@@ -45,13 +45,42 @@ pub unsafe trait VmpPrepareTmpBytesImpl<B: Backend> {
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VmpPMatPrepareImpl<B: Backend> {
pub unsafe trait VmpPrepareImpl<B: Backend> {
fn vmp_prepare_impl<R, A>(module: &Module<B>, res: &mut R, a: &A, scratch: &mut Scratch<B>)
where
R: VmpPMatToMut<B>,
A: MatZnxToRef;
}
#[allow(clippy::too_many_arguments)]
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VmpApplyDftTmpBytesImpl<B: Backend> {
fn vmp_apply_dft_tmp_bytes_impl(
module: &Module<B>,
res_size: usize,
a_size: usize,
b_rows: usize,
b_cols_in: usize,
b_cols_out: usize,
b_size: usize,
) -> usize;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VmpApplyDftImpl<B: Backend> {
fn vmp_apply_dft_impl<R, A, C>(module: &Module<B>, res: &mut R, a: &A, b: &C, scratch: &mut Scratch<B>)
where
R: VecZnxDftToMut<B>,
A: VecZnxToRef,
C: VmpPMatToRef<B>;
}
#[allow(clippy::too_many_arguments)]
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.

View File

@@ -1,65 +1,35 @@
use rand_distr::Distribution;
use crate::{
layouts::{Backend, Scratch, ZnToMut},
source::Source,
};
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO
/// * See [crate::api::ZnNormalizeTmpBytes] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait ZnNormalizeTmpBytesImpl<B: Backend> {
fn zn_normalize_tmp_bytes_impl(n: usize) -> usize;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [zn_normalize_base2k_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/zn64.c#L9) for reference code.
/// * See [crate::api::ZnxNormalizeInplace] for corresponding public API.
/// * See [crate::api::ZnNormalizeInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait ZnNormalizeInplaceImpl<B: Backend> {
fn zn_normalize_inplace_impl<A>(n: usize, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
fn zn_normalize_inplace_impl<R>(n: usize, basek: usize, res: &mut R, res_col: usize, scratch: &mut Scratch<B>)
where
A: ZnToMut;
R: ZnToMut;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::api::ZnFillUniform] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait ZnFillUniformImpl<B: Backend> {
fn zn_fill_uniform_impl<R>(n: usize, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source)
fn zn_fill_uniform_impl<R>(n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: ZnToMut;
}
#[allow(clippy::too_many_arguments)]
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::api::ZnFillDistF64] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait ZnFillDistF64Impl<B: Backend> {
fn zn_fill_dist_f64_impl<R, D: Distribution<f64>>(
n: usize,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) where
R: ZnToMut;
}
#[allow(clippy::too_many_arguments)]
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::api::ZnAddDistF64] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait ZnAddDistF64Impl<B: Backend> {
fn zn_add_dist_f64_impl<R, D: Distribution<f64>>(
n: usize,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) where
R: ZnToMut;
}
#[allow(clippy::too_many_arguments)]
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::api::ZnFillNormal] for corresponding public API.

View File

@@ -0,0 +1,24 @@
pub mod reim;
pub mod reim4;
pub mod svp;
pub mod vec_znx_big;
pub mod vec_znx_dft;
pub mod vmp;
pub(crate) fn assert_approx_eq_slice(a: &[f64], b: &[f64], tol: f64) {
assert_eq!(a.len(), b.len(), "Slices have different lengths");
for (i, (x, y)) in a.iter().zip(b.iter()).enumerate() {
let diff: f64 = (x - y).abs();
let scale: f64 = x.abs().max(y.abs()).max(1.0);
assert!(
diff <= tol * scale,
"Difference at index {}: left={} right={} rel_diff={} > tol={}",
i,
x,
y,
diff / scale,
tol
);
}
}

View File

@@ -0,0 +1,31 @@
#[inline(always)]
pub fn reim_from_znx_i64_ref(res: &mut [f64], a: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len())
}
for i in 0..res.len() {
res[i] = a[i] as f64
}
}
#[inline(always)]
pub fn reim_to_znx_i64_ref(res: &mut [i64], divisor: f64, a: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len())
}
let inv_div = 1. / divisor;
for i in 0..res.len() {
res[i] = (a[i] * inv_div).round() as i64
}
}
#[inline(always)]
pub fn reim_to_znx_i64_inplace_ref(res: &mut [f64], divisor: f64) {
let inv_div = 1. / divisor;
for ri in res {
*ri = f64::from_bits(((*ri * inv_div).round() as i64) as u64)
}
}

View File

@@ -0,0 +1,327 @@
use std::fmt::Debug;
use rand_distr::num_traits::{Float, FloatConst};
use crate::reference::fft64::reim::{as_arr, as_arr_mut};
#[inline(always)]
pub fn fft_ref<R: Float + FloatConst + Debug>(m: usize, omg: &[R], data: &mut [R]) {
assert!(data.len() == 2 * m);
let (re, im) = data.split_at_mut(m);
if m <= 16 {
match m {
1 => {}
2 => fft2_ref(
as_arr_mut::<2, R>(re),
as_arr_mut::<2, R>(im),
*as_arr::<2, R>(omg),
),
4 => fft4_ref(
as_arr_mut::<4, R>(re),
as_arr_mut::<4, R>(im),
*as_arr::<4, R>(omg),
),
8 => fft8_ref(
as_arr_mut::<8, R>(re),
as_arr_mut::<8, R>(im),
*as_arr::<8, R>(omg),
),
16 => fft16_ref(
as_arr_mut::<16, R>(re),
as_arr_mut::<16, R>(im),
*as_arr::<16, R>(omg),
),
_ => {}
}
} else if m <= 2048 {
fft_bfs_16_ref(m, re, im, omg, 0);
} else {
fft_rec_16_ref(m, re, im, omg, 0);
}
}
#[inline(always)]
fn fft_rec_16_ref<R: Float + FloatConst + Debug>(m: usize, re: &mut [R], im: &mut [R], omg: &[R], mut pos: usize) -> usize {
if m <= 2048 {
return fft_bfs_16_ref(m, re, im, omg, pos);
};
let h = m >> 1;
twiddle_fft_ref(h, re, im, as_arr::<2, R>(&omg[pos..]));
pos += 2;
pos = fft_rec_16_ref(h, re, im, omg, pos);
pos = fft_rec_16_ref(h, &mut re[h..], &mut im[h..], omg, pos);
pos
}
#[inline(always)]
fn cplx_twiddle<R: Float + FloatConst>(ra: &mut R, ia: &mut R, rb: &mut R, ib: &mut R, omg_re: R, omg_im: R) {
let dr: R = *rb * omg_re - *ib * omg_im;
let di: R = *rb * omg_im + *ib * omg_re;
*rb = *ra - dr;
*ib = *ia - di;
*ra = *ra + dr;
*ia = *ia + di;
}
#[inline(always)]
fn cplx_i_twiddle<R: Float + FloatConst>(ra: &mut R, ia: &mut R, rb: &mut R, ib: &mut R, omg_re: R, omg_im: R) {
let dr: R = *rb * omg_im + *ib * omg_re;
let di: R = *rb * omg_re - *ib * omg_im;
*rb = *ra + dr;
*ib = *ia - di;
*ra = *ra - dr;
*ia = *ia + di;
}
#[inline(always)]
fn fft2_ref<R: Float + FloatConst>(re: &mut [R; 2], im: &mut [R; 2], omg: [R; 2]) {
let [ra, rb] = re;
let [ia, ib] = im;
let [romg, iomg] = omg;
cplx_twiddle(ra, ia, rb, ib, romg, iomg);
}
#[inline(always)]
fn fft4_ref<R: Float + FloatConst>(re: &mut [R; 4], im: &mut [R; 4], omg: [R; 4]) {
let [re_0, re_1, re_2, re_3] = re;
let [im_0, im_1, im_2, im_3] = im;
{
let omg_0 = omg[0];
let omg_1 = omg[1];
cplx_twiddle(re_0, im_0, re_2, im_2, omg_0, omg_1);
cplx_twiddle(re_1, im_1, re_3, im_3, omg_0, omg_1);
}
{
let omg_0 = omg[2];
let omg_1 = omg[3];
cplx_twiddle(re_0, im_0, re_1, im_1, omg_0, omg_1);
cplx_i_twiddle(re_2, im_2, re_3, im_3, omg_0, omg_1);
}
}
#[inline(always)]
fn fft8_ref<R: Float + FloatConst>(re: &mut [R; 8], im: &mut [R; 8], omg: [R; 8]) {
let [re_0, re_1, re_2, re_3, re_4, re_5, re_6, re_7] = re;
let [im_0, im_1, im_2, im_3, im_4, im_5, im_6, im_7] = im;
{
let omg_0 = omg[0];
let omg_1 = omg[1];
cplx_twiddle(re_0, im_0, re_4, im_4, omg_0, omg_1);
cplx_twiddle(re_1, im_1, re_5, im_5, omg_0, omg_1);
cplx_twiddle(re_2, im_2, re_6, im_6, omg_0, omg_1);
cplx_twiddle(re_3, im_3, re_7, im_7, omg_0, omg_1);
}
{
let omg_2 = omg[2];
let omg_3 = omg[3];
cplx_twiddle(re_0, im_0, re_2, im_2, omg_2, omg_3);
cplx_twiddle(re_1, im_1, re_3, im_3, omg_2, omg_3);
cplx_i_twiddle(re_4, im_4, re_6, im_6, omg_2, omg_3);
cplx_i_twiddle(re_5, im_5, re_7, im_7, omg_2, omg_3);
}
{
let omg_4 = omg[4];
let omg_5 = omg[5];
let omg_6 = omg[6];
let omg_7 = omg[7];
cplx_twiddle(re_0, im_0, re_1, im_1, omg_4, omg_6);
cplx_i_twiddle(re_2, im_2, re_3, im_3, omg_4, omg_6);
cplx_twiddle(re_4, im_4, re_5, im_5, omg_5, omg_7);
cplx_i_twiddle(re_6, im_6, re_7, im_7, omg_5, omg_7);
}
}
#[inline(always)]
fn fft16_ref<R: Float + FloatConst + Debug>(re: &mut [R; 16], im: &mut [R; 16], omg: [R; 16]) {
let [
re_0,
re_1,
re_2,
re_3,
re_4,
re_5,
re_6,
re_7,
re_8,
re_9,
re_10,
re_11,
re_12,
re_13,
re_14,
re_15,
] = re;
let [
im_0,
im_1,
im_2,
im_3,
im_4,
im_5,
im_6,
im_7,
im_8,
im_9,
im_10,
im_11,
im_12,
im_13,
im_14,
im_15,
] = im;
{
let omg_0: R = omg[0];
let omg_1: R = omg[1];
cplx_twiddle(re_0, im_0, re_8, im_8, omg_0, omg_1);
cplx_twiddle(re_1, im_1, re_9, im_9, omg_0, omg_1);
cplx_twiddle(re_2, im_2, re_10, im_10, omg_0, omg_1);
cplx_twiddle(re_3, im_3, re_11, im_11, omg_0, omg_1);
cplx_twiddle(re_4, im_4, re_12, im_12, omg_0, omg_1);
cplx_twiddle(re_5, im_5, re_13, im_13, omg_0, omg_1);
cplx_twiddle(re_6, im_6, re_14, im_14, omg_0, omg_1);
cplx_twiddle(re_7, im_7, re_15, im_15, omg_0, omg_1);
}
{
let omg_2: R = omg[2];
let omg_3: R = omg[3];
cplx_twiddle(re_0, im_0, re_4, im_4, omg_2, omg_3);
cplx_twiddle(re_1, im_1, re_5, im_5, omg_2, omg_3);
cplx_twiddle(re_2, im_2, re_6, im_6, omg_2, omg_3);
cplx_twiddle(re_3, im_3, re_7, im_7, omg_2, omg_3);
cplx_i_twiddle(re_8, im_8, re_12, im_12, omg_2, omg_3);
cplx_i_twiddle(re_9, im_9, re_13, im_13, omg_2, omg_3);
cplx_i_twiddle(re_10, im_10, re_14, im_14, omg_2, omg_3);
cplx_i_twiddle(re_11, im_11, re_15, im_15, omg_2, omg_3);
}
{
let omg_0: R = omg[4];
let omg_1: R = omg[5];
let omg_2: R = omg[6];
let omg_3: R = omg[7];
cplx_twiddle(re_0, im_0, re_2, im_2, omg_0, omg_1);
cplx_twiddle(re_1, im_1, re_3, im_3, omg_0, omg_1);
cplx_twiddle(re_8, im_8, re_10, im_10, omg_2, omg_3);
cplx_twiddle(re_9, im_9, re_11, im_11, omg_2, omg_3);
cplx_i_twiddle(re_4, im_4, re_6, im_6, omg_0, omg_1);
cplx_i_twiddle(re_5, im_5, re_7, im_7, omg_0, omg_1);
cplx_i_twiddle(re_12, im_12, re_14, im_14, omg_2, omg_3);
cplx_i_twiddle(re_13, im_13, re_15, im_15, omg_2, omg_3);
}
{
let omg_0: R = omg[8];
let omg_1: R = omg[9];
let omg_2: R = omg[10];
let omg_3: R = omg[11];
let omg_4: R = omg[12];
let omg_5: R = omg[13];
let omg_6: R = omg[14];
let omg_7: R = omg[15];
cplx_twiddle(re_0, im_0, re_1, im_1, omg_0, omg_4);
cplx_twiddle(re_4, im_4, re_5, im_5, omg_1, omg_5);
cplx_twiddle(re_8, im_8, re_9, im_9, omg_2, omg_6);
cplx_twiddle(re_12, im_12, re_13, im_13, omg_3, omg_7);
cplx_i_twiddle(re_2, im_2, re_3, im_3, omg_0, omg_4);
cplx_i_twiddle(re_6, im_6, re_7, im_7, omg_1, omg_5);
cplx_i_twiddle(re_10, im_10, re_11, im_11, omg_2, omg_6);
cplx_i_twiddle(re_14, im_14, re_15, im_15, omg_3, omg_7);
}
}
#[inline(always)]
fn fft_bfs_16_ref<R: Float + FloatConst + Debug>(m: usize, re: &mut [R], im: &mut [R], omg: &[R], mut pos: usize) -> usize {
let log_m: usize = (usize::BITS - (m - 1).leading_zeros()) as usize;
let mut mm: usize = m;
if !log_m.is_multiple_of(2) {
let h: usize = mm >> 1;
twiddle_fft_ref(h, re, im, as_arr::<2, R>(&omg[pos..]));
pos += 2;
mm = h
}
while mm > 16 {
let h: usize = mm >> 2;
for off in (0..m).step_by(mm) {
bitwiddle_fft_ref(
h,
&mut re[off..],
&mut im[off..],
as_arr::<4, R>(&omg[pos..]),
);
pos += 4;
}
mm = h
}
for off in (0..m).step_by(16) {
fft16_ref(
as_arr_mut::<16, R>(&mut re[off..]),
as_arr_mut::<16, R>(&mut im[off..]),
*as_arr::<16, R>(&omg[pos..]),
);
pos += 16;
}
pos
}
#[inline(always)]
fn twiddle_fft_ref<R: Float + FloatConst>(h: usize, re: &mut [R], im: &mut [R], omg: &[R; 2]) {
let romg = omg[0];
let iomg = omg[1];
let (re_lhs, re_rhs) = re.split_at_mut(h);
let (im_lhs, im_rhs) = im.split_at_mut(h);
for i in 0..h {
cplx_twiddle(
&mut re_lhs[i],
&mut im_lhs[i],
&mut re_rhs[i],
&mut im_rhs[i],
romg,
iomg,
);
}
}
#[inline(always)]
fn bitwiddle_fft_ref<R: Float + FloatConst>(h: usize, re: &mut [R], im: &mut [R], omg: &[R; 4]) {
let (r0, r2) = re.split_at_mut(2 * h);
let (r0, r1) = r0.split_at_mut(h);
let (r2, r3) = r2.split_at_mut(h);
let (i0, i2) = im.split_at_mut(2 * h);
let (i0, i1) = i0.split_at_mut(h);
let (i2, i3) = i2.split_at_mut(h);
let omg_0: R = omg[0];
let omg_1: R = omg[1];
let omg_2: R = omg[2];
let omg_3: R = omg[3];
for i in 0..h {
cplx_twiddle(&mut r0[i], &mut i0[i], &mut r2[i], &mut i2[i], omg_0, omg_1);
cplx_twiddle(&mut r1[i], &mut i1[i], &mut r3[i], &mut i3[i], omg_0, omg_1);
}
for i in 0..h {
cplx_twiddle(&mut r0[i], &mut i0[i], &mut r1[i], &mut i1[i], omg_2, omg_3);
cplx_i_twiddle(&mut r2[i], &mut i2[i], &mut r3[i], &mut i3[i], omg_2, omg_3);
}
}

View File

@@ -0,0 +1,156 @@
#[inline(always)]
pub fn reim_add_ref(res: &mut [f64], a: &[f64], b: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), res.len());
assert_eq!(b.len(), res.len());
}
for i in 0..res.len() {
res[i] = a[i] + b[i]
}
}
#[inline(always)]
pub fn reim_add_inplace_ref(res: &mut [f64], a: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), res.len());
}
for i in 0..res.len() {
res[i] += a[i]
}
}
#[inline(always)]
pub fn reim_sub_ref(res: &mut [f64], a: &[f64], b: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), res.len());
assert_eq!(b.len(), res.len());
}
for i in 0..res.len() {
res[i] = a[i] - b[i]
}
}
#[inline(always)]
pub fn reim_sub_ab_inplace_ref(res: &mut [f64], a: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), res.len());
}
for i in 0..res.len() {
res[i] -= a[i]
}
}
#[inline(always)]
pub fn reim_sub_ba_inplace_ref(res: &mut [f64], a: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), res.len());
}
for i in 0..res.len() {
res[i] = a[i] - res[i]
}
}
#[inline(always)]
pub fn reim_negate_ref(res: &mut [f64], a: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), res.len());
}
for i in 0..res.len() {
res[i] = -a[i]
}
}
#[inline(always)]
pub fn reim_negate_inplace_ref(res: &mut [f64]) {
for ri in res {
*ri = -*ri
}
}
#[inline(always)]
pub fn reim_addmul_ref(res: &mut [f64], a: &[f64], b: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), res.len());
assert_eq!(b.len(), res.len());
}
let m: usize = res.len() >> 1;
let (rr, ri) = res.split_at_mut(m);
let (ar, ai) = a.split_at(m);
let (br, bi) = b.split_at(m);
for i in 0..m {
let _ar: f64 = ar[i];
let _ai: f64 = ai[i];
let _br: f64 = br[i];
let _bi: f64 = bi[i];
let _rr: f64 = _ar * _br - _ai * _bi;
let _ri: f64 = _ar * _bi + _ai * _br;
rr[i] += _rr;
ri[i] += _ri;
}
}
#[inline(always)]
pub fn reim_mul_inplace_ref(res: &mut [f64], a: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), res.len());
}
let m: usize = res.len() >> 1;
let (rr, ri) = res.split_at_mut(m);
let (ar, ai) = a.split_at(m);
for i in 0..m {
let _ar: f64 = ar[i];
let _ai: f64 = ai[i];
let _br: f64 = rr[i];
let _bi: f64 = ri[i];
let _rr: f64 = _ar * _br - _ai * _bi;
let _ri: f64 = _ar * _bi + _ai * _br;
rr[i] = _rr;
ri[i] = _ri;
}
}
#[inline(always)]
pub fn reim_mul_ref(res: &mut [f64], a: &[f64], b: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), res.len());
assert_eq!(b.len(), res.len());
}
let m: usize = res.len() >> 1;
let (rr, ri) = res.split_at_mut(m);
let (ar, ai) = a.split_at(m);
let (br, bi) = b.split_at(m);
for i in 0..m {
let _ar: f64 = ar[i];
let _ai: f64 = ai[i];
let _br: f64 = br[i];
let _bi: f64 = bi[i];
let _rr: f64 = _ar * _br - _ai * _bi;
let _ri: f64 = _ar * _bi + _ai * _br;
rr[i] = _rr;
ri[i] = _ri;
}
}

View File

@@ -0,0 +1,322 @@
use std::fmt::Debug;
use rand_distr::num_traits::{Float, FloatConst};
use crate::reference::fft64::reim::{as_arr, as_arr_mut};
pub fn ifft_ref<R: Float + FloatConst + Debug>(m: usize, omg: &[R], data: &mut [R]) {
assert!(data.len() == 2 * m);
let (re, im) = data.split_at_mut(m);
if m <= 16 {
match m {
1 => {}
2 => ifft2_ref(
as_arr_mut::<2, R>(re),
as_arr_mut::<2, R>(im),
*as_arr::<2, R>(omg),
),
4 => ifft4_ref(
as_arr_mut::<4, R>(re),
as_arr_mut::<4, R>(im),
*as_arr::<4, R>(omg),
),
8 => ifft8_ref(
as_arr_mut::<8, R>(re),
as_arr_mut::<8, R>(im),
*as_arr::<8, R>(omg),
),
16 => ifft16_ref(
as_arr_mut::<16, R>(re),
as_arr_mut::<16, R>(im),
*as_arr::<16, R>(omg),
),
_ => {}
}
} else if m <= 2048 {
ifft_bfs_16_ref(m, re, im, omg, 0);
} else {
ifft_rec_16_ref(m, re, im, omg, 0);
}
}
#[inline(always)]
fn ifft_rec_16_ref<R: Float + FloatConst>(m: usize, re: &mut [R], im: &mut [R], omg: &[R], mut pos: usize) -> usize {
if m <= 2048 {
return ifft_bfs_16_ref(m, re, im, omg, pos);
};
let h: usize = m >> 1;
pos = ifft_rec_16_ref(h, re, im, omg, pos);
pos = ifft_rec_16_ref(h, &mut re[h..], &mut im[h..], omg, pos);
inv_twiddle_ifft_ref(h, re, im, as_arr::<2, R>(&omg[pos..]));
pos += 2;
pos
}
#[inline(always)]
fn ifft_bfs_16_ref<R: Float + FloatConst>(m: usize, re: &mut [R], im: &mut [R], omg: &[R], mut pos: usize) -> usize {
let log_m: usize = (usize::BITS - (m - 1).leading_zeros()) as usize;
for off in (0..m).step_by(16) {
ifft16_ref(
as_arr_mut::<16, R>(&mut re[off..]),
as_arr_mut::<16, R>(&mut im[off..]),
*as_arr::<16, R>(&omg[pos..]),
);
pos += 16;
}
let mut h: usize = 16;
let m_half: usize = m >> 1;
while h < m_half {
let mm: usize = h << 2;
for off in (0..m).step_by(mm) {
inv_bitwiddle_ifft_ref(
h,
&mut re[off..],
&mut im[off..],
as_arr::<4, R>(&omg[pos..]),
);
pos += 4;
}
h = mm;
}
if !log_m.is_multiple_of(2) {
inv_twiddle_ifft_ref(h, re, im, as_arr::<2, R>(&omg[pos..]));
pos += 2;
}
pos
}
#[inline(always)]
fn inv_twiddle<R: Float + FloatConst>(ra: &mut R, ia: &mut R, rb: &mut R, ib: &mut R, omg_re: R, omg_im: R) {
let r_diff: R = *ra - *rb;
let i_diff: R = *ia - *ib;
*ra = *ra + *rb;
*ia = *ia + *ib;
*rb = r_diff * omg_re - i_diff * omg_im;
*ib = r_diff * omg_im + i_diff * omg_re;
}
#[inline(always)]
fn inv_itwiddle<R: Float + FloatConst>(ra: &mut R, ia: &mut R, rb: &mut R, ib: &mut R, omg_re: R, omg_im: R) {
let r_diff: R = *ra - *rb;
let i_diff: R = *ia - *ib;
*ra = *ra + *rb;
*ia = *ia + *ib;
*rb = r_diff * omg_im + i_diff * omg_re;
*ib = -r_diff * omg_re + i_diff * omg_im;
}
#[inline(always)]
fn ifft2_ref<R: Float + FloatConst>(re: &mut [R; 2], im: &mut [R; 2], omg: [R; 2]) {
let [ra, rb] = re;
let [ia, ib] = im;
let [romg, iomg] = omg;
inv_twiddle(ra, ia, rb, ib, romg, iomg);
}
#[inline(always)]
fn ifft4_ref<R: Float + FloatConst>(re: &mut [R; 4], im: &mut [R; 4], omg: [R; 4]) {
let [re_0, re_1, re_2, re_3] = re;
let [im_0, im_1, im_2, im_3] = im;
{
let omg_0: R = omg[0];
let omg_1: R = omg[1];
inv_twiddle(re_0, im_0, re_1, im_1, omg_0, omg_1);
inv_itwiddle(re_2, im_2, re_3, im_3, omg_0, omg_1);
}
{
let omg_0: R = omg[2];
let omg_1: R = omg[3];
inv_twiddle(re_0, im_0, re_2, im_2, omg_0, omg_1);
inv_twiddle(re_1, im_1, re_3, im_3, omg_0, omg_1);
}
}
#[inline(always)]
fn ifft8_ref<R: Float + FloatConst>(re: &mut [R; 8], im: &mut [R; 8], omg: [R; 8]) {
let [re_0, re_1, re_2, re_3, re_4, re_5, re_6, re_7] = re;
let [im_0, im_1, im_2, im_3, im_4, im_5, im_6, im_7] = im;
{
let omg_4: R = omg[0];
let omg_5: R = omg[1];
let omg_6: R = omg[2];
let omg_7: R = omg[3];
inv_twiddle(re_0, im_0, re_1, im_1, omg_4, omg_6);
inv_itwiddle(re_2, im_2, re_3, im_3, omg_4, omg_6);
inv_twiddle(re_4, im_4, re_5, im_5, omg_5, omg_7);
inv_itwiddle(re_6, im_6, re_7, im_7, omg_5, omg_7);
}
{
let omg_2: R = omg[4];
let omg_3: R = omg[5];
inv_twiddle(re_0, im_0, re_2, im_2, omg_2, omg_3);
inv_twiddle(re_1, im_1, re_3, im_3, omg_2, omg_3);
inv_itwiddle(re_4, im_4, re_6, im_6, omg_2, omg_3);
inv_itwiddle(re_5, im_5, re_7, im_7, omg_2, omg_3);
}
{
let omg_0: R = omg[6];
let omg_1: R = omg[7];
inv_twiddle(re_0, im_0, re_4, im_4, omg_0, omg_1);
inv_twiddle(re_1, im_1, re_5, im_5, omg_0, omg_1);
inv_twiddle(re_2, im_2, re_6, im_6, omg_0, omg_1);
inv_twiddle(re_3, im_3, re_7, im_7, omg_0, omg_1);
}
}
#[inline(always)]
fn ifft16_ref<R: Float + FloatConst>(re: &mut [R; 16], im: &mut [R; 16], omg: [R; 16]) {
let [
re_0,
re_1,
re_2,
re_3,
re_4,
re_5,
re_6,
re_7,
re_8,
re_9,
re_10,
re_11,
re_12,
re_13,
re_14,
re_15,
] = re;
let [
im_0,
im_1,
im_2,
im_3,
im_4,
im_5,
im_6,
im_7,
im_8,
im_9,
im_10,
im_11,
im_12,
im_13,
im_14,
im_15,
] = im;
{
let omg_0: R = omg[0];
let omg_1: R = omg[1];
let omg_2: R = omg[2];
let omg_3: R = omg[3];
let omg_4: R = omg[4];
let omg_5: R = omg[5];
let omg_6: R = omg[6];
let omg_7: R = omg[7];
inv_twiddle(re_0, im_0, re_1, im_1, omg_0, omg_4);
inv_itwiddle(re_2, im_2, re_3, im_3, omg_0, omg_4);
inv_twiddle(re_4, im_4, re_5, im_5, omg_1, omg_5);
inv_itwiddle(re_6, im_6, re_7, im_7, omg_1, omg_5);
inv_twiddle(re_8, im_8, re_9, im_9, omg_2, omg_6);
inv_itwiddle(re_10, im_10, re_11, im_11, omg_2, omg_6);
inv_twiddle(re_12, im_12, re_13, im_13, omg_3, omg_7);
inv_itwiddle(re_14, im_14, re_15, im_15, omg_3, omg_7);
}
{
let omg_0: R = omg[8];
let omg_1: R = omg[9];
let omg_2: R = omg[10];
let omg_3: R = omg[11];
inv_twiddle(re_0, im_0, re_2, im_2, omg_0, omg_1);
inv_twiddle(re_1, im_1, re_3, im_3, omg_0, omg_1);
inv_itwiddle(re_4, im_4, re_6, im_6, omg_0, omg_1);
inv_itwiddle(re_5, im_5, re_7, im_7, omg_0, omg_1);
inv_twiddle(re_8, im_8, re_10, im_10, omg_2, omg_3);
inv_twiddle(re_9, im_9, re_11, im_11, omg_2, omg_3);
inv_itwiddle(re_12, im_12, re_14, im_14, omg_2, omg_3);
inv_itwiddle(re_13, im_13, re_15, im_15, omg_2, omg_3);
}
{
let omg_2: R = omg[12];
let omg_3: R = omg[13];
inv_twiddle(re_0, im_0, re_4, im_4, omg_2, omg_3);
inv_twiddle(re_1, im_1, re_5, im_5, omg_2, omg_3);
inv_twiddle(re_2, im_2, re_6, im_6, omg_2, omg_3);
inv_twiddle(re_3, im_3, re_7, im_7, omg_2, omg_3);
inv_itwiddle(re_8, im_8, re_12, im_12, omg_2, omg_3);
inv_itwiddle(re_9, im_9, re_13, im_13, omg_2, omg_3);
inv_itwiddle(re_10, im_10, re_14, im_14, omg_2, omg_3);
inv_itwiddle(re_11, im_11, re_15, im_15, omg_2, omg_3);
}
{
let omg_0: R = omg[14];
let omg_1: R = omg[15];
inv_twiddle(re_0, im_0, re_8, im_8, omg_0, omg_1);
inv_twiddle(re_1, im_1, re_9, im_9, omg_0, omg_1);
inv_twiddle(re_2, im_2, re_10, im_10, omg_0, omg_1);
inv_twiddle(re_3, im_3, re_11, im_11, omg_0, omg_1);
inv_twiddle(re_4, im_4, re_12, im_12, omg_0, omg_1);
inv_twiddle(re_5, im_5, re_13, im_13, omg_0, omg_1);
inv_twiddle(re_6, im_6, re_14, im_14, omg_0, omg_1);
inv_twiddle(re_7, im_7, re_15, im_15, omg_0, omg_1);
}
}
#[inline(always)]
fn inv_twiddle_ifft_ref<R: Float + FloatConst>(h: usize, re: &mut [R], im: &mut [R], omg: &[R; 2]) {
let romg = omg[0];
let iomg = omg[1];
let (re_lhs, re_rhs) = re.split_at_mut(h);
let (im_lhs, im_rhs) = im.split_at_mut(h);
for i in 0..h {
inv_twiddle(
&mut re_lhs[i],
&mut im_lhs[i],
&mut re_rhs[i],
&mut im_rhs[i],
romg,
iomg,
);
}
}
#[inline(always)]
fn inv_bitwiddle_ifft_ref<R: Float + FloatConst>(h: usize, re: &mut [R], im: &mut [R], omg: &[R; 4]) {
let (r0, r2) = re.split_at_mut(2 * h);
let (r0, r1) = r0.split_at_mut(h);
let (r2, r3) = r2.split_at_mut(h);
let (i0, i2) = im.split_at_mut(2 * h);
let (i0, i1) = i0.split_at_mut(h);
let (i2, i3) = i2.split_at_mut(h);
let omg_0: R = omg[0];
let omg_1: R = omg[1];
let omg_2: R = omg[2];
let omg_3: R = omg[3];
for i in 0..h {
inv_twiddle(&mut r0[i], &mut i0[i], &mut r1[i], &mut i1[i], omg_0, omg_1);
inv_itwiddle(&mut r2[i], &mut i2[i], &mut r3[i], &mut i3[i], omg_0, omg_1);
}
for i in 0..h {
inv_twiddle(&mut r0[i], &mut i0[i], &mut r2[i], &mut i2[i], omg_2, omg_3);
inv_twiddle(&mut r1[i], &mut i1[i], &mut r3[i], &mut i3[i], omg_2, omg_3);
}
}

View File

@@ -0,0 +1,128 @@
// ----------------------------------------------------------------------
// DISCLAIMER
//
// This module contains code that has been directly ported from the
// spqlios-arithmetic library
// (https://github.com/tfhe/spqlios-arithmetic), which is licensed
// under the Apache License, Version 2.0.
//
// The porting process from C to Rust was done with minimal changes
// in order to preserve the semantics and performance characteristics
// of the original implementation.
//
// Both Poulpy and spqlios-arithmetic are distributed under the terms
// of the Apache License, Version 2.0. See the LICENSE file for details.
//
// ----------------------------------------------------------------------
#![allow(bad_asm_style)]
mod conversion;
mod fft_ref;
mod fft_vec;
mod ifft_ref;
mod table_fft;
mod table_ifft;
mod zero;
pub use conversion::*;
pub use fft_ref::*;
pub use fft_vec::*;
pub use ifft_ref::*;
pub use table_fft::*;
pub use table_ifft::*;
pub use zero::*;
#[inline(always)]
pub(crate) fn as_arr<const size: usize, R: Float + FloatConst>(x: &[R]) -> &[R; size] {
debug_assert!(x.len() >= size);
unsafe { &*(x.as_ptr() as *const [R; size]) }
}
#[inline(always)]
pub(crate) fn as_arr_mut<const size: usize, R: Float + FloatConst>(x: &mut [R]) -> &mut [R; size] {
debug_assert!(x.len() >= size);
unsafe { &mut *(x.as_mut_ptr() as *mut [R; size]) }
}
use rand_distr::num_traits::{Float, FloatConst};
#[inline(always)]
pub(crate) fn frac_rev_bits<R: Float + FloatConst>(x: usize) -> R {
let half: R = R::from(0.5).unwrap();
match x {
0 => R::zero(),
1 => half,
_ => {
if x.is_multiple_of(2) {
frac_rev_bits::<R>(x >> 1) * half
} else {
frac_rev_bits::<R>(x >> 1) * half + half
}
}
}
}
pub trait ReimDFTExecute<D, T> {
fn reim_dft_execute(table: &D, data: &mut [T]);
}
pub trait ReimFromZnx {
fn reim_from_znx(res: &mut [f64], a: &[i64]);
}
pub trait ReimToZnx {
fn reim_to_znx(res: &mut [i64], divisor: f64, a: &[f64]);
}
pub trait ReimToZnxInplace {
fn reim_to_znx_inplace(res: &mut [f64], divisor: f64);
}
pub trait ReimAdd {
fn reim_add(res: &mut [f64], a: &[f64], b: &[f64]);
}
pub trait ReimAddInplace {
fn reim_add_inplace(res: &mut [f64], a: &[f64]);
}
pub trait ReimSub {
fn reim_sub(res: &mut [f64], a: &[f64], b: &[f64]);
}
pub trait ReimSubABInplace {
fn reim_sub_ab_inplace(res: &mut [f64], a: &[f64]);
}
pub trait ReimSubBAInplace {
fn reim_sub_ba_inplace(res: &mut [f64], a: &[f64]);
}
pub trait ReimNegate {
fn reim_negate(res: &mut [f64], a: &[f64]);
}
pub trait ReimNegateInplace {
fn reim_negate_inplace(res: &mut [f64]);
}
pub trait ReimMul {
fn reim_mul(res: &mut [f64], a: &[f64], b: &[f64]);
}
pub trait ReimMulInplace {
fn reim_mul_inplace(res: &mut [f64], a: &[f64]);
}
pub trait ReimAddMul {
fn reim_addmul(res: &mut [f64], a: &[f64], b: &[f64]);
}
pub trait ReimCopy {
fn reim_copy(res: &mut [f64], a: &[f64]);
}
pub trait ReimZero {
fn reim_zero(res: &mut [f64]);
}

View File

@@ -0,0 +1,207 @@
use std::fmt::Debug;
use rand_distr::num_traits::{Float, FloatConst};
use crate::{
alloc_aligned,
reference::fft64::reim::{ReimDFTExecute, fft_ref, frac_rev_bits},
};
pub struct ReimFFTRef;
impl ReimDFTExecute<ReimFFTTable<f64>, f64> for ReimFFTRef {
fn reim_dft_execute(table: &ReimFFTTable<f64>, data: &mut [f64]) {
fft_ref(table.m, &table.omg, data);
}
}
pub struct ReimFFTTable<R: Float + FloatConst + Debug> {
m: usize,
omg: Vec<R>,
}
impl<R: Float + FloatConst + Debug + 'static> ReimFFTTable<R> {
pub fn new(m: usize) -> Self {
assert!(m & (m - 1) == 0, "m must be a power of two but is {}", m);
let mut omg: Vec<R> = alloc_aligned::<R>(2 * m);
let quarter: R = R::from(1. / 4.).unwrap();
if m <= 16 {
match m {
1 => {}
2 => {
fill_fft2_omegas(quarter, &mut omg, 0);
}
4 => {
fill_fft4_omegas(quarter, &mut omg, 0);
}
8 => {
fill_fft8_omegas(quarter, &mut omg, 0);
}
16 => {
fill_fft16_omegas(quarter, &mut omg, 0);
}
_ => {}
}
} else if m <= 2048 {
fill_fft_bfs_16_omegas(m, quarter, &mut omg, 0);
} else {
fill_fft_rec_16_omegas(m, quarter, &mut omg, 0);
}
Self { m, omg }
}
pub fn m(&self) -> usize {
self.m
}
pub fn omg(&self) -> &[R] {
&self.omg
}
}
#[inline(always)]
fn fill_fft2_omegas<R: Float + FloatConst>(j: R, omg: &mut [R], pos: usize) -> usize {
let omg_pos: &mut [R] = &mut omg[pos..];
assert!(omg_pos.len() >= 2);
let angle: R = j / R::from(2).unwrap();
let two_pi: R = R::from(2).unwrap() * R::PI();
omg_pos[0] = R::cos(two_pi * angle);
omg_pos[1] = R::sin(two_pi * angle);
pos + 2
}
#[inline(always)]
fn fill_fft4_omegas<R: Float + FloatConst>(j: R, omg: &mut [R], pos: usize) -> usize {
let omg_pos: &mut [R] = &mut omg[pos..];
assert!(omg_pos.len() >= 4);
let angle_1: R = j / R::from(2).unwrap();
let angle_2: R = j / R::from(4).unwrap();
let two_pi: R = R::from(2).unwrap() * R::PI();
omg_pos[0] = R::cos(two_pi * angle_1);
omg_pos[1] = R::sin(two_pi * angle_1);
omg_pos[2] = R::cos(two_pi * angle_2);
omg_pos[3] = R::sin(two_pi * angle_2);
pos + 4
}
#[inline(always)]
fn fill_fft8_omegas<R: Float + FloatConst>(j: R, omg: &mut [R], pos: usize) -> usize {
let omg_pos: &mut [R] = &mut omg[pos..];
assert!(omg_pos.len() >= 8);
let _8th: R = R::from(1. / 8.).unwrap();
let angle_1: R = j / R::from(2).unwrap();
let angle_2: R = j / R::from(4).unwrap();
let angle_4: R = j / R::from(8).unwrap();
let two_pi: R = R::from(2).unwrap() * R::PI();
omg_pos[0] = R::cos(two_pi * angle_1);
omg_pos[1] = R::sin(two_pi * angle_1);
omg_pos[2] = R::cos(two_pi * angle_2);
omg_pos[3] = R::sin(two_pi * angle_2);
omg_pos[4] = R::cos(two_pi * angle_4);
omg_pos[5] = R::cos(two_pi * (angle_4 + _8th));
omg_pos[6] = R::sin(two_pi * angle_4);
omg_pos[7] = R::sin(two_pi * (angle_4 + _8th));
pos + 8
}
#[inline(always)]
fn fill_fft16_omegas<R: Float + FloatConst>(j: R, omg: &mut [R], pos: usize) -> usize {
let omg_pos: &mut [R] = &mut omg[pos..];
assert!(omg_pos.len() >= 16);
let _8th: R = R::from(1. / 8.).unwrap();
let _16th: R = R::from(1. / 16.).unwrap();
let angle_1: R = j / R::from(2).unwrap();
let angle_2: R = j / R::from(4).unwrap();
let angle_4: R = j / R::from(8).unwrap();
let angle_8: R = j / R::from(16).unwrap();
let two_pi: R = R::from(2).unwrap() * R::PI();
omg_pos[0] = R::cos(two_pi * angle_1);
omg_pos[1] = R::sin(two_pi * angle_1);
omg_pos[2] = R::cos(two_pi * angle_2);
omg_pos[3] = R::sin(two_pi * angle_2);
omg_pos[4] = R::cos(two_pi * angle_4);
omg_pos[5] = R::sin(two_pi * angle_4);
omg_pos[6] = R::cos(two_pi * (angle_4 + _8th));
omg_pos[7] = R::sin(two_pi * (angle_4 + _8th));
omg_pos[8] = R::cos(two_pi * angle_8);
omg_pos[9] = R::cos(two_pi * (angle_8 + _8th));
omg_pos[10] = R::cos(two_pi * (angle_8 + _16th));
omg_pos[11] = R::cos(two_pi * (angle_8 + _8th + _16th));
omg_pos[12] = R::sin(two_pi * angle_8);
omg_pos[13] = R::sin(two_pi * (angle_8 + _8th));
omg_pos[14] = R::sin(two_pi * (angle_8 + _16th));
omg_pos[15] = R::sin(two_pi * (angle_8 + _8th + _16th));
pos + 16
}
#[inline(always)]
fn fill_fft_bfs_16_omegas<R: Float + FloatConst>(m: usize, j: R, omg: &mut [R], mut pos: usize) -> usize {
let log_m: usize = (usize::BITS - (m - 1).leading_zeros()) as usize;
let mut mm: usize = m;
let mut jj: R = j;
let two_pi: R = R::from(2).unwrap() * R::PI();
if !log_m.is_multiple_of(2) {
let h = mm >> 1;
let j: R = jj * R::from(0.5).unwrap();
omg[pos] = R::cos(two_pi * j);
omg[pos + 1] = R::sin(two_pi * j);
pos += 2;
mm = h;
jj = j
}
while mm > 16 {
let h: usize = mm >> 2;
let j: R = jj * R::from(1. / 4.).unwrap();
for i in (0..m).step_by(mm) {
let rs_0 = j + frac_rev_bits::<R>(i / mm) * R::from(1. / 4.).unwrap();
let rs_1 = R::from(2).unwrap() * rs_0;
omg[pos] = R::cos(two_pi * rs_1);
omg[pos + 1] = R::sin(two_pi * rs_1);
omg[pos + 2] = R::cos(two_pi * rs_0);
omg[pos + 3] = R::sin(two_pi * rs_0);
pos += 4;
}
mm = h;
jj = j;
}
for i in (0..m).step_by(16) {
let j = jj + frac_rev_bits(i >> 4);
fill_fft16_omegas(j, omg, pos);
pos += 16
}
pos
}
#[inline(always)]
fn fill_fft_rec_16_omegas<R: Float + FloatConst>(m: usize, j: R, omg: &mut [R], mut pos: usize) -> usize {
if m <= 2048 {
return fill_fft_bfs_16_omegas(m, j, omg, pos);
}
let h: usize = m >> 1;
let s: R = j * R::from(0.5).unwrap();
let _2pi = R::from(2).unwrap() * R::PI();
omg[pos] = R::cos(_2pi * s);
omg[pos + 1] = R::sin(_2pi * s);
pos += 2;
pos = fill_fft_rec_16_omegas(h, s, omg, pos);
pos = fill_fft_rec_16_omegas(h, s + R::from(0.5).unwrap(), omg, pos);
pos
}
#[inline(always)]
fn ctwiddle_ref(ra: &mut f64, ia: &mut f64, rb: &mut f64, ib: &mut f64, omg_re: f64, omg_im: f64) {
let dr: f64 = *rb * omg_re - *ib * omg_im;
let di: f64 = *rb * omg_im + *ib * omg_re;
*rb = *ra - dr;
*ib = *ia - di;
*ra += dr;
*ia += di;
}

View File

@@ -0,0 +1,201 @@
use std::fmt::Debug;
use rand_distr::num_traits::{Float, FloatConst};
use crate::{
alloc_aligned,
reference::fft64::reim::{ReimDFTExecute, frac_rev_bits, ifft_ref::ifft_ref},
};
pub struct ReimIFFTRef;
impl ReimDFTExecute<ReimIFFTTable<f64>, f64> for ReimIFFTRef {
fn reim_dft_execute(table: &ReimIFFTTable<f64>, data: &mut [f64]) {
ifft_ref(table.m, &table.omg, data);
}
}
pub struct ReimIFFTTable<R: Float + FloatConst + Debug> {
m: usize,
omg: Vec<R>,
}
impl<R: Float + FloatConst + Debug> ReimIFFTTable<R> {
pub fn new(m: usize) -> Self {
assert!(m & (m - 1) == 0, "m must be a power of two but is {}", m);
let mut omg: Vec<R> = alloc_aligned::<R>(2 * m);
let quarter: R = R::exp2(R::from(-2).unwrap());
if m <= 16 {
match m {
1 => {}
2 => {
fill_ifft2_omegas::<R>(quarter, &mut omg, 0);
}
4 => {
fill_ifft4_omegas(quarter, &mut omg, 0);
}
8 => {
fill_ifft8_omegas(quarter, &mut omg, 0);
}
16 => {
fill_ifft16_omegas(quarter, &mut omg, 0);
}
_ => {}
}
} else if m <= 2048 {
fill_ifft_bfs_16_omegas(m, quarter, &mut omg, 0);
} else {
fill_ifft_rec_16_omegas(m, quarter, &mut omg, 0);
}
Self { m, omg }
}
pub fn execute(&self, data: &mut [R]) {
ifft_ref(self.m, &self.omg, data);
}
pub fn m(&self) -> usize {
self.m
}
pub fn omg(&self) -> &[R] {
&self.omg
}
}
#[inline(always)]
fn fill_ifft2_omegas<R: Float + FloatConst>(j: R, omg: &mut [R], pos: usize) -> usize {
let omg_pos: &mut [R] = &mut omg[pos..];
assert!(omg_pos.len() >= 2);
let angle: R = j / R::exp2(R::from(2).unwrap());
let two_pi: R = R::exp2(R::from(2).unwrap()) * R::PI();
omg_pos[0] = R::cos(two_pi * angle);
omg_pos[1] = -R::sin(two_pi * angle);
pos + 2
}
#[inline(always)]
fn fill_ifft4_omegas<R: Float + FloatConst>(j: R, omg: &mut [R], pos: usize) -> usize {
let omg_pos: &mut [R] = &mut omg[pos..];
assert!(omg_pos.len() >= 4);
let angle_1: R = j / R::from(2).unwrap();
let angle_2: R = j / R::from(4).unwrap();
let two_pi: R = R::from(2).unwrap() * R::PI();
omg_pos[0] = R::cos(two_pi * angle_2);
omg_pos[1] = -R::sin(two_pi * angle_2);
omg_pos[2] = R::cos(two_pi * angle_1);
omg_pos[3] = -R::sin(two_pi * angle_1);
pos + 4
}
#[inline(always)]
fn fill_ifft8_omegas<R: Float + FloatConst>(j: R, omg: &mut [R], pos: usize) -> usize {
let omg_pos: &mut [R] = &mut omg[pos..];
assert!(omg_pos.len() >= 8);
let _8th: R = R::from(1. / 8.).unwrap();
let angle_1: R = j / R::from(2).unwrap();
let angle_2: R = j / R::from(4).unwrap();
let angle_4: R = j / R::from(2).unwrap();
let two_pi: R = R::from(2).unwrap() * R::PI();
omg_pos[0] = R::cos(two_pi * angle_4);
omg_pos[1] = R::cos(two_pi * (angle_4 + _8th));
omg_pos[2] = -R::sin(two_pi * angle_4);
omg_pos[3] = -R::sin(two_pi * (angle_4 + _8th));
omg_pos[4] = R::cos(two_pi * angle_2);
omg_pos[5] = -R::sin(two_pi * angle_2);
omg_pos[6] = R::cos(two_pi * angle_1);
omg_pos[7] = -R::sin(two_pi * angle_1);
pos + 8
}
#[inline(always)]
fn fill_ifft16_omegas<R: Float + FloatConst>(j: R, omg: &mut [R], pos: usize) -> usize {
let omg_pos: &mut [R] = &mut omg[pos..];
assert!(omg_pos.len() >= 16);
let _8th: R = R::from(1. / 8.).unwrap();
let _16th: R = R::from(1. / 16.).unwrap();
let angle_1: R = j / R::from(2).unwrap();
let angle_2: R = j / R::from(4).unwrap();
let angle_4: R = j / R::from(8).unwrap();
let angle_8: R = j / R::from(16).unwrap();
let two_pi: R = R::from(2).unwrap() * R::PI();
omg_pos[0] = R::cos(two_pi * angle_8);
omg_pos[1] = R::cos(two_pi * (angle_8 + _8th));
omg_pos[2] = R::cos(two_pi * (angle_8 + _16th));
omg_pos[3] = R::cos(two_pi * (angle_8 + _8th + _16th));
omg_pos[4] = -R::sin(two_pi * angle_8);
omg_pos[5] = -R::sin(two_pi * (angle_8 + _8th));
omg_pos[6] = -R::sin(two_pi * (angle_8 + _16th));
omg_pos[7] = -R::sin(two_pi * (angle_8 + _8th + _16th));
omg_pos[8] = R::cos(two_pi * angle_4);
omg_pos[9] = -R::sin(two_pi * angle_4);
omg_pos[10] = R::cos(two_pi * (angle_4 + _8th));
omg_pos[11] = -R::sin(two_pi * (angle_4 + _8th));
omg_pos[12] = R::cos(two_pi * angle_2);
omg_pos[13] = -R::sin(two_pi * angle_2);
omg_pos[14] = R::cos(two_pi * angle_1);
omg_pos[15] = -R::sin(two_pi * angle_1);
pos + 16
}
#[inline(always)]
fn fill_ifft_bfs_16_omegas<R: Float + FloatConst + Debug>(m: usize, j: R, omg: &mut [R], mut pos: usize) -> usize {
let log_m: usize = (usize::BITS - (m - 1).leading_zeros()) as usize;
let mut jj: R = j * R::from(16).unwrap() / R::from(m).unwrap();
for i in (0..m).step_by(16) {
let j = jj + frac_rev_bits(i >> 4);
fill_ifft16_omegas(j, omg, pos);
pos += 16
}
let mut h: usize = 16;
let m_half: usize = m >> 1;
let two_pi: R = R::from(2).unwrap() * R::PI();
while h < m_half {
let mm: usize = h << 2;
for i in (0..m).step_by(mm) {
let rs_0 = jj + frac_rev_bits::<R>(i / mm) / R::from(4).unwrap();
let rs_1 = R::from(2).unwrap() * rs_0;
omg[pos] = R::cos(two_pi * rs_0);
omg[pos + 1] = -R::sin(two_pi * rs_0);
omg[pos + 2] = R::cos(two_pi * rs_1);
omg[pos + 3] = -R::sin(two_pi * rs_1);
pos += 4;
}
h = mm;
jj = jj * R::from(4).unwrap();
}
if !log_m.is_multiple_of(2) {
omg[pos] = R::cos(two_pi * jj);
omg[pos + 1] = -R::sin(two_pi * jj);
pos += 2;
jj = jj * R::from(2).unwrap();
}
assert_eq!(jj, j);
pos
}
#[inline(always)]
fn fill_ifft_rec_16_omegas<R: Float + FloatConst + Debug>(m: usize, j: R, omg: &mut [R], mut pos: usize) -> usize {
if m <= 2048 {
return fill_ifft_bfs_16_omegas(m, j, omg, pos);
}
let h: usize = m >> 1;
let s: R = j / R::from(2).unwrap();
pos = fill_ifft_rec_16_omegas(h, s, omg, pos);
pos = fill_ifft_rec_16_omegas(h, s + R::from(0.5).unwrap(), omg, pos);
let _2pi = R::from(2).unwrap() * R::PI();
omg[pos] = R::cos(_2pi * s);
omg[pos + 1] = -R::sin(_2pi * s);
pos += 2;
pos
}

View File

@@ -0,0 +1,11 @@
pub fn reim_zero_ref(res: &mut [f64]) {
res.fill(0.);
}
pub fn reim_copy_ref(res: &mut [f64], a: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len())
}
res.copy_from_slice(a);
}

View File

@@ -0,0 +1,209 @@
use crate::reference::fft64::reim::as_arr;
#[inline(always)]
pub fn reim4_extract_1blk_from_reim_ref(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
let mut offset: usize = blk << 2;
debug_assert!(blk < (m >> 2));
debug_assert!(dst.len() >= 2 * rows * 4);
for chunk in dst.chunks_exact_mut(4).take(2 * rows) {
chunk.copy_from_slice(&src[offset..offset + 4]);
offset += m
}
}
#[inline(always)]
pub fn reim4_save_1blk_to_reim_ref<const OVERWRITE: bool>(m: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
let mut offset: usize = blk << 2;
debug_assert!(blk < (m >> 2));
debug_assert!(dst.len() >= offset + m + 4);
debug_assert!(src.len() >= 8);
let dst_off = &mut dst[offset..offset + 4];
if OVERWRITE {
dst_off.copy_from_slice(&src[0..4]);
} else {
dst_off[0] += src[0];
dst_off[1] += src[1];
dst_off[2] += src[2];
dst_off[3] += src[3];
}
offset += m;
let dst_off = &mut dst[offset..offset + 4];
if OVERWRITE {
dst_off.copy_from_slice(&src[4..8]);
} else {
dst_off[0] += src[4];
dst_off[1] += src[5];
dst_off[2] += src[6];
dst_off[3] += src[7];
}
}
#[inline(always)]
pub fn reim4_save_2blk_to_reim_ref<const OVERWRITE: bool>(m: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
let mut offset: usize = blk << 2;
debug_assert!(blk < (m >> 2));
debug_assert!(dst.len() >= offset + 3 * m + 4);
debug_assert!(src.len() >= 16);
let dst_off = &mut dst[offset..offset + 4];
if OVERWRITE {
dst_off.copy_from_slice(&src[0..4]);
} else {
dst_off[0] += src[0];
dst_off[1] += src[1];
dst_off[2] += src[2];
dst_off[3] += src[3];
}
offset += m;
let dst_off = &mut dst[offset..offset + 4];
if OVERWRITE {
dst_off.copy_from_slice(&src[4..8]);
} else {
dst_off[0] += src[4];
dst_off[1] += src[5];
dst_off[2] += src[6];
dst_off[3] += src[7];
}
offset += m;
let dst_off = &mut dst[offset..offset + 4];
if OVERWRITE {
dst_off.copy_from_slice(&src[8..12]);
} else {
dst_off[0] += src[8];
dst_off[1] += src[9];
dst_off[2] += src[10];
dst_off[3] += src[11];
}
offset += m;
let dst_off = &mut dst[offset..offset + 4];
if OVERWRITE {
dst_off.copy_from_slice(&src[12..16]);
} else {
dst_off[0] += src[12];
dst_off[1] += src[13];
dst_off[2] += src[14];
dst_off[3] += src[15];
}
}
#[inline(always)]
pub fn reim4_vec_mat1col_product_ref(
nrows: usize,
dst: &mut [f64], // 8 doubles: [re1(4), im1(4)]
u: &[f64], // nrows * 8 doubles: [ur(4) | ui(4)] per row
v: &[f64], // nrows * 8 doubles: [ar(4) | ai(4)] per row
) {
#[cfg(debug_assertions)]
{
assert!(dst.len() >= 8, "dst must have at least 8 doubles");
assert!(u.len() >= nrows * 8, "u must be at least nrows * 8 doubles");
assert!(v.len() >= nrows * 8, "v must be at least nrows * 8 doubles");
}
println!("u_ref: {:?}", &u[..nrows * 8]);
println!("v_ref: {:?}", &v[..nrows * 8]);
let mut acc: [f64; 8] = [0f64; 8];
let mut j = 0;
for _ in 0..nrows {
reim4_add_mul(&mut acc, as_arr(&u[j..]), as_arr(&v[j..]));
j += 8;
}
dst[0..8].copy_from_slice(&acc);
println!("dst_ref: {:?}", &dst[..8]);
println!();
}
#[inline(always)]
pub fn reim4_vec_mat2cols_product_ref(
nrows: usize,
dst: &mut [f64], // 16 doubles: [re1(4), im1(4), re2(4), im2(4)]
u: &[f64], // nrows * 8 doubles: [ur(4) | ui(4)] per row
v: &[f64], // nrows * 16 doubles: [ar(4) | ai(4) | br(4) | bi(4)] per row
) {
#[cfg(debug_assertions)]
{
assert_eq!(dst.len(), 16, "dst must have 16 doubles");
assert!(u.len() >= nrows * 8, "u must be at least nrows * 8 doubles");
assert!(
v.len() >= nrows * 16,
"v must be at least nrows * 16 doubles"
);
}
// zero accumulators
let mut acc_0: [f64; 8] = [0f64; 8];
let mut acc_1: [f64; 8] = [0f64; 8];
for i in 0..nrows {
let _1j: usize = i << 3;
let _2j: usize = i << 4;
let u_j: &[f64; 8] = as_arr(&u[_1j..]);
reim4_add_mul(&mut acc_0, u_j, as_arr(&v[_2j..]));
reim4_add_mul(&mut acc_1, u_j, as_arr(&v[_2j + 8..]));
}
dst[0..8].copy_from_slice(&acc_0);
dst[8..16].copy_from_slice(&acc_1);
}
#[inline(always)]
pub fn reim4_vec_mat2cols_2ndcol_product_ref(
nrows: usize,
dst: &mut [f64], // 8 doubles: [re1(4), im1(4), re2(4), im2(4)]
u: &[f64], // nrows * 8 doubles: [ur(4) | ui(4)] per row
v: &[f64], // nrows * 16 doubles: [x | x | br(4) | bi(4)] per row
) {
#[cfg(debug_assertions)]
{
assert!(
dst.len() >= 8,
"dst must be at least 8 doubles but is {}",
dst.len()
);
assert!(
u.len() >= nrows * 8,
"u must be at least nrows={} * 8 doubles but is {}",
nrows,
u.len()
);
assert!(
v.len() >= nrows * 16,
"v must be at least nrows={} * 16 doubles but is {}",
nrows,
v.len()
);
}
// zero accumulators
let mut acc: [f64; 8] = [0f64; 8];
for i in 0..nrows {
let _1j: usize = i << 3;
let _2j: usize = i << 4;
reim4_add_mul(&mut acc, as_arr(&u[_1j..]), as_arr(&v[_2j + 8..]));
}
dst[0..8].copy_from_slice(&acc);
}
#[inline(always)]
pub fn reim4_add_mul(dst: &mut [f64; 8], a: &[f64; 8], b: &[f64; 8]) {
for k in 0..4 {
let ar: f64 = a[k];
let br: f64 = b[k];
let ai: f64 = a[k + 4];
let bi: f64 = b[k + 4];
dst[k] += ar * br - ai * bi;
dst[k + 4] += ar * bi + ai * br;
}
}

View File

@@ -0,0 +1,27 @@
mod arithmetic_ref;
pub use arithmetic_ref::*;
pub trait Reim4Extract1Blk {
fn reim4_extract_1blk(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]);
}
pub trait Reim4Save1Blk {
fn reim4_save_1blk<const OVERWRITE: bool>(m: usize, blk: usize, dst: &mut [f64], src: &[f64]);
}
pub trait Reim4Save2Blks {
fn reim4_save_2blks<const OVERWRITE: bool>(m: usize, blk: usize, dst: &mut [f64], src: &[f64]);
}
pub trait Reim4Mat1ColProd {
fn reim4_mat1col_prod(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]);
}
pub trait Reim4Mat2ColsProd {
fn reim4_mat2cols_prod(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]);
}
pub trait Reim4Mat2Cols2ndColProd {
fn reim4_mat2cols_2ndcol_prod(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]);
}

View File

@@ -0,0 +1,119 @@
use crate::{
layouts::{
Backend, ScalarZnx, ScalarZnxToRef, SvpPPol, SvpPPolToMut, SvpPPolToRef, VecZnx, VecZnxDft, VecZnxDftToMut,
VecZnxDftToRef, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut,
},
reference::fft64::reim::{ReimAddMul, ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimMul, ReimMulInplace, ReimZero},
};
pub fn svp_prepare<R, A, BE>(table: &ReimFFTTable<f64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarPrep = f64> + ReimDFTExecute<ReimFFTTable<f64>, f64> + ReimFromZnx,
R: SvpPPolToMut<BE>,
A: ScalarZnxToRef,
{
let mut res: SvpPPol<&mut [u8], BE> = res.to_mut();
let a: ScalarZnx<&[u8]> = a.to_ref();
BE::reim_from_znx(res.at_mut(res_col, 0), a.at(a_col, 0));
BE::reim_dft_execute(table, res.at_mut(res_col, 0));
}
pub fn svp_apply_dft<R, A, B, BE>(
table: &ReimFFTTable<f64>,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
b: &B,
b_col: usize,
) where
BE: Backend<ScalarPrep = f64> + ReimDFTExecute<ReimFFTTable<f64>, f64> + ReimZero + ReimFromZnx + ReimMulInplace,
R: VecZnxDftToMut<BE>,
A: SvpPPolToRef<BE>,
B: VecZnxToRef,
{
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
let a: SvpPPol<&[u8], BE> = a.to_ref();
let b: VecZnx<&[u8]> = b.to_ref();
let res_size: usize = res.size();
let b_size: usize = b.size();
let min_size: usize = res_size.min(b_size);
let ppol: &[f64] = a.at(a_col, 0);
for j in 0..min_size {
let out: &mut [f64] = res.at_mut(res_col, j);
BE::reim_from_znx(out, b.at(b_col, j));
BE::reim_dft_execute(table, out);
BE::reim_mul_inplace(out, ppol);
}
for j in min_size..res_size {
BE::reim_zero(res.at_mut(res_col, j));
}
}
pub fn svp_apply_dft_to_dft<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
BE: Backend<ScalarPrep = f64> + ReimMul + ReimZero,
R: VecZnxDftToMut<BE>,
A: SvpPPolToRef<BE>,
B: VecZnxDftToRef<BE>,
{
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
let a: SvpPPol<&[u8], BE> = a.to_ref();
let b: VecZnxDft<&[u8], BE> = b.to_ref();
let res_size: usize = res.size();
let b_size: usize = b.size();
let min_size: usize = res_size.min(b_size);
let ppol: &[f64] = a.at(a_col, 0);
for j in 0..min_size {
BE::reim_mul(res.at_mut(res_col, j), ppol, b.at(b_col, j));
}
for j in min_size..res_size {
BE::reim_zero(res.at_mut(res_col, j));
}
}
pub fn svp_apply_dft_to_dft_add<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
BE: Backend<ScalarPrep = f64> + ReimAddMul + ReimZero,
R: VecZnxDftToMut<BE>,
A: SvpPPolToRef<BE>,
B: VecZnxDftToRef<BE>,
{
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
let a: SvpPPol<&[u8], BE> = a.to_ref();
let b: VecZnxDft<&[u8], BE> = b.to_ref();
let res_size: usize = res.size();
let b_size: usize = b.size();
let min_size: usize = res_size.min(b_size);
let ppol: &[f64] = a.at(a_col, 0);
for j in 0..min_size {
BE::reim_addmul(res.at_mut(res_col, j), ppol, b.at(b_col, j));
}
for j in min_size..res_size {
BE::reim_zero(res.at_mut(res_col, j));
}
}
pub fn svp_apply_dft_to_dft_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarPrep = f64> + ReimMulInplace,
R: VecZnxDftToMut<BE>,
A: SvpPPolToRef<BE>,
{
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
let a: SvpPPol<&[u8], BE> = a.to_ref();
let ppol: &[f64] = a.at(a_col, 0);
for j in 0..res.size() {
BE::reim_mul_inplace(res.at_mut(res_col, j), ppol);
}
}

View File

@@ -0,0 +1,521 @@
use std::f64::consts::SQRT_2;
use crate::{
api::VecZnxBigAddNormal,
layouts::{
Backend, Module, VecZnx, VecZnxBig, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef, ZnxView, ZnxViewMut,
},
oep::VecZnxBigAllocBytesImpl,
reference::{
vec_znx::{
vec_znx_add, vec_znx_add_inplace, vec_znx_automorphism, vec_znx_automorphism_inplace, vec_znx_negate,
vec_znx_negate_inplace, vec_znx_normalize, vec_znx_sub, vec_znx_sub_ab_inplace, vec_znx_sub_ba_inplace,
},
znx::{
ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxNormalizeFinalStep,
ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly,
ZnxSub, ZnxSubABInplace, ZnxSubBAInplace, ZnxZero, znx_add_normal_f64_ref,
},
},
source::Source,
};
pub fn vec_znx_big_add<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
BE: Backend<ScalarBig = i64> + ZnxAdd + ZnxCopy + ZnxZero,
R: VecZnxBigToMut<BE>,
A: VecZnxBigToRef<BE>,
B: VecZnxBigToRef<BE>,
{
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
let a: VecZnxBig<&[u8], BE> = a.to_ref();
let b: VecZnxBig<&[u8], BE> = b.to_ref();
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
data: res.data,
n: res.n,
cols: res.cols,
size: res.size,
max_size: res.max_size,
};
let a_vznx: VecZnx<&[u8]> = VecZnx {
data: a.data,
n: a.n,
cols: a.cols,
size: a.size,
max_size: a.max_size,
};
let b_vznx: VecZnx<&[u8]> = VecZnx {
data: b.data,
n: b.n,
cols: b.cols,
size: b.size,
max_size: b.max_size,
};
vec_znx_add::<_, _, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col, &b_vznx, b_col);
}
pub fn vec_znx_big_add_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarBig = i64> + ZnxAddInplace,
R: VecZnxBigToMut<BE>,
A: VecZnxBigToRef<BE>,
{
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
let a: VecZnxBig<&[u8], BE> = a.to_ref();
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
data: res.data,
n: res.n,
cols: res.cols,
size: res.size,
max_size: res.max_size,
};
let a_vznx: VecZnx<&[u8]> = VecZnx {
data: a.data,
n: a.n,
cols: a.cols,
size: a.size,
max_size: a.max_size,
};
vec_znx_add_inplace::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col);
}
pub fn vec_znx_big_add_small<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
BE: Backend<ScalarBig = i64> + ZnxAdd + ZnxCopy + ZnxZero,
R: VecZnxBigToMut<BE>,
A: VecZnxBigToRef<BE>,
B: VecZnxToRef,
{
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
let a: VecZnxBig<&[u8], BE> = a.to_ref();
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
data: res.data,
n: res.n,
cols: res.cols,
size: res.size,
max_size: res.max_size,
};
let a_vznx: VecZnx<&[u8]> = VecZnx {
data: a.data,
n: a.n,
cols: a.cols,
size: a.size,
max_size: a.max_size,
};
vec_znx_add::<_, _, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col, b, b_col);
}
pub fn vec_znx_big_add_small_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarBig = i64> + ZnxAddInplace,
R: VecZnxBigToMut<BE>,
A: VecZnxToRef,
{
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
data: res.data,
n: res.n,
cols: res.cols,
size: res.size,
max_size: res.max_size,
};
vec_znx_add_inplace::<_, _, BE>(&mut res_vznx, res_col, a, a_col);
}
pub fn vec_znx_big_automorphism_inplace_tmp_bytes(n: usize) -> usize {
n * size_of::<i64>()
}
pub fn vec_znx_big_automorphism<R, A, BE>(p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarBig = i64> + ZnxAutomorphism + ZnxZero,
R: VecZnxBigToMut<BE>,
A: VecZnxBigToRef<BE>,
{
let res: VecZnxBig<&mut [u8], _> = res.to_mut();
let a: VecZnxBig<&[u8], _> = a.to_ref();
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
data: res.data,
n: res.n,
cols: res.cols,
size: res.size,
max_size: res.max_size,
};
let a_vznx: VecZnx<&[u8]> = VecZnx {
data: a.data,
n: a.n,
cols: a.cols,
size: a.size,
max_size: a.max_size,
};
vec_znx_automorphism::<_, _, BE>(p, &mut res_vznx, res_col, &a_vznx, a_col);
}
pub fn vec_znx_big_automorphism_inplace<R, BE>(p: i64, res: &mut R, res_col: usize, tmp: &mut [i64])
where
BE: Backend<ScalarBig = i64> + ZnxAutomorphism + ZnxCopy,
R: VecZnxBigToMut<BE>,
{
let res: VecZnxBig<&mut [u8], _> = res.to_mut();
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
data: res.data,
n: res.n,
cols: res.cols,
size: res.size,
max_size: res.max_size,
};
vec_znx_automorphism_inplace::<_, BE>(p, &mut res_vznx, res_col, tmp);
}
pub fn vec_znx_big_negate<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarBig = i64> + ZnxNegate + ZnxZero,
R: VecZnxBigToMut<BE>,
A: VecZnxBigToRef<BE>,
{
let res: VecZnxBig<&mut [u8], _> = res.to_mut();
let a: VecZnxBig<&[u8], _> = a.to_ref();
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
data: res.data,
n: res.n,
cols: res.cols,
size: res.size,
max_size: res.max_size,
};
let a_vznx: VecZnx<&[u8]> = VecZnx {
data: a.data,
n: a.n,
cols: a.cols,
size: a.size,
max_size: a.max_size,
};
vec_znx_negate::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col);
}
pub fn vec_znx_big_negate_inplace<R, BE>(res: &mut R, res_col: usize)
where
BE: Backend<ScalarBig = i64> + ZnxNegateInplace,
R: VecZnxBigToMut<BE>,
{
let res: VecZnxBig<&mut [u8], _> = res.to_mut();
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
data: res.data,
n: res.n,
cols: res.cols,
size: res.size,
max_size: res.max_size,
};
vec_znx_negate_inplace::<_, BE>(&mut res_vznx, res_col);
}
pub fn vec_znx_big_normalize_tmp_bytes(n: usize) -> usize {
n * size_of::<i64>()
}
pub fn vec_znx_big_normalize<R, A, BE>(basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
where
R: VecZnxToMut,
A: VecZnxBigToRef<BE>,
BE: Backend<ScalarBig = i64>
+ ZnxNormalizeFirstStepCarryOnly
+ ZnxNormalizeMiddleStepCarryOnly
+ ZnxNormalizeMiddleStep
+ ZnxNormalizeFinalStep
+ ZnxNormalizeFirstStep
+ ZnxZero,
{
let a: VecZnxBig<&[u8], _> = a.to_ref();
let a_vznx: VecZnx<&[u8]> = VecZnx {
data: a.data,
n: a.n,
cols: a.cols,
size: a.size,
max_size: a.max_size,
};
vec_znx_normalize::<_, _, BE>(basek, res, res_col, &a_vznx, a_col, carry);
}
pub fn vec_znx_big_add_normal_ref<R, B: Backend<ScalarBig = i64>>(
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
sigma: f64,
bound: f64,
source: &mut Source,
) where
R: VecZnxBigToMut<B>,
{
let mut res: VecZnxBig<&mut [u8], B> = res.to_mut();
assert!(
(bound.log2().ceil() as i64) < 64,
"invalid bound: ceil(log2(bound))={} > 63",
(bound.log2().ceil() as i64)
);
let limb: usize = k.div_ceil(basek) - 1;
let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64;
znx_add_normal_f64_ref(
res.at_mut(res_col, limb),
sigma * scale,
bound * scale,
source,
)
}
pub fn test_vec_znx_big_add_normal<B>(module: &Module<B>)
where
Module<B>: VecZnxBigAddNormal<B>,
B: Backend<ScalarBig = i64> + VecZnxBigAllocBytesImpl<B>,
{
let n: usize = module.n();
let basek: usize = 17;
let k: usize = 2 * 17;
let size: usize = 5;
let sigma: f64 = 3.2;
let bound: f64 = 6.0 * sigma;
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let zero: Vec<i64> = vec![0; n];
let k_f64: f64 = (1u64 << k as u64) as f64;
let sqrt2: f64 = SQRT_2;
(0..cols).for_each(|col_i| {
let mut a: VecZnxBig<Vec<u8>, B> = VecZnxBig::alloc(n, cols, size);
module.vec_znx_big_add_normal(basek, &mut a, col_i, k, &mut source, sigma, bound);
module.vec_znx_big_add_normal(basek, &mut a, col_i, k, &mut source, sigma, bound);
(0..cols).for_each(|col_j| {
if col_j != col_i {
(0..size).for_each(|limb_i| {
assert_eq!(a.at(col_j, limb_i), zero);
})
} else {
let std: f64 = a.std(basek, col_i) * k_f64;
assert!(
(std - sigma * sqrt2).abs() < 0.1,
"std={} ~!= {}",
std,
sigma * sqrt2
);
}
})
});
}
/// R <- A - B
pub fn vec_znx_big_sub<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
BE: Backend<ScalarBig = i64> + ZnxSub + ZnxNegate + ZnxZero + ZnxCopy,
R: VecZnxBigToMut<BE>,
A: VecZnxBigToRef<BE>,
B: VecZnxBigToRef<BE>,
{
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
let a: VecZnxBig<&[u8], BE> = a.to_ref();
let b: VecZnxBig<&[u8], BE> = b.to_ref();
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
data: res.data,
n: res.n,
cols: res.cols,
size: res.size,
max_size: res.max_size,
};
let a_vznx: VecZnx<&[u8]> = VecZnx {
data: a.data,
n: a.n,
cols: a.cols,
size: a.size,
max_size: a.max_size,
};
let b_vznx: VecZnx<&[u8]> = VecZnx {
data: b.data,
n: b.n,
cols: b.cols,
size: b.size,
max_size: b.max_size,
};
vec_znx_sub::<_, _, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col, &b_vznx, b_col);
}
/// R <- A - B
pub fn vec_znx_big_sub_ab_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarBig = i64> + ZnxSubABInplace,
R: VecZnxBigToMut<BE>,
A: VecZnxBigToRef<BE>,
{
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
let a: VecZnxBig<&[u8], BE> = a.to_ref();
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
data: res.data,
n: res.n,
cols: res.cols,
size: res.size,
max_size: res.max_size,
};
let a_vznx: VecZnx<&[u8]> = VecZnx {
data: a.data,
n: a.n,
cols: a.cols,
size: a.size,
max_size: a.max_size,
};
vec_znx_sub_ab_inplace::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col);
}
/// R <- B - A
pub fn vec_znx_big_sub_ba_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarBig = i64> + ZnxSubBAInplace + ZnxNegateInplace,
R: VecZnxBigToMut<BE>,
A: VecZnxBigToRef<BE>,
{
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
let a: VecZnxBig<&[u8], BE> = a.to_ref();
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
data: res.data,
n: res.n,
cols: res.cols,
size: res.size,
max_size: res.max_size,
};
let a_vznx: VecZnx<&[u8]> = VecZnx {
data: a.data,
n: a.n,
cols: a.cols,
size: a.size,
max_size: a.max_size,
};
vec_znx_sub_ba_inplace::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col);
}
/// R <- A - B
pub fn vec_znx_big_sub_small_a<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
BE: Backend<ScalarBig = i64> + ZnxSub + ZnxNegate + ZnxZero + ZnxCopy,
R: VecZnxBigToMut<BE>,
A: VecZnxToRef,
B: VecZnxBigToRef<BE>,
{
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
let b: VecZnxBig<&[u8], BE> = b.to_ref();
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
data: res.data,
n: res.n,
cols: res.cols,
size: res.size,
max_size: res.max_size,
};
let b_vznx: VecZnx<&[u8]> = VecZnx {
data: b.data,
n: b.n,
cols: b.cols,
size: b.size,
max_size: b.max_size,
};
vec_znx_sub::<_, _, _, BE>(&mut res_vznx, res_col, a, a_col, &b_vznx, b_col);
}
/// R <- A - B
pub fn vec_znx_big_sub_small_b<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
BE: Backend<ScalarBig = i64> + ZnxSub + ZnxNegate + ZnxZero + ZnxCopy,
R: VecZnxBigToMut<BE>,
A: VecZnxBigToRef<BE>,
B: VecZnxToRef,
{
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
let a: VecZnxBig<&[u8], BE> = a.to_ref();
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
data: res.data,
n: res.n,
cols: res.cols,
size: res.size,
max_size: res.max_size,
};
let a_vznx: VecZnx<&[u8]> = VecZnx {
data: a.data,
n: a.n,
cols: a.cols,
size: a.size,
max_size: a.max_size,
};
vec_znx_sub::<_, _, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col, b, b_col);
}
/// R <- R - A
pub fn vec_znx_big_sub_small_a_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarBig = i64> + ZnxSubABInplace,
R: VecZnxBigToMut<BE>,
A: VecZnxToRef,
{
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
data: res.data,
n: res.n,
cols: res.cols,
size: res.size,
max_size: res.max_size,
};
vec_znx_sub_ab_inplace::<_, _, BE>(&mut res_vznx, res_col, a, a_col);
}
/// R <- A - R
pub fn vec_znx_big_sub_small_b_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarBig = i64> + ZnxSubBAInplace + ZnxNegateInplace,
R: VecZnxBigToMut<BE>,
A: VecZnxToRef,
{
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
data: res.data,
n: res.n,
cols: res.cols,
size: res.size,
max_size: res.max_size,
};
vec_znx_sub_ba_inplace::<_, _, BE>(&mut res_vznx, res_col, a, a_col);
}

View File

@@ -0,0 +1,369 @@
use bytemuck::cast_slice_mut;
use crate::{
layouts::{
Backend, Data, VecZnx, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, ZnxInfos,
ZnxView, ZnxViewMut,
},
reference::{
fft64::reim::{
ReimAdd, ReimAddInplace, ReimCopy, ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimIFFTTable, ReimNegate,
ReimNegateInplace, ReimSub, ReimSubABInplace, ReimSubBAInplace, ReimToZnx, ReimToZnxInplace, ReimZero,
},
znx::ZnxZero,
},
};
pub fn vec_znx_dft_add<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
BE: Backend<ScalarPrep = f64> + ReimAdd + ReimCopy + ReimZero,
R: VecZnxDftToMut<BE>,
A: VecZnxDftToRef<BE>,
B: VecZnxDftToRef<BE>,
{
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
let a: VecZnxDft<&[u8], BE> = a.to_ref();
let b: VecZnxDft<&[u8], BE> = b.to_ref();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
assert_eq!(b.n(), res.n());
}
let res_size: usize = res.size();
let a_size: usize = a.size();
let b_size: usize = b.size();
if a_size <= b_size {
let sum_size: usize = a_size.min(res_size);
let cpy_size: usize = b_size.min(res_size);
for j in 0..sum_size {
BE::reim_add(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j));
}
for j in sum_size..cpy_size {
BE::reim_copy(res.at_mut(res_col, j), b.at(b_col, j));
}
for j in cpy_size..res_size {
BE::reim_zero(res.at_mut(res_col, j));
}
} else {
let sum_size: usize = b_size.min(res_size);
let cpy_size: usize = a_size.min(res_size);
for j in 0..sum_size {
BE::reim_add(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j));
}
for j in sum_size..cpy_size {
BE::reim_copy(res.at_mut(res_col, j), a.at(a_col, j));
}
for j in cpy_size..res_size {
BE::reim_zero(res.at_mut(res_col, j));
}
}
}
pub fn vec_znx_dft_add_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarPrep = f64> + ReimAddInplace,
R: VecZnxDftToMut<BE>,
A: VecZnxDftToRef<BE>,
{
let a: VecZnxDft<&[u8], BE> = a.to_ref();
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
}
let res_size: usize = res.size();
let a_size: usize = a.size();
let sum_size: usize = a_size.min(res_size);
for j in 0..sum_size {
BE::reim_add_inplace(res.at_mut(res_col, j), a.at(a_col, j));
}
}
pub fn vec_znx_dft_copy<R, A, BE>(step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarPrep = f64> + ReimCopy + ReimZero,
R: VecZnxDftToMut<BE>,
A: VecZnxDftToRef<BE>,
{
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
let a: VecZnxDft<&[u8], BE> = a.to_ref();
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), a.n())
}
let steps: usize = a.size().div_ceil(step);
let min_steps: usize = res.size().min(steps);
(0..min_steps).for_each(|j| {
let limb: usize = offset + j * step;
if limb < a.size() {
BE::reim_copy(res.at_mut(res_col, j), a.at(a_col, limb));
}
});
(min_steps..res.size()).for_each(|j| {
BE::reim_zero(res.at_mut(res_col, j));
})
}
pub fn vec_znx_dft_apply<R, A, BE>(
table: &ReimFFTTable<f64>,
step: usize,
offset: usize,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
) where
BE: Backend<ScalarPrep = f64> + ReimDFTExecute<ReimFFTTable<f64>, f64> + ReimFromZnx + ReimZero,
R: VecZnxDftToMut<BE>,
A: VecZnxToRef,
{
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
let a: VecZnx<&[u8]> = a.to_ref();
#[cfg(debug_assertions)]
{
assert!(step > 0);
assert_eq!(table.m() << 1, res.n());
assert_eq!(a.n(), res.n());
}
let a_size: usize = a.size();
let res_size: usize = res.size();
let steps: usize = a_size.div_ceil(step);
let min_steps: usize = res_size.min(steps);
for j in 0..min_steps {
let limb = offset + j * step;
if limb < a_size {
BE::reim_from_znx(res.at_mut(res_col, j), a.at(a_col, limb));
BE::reim_dft_execute(table, res.at_mut(res_col, j));
}
}
(min_steps..res.size()).for_each(|j| {
BE::reim_zero(res.at_mut(res_col, j));
});
}
pub fn vec_znx_idft_apply<R, A, BE>(table: &ReimIFFTTable<f64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarPrep = f64, ScalarBig = i64>
+ ReimDFTExecute<ReimIFFTTable<f64>, f64>
+ ReimCopy
+ ReimToZnxInplace
+ ZnxZero,
R: VecZnxBigToMut<BE>,
A: VecZnxDftToRef<BE>,
{
let mut res: VecZnxBig<&mut [u8], BE> = res.to_mut();
let a: VecZnxDft<&[u8], BE> = a.to_ref();
#[cfg(debug_assertions)]
{
assert_eq!(table.m() << 1, res.n());
assert_eq!(a.n(), res.n());
}
let res_size: usize = res.size();
let min_size: usize = res_size.min(a.size());
let divisor: f64 = table.m() as f64;
for j in 0..min_size {
let res_slice_f64: &mut [f64] = cast_slice_mut(res.at_mut(res_col, j));
BE::reim_copy(res_slice_f64, a.at(a_col, j));
BE::reim_dft_execute(table, res_slice_f64);
BE::reim_to_znx_inplace(res_slice_f64, divisor);
}
for j in min_size..res_size {
BE::znx_zero(res.at_mut(res_col, j));
}
}
pub fn vec_znx_idft_apply_tmpa<R, A, BE>(table: &ReimIFFTTable<f64>, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
where
BE: Backend<ScalarPrep = f64, ScalarBig = i64> + ReimDFTExecute<ReimIFFTTable<f64>, f64> + ReimToZnx + ZnxZero,
R: VecZnxBigToMut<BE>,
A: VecZnxDftToMut<BE>,
{
let mut res: VecZnxBig<&mut [u8], BE> = res.to_mut();
let mut a: VecZnxDft<&mut [u8], BE> = a.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(table.m() << 1, res.n());
assert_eq!(a.n(), res.n());
}
let res_size = res.size();
let min_size: usize = res_size.min(a.size());
let divisor: f64 = table.m() as f64;
for j in 0..min_size {
BE::reim_dft_execute(table, a.at_mut(a_col, j));
BE::reim_to_znx(res.at_mut(res_col, j), divisor, a.at(a_col, j));
}
for j in min_size..res_size {
BE::znx_zero(res.at_mut(res_col, j));
}
}
pub fn vec_znx_idft_apply_consume<D: Data, BE>(table: &ReimIFFTTable<f64>, mut res: VecZnxDft<D, BE>) -> VecZnxBig<D, BE>
where
BE: Backend<ScalarPrep = f64, ScalarBig = i64> + ReimDFTExecute<ReimIFFTTable<f64>, f64> + ReimToZnxInplace,
VecZnxDft<D, BE>: VecZnxDftToMut<BE>,
{
{
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(table.m() << 1, res.n());
}
let divisor: f64 = table.m() as f64;
for i in 0..res.cols() {
for j in 0..res.size() {
BE::reim_dft_execute(table, res.at_mut(i, j));
BE::reim_to_znx_inplace(res.at_mut(i, j), divisor);
}
}
}
res.into_big()
}
pub fn vec_znx_dft_sub<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
BE: Backend<ScalarPrep = f64> + ReimSub + ReimNegate + ReimZero + ReimCopy,
R: VecZnxDftToMut<BE>,
A: VecZnxDftToRef<BE>,
B: VecZnxDftToRef<BE>,
{
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
let a: VecZnxDft<&[u8], BE> = a.to_ref();
let b: VecZnxDft<&[u8], BE> = b.to_ref();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
assert_eq!(b.n(), res.n());
}
let res_size: usize = res.size();
let a_size: usize = a.size();
let b_size: usize = b.size();
if a_size <= b_size {
let sum_size: usize = a_size.min(res_size);
let cpy_size: usize = b_size.min(res_size);
for j in 0..sum_size {
BE::reim_sub(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j));
}
for j in sum_size..cpy_size {
BE::reim_negate(res.at_mut(res_col, j), b.at(b_col, j));
}
for j in cpy_size..res_size {
BE::reim_zero(res.at_mut(res_col, j));
}
} else {
let sum_size: usize = b_size.min(res_size);
let cpy_size: usize = a_size.min(res_size);
for j in 0..sum_size {
BE::reim_sub(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j));
}
for j in sum_size..cpy_size {
BE::reim_copy(res.at_mut(res_col, j), a.at(a_col, j));
}
for j in cpy_size..res_size {
BE::reim_zero(res.at_mut(res_col, j));
}
}
}
pub fn vec_znx_dft_sub_ab_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarPrep = f64> + ReimSubABInplace,
R: VecZnxDftToMut<BE>,
A: VecZnxDftToRef<BE>,
{
let a: VecZnxDft<&[u8], BE> = a.to_ref();
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
}
let res_size: usize = res.size();
let a_size: usize = a.size();
let sum_size: usize = a_size.min(res_size);
for j in 0..sum_size {
BE::reim_sub_ab_inplace(res.at_mut(res_col, j), a.at(a_col, j));
}
}
pub fn vec_znx_dft_sub_ba_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarPrep = f64> + ReimSubBAInplace + ReimNegateInplace,
R: VecZnxDftToMut<BE>,
A: VecZnxDftToRef<BE>,
{
let a: VecZnxDft<&[u8], BE> = a.to_ref();
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
}
let res_size: usize = res.size();
let a_size: usize = a.size();
let sum_size: usize = a_size.min(res_size);
for j in 0..sum_size {
BE::reim_sub_ba_inplace(res.at_mut(res_col, j), a.at(a_col, j));
}
for j in sum_size..res_size {
BE::reim_negate_inplace(res.at_mut(res_col, j));
}
}
pub fn vec_znx_dft_zero<R, BE>(res: &mut R)
where
R: VecZnxDftToMut<BE>,
BE: Backend<ScalarPrep = f64> + ReimZero,
{
BE::reim_zero(res.to_mut().raw_mut());
}

View File

@@ -0,0 +1,365 @@
use crate::{
cast_mut,
layouts::{MatZnx, MatZnxToRef, VecZnx, VecZnxToRef, VmpPMatToMut, ZnxView, ZnxViewMut},
oep::VecZnxDftAllocBytesImpl,
reference::fft64::{
reim::{ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimZero},
reim4::{Reim4Extract1Blk, Reim4Mat1ColProd, Reim4Mat2Cols2ndColProd, Reim4Mat2ColsProd, Reim4Save1Blk, Reim4Save2Blks},
vec_znx_dft::vec_znx_dft_apply,
},
};
use crate::layouts::{Backend, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, VmpPMat, VmpPMatToRef, ZnxInfos};
pub fn vmp_prepare_tmp_bytes(n: usize) -> usize {
n * size_of::<i64>()
}
pub fn vmp_prepare<R, A, BE>(table: &ReimFFTTable<f64>, pmat: &mut R, mat: &A, tmp: &mut [f64])
where
BE: Backend<ScalarPrep = f64> + ReimDFTExecute<ReimFFTTable<f64>, f64> + ReimFromZnx + Reim4Extract1Blk,
R: VmpPMatToMut<BE>,
A: MatZnxToRef,
{
let mut res: crate::layouts::VmpPMat<&mut [u8], BE> = pmat.to_mut();
let a: MatZnx<&[u8]> = mat.to_ref();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
assert_eq!(
res.cols_in(),
a.cols_in(),
"res.cols_in: {} != a.cols_in: {}",
res.cols_in(),
a.cols_in()
);
assert_eq!(
res.rows(),
a.rows(),
"res.rows: {} != a.rows: {}",
res.rows(),
a.rows()
);
assert_eq!(
res.cols_out(),
a.cols_out(),
"res.cols_out: {} != a.cols_out: {}",
res.cols_out(),
a.cols_out()
);
assert_eq!(
res.size(),
a.size(),
"res.size: {} != a.size: {}",
res.size(),
a.size()
);
}
let nrows: usize = a.cols_in() * a.rows();
let ncols: usize = a.cols_out() * a.size();
vmp_prepare_core::<BE>(table, res.raw_mut(), a.raw(), nrows, ncols, tmp);
}
pub(crate) fn vmp_prepare_core<REIM>(
table: &ReimFFTTable<f64>,
pmat: &mut [f64],
mat: &[i64],
nrows: usize,
ncols: usize,
tmp: &mut [f64],
) where
REIM: ReimDFTExecute<ReimFFTTable<f64>, f64> + ReimFromZnx + Reim4Extract1Blk,
{
let m: usize = table.m();
let n: usize = m << 1;
#[cfg(debug_assertions)]
{
assert!(n >= 8);
assert_eq!(mat.len(), n * nrows * ncols);
assert_eq!(pmat.len(), n * nrows * ncols);
assert_eq!(tmp.len(), vmp_prepare_tmp_bytes(n) / size_of::<i64>())
}
let offset: usize = nrows * ncols * 8;
for row_i in 0..nrows {
for col_i in 0..ncols {
let pos: usize = n * (row_i * ncols + col_i);
REIM::reim_from_znx(tmp, &mat[pos..pos + n]);
REIM::reim_dft_execute(table, tmp);
let dst: &mut [f64] = if col_i == (ncols - 1) && !ncols.is_multiple_of(2) {
&mut pmat[col_i * nrows * 8 + row_i * 8..]
} else {
&mut pmat[(col_i / 2) * (nrows * 16) + row_i * 16 + (col_i % 2) * 8..]
};
for blk_i in 0..m >> 2 {
REIM::reim4_extract_1blk(m, 1, blk_i, &mut dst[blk_i * offset..], tmp);
}
}
}
}
pub fn vmp_apply_dft_tmp_bytes(n: usize, a_size: usize, prows: usize, pcols_in: usize) -> usize {
let row_max: usize = (a_size).min(prows);
(16 + (n + 8) * row_max * pcols_in) * size_of::<f64>()
}
pub fn vmp_apply_dft<R, A, M, BE>(table: &ReimFFTTable<f64>, res: &mut R, a: &A, pmat: &M, tmp_bytes: &mut [f64])
where
BE: Backend<ScalarPrep = f64>
+ VecZnxDftAllocBytesImpl<BE>
+ ReimDFTExecute<ReimFFTTable<f64>, f64>
+ ReimZero
+ Reim4Extract1Blk
+ Reim4Mat1ColProd
+ Reim4Mat2Cols2ndColProd
+ Reim4Mat2ColsProd
+ Reim4Save2Blks
+ Reim4Save1Blk
+ ReimFromZnx,
R: VecZnxDftToMut<BE>,
A: VecZnxToRef,
M: VmpPMatToRef<BE>,
{
let a: VecZnx<&[u8]> = a.to_ref();
let pmat: VmpPMat<&[u8], BE> = pmat.to_ref();
let n: usize = a.n();
let cols: usize = pmat.cols_in();
let size: usize = a.size().min(pmat.rows());
#[cfg(debug_assertions)]
{
assert!(tmp_bytes.len() >= vmp_apply_dft_tmp_bytes(n, size, pmat.rows(), cols));
assert!(a.cols() <= cols);
}
let (data, tmp_bytes) = tmp_bytes.split_at_mut(BE::vec_znx_dft_alloc_bytes_impl(n, cols, size));
let mut a_dft: VecZnxDft<&mut [u8], BE> = VecZnxDft::from_data(cast_mut(data), n, cols, size);
let offset: usize = cols - a.cols();
for j in 0..cols {
vec_znx_dft_apply(table, 1, 0, &mut a_dft, j, &a, offset + j);
}
vmp_apply_dft_to_dft(res, &a_dft, &pmat, tmp_bytes);
}
pub fn vmp_apply_dft_to_dft_tmp_bytes(a_size: usize, prows: usize, pcols_in: usize) -> usize {
let row_max: usize = (a_size).min(prows);
(16 + 8 * row_max * pcols_in) * size_of::<f64>()
}
pub fn vmp_apply_dft_to_dft<R, A, M, BE>(res: &mut R, a: &A, pmat: &M, tmp_bytes: &mut [f64])
where
BE: Backend<ScalarPrep = f64>
+ ReimZero
+ Reim4Extract1Blk
+ Reim4Mat1ColProd
+ Reim4Mat2Cols2ndColProd
+ Reim4Mat2ColsProd
+ Reim4Save2Blks
+ Reim4Save1Blk,
R: VecZnxDftToMut<BE>,
A: VecZnxDftToRef<BE>,
M: VmpPMatToRef<BE>,
{
use crate::layouts::{ZnxView, ZnxViewMut};
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
let a: VecZnxDft<&[u8], BE> = a.to_ref();
let pmat: VmpPMat<&[u8], BE> = pmat.to_ref();
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), pmat.n());
assert_eq!(a.n(), pmat.n());
assert_eq!(res.cols(), pmat.cols_out());
assert_eq!(a.cols(), pmat.cols_in());
}
let n: usize = res.n();
let nrows: usize = pmat.cols_in() * pmat.rows();
let ncols: usize = pmat.cols_out() * pmat.size();
let pmat_raw: &[f64] = pmat.raw();
let a_raw: &[f64] = a.raw();
let res_raw: &mut [f64] = res.raw_mut();
vmp_apply_dft_to_dft_core::<true, BE>(n, res_raw, a_raw, pmat_raw, 0, nrows, ncols, tmp_bytes)
}
pub fn vmp_apply_dft_to_dft_add<R, A, M, BE>(res: &mut R, a: &A, pmat: &M, limb_offset: usize, tmp_bytes: &mut [f64])
where
BE: Backend<ScalarPrep = f64>
+ ReimZero
+ Reim4Extract1Blk
+ Reim4Mat1ColProd
+ Reim4Mat2Cols2ndColProd
+ Reim4Mat2ColsProd
+ Reim4Save2Blks
+ Reim4Save1Blk,
R: VecZnxDftToMut<BE>,
A: VecZnxDftToRef<BE>,
M: VmpPMatToRef<BE>,
{
use crate::layouts::{ZnxView, ZnxViewMut};
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
let a: VecZnxDft<&[u8], BE> = a.to_ref();
let pmat: VmpPMat<&[u8], BE> = pmat.to_ref();
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), pmat.n());
assert_eq!(a.n(), pmat.n());
assert_eq!(res.cols(), pmat.cols_out());
assert_eq!(a.cols(), pmat.cols_in());
}
let n: usize = res.n();
let nrows: usize = pmat.cols_in() * pmat.rows();
let ncols: usize = pmat.cols_out() * pmat.size();
let pmat_raw: &[f64] = pmat.raw();
let a_raw: &[f64] = a.raw();
let res_raw: &mut [f64] = res.raw_mut();
vmp_apply_dft_to_dft_core::<false, BE>(
n,
res_raw,
a_raw,
pmat_raw,
limb_offset,
nrows,
ncols,
tmp_bytes,
)
}
#[allow(clippy::too_many_arguments)]
fn vmp_apply_dft_to_dft_core<const OVERWRITE: bool, REIM>(
n: usize,
res: &mut [f64],
a: &[f64],
pmat: &[f64],
limb_offset: usize,
nrows: usize,
ncols: usize,
tmp_bytes: &mut [f64],
) where
REIM: ReimZero
+ Reim4Extract1Blk
+ Reim4Mat1ColProd
+ Reim4Mat2Cols2ndColProd
+ Reim4Mat2ColsProd
+ Reim4Save2Blks
+ Reim4Save1Blk,
{
#[cfg(debug_assertions)]
{
assert!(n >= 8);
assert!(n.is_power_of_two());
assert_eq!(pmat.len(), n * nrows * ncols);
assert!(res.len() & (n - 1) == 0);
assert!(a.len() & (n - 1) == 0);
}
let a_size: usize = a.len() / n;
let res_size: usize = res.len() / n;
let m: usize = n >> 1;
let (mat2cols_output, extracted_blk) = tmp_bytes.split_at_mut(16);
let row_max: usize = nrows.min(a_size);
let col_max: usize = ncols.min(res_size);
if limb_offset >= col_max {
if OVERWRITE {
REIM::reim_zero(res);
}
return;
}
for blk_i in 0..(m >> 2) {
let mat_blk_start: &[f64] = &pmat[blk_i * (8 * nrows * ncols)..];
REIM::reim4_extract_1blk(m, row_max, blk_i, extracted_blk, a);
if limb_offset.is_multiple_of(2) {
for (col_res, col_pmat) in (0..).step_by(2).zip((limb_offset..col_max - 1).step_by(2)) {
let col_offset: usize = col_pmat * (8 * nrows);
REIM::reim4_mat2cols_prod(
row_max,
mat2cols_output,
extracted_blk,
&mat_blk_start[col_offset..],
);
REIM::reim4_save_2blks::<OVERWRITE>(m, blk_i, &mut res[col_res * n..], mat2cols_output);
}
} else {
let col_offset: usize = (limb_offset - 1) * (8 * nrows);
REIM::reim4_mat2cols_2ndcol_prod(
row_max,
mat2cols_output,
extracted_blk,
&mat_blk_start[col_offset..],
);
REIM::reim4_save_1blk::<OVERWRITE>(m, blk_i, res, mat2cols_output);
for (col_res, col_pmat) in (1..)
.step_by(2)
.zip((limb_offset + 1..col_max - 1).step_by(2))
{
let col_offset: usize = col_pmat * (8 * nrows);
REIM::reim4_mat2cols_prod(
row_max,
mat2cols_output,
extracted_blk,
&mat_blk_start[col_offset..],
);
REIM::reim4_save_2blks::<OVERWRITE>(m, blk_i, &mut res[col_res * n..], mat2cols_output);
}
}
if !col_max.is_multiple_of(2) {
let last_col: usize = col_max - 1;
let col_offset: usize = last_col * (8 * nrows);
if last_col >= limb_offset {
if ncols == col_max {
REIM::reim4_mat1col_prod(
row_max,
mat2cols_output,
extracted_blk,
&mat_blk_start[col_offset..],
);
} else {
REIM::reim4_mat2cols_prod(
row_max,
mat2cols_output,
extracted_blk,
&mat_blk_start[col_offset..],
);
}
REIM::reim4_save_1blk::<OVERWRITE>(
m,
blk_i,
&mut res[(last_col - limb_offset) * n..],
mat2cols_output,
);
}
}
}
REIM::reim_zero(&mut res[col_max * n..]);
}

View File

@@ -0,0 +1,4 @@
pub mod fft64;
pub mod vec_znx;
pub mod zn;
pub mod znx;

View File

@@ -0,0 +1,177 @@
use std::hint::black_box;
use criterion::{BenchmarkId, Criterion};
use crate::{
api::{ModuleNew, VecZnxAdd, VecZnxAddInplace},
layouts::{Backend, FillUniform, Module, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
reference::znx::{ZnxAdd, ZnxAddInplace, ZnxCopy, ZnxZero},
source::Source,
};
pub fn vec_znx_add<R, A, B, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
B: VecZnxToRef,
ZNXARI: ZnxAdd + ZnxCopy + ZnxZero,
{
let a: VecZnx<&[u8]> = a.to_ref();
let b: VecZnx<&[u8]> = b.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
assert_eq!(b.n(), res.n());
}
let res_size: usize = res.size();
let a_size: usize = a.size();
let b_size: usize = b.size();
if a_size <= b_size {
let sum_size: usize = a_size.min(res_size);
let cpy_size: usize = b_size.min(res_size);
for j in 0..sum_size {
ZNXARI::znx_add(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j));
}
for j in sum_size..cpy_size {
ZNXARI::znx_copy(res.at_mut(res_col, j), b.at(b_col, j));
}
for j in cpy_size..res_size {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
} else {
let sum_size: usize = b_size.min(res_size);
let cpy_size: usize = a_size.min(res_size);
for j in 0..sum_size {
ZNXARI::znx_add(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j));
}
for j in sum_size..cpy_size {
ZNXARI::znx_copy(res.at_mut(res_col, j), a.at(a_col, j));
}
for j in cpy_size..res_size {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
}
}
pub fn vec_znx_add_inplace<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxAddInplace,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
}
let res_size: usize = res.size();
let a_size: usize = a.size();
let sum_size: usize = a_size.min(res_size);
for j in 0..sum_size {
ZNXARI::znx_add_inplace(res.at_mut(res_col, j), a.at(a_col, j));
}
}
pub fn bench_vec_znx_add<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxAdd + ModuleNew<B>,
{
let group_name: String = format!("vec_znx_add::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxAdd + ModuleNew<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut c: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
// Fill a with random i64
a.fill_uniform(50, &mut source);
b.fill_uniform(50, &mut source);
move || {
for i in 0..cols {
module.vec_znx_add(&mut c, i, &a, i, &b, i);
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vec_znx_add_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxAddInplace + ModuleNew<B>,
{
let group_name: String = format!("vec_znx_add_inplace::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxAddInplace + ModuleNew<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
// Fill a with random i64
a.fill_uniform(50, &mut source);
b.fill_uniform(50, &mut source);
move || {
for i in 0..cols {
module.vec_znx_add_inplace(&mut b, i, &a, i);
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}

View File

@@ -0,0 +1,57 @@
use crate::{
layouts::{ScalarZnx, ScalarZnxToRef, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
reference::znx::{ZnxAdd, ZnxAddInplace, ZnxCopy, ZnxZero},
};
pub fn vec_znx_add_scalar<R, A, B, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize, b_limb: usize)
where
R: VecZnxToMut,
A: ScalarZnxToRef,
B: VecZnxToRef,
ZNXARI: ZnxAdd + ZnxCopy + ZnxZero,
{
let a: ScalarZnx<&[u8]> = a.to_ref();
let b: VecZnx<&[u8]> = b.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let min_size: usize = b.size().min(res.size());
#[cfg(debug_assertions)]
{
assert!(
b_limb < min_size,
"b_limb: {} > min_size: {}",
b_limb,
min_size
);
}
for j in 0..min_size {
if j == b_limb {
ZNXARI::znx_add(res.at_mut(res_col, j), a.at(a_col, 0), b.at(b_col, j));
} else {
ZNXARI::znx_copy(res.at_mut(res_col, j), b.at(b_col, j));
}
}
for j in min_size..res.size() {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
}
pub fn vec_znx_add_scalar_inplace<R, A, ZNXARI>(res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: ScalarZnxToRef,
ZNXARI: ZnxAddInplace,
{
let a: ScalarZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert!(res_limb < res.size());
}
ZNXARI::znx_add_inplace(res.at_mut(res_col, res_limb), a.at(a_col, 0));
}

View File

@@ -0,0 +1,150 @@
use std::hint::black_box;
use criterion::{BenchmarkId, Criterion};
use crate::{
api::{
ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAutomorphism, VecZnxAutomorphismInplace,
VecZnxAutomorphismInplaceTmpBytes,
},
layouts::{Backend, FillUniform, Module, ScratchOwned, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
reference::znx::{ZnxAutomorphism, ZnxCopy, ZnxZero},
source::Source,
};
pub fn vec_znx_automorphism_inplace_tmp_bytes(n: usize) -> usize {
n * size_of::<i64>()
}
pub fn vec_znx_automorphism<R, A, ZNXARI>(p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxAutomorphism + ZnxZero,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
use crate::layouts::ZnxInfos;
assert_eq!(a.n(), res.n());
}
let min_size: usize = res.size().min(a.size());
for j in 0..min_size {
ZNXARI::znx_automorphism(p, res.at_mut(res_col, j), a.at(a_col, j));
}
for j in min_size..res.size() {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
}
pub fn vec_znx_automorphism_inplace<R, ZNXARI>(p: i64, res: &mut R, res_col: usize, tmp: &mut [i64])
where
R: VecZnxToMut,
ZNXARI: ZnxAutomorphism + ZnxCopy,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), tmp.len());
}
for j in 0..res.size() {
ZNXARI::znx_automorphism(p, tmp, res.at(res_col, j));
ZNXARI::znx_copy(res.at_mut(res_col, j), tmp);
}
}
pub fn bench_vec_znx_automorphism<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxAutomorphism + ModuleNew<B>,
{
let group_name: String = format!("vec_znx_automorphism::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxAutomorphism + ModuleNew<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
// Fill a with random i64
a.fill_uniform(50, &mut source);
res.fill_uniform(50, &mut source);
move || {
for i in 0..cols {
module.vec_znx_automorphism(-7, &mut res, i, &a, i);
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vec_znx_automorphism_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxAutomorphismInplace<B> + VecZnxAutomorphismInplaceTmpBytes + ModuleNew<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_automorphism_inplace::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxAutomorphismInplace<B> + ModuleNew<B> + VecZnxAutomorphismInplaceTmpBytes,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let mut source: Source = Source::new([0u8; 32]);
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut scratch = ScratchOwned::alloc(module.vec_znx_automorphism_inplace_tmp_bytes());
// Fill a with random i64
res.fill_uniform(50, &mut source);
move || {
for i in 0..cols {
module.vec_znx_automorphism_inplace(-7, &mut res, i, scratch.borrow());
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}

View File

@@ -0,0 +1,32 @@
use crate::{
layouts::{VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
reference::znx::{ZnxCopy, ZnxZero},
};
pub fn vec_znx_copy<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxCopy + ZnxZero,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let a: VecZnx<&[u8]> = a.to_ref();
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), a.n())
}
let res_size = res.size();
let a_size = a.size();
let min_size = res_size.min(a_size);
for j in 0..min_size {
ZNXARI::znx_copy(res.at_mut(res_col, j), a.at(a_col, j));
}
for j in min_size..res_size {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
}

View File

@@ -0,0 +1,49 @@
use crate::{
layouts::{VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos},
reference::{
vec_znx::{vec_znx_rotate_inplace, vec_znx_switch_ring},
znx::{ZnxCopy, ZnxRotate, ZnxSwitchRing, ZnxZero},
},
};
pub fn vec_znx_merge_rings_tmp_bytes(n: usize) -> usize {
n * size_of::<i64>()
}
pub fn vec_znx_merge_rings<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &[A], a_col: usize, tmp: &mut [i64])
where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxCopy + ZnxSwitchRing + ZnxRotate + ZnxZero,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let (n_out, n_in) = (res.n(), a[0].to_ref().n());
#[cfg(debug_assertions)]
{
assert_eq!(tmp.len(), res.n());
debug_assert!(
n_out > n_in,
"invalid a: output ring degree should be greater"
);
a[1..].iter().for_each(|ai| {
debug_assert_eq!(
ai.to_ref().n(),
n_in,
"invalid input a: all VecZnx must have the same degree"
)
});
assert!(n_out.is_multiple_of(n_in));
assert_eq!(a.len(), n_out / n_in);
}
a.iter().for_each(|ai| {
vec_znx_switch_ring::<_, _, ZNXARI>(&mut res, res_col, ai, a_col);
vec_znx_rotate_inplace::<_, ZNXARI>(-1, &mut res, res_col, tmp);
});
vec_znx_rotate_inplace::<_, ZNXARI>(a.len() as i64, &mut res, res_col, tmp);
}

View File

@@ -0,0 +1,31 @@
mod add;
mod add_scalar;
mod automorphism;
mod copy;
mod merge_rings;
mod mul_xp_minus_one;
mod negate;
mod normalize;
mod rotate;
mod sampling;
mod shift;
mod split_ring;
mod sub;
mod sub_scalar;
mod switch_ring;
pub use add::*;
pub use add_scalar::*;
pub use automorphism::*;
pub use copy::*;
pub use merge_rings::*;
pub use mul_xp_minus_one::*;
pub use negate::*;
pub use normalize::*;
pub use rotate::*;
pub use sampling::*;
pub use shift::*;
pub use split_ring::*;
pub use sub::*;
pub use sub_scalar::*;
pub use switch_ring::*;

View File

@@ -0,0 +1,136 @@
use std::hint::black_box;
use criterion::{BenchmarkId, Criterion};
use crate::{
api::{
ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace,
VecZnxMulXpMinusOneInplaceTmpBytes,
},
layouts::{Backend, FillUniform, Module, ScratchOwned, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
reference::{
vec_znx::{vec_znx_rotate, vec_znx_sub_ab_inplace},
znx::{ZnxNegate, ZnxRotate, ZnxSubABInplace, ZnxSubBAInplace, ZnxZero},
},
source::Source,
};
pub fn vec_znx_mul_xp_minus_one_inplace_tmp_bytes(n: usize) -> usize {
n * size_of::<i64>()
}
pub fn vec_znx_mul_xp_minus_one<R, A, ZNXARI>(p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxRotate + ZnxZero + ZnxSubABInplace,
{
vec_znx_rotate::<_, _, ZNXARI>(p, res, res_col, a, a_col);
vec_znx_sub_ab_inplace::<_, _, ZNXARI>(res, res_col, a, a_col);
}
pub fn vec_znx_mul_xp_minus_one_inplace<R, ZNXARI>(p: i64, res: &mut R, res_col: usize, tmp: &mut [i64])
where
R: VecZnxToMut,
ZNXARI: ZnxRotate + ZnxNegate + ZnxSubBAInplace,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), tmp.len());
}
for j in 0..res.size() {
ZNXARI::znx_rotate(p, tmp, res.at(res_col, j));
ZNXARI::znx_sub_ba_inplace(res.at_mut(res_col, j), tmp);
}
}
pub fn bench_vec_znx_mul_xp_minus_one<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxMulXpMinusOne + ModuleNew<B>,
{
let group_name: String = format!("vec_znx_mul_xp_minus_one::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxMulXpMinusOne + ModuleNew<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
// Fill a with random i64
a.fill_uniform(50, &mut source);
res.fill_uniform(50, &mut source);
move || {
for i in 0..cols {
module.vec_znx_mul_xp_minus_one(-7, &mut res, i, &a, i);
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vec_znx_mul_xp_minus_one_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxMulXpMinusOneInplace<B> + VecZnxMulXpMinusOneInplaceTmpBytes + ModuleNew<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_mul_xp_minus_one_inplace::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxMulXpMinusOneInplace<B> + ModuleNew<B> + VecZnxMulXpMinusOneInplaceTmpBytes,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let mut source: Source = Source::new([0u8; 32]);
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut scratch = ScratchOwned::alloc(module.vec_znx_mul_xp_minus_one_inplace_tmp_bytes());
// Fill a with random i64
res.fill_uniform(50, &mut source);
move || {
for i in 0..cols {
module.vec_znx_mul_xp_minus_one_inplace(-7, &mut res, i, scratch.borrow());
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}

View File

@@ -0,0 +1,131 @@
use std::hint::black_box;
use criterion::{BenchmarkId, Criterion};
use crate::{
api::{ModuleNew, VecZnxNegate, VecZnxNegateInplace},
layouts::{Backend, FillUniform, Module, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
reference::znx::{ZnxNegate, ZnxNegateInplace, ZnxZero},
source::Source,
};
pub fn vec_znx_negate<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxNegate + ZnxZero,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
}
let min_size: usize = res.size().min(a.size());
for j in 0..min_size {
ZNXARI::znx_negate(res.at_mut(res_col, j), a.at(a_col, j));
}
for j in min_size..res.size() {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
}
pub fn vec_znx_negate_inplace<R, ZNXARI>(res: &mut R, res_col: usize)
where
R: VecZnxToMut,
ZNXARI: ZnxNegateInplace,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
for j in 0..res.size() {
ZNXARI::znx_negate_inplace(res.at_mut(res_col, j));
}
}
pub fn bench_vec_znx_negate<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxNegate + ModuleNew<B>,
{
let group_name: String = format!("vec_znx_negate::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxNegate + ModuleNew<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
// Fill a with random i64
a.fill_uniform(50, &mut source);
b.fill_uniform(50, &mut source);
move || {
for i in 0..cols {
module.vec_znx_negate(&mut b, i, &a, i);
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vec_znx_negate_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxNegateInplace + ModuleNew<B>,
{
let group_name: String = format!("vec_znx_negate_inplace::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxNegateInplace + ModuleNew<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
// Fill a with random i64
a.fill_uniform(50, &mut source);
move || {
for i in 0..cols {
module.vec_znx_negate_inplace(&mut a, i);
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}

View File

@@ -0,0 +1,193 @@
use std::hint::black_box;
use criterion::{BenchmarkId, Criterion};
use crate::{
api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes},
layouts::{Backend, FillUniform, Module, ScratchOwned, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
reference::znx::{
ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly,
ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace,
ZnxZero,
},
source::Source,
};
pub fn vec_znx_normalize_tmp_bytes(n: usize) -> usize {
n * size_of::<i64>()
}
pub fn vec_znx_normalize<R, A, ZNXARI>(basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxZero
+ ZnxNormalizeFirstStepCarryOnly
+ ZnxNormalizeMiddleStepCarryOnly
+ ZnxNormalizeMiddleStep
+ ZnxNormalizeFinalStep
+ ZnxNormalizeFirstStep,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let a: VecZnx<&[u8]> = a.to_ref();
#[cfg(debug_assertions)]
{
assert!(carry.len() >= res.n());
}
let res_size: usize = res.size();
let a_size = a.size();
if a_size > res_size {
for j in (res_size..a_size).rev() {
if j == a_size - 1 {
ZNXARI::znx_normalize_first_step_carry_only(basek, 0, a.at(a_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_carry_only(basek, 0, a.at(a_col, j), carry);
}
}
for j in (1..res_size).rev() {
ZNXARI::znx_normalize_middle_step(basek, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
}
ZNXARI::znx_normalize_final_step(basek, 0, res.at_mut(res_col, 0), a.at(a_col, 0), carry);
} else {
for j in (0..a_size).rev() {
if j == a_size - 1 {
ZNXARI::znx_normalize_first_step(basek, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
} else if j == 0 {
ZNXARI::znx_normalize_final_step(basek, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step(basek, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
}
}
for j in a_size..res_size {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
}
}
pub fn vec_znx_normalize_inplace<R: VecZnxToMut, ZNXARI>(basek: usize, res: &mut R, res_col: usize, carry: &mut [i64])
where
ZNXARI: ZnxNormalizeFirstStepInplace + ZnxNormalizeMiddleStepInplace + ZnxNormalizeFinalStepInplace,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert!(carry.len() >= res.n());
}
let res_size: usize = res.size();
for j in (0..res_size).rev() {
if j == res_size - 1 {
ZNXARI::znx_normalize_first_step_inplace(basek, 0, res.at_mut(res_col, j), carry);
} else if j == 0 {
ZNXARI::znx_normalize_final_step_inplace(basek, 0, res.at_mut(res_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_inplace(basek, 0, res.at_mut(res_col, j), carry);
}
}
}
pub fn bench_vec_znx_normalize<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxNormalize<B> + ModuleNew<B> + VecZnxNormalizeTmpBytes,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_normalize::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxNormalize<B> + ModuleNew<B> + VecZnxNormalizeTmpBytes,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let basek: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
// Fill a with random i64
a.fill_uniform(50, &mut source);
res.fill_uniform(50, &mut source);
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(module.vec_znx_normalize_tmp_bytes());
move || {
for i in 0..cols {
module.vec_znx_normalize(basek, &mut res, i, &a, i, scratch.borrow());
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vec_znx_normalize_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxNormalizeInplace<B> + ModuleNew<B> + VecZnxNormalizeTmpBytes,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_normalize_inplace::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxNormalizeInplace<B> + ModuleNew<B> + VecZnxNormalizeTmpBytes,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let basek: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
// Fill a with random i64
a.fill_uniform(50, &mut source);
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(module.vec_znx_normalize_tmp_bytes());
move || {
for i in 0..cols {
module.vec_znx_normalize_inplace(basek, &mut a, i, scratch.borrow());
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}

View File

@@ -0,0 +1,148 @@
use std::hint::black_box;
use criterion::{BenchmarkId, Criterion};
use crate::{
api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxRotate, VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes},
layouts::{Backend, FillUniform, Module, ScratchOwned, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
reference::znx::{ZnxCopy, ZnxRotate, ZnxZero},
source::Source,
};
pub fn vec_znx_rotate_inplace_tmp_bytes(n: usize) -> usize {
n * size_of::<i64>()
}
pub fn vec_znx_rotate<R, A, ZNXARI>(p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxRotate + ZnxZero,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let a: VecZnx<&[u8]> = a.to_ref();
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), a.n())
}
let res_size: usize = res.size();
let a_size: usize = a.size();
let min_size: usize = res_size.min(a_size);
for j in 0..min_size {
ZNXARI::znx_rotate(p, res.at_mut(res_col, j), a.at(a_col, j))
}
for j in min_size..res_size {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
}
pub fn vec_znx_rotate_inplace<R, ZNXARI>(p: i64, res: &mut R, res_col: usize, tmp: &mut [i64])
where
R: VecZnxToMut,
ZNXARI: ZnxRotate + ZnxCopy,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), tmp.len());
}
for j in 0..res.size() {
ZNXARI::znx_rotate(p, tmp, res.at(res_col, j));
ZNXARI::znx_copy(res.at_mut(res_col, j), tmp);
}
}
pub fn bench_vec_znx_rotate<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxRotate + ModuleNew<B>,
{
let group_name: String = format!("vec_znx_rotate::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxRotate + ModuleNew<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
// Fill a with random i64
a.fill_uniform(50, &mut source);
res.fill_uniform(50, &mut source);
move || {
for i in 0..cols {
module.vec_znx_rotate(-7, &mut res, i, &a, i);
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vec_znx_rotate_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxRotateInplace<B> + VecZnxRotateInplaceTmpBytes + ModuleNew<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_rotate_inplace::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxRotateInplace<B> + ModuleNew<B> + VecZnxRotateInplaceTmpBytes,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let mut source: Source = Source::new([0u8; 32]);
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut scratch = ScratchOwned::alloc(module.vec_znx_rotate_inplace_tmp_bytes());
// Fill a with random i64
res.fill_uniform(50, &mut source);
move || {
for i in 0..cols {
module.vec_znx_rotate_inplace(-7, &mut res, i, scratch.borrow());
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}

View File

@@ -0,0 +1,64 @@
use crate::{
layouts::{VecZnx, VecZnxToMut, ZnxInfos, ZnxViewMut},
reference::znx::{znx_add_normal_f64_ref, znx_fill_normal_f64_ref, znx_fill_uniform_ref},
source::Source,
};
pub fn vec_znx_fill_uniform_ref<R>(basek: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: VecZnxToMut,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
for j in 0..res.size() {
znx_fill_uniform_ref(basek, res.at_mut(res_col, j), source)
}
}
pub fn vec_znx_fill_normal_ref<R>(
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
sigma: f64,
bound: f64,
source: &mut Source,
) where
R: VecZnxToMut,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
assert!(
(bound.log2().ceil() as i64) < 64,
"invalid bound: ceil(log2(bound))={} > 63",
(bound.log2().ceil() as i64)
);
let limb: usize = k.div_ceil(basek) - 1;
let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64;
znx_fill_normal_f64_ref(
res.at_mut(res_col, limb),
sigma * scale,
bound * scale,
source,
)
}
pub fn vec_znx_add_normal_ref<R>(basek: usize, res: &mut R, res_col: usize, k: usize, sigma: f64, bound: f64, source: &mut Source)
where
R: VecZnxToMut,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
assert!(
(bound.log2().ceil() as i64) < 64,
"invalid bound: ceil(log2(bound))={} > 63",
(bound.log2().ceil() as i64)
);
let limb: usize = k.div_ceil(basek) - 1;
let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64;
znx_add_normal_f64_ref(
res.at_mut(res_col, limb),
sigma * scale,
bound * scale,
source,
)
}

View File

@@ -0,0 +1,672 @@
use std::hint::black_box;
use criterion::{BenchmarkId, Criterion};
use crate::{
api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxLsh, VecZnxLshInplace, VecZnxRsh, VecZnxRshInplace},
layouts::{Backend, FillUniform, Module, ScratchOwned, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
reference::{
vec_znx::vec_znx_copy,
znx::{
ZnxCopy, ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly,
ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace,
ZnxZero,
},
},
source::Source,
};
pub fn vec_znx_lsh_tmp_bytes(n: usize) -> usize {
n * size_of::<i64>()
}
pub fn vec_znx_lsh_inplace<R, ZNXARI>(basek: usize, k: usize, res: &mut R, res_col: usize, carry: &mut [i64])
where
R: VecZnxToMut,
ZNXARI: ZnxZero
+ ZnxCopy
+ ZnxNormalizeFirstStepInplace
+ ZnxNormalizeMiddleStepInplace
+ ZnxNormalizeFirstStepInplace
+ ZnxNormalizeFinalStepInplace,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let n: usize = res.n();
let cols: usize = res.cols();
let size: usize = res.size();
let steps: usize = k / basek;
let k_rem: usize = k % basek;
if steps >= size {
for j in 0..size {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
return;
}
// Inplace shift of limbs by a k/basek
if steps > 0 {
let start: usize = n * res_col;
let end: usize = start + n;
let slice_size: usize = n * cols;
let res_raw: &mut [i64] = res.raw_mut();
(0..size - steps).for_each(|j| {
let (lhs, rhs) = res_raw.split_at_mut(slice_size * (j + steps));
ZNXARI::znx_copy(
&mut lhs[start + j * slice_size..end + j * slice_size],
&rhs[start..end],
);
});
for j in size - steps..size {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
}
// Inplace normalization with left shift of k % basek
if !k.is_multiple_of(basek) {
for j in (0..size - steps).rev() {
if j == size - steps - 1 {
ZNXARI::znx_normalize_first_step_inplace(basek, k_rem, res.at_mut(res_col, j), carry);
} else if j == 0 {
ZNXARI::znx_normalize_final_step_inplace(basek, k_rem, res.at_mut(res_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_inplace(basek, k_rem, res.at_mut(res_col, j), carry);
}
}
}
}
pub fn vec_znx_lsh<R, A, ZNXARI>(basek: usize, k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxZero + ZnxNormalizeFirstStep + ZnxNormalizeMiddleStep + ZnxNormalizeFirstStep + ZnxCopy + ZnxNormalizeFinalStep,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let a: VecZnx<&[u8]> = a.to_ref();
let res_size: usize = res.size();
let a_size = a.size();
let steps: usize = k / basek;
let k_rem: usize = k % basek;
if steps >= res_size.min(a_size) {
for j in 0..res_size {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
return;
}
let min_size: usize = a_size.min(res_size) - steps;
// Simply a left shifted normalization of limbs
// by k/basek and intra-limb by basek - k%basek
if !k.is_multiple_of(basek) {
for j in (0..min_size).rev() {
if j == min_size - 1 {
ZNXARI::znx_normalize_first_step(
basek,
k_rem,
res.at_mut(res_col, j),
a.at(a_col, j + steps),
carry,
);
} else if j == 0 {
ZNXARI::znx_normalize_final_step(
basek,
k_rem,
res.at_mut(res_col, j),
a.at(a_col, j + steps),
carry,
);
} else {
ZNXARI::znx_normalize_middle_step(
basek,
k_rem,
res.at_mut(res_col, j),
a.at(a_col, j + steps),
carry,
);
}
}
} else {
// If k % basek = 0, then this is simply a copy.
for j in (0..min_size).rev() {
ZNXARI::znx_copy(res.at_mut(res_col, j), a.at(a_col, j + steps));
}
}
// Zeroes bottom
for j in min_size..res_size {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
}
pub fn vec_znx_rsh_tmp_bytes(n: usize) -> usize {
n * size_of::<i64>()
}
pub fn vec_znx_rsh_inplace<R, ZNXARI>(basek: usize, k: usize, res: &mut R, res_col: usize, carry: &mut [i64])
where
R: VecZnxToMut,
ZNXARI: ZnxZero
+ ZnxCopy
+ ZnxNormalizeFirstStepCarryOnly
+ ZnxNormalizeMiddleStepCarryOnly
+ ZnxNormalizeMiddleStep
+ ZnxNormalizeMiddleStepInplace
+ ZnxNormalizeFirstStepInplace
+ ZnxNormalizeFinalStepInplace,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let n: usize = res.n();
let cols: usize = res.cols();
let size: usize = res.size();
let mut steps: usize = k / basek;
let k_rem: usize = k % basek;
if k == 0 {
return;
}
if steps >= size {
for j in 0..size {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
return;
}
let start: usize = n * res_col;
let end: usize = start + n;
let slice_size: usize = n * cols;
if !k.is_multiple_of(basek) {
// We rsh by an additional basek and then lsh by basek-k
// Allows to re-use efficient normalization code, avoids
// avoids overflows & produce output that is normalized
steps += 1;
// All limbs of a that would fall outside of the limbs of res are discarded,
// but the carry still need to be computed.
(size - steps..size).rev().for_each(|j| {
if j == size - 1 {
ZNXARI::znx_normalize_first_step_carry_only(basek, basek - k_rem, res.at(res_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_carry_only(basek, basek - k_rem, res.at(res_col, j), carry);
}
});
// Continues with shifted normalization
let res_raw: &mut [i64] = res.raw_mut();
(steps..size).rev().for_each(|j| {
let (lhs, rhs) = res_raw.split_at_mut(slice_size * j);
let rhs_slice: &mut [i64] = &mut rhs[start..end];
let lhs_slice: &[i64] = &lhs[(j - steps) * slice_size + start..(j - steps) * slice_size + end];
ZNXARI::znx_normalize_middle_step(basek, basek - k_rem, rhs_slice, lhs_slice, carry);
});
// Propagates carry on the rest of the limbs of res
for j in (0..steps).rev() {
ZNXARI::znx_zero(res.at_mut(res_col, j));
if j == 0 {
ZNXARI::znx_normalize_final_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry);
}
}
} else {
// Shift by multiples of basek
let res_raw: &mut [i64] = res.raw_mut();
(steps..size).rev().for_each(|j| {
let (lhs, rhs) = res_raw.split_at_mut(slice_size * j);
ZNXARI::znx_copy(
&mut rhs[start..end],
&lhs[(j - steps) * slice_size + start..(j - steps) * slice_size + end],
);
});
// Zeroes the top
(0..steps).for_each(|j| {
ZNXARI::znx_zero(res.at_mut(res_col, j));
});
}
}
pub fn vec_znx_rsh<R, A, ZNXARI>(basek: usize, k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxZero
+ ZnxCopy
+ ZnxNormalizeFirstStepCarryOnly
+ ZnxNormalizeMiddleStepCarryOnly
+ ZnxNormalizeFirstStep
+ ZnxNormalizeMiddleStep
+ ZnxNormalizeMiddleStepInplace
+ ZnxNormalizeFirstStepInplace
+ ZnxNormalizeFinalStepInplace,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let a: VecZnx<&[u8]> = a.to_ref();
let res_size: usize = res.size();
let a_size: usize = a.size();
let mut steps: usize = k / basek;
let k_rem: usize = k % basek;
if k == 0 {
vec_znx_copy::<_, _, ZNXARI>(&mut res, res_col, &a, a_col);
return;
}
if steps >= res_size {
for j in 0..res_size {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
return;
}
if !k.is_multiple_of(basek) {
// We rsh by an additional basek and then lsh by basek-k
// Allows to re-use efficient normalization code, avoids
// avoids overflows & produce output that is normalized
steps += 1;
// All limbs of a that are moved outside of the limbs of res are discarded,
// but the carry still need to be computed.
for j in (res_size..a_size + steps).rev() {
if j == a_size + steps - 1 {
ZNXARI::znx_normalize_first_step_carry_only(basek, basek - k_rem, a.at(a_col, j - steps), carry);
} else {
ZNXARI::znx_normalize_middle_step_carry_only(basek, basek - k_rem, a.at(a_col, j - steps), carry);
}
}
// Avoids over flow of limbs of res
let min_size: usize = res_size.min(a_size + steps);
// Zeroes lower limbs of res if a_size + steps < res_size
(min_size..res_size).for_each(|j| {
ZNXARI::znx_zero(res.at_mut(res_col, j));
});
// Continues with shifted normalization
for j in (steps..min_size).rev() {
// Case if no limb of a was previously discarded
if res_size.saturating_sub(steps) >= a_size && j == min_size - 1 {
ZNXARI::znx_normalize_first_step(
basek,
basek - k_rem,
res.at_mut(res_col, j),
a.at(a_col, j - steps),
carry,
);
} else {
ZNXARI::znx_normalize_middle_step(
basek,
basek - k_rem,
res.at_mut(res_col, j),
a.at(a_col, j - steps),
carry,
);
}
}
// Propagates carry on the rest of the limbs of res
for j in (0..steps).rev() {
ZNXARI::znx_zero(res.at_mut(res_col, j));
if j == 0 {
ZNXARI::znx_normalize_final_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry);
}
}
} else {
let min_size: usize = res_size.min(a_size + steps);
// Zeroes the top
(0..steps).for_each(|j| {
ZNXARI::znx_zero(res.at_mut(res_col, j));
});
// Shift a into res, up to the maximum
for j in (steps..min_size).rev() {
ZNXARI::znx_copy(res.at_mut(res_col, j), a.at(a_col, j - steps));
}
// Zeroes bottom if a_size + steps < res_size
(min_size..res_size).for_each(|j| {
ZNXARI::znx_zero(res.at_mut(res_col, j));
});
}
}
pub fn bench_vec_znx_lsh_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: ModuleNew<B> + VecZnxLshInplace<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_lsh_inplace::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxLshInplace<B> + ModuleNew<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let basek: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(n * size_of::<i64>());
// Fill a with random i64
a.fill_uniform(50, &mut source);
b.fill_uniform(50, &mut source);
move || {
for i in 0..cols {
module.vec_znx_lsh_inplace(basek, basek - 1, &mut b, i, scratch.borrow());
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vec_znx_lsh<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxLsh<B> + ModuleNew<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_lsh::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxLsh<B> + ModuleNew<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let basek: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(n * size_of::<i64>());
// Fill a with random i64
a.fill_uniform(50, &mut source);
res.fill_uniform(50, &mut source);
move || {
for i in 0..cols {
module.vec_znx_lsh(basek, basek - 1, &mut res, i, &a, i, scratch.borrow());
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vec_znx_rsh_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxRshInplace<B> + ModuleNew<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_rsh_inplace::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxRshInplace<B> + ModuleNew<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let basek: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(n * size_of::<i64>());
// Fill a with random i64
a.fill_uniform(50, &mut source);
b.fill_uniform(50, &mut source);
move || {
for i in 0..cols {
module.vec_znx_rsh_inplace(basek, basek - 1, &mut b, i, scratch.borrow());
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vec_znx_rsh<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxRsh<B> + ModuleNew<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_rsh::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxRsh<B> + ModuleNew<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let basek: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(n * size_of::<i64>());
// Fill a with random i64
a.fill_uniform(50, &mut source);
res.fill_uniform(50, &mut source);
move || {
for i in 0..cols {
module.vec_znx_rsh(basek, basek - 1, &mut res, i, &a, i, scratch.borrow());
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
#[cfg(test)]
mod tests {
use crate::{
layouts::{FillUniform, VecZnx, ZnxView},
reference::{
vec_znx::{
vec_znx_copy, vec_znx_lsh, vec_znx_lsh_inplace, vec_znx_normalize_inplace, vec_znx_rsh, vec_znx_rsh_inplace,
vec_znx_sub_ab_inplace,
},
znx::ZnxRef,
},
source::Source,
};
#[test]
fn test_vec_znx_lsh() {
let n: usize = 8;
let cols: usize = 2;
let size: usize = 7;
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut source: Source = Source::new([0u8; 32]);
let mut carry: Vec<i64> = vec![0i64; n];
let basek: usize = 50;
for k in 0..256 {
a.fill_uniform(50, &mut source);
for i in 0..cols {
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut a, i, &mut carry);
vec_znx_copy::<_, _, ZnxRef>(&mut res_ref, i, &a, i);
}
for i in 0..cols {
vec_znx_lsh_inplace::<_, ZnxRef>(basek, k, &mut res_ref, i, &mut carry);
vec_znx_lsh::<_, _, ZnxRef>(basek, k, &mut res_test, i, &a, i, &mut carry);
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut res_test, i, &mut carry);
}
assert_eq!(res_ref, res_test);
}
}
#[test]
fn test_vec_znx_rsh() {
let n: usize = 8;
let cols: usize = 2;
let res_size: usize = 7;
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut carry: Vec<i64> = vec![0i64; n];
let basek: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
let zero: Vec<i64> = vec![0i64; n];
for a_size in [res_size - 1, res_size, res_size + 1] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
for k in 0..res_size * basek {
a.fill_uniform(50, &mut source);
for i in 0..cols {
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut a, i, &mut carry);
vec_znx_copy::<_, _, ZnxRef>(&mut res_ref, i, &a, i);
}
res_test.fill_uniform(50, &mut source);
for j in 0..cols {
vec_znx_rsh_inplace::<_, ZnxRef>(basek, k, &mut res_ref, j, &mut carry);
vec_znx_rsh::<_, _, ZnxRef>(basek, k, &mut res_test, j, &a, j, &mut carry);
}
for j in 0..cols {
vec_znx_lsh_inplace::<_, ZnxRef>(basek, k, &mut res_ref, j, &mut carry);
vec_znx_lsh_inplace::<_, ZnxRef>(basek, k, &mut res_test, j, &mut carry);
}
// Case where res has enough to fully store a right shifted without any loss
// In this case we can check exact equality.
if a_size + k.div_ceil(basek) <= res_size {
assert_eq!(res_ref, res_test);
for i in 0..cols {
for j in 0..a_size {
assert_eq!(res_ref.at(i, j), a.at(i, j), "r0 {} {}", i, j);
assert_eq!(res_test.at(i, j), a.at(i, j), "r1 {} {}", i, j);
}
for j in a_size..res_size {
assert_eq!(res_ref.at(i, j), zero, "r0 {} {}", i, j);
assert_eq!(res_test.at(i, j), zero, "r1 {} {}", i, j);
}
}
// Some loss occures, either because a initially has more precision than res
// or because the storage of the right shift of a requires more precision than
// res.
} else {
for j in 0..cols {
vec_znx_sub_ab_inplace::<_, _, ZnxRef>(&mut res_ref, j, &a, j);
vec_znx_sub_ab_inplace::<_, _, ZnxRef>(&mut res_test, j, &a, j);
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut res_ref, j, &mut carry);
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut res_test, j, &mut carry);
assert!(res_ref.std(basek, j).log2() - (k as f64) <= (k * basek) as f64);
assert!(res_test.std(basek, j).log2() - (k as f64) <= (k * basek) as f64);
}
}
}
}
}
}

View File

@@ -0,0 +1,62 @@
use crate::{
layouts::{VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
reference::znx::{ZnxRotate, ZnxSwitchRing, ZnxZero},
};
pub fn vec_znx_split_ring_tmp_bytes(n: usize) -> usize {
n * size_of::<i64>()
}
pub fn vec_znx_split_ring<R, A, ZNXARI>(res: &mut [R], res_col: usize, a: &A, a_col: usize, tmp: &mut [i64])
where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxSwitchRing + ZnxRotate + ZnxZero,
{
let a: VecZnx<&[u8]> = a.to_ref();
let a_size = a.size();
let (n_in, n_out) = (a.n(), res[0].to_mut().n());
#[cfg(debug_assertions)]
{
assert_eq!(tmp.len(), a.n());
assert!(
n_out < n_in,
"invalid a: output ring degree should be smaller"
);
res[1..].iter_mut().for_each(|bi| {
assert_eq!(
bi.to_mut().n(),
n_out,
"invalid input a: all VecZnx must have the same degree"
)
});
assert!(n_in.is_multiple_of(n_out));
assert_eq!(res.len(), n_in / n_out);
}
res.iter_mut().enumerate().for_each(|(i, bi)| {
let mut bi: VecZnx<&mut [u8]> = bi.to_mut();
let min_size = bi.size().min(a_size);
if i == 0 {
for j in 0..min_size {
ZNXARI::znx_switch_ring(bi.at_mut(res_col, j), a.at(a_col, j));
}
} else {
for j in 0..min_size {
ZNXARI::znx_rotate(-(i as i64), tmp, a.at(a_col, j));
ZNXARI::znx_switch_ring(bi.at_mut(res_col, j), tmp);
}
}
for j in min_size..bi.size() {
ZNXARI::znx_zero(bi.at_mut(res_col, j));
}
})
}

View File

@@ -0,0 +1,250 @@
use std::hint::black_box;
use criterion::{BenchmarkId, Criterion};
use crate::{
api::{ModuleNew, VecZnxSub, VecZnxSubABInplace, VecZnxSubBAInplace},
layouts::{Backend, FillUniform, Module, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
oep::{ModuleNewImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl},
reference::znx::{ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxSub, ZnxSubABInplace, ZnxSubBAInplace, ZnxZero},
source::Source,
};
pub fn vec_znx_sub<R, A, B, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
B: VecZnxToRef,
ZNXARI: ZnxSub + ZnxNegate + ZnxZero + ZnxCopy,
{
let a: VecZnx<&[u8]> = a.to_ref();
let b: VecZnx<&[u8]> = b.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
assert_eq!(b.n(), res.n());
}
let res_size: usize = res.size();
let a_size: usize = a.size();
let b_size: usize = b.size();
if a_size <= b_size {
let sum_size: usize = a_size.min(res_size);
let cpy_size: usize = b_size.min(res_size);
for j in 0..sum_size {
ZNXARI::znx_sub(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j));
}
for j in sum_size..cpy_size {
ZNXARI::znx_negate(res.at_mut(res_col, j), b.at(b_col, j));
}
for j in cpy_size..res_size {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
} else {
let sum_size: usize = b_size.min(res_size);
let cpy_size: usize = a_size.min(res_size);
for j in 0..sum_size {
ZNXARI::znx_sub(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j));
}
for j in sum_size..cpy_size {
ZNXARI::znx_copy(res.at_mut(res_col, j), a.at(a_col, j));
}
for j in cpy_size..res_size {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
}
}
pub fn vec_znx_sub_ab_inplace<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxSubABInplace,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
}
let res_size: usize = res.size();
let a_size: usize = a.size();
let sum_size: usize = a_size.min(res_size);
for j in 0..sum_size {
ZNXARI::znx_sub_ab_inplace(res.at_mut(res_col, j), a.at(a_col, j));
}
}
pub fn vec_znx_sub_ba_inplace<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxSubBAInplace + ZnxNegateInplace,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
}
let res_size: usize = res.size();
let a_size: usize = a.size();
let sum_size: usize = a_size.min(res_size);
for j in 0..sum_size {
ZNXARI::znx_sub_ba_inplace(res.at_mut(res_col, j), a.at(a_col, j));
}
for j in sum_size..res_size {
ZNXARI::znx_negate_inplace(res.at_mut(res_col, j));
}
}
pub fn bench_vec_znx_sub<B>(c: &mut Criterion, label: &str)
where
B: Backend + ModuleNewImpl<B> + VecZnxSubImpl<B>,
{
let group_name: String = format!("vec_znx_sub::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxSub + ModuleNew<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut c: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
// Fill a with random i64
a.fill_uniform(50, &mut source);
b.fill_uniform(50, &mut source);
move || {
for i in 0..cols {
module.vec_znx_sub(&mut c, i, &a, i, &b, i);
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vec_znx_sub_ab_inplace<B>(c: &mut Criterion, label: &str)
where
B: Backend + ModuleNewImpl<B> + VecZnxSubABInplaceImpl<B>,
{
let group_name: String = format!("vec_znx_sub_ab_inplace::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxSubABInplace + ModuleNew<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
// Fill a with random i64
a.fill_uniform(50, &mut source);
b.fill_uniform(50, &mut source);
move || {
for i in 0..cols {
module.vec_znx_sub_ab_inplace(&mut b, i, &a, i);
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vec_znx_sub_ba_inplace<B>(c: &mut Criterion, label: &str)
where
B: Backend + ModuleNewImpl<B> + VecZnxSubBAInplaceImpl<B>,
{
let group_name: String = format!("vec_znx_sub_ba_inplace::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxSubBAInplace + ModuleNew<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
// Fill a with random i64
a.fill_uniform(50, &mut source);
b.fill_uniform(50, &mut source);
move || {
for i in 0..cols {
module.vec_znx_sub_ba_inplace(&mut b, i, &a, i);
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}

View File

@@ -0,0 +1,58 @@
use crate::layouts::{ScalarZnxToRef, VecZnxToMut, VecZnxToRef};
use crate::{
layouts::{ScalarZnx, VecZnx, ZnxInfos, ZnxView, ZnxViewMut},
reference::znx::{ZnxSub, ZnxSubABInplace, ZnxZero},
};
pub fn vec_znx_sub_scalar<R, A, B, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize, b_limb: usize)
where
R: VecZnxToMut,
A: ScalarZnxToRef,
B: VecZnxToRef,
ZNXARI: ZnxSub + ZnxZero,
{
let a: ScalarZnx<&[u8]> = a.to_ref();
let b: VecZnx<&[u8]> = b.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let min_size: usize = b.size().min(res.size());
#[cfg(debug_assertions)]
{
assert!(
b_limb < min_size,
"b_limb: {} > min_size: {}",
b_limb,
min_size
);
}
for j in 0..min_size {
if j == b_limb {
ZNXARI::znx_sub(res.at_mut(res_col, j), b.at(b_col, j), a.at(a_col, 0));
} else {
res.at_mut(res_col, j).copy_from_slice(b.at(b_col, j));
}
}
for j in min_size..res.size() {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
}
pub fn vec_znx_sub_scalar_inplace<R, A, ZNXARI>(res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: ScalarZnxToRef,
ZNXARI: ZnxSubABInplace,
{
let a: ScalarZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert!(res_limb < res.size());
}
ZNXARI::znx_sub_ab_inplace(res.at_mut(res_col, res_limb), a.at(a_col, 0));
}

View File

@@ -0,0 +1,37 @@
use crate::{
layouts::{VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
reference::{
vec_znx::vec_znx_copy,
znx::{ZnxCopy, ZnxSwitchRing, ZnxZero},
},
};
/// Maps between negacyclic rings by changing the polynomial degree.
/// Up: Z[X]/(X^N+1) -> Z[X]/(X^{2^d N}+1) via X ↦ X^{2^d}
/// Down: Z[X]/(X^N+1) -> Z[X]/(X^{N/2^d}+1) by folding indices.
pub fn vec_znx_switch_ring<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxCopy + ZnxSwitchRing + ZnxZero,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let (n_in, n_out) = (a.n(), res.n());
if n_in == n_out {
vec_znx_copy::<_, _, ZNXARI>(&mut res, res_col, &a, a_col);
return;
}
let min_size: usize = a.size().min(res.size());
for j in 0..min_size {
ZNXARI::znx_switch_ring(res.at_mut(res_col, j), a.at(a_col, j));
}
for j in min_size..res.size() {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
}

View File

@@ -0,0 +1,5 @@
mod normalization;
mod sampling;
pub use normalization::*;
pub use sampling::*;

View File

@@ -0,0 +1,72 @@
use crate::{
api::{ScratchOwnedAlloc, ScratchOwnedBorrow, ZnNormalizeInplace, ZnNormalizeTmpBytes},
layouts::{Backend, Module, ScratchOwned, Zn, ZnToMut, ZnxInfos, ZnxView, ZnxViewMut},
reference::znx::{ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStepInplace, ZnxRef},
source::Source,
};
pub fn zn_normalize_tmp_bytes(n: usize) -> usize {
n * size_of::<i64>()
}
pub fn zn_normalize_inplace<R, ARI>(n: usize, basek: usize, res: &mut R, res_col: usize, carry: &mut [i64])
where
R: ZnToMut,
ARI: ZnxNormalizeFirstStepInplace + ZnxNormalizeFinalStepInplace + ZnxNormalizeMiddleStepInplace,
{
let mut res: Zn<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(carry.len(), res.n());
}
let res_size: usize = res.size();
for j in (0..res_size).rev() {
let out = &mut res.at_mut(res_col, j)[..n];
if j == res_size - 1 {
ARI::znx_normalize_first_step_inplace(basek, 0, out, carry);
} else if j == 0 {
ARI::znx_normalize_final_step_inplace(basek, 0, out, carry);
} else {
ARI::znx_normalize_middle_step_inplace(basek, 0, out, carry);
}
}
}
pub fn test_zn_normalize_inplace<B: Backend>(module: &Module<B>)
where
Module<B>: ZnNormalizeInplace<B> + ZnNormalizeTmpBytes,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let basek: usize = 12;
let n = 33;
let mut carry: Vec<i64> = vec![0i64; zn_normalize_tmp_bytes(n)];
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(module.zn_normalize_tmp_bytes(module.n()));
for res_size in [1, 2, 6, 11] {
let mut res_0: Zn<Vec<u8>> = Zn::alloc(n, cols, res_size);
let mut res_1: Zn<Vec<u8>> = Zn::alloc(n, cols, res_size);
res_0
.raw_mut()
.iter_mut()
.for_each(|x| *x = source.next_i32() as i64);
res_1.raw_mut().copy_from_slice(res_0.raw());
// Reference
for i in 0..cols {
zn_normalize_inplace::<_, ZnxRef>(n, basek, &mut res_0, i, &mut carry);
module.zn_normalize_inplace(n, basek, &mut res_1, i, scratch.borrow());
}
assert_eq!(res_0.raw(), res_1.raw());
}
}

View File

@@ -0,0 +1,75 @@
use crate::{
layouts::{Zn, ZnToMut, ZnxInfos, ZnxViewMut},
reference::znx::{znx_add_normal_f64_ref, znx_fill_normal_f64_ref, znx_fill_uniform_ref},
source::Source,
};
pub fn zn_fill_uniform<R>(n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: ZnToMut,
{
let mut res: Zn<&mut [u8]> = res.to_mut();
for j in 0..res.size() {
znx_fill_uniform_ref(basek, &mut res.at_mut(res_col, j)[..n], source)
}
}
#[allow(clippy::too_many_arguments)]
pub fn zn_fill_normal<R>(
n: usize,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) where
R: ZnToMut,
{
let mut res: Zn<&mut [u8]> = res.to_mut();
assert!(
(bound.log2().ceil() as i64) < 64,
"invalid bound: ceil(log2(bound))={} > 63",
(bound.log2().ceil() as i64)
);
let limb: usize = k.div_ceil(basek) - 1;
let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64;
znx_fill_normal_f64_ref(
&mut res.at_mut(res_col, limb)[..n],
sigma * scale,
bound * scale,
source,
)
}
#[allow(clippy::too_many_arguments)]
pub fn zn_add_normal<R>(
n: usize,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) where
R: ZnToMut,
{
let mut res: Zn<&mut [u8]> = res.to_mut();
assert!(
(bound.log2().ceil() as i64) < 64,
"invalid bound: ceil(log2(bound))={} > 63",
(bound.log2().ceil() as i64)
);
let limb: usize = k.div_ceil(basek) - 1;
let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64;
znx_add_normal_f64_ref(
&mut res.at_mut(res_col, limb)[..n],
sigma * scale,
bound * scale,
source,
)
}

View File

@@ -0,0 +1,25 @@
#[inline(always)]
pub fn znx_add_ref(res: &mut [i64], a: &[i64], b: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len());
assert_eq!(res.len(), b.len());
}
let n: usize = res.len();
for i in 0..n {
res[i] = a[i] + b[i];
}
}
pub fn znx_add_inplace_ref(res: &mut [i64], a: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len());
}
let n: usize = res.len();
for i in 0..n {
res[i] += a[i];
}
}

View File

@@ -0,0 +1,153 @@
use crate::reference::znx::{
ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxNormalizeFinalStep,
ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeFirstStepInplace,
ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, ZnxSub, ZnxSubABInplace,
ZnxSubBAInplace, ZnxSwitchRing, ZnxZero,
add::{znx_add_inplace_ref, znx_add_ref},
automorphism::znx_automorphism_ref,
copy::znx_copy_ref,
neg::{znx_negate_inplace_ref, znx_negate_ref},
normalization::{
znx_normalize_final_step_inplace_ref, znx_normalize_final_step_ref, znx_normalize_first_step_carry_only_ref,
znx_normalize_first_step_inplace_ref, znx_normalize_first_step_ref, znx_normalize_middle_step_carry_only_ref,
znx_normalize_middle_step_inplace_ref, znx_normalize_middle_step_ref,
},
sub::{znx_sub_ab_inplace_ref, znx_sub_ba_inplace_ref, znx_sub_ref},
switch_ring::znx_switch_ring_ref,
zero::znx_zero_ref,
};
pub struct ZnxRef {}
impl ZnxAdd for ZnxRef {
#[inline(always)]
fn znx_add(res: &mut [i64], a: &[i64], b: &[i64]) {
znx_add_ref(res, a, b);
}
}
impl ZnxAddInplace for ZnxRef {
#[inline(always)]
fn znx_add_inplace(res: &mut [i64], a: &[i64]) {
znx_add_inplace_ref(res, a);
}
}
impl ZnxSub for ZnxRef {
#[inline(always)]
fn znx_sub(res: &mut [i64], a: &[i64], b: &[i64]) {
znx_sub_ref(res, a, b);
}
}
impl ZnxSubABInplace for ZnxRef {
#[inline(always)]
fn znx_sub_ab_inplace(res: &mut [i64], a: &[i64]) {
znx_sub_ab_inplace_ref(res, a);
}
}
impl ZnxSubBAInplace for ZnxRef {
#[inline(always)]
fn znx_sub_ba_inplace(res: &mut [i64], a: &[i64]) {
znx_sub_ba_inplace_ref(res, a);
}
}
impl ZnxAutomorphism for ZnxRef {
#[inline(always)]
fn znx_automorphism(p: i64, res: &mut [i64], a: &[i64]) {
znx_automorphism_ref(p, res, a);
}
}
impl ZnxCopy for ZnxRef {
#[inline(always)]
fn znx_copy(res: &mut [i64], a: &[i64]) {
znx_copy_ref(res, a);
}
}
impl ZnxNegate for ZnxRef {
#[inline(always)]
fn znx_negate(res: &mut [i64], src: &[i64]) {
znx_negate_ref(res, src);
}
}
impl ZnxNegateInplace for ZnxRef {
#[inline(always)]
fn znx_negate_inplace(res: &mut [i64]) {
znx_negate_inplace_ref(res);
}
}
impl ZnxZero for ZnxRef {
#[inline(always)]
fn znx_zero(res: &mut [i64]) {
znx_zero_ref(res);
}
}
impl ZnxSwitchRing for ZnxRef {
#[inline(always)]
fn znx_switch_ring(res: &mut [i64], a: &[i64]) {
znx_switch_ring_ref(res, a);
}
}
impl ZnxNormalizeFinalStep for ZnxRef {
#[inline(always)]
fn znx_normalize_final_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
znx_normalize_final_step_ref(basek, lsh, x, a, carry);
}
}
impl ZnxNormalizeFinalStepInplace for ZnxRef {
#[inline(always)]
fn znx_normalize_final_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
znx_normalize_final_step_inplace_ref(basek, lsh, x, carry);
}
}
impl ZnxNormalizeFirstStep for ZnxRef {
#[inline(always)]
fn znx_normalize_first_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
znx_normalize_first_step_ref(basek, lsh, x, a, carry);
}
}
impl ZnxNormalizeFirstStepCarryOnly for ZnxRef {
#[inline(always)]
fn znx_normalize_first_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
znx_normalize_first_step_carry_only_ref(basek, lsh, x, carry);
}
}
impl ZnxNormalizeFirstStepInplace for ZnxRef {
#[inline(always)]
fn znx_normalize_first_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
znx_normalize_first_step_inplace_ref(basek, lsh, x, carry);
}
}
impl ZnxNormalizeMiddleStep for ZnxRef {
#[inline(always)]
fn znx_normalize_middle_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
znx_normalize_middle_step_ref(basek, lsh, x, a, carry);
}
}
impl ZnxNormalizeMiddleStepCarryOnly for ZnxRef {
#[inline(always)]
fn znx_normalize_middle_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
znx_normalize_middle_step_carry_only_ref(basek, lsh, x, carry);
}
}
impl ZnxNormalizeMiddleStepInplace for ZnxRef {
#[inline(always)]
fn znx_normalize_middle_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
znx_normalize_middle_step_inplace_ref(basek, lsh, x, carry);
}
}

View File

@@ -0,0 +1,21 @@
pub fn znx_automorphism_ref(p: i64, res: &mut [i64], a: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len());
}
let n: usize = res.len();
let mut k: usize = 0usize;
let mask: usize = 2 * n - 1;
let p_2n = (p & mask as i64) as usize;
res[0] = a[0];
for ai in a.iter().take(n).skip(1) {
k = (k + p_2n) & mask;
if k < n {
res[k] = *ai
} else {
res[k - n] = -*ai
}
}
}

View File

@@ -0,0 +1,8 @@
#[inline(always)]
pub fn znx_copy_ref(res: &mut [i64], a: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len())
}
res.copy_from_slice(a);
}

View File

@@ -0,0 +1,104 @@
mod add;
mod arithmetic_ref;
mod automorphism;
mod copy;
mod neg;
mod normalization;
mod rotate;
mod sampling;
mod sub;
mod switch_ring;
mod zero;
pub use add::*;
pub use arithmetic_ref::*;
pub use automorphism::*;
pub use copy::*;
pub use neg::*;
pub use normalization::*;
pub use rotate::*;
pub use sub::*;
pub use switch_ring::*;
pub use zero::*;
pub use sampling::*;
pub trait ZnxAdd {
fn znx_add(res: &mut [i64], a: &[i64], b: &[i64]);
}
pub trait ZnxAddInplace {
fn znx_add_inplace(res: &mut [i64], a: &[i64]);
}
pub trait ZnxSub {
fn znx_sub(res: &mut [i64], a: &[i64], b: &[i64]);
}
pub trait ZnxSubABInplace {
fn znx_sub_ab_inplace(res: &mut [i64], a: &[i64]);
}
pub trait ZnxSubBAInplace {
fn znx_sub_ba_inplace(res: &mut [i64], a: &[i64]);
}
pub trait ZnxAutomorphism {
fn znx_automorphism(p: i64, res: &mut [i64], a: &[i64]);
}
pub trait ZnxCopy {
fn znx_copy(res: &mut [i64], a: &[i64]);
}
pub trait ZnxNegate {
fn znx_negate(res: &mut [i64], src: &[i64]);
}
pub trait ZnxNegateInplace {
fn znx_negate_inplace(res: &mut [i64]);
}
pub trait ZnxRotate {
fn znx_rotate(p: i64, res: &mut [i64], src: &[i64]);
}
pub trait ZnxZero {
fn znx_zero(res: &mut [i64]);
}
pub trait ZnxSwitchRing {
fn znx_switch_ring(res: &mut [i64], a: &[i64]);
}
pub trait ZnxNormalizeFirstStepCarryOnly {
fn znx_normalize_first_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]);
}
pub trait ZnxNormalizeFirstStepInplace {
fn znx_normalize_first_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]);
}
pub trait ZnxNormalizeFirstStep {
fn znx_normalize_first_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]);
}
pub trait ZnxNormalizeMiddleStepCarryOnly {
fn znx_normalize_middle_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]);
}
pub trait ZnxNormalizeMiddleStepInplace {
fn znx_normalize_middle_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]);
}
pub trait ZnxNormalizeMiddleStep {
fn znx_normalize_middle_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]);
}
pub trait ZnxNormalizeFinalStepInplace {
fn znx_normalize_final_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]);
}
pub trait ZnxNormalizeFinalStep {
fn znx_normalize_final_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]);
}

View File

@@ -0,0 +1,18 @@
#[inline(always)]
pub fn znx_negate_ref(res: &mut [i64], src: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), src.len())
}
for i in 0..res.len() {
res[i] = -src[i]
}
}
#[inline(always)]
pub fn znx_negate_inplace_ref(res: &mut [i64]) {
for value in res {
*value = -*value
}
}

View File

@@ -0,0 +1,199 @@
use itertools::izip;
#[inline(always)]
pub fn get_digit(basek: usize, x: i64) -> i64 {
(x << (u64::BITS - basek as u32)) >> (u64::BITS - basek as u32)
}
#[inline(always)]
pub fn get_carry(basek: usize, x: i64, digit: i64) -> i64 {
(x.wrapping_sub(digit)) >> basek
}
#[inline(always)]
pub fn znx_normalize_first_step_carry_only_ref(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert!(x.len() <= carry.len());
assert!(lsh < basek);
}
if lsh == 0 {
x.iter().zip(carry.iter_mut()).for_each(|(x, c)| {
*c = get_carry(basek, *x, get_digit(basek, *x));
});
} else {
let basek_lsh: usize = basek - lsh;
x.iter().zip(carry.iter_mut()).for_each(|(x, c)| {
*c = get_carry(basek_lsh, *x, get_digit(basek_lsh, *x));
});
}
}
#[inline(always)]
pub fn znx_normalize_first_step_inplace_ref(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert!(x.len() <= carry.len());
assert!(lsh < basek);
}
if lsh == 0 {
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
let digit: i64 = get_digit(basek, *x);
*c = get_carry(basek, *x, digit);
*x = digit;
});
} else {
let basek_lsh: usize = basek - lsh;
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
let digit: i64 = get_digit(basek_lsh, *x);
*c = get_carry(basek_lsh, *x, digit);
*x = digit << lsh;
});
}
}
#[inline(always)]
pub fn znx_normalize_first_step_ref(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(x.len(), a.len());
assert!(x.len() <= carry.len());
assert!(lsh < basek);
}
if lsh == 0 {
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
let digit: i64 = get_digit(basek, *a);
*c = get_carry(basek, *a, digit);
*x = digit;
});
} else {
let basek_lsh: usize = basek - lsh;
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
let digit: i64 = get_digit(basek_lsh, *a);
*c = get_carry(basek_lsh, *a, digit);
*x = digit << lsh;
});
}
}
#[inline(always)]
pub fn znx_normalize_middle_step_carry_only_ref(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert!(x.len() <= carry.len());
assert!(lsh < basek);
}
if lsh == 0 {
x.iter().zip(carry.iter_mut()).for_each(|(x, c)| {
let digit: i64 = get_digit(basek, *x);
let carry: i64 = get_carry(basek, *x, digit);
let digit_plus_c: i64 = digit + *c;
*c = carry + get_carry(basek, digit_plus_c, get_digit(basek, digit_plus_c));
});
} else {
let basek_lsh: usize = basek - lsh;
x.iter().zip(carry.iter_mut()).for_each(|(x, c)| {
let digit: i64 = get_digit(basek_lsh, *x);
let carry: i64 = get_carry(basek_lsh, *x, digit);
let digit_plus_c: i64 = (digit << lsh) + *c;
*c = carry + get_carry(basek, digit_plus_c, get_digit(basek, digit_plus_c));
});
}
}
#[inline(always)]
pub fn znx_normalize_middle_step_inplace_ref(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert!(x.len() <= carry.len());
assert!(lsh < basek);
}
if lsh == 0 {
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
let digit: i64 = get_digit(basek, *x);
let carry: i64 = get_carry(basek, *x, digit);
let digit_plus_c: i64 = digit + *c;
*x = get_digit(basek, digit_plus_c);
*c = carry + get_carry(basek, digit_plus_c, *x);
});
} else {
let basek_lsh: usize = basek - lsh;
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
let digit: i64 = get_digit(basek_lsh, *x);
let carry: i64 = get_carry(basek_lsh, *x, digit);
let digit_plus_c: i64 = (digit << lsh) + *c;
*x = get_digit(basek, digit_plus_c);
*c = carry + get_carry(basek, digit_plus_c, *x);
});
}
}
#[inline(always)]
pub fn znx_normalize_middle_step_ref(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(x.len(), a.len());
assert!(x.len() <= carry.len());
assert!(lsh < basek);
}
if lsh == 0 {
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
let digit: i64 = get_digit(basek, *a);
let carry: i64 = get_carry(basek, *a, digit);
let digit_plus_c: i64 = digit + *c;
*x = get_digit(basek, digit_plus_c);
*c = carry + get_carry(basek, digit_plus_c, *x);
});
} else {
let basek_lsh: usize = basek - lsh;
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
let digit: i64 = get_digit(basek_lsh, *a);
let carry: i64 = get_carry(basek_lsh, *a, digit);
let digit_plus_c: i64 = (digit << lsh) + *c;
*x = get_digit(basek, digit_plus_c);
*c = carry + get_carry(basek, digit_plus_c, *x);
});
}
}
#[inline(always)]
pub fn znx_normalize_final_step_inplace_ref(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert!(x.len() <= carry.len());
assert!(lsh < basek);
}
if lsh == 0 {
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
*x = get_digit(basek, get_digit(basek, *x) + *c);
});
} else {
let basek_lsh: usize = basek - lsh;
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
*x = get_digit(basek, (get_digit(basek_lsh, *x) << lsh) + *c);
});
}
}
#[inline(always)]
pub fn znx_normalize_final_step_ref(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert!(x.len() <= carry.len());
assert!(lsh < basek);
}
if lsh == 0 {
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
*x = get_digit(basek, get_digit(basek, *a) + *c);
});
} else {
let basek_lsh: usize = basek - lsh;
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
*x = get_digit(basek, (get_digit(basek_lsh, *a) << lsh) + *c);
});
}
}

View File

@@ -0,0 +1,26 @@
use crate::reference::znx::{ZnxCopy, ZnxNegate};
pub fn znx_rotate<ZNXARI: ZnxNegate + ZnxCopy>(p: i64, res: &mut [i64], src: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), src.len());
}
let n: usize = res.len();
let mp_2n: usize = (p & (2 * n as i64 - 1)) as usize; // -p % 2n
let mp_1n: usize = mp_2n & (n - 1); // -p % n
let mp_1n_neg: usize = n - mp_1n; // p % n
let neg_first: bool = mp_2n < n;
let (dst1, dst2) = res.split_at_mut(mp_1n);
let (src1, src2) = src.split_at(mp_1n_neg);
if neg_first {
ZNXARI::znx_negate(dst1, src2);
ZNXARI::znx_copy(dst2, src1);
} else {
ZNXARI::znx_copy(dst1, src2);
ZNXARI::znx_negate(dst2, src1);
}
}

View File

@@ -0,0 +1,53 @@
use rand_distr::{Distribution, Normal};
use crate::source::Source;
pub fn znx_fill_uniform_ref(basek: usize, res: &mut [i64], source: &mut Source) {
let pow2k: u64 = 1 << basek;
let mask: u64 = pow2k - 1;
let pow2k_half: i64 = (pow2k >> 1) as i64;
res.iter_mut()
.for_each(|xi| *xi = (source.next_u64n(pow2k, mask) as i64) - pow2k_half)
}
pub fn znx_fill_dist_f64_ref<D: rand::prelude::Distribution<f64>>(res: &mut [i64], dist: D, bound: f64, source: &mut Source) {
res.iter_mut().for_each(|xi| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
}
*xi = dist_f64.round() as i64
})
}
pub fn znx_add_dist_f64_ref<D: rand::prelude::Distribution<f64>>(res: &mut [i64], dist: D, bound: f64, source: &mut Source) {
res.iter_mut().for_each(|xi| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
}
*xi += dist_f64.round() as i64
})
}
pub fn znx_fill_normal_f64_ref(res: &mut [i64], sigma: f64, bound: f64, source: &mut Source) {
let normal: Normal<f64> = Normal::new(0.0, sigma).unwrap();
res.iter_mut().for_each(|xi| {
let mut dist_f64: f64 = normal.sample(source);
while dist_f64.abs() > bound {
dist_f64 = normal.sample(source)
}
*xi = dist_f64.round() as i64
})
}
pub fn znx_add_normal_f64_ref(res: &mut [i64], sigma: f64, bound: f64, source: &mut Source) {
let normal: Normal<f64> = Normal::new(0.0, sigma).unwrap();
res.iter_mut().for_each(|xi| {
let mut dist_f64: f64 = normal.sample(source);
while dist_f64.abs() > bound {
dist_f64 = normal.sample(source)
}
*xi += dist_f64.round() as i64
})
}

View File

@@ -0,0 +1,36 @@
pub fn znx_sub_ref(res: &mut [i64], a: &[i64], b: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len());
assert_eq!(res.len(), b.len());
}
let n: usize = res.len();
for i in 0..n {
res[i] = a[i] - b[i];
}
}
pub fn znx_sub_ab_inplace_ref(res: &mut [i64], a: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len());
}
let n: usize = res.len();
for i in 0..n {
res[i] -= a[i];
}
}
pub fn znx_sub_ba_inplace_ref(res: &mut [i64], a: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len());
}
let n: usize = res.len();
for i in 0..n {
res[i] = a[i] - res[i];
}
}

View File

@@ -0,0 +1,29 @@
use crate::reference::znx::{copy::znx_copy_ref, zero::znx_zero_ref};
pub fn znx_switch_ring_ref(res: &mut [i64], a: &[i64]) {
let (n_in, n_out) = (a.len(), res.len());
#[cfg(debug_assertions)]
{
assert!(n_in.is_power_of_two());
assert!(n_in.max(n_out).is_multiple_of(n_in.min(n_out)))
}
if n_in == n_out {
znx_copy_ref(res, a);
return;
}
let (gap_in, gap_out): (usize, usize);
if n_in > n_out {
(gap_in, gap_out) = (n_in / n_out, 1)
} else {
(gap_in, gap_out) = (1, n_out / n_in);
znx_zero_ref(res);
}
res.iter_mut()
.step_by(gap_out)
.zip(a.iter().step_by(gap_in))
.for_each(|(x_out, x_in)| *x_out = *x_in);
}

View File

@@ -0,0 +1,3 @@
pub fn znx_zero_ref(res: &mut [i64]) {
res.fill(0);
}

View File

@@ -39,6 +39,12 @@ impl Source {
min + ((self.next_u64() << 11 >> 11) as f64) / MAXF64 * (max - min)
}
#[inline(always)]
pub fn next_i32(&mut self) -> i32 {
self.next_u32() as i32
}
#[inline(always)]
pub fn next_i64(&mut self) -> i64 {
self.next_u64() as i64
}

View File

@@ -0,0 +1,68 @@
pub mod serialization;
pub mod svp;
pub mod vec_znx;
pub mod vec_znx_big;
pub mod vec_znx_dft;
pub mod vmp;
#[macro_export]
macro_rules! backend_test_suite {
(
mod $modname:ident,
backend = $backend:ty,
size = $size:expr,
tests = {
$( $(#[$attr:meta])* $test_name:ident => $impl:path ),+ $(,)?
}
) => {
mod $modname {
use poulpy_hal::{api::ModuleNew, layouts::Module};
use once_cell::sync::Lazy;
static MODULE: Lazy<Module<$backend>> =
Lazy::new(|| Module::<$backend>::new($size));
$(
$(#[$attr])*
#[test]
fn $test_name() {
($impl)(&*MODULE);
}
)+
}
};
}
#[macro_export]
macro_rules! cross_backend_test_suite {
(
mod $modname:ident,
backend_ref = $backend_ref:ty,
backend_test = $backend_test:ty,
size = $size:expr,
basek = $basek:expr,
tests = {
$( $(#[$attr:meta])* $test_name:ident => $impl:path ),+ $(,)?
}
) => {
mod $modname {
use poulpy_hal::{api::ModuleNew, layouts::Module};
use once_cell::sync::Lazy;
static MODULE_REF: Lazy<Module<$backend_ref>> =
Lazy::new(|| Module::<$backend_ref>::new($size));
static MODULE_TEST: Lazy<Module<$backend_test>> =
Lazy::new(|| Module::<$backend_test>::new($size));
$(
$(#[$attr])*
#[test]
fn $test_name() {
($impl)($basek, &*MODULE_REF, &*MODULE_TEST);
}
)+
}
};
}

View File

@@ -14,7 +14,7 @@ where
{
// Fill original with uniform random data
let mut source = Source::new([0u8; 32]);
original.fill_uniform(&mut source);
original.fill_uniform(50, &mut source);
// Serialize into a buffer
let mut buffer = Vec::new();

View File

@@ -0,0 +1,470 @@
use rand::RngCore;
use crate::{
api::{
ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDft, SvpApplyDftToDft, SvpApplyDftToDftAdd, SvpApplyDftToDftInplace,
SvpPPolAlloc, SvpPrepare, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAlloc, VecZnxDftApply,
VecZnxIdftApplyConsume,
},
layouts::{Backend, DataViewMut, DigestU64, FillUniform, Module, ScalarZnx, ScratchOwned, SvpPPol, VecZnx, VecZnxDft},
source::Source,
};
pub fn test_svp_apply_dft<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: SvpPrepare<BR>
+ SvpApplyDft<BR>
+ SvpPPolAlloc<BR>
+ VecZnxDftAlloc<BR>
+ VecZnxBigNormalize<BR>
+ VecZnxIdftApplyConsume<BR>
+ VecZnxBigNormalizeTmpBytes,
Module<BT>: SvpPrepare<BT>
+ SvpApplyDft<BT>
+ SvpPPolAlloc<BT>
+ VecZnxDftAlloc<BT>
+ VecZnxBigNormalize<BT>
+ VecZnxIdftApplyConsume<BT>
+ VecZnxBigNormalizeTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
{
assert_eq!(module_ref.n(), module_test.n());
let n: usize = module_ref.n();
let cols: usize = 2;
let mut source: Source = Source::new([0u8; 32]);
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes());
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes());
let mut scalar: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(n, cols);
scalar.fill_uniform(basek, &mut source);
let scalar_digest: u64 = scalar.digest_u64();
let mut svp_ref: SvpPPol<Vec<u8>, BR> = module_ref.svp_ppol_alloc(cols);
let mut svp_test: SvpPPol<Vec<u8>, BT> = module_test.svp_ppol_alloc(cols);
for j in 0..cols {
module_ref.svp_prepare(&mut svp_ref, j, &scalar, j);
module_test.svp_prepare(&mut svp_test, j, &scalar, j);
}
assert_eq!(scalar.digest_u64(), scalar_digest);
let svp_ref_digest: u64 = svp_ref.digest_u64();
let svp_test_digest: u64 = svp_test.digest_u64();
for a_size in [1, 2, 3, 4] {
// Create a random input VecZnx
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
let a_digest: u64 = a.digest_u64();
for res_size in [1, 2, 3, 4] {
// Allocate VecZnxDft from FFT64Ref and module to test
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, res_size);
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, res_size);
// Fill output with garbage
source.fill_bytes(res_dft_ref.data_mut());
source.fill_bytes(res_dft_test.data_mut());
for j in 0..cols {
module_ref.svp_apply_dft(&mut res_dft_ref, j, &svp_ref, j, &a, j);
module_test.svp_apply_dft(&mut res_dft_test, j, &svp_test, j, &a, j);
}
// Assert no change to inputs
assert_eq!(svp_ref.digest_u64(), svp_ref_digest);
assert_eq!(svp_test.digest_u64(), svp_test_digest);
assert_eq!(a.digest_u64(), a_digest);
let res_big_ref: crate::layouts::VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
let res_big_test: crate::layouts::VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
&mut res_ref,
j,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
&mut res_test,
j,
&res_big_test,
j,
scratch_test.borrow(),
);
}
assert_eq!(res_ref, res_test);
}
}
}
pub fn test_svp_apply_dft_to_dft<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: SvpPrepare<BR>
+ SvpApplyDftToDft<BR>
+ SvpPPolAlloc<BR>
+ VecZnxDftAlloc<BR>
+ VecZnxBigNormalize<BR>
+ VecZnxDftApply<BR>
+ VecZnxIdftApplyConsume<BR>
+ VecZnxBigNormalizeTmpBytes,
Module<BT>: SvpPrepare<BT>
+ SvpApplyDftToDft<BT>
+ SvpPPolAlloc<BT>
+ VecZnxDftAlloc<BT>
+ VecZnxBigNormalize<BT>
+ VecZnxDftApply<BT>
+ VecZnxIdftApplyConsume<BT>
+ VecZnxBigNormalizeTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
{
assert_eq!(module_ref.n(), module_test.n());
let n: usize = module_ref.n();
let cols: usize = 2;
let mut source: Source = Source::new([0u8; 32]);
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes());
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes());
let mut scalar: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(n, cols);
scalar.fill_uniform(basek, &mut source);
let scalar_digest: u64 = scalar.digest_u64();
let mut svp_ref: SvpPPol<Vec<u8>, BR> = module_ref.svp_ppol_alloc(cols);
let mut svp_test: SvpPPol<Vec<u8>, BT> = module_test.svp_ppol_alloc(cols);
for j in 0..cols {
module_ref.svp_prepare(&mut svp_ref, j, &scalar, j);
module_test.svp_prepare(&mut svp_test, j, &scalar, j);
}
assert_eq!(scalar.digest_u64(), scalar_digest);
let svp_ref_digest: u64 = svp_ref.digest_u64();
let svp_test_digest: u64 = svp_test.digest_u64();
for a_size in [3] {
// Create a random input VecZnx
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
let a_digest: u64 = a.digest_u64();
let mut a_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, a_size);
let mut a_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, a_size);
for j in 0..cols {
module_ref.vec_znx_dft_apply(1, 0, &mut a_dft_ref, j, &a, j);
module_test.vec_znx_dft_apply(1, 0, &mut a_dft_test, j, &a, j);
}
assert_eq!(a.digest_u64(), a_digest);
let a_dft_ref_digest: u64 = a_dft_ref.digest_u64();
let a_dft_test_digest: u64 = a_dft_test.digest_u64();
for res_size in [3] {
// Allocate VecZnxDft from FFT64Ref and module to test
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, res_size);
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, res_size);
// Fill output with garbage
source.fill_bytes(res_dft_ref.data_mut());
source.fill_bytes(res_dft_test.data_mut());
for j in 0..cols {
module_ref.svp_apply_dft_to_dft(&mut res_dft_ref, j, &svp_ref, j, &a_dft_ref, j);
module_test.svp_apply_dft_to_dft(&mut res_dft_test, j, &svp_test, j, &a_dft_test, j);
}
// Assert no change to inputs
assert_eq!(a_dft_ref.digest_u64(), a_dft_ref_digest);
assert_eq!(a_dft_test.digest_u64(), a_dft_test_digest);
assert_eq!(svp_ref.digest_u64(), svp_ref_digest);
assert_eq!(svp_test.digest_u64(), svp_test_digest);
assert_eq!(a.digest_u64(), a_digest);
let res_big_ref: crate::layouts::VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
let res_big_test: crate::layouts::VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
println!("res_big_ref: {}", res_big_ref);
println!("res_big_test: {}", res_big_test);
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
&mut res_ref,
j,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
&mut res_test,
j,
&res_big_test,
j,
scratch_test.borrow(),
);
}
assert_eq!(res_ref, res_test);
}
}
}
pub fn test_svp_apply_dft_to_dft_add<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: SvpPrepare<BR>
+ SvpApplyDftToDftAdd<BR>
+ SvpPPolAlloc<BR>
+ VecZnxDftAlloc<BR>
+ VecZnxBigNormalize<BR>
+ VecZnxDftApply<BR>
+ VecZnxIdftApplyConsume<BR>
+ VecZnxBigNormalizeTmpBytes,
Module<BT>: SvpPrepare<BT>
+ SvpApplyDftToDftAdd<BT>
+ SvpPPolAlloc<BT>
+ VecZnxDftAlloc<BT>
+ VecZnxBigNormalize<BT>
+ VecZnxDftApply<BT>
+ VecZnxIdftApplyConsume<BT>
+ VecZnxBigNormalizeTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
{
assert_eq!(module_ref.n(), module_test.n());
let n: usize = module_ref.n();
let cols: usize = 2;
let mut source: Source = Source::new([0u8; 32]);
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes());
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes());
let mut scalar: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(n, cols);
scalar.fill_uniform(basek, &mut source);
let scalar_digest: u64 = scalar.digest_u64();
let mut svp_ref: SvpPPol<Vec<u8>, BR> = module_ref.svp_ppol_alloc(cols);
let mut svp_test: SvpPPol<Vec<u8>, BT> = module_test.svp_ppol_alloc(cols);
for j in 0..cols {
module_ref.svp_prepare(&mut svp_ref, j, &scalar, j);
module_test.svp_prepare(&mut svp_test, j, &scalar, j);
}
assert_eq!(scalar.digest_u64(), scalar_digest);
let svp_ref_digest: u64 = svp_ref.digest_u64();
let svp_test_digest: u64 = svp_test.digest_u64();
for a_size in [1, 2, 3, 4] {
// Create a random input VecZnx
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
let a_digest: u64 = a.digest_u64();
let mut a_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, a_size);
let mut a_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, a_size);
for j in 0..cols {
module_ref.vec_znx_dft_apply(1, 0, &mut a_dft_ref, j, &a, j);
module_test.vec_znx_dft_apply(1, 0, &mut a_dft_test, j, &a, j);
}
let a_dft_ref_digest: u64 = a_dft_ref.digest_u64();
let a_dft_test_digest: u64 = a_dft_test.digest_u64();
for res_size in [1, 2, 3, 4] {
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
res.fill_uniform(basek, &mut source);
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, res_size);
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, res_size);
for j in 0..cols {
module_ref.vec_znx_dft_apply(1, 0, &mut res_dft_ref, j, &res, j);
module_test.vec_znx_dft_apply(1, 0, &mut res_dft_test, j, &res, j);
}
// Fill output with garbage
source.fill_bytes(res_dft_ref.data_mut());
source.fill_bytes(res_dft_test.data_mut());
for j in 0..cols {
module_ref.svp_apply_dft_to_dft_add(&mut res_dft_ref, j, &svp_ref, j, &a_dft_ref, j);
module_test.svp_apply_dft_to_dft_add(&mut res_dft_test, j, &svp_test, j, &a_dft_test, j);
}
// Assert no change to inputs
assert_eq!(a_dft_ref.digest_u64(), a_dft_ref_digest);
assert_eq!(a_dft_test.digest_u64(), a_dft_test_digest);
assert_eq!(svp_ref.digest_u64(), svp_ref_digest);
assert_eq!(svp_test.digest_u64(), svp_test_digest);
assert_eq!(a.digest_u64(), a_digest);
let res_big_ref: crate::layouts::VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
let res_big_test: crate::layouts::VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
&mut res_ref,
j,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
&mut res_test,
j,
&res_big_test,
j,
scratch_test.borrow(),
);
}
assert_eq!(res_ref, res_test);
}
}
}
pub fn test_svp_apply_dft_to_dft_inplace<BR: Backend, BT: Backend>(
basek: usize,
module_ref: &Module<BR>,
module_test: &Module<BT>,
) where
Module<BR>: SvpPrepare<BR>
+ SvpApplyDftToDftInplace<BR>
+ SvpPPolAlloc<BR>
+ VecZnxDftAlloc<BR>
+ VecZnxBigNormalize<BR>
+ VecZnxDftApply<BR>
+ VecZnxIdftApplyConsume<BR>
+ VecZnxBigNormalizeTmpBytes,
Module<BT>: SvpPrepare<BT>
+ SvpApplyDftToDftInplace<BT>
+ SvpPPolAlloc<BT>
+ VecZnxDftAlloc<BT>
+ VecZnxBigNormalize<BT>
+ VecZnxDftApply<BT>
+ VecZnxIdftApplyConsume<BT>
+ VecZnxBigNormalizeTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
{
assert_eq!(module_ref.n(), module_test.n());
let n: usize = module_ref.n();
let cols: usize = 2;
let mut source: Source = Source::new([0u8; 32]);
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes());
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes());
let mut scalar: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(n, cols);
scalar.fill_uniform(basek, &mut source);
let scalar_digest: u64 = scalar.digest_u64();
let mut svp_ref: SvpPPol<Vec<u8>, BR> = module_ref.svp_ppol_alloc(cols);
let mut svp_test: SvpPPol<Vec<u8>, BT> = module_test.svp_ppol_alloc(cols);
for j in 0..cols {
module_ref.svp_prepare(&mut svp_ref, j, &scalar, j);
module_test.svp_prepare(&mut svp_test, j, &scalar, j);
}
assert_eq!(scalar.digest_u64(), scalar_digest);
let svp_ref_digest: u64 = svp_ref.digest_u64();
let svp_test_digest: u64 = svp_test.digest_u64();
for res_size in [1, 2, 3, 4] {
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
res.fill_uniform(basek, &mut source);
let res_digest: u64 = res.digest_u64();
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, res_size);
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, res_size);
for j in 0..cols {
module_ref.vec_znx_dft_apply(1, 0, &mut res_dft_ref, j, &res, j);
module_test.vec_znx_dft_apply(1, 0, &mut res_dft_test, j, &res, j);
}
assert_eq!(res.digest_u64(), res_digest);
for j in 0..cols {
module_ref.svp_apply_dft_to_dft_inplace(&mut res_dft_ref, j, &svp_ref, j);
module_test.svp_apply_dft_to_dft_inplace(&mut res_dft_test, j, &svp_test, j);
}
// Assert no change to inputs
assert_eq!(svp_ref.digest_u64(), svp_ref_digest);
assert_eq!(svp_test.digest_u64(), svp_test_digest);
let res_big_ref: crate::layouts::VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
let res_big_test: crate::layouts::VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
&mut res_ref,
j,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
&mut res_test,
j,
&res_big_test,
j,
scratch_test.borrow(),
);
}
println!("res_ref: {}", res_ref);
println!("res_test: {}", res_test);
assert_eq!(res_ref, res_test);
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,930 @@
use rand::RngCore;
use crate::{
api::{
ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxBigAlloc, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAdd,
VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftApply, VecZnxDftCopy, VecZnxDftSub, VecZnxDftSubABInplace,
VecZnxDftSubBAInplace, VecZnxIdftApply, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxIdftApplyTmpBytes,
},
layouts::{Backend, DataViewMut, DigestU64, FillUniform, Module, ScratchOwned, VecZnx, VecZnxBig, VecZnxDft},
source::Source,
};
pub fn test_vec_znx_dft_add<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxDftAdd<BR>
+ VecZnxDftAlloc<BR>
+ VecZnxDftApply<BR>
+ VecZnxIdftApplyConsume<BR>
+ VecZnxBigNormalize<BR>
+ VecZnxBigNormalizeTmpBytes,
Module<BT>: VecZnxDftAdd<BT>
+ VecZnxDftAlloc<BT>
+ VecZnxDftApply<BT>
+ VecZnxIdftApplyConsume<BT>
+ VecZnxBigNormalize<BT>
+ VecZnxBigNormalizeTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
{
assert_eq!(module_ref.n(), module_test.n());
let n: usize = module_ref.n();
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes());
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes());
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
let a_digest = a.digest_u64();
let mut a_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, a_size);
let mut a_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, a_size);
for j in 0..cols {
module_ref.vec_znx_dft_apply(1, 0, &mut a_dft_ref, j, &a, j);
module_test.vec_znx_dft_apply(1, 0, &mut a_dft_test, j, &a, j);
}
assert_eq!(a.digest_u64(), a_digest);
let a_dft_ref_digest: u64 = a_dft_ref.digest_u64();
let a_dft_test_digest: u64 = a_dft_test.digest_u64();
for b_size in [1, 2, 3, 4] {
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, b_size);
b.fill_uniform(basek, &mut source);
let b_digest: u64 = b.digest_u64();
let mut b_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, b_size);
let mut b_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, b_size);
for j in 0..cols {
module_ref.vec_znx_dft_apply(1, 0, &mut b_dft_ref, j, &b, j);
module_test.vec_znx_dft_apply(1, 0, &mut b_dft_test, j, &b, j);
}
assert_eq!(b.digest_u64(), b_digest);
let b_dft_ref_digest: u64 = b_dft_ref.digest_u64();
let b_dft_test_digest: u64 = b_dft_test.digest_u64();
for res_size in [1, 2, 3, 4] {
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, res_size);
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, res_size);
// Set d to garbage
source.fill_bytes(res_dft_ref.data_mut());
source.fill_bytes(res_dft_test.data_mut());
// Reference
for i in 0..cols {
module_ref.vec_znx_dft_add(&mut res_dft_ref, i, &a_dft_ref, i, &b_dft_ref, i);
module_test.vec_znx_dft_add(&mut res_dft_test, i, &a_dft_test, i, &b_dft_test, i);
}
assert_eq!(a_dft_ref.digest_u64(), a_dft_ref_digest);
assert_eq!(a_dft_test.digest_u64(), a_dft_test_digest);
assert_eq!(b_dft_ref.digest_u64(), b_dft_ref_digest);
assert_eq!(b_dft_test.digest_u64(), b_dft_test_digest);
let res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
let res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
let mut res_small_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_small_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let res_ref_digest: u64 = res_big_ref.digest_u64();
let res_test_digest: u64 = res_big_test.digest_u64();
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
&mut res_small_ref,
j,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
&mut res_small_test,
j,
&res_big_test,
j,
scratch_test.borrow(),
);
}
assert_eq!(res_big_ref.digest_u64(), res_ref_digest);
assert_eq!(res_big_test.digest_u64(), res_test_digest);
assert_eq!(res_small_ref, res_small_test);
}
}
}
}
pub fn test_vec_znx_dft_add_inplace<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxDftAddInplace<BR>
+ VecZnxDftAlloc<BR>
+ VecZnxDftApply<BR>
+ VecZnxIdftApplyConsume<BR>
+ VecZnxBigNormalize<BR>
+ VecZnxBigNormalizeTmpBytes,
Module<BT>: VecZnxDftAddInplace<BT>
+ VecZnxDftAlloc<BT>
+ VecZnxDftApply<BT>
+ VecZnxIdftApplyConsume<BT>
+ VecZnxBigNormalize<BT>
+ VecZnxBigNormalizeTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
{
assert_eq!(module_ref.n(), module_test.n());
let n: usize = module_ref.n();
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes());
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes());
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
let a_digest = a.digest_u64();
let mut a_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, a_size);
let mut a_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, a_size);
for j in 0..cols {
module_ref.vec_znx_dft_apply(1, 0, &mut a_dft_ref, j, &a, j);
module_test.vec_znx_dft_apply(1, 0, &mut a_dft_test, j, &a, j);
}
assert_eq!(a.digest_u64(), a_digest);
let a_dft_ref_digest: u64 = a_dft_ref.digest_u64();
let a_dft_test_digest: u64 = a_dft_test.digest_u64();
for res_size in [1, 2, 3, 4] {
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
res.fill_uniform(basek, &mut source);
let res_digest: u64 = res.digest_u64();
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, res_size);
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, res_size);
for j in 0..cols {
module_ref.vec_znx_dft_apply(1, 0, &mut res_dft_ref, j, &res, j);
module_test.vec_znx_dft_apply(1, 0, &mut res_dft_test, j, &res, j);
}
assert_eq!(res.digest_u64(), res_digest);
// Reference
for i in 0..cols {
module_ref.vec_znx_dft_add_inplace(&mut res_dft_ref, i, &a_dft_ref, i);
module_test.vec_znx_dft_add_inplace(&mut res_dft_test, i, &a_dft_test, i);
}
assert_eq!(a_dft_ref.digest_u64(), a_dft_ref_digest);
assert_eq!(a_dft_test.digest_u64(), a_dft_test_digest);
let res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
let res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
let mut res_small_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_small_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let res_ref_digest: u64 = res_big_ref.digest_u64();
let res_test_digest: u64 = res_big_test.digest_u64();
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
&mut res_small_ref,
j,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
&mut res_small_test,
j,
&res_big_test,
j,
scratch_test.borrow(),
);
}
assert_eq!(res_big_ref.digest_u64(), res_ref_digest);
assert_eq!(res_big_test.digest_u64(), res_test_digest);
assert_eq!(res_small_ref, res_small_test);
}
}
}
pub fn test_vec_znx_copy<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxDftCopy<BR>
+ VecZnxDftAlloc<BR>
+ VecZnxDftApply<BR>
+ VecZnxIdftApplyConsume<BR>
+ VecZnxBigNormalize<BR>
+ VecZnxBigNormalizeTmpBytes,
Module<BT>: VecZnxDftCopy<BT>
+ VecZnxDftAlloc<BT>
+ VecZnxDftApply<BT>
+ VecZnxIdftApplyConsume<BT>
+ VecZnxBigNormalize<BT>
+ VecZnxBigNormalizeTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
{
assert_eq!(module_ref.n(), module_test.n());
let n: usize = module_ref.n();
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes());
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes());
for a_size in [1, 2, 6, 11] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
let a_digest = a.digest_u64();
let mut a_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, a_size);
let mut a_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, a_size);
for j in 0..cols {
module_ref.vec_znx_dft_apply(1, 0, &mut a_dft_ref, j, &a, j);
module_test.vec_znx_dft_apply(1, 0, &mut a_dft_test, j, &a, j);
}
assert_eq!(a.digest_u64(), a_digest);
let a_dft_ref_digest: u64 = a_dft_ref.digest_u64();
let a_dft_test_digest: u64 = a_dft_test.digest_u64();
for res_size in [1, 2, 6, 11] {
for params in [[1, 0], [1, 1], [1, 2], [2, 2]] {
let steps: usize = params[0];
let offset: usize = params[1];
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, res_size);
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, res_size);
// Set d to garbage
source.fill_bytes(res_dft_ref.data_mut());
source.fill_bytes(res_dft_test.data_mut());
// Reference
for i in 0..cols {
module_ref.vec_znx_dft_copy(steps, offset, &mut res_dft_ref, i, &a_dft_ref, i);
module_test.vec_znx_dft_copy(steps, offset, &mut res_dft_test, i, &a_dft_test, i);
}
assert_eq!(a_dft_ref.digest_u64(), a_dft_ref_digest);
assert_eq!(a_dft_test.digest_u64(), a_dft_test_digest);
let res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
let res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
let mut res_small_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_small_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let res_ref_digest: u64 = res_big_ref.digest_u64();
let res_test_digest: u64 = res_big_test.digest_u64();
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
&mut res_small_ref,
j,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
&mut res_small_test,
j,
&res_big_test,
j,
scratch_test.borrow(),
);
}
assert_eq!(res_big_ref.digest_u64(), res_ref_digest);
assert_eq!(res_big_test.digest_u64(), res_test_digest);
assert_eq!(res_small_ref, res_small_test);
}
}
}
}
pub fn test_vec_znx_idft_apply<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxDftApply<BR>
+ VecZnxDftAlloc<BR>
+ VecZnxBigAlloc<BR>
+ VecZnxBigNormalize<BR>
+ VecZnxBigNormalizeTmpBytes
+ VecZnxIdftApply<BR>,
Module<BT>: VecZnxDftApply<BT>
+ VecZnxDftAlloc<BT>
+ VecZnxBigAlloc<BT>
+ VecZnxBigNormalize<BT>
+ VecZnxBigNormalizeTmpBytes
+ VecZnxIdftApply<BT>,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
{
assert_eq!(module_ref.n(), module_test.n());
let n: usize = module_ref.n();
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes());
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes());
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
let a_digest: u64 = a.digest_u64();
for res_size in [1, 2, 3, 4] {
for params in [[1, 0], [1, 1], [1, 2], [2, 2]] {
let steps: usize = params[0];
let offset: usize = params[1];
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, a_size);
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, a_size);
for j in 0..cols {
module_ref.vec_znx_dft_apply(steps, offset, &mut res_dft_ref, j, &a, j);
module_test.vec_znx_dft_apply(steps, offset, &mut res_dft_test, j, &a, j);
}
assert_eq!(a.digest_u64(), a_digest);
let res_dft_ref_digest: u64 = res_dft_ref.digest_u64();
let rest_dft_test_digest: u64 = res_dft_test.digest_u64();
let mut res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_big_alloc(cols, res_size);
let mut res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_big_alloc(cols, res_size);
for j in 0..cols {
module_ref.vec_znx_idft_apply(&mut res_big_ref, j, &res_dft_ref, j, scratch_ref.borrow());
module_test.vec_znx_idft_apply(
&mut res_big_test,
j,
&res_dft_test,
j,
scratch_test.borrow(),
);
}
assert_eq!(res_dft_ref.digest_u64(), res_dft_ref_digest);
assert_eq!(res_dft_test.digest_u64(), rest_dft_test_digest);
let mut res_small_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_small_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let res_ref_digest: u64 = res_big_ref.digest_u64();
let res_test_digest: u64 = res_big_test.digest_u64();
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
&mut res_small_ref,
j,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
&mut res_small_test,
j,
&res_big_test,
j,
scratch_test.borrow(),
);
}
assert_eq!(res_big_ref.digest_u64(), res_ref_digest);
assert_eq!(res_big_test.digest_u64(), res_test_digest);
assert_eq!(res_small_ref, res_small_test);
}
}
}
}
pub fn test_vec_znx_idft_apply_tmpa<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxDftApply<BR>
+ VecZnxDftAlloc<BR>
+ VecZnxBigAlloc<BR>
+ VecZnxBigNormalize<BR>
+ VecZnxBigNormalizeTmpBytes
+ VecZnxIdftApplyTmpA<BR>,
Module<BT>: VecZnxDftApply<BT>
+ VecZnxDftAlloc<BT>
+ VecZnxBigAlloc<BT>
+ VecZnxBigNormalize<BT>
+ VecZnxBigNormalizeTmpBytes
+ VecZnxIdftApplyTmpA<BT>,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
{
assert_eq!(module_ref.n(), module_test.n());
let n: usize = module_ref.n();
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes());
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes());
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
let a_digest: u64 = a.digest_u64();
for res_size in [1, 2, 3, 4] {
for params in [[1, 0], [1, 1], [1, 2], [2, 2]] {
let steps: usize = params[0];
let offset: usize = params[1];
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, a_size);
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, a_size);
for j in 0..cols {
module_ref.vec_znx_dft_apply(steps, offset, &mut res_dft_ref, j, &a, j);
module_test.vec_znx_dft_apply(steps, offset, &mut res_dft_test, j, &a, j);
}
assert_eq!(a.digest_u64(), a_digest);
let mut res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_big_alloc(cols, res_size);
let mut res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_big_alloc(cols, res_size);
for j in 0..cols {
module_ref.vec_znx_idft_apply_tmpa(&mut res_big_ref, j, &mut res_dft_ref, j);
module_test.vec_znx_idft_apply_tmpa(&mut res_big_test, j, &mut res_dft_test, j);
}
let mut res_small_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_small_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let res_ref_digest: u64 = res_big_ref.digest_u64();
let res_test_digest: u64 = res_big_test.digest_u64();
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
&mut res_small_ref,
j,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
&mut res_small_test,
j,
&res_big_test,
j,
scratch_test.borrow(),
);
}
assert_eq!(res_big_ref.digest_u64(), res_ref_digest);
assert_eq!(res_big_test.digest_u64(), res_test_digest);
assert_eq!(res_small_ref, res_small_test);
}
}
}
}
pub fn test_vec_znx_idft_apply_consume<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxDftApply<BR>
+ VecZnxIdftApplyTmpBytes
+ VecZnxDftAlloc<BR>
+ VecZnxBigNormalize<BR>
+ VecZnxBigNormalizeTmpBytes
+ VecZnxIdftApplyConsume<BR>,
Module<BT>: VecZnxDftApply<BT>
+ VecZnxIdftApplyTmpBytes
+ VecZnxDftAlloc<BT>
+ VecZnxBigNormalize<BT>
+ VecZnxBigNormalizeTmpBytes
+ VecZnxIdftApplyConsume<BT>,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
{
assert_eq!(module_ref.n(), module_test.n());
let n: usize = module_ref.n();
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let mut scratch_ref: ScratchOwned<BR> =
ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes() | module_ref.vec_znx_idft_apply_tmp_bytes());
let mut scratch_test: ScratchOwned<BT> =
ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes() | module_test.vec_znx_idft_apply_tmp_bytes());
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
let a_digest: u64 = a.digest_u64();
for res_size in [1, 2, 3, 4] {
for params in [[1, 0], [1, 1], [1, 2], [2, 2]] {
let steps: usize = params[0];
let offset: usize = params[1];
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, a_size);
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, a_size);
for j in 0..cols {
module_ref.vec_znx_dft_apply(steps, offset, &mut res_dft_ref, j, &a, j);
module_test.vec_znx_dft_apply(steps, offset, &mut res_dft_test, j, &a, j);
}
assert_eq!(a.digest_u64(), a_digest);
let res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
let res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
let mut res_small_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_small_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let res_ref_digest: u64 = res_big_ref.digest_u64();
let res_test_digest: u64 = res_big_test.digest_u64();
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
&mut res_small_ref,
j,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
&mut res_small_test,
j,
&res_big_test,
j,
scratch_test.borrow(),
);
}
assert_eq!(res_big_ref.digest_u64(), res_ref_digest);
assert_eq!(res_big_test.digest_u64(), res_test_digest);
assert_eq!(res_small_ref, res_small_test);
}
}
}
}
pub fn test_vec_znx_dft_sub<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxDftSub<BR>
+ VecZnxDftAlloc<BR>
+ VecZnxDftApply<BR>
+ VecZnxIdftApplyConsume<BR>
+ VecZnxBigNormalize<BR>
+ VecZnxBigNormalizeTmpBytes,
Module<BT>: VecZnxDftSub<BT>
+ VecZnxDftAlloc<BT>
+ VecZnxDftApply<BT>
+ VecZnxIdftApplyConsume<BT>
+ VecZnxBigNormalize<BT>
+ VecZnxBigNormalizeTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
{
assert_eq!(module_ref.n(), module_test.n());
let n: usize = module_ref.n();
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes());
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes());
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
let a_digest = a.digest_u64();
let mut a_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, a_size);
let mut a_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, a_size);
for j in 0..cols {
module_ref.vec_znx_dft_apply(1, 0, &mut a_dft_ref, j, &a, j);
module_test.vec_znx_dft_apply(1, 0, &mut a_dft_test, j, &a, j);
}
assert_eq!(a.digest_u64(), a_digest);
let a_dft_ref_digest: u64 = a_dft_ref.digest_u64();
let a_dft_test_digest: u64 = a_dft_test.digest_u64();
for b_size in [1, 2, 3, 4] {
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, b_size);
b.fill_uniform(basek, &mut source);
let b_digest: u64 = b.digest_u64();
let mut b_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, b_size);
let mut b_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, b_size);
for j in 0..cols {
module_ref.vec_znx_dft_apply(1, 0, &mut b_dft_ref, j, &b, j);
module_test.vec_znx_dft_apply(1, 0, &mut b_dft_test, j, &b, j);
}
assert_eq!(b.digest_u64(), b_digest);
let b_dft_ref_digest: u64 = b_dft_ref.digest_u64();
let b_dft_test_digest: u64 = b_dft_test.digest_u64();
for res_size in [1, 2, 3, 4] {
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, res_size);
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, res_size);
// Set d to garbage
source.fill_bytes(res_dft_ref.data_mut());
source.fill_bytes(res_dft_test.data_mut());
// Reference
for i in 0..cols {
module_ref.vec_znx_dft_sub(&mut res_dft_ref, i, &a_dft_ref, i, &b_dft_ref, i);
module_test.vec_znx_dft_sub(&mut res_dft_test, i, &a_dft_test, i, &b_dft_test, i);
}
assert_eq!(a_dft_ref.digest_u64(), a_dft_ref_digest);
assert_eq!(a_dft_test.digest_u64(), a_dft_test_digest);
assert_eq!(b_dft_ref.digest_u64(), b_dft_ref_digest);
assert_eq!(b_dft_test.digest_u64(), b_dft_test_digest);
let res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
let res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
let mut res_small_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_small_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let res_ref_digest: u64 = res_big_ref.digest_u64();
let res_test_digest: u64 = res_big_test.digest_u64();
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
&mut res_small_ref,
j,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
&mut res_small_test,
j,
&res_big_test,
j,
scratch_test.borrow(),
);
}
assert_eq!(res_big_ref.digest_u64(), res_ref_digest);
assert_eq!(res_big_test.digest_u64(), res_test_digest);
assert_eq!(res_small_ref, res_small_test);
}
}
}
}
pub fn test_vec_znx_dft_sub_ab_inplace<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxDftSubABInplace<BR>
+ VecZnxDftAlloc<BR>
+ VecZnxDftApply<BR>
+ VecZnxIdftApplyConsume<BR>
+ VecZnxBigNormalize<BR>
+ VecZnxBigNormalizeTmpBytes,
Module<BT>: VecZnxDftSubABInplace<BT>
+ VecZnxDftAlloc<BT>
+ VecZnxDftApply<BT>
+ VecZnxIdftApplyConsume<BT>
+ VecZnxBigNormalize<BT>
+ VecZnxBigNormalizeTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
{
assert_eq!(module_ref.n(), module_test.n());
let n: usize = module_ref.n();
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes());
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes());
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
let a_digest = a.digest_u64();
let mut a_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, a_size);
let mut a_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, a_size);
for j in 0..cols {
module_ref.vec_znx_dft_apply(1, 0, &mut a_dft_ref, j, &a, j);
module_test.vec_znx_dft_apply(1, 0, &mut a_dft_test, j, &a, j);
}
assert_eq!(a.digest_u64(), a_digest);
let a_dft_ref_digest: u64 = a_dft_ref.digest_u64();
let a_dft_test_digest: u64 = a_dft_test.digest_u64();
for res_size in [1, 2, 3, 4] {
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
res.fill_uniform(basek, &mut source);
let res_digest: u64 = res.digest_u64();
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, res_size);
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, res_size);
for j in 0..cols {
module_ref.vec_znx_dft_apply(1, 0, &mut res_dft_ref, j, &res, j);
module_test.vec_znx_dft_apply(1, 0, &mut res_dft_test, j, &res, j);
}
assert_eq!(res.digest_u64(), res_digest);
// Reference
for i in 0..cols {
module_ref.vec_znx_dft_sub_ab_inplace(&mut res_dft_ref, i, &a_dft_ref, i);
module_test.vec_znx_dft_sub_ab_inplace(&mut res_dft_test, i, &a_dft_test, i);
}
assert_eq!(a_dft_ref.digest_u64(), a_dft_ref_digest);
assert_eq!(a_dft_test.digest_u64(), a_dft_test_digest);
let res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
let res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
let mut res_small_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_small_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let res_ref_digest: u64 = res_big_ref.digest_u64();
let res_test_digest: u64 = res_big_test.digest_u64();
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
&mut res_small_ref,
j,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
&mut res_small_test,
j,
&res_big_test,
j,
scratch_test.borrow(),
);
}
assert_eq!(res_big_ref.digest_u64(), res_ref_digest);
assert_eq!(res_big_test.digest_u64(), res_test_digest);
assert_eq!(res_small_ref, res_small_test);
}
}
}
pub fn test_vec_znx_dft_sub_ba_inplace<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxDftSubBAInplace<BR>
+ VecZnxDftAlloc<BR>
+ VecZnxDftApply<BR>
+ VecZnxIdftApplyConsume<BR>
+ VecZnxBigNormalize<BR>
+ VecZnxBigNormalizeTmpBytes,
Module<BT>: VecZnxDftSubBAInplace<BT>
+ VecZnxDftAlloc<BT>
+ VecZnxDftApply<BT>
+ VecZnxIdftApplyConsume<BT>
+ VecZnxBigNormalize<BT>
+ VecZnxBigNormalizeTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
{
assert_eq!(module_ref.n(), module_test.n());
let n: usize = module_ref.n();
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes());
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes());
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
let a_digest = a.digest_u64();
let mut a_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, a_size);
let mut a_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, a_size);
for j in 0..cols {
module_ref.vec_znx_dft_apply(1, 0, &mut a_dft_ref, j, &a, j);
module_test.vec_znx_dft_apply(1, 0, &mut a_dft_test, j, &a, j);
}
assert_eq!(a.digest_u64(), a_digest);
let a_dft_ref_digest: u64 = a_dft_ref.digest_u64();
let a_dft_test_digest: u64 = a_dft_test.digest_u64();
for res_size in [1, 2, 3, 4] {
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
res.fill_uniform(basek, &mut source);
let res_digest: u64 = res.digest_u64();
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, res_size);
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, res_size);
for j in 0..cols {
module_ref.vec_znx_dft_apply(1, 0, &mut res_dft_ref, j, &res, j);
module_test.vec_znx_dft_apply(1, 0, &mut res_dft_test, j, &res, j);
}
assert_eq!(res.digest_u64(), res_digest);
// Reference
for i in 0..cols {
module_ref.vec_znx_dft_sub_ba_inplace(&mut res_dft_ref, i, &a_dft_ref, i);
module_test.vec_znx_dft_sub_ba_inplace(&mut res_dft_test, i, &a_dft_test, i);
}
assert_eq!(a_dft_ref.digest_u64(), a_dft_ref_digest);
assert_eq!(a_dft_test.digest_u64(), a_dft_test_digest);
let res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
let res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
let mut res_small_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_small_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let res_ref_digest: u64 = res_big_ref.digest_u64();
let res_test_digest: u64 = res_big_test.digest_u64();
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
&mut res_small_ref,
j,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
&mut res_small_test,
j,
&res_big_test,
j,
scratch_test.borrow(),
);
}
assert_eq!(res_big_ref.digest_u64(), res_ref_digest);
assert_eq!(res_big_test.digest_u64(), res_test_digest);
assert_eq!(res_small_ref, res_small_test);
}
}
}

View File

@@ -0,0 +1,384 @@
use crate::{
api::{
ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxBigNormalize, VecZnxDftAlloc, VecZnxDftApply,
VecZnxIdftApplyConsume, VmpApplyDft, VmpApplyDftTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd,
VmpApplyDftToDftAddTmpBytes, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare,
},
layouts::{DataViewMut, DigestU64, FillUniform, MatZnx, Module, ScratchOwned, VecZnx, VecZnxBig},
source::Source,
};
use rand::RngCore;
use crate::layouts::{Backend, VecZnxDft, VmpPMat};
pub fn test_vmp_apply_dft<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: ModuleNew<BR>
+ VmpApplyDftTmpBytes
+ VmpApplyDft<BR>
+ VmpPMatAlloc<BR>
+ VecZnxDftAlloc<BR>
+ VmpPrepare<BR>
+ VecZnxDftAlloc<BR>
+ VecZnxIdftApplyConsume<BR>
+ VecZnxBigNormalize<BR>,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
Module<BT>: ModuleNew<BT>
+ VmpApplyDftTmpBytes
+ VmpApplyDft<BT>
+ VmpPMatAlloc<BT>
+ VecZnxDftAlloc<BT>
+ VmpPrepare<BT>
+ VecZnxDftAlloc<BT>
+ VecZnxIdftApplyConsume<BT>
+ VecZnxBigNormalize<BT>,
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
{
assert_eq!(module_ref.n(), module_test.n());
let n: usize = module_ref.n();
let max_size: usize = 4;
let max_cols: usize = 2;
let mut source: Source = Source::new([0u8; 32]);
let mut scratch_ref: ScratchOwned<BR> =
ScratchOwned::alloc(module_ref.vmp_apply_dft_tmp_bytes(max_size, max_size, max_size, max_cols, max_cols, max_size));
let mut scratch_test: ScratchOwned<BT> =
ScratchOwned::alloc(module_test.vmp_apply_dft_tmp_bytes(max_size, max_size, max_size, max_cols, max_cols, max_size));
for cols_in in 1..max_cols + 1 {
for cols_out in 1..max_cols + 1 {
for size_in in 1..max_size + 1 {
for size_out in 1..max_size + 1 {
let rows: usize = cols_in;
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols_in, size_in);
a.fill_uniform(basek, &mut source);
let a_digest: u64 = a.digest_u64();
let mut mat: MatZnx<Vec<u8>> = MatZnx::alloc(n, rows, cols_in, cols_out, size_out);
mat.fill_uniform(basek, &mut source);
let mat_digest: u64 = mat.digest_u64();
let mut pmat_ref: VmpPMat<Vec<u8>, BR> = module_ref.vmp_pmat_alloc(rows, cols_in, cols_out, size_out);
let mut pmat_test: VmpPMat<Vec<u8>, BT> = module_test.vmp_pmat_alloc(rows, cols_in, cols_out, size_out);
module_ref.vmp_prepare(&mut pmat_ref, &mat, scratch_ref.borrow());
module_test.vmp_prepare(&mut pmat_test, &mat, scratch_test.borrow());
assert_eq!(mat.digest_u64(), mat_digest);
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols_out, size_out);
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols_out, size_out);
source.fill_bytes(res_dft_ref.data_mut());
source.fill_bytes(res_dft_test.data_mut());
module_ref.vmp_apply_dft(&mut res_dft_ref, &a, &pmat_ref, scratch_ref.borrow());
module_test.vmp_apply_dft(&mut res_dft_test, &a, &pmat_test, scratch_test.borrow());
assert_eq!(a.digest_u64(), a_digest);
let res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
let res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
let mut res_small_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols_out, size_out);
let mut res_small_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols_out, size_out);
let res_ref_digest: u64 = res_big_ref.digest_u64();
let res_test_digest: u64 = res_big_test.digest_u64();
for j in 0..cols_out {
module_ref.vec_znx_big_normalize(
basek,
&mut res_small_ref,
j,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
&mut res_small_test,
j,
&res_big_test,
j,
scratch_test.borrow(),
);
}
assert_eq!(res_big_ref.digest_u64(), res_ref_digest);
assert_eq!(res_big_test.digest_u64(), res_test_digest);
assert_eq!(res_small_ref, res_small_test);
}
}
}
}
}
pub fn test_vmp_apply_dft_to_dft<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: ModuleNew<BR>
+ VmpApplyDftToDftTmpBytes
+ VmpApplyDftToDft<BR>
+ VmpPMatAlloc<BR>
+ VecZnxDftAlloc<BR>
+ VmpPrepare<BR>
+ VecZnxIdftApplyConsume<BR>
+ VecZnxBigNormalize<BR>
+ VecZnxDftApply<BR>,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
Module<BT>: ModuleNew<BT>
+ VmpApplyDftToDftTmpBytes
+ VmpApplyDftToDft<BT>
+ VmpPMatAlloc<BT>
+ VecZnxDftAlloc<BT>
+ VmpPrepare<BT>
+ VecZnxIdftApplyConsume<BT>
+ VecZnxBigNormalize<BT>
+ VecZnxDftApply<BT>,
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
{
assert_eq!(module_ref.n(), module_test.n());
let n: usize = module_ref.n();
let max_size: usize = 4;
let max_cols: usize = 2;
let mut source: Source = Source::new([0u8; 32]);
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(
module_ref.vmp_apply_dft_to_dft_tmp_bytes(max_size, max_size, max_size, max_cols, max_cols, max_size),
);
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(
module_test.vmp_apply_dft_to_dft_tmp_bytes(max_size, max_size, max_size, max_cols, max_cols, max_size),
);
for cols_in in 1..max_cols + 1 {
for cols_out in 1..max_cols + 1 {
for size_in in 1..max_size + 1 {
for size_out in 1..max_size + 1 {
let rows: usize = size_in;
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols_in, size_in);
a.fill_uniform(basek, &mut source);
let a_digest: u64 = a.digest_u64();
let mut a_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols_in, size_in);
let mut a_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols_in, size_in);
for j in 0..cols_in {
module_ref.vec_znx_dft_apply(1, 0, &mut a_dft_ref, j, &a, j);
module_test.vec_znx_dft_apply(1, 0, &mut a_dft_test, j, &a, j);
}
assert_eq!(a.digest_u64(), a_digest);
let mut mat: MatZnx<Vec<u8>> = MatZnx::alloc(n, rows, cols_in, cols_out, size_out);
mat.fill_uniform(basek, &mut source);
let mat_digest: u64 = mat.digest_u64();
let mut pmat_ref: VmpPMat<Vec<u8>, BR> = module_ref.vmp_pmat_alloc(rows, cols_in, cols_out, size_out);
let mut pmat_test: VmpPMat<Vec<u8>, BT> = module_test.vmp_pmat_alloc(rows, cols_in, cols_out, size_out);
module_ref.vmp_prepare(&mut pmat_ref, &mat, scratch_ref.borrow());
module_test.vmp_prepare(&mut pmat_test, &mat, scratch_test.borrow());
assert_eq!(mat.digest_u64(), mat_digest);
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols_out, size_out);
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols_out, size_out);
source.fill_bytes(res_dft_ref.data_mut());
source.fill_bytes(res_dft_test.data_mut());
module_ref.vmp_apply_dft_to_dft(
&mut res_dft_ref,
&a_dft_ref,
&pmat_ref,
scratch_ref.borrow(),
);
module_test.vmp_apply_dft_to_dft(
&mut res_dft_test,
&a_dft_test,
&pmat_test,
scratch_test.borrow(),
);
let res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
let res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
let mut res_small_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols_out, size_out);
let mut res_small_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols_out, size_out);
let res_ref_digest: u64 = res_big_ref.digest_u64();
let res_test_digest: u64 = res_big_test.digest_u64();
for j in 0..cols_out {
module_ref.vec_znx_big_normalize(
basek,
&mut res_small_ref,
j,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
&mut res_small_test,
j,
&res_big_test,
j,
scratch_test.borrow(),
);
}
assert_eq!(res_big_ref.digest_u64(), res_ref_digest);
assert_eq!(res_big_test.digest_u64(), res_test_digest);
assert_eq!(res_small_ref, res_small_test);
}
}
}
}
}
pub fn test_vmp_apply_dft_to_dft_add<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: ModuleNew<BR>
+ VmpApplyDftToDftAddTmpBytes
+ VmpApplyDftToDftAdd<BR>
+ VmpPMatAlloc<BR>
+ VecZnxDftAlloc<BR>
+ VmpPrepare<BR>
+ VecZnxIdftApplyConsume<BR>
+ VecZnxBigNormalize<BR>
+ VecZnxDftApply<BR>,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
Module<BT>: ModuleNew<BT>
+ VmpApplyDftToDftAddTmpBytes
+ VmpApplyDftToDftAdd<BT>
+ VmpPMatAlloc<BT>
+ VecZnxDftAlloc<BT>
+ VmpPrepare<BT>
+ VecZnxIdftApplyConsume<BT>
+ VecZnxBigNormalize<BT>
+ VecZnxDftApply<BT>,
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
{
assert_eq!(module_ref.n(), module_test.n());
let n: usize = module_ref.n();
let max_size: usize = 4;
let max_cols: usize = 2;
let mut source: Source = Source::new([0u8; 32]);
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(
module_ref.vmp_apply_dft_to_dft_add_tmp_bytes(max_size, max_size, max_size, max_cols, max_cols, max_size),
);
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(
module_test.vmp_apply_dft_to_dft_add_tmp_bytes(max_size, max_size, max_size, max_cols, max_cols, max_size),
);
for cols_in in 1..max_cols + 1 {
for cols_out in 1..max_cols + 1 {
for size_in in 1..max_size + 1 {
for size_out in 1..max_size + 1 {
let rows: usize = size_in;
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols_in, size_in);
a.fill_uniform(basek, &mut source);
let a_digest: u64 = a.digest_u64();
let mut a_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols_in, size_in);
let mut a_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols_in, size_in);
for j in 0..cols_in {
module_ref.vec_znx_dft_apply(1, 0, &mut a_dft_ref, j, &a, j);
module_test.vec_znx_dft_apply(1, 0, &mut a_dft_test, j, &a, j);
}
assert_eq!(a.digest_u64(), a_digest);
let mut mat: MatZnx<Vec<u8>> = MatZnx::alloc(n, rows, cols_in, cols_out, size_out);
mat.fill_uniform(basek, &mut source);
let mat_digest: u64 = mat.digest_u64();
let mut pmat_ref: VmpPMat<Vec<u8>, BR> = module_ref.vmp_pmat_alloc(rows, cols_in, cols_out, size_out);
let mut pmat_test: VmpPMat<Vec<u8>, BT> = module_test.vmp_pmat_alloc(rows, cols_in, cols_out, size_out);
module_ref.vmp_prepare(&mut pmat_ref, &mat, scratch_ref.borrow());
module_test.vmp_prepare(&mut pmat_test, &mat, scratch_test.borrow());
assert_eq!(mat.digest_u64(), mat_digest);
for limb_offset in 0..size_out {
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols_out, size_out);
res.fill_uniform(basek, &mut source);
let res_digest: u64 = res.digest_u64();
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols_out, size_out);
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols_out, size_out);
for j in 0..cols_out {
module_ref.vec_znx_dft_apply(1, 0, &mut res_dft_ref, j, &res, j);
module_test.vec_znx_dft_apply(1, 0, &mut res_dft_test, j, &res, j);
}
assert_eq!(res.digest_u64(), res_digest);
module_ref.vmp_apply_dft_to_dft_add(
&mut res_dft_ref,
&a_dft_ref,
&pmat_ref,
limb_offset * cols_out,
scratch_ref.borrow(),
);
module_test.vmp_apply_dft_to_dft_add(
&mut res_dft_test,
&a_dft_test,
&pmat_test,
limb_offset * cols_out,
scratch_test.borrow(),
);
let res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
let res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
let mut res_small_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols_out, size_out);
let mut res_small_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols_out, size_out);
let res_ref_digest: u64 = res_big_ref.digest_u64();
let res_test_digest: u64 = res_big_test.digest_u64();
for j in 0..cols_out {
module_ref.vec_znx_big_normalize(
basek,
&mut res_small_ref,
j,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
&mut res_small_test,
j,
&res_big_test,
j,
scratch_test.borrow(),
);
}
assert_eq!(res_big_ref.digest_u64(), res_ref_digest);
assert_eq!(res_big_test.digest_u64(), res_test_digest);
assert_eq!(res_small_ref, res_small_test);
}
}
}
}
}
}

View File

@@ -1,3 +0,0 @@
pub mod serialization;
pub mod vec_znx;
pub mod vmp_pmat;

View File

@@ -1,50 +0,0 @@
use crate::{
layouts::{VecZnx, ZnxInfos, ZnxViewMut},
source::Source,
};
pub fn test_vec_znx_encode_vec_i64_lo_norm() {
let n: usize = 32;
let basek: usize = 17;
let size: usize = 5;
let k: usize = size * basek - 5;
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, 2, size);
let mut source: Source = Source::new([0u8; 32]);
let raw: &mut [i64] = a.raw_mut();
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
(0..a.cols()).for_each(|col_i| {
let mut have: Vec<i64> = vec![i64::default(); n];
have.iter_mut()
.for_each(|x| *x = (source.next_i64() << 56) >> 56);
a.encode_vec_i64(basek, col_i, k, &have, 10);
let mut want: Vec<i64> = vec![i64::default(); n];
a.decode_vec_i64(basek, col_i, k, &mut want);
assert_eq!(have, want, "{:?} != {:?}", &have, &want);
});
}
pub fn test_vec_znx_encode_vec_i64_hi_norm() {
let n: usize = 32;
let basek: usize = 17;
let size: usize = 5;
for k in [1, basek / 2, size * basek - 5] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, 2, size);
let mut source = Source::new([0u8; 32]);
let raw: &mut [i64] = a.raw_mut();
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
(0..a.cols()).for_each(|col_i| {
let mut have: Vec<i64> = vec![i64::default(); n];
have.iter_mut().for_each(|x| {
if k < 64 {
*x = source.next_u64n(1 << k, (1 << k) - 1) as i64;
} else {
*x = source.next_i64();
}
});
a.encode_vec_i64(basek, col_i, k, &have, 63);
let mut want: Vec<i64> = vec![i64::default(); n];
a.decode_vec_i64(basek, col_i, k, &mut want);
assert_eq!(have, want, "{:?} != {:?}", &have, &want);
})
}
}

View File

@@ -1,67 +0,0 @@
use crate::{
api::{VecZnxAddNormal, VecZnxFillUniform},
layouts::{Backend, Module, VecZnx, ZnxView},
source::Source,
};
pub fn test_vec_znx_fill_uniform<B: Backend>(module: &Module<B>)
where
Module<B>: VecZnxFillUniform,
{
let n: usize = module.n();
let basek: usize = 17;
let size: usize = 5;
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let zero: Vec<i64> = vec![0; n];
let one_12_sqrt: f64 = 0.28867513459481287;
(0..cols).for_each(|col_i| {
let mut a: VecZnx<_> = VecZnx::alloc(n, cols, size);
module.vec_znx_fill_uniform(basek, &mut a, col_i, size * basek, &mut source);
(0..cols).for_each(|col_j| {
if col_j != col_i {
(0..size).for_each(|limb_i| {
assert_eq!(a.at(col_j, limb_i), zero);
})
} else {
let std: f64 = a.std(basek, col_i);
assert!(
(std - one_12_sqrt).abs() < 0.01,
"std={} ~!= {}",
std,
one_12_sqrt
);
}
})
});
}
pub fn test_vec_znx_add_normal<B: Backend>(module: &Module<B>)
where
Module<B>: VecZnxAddNormal,
{
let n: usize = module.n();
let basek: usize = 17;
let k: usize = 2 * 17;
let size: usize = 5;
let sigma: f64 = 3.2;
let bound: f64 = 6.0 * sigma;
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let zero: Vec<i64> = vec![0; n];
let k_f64: f64 = (1u64 << k as u64) as f64;
(0..cols).for_each(|col_i| {
let mut a: VecZnx<_> = VecZnx::alloc(n, cols, size);
module.vec_znx_add_normal(basek, &mut a, col_i, k, &mut source, sigma, bound);
(0..cols).for_each(|col_j| {
if col_j != col_i {
(0..size).for_each(|limb_i| {
assert_eq!(a.at(col_j, limb_i), zero);
})
} else {
let std: f64 = a.std(basek, col_i) * k_f64;
assert!((std - sigma).abs() < 0.1, "std={} ~!= {}", std, sigma);
}
})
});
}

View File

@@ -1,5 +0,0 @@
mod generics;
pub use generics::*;
#[cfg(test)]
mod encoding;

View File

@@ -1,3 +0,0 @@
mod vmp_apply;
pub use vmp_apply::*;

Some files were not shown because too many files have changed in this diff Show More