Add cross-basek normalization (#90)

* added cross_basek_normalization

* updated method signatures to take layouts

* fixed cross-base normalization

fix #91
fix #93
This commit is contained in:
Jean-Philippe Bossuat
2025-09-30 14:40:10 +02:00
committed by GitHub
parent 4da790ea6a
commit 37e13b965c
216 changed files with 12481 additions and 7745 deletions

View File

@@ -5,8 +5,8 @@ use crate::{
VecZnxLshInplace, VecZnxLshTmpBytes, VecZnxMergeRings, VecZnxMergeRingsTmpBytes, VecZnxMulXpMinusOne,
VecZnxMulXpMinusOneInplace, VecZnxMulXpMinusOneInplaceTmpBytes, VecZnxNegate, VecZnxNegateInplace, VecZnxNormalize,
VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes,
VecZnxRsh, VecZnxRshInplace, VecZnxRshTmpBytes, VecZnxSplitRing, VecZnxSplitRingTmpBytes, VecZnxSub, VecZnxSubABInplace,
VecZnxSubBAInplace, VecZnxSubScalar, VecZnxSubScalarInplace, VecZnxSwitchRing,
VecZnxRsh, VecZnxRshInplace, VecZnxRshTmpBytes, VecZnxSplitRing, VecZnxSplitRingTmpBytes, VecZnxSub, VecZnxSubInplace,
VecZnxSubNegateInplace, VecZnxSubScalar, VecZnxSubScalarInplace, VecZnxSwitchRing,
},
layouts::{Backend, Module, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef},
oep::{
@@ -17,7 +17,7 @@ use crate::{
VecZnxMulXpMinusOneInplaceTmpBytesImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl,
VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl,
VecZnxRotateInplaceTmpBytesImpl, VecZnxRshImpl, VecZnxRshInplaceImpl, VecZnxRshTmpBytesImpl, VecZnxSplitRingImpl,
VecZnxSplitRingTmpBytesImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarImpl,
VecZnxSplitRingTmpBytesImpl, VecZnxSubImpl, VecZnxSubInplaceImpl, VecZnxSubNegateInplaceImpl, VecZnxSubScalarImpl,
VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl,
},
source::Source,
@@ -36,12 +36,21 @@ impl<B> VecZnxNormalize<B> for Module<B>
where
B: Backend + VecZnxNormalizeImpl<B>,
{
fn vec_znx_normalize<R, A>(&self, basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
where
#[allow(clippy::too_many_arguments)]
fn vec_znx_normalize<R, A>(
&self,
res_basek: usize,
res: &mut R,
res_col: usize,
a_basek: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch<B>,
) where
R: VecZnxToMut,
A: VecZnxToRef,
{
B::vec_znx_normalize_impl(self, basek, res, res_col, a, a_col, scratch)
B::vec_znx_normalize_impl(self, res_basek, res, res_col, a_basek, a, a_col, scratch)
}
}
@@ -49,11 +58,11 @@ impl<B> VecZnxNormalizeInplace<B> for Module<B>
where
B: Backend + VecZnxNormalizeInplaceImpl<B>,
{
fn vec_znx_normalize_inplace<A>(&self, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
fn vec_znx_normalize_inplace<A>(&self, base2k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
where
A: VecZnxToMut,
{
B::vec_znx_normalize_inplace_impl(self, basek, a, a_col, scratch)
B::vec_znx_normalize_inplace_impl(self, base2k, a, a_col, scratch)
}
}
@@ -125,29 +134,29 @@ where
}
}
impl<B> VecZnxSubABInplace for Module<B>
impl<B> VecZnxSubInplace for Module<B>
where
B: Backend + VecZnxSubABInplaceImpl<B>,
B: Backend + VecZnxSubInplaceImpl<B>,
{
fn vec_znx_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_sub_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
B::vec_znx_sub_ab_inplace_impl(self, res, res_col, a, a_col)
B::vec_znx_sub_inplace_impl(self, res, res_col, a, a_col)
}
}
impl<B> VecZnxSubBAInplace for Module<B>
impl<B> VecZnxSubNegateInplace for Module<B>
where
B: Backend + VecZnxSubBAInplaceImpl<B>,
B: Backend + VecZnxSubNegateInplaceImpl<B>,
{
fn vec_znx_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_sub_negate_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
B::vec_znx_sub_ba_inplace_impl(self, res, res_col, a, a_col)
B::vec_znx_sub_negate_inplace_impl(self, res, res_col, a, a_col)
}
}
@@ -227,7 +236,7 @@ where
{
fn vec_znx_lsh<R, A>(
&self,
basek: usize,
base2k: usize,
k: usize,
res: &mut R,
res_col: usize,
@@ -238,7 +247,7 @@ where
R: VecZnxToMut,
A: VecZnxToRef,
{
B::vec_znx_lsh_inplace_impl(self, basek, k, res, res_col, a, a_col, scratch);
B::vec_znx_lsh_impl(self, base2k, k, res, res_col, a, a_col, scratch);
}
}
@@ -248,7 +257,7 @@ where
{
fn vec_znx_rsh<R, A>(
&self,
basek: usize,
base2k: usize,
k: usize,
res: &mut R,
res_col: usize,
@@ -259,7 +268,7 @@ where
R: VecZnxToMut,
A: VecZnxToRef,
{
B::vec_znx_rsh_inplace_impl(self, basek, k, res, res_col, a, a_col, scratch);
B::vec_znx_rsh_impl(self, base2k, k, res, res_col, a, a_col, scratch);
}
}
@@ -267,11 +276,11 @@ impl<B> VecZnxLshInplace<B> for Module<B>
where
B: Backend + VecZnxLshInplaceImpl<B>,
{
fn vec_znx_lsh_inplace<A>(&self, basek: usize, k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
fn vec_znx_lsh_inplace<A>(&self, base2k: usize, k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
where
A: VecZnxToMut,
{
B::vec_znx_lsh_inplace_impl(self, basek, k, a, a_col, scratch)
B::vec_znx_lsh_inplace_impl(self, base2k, k, a, a_col, scratch)
}
}
@@ -279,11 +288,11 @@ impl<B> VecZnxRshInplace<B> for Module<B>
where
B: Backend + VecZnxRshInplaceImpl<B>,
{
fn vec_znx_rsh_inplace<A>(&self, basek: usize, k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
fn vec_znx_rsh_inplace<A>(&self, base2k: usize, k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
where
A: VecZnxToMut,
{
B::vec_znx_rsh_inplace_impl(self, basek, k, a, a_col, scratch)
B::vec_znx_rsh_inplace_impl(self, base2k, k, a, a_col, scratch)
}
}
@@ -463,11 +472,11 @@ impl<B> VecZnxFillUniform for Module<B>
where
B: Backend + VecZnxFillUniformImpl<B>,
{
fn vec_znx_fill_uniform<R>(&self, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
fn vec_znx_fill_uniform<R>(&self, base2k: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: VecZnxToMut,
{
B::vec_znx_fill_uniform_impl(self, basek, res, res_col, source);
B::vec_znx_fill_uniform_impl(self, base2k, res, res_col, source);
}
}
@@ -477,7 +486,7 @@ where
{
fn vec_znx_fill_normal<R>(
&self,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -487,7 +496,7 @@ where
) where
R: VecZnxToMut,
{
B::vec_znx_fill_normal_impl(self, basek, res, res_col, k, source, sigma, bound);
B::vec_znx_fill_normal_impl(self, base2k, res, res_col, k, source, sigma, bound);
}
}
@@ -497,7 +506,7 @@ where
{
fn vec_znx_add_normal<R>(
&self,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -507,6 +516,6 @@ where
) where
R: VecZnxToMut,
{
B::vec_znx_add_normal_impl(self, basek, res, res_col, k, source, sigma, bound);
B::vec_znx_add_normal_impl(self, base2k, res, res_col, k, source, sigma, bound);
}
}