mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
updated repo for publishing (#74)
This commit is contained in:
committed by
GitHub
parent
0be569eca0
commit
62eb87cc07
7
poulpy-backend/src/implementation/cpu_spqlios/ffi/cnv.rs
Normal file
7
poulpy-backend/src/implementation/cpu_spqlios/ffi/cnv.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
pub type CNV_PVEC_L = cnv_pvec_l_t;
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct cnv_pvec_r_t {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type CNV_PVEC_R = cnv_pvec_r_t;
|
||||
8
poulpy-backend/src/implementation/cpu_spqlios/ffi/mod.rs
Normal file
8
poulpy-backend/src/implementation/cpu_spqlios/ffi/mod.rs
Normal file
@@ -0,0 +1,8 @@
|
||||
pub mod module;
|
||||
pub mod reim;
|
||||
pub mod svp;
|
||||
pub mod vec_znx;
|
||||
pub mod vec_znx_big;
|
||||
pub mod vec_znx_dft;
|
||||
pub mod vmp;
|
||||
pub mod znx;
|
||||
19
poulpy-backend/src/implementation/cpu_spqlios/ffi/module.rs
Normal file
19
poulpy-backend/src/implementation/cpu_spqlios/ffi/module.rs
Normal file
@@ -0,0 +1,19 @@
|
||||
pub struct module_info_t {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
|
||||
pub type module_type_t = ::std::os::raw::c_uint;
|
||||
pub use self::module_type_t as MODULE_TYPE;
|
||||
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
pub type MODULE = module_info_t;
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_module_info(N: u64, mode: MODULE_TYPE) -> *mut MODULE;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn delete_module_info(module_info: *mut MODULE);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn module_get_n(module: *const MODULE) -> u64;
|
||||
}
|
||||
172
poulpy-backend/src/implementation/cpu_spqlios/ffi/reim.rs
Normal file
172
poulpy-backend/src/implementation/cpu_spqlios/ffi/reim.rs
Normal file
@@ -0,0 +1,172 @@
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct reim_fft_precomp {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type REIM_FFT_PRECOMP = reim_fft_precomp;
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct reim_ifft_precomp {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type REIM_IFFT_PRECOMP = reim_ifft_precomp;
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct reim_mul_precomp {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type REIM_FFTVEC_MUL_PRECOMP = reim_mul_precomp;
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct reim_addmul_precomp {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type REIM_FFTVEC_ADDMUL_PRECOMP = reim_addmul_precomp;
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct reim_from_znx32_precomp {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type REIM_FROM_ZNX32_PRECOMP = reim_from_znx32_precomp;
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct reim_from_znx64_precomp {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type REIM_FROM_ZNX64_PRECOMP = reim_from_znx64_precomp;
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct reim_from_tnx32_precomp {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type REIM_FROM_TNX32_PRECOMP = reim_from_tnx32_precomp;
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct reim_to_tnx32_precomp {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type REIM_TO_TNX32_PRECOMP = reim_to_tnx32_precomp;
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct reim_to_tnx_precomp {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type REIM_TO_TNX_PRECOMP = reim_to_tnx_precomp;
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct reim_to_znx64_precomp {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type REIM_TO_ZNX64_PRECOMP = reim_to_znx64_precomp;
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_reim_fft_precomp(m: u32, num_buffers: u32) -> *mut REIM_FFT_PRECOMP;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_fft_precomp_get_buffer(tables: *const REIM_FFT_PRECOMP, buffer_index: u32) -> *mut f64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_reim_fft_buffer(m: u32) -> *mut f64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn delete_reim_fft_buffer(buffer: *mut f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_fft(tables: *const REIM_FFT_PRECOMP, data: *mut f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_reim_ifft_precomp(m: u32, num_buffers: u32) -> *mut REIM_IFFT_PRECOMP;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_ifft_precomp_get_buffer(tables: *const REIM_IFFT_PRECOMP, buffer_index: u32) -> *mut f64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_ifft(tables: *const REIM_IFFT_PRECOMP, data: *mut f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_reim_fftvec_mul_precomp(m: u32) -> *mut REIM_FFTVEC_MUL_PRECOMP;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_fftvec_mul(tables: *const REIM_FFTVEC_MUL_PRECOMP, r: *mut f64, a: *const f64, b: *const f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_reim_fftvec_addmul_precomp(m: u32) -> *mut REIM_FFTVEC_ADDMUL_PRECOMP;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_fftvec_addmul(tables: *const REIM_FFTVEC_ADDMUL_PRECOMP, r: *mut f64, a: *const f64, b: *const f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_reim_from_znx32_precomp(m: u32, log2bound: u32) -> *mut REIM_FROM_ZNX32_PRECOMP;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_from_znx32(tables: *const REIM_FROM_ZNX32_PRECOMP, r: *mut ::std::os::raw::c_void, a: *const i32);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_from_znx64(tables: *const REIM_FROM_ZNX64_PRECOMP, r: *mut ::std::os::raw::c_void, a: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_reim_from_znx64_precomp(m: u32, maxbnd: u32) -> *mut REIM_FROM_ZNX64_PRECOMP;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_from_znx64_simple(m: u32, log2bound: u32, r: *mut ::std::os::raw::c_void, a: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_reim_from_tnx32_precomp(m: u32) -> *mut REIM_FROM_TNX32_PRECOMP;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_from_tnx32(tables: *const REIM_FROM_TNX32_PRECOMP, r: *mut ::std::os::raw::c_void, a: *const i32);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_reim_to_tnx32_precomp(m: u32, divisor: f64, log2overhead: u32) -> *mut REIM_TO_TNX32_PRECOMP;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_to_tnx32(tables: *const REIM_TO_TNX32_PRECOMP, r: *mut i32, a: *const ::std::os::raw::c_void);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_reim_to_tnx_precomp(m: u32, divisor: f64, log2overhead: u32) -> *mut REIM_TO_TNX_PRECOMP;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_to_tnx(tables: *const REIM_TO_TNX_PRECOMP, r: *mut f64, a: *const f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_to_tnx_simple(m: u32, divisor: f64, log2overhead: u32, r: *mut f64, a: *const f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_reim_to_znx64_precomp(m: u32, divisor: f64, log2bound: u32) -> *mut REIM_TO_ZNX64_PRECOMP;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_to_znx64(precomp: *const REIM_TO_ZNX64_PRECOMP, r: *mut i64, a: *const ::std::os::raw::c_void);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_to_znx64_simple(m: u32, divisor: f64, log2bound: u32, r: *mut i64, a: *const ::std::os::raw::c_void);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_fft_simple(m: u32, data: *mut ::std::os::raw::c_void);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_ifft_simple(m: u32, data: *mut ::std::os::raw::c_void);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_fftvec_mul_simple(
|
||||
m: u32,
|
||||
r: *mut ::std::os::raw::c_void,
|
||||
a: *const ::std::os::raw::c_void,
|
||||
b: *const ::std::os::raw::c_void,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_fftvec_addmul_simple(
|
||||
m: u32,
|
||||
r: *mut ::std::os::raw::c_void,
|
||||
a: *const ::std::os::raw::c_void,
|
||||
b: *const ::std::os::raw::c_void,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_from_znx32_simple(m: u32, log2bound: u32, r: *mut ::std::os::raw::c_void, x: *const i32);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_from_tnx32_simple(m: u32, r: *mut ::std::os::raw::c_void, x: *const i32);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_to_tnx32_simple(m: u32, divisor: f64, log2overhead: u32, r: *mut i32, x: *const ::std::os::raw::c_void);
|
||||
}
|
||||
47
poulpy-backend/src/implementation/cpu_spqlios/ffi/svp.rs
Normal file
47
poulpy-backend/src/implementation/cpu_spqlios/ffi/svp.rs
Normal file
@@ -0,0 +1,47 @@
|
||||
use crate::implementation::cpu_spqlios::ffi::{module::MODULE, vec_znx_dft::VEC_ZNX_DFT};
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct svp_ppol_t {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type SVP_PPOL = svp_ppol_t;
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn bytes_of_svp_ppol(module: *const MODULE) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_svp_ppol(module: *const MODULE) -> *mut SVP_PPOL;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn delete_svp_ppol(res: *mut SVP_PPOL);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn svp_prepare(module: *const MODULE, ppol: *mut SVP_PPOL, pol: *const i64);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn svp_apply_dft(
|
||||
module: *const MODULE,
|
||||
res: *const VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
ppol: *const SVP_PPOL,
|
||||
a: *const i64,
|
||||
a_size: u64,
|
||||
a_sl: u64,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn svp_apply_dft_to_dft(
|
||||
module: *const MODULE,
|
||||
res: *const VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
res_cols: u64,
|
||||
ppol: *const SVP_PPOL,
|
||||
a: *const VEC_ZNX_DFT,
|
||||
a_size: u64,
|
||||
a_cols: u64,
|
||||
);
|
||||
}
|
||||
115
poulpy-backend/src/implementation/cpu_spqlios/ffi/vec_znx.rs
Normal file
115
poulpy-backend/src/implementation/cpu_spqlios/ffi/vec_znx.rs
Normal file
@@ -0,0 +1,115 @@
|
||||
use crate::implementation::cpu_spqlios::ffi::module::MODULE;
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_add(
|
||||
module: *const MODULE,
|
||||
res: *mut i64,
|
||||
res_size: u64,
|
||||
res_sl: u64,
|
||||
a: *const i64,
|
||||
a_size: u64,
|
||||
a_sl: u64,
|
||||
b: *const i64,
|
||||
b_size: u64,
|
||||
b_sl: u64,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_automorphism(
|
||||
module: *const MODULE,
|
||||
p: i64,
|
||||
res: *mut i64,
|
||||
res_size: u64,
|
||||
res_sl: u64,
|
||||
a: *const i64,
|
||||
a_size: u64,
|
||||
a_sl: u64,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_mul_xp_minus_one(
|
||||
module: *const MODULE,
|
||||
p: i64,
|
||||
res: *mut i64,
|
||||
res_size: u64,
|
||||
res_sl: u64,
|
||||
a: *const i64,
|
||||
a_size: u64,
|
||||
a_sl: u64,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_negate(
|
||||
module: *const MODULE,
|
||||
res: *mut i64,
|
||||
res_size: u64,
|
||||
res_sl: u64,
|
||||
a: *const i64,
|
||||
a_size: u64,
|
||||
a_sl: u64,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_rotate(
|
||||
module: *const MODULE,
|
||||
p: i64,
|
||||
res: *mut i64,
|
||||
res_size: u64,
|
||||
res_sl: u64,
|
||||
a: *const i64,
|
||||
a_size: u64,
|
||||
a_sl: u64,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_sub(
|
||||
module: *const MODULE,
|
||||
res: *mut i64,
|
||||
res_size: u64,
|
||||
res_sl: u64,
|
||||
a: *const i64,
|
||||
a_size: u64,
|
||||
a_sl: u64,
|
||||
b: *const i64,
|
||||
b_size: u64,
|
||||
b_sl: u64,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_zero(module: *const MODULE, res: *mut i64, res_size: u64, res_sl: u64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_copy(
|
||||
module: *const MODULE,
|
||||
res: *mut i64,
|
||||
res_size: u64,
|
||||
res_sl: u64,
|
||||
a: *const i64,
|
||||
a_size: u64,
|
||||
a_sl: u64,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_normalize_base2k(
|
||||
module: *const MODULE,
|
||||
n: u64,
|
||||
base2k: u64,
|
||||
res: *mut i64,
|
||||
res_size: u64,
|
||||
res_sl: u64,
|
||||
a: *const i64,
|
||||
a_size: u64,
|
||||
a_sl: u64,
|
||||
tmp_space: *mut u8,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_normalize_base2k_tmp_bytes(module: *const MODULE, n: u64) -> u64;
|
||||
}
|
||||
163
poulpy-backend/src/implementation/cpu_spqlios/ffi/vec_znx_big.rs
Normal file
163
poulpy-backend/src/implementation/cpu_spqlios/ffi/vec_znx_big.rs
Normal file
@@ -0,0 +1,163 @@
|
||||
use crate::implementation::cpu_spqlios::ffi::module::MODULE;
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct vec_znx_big_t {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type VEC_ZNX_BIG = vec_znx_big_t;
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn bytes_of_vec_znx_big(module: *const MODULE, size: u64) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_vec_znx_big(module: *const MODULE, size: u64) -> *mut VEC_ZNX_BIG;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn delete_vec_znx_big(res: *mut VEC_ZNX_BIG);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_add(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
a: *const VEC_ZNX_BIG,
|
||||
a_size: u64,
|
||||
b: *const VEC_ZNX_BIG,
|
||||
b_size: u64,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_add_small(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
a: *const VEC_ZNX_BIG,
|
||||
a_size: u64,
|
||||
b: *const i64,
|
||||
b_size: u64,
|
||||
b_sl: u64,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_add_small2(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
a: *const i64,
|
||||
a_size: u64,
|
||||
a_sl: u64,
|
||||
b: *const i64,
|
||||
b_size: u64,
|
||||
b_sl: u64,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_sub(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
a: *const VEC_ZNX_BIG,
|
||||
a_size: u64,
|
||||
b: *const VEC_ZNX_BIG,
|
||||
b_size: u64,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_sub_small_b(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
a: *const VEC_ZNX_BIG,
|
||||
a_size: u64,
|
||||
b: *const i64,
|
||||
b_size: u64,
|
||||
b_sl: u64,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_sub_small_a(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
a: *const i64,
|
||||
a_size: u64,
|
||||
a_sl: u64,
|
||||
b: *const VEC_ZNX_BIG,
|
||||
b_size: u64,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_sub_small2(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
a: *const i64,
|
||||
a_size: u64,
|
||||
a_sl: u64,
|
||||
b: *const i64,
|
||||
b_size: u64,
|
||||
b_sl: u64,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_normalize_base2k_tmp_bytes(module: *const MODULE, n: u64) -> u64;
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_normalize_base2k(
|
||||
module: *const MODULE,
|
||||
n: u64,
|
||||
log2_base2k: u64,
|
||||
res: *mut i64,
|
||||
res_size: u64,
|
||||
res_sl: u64,
|
||||
a: *const VEC_ZNX_BIG,
|
||||
a_size: u64,
|
||||
tmp_space: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_range_normalize_base2k(
|
||||
module: *const MODULE,
|
||||
n: u64,
|
||||
log2_base2k: u64,
|
||||
res: *mut i64,
|
||||
res_size: u64,
|
||||
res_sl: u64,
|
||||
a: *const VEC_ZNX_BIG,
|
||||
a_range_begin: u64,
|
||||
a_range_xend: u64,
|
||||
a_range_step: u64,
|
||||
tmp_space: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_range_normalize_base2k_tmp_bytes(module: *const MODULE, n: u64) -> u64;
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_automorphism(
|
||||
module: *const MODULE,
|
||||
p: i64,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
a: *const VEC_ZNX_BIG,
|
||||
a_size: u64,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_rotate(
|
||||
module: *const MODULE,
|
||||
p: i64,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
a: *const VEC_ZNX_BIG,
|
||||
a_size: u64,
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
use crate::implementation::cpu_spqlios::ffi::{module::MODULE, vec_znx_big::VEC_ZNX_BIG};
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct vec_znx_dft_t {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type VEC_ZNX_DFT = vec_znx_dft_t;
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn bytes_of_vec_znx_dft(module: *const MODULE, size: u64) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_vec_znx_dft(module: *const MODULE, size: u64) -> *mut VEC_ZNX_DFT;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn delete_vec_znx_dft(res: *mut VEC_ZNX_DFT);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_dft_zero(module: *const MODULE, res: *mut VEC_ZNX_DFT, res_size: u64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_dft_add(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
a: *const VEC_ZNX_DFT,
|
||||
a_size: u64,
|
||||
b: *const VEC_ZNX_DFT,
|
||||
b_size: u64,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_dft_sub(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
a: *const VEC_ZNX_DFT,
|
||||
a_size: u64,
|
||||
b: *const VEC_ZNX_DFT,
|
||||
b_size: u64,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_dft(module: *const MODULE, res: *mut VEC_ZNX_DFT, res_size: u64, a: *const i64, a_size: u64, a_sl: u64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_idft(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
a_dft: *const VEC_ZNX_DFT,
|
||||
a_size: u64,
|
||||
tmp: *mut u8,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_idft_tmp_bytes(module: *const MODULE, n: u64) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_idft_tmp_a(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
a_dft: *mut VEC_ZNX_DFT,
|
||||
a_size: u64,
|
||||
);
|
||||
}
|
||||
114
poulpy-backend/src/implementation/cpu_spqlios/ffi/vmp.rs
Normal file
114
poulpy-backend/src/implementation/cpu_spqlios/ffi/vmp.rs
Normal file
@@ -0,0 +1,114 @@
|
||||
use crate::implementation::cpu_spqlios::ffi::{module::MODULE, vec_znx_dft::VEC_ZNX_DFT};
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct vmp_pmat_t {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
|
||||
// [rows][cols] = [#Decomposition][#Limbs]
|
||||
pub type VMP_PMAT = vmp_pmat_t;
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn bytes_of_vmp_pmat(module: *const MODULE, nrows: u64, ncols: u64) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_vmp_pmat(module: *const MODULE, nrows: u64, ncols: u64) -> *mut VMP_PMAT;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn delete_vmp_pmat(res: *mut VMP_PMAT);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_apply_dft(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
a: *const i64,
|
||||
a_size: u64,
|
||||
a_sl: u64,
|
||||
pmat: *const VMP_PMAT,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
tmp_space: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_apply_dft_add(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
a: *const i64,
|
||||
a_size: u64,
|
||||
a_sl: u64,
|
||||
pmat: *const VMP_PMAT,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
pmat_scale: u64,
|
||||
tmp_space: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_apply_dft_tmp_bytes(module: *const MODULE, res_size: u64, a_size: u64, nrows: u64, ncols: u64) -> u64;
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_apply_dft_to_dft(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
a_dft: *const VEC_ZNX_DFT,
|
||||
a_size: u64,
|
||||
pmat: *const VMP_PMAT,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
tmp_space: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_apply_dft_to_dft_add(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
a_dft: *const VEC_ZNX_DFT,
|
||||
a_size: u64,
|
||||
pmat: *const VMP_PMAT,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
pmat_scale: u64,
|
||||
tmp_space: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_apply_dft_to_dft_tmp_bytes(
|
||||
module: *const MODULE,
|
||||
nn: u64,
|
||||
res_size: u64,
|
||||
a_size: u64,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
) -> u64;
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_prepare_contiguous(
|
||||
module: *const MODULE,
|
||||
pmat: *mut VMP_PMAT,
|
||||
mat: *const i64,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
tmp_space: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_prepare_contiguous_dft(module: *const MODULE, pmat: *mut VMP_PMAT, mat: *const f64, nrows: u64, ncols: u64);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_prepare_tmp_bytes(module: *const MODULE, nn: u64, nrows: u64, ncols: u64) -> u64;
|
||||
}
|
||||
79
poulpy-backend/src/implementation/cpu_spqlios/ffi/znx.rs
Normal file
79
poulpy-backend/src/implementation/cpu_spqlios/ffi/znx.rs
Normal file
@@ -0,0 +1,79 @@
|
||||
use crate::implementation::cpu_spqlios::ffi::module::MODULE;
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_add_i64_ref(nn: u64, res: *mut i64, a: *const i64, b: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_add_i64_avx(nn: u64, res: *mut i64, a: *const i64, b: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_sub_i64_ref(nn: u64, res: *mut i64, a: *const i64, b: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_sub_i64_avx(nn: u64, res: *mut i64, a: *const i64, b: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_negate_i64_ref(nn: u64, res: *mut i64, a: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_negate_i64_avx(nn: u64, res: *mut i64, a: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_copy_i64_ref(nn: u64, res: *mut i64, a: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_zero_i64_ref(nn: u64, res: *mut i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn rnx_divide_by_m_ref(nn: u64, m: f64, res: *mut f64, a: *const f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn rnx_divide_by_m_avx(nn: u64, m: f64, res: *mut f64, a: *const f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn rnx_rotate_f64(nn: u64, p: i64, res: *mut f64, in_: *const f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_rotate_i64(nn: u64, p: i64, res: *mut i64, in_: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn rnx_rotate_inplace_f64(nn: u64, p: i64, res: *mut f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_rotate_inplace_i64(nn: u64, p: i64, res: *mut i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn rnx_automorphism_f64(nn: u64, p: i64, res: *mut f64, in_: *const f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_automorphism_i64(nn: u64, p: i64, res: *mut i64, in_: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn rnx_automorphism_inplace_f64(nn: u64, p: i64, res: *mut f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_automorphism_inplace_i64(nn: u64, p: i64, res: *mut i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn rnx_mul_xp_minus_one_f64(nn: u64, p: i64, res: *mut f64, in_: *const f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_mul_xp_minus_one_i64(nn: u64, p: i64, res: *mut i64, in_: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn rnx_mul_xp_minus_one_inplace_f64(nn: u64, p: i64, res: *mut f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_mul_xp_minus_one_inplace_i64(nn: u64, p: i64, res: *mut i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_normalize(nn: u64, base_k: u64, out: *mut i64, carry_out: *mut i64, in_: *const i64, carry_in: *const i64);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_small_single_product(module: *const MODULE, res: *mut i64, a: *const i64, b: *const i64, tmp: *mut u8);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_small_single_product_tmp_bytes(module: *const MODULE) -> u64;
|
||||
}
|
||||
27
poulpy-backend/src/implementation/cpu_spqlios/mod.rs
Normal file
27
poulpy-backend/src/implementation/cpu_spqlios/mod.rs
Normal file
@@ -0,0 +1,27 @@
|
||||
mod ffi;
|
||||
mod module_fft64;
|
||||
mod module_ntt120;
|
||||
mod scratch;
|
||||
mod svp_ppol_fft64;
|
||||
mod svp_ppol_ntt120;
|
||||
mod vec_znx;
|
||||
mod vec_znx_big_fft64;
|
||||
mod vec_znx_big_ntt120;
|
||||
mod vec_znx_dft_fft64;
|
||||
mod vec_znx_dft_ntt120;
|
||||
mod vmp_pmat_fft64;
|
||||
mod vmp_pmat_ntt120;
|
||||
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
|
||||
pub use module_fft64::*;
|
||||
pub use module_ntt120::*;
|
||||
|
||||
/// For external documentation
|
||||
pub use vec_znx::{
|
||||
vec_znx_copy_ref, vec_znx_lsh_inplace_ref, vec_znx_merge_ref, vec_znx_rsh_inplace_ref, vec_znx_split_ref,
|
||||
vec_znx_switch_degree_ref,
|
||||
};
|
||||
|
||||
pub trait CPUAVX {}
|
||||
@@ -0,0 +1,29 @@
|
||||
use std::ptr::NonNull;
|
||||
|
||||
use crate::{
|
||||
hal::{
|
||||
layouts::{Backend, Module},
|
||||
oep::ModuleNewImpl,
|
||||
},
|
||||
implementation::cpu_spqlios::{
|
||||
CPUAVX,
|
||||
ffi::module::{MODULE, delete_module_info, new_module_info},
|
||||
},
|
||||
};
|
||||
|
||||
pub struct FFT64;
|
||||
|
||||
impl CPUAVX for FFT64 {}
|
||||
|
||||
impl Backend for FFT64 {
|
||||
type Handle = MODULE;
|
||||
unsafe fn destroy(handle: NonNull<Self::Handle>) {
|
||||
unsafe { delete_module_info(handle.as_ptr()) }
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ModuleNewImpl<Self> for FFT64 {
|
||||
fn new_impl(n: u64) -> Module<Self> {
|
||||
unsafe { Module::from_raw_parts(new_module_info(n, 0), n) }
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
use std::ptr::NonNull;
|
||||
|
||||
use crate::{
|
||||
hal::{
|
||||
layouts::{Backend, Module},
|
||||
oep::ModuleNewImpl,
|
||||
},
|
||||
implementation::cpu_spqlios::{
|
||||
CPUAVX,
|
||||
ffi::module::{MODULE, delete_module_info, new_module_info},
|
||||
},
|
||||
};
|
||||
|
||||
pub struct NTT120;
|
||||
|
||||
impl CPUAVX for NTT120 {}
|
||||
|
||||
impl Backend for NTT120 {
|
||||
type Handle = MODULE;
|
||||
unsafe fn destroy(handle: NonNull<Self::Handle>) {
|
||||
unsafe { delete_module_info(handle.as_ptr()) }
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ModuleNewImpl<Self> for NTT120 {
|
||||
fn new_impl(n: u64) -> Module<Self> {
|
||||
unsafe { Module::from_raw_parts(new_module_info(n, 1), n) }
|
||||
}
|
||||
}
|
||||
271
poulpy-backend/src/implementation/cpu_spqlios/scratch.rs
Normal file
271
poulpy-backend/src/implementation/cpu_spqlios/scratch.rs
Normal file
@@ -0,0 +1,271 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::{
|
||||
DEFAULTALIGN, alloc_aligned,
|
||||
hal::{
|
||||
api::ScratchFromBytes,
|
||||
layouts::{Backend, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat},
|
||||
oep::{
|
||||
ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, SvpPPolAllocBytesImpl,
|
||||
TakeMatZnxImpl, TakeScalarZnxImpl, TakeSliceImpl, TakeSvpPPolImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl,
|
||||
TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, TakeVmpPMatImpl, VecZnxBigAllocBytesImpl,
|
||||
VecZnxDftAllocBytesImpl, VmpPMatAllocBytesImpl,
|
||||
},
|
||||
},
|
||||
implementation::cpu_spqlios::CPUAVX,
|
||||
};
|
||||
|
||||
unsafe impl<B: Backend> ScratchOwnedAllocImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn scratch_owned_alloc_impl(size: usize) -> ScratchOwned<B> {
|
||||
let data: Vec<u8> = alloc_aligned(size);
|
||||
ScratchOwned {
|
||||
data,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> ScratchOwnedBorrowImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn scratch_owned_borrow_impl(scratch: &mut ScratchOwned<B>) -> &mut Scratch<B> {
|
||||
Scratch::from_bytes(&mut scratch.data)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> ScratchFromBytesImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn scratch_from_bytes_impl(data: &mut [u8]) -> &mut Scratch<B> {
|
||||
unsafe { &mut *(data as *mut [u8] as *mut Scratch<B>) }
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> ScratchAvailableImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn scratch_available_impl(scratch: &Scratch<B>) -> usize {
|
||||
let ptr: *const u8 = scratch.data.as_ptr();
|
||||
let self_len: usize = scratch.data.len();
|
||||
let aligned_offset: usize = ptr.align_offset(DEFAULTALIGN);
|
||||
self_len.saturating_sub(aligned_offset)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeSliceImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn take_slice_impl<T>(scratch: &mut Scratch<B>, len: usize) -> (&mut [T], &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, len * std::mem::size_of::<T>());
|
||||
|
||||
unsafe {
|
||||
(
|
||||
&mut *(std::ptr::slice_from_raw_parts_mut(take_slice.as_mut_ptr() as *mut T, len)),
|
||||
Scratch::from_bytes(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeScalarZnxImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn take_scalar_znx_impl(scratch: &mut Scratch<B>, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, ScalarZnx::alloc_bytes(n, cols));
|
||||
(
|
||||
ScalarZnx::from_data(take_slice, n, cols),
|
||||
Scratch::from_bytes(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeSvpPPolImpl<B> for B
|
||||
where
|
||||
B: CPUAVX + SvpPPolAllocBytesImpl<B>,
|
||||
{
|
||||
fn take_svp_ppol_impl(scratch: &mut Scratch<B>, n: usize, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, B::svp_ppol_alloc_bytes_impl(n, cols));
|
||||
(
|
||||
SvpPPol::from_data(take_slice, n, cols),
|
||||
Scratch::from_bytes(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVecZnxImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn take_vec_znx_impl(scratch: &mut Scratch<B>, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, VecZnx::alloc_bytes(n, cols, size));
|
||||
(
|
||||
VecZnx::from_data(take_slice, n, cols, size),
|
||||
Scratch::from_bytes(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVecZnxBigImpl<B> for B
|
||||
where
|
||||
B: CPUAVX + VecZnxBigAllocBytesImpl<B>,
|
||||
{
|
||||
fn take_vec_znx_big_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
n: usize,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (VecZnxBig<&mut [u8], B>, &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(
|
||||
&mut scratch.data,
|
||||
B::vec_znx_big_alloc_bytes_impl(n, cols, size),
|
||||
);
|
||||
(
|
||||
VecZnxBig::from_data(take_slice, n, cols, size),
|
||||
Scratch::from_bytes(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVecZnxDftImpl<B> for B
|
||||
where
|
||||
B: CPUAVX + VecZnxDftAllocBytesImpl<B>,
|
||||
{
|
||||
fn take_vec_znx_dft_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
n: usize,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (VecZnxDft<&mut [u8], B>, &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(
|
||||
&mut scratch.data,
|
||||
B::vec_znx_dft_alloc_bytes_impl(n, cols, size),
|
||||
);
|
||||
|
||||
(
|
||||
VecZnxDft::from_data(take_slice, n, cols, size),
|
||||
Scratch::from_bytes(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVecZnxDftSliceImpl<B> for B
|
||||
where
|
||||
B: CPUAVX + VecZnxDftAllocBytesImpl<B>,
|
||||
{
|
||||
fn take_vec_znx_dft_slice_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
len: usize,
|
||||
n: usize,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (Vec<VecZnxDft<&mut [u8], B>>, &mut Scratch<B>) {
|
||||
let mut scratch: &mut Scratch<B> = scratch;
|
||||
let mut slice: Vec<VecZnxDft<&mut [u8], B>> = Vec::with_capacity(len);
|
||||
for _ in 0..len {
|
||||
let (znx, new_scratch) = B::take_vec_znx_dft_impl(scratch, n, cols, size);
|
||||
scratch = new_scratch;
|
||||
slice.push(znx);
|
||||
}
|
||||
(slice, scratch)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVecZnxSliceImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn take_vec_znx_slice_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
len: usize,
|
||||
n: usize,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (Vec<VecZnx<&mut [u8]>>, &mut Scratch<B>) {
|
||||
let mut scratch: &mut Scratch<B> = scratch;
|
||||
let mut slice: Vec<VecZnx<&mut [u8]>> = Vec::with_capacity(len);
|
||||
for _ in 0..len {
|
||||
let (znx, new_scratch) = B::take_vec_znx_impl(scratch, n, cols, size);
|
||||
scratch = new_scratch;
|
||||
slice.push(znx);
|
||||
}
|
||||
(slice, scratch)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVmpPMatImpl<B> for B
|
||||
where
|
||||
B: CPUAVX + VmpPMatAllocBytesImpl<B>,
|
||||
{
|
||||
fn take_vmp_pmat_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
n: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
) -> (VmpPMat<&mut [u8], B>, &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(
|
||||
&mut scratch.data,
|
||||
B::vmp_pmat_alloc_bytes_impl(n, rows, cols_in, cols_out, size),
|
||||
);
|
||||
(
|
||||
VmpPMat::from_data(take_slice, n, rows, cols_in, cols_out, size),
|
||||
Scratch::from_bytes(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeMatZnxImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn take_mat_znx_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
n: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
) -> (MatZnx<&mut [u8]>, &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(
|
||||
&mut scratch.data,
|
||||
MatZnx::alloc_bytes(n, rows, cols_in, cols_out, size),
|
||||
);
|
||||
(
|
||||
MatZnx::from_data(take_slice, n, rows, cols_in, cols_out, size),
|
||||
Scratch::from_bytes(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8]) {
|
||||
let ptr: *mut u8 = data.as_mut_ptr();
|
||||
let self_len: usize = data.len();
|
||||
|
||||
let aligned_offset: usize = ptr.align_offset(DEFAULTALIGN);
|
||||
let aligned_len: usize = self_len.saturating_sub(aligned_offset);
|
||||
|
||||
if let Some(rem_len) = aligned_len.checked_sub(take_len) {
|
||||
unsafe {
|
||||
let rem_ptr: *mut u8 = ptr.add(aligned_offset).add(take_len);
|
||||
let rem_slice: &mut [u8] = &mut *std::ptr::slice_from_raw_parts_mut(rem_ptr, rem_len);
|
||||
|
||||
let take_slice: &mut [u8] = &mut *std::ptr::slice_from_raw_parts_mut(ptr.add(aligned_offset), take_len);
|
||||
|
||||
(take_slice, rem_slice)
|
||||
}
|
||||
} else {
|
||||
panic!(
|
||||
"Attempted to take {} from scratch with {} aligned bytes left",
|
||||
take_len, aligned_len,
|
||||
);
|
||||
}
|
||||
}
|
||||
Submodule poulpy-backend/src/implementation/cpu_spqlios/spqlios-arithmetic added at de62af3507
114
poulpy-backend/src/implementation/cpu_spqlios/svp_ppol_fft64.rs
Normal file
114
poulpy-backend/src/implementation/cpu_spqlios/svp_ppol_fft64.rs
Normal file
@@ -0,0 +1,114 @@
|
||||
use crate::{
|
||||
hal::{
|
||||
api::{ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut},
|
||||
layouts::{
|
||||
Data, DataRef, Module, ScalarZnxToRef, SvpPPol, SvpPPolBytesOf, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDft,
|
||||
VecZnxDftToMut, VecZnxDftToRef,
|
||||
},
|
||||
oep::{SvpApplyImpl, SvpApplyInplaceImpl, SvpPPolAllocBytesImpl, SvpPPolAllocImpl, SvpPPolFromBytesImpl, SvpPrepareImpl},
|
||||
},
|
||||
implementation::cpu_spqlios::{
|
||||
ffi::{svp, vec_znx_dft::vec_znx_dft_t},
|
||||
module_fft64::FFT64,
|
||||
},
|
||||
};
|
||||
|
||||
const SVP_PPOL_FFT64_WORD_SIZE: usize = 1;
|
||||
|
||||
impl<D: Data> SvpPPolBytesOf for SvpPPol<D, FFT64> {
|
||||
fn bytes_of(n: usize, cols: usize) -> usize {
|
||||
SVP_PPOL_FFT64_WORD_SIZE * n * cols * size_of::<f64>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> ZnxSliceSize for SvpPPol<D, FFT64> {
|
||||
fn sl(&self) -> usize {
|
||||
SVP_PPOL_FFT64_WORD_SIZE * self.n()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> ZnxView for SvpPPol<D, FFT64> {
|
||||
type Scalar = f64;
|
||||
}
|
||||
|
||||
unsafe impl SvpPPolFromBytesImpl<Self> for FFT64 {
|
||||
fn svp_ppol_from_bytes_impl(n: usize, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<Self> {
|
||||
SvpPPolOwned::from_bytes(n, cols, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SvpPPolAllocImpl<Self> for FFT64 {
|
||||
fn svp_ppol_alloc_impl(n: usize, cols: usize) -> SvpPPolOwned<Self> {
|
||||
SvpPPolOwned::alloc(n, cols)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SvpPPolAllocBytesImpl<Self> for FFT64 {
|
||||
fn svp_ppol_alloc_bytes_impl(n: usize, cols: usize) -> usize {
|
||||
SvpPPol::<Vec<u8>, Self>::bytes_of(n, cols)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SvpPrepareImpl<Self> for FFT64 {
|
||||
fn svp_prepare_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: SvpPPolToMut<Self>,
|
||||
A: ScalarZnxToRef,
|
||||
{
|
||||
unsafe {
|
||||
svp::svp_prepare(
|
||||
module.ptr(),
|
||||
res.to_mut().at_mut_ptr(res_col, 0) as *mut svp::svp_ppol_t,
|
||||
a.to_ref().at_ptr(a_col, 0),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SvpApplyImpl<Self> for FFT64 {
|
||||
fn svp_apply_impl<R, A, B>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: SvpPPolToRef<Self>,
|
||||
B: VecZnxDftToRef<Self>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], Self> = res.to_mut();
|
||||
let a: SvpPPol<&[u8], Self> = a.to_ref();
|
||||
let b: VecZnxDft<&[u8], Self> = b.to_ref();
|
||||
unsafe {
|
||||
svp::svp_apply_dft_to_dft(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
|
||||
res.size() as u64,
|
||||
res.cols() as u64,
|
||||
a.at_ptr(a_col, 0) as *const svp::svp_ppol_t,
|
||||
b.at_ptr(b_col, 0) as *const vec_znx_dft_t,
|
||||
b.size() as u64,
|
||||
b.cols() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SvpApplyInplaceImpl for FFT64 {
|
||||
fn svp_apply_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: SvpPPolToRef<Self>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], Self> = res.to_mut();
|
||||
let a: SvpPPol<&[u8], Self> = a.to_ref();
|
||||
unsafe {
|
||||
svp::svp_apply_dft_to_dft(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
|
||||
res.size() as u64,
|
||||
res.cols() as u64,
|
||||
a.at_ptr(a_col, 0) as *const svp::svp_ppol_t,
|
||||
res.at_ptr(res_col, 0) as *const vec_znx_dft_t,
|
||||
res.size() as u64,
|
||||
res.cols() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
use crate::{
|
||||
hal::{
|
||||
api::{ZnxInfos, ZnxSliceSize, ZnxView},
|
||||
layouts::{Data, DataRef, SvpPPol, SvpPPolBytesOf, SvpPPolOwned},
|
||||
oep::{SvpPPolAllocBytesImpl, SvpPPolAllocImpl, SvpPPolFromBytesImpl},
|
||||
},
|
||||
implementation::cpu_spqlios::module_ntt120::NTT120,
|
||||
};
|
||||
|
||||
const SVP_PPOL_NTT120_WORD_SIZE: usize = 4;
|
||||
|
||||
impl<D: Data> SvpPPolBytesOf for SvpPPol<D, NTT120> {
|
||||
fn bytes_of(n: usize, cols: usize) -> usize {
|
||||
SVP_PPOL_NTT120_WORD_SIZE * n * cols * size_of::<i64>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> ZnxSliceSize for SvpPPol<D, NTT120> {
|
||||
fn sl(&self) -> usize {
|
||||
SVP_PPOL_NTT120_WORD_SIZE * self.n()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> ZnxView for SvpPPol<D, NTT120> {
|
||||
type Scalar = i64;
|
||||
}
|
||||
|
||||
unsafe impl SvpPPolFromBytesImpl<Self> for NTT120 {
|
||||
fn svp_ppol_from_bytes_impl(n: usize, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<NTT120> {
|
||||
SvpPPolOwned::from_bytes(n, cols, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SvpPPolAllocImpl<Self> for NTT120 {
|
||||
fn svp_ppol_alloc_impl(n: usize, cols: usize) -> SvpPPolOwned<NTT120> {
|
||||
SvpPPolOwned::alloc(n, cols)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SvpPPolAllocBytesImpl<Self> for NTT120 {
|
||||
fn svp_ppol_alloc_bytes_impl(n: usize, cols: usize) -> usize {
|
||||
SvpPPol::<Vec<u8>, Self>::bytes_of(n, cols)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
mod vec_znx_fft64;
|
||||
@@ -0,0 +1,20 @@
|
||||
use crate::{
|
||||
hal::{
|
||||
api::ModuleNew,
|
||||
layouts::Module,
|
||||
tests::vec_znx::{test_vec_znx_add_normal, test_vec_znx_fill_uniform},
|
||||
},
|
||||
implementation::cpu_spqlios::FFT64,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_vec_znx_fill_uniform_fft64() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << 12);
|
||||
test_vec_znx_fill_uniform(&module);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vec_znx_add_normal_fft64() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << 12);
|
||||
test_vec_znx_add_normal(&module);
|
||||
}
|
||||
929
poulpy-backend/src/implementation/cpu_spqlios/vec_znx.rs
Normal file
929
poulpy-backend/src/implementation/cpu_spqlios/vec_znx.rs
Normal file
@@ -0,0 +1,929 @@
|
||||
use itertools::izip;
|
||||
use rand_distr::Normal;
|
||||
|
||||
use crate::{
|
||||
hal::{
|
||||
api::{
|
||||
TakeSlice, TakeVecZnx, VecZnxAddDistF64, VecZnxCopy, VecZnxFillDistF64, VecZnxNormalizeTmpBytes, VecZnxRotate,
|
||||
VecZnxRotateInplace, VecZnxSwithcDegree, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
|
||||
},
|
||||
layouts::{Backend, Module, ScalarZnx, ScalarZnxToRef, Scratch, VecZnx, VecZnxToMut, VecZnxToRef},
|
||||
oep::{
|
||||
VecZnxAddDistF64Impl, VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl, VecZnxAddScalarInplaceImpl,
|
||||
VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl, VecZnxFillDistF64Impl, VecZnxFillNormalImpl,
|
||||
VecZnxFillUniformImpl, VecZnxLshInplaceImpl, VecZnxMergeImpl, VecZnxMulXpMinusOneImpl,
|
||||
VecZnxMulXpMinusOneInplaceImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl,
|
||||
VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl,
|
||||
VecZnxRshInplaceImpl, VecZnxSplitImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl,
|
||||
VecZnxSubScalarInplaceImpl, VecZnxSwithcDegreeImpl,
|
||||
},
|
||||
source::Source,
|
||||
},
|
||||
implementation::cpu_spqlios::{
|
||||
CPUAVX,
|
||||
ffi::{module::module_info_t, vec_znx, znx},
|
||||
},
|
||||
};
|
||||
|
||||
unsafe impl<B: Backend> VecZnxNormalizeTmpBytesImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn vec_znx_normalize_tmp_bytes_impl(module: &Module<B>, n: usize) -> usize {
|
||||
unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(module.ptr() as *const module_info_t, n as u64) as usize }
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxNormalizeImpl<B> for B {
|
||||
fn vec_znx_normalize_impl<R, A>(
|
||||
module: &Module<B>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), a.n());
|
||||
}
|
||||
|
||||
let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes(a.n()));
|
||||
|
||||
unsafe {
|
||||
vec_znx::vec_znx_normalize_base2k(
|
||||
module.ptr() as *const module_info_t,
|
||||
a.n() as u64,
|
||||
basek as u64,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxNormalizeInplaceImpl<B> for B {
|
||||
fn vec_znx_normalize_inplace_impl<A>(module: &Module<B>, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
let mut a: VecZnx<&mut [u8]> = a.to_mut();
|
||||
|
||||
let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes(a.n()));
|
||||
|
||||
unsafe {
|
||||
vec_znx::vec_znx_normalize_base2k(
|
||||
module.ptr() as *const module_info_t,
|
||||
a.n() as u64,
|
||||
basek as u64,
|
||||
a.at_mut_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxAddImpl<B> for B {
|
||||
fn vec_znx_add_impl<R, A, C>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
C: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let b: VecZnx<&[u8]> = b.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
assert_eq!(b.n(), res.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_add(
|
||||
module.ptr() as *const module_info_t,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(b_col, 0),
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxAddInplaceImpl<B> for B {
|
||||
fn vec_znx_add_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_add(
|
||||
module.ptr() as *const module_info_t,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxAddScalarInplaceImpl<B> for B {
|
||||
fn vec_znx_add_scalar_inplace_impl<R, A>(
|
||||
module: &Module<B>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
res_limb: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
let a: ScalarZnx<&[u8]> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
|
||||
unsafe {
|
||||
vec_znx::vec_znx_add(
|
||||
module.ptr() as *const module_info_t,
|
||||
res.at_mut_ptr(res_col, res_limb),
|
||||
1_u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
res.at_ptr(res_col, res_limb),
|
||||
1_u64,
|
||||
res.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxSubImpl<B> for B {
|
||||
fn vec_znx_sub_impl<R, A, C>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
C: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let b: VecZnx<&[u8]> = b.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
assert_eq!(b.n(), res.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
module.ptr() as *const module_info_t,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(b_col, 0),
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxSubABInplaceImpl<B> for B {
|
||||
fn vec_znx_sub_ab_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
module.ptr() as *const module_info_t,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxSubBAInplaceImpl<B> for B {
|
||||
fn vec_znx_sub_ba_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
module.ptr() as *const module_info_t,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxSubScalarInplaceImpl<B> for B {
|
||||
fn vec_znx_sub_scalar_inplace_impl<R, A>(
|
||||
module: &Module<B>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
res_limb: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
let a: ScalarZnx<&[u8]> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
module.ptr() as *const module_info_t,
|
||||
res.at_mut_ptr(res_col, res_limb),
|
||||
1_u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
res.at_ptr(res_col, res_limb),
|
||||
1_u64,
|
||||
res.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxNegateImpl<B> for B {
|
||||
fn vec_znx_negate_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_negate(
|
||||
module.ptr() as *const module_info_t,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxNegateInplaceImpl<B> for B {
|
||||
fn vec_znx_negate_inplace_impl<A>(module: &Module<B>, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
let mut a: VecZnx<&mut [u8]> = a.to_mut();
|
||||
unsafe {
|
||||
vec_znx::vec_znx_negate(
|
||||
module.ptr() as *const module_info_t,
|
||||
a.at_mut_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxLshInplaceImpl<B> for B {
|
||||
fn vec_znx_lsh_inplace_impl<A>(_module: &Module<B>, basek: usize, k: usize, a: &mut A)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
vec_znx_lsh_inplace_ref(basek, k, a)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_lsh_inplace_ref<A>(basek: usize, k: usize, a: &mut A)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
let mut a: VecZnx<&mut [u8]> = a.to_mut();
|
||||
|
||||
let n: usize = a.n();
|
||||
let cols: usize = a.cols();
|
||||
let size: usize = a.size();
|
||||
let steps: usize = k / basek;
|
||||
|
||||
a.raw_mut().rotate_left(n * steps * cols);
|
||||
(0..cols).for_each(|i| {
|
||||
(size - steps..size).for_each(|j| {
|
||||
a.zero_at(i, j);
|
||||
})
|
||||
});
|
||||
|
||||
let k_rem: usize = k % basek;
|
||||
|
||||
if k_rem != 0 {
|
||||
let shift: usize = i64::BITS as usize - k_rem;
|
||||
(0..cols).for_each(|i| {
|
||||
(0..steps).for_each(|j| {
|
||||
a.at_mut(i, j).iter_mut().for_each(|xi| {
|
||||
*xi <<= shift;
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxRshInplaceImpl<B> for B {
|
||||
fn vec_znx_rsh_inplace_impl<A>(_module: &Module<B>, basek: usize, k: usize, a: &mut A)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
vec_znx_rsh_inplace_ref(basek, k, a)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_rsh_inplace_ref<A>(basek: usize, k: usize, a: &mut A)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
let mut a: VecZnx<&mut [u8]> = a.to_mut();
|
||||
let n: usize = a.n();
|
||||
let cols: usize = a.cols();
|
||||
let size: usize = a.size();
|
||||
let steps: usize = k / basek;
|
||||
|
||||
a.raw_mut().rotate_right(n * steps * cols);
|
||||
(0..cols).for_each(|i| {
|
||||
(0..steps).for_each(|j| {
|
||||
a.zero_at(i, j);
|
||||
})
|
||||
});
|
||||
|
||||
let k_rem: usize = k % basek;
|
||||
|
||||
if k_rem != 0 {
|
||||
let mut carry: Vec<i64> = vec![0i64; n]; // ALLOC (but small so OK)
|
||||
let shift: usize = i64::BITS as usize - k_rem;
|
||||
(0..cols).for_each(|i| {
|
||||
carry.fill(0);
|
||||
(steps..size).for_each(|j| {
|
||||
izip!(carry.iter_mut(), a.at_mut(i, j).iter_mut()).for_each(|(ci, xi)| {
|
||||
*xi += *ci << basek;
|
||||
*ci = (*xi << shift) >> shift;
|
||||
*xi = (*xi - *ci) >> k_rem;
|
||||
});
|
||||
});
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxRotateImpl<B> for B {
|
||||
fn vec_znx_rotate_impl<R, A>(_module: &Module<B>, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), a.n());
|
||||
}
|
||||
unsafe {
|
||||
(0..a.size()).for_each(|j| {
|
||||
znx::znx_rotate_i64(
|
||||
a.n() as u64,
|
||||
k,
|
||||
res.at_mut_ptr(res_col, j),
|
||||
a.at_ptr(a_col, j),
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxRotateInplaceImpl<B> for B {
|
||||
fn vec_znx_rotate_inplace_impl<A>(_module: &Module<B>, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
let mut a: VecZnx<&mut [u8]> = a.to_mut();
|
||||
unsafe {
|
||||
(0..a.size()).for_each(|j| {
|
||||
znx::znx_rotate_inplace_i64(a.n() as u64, k, a.at_mut_ptr(a_col, j));
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxAutomorphismImpl<B> for B {
|
||||
fn vec_znx_automorphism_impl<R, A>(module: &Module<B>, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
module.ptr() as *const module_info_t,
|
||||
k,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxAutomorphismInplaceImpl<B> for B {
|
||||
fn vec_znx_automorphism_inplace_impl<A>(module: &Module<B>, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
let mut a: VecZnx<&mut [u8]> = a.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(
|
||||
k & 1 != 0,
|
||||
"invalid galois element: must be odd but is {}",
|
||||
k
|
||||
);
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
module.ptr() as *const module_info_t,
|
||||
k,
|
||||
a.at_mut_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxMulXpMinusOneImpl<B> for B {
|
||||
fn vec_znx_mul_xp_minus_one_impl<R, A>(module: &Module<B>, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
assert_eq!(res.n(), res.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_mul_xp_minus_one(
|
||||
module.ptr() as *const module_info_t,
|
||||
p,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxMulXpMinusOneInplaceImpl<B> for B {
|
||||
fn vec_znx_mul_xp_minus_one_inplace_impl<R>(module: &Module<B>, p: i64, res: &mut R, res_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), res.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_mul_xp_minus_one(
|
||||
module.ptr() as *const module_info_t,
|
||||
p,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxSplitImpl<B> for B {
|
||||
fn vec_znx_split_impl<R, A>(module: &Module<B>, res: &mut [R], res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
vec_znx_split_ref(module, res, res_col, a, a_col, scratch)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_split_ref<R, A, B>(
|
||||
module: &Module<B>,
|
||||
res: &mut [R],
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
B: Backend + CPUAVX,
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
|
||||
let (n_in, n_out) = (a.n(), res[0].to_mut().n());
|
||||
|
||||
let (mut buf, _) = scratch.take_vec_znx(n_in.max(n_out), 1, a.size());
|
||||
|
||||
debug_assert!(
|
||||
n_out < n_in,
|
||||
"invalid a: output ring degree should be smaller"
|
||||
);
|
||||
res[1..].iter_mut().for_each(|bi| {
|
||||
debug_assert_eq!(
|
||||
bi.to_mut().n(),
|
||||
n_out,
|
||||
"invalid input a: all VecZnx must have the same degree"
|
||||
)
|
||||
});
|
||||
|
||||
res.iter_mut().enumerate().for_each(|(i, bi)| {
|
||||
if i == 0 {
|
||||
module.vec_znx_switch_degree(bi, res_col, &a, a_col);
|
||||
module.vec_znx_rotate(-1, &mut buf, 0, &a, a_col);
|
||||
} else {
|
||||
module.vec_znx_switch_degree(bi, res_col, &buf, a_col);
|
||||
module.vec_znx_rotate_inplace(-1, &mut buf, a_col);
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxMergeImpl<B> for B {
|
||||
fn vec_znx_merge_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &[A], a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
vec_znx_merge_ref(module, res, res_col, a, a_col)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_merge_ref<R, A, B>(module: &Module<B>, res: &mut R, res_col: usize, a: &[A], a_col: usize)
|
||||
where
|
||||
B: Backend + CPUAVX,
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
let (n_in, n_out) = (res.n(), a[0].to_ref().n());
|
||||
|
||||
debug_assert!(
|
||||
n_out < n_in,
|
||||
"invalid a: output ring degree should be smaller"
|
||||
);
|
||||
a[1..].iter().for_each(|ai| {
|
||||
debug_assert_eq!(
|
||||
ai.to_ref().n(),
|
||||
n_out,
|
||||
"invalid input a: all VecZnx must have the same degree"
|
||||
)
|
||||
});
|
||||
|
||||
a.iter().for_each(|ai| {
|
||||
module.vec_znx_switch_degree(&mut res, res_col, ai, a_col);
|
||||
module.vec_znx_rotate_inplace(-1, &mut res, res_col);
|
||||
});
|
||||
|
||||
module.vec_znx_rotate_inplace(a.len() as i64, &mut res, res_col);
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxSwithcDegreeImpl<B> for B {
|
||||
fn vec_znx_switch_degree_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
vec_znx_switch_degree_ref(module, res, res_col, a, a_col)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_switch_degree_ref<R, A, B>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
B: Backend + CPUAVX,
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
let (n_in, n_out) = (a.n(), res.n());
|
||||
|
||||
if n_in == n_out {
|
||||
module.vec_znx_copy(&mut res, res_col, &a, a_col);
|
||||
return;
|
||||
}
|
||||
|
||||
let (gap_in, gap_out): (usize, usize);
|
||||
if n_in > n_out {
|
||||
(gap_in, gap_out) = (n_in / n_out, 1)
|
||||
} else {
|
||||
(gap_in, gap_out) = (1, n_out / n_in);
|
||||
res.zero();
|
||||
}
|
||||
|
||||
let size: usize = a.size().min(res.size());
|
||||
|
||||
(0..size).for_each(|i| {
|
||||
izip!(
|
||||
a.at(a_col, i).iter().step_by(gap_in),
|
||||
res.at_mut(res_col, i).iter_mut().step_by(gap_out)
|
||||
)
|
||||
.for_each(|(x_in, x_out)| *x_out = *x_in);
|
||||
});
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxCopyImpl<B> for B {
|
||||
fn vec_znx_copy_impl<R, A>(_module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
vec_znx_copy_ref(res, res_col, a, a_col)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_copy_ref<R, A>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let mut res_mut: VecZnx<&mut [u8]> = res.to_mut();
|
||||
let a_ref: VecZnx<&[u8]> = a.to_ref();
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_ref.size());
|
||||
|
||||
(0..min_size).for_each(|j| {
|
||||
res_mut
|
||||
.at_mut(res_col, j)
|
||||
.copy_from_slice(a_ref.at(a_col, j));
|
||||
});
|
||||
(min_size..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
})
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxFillUniformImpl<B> for B {
|
||||
fn vec_znx_fill_uniform_impl<R>(_module: &Module<B>, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
let mut a: VecZnx<&mut [u8]> = res.to_mut();
|
||||
let base2k: u64 = 1 << basek;
|
||||
let mask: u64 = base2k - 1;
|
||||
let base2k_half: i64 = (base2k >> 1) as i64;
|
||||
(0..k.div_ceil(basek)).for_each(|j| {
|
||||
a.at_mut(res_col, j)
|
||||
.iter_mut()
|
||||
.for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxFillDistF64Impl<B> for B {
|
||||
fn vec_znx_fill_dist_f64_impl<R, D: rand::prelude::Distribution<f64>>(
|
||||
_module: &Module<B>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
let mut a: VecZnx<&mut [u8]> = res.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = k.div_ceil(basek) - 1;
|
||||
let basek_rem: usize = (limb + 1) * basek - k;
|
||||
|
||||
if basek_rem != 0 {
|
||||
a.at_mut(res_col, limb).iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a = (dist_f64.round() as i64) << basek_rem;
|
||||
});
|
||||
} else {
|
||||
a.at_mut(res_col, limb).iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a = dist_f64.round() as i64
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxAddDistF64Impl<B> for B {
|
||||
fn vec_znx_add_dist_f64_impl<R, D: rand::prelude::Distribution<f64>>(
|
||||
_module: &Module<B>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
let mut a: VecZnx<&mut [u8]> = res.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = k.div_ceil(basek) - 1;
|
||||
let basek_rem: usize = (limb + 1) * basek - k;
|
||||
|
||||
if basek_rem != 0 {
|
||||
a.at_mut(res_col, limb).iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a += (dist_f64.round() as i64) << basek_rem;
|
||||
});
|
||||
} else {
|
||||
a.at_mut(res_col, limb).iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a += dist_f64.round() as i64
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxFillNormalImpl<B> for B {
|
||||
fn vec_znx_fill_normal_impl<R>(
|
||||
module: &Module<B>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
module.vec_znx_fill_dist_f64(
|
||||
basek,
|
||||
res,
|
||||
res_col,
|
||||
k,
|
||||
source,
|
||||
Normal::new(0.0, sigma).unwrap(),
|
||||
bound,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxAddNormalImpl<B> for B {
|
||||
fn vec_znx_add_normal_impl<R>(
|
||||
module: &Module<B>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
module.vec_znx_add_dist_f64(
|
||||
basek,
|
||||
res,
|
||||
res_col,
|
||||
k,
|
||||
source,
|
||||
Normal::new(0.0, sigma).unwrap(),
|
||||
bound,
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,737 @@
|
||||
use std::fmt;
|
||||
|
||||
use rand_distr::{Distribution, Normal};
|
||||
|
||||
use crate::{
|
||||
hal::{
|
||||
api::{
|
||||
TakeSlice, VecZnxBigAddDistF64, VecZnxBigFillDistF64, VecZnxBigNormalizeTmpBytes, ZnxInfos, ZnxSliceSize, ZnxView,
|
||||
ZnxViewMut,
|
||||
},
|
||||
layouts::{
|
||||
Data, DataRef, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigBytesOf, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef,
|
||||
VecZnxToMut, VecZnxToRef,
|
||||
},
|
||||
oep::{
|
||||
VecZnxBigAddDistF64Impl, VecZnxBigAddImpl, VecZnxBigAddInplaceImpl, VecZnxBigAddNormalImpl, VecZnxBigAddSmallImpl,
|
||||
VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, VecZnxBigAutomorphismImpl,
|
||||
VecZnxBigAutomorphismInplaceImpl, VecZnxBigFillDistF64Impl, VecZnxBigFillNormalImpl, VecZnxBigFromBytesImpl,
|
||||
VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl, VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubABInplaceImpl,
|
||||
VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl, VecZnxBigSubSmallAImpl, VecZnxBigSubSmallAInplaceImpl,
|
||||
VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl,
|
||||
},
|
||||
source::Source,
|
||||
},
|
||||
implementation::cpu_spqlios::{ffi::vec_znx, module_fft64::FFT64},
|
||||
};
|
||||
|
||||
const VEC_ZNX_BIG_FFT64_WORDSIZE: usize = 1;
|
||||
|
||||
impl<D: DataRef> ZnxView for VecZnxBig<D, FFT64> {
|
||||
type Scalar = i64;
|
||||
}
|
||||
|
||||
impl<D: Data> VecZnxBigBytesOf for VecZnxBig<D, FFT64> {
|
||||
fn bytes_of(n: usize, cols: usize, size: usize) -> usize {
|
||||
VEC_ZNX_BIG_FFT64_WORDSIZE * n * cols * size * size_of::<f64>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> ZnxSliceSize for VecZnxBig<D, FFT64> {
|
||||
fn sl(&self) -> usize {
|
||||
VEC_ZNX_BIG_FFT64_WORDSIZE * self.n() * self.cols()
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAllocImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_big_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxBigOwned<FFT64> {
|
||||
VecZnxBig::<Vec<u8>, FFT64>::new(n, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigFromBytesImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_big_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<FFT64> {
|
||||
VecZnxBig::<Vec<u8>, FFT64>::new_from_bytes(n, cols, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAllocBytesImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_big_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize {
|
||||
VecZnxBig::<Vec<u8>, FFT64>::bytes_of(n, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddDistF64Impl<FFT64> for FFT64 {
|
||||
fn add_dist_f64_impl<R: VecZnxBigToMut<FFT64>, D: Distribution<f64>>(
|
||||
_module: &Module<FFT64>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) {
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = k.div_ceil(basek) - 1;
|
||||
let basek_rem: usize = (limb + 1) * basek - k;
|
||||
|
||||
if basek_rem != 0 {
|
||||
res.at_mut(res_col, limb).iter_mut().for_each(|x| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*x += (dist_f64.round() as i64) << basek_rem;
|
||||
});
|
||||
} else {
|
||||
res.at_mut(res_col, limb).iter_mut().for_each(|x| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*x += dist_f64.round() as i64
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddNormalImpl<FFT64> for FFT64 {
|
||||
fn add_normal_impl<R: VecZnxBigToMut<FFT64>>(
|
||||
module: &Module<FFT64>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) {
|
||||
module.vec_znx_big_add_dist_f64(
|
||||
basek,
|
||||
res,
|
||||
res_col,
|
||||
k,
|
||||
source,
|
||||
Normal::new(0.0, sigma).unwrap(),
|
||||
bound,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigFillDistF64Impl<FFT64> for FFT64 {
|
||||
fn fill_dist_f64_impl<R: VecZnxBigToMut<FFT64>, D: Distribution<f64>>(
|
||||
_module: &Module<FFT64>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) {
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = k.div_ceil(basek) - 1;
|
||||
let basek_rem: usize = (limb + 1) * basek - k;
|
||||
|
||||
if basek_rem != 0 {
|
||||
res.at_mut(res_col, limb).iter_mut().for_each(|x| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*x = (dist_f64.round() as i64) << basek_rem;
|
||||
});
|
||||
} else {
|
||||
res.at_mut(res_col, limb).iter_mut().for_each(|x| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*x = dist_f64.round() as i64
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigFillNormalImpl<FFT64> for FFT64 {
|
||||
fn fill_normal_impl<R: VecZnxBigToMut<FFT64>>(
|
||||
module: &Module<FFT64>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) {
|
||||
module.vec_znx_big_fill_dist_f64(
|
||||
basek,
|
||||
res,
|
||||
res_col,
|
||||
k,
|
||||
source,
|
||||
Normal::new(0.0, sigma).unwrap(),
|
||||
bound,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddImpl<FFT64> for FFT64 {
|
||||
/// Adds `a` to `b` and stores the result on `c`.
|
||||
fn vec_znx_big_add_impl<R, A, B>(
|
||||
module: &Module<FFT64>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
B: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let b: VecZnxBig<&[u8], FFT64> = b.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
assert_eq!(b.n(), res.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_add(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(b_col, 0),
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddInplaceImpl<FFT64> for FFT64 {
|
||||
/// Adds `a` to `b` and stores the result on `b`.
|
||||
fn vec_znx_big_add_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_add(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddSmallImpl<FFT64> for FFT64 {
|
||||
/// Adds `a` to `b` and stores the result on `c`.
|
||||
fn vec_znx_big_add_small_impl<R, A, B>(
|
||||
module: &Module<FFT64>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
B: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let b: VecZnx<&[u8]> = b.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
assert_eq!(b.n(), res.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_add(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(b_col, 0),
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddSmallInplaceImpl<FFT64> for FFT64 {
|
||||
/// Adds `a` to `b` and stores the result on `b`.
|
||||
fn vec_znx_big_add_small_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_add(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubImpl<FFT64> for FFT64 {
|
||||
/// Subtracts `a` to `b` and stores the result on `c`.
|
||||
fn vec_znx_big_sub_impl<R, A, B>(
|
||||
module: &Module<FFT64>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
B: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let b: VecZnxBig<&[u8], FFT64> = b.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
assert_eq!(b.n(), res.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(b_col, 0),
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubABInplaceImpl<FFT64> for FFT64 {
|
||||
/// Subtracts `a` from `b` and stores the result on `b`.
|
||||
fn vec_znx_big_sub_ab_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubBAInplaceImpl<FFT64> for FFT64 {
|
||||
/// Subtracts `b` from `a` and stores the result on `b`.
|
||||
fn vec_znx_big_sub_ba_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubSmallAImpl<FFT64> for FFT64 {
|
||||
/// Subtracts `b` from `a` and stores the result on `c`.
|
||||
fn vec_znx_big_sub_small_a_impl<R, A, B>(
|
||||
module: &Module<FFT64>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxToRef,
|
||||
B: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let b: VecZnxBig<&[u8], FFT64> = b.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
assert_eq!(b.n(), res.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(b_col, 0),
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubSmallAInplaceImpl<FFT64> for FFT64 {
|
||||
/// Subtracts `a` from `res` and stores the result on `res`.
|
||||
fn vec_znx_big_sub_small_a_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubSmallBImpl<FFT64> for FFT64 {
|
||||
/// Subtracts `b` from `a` and stores the result on `c`.
|
||||
fn vec_znx_big_sub_small_b_impl<R, A, B>(
|
||||
module: &Module<FFT64>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
B: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let b: VecZnx<&[u8]> = b.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
assert_eq!(b.n(), res.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(b_col, 0),
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubSmallBInplaceImpl<FFT64> for FFT64 {
|
||||
/// Subtracts `res` from `a` and stores the result on `res`.
|
||||
fn vec_znx_big_sub_small_b_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigNegateInplaceImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_big_negate_inplace_impl<A>(module: &Module<FFT64>, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxBigToMut<FFT64>,
|
||||
{
|
||||
let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut();
|
||||
unsafe {
|
||||
vec_znx::vec_znx_negate(
|
||||
module.ptr(),
|
||||
a.at_mut_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigNormalizeTmpBytesImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_big_normalize_tmp_bytes_impl(module: &Module<FFT64>, n: usize) -> usize {
|
||||
unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(module.ptr(), n as u64) as usize }
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigNormalizeImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_big_normalize_impl<R, A>(
|
||||
module: &Module<FFT64>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<FFT64>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), a.n());
|
||||
}
|
||||
|
||||
let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes(a.n()));
|
||||
unsafe {
|
||||
vec_znx::vec_znx_normalize_base2k(
|
||||
module.ptr(),
|
||||
a.n() as u64,
|
||||
basek as u64,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAutomorphismImpl<FFT64> for FFT64 {
|
||||
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`.
|
||||
fn vec_znx_big_automorphism_impl<R, A>(module: &Module<FFT64>, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
module.ptr(),
|
||||
k,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAutomorphismInplaceImpl<FFT64> for FFT64 {
|
||||
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`.
|
||||
fn vec_znx_big_automorphism_inplace_impl<A>(module: &Module<FFT64>, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxBigToMut<FFT64>,
|
||||
{
|
||||
let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut();
|
||||
unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
module.ptr(),
|
||||
k,
|
||||
a.at_mut_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> fmt::Display for VecZnxBig<D, FFT64> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
writeln!(
|
||||
f,
|
||||
"VecZnxBig(n={}, cols={}, size={})",
|
||||
self.n, self.cols, self.size
|
||||
)?;
|
||||
|
||||
for col in 0..self.cols {
|
||||
writeln!(f, "Column {}:", col)?;
|
||||
for size in 0..self.size {
|
||||
let coeffs = self.at(col, size);
|
||||
write!(f, " Size {}: [", size)?;
|
||||
|
||||
let max_show = 100;
|
||||
let show_count = coeffs.len().min(max_show);
|
||||
|
||||
for (i, &coeff) in coeffs.iter().take(show_count).enumerate() {
|
||||
if i > 0 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
write!(f, "{}", coeff)?;
|
||||
}
|
||||
|
||||
if coeffs.len() > max_show {
|
||||
write!(f, ", ... ({} more)", coeffs.len() - max_show)?;
|
||||
}
|
||||
|
||||
writeln!(f, "]")?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
use crate::{
|
||||
hal::{
|
||||
api::{ZnxInfos, ZnxSliceSize, ZnxView},
|
||||
layouts::{Data, DataRef, VecZnxBig, VecZnxBigBytesOf},
|
||||
oep::VecZnxBigAllocBytesImpl,
|
||||
},
|
||||
implementation::cpu_spqlios::module_ntt120::NTT120,
|
||||
};
|
||||
|
||||
const VEC_ZNX_BIG_NTT120_WORDSIZE: usize = 4;
|
||||
|
||||
impl<D: DataRef> ZnxView for VecZnxBig<D, NTT120> {
|
||||
type Scalar = i128;
|
||||
}
|
||||
|
||||
impl<D: Data> VecZnxBigBytesOf for VecZnxBig<D, NTT120> {
|
||||
fn bytes_of(n: usize, cols: usize, size: usize) -> usize {
|
||||
VEC_ZNX_BIG_NTT120_WORDSIZE * n * cols * size * size_of::<f64>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> ZnxSliceSize for VecZnxBig<D, NTT120> {
|
||||
fn sl(&self) -> usize {
|
||||
VEC_ZNX_BIG_NTT120_WORDSIZE * self.n() * self.cols()
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAllocBytesImpl<NTT120> for NTT120 {
|
||||
fn vec_znx_big_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize {
|
||||
VecZnxBig::<Vec<u8>, NTT120>::bytes_of(n, cols, size)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,433 @@
|
||||
use std::fmt;
|
||||
|
||||
use crate::{
|
||||
hal::{
|
||||
api::{TakeSlice, VecZnxDftToVecZnxBigTmpBytes, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero},
|
||||
layouts::{
|
||||
Data, DataRef, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftBytesOf, VecZnxDftOwned,
|
||||
VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef,
|
||||
},
|
||||
oep::{
|
||||
VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftCopyImpl,
|
||||
VecZnxDftFromBytesImpl, VecZnxDftFromVecZnxImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl,
|
||||
VecZnxDftSubImpl, VecZnxDftToVecZnxBigConsumeImpl, VecZnxDftToVecZnxBigImpl, VecZnxDftToVecZnxBigTmpAImpl,
|
||||
VecZnxDftToVecZnxBigTmpBytesImpl, VecZnxDftZeroImpl,
|
||||
},
|
||||
},
|
||||
implementation::cpu_spqlios::{
|
||||
ffi::{vec_znx_big, vec_znx_dft},
|
||||
module_fft64::FFT64,
|
||||
},
|
||||
};
|
||||
|
||||
const VEC_ZNX_DFT_FFT64_WORDSIZE: usize = 1;
|
||||
|
||||
impl<D: Data> ZnxSliceSize for VecZnxDft<D, FFT64> {
|
||||
fn sl(&self) -> usize {
|
||||
VEC_ZNX_DFT_FFT64_WORDSIZE * self.n() * self.cols()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> VecZnxDftBytesOf for VecZnxDft<D, FFT64> {
|
||||
fn bytes_of(n: usize, cols: usize, size: usize) -> usize {
|
||||
VEC_ZNX_DFT_FFT64_WORDSIZE * n * cols * size * size_of::<f64>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> ZnxView for VecZnxDft<D, FFT64> {
|
||||
type Scalar = f64;
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftFromBytesImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<FFT64> {
|
||||
VecZnxDft::<Vec<u8>, FFT64>::from_bytes(n, cols, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftAllocBytesImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize {
|
||||
VecZnxDft::<Vec<u8>, FFT64>::bytes_of(n, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftAllocImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxDftOwned<FFT64> {
|
||||
VecZnxDftOwned::alloc(n, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftToVecZnxBigTmpBytesImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(module: &Module<FFT64>, n: usize) -> usize {
|
||||
unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(module.ptr(), n as u64) as usize }
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftToVecZnxBigImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_to_vec_znx_big_impl<R, A>(
|
||||
module: &Module<FFT64>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<FFT64>,
|
||||
) where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), a.n())
|
||||
}
|
||||
|
||||
let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_dft_to_vec_znx_big_tmp_bytes(a.n()));
|
||||
|
||||
let min_size: usize = res.size().min(a.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_znx_idft(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t,
|
||||
1_u64,
|
||||
a.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1_u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
});
|
||||
(min_size..res.size()).for_each(|j| {
|
||||
res.zero_at(res_col, j);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftToVecZnxBigTmpAImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_to_vec_znx_big_tmp_a_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxDftToMut<FFT64>,
|
||||
{
|
||||
let mut res_mut: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut();
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_mut.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_znx_idft_tmp_a(
|
||||
module.ptr(),
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t,
|
||||
1_u64,
|
||||
a_mut.at_mut_ptr(a_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1_u64,
|
||||
)
|
||||
});
|
||||
(min_size..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftToVecZnxBigConsumeImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_to_vec_znx_big_consume_impl<D: Data>(module: &Module<FFT64>, mut a: VecZnxDft<D, FFT64>) -> VecZnxBig<D, FFT64>
|
||||
where
|
||||
VecZnxDft<D, FFT64>: VecZnxDftToMut<FFT64>,
|
||||
{
|
||||
let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut();
|
||||
|
||||
unsafe {
|
||||
// Rev col and rows because ZnxDft.sl() >= ZnxBig.sl()
|
||||
(0..a_mut.size()).for_each(|j| {
|
||||
(0..a_mut.cols()).for_each(|i| {
|
||||
vec_znx_dft::vec_znx_idft_tmp_a(
|
||||
module.ptr(),
|
||||
a_mut.at_mut_ptr(i, j) as *mut vec_znx_big::vec_znx_big_t,
|
||||
1_u64,
|
||||
a_mut.at_mut_ptr(i, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1_u64,
|
||||
)
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
a.into_big()
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftFromVecZnxImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_from_vec_znx_impl<R, A>(
|
||||
module: &Module<FFT64>,
|
||||
step: usize,
|
||||
offset: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
) where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnx<&[u8]> = a.to_ref();
|
||||
let steps: usize = a_ref.size().div_ceil(step);
|
||||
let min_steps: usize = res_mut.size().min(steps);
|
||||
unsafe {
|
||||
(0..min_steps).for_each(|j| {
|
||||
let limb: usize = offset + j * step;
|
||||
if limb < a_ref.size() {
|
||||
vec_znx_dft::vec_znx_dft(
|
||||
module.ptr(),
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1_u64,
|
||||
a_ref.at_ptr(a_col, limb),
|
||||
1_u64,
|
||||
a_ref.sl() as u64,
|
||||
)
|
||||
}
|
||||
});
|
||||
(min_steps..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftAddImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_add_impl<R, A, D>(
|
||||
module: &Module<FFT64>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &D,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
D: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
let b_ref: VecZnxDft<&[u8], FFT64> = b.to_ref();
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_ref.size()).min(b_ref.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_add(
|
||||
module.ptr(),
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
b_ref.at_ptr(b_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
}
|
||||
(min_size..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftAddInplaceImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_add_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_ref.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_add(
|
||||
module.ptr(),
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftSubImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_sub_impl<R, A, D>(
|
||||
module: &Module<FFT64>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &D,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
D: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
let b_ref: VecZnxDft<&[u8], FFT64> = b.to_ref();
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_ref.size()).min(b_ref.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_sub(
|
||||
module.ptr(),
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
b_ref.at_ptr(b_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
}
|
||||
(min_size..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftSubABInplaceImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_sub_ab_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_ref.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_sub(
|
||||
module.ptr(),
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftSubBAInplaceImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_sub_ba_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_ref.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_sub(
|
||||
module.ptr(),
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftCopyImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_copy_impl<R, A>(
|
||||
_module: &Module<FFT64>,
|
||||
step: usize,
|
||||
offset: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
) where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
|
||||
let steps: usize = a_ref.size().div_ceil(step);
|
||||
let min_steps: usize = res_mut.size().min(steps);
|
||||
|
||||
(0..min_steps).for_each(|j| {
|
||||
let limb: usize = offset + j * step;
|
||||
if limb < a_ref.size() {
|
||||
res_mut
|
||||
.at_mut(res_col, j)
|
||||
.copy_from_slice(a_ref.at(a_col, limb));
|
||||
}
|
||||
});
|
||||
(min_steps..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftZeroImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_zero_impl<R>(_module: &Module<FFT64>, res: &mut R)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
{
|
||||
res.to_mut().data.fill(0);
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> fmt::Display for VecZnxDft<D, FFT64> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
writeln!(
|
||||
f,
|
||||
"VecZnxDft(n={}, cols={}, size={})",
|
||||
self.n, self.cols, self.size
|
||||
)?;
|
||||
|
||||
for col in 0..self.cols {
|
||||
writeln!(f, "Column {}:", col)?;
|
||||
for size in 0..self.size {
|
||||
let coeffs = self.at(col, size);
|
||||
write!(f, " Size {}: [", size)?;
|
||||
|
||||
let max_show = 100;
|
||||
let show_count = coeffs.len().min(max_show);
|
||||
|
||||
for (i, &coeff) in coeffs.iter().take(show_count).enumerate() {
|
||||
if i > 0 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
write!(f, "{}", coeff)?;
|
||||
}
|
||||
|
||||
if coeffs.len() > max_show {
|
||||
write!(f, ", ... ({} more)", coeffs.len() - max_show)?;
|
||||
}
|
||||
|
||||
writeln!(f, "]")?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
use crate::{
|
||||
hal::{
|
||||
api::{ZnxInfos, ZnxSliceSize, ZnxView},
|
||||
layouts::{Data, DataRef, VecZnxDft, VecZnxDftBytesOf, VecZnxDftOwned},
|
||||
oep::{VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl},
|
||||
},
|
||||
implementation::cpu_spqlios::module_ntt120::NTT120,
|
||||
};
|
||||
|
||||
const VEC_ZNX_DFT_NTT120_WORDSIZE: usize = 4;
|
||||
|
||||
impl<D: Data> ZnxSliceSize for VecZnxDft<D, NTT120> {
|
||||
fn sl(&self) -> usize {
|
||||
VEC_ZNX_DFT_NTT120_WORDSIZE * self.n() * self.cols()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> VecZnxDftBytesOf for VecZnxDft<D, NTT120> {
|
||||
fn bytes_of(n: usize, cols: usize, size: usize) -> usize {
|
||||
VEC_ZNX_DFT_NTT120_WORDSIZE * n * cols * size * size_of::<i64>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> ZnxView for VecZnxDft<D, NTT120> {
|
||||
type Scalar = i64;
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftAllocBytesImpl<NTT120> for NTT120 {
|
||||
fn vec_znx_dft_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize {
|
||||
VecZnxDft::<Vec<u8>, NTT120>::bytes_of(n, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftAllocImpl<NTT120> for NTT120 {
|
||||
fn vec_znx_dft_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxDftOwned<NTT120> {
|
||||
VecZnxDftOwned::alloc(n, cols, size)
|
||||
}
|
||||
}
|
||||
298
poulpy-backend/src/implementation/cpu_spqlios/vmp_pmat_fft64.rs
Normal file
298
poulpy-backend/src/implementation/cpu_spqlios/vmp_pmat_fft64.rs
Normal file
@@ -0,0 +1,298 @@
|
||||
use crate::{
|
||||
hal::{
|
||||
api::{TakeSlice, VmpApplyTmpBytes, VmpPrepareTmpBytes, ZnxInfos, ZnxView, ZnxViewMut},
|
||||
layouts::{
|
||||
DataRef, MatZnx, MatZnxToRef, Module, Scratch, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, VmpPMat, VmpPMatBytesOf,
|
||||
VmpPMatOwned, VmpPMatToMut, VmpPMatToRef,
|
||||
},
|
||||
oep::{
|
||||
VmpApplyAddImpl, VmpApplyAddTmpBytesImpl, VmpApplyImpl, VmpApplyTmpBytesImpl, VmpPMatAllocBytesImpl,
|
||||
VmpPMatAllocImpl, VmpPMatFromBytesImpl, VmpPMatPrepareImpl, VmpPrepareTmpBytesImpl,
|
||||
},
|
||||
},
|
||||
implementation::cpu_spqlios::{
|
||||
ffi::{vec_znx_dft::vec_znx_dft_t, vmp},
|
||||
module_fft64::FFT64,
|
||||
},
|
||||
};
|
||||
|
||||
const VMP_PMAT_FFT64_WORDSIZE: usize = 1;
|
||||
|
||||
impl<D: DataRef> ZnxView for VmpPMat<D, FFT64> {
|
||||
type Scalar = f64;
|
||||
}
|
||||
|
||||
impl VmpPMatBytesOf for FFT64 {
|
||||
fn vmp_pmat_bytes_of(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
VMP_PMAT_FFT64_WORDSIZE * n * rows * cols_in * cols_out * size * size_of::<f64>()
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpPMatAllocBytesImpl<FFT64> for FFT64
|
||||
where
|
||||
FFT64: VmpPMatBytesOf,
|
||||
{
|
||||
fn vmp_pmat_alloc_bytes_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
FFT64::vmp_pmat_bytes_of(n, rows, cols_in, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpPMatFromBytesImpl<FFT64> for FFT64 {
|
||||
fn vmp_pmat_from_bytes_impl(
|
||||
n: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
bytes: Vec<u8>,
|
||||
) -> VmpPMatOwned<FFT64> {
|
||||
VmpPMatOwned::from_bytes(n, rows, cols_in, cols_out, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpPMatAllocImpl<FFT64> for FFT64 {
|
||||
fn vmp_pmat_alloc_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<FFT64> {
|
||||
VmpPMatOwned::alloc(n, rows, cols_in, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpPrepareTmpBytesImpl<FFT64> for FFT64 {
|
||||
fn vmp_prepare_tmp_bytes_impl(
|
||||
module: &Module<FFT64>,
|
||||
n: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
) -> usize {
|
||||
unsafe {
|
||||
vmp::vmp_prepare_tmp_bytes(
|
||||
module.ptr(),
|
||||
n as u64,
|
||||
(rows * cols_in) as u64,
|
||||
(cols_out * size) as u64,
|
||||
) as usize
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpPMatPrepareImpl<FFT64> for FFT64 {
|
||||
fn vmp_prepare_impl<R, A>(module: &Module<FFT64>, res: &mut R, a: &A, scratch: &mut Scratch<FFT64>)
|
||||
where
|
||||
R: VmpPMatToMut<FFT64>,
|
||||
A: MatZnxToRef,
|
||||
{
|
||||
let mut res: VmpPMat<&mut [u8], FFT64> = res.to_mut();
|
||||
let a: MatZnx<&[u8]> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
assert_eq!(
|
||||
res.cols_in(),
|
||||
a.cols_in(),
|
||||
"res.cols_in: {} != a.cols_in: {}",
|
||||
res.cols_in(),
|
||||
a.cols_in()
|
||||
);
|
||||
assert_eq!(
|
||||
res.rows(),
|
||||
a.rows(),
|
||||
"res.rows: {} != a.rows: {}",
|
||||
res.rows(),
|
||||
a.rows()
|
||||
);
|
||||
assert_eq!(
|
||||
res.cols_out(),
|
||||
a.cols_out(),
|
||||
"res.cols_out: {} != a.cols_out: {}",
|
||||
res.cols_out(),
|
||||
a.cols_out()
|
||||
);
|
||||
assert_eq!(
|
||||
res.size(),
|
||||
a.size(),
|
||||
"res.size: {} != a.size: {}",
|
||||
res.size(),
|
||||
a.size()
|
||||
);
|
||||
}
|
||||
|
||||
let (tmp_bytes, _) =
|
||||
scratch.take_slice(module.vmp_prepare_tmp_bytes(res.n(), a.rows(), a.cols_in(), a.cols_out(), a.size()));
|
||||
|
||||
unsafe {
|
||||
vmp::vmp_prepare_contiguous(
|
||||
module.ptr(),
|
||||
res.as_mut_ptr() as *mut vmp::vmp_pmat_t,
|
||||
a.as_ptr(),
|
||||
(a.rows() * a.cols_in()) as u64,
|
||||
(a.size() * a.cols_out()) as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpApplyTmpBytesImpl<FFT64> for FFT64 {
|
||||
fn vmp_apply_tmp_bytes_impl(
|
||||
module: &Module<FFT64>,
|
||||
n: usize,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
b_cols_in: usize,
|
||||
b_cols_out: usize,
|
||||
b_size: usize,
|
||||
) -> usize {
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_to_dft_tmp_bytes(
|
||||
module.ptr(),
|
||||
n as u64,
|
||||
(res_size * b_cols_out) as u64,
|
||||
(a_size * b_cols_in) as u64,
|
||||
(b_rows * b_cols_in) as u64,
|
||||
(b_size * b_cols_out) as u64,
|
||||
) as usize
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpApplyImpl<FFT64> for FFT64 {
|
||||
fn vmp_apply_impl<R, A, C>(module: &Module<FFT64>, res: &mut R, a: &A, b: &C, scratch: &mut Scratch<FFT64>)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
C: VmpPMatToRef<FFT64>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], _> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], _> = a.to_ref();
|
||||
let b: VmpPMat<&[u8], _> = b.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(b.n(), res.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
assert_eq!(
|
||||
res.cols(),
|
||||
b.cols_out(),
|
||||
"res.cols(): {} != b.cols_out: {}",
|
||||
res.cols(),
|
||||
b.cols_out()
|
||||
);
|
||||
assert_eq!(
|
||||
a.cols(),
|
||||
b.cols_in(),
|
||||
"a.cols(): {} != b.cols_in: {}",
|
||||
a.cols(),
|
||||
b.cols_in()
|
||||
);
|
||||
}
|
||||
|
||||
let (tmp_bytes, _) = scratch.take_slice(module.vmp_apply_tmp_bytes(
|
||||
res.n(),
|
||||
res.size(),
|
||||
a.size(),
|
||||
b.rows(),
|
||||
b.cols_in(),
|
||||
b.cols_out(),
|
||||
b.size(),
|
||||
));
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_to_dft(
|
||||
module.ptr(),
|
||||
res.as_mut_ptr() as *mut vec_znx_dft_t,
|
||||
(res.size() * res.cols()) as u64,
|
||||
a.as_ptr() as *const vec_znx_dft_t,
|
||||
(a.size() * a.cols()) as u64,
|
||||
b.as_ptr() as *const vmp::vmp_pmat_t,
|
||||
(b.rows() * b.cols_in()) as u64,
|
||||
(b.size() * b.cols_out()) as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpApplyAddTmpBytesImpl<FFT64> for FFT64 {
|
||||
fn vmp_apply_add_tmp_bytes_impl(
|
||||
module: &Module<FFT64>,
|
||||
n: usize,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
b_cols_in: usize,
|
||||
b_cols_out: usize,
|
||||
b_size: usize,
|
||||
) -> usize {
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_to_dft_tmp_bytes(
|
||||
module.ptr(),
|
||||
n as u64,
|
||||
(res_size * b_cols_out) as u64,
|
||||
(a_size * b_cols_in) as u64,
|
||||
(b_rows * b_cols_in) as u64,
|
||||
(b_size * b_cols_out) as u64,
|
||||
) as usize
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpApplyAddImpl<FFT64> for FFT64 {
|
||||
fn vmp_apply_add_impl<R, A, C>(module: &Module<FFT64>, res: &mut R, a: &A, b: &C, scale: usize, scratch: &mut Scratch<FFT64>)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
C: VmpPMatToRef<FFT64>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], _> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], _> = a.to_ref();
|
||||
let b: VmpPMat<&[u8], _> = b.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
use crate::hal::api::ZnxInfos;
|
||||
|
||||
assert_eq!(b.n(), res.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
assert_eq!(
|
||||
res.cols(),
|
||||
b.cols_out(),
|
||||
"res.cols(): {} != b.cols_out: {}",
|
||||
res.cols(),
|
||||
b.cols_out()
|
||||
);
|
||||
assert_eq!(
|
||||
a.cols(),
|
||||
b.cols_in(),
|
||||
"a.cols(): {} != b.cols_in: {}",
|
||||
a.cols(),
|
||||
b.cols_in()
|
||||
);
|
||||
}
|
||||
|
||||
let (tmp_bytes, _) = scratch.take_slice(module.vmp_apply_tmp_bytes(
|
||||
res.n(),
|
||||
res.size(),
|
||||
a.size(),
|
||||
b.rows(),
|
||||
b.cols_in(),
|
||||
b.cols_out(),
|
||||
b.size(),
|
||||
));
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_to_dft_add(
|
||||
module.ptr(),
|
||||
res.as_mut_ptr() as *mut vec_znx_dft_t,
|
||||
(res.size() * res.cols()) as u64,
|
||||
a.as_ptr() as *const vec_znx_dft_t,
|
||||
(a.size() * a.cols()) as u64,
|
||||
b.as_ptr() as *const vmp::vmp_pmat_t,
|
||||
(b.rows() * b.cols_in()) as u64,
|
||||
(b.size() * b.cols_out()) as u64,
|
||||
(scale * b.cols_out()) as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
use crate::{
|
||||
hal::{
|
||||
api::ZnxView,
|
||||
layouts::{DataRef, VmpPMat},
|
||||
},
|
||||
implementation::cpu_spqlios::module_ntt120::NTT120,
|
||||
};
|
||||
|
||||
impl<D: DataRef> ZnxView for VmpPMat<D, NTT120> {
|
||||
type Scalar = i64;
|
||||
}
|
||||
Reference in New Issue
Block a user