Add cross-basek normalization (#90)

* added cross_basek_normalization

* updated method signatures to take layouts

* fixed cross-base normalization

fix #91
fix #93
This commit is contained in:
Jean-Philippe Bossuat
2025-09-30 14:40:10 +02:00
committed by GitHub
parent 4da790ea6a
commit 37e13b965c
216 changed files with 12481 additions and 7745 deletions

View File

@@ -1,10 +1,11 @@
use poulpy_hal::reference::fft64::{
reim::{
ReimAdd, ReimAddInplace, ReimAddMul, ReimCopy, ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimIFFTTable, ReimMul,
ReimMulInplace, ReimNegate, ReimNegateInplace, ReimSub, ReimSubABInplace, ReimSubBAInplace, ReimToZnx, ReimToZnxInplace,
ReimZero, fft_ref, ifft_ref, reim_add_inplace_ref, reim_add_ref, reim_addmul_ref, reim_copy_ref, reim_from_znx_i64_ref,
reim_mul_inplace_ref, reim_mul_ref, reim_negate_inplace_ref, reim_negate_ref, reim_sub_ab_inplace_ref,
reim_sub_ba_inplace_ref, reim_sub_ref, reim_to_znx_i64_inplace_ref, reim_to_znx_i64_ref, reim_zero_ref,
ReimMulInplace, ReimNegate, ReimNegateInplace, ReimSub, ReimSubInplace, ReimSubNegateInplace, ReimToZnx,
ReimToZnxInplace, ReimZero, fft_ref, ifft_ref, reim_add_inplace_ref, reim_add_ref, reim_addmul_ref, reim_copy_ref,
reim_from_znx_i64_ref, reim_mul_inplace_ref, reim_mul_ref, reim_negate_inplace_ref, reim_negate_ref,
reim_sub_inplace_ref, reim_sub_negate_inplace_ref, reim_sub_ref, reim_to_znx_i64_inplace_ref, reim_to_znx_i64_ref,
reim_zero_ref,
},
reim4::{
Reim4Extract1Blk, Reim4Mat1ColProd, Reim4Mat2Cols2ndColProd, Reim4Mat2ColsProd, Reim4Save1Blk, Reim4Save2Blks,
@@ -69,17 +70,17 @@ impl ReimSub for FFT64Ref {
}
}
impl ReimSubABInplace for FFT64Ref {
impl ReimSubInplace for FFT64Ref {
#[inline(always)]
fn reim_sub_ab_inplace(res: &mut [f64], a: &[f64]) {
reim_sub_ab_inplace_ref(res, a);
fn reim_sub_inplace(res: &mut [f64], a: &[f64]) {
reim_sub_inplace_ref(res, a);
}
}
impl ReimSubBAInplace for FFT64Ref {
impl ReimSubNegateInplace for FFT64Ref {
#[inline(always)]
fn reim_sub_ba_inplace(res: &mut [f64], a: &[f64]) {
reim_sub_ba_inplace_ref(res, a);
fn reim_sub_negate_inplace(res: &mut [f64], a: &[f64]) {
reim_sub_negate_inplace_ref(res, a);
}
}

View File

@@ -253,9 +253,6 @@ fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8]
(take_slice, rem_slice)
}
} else {
panic!(
"Attempted to take {} from scratch with {} aligned bytes left",
take_len, aligned_len,
);
panic!("Attempted to take {take_len} from scratch with {aligned_len} aligned bytes left");
}
}

View File

@@ -1,7 +1,8 @@
use poulpy_hal::{
api::{
TakeSlice, VecZnxAutomorphismInplaceTmpBytes, VecZnxMergeRingsTmpBytes, VecZnxMulXpMinusOneInplaceTmpBytes,
VecZnxNormalizeTmpBytes, VecZnxRotateInplaceTmpBytes, VecZnxSplitRingTmpBytes,
TakeSlice, VecZnxAutomorphismInplaceTmpBytes, VecZnxLshTmpBytes, VecZnxMergeRingsTmpBytes,
VecZnxMulXpMinusOneInplaceTmpBytes, VecZnxNormalizeTmpBytes, VecZnxRotateInplaceTmpBytes, VecZnxRshTmpBytes,
VecZnxSplitRingTmpBytes,
},
layouts::{Module, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef},
oep::{
@@ -12,7 +13,7 @@ use poulpy_hal::{
VecZnxMulXpMinusOneInplaceTmpBytesImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl,
VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl,
VecZnxRotateInplaceTmpBytesImpl, VecZnxRshImpl, VecZnxRshInplaceImpl, VecZnxRshTmpBytesImpl, VecZnxSplitRingImpl,
VecZnxSplitRingTmpBytesImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarImpl,
VecZnxSplitRingTmpBytesImpl, VecZnxSubImpl, VecZnxSubInplaceImpl, VecZnxSubNegateInplaceImpl, VecZnxSubScalarImpl,
VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl,
},
reference::vec_znx::{
@@ -23,7 +24,7 @@ use poulpy_hal::{
vec_znx_mul_xp_minus_one_inplace_tmp_bytes, vec_znx_negate, vec_znx_negate_inplace, vec_znx_normalize,
vec_znx_normalize_inplace, vec_znx_normalize_tmp_bytes, vec_znx_rotate, vec_znx_rotate_inplace,
vec_znx_rotate_inplace_tmp_bytes, vec_znx_rsh, vec_znx_rsh_inplace, vec_znx_rsh_tmp_bytes, vec_znx_split_ring,
vec_znx_split_ring_tmp_bytes, vec_znx_sub, vec_znx_sub_ab_inplace, vec_znx_sub_ba_inplace, vec_znx_sub_scalar,
vec_znx_split_ring_tmp_bytes, vec_znx_sub, vec_znx_sub_inplace, vec_znx_sub_negate_inplace, vec_znx_sub_scalar,
vec_znx_sub_scalar_inplace, vec_znx_switch_ring,
},
source::Source,
@@ -43,9 +44,10 @@ where
{
fn vec_znx_normalize_impl<R, A>(
module: &Module<Self>,
basek: usize,
res_basek: usize,
res: &mut R,
res_col: usize,
a_basek: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch<Self>,
@@ -54,7 +56,7 @@ where
A: VecZnxToRef,
{
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
vec_znx_normalize::<R, A, Self>(basek, res, res_col, a, a_col, carry);
vec_znx_normalize::<R, A, Self>(res_basek, res, res_col, a_basek, a, a_col, carry);
}
}
@@ -64,7 +66,7 @@ where
{
fn vec_znx_normalize_inplace_impl<R>(
module: &Module<Self>,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
scratch: &mut Scratch<Self>,
@@ -72,7 +74,7 @@ where
R: VecZnxToMut,
{
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
vec_znx_normalize_inplace::<R, Self>(basek, res, res_col, carry);
vec_znx_normalize_inplace::<R, Self>(base2k, res, res_col, carry);
}
}
@@ -143,23 +145,23 @@ unsafe impl VecZnxSubImpl<Self> for FFT64Ref {
}
}
unsafe impl VecZnxSubABInplaceImpl<Self> for FFT64Ref {
fn vec_znx_sub_ab_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
unsafe impl VecZnxSubInplaceImpl<Self> for FFT64Ref {
fn vec_znx_sub_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
vec_znx_sub_ab_inplace::<R, A, Self>(res, res_col, a, a_col);
vec_znx_sub_inplace::<R, A, Self>(res, res_col, a, a_col);
}
}
unsafe impl VecZnxSubBAInplaceImpl<Self> for FFT64Ref {
fn vec_znx_sub_ba_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
unsafe impl VecZnxSubNegateInplaceImpl<Self> for FFT64Ref {
fn vec_znx_sub_negate_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
vec_znx_sub_ba_inplace::<R, A, Self>(res, res_col, a, a_col);
vec_znx_sub_negate_inplace::<R, A, Self>(res, res_col, a, a_col);
}
}
@@ -234,9 +236,9 @@ where
Module<Self>: VecZnxNormalizeTmpBytes,
Scratch<Self>: TakeSlice,
{
fn vec_znx_lsh_inplace_impl<R, A>(
fn vec_znx_lsh_impl<R, A>(
module: &Module<Self>,
basek: usize,
base2k: usize,
k: usize,
res: &mut R,
res_col: usize,
@@ -247,8 +249,8 @@ where
R: VecZnxToMut,
A: VecZnxToRef,
{
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
vec_znx_lsh::<_, _, Self>(basek, k, res, res_col, a, a_col, carry);
let (carry, _) = scratch.take_slice(module.vec_znx_lsh_tmp_bytes() / size_of::<i64>());
vec_znx_lsh::<_, _, Self>(base2k, k, res, res_col, a, a_col, carry);
}
}
@@ -259,7 +261,7 @@ where
{
fn vec_znx_lsh_inplace_impl<A>(
module: &Module<Self>,
basek: usize,
base2k: usize,
k: usize,
a: &mut A,
a_col: usize,
@@ -267,8 +269,8 @@ where
) where
A: VecZnxToMut,
{
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
vec_znx_lsh_inplace::<_, Self>(basek, k, a, a_col, carry);
let (carry, _) = scratch.take_slice(module.vec_znx_lsh_tmp_bytes() / size_of::<i64>());
vec_znx_lsh_inplace::<_, Self>(base2k, k, a, a_col, carry);
}
}
@@ -277,9 +279,9 @@ where
Module<Self>: VecZnxNormalizeTmpBytes,
Scratch<Self>: TakeSlice,
{
fn vec_znx_rsh_inplace_impl<R, A>(
fn vec_znx_rsh_impl<R, A>(
module: &Module<Self>,
basek: usize,
base2k: usize,
k: usize,
res: &mut R,
res_col: usize,
@@ -290,8 +292,8 @@ where
R: VecZnxToMut,
A: VecZnxToRef,
{
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
vec_znx_rsh::<_, _, Self>(basek, k, res, res_col, a, a_col, carry);
let (carry, _) = scratch.take_slice(module.vec_znx_rsh_tmp_bytes() / size_of::<i64>());
vec_znx_rsh::<_, _, Self>(base2k, k, res, res_col, a, a_col, carry);
}
}
@@ -302,7 +304,7 @@ where
{
fn vec_znx_rsh_inplace_impl<A>(
module: &Module<Self>,
basek: usize,
base2k: usize,
k: usize,
a: &mut A,
a_col: usize,
@@ -310,8 +312,8 @@ where
) where
A: VecZnxToMut,
{
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
vec_znx_rsh_inplace::<_, Self>(basek, k, a, a_col, carry);
let (carry, _) = scratch.take_slice(module.vec_znx_rsh_tmp_bytes() / size_of::<i64>());
vec_znx_rsh_inplace::<_, Self>(base2k, k, a, a_col, carry);
}
}
@@ -495,18 +497,18 @@ unsafe impl VecZnxCopyImpl<Self> for FFT64Ref {
}
unsafe impl VecZnxFillUniformImpl<Self> for FFT64Ref {
fn vec_znx_fill_uniform_impl<R>(_module: &Module<Self>, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
fn vec_znx_fill_uniform_impl<R>(_module: &Module<Self>, base2k: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: VecZnxToMut,
{
vec_znx_fill_uniform_ref(basek, res, res_col, source)
vec_znx_fill_uniform_ref(base2k, res, res_col, source)
}
}
unsafe impl VecZnxFillNormalImpl<Self> for FFT64Ref {
fn vec_znx_fill_normal_impl<R>(
_module: &Module<Self>,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -516,14 +518,14 @@ unsafe impl VecZnxFillNormalImpl<Self> for FFT64Ref {
) where
R: VecZnxToMut,
{
vec_znx_fill_normal_ref(basek, res, res_col, k, sigma, bound, source);
vec_znx_fill_normal_ref(base2k, res, res_col, k, sigma, bound, source);
}
}
unsafe impl VecZnxAddNormalImpl<Self> for FFT64Ref {
fn vec_znx_add_normal_impl<R>(
_module: &Module<Self>,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -533,6 +535,6 @@ unsafe impl VecZnxAddNormalImpl<Self> for FFT64Ref {
) where
R: VecZnxToMut,
{
vec_znx_add_normal_ref(basek, res, res_col, k, sigma, bound, source);
vec_znx_add_normal_ref(base2k, res, res_col, k, sigma, bound, source);
}
}

View File

@@ -10,15 +10,15 @@ use poulpy_hal::{
VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, VecZnxBigAutomorphismImpl,
VecZnxBigAutomorphismInplaceImpl, VecZnxBigAutomorphismInplaceTmpBytesImpl, VecZnxBigFromBytesImpl,
VecZnxBigFromSmallImpl, VecZnxBigNegateImpl, VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl,
VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubABInplaceImpl, VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl,
VecZnxBigSubSmallAImpl, VecZnxBigSubSmallAInplaceImpl, VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl,
VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubImpl, VecZnxBigSubInplaceImpl, VecZnxBigSubNegateInplaceImpl,
VecZnxBigSubSmallAImpl, VecZnxBigSubSmallBImpl, VecZnxBigSubSmallInplaceImpl, VecZnxBigSubSmallNegateInplaceImpl,
},
reference::{
fft64::vec_znx_big::{
vec_znx_big_add, vec_znx_big_add_inplace, vec_znx_big_add_normal_ref, vec_znx_big_add_small,
vec_znx_big_add_small_inplace, vec_znx_big_automorphism, vec_znx_big_automorphism_inplace,
vec_znx_big_automorphism_inplace_tmp_bytes, vec_znx_big_negate, vec_znx_big_negate_inplace, vec_znx_big_normalize,
vec_znx_big_normalize_tmp_bytes, vec_znx_big_sub, vec_znx_big_sub_ab_inplace, vec_znx_big_sub_ba_inplace,
vec_znx_big_normalize_tmp_bytes, vec_znx_big_sub, vec_znx_big_sub_inplace, vec_znx_big_sub_negate_inplace,
vec_znx_big_sub_small_a, vec_znx_big_sub_small_a_inplace, vec_znx_big_sub_small_b, vec_znx_big_sub_small_b_inplace,
},
znx::{znx_copy_ref, znx_zero_ref},
@@ -76,7 +76,7 @@ unsafe impl VecZnxBigFromSmallImpl<Self> for FFT64Ref {
unsafe impl VecZnxBigAddNormalImpl<Self> for FFT64Ref {
fn add_normal_impl<R: VecZnxBigToMut<Self>>(
_module: &Module<Self>,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -84,7 +84,7 @@ unsafe impl VecZnxBigAddNormalImpl<Self> for FFT64Ref {
sigma: f64,
bound: f64,
) {
vec_znx_big_add_normal_ref(basek, res, res_col, k, sigma, bound, source);
vec_znx_big_add_normal_ref(base2k, res, res_col, k, sigma, bound, source);
}
}
@@ -167,25 +167,25 @@ unsafe impl VecZnxBigSubImpl<Self> for FFT64Ref {
}
}
unsafe impl VecZnxBigSubABInplaceImpl<Self> for FFT64Ref {
unsafe impl VecZnxBigSubInplaceImpl<Self> for FFT64Ref {
/// Subtracts `a` from `b` and stores the result on `b`.
fn vec_znx_big_sub_ab_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_big_sub_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<Self>,
A: VecZnxBigToRef<Self>,
{
vec_znx_big_sub_ab_inplace(res, res_col, a, a_col);
vec_znx_big_sub_inplace(res, res_col, a, a_col);
}
}
unsafe impl VecZnxBigSubBAInplaceImpl<Self> for FFT64Ref {
unsafe impl VecZnxBigSubNegateInplaceImpl<Self> for FFT64Ref {
/// Subtracts `b` from `a` and stores the result on `b`.
fn vec_znx_big_sub_ba_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_big_sub_negate_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<Self>,
A: VecZnxBigToRef<Self>,
{
vec_znx_big_sub_ba_inplace(res, res_col, a, a_col);
vec_znx_big_sub_negate_inplace(res, res_col, a, a_col);
}
}
@@ -208,9 +208,9 @@ unsafe impl VecZnxBigSubSmallAImpl<Self> for FFT64Ref {
}
}
unsafe impl VecZnxBigSubSmallAInplaceImpl<Self> for FFT64Ref {
unsafe impl VecZnxBigSubSmallInplaceImpl<Self> for FFT64Ref {
/// Subtracts `a` from `res` and stores the result on `res`.
fn vec_znx_big_sub_small_a_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_big_sub_small_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<Self>,
A: VecZnxToRef,
@@ -238,9 +238,9 @@ unsafe impl VecZnxBigSubSmallBImpl<Self> for FFT64Ref {
}
}
unsafe impl VecZnxBigSubSmallBInplaceImpl<Self> for FFT64Ref {
unsafe impl VecZnxBigSubSmallNegateInplaceImpl<Self> for FFT64Ref {
/// Subtracts `res` from `a` and stores the result on `res`.
fn vec_znx_big_sub_small_b_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_big_sub_small_negate_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<Self>,
A: VecZnxToRef,
@@ -280,9 +280,10 @@ where
{
fn vec_znx_big_normalize_impl<R, A>(
module: &Module<Self>,
basek: usize,
res_basek: usize,
res: &mut R,
res_col: usize,
a_basek: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch<Self>,
@@ -291,7 +292,7 @@ where
A: VecZnxBigToRef<Self>,
{
let (carry, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes() / size_of::<i64>());
vec_znx_big_normalize(basek, res, res_col, a, a_col, carry);
vec_znx_big_normalize(res_basek, res, res_col, a_basek, a, a_col, carry);
}
}
@@ -326,7 +327,7 @@ where
) where
R: VecZnxBigToMut<Self>,
{
let (tmp, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes() / size_of::<i64>());
let (tmp, _) = scratch.take_slice(module.vec_znx_big_automorphism_inplace_tmp_bytes() / size_of::<i64>());
vec_znx_big_automorphism_inplace(p, res, res_col, tmp);
}
}

View File

@@ -5,12 +5,12 @@ use poulpy_hal::{
},
oep::{
VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftApplyImpl,
VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl, VecZnxDftSubImpl,
VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubImpl, VecZnxDftSubInplaceImpl, VecZnxDftSubNegateInplaceImpl,
VecZnxDftZeroImpl, VecZnxIdftApplyConsumeImpl, VecZnxIdftApplyImpl, VecZnxIdftApplyTmpAImpl, VecZnxIdftApplyTmpBytesImpl,
},
reference::fft64::vec_znx_dft::{
vec_znx_dft_add, vec_znx_dft_add_inplace, vec_znx_dft_apply, vec_znx_dft_copy, vec_znx_dft_sub,
vec_znx_dft_sub_ab_inplace, vec_znx_dft_sub_ba_inplace, vec_znx_dft_zero, vec_znx_idft_apply, vec_znx_idft_apply_consume,
vec_znx_dft_add, vec_znx_dft_add_inplace, vec_znx_dft_apply, vec_znx_dft_copy, vec_znx_dft_sub, vec_znx_dft_sub_inplace,
vec_znx_dft_sub_negate_inplace, vec_znx_dft_zero, vec_znx_idft_apply, vec_znx_idft_apply_consume,
vec_znx_idft_apply_tmpa,
},
};
@@ -139,23 +139,23 @@ unsafe impl VecZnxDftSubImpl<Self> for FFT64Ref {
}
}
unsafe impl VecZnxDftSubABInplaceImpl<Self> for FFT64Ref {
fn vec_znx_dft_sub_ab_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
unsafe impl VecZnxDftSubInplaceImpl<Self> for FFT64Ref {
fn vec_znx_dft_sub_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<Self>,
A: VecZnxDftToRef<Self>,
{
vec_znx_dft_sub_ab_inplace(res, res_col, a, a_col);
vec_znx_dft_sub_inplace(res, res_col, a, a_col);
}
}
unsafe impl VecZnxDftSubBAInplaceImpl<Self> for FFT64Ref {
fn vec_znx_dft_sub_ba_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
unsafe impl VecZnxDftSubNegateInplaceImpl<Self> for FFT64Ref {
fn vec_znx_dft_sub_negate_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<Self>,
A: VecZnxDftToRef<Self>,
{
vec_znx_dft_sub_ba_inplace(res, res_col, a, a_col);
vec_znx_dft_sub_negate_inplace(res, res_col, a, a_col);
}
}

View File

@@ -18,21 +18,21 @@ unsafe impl ZnNormalizeInplaceImpl<Self> for FFT64Ref
where
Self: TakeSliceImpl<Self>,
{
fn zn_normalize_inplace_impl<R>(n: usize, basek: usize, res: &mut R, res_col: usize, scratch: &mut Scratch<Self>)
fn zn_normalize_inplace_impl<R>(n: usize, base2k: usize, res: &mut R, res_col: usize, scratch: &mut Scratch<Self>)
where
R: ZnToMut,
{
let (carry, _) = scratch.take_slice(n);
zn_normalize_inplace::<R, FFT64Ref>(n, basek, res, res_col, carry);
zn_normalize_inplace::<R, FFT64Ref>(n, base2k, res, res_col, carry);
}
}
unsafe impl ZnFillUniformImpl<Self> for FFT64Ref {
fn zn_fill_uniform_impl<R>(n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
fn zn_fill_uniform_impl<R>(n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: ZnToMut,
{
zn_fill_uniform(n, basek, res, res_col, source);
zn_fill_uniform(n, base2k, res, res_col, source);
}
}
@@ -40,7 +40,7 @@ unsafe impl ZnFillNormalImpl<Self> for FFT64Ref {
#[allow(clippy::too_many_arguments)]
fn zn_fill_normal_impl<R>(
n: usize,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -50,7 +50,7 @@ unsafe impl ZnFillNormalImpl<Self> for FFT64Ref {
) where
R: ZnToMut,
{
zn_fill_normal(n, basek, res, res_col, k, source, sigma, bound);
zn_fill_normal(n, base2k, res, res_col, k, source, sigma, bound);
}
}
@@ -58,7 +58,7 @@ unsafe impl ZnAddNormalImpl<Self> for FFT64Ref {
#[allow(clippy::too_many_arguments)]
fn zn_add_normal_impl<R>(
n: usize,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -68,6 +68,6 @@ unsafe impl ZnAddNormalImpl<Self> for FFT64Ref {
) where
R: ZnToMut,
{
zn_add_normal(n, basek, res, res_col, k, source, sigma, bound);
zn_add_normal(n, base2k, res, res_col, k, source, sigma, bound);
}
}

View File

@@ -1,12 +1,14 @@
use poulpy_hal::reference::znx::{
ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxNormalizeFinalStep,
ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeFirstStepInplace,
ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, ZnxRotate, ZnxSub, ZnxSubABInplace,
ZnxSubBAInplace, ZnxSwitchRing, ZnxZero, znx_add_inplace_ref, znx_add_ref, znx_automorphism_ref, znx_copy_ref,
znx_negate_inplace_ref, znx_negate_ref, znx_normalize_final_step_inplace_ref, znx_normalize_final_step_ref,
ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxExtractDigitAddMul, ZnxMulAddPowerOfTwo, ZnxMulPowerOfTwo,
ZnxMulPowerOfTwoInplace, ZnxNegate, ZnxNegateInplace, ZnxNormalizeDigit, ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace,
ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStep,
ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, ZnxRotate, ZnxSub, ZnxSubInplace, ZnxSubNegateInplace,
ZnxSwitchRing, ZnxZero, znx_add_inplace_ref, znx_add_ref, znx_automorphism_ref, znx_copy_ref, znx_extract_digit_addmul_ref,
znx_mul_add_power_of_two_ref, znx_mul_power_of_two_inplace_ref, znx_mul_power_of_two_ref, znx_negate_inplace_ref,
znx_negate_ref, znx_normalize_digit_ref, znx_normalize_final_step_inplace_ref, znx_normalize_final_step_ref,
znx_normalize_first_step_carry_only_ref, znx_normalize_first_step_inplace_ref, znx_normalize_first_step_ref,
znx_normalize_middle_step_carry_only_ref, znx_normalize_middle_step_inplace_ref, znx_normalize_middle_step_ref, znx_rotate,
znx_sub_ab_inplace_ref, znx_sub_ba_inplace_ref, znx_sub_ref, znx_switch_ring_ref, znx_zero_ref,
znx_sub_inplace_ref, znx_sub_negate_inplace_ref, znx_sub_ref, znx_switch_ring_ref, znx_zero_ref,
};
use crate::cpu_fft64_ref::FFT64Ref;
@@ -32,17 +34,38 @@ impl ZnxSub for FFT64Ref {
}
}
impl ZnxSubABInplace for FFT64Ref {
impl ZnxSubInplace for FFT64Ref {
#[inline(always)]
fn znx_sub_ab_inplace(res: &mut [i64], a: &[i64]) {
znx_sub_ab_inplace_ref(res, a);
fn znx_sub_inplace(res: &mut [i64], a: &[i64]) {
znx_sub_inplace_ref(res, a);
}
}
impl ZnxSubBAInplace for FFT64Ref {
impl ZnxSubNegateInplace for FFT64Ref {
#[inline(always)]
fn znx_sub_ba_inplace(res: &mut [i64], a: &[i64]) {
znx_sub_ba_inplace_ref(res, a);
fn znx_sub_negate_inplace(res: &mut [i64], a: &[i64]) {
znx_sub_negate_inplace_ref(res, a);
}
}
impl ZnxMulAddPowerOfTwo for FFT64Ref {
#[inline(always)]
fn znx_muladd_power_of_two(k: i64, res: &mut [i64], a: &[i64]) {
znx_mul_add_power_of_two_ref(k, res, a);
}
}
impl ZnxMulPowerOfTwo for FFT64Ref {
#[inline(always)]
fn znx_mul_power_of_two(k: i64, res: &mut [i64], a: &[i64]) {
znx_mul_power_of_two_ref(k, res, a);
}
}
impl ZnxMulPowerOfTwoInplace for FFT64Ref {
#[inline(always)]
fn znx_mul_power_of_two_inplace(k: i64, res: &mut [i64]) {
znx_mul_power_of_two_inplace_ref(k, res);
}
}
@@ -97,56 +120,70 @@ impl ZnxSwitchRing for FFT64Ref {
impl ZnxNormalizeFinalStep for FFT64Ref {
#[inline(always)]
fn znx_normalize_final_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
znx_normalize_final_step_ref(basek, lsh, x, a, carry);
fn znx_normalize_final_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
znx_normalize_final_step_ref(base2k, lsh, x, a, carry);
}
}
impl ZnxNormalizeFinalStepInplace for FFT64Ref {
#[inline(always)]
fn znx_normalize_final_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
znx_normalize_final_step_inplace_ref(basek, lsh, x, carry);
fn znx_normalize_final_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
znx_normalize_final_step_inplace_ref(base2k, lsh, x, carry);
}
}
impl ZnxNormalizeFirstStep for FFT64Ref {
#[inline(always)]
fn znx_normalize_first_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
znx_normalize_first_step_ref(basek, lsh, x, a, carry);
fn znx_normalize_first_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
znx_normalize_first_step_ref(base2k, lsh, x, a, carry);
}
}
impl ZnxNormalizeFirstStepCarryOnly for FFT64Ref {
#[inline(always)]
fn znx_normalize_first_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
znx_normalize_first_step_carry_only_ref(basek, lsh, x, carry);
fn znx_normalize_first_step_carry_only(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
znx_normalize_first_step_carry_only_ref(base2k, lsh, x, carry);
}
}
impl ZnxNormalizeFirstStepInplace for FFT64Ref {
#[inline(always)]
fn znx_normalize_first_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
znx_normalize_first_step_inplace_ref(basek, lsh, x, carry);
fn znx_normalize_first_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
znx_normalize_first_step_inplace_ref(base2k, lsh, x, carry);
}
}
impl ZnxNormalizeMiddleStep for FFT64Ref {
#[inline(always)]
fn znx_normalize_middle_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
znx_normalize_middle_step_ref(basek, lsh, x, a, carry);
fn znx_normalize_middle_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
znx_normalize_middle_step_ref(base2k, lsh, x, a, carry);
}
}
impl ZnxNormalizeMiddleStepCarryOnly for FFT64Ref {
#[inline(always)]
fn znx_normalize_middle_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
znx_normalize_middle_step_carry_only_ref(basek, lsh, x, carry);
fn znx_normalize_middle_step_carry_only(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
znx_normalize_middle_step_carry_only_ref(base2k, lsh, x, carry);
}
}
impl ZnxNormalizeMiddleStepInplace for FFT64Ref {
#[inline(always)]
fn znx_normalize_middle_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
znx_normalize_middle_step_inplace_ref(basek, lsh, x, carry);
fn znx_normalize_middle_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
znx_normalize_middle_step_inplace_ref(base2k, lsh, x, carry);
}
}
impl ZnxExtractDigitAddMul for FFT64Ref {
#[inline(always)]
fn znx_extract_digit_addmul(base2k: usize, lsh: usize, res: &mut [i64], src: &mut [i64]) {
znx_extract_digit_addmul_ref(base2k, lsh, res, src);
}
}
impl ZnxNormalizeDigit for FFT64Ref {
#[inline(always)]
fn znx_normalize_digit(base2k: usize, res: &mut [i64], src: &mut [i64]) {
znx_normalize_digit_ref(base2k, res, src);
}
}