Ref. + AVX code & generic tests + benches (#85)

This commit is contained in:
Jean-Philippe Bossuat
2025-09-15 16:16:11 +02:00
committed by GitHub
parent 99b9e3e10e
commit 56dbd29c59
286 changed files with 27797 additions and 7270 deletions

View File

@@ -1,7 +1,15 @@
use crate::{
api::{SvpApply, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPPolFromBytes, SvpPrepare},
layouts::{Backend, Module, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDftToMut, VecZnxDftToRef},
oep::{SvpApplyImpl, SvpApplyInplaceImpl, SvpPPolAllocBytesImpl, SvpPPolAllocImpl, SvpPPolFromBytesImpl, SvpPrepareImpl},
api::{
SvpApplyDft, SvpApplyDftToDft, SvpApplyDftToDftAdd, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes,
SvpPPolFromBytes, SvpPrepare,
},
layouts::{
Backend, Module, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef,
},
oep::{
SvpApplyDftImpl, SvpApplyDftToDftAddImpl, SvpApplyDftToDftImpl, SvpApplyDftToDftInplaceImpl, SvpPPolAllocBytesImpl,
SvpPPolAllocImpl, SvpPPolFromBytesImpl, SvpPrepareImpl,
},
};
impl<B> SvpPPolFromBytes<B> for Module<B>
@@ -44,29 +52,57 @@ where
}
}
impl<B> SvpApply<B> for Module<B>
impl<B> SvpApplyDft<B> for Module<B>
where
B: Backend + SvpApplyImpl<B>,
B: Backend + SvpApplyDftImpl<B>,
{
fn svp_apply<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
fn svp_apply_dft<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
where
R: VecZnxDftToMut<B>,
A: SvpPPolToRef<B>,
C: VecZnxToRef,
{
B::svp_apply_dft_impl(self, res, res_col, a, a_col, b, b_col);
}
}
impl<B> SvpApplyDftToDft<B> for Module<B>
where
B: Backend + SvpApplyDftToDftImpl<B>,
{
fn svp_apply_dft_to_dft<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
where
R: VecZnxDftToMut<B>,
A: SvpPPolToRef<B>,
C: VecZnxDftToRef<B>,
{
B::svp_apply_impl(self, res, res_col, a, a_col, b, b_col);
B::svp_apply_dft_to_dft_impl(self, res, res_col, a, a_col, b, b_col);
}
}
impl<B> SvpApplyInplace<B> for Module<B>
impl<B> SvpApplyDftToDftAdd<B> for Module<B>
where
B: Backend + SvpApplyInplaceImpl,
B: Backend + SvpApplyDftToDftAddImpl<B>,
{
fn svp_apply_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn svp_apply_dft_to_dft_add<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
where
R: VecZnxDftToMut<B>,
A: SvpPPolToRef<B>,
C: VecZnxDftToRef<B>,
{
B::svp_apply_dft_to_dft_add_impl(self, res, res_col, a, a_col, b, b_col);
}
}
impl<B> SvpApplyDftToDftInplace<B> for Module<B>
where
B: Backend + SvpApplyDftToDftInplaceImpl,
{
fn svp_apply_dft_to_dft_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<B>,
A: SvpPPolToRef<B>,
{
B::svp_apply_inplace_impl(self, res, res_col, a, a_col);
B::svp_apply_dft_to_dft_inplace_impl(self, res, res_col, a, a_col);
}
}

View File

@@ -1,19 +1,24 @@
use crate::{
api::{
VecZnxAdd, VecZnxAddDistF64, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism,
VecZnxAutomorphismInplace, VecZnxCopy, VecZnxFillDistF64, VecZnxFillNormal, VecZnxFillUniform, VecZnxLshInplace,
VecZnxMerge, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegate, VecZnxNegateInplace, VecZnxNormalize,
VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSplit,
VecZnxSub, VecZnxSubABInplace, VecZnxSubBAInplace, VecZnxSubScalarInplace, VecZnxSwithcDegree,
VecZnxAdd, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalar, VecZnxAddScalarInplace, VecZnxAutomorphism,
VecZnxAutomorphismInplace, VecZnxAutomorphismInplaceTmpBytes, VecZnxCopy, VecZnxFillNormal, VecZnxFillUniform, VecZnxLsh,
VecZnxLshInplace, VecZnxLshTmpBytes, VecZnxMergeRings, VecZnxMergeRingsTmpBytes, VecZnxMulXpMinusOne,
VecZnxMulXpMinusOneInplace, VecZnxMulXpMinusOneInplaceTmpBytes, VecZnxNegate, VecZnxNegateInplace, VecZnxNormalize,
VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes,
VecZnxRsh, VecZnxRshInplace, VecZnxRshTmpBytes, VecZnxSplitRing, VecZnxSplitRingTmpBytes, VecZnxSub, VecZnxSubABInplace,
VecZnxSubBAInplace, VecZnxSubScalar, VecZnxSubScalarInplace, VecZnxSwitchRing,
},
layouts::{Backend, Module, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef},
oep::{
VecZnxAddDistF64Impl, VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl, VecZnxAddScalarInplaceImpl,
VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl, VecZnxFillDistF64Impl, VecZnxFillNormalImpl,
VecZnxFillUniformImpl, VecZnxLshInplaceImpl, VecZnxMergeImpl, VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl,
VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl, VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl,
VecZnxRotateImpl, VecZnxRotateInplaceImpl, VecZnxRshInplaceImpl, VecZnxSplitImpl, VecZnxSubABInplaceImpl,
VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarInplaceImpl, VecZnxSwithcDegreeImpl,
VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl, VecZnxAddScalarImpl, VecZnxAddScalarInplaceImpl,
VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxAutomorphismInplaceTmpBytesImpl, VecZnxCopyImpl,
VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxLshImpl, VecZnxLshInplaceImpl, VecZnxLshTmpBytesImpl,
VecZnxMergeRingsImpl, VecZnxMergeRingsTmpBytesImpl, VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl,
VecZnxMulXpMinusOneInplaceTmpBytesImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl,
VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl,
VecZnxRotateInplaceTmpBytesImpl, VecZnxRshImpl, VecZnxRshInplaceImpl, VecZnxRshTmpBytesImpl, VecZnxSplitRingImpl,
VecZnxSplitRingTmpBytesImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarImpl,
VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl,
},
source::Source,
};
@@ -79,6 +84,20 @@ where
}
}
impl<B> VecZnxAddScalar for Module<B>
where
B: Backend + VecZnxAddScalarImpl<B>,
{
fn vec_znx_add_scalar<R, A, D>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize, b_limb: usize)
where
R: VecZnxToMut,
A: ScalarZnxToRef,
D: VecZnxToRef,
{
B::vec_znx_add_scalar_impl(self, res, res_col, a, a_col, b, b_col, b_limb)
}
}
impl<B> VecZnxAddScalarInplace for Module<B>
where
B: Backend + VecZnxAddScalarInplaceImpl<B>,
@@ -132,6 +151,20 @@ where
}
}
impl<B> VecZnxSubScalar for Module<B>
where
B: Backend + VecZnxSubScalarImpl<B>,
{
fn vec_znx_sub_scalar<R, A, D>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize, b_limb: usize)
where
R: VecZnxToMut,
A: ScalarZnxToRef,
D: VecZnxToRef,
{
B::vec_znx_sub_scalar_impl(self, res, res_col, a, a_col, b, b_col, b_limb)
}
}
impl<B> VecZnxSubScalarInplace for Module<B>
where
B: Backend + VecZnxSubScalarInplaceImpl<B>,
@@ -170,27 +203,87 @@ where
}
}
impl<B> VecZnxLshInplace for Module<B>
impl<B> VecZnxRshTmpBytes for Module<B>
where
B: Backend + VecZnxLshInplaceImpl<B>,
B: Backend + VecZnxRshTmpBytesImpl<B>,
{
fn vec_znx_lsh_inplace<A>(&self, basek: usize, k: usize, a: &mut A)
where
A: VecZnxToMut,
{
B::vec_znx_lsh_inplace_impl(self, basek, k, a)
fn vec_znx_rsh_tmp_bytes(&self) -> usize {
B::vec_znx_rsh_tmp_bytes_impl(self)
}
}
impl<B> VecZnxRshInplace for Module<B>
impl<B> VecZnxLshTmpBytes for Module<B>
where
B: Backend + VecZnxRshInplaceImpl<B>,
B: Backend + VecZnxLshTmpBytesImpl<B>,
{
fn vec_znx_rsh_inplace<A>(&self, basek: usize, k: usize, a: &mut A)
fn vec_znx_lsh_tmp_bytes(&self) -> usize {
B::vec_znx_lsh_tmp_bytes_impl(self)
}
}
impl<B> VecZnxLsh<B> for Module<B>
where
B: Backend + VecZnxLshImpl<B>,
{
fn vec_znx_lsh<R, A>(
&self,
basek: usize,
k: usize,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch<B>,
) where
R: VecZnxToMut,
A: VecZnxToRef,
{
B::vec_znx_lsh_inplace_impl(self, basek, k, res, res_col, a, a_col, scratch);
}
}
impl<B> VecZnxRsh<B> for Module<B>
where
B: Backend + VecZnxRshImpl<B>,
{
fn vec_znx_rsh<R, A>(
&self,
basek: usize,
k: usize,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch<B>,
) where
R: VecZnxToMut,
A: VecZnxToRef,
{
B::vec_znx_rsh_inplace_impl(self, basek, k, res, res_col, a, a_col, scratch);
}
}
impl<B> VecZnxLshInplace<B> for Module<B>
where
B: Backend + VecZnxLshInplaceImpl<B>,
{
fn vec_znx_lsh_inplace<A>(&self, basek: usize, k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
where
A: VecZnxToMut,
{
B::vec_znx_rsh_inplace_impl(self, basek, k, a)
B::vec_znx_lsh_inplace_impl(self, basek, k, a, a_col, scratch)
}
}
impl<B> VecZnxRshInplace<B> for Module<B>
where
B: Backend + VecZnxRshInplaceImpl<B>,
{
fn vec_znx_rsh_inplace<A>(&self, basek: usize, k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
where
A: VecZnxToMut,
{
B::vec_znx_rsh_inplace_impl(self, basek, k, a, a_col, scratch)
}
}
@@ -207,15 +300,24 @@ where
}
}
impl<B> VecZnxRotateInplace for Module<B>
impl<B> VecZnxRotateInplaceTmpBytes for Module<B>
where
B: Backend + VecZnxRotateInplaceTmpBytesImpl<B>,
{
fn vec_znx_rotate_inplace_tmp_bytes(&self) -> usize {
B::vec_znx_rotate_inplace_tmp_bytes_impl(self)
}
}
impl<B> VecZnxRotateInplace<B> for Module<B>
where
B: Backend + VecZnxRotateInplaceImpl<B>,
{
fn vec_znx_rotate_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
fn vec_znx_rotate_inplace<A>(&self, k: i64, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
where
A: VecZnxToMut,
{
B::vec_znx_rotate_inplace_impl(self, k, a, a_col)
B::vec_znx_rotate_inplace_impl(self, k, a, a_col, scratch)
}
}
@@ -232,15 +334,24 @@ where
}
}
impl<B> VecZnxAutomorphismInplace for Module<B>
impl<B> VecZnxAutomorphismInplaceTmpBytes for Module<B>
where
B: Backend + VecZnxAutomorphismInplaceTmpBytesImpl<B>,
{
fn vec_znx_automorphism_inplace_tmp_bytes(&self) -> usize {
B::vec_znx_automorphism_inplace_tmp_bytes_impl(self)
}
}
impl<B> VecZnxAutomorphismInplace<B> for Module<B>
where
B: Backend + VecZnxAutomorphismInplaceImpl<B>,
{
fn vec_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
fn vec_znx_automorphism_inplace<R>(&self, k: i64, res: &mut R, res_col: usize, scratch: &mut Scratch<B>)
where
A: VecZnxToMut,
R: VecZnxToMut,
{
B::vec_znx_automorphism_inplace_impl(self, k, a, a_col)
B::vec_znx_automorphism_inplace_impl(self, k, res, res_col, scratch)
}
}
@@ -257,54 +368,81 @@ where
}
}
impl<B> VecZnxMulXpMinusOneInplace for Module<B>
impl<B> VecZnxMulXpMinusOneInplaceTmpBytes for Module<B>
where
B: Backend + VecZnxMulXpMinusOneInplaceTmpBytesImpl<B>,
{
fn vec_znx_mul_xp_minus_one_inplace_tmp_bytes(&self) -> usize {
B::vec_znx_mul_xp_minus_one_inplace_tmp_bytes_impl(self)
}
}
impl<B> VecZnxMulXpMinusOneInplace<B> for Module<B>
where
B: Backend + VecZnxMulXpMinusOneInplaceImpl<B>,
{
fn vec_znx_mul_xp_minus_one_inplace<R>(&self, p: i64, res: &mut R, res_col: usize)
fn vec_znx_mul_xp_minus_one_inplace<R>(&self, p: i64, res: &mut R, res_col: usize, scratch: &mut Scratch<B>)
where
R: VecZnxToMut,
{
B::vec_znx_mul_xp_minus_one_inplace_impl(self, p, res, res_col);
B::vec_znx_mul_xp_minus_one_inplace_impl(self, p, res, res_col, scratch);
}
}
impl<B> VecZnxSplit<B> for Module<B>
impl<B> VecZnxSplitRingTmpBytes for Module<B>
where
B: Backend + VecZnxSplitImpl<B>,
B: Backend + VecZnxSplitRingTmpBytesImpl<B>,
{
fn vec_znx_split<R, A>(&self, res: &mut [R], res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
fn vec_znx_split_ring_tmp_bytes(&self) -> usize {
B::vec_znx_split_ring_tmp_bytes_impl(self)
}
}
impl<B> VecZnxSplitRing<B> for Module<B>
where
B: Backend + VecZnxSplitRingImpl<B>,
{
fn vec_znx_split_ring<R, A>(&self, res: &mut [R], res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
B::vec_znx_split_impl(self, res, res_col, a, a_col, scratch)
B::vec_znx_split_ring_impl(self, res, res_col, a, a_col, scratch)
}
}
impl<B> VecZnxMerge for Module<B>
impl<B> VecZnxMergeRingsTmpBytes for Module<B>
where
B: Backend + VecZnxMergeImpl<B>,
B: Backend + VecZnxMergeRingsTmpBytesImpl<B>,
{
fn vec_znx_merge<R, A>(&self, res: &mut R, res_col: usize, a: &[A], a_col: usize)
fn vec_znx_merge_rings_tmp_bytes(&self) -> usize {
B::vec_znx_merge_rings_tmp_bytes_impl(self)
}
}
impl<B> VecZnxMergeRings<B> for Module<B>
where
B: Backend + VecZnxMergeRingsImpl<B>,
{
fn vec_znx_merge_rings<R, A>(&self, res: &mut R, res_col: usize, a: &[A], a_col: usize, scratch: &mut Scratch<B>)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
B::vec_znx_merge_impl(self, res, res_col, a, a_col)
B::vec_znx_merge_rings_impl(self, res, res_col, a, a_col, scratch)
}
}
impl<B> VecZnxSwithcDegree for Module<B>
impl<B> VecZnxSwitchRing for Module<B>
where
B: Backend + VecZnxSwithcDegreeImpl<B>,
B: Backend + VecZnxSwitchRingImpl<B>,
{
fn vec_znx_switch_degree<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_switch_ring<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
B::vec_znx_switch_degree_impl(self, res, res_col, a, a_col)
B::vec_znx_switch_ring_impl(self, res, res_col, a, a_col)
}
}
@@ -325,51 +463,11 @@ impl<B> VecZnxFillUniform for Module<B>
where
B: Backend + VecZnxFillUniformImpl<B>,
{
fn vec_znx_fill_uniform<R>(&self, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source)
fn vec_znx_fill_uniform<R>(&self, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: VecZnxToMut,
{
B::vec_znx_fill_uniform_impl(self, basek, res, res_col, k, source);
}
}
impl<B> VecZnxFillDistF64 for Module<B>
where
B: Backend + VecZnxFillDistF64Impl<B>,
{
fn vec_znx_fill_dist_f64<R, D: rand::prelude::Distribution<f64>>(
&self,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) where
R: VecZnxToMut,
{
B::vec_znx_fill_dist_f64_impl(self, basek, res, res_col, k, source, dist, bound);
}
}
impl<B> VecZnxAddDistF64 for Module<B>
where
B: Backend + VecZnxAddDistF64Impl<B>,
{
fn vec_znx_add_dist_f64<R, D: rand::prelude::Distribution<f64>>(
&self,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) where
R: VecZnxToMut,
{
B::vec_znx_add_dist_f64_impl(self, basek, res, res_col, k, source, dist, bound);
B::vec_znx_fill_uniform_impl(self, basek, res, res_col, source);
}
}

View File

@@ -1,18 +1,16 @@
use rand_distr::Distribution;
use crate::{
api::{
VecZnxBigAdd, VecZnxBigAddDistF64, VecZnxBigAddInplace, VecZnxBigAddNormal, VecZnxBigAddSmall, VecZnxBigAddSmallInplace,
VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigAutomorphism, VecZnxBigAutomorphismInplace, VecZnxBigFillDistF64,
VecZnxBigFillNormal, VecZnxBigFromBytes, VecZnxBigNegateInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
VecZnxBigSub, VecZnxBigSubABInplace, VecZnxBigSubBAInplace, VecZnxBigSubSmallA, VecZnxBigSubSmallAInplace,
VecZnxBigSubSmallB, VecZnxBigSubSmallBInplace,
VecZnxBigAdd, VecZnxBigAddInplace, VecZnxBigAddNormal, VecZnxBigAddSmall, VecZnxBigAddSmallInplace, VecZnxBigAlloc,
VecZnxBigAllocBytes, VecZnxBigAutomorphism, VecZnxBigAutomorphismInplace, VecZnxBigAutomorphismInplaceTmpBytes,
VecZnxBigFromBytes, VecZnxBigFromSmall, VecZnxBigNegate, VecZnxBigNegateInplace, VecZnxBigNormalize,
VecZnxBigNormalizeTmpBytes, VecZnxBigSub, VecZnxBigSubABInplace, VecZnxBigSubBAInplace, VecZnxBigSubSmallA,
VecZnxBigSubSmallAInplace, VecZnxBigSubSmallB, VecZnxBigSubSmallBInplace,
},
layouts::{Backend, Module, Scratch, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef},
oep::{
VecZnxBigAddDistF64Impl, VecZnxBigAddImpl, VecZnxBigAddInplaceImpl, VecZnxBigAddNormalImpl, VecZnxBigAddSmallImpl,
VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, VecZnxBigAutomorphismImpl,
VecZnxBigAutomorphismInplaceImpl, VecZnxBigFillDistF64Impl, VecZnxBigFillNormalImpl, VecZnxBigFromBytesImpl,
VecZnxBigAddImpl, VecZnxBigAddInplaceImpl, VecZnxBigAddNormalImpl, VecZnxBigAddSmallImpl, VecZnxBigAddSmallInplaceImpl,
VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, VecZnxBigAutomorphismImpl, VecZnxBigAutomorphismInplaceImpl,
VecZnxBigAutomorphismInplaceTmpBytesImpl, VecZnxBigFromBytesImpl, VecZnxBigFromSmallImpl, VecZnxBigNegateImpl,
VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl, VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubABInplaceImpl,
VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl, VecZnxBigSubSmallAImpl, VecZnxBigSubSmallAInplaceImpl,
VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl,
@@ -20,6 +18,19 @@ use crate::{
source::Source,
};
impl<B> VecZnxBigFromSmall<B> for Module<B>
where
B: Backend + VecZnxBigFromSmallImpl<B>,
{
fn vec_znx_big_from_small<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxToRef,
{
B::vec_znx_big_from_small_impl(res, res_col, a, a_col);
}
}
impl<B> VecZnxBigAlloc<B> for Module<B>
where
B: Backend + VecZnxBigAllocImpl<B>,
@@ -47,24 +58,6 @@ where
}
}
impl<B> VecZnxBigAddDistF64<B> for Module<B>
where
B: Backend + VecZnxBigAddDistF64Impl<B>,
{
fn vec_znx_big_add_dist_f64<R: VecZnxBigToMut<B>, D: Distribution<f64>>(
&self,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) {
B::add_dist_f64_impl(self, basek, res, res_col, k, source, dist, bound);
}
}
impl<B> VecZnxBigAddNormal<B> for Module<B>
where
B: Backend + VecZnxBigAddNormalImpl<B>,
@@ -83,42 +76,6 @@ where
}
}
impl<B> VecZnxBigFillDistF64<B> for Module<B>
where
B: Backend + VecZnxBigFillDistF64Impl<B>,
{
fn vec_znx_big_fill_dist_f64<R: VecZnxBigToMut<B>, D: Distribution<f64>>(
&self,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) {
B::fill_dist_f64_impl(self, basek, res, res_col, k, source, dist, bound);
}
}
impl<B> VecZnxBigFillNormal<B> for Module<B>
where
B: Backend + VecZnxBigFillNormalImpl<B>,
{
fn vec_znx_big_fill_normal<R: VecZnxBigToMut<B>>(
&self,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) {
B::fill_normal_impl(self, basek, res, res_col, k, source, sigma, bound);
}
}
impl<B> VecZnxBigAdd<B> for Module<B>
where
B: Backend + VecZnxBigAddImpl<B>,
@@ -267,6 +224,19 @@ where
}
}
impl<B> VecZnxBigNegate<B> for Module<B>
where
B: Backend + VecZnxBigNegateImpl<B>,
{
fn vec_znx_big_negate<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxBigToRef<B>,
{
B::vec_znx_big_negate_impl(self, res, res_col, a, a_col);
}
}
impl<B> VecZnxBigNegateInplace<B> for Module<B>
where
B: Backend + VecZnxBigNegateInplaceImpl<B>,
@@ -321,14 +291,23 @@ where
}
}
impl<B> VecZnxBigAutomorphismInplaceTmpBytes for Module<B>
where
B: Backend + VecZnxBigAutomorphismInplaceTmpBytesImpl<B>,
{
fn vec_znx_big_automorphism_inplace_tmp_bytes(&self) -> usize {
B::vec_znx_big_automorphism_inplace_tmp_bytes_impl(self)
}
}
impl<B> VecZnxBigAutomorphismInplace<B> for Module<B>
where
B: Backend + VecZnxBigAutomorphismInplaceImpl<B>,
{
fn vec_znx_big_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
fn vec_znx_big_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
where
A: VecZnxBigToMut<B>,
{
B::vec_znx_big_automorphism_inplace_impl(self, k, a, a_col);
B::vec_znx_big_automorphism_inplace_impl(self, k, a, a_col, scratch);
}
}

View File

@@ -1,16 +1,17 @@
use crate::{
api::{
DFT, IDFT, IDFTConsume, IDFTTmpA, VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftCopy,
VecZnxDftFromBytes, VecZnxDftSub, VecZnxDftSubABInplace, VecZnxDftSubBAInplace, VecZnxDftZero, VecZnxIDFTTmpBytes,
VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftCopy,
VecZnxDftFromBytes, VecZnxDftSub, VecZnxDftSubABInplace, VecZnxDftSubBAInplace, VecZnxDftZero, VecZnxIdftApply,
VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxIdftApplyTmpBytes,
},
layouts::{
Backend, Data, Module, Scratch, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef,
VecZnxToRef,
},
oep::{
DFTImpl, IDFTConsumeImpl, IDFTImpl, IDFTTmpAImpl, VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl,
VecZnxDftAllocImpl, VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl,
VecZnxDftSubImpl, VecZnxDftZeroImpl, VecZnxIDFTTmpBytesImpl,
VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftApplyImpl,
VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl, VecZnxDftSubImpl,
VecZnxDftZeroImpl, VecZnxIdftApplyConsumeImpl, VecZnxIdftApplyImpl, VecZnxIdftApplyTmpAImpl, VecZnxIdftApplyTmpBytesImpl,
},
};
@@ -41,63 +42,63 @@ where
}
}
impl<B> VecZnxIDFTTmpBytes for Module<B>
impl<B> VecZnxIdftApplyTmpBytes for Module<B>
where
B: Backend + VecZnxIDFTTmpBytesImpl<B>,
B: Backend + VecZnxIdftApplyTmpBytesImpl<B>,
{
fn vec_znx_idft_tmp_bytes(&self) -> usize {
B::vec_znx_idft_tmp_bytes_impl(self)
fn vec_znx_idft_apply_tmp_bytes(&self) -> usize {
B::vec_znx_idft_apply_tmp_bytes_impl(self)
}
}
impl<B> IDFT<B> for Module<B>
impl<B> VecZnxIdftApply<B> for Module<B>
where
B: Backend + IDFTImpl<B>,
B: Backend + VecZnxIdftApplyImpl<B>,
{
fn idft<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
fn vec_znx_idft_apply<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
where
R: VecZnxBigToMut<B>,
A: VecZnxDftToRef<B>,
{
B::idft_impl(self, res, res_col, a, a_col, scratch);
B::vec_znx_idft_apply_impl(self, res, res_col, a, a_col, scratch);
}
}
impl<B> IDFTTmpA<B> for Module<B>
impl<B> VecZnxIdftApplyTmpA<B> for Module<B>
where
B: Backend + IDFTTmpAImpl<B>,
B: Backend + VecZnxIdftApplyTmpAImpl<B>,
{
fn idft_tmp_a<R, A>(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
fn vec_znx_idft_apply_tmpa<R, A>(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxDftToMut<B>,
{
B::idft_tmp_a_impl(self, res, res_col, a, a_col);
B::vec_znx_idft_apply_tmpa_impl(self, res, res_col, a, a_col);
}
}
impl<B> IDFTConsume<B> for Module<B>
impl<B> VecZnxIdftApplyConsume<B> for Module<B>
where
B: Backend + IDFTConsumeImpl<B>,
B: Backend + VecZnxIdftApplyConsumeImpl<B>,
{
fn vec_znx_idft_consume<D: Data>(&self, a: VecZnxDft<D, B>) -> VecZnxBig<D, B>
fn vec_znx_idft_apply_consume<D: Data>(&self, a: VecZnxDft<D, B>) -> VecZnxBig<D, B>
where
VecZnxDft<D, B>: VecZnxDftToMut<B>,
{
B::idft_consume_impl(self, a)
B::vec_znx_idft_apply_consume_impl(self, a)
}
}
impl<B> DFT<B> for Module<B>
impl<B> VecZnxDftApply<B> for Module<B>
where
B: Backend + DFTImpl<B>,
B: Backend + VecZnxDftApplyImpl<B>,
{
fn dft<R, A>(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_dft_apply<R, A>(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<B>,
A: VecZnxToRef,
{
B::dft_impl(self, step, offset, res, res_col, a, a_col);
B::vec_znx_dft_apply_impl(self, step, offset, res, res_col, a, a_col);
}
}

View File

@@ -1,12 +1,16 @@
use crate::{
api::{
VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftAddTmpBytes, VmpApplyDftToDftTmpBytes, VmpPMatAlloc,
VmpPMatAllocBytes, VmpPMatFromBytes, VmpPrepare, VmpPrepareTmpBytes,
VmpApplyDft, VmpApplyDftTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftAddTmpBytes,
VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPMatAllocBytes, VmpPMatFromBytes, VmpPrepare, VmpPrepareTmpBytes,
},
layouts::{
Backend, MatZnxToRef, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, VmpPMatOwned, VmpPMatToMut,
VmpPMatToRef,
},
layouts::{Backend, MatZnxToRef, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VmpPMatOwned, VmpPMatToMut, VmpPMatToRef},
oep::{
VmpApplyDftToDftAddImpl, VmpApplyDftToDftAddTmpBytesImpl, VmpApplyDftToDftImpl, VmpApplyDftToDftTmpBytesImpl,
VmpPMatAllocBytesImpl, VmpPMatAllocImpl, VmpPMatFromBytesImpl, VmpPMatPrepareImpl, VmpPrepareTmpBytesImpl,
VmpApplyDftImpl, VmpApplyDftTmpBytesImpl, VmpApplyDftToDftAddImpl, VmpApplyDftToDftAddTmpBytesImpl, VmpApplyDftToDftImpl,
VmpApplyDftToDftTmpBytesImpl, VmpPMatAllocBytesImpl, VmpPMatAllocImpl, VmpPMatFromBytesImpl, VmpPrepareImpl,
VmpPrepareTmpBytesImpl,
},
};
@@ -48,7 +52,7 @@ where
impl<B> VmpPrepare<B> for Module<B>
where
B: Backend + VmpPMatPrepareImpl<B>,
B: Backend + VmpPrepareImpl<B>,
{
fn vmp_prepare<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<B>)
where
@@ -59,6 +63,39 @@ where
}
}
impl<B> VmpApplyDftTmpBytes for Module<B>
where
B: Backend + VmpApplyDftTmpBytesImpl<B>,
{
fn vmp_apply_dft_tmp_bytes(
&self,
res_size: usize,
a_size: usize,
b_rows: usize,
b_cols_in: usize,
b_cols_out: usize,
b_size: usize,
) -> usize {
B::vmp_apply_dft_tmp_bytes_impl(
self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
)
}
}
impl<B> VmpApplyDft<B> for Module<B>
where
B: Backend + VmpApplyDftImpl<B>,
{
fn vmp_apply_dft<R, A, C>(&self, res: &mut R, a: &A, b: &C, scratch: &mut Scratch<B>)
where
R: VecZnxDftToMut<B>,
A: VecZnxToRef,
C: VmpPMatToRef<B>,
{
B::vmp_apply_dft_impl(self, res, a, b, scratch);
}
}
impl<B> VmpApplyDftToDftTmpBytes for Module<B>
where
B: Backend + VmpApplyDftToDftTmpBytesImpl<B>,

View File

@@ -1,10 +1,19 @@
use crate::{
api::{ZnAddDistF64, ZnAddNormal, ZnFillDistF64, ZnFillNormal, ZnFillUniform, ZnNormalizeInplace},
api::{ZnAddNormal, ZnFillNormal, ZnFillUniform, ZnNormalizeInplace, ZnNormalizeTmpBytes},
layouts::{Backend, Module, Scratch, ZnToMut},
oep::{ZnAddDistF64Impl, ZnAddNormalImpl, ZnFillDistF64Impl, ZnFillNormalImpl, ZnFillUniformImpl, ZnNormalizeInplaceImpl},
oep::{ZnAddNormalImpl, ZnFillNormalImpl, ZnFillUniformImpl, ZnNormalizeInplaceImpl, ZnNormalizeTmpBytesImpl},
source::Source,
};
impl<B> ZnNormalizeTmpBytes for Module<B>
where
B: Backend + ZnNormalizeTmpBytesImpl<B>,
{
fn zn_normalize_tmp_bytes(&self, n: usize) -> usize {
B::zn_normalize_tmp_bytes_impl(n)
}
}
impl<B> ZnNormalizeInplace<B> for Module<B>
where
B: Backend + ZnNormalizeInplaceImpl<B>,
@@ -21,53 +30,11 @@ impl<B> ZnFillUniform for Module<B>
where
B: Backend + ZnFillUniformImpl<B>,
{
fn zn_fill_uniform<R>(&self, n: usize, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source)
fn zn_fill_uniform<R>(&self, n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: ZnToMut,
{
B::zn_fill_uniform_impl(n, basek, res, res_col, k, source);
}
}
impl<B> ZnFillDistF64 for Module<B>
where
B: Backend + ZnFillDistF64Impl<B>,
{
fn zn_fill_dist_f64<R, D: rand::prelude::Distribution<f64>>(
&self,
n: usize,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) where
R: ZnToMut,
{
B::zn_fill_dist_f64_impl(n, basek, res, res_col, k, source, dist, bound);
}
}
impl<B> ZnAddDistF64 for Module<B>
where
B: Backend + ZnAddDistF64Impl<B>,
{
fn zn_add_dist_f64<R, D: rand::prelude::Distribution<f64>>(
&self,
n: usize,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) where
R: ZnToMut,
{
B::zn_add_dist_f64_impl(n, basek, res, res_col, k, source, dist, bound);
B::zn_fill_uniform_impl(n, basek, res, res_col, source);
}
}