Add cross-basek normalization (#90)

* added cross_basek_normalization

* updated method signatures to take layouts

* fixed cross-base normalization

fix #91
fix #93
This commit is contained in:
Jean-Philippe Bossuat
2025-09-30 14:40:10 +02:00
committed by GitHub
parent 4da790ea6a
commit 37e13b965c
216 changed files with 12481 additions and 7745 deletions

View File

@@ -6,5 +6,6 @@ mod vec_znx_big;
mod vec_znx_dft;
mod vmp_pmat;
mod zn;
mod znx;
pub struct FFT64Spqlios;

View File

@@ -3,20 +3,11 @@ use std::ptr::NonNull;
use poulpy_hal::{
layouts::{Backend, Module},
oep::ModuleNewImpl,
reference::znx::{
ZnxCopy, ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly,
ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace,
ZnxRotate, ZnxSwitchRing, ZnxZero, znx_copy_ref, znx_normalize_final_step_inplace_ref, znx_normalize_final_step_ref,
znx_normalize_first_step_carry_only_ref, znx_normalize_first_step_inplace_ref, znx_normalize_first_step_ref,
znx_normalize_middle_step_carry_only_ref, znx_normalize_middle_step_inplace_ref, znx_normalize_middle_step_ref,
znx_switch_ring_ref, znx_zero_ref,
},
};
use crate::cpu_spqlios::{
FFT64Spqlios,
ffi::module::{MODULE, delete_module_info, new_module_info},
znx::znx_rotate_i64,
};
impl Backend for FFT64Spqlios {
@@ -41,85 +32,3 @@ unsafe impl ModuleNewImpl<Self> for FFT64Spqlios {
unsafe { Module::from_raw_parts(new_module_info(n, 0), n) }
}
}
impl ZnxCopy for FFT64Spqlios {
fn znx_copy(res: &mut [i64], a: &[i64]) {
znx_copy_ref(res, a);
}
}
impl ZnxZero for FFT64Spqlios {
fn znx_zero(res: &mut [i64]) {
znx_zero_ref(res);
}
}
impl ZnxSwitchRing for FFT64Spqlios {
fn znx_switch_ring(res: &mut [i64], a: &[i64]) {
znx_switch_ring_ref(res, a);
}
}
impl ZnxRotate for FFT64Spqlios {
fn znx_rotate(p: i64, res: &mut [i64], src: &[i64]) {
unsafe {
znx_rotate_i64(res.len() as u64, p, res.as_mut_ptr(), src.as_ptr());
}
}
}
impl ZnxNormalizeFinalStep for FFT64Spqlios {
#[inline(always)]
fn znx_normalize_final_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
znx_normalize_final_step_ref(basek, lsh, x, a, carry);
}
}
impl ZnxNormalizeFinalStepInplace for FFT64Spqlios {
#[inline(always)]
fn znx_normalize_final_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
znx_normalize_final_step_inplace_ref(basek, lsh, x, carry);
}
}
impl ZnxNormalizeFirstStep for FFT64Spqlios {
#[inline(always)]
fn znx_normalize_first_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
znx_normalize_first_step_ref(basek, lsh, x, a, carry);
}
}
impl ZnxNormalizeFirstStepCarryOnly for FFT64Spqlios {
#[inline(always)]
fn znx_normalize_first_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
znx_normalize_first_step_carry_only_ref(basek, lsh, x, carry);
}
}
impl ZnxNormalizeFirstStepInplace for FFT64Spqlios {
#[inline(always)]
fn znx_normalize_first_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
znx_normalize_first_step_inplace_ref(basek, lsh, x, carry);
}
}
impl ZnxNormalizeMiddleStep for FFT64Spqlios {
#[inline(always)]
fn znx_normalize_middle_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
znx_normalize_middle_step_ref(basek, lsh, x, a, carry);
}
}
impl ZnxNormalizeMiddleStepCarryOnly for FFT64Spqlios {
#[inline(always)]
fn znx_normalize_middle_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
znx_normalize_middle_step_carry_only_ref(basek, lsh, x, carry);
}
}
impl ZnxNormalizeMiddleStepInplace for FFT64Spqlios {
#[inline(always)]
fn znx_normalize_middle_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
znx_normalize_middle_step_inplace_ref(basek, lsh, x, carry);
}
}

View File

@@ -253,9 +253,6 @@ fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8]
(take_slice, rem_slice)
}
} else {
panic!(
"Attempted to take {} from scratch with {} aligned bytes left",
take_len, aligned_len,
);
panic!("Attempted to take {take_len} from scratch with {aligned_len} aligned bytes left");
}
}

View File

@@ -1,5 +1,8 @@
use poulpy_hal::{
api::{TakeSlice, VecZnxMergeRingsTmpBytes, VecZnxNormalizeTmpBytes, VecZnxSplitRingTmpBytes},
api::{
TakeSlice, VecZnxLshTmpBytes, VecZnxMergeRingsTmpBytes, VecZnxNormalizeTmpBytes, VecZnxRshTmpBytes,
VecZnxSplitRingTmpBytes,
},
layouts::{
Module, ScalarZnx, ScalarZnxToRef, Scratch, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut,
},
@@ -11,16 +14,16 @@ use poulpy_hal::{
VecZnxMulXpMinusOneInplaceTmpBytesImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl,
VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl,
VecZnxRotateInplaceTmpBytesImpl, VecZnxRshImpl, VecZnxRshInplaceImpl, VecZnxRshTmpBytesImpl, VecZnxSplitRingImpl,
VecZnxSplitRingTmpBytesImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarImpl,
VecZnxSplitRingTmpBytesImpl, VecZnxSubImpl, VecZnxSubInplaceImpl, VecZnxSubNegateInplaceImpl, VecZnxSubScalarImpl,
VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl,
},
reference::{
vec_znx::{
vec_znx_add_normal_ref, vec_znx_automorphism_inplace_tmp_bytes, vec_znx_copy, vec_znx_fill_normal_ref,
vec_znx_fill_uniform_ref, vec_znx_lsh, vec_znx_lsh_inplace, vec_znx_lsh_tmp_bytes, vec_znx_merge_rings,
vec_znx_merge_rings_tmp_bytes, vec_znx_mul_xp_minus_one_inplace_tmp_bytes, vec_znx_rotate_inplace_tmp_bytes,
vec_znx_rsh, vec_znx_rsh_inplace, vec_znx_rsh_tmp_bytes, vec_znx_split_ring, vec_znx_split_ring_tmp_bytes,
vec_znx_switch_ring,
vec_znx_merge_rings_tmp_bytes, vec_znx_mul_xp_minus_one_inplace_tmp_bytes, vec_znx_normalize_tmp_bytes,
vec_znx_rotate_inplace_tmp_bytes, vec_znx_rsh, vec_znx_rsh_inplace, vec_znx_rsh_tmp_bytes, vec_znx_split_ring,
vec_znx_split_ring_tmp_bytes, vec_znx_switch_ring,
},
znx::{znx_copy_ref, znx_zero_ref},
},
@@ -34,7 +37,7 @@ use crate::cpu_spqlios::{
unsafe impl VecZnxNormalizeTmpBytesImpl<Self> for FFT64Spqlios {
fn vec_znx_normalize_tmp_bytes_impl(module: &Module<Self>) -> usize {
unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(module.ptr() as *const module_info_t) as usize }
vec_znx_normalize_tmp_bytes(module.n())
}
}
@@ -44,9 +47,10 @@ where
{
fn vec_znx_normalize_impl<R, A>(
module: &Module<Self>,
basek: usize,
res_basek: usize,
res: &mut R,
res_col: usize,
a_basek: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch<Self>,
@@ -60,6 +64,10 @@ where
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), a.n());
assert_eq!(
res_basek, a_basek,
"res_basek != a_basek -> base2k conversion is not supported"
)
}
let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes());
@@ -67,7 +75,7 @@ where
unsafe {
vec_znx::vec_znx_normalize_base2k(
module.ptr() as *const module_info_t,
basek as u64,
res_basek as u64,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
@@ -86,7 +94,7 @@ where
{
fn vec_znx_normalize_inplace_impl<A>(
module: &Module<Self>,
basek: usize,
base2k: usize,
a: &mut A,
a_col: usize,
scratch: &mut Scratch<Self>,
@@ -100,7 +108,7 @@ where
unsafe {
vec_znx::vec_znx_normalize_base2k(
module.ptr() as *const module_info_t,
basek as u64,
base2k as u64,
a.at_mut_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
@@ -301,8 +309,8 @@ unsafe impl VecZnxSubImpl<Self> for FFT64Spqlios {
}
}
unsafe impl VecZnxSubABInplaceImpl<Self> for FFT64Spqlios {
fn vec_znx_sub_ab_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
unsafe impl VecZnxSubInplaceImpl<Self> for FFT64Spqlios {
fn vec_znx_sub_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
@@ -330,8 +338,8 @@ unsafe impl VecZnxSubABInplaceImpl<Self> for FFT64Spqlios {
}
}
unsafe impl VecZnxSubBAInplaceImpl<Self> for FFT64Spqlios {
fn vec_znx_sub_ba_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
unsafe impl VecZnxSubNegateInplaceImpl<Self> for FFT64Spqlios {
fn vec_znx_sub_negate_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
@@ -512,9 +520,9 @@ where
Module<Self>: VecZnxNormalizeTmpBytes,
Scratch<Self>: TakeSlice,
{
fn vec_znx_lsh_inplace_impl<R, A>(
fn vec_znx_lsh_impl<R, A>(
module: &Module<Self>,
basek: usize,
base2k: usize,
k: usize,
res: &mut R,
res_col: usize,
@@ -525,8 +533,8 @@ where
R: VecZnxToMut,
A: VecZnxToRef,
{
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
vec_znx_lsh::<_, _, FFT64Spqlios>(basek, k, res, res_col, a, a_col, carry)
let (carry, _) = scratch.take_slice(module.vec_znx_lsh_tmp_bytes() / size_of::<i64>());
vec_znx_lsh::<_, _, FFT64Spqlios>(base2k, k, res, res_col, a, a_col, carry)
}
}
@@ -537,7 +545,7 @@ where
{
fn vec_znx_lsh_inplace_impl<A>(
module: &Module<Self>,
basek: usize,
base2k: usize,
k: usize,
a: &mut A,
a_col: usize,
@@ -545,8 +553,8 @@ where
) where
A: VecZnxToMut,
{
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
vec_znx_lsh_inplace::<_, FFT64Spqlios>(basek, k, a, a_col, carry)
let (carry, _) = scratch.take_slice(module.vec_znx_lsh_tmp_bytes() / size_of::<i64>());
vec_znx_lsh_inplace::<_, FFT64Spqlios>(base2k, k, a, a_col, carry)
}
}
@@ -555,9 +563,9 @@ where
Module<Self>: VecZnxNormalizeTmpBytes,
Scratch<Self>: TakeSlice,
{
fn vec_znx_rsh_inplace_impl<R, A>(
fn vec_znx_rsh_impl<R, A>(
module: &Module<Self>,
basek: usize,
base2k: usize,
k: usize,
res: &mut R,
res_col: usize,
@@ -568,8 +576,8 @@ where
R: VecZnxToMut,
A: VecZnxToRef,
{
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
vec_znx_rsh::<_, _, FFT64Spqlios>(basek, k, res, res_col, a, a_col, carry)
let (carry, _) = scratch.take_slice(module.vec_znx_rsh_tmp_bytes() / size_of::<i64>());
vec_znx_rsh::<_, _, FFT64Spqlios>(base2k, k, res, res_col, a, a_col, carry)
}
}
@@ -580,7 +588,7 @@ where
{
fn vec_znx_rsh_inplace_impl<A>(
module: &Module<Self>,
basek: usize,
base2k: usize,
k: usize,
a: &mut A,
a_col: usize,
@@ -588,8 +596,8 @@ where
) where
A: VecZnxToMut,
{
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
vec_znx_rsh_inplace::<_, FFT64Spqlios>(basek, k, a, a_col, carry)
let (carry, _) = scratch.take_slice(module.vec_znx_rsh_tmp_bytes() / size_of::<i64>());
vec_znx_rsh_inplace::<_, FFT64Spqlios>(base2k, k, a, a_col, carry)
}
}
@@ -690,11 +698,7 @@ unsafe impl VecZnxAutomorphismInplaceImpl<Self> for FFT64Spqlios {
let mut a: VecZnx<&mut [u8]> = a.to_mut();
#[cfg(debug_assertions)]
{
assert!(
k & 1 != 0,
"invalid galois element: must be odd but is {}",
k
);
assert!(k & 1 != 0, "invalid galois element: must be odd but is {k}");
}
unsafe {
vec_znx::vec_znx_automorphism(
@@ -852,18 +856,18 @@ unsafe impl VecZnxCopyImpl<Self> for FFT64Spqlios {
}
unsafe impl VecZnxFillUniformImpl<Self> for FFT64Spqlios {
fn vec_znx_fill_uniform_impl<R>(_module: &Module<Self>, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
fn vec_znx_fill_uniform_impl<R>(_module: &Module<Self>, base2k: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: VecZnxToMut,
{
vec_znx_fill_uniform_ref(basek, res, res_col, source)
vec_znx_fill_uniform_ref(base2k, res, res_col, source)
}
}
unsafe impl VecZnxFillNormalImpl<Self> for FFT64Spqlios {
fn vec_znx_fill_normal_impl<R>(
_module: &Module<Self>,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -873,14 +877,14 @@ unsafe impl VecZnxFillNormalImpl<Self> for FFT64Spqlios {
) where
R: VecZnxToMut,
{
vec_znx_fill_normal_ref(basek, res, res_col, k, sigma, bound, source);
vec_znx_fill_normal_ref(base2k, res, res_col, k, sigma, bound, source);
}
}
unsafe impl VecZnxAddNormalImpl<Self> for FFT64Spqlios {
fn vec_znx_add_normal_impl<R>(
_module: &Module<Self>,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -890,6 +894,6 @@ unsafe impl VecZnxAddNormalImpl<Self> for FFT64Spqlios {
) where
R: VecZnxToMut,
{
vec_znx_add_normal_ref(basek, res, res_col, k, sigma, bound, source);
vec_znx_add_normal_ref(base2k, res, res_col, k, sigma, bound, source);
}
}

View File

@@ -10,11 +10,12 @@ use poulpy_hal::{
VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, VecZnxBigAutomorphismImpl,
VecZnxBigAutomorphismInplaceImpl, VecZnxBigAutomorphismInplaceTmpBytesImpl, VecZnxBigFromBytesImpl,
VecZnxBigFromSmallImpl, VecZnxBigNegateImpl, VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl,
VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubABInplaceImpl, VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl,
VecZnxBigSubSmallAImpl, VecZnxBigSubSmallAInplaceImpl, VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl,
VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubImpl, VecZnxBigSubInplaceImpl, VecZnxBigSubNegateInplaceImpl,
VecZnxBigSubSmallAImpl, VecZnxBigSubSmallBImpl, VecZnxBigSubSmallInplaceImpl, VecZnxBigSubSmallNegateInplaceImpl,
},
reference::{
vec_znx::vec_znx_add_normal_ref,
fft64::vec_znx_big::vec_znx_big_normalize,
vec_znx::{vec_znx_add_normal_ref, vec_znx_normalize_tmp_bytes},
znx::{znx_copy_ref, znx_zero_ref},
},
source::Source,
@@ -70,7 +71,7 @@ unsafe impl VecZnxBigFromSmallImpl<Self> for FFT64Spqlios {
unsafe impl VecZnxBigAddNormalImpl<Self> for FFT64Spqlios {
fn add_normal_impl<R: VecZnxBigToMut<Self>>(
_module: &Module<Self>,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -88,7 +89,7 @@ unsafe impl VecZnxBigAddNormalImpl<Self> for FFT64Spqlios {
max_size: res.max_size,
};
vec_znx_add_normal_ref(basek, &mut res_znx, res_col, k, sigma, bound, source);
vec_znx_add_normal_ref(base2k, &mut res_znx, res_col, k, sigma, bound, source);
}
}
@@ -266,9 +267,9 @@ unsafe impl VecZnxBigSubImpl<Self> for FFT64Spqlios {
}
}
unsafe impl VecZnxBigSubABInplaceImpl<Self> for FFT64Spqlios {
unsafe impl VecZnxBigSubInplaceImpl<Self> for FFT64Spqlios {
/// Subtracts `a` from `b` and stores the result on `b`.
fn vec_znx_big_sub_ab_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_big_sub_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<Self>,
A: VecZnxBigToRef<Self>,
@@ -297,9 +298,9 @@ unsafe impl VecZnxBigSubABInplaceImpl<Self> for FFT64Spqlios {
}
}
unsafe impl VecZnxBigSubBAInplaceImpl<Self> for FFT64Spqlios {
unsafe impl VecZnxBigSubNegateInplaceImpl<Self> for FFT64Spqlios {
/// Subtracts `b` from `a` and stores the result on `b`.
fn vec_znx_big_sub_ba_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_big_sub_negate_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<Self>,
A: VecZnxBigToRef<Self>,
@@ -370,9 +371,9 @@ unsafe impl VecZnxBigSubSmallAImpl<Self> for FFT64Spqlios {
}
}
unsafe impl VecZnxBigSubSmallAInplaceImpl<Self> for FFT64Spqlios {
unsafe impl VecZnxBigSubSmallInplaceImpl<Self> for FFT64Spqlios {
/// Subtracts `a` from `res` and stores the result on `res`.
fn vec_znx_big_sub_small_a_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_big_sub_small_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<Self>,
A: VecZnxToRef,
@@ -443,9 +444,9 @@ unsafe impl VecZnxBigSubSmallBImpl<Self> for FFT64Spqlios {
}
}
unsafe impl VecZnxBigSubSmallBInplaceImpl<Self> for FFT64Spqlios {
unsafe impl VecZnxBigSubSmallNegateInplaceImpl<Self> for FFT64Spqlios {
/// Subtracts `res` from `a` and stores the result on `res`.
fn vec_znx_big_sub_small_b_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_big_sub_small_negate_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<Self>,
A: VecZnxToRef,
@@ -518,7 +519,7 @@ unsafe impl VecZnxBigNegateInplaceImpl<Self> for FFT64Spqlios {
unsafe impl VecZnxBigNormalizeTmpBytesImpl<Self> for FFT64Spqlios {
fn vec_znx_big_normalize_tmp_bytes_impl(module: &Module<Self>) -> usize {
unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(module.ptr()) as usize }
vec_znx_normalize_tmp_bytes(module.n())
}
}
@@ -528,9 +529,10 @@ where
{
fn vec_znx_big_normalize_impl<R, A>(
module: &Module<Self>,
basek: usize,
res_basek: usize,
res: &mut R,
res_col: usize,
a_basek: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch<Self>,
@@ -538,28 +540,21 @@ where
R: VecZnxToMut,
A: VecZnxBigToRef<Self>,
{
let a: VecZnxBig<&[u8], Self> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), a.n());
}
let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes());
unsafe {
vec_znx::vec_znx_normalize_base2k(
module.ptr(),
basek as u64,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
tmp_bytes.as_mut_ptr(),
);
}
let (carry, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes() / size_of::<i64>());
// unsafe {
// vec_znx::vec_znx_normalize_base2k(
// module.ptr(),
// base2k as u64,
// res.at_mut_ptr(res_col, 0),
// res.size() as u64,
// res.sl() as u64,
// a.at_ptr(a_col, 0),
// a.size() as u64,
// a.sl() as u64,
// tmp_bytes.as_mut_ptr(),
// );
// }
vec_znx_big_normalize(res_basek, res, res_col, a_basek, a, a_col, carry);
}
}

View File

@@ -6,7 +6,7 @@ use poulpy_hal::{
},
oep::{
VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftApplyImpl,
VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl, VecZnxDftSubImpl,
VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubImpl, VecZnxDftSubInplaceImpl, VecZnxDftSubNegateInplaceImpl,
VecZnxDftZeroImpl, VecZnxIdftApplyConsumeImpl, VecZnxIdftApplyImpl, VecZnxIdftApplyTmpAImpl, VecZnxIdftApplyTmpBytesImpl,
},
reference::{
@@ -336,8 +336,8 @@ unsafe impl VecZnxDftSubImpl<Self> for FFT64Spqlios {
}
}
unsafe impl VecZnxDftSubABInplaceImpl<Self> for FFT64Spqlios {
fn vec_znx_dft_sub_ab_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
unsafe impl VecZnxDftSubInplaceImpl<Self> for FFT64Spqlios {
fn vec_znx_dft_sub_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<Self>,
A: VecZnxDftToRef<Self>,
@@ -363,8 +363,8 @@ unsafe impl VecZnxDftSubABInplaceImpl<Self> for FFT64Spqlios {
}
}
unsafe impl VecZnxDftSubBAInplaceImpl<Self> for FFT64Spqlios {
fn vec_znx_dft_sub_ba_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
unsafe impl VecZnxDftSubNegateInplaceImpl<Self> for FFT64Spqlios {
fn vec_znx_dft_sub_negate_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<Self>,
A: VecZnxDftToRef<Self>,

View File

@@ -12,7 +12,7 @@ unsafe impl ZnNormalizeInplaceImpl<Self> for FFT64Spqlios
where
Self: TakeSliceImpl<Self>,
{
fn zn_normalize_inplace_impl<A>(n: usize, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<Self>)
fn zn_normalize_inplace_impl<A>(n: usize, base2k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<Self>)
where
A: ZnToMut,
{
@@ -23,7 +23,7 @@ where
unsafe {
zn64::zn64_normalize_base2k_ref(
n as u64,
basek as u64,
base2k as u64,
a.at_mut_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
@@ -37,11 +37,11 @@ where
}
unsafe impl ZnFillUniformImpl<Self> for FFT64Spqlios {
fn zn_fill_uniform_impl<R>(n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
fn zn_fill_uniform_impl<R>(n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: ZnToMut,
{
zn_fill_uniform(n, basek, res, res_col, source);
zn_fill_uniform(n, base2k, res, res_col, source);
}
}
@@ -49,7 +49,7 @@ unsafe impl ZnFillNormalImpl<Self> for FFT64Spqlios {
#[allow(clippy::too_many_arguments)]
fn zn_fill_normal_impl<R>(
n: usize,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -59,7 +59,7 @@ unsafe impl ZnFillNormalImpl<Self> for FFT64Spqlios {
) where
R: ZnToMut,
{
zn_fill_normal(n, basek, res, res_col, k, source, sigma, bound);
zn_fill_normal(n, base2k, res, res_col, k, source, sigma, bound);
}
}
@@ -67,7 +67,7 @@ unsafe impl ZnAddNormalImpl<Self> for FFT64Spqlios {
#[allow(clippy::too_many_arguments)]
fn zn_add_normal_impl<R>(
n: usize,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -77,6 +77,6 @@ unsafe impl ZnAddNormalImpl<Self> for FFT64Spqlios {
) where
R: ZnToMut,
{
zn_add_normal(n, basek, res, res_col, k, source, sigma, bound);
zn_add_normal(n, base2k, res, res_col, k, source, sigma, bound);
}
}

View File

@@ -0,0 +1,189 @@
use poulpy_hal::reference::znx::{
ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxExtractDigitAddMul, ZnxMulAddPowerOfTwo, ZnxMulPowerOfTwo,
ZnxMulPowerOfTwoInplace, ZnxNegate, ZnxNegateInplace, ZnxNormalizeDigit, ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace,
ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStep,
ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, ZnxRotate, ZnxSub, ZnxSubInplace, ZnxSubNegateInplace,
ZnxSwitchRing, ZnxZero, znx_add_inplace_ref, znx_add_ref, znx_automorphism_ref, znx_copy_ref, znx_extract_digit_addmul_ref,
znx_mul_add_power_of_two_ref, znx_mul_power_of_two_inplace_ref, znx_mul_power_of_two_ref, znx_negate_inplace_ref,
znx_negate_ref, znx_normalize_digit_ref, znx_normalize_final_step_inplace_ref, znx_normalize_final_step_ref,
znx_normalize_first_step_carry_only_ref, znx_normalize_first_step_inplace_ref, znx_normalize_first_step_ref,
znx_normalize_middle_step_carry_only_ref, znx_normalize_middle_step_inplace_ref, znx_normalize_middle_step_ref, znx_rotate,
znx_sub_inplace_ref, znx_sub_negate_inplace_ref, znx_sub_ref, znx_switch_ring_ref, znx_zero_ref,
};
use crate::FFT64Spqlios;
impl ZnxAdd for FFT64Spqlios {
#[inline(always)]
fn znx_add(res: &mut [i64], a: &[i64], b: &[i64]) {
znx_add_ref(res, a, b);
}
}
impl ZnxAddInplace for FFT64Spqlios {
#[inline(always)]
fn znx_add_inplace(res: &mut [i64], a: &[i64]) {
znx_add_inplace_ref(res, a);
}
}
impl ZnxSub for FFT64Spqlios {
#[inline(always)]
fn znx_sub(res: &mut [i64], a: &[i64], b: &[i64]) {
znx_sub_ref(res, a, b);
}
}
impl ZnxSubInplace for FFT64Spqlios {
#[inline(always)]
fn znx_sub_inplace(res: &mut [i64], a: &[i64]) {
znx_sub_inplace_ref(res, a);
}
}
impl ZnxSubNegateInplace for FFT64Spqlios {
#[inline(always)]
fn znx_sub_negate_inplace(res: &mut [i64], a: &[i64]) {
znx_sub_negate_inplace_ref(res, a);
}
}
impl ZnxMulAddPowerOfTwo for FFT64Spqlios {
#[inline(always)]
fn znx_muladd_power_of_two(k: i64, res: &mut [i64], a: &[i64]) {
znx_mul_add_power_of_two_ref(k, res, a);
}
}
impl ZnxMulPowerOfTwo for FFT64Spqlios {
#[inline(always)]
fn znx_mul_power_of_two(k: i64, res: &mut [i64], a: &[i64]) {
znx_mul_power_of_two_ref(k, res, a);
}
}
impl ZnxMulPowerOfTwoInplace for FFT64Spqlios {
#[inline(always)]
fn znx_mul_power_of_two_inplace(k: i64, res: &mut [i64]) {
znx_mul_power_of_two_inplace_ref(k, res);
}
}
impl ZnxAutomorphism for FFT64Spqlios {
#[inline(always)]
fn znx_automorphism(p: i64, res: &mut [i64], a: &[i64]) {
znx_automorphism_ref(p, res, a);
}
}
impl ZnxCopy for FFT64Spqlios {
#[inline(always)]
fn znx_copy(res: &mut [i64], a: &[i64]) {
znx_copy_ref(res, a);
}
}
impl ZnxNegate for FFT64Spqlios {
#[inline(always)]
fn znx_negate(res: &mut [i64], src: &[i64]) {
znx_negate_ref(res, src);
}
}
impl ZnxNegateInplace for FFT64Spqlios {
#[inline(always)]
fn znx_negate_inplace(res: &mut [i64]) {
znx_negate_inplace_ref(res);
}
}
impl ZnxRotate for FFT64Spqlios {
#[inline(always)]
fn znx_rotate(p: i64, res: &mut [i64], src: &[i64]) {
znx_rotate::<Self>(p, res, src);
}
}
impl ZnxZero for FFT64Spqlios {
#[inline(always)]
fn znx_zero(res: &mut [i64]) {
znx_zero_ref(res);
}
}
impl ZnxSwitchRing for FFT64Spqlios {
#[inline(always)]
fn znx_switch_ring(res: &mut [i64], a: &[i64]) {
znx_switch_ring_ref(res, a);
}
}
impl ZnxNormalizeFinalStep for FFT64Spqlios {
#[inline(always)]
fn znx_normalize_final_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
znx_normalize_final_step_ref(base2k, lsh, x, a, carry);
}
}
impl ZnxNormalizeFinalStepInplace for FFT64Spqlios {
#[inline(always)]
fn znx_normalize_final_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
znx_normalize_final_step_inplace_ref(base2k, lsh, x, carry);
}
}
impl ZnxNormalizeFirstStep for FFT64Spqlios {
#[inline(always)]
fn znx_normalize_first_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
znx_normalize_first_step_ref(base2k, lsh, x, a, carry);
}
}
impl ZnxNormalizeFirstStepCarryOnly for FFT64Spqlios {
#[inline(always)]
fn znx_normalize_first_step_carry_only(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
znx_normalize_first_step_carry_only_ref(base2k, lsh, x, carry);
}
}
impl ZnxNormalizeFirstStepInplace for FFT64Spqlios {
#[inline(always)]
fn znx_normalize_first_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
znx_normalize_first_step_inplace_ref(base2k, lsh, x, carry);
}
}
impl ZnxNormalizeMiddleStep for FFT64Spqlios {
#[inline(always)]
fn znx_normalize_middle_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
znx_normalize_middle_step_ref(base2k, lsh, x, a, carry);
}
}
impl ZnxNormalizeMiddleStepCarryOnly for FFT64Spqlios {
#[inline(always)]
fn znx_normalize_middle_step_carry_only(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
znx_normalize_middle_step_carry_only_ref(base2k, lsh, x, carry);
}
}
impl ZnxNormalizeMiddleStepInplace for FFT64Spqlios {
#[inline(always)]
fn znx_normalize_middle_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
znx_normalize_middle_step_inplace_ref(base2k, lsh, x, carry);
}
}
impl ZnxExtractDigitAddMul for FFT64Spqlios {
#[inline(always)]
fn znx_extract_digit_addmul(base2k: usize, lsh: usize, res: &mut [i64], src: &mut [i64]) {
znx_extract_digit_addmul_ref(base2k, lsh, res, src);
}
}
impl ZnxNormalizeDigit for FFT64Spqlios {
#[inline(always)]
fn znx_normalize_digit(base2k: usize, res: &mut [i64], src: &mut [i64]) {
znx_normalize_digit_ref(base2k, res, src);
}
}

View File

@@ -5,15 +5,15 @@ cross_backend_test_suite! {
backend_ref = crate::cpu_fft64_ref::FFT64Ref,
backend_test = crate::cpu_spqlios::FFT64Spqlios,
size = 1 << 5,
basek = 12,
base2k = 12,
tests = {
test_vec_znx_add => poulpy_hal::test_suite::vec_znx::test_vec_znx_add,
test_vec_znx_add_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_inplace,
test_vec_znx_add_scalar => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_scalar,
test_vec_znx_add_scalar_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_scalar_inplace,
test_vec_znx_sub => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub,
test_vec_znx_sub_ab_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_ab_inplace,
test_vec_znx_sub_ba_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_ba_inplace,
test_vec_znx_sub_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_inplace,
test_vec_znx_sub_negate_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_negate_inplace,
test_vec_znx_sub_scalar => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_scalar,
test_vec_znx_sub_scalar_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_scalar_inplace,
test_vec_znx_rsh => poulpy_hal::test_suite::vec_znx::test_vec_znx_rsh,
@@ -41,7 +41,7 @@ cross_backend_test_suite! {
backend_ref = crate::cpu_fft64_ref::FFT64Ref,
backend_test = crate::cpu_spqlios::FFT64Spqlios,
size = 1 << 5,
basek = 12,
base2k = 12,
tests = {
test_svp_apply_dft_to_dft => poulpy_hal::test_suite::svp::test_svp_apply_dft_to_dft,
test_svp_apply_dft_to_dft_inplace => poulpy_hal::test_suite::svp::test_svp_apply_dft_to_dft_inplace,
@@ -53,20 +53,20 @@ cross_backend_test_suite! {
backend_ref = crate::cpu_fft64_ref::FFT64Ref,
backend_test = crate::cpu_spqlios::FFT64Spqlios,
size = 1 << 5,
basek = 12,
base2k = 12,
tests = {
test_vec_znx_big_add => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add,
test_vec_znx_big_add_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add_inplace,
test_vec_znx_big_add_small => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add_small,
test_vec_znx_big_add_small_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add_small_inplace,
test_vec_znx_big_sub => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub,
test_vec_znx_big_sub_ab_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_ab_inplace,
test_vec_znx_big_sub_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_inplace,
test_vec_znx_big_automorphism => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_automorphism,
test_vec_znx_big_automorphism_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_automorphism_inplace,
test_vec_znx_big_negate => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_negate,
test_vec_znx_big_negate_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_negate_inplace,
test_vec_znx_big_normalize => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_normalize,
test_vec_znx_big_sub_ba_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_ba_inplace,
test_vec_znx_big_sub_negate_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_negate_inplace,
test_vec_znx_big_sub_small_a => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_a,
test_vec_znx_big_sub_small_a_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_a_inplace,
test_vec_znx_big_sub_small_b => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_b,
@@ -79,13 +79,13 @@ cross_backend_test_suite! {
backend_ref = crate::cpu_fft64_ref::FFT64Ref,
backend_test = crate::cpu_spqlios::FFT64Spqlios,
size = 1 << 5,
basek = 12,
base2k = 12,
tests = {
test_vec_znx_dft_add => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_add,
test_vec_znx_dft_add_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_add_inplace,
test_vec_znx_dft_sub => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub,
test_vec_znx_dft_sub_ab_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub_ab_inplace,
test_vec_znx_dft_sub_ba_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub_ba_inplace,
test_vec_znx_dft_sub_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub_inplace,
test_vec_znx_dft_sub_negate_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub_negate_inplace,
test_vec_znx_idft_apply => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_idft_apply,
test_vec_znx_idft_apply_consume => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_idft_apply_consume,
test_vec_znx_idft_apply_tmpa => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_idft_apply_tmpa,
@@ -97,7 +97,7 @@ cross_backend_test_suite! {
backend_ref = crate::cpu_fft64_ref::FFT64Ref,
backend_test = crate::cpu_spqlios::FFT64Spqlios,
size = 1 << 5,
basek = 12,
base2k = 12,
tests = {
test_vmp_apply_dft_to_dft => poulpy_hal::test_suite::vmp::test_vmp_apply_dft_to_dft,
test_vmp_apply_dft_to_dft_add => poulpy_hal::test_suite::vmp::test_vmp_apply_dft_to_dft_add,