mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
Ref. + AVX code & generic tests + benches (#85)
This commit is contained in:
committed by
GitHub
parent
99b9e3e10e
commit
56dbd29c59
@@ -17,7 +17,10 @@ rand = {workspace = true}
|
||||
rand_distr = {workspace = true}
|
||||
rand_core = {workspace = true}
|
||||
byteorder = {workspace = true}
|
||||
once_cell = {workspace = true}
|
||||
rand_chacha = "0.9.0"
|
||||
bytemuck = "1.23.2"
|
||||
|
||||
|
||||
[build-dependencies]
|
||||
cmake = "0.1.54"
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
use crate::layouts::{Backend, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDftToMut, VecZnxDftToRef};
|
||||
use crate::layouts::{
|
||||
Backend, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef,
|
||||
};
|
||||
|
||||
/// Allocates as [crate::layouts::SvpPPol].
|
||||
pub trait SvpPPolAlloc<B: Backend> {
|
||||
@@ -25,8 +27,26 @@ pub trait SvpPrepare<B: Backend> {
|
||||
}
|
||||
|
||||
/// Apply a scalar-vector product between `a[a_col]` and `b[b_col]` and stores the result on `res[res_col]`.
|
||||
pub trait SvpApply<B: Backend> {
|
||||
fn svp_apply<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
pub trait SvpApplyDft<B: Backend> {
|
||||
fn svp_apply_dft<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: SvpPPolToRef<B>,
|
||||
C: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// Apply a scalar-vector product between `a[a_col]` and `b[b_col]` and stores the result on `res[res_col]`.
|
||||
pub trait SvpApplyDftToDft<B: Backend> {
|
||||
fn svp_apply_dft_to_dft<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: SvpPPolToRef<B>,
|
||||
C: VecZnxDftToRef<B>;
|
||||
}
|
||||
|
||||
/// Apply a scalar-vector product between `a[a_col]` and `b[b_col]` and adds the result on `res[res_col]`.
|
||||
pub trait SvpApplyDftToDftAdd<B: Backend> {
|
||||
fn svp_apply_dft_to_dft_add<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: SvpPPolToRef<B>,
|
||||
@@ -34,8 +54,8 @@ pub trait SvpApply<B: Backend> {
|
||||
}
|
||||
|
||||
/// Apply a scalar-vector product between `res[res_col]` and `a[a_col]` and stores the result on `res[res_col]`.
|
||||
pub trait SvpApplyInplace<B: Backend> {
|
||||
fn svp_apply_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
pub trait SvpApplyDftToDftInplace<B: Backend> {
|
||||
fn svp_apply_dft_to_dft_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: SvpPPolToRef<B>;
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
use rand_distr::Distribution;
|
||||
|
||||
use crate::{
|
||||
layouts::{Backend, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef},
|
||||
source::Source,
|
||||
@@ -42,6 +40,16 @@ pub trait VecZnxAddInplace {
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxAddScalar {
|
||||
/// Adds the selected column of `a` on the selected column and limb of `b` and writes the result on the selected column of `res`.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn vec_znx_add_scalar<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize, b_limb: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
B: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxAddScalarInplace {
|
||||
/// Adds the selected column of `a` on the selected column and limb of `res`.
|
||||
fn vec_znx_add_scalar_inplace<R, A>(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize)
|
||||
@@ -79,6 +87,16 @@ pub trait VecZnxSubBAInplace {
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxSubScalar {
|
||||
/// Subtracts the selected column of `a` on the selected column and limb of `b` and writes the result on the selected column of `res`.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn vec_znx_sub_scalar<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize, b_limb: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
B: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxSubScalarInplace {
|
||||
/// Subtracts the selected column of `a` on the selected column and limb of `res`.
|
||||
fn vec_znx_sub_scalar_inplace<R, A>(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize)
|
||||
@@ -102,31 +120,61 @@ pub trait VecZnxNegateInplace {
|
||||
A: VecZnxToMut;
|
||||
}
|
||||
|
||||
pub trait VecZnxLshInplace {
|
||||
pub trait VecZnxLshTmpBytes {
|
||||
fn vec_znx_lsh_tmp_bytes(&self) -> usize;
|
||||
}
|
||||
|
||||
pub trait VecZnxLsh<B: Backend> {
|
||||
/// Left shift by k bits all columns of `a`.
|
||||
fn vec_znx_lsh_inplace<A>(&self, basek: usize, k: usize, a: &mut A)
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn vec_znx_lsh<R, A>(&self, basek: usize, k: usize, r: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxRshTmpBytes {
|
||||
fn vec_znx_rsh_tmp_bytes(&self) -> usize;
|
||||
}
|
||||
|
||||
pub trait VecZnxRsh<B: Backend> {
|
||||
/// Right shift by k bits all columns of `a`.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn vec_znx_rsh<R, A>(&self, basek: usize, k: usize, r: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxLshInplace<B: Backend> {
|
||||
/// Left shift by k bits all columns of `a`.
|
||||
fn vec_znx_lsh_inplace<A>(&self, basek: usize, k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
A: VecZnxToMut;
|
||||
}
|
||||
|
||||
pub trait VecZnxRshInplace {
|
||||
pub trait VecZnxRshInplace<B: Backend> {
|
||||
/// Right shift by k bits all columns of `a`.
|
||||
fn vec_znx_rsh_inplace<A>(&self, basek: usize, k: usize, a: &mut A)
|
||||
fn vec_znx_rsh_inplace<A>(&self, basek: usize, k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
A: VecZnxToMut;
|
||||
}
|
||||
|
||||
pub trait VecZnxRotate {
|
||||
/// Multiplies the selected column of `a` by X^k and stores the result in `res_col` of `res`.
|
||||
fn vec_znx_rotate<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
fn vec_znx_rotate<R, A>(&self, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxRotateInplace {
|
||||
pub trait VecZnxRotateInplaceTmpBytes {
|
||||
fn vec_znx_rotate_inplace_tmp_bytes(&self) -> usize;
|
||||
}
|
||||
|
||||
pub trait VecZnxRotateInplace<B: Backend> {
|
||||
/// Multiplies the selected column of `a` by X^k.
|
||||
fn vec_znx_rotate_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||
fn vec_znx_rotate_inplace<A>(&self, p: i64, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
A: VecZnxToMut;
|
||||
}
|
||||
@@ -139,54 +187,70 @@ pub trait VecZnxAutomorphism {
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxAutomorphismInplace {
|
||||
pub trait VecZnxAutomorphismInplaceTmpBytes {
|
||||
fn vec_znx_automorphism_inplace_tmp_bytes(&self) -> usize;
|
||||
}
|
||||
|
||||
pub trait VecZnxAutomorphismInplace<B: Backend> {
|
||||
/// Applies the automorphism X^i -> X^ik on the selected column of `a`.
|
||||
fn vec_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||
fn vec_znx_automorphism_inplace<R>(&self, k: i64, res: &mut R, res_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
A: VecZnxToMut;
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
pub trait VecZnxMulXpMinusOne {
|
||||
fn vec_znx_mul_xp_minus_one<R, A>(&self, p: i64, r: &mut R, r_col: usize, a: &A, a_col: usize)
|
||||
fn vec_znx_mul_xp_minus_one<R, A>(&self, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxMulXpMinusOneInplace {
|
||||
fn vec_znx_mul_xp_minus_one_inplace<R>(&self, p: i64, r: &mut R, r_col: usize)
|
||||
pub trait VecZnxMulXpMinusOneInplaceTmpBytes {
|
||||
fn vec_znx_mul_xp_minus_one_inplace_tmp_bytes(&self) -> usize;
|
||||
}
|
||||
|
||||
pub trait VecZnxMulXpMinusOneInplace<B: Backend> {
|
||||
fn vec_znx_mul_xp_minus_one_inplace<R>(&self, p: i64, res: &mut R, res_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
pub trait VecZnxSplit<B: Backend> {
|
||||
pub trait VecZnxSplitRingTmpBytes {
|
||||
fn vec_znx_split_ring_tmp_bytes(&self) -> usize;
|
||||
}
|
||||
|
||||
pub trait VecZnxSplitRing<B: Backend> {
|
||||
/// Splits the selected columns of `b` into subrings and copies them them into the selected column of `res`.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// This method requires that all [crate::layouts::VecZnx] of b have the same ring degree
|
||||
/// and that b.n() * b.len() <= a.n()
|
||||
fn vec_znx_split<R, A>(&self, res: &mut [R], res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
fn vec_znx_split_ring<R, A>(&self, res: &mut [R], res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxMerge {
|
||||
pub trait VecZnxMergeRingsTmpBytes {
|
||||
fn vec_znx_merge_rings_tmp_bytes(&self) -> usize;
|
||||
}
|
||||
|
||||
pub trait VecZnxMergeRings<B: Backend> {
|
||||
/// Merges the subrings of the selected column of `a` into the selected column of `res`.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// This method requires that all [crate::layouts::VecZnx] of a have the same ring degree
|
||||
/// and that a.n() * a.len() <= b.n()
|
||||
fn vec_znx_merge<R, A>(&self, res: &mut R, res_col: usize, a: &[A], a_col: usize)
|
||||
fn vec_znx_merge_rings<R, A>(&self, res: &mut R, res_col: usize, a: &[A], a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxSwithcDegree {
|
||||
fn vec_znx_switch_degree<R, A>(&self, res: &mut R, res_col: usize, a: &A, col_a: usize)
|
||||
pub trait VecZnxSwitchRing {
|
||||
fn vec_znx_switch_ring<R, A>(&self, res: &mut R, res_col: usize, a: &A, col_a: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
@@ -201,42 +265,11 @@ pub trait VecZnxCopy {
|
||||
|
||||
pub trait VecZnxFillUniform {
|
||||
/// Fills the first `size` size with uniform values in \[-2^{basek-1}, 2^{basek-1}\]
|
||||
fn vec_znx_fill_uniform<R>(&self, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source)
|
||||
fn vec_znx_fill_uniform<R>(&self, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
|
||||
where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub trait VecZnxFillDistF64 {
|
||||
fn vec_znx_fill_dist_f64<R, D: Distribution<f64>>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub trait VecZnxAddDistF64 {
|
||||
/// Adds vector sampled according to the provided distribution, scaled by 2^{-k} and bounded to \[-bound, bound\].
|
||||
fn vec_znx_add_dist_f64<R, D: Distribution<f64>>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub trait VecZnxFillNormal {
|
||||
fn vec_znx_fill_normal<R>(
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
use rand_distr::Distribution;
|
||||
|
||||
use crate::{
|
||||
layouts::{Backend, Scratch, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
pub trait VecZnxBigFromSmall<B: Backend> {
|
||||
fn vec_znx_big_from_small<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// Allocates as [crate::layouts::VecZnxBig].
|
||||
pub trait VecZnxBigAlloc<B: Backend> {
|
||||
fn vec_znx_big_alloc(&self, cols: usize, size: usize) -> VecZnxBigOwned<B>;
|
||||
@@ -45,48 +50,6 @@ pub trait VecZnxBigAddNormal<B: Backend> {
|
||||
);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub trait VecZnxBigFillNormal<B: Backend> {
|
||||
fn vec_znx_big_fill_normal<R: VecZnxBigToMut<B>>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub trait VecZnxBigFillDistF64<B: Backend> {
|
||||
fn vec_znx_big_fill_dist_f64<R: VecZnxBigToMut<B>, D: Distribution<f64>>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub trait VecZnxBigAddDistF64<B: Backend> {
|
||||
fn vec_znx_big_add_dist_f64<R: VecZnxBigToMut<B>, D: Distribution<f64>>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
);
|
||||
}
|
||||
|
||||
pub trait VecZnxBigAdd<B: Backend> {
|
||||
/// Adds `a` to `b` and stores the result on `c`.
|
||||
fn vec_znx_big_add<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
@@ -180,10 +143,17 @@ pub trait VecZnxBigSubSmallBInplace<B: Backend> {
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxBigNegateInplace<B: Backend> {
|
||||
fn vec_znx_big_negate_inplace<A>(&self, a: &mut A, a_col: usize)
|
||||
pub trait VecZnxBigNegate<B: Backend> {
|
||||
fn vec_znx_big_negate<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
A: VecZnxBigToMut<B>;
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxBigNegateInplace<B: Backend> {
|
||||
fn vec_znx_big_negate_inplace<R>(&self, res: &mut R, res_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxBigNormalizeTmpBytes {
|
||||
@@ -204,9 +174,13 @@ pub trait VecZnxBigNormalize<B: Backend> {
|
||||
A: VecZnxBigToRef<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxBigAutomorphismInplaceTmpBytes {
|
||||
fn vec_znx_big_automorphism_inplace_tmp_bytes(&self) -> usize;
|
||||
}
|
||||
|
||||
pub trait VecZnxBigAutomorphism<B: Backend> {
|
||||
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`.
|
||||
fn vec_znx_big_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
fn vec_znx_big_automorphism<R, A>(&self, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>;
|
||||
@@ -214,7 +188,7 @@ pub trait VecZnxBigAutomorphism<B: Backend> {
|
||||
|
||||
pub trait VecZnxBigAutomorphismInplace<B: Backend> {
|
||||
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`.
|
||||
fn vec_znx_big_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||
fn vec_znx_big_automorphism_inplace<R>(&self, p: i64, res: &mut R, res_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
A: VecZnxBigToMut<B>;
|
||||
R: VecZnxBigToMut<B>;
|
||||
}
|
||||
|
||||
@@ -14,33 +14,33 @@ pub trait VecZnxDftAllocBytes {
|
||||
fn vec_znx_dft_alloc_bytes(&self, cols: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
pub trait DFT<B: Backend> {
|
||||
fn dft<R, A>(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
pub trait VecZnxDftApply<B: Backend> {
|
||||
fn vec_znx_dft_apply<R, A>(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxIDFTTmpBytes {
|
||||
fn vec_znx_idft_tmp_bytes(&self) -> usize;
|
||||
pub trait VecZnxIdftApplyTmpBytes {
|
||||
fn vec_znx_idft_apply_tmp_bytes(&self) -> usize;
|
||||
}
|
||||
|
||||
pub trait IDFT<B: Backend> {
|
||||
fn idft<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
pub trait VecZnxIdftApply<B: Backend> {
|
||||
fn vec_znx_idft_apply<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxDftToRef<B>;
|
||||
}
|
||||
|
||||
pub trait IDFTTmpA<B: Backend> {
|
||||
fn idft_tmp_a<R, A>(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
|
||||
pub trait VecZnxIdftApplyTmpA<B: Backend> {
|
||||
fn vec_znx_idft_apply_tmpa<R, A>(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxDftToMut<B>;
|
||||
}
|
||||
|
||||
pub trait IDFTConsume<B: Backend> {
|
||||
fn vec_znx_idft_consume<D: Data>(&self, a: VecZnxDft<D, B>) -> VecZnxBig<D, B>
|
||||
pub trait VecZnxIdftApplyConsume<B: Backend> {
|
||||
fn vec_znx_idft_apply_consume<D: Data>(&self, a: VecZnxDft<D, B>) -> VecZnxBig<D, B>
|
||||
where
|
||||
VecZnxDft<D, B>: VecZnxDftToMut<B>;
|
||||
}
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
use crate::layouts::{Backend, MatZnxToRef, Scratch, VecZnxDftToMut, VecZnxDftToRef, VmpPMatOwned, VmpPMatToMut, VmpPMatToRef};
|
||||
use crate::layouts::{
|
||||
Backend, MatZnxToRef, Scratch, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, VmpPMatOwned, VmpPMatToMut, VmpPMatToRef,
|
||||
};
|
||||
|
||||
pub trait VmpPMatAlloc<B: Backend> {
|
||||
fn vmp_pmat_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<B>;
|
||||
@@ -17,12 +19,33 @@ pub trait VmpPrepareTmpBytes {
|
||||
}
|
||||
|
||||
pub trait VmpPrepare<B: Backend> {
|
||||
fn vmp_prepare<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<B>)
|
||||
fn vmp_prepare<R, A>(&self, pmat: &mut R, mat: &A, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VmpPMatToMut<B>,
|
||||
A: MatZnxToRef;
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub trait VmpApplyDftTmpBytes {
|
||||
fn vmp_apply_dft_tmp_bytes(
|
||||
&self,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
b_cols_in: usize,
|
||||
b_cols_out: usize,
|
||||
b_size: usize,
|
||||
) -> usize;
|
||||
}
|
||||
|
||||
pub trait VmpApplyDft<B: Backend> {
|
||||
fn vmp_apply_dft<R, A, C>(&self, res: &mut R, a: &A, pmat: &C, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxToRef,
|
||||
C: VmpPMatToRef<B>;
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub trait VmpApplyDftToDftTmpBytes {
|
||||
fn vmp_apply_dft_to_dft_tmp_bytes(
|
||||
@@ -61,7 +84,7 @@ pub trait VmpApplyDftToDft<B: Backend> {
|
||||
/// * `a`: the left operand [crate::layouts::VecZnxDft] of the vector matrix product.
|
||||
/// * `b`: the right operand [crate::layouts::VmpPMat] of the vector matrix product.
|
||||
/// * `buf`: scratch space, the size can be obtained with [VmpApplyTmpBytes::vmp_apply_tmp_bytes].
|
||||
fn vmp_apply_dft_to_dft<R, A, C>(&self, res: &mut R, a: &A, b: &C, scratch: &mut Scratch<B>)
|
||||
fn vmp_apply_dft_to_dft<R, A, C>(&self, res: &mut R, a: &A, pmat: &C, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
@@ -82,7 +105,7 @@ pub trait VmpApplyDftToDftAddTmpBytes {
|
||||
}
|
||||
|
||||
pub trait VmpApplyDftToDftAdd<B: Backend> {
|
||||
fn vmp_apply_dft_to_dft_add<R, A, C>(&self, res: &mut R, a: &A, b: &C, scale: usize, scratch: &mut Scratch<B>)
|
||||
fn vmp_apply_dft_to_dft_add<R, A, C>(&self, res: &mut R, a: &A, b: &C, limb_offset: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
|
||||
@@ -1,57 +1,29 @@
|
||||
use rand_distr::Distribution;
|
||||
|
||||
use crate::{
|
||||
layouts::{Backend, Scratch, ZnToMut},
|
||||
reference::zn::zn_normalize_tmp_bytes,
|
||||
source::Source,
|
||||
};
|
||||
|
||||
pub trait ZnNormalizeTmpBytes {
|
||||
fn zn_normalize_tmp_bytes(&self, n: usize) -> usize {
|
||||
zn_normalize_tmp_bytes(n)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ZnNormalizeInplace<B: Backend> {
|
||||
/// Normalizes the selected column of `a`.
|
||||
fn zn_normalize_inplace<A>(&self, n: usize, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
fn zn_normalize_inplace<R>(&self, n: usize, basek: usize, res: &mut R, res_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
A: ZnToMut;
|
||||
R: ZnToMut;
|
||||
}
|
||||
|
||||
pub trait ZnFillUniform {
|
||||
/// Fills the first `size` size with uniform values in \[-2^{basek-1}, 2^{basek-1}\]
|
||||
fn zn_fill_uniform<R>(&self, n: usize, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source)
|
||||
fn zn_fill_uniform<R>(&self, n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
|
||||
where
|
||||
R: ZnToMut;
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub trait ZnFillDistF64 {
|
||||
fn zn_fill_dist_f64<R, D: Distribution<f64>>(
|
||||
&self,
|
||||
n: usize,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: ZnToMut;
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub trait ZnAddDistF64 {
|
||||
/// Adds vector sampled according to the provided distribution, scaled by 2^{-k} and bounded to \[-bound, bound\].
|
||||
fn zn_add_dist_f64<R, D: Distribution<f64>>(
|
||||
&self,
|
||||
n: usize,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: ZnToMut;
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub trait ZnFillNormal {
|
||||
fn zn_fill_normal<R>(
|
||||
|
||||
5
poulpy-hal/src/bench_suite/mod.rs
Normal file
5
poulpy-hal/src/bench_suite/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
pub mod svp;
|
||||
pub mod vec_znx;
|
||||
pub mod vec_znx_big;
|
||||
pub mod vec_znx_dft;
|
||||
pub mod vmp;
|
||||
237
poulpy-hal/src/bench_suite/svp.rs
Normal file
237
poulpy-hal/src/bench_suite/svp.rs
Normal file
@@ -0,0 +1,237 @@
|
||||
use std::hint::black_box;
|
||||
|
||||
use criterion::{BenchmarkId, Criterion};
|
||||
use rand::RngCore;
|
||||
|
||||
use crate::{
|
||||
api::{
|
||||
ModuleNew, SvpApplyDft, SvpApplyDftToDft, SvpApplyDftToDftAdd, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPrepare,
|
||||
VecZnxDftAlloc,
|
||||
},
|
||||
layouts::{Backend, DataViewMut, FillUniform, Module, ScalarZnx, SvpPPol, VecZnx, VecZnxDft},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
pub fn bench_svp_prepare<B>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: SvpPrepare<B> + SvpPPolAlloc<B> + ModuleNew<B>,
|
||||
B: Backend<ScalarPrep = f64>,
|
||||
{
|
||||
let group_name: String = format!("svp_prepare::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B>(log_n: usize) -> impl FnMut()
|
||||
where
|
||||
Module<B>: SvpPrepare<B> + SvpPPolAlloc<B> + ModuleNew<B>,
|
||||
B: Backend<ScalarPrep = f64>,
|
||||
{
|
||||
let module: Module<B> = Module::<B>::new(1 << log_n);
|
||||
|
||||
let cols: usize = 2;
|
||||
|
||||
let mut svp: SvpPPol<Vec<u8>, B> = module.svp_ppol_alloc(cols);
|
||||
let mut a: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(module.n(), cols);
|
||||
let mut source = Source::new([0u8; 32]);
|
||||
a.fill_uniform(50, &mut source);
|
||||
|
||||
move || {
|
||||
module.svp_prepare(&mut svp, 0, &a, 0);
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for log_n in [10, 11, 12, 13, 14] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}", 1 << log_n));
|
||||
let mut runner = runner::<B>(log_n);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_svp_apply_dft<B>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: SvpApplyDft<B> + SvpPPolAlloc<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
|
||||
B: Backend<ScalarPrep = f64>,
|
||||
{
|
||||
let group_name: String = format!("svp_apply_dft::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: SvpApplyDft<B> + SvpPPolAlloc<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
|
||||
B: Backend<ScalarPrep = f64>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let mut svp: SvpPPol<Vec<u8>, B> = module.svp_ppol_alloc(cols);
|
||||
let mut res: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(cols, size);
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
let mut source = Source::new([0u8; 32]);
|
||||
|
||||
source.fill_bytes(svp.data_mut());
|
||||
source.fill_bytes(res.data_mut());
|
||||
source.fill_bytes(a.data_mut());
|
||||
|
||||
move || {
|
||||
for j in 0..cols {
|
||||
module.svp_apply_dft(&mut res, j, &svp, j, &a, j);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 7], [13, 2, 15], [14, 2, 31]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_svp_apply_dft_to_dft<B>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: SvpApplyDftToDft<B> + SvpPPolAlloc<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
|
||||
B: Backend<ScalarPrep = f64>,
|
||||
{
|
||||
let group_name: String = format!("svp_apply_dft_to_dft::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: SvpApplyDftToDft<B> + SvpPPolAlloc<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
|
||||
B: Backend<ScalarPrep = f64>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let mut svp: SvpPPol<Vec<u8>, B> = module.svp_ppol_alloc(cols);
|
||||
let mut res: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(cols, size);
|
||||
let mut a: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(cols, size);
|
||||
|
||||
let mut source = Source::new([0u8; 32]);
|
||||
|
||||
source.fill_bytes(svp.data_mut());
|
||||
source.fill_bytes(res.data_mut());
|
||||
source.fill_bytes(a.data_mut());
|
||||
|
||||
move || {
|
||||
for j in 0..cols {
|
||||
module.svp_apply_dft_to_dft(&mut res, j, &svp, j, &a, j);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 7], [13, 2, 15], [14, 2, 31]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_svp_apply_dft_to_dft_add<B>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: SvpApplyDftToDftAdd<B> + SvpPPolAlloc<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
|
||||
B: Backend<ScalarPrep = f64>,
|
||||
{
|
||||
let group_name: String = format!("svp_apply_dft_to_dft_add::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: SvpApplyDftToDftAdd<B> + SvpPPolAlloc<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
|
||||
B: Backend<ScalarPrep = f64>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let mut svp: SvpPPol<Vec<u8>, B> = module.svp_ppol_alloc(cols);
|
||||
let mut res: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(cols, size);
|
||||
let mut a: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(cols, size);
|
||||
|
||||
let mut source = Source::new([0u8; 32]);
|
||||
|
||||
source.fill_bytes(svp.data_mut());
|
||||
source.fill_bytes(res.data_mut());
|
||||
source.fill_bytes(a.data_mut());
|
||||
|
||||
move || {
|
||||
for j in 0..cols {
|
||||
module.svp_apply_dft_to_dft_add(&mut res, j, &svp, j, &a, j);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 7], [13, 2, 15], [14, 2, 31]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_svp_apply_dft_to_dft_inplace<B>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: SvpApplyDftToDftInplace<B> + SvpPPolAlloc<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
|
||||
B: Backend<ScalarPrep = f64>,
|
||||
{
|
||||
let group_name: String = format!("svp_apply_dft_to_dft_inplace::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: SvpApplyDftToDftInplace<B> + SvpPPolAlloc<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
|
||||
B: Backend<ScalarPrep = f64>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let mut svp: SvpPPol<Vec<u8>, B> = module.svp_ppol_alloc(cols);
|
||||
let mut res: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(cols, size);
|
||||
|
||||
let mut source = Source::new([0u8; 32]);
|
||||
|
||||
source.fill_bytes(svp.data_mut());
|
||||
source.fill_bytes(res.data_mut());
|
||||
|
||||
move || {
|
||||
for j in 0..cols {
|
||||
module.svp_apply_dft_to_dft_inplace(&mut res, j, &svp, j);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 7], [13, 2, 15], [14, 2, 31]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
1
poulpy-hal/src/bench_suite/vec_znx.rs
Normal file
1
poulpy-hal/src/bench_suite/vec_znx.rs
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
641
poulpy-hal/src/bench_suite/vec_znx_big.rs
Normal file
641
poulpy-hal/src/bench_suite/vec_znx_big.rs
Normal file
@@ -0,0 +1,641 @@
|
||||
use std::hint::black_box;
|
||||
|
||||
use criterion::{BenchmarkId, Criterion};
|
||||
use rand::RngCore;
|
||||
|
||||
use crate::{
|
||||
api::{
|
||||
ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxBigAdd, VecZnxBigAddInplace, VecZnxBigAddSmall,
|
||||
VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAutomorphism, VecZnxBigAutomorphismInplace,
|
||||
VecZnxBigAutomorphismInplaceTmpBytes, VecZnxBigNegate, VecZnxBigNegateInplace, VecZnxBigNormalize,
|
||||
VecZnxBigNormalizeTmpBytes, VecZnxBigSub, VecZnxBigSubABInplace, VecZnxBigSubBAInplace, VecZnxBigSubSmallA,
|
||||
VecZnxBigSubSmallB,
|
||||
},
|
||||
layouts::{Backend, DataViewMut, Module, ScratchOwned, VecZnx, VecZnxBig},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
pub fn bench_vec_znx_big_add<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxBigAdd<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_big_add::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxBigAdd<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
|
||||
{
|
||||
let n: usize = params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
|
||||
let mut b: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
|
||||
let mut c: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
|
||||
|
||||
// Fill a with random bytes
|
||||
source.fill_bytes(a.data_mut());
|
||||
source.fill_bytes(b.data_mut());
|
||||
source.fill_bytes(c.data_mut());
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_big_add(&mut c, i, &a, i, &b, i);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_big_add_inplace<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxBigAddInplace<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_big_add_inplace::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxBigAddInplace<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
|
||||
{
|
||||
let n: usize = params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
|
||||
let mut c: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
|
||||
|
||||
// Fill a with random i64
|
||||
source.fill_bytes(a.data_mut());
|
||||
source.fill_bytes(c.data_mut());
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_big_add_inplace(&mut c, i, &a, i);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_big_add_small<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxBigAddSmall<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_big_add_small::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxBigAddSmall<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
|
||||
{
|
||||
let n: usize = params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
|
||||
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
let mut c: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
|
||||
|
||||
// Fill a with random i64
|
||||
source.fill_bytes(a.data_mut());
|
||||
source.fill_bytes(b.data_mut());
|
||||
source.fill_bytes(c.data_mut());
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_big_add_small(&mut c, i, &a, i, &b, i);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_big_add_small_inplace<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxBigAddSmallInplace<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_big_add_small_inplace::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxBigAddSmallInplace<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
|
||||
{
|
||||
let n: usize = params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
let mut c: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
|
||||
|
||||
// Fill a with random i64
|
||||
source.fill_bytes(a.data_mut());
|
||||
source.fill_bytes(c.data_mut());
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_big_add_small_inplace(&mut c, i, &a, i);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_big_automorphism<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxBigAutomorphism<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_big_automorphism::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxBigAutomorphism<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
|
||||
{
|
||||
let n: usize = params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
|
||||
let mut res: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
|
||||
|
||||
// Fill a with random i64
|
||||
source.fill_bytes(a.data_mut());
|
||||
source.fill_bytes(res.data_mut());
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_big_automorphism(-7, &mut res, i, &a, i);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_automorphism_inplace<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxBigAutomorphismInplace<B> + VecZnxBigAutomorphismInplaceTmpBytes + ModuleNew<B> + VecZnxBigAlloc<B>,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_automorphism_inplace::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxBigAutomorphismInplace<B> + ModuleNew<B> + VecZnxBigAutomorphismInplaceTmpBytes + VecZnxBigAlloc<B>,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let n: usize = params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut res: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
|
||||
|
||||
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(module.vec_znx_big_automorphism_inplace_tmp_bytes());
|
||||
|
||||
// Fill a with random i64
|
||||
source.fill_bytes(res.data_mut());
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_big_automorphism_inplace(-7, &mut res, i, scratch.borrow());
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_big_negate<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxBigNegate<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_big_negate::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxBigNegate<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
|
||||
{
|
||||
let module: Module<B> = Module::<B>::new(1 << params[0]);
|
||||
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
|
||||
let mut b: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
|
||||
|
||||
// Fill a with random i64
|
||||
source.fill_bytes(a.data_mut());
|
||||
source.fill_bytes(b.data_mut());
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_big_negate(&mut b, i, &a, i);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_big_negate_inplace<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxBigNegateInplace<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_negate_big_inplace::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxBigNegateInplace<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
|
||||
{
|
||||
let module: Module<B> = Module::<B>::new(1 << params[0]);
|
||||
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
|
||||
|
||||
// Fill a with random i64
|
||||
source.fill_bytes(a.data_mut());
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_big_negate_inplace(&mut a, i);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_normalize<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxBigNormalize<B> + ModuleNew<B> + VecZnxBigNormalizeTmpBytes + VecZnxBigAlloc<B>,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_big_normalize::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxBigNormalize<B> + ModuleNew<B> + VecZnxBigNormalizeTmpBytes + VecZnxBigAlloc<B>,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let n: usize = params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let basek: usize = 50;
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
|
||||
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), cols, size);
|
||||
|
||||
// Fill a with random i64
|
||||
source.fill_bytes(a.data_mut());
|
||||
source.fill_bytes(res.data_mut());
|
||||
|
||||
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(module.vec_znx_big_normalize_tmp_bytes());
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_big_normalize(basek, &mut res, i, &a, i, scratch.borrow());
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_big_sub<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxBigSub<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_big_sub::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxBigSub<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
|
||||
{
|
||||
let module: Module<B> = Module::<B>::new(1 << params[0]);
|
||||
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
|
||||
let mut b: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
|
||||
let mut c: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
|
||||
|
||||
// Fill a with random bytes
|
||||
source.fill_bytes(a.data_mut());
|
||||
source.fill_bytes(b.data_mut());
|
||||
source.fill_bytes(c.data_mut());
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_big_sub(&mut c, i, &a, i, &b, i);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_big_sub_ab_inplace<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxBigSubABInplace<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_big_sub_inplace::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxBigSubABInplace<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
|
||||
{
|
||||
let module: Module<B> = Module::<B>::new(1 << params[0]);
|
||||
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
|
||||
let mut c: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
|
||||
|
||||
// Fill a with random bytes
|
||||
source.fill_bytes(a.data_mut());
|
||||
source.fill_bytes(c.data_mut());
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_big_sub_ab_inplace(&mut c, i, &a, i);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_big_sub_ba_inplace<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxBigSubBAInplace<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_big_sub_inplace::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxBigSubBAInplace<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
|
||||
{
|
||||
let module: Module<B> = Module::<B>::new(1 << params[0]);
|
||||
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
|
||||
let mut c: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
|
||||
|
||||
// Fill a with random bytes
|
||||
source.fill_bytes(a.data_mut());
|
||||
source.fill_bytes(c.data_mut());
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_big_sub_ba_inplace(&mut c, i, &a, i);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_big_sub_small_a<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxBigSubSmallA<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_big_sub_small_a::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxBigSubSmallA<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
|
||||
{
|
||||
let module: Module<B> = Module::<B>::new(1 << params[0]);
|
||||
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), cols, size);
|
||||
let mut b: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
|
||||
let mut c: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
|
||||
|
||||
// Fill a with random bytes
|
||||
source.fill_bytes(a.data_mut());
|
||||
source.fill_bytes(b.data_mut());
|
||||
source.fill_bytes(c.data_mut());
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_big_sub_small_a(&mut c, i, &a, i, &b, i);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_big_sub_small_b<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxBigSubSmallB<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_big_sub_small_b::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxBigSubSmallB<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
|
||||
{
|
||||
let module: Module<B> = Module::<B>::new(1 << params[0]);
|
||||
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
|
||||
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), cols, size);
|
||||
let mut c: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
|
||||
|
||||
// Fill a with random bytes
|
||||
source.fill_bytes(a.data_mut());
|
||||
source.fill_bytes(b.data_mut());
|
||||
source.fill_bytes(c.data_mut());
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_big_sub_small_b(&mut c, i, &a, i, &b, i);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
365
poulpy-hal/src/bench_suite/vec_znx_dft.rs
Normal file
365
poulpy-hal/src/bench_suite/vec_znx_dft.rs
Normal file
@@ -0,0 +1,365 @@
|
||||
use std::hint::black_box;
|
||||
|
||||
use criterion::{BenchmarkId, Criterion};
|
||||
use rand::RngCore;
|
||||
|
||||
use crate::{
|
||||
api::{
|
||||
ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxBigAlloc, VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAlloc,
|
||||
VecZnxDftApply, VecZnxDftSub, VecZnxDftSubABInplace, VecZnxDftSubBAInplace, VecZnxIdftApply, VecZnxIdftApplyTmpA,
|
||||
VecZnxIdftApplyTmpBytes,
|
||||
},
|
||||
layouts::{Backend, DataViewMut, Module, ScratchOwned, VecZnx, VecZnxBig, VecZnxDft},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
pub fn bench_vec_znx_dft_add<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxDftAdd<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_dft_add::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxDftAdd<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
|
||||
{
|
||||
let n: usize = params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(cols, size);
|
||||
let mut b: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(cols, size);
|
||||
let mut c: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(cols, size);
|
||||
|
||||
source.fill_bytes(a.data_mut());
|
||||
source.fill_bytes(b.data_mut());
|
||||
source.fill_bytes(c.data_mut());
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_dft_add(&mut c, i, &a, i, &b, i);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_dft_add_inplace<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxDftAddInplace<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_dft_add_inplace::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxDftAddInplace<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
|
||||
{
|
||||
let n: usize = params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(cols, size);
|
||||
let mut c: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(cols, size);
|
||||
|
||||
// Fill a with random i64
|
||||
source.fill_bytes(a.data_mut());
|
||||
source.fill_bytes(c.data_mut());
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_dft_add_inplace(&mut c, i, &a, i);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_dft_apply<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxDftApply<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_dft_apply::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxDftApply<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
|
||||
{
|
||||
let n: usize = params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut res: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(cols, size);
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
source.fill_bytes(res.data_mut());
|
||||
source.fill_bytes(a.data_mut());
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_dft_apply(1, 0, &mut res, i, &a, i);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_idft_apply<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxIdftApply<B> + ModuleNew<B> + VecZnxIdftApplyTmpBytes + VecZnxDftAlloc<B> + VecZnxBigAlloc<B>,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_idft_apply::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxIdftApply<B> + ModuleNew<B> + VecZnxIdftApplyTmpBytes + VecZnxDftAlloc<B> + VecZnxBigAlloc<B>,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let n: usize = params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut res: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
|
||||
let mut a: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(cols, size);
|
||||
source.fill_bytes(res.data_mut());
|
||||
source.fill_bytes(a.data_mut());
|
||||
|
||||
let mut scratch = ScratchOwned::alloc(module.vec_znx_idft_apply_tmp_bytes());
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_idft_apply(&mut res, i, &a, i, scratch.borrow());
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_idft_apply_tmpa<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxIdftApplyTmpA<B> + ModuleNew<B> + VecZnxDftAlloc<B> + VecZnxBigAlloc<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_idft_apply_tmpa::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxIdftApplyTmpA<B> + ModuleNew<B> + VecZnxDftAlloc<B> + VecZnxBigAlloc<B>,
|
||||
{
|
||||
let module: Module<B> = Module::<B>::new(1 << params[0]);
|
||||
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut res: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(cols, size);
|
||||
let mut a: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(cols, size);
|
||||
source.fill_bytes(res.data_mut());
|
||||
source.fill_bytes(a.data_mut());
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_idft_apply_tmpa(&mut res, i, &mut a, i);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_dft_sub<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxDftSub<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_dft_sub::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxDftSub<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
|
||||
{
|
||||
let n: usize = params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(cols, size);
|
||||
let mut b: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(cols, size);
|
||||
let mut c: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(cols, size);
|
||||
|
||||
source.fill_bytes(a.data_mut());
|
||||
source.fill_bytes(b.data_mut());
|
||||
source.fill_bytes(c.data_mut());
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_dft_sub(&mut c, i, &a, i, &b, i);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_dft_sub_ab_inplace<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxDftSubABInplace<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_dft_sub_ab_inplace::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxDftSubABInplace<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
|
||||
{
|
||||
let n: usize = params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(cols, size);
|
||||
let mut c: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(cols, size);
|
||||
|
||||
// Fill a with random i64
|
||||
source.fill_bytes(a.data_mut());
|
||||
source.fill_bytes(c.data_mut());
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_dft_sub_ab_inplace(&mut c, i, &a, i);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_dft_sub_ba_inplace<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxDftSubBAInplace<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_dft_sub_ba_inplace::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxDftSubBAInplace<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
|
||||
{
|
||||
let n: usize = params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(cols, size);
|
||||
let mut c: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(cols, size);
|
||||
|
||||
// Fill a with random i64
|
||||
source.fill_bytes(a.data_mut());
|
||||
source.fill_bytes(c.data_mut());
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_dft_sub_ba_inplace(&mut c, i, &a, i);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
259
poulpy-hal/src/bench_suite/vmp.rs
Normal file
259
poulpy-hal/src/bench_suite/vmp.rs
Normal file
@@ -0,0 +1,259 @@
|
||||
use std::hint::black_box;
|
||||
|
||||
use criterion::{BenchmarkId, Criterion};
|
||||
use rand::RngCore;
|
||||
|
||||
use crate::{
|
||||
api::{
|
||||
ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxDftAlloc, VmpApplyDft, VmpApplyDftTmpBytes, VmpApplyDftToDft,
|
||||
VmpApplyDftToDftAdd, VmpApplyDftToDftAddTmpBytes, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, VmpPrepareTmpBytes,
|
||||
},
|
||||
layouts::{Backend, DataViewMut, MatZnx, Module, ScratchOwned, VecZnx, VecZnxDft, VmpPMat},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
pub fn bench_vmp_prepare<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: ModuleNew<B> + VmpPMatAlloc<B> + VmpPrepare<B> + VmpPrepareTmpBytes,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let group_name: String = format!("vmp_prepare::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 5]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: ModuleNew<B> + VmpPMatAlloc<B> + VmpPrepare<B> + VmpPrepareTmpBytes,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let module: Module<B> = Module::<B>::new(1 << params[0]);
|
||||
|
||||
let rows: usize = params[1];
|
||||
let cols_in: usize = params[2];
|
||||
let cols_out: usize = params[3];
|
||||
let size: usize = params[4];
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(module.vmp_prepare_tmp_bytes(rows, cols_in, cols_out, size));
|
||||
|
||||
let mut mat: MatZnx<Vec<u8>> = MatZnx::alloc(module.n(), rows, cols_in, cols_out, size);
|
||||
let mut pmat: VmpPMat<Vec<u8>, B> = module.vmp_pmat_alloc(rows, cols_in, cols_out, size);
|
||||
|
||||
source.fill_bytes(mat.data_mut());
|
||||
source.fill_bytes(pmat.data_mut());
|
||||
|
||||
move || {
|
||||
module.vmp_prepare(&mut pmat, &mat, scratch.borrow());
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [
|
||||
[10, 2, 1, 2, 3],
|
||||
[11, 4, 1, 2, 5],
|
||||
[12, 7, 1, 2, 8],
|
||||
[13, 15, 1, 2, 16],
|
||||
[14, 31, 1, 2, 32],
|
||||
] {
|
||||
let id = BenchmarkId::from_parameter(format!(
|
||||
"{}x({}x{})x({}x{})",
|
||||
1 << params[0],
|
||||
params[2],
|
||||
params[1],
|
||||
params[3],
|
||||
params[4]
|
||||
));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vmp_apply_dft<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: ModuleNew<B> + VmpApplyDftTmpBytes + VmpApplyDft<B> + VmpPMatAlloc<B> + VecZnxDftAlloc<B>,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let group_name: String = format!("vmp_apply_dft::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 5]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: ModuleNew<B> + VmpApplyDftTmpBytes + VmpApplyDft<B> + VmpPMatAlloc<B> + VecZnxDftAlloc<B>,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let module: Module<B> = Module::<B>::new(1 << params[0]);
|
||||
|
||||
let rows: usize = params[1];
|
||||
let cols_in: usize = params[2];
|
||||
let cols_out: usize = params[3];
|
||||
let size: usize = params[4];
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(1 << 20);
|
||||
|
||||
let mut res: VecZnxDft<Vec<u8>, _> = module.vec_znx_dft_alloc(cols_out, size);
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), cols_in, size);
|
||||
let mut pmat: VmpPMat<Vec<u8>, B> = module.vmp_pmat_alloc(rows, cols_in, cols_out, size);
|
||||
|
||||
source.fill_bytes(pmat.data_mut());
|
||||
source.fill_bytes(res.data_mut());
|
||||
source.fill_bytes(a.data_mut());
|
||||
|
||||
move || {
|
||||
module.vmp_apply_dft(&mut res, &a, &pmat, scratch.borrow());
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [
|
||||
[10, 2, 1, 2, 3],
|
||||
[11, 4, 1, 2, 5],
|
||||
[12, 7, 1, 2, 8],
|
||||
[13, 15, 1, 2, 16],
|
||||
[14, 31, 1, 2, 32],
|
||||
] {
|
||||
let id = BenchmarkId::from_parameter(format!(
|
||||
"{}x({}x{})x({}x{})",
|
||||
1 << params[0],
|
||||
params[2],
|
||||
params[1],
|
||||
params[3],
|
||||
params[4]
|
||||
));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vmp_apply_dft_to_dft<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: ModuleNew<B> + VecZnxDftAlloc<B> + VmpPMatAlloc<B> + VmpApplyDftToDft<B> + VmpApplyDftToDftTmpBytes,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let group_name: String = format!("vmp_apply_dft_to_dft::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 5]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: ModuleNew<B> + VecZnxDftAlloc<B> + VmpPMatAlloc<B> + VmpApplyDftToDft<B> + VmpApplyDftToDftTmpBytes,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let module: Module<B> = Module::<B>::new(1 << params[0]);
|
||||
|
||||
let rows: usize = params[1];
|
||||
let cols_in: usize = params[2];
|
||||
let cols_out: usize = params[3];
|
||||
let size: usize = params[4];
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut scratch: ScratchOwned<B> =
|
||||
ScratchOwned::alloc(module.vmp_apply_dft_to_dft_tmp_bytes(size, size, rows, cols_in, cols_out, size));
|
||||
|
||||
let mut res: VecZnxDft<Vec<u8>, _> = module.vec_znx_dft_alloc(cols_out, size);
|
||||
let mut a: VecZnxDft<Vec<u8>, _> = module.vec_znx_dft_alloc(cols_in, size);
|
||||
|
||||
let mut pmat: VmpPMat<Vec<u8>, B> = module.vmp_pmat_alloc(rows, cols_in, cols_out, size);
|
||||
|
||||
source.fill_bytes(pmat.data_mut());
|
||||
source.fill_bytes(res.data_mut());
|
||||
source.fill_bytes(a.data_mut());
|
||||
|
||||
move || {
|
||||
module.vmp_apply_dft_to_dft(&mut res, &a, &pmat, scratch.borrow());
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [
|
||||
[10, 2, 1, 2, 3],
|
||||
[11, 4, 1, 2, 5],
|
||||
[12, 7, 1, 2, 8],
|
||||
[13, 15, 1, 2, 16],
|
||||
[14, 31, 1, 2, 32],
|
||||
] {
|
||||
let id = BenchmarkId::from_parameter(format!(
|
||||
"{}x({}x{})x({}x{})",
|
||||
1 << params[0], // n
|
||||
params[2], // cols_in
|
||||
params[1], // size_in (=rows)
|
||||
params[3], // cols_out
|
||||
params[4] // size_out
|
||||
));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vmp_apply_dft_to_dft_add<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: ModuleNew<B> + VecZnxDftAlloc<B> + VmpPMatAlloc<B> + VmpApplyDftToDftAdd<B> + VmpApplyDftToDftAddTmpBytes,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let group_name: String = format!("vmp_apply_dft_to_dft_add::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 5]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: ModuleNew<B> + VecZnxDftAlloc<B> + VmpPMatAlloc<B> + VmpApplyDftToDftAdd<B> + VmpApplyDftToDftAddTmpBytes,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let module: Module<B> = Module::<B>::new(1 << params[0]);
|
||||
|
||||
let rows: usize = params[1];
|
||||
let cols_in: usize = params[2];
|
||||
let cols_out: usize = params[3];
|
||||
let size: usize = params[4];
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut scratch: ScratchOwned<B> =
|
||||
ScratchOwned::alloc(module.vmp_apply_dft_to_dft_add_tmp_bytes(size, size, rows, cols_in, cols_out, size));
|
||||
|
||||
let mut res: VecZnxDft<Vec<u8>, _> = module.vec_znx_dft_alloc(cols_out, size);
|
||||
let mut a: VecZnxDft<Vec<u8>, _> = module.vec_znx_dft_alloc(cols_in, size);
|
||||
|
||||
let mut pmat: VmpPMat<Vec<u8>, B> = module.vmp_pmat_alloc(rows, cols_in, cols_out, size);
|
||||
|
||||
source.fill_bytes(pmat.data_mut());
|
||||
source.fill_bytes(res.data_mut());
|
||||
source.fill_bytes(a.data_mut());
|
||||
|
||||
move || {
|
||||
module.vmp_apply_dft_to_dft_add(&mut res, &a, &pmat, 1, scratch.borrow());
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [
|
||||
[10, 2, 1, 2, 3],
|
||||
[11, 4, 1, 2, 5],
|
||||
[12, 7, 1, 2, 8],
|
||||
[13, 15, 1, 2, 16],
|
||||
[14, 31, 1, 2, 32],
|
||||
] {
|
||||
let id = BenchmarkId::from_parameter(format!(
|
||||
"{}x({}x{})x({}x{})",
|
||||
1 << params[0],
|
||||
params[2],
|
||||
params[1],
|
||||
params[3],
|
||||
params[4]
|
||||
));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
@@ -1,7 +1,15 @@
|
||||
use crate::{
|
||||
api::{SvpApply, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPPolFromBytes, SvpPrepare},
|
||||
layouts::{Backend, Module, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDftToMut, VecZnxDftToRef},
|
||||
oep::{SvpApplyImpl, SvpApplyInplaceImpl, SvpPPolAllocBytesImpl, SvpPPolAllocImpl, SvpPPolFromBytesImpl, SvpPrepareImpl},
|
||||
api::{
|
||||
SvpApplyDft, SvpApplyDftToDft, SvpApplyDftToDftAdd, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes,
|
||||
SvpPPolFromBytes, SvpPrepare,
|
||||
},
|
||||
layouts::{
|
||||
Backend, Module, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef,
|
||||
},
|
||||
oep::{
|
||||
SvpApplyDftImpl, SvpApplyDftToDftAddImpl, SvpApplyDftToDftImpl, SvpApplyDftToDftInplaceImpl, SvpPPolAllocBytesImpl,
|
||||
SvpPPolAllocImpl, SvpPPolFromBytesImpl, SvpPrepareImpl,
|
||||
},
|
||||
};
|
||||
|
||||
impl<B> SvpPPolFromBytes<B> for Module<B>
|
||||
@@ -44,29 +52,57 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> SvpApply<B> for Module<B>
|
||||
impl<B> SvpApplyDft<B> for Module<B>
|
||||
where
|
||||
B: Backend + SvpApplyImpl<B>,
|
||||
B: Backend + SvpApplyDftImpl<B>,
|
||||
{
|
||||
fn svp_apply<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
fn svp_apply_dft<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: SvpPPolToRef<B>,
|
||||
C: VecZnxToRef,
|
||||
{
|
||||
B::svp_apply_dft_impl(self, res, res_col, a, a_col, b, b_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> SvpApplyDftToDft<B> for Module<B>
|
||||
where
|
||||
B: Backend + SvpApplyDftToDftImpl<B>,
|
||||
{
|
||||
fn svp_apply_dft_to_dft<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: SvpPPolToRef<B>,
|
||||
C: VecZnxDftToRef<B>,
|
||||
{
|
||||
B::svp_apply_impl(self, res, res_col, a, a_col, b, b_col);
|
||||
B::svp_apply_dft_to_dft_impl(self, res, res_col, a, a_col, b, b_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> SvpApplyInplace<B> for Module<B>
|
||||
impl<B> SvpApplyDftToDftAdd<B> for Module<B>
|
||||
where
|
||||
B: Backend + SvpApplyInplaceImpl,
|
||||
B: Backend + SvpApplyDftToDftAddImpl<B>,
|
||||
{
|
||||
fn svp_apply_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
fn svp_apply_dft_to_dft_add<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: SvpPPolToRef<B>,
|
||||
C: VecZnxDftToRef<B>,
|
||||
{
|
||||
B::svp_apply_dft_to_dft_add_impl(self, res, res_col, a, a_col, b, b_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> SvpApplyDftToDftInplace<B> for Module<B>
|
||||
where
|
||||
B: Backend + SvpApplyDftToDftInplaceImpl,
|
||||
{
|
||||
fn svp_apply_dft_to_dft_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: SvpPPolToRef<B>,
|
||||
{
|
||||
B::svp_apply_inplace_impl(self, res, res_col, a, a_col);
|
||||
B::svp_apply_dft_to_dft_inplace_impl(self, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,19 +1,24 @@
|
||||
use crate::{
|
||||
api::{
|
||||
VecZnxAdd, VecZnxAddDistF64, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism,
|
||||
VecZnxAutomorphismInplace, VecZnxCopy, VecZnxFillDistF64, VecZnxFillNormal, VecZnxFillUniform, VecZnxLshInplace,
|
||||
VecZnxMerge, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegate, VecZnxNegateInplace, VecZnxNormalize,
|
||||
VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSplit,
|
||||
VecZnxSub, VecZnxSubABInplace, VecZnxSubBAInplace, VecZnxSubScalarInplace, VecZnxSwithcDegree,
|
||||
VecZnxAdd, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalar, VecZnxAddScalarInplace, VecZnxAutomorphism,
|
||||
VecZnxAutomorphismInplace, VecZnxAutomorphismInplaceTmpBytes, VecZnxCopy, VecZnxFillNormal, VecZnxFillUniform, VecZnxLsh,
|
||||
VecZnxLshInplace, VecZnxLshTmpBytes, VecZnxMergeRings, VecZnxMergeRingsTmpBytes, VecZnxMulXpMinusOne,
|
||||
VecZnxMulXpMinusOneInplace, VecZnxMulXpMinusOneInplaceTmpBytes, VecZnxNegate, VecZnxNegateInplace, VecZnxNormalize,
|
||||
VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes,
|
||||
VecZnxRsh, VecZnxRshInplace, VecZnxRshTmpBytes, VecZnxSplitRing, VecZnxSplitRingTmpBytes, VecZnxSub, VecZnxSubABInplace,
|
||||
VecZnxSubBAInplace, VecZnxSubScalar, VecZnxSubScalarInplace, VecZnxSwitchRing,
|
||||
},
|
||||
layouts::{Backend, Module, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef},
|
||||
oep::{
|
||||
VecZnxAddDistF64Impl, VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl, VecZnxAddScalarInplaceImpl,
|
||||
VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl, VecZnxFillDistF64Impl, VecZnxFillNormalImpl,
|
||||
VecZnxFillUniformImpl, VecZnxLshInplaceImpl, VecZnxMergeImpl, VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl,
|
||||
VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl, VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl,
|
||||
VecZnxRotateImpl, VecZnxRotateInplaceImpl, VecZnxRshInplaceImpl, VecZnxSplitImpl, VecZnxSubABInplaceImpl,
|
||||
VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarInplaceImpl, VecZnxSwithcDegreeImpl,
|
||||
VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl, VecZnxAddScalarImpl, VecZnxAddScalarInplaceImpl,
|
||||
VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxAutomorphismInplaceTmpBytesImpl, VecZnxCopyImpl,
|
||||
VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxLshImpl, VecZnxLshInplaceImpl, VecZnxLshTmpBytesImpl,
|
||||
VecZnxMergeRingsImpl, VecZnxMergeRingsTmpBytesImpl, VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl,
|
||||
VecZnxMulXpMinusOneInplaceTmpBytesImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl,
|
||||
VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl,
|
||||
VecZnxRotateInplaceTmpBytesImpl, VecZnxRshImpl, VecZnxRshInplaceImpl, VecZnxRshTmpBytesImpl, VecZnxSplitRingImpl,
|
||||
VecZnxSplitRingTmpBytesImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarImpl,
|
||||
VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl,
|
||||
},
|
||||
source::Source,
|
||||
};
|
||||
@@ -79,6 +84,20 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxAddScalar for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxAddScalarImpl<B>,
|
||||
{
|
||||
fn vec_znx_add_scalar<R, A, D>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize, b_limb: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
D: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_add_scalar_impl(self, res, res_col, a, a_col, b, b_col, b_limb)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxAddScalarInplace for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxAddScalarInplaceImpl<B>,
|
||||
@@ -132,6 +151,20 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxSubScalar for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxSubScalarImpl<B>,
|
||||
{
|
||||
fn vec_znx_sub_scalar<R, A, D>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize, b_limb: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
D: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_sub_scalar_impl(self, res, res_col, a, a_col, b, b_col, b_limb)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxSubScalarInplace for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxSubScalarInplaceImpl<B>,
|
||||
@@ -170,27 +203,87 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxLshInplace for Module<B>
|
||||
impl<B> VecZnxRshTmpBytes for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxLshInplaceImpl<B>,
|
||||
B: Backend + VecZnxRshTmpBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_lsh_inplace<A>(&self, basek: usize, k: usize, a: &mut A)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
B::vec_znx_lsh_inplace_impl(self, basek, k, a)
|
||||
fn vec_znx_rsh_tmp_bytes(&self) -> usize {
|
||||
B::vec_znx_rsh_tmp_bytes_impl(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxRshInplace for Module<B>
|
||||
impl<B> VecZnxLshTmpBytes for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxRshInplaceImpl<B>,
|
||||
B: Backend + VecZnxLshTmpBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_rsh_inplace<A>(&self, basek: usize, k: usize, a: &mut A)
|
||||
fn vec_znx_lsh_tmp_bytes(&self) -> usize {
|
||||
B::vec_znx_lsh_tmp_bytes_impl(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxLsh<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxLshImpl<B>,
|
||||
{
|
||||
fn vec_znx_lsh<R, A>(
|
||||
&self,
|
||||
basek: usize,
|
||||
k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_lsh_inplace_impl(self, basek, k, res, res_col, a, a_col, scratch);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxRsh<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxRshImpl<B>,
|
||||
{
|
||||
fn vec_znx_rsh<R, A>(
|
||||
&self,
|
||||
basek: usize,
|
||||
k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_rsh_inplace_impl(self, basek, k, res, res_col, a, a_col, scratch);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxLshInplace<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxLshInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_lsh_inplace<A>(&self, basek: usize, k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
B::vec_znx_rsh_inplace_impl(self, basek, k, a)
|
||||
B::vec_znx_lsh_inplace_impl(self, basek, k, a, a_col, scratch)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxRshInplace<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxRshInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_rsh_inplace<A>(&self, basek: usize, k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
B::vec_znx_rsh_inplace_impl(self, basek, k, a, a_col, scratch)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -207,15 +300,24 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxRotateInplace for Module<B>
|
||||
impl<B> VecZnxRotateInplaceTmpBytes for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxRotateInplaceTmpBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_rotate_inplace_tmp_bytes(&self) -> usize {
|
||||
B::vec_znx_rotate_inplace_tmp_bytes_impl(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxRotateInplace<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxRotateInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_rotate_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||
fn vec_znx_rotate_inplace<A>(&self, k: i64, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
B::vec_znx_rotate_inplace_impl(self, k, a, a_col)
|
||||
B::vec_znx_rotate_inplace_impl(self, k, a, a_col, scratch)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -232,15 +334,24 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxAutomorphismInplace for Module<B>
|
||||
impl<B> VecZnxAutomorphismInplaceTmpBytes for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxAutomorphismInplaceTmpBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_automorphism_inplace_tmp_bytes(&self) -> usize {
|
||||
B::vec_znx_automorphism_inplace_tmp_bytes_impl(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxAutomorphismInplace<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxAutomorphismInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||
fn vec_znx_automorphism_inplace<R>(&self, k: i64, res: &mut R, res_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
B::vec_znx_automorphism_inplace_impl(self, k, a, a_col)
|
||||
B::vec_znx_automorphism_inplace_impl(self, k, res, res_col, scratch)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -257,54 +368,81 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxMulXpMinusOneInplace for Module<B>
|
||||
impl<B> VecZnxMulXpMinusOneInplaceTmpBytes for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxMulXpMinusOneInplaceTmpBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_mul_xp_minus_one_inplace_tmp_bytes(&self) -> usize {
|
||||
B::vec_znx_mul_xp_minus_one_inplace_tmp_bytes_impl(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxMulXpMinusOneInplace<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxMulXpMinusOneInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_mul_xp_minus_one_inplace<R>(&self, p: i64, res: &mut R, res_col: usize)
|
||||
fn vec_znx_mul_xp_minus_one_inplace<R>(&self, p: i64, res: &mut R, res_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
B::vec_znx_mul_xp_minus_one_inplace_impl(self, p, res, res_col);
|
||||
B::vec_znx_mul_xp_minus_one_inplace_impl(self, p, res, res_col, scratch);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxSplit<B> for Module<B>
|
||||
impl<B> VecZnxSplitRingTmpBytes for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxSplitImpl<B>,
|
||||
B: Backend + VecZnxSplitRingTmpBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_split<R, A>(&self, res: &mut [R], res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
fn vec_znx_split_ring_tmp_bytes(&self) -> usize {
|
||||
B::vec_znx_split_ring_tmp_bytes_impl(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxSplitRing<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxSplitRingImpl<B>,
|
||||
{
|
||||
fn vec_znx_split_ring<R, A>(&self, res: &mut [R], res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_split_impl(self, res, res_col, a, a_col, scratch)
|
||||
B::vec_znx_split_ring_impl(self, res, res_col, a, a_col, scratch)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxMerge for Module<B>
|
||||
impl<B> VecZnxMergeRingsTmpBytes for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxMergeImpl<B>,
|
||||
B: Backend + VecZnxMergeRingsTmpBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_merge<R, A>(&self, res: &mut R, res_col: usize, a: &[A], a_col: usize)
|
||||
fn vec_znx_merge_rings_tmp_bytes(&self) -> usize {
|
||||
B::vec_znx_merge_rings_tmp_bytes_impl(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxMergeRings<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxMergeRingsImpl<B>,
|
||||
{
|
||||
fn vec_znx_merge_rings<R, A>(&self, res: &mut R, res_col: usize, a: &[A], a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_merge_impl(self, res, res_col, a, a_col)
|
||||
B::vec_znx_merge_rings_impl(self, res, res_col, a, a_col, scratch)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxSwithcDegree for Module<B>
|
||||
impl<B> VecZnxSwitchRing for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxSwithcDegreeImpl<B>,
|
||||
B: Backend + VecZnxSwitchRingImpl<B>,
|
||||
{
|
||||
fn vec_znx_switch_degree<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
fn vec_znx_switch_ring<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_switch_degree_impl(self, res, res_col, a, a_col)
|
||||
B::vec_znx_switch_ring_impl(self, res, res_col, a, a_col)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -325,51 +463,11 @@ impl<B> VecZnxFillUniform for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxFillUniformImpl<B>,
|
||||
{
|
||||
fn vec_znx_fill_uniform<R>(&self, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source)
|
||||
fn vec_znx_fill_uniform<R>(&self, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
B::vec_znx_fill_uniform_impl(self, basek, res, res_col, k, source);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxFillDistF64 for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxFillDistF64Impl<B>,
|
||||
{
|
||||
fn vec_znx_fill_dist_f64<R, D: rand::prelude::Distribution<f64>>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
B::vec_znx_fill_dist_f64_impl(self, basek, res, res_col, k, source, dist, bound);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxAddDistF64 for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxAddDistF64Impl<B>,
|
||||
{
|
||||
fn vec_znx_add_dist_f64<R, D: rand::prelude::Distribution<f64>>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
B::vec_znx_add_dist_f64_impl(self, basek, res, res_col, k, source, dist, bound);
|
||||
B::vec_znx_fill_uniform_impl(self, basek, res, res_col, source);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,18 +1,16 @@
|
||||
use rand_distr::Distribution;
|
||||
|
||||
use crate::{
|
||||
api::{
|
||||
VecZnxBigAdd, VecZnxBigAddDistF64, VecZnxBigAddInplace, VecZnxBigAddNormal, VecZnxBigAddSmall, VecZnxBigAddSmallInplace,
|
||||
VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigAutomorphism, VecZnxBigAutomorphismInplace, VecZnxBigFillDistF64,
|
||||
VecZnxBigFillNormal, VecZnxBigFromBytes, VecZnxBigNegateInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
|
||||
VecZnxBigSub, VecZnxBigSubABInplace, VecZnxBigSubBAInplace, VecZnxBigSubSmallA, VecZnxBigSubSmallAInplace,
|
||||
VecZnxBigSubSmallB, VecZnxBigSubSmallBInplace,
|
||||
VecZnxBigAdd, VecZnxBigAddInplace, VecZnxBigAddNormal, VecZnxBigAddSmall, VecZnxBigAddSmallInplace, VecZnxBigAlloc,
|
||||
VecZnxBigAllocBytes, VecZnxBigAutomorphism, VecZnxBigAutomorphismInplace, VecZnxBigAutomorphismInplaceTmpBytes,
|
||||
VecZnxBigFromBytes, VecZnxBigFromSmall, VecZnxBigNegate, VecZnxBigNegateInplace, VecZnxBigNormalize,
|
||||
VecZnxBigNormalizeTmpBytes, VecZnxBigSub, VecZnxBigSubABInplace, VecZnxBigSubBAInplace, VecZnxBigSubSmallA,
|
||||
VecZnxBigSubSmallAInplace, VecZnxBigSubSmallB, VecZnxBigSubSmallBInplace,
|
||||
},
|
||||
layouts::{Backend, Module, Scratch, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef},
|
||||
oep::{
|
||||
VecZnxBigAddDistF64Impl, VecZnxBigAddImpl, VecZnxBigAddInplaceImpl, VecZnxBigAddNormalImpl, VecZnxBigAddSmallImpl,
|
||||
VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, VecZnxBigAutomorphismImpl,
|
||||
VecZnxBigAutomorphismInplaceImpl, VecZnxBigFillDistF64Impl, VecZnxBigFillNormalImpl, VecZnxBigFromBytesImpl,
|
||||
VecZnxBigAddImpl, VecZnxBigAddInplaceImpl, VecZnxBigAddNormalImpl, VecZnxBigAddSmallImpl, VecZnxBigAddSmallInplaceImpl,
|
||||
VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, VecZnxBigAutomorphismImpl, VecZnxBigAutomorphismInplaceImpl,
|
||||
VecZnxBigAutomorphismInplaceTmpBytesImpl, VecZnxBigFromBytesImpl, VecZnxBigFromSmallImpl, VecZnxBigNegateImpl,
|
||||
VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl, VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubABInplaceImpl,
|
||||
VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl, VecZnxBigSubSmallAImpl, VecZnxBigSubSmallAInplaceImpl,
|
||||
VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl,
|
||||
@@ -20,6 +18,19 @@ use crate::{
|
||||
source::Source,
|
||||
};
|
||||
|
||||
impl<B> VecZnxBigFromSmall<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigFromSmallImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_from_small<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_big_from_small_impl(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigAlloc<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigAllocImpl<B>,
|
||||
@@ -47,24 +58,6 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigAddDistF64<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigAddDistF64Impl<B>,
|
||||
{
|
||||
fn vec_znx_big_add_dist_f64<R: VecZnxBigToMut<B>, D: Distribution<f64>>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) {
|
||||
B::add_dist_f64_impl(self, basek, res, res_col, k, source, dist, bound);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigAddNormal<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigAddNormalImpl<B>,
|
||||
@@ -83,42 +76,6 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigFillDistF64<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigFillDistF64Impl<B>,
|
||||
{
|
||||
fn vec_znx_big_fill_dist_f64<R: VecZnxBigToMut<B>, D: Distribution<f64>>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) {
|
||||
B::fill_dist_f64_impl(self, basek, res, res_col, k, source, dist, bound);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigFillNormal<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigFillNormalImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_fill_normal<R: VecZnxBigToMut<B>>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) {
|
||||
B::fill_normal_impl(self, basek, res, res_col, k, source, sigma, bound);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigAdd<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigAddImpl<B>,
|
||||
@@ -267,6 +224,19 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigNegate<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigNegateImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_negate<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>,
|
||||
{
|
||||
B::vec_znx_big_negate_impl(self, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigNegateInplace<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigNegateInplaceImpl<B>,
|
||||
@@ -321,14 +291,23 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigAutomorphismInplaceTmpBytes for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigAutomorphismInplaceTmpBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_automorphism_inplace_tmp_bytes(&self) -> usize {
|
||||
B::vec_znx_big_automorphism_inplace_tmp_bytes_impl(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigAutomorphismInplace<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigAutomorphismInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||
fn vec_znx_big_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
A: VecZnxBigToMut<B>,
|
||||
{
|
||||
B::vec_znx_big_automorphism_inplace_impl(self, k, a, a_col);
|
||||
B::vec_znx_big_automorphism_inplace_impl(self, k, a, a_col, scratch);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
use crate::{
|
||||
api::{
|
||||
DFT, IDFT, IDFTConsume, IDFTTmpA, VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftCopy,
|
||||
VecZnxDftFromBytes, VecZnxDftSub, VecZnxDftSubABInplace, VecZnxDftSubBAInplace, VecZnxDftZero, VecZnxIDFTTmpBytes,
|
||||
VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftCopy,
|
||||
VecZnxDftFromBytes, VecZnxDftSub, VecZnxDftSubABInplace, VecZnxDftSubBAInplace, VecZnxDftZero, VecZnxIdftApply,
|
||||
VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxIdftApplyTmpBytes,
|
||||
},
|
||||
layouts::{
|
||||
Backend, Data, Module, Scratch, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef,
|
||||
VecZnxToRef,
|
||||
},
|
||||
oep::{
|
||||
DFTImpl, IDFTConsumeImpl, IDFTImpl, IDFTTmpAImpl, VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl,
|
||||
VecZnxDftAllocImpl, VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl,
|
||||
VecZnxDftSubImpl, VecZnxDftZeroImpl, VecZnxIDFTTmpBytesImpl,
|
||||
VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftApplyImpl,
|
||||
VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl, VecZnxDftSubImpl,
|
||||
VecZnxDftZeroImpl, VecZnxIdftApplyConsumeImpl, VecZnxIdftApplyImpl, VecZnxIdftApplyTmpAImpl, VecZnxIdftApplyTmpBytesImpl,
|
||||
},
|
||||
};
|
||||
|
||||
@@ -41,63 +42,63 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxIDFTTmpBytes for Module<B>
|
||||
impl<B> VecZnxIdftApplyTmpBytes for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxIDFTTmpBytesImpl<B>,
|
||||
B: Backend + VecZnxIdftApplyTmpBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_idft_tmp_bytes(&self) -> usize {
|
||||
B::vec_znx_idft_tmp_bytes_impl(self)
|
||||
fn vec_znx_idft_apply_tmp_bytes(&self) -> usize {
|
||||
B::vec_znx_idft_apply_tmp_bytes_impl(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> IDFT<B> for Module<B>
|
||||
impl<B> VecZnxIdftApply<B> for Module<B>
|
||||
where
|
||||
B: Backend + IDFTImpl<B>,
|
||||
B: Backend + VecZnxIdftApplyImpl<B>,
|
||||
{
|
||||
fn idft<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
fn vec_znx_idft_apply<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
{
|
||||
B::idft_impl(self, res, res_col, a, a_col, scratch);
|
||||
B::vec_znx_idft_apply_impl(self, res, res_col, a, a_col, scratch);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> IDFTTmpA<B> for Module<B>
|
||||
impl<B> VecZnxIdftApplyTmpA<B> for Module<B>
|
||||
where
|
||||
B: Backend + IDFTTmpAImpl<B>,
|
||||
B: Backend + VecZnxIdftApplyTmpAImpl<B>,
|
||||
{
|
||||
fn idft_tmp_a<R, A>(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
|
||||
fn vec_znx_idft_apply_tmpa<R, A>(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxDftToMut<B>,
|
||||
{
|
||||
B::idft_tmp_a_impl(self, res, res_col, a, a_col);
|
||||
B::vec_znx_idft_apply_tmpa_impl(self, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> IDFTConsume<B> for Module<B>
|
||||
impl<B> VecZnxIdftApplyConsume<B> for Module<B>
|
||||
where
|
||||
B: Backend + IDFTConsumeImpl<B>,
|
||||
B: Backend + VecZnxIdftApplyConsumeImpl<B>,
|
||||
{
|
||||
fn vec_znx_idft_consume<D: Data>(&self, a: VecZnxDft<D, B>) -> VecZnxBig<D, B>
|
||||
fn vec_znx_idft_apply_consume<D: Data>(&self, a: VecZnxDft<D, B>) -> VecZnxBig<D, B>
|
||||
where
|
||||
VecZnxDft<D, B>: VecZnxDftToMut<B>,
|
||||
{
|
||||
B::idft_consume_impl(self, a)
|
||||
B::vec_znx_idft_apply_consume_impl(self, a)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> DFT<B> for Module<B>
|
||||
impl<B> VecZnxDftApply<B> for Module<B>
|
||||
where
|
||||
B: Backend + DFTImpl<B>,
|
||||
B: Backend + VecZnxDftApplyImpl<B>,
|
||||
{
|
||||
fn dft<R, A>(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
fn vec_znx_dft_apply<R, A>(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
B::dft_impl(self, step, offset, res, res_col, a, a_col);
|
||||
B::vec_znx_dft_apply_impl(self, step, offset, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
use crate::{
|
||||
api::{
|
||||
VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftAddTmpBytes, VmpApplyDftToDftTmpBytes, VmpPMatAlloc,
|
||||
VmpPMatAllocBytes, VmpPMatFromBytes, VmpPrepare, VmpPrepareTmpBytes,
|
||||
VmpApplyDft, VmpApplyDftTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftAddTmpBytes,
|
||||
VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPMatAllocBytes, VmpPMatFromBytes, VmpPrepare, VmpPrepareTmpBytes,
|
||||
},
|
||||
layouts::{
|
||||
Backend, MatZnxToRef, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, VmpPMatOwned, VmpPMatToMut,
|
||||
VmpPMatToRef,
|
||||
},
|
||||
layouts::{Backend, MatZnxToRef, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VmpPMatOwned, VmpPMatToMut, VmpPMatToRef},
|
||||
oep::{
|
||||
VmpApplyDftToDftAddImpl, VmpApplyDftToDftAddTmpBytesImpl, VmpApplyDftToDftImpl, VmpApplyDftToDftTmpBytesImpl,
|
||||
VmpPMatAllocBytesImpl, VmpPMatAllocImpl, VmpPMatFromBytesImpl, VmpPMatPrepareImpl, VmpPrepareTmpBytesImpl,
|
||||
VmpApplyDftImpl, VmpApplyDftTmpBytesImpl, VmpApplyDftToDftAddImpl, VmpApplyDftToDftAddTmpBytesImpl, VmpApplyDftToDftImpl,
|
||||
VmpApplyDftToDftTmpBytesImpl, VmpPMatAllocBytesImpl, VmpPMatAllocImpl, VmpPMatFromBytesImpl, VmpPrepareImpl,
|
||||
VmpPrepareTmpBytesImpl,
|
||||
},
|
||||
};
|
||||
|
||||
@@ -48,7 +52,7 @@ where
|
||||
|
||||
impl<B> VmpPrepare<B> for Module<B>
|
||||
where
|
||||
B: Backend + VmpPMatPrepareImpl<B>,
|
||||
B: Backend + VmpPrepareImpl<B>,
|
||||
{
|
||||
fn vmp_prepare<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<B>)
|
||||
where
|
||||
@@ -59,6 +63,39 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VmpApplyDftTmpBytes for Module<B>
|
||||
where
|
||||
B: Backend + VmpApplyDftTmpBytesImpl<B>,
|
||||
{
|
||||
fn vmp_apply_dft_tmp_bytes(
|
||||
&self,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
b_cols_in: usize,
|
||||
b_cols_out: usize,
|
||||
b_size: usize,
|
||||
) -> usize {
|
||||
B::vmp_apply_dft_tmp_bytes_impl(
|
||||
self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VmpApplyDft<B> for Module<B>
|
||||
where
|
||||
B: Backend + VmpApplyDftImpl<B>,
|
||||
{
|
||||
fn vmp_apply_dft<R, A, C>(&self, res: &mut R, a: &A, b: &C, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxToRef,
|
||||
C: VmpPMatToRef<B>,
|
||||
{
|
||||
B::vmp_apply_dft_impl(self, res, a, b, scratch);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VmpApplyDftToDftTmpBytes for Module<B>
|
||||
where
|
||||
B: Backend + VmpApplyDftToDftTmpBytesImpl<B>,
|
||||
|
||||
@@ -1,10 +1,19 @@
|
||||
use crate::{
|
||||
api::{ZnAddDistF64, ZnAddNormal, ZnFillDistF64, ZnFillNormal, ZnFillUniform, ZnNormalizeInplace},
|
||||
api::{ZnAddNormal, ZnFillNormal, ZnFillUniform, ZnNormalizeInplace, ZnNormalizeTmpBytes},
|
||||
layouts::{Backend, Module, Scratch, ZnToMut},
|
||||
oep::{ZnAddDistF64Impl, ZnAddNormalImpl, ZnFillDistF64Impl, ZnFillNormalImpl, ZnFillUniformImpl, ZnNormalizeInplaceImpl},
|
||||
oep::{ZnAddNormalImpl, ZnFillNormalImpl, ZnFillUniformImpl, ZnNormalizeInplaceImpl, ZnNormalizeTmpBytesImpl},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
impl<B> ZnNormalizeTmpBytes for Module<B>
|
||||
where
|
||||
B: Backend + ZnNormalizeTmpBytesImpl<B>,
|
||||
{
|
||||
fn zn_normalize_tmp_bytes(&self, n: usize) -> usize {
|
||||
B::zn_normalize_tmp_bytes_impl(n)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> ZnNormalizeInplace<B> for Module<B>
|
||||
where
|
||||
B: Backend + ZnNormalizeInplaceImpl<B>,
|
||||
@@ -21,53 +30,11 @@ impl<B> ZnFillUniform for Module<B>
|
||||
where
|
||||
B: Backend + ZnFillUniformImpl<B>,
|
||||
{
|
||||
fn zn_fill_uniform<R>(&self, n: usize, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source)
|
||||
fn zn_fill_uniform<R>(&self, n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
|
||||
where
|
||||
R: ZnToMut,
|
||||
{
|
||||
B::zn_fill_uniform_impl(n, basek, res, res_col, k, source);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> ZnFillDistF64 for Module<B>
|
||||
where
|
||||
B: Backend + ZnFillDistF64Impl<B>,
|
||||
{
|
||||
fn zn_fill_dist_f64<R, D: rand::prelude::Distribution<f64>>(
|
||||
&self,
|
||||
n: usize,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: ZnToMut,
|
||||
{
|
||||
B::zn_fill_dist_f64_impl(n, basek, res, res_col, k, source, dist, bound);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> ZnAddDistF64 for Module<B>
|
||||
where
|
||||
B: Backend + ZnAddDistF64Impl<B>,
|
||||
{
|
||||
fn zn_add_dist_f64<R, D: rand::prelude::Distribution<f64>>(
|
||||
&self,
|
||||
n: usize,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: ZnToMut,
|
||||
{
|
||||
B::zn_add_dist_f64_impl(n, basek, res, res_col, k, source, dist, bound);
|
||||
B::zn_fill_uniform_impl(n, basek, res, res_col, source);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
use itertools::izip;
|
||||
use rug::{Assign, Float};
|
||||
|
||||
use crate::layouts::{
|
||||
DataMut, DataRef, VecZnx, VecZnxToMut, VecZnxToRef, Zn, ZnToMut, ZnToRef, ZnxInfos, ZnxView, ZnxViewMut, ZnxZero,
|
||||
use crate::{
|
||||
layouts::{DataMut, DataRef, VecZnx, VecZnxToMut, VecZnxToRef, Zn, ZnToMut, ZnToRef, ZnxInfos, ZnxView, ZnxViewMut},
|
||||
reference::znx::znx_zero_ref,
|
||||
};
|
||||
|
||||
impl<D: DataMut> VecZnx<D> {
|
||||
@@ -28,7 +29,7 @@ impl<D: DataMut> VecZnx<D> {
|
||||
|
||||
// Zeroes coefficients of the i-th column
|
||||
(0..a.size()).for_each(|i| {
|
||||
a.zero_at(col, i);
|
||||
znx_zero_ref(a.at_mut(col, i));
|
||||
});
|
||||
|
||||
// If 2^{basek} * 2^{k_rem} < 2^{63}-1, then we can simply copy
|
||||
@@ -183,7 +184,7 @@ impl<D: DataRef> VecZnx<D> {
|
||||
let prec: u32 = (basek * size) as u32;
|
||||
|
||||
// 2^{basek}
|
||||
let base = Float::with_val(prec, (1 << basek) as f64);
|
||||
let base = Float::with_val(prec, (1u64 << basek) as f64);
|
||||
|
||||
// y[i] = sum x[j][i] * 2^{-basek*j}
|
||||
(0..size).for_each(|i| {
|
||||
|
||||
@@ -1,17 +1,21 @@
|
||||
use crate::{
|
||||
alloc_aligned,
|
||||
layouts::{
|
||||
Data, DataMut, DataRef, DataView, DataViewMut, FillUniform, ReaderFrom, Reset, ToOwnedDeep, VecZnx, WriterTo, ZnxInfos,
|
||||
ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
|
||||
Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, FillUniform, ReaderFrom, Reset, ToOwnedDeep, VecZnx, WriterTo,
|
||||
ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
|
||||
},
|
||||
source::Source,
|
||||
};
|
||||
use std::fmt;
|
||||
use std::{
|
||||
fmt,
|
||||
hash::{DefaultHasher, Hasher},
|
||||
};
|
||||
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
use rand::RngCore;
|
||||
|
||||
#[derive(PartialEq, Eq, Clone)]
|
||||
#[repr(C)]
|
||||
#[derive(PartialEq, Eq, Clone, Hash)]
|
||||
pub struct MatZnx<D: Data> {
|
||||
data: D,
|
||||
n: usize,
|
||||
@@ -21,6 +25,19 @@ pub struct MatZnx<D: Data> {
|
||||
cols_out: usize,
|
||||
}
|
||||
|
||||
impl<D: DataRef> DigestU64 for MatZnx<D> {
|
||||
fn digest_u64(&self) -> u64 {
|
||||
let mut h: DefaultHasher = DefaultHasher::new();
|
||||
h.write(self.data.as_ref());
|
||||
h.write_usize(self.n);
|
||||
h.write_usize(self.size);
|
||||
h.write_usize(self.rows);
|
||||
h.write_usize(self.cols_in);
|
||||
h.write_usize(self.cols_out);
|
||||
h.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> ToOwnedDeep for MatZnx<D> {
|
||||
type Owned = MatZnx<Vec<u8>>;
|
||||
fn to_owned_deep(&self) -> Self::Owned {
|
||||
@@ -57,6 +74,10 @@ impl<D: Data> ZnxInfos for MatZnx<D> {
|
||||
fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
|
||||
fn poly_count(&self) -> usize {
|
||||
self.rows() * self.cols_in() * self.cols_out() * self.size()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> ZnxSliceSize for MatZnx<D> {
|
||||
@@ -175,8 +196,18 @@ impl<D: DataMut> MatZnx<D> {
|
||||
}
|
||||
|
||||
impl<D: DataMut> FillUniform for MatZnx<D> {
|
||||
fn fill_uniform(&mut self, source: &mut Source) {
|
||||
source.fill_bytes(self.data.as_mut());
|
||||
fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) {
|
||||
match log_bound {
|
||||
64 => source.fill_bytes(self.data.as_mut()),
|
||||
0 => panic!("invalid log_bound, cannot be zero"),
|
||||
_ => {
|
||||
let mask: u64 = (1u64 << log_bound) - 1;
|
||||
for x in self.raw_mut().iter_mut() {
|
||||
let r = source.next_u64() & mask;
|
||||
*x = ((r << (64 - log_bound)) as i64) >> (64 - log_bound);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -34,3 +34,7 @@ pub trait ToOwnedDeep {
|
||||
type Owned;
|
||||
fn to_owned_deep(&self) -> Self::Owned;
|
||||
}
|
||||
|
||||
pub trait DigestU64 {
|
||||
fn digest_u64(&self) -> u64;
|
||||
}
|
||||
|
||||
@@ -20,7 +20,30 @@ pub struct Module<B: Backend> {
|
||||
_marker: PhantomData<B>,
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> Sync for Module<B> {}
|
||||
unsafe impl<B: Backend> Send for Module<B> {}
|
||||
|
||||
impl<B: Backend> Module<B> {
|
||||
#[allow(clippy::missing_safety_doc)]
|
||||
#[inline]
|
||||
pub fn new_marker(n: u64) -> Self {
|
||||
Self {
|
||||
ptr: NonNull::dangling(),
|
||||
n,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::missing_safety_doc)]
|
||||
#[inline]
|
||||
pub unsafe fn from_nonnull(ptr: NonNull<B::Handle>, n: u64) -> Self {
|
||||
Self {
|
||||
ptr,
|
||||
n,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Construct from a raw pointer managed elsewhere.
|
||||
/// SAFETY: `ptr` must be non-null and remain valid for the lifetime of this Module.
|
||||
#[inline]
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use std::hash::{DefaultHasher, Hasher};
|
||||
|
||||
use rand::seq::SliceRandom;
|
||||
use rand_core::RngCore;
|
||||
use rand_distr::{Distribution, weighted::WeightedIndex};
|
||||
@@ -5,19 +7,30 @@ use rand_distr::{Distribution, weighted::WeightedIndex};
|
||||
use crate::{
|
||||
alloc_aligned,
|
||||
layouts::{
|
||||
Data, DataMut, DataRef, DataView, DataViewMut, FillUniform, ReaderFrom, Reset, ToOwnedDeep, VecZnx, WriterTo, ZnxInfos,
|
||||
ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
|
||||
Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, FillUniform, ReaderFrom, Reset, ToOwnedDeep, VecZnx, WriterTo,
|
||||
ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
|
||||
},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
#[derive(PartialEq, Eq, Debug, Clone)]
|
||||
#[repr(C)]
|
||||
#[derive(PartialEq, Eq, Debug, Clone, Hash)]
|
||||
pub struct ScalarZnx<D: Data> {
|
||||
pub data: D,
|
||||
pub n: usize,
|
||||
pub cols: usize,
|
||||
}
|
||||
|
||||
impl<D: DataRef> DigestU64 for ScalarZnx<D> {
|
||||
fn digest_u64(&self) -> u64 {
|
||||
let mut h: DefaultHasher = DefaultHasher::new();
|
||||
h.write(self.data.as_ref());
|
||||
h.write_usize(self.n);
|
||||
h.write_usize(self.cols);
|
||||
h.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> ToOwnedDeep for ScalarZnx<D> {
|
||||
type Owned = ScalarZnx<Vec<u8>>;
|
||||
fn to_owned_deep(&self) -> Self::Owned {
|
||||
@@ -145,8 +158,18 @@ impl<D: DataMut> ZnxZero for ScalarZnx<D> {
|
||||
}
|
||||
|
||||
impl<D: DataMut> FillUniform for ScalarZnx<D> {
|
||||
fn fill_uniform(&mut self, source: &mut Source) {
|
||||
source.fill_bytes(self.data.as_mut());
|
||||
fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) {
|
||||
match log_bound {
|
||||
64 => source.fill_bytes(self.data.as_mut()),
|
||||
0 => panic!("invalid log_bound, cannot be zero"),
|
||||
_ => {
|
||||
let mask: u64 = (1u64 << log_bound) - 1;
|
||||
for x in self.raw_mut().iter_mut() {
|
||||
let r = source.next_u64() & mask;
|
||||
*x = ((r << (64 - log_bound)) as i64) >> (64 - log_bound);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,11 +2,13 @@ use std::marker::PhantomData;
|
||||
|
||||
use crate::layouts::Backend;
|
||||
|
||||
#[repr(C)]
|
||||
pub struct ScratchOwned<B: Backend> {
|
||||
pub data: Vec<u8>,
|
||||
pub _phantom: PhantomData<B>,
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
pub struct Scratch<B: Backend> {
|
||||
pub _phantom: PhantomData<B>,
|
||||
pub data: [u8],
|
||||
|
||||
@@ -4,7 +4,7 @@ use rug::{
|
||||
ops::{AddAssignRound, DivAssignRound, SubAssignRound},
|
||||
};
|
||||
|
||||
use crate::layouts::{DataRef, VecZnx, ZnxInfos};
|
||||
use crate::layouts::{Backend, DataRef, VecZnx, VecZnxBig, VecZnxBigToRef, ZnxInfos};
|
||||
|
||||
impl<D: DataRef> VecZnx<D> {
|
||||
pub fn std(&self, basek: usize, col: usize) -> f64 {
|
||||
@@ -27,3 +27,17 @@ impl<D: DataRef> VecZnx<D> {
|
||||
std.to_f64()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef, B: Backend + Backend<ScalarBig = i64>> VecZnxBig<D, B> {
|
||||
pub fn std(&self, basek: usize, col: usize) -> f64 {
|
||||
let self_ref: VecZnxBig<&[u8], B> = self.to_ref();
|
||||
let znx: VecZnx<&[u8]> = VecZnx {
|
||||
data: self_ref.data,
|
||||
n: self_ref.n,
|
||||
cols: self_ref.cols,
|
||||
size: self_ref.size,
|
||||
max_size: self_ref.max_size,
|
||||
};
|
||||
znx.std(basek, col)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +1,19 @@
|
||||
use std::marker::PhantomData;
|
||||
use std::{
|
||||
fmt,
|
||||
hash::{DefaultHasher, Hasher},
|
||||
marker::PhantomData,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
alloc_aligned,
|
||||
layouts::{Backend, Data, DataMut, DataRef, DataView, DataViewMut, ReaderFrom, WriterTo, ZnxInfos, ZnxSliceSize, ZnxView},
|
||||
layouts::{
|
||||
Backend, Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, ReaderFrom, WriterTo, ZnxInfos, ZnxSliceSize, ZnxView,
|
||||
},
|
||||
oep::SvpPPolAllocBytesImpl,
|
||||
};
|
||||
|
||||
#[derive(PartialEq, Eq)]
|
||||
#[repr(C)]
|
||||
#[derive(PartialEq, Eq, Hash)]
|
||||
pub struct SvpPPol<D: Data, B: Backend> {
|
||||
pub data: D,
|
||||
pub n: usize,
|
||||
@@ -14,6 +21,16 @@ pub struct SvpPPol<D: Data, B: Backend> {
|
||||
pub _phantom: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<D: DataRef, B: Backend> DigestU64 for SvpPPol<D, B> {
|
||||
fn digest_u64(&self) -> u64 {
|
||||
let mut h: DefaultHasher = DefaultHasher::new();
|
||||
h.write(self.data.as_ref());
|
||||
h.write_usize(self.n);
|
||||
h.write_usize(self.cols);
|
||||
h.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data, B: Backend> ZnxSliceSize for SvpPPol<D, B> {
|
||||
fn sl(&self) -> usize {
|
||||
B::layout_prep_word_count() * self.n()
|
||||
@@ -153,3 +170,32 @@ impl<D: DataRef, B: Backend> WriterTo for SvpPPol<D, B> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef, B: Backend> fmt::Display for SvpPPol<D, B> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
writeln!(f, "SvpPPol(n={}, cols={})", self.n, self.cols)?;
|
||||
|
||||
for col in 0..self.cols {
|
||||
writeln!(f, "Column {}:", col)?;
|
||||
let coeffs = self.at(col, 0);
|
||||
write!(f, "[")?;
|
||||
|
||||
let max_show = 100;
|
||||
let show_count = coeffs.len().min(max_show);
|
||||
|
||||
for (i, &coeff) in coeffs.iter().take(show_count).enumerate() {
|
||||
if i > 0 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
write!(f, "{}", coeff)?;
|
||||
}
|
||||
|
||||
if coeffs.len() > max_show {
|
||||
write!(f, ", ... ({} more)", coeffs.len() - max_show)?;
|
||||
}
|
||||
|
||||
writeln!(f, "]")?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
use std::fmt;
|
||||
use std::{
|
||||
fmt,
|
||||
hash::{DefaultHasher, Hasher},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
alloc_aligned,
|
||||
layouts::{
|
||||
Data, DataMut, DataRef, DataView, DataViewMut, FillUniform, ReaderFrom, Reset, ToOwnedDeep, WriterTo, ZnxInfos,
|
||||
ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
|
||||
Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, FillUniform, ReaderFrom, Reset, ToOwnedDeep, WriterTo,
|
||||
ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
|
||||
},
|
||||
source::Source,
|
||||
};
|
||||
@@ -12,7 +15,8 @@ use crate::{
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
use rand::RngCore;
|
||||
|
||||
#[derive(PartialEq, Eq, Clone, Copy)]
|
||||
#[repr(C)]
|
||||
#[derive(PartialEq, Eq, Clone, Copy, Hash)]
|
||||
pub struct VecZnx<D: Data> {
|
||||
pub data: D,
|
||||
pub n: usize,
|
||||
@@ -21,6 +25,18 @@ pub struct VecZnx<D: Data> {
|
||||
pub max_size: usize,
|
||||
}
|
||||
|
||||
impl<D: DataRef> DigestU64 for VecZnx<D> {
|
||||
fn digest_u64(&self) -> u64 {
|
||||
let mut h: DefaultHasher = DefaultHasher::new();
|
||||
h.write(self.data.as_ref());
|
||||
h.write_usize(self.n);
|
||||
h.write_usize(self.cols);
|
||||
h.write_usize(self.size);
|
||||
h.write_usize(self.max_size);
|
||||
h.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> ToOwnedDeep for VecZnx<D> {
|
||||
type Owned = VecZnx<Vec<u8>>;
|
||||
fn to_owned_deep(&self) -> Self::Owned {
|
||||
@@ -173,8 +189,18 @@ impl<D: DataRef> fmt::Display for VecZnx<D> {
|
||||
}
|
||||
|
||||
impl<D: DataMut> FillUniform for VecZnx<D> {
|
||||
fn fill_uniform(&mut self, source: &mut Source) {
|
||||
source.fill_bytes(self.data.as_mut());
|
||||
fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) {
|
||||
match log_bound {
|
||||
64 => source.fill_bytes(self.data.as_mut()),
|
||||
0 => panic!("invalid log_bound, cannot be zero"),
|
||||
_ => {
|
||||
let mask: u64 = (1u64 << log_bound) - 1;
|
||||
for x in self.raw_mut().iter_mut() {
|
||||
let r = source.next_u64() & mask;
|
||||
*x = ((r << (64 - log_bound)) as i64) >> (64 - log_bound);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,15 +1,21 @@
|
||||
use std::marker::PhantomData;
|
||||
use std::{
|
||||
hash::{DefaultHasher, Hasher},
|
||||
marker::PhantomData,
|
||||
};
|
||||
|
||||
use rand_distr::num_traits::Zero;
|
||||
use std::fmt;
|
||||
|
||||
use crate::{
|
||||
alloc_aligned,
|
||||
layouts::{Backend, Data, DataMut, DataRef, DataView, DataViewMut, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero},
|
||||
layouts::{
|
||||
Backend, Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
|
||||
},
|
||||
oep::VecZnxBigAllocBytesImpl,
|
||||
};
|
||||
|
||||
#[derive(PartialEq, Eq)]
|
||||
#[repr(C)]
|
||||
#[derive(PartialEq, Eq, Hash)]
|
||||
pub struct VecZnxBig<D: Data, B: Backend> {
|
||||
pub data: D,
|
||||
pub n: usize,
|
||||
@@ -19,6 +25,18 @@ pub struct VecZnxBig<D: Data, B: Backend> {
|
||||
pub _phantom: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<D: DataRef, B: Backend> DigestU64 for VecZnxBig<D, B> {
|
||||
fn digest_u64(&self) -> u64 {
|
||||
let mut h: DefaultHasher = DefaultHasher::new();
|
||||
h.write(self.data.as_ref());
|
||||
h.write_usize(self.n);
|
||||
h.write_usize(self.cols);
|
||||
h.write_usize(self.size);
|
||||
h.write_usize(self.max_size);
|
||||
h.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data, B: Backend> ZnxSliceSize for VecZnxBig<D, B> {
|
||||
fn sl(&self) -> usize {
|
||||
B::layout_big_word_count() * self.n() * self.cols()
|
||||
|
||||
@@ -1,14 +1,21 @@
|
||||
use std::{fmt, marker::PhantomData};
|
||||
use std::{
|
||||
fmt,
|
||||
hash::{DefaultHasher, Hasher},
|
||||
marker::PhantomData,
|
||||
};
|
||||
|
||||
use rand_distr::num_traits::Zero;
|
||||
|
||||
use crate::{
|
||||
alloc_aligned,
|
||||
layouts::{
|
||||
Backend, Data, DataMut, DataRef, DataView, DataViewMut, VecZnxBig, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
|
||||
Backend, Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, VecZnxBig, ZnxInfos, ZnxSliceSize, ZnxView,
|
||||
ZnxViewMut, ZnxZero,
|
||||
},
|
||||
oep::VecZnxBigAllocBytesImpl,
|
||||
oep::VecZnxDftAllocBytesImpl,
|
||||
};
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(PartialEq, Eq)]
|
||||
pub struct VecZnxDft<D: Data, B: Backend> {
|
||||
pub data: D,
|
||||
@@ -19,6 +26,18 @@ pub struct VecZnxDft<D: Data, B: Backend> {
|
||||
pub _phantom: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<D: DataRef, B: Backend> DigestU64 for VecZnxDft<D, B> {
|
||||
fn digest_u64(&self) -> u64 {
|
||||
let mut h: DefaultHasher = DefaultHasher::new();
|
||||
h.write(self.data.as_ref());
|
||||
h.write_usize(self.n);
|
||||
h.write_usize(self.cols);
|
||||
h.write_usize(self.size);
|
||||
h.write_usize(self.max_size);
|
||||
h.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data, B: Backend> ZnxSliceSize for VecZnxDft<D, B> {
|
||||
fn sl(&self) -> usize {
|
||||
B::layout_prep_word_count() * self.n() * self.cols()
|
||||
@@ -94,10 +113,10 @@ where
|
||||
|
||||
impl<D: DataRef + From<Vec<u8>>, B: Backend> VecZnxDft<D, B>
|
||||
where
|
||||
B: VecZnxBigAllocBytesImpl<B>,
|
||||
B: VecZnxDftAllocBytesImpl<B>,
|
||||
{
|
||||
pub fn alloc(n: usize, cols: usize, size: usize) -> Self {
|
||||
let data: Vec<u8> = alloc_aligned::<u8>(B::vec_znx_big_alloc_bytes_impl(n, cols, size));
|
||||
let data: Vec<u8> = alloc_aligned::<u8>(B::vec_znx_dft_alloc_bytes_impl(n, cols, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n,
|
||||
@@ -110,7 +129,7 @@ where
|
||||
|
||||
pub fn from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
|
||||
let data: Vec<u8> = bytes.into();
|
||||
assert!(data.len() == B::vec_znx_big_alloc_bytes_impl(n, cols, size));
|
||||
assert!(data.len() == B::vec_znx_dft_alloc_bytes_impl(n, cols, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n,
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
use std::marker::PhantomData;
|
||||
use std::{
|
||||
hash::{DefaultHasher, Hasher},
|
||||
marker::PhantomData,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
alloc_aligned,
|
||||
layouts::{Backend, Data, DataMut, DataRef, DataView, DataViewMut, ZnxInfos, ZnxView},
|
||||
layouts::{Backend, Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, ZnxInfos, ZnxView},
|
||||
oep::VmpPMatAllocBytesImpl,
|
||||
};
|
||||
|
||||
#[derive(PartialEq, Eq)]
|
||||
#[repr(C)]
|
||||
#[derive(PartialEq, Eq, Hash)]
|
||||
pub struct VmpPMat<D: Data, B: Backend> {
|
||||
data: D,
|
||||
n: usize,
|
||||
@@ -17,6 +21,19 @@ pub struct VmpPMat<D: Data, B: Backend> {
|
||||
_phantom: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<D: DataRef, B: Backend> DigestU64 for VmpPMat<D, B> {
|
||||
fn digest_u64(&self) -> u64 {
|
||||
let mut h: DefaultHasher = DefaultHasher::new();
|
||||
h.write(self.data.as_ref());
|
||||
h.write_usize(self.n);
|
||||
h.write_usize(self.size);
|
||||
h.write_usize(self.rows);
|
||||
h.write_usize(self.cols_in);
|
||||
h.write_usize(self.cols_out);
|
||||
h.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef, B: Backend> ZnxView for VmpPMat<D, B> {
|
||||
type Scalar = B::ScalarPrep;
|
||||
}
|
||||
@@ -37,6 +54,10 @@ impl<D: Data, B: Backend> ZnxInfos for VmpPMat<D, B> {
|
||||
fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
|
||||
fn poly_count(&self) -> usize {
|
||||
self.rows() * self.cols_in() * self.size() * self.cols_out()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data, B: Backend> DataView for VmpPMat<D, B> {
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
use std::fmt;
|
||||
use std::{
|
||||
fmt,
|
||||
hash::{DefaultHasher, Hasher},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
alloc_aligned,
|
||||
layouts::{
|
||||
Data, DataMut, DataRef, DataView, DataViewMut, FillUniform, ReaderFrom, Reset, ToOwnedDeep, WriterTo, ZnxInfos,
|
||||
ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
|
||||
Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, FillUniform, ReaderFrom, Reset, ToOwnedDeep, WriterTo,
|
||||
ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
|
||||
},
|
||||
source::Source,
|
||||
};
|
||||
@@ -12,7 +15,8 @@ use crate::{
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
use rand::RngCore;
|
||||
|
||||
#[derive(PartialEq, Eq, Clone, Copy)]
|
||||
#[repr(C)]
|
||||
#[derive(PartialEq, Eq, Clone, Copy, Hash)]
|
||||
pub struct Zn<D: Data> {
|
||||
pub data: D,
|
||||
pub n: usize,
|
||||
@@ -21,6 +25,18 @@ pub struct Zn<D: Data> {
|
||||
pub max_size: usize,
|
||||
}
|
||||
|
||||
impl<D: DataRef> DigestU64 for Zn<D> {
|
||||
fn digest_u64(&self) -> u64 {
|
||||
let mut h: DefaultHasher = DefaultHasher::new();
|
||||
h.write(self.data.as_ref());
|
||||
h.write_usize(self.n);
|
||||
h.write_usize(self.cols);
|
||||
h.write_usize(self.size);
|
||||
h.write_usize(self.max_size);
|
||||
h.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> ToOwnedDeep for Zn<D> {
|
||||
type Owned = Zn<Vec<u8>>;
|
||||
fn to_owned_deep(&self) -> Self::Owned {
|
||||
@@ -173,8 +189,18 @@ impl<D: DataRef> fmt::Display for Zn<D> {
|
||||
}
|
||||
|
||||
impl<D: DataMut> FillUniform for Zn<D> {
|
||||
fn fill_uniform(&mut self, source: &mut Source) {
|
||||
source.fill_bytes(self.data.as_mut());
|
||||
fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) {
|
||||
match log_bound {
|
||||
64 => source.fill_bytes(self.data.as_mut()),
|
||||
0 => panic!("invalid log_bound, cannot be zero"),
|
||||
_ => {
|
||||
let mask: u64 = (1u64 << log_bound) - 1;
|
||||
for x in self.raw_mut().iter_mut() {
|
||||
let r = source.next_u64() & mask;
|
||||
*x = ((r << (64 - log_bound)) as i64) >> (64 - log_bound);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -117,7 +117,7 @@ where
|
||||
}
|
||||
|
||||
pub trait FillUniform {
|
||||
fn fill_uniform(&mut self, source: &mut Source);
|
||||
fn fill_uniform(&mut self, log_bound: usize, source: &mut Source);
|
||||
}
|
||||
|
||||
pub trait Reset {
|
||||
|
||||
@@ -4,11 +4,13 @@
|
||||
#![feature(trait_alias)]
|
||||
|
||||
pub mod api;
|
||||
pub mod bench_suite;
|
||||
pub mod delegates;
|
||||
pub mod layouts;
|
||||
pub mod oep;
|
||||
pub mod reference;
|
||||
pub mod source;
|
||||
pub mod tests;
|
||||
pub mod test_suite;
|
||||
|
||||
pub mod doc {
|
||||
#[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/docs/backend_safety_contract.md"))]
|
||||
@@ -85,13 +87,20 @@ fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec<u8> {
|
||||
/// Allocates a block of T aligned with [DEFAULTALIGN].
|
||||
/// Size of T * size msut be a multiple of [DEFAULTALIGN].
|
||||
pub fn alloc_aligned_custom<T>(size: usize, align: usize) -> Vec<T> {
|
||||
assert_eq!(
|
||||
(size * size_of::<T>()) % (align / size_of::<T>()),
|
||||
0,
|
||||
"size={} must be a multiple of align={}",
|
||||
size,
|
||||
assert!(
|
||||
align.is_power_of_two(),
|
||||
"Alignment must be a power of two but is {}",
|
||||
align
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
(size * size_of::<T>()) % align,
|
||||
0,
|
||||
"size*size_of::<T>()={} must be a multiple of align={}",
|
||||
size * size_of::<T>(),
|
||||
align
|
||||
);
|
||||
|
||||
let mut vec_u8: Vec<u8> = alloc_aligned_custom_u8(size_of::<T>() * size, align);
|
||||
let ptr: *mut T = vec_u8.as_mut_ptr() as *mut T;
|
||||
let len: usize = vec_u8.len() / size_of::<T>();
|
||||
@@ -100,11 +109,11 @@ pub fn alloc_aligned_custom<T>(size: usize, align: usize) -> Vec<T> {
|
||||
unsafe { Vec::from_raw_parts(ptr, len, cap) }
|
||||
}
|
||||
|
||||
/// Allocates an aligned vector of size equal to the smallest multiple
|
||||
/// of [DEFAULTALIGN]/`size_of::<T>`() that is equal or greater to `size`.
|
||||
/// Allocates an aligned vector of the given size.
|
||||
/// Padds until it is size in [u8] a multiple of [DEFAULTALIGN].
|
||||
pub fn alloc_aligned<T>(size: usize) -> Vec<T> {
|
||||
alloc_aligned_custom::<T>(
|
||||
size + (DEFAULTALIGN - (size % (DEFAULTALIGN / size_of::<T>()))) % DEFAULTALIGN,
|
||||
(size * size_of::<T>()).next_multiple_of(DEFAULTALIGN) / size_of::<T>(),
|
||||
DEFAULTALIGN,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
use crate::layouts::{Backend, Module, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDftToMut, VecZnxDftToRef};
|
||||
use crate::layouts::{
|
||||
Backend, Module, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef,
|
||||
};
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See TODO for reference code.
|
||||
@@ -39,9 +41,28 @@ pub unsafe trait SvpPrepareImpl<B: Backend> {
|
||||
/// * See TODO for reference code.
|
||||
/// * See TODO for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait SvpApplyImpl<B: Backend> {
|
||||
fn svp_apply_impl<R, A, C>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
pub unsafe trait SvpApplyDftImpl<B: Backend> {
|
||||
fn svp_apply_dft_impl<R, A, C>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: SvpPPolToRef<B>,
|
||||
C: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See TODO for reference code.
|
||||
/// * See TODO for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait SvpApplyDftToDftImpl<B: Backend> {
|
||||
fn svp_apply_dft_to_dft_impl<R, A, C>(
|
||||
module: &Module<B>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &C,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: SvpPPolToRef<B>,
|
||||
C: VecZnxDftToRef<B>;
|
||||
@@ -51,8 +72,27 @@ pub unsafe trait SvpApplyImpl<B: Backend> {
|
||||
/// * See TODO for reference code.
|
||||
/// * See TODO for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait SvpApplyInplaceImpl: Backend {
|
||||
fn svp_apply_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
pub unsafe trait SvpApplyDftToDftAddImpl<B: Backend> {
|
||||
fn svp_apply_dft_to_dft_add_impl<R, A, C>(
|
||||
module: &Module<B>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &C,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: SvpPPolToRef<B>,
|
||||
C: VecZnxDftToRef<B>;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See TODO for reference code.
|
||||
/// * See TODO for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait SvpApplyDftToDftInplaceImpl: Backend {
|
||||
fn svp_apply_dft_to_dft_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: SvpPPolToRef<Self>;
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
use rand_distr::Distribution;
|
||||
|
||||
use crate::{
|
||||
layouts::{Backend, Module, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef},
|
||||
source::Source,
|
||||
@@ -64,6 +62,28 @@ pub unsafe trait VecZnxAddInplaceImpl<B: Backend> {
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [vec_znx_add_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L86) for reference code.
|
||||
/// * See [crate::api::VecZnxAddScalar] for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxAddScalarImpl<D: Backend> {
|
||||
/// Adds the selected column of `a` on the selected column and limb of `b` and writes the result on the selected column of `res`.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn vec_znx_add_scalar_impl<R, A, B>(
|
||||
module: &Module<D>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
b_limb: usize,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
B: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [vec_znx_add_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L86) for reference code.
|
||||
/// * See [crate::api::VecZnxAddScalarInplace] for corresponding public API.
|
||||
@@ -115,6 +135,28 @@ pub unsafe trait VecZnxSubBAInplaceImpl<B: Backend> {
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See TODO.
|
||||
/// * See [crate::api::VecZnxAddScalar] for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxSubScalarImpl<D: Backend> {
|
||||
/// Adds the selected column of `a` on the selected column and limb of `b` and writes the result on the selected column of `res`.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn vec_znx_sub_scalar_impl<R, A, B>(
|
||||
module: &Module<D>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
b_limb: usize,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
B: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [vec_znx_sub_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L125) for reference code.
|
||||
/// * See [crate::api::VecZnxSubScalarInplace] for corresponding public API.
|
||||
@@ -153,14 +195,76 @@ pub unsafe trait VecZnxNegateInplaceImpl<B: Backend> {
|
||||
A: VecZnxToMut;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::reference::vec_znx::shift::vec_znx_rsh_tmp_bytes] for reference code.
|
||||
/// * See [crate::api::VecZnxRshTmpBytes] for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxRshTmpBytesImpl<B: Backend> {
|
||||
fn vec_znx_rsh_tmp_bytes_impl(module: &Module<B>) -> usize;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::reference::vec_znx::shift::vec_znx_rsh_inplace] for reference code.
|
||||
/// * See [crate::api::VecZnxRsh] for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxRshImpl<B: Backend> {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn vec_znx_rsh_inplace_impl<R, A>(
|
||||
module: &Module<B>,
|
||||
basek: usize,
|
||||
k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::reference::vec_znx::shift::vec_znx_lsh_tmp_bytes] for reference code.
|
||||
/// * See [crate::api::VecZnxLshTmpBytes] for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxLshTmpBytesImpl<B: Backend> {
|
||||
fn vec_znx_lsh_tmp_bytes_impl(module: &Module<B>) -> usize;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::reference::vec_znx::shift::vec_znx_lsh_inplace] for reference code.
|
||||
/// * See [crate::api::VecZnxLsh] for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxLshImpl<B: Backend> {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn vec_znx_lsh_inplace_impl<R, A>(
|
||||
module: &Module<B>,
|
||||
basek: usize,
|
||||
k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::cpu_spqlios::vec_znx::vec_znx_rsh_inplace_ref] for reference code.
|
||||
/// * See [crate::api::VecZnxRshInplace] for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxRshInplaceImpl<B: Backend> {
|
||||
fn vec_znx_rsh_inplace_impl<A>(module: &Module<B>, basek: usize, k: usize, a: &mut A)
|
||||
where
|
||||
A: VecZnxToMut;
|
||||
fn vec_znx_rsh_inplace_impl<R>(
|
||||
module: &Module<B>,
|
||||
basek: usize,
|
||||
k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
@@ -168,9 +272,15 @@ pub unsafe trait VecZnxRshInplaceImpl<B: Backend> {
|
||||
/// * See [crate::api::VecZnxLshInplace] for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxLshInplaceImpl<B: Backend> {
|
||||
fn vec_znx_lsh_inplace_impl<A>(module: &Module<B>, basek: usize, k: usize, a: &mut A)
|
||||
where
|
||||
A: VecZnxToMut;
|
||||
fn vec_znx_lsh_inplace_impl<R>(
|
||||
module: &Module<B>,
|
||||
basek: usize,
|
||||
k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
@@ -184,12 +294,20 @@ pub unsafe trait VecZnxRotateImpl<B: Backend> {
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See TODO;
|
||||
/// * See [crate::api::VecZnxRotateInplaceTmpBytes] for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxRotateInplaceTmpBytesImpl<B: Backend> {
|
||||
fn vec_znx_rotate_inplace_tmp_bytes_impl(module: &Module<B>) -> usize;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [vec_znx_rotate_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L164) for reference code.
|
||||
/// * See [crate::api::VecZnxRotateInplace] for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxRotateInplaceImpl<B: Backend> {
|
||||
fn vec_znx_rotate_inplace_impl<A>(module: &Module<B>, k: i64, a: &mut A, a_col: usize)
|
||||
fn vec_znx_rotate_inplace_impl<A>(module: &Module<B>, k: i64, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
A: VecZnxToMut;
|
||||
}
|
||||
@@ -199,20 +317,28 @@ pub unsafe trait VecZnxRotateInplaceImpl<B: Backend> {
|
||||
/// * See [crate::api::VecZnxAutomorphism] for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxAutomorphismImpl<B: Backend> {
|
||||
fn vec_znx_automorphism_impl<R, A>(module: &Module<B>, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
fn vec_znx_automorphism_impl<R, A>(module: &Module<B>, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See TODO;
|
||||
/// * See [crate::api::VecZnxAutomorphismInplaceTmpBytes] for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxAutomorphismInplaceTmpBytesImpl<B: Backend> {
|
||||
fn vec_znx_automorphism_inplace_tmp_bytes_impl(module: &Module<B>) -> usize;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [vec_znx_automorphism_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L188) for reference code.
|
||||
/// * See [crate::api::VecZnxAutomorphismInplace] for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxAutomorphismInplaceImpl<B: Backend> {
|
||||
fn vec_znx_automorphism_inplace_impl<A>(module: &Module<B>, k: i64, a: &mut A, a_col: usize)
|
||||
fn vec_znx_automorphism_inplace_impl<R>(module: &Module<B>, k: i64, res: &mut R, res_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
A: VecZnxToMut;
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
@@ -226,34 +352,75 @@ pub unsafe trait VecZnxMulXpMinusOneImpl<B: Backend> {
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See TODO;
|
||||
/// * See [crate::api::VecZnxMulXpMinusOneInplaceTmpBytes] for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxMulXpMinusOneInplaceTmpBytesImpl<B: Backend> {
|
||||
fn vec_znx_mul_xp_minus_one_inplace_tmp_bytes_impl(module: &Module<B>) -> usize;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [vec_znx_mul_xp_minus_one_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/7160f588da49712a042931ea247b4259b95cefcc/spqlios/arithmetic/vec_znx.c#L200C13-L200C41) for reference code.
|
||||
/// * See [crate::api::VecZnxMulXpMinusOneInplace] for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxMulXpMinusOneInplaceImpl<B: Backend> {
|
||||
fn vec_znx_mul_xp_minus_one_inplace_impl<R>(module: &Module<B>, p: i64, res: &mut R, res_col: usize)
|
||||
where
|
||||
fn vec_znx_mul_xp_minus_one_inplace_impl<R>(
|
||||
module: &Module<B>,
|
||||
p: i64,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::cpu_spqlios::vec_znx::vec_znx_split_ref] for reference code.
|
||||
/// * See [crate::api::VecZnxSplit] for corresponding public API.
|
||||
/// * See TODO;
|
||||
/// * See [crate::api::VecZnxSplitRingTmpBytes] for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxSplitImpl<B: Backend> {
|
||||
fn vec_znx_split_impl<R, A>(module: &Module<B>, res: &mut [R], res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
pub unsafe trait VecZnxSplitRingTmpBytesImpl<B: Backend> {
|
||||
fn vec_znx_split_ring_tmp_bytes_impl(module: &Module<B>) -> usize;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::cpu_spqlios::vec_znx::vec_znx_split_ref] for reference code.
|
||||
/// * See [crate::api::VecZnxSplitRing] for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxSplitRingImpl<B: Backend> {
|
||||
fn vec_znx_split_ring_impl<R, A>(
|
||||
module: &Module<B>,
|
||||
res: &mut [R],
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See TODO;
|
||||
/// * See [crate::api::VecZnxMergeRingsTmpBytes] for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxMergeRingsTmpBytesImpl<B: Backend> {
|
||||
fn vec_znx_merge_rings_tmp_bytes_impl(module: &Module<B>) -> usize;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::cpu_spqlios::vec_znx::vec_znx_merge_ref] for reference code.
|
||||
/// * See [crate::api::VecZnxMerge] for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxMergeImpl<B: Backend> {
|
||||
fn vec_znx_merge_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &[A], a_col: usize)
|
||||
where
|
||||
pub unsafe trait VecZnxMergeRingsImpl<B: Backend> {
|
||||
fn vec_znx_merge_rings_impl<R, A>(
|
||||
module: &Module<B>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &[A],
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
@@ -262,8 +429,8 @@ pub unsafe trait VecZnxMergeImpl<B: Backend> {
|
||||
/// * See [crate::cpu_spqlios::vec_znx::vec_znx_switch_degree_ref] for reference code.
|
||||
/// * See [crate::api::VecZnxSwithcDegree] for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxSwithcDegreeImpl<B: Backend> {
|
||||
fn vec_znx_switch_degree_impl<R: VecZnxToMut, A: VecZnxToRef>(
|
||||
pub unsafe trait VecZnxSwitchRingImpl<B: Backend> {
|
||||
fn vec_znx_switch_ring_impl<R: VecZnxToMut, A: VecZnxToRef>(
|
||||
module: &Module<B>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
@@ -287,47 +454,11 @@ pub unsafe trait VecZnxCopyImpl<B: Backend> {
|
||||
/// * See [crate::api::VecZnxFillUniform] for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxFillUniformImpl<B: Backend> {
|
||||
fn vec_znx_fill_uniform_impl<R>(module: &Module<B>, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source)
|
||||
fn vec_znx_fill_uniform_impl<R>(module: &Module<B>, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
|
||||
where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::api::VecZnxFillDistF64] for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxFillDistF64Impl<B: Backend> {
|
||||
fn vec_znx_fill_dist_f64_impl<R, D: Distribution<f64>>(
|
||||
module: &Module<B>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::api::VecZnxAddDistF64] for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxAddDistF64Impl<B: Backend> {
|
||||
fn vec_znx_add_dist_f64_impl<R, D: Distribution<f64>>(
|
||||
module: &Module<B>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::api::VecZnxFillNormal] for corresponding public API.
|
||||
|
||||
@@ -1,10 +1,19 @@
|
||||
use rand_distr::Distribution;
|
||||
|
||||
use crate::{
|
||||
layouts::{Backend, Module, Scratch, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See TODO for reference code.
|
||||
/// * See TODO for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxBigFromSmallImpl<B: Backend> {
|
||||
fn vec_znx_big_from_small_impl<R, A>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See TODO for reference code.
|
||||
/// * See TODO for corresponding public API.
|
||||
@@ -47,60 +56,6 @@ pub unsafe trait VecZnxBigAddNormalImpl<B: Backend> {
|
||||
);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See TODO for reference code.
|
||||
/// * See TODO for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxBigFillNormalImpl<B: Backend> {
|
||||
fn fill_normal_impl<R: VecZnxBigToMut<B>>(
|
||||
module: &Module<B>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See TODO for reference code.
|
||||
/// * See TODO for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxBigFillDistF64Impl<B: Backend> {
|
||||
fn fill_dist_f64_impl<R: VecZnxBigToMut<B>, D: Distribution<f64>>(
|
||||
module: &Module<B>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See TODO for reference code.
|
||||
/// * See TODO for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxBigAddDistF64Impl<B: Backend> {
|
||||
fn add_dist_f64_impl<R: VecZnxBigToMut<B>, D: Distribution<f64>>(
|
||||
module: &Module<B>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
);
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See TODO for reference code.
|
||||
/// * See TODO for corresponding public API.
|
||||
@@ -248,6 +203,17 @@ pub unsafe trait VecZnxBigSubSmallBInplaceImpl<B: Backend> {
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See TODO for reference code.
|
||||
/// * See TODO for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxBigNegateImpl<B: Backend> {
|
||||
fn vec_znx_big_negate_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See TODO for reference code.
|
||||
/// * See TODO for corresponding public API.
|
||||
@@ -295,12 +261,20 @@ pub unsafe trait VecZnxBigAutomorphismImpl<B: Backend> {
|
||||
A: VecZnxBigToRef<B>;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See TODO for reference code.
|
||||
/// * See TODO for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxBigAutomorphismInplaceTmpBytesImpl<B: Backend> {
|
||||
fn vec_znx_big_automorphism_inplace_tmp_bytes_impl(module: &Module<B>) -> usize;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See TODO for reference code.
|
||||
/// * See TODO for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxBigAutomorphismInplaceImpl<B: Backend> {
|
||||
fn vec_znx_big_automorphism_inplace_impl<A>(module: &Module<B>, k: i64, a: &mut A, a_col: usize)
|
||||
fn vec_znx_big_automorphism_inplace_impl<A>(module: &Module<B>, k: i64, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
A: VecZnxBigToMut<B>;
|
||||
}
|
||||
|
||||
@@ -23,9 +23,16 @@ pub unsafe trait VecZnxDftFromBytesImpl<B: Backend> {
|
||||
/// * See TODO for reference code.
|
||||
/// * See TODO for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait DFTImpl<B: Backend> {
|
||||
fn dft_impl<R, A>(module: &Module<B>, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
pub unsafe trait VecZnxDftApplyImpl<B: Backend> {
|
||||
fn vec_znx_dft_apply_impl<R, A>(
|
||||
module: &Module<B>,
|
||||
step: usize,
|
||||
offset: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
) where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
@@ -42,17 +49,23 @@ pub unsafe trait VecZnxDftAllocBytesImpl<B: Backend> {
|
||||
/// * See TODO for reference code.
|
||||
/// * See TODO for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxIDFTTmpBytesImpl<B: Backend> {
|
||||
fn vec_znx_idft_tmp_bytes_impl(module: &Module<B>) -> usize;
|
||||
pub unsafe trait VecZnxIdftApplyTmpBytesImpl<B: Backend> {
|
||||
fn vec_znx_idft_apply_tmp_bytes_impl(module: &Module<B>) -> usize;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See TODO for reference code.
|
||||
/// * See TODO for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait IDFTImpl<B: Backend> {
|
||||
fn idft_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
pub unsafe trait VecZnxIdftApplyImpl<B: Backend> {
|
||||
fn vec_znx_idft_apply_impl<R, A>(
|
||||
module: &Module<B>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxDftToRef<B>;
|
||||
}
|
||||
@@ -61,8 +74,8 @@ pub unsafe trait IDFTImpl<B: Backend> {
|
||||
/// * See TODO for reference code.
|
||||
/// * See for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait IDFTTmpAImpl<B: Backend> {
|
||||
fn idft_tmp_a_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
|
||||
pub unsafe trait VecZnxIdftApplyTmpAImpl<B: Backend> {
|
||||
fn vec_znx_idft_apply_tmpa_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxDftToMut<B>;
|
||||
@@ -72,8 +85,8 @@ pub unsafe trait IDFTTmpAImpl<B: Backend> {
|
||||
/// * See TODO for reference code.
|
||||
/// * See TODO for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait IDFTConsumeImpl<B: Backend> {
|
||||
fn idft_consume_impl<D: Data>(module: &Module<B>, a: VecZnxDft<D, B>) -> VecZnxBig<D, B>
|
||||
pub unsafe trait VecZnxIdftApplyConsumeImpl<B: Backend> {
|
||||
fn vec_znx_idft_apply_consume_impl<D: Data>(module: &Module<B>, a: VecZnxDft<D, B>) -> VecZnxBig<D, B>
|
||||
where
|
||||
VecZnxDft<D, B>: VecZnxDftToMut<B>;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::layouts::{
|
||||
Backend, MatZnxToRef, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VmpPMatOwned, VmpPMatToMut, VmpPMatToRef,
|
||||
Backend, MatZnxToRef, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, VmpPMatOwned, VmpPMatToMut, VmpPMatToRef,
|
||||
};
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
@@ -45,13 +45,42 @@ pub unsafe trait VmpPrepareTmpBytesImpl<B: Backend> {
|
||||
/// * See TODO for reference code.
|
||||
/// * See TODO for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VmpPMatPrepareImpl<B: Backend> {
|
||||
pub unsafe trait VmpPrepareImpl<B: Backend> {
|
||||
fn vmp_prepare_impl<R, A>(module: &Module<B>, res: &mut R, a: &A, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VmpPMatToMut<B>,
|
||||
A: MatZnxToRef;
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See TODO for reference code.
|
||||
/// * See TODO for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VmpApplyDftTmpBytesImpl<B: Backend> {
|
||||
fn vmp_apply_dft_tmp_bytes_impl(
|
||||
module: &Module<B>,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
b_cols_in: usize,
|
||||
b_cols_out: usize,
|
||||
b_size: usize,
|
||||
) -> usize;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See TODO for reference code.
|
||||
/// * See TODO for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VmpApplyDftImpl<B: Backend> {
|
||||
fn vmp_apply_dft_impl<R, A, C>(module: &Module<B>, res: &mut R, a: &A, b: &C, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxToRef,
|
||||
C: VmpPMatToRef<B>;
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See TODO for reference code.
|
||||
|
||||
@@ -1,65 +1,35 @@
|
||||
use rand_distr::Distribution;
|
||||
|
||||
use crate::{
|
||||
layouts::{Backend, Scratch, ZnToMut},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See TODO
|
||||
/// * See [crate::api::ZnNormalizeTmpBytes] for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait ZnNormalizeTmpBytesImpl<B: Backend> {
|
||||
fn zn_normalize_tmp_bytes_impl(n: usize) -> usize;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [zn_normalize_base2k_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/zn64.c#L9) for reference code.
|
||||
/// * See [crate::api::ZnxNormalizeInplace] for corresponding public API.
|
||||
/// * See [crate::api::ZnNormalizeInplace] for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait ZnNormalizeInplaceImpl<B: Backend> {
|
||||
fn zn_normalize_inplace_impl<A>(n: usize, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
fn zn_normalize_inplace_impl<R>(n: usize, basek: usize, res: &mut R, res_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
A: ZnToMut;
|
||||
R: ZnToMut;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::api::ZnFillUniform] for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait ZnFillUniformImpl<B: Backend> {
|
||||
fn zn_fill_uniform_impl<R>(n: usize, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source)
|
||||
fn zn_fill_uniform_impl<R>(n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
|
||||
where
|
||||
R: ZnToMut;
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::api::ZnFillDistF64] for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait ZnFillDistF64Impl<B: Backend> {
|
||||
fn zn_fill_dist_f64_impl<R, D: Distribution<f64>>(
|
||||
n: usize,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: ZnToMut;
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::api::ZnAddDistF64] for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait ZnAddDistF64Impl<B: Backend> {
|
||||
fn zn_add_dist_f64_impl<R, D: Distribution<f64>>(
|
||||
n: usize,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: ZnToMut;
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::api::ZnFillNormal] for corresponding public API.
|
||||
|
||||
24
poulpy-hal/src/reference/fft64/mod.rs
Normal file
24
poulpy-hal/src/reference/fft64/mod.rs
Normal file
@@ -0,0 +1,24 @@
|
||||
pub mod reim;
|
||||
pub mod reim4;
|
||||
pub mod svp;
|
||||
pub mod vec_znx_big;
|
||||
pub mod vec_znx_dft;
|
||||
pub mod vmp;
|
||||
|
||||
pub(crate) fn assert_approx_eq_slice(a: &[f64], b: &[f64], tol: f64) {
|
||||
assert_eq!(a.len(), b.len(), "Slices have different lengths");
|
||||
|
||||
for (i, (x, y)) in a.iter().zip(b.iter()).enumerate() {
|
||||
let diff: f64 = (x - y).abs();
|
||||
let scale: f64 = x.abs().max(y.abs()).max(1.0);
|
||||
assert!(
|
||||
diff <= tol * scale,
|
||||
"Difference at index {}: left={} right={} rel_diff={} > tol={}",
|
||||
i,
|
||||
x,
|
||||
y,
|
||||
diff / scale,
|
||||
tol
|
||||
);
|
||||
}
|
||||
}
|
||||
31
poulpy-hal/src/reference/fft64/reim/conversion.rs
Normal file
31
poulpy-hal/src/reference/fft64/reim/conversion.rs
Normal file
@@ -0,0 +1,31 @@
|
||||
#[inline(always)]
|
||||
pub fn reim_from_znx_i64_ref(res: &mut [f64], a: &[i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), a.len())
|
||||
}
|
||||
|
||||
for i in 0..res.len() {
|
||||
res[i] = a[i] as f64
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim_to_znx_i64_ref(res: &mut [i64], divisor: f64, a: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), a.len())
|
||||
}
|
||||
let inv_div = 1. / divisor;
|
||||
for i in 0..res.len() {
|
||||
res[i] = (a[i] * inv_div).round() as i64
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim_to_znx_i64_inplace_ref(res: &mut [f64], divisor: f64) {
|
||||
let inv_div = 1. / divisor;
|
||||
for ri in res {
|
||||
*ri = f64::from_bits(((*ri * inv_div).round() as i64) as u64)
|
||||
}
|
||||
}
|
||||
327
poulpy-hal/src/reference/fft64/reim/fft_ref.rs
Normal file
327
poulpy-hal/src/reference/fft64/reim/fft_ref.rs
Normal file
@@ -0,0 +1,327 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use rand_distr::num_traits::{Float, FloatConst};
|
||||
|
||||
use crate::reference::fft64::reim::{as_arr, as_arr_mut};
|
||||
|
||||
#[inline(always)]
|
||||
pub fn fft_ref<R: Float + FloatConst + Debug>(m: usize, omg: &[R], data: &mut [R]) {
|
||||
assert!(data.len() == 2 * m);
|
||||
let (re, im) = data.split_at_mut(m);
|
||||
|
||||
if m <= 16 {
|
||||
match m {
|
||||
1 => {}
|
||||
2 => fft2_ref(
|
||||
as_arr_mut::<2, R>(re),
|
||||
as_arr_mut::<2, R>(im),
|
||||
*as_arr::<2, R>(omg),
|
||||
),
|
||||
4 => fft4_ref(
|
||||
as_arr_mut::<4, R>(re),
|
||||
as_arr_mut::<4, R>(im),
|
||||
*as_arr::<4, R>(omg),
|
||||
),
|
||||
8 => fft8_ref(
|
||||
as_arr_mut::<8, R>(re),
|
||||
as_arr_mut::<8, R>(im),
|
||||
*as_arr::<8, R>(omg),
|
||||
),
|
||||
16 => fft16_ref(
|
||||
as_arr_mut::<16, R>(re),
|
||||
as_arr_mut::<16, R>(im),
|
||||
*as_arr::<16, R>(omg),
|
||||
),
|
||||
_ => {}
|
||||
}
|
||||
} else if m <= 2048 {
|
||||
fft_bfs_16_ref(m, re, im, omg, 0);
|
||||
} else {
|
||||
fft_rec_16_ref(m, re, im, omg, 0);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fft_rec_16_ref<R: Float + FloatConst + Debug>(m: usize, re: &mut [R], im: &mut [R], omg: &[R], mut pos: usize) -> usize {
|
||||
if m <= 2048 {
|
||||
return fft_bfs_16_ref(m, re, im, omg, pos);
|
||||
};
|
||||
|
||||
let h = m >> 1;
|
||||
twiddle_fft_ref(h, re, im, as_arr::<2, R>(&omg[pos..]));
|
||||
pos += 2;
|
||||
pos = fft_rec_16_ref(h, re, im, omg, pos);
|
||||
pos = fft_rec_16_ref(h, &mut re[h..], &mut im[h..], omg, pos);
|
||||
pos
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn cplx_twiddle<R: Float + FloatConst>(ra: &mut R, ia: &mut R, rb: &mut R, ib: &mut R, omg_re: R, omg_im: R) {
|
||||
let dr: R = *rb * omg_re - *ib * omg_im;
|
||||
let di: R = *rb * omg_im + *ib * omg_re;
|
||||
*rb = *ra - dr;
|
||||
*ib = *ia - di;
|
||||
*ra = *ra + dr;
|
||||
*ia = *ia + di;
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn cplx_i_twiddle<R: Float + FloatConst>(ra: &mut R, ia: &mut R, rb: &mut R, ib: &mut R, omg_re: R, omg_im: R) {
|
||||
let dr: R = *rb * omg_im + *ib * omg_re;
|
||||
let di: R = *rb * omg_re - *ib * omg_im;
|
||||
*rb = *ra + dr;
|
||||
*ib = *ia - di;
|
||||
*ra = *ra - dr;
|
||||
*ia = *ia + di;
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fft2_ref<R: Float + FloatConst>(re: &mut [R; 2], im: &mut [R; 2], omg: [R; 2]) {
|
||||
let [ra, rb] = re;
|
||||
let [ia, ib] = im;
|
||||
let [romg, iomg] = omg;
|
||||
cplx_twiddle(ra, ia, rb, ib, romg, iomg);
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fft4_ref<R: Float + FloatConst>(re: &mut [R; 4], im: &mut [R; 4], omg: [R; 4]) {
|
||||
let [re_0, re_1, re_2, re_3] = re;
|
||||
let [im_0, im_1, im_2, im_3] = im;
|
||||
|
||||
{
|
||||
let omg_0 = omg[0];
|
||||
let omg_1 = omg[1];
|
||||
cplx_twiddle(re_0, im_0, re_2, im_2, omg_0, omg_1);
|
||||
cplx_twiddle(re_1, im_1, re_3, im_3, omg_0, omg_1);
|
||||
}
|
||||
|
||||
{
|
||||
let omg_0 = omg[2];
|
||||
let omg_1 = omg[3];
|
||||
cplx_twiddle(re_0, im_0, re_1, im_1, omg_0, omg_1);
|
||||
cplx_i_twiddle(re_2, im_2, re_3, im_3, omg_0, omg_1);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fft8_ref<R: Float + FloatConst>(re: &mut [R; 8], im: &mut [R; 8], omg: [R; 8]) {
|
||||
let [re_0, re_1, re_2, re_3, re_4, re_5, re_6, re_7] = re;
|
||||
let [im_0, im_1, im_2, im_3, im_4, im_5, im_6, im_7] = im;
|
||||
|
||||
{
|
||||
let omg_0 = omg[0];
|
||||
let omg_1 = omg[1];
|
||||
cplx_twiddle(re_0, im_0, re_4, im_4, omg_0, omg_1);
|
||||
cplx_twiddle(re_1, im_1, re_5, im_5, omg_0, omg_1);
|
||||
cplx_twiddle(re_2, im_2, re_6, im_6, omg_0, omg_1);
|
||||
cplx_twiddle(re_3, im_3, re_7, im_7, omg_0, omg_1);
|
||||
}
|
||||
|
||||
{
|
||||
let omg_2 = omg[2];
|
||||
let omg_3 = omg[3];
|
||||
cplx_twiddle(re_0, im_0, re_2, im_2, omg_2, omg_3);
|
||||
cplx_twiddle(re_1, im_1, re_3, im_3, omg_2, omg_3);
|
||||
cplx_i_twiddle(re_4, im_4, re_6, im_6, omg_2, omg_3);
|
||||
cplx_i_twiddle(re_5, im_5, re_7, im_7, omg_2, omg_3);
|
||||
}
|
||||
|
||||
{
|
||||
let omg_4 = omg[4];
|
||||
let omg_5 = omg[5];
|
||||
let omg_6 = omg[6];
|
||||
let omg_7 = omg[7];
|
||||
cplx_twiddle(re_0, im_0, re_1, im_1, omg_4, omg_6);
|
||||
cplx_i_twiddle(re_2, im_2, re_3, im_3, omg_4, omg_6);
|
||||
cplx_twiddle(re_4, im_4, re_5, im_5, omg_5, omg_7);
|
||||
cplx_i_twiddle(re_6, im_6, re_7, im_7, omg_5, omg_7);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fft16_ref<R: Float + FloatConst + Debug>(re: &mut [R; 16], im: &mut [R; 16], omg: [R; 16]) {
|
||||
let [
|
||||
re_0,
|
||||
re_1,
|
||||
re_2,
|
||||
re_3,
|
||||
re_4,
|
||||
re_5,
|
||||
re_6,
|
||||
re_7,
|
||||
re_8,
|
||||
re_9,
|
||||
re_10,
|
||||
re_11,
|
||||
re_12,
|
||||
re_13,
|
||||
re_14,
|
||||
re_15,
|
||||
] = re;
|
||||
let [
|
||||
im_0,
|
||||
im_1,
|
||||
im_2,
|
||||
im_3,
|
||||
im_4,
|
||||
im_5,
|
||||
im_6,
|
||||
im_7,
|
||||
im_8,
|
||||
im_9,
|
||||
im_10,
|
||||
im_11,
|
||||
im_12,
|
||||
im_13,
|
||||
im_14,
|
||||
im_15,
|
||||
] = im;
|
||||
|
||||
{
|
||||
let omg_0: R = omg[0];
|
||||
let omg_1: R = omg[1];
|
||||
cplx_twiddle(re_0, im_0, re_8, im_8, omg_0, omg_1);
|
||||
cplx_twiddle(re_1, im_1, re_9, im_9, omg_0, omg_1);
|
||||
cplx_twiddle(re_2, im_2, re_10, im_10, omg_0, omg_1);
|
||||
cplx_twiddle(re_3, im_3, re_11, im_11, omg_0, omg_1);
|
||||
|
||||
cplx_twiddle(re_4, im_4, re_12, im_12, omg_0, omg_1);
|
||||
cplx_twiddle(re_5, im_5, re_13, im_13, omg_0, omg_1);
|
||||
cplx_twiddle(re_6, im_6, re_14, im_14, omg_0, omg_1);
|
||||
cplx_twiddle(re_7, im_7, re_15, im_15, omg_0, omg_1);
|
||||
}
|
||||
|
||||
{
|
||||
let omg_2: R = omg[2];
|
||||
let omg_3: R = omg[3];
|
||||
cplx_twiddle(re_0, im_0, re_4, im_4, omg_2, omg_3);
|
||||
cplx_twiddle(re_1, im_1, re_5, im_5, omg_2, omg_3);
|
||||
cplx_twiddle(re_2, im_2, re_6, im_6, omg_2, omg_3);
|
||||
cplx_twiddle(re_3, im_3, re_7, im_7, omg_2, omg_3);
|
||||
|
||||
cplx_i_twiddle(re_8, im_8, re_12, im_12, omg_2, omg_3);
|
||||
cplx_i_twiddle(re_9, im_9, re_13, im_13, omg_2, omg_3);
|
||||
cplx_i_twiddle(re_10, im_10, re_14, im_14, omg_2, omg_3);
|
||||
cplx_i_twiddle(re_11, im_11, re_15, im_15, omg_2, omg_3);
|
||||
}
|
||||
|
||||
{
|
||||
let omg_0: R = omg[4];
|
||||
let omg_1: R = omg[5];
|
||||
let omg_2: R = omg[6];
|
||||
let omg_3: R = omg[7];
|
||||
cplx_twiddle(re_0, im_0, re_2, im_2, omg_0, omg_1);
|
||||
cplx_twiddle(re_1, im_1, re_3, im_3, omg_0, omg_1);
|
||||
cplx_twiddle(re_8, im_8, re_10, im_10, omg_2, omg_3);
|
||||
cplx_twiddle(re_9, im_9, re_11, im_11, omg_2, omg_3);
|
||||
|
||||
cplx_i_twiddle(re_4, im_4, re_6, im_6, omg_0, omg_1);
|
||||
cplx_i_twiddle(re_5, im_5, re_7, im_7, omg_0, omg_1);
|
||||
cplx_i_twiddle(re_12, im_12, re_14, im_14, omg_2, omg_3);
|
||||
cplx_i_twiddle(re_13, im_13, re_15, im_15, omg_2, omg_3);
|
||||
}
|
||||
|
||||
{
|
||||
let omg_0: R = omg[8];
|
||||
let omg_1: R = omg[9];
|
||||
let omg_2: R = omg[10];
|
||||
let omg_3: R = omg[11];
|
||||
let omg_4: R = omg[12];
|
||||
let omg_5: R = omg[13];
|
||||
let omg_6: R = omg[14];
|
||||
let omg_7: R = omg[15];
|
||||
cplx_twiddle(re_0, im_0, re_1, im_1, omg_0, omg_4);
|
||||
cplx_twiddle(re_4, im_4, re_5, im_5, omg_1, omg_5);
|
||||
cplx_twiddle(re_8, im_8, re_9, im_9, omg_2, omg_6);
|
||||
cplx_twiddle(re_12, im_12, re_13, im_13, omg_3, omg_7);
|
||||
|
||||
cplx_i_twiddle(re_2, im_2, re_3, im_3, omg_0, omg_4);
|
||||
cplx_i_twiddle(re_6, im_6, re_7, im_7, omg_1, omg_5);
|
||||
cplx_i_twiddle(re_10, im_10, re_11, im_11, omg_2, omg_6);
|
||||
cplx_i_twiddle(re_14, im_14, re_15, im_15, omg_3, omg_7);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fft_bfs_16_ref<R: Float + FloatConst + Debug>(m: usize, re: &mut [R], im: &mut [R], omg: &[R], mut pos: usize) -> usize {
|
||||
let log_m: usize = (usize::BITS - (m - 1).leading_zeros()) as usize;
|
||||
let mut mm: usize = m;
|
||||
|
||||
if !log_m.is_multiple_of(2) {
|
||||
let h: usize = mm >> 1;
|
||||
twiddle_fft_ref(h, re, im, as_arr::<2, R>(&omg[pos..]));
|
||||
pos += 2;
|
||||
mm = h
|
||||
}
|
||||
|
||||
while mm > 16 {
|
||||
let h: usize = mm >> 2;
|
||||
for off in (0..m).step_by(mm) {
|
||||
bitwiddle_fft_ref(
|
||||
h,
|
||||
&mut re[off..],
|
||||
&mut im[off..],
|
||||
as_arr::<4, R>(&omg[pos..]),
|
||||
);
|
||||
pos += 4;
|
||||
}
|
||||
mm = h
|
||||
}
|
||||
|
||||
for off in (0..m).step_by(16) {
|
||||
fft16_ref(
|
||||
as_arr_mut::<16, R>(&mut re[off..]),
|
||||
as_arr_mut::<16, R>(&mut im[off..]),
|
||||
*as_arr::<16, R>(&omg[pos..]),
|
||||
);
|
||||
pos += 16;
|
||||
}
|
||||
|
||||
pos
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn twiddle_fft_ref<R: Float + FloatConst>(h: usize, re: &mut [R], im: &mut [R], omg: &[R; 2]) {
|
||||
let romg = omg[0];
|
||||
let iomg = omg[1];
|
||||
|
||||
let (re_lhs, re_rhs) = re.split_at_mut(h);
|
||||
let (im_lhs, im_rhs) = im.split_at_mut(h);
|
||||
|
||||
for i in 0..h {
|
||||
cplx_twiddle(
|
||||
&mut re_lhs[i],
|
||||
&mut im_lhs[i],
|
||||
&mut re_rhs[i],
|
||||
&mut im_rhs[i],
|
||||
romg,
|
||||
iomg,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn bitwiddle_fft_ref<R: Float + FloatConst>(h: usize, re: &mut [R], im: &mut [R], omg: &[R; 4]) {
|
||||
let (r0, r2) = re.split_at_mut(2 * h);
|
||||
let (r0, r1) = r0.split_at_mut(h);
|
||||
let (r2, r3) = r2.split_at_mut(h);
|
||||
|
||||
let (i0, i2) = im.split_at_mut(2 * h);
|
||||
let (i0, i1) = i0.split_at_mut(h);
|
||||
let (i2, i3) = i2.split_at_mut(h);
|
||||
|
||||
let omg_0: R = omg[0];
|
||||
let omg_1: R = omg[1];
|
||||
let omg_2: R = omg[2];
|
||||
let omg_3: R = omg[3];
|
||||
|
||||
for i in 0..h {
|
||||
cplx_twiddle(&mut r0[i], &mut i0[i], &mut r2[i], &mut i2[i], omg_0, omg_1);
|
||||
cplx_twiddle(&mut r1[i], &mut i1[i], &mut r3[i], &mut i3[i], omg_0, omg_1);
|
||||
}
|
||||
|
||||
for i in 0..h {
|
||||
cplx_twiddle(&mut r0[i], &mut i0[i], &mut r1[i], &mut i1[i], omg_2, omg_3);
|
||||
cplx_i_twiddle(&mut r2[i], &mut i2[i], &mut r3[i], &mut i3[i], omg_2, omg_3);
|
||||
}
|
||||
}
|
||||
156
poulpy-hal/src/reference/fft64/reim/fft_vec.rs
Normal file
156
poulpy-hal/src/reference/fft64/reim/fft_vec.rs
Normal file
@@ -0,0 +1,156 @@
|
||||
#[inline(always)]
|
||||
pub fn reim_add_ref(res: &mut [f64], a: &[f64], b: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
assert_eq!(b.len(), res.len());
|
||||
}
|
||||
|
||||
for i in 0..res.len() {
|
||||
res[i] = a[i] + b[i]
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim_add_inplace_ref(res: &mut [f64], a: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
}
|
||||
|
||||
for i in 0..res.len() {
|
||||
res[i] += a[i]
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim_sub_ref(res: &mut [f64], a: &[f64], b: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
assert_eq!(b.len(), res.len());
|
||||
}
|
||||
|
||||
for i in 0..res.len() {
|
||||
res[i] = a[i] - b[i]
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim_sub_ab_inplace_ref(res: &mut [f64], a: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
}
|
||||
|
||||
for i in 0..res.len() {
|
||||
res[i] -= a[i]
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim_sub_ba_inplace_ref(res: &mut [f64], a: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
}
|
||||
|
||||
for i in 0..res.len() {
|
||||
res[i] = a[i] - res[i]
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim_negate_ref(res: &mut [f64], a: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
}
|
||||
|
||||
for i in 0..res.len() {
|
||||
res[i] = -a[i]
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim_negate_inplace_ref(res: &mut [f64]) {
|
||||
for ri in res {
|
||||
*ri = -*ri
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim_addmul_ref(res: &mut [f64], a: &[f64], b: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
assert_eq!(b.len(), res.len());
|
||||
}
|
||||
|
||||
let m: usize = res.len() >> 1;
|
||||
|
||||
let (rr, ri) = res.split_at_mut(m);
|
||||
let (ar, ai) = a.split_at(m);
|
||||
let (br, bi) = b.split_at(m);
|
||||
|
||||
for i in 0..m {
|
||||
let _ar: f64 = ar[i];
|
||||
let _ai: f64 = ai[i];
|
||||
let _br: f64 = br[i];
|
||||
let _bi: f64 = bi[i];
|
||||
let _rr: f64 = _ar * _br - _ai * _bi;
|
||||
let _ri: f64 = _ar * _bi + _ai * _br;
|
||||
rr[i] += _rr;
|
||||
ri[i] += _ri;
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim_mul_inplace_ref(res: &mut [f64], a: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
}
|
||||
|
||||
let m: usize = res.len() >> 1;
|
||||
|
||||
let (rr, ri) = res.split_at_mut(m);
|
||||
let (ar, ai) = a.split_at(m);
|
||||
|
||||
for i in 0..m {
|
||||
let _ar: f64 = ar[i];
|
||||
let _ai: f64 = ai[i];
|
||||
let _br: f64 = rr[i];
|
||||
let _bi: f64 = ri[i];
|
||||
let _rr: f64 = _ar * _br - _ai * _bi;
|
||||
let _ri: f64 = _ar * _bi + _ai * _br;
|
||||
rr[i] = _rr;
|
||||
ri[i] = _ri;
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim_mul_ref(res: &mut [f64], a: &[f64], b: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
assert_eq!(b.len(), res.len());
|
||||
}
|
||||
|
||||
let m: usize = res.len() >> 1;
|
||||
|
||||
let (rr, ri) = res.split_at_mut(m);
|
||||
let (ar, ai) = a.split_at(m);
|
||||
let (br, bi) = b.split_at(m);
|
||||
|
||||
for i in 0..m {
|
||||
let _ar: f64 = ar[i];
|
||||
let _ai: f64 = ai[i];
|
||||
let _br: f64 = br[i];
|
||||
let _bi: f64 = bi[i];
|
||||
let _rr: f64 = _ar * _br - _ai * _bi;
|
||||
let _ri: f64 = _ar * _bi + _ai * _br;
|
||||
rr[i] = _rr;
|
||||
ri[i] = _ri;
|
||||
}
|
||||
}
|
||||
322
poulpy-hal/src/reference/fft64/reim/ifft_ref.rs
Normal file
322
poulpy-hal/src/reference/fft64/reim/ifft_ref.rs
Normal file
@@ -0,0 +1,322 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use rand_distr::num_traits::{Float, FloatConst};
|
||||
|
||||
use crate::reference::fft64::reim::{as_arr, as_arr_mut};
|
||||
|
||||
pub fn ifft_ref<R: Float + FloatConst + Debug>(m: usize, omg: &[R], data: &mut [R]) {
|
||||
assert!(data.len() == 2 * m);
|
||||
let (re, im) = data.split_at_mut(m);
|
||||
|
||||
if m <= 16 {
|
||||
match m {
|
||||
1 => {}
|
||||
2 => ifft2_ref(
|
||||
as_arr_mut::<2, R>(re),
|
||||
as_arr_mut::<2, R>(im),
|
||||
*as_arr::<2, R>(omg),
|
||||
),
|
||||
4 => ifft4_ref(
|
||||
as_arr_mut::<4, R>(re),
|
||||
as_arr_mut::<4, R>(im),
|
||||
*as_arr::<4, R>(omg),
|
||||
),
|
||||
8 => ifft8_ref(
|
||||
as_arr_mut::<8, R>(re),
|
||||
as_arr_mut::<8, R>(im),
|
||||
*as_arr::<8, R>(omg),
|
||||
),
|
||||
16 => ifft16_ref(
|
||||
as_arr_mut::<16, R>(re),
|
||||
as_arr_mut::<16, R>(im),
|
||||
*as_arr::<16, R>(omg),
|
||||
),
|
||||
_ => {}
|
||||
}
|
||||
} else if m <= 2048 {
|
||||
ifft_bfs_16_ref(m, re, im, omg, 0);
|
||||
} else {
|
||||
ifft_rec_16_ref(m, re, im, omg, 0);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn ifft_rec_16_ref<R: Float + FloatConst>(m: usize, re: &mut [R], im: &mut [R], omg: &[R], mut pos: usize) -> usize {
|
||||
if m <= 2048 {
|
||||
return ifft_bfs_16_ref(m, re, im, omg, pos);
|
||||
};
|
||||
let h: usize = m >> 1;
|
||||
pos = ifft_rec_16_ref(h, re, im, omg, pos);
|
||||
pos = ifft_rec_16_ref(h, &mut re[h..], &mut im[h..], omg, pos);
|
||||
inv_twiddle_ifft_ref(h, re, im, as_arr::<2, R>(&omg[pos..]));
|
||||
pos += 2;
|
||||
pos
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn ifft_bfs_16_ref<R: Float + FloatConst>(m: usize, re: &mut [R], im: &mut [R], omg: &[R], mut pos: usize) -> usize {
|
||||
let log_m: usize = (usize::BITS - (m - 1).leading_zeros()) as usize;
|
||||
|
||||
for off in (0..m).step_by(16) {
|
||||
ifft16_ref(
|
||||
as_arr_mut::<16, R>(&mut re[off..]),
|
||||
as_arr_mut::<16, R>(&mut im[off..]),
|
||||
*as_arr::<16, R>(&omg[pos..]),
|
||||
);
|
||||
pos += 16;
|
||||
}
|
||||
|
||||
let mut h: usize = 16;
|
||||
let m_half: usize = m >> 1;
|
||||
|
||||
while h < m_half {
|
||||
let mm: usize = h << 2;
|
||||
for off in (0..m).step_by(mm) {
|
||||
inv_bitwiddle_ifft_ref(
|
||||
h,
|
||||
&mut re[off..],
|
||||
&mut im[off..],
|
||||
as_arr::<4, R>(&omg[pos..]),
|
||||
);
|
||||
pos += 4;
|
||||
}
|
||||
h = mm;
|
||||
}
|
||||
|
||||
if !log_m.is_multiple_of(2) {
|
||||
inv_twiddle_ifft_ref(h, re, im, as_arr::<2, R>(&omg[pos..]));
|
||||
pos += 2;
|
||||
}
|
||||
|
||||
pos
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn inv_twiddle<R: Float + FloatConst>(ra: &mut R, ia: &mut R, rb: &mut R, ib: &mut R, omg_re: R, omg_im: R) {
|
||||
let r_diff: R = *ra - *rb;
|
||||
let i_diff: R = *ia - *ib;
|
||||
*ra = *ra + *rb;
|
||||
*ia = *ia + *ib;
|
||||
*rb = r_diff * omg_re - i_diff * omg_im;
|
||||
*ib = r_diff * omg_im + i_diff * omg_re;
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn inv_itwiddle<R: Float + FloatConst>(ra: &mut R, ia: &mut R, rb: &mut R, ib: &mut R, omg_re: R, omg_im: R) {
|
||||
let r_diff: R = *ra - *rb;
|
||||
let i_diff: R = *ia - *ib;
|
||||
*ra = *ra + *rb;
|
||||
*ia = *ia + *ib;
|
||||
*rb = r_diff * omg_im + i_diff * omg_re;
|
||||
*ib = -r_diff * omg_re + i_diff * omg_im;
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn ifft2_ref<R: Float + FloatConst>(re: &mut [R; 2], im: &mut [R; 2], omg: [R; 2]) {
|
||||
let [ra, rb] = re;
|
||||
let [ia, ib] = im;
|
||||
let [romg, iomg] = omg;
|
||||
inv_twiddle(ra, ia, rb, ib, romg, iomg);
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn ifft4_ref<R: Float + FloatConst>(re: &mut [R; 4], im: &mut [R; 4], omg: [R; 4]) {
|
||||
let [re_0, re_1, re_2, re_3] = re;
|
||||
let [im_0, im_1, im_2, im_3] = im;
|
||||
|
||||
{
|
||||
let omg_0: R = omg[0];
|
||||
let omg_1: R = omg[1];
|
||||
|
||||
inv_twiddle(re_0, im_0, re_1, im_1, omg_0, omg_1);
|
||||
inv_itwiddle(re_2, im_2, re_3, im_3, omg_0, omg_1);
|
||||
}
|
||||
|
||||
{
|
||||
let omg_0: R = omg[2];
|
||||
let omg_1: R = omg[3];
|
||||
inv_twiddle(re_0, im_0, re_2, im_2, omg_0, omg_1);
|
||||
inv_twiddle(re_1, im_1, re_3, im_3, omg_0, omg_1);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn ifft8_ref<R: Float + FloatConst>(re: &mut [R; 8], im: &mut [R; 8], omg: [R; 8]) {
|
||||
let [re_0, re_1, re_2, re_3, re_4, re_5, re_6, re_7] = re;
|
||||
let [im_0, im_1, im_2, im_3, im_4, im_5, im_6, im_7] = im;
|
||||
|
||||
{
|
||||
let omg_4: R = omg[0];
|
||||
let omg_5: R = omg[1];
|
||||
let omg_6: R = omg[2];
|
||||
let omg_7: R = omg[3];
|
||||
inv_twiddle(re_0, im_0, re_1, im_1, omg_4, omg_6);
|
||||
inv_itwiddle(re_2, im_2, re_3, im_3, omg_4, omg_6);
|
||||
inv_twiddle(re_4, im_4, re_5, im_5, omg_5, omg_7);
|
||||
inv_itwiddle(re_6, im_6, re_7, im_7, omg_5, omg_7);
|
||||
}
|
||||
|
||||
{
|
||||
let omg_2: R = omg[4];
|
||||
let omg_3: R = omg[5];
|
||||
inv_twiddle(re_0, im_0, re_2, im_2, omg_2, omg_3);
|
||||
inv_twiddle(re_1, im_1, re_3, im_3, omg_2, omg_3);
|
||||
inv_itwiddle(re_4, im_4, re_6, im_6, omg_2, omg_3);
|
||||
inv_itwiddle(re_5, im_5, re_7, im_7, omg_2, omg_3);
|
||||
}
|
||||
|
||||
{
|
||||
let omg_0: R = omg[6];
|
||||
let omg_1: R = omg[7];
|
||||
inv_twiddle(re_0, im_0, re_4, im_4, omg_0, omg_1);
|
||||
inv_twiddle(re_1, im_1, re_5, im_5, omg_0, omg_1);
|
||||
inv_twiddle(re_2, im_2, re_6, im_6, omg_0, omg_1);
|
||||
inv_twiddle(re_3, im_3, re_7, im_7, omg_0, omg_1);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn ifft16_ref<R: Float + FloatConst>(re: &mut [R; 16], im: &mut [R; 16], omg: [R; 16]) {
|
||||
let [
|
||||
re_0,
|
||||
re_1,
|
||||
re_2,
|
||||
re_3,
|
||||
re_4,
|
||||
re_5,
|
||||
re_6,
|
||||
re_7,
|
||||
re_8,
|
||||
re_9,
|
||||
re_10,
|
||||
re_11,
|
||||
re_12,
|
||||
re_13,
|
||||
re_14,
|
||||
re_15,
|
||||
] = re;
|
||||
let [
|
||||
im_0,
|
||||
im_1,
|
||||
im_2,
|
||||
im_3,
|
||||
im_4,
|
||||
im_5,
|
||||
im_6,
|
||||
im_7,
|
||||
im_8,
|
||||
im_9,
|
||||
im_10,
|
||||
im_11,
|
||||
im_12,
|
||||
im_13,
|
||||
im_14,
|
||||
im_15,
|
||||
] = im;
|
||||
|
||||
{
|
||||
let omg_0: R = omg[0];
|
||||
let omg_1: R = omg[1];
|
||||
let omg_2: R = omg[2];
|
||||
let omg_3: R = omg[3];
|
||||
let omg_4: R = omg[4];
|
||||
let omg_5: R = omg[5];
|
||||
let omg_6: R = omg[6];
|
||||
let omg_7: R = omg[7];
|
||||
inv_twiddle(re_0, im_0, re_1, im_1, omg_0, omg_4);
|
||||
inv_itwiddle(re_2, im_2, re_3, im_3, omg_0, omg_4);
|
||||
inv_twiddle(re_4, im_4, re_5, im_5, omg_1, omg_5);
|
||||
inv_itwiddle(re_6, im_6, re_7, im_7, omg_1, omg_5);
|
||||
inv_twiddle(re_8, im_8, re_9, im_9, omg_2, omg_6);
|
||||
inv_itwiddle(re_10, im_10, re_11, im_11, omg_2, omg_6);
|
||||
inv_twiddle(re_12, im_12, re_13, im_13, omg_3, omg_7);
|
||||
inv_itwiddle(re_14, im_14, re_15, im_15, omg_3, omg_7);
|
||||
}
|
||||
|
||||
{
|
||||
let omg_0: R = omg[8];
|
||||
let omg_1: R = omg[9];
|
||||
let omg_2: R = omg[10];
|
||||
let omg_3: R = omg[11];
|
||||
inv_twiddle(re_0, im_0, re_2, im_2, omg_0, omg_1);
|
||||
inv_twiddle(re_1, im_1, re_3, im_3, omg_0, omg_1);
|
||||
inv_itwiddle(re_4, im_4, re_6, im_6, omg_0, omg_1);
|
||||
inv_itwiddle(re_5, im_5, re_7, im_7, omg_0, omg_1);
|
||||
inv_twiddle(re_8, im_8, re_10, im_10, omg_2, omg_3);
|
||||
inv_twiddle(re_9, im_9, re_11, im_11, omg_2, omg_3);
|
||||
inv_itwiddle(re_12, im_12, re_14, im_14, omg_2, omg_3);
|
||||
inv_itwiddle(re_13, im_13, re_15, im_15, omg_2, omg_3);
|
||||
}
|
||||
|
||||
{
|
||||
let omg_2: R = omg[12];
|
||||
let omg_3: R = omg[13];
|
||||
inv_twiddle(re_0, im_0, re_4, im_4, omg_2, omg_3);
|
||||
inv_twiddle(re_1, im_1, re_5, im_5, omg_2, omg_3);
|
||||
inv_twiddle(re_2, im_2, re_6, im_6, omg_2, omg_3);
|
||||
inv_twiddle(re_3, im_3, re_7, im_7, omg_2, omg_3);
|
||||
inv_itwiddle(re_8, im_8, re_12, im_12, omg_2, omg_3);
|
||||
inv_itwiddle(re_9, im_9, re_13, im_13, omg_2, omg_3);
|
||||
inv_itwiddle(re_10, im_10, re_14, im_14, omg_2, omg_3);
|
||||
inv_itwiddle(re_11, im_11, re_15, im_15, omg_2, omg_3);
|
||||
}
|
||||
|
||||
{
|
||||
let omg_0: R = omg[14];
|
||||
let omg_1: R = omg[15];
|
||||
inv_twiddle(re_0, im_0, re_8, im_8, omg_0, omg_1);
|
||||
inv_twiddle(re_1, im_1, re_9, im_9, omg_0, omg_1);
|
||||
inv_twiddle(re_2, im_2, re_10, im_10, omg_0, omg_1);
|
||||
inv_twiddle(re_3, im_3, re_11, im_11, omg_0, omg_1);
|
||||
inv_twiddle(re_4, im_4, re_12, im_12, omg_0, omg_1);
|
||||
inv_twiddle(re_5, im_5, re_13, im_13, omg_0, omg_1);
|
||||
inv_twiddle(re_6, im_6, re_14, im_14, omg_0, omg_1);
|
||||
inv_twiddle(re_7, im_7, re_15, im_15, omg_0, omg_1);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn inv_twiddle_ifft_ref<R: Float + FloatConst>(h: usize, re: &mut [R], im: &mut [R], omg: &[R; 2]) {
|
||||
let romg = omg[0];
|
||||
let iomg = omg[1];
|
||||
|
||||
let (re_lhs, re_rhs) = re.split_at_mut(h);
|
||||
let (im_lhs, im_rhs) = im.split_at_mut(h);
|
||||
|
||||
for i in 0..h {
|
||||
inv_twiddle(
|
||||
&mut re_lhs[i],
|
||||
&mut im_lhs[i],
|
||||
&mut re_rhs[i],
|
||||
&mut im_rhs[i],
|
||||
romg,
|
||||
iomg,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn inv_bitwiddle_ifft_ref<R: Float + FloatConst>(h: usize, re: &mut [R], im: &mut [R], omg: &[R; 4]) {
|
||||
let (r0, r2) = re.split_at_mut(2 * h);
|
||||
let (r0, r1) = r0.split_at_mut(h);
|
||||
let (r2, r3) = r2.split_at_mut(h);
|
||||
|
||||
let (i0, i2) = im.split_at_mut(2 * h);
|
||||
let (i0, i1) = i0.split_at_mut(h);
|
||||
let (i2, i3) = i2.split_at_mut(h);
|
||||
|
||||
let omg_0: R = omg[0];
|
||||
let omg_1: R = omg[1];
|
||||
let omg_2: R = omg[2];
|
||||
let omg_3: R = omg[3];
|
||||
|
||||
for i in 0..h {
|
||||
inv_twiddle(&mut r0[i], &mut i0[i], &mut r1[i], &mut i1[i], omg_0, omg_1);
|
||||
inv_itwiddle(&mut r2[i], &mut i2[i], &mut r3[i], &mut i3[i], omg_0, omg_1);
|
||||
}
|
||||
|
||||
for i in 0..h {
|
||||
inv_twiddle(&mut r0[i], &mut i0[i], &mut r2[i], &mut i2[i], omg_2, omg_3);
|
||||
inv_twiddle(&mut r1[i], &mut i1[i], &mut r3[i], &mut i3[i], omg_2, omg_3);
|
||||
}
|
||||
}
|
||||
128
poulpy-hal/src/reference/fft64/reim/mod.rs
Normal file
128
poulpy-hal/src/reference/fft64/reim/mod.rs
Normal file
@@ -0,0 +1,128 @@
|
||||
// ----------------------------------------------------------------------
|
||||
// DISCLAIMER
|
||||
//
|
||||
// This module contains code that has been directly ported from the
|
||||
// spqlios-arithmetic library
|
||||
// (https://github.com/tfhe/spqlios-arithmetic), which is licensed
|
||||
// under the Apache License, Version 2.0.
|
||||
//
|
||||
// The porting process from C to Rust was done with minimal changes
|
||||
// in order to preserve the semantics and performance characteristics
|
||||
// of the original implementation.
|
||||
//
|
||||
// Both Poulpy and spqlios-arithmetic are distributed under the terms
|
||||
// of the Apache License, Version 2.0. See the LICENSE file for details.
|
||||
//
|
||||
// ----------------------------------------------------------------------
|
||||
|
||||
#![allow(bad_asm_style)]
|
||||
|
||||
mod conversion;
|
||||
mod fft_ref;
|
||||
mod fft_vec;
|
||||
mod ifft_ref;
|
||||
mod table_fft;
|
||||
mod table_ifft;
|
||||
mod zero;
|
||||
|
||||
pub use conversion::*;
|
||||
pub use fft_ref::*;
|
||||
pub use fft_vec::*;
|
||||
pub use ifft_ref::*;
|
||||
pub use table_fft::*;
|
||||
pub use table_ifft::*;
|
||||
pub use zero::*;
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn as_arr<const size: usize, R: Float + FloatConst>(x: &[R]) -> &[R; size] {
|
||||
debug_assert!(x.len() >= size);
|
||||
unsafe { &*(x.as_ptr() as *const [R; size]) }
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn as_arr_mut<const size: usize, R: Float + FloatConst>(x: &mut [R]) -> &mut [R; size] {
|
||||
debug_assert!(x.len() >= size);
|
||||
unsafe { &mut *(x.as_mut_ptr() as *mut [R; size]) }
|
||||
}
|
||||
|
||||
use rand_distr::num_traits::{Float, FloatConst};
|
||||
#[inline(always)]
|
||||
pub(crate) fn frac_rev_bits<R: Float + FloatConst>(x: usize) -> R {
|
||||
let half: R = R::from(0.5).unwrap();
|
||||
|
||||
match x {
|
||||
0 => R::zero(),
|
||||
1 => half,
|
||||
_ => {
|
||||
if x.is_multiple_of(2) {
|
||||
frac_rev_bits::<R>(x >> 1) * half
|
||||
} else {
|
||||
frac_rev_bits::<R>(x >> 1) * half + half
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ReimDFTExecute<D, T> {
|
||||
fn reim_dft_execute(table: &D, data: &mut [T]);
|
||||
}
|
||||
|
||||
pub trait ReimFromZnx {
|
||||
fn reim_from_znx(res: &mut [f64], a: &[i64]);
|
||||
}
|
||||
|
||||
pub trait ReimToZnx {
|
||||
fn reim_to_znx(res: &mut [i64], divisor: f64, a: &[f64]);
|
||||
}
|
||||
|
||||
pub trait ReimToZnxInplace {
|
||||
fn reim_to_znx_inplace(res: &mut [f64], divisor: f64);
|
||||
}
|
||||
|
||||
pub trait ReimAdd {
|
||||
fn reim_add(res: &mut [f64], a: &[f64], b: &[f64]);
|
||||
}
|
||||
|
||||
pub trait ReimAddInplace {
|
||||
fn reim_add_inplace(res: &mut [f64], a: &[f64]);
|
||||
}
|
||||
|
||||
pub trait ReimSub {
|
||||
fn reim_sub(res: &mut [f64], a: &[f64], b: &[f64]);
|
||||
}
|
||||
|
||||
pub trait ReimSubABInplace {
|
||||
fn reim_sub_ab_inplace(res: &mut [f64], a: &[f64]);
|
||||
}
|
||||
|
||||
pub trait ReimSubBAInplace {
|
||||
fn reim_sub_ba_inplace(res: &mut [f64], a: &[f64]);
|
||||
}
|
||||
|
||||
pub trait ReimNegate {
|
||||
fn reim_negate(res: &mut [f64], a: &[f64]);
|
||||
}
|
||||
|
||||
pub trait ReimNegateInplace {
|
||||
fn reim_negate_inplace(res: &mut [f64]);
|
||||
}
|
||||
|
||||
pub trait ReimMul {
|
||||
fn reim_mul(res: &mut [f64], a: &[f64], b: &[f64]);
|
||||
}
|
||||
|
||||
pub trait ReimMulInplace {
|
||||
fn reim_mul_inplace(res: &mut [f64], a: &[f64]);
|
||||
}
|
||||
|
||||
pub trait ReimAddMul {
|
||||
fn reim_addmul(res: &mut [f64], a: &[f64], b: &[f64]);
|
||||
}
|
||||
|
||||
pub trait ReimCopy {
|
||||
fn reim_copy(res: &mut [f64], a: &[f64]);
|
||||
}
|
||||
|
||||
pub trait ReimZero {
|
||||
fn reim_zero(res: &mut [f64]);
|
||||
}
|
||||
207
poulpy-hal/src/reference/fft64/reim/table_fft.rs
Normal file
207
poulpy-hal/src/reference/fft64/reim/table_fft.rs
Normal file
@@ -0,0 +1,207 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use rand_distr::num_traits::{Float, FloatConst};
|
||||
|
||||
use crate::{
|
||||
alloc_aligned,
|
||||
reference::fft64::reim::{ReimDFTExecute, fft_ref, frac_rev_bits},
|
||||
};
|
||||
|
||||
pub struct ReimFFTRef;
|
||||
|
||||
impl ReimDFTExecute<ReimFFTTable<f64>, f64> for ReimFFTRef {
|
||||
fn reim_dft_execute(table: &ReimFFTTable<f64>, data: &mut [f64]) {
|
||||
fft_ref(table.m, &table.omg, data);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ReimFFTTable<R: Float + FloatConst + Debug> {
|
||||
m: usize,
|
||||
omg: Vec<R>,
|
||||
}
|
||||
|
||||
impl<R: Float + FloatConst + Debug + 'static> ReimFFTTable<R> {
|
||||
pub fn new(m: usize) -> Self {
|
||||
assert!(m & (m - 1) == 0, "m must be a power of two but is {}", m);
|
||||
let mut omg: Vec<R> = alloc_aligned::<R>(2 * m);
|
||||
|
||||
let quarter: R = R::from(1. / 4.).unwrap();
|
||||
|
||||
if m <= 16 {
|
||||
match m {
|
||||
1 => {}
|
||||
2 => {
|
||||
fill_fft2_omegas(quarter, &mut omg, 0);
|
||||
}
|
||||
4 => {
|
||||
fill_fft4_omegas(quarter, &mut omg, 0);
|
||||
}
|
||||
8 => {
|
||||
fill_fft8_omegas(quarter, &mut omg, 0);
|
||||
}
|
||||
16 => {
|
||||
fill_fft16_omegas(quarter, &mut omg, 0);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
} else if m <= 2048 {
|
||||
fill_fft_bfs_16_omegas(m, quarter, &mut omg, 0);
|
||||
} else {
|
||||
fill_fft_rec_16_omegas(m, quarter, &mut omg, 0);
|
||||
}
|
||||
|
||||
Self { m, omg }
|
||||
}
|
||||
|
||||
pub fn m(&self) -> usize {
|
||||
self.m
|
||||
}
|
||||
|
||||
pub fn omg(&self) -> &[R] {
|
||||
&self.omg
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fill_fft2_omegas<R: Float + FloatConst>(j: R, omg: &mut [R], pos: usize) -> usize {
|
||||
let omg_pos: &mut [R] = &mut omg[pos..];
|
||||
assert!(omg_pos.len() >= 2);
|
||||
let angle: R = j / R::from(2).unwrap();
|
||||
let two_pi: R = R::from(2).unwrap() * R::PI();
|
||||
omg_pos[0] = R::cos(two_pi * angle);
|
||||
omg_pos[1] = R::sin(two_pi * angle);
|
||||
pos + 2
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fill_fft4_omegas<R: Float + FloatConst>(j: R, omg: &mut [R], pos: usize) -> usize {
|
||||
let omg_pos: &mut [R] = &mut omg[pos..];
|
||||
assert!(omg_pos.len() >= 4);
|
||||
let angle_1: R = j / R::from(2).unwrap();
|
||||
let angle_2: R = j / R::from(4).unwrap();
|
||||
let two_pi: R = R::from(2).unwrap() * R::PI();
|
||||
omg_pos[0] = R::cos(two_pi * angle_1);
|
||||
omg_pos[1] = R::sin(two_pi * angle_1);
|
||||
omg_pos[2] = R::cos(two_pi * angle_2);
|
||||
omg_pos[3] = R::sin(two_pi * angle_2);
|
||||
pos + 4
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fill_fft8_omegas<R: Float + FloatConst>(j: R, omg: &mut [R], pos: usize) -> usize {
|
||||
let omg_pos: &mut [R] = &mut omg[pos..];
|
||||
assert!(omg_pos.len() >= 8);
|
||||
let _8th: R = R::from(1. / 8.).unwrap();
|
||||
let angle_1: R = j / R::from(2).unwrap();
|
||||
let angle_2: R = j / R::from(4).unwrap();
|
||||
let angle_4: R = j / R::from(8).unwrap();
|
||||
let two_pi: R = R::from(2).unwrap() * R::PI();
|
||||
omg_pos[0] = R::cos(two_pi * angle_1);
|
||||
omg_pos[1] = R::sin(two_pi * angle_1);
|
||||
omg_pos[2] = R::cos(two_pi * angle_2);
|
||||
omg_pos[3] = R::sin(two_pi * angle_2);
|
||||
omg_pos[4] = R::cos(two_pi * angle_4);
|
||||
omg_pos[5] = R::cos(two_pi * (angle_4 + _8th));
|
||||
omg_pos[6] = R::sin(two_pi * angle_4);
|
||||
omg_pos[7] = R::sin(two_pi * (angle_4 + _8th));
|
||||
pos + 8
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fill_fft16_omegas<R: Float + FloatConst>(j: R, omg: &mut [R], pos: usize) -> usize {
|
||||
let omg_pos: &mut [R] = &mut omg[pos..];
|
||||
assert!(omg_pos.len() >= 16);
|
||||
let _8th: R = R::from(1. / 8.).unwrap();
|
||||
let _16th: R = R::from(1. / 16.).unwrap();
|
||||
let angle_1: R = j / R::from(2).unwrap();
|
||||
let angle_2: R = j / R::from(4).unwrap();
|
||||
let angle_4: R = j / R::from(8).unwrap();
|
||||
let angle_8: R = j / R::from(16).unwrap();
|
||||
let two_pi: R = R::from(2).unwrap() * R::PI();
|
||||
omg_pos[0] = R::cos(two_pi * angle_1);
|
||||
omg_pos[1] = R::sin(two_pi * angle_1);
|
||||
omg_pos[2] = R::cos(two_pi * angle_2);
|
||||
omg_pos[3] = R::sin(two_pi * angle_2);
|
||||
omg_pos[4] = R::cos(two_pi * angle_4);
|
||||
omg_pos[5] = R::sin(two_pi * angle_4);
|
||||
omg_pos[6] = R::cos(two_pi * (angle_4 + _8th));
|
||||
omg_pos[7] = R::sin(two_pi * (angle_4 + _8th));
|
||||
omg_pos[8] = R::cos(two_pi * angle_8);
|
||||
omg_pos[9] = R::cos(two_pi * (angle_8 + _8th));
|
||||
omg_pos[10] = R::cos(two_pi * (angle_8 + _16th));
|
||||
omg_pos[11] = R::cos(two_pi * (angle_8 + _8th + _16th));
|
||||
omg_pos[12] = R::sin(two_pi * angle_8);
|
||||
omg_pos[13] = R::sin(two_pi * (angle_8 + _8th));
|
||||
omg_pos[14] = R::sin(two_pi * (angle_8 + _16th));
|
||||
omg_pos[15] = R::sin(two_pi * (angle_8 + _8th + _16th));
|
||||
pos + 16
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fill_fft_bfs_16_omegas<R: Float + FloatConst>(m: usize, j: R, omg: &mut [R], mut pos: usize) -> usize {
|
||||
let log_m: usize = (usize::BITS - (m - 1).leading_zeros()) as usize;
|
||||
let mut mm: usize = m;
|
||||
let mut jj: R = j;
|
||||
|
||||
let two_pi: R = R::from(2).unwrap() * R::PI();
|
||||
|
||||
if !log_m.is_multiple_of(2) {
|
||||
let h = mm >> 1;
|
||||
let j: R = jj * R::from(0.5).unwrap();
|
||||
omg[pos] = R::cos(two_pi * j);
|
||||
omg[pos + 1] = R::sin(two_pi * j);
|
||||
pos += 2;
|
||||
mm = h;
|
||||
jj = j
|
||||
}
|
||||
|
||||
while mm > 16 {
|
||||
let h: usize = mm >> 2;
|
||||
let j: R = jj * R::from(1. / 4.).unwrap();
|
||||
for i in (0..m).step_by(mm) {
|
||||
let rs_0 = j + frac_rev_bits::<R>(i / mm) * R::from(1. / 4.).unwrap();
|
||||
let rs_1 = R::from(2).unwrap() * rs_0;
|
||||
omg[pos] = R::cos(two_pi * rs_1);
|
||||
omg[pos + 1] = R::sin(two_pi * rs_1);
|
||||
omg[pos + 2] = R::cos(two_pi * rs_0);
|
||||
omg[pos + 3] = R::sin(two_pi * rs_0);
|
||||
pos += 4;
|
||||
}
|
||||
mm = h;
|
||||
jj = j;
|
||||
}
|
||||
|
||||
for i in (0..m).step_by(16) {
|
||||
let j = jj + frac_rev_bits(i >> 4);
|
||||
fill_fft16_omegas(j, omg, pos);
|
||||
pos += 16
|
||||
}
|
||||
|
||||
pos
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fill_fft_rec_16_omegas<R: Float + FloatConst>(m: usize, j: R, omg: &mut [R], mut pos: usize) -> usize {
|
||||
if m <= 2048 {
|
||||
return fill_fft_bfs_16_omegas(m, j, omg, pos);
|
||||
}
|
||||
let h: usize = m >> 1;
|
||||
let s: R = j * R::from(0.5).unwrap();
|
||||
let _2pi = R::from(2).unwrap() * R::PI();
|
||||
omg[pos] = R::cos(_2pi * s);
|
||||
omg[pos + 1] = R::sin(_2pi * s);
|
||||
pos += 2;
|
||||
pos = fill_fft_rec_16_omegas(h, s, omg, pos);
|
||||
pos = fill_fft_rec_16_omegas(h, s + R::from(0.5).unwrap(), omg, pos);
|
||||
pos
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn ctwiddle_ref(ra: &mut f64, ia: &mut f64, rb: &mut f64, ib: &mut f64, omg_re: f64, omg_im: f64) {
|
||||
let dr: f64 = *rb * omg_re - *ib * omg_im;
|
||||
let di: f64 = *rb * omg_im + *ib * omg_re;
|
||||
*rb = *ra - dr;
|
||||
*ib = *ia - di;
|
||||
*ra += dr;
|
||||
*ia += di;
|
||||
}
|
||||
201
poulpy-hal/src/reference/fft64/reim/table_ifft.rs
Normal file
201
poulpy-hal/src/reference/fft64/reim/table_ifft.rs
Normal file
@@ -0,0 +1,201 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use rand_distr::num_traits::{Float, FloatConst};
|
||||
|
||||
use crate::{
|
||||
alloc_aligned,
|
||||
reference::fft64::reim::{ReimDFTExecute, frac_rev_bits, ifft_ref::ifft_ref},
|
||||
};
|
||||
|
||||
pub struct ReimIFFTRef;
|
||||
|
||||
impl ReimDFTExecute<ReimIFFTTable<f64>, f64> for ReimIFFTRef {
|
||||
fn reim_dft_execute(table: &ReimIFFTTable<f64>, data: &mut [f64]) {
|
||||
ifft_ref(table.m, &table.omg, data);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ReimIFFTTable<R: Float + FloatConst + Debug> {
|
||||
m: usize,
|
||||
omg: Vec<R>,
|
||||
}
|
||||
|
||||
impl<R: Float + FloatConst + Debug> ReimIFFTTable<R> {
|
||||
pub fn new(m: usize) -> Self {
|
||||
assert!(m & (m - 1) == 0, "m must be a power of two but is {}", m);
|
||||
let mut omg: Vec<R> = alloc_aligned::<R>(2 * m);
|
||||
|
||||
let quarter: R = R::exp2(R::from(-2).unwrap());
|
||||
|
||||
if m <= 16 {
|
||||
match m {
|
||||
1 => {}
|
||||
2 => {
|
||||
fill_ifft2_omegas::<R>(quarter, &mut omg, 0);
|
||||
}
|
||||
4 => {
|
||||
fill_ifft4_omegas(quarter, &mut omg, 0);
|
||||
}
|
||||
8 => {
|
||||
fill_ifft8_omegas(quarter, &mut omg, 0);
|
||||
}
|
||||
16 => {
|
||||
fill_ifft16_omegas(quarter, &mut omg, 0);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
} else if m <= 2048 {
|
||||
fill_ifft_bfs_16_omegas(m, quarter, &mut omg, 0);
|
||||
} else {
|
||||
fill_ifft_rec_16_omegas(m, quarter, &mut omg, 0);
|
||||
}
|
||||
|
||||
Self { m, omg }
|
||||
}
|
||||
|
||||
pub fn execute(&self, data: &mut [R]) {
|
||||
ifft_ref(self.m, &self.omg, data);
|
||||
}
|
||||
|
||||
pub fn m(&self) -> usize {
|
||||
self.m
|
||||
}
|
||||
|
||||
pub fn omg(&self) -> &[R] {
|
||||
&self.omg
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fill_ifft2_omegas<R: Float + FloatConst>(j: R, omg: &mut [R], pos: usize) -> usize {
|
||||
let omg_pos: &mut [R] = &mut omg[pos..];
|
||||
assert!(omg_pos.len() >= 2);
|
||||
let angle: R = j / R::exp2(R::from(2).unwrap());
|
||||
let two_pi: R = R::exp2(R::from(2).unwrap()) * R::PI();
|
||||
omg_pos[0] = R::cos(two_pi * angle);
|
||||
omg_pos[1] = -R::sin(two_pi * angle);
|
||||
pos + 2
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fill_ifft4_omegas<R: Float + FloatConst>(j: R, omg: &mut [R], pos: usize) -> usize {
|
||||
let omg_pos: &mut [R] = &mut omg[pos..];
|
||||
assert!(omg_pos.len() >= 4);
|
||||
let angle_1: R = j / R::from(2).unwrap();
|
||||
let angle_2: R = j / R::from(4).unwrap();
|
||||
let two_pi: R = R::from(2).unwrap() * R::PI();
|
||||
omg_pos[0] = R::cos(two_pi * angle_2);
|
||||
omg_pos[1] = -R::sin(two_pi * angle_2);
|
||||
omg_pos[2] = R::cos(two_pi * angle_1);
|
||||
omg_pos[3] = -R::sin(two_pi * angle_1);
|
||||
pos + 4
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fill_ifft8_omegas<R: Float + FloatConst>(j: R, omg: &mut [R], pos: usize) -> usize {
|
||||
let omg_pos: &mut [R] = &mut omg[pos..];
|
||||
assert!(omg_pos.len() >= 8);
|
||||
let _8th: R = R::from(1. / 8.).unwrap();
|
||||
let angle_1: R = j / R::from(2).unwrap();
|
||||
let angle_2: R = j / R::from(4).unwrap();
|
||||
let angle_4: R = j / R::from(2).unwrap();
|
||||
let two_pi: R = R::from(2).unwrap() * R::PI();
|
||||
omg_pos[0] = R::cos(two_pi * angle_4);
|
||||
omg_pos[1] = R::cos(two_pi * (angle_4 + _8th));
|
||||
omg_pos[2] = -R::sin(two_pi * angle_4);
|
||||
omg_pos[3] = -R::sin(two_pi * (angle_4 + _8th));
|
||||
omg_pos[4] = R::cos(two_pi * angle_2);
|
||||
omg_pos[5] = -R::sin(two_pi * angle_2);
|
||||
omg_pos[6] = R::cos(two_pi * angle_1);
|
||||
omg_pos[7] = -R::sin(two_pi * angle_1);
|
||||
pos + 8
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fill_ifft16_omegas<R: Float + FloatConst>(j: R, omg: &mut [R], pos: usize) -> usize {
|
||||
let omg_pos: &mut [R] = &mut omg[pos..];
|
||||
assert!(omg_pos.len() >= 16);
|
||||
let _8th: R = R::from(1. / 8.).unwrap();
|
||||
let _16th: R = R::from(1. / 16.).unwrap();
|
||||
let angle_1: R = j / R::from(2).unwrap();
|
||||
let angle_2: R = j / R::from(4).unwrap();
|
||||
let angle_4: R = j / R::from(8).unwrap();
|
||||
let angle_8: R = j / R::from(16).unwrap();
|
||||
let two_pi: R = R::from(2).unwrap() * R::PI();
|
||||
omg_pos[0] = R::cos(two_pi * angle_8);
|
||||
omg_pos[1] = R::cos(two_pi * (angle_8 + _8th));
|
||||
omg_pos[2] = R::cos(two_pi * (angle_8 + _16th));
|
||||
omg_pos[3] = R::cos(two_pi * (angle_8 + _8th + _16th));
|
||||
omg_pos[4] = -R::sin(two_pi * angle_8);
|
||||
omg_pos[5] = -R::sin(two_pi * (angle_8 + _8th));
|
||||
omg_pos[6] = -R::sin(two_pi * (angle_8 + _16th));
|
||||
omg_pos[7] = -R::sin(two_pi * (angle_8 + _8th + _16th));
|
||||
omg_pos[8] = R::cos(two_pi * angle_4);
|
||||
omg_pos[9] = -R::sin(two_pi * angle_4);
|
||||
omg_pos[10] = R::cos(two_pi * (angle_4 + _8th));
|
||||
omg_pos[11] = -R::sin(two_pi * (angle_4 + _8th));
|
||||
omg_pos[12] = R::cos(two_pi * angle_2);
|
||||
omg_pos[13] = -R::sin(two_pi * angle_2);
|
||||
omg_pos[14] = R::cos(two_pi * angle_1);
|
||||
omg_pos[15] = -R::sin(two_pi * angle_1);
|
||||
pos + 16
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fill_ifft_bfs_16_omegas<R: Float + FloatConst + Debug>(m: usize, j: R, omg: &mut [R], mut pos: usize) -> usize {
|
||||
let log_m: usize = (usize::BITS - (m - 1).leading_zeros()) as usize;
|
||||
let mut jj: R = j * R::from(16).unwrap() / R::from(m).unwrap();
|
||||
|
||||
for i in (0..m).step_by(16) {
|
||||
let j = jj + frac_rev_bits(i >> 4);
|
||||
fill_ifft16_omegas(j, omg, pos);
|
||||
pos += 16
|
||||
}
|
||||
|
||||
let mut h: usize = 16;
|
||||
let m_half: usize = m >> 1;
|
||||
|
||||
let two_pi: R = R::from(2).unwrap() * R::PI();
|
||||
|
||||
while h < m_half {
|
||||
let mm: usize = h << 2;
|
||||
for i in (0..m).step_by(mm) {
|
||||
let rs_0 = jj + frac_rev_bits::<R>(i / mm) / R::from(4).unwrap();
|
||||
let rs_1 = R::from(2).unwrap() * rs_0;
|
||||
omg[pos] = R::cos(two_pi * rs_0);
|
||||
omg[pos + 1] = -R::sin(two_pi * rs_0);
|
||||
omg[pos + 2] = R::cos(two_pi * rs_1);
|
||||
omg[pos + 3] = -R::sin(two_pi * rs_1);
|
||||
pos += 4;
|
||||
}
|
||||
h = mm;
|
||||
jj = jj * R::from(4).unwrap();
|
||||
}
|
||||
|
||||
if !log_m.is_multiple_of(2) {
|
||||
omg[pos] = R::cos(two_pi * jj);
|
||||
omg[pos + 1] = -R::sin(two_pi * jj);
|
||||
pos += 2;
|
||||
jj = jj * R::from(2).unwrap();
|
||||
}
|
||||
|
||||
assert_eq!(jj, j);
|
||||
|
||||
pos
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fill_ifft_rec_16_omegas<R: Float + FloatConst + Debug>(m: usize, j: R, omg: &mut [R], mut pos: usize) -> usize {
|
||||
if m <= 2048 {
|
||||
return fill_ifft_bfs_16_omegas(m, j, omg, pos);
|
||||
}
|
||||
let h: usize = m >> 1;
|
||||
let s: R = j / R::from(2).unwrap();
|
||||
pos = fill_ifft_rec_16_omegas(h, s, omg, pos);
|
||||
pos = fill_ifft_rec_16_omegas(h, s + R::from(0.5).unwrap(), omg, pos);
|
||||
let _2pi = R::from(2).unwrap() * R::PI();
|
||||
omg[pos] = R::cos(_2pi * s);
|
||||
omg[pos + 1] = -R::sin(_2pi * s);
|
||||
pos += 2;
|
||||
pos
|
||||
}
|
||||
11
poulpy-hal/src/reference/fft64/reim/zero.rs
Normal file
11
poulpy-hal/src/reference/fft64/reim/zero.rs
Normal file
@@ -0,0 +1,11 @@
|
||||
pub fn reim_zero_ref(res: &mut [f64]) {
|
||||
res.fill(0.);
|
||||
}
|
||||
|
||||
pub fn reim_copy_ref(res: &mut [f64], a: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), a.len())
|
||||
}
|
||||
res.copy_from_slice(a);
|
||||
}
|
||||
209
poulpy-hal/src/reference/fft64/reim4/arithmetic_ref.rs
Normal file
209
poulpy-hal/src/reference/fft64/reim4/arithmetic_ref.rs
Normal file
@@ -0,0 +1,209 @@
|
||||
use crate::reference::fft64::reim::as_arr;
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim4_extract_1blk_from_reim_ref(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
|
||||
let mut offset: usize = blk << 2;
|
||||
|
||||
debug_assert!(blk < (m >> 2));
|
||||
debug_assert!(dst.len() >= 2 * rows * 4);
|
||||
|
||||
for chunk in dst.chunks_exact_mut(4).take(2 * rows) {
|
||||
chunk.copy_from_slice(&src[offset..offset + 4]);
|
||||
offset += m
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim4_save_1blk_to_reim_ref<const OVERWRITE: bool>(m: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
|
||||
let mut offset: usize = blk << 2;
|
||||
|
||||
debug_assert!(blk < (m >> 2));
|
||||
debug_assert!(dst.len() >= offset + m + 4);
|
||||
debug_assert!(src.len() >= 8);
|
||||
|
||||
let dst_off = &mut dst[offset..offset + 4];
|
||||
|
||||
if OVERWRITE {
|
||||
dst_off.copy_from_slice(&src[0..4]);
|
||||
} else {
|
||||
dst_off[0] += src[0];
|
||||
dst_off[1] += src[1];
|
||||
dst_off[2] += src[2];
|
||||
dst_off[3] += src[3];
|
||||
}
|
||||
|
||||
offset += m;
|
||||
|
||||
let dst_off = &mut dst[offset..offset + 4];
|
||||
if OVERWRITE {
|
||||
dst_off.copy_from_slice(&src[4..8]);
|
||||
} else {
|
||||
dst_off[0] += src[4];
|
||||
dst_off[1] += src[5];
|
||||
dst_off[2] += src[6];
|
||||
dst_off[3] += src[7];
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim4_save_2blk_to_reim_ref<const OVERWRITE: bool>(m: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
|
||||
let mut offset: usize = blk << 2;
|
||||
|
||||
debug_assert!(blk < (m >> 2));
|
||||
debug_assert!(dst.len() >= offset + 3 * m + 4);
|
||||
debug_assert!(src.len() >= 16);
|
||||
|
||||
let dst_off = &mut dst[offset..offset + 4];
|
||||
if OVERWRITE {
|
||||
dst_off.copy_from_slice(&src[0..4]);
|
||||
} else {
|
||||
dst_off[0] += src[0];
|
||||
dst_off[1] += src[1];
|
||||
dst_off[2] += src[2];
|
||||
dst_off[3] += src[3];
|
||||
}
|
||||
|
||||
offset += m;
|
||||
let dst_off = &mut dst[offset..offset + 4];
|
||||
if OVERWRITE {
|
||||
dst_off.copy_from_slice(&src[4..8]);
|
||||
} else {
|
||||
dst_off[0] += src[4];
|
||||
dst_off[1] += src[5];
|
||||
dst_off[2] += src[6];
|
||||
dst_off[3] += src[7];
|
||||
}
|
||||
|
||||
offset += m;
|
||||
|
||||
let dst_off = &mut dst[offset..offset + 4];
|
||||
if OVERWRITE {
|
||||
dst_off.copy_from_slice(&src[8..12]);
|
||||
} else {
|
||||
dst_off[0] += src[8];
|
||||
dst_off[1] += src[9];
|
||||
dst_off[2] += src[10];
|
||||
dst_off[3] += src[11];
|
||||
}
|
||||
|
||||
offset += m;
|
||||
let dst_off = &mut dst[offset..offset + 4];
|
||||
if OVERWRITE {
|
||||
dst_off.copy_from_slice(&src[12..16]);
|
||||
} else {
|
||||
dst_off[0] += src[12];
|
||||
dst_off[1] += src[13];
|
||||
dst_off[2] += src[14];
|
||||
dst_off[3] += src[15];
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim4_vec_mat1col_product_ref(
|
||||
nrows: usize,
|
||||
dst: &mut [f64], // 8 doubles: [re1(4), im1(4)]
|
||||
u: &[f64], // nrows * 8 doubles: [ur(4) | ui(4)] per row
|
||||
v: &[f64], // nrows * 8 doubles: [ar(4) | ai(4)] per row
|
||||
) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(dst.len() >= 8, "dst must have at least 8 doubles");
|
||||
assert!(u.len() >= nrows * 8, "u must be at least nrows * 8 doubles");
|
||||
assert!(v.len() >= nrows * 8, "v must be at least nrows * 8 doubles");
|
||||
}
|
||||
|
||||
println!("u_ref: {:?}", &u[..nrows * 8]);
|
||||
println!("v_ref: {:?}", &v[..nrows * 8]);
|
||||
|
||||
let mut acc: [f64; 8] = [0f64; 8];
|
||||
let mut j = 0;
|
||||
for _ in 0..nrows {
|
||||
reim4_add_mul(&mut acc, as_arr(&u[j..]), as_arr(&v[j..]));
|
||||
j += 8;
|
||||
}
|
||||
dst[0..8].copy_from_slice(&acc);
|
||||
|
||||
println!("dst_ref: {:?}", &dst[..8]);
|
||||
println!();
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim4_vec_mat2cols_product_ref(
|
||||
nrows: usize,
|
||||
dst: &mut [f64], // 16 doubles: [re1(4), im1(4), re2(4), im2(4)]
|
||||
u: &[f64], // nrows * 8 doubles: [ur(4) | ui(4)] per row
|
||||
v: &[f64], // nrows * 16 doubles: [ar(4) | ai(4) | br(4) | bi(4)] per row
|
||||
) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(dst.len(), 16, "dst must have 16 doubles");
|
||||
assert!(u.len() >= nrows * 8, "u must be at least nrows * 8 doubles");
|
||||
assert!(
|
||||
v.len() >= nrows * 16,
|
||||
"v must be at least nrows * 16 doubles"
|
||||
);
|
||||
}
|
||||
|
||||
// zero accumulators
|
||||
let mut acc_0: [f64; 8] = [0f64; 8];
|
||||
let mut acc_1: [f64; 8] = [0f64; 8];
|
||||
for i in 0..nrows {
|
||||
let _1j: usize = i << 3;
|
||||
let _2j: usize = i << 4;
|
||||
let u_j: &[f64; 8] = as_arr(&u[_1j..]);
|
||||
reim4_add_mul(&mut acc_0, u_j, as_arr(&v[_2j..]));
|
||||
reim4_add_mul(&mut acc_1, u_j, as_arr(&v[_2j + 8..]));
|
||||
}
|
||||
dst[0..8].copy_from_slice(&acc_0);
|
||||
dst[8..16].copy_from_slice(&acc_1);
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim4_vec_mat2cols_2ndcol_product_ref(
|
||||
nrows: usize,
|
||||
dst: &mut [f64], // 8 doubles: [re1(4), im1(4), re2(4), im2(4)]
|
||||
u: &[f64], // nrows * 8 doubles: [ur(4) | ui(4)] per row
|
||||
v: &[f64], // nrows * 16 doubles: [x | x | br(4) | bi(4)] per row
|
||||
) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(
|
||||
dst.len() >= 8,
|
||||
"dst must be at least 8 doubles but is {}",
|
||||
dst.len()
|
||||
);
|
||||
assert!(
|
||||
u.len() >= nrows * 8,
|
||||
"u must be at least nrows={} * 8 doubles but is {}",
|
||||
nrows,
|
||||
u.len()
|
||||
);
|
||||
assert!(
|
||||
v.len() >= nrows * 16,
|
||||
"v must be at least nrows={} * 16 doubles but is {}",
|
||||
nrows,
|
||||
v.len()
|
||||
);
|
||||
}
|
||||
|
||||
// zero accumulators
|
||||
let mut acc: [f64; 8] = [0f64; 8];
|
||||
for i in 0..nrows {
|
||||
let _1j: usize = i << 3;
|
||||
let _2j: usize = i << 4;
|
||||
reim4_add_mul(&mut acc, as_arr(&u[_1j..]), as_arr(&v[_2j + 8..]));
|
||||
}
|
||||
dst[0..8].copy_from_slice(&acc);
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim4_add_mul(dst: &mut [f64; 8], a: &[f64; 8], b: &[f64; 8]) {
|
||||
for k in 0..4 {
|
||||
let ar: f64 = a[k];
|
||||
let br: f64 = b[k];
|
||||
let ai: f64 = a[k + 4];
|
||||
let bi: f64 = b[k + 4];
|
||||
dst[k] += ar * br - ai * bi;
|
||||
dst[k + 4] += ar * bi + ai * br;
|
||||
}
|
||||
}
|
||||
27
poulpy-hal/src/reference/fft64/reim4/mod.rs
Normal file
27
poulpy-hal/src/reference/fft64/reim4/mod.rs
Normal file
@@ -0,0 +1,27 @@
|
||||
mod arithmetic_ref;
|
||||
|
||||
pub use arithmetic_ref::*;
|
||||
|
||||
pub trait Reim4Extract1Blk {
|
||||
fn reim4_extract_1blk(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]);
|
||||
}
|
||||
|
||||
pub trait Reim4Save1Blk {
|
||||
fn reim4_save_1blk<const OVERWRITE: bool>(m: usize, blk: usize, dst: &mut [f64], src: &[f64]);
|
||||
}
|
||||
|
||||
pub trait Reim4Save2Blks {
|
||||
fn reim4_save_2blks<const OVERWRITE: bool>(m: usize, blk: usize, dst: &mut [f64], src: &[f64]);
|
||||
}
|
||||
|
||||
pub trait Reim4Mat1ColProd {
|
||||
fn reim4_mat1col_prod(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]);
|
||||
}
|
||||
|
||||
pub trait Reim4Mat2ColsProd {
|
||||
fn reim4_mat2cols_prod(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]);
|
||||
}
|
||||
|
||||
pub trait Reim4Mat2Cols2ndColProd {
|
||||
fn reim4_mat2cols_2ndcol_prod(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]);
|
||||
}
|
||||
119
poulpy-hal/src/reference/fft64/svp.rs
Normal file
119
poulpy-hal/src/reference/fft64/svp.rs
Normal file
@@ -0,0 +1,119 @@
|
||||
use crate::{
|
||||
layouts::{
|
||||
Backend, ScalarZnx, ScalarZnxToRef, SvpPPol, SvpPPolToMut, SvpPPolToRef, VecZnx, VecZnxDft, VecZnxDftToMut,
|
||||
VecZnxDftToRef, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut,
|
||||
},
|
||||
reference::fft64::reim::{ReimAddMul, ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimMul, ReimMulInplace, ReimZero},
|
||||
};
|
||||
|
||||
pub fn svp_prepare<R, A, BE>(table: &ReimFFTTable<f64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64> + ReimDFTExecute<ReimFFTTable<f64>, f64> + ReimFromZnx,
|
||||
R: SvpPPolToMut<BE>,
|
||||
A: ScalarZnxToRef,
|
||||
{
|
||||
let mut res: SvpPPol<&mut [u8], BE> = res.to_mut();
|
||||
let a: ScalarZnx<&[u8]> = a.to_ref();
|
||||
BE::reim_from_znx(res.at_mut(res_col, 0), a.at(a_col, 0));
|
||||
BE::reim_dft_execute(table, res.at_mut(res_col, 0));
|
||||
}
|
||||
|
||||
pub fn svp_apply_dft<R, A, B, BE>(
|
||||
table: &ReimFFTTable<f64>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
) where
|
||||
BE: Backend<ScalarPrep = f64> + ReimDFTExecute<ReimFFTTable<f64>, f64> + ReimZero + ReimFromZnx + ReimMulInplace,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: SvpPPolToRef<BE>,
|
||||
B: VecZnxToRef,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
|
||||
let a: SvpPPol<&[u8], BE> = a.to_ref();
|
||||
let b: VecZnx<&[u8]> = b.to_ref();
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let b_size: usize = b.size();
|
||||
let min_size: usize = res_size.min(b_size);
|
||||
|
||||
let ppol: &[f64] = a.at(a_col, 0);
|
||||
for j in 0..min_size {
|
||||
let out: &mut [f64] = res.at_mut(res_col, j);
|
||||
BE::reim_from_znx(out, b.at(b_col, j));
|
||||
BE::reim_dft_execute(table, out);
|
||||
BE::reim_mul_inplace(out, ppol);
|
||||
}
|
||||
|
||||
for j in min_size..res_size {
|
||||
BE::reim_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn svp_apply_dft_to_dft<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64> + ReimMul + ReimZero,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: SvpPPolToRef<BE>,
|
||||
B: VecZnxDftToRef<BE>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
|
||||
let a: SvpPPol<&[u8], BE> = a.to_ref();
|
||||
let b: VecZnxDft<&[u8], BE> = b.to_ref();
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let b_size: usize = b.size();
|
||||
let min_size: usize = res_size.min(b_size);
|
||||
|
||||
let ppol: &[f64] = a.at(a_col, 0);
|
||||
for j in 0..min_size {
|
||||
BE::reim_mul(res.at_mut(res_col, j), ppol, b.at(b_col, j));
|
||||
}
|
||||
|
||||
for j in min_size..res_size {
|
||||
BE::reim_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn svp_apply_dft_to_dft_add<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64> + ReimAddMul + ReimZero,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: SvpPPolToRef<BE>,
|
||||
B: VecZnxDftToRef<BE>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
|
||||
let a: SvpPPol<&[u8], BE> = a.to_ref();
|
||||
let b: VecZnxDft<&[u8], BE> = b.to_ref();
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let b_size: usize = b.size();
|
||||
let min_size: usize = res_size.min(b_size);
|
||||
|
||||
let ppol: &[f64] = a.at(a_col, 0);
|
||||
for j in 0..min_size {
|
||||
BE::reim_addmul(res.at_mut(res_col, j), ppol, b.at(b_col, j));
|
||||
}
|
||||
|
||||
for j in min_size..res_size {
|
||||
BE::reim_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn svp_apply_dft_to_dft_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64> + ReimMulInplace,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: SvpPPolToRef<BE>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
|
||||
let a: SvpPPol<&[u8], BE> = a.to_ref();
|
||||
|
||||
let ppol: &[f64] = a.at(a_col, 0);
|
||||
for j in 0..res.size() {
|
||||
BE::reim_mul_inplace(res.at_mut(res_col, j), ppol);
|
||||
}
|
||||
}
|
||||
521
poulpy-hal/src/reference/fft64/vec_znx_big.rs
Normal file
521
poulpy-hal/src/reference/fft64/vec_znx_big.rs
Normal file
@@ -0,0 +1,521 @@
|
||||
use std::f64::consts::SQRT_2;
|
||||
|
||||
use crate::{
|
||||
api::VecZnxBigAddNormal,
|
||||
layouts::{
|
||||
Backend, Module, VecZnx, VecZnxBig, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef, ZnxView, ZnxViewMut,
|
||||
},
|
||||
oep::VecZnxBigAllocBytesImpl,
|
||||
reference::{
|
||||
vec_znx::{
|
||||
vec_znx_add, vec_znx_add_inplace, vec_znx_automorphism, vec_znx_automorphism_inplace, vec_znx_negate,
|
||||
vec_znx_negate_inplace, vec_znx_normalize, vec_znx_sub, vec_znx_sub_ab_inplace, vec_znx_sub_ba_inplace,
|
||||
},
|
||||
znx::{
|
||||
ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxNormalizeFinalStep,
|
||||
ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly,
|
||||
ZnxSub, ZnxSubABInplace, ZnxSubBAInplace, ZnxZero, znx_add_normal_f64_ref,
|
||||
},
|
||||
},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
pub fn vec_znx_big_add<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxAdd + ZnxCopy + ZnxZero,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxBigToRef<BE>,
|
||||
B: VecZnxBigToRef<BE>,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
|
||||
let a: VecZnxBig<&[u8], BE> = a.to_ref();
|
||||
let b: VecZnxBig<&[u8], BE> = b.to_ref();
|
||||
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
let a_vznx: VecZnx<&[u8]> = VecZnx {
|
||||
data: a.data,
|
||||
n: a.n,
|
||||
cols: a.cols,
|
||||
size: a.size,
|
||||
max_size: a.max_size,
|
||||
};
|
||||
|
||||
let b_vznx: VecZnx<&[u8]> = VecZnx {
|
||||
data: b.data,
|
||||
n: b.n,
|
||||
cols: b.cols,
|
||||
size: b.size,
|
||||
max_size: b.max_size,
|
||||
};
|
||||
|
||||
vec_znx_add::<_, _, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col, &b_vznx, b_col);
|
||||
}
|
||||
|
||||
pub fn vec_znx_big_add_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxAddInplace,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxBigToRef<BE>,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
|
||||
let a: VecZnxBig<&[u8], BE> = a.to_ref();
|
||||
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
let a_vznx: VecZnx<&[u8]> = VecZnx {
|
||||
data: a.data,
|
||||
n: a.n,
|
||||
cols: a.cols,
|
||||
size: a.size,
|
||||
max_size: a.max_size,
|
||||
};
|
||||
|
||||
vec_znx_add_inplace::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col);
|
||||
}
|
||||
|
||||
pub fn vec_znx_big_add_small<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxAdd + ZnxCopy + ZnxZero,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxBigToRef<BE>,
|
||||
B: VecZnxToRef,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
|
||||
let a: VecZnxBig<&[u8], BE> = a.to_ref();
|
||||
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
let a_vznx: VecZnx<&[u8]> = VecZnx {
|
||||
data: a.data,
|
||||
n: a.n,
|
||||
cols: a.cols,
|
||||
size: a.size,
|
||||
max_size: a.max_size,
|
||||
};
|
||||
|
||||
vec_znx_add::<_, _, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col, b, b_col);
|
||||
}
|
||||
|
||||
pub fn vec_znx_big_add_small_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxAddInplace,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
vec_znx_add_inplace::<_, _, BE>(&mut res_vznx, res_col, a, a_col);
|
||||
}
|
||||
|
||||
pub fn vec_znx_big_automorphism_inplace_tmp_bytes(n: usize) -> usize {
|
||||
n * size_of::<i64>()
|
||||
}
|
||||
|
||||
pub fn vec_znx_big_automorphism<R, A, BE>(p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxAutomorphism + ZnxZero,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxBigToRef<BE>,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], _> = res.to_mut();
|
||||
let a: VecZnxBig<&[u8], _> = a.to_ref();
|
||||
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
let a_vznx: VecZnx<&[u8]> = VecZnx {
|
||||
data: a.data,
|
||||
n: a.n,
|
||||
cols: a.cols,
|
||||
size: a.size,
|
||||
max_size: a.max_size,
|
||||
};
|
||||
|
||||
vec_znx_automorphism::<_, _, BE>(p, &mut res_vznx, res_col, &a_vznx, a_col);
|
||||
}
|
||||
|
||||
pub fn vec_znx_big_automorphism_inplace<R, BE>(p: i64, res: &mut R, res_col: usize, tmp: &mut [i64])
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxAutomorphism + ZnxCopy,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], _> = res.to_mut();
|
||||
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
vec_znx_automorphism_inplace::<_, BE>(p, &mut res_vznx, res_col, tmp);
|
||||
}
|
||||
|
||||
pub fn vec_znx_big_negate<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxNegate + ZnxZero,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxBigToRef<BE>,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], _> = res.to_mut();
|
||||
let a: VecZnxBig<&[u8], _> = a.to_ref();
|
||||
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
let a_vznx: VecZnx<&[u8]> = VecZnx {
|
||||
data: a.data,
|
||||
n: a.n,
|
||||
cols: a.cols,
|
||||
size: a.size,
|
||||
max_size: a.max_size,
|
||||
};
|
||||
|
||||
vec_znx_negate::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col);
|
||||
}
|
||||
|
||||
pub fn vec_znx_big_negate_inplace<R, BE>(res: &mut R, res_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxNegateInplace,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], _> = res.to_mut();
|
||||
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
vec_znx_negate_inplace::<_, BE>(&mut res_vznx, res_col);
|
||||
}
|
||||
|
||||
pub fn vec_znx_big_normalize_tmp_bytes(n: usize) -> usize {
|
||||
n * size_of::<i64>()
|
||||
}
|
||||
|
||||
pub fn vec_znx_big_normalize<R, A, BE>(basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxBigToRef<BE>,
|
||||
BE: Backend<ScalarBig = i64>
|
||||
+ ZnxNormalizeFirstStepCarryOnly
|
||||
+ ZnxNormalizeMiddleStepCarryOnly
|
||||
+ ZnxNormalizeMiddleStep
|
||||
+ ZnxNormalizeFinalStep
|
||||
+ ZnxNormalizeFirstStep
|
||||
+ ZnxZero,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], _> = a.to_ref();
|
||||
let a_vznx: VecZnx<&[u8]> = VecZnx {
|
||||
data: a.data,
|
||||
n: a.n,
|
||||
cols: a.cols,
|
||||
size: a.size,
|
||||
max_size: a.max_size,
|
||||
};
|
||||
|
||||
vec_znx_normalize::<_, _, BE>(basek, res, res_col, &a_vznx, a_col, carry);
|
||||
}
|
||||
|
||||
pub fn vec_znx_big_add_normal_ref<R, B: Backend<ScalarBig = i64>>(
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
source: &mut Source,
|
||||
) where
|
||||
R: VecZnxBigToMut<B>,
|
||||
{
|
||||
let mut res: VecZnxBig<&mut [u8], B> = res.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = k.div_ceil(basek) - 1;
|
||||
let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64;
|
||||
znx_add_normal_f64_ref(
|
||||
res.at_mut(res_col, limb),
|
||||
sigma * scale,
|
||||
bound * scale,
|
||||
source,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn test_vec_znx_big_add_normal<B>(module: &Module<B>)
|
||||
where
|
||||
Module<B>: VecZnxBigAddNormal<B>,
|
||||
B: Backend<ScalarBig = i64> + VecZnxBigAllocBytesImpl<B>,
|
||||
{
|
||||
let n: usize = module.n();
|
||||
let basek: usize = 17;
|
||||
let k: usize = 2 * 17;
|
||||
let size: usize = 5;
|
||||
let sigma: f64 = 3.2;
|
||||
let bound: f64 = 6.0 * sigma;
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let cols: usize = 2;
|
||||
let zero: Vec<i64> = vec![0; n];
|
||||
let k_f64: f64 = (1u64 << k as u64) as f64;
|
||||
let sqrt2: f64 = SQRT_2;
|
||||
(0..cols).for_each(|col_i| {
|
||||
let mut a: VecZnxBig<Vec<u8>, B> = VecZnxBig::alloc(n, cols, size);
|
||||
module.vec_znx_big_add_normal(basek, &mut a, col_i, k, &mut source, sigma, bound);
|
||||
module.vec_znx_big_add_normal(basek, &mut a, col_i, k, &mut source, sigma, bound);
|
||||
(0..cols).for_each(|col_j| {
|
||||
if col_j != col_i {
|
||||
(0..size).for_each(|limb_i| {
|
||||
assert_eq!(a.at(col_j, limb_i), zero);
|
||||
})
|
||||
} else {
|
||||
let std: f64 = a.std(basek, col_i) * k_f64;
|
||||
assert!(
|
||||
(std - sigma * sqrt2).abs() < 0.1,
|
||||
"std={} ~!= {}",
|
||||
std,
|
||||
sigma * sqrt2
|
||||
);
|
||||
}
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
/// R <- A - B
|
||||
pub fn vec_znx_big_sub<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxSub + ZnxNegate + ZnxZero + ZnxCopy,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxBigToRef<BE>,
|
||||
B: VecZnxBigToRef<BE>,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
|
||||
let a: VecZnxBig<&[u8], BE> = a.to_ref();
|
||||
let b: VecZnxBig<&[u8], BE> = b.to_ref();
|
||||
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
let a_vznx: VecZnx<&[u8]> = VecZnx {
|
||||
data: a.data,
|
||||
n: a.n,
|
||||
cols: a.cols,
|
||||
size: a.size,
|
||||
max_size: a.max_size,
|
||||
};
|
||||
|
||||
let b_vznx: VecZnx<&[u8]> = VecZnx {
|
||||
data: b.data,
|
||||
n: b.n,
|
||||
cols: b.cols,
|
||||
size: b.size,
|
||||
max_size: b.max_size,
|
||||
};
|
||||
|
||||
vec_znx_sub::<_, _, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col, &b_vznx, b_col);
|
||||
}
|
||||
|
||||
/// R <- A - B
|
||||
pub fn vec_znx_big_sub_ab_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxSubABInplace,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxBigToRef<BE>,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
|
||||
let a: VecZnxBig<&[u8], BE> = a.to_ref();
|
||||
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
let a_vznx: VecZnx<&[u8]> = VecZnx {
|
||||
data: a.data,
|
||||
n: a.n,
|
||||
cols: a.cols,
|
||||
size: a.size,
|
||||
max_size: a.max_size,
|
||||
};
|
||||
|
||||
vec_znx_sub_ab_inplace::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col);
|
||||
}
|
||||
|
||||
/// R <- B - A
|
||||
pub fn vec_znx_big_sub_ba_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxSubBAInplace + ZnxNegateInplace,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxBigToRef<BE>,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
|
||||
let a: VecZnxBig<&[u8], BE> = a.to_ref();
|
||||
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
let a_vznx: VecZnx<&[u8]> = VecZnx {
|
||||
data: a.data,
|
||||
n: a.n,
|
||||
cols: a.cols,
|
||||
size: a.size,
|
||||
max_size: a.max_size,
|
||||
};
|
||||
|
||||
vec_znx_sub_ba_inplace::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col);
|
||||
}
|
||||
|
||||
/// R <- A - B
|
||||
pub fn vec_znx_big_sub_small_a<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxSub + ZnxNegate + ZnxZero + ZnxCopy,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxToRef,
|
||||
B: VecZnxBigToRef<BE>,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
|
||||
let b: VecZnxBig<&[u8], BE> = b.to_ref();
|
||||
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
let b_vznx: VecZnx<&[u8]> = VecZnx {
|
||||
data: b.data,
|
||||
n: b.n,
|
||||
cols: b.cols,
|
||||
size: b.size,
|
||||
max_size: b.max_size,
|
||||
};
|
||||
|
||||
vec_znx_sub::<_, _, _, BE>(&mut res_vznx, res_col, a, a_col, &b_vznx, b_col);
|
||||
}
|
||||
|
||||
/// R <- A - B
|
||||
pub fn vec_znx_big_sub_small_b<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxSub + ZnxNegate + ZnxZero + ZnxCopy,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxBigToRef<BE>,
|
||||
B: VecZnxToRef,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
|
||||
let a: VecZnxBig<&[u8], BE> = a.to_ref();
|
||||
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
let a_vznx: VecZnx<&[u8]> = VecZnx {
|
||||
data: a.data,
|
||||
n: a.n,
|
||||
cols: a.cols,
|
||||
size: a.size,
|
||||
max_size: a.max_size,
|
||||
};
|
||||
|
||||
vec_znx_sub::<_, _, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col, b, b_col);
|
||||
}
|
||||
|
||||
/// R <- R - A
|
||||
pub fn vec_znx_big_sub_small_a_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxSubABInplace,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
|
||||
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
vec_znx_sub_ab_inplace::<_, _, BE>(&mut res_vznx, res_col, a, a_col);
|
||||
}
|
||||
|
||||
/// R <- A - R
|
||||
pub fn vec_znx_big_sub_small_b_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxSubBAInplace + ZnxNegateInplace,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
|
||||
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
vec_znx_sub_ba_inplace::<_, _, BE>(&mut res_vznx, res_col, a, a_col);
|
||||
}
|
||||
369
poulpy-hal/src/reference/fft64/vec_znx_dft.rs
Normal file
369
poulpy-hal/src/reference/fft64/vec_znx_dft.rs
Normal file
@@ -0,0 +1,369 @@
|
||||
use bytemuck::cast_slice_mut;
|
||||
|
||||
use crate::{
|
||||
layouts::{
|
||||
Backend, Data, VecZnx, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, ZnxInfos,
|
||||
ZnxView, ZnxViewMut,
|
||||
},
|
||||
reference::{
|
||||
fft64::reim::{
|
||||
ReimAdd, ReimAddInplace, ReimCopy, ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimIFFTTable, ReimNegate,
|
||||
ReimNegateInplace, ReimSub, ReimSubABInplace, ReimSubBAInplace, ReimToZnx, ReimToZnxInplace, ReimZero,
|
||||
},
|
||||
znx::ZnxZero,
|
||||
},
|
||||
};
|
||||
|
||||
pub fn vec_znx_dft_add<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64> + ReimAdd + ReimCopy + ReimZero,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: VecZnxDftToRef<BE>,
|
||||
B: VecZnxDftToRef<BE>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], BE> = a.to_ref();
|
||||
let b: VecZnxDft<&[u8], BE> = b.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
assert_eq!(b.n(), res.n());
|
||||
}
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let a_size: usize = a.size();
|
||||
let b_size: usize = b.size();
|
||||
|
||||
if a_size <= b_size {
|
||||
let sum_size: usize = a_size.min(res_size);
|
||||
let cpy_size: usize = b_size.min(res_size);
|
||||
|
||||
for j in 0..sum_size {
|
||||
BE::reim_add(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j));
|
||||
}
|
||||
|
||||
for j in sum_size..cpy_size {
|
||||
BE::reim_copy(res.at_mut(res_col, j), b.at(b_col, j));
|
||||
}
|
||||
|
||||
for j in cpy_size..res_size {
|
||||
BE::reim_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
} else {
|
||||
let sum_size: usize = b_size.min(res_size);
|
||||
let cpy_size: usize = a_size.min(res_size);
|
||||
|
||||
for j in 0..sum_size {
|
||||
BE::reim_add(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j));
|
||||
}
|
||||
|
||||
for j in sum_size..cpy_size {
|
||||
BE::reim_copy(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
|
||||
for j in cpy_size..res_size {
|
||||
BE::reim_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_dft_add_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64> + ReimAddInplace,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: VecZnxDftToRef<BE>,
|
||||
{
|
||||
let a: VecZnxDft<&[u8], BE> = a.to_ref();
|
||||
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let a_size: usize = a.size();
|
||||
|
||||
let sum_size: usize = a_size.min(res_size);
|
||||
|
||||
for j in 0..sum_size {
|
||||
BE::reim_add_inplace(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_dft_copy<R, A, BE>(step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64> + ReimCopy + ReimZero,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: VecZnxDftToRef<BE>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], BE> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), a.n())
|
||||
}
|
||||
|
||||
let steps: usize = a.size().div_ceil(step);
|
||||
let min_steps: usize = res.size().min(steps);
|
||||
|
||||
(0..min_steps).for_each(|j| {
|
||||
let limb: usize = offset + j * step;
|
||||
if limb < a.size() {
|
||||
BE::reim_copy(res.at_mut(res_col, j), a.at(a_col, limb));
|
||||
}
|
||||
});
|
||||
(min_steps..res.size()).for_each(|j| {
|
||||
BE::reim_zero(res.at_mut(res_col, j));
|
||||
})
|
||||
}
|
||||
|
||||
pub fn vec_znx_dft_apply<R, A, BE>(
|
||||
table: &ReimFFTTable<f64>,
|
||||
step: usize,
|
||||
offset: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
) where
|
||||
BE: Backend<ScalarPrep = f64> + ReimDFTExecute<ReimFFTTable<f64>, f64> + ReimFromZnx + ReimZero,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(step > 0);
|
||||
assert_eq!(table.m() << 1, res.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
|
||||
let a_size: usize = a.size();
|
||||
let res_size: usize = res.size();
|
||||
|
||||
let steps: usize = a_size.div_ceil(step);
|
||||
let min_steps: usize = res_size.min(steps);
|
||||
|
||||
for j in 0..min_steps {
|
||||
let limb = offset + j * step;
|
||||
if limb < a_size {
|
||||
BE::reim_from_znx(res.at_mut(res_col, j), a.at(a_col, limb));
|
||||
BE::reim_dft_execute(table, res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
(min_steps..res.size()).for_each(|j| {
|
||||
BE::reim_zero(res.at_mut(res_col, j));
|
||||
});
|
||||
}
|
||||
|
||||
pub fn vec_znx_idft_apply<R, A, BE>(table: &ReimIFFTTable<f64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64, ScalarBig = i64>
|
||||
+ ReimDFTExecute<ReimIFFTTable<f64>, f64>
|
||||
+ ReimCopy
|
||||
+ ReimToZnxInplace
|
||||
+ ZnxZero,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxDftToRef<BE>,
|
||||
{
|
||||
let mut res: VecZnxBig<&mut [u8], BE> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], BE> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(table.m() << 1, res.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let min_size: usize = res_size.min(a.size());
|
||||
|
||||
let divisor: f64 = table.m() as f64;
|
||||
|
||||
for j in 0..min_size {
|
||||
let res_slice_f64: &mut [f64] = cast_slice_mut(res.at_mut(res_col, j));
|
||||
BE::reim_copy(res_slice_f64, a.at(a_col, j));
|
||||
BE::reim_dft_execute(table, res_slice_f64);
|
||||
BE::reim_to_znx_inplace(res_slice_f64, divisor);
|
||||
}
|
||||
|
||||
for j in min_size..res_size {
|
||||
BE::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_idft_apply_tmpa<R, A, BE>(table: &ReimIFFTTable<f64>, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64, ScalarBig = i64> + ReimDFTExecute<ReimIFFTTable<f64>, f64> + ReimToZnx + ZnxZero,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxDftToMut<BE>,
|
||||
{
|
||||
let mut res: VecZnxBig<&mut [u8], BE> = res.to_mut();
|
||||
let mut a: VecZnxDft<&mut [u8], BE> = a.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(table.m() << 1, res.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
|
||||
let res_size = res.size();
|
||||
let min_size: usize = res_size.min(a.size());
|
||||
|
||||
let divisor: f64 = table.m() as f64;
|
||||
|
||||
for j in 0..min_size {
|
||||
BE::reim_dft_execute(table, a.at_mut(a_col, j));
|
||||
BE::reim_to_znx(res.at_mut(res_col, j), divisor, a.at(a_col, j));
|
||||
}
|
||||
|
||||
for j in min_size..res_size {
|
||||
BE::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_idft_apply_consume<D: Data, BE>(table: &ReimIFFTTable<f64>, mut res: VecZnxDft<D, BE>) -> VecZnxBig<D, BE>
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64, ScalarBig = i64> + ReimDFTExecute<ReimIFFTTable<f64>, f64> + ReimToZnxInplace,
|
||||
VecZnxDft<D, BE>: VecZnxDftToMut<BE>,
|
||||
{
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(table.m() << 1, res.n());
|
||||
}
|
||||
|
||||
let divisor: f64 = table.m() as f64;
|
||||
|
||||
for i in 0..res.cols() {
|
||||
for j in 0..res.size() {
|
||||
BE::reim_dft_execute(table, res.at_mut(i, j));
|
||||
BE::reim_to_znx_inplace(res.at_mut(i, j), divisor);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
res.into_big()
|
||||
}
|
||||
|
||||
pub fn vec_znx_dft_sub<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64> + ReimSub + ReimNegate + ReimZero + ReimCopy,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: VecZnxDftToRef<BE>,
|
||||
B: VecZnxDftToRef<BE>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], BE> = a.to_ref();
|
||||
let b: VecZnxDft<&[u8], BE> = b.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
assert_eq!(b.n(), res.n());
|
||||
}
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let a_size: usize = a.size();
|
||||
let b_size: usize = b.size();
|
||||
|
||||
if a_size <= b_size {
|
||||
let sum_size: usize = a_size.min(res_size);
|
||||
let cpy_size: usize = b_size.min(res_size);
|
||||
|
||||
for j in 0..sum_size {
|
||||
BE::reim_sub(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j));
|
||||
}
|
||||
|
||||
for j in sum_size..cpy_size {
|
||||
BE::reim_negate(res.at_mut(res_col, j), b.at(b_col, j));
|
||||
}
|
||||
|
||||
for j in cpy_size..res_size {
|
||||
BE::reim_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
} else {
|
||||
let sum_size: usize = b_size.min(res_size);
|
||||
let cpy_size: usize = a_size.min(res_size);
|
||||
|
||||
for j in 0..sum_size {
|
||||
BE::reim_sub(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j));
|
||||
}
|
||||
|
||||
for j in sum_size..cpy_size {
|
||||
BE::reim_copy(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
|
||||
for j in cpy_size..res_size {
|
||||
BE::reim_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_dft_sub_ab_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64> + ReimSubABInplace,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: VecZnxDftToRef<BE>,
|
||||
{
|
||||
let a: VecZnxDft<&[u8], BE> = a.to_ref();
|
||||
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let a_size: usize = a.size();
|
||||
|
||||
let sum_size: usize = a_size.min(res_size);
|
||||
|
||||
for j in 0..sum_size {
|
||||
BE::reim_sub_ab_inplace(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_dft_sub_ba_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64> + ReimSubBAInplace + ReimNegateInplace,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: VecZnxDftToRef<BE>,
|
||||
{
|
||||
let a: VecZnxDft<&[u8], BE> = a.to_ref();
|
||||
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let a_size: usize = a.size();
|
||||
|
||||
let sum_size: usize = a_size.min(res_size);
|
||||
|
||||
for j in 0..sum_size {
|
||||
BE::reim_sub_ba_inplace(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
|
||||
for j in sum_size..res_size {
|
||||
BE::reim_negate_inplace(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_dft_zero<R, BE>(res: &mut R)
|
||||
where
|
||||
R: VecZnxDftToMut<BE>,
|
||||
BE: Backend<ScalarPrep = f64> + ReimZero,
|
||||
{
|
||||
BE::reim_zero(res.to_mut().raw_mut());
|
||||
}
|
||||
365
poulpy-hal/src/reference/fft64/vmp.rs
Normal file
365
poulpy-hal/src/reference/fft64/vmp.rs
Normal file
@@ -0,0 +1,365 @@
|
||||
use crate::{
|
||||
cast_mut,
|
||||
layouts::{MatZnx, MatZnxToRef, VecZnx, VecZnxToRef, VmpPMatToMut, ZnxView, ZnxViewMut},
|
||||
oep::VecZnxDftAllocBytesImpl,
|
||||
reference::fft64::{
|
||||
reim::{ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimZero},
|
||||
reim4::{Reim4Extract1Blk, Reim4Mat1ColProd, Reim4Mat2Cols2ndColProd, Reim4Mat2ColsProd, Reim4Save1Blk, Reim4Save2Blks},
|
||||
vec_znx_dft::vec_znx_dft_apply,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::layouts::{Backend, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, VmpPMat, VmpPMatToRef, ZnxInfos};
|
||||
|
||||
pub fn vmp_prepare_tmp_bytes(n: usize) -> usize {
|
||||
n * size_of::<i64>()
|
||||
}
|
||||
|
||||
pub fn vmp_prepare<R, A, BE>(table: &ReimFFTTable<f64>, pmat: &mut R, mat: &A, tmp: &mut [f64])
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64> + ReimDFTExecute<ReimFFTTable<f64>, f64> + ReimFromZnx + Reim4Extract1Blk,
|
||||
R: VmpPMatToMut<BE>,
|
||||
A: MatZnxToRef,
|
||||
{
|
||||
let mut res: crate::layouts::VmpPMat<&mut [u8], BE> = pmat.to_mut();
|
||||
let a: MatZnx<&[u8]> = mat.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
assert_eq!(
|
||||
res.cols_in(),
|
||||
a.cols_in(),
|
||||
"res.cols_in: {} != a.cols_in: {}",
|
||||
res.cols_in(),
|
||||
a.cols_in()
|
||||
);
|
||||
assert_eq!(
|
||||
res.rows(),
|
||||
a.rows(),
|
||||
"res.rows: {} != a.rows: {}",
|
||||
res.rows(),
|
||||
a.rows()
|
||||
);
|
||||
assert_eq!(
|
||||
res.cols_out(),
|
||||
a.cols_out(),
|
||||
"res.cols_out: {} != a.cols_out: {}",
|
||||
res.cols_out(),
|
||||
a.cols_out()
|
||||
);
|
||||
assert_eq!(
|
||||
res.size(),
|
||||
a.size(),
|
||||
"res.size: {} != a.size: {}",
|
||||
res.size(),
|
||||
a.size()
|
||||
);
|
||||
}
|
||||
|
||||
let nrows: usize = a.cols_in() * a.rows();
|
||||
let ncols: usize = a.cols_out() * a.size();
|
||||
vmp_prepare_core::<BE>(table, res.raw_mut(), a.raw(), nrows, ncols, tmp);
|
||||
}
|
||||
|
||||
pub(crate) fn vmp_prepare_core<REIM>(
|
||||
table: &ReimFFTTable<f64>,
|
||||
pmat: &mut [f64],
|
||||
mat: &[i64],
|
||||
nrows: usize,
|
||||
ncols: usize,
|
||||
tmp: &mut [f64],
|
||||
) where
|
||||
REIM: ReimDFTExecute<ReimFFTTable<f64>, f64> + ReimFromZnx + Reim4Extract1Blk,
|
||||
{
|
||||
let m: usize = table.m();
|
||||
let n: usize = m << 1;
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(n >= 8);
|
||||
assert_eq!(mat.len(), n * nrows * ncols);
|
||||
assert_eq!(pmat.len(), n * nrows * ncols);
|
||||
assert_eq!(tmp.len(), vmp_prepare_tmp_bytes(n) / size_of::<i64>())
|
||||
}
|
||||
|
||||
let offset: usize = nrows * ncols * 8;
|
||||
|
||||
for row_i in 0..nrows {
|
||||
for col_i in 0..ncols {
|
||||
let pos: usize = n * (row_i * ncols + col_i);
|
||||
|
||||
REIM::reim_from_znx(tmp, &mat[pos..pos + n]);
|
||||
REIM::reim_dft_execute(table, tmp);
|
||||
|
||||
let dst: &mut [f64] = if col_i == (ncols - 1) && !ncols.is_multiple_of(2) {
|
||||
&mut pmat[col_i * nrows * 8 + row_i * 8..]
|
||||
} else {
|
||||
&mut pmat[(col_i / 2) * (nrows * 16) + row_i * 16 + (col_i % 2) * 8..]
|
||||
};
|
||||
|
||||
for blk_i in 0..m >> 2 {
|
||||
REIM::reim4_extract_1blk(m, 1, blk_i, &mut dst[blk_i * offset..], tmp);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vmp_apply_dft_tmp_bytes(n: usize, a_size: usize, prows: usize, pcols_in: usize) -> usize {
|
||||
let row_max: usize = (a_size).min(prows);
|
||||
(16 + (n + 8) * row_max * pcols_in) * size_of::<f64>()
|
||||
}
|
||||
|
||||
pub fn vmp_apply_dft<R, A, M, BE>(table: &ReimFFTTable<f64>, res: &mut R, a: &A, pmat: &M, tmp_bytes: &mut [f64])
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64>
|
||||
+ VecZnxDftAllocBytesImpl<BE>
|
||||
+ ReimDFTExecute<ReimFFTTable<f64>, f64>
|
||||
+ ReimZero
|
||||
+ Reim4Extract1Blk
|
||||
+ Reim4Mat1ColProd
|
||||
+ Reim4Mat2Cols2ndColProd
|
||||
+ Reim4Mat2ColsProd
|
||||
+ Reim4Save2Blks
|
||||
+ Reim4Save1Blk
|
||||
+ ReimFromZnx,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: VecZnxToRef,
|
||||
M: VmpPMatToRef<BE>,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let pmat: VmpPMat<&[u8], BE> = pmat.to_ref();
|
||||
|
||||
let n: usize = a.n();
|
||||
let cols: usize = pmat.cols_in();
|
||||
let size: usize = a.size().min(pmat.rows());
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(tmp_bytes.len() >= vmp_apply_dft_tmp_bytes(n, size, pmat.rows(), cols));
|
||||
assert!(a.cols() <= cols);
|
||||
}
|
||||
|
||||
let (data, tmp_bytes) = tmp_bytes.split_at_mut(BE::vec_znx_dft_alloc_bytes_impl(n, cols, size));
|
||||
|
||||
let mut a_dft: VecZnxDft<&mut [u8], BE> = VecZnxDft::from_data(cast_mut(data), n, cols, size);
|
||||
|
||||
let offset: usize = cols - a.cols();
|
||||
for j in 0..cols {
|
||||
vec_znx_dft_apply(table, 1, 0, &mut a_dft, j, &a, offset + j);
|
||||
}
|
||||
|
||||
vmp_apply_dft_to_dft(res, &a_dft, &pmat, tmp_bytes);
|
||||
}
|
||||
|
||||
pub fn vmp_apply_dft_to_dft_tmp_bytes(a_size: usize, prows: usize, pcols_in: usize) -> usize {
|
||||
let row_max: usize = (a_size).min(prows);
|
||||
(16 + 8 * row_max * pcols_in) * size_of::<f64>()
|
||||
}
|
||||
|
||||
pub fn vmp_apply_dft_to_dft<R, A, M, BE>(res: &mut R, a: &A, pmat: &M, tmp_bytes: &mut [f64])
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64>
|
||||
+ ReimZero
|
||||
+ Reim4Extract1Blk
|
||||
+ Reim4Mat1ColProd
|
||||
+ Reim4Mat2Cols2ndColProd
|
||||
+ Reim4Mat2ColsProd
|
||||
+ Reim4Save2Blks
|
||||
+ Reim4Save1Blk,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: VecZnxDftToRef<BE>,
|
||||
M: VmpPMatToRef<BE>,
|
||||
{
|
||||
use crate::layouts::{ZnxView, ZnxViewMut};
|
||||
|
||||
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], BE> = a.to_ref();
|
||||
let pmat: VmpPMat<&[u8], BE> = pmat.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), pmat.n());
|
||||
assert_eq!(a.n(), pmat.n());
|
||||
assert_eq!(res.cols(), pmat.cols_out());
|
||||
assert_eq!(a.cols(), pmat.cols_in());
|
||||
}
|
||||
|
||||
let n: usize = res.n();
|
||||
let nrows: usize = pmat.cols_in() * pmat.rows();
|
||||
let ncols: usize = pmat.cols_out() * pmat.size();
|
||||
|
||||
let pmat_raw: &[f64] = pmat.raw();
|
||||
let a_raw: &[f64] = a.raw();
|
||||
let res_raw: &mut [f64] = res.raw_mut();
|
||||
|
||||
vmp_apply_dft_to_dft_core::<true, BE>(n, res_raw, a_raw, pmat_raw, 0, nrows, ncols, tmp_bytes)
|
||||
}
|
||||
|
||||
pub fn vmp_apply_dft_to_dft_add<R, A, M, BE>(res: &mut R, a: &A, pmat: &M, limb_offset: usize, tmp_bytes: &mut [f64])
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64>
|
||||
+ ReimZero
|
||||
+ Reim4Extract1Blk
|
||||
+ Reim4Mat1ColProd
|
||||
+ Reim4Mat2Cols2ndColProd
|
||||
+ Reim4Mat2ColsProd
|
||||
+ Reim4Save2Blks
|
||||
+ Reim4Save1Blk,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: VecZnxDftToRef<BE>,
|
||||
M: VmpPMatToRef<BE>,
|
||||
{
|
||||
use crate::layouts::{ZnxView, ZnxViewMut};
|
||||
|
||||
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], BE> = a.to_ref();
|
||||
let pmat: VmpPMat<&[u8], BE> = pmat.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), pmat.n());
|
||||
assert_eq!(a.n(), pmat.n());
|
||||
assert_eq!(res.cols(), pmat.cols_out());
|
||||
assert_eq!(a.cols(), pmat.cols_in());
|
||||
}
|
||||
|
||||
let n: usize = res.n();
|
||||
let nrows: usize = pmat.cols_in() * pmat.rows();
|
||||
let ncols: usize = pmat.cols_out() * pmat.size();
|
||||
|
||||
let pmat_raw: &[f64] = pmat.raw();
|
||||
let a_raw: &[f64] = a.raw();
|
||||
let res_raw: &mut [f64] = res.raw_mut();
|
||||
|
||||
vmp_apply_dft_to_dft_core::<false, BE>(
|
||||
n,
|
||||
res_raw,
|
||||
a_raw,
|
||||
pmat_raw,
|
||||
limb_offset,
|
||||
nrows,
|
||||
ncols,
|
||||
tmp_bytes,
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn vmp_apply_dft_to_dft_core<const OVERWRITE: bool, REIM>(
|
||||
n: usize,
|
||||
res: &mut [f64],
|
||||
a: &[f64],
|
||||
pmat: &[f64],
|
||||
limb_offset: usize,
|
||||
nrows: usize,
|
||||
ncols: usize,
|
||||
tmp_bytes: &mut [f64],
|
||||
) where
|
||||
REIM: ReimZero
|
||||
+ Reim4Extract1Blk
|
||||
+ Reim4Mat1ColProd
|
||||
+ Reim4Mat2Cols2ndColProd
|
||||
+ Reim4Mat2ColsProd
|
||||
+ Reim4Save2Blks
|
||||
+ Reim4Save1Blk,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(n >= 8);
|
||||
assert!(n.is_power_of_two());
|
||||
assert_eq!(pmat.len(), n * nrows * ncols);
|
||||
assert!(res.len() & (n - 1) == 0);
|
||||
assert!(a.len() & (n - 1) == 0);
|
||||
}
|
||||
|
||||
let a_size: usize = a.len() / n;
|
||||
let res_size: usize = res.len() / n;
|
||||
|
||||
let m: usize = n >> 1;
|
||||
|
||||
let (mat2cols_output, extracted_blk) = tmp_bytes.split_at_mut(16);
|
||||
|
||||
let row_max: usize = nrows.min(a_size);
|
||||
let col_max: usize = ncols.min(res_size);
|
||||
|
||||
if limb_offset >= col_max {
|
||||
if OVERWRITE {
|
||||
REIM::reim_zero(res);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
for blk_i in 0..(m >> 2) {
|
||||
let mat_blk_start: &[f64] = &pmat[blk_i * (8 * nrows * ncols)..];
|
||||
|
||||
REIM::reim4_extract_1blk(m, row_max, blk_i, extracted_blk, a);
|
||||
|
||||
if limb_offset.is_multiple_of(2) {
|
||||
for (col_res, col_pmat) in (0..).step_by(2).zip((limb_offset..col_max - 1).step_by(2)) {
|
||||
let col_offset: usize = col_pmat * (8 * nrows);
|
||||
REIM::reim4_mat2cols_prod(
|
||||
row_max,
|
||||
mat2cols_output,
|
||||
extracted_blk,
|
||||
&mat_blk_start[col_offset..],
|
||||
);
|
||||
REIM::reim4_save_2blks::<OVERWRITE>(m, blk_i, &mut res[col_res * n..], mat2cols_output);
|
||||
}
|
||||
} else {
|
||||
let col_offset: usize = (limb_offset - 1) * (8 * nrows);
|
||||
REIM::reim4_mat2cols_2ndcol_prod(
|
||||
row_max,
|
||||
mat2cols_output,
|
||||
extracted_blk,
|
||||
&mat_blk_start[col_offset..],
|
||||
);
|
||||
|
||||
REIM::reim4_save_1blk::<OVERWRITE>(m, blk_i, res, mat2cols_output);
|
||||
|
||||
for (col_res, col_pmat) in (1..)
|
||||
.step_by(2)
|
||||
.zip((limb_offset + 1..col_max - 1).step_by(2))
|
||||
{
|
||||
let col_offset: usize = col_pmat * (8 * nrows);
|
||||
REIM::reim4_mat2cols_prod(
|
||||
row_max,
|
||||
mat2cols_output,
|
||||
extracted_blk,
|
||||
&mat_blk_start[col_offset..],
|
||||
);
|
||||
REIM::reim4_save_2blks::<OVERWRITE>(m, blk_i, &mut res[col_res * n..], mat2cols_output);
|
||||
}
|
||||
}
|
||||
|
||||
if !col_max.is_multiple_of(2) {
|
||||
let last_col: usize = col_max - 1;
|
||||
let col_offset: usize = last_col * (8 * nrows);
|
||||
|
||||
if last_col >= limb_offset {
|
||||
if ncols == col_max {
|
||||
REIM::reim4_mat1col_prod(
|
||||
row_max,
|
||||
mat2cols_output,
|
||||
extracted_blk,
|
||||
&mat_blk_start[col_offset..],
|
||||
);
|
||||
} else {
|
||||
REIM::reim4_mat2cols_prod(
|
||||
row_max,
|
||||
mat2cols_output,
|
||||
extracted_blk,
|
||||
&mat_blk_start[col_offset..],
|
||||
);
|
||||
}
|
||||
REIM::reim4_save_1blk::<OVERWRITE>(
|
||||
m,
|
||||
blk_i,
|
||||
&mut res[(last_col - limb_offset) * n..],
|
||||
mat2cols_output,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
REIM::reim_zero(&mut res[col_max * n..]);
|
||||
}
|
||||
4
poulpy-hal/src/reference/mod.rs
Normal file
4
poulpy-hal/src/reference/mod.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
pub mod fft64;
|
||||
pub mod vec_znx;
|
||||
pub mod zn;
|
||||
pub mod znx;
|
||||
177
poulpy-hal/src/reference/vec_znx/add.rs
Normal file
177
poulpy-hal/src/reference/vec_znx/add.rs
Normal file
@@ -0,0 +1,177 @@
|
||||
use std::hint::black_box;
|
||||
|
||||
use criterion::{BenchmarkId, Criterion};
|
||||
|
||||
use crate::{
|
||||
api::{ModuleNew, VecZnxAdd, VecZnxAddInplace},
|
||||
layouts::{Backend, FillUniform, Module, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
|
||||
reference::znx::{ZnxAdd, ZnxAddInplace, ZnxCopy, ZnxZero},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
pub fn vec_znx_add<R, A, B, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
B: VecZnxToRef,
|
||||
ZNXARI: ZnxAdd + ZnxCopy + ZnxZero,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let b: VecZnx<&[u8]> = b.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
assert_eq!(b.n(), res.n());
|
||||
}
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let a_size: usize = a.size();
|
||||
let b_size: usize = b.size();
|
||||
|
||||
if a_size <= b_size {
|
||||
let sum_size: usize = a_size.min(res_size);
|
||||
let cpy_size: usize = b_size.min(res_size);
|
||||
|
||||
for j in 0..sum_size {
|
||||
ZNXARI::znx_add(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j));
|
||||
}
|
||||
|
||||
for j in sum_size..cpy_size {
|
||||
ZNXARI::znx_copy(res.at_mut(res_col, j), b.at(b_col, j));
|
||||
}
|
||||
|
||||
for j in cpy_size..res_size {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
} else {
|
||||
let sum_size: usize = b_size.min(res_size);
|
||||
let cpy_size: usize = a_size.min(res_size);
|
||||
|
||||
for j in 0..sum_size {
|
||||
ZNXARI::znx_add(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j));
|
||||
}
|
||||
|
||||
for j in sum_size..cpy_size {
|
||||
ZNXARI::znx_copy(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
|
||||
for j in cpy_size..res_size {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_add_inplace<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
ZNXARI: ZnxAddInplace,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let a_size: usize = a.size();
|
||||
|
||||
let sum_size: usize = a_size.min(res_size);
|
||||
|
||||
for j in 0..sum_size {
|
||||
ZNXARI::znx_add_inplace(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_add<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxAdd + ModuleNew<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_add::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxAdd + ModuleNew<B>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
let mut c: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
// Fill a with random i64
|
||||
a.fill_uniform(50, &mut source);
|
||||
b.fill_uniform(50, &mut source);
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_add(&mut c, i, &a, i, &b, i);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_add_inplace<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxAddInplace + ModuleNew<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_add_inplace::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxAddInplace + ModuleNew<B>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
// Fill a with random i64
|
||||
a.fill_uniform(50, &mut source);
|
||||
b.fill_uniform(50, &mut source);
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_add_inplace(&mut b, i, &a, i);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
57
poulpy-hal/src/reference/vec_znx/add_scalar.rs
Normal file
57
poulpy-hal/src/reference/vec_znx/add_scalar.rs
Normal file
@@ -0,0 +1,57 @@
|
||||
use crate::{
|
||||
layouts::{ScalarZnx, ScalarZnxToRef, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
|
||||
reference::znx::{ZnxAdd, ZnxAddInplace, ZnxCopy, ZnxZero},
|
||||
};
|
||||
|
||||
pub fn vec_znx_add_scalar<R, A, B, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize, b_limb: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
B: VecZnxToRef,
|
||||
ZNXARI: ZnxAdd + ZnxCopy + ZnxZero,
|
||||
{
|
||||
let a: ScalarZnx<&[u8]> = a.to_ref();
|
||||
let b: VecZnx<&[u8]> = b.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
let min_size: usize = b.size().min(res.size());
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(
|
||||
b_limb < min_size,
|
||||
"b_limb: {} > min_size: {}",
|
||||
b_limb,
|
||||
min_size
|
||||
);
|
||||
}
|
||||
|
||||
for j in 0..min_size {
|
||||
if j == b_limb {
|
||||
ZNXARI::znx_add(res.at_mut(res_col, j), a.at(a_col, 0), b.at(b_col, j));
|
||||
} else {
|
||||
ZNXARI::znx_copy(res.at_mut(res_col, j), b.at(b_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
for j in min_size..res.size() {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_add_scalar_inplace<R, A, ZNXARI>(res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
ZNXARI: ZnxAddInplace,
|
||||
{
|
||||
let a: ScalarZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(res_limb < res.size());
|
||||
}
|
||||
|
||||
ZNXARI::znx_add_inplace(res.at_mut(res_col, res_limb), a.at(a_col, 0));
|
||||
}
|
||||
150
poulpy-hal/src/reference/vec_znx/automorphism.rs
Normal file
150
poulpy-hal/src/reference/vec_znx/automorphism.rs
Normal file
@@ -0,0 +1,150 @@
|
||||
use std::hint::black_box;
|
||||
|
||||
use criterion::{BenchmarkId, Criterion};
|
||||
|
||||
use crate::{
|
||||
api::{
|
||||
ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAutomorphism, VecZnxAutomorphismInplace,
|
||||
VecZnxAutomorphismInplaceTmpBytes,
|
||||
},
|
||||
layouts::{Backend, FillUniform, Module, ScratchOwned, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
|
||||
reference::znx::{ZnxAutomorphism, ZnxCopy, ZnxZero},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
pub fn vec_znx_automorphism_inplace_tmp_bytes(n: usize) -> usize {
|
||||
n * size_of::<i64>()
|
||||
}
|
||||
|
||||
pub fn vec_znx_automorphism<R, A, ZNXARI>(p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
ZNXARI: ZnxAutomorphism + ZnxZero,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
use crate::layouts::ZnxInfos;
|
||||
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
|
||||
let min_size: usize = res.size().min(a.size());
|
||||
|
||||
for j in 0..min_size {
|
||||
ZNXARI::znx_automorphism(p, res.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
|
||||
for j in min_size..res.size() {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_automorphism_inplace<R, ZNXARI>(p: i64, res: &mut R, res_col: usize, tmp: &mut [i64])
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
ZNXARI: ZnxAutomorphism + ZnxCopy,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), tmp.len());
|
||||
}
|
||||
for j in 0..res.size() {
|
||||
ZNXARI::znx_automorphism(p, tmp, res.at(res_col, j));
|
||||
ZNXARI::znx_copy(res.at_mut(res_col, j), tmp);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_automorphism<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxAutomorphism + ModuleNew<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_automorphism::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxAutomorphism + ModuleNew<B>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
// Fill a with random i64
|
||||
a.fill_uniform(50, &mut source);
|
||||
res.fill_uniform(50, &mut source);
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_automorphism(-7, &mut res, i, &a, i);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_automorphism_inplace<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxAutomorphismInplace<B> + VecZnxAutomorphismInplaceTmpBytes + ModuleNew<B>,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_automorphism_inplace::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxAutomorphismInplace<B> + ModuleNew<B> + VecZnxAutomorphismInplaceTmpBytes,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
let mut scratch = ScratchOwned::alloc(module.vec_znx_automorphism_inplace_tmp_bytes());
|
||||
|
||||
// Fill a with random i64
|
||||
res.fill_uniform(50, &mut source);
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_automorphism_inplace(-7, &mut res, i, scratch.borrow());
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
32
poulpy-hal/src/reference/vec_znx/copy.rs
Normal file
32
poulpy-hal/src/reference/vec_znx/copy.rs
Normal file
@@ -0,0 +1,32 @@
|
||||
use crate::{
|
||||
layouts::{VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
|
||||
reference::znx::{ZnxCopy, ZnxZero},
|
||||
};
|
||||
|
||||
pub fn vec_znx_copy<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
ZNXARI: ZnxCopy + ZnxZero,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), a.n())
|
||||
}
|
||||
|
||||
let res_size = res.size();
|
||||
let a_size = a.size();
|
||||
|
||||
let min_size = res_size.min(a_size);
|
||||
|
||||
for j in 0..min_size {
|
||||
ZNXARI::znx_copy(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
|
||||
for j in min_size..res_size {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
49
poulpy-hal/src/reference/vec_znx/merge_rings.rs
Normal file
49
poulpy-hal/src/reference/vec_znx/merge_rings.rs
Normal file
@@ -0,0 +1,49 @@
|
||||
use crate::{
|
||||
layouts::{VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos},
|
||||
reference::{
|
||||
vec_znx::{vec_znx_rotate_inplace, vec_znx_switch_ring},
|
||||
znx::{ZnxCopy, ZnxRotate, ZnxSwitchRing, ZnxZero},
|
||||
},
|
||||
};
|
||||
|
||||
pub fn vec_znx_merge_rings_tmp_bytes(n: usize) -> usize {
|
||||
n * size_of::<i64>()
|
||||
}
|
||||
|
||||
pub fn vec_znx_merge_rings<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &[A], a_col: usize, tmp: &mut [i64])
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
ZNXARI: ZnxCopy + ZnxSwitchRing + ZnxRotate + ZnxZero,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
let (n_out, n_in) = (res.n(), a[0].to_ref().n());
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(tmp.len(), res.n());
|
||||
|
||||
debug_assert!(
|
||||
n_out > n_in,
|
||||
"invalid a: output ring degree should be greater"
|
||||
);
|
||||
a[1..].iter().for_each(|ai| {
|
||||
debug_assert_eq!(
|
||||
ai.to_ref().n(),
|
||||
n_in,
|
||||
"invalid input a: all VecZnx must have the same degree"
|
||||
)
|
||||
});
|
||||
|
||||
assert!(n_out.is_multiple_of(n_in));
|
||||
assert_eq!(a.len(), n_out / n_in);
|
||||
}
|
||||
|
||||
a.iter().for_each(|ai| {
|
||||
vec_znx_switch_ring::<_, _, ZNXARI>(&mut res, res_col, ai, a_col);
|
||||
vec_znx_rotate_inplace::<_, ZNXARI>(-1, &mut res, res_col, tmp);
|
||||
});
|
||||
|
||||
vec_znx_rotate_inplace::<_, ZNXARI>(a.len() as i64, &mut res, res_col, tmp);
|
||||
}
|
||||
31
poulpy-hal/src/reference/vec_znx/mod.rs
Normal file
31
poulpy-hal/src/reference/vec_znx/mod.rs
Normal file
@@ -0,0 +1,31 @@
|
||||
mod add;
|
||||
mod add_scalar;
|
||||
mod automorphism;
|
||||
mod copy;
|
||||
mod merge_rings;
|
||||
mod mul_xp_minus_one;
|
||||
mod negate;
|
||||
mod normalize;
|
||||
mod rotate;
|
||||
mod sampling;
|
||||
mod shift;
|
||||
mod split_ring;
|
||||
mod sub;
|
||||
mod sub_scalar;
|
||||
mod switch_ring;
|
||||
|
||||
pub use add::*;
|
||||
pub use add_scalar::*;
|
||||
pub use automorphism::*;
|
||||
pub use copy::*;
|
||||
pub use merge_rings::*;
|
||||
pub use mul_xp_minus_one::*;
|
||||
pub use negate::*;
|
||||
pub use normalize::*;
|
||||
pub use rotate::*;
|
||||
pub use sampling::*;
|
||||
pub use shift::*;
|
||||
pub use split_ring::*;
|
||||
pub use sub::*;
|
||||
pub use sub_scalar::*;
|
||||
pub use switch_ring::*;
|
||||
136
poulpy-hal/src/reference/vec_znx/mul_xp_minus_one.rs
Normal file
136
poulpy-hal/src/reference/vec_znx/mul_xp_minus_one.rs
Normal file
@@ -0,0 +1,136 @@
|
||||
use std::hint::black_box;
|
||||
|
||||
use criterion::{BenchmarkId, Criterion};
|
||||
|
||||
use crate::{
|
||||
api::{
|
||||
ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace,
|
||||
VecZnxMulXpMinusOneInplaceTmpBytes,
|
||||
},
|
||||
layouts::{Backend, FillUniform, Module, ScratchOwned, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
|
||||
reference::{
|
||||
vec_znx::{vec_znx_rotate, vec_znx_sub_ab_inplace},
|
||||
znx::{ZnxNegate, ZnxRotate, ZnxSubABInplace, ZnxSubBAInplace, ZnxZero},
|
||||
},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
pub fn vec_znx_mul_xp_minus_one_inplace_tmp_bytes(n: usize) -> usize {
|
||||
n * size_of::<i64>()
|
||||
}
|
||||
|
||||
pub fn vec_znx_mul_xp_minus_one<R, A, ZNXARI>(p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
ZNXARI: ZnxRotate + ZnxZero + ZnxSubABInplace,
|
||||
{
|
||||
vec_znx_rotate::<_, _, ZNXARI>(p, res, res_col, a, a_col);
|
||||
vec_znx_sub_ab_inplace::<_, _, ZNXARI>(res, res_col, a, a_col);
|
||||
}
|
||||
|
||||
pub fn vec_znx_mul_xp_minus_one_inplace<R, ZNXARI>(p: i64, res: &mut R, res_col: usize, tmp: &mut [i64])
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
ZNXARI: ZnxRotate + ZnxNegate + ZnxSubBAInplace,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), tmp.len());
|
||||
}
|
||||
for j in 0..res.size() {
|
||||
ZNXARI::znx_rotate(p, tmp, res.at(res_col, j));
|
||||
ZNXARI::znx_sub_ba_inplace(res.at_mut(res_col, j), tmp);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_mul_xp_minus_one<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxMulXpMinusOne + ModuleNew<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_mul_xp_minus_one::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxMulXpMinusOne + ModuleNew<B>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
// Fill a with random i64
|
||||
a.fill_uniform(50, &mut source);
|
||||
res.fill_uniform(50, &mut source);
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_mul_xp_minus_one(-7, &mut res, i, &a, i);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_mul_xp_minus_one_inplace<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxMulXpMinusOneInplace<B> + VecZnxMulXpMinusOneInplaceTmpBytes + ModuleNew<B>,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_mul_xp_minus_one_inplace::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxMulXpMinusOneInplace<B> + ModuleNew<B> + VecZnxMulXpMinusOneInplaceTmpBytes,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
let mut scratch = ScratchOwned::alloc(module.vec_znx_mul_xp_minus_one_inplace_tmp_bytes());
|
||||
|
||||
// Fill a with random i64
|
||||
res.fill_uniform(50, &mut source);
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_mul_xp_minus_one_inplace(-7, &mut res, i, scratch.borrow());
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
131
poulpy-hal/src/reference/vec_znx/negate.rs
Normal file
131
poulpy-hal/src/reference/vec_znx/negate.rs
Normal file
@@ -0,0 +1,131 @@
|
||||
use std::hint::black_box;
|
||||
|
||||
use criterion::{BenchmarkId, Criterion};
|
||||
|
||||
use crate::{
|
||||
api::{ModuleNew, VecZnxNegate, VecZnxNegateInplace},
|
||||
layouts::{Backend, FillUniform, Module, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
|
||||
reference::znx::{ZnxNegate, ZnxNegateInplace, ZnxZero},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
pub fn vec_znx_negate<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
ZNXARI: ZnxNegate + ZnxZero,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
|
||||
let min_size: usize = res.size().min(a.size());
|
||||
|
||||
for j in 0..min_size {
|
||||
ZNXARI::znx_negate(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
|
||||
for j in min_size..res.size() {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_negate_inplace<R, ZNXARI>(res: &mut R, res_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
ZNXARI: ZnxNegateInplace,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
for j in 0..res.size() {
|
||||
ZNXARI::znx_negate_inplace(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_negate<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxNegate + ModuleNew<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_negate::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxNegate + ModuleNew<B>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
// Fill a with random i64
|
||||
a.fill_uniform(50, &mut source);
|
||||
b.fill_uniform(50, &mut source);
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_negate(&mut b, i, &a, i);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_negate_inplace<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxNegateInplace + ModuleNew<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_negate_inplace::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxNegateInplace + ModuleNew<B>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
// Fill a with random i64
|
||||
a.fill_uniform(50, &mut source);
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_negate_inplace(&mut a, i);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
193
poulpy-hal/src/reference/vec_znx/normalize.rs
Normal file
193
poulpy-hal/src/reference/vec_znx/normalize.rs
Normal file
@@ -0,0 +1,193 @@
|
||||
use std::hint::black_box;
|
||||
|
||||
use criterion::{BenchmarkId, Criterion};
|
||||
|
||||
use crate::{
|
||||
api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes},
|
||||
layouts::{Backend, FillUniform, Module, ScratchOwned, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
|
||||
reference::znx::{
|
||||
ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly,
|
||||
ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace,
|
||||
ZnxZero,
|
||||
},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
pub fn vec_znx_normalize_tmp_bytes(n: usize) -> usize {
|
||||
n * size_of::<i64>()
|
||||
}
|
||||
|
||||
pub fn vec_znx_normalize<R, A, ZNXARI>(basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
ZNXARI: ZnxZero
|
||||
+ ZnxNormalizeFirstStepCarryOnly
|
||||
+ ZnxNormalizeMiddleStepCarryOnly
|
||||
+ ZnxNormalizeMiddleStep
|
||||
+ ZnxNormalizeFinalStep
|
||||
+ ZnxNormalizeFirstStep,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(carry.len() >= res.n());
|
||||
}
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let a_size = a.size();
|
||||
|
||||
if a_size > res_size {
|
||||
for j in (res_size..a_size).rev() {
|
||||
if j == a_size - 1 {
|
||||
ZNXARI::znx_normalize_first_step_carry_only(basek, 0, a.at(a_col, j), carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step_carry_only(basek, 0, a.at(a_col, j), carry);
|
||||
}
|
||||
}
|
||||
|
||||
for j in (1..res_size).rev() {
|
||||
ZNXARI::znx_normalize_middle_step(basek, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
|
||||
}
|
||||
|
||||
ZNXARI::znx_normalize_final_step(basek, 0, res.at_mut(res_col, 0), a.at(a_col, 0), carry);
|
||||
} else {
|
||||
for j in (0..a_size).rev() {
|
||||
if j == a_size - 1 {
|
||||
ZNXARI::znx_normalize_first_step(basek, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
|
||||
} else if j == 0 {
|
||||
ZNXARI::znx_normalize_final_step(basek, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step(basek, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
|
||||
}
|
||||
}
|
||||
|
||||
for j in a_size..res_size {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_normalize_inplace<R: VecZnxToMut, ZNXARI>(basek: usize, res: &mut R, res_col: usize, carry: &mut [i64])
|
||||
where
|
||||
ZNXARI: ZnxNormalizeFirstStepInplace + ZnxNormalizeMiddleStepInplace + ZnxNormalizeFinalStepInplace,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(carry.len() >= res.n());
|
||||
}
|
||||
|
||||
let res_size: usize = res.size();
|
||||
|
||||
for j in (0..res_size).rev() {
|
||||
if j == res_size - 1 {
|
||||
ZNXARI::znx_normalize_first_step_inplace(basek, 0, res.at_mut(res_col, j), carry);
|
||||
} else if j == 0 {
|
||||
ZNXARI::znx_normalize_final_step_inplace(basek, 0, res.at_mut(res_col, j), carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step_inplace(basek, 0, res.at_mut(res_col, j), carry);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_normalize<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxNormalize<B> + ModuleNew<B> + VecZnxNormalizeTmpBytes,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_normalize::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxNormalize<B> + ModuleNew<B> + VecZnxNormalizeTmpBytes,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let basek: usize = 50;
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
// Fill a with random i64
|
||||
a.fill_uniform(50, &mut source);
|
||||
res.fill_uniform(50, &mut source);
|
||||
|
||||
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(module.vec_znx_normalize_tmp_bytes());
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_normalize(basek, &mut res, i, &a, i, scratch.borrow());
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_normalize_inplace<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxNormalizeInplace<B> + ModuleNew<B> + VecZnxNormalizeTmpBytes,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_normalize_inplace::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxNormalizeInplace<B> + ModuleNew<B> + VecZnxNormalizeTmpBytes,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let basek: usize = 50;
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
// Fill a with random i64
|
||||
a.fill_uniform(50, &mut source);
|
||||
|
||||
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(module.vec_znx_normalize_tmp_bytes());
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_normalize_inplace(basek, &mut a, i, scratch.borrow());
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
148
poulpy-hal/src/reference/vec_znx/rotate.rs
Normal file
148
poulpy-hal/src/reference/vec_znx/rotate.rs
Normal file
@@ -0,0 +1,148 @@
|
||||
use std::hint::black_box;
|
||||
|
||||
use criterion::{BenchmarkId, Criterion};
|
||||
|
||||
use crate::{
|
||||
api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxRotate, VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes},
|
||||
layouts::{Backend, FillUniform, Module, ScratchOwned, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
|
||||
reference::znx::{ZnxCopy, ZnxRotate, ZnxZero},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
pub fn vec_znx_rotate_inplace_tmp_bytes(n: usize) -> usize {
|
||||
n * size_of::<i64>()
|
||||
}
|
||||
|
||||
pub fn vec_znx_rotate<R, A, ZNXARI>(p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
ZNXARI: ZnxRotate + ZnxZero,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), a.n())
|
||||
}
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let a_size: usize = a.size();
|
||||
|
||||
let min_size: usize = res_size.min(a_size);
|
||||
|
||||
for j in 0..min_size {
|
||||
ZNXARI::znx_rotate(p, res.at_mut(res_col, j), a.at(a_col, j))
|
||||
}
|
||||
|
||||
for j in min_size..res_size {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_rotate_inplace<R, ZNXARI>(p: i64, res: &mut R, res_col: usize, tmp: &mut [i64])
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
ZNXARI: ZnxRotate + ZnxCopy,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), tmp.len());
|
||||
}
|
||||
for j in 0..res.size() {
|
||||
ZNXARI::znx_rotate(p, tmp, res.at(res_col, j));
|
||||
ZNXARI::znx_copy(res.at_mut(res_col, j), tmp);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_rotate<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxRotate + ModuleNew<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_rotate::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxRotate + ModuleNew<B>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
// Fill a with random i64
|
||||
a.fill_uniform(50, &mut source);
|
||||
res.fill_uniform(50, &mut source);
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_rotate(-7, &mut res, i, &a, i);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_rotate_inplace<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxRotateInplace<B> + VecZnxRotateInplaceTmpBytes + ModuleNew<B>,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_rotate_inplace::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxRotateInplace<B> + ModuleNew<B> + VecZnxRotateInplaceTmpBytes,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
let mut scratch = ScratchOwned::alloc(module.vec_znx_rotate_inplace_tmp_bytes());
|
||||
|
||||
// Fill a with random i64
|
||||
res.fill_uniform(50, &mut source);
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_rotate_inplace(-7, &mut res, i, scratch.borrow());
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
64
poulpy-hal/src/reference/vec_znx/sampling.rs
Normal file
64
poulpy-hal/src/reference/vec_znx/sampling.rs
Normal file
@@ -0,0 +1,64 @@
|
||||
use crate::{
|
||||
layouts::{VecZnx, VecZnxToMut, ZnxInfos, ZnxViewMut},
|
||||
reference::znx::{znx_add_normal_f64_ref, znx_fill_normal_f64_ref, znx_fill_uniform_ref},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
pub fn vec_znx_fill_uniform_ref<R>(basek: usize, res: &mut R, res_col: usize, source: &mut Source)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
for j in 0..res.size() {
|
||||
znx_fill_uniform_ref(basek, res.at_mut(res_col, j), source)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_fill_normal_ref<R>(
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
source: &mut Source,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = k.div_ceil(basek) - 1;
|
||||
let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64;
|
||||
znx_fill_normal_f64_ref(
|
||||
res.at_mut(res_col, limb),
|
||||
sigma * scale,
|
||||
bound * scale,
|
||||
source,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn vec_znx_add_normal_ref<R>(basek: usize, res: &mut R, res_col: usize, k: usize, sigma: f64, bound: f64, source: &mut Source)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = k.div_ceil(basek) - 1;
|
||||
let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64;
|
||||
znx_add_normal_f64_ref(
|
||||
res.at_mut(res_col, limb),
|
||||
sigma * scale,
|
||||
bound * scale,
|
||||
source,
|
||||
)
|
||||
}
|
||||
672
poulpy-hal/src/reference/vec_znx/shift.rs
Normal file
672
poulpy-hal/src/reference/vec_znx/shift.rs
Normal file
@@ -0,0 +1,672 @@
|
||||
use std::hint::black_box;
|
||||
|
||||
use criterion::{BenchmarkId, Criterion};
|
||||
|
||||
use crate::{
|
||||
api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxLsh, VecZnxLshInplace, VecZnxRsh, VecZnxRshInplace},
|
||||
layouts::{Backend, FillUniform, Module, ScratchOwned, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
|
||||
reference::{
|
||||
vec_znx::vec_znx_copy,
|
||||
znx::{
|
||||
ZnxCopy, ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly,
|
||||
ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace,
|
||||
ZnxZero,
|
||||
},
|
||||
},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
pub fn vec_znx_lsh_tmp_bytes(n: usize) -> usize {
|
||||
n * size_of::<i64>()
|
||||
}
|
||||
|
||||
pub fn vec_znx_lsh_inplace<R, ZNXARI>(basek: usize, k: usize, res: &mut R, res_col: usize, carry: &mut [i64])
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
ZNXARI: ZnxZero
|
||||
+ ZnxCopy
|
||||
+ ZnxNormalizeFirstStepInplace
|
||||
+ ZnxNormalizeMiddleStepInplace
|
||||
+ ZnxNormalizeFirstStepInplace
|
||||
+ ZnxNormalizeFinalStepInplace,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
let n: usize = res.n();
|
||||
let cols: usize = res.cols();
|
||||
let size: usize = res.size();
|
||||
let steps: usize = k / basek;
|
||||
let k_rem: usize = k % basek;
|
||||
|
||||
if steps >= size {
|
||||
for j in 0..size {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Inplace shift of limbs by a k/basek
|
||||
if steps > 0 {
|
||||
let start: usize = n * res_col;
|
||||
let end: usize = start + n;
|
||||
let slice_size: usize = n * cols;
|
||||
let res_raw: &mut [i64] = res.raw_mut();
|
||||
|
||||
(0..size - steps).for_each(|j| {
|
||||
let (lhs, rhs) = res_raw.split_at_mut(slice_size * (j + steps));
|
||||
ZNXARI::znx_copy(
|
||||
&mut lhs[start + j * slice_size..end + j * slice_size],
|
||||
&rhs[start..end],
|
||||
);
|
||||
});
|
||||
|
||||
for j in size - steps..size {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
// Inplace normalization with left shift of k % basek
|
||||
if !k.is_multiple_of(basek) {
|
||||
for j in (0..size - steps).rev() {
|
||||
if j == size - steps - 1 {
|
||||
ZNXARI::znx_normalize_first_step_inplace(basek, k_rem, res.at_mut(res_col, j), carry);
|
||||
} else if j == 0 {
|
||||
ZNXARI::znx_normalize_final_step_inplace(basek, k_rem, res.at_mut(res_col, j), carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step_inplace(basek, k_rem, res.at_mut(res_col, j), carry);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_lsh<R, A, ZNXARI>(basek: usize, k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
ZNXARI: ZnxZero + ZnxNormalizeFirstStep + ZnxNormalizeMiddleStep + ZnxNormalizeFirstStep + ZnxCopy + ZnxNormalizeFinalStep,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let a_size = a.size();
|
||||
let steps: usize = k / basek;
|
||||
let k_rem: usize = k % basek;
|
||||
|
||||
if steps >= res_size.min(a_size) {
|
||||
for j in 0..res_size {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
let min_size: usize = a_size.min(res_size) - steps;
|
||||
|
||||
// Simply a left shifted normalization of limbs
|
||||
// by k/basek and intra-limb by basek - k%basek
|
||||
if !k.is_multiple_of(basek) {
|
||||
for j in (0..min_size).rev() {
|
||||
if j == min_size - 1 {
|
||||
ZNXARI::znx_normalize_first_step(
|
||||
basek,
|
||||
k_rem,
|
||||
res.at_mut(res_col, j),
|
||||
a.at(a_col, j + steps),
|
||||
carry,
|
||||
);
|
||||
} else if j == 0 {
|
||||
ZNXARI::znx_normalize_final_step(
|
||||
basek,
|
||||
k_rem,
|
||||
res.at_mut(res_col, j),
|
||||
a.at(a_col, j + steps),
|
||||
carry,
|
||||
);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step(
|
||||
basek,
|
||||
k_rem,
|
||||
res.at_mut(res_col, j),
|
||||
a.at(a_col, j + steps),
|
||||
carry,
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// If k % basek = 0, then this is simply a copy.
|
||||
for j in (0..min_size).rev() {
|
||||
ZNXARI::znx_copy(res.at_mut(res_col, j), a.at(a_col, j + steps));
|
||||
}
|
||||
}
|
||||
|
||||
// Zeroes bottom
|
||||
for j in min_size..res_size {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_rsh_tmp_bytes(n: usize) -> usize {
|
||||
n * size_of::<i64>()
|
||||
}
|
||||
|
||||
pub fn vec_znx_rsh_inplace<R, ZNXARI>(basek: usize, k: usize, res: &mut R, res_col: usize, carry: &mut [i64])
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
ZNXARI: ZnxZero
|
||||
+ ZnxCopy
|
||||
+ ZnxNormalizeFirstStepCarryOnly
|
||||
+ ZnxNormalizeMiddleStepCarryOnly
|
||||
+ ZnxNormalizeMiddleStep
|
||||
+ ZnxNormalizeMiddleStepInplace
|
||||
+ ZnxNormalizeFirstStepInplace
|
||||
+ ZnxNormalizeFinalStepInplace,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
let n: usize = res.n();
|
||||
let cols: usize = res.cols();
|
||||
let size: usize = res.size();
|
||||
|
||||
let mut steps: usize = k / basek;
|
||||
let k_rem: usize = k % basek;
|
||||
|
||||
if k == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
if steps >= size {
|
||||
for j in 0..size {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
let start: usize = n * res_col;
|
||||
let end: usize = start + n;
|
||||
let slice_size: usize = n * cols;
|
||||
|
||||
if !k.is_multiple_of(basek) {
|
||||
// We rsh by an additional basek and then lsh by basek-k
|
||||
// Allows to re-use efficient normalization code, avoids
|
||||
// avoids overflows & produce output that is normalized
|
||||
steps += 1;
|
||||
|
||||
// All limbs of a that would fall outside of the limbs of res are discarded,
|
||||
// but the carry still need to be computed.
|
||||
(size - steps..size).rev().for_each(|j| {
|
||||
if j == size - 1 {
|
||||
ZNXARI::znx_normalize_first_step_carry_only(basek, basek - k_rem, res.at(res_col, j), carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step_carry_only(basek, basek - k_rem, res.at(res_col, j), carry);
|
||||
}
|
||||
});
|
||||
|
||||
// Continues with shifted normalization
|
||||
let res_raw: &mut [i64] = res.raw_mut();
|
||||
(steps..size).rev().for_each(|j| {
|
||||
let (lhs, rhs) = res_raw.split_at_mut(slice_size * j);
|
||||
let rhs_slice: &mut [i64] = &mut rhs[start..end];
|
||||
let lhs_slice: &[i64] = &lhs[(j - steps) * slice_size + start..(j - steps) * slice_size + end];
|
||||
ZNXARI::znx_normalize_middle_step(basek, basek - k_rem, rhs_slice, lhs_slice, carry);
|
||||
});
|
||||
|
||||
// Propagates carry on the rest of the limbs of res
|
||||
for j in (0..steps).rev() {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
if j == 0 {
|
||||
ZNXARI::znx_normalize_final_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Shift by multiples of basek
|
||||
let res_raw: &mut [i64] = res.raw_mut();
|
||||
(steps..size).rev().for_each(|j| {
|
||||
let (lhs, rhs) = res_raw.split_at_mut(slice_size * j);
|
||||
ZNXARI::znx_copy(
|
||||
&mut rhs[start..end],
|
||||
&lhs[(j - steps) * slice_size + start..(j - steps) * slice_size + end],
|
||||
);
|
||||
});
|
||||
|
||||
// Zeroes the top
|
||||
(0..steps).for_each(|j| {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_rsh<R, A, ZNXARI>(basek: usize, k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
ZNXARI: ZnxZero
|
||||
+ ZnxCopy
|
||||
+ ZnxNormalizeFirstStepCarryOnly
|
||||
+ ZnxNormalizeMiddleStepCarryOnly
|
||||
+ ZnxNormalizeFirstStep
|
||||
+ ZnxNormalizeMiddleStep
|
||||
+ ZnxNormalizeMiddleStepInplace
|
||||
+ ZnxNormalizeFirstStepInplace
|
||||
+ ZnxNormalizeFinalStepInplace,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let a_size: usize = a.size();
|
||||
|
||||
let mut steps: usize = k / basek;
|
||||
let k_rem: usize = k % basek;
|
||||
|
||||
if k == 0 {
|
||||
vec_znx_copy::<_, _, ZNXARI>(&mut res, res_col, &a, a_col);
|
||||
return;
|
||||
}
|
||||
|
||||
if steps >= res_size {
|
||||
for j in 0..res_size {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if !k.is_multiple_of(basek) {
|
||||
// We rsh by an additional basek and then lsh by basek-k
|
||||
// Allows to re-use efficient normalization code, avoids
|
||||
// avoids overflows & produce output that is normalized
|
||||
steps += 1;
|
||||
|
||||
// All limbs of a that are moved outside of the limbs of res are discarded,
|
||||
// but the carry still need to be computed.
|
||||
for j in (res_size..a_size + steps).rev() {
|
||||
if j == a_size + steps - 1 {
|
||||
ZNXARI::znx_normalize_first_step_carry_only(basek, basek - k_rem, a.at(a_col, j - steps), carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step_carry_only(basek, basek - k_rem, a.at(a_col, j - steps), carry);
|
||||
}
|
||||
}
|
||||
|
||||
// Avoids over flow of limbs of res
|
||||
let min_size: usize = res_size.min(a_size + steps);
|
||||
|
||||
// Zeroes lower limbs of res if a_size + steps < res_size
|
||||
(min_size..res_size).for_each(|j| {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
});
|
||||
|
||||
// Continues with shifted normalization
|
||||
for j in (steps..min_size).rev() {
|
||||
// Case if no limb of a was previously discarded
|
||||
if res_size.saturating_sub(steps) >= a_size && j == min_size - 1 {
|
||||
ZNXARI::znx_normalize_first_step(
|
||||
basek,
|
||||
basek - k_rem,
|
||||
res.at_mut(res_col, j),
|
||||
a.at(a_col, j - steps),
|
||||
carry,
|
||||
);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step(
|
||||
basek,
|
||||
basek - k_rem,
|
||||
res.at_mut(res_col, j),
|
||||
a.at(a_col, j - steps),
|
||||
carry,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Propagates carry on the rest of the limbs of res
|
||||
for j in (0..steps).rev() {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
if j == 0 {
|
||||
ZNXARI::znx_normalize_final_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let min_size: usize = res_size.min(a_size + steps);
|
||||
|
||||
// Zeroes the top
|
||||
(0..steps).for_each(|j| {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
});
|
||||
|
||||
// Shift a into res, up to the maximum
|
||||
for j in (steps..min_size).rev() {
|
||||
ZNXARI::znx_copy(res.at_mut(res_col, j), a.at(a_col, j - steps));
|
||||
}
|
||||
|
||||
// Zeroes bottom if a_size + steps < res_size
|
||||
(min_size..res_size).for_each(|j| {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_lsh_inplace<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: ModuleNew<B> + VecZnxLshInplace<B>,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_lsh_inplace::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxLshInplace<B> + ModuleNew<B>,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let basek: usize = 50;
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(n * size_of::<i64>());
|
||||
|
||||
// Fill a with random i64
|
||||
a.fill_uniform(50, &mut source);
|
||||
b.fill_uniform(50, &mut source);
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_lsh_inplace(basek, basek - 1, &mut b, i, scratch.borrow());
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_lsh<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxLsh<B> + ModuleNew<B>,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_lsh::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxLsh<B> + ModuleNew<B>,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let basek: usize = 50;
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(n * size_of::<i64>());
|
||||
|
||||
// Fill a with random i64
|
||||
a.fill_uniform(50, &mut source);
|
||||
res.fill_uniform(50, &mut source);
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_lsh(basek, basek - 1, &mut res, i, &a, i, scratch.borrow());
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_rsh_inplace<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxRshInplace<B> + ModuleNew<B>,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_rsh_inplace::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxRshInplace<B> + ModuleNew<B>,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let basek: usize = 50;
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(n * size_of::<i64>());
|
||||
|
||||
// Fill a with random i64
|
||||
a.fill_uniform(50, &mut source);
|
||||
b.fill_uniform(50, &mut source);
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_rsh_inplace(basek, basek - 1, &mut b, i, scratch.borrow());
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_rsh<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxRsh<B> + ModuleNew<B>,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_rsh::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxRsh<B> + ModuleNew<B>,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let basek: usize = 50;
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(n * size_of::<i64>());
|
||||
|
||||
// Fill a with random i64
|
||||
a.fill_uniform(50, &mut source);
|
||||
res.fill_uniform(50, &mut source);
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_rsh(basek, basek - 1, &mut res, i, &a, i, scratch.borrow());
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{
|
||||
layouts::{FillUniform, VecZnx, ZnxView},
|
||||
reference::{
|
||||
vec_znx::{
|
||||
vec_znx_copy, vec_znx_lsh, vec_znx_lsh_inplace, vec_znx_normalize_inplace, vec_znx_rsh, vec_znx_rsh_inplace,
|
||||
vec_znx_sub_ab_inplace,
|
||||
},
|
||||
znx::ZnxRef,
|
||||
},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_vec_znx_lsh() {
|
||||
let n: usize = 8;
|
||||
let cols: usize = 2;
|
||||
let size: usize = 7;
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut carry: Vec<i64> = vec![0i64; n];
|
||||
|
||||
let basek: usize = 50;
|
||||
|
||||
for k in 0..256 {
|
||||
a.fill_uniform(50, &mut source);
|
||||
|
||||
for i in 0..cols {
|
||||
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut a, i, &mut carry);
|
||||
vec_znx_copy::<_, _, ZnxRef>(&mut res_ref, i, &a, i);
|
||||
}
|
||||
|
||||
for i in 0..cols {
|
||||
vec_znx_lsh_inplace::<_, ZnxRef>(basek, k, &mut res_ref, i, &mut carry);
|
||||
vec_znx_lsh::<_, _, ZnxRef>(basek, k, &mut res_test, i, &a, i, &mut carry);
|
||||
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut res_test, i, &mut carry);
|
||||
}
|
||||
|
||||
assert_eq!(res_ref, res_test);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vec_znx_rsh() {
|
||||
let n: usize = 8;
|
||||
let cols: usize = 2;
|
||||
|
||||
let res_size: usize = 7;
|
||||
|
||||
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
|
||||
let mut carry: Vec<i64> = vec![0i64; n];
|
||||
|
||||
let basek: usize = 50;
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let zero: Vec<i64> = vec![0i64; n];
|
||||
|
||||
for a_size in [res_size - 1, res_size, res_size + 1] {
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
|
||||
|
||||
for k in 0..res_size * basek {
|
||||
a.fill_uniform(50, &mut source);
|
||||
|
||||
for i in 0..cols {
|
||||
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut a, i, &mut carry);
|
||||
vec_znx_copy::<_, _, ZnxRef>(&mut res_ref, i, &a, i);
|
||||
}
|
||||
|
||||
res_test.fill_uniform(50, &mut source);
|
||||
|
||||
for j in 0..cols {
|
||||
vec_znx_rsh_inplace::<_, ZnxRef>(basek, k, &mut res_ref, j, &mut carry);
|
||||
vec_znx_rsh::<_, _, ZnxRef>(basek, k, &mut res_test, j, &a, j, &mut carry);
|
||||
}
|
||||
|
||||
for j in 0..cols {
|
||||
vec_znx_lsh_inplace::<_, ZnxRef>(basek, k, &mut res_ref, j, &mut carry);
|
||||
vec_znx_lsh_inplace::<_, ZnxRef>(basek, k, &mut res_test, j, &mut carry);
|
||||
}
|
||||
|
||||
// Case where res has enough to fully store a right shifted without any loss
|
||||
// In this case we can check exact equality.
|
||||
if a_size + k.div_ceil(basek) <= res_size {
|
||||
assert_eq!(res_ref, res_test);
|
||||
|
||||
for i in 0..cols {
|
||||
for j in 0..a_size {
|
||||
assert_eq!(res_ref.at(i, j), a.at(i, j), "r0 {} {}", i, j);
|
||||
assert_eq!(res_test.at(i, j), a.at(i, j), "r1 {} {}", i, j);
|
||||
}
|
||||
|
||||
for j in a_size..res_size {
|
||||
assert_eq!(res_ref.at(i, j), zero, "r0 {} {}", i, j);
|
||||
assert_eq!(res_test.at(i, j), zero, "r1 {} {}", i, j);
|
||||
}
|
||||
}
|
||||
// Some loss occures, either because a initially has more precision than res
|
||||
// or because the storage of the right shift of a requires more precision than
|
||||
// res.
|
||||
} else {
|
||||
for j in 0..cols {
|
||||
vec_znx_sub_ab_inplace::<_, _, ZnxRef>(&mut res_ref, j, &a, j);
|
||||
vec_znx_sub_ab_inplace::<_, _, ZnxRef>(&mut res_test, j, &a, j);
|
||||
|
||||
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut res_ref, j, &mut carry);
|
||||
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut res_test, j, &mut carry);
|
||||
|
||||
assert!(res_ref.std(basek, j).log2() - (k as f64) <= (k * basek) as f64);
|
||||
assert!(res_test.std(basek, j).log2() - (k as f64) <= (k * basek) as f64);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
62
poulpy-hal/src/reference/vec_znx/split_ring.rs
Normal file
62
poulpy-hal/src/reference/vec_znx/split_ring.rs
Normal file
@@ -0,0 +1,62 @@
|
||||
use crate::{
|
||||
layouts::{VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
|
||||
reference::znx::{ZnxRotate, ZnxSwitchRing, ZnxZero},
|
||||
};
|
||||
|
||||
pub fn vec_znx_split_ring_tmp_bytes(n: usize) -> usize {
|
||||
n * size_of::<i64>()
|
||||
}
|
||||
|
||||
pub fn vec_znx_split_ring<R, A, ZNXARI>(res: &mut [R], res_col: usize, a: &A, a_col: usize, tmp: &mut [i64])
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
ZNXARI: ZnxSwitchRing + ZnxRotate + ZnxZero,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let a_size = a.size();
|
||||
|
||||
let (n_in, n_out) = (a.n(), res[0].to_mut().n());
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(tmp.len(), a.n());
|
||||
|
||||
assert!(
|
||||
n_out < n_in,
|
||||
"invalid a: output ring degree should be smaller"
|
||||
);
|
||||
|
||||
res[1..].iter_mut().for_each(|bi| {
|
||||
assert_eq!(
|
||||
bi.to_mut().n(),
|
||||
n_out,
|
||||
"invalid input a: all VecZnx must have the same degree"
|
||||
)
|
||||
});
|
||||
|
||||
assert!(n_in.is_multiple_of(n_out));
|
||||
assert_eq!(res.len(), n_in / n_out);
|
||||
}
|
||||
|
||||
res.iter_mut().enumerate().for_each(|(i, bi)| {
|
||||
let mut bi: VecZnx<&mut [u8]> = bi.to_mut();
|
||||
|
||||
let min_size = bi.size().min(a_size);
|
||||
|
||||
if i == 0 {
|
||||
for j in 0..min_size {
|
||||
ZNXARI::znx_switch_ring(bi.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
} else {
|
||||
for j in 0..min_size {
|
||||
ZNXARI::znx_rotate(-(i as i64), tmp, a.at(a_col, j));
|
||||
ZNXARI::znx_switch_ring(bi.at_mut(res_col, j), tmp);
|
||||
}
|
||||
}
|
||||
|
||||
for j in min_size..bi.size() {
|
||||
ZNXARI::znx_zero(bi.at_mut(res_col, j));
|
||||
}
|
||||
})
|
||||
}
|
||||
250
poulpy-hal/src/reference/vec_znx/sub.rs
Normal file
250
poulpy-hal/src/reference/vec_znx/sub.rs
Normal file
@@ -0,0 +1,250 @@
|
||||
use std::hint::black_box;
|
||||
|
||||
use criterion::{BenchmarkId, Criterion};
|
||||
|
||||
use crate::{
|
||||
api::{ModuleNew, VecZnxSub, VecZnxSubABInplace, VecZnxSubBAInplace},
|
||||
layouts::{Backend, FillUniform, Module, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
|
||||
oep::{ModuleNewImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl},
|
||||
reference::znx::{ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxSub, ZnxSubABInplace, ZnxSubBAInplace, ZnxZero},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
pub fn vec_znx_sub<R, A, B, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
B: VecZnxToRef,
|
||||
ZNXARI: ZnxSub + ZnxNegate + ZnxZero + ZnxCopy,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let b: VecZnx<&[u8]> = b.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
assert_eq!(b.n(), res.n());
|
||||
}
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let a_size: usize = a.size();
|
||||
let b_size: usize = b.size();
|
||||
|
||||
if a_size <= b_size {
|
||||
let sum_size: usize = a_size.min(res_size);
|
||||
let cpy_size: usize = b_size.min(res_size);
|
||||
|
||||
for j in 0..sum_size {
|
||||
ZNXARI::znx_sub(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j));
|
||||
}
|
||||
|
||||
for j in sum_size..cpy_size {
|
||||
ZNXARI::znx_negate(res.at_mut(res_col, j), b.at(b_col, j));
|
||||
}
|
||||
|
||||
for j in cpy_size..res_size {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
} else {
|
||||
let sum_size: usize = b_size.min(res_size);
|
||||
let cpy_size: usize = a_size.min(res_size);
|
||||
|
||||
for j in 0..sum_size {
|
||||
ZNXARI::znx_sub(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j));
|
||||
}
|
||||
|
||||
for j in sum_size..cpy_size {
|
||||
ZNXARI::znx_copy(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
|
||||
for j in cpy_size..res_size {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_sub_ab_inplace<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
ZNXARI: ZnxSubABInplace,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let a_size: usize = a.size();
|
||||
|
||||
let sum_size: usize = a_size.min(res_size);
|
||||
|
||||
for j in 0..sum_size {
|
||||
ZNXARI::znx_sub_ab_inplace(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_sub_ba_inplace<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
ZNXARI: ZnxSubBAInplace + ZnxNegateInplace,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let a_size: usize = a.size();
|
||||
|
||||
let sum_size: usize = a_size.min(res_size);
|
||||
|
||||
for j in 0..sum_size {
|
||||
ZNXARI::znx_sub_ba_inplace(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
|
||||
for j in sum_size..res_size {
|
||||
ZNXARI::znx_negate_inplace(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_sub<B>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
B: Backend + ModuleNewImpl<B> + VecZnxSubImpl<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_sub::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxSub + ModuleNew<B>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
let mut c: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
// Fill a with random i64
|
||||
a.fill_uniform(50, &mut source);
|
||||
b.fill_uniform(50, &mut source);
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_sub(&mut c, i, &a, i, &b, i);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_sub_ab_inplace<B>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
B: Backend + ModuleNewImpl<B> + VecZnxSubABInplaceImpl<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_sub_ab_inplace::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxSubABInplace + ModuleNew<B>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
// Fill a with random i64
|
||||
a.fill_uniform(50, &mut source);
|
||||
b.fill_uniform(50, &mut source);
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_sub_ab_inplace(&mut b, i, &a, i);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_sub_ba_inplace<B>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
B: Backend + ModuleNewImpl<B> + VecZnxSubBAInplaceImpl<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_sub_ba_inplace::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxSubBAInplace + ModuleNew<B>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
// Fill a with random i64
|
||||
a.fill_uniform(50, &mut source);
|
||||
b.fill_uniform(50, &mut source);
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_sub_ba_inplace(&mut b, i, &a, i);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
58
poulpy-hal/src/reference/vec_znx/sub_scalar.rs
Normal file
58
poulpy-hal/src/reference/vec_znx/sub_scalar.rs
Normal file
@@ -0,0 +1,58 @@
|
||||
use crate::layouts::{ScalarZnxToRef, VecZnxToMut, VecZnxToRef};
|
||||
use crate::{
|
||||
layouts::{ScalarZnx, VecZnx, ZnxInfos, ZnxView, ZnxViewMut},
|
||||
reference::znx::{ZnxSub, ZnxSubABInplace, ZnxZero},
|
||||
};
|
||||
|
||||
pub fn vec_znx_sub_scalar<R, A, B, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize, b_limb: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
B: VecZnxToRef,
|
||||
ZNXARI: ZnxSub + ZnxZero,
|
||||
{
|
||||
let a: ScalarZnx<&[u8]> = a.to_ref();
|
||||
let b: VecZnx<&[u8]> = b.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
let min_size: usize = b.size().min(res.size());
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(
|
||||
b_limb < min_size,
|
||||
"b_limb: {} > min_size: {}",
|
||||
b_limb,
|
||||
min_size
|
||||
);
|
||||
}
|
||||
|
||||
for j in 0..min_size {
|
||||
if j == b_limb {
|
||||
ZNXARI::znx_sub(res.at_mut(res_col, j), b.at(b_col, j), a.at(a_col, 0));
|
||||
} else {
|
||||
res.at_mut(res_col, j).copy_from_slice(b.at(b_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
for j in min_size..res.size() {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_sub_scalar_inplace<R, A, ZNXARI>(res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
ZNXARI: ZnxSubABInplace,
|
||||
{
|
||||
let a: ScalarZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(res_limb < res.size());
|
||||
}
|
||||
|
||||
ZNXARI::znx_sub_ab_inplace(res.at_mut(res_col, res_limb), a.at(a_col, 0));
|
||||
}
|
||||
37
poulpy-hal/src/reference/vec_znx/switch_ring.rs
Normal file
37
poulpy-hal/src/reference/vec_znx/switch_ring.rs
Normal file
@@ -0,0 +1,37 @@
|
||||
use crate::{
|
||||
layouts::{VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
|
||||
reference::{
|
||||
vec_znx::vec_znx_copy,
|
||||
znx::{ZnxCopy, ZnxSwitchRing, ZnxZero},
|
||||
},
|
||||
};
|
||||
|
||||
/// Maps between negacyclic rings by changing the polynomial degree.
|
||||
/// Up: Z[X]/(X^N+1) -> Z[X]/(X^{2^d N}+1) via X ↦ X^{2^d}
|
||||
/// Down: Z[X]/(X^N+1) -> Z[X]/(X^{N/2^d}+1) by folding indices.
|
||||
pub fn vec_znx_switch_ring<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
ZNXARI: ZnxCopy + ZnxSwitchRing + ZnxZero,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
let (n_in, n_out) = (a.n(), res.n());
|
||||
|
||||
if n_in == n_out {
|
||||
vec_znx_copy::<_, _, ZNXARI>(&mut res, res_col, &a, a_col);
|
||||
return;
|
||||
}
|
||||
|
||||
let min_size: usize = a.size().min(res.size());
|
||||
|
||||
for j in 0..min_size {
|
||||
ZNXARI::znx_switch_ring(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
|
||||
for j in min_size..res.size() {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
5
poulpy-hal/src/reference/zn/mod.rs
Normal file
5
poulpy-hal/src/reference/zn/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
mod normalization;
|
||||
mod sampling;
|
||||
|
||||
pub use normalization::*;
|
||||
pub use sampling::*;
|
||||
72
poulpy-hal/src/reference/zn/normalization.rs
Normal file
72
poulpy-hal/src/reference/zn/normalization.rs
Normal file
@@ -0,0 +1,72 @@
|
||||
use crate::{
|
||||
api::{ScratchOwnedAlloc, ScratchOwnedBorrow, ZnNormalizeInplace, ZnNormalizeTmpBytes},
|
||||
layouts::{Backend, Module, ScratchOwned, Zn, ZnToMut, ZnxInfos, ZnxView, ZnxViewMut},
|
||||
reference::znx::{ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStepInplace, ZnxRef},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
pub fn zn_normalize_tmp_bytes(n: usize) -> usize {
|
||||
n * size_of::<i64>()
|
||||
}
|
||||
|
||||
pub fn zn_normalize_inplace<R, ARI>(n: usize, basek: usize, res: &mut R, res_col: usize, carry: &mut [i64])
|
||||
where
|
||||
R: ZnToMut,
|
||||
ARI: ZnxNormalizeFirstStepInplace + ZnxNormalizeFinalStepInplace + ZnxNormalizeMiddleStepInplace,
|
||||
{
|
||||
let mut res: Zn<&mut [u8]> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(carry.len(), res.n());
|
||||
}
|
||||
|
||||
let res_size: usize = res.size();
|
||||
|
||||
for j in (0..res_size).rev() {
|
||||
let out = &mut res.at_mut(res_col, j)[..n];
|
||||
|
||||
if j == res_size - 1 {
|
||||
ARI::znx_normalize_first_step_inplace(basek, 0, out, carry);
|
||||
} else if j == 0 {
|
||||
ARI::znx_normalize_final_step_inplace(basek, 0, out, carry);
|
||||
} else {
|
||||
ARI::znx_normalize_middle_step_inplace(basek, 0, out, carry);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn test_zn_normalize_inplace<B: Backend>(module: &Module<B>)
|
||||
where
|
||||
Module<B>: ZnNormalizeInplace<B> + ZnNormalizeTmpBytes,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let cols: usize = 2;
|
||||
let basek: usize = 12;
|
||||
|
||||
let n = 33;
|
||||
|
||||
let mut carry: Vec<i64> = vec![0i64; zn_normalize_tmp_bytes(n)];
|
||||
|
||||
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(module.zn_normalize_tmp_bytes(module.n()));
|
||||
|
||||
for res_size in [1, 2, 6, 11] {
|
||||
let mut res_0: Zn<Vec<u8>> = Zn::alloc(n, cols, res_size);
|
||||
let mut res_1: Zn<Vec<u8>> = Zn::alloc(n, cols, res_size);
|
||||
|
||||
res_0
|
||||
.raw_mut()
|
||||
.iter_mut()
|
||||
.for_each(|x| *x = source.next_i32() as i64);
|
||||
res_1.raw_mut().copy_from_slice(res_0.raw());
|
||||
|
||||
// Reference
|
||||
for i in 0..cols {
|
||||
zn_normalize_inplace::<_, ZnxRef>(n, basek, &mut res_0, i, &mut carry);
|
||||
module.zn_normalize_inplace(n, basek, &mut res_1, i, scratch.borrow());
|
||||
}
|
||||
|
||||
assert_eq!(res_0.raw(), res_1.raw());
|
||||
}
|
||||
}
|
||||
75
poulpy-hal/src/reference/zn/sampling.rs
Normal file
75
poulpy-hal/src/reference/zn/sampling.rs
Normal file
@@ -0,0 +1,75 @@
|
||||
use crate::{
|
||||
layouts::{Zn, ZnToMut, ZnxInfos, ZnxViewMut},
|
||||
reference::znx::{znx_add_normal_f64_ref, znx_fill_normal_f64_ref, znx_fill_uniform_ref},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
pub fn zn_fill_uniform<R>(n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
|
||||
where
|
||||
R: ZnToMut,
|
||||
{
|
||||
let mut res: Zn<&mut [u8]> = res.to_mut();
|
||||
for j in 0..res.size() {
|
||||
znx_fill_uniform_ref(basek, &mut res.at_mut(res_col, j)[..n], source)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn zn_fill_normal<R>(
|
||||
n: usize,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) where
|
||||
R: ZnToMut,
|
||||
{
|
||||
let mut res: Zn<&mut [u8]> = res.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = k.div_ceil(basek) - 1;
|
||||
let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64;
|
||||
znx_fill_normal_f64_ref(
|
||||
&mut res.at_mut(res_col, limb)[..n],
|
||||
sigma * scale,
|
||||
bound * scale,
|
||||
source,
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn zn_add_normal<R>(
|
||||
n: usize,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) where
|
||||
R: ZnToMut,
|
||||
{
|
||||
let mut res: Zn<&mut [u8]> = res.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = k.div_ceil(basek) - 1;
|
||||
let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64;
|
||||
znx_add_normal_f64_ref(
|
||||
&mut res.at_mut(res_col, limb)[..n],
|
||||
sigma * scale,
|
||||
bound * scale,
|
||||
source,
|
||||
)
|
||||
}
|
||||
25
poulpy-hal/src/reference/znx/add.rs
Normal file
25
poulpy-hal/src/reference/znx/add.rs
Normal file
@@ -0,0 +1,25 @@
|
||||
#[inline(always)]
|
||||
pub fn znx_add_ref(res: &mut [i64], a: &[i64], b: &[i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), a.len());
|
||||
assert_eq!(res.len(), b.len());
|
||||
}
|
||||
|
||||
let n: usize = res.len();
|
||||
for i in 0..n {
|
||||
res[i] = a[i] + b[i];
|
||||
}
|
||||
}
|
||||
|
||||
pub fn znx_add_inplace_ref(res: &mut [i64], a: &[i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), a.len());
|
||||
}
|
||||
|
||||
let n: usize = res.len();
|
||||
for i in 0..n {
|
||||
res[i] += a[i];
|
||||
}
|
||||
}
|
||||
153
poulpy-hal/src/reference/znx/arithmetic_ref.rs
Normal file
153
poulpy-hal/src/reference/znx/arithmetic_ref.rs
Normal file
@@ -0,0 +1,153 @@
|
||||
use crate::reference::znx::{
|
||||
ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxNormalizeFinalStep,
|
||||
ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeFirstStepInplace,
|
||||
ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, ZnxSub, ZnxSubABInplace,
|
||||
ZnxSubBAInplace, ZnxSwitchRing, ZnxZero,
|
||||
add::{znx_add_inplace_ref, znx_add_ref},
|
||||
automorphism::znx_automorphism_ref,
|
||||
copy::znx_copy_ref,
|
||||
neg::{znx_negate_inplace_ref, znx_negate_ref},
|
||||
normalization::{
|
||||
znx_normalize_final_step_inplace_ref, znx_normalize_final_step_ref, znx_normalize_first_step_carry_only_ref,
|
||||
znx_normalize_first_step_inplace_ref, znx_normalize_first_step_ref, znx_normalize_middle_step_carry_only_ref,
|
||||
znx_normalize_middle_step_inplace_ref, znx_normalize_middle_step_ref,
|
||||
},
|
||||
sub::{znx_sub_ab_inplace_ref, znx_sub_ba_inplace_ref, znx_sub_ref},
|
||||
switch_ring::znx_switch_ring_ref,
|
||||
zero::znx_zero_ref,
|
||||
};
|
||||
|
||||
pub struct ZnxRef {}
|
||||
|
||||
impl ZnxAdd for ZnxRef {
|
||||
#[inline(always)]
|
||||
fn znx_add(res: &mut [i64], a: &[i64], b: &[i64]) {
|
||||
znx_add_ref(res, a, b);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxAddInplace for ZnxRef {
|
||||
#[inline(always)]
|
||||
fn znx_add_inplace(res: &mut [i64], a: &[i64]) {
|
||||
znx_add_inplace_ref(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxSub for ZnxRef {
|
||||
#[inline(always)]
|
||||
fn znx_sub(res: &mut [i64], a: &[i64], b: &[i64]) {
|
||||
znx_sub_ref(res, a, b);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxSubABInplace for ZnxRef {
|
||||
#[inline(always)]
|
||||
fn znx_sub_ab_inplace(res: &mut [i64], a: &[i64]) {
|
||||
znx_sub_ab_inplace_ref(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxSubBAInplace for ZnxRef {
|
||||
#[inline(always)]
|
||||
fn znx_sub_ba_inplace(res: &mut [i64], a: &[i64]) {
|
||||
znx_sub_ba_inplace_ref(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxAutomorphism for ZnxRef {
|
||||
#[inline(always)]
|
||||
fn znx_automorphism(p: i64, res: &mut [i64], a: &[i64]) {
|
||||
znx_automorphism_ref(p, res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxCopy for ZnxRef {
|
||||
#[inline(always)]
|
||||
fn znx_copy(res: &mut [i64], a: &[i64]) {
|
||||
znx_copy_ref(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNegate for ZnxRef {
|
||||
#[inline(always)]
|
||||
fn znx_negate(res: &mut [i64], src: &[i64]) {
|
||||
znx_negate_ref(res, src);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNegateInplace for ZnxRef {
|
||||
#[inline(always)]
|
||||
fn znx_negate_inplace(res: &mut [i64]) {
|
||||
znx_negate_inplace_ref(res);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxZero for ZnxRef {
|
||||
#[inline(always)]
|
||||
fn znx_zero(res: &mut [i64]) {
|
||||
znx_zero_ref(res);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxSwitchRing for ZnxRef {
|
||||
#[inline(always)]
|
||||
fn znx_switch_ring(res: &mut [i64], a: &[i64]) {
|
||||
znx_switch_ring_ref(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFinalStep for ZnxRef {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_final_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_final_step_ref(basek, lsh, x, a, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFinalStepInplace for ZnxRef {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_final_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
znx_normalize_final_step_inplace_ref(basek, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFirstStep for ZnxRef {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_first_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_first_step_ref(basek, lsh, x, a, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFirstStepCarryOnly for ZnxRef {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_first_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_first_step_carry_only_ref(basek, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFirstStepInplace for ZnxRef {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_first_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
znx_normalize_first_step_inplace_ref(basek, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeMiddleStep for ZnxRef {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_middle_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_middle_step_ref(basek, lsh, x, a, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeMiddleStepCarryOnly for ZnxRef {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_middle_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_middle_step_carry_only_ref(basek, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeMiddleStepInplace for ZnxRef {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_middle_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
znx_normalize_middle_step_inplace_ref(basek, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
21
poulpy-hal/src/reference/znx/automorphism.rs
Normal file
21
poulpy-hal/src/reference/znx/automorphism.rs
Normal file
@@ -0,0 +1,21 @@
|
||||
pub fn znx_automorphism_ref(p: i64, res: &mut [i64], a: &[i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), a.len());
|
||||
}
|
||||
|
||||
let n: usize = res.len();
|
||||
let mut k: usize = 0usize;
|
||||
let mask: usize = 2 * n - 1;
|
||||
let p_2n = (p & mask as i64) as usize;
|
||||
|
||||
res[0] = a[0];
|
||||
for ai in a.iter().take(n).skip(1) {
|
||||
k = (k + p_2n) & mask;
|
||||
if k < n {
|
||||
res[k] = *ai
|
||||
} else {
|
||||
res[k - n] = -*ai
|
||||
}
|
||||
}
|
||||
}
|
||||
8
poulpy-hal/src/reference/znx/copy.rs
Normal file
8
poulpy-hal/src/reference/znx/copy.rs
Normal file
@@ -0,0 +1,8 @@
|
||||
#[inline(always)]
|
||||
pub fn znx_copy_ref(res: &mut [i64], a: &[i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), a.len())
|
||||
}
|
||||
res.copy_from_slice(a);
|
||||
}
|
||||
104
poulpy-hal/src/reference/znx/mod.rs
Normal file
104
poulpy-hal/src/reference/znx/mod.rs
Normal file
@@ -0,0 +1,104 @@
|
||||
mod add;
|
||||
mod arithmetic_ref;
|
||||
mod automorphism;
|
||||
mod copy;
|
||||
mod neg;
|
||||
mod normalization;
|
||||
mod rotate;
|
||||
mod sampling;
|
||||
mod sub;
|
||||
mod switch_ring;
|
||||
mod zero;
|
||||
|
||||
pub use add::*;
|
||||
pub use arithmetic_ref::*;
|
||||
pub use automorphism::*;
|
||||
pub use copy::*;
|
||||
pub use neg::*;
|
||||
pub use normalization::*;
|
||||
pub use rotate::*;
|
||||
pub use sub::*;
|
||||
pub use switch_ring::*;
|
||||
pub use zero::*;
|
||||
|
||||
pub use sampling::*;
|
||||
|
||||
pub trait ZnxAdd {
|
||||
fn znx_add(res: &mut [i64], a: &[i64], b: &[i64]);
|
||||
}
|
||||
|
||||
pub trait ZnxAddInplace {
|
||||
fn znx_add_inplace(res: &mut [i64], a: &[i64]);
|
||||
}
|
||||
|
||||
pub trait ZnxSub {
|
||||
fn znx_sub(res: &mut [i64], a: &[i64], b: &[i64]);
|
||||
}
|
||||
|
||||
pub trait ZnxSubABInplace {
|
||||
fn znx_sub_ab_inplace(res: &mut [i64], a: &[i64]);
|
||||
}
|
||||
|
||||
pub trait ZnxSubBAInplace {
|
||||
fn znx_sub_ba_inplace(res: &mut [i64], a: &[i64]);
|
||||
}
|
||||
|
||||
pub trait ZnxAutomorphism {
|
||||
fn znx_automorphism(p: i64, res: &mut [i64], a: &[i64]);
|
||||
}
|
||||
|
||||
pub trait ZnxCopy {
|
||||
fn znx_copy(res: &mut [i64], a: &[i64]);
|
||||
}
|
||||
|
||||
pub trait ZnxNegate {
|
||||
fn znx_negate(res: &mut [i64], src: &[i64]);
|
||||
}
|
||||
|
||||
pub trait ZnxNegateInplace {
|
||||
fn znx_negate_inplace(res: &mut [i64]);
|
||||
}
|
||||
|
||||
pub trait ZnxRotate {
|
||||
fn znx_rotate(p: i64, res: &mut [i64], src: &[i64]);
|
||||
}
|
||||
|
||||
pub trait ZnxZero {
|
||||
fn znx_zero(res: &mut [i64]);
|
||||
}
|
||||
|
||||
pub trait ZnxSwitchRing {
|
||||
fn znx_switch_ring(res: &mut [i64], a: &[i64]);
|
||||
}
|
||||
|
||||
pub trait ZnxNormalizeFirstStepCarryOnly {
|
||||
fn znx_normalize_first_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]);
|
||||
}
|
||||
|
||||
pub trait ZnxNormalizeFirstStepInplace {
|
||||
fn znx_normalize_first_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]);
|
||||
}
|
||||
|
||||
pub trait ZnxNormalizeFirstStep {
|
||||
fn znx_normalize_first_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]);
|
||||
}
|
||||
|
||||
pub trait ZnxNormalizeMiddleStepCarryOnly {
|
||||
fn znx_normalize_middle_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]);
|
||||
}
|
||||
|
||||
pub trait ZnxNormalizeMiddleStepInplace {
|
||||
fn znx_normalize_middle_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]);
|
||||
}
|
||||
|
||||
pub trait ZnxNormalizeMiddleStep {
|
||||
fn znx_normalize_middle_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]);
|
||||
}
|
||||
|
||||
pub trait ZnxNormalizeFinalStepInplace {
|
||||
fn znx_normalize_final_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]);
|
||||
}
|
||||
|
||||
pub trait ZnxNormalizeFinalStep {
|
||||
fn znx_normalize_final_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]);
|
||||
}
|
||||
18
poulpy-hal/src/reference/znx/neg.rs
Normal file
18
poulpy-hal/src/reference/znx/neg.rs
Normal file
@@ -0,0 +1,18 @@
|
||||
#[inline(always)]
|
||||
pub fn znx_negate_ref(res: &mut [i64], src: &[i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), src.len())
|
||||
}
|
||||
|
||||
for i in 0..res.len() {
|
||||
res[i] = -src[i]
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn znx_negate_inplace_ref(res: &mut [i64]) {
|
||||
for value in res {
|
||||
*value = -*value
|
||||
}
|
||||
}
|
||||
199
poulpy-hal/src/reference/znx/normalization.rs
Normal file
199
poulpy-hal/src/reference/znx/normalization.rs
Normal file
@@ -0,0 +1,199 @@
|
||||
use itertools::izip;
|
||||
|
||||
#[inline(always)]
|
||||
pub fn get_digit(basek: usize, x: i64) -> i64 {
|
||||
(x << (u64::BITS - basek as u32)) >> (u64::BITS - basek as u32)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn get_carry(basek: usize, x: i64, digit: i64) -> i64 {
|
||||
(x.wrapping_sub(digit)) >> basek
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn znx_normalize_first_step_carry_only_ref(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(x.len() <= carry.len());
|
||||
assert!(lsh < basek);
|
||||
}
|
||||
|
||||
if lsh == 0 {
|
||||
x.iter().zip(carry.iter_mut()).for_each(|(x, c)| {
|
||||
*c = get_carry(basek, *x, get_digit(basek, *x));
|
||||
});
|
||||
} else {
|
||||
let basek_lsh: usize = basek - lsh;
|
||||
x.iter().zip(carry.iter_mut()).for_each(|(x, c)| {
|
||||
*c = get_carry(basek_lsh, *x, get_digit(basek_lsh, *x));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn znx_normalize_first_step_inplace_ref(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(x.len() <= carry.len());
|
||||
assert!(lsh < basek);
|
||||
}
|
||||
|
||||
if lsh == 0 {
|
||||
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
|
||||
let digit: i64 = get_digit(basek, *x);
|
||||
*c = get_carry(basek, *x, digit);
|
||||
*x = digit;
|
||||
});
|
||||
} else {
|
||||
let basek_lsh: usize = basek - lsh;
|
||||
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
|
||||
let digit: i64 = get_digit(basek_lsh, *x);
|
||||
*c = get_carry(basek_lsh, *x, digit);
|
||||
*x = digit << lsh;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn znx_normalize_first_step_ref(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(x.len(), a.len());
|
||||
assert!(x.len() <= carry.len());
|
||||
assert!(lsh < basek);
|
||||
}
|
||||
|
||||
if lsh == 0 {
|
||||
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
|
||||
let digit: i64 = get_digit(basek, *a);
|
||||
*c = get_carry(basek, *a, digit);
|
||||
*x = digit;
|
||||
});
|
||||
} else {
|
||||
let basek_lsh: usize = basek - lsh;
|
||||
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
|
||||
let digit: i64 = get_digit(basek_lsh, *a);
|
||||
*c = get_carry(basek_lsh, *a, digit);
|
||||
*x = digit << lsh;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn znx_normalize_middle_step_carry_only_ref(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(x.len() <= carry.len());
|
||||
assert!(lsh < basek);
|
||||
}
|
||||
if lsh == 0 {
|
||||
x.iter().zip(carry.iter_mut()).for_each(|(x, c)| {
|
||||
let digit: i64 = get_digit(basek, *x);
|
||||
let carry: i64 = get_carry(basek, *x, digit);
|
||||
let digit_plus_c: i64 = digit + *c;
|
||||
*c = carry + get_carry(basek, digit_plus_c, get_digit(basek, digit_plus_c));
|
||||
});
|
||||
} else {
|
||||
let basek_lsh: usize = basek - lsh;
|
||||
x.iter().zip(carry.iter_mut()).for_each(|(x, c)| {
|
||||
let digit: i64 = get_digit(basek_lsh, *x);
|
||||
let carry: i64 = get_carry(basek_lsh, *x, digit);
|
||||
let digit_plus_c: i64 = (digit << lsh) + *c;
|
||||
*c = carry + get_carry(basek, digit_plus_c, get_digit(basek, digit_plus_c));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn znx_normalize_middle_step_inplace_ref(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(x.len() <= carry.len());
|
||||
assert!(lsh < basek);
|
||||
}
|
||||
if lsh == 0 {
|
||||
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
|
||||
let digit: i64 = get_digit(basek, *x);
|
||||
let carry: i64 = get_carry(basek, *x, digit);
|
||||
let digit_plus_c: i64 = digit + *c;
|
||||
*x = get_digit(basek, digit_plus_c);
|
||||
*c = carry + get_carry(basek, digit_plus_c, *x);
|
||||
});
|
||||
} else {
|
||||
let basek_lsh: usize = basek - lsh;
|
||||
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
|
||||
let digit: i64 = get_digit(basek_lsh, *x);
|
||||
let carry: i64 = get_carry(basek_lsh, *x, digit);
|
||||
let digit_plus_c: i64 = (digit << lsh) + *c;
|
||||
*x = get_digit(basek, digit_plus_c);
|
||||
*c = carry + get_carry(basek, digit_plus_c, *x);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn znx_normalize_middle_step_ref(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(x.len(), a.len());
|
||||
assert!(x.len() <= carry.len());
|
||||
assert!(lsh < basek);
|
||||
}
|
||||
if lsh == 0 {
|
||||
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
|
||||
let digit: i64 = get_digit(basek, *a);
|
||||
let carry: i64 = get_carry(basek, *a, digit);
|
||||
let digit_plus_c: i64 = digit + *c;
|
||||
*x = get_digit(basek, digit_plus_c);
|
||||
*c = carry + get_carry(basek, digit_plus_c, *x);
|
||||
});
|
||||
} else {
|
||||
let basek_lsh: usize = basek - lsh;
|
||||
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
|
||||
let digit: i64 = get_digit(basek_lsh, *a);
|
||||
let carry: i64 = get_carry(basek_lsh, *a, digit);
|
||||
let digit_plus_c: i64 = (digit << lsh) + *c;
|
||||
*x = get_digit(basek, digit_plus_c);
|
||||
*c = carry + get_carry(basek, digit_plus_c, *x);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn znx_normalize_final_step_inplace_ref(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(x.len() <= carry.len());
|
||||
assert!(lsh < basek);
|
||||
}
|
||||
|
||||
if lsh == 0 {
|
||||
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
|
||||
*x = get_digit(basek, get_digit(basek, *x) + *c);
|
||||
});
|
||||
} else {
|
||||
let basek_lsh: usize = basek - lsh;
|
||||
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
|
||||
*x = get_digit(basek, (get_digit(basek_lsh, *x) << lsh) + *c);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn znx_normalize_final_step_ref(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(x.len() <= carry.len());
|
||||
assert!(lsh < basek);
|
||||
}
|
||||
if lsh == 0 {
|
||||
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
|
||||
*x = get_digit(basek, get_digit(basek, *a) + *c);
|
||||
});
|
||||
} else {
|
||||
let basek_lsh: usize = basek - lsh;
|
||||
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
|
||||
*x = get_digit(basek, (get_digit(basek_lsh, *a) << lsh) + *c);
|
||||
});
|
||||
}
|
||||
}
|
||||
26
poulpy-hal/src/reference/znx/rotate.rs
Normal file
26
poulpy-hal/src/reference/znx/rotate.rs
Normal file
@@ -0,0 +1,26 @@
|
||||
use crate::reference::znx::{ZnxCopy, ZnxNegate};
|
||||
|
||||
pub fn znx_rotate<ZNXARI: ZnxNegate + ZnxCopy>(p: i64, res: &mut [i64], src: &[i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), src.len());
|
||||
}
|
||||
|
||||
let n: usize = res.len();
|
||||
|
||||
let mp_2n: usize = (p & (2 * n as i64 - 1)) as usize; // -p % 2n
|
||||
let mp_1n: usize = mp_2n & (n - 1); // -p % n
|
||||
let mp_1n_neg: usize = n - mp_1n; // p % n
|
||||
let neg_first: bool = mp_2n < n;
|
||||
|
||||
let (dst1, dst2) = res.split_at_mut(mp_1n);
|
||||
let (src1, src2) = src.split_at(mp_1n_neg);
|
||||
|
||||
if neg_first {
|
||||
ZNXARI::znx_negate(dst1, src2);
|
||||
ZNXARI::znx_copy(dst2, src1);
|
||||
} else {
|
||||
ZNXARI::znx_copy(dst1, src2);
|
||||
ZNXARI::znx_negate(dst2, src1);
|
||||
}
|
||||
}
|
||||
53
poulpy-hal/src/reference/znx/sampling.rs
Normal file
53
poulpy-hal/src/reference/znx/sampling.rs
Normal file
@@ -0,0 +1,53 @@
|
||||
use rand_distr::{Distribution, Normal};
|
||||
|
||||
use crate::source::Source;
|
||||
|
||||
pub fn znx_fill_uniform_ref(basek: usize, res: &mut [i64], source: &mut Source) {
|
||||
let pow2k: u64 = 1 << basek;
|
||||
let mask: u64 = pow2k - 1;
|
||||
let pow2k_half: i64 = (pow2k >> 1) as i64;
|
||||
res.iter_mut()
|
||||
.for_each(|xi| *xi = (source.next_u64n(pow2k, mask) as i64) - pow2k_half)
|
||||
}
|
||||
|
||||
pub fn znx_fill_dist_f64_ref<D: rand::prelude::Distribution<f64>>(res: &mut [i64], dist: D, bound: f64, source: &mut Source) {
|
||||
res.iter_mut().for_each(|xi| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*xi = dist_f64.round() as i64
|
||||
})
|
||||
}
|
||||
|
||||
pub fn znx_add_dist_f64_ref<D: rand::prelude::Distribution<f64>>(res: &mut [i64], dist: D, bound: f64, source: &mut Source) {
|
||||
res.iter_mut().for_each(|xi| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*xi += dist_f64.round() as i64
|
||||
})
|
||||
}
|
||||
|
||||
pub fn znx_fill_normal_f64_ref(res: &mut [i64], sigma: f64, bound: f64, source: &mut Source) {
|
||||
let normal: Normal<f64> = Normal::new(0.0, sigma).unwrap();
|
||||
res.iter_mut().for_each(|xi| {
|
||||
let mut dist_f64: f64 = normal.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = normal.sample(source)
|
||||
}
|
||||
*xi = dist_f64.round() as i64
|
||||
})
|
||||
}
|
||||
|
||||
pub fn znx_add_normal_f64_ref(res: &mut [i64], sigma: f64, bound: f64, source: &mut Source) {
|
||||
let normal: Normal<f64> = Normal::new(0.0, sigma).unwrap();
|
||||
res.iter_mut().for_each(|xi| {
|
||||
let mut dist_f64: f64 = normal.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = normal.sample(source)
|
||||
}
|
||||
*xi += dist_f64.round() as i64
|
||||
})
|
||||
}
|
||||
36
poulpy-hal/src/reference/znx/sub.rs
Normal file
36
poulpy-hal/src/reference/znx/sub.rs
Normal file
@@ -0,0 +1,36 @@
|
||||
pub fn znx_sub_ref(res: &mut [i64], a: &[i64], b: &[i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), a.len());
|
||||
assert_eq!(res.len(), b.len());
|
||||
}
|
||||
|
||||
let n: usize = res.len();
|
||||
for i in 0..n {
|
||||
res[i] = a[i] - b[i];
|
||||
}
|
||||
}
|
||||
|
||||
pub fn znx_sub_ab_inplace_ref(res: &mut [i64], a: &[i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), a.len());
|
||||
}
|
||||
|
||||
let n: usize = res.len();
|
||||
for i in 0..n {
|
||||
res[i] -= a[i];
|
||||
}
|
||||
}
|
||||
|
||||
pub fn znx_sub_ba_inplace_ref(res: &mut [i64], a: &[i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), a.len());
|
||||
}
|
||||
|
||||
let n: usize = res.len();
|
||||
for i in 0..n {
|
||||
res[i] = a[i] - res[i];
|
||||
}
|
||||
}
|
||||
29
poulpy-hal/src/reference/znx/switch_ring.rs
Normal file
29
poulpy-hal/src/reference/znx/switch_ring.rs
Normal file
@@ -0,0 +1,29 @@
|
||||
use crate::reference::znx::{copy::znx_copy_ref, zero::znx_zero_ref};
|
||||
|
||||
pub fn znx_switch_ring_ref(res: &mut [i64], a: &[i64]) {
|
||||
let (n_in, n_out) = (a.len(), res.len());
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(n_in.is_power_of_two());
|
||||
assert!(n_in.max(n_out).is_multiple_of(n_in.min(n_out)))
|
||||
}
|
||||
|
||||
if n_in == n_out {
|
||||
znx_copy_ref(res, a);
|
||||
return;
|
||||
}
|
||||
|
||||
let (gap_in, gap_out): (usize, usize);
|
||||
if n_in > n_out {
|
||||
(gap_in, gap_out) = (n_in / n_out, 1)
|
||||
} else {
|
||||
(gap_in, gap_out) = (1, n_out / n_in);
|
||||
znx_zero_ref(res);
|
||||
}
|
||||
|
||||
res.iter_mut()
|
||||
.step_by(gap_out)
|
||||
.zip(a.iter().step_by(gap_in))
|
||||
.for_each(|(x_out, x_in)| *x_out = *x_in);
|
||||
}
|
||||
3
poulpy-hal/src/reference/znx/zero.rs
Normal file
3
poulpy-hal/src/reference/znx/zero.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub fn znx_zero_ref(res: &mut [i64]) {
|
||||
res.fill(0);
|
||||
}
|
||||
@@ -39,6 +39,12 @@ impl Source {
|
||||
min + ((self.next_u64() << 11 >> 11) as f64) / MAXF64 * (max - min)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn next_i32(&mut self) -> i32 {
|
||||
self.next_u32() as i32
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn next_i64(&mut self) -> i64 {
|
||||
self.next_u64() as i64
|
||||
}
|
||||
|
||||
68
poulpy-hal/src/test_suite/mod.rs
Normal file
68
poulpy-hal/src/test_suite/mod.rs
Normal file
@@ -0,0 +1,68 @@
|
||||
pub mod serialization;
|
||||
pub mod svp;
|
||||
pub mod vec_znx;
|
||||
pub mod vec_znx_big;
|
||||
pub mod vec_znx_dft;
|
||||
pub mod vmp;
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! backend_test_suite {
|
||||
(
|
||||
mod $modname:ident,
|
||||
backend = $backend:ty,
|
||||
size = $size:expr,
|
||||
tests = {
|
||||
$( $(#[$attr:meta])* $test_name:ident => $impl:path ),+ $(,)?
|
||||
}
|
||||
) => {
|
||||
mod $modname {
|
||||
use poulpy_hal::{api::ModuleNew, layouts::Module};
|
||||
|
||||
use once_cell::sync::Lazy;
|
||||
|
||||
static MODULE: Lazy<Module<$backend>> =
|
||||
Lazy::new(|| Module::<$backend>::new($size));
|
||||
|
||||
$(
|
||||
$(#[$attr])*
|
||||
#[test]
|
||||
fn $test_name() {
|
||||
($impl)(&*MODULE);
|
||||
}
|
||||
)+
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! cross_backend_test_suite {
|
||||
(
|
||||
mod $modname:ident,
|
||||
backend_ref = $backend_ref:ty,
|
||||
backend_test = $backend_test:ty,
|
||||
size = $size:expr,
|
||||
basek = $basek:expr,
|
||||
tests = {
|
||||
$( $(#[$attr:meta])* $test_name:ident => $impl:path ),+ $(,)?
|
||||
}
|
||||
) => {
|
||||
mod $modname {
|
||||
use poulpy_hal::{api::ModuleNew, layouts::Module};
|
||||
|
||||
use once_cell::sync::Lazy;
|
||||
|
||||
static MODULE_REF: Lazy<Module<$backend_ref>> =
|
||||
Lazy::new(|| Module::<$backend_ref>::new($size));
|
||||
static MODULE_TEST: Lazy<Module<$backend_test>> =
|
||||
Lazy::new(|| Module::<$backend_test>::new($size));
|
||||
|
||||
$(
|
||||
$(#[$attr])*
|
||||
#[test]
|
||||
fn $test_name() {
|
||||
($impl)($basek, &*MODULE_REF, &*MODULE_TEST);
|
||||
}
|
||||
)+
|
||||
}
|
||||
};
|
||||
}
|
||||
@@ -14,7 +14,7 @@ where
|
||||
{
|
||||
// Fill original with uniform random data
|
||||
let mut source = Source::new([0u8; 32]);
|
||||
original.fill_uniform(&mut source);
|
||||
original.fill_uniform(50, &mut source);
|
||||
|
||||
// Serialize into a buffer
|
||||
let mut buffer = Vec::new();
|
||||
470
poulpy-hal/src/test_suite/svp.rs
Normal file
470
poulpy-hal/src/test_suite/svp.rs
Normal file
@@ -0,0 +1,470 @@
|
||||
use rand::RngCore;
|
||||
|
||||
use crate::{
|
||||
api::{
|
||||
ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDft, SvpApplyDftToDft, SvpApplyDftToDftAdd, SvpApplyDftToDftInplace,
|
||||
SvpPPolAlloc, SvpPrepare, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAlloc, VecZnxDftApply,
|
||||
VecZnxIdftApplyConsume,
|
||||
},
|
||||
layouts::{Backend, DataViewMut, DigestU64, FillUniform, Module, ScalarZnx, ScratchOwned, SvpPPol, VecZnx, VecZnxDft},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
pub fn test_svp_apply_dft<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
|
||||
where
|
||||
Module<BR>: SvpPrepare<BR>
|
||||
+ SvpApplyDft<BR>
|
||||
+ SvpPPolAlloc<BR>
|
||||
+ VecZnxDftAlloc<BR>
|
||||
+ VecZnxBigNormalize<BR>
|
||||
+ VecZnxIdftApplyConsume<BR>
|
||||
+ VecZnxBigNormalizeTmpBytes,
|
||||
Module<BT>: SvpPrepare<BT>
|
||||
+ SvpApplyDft<BT>
|
||||
+ SvpPPolAlloc<BT>
|
||||
+ VecZnxDftAlloc<BT>
|
||||
+ VecZnxBigNormalize<BT>
|
||||
+ VecZnxIdftApplyConsume<BT>
|
||||
+ VecZnxBigNormalizeTmpBytes,
|
||||
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
|
||||
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
|
||||
{
|
||||
assert_eq!(module_ref.n(), module_test.n());
|
||||
let n: usize = module_ref.n();
|
||||
|
||||
let cols: usize = 2;
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes());
|
||||
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes());
|
||||
|
||||
let mut scalar: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(n, cols);
|
||||
scalar.fill_uniform(basek, &mut source);
|
||||
|
||||
let scalar_digest: u64 = scalar.digest_u64();
|
||||
|
||||
let mut svp_ref: SvpPPol<Vec<u8>, BR> = module_ref.svp_ppol_alloc(cols);
|
||||
let mut svp_test: SvpPPol<Vec<u8>, BT> = module_test.svp_ppol_alloc(cols);
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.svp_prepare(&mut svp_ref, j, &scalar, j);
|
||||
module_test.svp_prepare(&mut svp_test, j, &scalar, j);
|
||||
}
|
||||
|
||||
assert_eq!(scalar.digest_u64(), scalar_digest);
|
||||
|
||||
let svp_ref_digest: u64 = svp_ref.digest_u64();
|
||||
let svp_test_digest: u64 = svp_test.digest_u64();
|
||||
|
||||
for a_size in [1, 2, 3, 4] {
|
||||
// Create a random input VecZnx
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
|
||||
a.fill_uniform(basek, &mut source);
|
||||
|
||||
let a_digest: u64 = a.digest_u64();
|
||||
|
||||
for res_size in [1, 2, 3, 4] {
|
||||
// Allocate VecZnxDft from FFT64Ref and module to test
|
||||
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, res_size);
|
||||
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, res_size);
|
||||
|
||||
// Fill output with garbage
|
||||
source.fill_bytes(res_dft_ref.data_mut());
|
||||
source.fill_bytes(res_dft_test.data_mut());
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.svp_apply_dft(&mut res_dft_ref, j, &svp_ref, j, &a, j);
|
||||
module_test.svp_apply_dft(&mut res_dft_test, j, &svp_test, j, &a, j);
|
||||
}
|
||||
|
||||
// Assert no change to inputs
|
||||
assert_eq!(svp_ref.digest_u64(), svp_ref_digest);
|
||||
assert_eq!(svp_test.digest_u64(), svp_test_digest);
|
||||
assert_eq!(a.digest_u64(), a_digest);
|
||||
|
||||
let res_big_ref: crate::layouts::VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
|
||||
let res_big_test: crate::layouts::VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
|
||||
|
||||
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
basek,
|
||||
&mut res_ref,
|
||||
j,
|
||||
&res_big_ref,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
basek,
|
||||
&mut res_test,
|
||||
j,
|
||||
&res_big_test,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
}
|
||||
|
||||
assert_eq!(res_ref, res_test);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn test_svp_apply_dft_to_dft<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
|
||||
where
|
||||
Module<BR>: SvpPrepare<BR>
|
||||
+ SvpApplyDftToDft<BR>
|
||||
+ SvpPPolAlloc<BR>
|
||||
+ VecZnxDftAlloc<BR>
|
||||
+ VecZnxBigNormalize<BR>
|
||||
+ VecZnxDftApply<BR>
|
||||
+ VecZnxIdftApplyConsume<BR>
|
||||
+ VecZnxBigNormalizeTmpBytes,
|
||||
Module<BT>: SvpPrepare<BT>
|
||||
+ SvpApplyDftToDft<BT>
|
||||
+ SvpPPolAlloc<BT>
|
||||
+ VecZnxDftAlloc<BT>
|
||||
+ VecZnxBigNormalize<BT>
|
||||
+ VecZnxDftApply<BT>
|
||||
+ VecZnxIdftApplyConsume<BT>
|
||||
+ VecZnxBigNormalizeTmpBytes,
|
||||
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
|
||||
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
|
||||
{
|
||||
assert_eq!(module_ref.n(), module_test.n());
|
||||
let n: usize = module_ref.n();
|
||||
|
||||
let cols: usize = 2;
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes());
|
||||
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes());
|
||||
|
||||
let mut scalar: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(n, cols);
|
||||
scalar.fill_uniform(basek, &mut source);
|
||||
|
||||
let scalar_digest: u64 = scalar.digest_u64();
|
||||
|
||||
let mut svp_ref: SvpPPol<Vec<u8>, BR> = module_ref.svp_ppol_alloc(cols);
|
||||
let mut svp_test: SvpPPol<Vec<u8>, BT> = module_test.svp_ppol_alloc(cols);
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.svp_prepare(&mut svp_ref, j, &scalar, j);
|
||||
module_test.svp_prepare(&mut svp_test, j, &scalar, j);
|
||||
}
|
||||
|
||||
assert_eq!(scalar.digest_u64(), scalar_digest);
|
||||
|
||||
let svp_ref_digest: u64 = svp_ref.digest_u64();
|
||||
let svp_test_digest: u64 = svp_test.digest_u64();
|
||||
|
||||
for a_size in [3] {
|
||||
// Create a random input VecZnx
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
|
||||
a.fill_uniform(basek, &mut source);
|
||||
|
||||
let a_digest: u64 = a.digest_u64();
|
||||
|
||||
let mut a_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, a_size);
|
||||
let mut a_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, a_size);
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_dft_apply(1, 0, &mut a_dft_ref, j, &a, j);
|
||||
module_test.vec_znx_dft_apply(1, 0, &mut a_dft_test, j, &a, j);
|
||||
}
|
||||
|
||||
assert_eq!(a.digest_u64(), a_digest);
|
||||
|
||||
let a_dft_ref_digest: u64 = a_dft_ref.digest_u64();
|
||||
let a_dft_test_digest: u64 = a_dft_test.digest_u64();
|
||||
|
||||
for res_size in [3] {
|
||||
// Allocate VecZnxDft from FFT64Ref and module to test
|
||||
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, res_size);
|
||||
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, res_size);
|
||||
|
||||
// Fill output with garbage
|
||||
source.fill_bytes(res_dft_ref.data_mut());
|
||||
source.fill_bytes(res_dft_test.data_mut());
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.svp_apply_dft_to_dft(&mut res_dft_ref, j, &svp_ref, j, &a_dft_ref, j);
|
||||
module_test.svp_apply_dft_to_dft(&mut res_dft_test, j, &svp_test, j, &a_dft_test, j);
|
||||
}
|
||||
|
||||
// Assert no change to inputs
|
||||
assert_eq!(a_dft_ref.digest_u64(), a_dft_ref_digest);
|
||||
assert_eq!(a_dft_test.digest_u64(), a_dft_test_digest);
|
||||
assert_eq!(svp_ref.digest_u64(), svp_ref_digest);
|
||||
assert_eq!(svp_test.digest_u64(), svp_test_digest);
|
||||
assert_eq!(a.digest_u64(), a_digest);
|
||||
|
||||
let res_big_ref: crate::layouts::VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
|
||||
let res_big_test: crate::layouts::VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
|
||||
|
||||
println!("res_big_ref: {}", res_big_ref);
|
||||
println!("res_big_test: {}", res_big_test);
|
||||
|
||||
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
basek,
|
||||
&mut res_ref,
|
||||
j,
|
||||
&res_big_ref,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
basek,
|
||||
&mut res_test,
|
||||
j,
|
||||
&res_big_test,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
}
|
||||
|
||||
assert_eq!(res_ref, res_test);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn test_svp_apply_dft_to_dft_add<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
|
||||
where
|
||||
Module<BR>: SvpPrepare<BR>
|
||||
+ SvpApplyDftToDftAdd<BR>
|
||||
+ SvpPPolAlloc<BR>
|
||||
+ VecZnxDftAlloc<BR>
|
||||
+ VecZnxBigNormalize<BR>
|
||||
+ VecZnxDftApply<BR>
|
||||
+ VecZnxIdftApplyConsume<BR>
|
||||
+ VecZnxBigNormalizeTmpBytes,
|
||||
Module<BT>: SvpPrepare<BT>
|
||||
+ SvpApplyDftToDftAdd<BT>
|
||||
+ SvpPPolAlloc<BT>
|
||||
+ VecZnxDftAlloc<BT>
|
||||
+ VecZnxBigNormalize<BT>
|
||||
+ VecZnxDftApply<BT>
|
||||
+ VecZnxIdftApplyConsume<BT>
|
||||
+ VecZnxBigNormalizeTmpBytes,
|
||||
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
|
||||
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
|
||||
{
|
||||
assert_eq!(module_ref.n(), module_test.n());
|
||||
let n: usize = module_ref.n();
|
||||
|
||||
let cols: usize = 2;
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes());
|
||||
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes());
|
||||
|
||||
let mut scalar: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(n, cols);
|
||||
scalar.fill_uniform(basek, &mut source);
|
||||
|
||||
let scalar_digest: u64 = scalar.digest_u64();
|
||||
|
||||
let mut svp_ref: SvpPPol<Vec<u8>, BR> = module_ref.svp_ppol_alloc(cols);
|
||||
let mut svp_test: SvpPPol<Vec<u8>, BT> = module_test.svp_ppol_alloc(cols);
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.svp_prepare(&mut svp_ref, j, &scalar, j);
|
||||
module_test.svp_prepare(&mut svp_test, j, &scalar, j);
|
||||
}
|
||||
|
||||
assert_eq!(scalar.digest_u64(), scalar_digest);
|
||||
|
||||
let svp_ref_digest: u64 = svp_ref.digest_u64();
|
||||
let svp_test_digest: u64 = svp_test.digest_u64();
|
||||
|
||||
for a_size in [1, 2, 3, 4] {
|
||||
// Create a random input VecZnx
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
|
||||
a.fill_uniform(basek, &mut source);
|
||||
|
||||
let a_digest: u64 = a.digest_u64();
|
||||
|
||||
let mut a_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, a_size);
|
||||
let mut a_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, a_size);
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_dft_apply(1, 0, &mut a_dft_ref, j, &a, j);
|
||||
module_test.vec_znx_dft_apply(1, 0, &mut a_dft_test, j, &a, j);
|
||||
}
|
||||
|
||||
let a_dft_ref_digest: u64 = a_dft_ref.digest_u64();
|
||||
let a_dft_test_digest: u64 = a_dft_test.digest_u64();
|
||||
|
||||
for res_size in [1, 2, 3, 4] {
|
||||
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
res.fill_uniform(basek, &mut source);
|
||||
|
||||
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, res_size);
|
||||
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, res_size);
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_dft_apply(1, 0, &mut res_dft_ref, j, &res, j);
|
||||
module_test.vec_znx_dft_apply(1, 0, &mut res_dft_test, j, &res, j);
|
||||
}
|
||||
|
||||
// Fill output with garbage
|
||||
source.fill_bytes(res_dft_ref.data_mut());
|
||||
source.fill_bytes(res_dft_test.data_mut());
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.svp_apply_dft_to_dft_add(&mut res_dft_ref, j, &svp_ref, j, &a_dft_ref, j);
|
||||
module_test.svp_apply_dft_to_dft_add(&mut res_dft_test, j, &svp_test, j, &a_dft_test, j);
|
||||
}
|
||||
|
||||
// Assert no change to inputs
|
||||
assert_eq!(a_dft_ref.digest_u64(), a_dft_ref_digest);
|
||||
assert_eq!(a_dft_test.digest_u64(), a_dft_test_digest);
|
||||
assert_eq!(svp_ref.digest_u64(), svp_ref_digest);
|
||||
assert_eq!(svp_test.digest_u64(), svp_test_digest);
|
||||
assert_eq!(a.digest_u64(), a_digest);
|
||||
|
||||
let res_big_ref: crate::layouts::VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
|
||||
let res_big_test: crate::layouts::VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
|
||||
|
||||
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
basek,
|
||||
&mut res_ref,
|
||||
j,
|
||||
&res_big_ref,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
basek,
|
||||
&mut res_test,
|
||||
j,
|
||||
&res_big_test,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
}
|
||||
|
||||
assert_eq!(res_ref, res_test);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn test_svp_apply_dft_to_dft_inplace<BR: Backend, BT: Backend>(
|
||||
basek: usize,
|
||||
module_ref: &Module<BR>,
|
||||
module_test: &Module<BT>,
|
||||
) where
|
||||
Module<BR>: SvpPrepare<BR>
|
||||
+ SvpApplyDftToDftInplace<BR>
|
||||
+ SvpPPolAlloc<BR>
|
||||
+ VecZnxDftAlloc<BR>
|
||||
+ VecZnxBigNormalize<BR>
|
||||
+ VecZnxDftApply<BR>
|
||||
+ VecZnxIdftApplyConsume<BR>
|
||||
+ VecZnxBigNormalizeTmpBytes,
|
||||
Module<BT>: SvpPrepare<BT>
|
||||
+ SvpApplyDftToDftInplace<BT>
|
||||
+ SvpPPolAlloc<BT>
|
||||
+ VecZnxDftAlloc<BT>
|
||||
+ VecZnxBigNormalize<BT>
|
||||
+ VecZnxDftApply<BT>
|
||||
+ VecZnxIdftApplyConsume<BT>
|
||||
+ VecZnxBigNormalizeTmpBytes,
|
||||
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
|
||||
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
|
||||
{
|
||||
assert_eq!(module_ref.n(), module_test.n());
|
||||
let n: usize = module_ref.n();
|
||||
|
||||
let cols: usize = 2;
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes());
|
||||
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes());
|
||||
|
||||
let mut scalar: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(n, cols);
|
||||
scalar.fill_uniform(basek, &mut source);
|
||||
|
||||
let scalar_digest: u64 = scalar.digest_u64();
|
||||
|
||||
let mut svp_ref: SvpPPol<Vec<u8>, BR> = module_ref.svp_ppol_alloc(cols);
|
||||
let mut svp_test: SvpPPol<Vec<u8>, BT> = module_test.svp_ppol_alloc(cols);
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.svp_prepare(&mut svp_ref, j, &scalar, j);
|
||||
module_test.svp_prepare(&mut svp_test, j, &scalar, j);
|
||||
}
|
||||
|
||||
assert_eq!(scalar.digest_u64(), scalar_digest);
|
||||
|
||||
let svp_ref_digest: u64 = svp_ref.digest_u64();
|
||||
let svp_test_digest: u64 = svp_test.digest_u64();
|
||||
|
||||
for res_size in [1, 2, 3, 4] {
|
||||
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
res.fill_uniform(basek, &mut source);
|
||||
let res_digest: u64 = res.digest_u64();
|
||||
|
||||
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, res_size);
|
||||
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, res_size);
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_dft_apply(1, 0, &mut res_dft_ref, j, &res, j);
|
||||
module_test.vec_znx_dft_apply(1, 0, &mut res_dft_test, j, &res, j);
|
||||
}
|
||||
|
||||
assert_eq!(res.digest_u64(), res_digest);
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.svp_apply_dft_to_dft_inplace(&mut res_dft_ref, j, &svp_ref, j);
|
||||
module_test.svp_apply_dft_to_dft_inplace(&mut res_dft_test, j, &svp_test, j);
|
||||
}
|
||||
|
||||
// Assert no change to inputs
|
||||
assert_eq!(svp_ref.digest_u64(), svp_ref_digest);
|
||||
assert_eq!(svp_test.digest_u64(), svp_test_digest);
|
||||
|
||||
let res_big_ref: crate::layouts::VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
|
||||
let res_big_test: crate::layouts::VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
|
||||
|
||||
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
basek,
|
||||
&mut res_ref,
|
||||
j,
|
||||
&res_big_ref,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
basek,
|
||||
&mut res_test,
|
||||
j,
|
||||
&res_big_test,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
}
|
||||
|
||||
println!("res_ref: {}", res_ref);
|
||||
println!("res_test: {}", res_test);
|
||||
|
||||
assert_eq!(res_ref, res_test);
|
||||
}
|
||||
}
|
||||
1255
poulpy-hal/src/test_suite/vec_znx.rs
Normal file
1255
poulpy-hal/src/test_suite/vec_znx.rs
Normal file
File diff suppressed because it is too large
Load Diff
1432
poulpy-hal/src/test_suite/vec_znx_big.rs
Normal file
1432
poulpy-hal/src/test_suite/vec_znx_big.rs
Normal file
File diff suppressed because it is too large
Load Diff
930
poulpy-hal/src/test_suite/vec_znx_dft.rs
Normal file
930
poulpy-hal/src/test_suite/vec_znx_dft.rs
Normal file
@@ -0,0 +1,930 @@
|
||||
use rand::RngCore;
|
||||
|
||||
use crate::{
|
||||
api::{
|
||||
ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxBigAlloc, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAdd,
|
||||
VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftApply, VecZnxDftCopy, VecZnxDftSub, VecZnxDftSubABInplace,
|
||||
VecZnxDftSubBAInplace, VecZnxIdftApply, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxIdftApplyTmpBytes,
|
||||
},
|
||||
layouts::{Backend, DataViewMut, DigestU64, FillUniform, Module, ScratchOwned, VecZnx, VecZnxBig, VecZnxDft},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
pub fn test_vec_znx_dft_add<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
|
||||
where
|
||||
Module<BR>: VecZnxDftAdd<BR>
|
||||
+ VecZnxDftAlloc<BR>
|
||||
+ VecZnxDftApply<BR>
|
||||
+ VecZnxIdftApplyConsume<BR>
|
||||
+ VecZnxBigNormalize<BR>
|
||||
+ VecZnxBigNormalizeTmpBytes,
|
||||
Module<BT>: VecZnxDftAdd<BT>
|
||||
+ VecZnxDftAlloc<BT>
|
||||
+ VecZnxDftApply<BT>
|
||||
+ VecZnxIdftApplyConsume<BT>
|
||||
+ VecZnxBigNormalize<BT>
|
||||
+ VecZnxBigNormalizeTmpBytes,
|
||||
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
|
||||
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
|
||||
{
|
||||
assert_eq!(module_ref.n(), module_test.n());
|
||||
let n: usize = module_ref.n();
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let cols: usize = 2;
|
||||
|
||||
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes());
|
||||
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes());
|
||||
|
||||
for a_size in [1, 2, 3, 4] {
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
|
||||
a.fill_uniform(basek, &mut source);
|
||||
let a_digest = a.digest_u64();
|
||||
|
||||
let mut a_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, a_size);
|
||||
let mut a_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, a_size);
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_dft_apply(1, 0, &mut a_dft_ref, j, &a, j);
|
||||
module_test.vec_znx_dft_apply(1, 0, &mut a_dft_test, j, &a, j);
|
||||
}
|
||||
|
||||
assert_eq!(a.digest_u64(), a_digest);
|
||||
|
||||
let a_dft_ref_digest: u64 = a_dft_ref.digest_u64();
|
||||
let a_dft_test_digest: u64 = a_dft_test.digest_u64();
|
||||
|
||||
for b_size in [1, 2, 3, 4] {
|
||||
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, b_size);
|
||||
b.fill_uniform(basek, &mut source);
|
||||
let b_digest: u64 = b.digest_u64();
|
||||
|
||||
let mut b_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, b_size);
|
||||
let mut b_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, b_size);
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_dft_apply(1, 0, &mut b_dft_ref, j, &b, j);
|
||||
module_test.vec_znx_dft_apply(1, 0, &mut b_dft_test, j, &b, j);
|
||||
}
|
||||
|
||||
assert_eq!(b.digest_u64(), b_digest);
|
||||
|
||||
let b_dft_ref_digest: u64 = b_dft_ref.digest_u64();
|
||||
let b_dft_test_digest: u64 = b_dft_test.digest_u64();
|
||||
|
||||
for res_size in [1, 2, 3, 4] {
|
||||
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, res_size);
|
||||
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, res_size);
|
||||
|
||||
// Set d to garbage
|
||||
source.fill_bytes(res_dft_ref.data_mut());
|
||||
source.fill_bytes(res_dft_test.data_mut());
|
||||
|
||||
// Reference
|
||||
for i in 0..cols {
|
||||
module_ref.vec_znx_dft_add(&mut res_dft_ref, i, &a_dft_ref, i, &b_dft_ref, i);
|
||||
module_test.vec_znx_dft_add(&mut res_dft_test, i, &a_dft_test, i, &b_dft_test, i);
|
||||
}
|
||||
|
||||
assert_eq!(a_dft_ref.digest_u64(), a_dft_ref_digest);
|
||||
assert_eq!(a_dft_test.digest_u64(), a_dft_test_digest);
|
||||
assert_eq!(b_dft_ref.digest_u64(), b_dft_ref_digest);
|
||||
assert_eq!(b_dft_test.digest_u64(), b_dft_test_digest);
|
||||
|
||||
let res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
|
||||
let res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
|
||||
|
||||
let mut res_small_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
let mut res_small_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
|
||||
let res_ref_digest: u64 = res_big_ref.digest_u64();
|
||||
let res_test_digest: u64 = res_big_test.digest_u64();
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
basek,
|
||||
&mut res_small_ref,
|
||||
j,
|
||||
&res_big_ref,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
basek,
|
||||
&mut res_small_test,
|
||||
j,
|
||||
&res_big_test,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
}
|
||||
|
||||
assert_eq!(res_big_ref.digest_u64(), res_ref_digest);
|
||||
assert_eq!(res_big_test.digest_u64(), res_test_digest);
|
||||
|
||||
assert_eq!(res_small_ref, res_small_test);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn test_vec_znx_dft_add_inplace<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
|
||||
where
|
||||
Module<BR>: VecZnxDftAddInplace<BR>
|
||||
+ VecZnxDftAlloc<BR>
|
||||
+ VecZnxDftApply<BR>
|
||||
+ VecZnxIdftApplyConsume<BR>
|
||||
+ VecZnxBigNormalize<BR>
|
||||
+ VecZnxBigNormalizeTmpBytes,
|
||||
Module<BT>: VecZnxDftAddInplace<BT>
|
||||
+ VecZnxDftAlloc<BT>
|
||||
+ VecZnxDftApply<BT>
|
||||
+ VecZnxIdftApplyConsume<BT>
|
||||
+ VecZnxBigNormalize<BT>
|
||||
+ VecZnxBigNormalizeTmpBytes,
|
||||
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
|
||||
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
|
||||
{
|
||||
assert_eq!(module_ref.n(), module_test.n());
|
||||
let n: usize = module_ref.n();
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let cols: usize = 2;
|
||||
|
||||
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes());
|
||||
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes());
|
||||
for a_size in [1, 2, 3, 4] {
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
|
||||
a.fill_uniform(basek, &mut source);
|
||||
let a_digest = a.digest_u64();
|
||||
|
||||
let mut a_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, a_size);
|
||||
let mut a_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, a_size);
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_dft_apply(1, 0, &mut a_dft_ref, j, &a, j);
|
||||
module_test.vec_znx_dft_apply(1, 0, &mut a_dft_test, j, &a, j);
|
||||
}
|
||||
|
||||
assert_eq!(a.digest_u64(), a_digest);
|
||||
|
||||
let a_dft_ref_digest: u64 = a_dft_ref.digest_u64();
|
||||
let a_dft_test_digest: u64 = a_dft_test.digest_u64();
|
||||
|
||||
for res_size in [1, 2, 3, 4] {
|
||||
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
|
||||
res.fill_uniform(basek, &mut source);
|
||||
let res_digest: u64 = res.digest_u64();
|
||||
|
||||
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, res_size);
|
||||
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, res_size);
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_dft_apply(1, 0, &mut res_dft_ref, j, &res, j);
|
||||
module_test.vec_znx_dft_apply(1, 0, &mut res_dft_test, j, &res, j);
|
||||
}
|
||||
|
||||
assert_eq!(res.digest_u64(), res_digest);
|
||||
|
||||
// Reference
|
||||
for i in 0..cols {
|
||||
module_ref.vec_znx_dft_add_inplace(&mut res_dft_ref, i, &a_dft_ref, i);
|
||||
module_test.vec_znx_dft_add_inplace(&mut res_dft_test, i, &a_dft_test, i);
|
||||
}
|
||||
|
||||
assert_eq!(a_dft_ref.digest_u64(), a_dft_ref_digest);
|
||||
assert_eq!(a_dft_test.digest_u64(), a_dft_test_digest);
|
||||
|
||||
let res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
|
||||
let res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
|
||||
|
||||
let mut res_small_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
let mut res_small_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
|
||||
let res_ref_digest: u64 = res_big_ref.digest_u64();
|
||||
let res_test_digest: u64 = res_big_test.digest_u64();
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
basek,
|
||||
&mut res_small_ref,
|
||||
j,
|
||||
&res_big_ref,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
basek,
|
||||
&mut res_small_test,
|
||||
j,
|
||||
&res_big_test,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
}
|
||||
|
||||
assert_eq!(res_big_ref.digest_u64(), res_ref_digest);
|
||||
assert_eq!(res_big_test.digest_u64(), res_test_digest);
|
||||
|
||||
assert_eq!(res_small_ref, res_small_test);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn test_vec_znx_copy<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
|
||||
where
|
||||
Module<BR>: VecZnxDftCopy<BR>
|
||||
+ VecZnxDftAlloc<BR>
|
||||
+ VecZnxDftApply<BR>
|
||||
+ VecZnxIdftApplyConsume<BR>
|
||||
+ VecZnxBigNormalize<BR>
|
||||
+ VecZnxBigNormalizeTmpBytes,
|
||||
Module<BT>: VecZnxDftCopy<BT>
|
||||
+ VecZnxDftAlloc<BT>
|
||||
+ VecZnxDftApply<BT>
|
||||
+ VecZnxIdftApplyConsume<BT>
|
||||
+ VecZnxBigNormalize<BT>
|
||||
+ VecZnxBigNormalizeTmpBytes,
|
||||
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
|
||||
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
|
||||
{
|
||||
assert_eq!(module_ref.n(), module_test.n());
|
||||
let n: usize = module_ref.n();
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let cols: usize = 2;
|
||||
|
||||
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes());
|
||||
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes());
|
||||
|
||||
for a_size in [1, 2, 6, 11] {
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
|
||||
a.fill_uniform(basek, &mut source);
|
||||
let a_digest = a.digest_u64();
|
||||
|
||||
let mut a_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, a_size);
|
||||
let mut a_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, a_size);
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_dft_apply(1, 0, &mut a_dft_ref, j, &a, j);
|
||||
module_test.vec_znx_dft_apply(1, 0, &mut a_dft_test, j, &a, j);
|
||||
}
|
||||
|
||||
assert_eq!(a.digest_u64(), a_digest);
|
||||
|
||||
let a_dft_ref_digest: u64 = a_dft_ref.digest_u64();
|
||||
let a_dft_test_digest: u64 = a_dft_test.digest_u64();
|
||||
|
||||
for res_size in [1, 2, 6, 11] {
|
||||
for params in [[1, 0], [1, 1], [1, 2], [2, 2]] {
|
||||
let steps: usize = params[0];
|
||||
let offset: usize = params[1];
|
||||
|
||||
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, res_size);
|
||||
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, res_size);
|
||||
|
||||
// Set d to garbage
|
||||
source.fill_bytes(res_dft_ref.data_mut());
|
||||
source.fill_bytes(res_dft_test.data_mut());
|
||||
|
||||
// Reference
|
||||
for i in 0..cols {
|
||||
module_ref.vec_znx_dft_copy(steps, offset, &mut res_dft_ref, i, &a_dft_ref, i);
|
||||
module_test.vec_znx_dft_copy(steps, offset, &mut res_dft_test, i, &a_dft_test, i);
|
||||
}
|
||||
|
||||
assert_eq!(a_dft_ref.digest_u64(), a_dft_ref_digest);
|
||||
assert_eq!(a_dft_test.digest_u64(), a_dft_test_digest);
|
||||
|
||||
let res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
|
||||
let res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
|
||||
|
||||
let mut res_small_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
let mut res_small_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
|
||||
let res_ref_digest: u64 = res_big_ref.digest_u64();
|
||||
let res_test_digest: u64 = res_big_test.digest_u64();
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
basek,
|
||||
&mut res_small_ref,
|
||||
j,
|
||||
&res_big_ref,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
basek,
|
||||
&mut res_small_test,
|
||||
j,
|
||||
&res_big_test,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
}
|
||||
|
||||
assert_eq!(res_big_ref.digest_u64(), res_ref_digest);
|
||||
assert_eq!(res_big_test.digest_u64(), res_test_digest);
|
||||
|
||||
assert_eq!(res_small_ref, res_small_test);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn test_vec_znx_idft_apply<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
|
||||
where
|
||||
Module<BR>: VecZnxDftApply<BR>
|
||||
+ VecZnxDftAlloc<BR>
|
||||
+ VecZnxBigAlloc<BR>
|
||||
+ VecZnxBigNormalize<BR>
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VecZnxIdftApply<BR>,
|
||||
Module<BT>: VecZnxDftApply<BT>
|
||||
+ VecZnxDftAlloc<BT>
|
||||
+ VecZnxBigAlloc<BT>
|
||||
+ VecZnxBigNormalize<BT>
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VecZnxIdftApply<BT>,
|
||||
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
|
||||
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
|
||||
{
|
||||
assert_eq!(module_ref.n(), module_test.n());
|
||||
let n: usize = module_ref.n();
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let cols: usize = 2;
|
||||
|
||||
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes());
|
||||
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes());
|
||||
|
||||
for a_size in [1, 2, 3, 4] {
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
|
||||
a.fill_uniform(basek, &mut source);
|
||||
let a_digest: u64 = a.digest_u64();
|
||||
|
||||
for res_size in [1, 2, 3, 4] {
|
||||
for params in [[1, 0], [1, 1], [1, 2], [2, 2]] {
|
||||
let steps: usize = params[0];
|
||||
let offset: usize = params[1];
|
||||
|
||||
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, a_size);
|
||||
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, a_size);
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_dft_apply(steps, offset, &mut res_dft_ref, j, &a, j);
|
||||
module_test.vec_znx_dft_apply(steps, offset, &mut res_dft_test, j, &a, j);
|
||||
}
|
||||
|
||||
assert_eq!(a.digest_u64(), a_digest);
|
||||
let res_dft_ref_digest: u64 = res_dft_ref.digest_u64();
|
||||
let rest_dft_test_digest: u64 = res_dft_test.digest_u64();
|
||||
|
||||
let mut res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_big_alloc(cols, res_size);
|
||||
let mut res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_big_alloc(cols, res_size);
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_idft_apply(&mut res_big_ref, j, &res_dft_ref, j, scratch_ref.borrow());
|
||||
module_test.vec_znx_idft_apply(
|
||||
&mut res_big_test,
|
||||
j,
|
||||
&res_dft_test,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
}
|
||||
|
||||
assert_eq!(res_dft_ref.digest_u64(), res_dft_ref_digest);
|
||||
assert_eq!(res_dft_test.digest_u64(), rest_dft_test_digest);
|
||||
|
||||
let mut res_small_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
let mut res_small_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
|
||||
let res_ref_digest: u64 = res_big_ref.digest_u64();
|
||||
let res_test_digest: u64 = res_big_test.digest_u64();
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
basek,
|
||||
&mut res_small_ref,
|
||||
j,
|
||||
&res_big_ref,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
basek,
|
||||
&mut res_small_test,
|
||||
j,
|
||||
&res_big_test,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
}
|
||||
|
||||
assert_eq!(res_big_ref.digest_u64(), res_ref_digest);
|
||||
assert_eq!(res_big_test.digest_u64(), res_test_digest);
|
||||
|
||||
assert_eq!(res_small_ref, res_small_test);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn test_vec_znx_idft_apply_tmpa<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
|
||||
where
|
||||
Module<BR>: VecZnxDftApply<BR>
|
||||
+ VecZnxDftAlloc<BR>
|
||||
+ VecZnxBigAlloc<BR>
|
||||
+ VecZnxBigNormalize<BR>
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VecZnxIdftApplyTmpA<BR>,
|
||||
Module<BT>: VecZnxDftApply<BT>
|
||||
+ VecZnxDftAlloc<BT>
|
||||
+ VecZnxBigAlloc<BT>
|
||||
+ VecZnxBigNormalize<BT>
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VecZnxIdftApplyTmpA<BT>,
|
||||
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
|
||||
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
|
||||
{
|
||||
assert_eq!(module_ref.n(), module_test.n());
|
||||
let n: usize = module_ref.n();
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let cols: usize = 2;
|
||||
|
||||
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes());
|
||||
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes());
|
||||
|
||||
for a_size in [1, 2, 3, 4] {
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
|
||||
a.fill_uniform(basek, &mut source);
|
||||
let a_digest: u64 = a.digest_u64();
|
||||
|
||||
for res_size in [1, 2, 3, 4] {
|
||||
for params in [[1, 0], [1, 1], [1, 2], [2, 2]] {
|
||||
let steps: usize = params[0];
|
||||
let offset: usize = params[1];
|
||||
|
||||
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, a_size);
|
||||
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, a_size);
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_dft_apply(steps, offset, &mut res_dft_ref, j, &a, j);
|
||||
module_test.vec_znx_dft_apply(steps, offset, &mut res_dft_test, j, &a, j);
|
||||
}
|
||||
|
||||
assert_eq!(a.digest_u64(), a_digest);
|
||||
|
||||
let mut res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_big_alloc(cols, res_size);
|
||||
let mut res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_big_alloc(cols, res_size);
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_idft_apply_tmpa(&mut res_big_ref, j, &mut res_dft_ref, j);
|
||||
module_test.vec_znx_idft_apply_tmpa(&mut res_big_test, j, &mut res_dft_test, j);
|
||||
}
|
||||
|
||||
let mut res_small_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
let mut res_small_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
|
||||
let res_ref_digest: u64 = res_big_ref.digest_u64();
|
||||
let res_test_digest: u64 = res_big_test.digest_u64();
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
basek,
|
||||
&mut res_small_ref,
|
||||
j,
|
||||
&res_big_ref,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
basek,
|
||||
&mut res_small_test,
|
||||
j,
|
||||
&res_big_test,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
}
|
||||
|
||||
assert_eq!(res_big_ref.digest_u64(), res_ref_digest);
|
||||
assert_eq!(res_big_test.digest_u64(), res_test_digest);
|
||||
|
||||
assert_eq!(res_small_ref, res_small_test);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn test_vec_znx_idft_apply_consume<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
|
||||
where
|
||||
Module<BR>: VecZnxDftApply<BR>
|
||||
+ VecZnxIdftApplyTmpBytes
|
||||
+ VecZnxDftAlloc<BR>
|
||||
+ VecZnxBigNormalize<BR>
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VecZnxIdftApplyConsume<BR>,
|
||||
Module<BT>: VecZnxDftApply<BT>
|
||||
+ VecZnxIdftApplyTmpBytes
|
||||
+ VecZnxDftAlloc<BT>
|
||||
+ VecZnxBigNormalize<BT>
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VecZnxIdftApplyConsume<BT>,
|
||||
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
|
||||
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
|
||||
{
|
||||
assert_eq!(module_ref.n(), module_test.n());
|
||||
let n: usize = module_ref.n();
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let cols: usize = 2;
|
||||
|
||||
let mut scratch_ref: ScratchOwned<BR> =
|
||||
ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes() | module_ref.vec_znx_idft_apply_tmp_bytes());
|
||||
let mut scratch_test: ScratchOwned<BT> =
|
||||
ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes() | module_test.vec_znx_idft_apply_tmp_bytes());
|
||||
|
||||
for a_size in [1, 2, 3, 4] {
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
|
||||
a.fill_uniform(basek, &mut source);
|
||||
let a_digest: u64 = a.digest_u64();
|
||||
|
||||
for res_size in [1, 2, 3, 4] {
|
||||
for params in [[1, 0], [1, 1], [1, 2], [2, 2]] {
|
||||
let steps: usize = params[0];
|
||||
let offset: usize = params[1];
|
||||
|
||||
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, a_size);
|
||||
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, a_size);
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_dft_apply(steps, offset, &mut res_dft_ref, j, &a, j);
|
||||
module_test.vec_znx_dft_apply(steps, offset, &mut res_dft_test, j, &a, j);
|
||||
}
|
||||
|
||||
assert_eq!(a.digest_u64(), a_digest);
|
||||
|
||||
let res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
|
||||
let res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
|
||||
|
||||
let mut res_small_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
let mut res_small_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
|
||||
let res_ref_digest: u64 = res_big_ref.digest_u64();
|
||||
let res_test_digest: u64 = res_big_test.digest_u64();
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
basek,
|
||||
&mut res_small_ref,
|
||||
j,
|
||||
&res_big_ref,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
basek,
|
||||
&mut res_small_test,
|
||||
j,
|
||||
&res_big_test,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
}
|
||||
|
||||
assert_eq!(res_big_ref.digest_u64(), res_ref_digest);
|
||||
assert_eq!(res_big_test.digest_u64(), res_test_digest);
|
||||
|
||||
assert_eq!(res_small_ref, res_small_test);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn test_vec_znx_dft_sub<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
|
||||
where
|
||||
Module<BR>: VecZnxDftSub<BR>
|
||||
+ VecZnxDftAlloc<BR>
|
||||
+ VecZnxDftApply<BR>
|
||||
+ VecZnxIdftApplyConsume<BR>
|
||||
+ VecZnxBigNormalize<BR>
|
||||
+ VecZnxBigNormalizeTmpBytes,
|
||||
Module<BT>: VecZnxDftSub<BT>
|
||||
+ VecZnxDftAlloc<BT>
|
||||
+ VecZnxDftApply<BT>
|
||||
+ VecZnxIdftApplyConsume<BT>
|
||||
+ VecZnxBigNormalize<BT>
|
||||
+ VecZnxBigNormalizeTmpBytes,
|
||||
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
|
||||
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
|
||||
{
|
||||
assert_eq!(module_ref.n(), module_test.n());
|
||||
let n: usize = module_ref.n();
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let cols: usize = 2;
|
||||
|
||||
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes());
|
||||
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes());
|
||||
|
||||
for a_size in [1, 2, 3, 4] {
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
|
||||
a.fill_uniform(basek, &mut source);
|
||||
let a_digest = a.digest_u64();
|
||||
|
||||
let mut a_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, a_size);
|
||||
let mut a_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, a_size);
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_dft_apply(1, 0, &mut a_dft_ref, j, &a, j);
|
||||
module_test.vec_znx_dft_apply(1, 0, &mut a_dft_test, j, &a, j);
|
||||
}
|
||||
|
||||
assert_eq!(a.digest_u64(), a_digest);
|
||||
|
||||
let a_dft_ref_digest: u64 = a_dft_ref.digest_u64();
|
||||
let a_dft_test_digest: u64 = a_dft_test.digest_u64();
|
||||
|
||||
for b_size in [1, 2, 3, 4] {
|
||||
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, b_size);
|
||||
b.fill_uniform(basek, &mut source);
|
||||
let b_digest: u64 = b.digest_u64();
|
||||
|
||||
let mut b_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, b_size);
|
||||
let mut b_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, b_size);
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_dft_apply(1, 0, &mut b_dft_ref, j, &b, j);
|
||||
module_test.vec_znx_dft_apply(1, 0, &mut b_dft_test, j, &b, j);
|
||||
}
|
||||
|
||||
assert_eq!(b.digest_u64(), b_digest);
|
||||
|
||||
let b_dft_ref_digest: u64 = b_dft_ref.digest_u64();
|
||||
let b_dft_test_digest: u64 = b_dft_test.digest_u64();
|
||||
|
||||
for res_size in [1, 2, 3, 4] {
|
||||
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, res_size);
|
||||
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, res_size);
|
||||
|
||||
// Set d to garbage
|
||||
source.fill_bytes(res_dft_ref.data_mut());
|
||||
source.fill_bytes(res_dft_test.data_mut());
|
||||
|
||||
// Reference
|
||||
for i in 0..cols {
|
||||
module_ref.vec_znx_dft_sub(&mut res_dft_ref, i, &a_dft_ref, i, &b_dft_ref, i);
|
||||
module_test.vec_znx_dft_sub(&mut res_dft_test, i, &a_dft_test, i, &b_dft_test, i);
|
||||
}
|
||||
|
||||
assert_eq!(a_dft_ref.digest_u64(), a_dft_ref_digest);
|
||||
assert_eq!(a_dft_test.digest_u64(), a_dft_test_digest);
|
||||
assert_eq!(b_dft_ref.digest_u64(), b_dft_ref_digest);
|
||||
assert_eq!(b_dft_test.digest_u64(), b_dft_test_digest);
|
||||
|
||||
let res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
|
||||
let res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
|
||||
|
||||
let mut res_small_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
let mut res_small_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
|
||||
let res_ref_digest: u64 = res_big_ref.digest_u64();
|
||||
let res_test_digest: u64 = res_big_test.digest_u64();
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
basek,
|
||||
&mut res_small_ref,
|
||||
j,
|
||||
&res_big_ref,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
basek,
|
||||
&mut res_small_test,
|
||||
j,
|
||||
&res_big_test,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
}
|
||||
|
||||
assert_eq!(res_big_ref.digest_u64(), res_ref_digest);
|
||||
assert_eq!(res_big_test.digest_u64(), res_test_digest);
|
||||
|
||||
assert_eq!(res_small_ref, res_small_test);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn test_vec_znx_dft_sub_ab_inplace<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
|
||||
where
|
||||
Module<BR>: VecZnxDftSubABInplace<BR>
|
||||
+ VecZnxDftAlloc<BR>
|
||||
+ VecZnxDftApply<BR>
|
||||
+ VecZnxIdftApplyConsume<BR>
|
||||
+ VecZnxBigNormalize<BR>
|
||||
+ VecZnxBigNormalizeTmpBytes,
|
||||
Module<BT>: VecZnxDftSubABInplace<BT>
|
||||
+ VecZnxDftAlloc<BT>
|
||||
+ VecZnxDftApply<BT>
|
||||
+ VecZnxIdftApplyConsume<BT>
|
||||
+ VecZnxBigNormalize<BT>
|
||||
+ VecZnxBigNormalizeTmpBytes,
|
||||
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
|
||||
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
|
||||
{
|
||||
assert_eq!(module_ref.n(), module_test.n());
|
||||
let n: usize = module_ref.n();
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let cols: usize = 2;
|
||||
|
||||
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes());
|
||||
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes());
|
||||
for a_size in [1, 2, 3, 4] {
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
|
||||
a.fill_uniform(basek, &mut source);
|
||||
let a_digest = a.digest_u64();
|
||||
|
||||
let mut a_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, a_size);
|
||||
let mut a_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, a_size);
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_dft_apply(1, 0, &mut a_dft_ref, j, &a, j);
|
||||
module_test.vec_znx_dft_apply(1, 0, &mut a_dft_test, j, &a, j);
|
||||
}
|
||||
|
||||
assert_eq!(a.digest_u64(), a_digest);
|
||||
|
||||
let a_dft_ref_digest: u64 = a_dft_ref.digest_u64();
|
||||
let a_dft_test_digest: u64 = a_dft_test.digest_u64();
|
||||
|
||||
for res_size in [1, 2, 3, 4] {
|
||||
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
|
||||
res.fill_uniform(basek, &mut source);
|
||||
let res_digest: u64 = res.digest_u64();
|
||||
|
||||
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, res_size);
|
||||
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, res_size);
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_dft_apply(1, 0, &mut res_dft_ref, j, &res, j);
|
||||
module_test.vec_znx_dft_apply(1, 0, &mut res_dft_test, j, &res, j);
|
||||
}
|
||||
|
||||
assert_eq!(res.digest_u64(), res_digest);
|
||||
|
||||
// Reference
|
||||
for i in 0..cols {
|
||||
module_ref.vec_znx_dft_sub_ab_inplace(&mut res_dft_ref, i, &a_dft_ref, i);
|
||||
module_test.vec_znx_dft_sub_ab_inplace(&mut res_dft_test, i, &a_dft_test, i);
|
||||
}
|
||||
|
||||
assert_eq!(a_dft_ref.digest_u64(), a_dft_ref_digest);
|
||||
assert_eq!(a_dft_test.digest_u64(), a_dft_test_digest);
|
||||
|
||||
let res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
|
||||
let res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
|
||||
|
||||
let mut res_small_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
let mut res_small_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
|
||||
let res_ref_digest: u64 = res_big_ref.digest_u64();
|
||||
let res_test_digest: u64 = res_big_test.digest_u64();
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
basek,
|
||||
&mut res_small_ref,
|
||||
j,
|
||||
&res_big_ref,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
basek,
|
||||
&mut res_small_test,
|
||||
j,
|
||||
&res_big_test,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
}
|
||||
|
||||
assert_eq!(res_big_ref.digest_u64(), res_ref_digest);
|
||||
assert_eq!(res_big_test.digest_u64(), res_test_digest);
|
||||
|
||||
assert_eq!(res_small_ref, res_small_test);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn test_vec_znx_dft_sub_ba_inplace<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
|
||||
where
|
||||
Module<BR>: VecZnxDftSubBAInplace<BR>
|
||||
+ VecZnxDftAlloc<BR>
|
||||
+ VecZnxDftApply<BR>
|
||||
+ VecZnxIdftApplyConsume<BR>
|
||||
+ VecZnxBigNormalize<BR>
|
||||
+ VecZnxBigNormalizeTmpBytes,
|
||||
Module<BT>: VecZnxDftSubBAInplace<BT>
|
||||
+ VecZnxDftAlloc<BT>
|
||||
+ VecZnxDftApply<BT>
|
||||
+ VecZnxIdftApplyConsume<BT>
|
||||
+ VecZnxBigNormalize<BT>
|
||||
+ VecZnxBigNormalizeTmpBytes,
|
||||
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
|
||||
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
|
||||
{
|
||||
assert_eq!(module_ref.n(), module_test.n());
|
||||
let n: usize = module_ref.n();
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let cols: usize = 2;
|
||||
|
||||
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes());
|
||||
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes());
|
||||
for a_size in [1, 2, 3, 4] {
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
|
||||
a.fill_uniform(basek, &mut source);
|
||||
let a_digest = a.digest_u64();
|
||||
|
||||
let mut a_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, a_size);
|
||||
let mut a_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, a_size);
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_dft_apply(1, 0, &mut a_dft_ref, j, &a, j);
|
||||
module_test.vec_znx_dft_apply(1, 0, &mut a_dft_test, j, &a, j);
|
||||
}
|
||||
|
||||
assert_eq!(a.digest_u64(), a_digest);
|
||||
|
||||
let a_dft_ref_digest: u64 = a_dft_ref.digest_u64();
|
||||
let a_dft_test_digest: u64 = a_dft_test.digest_u64();
|
||||
|
||||
for res_size in [1, 2, 3, 4] {
|
||||
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
|
||||
res.fill_uniform(basek, &mut source);
|
||||
let res_digest: u64 = res.digest_u64();
|
||||
|
||||
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, res_size);
|
||||
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, res_size);
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_dft_apply(1, 0, &mut res_dft_ref, j, &res, j);
|
||||
module_test.vec_znx_dft_apply(1, 0, &mut res_dft_test, j, &res, j);
|
||||
}
|
||||
|
||||
assert_eq!(res.digest_u64(), res_digest);
|
||||
|
||||
// Reference
|
||||
for i in 0..cols {
|
||||
module_ref.vec_znx_dft_sub_ba_inplace(&mut res_dft_ref, i, &a_dft_ref, i);
|
||||
module_test.vec_znx_dft_sub_ba_inplace(&mut res_dft_test, i, &a_dft_test, i);
|
||||
}
|
||||
|
||||
assert_eq!(a_dft_ref.digest_u64(), a_dft_ref_digest);
|
||||
assert_eq!(a_dft_test.digest_u64(), a_dft_test_digest);
|
||||
|
||||
let res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
|
||||
let res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
|
||||
|
||||
let mut res_small_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
let mut res_small_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
|
||||
let res_ref_digest: u64 = res_big_ref.digest_u64();
|
||||
let res_test_digest: u64 = res_big_test.digest_u64();
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
basek,
|
||||
&mut res_small_ref,
|
||||
j,
|
||||
&res_big_ref,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
basek,
|
||||
&mut res_small_test,
|
||||
j,
|
||||
&res_big_test,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
}
|
||||
|
||||
assert_eq!(res_big_ref.digest_u64(), res_ref_digest);
|
||||
assert_eq!(res_big_test.digest_u64(), res_test_digest);
|
||||
|
||||
assert_eq!(res_small_ref, res_small_test);
|
||||
}
|
||||
}
|
||||
}
|
||||
384
poulpy-hal/src/test_suite/vmp.rs
Normal file
384
poulpy-hal/src/test_suite/vmp.rs
Normal file
@@ -0,0 +1,384 @@
|
||||
use crate::{
|
||||
api::{
|
||||
ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxBigNormalize, VecZnxDftAlloc, VecZnxDftApply,
|
||||
VecZnxIdftApplyConsume, VmpApplyDft, VmpApplyDftTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd,
|
||||
VmpApplyDftToDftAddTmpBytes, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare,
|
||||
},
|
||||
layouts::{DataViewMut, DigestU64, FillUniform, MatZnx, Module, ScratchOwned, VecZnx, VecZnxBig},
|
||||
source::Source,
|
||||
};
|
||||
use rand::RngCore;
|
||||
|
||||
use crate::layouts::{Backend, VecZnxDft, VmpPMat};
|
||||
|
||||
pub fn test_vmp_apply_dft<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
|
||||
where
|
||||
Module<BR>: ModuleNew<BR>
|
||||
+ VmpApplyDftTmpBytes
|
||||
+ VmpApplyDft<BR>
|
||||
+ VmpPMatAlloc<BR>
|
||||
+ VecZnxDftAlloc<BR>
|
||||
+ VmpPrepare<BR>
|
||||
+ VecZnxDftAlloc<BR>
|
||||
+ VecZnxIdftApplyConsume<BR>
|
||||
+ VecZnxBigNormalize<BR>,
|
||||
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
|
||||
Module<BT>: ModuleNew<BT>
|
||||
+ VmpApplyDftTmpBytes
|
||||
+ VmpApplyDft<BT>
|
||||
+ VmpPMatAlloc<BT>
|
||||
+ VecZnxDftAlloc<BT>
|
||||
+ VmpPrepare<BT>
|
||||
+ VecZnxDftAlloc<BT>
|
||||
+ VecZnxIdftApplyConsume<BT>
|
||||
+ VecZnxBigNormalize<BT>,
|
||||
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
|
||||
{
|
||||
assert_eq!(module_ref.n(), module_test.n());
|
||||
let n: usize = module_ref.n();
|
||||
|
||||
let max_size: usize = 4;
|
||||
let max_cols: usize = 2;
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut scratch_ref: ScratchOwned<BR> =
|
||||
ScratchOwned::alloc(module_ref.vmp_apply_dft_tmp_bytes(max_size, max_size, max_size, max_cols, max_cols, max_size));
|
||||
let mut scratch_test: ScratchOwned<BT> =
|
||||
ScratchOwned::alloc(module_test.vmp_apply_dft_tmp_bytes(max_size, max_size, max_size, max_cols, max_cols, max_size));
|
||||
|
||||
for cols_in in 1..max_cols + 1 {
|
||||
for cols_out in 1..max_cols + 1 {
|
||||
for size_in in 1..max_size + 1 {
|
||||
for size_out in 1..max_size + 1 {
|
||||
let rows: usize = cols_in;
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols_in, size_in);
|
||||
a.fill_uniform(basek, &mut source);
|
||||
let a_digest: u64 = a.digest_u64();
|
||||
|
||||
let mut mat: MatZnx<Vec<u8>> = MatZnx::alloc(n, rows, cols_in, cols_out, size_out);
|
||||
mat.fill_uniform(basek, &mut source);
|
||||
let mat_digest: u64 = mat.digest_u64();
|
||||
|
||||
let mut pmat_ref: VmpPMat<Vec<u8>, BR> = module_ref.vmp_pmat_alloc(rows, cols_in, cols_out, size_out);
|
||||
let mut pmat_test: VmpPMat<Vec<u8>, BT> = module_test.vmp_pmat_alloc(rows, cols_in, cols_out, size_out);
|
||||
|
||||
module_ref.vmp_prepare(&mut pmat_ref, &mat, scratch_ref.borrow());
|
||||
module_test.vmp_prepare(&mut pmat_test, &mat, scratch_test.borrow());
|
||||
|
||||
assert_eq!(mat.digest_u64(), mat_digest);
|
||||
|
||||
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols_out, size_out);
|
||||
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols_out, size_out);
|
||||
|
||||
source.fill_bytes(res_dft_ref.data_mut());
|
||||
source.fill_bytes(res_dft_test.data_mut());
|
||||
|
||||
module_ref.vmp_apply_dft(&mut res_dft_ref, &a, &pmat_ref, scratch_ref.borrow());
|
||||
module_test.vmp_apply_dft(&mut res_dft_test, &a, &pmat_test, scratch_test.borrow());
|
||||
|
||||
assert_eq!(a.digest_u64(), a_digest);
|
||||
|
||||
let res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
|
||||
let res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
|
||||
|
||||
let mut res_small_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols_out, size_out);
|
||||
let mut res_small_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols_out, size_out);
|
||||
|
||||
let res_ref_digest: u64 = res_big_ref.digest_u64();
|
||||
let res_test_digest: u64 = res_big_test.digest_u64();
|
||||
|
||||
for j in 0..cols_out {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
basek,
|
||||
&mut res_small_ref,
|
||||
j,
|
||||
&res_big_ref,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
basek,
|
||||
&mut res_small_test,
|
||||
j,
|
||||
&res_big_test,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
}
|
||||
|
||||
assert_eq!(res_big_ref.digest_u64(), res_ref_digest);
|
||||
assert_eq!(res_big_test.digest_u64(), res_test_digest);
|
||||
|
||||
assert_eq!(res_small_ref, res_small_test);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn test_vmp_apply_dft_to_dft<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
|
||||
where
|
||||
Module<BR>: ModuleNew<BR>
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VmpApplyDftToDft<BR>
|
||||
+ VmpPMatAlloc<BR>
|
||||
+ VecZnxDftAlloc<BR>
|
||||
+ VmpPrepare<BR>
|
||||
+ VecZnxIdftApplyConsume<BR>
|
||||
+ VecZnxBigNormalize<BR>
|
||||
+ VecZnxDftApply<BR>,
|
||||
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
|
||||
Module<BT>: ModuleNew<BT>
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VmpApplyDftToDft<BT>
|
||||
+ VmpPMatAlloc<BT>
|
||||
+ VecZnxDftAlloc<BT>
|
||||
+ VmpPrepare<BT>
|
||||
+ VecZnxIdftApplyConsume<BT>
|
||||
+ VecZnxBigNormalize<BT>
|
||||
+ VecZnxDftApply<BT>,
|
||||
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
|
||||
{
|
||||
assert_eq!(module_ref.n(), module_test.n());
|
||||
let n: usize = module_ref.n();
|
||||
|
||||
let max_size: usize = 4;
|
||||
let max_cols: usize = 2;
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(
|
||||
module_ref.vmp_apply_dft_to_dft_tmp_bytes(max_size, max_size, max_size, max_cols, max_cols, max_size),
|
||||
);
|
||||
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(
|
||||
module_test.vmp_apply_dft_to_dft_tmp_bytes(max_size, max_size, max_size, max_cols, max_cols, max_size),
|
||||
);
|
||||
|
||||
for cols_in in 1..max_cols + 1 {
|
||||
for cols_out in 1..max_cols + 1 {
|
||||
for size_in in 1..max_size + 1 {
|
||||
for size_out in 1..max_size + 1 {
|
||||
let rows: usize = size_in;
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols_in, size_in);
|
||||
a.fill_uniform(basek, &mut source);
|
||||
let a_digest: u64 = a.digest_u64();
|
||||
|
||||
let mut a_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols_in, size_in);
|
||||
let mut a_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols_in, size_in);
|
||||
|
||||
for j in 0..cols_in {
|
||||
module_ref.vec_znx_dft_apply(1, 0, &mut a_dft_ref, j, &a, j);
|
||||
module_test.vec_znx_dft_apply(1, 0, &mut a_dft_test, j, &a, j);
|
||||
}
|
||||
|
||||
assert_eq!(a.digest_u64(), a_digest);
|
||||
|
||||
let mut mat: MatZnx<Vec<u8>> = MatZnx::alloc(n, rows, cols_in, cols_out, size_out);
|
||||
mat.fill_uniform(basek, &mut source);
|
||||
let mat_digest: u64 = mat.digest_u64();
|
||||
|
||||
let mut pmat_ref: VmpPMat<Vec<u8>, BR> = module_ref.vmp_pmat_alloc(rows, cols_in, cols_out, size_out);
|
||||
let mut pmat_test: VmpPMat<Vec<u8>, BT> = module_test.vmp_pmat_alloc(rows, cols_in, cols_out, size_out);
|
||||
|
||||
module_ref.vmp_prepare(&mut pmat_ref, &mat, scratch_ref.borrow());
|
||||
module_test.vmp_prepare(&mut pmat_test, &mat, scratch_test.borrow());
|
||||
|
||||
assert_eq!(mat.digest_u64(), mat_digest);
|
||||
|
||||
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols_out, size_out);
|
||||
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols_out, size_out);
|
||||
|
||||
source.fill_bytes(res_dft_ref.data_mut());
|
||||
source.fill_bytes(res_dft_test.data_mut());
|
||||
|
||||
module_ref.vmp_apply_dft_to_dft(
|
||||
&mut res_dft_ref,
|
||||
&a_dft_ref,
|
||||
&pmat_ref,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vmp_apply_dft_to_dft(
|
||||
&mut res_dft_test,
|
||||
&a_dft_test,
|
||||
&pmat_test,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
|
||||
let res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
|
||||
let res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
|
||||
|
||||
let mut res_small_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols_out, size_out);
|
||||
let mut res_small_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols_out, size_out);
|
||||
|
||||
let res_ref_digest: u64 = res_big_ref.digest_u64();
|
||||
let res_test_digest: u64 = res_big_test.digest_u64();
|
||||
|
||||
for j in 0..cols_out {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
basek,
|
||||
&mut res_small_ref,
|
||||
j,
|
||||
&res_big_ref,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
basek,
|
||||
&mut res_small_test,
|
||||
j,
|
||||
&res_big_test,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
}
|
||||
|
||||
assert_eq!(res_big_ref.digest_u64(), res_ref_digest);
|
||||
assert_eq!(res_big_test.digest_u64(), res_test_digest);
|
||||
|
||||
assert_eq!(res_small_ref, res_small_test);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn test_vmp_apply_dft_to_dft_add<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
|
||||
where
|
||||
Module<BR>: ModuleNew<BR>
|
||||
+ VmpApplyDftToDftAddTmpBytes
|
||||
+ VmpApplyDftToDftAdd<BR>
|
||||
+ VmpPMatAlloc<BR>
|
||||
+ VecZnxDftAlloc<BR>
|
||||
+ VmpPrepare<BR>
|
||||
+ VecZnxIdftApplyConsume<BR>
|
||||
+ VecZnxBigNormalize<BR>
|
||||
+ VecZnxDftApply<BR>,
|
||||
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
|
||||
Module<BT>: ModuleNew<BT>
|
||||
+ VmpApplyDftToDftAddTmpBytes
|
||||
+ VmpApplyDftToDftAdd<BT>
|
||||
+ VmpPMatAlloc<BT>
|
||||
+ VecZnxDftAlloc<BT>
|
||||
+ VmpPrepare<BT>
|
||||
+ VecZnxIdftApplyConsume<BT>
|
||||
+ VecZnxBigNormalize<BT>
|
||||
+ VecZnxDftApply<BT>,
|
||||
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
|
||||
{
|
||||
assert_eq!(module_ref.n(), module_test.n());
|
||||
let n: usize = module_ref.n();
|
||||
|
||||
let max_size: usize = 4;
|
||||
let max_cols: usize = 2;
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(
|
||||
module_ref.vmp_apply_dft_to_dft_add_tmp_bytes(max_size, max_size, max_size, max_cols, max_cols, max_size),
|
||||
);
|
||||
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(
|
||||
module_test.vmp_apply_dft_to_dft_add_tmp_bytes(max_size, max_size, max_size, max_cols, max_cols, max_size),
|
||||
);
|
||||
|
||||
for cols_in in 1..max_cols + 1 {
|
||||
for cols_out in 1..max_cols + 1 {
|
||||
for size_in in 1..max_size + 1 {
|
||||
for size_out in 1..max_size + 1 {
|
||||
let rows: usize = size_in;
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols_in, size_in);
|
||||
a.fill_uniform(basek, &mut source);
|
||||
let a_digest: u64 = a.digest_u64();
|
||||
|
||||
let mut a_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols_in, size_in);
|
||||
let mut a_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols_in, size_in);
|
||||
|
||||
for j in 0..cols_in {
|
||||
module_ref.vec_znx_dft_apply(1, 0, &mut a_dft_ref, j, &a, j);
|
||||
module_test.vec_znx_dft_apply(1, 0, &mut a_dft_test, j, &a, j);
|
||||
}
|
||||
|
||||
assert_eq!(a.digest_u64(), a_digest);
|
||||
|
||||
let mut mat: MatZnx<Vec<u8>> = MatZnx::alloc(n, rows, cols_in, cols_out, size_out);
|
||||
mat.fill_uniform(basek, &mut source);
|
||||
let mat_digest: u64 = mat.digest_u64();
|
||||
|
||||
let mut pmat_ref: VmpPMat<Vec<u8>, BR> = module_ref.vmp_pmat_alloc(rows, cols_in, cols_out, size_out);
|
||||
let mut pmat_test: VmpPMat<Vec<u8>, BT> = module_test.vmp_pmat_alloc(rows, cols_in, cols_out, size_out);
|
||||
|
||||
module_ref.vmp_prepare(&mut pmat_ref, &mat, scratch_ref.borrow());
|
||||
module_test.vmp_prepare(&mut pmat_test, &mat, scratch_test.borrow());
|
||||
|
||||
assert_eq!(mat.digest_u64(), mat_digest);
|
||||
|
||||
for limb_offset in 0..size_out {
|
||||
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols_out, size_out);
|
||||
res.fill_uniform(basek, &mut source);
|
||||
let res_digest: u64 = res.digest_u64();
|
||||
|
||||
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols_out, size_out);
|
||||
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols_out, size_out);
|
||||
|
||||
for j in 0..cols_out {
|
||||
module_ref.vec_znx_dft_apply(1, 0, &mut res_dft_ref, j, &res, j);
|
||||
module_test.vec_znx_dft_apply(1, 0, &mut res_dft_test, j, &res, j);
|
||||
}
|
||||
|
||||
assert_eq!(res.digest_u64(), res_digest);
|
||||
|
||||
module_ref.vmp_apply_dft_to_dft_add(
|
||||
&mut res_dft_ref,
|
||||
&a_dft_ref,
|
||||
&pmat_ref,
|
||||
limb_offset * cols_out,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vmp_apply_dft_to_dft_add(
|
||||
&mut res_dft_test,
|
||||
&a_dft_test,
|
||||
&pmat_test,
|
||||
limb_offset * cols_out,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
|
||||
let res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
|
||||
let res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
|
||||
|
||||
let mut res_small_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols_out, size_out);
|
||||
let mut res_small_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols_out, size_out);
|
||||
|
||||
let res_ref_digest: u64 = res_big_ref.digest_u64();
|
||||
let res_test_digest: u64 = res_big_test.digest_u64();
|
||||
|
||||
for j in 0..cols_out {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
basek,
|
||||
&mut res_small_ref,
|
||||
j,
|
||||
&res_big_ref,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
basek,
|
||||
&mut res_small_test,
|
||||
j,
|
||||
&res_big_test,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
}
|
||||
|
||||
assert_eq!(res_big_ref.digest_u64(), res_ref_digest);
|
||||
assert_eq!(res_big_test.digest_u64(), res_test_digest);
|
||||
|
||||
assert_eq!(res_small_ref, res_small_test);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,3 +0,0 @@
|
||||
pub mod serialization;
|
||||
pub mod vec_znx;
|
||||
pub mod vmp_pmat;
|
||||
@@ -1,50 +0,0 @@
|
||||
use crate::{
|
||||
layouts::{VecZnx, ZnxInfos, ZnxViewMut},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
pub fn test_vec_znx_encode_vec_i64_lo_norm() {
|
||||
let n: usize = 32;
|
||||
let basek: usize = 17;
|
||||
let size: usize = 5;
|
||||
let k: usize = size * basek - 5;
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, 2, size);
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let raw: &mut [i64] = a.raw_mut();
|
||||
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
|
||||
(0..a.cols()).for_each(|col_i| {
|
||||
let mut have: Vec<i64> = vec![i64::default(); n];
|
||||
have.iter_mut()
|
||||
.for_each(|x| *x = (source.next_i64() << 56) >> 56);
|
||||
a.encode_vec_i64(basek, col_i, k, &have, 10);
|
||||
let mut want: Vec<i64> = vec![i64::default(); n];
|
||||
a.decode_vec_i64(basek, col_i, k, &mut want);
|
||||
assert_eq!(have, want, "{:?} != {:?}", &have, &want);
|
||||
});
|
||||
}
|
||||
|
||||
pub fn test_vec_znx_encode_vec_i64_hi_norm() {
|
||||
let n: usize = 32;
|
||||
let basek: usize = 17;
|
||||
let size: usize = 5;
|
||||
for k in [1, basek / 2, size * basek - 5] {
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, 2, size);
|
||||
let mut source = Source::new([0u8; 32]);
|
||||
let raw: &mut [i64] = a.raw_mut();
|
||||
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
|
||||
(0..a.cols()).for_each(|col_i| {
|
||||
let mut have: Vec<i64> = vec![i64::default(); n];
|
||||
have.iter_mut().for_each(|x| {
|
||||
if k < 64 {
|
||||
*x = source.next_u64n(1 << k, (1 << k) - 1) as i64;
|
||||
} else {
|
||||
*x = source.next_i64();
|
||||
}
|
||||
});
|
||||
a.encode_vec_i64(basek, col_i, k, &have, 63);
|
||||
let mut want: Vec<i64> = vec![i64::default(); n];
|
||||
a.decode_vec_i64(basek, col_i, k, &mut want);
|
||||
assert_eq!(have, want, "{:?} != {:?}", &have, &want);
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,67 +0,0 @@
|
||||
use crate::{
|
||||
api::{VecZnxAddNormal, VecZnxFillUniform},
|
||||
layouts::{Backend, Module, VecZnx, ZnxView},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
pub fn test_vec_znx_fill_uniform<B: Backend>(module: &Module<B>)
|
||||
where
|
||||
Module<B>: VecZnxFillUniform,
|
||||
{
|
||||
let n: usize = module.n();
|
||||
let basek: usize = 17;
|
||||
let size: usize = 5;
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let cols: usize = 2;
|
||||
let zero: Vec<i64> = vec![0; n];
|
||||
let one_12_sqrt: f64 = 0.28867513459481287;
|
||||
(0..cols).for_each(|col_i| {
|
||||
let mut a: VecZnx<_> = VecZnx::alloc(n, cols, size);
|
||||
module.vec_znx_fill_uniform(basek, &mut a, col_i, size * basek, &mut source);
|
||||
(0..cols).for_each(|col_j| {
|
||||
if col_j != col_i {
|
||||
(0..size).for_each(|limb_i| {
|
||||
assert_eq!(a.at(col_j, limb_i), zero);
|
||||
})
|
||||
} else {
|
||||
let std: f64 = a.std(basek, col_i);
|
||||
assert!(
|
||||
(std - one_12_sqrt).abs() < 0.01,
|
||||
"std={} ~!= {}",
|
||||
std,
|
||||
one_12_sqrt
|
||||
);
|
||||
}
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
pub fn test_vec_znx_add_normal<B: Backend>(module: &Module<B>)
|
||||
where
|
||||
Module<B>: VecZnxAddNormal,
|
||||
{
|
||||
let n: usize = module.n();
|
||||
let basek: usize = 17;
|
||||
let k: usize = 2 * 17;
|
||||
let size: usize = 5;
|
||||
let sigma: f64 = 3.2;
|
||||
let bound: f64 = 6.0 * sigma;
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let cols: usize = 2;
|
||||
let zero: Vec<i64> = vec![0; n];
|
||||
let k_f64: f64 = (1u64 << k as u64) as f64;
|
||||
(0..cols).for_each(|col_i| {
|
||||
let mut a: VecZnx<_> = VecZnx::alloc(n, cols, size);
|
||||
module.vec_znx_add_normal(basek, &mut a, col_i, k, &mut source, sigma, bound);
|
||||
(0..cols).for_each(|col_j| {
|
||||
if col_j != col_i {
|
||||
(0..size).for_each(|limb_i| {
|
||||
assert_eq!(a.at(col_j, limb_i), zero);
|
||||
})
|
||||
} else {
|
||||
let std: f64 = a.std(basek, col_i) * k_f64;
|
||||
assert!((std - sigma).abs() < 0.1, "std={} ~!= {}", std, sigma);
|
||||
}
|
||||
})
|
||||
});
|
||||
}
|
||||
@@ -1,5 +0,0 @@
|
||||
mod generics;
|
||||
pub use generics::*;
|
||||
|
||||
#[cfg(test)]
|
||||
mod encoding;
|
||||
@@ -1,3 +0,0 @@
|
||||
mod vmp_apply;
|
||||
|
||||
pub use vmp_apply::*;
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user