Add cross-basek normalization (#90)

* added cross_basek_normalization

* updated method signatures to take layouts

* fixed cross-base normalization

fix #91
fix #93
This commit is contained in:
Jean-Philippe Bossuat
2025-09-30 14:40:10 +02:00
committed by GitHub
parent 4da790ea6a
commit 37e13b965c
216 changed files with 12481 additions and 7745 deletions

View File

@@ -10,11 +10,12 @@ use poulpy_hal::{
VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, VecZnxBigAutomorphismImpl,
VecZnxBigAutomorphismInplaceImpl, VecZnxBigAutomorphismInplaceTmpBytesImpl, VecZnxBigFromBytesImpl,
VecZnxBigFromSmallImpl, VecZnxBigNegateImpl, VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl,
VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubABInplaceImpl, VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl,
VecZnxBigSubSmallAImpl, VecZnxBigSubSmallAInplaceImpl, VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl,
VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubImpl, VecZnxBigSubInplaceImpl, VecZnxBigSubNegateInplaceImpl,
VecZnxBigSubSmallAImpl, VecZnxBigSubSmallBImpl, VecZnxBigSubSmallInplaceImpl, VecZnxBigSubSmallNegateInplaceImpl,
},
reference::{
vec_znx::vec_znx_add_normal_ref,
fft64::vec_znx_big::vec_znx_big_normalize,
vec_znx::{vec_znx_add_normal_ref, vec_znx_normalize_tmp_bytes},
znx::{znx_copy_ref, znx_zero_ref},
},
source::Source,
@@ -70,7 +71,7 @@ unsafe impl VecZnxBigFromSmallImpl<Self> for FFT64Spqlios {
unsafe impl VecZnxBigAddNormalImpl<Self> for FFT64Spqlios {
fn add_normal_impl<R: VecZnxBigToMut<Self>>(
_module: &Module<Self>,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -88,7 +89,7 @@ unsafe impl VecZnxBigAddNormalImpl<Self> for FFT64Spqlios {
max_size: res.max_size,
};
vec_znx_add_normal_ref(basek, &mut res_znx, res_col, k, sigma, bound, source);
vec_znx_add_normal_ref(base2k, &mut res_znx, res_col, k, sigma, bound, source);
}
}
@@ -266,9 +267,9 @@ unsafe impl VecZnxBigSubImpl<Self> for FFT64Spqlios {
}
}
unsafe impl VecZnxBigSubABInplaceImpl<Self> for FFT64Spqlios {
unsafe impl VecZnxBigSubInplaceImpl<Self> for FFT64Spqlios {
/// Subtracts `a` from `b` and stores the result on `b`.
fn vec_znx_big_sub_ab_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_big_sub_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<Self>,
A: VecZnxBigToRef<Self>,
@@ -297,9 +298,9 @@ unsafe impl VecZnxBigSubABInplaceImpl<Self> for FFT64Spqlios {
}
}
unsafe impl VecZnxBigSubBAInplaceImpl<Self> for FFT64Spqlios {
unsafe impl VecZnxBigSubNegateInplaceImpl<Self> for FFT64Spqlios {
/// Subtracts `b` from `a` and stores the result on `b`.
fn vec_znx_big_sub_ba_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_big_sub_negate_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<Self>,
A: VecZnxBigToRef<Self>,
@@ -370,9 +371,9 @@ unsafe impl VecZnxBigSubSmallAImpl<Self> for FFT64Spqlios {
}
}
unsafe impl VecZnxBigSubSmallAInplaceImpl<Self> for FFT64Spqlios {
unsafe impl VecZnxBigSubSmallInplaceImpl<Self> for FFT64Spqlios {
/// Subtracts `a` from `res` and stores the result on `res`.
fn vec_znx_big_sub_small_a_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_big_sub_small_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<Self>,
A: VecZnxToRef,
@@ -443,9 +444,9 @@ unsafe impl VecZnxBigSubSmallBImpl<Self> for FFT64Spqlios {
}
}
unsafe impl VecZnxBigSubSmallBInplaceImpl<Self> for FFT64Spqlios {
unsafe impl VecZnxBigSubSmallNegateInplaceImpl<Self> for FFT64Spqlios {
/// Subtracts `res` from `a` and stores the result on `res`.
fn vec_znx_big_sub_small_b_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_big_sub_small_negate_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<Self>,
A: VecZnxToRef,
@@ -518,7 +519,7 @@ unsafe impl VecZnxBigNegateInplaceImpl<Self> for FFT64Spqlios {
unsafe impl VecZnxBigNormalizeTmpBytesImpl<Self> for FFT64Spqlios {
fn vec_znx_big_normalize_tmp_bytes_impl(module: &Module<Self>) -> usize {
unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(module.ptr()) as usize }
vec_znx_normalize_tmp_bytes(module.n())
}
}
@@ -528,9 +529,10 @@ where
{
fn vec_znx_big_normalize_impl<R, A>(
module: &Module<Self>,
basek: usize,
res_basek: usize,
res: &mut R,
res_col: usize,
a_basek: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch<Self>,
@@ -538,28 +540,21 @@ where
R: VecZnxToMut,
A: VecZnxBigToRef<Self>,
{
let a: VecZnxBig<&[u8], Self> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), a.n());
}
let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes());
unsafe {
vec_znx::vec_znx_normalize_base2k(
module.ptr(),
basek as u64,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
tmp_bytes.as_mut_ptr(),
);
}
let (carry, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes() / size_of::<i64>());
// unsafe {
// vec_znx::vec_znx_normalize_base2k(
// module.ptr(),
// base2k as u64,
// res.at_mut_ptr(res_col, 0),
// res.size() as u64,
// res.sl() as u64,
// a.at_ptr(a_col, 0),
// a.size() as u64,
// a.sl() as u64,
// tmp_bytes.as_mut_ptr(),
// );
// }
vec_znx_big_normalize(res_basek, res, res_col, a_basek, a, a_col, carry);
}
}