Add cross-basek normalization (#90)

* added cross_basek_normalization

* updated method signatures to take layouts

* fixed cross-base normalization

fix #91
fix #93
This commit is contained in:
Jean-Philippe Bossuat
2025-09-30 14:40:10 +02:00
committed by GitHub
parent 4da790ea6a
commit 37e13b965c
216 changed files with 12481 additions and 7745 deletions

View File

@@ -7,7 +7,7 @@ use poulpy_hal::{
fft64::{
reim::{
ReimAdd, ReimAddInplace, ReimAddMul, ReimCopy, ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimIFFTTable, ReimMul,
ReimMulInplace, ReimNegate, ReimNegateInplace, ReimSub, ReimSubABInplace, ReimSubBAInplace, ReimToZnx,
ReimMulInplace, ReimNegate, ReimNegateInplace, ReimSub, ReimSubInplace, ReimSubNegateInplace, ReimToZnx,
ReimToZnxInplace, ReimZero, reim_copy_ref, reim_zero_ref,
},
reim4::{
@@ -15,10 +15,11 @@ use poulpy_hal::{
},
},
znx::{
ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxNormalizeFinalStep,
ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxExtractDigitAddMul, ZnxMulAddPowerOfTwo, ZnxMulPowerOfTwo,
ZnxMulPowerOfTwoInplace, ZnxNegate, ZnxNegateInplace, ZnxNormalizeDigit, ZnxNormalizeFinalStep,
ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeFirstStepInplace,
ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, ZnxRotate, ZnxSub,
ZnxSubABInplace, ZnxSubBAInplace, ZnxSwitchRing, ZnxZero, znx_copy_ref, znx_rotate, znx_zero_ref,
ZnxSubInplace, ZnxSubNegateInplace, ZnxSwitchRing, ZnxZero, znx_copy_ref, znx_rotate, znx_zero_ref,
},
},
};
@@ -27,8 +28,8 @@ use crate::cpu_fft64_avx::{
FFT64Avx,
reim::{
ReimFFTAvx, ReimIFFTAvx, reim_add_avx2_fma, reim_add_inplace_avx2_fma, reim_addmul_avx2_fma, reim_from_znx_i64_bnd50_fma,
reim_mul_avx2_fma, reim_mul_inplace_avx2_fma, reim_negate_avx2_fma, reim_negate_inplace_avx2_fma,
reim_sub_ab_inplace_avx2_fma, reim_sub_avx2_fma, reim_sub_ba_inplace_avx2_fma, reim_to_znx_i64_inplace_bnd63_avx2_fma,
reim_mul_avx2_fma, reim_mul_inplace_avx2_fma, reim_negate_avx2_fma, reim_negate_inplace_avx2_fma, reim_sub_avx2_fma,
reim_sub_inplace_avx2_fma, reim_sub_negate_inplace_avx2_fma, reim_to_znx_i64_inplace_bnd63_avx2_fma,
},
reim_to_znx_i64_bnd63_avx2_fma,
reim4::{
@@ -36,11 +37,12 @@ use crate::cpu_fft64_avx::{
reim4_vec_mat1col_product_avx, reim4_vec_mat2cols_2ndcol_product_avx, reim4_vec_mat2cols_product_avx,
},
znx_avx::{
znx_add_avx, znx_add_inplace_avx, znx_automorphism_avx, znx_negate_avx, znx_negate_inplace_avx,
znx_normalize_final_step_avx, znx_normalize_final_step_inplace_avx, znx_normalize_first_step_avx,
znx_normalize_first_step_carry_only_avx, znx_normalize_first_step_inplace_avx, znx_normalize_middle_step_avx,
znx_normalize_middle_step_carry_only_avx, znx_normalize_middle_step_inplace_avx, znx_sub_ab_inplace_avx, znx_sub_avx,
znx_sub_ba_inplace_avx, znx_switch_ring_avx,
znx_add_avx, znx_add_inplace_avx, znx_automorphism_avx, znx_extract_digit_addmul_avx, znx_mul_add_power_of_two_avx,
znx_mul_power_of_two_avx, znx_mul_power_of_two_inplace_avx, znx_negate_avx, znx_negate_inplace_avx,
znx_normalize_digit_avx, znx_normalize_final_step_avx, znx_normalize_final_step_inplace_avx,
znx_normalize_first_step_avx, znx_normalize_first_step_carry_only_avx, znx_normalize_first_step_inplace_avx,
znx_normalize_middle_step_avx, znx_normalize_middle_step_carry_only_avx, znx_normalize_middle_step_inplace_avx,
znx_sub_avx, znx_sub_inplace_avx, znx_sub_negate_inplace_avx, znx_switch_ring_avx,
},
};
@@ -131,20 +133,20 @@ impl ZnxSub for FFT64Avx {
}
}
impl ZnxSubABInplace for FFT64Avx {
impl ZnxSubInplace for FFT64Avx {
#[inline(always)]
fn znx_sub_ab_inplace(res: &mut [i64], a: &[i64]) {
fn znx_sub_inplace(res: &mut [i64], a: &[i64]) {
unsafe {
znx_sub_ab_inplace_avx(res, a);
znx_sub_inplace_avx(res, a);
}
}
}
impl ZnxSubBAInplace for FFT64Avx {
impl ZnxSubNegateInplace for FFT64Avx {
#[inline(always)]
fn znx_sub_ba_inplace(res: &mut [i64], a: &[i64]) {
fn znx_sub_negate_inplace(res: &mut [i64], a: &[i64]) {
unsafe {
znx_sub_ba_inplace_avx(res, a);
znx_sub_negate_inplace_avx(res, a);
}
}
}
@@ -183,6 +185,33 @@ impl ZnxNegateInplace for FFT64Avx {
}
}
impl ZnxMulAddPowerOfTwo for FFT64Avx {
#[inline(always)]
fn znx_muladd_power_of_two(k: i64, res: &mut [i64], a: &[i64]) {
unsafe {
znx_mul_add_power_of_two_avx(k, res, a);
}
}
}
impl ZnxMulPowerOfTwo for FFT64Avx {
#[inline(always)]
fn znx_mul_power_of_two(k: i64, res: &mut [i64], a: &[i64]) {
unsafe {
znx_mul_power_of_two_avx(k, res, a);
}
}
}
impl ZnxMulPowerOfTwoInplace for FFT64Avx {
#[inline(always)]
fn znx_mul_power_of_two_inplace(k: i64, res: &mut [i64]) {
unsafe {
znx_mul_power_of_two_inplace_avx(k, res);
}
}
}
impl ZnxRotate for FFT64Avx {
#[inline(always)]
fn znx_rotate(p: i64, res: &mut [i64], src: &[i64]) {
@@ -208,72 +237,90 @@ impl ZnxSwitchRing for FFT64Avx {
impl ZnxNormalizeFinalStep for FFT64Avx {
#[inline(always)]
fn znx_normalize_final_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
fn znx_normalize_final_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
unsafe {
znx_normalize_final_step_avx(basek, lsh, x, a, carry);
znx_normalize_final_step_avx(base2k, lsh, x, a, carry);
}
}
}
impl ZnxNormalizeFinalStepInplace for FFT64Avx {
#[inline(always)]
fn znx_normalize_final_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
fn znx_normalize_final_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
unsafe {
znx_normalize_final_step_inplace_avx(basek, lsh, x, carry);
znx_normalize_final_step_inplace_avx(base2k, lsh, x, carry);
}
}
}
impl ZnxNormalizeFirstStep for FFT64Avx {
#[inline(always)]
fn znx_normalize_first_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
fn znx_normalize_first_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
unsafe {
znx_normalize_first_step_avx(basek, lsh, x, a, carry);
znx_normalize_first_step_avx(base2k, lsh, x, a, carry);
}
}
}
impl ZnxNormalizeFirstStepCarryOnly for FFT64Avx {
#[inline(always)]
fn znx_normalize_first_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
fn znx_normalize_first_step_carry_only(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
unsafe {
znx_normalize_first_step_carry_only_avx(basek, lsh, x, carry);
znx_normalize_first_step_carry_only_avx(base2k, lsh, x, carry);
}
}
}
impl ZnxNormalizeFirstStepInplace for FFT64Avx {
#[inline(always)]
fn znx_normalize_first_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
fn znx_normalize_first_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
unsafe {
znx_normalize_first_step_inplace_avx(basek, lsh, x, carry);
znx_normalize_first_step_inplace_avx(base2k, lsh, x, carry);
}
}
}
impl ZnxNormalizeMiddleStep for FFT64Avx {
#[inline(always)]
fn znx_normalize_middle_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
fn znx_normalize_middle_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
unsafe {
znx_normalize_middle_step_avx(basek, lsh, x, a, carry);
znx_normalize_middle_step_avx(base2k, lsh, x, a, carry);
}
}
}
impl ZnxNormalizeMiddleStepCarryOnly for FFT64Avx {
#[inline(always)]
fn znx_normalize_middle_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
fn znx_normalize_middle_step_carry_only(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
unsafe {
znx_normalize_middle_step_carry_only_avx(basek, lsh, x, carry);
znx_normalize_middle_step_carry_only_avx(base2k, lsh, x, carry);
}
}
}
impl ZnxNormalizeMiddleStepInplace for FFT64Avx {
#[inline(always)]
fn znx_normalize_middle_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
fn znx_normalize_middle_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
unsafe {
znx_normalize_middle_step_inplace_avx(basek, lsh, x, carry);
znx_normalize_middle_step_inplace_avx(base2k, lsh, x, carry);
}
}
}
impl ZnxExtractDigitAddMul for FFT64Avx {
#[inline(always)]
fn znx_extract_digit_addmul(base2k: usize, lsh: usize, res: &mut [i64], src: &mut [i64]) {
unsafe {
znx_extract_digit_addmul_avx(base2k, lsh, res, src);
}
}
}
impl ZnxNormalizeDigit for FFT64Avx {
#[inline(always)]
fn znx_normalize_digit(base2k: usize, res: &mut [i64], src: &mut [i64]) {
unsafe {
znx_normalize_digit_avx(base2k, res, src);
}
}
}
@@ -346,20 +393,20 @@ impl ReimSub for FFT64Avx {
}
}
impl ReimSubABInplace for FFT64Avx {
impl ReimSubInplace for FFT64Avx {
#[inline(always)]
fn reim_sub_ab_inplace(res: &mut [f64], a: &[f64]) {
fn reim_sub_inplace(res: &mut [f64], a: &[f64]) {
unsafe {
reim_sub_ab_inplace_avx2_fma(res, a);
reim_sub_inplace_avx2_fma(res, a);
}
}
}
impl ReimSubBAInplace for FFT64Avx {
impl ReimSubNegateInplace for FFT64Avx {
#[inline(always)]
fn reim_sub_ba_inplace(res: &mut [f64], a: &[f64]) {
fn reim_sub_negate_inplace(res: &mut [f64], a: &[f64]) {
unsafe {
reim_sub_ba_inplace_avx2_fma(res, a);
reim_sub_negate_inplace_avx2_fma(res, a);
}
}
}

View File

@@ -88,7 +88,7 @@ pub fn reim_sub_avx2_fma(res: &mut [f64], a: &[f64], b: &[f64]) {
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
#[target_feature(enable = "avx2,fma")]
pub fn reim_sub_ab_inplace_avx2_fma(res: &mut [f64], a: &[f64]) {
pub fn reim_sub_inplace_avx2_fma(res: &mut [f64], a: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), res.len());
@@ -115,7 +115,7 @@ pub fn reim_sub_ab_inplace_avx2_fma(res: &mut [f64], a: &[f64]) {
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
#[target_feature(enable = "avx2,fma")]
pub fn reim_sub_ba_inplace_avx2_fma(res: &mut [f64], a: &[f64]) {
pub fn reim_sub_negate_inplace_avx2_fma(res: &mut [f64], a: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), res.len());

View File

@@ -253,9 +253,6 @@ fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8]
(take_slice, rem_slice)
}
} else {
panic!(
"Attempted to take {} from scratch with {} aligned bytes left",
take_len, aligned_len,
);
panic!("Attempted to take {take_len} from scratch with {aligned_len} aligned bytes left");
}
}

View File

@@ -5,15 +5,15 @@ cross_backend_test_suite! {
backend_ref = crate::cpu_fft64_ref::FFT64Ref,
backend_test = crate::cpu_fft64_avx::FFT64Avx,
size = 1 << 5,
basek = 12,
base2k = 12,
tests = {
test_vec_znx_add => poulpy_hal::test_suite::vec_znx::test_vec_znx_add,
test_vec_znx_add_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_inplace,
test_vec_znx_add_scalar => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_scalar,
test_vec_znx_add_scalar_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_scalar_inplace,
test_vec_znx_sub => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub,
test_vec_znx_sub_ab_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_ab_inplace,
test_vec_znx_sub_ba_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_ba_inplace,
test_vec_znx_sub_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_inplace,
test_vec_znx_sub_negate_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_negate_inplace,
test_vec_znx_sub_scalar => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_scalar,
test_vec_znx_sub_scalar_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_scalar_inplace,
test_vec_znx_rsh => poulpy_hal::test_suite::vec_znx::test_vec_znx_rsh,
@@ -41,7 +41,7 @@ cross_backend_test_suite! {
backend_ref = crate::cpu_fft64_ref::FFT64Ref,
backend_test = crate::cpu_fft64_avx::FFT64Avx,
size = 1 << 5,
basek = 12,
base2k = 12,
tests = {
test_svp_apply_dft_to_dft => poulpy_hal::test_suite::svp::test_svp_apply_dft_to_dft,
test_svp_apply_dft_to_dft_inplace => poulpy_hal::test_suite::svp::test_svp_apply_dft_to_dft_inplace,
@@ -53,20 +53,20 @@ cross_backend_test_suite! {
backend_ref = crate::cpu_fft64_ref::FFT64Ref,
backend_test = crate::cpu_fft64_avx::FFT64Avx,
size = 1 << 5,
basek = 12,
base2k = 12,
tests = {
test_vec_znx_big_add => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add,
test_vec_znx_big_add_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add_inplace,
test_vec_znx_big_add_small => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add_small,
test_vec_znx_big_add_small_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add_small_inplace,
test_vec_znx_big_sub => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub,
test_vec_znx_big_sub_ab_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_ab_inplace,
test_vec_znx_big_sub_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_inplace,
test_vec_znx_big_automorphism => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_automorphism,
test_vec_znx_big_automorphism_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_automorphism_inplace,
test_vec_znx_big_negate => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_negate,
test_vec_znx_big_negate_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_negate_inplace,
test_vec_znx_big_normalize => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_normalize,
test_vec_znx_big_sub_ba_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_ba_inplace,
test_vec_znx_big_sub_negate_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_negate_inplace,
test_vec_znx_big_sub_small_a => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_a,
test_vec_znx_big_sub_small_a_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_a_inplace,
test_vec_znx_big_sub_small_b => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_b,
@@ -79,13 +79,13 @@ cross_backend_test_suite! {
backend_ref = crate::cpu_fft64_ref::FFT64Ref,
backend_test = crate::cpu_fft64_avx::FFT64Avx,
size = 1 << 5,
basek = 12,
base2k = 12,
tests = {
test_vec_znx_dft_add => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_add,
test_vec_znx_dft_add_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_add_inplace,
test_vec_znx_dft_sub => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub,
test_vec_znx_dft_sub_ab_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub_ab_inplace,
test_vec_znx_dft_sub_ba_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub_ba_inplace,
test_vec_znx_dft_sub_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub_inplace,
test_vec_znx_dft_sub_negate_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub_negate_inplace,
test_vec_znx_idft_apply => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_idft_apply,
test_vec_znx_idft_apply_consume => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_idft_apply_consume,
test_vec_znx_idft_apply_tmpa => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_idft_apply_tmpa,
@@ -97,7 +97,7 @@ cross_backend_test_suite! {
backend_ref = crate::cpu_fft64_ref::FFT64Ref,
backend_test = crate::cpu_fft64_avx::FFT64Avx,
size = 1 << 5,
basek = 12,
base2k = 12,
tests = {
test_vmp_apply_dft_to_dft => poulpy_hal::test_suite::vmp::test_vmp_apply_dft_to_dft,
test_vmp_apply_dft_to_dft_add => poulpy_hal::test_suite::vmp::test_vmp_apply_dft_to_dft_add,

View File

@@ -1,7 +1,8 @@
use poulpy_hal::{
api::{
TakeSlice, VecZnxAutomorphismInplaceTmpBytes, VecZnxMergeRingsTmpBytes, VecZnxMulXpMinusOneInplaceTmpBytes,
VecZnxNormalizeTmpBytes, VecZnxRotateInplaceTmpBytes, VecZnxSplitRingTmpBytes,
TakeSlice, VecZnxAutomorphismInplaceTmpBytes, VecZnxLshTmpBytes, VecZnxMergeRingsTmpBytes,
VecZnxMulXpMinusOneInplaceTmpBytes, VecZnxNormalizeTmpBytes, VecZnxRotateInplaceTmpBytes, VecZnxRshTmpBytes,
VecZnxSplitRingTmpBytes,
},
layouts::{Module, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef},
oep::{
@@ -12,7 +13,7 @@ use poulpy_hal::{
VecZnxMulXpMinusOneInplaceTmpBytesImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl,
VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl,
VecZnxRotateInplaceTmpBytesImpl, VecZnxRshImpl, VecZnxRshInplaceImpl, VecZnxRshTmpBytesImpl, VecZnxSplitRingImpl,
VecZnxSplitRingTmpBytesImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarImpl,
VecZnxSplitRingTmpBytesImpl, VecZnxSubImpl, VecZnxSubInplaceImpl, VecZnxSubNegateInplaceImpl, VecZnxSubScalarImpl,
VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl,
},
reference::vec_znx::{
@@ -23,7 +24,7 @@ use poulpy_hal::{
vec_znx_mul_xp_minus_one_inplace_tmp_bytes, vec_znx_negate, vec_znx_negate_inplace, vec_znx_normalize,
vec_znx_normalize_inplace, vec_znx_normalize_tmp_bytes, vec_znx_rotate, vec_znx_rotate_inplace,
vec_znx_rotate_inplace_tmp_bytes, vec_znx_rsh, vec_znx_rsh_inplace, vec_znx_rsh_tmp_bytes, vec_znx_split_ring,
vec_znx_split_ring_tmp_bytes, vec_znx_sub, vec_znx_sub_ab_inplace, vec_znx_sub_ba_inplace, vec_znx_sub_scalar,
vec_znx_split_ring_tmp_bytes, vec_znx_sub, vec_znx_sub_inplace, vec_znx_sub_negate_inplace, vec_znx_sub_scalar,
vec_znx_sub_scalar_inplace, vec_znx_switch_ring,
},
source::Source,
@@ -43,9 +44,10 @@ where
{
fn vec_znx_normalize_impl<R, A>(
module: &Module<Self>,
basek: usize,
res_basek: usize,
res: &mut R,
res_col: usize,
a_basek: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch<Self>,
@@ -54,7 +56,7 @@ where
A: VecZnxToRef,
{
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
vec_znx_normalize::<R, A, Self>(basek, res, res_col, a, a_col, carry);
vec_znx_normalize::<R, A, Self>(res_basek, res, res_col, a_basek, a, a_col, carry);
}
}
@@ -64,7 +66,7 @@ where
{
fn vec_znx_normalize_inplace_impl<R>(
module: &Module<Self>,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
scratch: &mut Scratch<Self>,
@@ -72,7 +74,7 @@ where
R: VecZnxToMut,
{
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
vec_znx_normalize_inplace::<R, Self>(basek, res, res_col, carry);
vec_znx_normalize_inplace::<R, Self>(base2k, res, res_col, carry);
}
}
@@ -143,23 +145,23 @@ unsafe impl VecZnxSubImpl<Self> for FFT64Avx {
}
}
unsafe impl VecZnxSubABInplaceImpl<Self> for FFT64Avx {
fn vec_znx_sub_ab_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
unsafe impl VecZnxSubInplaceImpl<Self> for FFT64Avx {
fn vec_znx_sub_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
vec_znx_sub_ab_inplace::<R, A, Self>(res, res_col, a, a_col);
vec_znx_sub_inplace::<R, A, Self>(res, res_col, a, a_col);
}
}
unsafe impl VecZnxSubBAInplaceImpl<Self> for FFT64Avx {
fn vec_znx_sub_ba_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
unsafe impl VecZnxSubNegateInplaceImpl<Self> for FFT64Avx {
fn vec_znx_sub_negate_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
vec_znx_sub_ba_inplace::<R, A, Self>(res, res_col, a, a_col);
vec_znx_sub_negate_inplace::<R, A, Self>(res, res_col, a, a_col);
}
}
@@ -234,9 +236,9 @@ where
Module<Self>: VecZnxNormalizeTmpBytes,
Scratch<Self>: TakeSlice,
{
fn vec_znx_lsh_inplace_impl<R, A>(
fn vec_znx_lsh_impl<R, A>(
module: &Module<Self>,
basek: usize,
base2k: usize,
k: usize,
res: &mut R,
res_col: usize,
@@ -247,8 +249,8 @@ where
R: VecZnxToMut,
A: VecZnxToRef,
{
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
vec_znx_lsh::<_, _, Self>(basek, k, res, res_col, a, a_col, carry);
let (carry, _) = scratch.take_slice(module.vec_znx_lsh_tmp_bytes() / size_of::<i64>());
vec_znx_lsh::<_, _, Self>(base2k, k, res, res_col, a, a_col, carry);
}
}
@@ -259,7 +261,7 @@ where
{
fn vec_znx_lsh_inplace_impl<A>(
module: &Module<Self>,
basek: usize,
base2k: usize,
k: usize,
a: &mut A,
a_col: usize,
@@ -267,8 +269,8 @@ where
) where
A: VecZnxToMut,
{
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
vec_znx_lsh_inplace::<_, Self>(basek, k, a, a_col, carry);
let (carry, _) = scratch.take_slice(module.vec_znx_lsh_tmp_bytes() / size_of::<i64>());
vec_znx_lsh_inplace::<_, Self>(base2k, k, a, a_col, carry);
}
}
@@ -277,9 +279,9 @@ where
Module<Self>: VecZnxNormalizeTmpBytes,
Scratch<Self>: TakeSlice,
{
fn vec_znx_rsh_inplace_impl<R, A>(
fn vec_znx_rsh_impl<R, A>(
module: &Module<Self>,
basek: usize,
base2k: usize,
k: usize,
res: &mut R,
res_col: usize,
@@ -290,8 +292,8 @@ where
R: VecZnxToMut,
A: VecZnxToRef,
{
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
vec_znx_rsh::<_, _, Self>(basek, k, res, res_col, a, a_col, carry);
let (carry, _) = scratch.take_slice(module.vec_znx_rsh_tmp_bytes() / size_of::<i64>());
vec_znx_rsh::<_, _, Self>(base2k, k, res, res_col, a, a_col, carry);
}
}
@@ -302,7 +304,7 @@ where
{
fn vec_znx_rsh_inplace_impl<A>(
module: &Module<Self>,
basek: usize,
base2k: usize,
k: usize,
a: &mut A,
a_col: usize,
@@ -310,8 +312,8 @@ where
) where
A: VecZnxToMut,
{
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
vec_znx_rsh_inplace::<_, Self>(basek, k, a, a_col, carry);
let (carry, _) = scratch.take_slice(module.vec_znx_rsh_tmp_bytes() / size_of::<i64>());
vec_znx_rsh_inplace::<_, Self>(base2k, k, a, a_col, carry);
}
}
@@ -495,18 +497,18 @@ unsafe impl VecZnxCopyImpl<Self> for FFT64Avx {
}
unsafe impl VecZnxFillUniformImpl<Self> for FFT64Avx {
fn vec_znx_fill_uniform_impl<R>(_module: &Module<Self>, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
fn vec_znx_fill_uniform_impl<R>(_module: &Module<Self>, base2k: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: VecZnxToMut,
{
vec_znx_fill_uniform_ref(basek, res, res_col, source)
vec_znx_fill_uniform_ref(base2k, res, res_col, source)
}
}
unsafe impl VecZnxFillNormalImpl<Self> for FFT64Avx {
fn vec_znx_fill_normal_impl<R>(
_module: &Module<Self>,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -516,14 +518,14 @@ unsafe impl VecZnxFillNormalImpl<Self> for FFT64Avx {
) where
R: VecZnxToMut,
{
vec_znx_fill_normal_ref(basek, res, res_col, k, sigma, bound, source);
vec_znx_fill_normal_ref(base2k, res, res_col, k, sigma, bound, source);
}
}
unsafe impl VecZnxAddNormalImpl<Self> for FFT64Avx {
fn vec_znx_add_normal_impl<R>(
_module: &Module<Self>,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -533,6 +535,6 @@ unsafe impl VecZnxAddNormalImpl<Self> for FFT64Avx {
) where
R: VecZnxToMut,
{
vec_znx_add_normal_ref(basek, res, res_col, k, sigma, bound, source);
vec_znx_add_normal_ref(base2k, res, res_col, k, sigma, bound, source);
}
}

View File

@@ -10,15 +10,15 @@ use poulpy_hal::{
VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, VecZnxBigAutomorphismImpl,
VecZnxBigAutomorphismInplaceImpl, VecZnxBigAutomorphismInplaceTmpBytesImpl, VecZnxBigFromBytesImpl,
VecZnxBigFromSmallImpl, VecZnxBigNegateImpl, VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl,
VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubABInplaceImpl, VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl,
VecZnxBigSubSmallAImpl, VecZnxBigSubSmallAInplaceImpl, VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl,
VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubImpl, VecZnxBigSubInplaceImpl, VecZnxBigSubNegateInplaceImpl,
VecZnxBigSubSmallAImpl, VecZnxBigSubSmallBImpl, VecZnxBigSubSmallInplaceImpl, VecZnxBigSubSmallNegateInplaceImpl,
},
reference::{
fft64::vec_znx_big::{
vec_znx_big_add, vec_znx_big_add_inplace, vec_znx_big_add_normal_ref, vec_znx_big_add_small,
vec_znx_big_add_small_inplace, vec_znx_big_automorphism, vec_znx_big_automorphism_inplace,
vec_znx_big_automorphism_inplace_tmp_bytes, vec_znx_big_negate, vec_znx_big_negate_inplace, vec_znx_big_normalize,
vec_znx_big_normalize_tmp_bytes, vec_znx_big_sub, vec_znx_big_sub_ab_inplace, vec_znx_big_sub_ba_inplace,
vec_znx_big_normalize_tmp_bytes, vec_znx_big_sub, vec_znx_big_sub_inplace, vec_znx_big_sub_negate_inplace,
vec_znx_big_sub_small_a, vec_znx_big_sub_small_a_inplace, vec_znx_big_sub_small_b, vec_znx_big_sub_small_b_inplace,
},
znx::{znx_copy_ref, znx_zero_ref},
@@ -76,7 +76,7 @@ unsafe impl VecZnxBigFromSmallImpl<Self> for FFT64Avx {
unsafe impl VecZnxBigAddNormalImpl<Self> for FFT64Avx {
fn add_normal_impl<R: VecZnxBigToMut<Self>>(
_module: &Module<Self>,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -84,7 +84,7 @@ unsafe impl VecZnxBigAddNormalImpl<Self> for FFT64Avx {
sigma: f64,
bound: f64,
) {
vec_znx_big_add_normal_ref(basek, res, res_col, k, sigma, bound, source);
vec_znx_big_add_normal_ref(base2k, res, res_col, k, sigma, bound, source);
}
}
@@ -167,25 +167,25 @@ unsafe impl VecZnxBigSubImpl<Self> for FFT64Avx {
}
}
unsafe impl VecZnxBigSubABInplaceImpl<Self> for FFT64Avx {
unsafe impl VecZnxBigSubInplaceImpl<Self> for FFT64Avx {
/// Subtracts `a` from `b` and stores the result on `b`.
fn vec_znx_big_sub_ab_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_big_sub_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<Self>,
A: VecZnxBigToRef<Self>,
{
vec_znx_big_sub_ab_inplace(res, res_col, a, a_col);
vec_znx_big_sub_inplace(res, res_col, a, a_col);
}
}
unsafe impl VecZnxBigSubBAInplaceImpl<Self> for FFT64Avx {
unsafe impl VecZnxBigSubNegateInplaceImpl<Self> for FFT64Avx {
/// Subtracts `b` from `a` and stores the result on `b`.
fn vec_znx_big_sub_ba_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_big_sub_negate_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<Self>,
A: VecZnxBigToRef<Self>,
{
vec_znx_big_sub_ba_inplace(res, res_col, a, a_col);
vec_znx_big_sub_negate_inplace(res, res_col, a, a_col);
}
}
@@ -208,9 +208,9 @@ unsafe impl VecZnxBigSubSmallAImpl<Self> for FFT64Avx {
}
}
unsafe impl VecZnxBigSubSmallAInplaceImpl<Self> for FFT64Avx {
unsafe impl VecZnxBigSubSmallInplaceImpl<Self> for FFT64Avx {
/// Subtracts `a` from `res` and stores the result on `res`.
fn vec_znx_big_sub_small_a_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_big_sub_small_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<Self>,
A: VecZnxToRef,
@@ -238,9 +238,9 @@ unsafe impl VecZnxBigSubSmallBImpl<Self> for FFT64Avx {
}
}
unsafe impl VecZnxBigSubSmallBInplaceImpl<Self> for FFT64Avx {
unsafe impl VecZnxBigSubSmallNegateInplaceImpl<Self> for FFT64Avx {
/// Subtracts `res` from `a` and stores the result on `res`.
fn vec_znx_big_sub_small_b_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_big_sub_small_negate_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<Self>,
A: VecZnxToRef,
@@ -280,9 +280,10 @@ where
{
fn vec_znx_big_normalize_impl<R, A>(
module: &Module<Self>,
basek: usize,
res_basek: usize,
res: &mut R,
res_col: usize,
a_basek: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch<Self>,
@@ -291,7 +292,7 @@ where
A: VecZnxBigToRef<Self>,
{
let (carry, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes() / size_of::<i64>());
vec_znx_big_normalize(basek, res, res_col, a, a_col, carry);
vec_znx_big_normalize(res_basek, res, res_col, a_basek, a, a_col, carry);
}
}
@@ -326,7 +327,7 @@ where
) where
R: VecZnxBigToMut<Self>,
{
let (tmp, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes() / size_of::<i64>());
let (tmp, _) = scratch.take_slice(module.vec_znx_big_automorphism_inplace_tmp_bytes() / size_of::<i64>());
vec_znx_big_automorphism_inplace(p, res, res_col, tmp);
}
}

View File

@@ -5,12 +5,12 @@ use poulpy_hal::{
},
oep::{
VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftApplyImpl,
VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl, VecZnxDftSubImpl,
VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubImpl, VecZnxDftSubInplaceImpl, VecZnxDftSubNegateInplaceImpl,
VecZnxDftZeroImpl, VecZnxIdftApplyConsumeImpl, VecZnxIdftApplyImpl, VecZnxIdftApplyTmpAImpl, VecZnxIdftApplyTmpBytesImpl,
},
reference::fft64::vec_znx_dft::{
vec_znx_dft_add, vec_znx_dft_add_inplace, vec_znx_dft_apply, vec_znx_dft_copy, vec_znx_dft_sub,
vec_znx_dft_sub_ab_inplace, vec_znx_dft_sub_ba_inplace, vec_znx_dft_zero, vec_znx_idft_apply, vec_znx_idft_apply_consume,
vec_znx_dft_add, vec_znx_dft_add_inplace, vec_znx_dft_apply, vec_znx_dft_copy, vec_znx_dft_sub, vec_znx_dft_sub_inplace,
vec_znx_dft_sub_negate_inplace, vec_znx_dft_zero, vec_znx_idft_apply, vec_znx_idft_apply_consume,
vec_znx_idft_apply_tmpa,
},
};
@@ -139,23 +139,23 @@ unsafe impl VecZnxDftSubImpl<Self> for FFT64Avx {
}
}
unsafe impl VecZnxDftSubABInplaceImpl<Self> for FFT64Avx {
fn vec_znx_dft_sub_ab_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
unsafe impl VecZnxDftSubInplaceImpl<Self> for FFT64Avx {
fn vec_znx_dft_sub_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<Self>,
A: VecZnxDftToRef<Self>,
{
vec_znx_dft_sub_ab_inplace(res, res_col, a, a_col);
vec_znx_dft_sub_inplace(res, res_col, a, a_col);
}
}
unsafe impl VecZnxDftSubBAInplaceImpl<Self> for FFT64Avx {
fn vec_znx_dft_sub_ba_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
unsafe impl VecZnxDftSubNegateInplaceImpl<Self> for FFT64Avx {
fn vec_znx_dft_sub_negate_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<Self>,
A: VecZnxDftToRef<Self>,
{
vec_znx_dft_sub_ba_inplace(res, res_col, a, a_col);
vec_znx_dft_sub_negate_inplace(res, res_col, a, a_col);
}
}

View File

@@ -18,21 +18,21 @@ unsafe impl ZnNormalizeInplaceImpl<Self> for FFT64Avx
where
Self: TakeSliceImpl<Self>,
{
fn zn_normalize_inplace_impl<R>(n: usize, basek: usize, res: &mut R, res_col: usize, scratch: &mut Scratch<Self>)
fn zn_normalize_inplace_impl<R>(n: usize, base2k: usize, res: &mut R, res_col: usize, scratch: &mut Scratch<Self>)
where
R: ZnToMut,
{
let (carry, _) = scratch.take_slice(n);
zn_normalize_inplace::<R, FFT64Avx>(n, basek, res, res_col, carry);
zn_normalize_inplace::<R, FFT64Avx>(n, base2k, res, res_col, carry);
}
}
unsafe impl ZnFillUniformImpl<Self> for FFT64Avx {
fn zn_fill_uniform_impl<R>(n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
fn zn_fill_uniform_impl<R>(n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: ZnToMut,
{
zn_fill_uniform(n, basek, res, res_col, source);
zn_fill_uniform(n, base2k, res, res_col, source);
}
}
@@ -40,7 +40,7 @@ unsafe impl ZnFillNormalImpl<Self> for FFT64Avx {
#[allow(clippy::too_many_arguments)]
fn zn_fill_normal_impl<R>(
n: usize,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -50,7 +50,7 @@ unsafe impl ZnFillNormalImpl<Self> for FFT64Avx {
) where
R: ZnToMut,
{
zn_fill_normal(n, basek, res, res_col, k, source, sigma, bound);
zn_fill_normal(n, base2k, res, res_col, k, source, sigma, bound);
}
}
@@ -58,7 +58,7 @@ unsafe impl ZnAddNormalImpl<Self> for FFT64Avx {
#[allow(clippy::too_many_arguments)]
fn zn_add_normal_impl<R>(
n: usize,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -68,6 +68,6 @@ unsafe impl ZnAddNormalImpl<Self> for FFT64Avx {
) where
R: ZnToMut,
{
zn_add_normal(n, basek, res, res_col, k, source, sigma, bound);
zn_add_normal(n, base2k, res, res_col, k, source, sigma, bound);
}
}

View File

@@ -1,5 +1,6 @@
mod add;
mod automorphism;
mod mul;
mod neg;
mod normalization;
mod sub;
@@ -7,6 +8,7 @@ mod switch_ring;
pub(crate) use add::*;
pub(crate) use automorphism::*;
pub(crate) use mul::*;
pub(crate) use neg::*;
pub(crate) use normalization::*;
pub(crate) use sub::*;

View File

@@ -0,0 +1,318 @@
/// Multiply/divide by a power of two with rounding matching [poulpy_hal::reference::znx::znx_mul_power_of_two_ref].
///
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn znx_mul_power_of_two_avx(k: i64, res: &mut [i64], a: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len());
}
use core::arch::x86_64::{
__m128i, __m256i, _mm_cvtsi32_si128, _mm256_add_epi64, _mm256_and_si256, _mm256_cmpgt_epi64, _mm256_loadu_si256,
_mm256_or_si256, _mm256_set1_epi64x, _mm256_setzero_si256, _mm256_sll_epi64, _mm256_srl_epi64, _mm256_srli_epi64,
_mm256_storeu_si256, _mm256_sub_epi64,
};
let n: usize = res.len();
if n == 0 {
return;
}
if k == 0 {
use poulpy_hal::reference::znx::znx_copy_ref;
znx_copy_ref(res, a);
return;
}
let span: usize = n >> 2; // number of 256-bit chunks
unsafe {
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
let mut aa: *const __m256i = a.as_ptr() as *const __m256i;
if k > 0 {
// Left shift by k (variable count).
#[cfg(debug_assertions)]
{
debug_assert!(k <= 63);
}
let cnt128: __m128i = _mm_cvtsi32_si128(k as i32);
for _ in 0..span {
let x: __m256i = _mm256_loadu_si256(aa);
let y: __m256i = _mm256_sll_epi64(x, cnt128);
_mm256_storeu_si256(rr, y);
rr = rr.add(1);
aa = aa.add(1);
}
// tail
if !n.is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_mul_power_of_two_ref;
znx_mul_power_of_two_ref(k, &mut res[span << 2..], &a[span << 2..]);
}
return;
}
// k < 0 => arithmetic right shift with rounding:
// for each x:
// sign_bit = (x >> 63) & 1
// bias = (1<<(kp-1)) - sign_bit
// t = x + bias
// y = t >> kp (arithmetic)
let kp = -k;
#[cfg(debug_assertions)]
{
debug_assert!((1..=63).contains(&kp));
}
let cnt_right: __m128i = _mm_cvtsi32_si128(kp as i32);
let bias_base: __m256i = _mm256_set1_epi64x(1_i64 << (kp - 1));
let top_mask: __m256i = _mm256_set1_epi64x(-1_i64 << (64 - kp)); // high kp bits
let zero: __m256i = _mm256_setzero_si256();
for _ in 0..span {
let x = _mm256_loadu_si256(aa);
// bias = (1 << (kp-1)) - sign_bit
let sign_bit_x: __m256i = _mm256_srli_epi64(x, 63);
let bias: __m256i = _mm256_sub_epi64(bias_base, sign_bit_x);
// t = x + bias
let t: __m256i = _mm256_add_epi64(x, bias);
// logical shift
let lsr: __m256i = _mm256_srl_epi64(t, cnt_right);
// sign extension
let neg: __m256i = _mm256_cmpgt_epi64(zero, t);
let fill: __m256i = _mm256_and_si256(neg, top_mask);
let y: __m256i = _mm256_or_si256(lsr, fill);
_mm256_storeu_si256(rr, y);
rr = rr.add(1);
aa = aa.add(1);
}
}
// tail
if !n.is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_mul_power_of_two_ref;
znx_mul_power_of_two_ref(k, &mut res[span << 2..], &a[span << 2..]);
}
}
/// Multiply/divide inplace by a power of two with rounding matching [poulpy_hal::reference::znx::znx_mul_power_of_two_inplace_ref].
///
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn znx_mul_power_of_two_inplace_avx(k: i64, res: &mut [i64]) {
use core::arch::x86_64::{
__m128i, __m256i, _mm_cvtsi32_si128, _mm256_add_epi64, _mm256_and_si256, _mm256_cmpgt_epi64, _mm256_loadu_si256,
_mm256_or_si256, _mm256_set1_epi64x, _mm256_setzero_si256, _mm256_sll_epi64, _mm256_srl_epi64, _mm256_srli_epi64,
_mm256_storeu_si256, _mm256_sub_epi64,
};
let n: usize = res.len();
if n == 0 {
return;
}
if k == 0 {
return;
}
let span: usize = n >> 2; // number of 256-bit chunks
unsafe {
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
if k > 0 {
// Left shift by k (variable count).
#[cfg(debug_assertions)]
{
debug_assert!(k <= 63);
}
let cnt128: __m128i = _mm_cvtsi32_si128(k as i32);
for _ in 0..span {
let x: __m256i = _mm256_loadu_si256(rr);
let y: __m256i = _mm256_sll_epi64(x, cnt128);
_mm256_storeu_si256(rr, y);
rr = rr.add(1);
}
// tail
if !n.is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_mul_power_of_two_inplace_ref;
znx_mul_power_of_two_inplace_ref(k, &mut res[span << 2..]);
}
return;
}
// k < 0 => arithmetic right shift with rounding:
// for each x:
// sign_bit = (x >> 63) & 1
// bias = (1<<(kp-1)) - sign_bit
// t = x + bias
// y = t >> kp (arithmetic)
let kp = -k;
#[cfg(debug_assertions)]
{
debug_assert!((1..=63).contains(&kp));
}
let cnt_right: __m128i = _mm_cvtsi32_si128(kp as i32);
let bias_base: __m256i = _mm256_set1_epi64x(1_i64 << (kp - 1));
let top_mask: __m256i = _mm256_set1_epi64x(-1_i64 << (64 - kp)); // high kp bits
let zero: __m256i = _mm256_setzero_si256();
for _ in 0..span {
let x = _mm256_loadu_si256(rr);
// bias = (1 << (kp-1)) - sign_bit
let sign_bit_x: __m256i = _mm256_srli_epi64(x, 63);
let bias: __m256i = _mm256_sub_epi64(bias_base, sign_bit_x);
// t = x + bias
let t: __m256i = _mm256_add_epi64(x, bias);
// logical shift
let lsr: __m256i = _mm256_srl_epi64(t, cnt_right);
// sign extension
let neg: __m256i = _mm256_cmpgt_epi64(zero, t);
let fill: __m256i = _mm256_and_si256(neg, top_mask);
let y: __m256i = _mm256_or_si256(lsr, fill);
_mm256_storeu_si256(rr, y);
rr = rr.add(1);
}
}
// tail
if !n.is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_mul_power_of_two_inplace_ref;
znx_mul_power_of_two_inplace_ref(k, &mut res[span << 2..]);
}
}
/// Multiply/divide by a power of two and add on the result with rounding matching [poulpy_hal::reference::znx::znx_mul_power_of_two_inplace_ref].
///
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn znx_mul_add_power_of_two_avx(k: i64, res: &mut [i64], a: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len());
}
use core::arch::x86_64::{
__m128i, __m256i, _mm_cvtsi32_si128, _mm256_add_epi64, _mm256_and_si256, _mm256_cmpgt_epi64, _mm256_loadu_si256,
_mm256_or_si256, _mm256_set1_epi64x, _mm256_setzero_si256, _mm256_sll_epi64, _mm256_srl_epi64, _mm256_srli_epi64,
_mm256_storeu_si256, _mm256_sub_epi64,
};
let n: usize = res.len();
if n == 0 {
return;
}
if k == 0 {
use crate::cpu_fft64_avx::znx_avx::znx_add_inplace_avx;
znx_add_inplace_avx(res, a);
return;
}
let span: usize = n >> 2; // number of 256-bit chunks
unsafe {
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
let mut aa: *const __m256i = a.as_ptr() as *const __m256i;
if k > 0 {
// Left shift by k (variable count).
#[cfg(debug_assertions)]
{
debug_assert!(k <= 63);
}
let cnt128: __m128i = _mm_cvtsi32_si128(k as i32);
for _ in 0..span {
let x: __m256i = _mm256_loadu_si256(aa);
let y: __m256i = _mm256_loadu_si256(rr);
_mm256_storeu_si256(rr, _mm256_add_epi64(y, _mm256_sll_epi64(x, cnt128)));
rr = rr.add(1);
aa = aa.add(1);
}
// tail
if !n.is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_mul_add_power_of_two_ref;
znx_mul_add_power_of_two_ref(k, &mut res[span << 2..], &a[span << 2..]);
}
return;
}
// k < 0 => arithmetic right shift with rounding:
// for each x:
// sign_bit = (x >> 63) & 1
// bias = (1<<(kp-1)) - sign_bit
// t = x + bias
// y = t >> kp (arithmetic)
let kp = -k;
#[cfg(debug_assertions)]
{
debug_assert!((1..=63).contains(&kp));
}
let cnt_right: __m128i = _mm_cvtsi32_si128(kp as i32);
let bias_base: __m256i = _mm256_set1_epi64x(1_i64 << (kp - 1));
let top_mask: __m256i = _mm256_set1_epi64x(-1_i64 << (64 - kp)); // high kp bits
let zero: __m256i = _mm256_setzero_si256();
for _ in 0..span {
let x: __m256i = _mm256_loadu_si256(aa);
let y: __m256i = _mm256_loadu_si256(rr);
// bias = (1 << (kp-1)) - sign_bit
let sign_bit_x: __m256i = _mm256_srli_epi64(x, 63);
let bias: __m256i = _mm256_sub_epi64(bias_base, sign_bit_x);
// t = x + bias
let t: __m256i = _mm256_add_epi64(x, bias);
// logical shift
let lsr: __m256i = _mm256_srl_epi64(t, cnt_right);
// sign extension
let neg: __m256i = _mm256_cmpgt_epi64(zero, t);
let fill: __m256i = _mm256_and_si256(neg, top_mask);
let out: __m256i = _mm256_or_si256(lsr, fill);
_mm256_storeu_si256(rr, _mm256_add_epi64(y, out));
rr = rr.add(1);
aa = aa.add(1);
}
}
// tail
if !n.is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_mul_add_power_of_two_ref;
znx_mul_add_power_of_two_ref(k, &mut res[span << 2..], &a[span << 2..]);
}
}

View File

@@ -6,14 +6,14 @@ use std::arch::x86_64::__m256i;
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[target_feature(enable = "avx2")]
fn normalize_consts_avx(basek: usize) -> (__m256i, __m256i, __m256i, __m256i) {
fn normalize_consts_avx(base2k: usize) -> (__m256i, __m256i, __m256i, __m256i) {
use std::arch::x86_64::_mm256_set1_epi64x;
assert!((1..=63).contains(&basek));
let mask_k: i64 = ((1u64 << basek) - 1) as i64; // 0..k-1 bits set
let sign_k: i64 = (1u64 << (basek - 1)) as i64; // bit k-1
let topmask: i64 = (!0u64 << (64 - basek)) as i64; // top k bits set
let sh_k: __m256i = _mm256_set1_epi64x(basek as i64);
assert!((1..=63).contains(&base2k));
let mask_k: i64 = ((1u64 << base2k) - 1) as i64; // 0..k-1 bits set
let sign_k: i64 = (1u64 << (base2k - 1)) as i64; // bit k-1
let topmask: i64 = (!0u64 << (64 - base2k)) as i64; // top k bits set
let sh_k: __m256i = _mm256_set1_epi64x(base2k as i64);
(
_mm256_set1_epi64x(mask_k), // mask_k_vec
_mm256_set1_epi64x(sign_k), // sign_k_vec
@@ -46,14 +46,14 @@ fn get_digit_avx(x: __m256i, mask_k: __m256i, sign_k: __m256i) -> __m256i {
unsafe fn get_carry_avx(
x: __m256i,
digit: __m256i,
basek: __m256i, // _mm256_set1_epi64x(k)
base2k: __m256i, // _mm256_set1_epi64x(k)
top_mask: __m256i, // (!0 << (64 - k)) broadcast
) -> __m256i {
use std::arch::x86_64::{
__m256i, _mm256_and_si256, _mm256_cmpgt_epi64, _mm256_or_si256, _mm256_setzero_si256, _mm256_srlv_epi64, _mm256_sub_epi64,
};
let diff: __m256i = _mm256_sub_epi64(x, digit);
let lsr: __m256i = _mm256_srlv_epi64(diff, basek); // logical >>
let lsr: __m256i = _mm256_srlv_epi64(diff, base2k); // logical >>
let neg: __m256i = _mm256_cmpgt_epi64(_mm256_setzero_si256(), diff); // 0xFFFF.. where v<0
let fill: __m256i = _mm256_and_si256(neg, top_mask); // top k bits if negative
_mm256_or_si256(lsr, fill)
@@ -61,13 +61,121 @@ unsafe fn get_carry_avx(
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
/// `res` and `src` must have the same length and must not alias.
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub fn znx_normalize_first_step_carry_only_avx(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
pub fn znx_extract_digit_addmul_avx(base2k: usize, lsh: usize, res: &mut [i64], src: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(x.len(), carry.len());
assert!(lsh < basek);
assert_eq!(res.len(), src.len());
assert!(lsh < base2k);
}
use std::arch::x86_64::{
__m256i, _mm256_add_epi64, _mm256_loadu_si256, _mm256_set1_epi64x, _mm256_sllv_epi64, _mm256_storeu_si256,
};
let n: usize = res.len();
let span: usize = n >> 2;
unsafe {
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
let mut ss: *mut __m256i = src.as_mut_ptr() as *mut __m256i;
// constants for digit/carry extraction
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k);
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
for _ in 0..span {
// load source & extract digit/carry
let sv: __m256i = _mm256_loadu_si256(ss);
let digit_256: __m256i = get_digit_avx(sv, mask, sign);
let carry_256: __m256i = get_carry_avx(sv, digit_256, basek_vec, top_mask);
// res += (digit << lsh)
let rv: __m256i = _mm256_loadu_si256(rr);
let madd: __m256i = _mm256_sllv_epi64(digit_256, lsh_v);
let sum: __m256i = _mm256_add_epi64(rv, madd);
_mm256_storeu_si256(rr, sum);
_mm256_storeu_si256(ss, carry_256);
rr = rr.add(1);
ss = ss.add(1);
}
}
// tail (scalar)
if !n.is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_extract_digit_addmul_ref;
let off: usize = span << 2;
znx_extract_digit_addmul_ref(base2k, lsh, &mut res[off..], &mut src[off..]);
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// `res` and `src` must have the same length and must not alias.
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub fn znx_normalize_digit_avx(base2k: usize, res: &mut [i64], src: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), src.len());
}
use std::arch::x86_64::{__m256i, _mm256_add_epi64, _mm256_loadu_si256, _mm256_storeu_si256};
let n: usize = res.len();
let span: usize = n >> 2;
unsafe {
// Pointers to 256-bit lanes
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
let mut ss: *mut __m256i = src.as_mut_ptr() as *mut __m256i;
// Constants for digit/carry extraction
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k);
for _ in 0..span {
// Load res lane
let rv: __m256i = _mm256_loadu_si256(rr);
// Extract digit and carry from res
let digit_256: __m256i = get_digit_avx(rv, mask, sign);
let carry_256: __m256i = get_carry_avx(rv, digit_256, basek_vec, top_mask);
// src += carry
let sv: __m256i = _mm256_loadu_si256(ss);
let sum: __m256i = _mm256_add_epi64(sv, carry_256);
_mm256_storeu_si256(ss, sum);
_mm256_storeu_si256(rr, digit_256);
rr = rr.add(1);
ss = ss.add(1);
}
}
// scalar tail
if !n.is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_normalize_digit_ref;
let off = span << 2;
znx_normalize_digit_ref(base2k, &mut res[off..], &mut src[off..]);
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[target_feature(enable = "avx2")]
pub fn znx_normalize_first_step_carry_only_avx(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert!(x.len() <= carry.len());
assert!(lsh < base2k);
}
use std::arch::x86_64::{_mm256_loadu_si256, _mm256_storeu_si256};
@@ -81,19 +189,19 @@ pub fn znx_normalize_first_step_carry_only_avx(basek: usize, lsh: usize, x: &[i6
let mut cc: *mut __m256i = carry.as_ptr() as *mut __m256i;
let (mask, sign, basek_vec, top_mask) = if lsh == 0 {
normalize_consts_avx(basek)
normalize_consts_avx(base2k)
} else {
normalize_consts_avx(basek - lsh)
normalize_consts_avx(base2k - lsh)
};
for _ in 0..span {
let xx_256: __m256i = _mm256_loadu_si256(xx);
let xv: __m256i = _mm256_loadu_si256(xx);
// (x << (64 - basek)) >> (64 - basek)
let digit_256: __m256i = get_digit_avx(xx_256, mask, sign);
// (x << (64 - base2k)) >> (64 - base2k)
let digit_256: __m256i = get_digit_avx(xv, mask, sign);
// (x - digit) >> basek
let carry_256: __m256i = get_carry_avx(xx_256, digit_256, basek_vec, top_mask);
// (x - digit) >> base2k
let carry_256: __m256i = get_carry_avx(xv, digit_256, basek_vec, top_mask);
_mm256_storeu_si256(cc, carry_256);
@@ -106,7 +214,7 @@ pub fn znx_normalize_first_step_carry_only_avx(basek: usize, lsh: usize, x: &[i6
if !x.len().is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_normalize_first_step_carry_only_ref;
znx_normalize_first_step_carry_only_ref(basek, lsh, &x[span << 2..], &mut carry[span << 2..]);
znx_normalize_first_step_carry_only_ref(base2k, lsh, &x[span << 2..], &mut carry[span << 2..]);
}
}
@@ -114,11 +222,11 @@ pub fn znx_normalize_first_step_carry_only_avx(basek: usize, lsh: usize, x: &[i6
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[target_feature(enable = "avx2")]
pub fn znx_normalize_first_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
pub fn znx_normalize_first_step_inplace_avx(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(x.len(), carry.len());
assert!(lsh < basek);
assert!(x.len() <= carry.len());
assert!(lsh < base2k);
}
use std::arch::x86_64::{_mm256_loadu_si256, _mm256_set1_epi64x, _mm256_sllv_epi64, _mm256_storeu_si256};
@@ -132,16 +240,16 @@ pub fn znx_normalize_first_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i
let mut cc: *mut __m256i = carry.as_ptr() as *mut __m256i;
if lsh == 0 {
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(basek);
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k);
for _ in 0..span {
let xx_256: __m256i = _mm256_loadu_si256(xx);
let xv: __m256i = _mm256_loadu_si256(xx);
// (x << (64 - basek)) >> (64 - basek)
let digit_256: __m256i = get_digit_avx(xx_256, mask, sign);
// (x << (64 - base2k)) >> (64 - base2k)
let digit_256: __m256i = get_digit_avx(xv, mask, sign);
// (x - digit) >> basek
let carry_256: __m256i = get_carry_avx(xx_256, digit_256, basek_vec, top_mask);
// (x - digit) >> base2k
let carry_256: __m256i = get_carry_avx(xv, digit_256, basek_vec, top_mask);
_mm256_storeu_si256(xx, digit_256);
_mm256_storeu_si256(cc, carry_256);
@@ -150,18 +258,18 @@ pub fn znx_normalize_first_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i
cc = cc.add(1);
}
} else {
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(basek - lsh);
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k - lsh);
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
for _ in 0..span {
let xx_256: __m256i = _mm256_loadu_si256(xx);
let xv: __m256i = _mm256_loadu_si256(xx);
// (x << (64 - basek)) >> (64 - basek)
let digit_256: __m256i = get_digit_avx(xx_256, mask, sign);
// (x << (64 - base2k)) >> (64 - base2k)
let digit_256: __m256i = get_digit_avx(xv, mask, sign);
// (x - digit) >> basek
let carry_256: __m256i = get_carry_avx(xx_256, digit_256, basek_vec, top_mask);
// (x - digit) >> base2k
let carry_256: __m256i = get_carry_avx(xv, digit_256, basek_vec, top_mask);
_mm256_storeu_si256(xx, _mm256_sllv_epi64(digit_256, lsh_v));
_mm256_storeu_si256(cc, carry_256);
@@ -176,7 +284,7 @@ pub fn znx_normalize_first_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i
if !x.len().is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_normalize_first_step_inplace_ref;
znx_normalize_first_step_inplace_ref(basek, lsh, &mut x[span << 2..], &mut carry[span << 2..]);
znx_normalize_first_step_inplace_ref(base2k, lsh, &mut x[span << 2..], &mut carry[span << 2..]);
}
}
@@ -184,12 +292,12 @@ pub fn znx_normalize_first_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[target_feature(enable = "avx2")]
pub fn znx_normalize_first_step_avx(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
pub fn znx_normalize_first_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(x.len(), carry.len());
assert_eq!(a.len(), carry.len());
assert!(lsh < basek);
assert_eq!(x.len(), a.len());
assert!(x.len() <= carry.len());
assert!(lsh < base2k);
}
use std::arch::x86_64::{_mm256_loadu_si256, _mm256_sllv_epi64, _mm256_storeu_si256};
@@ -204,16 +312,16 @@ pub fn znx_normalize_first_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
let mut cc: *mut __m256i = carry.as_ptr() as *mut __m256i;
if lsh == 0 {
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(basek);
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k);
for _ in 0..span {
let aa_256: __m256i = _mm256_loadu_si256(aa);
let av: __m256i = _mm256_loadu_si256(aa);
// (x << (64 - basek)) >> (64 - basek)
let digit_256: __m256i = get_digit_avx(aa_256, mask, sign);
// (x << (64 - base2k)) >> (64 - base2k)
let digit_256: __m256i = get_digit_avx(av, mask, sign);
// (x - digit) >> basek
let carry_256: __m256i = get_carry_avx(aa_256, digit_256, basek_vec, top_mask);
// (x - digit) >> base2k
let carry_256: __m256i = get_carry_avx(av, digit_256, basek_vec, top_mask);
_mm256_storeu_si256(xx, digit_256);
_mm256_storeu_si256(cc, carry_256);
@@ -225,18 +333,18 @@ pub fn znx_normalize_first_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
} else {
use std::arch::x86_64::_mm256_set1_epi64x;
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(basek - lsh);
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k - lsh);
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
for _ in 0..span {
let aa_256: __m256i = _mm256_loadu_si256(aa);
let av: __m256i = _mm256_loadu_si256(aa);
// (x << (64 - basek)) >> (64 - basek)
let digit_256: __m256i = get_digit_avx(aa_256, mask, sign);
// (x << (64 - base2k)) >> (64 - base2k)
let digit_256: __m256i = get_digit_avx(av, mask, sign);
// (x - digit) >> basek
let carry_256: __m256i = get_carry_avx(aa_256, digit_256, basek_vec, top_mask);
// (x - digit) >> base2k
let carry_256: __m256i = get_carry_avx(av, digit_256, basek_vec, top_mask);
_mm256_storeu_si256(xx, _mm256_sllv_epi64(digit_256, lsh_v));
_mm256_storeu_si256(cc, carry_256);
@@ -253,7 +361,7 @@ pub fn znx_normalize_first_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
use poulpy_hal::reference::znx::znx_normalize_first_step_ref;
znx_normalize_first_step_ref(
basek,
base2k,
lsh,
&mut x[span << 2..],
&a[span << 2..],
@@ -266,11 +374,11 @@ pub fn znx_normalize_first_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[target_feature(enable = "avx2")]
pub fn znx_normalize_middle_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
pub fn znx_normalize_middle_step_inplace_avx(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(x.len(), carry.len());
assert!(lsh < basek);
assert!(x.len() <= carry.len());
assert!(lsh < base2k);
}
use std::arch::x86_64::{_mm256_add_epi64, _mm256_loadu_si256, _mm256_sllv_epi64, _mm256_storeu_si256};
@@ -279,7 +387,7 @@ pub fn znx_normalize_middle_step_inplace_avx(basek: usize, lsh: usize, x: &mut [
let span: usize = n >> 2;
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(basek);
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k);
unsafe {
let mut xx: *mut __m256i = x.as_mut_ptr() as *mut __m256i;
@@ -287,13 +395,13 @@ pub fn znx_normalize_middle_step_inplace_avx(basek: usize, lsh: usize, x: &mut [
if lsh == 0 {
for _ in 0..span {
let xx_256: __m256i = _mm256_loadu_si256(xx);
let cc_256: __m256i = _mm256_loadu_si256(cc);
let xv: __m256i = _mm256_loadu_si256(xx);
let cv: __m256i = _mm256_loadu_si256(cc);
let d0: __m256i = get_digit_avx(xx_256, mask, sign);
let c0: __m256i = get_carry_avx(xx_256, d0, basek_vec, top_mask);
let d0: __m256i = get_digit_avx(xv, mask, sign);
let c0: __m256i = get_carry_avx(xv, d0, basek_vec, top_mask);
let s: __m256i = _mm256_add_epi64(d0, cc_256);
let s: __m256i = _mm256_add_epi64(d0, cv);
let x1: __m256i = get_digit_avx(s, mask, sign);
let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask);
let cout: __m256i = _mm256_add_epi64(c0, c1);
@@ -307,20 +415,20 @@ pub fn znx_normalize_middle_step_inplace_avx(basek: usize, lsh: usize, x: &mut [
} else {
use std::arch::x86_64::_mm256_set1_epi64x;
let (mask_lsh, sign_lsh, basek_vec_lsh, top_mask_lsh) = normalize_consts_avx(basek - lsh);
let (mask_lsh, sign_lsh, basek_vec_lsh, top_mask_lsh) = normalize_consts_avx(base2k - lsh);
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
for _ in 0..span {
let xx_256: __m256i = _mm256_loadu_si256(xx);
let cc_256: __m256i = _mm256_loadu_si256(cc);
let xv: __m256i = _mm256_loadu_si256(xx);
let cv: __m256i = _mm256_loadu_si256(cc);
let d0: __m256i = get_digit_avx(xx_256, mask_lsh, sign_lsh);
let c0: __m256i = get_carry_avx(xx_256, d0, basek_vec_lsh, top_mask_lsh);
let d0: __m256i = get_digit_avx(xv, mask_lsh, sign_lsh);
let c0: __m256i = get_carry_avx(xv, d0, basek_vec_lsh, top_mask_lsh);
let d0_lsh: __m256i = _mm256_sllv_epi64(d0, lsh_v);
let s: __m256i = _mm256_add_epi64(d0_lsh, cc_256);
let s: __m256i = _mm256_add_epi64(d0_lsh, cv);
let x1: __m256i = get_digit_avx(s, mask, sign);
let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask);
let cout: __m256i = _mm256_add_epi64(c0, c1);
@@ -337,7 +445,7 @@ pub fn znx_normalize_middle_step_inplace_avx(basek: usize, lsh: usize, x: &mut [
if !x.len().is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_normalize_middle_step_inplace_ref;
znx_normalize_middle_step_inplace_ref(basek, lsh, &mut x[span << 2..], &mut carry[span << 2..]);
znx_normalize_middle_step_inplace_ref(base2k, lsh, &mut x[span << 2..], &mut carry[span << 2..]);
}
}
@@ -345,11 +453,11 @@ pub fn znx_normalize_middle_step_inplace_avx(basek: usize, lsh: usize, x: &mut [
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[target_feature(enable = "avx2")]
pub fn znx_normalize_middle_step_carry_only_avx(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
pub fn znx_normalize_middle_step_carry_only_avx(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(x.len(), carry.len());
assert!(lsh < basek);
assert!(x.len() <= carry.len());
assert!(lsh < base2k);
}
use std::arch::x86_64::{_mm256_add_epi64, _mm256_loadu_si256, _mm256_sllv_epi64, _mm256_storeu_si256};
@@ -358,7 +466,7 @@ pub fn znx_normalize_middle_step_carry_only_avx(basek: usize, lsh: usize, x: &[i
let span: usize = n >> 2;
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(basek);
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k);
unsafe {
let mut xx: *const __m256i = x.as_ptr() as *const __m256i;
@@ -366,13 +474,13 @@ pub fn znx_normalize_middle_step_carry_only_avx(basek: usize, lsh: usize, x: &[i
if lsh == 0 {
for _ in 0..span {
let xx_256: __m256i = _mm256_loadu_si256(xx);
let cc_256: __m256i = _mm256_loadu_si256(cc);
let xv: __m256i = _mm256_loadu_si256(xx);
let cv: __m256i = _mm256_loadu_si256(cc);
let d0: __m256i = get_digit_avx(xx_256, mask, sign);
let c0: __m256i = get_carry_avx(xx_256, d0, basek_vec, top_mask);
let d0: __m256i = get_digit_avx(xv, mask, sign);
let c0: __m256i = get_carry_avx(xv, d0, basek_vec, top_mask);
let s: __m256i = _mm256_add_epi64(d0, cc_256);
let s: __m256i = _mm256_add_epi64(d0, cv);
let x1: __m256i = get_digit_avx(s, mask, sign);
let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask);
let cout: __m256i = _mm256_add_epi64(c0, c1);
@@ -385,20 +493,20 @@ pub fn znx_normalize_middle_step_carry_only_avx(basek: usize, lsh: usize, x: &[i
} else {
use std::arch::x86_64::_mm256_set1_epi64x;
let (mask_lsh, sign_lsh, basek_vec_lsh, top_mask_lsh) = normalize_consts_avx(basek - lsh);
let (mask_lsh, sign_lsh, basek_vec_lsh, top_mask_lsh) = normalize_consts_avx(base2k - lsh);
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
for _ in 0..span {
let xx_256: __m256i = _mm256_loadu_si256(xx);
let cc_256: __m256i = _mm256_loadu_si256(cc);
let xv: __m256i = _mm256_loadu_si256(xx);
let cv: __m256i = _mm256_loadu_si256(cc);
let d0: __m256i = get_digit_avx(xx_256, mask_lsh, sign_lsh);
let c0: __m256i = get_carry_avx(xx_256, d0, basek_vec_lsh, top_mask_lsh);
let d0: __m256i = get_digit_avx(xv, mask_lsh, sign_lsh);
let c0: __m256i = get_carry_avx(xv, d0, basek_vec_lsh, top_mask_lsh);
let d0_lsh: __m256i = _mm256_sllv_epi64(d0, lsh_v);
let s: __m256i = _mm256_add_epi64(d0_lsh, cc_256);
let s: __m256i = _mm256_add_epi64(d0_lsh, cv);
let x1: __m256i = get_digit_avx(s, mask, sign);
let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask);
let cout: __m256i = _mm256_add_epi64(c0, c1);
@@ -414,7 +522,7 @@ pub fn znx_normalize_middle_step_carry_only_avx(basek: usize, lsh: usize, x: &[i
if !x.len().is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_normalize_middle_step_carry_only_ref;
znx_normalize_middle_step_carry_only_ref(basek, lsh, &x[span << 2..], &mut carry[span << 2..]);
znx_normalize_middle_step_carry_only_ref(base2k, lsh, &x[span << 2..], &mut carry[span << 2..]);
}
}
@@ -422,12 +530,12 @@ pub fn znx_normalize_middle_step_carry_only_avx(basek: usize, lsh: usize, x: &[i
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[target_feature(enable = "avx2")]
pub fn znx_normalize_middle_step_avx(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
pub fn znx_normalize_middle_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(x.len(), carry.len());
assert_eq!(a.len(), carry.len());
assert!(lsh < basek);
assert_eq!(x.len(), a.len());
assert!(x.len() <= carry.len());
assert!(lsh < base2k);
}
use std::arch::x86_64::{_mm256_add_epi64, _mm256_loadu_si256, _mm256_sllv_epi64, _mm256_storeu_si256};
@@ -436,7 +544,7 @@ pub fn znx_normalize_middle_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
let span: usize = n >> 2;
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(basek);
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k);
unsafe {
let mut xx: *mut __m256i = x.as_mut_ptr() as *mut __m256i;
@@ -445,13 +553,13 @@ pub fn znx_normalize_middle_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
if lsh == 0 {
for _ in 0..span {
let aa_256: __m256i = _mm256_loadu_si256(aa);
let cc_256: __m256i = _mm256_loadu_si256(cc);
let av: __m256i = _mm256_loadu_si256(aa);
let cv: __m256i = _mm256_loadu_si256(cc);
let d0: __m256i = get_digit_avx(aa_256, mask, sign);
let c0: __m256i = get_carry_avx(aa_256, d0, basek_vec, top_mask);
let d0: __m256i = get_digit_avx(av, mask, sign);
let c0: __m256i = get_carry_avx(av, d0, basek_vec, top_mask);
let s: __m256i = _mm256_add_epi64(d0, cc_256);
let s: __m256i = _mm256_add_epi64(d0, cv);
let x1: __m256i = get_digit_avx(s, mask, sign);
let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask);
let cout: __m256i = _mm256_add_epi64(c0, c1);
@@ -466,20 +574,20 @@ pub fn znx_normalize_middle_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
} else {
use std::arch::x86_64::_mm256_set1_epi64x;
let (mask_lsh, sign_lsh, basek_vec_lsh, top_mask_lsh) = normalize_consts_avx(basek - lsh);
let (mask_lsh, sign_lsh, basek_vec_lsh, top_mask_lsh) = normalize_consts_avx(base2k - lsh);
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
for _ in 0..span {
let aa_256: __m256i = _mm256_loadu_si256(aa);
let cc_256: __m256i = _mm256_loadu_si256(cc);
let av: __m256i = _mm256_loadu_si256(aa);
let cv: __m256i = _mm256_loadu_si256(cc);
let d0: __m256i = get_digit_avx(aa_256, mask_lsh, sign_lsh);
let c0: __m256i = get_carry_avx(aa_256, d0, basek_vec_lsh, top_mask_lsh);
let d0: __m256i = get_digit_avx(av, mask_lsh, sign_lsh);
let c0: __m256i = get_carry_avx(av, d0, basek_vec_lsh, top_mask_lsh);
let d0_lsh: __m256i = _mm256_sllv_epi64(d0, lsh_v);
let s: __m256i = _mm256_add_epi64(d0_lsh, cc_256);
let s: __m256i = _mm256_add_epi64(d0_lsh, cv);
let x1: __m256i = get_digit_avx(s, mask, sign);
let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask);
let cout: __m256i = _mm256_add_epi64(c0, c1);
@@ -498,7 +606,7 @@ pub fn znx_normalize_middle_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
use poulpy_hal::reference::znx::znx_normalize_middle_step_ref;
znx_normalize_middle_step_ref(
basek,
base2k,
lsh,
&mut x[span << 2..],
&a[span << 2..],
@@ -511,11 +619,11 @@ pub fn znx_normalize_middle_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[target_feature(enable = "avx2")]
pub fn znx_normalize_final_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
pub fn znx_normalize_final_step_inplace_avx(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(x.len(), carry.len());
assert!(lsh < basek);
assert!(x.len() <= carry.len());
assert!(lsh < base2k);
}
use std::arch::x86_64::{_mm256_add_epi64, _mm256_loadu_si256, _mm256_sllv_epi64, _mm256_storeu_si256};
@@ -524,7 +632,7 @@ pub fn znx_normalize_final_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i
let span: usize = n >> 2;
let (mask, sign, _, _) = normalize_consts_avx(basek);
let (mask, sign, _, _) = normalize_consts_avx(base2k);
unsafe {
let mut xx: *mut __m256i = x.as_mut_ptr() as *mut __m256i;
@@ -547,7 +655,7 @@ pub fn znx_normalize_final_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i
} else {
use std::arch::x86_64::_mm256_set1_epi64x;
let (mask_lsh, sign_lsh, _, _) = normalize_consts_avx(basek - lsh);
let (mask_lsh, sign_lsh, _, _) = normalize_consts_avx(base2k - lsh);
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
@@ -573,7 +681,7 @@ pub fn znx_normalize_final_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i
if !x.len().is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_normalize_final_step_inplace_ref;
znx_normalize_final_step_inplace_ref(basek, lsh, &mut x[span << 2..], &mut carry[span << 2..]);
znx_normalize_final_step_inplace_ref(base2k, lsh, &mut x[span << 2..], &mut carry[span << 2..]);
}
}
@@ -581,12 +689,12 @@ pub fn znx_normalize_final_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[target_feature(enable = "avx2")]
pub fn znx_normalize_final_step_avx(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
pub fn znx_normalize_final_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(x.len(), carry.len());
assert_eq!(a.len(), carry.len());
assert!(lsh < basek);
assert_eq!(x.len(), a.len());
assert!(x.len() <= carry.len());
assert!(lsh < base2k);
}
use std::arch::x86_64::{_mm256_add_epi64, _mm256_loadu_si256, _mm256_sllv_epi64, _mm256_storeu_si256};
@@ -595,7 +703,7 @@ pub fn znx_normalize_final_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
let span: usize = n >> 2;
let (mask, sign, _, _) = normalize_consts_avx(basek);
let (mask, sign, _, _) = normalize_consts_avx(base2k);
unsafe {
let mut xx: *mut __m256i = x.as_mut_ptr() as *mut __m256i;
@@ -620,7 +728,7 @@ pub fn znx_normalize_final_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
} else {
use std::arch::x86_64::_mm256_set1_epi64x;
let (mask_lsh, sign_lsh, _, _) = normalize_consts_avx(basek - lsh);
let (mask_lsh, sign_lsh, _, _) = normalize_consts_avx(base2k - lsh);
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
@@ -647,7 +755,7 @@ pub fn znx_normalize_final_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
use poulpy_hal::reference::znx::znx_normalize_final_step_ref;
znx_normalize_final_step_ref(
basek,
base2k,
lsh,
&mut x[span << 2..],
&a[span << 2..],
@@ -658,9 +766,9 @@ pub fn znx_normalize_final_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
mod tests {
use poulpy_hal::reference::znx::{
get_carry, get_digit, znx_normalize_final_step_inplace_ref, znx_normalize_final_step_ref,
znx_normalize_first_step_inplace_ref, znx_normalize_first_step_ref, znx_normalize_middle_step_inplace_ref,
znx_normalize_middle_step_ref,
get_carry_i64, get_digit_i64, znx_extract_digit_addmul_ref, znx_normalize_digit_ref,
znx_normalize_final_step_inplace_ref, znx_normalize_final_step_ref, znx_normalize_first_step_inplace_ref,
znx_normalize_first_step_ref, znx_normalize_middle_step_inplace_ref, znx_normalize_middle_step_ref,
};
use super::*;
@@ -670,7 +778,7 @@ mod tests {
#[allow(dead_code)]
#[target_feature(enable = "avx2")]
fn test_get_digit_avx_internal() {
let basek: usize = 12;
let base2k: usize = 12;
let x: [i64; 4] = [
7638646372408325293,
-61440197422348985,
@@ -678,15 +786,15 @@ mod tests {
-4835376105455195188,
];
let y0: Vec<i64> = vec![
get_digit(basek, x[0]),
get_digit(basek, x[1]),
get_digit(basek, x[2]),
get_digit(basek, x[3]),
get_digit_i64(base2k, x[0]),
get_digit_i64(base2k, x[1]),
get_digit_i64(base2k, x[2]),
get_digit_i64(base2k, x[3]),
];
let mut y1: Vec<i64> = vec![0i64; 4];
unsafe {
let x_256: __m256i = _mm256_loadu_si256(x.as_ptr() as *const __m256i);
let (mask, sign, _, _) = normalize_consts_avx(basek);
let (mask, sign, _, _) = normalize_consts_avx(base2k);
let digit: __m256i = get_digit_avx(x_256, mask, sign);
_mm256_storeu_si256(y1.as_mut_ptr() as *mut __m256i, digit);
}
@@ -707,7 +815,7 @@ mod tests {
#[allow(dead_code)]
#[target_feature(enable = "avx2")]
fn test_get_carry_avx_internal() {
let basek: usize = 12;
let base2k: usize = 12;
let x: [i64; 4] = [
7638646372408325293,
-61440197422348985,
@@ -716,16 +824,16 @@ mod tests {
];
let carry: [i64; 4] = [1174467039, -144794816, -1466676977, 513122840];
let y0: Vec<i64> = vec![
get_carry(basek, x[0], carry[0]),
get_carry(basek, x[1], carry[1]),
get_carry(basek, x[2], carry[2]),
get_carry(basek, x[3], carry[3]),
get_carry_i64(base2k, x[0], carry[0]),
get_carry_i64(base2k, x[1], carry[1]),
get_carry_i64(base2k, x[2], carry[2]),
get_carry_i64(base2k, x[3], carry[3]),
];
let mut y1: Vec<i64> = vec![0i64; 4];
unsafe {
let x_256: __m256i = _mm256_loadu_si256(x.as_ptr() as *const __m256i);
let d_256: __m256i = _mm256_loadu_si256(carry.as_ptr() as *const __m256i);
let (_, _, basek_vec, top_mask) = normalize_consts_avx(basek);
let (_, _, basek_vec, top_mask) = normalize_consts_avx(base2k);
let digit: __m256i = get_carry_avx(x_256, d_256, basek_vec, top_mask);
_mm256_storeu_si256(y1.as_mut_ptr() as *mut __m256i, digit);
}
@@ -762,16 +870,16 @@ mod tests {
];
let mut c1: [i64; 4] = c0;
let basek = 12;
let base2k = 12;
znx_normalize_first_step_inplace_ref(basek, 0, &mut y0, &mut c0);
znx_normalize_first_step_inplace_avx(basek, 0, &mut y1, &mut c1);
znx_normalize_first_step_inplace_ref(base2k, 0, &mut y0, &mut c0);
znx_normalize_first_step_inplace_avx(base2k, 0, &mut y1, &mut c1);
assert_eq!(y0, y1);
assert_eq!(c0, c1);
znx_normalize_first_step_inplace_ref(basek, basek - 1, &mut y0, &mut c0);
znx_normalize_first_step_inplace_avx(basek, basek - 1, &mut y1, &mut c1);
znx_normalize_first_step_inplace_ref(base2k, base2k - 1, &mut y0, &mut c0);
znx_normalize_first_step_inplace_avx(base2k, base2k - 1, &mut y1, &mut c1);
assert_eq!(y0, y1);
assert_eq!(c0, c1);
@@ -807,16 +915,16 @@ mod tests {
];
let mut c1: [i64; 4] = c0;
let basek = 12;
let base2k = 12;
znx_normalize_middle_step_inplace_ref(basek, 0, &mut y0, &mut c0);
znx_normalize_middle_step_inplace_avx(basek, 0, &mut y1, &mut c1);
znx_normalize_middle_step_inplace_ref(base2k, 0, &mut y0, &mut c0);
znx_normalize_middle_step_inplace_avx(base2k, 0, &mut y1, &mut c1);
assert_eq!(y0, y1);
assert_eq!(c0, c1);
znx_normalize_middle_step_inplace_ref(basek, basek - 1, &mut y0, &mut c0);
znx_normalize_middle_step_inplace_avx(basek, basek - 1, &mut y1, &mut c1);
znx_normalize_middle_step_inplace_ref(base2k, base2k - 1, &mut y0, &mut c0);
znx_normalize_middle_step_inplace_avx(base2k, base2k - 1, &mut y1, &mut c1);
assert_eq!(y0, y1);
assert_eq!(c0, c1);
@@ -852,16 +960,16 @@ mod tests {
];
let mut c1: [i64; 4] = c0;
let basek = 12;
let base2k = 12;
znx_normalize_final_step_inplace_ref(basek, 0, &mut y0, &mut c0);
znx_normalize_final_step_inplace_avx(basek, 0, &mut y1, &mut c1);
znx_normalize_final_step_inplace_ref(base2k, 0, &mut y0, &mut c0);
znx_normalize_final_step_inplace_avx(base2k, 0, &mut y1, &mut c1);
assert_eq!(y0, y1);
assert_eq!(c0, c1);
znx_normalize_final_step_inplace_ref(basek, basek - 1, &mut y0, &mut c0);
znx_normalize_final_step_inplace_avx(basek, basek - 1, &mut y1, &mut c1);
znx_normalize_final_step_inplace_ref(base2k, base2k - 1, &mut y0, &mut c0);
znx_normalize_final_step_inplace_avx(base2k, base2k - 1, &mut y1, &mut c1);
assert_eq!(y0, y1);
assert_eq!(c0, c1);
@@ -898,16 +1006,16 @@ mod tests {
];
let mut c1: [i64; 4] = c0;
let basek = 12;
let base2k = 12;
znx_normalize_first_step_ref(basek, 0, &mut y0, &a, &mut c0);
znx_normalize_first_step_avx(basek, 0, &mut y1, &a, &mut c1);
znx_normalize_first_step_ref(base2k, 0, &mut y0, &a, &mut c0);
znx_normalize_first_step_avx(base2k, 0, &mut y1, &a, &mut c1);
assert_eq!(y0, y1);
assert_eq!(c0, c1);
znx_normalize_first_step_ref(basek, basek - 1, &mut y0, &a, &mut c0);
znx_normalize_first_step_avx(basek, basek - 1, &mut y1, &a, &mut c1);
znx_normalize_first_step_ref(base2k, base2k - 1, &mut y0, &a, &mut c0);
znx_normalize_first_step_avx(base2k, base2k - 1, &mut y1, &a, &mut c1);
assert_eq!(y0, y1);
assert_eq!(c0, c1);
@@ -944,16 +1052,16 @@ mod tests {
];
let mut c1: [i64; 4] = c0;
let basek = 12;
let base2k = 12;
znx_normalize_middle_step_ref(basek, 0, &mut y0, &a, &mut c0);
znx_normalize_middle_step_avx(basek, 0, &mut y1, &a, &mut c1);
znx_normalize_middle_step_ref(base2k, 0, &mut y0, &a, &mut c0);
znx_normalize_middle_step_avx(base2k, 0, &mut y1, &a, &mut c1);
assert_eq!(y0, y1);
assert_eq!(c0, c1);
znx_normalize_middle_step_ref(basek, basek - 1, &mut y0, &a, &mut c0);
znx_normalize_middle_step_avx(basek, basek - 1, &mut y1, &a, &mut c1);
znx_normalize_middle_step_ref(base2k, base2k - 1, &mut y0, &a, &mut c0);
znx_normalize_middle_step_avx(base2k, base2k - 1, &mut y1, &a, &mut c1);
assert_eq!(y0, y1);
assert_eq!(c0, c1);
@@ -990,16 +1098,16 @@ mod tests {
];
let mut c1: [i64; 4] = c0;
let basek = 12;
let base2k = 12;
znx_normalize_final_step_ref(basek, 0, &mut y0, &a, &mut c0);
znx_normalize_final_step_avx(basek, 0, &mut y1, &a, &mut c1);
znx_normalize_final_step_ref(base2k, 0, &mut y0, &a, &mut c0);
znx_normalize_final_step_avx(base2k, 0, &mut y1, &a, &mut c1);
assert_eq!(y0, y1);
assert_eq!(c0, c1);
znx_normalize_final_step_ref(basek, basek - 1, &mut y0, &a, &mut c0);
znx_normalize_final_step_avx(basek, basek - 1, &mut y1, &a, &mut c1);
znx_normalize_final_step_ref(base2k, base2k - 1, &mut y0, &a, &mut c0);
znx_normalize_final_step_avx(base2k, base2k - 1, &mut y1, &a, &mut c1);
assert_eq!(y0, y1);
assert_eq!(c0, c1);
@@ -1015,4 +1123,86 @@ mod tests {
test_znx_normalize_final_step_avx_internal();
}
}
#[target_feature(enable = "avx2")]
fn znx_extract_digit_addmul_internal() {
let mut y0: [i64; 4] = [
7638646372408325293,
-61440197422348985,
6835891051541717957,
-4835376105455195188,
];
let mut y1: [i64; 4] = y0;
let mut c0: [i64; 4] = [
621182201135793202,
9000856573317006236,
5542252755421113668,
-6036847263131690631,
];
let mut c1: [i64; 4] = c0;
let base2k: usize = 12;
znx_extract_digit_addmul_ref(base2k, 0, &mut y0, &mut c0);
znx_extract_digit_addmul_avx(base2k, 0, &mut y1, &mut c1);
assert_eq!(y0, y1);
assert_eq!(c0, c1);
znx_extract_digit_addmul_ref(base2k, base2k - 1, &mut y0, &mut c0);
znx_extract_digit_addmul_avx(base2k, base2k - 1, &mut y1, &mut c1);
assert_eq!(y0, y1);
assert_eq!(c0, c1);
}
#[test]
fn test_znx_extract_digit_addmul_avx() {
if !std::is_x86_feature_detected!("avx2") {
eprintln!("skipping: CPU lacks avx2");
return;
};
unsafe {
znx_extract_digit_addmul_internal();
}
}
#[target_feature(enable = "avx2")]
fn znx_normalize_digit_internal() {
let mut y0: [i64; 4] = [
7638646372408325293,
-61440197422348985,
6835891051541717957,
-4835376105455195188,
];
let mut y1: [i64; 4] = y0;
let mut c0: [i64; 4] = [
621182201135793202,
9000856573317006236,
5542252755421113668,
-6036847263131690631,
];
let mut c1: [i64; 4] = c0;
let base2k: usize = 12;
znx_normalize_digit_ref(base2k, &mut y0, &mut c0);
znx_normalize_digit_avx(base2k, &mut y1, &mut c1);
assert_eq!(y0, y1);
assert_eq!(c0, c1);
}
#[test]
fn test_znx_normalize_digit_internal_avx() {
if !std::is_x86_feature_detected!("avx2") {
eprintln!("skipping: CPU lacks avx2");
return;
};
unsafe {
znx_normalize_digit_internal();
}
}
}

View File

@@ -41,7 +41,7 @@ pub fn znx_sub_avx(res: &mut [i64], a: &[i64], b: &[i64]) {
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[target_feature(enable = "avx2")]
pub fn znx_sub_ab_inplace_avx(res: &mut [i64], a: &[i64]) {
pub fn znx_sub_inplace_avx(res: &mut [i64], a: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len());
@@ -67,9 +67,9 @@ pub fn znx_sub_ab_inplace_avx(res: &mut [i64], a: &[i64]) {
// tail
if !res.len().is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_sub_ab_inplace_ref;
use poulpy_hal::reference::znx::znx_sub_inplace_ref;
znx_sub_ab_inplace_ref(&mut res[span << 2..], &a[span << 2..]);
znx_sub_inplace_ref(&mut res[span << 2..], &a[span << 2..]);
}
}
@@ -77,7 +77,7 @@ pub fn znx_sub_ab_inplace_avx(res: &mut [i64], a: &[i64]) {
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
/// all inputs must have the same length and must not alias.
#[target_feature(enable = "avx2")]
pub fn znx_sub_ba_inplace_avx(res: &mut [i64], a: &[i64]) {
pub fn znx_sub_negate_inplace_avx(res: &mut [i64], a: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len());
@@ -103,8 +103,8 @@ pub fn znx_sub_ba_inplace_avx(res: &mut [i64], a: &[i64]) {
// tail
if !res.len().is_multiple_of(4) {
use poulpy_hal::reference::znx::znx_sub_ba_inplace_ref;
use poulpy_hal::reference::znx::znx_sub_negate_inplace_ref;
znx_sub_ba_inplace_ref(&mut res[span << 2..], &a[span << 2..]);
znx_sub_negate_inplace_ref(&mut res[span << 2..], &a[span << 2..]);
}
}

View File

@@ -1,10 +1,11 @@
use poulpy_hal::reference::fft64::{
reim::{
ReimAdd, ReimAddInplace, ReimAddMul, ReimCopy, ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimIFFTTable, ReimMul,
ReimMulInplace, ReimNegate, ReimNegateInplace, ReimSub, ReimSubABInplace, ReimSubBAInplace, ReimToZnx, ReimToZnxInplace,
ReimZero, fft_ref, ifft_ref, reim_add_inplace_ref, reim_add_ref, reim_addmul_ref, reim_copy_ref, reim_from_znx_i64_ref,
reim_mul_inplace_ref, reim_mul_ref, reim_negate_inplace_ref, reim_negate_ref, reim_sub_ab_inplace_ref,
reim_sub_ba_inplace_ref, reim_sub_ref, reim_to_znx_i64_inplace_ref, reim_to_znx_i64_ref, reim_zero_ref,
ReimMulInplace, ReimNegate, ReimNegateInplace, ReimSub, ReimSubInplace, ReimSubNegateInplace, ReimToZnx,
ReimToZnxInplace, ReimZero, fft_ref, ifft_ref, reim_add_inplace_ref, reim_add_ref, reim_addmul_ref, reim_copy_ref,
reim_from_znx_i64_ref, reim_mul_inplace_ref, reim_mul_ref, reim_negate_inplace_ref, reim_negate_ref,
reim_sub_inplace_ref, reim_sub_negate_inplace_ref, reim_sub_ref, reim_to_znx_i64_inplace_ref, reim_to_znx_i64_ref,
reim_zero_ref,
},
reim4::{
Reim4Extract1Blk, Reim4Mat1ColProd, Reim4Mat2Cols2ndColProd, Reim4Mat2ColsProd, Reim4Save1Blk, Reim4Save2Blks,
@@ -69,17 +70,17 @@ impl ReimSub for FFT64Ref {
}
}
impl ReimSubABInplace for FFT64Ref {
impl ReimSubInplace for FFT64Ref {
#[inline(always)]
fn reim_sub_ab_inplace(res: &mut [f64], a: &[f64]) {
reim_sub_ab_inplace_ref(res, a);
fn reim_sub_inplace(res: &mut [f64], a: &[f64]) {
reim_sub_inplace_ref(res, a);
}
}
impl ReimSubBAInplace for FFT64Ref {
impl ReimSubNegateInplace for FFT64Ref {
#[inline(always)]
fn reim_sub_ba_inplace(res: &mut [f64], a: &[f64]) {
reim_sub_ba_inplace_ref(res, a);
fn reim_sub_negate_inplace(res: &mut [f64], a: &[f64]) {
reim_sub_negate_inplace_ref(res, a);
}
}

View File

@@ -253,9 +253,6 @@ fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8]
(take_slice, rem_slice)
}
} else {
panic!(
"Attempted to take {} from scratch with {} aligned bytes left",
take_len, aligned_len,
);
panic!("Attempted to take {take_len} from scratch with {aligned_len} aligned bytes left");
}
}

View File

@@ -1,7 +1,8 @@
use poulpy_hal::{
api::{
TakeSlice, VecZnxAutomorphismInplaceTmpBytes, VecZnxMergeRingsTmpBytes, VecZnxMulXpMinusOneInplaceTmpBytes,
VecZnxNormalizeTmpBytes, VecZnxRotateInplaceTmpBytes, VecZnxSplitRingTmpBytes,
TakeSlice, VecZnxAutomorphismInplaceTmpBytes, VecZnxLshTmpBytes, VecZnxMergeRingsTmpBytes,
VecZnxMulXpMinusOneInplaceTmpBytes, VecZnxNormalizeTmpBytes, VecZnxRotateInplaceTmpBytes, VecZnxRshTmpBytes,
VecZnxSplitRingTmpBytes,
},
layouts::{Module, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef},
oep::{
@@ -12,7 +13,7 @@ use poulpy_hal::{
VecZnxMulXpMinusOneInplaceTmpBytesImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl,
VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl,
VecZnxRotateInplaceTmpBytesImpl, VecZnxRshImpl, VecZnxRshInplaceImpl, VecZnxRshTmpBytesImpl, VecZnxSplitRingImpl,
VecZnxSplitRingTmpBytesImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarImpl,
VecZnxSplitRingTmpBytesImpl, VecZnxSubImpl, VecZnxSubInplaceImpl, VecZnxSubNegateInplaceImpl, VecZnxSubScalarImpl,
VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl,
},
reference::vec_znx::{
@@ -23,7 +24,7 @@ use poulpy_hal::{
vec_znx_mul_xp_minus_one_inplace_tmp_bytes, vec_znx_negate, vec_znx_negate_inplace, vec_znx_normalize,
vec_znx_normalize_inplace, vec_znx_normalize_tmp_bytes, vec_znx_rotate, vec_znx_rotate_inplace,
vec_znx_rotate_inplace_tmp_bytes, vec_znx_rsh, vec_znx_rsh_inplace, vec_znx_rsh_tmp_bytes, vec_znx_split_ring,
vec_znx_split_ring_tmp_bytes, vec_znx_sub, vec_znx_sub_ab_inplace, vec_znx_sub_ba_inplace, vec_znx_sub_scalar,
vec_znx_split_ring_tmp_bytes, vec_znx_sub, vec_znx_sub_inplace, vec_znx_sub_negate_inplace, vec_znx_sub_scalar,
vec_znx_sub_scalar_inplace, vec_znx_switch_ring,
},
source::Source,
@@ -43,9 +44,10 @@ where
{
fn vec_znx_normalize_impl<R, A>(
module: &Module<Self>,
basek: usize,
res_basek: usize,
res: &mut R,
res_col: usize,
a_basek: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch<Self>,
@@ -54,7 +56,7 @@ where
A: VecZnxToRef,
{
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
vec_znx_normalize::<R, A, Self>(basek, res, res_col, a, a_col, carry);
vec_znx_normalize::<R, A, Self>(res_basek, res, res_col, a_basek, a, a_col, carry);
}
}
@@ -64,7 +66,7 @@ where
{
fn vec_znx_normalize_inplace_impl<R>(
module: &Module<Self>,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
scratch: &mut Scratch<Self>,
@@ -72,7 +74,7 @@ where
R: VecZnxToMut,
{
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
vec_znx_normalize_inplace::<R, Self>(basek, res, res_col, carry);
vec_znx_normalize_inplace::<R, Self>(base2k, res, res_col, carry);
}
}
@@ -143,23 +145,23 @@ unsafe impl VecZnxSubImpl<Self> for FFT64Ref {
}
}
unsafe impl VecZnxSubABInplaceImpl<Self> for FFT64Ref {
fn vec_znx_sub_ab_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
unsafe impl VecZnxSubInplaceImpl<Self> for FFT64Ref {
fn vec_znx_sub_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
vec_znx_sub_ab_inplace::<R, A, Self>(res, res_col, a, a_col);
vec_znx_sub_inplace::<R, A, Self>(res, res_col, a, a_col);
}
}
unsafe impl VecZnxSubBAInplaceImpl<Self> for FFT64Ref {
fn vec_znx_sub_ba_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
unsafe impl VecZnxSubNegateInplaceImpl<Self> for FFT64Ref {
fn vec_znx_sub_negate_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
vec_znx_sub_ba_inplace::<R, A, Self>(res, res_col, a, a_col);
vec_znx_sub_negate_inplace::<R, A, Self>(res, res_col, a, a_col);
}
}
@@ -234,9 +236,9 @@ where
Module<Self>: VecZnxNormalizeTmpBytes,
Scratch<Self>: TakeSlice,
{
fn vec_znx_lsh_inplace_impl<R, A>(
fn vec_znx_lsh_impl<R, A>(
module: &Module<Self>,
basek: usize,
base2k: usize,
k: usize,
res: &mut R,
res_col: usize,
@@ -247,8 +249,8 @@ where
R: VecZnxToMut,
A: VecZnxToRef,
{
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
vec_znx_lsh::<_, _, Self>(basek, k, res, res_col, a, a_col, carry);
let (carry, _) = scratch.take_slice(module.vec_znx_lsh_tmp_bytes() / size_of::<i64>());
vec_znx_lsh::<_, _, Self>(base2k, k, res, res_col, a, a_col, carry);
}
}
@@ -259,7 +261,7 @@ where
{
fn vec_znx_lsh_inplace_impl<A>(
module: &Module<Self>,
basek: usize,
base2k: usize,
k: usize,
a: &mut A,
a_col: usize,
@@ -267,8 +269,8 @@ where
) where
A: VecZnxToMut,
{
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
vec_znx_lsh_inplace::<_, Self>(basek, k, a, a_col, carry);
let (carry, _) = scratch.take_slice(module.vec_znx_lsh_tmp_bytes() / size_of::<i64>());
vec_znx_lsh_inplace::<_, Self>(base2k, k, a, a_col, carry);
}
}
@@ -277,9 +279,9 @@ where
Module<Self>: VecZnxNormalizeTmpBytes,
Scratch<Self>: TakeSlice,
{
fn vec_znx_rsh_inplace_impl<R, A>(
fn vec_znx_rsh_impl<R, A>(
module: &Module<Self>,
basek: usize,
base2k: usize,
k: usize,
res: &mut R,
res_col: usize,
@@ -290,8 +292,8 @@ where
R: VecZnxToMut,
A: VecZnxToRef,
{
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
vec_znx_rsh::<_, _, Self>(basek, k, res, res_col, a, a_col, carry);
let (carry, _) = scratch.take_slice(module.vec_znx_rsh_tmp_bytes() / size_of::<i64>());
vec_znx_rsh::<_, _, Self>(base2k, k, res, res_col, a, a_col, carry);
}
}
@@ -302,7 +304,7 @@ where
{
fn vec_znx_rsh_inplace_impl<A>(
module: &Module<Self>,
basek: usize,
base2k: usize,
k: usize,
a: &mut A,
a_col: usize,
@@ -310,8 +312,8 @@ where
) where
A: VecZnxToMut,
{
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
vec_znx_rsh_inplace::<_, Self>(basek, k, a, a_col, carry);
let (carry, _) = scratch.take_slice(module.vec_znx_rsh_tmp_bytes() / size_of::<i64>());
vec_znx_rsh_inplace::<_, Self>(base2k, k, a, a_col, carry);
}
}
@@ -495,18 +497,18 @@ unsafe impl VecZnxCopyImpl<Self> for FFT64Ref {
}
unsafe impl VecZnxFillUniformImpl<Self> for FFT64Ref {
fn vec_znx_fill_uniform_impl<R>(_module: &Module<Self>, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
fn vec_znx_fill_uniform_impl<R>(_module: &Module<Self>, base2k: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: VecZnxToMut,
{
vec_znx_fill_uniform_ref(basek, res, res_col, source)
vec_znx_fill_uniform_ref(base2k, res, res_col, source)
}
}
unsafe impl VecZnxFillNormalImpl<Self> for FFT64Ref {
fn vec_znx_fill_normal_impl<R>(
_module: &Module<Self>,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -516,14 +518,14 @@ unsafe impl VecZnxFillNormalImpl<Self> for FFT64Ref {
) where
R: VecZnxToMut,
{
vec_znx_fill_normal_ref(basek, res, res_col, k, sigma, bound, source);
vec_znx_fill_normal_ref(base2k, res, res_col, k, sigma, bound, source);
}
}
unsafe impl VecZnxAddNormalImpl<Self> for FFT64Ref {
fn vec_znx_add_normal_impl<R>(
_module: &Module<Self>,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -533,6 +535,6 @@ unsafe impl VecZnxAddNormalImpl<Self> for FFT64Ref {
) where
R: VecZnxToMut,
{
vec_znx_add_normal_ref(basek, res, res_col, k, sigma, bound, source);
vec_znx_add_normal_ref(base2k, res, res_col, k, sigma, bound, source);
}
}

View File

@@ -10,15 +10,15 @@ use poulpy_hal::{
VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, VecZnxBigAutomorphismImpl,
VecZnxBigAutomorphismInplaceImpl, VecZnxBigAutomorphismInplaceTmpBytesImpl, VecZnxBigFromBytesImpl,
VecZnxBigFromSmallImpl, VecZnxBigNegateImpl, VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl,
VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubABInplaceImpl, VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl,
VecZnxBigSubSmallAImpl, VecZnxBigSubSmallAInplaceImpl, VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl,
VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubImpl, VecZnxBigSubInplaceImpl, VecZnxBigSubNegateInplaceImpl,
VecZnxBigSubSmallAImpl, VecZnxBigSubSmallBImpl, VecZnxBigSubSmallInplaceImpl, VecZnxBigSubSmallNegateInplaceImpl,
},
reference::{
fft64::vec_znx_big::{
vec_znx_big_add, vec_znx_big_add_inplace, vec_znx_big_add_normal_ref, vec_znx_big_add_small,
vec_znx_big_add_small_inplace, vec_znx_big_automorphism, vec_znx_big_automorphism_inplace,
vec_znx_big_automorphism_inplace_tmp_bytes, vec_znx_big_negate, vec_znx_big_negate_inplace, vec_znx_big_normalize,
vec_znx_big_normalize_tmp_bytes, vec_znx_big_sub, vec_znx_big_sub_ab_inplace, vec_znx_big_sub_ba_inplace,
vec_znx_big_normalize_tmp_bytes, vec_znx_big_sub, vec_znx_big_sub_inplace, vec_znx_big_sub_negate_inplace,
vec_znx_big_sub_small_a, vec_znx_big_sub_small_a_inplace, vec_znx_big_sub_small_b, vec_znx_big_sub_small_b_inplace,
},
znx::{znx_copy_ref, znx_zero_ref},
@@ -76,7 +76,7 @@ unsafe impl VecZnxBigFromSmallImpl<Self> for FFT64Ref {
unsafe impl VecZnxBigAddNormalImpl<Self> for FFT64Ref {
fn add_normal_impl<R: VecZnxBigToMut<Self>>(
_module: &Module<Self>,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -84,7 +84,7 @@ unsafe impl VecZnxBigAddNormalImpl<Self> for FFT64Ref {
sigma: f64,
bound: f64,
) {
vec_znx_big_add_normal_ref(basek, res, res_col, k, sigma, bound, source);
vec_znx_big_add_normal_ref(base2k, res, res_col, k, sigma, bound, source);
}
}
@@ -167,25 +167,25 @@ unsafe impl VecZnxBigSubImpl<Self> for FFT64Ref {
}
}
unsafe impl VecZnxBigSubABInplaceImpl<Self> for FFT64Ref {
unsafe impl VecZnxBigSubInplaceImpl<Self> for FFT64Ref {
/// Subtracts `a` from `b` and stores the result on `b`.
fn vec_znx_big_sub_ab_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_big_sub_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<Self>,
A: VecZnxBigToRef<Self>,
{
vec_znx_big_sub_ab_inplace(res, res_col, a, a_col);
vec_znx_big_sub_inplace(res, res_col, a, a_col);
}
}
unsafe impl VecZnxBigSubBAInplaceImpl<Self> for FFT64Ref {
unsafe impl VecZnxBigSubNegateInplaceImpl<Self> for FFT64Ref {
/// Subtracts `b` from `a` and stores the result on `b`.
fn vec_znx_big_sub_ba_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_big_sub_negate_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<Self>,
A: VecZnxBigToRef<Self>,
{
vec_znx_big_sub_ba_inplace(res, res_col, a, a_col);
vec_znx_big_sub_negate_inplace(res, res_col, a, a_col);
}
}
@@ -208,9 +208,9 @@ unsafe impl VecZnxBigSubSmallAImpl<Self> for FFT64Ref {
}
}
unsafe impl VecZnxBigSubSmallAInplaceImpl<Self> for FFT64Ref {
unsafe impl VecZnxBigSubSmallInplaceImpl<Self> for FFT64Ref {
/// Subtracts `a` from `res` and stores the result on `res`.
fn vec_znx_big_sub_small_a_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_big_sub_small_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<Self>,
A: VecZnxToRef,
@@ -238,9 +238,9 @@ unsafe impl VecZnxBigSubSmallBImpl<Self> for FFT64Ref {
}
}
unsafe impl VecZnxBigSubSmallBInplaceImpl<Self> for FFT64Ref {
unsafe impl VecZnxBigSubSmallNegateInplaceImpl<Self> for FFT64Ref {
/// Subtracts `res` from `a` and stores the result on `res`.
fn vec_znx_big_sub_small_b_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_big_sub_small_negate_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<Self>,
A: VecZnxToRef,
@@ -280,9 +280,10 @@ where
{
fn vec_znx_big_normalize_impl<R, A>(
module: &Module<Self>,
basek: usize,
res_basek: usize,
res: &mut R,
res_col: usize,
a_basek: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch<Self>,
@@ -291,7 +292,7 @@ where
A: VecZnxBigToRef<Self>,
{
let (carry, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes() / size_of::<i64>());
vec_znx_big_normalize(basek, res, res_col, a, a_col, carry);
vec_znx_big_normalize(res_basek, res, res_col, a_basek, a, a_col, carry);
}
}
@@ -326,7 +327,7 @@ where
) where
R: VecZnxBigToMut<Self>,
{
let (tmp, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes() / size_of::<i64>());
let (tmp, _) = scratch.take_slice(module.vec_znx_big_automorphism_inplace_tmp_bytes() / size_of::<i64>());
vec_znx_big_automorphism_inplace(p, res, res_col, tmp);
}
}

View File

@@ -5,12 +5,12 @@ use poulpy_hal::{
},
oep::{
VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftApplyImpl,
VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl, VecZnxDftSubImpl,
VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubImpl, VecZnxDftSubInplaceImpl, VecZnxDftSubNegateInplaceImpl,
VecZnxDftZeroImpl, VecZnxIdftApplyConsumeImpl, VecZnxIdftApplyImpl, VecZnxIdftApplyTmpAImpl, VecZnxIdftApplyTmpBytesImpl,
},
reference::fft64::vec_znx_dft::{
vec_znx_dft_add, vec_znx_dft_add_inplace, vec_znx_dft_apply, vec_znx_dft_copy, vec_znx_dft_sub,
vec_znx_dft_sub_ab_inplace, vec_znx_dft_sub_ba_inplace, vec_znx_dft_zero, vec_znx_idft_apply, vec_znx_idft_apply_consume,
vec_znx_dft_add, vec_znx_dft_add_inplace, vec_znx_dft_apply, vec_znx_dft_copy, vec_znx_dft_sub, vec_znx_dft_sub_inplace,
vec_znx_dft_sub_negate_inplace, vec_znx_dft_zero, vec_znx_idft_apply, vec_znx_idft_apply_consume,
vec_znx_idft_apply_tmpa,
},
};
@@ -139,23 +139,23 @@ unsafe impl VecZnxDftSubImpl<Self> for FFT64Ref {
}
}
unsafe impl VecZnxDftSubABInplaceImpl<Self> for FFT64Ref {
fn vec_znx_dft_sub_ab_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
unsafe impl VecZnxDftSubInplaceImpl<Self> for FFT64Ref {
fn vec_znx_dft_sub_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<Self>,
A: VecZnxDftToRef<Self>,
{
vec_znx_dft_sub_ab_inplace(res, res_col, a, a_col);
vec_znx_dft_sub_inplace(res, res_col, a, a_col);
}
}
unsafe impl VecZnxDftSubBAInplaceImpl<Self> for FFT64Ref {
fn vec_znx_dft_sub_ba_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
unsafe impl VecZnxDftSubNegateInplaceImpl<Self> for FFT64Ref {
fn vec_znx_dft_sub_negate_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<Self>,
A: VecZnxDftToRef<Self>,
{
vec_znx_dft_sub_ba_inplace(res, res_col, a, a_col);
vec_znx_dft_sub_negate_inplace(res, res_col, a, a_col);
}
}

View File

@@ -18,21 +18,21 @@ unsafe impl ZnNormalizeInplaceImpl<Self> for FFT64Ref
where
Self: TakeSliceImpl<Self>,
{
fn zn_normalize_inplace_impl<R>(n: usize, basek: usize, res: &mut R, res_col: usize, scratch: &mut Scratch<Self>)
fn zn_normalize_inplace_impl<R>(n: usize, base2k: usize, res: &mut R, res_col: usize, scratch: &mut Scratch<Self>)
where
R: ZnToMut,
{
let (carry, _) = scratch.take_slice(n);
zn_normalize_inplace::<R, FFT64Ref>(n, basek, res, res_col, carry);
zn_normalize_inplace::<R, FFT64Ref>(n, base2k, res, res_col, carry);
}
}
unsafe impl ZnFillUniformImpl<Self> for FFT64Ref {
fn zn_fill_uniform_impl<R>(n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
fn zn_fill_uniform_impl<R>(n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: ZnToMut,
{
zn_fill_uniform(n, basek, res, res_col, source);
zn_fill_uniform(n, base2k, res, res_col, source);
}
}
@@ -40,7 +40,7 @@ unsafe impl ZnFillNormalImpl<Self> for FFT64Ref {
#[allow(clippy::too_many_arguments)]
fn zn_fill_normal_impl<R>(
n: usize,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -50,7 +50,7 @@ unsafe impl ZnFillNormalImpl<Self> for FFT64Ref {
) where
R: ZnToMut,
{
zn_fill_normal(n, basek, res, res_col, k, source, sigma, bound);
zn_fill_normal(n, base2k, res, res_col, k, source, sigma, bound);
}
}
@@ -58,7 +58,7 @@ unsafe impl ZnAddNormalImpl<Self> for FFT64Ref {
#[allow(clippy::too_many_arguments)]
fn zn_add_normal_impl<R>(
n: usize,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -68,6 +68,6 @@ unsafe impl ZnAddNormalImpl<Self> for FFT64Ref {
) where
R: ZnToMut,
{
zn_add_normal(n, basek, res, res_col, k, source, sigma, bound);
zn_add_normal(n, base2k, res, res_col, k, source, sigma, bound);
}
}

View File

@@ -1,12 +1,14 @@
use poulpy_hal::reference::znx::{
ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxNormalizeFinalStep,
ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeFirstStepInplace,
ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, ZnxRotate, ZnxSub, ZnxSubABInplace,
ZnxSubBAInplace, ZnxSwitchRing, ZnxZero, znx_add_inplace_ref, znx_add_ref, znx_automorphism_ref, znx_copy_ref,
znx_negate_inplace_ref, znx_negate_ref, znx_normalize_final_step_inplace_ref, znx_normalize_final_step_ref,
ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxExtractDigitAddMul, ZnxMulAddPowerOfTwo, ZnxMulPowerOfTwo,
ZnxMulPowerOfTwoInplace, ZnxNegate, ZnxNegateInplace, ZnxNormalizeDigit, ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace,
ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStep,
ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, ZnxRotate, ZnxSub, ZnxSubInplace, ZnxSubNegateInplace,
ZnxSwitchRing, ZnxZero, znx_add_inplace_ref, znx_add_ref, znx_automorphism_ref, znx_copy_ref, znx_extract_digit_addmul_ref,
znx_mul_add_power_of_two_ref, znx_mul_power_of_two_inplace_ref, znx_mul_power_of_two_ref, znx_negate_inplace_ref,
znx_negate_ref, znx_normalize_digit_ref, znx_normalize_final_step_inplace_ref, znx_normalize_final_step_ref,
znx_normalize_first_step_carry_only_ref, znx_normalize_first_step_inplace_ref, znx_normalize_first_step_ref,
znx_normalize_middle_step_carry_only_ref, znx_normalize_middle_step_inplace_ref, znx_normalize_middle_step_ref, znx_rotate,
znx_sub_ab_inplace_ref, znx_sub_ba_inplace_ref, znx_sub_ref, znx_switch_ring_ref, znx_zero_ref,
znx_sub_inplace_ref, znx_sub_negate_inplace_ref, znx_sub_ref, znx_switch_ring_ref, znx_zero_ref,
};
use crate::cpu_fft64_ref::FFT64Ref;
@@ -32,17 +34,38 @@ impl ZnxSub for FFT64Ref {
}
}
impl ZnxSubABInplace for FFT64Ref {
impl ZnxSubInplace for FFT64Ref {
#[inline(always)]
fn znx_sub_ab_inplace(res: &mut [i64], a: &[i64]) {
znx_sub_ab_inplace_ref(res, a);
fn znx_sub_inplace(res: &mut [i64], a: &[i64]) {
znx_sub_inplace_ref(res, a);
}
}
impl ZnxSubBAInplace for FFT64Ref {
impl ZnxSubNegateInplace for FFT64Ref {
#[inline(always)]
fn znx_sub_ba_inplace(res: &mut [i64], a: &[i64]) {
znx_sub_ba_inplace_ref(res, a);
fn znx_sub_negate_inplace(res: &mut [i64], a: &[i64]) {
znx_sub_negate_inplace_ref(res, a);
}
}
impl ZnxMulAddPowerOfTwo for FFT64Ref {
#[inline(always)]
fn znx_muladd_power_of_two(k: i64, res: &mut [i64], a: &[i64]) {
znx_mul_add_power_of_two_ref(k, res, a);
}
}
impl ZnxMulPowerOfTwo for FFT64Ref {
#[inline(always)]
fn znx_mul_power_of_two(k: i64, res: &mut [i64], a: &[i64]) {
znx_mul_power_of_two_ref(k, res, a);
}
}
impl ZnxMulPowerOfTwoInplace for FFT64Ref {
#[inline(always)]
fn znx_mul_power_of_two_inplace(k: i64, res: &mut [i64]) {
znx_mul_power_of_two_inplace_ref(k, res);
}
}
@@ -97,56 +120,70 @@ impl ZnxSwitchRing for FFT64Ref {
impl ZnxNormalizeFinalStep for FFT64Ref {
#[inline(always)]
fn znx_normalize_final_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
znx_normalize_final_step_ref(basek, lsh, x, a, carry);
fn znx_normalize_final_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
znx_normalize_final_step_ref(base2k, lsh, x, a, carry);
}
}
impl ZnxNormalizeFinalStepInplace for FFT64Ref {
#[inline(always)]
fn znx_normalize_final_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
znx_normalize_final_step_inplace_ref(basek, lsh, x, carry);
fn znx_normalize_final_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
znx_normalize_final_step_inplace_ref(base2k, lsh, x, carry);
}
}
impl ZnxNormalizeFirstStep for FFT64Ref {
#[inline(always)]
fn znx_normalize_first_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
znx_normalize_first_step_ref(basek, lsh, x, a, carry);
fn znx_normalize_first_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
znx_normalize_first_step_ref(base2k, lsh, x, a, carry);
}
}
impl ZnxNormalizeFirstStepCarryOnly for FFT64Ref {
#[inline(always)]
fn znx_normalize_first_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
znx_normalize_first_step_carry_only_ref(basek, lsh, x, carry);
fn znx_normalize_first_step_carry_only(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
znx_normalize_first_step_carry_only_ref(base2k, lsh, x, carry);
}
}
impl ZnxNormalizeFirstStepInplace for FFT64Ref {
#[inline(always)]
fn znx_normalize_first_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
znx_normalize_first_step_inplace_ref(basek, lsh, x, carry);
fn znx_normalize_first_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
znx_normalize_first_step_inplace_ref(base2k, lsh, x, carry);
}
}
impl ZnxNormalizeMiddleStep for FFT64Ref {
#[inline(always)]
fn znx_normalize_middle_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
znx_normalize_middle_step_ref(basek, lsh, x, a, carry);
fn znx_normalize_middle_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
znx_normalize_middle_step_ref(base2k, lsh, x, a, carry);
}
}
impl ZnxNormalizeMiddleStepCarryOnly for FFT64Ref {
#[inline(always)]
fn znx_normalize_middle_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
znx_normalize_middle_step_carry_only_ref(basek, lsh, x, carry);
fn znx_normalize_middle_step_carry_only(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
znx_normalize_middle_step_carry_only_ref(base2k, lsh, x, carry);
}
}
impl ZnxNormalizeMiddleStepInplace for FFT64Ref {
#[inline(always)]
fn znx_normalize_middle_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
znx_normalize_middle_step_inplace_ref(basek, lsh, x, carry);
fn znx_normalize_middle_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
znx_normalize_middle_step_inplace_ref(base2k, lsh, x, carry);
}
}
impl ZnxExtractDigitAddMul for FFT64Ref {
#[inline(always)]
fn znx_extract_digit_addmul(base2k: usize, lsh: usize, res: &mut [i64], src: &mut [i64]) {
znx_extract_digit_addmul_ref(base2k, lsh, res, src);
}
}
impl ZnxNormalizeDigit for FFT64Ref {
#[inline(always)]
fn znx_normalize_digit(base2k: usize, res: &mut [i64], src: &mut [i64]) {
znx_normalize_digit_ref(base2k, res, src);
}
}

View File

@@ -6,5 +6,6 @@ mod vec_znx_big;
mod vec_znx_dft;
mod vmp_pmat;
mod zn;
mod znx;
pub struct FFT64Spqlios;

View File

@@ -3,20 +3,11 @@ use std::ptr::NonNull;
use poulpy_hal::{
layouts::{Backend, Module},
oep::ModuleNewImpl,
reference::znx::{
ZnxCopy, ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly,
ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace,
ZnxRotate, ZnxSwitchRing, ZnxZero, znx_copy_ref, znx_normalize_final_step_inplace_ref, znx_normalize_final_step_ref,
znx_normalize_first_step_carry_only_ref, znx_normalize_first_step_inplace_ref, znx_normalize_first_step_ref,
znx_normalize_middle_step_carry_only_ref, znx_normalize_middle_step_inplace_ref, znx_normalize_middle_step_ref,
znx_switch_ring_ref, znx_zero_ref,
},
};
use crate::cpu_spqlios::{
FFT64Spqlios,
ffi::module::{MODULE, delete_module_info, new_module_info},
znx::znx_rotate_i64,
};
impl Backend for FFT64Spqlios {
@@ -41,85 +32,3 @@ unsafe impl ModuleNewImpl<Self> for FFT64Spqlios {
unsafe { Module::from_raw_parts(new_module_info(n, 0), n) }
}
}
impl ZnxCopy for FFT64Spqlios {
fn znx_copy(res: &mut [i64], a: &[i64]) {
znx_copy_ref(res, a);
}
}
impl ZnxZero for FFT64Spqlios {
fn znx_zero(res: &mut [i64]) {
znx_zero_ref(res);
}
}
impl ZnxSwitchRing for FFT64Spqlios {
fn znx_switch_ring(res: &mut [i64], a: &[i64]) {
znx_switch_ring_ref(res, a);
}
}
impl ZnxRotate for FFT64Spqlios {
fn znx_rotate(p: i64, res: &mut [i64], src: &[i64]) {
unsafe {
znx_rotate_i64(res.len() as u64, p, res.as_mut_ptr(), src.as_ptr());
}
}
}
impl ZnxNormalizeFinalStep for FFT64Spqlios {
#[inline(always)]
fn znx_normalize_final_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
znx_normalize_final_step_ref(basek, lsh, x, a, carry);
}
}
impl ZnxNormalizeFinalStepInplace for FFT64Spqlios {
#[inline(always)]
fn znx_normalize_final_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
znx_normalize_final_step_inplace_ref(basek, lsh, x, carry);
}
}
impl ZnxNormalizeFirstStep for FFT64Spqlios {
#[inline(always)]
fn znx_normalize_first_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
znx_normalize_first_step_ref(basek, lsh, x, a, carry);
}
}
impl ZnxNormalizeFirstStepCarryOnly for FFT64Spqlios {
#[inline(always)]
fn znx_normalize_first_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
znx_normalize_first_step_carry_only_ref(basek, lsh, x, carry);
}
}
impl ZnxNormalizeFirstStepInplace for FFT64Spqlios {
#[inline(always)]
fn znx_normalize_first_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
znx_normalize_first_step_inplace_ref(basek, lsh, x, carry);
}
}
impl ZnxNormalizeMiddleStep for FFT64Spqlios {
#[inline(always)]
fn znx_normalize_middle_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
znx_normalize_middle_step_ref(basek, lsh, x, a, carry);
}
}
impl ZnxNormalizeMiddleStepCarryOnly for FFT64Spqlios {
#[inline(always)]
fn znx_normalize_middle_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
znx_normalize_middle_step_carry_only_ref(basek, lsh, x, carry);
}
}
impl ZnxNormalizeMiddleStepInplace for FFT64Spqlios {
#[inline(always)]
fn znx_normalize_middle_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
znx_normalize_middle_step_inplace_ref(basek, lsh, x, carry);
}
}

View File

@@ -253,9 +253,6 @@ fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8]
(take_slice, rem_slice)
}
} else {
panic!(
"Attempted to take {} from scratch with {} aligned bytes left",
take_len, aligned_len,
);
panic!("Attempted to take {take_len} from scratch with {aligned_len} aligned bytes left");
}
}

View File

@@ -1,5 +1,8 @@
use poulpy_hal::{
api::{TakeSlice, VecZnxMergeRingsTmpBytes, VecZnxNormalizeTmpBytes, VecZnxSplitRingTmpBytes},
api::{
TakeSlice, VecZnxLshTmpBytes, VecZnxMergeRingsTmpBytes, VecZnxNormalizeTmpBytes, VecZnxRshTmpBytes,
VecZnxSplitRingTmpBytes,
},
layouts::{
Module, ScalarZnx, ScalarZnxToRef, Scratch, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut,
},
@@ -11,16 +14,16 @@ use poulpy_hal::{
VecZnxMulXpMinusOneInplaceTmpBytesImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl,
VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl,
VecZnxRotateInplaceTmpBytesImpl, VecZnxRshImpl, VecZnxRshInplaceImpl, VecZnxRshTmpBytesImpl, VecZnxSplitRingImpl,
VecZnxSplitRingTmpBytesImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarImpl,
VecZnxSplitRingTmpBytesImpl, VecZnxSubImpl, VecZnxSubInplaceImpl, VecZnxSubNegateInplaceImpl, VecZnxSubScalarImpl,
VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl,
},
reference::{
vec_znx::{
vec_znx_add_normal_ref, vec_znx_automorphism_inplace_tmp_bytes, vec_znx_copy, vec_znx_fill_normal_ref,
vec_znx_fill_uniform_ref, vec_znx_lsh, vec_znx_lsh_inplace, vec_znx_lsh_tmp_bytes, vec_znx_merge_rings,
vec_znx_merge_rings_tmp_bytes, vec_znx_mul_xp_minus_one_inplace_tmp_bytes, vec_znx_rotate_inplace_tmp_bytes,
vec_znx_rsh, vec_znx_rsh_inplace, vec_znx_rsh_tmp_bytes, vec_znx_split_ring, vec_znx_split_ring_tmp_bytes,
vec_znx_switch_ring,
vec_znx_merge_rings_tmp_bytes, vec_znx_mul_xp_minus_one_inplace_tmp_bytes, vec_znx_normalize_tmp_bytes,
vec_znx_rotate_inplace_tmp_bytes, vec_znx_rsh, vec_znx_rsh_inplace, vec_znx_rsh_tmp_bytes, vec_znx_split_ring,
vec_znx_split_ring_tmp_bytes, vec_znx_switch_ring,
},
znx::{znx_copy_ref, znx_zero_ref},
},
@@ -34,7 +37,7 @@ use crate::cpu_spqlios::{
unsafe impl VecZnxNormalizeTmpBytesImpl<Self> for FFT64Spqlios {
fn vec_znx_normalize_tmp_bytes_impl(module: &Module<Self>) -> usize {
unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(module.ptr() as *const module_info_t) as usize }
vec_znx_normalize_tmp_bytes(module.n())
}
}
@@ -44,9 +47,10 @@ where
{
fn vec_znx_normalize_impl<R, A>(
module: &Module<Self>,
basek: usize,
res_basek: usize,
res: &mut R,
res_col: usize,
a_basek: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch<Self>,
@@ -60,6 +64,10 @@ where
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), a.n());
assert_eq!(
res_basek, a_basek,
"res_basek != a_basek -> base2k conversion is not supported"
)
}
let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes());
@@ -67,7 +75,7 @@ where
unsafe {
vec_znx::vec_znx_normalize_base2k(
module.ptr() as *const module_info_t,
basek as u64,
res_basek as u64,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
@@ -86,7 +94,7 @@ where
{
fn vec_znx_normalize_inplace_impl<A>(
module: &Module<Self>,
basek: usize,
base2k: usize,
a: &mut A,
a_col: usize,
scratch: &mut Scratch<Self>,
@@ -100,7 +108,7 @@ where
unsafe {
vec_znx::vec_znx_normalize_base2k(
module.ptr() as *const module_info_t,
basek as u64,
base2k as u64,
a.at_mut_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
@@ -301,8 +309,8 @@ unsafe impl VecZnxSubImpl<Self> for FFT64Spqlios {
}
}
unsafe impl VecZnxSubABInplaceImpl<Self> for FFT64Spqlios {
fn vec_znx_sub_ab_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
unsafe impl VecZnxSubInplaceImpl<Self> for FFT64Spqlios {
fn vec_znx_sub_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
@@ -330,8 +338,8 @@ unsafe impl VecZnxSubABInplaceImpl<Self> for FFT64Spqlios {
}
}
unsafe impl VecZnxSubBAInplaceImpl<Self> for FFT64Spqlios {
fn vec_znx_sub_ba_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
unsafe impl VecZnxSubNegateInplaceImpl<Self> for FFT64Spqlios {
fn vec_znx_sub_negate_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
@@ -512,9 +520,9 @@ where
Module<Self>: VecZnxNormalizeTmpBytes,
Scratch<Self>: TakeSlice,
{
fn vec_znx_lsh_inplace_impl<R, A>(
fn vec_znx_lsh_impl<R, A>(
module: &Module<Self>,
basek: usize,
base2k: usize,
k: usize,
res: &mut R,
res_col: usize,
@@ -525,8 +533,8 @@ where
R: VecZnxToMut,
A: VecZnxToRef,
{
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
vec_znx_lsh::<_, _, FFT64Spqlios>(basek, k, res, res_col, a, a_col, carry)
let (carry, _) = scratch.take_slice(module.vec_znx_lsh_tmp_bytes() / size_of::<i64>());
vec_znx_lsh::<_, _, FFT64Spqlios>(base2k, k, res, res_col, a, a_col, carry)
}
}
@@ -537,7 +545,7 @@ where
{
fn vec_znx_lsh_inplace_impl<A>(
module: &Module<Self>,
basek: usize,
base2k: usize,
k: usize,
a: &mut A,
a_col: usize,
@@ -545,8 +553,8 @@ where
) where
A: VecZnxToMut,
{
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
vec_znx_lsh_inplace::<_, FFT64Spqlios>(basek, k, a, a_col, carry)
let (carry, _) = scratch.take_slice(module.vec_znx_lsh_tmp_bytes() / size_of::<i64>());
vec_znx_lsh_inplace::<_, FFT64Spqlios>(base2k, k, a, a_col, carry)
}
}
@@ -555,9 +563,9 @@ where
Module<Self>: VecZnxNormalizeTmpBytes,
Scratch<Self>: TakeSlice,
{
fn vec_znx_rsh_inplace_impl<R, A>(
fn vec_znx_rsh_impl<R, A>(
module: &Module<Self>,
basek: usize,
base2k: usize,
k: usize,
res: &mut R,
res_col: usize,
@@ -568,8 +576,8 @@ where
R: VecZnxToMut,
A: VecZnxToRef,
{
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
vec_znx_rsh::<_, _, FFT64Spqlios>(basek, k, res, res_col, a, a_col, carry)
let (carry, _) = scratch.take_slice(module.vec_znx_rsh_tmp_bytes() / size_of::<i64>());
vec_znx_rsh::<_, _, FFT64Spqlios>(base2k, k, res, res_col, a, a_col, carry)
}
}
@@ -580,7 +588,7 @@ where
{
fn vec_znx_rsh_inplace_impl<A>(
module: &Module<Self>,
basek: usize,
base2k: usize,
k: usize,
a: &mut A,
a_col: usize,
@@ -588,8 +596,8 @@ where
) where
A: VecZnxToMut,
{
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
vec_znx_rsh_inplace::<_, FFT64Spqlios>(basek, k, a, a_col, carry)
let (carry, _) = scratch.take_slice(module.vec_znx_rsh_tmp_bytes() / size_of::<i64>());
vec_znx_rsh_inplace::<_, FFT64Spqlios>(base2k, k, a, a_col, carry)
}
}
@@ -690,11 +698,7 @@ unsafe impl VecZnxAutomorphismInplaceImpl<Self> for FFT64Spqlios {
let mut a: VecZnx<&mut [u8]> = a.to_mut();
#[cfg(debug_assertions)]
{
assert!(
k & 1 != 0,
"invalid galois element: must be odd but is {}",
k
);
assert!(k & 1 != 0, "invalid galois element: must be odd but is {k}");
}
unsafe {
vec_znx::vec_znx_automorphism(
@@ -852,18 +856,18 @@ unsafe impl VecZnxCopyImpl<Self> for FFT64Spqlios {
}
unsafe impl VecZnxFillUniformImpl<Self> for FFT64Spqlios {
fn vec_znx_fill_uniform_impl<R>(_module: &Module<Self>, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
fn vec_znx_fill_uniform_impl<R>(_module: &Module<Self>, base2k: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: VecZnxToMut,
{
vec_znx_fill_uniform_ref(basek, res, res_col, source)
vec_znx_fill_uniform_ref(base2k, res, res_col, source)
}
}
unsafe impl VecZnxFillNormalImpl<Self> for FFT64Spqlios {
fn vec_znx_fill_normal_impl<R>(
_module: &Module<Self>,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -873,14 +877,14 @@ unsafe impl VecZnxFillNormalImpl<Self> for FFT64Spqlios {
) where
R: VecZnxToMut,
{
vec_znx_fill_normal_ref(basek, res, res_col, k, sigma, bound, source);
vec_znx_fill_normal_ref(base2k, res, res_col, k, sigma, bound, source);
}
}
unsafe impl VecZnxAddNormalImpl<Self> for FFT64Spqlios {
fn vec_znx_add_normal_impl<R>(
_module: &Module<Self>,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -890,6 +894,6 @@ unsafe impl VecZnxAddNormalImpl<Self> for FFT64Spqlios {
) where
R: VecZnxToMut,
{
vec_znx_add_normal_ref(basek, res, res_col, k, sigma, bound, source);
vec_znx_add_normal_ref(base2k, res, res_col, k, sigma, bound, source);
}
}

View File

@@ -10,11 +10,12 @@ use poulpy_hal::{
VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, VecZnxBigAutomorphismImpl,
VecZnxBigAutomorphismInplaceImpl, VecZnxBigAutomorphismInplaceTmpBytesImpl, VecZnxBigFromBytesImpl,
VecZnxBigFromSmallImpl, VecZnxBigNegateImpl, VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl,
VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubABInplaceImpl, VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl,
VecZnxBigSubSmallAImpl, VecZnxBigSubSmallAInplaceImpl, VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl,
VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubImpl, VecZnxBigSubInplaceImpl, VecZnxBigSubNegateInplaceImpl,
VecZnxBigSubSmallAImpl, VecZnxBigSubSmallBImpl, VecZnxBigSubSmallInplaceImpl, VecZnxBigSubSmallNegateInplaceImpl,
},
reference::{
vec_znx::vec_znx_add_normal_ref,
fft64::vec_znx_big::vec_znx_big_normalize,
vec_znx::{vec_znx_add_normal_ref, vec_znx_normalize_tmp_bytes},
znx::{znx_copy_ref, znx_zero_ref},
},
source::Source,
@@ -70,7 +71,7 @@ unsafe impl VecZnxBigFromSmallImpl<Self> for FFT64Spqlios {
unsafe impl VecZnxBigAddNormalImpl<Self> for FFT64Spqlios {
fn add_normal_impl<R: VecZnxBigToMut<Self>>(
_module: &Module<Self>,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -88,7 +89,7 @@ unsafe impl VecZnxBigAddNormalImpl<Self> for FFT64Spqlios {
max_size: res.max_size,
};
vec_znx_add_normal_ref(basek, &mut res_znx, res_col, k, sigma, bound, source);
vec_znx_add_normal_ref(base2k, &mut res_znx, res_col, k, sigma, bound, source);
}
}
@@ -266,9 +267,9 @@ unsafe impl VecZnxBigSubImpl<Self> for FFT64Spqlios {
}
}
unsafe impl VecZnxBigSubABInplaceImpl<Self> for FFT64Spqlios {
unsafe impl VecZnxBigSubInplaceImpl<Self> for FFT64Spqlios {
/// Subtracts `a` from `b` and stores the result on `b`.
fn vec_znx_big_sub_ab_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_big_sub_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<Self>,
A: VecZnxBigToRef<Self>,
@@ -297,9 +298,9 @@ unsafe impl VecZnxBigSubABInplaceImpl<Self> for FFT64Spqlios {
}
}
unsafe impl VecZnxBigSubBAInplaceImpl<Self> for FFT64Spqlios {
unsafe impl VecZnxBigSubNegateInplaceImpl<Self> for FFT64Spqlios {
/// Subtracts `b` from `a` and stores the result on `b`.
fn vec_znx_big_sub_ba_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_big_sub_negate_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<Self>,
A: VecZnxBigToRef<Self>,
@@ -370,9 +371,9 @@ unsafe impl VecZnxBigSubSmallAImpl<Self> for FFT64Spqlios {
}
}
unsafe impl VecZnxBigSubSmallAInplaceImpl<Self> for FFT64Spqlios {
unsafe impl VecZnxBigSubSmallInplaceImpl<Self> for FFT64Spqlios {
/// Subtracts `a` from `res` and stores the result on `res`.
fn vec_znx_big_sub_small_a_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_big_sub_small_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<Self>,
A: VecZnxToRef,
@@ -443,9 +444,9 @@ unsafe impl VecZnxBigSubSmallBImpl<Self> for FFT64Spqlios {
}
}
unsafe impl VecZnxBigSubSmallBInplaceImpl<Self> for FFT64Spqlios {
unsafe impl VecZnxBigSubSmallNegateInplaceImpl<Self> for FFT64Spqlios {
/// Subtracts `res` from `a` and stores the result on `res`.
fn vec_znx_big_sub_small_b_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_big_sub_small_negate_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<Self>,
A: VecZnxToRef,
@@ -518,7 +519,7 @@ unsafe impl VecZnxBigNegateInplaceImpl<Self> for FFT64Spqlios {
unsafe impl VecZnxBigNormalizeTmpBytesImpl<Self> for FFT64Spqlios {
fn vec_znx_big_normalize_tmp_bytes_impl(module: &Module<Self>) -> usize {
unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(module.ptr()) as usize }
vec_znx_normalize_tmp_bytes(module.n())
}
}
@@ -528,9 +529,10 @@ where
{
fn vec_znx_big_normalize_impl<R, A>(
module: &Module<Self>,
basek: usize,
res_basek: usize,
res: &mut R,
res_col: usize,
a_basek: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch<Self>,
@@ -538,28 +540,21 @@ where
R: VecZnxToMut,
A: VecZnxBigToRef<Self>,
{
let a: VecZnxBig<&[u8], Self> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), a.n());
}
let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes());
unsafe {
vec_znx::vec_znx_normalize_base2k(
module.ptr(),
basek as u64,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
tmp_bytes.as_mut_ptr(),
);
}
let (carry, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes() / size_of::<i64>());
// unsafe {
// vec_znx::vec_znx_normalize_base2k(
// module.ptr(),
// base2k as u64,
// res.at_mut_ptr(res_col, 0),
// res.size() as u64,
// res.sl() as u64,
// a.at_ptr(a_col, 0),
// a.size() as u64,
// a.sl() as u64,
// tmp_bytes.as_mut_ptr(),
// );
// }
vec_znx_big_normalize(res_basek, res, res_col, a_basek, a, a_col, carry);
}
}

View File

@@ -6,7 +6,7 @@ use poulpy_hal::{
},
oep::{
VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftApplyImpl,
VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl, VecZnxDftSubImpl,
VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubImpl, VecZnxDftSubInplaceImpl, VecZnxDftSubNegateInplaceImpl,
VecZnxDftZeroImpl, VecZnxIdftApplyConsumeImpl, VecZnxIdftApplyImpl, VecZnxIdftApplyTmpAImpl, VecZnxIdftApplyTmpBytesImpl,
},
reference::{
@@ -336,8 +336,8 @@ unsafe impl VecZnxDftSubImpl<Self> for FFT64Spqlios {
}
}
unsafe impl VecZnxDftSubABInplaceImpl<Self> for FFT64Spqlios {
fn vec_znx_dft_sub_ab_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
unsafe impl VecZnxDftSubInplaceImpl<Self> for FFT64Spqlios {
fn vec_znx_dft_sub_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<Self>,
A: VecZnxDftToRef<Self>,
@@ -363,8 +363,8 @@ unsafe impl VecZnxDftSubABInplaceImpl<Self> for FFT64Spqlios {
}
}
unsafe impl VecZnxDftSubBAInplaceImpl<Self> for FFT64Spqlios {
fn vec_znx_dft_sub_ba_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
unsafe impl VecZnxDftSubNegateInplaceImpl<Self> for FFT64Spqlios {
fn vec_znx_dft_sub_negate_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<Self>,
A: VecZnxDftToRef<Self>,

View File

@@ -12,7 +12,7 @@ unsafe impl ZnNormalizeInplaceImpl<Self> for FFT64Spqlios
where
Self: TakeSliceImpl<Self>,
{
fn zn_normalize_inplace_impl<A>(n: usize, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<Self>)
fn zn_normalize_inplace_impl<A>(n: usize, base2k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<Self>)
where
A: ZnToMut,
{
@@ -23,7 +23,7 @@ where
unsafe {
zn64::zn64_normalize_base2k_ref(
n as u64,
basek as u64,
base2k as u64,
a.at_mut_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
@@ -37,11 +37,11 @@ where
}
unsafe impl ZnFillUniformImpl<Self> for FFT64Spqlios {
fn zn_fill_uniform_impl<R>(n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
fn zn_fill_uniform_impl<R>(n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: ZnToMut,
{
zn_fill_uniform(n, basek, res, res_col, source);
zn_fill_uniform(n, base2k, res, res_col, source);
}
}
@@ -49,7 +49,7 @@ unsafe impl ZnFillNormalImpl<Self> for FFT64Spqlios {
#[allow(clippy::too_many_arguments)]
fn zn_fill_normal_impl<R>(
n: usize,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -59,7 +59,7 @@ unsafe impl ZnFillNormalImpl<Self> for FFT64Spqlios {
) where
R: ZnToMut,
{
zn_fill_normal(n, basek, res, res_col, k, source, sigma, bound);
zn_fill_normal(n, base2k, res, res_col, k, source, sigma, bound);
}
}
@@ -67,7 +67,7 @@ unsafe impl ZnAddNormalImpl<Self> for FFT64Spqlios {
#[allow(clippy::too_many_arguments)]
fn zn_add_normal_impl<R>(
n: usize,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -77,6 +77,6 @@ unsafe impl ZnAddNormalImpl<Self> for FFT64Spqlios {
) where
R: ZnToMut,
{
zn_add_normal(n, basek, res, res_col, k, source, sigma, bound);
zn_add_normal(n, base2k, res, res_col, k, source, sigma, bound);
}
}

View File

@@ -0,0 +1,189 @@
use poulpy_hal::reference::znx::{
ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxExtractDigitAddMul, ZnxMulAddPowerOfTwo, ZnxMulPowerOfTwo,
ZnxMulPowerOfTwoInplace, ZnxNegate, ZnxNegateInplace, ZnxNormalizeDigit, ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace,
ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStep,
ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, ZnxRotate, ZnxSub, ZnxSubInplace, ZnxSubNegateInplace,
ZnxSwitchRing, ZnxZero, znx_add_inplace_ref, znx_add_ref, znx_automorphism_ref, znx_copy_ref, znx_extract_digit_addmul_ref,
znx_mul_add_power_of_two_ref, znx_mul_power_of_two_inplace_ref, znx_mul_power_of_two_ref, znx_negate_inplace_ref,
znx_negate_ref, znx_normalize_digit_ref, znx_normalize_final_step_inplace_ref, znx_normalize_final_step_ref,
znx_normalize_first_step_carry_only_ref, znx_normalize_first_step_inplace_ref, znx_normalize_first_step_ref,
znx_normalize_middle_step_carry_only_ref, znx_normalize_middle_step_inplace_ref, znx_normalize_middle_step_ref, znx_rotate,
znx_sub_inplace_ref, znx_sub_negate_inplace_ref, znx_sub_ref, znx_switch_ring_ref, znx_zero_ref,
};
use crate::FFT64Spqlios;
impl ZnxAdd for FFT64Spqlios {
#[inline(always)]
fn znx_add(res: &mut [i64], a: &[i64], b: &[i64]) {
znx_add_ref(res, a, b);
}
}
impl ZnxAddInplace for FFT64Spqlios {
#[inline(always)]
fn znx_add_inplace(res: &mut [i64], a: &[i64]) {
znx_add_inplace_ref(res, a);
}
}
impl ZnxSub for FFT64Spqlios {
#[inline(always)]
fn znx_sub(res: &mut [i64], a: &[i64], b: &[i64]) {
znx_sub_ref(res, a, b);
}
}
impl ZnxSubInplace for FFT64Spqlios {
#[inline(always)]
fn znx_sub_inplace(res: &mut [i64], a: &[i64]) {
znx_sub_inplace_ref(res, a);
}
}
impl ZnxSubNegateInplace for FFT64Spqlios {
#[inline(always)]
fn znx_sub_negate_inplace(res: &mut [i64], a: &[i64]) {
znx_sub_negate_inplace_ref(res, a);
}
}
impl ZnxMulAddPowerOfTwo for FFT64Spqlios {
#[inline(always)]
fn znx_muladd_power_of_two(k: i64, res: &mut [i64], a: &[i64]) {
znx_mul_add_power_of_two_ref(k, res, a);
}
}
impl ZnxMulPowerOfTwo for FFT64Spqlios {
#[inline(always)]
fn znx_mul_power_of_two(k: i64, res: &mut [i64], a: &[i64]) {
znx_mul_power_of_two_ref(k, res, a);
}
}
impl ZnxMulPowerOfTwoInplace for FFT64Spqlios {
#[inline(always)]
fn znx_mul_power_of_two_inplace(k: i64, res: &mut [i64]) {
znx_mul_power_of_two_inplace_ref(k, res);
}
}
impl ZnxAutomorphism for FFT64Spqlios {
#[inline(always)]
fn znx_automorphism(p: i64, res: &mut [i64], a: &[i64]) {
znx_automorphism_ref(p, res, a);
}
}
impl ZnxCopy for FFT64Spqlios {
#[inline(always)]
fn znx_copy(res: &mut [i64], a: &[i64]) {
znx_copy_ref(res, a);
}
}
impl ZnxNegate for FFT64Spqlios {
#[inline(always)]
fn znx_negate(res: &mut [i64], src: &[i64]) {
znx_negate_ref(res, src);
}
}
impl ZnxNegateInplace for FFT64Spqlios {
#[inline(always)]
fn znx_negate_inplace(res: &mut [i64]) {
znx_negate_inplace_ref(res);
}
}
impl ZnxRotate for FFT64Spqlios {
#[inline(always)]
fn znx_rotate(p: i64, res: &mut [i64], src: &[i64]) {
znx_rotate::<Self>(p, res, src);
}
}
impl ZnxZero for FFT64Spqlios {
#[inline(always)]
fn znx_zero(res: &mut [i64]) {
znx_zero_ref(res);
}
}
impl ZnxSwitchRing for FFT64Spqlios {
#[inline(always)]
fn znx_switch_ring(res: &mut [i64], a: &[i64]) {
znx_switch_ring_ref(res, a);
}
}
impl ZnxNormalizeFinalStep for FFT64Spqlios {
#[inline(always)]
fn znx_normalize_final_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
znx_normalize_final_step_ref(base2k, lsh, x, a, carry);
}
}
impl ZnxNormalizeFinalStepInplace for FFT64Spqlios {
#[inline(always)]
fn znx_normalize_final_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
znx_normalize_final_step_inplace_ref(base2k, lsh, x, carry);
}
}
impl ZnxNormalizeFirstStep for FFT64Spqlios {
#[inline(always)]
fn znx_normalize_first_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
znx_normalize_first_step_ref(base2k, lsh, x, a, carry);
}
}
impl ZnxNormalizeFirstStepCarryOnly for FFT64Spqlios {
#[inline(always)]
fn znx_normalize_first_step_carry_only(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
znx_normalize_first_step_carry_only_ref(base2k, lsh, x, carry);
}
}
impl ZnxNormalizeFirstStepInplace for FFT64Spqlios {
#[inline(always)]
fn znx_normalize_first_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
znx_normalize_first_step_inplace_ref(base2k, lsh, x, carry);
}
}
impl ZnxNormalizeMiddleStep for FFT64Spqlios {
#[inline(always)]
fn znx_normalize_middle_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
znx_normalize_middle_step_ref(base2k, lsh, x, a, carry);
}
}
impl ZnxNormalizeMiddleStepCarryOnly for FFT64Spqlios {
#[inline(always)]
fn znx_normalize_middle_step_carry_only(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
znx_normalize_middle_step_carry_only_ref(base2k, lsh, x, carry);
}
}
impl ZnxNormalizeMiddleStepInplace for FFT64Spqlios {
#[inline(always)]
fn znx_normalize_middle_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
znx_normalize_middle_step_inplace_ref(base2k, lsh, x, carry);
}
}
impl ZnxExtractDigitAddMul for FFT64Spqlios {
#[inline(always)]
fn znx_extract_digit_addmul(base2k: usize, lsh: usize, res: &mut [i64], src: &mut [i64]) {
znx_extract_digit_addmul_ref(base2k, lsh, res, src);
}
}
impl ZnxNormalizeDigit for FFT64Spqlios {
#[inline(always)]
fn znx_normalize_digit(base2k: usize, res: &mut [i64], src: &mut [i64]) {
znx_normalize_digit_ref(base2k, res, src);
}
}

View File

@@ -5,15 +5,15 @@ cross_backend_test_suite! {
backend_ref = crate::cpu_fft64_ref::FFT64Ref,
backend_test = crate::cpu_spqlios::FFT64Spqlios,
size = 1 << 5,
basek = 12,
base2k = 12,
tests = {
test_vec_znx_add => poulpy_hal::test_suite::vec_znx::test_vec_znx_add,
test_vec_znx_add_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_inplace,
test_vec_znx_add_scalar => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_scalar,
test_vec_znx_add_scalar_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_scalar_inplace,
test_vec_znx_sub => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub,
test_vec_znx_sub_ab_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_ab_inplace,
test_vec_znx_sub_ba_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_ba_inplace,
test_vec_znx_sub_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_inplace,
test_vec_znx_sub_negate_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_negate_inplace,
test_vec_znx_sub_scalar => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_scalar,
test_vec_znx_sub_scalar_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_scalar_inplace,
test_vec_znx_rsh => poulpy_hal::test_suite::vec_znx::test_vec_znx_rsh,
@@ -41,7 +41,7 @@ cross_backend_test_suite! {
backend_ref = crate::cpu_fft64_ref::FFT64Ref,
backend_test = crate::cpu_spqlios::FFT64Spqlios,
size = 1 << 5,
basek = 12,
base2k = 12,
tests = {
test_svp_apply_dft_to_dft => poulpy_hal::test_suite::svp::test_svp_apply_dft_to_dft,
test_svp_apply_dft_to_dft_inplace => poulpy_hal::test_suite::svp::test_svp_apply_dft_to_dft_inplace,
@@ -53,20 +53,20 @@ cross_backend_test_suite! {
backend_ref = crate::cpu_fft64_ref::FFT64Ref,
backend_test = crate::cpu_spqlios::FFT64Spqlios,
size = 1 << 5,
basek = 12,
base2k = 12,
tests = {
test_vec_znx_big_add => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add,
test_vec_znx_big_add_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add_inplace,
test_vec_znx_big_add_small => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add_small,
test_vec_znx_big_add_small_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add_small_inplace,
test_vec_znx_big_sub => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub,
test_vec_znx_big_sub_ab_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_ab_inplace,
test_vec_znx_big_sub_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_inplace,
test_vec_znx_big_automorphism => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_automorphism,
test_vec_znx_big_automorphism_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_automorphism_inplace,
test_vec_znx_big_negate => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_negate,
test_vec_znx_big_negate_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_negate_inplace,
test_vec_znx_big_normalize => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_normalize,
test_vec_znx_big_sub_ba_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_ba_inplace,
test_vec_znx_big_sub_negate_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_negate_inplace,
test_vec_znx_big_sub_small_a => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_a,
test_vec_znx_big_sub_small_a_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_a_inplace,
test_vec_znx_big_sub_small_b => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_b,
@@ -79,13 +79,13 @@ cross_backend_test_suite! {
backend_ref = crate::cpu_fft64_ref::FFT64Ref,
backend_test = crate::cpu_spqlios::FFT64Spqlios,
size = 1 << 5,
basek = 12,
base2k = 12,
tests = {
test_vec_znx_dft_add => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_add,
test_vec_znx_dft_add_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_add_inplace,
test_vec_znx_dft_sub => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub,
test_vec_znx_dft_sub_ab_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub_ab_inplace,
test_vec_znx_dft_sub_ba_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub_ba_inplace,
test_vec_znx_dft_sub_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub_inplace,
test_vec_znx_dft_sub_negate_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub_negate_inplace,
test_vec_znx_idft_apply => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_idft_apply,
test_vec_znx_idft_apply_consume => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_idft_apply_consume,
test_vec_znx_idft_apply_tmpa => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_idft_apply_tmpa,
@@ -97,7 +97,7 @@ cross_backend_test_suite! {
backend_ref = crate::cpu_fft64_ref::FFT64Ref,
backend_test = crate::cpu_spqlios::FFT64Spqlios,
size = 1 << 5,
basek = 12,
base2k = 12,
tests = {
test_vmp_apply_dft_to_dft => poulpy_hal::test_suite::vmp::test_vmp_apply_dft_to_dft,
test_vmp_apply_dft_to_dft_add => poulpy_hal::test_suite::vmp::test_vmp_apply_dft_to_dft_add,