mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
Add cross-basek normalization (#90)
* added cross_basek_normalization * updated method signatures to take layouts * fixed cross-base normalization fix #91 fix #93
This commit is contained in:
committed by
GitHub
parent
4da790ea6a
commit
37e13b965c
@@ -7,7 +7,7 @@ use poulpy_hal::{
|
||||
fft64::{
|
||||
reim::{
|
||||
ReimAdd, ReimAddInplace, ReimAddMul, ReimCopy, ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimIFFTTable, ReimMul,
|
||||
ReimMulInplace, ReimNegate, ReimNegateInplace, ReimSub, ReimSubABInplace, ReimSubBAInplace, ReimToZnx,
|
||||
ReimMulInplace, ReimNegate, ReimNegateInplace, ReimSub, ReimSubInplace, ReimSubNegateInplace, ReimToZnx,
|
||||
ReimToZnxInplace, ReimZero, reim_copy_ref, reim_zero_ref,
|
||||
},
|
||||
reim4::{
|
||||
@@ -15,10 +15,11 @@ use poulpy_hal::{
|
||||
},
|
||||
},
|
||||
znx::{
|
||||
ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxNormalizeFinalStep,
|
||||
ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxExtractDigitAddMul, ZnxMulAddPowerOfTwo, ZnxMulPowerOfTwo,
|
||||
ZnxMulPowerOfTwoInplace, ZnxNegate, ZnxNegateInplace, ZnxNormalizeDigit, ZnxNormalizeFinalStep,
|
||||
ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeFirstStepInplace,
|
||||
ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, ZnxRotate, ZnxSub,
|
||||
ZnxSubABInplace, ZnxSubBAInplace, ZnxSwitchRing, ZnxZero, znx_copy_ref, znx_rotate, znx_zero_ref,
|
||||
ZnxSubInplace, ZnxSubNegateInplace, ZnxSwitchRing, ZnxZero, znx_copy_ref, znx_rotate, znx_zero_ref,
|
||||
},
|
||||
},
|
||||
};
|
||||
@@ -27,8 +28,8 @@ use crate::cpu_fft64_avx::{
|
||||
FFT64Avx,
|
||||
reim::{
|
||||
ReimFFTAvx, ReimIFFTAvx, reim_add_avx2_fma, reim_add_inplace_avx2_fma, reim_addmul_avx2_fma, reim_from_znx_i64_bnd50_fma,
|
||||
reim_mul_avx2_fma, reim_mul_inplace_avx2_fma, reim_negate_avx2_fma, reim_negate_inplace_avx2_fma,
|
||||
reim_sub_ab_inplace_avx2_fma, reim_sub_avx2_fma, reim_sub_ba_inplace_avx2_fma, reim_to_znx_i64_inplace_bnd63_avx2_fma,
|
||||
reim_mul_avx2_fma, reim_mul_inplace_avx2_fma, reim_negate_avx2_fma, reim_negate_inplace_avx2_fma, reim_sub_avx2_fma,
|
||||
reim_sub_inplace_avx2_fma, reim_sub_negate_inplace_avx2_fma, reim_to_znx_i64_inplace_bnd63_avx2_fma,
|
||||
},
|
||||
reim_to_znx_i64_bnd63_avx2_fma,
|
||||
reim4::{
|
||||
@@ -36,11 +37,12 @@ use crate::cpu_fft64_avx::{
|
||||
reim4_vec_mat1col_product_avx, reim4_vec_mat2cols_2ndcol_product_avx, reim4_vec_mat2cols_product_avx,
|
||||
},
|
||||
znx_avx::{
|
||||
znx_add_avx, znx_add_inplace_avx, znx_automorphism_avx, znx_negate_avx, znx_negate_inplace_avx,
|
||||
znx_normalize_final_step_avx, znx_normalize_final_step_inplace_avx, znx_normalize_first_step_avx,
|
||||
znx_normalize_first_step_carry_only_avx, znx_normalize_first_step_inplace_avx, znx_normalize_middle_step_avx,
|
||||
znx_normalize_middle_step_carry_only_avx, znx_normalize_middle_step_inplace_avx, znx_sub_ab_inplace_avx, znx_sub_avx,
|
||||
znx_sub_ba_inplace_avx, znx_switch_ring_avx,
|
||||
znx_add_avx, znx_add_inplace_avx, znx_automorphism_avx, znx_extract_digit_addmul_avx, znx_mul_add_power_of_two_avx,
|
||||
znx_mul_power_of_two_avx, znx_mul_power_of_two_inplace_avx, znx_negate_avx, znx_negate_inplace_avx,
|
||||
znx_normalize_digit_avx, znx_normalize_final_step_avx, znx_normalize_final_step_inplace_avx,
|
||||
znx_normalize_first_step_avx, znx_normalize_first_step_carry_only_avx, znx_normalize_first_step_inplace_avx,
|
||||
znx_normalize_middle_step_avx, znx_normalize_middle_step_carry_only_avx, znx_normalize_middle_step_inplace_avx,
|
||||
znx_sub_avx, znx_sub_inplace_avx, znx_sub_negate_inplace_avx, znx_switch_ring_avx,
|
||||
},
|
||||
};
|
||||
|
||||
@@ -131,20 +133,20 @@ impl ZnxSub for FFT64Avx {
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxSubABInplace for FFT64Avx {
|
||||
impl ZnxSubInplace for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn znx_sub_ab_inplace(res: &mut [i64], a: &[i64]) {
|
||||
fn znx_sub_inplace(res: &mut [i64], a: &[i64]) {
|
||||
unsafe {
|
||||
znx_sub_ab_inplace_avx(res, a);
|
||||
znx_sub_inplace_avx(res, a);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxSubBAInplace for FFT64Avx {
|
||||
impl ZnxSubNegateInplace for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn znx_sub_ba_inplace(res: &mut [i64], a: &[i64]) {
|
||||
fn znx_sub_negate_inplace(res: &mut [i64], a: &[i64]) {
|
||||
unsafe {
|
||||
znx_sub_ba_inplace_avx(res, a);
|
||||
znx_sub_negate_inplace_avx(res, a);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -183,6 +185,33 @@ impl ZnxNegateInplace for FFT64Avx {
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxMulAddPowerOfTwo for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn znx_muladd_power_of_two(k: i64, res: &mut [i64], a: &[i64]) {
|
||||
unsafe {
|
||||
znx_mul_add_power_of_two_avx(k, res, a);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxMulPowerOfTwo for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn znx_mul_power_of_two(k: i64, res: &mut [i64], a: &[i64]) {
|
||||
unsafe {
|
||||
znx_mul_power_of_two_avx(k, res, a);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxMulPowerOfTwoInplace for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn znx_mul_power_of_two_inplace(k: i64, res: &mut [i64]) {
|
||||
unsafe {
|
||||
znx_mul_power_of_two_inplace_avx(k, res);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxRotate for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn znx_rotate(p: i64, res: &mut [i64], src: &[i64]) {
|
||||
@@ -208,72 +237,90 @@ impl ZnxSwitchRing for FFT64Avx {
|
||||
|
||||
impl ZnxNormalizeFinalStep for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_final_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
fn znx_normalize_final_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
unsafe {
|
||||
znx_normalize_final_step_avx(basek, lsh, x, a, carry);
|
||||
znx_normalize_final_step_avx(base2k, lsh, x, a, carry);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFinalStepInplace for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_final_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
fn znx_normalize_final_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
unsafe {
|
||||
znx_normalize_final_step_inplace_avx(basek, lsh, x, carry);
|
||||
znx_normalize_final_step_inplace_avx(base2k, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFirstStep for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_first_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
fn znx_normalize_first_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
unsafe {
|
||||
znx_normalize_first_step_avx(basek, lsh, x, a, carry);
|
||||
znx_normalize_first_step_avx(base2k, lsh, x, a, carry);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFirstStepCarryOnly for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_first_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
|
||||
fn znx_normalize_first_step_carry_only(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
|
||||
unsafe {
|
||||
znx_normalize_first_step_carry_only_avx(basek, lsh, x, carry);
|
||||
znx_normalize_first_step_carry_only_avx(base2k, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFirstStepInplace for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_first_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
fn znx_normalize_first_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
unsafe {
|
||||
znx_normalize_first_step_inplace_avx(basek, lsh, x, carry);
|
||||
znx_normalize_first_step_inplace_avx(base2k, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeMiddleStep for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_middle_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
fn znx_normalize_middle_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
unsafe {
|
||||
znx_normalize_middle_step_avx(basek, lsh, x, a, carry);
|
||||
znx_normalize_middle_step_avx(base2k, lsh, x, a, carry);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeMiddleStepCarryOnly for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_middle_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
|
||||
fn znx_normalize_middle_step_carry_only(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
|
||||
unsafe {
|
||||
znx_normalize_middle_step_carry_only_avx(basek, lsh, x, carry);
|
||||
znx_normalize_middle_step_carry_only_avx(base2k, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeMiddleStepInplace for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_middle_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
fn znx_normalize_middle_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
unsafe {
|
||||
znx_normalize_middle_step_inplace_avx(basek, lsh, x, carry);
|
||||
znx_normalize_middle_step_inplace_avx(base2k, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxExtractDigitAddMul for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn znx_extract_digit_addmul(base2k: usize, lsh: usize, res: &mut [i64], src: &mut [i64]) {
|
||||
unsafe {
|
||||
znx_extract_digit_addmul_avx(base2k, lsh, res, src);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeDigit for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_digit(base2k: usize, res: &mut [i64], src: &mut [i64]) {
|
||||
unsafe {
|
||||
znx_normalize_digit_avx(base2k, res, src);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -346,20 +393,20 @@ impl ReimSub for FFT64Avx {
|
||||
}
|
||||
}
|
||||
|
||||
impl ReimSubABInplace for FFT64Avx {
|
||||
impl ReimSubInplace for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn reim_sub_ab_inplace(res: &mut [f64], a: &[f64]) {
|
||||
fn reim_sub_inplace(res: &mut [f64], a: &[f64]) {
|
||||
unsafe {
|
||||
reim_sub_ab_inplace_avx2_fma(res, a);
|
||||
reim_sub_inplace_avx2_fma(res, a);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReimSubBAInplace for FFT64Avx {
|
||||
impl ReimSubNegateInplace for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn reim_sub_ba_inplace(res: &mut [f64], a: &[f64]) {
|
||||
fn reim_sub_negate_inplace(res: &mut [f64], a: &[f64]) {
|
||||
unsafe {
|
||||
reim_sub_ba_inplace_avx2_fma(res, a);
|
||||
reim_sub_negate_inplace_avx2_fma(res, a);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -88,7 +88,7 @@ pub fn reim_sub_avx2_fma(res: &mut [f64], a: &[f64], b: &[f64]) {
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
pub fn reim_sub_ab_inplace_avx2_fma(res: &mut [f64], a: &[f64]) {
|
||||
pub fn reim_sub_inplace_avx2_fma(res: &mut [f64], a: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
@@ -115,7 +115,7 @@ pub fn reim_sub_ab_inplace_avx2_fma(res: &mut [f64], a: &[f64]) {
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
pub fn reim_sub_ba_inplace_avx2_fma(res: &mut [f64], a: &[f64]) {
|
||||
pub fn reim_sub_negate_inplace_avx2_fma(res: &mut [f64], a: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
|
||||
@@ -253,9 +253,6 @@ fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8]
|
||||
(take_slice, rem_slice)
|
||||
}
|
||||
} else {
|
||||
panic!(
|
||||
"Attempted to take {} from scratch with {} aligned bytes left",
|
||||
take_len, aligned_len,
|
||||
);
|
||||
panic!("Attempted to take {take_len} from scratch with {aligned_len} aligned bytes left");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,15 +5,15 @@ cross_backend_test_suite! {
|
||||
backend_ref = crate::cpu_fft64_ref::FFT64Ref,
|
||||
backend_test = crate::cpu_fft64_avx::FFT64Avx,
|
||||
size = 1 << 5,
|
||||
basek = 12,
|
||||
base2k = 12,
|
||||
tests = {
|
||||
test_vec_znx_add => poulpy_hal::test_suite::vec_znx::test_vec_znx_add,
|
||||
test_vec_znx_add_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_inplace,
|
||||
test_vec_znx_add_scalar => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_scalar,
|
||||
test_vec_znx_add_scalar_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_scalar_inplace,
|
||||
test_vec_znx_sub => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub,
|
||||
test_vec_znx_sub_ab_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_ab_inplace,
|
||||
test_vec_znx_sub_ba_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_ba_inplace,
|
||||
test_vec_znx_sub_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_inplace,
|
||||
test_vec_znx_sub_negate_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_negate_inplace,
|
||||
test_vec_znx_sub_scalar => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_scalar,
|
||||
test_vec_znx_sub_scalar_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_scalar_inplace,
|
||||
test_vec_znx_rsh => poulpy_hal::test_suite::vec_znx::test_vec_znx_rsh,
|
||||
@@ -41,7 +41,7 @@ cross_backend_test_suite! {
|
||||
backend_ref = crate::cpu_fft64_ref::FFT64Ref,
|
||||
backend_test = crate::cpu_fft64_avx::FFT64Avx,
|
||||
size = 1 << 5,
|
||||
basek = 12,
|
||||
base2k = 12,
|
||||
tests = {
|
||||
test_svp_apply_dft_to_dft => poulpy_hal::test_suite::svp::test_svp_apply_dft_to_dft,
|
||||
test_svp_apply_dft_to_dft_inplace => poulpy_hal::test_suite::svp::test_svp_apply_dft_to_dft_inplace,
|
||||
@@ -53,20 +53,20 @@ cross_backend_test_suite! {
|
||||
backend_ref = crate::cpu_fft64_ref::FFT64Ref,
|
||||
backend_test = crate::cpu_fft64_avx::FFT64Avx,
|
||||
size = 1 << 5,
|
||||
basek = 12,
|
||||
base2k = 12,
|
||||
tests = {
|
||||
test_vec_znx_big_add => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add,
|
||||
test_vec_znx_big_add_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add_inplace,
|
||||
test_vec_znx_big_add_small => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add_small,
|
||||
test_vec_znx_big_add_small_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add_small_inplace,
|
||||
test_vec_znx_big_sub => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub,
|
||||
test_vec_znx_big_sub_ab_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_ab_inplace,
|
||||
test_vec_znx_big_sub_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_inplace,
|
||||
test_vec_znx_big_automorphism => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_automorphism,
|
||||
test_vec_znx_big_automorphism_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_automorphism_inplace,
|
||||
test_vec_znx_big_negate => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_negate,
|
||||
test_vec_znx_big_negate_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_negate_inplace,
|
||||
test_vec_znx_big_normalize => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_normalize,
|
||||
test_vec_znx_big_sub_ba_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_ba_inplace,
|
||||
test_vec_znx_big_sub_negate_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_negate_inplace,
|
||||
test_vec_znx_big_sub_small_a => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_a,
|
||||
test_vec_znx_big_sub_small_a_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_a_inplace,
|
||||
test_vec_znx_big_sub_small_b => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_b,
|
||||
@@ -79,13 +79,13 @@ cross_backend_test_suite! {
|
||||
backend_ref = crate::cpu_fft64_ref::FFT64Ref,
|
||||
backend_test = crate::cpu_fft64_avx::FFT64Avx,
|
||||
size = 1 << 5,
|
||||
basek = 12,
|
||||
base2k = 12,
|
||||
tests = {
|
||||
test_vec_znx_dft_add => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_add,
|
||||
test_vec_znx_dft_add_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_add_inplace,
|
||||
test_vec_znx_dft_sub => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub,
|
||||
test_vec_znx_dft_sub_ab_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub_ab_inplace,
|
||||
test_vec_znx_dft_sub_ba_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub_ba_inplace,
|
||||
test_vec_znx_dft_sub_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub_inplace,
|
||||
test_vec_znx_dft_sub_negate_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub_negate_inplace,
|
||||
test_vec_znx_idft_apply => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_idft_apply,
|
||||
test_vec_znx_idft_apply_consume => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_idft_apply_consume,
|
||||
test_vec_znx_idft_apply_tmpa => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_idft_apply_tmpa,
|
||||
@@ -97,7 +97,7 @@ cross_backend_test_suite! {
|
||||
backend_ref = crate::cpu_fft64_ref::FFT64Ref,
|
||||
backend_test = crate::cpu_fft64_avx::FFT64Avx,
|
||||
size = 1 << 5,
|
||||
basek = 12,
|
||||
base2k = 12,
|
||||
tests = {
|
||||
test_vmp_apply_dft_to_dft => poulpy_hal::test_suite::vmp::test_vmp_apply_dft_to_dft,
|
||||
test_vmp_apply_dft_to_dft_add => poulpy_hal::test_suite::vmp::test_vmp_apply_dft_to_dft_add,
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
TakeSlice, VecZnxAutomorphismInplaceTmpBytes, VecZnxMergeRingsTmpBytes, VecZnxMulXpMinusOneInplaceTmpBytes,
|
||||
VecZnxNormalizeTmpBytes, VecZnxRotateInplaceTmpBytes, VecZnxSplitRingTmpBytes,
|
||||
TakeSlice, VecZnxAutomorphismInplaceTmpBytes, VecZnxLshTmpBytes, VecZnxMergeRingsTmpBytes,
|
||||
VecZnxMulXpMinusOneInplaceTmpBytes, VecZnxNormalizeTmpBytes, VecZnxRotateInplaceTmpBytes, VecZnxRshTmpBytes,
|
||||
VecZnxSplitRingTmpBytes,
|
||||
},
|
||||
layouts::{Module, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef},
|
||||
oep::{
|
||||
@@ -12,7 +13,7 @@ use poulpy_hal::{
|
||||
VecZnxMulXpMinusOneInplaceTmpBytesImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl,
|
||||
VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl,
|
||||
VecZnxRotateInplaceTmpBytesImpl, VecZnxRshImpl, VecZnxRshInplaceImpl, VecZnxRshTmpBytesImpl, VecZnxSplitRingImpl,
|
||||
VecZnxSplitRingTmpBytesImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarImpl,
|
||||
VecZnxSplitRingTmpBytesImpl, VecZnxSubImpl, VecZnxSubInplaceImpl, VecZnxSubNegateInplaceImpl, VecZnxSubScalarImpl,
|
||||
VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl,
|
||||
},
|
||||
reference::vec_znx::{
|
||||
@@ -23,7 +24,7 @@ use poulpy_hal::{
|
||||
vec_znx_mul_xp_minus_one_inplace_tmp_bytes, vec_znx_negate, vec_znx_negate_inplace, vec_znx_normalize,
|
||||
vec_znx_normalize_inplace, vec_znx_normalize_tmp_bytes, vec_znx_rotate, vec_znx_rotate_inplace,
|
||||
vec_znx_rotate_inplace_tmp_bytes, vec_znx_rsh, vec_znx_rsh_inplace, vec_znx_rsh_tmp_bytes, vec_znx_split_ring,
|
||||
vec_znx_split_ring_tmp_bytes, vec_znx_sub, vec_znx_sub_ab_inplace, vec_znx_sub_ba_inplace, vec_znx_sub_scalar,
|
||||
vec_znx_split_ring_tmp_bytes, vec_znx_sub, vec_znx_sub_inplace, vec_znx_sub_negate_inplace, vec_znx_sub_scalar,
|
||||
vec_znx_sub_scalar_inplace, vec_znx_switch_ring,
|
||||
},
|
||||
source::Source,
|
||||
@@ -43,9 +44,10 @@ where
|
||||
{
|
||||
fn vec_znx_normalize_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
res_basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a_basek: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
@@ -54,7 +56,7 @@ where
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_normalize::<R, A, Self>(basek, res, res_col, a, a_col, carry);
|
||||
vec_znx_normalize::<R, A, Self>(res_basek, res, res_col, a_basek, a, a_col, carry);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -64,7 +66,7 @@ where
|
||||
{
|
||||
fn vec_znx_normalize_inplace_impl<R>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
base2k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
@@ -72,7 +74,7 @@ where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_normalize_inplace::<R, Self>(basek, res, res_col, carry);
|
||||
vec_znx_normalize_inplace::<R, Self>(base2k, res, res_col, carry);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -143,23 +145,23 @@ unsafe impl VecZnxSubImpl<Self> for FFT64Avx {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxSubABInplaceImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_sub_ab_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
unsafe impl VecZnxSubInplaceImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_sub_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
vec_znx_sub_ab_inplace::<R, A, Self>(res, res_col, a, a_col);
|
||||
vec_znx_sub_inplace::<R, A, Self>(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxSubBAInplaceImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_sub_ba_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
unsafe impl VecZnxSubNegateInplaceImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_sub_negate_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
vec_znx_sub_ba_inplace::<R, A, Self>(res, res_col, a, a_col);
|
||||
vec_znx_sub_negate_inplace::<R, A, Self>(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -234,9 +236,9 @@ where
|
||||
Module<Self>: VecZnxNormalizeTmpBytes,
|
||||
Scratch<Self>: TakeSlice,
|
||||
{
|
||||
fn vec_znx_lsh_inplace_impl<R, A>(
|
||||
fn vec_znx_lsh_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
base2k: usize,
|
||||
k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
@@ -247,8 +249,8 @@ where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_lsh::<_, _, Self>(basek, k, res, res_col, a, a_col, carry);
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_lsh_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_lsh::<_, _, Self>(base2k, k, res, res_col, a, a_col, carry);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -259,7 +261,7 @@ where
|
||||
{
|
||||
fn vec_znx_lsh_inplace_impl<A>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
base2k: usize,
|
||||
k: usize,
|
||||
a: &mut A,
|
||||
a_col: usize,
|
||||
@@ -267,8 +269,8 @@ where
|
||||
) where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_lsh_inplace::<_, Self>(basek, k, a, a_col, carry);
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_lsh_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_lsh_inplace::<_, Self>(base2k, k, a, a_col, carry);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -277,9 +279,9 @@ where
|
||||
Module<Self>: VecZnxNormalizeTmpBytes,
|
||||
Scratch<Self>: TakeSlice,
|
||||
{
|
||||
fn vec_znx_rsh_inplace_impl<R, A>(
|
||||
fn vec_znx_rsh_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
base2k: usize,
|
||||
k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
@@ -290,8 +292,8 @@ where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_rsh::<_, _, Self>(basek, k, res, res_col, a, a_col, carry);
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_rsh_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_rsh::<_, _, Self>(base2k, k, res, res_col, a, a_col, carry);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -302,7 +304,7 @@ where
|
||||
{
|
||||
fn vec_znx_rsh_inplace_impl<A>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
base2k: usize,
|
||||
k: usize,
|
||||
a: &mut A,
|
||||
a_col: usize,
|
||||
@@ -310,8 +312,8 @@ where
|
||||
) where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_rsh_inplace::<_, Self>(basek, k, a, a_col, carry);
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_rsh_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_rsh_inplace::<_, Self>(base2k, k, a, a_col, carry);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -495,18 +497,18 @@ unsafe impl VecZnxCopyImpl<Self> for FFT64Avx {
|
||||
}
|
||||
|
||||
unsafe impl VecZnxFillUniformImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_fill_uniform_impl<R>(_module: &Module<Self>, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
|
||||
fn vec_znx_fill_uniform_impl<R>(_module: &Module<Self>, base2k: usize, res: &mut R, res_col: usize, source: &mut Source)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
vec_znx_fill_uniform_ref(basek, res, res_col, source)
|
||||
vec_znx_fill_uniform_ref(base2k, res, res_col, source)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxFillNormalImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_fill_normal_impl<R>(
|
||||
_module: &Module<Self>,
|
||||
basek: usize,
|
||||
base2k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
@@ -516,14 +518,14 @@ unsafe impl VecZnxFillNormalImpl<Self> for FFT64Avx {
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
vec_znx_fill_normal_ref(basek, res, res_col, k, sigma, bound, source);
|
||||
vec_znx_fill_normal_ref(base2k, res, res_col, k, sigma, bound, source);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxAddNormalImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_add_normal_impl<R>(
|
||||
_module: &Module<Self>,
|
||||
basek: usize,
|
||||
base2k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
@@ -533,6 +535,6 @@ unsafe impl VecZnxAddNormalImpl<Self> for FFT64Avx {
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
vec_znx_add_normal_ref(basek, res, res_col, k, sigma, bound, source);
|
||||
vec_znx_add_normal_ref(base2k, res, res_col, k, sigma, bound, source);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,15 +10,15 @@ use poulpy_hal::{
|
||||
VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, VecZnxBigAutomorphismImpl,
|
||||
VecZnxBigAutomorphismInplaceImpl, VecZnxBigAutomorphismInplaceTmpBytesImpl, VecZnxBigFromBytesImpl,
|
||||
VecZnxBigFromSmallImpl, VecZnxBigNegateImpl, VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl,
|
||||
VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubABInplaceImpl, VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl,
|
||||
VecZnxBigSubSmallAImpl, VecZnxBigSubSmallAInplaceImpl, VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl,
|
||||
VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubImpl, VecZnxBigSubInplaceImpl, VecZnxBigSubNegateInplaceImpl,
|
||||
VecZnxBigSubSmallAImpl, VecZnxBigSubSmallBImpl, VecZnxBigSubSmallInplaceImpl, VecZnxBigSubSmallNegateInplaceImpl,
|
||||
},
|
||||
reference::{
|
||||
fft64::vec_znx_big::{
|
||||
vec_znx_big_add, vec_znx_big_add_inplace, vec_znx_big_add_normal_ref, vec_znx_big_add_small,
|
||||
vec_znx_big_add_small_inplace, vec_znx_big_automorphism, vec_znx_big_automorphism_inplace,
|
||||
vec_znx_big_automorphism_inplace_tmp_bytes, vec_znx_big_negate, vec_znx_big_negate_inplace, vec_znx_big_normalize,
|
||||
vec_znx_big_normalize_tmp_bytes, vec_znx_big_sub, vec_znx_big_sub_ab_inplace, vec_znx_big_sub_ba_inplace,
|
||||
vec_znx_big_normalize_tmp_bytes, vec_znx_big_sub, vec_znx_big_sub_inplace, vec_znx_big_sub_negate_inplace,
|
||||
vec_znx_big_sub_small_a, vec_znx_big_sub_small_a_inplace, vec_znx_big_sub_small_b, vec_znx_big_sub_small_b_inplace,
|
||||
},
|
||||
znx::{znx_copy_ref, znx_zero_ref},
|
||||
@@ -76,7 +76,7 @@ unsafe impl VecZnxBigFromSmallImpl<Self> for FFT64Avx {
|
||||
unsafe impl VecZnxBigAddNormalImpl<Self> for FFT64Avx {
|
||||
fn add_normal_impl<R: VecZnxBigToMut<Self>>(
|
||||
_module: &Module<Self>,
|
||||
basek: usize,
|
||||
base2k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
@@ -84,7 +84,7 @@ unsafe impl VecZnxBigAddNormalImpl<Self> for FFT64Avx {
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) {
|
||||
vec_znx_big_add_normal_ref(basek, res, res_col, k, sigma, bound, source);
|
||||
vec_znx_big_add_normal_ref(base2k, res, res_col, k, sigma, bound, source);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -167,25 +167,25 @@ unsafe impl VecZnxBigSubImpl<Self> for FFT64Avx {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubABInplaceImpl<Self> for FFT64Avx {
|
||||
unsafe impl VecZnxBigSubInplaceImpl<Self> for FFT64Avx {
|
||||
/// Subtracts `a` from `b` and stores the result on `b`.
|
||||
fn vec_znx_big_sub_ab_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
fn vec_znx_big_sub_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxBigToRef<Self>,
|
||||
{
|
||||
vec_znx_big_sub_ab_inplace(res, res_col, a, a_col);
|
||||
vec_znx_big_sub_inplace(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubBAInplaceImpl<Self> for FFT64Avx {
|
||||
unsafe impl VecZnxBigSubNegateInplaceImpl<Self> for FFT64Avx {
|
||||
/// Subtracts `b` from `a` and stores the result on `b`.
|
||||
fn vec_znx_big_sub_ba_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
fn vec_znx_big_sub_negate_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxBigToRef<Self>,
|
||||
{
|
||||
vec_znx_big_sub_ba_inplace(res, res_col, a, a_col);
|
||||
vec_znx_big_sub_negate_inplace(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -208,9 +208,9 @@ unsafe impl VecZnxBigSubSmallAImpl<Self> for FFT64Avx {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubSmallAInplaceImpl<Self> for FFT64Avx {
|
||||
unsafe impl VecZnxBigSubSmallInplaceImpl<Self> for FFT64Avx {
|
||||
/// Subtracts `a` from `res` and stores the result on `res`.
|
||||
fn vec_znx_big_sub_small_a_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
fn vec_znx_big_sub_small_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxToRef,
|
||||
@@ -238,9 +238,9 @@ unsafe impl VecZnxBigSubSmallBImpl<Self> for FFT64Avx {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubSmallBInplaceImpl<Self> for FFT64Avx {
|
||||
unsafe impl VecZnxBigSubSmallNegateInplaceImpl<Self> for FFT64Avx {
|
||||
/// Subtracts `res` from `a` and stores the result on `res`.
|
||||
fn vec_znx_big_sub_small_b_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
fn vec_znx_big_sub_small_negate_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxToRef,
|
||||
@@ -280,9 +280,10 @@ where
|
||||
{
|
||||
fn vec_znx_big_normalize_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
res_basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a_basek: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
@@ -291,7 +292,7 @@ where
|
||||
A: VecZnxBigToRef<Self>,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_big_normalize(basek, res, res_col, a, a_col, carry);
|
||||
vec_znx_big_normalize(res_basek, res, res_col, a_basek, a, a_col, carry);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -326,7 +327,7 @@ where
|
||||
) where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
{
|
||||
let (tmp, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes() / size_of::<i64>());
|
||||
let (tmp, _) = scratch.take_slice(module.vec_znx_big_automorphism_inplace_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_big_automorphism_inplace(p, res, res_col, tmp);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,12 +5,12 @@ use poulpy_hal::{
|
||||
},
|
||||
oep::{
|
||||
VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftApplyImpl,
|
||||
VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl, VecZnxDftSubImpl,
|
||||
VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubImpl, VecZnxDftSubInplaceImpl, VecZnxDftSubNegateInplaceImpl,
|
||||
VecZnxDftZeroImpl, VecZnxIdftApplyConsumeImpl, VecZnxIdftApplyImpl, VecZnxIdftApplyTmpAImpl, VecZnxIdftApplyTmpBytesImpl,
|
||||
},
|
||||
reference::fft64::vec_znx_dft::{
|
||||
vec_znx_dft_add, vec_znx_dft_add_inplace, vec_znx_dft_apply, vec_znx_dft_copy, vec_znx_dft_sub,
|
||||
vec_znx_dft_sub_ab_inplace, vec_znx_dft_sub_ba_inplace, vec_znx_dft_zero, vec_znx_idft_apply, vec_znx_idft_apply_consume,
|
||||
vec_znx_dft_add, vec_znx_dft_add_inplace, vec_znx_dft_apply, vec_znx_dft_copy, vec_znx_dft_sub, vec_znx_dft_sub_inplace,
|
||||
vec_znx_dft_sub_negate_inplace, vec_znx_dft_zero, vec_znx_idft_apply, vec_znx_idft_apply_consume,
|
||||
vec_znx_idft_apply_tmpa,
|
||||
},
|
||||
};
|
||||
@@ -139,23 +139,23 @@ unsafe impl VecZnxDftSubImpl<Self> for FFT64Avx {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftSubABInplaceImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_dft_sub_ab_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
unsafe impl VecZnxDftSubInplaceImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_dft_sub_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
{
|
||||
vec_znx_dft_sub_ab_inplace(res, res_col, a, a_col);
|
||||
vec_znx_dft_sub_inplace(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftSubBAInplaceImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_dft_sub_ba_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
unsafe impl VecZnxDftSubNegateInplaceImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_dft_sub_negate_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
{
|
||||
vec_znx_dft_sub_ba_inplace(res, res_col, a, a_col);
|
||||
vec_znx_dft_sub_negate_inplace(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -18,21 +18,21 @@ unsafe impl ZnNormalizeInplaceImpl<Self> for FFT64Avx
|
||||
where
|
||||
Self: TakeSliceImpl<Self>,
|
||||
{
|
||||
fn zn_normalize_inplace_impl<R>(n: usize, basek: usize, res: &mut R, res_col: usize, scratch: &mut Scratch<Self>)
|
||||
fn zn_normalize_inplace_impl<R>(n: usize, base2k: usize, res: &mut R, res_col: usize, scratch: &mut Scratch<Self>)
|
||||
where
|
||||
R: ZnToMut,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(n);
|
||||
zn_normalize_inplace::<R, FFT64Avx>(n, basek, res, res_col, carry);
|
||||
zn_normalize_inplace::<R, FFT64Avx>(n, base2k, res, res_col, carry);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ZnFillUniformImpl<Self> for FFT64Avx {
|
||||
fn zn_fill_uniform_impl<R>(n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
|
||||
fn zn_fill_uniform_impl<R>(n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source)
|
||||
where
|
||||
R: ZnToMut,
|
||||
{
|
||||
zn_fill_uniform(n, basek, res, res_col, source);
|
||||
zn_fill_uniform(n, base2k, res, res_col, source);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ unsafe impl ZnFillNormalImpl<Self> for FFT64Avx {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn zn_fill_normal_impl<R>(
|
||||
n: usize,
|
||||
basek: usize,
|
||||
base2k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
@@ -50,7 +50,7 @@ unsafe impl ZnFillNormalImpl<Self> for FFT64Avx {
|
||||
) where
|
||||
R: ZnToMut,
|
||||
{
|
||||
zn_fill_normal(n, basek, res, res_col, k, source, sigma, bound);
|
||||
zn_fill_normal(n, base2k, res, res_col, k, source, sigma, bound);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,7 +58,7 @@ unsafe impl ZnAddNormalImpl<Self> for FFT64Avx {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn zn_add_normal_impl<R>(
|
||||
n: usize,
|
||||
basek: usize,
|
||||
base2k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
@@ -68,6 +68,6 @@ unsafe impl ZnAddNormalImpl<Self> for FFT64Avx {
|
||||
) where
|
||||
R: ZnToMut,
|
||||
{
|
||||
zn_add_normal(n, basek, res, res_col, k, source, sigma, bound);
|
||||
zn_add_normal(n, base2k, res, res_col, k, source, sigma, bound);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
mod add;
|
||||
mod automorphism;
|
||||
mod mul;
|
||||
mod neg;
|
||||
mod normalization;
|
||||
mod sub;
|
||||
@@ -7,6 +8,7 @@ mod switch_ring;
|
||||
|
||||
pub(crate) use add::*;
|
||||
pub(crate) use automorphism::*;
|
||||
pub(crate) use mul::*;
|
||||
pub(crate) use neg::*;
|
||||
pub(crate) use normalization::*;
|
||||
pub(crate) use sub::*;
|
||||
|
||||
318
poulpy-backend/src/cpu_fft64_avx/znx_avx/mul.rs
Normal file
318
poulpy-backend/src/cpu_fft64_avx/znx_avx/mul.rs
Normal file
@@ -0,0 +1,318 @@
|
||||
/// Multiply/divide by a power of two with rounding matching [poulpy_hal::reference::znx::znx_mul_power_of_two_ref].
|
||||
///
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
/// all inputs must have the same length and must not alias.
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx2")]
|
||||
pub unsafe fn znx_mul_power_of_two_avx(k: i64, res: &mut [i64], a: &[i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), a.len());
|
||||
}
|
||||
|
||||
use core::arch::x86_64::{
|
||||
__m128i, __m256i, _mm_cvtsi32_si128, _mm256_add_epi64, _mm256_and_si256, _mm256_cmpgt_epi64, _mm256_loadu_si256,
|
||||
_mm256_or_si256, _mm256_set1_epi64x, _mm256_setzero_si256, _mm256_sll_epi64, _mm256_srl_epi64, _mm256_srli_epi64,
|
||||
_mm256_storeu_si256, _mm256_sub_epi64,
|
||||
};
|
||||
|
||||
let n: usize = res.len();
|
||||
|
||||
if n == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
if k == 0 {
|
||||
use poulpy_hal::reference::znx::znx_copy_ref;
|
||||
znx_copy_ref(res, a);
|
||||
return;
|
||||
}
|
||||
|
||||
let span: usize = n >> 2; // number of 256-bit chunks
|
||||
|
||||
unsafe {
|
||||
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
|
||||
let mut aa: *const __m256i = a.as_ptr() as *const __m256i;
|
||||
|
||||
if k > 0 {
|
||||
// Left shift by k (variable count).
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
debug_assert!(k <= 63);
|
||||
}
|
||||
let cnt128: __m128i = _mm_cvtsi32_si128(k as i32);
|
||||
for _ in 0..span {
|
||||
let x: __m256i = _mm256_loadu_si256(aa);
|
||||
let y: __m256i = _mm256_sll_epi64(x, cnt128);
|
||||
_mm256_storeu_si256(rr, y);
|
||||
rr = rr.add(1);
|
||||
aa = aa.add(1);
|
||||
}
|
||||
|
||||
// tail
|
||||
if !n.is_multiple_of(4) {
|
||||
use poulpy_hal::reference::znx::znx_mul_power_of_two_ref;
|
||||
|
||||
znx_mul_power_of_two_ref(k, &mut res[span << 2..], &a[span << 2..]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// k < 0 => arithmetic right shift with rounding:
|
||||
// for each x:
|
||||
// sign_bit = (x >> 63) & 1
|
||||
// bias = (1<<(kp-1)) - sign_bit
|
||||
// t = x + bias
|
||||
// y = t >> kp (arithmetic)
|
||||
let kp = -k;
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
debug_assert!((1..=63).contains(&kp));
|
||||
}
|
||||
|
||||
let cnt_right: __m128i = _mm_cvtsi32_si128(kp as i32);
|
||||
let bias_base: __m256i = _mm256_set1_epi64x(1_i64 << (kp - 1));
|
||||
let top_mask: __m256i = _mm256_set1_epi64x(-1_i64 << (64 - kp)); // high kp bits
|
||||
let zero: __m256i = _mm256_setzero_si256();
|
||||
|
||||
for _ in 0..span {
|
||||
let x = _mm256_loadu_si256(aa);
|
||||
|
||||
// bias = (1 << (kp-1)) - sign_bit
|
||||
let sign_bit_x: __m256i = _mm256_srli_epi64(x, 63);
|
||||
let bias: __m256i = _mm256_sub_epi64(bias_base, sign_bit_x);
|
||||
|
||||
// t = x + bias
|
||||
let t: __m256i = _mm256_add_epi64(x, bias);
|
||||
|
||||
// logical shift
|
||||
let lsr: __m256i = _mm256_srl_epi64(t, cnt_right);
|
||||
|
||||
// sign extension
|
||||
let neg: __m256i = _mm256_cmpgt_epi64(zero, t);
|
||||
let fill: __m256i = _mm256_and_si256(neg, top_mask);
|
||||
let y: __m256i = _mm256_or_si256(lsr, fill);
|
||||
|
||||
_mm256_storeu_si256(rr, y);
|
||||
rr = rr.add(1);
|
||||
aa = aa.add(1);
|
||||
}
|
||||
}
|
||||
|
||||
// tail
|
||||
if !n.is_multiple_of(4) {
|
||||
use poulpy_hal::reference::znx::znx_mul_power_of_two_ref;
|
||||
|
||||
znx_mul_power_of_two_ref(k, &mut res[span << 2..], &a[span << 2..]);
|
||||
}
|
||||
}
|
||||
|
||||
/// Multiply/divide inplace by a power of two with rounding matching [poulpy_hal::reference::znx::znx_mul_power_of_two_inplace_ref].
|
||||
///
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
/// all inputs must have the same length and must not alias.
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx2")]
|
||||
pub unsafe fn znx_mul_power_of_two_inplace_avx(k: i64, res: &mut [i64]) {
|
||||
use core::arch::x86_64::{
|
||||
__m128i, __m256i, _mm_cvtsi32_si128, _mm256_add_epi64, _mm256_and_si256, _mm256_cmpgt_epi64, _mm256_loadu_si256,
|
||||
_mm256_or_si256, _mm256_set1_epi64x, _mm256_setzero_si256, _mm256_sll_epi64, _mm256_srl_epi64, _mm256_srli_epi64,
|
||||
_mm256_storeu_si256, _mm256_sub_epi64,
|
||||
};
|
||||
|
||||
let n: usize = res.len();
|
||||
|
||||
if n == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
if k == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
let span: usize = n >> 2; // number of 256-bit chunks
|
||||
|
||||
unsafe {
|
||||
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
|
||||
|
||||
if k > 0 {
|
||||
// Left shift by k (variable count).
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
debug_assert!(k <= 63);
|
||||
}
|
||||
let cnt128: __m128i = _mm_cvtsi32_si128(k as i32);
|
||||
for _ in 0..span {
|
||||
let x: __m256i = _mm256_loadu_si256(rr);
|
||||
let y: __m256i = _mm256_sll_epi64(x, cnt128);
|
||||
_mm256_storeu_si256(rr, y);
|
||||
rr = rr.add(1);
|
||||
}
|
||||
|
||||
// tail
|
||||
if !n.is_multiple_of(4) {
|
||||
use poulpy_hal::reference::znx::znx_mul_power_of_two_inplace_ref;
|
||||
znx_mul_power_of_two_inplace_ref(k, &mut res[span << 2..]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// k < 0 => arithmetic right shift with rounding:
|
||||
// for each x:
|
||||
// sign_bit = (x >> 63) & 1
|
||||
// bias = (1<<(kp-1)) - sign_bit
|
||||
// t = x + bias
|
||||
// y = t >> kp (arithmetic)
|
||||
let kp = -k;
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
debug_assert!((1..=63).contains(&kp));
|
||||
}
|
||||
|
||||
let cnt_right: __m128i = _mm_cvtsi32_si128(kp as i32);
|
||||
let bias_base: __m256i = _mm256_set1_epi64x(1_i64 << (kp - 1));
|
||||
let top_mask: __m256i = _mm256_set1_epi64x(-1_i64 << (64 - kp)); // high kp bits
|
||||
let zero: __m256i = _mm256_setzero_si256();
|
||||
|
||||
for _ in 0..span {
|
||||
let x = _mm256_loadu_si256(rr);
|
||||
|
||||
// bias = (1 << (kp-1)) - sign_bit
|
||||
let sign_bit_x: __m256i = _mm256_srli_epi64(x, 63);
|
||||
let bias: __m256i = _mm256_sub_epi64(bias_base, sign_bit_x);
|
||||
|
||||
// t = x + bias
|
||||
let t: __m256i = _mm256_add_epi64(x, bias);
|
||||
|
||||
// logical shift
|
||||
let lsr: __m256i = _mm256_srl_epi64(t, cnt_right);
|
||||
|
||||
// sign extension
|
||||
let neg: __m256i = _mm256_cmpgt_epi64(zero, t);
|
||||
let fill: __m256i = _mm256_and_si256(neg, top_mask);
|
||||
let y: __m256i = _mm256_or_si256(lsr, fill);
|
||||
|
||||
_mm256_storeu_si256(rr, y);
|
||||
rr = rr.add(1);
|
||||
}
|
||||
}
|
||||
|
||||
// tail
|
||||
if !n.is_multiple_of(4) {
|
||||
use poulpy_hal::reference::znx::znx_mul_power_of_two_inplace_ref;
|
||||
znx_mul_power_of_two_inplace_ref(k, &mut res[span << 2..]);
|
||||
}
|
||||
}
|
||||
|
||||
/// Multiply/divide by a power of two and add on the result with rounding matching [poulpy_hal::reference::znx::znx_mul_power_of_two_inplace_ref].
|
||||
///
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
/// all inputs must have the same length and must not alias.
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx2")]
|
||||
pub unsafe fn znx_mul_add_power_of_two_avx(k: i64, res: &mut [i64], a: &[i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), a.len());
|
||||
}
|
||||
|
||||
use core::arch::x86_64::{
|
||||
__m128i, __m256i, _mm_cvtsi32_si128, _mm256_add_epi64, _mm256_and_si256, _mm256_cmpgt_epi64, _mm256_loadu_si256,
|
||||
_mm256_or_si256, _mm256_set1_epi64x, _mm256_setzero_si256, _mm256_sll_epi64, _mm256_srl_epi64, _mm256_srli_epi64,
|
||||
_mm256_storeu_si256, _mm256_sub_epi64,
|
||||
};
|
||||
|
||||
let n: usize = res.len();
|
||||
|
||||
if n == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
if k == 0 {
|
||||
use crate::cpu_fft64_avx::znx_avx::znx_add_inplace_avx;
|
||||
|
||||
znx_add_inplace_avx(res, a);
|
||||
return;
|
||||
}
|
||||
|
||||
let span: usize = n >> 2; // number of 256-bit chunks
|
||||
|
||||
unsafe {
|
||||
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
|
||||
let mut aa: *const __m256i = a.as_ptr() as *const __m256i;
|
||||
|
||||
if k > 0 {
|
||||
// Left shift by k (variable count).
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
debug_assert!(k <= 63);
|
||||
}
|
||||
let cnt128: __m128i = _mm_cvtsi32_si128(k as i32);
|
||||
for _ in 0..span {
|
||||
let x: __m256i = _mm256_loadu_si256(aa);
|
||||
let y: __m256i = _mm256_loadu_si256(rr);
|
||||
_mm256_storeu_si256(rr, _mm256_add_epi64(y, _mm256_sll_epi64(x, cnt128)));
|
||||
rr = rr.add(1);
|
||||
aa = aa.add(1);
|
||||
}
|
||||
|
||||
// tail
|
||||
if !n.is_multiple_of(4) {
|
||||
use poulpy_hal::reference::znx::znx_mul_add_power_of_two_ref;
|
||||
|
||||
znx_mul_add_power_of_two_ref(k, &mut res[span << 2..], &a[span << 2..]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// k < 0 => arithmetic right shift with rounding:
|
||||
// for each x:
|
||||
// sign_bit = (x >> 63) & 1
|
||||
// bias = (1<<(kp-1)) - sign_bit
|
||||
// t = x + bias
|
||||
// y = t >> kp (arithmetic)
|
||||
let kp = -k;
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
debug_assert!((1..=63).contains(&kp));
|
||||
}
|
||||
|
||||
let cnt_right: __m128i = _mm_cvtsi32_si128(kp as i32);
|
||||
let bias_base: __m256i = _mm256_set1_epi64x(1_i64 << (kp - 1));
|
||||
let top_mask: __m256i = _mm256_set1_epi64x(-1_i64 << (64 - kp)); // high kp bits
|
||||
let zero: __m256i = _mm256_setzero_si256();
|
||||
|
||||
for _ in 0..span {
|
||||
let x: __m256i = _mm256_loadu_si256(aa);
|
||||
let y: __m256i = _mm256_loadu_si256(rr);
|
||||
|
||||
// bias = (1 << (kp-1)) - sign_bit
|
||||
let sign_bit_x: __m256i = _mm256_srli_epi64(x, 63);
|
||||
let bias: __m256i = _mm256_sub_epi64(bias_base, sign_bit_x);
|
||||
|
||||
// t = x + bias
|
||||
let t: __m256i = _mm256_add_epi64(x, bias);
|
||||
|
||||
// logical shift
|
||||
let lsr: __m256i = _mm256_srl_epi64(t, cnt_right);
|
||||
|
||||
// sign extension
|
||||
let neg: __m256i = _mm256_cmpgt_epi64(zero, t);
|
||||
let fill: __m256i = _mm256_and_si256(neg, top_mask);
|
||||
let out: __m256i = _mm256_or_si256(lsr, fill);
|
||||
|
||||
_mm256_storeu_si256(rr, _mm256_add_epi64(y, out));
|
||||
rr = rr.add(1);
|
||||
aa = aa.add(1);
|
||||
}
|
||||
}
|
||||
|
||||
// tail
|
||||
if !n.is_multiple_of(4) {
|
||||
use poulpy_hal::reference::znx::znx_mul_add_power_of_two_ref;
|
||||
znx_mul_add_power_of_two_ref(k, &mut res[span << 2..], &a[span << 2..]);
|
||||
}
|
||||
}
|
||||
@@ -6,14 +6,14 @@ use std::arch::x86_64::__m256i;
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
/// all inputs must have the same length and must not alias.
|
||||
#[target_feature(enable = "avx2")]
|
||||
fn normalize_consts_avx(basek: usize) -> (__m256i, __m256i, __m256i, __m256i) {
|
||||
fn normalize_consts_avx(base2k: usize) -> (__m256i, __m256i, __m256i, __m256i) {
|
||||
use std::arch::x86_64::_mm256_set1_epi64x;
|
||||
|
||||
assert!((1..=63).contains(&basek));
|
||||
let mask_k: i64 = ((1u64 << basek) - 1) as i64; // 0..k-1 bits set
|
||||
let sign_k: i64 = (1u64 << (basek - 1)) as i64; // bit k-1
|
||||
let topmask: i64 = (!0u64 << (64 - basek)) as i64; // top k bits set
|
||||
let sh_k: __m256i = _mm256_set1_epi64x(basek as i64);
|
||||
assert!((1..=63).contains(&base2k));
|
||||
let mask_k: i64 = ((1u64 << base2k) - 1) as i64; // 0..k-1 bits set
|
||||
let sign_k: i64 = (1u64 << (base2k - 1)) as i64; // bit k-1
|
||||
let topmask: i64 = (!0u64 << (64 - base2k)) as i64; // top k bits set
|
||||
let sh_k: __m256i = _mm256_set1_epi64x(base2k as i64);
|
||||
(
|
||||
_mm256_set1_epi64x(mask_k), // mask_k_vec
|
||||
_mm256_set1_epi64x(sign_k), // sign_k_vec
|
||||
@@ -46,14 +46,14 @@ fn get_digit_avx(x: __m256i, mask_k: __m256i, sign_k: __m256i) -> __m256i {
|
||||
unsafe fn get_carry_avx(
|
||||
x: __m256i,
|
||||
digit: __m256i,
|
||||
basek: __m256i, // _mm256_set1_epi64x(k)
|
||||
base2k: __m256i, // _mm256_set1_epi64x(k)
|
||||
top_mask: __m256i, // (!0 << (64 - k)) broadcast
|
||||
) -> __m256i {
|
||||
use std::arch::x86_64::{
|
||||
__m256i, _mm256_and_si256, _mm256_cmpgt_epi64, _mm256_or_si256, _mm256_setzero_si256, _mm256_srlv_epi64, _mm256_sub_epi64,
|
||||
};
|
||||
let diff: __m256i = _mm256_sub_epi64(x, digit);
|
||||
let lsr: __m256i = _mm256_srlv_epi64(diff, basek); // logical >>
|
||||
let lsr: __m256i = _mm256_srlv_epi64(diff, base2k); // logical >>
|
||||
let neg: __m256i = _mm256_cmpgt_epi64(_mm256_setzero_si256(), diff); // 0xFFFF.. where v<0
|
||||
let fill: __m256i = _mm256_and_si256(neg, top_mask); // top k bits if negative
|
||||
_mm256_or_si256(lsr, fill)
|
||||
@@ -61,13 +61,121 @@ unsafe fn get_carry_avx(
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
/// all inputs must have the same length and must not alias.
|
||||
/// `res` and `src` must have the same length and must not alias.
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx2")]
|
||||
pub fn znx_normalize_first_step_carry_only_avx(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
|
||||
pub fn znx_extract_digit_addmul_avx(base2k: usize, lsh: usize, res: &mut [i64], src: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(x.len(), carry.len());
|
||||
assert!(lsh < basek);
|
||||
assert_eq!(res.len(), src.len());
|
||||
assert!(lsh < base2k);
|
||||
}
|
||||
|
||||
use std::arch::x86_64::{
|
||||
__m256i, _mm256_add_epi64, _mm256_loadu_si256, _mm256_set1_epi64x, _mm256_sllv_epi64, _mm256_storeu_si256,
|
||||
};
|
||||
|
||||
let n: usize = res.len();
|
||||
let span: usize = n >> 2;
|
||||
|
||||
unsafe {
|
||||
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
|
||||
let mut ss: *mut __m256i = src.as_mut_ptr() as *mut __m256i;
|
||||
|
||||
// constants for digit/carry extraction
|
||||
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k);
|
||||
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
|
||||
|
||||
for _ in 0..span {
|
||||
// load source & extract digit/carry
|
||||
let sv: __m256i = _mm256_loadu_si256(ss);
|
||||
let digit_256: __m256i = get_digit_avx(sv, mask, sign);
|
||||
let carry_256: __m256i = get_carry_avx(sv, digit_256, basek_vec, top_mask);
|
||||
|
||||
// res += (digit << lsh)
|
||||
let rv: __m256i = _mm256_loadu_si256(rr);
|
||||
let madd: __m256i = _mm256_sllv_epi64(digit_256, lsh_v);
|
||||
let sum: __m256i = _mm256_add_epi64(rv, madd);
|
||||
|
||||
_mm256_storeu_si256(rr, sum);
|
||||
_mm256_storeu_si256(ss, carry_256);
|
||||
|
||||
rr = rr.add(1);
|
||||
ss = ss.add(1);
|
||||
}
|
||||
}
|
||||
|
||||
// tail (scalar)
|
||||
if !n.is_multiple_of(4) {
|
||||
use poulpy_hal::reference::znx::znx_extract_digit_addmul_ref;
|
||||
|
||||
let off: usize = span << 2;
|
||||
znx_extract_digit_addmul_ref(base2k, lsh, &mut res[off..], &mut src[off..]);
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
/// `res` and `src` must have the same length and must not alias.
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx2")]
|
||||
pub fn znx_normalize_digit_avx(base2k: usize, res: &mut [i64], src: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), src.len());
|
||||
}
|
||||
|
||||
use std::arch::x86_64::{__m256i, _mm256_add_epi64, _mm256_loadu_si256, _mm256_storeu_si256};
|
||||
|
||||
let n: usize = res.len();
|
||||
let span: usize = n >> 2;
|
||||
|
||||
unsafe {
|
||||
// Pointers to 256-bit lanes
|
||||
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
|
||||
let mut ss: *mut __m256i = src.as_mut_ptr() as *mut __m256i;
|
||||
|
||||
// Constants for digit/carry extraction
|
||||
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k);
|
||||
|
||||
for _ in 0..span {
|
||||
// Load res lane
|
||||
let rv: __m256i = _mm256_loadu_si256(rr);
|
||||
|
||||
// Extract digit and carry from res
|
||||
let digit_256: __m256i = get_digit_avx(rv, mask, sign);
|
||||
let carry_256: __m256i = get_carry_avx(rv, digit_256, basek_vec, top_mask);
|
||||
|
||||
// src += carry
|
||||
let sv: __m256i = _mm256_loadu_si256(ss);
|
||||
let sum: __m256i = _mm256_add_epi64(sv, carry_256);
|
||||
|
||||
_mm256_storeu_si256(ss, sum);
|
||||
_mm256_storeu_si256(rr, digit_256);
|
||||
|
||||
rr = rr.add(1);
|
||||
ss = ss.add(1);
|
||||
}
|
||||
}
|
||||
|
||||
// scalar tail
|
||||
if !n.is_multiple_of(4) {
|
||||
use poulpy_hal::reference::znx::znx_normalize_digit_ref;
|
||||
|
||||
let off = span << 2;
|
||||
znx_normalize_digit_ref(base2k, &mut res[off..], &mut src[off..]);
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
/// all inputs must have the same length and must not alias.
|
||||
#[target_feature(enable = "avx2")]
|
||||
pub fn znx_normalize_first_step_carry_only_avx(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(x.len() <= carry.len());
|
||||
assert!(lsh < base2k);
|
||||
}
|
||||
|
||||
use std::arch::x86_64::{_mm256_loadu_si256, _mm256_storeu_si256};
|
||||
@@ -81,19 +189,19 @@ pub fn znx_normalize_first_step_carry_only_avx(basek: usize, lsh: usize, x: &[i6
|
||||
let mut cc: *mut __m256i = carry.as_ptr() as *mut __m256i;
|
||||
|
||||
let (mask, sign, basek_vec, top_mask) = if lsh == 0 {
|
||||
normalize_consts_avx(basek)
|
||||
normalize_consts_avx(base2k)
|
||||
} else {
|
||||
normalize_consts_avx(basek - lsh)
|
||||
normalize_consts_avx(base2k - lsh)
|
||||
};
|
||||
|
||||
for _ in 0..span {
|
||||
let xx_256: __m256i = _mm256_loadu_si256(xx);
|
||||
let xv: __m256i = _mm256_loadu_si256(xx);
|
||||
|
||||
// (x << (64 - basek)) >> (64 - basek)
|
||||
let digit_256: __m256i = get_digit_avx(xx_256, mask, sign);
|
||||
// (x << (64 - base2k)) >> (64 - base2k)
|
||||
let digit_256: __m256i = get_digit_avx(xv, mask, sign);
|
||||
|
||||
// (x - digit) >> basek
|
||||
let carry_256: __m256i = get_carry_avx(xx_256, digit_256, basek_vec, top_mask);
|
||||
// (x - digit) >> base2k
|
||||
let carry_256: __m256i = get_carry_avx(xv, digit_256, basek_vec, top_mask);
|
||||
|
||||
_mm256_storeu_si256(cc, carry_256);
|
||||
|
||||
@@ -106,7 +214,7 @@ pub fn znx_normalize_first_step_carry_only_avx(basek: usize, lsh: usize, x: &[i6
|
||||
if !x.len().is_multiple_of(4) {
|
||||
use poulpy_hal::reference::znx::znx_normalize_first_step_carry_only_ref;
|
||||
|
||||
znx_normalize_first_step_carry_only_ref(basek, lsh, &x[span << 2..], &mut carry[span << 2..]);
|
||||
znx_normalize_first_step_carry_only_ref(base2k, lsh, &x[span << 2..], &mut carry[span << 2..]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,11 +222,11 @@ pub fn znx_normalize_first_step_carry_only_avx(basek: usize, lsh: usize, x: &[i6
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
/// all inputs must have the same length and must not alias.
|
||||
#[target_feature(enable = "avx2")]
|
||||
pub fn znx_normalize_first_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
pub fn znx_normalize_first_step_inplace_avx(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(x.len(), carry.len());
|
||||
assert!(lsh < basek);
|
||||
assert!(x.len() <= carry.len());
|
||||
assert!(lsh < base2k);
|
||||
}
|
||||
|
||||
use std::arch::x86_64::{_mm256_loadu_si256, _mm256_set1_epi64x, _mm256_sllv_epi64, _mm256_storeu_si256};
|
||||
@@ -132,16 +240,16 @@ pub fn znx_normalize_first_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i
|
||||
let mut cc: *mut __m256i = carry.as_ptr() as *mut __m256i;
|
||||
|
||||
if lsh == 0 {
|
||||
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(basek);
|
||||
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k);
|
||||
|
||||
for _ in 0..span {
|
||||
let xx_256: __m256i = _mm256_loadu_si256(xx);
|
||||
let xv: __m256i = _mm256_loadu_si256(xx);
|
||||
|
||||
// (x << (64 - basek)) >> (64 - basek)
|
||||
let digit_256: __m256i = get_digit_avx(xx_256, mask, sign);
|
||||
// (x << (64 - base2k)) >> (64 - base2k)
|
||||
let digit_256: __m256i = get_digit_avx(xv, mask, sign);
|
||||
|
||||
// (x - digit) >> basek
|
||||
let carry_256: __m256i = get_carry_avx(xx_256, digit_256, basek_vec, top_mask);
|
||||
// (x - digit) >> base2k
|
||||
let carry_256: __m256i = get_carry_avx(xv, digit_256, basek_vec, top_mask);
|
||||
|
||||
_mm256_storeu_si256(xx, digit_256);
|
||||
_mm256_storeu_si256(cc, carry_256);
|
||||
@@ -150,18 +258,18 @@ pub fn znx_normalize_first_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i
|
||||
cc = cc.add(1);
|
||||
}
|
||||
} else {
|
||||
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(basek - lsh);
|
||||
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k - lsh);
|
||||
|
||||
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
|
||||
|
||||
for _ in 0..span {
|
||||
let xx_256: __m256i = _mm256_loadu_si256(xx);
|
||||
let xv: __m256i = _mm256_loadu_si256(xx);
|
||||
|
||||
// (x << (64 - basek)) >> (64 - basek)
|
||||
let digit_256: __m256i = get_digit_avx(xx_256, mask, sign);
|
||||
// (x << (64 - base2k)) >> (64 - base2k)
|
||||
let digit_256: __m256i = get_digit_avx(xv, mask, sign);
|
||||
|
||||
// (x - digit) >> basek
|
||||
let carry_256: __m256i = get_carry_avx(xx_256, digit_256, basek_vec, top_mask);
|
||||
// (x - digit) >> base2k
|
||||
let carry_256: __m256i = get_carry_avx(xv, digit_256, basek_vec, top_mask);
|
||||
|
||||
_mm256_storeu_si256(xx, _mm256_sllv_epi64(digit_256, lsh_v));
|
||||
_mm256_storeu_si256(cc, carry_256);
|
||||
@@ -176,7 +284,7 @@ pub fn znx_normalize_first_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i
|
||||
if !x.len().is_multiple_of(4) {
|
||||
use poulpy_hal::reference::znx::znx_normalize_first_step_inplace_ref;
|
||||
|
||||
znx_normalize_first_step_inplace_ref(basek, lsh, &mut x[span << 2..], &mut carry[span << 2..]);
|
||||
znx_normalize_first_step_inplace_ref(base2k, lsh, &mut x[span << 2..], &mut carry[span << 2..]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -184,12 +292,12 @@ pub fn znx_normalize_first_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
/// all inputs must have the same length and must not alias.
|
||||
#[target_feature(enable = "avx2")]
|
||||
pub fn znx_normalize_first_step_avx(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
pub fn znx_normalize_first_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(x.len(), carry.len());
|
||||
assert_eq!(a.len(), carry.len());
|
||||
assert!(lsh < basek);
|
||||
assert_eq!(x.len(), a.len());
|
||||
assert!(x.len() <= carry.len());
|
||||
assert!(lsh < base2k);
|
||||
}
|
||||
|
||||
use std::arch::x86_64::{_mm256_loadu_si256, _mm256_sllv_epi64, _mm256_storeu_si256};
|
||||
@@ -204,16 +312,16 @@ pub fn znx_normalize_first_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
|
||||
let mut cc: *mut __m256i = carry.as_ptr() as *mut __m256i;
|
||||
|
||||
if lsh == 0 {
|
||||
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(basek);
|
||||
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k);
|
||||
|
||||
for _ in 0..span {
|
||||
let aa_256: __m256i = _mm256_loadu_si256(aa);
|
||||
let av: __m256i = _mm256_loadu_si256(aa);
|
||||
|
||||
// (x << (64 - basek)) >> (64 - basek)
|
||||
let digit_256: __m256i = get_digit_avx(aa_256, mask, sign);
|
||||
// (x << (64 - base2k)) >> (64 - base2k)
|
||||
let digit_256: __m256i = get_digit_avx(av, mask, sign);
|
||||
|
||||
// (x - digit) >> basek
|
||||
let carry_256: __m256i = get_carry_avx(aa_256, digit_256, basek_vec, top_mask);
|
||||
// (x - digit) >> base2k
|
||||
let carry_256: __m256i = get_carry_avx(av, digit_256, basek_vec, top_mask);
|
||||
|
||||
_mm256_storeu_si256(xx, digit_256);
|
||||
_mm256_storeu_si256(cc, carry_256);
|
||||
@@ -225,18 +333,18 @@ pub fn znx_normalize_first_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
|
||||
} else {
|
||||
use std::arch::x86_64::_mm256_set1_epi64x;
|
||||
|
||||
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(basek - lsh);
|
||||
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k - lsh);
|
||||
|
||||
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
|
||||
|
||||
for _ in 0..span {
|
||||
let aa_256: __m256i = _mm256_loadu_si256(aa);
|
||||
let av: __m256i = _mm256_loadu_si256(aa);
|
||||
|
||||
// (x << (64 - basek)) >> (64 - basek)
|
||||
let digit_256: __m256i = get_digit_avx(aa_256, mask, sign);
|
||||
// (x << (64 - base2k)) >> (64 - base2k)
|
||||
let digit_256: __m256i = get_digit_avx(av, mask, sign);
|
||||
|
||||
// (x - digit) >> basek
|
||||
let carry_256: __m256i = get_carry_avx(aa_256, digit_256, basek_vec, top_mask);
|
||||
// (x - digit) >> base2k
|
||||
let carry_256: __m256i = get_carry_avx(av, digit_256, basek_vec, top_mask);
|
||||
|
||||
_mm256_storeu_si256(xx, _mm256_sllv_epi64(digit_256, lsh_v));
|
||||
_mm256_storeu_si256(cc, carry_256);
|
||||
@@ -253,7 +361,7 @@ pub fn znx_normalize_first_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
|
||||
use poulpy_hal::reference::znx::znx_normalize_first_step_ref;
|
||||
|
||||
znx_normalize_first_step_ref(
|
||||
basek,
|
||||
base2k,
|
||||
lsh,
|
||||
&mut x[span << 2..],
|
||||
&a[span << 2..],
|
||||
@@ -266,11 +374,11 @@ pub fn znx_normalize_first_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
/// all inputs must have the same length and must not alias.
|
||||
#[target_feature(enable = "avx2")]
|
||||
pub fn znx_normalize_middle_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
pub fn znx_normalize_middle_step_inplace_avx(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(x.len(), carry.len());
|
||||
assert!(lsh < basek);
|
||||
assert!(x.len() <= carry.len());
|
||||
assert!(lsh < base2k);
|
||||
}
|
||||
|
||||
use std::arch::x86_64::{_mm256_add_epi64, _mm256_loadu_si256, _mm256_sllv_epi64, _mm256_storeu_si256};
|
||||
@@ -279,7 +387,7 @@ pub fn znx_normalize_middle_step_inplace_avx(basek: usize, lsh: usize, x: &mut [
|
||||
|
||||
let span: usize = n >> 2;
|
||||
|
||||
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(basek);
|
||||
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k);
|
||||
|
||||
unsafe {
|
||||
let mut xx: *mut __m256i = x.as_mut_ptr() as *mut __m256i;
|
||||
@@ -287,13 +395,13 @@ pub fn znx_normalize_middle_step_inplace_avx(basek: usize, lsh: usize, x: &mut [
|
||||
|
||||
if lsh == 0 {
|
||||
for _ in 0..span {
|
||||
let xx_256: __m256i = _mm256_loadu_si256(xx);
|
||||
let cc_256: __m256i = _mm256_loadu_si256(cc);
|
||||
let xv: __m256i = _mm256_loadu_si256(xx);
|
||||
let cv: __m256i = _mm256_loadu_si256(cc);
|
||||
|
||||
let d0: __m256i = get_digit_avx(xx_256, mask, sign);
|
||||
let c0: __m256i = get_carry_avx(xx_256, d0, basek_vec, top_mask);
|
||||
let d0: __m256i = get_digit_avx(xv, mask, sign);
|
||||
let c0: __m256i = get_carry_avx(xv, d0, basek_vec, top_mask);
|
||||
|
||||
let s: __m256i = _mm256_add_epi64(d0, cc_256);
|
||||
let s: __m256i = _mm256_add_epi64(d0, cv);
|
||||
let x1: __m256i = get_digit_avx(s, mask, sign);
|
||||
let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask);
|
||||
let cout: __m256i = _mm256_add_epi64(c0, c1);
|
||||
@@ -307,20 +415,20 @@ pub fn znx_normalize_middle_step_inplace_avx(basek: usize, lsh: usize, x: &mut [
|
||||
} else {
|
||||
use std::arch::x86_64::_mm256_set1_epi64x;
|
||||
|
||||
let (mask_lsh, sign_lsh, basek_vec_lsh, top_mask_lsh) = normalize_consts_avx(basek - lsh);
|
||||
let (mask_lsh, sign_lsh, basek_vec_lsh, top_mask_lsh) = normalize_consts_avx(base2k - lsh);
|
||||
|
||||
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
|
||||
|
||||
for _ in 0..span {
|
||||
let xx_256: __m256i = _mm256_loadu_si256(xx);
|
||||
let cc_256: __m256i = _mm256_loadu_si256(cc);
|
||||
let xv: __m256i = _mm256_loadu_si256(xx);
|
||||
let cv: __m256i = _mm256_loadu_si256(cc);
|
||||
|
||||
let d0: __m256i = get_digit_avx(xx_256, mask_lsh, sign_lsh);
|
||||
let c0: __m256i = get_carry_avx(xx_256, d0, basek_vec_lsh, top_mask_lsh);
|
||||
let d0: __m256i = get_digit_avx(xv, mask_lsh, sign_lsh);
|
||||
let c0: __m256i = get_carry_avx(xv, d0, basek_vec_lsh, top_mask_lsh);
|
||||
|
||||
let d0_lsh: __m256i = _mm256_sllv_epi64(d0, lsh_v);
|
||||
|
||||
let s: __m256i = _mm256_add_epi64(d0_lsh, cc_256);
|
||||
let s: __m256i = _mm256_add_epi64(d0_lsh, cv);
|
||||
let x1: __m256i = get_digit_avx(s, mask, sign);
|
||||
let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask);
|
||||
let cout: __m256i = _mm256_add_epi64(c0, c1);
|
||||
@@ -337,7 +445,7 @@ pub fn znx_normalize_middle_step_inplace_avx(basek: usize, lsh: usize, x: &mut [
|
||||
if !x.len().is_multiple_of(4) {
|
||||
use poulpy_hal::reference::znx::znx_normalize_middle_step_inplace_ref;
|
||||
|
||||
znx_normalize_middle_step_inplace_ref(basek, lsh, &mut x[span << 2..], &mut carry[span << 2..]);
|
||||
znx_normalize_middle_step_inplace_ref(base2k, lsh, &mut x[span << 2..], &mut carry[span << 2..]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -345,11 +453,11 @@ pub fn znx_normalize_middle_step_inplace_avx(basek: usize, lsh: usize, x: &mut [
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
/// all inputs must have the same length and must not alias.
|
||||
#[target_feature(enable = "avx2")]
|
||||
pub fn znx_normalize_middle_step_carry_only_avx(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
|
||||
pub fn znx_normalize_middle_step_carry_only_avx(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(x.len(), carry.len());
|
||||
assert!(lsh < basek);
|
||||
assert!(x.len() <= carry.len());
|
||||
assert!(lsh < base2k);
|
||||
}
|
||||
|
||||
use std::arch::x86_64::{_mm256_add_epi64, _mm256_loadu_si256, _mm256_sllv_epi64, _mm256_storeu_si256};
|
||||
@@ -358,7 +466,7 @@ pub fn znx_normalize_middle_step_carry_only_avx(basek: usize, lsh: usize, x: &[i
|
||||
|
||||
let span: usize = n >> 2;
|
||||
|
||||
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(basek);
|
||||
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k);
|
||||
|
||||
unsafe {
|
||||
let mut xx: *const __m256i = x.as_ptr() as *const __m256i;
|
||||
@@ -366,13 +474,13 @@ pub fn znx_normalize_middle_step_carry_only_avx(basek: usize, lsh: usize, x: &[i
|
||||
|
||||
if lsh == 0 {
|
||||
for _ in 0..span {
|
||||
let xx_256: __m256i = _mm256_loadu_si256(xx);
|
||||
let cc_256: __m256i = _mm256_loadu_si256(cc);
|
||||
let xv: __m256i = _mm256_loadu_si256(xx);
|
||||
let cv: __m256i = _mm256_loadu_si256(cc);
|
||||
|
||||
let d0: __m256i = get_digit_avx(xx_256, mask, sign);
|
||||
let c0: __m256i = get_carry_avx(xx_256, d0, basek_vec, top_mask);
|
||||
let d0: __m256i = get_digit_avx(xv, mask, sign);
|
||||
let c0: __m256i = get_carry_avx(xv, d0, basek_vec, top_mask);
|
||||
|
||||
let s: __m256i = _mm256_add_epi64(d0, cc_256);
|
||||
let s: __m256i = _mm256_add_epi64(d0, cv);
|
||||
let x1: __m256i = get_digit_avx(s, mask, sign);
|
||||
let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask);
|
||||
let cout: __m256i = _mm256_add_epi64(c0, c1);
|
||||
@@ -385,20 +493,20 @@ pub fn znx_normalize_middle_step_carry_only_avx(basek: usize, lsh: usize, x: &[i
|
||||
} else {
|
||||
use std::arch::x86_64::_mm256_set1_epi64x;
|
||||
|
||||
let (mask_lsh, sign_lsh, basek_vec_lsh, top_mask_lsh) = normalize_consts_avx(basek - lsh);
|
||||
let (mask_lsh, sign_lsh, basek_vec_lsh, top_mask_lsh) = normalize_consts_avx(base2k - lsh);
|
||||
|
||||
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
|
||||
|
||||
for _ in 0..span {
|
||||
let xx_256: __m256i = _mm256_loadu_si256(xx);
|
||||
let cc_256: __m256i = _mm256_loadu_si256(cc);
|
||||
let xv: __m256i = _mm256_loadu_si256(xx);
|
||||
let cv: __m256i = _mm256_loadu_si256(cc);
|
||||
|
||||
let d0: __m256i = get_digit_avx(xx_256, mask_lsh, sign_lsh);
|
||||
let c0: __m256i = get_carry_avx(xx_256, d0, basek_vec_lsh, top_mask_lsh);
|
||||
let d0: __m256i = get_digit_avx(xv, mask_lsh, sign_lsh);
|
||||
let c0: __m256i = get_carry_avx(xv, d0, basek_vec_lsh, top_mask_lsh);
|
||||
|
||||
let d0_lsh: __m256i = _mm256_sllv_epi64(d0, lsh_v);
|
||||
|
||||
let s: __m256i = _mm256_add_epi64(d0_lsh, cc_256);
|
||||
let s: __m256i = _mm256_add_epi64(d0_lsh, cv);
|
||||
let x1: __m256i = get_digit_avx(s, mask, sign);
|
||||
let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask);
|
||||
let cout: __m256i = _mm256_add_epi64(c0, c1);
|
||||
@@ -414,7 +522,7 @@ pub fn znx_normalize_middle_step_carry_only_avx(basek: usize, lsh: usize, x: &[i
|
||||
if !x.len().is_multiple_of(4) {
|
||||
use poulpy_hal::reference::znx::znx_normalize_middle_step_carry_only_ref;
|
||||
|
||||
znx_normalize_middle_step_carry_only_ref(basek, lsh, &x[span << 2..], &mut carry[span << 2..]);
|
||||
znx_normalize_middle_step_carry_only_ref(base2k, lsh, &x[span << 2..], &mut carry[span << 2..]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -422,12 +530,12 @@ pub fn znx_normalize_middle_step_carry_only_avx(basek: usize, lsh: usize, x: &[i
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
/// all inputs must have the same length and must not alias.
|
||||
#[target_feature(enable = "avx2")]
|
||||
pub fn znx_normalize_middle_step_avx(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
pub fn znx_normalize_middle_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(x.len(), carry.len());
|
||||
assert_eq!(a.len(), carry.len());
|
||||
assert!(lsh < basek);
|
||||
assert_eq!(x.len(), a.len());
|
||||
assert!(x.len() <= carry.len());
|
||||
assert!(lsh < base2k);
|
||||
}
|
||||
|
||||
use std::arch::x86_64::{_mm256_add_epi64, _mm256_loadu_si256, _mm256_sllv_epi64, _mm256_storeu_si256};
|
||||
@@ -436,7 +544,7 @@ pub fn znx_normalize_middle_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
|
||||
|
||||
let span: usize = n >> 2;
|
||||
|
||||
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(basek);
|
||||
let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k);
|
||||
|
||||
unsafe {
|
||||
let mut xx: *mut __m256i = x.as_mut_ptr() as *mut __m256i;
|
||||
@@ -445,13 +553,13 @@ pub fn znx_normalize_middle_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
|
||||
|
||||
if lsh == 0 {
|
||||
for _ in 0..span {
|
||||
let aa_256: __m256i = _mm256_loadu_si256(aa);
|
||||
let cc_256: __m256i = _mm256_loadu_si256(cc);
|
||||
let av: __m256i = _mm256_loadu_si256(aa);
|
||||
let cv: __m256i = _mm256_loadu_si256(cc);
|
||||
|
||||
let d0: __m256i = get_digit_avx(aa_256, mask, sign);
|
||||
let c0: __m256i = get_carry_avx(aa_256, d0, basek_vec, top_mask);
|
||||
let d0: __m256i = get_digit_avx(av, mask, sign);
|
||||
let c0: __m256i = get_carry_avx(av, d0, basek_vec, top_mask);
|
||||
|
||||
let s: __m256i = _mm256_add_epi64(d0, cc_256);
|
||||
let s: __m256i = _mm256_add_epi64(d0, cv);
|
||||
let x1: __m256i = get_digit_avx(s, mask, sign);
|
||||
let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask);
|
||||
let cout: __m256i = _mm256_add_epi64(c0, c1);
|
||||
@@ -466,20 +574,20 @@ pub fn znx_normalize_middle_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
|
||||
} else {
|
||||
use std::arch::x86_64::_mm256_set1_epi64x;
|
||||
|
||||
let (mask_lsh, sign_lsh, basek_vec_lsh, top_mask_lsh) = normalize_consts_avx(basek - lsh);
|
||||
let (mask_lsh, sign_lsh, basek_vec_lsh, top_mask_lsh) = normalize_consts_avx(base2k - lsh);
|
||||
|
||||
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
|
||||
|
||||
for _ in 0..span {
|
||||
let aa_256: __m256i = _mm256_loadu_si256(aa);
|
||||
let cc_256: __m256i = _mm256_loadu_si256(cc);
|
||||
let av: __m256i = _mm256_loadu_si256(aa);
|
||||
let cv: __m256i = _mm256_loadu_si256(cc);
|
||||
|
||||
let d0: __m256i = get_digit_avx(aa_256, mask_lsh, sign_lsh);
|
||||
let c0: __m256i = get_carry_avx(aa_256, d0, basek_vec_lsh, top_mask_lsh);
|
||||
let d0: __m256i = get_digit_avx(av, mask_lsh, sign_lsh);
|
||||
let c0: __m256i = get_carry_avx(av, d0, basek_vec_lsh, top_mask_lsh);
|
||||
|
||||
let d0_lsh: __m256i = _mm256_sllv_epi64(d0, lsh_v);
|
||||
|
||||
let s: __m256i = _mm256_add_epi64(d0_lsh, cc_256);
|
||||
let s: __m256i = _mm256_add_epi64(d0_lsh, cv);
|
||||
let x1: __m256i = get_digit_avx(s, mask, sign);
|
||||
let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask);
|
||||
let cout: __m256i = _mm256_add_epi64(c0, c1);
|
||||
@@ -498,7 +606,7 @@ pub fn znx_normalize_middle_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
|
||||
use poulpy_hal::reference::znx::znx_normalize_middle_step_ref;
|
||||
|
||||
znx_normalize_middle_step_ref(
|
||||
basek,
|
||||
base2k,
|
||||
lsh,
|
||||
&mut x[span << 2..],
|
||||
&a[span << 2..],
|
||||
@@ -511,11 +619,11 @@ pub fn znx_normalize_middle_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
/// all inputs must have the same length and must not alias.
|
||||
#[target_feature(enable = "avx2")]
|
||||
pub fn znx_normalize_final_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
pub fn znx_normalize_final_step_inplace_avx(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(x.len(), carry.len());
|
||||
assert!(lsh < basek);
|
||||
assert!(x.len() <= carry.len());
|
||||
assert!(lsh < base2k);
|
||||
}
|
||||
|
||||
use std::arch::x86_64::{_mm256_add_epi64, _mm256_loadu_si256, _mm256_sllv_epi64, _mm256_storeu_si256};
|
||||
@@ -524,7 +632,7 @@ pub fn znx_normalize_final_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i
|
||||
|
||||
let span: usize = n >> 2;
|
||||
|
||||
let (mask, sign, _, _) = normalize_consts_avx(basek);
|
||||
let (mask, sign, _, _) = normalize_consts_avx(base2k);
|
||||
|
||||
unsafe {
|
||||
let mut xx: *mut __m256i = x.as_mut_ptr() as *mut __m256i;
|
||||
@@ -547,7 +655,7 @@ pub fn znx_normalize_final_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i
|
||||
} else {
|
||||
use std::arch::x86_64::_mm256_set1_epi64x;
|
||||
|
||||
let (mask_lsh, sign_lsh, _, _) = normalize_consts_avx(basek - lsh);
|
||||
let (mask_lsh, sign_lsh, _, _) = normalize_consts_avx(base2k - lsh);
|
||||
|
||||
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
|
||||
|
||||
@@ -573,7 +681,7 @@ pub fn znx_normalize_final_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i
|
||||
if !x.len().is_multiple_of(4) {
|
||||
use poulpy_hal::reference::znx::znx_normalize_final_step_inplace_ref;
|
||||
|
||||
znx_normalize_final_step_inplace_ref(basek, lsh, &mut x[span << 2..], &mut carry[span << 2..]);
|
||||
znx_normalize_final_step_inplace_ref(base2k, lsh, &mut x[span << 2..], &mut carry[span << 2..]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -581,12 +689,12 @@ pub fn znx_normalize_final_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
/// all inputs must have the same length and must not alias.
|
||||
#[target_feature(enable = "avx2")]
|
||||
pub fn znx_normalize_final_step_avx(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
pub fn znx_normalize_final_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(x.len(), carry.len());
|
||||
assert_eq!(a.len(), carry.len());
|
||||
assert!(lsh < basek);
|
||||
assert_eq!(x.len(), a.len());
|
||||
assert!(x.len() <= carry.len());
|
||||
assert!(lsh < base2k);
|
||||
}
|
||||
|
||||
use std::arch::x86_64::{_mm256_add_epi64, _mm256_loadu_si256, _mm256_sllv_epi64, _mm256_storeu_si256};
|
||||
@@ -595,7 +703,7 @@ pub fn znx_normalize_final_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
|
||||
|
||||
let span: usize = n >> 2;
|
||||
|
||||
let (mask, sign, _, _) = normalize_consts_avx(basek);
|
||||
let (mask, sign, _, _) = normalize_consts_avx(base2k);
|
||||
|
||||
unsafe {
|
||||
let mut xx: *mut __m256i = x.as_mut_ptr() as *mut __m256i;
|
||||
@@ -620,7 +728,7 @@ pub fn znx_normalize_final_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
|
||||
} else {
|
||||
use std::arch::x86_64::_mm256_set1_epi64x;
|
||||
|
||||
let (mask_lsh, sign_lsh, _, _) = normalize_consts_avx(basek - lsh);
|
||||
let (mask_lsh, sign_lsh, _, _) = normalize_consts_avx(base2k - lsh);
|
||||
|
||||
let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64);
|
||||
|
||||
@@ -647,7 +755,7 @@ pub fn znx_normalize_final_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
|
||||
use poulpy_hal::reference::znx::znx_normalize_final_step_ref;
|
||||
|
||||
znx_normalize_final_step_ref(
|
||||
basek,
|
||||
base2k,
|
||||
lsh,
|
||||
&mut x[span << 2..],
|
||||
&a[span << 2..],
|
||||
@@ -658,9 +766,9 @@ pub fn znx_normalize_final_step_avx(basek: usize, lsh: usize, x: &mut [i64], a:
|
||||
|
||||
mod tests {
|
||||
use poulpy_hal::reference::znx::{
|
||||
get_carry, get_digit, znx_normalize_final_step_inplace_ref, znx_normalize_final_step_ref,
|
||||
znx_normalize_first_step_inplace_ref, znx_normalize_first_step_ref, znx_normalize_middle_step_inplace_ref,
|
||||
znx_normalize_middle_step_ref,
|
||||
get_carry_i64, get_digit_i64, znx_extract_digit_addmul_ref, znx_normalize_digit_ref,
|
||||
znx_normalize_final_step_inplace_ref, znx_normalize_final_step_ref, znx_normalize_first_step_inplace_ref,
|
||||
znx_normalize_first_step_ref, znx_normalize_middle_step_inplace_ref, znx_normalize_middle_step_ref,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
@@ -670,7 +778,7 @@ mod tests {
|
||||
#[allow(dead_code)]
|
||||
#[target_feature(enable = "avx2")]
|
||||
fn test_get_digit_avx_internal() {
|
||||
let basek: usize = 12;
|
||||
let base2k: usize = 12;
|
||||
let x: [i64; 4] = [
|
||||
7638646372408325293,
|
||||
-61440197422348985,
|
||||
@@ -678,15 +786,15 @@ mod tests {
|
||||
-4835376105455195188,
|
||||
];
|
||||
let y0: Vec<i64> = vec![
|
||||
get_digit(basek, x[0]),
|
||||
get_digit(basek, x[1]),
|
||||
get_digit(basek, x[2]),
|
||||
get_digit(basek, x[3]),
|
||||
get_digit_i64(base2k, x[0]),
|
||||
get_digit_i64(base2k, x[1]),
|
||||
get_digit_i64(base2k, x[2]),
|
||||
get_digit_i64(base2k, x[3]),
|
||||
];
|
||||
let mut y1: Vec<i64> = vec![0i64; 4];
|
||||
unsafe {
|
||||
let x_256: __m256i = _mm256_loadu_si256(x.as_ptr() as *const __m256i);
|
||||
let (mask, sign, _, _) = normalize_consts_avx(basek);
|
||||
let (mask, sign, _, _) = normalize_consts_avx(base2k);
|
||||
let digit: __m256i = get_digit_avx(x_256, mask, sign);
|
||||
_mm256_storeu_si256(y1.as_mut_ptr() as *mut __m256i, digit);
|
||||
}
|
||||
@@ -707,7 +815,7 @@ mod tests {
|
||||
#[allow(dead_code)]
|
||||
#[target_feature(enable = "avx2")]
|
||||
fn test_get_carry_avx_internal() {
|
||||
let basek: usize = 12;
|
||||
let base2k: usize = 12;
|
||||
let x: [i64; 4] = [
|
||||
7638646372408325293,
|
||||
-61440197422348985,
|
||||
@@ -716,16 +824,16 @@ mod tests {
|
||||
];
|
||||
let carry: [i64; 4] = [1174467039, -144794816, -1466676977, 513122840];
|
||||
let y0: Vec<i64> = vec![
|
||||
get_carry(basek, x[0], carry[0]),
|
||||
get_carry(basek, x[1], carry[1]),
|
||||
get_carry(basek, x[2], carry[2]),
|
||||
get_carry(basek, x[3], carry[3]),
|
||||
get_carry_i64(base2k, x[0], carry[0]),
|
||||
get_carry_i64(base2k, x[1], carry[1]),
|
||||
get_carry_i64(base2k, x[2], carry[2]),
|
||||
get_carry_i64(base2k, x[3], carry[3]),
|
||||
];
|
||||
let mut y1: Vec<i64> = vec![0i64; 4];
|
||||
unsafe {
|
||||
let x_256: __m256i = _mm256_loadu_si256(x.as_ptr() as *const __m256i);
|
||||
let d_256: __m256i = _mm256_loadu_si256(carry.as_ptr() as *const __m256i);
|
||||
let (_, _, basek_vec, top_mask) = normalize_consts_avx(basek);
|
||||
let (_, _, basek_vec, top_mask) = normalize_consts_avx(base2k);
|
||||
let digit: __m256i = get_carry_avx(x_256, d_256, basek_vec, top_mask);
|
||||
_mm256_storeu_si256(y1.as_mut_ptr() as *mut __m256i, digit);
|
||||
}
|
||||
@@ -762,16 +870,16 @@ mod tests {
|
||||
];
|
||||
let mut c1: [i64; 4] = c0;
|
||||
|
||||
let basek = 12;
|
||||
let base2k = 12;
|
||||
|
||||
znx_normalize_first_step_inplace_ref(basek, 0, &mut y0, &mut c0);
|
||||
znx_normalize_first_step_inplace_avx(basek, 0, &mut y1, &mut c1);
|
||||
znx_normalize_first_step_inplace_ref(base2k, 0, &mut y0, &mut c0);
|
||||
znx_normalize_first_step_inplace_avx(base2k, 0, &mut y1, &mut c1);
|
||||
|
||||
assert_eq!(y0, y1);
|
||||
assert_eq!(c0, c1);
|
||||
|
||||
znx_normalize_first_step_inplace_ref(basek, basek - 1, &mut y0, &mut c0);
|
||||
znx_normalize_first_step_inplace_avx(basek, basek - 1, &mut y1, &mut c1);
|
||||
znx_normalize_first_step_inplace_ref(base2k, base2k - 1, &mut y0, &mut c0);
|
||||
znx_normalize_first_step_inplace_avx(base2k, base2k - 1, &mut y1, &mut c1);
|
||||
|
||||
assert_eq!(y0, y1);
|
||||
assert_eq!(c0, c1);
|
||||
@@ -807,16 +915,16 @@ mod tests {
|
||||
];
|
||||
let mut c1: [i64; 4] = c0;
|
||||
|
||||
let basek = 12;
|
||||
let base2k = 12;
|
||||
|
||||
znx_normalize_middle_step_inplace_ref(basek, 0, &mut y0, &mut c0);
|
||||
znx_normalize_middle_step_inplace_avx(basek, 0, &mut y1, &mut c1);
|
||||
znx_normalize_middle_step_inplace_ref(base2k, 0, &mut y0, &mut c0);
|
||||
znx_normalize_middle_step_inplace_avx(base2k, 0, &mut y1, &mut c1);
|
||||
|
||||
assert_eq!(y0, y1);
|
||||
assert_eq!(c0, c1);
|
||||
|
||||
znx_normalize_middle_step_inplace_ref(basek, basek - 1, &mut y0, &mut c0);
|
||||
znx_normalize_middle_step_inplace_avx(basek, basek - 1, &mut y1, &mut c1);
|
||||
znx_normalize_middle_step_inplace_ref(base2k, base2k - 1, &mut y0, &mut c0);
|
||||
znx_normalize_middle_step_inplace_avx(base2k, base2k - 1, &mut y1, &mut c1);
|
||||
|
||||
assert_eq!(y0, y1);
|
||||
assert_eq!(c0, c1);
|
||||
@@ -852,16 +960,16 @@ mod tests {
|
||||
];
|
||||
let mut c1: [i64; 4] = c0;
|
||||
|
||||
let basek = 12;
|
||||
let base2k = 12;
|
||||
|
||||
znx_normalize_final_step_inplace_ref(basek, 0, &mut y0, &mut c0);
|
||||
znx_normalize_final_step_inplace_avx(basek, 0, &mut y1, &mut c1);
|
||||
znx_normalize_final_step_inplace_ref(base2k, 0, &mut y0, &mut c0);
|
||||
znx_normalize_final_step_inplace_avx(base2k, 0, &mut y1, &mut c1);
|
||||
|
||||
assert_eq!(y0, y1);
|
||||
assert_eq!(c0, c1);
|
||||
|
||||
znx_normalize_final_step_inplace_ref(basek, basek - 1, &mut y0, &mut c0);
|
||||
znx_normalize_final_step_inplace_avx(basek, basek - 1, &mut y1, &mut c1);
|
||||
znx_normalize_final_step_inplace_ref(base2k, base2k - 1, &mut y0, &mut c0);
|
||||
znx_normalize_final_step_inplace_avx(base2k, base2k - 1, &mut y1, &mut c1);
|
||||
|
||||
assert_eq!(y0, y1);
|
||||
assert_eq!(c0, c1);
|
||||
@@ -898,16 +1006,16 @@ mod tests {
|
||||
];
|
||||
let mut c1: [i64; 4] = c0;
|
||||
|
||||
let basek = 12;
|
||||
let base2k = 12;
|
||||
|
||||
znx_normalize_first_step_ref(basek, 0, &mut y0, &a, &mut c0);
|
||||
znx_normalize_first_step_avx(basek, 0, &mut y1, &a, &mut c1);
|
||||
znx_normalize_first_step_ref(base2k, 0, &mut y0, &a, &mut c0);
|
||||
znx_normalize_first_step_avx(base2k, 0, &mut y1, &a, &mut c1);
|
||||
|
||||
assert_eq!(y0, y1);
|
||||
assert_eq!(c0, c1);
|
||||
|
||||
znx_normalize_first_step_ref(basek, basek - 1, &mut y0, &a, &mut c0);
|
||||
znx_normalize_first_step_avx(basek, basek - 1, &mut y1, &a, &mut c1);
|
||||
znx_normalize_first_step_ref(base2k, base2k - 1, &mut y0, &a, &mut c0);
|
||||
znx_normalize_first_step_avx(base2k, base2k - 1, &mut y1, &a, &mut c1);
|
||||
|
||||
assert_eq!(y0, y1);
|
||||
assert_eq!(c0, c1);
|
||||
@@ -944,16 +1052,16 @@ mod tests {
|
||||
];
|
||||
let mut c1: [i64; 4] = c0;
|
||||
|
||||
let basek = 12;
|
||||
let base2k = 12;
|
||||
|
||||
znx_normalize_middle_step_ref(basek, 0, &mut y0, &a, &mut c0);
|
||||
znx_normalize_middle_step_avx(basek, 0, &mut y1, &a, &mut c1);
|
||||
znx_normalize_middle_step_ref(base2k, 0, &mut y0, &a, &mut c0);
|
||||
znx_normalize_middle_step_avx(base2k, 0, &mut y1, &a, &mut c1);
|
||||
|
||||
assert_eq!(y0, y1);
|
||||
assert_eq!(c0, c1);
|
||||
|
||||
znx_normalize_middle_step_ref(basek, basek - 1, &mut y0, &a, &mut c0);
|
||||
znx_normalize_middle_step_avx(basek, basek - 1, &mut y1, &a, &mut c1);
|
||||
znx_normalize_middle_step_ref(base2k, base2k - 1, &mut y0, &a, &mut c0);
|
||||
znx_normalize_middle_step_avx(base2k, base2k - 1, &mut y1, &a, &mut c1);
|
||||
|
||||
assert_eq!(y0, y1);
|
||||
assert_eq!(c0, c1);
|
||||
@@ -990,16 +1098,16 @@ mod tests {
|
||||
];
|
||||
let mut c1: [i64; 4] = c0;
|
||||
|
||||
let basek = 12;
|
||||
let base2k = 12;
|
||||
|
||||
znx_normalize_final_step_ref(basek, 0, &mut y0, &a, &mut c0);
|
||||
znx_normalize_final_step_avx(basek, 0, &mut y1, &a, &mut c1);
|
||||
znx_normalize_final_step_ref(base2k, 0, &mut y0, &a, &mut c0);
|
||||
znx_normalize_final_step_avx(base2k, 0, &mut y1, &a, &mut c1);
|
||||
|
||||
assert_eq!(y0, y1);
|
||||
assert_eq!(c0, c1);
|
||||
|
||||
znx_normalize_final_step_ref(basek, basek - 1, &mut y0, &a, &mut c0);
|
||||
znx_normalize_final_step_avx(basek, basek - 1, &mut y1, &a, &mut c1);
|
||||
znx_normalize_final_step_ref(base2k, base2k - 1, &mut y0, &a, &mut c0);
|
||||
znx_normalize_final_step_avx(base2k, base2k - 1, &mut y1, &a, &mut c1);
|
||||
|
||||
assert_eq!(y0, y1);
|
||||
assert_eq!(c0, c1);
|
||||
@@ -1015,4 +1123,86 @@ mod tests {
|
||||
test_znx_normalize_final_step_avx_internal();
|
||||
}
|
||||
}
|
||||
|
||||
#[target_feature(enable = "avx2")]
|
||||
fn znx_extract_digit_addmul_internal() {
|
||||
let mut y0: [i64; 4] = [
|
||||
7638646372408325293,
|
||||
-61440197422348985,
|
||||
6835891051541717957,
|
||||
-4835376105455195188,
|
||||
];
|
||||
let mut y1: [i64; 4] = y0;
|
||||
|
||||
let mut c0: [i64; 4] = [
|
||||
621182201135793202,
|
||||
9000856573317006236,
|
||||
5542252755421113668,
|
||||
-6036847263131690631,
|
||||
];
|
||||
let mut c1: [i64; 4] = c0;
|
||||
|
||||
let base2k: usize = 12;
|
||||
|
||||
znx_extract_digit_addmul_ref(base2k, 0, &mut y0, &mut c0);
|
||||
znx_extract_digit_addmul_avx(base2k, 0, &mut y1, &mut c1);
|
||||
|
||||
assert_eq!(y0, y1);
|
||||
assert_eq!(c0, c1);
|
||||
|
||||
znx_extract_digit_addmul_ref(base2k, base2k - 1, &mut y0, &mut c0);
|
||||
znx_extract_digit_addmul_avx(base2k, base2k - 1, &mut y1, &mut c1);
|
||||
|
||||
assert_eq!(y0, y1);
|
||||
assert_eq!(c0, c1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_znx_extract_digit_addmul_avx() {
|
||||
if !std::is_x86_feature_detected!("avx2") {
|
||||
eprintln!("skipping: CPU lacks avx2");
|
||||
return;
|
||||
};
|
||||
unsafe {
|
||||
znx_extract_digit_addmul_internal();
|
||||
}
|
||||
}
|
||||
|
||||
#[target_feature(enable = "avx2")]
|
||||
fn znx_normalize_digit_internal() {
|
||||
let mut y0: [i64; 4] = [
|
||||
7638646372408325293,
|
||||
-61440197422348985,
|
||||
6835891051541717957,
|
||||
-4835376105455195188,
|
||||
];
|
||||
let mut y1: [i64; 4] = y0;
|
||||
|
||||
let mut c0: [i64; 4] = [
|
||||
621182201135793202,
|
||||
9000856573317006236,
|
||||
5542252755421113668,
|
||||
-6036847263131690631,
|
||||
];
|
||||
let mut c1: [i64; 4] = c0;
|
||||
|
||||
let base2k: usize = 12;
|
||||
|
||||
znx_normalize_digit_ref(base2k, &mut y0, &mut c0);
|
||||
znx_normalize_digit_avx(base2k, &mut y1, &mut c1);
|
||||
|
||||
assert_eq!(y0, y1);
|
||||
assert_eq!(c0, c1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_znx_normalize_digit_internal_avx() {
|
||||
if !std::is_x86_feature_detected!("avx2") {
|
||||
eprintln!("skipping: CPU lacks avx2");
|
||||
return;
|
||||
};
|
||||
unsafe {
|
||||
znx_normalize_digit_internal();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,7 +41,7 @@ pub fn znx_sub_avx(res: &mut [i64], a: &[i64], b: &[i64]) {
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
/// all inputs must have the same length and must not alias.
|
||||
#[target_feature(enable = "avx2")]
|
||||
pub fn znx_sub_ab_inplace_avx(res: &mut [i64], a: &[i64]) {
|
||||
pub fn znx_sub_inplace_avx(res: &mut [i64], a: &[i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), a.len());
|
||||
@@ -67,9 +67,9 @@ pub fn znx_sub_ab_inplace_avx(res: &mut [i64], a: &[i64]) {
|
||||
|
||||
// tail
|
||||
if !res.len().is_multiple_of(4) {
|
||||
use poulpy_hal::reference::znx::znx_sub_ab_inplace_ref;
|
||||
use poulpy_hal::reference::znx::znx_sub_inplace_ref;
|
||||
|
||||
znx_sub_ab_inplace_ref(&mut res[span << 2..], &a[span << 2..]);
|
||||
znx_sub_inplace_ref(&mut res[span << 2..], &a[span << 2..]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,7 +77,7 @@ pub fn znx_sub_ab_inplace_avx(res: &mut [i64], a: &[i64]) {
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
/// all inputs must have the same length and must not alias.
|
||||
#[target_feature(enable = "avx2")]
|
||||
pub fn znx_sub_ba_inplace_avx(res: &mut [i64], a: &[i64]) {
|
||||
pub fn znx_sub_negate_inplace_avx(res: &mut [i64], a: &[i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), a.len());
|
||||
@@ -103,8 +103,8 @@ pub fn znx_sub_ba_inplace_avx(res: &mut [i64], a: &[i64]) {
|
||||
|
||||
// tail
|
||||
if !res.len().is_multiple_of(4) {
|
||||
use poulpy_hal::reference::znx::znx_sub_ba_inplace_ref;
|
||||
use poulpy_hal::reference::znx::znx_sub_negate_inplace_ref;
|
||||
|
||||
znx_sub_ba_inplace_ref(&mut res[span << 2..], &a[span << 2..]);
|
||||
znx_sub_negate_inplace_ref(&mut res[span << 2..], &a[span << 2..]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
use poulpy_hal::reference::fft64::{
|
||||
reim::{
|
||||
ReimAdd, ReimAddInplace, ReimAddMul, ReimCopy, ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimIFFTTable, ReimMul,
|
||||
ReimMulInplace, ReimNegate, ReimNegateInplace, ReimSub, ReimSubABInplace, ReimSubBAInplace, ReimToZnx, ReimToZnxInplace,
|
||||
ReimZero, fft_ref, ifft_ref, reim_add_inplace_ref, reim_add_ref, reim_addmul_ref, reim_copy_ref, reim_from_znx_i64_ref,
|
||||
reim_mul_inplace_ref, reim_mul_ref, reim_negate_inplace_ref, reim_negate_ref, reim_sub_ab_inplace_ref,
|
||||
reim_sub_ba_inplace_ref, reim_sub_ref, reim_to_znx_i64_inplace_ref, reim_to_znx_i64_ref, reim_zero_ref,
|
||||
ReimMulInplace, ReimNegate, ReimNegateInplace, ReimSub, ReimSubInplace, ReimSubNegateInplace, ReimToZnx,
|
||||
ReimToZnxInplace, ReimZero, fft_ref, ifft_ref, reim_add_inplace_ref, reim_add_ref, reim_addmul_ref, reim_copy_ref,
|
||||
reim_from_znx_i64_ref, reim_mul_inplace_ref, reim_mul_ref, reim_negate_inplace_ref, reim_negate_ref,
|
||||
reim_sub_inplace_ref, reim_sub_negate_inplace_ref, reim_sub_ref, reim_to_znx_i64_inplace_ref, reim_to_znx_i64_ref,
|
||||
reim_zero_ref,
|
||||
},
|
||||
reim4::{
|
||||
Reim4Extract1Blk, Reim4Mat1ColProd, Reim4Mat2Cols2ndColProd, Reim4Mat2ColsProd, Reim4Save1Blk, Reim4Save2Blks,
|
||||
@@ -69,17 +70,17 @@ impl ReimSub for FFT64Ref {
|
||||
}
|
||||
}
|
||||
|
||||
impl ReimSubABInplace for FFT64Ref {
|
||||
impl ReimSubInplace for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn reim_sub_ab_inplace(res: &mut [f64], a: &[f64]) {
|
||||
reim_sub_ab_inplace_ref(res, a);
|
||||
fn reim_sub_inplace(res: &mut [f64], a: &[f64]) {
|
||||
reim_sub_inplace_ref(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ReimSubBAInplace for FFT64Ref {
|
||||
impl ReimSubNegateInplace for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn reim_sub_ba_inplace(res: &mut [f64], a: &[f64]) {
|
||||
reim_sub_ba_inplace_ref(res, a);
|
||||
fn reim_sub_negate_inplace(res: &mut [f64], a: &[f64]) {
|
||||
reim_sub_negate_inplace_ref(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -253,9 +253,6 @@ fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8]
|
||||
(take_slice, rem_slice)
|
||||
}
|
||||
} else {
|
||||
panic!(
|
||||
"Attempted to take {} from scratch with {} aligned bytes left",
|
||||
take_len, aligned_len,
|
||||
);
|
||||
panic!("Attempted to take {take_len} from scratch with {aligned_len} aligned bytes left");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
TakeSlice, VecZnxAutomorphismInplaceTmpBytes, VecZnxMergeRingsTmpBytes, VecZnxMulXpMinusOneInplaceTmpBytes,
|
||||
VecZnxNormalizeTmpBytes, VecZnxRotateInplaceTmpBytes, VecZnxSplitRingTmpBytes,
|
||||
TakeSlice, VecZnxAutomorphismInplaceTmpBytes, VecZnxLshTmpBytes, VecZnxMergeRingsTmpBytes,
|
||||
VecZnxMulXpMinusOneInplaceTmpBytes, VecZnxNormalizeTmpBytes, VecZnxRotateInplaceTmpBytes, VecZnxRshTmpBytes,
|
||||
VecZnxSplitRingTmpBytes,
|
||||
},
|
||||
layouts::{Module, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef},
|
||||
oep::{
|
||||
@@ -12,7 +13,7 @@ use poulpy_hal::{
|
||||
VecZnxMulXpMinusOneInplaceTmpBytesImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl,
|
||||
VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl,
|
||||
VecZnxRotateInplaceTmpBytesImpl, VecZnxRshImpl, VecZnxRshInplaceImpl, VecZnxRshTmpBytesImpl, VecZnxSplitRingImpl,
|
||||
VecZnxSplitRingTmpBytesImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarImpl,
|
||||
VecZnxSplitRingTmpBytesImpl, VecZnxSubImpl, VecZnxSubInplaceImpl, VecZnxSubNegateInplaceImpl, VecZnxSubScalarImpl,
|
||||
VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl,
|
||||
},
|
||||
reference::vec_znx::{
|
||||
@@ -23,7 +24,7 @@ use poulpy_hal::{
|
||||
vec_znx_mul_xp_minus_one_inplace_tmp_bytes, vec_znx_negate, vec_znx_negate_inplace, vec_znx_normalize,
|
||||
vec_znx_normalize_inplace, vec_znx_normalize_tmp_bytes, vec_znx_rotate, vec_znx_rotate_inplace,
|
||||
vec_znx_rotate_inplace_tmp_bytes, vec_znx_rsh, vec_znx_rsh_inplace, vec_znx_rsh_tmp_bytes, vec_znx_split_ring,
|
||||
vec_znx_split_ring_tmp_bytes, vec_znx_sub, vec_znx_sub_ab_inplace, vec_znx_sub_ba_inplace, vec_znx_sub_scalar,
|
||||
vec_znx_split_ring_tmp_bytes, vec_znx_sub, vec_znx_sub_inplace, vec_znx_sub_negate_inplace, vec_znx_sub_scalar,
|
||||
vec_znx_sub_scalar_inplace, vec_znx_switch_ring,
|
||||
},
|
||||
source::Source,
|
||||
@@ -43,9 +44,10 @@ where
|
||||
{
|
||||
fn vec_znx_normalize_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
res_basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a_basek: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
@@ -54,7 +56,7 @@ where
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_normalize::<R, A, Self>(basek, res, res_col, a, a_col, carry);
|
||||
vec_znx_normalize::<R, A, Self>(res_basek, res, res_col, a_basek, a, a_col, carry);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -64,7 +66,7 @@ where
|
||||
{
|
||||
fn vec_znx_normalize_inplace_impl<R>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
base2k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
@@ -72,7 +74,7 @@ where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_normalize_inplace::<R, Self>(basek, res, res_col, carry);
|
||||
vec_znx_normalize_inplace::<R, Self>(base2k, res, res_col, carry);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -143,23 +145,23 @@ unsafe impl VecZnxSubImpl<Self> for FFT64Ref {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxSubABInplaceImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_sub_ab_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
unsafe impl VecZnxSubInplaceImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_sub_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
vec_znx_sub_ab_inplace::<R, A, Self>(res, res_col, a, a_col);
|
||||
vec_znx_sub_inplace::<R, A, Self>(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxSubBAInplaceImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_sub_ba_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
unsafe impl VecZnxSubNegateInplaceImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_sub_negate_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
vec_znx_sub_ba_inplace::<R, A, Self>(res, res_col, a, a_col);
|
||||
vec_znx_sub_negate_inplace::<R, A, Self>(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -234,9 +236,9 @@ where
|
||||
Module<Self>: VecZnxNormalizeTmpBytes,
|
||||
Scratch<Self>: TakeSlice,
|
||||
{
|
||||
fn vec_znx_lsh_inplace_impl<R, A>(
|
||||
fn vec_znx_lsh_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
base2k: usize,
|
||||
k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
@@ -247,8 +249,8 @@ where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_lsh::<_, _, Self>(basek, k, res, res_col, a, a_col, carry);
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_lsh_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_lsh::<_, _, Self>(base2k, k, res, res_col, a, a_col, carry);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -259,7 +261,7 @@ where
|
||||
{
|
||||
fn vec_znx_lsh_inplace_impl<A>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
base2k: usize,
|
||||
k: usize,
|
||||
a: &mut A,
|
||||
a_col: usize,
|
||||
@@ -267,8 +269,8 @@ where
|
||||
) where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_lsh_inplace::<_, Self>(basek, k, a, a_col, carry);
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_lsh_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_lsh_inplace::<_, Self>(base2k, k, a, a_col, carry);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -277,9 +279,9 @@ where
|
||||
Module<Self>: VecZnxNormalizeTmpBytes,
|
||||
Scratch<Self>: TakeSlice,
|
||||
{
|
||||
fn vec_znx_rsh_inplace_impl<R, A>(
|
||||
fn vec_znx_rsh_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
base2k: usize,
|
||||
k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
@@ -290,8 +292,8 @@ where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_rsh::<_, _, Self>(basek, k, res, res_col, a, a_col, carry);
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_rsh_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_rsh::<_, _, Self>(base2k, k, res, res_col, a, a_col, carry);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -302,7 +304,7 @@ where
|
||||
{
|
||||
fn vec_znx_rsh_inplace_impl<A>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
base2k: usize,
|
||||
k: usize,
|
||||
a: &mut A,
|
||||
a_col: usize,
|
||||
@@ -310,8 +312,8 @@ where
|
||||
) where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_rsh_inplace::<_, Self>(basek, k, a, a_col, carry);
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_rsh_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_rsh_inplace::<_, Self>(base2k, k, a, a_col, carry);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -495,18 +497,18 @@ unsafe impl VecZnxCopyImpl<Self> for FFT64Ref {
|
||||
}
|
||||
|
||||
unsafe impl VecZnxFillUniformImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_fill_uniform_impl<R>(_module: &Module<Self>, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
|
||||
fn vec_znx_fill_uniform_impl<R>(_module: &Module<Self>, base2k: usize, res: &mut R, res_col: usize, source: &mut Source)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
vec_znx_fill_uniform_ref(basek, res, res_col, source)
|
||||
vec_znx_fill_uniform_ref(base2k, res, res_col, source)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxFillNormalImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_fill_normal_impl<R>(
|
||||
_module: &Module<Self>,
|
||||
basek: usize,
|
||||
base2k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
@@ -516,14 +518,14 @@ unsafe impl VecZnxFillNormalImpl<Self> for FFT64Ref {
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
vec_znx_fill_normal_ref(basek, res, res_col, k, sigma, bound, source);
|
||||
vec_znx_fill_normal_ref(base2k, res, res_col, k, sigma, bound, source);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxAddNormalImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_add_normal_impl<R>(
|
||||
_module: &Module<Self>,
|
||||
basek: usize,
|
||||
base2k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
@@ -533,6 +535,6 @@ unsafe impl VecZnxAddNormalImpl<Self> for FFT64Ref {
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
vec_znx_add_normal_ref(basek, res, res_col, k, sigma, bound, source);
|
||||
vec_znx_add_normal_ref(base2k, res, res_col, k, sigma, bound, source);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,15 +10,15 @@ use poulpy_hal::{
|
||||
VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, VecZnxBigAutomorphismImpl,
|
||||
VecZnxBigAutomorphismInplaceImpl, VecZnxBigAutomorphismInplaceTmpBytesImpl, VecZnxBigFromBytesImpl,
|
||||
VecZnxBigFromSmallImpl, VecZnxBigNegateImpl, VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl,
|
||||
VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubABInplaceImpl, VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl,
|
||||
VecZnxBigSubSmallAImpl, VecZnxBigSubSmallAInplaceImpl, VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl,
|
||||
VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubImpl, VecZnxBigSubInplaceImpl, VecZnxBigSubNegateInplaceImpl,
|
||||
VecZnxBigSubSmallAImpl, VecZnxBigSubSmallBImpl, VecZnxBigSubSmallInplaceImpl, VecZnxBigSubSmallNegateInplaceImpl,
|
||||
},
|
||||
reference::{
|
||||
fft64::vec_znx_big::{
|
||||
vec_znx_big_add, vec_znx_big_add_inplace, vec_znx_big_add_normal_ref, vec_znx_big_add_small,
|
||||
vec_znx_big_add_small_inplace, vec_znx_big_automorphism, vec_znx_big_automorphism_inplace,
|
||||
vec_znx_big_automorphism_inplace_tmp_bytes, vec_znx_big_negate, vec_znx_big_negate_inplace, vec_znx_big_normalize,
|
||||
vec_znx_big_normalize_tmp_bytes, vec_znx_big_sub, vec_znx_big_sub_ab_inplace, vec_znx_big_sub_ba_inplace,
|
||||
vec_znx_big_normalize_tmp_bytes, vec_znx_big_sub, vec_znx_big_sub_inplace, vec_znx_big_sub_negate_inplace,
|
||||
vec_znx_big_sub_small_a, vec_znx_big_sub_small_a_inplace, vec_znx_big_sub_small_b, vec_znx_big_sub_small_b_inplace,
|
||||
},
|
||||
znx::{znx_copy_ref, znx_zero_ref},
|
||||
@@ -76,7 +76,7 @@ unsafe impl VecZnxBigFromSmallImpl<Self> for FFT64Ref {
|
||||
unsafe impl VecZnxBigAddNormalImpl<Self> for FFT64Ref {
|
||||
fn add_normal_impl<R: VecZnxBigToMut<Self>>(
|
||||
_module: &Module<Self>,
|
||||
basek: usize,
|
||||
base2k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
@@ -84,7 +84,7 @@ unsafe impl VecZnxBigAddNormalImpl<Self> for FFT64Ref {
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) {
|
||||
vec_znx_big_add_normal_ref(basek, res, res_col, k, sigma, bound, source);
|
||||
vec_znx_big_add_normal_ref(base2k, res, res_col, k, sigma, bound, source);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -167,25 +167,25 @@ unsafe impl VecZnxBigSubImpl<Self> for FFT64Ref {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubABInplaceImpl<Self> for FFT64Ref {
|
||||
unsafe impl VecZnxBigSubInplaceImpl<Self> for FFT64Ref {
|
||||
/// Subtracts `a` from `b` and stores the result on `b`.
|
||||
fn vec_znx_big_sub_ab_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
fn vec_znx_big_sub_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxBigToRef<Self>,
|
||||
{
|
||||
vec_znx_big_sub_ab_inplace(res, res_col, a, a_col);
|
||||
vec_znx_big_sub_inplace(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubBAInplaceImpl<Self> for FFT64Ref {
|
||||
unsafe impl VecZnxBigSubNegateInplaceImpl<Self> for FFT64Ref {
|
||||
/// Subtracts `b` from `a` and stores the result on `b`.
|
||||
fn vec_znx_big_sub_ba_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
fn vec_znx_big_sub_negate_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxBigToRef<Self>,
|
||||
{
|
||||
vec_znx_big_sub_ba_inplace(res, res_col, a, a_col);
|
||||
vec_znx_big_sub_negate_inplace(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -208,9 +208,9 @@ unsafe impl VecZnxBigSubSmallAImpl<Self> for FFT64Ref {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubSmallAInplaceImpl<Self> for FFT64Ref {
|
||||
unsafe impl VecZnxBigSubSmallInplaceImpl<Self> for FFT64Ref {
|
||||
/// Subtracts `a` from `res` and stores the result on `res`.
|
||||
fn vec_znx_big_sub_small_a_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
fn vec_znx_big_sub_small_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxToRef,
|
||||
@@ -238,9 +238,9 @@ unsafe impl VecZnxBigSubSmallBImpl<Self> for FFT64Ref {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubSmallBInplaceImpl<Self> for FFT64Ref {
|
||||
unsafe impl VecZnxBigSubSmallNegateInplaceImpl<Self> for FFT64Ref {
|
||||
/// Subtracts `res` from `a` and stores the result on `res`.
|
||||
fn vec_znx_big_sub_small_b_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
fn vec_znx_big_sub_small_negate_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxToRef,
|
||||
@@ -280,9 +280,10 @@ where
|
||||
{
|
||||
fn vec_znx_big_normalize_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
res_basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a_basek: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
@@ -291,7 +292,7 @@ where
|
||||
A: VecZnxBigToRef<Self>,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_big_normalize(basek, res, res_col, a, a_col, carry);
|
||||
vec_znx_big_normalize(res_basek, res, res_col, a_basek, a, a_col, carry);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -326,7 +327,7 @@ where
|
||||
) where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
{
|
||||
let (tmp, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes() / size_of::<i64>());
|
||||
let (tmp, _) = scratch.take_slice(module.vec_znx_big_automorphism_inplace_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_big_automorphism_inplace(p, res, res_col, tmp);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,12 +5,12 @@ use poulpy_hal::{
|
||||
},
|
||||
oep::{
|
||||
VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftApplyImpl,
|
||||
VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl, VecZnxDftSubImpl,
|
||||
VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubImpl, VecZnxDftSubInplaceImpl, VecZnxDftSubNegateInplaceImpl,
|
||||
VecZnxDftZeroImpl, VecZnxIdftApplyConsumeImpl, VecZnxIdftApplyImpl, VecZnxIdftApplyTmpAImpl, VecZnxIdftApplyTmpBytesImpl,
|
||||
},
|
||||
reference::fft64::vec_znx_dft::{
|
||||
vec_znx_dft_add, vec_znx_dft_add_inplace, vec_znx_dft_apply, vec_znx_dft_copy, vec_znx_dft_sub,
|
||||
vec_znx_dft_sub_ab_inplace, vec_znx_dft_sub_ba_inplace, vec_znx_dft_zero, vec_znx_idft_apply, vec_znx_idft_apply_consume,
|
||||
vec_znx_dft_add, vec_znx_dft_add_inplace, vec_znx_dft_apply, vec_znx_dft_copy, vec_znx_dft_sub, vec_znx_dft_sub_inplace,
|
||||
vec_znx_dft_sub_negate_inplace, vec_znx_dft_zero, vec_znx_idft_apply, vec_znx_idft_apply_consume,
|
||||
vec_znx_idft_apply_tmpa,
|
||||
},
|
||||
};
|
||||
@@ -139,23 +139,23 @@ unsafe impl VecZnxDftSubImpl<Self> for FFT64Ref {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftSubABInplaceImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_dft_sub_ab_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
unsafe impl VecZnxDftSubInplaceImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_dft_sub_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
{
|
||||
vec_znx_dft_sub_ab_inplace(res, res_col, a, a_col);
|
||||
vec_znx_dft_sub_inplace(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftSubBAInplaceImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_dft_sub_ba_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
unsafe impl VecZnxDftSubNegateInplaceImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_dft_sub_negate_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
{
|
||||
vec_znx_dft_sub_ba_inplace(res, res_col, a, a_col);
|
||||
vec_znx_dft_sub_negate_inplace(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -18,21 +18,21 @@ unsafe impl ZnNormalizeInplaceImpl<Self> for FFT64Ref
|
||||
where
|
||||
Self: TakeSliceImpl<Self>,
|
||||
{
|
||||
fn zn_normalize_inplace_impl<R>(n: usize, basek: usize, res: &mut R, res_col: usize, scratch: &mut Scratch<Self>)
|
||||
fn zn_normalize_inplace_impl<R>(n: usize, base2k: usize, res: &mut R, res_col: usize, scratch: &mut Scratch<Self>)
|
||||
where
|
||||
R: ZnToMut,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(n);
|
||||
zn_normalize_inplace::<R, FFT64Ref>(n, basek, res, res_col, carry);
|
||||
zn_normalize_inplace::<R, FFT64Ref>(n, base2k, res, res_col, carry);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ZnFillUniformImpl<Self> for FFT64Ref {
|
||||
fn zn_fill_uniform_impl<R>(n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
|
||||
fn zn_fill_uniform_impl<R>(n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source)
|
||||
where
|
||||
R: ZnToMut,
|
||||
{
|
||||
zn_fill_uniform(n, basek, res, res_col, source);
|
||||
zn_fill_uniform(n, base2k, res, res_col, source);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ unsafe impl ZnFillNormalImpl<Self> for FFT64Ref {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn zn_fill_normal_impl<R>(
|
||||
n: usize,
|
||||
basek: usize,
|
||||
base2k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
@@ -50,7 +50,7 @@ unsafe impl ZnFillNormalImpl<Self> for FFT64Ref {
|
||||
) where
|
||||
R: ZnToMut,
|
||||
{
|
||||
zn_fill_normal(n, basek, res, res_col, k, source, sigma, bound);
|
||||
zn_fill_normal(n, base2k, res, res_col, k, source, sigma, bound);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,7 +58,7 @@ unsafe impl ZnAddNormalImpl<Self> for FFT64Ref {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn zn_add_normal_impl<R>(
|
||||
n: usize,
|
||||
basek: usize,
|
||||
base2k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
@@ -68,6 +68,6 @@ unsafe impl ZnAddNormalImpl<Self> for FFT64Ref {
|
||||
) where
|
||||
R: ZnToMut,
|
||||
{
|
||||
zn_add_normal(n, basek, res, res_col, k, source, sigma, bound);
|
||||
zn_add_normal(n, base2k, res, res_col, k, source, sigma, bound);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
use poulpy_hal::reference::znx::{
|
||||
ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxNormalizeFinalStep,
|
||||
ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeFirstStepInplace,
|
||||
ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, ZnxRotate, ZnxSub, ZnxSubABInplace,
|
||||
ZnxSubBAInplace, ZnxSwitchRing, ZnxZero, znx_add_inplace_ref, znx_add_ref, znx_automorphism_ref, znx_copy_ref,
|
||||
znx_negate_inplace_ref, znx_negate_ref, znx_normalize_final_step_inplace_ref, znx_normalize_final_step_ref,
|
||||
ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxExtractDigitAddMul, ZnxMulAddPowerOfTwo, ZnxMulPowerOfTwo,
|
||||
ZnxMulPowerOfTwoInplace, ZnxNegate, ZnxNegateInplace, ZnxNormalizeDigit, ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace,
|
||||
ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStep,
|
||||
ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, ZnxRotate, ZnxSub, ZnxSubInplace, ZnxSubNegateInplace,
|
||||
ZnxSwitchRing, ZnxZero, znx_add_inplace_ref, znx_add_ref, znx_automorphism_ref, znx_copy_ref, znx_extract_digit_addmul_ref,
|
||||
znx_mul_add_power_of_two_ref, znx_mul_power_of_two_inplace_ref, znx_mul_power_of_two_ref, znx_negate_inplace_ref,
|
||||
znx_negate_ref, znx_normalize_digit_ref, znx_normalize_final_step_inplace_ref, znx_normalize_final_step_ref,
|
||||
znx_normalize_first_step_carry_only_ref, znx_normalize_first_step_inplace_ref, znx_normalize_first_step_ref,
|
||||
znx_normalize_middle_step_carry_only_ref, znx_normalize_middle_step_inplace_ref, znx_normalize_middle_step_ref, znx_rotate,
|
||||
znx_sub_ab_inplace_ref, znx_sub_ba_inplace_ref, znx_sub_ref, znx_switch_ring_ref, znx_zero_ref,
|
||||
znx_sub_inplace_ref, znx_sub_negate_inplace_ref, znx_sub_ref, znx_switch_ring_ref, znx_zero_ref,
|
||||
};
|
||||
|
||||
use crate::cpu_fft64_ref::FFT64Ref;
|
||||
@@ -32,17 +34,38 @@ impl ZnxSub for FFT64Ref {
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxSubABInplace for FFT64Ref {
|
||||
impl ZnxSubInplace for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn znx_sub_ab_inplace(res: &mut [i64], a: &[i64]) {
|
||||
znx_sub_ab_inplace_ref(res, a);
|
||||
fn znx_sub_inplace(res: &mut [i64], a: &[i64]) {
|
||||
znx_sub_inplace_ref(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxSubBAInplace for FFT64Ref {
|
||||
impl ZnxSubNegateInplace for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn znx_sub_ba_inplace(res: &mut [i64], a: &[i64]) {
|
||||
znx_sub_ba_inplace_ref(res, a);
|
||||
fn znx_sub_negate_inplace(res: &mut [i64], a: &[i64]) {
|
||||
znx_sub_negate_inplace_ref(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxMulAddPowerOfTwo for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn znx_muladd_power_of_two(k: i64, res: &mut [i64], a: &[i64]) {
|
||||
znx_mul_add_power_of_two_ref(k, res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxMulPowerOfTwo for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn znx_mul_power_of_two(k: i64, res: &mut [i64], a: &[i64]) {
|
||||
znx_mul_power_of_two_ref(k, res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxMulPowerOfTwoInplace for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn znx_mul_power_of_two_inplace(k: i64, res: &mut [i64]) {
|
||||
znx_mul_power_of_two_inplace_ref(k, res);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -97,56 +120,70 @@ impl ZnxSwitchRing for FFT64Ref {
|
||||
|
||||
impl ZnxNormalizeFinalStep for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_final_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_final_step_ref(basek, lsh, x, a, carry);
|
||||
fn znx_normalize_final_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_final_step_ref(base2k, lsh, x, a, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFinalStepInplace for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_final_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
znx_normalize_final_step_inplace_ref(basek, lsh, x, carry);
|
||||
fn znx_normalize_final_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
znx_normalize_final_step_inplace_ref(base2k, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFirstStep for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_first_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_first_step_ref(basek, lsh, x, a, carry);
|
||||
fn znx_normalize_first_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_first_step_ref(base2k, lsh, x, a, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFirstStepCarryOnly for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_first_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_first_step_carry_only_ref(basek, lsh, x, carry);
|
||||
fn znx_normalize_first_step_carry_only(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_first_step_carry_only_ref(base2k, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFirstStepInplace for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_first_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
znx_normalize_first_step_inplace_ref(basek, lsh, x, carry);
|
||||
fn znx_normalize_first_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
znx_normalize_first_step_inplace_ref(base2k, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeMiddleStep for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_middle_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_middle_step_ref(basek, lsh, x, a, carry);
|
||||
fn znx_normalize_middle_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_middle_step_ref(base2k, lsh, x, a, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeMiddleStepCarryOnly for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_middle_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_middle_step_carry_only_ref(basek, lsh, x, carry);
|
||||
fn znx_normalize_middle_step_carry_only(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_middle_step_carry_only_ref(base2k, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeMiddleStepInplace for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_middle_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
znx_normalize_middle_step_inplace_ref(basek, lsh, x, carry);
|
||||
fn znx_normalize_middle_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
znx_normalize_middle_step_inplace_ref(base2k, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxExtractDigitAddMul for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn znx_extract_digit_addmul(base2k: usize, lsh: usize, res: &mut [i64], src: &mut [i64]) {
|
||||
znx_extract_digit_addmul_ref(base2k, lsh, res, src);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeDigit for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_digit(base2k: usize, res: &mut [i64], src: &mut [i64]) {
|
||||
znx_normalize_digit_ref(base2k, res, src);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,5 +6,6 @@ mod vec_znx_big;
|
||||
mod vec_znx_dft;
|
||||
mod vmp_pmat;
|
||||
mod zn;
|
||||
mod znx;
|
||||
|
||||
pub struct FFT64Spqlios;
|
||||
|
||||
@@ -3,20 +3,11 @@ use std::ptr::NonNull;
|
||||
use poulpy_hal::{
|
||||
layouts::{Backend, Module},
|
||||
oep::ModuleNewImpl,
|
||||
reference::znx::{
|
||||
ZnxCopy, ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly,
|
||||
ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace,
|
||||
ZnxRotate, ZnxSwitchRing, ZnxZero, znx_copy_ref, znx_normalize_final_step_inplace_ref, znx_normalize_final_step_ref,
|
||||
znx_normalize_first_step_carry_only_ref, znx_normalize_first_step_inplace_ref, znx_normalize_first_step_ref,
|
||||
znx_normalize_middle_step_carry_only_ref, znx_normalize_middle_step_inplace_ref, znx_normalize_middle_step_ref,
|
||||
znx_switch_ring_ref, znx_zero_ref,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::cpu_spqlios::{
|
||||
FFT64Spqlios,
|
||||
ffi::module::{MODULE, delete_module_info, new_module_info},
|
||||
znx::znx_rotate_i64,
|
||||
};
|
||||
|
||||
impl Backend for FFT64Spqlios {
|
||||
@@ -41,85 +32,3 @@ unsafe impl ModuleNewImpl<Self> for FFT64Spqlios {
|
||||
unsafe { Module::from_raw_parts(new_module_info(n, 0), n) }
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxCopy for FFT64Spqlios {
|
||||
fn znx_copy(res: &mut [i64], a: &[i64]) {
|
||||
znx_copy_ref(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxZero for FFT64Spqlios {
|
||||
fn znx_zero(res: &mut [i64]) {
|
||||
znx_zero_ref(res);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxSwitchRing for FFT64Spqlios {
|
||||
fn znx_switch_ring(res: &mut [i64], a: &[i64]) {
|
||||
znx_switch_ring_ref(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxRotate for FFT64Spqlios {
|
||||
fn znx_rotate(p: i64, res: &mut [i64], src: &[i64]) {
|
||||
unsafe {
|
||||
znx_rotate_i64(res.len() as u64, p, res.as_mut_ptr(), src.as_ptr());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFinalStep for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_final_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_final_step_ref(basek, lsh, x, a, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFinalStepInplace for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_final_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
znx_normalize_final_step_inplace_ref(basek, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFirstStep for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_first_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_first_step_ref(basek, lsh, x, a, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFirstStepCarryOnly for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_first_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_first_step_carry_only_ref(basek, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFirstStepInplace for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_first_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
znx_normalize_first_step_inplace_ref(basek, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeMiddleStep for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_middle_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_middle_step_ref(basek, lsh, x, a, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeMiddleStepCarryOnly for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_middle_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_middle_step_carry_only_ref(basek, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeMiddleStepInplace for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_middle_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
znx_normalize_middle_step_inplace_ref(basek, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -253,9 +253,6 @@ fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8]
|
||||
(take_slice, rem_slice)
|
||||
}
|
||||
} else {
|
||||
panic!(
|
||||
"Attempted to take {} from scratch with {} aligned bytes left",
|
||||
take_len, aligned_len,
|
||||
);
|
||||
panic!("Attempted to take {take_len} from scratch with {aligned_len} aligned bytes left");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
use poulpy_hal::{
|
||||
api::{TakeSlice, VecZnxMergeRingsTmpBytes, VecZnxNormalizeTmpBytes, VecZnxSplitRingTmpBytes},
|
||||
api::{
|
||||
TakeSlice, VecZnxLshTmpBytes, VecZnxMergeRingsTmpBytes, VecZnxNormalizeTmpBytes, VecZnxRshTmpBytes,
|
||||
VecZnxSplitRingTmpBytes,
|
||||
},
|
||||
layouts::{
|
||||
Module, ScalarZnx, ScalarZnxToRef, Scratch, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut,
|
||||
},
|
||||
@@ -11,16 +14,16 @@ use poulpy_hal::{
|
||||
VecZnxMulXpMinusOneInplaceTmpBytesImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl,
|
||||
VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl,
|
||||
VecZnxRotateInplaceTmpBytesImpl, VecZnxRshImpl, VecZnxRshInplaceImpl, VecZnxRshTmpBytesImpl, VecZnxSplitRingImpl,
|
||||
VecZnxSplitRingTmpBytesImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarImpl,
|
||||
VecZnxSplitRingTmpBytesImpl, VecZnxSubImpl, VecZnxSubInplaceImpl, VecZnxSubNegateInplaceImpl, VecZnxSubScalarImpl,
|
||||
VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl,
|
||||
},
|
||||
reference::{
|
||||
vec_znx::{
|
||||
vec_znx_add_normal_ref, vec_znx_automorphism_inplace_tmp_bytes, vec_znx_copy, vec_znx_fill_normal_ref,
|
||||
vec_znx_fill_uniform_ref, vec_znx_lsh, vec_znx_lsh_inplace, vec_znx_lsh_tmp_bytes, vec_znx_merge_rings,
|
||||
vec_znx_merge_rings_tmp_bytes, vec_znx_mul_xp_minus_one_inplace_tmp_bytes, vec_znx_rotate_inplace_tmp_bytes,
|
||||
vec_znx_rsh, vec_znx_rsh_inplace, vec_znx_rsh_tmp_bytes, vec_znx_split_ring, vec_znx_split_ring_tmp_bytes,
|
||||
vec_znx_switch_ring,
|
||||
vec_znx_merge_rings_tmp_bytes, vec_znx_mul_xp_minus_one_inplace_tmp_bytes, vec_znx_normalize_tmp_bytes,
|
||||
vec_znx_rotate_inplace_tmp_bytes, vec_znx_rsh, vec_znx_rsh_inplace, vec_znx_rsh_tmp_bytes, vec_znx_split_ring,
|
||||
vec_znx_split_ring_tmp_bytes, vec_znx_switch_ring,
|
||||
},
|
||||
znx::{znx_copy_ref, znx_zero_ref},
|
||||
},
|
||||
@@ -34,7 +37,7 @@ use crate::cpu_spqlios::{
|
||||
|
||||
unsafe impl VecZnxNormalizeTmpBytesImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_normalize_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(module.ptr() as *const module_info_t) as usize }
|
||||
vec_znx_normalize_tmp_bytes(module.n())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,9 +47,10 @@ where
|
||||
{
|
||||
fn vec_znx_normalize_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
res_basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a_basek: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
@@ -60,6 +64,10 @@ where
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), a.n());
|
||||
assert_eq!(
|
||||
res_basek, a_basek,
|
||||
"res_basek != a_basek -> base2k conversion is not supported"
|
||||
)
|
||||
}
|
||||
|
||||
let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes());
|
||||
@@ -67,7 +75,7 @@ where
|
||||
unsafe {
|
||||
vec_znx::vec_znx_normalize_base2k(
|
||||
module.ptr() as *const module_info_t,
|
||||
basek as u64,
|
||||
res_basek as u64,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
@@ -86,7 +94,7 @@ where
|
||||
{
|
||||
fn vec_znx_normalize_inplace_impl<A>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
base2k: usize,
|
||||
a: &mut A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
@@ -100,7 +108,7 @@ where
|
||||
unsafe {
|
||||
vec_znx::vec_znx_normalize_base2k(
|
||||
module.ptr() as *const module_info_t,
|
||||
basek as u64,
|
||||
base2k as u64,
|
||||
a.at_mut_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
@@ -301,8 +309,8 @@ unsafe impl VecZnxSubImpl<Self> for FFT64Spqlios {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxSubABInplaceImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_sub_ab_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
unsafe impl VecZnxSubInplaceImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_sub_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
@@ -330,8 +338,8 @@ unsafe impl VecZnxSubABInplaceImpl<Self> for FFT64Spqlios {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxSubBAInplaceImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_sub_ba_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
unsafe impl VecZnxSubNegateInplaceImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_sub_negate_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
@@ -512,9 +520,9 @@ where
|
||||
Module<Self>: VecZnxNormalizeTmpBytes,
|
||||
Scratch<Self>: TakeSlice,
|
||||
{
|
||||
fn vec_znx_lsh_inplace_impl<R, A>(
|
||||
fn vec_znx_lsh_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
base2k: usize,
|
||||
k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
@@ -525,8 +533,8 @@ where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_lsh::<_, _, FFT64Spqlios>(basek, k, res, res_col, a, a_col, carry)
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_lsh_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_lsh::<_, _, FFT64Spqlios>(base2k, k, res, res_col, a, a_col, carry)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -537,7 +545,7 @@ where
|
||||
{
|
||||
fn vec_znx_lsh_inplace_impl<A>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
base2k: usize,
|
||||
k: usize,
|
||||
a: &mut A,
|
||||
a_col: usize,
|
||||
@@ -545,8 +553,8 @@ where
|
||||
) where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_lsh_inplace::<_, FFT64Spqlios>(basek, k, a, a_col, carry)
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_lsh_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_lsh_inplace::<_, FFT64Spqlios>(base2k, k, a, a_col, carry)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -555,9 +563,9 @@ where
|
||||
Module<Self>: VecZnxNormalizeTmpBytes,
|
||||
Scratch<Self>: TakeSlice,
|
||||
{
|
||||
fn vec_znx_rsh_inplace_impl<R, A>(
|
||||
fn vec_znx_rsh_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
base2k: usize,
|
||||
k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
@@ -568,8 +576,8 @@ where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_rsh::<_, _, FFT64Spqlios>(basek, k, res, res_col, a, a_col, carry)
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_rsh_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_rsh::<_, _, FFT64Spqlios>(base2k, k, res, res_col, a, a_col, carry)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -580,7 +588,7 @@ where
|
||||
{
|
||||
fn vec_znx_rsh_inplace_impl<A>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
base2k: usize,
|
||||
k: usize,
|
||||
a: &mut A,
|
||||
a_col: usize,
|
||||
@@ -588,8 +596,8 @@ where
|
||||
) where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_rsh_inplace::<_, FFT64Spqlios>(basek, k, a, a_col, carry)
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_rsh_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_rsh_inplace::<_, FFT64Spqlios>(base2k, k, a, a_col, carry)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -690,11 +698,7 @@ unsafe impl VecZnxAutomorphismInplaceImpl<Self> for FFT64Spqlios {
|
||||
let mut a: VecZnx<&mut [u8]> = a.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(
|
||||
k & 1 != 0,
|
||||
"invalid galois element: must be odd but is {}",
|
||||
k
|
||||
);
|
||||
assert!(k & 1 != 0, "invalid galois element: must be odd but is {k}");
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
@@ -852,18 +856,18 @@ unsafe impl VecZnxCopyImpl<Self> for FFT64Spqlios {
|
||||
}
|
||||
|
||||
unsafe impl VecZnxFillUniformImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_fill_uniform_impl<R>(_module: &Module<Self>, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
|
||||
fn vec_znx_fill_uniform_impl<R>(_module: &Module<Self>, base2k: usize, res: &mut R, res_col: usize, source: &mut Source)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
vec_znx_fill_uniform_ref(basek, res, res_col, source)
|
||||
vec_znx_fill_uniform_ref(base2k, res, res_col, source)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxFillNormalImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_fill_normal_impl<R>(
|
||||
_module: &Module<Self>,
|
||||
basek: usize,
|
||||
base2k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
@@ -873,14 +877,14 @@ unsafe impl VecZnxFillNormalImpl<Self> for FFT64Spqlios {
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
vec_znx_fill_normal_ref(basek, res, res_col, k, sigma, bound, source);
|
||||
vec_znx_fill_normal_ref(base2k, res, res_col, k, sigma, bound, source);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxAddNormalImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_add_normal_impl<R>(
|
||||
_module: &Module<Self>,
|
||||
basek: usize,
|
||||
base2k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
@@ -890,6 +894,6 @@ unsafe impl VecZnxAddNormalImpl<Self> for FFT64Spqlios {
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
vec_znx_add_normal_ref(basek, res, res_col, k, sigma, bound, source);
|
||||
vec_znx_add_normal_ref(base2k, res, res_col, k, sigma, bound, source);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,11 +10,12 @@ use poulpy_hal::{
|
||||
VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, VecZnxBigAutomorphismImpl,
|
||||
VecZnxBigAutomorphismInplaceImpl, VecZnxBigAutomorphismInplaceTmpBytesImpl, VecZnxBigFromBytesImpl,
|
||||
VecZnxBigFromSmallImpl, VecZnxBigNegateImpl, VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl,
|
||||
VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubABInplaceImpl, VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl,
|
||||
VecZnxBigSubSmallAImpl, VecZnxBigSubSmallAInplaceImpl, VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl,
|
||||
VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubImpl, VecZnxBigSubInplaceImpl, VecZnxBigSubNegateInplaceImpl,
|
||||
VecZnxBigSubSmallAImpl, VecZnxBigSubSmallBImpl, VecZnxBigSubSmallInplaceImpl, VecZnxBigSubSmallNegateInplaceImpl,
|
||||
},
|
||||
reference::{
|
||||
vec_znx::vec_znx_add_normal_ref,
|
||||
fft64::vec_znx_big::vec_znx_big_normalize,
|
||||
vec_znx::{vec_znx_add_normal_ref, vec_znx_normalize_tmp_bytes},
|
||||
znx::{znx_copy_ref, znx_zero_ref},
|
||||
},
|
||||
source::Source,
|
||||
@@ -70,7 +71,7 @@ unsafe impl VecZnxBigFromSmallImpl<Self> for FFT64Spqlios {
|
||||
unsafe impl VecZnxBigAddNormalImpl<Self> for FFT64Spqlios {
|
||||
fn add_normal_impl<R: VecZnxBigToMut<Self>>(
|
||||
_module: &Module<Self>,
|
||||
basek: usize,
|
||||
base2k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
@@ -88,7 +89,7 @@ unsafe impl VecZnxBigAddNormalImpl<Self> for FFT64Spqlios {
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
vec_znx_add_normal_ref(basek, &mut res_znx, res_col, k, sigma, bound, source);
|
||||
vec_znx_add_normal_ref(base2k, &mut res_znx, res_col, k, sigma, bound, source);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -266,9 +267,9 @@ unsafe impl VecZnxBigSubImpl<Self> for FFT64Spqlios {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubABInplaceImpl<Self> for FFT64Spqlios {
|
||||
unsafe impl VecZnxBigSubInplaceImpl<Self> for FFT64Spqlios {
|
||||
/// Subtracts `a` from `b` and stores the result on `b`.
|
||||
fn vec_znx_big_sub_ab_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
fn vec_znx_big_sub_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxBigToRef<Self>,
|
||||
@@ -297,9 +298,9 @@ unsafe impl VecZnxBigSubABInplaceImpl<Self> for FFT64Spqlios {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubBAInplaceImpl<Self> for FFT64Spqlios {
|
||||
unsafe impl VecZnxBigSubNegateInplaceImpl<Self> for FFT64Spqlios {
|
||||
/// Subtracts `b` from `a` and stores the result on `b`.
|
||||
fn vec_znx_big_sub_ba_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
fn vec_znx_big_sub_negate_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxBigToRef<Self>,
|
||||
@@ -370,9 +371,9 @@ unsafe impl VecZnxBigSubSmallAImpl<Self> for FFT64Spqlios {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubSmallAInplaceImpl<Self> for FFT64Spqlios {
|
||||
unsafe impl VecZnxBigSubSmallInplaceImpl<Self> for FFT64Spqlios {
|
||||
/// Subtracts `a` from `res` and stores the result on `res`.
|
||||
fn vec_znx_big_sub_small_a_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
fn vec_znx_big_sub_small_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxToRef,
|
||||
@@ -443,9 +444,9 @@ unsafe impl VecZnxBigSubSmallBImpl<Self> for FFT64Spqlios {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubSmallBInplaceImpl<Self> for FFT64Spqlios {
|
||||
unsafe impl VecZnxBigSubSmallNegateInplaceImpl<Self> for FFT64Spqlios {
|
||||
/// Subtracts `res` from `a` and stores the result on `res`.
|
||||
fn vec_znx_big_sub_small_b_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
fn vec_znx_big_sub_small_negate_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxToRef,
|
||||
@@ -518,7 +519,7 @@ unsafe impl VecZnxBigNegateInplaceImpl<Self> for FFT64Spqlios {
|
||||
|
||||
unsafe impl VecZnxBigNormalizeTmpBytesImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_big_normalize_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(module.ptr()) as usize }
|
||||
vec_znx_normalize_tmp_bytes(module.n())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -528,9 +529,10 @@ where
|
||||
{
|
||||
fn vec_znx_big_normalize_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
res_basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a_basek: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
@@ -538,28 +540,21 @@ where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxBigToRef<Self>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], Self> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), a.n());
|
||||
}
|
||||
|
||||
let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes());
|
||||
unsafe {
|
||||
vec_znx::vec_znx_normalize_base2k(
|
||||
module.ptr(),
|
||||
basek as u64,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
);
|
||||
}
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes() / size_of::<i64>());
|
||||
// unsafe {
|
||||
// vec_znx::vec_znx_normalize_base2k(
|
||||
// module.ptr(),
|
||||
// base2k as u64,
|
||||
// res.at_mut_ptr(res_col, 0),
|
||||
// res.size() as u64,
|
||||
// res.sl() as u64,
|
||||
// a.at_ptr(a_col, 0),
|
||||
// a.size() as u64,
|
||||
// a.sl() as u64,
|
||||
// tmp_bytes.as_mut_ptr(),
|
||||
// );
|
||||
// }
|
||||
vec_znx_big_normalize(res_basek, res, res_col, a_basek, a, a_col, carry);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ use poulpy_hal::{
|
||||
},
|
||||
oep::{
|
||||
VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftApplyImpl,
|
||||
VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl, VecZnxDftSubImpl,
|
||||
VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubImpl, VecZnxDftSubInplaceImpl, VecZnxDftSubNegateInplaceImpl,
|
||||
VecZnxDftZeroImpl, VecZnxIdftApplyConsumeImpl, VecZnxIdftApplyImpl, VecZnxIdftApplyTmpAImpl, VecZnxIdftApplyTmpBytesImpl,
|
||||
},
|
||||
reference::{
|
||||
@@ -336,8 +336,8 @@ unsafe impl VecZnxDftSubImpl<Self> for FFT64Spqlios {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftSubABInplaceImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_dft_sub_ab_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
unsafe impl VecZnxDftSubInplaceImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_dft_sub_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
@@ -363,8 +363,8 @@ unsafe impl VecZnxDftSubABInplaceImpl<Self> for FFT64Spqlios {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftSubBAInplaceImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_dft_sub_ba_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
unsafe impl VecZnxDftSubNegateInplaceImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_dft_sub_negate_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
|
||||
@@ -12,7 +12,7 @@ unsafe impl ZnNormalizeInplaceImpl<Self> for FFT64Spqlios
|
||||
where
|
||||
Self: TakeSliceImpl<Self>,
|
||||
{
|
||||
fn zn_normalize_inplace_impl<A>(n: usize, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<Self>)
|
||||
fn zn_normalize_inplace_impl<A>(n: usize, base2k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<Self>)
|
||||
where
|
||||
A: ZnToMut,
|
||||
{
|
||||
@@ -23,7 +23,7 @@ where
|
||||
unsafe {
|
||||
zn64::zn64_normalize_base2k_ref(
|
||||
n as u64,
|
||||
basek as u64,
|
||||
base2k as u64,
|
||||
a.at_mut_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
@@ -37,11 +37,11 @@ where
|
||||
}
|
||||
|
||||
unsafe impl ZnFillUniformImpl<Self> for FFT64Spqlios {
|
||||
fn zn_fill_uniform_impl<R>(n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
|
||||
fn zn_fill_uniform_impl<R>(n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source)
|
||||
where
|
||||
R: ZnToMut,
|
||||
{
|
||||
zn_fill_uniform(n, basek, res, res_col, source);
|
||||
zn_fill_uniform(n, base2k, res, res_col, source);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,7 +49,7 @@ unsafe impl ZnFillNormalImpl<Self> for FFT64Spqlios {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn zn_fill_normal_impl<R>(
|
||||
n: usize,
|
||||
basek: usize,
|
||||
base2k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
@@ -59,7 +59,7 @@ unsafe impl ZnFillNormalImpl<Self> for FFT64Spqlios {
|
||||
) where
|
||||
R: ZnToMut,
|
||||
{
|
||||
zn_fill_normal(n, basek, res, res_col, k, source, sigma, bound);
|
||||
zn_fill_normal(n, base2k, res, res_col, k, source, sigma, bound);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,7 +67,7 @@ unsafe impl ZnAddNormalImpl<Self> for FFT64Spqlios {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn zn_add_normal_impl<R>(
|
||||
n: usize,
|
||||
basek: usize,
|
||||
base2k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
@@ -77,6 +77,6 @@ unsafe impl ZnAddNormalImpl<Self> for FFT64Spqlios {
|
||||
) where
|
||||
R: ZnToMut,
|
||||
{
|
||||
zn_add_normal(n, basek, res, res_col, k, source, sigma, bound);
|
||||
zn_add_normal(n, base2k, res, res_col, k, source, sigma, bound);
|
||||
}
|
||||
}
|
||||
|
||||
189
poulpy-backend/src/cpu_spqlios/fft64/znx.rs
Normal file
189
poulpy-backend/src/cpu_spqlios/fft64/znx.rs
Normal file
@@ -0,0 +1,189 @@
|
||||
use poulpy_hal::reference::znx::{
|
||||
ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxExtractDigitAddMul, ZnxMulAddPowerOfTwo, ZnxMulPowerOfTwo,
|
||||
ZnxMulPowerOfTwoInplace, ZnxNegate, ZnxNegateInplace, ZnxNormalizeDigit, ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace,
|
||||
ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStep,
|
||||
ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, ZnxRotate, ZnxSub, ZnxSubInplace, ZnxSubNegateInplace,
|
||||
ZnxSwitchRing, ZnxZero, znx_add_inplace_ref, znx_add_ref, znx_automorphism_ref, znx_copy_ref, znx_extract_digit_addmul_ref,
|
||||
znx_mul_add_power_of_two_ref, znx_mul_power_of_two_inplace_ref, znx_mul_power_of_two_ref, znx_negate_inplace_ref,
|
||||
znx_negate_ref, znx_normalize_digit_ref, znx_normalize_final_step_inplace_ref, znx_normalize_final_step_ref,
|
||||
znx_normalize_first_step_carry_only_ref, znx_normalize_first_step_inplace_ref, znx_normalize_first_step_ref,
|
||||
znx_normalize_middle_step_carry_only_ref, znx_normalize_middle_step_inplace_ref, znx_normalize_middle_step_ref, znx_rotate,
|
||||
znx_sub_inplace_ref, znx_sub_negate_inplace_ref, znx_sub_ref, znx_switch_ring_ref, znx_zero_ref,
|
||||
};
|
||||
|
||||
use crate::FFT64Spqlios;
|
||||
|
||||
impl ZnxAdd for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_add(res: &mut [i64], a: &[i64], b: &[i64]) {
|
||||
znx_add_ref(res, a, b);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxAddInplace for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_add_inplace(res: &mut [i64], a: &[i64]) {
|
||||
znx_add_inplace_ref(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxSub for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_sub(res: &mut [i64], a: &[i64], b: &[i64]) {
|
||||
znx_sub_ref(res, a, b);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxSubInplace for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_sub_inplace(res: &mut [i64], a: &[i64]) {
|
||||
znx_sub_inplace_ref(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxSubNegateInplace for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_sub_negate_inplace(res: &mut [i64], a: &[i64]) {
|
||||
znx_sub_negate_inplace_ref(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxMulAddPowerOfTwo for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_muladd_power_of_two(k: i64, res: &mut [i64], a: &[i64]) {
|
||||
znx_mul_add_power_of_two_ref(k, res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxMulPowerOfTwo for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_mul_power_of_two(k: i64, res: &mut [i64], a: &[i64]) {
|
||||
znx_mul_power_of_two_ref(k, res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxMulPowerOfTwoInplace for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_mul_power_of_two_inplace(k: i64, res: &mut [i64]) {
|
||||
znx_mul_power_of_two_inplace_ref(k, res);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxAutomorphism for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_automorphism(p: i64, res: &mut [i64], a: &[i64]) {
|
||||
znx_automorphism_ref(p, res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxCopy for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_copy(res: &mut [i64], a: &[i64]) {
|
||||
znx_copy_ref(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNegate for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_negate(res: &mut [i64], src: &[i64]) {
|
||||
znx_negate_ref(res, src);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNegateInplace for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_negate_inplace(res: &mut [i64]) {
|
||||
znx_negate_inplace_ref(res);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxRotate for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_rotate(p: i64, res: &mut [i64], src: &[i64]) {
|
||||
znx_rotate::<Self>(p, res, src);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxZero for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_zero(res: &mut [i64]) {
|
||||
znx_zero_ref(res);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxSwitchRing for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_switch_ring(res: &mut [i64], a: &[i64]) {
|
||||
znx_switch_ring_ref(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFinalStep for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_final_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_final_step_ref(base2k, lsh, x, a, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFinalStepInplace for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_final_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
znx_normalize_final_step_inplace_ref(base2k, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFirstStep for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_first_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_first_step_ref(base2k, lsh, x, a, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFirstStepCarryOnly for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_first_step_carry_only(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_first_step_carry_only_ref(base2k, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFirstStepInplace for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_first_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
znx_normalize_first_step_inplace_ref(base2k, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeMiddleStep for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_middle_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_middle_step_ref(base2k, lsh, x, a, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeMiddleStepCarryOnly for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_middle_step_carry_only(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_middle_step_carry_only_ref(base2k, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeMiddleStepInplace for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_middle_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
znx_normalize_middle_step_inplace_ref(base2k, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxExtractDigitAddMul for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_extract_digit_addmul(base2k: usize, lsh: usize, res: &mut [i64], src: &mut [i64]) {
|
||||
znx_extract_digit_addmul_ref(base2k, lsh, res, src);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeDigit for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_digit(base2k: usize, res: &mut [i64], src: &mut [i64]) {
|
||||
znx_normalize_digit_ref(base2k, res, src);
|
||||
}
|
||||
}
|
||||
@@ -5,15 +5,15 @@ cross_backend_test_suite! {
|
||||
backend_ref = crate::cpu_fft64_ref::FFT64Ref,
|
||||
backend_test = crate::cpu_spqlios::FFT64Spqlios,
|
||||
size = 1 << 5,
|
||||
basek = 12,
|
||||
base2k = 12,
|
||||
tests = {
|
||||
test_vec_znx_add => poulpy_hal::test_suite::vec_znx::test_vec_znx_add,
|
||||
test_vec_znx_add_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_inplace,
|
||||
test_vec_znx_add_scalar => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_scalar,
|
||||
test_vec_znx_add_scalar_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_scalar_inplace,
|
||||
test_vec_znx_sub => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub,
|
||||
test_vec_znx_sub_ab_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_ab_inplace,
|
||||
test_vec_znx_sub_ba_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_ba_inplace,
|
||||
test_vec_znx_sub_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_inplace,
|
||||
test_vec_znx_sub_negate_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_negate_inplace,
|
||||
test_vec_znx_sub_scalar => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_scalar,
|
||||
test_vec_znx_sub_scalar_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_scalar_inplace,
|
||||
test_vec_znx_rsh => poulpy_hal::test_suite::vec_znx::test_vec_znx_rsh,
|
||||
@@ -41,7 +41,7 @@ cross_backend_test_suite! {
|
||||
backend_ref = crate::cpu_fft64_ref::FFT64Ref,
|
||||
backend_test = crate::cpu_spqlios::FFT64Spqlios,
|
||||
size = 1 << 5,
|
||||
basek = 12,
|
||||
base2k = 12,
|
||||
tests = {
|
||||
test_svp_apply_dft_to_dft => poulpy_hal::test_suite::svp::test_svp_apply_dft_to_dft,
|
||||
test_svp_apply_dft_to_dft_inplace => poulpy_hal::test_suite::svp::test_svp_apply_dft_to_dft_inplace,
|
||||
@@ -53,20 +53,20 @@ cross_backend_test_suite! {
|
||||
backend_ref = crate::cpu_fft64_ref::FFT64Ref,
|
||||
backend_test = crate::cpu_spqlios::FFT64Spqlios,
|
||||
size = 1 << 5,
|
||||
basek = 12,
|
||||
base2k = 12,
|
||||
tests = {
|
||||
test_vec_znx_big_add => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add,
|
||||
test_vec_znx_big_add_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add_inplace,
|
||||
test_vec_znx_big_add_small => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add_small,
|
||||
test_vec_znx_big_add_small_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add_small_inplace,
|
||||
test_vec_znx_big_sub => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub,
|
||||
test_vec_znx_big_sub_ab_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_ab_inplace,
|
||||
test_vec_znx_big_sub_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_inplace,
|
||||
test_vec_znx_big_automorphism => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_automorphism,
|
||||
test_vec_znx_big_automorphism_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_automorphism_inplace,
|
||||
test_vec_znx_big_negate => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_negate,
|
||||
test_vec_znx_big_negate_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_negate_inplace,
|
||||
test_vec_znx_big_normalize => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_normalize,
|
||||
test_vec_znx_big_sub_ba_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_ba_inplace,
|
||||
test_vec_znx_big_sub_negate_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_negate_inplace,
|
||||
test_vec_znx_big_sub_small_a => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_a,
|
||||
test_vec_znx_big_sub_small_a_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_a_inplace,
|
||||
test_vec_znx_big_sub_small_b => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_b,
|
||||
@@ -79,13 +79,13 @@ cross_backend_test_suite! {
|
||||
backend_ref = crate::cpu_fft64_ref::FFT64Ref,
|
||||
backend_test = crate::cpu_spqlios::FFT64Spqlios,
|
||||
size = 1 << 5,
|
||||
basek = 12,
|
||||
base2k = 12,
|
||||
tests = {
|
||||
test_vec_znx_dft_add => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_add,
|
||||
test_vec_znx_dft_add_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_add_inplace,
|
||||
test_vec_znx_dft_sub => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub,
|
||||
test_vec_znx_dft_sub_ab_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub_ab_inplace,
|
||||
test_vec_znx_dft_sub_ba_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub_ba_inplace,
|
||||
test_vec_znx_dft_sub_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub_inplace,
|
||||
test_vec_znx_dft_sub_negate_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub_negate_inplace,
|
||||
test_vec_znx_idft_apply => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_idft_apply,
|
||||
test_vec_znx_idft_apply_consume => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_idft_apply_consume,
|
||||
test_vec_znx_idft_apply_tmpa => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_idft_apply_tmpa,
|
||||
@@ -97,7 +97,7 @@ cross_backend_test_suite! {
|
||||
backend_ref = crate::cpu_fft64_ref::FFT64Ref,
|
||||
backend_test = crate::cpu_spqlios::FFT64Spqlios,
|
||||
size = 1 << 5,
|
||||
basek = 12,
|
||||
base2k = 12,
|
||||
tests = {
|
||||
test_vmp_apply_dft_to_dft => poulpy_hal::test_suite::vmp::test_vmp_apply_dft_to_dft,
|
||||
test_vmp_apply_dft_to_dft_add => poulpy_hal::test_suite::vmp::test_vmp_apply_dft_to_dft_add,
|
||||
|
||||
Reference in New Issue
Block a user