mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
Ref. + AVX code & generic tests + benches (#85)
This commit is contained in:
committed by
GitHub
parent
99b9e3e10e
commit
56dbd29c59
@@ -1,6 +1,8 @@
|
||||
#[allow(non_camel_case_types)]
|
||||
pub mod module;
|
||||
#[allow(non_camel_case_types)]
|
||||
pub mod reim;
|
||||
#[allow(non_camel_case_types)]
|
||||
pub mod svp;
|
||||
#[allow(non_camel_case_types)]
|
||||
pub mod vec_znx;
|
||||
|
||||
172
poulpy-backend/src/cpu_spqlios/ffi/reim.rs
Normal file
172
poulpy-backend/src/cpu_spqlios/ffi/reim.rs
Normal file
@@ -0,0 +1,172 @@
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct reim_fft_precomp {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type REIM_FFT_PRECOMP = reim_fft_precomp;
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct reim_ifft_precomp {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type REIM_IFFT_PRECOMP = reim_ifft_precomp;
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct reim_mul_precomp {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type REIM_FFTVEC_MUL_PRECOMP = reim_mul_precomp;
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct reim_addmul_precomp {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type REIM_FFTVEC_ADDMUL_PRECOMP = reim_addmul_precomp;
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct reim_from_znx32_precomp {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type REIM_FROM_ZNX32_PRECOMP = reim_from_znx32_precomp;
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct reim_from_znx64_precomp {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type REIM_FROM_ZNX64_PRECOMP = reim_from_znx64_precomp;
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct reim_from_tnx32_precomp {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type REIM_FROM_TNX32_PRECOMP = reim_from_tnx32_precomp;
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct reim_to_tnx32_precomp {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type REIM_TO_TNX32_PRECOMP = reim_to_tnx32_precomp;
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct reim_to_tnx_precomp {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type REIM_TO_TNX_PRECOMP = reim_to_tnx_precomp;
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct reim_to_znx64_precomp {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type REIM_TO_ZNX64_PRECOMP = reim_to_znx64_precomp;
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_reim_fft_precomp(m: u32, num_buffers: u32) -> *mut REIM_FFT_PRECOMP;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_fft_precomp_get_buffer(tables: *const REIM_FFT_PRECOMP, buffer_index: u32) -> *mut f64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_reim_fft_buffer(m: u32) -> *mut f64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn delete_reim_fft_buffer(buffer: *mut f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_fft(tables: *const REIM_FFT_PRECOMP, data: *mut f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_reim_ifft_precomp(m: u32, num_buffers: u32) -> *mut REIM_IFFT_PRECOMP;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_ifft_precomp_get_buffer(tables: *const REIM_IFFT_PRECOMP, buffer_index: u32) -> *mut f64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_ifft(tables: *const REIM_IFFT_PRECOMP, data: *mut f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_reim_fftvec_mul_precomp(m: u32) -> *mut REIM_FFTVEC_MUL_PRECOMP;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_fftvec_mul(tables: *const REIM_FFTVEC_MUL_PRECOMP, r: *mut f64, a: *const f64, b: *const f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_reim_fftvec_addmul_precomp(m: u32) -> *mut REIM_FFTVEC_ADDMUL_PRECOMP;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_fftvec_addmul(tables: *const REIM_FFTVEC_ADDMUL_PRECOMP, r: *mut f64, a: *const f64, b: *const f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_reim_from_znx32_precomp(m: u32, log2bound: u32) -> *mut REIM_FROM_ZNX32_PRECOMP;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_from_znx32(tables: *const REIM_FROM_ZNX32_PRECOMP, r: *mut ::std::os::raw::c_void, a: *const i32);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_from_znx64(tables: *const REIM_FROM_ZNX64_PRECOMP, r: *mut ::std::os::raw::c_void, a: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_reim_from_znx64_precomp(m: u32, maxbnd: u32) -> *mut REIM_FROM_ZNX64_PRECOMP;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_from_znx64_simple(m: u32, log2bound: u32, r: *mut ::std::os::raw::c_void, a: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_reim_from_tnx32_precomp(m: u32) -> *mut REIM_FROM_TNX32_PRECOMP;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_from_tnx32(tables: *const REIM_FROM_TNX32_PRECOMP, r: *mut ::std::os::raw::c_void, a: *const i32);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_reim_to_tnx32_precomp(m: u32, divisor: f64, log2overhead: u32) -> *mut REIM_TO_TNX32_PRECOMP;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_to_tnx32(tables: *const REIM_TO_TNX32_PRECOMP, r: *mut i32, a: *const ::std::os::raw::c_void);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_reim_to_tnx_precomp(m: u32, divisor: f64, log2overhead: u32) -> *mut REIM_TO_TNX_PRECOMP;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_to_tnx(tables: *const REIM_TO_TNX_PRECOMP, r: *mut f64, a: *const f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_to_tnx_simple(m: u32, divisor: f64, log2overhead: u32, r: *mut f64, a: *const f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_reim_to_znx64_precomp(m: u32, divisor: f64, log2bound: u32) -> *mut REIM_TO_ZNX64_PRECOMP;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_to_znx64(precomp: *const REIM_TO_ZNX64_PRECOMP, r: *mut i64, a: *const ::std::os::raw::c_void);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_to_znx64_simple(m: u32, divisor: f64, log2bound: u32, r: *mut i64, a: *const ::std::os::raw::c_void);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_fft_simple(m: u32, data: *mut ::std::os::raw::c_void);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_ifft_simple(m: u32, data: *mut ::std::os::raw::c_void);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_fftvec_mul_simple(
|
||||
m: u32,
|
||||
r: *mut ::std::os::raw::c_void,
|
||||
a: *const ::std::os::raw::c_void,
|
||||
b: *const ::std::os::raw::c_void,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_fftvec_addmul_simple(
|
||||
m: u32,
|
||||
r: *mut ::std::os::raw::c_void,
|
||||
a: *const ::std::os::raw::c_void,
|
||||
b: *const ::std::os::raw::c_void,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_from_znx32_simple(m: u32, log2bound: u32, r: *mut ::std::os::raw::c_void, x: *const i32);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_from_tnx32_simple(m: u32, r: *mut ::std::os::raw::c_void, x: *const i32);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_to_tnx32_simple(m: u32, divisor: f64, log2overhead: u32, r: *mut i32, x: *const ::std::os::raw::c_void);
|
||||
}
|
||||
@@ -7,10 +7,4 @@ mod vec_znx_dft;
|
||||
mod vmp_pmat;
|
||||
mod zn;
|
||||
|
||||
pub use module::FFT64;
|
||||
|
||||
/// For external documentation
|
||||
pub use vec_znx::{
|
||||
vec_znx_copy_ref, vec_znx_lsh_inplace_ref, vec_znx_merge_ref, vec_znx_rsh_inplace_ref, vec_znx_split_ref,
|
||||
vec_znx_switch_degree_ref,
|
||||
};
|
||||
pub struct FFT64Spqlios;
|
||||
|
||||
@@ -3,13 +3,23 @@ use std::ptr::NonNull;
|
||||
use poulpy_hal::{
|
||||
layouts::{Backend, Module},
|
||||
oep::ModuleNewImpl,
|
||||
reference::znx::{
|
||||
ZnxCopy, ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly,
|
||||
ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace,
|
||||
ZnxRotate, ZnxSwitchRing, ZnxZero, znx_copy_ref, znx_normalize_final_step_inplace_ref, znx_normalize_final_step_ref,
|
||||
znx_normalize_first_step_carry_only_ref, znx_normalize_first_step_inplace_ref, znx_normalize_first_step_ref,
|
||||
znx_normalize_middle_step_carry_only_ref, znx_normalize_middle_step_inplace_ref, znx_normalize_middle_step_ref,
|
||||
znx_switch_ring_ref, znx_zero_ref,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::cpu_spqlios::ffi::module::{MODULE, delete_module_info, new_module_info};
|
||||
use crate::cpu_spqlios::{
|
||||
FFT64Spqlios,
|
||||
ffi::module::{MODULE, delete_module_info, new_module_info},
|
||||
znx::znx_rotate_i64,
|
||||
};
|
||||
|
||||
pub struct FFT64;
|
||||
|
||||
impl Backend for FFT64 {
|
||||
impl Backend for FFT64Spqlios {
|
||||
type ScalarPrep = f64;
|
||||
type ScalarBig = i64;
|
||||
type Handle = MODULE;
|
||||
@@ -26,8 +36,90 @@ impl Backend for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ModuleNewImpl<Self> for FFT64 {
|
||||
unsafe impl ModuleNewImpl<Self> for FFT64Spqlios {
|
||||
fn new_impl(n: u64) -> Module<Self> {
|
||||
unsafe { Module::from_raw_parts(new_module_info(n, 0), n) }
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxCopy for FFT64Spqlios {
|
||||
fn znx_copy(res: &mut [i64], a: &[i64]) {
|
||||
znx_copy_ref(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxZero for FFT64Spqlios {
|
||||
fn znx_zero(res: &mut [i64]) {
|
||||
znx_zero_ref(res);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxSwitchRing for FFT64Spqlios {
|
||||
fn znx_switch_ring(res: &mut [i64], a: &[i64]) {
|
||||
znx_switch_ring_ref(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxRotate for FFT64Spqlios {
|
||||
fn znx_rotate(p: i64, res: &mut [i64], src: &[i64]) {
|
||||
unsafe {
|
||||
znx_rotate_i64(res.len() as u64, p, res.as_mut_ptr(), src.as_ptr());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFinalStep for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_final_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_final_step_ref(basek, lsh, x, a, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFinalStepInplace for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_final_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
znx_normalize_final_step_inplace_ref(basek, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFirstStep for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_first_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_first_step_ref(basek, lsh, x, a, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFirstStepCarryOnly for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_first_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_first_step_carry_only_ref(basek, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFirstStepInplace for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_first_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
znx_normalize_first_step_inplace_ref(basek, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeMiddleStep for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_middle_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_middle_step_ref(basek, lsh, x, a, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeMiddleStepCarryOnly for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_middle_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_middle_step_carry_only_ref(basek, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeMiddleStepInplace for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_middle_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
znx_normalize_middle_step_inplace_ref(basek, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,9 +12,9 @@ use poulpy_hal::{
|
||||
},
|
||||
};
|
||||
|
||||
use crate::cpu_spqlios::FFT64;
|
||||
use crate::cpu_spqlios::FFT64Spqlios;
|
||||
|
||||
unsafe impl<B: Backend> ScratchOwnedAllocImpl<B> for FFT64 {
|
||||
unsafe impl<B: Backend> ScratchOwnedAllocImpl<B> for FFT64Spqlios {
|
||||
fn scratch_owned_alloc_impl(size: usize) -> ScratchOwned<B> {
|
||||
let data: Vec<u8> = alloc_aligned(size);
|
||||
ScratchOwned {
|
||||
@@ -24,7 +24,7 @@ unsafe impl<B: Backend> ScratchOwnedAllocImpl<B> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> ScratchOwnedBorrowImpl<B> for FFT64
|
||||
unsafe impl<B: Backend> ScratchOwnedBorrowImpl<B> for FFT64Spqlios
|
||||
where
|
||||
B: ScratchFromBytesImpl<B>,
|
||||
{
|
||||
@@ -33,13 +33,13 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> ScratchFromBytesImpl<B> for FFT64 {
|
||||
unsafe impl<B: Backend> ScratchFromBytesImpl<B> for FFT64Spqlios {
|
||||
fn scratch_from_bytes_impl(data: &mut [u8]) -> &mut Scratch<B> {
|
||||
unsafe { &mut *(data as *mut [u8] as *mut Scratch<B>) }
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> ScratchAvailableImpl<B> for FFT64 {
|
||||
unsafe impl<B: Backend> ScratchAvailableImpl<B> for FFT64Spqlios {
|
||||
fn scratch_available_impl(scratch: &Scratch<B>) -> usize {
|
||||
let ptr: *const u8 = scratch.data.as_ptr();
|
||||
let self_len: usize = scratch.data.len();
|
||||
@@ -48,7 +48,7 @@ unsafe impl<B: Backend> ScratchAvailableImpl<B> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeSliceImpl<B> for FFT64
|
||||
unsafe impl<B: Backend> TakeSliceImpl<B> for FFT64Spqlios
|
||||
where
|
||||
B: ScratchFromBytesImpl<B>,
|
||||
{
|
||||
@@ -64,7 +64,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeScalarZnxImpl<B> for FFT64
|
||||
unsafe impl<B: Backend> TakeScalarZnxImpl<B> for FFT64Spqlios
|
||||
where
|
||||
B: ScratchFromBytesImpl<B>,
|
||||
{
|
||||
@@ -77,7 +77,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeSvpPPolImpl<B> for FFT64
|
||||
unsafe impl<B: Backend> TakeSvpPPolImpl<B> for FFT64Spqlios
|
||||
where
|
||||
B: SvpPPolAllocBytesImpl<B> + ScratchFromBytesImpl<B>,
|
||||
{
|
||||
@@ -90,7 +90,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVecZnxImpl<B> for FFT64
|
||||
unsafe impl<B: Backend> TakeVecZnxImpl<B> for FFT64Spqlios
|
||||
where
|
||||
B: ScratchFromBytesImpl<B>,
|
||||
{
|
||||
@@ -103,7 +103,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVecZnxBigImpl<B> for FFT64
|
||||
unsafe impl<B: Backend> TakeVecZnxBigImpl<B> for FFT64Spqlios
|
||||
where
|
||||
B: VecZnxBigAllocBytesImpl<B> + ScratchFromBytesImpl<B>,
|
||||
{
|
||||
@@ -124,7 +124,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVecZnxDftImpl<B> for FFT64
|
||||
unsafe impl<B: Backend> TakeVecZnxDftImpl<B> for FFT64Spqlios
|
||||
where
|
||||
B: VecZnxDftAllocBytesImpl<B> + ScratchFromBytesImpl<B>,
|
||||
{
|
||||
@@ -146,7 +146,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVecZnxDftSliceImpl<B> for FFT64
|
||||
unsafe impl<B: Backend> TakeVecZnxDftSliceImpl<B> for FFT64Spqlios
|
||||
where
|
||||
B: VecZnxDftAllocBytesImpl<B> + ScratchFromBytesImpl<B> + TakeVecZnxDftImpl<B>,
|
||||
{
|
||||
@@ -168,7 +168,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVecZnxSliceImpl<B> for FFT64
|
||||
unsafe impl<B: Backend> TakeVecZnxSliceImpl<B> for FFT64Spqlios
|
||||
where
|
||||
B: ScratchFromBytesImpl<B> + TakeVecZnxImpl<B>,
|
||||
{
|
||||
@@ -190,7 +190,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVmpPMatImpl<B> for FFT64
|
||||
unsafe impl<B: Backend> TakeVmpPMatImpl<B> for FFT64Spqlios
|
||||
where
|
||||
B: VmpPMatAllocBytesImpl<B> + ScratchFromBytesImpl<B>,
|
||||
{
|
||||
@@ -213,7 +213,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeMatZnxImpl<B> for FFT64
|
||||
unsafe impl<B: Backend> TakeMatZnxImpl<B> for FFT64Spqlios
|
||||
where
|
||||
B: ScratchFromBytesImpl<B>,
|
||||
{
|
||||
|
||||
@@ -3,33 +3,36 @@ use poulpy_hal::{
|
||||
Backend, Module, ScalarZnxToRef, SvpPPol, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDft, VecZnxDftToMut,
|
||||
VecZnxDftToRef, ZnxInfos, ZnxView, ZnxViewMut,
|
||||
},
|
||||
oep::{SvpApplyImpl, SvpApplyInplaceImpl, SvpPPolAllocBytesImpl, SvpPPolAllocImpl, SvpPPolFromBytesImpl, SvpPrepareImpl},
|
||||
oep::{
|
||||
SvpApplyDftToDftImpl, SvpApplyDftToDftInplaceImpl, SvpPPolAllocBytesImpl, SvpPPolAllocImpl, SvpPPolFromBytesImpl,
|
||||
SvpPrepareImpl,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::cpu_spqlios::{
|
||||
FFT64,
|
||||
FFT64Spqlios,
|
||||
ffi::{svp, vec_znx_dft::vec_znx_dft_t},
|
||||
};
|
||||
|
||||
unsafe impl SvpPPolFromBytesImpl<Self> for FFT64 {
|
||||
unsafe impl SvpPPolFromBytesImpl<Self> for FFT64Spqlios {
|
||||
fn svp_ppol_from_bytes_impl(n: usize, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<Self> {
|
||||
SvpPPolOwned::from_bytes(n, cols, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SvpPPolAllocImpl<Self> for FFT64 {
|
||||
unsafe impl SvpPPolAllocImpl<Self> for FFT64Spqlios {
|
||||
fn svp_ppol_alloc_impl(n: usize, cols: usize) -> SvpPPolOwned<Self> {
|
||||
SvpPPolOwned::alloc(n, cols)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SvpPPolAllocBytesImpl<Self> for FFT64 {
|
||||
unsafe impl SvpPPolAllocBytesImpl<Self> for FFT64Spqlios {
|
||||
fn svp_ppol_alloc_bytes_impl(n: usize, cols: usize) -> usize {
|
||||
FFT64::layout_prep_word_count() * n * cols * size_of::<f64>()
|
||||
FFT64Spqlios::layout_prep_word_count() * n * cols * size_of::<f64>()
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SvpPrepareImpl<Self> for FFT64 {
|
||||
unsafe impl SvpPrepareImpl<Self> for FFT64Spqlios {
|
||||
fn svp_prepare_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: SvpPPolToMut<Self>,
|
||||
@@ -45,9 +48,16 @@ unsafe impl SvpPrepareImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SvpApplyImpl<Self> for FFT64 {
|
||||
fn svp_apply_impl<R, A, B>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
unsafe impl SvpApplyDftToDftImpl<Self> for FFT64Spqlios {
|
||||
fn svp_apply_dft_to_dft_impl<R, A, B>(
|
||||
module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: SvpPPolToRef<Self>,
|
||||
B: VecZnxDftToRef<Self>,
|
||||
@@ -70,8 +80,8 @@ unsafe impl SvpApplyImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SvpApplyInplaceImpl for FFT64 {
|
||||
fn svp_apply_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
unsafe impl SvpApplyDftToDftInplaceImpl for FFT64Spqlios {
|
||||
fn svp_apply_dft_to_dft_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: SvpPPolToRef<Self>,
|
||||
|
||||
@@ -1,39 +1,44 @@
|
||||
use itertools::izip;
|
||||
use rand_distr::Normal;
|
||||
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
TakeSlice, TakeVecZnx, VecZnxAddDistF64, VecZnxCopy, VecZnxFillDistF64, VecZnxNormalizeTmpBytes, VecZnxRotate,
|
||||
VecZnxRotateInplace, VecZnxSwithcDegree,
|
||||
},
|
||||
api::{TakeSlice, VecZnxMergeRingsTmpBytes, VecZnxNormalizeTmpBytes, VecZnxSplitRingTmpBytes},
|
||||
layouts::{
|
||||
Backend, Module, ScalarZnx, ScalarZnxToRef, Scratch, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxSliceSize, ZnxView,
|
||||
ZnxViewMut, ZnxZero,
|
||||
Module, ScalarZnx, ScalarZnxToRef, Scratch, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut,
|
||||
},
|
||||
oep::{
|
||||
TakeSliceImpl, TakeVecZnxImpl, VecZnxAddDistF64Impl, VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl,
|
||||
VecZnxAddScalarInplaceImpl, VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl, VecZnxFillDistF64Impl,
|
||||
VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxLshInplaceImpl, VecZnxMergeImpl, VecZnxMulXpMinusOneImpl,
|
||||
VecZnxMulXpMinusOneInplaceImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl,
|
||||
VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl, VecZnxRshInplaceImpl,
|
||||
VecZnxSplitImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarInplaceImpl,
|
||||
VecZnxSwithcDegreeImpl,
|
||||
TakeSliceImpl, VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl, VecZnxAddScalarImpl, VecZnxAddScalarInplaceImpl,
|
||||
VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxAutomorphismInplaceTmpBytesImpl, VecZnxCopyImpl,
|
||||
VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxLshImpl, VecZnxLshInplaceImpl, VecZnxLshTmpBytesImpl,
|
||||
VecZnxMergeRingsImpl, VecZnxMergeRingsTmpBytesImpl, VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl,
|
||||
VecZnxMulXpMinusOneInplaceTmpBytesImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl,
|
||||
VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl,
|
||||
VecZnxRotateInplaceTmpBytesImpl, VecZnxRshImpl, VecZnxRshInplaceImpl, VecZnxRshTmpBytesImpl, VecZnxSplitRingImpl,
|
||||
VecZnxSplitRingTmpBytesImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarImpl,
|
||||
VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl,
|
||||
},
|
||||
reference::{
|
||||
vec_znx::{
|
||||
vec_znx_add_normal_ref, vec_znx_automorphism_inplace_tmp_bytes, vec_znx_copy, vec_znx_fill_normal_ref,
|
||||
vec_znx_fill_uniform_ref, vec_znx_lsh, vec_znx_lsh_inplace, vec_znx_lsh_tmp_bytes, vec_znx_merge_rings,
|
||||
vec_znx_merge_rings_tmp_bytes, vec_znx_mul_xp_minus_one_inplace_tmp_bytes, vec_znx_rotate_inplace_tmp_bytes,
|
||||
vec_znx_rsh, vec_znx_rsh_inplace, vec_znx_rsh_tmp_bytes, vec_znx_split_ring, vec_znx_split_ring_tmp_bytes,
|
||||
vec_znx_switch_ring,
|
||||
},
|
||||
znx::{znx_copy_ref, znx_zero_ref},
|
||||
},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
use crate::cpu_spqlios::{
|
||||
FFT64,
|
||||
FFT64Spqlios,
|
||||
ffi::{module::module_info_t, vec_znx, znx},
|
||||
};
|
||||
|
||||
unsafe impl VecZnxNormalizeTmpBytesImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxNormalizeTmpBytesImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_normalize_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(module.ptr() as *const module_info_t) as usize }
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxNormalizeImpl<Self> for FFT64
|
||||
unsafe impl VecZnxNormalizeImpl<Self> for FFT64Spqlios
|
||||
where
|
||||
Self: TakeSliceImpl<Self> + VecZnxNormalizeTmpBytesImpl<Self>,
|
||||
{
|
||||
@@ -75,7 +80,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxNormalizeInplaceImpl<Self> for FFT64
|
||||
unsafe impl VecZnxNormalizeInplaceImpl<Self> for FFT64Spqlios
|
||||
where
|
||||
Self: TakeSliceImpl<Self> + VecZnxNormalizeTmpBytesImpl<Self>,
|
||||
{
|
||||
@@ -108,7 +113,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxAddImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxAddImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_add_impl<R, A, C>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
@@ -142,7 +147,7 @@ unsafe impl VecZnxAddImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxAddInplaceImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxAddInplaceImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_add_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
@@ -172,7 +177,7 @@ unsafe impl VecZnxAddInplaceImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxAddScalarInplaceImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxAddScalarInplaceImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_add_scalar_inplace_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
res: &mut R,
|
||||
@@ -209,7 +214,60 @@ unsafe impl VecZnxAddScalarInplaceImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxSubImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxAddScalarImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_add_scalar_impl<R, A, B>(
|
||||
module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
b_limb: usize,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
B: VecZnxToRef,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
let a: ScalarZnx<&[u8]> = a.to_ref();
|
||||
let b: VecZnx<&[u8]> = b.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
|
||||
let min_size: usize = b.size().min(res.size());
|
||||
|
||||
unsafe {
|
||||
vec_znx::vec_znx_add(
|
||||
module.ptr() as *const module_info_t,
|
||||
res.at_mut_ptr(res_col, b_limb),
|
||||
1_u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(b_col, b_limb),
|
||||
1_u64,
|
||||
b.sl() as u64,
|
||||
);
|
||||
|
||||
for j in 0..min_size {
|
||||
if j != b_limb {
|
||||
znx_copy_ref(res.at_mut(res_col, j), b.at(b_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
for j in min_size..res.size() {
|
||||
znx_zero_ref(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxSubImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_sub_impl<R, A, C>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
@@ -243,7 +301,7 @@ unsafe impl VecZnxSubImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxSubABInplaceImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxSubABInplaceImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_sub_ab_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
@@ -272,7 +330,7 @@ unsafe impl VecZnxSubABInplaceImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxSubBAInplaceImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxSubBAInplaceImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_sub_ba_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
@@ -301,7 +359,60 @@ unsafe impl VecZnxSubBAInplaceImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxSubScalarInplaceImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxSubScalarImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_sub_scalar_impl<R, A, B>(
|
||||
module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
b_limb: usize,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
B: VecZnxToRef,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
let a: ScalarZnx<&[u8]> = a.to_ref();
|
||||
let b: VecZnx<&[u8]> = b.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
|
||||
let min_size: usize = b.size().min(res.size());
|
||||
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
module.ptr() as *const module_info_t,
|
||||
res.at_mut_ptr(res_col, b_limb),
|
||||
1_u64,
|
||||
res.sl() as u64,
|
||||
b.at_ptr(b_col, b_limb),
|
||||
1_u64,
|
||||
b.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
);
|
||||
|
||||
for j in 0..min_size {
|
||||
if j != b_limb {
|
||||
res.at_mut(res_col, j).copy_from_slice(b.at(b_col, j))
|
||||
}
|
||||
}
|
||||
|
||||
for j in min_size..res.size() {
|
||||
znx_zero_ref(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxSubScalarInplaceImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_sub_scalar_inplace_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
res: &mut R,
|
||||
@@ -327,18 +438,18 @@ unsafe impl VecZnxSubScalarInplaceImpl<Self> for FFT64 {
|
||||
res.at_mut_ptr(res_col, res_limb),
|
||||
1_u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
res.at_ptr(res_col, res_limb),
|
||||
1_u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxNegateImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxNegateImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_negate_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
@@ -364,7 +475,7 @@ unsafe impl VecZnxNegateImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxNegateInplaceImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxNegateInplaceImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_negate_inplace_impl<A>(module: &Module<Self>, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
@@ -384,92 +495,105 @@ unsafe impl VecZnxNegateInplaceImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxLshInplaceImpl<Self> for FFT64 {
|
||||
fn vec_znx_lsh_inplace_impl<A>(_module: &Module<Self>, basek: usize, k: usize, a: &mut A)
|
||||
where
|
||||
unsafe impl VecZnxLshTmpBytesImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_lsh_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
vec_znx_lsh_tmp_bytes(module.n())
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxRshTmpBytesImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_rsh_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
vec_znx_rsh_tmp_bytes(module.n())
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxLshImpl<Self> for FFT64Spqlios
|
||||
where
|
||||
Module<Self>: VecZnxNormalizeTmpBytes,
|
||||
Scratch<Self>: TakeSlice,
|
||||
{
|
||||
fn vec_znx_lsh_inplace_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_lsh::<_, _, FFT64Spqlios>(basek, k, res, res_col, a, a_col, carry)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxLshInplaceImpl<Self> for FFT64Spqlios
|
||||
where
|
||||
Module<Self>: VecZnxNormalizeTmpBytes,
|
||||
Scratch<Self>: TakeSlice,
|
||||
{
|
||||
fn vec_znx_lsh_inplace_impl<A>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
k: usize,
|
||||
a: &mut A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
vec_znx_lsh_inplace_ref(basek, k, a)
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_lsh_inplace::<_, FFT64Spqlios>(basek, k, a, a_col, carry)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_lsh_inplace_ref<A>(basek: usize, k: usize, a: &mut A)
|
||||
unsafe impl VecZnxRshImpl<Self> for FFT64Spqlios
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
Module<Self>: VecZnxNormalizeTmpBytes,
|
||||
Scratch<Self>: TakeSlice,
|
||||
{
|
||||
let mut a: VecZnx<&mut [u8]> = a.to_mut();
|
||||
|
||||
let n: usize = a.n();
|
||||
let cols: usize = a.cols();
|
||||
let size: usize = a.size();
|
||||
let steps: usize = k / basek;
|
||||
|
||||
a.raw_mut().rotate_left(n * steps * cols);
|
||||
(0..cols).for_each(|i| {
|
||||
(size - steps..size).for_each(|j| {
|
||||
a.zero_at(i, j);
|
||||
})
|
||||
});
|
||||
|
||||
let k_rem: usize = k % basek;
|
||||
|
||||
if k_rem != 0 {
|
||||
let shift: usize = i64::BITS as usize - k_rem;
|
||||
(0..cols).for_each(|i| {
|
||||
(0..steps).for_each(|j| {
|
||||
a.at_mut(i, j).iter_mut().for_each(|xi| {
|
||||
*xi <<= shift;
|
||||
});
|
||||
});
|
||||
});
|
||||
fn vec_znx_rsh_inplace_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_rsh::<_, _, FFT64Spqlios>(basek, k, res, res_col, a, a_col, carry)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxRshInplaceImpl<Self> for FFT64 {
|
||||
fn vec_znx_rsh_inplace_impl<A>(_module: &Module<Self>, basek: usize, k: usize, a: &mut A)
|
||||
where
|
||||
unsafe impl VecZnxRshInplaceImpl<Self> for FFT64Spqlios
|
||||
where
|
||||
Module<Self>: VecZnxNormalizeTmpBytes,
|
||||
Scratch<Self>: TakeSlice,
|
||||
{
|
||||
fn vec_znx_rsh_inplace_impl<A>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
k: usize,
|
||||
a: &mut A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
vec_znx_rsh_inplace_ref(basek, k, a)
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_rsh_inplace::<_, FFT64Spqlios>(basek, k, a, a_col, carry)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_rsh_inplace_ref<A>(basek: usize, k: usize, a: &mut A)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
let mut a: VecZnx<&mut [u8]> = a.to_mut();
|
||||
let n: usize = a.n();
|
||||
let cols: usize = a.cols();
|
||||
let size: usize = a.size();
|
||||
let steps: usize = k / basek;
|
||||
|
||||
a.raw_mut().rotate_right(n * steps * cols);
|
||||
(0..cols).for_each(|i| {
|
||||
(0..steps).for_each(|j| {
|
||||
a.zero_at(i, j);
|
||||
})
|
||||
});
|
||||
|
||||
let k_rem: usize = k % basek;
|
||||
|
||||
if k_rem != 0 {
|
||||
let mut carry: Vec<i64> = vec![0i64; n]; // ALLOC (but small so OK)
|
||||
let shift: usize = i64::BITS as usize - k_rem;
|
||||
(0..cols).for_each(|i| {
|
||||
carry.fill(0);
|
||||
(steps..size).for_each(|j| {
|
||||
izip!(carry.iter_mut(), a.at_mut(i, j).iter_mut()).for_each(|(ci, xi)| {
|
||||
*xi += *ci << basek;
|
||||
*ci = (*xi << shift) >> shift;
|
||||
*xi = (*xi - *ci) >> k_rem;
|
||||
});
|
||||
});
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxRotateImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxRotateImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_rotate_impl<R, A>(_module: &Module<Self>, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
@@ -482,7 +606,8 @@ unsafe impl VecZnxRotateImpl<Self> for FFT64 {
|
||||
assert_eq!(res.n(), a.n());
|
||||
}
|
||||
unsafe {
|
||||
(0..a.size()).for_each(|j| {
|
||||
let min_size = res.size().min(a.size());
|
||||
(0..min_size).for_each(|j| {
|
||||
znx::znx_rotate_i64(
|
||||
a.n() as u64,
|
||||
k,
|
||||
@@ -490,12 +615,28 @@ unsafe impl VecZnxRotateImpl<Self> for FFT64 {
|
||||
a.at_ptr(a_col, j),
|
||||
);
|
||||
});
|
||||
|
||||
(min_size..res.size()).for_each(|j| {
|
||||
znx_zero_ref(res.at_mut(res_col, j));
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxRotateInplaceImpl<Self> for FFT64 {
|
||||
fn vec_znx_rotate_inplace_impl<A>(_module: &Module<Self>, k: i64, a: &mut A, a_col: usize)
|
||||
unsafe impl VecZnxRotateInplaceTmpBytesImpl<Self> for FFT64Spqlios
|
||||
where
|
||||
Scratch<Self>: TakeSlice,
|
||||
{
|
||||
fn vec_znx_rotate_inplace_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
vec_znx_rotate_inplace_tmp_bytes(module.n())
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxRotateInplaceImpl<Self> for FFT64Spqlios
|
||||
where
|
||||
Scratch<Self>: TakeSlice,
|
||||
{
|
||||
fn vec_znx_rotate_inplace_impl<A>(_module: &Module<Self>, k: i64, a: &mut A, a_col: usize, _scratch: &mut Scratch<Self>)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
@@ -508,7 +649,7 @@ unsafe impl VecZnxRotateInplaceImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxAutomorphismImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxAutomorphismImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_automorphism_impl<R, A>(module: &Module<Self>, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
@@ -535,8 +676,14 @@ unsafe impl VecZnxAutomorphismImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxAutomorphismInplaceImpl<Self> for FFT64 {
|
||||
fn vec_znx_automorphism_inplace_impl<A>(module: &Module<Self>, k: i64, a: &mut A, a_col: usize)
|
||||
unsafe impl VecZnxAutomorphismInplaceTmpBytesImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_automorphism_inplace_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
vec_znx_automorphism_inplace_tmp_bytes(module.n())
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxAutomorphismInplaceImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_automorphism_inplace_impl<A>(module: &Module<Self>, k: i64, a: &mut A, a_col: usize, _scratch: &mut Scratch<Self>)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
@@ -564,7 +711,7 @@ unsafe impl VecZnxAutomorphismInplaceImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxMulXpMinusOneImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxMulXpMinusOneImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_mul_xp_minus_one_impl<R, A>(module: &Module<Self>, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
@@ -592,9 +739,20 @@ unsafe impl VecZnxMulXpMinusOneImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxMulXpMinusOneInplaceImpl<Self> for FFT64 {
|
||||
fn vec_znx_mul_xp_minus_one_inplace_impl<R>(module: &Module<Self>, p: i64, res: &mut R, res_col: usize)
|
||||
where
|
||||
unsafe impl VecZnxMulXpMinusOneInplaceTmpBytesImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_mul_xp_minus_one_inplace_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
vec_znx_mul_xp_minus_one_inplace_tmp_bytes(module.n())
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxMulXpMinusOneInplaceImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_mul_xp_minus_one_inplace_impl<R>(
|
||||
module: &Module<Self>,
|
||||
p: i64,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
_scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
@@ -617,15 +775,18 @@ unsafe impl VecZnxMulXpMinusOneInplaceImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxSplitImpl<Self> for FFT64
|
||||
unsafe impl VecZnxSplitRingTmpBytesImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_split_ring_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
vec_znx_split_ring_tmp_bytes(module.n())
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxSplitRingImpl<Self> for FFT64Spqlios
|
||||
where
|
||||
Self: TakeVecZnxImpl<Self>
|
||||
+ TakeVecZnxImpl<Self>
|
||||
+ VecZnxSwithcDegreeImpl<Self>
|
||||
+ VecZnxRotateImpl<Self>
|
||||
+ VecZnxRotateInplaceImpl<Self>,
|
||||
Module<Self>: VecZnxSplitRingTmpBytes,
|
||||
Scratch<Self>: TakeSlice,
|
||||
{
|
||||
fn vec_znx_split_impl<R, A>(
|
||||
fn vec_znx_split_ring_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
res: &mut [R],
|
||||
res_col: usize,
|
||||
@@ -636,287 +797,72 @@ where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
vec_znx_split_ref(module, res, res_col, a, a_col, scratch)
|
||||
let (tmp, _) = scratch.take_slice(module.vec_znx_split_ring_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_split_ring::<_, _, FFT64Spqlios>(res, res_col, a, a_col, tmp);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_split_ref<R, A, B>(
|
||||
module: &Module<B>,
|
||||
res: &mut [R],
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
B: Backend + TakeVecZnxImpl<B> + VecZnxSwithcDegreeImpl<B> + VecZnxRotateImpl<B> + VecZnxRotateInplaceImpl<B>,
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
|
||||
let (n_in, n_out) = (a.n(), res[0].to_mut().n());
|
||||
|
||||
let (mut buf, _) = scratch.take_vec_znx(n_in.max(n_out), 1, a.size());
|
||||
|
||||
debug_assert!(
|
||||
n_out < n_in,
|
||||
"invalid a: output ring degree should be smaller"
|
||||
);
|
||||
res[1..].iter_mut().for_each(|bi| {
|
||||
debug_assert_eq!(
|
||||
bi.to_mut().n(),
|
||||
n_out,
|
||||
"invalid input a: all VecZnx must have the same degree"
|
||||
)
|
||||
});
|
||||
|
||||
res.iter_mut().enumerate().for_each(|(i, bi)| {
|
||||
if i == 0 {
|
||||
module.vec_znx_switch_degree(bi, res_col, &a, a_col);
|
||||
module.vec_znx_rotate(-1, &mut buf, 0, &a, a_col);
|
||||
} else {
|
||||
module.vec_znx_switch_degree(bi, res_col, &buf, a_col);
|
||||
module.vec_znx_rotate_inplace(-1, &mut buf, a_col);
|
||||
}
|
||||
})
|
||||
unsafe impl VecZnxMergeRingsTmpBytesImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_merge_rings_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
vec_znx_merge_rings_tmp_bytes(module.n())
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxMergeImpl<Self> for FFT64
|
||||
unsafe impl VecZnxMergeRingsImpl<Self> for FFT64Spqlios
|
||||
where
|
||||
Self: VecZnxSwithcDegreeImpl<Self> + VecZnxRotateInplaceImpl<Self>,
|
||||
Module<Self>: VecZnxMergeRingsTmpBytes,
|
||||
{
|
||||
fn vec_znx_merge_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &[A], a_col: usize)
|
||||
where
|
||||
fn vec_znx_merge_rings_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &[A],
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
vec_znx_merge_ref(module, res, res_col, a, a_col)
|
||||
let (tmp, _) = scratch.take_slice(module.vec_znx_merge_rings_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_merge_rings::<_, _, FFT64Spqlios>(res, res_col, a, a_col, tmp);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_merge_ref<R, A, B>(module: &Module<B>, res: &mut R, res_col: usize, a: &[A], a_col: usize)
|
||||
where
|
||||
B: Backend + VecZnxSwithcDegreeImpl<B> + VecZnxRotateInplaceImpl<B>,
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
let (n_in, n_out) = (res.n(), a[0].to_ref().n());
|
||||
|
||||
debug_assert!(
|
||||
n_out < n_in,
|
||||
"invalid a: output ring degree should be smaller"
|
||||
);
|
||||
a[1..].iter().for_each(|ai| {
|
||||
debug_assert_eq!(
|
||||
ai.to_ref().n(),
|
||||
n_out,
|
||||
"invalid input a: all VecZnx must have the same degree"
|
||||
)
|
||||
});
|
||||
|
||||
a.iter().for_each(|ai| {
|
||||
module.vec_znx_switch_degree(&mut res, res_col, ai, a_col);
|
||||
module.vec_znx_rotate_inplace(-1, &mut res, res_col);
|
||||
});
|
||||
|
||||
module.vec_znx_rotate_inplace(a.len() as i64, &mut res, res_col);
|
||||
}
|
||||
|
||||
unsafe impl VecZnxSwithcDegreeImpl<Self> for FFT64
|
||||
unsafe impl VecZnxSwitchRingImpl<Self> for FFT64Spqlios
|
||||
where
|
||||
Self: VecZnxCopyImpl<Self>,
|
||||
{
|
||||
fn vec_znx_switch_degree_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
fn vec_znx_switch_ring_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
vec_znx_switch_degree_ref(module, res, res_col, a, a_col)
|
||||
vec_znx_switch_ring::<_, _, FFT64Spqlios>(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_switch_degree_ref<R, A, B>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
B: Backend + VecZnxCopyImpl<B>,
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
let (n_in, n_out) = (a.n(), res.n());
|
||||
|
||||
if n_in == n_out {
|
||||
module.vec_znx_copy(&mut res, res_col, &a, a_col);
|
||||
return;
|
||||
}
|
||||
|
||||
let (gap_in, gap_out): (usize, usize);
|
||||
if n_in > n_out {
|
||||
(gap_in, gap_out) = (n_in / n_out, 1)
|
||||
} else {
|
||||
(gap_in, gap_out) = (1, n_out / n_in);
|
||||
res.zero();
|
||||
}
|
||||
|
||||
let size: usize = a.size().min(res.size());
|
||||
|
||||
(0..size).for_each(|i| {
|
||||
izip!(
|
||||
a.at(a_col, i).iter().step_by(gap_in),
|
||||
res.at_mut(res_col, i).iter_mut().step_by(gap_out)
|
||||
)
|
||||
.for_each(|(x_in, x_out)| *x_out = *x_in);
|
||||
});
|
||||
}
|
||||
|
||||
unsafe impl VecZnxCopyImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxCopyImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_copy_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
vec_znx_copy_ref(res, res_col, a, a_col)
|
||||
vec_znx_copy::<_, _, FFT64Spqlios>(res, res_col, a, a_col)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_copy_ref<R, A>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let mut res_mut: VecZnx<&mut [u8]> = res.to_mut();
|
||||
let a_ref: VecZnx<&[u8]> = a.to_ref();
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_ref.size());
|
||||
|
||||
(0..min_size).for_each(|j| {
|
||||
res_mut
|
||||
.at_mut(res_col, j)
|
||||
.copy_from_slice(a_ref.at(a_col, j));
|
||||
});
|
||||
(min_size..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
})
|
||||
}
|
||||
|
||||
unsafe impl VecZnxFillUniformImpl<Self> for FFT64 {
|
||||
fn vec_znx_fill_uniform_impl<R>(
|
||||
_module: &Module<Self>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
) where
|
||||
unsafe impl VecZnxFillUniformImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_fill_uniform_impl<R>(_module: &Module<Self>, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
let mut a: VecZnx<&mut [u8]> = res.to_mut();
|
||||
let base2k: u64 = 1 << basek;
|
||||
let mask: u64 = base2k - 1;
|
||||
let base2k_half: i64 = (base2k >> 1) as i64;
|
||||
(0..k.div_ceil(basek)).for_each(|j| {
|
||||
a.at_mut(res_col, j)
|
||||
.iter_mut()
|
||||
.for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half);
|
||||
})
|
||||
vec_znx_fill_uniform_ref(basek, res, res_col, source)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxFillDistF64Impl<Self> for FFT64 {
|
||||
fn vec_znx_fill_dist_f64_impl<R, D: rand::prelude::Distribution<f64>>(
|
||||
_module: &Module<Self>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
let mut a: VecZnx<&mut [u8]> = res.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = k.div_ceil(basek) - 1;
|
||||
let basek_rem: usize = (limb + 1) * basek - k;
|
||||
|
||||
if basek_rem != 0 {
|
||||
a.at_mut(res_col, limb).iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a = (dist_f64.round() as i64) << basek_rem;
|
||||
});
|
||||
} else {
|
||||
a.at_mut(res_col, limb).iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a = dist_f64.round() as i64
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxAddDistF64Impl<Self> for FFT64 {
|
||||
fn vec_znx_add_dist_f64_impl<R, D: rand::prelude::Distribution<f64>>(
|
||||
_module: &Module<Self>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
let mut a: VecZnx<&mut [u8]> = res.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = k.div_ceil(basek) - 1;
|
||||
let basek_rem: usize = (limb + 1) * basek - k;
|
||||
|
||||
if basek_rem != 0 {
|
||||
a.at_mut(res_col, limb).iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a += (dist_f64.round() as i64) << basek_rem;
|
||||
});
|
||||
} else {
|
||||
a.at_mut(res_col, limb).iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a += dist_f64.round() as i64
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxFillNormalImpl<Self> for FFT64
|
||||
where
|
||||
Self: VecZnxFillDistF64Impl<Self>,
|
||||
{
|
||||
unsafe impl VecZnxFillNormalImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_fill_normal_impl<R>(
|
||||
module: &Module<Self>,
|
||||
_module: &Module<Self>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
@@ -927,24 +873,13 @@ where
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
module.vec_znx_fill_dist_f64(
|
||||
basek,
|
||||
res,
|
||||
res_col,
|
||||
k,
|
||||
source,
|
||||
Normal::new(0.0, sigma).unwrap(),
|
||||
bound,
|
||||
);
|
||||
vec_znx_fill_normal_ref(basek, res, res_col, k, sigma, bound, source);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxAddNormalImpl<Self> for FFT64
|
||||
where
|
||||
Self: VecZnxAddDistF64Impl<Self>,
|
||||
{
|
||||
unsafe impl VecZnxAddNormalImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_add_normal_impl<R>(
|
||||
module: &Module<Self>,
|
||||
_module: &Module<Self>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
@@ -955,14 +890,6 @@ where
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
module.vec_znx_add_dist_f64(
|
||||
basek,
|
||||
res,
|
||||
res_col,
|
||||
k,
|
||||
source,
|
||||
Normal::new(0.0, sigma).unwrap(),
|
||||
bound,
|
||||
);
|
||||
vec_znx_add_normal_ref(basek, res, res_col, k, sigma, bound, source);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,170 +1,98 @@
|
||||
use rand_distr::{Distribution, Normal};
|
||||
|
||||
use crate::cpu_spqlios::{FFT64, ffi::vec_znx};
|
||||
use crate::cpu_spqlios::{FFT64Spqlios, ffi::vec_znx};
|
||||
use poulpy_hal::{
|
||||
api::{TakeSlice, VecZnxBigAddDistF64, VecZnxBigFillDistF64, VecZnxBigNormalizeTmpBytes},
|
||||
api::{TakeSlice, VecZnxBigNormalizeTmpBytes},
|
||||
layouts::{
|
||||
Backend, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef,
|
||||
ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut,
|
||||
},
|
||||
oep::{
|
||||
TakeSliceImpl, VecZnxBigAddDistF64Impl, VecZnxBigAddImpl, VecZnxBigAddInplaceImpl, VecZnxBigAddNormalImpl,
|
||||
VecZnxBigAddSmallImpl, VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl,
|
||||
VecZnxBigAutomorphismImpl, VecZnxBigAutomorphismInplaceImpl, VecZnxBigFillDistF64Impl, VecZnxBigFillNormalImpl,
|
||||
VecZnxBigFromBytesImpl, VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl, VecZnxBigNormalizeTmpBytesImpl,
|
||||
VecZnxBigSubABInplaceImpl, VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl, VecZnxBigSubSmallAImpl,
|
||||
VecZnxBigSubSmallAInplaceImpl, VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl,
|
||||
TakeSliceImpl, VecZnxBigAddImpl, VecZnxBigAddInplaceImpl, VecZnxBigAddNormalImpl, VecZnxBigAddSmallImpl,
|
||||
VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, VecZnxBigAutomorphismImpl,
|
||||
VecZnxBigAutomorphismInplaceImpl, VecZnxBigAutomorphismInplaceTmpBytesImpl, VecZnxBigFromBytesImpl,
|
||||
VecZnxBigFromSmallImpl, VecZnxBigNegateImpl, VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl,
|
||||
VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubABInplaceImpl, VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl,
|
||||
VecZnxBigSubSmallAImpl, VecZnxBigSubSmallAInplaceImpl, VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl,
|
||||
},
|
||||
reference::{
|
||||
vec_znx::vec_znx_add_normal_ref,
|
||||
znx::{znx_copy_ref, znx_zero_ref},
|
||||
},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
unsafe impl VecZnxBigAllocBytesImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxBigAllocBytesImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_big_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize {
|
||||
Self::layout_big_word_count() * n * cols * size * size_of::<f64>()
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAllocImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxBigAllocImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_big_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxBigOwned<Self> {
|
||||
VecZnxBig::alloc(n, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigFromBytesImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxBigFromBytesImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_big_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<Self> {
|
||||
VecZnxBig::from_bytes(n, cols, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddDistF64Impl<Self> for FFT64 {
|
||||
fn add_dist_f64_impl<R: VecZnxBigToMut<Self>, D: Distribution<f64>>(
|
||||
_module: &Module<Self>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) {
|
||||
unsafe impl VecZnxBigFromSmallImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_big_from_small_impl<R, A>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
|
||||
let limb: usize = k.div_ceil(basek) - 1;
|
||||
let basek_rem: usize = (limb + 1) * basek - k;
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), a.n());
|
||||
}
|
||||
|
||||
if basek_rem != 0 {
|
||||
res.at_mut(res_col, limb).iter_mut().for_each(|x| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*x += (dist_f64.round() as i64) << basek_rem;
|
||||
});
|
||||
} else {
|
||||
res.at_mut(res_col, limb).iter_mut().for_each(|x| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*x += dist_f64.round() as i64
|
||||
});
|
||||
let res_size: usize = res.size();
|
||||
let a_size: usize = a.size();
|
||||
|
||||
let min_size: usize = res_size.min(a_size);
|
||||
|
||||
for j in 0..min_size {
|
||||
znx_copy_ref(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
|
||||
for j in min_size..res_size {
|
||||
znx_zero_ref(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddNormalImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxBigAddNormalImpl<Self> for FFT64Spqlios {
|
||||
fn add_normal_impl<R: VecZnxBigToMut<Self>>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) {
|
||||
module.vec_znx_big_add_dist_f64(
|
||||
basek,
|
||||
res,
|
||||
res_col,
|
||||
k,
|
||||
source,
|
||||
Normal::new(0.0, sigma).unwrap(),
|
||||
bound,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigFillDistF64Impl<Self> for FFT64 {
|
||||
fn fill_dist_f64_impl<R: VecZnxBigToMut<Self>, D: Distribution<f64>>(
|
||||
_module: &Module<Self>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) {
|
||||
let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = k.div_ceil(basek) - 1;
|
||||
let basek_rem: usize = (limb + 1) * basek - k;
|
||||
|
||||
if basek_rem != 0 {
|
||||
res.at_mut(res_col, limb).iter_mut().for_each(|x| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*x = (dist_f64.round() as i64) << basek_rem;
|
||||
});
|
||||
} else {
|
||||
res.at_mut(res_col, limb).iter_mut().for_each(|x| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*x = dist_f64.round() as i64
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigFillNormalImpl<Self> for FFT64 {
|
||||
fn fill_normal_impl<R: VecZnxBigToMut<Self>>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) {
|
||||
module.vec_znx_big_fill_dist_f64(
|
||||
basek,
|
||||
res,
|
||||
res_col,
|
||||
k,
|
||||
source,
|
||||
Normal::new(0.0, sigma).unwrap(),
|
||||
bound,
|
||||
);
|
||||
let res: VecZnxBig<&mut [u8], FFT64Spqlios> = res.to_mut();
|
||||
|
||||
let mut res_znx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
vec_znx_add_normal_ref(basek, &mut res_znx, res_col, k, sigma, bound, source);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxBigAddImpl<Self> for FFT64Spqlios {
|
||||
/// Adds `a` to `b` and stores the result on `c`.
|
||||
fn vec_znx_big_add_impl<R, A, B>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
@@ -199,7 +127,7 @@ unsafe impl VecZnxBigAddImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddInplaceImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxBigAddInplaceImpl<Self> for FFT64Spqlios {
|
||||
/// Adds `a` to `b` and stores the result on `b`.
|
||||
fn vec_znx_big_add_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
@@ -230,7 +158,7 @@ unsafe impl VecZnxBigAddInplaceImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddSmallImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxBigAddSmallImpl<Self> for FFT64Spqlios {
|
||||
/// Adds `a` to `b` and stores the result on `c`.
|
||||
fn vec_znx_big_add_small_impl<R, A, B>(
|
||||
module: &Module<Self>,
|
||||
@@ -272,7 +200,7 @@ unsafe impl VecZnxBigAddSmallImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddSmallInplaceImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxBigAddSmallInplaceImpl<Self> for FFT64Spqlios {
|
||||
/// Adds `a` to `b` and stores the result on `b`.
|
||||
fn vec_znx_big_add_small_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
@@ -303,7 +231,7 @@ unsafe impl VecZnxBigAddSmallInplaceImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxBigSubImpl<Self> for FFT64Spqlios {
|
||||
/// Subtracts `a` to `b` and stores the result on `c`.
|
||||
fn vec_znx_big_sub_impl<R, A, B>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
@@ -338,7 +266,7 @@ unsafe impl VecZnxBigSubImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubABInplaceImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxBigSubABInplaceImpl<Self> for FFT64Spqlios {
|
||||
/// Subtracts `a` from `b` and stores the result on `b`.
|
||||
fn vec_znx_big_sub_ab_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
@@ -369,7 +297,7 @@ unsafe impl VecZnxBigSubABInplaceImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubBAInplaceImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxBigSubBAInplaceImpl<Self> for FFT64Spqlios {
|
||||
/// Subtracts `b` from `a` and stores the result on `b`.
|
||||
fn vec_znx_big_sub_ba_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
@@ -400,7 +328,7 @@ unsafe impl VecZnxBigSubBAInplaceImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubSmallAImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxBigSubSmallAImpl<Self> for FFT64Spqlios {
|
||||
/// Subtracts `b` from `a` and stores the result on `c`.
|
||||
fn vec_znx_big_sub_small_a_impl<R, A, B>(
|
||||
module: &Module<Self>,
|
||||
@@ -442,7 +370,7 @@ unsafe impl VecZnxBigSubSmallAImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubSmallAInplaceImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxBigSubSmallAInplaceImpl<Self> for FFT64Spqlios {
|
||||
/// Subtracts `a` from `res` and stores the result on `res`.
|
||||
fn vec_znx_big_sub_small_a_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
@@ -473,7 +401,7 @@ unsafe impl VecZnxBigSubSmallAInplaceImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubSmallBImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxBigSubSmallBImpl<Self> for FFT64Spqlios {
|
||||
/// Subtracts `b` from `a` and stores the result on `c`.
|
||||
fn vec_znx_big_sub_small_b_impl<R, A, B>(
|
||||
module: &Module<Self>,
|
||||
@@ -515,7 +443,7 @@ unsafe impl VecZnxBigSubSmallBImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubSmallBInplaceImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxBigSubSmallBInplaceImpl<Self> for FFT64Spqlios {
|
||||
/// Subtracts `res` from `a` and stores the result on `res`.
|
||||
fn vec_znx_big_sub_small_b_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
@@ -546,7 +474,29 @@ unsafe impl VecZnxBigSubSmallBInplaceImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigNegateInplaceImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxBigNegateImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_big_negate_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxBigToRef<Self>,
|
||||
{
|
||||
let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut();
|
||||
let a: VecZnxBig<&[u8], Self> = a.to_ref();
|
||||
unsafe {
|
||||
vec_znx::vec_znx_negate(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigNegateInplaceImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_big_negate_inplace_impl<A>(module: &Module<Self>, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxBigToMut<Self>,
|
||||
@@ -566,13 +516,13 @@ unsafe impl VecZnxBigNegateInplaceImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigNormalizeTmpBytesImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxBigNormalizeTmpBytesImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_big_normalize_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(module.ptr()) as usize }
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigNormalizeImpl<Self> for FFT64
|
||||
unsafe impl VecZnxBigNormalizeImpl<Self> for FFT64Spqlios
|
||||
where
|
||||
Self: TakeSliceImpl<Self>,
|
||||
{
|
||||
@@ -613,7 +563,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAutomorphismImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxBigAutomorphismImpl<Self> for FFT64Spqlios {
|
||||
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`.
|
||||
fn vec_znx_big_automorphism_impl<R, A>(module: &Module<Self>, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
@@ -642,10 +592,21 @@ unsafe impl VecZnxBigAutomorphismImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAutomorphismInplaceImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxBigAutomorphismInplaceTmpBytesImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_big_automorphism_inplace_tmp_bytes_impl(_module: &Module<Self>) -> usize {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAutomorphismInplaceImpl<Self> for FFT64Spqlios {
|
||||
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`.
|
||||
fn vec_znx_big_automorphism_inplace_impl<A>(module: &Module<Self>, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
fn vec_znx_big_automorphism_inplace_impl<A>(
|
||||
module: &Module<Self>,
|
||||
k: i64,
|
||||
a: &mut A,
|
||||
a_col: usize,
|
||||
_scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
A: VecZnxBigToMut<Self>,
|
||||
{
|
||||
let mut a: VecZnxBig<&mut [u8], Self> = a.to_mut();
|
||||
|
||||
@@ -1,60 +1,73 @@
|
||||
use poulpy_hal::{
|
||||
api::{TakeSlice, VecZnxIDFTTmpBytes},
|
||||
api::{TakeSlice, VecZnxIdftApplyTmpBytes},
|
||||
layouts::{
|
||||
Backend, Data, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut,
|
||||
VecZnxDftToRef, VecZnxToRef, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
|
||||
VecZnxDftToRef, VecZnxToRef, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut,
|
||||
},
|
||||
oep::{
|
||||
DFTImpl, IDFTConsumeImpl, IDFTImpl, IDFTTmpAImpl, VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl,
|
||||
VecZnxDftAllocImpl, VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl,
|
||||
VecZnxDftSubImpl, VecZnxDftZeroImpl, VecZnxIDFTTmpBytesImpl,
|
||||
VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftApplyImpl,
|
||||
VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl, VecZnxDftSubImpl,
|
||||
VecZnxDftZeroImpl, VecZnxIdftApplyConsumeImpl, VecZnxIdftApplyImpl, VecZnxIdftApplyTmpAImpl, VecZnxIdftApplyTmpBytesImpl,
|
||||
},
|
||||
reference::{
|
||||
fft64::{
|
||||
reim::{ReimCopy, ReimZero, reim_copy_ref, reim_negate_inplace_ref, reim_negate_ref, reim_zero_ref},
|
||||
vec_znx_dft::vec_znx_dft_copy,
|
||||
},
|
||||
znx::znx_zero_ref,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::cpu_spqlios::{
|
||||
FFT64,
|
||||
FFT64Spqlios,
|
||||
ffi::{vec_znx_big, vec_znx_dft},
|
||||
};
|
||||
|
||||
unsafe impl VecZnxDftFromBytesImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxDftFromBytesImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_dft_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<Self> {
|
||||
VecZnxDft::<Vec<u8>, FFT64>::from_bytes(n, cols, size, bytes)
|
||||
VecZnxDft::<Vec<u8>, Self>::from_bytes(n, cols, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftAllocBytesImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxDftAllocBytesImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_dft_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize {
|
||||
FFT64::layout_prep_word_count() * n * cols * size * size_of::<<FFT64 as Backend>::ScalarPrep>()
|
||||
Self::layout_prep_word_count() * n * cols * size * size_of::<<FFT64Spqlios as Backend>::ScalarPrep>()
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftAllocImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxDftAllocImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_dft_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxDftOwned<Self> {
|
||||
VecZnxDftOwned::alloc(n, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxIDFTTmpBytesImpl<Self> for FFT64 {
|
||||
fn vec_znx_idft_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
unsafe impl VecZnxIdftApplyTmpBytesImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_idft_apply_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(module.ptr()) as usize }
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl IDFTImpl<Self> for FFT64 {
|
||||
fn idft_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<Self>)
|
||||
where
|
||||
unsafe impl VecZnxIdftApplyImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_idft_apply_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
{
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], Self> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), a.n())
|
||||
}
|
||||
|
||||
let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_idft_tmp_bytes());
|
||||
let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_idft_apply_tmp_bytes());
|
||||
|
||||
let min_size: usize = res.size().min(a.size());
|
||||
|
||||
@@ -69,47 +82,43 @@ unsafe impl IDFTImpl<Self> for FFT64 {
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
});
|
||||
(min_size..res.size()).for_each(|j| {
|
||||
res.zero_at(res_col, j);
|
||||
});
|
||||
(min_size..res.size()).for_each(|j| znx_zero_ref(res.at_mut(res_col, j)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl IDFTTmpAImpl<Self> for FFT64 {
|
||||
fn idft_tmp_a_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
|
||||
unsafe impl VecZnxIdftApplyTmpAImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_idft_apply_tmpa_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxDftToMut<Self>,
|
||||
{
|
||||
let mut res_mut: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut();
|
||||
let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut();
|
||||
let mut a_mut: VecZnxDft<&mut [u8], Self> = a.to_mut();
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_mut.size());
|
||||
let min_size: usize = res.size().min(a_mut.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_znx_idft_tmp_a(
|
||||
module.ptr(),
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t,
|
||||
res.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t,
|
||||
1_u64,
|
||||
a_mut.at_mut_ptr(a_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1_u64,
|
||||
)
|
||||
});
|
||||
(min_size..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
})
|
||||
(min_size..res.size()).for_each(|j| znx_zero_ref(res.at_mut(res_col, j)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl IDFTConsumeImpl<Self> for FFT64 {
|
||||
fn idft_consume_impl<D: Data>(module: &Module<Self>, mut a: VecZnxDft<D, FFT64>) -> VecZnxBig<D, FFT64>
|
||||
unsafe impl VecZnxIdftApplyConsumeImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_idft_apply_consume_impl<D: Data>(module: &Module<Self>, mut a: VecZnxDft<D, Self>) -> VecZnxBig<D, Self>
|
||||
where
|
||||
VecZnxDft<D, FFT64>: VecZnxDftToMut<Self>,
|
||||
VecZnxDft<D, Self>: VecZnxDftToMut<Self>,
|
||||
{
|
||||
let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut();
|
||||
let mut a_mut: VecZnxDft<&mut [u8], Self> = a.to_mut();
|
||||
|
||||
unsafe {
|
||||
// Rev col and rows because ZnxDft.sl() >= ZnxBig.sl()
|
||||
@@ -130,89 +139,129 @@ unsafe impl IDFTConsumeImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl DFTImpl<Self> for FFT64 {
|
||||
fn dft_impl<R, A>(module: &Module<Self>, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
unsafe impl VecZnxDftApplyImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_dft_apply_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
step: usize,
|
||||
offset: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
) where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnx<&[u8]> = a.to_ref();
|
||||
let steps: usize = a_ref.size().div_ceil(step);
|
||||
let min_steps: usize = res_mut.size().min(steps);
|
||||
let mut res: VecZnxDft<&mut [u8], Self> = res.to_mut();
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let steps: usize = a.size().div_ceil(step);
|
||||
let min_steps: usize = res.size().min(steps);
|
||||
unsafe {
|
||||
(0..min_steps).for_each(|j| {
|
||||
let limb: usize = offset + j * step;
|
||||
if limb < a_ref.size() {
|
||||
if limb < a.size() {
|
||||
vec_znx_dft::vec_znx_dft(
|
||||
module.ptr(),
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
res.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1_u64,
|
||||
a_ref.at_ptr(a_col, limb),
|
||||
a.at_ptr(a_col, limb),
|
||||
1_u64,
|
||||
a_ref.sl() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
});
|
||||
(min_steps..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
});
|
||||
(min_steps..res.size()).for_each(|j| reim_zero_ref(res.at_mut(res_col, j)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftAddImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxDftAddImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_dft_add_impl<R, A, D>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
D: VecZnxDftToRef<Self>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
let b_ref: VecZnxDft<&[u8], FFT64> = b.to_ref();
|
||||
let mut res: VecZnxDft<&mut [u8], Self> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], Self> = a.to_ref();
|
||||
let b: VecZnxDft<&[u8], Self> = b.to_ref();
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_ref.size()).min(b_ref.size());
|
||||
let res_size: usize = res.size();
|
||||
let a_size: usize = a.size();
|
||||
let b_size: usize = b.size();
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_add(
|
||||
module.ptr(),
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
b_ref.at_ptr(b_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
if a_size <= b_size {
|
||||
let sum_size: usize = a_size.min(res_size);
|
||||
let cpy_size: usize = b_size.min(res_size);
|
||||
|
||||
(0..sum_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_add(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
b.at_ptr(b_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
|
||||
for j in sum_size..cpy_size {
|
||||
reim_copy_ref(res.at_mut(res_col, j), b.at(b_col, j));
|
||||
}
|
||||
|
||||
for j in cpy_size..res_size {
|
||||
reim_zero_ref(res.at_mut(res_col, j));
|
||||
}
|
||||
} else {
|
||||
let sum_size: usize = b_size.min(res_size);
|
||||
let cpy_size: usize = a_size.min(res_size);
|
||||
|
||||
(0..sum_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_add(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
b.at_ptr(b_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
|
||||
for j in sum_size..cpy_size {
|
||||
reim_copy_ref(res.at_mut(res_col, j), a.at(b_col, j));
|
||||
}
|
||||
|
||||
for j in cpy_size..res_size {
|
||||
reim_zero_ref(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
(min_size..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftAddInplaceImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxDftAddInplaceImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_dft_add_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
let mut res: VecZnxDft<&mut [u8], Self> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], Self> = a.to_ref();
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_ref.size());
|
||||
let min_size: usize = res.size().min(a.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_add(
|
||||
module.ptr(),
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
res.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
res.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
a.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
@@ -220,58 +269,93 @@ unsafe impl VecZnxDftAddInplaceImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftSubImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxDftSubImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_dft_sub_impl<R, A, D>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
D: VecZnxDftToRef<Self>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
let b_ref: VecZnxDft<&[u8], FFT64> = b.to_ref();
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_ref.size()).min(b_ref.size());
|
||||
let mut res: VecZnxDft<&mut [u8], Self> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], Self> = a.to_ref();
|
||||
let b: VecZnxDft<&[u8], Self> = b.to_ref();
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_sub(
|
||||
module.ptr(),
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
b_ref.at_ptr(b_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
let res_size: usize = res.size();
|
||||
let a_size: usize = a.size();
|
||||
let b_size: usize = b.size();
|
||||
|
||||
if a_size <= b_size {
|
||||
let sum_size: usize = a_size.min(res_size);
|
||||
let cpy_size: usize = b_size.min(res_size);
|
||||
|
||||
(0..sum_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_sub(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
b.at_ptr(b_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
|
||||
for j in sum_size..cpy_size {
|
||||
reim_negate_ref(res.at_mut(res_col, j), b.at(b_col, j));
|
||||
}
|
||||
|
||||
for j in cpy_size..res_size {
|
||||
reim_zero_ref(res.at_mut(res_col, j));
|
||||
}
|
||||
} else {
|
||||
let sum_size: usize = b_size.min(res_size);
|
||||
let cpy_size: usize = a_size.min(res_size);
|
||||
|
||||
(0..sum_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_sub(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
b.at_ptr(b_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
|
||||
for j in sum_size..cpy_size {
|
||||
reim_copy_ref(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
|
||||
for j in cpy_size..res_size {
|
||||
reim_zero_ref(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
(min_size..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftSubABInplaceImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxDftSubABInplaceImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_dft_sub_ab_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
let mut res: VecZnxDft<&mut [u8], Self> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], Self> = a.to_ref();
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_ref.size());
|
||||
let min_size: usize = res.size().min(a.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_sub(
|
||||
module.ptr(),
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
res.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
res.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
a.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
@@ -279,34 +363,38 @@ unsafe impl VecZnxDftSubABInplaceImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftSubBAInplaceImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxDftSubBAInplaceImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_dft_sub_ba_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
let mut res: VecZnxDft<&mut [u8], Self> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], Self> = a.to_ref();
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_ref.size());
|
||||
let min_size: usize = res.size().min(a.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_sub(
|
||||
module.ptr(),
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
res.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
a.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
res.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
|
||||
for j in min_size..res.size() {
|
||||
reim_negate_inplace_ref(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftCopyImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxDftCopyImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_dft_copy_impl<R, A>(
|
||||
_module: &Module<Self>,
|
||||
step: usize,
|
||||
@@ -319,27 +407,25 @@ unsafe impl VecZnxDftCopyImpl<Self> for FFT64 {
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
|
||||
let steps: usize = a_ref.size().div_ceil(step);
|
||||
let min_steps: usize = res_mut.size().min(steps);
|
||||
|
||||
(0..min_steps).for_each(|j| {
|
||||
let limb: usize = offset + j * step;
|
||||
if limb < a_ref.size() {
|
||||
res_mut
|
||||
.at_mut(res_col, j)
|
||||
.copy_from_slice(a_ref.at(a_col, limb));
|
||||
}
|
||||
});
|
||||
(min_steps..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
})
|
||||
vec_znx_dft_copy(step, offset, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftZeroImpl<Self> for FFT64 {
|
||||
impl ReimCopy for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn reim_copy(res: &mut [f64], a: &[f64]) {
|
||||
reim_copy_ref(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ReimZero for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn reim_zero(res: &mut [f64]) {
|
||||
reim_zero_ref(res);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftZeroImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_dft_zero_impl<R>(_module: &Module<Self>, res: &mut R)
|
||||
where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
|
||||
@@ -6,22 +6,22 @@ use poulpy_hal::{
|
||||
},
|
||||
oep::{
|
||||
VmpApplyDftToDftAddImpl, VmpApplyDftToDftAddTmpBytesImpl, VmpApplyDftToDftImpl, VmpApplyDftToDftTmpBytesImpl,
|
||||
VmpPMatAllocBytesImpl, VmpPMatAllocImpl, VmpPMatFromBytesImpl, VmpPMatPrepareImpl, VmpPrepareTmpBytesImpl,
|
||||
VmpPMatAllocBytesImpl, VmpPMatAllocImpl, VmpPMatFromBytesImpl, VmpPrepareImpl, VmpPrepareTmpBytesImpl,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::cpu_spqlios::{
|
||||
FFT64,
|
||||
FFT64Spqlios,
|
||||
ffi::{vec_znx_dft::vec_znx_dft_t, vmp},
|
||||
};
|
||||
|
||||
unsafe impl VmpPMatAllocBytesImpl<FFT64> for FFT64 {
|
||||
unsafe impl VmpPMatAllocBytesImpl<Self> for FFT64Spqlios {
|
||||
fn vmp_pmat_alloc_bytes_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
FFT64::layout_prep_word_count() * n * rows * cols_in * cols_out * size * size_of::<f64>()
|
||||
Self::layout_prep_word_count() * n * rows * cols_in * cols_out * size * size_of::<f64>()
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpPMatFromBytesImpl<FFT64> for FFT64 {
|
||||
unsafe impl VmpPMatFromBytesImpl<Self> for FFT64Spqlios {
|
||||
fn vmp_pmat_from_bytes_impl(
|
||||
n: usize,
|
||||
rows: usize,
|
||||
@@ -29,19 +29,19 @@ unsafe impl VmpPMatFromBytesImpl<FFT64> for FFT64 {
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
bytes: Vec<u8>,
|
||||
) -> VmpPMatOwned<FFT64> {
|
||||
) -> VmpPMatOwned<Self> {
|
||||
VmpPMatOwned::from_bytes(n, rows, cols_in, cols_out, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpPMatAllocImpl<FFT64> for FFT64 {
|
||||
fn vmp_pmat_alloc_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<FFT64> {
|
||||
unsafe impl VmpPMatAllocImpl<Self> for FFT64Spqlios {
|
||||
fn vmp_pmat_alloc_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<Self> {
|
||||
VmpPMatOwned::alloc(n, rows, cols_in, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpPrepareTmpBytesImpl<FFT64> for FFT64 {
|
||||
fn vmp_prepare_tmp_bytes_impl(module: &Module<FFT64>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
unsafe impl VmpPrepareTmpBytesImpl<Self> for FFT64Spqlios {
|
||||
fn vmp_prepare_tmp_bytes_impl(module: &Module<Self>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
unsafe {
|
||||
vmp::vmp_prepare_tmp_bytes(
|
||||
module.ptr(),
|
||||
@@ -52,13 +52,13 @@ unsafe impl VmpPrepareTmpBytesImpl<FFT64> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpPMatPrepareImpl<FFT64> for FFT64 {
|
||||
fn vmp_prepare_impl<R, A>(module: &Module<FFT64>, res: &mut R, a: &A, scratch: &mut Scratch<FFT64>)
|
||||
unsafe impl VmpPrepareImpl<Self> for FFT64Spqlios {
|
||||
fn vmp_prepare_impl<R, A>(module: &Module<Self>, res: &mut R, a: &A, scratch: &mut Scratch<Self>)
|
||||
where
|
||||
R: VmpPMatToMut<FFT64>,
|
||||
R: VmpPMatToMut<Self>,
|
||||
A: MatZnxToRef,
|
||||
{
|
||||
let mut res: VmpPMat<&mut [u8], FFT64> = res.to_mut();
|
||||
let mut res: VmpPMat<&mut [u8], Self> = res.to_mut();
|
||||
let a: MatZnx<&[u8]> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
@@ -109,9 +109,9 @@ unsafe impl VmpPMatPrepareImpl<FFT64> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpApplyDftToDftTmpBytesImpl<FFT64> for FFT64 {
|
||||
unsafe impl VmpApplyDftToDftTmpBytesImpl<Self> for FFT64Spqlios {
|
||||
fn vmp_apply_dft_to_dft_tmp_bytes_impl(
|
||||
module: &Module<FFT64>,
|
||||
module: &Module<Self>,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
@@ -131,12 +131,12 @@ unsafe impl VmpApplyDftToDftTmpBytesImpl<FFT64> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpApplyDftToDftImpl<FFT64> for FFT64 {
|
||||
fn vmp_apply_dft_to_dft_impl<R, A, C>(module: &Module<FFT64>, res: &mut R, a: &A, b: &C, scratch: &mut Scratch<FFT64>)
|
||||
unsafe impl VmpApplyDftToDftImpl<Self> for FFT64Spqlios {
|
||||
fn vmp_apply_dft_to_dft_impl<R, A, C>(module: &Module<Self>, res: &mut R, a: &A, b: &C, scratch: &mut Scratch<Self>)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
C: VmpPMatToRef<FFT64>,
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
C: VmpPMatToRef<Self>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], _> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], _> = a.to_ref();
|
||||
@@ -186,9 +186,9 @@ unsafe impl VmpApplyDftToDftImpl<FFT64> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpApplyDftToDftAddTmpBytesImpl<FFT64> for FFT64 {
|
||||
unsafe impl VmpApplyDftToDftAddTmpBytesImpl<Self> for FFT64Spqlios {
|
||||
fn vmp_apply_dft_to_dft_add_tmp_bytes_impl(
|
||||
module: &Module<FFT64>,
|
||||
module: &Module<Self>,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
@@ -208,18 +208,18 @@ unsafe impl VmpApplyDftToDftAddTmpBytesImpl<FFT64> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpApplyDftToDftAddImpl<FFT64> for FFT64 {
|
||||
unsafe impl VmpApplyDftToDftAddImpl<Self> for FFT64Spqlios {
|
||||
fn vmp_apply_dft_to_dft_add_impl<R, A, C>(
|
||||
module: &Module<FFT64>,
|
||||
module: &Module<Self>,
|
||||
res: &mut R,
|
||||
a: &A,
|
||||
b: &C,
|
||||
scale: usize,
|
||||
scratch: &mut Scratch<FFT64>,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
C: VmpPMatToRef<FFT64>,
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
C: VmpPMatToRef<Self>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], _> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], _> = a.to_ref();
|
||||
|
||||
@@ -1,17 +1,14 @@
|
||||
use poulpy_hal::{
|
||||
api::TakeSlice,
|
||||
layouts::{Scratch, Zn, ZnToMut, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut},
|
||||
oep::{
|
||||
TakeSliceImpl, ZnAddDistF64Impl, ZnAddNormalImpl, ZnFillDistF64Impl, ZnFillNormalImpl, ZnFillUniformImpl,
|
||||
ZnNormalizeInplaceImpl,
|
||||
},
|
||||
oep::{TakeSliceImpl, ZnAddNormalImpl, ZnFillNormalImpl, ZnFillUniformImpl, ZnNormalizeInplaceImpl},
|
||||
reference::zn::{zn_add_normal, zn_fill_normal, zn_fill_uniform},
|
||||
source::Source,
|
||||
};
|
||||
use rand_distr::Normal;
|
||||
|
||||
use crate::cpu_spqlios::{FFT64, ffi::zn64};
|
||||
use crate::cpu_spqlios::{FFT64Spqlios, ffi::zn64};
|
||||
|
||||
unsafe impl ZnNormalizeInplaceImpl<Self> for FFT64
|
||||
unsafe impl ZnNormalizeInplaceImpl<Self> for FFT64Spqlios
|
||||
where
|
||||
Self: TakeSliceImpl<Self>,
|
||||
{
|
||||
@@ -39,113 +36,17 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ZnFillUniformImpl<Self> for FFT64 {
|
||||
fn zn_fill_uniform_impl<R>(n: usize, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source)
|
||||
unsafe impl ZnFillUniformImpl<Self> for FFT64Spqlios {
|
||||
fn zn_fill_uniform_impl<R>(n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
|
||||
where
|
||||
R: ZnToMut,
|
||||
{
|
||||
let mut a: Zn<&mut [u8]> = res.to_mut();
|
||||
let base2k: u64 = 1 << basek;
|
||||
let mask: u64 = base2k - 1;
|
||||
let base2k_half: i64 = (base2k >> 1) as i64;
|
||||
(0..k.div_ceil(basek)).for_each(|j| {
|
||||
a.at_mut(res_col, j)[..n]
|
||||
.iter_mut()
|
||||
.for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half);
|
||||
})
|
||||
zn_fill_uniform(n, basek, res, res_col, source);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ZnFillDistF64Impl<Self> for FFT64 {
|
||||
fn zn_fill_dist_f64_impl<R, D: rand::prelude::Distribution<f64>>(
|
||||
n: usize,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: ZnToMut,
|
||||
{
|
||||
let mut a: Zn<&mut [u8]> = res.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = k.div_ceil(basek) - 1;
|
||||
let basek_rem: usize = (limb + 1) * basek - k;
|
||||
|
||||
if basek_rem != 0 {
|
||||
a.at_mut(res_col, limb)[..n].iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a = (dist_f64.round() as i64) << basek_rem;
|
||||
});
|
||||
} else {
|
||||
a.at_mut(res_col, limb)[..n].iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a = dist_f64.round() as i64
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ZnAddDistF64Impl<Self> for FFT64 {
|
||||
fn zn_add_dist_f64_impl<R, D: rand::prelude::Distribution<f64>>(
|
||||
n: usize,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: ZnToMut,
|
||||
{
|
||||
let mut a: Zn<&mut [u8]> = res.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = k.div_ceil(basek) - 1;
|
||||
let basek_rem: usize = (limb + 1) * basek - k;
|
||||
|
||||
if basek_rem != 0 {
|
||||
a.at_mut(res_col, limb)[..n].iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a += (dist_f64.round() as i64) << basek_rem;
|
||||
});
|
||||
} else {
|
||||
a.at_mut(res_col, limb)[..n].iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a += dist_f64.round() as i64
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ZnFillNormalImpl<Self> for FFT64
|
||||
where
|
||||
Self: ZnFillDistF64Impl<Self>,
|
||||
{
|
||||
unsafe impl ZnFillNormalImpl<Self> for FFT64Spqlios {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn zn_fill_normal_impl<R>(
|
||||
n: usize,
|
||||
basek: usize,
|
||||
@@ -158,23 +59,12 @@ where
|
||||
) where
|
||||
R: ZnToMut,
|
||||
{
|
||||
Self::zn_fill_dist_f64_impl(
|
||||
n,
|
||||
basek,
|
||||
res,
|
||||
res_col,
|
||||
k,
|
||||
source,
|
||||
Normal::new(0.0, sigma).unwrap(),
|
||||
bound,
|
||||
);
|
||||
zn_fill_normal(n, basek, res, res_col, k, source, sigma, bound);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ZnAddNormalImpl<Self> for FFT64
|
||||
where
|
||||
Self: ZnAddDistF64Impl<Self>,
|
||||
{
|
||||
unsafe impl ZnAddNormalImpl<Self> for FFT64Spqlios {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn zn_add_normal_impl<R>(
|
||||
n: usize,
|
||||
basek: usize,
|
||||
@@ -187,15 +77,6 @@ where
|
||||
) where
|
||||
R: ZnToMut,
|
||||
{
|
||||
Self::zn_add_dist_f64_impl(
|
||||
n,
|
||||
basek,
|
||||
res,
|
||||
res_col,
|
||||
k,
|
||||
source,
|
||||
Normal::new(0.0, sigma).unwrap(),
|
||||
bound,
|
||||
);
|
||||
zn_add_normal(n, basek, res, res_col, k, source, sigma, bound);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,8 @@ mod fft64;
|
||||
mod ntt120;
|
||||
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
mod tests;
|
||||
|
||||
pub use ffi::*;
|
||||
pub use fft64::*;
|
||||
pub use ntt120::*;
|
||||
|
||||
Submodule poulpy-backend/src/cpu_spqlios/spqlios-arithmetic updated: 708e5d7e86...b6938df774
@@ -1,2 +0,0 @@
|
||||
mod vec_znx_fft64;
|
||||
mod vmp_pmat_fft64;
|
||||
@@ -1,19 +0,0 @@
|
||||
use poulpy_hal::{
|
||||
api::ModuleNew,
|
||||
layouts::Module,
|
||||
tests::vec_znx::{test_vec_znx_add_normal, test_vec_znx_fill_uniform},
|
||||
};
|
||||
|
||||
use crate::cpu_spqlios::FFT64;
|
||||
|
||||
#[test]
|
||||
fn test_vec_znx_fill_uniform_fft64() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << 12);
|
||||
test_vec_znx_fill_uniform(&module);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vec_znx_add_normal_fft64() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << 12);
|
||||
test_vec_znx_add_normal(&module);
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
use poulpy_hal::tests::vmp_pmat::test_vmp_apply;
|
||||
|
||||
use crate::cpu_spqlios::FFT64;
|
||||
|
||||
#[test]
|
||||
fn vmp_apply() {
|
||||
test_vmp_apply::<FFT64>();
|
||||
}
|
||||
117
poulpy-backend/src/cpu_spqlios/tests.rs
Normal file
117
poulpy-backend/src/cpu_spqlios/tests.rs
Normal file
@@ -0,0 +1,117 @@
|
||||
use poulpy_hal::{backend_test_suite, cross_backend_test_suite};
|
||||
|
||||
cross_backend_test_suite! {
|
||||
mod vec_znx,
|
||||
backend_ref = crate::cpu_fft64_ref::FFT64Ref,
|
||||
backend_test = crate::cpu_spqlios::FFT64Spqlios,
|
||||
size = 1 << 5,
|
||||
basek = 12,
|
||||
tests = {
|
||||
test_vec_znx_add => poulpy_hal::test_suite::vec_znx::test_vec_znx_add,
|
||||
test_vec_znx_add_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_inplace,
|
||||
test_vec_znx_add_scalar => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_scalar,
|
||||
test_vec_znx_add_scalar_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_scalar_inplace,
|
||||
test_vec_znx_sub => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub,
|
||||
test_vec_znx_sub_ab_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_ab_inplace,
|
||||
test_vec_znx_sub_ba_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_ba_inplace,
|
||||
test_vec_znx_sub_scalar => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_scalar,
|
||||
test_vec_znx_sub_scalar_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_scalar_inplace,
|
||||
test_vec_znx_rsh => poulpy_hal::test_suite::vec_znx::test_vec_znx_rsh,
|
||||
test_vec_znx_rsh_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_rsh_inplace,
|
||||
test_vec_znx_lsh => poulpy_hal::test_suite::vec_znx::test_vec_znx_lsh,
|
||||
test_vec_znx_lsh_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_lsh_inplace,
|
||||
test_vec_znx_negate => poulpy_hal::test_suite::vec_znx::test_vec_znx_negate,
|
||||
test_vec_znx_negate_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_negate_inplace,
|
||||
test_vec_znx_rotate => poulpy_hal::test_suite::vec_znx::test_vec_znx_rotate,
|
||||
test_vec_znx_rotate_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_rotate_inplace,
|
||||
test_vec_znx_automorphism => poulpy_hal::test_suite::vec_znx::test_vec_znx_automorphism,
|
||||
test_vec_znx_automorphism_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_automorphism_inplace,
|
||||
test_vec_znx_mul_xp_minus_one => poulpy_hal::test_suite::vec_znx::test_vec_znx_mul_xp_minus_one,
|
||||
test_vec_znx_mul_xp_minus_one_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_mul_xp_minus_one_inplace,
|
||||
test_vec_znx_normalize => poulpy_hal::test_suite::vec_znx::test_vec_znx_normalize,
|
||||
test_vec_znx_normalize_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_normalize_inplace,
|
||||
test_vec_znx_switch_ring => poulpy_hal::test_suite::vec_znx::test_vec_znx_switch_ring,
|
||||
test_vec_znx_split_ring => poulpy_hal::test_suite::vec_znx::test_vec_znx_split_ring,
|
||||
test_vec_znx_copy => poulpy_hal::test_suite::vec_znx::test_vec_znx_copy,
|
||||
}
|
||||
}
|
||||
|
||||
cross_backend_test_suite! {
|
||||
mod svp,
|
||||
backend_ref = crate::cpu_fft64_ref::FFT64Ref,
|
||||
backend_test = crate::cpu_spqlios::FFT64Spqlios,
|
||||
size = 1 << 5,
|
||||
basek = 12,
|
||||
tests = {
|
||||
test_svp_apply_dft_to_dft => poulpy_hal::test_suite::svp::test_svp_apply_dft_to_dft,
|
||||
test_svp_apply_dft_to_dft_inplace => poulpy_hal::test_suite::svp::test_svp_apply_dft_to_dft_inplace,
|
||||
}
|
||||
}
|
||||
|
||||
cross_backend_test_suite! {
|
||||
mod vec_znx_big,
|
||||
backend_ref = crate::cpu_fft64_ref::FFT64Ref,
|
||||
backend_test = crate::cpu_spqlios::FFT64Spqlios,
|
||||
size = 1 << 5,
|
||||
basek = 12,
|
||||
tests = {
|
||||
test_vec_znx_big_add => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add,
|
||||
test_vec_znx_big_add_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add_inplace,
|
||||
test_vec_znx_big_add_small => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add_small,
|
||||
test_vec_znx_big_add_small_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add_small_inplace,
|
||||
test_vec_znx_big_sub => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub,
|
||||
test_vec_znx_big_sub_ab_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_ab_inplace,
|
||||
test_vec_znx_big_automorphism => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_automorphism,
|
||||
test_vec_znx_big_automorphism_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_automorphism_inplace,
|
||||
test_vec_znx_big_negate => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_negate,
|
||||
test_vec_znx_big_negate_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_negate_inplace,
|
||||
test_vec_znx_big_normalize => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_normalize,
|
||||
test_vec_znx_big_sub_ba_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_ba_inplace,
|
||||
test_vec_znx_big_sub_small_a => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_a,
|
||||
test_vec_znx_big_sub_small_a_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_a_inplace,
|
||||
test_vec_znx_big_sub_small_b => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_b,
|
||||
test_vec_znx_big_sub_small_b_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_b_inplace,
|
||||
}
|
||||
}
|
||||
|
||||
cross_backend_test_suite! {
|
||||
mod vec_znx_dft,
|
||||
backend_ref = crate::cpu_fft64_ref::FFT64Ref,
|
||||
backend_test = crate::cpu_spqlios::FFT64Spqlios,
|
||||
size = 1 << 5,
|
||||
basek = 12,
|
||||
tests = {
|
||||
test_vec_znx_dft_add => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_add,
|
||||
test_vec_znx_dft_add_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_add_inplace,
|
||||
test_vec_znx_dft_sub => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub,
|
||||
test_vec_znx_dft_sub_ab_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub_ab_inplace,
|
||||
test_vec_znx_dft_sub_ba_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub_ba_inplace,
|
||||
test_vec_znx_idft_apply => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_idft_apply,
|
||||
test_vec_znx_idft_apply_consume => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_idft_apply_consume,
|
||||
test_vec_znx_idft_apply_tmpa => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_idft_apply_tmpa,
|
||||
}
|
||||
}
|
||||
|
||||
cross_backend_test_suite! {
|
||||
mod vmp,
|
||||
backend_ref = crate::cpu_fft64_ref::FFT64Ref,
|
||||
backend_test = crate::cpu_spqlios::FFT64Spqlios,
|
||||
size = 1 << 5,
|
||||
basek = 12,
|
||||
tests = {
|
||||
test_vmp_apply_dft_to_dft => poulpy_hal::test_suite::vmp::test_vmp_apply_dft_to_dft,
|
||||
test_vmp_apply_dft_to_dft_add => poulpy_hal::test_suite::vmp::test_vmp_apply_dft_to_dft_add,
|
||||
}
|
||||
}
|
||||
|
||||
backend_test_suite! {
|
||||
mod sampling,
|
||||
backend = crate::cpu_spqlios::FFT64Spqlios,
|
||||
size = 1 << 12,
|
||||
tests = {
|
||||
test_vec_znx_fill_uniform => poulpy_hal::test_suite::vec_znx::test_vec_znx_fill_uniform,
|
||||
test_vec_znx_fill_normal => poulpy_hal::test_suite::vec_znx::test_vec_znx_fill_normal,
|
||||
test_vec_znx_add_normal => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_normal,
|
||||
test_vec_znx_big_sub_small_b_inplace => poulpy_hal::reference::fft64::vec_znx_big::test_vec_znx_big_add_normal,
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user