mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
Crates io (#76)
* crates re-organisation * fixed typo in layout & added test for vmp_apply * updated dependencies
This commit is contained in:
committed by
GitHub
parent
dce4d82706
commit
a1de248567
@@ -1,27 +1,28 @@
|
||||
[package]
|
||||
name = "poulpy-backend"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
license = "Apache-2.0"
|
||||
readme = "README.md"
|
||||
description = "A crate implementing bivariate polynomial arithmetic"
|
||||
repository = "https://github.com/phantomzone-org/poulpy"
|
||||
homepage = "https://github.com/phantomzone-org/poulpy"
|
||||
documentation = "https://docs.rs/poulpy"
|
||||
|
||||
[dependencies]
|
||||
rug = {workspace = true}
|
||||
criterion = {workspace = true}
|
||||
itertools = {workspace = true}
|
||||
rand = {workspace = true}
|
||||
rand_distr = {workspace = true}
|
||||
rand_core = {workspace = true}
|
||||
byteorder = {workspace = true}
|
||||
rand_chacha = "0.9.0"
|
||||
|
||||
[build-dependencies]
|
||||
cmake = "0.1.54"
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
all-features = true
|
||||
[package]
|
||||
name = "poulpy-backend"
|
||||
version = "0.1.2"
|
||||
edition = "2024"
|
||||
license = "Apache-2.0"
|
||||
readme = "README.md"
|
||||
description = "A crate providing concrete implementations of poulpy-hal through its open extension points"
|
||||
repository = "https://github.com/phantomzone-org/poulpy"
|
||||
homepage = "https://github.com/phantomzone-org/poulpy"
|
||||
documentation = "https://docs.rs/poulpy"
|
||||
|
||||
[dependencies]
|
||||
poulpy-hal = "0.1.2"
|
||||
rug = {workspace = true}
|
||||
criterion = {workspace = true}
|
||||
itertools = {workspace = true}
|
||||
rand = {workspace = true}
|
||||
rand_distr = {workspace = true}
|
||||
rand_core = {workspace = true}
|
||||
byteorder = {workspace = true}
|
||||
rand_chacha = "0.9.0"
|
||||
|
||||
[build-dependencies]
|
||||
cmake = "0.1.54"
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
all-features = true
|
||||
rustdoc-args = ["--cfg", "docsrs"]
|
||||
@@ -1,12 +1,15 @@
|
||||
|
||||
## WSL/Ubuntu
|
||||
To use this crate you need to build spqlios-arithmetic, which is provided a as a git submodule:
|
||||
1) Initialize the sub-module
|
||||
2) $ cd backend/spqlios-arithmetic
|
||||
3) mdkir build
|
||||
4) cd build
|
||||
5) cmake ..
|
||||
6) make
|
||||
|
||||
## Others
|
||||
|
||||
|
||||
## spqlios-arithmetic
|
||||
|
||||
### WSL/Ubuntu
|
||||
To use this crate you need to build spqlios-arithmetic, which is provided a as a git submodule:
|
||||
1) Initialize the sub-module
|
||||
2) $ cd backend/spqlios-arithmetic
|
||||
3) mdkir build
|
||||
4) cd build
|
||||
5) cmake ..
|
||||
6) make
|
||||
|
||||
### Others
|
||||
Steps 3 to 6 might change depending of your platform. See [spqlios-arithmetic/wiki/build](https://github.com/tfhe/spqlios-arithmetic/wiki/build) for additional information and build options.
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
pub fn build() {
|
||||
let dst: PathBuf = cmake::Config::new("src/implementation/cpu_spqlios/spqlios-arithmetic")
|
||||
let dst: PathBuf = cmake::Config::new("src/cpu_spqlios/spqlios-arithmetic")
|
||||
.define("ENABLE_TESTING", "FALSE")
|
||||
.build();
|
||||
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
Implementors must uphold all of the following for **every** call:
|
||||
|
||||
* **Memory domains**: Pointers produced by to_ref() / to_mut() must be valid
|
||||
in the target execution domain for Self (e.g., CPU host memory for CPU,
|
||||
device memory for a specific GPU). If host↔device transfers are required,
|
||||
perform them inside the implementation; do not assume the caller synchronized.
|
||||
|
||||
* **Alignment & layout**: All data must match the layout, stride, and element
|
||||
size expected by the kernel. size(), rows(), cols_in(), cols_out(),
|
||||
n(), etc... must be interpreted identically to the reference CPU implementation.
|
||||
|
||||
* **Scratch lifetime**: Any scratch obtained from scratch.tmp_slice(...) (or a
|
||||
backend-specific variant) must remain valid for the duration of the call; it
|
||||
may be reused by the caller afterwards. Do not retain pointers past return.
|
||||
|
||||
* **Synchronization**: The call must appear **logically synchronous** to the
|
||||
caller. If you enqueue asynchronous work (e.g., CUDA streams), you must
|
||||
ensure completion before returning or clearly document and implement a
|
||||
synchronization contract used by all backends consistently.
|
||||
|
||||
* **Aliasing & overlaps**: If res, a, b, etc... alias or overlap in ways
|
||||
that violate your kernel’s requirements, you must either handle safely or reject
|
||||
with a defined error path (e.g., debug assert). Never trigger UB.
|
||||
|
||||
* **Numerical contract**: For modular/integer arithmetic, results must be
|
||||
bit-exact to the specification. For floating-point, any permitted tolerance
|
||||
must be documented and consistent with the crate’s guarantees.
|
||||
@@ -1,15 +1,13 @@
|
||||
use itertools::izip;
|
||||
use poulpy_backend::{
|
||||
hal::{
|
||||
api::{
|
||||
ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPrepare, VecZnxAddNormal,
|
||||
VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace,
|
||||
VecZnxDftAlloc, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigTmpA, VecZnxFillUniform, VecZnxNormalizeInplace, ZnxInfos,
|
||||
},
|
||||
layouts::{Module, ScalarZnx, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft},
|
||||
source::Source,
|
||||
use poulpy_backend::cpu_spqlios::FFT64;
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPrepare, VecZnxAddNormal,
|
||||
VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace,
|
||||
VecZnxDftAlloc, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigTmpA, VecZnxFillUniform, VecZnxNormalizeInplace, ZnxInfos,
|
||||
},
|
||||
implementation::cpu_spqlios::FFT64,
|
||||
layouts::{Module, ScalarZnx, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
fn main() {
|
||||
|
||||
15
poulpy-backend/src/cpu_spqlios/ffi/mod.rs
Normal file
15
poulpy-backend/src/cpu_spqlios/ffi/mod.rs
Normal file
@@ -0,0 +1,15 @@
|
||||
#[allow(non_camel_case_types)]
|
||||
pub mod module;
|
||||
#[allow(non_camel_case_types)]
|
||||
pub mod svp;
|
||||
#[allow(non_camel_case_types)]
|
||||
pub mod vec_znx;
|
||||
#[allow(dead_code)]
|
||||
#[allow(non_camel_case_types)]
|
||||
pub mod vec_znx_big;
|
||||
#[allow(non_camel_case_types)]
|
||||
pub mod vec_znx_dft;
|
||||
#[allow(non_camel_case_types)]
|
||||
pub mod vmp;
|
||||
#[allow(non_camel_case_types)]
|
||||
pub mod znx;
|
||||
@@ -1,19 +1,17 @@
|
||||
pub struct module_info_t {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
|
||||
pub type module_type_t = ::std::os::raw::c_uint;
|
||||
pub use self::module_type_t as MODULE_TYPE;
|
||||
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
pub type MODULE = module_info_t;
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_module_info(N: u64, mode: MODULE_TYPE) -> *mut MODULE;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn delete_module_info(module_info: *mut MODULE);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn module_get_n(module: *const MODULE) -> u64;
|
||||
}
|
||||
#[repr(C)]
|
||||
pub struct module_info_t {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
|
||||
pub type module_type_t = ::std::os::raw::c_uint;
|
||||
pub use self::module_type_t as MODULE_TYPE;
|
||||
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
pub type MODULE = module_info_t;
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_module_info(N: u64, mode: MODULE_TYPE) -> *mut MODULE;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn delete_module_info(module_info: *mut MODULE);
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::implementation::cpu_spqlios::ffi::{module::MODULE, vec_znx_dft::VEC_ZNX_DFT};
|
||||
use crate::cpu_spqlios::ffi::{module::MODULE, vec_znx_dft::VEC_ZNX_DFT};
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
@@ -7,20 +7,11 @@ pub struct svp_ppol_t {
|
||||
}
|
||||
pub type SVP_PPOL = svp_ppol_t;
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn bytes_of_svp_ppol(module: *const MODULE) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_svp_ppol(module: *const MODULE) -> *mut SVP_PPOL;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn delete_svp_ppol(res: *mut SVP_PPOL);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn svp_prepare(module: *const MODULE, ppol: *mut SVP_PPOL, pol: *const i64);
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn svp_apply_dft(
|
||||
module: *const MODULE,
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::implementation::cpu_spqlios::ffi::module::MODULE;
|
||||
use crate::cpu_spqlios::ffi::module::MODULE;
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_add(
|
||||
@@ -53,6 +53,7 @@ unsafe extern "C" {
|
||||
);
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_rotate(
|
||||
module: *const MODULE,
|
||||
@@ -81,9 +82,12 @@ unsafe extern "C" {
|
||||
);
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_zero(module: *const MODULE, res: *mut i64, res_size: u64, res_sl: u64);
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_copy(
|
||||
module: *const MODULE,
|
||||
@@ -1,163 +1,153 @@
|
||||
use crate::implementation::cpu_spqlios::ffi::module::MODULE;
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct vec_znx_big_t {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type VEC_ZNX_BIG = vec_znx_big_t;
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn bytes_of_vec_znx_big(module: *const MODULE, size: u64) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_vec_znx_big(module: *const MODULE, size: u64) -> *mut VEC_ZNX_BIG;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn delete_vec_znx_big(res: *mut VEC_ZNX_BIG);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_add(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
a: *const VEC_ZNX_BIG,
|
||||
a_size: u64,
|
||||
b: *const VEC_ZNX_BIG,
|
||||
b_size: u64,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_add_small(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
a: *const VEC_ZNX_BIG,
|
||||
a_size: u64,
|
||||
b: *const i64,
|
||||
b_size: u64,
|
||||
b_sl: u64,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_add_small2(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
a: *const i64,
|
||||
a_size: u64,
|
||||
a_sl: u64,
|
||||
b: *const i64,
|
||||
b_size: u64,
|
||||
b_sl: u64,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_sub(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
a: *const VEC_ZNX_BIG,
|
||||
a_size: u64,
|
||||
b: *const VEC_ZNX_BIG,
|
||||
b_size: u64,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_sub_small_b(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
a: *const VEC_ZNX_BIG,
|
||||
a_size: u64,
|
||||
b: *const i64,
|
||||
b_size: u64,
|
||||
b_sl: u64,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_sub_small_a(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
a: *const i64,
|
||||
a_size: u64,
|
||||
a_sl: u64,
|
||||
b: *const VEC_ZNX_BIG,
|
||||
b_size: u64,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_sub_small2(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
a: *const i64,
|
||||
a_size: u64,
|
||||
a_sl: u64,
|
||||
b: *const i64,
|
||||
b_size: u64,
|
||||
b_sl: u64,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_normalize_base2k_tmp_bytes(module: *const MODULE, n: u64) -> u64;
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_normalize_base2k(
|
||||
module: *const MODULE,
|
||||
n: u64,
|
||||
log2_base2k: u64,
|
||||
res: *mut i64,
|
||||
res_size: u64,
|
||||
res_sl: u64,
|
||||
a: *const VEC_ZNX_BIG,
|
||||
a_size: u64,
|
||||
tmp_space: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_range_normalize_base2k(
|
||||
module: *const MODULE,
|
||||
n: u64,
|
||||
log2_base2k: u64,
|
||||
res: *mut i64,
|
||||
res_size: u64,
|
||||
res_sl: u64,
|
||||
a: *const VEC_ZNX_BIG,
|
||||
a_range_begin: u64,
|
||||
a_range_xend: u64,
|
||||
a_range_step: u64,
|
||||
tmp_space: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_range_normalize_base2k_tmp_bytes(module: *const MODULE, n: u64) -> u64;
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_automorphism(
|
||||
module: *const MODULE,
|
||||
p: i64,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
a: *const VEC_ZNX_BIG,
|
||||
a_size: u64,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_rotate(
|
||||
module: *const MODULE,
|
||||
p: i64,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
a: *const VEC_ZNX_BIG,
|
||||
a_size: u64,
|
||||
);
|
||||
}
|
||||
use crate::cpu_spqlios::ffi::module::MODULE;
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct vec_znx_big_t {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type VEC_ZNX_BIG = vec_znx_big_t;
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_add(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
a: *const VEC_ZNX_BIG,
|
||||
a_size: u64,
|
||||
b: *const VEC_ZNX_BIG,
|
||||
b_size: u64,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_add_small(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
a: *const VEC_ZNX_BIG,
|
||||
a_size: u64,
|
||||
b: *const i64,
|
||||
b_size: u64,
|
||||
b_sl: u64,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_add_small2(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
a: *const i64,
|
||||
a_size: u64,
|
||||
a_sl: u64,
|
||||
b: *const i64,
|
||||
b_size: u64,
|
||||
b_sl: u64,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_sub(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
a: *const VEC_ZNX_BIG,
|
||||
a_size: u64,
|
||||
b: *const VEC_ZNX_BIG,
|
||||
b_size: u64,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_sub_small_b(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
a: *const VEC_ZNX_BIG,
|
||||
a_size: u64,
|
||||
b: *const i64,
|
||||
b_size: u64,
|
||||
b_sl: u64,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_sub_small_a(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
a: *const i64,
|
||||
a_size: u64,
|
||||
a_sl: u64,
|
||||
b: *const VEC_ZNX_BIG,
|
||||
b_size: u64,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_sub_small2(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
a: *const i64,
|
||||
a_size: u64,
|
||||
a_sl: u64,
|
||||
b: *const i64,
|
||||
b_size: u64,
|
||||
b_sl: u64,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_normalize_base2k_tmp_bytes(module: *const MODULE, n: u64) -> u64;
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_normalize_base2k(
|
||||
module: *const MODULE,
|
||||
n: u64,
|
||||
log2_base2k: u64,
|
||||
res: *mut i64,
|
||||
res_size: u64,
|
||||
res_sl: u64,
|
||||
a: *const VEC_ZNX_BIG,
|
||||
a_size: u64,
|
||||
tmp_space: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_range_normalize_base2k(
|
||||
module: *const MODULE,
|
||||
n: u64,
|
||||
log2_base2k: u64,
|
||||
res: *mut i64,
|
||||
res_size: u64,
|
||||
res_sl: u64,
|
||||
a: *const VEC_ZNX_BIG,
|
||||
a_range_begin: u64,
|
||||
a_range_xend: u64,
|
||||
a_range_step: u64,
|
||||
tmp_space: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_range_normalize_base2k_tmp_bytes(module: *const MODULE, n: u64) -> u64;
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_automorphism(
|
||||
module: *const MODULE,
|
||||
p: i64,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
a: *const VEC_ZNX_BIG,
|
||||
a_size: u64,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_rotate(
|
||||
module: *const MODULE,
|
||||
p: i64,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
a: *const VEC_ZNX_BIG,
|
||||
a_size: u64,
|
||||
);
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::implementation::cpu_spqlios::ffi::{module::MODULE, vec_znx_big::VEC_ZNX_BIG};
|
||||
use crate::cpu_spqlios::ffi::{module::MODULE, vec_znx_big::VEC_ZNX_BIG};
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
@@ -7,19 +7,6 @@ pub struct vec_znx_dft_t {
|
||||
}
|
||||
pub type VEC_ZNX_DFT = vec_znx_dft_t;
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn bytes_of_vec_znx_dft(module: *const MODULE, size: u64) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_vec_znx_dft(module: *const MODULE, size: u64) -> *mut VEC_ZNX_DFT;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn delete_vec_znx_dft(res: *mut VEC_ZNX_DFT);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_dft_zero(module: *const MODULE, res: *mut VEC_ZNX_DFT, res_size: u64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_dft_add(
|
||||
module: *const MODULE,
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::implementation::cpu_spqlios::ffi::{module::MODULE, vec_znx_dft::VEC_ZNX_DFT};
|
||||
use crate::cpu_spqlios::ffi::{module::MODULE, vec_znx_dft::VEC_ZNX_DFT};
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
@@ -9,16 +9,7 @@ pub struct vmp_pmat_t {
|
||||
// [rows][cols] = [#Decomposition][#Limbs]
|
||||
pub type VMP_PMAT = vmp_pmat_t;
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn bytes_of_vmp_pmat(module: *const MODULE, nrows: u64, ncols: u64) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_vmp_pmat(module: *const MODULE, nrows: u64, ncols: u64) -> *mut VMP_PMAT;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn delete_vmp_pmat(res: *mut VMP_PMAT);
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_apply_dft(
|
||||
module: *const MODULE,
|
||||
@@ -34,6 +25,7 @@ unsafe extern "C" {
|
||||
);
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_apply_dft_add(
|
||||
module: *const MODULE,
|
||||
@@ -50,6 +42,7 @@ unsafe extern "C" {
|
||||
);
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_apply_dft_tmp_bytes(module: *const MODULE, res_size: u64, a_size: u64, nrows: u64, ncols: u64) -> u64;
|
||||
}
|
||||
@@ -105,10 +98,6 @@ unsafe extern "C" {
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_prepare_contiguous_dft(module: *const MODULE, pmat: *mut VMP_PMAT, mat: *const f64, nrows: u64, ncols: u64);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_prepare_tmp_bytes(module: *const MODULE, nn: u64, nrows: u64, ncols: u64) -> u64;
|
||||
}
|
||||
7
poulpy-backend/src/cpu_spqlios/ffi/znx.rs
Normal file
7
poulpy-backend/src/cpu_spqlios/ffi/znx.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_rotate_i64(nn: u64, p: i64, res: *mut i64, in_: *const i64);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_rotate_inplace_i64(nn: u64, p: i64, res: *mut i64);
|
||||
}
|
||||
15
poulpy-backend/src/cpu_spqlios/fft64/mod.rs
Normal file
15
poulpy-backend/src/cpu_spqlios/fft64/mod.rs
Normal file
@@ -0,0 +1,15 @@
|
||||
mod module;
|
||||
mod scratch;
|
||||
mod svp_ppol;
|
||||
mod vec_znx;
|
||||
mod vec_znx_big;
|
||||
mod vec_znx_dft;
|
||||
mod vmp_pmat;
|
||||
|
||||
pub use module::FFT64;
|
||||
|
||||
/// For external documentation
|
||||
pub use vec_znx::{
|
||||
vec_znx_copy_ref, vec_znx_lsh_inplace_ref, vec_znx_merge_ref, vec_znx_rsh_inplace_ref, vec_znx_split_ref,
|
||||
vec_znx_switch_degree_ref,
|
||||
};
|
||||
@@ -1,25 +1,29 @@
|
||||
use std::ptr::NonNull;
|
||||
|
||||
use crate::{
|
||||
hal::{
|
||||
layouts::{Backend, Module},
|
||||
oep::ModuleNewImpl,
|
||||
},
|
||||
implementation::cpu_spqlios::{
|
||||
CPUAVX,
|
||||
ffi::module::{MODULE, delete_module_info, new_module_info},
|
||||
},
|
||||
use poulpy_hal::{
|
||||
layouts::{Backend, Module},
|
||||
oep::ModuleNewImpl,
|
||||
};
|
||||
|
||||
use crate::cpu_spqlios::ffi::module::{MODULE, delete_module_info, new_module_info};
|
||||
|
||||
pub struct FFT64;
|
||||
|
||||
impl CPUAVX for FFT64 {}
|
||||
|
||||
impl Backend for FFT64 {
|
||||
type ScalarPrep = f64;
|
||||
type ScalarBig = i64;
|
||||
type Handle = MODULE;
|
||||
unsafe fn destroy(handle: NonNull<Self::Handle>) {
|
||||
unsafe { delete_module_info(handle.as_ptr()) }
|
||||
}
|
||||
|
||||
fn layout_big_word_count() -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn layout_prep_word_count() -> usize {
|
||||
1
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ModuleNewImpl<Self> for FFT64 {
|
||||
@@ -1,24 +1,20 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::{
|
||||
use poulpy_hal::{
|
||||
DEFAULTALIGN, alloc_aligned,
|
||||
hal::{
|
||||
api::ScratchFromBytes,
|
||||
layouts::{Backend, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat},
|
||||
oep::{
|
||||
ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, SvpPPolAllocBytesImpl,
|
||||
TakeMatZnxImpl, TakeScalarZnxImpl, TakeSliceImpl, TakeSvpPPolImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl,
|
||||
TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, TakeVmpPMatImpl, VecZnxBigAllocBytesImpl,
|
||||
VecZnxDftAllocBytesImpl, VmpPMatAllocBytesImpl,
|
||||
},
|
||||
api::ScratchFromBytes,
|
||||
layouts::{Backend, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat},
|
||||
oep::{
|
||||
ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, SvpPPolAllocBytesImpl,
|
||||
TakeMatZnxImpl, TakeScalarZnxImpl, TakeSliceImpl, TakeSvpPPolImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl,
|
||||
TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, TakeVmpPMatImpl, VecZnxBigAllocBytesImpl,
|
||||
VecZnxDftAllocBytesImpl, VmpPMatAllocBytesImpl,
|
||||
},
|
||||
implementation::cpu_spqlios::CPUAVX,
|
||||
};
|
||||
|
||||
unsafe impl<B: Backend> ScratchOwnedAllocImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
use crate::cpu_spqlios::FFT64;
|
||||
|
||||
unsafe impl<B: Backend> ScratchOwnedAllocImpl<B> for FFT64 {
|
||||
fn scratch_owned_alloc_impl(size: usize) -> ScratchOwned<B> {
|
||||
let data: Vec<u8> = alloc_aligned(size);
|
||||
ScratchOwned {
|
||||
@@ -28,28 +24,22 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> ScratchOwnedBorrowImpl<B> for B
|
||||
unsafe impl<B: Backend> ScratchOwnedBorrowImpl<B> for FFT64
|
||||
where
|
||||
B: CPUAVX,
|
||||
B: ScratchFromBytesImpl<B>,
|
||||
{
|
||||
fn scratch_owned_borrow_impl(scratch: &mut ScratchOwned<B>) -> &mut Scratch<B> {
|
||||
Scratch::from_bytes(&mut scratch.data)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> ScratchFromBytesImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
unsafe impl<B: Backend> ScratchFromBytesImpl<B> for FFT64 {
|
||||
fn scratch_from_bytes_impl(data: &mut [u8]) -> &mut Scratch<B> {
|
||||
unsafe { &mut *(data as *mut [u8] as *mut Scratch<B>) }
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> ScratchAvailableImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
unsafe impl<B: Backend> ScratchAvailableImpl<B> for FFT64 {
|
||||
fn scratch_available_impl(scratch: &Scratch<B>) -> usize {
|
||||
let ptr: *const u8 = scratch.data.as_ptr();
|
||||
let self_len: usize = scratch.data.len();
|
||||
@@ -58,9 +48,9 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeSliceImpl<B> for B
|
||||
unsafe impl<B: Backend> TakeSliceImpl<B> for FFT64
|
||||
where
|
||||
B: CPUAVX,
|
||||
B: ScratchFromBytesImpl<B>,
|
||||
{
|
||||
fn take_slice_impl<T>(scratch: &mut Scratch<B>, len: usize) -> (&mut [T], &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, len * std::mem::size_of::<T>());
|
||||
@@ -74,9 +64,9 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeScalarZnxImpl<B> for B
|
||||
unsafe impl<B: Backend> TakeScalarZnxImpl<B> for FFT64
|
||||
where
|
||||
B: CPUAVX,
|
||||
B: ScratchFromBytesImpl<B>,
|
||||
{
|
||||
fn take_scalar_znx_impl(scratch: &mut Scratch<B>, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, ScalarZnx::alloc_bytes(n, cols));
|
||||
@@ -87,9 +77,9 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeSvpPPolImpl<B> for B
|
||||
unsafe impl<B: Backend> TakeSvpPPolImpl<B> for FFT64
|
||||
where
|
||||
B: CPUAVX + SvpPPolAllocBytesImpl<B>,
|
||||
B: SvpPPolAllocBytesImpl<B> + ScratchFromBytesImpl<B>,
|
||||
{
|
||||
fn take_svp_ppol_impl(scratch: &mut Scratch<B>, n: usize, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, B::svp_ppol_alloc_bytes_impl(n, cols));
|
||||
@@ -100,9 +90,9 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVecZnxImpl<B> for B
|
||||
unsafe impl<B: Backend> TakeVecZnxImpl<B> for FFT64
|
||||
where
|
||||
B: CPUAVX,
|
||||
B: ScratchFromBytesImpl<B>,
|
||||
{
|
||||
fn take_vec_znx_impl(scratch: &mut Scratch<B>, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, VecZnx::alloc_bytes(n, cols, size));
|
||||
@@ -113,9 +103,9 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVecZnxBigImpl<B> for B
|
||||
unsafe impl<B: Backend> TakeVecZnxBigImpl<B> for FFT64
|
||||
where
|
||||
B: CPUAVX + VecZnxBigAllocBytesImpl<B>,
|
||||
B: VecZnxBigAllocBytesImpl<B> + ScratchFromBytesImpl<B>,
|
||||
{
|
||||
fn take_vec_znx_big_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
@@ -134,9 +124,9 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVecZnxDftImpl<B> for B
|
||||
unsafe impl<B: Backend> TakeVecZnxDftImpl<B> for FFT64
|
||||
where
|
||||
B: CPUAVX + VecZnxDftAllocBytesImpl<B>,
|
||||
B: VecZnxDftAllocBytesImpl<B> + ScratchFromBytesImpl<B>,
|
||||
{
|
||||
fn take_vec_znx_dft_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
@@ -156,9 +146,9 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVecZnxDftSliceImpl<B> for B
|
||||
unsafe impl<B: Backend> TakeVecZnxDftSliceImpl<B> for FFT64
|
||||
where
|
||||
B: CPUAVX + VecZnxDftAllocBytesImpl<B>,
|
||||
B: VecZnxDftAllocBytesImpl<B> + ScratchFromBytesImpl<B> + TakeVecZnxDftImpl<B>,
|
||||
{
|
||||
fn take_vec_znx_dft_slice_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
@@ -178,9 +168,9 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVecZnxSliceImpl<B> for B
|
||||
unsafe impl<B: Backend> TakeVecZnxSliceImpl<B> for FFT64
|
||||
where
|
||||
B: CPUAVX,
|
||||
B: ScratchFromBytesImpl<B> + TakeVecZnxImpl<B>,
|
||||
{
|
||||
fn take_vec_znx_slice_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
@@ -200,9 +190,9 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVmpPMatImpl<B> for B
|
||||
unsafe impl<B: Backend> TakeVmpPMatImpl<B> for FFT64
|
||||
where
|
||||
B: CPUAVX + VmpPMatAllocBytesImpl<B>,
|
||||
B: VmpPMatAllocBytesImpl<B> + ScratchFromBytesImpl<B>,
|
||||
{
|
||||
fn take_vmp_pmat_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
@@ -223,9 +213,9 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeMatZnxImpl<B> for B
|
||||
unsafe impl<B: Backend> TakeMatZnxImpl<B> for FFT64
|
||||
where
|
||||
B: CPUAVX,
|
||||
B: ScratchFromBytesImpl<B>,
|
||||
{
|
||||
fn take_mat_znx_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
@@ -1,35 +1,16 @@
|
||||
use crate::{
|
||||
hal::{
|
||||
api::{ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut},
|
||||
layouts::{
|
||||
Data, DataRef, Module, ScalarZnxToRef, SvpPPol, SvpPPolBytesOf, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDft,
|
||||
VecZnxDftToMut, VecZnxDftToRef,
|
||||
},
|
||||
oep::{SvpApplyImpl, SvpApplyInplaceImpl, SvpPPolAllocBytesImpl, SvpPPolAllocImpl, SvpPPolFromBytesImpl, SvpPrepareImpl},
|
||||
},
|
||||
implementation::cpu_spqlios::{
|
||||
ffi::{svp, vec_znx_dft::vec_znx_dft_t},
|
||||
module_fft64::FFT64,
|
||||
use poulpy_hal::{
|
||||
api::{ZnxInfos, ZnxView, ZnxViewMut},
|
||||
layouts::{
|
||||
Backend, Module, ScalarZnxToRef, SvpPPol, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDft, VecZnxDftToMut,
|
||||
VecZnxDftToRef,
|
||||
},
|
||||
oep::{SvpApplyImpl, SvpApplyInplaceImpl, SvpPPolAllocBytesImpl, SvpPPolAllocImpl, SvpPPolFromBytesImpl, SvpPrepareImpl},
|
||||
};
|
||||
|
||||
const SVP_PPOL_FFT64_WORD_SIZE: usize = 1;
|
||||
|
||||
impl<D: Data> SvpPPolBytesOf for SvpPPol<D, FFT64> {
|
||||
fn bytes_of(n: usize, cols: usize) -> usize {
|
||||
SVP_PPOL_FFT64_WORD_SIZE * n * cols * size_of::<f64>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> ZnxSliceSize for SvpPPol<D, FFT64> {
|
||||
fn sl(&self) -> usize {
|
||||
SVP_PPOL_FFT64_WORD_SIZE * self.n()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> ZnxView for SvpPPol<D, FFT64> {
|
||||
type Scalar = f64;
|
||||
}
|
||||
use crate::cpu_spqlios::{
|
||||
FFT64,
|
||||
ffi::{svp, vec_znx_dft::vec_znx_dft_t},
|
||||
};
|
||||
|
||||
unsafe impl SvpPPolFromBytesImpl<Self> for FFT64 {
|
||||
fn svp_ppol_from_bytes_impl(n: usize, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<Self> {
|
||||
@@ -45,7 +26,7 @@ unsafe impl SvpPPolAllocImpl<Self> for FFT64 {
|
||||
|
||||
unsafe impl SvpPPolAllocBytesImpl<Self> for FFT64 {
|
||||
fn svp_ppol_alloc_bytes_impl(n: usize, cols: usize) -> usize {
|
||||
SvpPPol::<Vec<u8>, Self>::bytes_of(n, cols)
|
||||
FFT64::layout_prep_word_count() * n * cols * size_of::<f64>()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,48 +1,47 @@
|
||||
use itertools::izip;
|
||||
use rand_distr::Normal;
|
||||
|
||||
use crate::{
|
||||
hal::{
|
||||
api::{
|
||||
TakeSlice, TakeVecZnx, VecZnxAddDistF64, VecZnxCopy, VecZnxFillDistF64, VecZnxNormalizeTmpBytes, VecZnxRotate,
|
||||
VecZnxRotateInplace, VecZnxSwithcDegree, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
|
||||
},
|
||||
layouts::{Backend, Module, ScalarZnx, ScalarZnxToRef, Scratch, VecZnx, VecZnxToMut, VecZnxToRef},
|
||||
oep::{
|
||||
VecZnxAddDistF64Impl, VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl, VecZnxAddScalarInplaceImpl,
|
||||
VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl, VecZnxFillDistF64Impl, VecZnxFillNormalImpl,
|
||||
VecZnxFillUniformImpl, VecZnxLshInplaceImpl, VecZnxMergeImpl, VecZnxMulXpMinusOneImpl,
|
||||
VecZnxMulXpMinusOneInplaceImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl,
|
||||
VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl,
|
||||
VecZnxRshInplaceImpl, VecZnxSplitImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl,
|
||||
VecZnxSubScalarInplaceImpl, VecZnxSwithcDegreeImpl,
|
||||
},
|
||||
source::Source,
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
TakeSlice, TakeVecZnx, VecZnxAddDistF64, VecZnxCopy, VecZnxFillDistF64, VecZnxNormalizeTmpBytes, VecZnxRotate,
|
||||
VecZnxRotateInplace, VecZnxSwithcDegree, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
|
||||
},
|
||||
implementation::cpu_spqlios::{
|
||||
CPUAVX,
|
||||
ffi::{module::module_info_t, vec_znx, znx},
|
||||
layouts::{Backend, Module, ScalarZnx, ScalarZnxToRef, Scratch, VecZnx, VecZnxToMut, VecZnxToRef},
|
||||
oep::{
|
||||
TakeSliceImpl, TakeVecZnxImpl, VecZnxAddDistF64Impl, VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl,
|
||||
VecZnxAddScalarInplaceImpl, VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl, VecZnxFillDistF64Impl,
|
||||
VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxLshInplaceImpl, VecZnxMergeImpl, VecZnxMulXpMinusOneImpl,
|
||||
VecZnxMulXpMinusOneInplaceImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl,
|
||||
VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl, VecZnxRshInplaceImpl,
|
||||
VecZnxSplitImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarInplaceImpl,
|
||||
VecZnxSwithcDegreeImpl,
|
||||
},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
unsafe impl<B: Backend> VecZnxNormalizeTmpBytesImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn vec_znx_normalize_tmp_bytes_impl(module: &Module<B>, n: usize) -> usize {
|
||||
use crate::cpu_spqlios::{
|
||||
FFT64,
|
||||
ffi::{module::module_info_t, vec_znx, znx},
|
||||
};
|
||||
|
||||
unsafe impl VecZnxNormalizeTmpBytesImpl<Self> for FFT64 {
|
||||
fn vec_znx_normalize_tmp_bytes_impl(module: &Module<Self>, n: usize) -> usize {
|
||||
unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(module.ptr() as *const module_info_t, n as u64) as usize }
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxNormalizeImpl<B> for B {
|
||||
unsafe impl VecZnxNormalizeImpl<Self> for FFT64
|
||||
where
|
||||
Self: TakeSliceImpl<Self> + VecZnxNormalizeTmpBytesImpl<Self>,
|
||||
{
|
||||
fn vec_znx_normalize_impl<R, A>(
|
||||
module: &Module<B>,
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<B>,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
@@ -74,9 +73,17 @@ unsafe impl<B: Backend + CPUAVX> VecZnxNormalizeImpl<B> for B {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxNormalizeInplaceImpl<B> for B {
|
||||
fn vec_znx_normalize_inplace_impl<A>(module: &Module<B>, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
unsafe impl VecZnxNormalizeInplaceImpl<Self> for FFT64
|
||||
where
|
||||
Self: TakeSliceImpl<Self> + VecZnxNormalizeTmpBytesImpl<Self>,
|
||||
{
|
||||
fn vec_znx_normalize_inplace_impl<A>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
a: &mut A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
let mut a: VecZnx<&mut [u8]> = a.to_mut();
|
||||
@@ -100,8 +107,8 @@ unsafe impl<B: Backend + CPUAVX> VecZnxNormalizeInplaceImpl<B> for B {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxAddImpl<B> for B {
|
||||
fn vec_znx_add_impl<R, A, C>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
unsafe impl VecZnxAddImpl<Self> for FFT64 {
|
||||
fn vec_znx_add_impl<R, A, C>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
@@ -134,8 +141,8 @@ unsafe impl<B: Backend + CPUAVX> VecZnxAddImpl<B> for B {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxAddInplaceImpl<B> for B {
|
||||
fn vec_znx_add_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
unsafe impl VecZnxAddInplaceImpl<Self> for FFT64 {
|
||||
fn vec_znx_add_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
@@ -164,9 +171,9 @@ unsafe impl<B: Backend + CPUAVX> VecZnxAddInplaceImpl<B> for B {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxAddScalarInplaceImpl<B> for B {
|
||||
unsafe impl VecZnxAddScalarInplaceImpl<Self> for FFT64 {
|
||||
fn vec_znx_add_scalar_inplace_impl<R, A>(
|
||||
module: &Module<B>,
|
||||
module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
res_limb: usize,
|
||||
@@ -201,8 +208,8 @@ unsafe impl<B: Backend + CPUAVX> VecZnxAddScalarInplaceImpl<B> for B {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxSubImpl<B> for B {
|
||||
fn vec_znx_sub_impl<R, A, C>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
unsafe impl VecZnxSubImpl<Self> for FFT64 {
|
||||
fn vec_znx_sub_impl<R, A, C>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
@@ -235,8 +242,8 @@ unsafe impl<B: Backend + CPUAVX> VecZnxSubImpl<B> for B {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxSubABInplaceImpl<B> for B {
|
||||
fn vec_znx_sub_ab_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
unsafe impl VecZnxSubABInplaceImpl<Self> for FFT64 {
|
||||
fn vec_znx_sub_ab_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
@@ -264,8 +271,8 @@ unsafe impl<B: Backend + CPUAVX> VecZnxSubABInplaceImpl<B> for B {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxSubBAInplaceImpl<B> for B {
|
||||
fn vec_znx_sub_ba_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
unsafe impl VecZnxSubBAInplaceImpl<Self> for FFT64 {
|
||||
fn vec_znx_sub_ba_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
@@ -293,9 +300,9 @@ unsafe impl<B: Backend + CPUAVX> VecZnxSubBAInplaceImpl<B> for B {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxSubScalarInplaceImpl<B> for B {
|
||||
unsafe impl VecZnxSubScalarInplaceImpl<Self> for FFT64 {
|
||||
fn vec_znx_sub_scalar_inplace_impl<R, A>(
|
||||
module: &Module<B>,
|
||||
module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
res_limb: usize,
|
||||
@@ -330,8 +337,8 @@ unsafe impl<B: Backend + CPUAVX> VecZnxSubScalarInplaceImpl<B> for B {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxNegateImpl<B> for B {
|
||||
fn vec_znx_negate_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
unsafe impl VecZnxNegateImpl<Self> for FFT64 {
|
||||
fn vec_znx_negate_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
@@ -356,8 +363,8 @@ unsafe impl<B: Backend + CPUAVX> VecZnxNegateImpl<B> for B {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxNegateInplaceImpl<B> for B {
|
||||
fn vec_znx_negate_inplace_impl<A>(module: &Module<B>, a: &mut A, a_col: usize)
|
||||
unsafe impl VecZnxNegateInplaceImpl<Self> for FFT64 {
|
||||
fn vec_znx_negate_inplace_impl<A>(module: &Module<Self>, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
@@ -376,8 +383,8 @@ unsafe impl<B: Backend + CPUAVX> VecZnxNegateInplaceImpl<B> for B {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxLshInplaceImpl<B> for B {
|
||||
fn vec_znx_lsh_inplace_impl<A>(_module: &Module<B>, basek: usize, k: usize, a: &mut A)
|
||||
unsafe impl VecZnxLshInplaceImpl<Self> for FFT64 {
|
||||
fn vec_znx_lsh_inplace_impl<A>(_module: &Module<Self>, basek: usize, k: usize, a: &mut A)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
@@ -417,8 +424,8 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxRshInplaceImpl<B> for B {
|
||||
fn vec_znx_rsh_inplace_impl<A>(_module: &Module<B>, basek: usize, k: usize, a: &mut A)
|
||||
unsafe impl VecZnxRshInplaceImpl<Self> for FFT64 {
|
||||
fn vec_znx_rsh_inplace_impl<A>(_module: &Module<Self>, basek: usize, k: usize, a: &mut A)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
@@ -461,8 +468,8 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxRotateImpl<B> for B {
|
||||
fn vec_znx_rotate_impl<R, A>(_module: &Module<B>, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
unsafe impl VecZnxRotateImpl<Self> for FFT64 {
|
||||
fn vec_znx_rotate_impl<R, A>(_module: &Module<Self>, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
@@ -486,8 +493,8 @@ unsafe impl<B: Backend + CPUAVX> VecZnxRotateImpl<B> for B {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxRotateInplaceImpl<B> for B {
|
||||
fn vec_znx_rotate_inplace_impl<A>(_module: &Module<B>, k: i64, a: &mut A, a_col: usize)
|
||||
unsafe impl VecZnxRotateInplaceImpl<Self> for FFT64 {
|
||||
fn vec_znx_rotate_inplace_impl<A>(_module: &Module<Self>, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
@@ -500,8 +507,8 @@ unsafe impl<B: Backend + CPUAVX> VecZnxRotateInplaceImpl<B> for B {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxAutomorphismImpl<B> for B {
|
||||
fn vec_znx_automorphism_impl<R, A>(module: &Module<B>, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
unsafe impl VecZnxAutomorphismImpl<Self> for FFT64 {
|
||||
fn vec_znx_automorphism_impl<R, A>(module: &Module<Self>, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
@@ -527,8 +534,8 @@ unsafe impl<B: Backend + CPUAVX> VecZnxAutomorphismImpl<B> for B {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxAutomorphismInplaceImpl<B> for B {
|
||||
fn vec_znx_automorphism_inplace_impl<A>(module: &Module<B>, k: i64, a: &mut A, a_col: usize)
|
||||
unsafe impl VecZnxAutomorphismInplaceImpl<Self> for FFT64 {
|
||||
fn vec_znx_automorphism_inplace_impl<A>(module: &Module<Self>, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
@@ -556,8 +563,8 @@ unsafe impl<B: Backend + CPUAVX> VecZnxAutomorphismInplaceImpl<B> for B {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxMulXpMinusOneImpl<B> for B {
|
||||
fn vec_znx_mul_xp_minus_one_impl<R, A>(module: &Module<B>, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
unsafe impl VecZnxMulXpMinusOneImpl<Self> for FFT64 {
|
||||
fn vec_znx_mul_xp_minus_one_impl<R, A>(module: &Module<Self>, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
@@ -584,8 +591,8 @@ unsafe impl<B: Backend + CPUAVX> VecZnxMulXpMinusOneImpl<B> for B {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxMulXpMinusOneInplaceImpl<B> for B {
|
||||
fn vec_znx_mul_xp_minus_one_inplace_impl<R>(module: &Module<B>, p: i64, res: &mut R, res_col: usize)
|
||||
unsafe impl VecZnxMulXpMinusOneInplaceImpl<Self> for FFT64 {
|
||||
fn vec_znx_mul_xp_minus_one_inplace_impl<R>(module: &Module<Self>, p: i64, res: &mut R, res_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
@@ -609,9 +616,22 @@ unsafe impl<B: Backend + CPUAVX> VecZnxMulXpMinusOneInplaceImpl<B> for B {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxSplitImpl<B> for B {
|
||||
fn vec_znx_split_impl<R, A>(module: &Module<B>, res: &mut [R], res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
unsafe impl VecZnxSplitImpl<Self> for FFT64
|
||||
where
|
||||
Self: TakeVecZnxImpl<Self>
|
||||
+ TakeVecZnxImpl<Self>
|
||||
+ VecZnxSwithcDegreeImpl<Self>
|
||||
+ VecZnxRotateImpl<Self>
|
||||
+ VecZnxRotateInplaceImpl<Self>,
|
||||
{
|
||||
fn vec_znx_split_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
res: &mut [R],
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
@@ -627,7 +647,7 @@ pub fn vec_znx_split_ref<R, A, B>(
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
B: Backend + CPUAVX,
|
||||
B: Backend + TakeVecZnxImpl<B> + VecZnxSwithcDegreeImpl<B> + VecZnxRotateImpl<B> + VecZnxRotateInplaceImpl<B>,
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
@@ -660,8 +680,11 @@ pub fn vec_znx_split_ref<R, A, B>(
|
||||
})
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxMergeImpl<B> for B {
|
||||
fn vec_znx_merge_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &[A], a_col: usize)
|
||||
unsafe impl VecZnxMergeImpl<Self> for FFT64
|
||||
where
|
||||
Self: VecZnxSwithcDegreeImpl<Self> + VecZnxRotateInplaceImpl<Self>,
|
||||
{
|
||||
fn vec_znx_merge_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &[A], a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
@@ -672,7 +695,7 @@ unsafe impl<B: Backend + CPUAVX> VecZnxMergeImpl<B> for B {
|
||||
|
||||
pub fn vec_znx_merge_ref<R, A, B>(module: &Module<B>, res: &mut R, res_col: usize, a: &[A], a_col: usize)
|
||||
where
|
||||
B: Backend + CPUAVX,
|
||||
B: Backend + VecZnxSwithcDegreeImpl<B> + VecZnxRotateInplaceImpl<B>,
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
@@ -700,8 +723,11 @@ where
|
||||
module.vec_znx_rotate_inplace(a.len() as i64, &mut res, res_col);
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxSwithcDegreeImpl<B> for B {
|
||||
fn vec_znx_switch_degree_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
unsafe impl VecZnxSwithcDegreeImpl<Self> for FFT64
|
||||
where
|
||||
Self: VecZnxCopyImpl<Self>,
|
||||
{
|
||||
fn vec_znx_switch_degree_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
@@ -712,7 +738,7 @@ unsafe impl<B: Backend + CPUAVX> VecZnxSwithcDegreeImpl<B> for B {
|
||||
|
||||
pub fn vec_znx_switch_degree_ref<R, A, B>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
B: Backend + CPUAVX,
|
||||
B: Backend + VecZnxCopyImpl<B>,
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
@@ -745,8 +771,8 @@ where
|
||||
});
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxCopyImpl<B> for B {
|
||||
fn vec_znx_copy_impl<R, A>(_module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
unsafe impl VecZnxCopyImpl<Self> for FFT64 {
|
||||
fn vec_znx_copy_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
@@ -775,9 +801,15 @@ where
|
||||
})
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxFillUniformImpl<B> for B {
|
||||
fn vec_znx_fill_uniform_impl<R>(_module: &Module<B>, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source)
|
||||
where
|
||||
unsafe impl VecZnxFillUniformImpl<Self> for FFT64 {
|
||||
fn vec_znx_fill_uniform_impl<R>(
|
||||
_module: &Module<Self>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
let mut a: VecZnx<&mut [u8]> = res.to_mut();
|
||||
@@ -792,9 +824,9 @@ unsafe impl<B: Backend + CPUAVX> VecZnxFillUniformImpl<B> for B {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxFillDistF64Impl<B> for B {
|
||||
unsafe impl VecZnxFillDistF64Impl<Self> for FFT64 {
|
||||
fn vec_znx_fill_dist_f64_impl<R, D: rand::prelude::Distribution<f64>>(
|
||||
_module: &Module<B>,
|
||||
_module: &Module<Self>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
@@ -835,9 +867,9 @@ unsafe impl<B: Backend + CPUAVX> VecZnxFillDistF64Impl<B> for B {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxAddDistF64Impl<B> for B {
|
||||
unsafe impl VecZnxAddDistF64Impl<Self> for FFT64 {
|
||||
fn vec_znx_add_dist_f64_impl<R, D: rand::prelude::Distribution<f64>>(
|
||||
_module: &Module<B>,
|
||||
_module: &Module<Self>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
@@ -878,9 +910,12 @@ unsafe impl<B: Backend + CPUAVX> VecZnxAddDistF64Impl<B> for B {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxFillNormalImpl<B> for B {
|
||||
unsafe impl VecZnxFillNormalImpl<Self> for FFT64
|
||||
where
|
||||
Self: VecZnxFillDistF64Impl<Self>,
|
||||
{
|
||||
fn vec_znx_fill_normal_impl<R>(
|
||||
module: &Module<B>,
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
@@ -903,9 +938,12 @@ unsafe impl<B: Backend + CPUAVX> VecZnxFillNormalImpl<B> for B {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend + CPUAVX> VecZnxAddNormalImpl<B> for B {
|
||||
unsafe impl VecZnxAddNormalImpl<Self> for FFT64
|
||||
where
|
||||
Self: VecZnxAddDistF64Impl<Self>,
|
||||
{
|
||||
fn vec_znx_add_normal_impl<R>(
|
||||
module: &Module<B>,
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
@@ -1,69 +1,46 @@
|
||||
use std::fmt;
|
||||
|
||||
use rand_distr::{Distribution, Normal};
|
||||
|
||||
use crate::{
|
||||
hal::{
|
||||
api::{
|
||||
TakeSlice, VecZnxBigAddDistF64, VecZnxBigFillDistF64, VecZnxBigNormalizeTmpBytes, ZnxInfos, ZnxSliceSize, ZnxView,
|
||||
ZnxViewMut,
|
||||
},
|
||||
layouts::{
|
||||
Data, DataRef, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigBytesOf, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef,
|
||||
VecZnxToMut, VecZnxToRef,
|
||||
},
|
||||
oep::{
|
||||
VecZnxBigAddDistF64Impl, VecZnxBigAddImpl, VecZnxBigAddInplaceImpl, VecZnxBigAddNormalImpl, VecZnxBigAddSmallImpl,
|
||||
VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, VecZnxBigAutomorphismImpl,
|
||||
VecZnxBigAutomorphismInplaceImpl, VecZnxBigFillDistF64Impl, VecZnxBigFillNormalImpl, VecZnxBigFromBytesImpl,
|
||||
VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl, VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubABInplaceImpl,
|
||||
VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl, VecZnxBigSubSmallAImpl, VecZnxBigSubSmallAInplaceImpl,
|
||||
VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl,
|
||||
},
|
||||
source::Source,
|
||||
use crate::cpu_spqlios::{FFT64, ffi::vec_znx};
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
TakeSlice, VecZnxBigAddDistF64, VecZnxBigFillDistF64, VecZnxBigNormalizeTmpBytes, ZnxInfos, ZnxSliceSize, ZnxView,
|
||||
ZnxViewMut,
|
||||
},
|
||||
implementation::cpu_spqlios::{ffi::vec_znx, module_fft64::FFT64},
|
||||
layouts::{
|
||||
Backend, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef,
|
||||
},
|
||||
oep::{
|
||||
TakeSliceImpl, VecZnxBigAddDistF64Impl, VecZnxBigAddImpl, VecZnxBigAddInplaceImpl, VecZnxBigAddNormalImpl,
|
||||
VecZnxBigAddSmallImpl, VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl,
|
||||
VecZnxBigAutomorphismImpl, VecZnxBigAutomorphismInplaceImpl, VecZnxBigFillDistF64Impl, VecZnxBigFillNormalImpl,
|
||||
VecZnxBigFromBytesImpl, VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl, VecZnxBigNormalizeTmpBytesImpl,
|
||||
VecZnxBigSubABInplaceImpl, VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl, VecZnxBigSubSmallAImpl,
|
||||
VecZnxBigSubSmallAInplaceImpl, VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl,
|
||||
},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
const VEC_ZNX_BIG_FFT64_WORDSIZE: usize = 1;
|
||||
|
||||
impl<D: DataRef> ZnxView for VecZnxBig<D, FFT64> {
|
||||
type Scalar = i64;
|
||||
}
|
||||
|
||||
impl<D: Data> VecZnxBigBytesOf for VecZnxBig<D, FFT64> {
|
||||
fn bytes_of(n: usize, cols: usize, size: usize) -> usize {
|
||||
VEC_ZNX_BIG_FFT64_WORDSIZE * n * cols * size * size_of::<f64>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> ZnxSliceSize for VecZnxBig<D, FFT64> {
|
||||
fn sl(&self) -> usize {
|
||||
VEC_ZNX_BIG_FFT64_WORDSIZE * self.n() * self.cols()
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAllocImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_big_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxBigOwned<FFT64> {
|
||||
VecZnxBig::<Vec<u8>, FFT64>::new(n, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigFromBytesImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_big_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<FFT64> {
|
||||
VecZnxBig::<Vec<u8>, FFT64>::new_from_bytes(n, cols, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAllocBytesImpl<FFT64> for FFT64 {
|
||||
unsafe impl VecZnxBigAllocBytesImpl<Self> for FFT64 {
|
||||
fn vec_znx_big_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize {
|
||||
VecZnxBig::<Vec<u8>, FFT64>::bytes_of(n, cols, size)
|
||||
Self::layout_big_word_count() * n * cols * size * size_of::<f64>()
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddDistF64Impl<FFT64> for FFT64 {
|
||||
fn add_dist_f64_impl<R: VecZnxBigToMut<FFT64>, D: Distribution<f64>>(
|
||||
_module: &Module<FFT64>,
|
||||
unsafe impl VecZnxBigAllocImpl<Self> for FFT64 {
|
||||
fn vec_znx_big_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxBigOwned<Self> {
|
||||
VecZnxBig::alloc(n, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigFromBytesImpl<Self> for FFT64 {
|
||||
fn vec_znx_big_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<Self> {
|
||||
VecZnxBig::from_bytes(n, cols, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddDistF64Impl<Self> for FFT64 {
|
||||
fn add_dist_f64_impl<R: VecZnxBigToMut<Self>, D: Distribution<f64>>(
|
||||
_module: &Module<Self>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
@@ -72,7 +49,7 @@ unsafe impl VecZnxBigAddDistF64Impl<FFT64> for FFT64 {
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) {
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
@@ -102,9 +79,9 @@ unsafe impl VecZnxBigAddDistF64Impl<FFT64> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddNormalImpl<FFT64> for FFT64 {
|
||||
fn add_normal_impl<R: VecZnxBigToMut<FFT64>>(
|
||||
module: &Module<FFT64>,
|
||||
unsafe impl VecZnxBigAddNormalImpl<Self> for FFT64 {
|
||||
fn add_normal_impl<R: VecZnxBigToMut<Self>>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
@@ -125,9 +102,9 @@ unsafe impl VecZnxBigAddNormalImpl<FFT64> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigFillDistF64Impl<FFT64> for FFT64 {
|
||||
fn fill_dist_f64_impl<R: VecZnxBigToMut<FFT64>, D: Distribution<f64>>(
|
||||
_module: &Module<FFT64>,
|
||||
unsafe impl VecZnxBigFillDistF64Impl<Self> for FFT64 {
|
||||
fn fill_dist_f64_impl<R: VecZnxBigToMut<Self>, D: Distribution<f64>>(
|
||||
_module: &Module<Self>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
@@ -136,7 +113,7 @@ unsafe impl VecZnxBigFillDistF64Impl<FFT64> for FFT64 {
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) {
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
@@ -166,9 +143,9 @@ unsafe impl VecZnxBigFillDistF64Impl<FFT64> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigFillNormalImpl<FFT64> for FFT64 {
|
||||
fn fill_normal_impl<R: VecZnxBigToMut<FFT64>>(
|
||||
module: &Module<FFT64>,
|
||||
unsafe impl VecZnxBigFillNormalImpl<Self> for FFT64 {
|
||||
fn fill_normal_impl<R: VecZnxBigToMut<Self>>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
@@ -189,24 +166,17 @@ unsafe impl VecZnxBigFillNormalImpl<FFT64> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddImpl<FFT64> for FFT64 {
|
||||
unsafe impl VecZnxBigAddImpl<Self> for FFT64 {
|
||||
/// Adds `a` to `b` and stores the result on `c`.
|
||||
fn vec_znx_big_add_impl<R, A, B>(
|
||||
module: &Module<FFT64>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
B: VecZnxBigToRef<FFT64>,
|
||||
fn vec_znx_big_add_impl<R, A, B>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxBigToRef<Self>,
|
||||
B: VecZnxBigToRef<Self>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let b: VecZnxBig<&[u8], FFT64> = b.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
let a: VecZnxBig<&[u8], Self> = a.to_ref();
|
||||
let b: VecZnxBig<&[u8], Self> = b.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
@@ -231,15 +201,15 @@ unsafe impl VecZnxBigAddImpl<FFT64> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddInplaceImpl<FFT64> for FFT64 {
|
||||
unsafe impl VecZnxBigAddInplaceImpl<Self> for FFT64 {
|
||||
/// Adds `a` to `b` and stores the result on `b`.
|
||||
fn vec_znx_big_add_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
fn vec_znx_big_add_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxBigToRef<Self>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
let a: VecZnxBig<&[u8], Self> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
@@ -262,10 +232,10 @@ unsafe impl VecZnxBigAddInplaceImpl<FFT64> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddSmallImpl<FFT64> for FFT64 {
|
||||
unsafe impl VecZnxBigAddSmallImpl<Self> for FFT64 {
|
||||
/// Adds `a` to `b` and stores the result on `c`.
|
||||
fn vec_znx_big_add_small_impl<R, A, B>(
|
||||
module: &Module<FFT64>,
|
||||
module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
@@ -273,13 +243,13 @@ unsafe impl VecZnxBigAddSmallImpl<FFT64> for FFT64 {
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxBigToRef<Self>,
|
||||
B: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let a: VecZnxBig<&[u8], Self> = a.to_ref();
|
||||
let b: VecZnx<&[u8]> = b.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
@@ -304,15 +274,15 @@ unsafe impl VecZnxBigAddSmallImpl<FFT64> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddSmallInplaceImpl<FFT64> for FFT64 {
|
||||
unsafe impl VecZnxBigAddSmallInplaceImpl<Self> for FFT64 {
|
||||
/// Adds `a` to `b` and stores the result on `b`.
|
||||
fn vec_znx_big_add_small_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
fn vec_znx_big_add_small_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
@@ -335,24 +305,17 @@ unsafe impl VecZnxBigAddSmallInplaceImpl<FFT64> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubImpl<FFT64> for FFT64 {
|
||||
unsafe impl VecZnxBigSubImpl<Self> for FFT64 {
|
||||
/// Subtracts `a` to `b` and stores the result on `c`.
|
||||
fn vec_znx_big_sub_impl<R, A, B>(
|
||||
module: &Module<FFT64>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
B: VecZnxBigToRef<FFT64>,
|
||||
fn vec_znx_big_sub_impl<R, A, B>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxBigToRef<Self>,
|
||||
B: VecZnxBigToRef<Self>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let b: VecZnxBig<&[u8], FFT64> = b.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
let a: VecZnxBig<&[u8], Self> = a.to_ref();
|
||||
let b: VecZnxBig<&[u8], Self> = b.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
@@ -377,15 +340,15 @@ unsafe impl VecZnxBigSubImpl<FFT64> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubABInplaceImpl<FFT64> for FFT64 {
|
||||
unsafe impl VecZnxBigSubABInplaceImpl<Self> for FFT64 {
|
||||
/// Subtracts `a` from `b` and stores the result on `b`.
|
||||
fn vec_znx_big_sub_ab_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
fn vec_znx_big_sub_ab_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxBigToRef<Self>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
let a: VecZnxBig<&[u8], Self> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
@@ -408,15 +371,15 @@ unsafe impl VecZnxBigSubABInplaceImpl<FFT64> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubBAInplaceImpl<FFT64> for FFT64 {
|
||||
unsafe impl VecZnxBigSubBAInplaceImpl<Self> for FFT64 {
|
||||
/// Subtracts `b` from `a` and stores the result on `b`.
|
||||
fn vec_znx_big_sub_ba_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
fn vec_znx_big_sub_ba_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxBigToRef<Self>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
let a: VecZnxBig<&[u8], Self> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
@@ -439,10 +402,10 @@ unsafe impl VecZnxBigSubBAInplaceImpl<FFT64> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubSmallAImpl<FFT64> for FFT64 {
|
||||
unsafe impl VecZnxBigSubSmallAImpl<Self> for FFT64 {
|
||||
/// Subtracts `b` from `a` and stores the result on `c`.
|
||||
fn vec_znx_big_sub_small_a_impl<R, A, B>(
|
||||
module: &Module<FFT64>,
|
||||
module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
@@ -450,13 +413,13 @@ unsafe impl VecZnxBigSubSmallAImpl<FFT64> for FFT64 {
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxToRef,
|
||||
B: VecZnxBigToRef<FFT64>,
|
||||
B: VecZnxBigToRef<Self>,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let b: VecZnxBig<&[u8], FFT64> = b.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
let b: VecZnxBig<&[u8], Self> = b.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
@@ -481,15 +444,15 @@ unsafe impl VecZnxBigSubSmallAImpl<FFT64> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubSmallAInplaceImpl<FFT64> for FFT64 {
|
||||
unsafe impl VecZnxBigSubSmallAInplaceImpl<Self> for FFT64 {
|
||||
/// Subtracts `a` from `res` and stores the result on `res`.
|
||||
fn vec_znx_big_sub_small_a_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
fn vec_znx_big_sub_small_a_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
@@ -512,10 +475,10 @@ unsafe impl VecZnxBigSubSmallAInplaceImpl<FFT64> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubSmallBImpl<FFT64> for FFT64 {
|
||||
unsafe impl VecZnxBigSubSmallBImpl<Self> for FFT64 {
|
||||
/// Subtracts `b` from `a` and stores the result on `c`.
|
||||
fn vec_znx_big_sub_small_b_impl<R, A, B>(
|
||||
module: &Module<FFT64>,
|
||||
module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
@@ -523,13 +486,13 @@ unsafe impl VecZnxBigSubSmallBImpl<FFT64> for FFT64 {
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxBigToRef<Self>,
|
||||
B: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let a: VecZnxBig<&[u8], Self> = a.to_ref();
|
||||
let b: VecZnx<&[u8]> = b.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
@@ -554,15 +517,15 @@ unsafe impl VecZnxBigSubSmallBImpl<FFT64> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubSmallBInplaceImpl<FFT64> for FFT64 {
|
||||
unsafe impl VecZnxBigSubSmallBInplaceImpl<Self> for FFT64 {
|
||||
/// Subtracts `res` from `a` and stores the result on `res`.
|
||||
fn vec_znx_big_sub_small_b_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
fn vec_znx_big_sub_small_b_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
@@ -585,12 +548,12 @@ unsafe impl VecZnxBigSubSmallBInplaceImpl<FFT64> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigNegateInplaceImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_big_negate_inplace_impl<A>(module: &Module<FFT64>, a: &mut A, a_col: usize)
|
||||
unsafe impl VecZnxBigNegateInplaceImpl<Self> for FFT64 {
|
||||
fn vec_znx_big_negate_inplace_impl<A>(module: &Module<Self>, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToMut<Self>,
|
||||
{
|
||||
let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut();
|
||||
let mut a: VecZnxBig<&mut [u8], Self> = a.to_mut();
|
||||
unsafe {
|
||||
vec_znx::vec_znx_negate(
|
||||
module.ptr(),
|
||||
@@ -605,26 +568,29 @@ unsafe impl VecZnxBigNegateInplaceImpl<FFT64> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigNormalizeTmpBytesImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_big_normalize_tmp_bytes_impl(module: &Module<FFT64>, n: usize) -> usize {
|
||||
unsafe impl VecZnxBigNormalizeTmpBytesImpl<Self> for FFT64 {
|
||||
fn vec_znx_big_normalize_tmp_bytes_impl(module: &Module<Self>, n: usize) -> usize {
|
||||
unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(module.ptr(), n as u64) as usize }
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigNormalizeImpl<FFT64> for FFT64 {
|
||||
unsafe impl VecZnxBigNormalizeImpl<Self> for FFT64
|
||||
where
|
||||
Self: TakeSliceImpl<Self>,
|
||||
{
|
||||
fn vec_znx_big_normalize_impl<R, A>(
|
||||
module: &Module<FFT64>,
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<FFT64>,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
A: VecZnxBigToRef<Self>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let a: VecZnxBig<&[u8], Self> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
@@ -650,15 +616,15 @@ unsafe impl VecZnxBigNormalizeImpl<FFT64> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAutomorphismImpl<FFT64> for FFT64 {
|
||||
unsafe impl VecZnxBigAutomorphismImpl<Self> for FFT64 {
|
||||
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`.
|
||||
fn vec_znx_big_automorphism_impl<R, A>(module: &Module<FFT64>, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
fn vec_znx_big_automorphism_impl<R, A>(module: &Module<Self>, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxBigToRef<Self>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
let a: VecZnxBig<&[u8], Self> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
@@ -679,13 +645,13 @@ unsafe impl VecZnxBigAutomorphismImpl<FFT64> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAutomorphismInplaceImpl<FFT64> for FFT64 {
|
||||
unsafe impl VecZnxBigAutomorphismInplaceImpl<Self> for FFT64 {
|
||||
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`.
|
||||
fn vec_znx_big_automorphism_inplace_impl<A>(module: &Module<FFT64>, k: i64, a: &mut A, a_col: usize)
|
||||
fn vec_znx_big_automorphism_inplace_impl<A>(module: &Module<Self>, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToMut<Self>,
|
||||
{
|
||||
let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut();
|
||||
let mut a: VecZnxBig<&mut [u8], Self> = a.to_mut();
|
||||
unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
module.ptr(),
|
||||
@@ -700,38 +666,3 @@ unsafe impl VecZnxBigAutomorphismInplaceImpl<FFT64> for FFT64 {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> fmt::Display for VecZnxBig<D, FFT64> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
writeln!(
|
||||
f,
|
||||
"VecZnxBig(n={}, cols={}, size={})",
|
||||
self.n, self.cols, self.size
|
||||
)?;
|
||||
|
||||
for col in 0..self.cols {
|
||||
writeln!(f, "Column {}:", col)?;
|
||||
for size in 0..self.size {
|
||||
let coeffs = self.at(col, size);
|
||||
write!(f, " Size {}: [", size)?;
|
||||
|
||||
let max_show = 100;
|
||||
let show_count = coeffs.len().min(max_show);
|
||||
|
||||
for (i, &coeff) in coeffs.iter().take(show_count).enumerate() {
|
||||
if i > 0 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
write!(f, "{}", coeff)?;
|
||||
}
|
||||
|
||||
if coeffs.len() > max_show {
|
||||
write!(f, ", ... ({} more)", coeffs.len() - max_show)?;
|
||||
}
|
||||
|
||||
writeln!(f, "]")?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,78 +1,57 @@
|
||||
use std::fmt;
|
||||
|
||||
use crate::{
|
||||
hal::{
|
||||
api::{TakeSlice, VecZnxDftToVecZnxBigTmpBytes, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero},
|
||||
layouts::{
|
||||
Data, DataRef, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftBytesOf, VecZnxDftOwned,
|
||||
VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef,
|
||||
},
|
||||
oep::{
|
||||
VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftCopyImpl,
|
||||
VecZnxDftFromBytesImpl, VecZnxDftFromVecZnxImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl,
|
||||
VecZnxDftSubImpl, VecZnxDftToVecZnxBigConsumeImpl, VecZnxDftToVecZnxBigImpl, VecZnxDftToVecZnxBigTmpAImpl,
|
||||
VecZnxDftToVecZnxBigTmpBytesImpl, VecZnxDftZeroImpl,
|
||||
},
|
||||
use poulpy_hal::{
|
||||
api::{TakeSlice, VecZnxDftToVecZnxBigTmpBytes, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero},
|
||||
layouts::{
|
||||
Backend, Data, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut,
|
||||
VecZnxDftToRef, VecZnxToRef,
|
||||
},
|
||||
implementation::cpu_spqlios::{
|
||||
ffi::{vec_znx_big, vec_znx_dft},
|
||||
module_fft64::FFT64,
|
||||
oep::{
|
||||
VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftCopyImpl,
|
||||
VecZnxDftFromBytesImpl, VecZnxDftFromVecZnxImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl, VecZnxDftSubImpl,
|
||||
VecZnxDftToVecZnxBigConsumeImpl, VecZnxDftToVecZnxBigImpl, VecZnxDftToVecZnxBigTmpAImpl,
|
||||
VecZnxDftToVecZnxBigTmpBytesImpl, VecZnxDftZeroImpl,
|
||||
},
|
||||
};
|
||||
|
||||
const VEC_ZNX_DFT_FFT64_WORDSIZE: usize = 1;
|
||||
use crate::cpu_spqlios::{
|
||||
FFT64,
|
||||
ffi::{vec_znx_big, vec_znx_dft},
|
||||
};
|
||||
|
||||
impl<D: Data> ZnxSliceSize for VecZnxDft<D, FFT64> {
|
||||
fn sl(&self) -> usize {
|
||||
VEC_ZNX_DFT_FFT64_WORDSIZE * self.n() * self.cols()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> VecZnxDftBytesOf for VecZnxDft<D, FFT64> {
|
||||
fn bytes_of(n: usize, cols: usize, size: usize) -> usize {
|
||||
VEC_ZNX_DFT_FFT64_WORDSIZE * n * cols * size * size_of::<f64>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> ZnxView for VecZnxDft<D, FFT64> {
|
||||
type Scalar = f64;
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftFromBytesImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<FFT64> {
|
||||
unsafe impl VecZnxDftFromBytesImpl<Self> for FFT64 {
|
||||
fn vec_znx_dft_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<Self> {
|
||||
VecZnxDft::<Vec<u8>, FFT64>::from_bytes(n, cols, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftAllocBytesImpl<FFT64> for FFT64 {
|
||||
unsafe impl VecZnxDftAllocBytesImpl<Self> for FFT64 {
|
||||
fn vec_znx_dft_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize {
|
||||
VecZnxDft::<Vec<u8>, FFT64>::bytes_of(n, cols, size)
|
||||
FFT64::layout_prep_word_count() * n * cols * size * size_of::<<FFT64 as Backend>::ScalarPrep>()
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftAllocImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxDftOwned<FFT64> {
|
||||
unsafe impl VecZnxDftAllocImpl<Self> for FFT64 {
|
||||
fn vec_znx_dft_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxDftOwned<Self> {
|
||||
VecZnxDftOwned::alloc(n, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftToVecZnxBigTmpBytesImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(module: &Module<FFT64>, n: usize) -> usize {
|
||||
unsafe impl VecZnxDftToVecZnxBigTmpBytesImpl<Self> for FFT64 {
|
||||
fn vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(module: &Module<Self>, n: usize) -> usize {
|
||||
unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(module.ptr(), n as u64) as usize }
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftToVecZnxBigImpl<FFT64> for FFT64 {
|
||||
unsafe impl VecZnxDftToVecZnxBigImpl<Self> for FFT64 {
|
||||
fn vec_znx_dft_to_vec_znx_big_impl<R, A>(
|
||||
module: &Module<FFT64>,
|
||||
module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<FFT64>,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
{
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
@@ -104,11 +83,11 @@ unsafe impl VecZnxDftToVecZnxBigImpl<FFT64> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftToVecZnxBigTmpAImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_to_vec_znx_big_tmp_a_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
|
||||
unsafe impl VecZnxDftToVecZnxBigTmpAImpl<Self> for FFT64 {
|
||||
fn vec_znx_dft_to_vec_znx_big_tmp_a_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxDftToMut<FFT64>,
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxDftToMut<Self>,
|
||||
{
|
||||
let mut res_mut: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut();
|
||||
@@ -132,10 +111,10 @@ unsafe impl VecZnxDftToVecZnxBigTmpAImpl<FFT64> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftToVecZnxBigConsumeImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_to_vec_znx_big_consume_impl<D: Data>(module: &Module<FFT64>, mut a: VecZnxDft<D, FFT64>) -> VecZnxBig<D, FFT64>
|
||||
unsafe impl VecZnxDftToVecZnxBigConsumeImpl<Self> for FFT64 {
|
||||
fn vec_znx_dft_to_vec_znx_big_consume_impl<D: Data>(module: &Module<Self>, mut a: VecZnxDft<D, FFT64>) -> VecZnxBig<D, FFT64>
|
||||
where
|
||||
VecZnxDft<D, FFT64>: VecZnxDftToMut<FFT64>,
|
||||
VecZnxDft<D, FFT64>: VecZnxDftToMut<Self>,
|
||||
{
|
||||
let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut();
|
||||
|
||||
@@ -158,9 +137,9 @@ unsafe impl VecZnxDftToVecZnxBigConsumeImpl<FFT64> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftFromVecZnxImpl<FFT64> for FFT64 {
|
||||
unsafe impl VecZnxDftFromVecZnxImpl<Self> for FFT64 {
|
||||
fn vec_znx_dft_from_vec_znx_impl<R, A>(
|
||||
module: &Module<FFT64>,
|
||||
module: &Module<Self>,
|
||||
step: usize,
|
||||
offset: usize,
|
||||
res: &mut R,
|
||||
@@ -168,7 +147,7 @@ unsafe impl VecZnxDftFromVecZnxImpl<FFT64> for FFT64 {
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
) where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
@@ -196,19 +175,12 @@ unsafe impl VecZnxDftFromVecZnxImpl<FFT64> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftAddImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_add_impl<R, A, D>(
|
||||
module: &Module<FFT64>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &D,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
D: VecZnxDftToRef<FFT64>,
|
||||
unsafe impl VecZnxDftAddImpl<Self> for FFT64 {
|
||||
fn vec_znx_dft_add_impl<R, A, D>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
D: VecZnxDftToRef<Self>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
@@ -235,11 +207,11 @@ unsafe impl VecZnxDftAddImpl<FFT64> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftAddInplaceImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_add_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
unsafe impl VecZnxDftAddInplaceImpl<Self> for FFT64 {
|
||||
fn vec_znx_dft_add_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
@@ -262,19 +234,12 @@ unsafe impl VecZnxDftAddInplaceImpl<FFT64> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftSubImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_sub_impl<R, A, D>(
|
||||
module: &Module<FFT64>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &D,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
D: VecZnxDftToRef<FFT64>,
|
||||
unsafe impl VecZnxDftSubImpl<Self> for FFT64 {
|
||||
fn vec_znx_dft_sub_impl<R, A, D>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
D: VecZnxDftToRef<Self>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
@@ -301,11 +266,11 @@ unsafe impl VecZnxDftSubImpl<FFT64> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftSubABInplaceImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_sub_ab_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
unsafe impl VecZnxDftSubABInplaceImpl<Self> for FFT64 {
|
||||
fn vec_znx_dft_sub_ab_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
@@ -328,11 +293,11 @@ unsafe impl VecZnxDftSubABInplaceImpl<FFT64> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftSubBAInplaceImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_sub_ba_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
unsafe impl VecZnxDftSubBAInplaceImpl<Self> for FFT64 {
|
||||
fn vec_znx_dft_sub_ba_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
@@ -355,9 +320,9 @@ unsafe impl VecZnxDftSubBAInplaceImpl<FFT64> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftCopyImpl<FFT64> for FFT64 {
|
||||
unsafe impl VecZnxDftCopyImpl<Self> for FFT64 {
|
||||
fn vec_znx_dft_copy_impl<R, A>(
|
||||
_module: &Module<FFT64>,
|
||||
_module: &Module<Self>,
|
||||
step: usize,
|
||||
offset: usize,
|
||||
res: &mut R,
|
||||
@@ -365,8 +330,8 @@ unsafe impl VecZnxDftCopyImpl<FFT64> for FFT64 {
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
) where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
@@ -388,46 +353,11 @@ unsafe impl VecZnxDftCopyImpl<FFT64> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftZeroImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_zero_impl<R>(_module: &Module<FFT64>, res: &mut R)
|
||||
unsafe impl VecZnxDftZeroImpl<Self> for FFT64 {
|
||||
fn vec_znx_dft_zero_impl<R>(_module: &Module<Self>, res: &mut R)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
R: VecZnxDftToMut<Self>,
|
||||
{
|
||||
res.to_mut().data.fill(0);
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> fmt::Display for VecZnxDft<D, FFT64> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
writeln!(
|
||||
f,
|
||||
"VecZnxDft(n={}, cols={}, size={})",
|
||||
self.n, self.cols, self.size
|
||||
)?;
|
||||
|
||||
for col in 0..self.cols {
|
||||
writeln!(f, "Column {}:", col)?;
|
||||
for size in 0..self.size {
|
||||
let coeffs = self.at(col, size);
|
||||
write!(f, " Size {}: [", size)?;
|
||||
|
||||
let max_show = 100;
|
||||
let show_count = coeffs.len().min(max_show);
|
||||
|
||||
for (i, &coeff) in coeffs.iter().take(show_count).enumerate() {
|
||||
if i > 0 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
write!(f, "{}", coeff)?;
|
||||
}
|
||||
|
||||
if coeffs.len() > max_show {
|
||||
write!(f, ", ... ({} more)", coeffs.len() - max_show)?;
|
||||
}
|
||||
|
||||
writeln!(f, "]")?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,39 +1,23 @@
|
||||
use crate::{
|
||||
hal::{
|
||||
api::{TakeSlice, VmpApplyTmpBytes, VmpPrepareTmpBytes, ZnxInfos, ZnxView, ZnxViewMut},
|
||||
layouts::{
|
||||
DataRef, MatZnx, MatZnxToRef, Module, Scratch, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, VmpPMat, VmpPMatBytesOf,
|
||||
VmpPMatOwned, VmpPMatToMut, VmpPMatToRef,
|
||||
},
|
||||
oep::{
|
||||
VmpApplyAddImpl, VmpApplyAddTmpBytesImpl, VmpApplyImpl, VmpApplyTmpBytesImpl, VmpPMatAllocBytesImpl,
|
||||
VmpPMatAllocImpl, VmpPMatFromBytesImpl, VmpPMatPrepareImpl, VmpPrepareTmpBytesImpl,
|
||||
},
|
||||
use poulpy_hal::{
|
||||
api::{TakeSlice, VmpApplyTmpBytes, VmpPrepareTmpBytes, ZnxInfos, ZnxView, ZnxViewMut},
|
||||
layouts::{
|
||||
Backend, MatZnx, MatZnxToRef, Module, Scratch, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, VmpPMat, VmpPMatOwned,
|
||||
VmpPMatToMut, VmpPMatToRef,
|
||||
},
|
||||
implementation::cpu_spqlios::{
|
||||
ffi::{vec_znx_dft::vec_znx_dft_t, vmp},
|
||||
module_fft64::FFT64,
|
||||
oep::{
|
||||
VmpApplyAddImpl, VmpApplyAddTmpBytesImpl, VmpApplyImpl, VmpApplyTmpBytesImpl, VmpPMatAllocBytesImpl, VmpPMatAllocImpl,
|
||||
VmpPMatFromBytesImpl, VmpPMatPrepareImpl, VmpPrepareTmpBytesImpl,
|
||||
},
|
||||
};
|
||||
|
||||
const VMP_PMAT_FFT64_WORDSIZE: usize = 1;
|
||||
use crate::cpu_spqlios::{
|
||||
FFT64,
|
||||
ffi::{vec_znx_dft::vec_znx_dft_t, vmp},
|
||||
};
|
||||
|
||||
impl<D: DataRef> ZnxView for VmpPMat<D, FFT64> {
|
||||
type Scalar = f64;
|
||||
}
|
||||
|
||||
impl VmpPMatBytesOf for FFT64 {
|
||||
fn vmp_pmat_bytes_of(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
VMP_PMAT_FFT64_WORDSIZE * n * rows * cols_in * cols_out * size * size_of::<f64>()
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpPMatAllocBytesImpl<FFT64> for FFT64
|
||||
where
|
||||
FFT64: VmpPMatBytesOf,
|
||||
{
|
||||
unsafe impl VmpPMatAllocBytesImpl<FFT64> for FFT64 {
|
||||
fn vmp_pmat_alloc_bytes_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
FFT64::vmp_pmat_bytes_of(n, rows, cols_in, cols_out, size)
|
||||
FFT64::layout_prep_word_count() * n * rows * cols_in * cols_out * size * size_of::<f64>()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -251,8 +235,6 @@ unsafe impl VmpApplyAddImpl<FFT64> for FFT64 {
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
use crate::hal::api::ZnxInfos;
|
||||
|
||||
assert_eq!(b.n(), res.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
assert_eq!(
|
||||
9
poulpy-backend/src/cpu_spqlios/mod.rs
Normal file
9
poulpy-backend/src/cpu_spqlios/mod.rs
Normal file
@@ -0,0 +1,9 @@
|
||||
mod ffi;
|
||||
mod fft64;
|
||||
mod ntt120;
|
||||
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
|
||||
pub use fft64::*;
|
||||
pub use ntt120::*;
|
||||
@@ -1,7 +1,7 @@
|
||||
mod module;
|
||||
mod scratch;
|
||||
mod svp_ppol;
|
||||
mod vec_znx;
|
||||
mod vec_znx_big;
|
||||
mod vec_znx_dft;
|
||||
mod vmp_pmat;
|
||||
|
||||
pub use module::NTT120;
|
||||
@@ -1,25 +1,29 @@
|
||||
use std::ptr::NonNull;
|
||||
|
||||
use crate::{
|
||||
hal::{
|
||||
layouts::{Backend, Module},
|
||||
oep::ModuleNewImpl,
|
||||
},
|
||||
implementation::cpu_spqlios::{
|
||||
CPUAVX,
|
||||
ffi::module::{MODULE, delete_module_info, new_module_info},
|
||||
},
|
||||
use poulpy_hal::{
|
||||
layouts::{Backend, Module},
|
||||
oep::ModuleNewImpl,
|
||||
};
|
||||
|
||||
use crate::cpu_spqlios::ffi::module::{MODULE, delete_module_info, new_module_info};
|
||||
|
||||
pub struct NTT120;
|
||||
|
||||
impl CPUAVX for NTT120 {}
|
||||
|
||||
impl Backend for NTT120 {
|
||||
type ScalarPrep = i64;
|
||||
type ScalarBig = i128;
|
||||
type Handle = MODULE;
|
||||
unsafe fn destroy(handle: NonNull<Self::Handle>) {
|
||||
unsafe { delete_module_info(handle.as_ptr()) }
|
||||
}
|
||||
|
||||
fn layout_big_word_count() -> usize {
|
||||
4
|
||||
}
|
||||
|
||||
fn layout_prep_word_count() -> usize {
|
||||
1
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ModuleNewImpl<Self> for NTT120 {
|
||||
24
poulpy-backend/src/cpu_spqlios/ntt120/svp_ppol.rs
Normal file
24
poulpy-backend/src/cpu_spqlios/ntt120/svp_ppol.rs
Normal file
@@ -0,0 +1,24 @@
|
||||
use poulpy_hal::{
|
||||
layouts::{Backend, SvpPPolOwned},
|
||||
oep::{SvpPPolAllocBytesImpl, SvpPPolAllocImpl, SvpPPolFromBytesImpl},
|
||||
};
|
||||
|
||||
use crate::cpu_spqlios::NTT120;
|
||||
|
||||
unsafe impl SvpPPolFromBytesImpl<Self> for NTT120 {
|
||||
fn svp_ppol_from_bytes_impl(n: usize, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<NTT120> {
|
||||
SvpPPolOwned::from_bytes(n, cols, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SvpPPolAllocImpl<Self> for NTT120 {
|
||||
fn svp_ppol_alloc_impl(n: usize, cols: usize) -> SvpPPolOwned<NTT120> {
|
||||
SvpPPolOwned::alloc(n, cols)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SvpPPolAllocBytesImpl<Self> for NTT120 {
|
||||
fn svp_ppol_alloc_bytes_impl(n: usize, cols: usize) -> usize {
|
||||
NTT120::layout_prep_word_count() * n * cols * size_of::<i64>()
|
||||
}
|
||||
}
|
||||
9
poulpy-backend/src/cpu_spqlios/ntt120/vec_znx_big.rs
Normal file
9
poulpy-backend/src/cpu_spqlios/ntt120/vec_znx_big.rs
Normal file
@@ -0,0 +1,9 @@
|
||||
use poulpy_hal::{layouts::Backend, oep::VecZnxBigAllocBytesImpl};
|
||||
|
||||
use crate::cpu_spqlios::NTT120;
|
||||
|
||||
unsafe impl VecZnxBigAllocBytesImpl<NTT120> for NTT120 {
|
||||
fn vec_znx_big_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize {
|
||||
NTT120::layout_big_word_count() * n * cols * size * size_of::<i128>()
|
||||
}
|
||||
}
|
||||
18
poulpy-backend/src/cpu_spqlios/ntt120/vec_znx_dft.rs
Normal file
18
poulpy-backend/src/cpu_spqlios/ntt120/vec_znx_dft.rs
Normal file
@@ -0,0 +1,18 @@
|
||||
use poulpy_hal::{
|
||||
layouts::{Backend, VecZnxDftOwned},
|
||||
oep::{VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl},
|
||||
};
|
||||
|
||||
use crate::cpu_spqlios::NTT120;
|
||||
|
||||
unsafe impl VecZnxDftAllocBytesImpl<NTT120> for NTT120 {
|
||||
fn vec_znx_dft_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize {
|
||||
NTT120::layout_prep_word_count() * n * cols * size * size_of::<i64>()
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftAllocImpl<NTT120> for NTT120 {
|
||||
fn vec_znx_dft_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxDftOwned<NTT120> {
|
||||
VecZnxDftOwned::alloc(n, cols, size)
|
||||
}
|
||||
}
|
||||
1
poulpy-backend/src/cpu_spqlios/ntt120/vmp_pmat.rs
Normal file
1
poulpy-backend/src/cpu_spqlios/ntt120/vmp_pmat.rs
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
# Use the Google style in this project.
|
||||
BasedOnStyle: Google
|
||||
|
||||
# Some folks prefer to write "int& foo" while others prefer "int &foo". The
|
||||
# Google Style Guide only asks for consistency within a project, we chose
|
||||
# "int& foo" for this project:
|
||||
DerivePointerAlignment: false
|
||||
PointerAlignment: Left
|
||||
|
||||
# The Google Style Guide only asks for consistency w.r.t. "east const" vs.
|
||||
# "const west" alignment of cv-qualifiers. In this project we use "east const".
|
||||
QualifierAlignment: Left
|
||||
|
||||
ColumnLimit: 120
|
||||
20
poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/.github/workflows/auto-release.yml
vendored
Normal file
20
poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/.github/workflows/auto-release.yml
vendored
Normal file
@@ -0,0 +1,20 @@
|
||||
name: Auto-Release
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: Auto-Release
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 3
|
||||
# sparse-checkout: manifest.yaml scripts/auto-release.sh
|
||||
|
||||
- run:
|
||||
${{github.workspace}}/scripts/auto-release.sh
|
||||
6
poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/.gitignore
vendored
Normal file
6
poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/.gitignore
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
cmake-build-*
|
||||
.idea
|
||||
|
||||
build
|
||||
.vscode
|
||||
.*.sh
|
||||
@@ -0,0 +1,69 @@
|
||||
cmake_minimum_required(VERSION 3.8)
|
||||
project(spqlios)
|
||||
|
||||
# read the current version from the manifest file
|
||||
file(READ "manifest.yaml" manifest)
|
||||
string(REGEX MATCH "version: +(([0-9]+)\\.([0-9]+)\\.([0-9]+))" SPQLIOS_VERSION_BLAH ${manifest})
|
||||
#message(STATUS "Version: ${SPQLIOS_VERSION_BLAH}")
|
||||
set(SPQLIOS_VERSION ${CMAKE_MATCH_1})
|
||||
set(SPQLIOS_VERSION_MAJOR ${CMAKE_MATCH_2})
|
||||
set(SPQLIOS_VERSION_MINOR ${CMAKE_MATCH_3})
|
||||
set(SPQLIOS_VERSION_PATCH ${CMAKE_MATCH_4})
|
||||
message(STATUS "Compiling spqlios-fft version: ${SPQLIOS_VERSION_MAJOR}.${SPQLIOS_VERSION_MINOR}.${SPQLIOS_VERSION_PATCH}")
|
||||
|
||||
#set(ENABLE_SPQLIOS_F128 ON CACHE BOOL "Enable float128 via libquadmath")
|
||||
set(WARNING_PARANOID ON CACHE BOOL "Treat all warnings as errors")
|
||||
set(ENABLE_TESTING ON CACHE BOOL "Compiles unittests and integration tests")
|
||||
set(DEVMODE_INSTALL OFF CACHE BOOL "Install private headers and testlib (mainly for CI)")
|
||||
|
||||
if (NOT CMAKE_BUILD_TYPE OR CMAKE_BUILD_TYPE STREQUAL "")
|
||||
set(CMAKE_BUILD_TYPE "Release" CACHE STRING "Build type: Release or Debug" FORCE)
|
||||
endif()
|
||||
message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
|
||||
|
||||
if (WARNING_PARANOID)
|
||||
add_compile_options(-Wall -Werror -Wno-unused-command-line-argument)
|
||||
endif()
|
||||
|
||||
message(STATUS "CMAKE_HOST_SYSTEM_NAME: ${CMAKE_HOST_SYSTEM_NAME}")
|
||||
message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
|
||||
message(STATUS "CMAKE_SYSTEM_NAME: ${CMAKE_SYSTEM_NAME}")
|
||||
|
||||
if (CMAKE_SYSTEM_PROCESSOR MATCHES "(x86)|(X86)|(amd64)|(AMD64)")
|
||||
set(X86 ON)
|
||||
set(AARCH64 OFF)
|
||||
else ()
|
||||
set(X86 OFF)
|
||||
# set(ENABLE_SPQLIOS_F128 OFF) # float128 are only supported for x86 targets
|
||||
endif ()
|
||||
if (CMAKE_SYSTEM_PROCESSOR MATCHES "(aarch64)|(arm64)")
|
||||
set(AARCH64 ON)
|
||||
endif ()
|
||||
|
||||
if (CMAKE_SYSTEM_NAME MATCHES "(Windows)|(MSYS)")
|
||||
set(WIN32 ON)
|
||||
endif ()
|
||||
if (WIN32)
|
||||
#overrides for win32
|
||||
set(X86 OFF)
|
||||
set(AARCH64 OFF)
|
||||
set(X86_WIN32 ON)
|
||||
else()
|
||||
set(X86_WIN32 OFF)
|
||||
set(WIN32 OFF)
|
||||
endif (WIN32)
|
||||
|
||||
message(STATUS "--> WIN32: ${WIN32}")
|
||||
message(STATUS "--> X86_WIN32: ${X86_WIN32}")
|
||||
message(STATUS "--> X86_LINUX: ${X86}")
|
||||
message(STATUS "--> AARCH64: ${AARCH64}")
|
||||
|
||||
# compiles the main library in spqlios
|
||||
add_subdirectory(spqlios)
|
||||
|
||||
# compiles and activates unittests and itests
|
||||
if (${ENABLE_TESTING})
|
||||
enable_testing()
|
||||
add_subdirectory(test)
|
||||
endif()
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
# Contributing to SPQlios-fft
|
||||
|
||||
The spqlios-fft team encourages contributions.
|
||||
We encourage users to fix bugs, improve the documentation, write tests and to enhance the code, or ask for new features.
|
||||
We encourage researchers to contribute with implementations of their FFT or NTT algorithms.
|
||||
In the following we are trying to give some guidance on how to contribute effectively.
|
||||
|
||||
## Communication ##
|
||||
|
||||
Communication in the spqlios-fft project happens mainly on [GitHub](https://github.com/tfhe/spqlios-fft/issues).
|
||||
|
||||
All communications are public, so please make sure to maintain professional behaviour in
|
||||
all published comments. See [Code of Conduct](https://www.contributor-covenant.org/version/2/1/code_of_conduct/) for
|
||||
guidelines.
|
||||
|
||||
## Reporting Bugs or Requesting features ##
|
||||
|
||||
Bug should be filed at [https://github.com/tfhe/spqlios-fft/issues](https://github.com/tfhe/spqlios-fft/issues).
|
||||
|
||||
Features can also be requested there, in this case, please ensure that the features you request are self-contained,
|
||||
easy to define, and generic enough to be used in different use-cases. Please provide an example of use-cases if
|
||||
possible.
|
||||
|
||||
## Setting up topic branches and generating pull requests
|
||||
|
||||
This section applies to people that already have write access to the repository. Specific instructions for pull-requests
|
||||
from public forks will be given later.
|
||||
|
||||
To implement some changes, please follow these steps:
|
||||
|
||||
- Create a "topic branch". Usually, the branch name should be `username/small-title`
|
||||
or better `username/issuenumber-small-title` where `issuenumber` is the number of
|
||||
the github issue number that is tackled.
|
||||
- Push any needed commits to your branch. Make sure it compiles in `CMAKE_BUILD_TYPE=Debug` and `=Release`, with `-DWARNING_PARANOID=ON`.
|
||||
- When the branch is nearly ready for review, please open a pull request, and add the label `check-on-arm`
|
||||
- Do as many commits as necessary until all CI checks pass and all PR comments have been resolved.
|
||||
|
||||
> _During the process, you may optionnally use `git rebase -i` to clean up your commit history. If you elect to do so,
|
||||
please at the very least make sure that nobody else is working or has forked from your branch: the conflicts it would generate
|
||||
and the human hours to fix them are not worth it. `Git merge` remains the preferred option._
|
||||
|
||||
- Finally, when all reviews are positive and all CI checks pass, you may merge your branch via the github webpage.
|
||||
|
||||
### Keep your pull requests limited to a single issue
|
||||
|
||||
Pull requests should be as small/atomic as possible.
|
||||
|
||||
### Coding Conventions
|
||||
|
||||
* Please make sure that your code is formatted according to the `.clang-format` file and
|
||||
that all files end with a newline character.
|
||||
* Please make sure that all the functions declared in the public api have relevant doxygen comments.
|
||||
Preferably, functions in the private apis should also contain a brief doxygen description.
|
||||
|
||||
### Versions and History
|
||||
|
||||
* **Stable API** The project uses semantic versioning on the functions that are listed as `stable` in the documentation. A version has
|
||||
the form `x.y.z`
|
||||
* a patch release that increments `z` does not modify the stable API.
|
||||
* a minor release that increments `y` adds a new feature to the stable API.
|
||||
* In the unlikely case where we need to change or remove a feature, we will trigger a major release that
|
||||
increments `x`.
|
||||
|
||||
> _If any, we will mark those features as deprecated at least six months before the major release._
|
||||
|
||||
* **Experimental API** Features that are not part of the stable section in the documentation are experimental features: you may test them at
|
||||
your own risk,
|
||||
but keep in mind that semantic versioning does not apply to them.
|
||||
|
||||
> _If you have a use-case that uses an experimental feature, we encourage
|
||||
> you to tell us about it, so that this feature reaches to the stable section faster!_
|
||||
|
||||
* **Version history** The current version is reported in `manifest.yaml`, any change of version comes up with a tag on the main branch, and the history between releases is summarized in `Changelog.md`. It is the main source of truth for anyone who wishes to
|
||||
get insight about
|
||||
the history of the repository (not the commit graph).
|
||||
|
||||
> Note: _The commit graph of git is for git's internal use only. Its main purpose is to reduce potential merge conflicts to a minimum, even in scenario where multiple features are developped in parallel: it may therefore be non-linear. If, as humans, we like to see a linear history, please read `Changelog.md` instead!_
|
||||
@@ -0,0 +1,18 @@
|
||||
# Changelog
|
||||
|
||||
All notable changes to this project will be documented in this file.
|
||||
this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [2.0.0] - 2024-08-21
|
||||
|
||||
- Initial release of the `vec_znx` (except convolution products), `vec_rnx` and `zn` apis.
|
||||
- Hardware acceleration available: AVX2 (most parts)
|
||||
- APIs are documented in the wiki and are in "beta mode": during the 2.x -> 3.x transition, functions whose API is satisfactory in test projects will pass in "stable mode".
|
||||
|
||||
## [1.0.0] - 2023-07-18
|
||||
|
||||
- Initial release of the double precision fft on the reim and cplx backends
|
||||
- Coeffs-space conversions cplx <-> znx32 and tnx32
|
||||
- FFT-space conversions cplx <-> reim4 layouts
|
||||
- FFT-space multiplications on the cplx, reim and reim4 layouts.
|
||||
- In this first release, the only platform supported is linux x86_64 (generic C code, and avx2/fma). It compiles on arm64, but without any acceleration.
|
||||
201
poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/LICENSE
Normal file
201
poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/LICENSE
Normal file
@@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
65
poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/README.md
Normal file
65
poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/README.md
Normal file
@@ -0,0 +1,65 @@
|
||||
# SPQlios library
|
||||
|
||||
|
||||
|
||||
The SPQlios library provides fast arithmetic for Fully Homomorphic Encryption, and other lattice constructions that arise in post quantum cryptography.
|
||||
|
||||
<img src="docs/api-full.svg">
|
||||
|
||||
Namely, it is divided into 4 sections:
|
||||
|
||||
* The low-level DFT section support FFT over 64-bit floats, as well as NTT modulo one fixed 120-bit modulus. It is an upgrade of the original spqlios-fft module embedded in the TFHE library since 2016. The DFT section exposes the traditional DFT, inverse-DFT, and coefficient-wise multiplications in DFT space.
|
||||
* The VEC_ZNX section exposes fast algebra over vectors of small integer polynomial modulo $X^N+1$. It proposed in particular efficient (prepared) vector-matrix products, scalar-vector products, convolution products, and element-wise products, operations that naturally occurs on gadget-decomposed Ring-LWE coordinates.
|
||||
* The RNX section is a simpler variant of VEC_ZNX, to represent single polynomials modulo $X^N+1$ (over the reals or over the torus) when the coefficient precision fits on 64-bit doubles. The small vector-matrix API of the RNX section is particularly adapted to reproducing the fastest CGGI-based bootstrappings.
|
||||
* The ZN section focuses over vector and matrix algebra over scalars (used by scalar LWE, or scalar key-switches, but also on non-ring schemes like Frodo, FrodoPIR, and SimplePIR).
|
||||
|
||||
### A high value target for hardware accelerations
|
||||
|
||||
SPQlios is more than a library, it is also a good target for hardware developers.
|
||||
On one hand, the arithmetic operations that are defined in the library have a clear standalone mathematical definition. And at the same time, the amount of work in each operations is sufficiently large so that meaningful functions only require a few of these.
|
||||
|
||||
This makes the SPQlios API a high value target for hardware acceleration, that targets FHE.
|
||||
|
||||
### SPQLios is not an FHE library, but a huge enabler
|
||||
|
||||
SPQlios itself is not an FHE library: there is no ciphertext, plaintext or key. It is a mathematical library that exposes efficient algebra over polynomials. Using the functions exposed, it is possible to quickly build efficient FHE libraries, with support for the main schemes based on Ring-LWE: BFV, BGV, CGGI, DM, CKKS.
|
||||
|
||||
|
||||
## Dependencies
|
||||
|
||||
The SPQLIOS-FFT library is a C library that can be compiled with a standard C compiler, and depends only on libc and libm. The API
|
||||
interface can be used in a regular C code, and any other language via classical foreign APIs.
|
||||
|
||||
The unittests and integration tests are in an optional part of the code, and are written in C++. These tests rely on
|
||||
[```benchmark```](https://github.com/google/benchmark), and [```gtest```](https://github.com/google/googletest) libraries, and therefore require a C++17 compiler.
|
||||
|
||||
Currently, the project has been tested with the gcc,g++ >= 11.3.0 compiler under Linux (x86_64). In the future, we plan to
|
||||
extend the compatibility to other compilers, platforms and operating systems.
|
||||
|
||||
|
||||
## Installation
|
||||
|
||||
The library uses a classical ```cmake``` build mechanism: use ```cmake``` to create a ```build``` folder in the top level directory and run ```make``` from inside it. This assumes that the standard tool ```cmake``` is already installed on the system, and an up-to-date c++ compiler (i.e. g++ >=11.3.0) as well.
|
||||
|
||||
It will compile the shared library in optimized mode, and ```make install``` install it to the desired prefix folder (by default ```/usr/local/lib```).
|
||||
|
||||
If you want to choose additional compile options (i.e. other installation folder, debug mode, tests), you need to run cmake manually and pass the desired options:
|
||||
```
|
||||
mkdir build
|
||||
cd build
|
||||
cmake ../src -CMAKE_INSTALL_PREFIX=/usr/
|
||||
make
|
||||
```
|
||||
The available options are the following:
|
||||
|
||||
| Variable Name | values |
|
||||
| -------------------- | ------------------------------------------------------------ |
|
||||
| CMAKE_INSTALL_PREFIX | */usr/local* installation folder (libs go in lib/ and headers in include/) |
|
||||
| WARNING_PARANOID | All warnings are shown and treated as errors. Off by default |
|
||||
| ENABLE_TESTING | Compiles unit tests and integration tests |
|
||||
|
||||
------
|
||||
|
||||
<img src="docs/logo-sandboxaq-black.svg">
|
||||
|
||||
<img src="docs/logo-inpher1.png">
|
||||
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 550 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 24 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 24 KiB |
@@ -0,0 +1,139 @@
|
||||
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
|
||||
<!-- Generator: Adobe Illustrator 24.2.1, SVG Export Plug-In . SVG Version: 6.00 Build 0) -->
|
||||
|
||||
<svg
|
||||
version="1.1"
|
||||
id="Layer_1"
|
||||
x="0px"
|
||||
y="0px"
|
||||
viewBox="0 0 270 49.4"
|
||||
style="enable-background:new 0 0 270 49.4;"
|
||||
xml:space="preserve"
|
||||
sodipodi:docname="logo-sandboxaq-black.svg"
|
||||
inkscape:version="1.3.2 (1:1.3.2+202311252150+091e20ef0f)"
|
||||
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
|
||||
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
xmlns:svg="http://www.w3.org/2000/svg"><defs
|
||||
id="defs9839">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
</defs><sodipodi:namedview
|
||||
id="namedview9837"
|
||||
pagecolor="#ffffff"
|
||||
bordercolor="#000000"
|
||||
borderopacity="0.25"
|
||||
inkscape:showpageshadow="2"
|
||||
inkscape:pageopacity="0.0"
|
||||
inkscape:pagecheckerboard="0"
|
||||
inkscape:deskcolor="#d1d1d1"
|
||||
showgrid="false"
|
||||
inkscape:zoom="1.194332"
|
||||
inkscape:cx="135.64068"
|
||||
inkscape:cy="25.118645"
|
||||
inkscape:window-width="804"
|
||||
inkscape:window-height="436"
|
||||
inkscape:window-x="190"
|
||||
inkscape:window-y="27"
|
||||
inkscape:window-maximized="0"
|
||||
inkscape:current-layer="Layer_1" />
|
||||
<style
|
||||
type="text/css"
|
||||
id="style9786">
|
||||
.st0{fill:#EBB028;}
|
||||
.st1{fill:#FFFFFF;}
|
||||
</style>
|
||||
<text
|
||||
transform="matrix(1 0 0 1 393.832 -491.944)"
|
||||
class="st1"
|
||||
style="font-family:'Satoshi-Medium'; font-size:86.2078px;"
|
||||
id="text9788">SANDBOX </text>
|
||||
<text
|
||||
transform="matrix(1 0 0 1 896.332 -491.944)"
|
||||
class="st1"
|
||||
style="font-family:'Satoshi-Black'; font-size:86.2078px;"
|
||||
id="text9790">AQ</text>
|
||||
<g
|
||||
id="g9808">
|
||||
<g
|
||||
id="g9800">
|
||||
<g
|
||||
id="g9798">
|
||||
<path
|
||||
class="st0"
|
||||
d="m 8.9,9.7 v 3.9 l 29.6,17.1 v 2.7 c 0,1.2 -0.6,2.3 -1.6,2.9 L 31,39.8 v -4 L 1.4,18.6 V 15.9 C 1.4,14.7 2,13.6 3.1,13 Z"
|
||||
id="path9792" />
|
||||
<path
|
||||
class="st0"
|
||||
d="M 18.3,45.1 3.1,36.3 C 2.1,35.7 1.4,34.6 1.4,33.4 V 26 L 28,41.4 21.5,45.1 c -0.9,0.6 -2.2,0.6 -3.2,0 z"
|
||||
id="path9794" />
|
||||
<path
|
||||
class="st0"
|
||||
d="m 21.6,4.3 15.2,8.8 c 1,0.6 1.7,1.7 1.7,2.9 v 7.5 L 11.8,8 18.3,4.3 c 1,-0.6 2.3,-0.6 3.3,0 z"
|
||||
id="path9796" />
|
||||
</g>
|
||||
</g>
|
||||
<g
|
||||
id="g9806">
|
||||
<polygon
|
||||
class="st0"
|
||||
points="248.1,23.2 248.1,30 251.4,33.8 257.3,33.8 "
|
||||
id="polygon9802" />
|
||||
<path
|
||||
class="st0"
|
||||
d="m 246.9,31 -0.1,-0.1 h -0.1 c -0.2,0 -0.4,0 -0.6,0 -3.5,0 -5.7,-2.6 -5.7,-6.7 0,-4.1 2.2,-6.7 5.7,-6.7 3.5,0 5.7,2.6 5.7,6.7 0,0.3 0,0.6 0,0.9 l 3.6,4.2 c 0.7,-1.5 1,-3.2 1,-5.1 0,-6.5 -4.2,-11 -10.3,-11 -6.1,0 -10.3,4.5 -10.3,11 0,6.5 4.2,11 10.3,11 1.2,0 2.3,-0.2 3.4,-0.5 l 0.5,-0.2 z"
|
||||
id="path9804" />
|
||||
</g>
|
||||
</g><g
|
||||
id="g9824"
|
||||
style="fill:#1a1a1a">
|
||||
<path
|
||||
class="st1"
|
||||
d="m 58.7,13.2 c 4.6,0 7.4,2.5 7.4,6.5 h -4.6 c 0,-1.5 -1.1,-2.4 -2.9,-2.4 -1.9,0 -3.1,0.9 -3.1,2.3 0,1.3 0.7,1.9 2.2,2.2 l 3.2,0.7 c 3.8,0.8 5.6,2.6 5.6,5.9 0,4.1 -3.2,6.8 -8.1,6.8 -4.7,0 -7.8,-2.6 -7.8,-6.5 h 4.6 c 0,1.6 1.1,2.4 3.2,2.4 2.1,0 3.4,-0.8 3.4,-2.2 0,-1.2 -0.5,-1.8 -2,-2.1 l -3.2,-0.7 c -3.8,-0.8 -5.7,-2.9 -5.7,-6.4 0,-3.7 3.2,-6.5 7.8,-6.5 z"
|
||||
id="path9810"
|
||||
style="fill:#1a1a1a" />
|
||||
<path
|
||||
class="st1"
|
||||
d="M 70.4,34.9 78,13.6 h 4.5 l 7.6,21.3 h -4.9 l -1.5,-4.5 h -6.9 l -1.5,4.5 z m 7.7,-8.4 h 4.2 L 80.8,22 c -0.2,-0.7 -0.5,-1.6 -0.6,-2.1 -0.1,0.5 -0.3,1.3 -0.6,2.1 z"
|
||||
id="path9812"
|
||||
style="fill:#1a1a1a" />
|
||||
<path
|
||||
class="st1"
|
||||
d="M 95.3,34.9 V 13.6 h 4.6 l 9,13.5 V 13.6 h 4.6 v 21.3 h -4.6 l -9,-13.5 v 13.5 z"
|
||||
id="path9814"
|
||||
style="fill:#1a1a1a" />
|
||||
<path
|
||||
class="st1"
|
||||
d="M 120.7,34.9 V 13.6 h 8 c 6.2,0 10.6,4.4 10.6,10.7 0,6.2 -4.2,10.6 -10.3,10.6 z m 4.7,-17 v 12.6 h 3.2 c 3.7,0 5.8,-2.3 5.8,-6.3 0,-4 -2.3,-6.4 -6.1,-6.4 h -2.9 z"
|
||||
id="path9816"
|
||||
style="fill:#1a1a1a" />
|
||||
<path
|
||||
class="st1"
|
||||
d="m 145.4,13.6 h 8.8 c 4.3,0 6.9,2.2 6.9,5.9 0,2.3 -1,3.9 -3,4.8 2.1,0.7 3.2,2.3 3.2,4.7 0,3.8 -2.5,5.9 -7.1,5.9 h -8.8 z m 4.7,4.1 v 4.6 h 3.7 c 1.7,0 2.6,-0.8 2.6,-2.4 0,-1.5 -0.9,-2.3 -2.6,-2.3 h -3.7 z m 0,8.5 v 4.6 h 3.9 c 1.7,0 2.6,-0.8 2.6,-2.4 0,-1.4 -0.9,-2.2 -2.6,-2.2 z"
|
||||
id="path9818"
|
||||
style="fill:#1a1a1a" />
|
||||
<path
|
||||
class="st1"
|
||||
d="m 176.5,35.2 c -6.1,0 -10.4,-4.5 -10.4,-11 0,-6.5 4.3,-11 10.4,-11 6.2,0 10.4,4.5 10.4,11 0,6.5 -4.2,11 -10.4,11 z m 0.1,-17.5 c -3.4,0 -5.5,2.4 -5.5,6.5 0,4.1 2.1,6.5 5.5,6.5 3.4,0 5.5,-2.5 5.5,-6.5 0,-4 -2.1,-6.5 -5.5,-6.5 z"
|
||||
id="path9820"
|
||||
style="fill:#1a1a1a" />
|
||||
<path
|
||||
class="st1"
|
||||
d="m 190.4,13.6 h 5.5 l 1.8,2.8 c 0.8,1.2 1.5,2.5 2.5,4.3 l 4.3,-7 h 5.4 l -6.7,10.6 6.7,10.6 h -5.5 L 203,32.7 c -1.1,-1.7 -1.8,-3 -2.8,-4.9 l -4.6,7.1 h -5.5 l 7.1,-10.6 z"
|
||||
id="path9822"
|
||||
style="fill:#1a1a1a" />
|
||||
</g><path
|
||||
class="st0"
|
||||
d="m 229,34.9 h 4.7 L 226,13.6 h -4.3 L 214,34.8 h 4.6 l 1.6,-4.5 h 7.1 z m -5.1,-14.6 c 0,0 0,0 0,0 0,-0.1 0,-0.1 0,0 l 2.2,6.2 h -4.4 z"
|
||||
id="path9826" /><g
|
||||
id="g9832">
|
||||
<path
|
||||
class="st1"
|
||||
d="m 259.5,11.2 h 3.9 v 1 h -1.3 v 3.1 h -1.3 v -3.1 h -1.3 z m 4.5,0 h 1.7 l 0.6,2.5 0.6,-2.5 h 1.7 v 4.1 h -1 v -3.1 l -0.8,3.1 h -0.9 l -0.8,-3.1 v 3.1 h -1 v -4.1 z"
|
||||
id="path9830" />
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 5.0 KiB |
@@ -0,0 +1,133 @@
|
||||
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
|
||||
<!-- Generator: Adobe Illustrator 24.2.1, SVG Export Plug-In . SVG Version: 6.00 Build 0) -->
|
||||
|
||||
<svg
|
||||
version="1.1"
|
||||
id="Layer_1"
|
||||
x="0px"
|
||||
y="0px"
|
||||
viewBox="0 0 270 49.4"
|
||||
style="enable-background:new 0 0 270 49.4;"
|
||||
xml:space="preserve"
|
||||
sodipodi:docname="logo-sandboxaq-white.svg"
|
||||
inkscape:version="1.2.2 (1:1.2.2+202212051551+b0a8486541)"
|
||||
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
|
||||
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
xmlns:svg="http://www.w3.org/2000/svg"><defs
|
||||
id="defs9839" /><sodipodi:namedview
|
||||
id="namedview9837"
|
||||
pagecolor="#ffffff"
|
||||
bordercolor="#000000"
|
||||
borderopacity="0.25"
|
||||
inkscape:showpageshadow="2"
|
||||
inkscape:pageopacity="0.0"
|
||||
inkscape:pagecheckerboard="0"
|
||||
inkscape:deskcolor="#d1d1d1"
|
||||
showgrid="false"
|
||||
inkscape:zoom="2.3886639"
|
||||
inkscape:cx="135.22204"
|
||||
inkscape:cy="25.327967"
|
||||
inkscape:window-width="1072"
|
||||
inkscape:window-height="688"
|
||||
inkscape:window-x="0"
|
||||
inkscape:window-y="0"
|
||||
inkscape:window-maximized="1"
|
||||
inkscape:current-layer="Layer_1" />
|
||||
<style
|
||||
type="text/css"
|
||||
id="style9786">
|
||||
.st0{fill:#EBB028;}
|
||||
.st1{fill:#FFFFFF;}
|
||||
</style>
|
||||
<text
|
||||
transform="matrix(1 0 0 1 393.832 -491.944)"
|
||||
class="st1"
|
||||
style="font-family:'Satoshi-Medium'; font-size:86.2078px;"
|
||||
id="text9788">SANDBOX </text>
|
||||
<text
|
||||
transform="matrix(1 0 0 1 896.332 -491.944)"
|
||||
class="st1"
|
||||
style="font-family:'Satoshi-Black'; font-size:86.2078px;"
|
||||
id="text9790">AQ</text>
|
||||
<g
|
||||
id="g9834">
|
||||
<g
|
||||
id="g9828">
|
||||
<g
|
||||
id="g9808">
|
||||
<g
|
||||
id="g9800">
|
||||
<g
|
||||
id="g9798">
|
||||
<path
|
||||
class="st0"
|
||||
d="M8.9,9.7v3.9l29.6,17.1v2.7c0,1.2-0.6,2.3-1.6,2.9L31,39.8v-4L1.4,18.6v-2.7c0-1.2,0.6-2.3,1.7-2.9 L8.9,9.7z"
|
||||
id="path9792" />
|
||||
<path
|
||||
class="st0"
|
||||
d="M18.3,45.1L3.1,36.3c-1-0.6-1.7-1.7-1.7-2.9V26L28,41.4l-6.5,3.7C20.6,45.7,19.3,45.7,18.3,45.1z"
|
||||
id="path9794" />
|
||||
<path
|
||||
class="st0"
|
||||
d="M21.6,4.3l15.2,8.8c1,0.6,1.7,1.7,1.7,2.9v7.5L11.8,8l6.5-3.7C19.3,3.7,20.6,3.7,21.6,4.3z"
|
||||
id="path9796" />
|
||||
</g>
|
||||
</g>
|
||||
<g
|
||||
id="g9806">
|
||||
<polygon
|
||||
class="st0"
|
||||
points="248.1,23.2 248.1,30 251.4,33.8 257.3,33.8 "
|
||||
id="polygon9802" />
|
||||
<path
|
||||
class="st0"
|
||||
d="M246.9,31l-0.1-0.1l-0.1,0c-0.2,0-0.4,0-0.6,0c-3.5,0-5.7-2.6-5.7-6.7c0-4.1,2.2-6.7,5.7-6.7 s5.7,2.6,5.7,6.7c0,0.3,0,0.6,0,0.9l3.6,4.2c0.7-1.5,1-3.2,1-5.1c0-6.5-4.2-11-10.3-11c-6.1,0-10.3,4.5-10.3,11s4.2,11,10.3,11 c1.2,0,2.3-0.2,3.4-0.5l0.5-0.2L246.9,31z"
|
||||
id="path9804" />
|
||||
</g>
|
||||
</g>
|
||||
<g
|
||||
id="g9824">
|
||||
<path
|
||||
class="st1"
|
||||
d="M58.7,13.2c4.6,0,7.4,2.5,7.4,6.5h-4.6c0-1.5-1.1-2.4-2.9-2.4c-1.9,0-3.1,0.9-3.1,2.3c0,1.3,0.7,1.9,2.2,2.2 l3.2,0.7c3.8,0.8,5.6,2.6,5.6,5.9c0,4.1-3.2,6.8-8.1,6.8c-4.7,0-7.8-2.6-7.8-6.5h4.6c0,1.6,1.1,2.4,3.2,2.4 c2.1,0,3.4-0.8,3.4-2.2c0-1.2-0.5-1.8-2-2.1l-3.2-0.7c-3.8-0.8-5.7-2.9-5.7-6.4C50.9,16,54.1,13.2,58.7,13.2z"
|
||||
id="path9810" />
|
||||
<path
|
||||
class="st1"
|
||||
d="M70.4,34.9L78,13.6h4.5l7.6,21.3h-4.9l-1.5-4.5h-6.9l-1.5,4.5H70.4z M78.1,26.5h4.2L80.8,22 c-0.2-0.7-0.5-1.6-0.6-2.1c-0.1,0.5-0.3,1.3-0.6,2.1L78.1,26.5z"
|
||||
id="path9812" />
|
||||
<path
|
||||
class="st1"
|
||||
d="M95.3,34.9V13.6h4.6l9,13.5V13.6h4.6v21.3h-4.6l-9-13.5v13.5H95.3z"
|
||||
id="path9814" />
|
||||
<path
|
||||
class="st1"
|
||||
d="M120.7,34.9V13.6h8c6.2,0,10.6,4.4,10.6,10.7c0,6.2-4.2,10.6-10.3,10.6H120.7z M125.4,17.9v12.6h3.2 c3.7,0,5.8-2.3,5.8-6.3c0-4-2.3-6.4-6.1-6.4H125.4z"
|
||||
id="path9816" />
|
||||
<path
|
||||
class="st1"
|
||||
d="M145.4,13.6h8.8c4.3,0,6.9,2.2,6.9,5.9c0,2.3-1,3.9-3,4.8c2.1,0.7,3.2,2.3,3.2,4.7c0,3.8-2.5,5.9-7.1,5.9 h-8.8V13.6z M150.1,17.7v4.6h3.7c1.7,0,2.6-0.8,2.6-2.4c0-1.5-0.9-2.3-2.6-2.3H150.1z M150.1,26.2v4.6h3.9c1.7,0,2.6-0.8,2.6-2.4 c0-1.4-0.9-2.2-2.6-2.2H150.1z"
|
||||
id="path9818" />
|
||||
<path
|
||||
class="st1"
|
||||
d="M176.5,35.2c-6.1,0-10.4-4.5-10.4-11s4.3-11,10.4-11c6.2,0,10.4,4.5,10.4,11S182.7,35.2,176.5,35.2z M176.6,17.7c-3.4,0-5.5,2.4-5.5,6.5c0,4.1,2.1,6.5,5.5,6.5c3.4,0,5.5-2.5,5.5-6.5C182.1,20.2,180,17.7,176.6,17.7z"
|
||||
id="path9820" />
|
||||
<path
|
||||
class="st1"
|
||||
d="M190.4,13.6h5.5l1.8,2.8c0.8,1.2,1.5,2.5,2.5,4.3l4.3-7h5.4l-6.7,10.6l6.7,10.6h-5.5l-1.4-2.2 c-1.1-1.7-1.8-3-2.8-4.9l-4.6,7.1h-5.5l7.1-10.6L190.4,13.6z"
|
||||
id="path9822" />
|
||||
</g>
|
||||
<path
|
||||
class="st0"
|
||||
d="M229,34.9h4.7L226,13.6h-4.3l-7.7,21.2h4.6l1.6-4.5h7.1L229,34.9z M223.9,20.3 C223.9,20.3,223.9,20.3,223.9,20.3C223.9,20.2,223.9,20.2,223.9,20.3l2.2,6.2h-4.4L223.9,20.3z"
|
||||
id="path9826" />
|
||||
</g>
|
||||
<g
|
||||
id="g9832">
|
||||
<path
|
||||
class="st1"
|
||||
d="M259.5,11.2h3.9v1h-1.3v3.1h-1.3v-3.1h-1.3V11.2L259.5,11.2z M264,11.2h1.7l0.6,2.5l0.6-2.5h1.7v4.1h-1v-3.1 l-0.8,3.1h-0.9l-0.8-3.1v3.1h-1V11.2L264,11.2z"
|
||||
id="path9830" />
|
||||
</g>
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 4.7 KiB |
@@ -0,0 +1,2 @@
|
||||
library: spqlios-fft
|
||||
version: 2.0.0
|
||||
27
poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/scripts/auto-release.sh
Executable file
27
poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/scripts/auto-release.sh
Executable file
@@ -0,0 +1,27 @@
|
||||
#!/bin/sh
|
||||
|
||||
# this script generates one tag if there is a version change in manifest.yaml
|
||||
cd `dirname $0`/..
|
||||
if [ "v$1" = "v-y" ]; then
|
||||
echo "production mode!";
|
||||
fi
|
||||
changes=`git diff HEAD~1..HEAD -- manifest.yaml | grep 'version:'`
|
||||
oldversion=$(echo "$changes" | grep '^-version:' | cut '-d ' -f2)
|
||||
version=$(echo "$changes" | grep '^+version:' | cut '-d ' -f2)
|
||||
echo "Versions: $oldversion --> $version"
|
||||
if [ "v$oldversion" = "v$version" ]; then
|
||||
echo "Same version - nothing to do"; exit 0;
|
||||
fi
|
||||
if [ "v$1" = "v-y" ]; then
|
||||
git config user.name github-actions
|
||||
git config user.email github-actions@github.com
|
||||
git tag -a "v$version" -m "Version $version"
|
||||
git push origin "v$version"
|
||||
else
|
||||
cat <<EOF
|
||||
# the script would do:
|
||||
git tag -a "v$version" -m "Version $version"
|
||||
git push origin "v$version"
|
||||
EOF
|
||||
fi
|
||||
|
||||
102
poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/scripts/ci-pkg
Executable file
102
poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/scripts/ci-pkg
Executable file
@@ -0,0 +1,102 @@
|
||||
#!/bin/sh
|
||||
|
||||
# ONLY USE A PREFIX YOU ARE CONFIDENT YOU CAN WIPE OUT ENTIRELY
|
||||
CI_INSTALL_PREFIX=/opt/spqlios
|
||||
CI_REPO_URL=https://spq-dav.algonics.net/ci
|
||||
WORKDIR=`pwd`
|
||||
if [ "x$DESTDIR" = "x" ]; then
|
||||
DESTDIR=/
|
||||
else
|
||||
mkdir -p $DESTDIR
|
||||
DESTDIR=`realpath $DESTDIR`
|
||||
fi
|
||||
DIR=`dirname "$0"`
|
||||
cd $DIR/..
|
||||
DIR=`pwd`
|
||||
|
||||
FULL_UNAME=`uname -a | tr '[A-Z]' '[a-z]'`
|
||||
HOST=`echo $FULL_UNAME | sed 's/ .*//'`
|
||||
ARCH=none
|
||||
case "$HOST" in
|
||||
*linux*)
|
||||
DISTRIB=`lsb_release -c | awk '{print $2}' | tr '[A-Z]' '[a-z]'`
|
||||
HOST=linux-$DISTRIB
|
||||
;;
|
||||
*darwin*)
|
||||
HOST=darwin
|
||||
;;
|
||||
*mingw*|*msys*)
|
||||
DISTRIB=`echo $MSYSTEM | tr '[A-Z]' '[a-z]'`
|
||||
HOST=msys64-$DISTRIB
|
||||
;;
|
||||
*)
|
||||
echo "Host unknown: $HOST";
|
||||
exit 1
|
||||
esac
|
||||
case "$FULL_UNAME" in
|
||||
*x86_64*)
|
||||
ARCH=x86_64
|
||||
;;
|
||||
*aarch64*)
|
||||
ARCH=aarch64
|
||||
;;
|
||||
*arm64*)
|
||||
ARCH=arm64
|
||||
;;
|
||||
*)
|
||||
echo "Architecture unknown: $FULL_UNAME";
|
||||
exit 1
|
||||
esac
|
||||
UNAME="$HOST-$ARCH"
|
||||
CMH=
|
||||
if [ -d lib/spqlios/.git ]; then
|
||||
CMH=`git submodule status | sed 's/\(..........\).*/\1/'`
|
||||
else
|
||||
CMH=`git rev-parse HEAD | sed 's/\(..........\).*/\1/'`
|
||||
fi
|
||||
FNAME=spqlios-arithmetic-$CMH-$UNAME.tar.gz
|
||||
|
||||
cat <<EOF
|
||||
================= CI MINI-PACKAGER ==================
|
||||
Work Dir: WORKDIR=$WORKDIR
|
||||
Spq Dir: DIR=$DIR
|
||||
Install Root: DESTDIR=$DESTDIR
|
||||
Install Prefix: CI_INSTALL_PREFIX=$CI_INSTALL_PREFIX
|
||||
Archive Name: FNAME=$FNAME
|
||||
CI WebDav: CI_REPO_URL=$CI_REPO_URL
|
||||
=====================================================
|
||||
EOF
|
||||
|
||||
if [ "x$1" = "xcreate" ]; then
|
||||
rm -rf dist
|
||||
cmake -B build -S . -DCMAKE_INSTALL_PREFIX="$CI_INSTALL_PREFIX" -DCMAKE_BUILD_TYPE=Release -DENABLE_TESTING=ON -DWARNING_PARANOID=ON -DDEVMODE_INSTALL=ON || exit 1
|
||||
cmake --build build || exit 1
|
||||
rm -rf "$DIR/dist" 2>/dev/null
|
||||
rm -f "$DIR/$FNAME" 2>/dev/null
|
||||
DESTDIR="$DIR/dist" cmake --install build || exit 1
|
||||
if [ -d "$DIR/dist$CI_INSTALL_PREFIX" ]; then
|
||||
tar -C "$DIR/dist" -cvzf "$DIR/$FNAME" .
|
||||
else
|
||||
# fix since msys can mess up the paths
|
||||
REAL_DEST=`find "$DIR/dist" -type d -exec test -d "{}$CI_INSTALL_PREFIX" \; -print`
|
||||
echo "REAL_DEST: $REAL_DEST"
|
||||
[ -d "$REAL_DEST$CI_INSTALL_PREFIX" ] && tar -C "$REAL_DEST" -cvzf "$DIR/$FNAME" .
|
||||
fi
|
||||
[ -f "$DIR/$FNAME" ] || { echo "failed to create $DIR/$FNAME"; exit 1; }
|
||||
[ "x$CI_CREDS" = "x" ] && { echo "CI_CREDS is not set: not uploading"; exit 1; }
|
||||
curl -u "$CI_CREDS" -T "$DIR/$FNAME" "$CI_REPO_URL/$FNAME"
|
||||
fi
|
||||
|
||||
if [ "x$1" = "xinstall" ]; then
|
||||
[ "x$CI_CREDS" = "x" ] && { echo "CI_CREDS is not set: not downloading"; exit 1; }
|
||||
# cleaning
|
||||
rm -rf "$DESTDIR$CI_INSTALL_PREFIX"/* 2>/dev/null
|
||||
rm -f "$DIR/$FNAME" 2>/dev/null
|
||||
# downloading
|
||||
curl -u "$CI_CREDS" -o "$DIR/$FNAME" "$CI_REPO_URL/$FNAME"
|
||||
[ -f "$DIR/$FNAME" ] || { echo "failed to download $DIR/$FNAME"; exit 0; }
|
||||
# installing
|
||||
mkdir -p $DESTDIR
|
||||
tar -C "$DESTDIR" -xvzf "$DIR/$FNAME"
|
||||
exit 0
|
||||
fi
|
||||
181
poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/scripts/prepare-release
Executable file
181
poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/scripts/prepare-release
Executable file
@@ -0,0 +1,181 @@
|
||||
#!/usr/bin/perl
|
||||
##
|
||||
## This script will help update manifest.yaml and Changelog.md before a release
|
||||
## Any merge to master that changes the version line in manifest.yaml
|
||||
## is considered as a new release.
|
||||
##
|
||||
## When ready to make a release, please run ./scripts/prepare-release
|
||||
## and commit push the final result!
|
||||
use File::Basename;
|
||||
use Cwd 'abs_path';
|
||||
|
||||
# find its way to the root of git's repository
|
||||
my $scriptsdirname = dirname(abs_path(__FILE__));
|
||||
chdir "$scriptsdirname/..";
|
||||
print "✓ Entering directory:".`pwd`;
|
||||
|
||||
# ensures that the current branch is ahead of origin/main
|
||||
my $diff= `git diff`;
|
||||
chop $diff;
|
||||
if ($diff =~ /./) {
|
||||
die("ERROR: Please commit all the changes before calling the prepare-release script.");
|
||||
} else {
|
||||
print("✓ All changes are comitted.\n");
|
||||
}
|
||||
system("git fetch origin");
|
||||
my $vcount = `git rev-list --left-right --count origin/main...HEAD`;
|
||||
$vcount =~ /^([0-9]+)[ \t]*([0-9]+)$/;
|
||||
if ($2>0) {
|
||||
die("ERROR: the current HEAD is not ahead of origin/main\n. Please use git merge origin/main.");
|
||||
} else {
|
||||
print("✓ Current HEAD is up to date with origin/main.\n");
|
||||
}
|
||||
|
||||
mkdir ".changes";
|
||||
my $currentbranch = `git rev-parse --abbrev-ref HEAD`;
|
||||
chop $currentbranch;
|
||||
$currentbranch =~ s/[^a-zA-Z._-]+/-/g;
|
||||
my $changefile=".changes/$currentbranch.md";
|
||||
my $origmanifestfile=".changes/$currentbranch--manifest.yaml";
|
||||
my $origchangelogfile=".changes/$currentbranch--Changelog.md";
|
||||
|
||||
my $exit_code=system("wget -O $origmanifestfile https://raw.githubusercontent.com/tfhe/spqlios-fft/main/manifest.yaml");
|
||||
if ($exit_code!=0 or ! -f $origmanifestfile) {
|
||||
die("ERROR: failed to download manifest.yaml");
|
||||
}
|
||||
$exit_code=system("wget -O $origchangelogfile https://raw.githubusercontent.com/tfhe/spqlios-fft/main/Changelog.md");
|
||||
if ($exit_code!=0 or ! -f $origchangelogfile) {
|
||||
die("ERROR: failed to download Changelog.md");
|
||||
}
|
||||
|
||||
# read the current version (from origin/main manifest)
|
||||
my $vmajor = 0;
|
||||
my $vminor = 0;
|
||||
my $vpatch = 0;
|
||||
my $versionline = `grep '^version: ' $origmanifestfile | cut -d" " -f2`;
|
||||
chop $versionline;
|
||||
if (not $versionline =~ /^([0-9]+)\.([0-9]+)\.([0-9]+)$/) {
|
||||
die("ERROR: invalid version in manifest file: $versionline\n");
|
||||
} else {
|
||||
$vmajor = int($1);
|
||||
$vminor = int($2);
|
||||
$vpatch = int($3);
|
||||
}
|
||||
print "Version in manifest file: $vmajor.$vminor.$vpatch\n";
|
||||
|
||||
if (not -f $changefile) {
|
||||
## create a changes file
|
||||
open F,">$changefile";
|
||||
print F "# Changefile for branch $currentbranch\n\n";
|
||||
print F "## Type of release (major,minor,patch)?\n\n";
|
||||
print F "releasetype: patch\n\n";
|
||||
print F "## What has changed (please edit)?\n\n";
|
||||
print F "- This has changed.\n";
|
||||
close F;
|
||||
}
|
||||
|
||||
system("editor $changefile");
|
||||
|
||||
# compute the new version
|
||||
my $nvmajor;
|
||||
my $nvminor;
|
||||
my $nvpatch;
|
||||
my $changelog;
|
||||
my $recordchangelog=0;
|
||||
open F,"$changefile";
|
||||
while ($line=<F>) {
|
||||
chop $line;
|
||||
if ($recordchangelog) {
|
||||
($line =~ /^$/) and next;
|
||||
$changelog .= "$line\n";
|
||||
next;
|
||||
}
|
||||
if ($line =~ /^releasetype *: *patch *$/) {
|
||||
$nvmajor=$vmajor;
|
||||
$nvminor=$vminor;
|
||||
$nvpatch=$vpatch+1;
|
||||
}
|
||||
if ($line =~ /^releasetype *: *minor *$/) {
|
||||
$nvmajor=$vmajor;
|
||||
$nvminor=$vminor+1;
|
||||
$nvpatch=0;
|
||||
}
|
||||
if ($line =~ /^releasetype *: *major *$/) {
|
||||
$nvmajor=$vmajor+1;
|
||||
$nvminor=0;
|
||||
$nvpatch=0;
|
||||
}
|
||||
if ($line =~ /^## What has changed/) {
|
||||
$recordchangelog=1;
|
||||
}
|
||||
}
|
||||
close F;
|
||||
print "New version: $nvmajor.$nvminor.$nvpatch\n";
|
||||
print "Changes:\n$changelog";
|
||||
|
||||
# updating manifest.yaml
|
||||
open F,"manifest.yaml";
|
||||
open G,">.changes/manifest.yaml";
|
||||
while ($line=<F>) {
|
||||
if ($line =~ /^version *: */) {
|
||||
print G "version: $nvmajor.$nvminor.$nvpatch\n";
|
||||
next;
|
||||
}
|
||||
print G $line;
|
||||
}
|
||||
close F;
|
||||
close G;
|
||||
# updating Changelog.md
|
||||
open F,"$origchangelogfile";
|
||||
open G,">.changes/Changelog.md";
|
||||
print G <<EOF
|
||||
# Changelog
|
||||
|
||||
All notable changes to this project will be documented in this file.
|
||||
this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
EOF
|
||||
;
|
||||
print G "## [$nvmajor.$nvminor.$nvpatch] - ".`date '+%Y-%m-%d'`."\n";
|
||||
print G "$changelog\n";
|
||||
my $skip_section=1;
|
||||
while ($line=<F>) {
|
||||
if ($line =~ /^## +\[([0-9]+)\.([0-9]+)\.([0-9]+)\] +/) {
|
||||
if ($1>$nvmajor) {
|
||||
die("ERROR: found larger version $1.$2.$3 in the Changelog.md\n");
|
||||
} elsif ($1<$nvmajor) {
|
||||
$skip_section=0;
|
||||
} elsif ($2>$nvminor) {
|
||||
die("ERROR: found larger version $1.$2.$3 in the Changelog.md\n");
|
||||
} elsif ($2<$nvminor) {
|
||||
$skip_section=0;
|
||||
} elsif ($3>$nvpatch) {
|
||||
die("ERROR: found larger version $1.$2.$3 in the Changelog.md\n");
|
||||
} elsif ($2<$nvpatch) {
|
||||
$skip_section=0;
|
||||
} else {
|
||||
$skip_section=1;
|
||||
}
|
||||
}
|
||||
($skip_section) and next;
|
||||
print G $line;
|
||||
}
|
||||
close F;
|
||||
close G;
|
||||
|
||||
print "-------------------------------------\n";
|
||||
print "THIS WILL BE UPDATED:\n";
|
||||
print "-------------------------------------\n";
|
||||
system("diff -u manifest.yaml .changes/manifest.yaml");
|
||||
system("diff -u Changelog.md .changes/Changelog.md");
|
||||
print "-------------------------------------\n";
|
||||
print "To proceed: press <enter> otherwise <CTRL+C>\n";
|
||||
my $bla;
|
||||
$bla=<STDIN>;
|
||||
system("cp -vf .changes/manifest.yaml manifest.yaml");
|
||||
system("cp -vf .changes/Changelog.md Changelog.md");
|
||||
system("git commit -a -m \"Update version and changelog.\"");
|
||||
system("git push");
|
||||
print("✓ Changes have been committed and pushed!\n");
|
||||
print("✓ A new release will be created when this branch is merged to main.\n");
|
||||
|
||||
@@ -0,0 +1,223 @@
|
||||
enable_language(ASM)
|
||||
|
||||
# C source files that are compiled for all targets (i.e. reference code)
|
||||
set(SRCS_GENERIC
|
||||
commons.c
|
||||
commons_private.c
|
||||
coeffs/coeffs_arithmetic.c
|
||||
arithmetic/vec_znx.c
|
||||
arithmetic/vec_znx_dft.c
|
||||
arithmetic/vector_matrix_product.c
|
||||
cplx/cplx_common.c
|
||||
cplx/cplx_conversions.c
|
||||
cplx/cplx_fft_asserts.c
|
||||
cplx/cplx_fft_ref.c
|
||||
cplx/cplx_fftvec_ref.c
|
||||
cplx/cplx_ifft_ref.c
|
||||
cplx/spqlios_cplx_fft.c
|
||||
reim4/reim4_arithmetic_ref.c
|
||||
reim4/reim4_fftvec_addmul_ref.c
|
||||
reim4/reim4_fftvec_conv_ref.c
|
||||
reim/reim_conversions.c
|
||||
reim/reim_fft_ifft.c
|
||||
reim/reim_fft_ref.c
|
||||
reim/reim_fftvec_ref.c
|
||||
reim/reim_ifft_ref.c
|
||||
reim/reim_ifft_ref.c
|
||||
reim/reim_to_tnx_ref.c
|
||||
q120/q120_ntt.c
|
||||
q120/q120_arithmetic_ref.c
|
||||
q120/q120_arithmetic_simple.c
|
||||
arithmetic/scalar_vector_product.c
|
||||
arithmetic/vec_znx_big.c
|
||||
arithmetic/znx_small.c
|
||||
arithmetic/module_api.c
|
||||
arithmetic/zn_vmp_int8_ref.c
|
||||
arithmetic/zn_vmp_int16_ref.c
|
||||
arithmetic/zn_vmp_int32_ref.c
|
||||
arithmetic/zn_vmp_ref.c
|
||||
arithmetic/zn_api.c
|
||||
arithmetic/zn_conversions_ref.c
|
||||
arithmetic/zn_approxdecomp_ref.c
|
||||
arithmetic/vec_rnx_api.c
|
||||
arithmetic/vec_rnx_conversions_ref.c
|
||||
arithmetic/vec_rnx_svp_ref.c
|
||||
reim/reim_execute.c
|
||||
cplx/cplx_execute.c
|
||||
reim4/reim4_execute.c
|
||||
arithmetic/vec_rnx_arithmetic.c
|
||||
arithmetic/vec_rnx_approxdecomp_ref.c
|
||||
arithmetic/vec_rnx_vmp_ref.c
|
||||
)
|
||||
# C or assembly source files compiled only on x86 targets
|
||||
set(SRCS_X86
|
||||
)
|
||||
# C or assembly source files compiled only on aarch64 targets
|
||||
set(SRCS_AARCH64
|
||||
cplx/cplx_fallbacks_aarch64.c
|
||||
reim/reim_fallbacks_aarch64.c
|
||||
reim4/reim4_fallbacks_aarch64.c
|
||||
q120/q120_fallbacks_aarch64.c
|
||||
reim/reim_fft_neon.c
|
||||
)
|
||||
|
||||
# C or assembly source files compiled only on x86: avx, avx2, fma targets
|
||||
set(SRCS_FMA_C
|
||||
arithmetic/vector_matrix_product_avx.c
|
||||
cplx/cplx_conversions_avx2_fma.c
|
||||
cplx/cplx_fft_avx2_fma.c
|
||||
cplx/cplx_fft_sse.c
|
||||
cplx/cplx_fftvec_avx2_fma.c
|
||||
cplx/cplx_ifft_avx2_fma.c
|
||||
reim4/reim4_arithmetic_avx2.c
|
||||
reim4/reim4_fftvec_conv_fma.c
|
||||
reim4/reim4_fftvec_addmul_fma.c
|
||||
reim/reim_conversions_avx.c
|
||||
reim/reim_fft4_avx_fma.c
|
||||
reim/reim_fft8_avx_fma.c
|
||||
reim/reim_ifft4_avx_fma.c
|
||||
reim/reim_ifft8_avx_fma.c
|
||||
reim/reim_fft_avx2.c
|
||||
reim/reim_ifft_avx2.c
|
||||
reim/reim_to_tnx_avx.c
|
||||
reim/reim_fftvec_fma.c
|
||||
)
|
||||
set(SRCS_FMA_ASM
|
||||
cplx/cplx_fft16_avx_fma.s
|
||||
cplx/cplx_ifft16_avx_fma.s
|
||||
reim/reim_fft16_avx_fma.s
|
||||
reim/reim_ifft16_avx_fma.s
|
||||
)
|
||||
set(SRCS_FMA_WIN32_ASM
|
||||
cplx/cplx_fft16_avx_fma_win32.s
|
||||
cplx/cplx_ifft16_avx_fma_win32.s
|
||||
reim/reim_fft16_avx_fma_win32.s
|
||||
reim/reim_ifft16_avx_fma_win32.s
|
||||
)
|
||||
set_source_files_properties(${SRCS_FMA_C} PROPERTIES COMPILE_OPTIONS "-mfma;-mavx;-mavx2")
|
||||
set_source_files_properties(${SRCS_FMA_ASM} PROPERTIES COMPILE_OPTIONS "-mfma;-mavx;-mavx2")
|
||||
|
||||
# C or assembly source files compiled only on x86: avx512f/vl/dq + fma targets
|
||||
set(SRCS_AVX512
|
||||
cplx/cplx_fft_avx512.c
|
||||
)
|
||||
set_source_files_properties(${SRCS_AVX512} PROPERTIES COMPILE_OPTIONS "-mfma;-mavx512f;-mavx512vl;-mavx512dq")
|
||||
|
||||
# C or assembly source files compiled only on x86: avx2 + bmi targets
|
||||
set(SRCS_AVX2
|
||||
arithmetic/vec_znx_avx.c
|
||||
coeffs/coeffs_arithmetic_avx.c
|
||||
arithmetic/vec_znx_dft_avx2.c
|
||||
arithmetic/zn_vmp_int8_avx.c
|
||||
arithmetic/zn_vmp_int16_avx.c
|
||||
arithmetic/zn_vmp_int32_avx.c
|
||||
q120/q120_arithmetic_avx2.c
|
||||
q120/q120_ntt_avx2.c
|
||||
arithmetic/vec_rnx_arithmetic_avx.c
|
||||
arithmetic/vec_rnx_approxdecomp_avx.c
|
||||
arithmetic/vec_rnx_vmp_avx.c
|
||||
|
||||
)
|
||||
set_source_files_properties(${SRCS_AVX2} PROPERTIES COMPILE_OPTIONS "-mbmi2;-mavx2")
|
||||
|
||||
# C source files on float128 via libquadmath on x86 targets targets
|
||||
set(SRCS_F128
|
||||
cplx_f128/cplx_fft_f128.c
|
||||
cplx_f128/cplx_fft_f128.h
|
||||
)
|
||||
|
||||
# H header files containing the public API (these headers are installed)
|
||||
set(HEADERSPUBLIC
|
||||
commons.h
|
||||
arithmetic/vec_znx_arithmetic.h
|
||||
arithmetic/vec_rnx_arithmetic.h
|
||||
arithmetic/zn_arithmetic.h
|
||||
cplx/cplx_fft.h
|
||||
reim/reim_fft.h
|
||||
q120/q120_common.h
|
||||
q120/q120_arithmetic.h
|
||||
q120/q120_ntt.h
|
||||
)
|
||||
|
||||
# H header files containing the private API (these headers are used internally)
|
||||
set(HEADERSPRIVATE
|
||||
commons_private.h
|
||||
cplx/cplx_fft_internal.h
|
||||
cplx/cplx_fft_private.h
|
||||
reim4/reim4_arithmetic.h
|
||||
reim4/reim4_fftvec_internal.h
|
||||
reim4/reim4_fftvec_private.h
|
||||
reim4/reim4_fftvec_public.h
|
||||
reim/reim_fft_internal.h
|
||||
reim/reim_fft_private.h
|
||||
q120/q120_arithmetic_private.h
|
||||
q120/q120_ntt_private.h
|
||||
arithmetic/vec_znx_arithmetic.h
|
||||
arithmetic/vec_rnx_arithmetic_private.h
|
||||
arithmetic/vec_rnx_arithmetic_plugin.h
|
||||
arithmetic/zn_arithmetic_private.h
|
||||
arithmetic/zn_arithmetic_plugin.h
|
||||
coeffs/coeffs_arithmetic.h
|
||||
reim/reim_fft_core_template.h
|
||||
)
|
||||
|
||||
set(SPQLIOSSOURCES
|
||||
${SRCS_GENERIC}
|
||||
${HEADERSPUBLIC}
|
||||
${HEADERSPRIVATE}
|
||||
)
|
||||
if (${X86})
|
||||
set(SPQLIOSSOURCES ${SPQLIOSSOURCES}
|
||||
${SRCS_X86}
|
||||
${SRCS_FMA_C}
|
||||
${SRCS_FMA_ASM}
|
||||
${SRCS_AVX2}
|
||||
${SRCS_AVX512}
|
||||
)
|
||||
elseif (${X86_WIN32})
|
||||
set(SPQLIOSSOURCES ${SPQLIOSSOURCES}
|
||||
#${SRCS_X86}
|
||||
${SRCS_FMA_C}
|
||||
${SRCS_FMA_WIN32_ASM}
|
||||
${SRCS_AVX2}
|
||||
${SRCS_AVX512}
|
||||
)
|
||||
elseif (${AARCH64})
|
||||
set(SPQLIOSSOURCES ${SPQLIOSSOURCES}
|
||||
${SRCS_AARCH64}
|
||||
)
|
||||
endif ()
|
||||
|
||||
|
||||
set(SPQLIOSLIBDEP
|
||||
m # libmath depencency for cosinus/sinus functions
|
||||
)
|
||||
|
||||
if (ENABLE_SPQLIOS_F128)
|
||||
find_library(quadmath REQUIRED NAMES quadmath)
|
||||
set(SPQLIOSSOURCES ${SPQLIOSSOURCES} ${SRCS_F128})
|
||||
set(SPQLIOSLIBDEP ${SPQLIOSLIBDEP} quadmath)
|
||||
endif (ENABLE_SPQLIOS_F128)
|
||||
|
||||
add_library(libspqlios-static STATIC ${SPQLIOSSOURCES})
|
||||
add_library(libspqlios SHARED ${SPQLIOSSOURCES})
|
||||
set_property(TARGET libspqlios-static PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
set_property(TARGET libspqlios PROPERTY OUTPUT_NAME spqlios)
|
||||
set_property(TARGET libspqlios-static PROPERTY OUTPUT_NAME spqlios)
|
||||
set_property(TARGET libspqlios PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
set_property(TARGET libspqlios PROPERTY SOVERSION ${SPQLIOS_VERSION_MAJOR})
|
||||
set_property(TARGET libspqlios PROPERTY VERSION ${SPQLIOS_VERSION})
|
||||
if (NOT APPLE)
|
||||
target_link_options(libspqlios-static PUBLIC -Wl,--no-undefined)
|
||||
target_link_options(libspqlios PUBLIC -Wl,--no-undefined)
|
||||
endif()
|
||||
target_link_libraries(libspqlios ${SPQLIOSLIBDEP})
|
||||
target_link_libraries(libspqlios-static ${SPQLIOSLIBDEP})
|
||||
install(TARGETS libspqlios-static)
|
||||
install(TARGETS libspqlios)
|
||||
|
||||
# install the public headers only
|
||||
foreach (file ${HEADERSPUBLIC})
|
||||
get_filename_component(dir ${file} DIRECTORY)
|
||||
install(FILES ${file} DESTINATION include/spqlios/${dir})
|
||||
endforeach ()
|
||||
@@ -0,0 +1,172 @@
|
||||
#include <string.h>
|
||||
|
||||
#include "vec_znx_arithmetic_private.h"
|
||||
|
||||
static void fill_generic_virtual_table(MODULE* module) {
|
||||
// TODO add default ref handler here
|
||||
module->func.vec_znx_zero = vec_znx_zero_ref;
|
||||
module->func.vec_znx_copy = vec_znx_copy_ref;
|
||||
module->func.vec_znx_negate = vec_znx_negate_ref;
|
||||
module->func.vec_znx_add = vec_znx_add_ref;
|
||||
module->func.vec_znx_sub = vec_znx_sub_ref;
|
||||
module->func.vec_znx_rotate = vec_znx_rotate_ref;
|
||||
module->func.vec_znx_mul_xp_minus_one = vec_znx_mul_xp_minus_one_ref;
|
||||
module->func.vec_znx_automorphism = vec_znx_automorphism_ref;
|
||||
module->func.vec_znx_normalize_base2k = vec_znx_normalize_base2k_ref;
|
||||
module->func.vec_znx_normalize_base2k_tmp_bytes = vec_znx_normalize_base2k_tmp_bytes_ref;
|
||||
if (CPU_SUPPORTS("avx2")) {
|
||||
// TODO add avx handlers here
|
||||
module->func.vec_znx_negate = vec_znx_negate_avx;
|
||||
module->func.vec_znx_add = vec_znx_add_avx;
|
||||
module->func.vec_znx_sub = vec_znx_sub_avx;
|
||||
}
|
||||
}
|
||||
|
||||
static void fill_fft64_virtual_table(MODULE* module) {
|
||||
// TODO add default ref handler here
|
||||
// module->func.vec_znx_dft = ...;
|
||||
module->func.vec_znx_big_normalize_base2k = fft64_vec_znx_big_normalize_base2k;
|
||||
module->func.vec_znx_big_normalize_base2k_tmp_bytes = fft64_vec_znx_big_normalize_base2k_tmp_bytes;
|
||||
module->func.vec_znx_big_range_normalize_base2k = fft64_vec_znx_big_range_normalize_base2k;
|
||||
module->func.vec_znx_big_range_normalize_base2k_tmp_bytes = fft64_vec_znx_big_range_normalize_base2k_tmp_bytes;
|
||||
module->func.vec_znx_dft = fft64_vec_znx_dft;
|
||||
module->func.vec_znx_idft = fft64_vec_znx_idft;
|
||||
module->func.vec_dft_add = fft64_vec_dft_add;
|
||||
module->func.vec_dft_sub = fft64_vec_dft_sub;
|
||||
module->func.vec_znx_idft_tmp_bytes = fft64_vec_znx_idft_tmp_bytes;
|
||||
module->func.vec_znx_idft_tmp_a = fft64_vec_znx_idft_tmp_a;
|
||||
module->func.vec_znx_big_add = fft64_vec_znx_big_add;
|
||||
module->func.vec_znx_big_add_small = fft64_vec_znx_big_add_small;
|
||||
module->func.vec_znx_big_add_small2 = fft64_vec_znx_big_add_small2;
|
||||
module->func.vec_znx_big_sub = fft64_vec_znx_big_sub;
|
||||
module->func.vec_znx_big_sub_small_a = fft64_vec_znx_big_sub_small_a;
|
||||
module->func.vec_znx_big_sub_small_b = fft64_vec_znx_big_sub_small_b;
|
||||
module->func.vec_znx_big_sub_small2 = fft64_vec_znx_big_sub_small2;
|
||||
module->func.vec_znx_big_rotate = fft64_vec_znx_big_rotate;
|
||||
module->func.vec_znx_big_automorphism = fft64_vec_znx_big_automorphism;
|
||||
module->func.svp_prepare = fft64_svp_prepare_ref;
|
||||
module->func.svp_apply_dft = fft64_svp_apply_dft_ref;
|
||||
module->func.svp_apply_dft_to_dft = fft64_svp_apply_dft_to_dft_ref;
|
||||
module->func.znx_small_single_product = fft64_znx_small_single_product;
|
||||
module->func.znx_small_single_product_tmp_bytes = fft64_znx_small_single_product_tmp_bytes;
|
||||
module->func.vmp_prepare_contiguous = fft64_vmp_prepare_contiguous_ref;
|
||||
module->func.vmp_prepare_tmp_bytes = fft64_vmp_prepare_tmp_bytes;
|
||||
module->func.vmp_apply_dft = fft64_vmp_apply_dft_ref;
|
||||
module->func.vmp_apply_dft_add = fft64_vmp_apply_dft_add_ref;
|
||||
module->func.vmp_apply_dft_tmp_bytes = fft64_vmp_apply_dft_tmp_bytes;
|
||||
module->func.vmp_apply_dft_to_dft = fft64_vmp_apply_dft_to_dft_ref;
|
||||
module->func.vmp_apply_dft_to_dft_add = fft64_vmp_apply_dft_to_dft_add_ref;
|
||||
module->func.vmp_apply_dft_to_dft_tmp_bytes = fft64_vmp_apply_dft_to_dft_tmp_bytes;
|
||||
module->func.bytes_of_vec_znx_dft = fft64_bytes_of_vec_znx_dft;
|
||||
module->func.bytes_of_vec_znx_big = fft64_bytes_of_vec_znx_big;
|
||||
module->func.bytes_of_svp_ppol = fft64_bytes_of_svp_ppol;
|
||||
module->func.bytes_of_vmp_pmat = fft64_bytes_of_vmp_pmat;
|
||||
if (CPU_SUPPORTS("avx2")) {
|
||||
// TODO add avx handlers here
|
||||
// TODO: enable when avx implementation is done
|
||||
module->func.vmp_prepare_contiguous = fft64_vmp_prepare_contiguous_avx;
|
||||
module->func.vmp_apply_dft = fft64_vmp_apply_dft_avx;
|
||||
module->func.vmp_apply_dft_add = fft64_vmp_apply_dft_add_avx;
|
||||
module->func.vmp_apply_dft_to_dft = fft64_vmp_apply_dft_to_dft_avx;
|
||||
module->func.vmp_apply_dft_to_dft_add = fft64_vmp_apply_dft_to_dft_add_avx;
|
||||
}
|
||||
}
|
||||
|
||||
static void fill_ntt120_virtual_table(MODULE* module) {
|
||||
// TODO add default ref handler here
|
||||
// module->func.vec_znx_dft = ...;
|
||||
if (CPU_SUPPORTS("avx2")) {
|
||||
// TODO add avx handlers here
|
||||
module->func.vec_znx_dft = ntt120_vec_znx_dft_avx;
|
||||
module->func.vec_znx_idft = ntt120_vec_znx_idft_avx;
|
||||
module->func.vec_znx_idft_tmp_bytes = ntt120_vec_znx_idft_tmp_bytes_avx;
|
||||
module->func.vec_znx_idft_tmp_a = ntt120_vec_znx_idft_tmp_a_avx;
|
||||
}
|
||||
}
|
||||
|
||||
static void fill_virtual_table(MODULE* module) {
|
||||
fill_generic_virtual_table(module);
|
||||
switch (module->module_type) {
|
||||
case FFT64:
|
||||
fill_fft64_virtual_table(module);
|
||||
break;
|
||||
case NTT120:
|
||||
fill_ntt120_virtual_table(module);
|
||||
break;
|
||||
default:
|
||||
NOT_SUPPORTED(); // invalid type
|
||||
}
|
||||
}
|
||||
|
||||
static void fill_fft64_precomp(MODULE* module) {
|
||||
// fill any necessary precomp stuff
|
||||
module->mod.fft64.p_conv = new_reim_from_znx64_precomp(module->m, 50);
|
||||
module->mod.fft64.p_fft = new_reim_fft_precomp(module->m, 0);
|
||||
module->mod.fft64.p_reim_to_znx = new_reim_to_znx64_precomp(module->m, module->m, 63);
|
||||
module->mod.fft64.p_ifft = new_reim_ifft_precomp(module->m, 0);
|
||||
module->mod.fft64.p_addmul = new_reim_fftvec_addmul_precomp(module->m);
|
||||
module->mod.fft64.mul_fft = new_reim_fftvec_mul_precomp(module->m);
|
||||
module->mod.fft64.add_fft = new_reim_fftvec_add_precomp(module->m);
|
||||
module->mod.fft64.sub_fft = new_reim_fftvec_sub_precomp(module->m);
|
||||
}
|
||||
static void fill_ntt120_precomp(MODULE* module) {
|
||||
// fill any necessary precomp stuff
|
||||
if (CPU_SUPPORTS("avx2")) {
|
||||
module->mod.q120.p_ntt = q120_new_ntt_bb_precomp(module->nn);
|
||||
module->mod.q120.p_intt = q120_new_intt_bb_precomp(module->nn);
|
||||
}
|
||||
}
|
||||
|
||||
static void fill_module_precomp(MODULE* module) {
|
||||
switch (module->module_type) {
|
||||
case FFT64:
|
||||
fill_fft64_precomp(module);
|
||||
break;
|
||||
case NTT120:
|
||||
fill_ntt120_precomp(module);
|
||||
break;
|
||||
default:
|
||||
NOT_SUPPORTED(); // invalid type
|
||||
}
|
||||
}
|
||||
|
||||
static void fill_module(MODULE* module, uint64_t nn, MODULE_TYPE mtype) {
|
||||
// init to zero to ensure that any non-initialized field bug is detected
|
||||
// by at least a "proper" segfault
|
||||
memset(module, 0, sizeof(MODULE));
|
||||
module->module_type = mtype;
|
||||
module->nn = nn;
|
||||
module->m = nn >> 1;
|
||||
fill_module_precomp(module);
|
||||
fill_virtual_table(module);
|
||||
}
|
||||
|
||||
EXPORT MODULE* new_module_info(uint64_t N, MODULE_TYPE mtype) {
|
||||
MODULE* m = (MODULE*)malloc(sizeof(MODULE));
|
||||
fill_module(m, N, mtype);
|
||||
return m;
|
||||
}
|
||||
|
||||
EXPORT void delete_module_info(MODULE* mod) {
|
||||
switch (mod->module_type) {
|
||||
case FFT64:
|
||||
free(mod->mod.fft64.p_conv);
|
||||
free(mod->mod.fft64.p_fft);
|
||||
free(mod->mod.fft64.p_ifft);
|
||||
free(mod->mod.fft64.p_reim_to_znx);
|
||||
free(mod->mod.fft64.mul_fft);
|
||||
free(mod->mod.fft64.p_addmul);
|
||||
break;
|
||||
case NTT120:
|
||||
if (CPU_SUPPORTS("avx2")) {
|
||||
q120_del_ntt_bb_precomp(mod->mod.q120.p_ntt);
|
||||
q120_del_intt_bb_precomp(mod->mod.q120.p_intt);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
free(mod);
|
||||
}
|
||||
|
||||
EXPORT uint64_t module_get_n(const MODULE* module) { return module->nn; }
|
||||
@@ -0,0 +1,102 @@
|
||||
#include <string.h>
|
||||
|
||||
#include "vec_znx_arithmetic_private.h"
|
||||
|
||||
EXPORT uint64_t bytes_of_svp_ppol(const MODULE* module) { return module->func.bytes_of_svp_ppol(module); }
|
||||
|
||||
EXPORT uint64_t fft64_bytes_of_svp_ppol(const MODULE* module) { return module->nn * sizeof(double); }
|
||||
|
||||
EXPORT SVP_PPOL* new_svp_ppol(const MODULE* module) { return spqlios_alloc(bytes_of_svp_ppol(module)); }
|
||||
|
||||
EXPORT void delete_svp_ppol(SVP_PPOL* ppol) { spqlios_free(ppol); }
|
||||
|
||||
// public wrappers
|
||||
EXPORT void svp_prepare(const MODULE* module, // N
|
||||
SVP_PPOL* ppol, // output
|
||||
const int64_t* pol // a
|
||||
) {
|
||||
module->func.svp_prepare(module, ppol, pol);
|
||||
}
|
||||
|
||||
/** @brief prepares a svp polynomial */
|
||||
EXPORT void fft64_svp_prepare_ref(const MODULE* module, // N
|
||||
SVP_PPOL* ppol, // output
|
||||
const int64_t* pol // a
|
||||
) {
|
||||
reim_from_znx64(module->mod.fft64.p_conv, ppol, pol);
|
||||
reim_fft(module->mod.fft64.p_fft, (double*)ppol);
|
||||
}
|
||||
|
||||
EXPORT void svp_apply_dft(const MODULE* module, // N
|
||||
const VEC_ZNX_DFT* res, uint64_t res_size, // output
|
||||
const SVP_PPOL* ppol, // prepared pol
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl) {
|
||||
module->func.svp_apply_dft(module, // N
|
||||
res,
|
||||
res_size, // output
|
||||
ppol, // prepared pol
|
||||
a, a_size, a_sl);
|
||||
}
|
||||
|
||||
EXPORT void svp_apply_dft_to_dft(const MODULE* module, // N
|
||||
const VEC_ZNX_DFT* res, uint64_t res_size,
|
||||
uint64_t res_cols, // output
|
||||
const SVP_PPOL* ppol, // prepared pol
|
||||
const VEC_ZNX_DFT* a, uint64_t a_size, uint64_t a_cols) {
|
||||
module->func.svp_apply_dft_to_dft(module, // N
|
||||
res, res_size, res_cols, // output
|
||||
ppol, a, a_size, a_cols // prepared pol
|
||||
);
|
||||
}
|
||||
|
||||
// result = ppol * a
|
||||
EXPORT void fft64_svp_apply_dft_ref(const MODULE* module, // N
|
||||
const VEC_ZNX_DFT* res, uint64_t res_size, // output
|
||||
const SVP_PPOL* ppol, // prepared pol
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
const uint64_t nn = module->nn;
|
||||
double* const dres = (double*)res;
|
||||
double* const dppol = (double*)ppol;
|
||||
|
||||
const uint64_t auto_end_idx = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < auto_end_idx; ++i) {
|
||||
const int64_t* a_ptr = a + i * a_sl;
|
||||
double* const res_ptr = dres + i * nn;
|
||||
// copy the polynomial to res, apply fft in place, call fftvec_mul in place.
|
||||
reim_from_znx64(module->mod.fft64.p_conv, res_ptr, a_ptr);
|
||||
reim_fft(module->mod.fft64.p_fft, res_ptr);
|
||||
reim_fftvec_mul(module->mod.fft64.mul_fft, res_ptr, res_ptr, dppol);
|
||||
}
|
||||
|
||||
// then extend with zeros
|
||||
memset(dres + auto_end_idx * nn, 0, (res_size - auto_end_idx) * nn * sizeof(double));
|
||||
}
|
||||
|
||||
// result = ppol * a
|
||||
EXPORT void fft64_svp_apply_dft_to_dft_ref(const MODULE* module, // N
|
||||
const VEC_ZNX_DFT* res, uint64_t res_size,
|
||||
uint64_t res_cols, // output
|
||||
const SVP_PPOL* ppol, // prepared pol
|
||||
const VEC_ZNX_DFT* a, uint64_t a_size,
|
||||
uint64_t a_cols // a
|
||||
) {
|
||||
const uint64_t nn = module->nn;
|
||||
const uint64_t res_sl = nn * res_cols;
|
||||
const uint64_t a_sl = nn * a_cols;
|
||||
double* const dres = (double*)res;
|
||||
double* const da = (double*)a;
|
||||
double* const dppol = (double*)ppol;
|
||||
|
||||
const uint64_t auto_end_idx = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < auto_end_idx; ++i) {
|
||||
const double* a_ptr = da + i * a_sl;
|
||||
double* const res_ptr = dres + i * res_sl;
|
||||
reim_fftvec_mul(module->mod.fft64.mul_fft, res_ptr, a_ptr, dppol);
|
||||
}
|
||||
|
||||
// then extend with zeros
|
||||
for (uint64_t i = auto_end_idx; i < res_size; i++) {
|
||||
memset(dres + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,344 @@
|
||||
#include <string.h>
|
||||
|
||||
#include "vec_rnx_arithmetic_private.h"
|
||||
|
||||
void fft64_init_rnx_module_precomp(MOD_RNX* module) {
|
||||
// Add here initialization of items that are in the precomp
|
||||
const uint64_t m = module->m;
|
||||
module->precomp.fft64.p_fft = new_reim_fft_precomp(m, 0);
|
||||
module->precomp.fft64.p_ifft = new_reim_ifft_precomp(m, 0);
|
||||
module->precomp.fft64.p_fftvec_add = new_reim_fftvec_add_precomp(m);
|
||||
module->precomp.fft64.p_fftvec_mul = new_reim_fftvec_mul_precomp(m);
|
||||
module->precomp.fft64.p_fftvec_addmul = new_reim_fftvec_addmul_precomp(m);
|
||||
}
|
||||
|
||||
void fft64_finalize_rnx_module_precomp(MOD_RNX* module) {
|
||||
// Add here deleters for items that are in the precomp
|
||||
delete_reim_fft_precomp(module->precomp.fft64.p_fft);
|
||||
delete_reim_ifft_precomp(module->precomp.fft64.p_ifft);
|
||||
delete_reim_fftvec_add_precomp(module->precomp.fft64.p_fftvec_add);
|
||||
delete_reim_fftvec_mul_precomp(module->precomp.fft64.p_fftvec_mul);
|
||||
delete_reim_fftvec_addmul_precomp(module->precomp.fft64.p_fftvec_addmul);
|
||||
}
|
||||
|
||||
void fft64_init_rnx_module_vtable(MOD_RNX* module) {
|
||||
// Add function pointers here
|
||||
module->vtable.vec_rnx_add = vec_rnx_add_ref;
|
||||
module->vtable.vec_rnx_zero = vec_rnx_zero_ref;
|
||||
module->vtable.vec_rnx_copy = vec_rnx_copy_ref;
|
||||
module->vtable.vec_rnx_negate = vec_rnx_negate_ref;
|
||||
module->vtable.vec_rnx_sub = vec_rnx_sub_ref;
|
||||
module->vtable.vec_rnx_rotate = vec_rnx_rotate_ref;
|
||||
module->vtable.vec_rnx_automorphism = vec_rnx_automorphism_ref;
|
||||
module->vtable.vec_rnx_mul_xp_minus_one = vec_rnx_mul_xp_minus_one_ref;
|
||||
module->vtable.rnx_vmp_apply_dft_to_dft_tmp_bytes = fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref;
|
||||
module->vtable.rnx_vmp_apply_dft_to_dft = fft64_rnx_vmp_apply_dft_to_dft_ref;
|
||||
module->vtable.rnx_vmp_apply_tmp_a_tmp_bytes = fft64_rnx_vmp_apply_tmp_a_tmp_bytes_ref;
|
||||
module->vtable.rnx_vmp_apply_tmp_a = fft64_rnx_vmp_apply_tmp_a_ref;
|
||||
module->vtable.rnx_vmp_prepare_tmp_bytes = fft64_rnx_vmp_prepare_tmp_bytes_ref;
|
||||
module->vtable.rnx_vmp_prepare_contiguous = fft64_rnx_vmp_prepare_contiguous_ref;
|
||||
module->vtable.rnx_vmp_prepare_dblptr = fft64_rnx_vmp_prepare_dblptr_ref;
|
||||
module->vtable.rnx_vmp_prepare_row = fft64_rnx_vmp_prepare_row_ref;
|
||||
module->vtable.bytes_of_rnx_vmp_pmat = fft64_bytes_of_rnx_vmp_pmat;
|
||||
module->vtable.rnx_approxdecomp_from_tnxdbl = rnx_approxdecomp_from_tnxdbl_ref;
|
||||
module->vtable.vec_rnx_to_znx32 = vec_rnx_to_znx32_ref;
|
||||
module->vtable.vec_rnx_from_znx32 = vec_rnx_from_znx32_ref;
|
||||
module->vtable.vec_rnx_to_tnx32 = vec_rnx_to_tnx32_ref;
|
||||
module->vtable.vec_rnx_from_tnx32 = vec_rnx_from_tnx32_ref;
|
||||
module->vtable.vec_rnx_to_tnxdbl = vec_rnx_to_tnxdbl_ref;
|
||||
module->vtable.bytes_of_rnx_svp_ppol = fft64_bytes_of_rnx_svp_ppol;
|
||||
module->vtable.rnx_svp_prepare = fft64_rnx_svp_prepare_ref;
|
||||
module->vtable.rnx_svp_apply = fft64_rnx_svp_apply_ref;
|
||||
|
||||
// Add optimized function pointers here
|
||||
if (CPU_SUPPORTS("avx")) {
|
||||
module->vtable.vec_rnx_add = vec_rnx_add_avx;
|
||||
module->vtable.vec_rnx_sub = vec_rnx_sub_avx;
|
||||
module->vtable.vec_rnx_negate = vec_rnx_negate_avx;
|
||||
module->vtable.rnx_vmp_apply_dft_to_dft_tmp_bytes = fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_avx;
|
||||
module->vtable.rnx_vmp_apply_dft_to_dft = fft64_rnx_vmp_apply_dft_to_dft_avx;
|
||||
module->vtable.rnx_vmp_apply_tmp_a_tmp_bytes = fft64_rnx_vmp_apply_tmp_a_tmp_bytes_avx;
|
||||
module->vtable.rnx_vmp_apply_tmp_a = fft64_rnx_vmp_apply_tmp_a_avx;
|
||||
module->vtable.rnx_vmp_prepare_tmp_bytes = fft64_rnx_vmp_prepare_tmp_bytes_avx;
|
||||
module->vtable.rnx_vmp_prepare_contiguous = fft64_rnx_vmp_prepare_contiguous_avx;
|
||||
module->vtable.rnx_vmp_prepare_dblptr = fft64_rnx_vmp_prepare_dblptr_avx;
|
||||
module->vtable.rnx_vmp_prepare_row = fft64_rnx_vmp_prepare_row_avx;
|
||||
module->vtable.rnx_approxdecomp_from_tnxdbl = rnx_approxdecomp_from_tnxdbl_avx;
|
||||
}
|
||||
}
|
||||
|
||||
void init_rnx_module_info(MOD_RNX* module, //
|
||||
uint64_t n, RNX_MODULE_TYPE mtype) {
|
||||
memset(module, 0, sizeof(MOD_RNX));
|
||||
module->n = n;
|
||||
module->m = n >> 1;
|
||||
module->mtype = mtype;
|
||||
switch (mtype) {
|
||||
case FFT64:
|
||||
fft64_init_rnx_module_precomp(module);
|
||||
fft64_init_rnx_module_vtable(module);
|
||||
break;
|
||||
default:
|
||||
NOT_SUPPORTED(); // unknown mtype
|
||||
}
|
||||
}
|
||||
|
||||
void finalize_rnx_module_info(MOD_RNX* module) {
|
||||
if (module->custom) module->custom_deleter(module->custom);
|
||||
switch (module->mtype) {
|
||||
case FFT64:
|
||||
fft64_finalize_rnx_module_precomp(module);
|
||||
// fft64_finalize_rnx_module_vtable(module); // nothing to finalize
|
||||
break;
|
||||
default:
|
||||
NOT_SUPPORTED(); // unknown mtype
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT MOD_RNX* new_rnx_module_info(uint64_t nn, RNX_MODULE_TYPE mtype) {
|
||||
MOD_RNX* res = (MOD_RNX*)malloc(sizeof(MOD_RNX));
|
||||
init_rnx_module_info(res, nn, mtype);
|
||||
return res;
|
||||
}
|
||||
|
||||
EXPORT void delete_rnx_module_info(MOD_RNX* module_info) {
|
||||
finalize_rnx_module_info(module_info);
|
||||
free(module_info);
|
||||
}
|
||||
|
||||
EXPORT uint64_t rnx_module_get_n(const MOD_RNX* module) { return module->n; }
|
||||
|
||||
/** @brief allocates a prepared matrix (release with delete_rnx_vmp_pmat) */
|
||||
EXPORT RNX_VMP_PMAT* new_rnx_vmp_pmat(const MOD_RNX* module, // N
|
||||
uint64_t nrows, uint64_t ncols) { // dimensions
|
||||
return (RNX_VMP_PMAT*)spqlios_alloc(bytes_of_rnx_vmp_pmat(module, nrows, ncols));
|
||||
}
|
||||
EXPORT void delete_rnx_vmp_pmat(RNX_VMP_PMAT* ptr) { spqlios_free(ptr); }
|
||||
|
||||
//////////////// wrappers //////////////////
|
||||
|
||||
/** @brief sets res = a + b */
|
||||
EXPORT void vec_rnx_add( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||
) {
|
||||
module->vtable.vec_rnx_add(module, res, res_size, res_sl, a, a_size, a_sl, b, b_size, b_sl);
|
||||
}
|
||||
|
||||
/** @brief sets res = 0 */
|
||||
EXPORT void vec_rnx_zero( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl // res
|
||||
) {
|
||||
module->vtable.vec_rnx_zero(module, res, res_size, res_sl);
|
||||
}
|
||||
|
||||
/** @brief sets res = a */
|
||||
EXPORT void vec_rnx_copy( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
module->vtable.vec_rnx_copy(module, res, res_size, res_sl, a, a_size, a_sl);
|
||||
}
|
||||
|
||||
/** @brief sets res = -a */
|
||||
EXPORT void vec_rnx_negate( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
module->vtable.vec_rnx_negate(module, res, res_size, res_sl, a, a_size, a_sl);
|
||||
}
|
||||
|
||||
/** @brief sets res = a - b */
|
||||
EXPORT void vec_rnx_sub( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||
) {
|
||||
module->vtable.vec_rnx_sub(module, res, res_size, res_sl, a, a_size, a_sl, b, b_size, b_sl);
|
||||
}
|
||||
|
||||
/** @brief sets res = a . X^p */
|
||||
EXPORT void vec_rnx_rotate( //
|
||||
const MOD_RNX* module, // N
|
||||
const int64_t p, // rotation value
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
module->vtable.vec_rnx_rotate(module, p, res, res_size, res_sl, a, a_size, a_sl);
|
||||
}
|
||||
|
||||
/** @brief sets res = a(X^p) */
|
||||
EXPORT void vec_rnx_automorphism( //
|
||||
const MOD_RNX* module, // N
|
||||
int64_t p, // X -> X^p
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
module->vtable.vec_rnx_automorphism(module, p, res, res_size, res_sl, a, a_size, a_sl);
|
||||
}
|
||||
|
||||
EXPORT void vec_rnx_mul_xp_minus_one( //
|
||||
const MOD_RNX* module, // N
|
||||
const int64_t p, // rotation value
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
module->vtable.vec_rnx_mul_xp_minus_one(module, p, res, res_size, res_sl, a, a_size, a_sl);
|
||||
}
|
||||
/** @brief number of bytes in a RNX_VMP_PMAT (for manual allocation) */
|
||||
EXPORT uint64_t bytes_of_rnx_vmp_pmat(const MOD_RNX* module, // N
|
||||
uint64_t nrows, uint64_t ncols) { // dimensions
|
||||
return module->vtable.bytes_of_rnx_vmp_pmat(module, nrows, ncols);
|
||||
}
|
||||
|
||||
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||
EXPORT void rnx_vmp_prepare_contiguous( //
|
||||
const MOD_RNX* module, // N
|
||||
RNX_VMP_PMAT* pmat, // output
|
||||
const double* a, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
) {
|
||||
module->vtable.rnx_vmp_prepare_contiguous(module, pmat, a, nrows, ncols, tmp_space);
|
||||
}
|
||||
|
||||
/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */
|
||||
EXPORT void rnx_vmp_prepare_dblptr( //
|
||||
const MOD_RNX* module, // N
|
||||
RNX_VMP_PMAT* pmat, // output
|
||||
const double** a, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
) {
|
||||
module->vtable.rnx_vmp_prepare_dblptr(module, pmat, a, nrows, ncols, tmp_space);
|
||||
}
|
||||
|
||||
/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */
|
||||
EXPORT void rnx_vmp_prepare_row( //
|
||||
const MOD_RNX* module, // N
|
||||
RNX_VMP_PMAT* pmat, // output
|
||||
const double* a, uint64_t row_i, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
) {
|
||||
module->vtable.rnx_vmp_prepare_row(module, pmat, a, row_i, nrows, ncols, tmp_space);
|
||||
}
|
||||
|
||||
/** @brief number of scratch bytes necessary to prepare a matrix */
|
||||
EXPORT uint64_t rnx_vmp_prepare_tmp_bytes(const MOD_RNX* module) {
|
||||
return module->vtable.rnx_vmp_prepare_tmp_bytes(module);
|
||||
}
|
||||
|
||||
/** @brief applies a vmp product res = a x pmat */
|
||||
EXPORT void rnx_vmp_apply_tmp_a( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten)
|
||||
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space
|
||||
) {
|
||||
module->vtable.rnx_vmp_apply_tmp_a(module, res, res_size, res_sl, tmpa, a_size, a_sl, pmat, nrows, ncols, tmp_space);
|
||||
}
|
||||
|
||||
EXPORT uint64_t rnx_vmp_apply_tmp_a_tmp_bytes( //
|
||||
const MOD_RNX* module, // N
|
||||
uint64_t res_size, // res size
|
||||
uint64_t a_size, // a size
|
||||
uint64_t nrows, uint64_t ncols // prep matrix dims
|
||||
) {
|
||||
return module->vtable.rnx_vmp_apply_tmp_a_tmp_bytes(module, res_size, a_size, nrows, ncols);
|
||||
}
|
||||
|
||||
/** @brief minimal size of the tmp_space */
|
||||
EXPORT void rnx_vmp_apply_dft_to_dft( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a_dft, uint64_t a_size, uint64_t a_sl, // a
|
||||
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||
) {
|
||||
module->vtable.rnx_vmp_apply_dft_to_dft(module, res, res_size, res_sl, a_dft, a_size, a_sl, pmat, nrows, ncols,
|
||||
tmp_space);
|
||||
}
|
||||
|
||||
/** @brief minimal size of the tmp_space */
|
||||
EXPORT uint64_t rnx_vmp_apply_dft_to_dft_tmp_bytes( //
|
||||
const MOD_RNX* module, // N
|
||||
uint64_t res_size, // res
|
||||
uint64_t a_size, // a
|
||||
uint64_t nrows, uint64_t ncols // prep matrix
|
||||
) {
|
||||
return module->vtable.rnx_vmp_apply_dft_to_dft_tmp_bytes(module, res_size, a_size, nrows, ncols);
|
||||
}
|
||||
|
||||
EXPORT uint64_t bytes_of_rnx_svp_ppol(const MOD_RNX* module) { return module->vtable.bytes_of_rnx_svp_ppol(module); }
|
||||
|
||||
EXPORT void rnx_svp_prepare(const MOD_RNX* module, // N
|
||||
RNX_SVP_PPOL* ppol, // output
|
||||
const double* pol // a
|
||||
) {
|
||||
module->vtable.rnx_svp_prepare(module, ppol, pol);
|
||||
}
|
||||
|
||||
EXPORT void rnx_svp_apply( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // output
|
||||
const RNX_SVP_PPOL* ppol, // prepared pol
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
module->vtable.rnx_svp_apply(module, // N
|
||||
res, res_size, res_sl, // output
|
||||
ppol, // prepared pol
|
||||
a, a_size, a_sl);
|
||||
}
|
||||
|
||||
EXPORT void rnx_approxdecomp_from_tnxdbl( //
|
||||
const MOD_RNX* module, // N
|
||||
const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a) { // a
|
||||
module->vtable.rnx_approxdecomp_from_tnxdbl(module, gadget, res, res_size, res_sl, a);
|
||||
}
|
||||
|
||||
EXPORT void vec_rnx_to_znx32( //
|
||||
const MOD_RNX* module, // N
|
||||
int32_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
module->vtable.vec_rnx_to_znx32(module, res, res_size, res_sl, a, a_size, a_sl);
|
||||
}
|
||||
|
||||
EXPORT void vec_rnx_from_znx32( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int32_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
module->vtable.vec_rnx_from_znx32(module, res, res_size, res_sl, a, a_size, a_sl);
|
||||
}
|
||||
|
||||
EXPORT void vec_rnx_to_tnx32( //
|
||||
const MOD_RNX* module, // N
|
||||
int32_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
module->vtable.vec_rnx_to_tnx32(module, res, res_size, res_sl, a, a_size, a_sl);
|
||||
}
|
||||
|
||||
EXPORT void vec_rnx_from_tnx32( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int32_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
module->vtable.vec_rnx_from_tnx32(module, res, res_size, res_sl, a, a_size, a_sl);
|
||||
}
|
||||
|
||||
EXPORT void vec_rnx_to_tnxdbl( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
module->vtable.vec_rnx_to_tnxdbl(module, res, res_size, res_sl, a, a_size, a_sl);
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
#include <memory.h>
|
||||
|
||||
#include "immintrin.h"
|
||||
#include "vec_rnx_arithmetic_private.h"
|
||||
|
||||
/** @brief sets res = gadget_decompose(a) */
|
||||
EXPORT void rnx_approxdecomp_from_tnxdbl_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a // a
|
||||
) {
|
||||
const uint64_t nn = module->n;
|
||||
if (nn < 4) return rnx_approxdecomp_from_tnxdbl_ref(module, gadget, res, res_size, res_sl, a);
|
||||
const uint64_t ell = gadget->ell;
|
||||
const __m256i k = _mm256_set1_epi64x(gadget->k);
|
||||
const __m256d add_cst = _mm256_set1_pd(gadget->add_cst);
|
||||
const __m256i and_mask = _mm256_set1_epi64x(gadget->and_mask);
|
||||
const __m256i or_mask = _mm256_set1_epi64x(gadget->or_mask);
|
||||
const __m256d sub_cst = _mm256_set1_pd(gadget->sub_cst);
|
||||
const uint64_t msize = res_size <= ell ? res_size : ell;
|
||||
// gadget decompose column by column
|
||||
if (msize == ell) {
|
||||
// this is the main scenario when msize == ell
|
||||
double* const last_r = res + (msize - 1) * res_sl;
|
||||
for (uint64_t j = 0; j < nn; j += 4) {
|
||||
double* rr = last_r + j;
|
||||
const double* aa = a + j;
|
||||
__m256d t_dbl = _mm256_add_pd(_mm256_loadu_pd(aa), add_cst);
|
||||
__m256i t_int = _mm256_castpd_si256(t_dbl);
|
||||
do {
|
||||
__m256i u_int = _mm256_or_si256(_mm256_and_si256(t_int, and_mask), or_mask);
|
||||
_mm256_storeu_pd(rr, _mm256_sub_pd(_mm256_castsi256_pd(u_int), sub_cst));
|
||||
t_int = _mm256_srlv_epi64(t_int, k);
|
||||
rr -= res_sl;
|
||||
} while (rr >= res);
|
||||
}
|
||||
} else if (msize > 0) {
|
||||
// otherwise, if msize < ell: there is one additional rshift
|
||||
const __m256i first_rsh = _mm256_set1_epi64x((ell - msize) * gadget->k);
|
||||
double* const last_r = res + (msize - 1) * res_sl;
|
||||
for (uint64_t j = 0; j < nn; j += 4) {
|
||||
double* rr = last_r + j;
|
||||
const double* aa = a + j;
|
||||
__m256d t_dbl = _mm256_add_pd(_mm256_loadu_pd(aa), add_cst);
|
||||
__m256i t_int = _mm256_srlv_epi64(_mm256_castpd_si256(t_dbl), first_rsh);
|
||||
do {
|
||||
__m256i u_int = _mm256_or_si256(_mm256_and_si256(t_int, and_mask), or_mask);
|
||||
_mm256_storeu_pd(rr, _mm256_sub_pd(_mm256_castsi256_pd(u_int), sub_cst));
|
||||
t_int = _mm256_srlv_epi64(t_int, k);
|
||||
rr -= res_sl;
|
||||
} while (rr >= res);
|
||||
}
|
||||
}
|
||||
// zero-out the last slices (if any)
|
||||
for (uint64_t i = msize; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
#include <memory.h>
|
||||
|
||||
#include "vec_rnx_arithmetic_private.h"
|
||||
|
||||
typedef union di {
|
||||
double dv;
|
||||
uint64_t uv;
|
||||
} di_t;
|
||||
|
||||
/** @brief new gadget: delete with delete_tnxdbl_approxdecomp_gadget */
|
||||
EXPORT TNXDBL_APPROXDECOMP_GADGET* new_tnxdbl_approxdecomp_gadget( //
|
||||
const MOD_RNX* module, // N
|
||||
uint64_t k, uint64_t ell // base 2^K and size
|
||||
) {
|
||||
if (k * ell > 50) return spqlios_error("gadget requires a too large fp precision");
|
||||
TNXDBL_APPROXDECOMP_GADGET* res = spqlios_alloc(sizeof(TNXDBL_APPROXDECOMP_GADGET));
|
||||
res->k = k;
|
||||
res->ell = ell;
|
||||
// double add_cst; // double(3.2^(51-ell.K) + 1/2.(sum 2^(-iK)) for i=[0,ell[)
|
||||
union di add_cst;
|
||||
add_cst.dv = UINT64_C(3) << (51 - ell * k);
|
||||
for (uint64_t i = 0; i < ell; ++i) {
|
||||
add_cst.uv |= UINT64_C(1) << ((i + 1) * k - 1);
|
||||
}
|
||||
res->add_cst = add_cst.dv;
|
||||
// uint64_t and_mask; // uint64(2^(K)-1)
|
||||
res->and_mask = (UINT64_C(1) << k) - 1;
|
||||
// uint64_t or_mask; // double(2^52)
|
||||
union di or_mask;
|
||||
or_mask.dv = (UINT64_C(1) << 52);
|
||||
res->or_mask = or_mask.uv;
|
||||
// double sub_cst; // double(2^52 + 2^(K-1))
|
||||
res->sub_cst = ((UINT64_C(1) << 52) + (UINT64_C(1) << (k - 1)));
|
||||
return res;
|
||||
}
|
||||
|
||||
EXPORT void delete_tnxdbl_approxdecomp_gadget(TNXDBL_APPROXDECOMP_GADGET* gadget) { spqlios_free(gadget); }
|
||||
|
||||
/** @brief sets res = gadget_decompose(a) */
|
||||
EXPORT void rnx_approxdecomp_from_tnxdbl_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a // a
|
||||
) {
|
||||
const uint64_t nn = module->n;
|
||||
const uint64_t k = gadget->k;
|
||||
const uint64_t ell = gadget->ell;
|
||||
const double add_cst = gadget->add_cst;
|
||||
const uint64_t and_mask = gadget->and_mask;
|
||||
const uint64_t or_mask = gadget->or_mask;
|
||||
const double sub_cst = gadget->sub_cst;
|
||||
const uint64_t msize = res_size <= ell ? res_size : ell;
|
||||
const uint64_t first_rsh = (ell - msize) * k;
|
||||
// gadget decompose column by column
|
||||
if (msize > 0) {
|
||||
double* const last_r = res + (msize - 1) * res_sl;
|
||||
for (uint64_t j = 0; j < nn; ++j) {
|
||||
double* rr = last_r + j;
|
||||
di_t t = {.dv = a[j] + add_cst};
|
||||
if (msize < ell) t.uv >>= first_rsh;
|
||||
do {
|
||||
di_t u;
|
||||
u.uv = (t.uv & and_mask) | or_mask;
|
||||
*rr = u.dv - sub_cst;
|
||||
t.uv >>= k;
|
||||
rr -= res_sl;
|
||||
} while (rr >= res);
|
||||
}
|
||||
}
|
||||
// zero-out the last slices (if any)
|
||||
for (uint64_t i = msize; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,223 @@
|
||||
#include <string.h>
|
||||
|
||||
#include "../coeffs/coeffs_arithmetic.h"
|
||||
#include "vec_rnx_arithmetic_private.h"
|
||||
|
||||
void rnx_add_ref(uint64_t nn, double* res, const double* a, const double* b) {
|
||||
for (uint64_t i = 0; i < nn; ++i) {
|
||||
res[i] = a[i] + b[i];
|
||||
}
|
||||
}
|
||||
|
||||
void rnx_sub_ref(uint64_t nn, double* res, const double* a, const double* b) {
|
||||
for (uint64_t i = 0; i < nn; ++i) {
|
||||
res[i] = a[i] - b[i];
|
||||
}
|
||||
}
|
||||
|
||||
void rnx_negate_ref(uint64_t nn, double* res, const double* a) {
|
||||
for (uint64_t i = 0; i < nn; ++i) {
|
||||
res[i] = -a[i];
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief sets res = a + b */
|
||||
EXPORT void vec_rnx_add_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||
) {
|
||||
const uint64_t nn = module->n;
|
||||
if (a_size < b_size) {
|
||||
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||
const uint64_t nsize = res_size < b_size ? res_size : b_size;
|
||||
for (uint64_t i = 0; i < msize; ++i) {
|
||||
rnx_add_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||
}
|
||||
for (uint64_t i = msize; i < nsize; ++i) {
|
||||
memcpy(res + i * res_sl, b + i * b_sl, nn * sizeof(double));
|
||||
}
|
||||
for (uint64_t i = nsize; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
} else {
|
||||
const uint64_t msize = res_size < b_size ? res_size : b_size;
|
||||
const uint64_t nsize = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < msize; ++i) {
|
||||
rnx_add_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||
}
|
||||
for (uint64_t i = msize; i < nsize; ++i) {
|
||||
memcpy(res + i * res_sl, a + i * a_sl, nn * sizeof(double));
|
||||
}
|
||||
for (uint64_t i = nsize; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief sets res = 0 */
|
||||
EXPORT void vec_rnx_zero_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl // res
|
||||
) {
|
||||
const uint64_t nn = module->n;
|
||||
for (uint64_t i = 0; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief sets res = a */
|
||||
EXPORT void vec_rnx_copy_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
const uint64_t nn = module->n;
|
||||
|
||||
const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size;
|
||||
// rotate up to the smallest dimension
|
||||
for (uint64_t i = 0; i < rot_end_idx; ++i) {
|
||||
double* res_ptr = res + i * res_sl;
|
||||
const double* a_ptr = a + i * a_sl;
|
||||
memcpy(res_ptr, a_ptr, nn * sizeof(double));
|
||||
}
|
||||
// then extend with zeros
|
||||
for (uint64_t i = rot_end_idx; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief sets res = -a */
|
||||
EXPORT void vec_rnx_negate_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
const uint64_t nn = module->n;
|
||||
|
||||
const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size;
|
||||
// rotate up to the smallest dimension
|
||||
for (uint64_t i = 0; i < rot_end_idx; ++i) {
|
||||
double* res_ptr = res + i * res_sl;
|
||||
const double* a_ptr = a + i * a_sl;
|
||||
rnx_negate_ref(nn, res_ptr, a_ptr);
|
||||
}
|
||||
// then extend with zeros
|
||||
for (uint64_t i = rot_end_idx; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief sets res = a - b */
|
||||
EXPORT void vec_rnx_sub_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||
) {
|
||||
const uint64_t nn = module->n;
|
||||
if (a_size < b_size) {
|
||||
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||
const uint64_t nsize = res_size < b_size ? res_size : b_size;
|
||||
for (uint64_t i = 0; i < msize; ++i) {
|
||||
rnx_sub_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||
}
|
||||
for (uint64_t i = msize; i < nsize; ++i) {
|
||||
rnx_negate_ref(nn, res + i * res_sl, b + i * b_sl);
|
||||
}
|
||||
for (uint64_t i = nsize; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
} else {
|
||||
const uint64_t msize = res_size < b_size ? res_size : b_size;
|
||||
const uint64_t nsize = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < msize; ++i) {
|
||||
rnx_sub_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||
}
|
||||
for (uint64_t i = msize; i < nsize; ++i) {
|
||||
memcpy(res + i * res_sl, a + i * a_sl, nn * sizeof(double));
|
||||
}
|
||||
for (uint64_t i = nsize; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief sets res = a . X^p */
|
||||
EXPORT void vec_rnx_rotate_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
const int64_t p, // rotation value
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
const uint64_t nn = module->n;
|
||||
|
||||
const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size;
|
||||
// rotate up to the smallest dimension
|
||||
for (uint64_t i = 0; i < rot_end_idx; ++i) {
|
||||
double* res_ptr = res + i * res_sl;
|
||||
const double* a_ptr = a + i * a_sl;
|
||||
if (res_ptr == a_ptr) {
|
||||
rnx_rotate_inplace_f64(nn, p, res_ptr);
|
||||
} else {
|
||||
rnx_rotate_f64(nn, p, res_ptr, a_ptr);
|
||||
}
|
||||
}
|
||||
// then extend with zeros
|
||||
for (uint64_t i = rot_end_idx; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief sets res = a(X^p) */
|
||||
EXPORT void vec_rnx_automorphism_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
int64_t p, // X -> X^p
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
const uint64_t nn = module->n;
|
||||
|
||||
const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size;
|
||||
// rotate up to the smallest dimension
|
||||
for (uint64_t i = 0; i < rot_end_idx; ++i) {
|
||||
double* res_ptr = res + i * res_sl;
|
||||
const double* a_ptr = a + i * a_sl;
|
||||
if (res_ptr == a_ptr) {
|
||||
rnx_automorphism_inplace_f64(nn, p, res_ptr);
|
||||
} else {
|
||||
rnx_automorphism_f64(nn, p, res_ptr, a_ptr);
|
||||
}
|
||||
}
|
||||
// then extend with zeros
|
||||
for (uint64_t i = rot_end_idx; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief sets res = a . (X^p - 1) */
|
||||
EXPORT void vec_rnx_mul_xp_minus_one_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
const int64_t p, // rotation value
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
const uint64_t nn = module->n;
|
||||
|
||||
const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size;
|
||||
// rotate up to the smallest dimension
|
||||
for (uint64_t i = 0; i < rot_end_idx; ++i) {
|
||||
double* res_ptr = res + i * res_sl;
|
||||
const double* a_ptr = a + i * a_sl;
|
||||
if (res_ptr == a_ptr) {
|
||||
rnx_mul_xp_minus_one_inplace_f64(nn, p, res_ptr);
|
||||
} else {
|
||||
rnx_mul_xp_minus_one_f64(nn, p, res_ptr, a_ptr);
|
||||
}
|
||||
}
|
||||
// then extend with zeros
|
||||
for (uint64_t i = rot_end_idx; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,356 @@
|
||||
#ifndef SPQLIOS_VEC_RNX_ARITHMETIC_H
|
||||
#define SPQLIOS_VEC_RNX_ARITHMETIC_H
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include "../commons.h"
|
||||
|
||||
/**
|
||||
* We support the following module families:
|
||||
* - FFT64:
|
||||
* the overall precision should fit at all times over 52 bits.
|
||||
*/
|
||||
typedef enum rnx_module_type_t { FFT64 } RNX_MODULE_TYPE;
|
||||
|
||||
/** @brief opaque structure that describes the modules (RnX,ZnX,TnX) and the hardware */
|
||||
typedef struct rnx_module_info_t MOD_RNX;
|
||||
|
||||
/**
|
||||
* @brief obtain a module info for ring dimension N
|
||||
* the module-info knows about:
|
||||
* - the dimension N (or the complex dimension m=N/2)
|
||||
* - any moduleuted fft or ntt items
|
||||
* - the hardware (avx, arm64, x86, ...)
|
||||
*/
|
||||
EXPORT MOD_RNX* new_rnx_module_info(uint64_t N, RNX_MODULE_TYPE mode);
|
||||
EXPORT void delete_rnx_module_info(MOD_RNX* module_info);
|
||||
EXPORT uint64_t rnx_module_get_n(const MOD_RNX* module);
|
||||
|
||||
// basic arithmetic
|
||||
|
||||
/** @brief sets res = 0 */
|
||||
EXPORT void vec_rnx_zero( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl // res
|
||||
);
|
||||
|
||||
/** @brief sets res = a */
|
||||
EXPORT void vec_rnx_copy( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
/** @brief sets res = -a */
|
||||
EXPORT void vec_rnx_negate( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
/** @brief sets res = a + b */
|
||||
EXPORT void vec_rnx_add( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||
);
|
||||
|
||||
/** @brief sets res = a - b */
|
||||
EXPORT void vec_rnx_sub( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||
);
|
||||
|
||||
/** @brief sets res = a . X^p */
|
||||
EXPORT void vec_rnx_rotate( //
|
||||
const MOD_RNX* module, // N
|
||||
const int64_t p, // rotation value
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
/** @brief sets res = a . (X^p - 1) */
|
||||
EXPORT void vec_rnx_mul_xp_minus_one( //
|
||||
const MOD_RNX* module, // N
|
||||
const int64_t p, // rotation value
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
/** @brief sets res = a(X^p) */
|
||||
EXPORT void vec_rnx_automorphism( //
|
||||
const MOD_RNX* module, // N
|
||||
int64_t p, // X -> X^p
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
// conversions //
|
||||
///////////////////////////////////////////////////////////////////
|
||||
|
||||
EXPORT void vec_rnx_to_znx32( //
|
||||
const MOD_RNX* module, // N
|
||||
int32_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
EXPORT void vec_rnx_from_znx32( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int32_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
EXPORT void vec_rnx_to_tnx32( //
|
||||
const MOD_RNX* module, // N
|
||||
int32_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
EXPORT void vec_rnx_from_tnx32( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int32_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
EXPORT void vec_rnx_to_tnx32x2( //
|
||||
const MOD_RNX* module, // N
|
||||
int32_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
EXPORT void vec_rnx_from_tnx32x2( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int32_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
EXPORT void vec_rnx_to_tnxdbl( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
// isolated products (n.log(n), but not particularly optimized //
|
||||
///////////////////////////////////////////////////////////////////
|
||||
|
||||
/** @brief res = a * b : small polynomial product */
|
||||
EXPORT void rnx_small_single_product( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, // output
|
||||
const double* a, // a
|
||||
const double* b, // b
|
||||
uint8_t* tmp); // scratch space
|
||||
|
||||
EXPORT uint64_t rnx_small_single_product_tmp_bytes(const MOD_RNX* module);
|
||||
|
||||
/** @brief res = a * b centermod 1: small polynomial product */
|
||||
EXPORT void tnxdbl_small_single_product( //
|
||||
const MOD_RNX* module, // N
|
||||
double* torus_res, // output
|
||||
const double* int_a, // a
|
||||
const double* torus_b, // b
|
||||
uint8_t* tmp); // scratch space
|
||||
|
||||
EXPORT uint64_t tnxdbl_small_single_product_tmp_bytes(const MOD_RNX* module);
|
||||
|
||||
/** @brief res = a * b: small polynomial product */
|
||||
EXPORT void znx32_small_single_product( //
|
||||
const MOD_RNX* module, // N
|
||||
int32_t* int_res, // output
|
||||
const int32_t* int_a, // a
|
||||
const int32_t* int_b, // b
|
||||
uint8_t* tmp); // scratch space
|
||||
|
||||
EXPORT uint64_t znx32_small_single_product_tmp_bytes(const MOD_RNX* module);
|
||||
|
||||
/** @brief res = a * b centermod 1: small polynomial product */
|
||||
EXPORT void tnx32_small_single_product( //
|
||||
const MOD_RNX* module, // N
|
||||
int32_t* torus_res, // output
|
||||
const int32_t* int_a, // a
|
||||
const int32_t* torus_b, // b
|
||||
uint8_t* tmp); // scratch space
|
||||
|
||||
EXPORT uint64_t tnx32_small_single_product_tmp_bytes(const MOD_RNX* module);
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
// prepared gadget decompositions (optimized) //
|
||||
///////////////////////////////////////////////////////////////////
|
||||
|
||||
// decompose from tnx32
|
||||
|
||||
typedef struct tnx32_approxdecomp_gadget_t TNX32_APPROXDECOMP_GADGET;
|
||||
|
||||
/** @brief new gadget: delete with delete_tnx32_approxdecomp_gadget */
|
||||
EXPORT TNX32_APPROXDECOMP_GADGET* new_tnx32_approxdecomp_gadget( //
|
||||
const MOD_RNX* module, // N
|
||||
uint64_t k, uint64_t ell // base 2^K and size
|
||||
);
|
||||
EXPORT void delete_tnx32_approxdecomp_gadget(const MOD_RNX* module);
|
||||
|
||||
/** @brief sets res = gadget_decompose(a) */
|
||||
EXPORT void rnx_approxdecomp_from_tnx32( //
|
||||
const MOD_RNX* module, // N
|
||||
const TNX32_APPROXDECOMP_GADGET* gadget, // output base 2^K
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int32_t* a // a
|
||||
);
|
||||
|
||||
// decompose from tnx32x2
|
||||
|
||||
typedef struct tnx32x2_approxdecomp_gadget_t TNX32X2_APPROXDECOMP_GADGET;
|
||||
|
||||
/** @brief new gadget: delete with delete_tnx32x2_approxdecomp_gadget */
|
||||
EXPORT TNX32X2_APPROXDECOMP_GADGET* new_tnx32x2_approxdecomp_gadget(const MOD_RNX* module, uint64_t ka, uint64_t ella,
|
||||
uint64_t kb, uint64_t ellb);
|
||||
EXPORT void delete_tnx32x2_approxdecomp_gadget(const MOD_RNX* module);
|
||||
|
||||
/** @brief sets res = gadget_decompose(a) */
|
||||
EXPORT void rnx_approxdecomp_from_tnx32x2( //
|
||||
const MOD_RNX* module, // N
|
||||
const TNX32X2_APPROXDECOMP_GADGET* gadget, // output base 2^K
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int32_t* a // a
|
||||
);
|
||||
|
||||
// decompose from tnxdbl
|
||||
|
||||
typedef struct tnxdbl_approxdecomp_gadget_t TNXDBL_APPROXDECOMP_GADGET;
|
||||
|
||||
/** @brief new gadget: delete with delete_tnxdbl_approxdecomp_gadget */
|
||||
EXPORT TNXDBL_APPROXDECOMP_GADGET* new_tnxdbl_approxdecomp_gadget( //
|
||||
const MOD_RNX* module, // N
|
||||
uint64_t k, uint64_t ell // base 2^K and size
|
||||
);
|
||||
EXPORT void delete_tnxdbl_approxdecomp_gadget(TNXDBL_APPROXDECOMP_GADGET* gadget);
|
||||
|
||||
/** @brief sets res = gadget_decompose(a) */
|
||||
EXPORT void rnx_approxdecomp_from_tnxdbl( //
|
||||
const MOD_RNX* module, // N
|
||||
const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a); // a
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
// prepared scalar-vector product (optimized) //
|
||||
///////////////////////////////////////////////////////////////////
|
||||
|
||||
/** @brief opaque type that represents a polynomial of RnX prepared for a scalar-vector product */
|
||||
typedef struct rnx_svp_ppol_t RNX_SVP_PPOL;
|
||||
|
||||
/** @brief number of bytes in a RNX_VMP_PMAT (for manual allocation) */
|
||||
EXPORT uint64_t bytes_of_rnx_svp_ppol(const MOD_RNX* module); // N
|
||||
|
||||
/** @brief allocates a prepared vector (release with delete_rnx_svp_ppol) */
|
||||
EXPORT RNX_SVP_PPOL* new_rnx_svp_ppol(const MOD_RNX* module); // N
|
||||
|
||||
/** @brief frees memory for a prepared vector */
|
||||
EXPORT void delete_rnx_svp_ppol(RNX_SVP_PPOL* res);
|
||||
|
||||
/** @brief prepares a svp polynomial */
|
||||
EXPORT void rnx_svp_prepare(const MOD_RNX* module, // N
|
||||
RNX_SVP_PPOL* ppol, // output
|
||||
const double* pol // a
|
||||
);
|
||||
|
||||
/** @brief apply a svp product, result = ppol * a, presented in DFT space */
|
||||
EXPORT void rnx_svp_apply( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // output
|
||||
const RNX_SVP_PPOL* ppol, // prepared pol
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
// prepared vector-matrix product (optimized) //
|
||||
///////////////////////////////////////////////////////////////////
|
||||
|
||||
typedef struct rnx_vmp_pmat_t RNX_VMP_PMAT;
|
||||
|
||||
/** @brief number of bytes in a RNX_VMP_PMAT (for manual allocation) */
|
||||
EXPORT uint64_t bytes_of_rnx_vmp_pmat(const MOD_RNX* module, // N
|
||||
uint64_t nrows, uint64_t ncols); // dimensions
|
||||
|
||||
/** @brief allocates a prepared matrix (release with delete_rnx_vmp_pmat) */
|
||||
EXPORT RNX_VMP_PMAT* new_rnx_vmp_pmat(const MOD_RNX* module, // N
|
||||
uint64_t nrows, uint64_t ncols); // dimensions
|
||||
EXPORT void delete_rnx_vmp_pmat(RNX_VMP_PMAT* ptr);
|
||||
|
||||
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||
EXPORT void rnx_vmp_prepare_contiguous( //
|
||||
const MOD_RNX* module, // N
|
||||
RNX_VMP_PMAT* pmat, // output
|
||||
const double* a, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
|
||||
/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */
|
||||
EXPORT void rnx_vmp_prepare_dblptr( //
|
||||
const MOD_RNX* module, // N
|
||||
RNX_VMP_PMAT* pmat, // output
|
||||
const double** a, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
|
||||
/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */
|
||||
EXPORT void rnx_vmp_prepare_row( //
|
||||
const MOD_RNX* module, // N
|
||||
RNX_VMP_PMAT* pmat, // output
|
||||
const double* a, uint64_t row_i, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
|
||||
/** @brief number of scratch bytes necessary to prepare a matrix */
|
||||
EXPORT uint64_t rnx_vmp_prepare_tmp_bytes(const MOD_RNX* module);
|
||||
|
||||
/** @brief applies a vmp product res = a x pmat */
|
||||
EXPORT void rnx_vmp_apply_tmp_a( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten)
|
||||
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
|
||||
EXPORT uint64_t rnx_vmp_apply_tmp_a_tmp_bytes( //
|
||||
const MOD_RNX* module, // N
|
||||
uint64_t res_size, // res size
|
||||
uint64_t a_size, // a size
|
||||
uint64_t nrows, uint64_t ncols // prep matrix dims
|
||||
);
|
||||
|
||||
/** @brief minimal size of the tmp_space */
|
||||
EXPORT void rnx_vmp_apply_dft_to_dft( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a_dft, uint64_t a_size, uint64_t a_sl, // a
|
||||
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||
);
|
||||
|
||||
/** @brief minimal size of the tmp_space */
|
||||
EXPORT uint64_t rnx_vmp_apply_dft_to_dft_tmp_bytes( //
|
||||
const MOD_RNX* module, // N
|
||||
uint64_t res_size, // res
|
||||
uint64_t a_size, // a
|
||||
uint64_t nrows, uint64_t ncols // prep matrix
|
||||
);
|
||||
|
||||
/** @brief sets res = DFT(a) */
|
||||
EXPORT void vec_rnx_dft(const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
/** @brief sets res = iDFT(a_dft) -- idft is not normalized */
|
||||
EXPORT void vec_rnx_idft(const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a_dft, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
#endif // SPQLIOS_VEC_RNX_ARITHMETIC_H
|
||||
@@ -0,0 +1,189 @@
|
||||
#include <immintrin.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "vec_rnx_arithmetic_private.h"
|
||||
|
||||
void rnx_add_avx(uint64_t nn, double* res, const double* a, const double* b) {
|
||||
if (nn < 8) {
|
||||
if (nn == 4) {
|
||||
_mm256_storeu_pd(res, _mm256_add_pd(_mm256_loadu_pd(a), _mm256_loadu_pd(b)));
|
||||
} else if (nn == 2) {
|
||||
_mm_storeu_pd(res, _mm_add_pd(_mm_loadu_pd(a), _mm_loadu_pd(b)));
|
||||
} else if (nn == 1) {
|
||||
*res = *a + *b;
|
||||
} else {
|
||||
NOT_SUPPORTED(); // not a power of 2
|
||||
}
|
||||
return;
|
||||
}
|
||||
// general case: nn >= 8
|
||||
__m256d x0, x1, x2, x3, x4, x5;
|
||||
const double* aa = a;
|
||||
const double* bb = b;
|
||||
double* rr = res;
|
||||
double* const rrend = res + nn;
|
||||
do {
|
||||
x0 = _mm256_loadu_pd(aa);
|
||||
x1 = _mm256_loadu_pd(aa + 4);
|
||||
x2 = _mm256_loadu_pd(bb);
|
||||
x3 = _mm256_loadu_pd(bb + 4);
|
||||
x4 = _mm256_add_pd(x0, x2);
|
||||
x5 = _mm256_add_pd(x1, x3);
|
||||
_mm256_storeu_pd(rr, x4);
|
||||
_mm256_storeu_pd(rr + 4, x5);
|
||||
aa += 8;
|
||||
bb += 8;
|
||||
rr += 8;
|
||||
} while (rr < rrend);
|
||||
}
|
||||
|
||||
void rnx_sub_avx(uint64_t nn, double* res, const double* a, const double* b) {
|
||||
if (nn < 8) {
|
||||
if (nn == 4) {
|
||||
_mm256_storeu_pd(res, _mm256_sub_pd(_mm256_loadu_pd(a), _mm256_loadu_pd(b)));
|
||||
} else if (nn == 2) {
|
||||
_mm_storeu_pd(res, _mm_sub_pd(_mm_loadu_pd(a), _mm_loadu_pd(b)));
|
||||
} else if (nn == 1) {
|
||||
*res = *a - *b;
|
||||
} else {
|
||||
NOT_SUPPORTED(); // not a power of 2
|
||||
}
|
||||
return;
|
||||
}
|
||||
// general case: nn >= 8
|
||||
__m256d x0, x1, x2, x3, x4, x5;
|
||||
const double* aa = a;
|
||||
const double* bb = b;
|
||||
double* rr = res;
|
||||
double* const rrend = res + nn;
|
||||
do {
|
||||
x0 = _mm256_loadu_pd(aa);
|
||||
x1 = _mm256_loadu_pd(aa + 4);
|
||||
x2 = _mm256_loadu_pd(bb);
|
||||
x3 = _mm256_loadu_pd(bb + 4);
|
||||
x4 = _mm256_sub_pd(x0, x2);
|
||||
x5 = _mm256_sub_pd(x1, x3);
|
||||
_mm256_storeu_pd(rr, x4);
|
||||
_mm256_storeu_pd(rr + 4, x5);
|
||||
aa += 8;
|
||||
bb += 8;
|
||||
rr += 8;
|
||||
} while (rr < rrend);
|
||||
}
|
||||
|
||||
void rnx_negate_avx(uint64_t nn, double* res, const double* b) {
|
||||
if (nn < 8) {
|
||||
if (nn == 4) {
|
||||
_mm256_storeu_pd(res, _mm256_sub_pd(_mm256_set1_pd(0), _mm256_loadu_pd(b)));
|
||||
} else if (nn == 2) {
|
||||
_mm_storeu_pd(res, _mm_sub_pd(_mm_set1_pd(0), _mm_loadu_pd(b)));
|
||||
} else if (nn == 1) {
|
||||
*res = -*b;
|
||||
} else {
|
||||
NOT_SUPPORTED(); // not a power of 2
|
||||
}
|
||||
return;
|
||||
}
|
||||
// general case: nn >= 8
|
||||
__m256d x2, x3, x4, x5;
|
||||
const __m256d ZERO = _mm256_set1_pd(0);
|
||||
const double* bb = b;
|
||||
double* rr = res;
|
||||
double* const rrend = res + nn;
|
||||
do {
|
||||
x2 = _mm256_loadu_pd(bb);
|
||||
x3 = _mm256_loadu_pd(bb + 4);
|
||||
x4 = _mm256_sub_pd(ZERO, x2);
|
||||
x5 = _mm256_sub_pd(ZERO, x3);
|
||||
_mm256_storeu_pd(rr, x4);
|
||||
_mm256_storeu_pd(rr + 4, x5);
|
||||
bb += 8;
|
||||
rr += 8;
|
||||
} while (rr < rrend);
|
||||
}
|
||||
|
||||
/** @brief sets res = a + b */
|
||||
EXPORT void vec_rnx_add_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||
) {
|
||||
const uint64_t nn = module->n;
|
||||
if (a_size < b_size) {
|
||||
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||
const uint64_t nsize = res_size < b_size ? res_size : b_size;
|
||||
for (uint64_t i = 0; i < msize; ++i) {
|
||||
rnx_add_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||
}
|
||||
for (uint64_t i = msize; i < nsize; ++i) {
|
||||
memcpy(res + i * res_sl, b + i * b_sl, nn * sizeof(double));
|
||||
}
|
||||
for (uint64_t i = nsize; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
} else {
|
||||
const uint64_t msize = res_size < b_size ? res_size : b_size;
|
||||
const uint64_t nsize = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < msize; ++i) {
|
||||
rnx_add_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||
}
|
||||
for (uint64_t i = msize; i < nsize; ++i) {
|
||||
memcpy(res + i * res_sl, a + i * a_sl, nn * sizeof(double));
|
||||
}
|
||||
for (uint64_t i = nsize; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief sets res = -a */
|
||||
EXPORT void vec_rnx_negate_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
const uint64_t nn = module->n;
|
||||
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < msize; ++i) {
|
||||
rnx_negate_avx(nn, res + i * res_sl, a + i * a_sl);
|
||||
}
|
||||
for (uint64_t i = msize; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief sets res = a - b */
|
||||
EXPORT void vec_rnx_sub_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||
) {
|
||||
const uint64_t nn = module->n;
|
||||
if (a_size < b_size) {
|
||||
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||
const uint64_t nsize = res_size < b_size ? res_size : b_size;
|
||||
for (uint64_t i = 0; i < msize; ++i) {
|
||||
rnx_sub_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||
}
|
||||
for (uint64_t i = msize; i < nsize; ++i) {
|
||||
rnx_negate_avx(nn, res + i * res_sl, b + i * b_sl);
|
||||
}
|
||||
for (uint64_t i = nsize; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
} else {
|
||||
const uint64_t msize = res_size < b_size ? res_size : b_size;
|
||||
const uint64_t nsize = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < msize; ++i) {
|
||||
rnx_sub_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||
}
|
||||
for (uint64_t i = msize; i < nsize; ++i) {
|
||||
memcpy(res + i * res_sl, a + i * a_sl, nn * sizeof(double));
|
||||
}
|
||||
for (uint64_t i = nsize; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,92 @@
|
||||
#ifndef SPQLIOS_VEC_RNX_ARITHMETIC_PLUGIN_H
|
||||
#define SPQLIOS_VEC_RNX_ARITHMETIC_PLUGIN_H
|
||||
|
||||
#include "vec_rnx_arithmetic.h"
|
||||
|
||||
typedef typeof(vec_rnx_zero) VEC_RNX_ZERO_F;
|
||||
typedef typeof(vec_rnx_copy) VEC_RNX_COPY_F;
|
||||
typedef typeof(vec_rnx_negate) VEC_RNX_NEGATE_F;
|
||||
typedef typeof(vec_rnx_add) VEC_RNX_ADD_F;
|
||||
typedef typeof(vec_rnx_sub) VEC_RNX_SUB_F;
|
||||
typedef typeof(vec_rnx_rotate) VEC_RNX_ROTATE_F;
|
||||
typedef typeof(vec_rnx_mul_xp_minus_one) VEC_RNX_MUL_XP_MINUS_ONE_F;
|
||||
typedef typeof(vec_rnx_automorphism) VEC_RNX_AUTOMORPHISM_F;
|
||||
typedef typeof(vec_rnx_to_znx32) VEC_RNX_TO_ZNX32_F;
|
||||
typedef typeof(vec_rnx_from_znx32) VEC_RNX_FROM_ZNX32_F;
|
||||
typedef typeof(vec_rnx_to_tnx32) VEC_RNX_TO_TNX32_F;
|
||||
typedef typeof(vec_rnx_from_tnx32) VEC_RNX_FROM_TNX32_F;
|
||||
typedef typeof(vec_rnx_to_tnx32x2) VEC_RNX_TO_TNX32X2_F;
|
||||
typedef typeof(vec_rnx_from_tnx32x2) VEC_RNX_FROM_TNX32X2_F;
|
||||
typedef typeof(vec_rnx_to_tnxdbl) VEC_RNX_TO_TNXDBL_F;
|
||||
// typedef typeof(vec_rnx_from_tnxdbl) VEC_RNX_FROM_TNXDBL_F;
|
||||
typedef typeof(rnx_small_single_product) RNX_SMALL_SINGLE_PRODUCT_F;
|
||||
typedef typeof(rnx_small_single_product_tmp_bytes) RNX_SMALL_SINGLE_PRODUCT_TMP_BYTES_F;
|
||||
typedef typeof(tnxdbl_small_single_product) TNXDBL_SMALL_SINGLE_PRODUCT_F;
|
||||
typedef typeof(tnxdbl_small_single_product_tmp_bytes) TNXDBL_SMALL_SINGLE_PRODUCT_TMP_BYTES_F;
|
||||
typedef typeof(znx32_small_single_product) ZNX32_SMALL_SINGLE_PRODUCT_F;
|
||||
typedef typeof(znx32_small_single_product_tmp_bytes) ZNX32_SMALL_SINGLE_PRODUCT_TMP_BYTES_F;
|
||||
typedef typeof(tnx32_small_single_product) TNX32_SMALL_SINGLE_PRODUCT_F;
|
||||
typedef typeof(tnx32_small_single_product_tmp_bytes) TNX32_SMALL_SINGLE_PRODUCT_TMP_BYTES_F;
|
||||
typedef typeof(rnx_approxdecomp_from_tnx32) RNX_APPROXDECOMP_FROM_TNX32_F;
|
||||
typedef typeof(rnx_approxdecomp_from_tnx32x2) RNX_APPROXDECOMP_FROM_TNX32X2_F;
|
||||
typedef typeof(rnx_approxdecomp_from_tnxdbl) RNX_APPROXDECOMP_FROM_TNXDBL_F;
|
||||
typedef typeof(bytes_of_rnx_svp_ppol) BYTES_OF_RNX_SVP_PPOL_F;
|
||||
typedef typeof(rnx_svp_prepare) RNX_SVP_PREPARE_F;
|
||||
typedef typeof(rnx_svp_apply) RNX_SVP_APPLY_F;
|
||||
typedef typeof(bytes_of_rnx_vmp_pmat) BYTES_OF_RNX_VMP_PMAT_F;
|
||||
typedef typeof(rnx_vmp_prepare_contiguous) RNX_VMP_PREPARE_CONTIGUOUS_F;
|
||||
typedef typeof(rnx_vmp_prepare_dblptr) RNX_VMP_PREPARE_DBLPTR_F;
|
||||
typedef typeof(rnx_vmp_prepare_row) RNX_VMP_PREPARE_ROW_F;
|
||||
typedef typeof(rnx_vmp_prepare_tmp_bytes) RNX_VMP_PREPARE_TMP_BYTES_F;
|
||||
typedef typeof(rnx_vmp_apply_tmp_a) RNX_VMP_APPLY_TMP_A_F;
|
||||
typedef typeof(rnx_vmp_apply_tmp_a_tmp_bytes) RNX_VMP_APPLY_TMP_A_TMP_BYTES_F;
|
||||
typedef typeof(rnx_vmp_apply_dft_to_dft) RNX_VMP_APPLY_DFT_TO_DFT_F;
|
||||
typedef typeof(rnx_vmp_apply_dft_to_dft_tmp_bytes) RNX_VMP_APPLY_DFT_TO_DFT_TMP_BYTES_F;
|
||||
typedef typeof(vec_rnx_dft) VEC_RNX_DFT_F;
|
||||
typedef typeof(vec_rnx_idft) VEC_RNX_IDFT_F;
|
||||
|
||||
typedef struct rnx_module_vtable_t RNX_MODULE_VTABLE;
|
||||
struct rnx_module_vtable_t {
|
||||
VEC_RNX_ZERO_F* vec_rnx_zero;
|
||||
VEC_RNX_COPY_F* vec_rnx_copy;
|
||||
VEC_RNX_NEGATE_F* vec_rnx_negate;
|
||||
VEC_RNX_ADD_F* vec_rnx_add;
|
||||
VEC_RNX_SUB_F* vec_rnx_sub;
|
||||
VEC_RNX_ROTATE_F* vec_rnx_rotate;
|
||||
VEC_RNX_MUL_XP_MINUS_ONE_F* vec_rnx_mul_xp_minus_one;
|
||||
VEC_RNX_AUTOMORPHISM_F* vec_rnx_automorphism;
|
||||
VEC_RNX_TO_ZNX32_F* vec_rnx_to_znx32;
|
||||
VEC_RNX_FROM_ZNX32_F* vec_rnx_from_znx32;
|
||||
VEC_RNX_TO_TNX32_F* vec_rnx_to_tnx32;
|
||||
VEC_RNX_FROM_TNX32_F* vec_rnx_from_tnx32;
|
||||
VEC_RNX_TO_TNX32X2_F* vec_rnx_to_tnx32x2;
|
||||
VEC_RNX_FROM_TNX32X2_F* vec_rnx_from_tnx32x2;
|
||||
VEC_RNX_TO_TNXDBL_F* vec_rnx_to_tnxdbl;
|
||||
RNX_SMALL_SINGLE_PRODUCT_F* rnx_small_single_product;
|
||||
RNX_SMALL_SINGLE_PRODUCT_TMP_BYTES_F* rnx_small_single_product_tmp_bytes;
|
||||
TNXDBL_SMALL_SINGLE_PRODUCT_F* tnxdbl_small_single_product;
|
||||
TNXDBL_SMALL_SINGLE_PRODUCT_TMP_BYTES_F* tnxdbl_small_single_product_tmp_bytes;
|
||||
ZNX32_SMALL_SINGLE_PRODUCT_F* znx32_small_single_product;
|
||||
ZNX32_SMALL_SINGLE_PRODUCT_TMP_BYTES_F* znx32_small_single_product_tmp_bytes;
|
||||
TNX32_SMALL_SINGLE_PRODUCT_F* tnx32_small_single_product;
|
||||
TNX32_SMALL_SINGLE_PRODUCT_TMP_BYTES_F* tnx32_small_single_product_tmp_bytes;
|
||||
RNX_APPROXDECOMP_FROM_TNX32_F* rnx_approxdecomp_from_tnx32;
|
||||
RNX_APPROXDECOMP_FROM_TNX32X2_F* rnx_approxdecomp_from_tnx32x2;
|
||||
RNX_APPROXDECOMP_FROM_TNXDBL_F* rnx_approxdecomp_from_tnxdbl;
|
||||
BYTES_OF_RNX_SVP_PPOL_F* bytes_of_rnx_svp_ppol;
|
||||
RNX_SVP_PREPARE_F* rnx_svp_prepare;
|
||||
RNX_SVP_APPLY_F* rnx_svp_apply;
|
||||
BYTES_OF_RNX_VMP_PMAT_F* bytes_of_rnx_vmp_pmat;
|
||||
RNX_VMP_PREPARE_CONTIGUOUS_F* rnx_vmp_prepare_contiguous;
|
||||
RNX_VMP_PREPARE_DBLPTR_F* rnx_vmp_prepare_dblptr;
|
||||
RNX_VMP_PREPARE_ROW_F* rnx_vmp_prepare_row;
|
||||
RNX_VMP_PREPARE_TMP_BYTES_F* rnx_vmp_prepare_tmp_bytes;
|
||||
RNX_VMP_APPLY_TMP_A_F* rnx_vmp_apply_tmp_a;
|
||||
RNX_VMP_APPLY_TMP_A_TMP_BYTES_F* rnx_vmp_apply_tmp_a_tmp_bytes;
|
||||
RNX_VMP_APPLY_DFT_TO_DFT_F* rnx_vmp_apply_dft_to_dft;
|
||||
RNX_VMP_APPLY_DFT_TO_DFT_TMP_BYTES_F* rnx_vmp_apply_dft_to_dft_tmp_bytes;
|
||||
VEC_RNX_DFT_F* vec_rnx_dft;
|
||||
VEC_RNX_IDFT_F* vec_rnx_idft;
|
||||
};
|
||||
|
||||
#endif // SPQLIOS_VEC_RNX_ARITHMETIC_PLUGIN_H
|
||||
@@ -0,0 +1,309 @@
|
||||
#ifndef SPQLIOS_VEC_RNX_ARITHMETIC_PRIVATE_H
|
||||
#define SPQLIOS_VEC_RNX_ARITHMETIC_PRIVATE_H
|
||||
|
||||
#include "../commons_private.h"
|
||||
#include "../reim/reim_fft.h"
|
||||
#include "vec_rnx_arithmetic.h"
|
||||
#include "vec_rnx_arithmetic_plugin.h"
|
||||
|
||||
typedef struct fft64_rnx_module_precomp_t FFT64_RNX_MODULE_PRECOMP;
|
||||
struct fft64_rnx_module_precomp_t {
|
||||
REIM_FFT_PRECOMP* p_fft;
|
||||
REIM_IFFT_PRECOMP* p_ifft;
|
||||
REIM_FFTVEC_ADD_PRECOMP* p_fftvec_add;
|
||||
REIM_FFTVEC_MUL_PRECOMP* p_fftvec_mul;
|
||||
REIM_FFTVEC_ADDMUL_PRECOMP* p_fftvec_addmul;
|
||||
};
|
||||
|
||||
typedef union rnx_module_precomp_t RNX_MODULE_PRECOMP;
|
||||
union rnx_module_precomp_t {
|
||||
FFT64_RNX_MODULE_PRECOMP fft64;
|
||||
};
|
||||
|
||||
void fft64_init_rnx_module_precomp(MOD_RNX* module);
|
||||
|
||||
void fft64_finalize_rnx_module_precomp(MOD_RNX* module);
|
||||
|
||||
/** @brief opaque structure that describes the modules (RnX,ZnX,TnX) and the hardware */
|
||||
struct rnx_module_info_t {
|
||||
uint64_t n;
|
||||
uint64_t m;
|
||||
RNX_MODULE_TYPE mtype;
|
||||
RNX_MODULE_VTABLE vtable;
|
||||
RNX_MODULE_PRECOMP precomp;
|
||||
void* custom;
|
||||
void (*custom_deleter)(void*);
|
||||
};
|
||||
|
||||
void init_rnx_module_info(MOD_RNX* module, //
|
||||
uint64_t, RNX_MODULE_TYPE mtype);
|
||||
|
||||
void finalize_rnx_module_info(MOD_RNX* module);
|
||||
|
||||
void fft64_init_rnx_module_vtable(MOD_RNX* module);
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
// prepared gadget decompositions (optimized) //
|
||||
///////////////////////////////////////////////////////////////////
|
||||
|
||||
struct tnx32_approxdec_gadget_t {
|
||||
uint64_t k;
|
||||
uint64_t ell;
|
||||
int32_t add_cst; // 1/2.(sum 2^-(i+1)K)
|
||||
int32_t rshift_base; // 32 - K
|
||||
int64_t and_mask; // 2^K-1
|
||||
int64_t or_mask; // double(2^52)
|
||||
double sub_cst; // double(2^52 + 2^(K-1))
|
||||
uint8_t rshifts[8]; // 32 - (i+1).K
|
||||
};
|
||||
|
||||
struct tnx32x2_approxdec_gadget_t {
|
||||
// TODO
|
||||
};
|
||||
|
||||
struct tnxdbl_approxdecomp_gadget_t {
|
||||
uint64_t k;
|
||||
uint64_t ell;
|
||||
double add_cst; // double(3.2^(51-ell.K) + 1/2.(sum 2^(-iK)) for i=[0,ell[)
|
||||
uint64_t and_mask; // uint64(2^(K)-1)
|
||||
uint64_t or_mask; // double(2^52)
|
||||
double sub_cst; // double(2^52 + 2^(K-1))
|
||||
};
|
||||
|
||||
EXPORT void vec_rnx_add_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||
);
|
||||
EXPORT void vec_rnx_add_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||
);
|
||||
|
||||
/** @brief sets res = 0 */
|
||||
EXPORT void vec_rnx_zero_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl // res
|
||||
);
|
||||
|
||||
/** @brief sets res = a */
|
||||
EXPORT void vec_rnx_copy_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
/** @brief sets res = -a */
|
||||
EXPORT void vec_rnx_negate_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
/** @brief sets res = -a */
|
||||
EXPORT void vec_rnx_negate_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
/** @brief sets res = a - b */
|
||||
EXPORT void vec_rnx_sub_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||
);
|
||||
|
||||
/** @brief sets res = a - b */
|
||||
EXPORT void vec_rnx_sub_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||
);
|
||||
|
||||
/** @brief sets res = a . X^p */
|
||||
EXPORT void vec_rnx_rotate_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
const int64_t p, // rotation value
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
/** @brief sets res = a(X^p) */
|
||||
EXPORT void vec_rnx_automorphism_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
int64_t p, // X -> X^p
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
/** @brief number of bytes in a RNX_VMP_PMAT (for manual allocation) */
|
||||
EXPORT uint64_t fft64_bytes_of_rnx_vmp_pmat(const MOD_RNX* module, // N
|
||||
uint64_t nrows, uint64_t ncols);
|
||||
|
||||
EXPORT void fft64_rnx_vmp_apply_dft_to_dft_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a_dft, uint64_t a_size, uint64_t a_sl, // a
|
||||
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||
);
|
||||
EXPORT void fft64_rnx_vmp_apply_dft_to_dft_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a_dft, uint64_t a_size, uint64_t a_sl, // a
|
||||
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||
);
|
||||
EXPORT uint64_t fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
uint64_t res_size, // res
|
||||
uint64_t a_size, // a
|
||||
uint64_t nrows, uint64_t ncols // prep matrix
|
||||
);
|
||||
EXPORT uint64_t fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
uint64_t res_size, // res
|
||||
uint64_t a_size, // a
|
||||
uint64_t nrows, uint64_t ncols // prep matrix
|
||||
);
|
||||
EXPORT void fft64_rnx_vmp_prepare_contiguous_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
RNX_VMP_PMAT* pmat, // output
|
||||
const double* mat, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
EXPORT void fft64_rnx_vmp_prepare_contiguous_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
RNX_VMP_PMAT* pmat, // output
|
||||
const double* mat, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
EXPORT void fft64_rnx_vmp_prepare_dblptr_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
RNX_VMP_PMAT* pmat, // output
|
||||
const double** mat, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
EXPORT void fft64_rnx_vmp_prepare_dblptr_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
RNX_VMP_PMAT* pmat, // output
|
||||
const double** mat, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
EXPORT void fft64_rnx_vmp_prepare_row_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
RNX_VMP_PMAT* pmat, // output
|
||||
const double* mat, uint64_t row_i, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
EXPORT void fft64_rnx_vmp_prepare_row_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
RNX_VMP_PMAT* pmat, // output
|
||||
const double* mat, uint64_t row_i, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
EXPORT uint64_t fft64_rnx_vmp_prepare_tmp_bytes_ref(const MOD_RNX* module);
|
||||
EXPORT uint64_t fft64_rnx_vmp_prepare_tmp_bytes_avx(const MOD_RNX* module);
|
||||
|
||||
EXPORT void fft64_rnx_vmp_apply_tmp_a_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res (addr must be != a)
|
||||
double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten)
|
||||
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
EXPORT void fft64_rnx_vmp_apply_tmp_a_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res (addr must be != a)
|
||||
double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten)
|
||||
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
|
||||
EXPORT uint64_t fft64_rnx_vmp_apply_tmp_a_tmp_bytes_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
uint64_t res_size, // res
|
||||
uint64_t a_size, // a
|
||||
uint64_t nrows, uint64_t ncols // prep matrix
|
||||
);
|
||||
EXPORT uint64_t fft64_rnx_vmp_apply_tmp_a_tmp_bytes_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
uint64_t res_size, // res
|
||||
uint64_t a_size, // a
|
||||
uint64_t nrows, uint64_t ncols // prep matrix
|
||||
);
|
||||
|
||||
/// gadget decompositions
|
||||
|
||||
/** @brief sets res = gadget_decompose(a) */
|
||||
EXPORT void rnx_approxdecomp_from_tnxdbl_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a); // a
|
||||
EXPORT void rnx_approxdecomp_from_tnxdbl_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a); // a
|
||||
|
||||
EXPORT void vec_rnx_mul_xp_minus_one_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
const int64_t p, // rotation value
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
EXPORT void vec_rnx_to_znx32_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
int32_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
EXPORT void vec_rnx_from_znx32_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int32_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
EXPORT void vec_rnx_to_tnx32_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
int32_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
EXPORT void vec_rnx_from_tnx32_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int32_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
EXPORT void vec_rnx_to_tnxdbl_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
EXPORT uint64_t fft64_bytes_of_rnx_svp_ppol(const MOD_RNX* module); // N
|
||||
|
||||
/** @brief prepares a svp polynomial */
|
||||
EXPORT void fft64_rnx_svp_prepare_ref(const MOD_RNX* module, // N
|
||||
RNX_SVP_PPOL* ppol, // output
|
||||
const double* pol // a
|
||||
);
|
||||
|
||||
/** @brief apply a svp product, result = ppol * a, presented in DFT space */
|
||||
EXPORT void fft64_rnx_svp_apply_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // output
|
||||
const RNX_SVP_PPOL* ppol, // prepared pol
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
#endif // SPQLIOS_VEC_RNX_ARITHMETIC_PRIVATE_H
|
||||
@@ -0,0 +1,91 @@
|
||||
#include <memory.h>
|
||||
|
||||
#include "vec_rnx_arithmetic_private.h"
|
||||
#include "zn_arithmetic_private.h"
|
||||
|
||||
EXPORT void vec_rnx_to_znx32_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
int32_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
const uint64_t nn = module->n;
|
||||
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < msize; ++i) {
|
||||
dbl_round_to_i32_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn);
|
||||
}
|
||||
for (uint64_t i = msize; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(int32_t));
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void vec_rnx_from_znx32_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int32_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
const uint64_t nn = module->n;
|
||||
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < msize; ++i) {
|
||||
i32_to_dbl_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn);
|
||||
}
|
||||
for (uint64_t i = msize; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(int32_t));
|
||||
}
|
||||
}
|
||||
EXPORT void vec_rnx_to_tnx32_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
int32_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
const uint64_t nn = module->n;
|
||||
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < msize; ++i) {
|
||||
dbl_to_tn32_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn);
|
||||
}
|
||||
for (uint64_t i = msize; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(int32_t));
|
||||
}
|
||||
}
|
||||
EXPORT void vec_rnx_from_tnx32_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int32_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
const uint64_t nn = module->n;
|
||||
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < msize; ++i) {
|
||||
tn32_to_dbl_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn);
|
||||
}
|
||||
for (uint64_t i = msize; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(int32_t));
|
||||
}
|
||||
}
|
||||
|
||||
static void dbl_to_tndbl_ref( //
|
||||
const void* UNUSED, // N
|
||||
double* res, uint64_t res_size, // res
|
||||
const double* a, uint64_t a_size // a
|
||||
) {
|
||||
static const double OFF_CST = INT64_C(3) << 51;
|
||||
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < msize; ++i) {
|
||||
double ai = a[i] + OFF_CST;
|
||||
res[i] = a[i] - (ai - OFF_CST);
|
||||
}
|
||||
memset(res + msize, 0, (res_size - msize) * sizeof(double));
|
||||
}
|
||||
|
||||
EXPORT void vec_rnx_to_tnxdbl_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
const uint64_t nn = module->n;
|
||||
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < msize; ++i) {
|
||||
dbl_to_tndbl_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn);
|
||||
}
|
||||
for (uint64_t i = msize; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(int32_t));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
#include <string.h>
|
||||
|
||||
#include "../coeffs/coeffs_arithmetic.h"
|
||||
#include "vec_rnx_arithmetic_private.h"
|
||||
|
||||
EXPORT uint64_t fft64_bytes_of_rnx_svp_ppol(const MOD_RNX* module) { return module->n * sizeof(double); }
|
||||
|
||||
EXPORT RNX_SVP_PPOL* new_rnx_svp_ppol(const MOD_RNX* module) { return spqlios_alloc(bytes_of_rnx_svp_ppol(module)); }
|
||||
|
||||
EXPORT void delete_rnx_svp_ppol(RNX_SVP_PPOL* ppol) { spqlios_free(ppol); }
|
||||
|
||||
/** @brief prepares a svp polynomial */
|
||||
EXPORT void fft64_rnx_svp_prepare_ref(const MOD_RNX* module, // N
|
||||
RNX_SVP_PPOL* ppol, // output
|
||||
const double* pol // a
|
||||
) {
|
||||
double* const dppol = (double*)ppol;
|
||||
rnx_divide_by_m_ref(module->n, module->m, dppol, pol);
|
||||
reim_fft(module->precomp.fft64.p_fft, dppol);
|
||||
}
|
||||
|
||||
EXPORT void fft64_rnx_svp_apply_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // output
|
||||
const RNX_SVP_PPOL* ppol, // prepared pol
|
||||
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
const uint64_t nn = module->n;
|
||||
double* const dppol = (double*)ppol;
|
||||
|
||||
const uint64_t auto_end_idx = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < auto_end_idx; ++i) {
|
||||
const double* a_ptr = a + i * a_sl;
|
||||
double* const res_ptr = res + i * res_sl;
|
||||
// copy the polynomial to res, apply fft in place, call fftvec
|
||||
// _mul, apply ifft in place.
|
||||
memcpy(res_ptr, a_ptr, nn * sizeof(double));
|
||||
reim_fft(module->precomp.fft64.p_fft, (double*)res_ptr);
|
||||
reim_fftvec_mul(module->precomp.fft64.p_fftvec_mul, res_ptr, res_ptr, dppol);
|
||||
reim_ifft(module->precomp.fft64.p_ifft, res_ptr);
|
||||
}
|
||||
|
||||
// then extend with zeros
|
||||
for (uint64_t i = auto_end_idx; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,254 @@
|
||||
#include <assert.h>
|
||||
#include <immintrin.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "../coeffs/coeffs_arithmetic.h"
|
||||
#include "../reim/reim_fft.h"
|
||||
#include "../reim4/reim4_arithmetic.h"
|
||||
#include "vec_rnx_arithmetic_private.h"
|
||||
|
||||
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||
EXPORT void fft64_rnx_vmp_prepare_contiguous_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
RNX_VMP_PMAT* pmat, // output
|
||||
const double* mat, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
) {
|
||||
// there is an edge case if nn < 8
|
||||
const uint64_t nn = module->n;
|
||||
const uint64_t m = module->m;
|
||||
|
||||
double* const dtmp = (double*)tmp_space;
|
||||
double* const output_mat = (double*)pmat;
|
||||
double* start_addr = (double*)pmat;
|
||||
uint64_t offset = nrows * ncols * 8;
|
||||
|
||||
if (nn >= 8) {
|
||||
for (uint64_t row_i = 0; row_i < nrows; row_i++) {
|
||||
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||
rnx_divide_by_m_avx(nn, m, dtmp, mat + (row_i * ncols + col_i) * nn);
|
||||
reim_fft(module->precomp.fft64.p_fft, dtmp);
|
||||
|
||||
if (col_i == (ncols - 1) && (ncols % 2 == 1)) {
|
||||
// special case: last column out of an odd column number
|
||||
start_addr = output_mat + col_i * nrows * 8 // col == ncols-1
|
||||
+ row_i * 8;
|
||||
} else {
|
||||
// general case: columns go by pair
|
||||
start_addr = output_mat + (col_i / 2) * (2 * nrows) * 8 // second: col pair index
|
||||
+ row_i * 2 * 8 // third: row index
|
||||
+ (col_i % 2) * 8;
|
||||
}
|
||||
|
||||
for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
|
||||
// extract blk from tmp and save it
|
||||
reim4_extract_1blk_from_reim_avx(m, blk_i, start_addr + blk_i * offset, dtmp);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (uint64_t row_i = 0; row_i < nrows; row_i++) {
|
||||
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||
double* res = output_mat + (col_i * nrows + row_i) * nn;
|
||||
rnx_divide_by_m_avx(nn, m, res, mat + (row_i * ncols + col_i) * nn);
|
||||
reim_fft(module->precomp.fft64.p_fft, res);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */
|
||||
EXPORT void fft64_rnx_vmp_prepare_dblptr_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
RNX_VMP_PMAT* pmat, // output
|
||||
const double** mat, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
) {
|
||||
for (uint64_t row_i = 0; row_i < nrows; row_i++) {
|
||||
fft64_rnx_vmp_prepare_row_avx(module, pmat, mat[row_i], row_i, nrows, ncols, tmp_space);
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */
|
||||
EXPORT void fft64_rnx_vmp_prepare_row_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
RNX_VMP_PMAT* pmat, // output
|
||||
const double* row, uint64_t row_i, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
) {
|
||||
// there is an edge case if nn < 8
|
||||
const uint64_t nn = module->n;
|
||||
const uint64_t m = module->m;
|
||||
|
||||
double* const dtmp = (double*)tmp_space;
|
||||
double* const output_mat = (double*)pmat;
|
||||
double* start_addr = (double*)pmat;
|
||||
uint64_t offset = nrows * ncols * 8;
|
||||
|
||||
if (nn >= 8) {
|
||||
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||
rnx_divide_by_m_avx(nn, m, dtmp, row + col_i * nn);
|
||||
reim_fft(module->precomp.fft64.p_fft, dtmp);
|
||||
|
||||
if (col_i == (ncols - 1) && (ncols % 2 == 1)) {
|
||||
// special case: last column out of an odd column number
|
||||
start_addr = output_mat + col_i * nrows * 8 // col == ncols-1
|
||||
+ row_i * 8;
|
||||
} else {
|
||||
// general case: columns go by pair
|
||||
start_addr = output_mat + (col_i / 2) * (2 * nrows) * 8 // second: col pair index
|
||||
+ row_i * 2 * 8 // third: row index
|
||||
+ (col_i % 2) * 8;
|
||||
}
|
||||
|
||||
for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
|
||||
// extract blk from tmp and save it
|
||||
reim4_extract_1blk_from_reim_avx(m, blk_i, start_addr + blk_i * offset, dtmp);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||
double* res = output_mat + (col_i * nrows + row_i) * nn;
|
||||
rnx_divide_by_m_avx(nn, m, res, row + col_i * nn);
|
||||
reim_fft(module->precomp.fft64.p_fft, res);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief minimal size of the tmp_space */
|
||||
EXPORT void fft64_rnx_vmp_apply_dft_to_dft_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a_dft, uint64_t a_size, uint64_t a_sl, // a
|
||||
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||
) {
|
||||
const uint64_t m = module->m;
|
||||
const uint64_t nn = module->n;
|
||||
|
||||
double* mat2cols_output = (double*)tmp_space; // 128 bytes
|
||||
double* extracted_blk = (double*)tmp_space + 16; // 64*min(nrows,a_size) bytes
|
||||
|
||||
double* mat_input = (double*)pmat;
|
||||
|
||||
const uint64_t row_max = nrows < a_size ? nrows : a_size;
|
||||
const uint64_t col_max = ncols < res_size ? ncols : res_size;
|
||||
|
||||
if (row_max > 0 && col_max > 0) {
|
||||
if (nn >= 8) {
|
||||
// let's do some prefetching of the GSW key, since on some cpus,
|
||||
// it helps
|
||||
const uint64_t ms4 = m >> 2; // m/4
|
||||
const uint64_t gsw_iter_doubles = 8 * nrows * ncols;
|
||||
const uint64_t pref_doubles = 1200;
|
||||
const double* gsw_pref_ptr = mat_input;
|
||||
const double* const gsw_ptr_end = mat_input + ms4 * gsw_iter_doubles;
|
||||
const double* gsw_pref_ptr_target = mat_input + pref_doubles;
|
||||
for (; gsw_pref_ptr < gsw_pref_ptr_target; gsw_pref_ptr += 8) {
|
||||
__builtin_prefetch(gsw_pref_ptr, 0, _MM_HINT_T0);
|
||||
}
|
||||
const double* mat_blk_start;
|
||||
uint64_t blk_i;
|
||||
for (blk_i = 0, mat_blk_start = mat_input; blk_i < ms4; blk_i++, mat_blk_start += gsw_iter_doubles) {
|
||||
// prefetch the next iteration
|
||||
if (gsw_pref_ptr_target < gsw_ptr_end) {
|
||||
gsw_pref_ptr_target += gsw_iter_doubles;
|
||||
if (gsw_pref_ptr_target > gsw_ptr_end) gsw_pref_ptr_target = gsw_ptr_end;
|
||||
for (; gsw_pref_ptr < gsw_pref_ptr_target; gsw_pref_ptr += 8) {
|
||||
__builtin_prefetch(gsw_pref_ptr, 0, _MM_HINT_T0);
|
||||
}
|
||||
}
|
||||
reim4_extract_1blk_from_contiguous_reim_sl_avx(m, a_sl, row_max, blk_i, extracted_blk, a_dft);
|
||||
// apply mat2cols
|
||||
for (uint64_t col_i = 0; col_i < col_max - 1; col_i += 2) {
|
||||
uint64_t col_offset = col_i * (8 * nrows);
|
||||
reim4_vec_mat2cols_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
|
||||
reim4_save_1blk_to_reim_avx(m, blk_i, res + col_i * res_sl, mat2cols_output);
|
||||
reim4_save_1blk_to_reim_avx(m, blk_i, res + (col_i + 1) * res_sl, mat2cols_output + 8);
|
||||
}
|
||||
|
||||
// check if col_max is odd, then special case
|
||||
if (col_max % 2 == 1) {
|
||||
uint64_t last_col = col_max - 1;
|
||||
uint64_t col_offset = last_col * (8 * nrows);
|
||||
|
||||
// the last column is alone in the pmat: vec_mat1col
|
||||
if (ncols == col_max) {
|
||||
reim4_vec_mat1col_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
} else {
|
||||
// the last column is part of a colpair in the pmat: vec_mat2cols and ignore the second position
|
||||
reim4_vec_mat2cols_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
}
|
||||
reim4_save_1blk_to_reim_avx(m, blk_i, res + last_col * res_sl, mat2cols_output);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const double* in;
|
||||
uint64_t in_sl;
|
||||
if (res == a_dft) {
|
||||
// it is in place: copy the input vector
|
||||
in = (double*)tmp_space;
|
||||
in_sl = nn;
|
||||
// vec_rnx_copy(module, (double*)tmp_space, row_max, nn, a_dft, row_max, a_sl);
|
||||
for (uint64_t row_i = 0; row_i < row_max; row_i++) {
|
||||
memcpy((double*)tmp_space + row_i * nn, a_dft + row_i * a_sl, nn * sizeof(double));
|
||||
}
|
||||
} else {
|
||||
// it is out of place: do the product directly
|
||||
in = a_dft;
|
||||
in_sl = a_sl;
|
||||
}
|
||||
for (uint64_t col_i = 0; col_i < col_max; col_i++) {
|
||||
double* pmat_col = mat_input + col_i * nrows * nn;
|
||||
{
|
||||
reim_fftvec_mul(module->precomp.fft64.p_fftvec_mul, //
|
||||
res + col_i * res_sl, //
|
||||
in, //
|
||||
pmat_col);
|
||||
}
|
||||
for (uint64_t row_i = 1; row_i < row_max; row_i++) {
|
||||
reim_fftvec_addmul(module->precomp.fft64.p_fftvec_addmul, //
|
||||
res + col_i * res_sl, //
|
||||
in + row_i * in_sl, //
|
||||
pmat_col + row_i * nn);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// zero out remaining bytes (if any)
|
||||
for (uint64_t i = col_max; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief applies a vmp product res = a x pmat */
|
||||
EXPORT void fft64_rnx_vmp_apply_tmp_a_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res (addr must be != a)
|
||||
double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten)
|
||||
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space
|
||||
) {
|
||||
const uint64_t nn = module->n;
|
||||
const uint64_t rows = nrows < a_size ? nrows : a_size;
|
||||
const uint64_t cols = ncols < res_size ? ncols : res_size;
|
||||
|
||||
// fft is done in place on the input (tmpa is destroyed)
|
||||
for (uint64_t i = 0; i < rows; ++i) {
|
||||
reim_fft(module->precomp.fft64.p_fft, tmpa + i * a_sl);
|
||||
}
|
||||
fft64_rnx_vmp_apply_dft_to_dft_avx(module, //
|
||||
res, cols, res_sl, //
|
||||
tmpa, rows, a_sl, //
|
||||
pmat, nrows, ncols, //
|
||||
tmp_space);
|
||||
// ifft is done in place on the output
|
||||
for (uint64_t i = 0; i < cols; ++i) {
|
||||
reim_ifft(module->precomp.fft64.p_ifft, res + i * res_sl);
|
||||
}
|
||||
// zero out the remaining positions
|
||||
for (uint64_t i = cols; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,309 @@
|
||||
#include <assert.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "../coeffs/coeffs_arithmetic.h"
|
||||
#include "../reim/reim_fft.h"
|
||||
#include "../reim4/reim4_arithmetic.h"
|
||||
#include "vec_rnx_arithmetic_private.h"
|
||||
|
||||
/** @brief number of bytes in a RNX_VMP_PMAT (for manual allocation) */
|
||||
EXPORT uint64_t fft64_bytes_of_rnx_vmp_pmat(const MOD_RNX* module, // N
|
||||
uint64_t nrows, uint64_t ncols) { // dimensions
|
||||
return nrows * ncols * module->n * sizeof(double);
|
||||
}
|
||||
|
||||
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||
EXPORT void fft64_rnx_vmp_prepare_contiguous_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
RNX_VMP_PMAT* pmat, // output
|
||||
const double* mat, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
) {
|
||||
// there is an edge case if nn < 8
|
||||
const uint64_t nn = module->n;
|
||||
const uint64_t m = module->m;
|
||||
|
||||
double* const dtmp = (double*)tmp_space;
|
||||
double* const output_mat = (double*)pmat;
|
||||
double* start_addr = (double*)pmat;
|
||||
uint64_t offset = nrows * ncols * 8;
|
||||
|
||||
if (nn >= 8) {
|
||||
for (uint64_t row_i = 0; row_i < nrows; row_i++) {
|
||||
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||
rnx_divide_by_m_ref(nn, m, dtmp, mat + (row_i * ncols + col_i) * nn);
|
||||
reim_fft(module->precomp.fft64.p_fft, dtmp);
|
||||
|
||||
if (col_i == (ncols - 1) && (ncols % 2 == 1)) {
|
||||
// special case: last column out of an odd column number
|
||||
start_addr = output_mat + col_i * nrows * 8 // col == ncols-1
|
||||
+ row_i * 8;
|
||||
} else {
|
||||
// general case: columns go by pair
|
||||
start_addr = output_mat + (col_i / 2) * (2 * nrows) * 8 // second: col pair index
|
||||
+ row_i * 2 * 8 // third: row index
|
||||
+ (col_i % 2) * 8;
|
||||
}
|
||||
|
||||
for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
|
||||
// extract blk from tmp and save it
|
||||
reim4_extract_1blk_from_reim_ref(m, blk_i, start_addr + blk_i * offset, dtmp);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (uint64_t row_i = 0; row_i < nrows; row_i++) {
|
||||
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||
double* res = output_mat + (col_i * nrows + row_i) * nn;
|
||||
rnx_divide_by_m_ref(nn, m, res, mat + (row_i * ncols + col_i) * nn);
|
||||
reim_fft(module->precomp.fft64.p_fft, res);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */
|
||||
EXPORT void fft64_rnx_vmp_prepare_dblptr_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
RNX_VMP_PMAT* pmat, // output
|
||||
const double** mat, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
) {
|
||||
for (uint64_t row_i = 0; row_i < nrows; row_i++) {
|
||||
fft64_rnx_vmp_prepare_row_ref(module, pmat, mat[row_i], row_i, nrows, ncols, tmp_space);
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */
|
||||
EXPORT void fft64_rnx_vmp_prepare_row_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
RNX_VMP_PMAT* pmat, // output
|
||||
const double* row, uint64_t row_i, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
) {
|
||||
// there is an edge case if nn < 8
|
||||
const uint64_t nn = module->n;
|
||||
const uint64_t m = module->m;
|
||||
|
||||
double* const dtmp = (double*)tmp_space;
|
||||
double* const output_mat = (double*)pmat;
|
||||
double* start_addr = (double*)pmat;
|
||||
uint64_t offset = nrows * ncols * 8;
|
||||
|
||||
if (nn >= 8) {
|
||||
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||
rnx_divide_by_m_ref(nn, m, dtmp, row + col_i * nn);
|
||||
reim_fft(module->precomp.fft64.p_fft, dtmp);
|
||||
|
||||
if (col_i == (ncols - 1) && (ncols % 2 == 1)) {
|
||||
// special case: last column out of an odd column number
|
||||
start_addr = output_mat + col_i * nrows * 8 // col == ncols-1
|
||||
+ row_i * 8;
|
||||
} else {
|
||||
// general case: columns go by pair
|
||||
start_addr = output_mat + (col_i / 2) * (2 * nrows) * 8 // second: col pair index
|
||||
+ row_i * 2 * 8 // third: row index
|
||||
+ (col_i % 2) * 8;
|
||||
}
|
||||
|
||||
for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
|
||||
// extract blk from tmp and save it
|
||||
reim4_extract_1blk_from_reim_ref(m, blk_i, start_addr + blk_i * offset, dtmp);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||
double* res = output_mat + (col_i * nrows + row_i) * nn;
|
||||
rnx_divide_by_m_ref(nn, m, res, row + col_i * nn);
|
||||
reim_fft(module->precomp.fft64.p_fft, res);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief number of scratch bytes necessary to prepare a matrix */
|
||||
EXPORT uint64_t fft64_rnx_vmp_prepare_tmp_bytes_ref(const MOD_RNX* module) {
|
||||
const uint64_t nn = module->n;
|
||||
return nn * sizeof(int64_t);
|
||||
}
|
||||
|
||||
/** @brief minimal size of the tmp_space */
|
||||
EXPORT void fft64_rnx_vmp_apply_dft_to_dft_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const double* a_dft, uint64_t a_size, uint64_t a_sl, // a
|
||||
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||
) {
|
||||
const uint64_t m = module->m;
|
||||
const uint64_t nn = module->n;
|
||||
|
||||
double* mat2cols_output = (double*)tmp_space; // 128 bytes
|
||||
double* extracted_blk = (double*)tmp_space + 16; // 64*min(nrows,a_size) bytes
|
||||
|
||||
double* mat_input = (double*)pmat;
|
||||
|
||||
const uint64_t row_max = nrows < a_size ? nrows : a_size;
|
||||
const uint64_t col_max = ncols < res_size ? ncols : res_size;
|
||||
|
||||
if (row_max > 0 && col_max > 0) {
|
||||
if (nn >= 8) {
|
||||
for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
|
||||
double* mat_blk_start = mat_input + blk_i * (8 * nrows * ncols);
|
||||
|
||||
reim4_extract_1blk_from_contiguous_reim_sl_ref(m, a_sl, row_max, blk_i, extracted_blk, a_dft);
|
||||
// apply mat2cols
|
||||
for (uint64_t col_i = 0; col_i < col_max - 1; col_i += 2) {
|
||||
uint64_t col_offset = col_i * (8 * nrows);
|
||||
reim4_vec_mat2cols_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
|
||||
reim4_save_1blk_to_reim_ref(m, blk_i, res + col_i * res_sl, mat2cols_output);
|
||||
reim4_save_1blk_to_reim_ref(m, blk_i, res + (col_i + 1) * res_sl, mat2cols_output + 8);
|
||||
}
|
||||
|
||||
// check if col_max is odd, then special case
|
||||
if (col_max % 2 == 1) {
|
||||
uint64_t last_col = col_max - 1;
|
||||
uint64_t col_offset = last_col * (8 * nrows);
|
||||
|
||||
// the last column is alone in the pmat: vec_mat1col
|
||||
if (ncols == col_max) {
|
||||
reim4_vec_mat1col_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
} else {
|
||||
// the last column is part of a colpair in the pmat: vec_mat2cols and ignore the second position
|
||||
reim4_vec_mat2cols_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
}
|
||||
reim4_save_1blk_to_reim_ref(m, blk_i, res + last_col * res_sl, mat2cols_output);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const double* in;
|
||||
uint64_t in_sl;
|
||||
if (res == a_dft) {
|
||||
// it is in place: copy the input vector
|
||||
in = (double*)tmp_space;
|
||||
in_sl = nn;
|
||||
// vec_rnx_copy(module, (double*)tmp_space, row_max, nn, a_dft, row_max, a_sl);
|
||||
for (uint64_t row_i = 0; row_i < row_max; row_i++) {
|
||||
memcpy((double*)tmp_space + row_i * nn, a_dft + row_i * a_sl, nn * sizeof(double));
|
||||
}
|
||||
} else {
|
||||
// it is out of place: do the product directly
|
||||
in = a_dft;
|
||||
in_sl = a_sl;
|
||||
}
|
||||
for (uint64_t col_i = 0; col_i < col_max; col_i++) {
|
||||
double* pmat_col = mat_input + col_i * nrows * nn;
|
||||
{
|
||||
reim_fftvec_mul(module->precomp.fft64.p_fftvec_mul, //
|
||||
res + col_i * res_sl, //
|
||||
in, //
|
||||
pmat_col);
|
||||
}
|
||||
for (uint64_t row_i = 1; row_i < row_max; row_i++) {
|
||||
reim_fftvec_addmul(module->precomp.fft64.p_fftvec_addmul, //
|
||||
res + col_i * res_sl, //
|
||||
in + row_i * in_sl, //
|
||||
pmat_col + row_i * nn);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// zero out remaining bytes (if any)
|
||||
for (uint64_t i = col_max; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief applies a vmp product res = a x pmat */
|
||||
EXPORT void fft64_rnx_vmp_apply_tmp_a_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
double* res, uint64_t res_size, uint64_t res_sl, // res (addr must be != a)
|
||||
double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten)
|
||||
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space
|
||||
) {
|
||||
const uint64_t nn = module->n;
|
||||
const uint64_t rows = nrows < a_size ? nrows : a_size;
|
||||
const uint64_t cols = ncols < res_size ? ncols : res_size;
|
||||
|
||||
// fft is done in place on the input (tmpa is destroyed)
|
||||
for (uint64_t i = 0; i < rows; ++i) {
|
||||
reim_fft(module->precomp.fft64.p_fft, tmpa + i * a_sl);
|
||||
}
|
||||
fft64_rnx_vmp_apply_dft_to_dft_ref(module, //
|
||||
res, cols, res_sl, //
|
||||
tmpa, rows, a_sl, //
|
||||
pmat, nrows, ncols, //
|
||||
tmp_space);
|
||||
// ifft is done in place on the output
|
||||
for (uint64_t i = 0; i < cols; ++i) {
|
||||
reim_ifft(module->precomp.fft64.p_ifft, res + i * res_sl);
|
||||
}
|
||||
// zero out the remaining positions
|
||||
for (uint64_t i = cols; i < res_size; ++i) {
|
||||
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief minimal size of the tmp_space */
|
||||
EXPORT uint64_t fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
uint64_t res_size, // res
|
||||
uint64_t a_size, // a
|
||||
uint64_t nrows, uint64_t ncols // prep matrix
|
||||
) {
|
||||
const uint64_t row_max = nrows < a_size ? nrows : a_size;
|
||||
|
||||
return (128) + (64 * row_max);
|
||||
}
|
||||
|
||||
#ifdef __APPLE__
|
||||
EXPORT uint64_t fft64_rnx_vmp_apply_tmp_a_tmp_bytes_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
uint64_t res_size, // res
|
||||
uint64_t a_size, // a
|
||||
uint64_t nrows, uint64_t ncols // prep matrix
|
||||
) {
|
||||
return fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref(module, res_size, a_size, nrows, ncols);
|
||||
}
|
||||
#else
|
||||
EXPORT uint64_t fft64_rnx_vmp_apply_tmp_a_tmp_bytes_ref( //
|
||||
const MOD_RNX* module, // N
|
||||
uint64_t res_size, // res
|
||||
uint64_t a_size, // a
|
||||
uint64_t nrows, uint64_t ncols // prep matrix
|
||||
) __attribute((alias("fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref")));
|
||||
#endif
|
||||
// avx aliases that need to be defined in the same .c file
|
||||
|
||||
/** @brief number of scratch bytes necessary to prepare a matrix */
|
||||
#ifdef __APPLE__
|
||||
#pragma weak fft64_rnx_vmp_prepare_tmp_bytes_avx = fft64_rnx_vmp_prepare_tmp_bytes_ref
|
||||
#else
|
||||
EXPORT uint64_t fft64_rnx_vmp_prepare_tmp_bytes_avx(const MOD_RNX* module)
|
||||
__attribute((alias("fft64_rnx_vmp_prepare_tmp_bytes_ref")));
|
||||
#endif
|
||||
|
||||
/** @brief minimal size of the tmp_space */
|
||||
#ifdef __APPLE__
|
||||
#pragma weak fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_avx = fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref
|
||||
#else
|
||||
EXPORT uint64_t fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
uint64_t res_size, // res
|
||||
uint64_t a_size, // a
|
||||
uint64_t nrows, uint64_t ncols // prep matrix
|
||||
) __attribute((alias("fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref")));
|
||||
#endif
|
||||
|
||||
#ifdef __APPLE__
|
||||
#pragma weak fft64_rnx_vmp_apply_tmp_a_tmp_bytes_avx = fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref
|
||||
#else
|
||||
EXPORT uint64_t fft64_rnx_vmp_apply_tmp_a_tmp_bytes_avx( //
|
||||
const MOD_RNX* module, // N
|
||||
uint64_t res_size, // res
|
||||
uint64_t a_size, // a
|
||||
uint64_t nrows, uint64_t ncols // prep matrix
|
||||
) __attribute((alias("fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref")));
|
||||
#endif
|
||||
// wrappers
|
||||
@@ -0,0 +1,369 @@
|
||||
#include <assert.h>
|
||||
#include <math.h>
|
||||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "../coeffs/coeffs_arithmetic.h"
|
||||
#include "../q120/q120_arithmetic.h"
|
||||
#include "../q120/q120_ntt.h"
|
||||
#include "../reim/reim_fft_internal.h"
|
||||
#include "../reim4/reim4_arithmetic.h"
|
||||
#include "vec_znx_arithmetic.h"
|
||||
#include "vec_znx_arithmetic_private.h"
|
||||
|
||||
// general function (virtual dispatch)
|
||||
|
||||
EXPORT void vec_znx_add(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
) {
|
||||
module->func.vec_znx_add(module, // N
|
||||
res, res_size, res_sl, // res
|
||||
a, a_size, a_sl, // a
|
||||
b, b_size, b_sl // b
|
||||
);
|
||||
}
|
||||
|
||||
EXPORT void vec_znx_sub(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
) {
|
||||
module->func.vec_znx_sub(module, // N
|
||||
res, res_size, res_sl, // res
|
||||
a, a_size, a_sl, // a
|
||||
b, b_size, b_sl // b
|
||||
);
|
||||
}
|
||||
|
||||
EXPORT void vec_znx_rotate(const MODULE* module, // N
|
||||
const int64_t p, // rotation value
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
module->func.vec_znx_rotate(module, // N
|
||||
p, // p
|
||||
res, res_size, res_sl, // res
|
||||
a, a_size, a_sl // a
|
||||
);
|
||||
}
|
||||
|
||||
EXPORT void vec_znx_mul_xp_minus_one(const MODULE* module, // N
|
||||
const int64_t p, // p
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
module->func.vec_znx_mul_xp_minus_one(module, // N
|
||||
p, // p
|
||||
res, res_size, res_sl, // res
|
||||
a, a_size, a_sl // a
|
||||
);
|
||||
}
|
||||
|
||||
EXPORT void vec_znx_automorphism(const MODULE* module, // N
|
||||
const int64_t p, // X->X^p
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
module->func.vec_znx_automorphism(module, // N
|
||||
p, // p
|
||||
res, res_size, res_sl, // res
|
||||
a, a_size, a_sl // a
|
||||
);
|
||||
}
|
||||
|
||||
EXPORT void vec_znx_normalize_base2k(const MODULE* module, uint64_t nn, // N
|
||||
uint64_t log2_base2k, // output base 2^K
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
uint8_t* tmp_space // scratch space of size >= N
|
||||
) {
|
||||
module->func.vec_znx_normalize_base2k(module, nn, // N
|
||||
log2_base2k, // log2_base2k
|
||||
res, res_size, res_sl, // res
|
||||
a, a_size, a_sl, // a
|
||||
tmp_space);
|
||||
}
|
||||
|
||||
EXPORT uint64_t vec_znx_normalize_base2k_tmp_bytes(const MODULE* module, uint64_t nn // N
|
||||
) {
|
||||
return module->func.vec_znx_normalize_base2k_tmp_bytes(module, nn // N
|
||||
);
|
||||
}
|
||||
|
||||
// specialized function (ref)
|
||||
|
||||
EXPORT void vec_znx_add_ref(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
) {
|
||||
const uint64_t nn = module->nn;
|
||||
if (a_size <= b_size) {
|
||||
const uint64_t sum_idx = res_size < a_size ? res_size : a_size;
|
||||
const uint64_t copy_idx = res_size < b_size ? res_size : b_size;
|
||||
// add up to the smallest dimension
|
||||
for (uint64_t i = 0; i < sum_idx; ++i) {
|
||||
znx_add_i64_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||
}
|
||||
// then copy to the largest dimension
|
||||
for (uint64_t i = sum_idx; i < copy_idx; ++i) {
|
||||
znx_copy_i64_ref(nn, res + i * res_sl, b + i * b_sl);
|
||||
}
|
||||
// then extend with zeros
|
||||
for (uint64_t i = copy_idx; i < res_size; ++i) {
|
||||
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||
}
|
||||
} else {
|
||||
const uint64_t sum_idx = res_size < b_size ? res_size : b_size;
|
||||
const uint64_t copy_idx = res_size < a_size ? res_size : a_size;
|
||||
// add up to the smallest dimension
|
||||
for (uint64_t i = 0; i < sum_idx; ++i) {
|
||||
znx_add_i64_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||
}
|
||||
// then copy to the largest dimension
|
||||
for (uint64_t i = sum_idx; i < copy_idx; ++i) {
|
||||
znx_copy_i64_ref(nn, res + i * res_sl, a + i * a_sl);
|
||||
}
|
||||
// then extend with zeros
|
||||
for (uint64_t i = copy_idx; i < res_size; ++i) {
|
||||
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void vec_znx_sub_ref(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
) {
|
||||
const uint64_t nn = module->nn;
|
||||
if (a_size <= b_size) {
|
||||
const uint64_t sub_idx = res_size < a_size ? res_size : a_size;
|
||||
const uint64_t copy_idx = res_size < b_size ? res_size : b_size;
|
||||
// subtract up to the smallest dimension
|
||||
for (uint64_t i = 0; i < sub_idx; ++i) {
|
||||
znx_sub_i64_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||
}
|
||||
// then negate to the largest dimension
|
||||
for (uint64_t i = sub_idx; i < copy_idx; ++i) {
|
||||
znx_negate_i64_ref(nn, res + i * res_sl, b + i * b_sl);
|
||||
}
|
||||
// then extend with zeros
|
||||
for (uint64_t i = copy_idx; i < res_size; ++i) {
|
||||
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||
}
|
||||
} else {
|
||||
const uint64_t sub_idx = res_size < b_size ? res_size : b_size;
|
||||
const uint64_t copy_idx = res_size < a_size ? res_size : a_size;
|
||||
// subtract up to the smallest dimension
|
||||
for (uint64_t i = 0; i < sub_idx; ++i) {
|
||||
znx_sub_i64_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||
}
|
||||
// then copy to the largest dimension
|
||||
for (uint64_t i = sub_idx; i < copy_idx; ++i) {
|
||||
znx_copy_i64_ref(nn, res + i * res_sl, a + i * a_sl);
|
||||
}
|
||||
// then extend with zeros
|
||||
for (uint64_t i = copy_idx; i < res_size; ++i) {
|
||||
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void vec_znx_rotate_ref(const MODULE* module, // N
|
||||
const int64_t p, // rotation value
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
const uint64_t nn = module->nn;
|
||||
|
||||
const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size;
|
||||
// rotate up to the smallest dimension
|
||||
for (uint64_t i = 0; i < rot_end_idx; ++i) {
|
||||
int64_t* res_ptr = res + i * res_sl;
|
||||
const int64_t* a_ptr = a + i * a_sl;
|
||||
if (res_ptr == a_ptr) {
|
||||
znx_rotate_inplace_i64(nn, p, res_ptr);
|
||||
} else {
|
||||
znx_rotate_i64(nn, p, res_ptr, a_ptr);
|
||||
}
|
||||
}
|
||||
// then extend with zeros
|
||||
for (uint64_t i = rot_end_idx; i < res_size; ++i) {
|
||||
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void vec_znx_mul_xp_minus_one_ref(const MODULE* module, // N
|
||||
const int64_t p, // p
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
const uint64_t nn = module->nn;
|
||||
|
||||
const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < rot_end_idx; ++i) {
|
||||
int64_t* res_ptr = res + i * res_sl;
|
||||
const int64_t* a_ptr = a + i * a_sl;
|
||||
if (res_ptr == a_ptr) {
|
||||
znx_mul_xp_minus_one_inplace_i64(nn, p, res_ptr);
|
||||
} else {
|
||||
znx_mul_xp_minus_one_i64(nn, p, res_ptr, a_ptr);
|
||||
}
|
||||
}
|
||||
// then extend with zeros
|
||||
for (uint64_t i = rot_end_idx; i < res_size; ++i) {
|
||||
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void vec_znx_automorphism_ref(const MODULE* module, // N
|
||||
const int64_t p, // X->X^p
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
const uint64_t nn = module->nn;
|
||||
|
||||
const uint64_t auto_end_idx = res_size < a_size ? res_size : a_size;
|
||||
|
||||
for (uint64_t i = 0; i < auto_end_idx; ++i) {
|
||||
int64_t* res_ptr = res + i * res_sl;
|
||||
const int64_t* a_ptr = a + i * a_sl;
|
||||
if (res_ptr == a_ptr) {
|
||||
znx_automorphism_inplace_i64(nn, p, res_ptr);
|
||||
} else {
|
||||
znx_automorphism_i64(nn, p, res_ptr, a_ptr);
|
||||
}
|
||||
}
|
||||
// then extend with zeros
|
||||
for (uint64_t i = auto_end_idx; i < res_size; ++i) {
|
||||
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void vec_znx_normalize_base2k_ref(const MODULE* module, uint64_t nn, // N
|
||||
uint64_t log2_base2k, // output base 2^K
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
uint8_t* tmp_space // scratch space of size >= N
|
||||
) {
|
||||
|
||||
// use MSB limb of res for carry propagation
|
||||
int64_t* cout = (int64_t*)tmp_space;
|
||||
int64_t* cin = 0x0;
|
||||
|
||||
// propagate carry until first limb of res
|
||||
int64_t i = a_size - 1;
|
||||
for (; i >= res_size; --i) {
|
||||
znx_normalize(nn, log2_base2k, 0x0, cout, a + i * a_sl, cin);
|
||||
cin = cout;
|
||||
}
|
||||
|
||||
// propagate carry and normalize
|
||||
for (; i >= 1; --i) {
|
||||
znx_normalize(nn, log2_base2k, res + i * res_sl, cout, a + i * a_sl, cin);
|
||||
cin = cout;
|
||||
}
|
||||
|
||||
// normalize last limb
|
||||
znx_normalize(nn, log2_base2k, res, 0x0, a, cin);
|
||||
|
||||
// extend result with zeros
|
||||
for (uint64_t i = a_size; i < res_size; ++i) {
|
||||
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT uint64_t vec_znx_normalize_base2k_tmp_bytes_ref(const MODULE* module, uint64_t nn // N
|
||||
) {
|
||||
return nn * sizeof(int64_t);
|
||||
}
|
||||
|
||||
// alias have to be defined in this unit: do not move
|
||||
#ifdef __APPLE__
|
||||
EXPORT uint64_t fft64_vec_znx_big_range_normalize_base2k_tmp_bytes( //
|
||||
const MODULE* module, // N
|
||||
uint64_t nn
|
||||
) {
|
||||
return vec_znx_normalize_base2k_tmp_bytes_ref(module, nn);
|
||||
}
|
||||
EXPORT uint64_t fft64_vec_znx_big_normalize_base2k_tmp_bytes( //
|
||||
const MODULE* module, // N
|
||||
uint64_t nn
|
||||
) {
|
||||
return vec_znx_normalize_base2k_tmp_bytes_ref(module, nn);
|
||||
}
|
||||
#else
|
||||
EXPORT uint64_t fft64_vec_znx_big_normalize_base2k_tmp_bytes( //
|
||||
const MODULE* module, // N
|
||||
uint64_t nn
|
||||
) __attribute((alias("vec_znx_normalize_base2k_tmp_bytes_ref")));
|
||||
|
||||
EXPORT uint64_t fft64_vec_znx_big_range_normalize_base2k_tmp_bytes( //
|
||||
const MODULE* module, // N
|
||||
uint64_t nn
|
||||
) __attribute((alias("vec_znx_normalize_base2k_tmp_bytes_ref")));
|
||||
#endif
|
||||
|
||||
/** @brief sets res = 0 */
|
||||
EXPORT void vec_znx_zero(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl // res
|
||||
) {
|
||||
module->func.vec_znx_zero(module, res, res_size, res_sl);
|
||||
}
|
||||
|
||||
/** @brief sets res = a */
|
||||
EXPORT void vec_znx_copy(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
module->func.vec_znx_copy(module, res, res_size, res_sl, a, a_size, a_sl);
|
||||
}
|
||||
|
||||
/** @brief sets res = a */
|
||||
EXPORT void vec_znx_negate(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
module->func.vec_znx_negate(module, res, res_size, res_sl, a, a_size, a_sl);
|
||||
}
|
||||
|
||||
EXPORT void vec_znx_zero_ref(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl // res
|
||||
) {
|
||||
uint64_t nn = module->nn;
|
||||
for (uint64_t i = 0; i < res_size; ++i) {
|
||||
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void vec_znx_copy_ref(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
uint64_t nn = module->nn;
|
||||
uint64_t smin = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < smin; ++i) {
|
||||
znx_copy_i64_ref(nn, res + i * res_sl, a + i * a_sl);
|
||||
}
|
||||
for (uint64_t i = smin; i < res_size; ++i) {
|
||||
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void vec_znx_negate_ref(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
uint64_t nn = module->nn;
|
||||
uint64_t smin = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < smin; ++i) {
|
||||
znx_negate_i64_ref(nn, res + i * res_sl, a + i * a_sl);
|
||||
}
|
||||
for (uint64_t i = smin; i < res_size; ++i) {
|
||||
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,370 @@
|
||||
#ifndef SPQLIOS_VEC_ZNX_ARITHMETIC_H
|
||||
#define SPQLIOS_VEC_ZNX_ARITHMETIC_H
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include "../commons.h"
|
||||
#include "../reim/reim_fft.h"
|
||||
|
||||
/**
|
||||
* We support the following module families:
|
||||
* - FFT64:
|
||||
* all the polynomials should fit at all times over 52 bits.
|
||||
* for FHE implementations, the recommended limb-sizes are
|
||||
* between K=10 and 20, which is good for low multiplicative depths.
|
||||
* - NTT120:
|
||||
* all the polynomials should fit at all times over 119 bits.
|
||||
* for FHE implementations, the recommended limb-sizes are
|
||||
* between K=20 and 40, which is good for large multiplicative depths.
|
||||
*/
|
||||
typedef enum module_type_t { FFT64, NTT120 } MODULE_TYPE;
|
||||
|
||||
/** @brief opaque structure that describr the modules (ZnX,TnX) and the hardware */
|
||||
typedef struct module_info_t MODULE;
|
||||
/** @brief opaque type that represents a prepared matrix */
|
||||
typedef struct vmp_pmat_t VMP_PMAT;
|
||||
/** @brief opaque type that represents a vector of znx in DFT space */
|
||||
typedef struct vec_znx_dft_t VEC_ZNX_DFT;
|
||||
/** @brief opaque type that represents a vector of znx in large coeffs space */
|
||||
typedef struct vec_znx_bigcoeff_t VEC_ZNX_BIG;
|
||||
/** @brief opaque type that represents a prepared scalar vector product */
|
||||
typedef struct svp_ppol_t SVP_PPOL;
|
||||
/** @brief opaque type that represents a prepared left convolution vector product */
|
||||
typedef struct cnv_pvec_l_t CNV_PVEC_L;
|
||||
/** @brief opaque type that represents a prepared right convolution vector product */
|
||||
typedef struct cnv_pvec_r_t CNV_PVEC_R;
|
||||
|
||||
/** @brief bytes needed for a vec_znx in DFT space */
|
||||
EXPORT uint64_t bytes_of_vec_znx_dft(const MODULE* module, // N
|
||||
uint64_t size);
|
||||
|
||||
/** @brief allocates a vec_znx in DFT space */
|
||||
EXPORT VEC_ZNX_DFT* new_vec_znx_dft(const MODULE* module, // N
|
||||
uint64_t size);
|
||||
|
||||
/** @brief frees memory from a vec_znx in DFT space */
|
||||
EXPORT void delete_vec_znx_dft(VEC_ZNX_DFT* res);
|
||||
|
||||
/** @brief bytes needed for a vec_znx_big */
|
||||
EXPORT uint64_t bytes_of_vec_znx_big(const MODULE* module, // N
|
||||
uint64_t size);
|
||||
|
||||
/** @brief allocates a vec_znx_big */
|
||||
EXPORT VEC_ZNX_BIG* new_vec_znx_big(const MODULE* module, // N
|
||||
uint64_t size);
|
||||
/** @brief frees memory from a vec_znx_big */
|
||||
EXPORT void delete_vec_znx_big(VEC_ZNX_BIG* res);
|
||||
|
||||
/** @brief bytes needed for a prepared vector */
|
||||
EXPORT uint64_t bytes_of_svp_ppol(const MODULE* module); // N
|
||||
|
||||
/** @brief allocates a prepared vector */
|
||||
EXPORT SVP_PPOL* new_svp_ppol(const MODULE* module); // N
|
||||
|
||||
/** @brief frees memory for a prepared vector */
|
||||
EXPORT void delete_svp_ppol(SVP_PPOL* res);
|
||||
|
||||
/** @brief bytes needed for a prepared matrix */
|
||||
EXPORT uint64_t bytes_of_vmp_pmat(const MODULE* module, // N
|
||||
uint64_t nrows, uint64_t ncols);
|
||||
|
||||
/** @brief allocates a prepared matrix */
|
||||
EXPORT VMP_PMAT* new_vmp_pmat(const MODULE* module, // N
|
||||
uint64_t nrows, uint64_t ncols);
|
||||
|
||||
/** @brief frees memory for a prepared matrix */
|
||||
EXPORT void delete_vmp_pmat(VMP_PMAT* res);
|
||||
|
||||
/**
|
||||
* @brief obtain a module info for ring dimension N
|
||||
* the module-info knows about:
|
||||
* - the dimension N (or the complex dimension m=N/2)
|
||||
* - any moduleuted fft or ntt items
|
||||
* - the hardware (avx, arm64, x86, ...)
|
||||
*/
|
||||
EXPORT MODULE* new_module_info(uint64_t N, MODULE_TYPE mode);
|
||||
EXPORT void delete_module_info(MODULE* module_info);
|
||||
EXPORT uint64_t module_get_n(const MODULE* module);
|
||||
|
||||
/** @brief sets res = 0 */
|
||||
EXPORT void vec_znx_zero(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl // res
|
||||
);
|
||||
|
||||
/** @brief sets res = a */
|
||||
EXPORT void vec_znx_copy(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
/** @brief sets res = a */
|
||||
EXPORT void vec_znx_negate(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
/** @brief sets res = a + b */
|
||||
EXPORT void vec_znx_add(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
);
|
||||
|
||||
/** @brief sets res = a - b */
|
||||
EXPORT void vec_znx_sub(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
);
|
||||
|
||||
/** @brief sets res = k-normalize-reduce(a) */
|
||||
EXPORT void vec_znx_normalize_base2k(const MODULE* module, // MODULE
|
||||
uint64_t nn, // N
|
||||
uint64_t log2_base2k, // output base 2^K
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
uint8_t* tmp_space // scratch space (size >= N)
|
||||
);
|
||||
|
||||
/** @brief returns the minimal byte length of scratch space for vec_znx_normalize_base2k */
|
||||
EXPORT uint64_t vec_znx_normalize_base2k_tmp_bytes(const MODULE* module, uint64_t nn // N
|
||||
);
|
||||
|
||||
/** @brief sets res = a . X^p */
|
||||
EXPORT void vec_znx_rotate(const MODULE* module, // N
|
||||
const int64_t p, // rotation value
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
/** @brief sets res = a * (X^{p} - 1) */
|
||||
EXPORT void vec_znx_mul_xp_minus_one(const MODULE* module, // N
|
||||
const int64_t p, // rotation value
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
/** @brief sets res = a(X^p) */
|
||||
EXPORT void vec_znx_automorphism(const MODULE* module, // N
|
||||
const int64_t p, // X-X^p
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
/** @brief sets res = 0 */
|
||||
EXPORT void vec_dft_zero(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size // res
|
||||
);
|
||||
|
||||
/** @brief sets res = a+b */
|
||||
EXPORT void vec_dft_add(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a, uint64_t a_size, // a
|
||||
const VEC_ZNX_DFT* b, uint64_t b_size // b
|
||||
);
|
||||
|
||||
/** @brief sets res = a-b */
|
||||
EXPORT void vec_dft_sub(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a, uint64_t a_size, // a
|
||||
const VEC_ZNX_DFT* b, uint64_t b_size // b
|
||||
);
|
||||
|
||||
/** @brief sets res = DFT(a) */
|
||||
EXPORT void vec_znx_dft(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
/** @brief sets res = iDFT(a_dft) -- output in big coeffs space */
|
||||
EXPORT void vec_znx_idft(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||
uint8_t* tmp // scratch space
|
||||
);
|
||||
|
||||
/** @brief tmp bytes required for vec_znx_idft */
|
||||
EXPORT uint64_t vec_znx_idft_tmp_bytes(const MODULE* module, uint64_t nn);
|
||||
|
||||
/**
|
||||
* @brief sets res = iDFT(a_dft) -- output in big coeffs space
|
||||
*
|
||||
* @note a_dft is overwritten
|
||||
*/
|
||||
EXPORT void vec_znx_idft_tmp_a(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
VEC_ZNX_DFT* a_dft, uint64_t a_size // a is overwritten
|
||||
);
|
||||
|
||||
/** @brief sets res = a+b */
|
||||
EXPORT void vec_znx_big_add(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||
);
|
||||
/** @brief sets res = a+b */
|
||||
EXPORT void vec_znx_big_add_small(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
);
|
||||
EXPORT void vec_znx_big_add_small2(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
);
|
||||
|
||||
/** @brief sets res = a-b */
|
||||
EXPORT void vec_znx_big_sub(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||
);
|
||||
EXPORT void vec_znx_big_sub_small_b(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
);
|
||||
EXPORT void vec_znx_big_sub_small_a(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||
);
|
||||
EXPORT void vec_znx_big_sub_small2(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
);
|
||||
|
||||
/** @brief sets res = k-normalize(a) -- output in int64 coeffs space */
|
||||
EXPORT void vec_znx_big_normalize_base2k(const MODULE* module, // MODULE
|
||||
uint64_t nn, // N
|
||||
uint64_t log2_base2k, // base-2^k
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||
uint8_t* tmp_space // temp space
|
||||
);
|
||||
|
||||
/** @brief returns the minimal byte length of scratch space for vec_znx_big_normalize_base2k */
|
||||
EXPORT uint64_t vec_znx_big_normalize_base2k_tmp_bytes(const MODULE* module, uint64_t nn // N
|
||||
);
|
||||
|
||||
/** @brief sets res = k-normalize(a.subrange) -- output in int64 coeffs space */
|
||||
EXPORT void vec_znx_big_range_normalize_base2k( //
|
||||
const MODULE* module, // MODULE
|
||||
uint64_t nn,
|
||||
uint64_t log2_base2k, // base-2^k
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_range_begin, uint64_t a_range_xend, uint64_t a_range_step, // range
|
||||
uint8_t* tmp_space // temp space
|
||||
);
|
||||
|
||||
/** @brief returns the minimal byte length of scratch space for vec_znx_big_range_normalize_base2k */
|
||||
EXPORT uint64_t vec_znx_big_range_normalize_base2k_tmp_bytes( //
|
||||
const MODULE* module, uint64_t nn // N
|
||||
);
|
||||
|
||||
/** @brief sets res = a . X^p */
|
||||
EXPORT void vec_znx_big_rotate(const MODULE* module, // N
|
||||
int64_t p, // rotation value
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size // a
|
||||
);
|
||||
|
||||
/** @brief sets res = a(X^p) */
|
||||
EXPORT void vec_znx_big_automorphism(const MODULE* module, // N
|
||||
int64_t p, // X-X^p
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size // a
|
||||
);
|
||||
|
||||
/** @brief apply a svp product, result = ppol * a, presented in DFT space */
|
||||
EXPORT void svp_apply_dft(const MODULE* module, // N
|
||||
const VEC_ZNX_DFT* res, uint64_t res_size, // output
|
||||
const SVP_PPOL* ppol, // prepared pol
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
/** @brief apply a svp product, result = ppol * a, presented in DFT space */
|
||||
EXPORT void svp_apply_dft_to_dft(const MODULE* module, // N
|
||||
const VEC_ZNX_DFT* res, uint64_t res_size,
|
||||
uint64_t res_cols, // output
|
||||
const SVP_PPOL* ppol, // prepared pol
|
||||
const VEC_ZNX_DFT* a, uint64_t a_size, uint64_t a_cols // a
|
||||
);
|
||||
|
||||
/** @brief prepares a svp polynomial */
|
||||
EXPORT void svp_prepare(const MODULE* module, // N
|
||||
SVP_PPOL* ppol, // output
|
||||
const int64_t* pol // a
|
||||
);
|
||||
|
||||
/** @brief res = a * b : small integer polynomial product */
|
||||
EXPORT void znx_small_single_product(const MODULE* module, // N
|
||||
int64_t* res, // output
|
||||
const int64_t* a, // a
|
||||
const int64_t* b, // b
|
||||
uint8_t* tmp);
|
||||
|
||||
/** @brief tmp bytes required for znx_small_single_product */
|
||||
EXPORT uint64_t znx_small_single_product_tmp_bytes(const MODULE* module, uint64_t nn);
|
||||
|
||||
/** @brief minimal scratch space byte-size required for the vmp_prepare function */
|
||||
EXPORT uint64_t vmp_prepare_tmp_bytes(const MODULE* module, uint64_t nn, // N
|
||||
uint64_t nrows, uint64_t ncols);
|
||||
|
||||
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||
EXPORT void vmp_prepare_contiguous(const MODULE* module, // N
|
||||
VMP_PMAT* pmat, // output
|
||||
const int64_t* mat, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
|
||||
/** @brief applies a vmp product (result in DFT space) adds to res inplace */
|
||||
EXPORT void vmp_apply_dft_add(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, uint64_t pmat_scale, // prep matrix
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
|
||||
/** @brief applies a vmp product (result in DFT space) */
|
||||
EXPORT void vmp_apply_dft(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
|
||||
/** @brief minimal size of the tmp_space */
|
||||
EXPORT uint64_t vmp_apply_dft_tmp_bytes(const MODULE* module, uint64_t nn, // N
|
||||
uint64_t res_size, // res
|
||||
uint64_t a_size, // a
|
||||
uint64_t nrows, uint64_t ncols // prep matrix
|
||||
);
|
||||
|
||||
/** @brief applies vmp product */
|
||||
EXPORT void vmp_apply_dft_to_dft(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, const uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||
const VMP_PMAT* pmat, const uint64_t nrows,
|
||||
const uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||
);
|
||||
|
||||
/** @brief applies vmp product and adds to res inplace */
|
||||
EXPORT void vmp_apply_dft_to_dft_add(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, const uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||
const VMP_PMAT* pmat, const uint64_t nrows, const uint64_t ncols,
|
||||
const uint64_t pmat_scale, // prep matrix
|
||||
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||
);
|
||||
|
||||
/** @brief minimal size of the tmp_space */
|
||||
EXPORT uint64_t vmp_apply_dft_to_dft_tmp_bytes(const MODULE* module, uint64_t nn, // N
|
||||
uint64_t res_size, // res
|
||||
uint64_t a_size, // a
|
||||
uint64_t nrows, uint64_t ncols // prep matrix
|
||||
);
|
||||
#endif // SPQLIOS_VEC_ZNX_ARITHMETIC_H
|
||||
@@ -0,0 +1,563 @@
|
||||
#ifndef SPQLIOS_VEC_ZNX_ARITHMETIC_PRIVATE_H
|
||||
#define SPQLIOS_VEC_ZNX_ARITHMETIC_PRIVATE_H
|
||||
|
||||
#include "../commons_private.h"
|
||||
#include "../q120/q120_ntt.h"
|
||||
#include "vec_znx_arithmetic.h"
|
||||
|
||||
/**
|
||||
* Layouts families:
|
||||
*
|
||||
* fft64:
|
||||
* K: <= 20, N: <= 65536, ell: <= 200
|
||||
* vec<ZnX> normalized: represented by int64
|
||||
* vec<ZnX> large: represented by int64 (expect <=52 bits)
|
||||
* vec<ZnX> DFT: represented by double (reim_fft space)
|
||||
* On AVX2 inftastructure, PMAT, LCNV, RCNV use a special reim4_fft space
|
||||
*
|
||||
* ntt120:
|
||||
* K: <= 50, N: <= 65536, ell: <= 80
|
||||
* vec<ZnX> normalized: represented by int64
|
||||
* vec<ZnX> large: represented by int128 (expect <=120 bits)
|
||||
* vec<ZnX> DFT: represented by int64x4 (ntt120 space)
|
||||
* On AVX2 inftastructure, PMAT, LCNV, RCNV use a special ntt120 space
|
||||
*
|
||||
* ntt104:
|
||||
* K: <= 40, N: <= 65536, ell: <= 80
|
||||
* vec<ZnX> normalized: represented by int64
|
||||
* vec<ZnX> large: represented by int128 (expect <=120 bits)
|
||||
* vec<ZnX> DFT: represented by int64x4 (ntt120 space)
|
||||
* On AVX512 inftastructure, PMAT, LCNV, RCNV use a special ntt104 space
|
||||
*/
|
||||
|
||||
struct fft64_module_info_t {
|
||||
// pre-computation for reim_fft
|
||||
REIM_FFT_PRECOMP* p_fft;
|
||||
// pre-computation for add_fft
|
||||
REIM_FFTVEC_ADD_PRECOMP* add_fft;
|
||||
// pre-computation for add_fft
|
||||
REIM_FFTVEC_SUB_PRECOMP* sub_fft;
|
||||
// pre-computation for mul_fft
|
||||
REIM_FFTVEC_MUL_PRECOMP* mul_fft;
|
||||
// pre-computation for reim_from_znx6
|
||||
REIM_FROM_ZNX64_PRECOMP* p_conv;
|
||||
// pre-computation for reim_tp_znx6
|
||||
REIM_TO_ZNX64_PRECOMP* p_reim_to_znx;
|
||||
// pre-computation for reim_fft
|
||||
REIM_IFFT_PRECOMP* p_ifft;
|
||||
// pre-computation for reim_fftvec_addmul
|
||||
REIM_FFTVEC_ADDMUL_PRECOMP* p_addmul;
|
||||
};
|
||||
|
||||
struct q120_module_info_t {
|
||||
// pre-computation for q120b to q120b ntt
|
||||
q120_ntt_precomp* p_ntt;
|
||||
// pre-computation for q120b to q120b intt
|
||||
q120_ntt_precomp* p_intt;
|
||||
};
|
||||
|
||||
// TODO add function types here
|
||||
typedef typeof(vec_znx_zero) VEC_ZNX_ZERO_F;
|
||||
typedef typeof(vec_znx_copy) VEC_ZNX_COPY_F;
|
||||
typedef typeof(vec_znx_negate) VEC_ZNX_NEGATE_F;
|
||||
typedef typeof(vec_znx_add) VEC_ZNX_ADD_F;
|
||||
typedef typeof(vec_znx_dft) VEC_ZNX_DFT_F;
|
||||
typedef typeof(vec_dft_add) VEC_DFT_ADD_F;
|
||||
typedef typeof(vec_dft_sub) VEC_DFT_SUB_F;
|
||||
typedef typeof(vec_znx_idft) VEC_ZNX_IDFT_F;
|
||||
typedef typeof(vec_znx_idft_tmp_bytes) VEC_ZNX_IDFT_TMP_BYTES_F;
|
||||
typedef typeof(vec_znx_idft_tmp_a) VEC_ZNX_IDFT_TMP_A_F;
|
||||
typedef typeof(vec_znx_sub) VEC_ZNX_SUB_F;
|
||||
typedef typeof(vec_znx_rotate) VEC_ZNX_ROTATE_F;
|
||||
typedef typeof(vec_znx_mul_xp_minus_one) VEC_ZNX_MUL_XP_MINUS_ONE_F;
|
||||
typedef typeof(vec_znx_automorphism) VEC_ZNX_AUTOMORPHISM_F;
|
||||
typedef typeof(vec_znx_normalize_base2k) VEC_ZNX_NORMALIZE_BASE2K_F;
|
||||
typedef typeof(vec_znx_normalize_base2k_tmp_bytes) VEC_ZNX_NORMALIZE_BASE2K_TMP_BYTES_F;
|
||||
typedef typeof(vec_znx_big_normalize_base2k) VEC_ZNX_BIG_NORMALIZE_BASE2K_F;
|
||||
typedef typeof(vec_znx_big_normalize_base2k_tmp_bytes) VEC_ZNX_BIG_NORMALIZE_BASE2K_TMP_BYTES_F;
|
||||
typedef typeof(vec_znx_big_range_normalize_base2k) VEC_ZNX_BIG_RANGE_NORMALIZE_BASE2K_F;
|
||||
typedef typeof(vec_znx_big_range_normalize_base2k_tmp_bytes) VEC_ZNX_BIG_RANGE_NORMALIZE_BASE2K_TMP_BYTES_F;
|
||||
typedef typeof(vec_znx_big_add) VEC_ZNX_BIG_ADD_F;
|
||||
typedef typeof(vec_znx_big_add_small) VEC_ZNX_BIG_ADD_SMALL_F;
|
||||
typedef typeof(vec_znx_big_add_small2) VEC_ZNX_BIG_ADD_SMALL2_F;
|
||||
typedef typeof(vec_znx_big_sub) VEC_ZNX_BIG_SUB_F;
|
||||
typedef typeof(vec_znx_big_sub_small_a) VEC_ZNX_BIG_SUB_SMALL_A_F;
|
||||
typedef typeof(vec_znx_big_sub_small_b) VEC_ZNX_BIG_SUB_SMALL_B_F;
|
||||
typedef typeof(vec_znx_big_sub_small2) VEC_ZNX_BIG_SUB_SMALL2_F;
|
||||
typedef typeof(vec_znx_big_rotate) VEC_ZNX_BIG_ROTATE_F;
|
||||
typedef typeof(vec_znx_big_automorphism) VEC_ZNX_BIG_AUTOMORPHISM_F;
|
||||
typedef typeof(svp_prepare) SVP_PREPARE;
|
||||
typedef typeof(svp_apply_dft) SVP_APPLY_DFT_F;
|
||||
typedef typeof(svp_apply_dft_to_dft) SVP_APPLY_DFT_TO_DFT_F;
|
||||
typedef typeof(znx_small_single_product) ZNX_SMALL_SINGLE_PRODUCT_F;
|
||||
typedef typeof(znx_small_single_product_tmp_bytes) ZNX_SMALL_SINGLE_PRODUCT_TMP_BYTES_F;
|
||||
typedef typeof(vmp_prepare_contiguous) VMP_PREPARE_CONTIGUOUS_F;
|
||||
typedef typeof(vmp_prepare_tmp_bytes) VMP_PREPARE_TMP_BYTES_F;
|
||||
typedef typeof(vmp_apply_dft) VMP_APPLY_DFT_F;
|
||||
typedef typeof(vmp_apply_dft_add) VMP_APPLY_DFT_ADD_F;
|
||||
typedef typeof(vmp_apply_dft_tmp_bytes) VMP_APPLY_DFT_TMP_BYTES_F;
|
||||
typedef typeof(vmp_apply_dft_to_dft) VMP_APPLY_DFT_TO_DFT_F;
|
||||
typedef typeof(vmp_apply_dft_to_dft_add) VMP_APPLY_DFT_TO_DFT_ADD_F;
|
||||
typedef typeof(vmp_apply_dft_to_dft_tmp_bytes) VMP_APPLY_DFT_TO_DFT_TMP_BYTES_F;
|
||||
typedef typeof(bytes_of_vec_znx_dft) BYTES_OF_VEC_ZNX_DFT_F;
|
||||
typedef typeof(bytes_of_vec_znx_big) BYTES_OF_VEC_ZNX_BIG_F;
|
||||
typedef typeof(bytes_of_svp_ppol) BYTES_OF_SVP_PPOL_F;
|
||||
typedef typeof(bytes_of_vmp_pmat) BYTES_OF_VMP_PMAT_F;
|
||||
|
||||
struct module_virtual_functions_t {
|
||||
// TODO add functions here
|
||||
VEC_ZNX_ZERO_F* vec_znx_zero;
|
||||
VEC_ZNX_COPY_F* vec_znx_copy;
|
||||
VEC_ZNX_NEGATE_F* vec_znx_negate;
|
||||
VEC_ZNX_ADD_F* vec_znx_add;
|
||||
VEC_ZNX_DFT_F* vec_znx_dft;
|
||||
VEC_DFT_ADD_F* vec_dft_add;
|
||||
VEC_DFT_SUB_F* vec_dft_sub;
|
||||
VEC_ZNX_IDFT_F* vec_znx_idft;
|
||||
VEC_ZNX_IDFT_TMP_BYTES_F* vec_znx_idft_tmp_bytes;
|
||||
VEC_ZNX_IDFT_TMP_A_F* vec_znx_idft_tmp_a;
|
||||
VEC_ZNX_SUB_F* vec_znx_sub;
|
||||
VEC_ZNX_ROTATE_F* vec_znx_rotate;
|
||||
VEC_ZNX_MUL_XP_MINUS_ONE_F* vec_znx_mul_xp_minus_one;
|
||||
VEC_ZNX_AUTOMORPHISM_F* vec_znx_automorphism;
|
||||
VEC_ZNX_NORMALIZE_BASE2K_F* vec_znx_normalize_base2k;
|
||||
VEC_ZNX_NORMALIZE_BASE2K_TMP_BYTES_F* vec_znx_normalize_base2k_tmp_bytes;
|
||||
VEC_ZNX_BIG_NORMALIZE_BASE2K_F* vec_znx_big_normalize_base2k;
|
||||
VEC_ZNX_BIG_NORMALIZE_BASE2K_TMP_BYTES_F* vec_znx_big_normalize_base2k_tmp_bytes;
|
||||
VEC_ZNX_BIG_RANGE_NORMALIZE_BASE2K_F* vec_znx_big_range_normalize_base2k;
|
||||
VEC_ZNX_BIG_RANGE_NORMALIZE_BASE2K_TMP_BYTES_F* vec_znx_big_range_normalize_base2k_tmp_bytes;
|
||||
VEC_ZNX_BIG_ADD_F* vec_znx_big_add;
|
||||
VEC_ZNX_BIG_ADD_SMALL_F* vec_znx_big_add_small;
|
||||
VEC_ZNX_BIG_ADD_SMALL2_F* vec_znx_big_add_small2;
|
||||
VEC_ZNX_BIG_SUB_F* vec_znx_big_sub;
|
||||
VEC_ZNX_BIG_SUB_SMALL_A_F* vec_znx_big_sub_small_a;
|
||||
VEC_ZNX_BIG_SUB_SMALL_B_F* vec_znx_big_sub_small_b;
|
||||
VEC_ZNX_BIG_SUB_SMALL2_F* vec_znx_big_sub_small2;
|
||||
VEC_ZNX_BIG_ROTATE_F* vec_znx_big_rotate;
|
||||
VEC_ZNX_BIG_AUTOMORPHISM_F* vec_znx_big_automorphism;
|
||||
SVP_PREPARE* svp_prepare;
|
||||
SVP_APPLY_DFT_F* svp_apply_dft;
|
||||
SVP_APPLY_DFT_TO_DFT_F* svp_apply_dft_to_dft;
|
||||
ZNX_SMALL_SINGLE_PRODUCT_F* znx_small_single_product;
|
||||
ZNX_SMALL_SINGLE_PRODUCT_TMP_BYTES_F* znx_small_single_product_tmp_bytes;
|
||||
VMP_PREPARE_CONTIGUOUS_F* vmp_prepare_contiguous;
|
||||
VMP_PREPARE_TMP_BYTES_F* vmp_prepare_tmp_bytes;
|
||||
VMP_APPLY_DFT_F* vmp_apply_dft;
|
||||
VMP_APPLY_DFT_ADD_F* vmp_apply_dft_add;
|
||||
VMP_APPLY_DFT_TMP_BYTES_F* vmp_apply_dft_tmp_bytes;
|
||||
VMP_APPLY_DFT_TO_DFT_F* vmp_apply_dft_to_dft;
|
||||
VMP_APPLY_DFT_TO_DFT_ADD_F* vmp_apply_dft_to_dft_add;
|
||||
VMP_APPLY_DFT_TO_DFT_TMP_BYTES_F* vmp_apply_dft_to_dft_tmp_bytes;
|
||||
BYTES_OF_VEC_ZNX_DFT_F* bytes_of_vec_znx_dft;
|
||||
BYTES_OF_VEC_ZNX_BIG_F* bytes_of_vec_znx_big;
|
||||
BYTES_OF_SVP_PPOL_F* bytes_of_svp_ppol;
|
||||
BYTES_OF_VMP_PMAT_F* bytes_of_vmp_pmat;
|
||||
};
|
||||
|
||||
union backend_module_info_t {
|
||||
struct fft64_module_info_t fft64;
|
||||
struct q120_module_info_t q120;
|
||||
};
|
||||
|
||||
struct module_info_t {
|
||||
// generic parameters
|
||||
MODULE_TYPE module_type;
|
||||
uint64_t nn;
|
||||
uint64_t m;
|
||||
// backend_dependent functions
|
||||
union backend_module_info_t mod;
|
||||
// virtual functions
|
||||
struct module_virtual_functions_t func;
|
||||
};
|
||||
|
||||
EXPORT uint64_t fft64_bytes_of_vec_znx_dft(const MODULE* module, // N
|
||||
uint64_t size);
|
||||
|
||||
EXPORT uint64_t fft64_bytes_of_vec_znx_big(const MODULE* module, // N
|
||||
uint64_t size);
|
||||
|
||||
EXPORT uint64_t fft64_bytes_of_svp_ppol(const MODULE* module); // N
|
||||
|
||||
EXPORT uint64_t fft64_bytes_of_vmp_pmat(const MODULE* module, // N
|
||||
uint64_t nrows, uint64_t ncols);
|
||||
|
||||
EXPORT void vec_znx_zero_ref(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl // res
|
||||
);
|
||||
|
||||
EXPORT void vec_znx_copy_ref(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
EXPORT void vec_znx_negate_ref(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
EXPORT void vec_znx_negate_avx(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
EXPORT void vec_znx_add_ref(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
);
|
||||
EXPORT void vec_znx_add_avx(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
);
|
||||
|
||||
EXPORT void vec_znx_sub_ref(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
);
|
||||
|
||||
EXPORT void vec_znx_sub_avx(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
);
|
||||
|
||||
EXPORT void vec_znx_normalize_base2k_ref(const MODULE* module, uint64_t nn, // N
|
||||
uint64_t log2_base2k, // output base 2^K
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // inp
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
|
||||
EXPORT uint64_t vec_znx_normalize_base2k_tmp_bytes_ref(const MODULE* module, uint64_t nn // N
|
||||
);
|
||||
|
||||
EXPORT void vec_znx_rotate_ref(const MODULE* module, // N
|
||||
const int64_t p, // rotation value
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
EXPORT void vec_znx_mul_xp_minus_one_ref(const MODULE* module, // N
|
||||
const int64_t p, // rotation value
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
EXPORT void vec_znx_automorphism_ref(const MODULE* module, // N
|
||||
const int64_t p, // X->X^p
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
EXPORT void vmp_prepare_ref(const MODULE* module, // N
|
||||
VMP_PMAT* pmat, // output
|
||||
const int64_t* mat, uint64_t nrows, uint64_t ncols // a
|
||||
);
|
||||
|
||||
EXPORT void vmp_apply_dft_ref(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols // prep matrix
|
||||
);
|
||||
|
||||
EXPORT void vec_dft_zero_ref(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size // res
|
||||
);
|
||||
|
||||
EXPORT void vec_dft_add_ref(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a, uint64_t a_size, // a
|
||||
const VEC_ZNX_DFT* b, uint64_t b_size // b
|
||||
);
|
||||
|
||||
EXPORT void vec_dft_sub_ref(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a, uint64_t a_size, // a
|
||||
const VEC_ZNX_DFT* b, uint64_t b_size // b
|
||||
);
|
||||
|
||||
EXPORT void vec_dft_ref(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
EXPORT void vec_idft_ref(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a_dft, uint64_t a_size);
|
||||
|
||||
EXPORT void vec_znx_big_normalize_ref(const MODULE* module, // MODULE
|
||||
uint64_t nn, // N
|
||||
uint64_t k, // base-2^k
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size // a
|
||||
);
|
||||
|
||||
/** @brief apply a svp product, result = ppol * a, presented in DFT space */
|
||||
EXPORT void fft64_svp_apply_dft_ref(const MODULE* module, // N
|
||||
const VEC_ZNX_DFT* res, uint64_t res_size, // output
|
||||
const SVP_PPOL* ppol, // prepared pol
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
/** @brief apply a svp product, result = ppol * a, presented in DFT space */
|
||||
EXPORT void fft64_svp_apply_dft_to_dft_ref(const MODULE* module, // N
|
||||
const VEC_ZNX_DFT* res, uint64_t res_size,
|
||||
uint64_t res_cols, // output
|
||||
const SVP_PPOL* ppol, // prepared pol
|
||||
const VEC_ZNX_DFT* a, uint64_t a_size,
|
||||
uint64_t a_cols // a
|
||||
);
|
||||
|
||||
/** @brief sets res = k-normalize(a) -- output in int64 coeffs space */
|
||||
EXPORT void fft64_vec_znx_big_normalize_base2k(const MODULE* module, // MODULE
|
||||
uint64_t nn, // N
|
||||
uint64_t k, // base-2^k
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||
uint8_t* tmp_space // temp space
|
||||
);
|
||||
|
||||
/** @brief returns the minimal byte length of scratch space for vec_znx_big_normalize_base2k */
|
||||
EXPORT uint64_t fft64_vec_znx_big_normalize_base2k_tmp_bytes(const MODULE* module, uint64_t nn // N
|
||||
|
||||
);
|
||||
|
||||
/** @brief sets res = k-normalize(a.subrange) -- output in int64 coeffs space */
|
||||
EXPORT void fft64_vec_znx_big_range_normalize_base2k(const MODULE* module, // MODULE
|
||||
uint64_t nn,
|
||||
uint64_t log2_base2k, // base-2^k
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_range_begin, // a
|
||||
uint64_t a_range_xend, uint64_t a_range_step, // range
|
||||
uint8_t* tmp_space // temp space
|
||||
);
|
||||
|
||||
/** @brief returns the minimal byte length of scratch space for vec_znx_big_range_normalize_base2k */
|
||||
EXPORT uint64_t fft64_vec_znx_big_range_normalize_base2k_tmp_bytes(const MODULE* module, uint64_t nn // N
|
||||
);
|
||||
|
||||
EXPORT void fft64_vec_znx_dft(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
EXPORT void fft64_vec_dft_add(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a, uint64_t a_size, // a
|
||||
const VEC_ZNX_DFT* b, uint64_t b_size // b
|
||||
);
|
||||
|
||||
EXPORT void fft64_vec_dft_sub(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a, uint64_t a_size, // a
|
||||
const VEC_ZNX_DFT* b, uint64_t b_size // b
|
||||
);
|
||||
|
||||
EXPORT void fft64_vec_znx_idft(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||
uint8_t* tmp // scratch space
|
||||
);
|
||||
|
||||
EXPORT uint64_t fft64_vec_znx_idft_tmp_bytes(const MODULE* module, uint64_t nn);
|
||||
|
||||
EXPORT void fft64_vec_znx_idft_tmp_a(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
VEC_ZNX_DFT* a_dft, uint64_t a_size // a is overwritten
|
||||
);
|
||||
|
||||
EXPORT void ntt120_vec_znx_dft_avx(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
);
|
||||
|
||||
/** */
|
||||
EXPORT void ntt120_vec_znx_idft_avx(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||
uint8_t* tmp // scratch space
|
||||
);
|
||||
|
||||
EXPORT uint64_t ntt120_vec_znx_idft_tmp_bytes_avx(const MODULE* module, uint64_t nn);
|
||||
|
||||
EXPORT void ntt120_vec_znx_idft_tmp_a_avx(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
VEC_ZNX_DFT* a_dft, uint64_t a_size // a is overwritten
|
||||
);
|
||||
|
||||
// big additions/subtractions
|
||||
|
||||
/** @brief sets res = a+b */
|
||||
EXPORT void fft64_vec_znx_big_add(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||
);
|
||||
/** @brief sets res = a+b */
|
||||
EXPORT void fft64_vec_znx_big_add_small(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
);
|
||||
EXPORT void fft64_vec_znx_big_add_small2(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
);
|
||||
|
||||
/** @brief sets res = a-b */
|
||||
EXPORT void fft64_vec_znx_big_sub(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||
);
|
||||
EXPORT void fft64_vec_znx_big_sub_small_b(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
);
|
||||
EXPORT void fft64_vec_znx_big_sub_small_a(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||
);
|
||||
EXPORT void fft64_vec_znx_big_sub_small2(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
);
|
||||
|
||||
/** @brief sets res = a . X^p */
|
||||
EXPORT void fft64_vec_znx_big_rotate(const MODULE* module, // N
|
||||
int64_t p, // rotation value
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size // a
|
||||
);
|
||||
|
||||
/** @brief sets res = a(X^p) */
|
||||
EXPORT void fft64_vec_znx_big_automorphism(const MODULE* module, // N
|
||||
int64_t p, // X-X^p
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size // a
|
||||
);
|
||||
|
||||
/** @brief prepares a svp polynomial */
|
||||
EXPORT void fft64_svp_prepare_ref(const MODULE* module, // N
|
||||
SVP_PPOL* ppol, // output
|
||||
const int64_t* pol // a
|
||||
);
|
||||
|
||||
/** @brief res = a * b : small integer polynomial product */
|
||||
EXPORT void fft64_znx_small_single_product(const MODULE* module, // N
|
||||
int64_t* res, // output
|
||||
const int64_t* a, // a
|
||||
const int64_t* b, // b
|
||||
uint8_t* tmp);
|
||||
|
||||
/** @brief tmp bytes required for znx_small_single_product */
|
||||
EXPORT uint64_t fft64_znx_small_single_product_tmp_bytes(const MODULE* module, uint64_t nn);
|
||||
|
||||
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||
EXPORT void fft64_vmp_prepare_contiguous_ref(const MODULE* module, // N
|
||||
VMP_PMAT* pmat, // output
|
||||
const int64_t* mat, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
|
||||
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||
EXPORT void fft64_vmp_prepare_contiguous_avx(const MODULE* module, // N
|
||||
VMP_PMAT* pmat, // output
|
||||
const int64_t* mat, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
|
||||
/** @brief minimal scratch space byte-size required for the vmp_prepare function */
|
||||
EXPORT uint64_t fft64_vmp_prepare_tmp_bytes(const MODULE* module, uint64_t nn, // N
|
||||
uint64_t nrows, uint64_t ncols);
|
||||
|
||||
/** @brief applies a vmp product (result in DFT space) and adds to res inplace */
|
||||
EXPORT void fft64_vmp_apply_dft_add_ref(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols,
|
||||
uint64_t pmat_scale, // prep matrix
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
|
||||
/** @brief applies a vmp product (result in DFT space) */
|
||||
EXPORT void fft64_vmp_apply_dft_ref(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
|
||||
/** @brief applies a vmp product (result in DFT space) */
|
||||
EXPORT void fft64_vmp_apply_dft_avx(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
|
||||
/** @brief applies a vmp product (result in DFT space) and adds to res inplace*/
|
||||
EXPORT void fft64_vmp_apply_dft_add_avx(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols,
|
||||
uint64_t pmat_scale, // prep matrix
|
||||
uint8_t* tmp_space // scratch space
|
||||
);
|
||||
|
||||
/** @brief this inner function could be very handy */
|
||||
EXPORT void fft64_vmp_apply_dft_to_dft_ref(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, const uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||
const VMP_PMAT* pmat, const uint64_t nrows,
|
||||
const uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||
);
|
||||
|
||||
/** @brief applies rmp product and adds to res inplace */
|
||||
EXPORT void fft64_vmp_apply_dft_to_dft_add_ref(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, const uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||
const VMP_PMAT* pmat, const uint64_t nrows, const uint64_t ncols,
|
||||
uint64_t pmat_scale, // prep matrix
|
||||
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||
);
|
||||
|
||||
/** @brief this inner function could be very handy */
|
||||
EXPORT void fft64_vmp_apply_dft_to_dft_avx(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, const uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||
const VMP_PMAT* pmat, const uint64_t nrows,
|
||||
const uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||
);
|
||||
|
||||
/** @brief applies rmp product and adds to res inplace */
|
||||
EXPORT void fft64_vmp_apply_dft_to_dft_add_avx(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, const uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||
const VMP_PMAT* pmat, const uint64_t nrows, const uint64_t ncols,
|
||||
uint64_t pmat_scale, // prep matrix
|
||||
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||
);
|
||||
|
||||
/** @brief minimal size of the tmp_space */
|
||||
EXPORT uint64_t fft64_vmp_apply_dft_tmp_bytes(const MODULE* module, uint64_t nn, // N
|
||||
uint64_t res_size, // res
|
||||
uint64_t a_size, // a
|
||||
uint64_t nrows, uint64_t ncols // prep matrix
|
||||
);
|
||||
|
||||
/** @brief minimal size of the tmp_space */
|
||||
EXPORT uint64_t fft64_vmp_apply_dft_to_dft_tmp_bytes(const MODULE* module, uint64_t nn, // N
|
||||
uint64_t res_size, // res
|
||||
uint64_t a_size, // a
|
||||
uint64_t nrows, uint64_t ncols // prep matrix
|
||||
);
|
||||
#endif // SPQLIOS_VEC_ZNX_ARITHMETIC_PRIVATE_H
|
||||
@@ -0,0 +1,103 @@
|
||||
#include <string.h>
|
||||
|
||||
#include "../coeffs/coeffs_arithmetic.h"
|
||||
#include "../reim4/reim4_arithmetic.h"
|
||||
#include "vec_znx_arithmetic_private.h"
|
||||
|
||||
// specialized function (ref)
|
||||
|
||||
// Note: these functions do not have an avx variant.
|
||||
#define znx_copy_i64_avx znx_copy_i64_ref
|
||||
#define znx_zero_i64_avx znx_zero_i64_ref
|
||||
|
||||
EXPORT void vec_znx_add_avx(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
) {
|
||||
const uint64_t nn = module->nn;
|
||||
if (a_size <= b_size) {
|
||||
const uint64_t sum_idx = res_size < a_size ? res_size : a_size;
|
||||
const uint64_t copy_idx = res_size < b_size ? res_size : b_size;
|
||||
// add up to the smallest dimension
|
||||
for (uint64_t i = 0; i < sum_idx; ++i) {
|
||||
znx_add_i64_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||
}
|
||||
// then copy to the largest dimension
|
||||
for (uint64_t i = sum_idx; i < copy_idx; ++i) {
|
||||
znx_copy_i64_avx(nn, res + i * res_sl, b + i * b_sl);
|
||||
}
|
||||
// then extend with zeros
|
||||
for (uint64_t i = copy_idx; i < res_size; ++i) {
|
||||
znx_zero_i64_avx(nn, res + i * res_sl);
|
||||
}
|
||||
} else {
|
||||
const uint64_t sum_idx = res_size < b_size ? res_size : b_size;
|
||||
const uint64_t copy_idx = res_size < a_size ? res_size : a_size;
|
||||
// add up to the smallest dimension
|
||||
for (uint64_t i = 0; i < sum_idx; ++i) {
|
||||
znx_add_i64_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||
}
|
||||
// then copy to the largest dimension
|
||||
for (uint64_t i = sum_idx; i < copy_idx; ++i) {
|
||||
znx_copy_i64_avx(nn, res + i * res_sl, a + i * a_sl);
|
||||
}
|
||||
// then extend with zeros
|
||||
for (uint64_t i = copy_idx; i < res_size; ++i) {
|
||||
znx_zero_i64_avx(nn, res + i * res_sl);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void vec_znx_sub_avx(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
) {
|
||||
const uint64_t nn = module->nn;
|
||||
if (a_size <= b_size) {
|
||||
const uint64_t sub_idx = res_size < a_size ? res_size : a_size;
|
||||
const uint64_t copy_idx = res_size < b_size ? res_size : b_size;
|
||||
// subtract up to the smallest dimension
|
||||
for (uint64_t i = 0; i < sub_idx; ++i) {
|
||||
znx_sub_i64_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||
}
|
||||
// then negate to the largest dimension
|
||||
for (uint64_t i = sub_idx; i < copy_idx; ++i) {
|
||||
znx_negate_i64_avx(nn, res + i * res_sl, b + i * b_sl);
|
||||
}
|
||||
// then extend with zeros
|
||||
for (uint64_t i = copy_idx; i < res_size; ++i) {
|
||||
znx_zero_i64_avx(nn, res + i * res_sl);
|
||||
}
|
||||
} else {
|
||||
const uint64_t sub_idx = res_size < b_size ? res_size : b_size;
|
||||
const uint64_t copy_idx = res_size < a_size ? res_size : a_size;
|
||||
// subtract up to the smallest dimension
|
||||
for (uint64_t i = 0; i < sub_idx; ++i) {
|
||||
znx_sub_i64_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||
}
|
||||
// then copy to the largest dimension
|
||||
for (uint64_t i = sub_idx; i < copy_idx; ++i) {
|
||||
znx_copy_i64_avx(nn, res + i * res_sl, a + i * a_sl);
|
||||
}
|
||||
// then extend with zeros
|
||||
for (uint64_t i = copy_idx; i < res_size; ++i) {
|
||||
znx_zero_i64_avx(nn, res + i * res_sl);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void vec_znx_negate_avx(const MODULE* module, // N
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
uint64_t nn = module->nn;
|
||||
uint64_t smin = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < smin; ++i) {
|
||||
znx_negate_i64_avx(nn, res + i * res_sl, a + i * a_sl);
|
||||
}
|
||||
for (uint64_t i = smin; i < res_size; ++i) {
|
||||
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,278 @@
|
||||
#include "vec_znx_arithmetic_private.h"
|
||||
|
||||
EXPORT uint64_t bytes_of_vec_znx_big(const MODULE* module, // N
|
||||
uint64_t size) {
|
||||
return module->func.bytes_of_vec_znx_big(module, size);
|
||||
}
|
||||
|
||||
// public wrappers
|
||||
|
||||
/** @brief sets res = a+b */
|
||||
EXPORT void vec_znx_big_add(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||
) {
|
||||
module->func.vec_znx_big_add(module, res, res_size, a, a_size, b, b_size);
|
||||
}
|
||||
|
||||
/** @brief sets res = a+b */
|
||||
EXPORT void vec_znx_big_add_small(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
) {
|
||||
module->func.vec_znx_big_add_small(module, res, res_size, a, a_size, b, b_size, b_sl);
|
||||
}
|
||||
|
||||
EXPORT void vec_znx_big_add_small2(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
) {
|
||||
module->func.vec_znx_big_add_small2(module, res, res_size, a, a_size, a_sl, b, b_size, b_sl);
|
||||
}
|
||||
|
||||
/** @brief sets res = a-b */
|
||||
EXPORT void vec_znx_big_sub(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||
) {
|
||||
module->func.vec_znx_big_sub(module, res, res_size, a, a_size, b, b_size);
|
||||
}
|
||||
|
||||
EXPORT void vec_znx_big_sub_small_b(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
) {
|
||||
module->func.vec_znx_big_sub_small_b(module, res, res_size, a, a_size, b, b_size, b_sl);
|
||||
}
|
||||
|
||||
EXPORT void vec_znx_big_sub_small_a(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||
) {
|
||||
module->func.vec_znx_big_sub_small_a(module, res, res_size, a, a_size, a_sl, b, b_size);
|
||||
}
|
||||
EXPORT void vec_znx_big_sub_small2(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
) {
|
||||
module->func.vec_znx_big_sub_small2(module, res, res_size, a, a_size, a_sl, b, b_size, b_sl);
|
||||
}
|
||||
|
||||
/** @brief sets res = a . X^p */
|
||||
EXPORT void vec_znx_big_rotate(const MODULE* module, // N
|
||||
int64_t p, // rotation value
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size // a
|
||||
) {
|
||||
module->func.vec_znx_big_rotate(module, p, res, res_size, a, a_size);
|
||||
}
|
||||
|
||||
/** @brief sets res = a(X^p) */
|
||||
EXPORT void vec_znx_big_automorphism(const MODULE* module, // N
|
||||
int64_t p, // X-X^p
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size // a
|
||||
) {
|
||||
module->func.vec_znx_big_automorphism(module, p, res, res_size, a, a_size);
|
||||
}
|
||||
|
||||
// private wrappers
|
||||
|
||||
EXPORT uint64_t fft64_bytes_of_vec_znx_big(const MODULE* module, // N
|
||||
uint64_t size) {
|
||||
return module->nn * size * sizeof(double);
|
||||
}
|
||||
|
||||
EXPORT VEC_ZNX_BIG* new_vec_znx_big(const MODULE* module, // N
|
||||
uint64_t size) {
|
||||
return spqlios_alloc(bytes_of_vec_znx_big(module, size));
|
||||
}
|
||||
|
||||
EXPORT void delete_vec_znx_big(VEC_ZNX_BIG* res) { spqlios_free(res); }
|
||||
|
||||
/** @brief sets res = a+b */
|
||||
EXPORT void fft64_vec_znx_big_add(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||
) {
|
||||
const uint64_t n = module->nn;
|
||||
vec_znx_add(module, //
|
||||
(int64_t*)res, res_size, n, //
|
||||
(int64_t*)a, a_size, n, //
|
||||
(int64_t*)b, b_size, n);
|
||||
}
|
||||
/** @brief sets res = a+b */
|
||||
EXPORT void fft64_vec_znx_big_add_small(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
) {
|
||||
const uint64_t n = module->nn;
|
||||
vec_znx_add(module, //
|
||||
(int64_t*)res, res_size, n, //
|
||||
(int64_t*)a, a_size, n, //
|
||||
b, b_size, b_sl);
|
||||
}
|
||||
EXPORT void fft64_vec_znx_big_add_small2(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
) {
|
||||
const uint64_t n = module->nn;
|
||||
vec_znx_add(module, //
|
||||
(int64_t*)res, res_size, n, //
|
||||
a, a_size, a_sl, //
|
||||
b, b_size, b_sl);
|
||||
}
|
||||
|
||||
/** @brief sets res = a-b */
|
||||
EXPORT void fft64_vec_znx_big_sub(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||
) {
|
||||
const uint64_t n = module->nn;
|
||||
vec_znx_sub(module, //
|
||||
(int64_t*)res, res_size, n, //
|
||||
(int64_t*)a, a_size, n, //
|
||||
(int64_t*)b, b_size, n);
|
||||
}
|
||||
|
||||
EXPORT void fft64_vec_znx_big_sub_small_b(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
) {
|
||||
const uint64_t n = module->nn;
|
||||
vec_znx_sub(module, //
|
||||
(int64_t*)res, res_size, n, //
|
||||
(int64_t*)a, a_size, //
|
||||
n, b, b_size, b_sl);
|
||||
}
|
||||
EXPORT void fft64_vec_znx_big_sub_small_a(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||
) {
|
||||
const uint64_t n = module->nn;
|
||||
vec_znx_sub(module, //
|
||||
(int64_t*)res, res_size, n, //
|
||||
a, a_size, a_sl, //
|
||||
(int64_t*)b, b_size, n);
|
||||
}
|
||||
EXPORT void fft64_vec_znx_big_sub_small2(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||
) {
|
||||
const uint64_t n = module->nn;
|
||||
vec_znx_sub(module, //
|
||||
(int64_t*)res, res_size, //
|
||||
n, a, a_size, //
|
||||
a_sl, b, b_size, b_sl);
|
||||
}
|
||||
|
||||
/** @brief sets res = a . X^p */
|
||||
EXPORT void fft64_vec_znx_big_rotate(const MODULE* module, // N
|
||||
int64_t p, // rotation value
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size // a
|
||||
) {
|
||||
uint64_t nn = module->nn;
|
||||
vec_znx_rotate(module, p, (int64_t*)res, res_size, nn, (int64_t*)a, a_size, nn);
|
||||
}
|
||||
|
||||
/** @brief sets res = a(X^p) */
|
||||
EXPORT void fft64_vec_znx_big_automorphism(const MODULE* module, // N
|
||||
int64_t p, // X-X^p
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size // a
|
||||
) {
|
||||
uint64_t nn = module->nn;
|
||||
vec_znx_automorphism(module, p, (int64_t*)res, res_size, nn, (int64_t*)a, a_size, nn);
|
||||
}
|
||||
|
||||
EXPORT void vec_znx_big_normalize_base2k(const MODULE* module, // MODULE
|
||||
uint64_t nn, // N
|
||||
uint64_t k, // base-2^k
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||
uint8_t* tmp_space // temp space
|
||||
) {
|
||||
module->func.vec_znx_big_normalize_base2k(module, // MODULE
|
||||
nn, // N
|
||||
k, // base-2^k
|
||||
res, res_size, res_sl, // res
|
||||
a, a_size, // a
|
||||
tmp_space);
|
||||
}
|
||||
|
||||
EXPORT uint64_t vec_znx_big_normalize_base2k_tmp_bytes(const MODULE* module, uint64_t nn // N
|
||||
) {
|
||||
return module->func.vec_znx_big_normalize_base2k_tmp_bytes(module, nn // N
|
||||
);
|
||||
}
|
||||
|
||||
/** @brief sets res = k-normalize(a.subrange) -- output in int64 coeffs space */
|
||||
EXPORT void vec_znx_big_range_normalize_base2k( //
|
||||
const MODULE* module, // MODULE
|
||||
uint64_t nn, // N
|
||||
uint64_t log2_base2k, // base-2^k
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_range_begin, uint64_t a_range_xend, uint64_t a_range_step, // range
|
||||
uint8_t* tmp_space // temp space
|
||||
) {
|
||||
module->func.vec_znx_big_range_normalize_base2k(module, nn, log2_base2k, res, res_size, res_sl, a, a_range_begin,
|
||||
a_range_xend, a_range_step, tmp_space);
|
||||
}
|
||||
|
||||
/** @brief returns the minimal byte length of scratch space for vec_znx_big_range_normalize_base2k */
|
||||
EXPORT uint64_t vec_znx_big_range_normalize_base2k_tmp_bytes( //
|
||||
const MODULE* module, // MODULE
|
||||
uint64_t nn // N
|
||||
) {
|
||||
return module->func.vec_znx_big_range_normalize_base2k_tmp_bytes(module, nn);
|
||||
}
|
||||
|
||||
EXPORT void fft64_vec_znx_big_normalize_base2k(const MODULE* module, // MODULE
|
||||
uint64_t nn, // N
|
||||
uint64_t k, // base-2^k
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||
uint8_t* tmp_space) {
|
||||
uint64_t a_sl = nn;
|
||||
module->func.vec_znx_normalize_base2k(module, // N
|
||||
nn,
|
||||
k, // log2_base2k
|
||||
res, res_size, res_sl, // res
|
||||
(int64_t*)a, a_size, a_sl, // a
|
||||
tmp_space);
|
||||
}
|
||||
|
||||
EXPORT void fft64_vec_znx_big_range_normalize_base2k( //
|
||||
const MODULE* module, // MODULE
|
||||
uint64_t nn, // N
|
||||
uint64_t k, // base-2^k
|
||||
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||
const VEC_ZNX_BIG* a, uint64_t a_begin, uint64_t a_end, uint64_t a_step, // a
|
||||
uint8_t* tmp_space) {
|
||||
// convert the range indexes to int64[] slices
|
||||
const int64_t* a_st = ((int64_t*)a) + nn * a_begin;
|
||||
const uint64_t a_size = (a_end + a_step - 1 - a_begin) / a_step;
|
||||
const uint64_t a_sl = nn * a_step;
|
||||
// forward the call
|
||||
module->func.vec_znx_normalize_base2k(module, // MODULE
|
||||
nn, // N
|
||||
k, // log2_base2k
|
||||
res, res_size, res_sl, // res
|
||||
a_st, a_size, a_sl, // a
|
||||
tmp_space);
|
||||
}
|
||||
@@ -0,0 +1,214 @@
|
||||
#include <string.h>
|
||||
|
||||
#include "../q120/q120_arithmetic.h"
|
||||
#include "vec_znx_arithmetic_private.h"
|
||||
|
||||
EXPORT void vec_znx_dft(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
return module->func.vec_znx_dft(module, res, res_size, a, a_size, a_sl);
|
||||
}
|
||||
|
||||
EXPORT void vec_dft_add(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a, uint64_t a_size, // a
|
||||
const VEC_ZNX_DFT* b, uint64_t b_size // b
|
||||
) {
|
||||
return module->func.vec_dft_add(module, res, res_size, a, a_size, b, b_size);
|
||||
}
|
||||
|
||||
EXPORT void vec_dft_sub(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a, uint64_t a_size, // a
|
||||
const VEC_ZNX_DFT* b, uint64_t b_size // b
|
||||
) {
|
||||
return module->func.vec_dft_sub(module, res, res_size, a, a_size, b, b_size);
|
||||
}
|
||||
|
||||
EXPORT void vec_znx_idft(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||
uint8_t* tmp // scratch space
|
||||
) {
|
||||
return module->func.vec_znx_idft(module, res, res_size, a_dft, a_size, tmp);
|
||||
}
|
||||
|
||||
EXPORT uint64_t vec_znx_idft_tmp_bytes(const MODULE* module, uint64_t nn) { return module->func.vec_znx_idft_tmp_bytes(module, nn); }
|
||||
|
||||
EXPORT void vec_znx_idft_tmp_a(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
VEC_ZNX_DFT* a_dft, uint64_t a_size // a is overwritten
|
||||
) {
|
||||
return module->func.vec_znx_idft_tmp_a(module, res, res_size, a_dft, a_size);
|
||||
}
|
||||
|
||||
EXPORT uint64_t bytes_of_vec_znx_dft(const MODULE* module, // N
|
||||
uint64_t size) {
|
||||
return module->func.bytes_of_vec_znx_dft(module, size);
|
||||
}
|
||||
|
||||
// fft64 backend
|
||||
EXPORT uint64_t fft64_bytes_of_vec_znx_dft(const MODULE* module, // N
|
||||
uint64_t size) {
|
||||
return module->nn * size * sizeof(double);
|
||||
}
|
||||
|
||||
EXPORT VEC_ZNX_DFT* new_vec_znx_dft(const MODULE* module, // N
|
||||
uint64_t size) {
|
||||
return spqlios_alloc(bytes_of_vec_znx_dft(module, size));
|
||||
}
|
||||
|
||||
EXPORT void delete_vec_znx_dft(VEC_ZNX_DFT* res) { spqlios_free(res); }
|
||||
|
||||
EXPORT void fft64_vec_znx_dft(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
const uint64_t smin = res_size < a_size ? res_size : a_size;
|
||||
const uint64_t nn = module->nn;
|
||||
|
||||
for (uint64_t i = 0; i < smin; i++) {
|
||||
reim_from_znx64(module->mod.fft64.p_conv, ((double*)res) + i * nn, a + i * a_sl);
|
||||
reim_fft(module->mod.fft64.p_fft, ((double*)res) + i * nn);
|
||||
}
|
||||
|
||||
// fill up remaining part with 0's
|
||||
double* const dres = (double*)res;
|
||||
memset(dres + smin * nn, 0, (res_size - smin) * nn * sizeof(double));
|
||||
}
|
||||
|
||||
EXPORT void fft64_vec_dft_add(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a, uint64_t a_size, // a
|
||||
const VEC_ZNX_DFT* b, uint64_t b_size // b
|
||||
) {
|
||||
const uint64_t smin0 = a_size < b_size ? a_size : b_size;
|
||||
const uint64_t smin = res_size < smin0 ? res_size : smin0;
|
||||
const uint64_t nn = module->nn;
|
||||
|
||||
for (uint64_t i = 0; i < smin; i++) {
|
||||
reim_fftvec_add(module->mod.fft64.add_fft, ((double*)res) + i * nn, ((double*)a) + i * nn, ((double*)b) + i * nn);
|
||||
}
|
||||
|
||||
// fill remain `res` part with 0's
|
||||
double* const dres = (double*)res;
|
||||
memset(dres + smin * nn, 0, (res_size - smin) * nn * sizeof(double));
|
||||
}
|
||||
|
||||
EXPORT void fft64_vec_dft_sub(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a, uint64_t a_size, // a
|
||||
const VEC_ZNX_DFT* b, uint64_t b_size // b
|
||||
) {
|
||||
const uint64_t smin0 = a_size < b_size ? a_size : b_size;
|
||||
const uint64_t smin = res_size < smin0 ? res_size : smin0;
|
||||
const uint64_t nn = module->nn;
|
||||
|
||||
for (uint64_t i = 0; i < smin; i++) {
|
||||
reim_fftvec_sub(module->mod.fft64.sub_fft, ((double*)res) + i * nn, ((double*)a) + i * nn, ((double*)b) + i * nn);
|
||||
}
|
||||
|
||||
// fill remain `res` part with 0's
|
||||
double* const dres = (double*)res;
|
||||
memset(dres + smin * nn, 0, (res_size - smin) * nn * sizeof(double));
|
||||
}
|
||||
|
||||
EXPORT void fft64_vec_znx_idft(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||
uint8_t* tmp // unused
|
||||
) {
|
||||
const uint64_t nn = module->nn;
|
||||
const uint64_t smin = res_size < a_size ? res_size : a_size;
|
||||
if ((double*)res != (double*)a_dft) {
|
||||
memcpy(res, a_dft, smin * nn * sizeof(double));
|
||||
}
|
||||
|
||||
for (uint64_t i = 0; i < smin; i++) {
|
||||
reim_ifft(module->mod.fft64.p_ifft, ((double*)res) + i * nn);
|
||||
reim_to_znx64(module->mod.fft64.p_reim_to_znx, ((int64_t*)res) + i * nn, ((int64_t*)res) + i * nn);
|
||||
}
|
||||
|
||||
// fill up remaining part with 0's
|
||||
int64_t* const dres = (int64_t*)res;
|
||||
memset(dres + smin * nn, 0, (res_size - smin) * nn * sizeof(double));
|
||||
}
|
||||
|
||||
EXPORT uint64_t fft64_vec_znx_idft_tmp_bytes(const MODULE* module, uint64_t nn) { return 0; }
|
||||
|
||||
EXPORT void fft64_vec_znx_idft_tmp_a(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
VEC_ZNX_DFT* a_dft, uint64_t a_size // a is overwritten
|
||||
) {
|
||||
const uint64_t nn = module->nn;
|
||||
const uint64_t smin = res_size < a_size ? res_size : a_size;
|
||||
|
||||
int64_t* const tres = (int64_t*)res;
|
||||
double* const ta = (double*)a_dft;
|
||||
for (uint64_t i = 0; i < smin; i++) {
|
||||
reim_ifft(module->mod.fft64.p_ifft, ta + i * nn);
|
||||
reim_to_znx64(module->mod.fft64.p_reim_to_znx, tres + i * nn, ta + i * nn);
|
||||
}
|
||||
|
||||
// fill up remaining part with 0's
|
||||
memset(tres + smin * nn, 0, (res_size - smin) * nn * sizeof(double));
|
||||
}
|
||||
|
||||
// ntt120 backend
|
||||
|
||||
EXPORT void ntt120_vec_znx_dft_avx(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||
) {
|
||||
const uint64_t nn = module->nn;
|
||||
const uint64_t smin = res_size < a_size ? res_size : a_size;
|
||||
|
||||
int64_t* tres = (int64_t*)res;
|
||||
for (uint64_t i = 0; i < smin; i++) {
|
||||
q120_b_from_znx64_simple(nn, (q120b*)(tres + i * nn * 4), a + i * a_sl);
|
||||
q120_ntt_bb_avx2(module->mod.q120.p_ntt, (q120b*)(tres + i * nn * 4));
|
||||
}
|
||||
|
||||
// fill up remaining part with 0's
|
||||
memset(tres + smin * nn * 4, 0, (res_size - smin) * nn * 4 * sizeof(int64_t));
|
||||
}
|
||||
|
||||
EXPORT void ntt120_vec_znx_idft_avx(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||
uint8_t* tmp) {
|
||||
const uint64_t nn = module->nn;
|
||||
const uint64_t smin = res_size < a_size ? res_size : a_size;
|
||||
|
||||
__int128_t* const tres = (__int128_t*)res;
|
||||
const int64_t* const ta = (int64_t*)a_dft;
|
||||
for (uint64_t i = 0; i < smin; i++) {
|
||||
memcpy(tmp, ta + i * nn * 4, nn * 4 * sizeof(uint64_t));
|
||||
q120_intt_bb_avx2(module->mod.q120.p_intt, (q120b*)tmp);
|
||||
q120_b_to_znx128_simple(nn, tres + i * nn, (q120b*)tmp);
|
||||
}
|
||||
|
||||
// fill up remaining part with 0's
|
||||
memset(tres + smin * nn, 0, (res_size - smin) * nn * sizeof(*tres));
|
||||
}
|
||||
|
||||
EXPORT uint64_t ntt120_vec_znx_idft_tmp_bytes_avx(const MODULE* module, uint64_t nn) { return nn * 4 * sizeof(uint64_t); }
|
||||
|
||||
EXPORT void ntt120_vec_znx_idft_tmp_a_avx(const MODULE* module, // N
|
||||
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||
VEC_ZNX_DFT* a_dft, uint64_t a_size // a is overwritten
|
||||
) {
|
||||
const uint64_t nn = module->nn;
|
||||
const uint64_t smin = res_size < a_size ? res_size : a_size;
|
||||
|
||||
__int128_t* const tres = (__int128_t*)res;
|
||||
int64_t* const ta = (int64_t*)a_dft;
|
||||
for (uint64_t i = 0; i < smin; i++) {
|
||||
q120_intt_bb_avx2(module->mod.q120.p_intt, (q120b*)(ta + i * nn * 4));
|
||||
q120_b_to_znx128_simple(nn, tres + i * nn, (q120b*)(ta + i * nn * 4));
|
||||
}
|
||||
|
||||
// fill up remaining part with 0's
|
||||
memset(tres + smin * nn, 0, (res_size - smin) * nn * sizeof(*tres));
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
#include "vec_znx_arithmetic_private.h"
|
||||
@@ -0,0 +1,369 @@
|
||||
#include <assert.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "../reim4/reim4_arithmetic.h"
|
||||
#include "vec_znx_arithmetic_private.h"
|
||||
|
||||
EXPORT uint64_t bytes_of_vmp_pmat(const MODULE* module, // N
|
||||
uint64_t nrows, uint64_t ncols // dimensions
|
||||
) {
|
||||
return module->func.bytes_of_vmp_pmat(module, nrows, ncols);
|
||||
}
|
||||
|
||||
// fft64
|
||||
EXPORT uint64_t fft64_bytes_of_vmp_pmat(const MODULE* module, // N
|
||||
uint64_t nrows, uint64_t ncols // dimensions
|
||||
) {
|
||||
return module->nn * nrows * ncols * sizeof(double);
|
||||
}
|
||||
|
||||
EXPORT VMP_PMAT* new_vmp_pmat(const MODULE* module, // N
|
||||
uint64_t nrows, uint64_t ncols // dimensions
|
||||
) {
|
||||
return spqlios_alloc(bytes_of_vmp_pmat(module, nrows, ncols));
|
||||
}
|
||||
|
||||
EXPORT void delete_vmp_pmat(VMP_PMAT* res) { spqlios_free(res); }
|
||||
|
||||
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||
EXPORT void vmp_prepare_contiguous(const MODULE* module, // N
|
||||
VMP_PMAT* pmat, // output
|
||||
const int64_t* mat, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
) {
|
||||
module->func.vmp_prepare_contiguous(module, pmat, mat, nrows, ncols, tmp_space);
|
||||
}
|
||||
|
||||
/** @brief minimal scratch space byte-size required for the vmp_prepare function */
|
||||
EXPORT uint64_t vmp_prepare_tmp_bytes(const MODULE* module, uint64_t nn, // N
|
||||
uint64_t nrows, uint64_t ncols) {
|
||||
return module->func.vmp_prepare_tmp_bytes(module, nn, nrows, ncols);
|
||||
}
|
||||
|
||||
EXPORT double* get_blk_addr(uint64_t row_i, uint64_t col_i, uint64_t nrows, uint64_t ncols, const VMP_PMAT* pmat) {
|
||||
double* output_mat = (double*)pmat;
|
||||
|
||||
if (col_i == (ncols - 1) && (ncols % 2 == 1)) {
|
||||
// special case: last column out of an odd column number
|
||||
return output_mat + col_i * nrows * 8 // col == ncols-1
|
||||
+ row_i * 8;
|
||||
} else {
|
||||
// general case: columns go by pair
|
||||
return output_mat + (col_i / 2) * (2 * nrows) * 8 // second: col pair index
|
||||
+ row_i * 2 * 8 // third: row index
|
||||
+ (col_i % 2) * 8;
|
||||
}
|
||||
}
|
||||
|
||||
void fft64_store_svp_ppol_into_vmp_pmat_row_blk_ref(uint64_t nn, uint64_t m, const SVP_PPOL* svp_ppol, uint64_t row_i,
|
||||
uint64_t col_i, uint64_t nrows, uint64_t ncols, VMP_PMAT* pmat) {
|
||||
double* start_addr = get_blk_addr(row_i, col_i, nrows, ncols, pmat);
|
||||
uint64_t offset = nrows * ncols * 8;
|
||||
for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
|
||||
reim4_extract_1blk_from_reim_ref(m, blk_i, start_addr + blk_i * offset, (double*)svp_ppol);
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||
EXPORT void fft64_vmp_prepare_contiguous_ref(const MODULE* module, // N
|
||||
VMP_PMAT* pmat, // output
|
||||
const int64_t* mat, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
) {
|
||||
// there is an edge case if nn < 8
|
||||
const uint64_t nn = module->nn;
|
||||
const uint64_t m = module->m;
|
||||
|
||||
if (nn >= 8) {
|
||||
for (uint64_t row_i = 0; row_i < nrows; row_i++) {
|
||||
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||
reim_from_znx64(module->mod.fft64.p_conv, (SVP_PPOL*)tmp_space, mat + (row_i * ncols + col_i) * nn);
|
||||
reim_fft(module->mod.fft64.p_fft, (double*)tmp_space);
|
||||
fft64_store_svp_ppol_into_vmp_pmat_row_blk_ref(nn, m, (SVP_PPOL*)tmp_space, row_i, col_i, nrows, ncols, pmat);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (uint64_t row_i = 0; row_i < nrows; row_i++) {
|
||||
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||
double* res = (double*)pmat + (col_i * nrows + row_i) * nn;
|
||||
reim_from_znx64(module->mod.fft64.p_conv, (SVP_PPOL*)res, mat + (row_i * ncols + col_i) * nn);
|
||||
reim_fft(module->mod.fft64.p_fft, res);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief minimal scratch space byte-size required for the vmp_prepare function */
|
||||
EXPORT uint64_t fft64_vmp_prepare_tmp_bytes(const MODULE* module, uint64_t nn, // N
|
||||
uint64_t nrows, uint64_t ncols) {
|
||||
return nn * sizeof(int64_t);
|
||||
}
|
||||
|
||||
/** @brief applies a vmp product (result in DFT space) and adds to res inplace */
|
||||
EXPORT void fft64_vmp_apply_dft_add_ref(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols,
|
||||
uint64_t pmat_scale, // prep matrix
|
||||
uint8_t* tmp_space // scratch space
|
||||
) {
|
||||
const uint64_t nn = module->nn;
|
||||
const uint64_t rows = nrows < a_size ? nrows : a_size;
|
||||
|
||||
VEC_ZNX_DFT* a_dft = (VEC_ZNX_DFT*)tmp_space;
|
||||
uint8_t* new_tmp_space = (uint8_t*)tmp_space + rows * nn * sizeof(double);
|
||||
|
||||
fft64_vec_znx_dft(module, a_dft, rows, a, a_size, a_sl);
|
||||
fft64_vmp_apply_dft_to_dft_add_ref(module, res, res_size, a_dft, a_size, pmat, nrows, ncols, pmat_scale,
|
||||
new_tmp_space);
|
||||
}
|
||||
|
||||
/** @brief applies a vmp product (result in DFT space) */
|
||||
EXPORT void fft64_vmp_apply_dft_ref(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space
|
||||
) {
|
||||
const uint64_t nn = module->nn;
|
||||
const uint64_t rows = nrows < a_size ? nrows : a_size;
|
||||
|
||||
VEC_ZNX_DFT* a_dft = (VEC_ZNX_DFT*)tmp_space;
|
||||
uint8_t* new_tmp_space = (uint8_t*)tmp_space + rows * nn * sizeof(double);
|
||||
|
||||
fft64_vec_znx_dft(module, a_dft, rows, a, a_size, a_sl);
|
||||
fft64_vmp_apply_dft_to_dft_ref(module, res, res_size, a_dft, a_size, pmat, nrows, ncols, new_tmp_space);
|
||||
}
|
||||
|
||||
/** @brief like fft64_vmp_apply_dft_to_dft_ref but adds in place */
|
||||
EXPORT void fft64_vmp_apply_dft_to_dft_add_ref(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, const uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||
const VMP_PMAT* pmat, const uint64_t nrows, const uint64_t ncols,
|
||||
uint64_t pmat_scale, // prep matrix
|
||||
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||
) {
|
||||
const uint64_t m = module->m;
|
||||
const uint64_t nn = module->nn;
|
||||
assert(nn >= 8);
|
||||
|
||||
double* mat2cols_output = (double*)tmp_space; // 128 bytes
|
||||
double* extracted_blk = (double*)tmp_space + 16; // 64*min(nrows,a_size) bytes
|
||||
|
||||
double* mat_input = (double*)pmat;
|
||||
double* vec_input = (double*)a_dft;
|
||||
double* vec_output = (double*)res;
|
||||
|
||||
// const uint64_t row_max0 = res_size < a_size ? res_size: a_size;
|
||||
const uint64_t row_max = nrows < a_size ? nrows : a_size;
|
||||
const uint64_t col_max = ncols < res_size ? ncols : res_size;
|
||||
|
||||
if (nn >= 8) {
|
||||
for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
|
||||
double* mat_blk_start = mat_input + blk_i * (8 * nrows * ncols);
|
||||
|
||||
reim4_extract_1blk_from_contiguous_reim_ref(m, row_max, blk_i, (double*)extracted_blk, (double*)a_dft);
|
||||
|
||||
if (pmat_scale % 2 == 0) {
|
||||
// apply mat2cols
|
||||
for (uint64_t col_res = 0, col_pmat = pmat_scale; col_pmat < col_max - 1; col_res += 2, col_pmat += 2) {
|
||||
uint64_t col_offset = col_pmat * (8 * nrows);
|
||||
reim4_vec_mat2cols_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
reim4_add_1blk_to_reim_ref(m, blk_i, vec_output + col_res * nn, mat2cols_output);
|
||||
reim4_add_1blk_to_reim_ref(m, blk_i, vec_output + (col_res + 1) * nn, mat2cols_output + 8);
|
||||
}
|
||||
} else {
|
||||
uint64_t col_offset = (pmat_scale - 1) * (8 * nrows);
|
||||
reim4_vec_mat2cols_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
reim4_add_1blk_to_reim_ref(m, blk_i, vec_output, mat2cols_output + 8);
|
||||
|
||||
// apply mat2cols
|
||||
for (uint64_t col_res = 1, col_pmat = pmat_scale + 1; col_pmat < col_max - 1; col_res += 2, col_pmat += 2) {
|
||||
uint64_t col_offset = col_pmat * (8 * nrows);
|
||||
reim4_vec_mat2cols_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
|
||||
reim4_add_1blk_to_reim_ref(m, blk_i, vec_output + col_res * nn, mat2cols_output);
|
||||
reim4_add_1blk_to_reim_ref(m, blk_i, vec_output + (col_res + 1) * nn, mat2cols_output + 8);
|
||||
}
|
||||
}
|
||||
|
||||
// check if col_max is odd, then special case
|
||||
if (col_max % 2 == 1) {
|
||||
uint64_t last_col = col_max - 1;
|
||||
uint64_t col_offset = last_col * (8 * nrows);
|
||||
|
||||
if (last_col >= pmat_scale) {
|
||||
// the last column is alone in the pmat: vec_mat1col
|
||||
if (ncols == col_max) {
|
||||
reim4_vec_mat1col_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
} else {
|
||||
// the last column is part of a colpair in the pmat: vec_mat2cols and ignore the second position
|
||||
reim4_vec_mat2cols_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
}
|
||||
|
||||
reim4_add_1blk_to_reim_ref(m, blk_i, vec_output + (last_col - pmat_scale) * nn, mat2cols_output);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (uint64_t col_res = 0, col_pmat = pmat_scale; col_pmat < col_max; col_res += 1, col_pmat += 1) {
|
||||
double* pmat_col = mat_input + col_pmat * nrows * nn;
|
||||
for (uint64_t row_i = 0; row_i < row_max; row_i++) {
|
||||
reim_fftvec_addmul(module->mod.fft64.p_addmul, vec_output + col_res * nn, vec_input + row_i * nn,
|
||||
pmat_col + row_i * nn);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// zero out remaining bytes
|
||||
memset(vec_output + col_max * nn, 0, (res_size - col_max) * nn * sizeof(double));
|
||||
}
|
||||
|
||||
/** @brief this inner function could be very handy */
|
||||
EXPORT void fft64_vmp_apply_dft_to_dft_ref(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, const uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||
const VMP_PMAT* pmat, const uint64_t nrows,
|
||||
const uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||
) {
|
||||
const uint64_t m = module->m;
|
||||
const uint64_t nn = module->nn;
|
||||
assert(nn >= 8);
|
||||
|
||||
double* mat2cols_output = (double*)tmp_space; // 128 bytes
|
||||
double* extracted_blk = (double*)tmp_space + 16; // 64*min(nrows,a_size) bytes
|
||||
|
||||
double* mat_input = (double*)pmat;
|
||||
double* vec_input = (double*)a_dft;
|
||||
double* vec_output = (double*)res;
|
||||
|
||||
// const uint64_t row_max0 = res_size < a_size ? res_size: a_size;
|
||||
const uint64_t row_max = nrows < a_size ? nrows : a_size;
|
||||
const uint64_t col_max = ncols < res_size ? ncols : res_size;
|
||||
|
||||
if (nn >= 8) {
|
||||
for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
|
||||
double* mat_blk_start = mat_input + blk_i * (8 * nrows * ncols);
|
||||
|
||||
reim4_extract_1blk_from_contiguous_reim_ref(m, row_max, blk_i, (double*)extracted_blk, (double*)a_dft);
|
||||
// apply mat2cols
|
||||
for (uint64_t col_i = 0; col_i < col_max - 1; col_i += 2) {
|
||||
uint64_t col_offset = col_i * (8 * nrows);
|
||||
reim4_vec_mat2cols_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
|
||||
reim4_save_1blk_to_reim_ref(m, blk_i, vec_output + col_i * nn, mat2cols_output);
|
||||
reim4_save_1blk_to_reim_ref(m, blk_i, vec_output + (col_i + 1) * nn, mat2cols_output + 8);
|
||||
}
|
||||
|
||||
// check if col_max is odd, then special case
|
||||
if (col_max % 2 == 1) {
|
||||
uint64_t last_col = col_max - 1;
|
||||
uint64_t col_offset = last_col * (8 * nrows);
|
||||
|
||||
// the last column is alone in the pmat: vec_mat1col
|
||||
if (ncols == col_max) {
|
||||
reim4_vec_mat1col_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
} else {
|
||||
// the last column is part of a colpair in the pmat: vec_mat2cols and ignore the second position
|
||||
reim4_vec_mat2cols_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
}
|
||||
reim4_save_1blk_to_reim_ref(m, blk_i, vec_output + last_col * nn, mat2cols_output);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (uint64_t col_i = 0; col_i < col_max; col_i++) {
|
||||
double* pmat_col = mat_input + col_i * nrows * nn;
|
||||
for (uint64_t row_i = 0; row_i < 1; row_i++) {
|
||||
reim_fftvec_mul(module->mod.fft64.mul_fft, vec_output + col_i * nn, vec_input + row_i * nn,
|
||||
pmat_col + row_i * nn);
|
||||
}
|
||||
for (uint64_t row_i = 1; row_i < row_max; row_i++) {
|
||||
reim_fftvec_addmul(module->mod.fft64.p_addmul, vec_output + col_i * nn, vec_input + row_i * nn,
|
||||
pmat_col + row_i * nn);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// zero out remaining bytes
|
||||
memset(vec_output + col_max * nn, 0, (res_size - col_max) * nn * sizeof(double));
|
||||
}
|
||||
|
||||
/** @brief minimal size of the tmp_space */
|
||||
EXPORT uint64_t fft64_vmp_apply_dft_tmp_bytes(const MODULE* module, uint64_t nn, // N
|
||||
uint64_t res_size, // res
|
||||
uint64_t a_size, // a
|
||||
uint64_t nrows, uint64_t ncols // prep matrix
|
||||
) {
|
||||
const uint64_t row_max = nrows < a_size ? nrows : a_size;
|
||||
return (row_max * nn * sizeof(double)) + (128) + (64 * row_max);
|
||||
}
|
||||
|
||||
/** @brief minimal size of the tmp_space */
|
||||
EXPORT uint64_t fft64_vmp_apply_dft_to_dft_tmp_bytes(const MODULE* module, uint64_t nn, // N
|
||||
uint64_t res_size, // res
|
||||
uint64_t a_size, // a
|
||||
uint64_t nrows, uint64_t ncols // prep matrix
|
||||
) {
|
||||
const uint64_t row_max = nrows < a_size ? nrows : a_size;
|
||||
|
||||
return (128) + (64 * row_max);
|
||||
}
|
||||
|
||||
EXPORT void vmp_apply_dft_to_dft(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, const uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||
const VMP_PMAT* pmat, const uint64_t nrows,
|
||||
const uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||
) {
|
||||
module->func.vmp_apply_dft_to_dft(module, res, res_size, a_dft, a_size, pmat, nrows, ncols, tmp_space);
|
||||
}
|
||||
|
||||
EXPORT void vmp_apply_dft_to_dft_add(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, const uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||
const VMP_PMAT* pmat, const uint64_t nrows, const uint64_t ncols,
|
||||
uint64_t pmat_scale, // prep matrix
|
||||
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||
) {
|
||||
module->func.vmp_apply_dft_to_dft_add(module, res, res_size, a_dft, a_size, pmat, nrows, ncols, pmat_scale,
|
||||
tmp_space);
|
||||
}
|
||||
|
||||
EXPORT uint64_t vmp_apply_dft_to_dft_tmp_bytes(const MODULE* module, uint64_t nn, // N
|
||||
uint64_t res_size, // res
|
||||
uint64_t a_size, // a
|
||||
uint64_t nrows, uint64_t ncols // prep matrix
|
||||
) {
|
||||
return module->func.vmp_apply_dft_to_dft_tmp_bytes(module, nn, res_size, a_size, nrows, ncols);
|
||||
}
|
||||
|
||||
/** @brief applies a vmp product (result in DFT space) adds to res inplace */
|
||||
EXPORT void vmp_apply_dft_add(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, uint64_t pmat_scale, // prep matrix
|
||||
uint8_t* tmp_space // scratch space
|
||||
) {
|
||||
module->func.vmp_apply_dft_add(module, res, res_size, a, a_size, a_sl, pmat, nrows, ncols, pmat_scale, tmp_space);
|
||||
}
|
||||
|
||||
/** @brief applies a vmp product (result in DFT space) */
|
||||
EXPORT void vmp_apply_dft(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space
|
||||
) {
|
||||
module->func.vmp_apply_dft(module, res, res_size, a, a_size, a_sl, pmat, nrows, ncols, tmp_space);
|
||||
}
|
||||
|
||||
/** @brief minimal size of the tmp_space */
|
||||
EXPORT uint64_t vmp_apply_dft_tmp_bytes(const MODULE* module, uint64_t nn, // N
|
||||
uint64_t res_size, // res
|
||||
uint64_t a_size, // a
|
||||
uint64_t nrows, uint64_t ncols // prep matrix
|
||||
) {
|
||||
return module->func.vmp_apply_dft_tmp_bytes(module, nn, res_size, a_size, nrows, ncols);
|
||||
}
|
||||
@@ -0,0 +1,244 @@
|
||||
#include <string.h>
|
||||
|
||||
#include "../reim4/reim4_arithmetic.h"
|
||||
#include "vec_znx_arithmetic_private.h"
|
||||
|
||||
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||
EXPORT void fft64_vmp_prepare_contiguous_avx(const MODULE* module, // N
|
||||
VMP_PMAT* pmat, // output
|
||||
const int64_t* mat, uint64_t nrows, uint64_t ncols, // a
|
||||
uint8_t* tmp_space // scratch space
|
||||
) {
|
||||
// there is an edge case if nn < 8
|
||||
const uint64_t nn = module->nn;
|
||||
const uint64_t m = module->m;
|
||||
|
||||
double* output_mat = (double*)pmat;
|
||||
double* start_addr = (double*)pmat;
|
||||
uint64_t offset = nrows * ncols * 8;
|
||||
|
||||
if (nn >= 8) {
|
||||
for (uint64_t row_i = 0; row_i < nrows; row_i++) {
|
||||
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||
reim_from_znx64(module->mod.fft64.p_conv, (SVP_PPOL*)tmp_space, mat + (row_i * ncols + col_i) * nn);
|
||||
reim_fft(module->mod.fft64.p_fft, (double*)tmp_space);
|
||||
|
||||
if (col_i == (ncols - 1) && (ncols % 2 == 1)) {
|
||||
// special case: last column out of an odd column number
|
||||
start_addr = output_mat + col_i * nrows * 8 // col == ncols-1
|
||||
+ row_i * 8;
|
||||
} else {
|
||||
// general case: columns go by pair
|
||||
start_addr = output_mat + (col_i / 2) * (2 * nrows) * 8 // second: col pair index
|
||||
+ row_i * 2 * 8 // third: row index
|
||||
+ (col_i % 2) * 8;
|
||||
}
|
||||
|
||||
for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
|
||||
// extract blk from tmp and save it
|
||||
reim4_extract_1blk_from_reim_avx(m, blk_i, start_addr + blk_i * offset, (double*)tmp_space);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (uint64_t row_i = 0; row_i < nrows; row_i++) {
|
||||
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||
double* res = (double*)pmat + (col_i * nrows + row_i) * nn;
|
||||
reim_from_znx64(module->mod.fft64.p_conv, (SVP_PPOL*)res, mat + (row_i * ncols + col_i) * nn);
|
||||
reim_fft(module->mod.fft64.p_fft, res);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
double* get_blk_addr(int row, int col, int nrows, int ncols, VMP_PMAT* pmat);
|
||||
|
||||
void fft64_store_svp_ppol_into_vmp_pmat_row_blk_avx(uint64_t nn, uint64_t m, const SVP_PPOL* svp_ppol, uint64_t row_i,
|
||||
uint64_t col_i, uint64_t nrows, uint64_t ncols, VMP_PMAT* pmat) {
|
||||
double* start_addr = get_blk_addr(row_i, col_i, nrows, ncols, pmat);
|
||||
uint64_t offset = nrows * ncols * 8;
|
||||
for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
|
||||
reim4_extract_1blk_from_reim_avx(m, blk_i, start_addr + blk_i * offset, (double*)svp_ppol);
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief applies a vmp product (result in DFT space) abd adds to res inplace */
|
||||
EXPORT void fft64_vmp_apply_dft_add_avx(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols,
|
||||
uint64_t pmat_scale, // prep matrix
|
||||
uint8_t* tmp_space // scratch space
|
||||
) {
|
||||
const uint64_t nn = module->nn;
|
||||
const uint64_t rows = nrows < a_size ? nrows : a_size;
|
||||
|
||||
VEC_ZNX_DFT* a_dft = (VEC_ZNX_DFT*)tmp_space;
|
||||
uint8_t* new_tmp_space = (uint8_t*)tmp_space + rows * nn * sizeof(double);
|
||||
|
||||
fft64_vec_znx_dft(module, a_dft, rows, a, a_size, a_sl);
|
||||
fft64_vmp_apply_dft_to_dft_add_avx(module, res, res_size, a_dft, a_size, pmat, nrows, ncols, pmat_scale,
|
||||
new_tmp_space);
|
||||
}
|
||||
|
||||
/** @brief applies a vmp product (result in DFT space) */
|
||||
EXPORT void fft64_vmp_apply_dft_avx(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space
|
||||
) {
|
||||
const uint64_t nn = module->nn;
|
||||
const uint64_t rows = nrows < a_size ? nrows : a_size;
|
||||
|
||||
VEC_ZNX_DFT* a_dft = (VEC_ZNX_DFT*)tmp_space;
|
||||
uint8_t* new_tmp_space = (uint8_t*)tmp_space + rows * nn * sizeof(double);
|
||||
|
||||
fft64_vec_znx_dft(module, a_dft, rows, a, a_size, a_sl);
|
||||
fft64_vmp_apply_dft_to_dft_avx(module, res, res_size, a_dft, a_size, pmat, nrows, ncols, new_tmp_space);
|
||||
}
|
||||
|
||||
EXPORT void fft64_vmp_apply_dft_to_dft_add_avx(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, const uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||
const VMP_PMAT* pmat, const uint64_t nrows, const uint64_t ncols,
|
||||
uint64_t pmat_scale, // prep matrix
|
||||
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||
) {
|
||||
const uint64_t m = module->m;
|
||||
const uint64_t nn = module->nn;
|
||||
|
||||
double* mat2cols_output = (double*)tmp_space; // 128 bytes
|
||||
double* extracted_blk = (double*)tmp_space + 16; // 64*min(nrows,a_size) bytes
|
||||
|
||||
double* mat_input = (double*)pmat;
|
||||
double* vec_input = (double*)a_dft;
|
||||
double* vec_output = (double*)res;
|
||||
|
||||
const uint64_t row_max = nrows < a_size ? nrows : a_size;
|
||||
const uint64_t col_max = ncols < res_size ? ncols : res_size;
|
||||
|
||||
if (nn >= 8) {
|
||||
for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
|
||||
double* mat_blk_start = mat_input + blk_i * (8 * nrows * ncols);
|
||||
|
||||
reim4_extract_1blk_from_contiguous_reim_avx(m, row_max, blk_i, (double*)extracted_blk, (double*)a_dft);
|
||||
|
||||
if (pmat_scale % 2 == 0) {
|
||||
for (uint64_t col_res = 0, col_pmat = pmat_scale; col_pmat < col_max - 1; col_res += 2, col_pmat += 2) {
|
||||
uint64_t col_offset = col_pmat * (8 * nrows);
|
||||
reim4_vec_mat2cols_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
reim4_add_1blk_to_reim_avx(m, blk_i, vec_output + col_res * nn, mat2cols_output);
|
||||
reim4_add_1blk_to_reim_avx(m, blk_i, vec_output + (col_res + 1) * nn, mat2cols_output + 8);
|
||||
}
|
||||
} else {
|
||||
uint64_t col_offset = (pmat_scale - 1) * (8 * nrows);
|
||||
reim4_vec_mat2cols_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
reim4_add_1blk_to_reim_avx(m, blk_i, vec_output, mat2cols_output + 8);
|
||||
|
||||
for (uint64_t col_res = 1, col_pmat = pmat_scale + 1; col_pmat < col_max - 1; col_res += 2, col_pmat += 2) {
|
||||
uint64_t col_offset = col_pmat * (8 * nrows);
|
||||
reim4_vec_mat2cols_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
reim4_add_1blk_to_reim_avx(m, blk_i, vec_output + col_res * nn, mat2cols_output);
|
||||
reim4_add_1blk_to_reim_avx(m, blk_i, vec_output + (col_res + 1) * nn, mat2cols_output + 8);
|
||||
}
|
||||
}
|
||||
|
||||
// check if col_max is odd, then special case
|
||||
if (col_max % 2 == 1) {
|
||||
uint64_t last_col = col_max - 1;
|
||||
uint64_t col_offset = last_col * (8 * nrows);
|
||||
|
||||
if (last_col >= pmat_scale) {
|
||||
// the last column is alone in the pmat: vec_mat1col
|
||||
if (ncols == col_max)
|
||||
reim4_vec_mat1col_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
else {
|
||||
// the last column is part of a colpair in the pmat: vec_mat2cols and ignore the second position
|
||||
reim4_vec_mat2cols_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
}
|
||||
reim4_add_1blk_to_reim_avx(m, blk_i, vec_output + (last_col - pmat_scale) * nn, mat2cols_output);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (uint64_t col_res = 0, col_pmat = pmat_scale; col_pmat < col_max; col_res += 1, col_pmat += 1) {
|
||||
double* pmat_col = mat_input + col_pmat * nrows * nn;
|
||||
for (uint64_t row_i = 0; row_i < row_max; row_i++) {
|
||||
reim_fftvec_addmul(module->mod.fft64.p_addmul, vec_output + col_res * nn, vec_input + row_i * nn,
|
||||
pmat_col + row_i * nn);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// zero out remaining bytes
|
||||
memset(vec_output + col_max * nn, 0, (res_size - col_max) * nn * sizeof(double));
|
||||
}
|
||||
|
||||
/** @brief this inner function could be very handy */
|
||||
EXPORT void fft64_vmp_apply_dft_to_dft_avx(const MODULE* module, // N
|
||||
VEC_ZNX_DFT* res, const uint64_t res_size, // res
|
||||
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||
const VMP_PMAT* pmat, const uint64_t nrows,
|
||||
const uint64_t ncols, // prep matrix
|
||||
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||
) {
|
||||
const uint64_t m = module->m;
|
||||
const uint64_t nn = module->nn;
|
||||
|
||||
double* mat2cols_output = (double*)tmp_space; // 128 bytes
|
||||
double* extracted_blk = (double*)tmp_space + 16; // 64*min(nrows,a_size) bytes
|
||||
|
||||
double* mat_input = (double*)pmat;
|
||||
double* vec_input = (double*)a_dft;
|
||||
double* vec_output = (double*)res;
|
||||
|
||||
const uint64_t row_max = nrows < a_size ? nrows : a_size;
|
||||
const uint64_t col_max = ncols < res_size ? ncols : res_size;
|
||||
|
||||
if (nn >= 8) {
|
||||
for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
|
||||
double* mat_blk_start = mat_input + blk_i * (8 * nrows * ncols);
|
||||
|
||||
reim4_extract_1blk_from_contiguous_reim_avx(m, row_max, blk_i, (double*)extracted_blk, (double*)a_dft);
|
||||
// apply mat2cols
|
||||
for (uint64_t col_i = 0; col_i < col_max - 1; col_i += 2) {
|
||||
uint64_t col_offset = col_i * (8 * nrows);
|
||||
reim4_vec_mat2cols_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
|
||||
reim4_save_1blk_to_reim_avx(m, blk_i, vec_output + col_i * nn, mat2cols_output);
|
||||
reim4_save_1blk_to_reim_avx(m, blk_i, vec_output + (col_i + 1) * nn, mat2cols_output + 8);
|
||||
}
|
||||
|
||||
// check if col_max is odd, then special case
|
||||
if (col_max % 2 == 1) {
|
||||
uint64_t last_col = col_max - 1;
|
||||
uint64_t col_offset = last_col * (8 * nrows);
|
||||
|
||||
// the last column is alone in the pmat: vec_mat1col
|
||||
if (ncols == col_max)
|
||||
reim4_vec_mat1col_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
else {
|
||||
// the last column is part of a colpair in the pmat: vec_mat2cols and ignore the second position
|
||||
reim4_vec_mat2cols_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||
}
|
||||
reim4_save_1blk_to_reim_avx(m, blk_i, vec_output + last_col * nn, mat2cols_output);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (uint64_t col_i = 0; col_i < col_max; col_i++) {
|
||||
double* pmat_col = mat_input + col_i * nrows * nn;
|
||||
for (uint64_t row_i = 0; row_i < 1; row_i++) {
|
||||
reim_fftvec_mul(module->mod.fft64.mul_fft, vec_output + col_i * nn, vec_input + row_i * nn,
|
||||
pmat_col + row_i * nn);
|
||||
}
|
||||
for (uint64_t row_i = 1; row_i < row_max; row_i++) {
|
||||
reim_fftvec_addmul(module->mod.fft64.p_addmul, vec_output + col_i * nn, vec_input + row_i * nn,
|
||||
pmat_col + row_i * nn);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// zero out remaining bytes
|
||||
memset(vec_output + col_max * nn, 0, (res_size - col_max) * nn * sizeof(double));
|
||||
}
|
||||
@@ -0,0 +1,185 @@
|
||||
#include <string.h>
|
||||
|
||||
#include "zn_arithmetic_private.h"
|
||||
|
||||
void default_init_z_module_precomp(MOD_Z* module) {
|
||||
// Add here initialization of items that are in the precomp
|
||||
}
|
||||
|
||||
void default_finalize_z_module_precomp(MOD_Z* module) {
|
||||
// Add here deleters for items that are in the precomp
|
||||
}
|
||||
|
||||
void default_init_z_module_vtable(MOD_Z* module) {
|
||||
// Add function pointers here
|
||||
module->vtable.i8_approxdecomp_from_tndbl = default_i8_approxdecomp_from_tndbl_ref;
|
||||
module->vtable.i16_approxdecomp_from_tndbl = default_i16_approxdecomp_from_tndbl_ref;
|
||||
module->vtable.i32_approxdecomp_from_tndbl = default_i32_approxdecomp_from_tndbl_ref;
|
||||
module->vtable.zn32_vmp_prepare_contiguous = default_zn32_vmp_prepare_contiguous_ref;
|
||||
module->vtable.zn32_vmp_prepare_dblptr = default_zn32_vmp_prepare_dblptr_ref;
|
||||
module->vtable.zn32_vmp_prepare_row = default_zn32_vmp_prepare_row_ref;
|
||||
module->vtable.zn32_vmp_apply_i8 = default_zn32_vmp_apply_i8_ref;
|
||||
module->vtable.zn32_vmp_apply_i16 = default_zn32_vmp_apply_i16_ref;
|
||||
module->vtable.zn32_vmp_apply_i32 = default_zn32_vmp_apply_i32_ref;
|
||||
module->vtable.dbl_to_tn32 = dbl_to_tn32_ref;
|
||||
module->vtable.tn32_to_dbl = tn32_to_dbl_ref;
|
||||
module->vtable.dbl_round_to_i32 = dbl_round_to_i32_ref;
|
||||
module->vtable.i32_to_dbl = i32_to_dbl_ref;
|
||||
module->vtable.dbl_round_to_i64 = dbl_round_to_i64_ref;
|
||||
module->vtable.i64_to_dbl = i64_to_dbl_ref;
|
||||
|
||||
// Add optimized function pointers here
|
||||
if (CPU_SUPPORTS("avx")) {
|
||||
module->vtable.zn32_vmp_apply_i8 = default_zn32_vmp_apply_i8_avx;
|
||||
module->vtable.zn32_vmp_apply_i16 = default_zn32_vmp_apply_i16_avx;
|
||||
module->vtable.zn32_vmp_apply_i32 = default_zn32_vmp_apply_i32_avx;
|
||||
}
|
||||
}
|
||||
|
||||
void init_z_module_info(MOD_Z* module, //
|
||||
Z_MODULE_TYPE mtype) {
|
||||
memset(module, 0, sizeof(MOD_Z));
|
||||
module->mtype = mtype;
|
||||
switch (mtype) {
|
||||
case DEFAULT:
|
||||
default_init_z_module_precomp(module);
|
||||
default_init_z_module_vtable(module);
|
||||
break;
|
||||
default:
|
||||
NOT_SUPPORTED(); // unknown mtype
|
||||
}
|
||||
}
|
||||
|
||||
void finalize_z_module_info(MOD_Z* module) {
|
||||
if (module->custom) module->custom_deleter(module->custom);
|
||||
switch (module->mtype) {
|
||||
case DEFAULT:
|
||||
default_finalize_z_module_precomp(module);
|
||||
// fft64_finalize_rnx_module_vtable(module); // nothing to finalize
|
||||
break;
|
||||
default:
|
||||
NOT_SUPPORTED(); // unknown mtype
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT MOD_Z* new_z_module_info(Z_MODULE_TYPE mtype) {
|
||||
MOD_Z* res = (MOD_Z*)malloc(sizeof(MOD_Z));
|
||||
init_z_module_info(res, mtype);
|
||||
return res;
|
||||
}
|
||||
|
||||
EXPORT void delete_z_module_info(MOD_Z* module_info) {
|
||||
finalize_z_module_info(module_info);
|
||||
free(module_info);
|
||||
}
|
||||
|
||||
//////////////// wrappers //////////////////
|
||||
|
||||
/** @brief sets res = gadget_decompose(a) (int8_t* output) */
|
||||
EXPORT void i8_approxdecomp_from_tndbl(const MOD_Z* module, // N
|
||||
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||
int8_t* res, uint64_t res_size, // res (in general, size ell.a_size)
|
||||
const double* a, uint64_t a_size) { // a
|
||||
module->vtable.i8_approxdecomp_from_tndbl(module, gadget, res, res_size, a, a_size);
|
||||
}
|
||||
|
||||
/** @brief sets res = gadget_decompose(a) (int16_t* output) */
|
||||
EXPORT void i16_approxdecomp_from_tndbl(const MOD_Z* module, // N
|
||||
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||
int16_t* res, uint64_t res_size, // res (in general, size ell.a_size)
|
||||
const double* a, uint64_t a_size) { // a
|
||||
module->vtable.i16_approxdecomp_from_tndbl(module, gadget, res, res_size, a, a_size);
|
||||
}
|
||||
/** @brief sets res = gadget_decompose(a) (int32_t* output) */
|
||||
EXPORT void i32_approxdecomp_from_tndbl(const MOD_Z* module, // N
|
||||
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||
int32_t* res, uint64_t res_size, // res (in general, size ell.a_size)
|
||||
const double* a, uint64_t a_size) { // a
|
||||
module->vtable.i32_approxdecomp_from_tndbl(module, gadget, res, res_size, a, a_size);
|
||||
}
|
||||
|
||||
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||
EXPORT void zn32_vmp_prepare_contiguous(const MOD_Z* module,
|
||||
ZN32_VMP_PMAT* pmat, // output
|
||||
const int32_t* mat, uint64_t nrows, uint64_t ncols) { // a
|
||||
module->vtable.zn32_vmp_prepare_contiguous(module, pmat, mat, nrows, ncols);
|
||||
}
|
||||
|
||||
/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */
|
||||
EXPORT void zn32_vmp_prepare_dblptr(const MOD_Z* module,
|
||||
ZN32_VMP_PMAT* pmat, // output
|
||||
const int32_t** mat, uint64_t nrows, uint64_t ncols) { // a
|
||||
module->vtable.zn32_vmp_prepare_dblptr(module, pmat, mat, nrows, ncols);
|
||||
}
|
||||
|
||||
/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */
|
||||
EXPORT void zn32_vmp_prepare_row(const MOD_Z* module,
|
||||
ZN32_VMP_PMAT* pmat, // output
|
||||
const int32_t* row, uint64_t row_i, uint64_t nrows, uint64_t ncols) { // a
|
||||
module->vtable.zn32_vmp_prepare_row(module, pmat, row, row_i, nrows, ncols);
|
||||
}
|
||||
|
||||
/** @brief applies a vmp product (int32_t* input) */
|
||||
EXPORT void zn32_vmp_apply_i32(const MOD_Z* module, int32_t* res, uint64_t res_size, const int32_t* a, uint64_t a_size,
|
||||
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) {
|
||||
module->vtable.zn32_vmp_apply_i32(module, res, res_size, a, a_size, pmat, nrows, ncols);
|
||||
}
|
||||
/** @brief applies a vmp product (int16_t* input) */
|
||||
EXPORT void zn32_vmp_apply_i16(const MOD_Z* module, int32_t* res, uint64_t res_size, const int16_t* a, uint64_t a_size,
|
||||
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) {
|
||||
module->vtable.zn32_vmp_apply_i16(module, res, res_size, a, a_size, pmat, nrows, ncols);
|
||||
}
|
||||
|
||||
/** @brief applies a vmp product (int8_t* input) */
|
||||
EXPORT void zn32_vmp_apply_i8(const MOD_Z* module, int32_t* res, uint64_t res_size, const int8_t* a, uint64_t a_size,
|
||||
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) {
|
||||
module->vtable.zn32_vmp_apply_i8(module, res, res_size, a, a_size, pmat, nrows, ncols);
|
||||
}
|
||||
|
||||
/** reduction mod 1, output in torus32 space */
|
||||
EXPORT void dbl_to_tn32(const MOD_Z* module, //
|
||||
int32_t* res, uint64_t res_size, // res
|
||||
const double* a, uint64_t a_size // a
|
||||
) {
|
||||
module->vtable.dbl_to_tn32(module, res, res_size, a, a_size);
|
||||
}
|
||||
|
||||
/** real centerlift mod 1, output in double space */
|
||||
EXPORT void tn32_to_dbl(const MOD_Z* module, //
|
||||
double* res, uint64_t res_size, // res
|
||||
const int32_t* a, uint64_t a_size // a
|
||||
) {
|
||||
module->vtable.tn32_to_dbl(module, res, res_size, a, a_size);
|
||||
}
|
||||
|
||||
/** round to the nearest int, output in i32 space */
|
||||
EXPORT void dbl_round_to_i32(const MOD_Z* module, //
|
||||
int32_t* res, uint64_t res_size, // res
|
||||
const double* a, uint64_t a_size // a
|
||||
) {
|
||||
module->vtable.dbl_round_to_i32(module, res, res_size, a, a_size);
|
||||
}
|
||||
|
||||
/** small int (int32 space) to double */
|
||||
EXPORT void i32_to_dbl(const MOD_Z* module, //
|
||||
double* res, uint64_t res_size, // res
|
||||
const int32_t* a, uint64_t a_size // a
|
||||
) {
|
||||
module->vtable.i32_to_dbl(module, res, res_size, a, a_size);
|
||||
}
|
||||
|
||||
/** round to the nearest int, output in int64 space */
|
||||
EXPORT void dbl_round_to_i64(const MOD_Z* module, //
|
||||
int64_t* res, uint64_t res_size, // res
|
||||
const double* a, uint64_t a_size // a
|
||||
) {
|
||||
module->vtable.dbl_round_to_i64(module, res, res_size, a, a_size);
|
||||
}
|
||||
|
||||
/** small int (int64 space, <= 2^50) to double */
|
||||
EXPORT void i64_to_dbl(const MOD_Z* module, //
|
||||
double* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size // a
|
||||
) {
|
||||
module->vtable.i64_to_dbl(module, res, res_size, a, a_size);
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
#include <memory.h>
|
||||
|
||||
#include "zn_arithmetic_private.h"
|
||||
|
||||
EXPORT TNDBL_APPROXDECOMP_GADGET* new_tndbl_approxdecomp_gadget(const MOD_Z* module, //
|
||||
uint64_t k, uint64_t ell) {
|
||||
if (k * ell > 50) {
|
||||
return spqlios_error("approx decomposition requested is too precise for doubles");
|
||||
}
|
||||
if (k < 1) {
|
||||
return spqlios_error("approx decomposition supports k>=1");
|
||||
}
|
||||
TNDBL_APPROXDECOMP_GADGET* res = malloc(sizeof(TNDBL_APPROXDECOMP_GADGET));
|
||||
memset(res, 0, sizeof(TNDBL_APPROXDECOMP_GADGET));
|
||||
res->k = k;
|
||||
res->ell = ell;
|
||||
double add_cst = INT64_C(3) << (51 - k * ell);
|
||||
for (uint64_t i = 0; i < ell; ++i) {
|
||||
add_cst += pow(2., -(double)(i * k + 1));
|
||||
}
|
||||
res->add_cst = add_cst;
|
||||
res->and_mask = (UINT64_C(1) << k) - 1;
|
||||
res->sub_cst = UINT64_C(1) << (k - 1);
|
||||
for (uint64_t i = 0; i < ell; ++i) res->rshifts[i] = (ell - 1 - i) * k;
|
||||
return res;
|
||||
}
|
||||
EXPORT void delete_tndbl_approxdecomp_gadget(TNDBL_APPROXDECOMP_GADGET* ptr) { free(ptr); }
|
||||
|
||||
EXPORT int default_init_tndbl_approxdecomp_gadget(const MOD_Z* module, //
|
||||
TNDBL_APPROXDECOMP_GADGET* res, //
|
||||
uint64_t k, uint64_t ell) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
typedef union {
|
||||
double dv;
|
||||
uint64_t uv;
|
||||
} du_t;
|
||||
|
||||
#define IMPL_ixx_approxdecomp_from_tndbl_ref(ITYPE) \
|
||||
if (res_size != a_size * gadget->ell) NOT_IMPLEMENTED(); \
|
||||
const uint64_t ell = gadget->ell; \
|
||||
const double add_cst = gadget->add_cst; \
|
||||
const uint8_t* const rshifts = gadget->rshifts; \
|
||||
const ITYPE and_mask = gadget->and_mask; \
|
||||
const ITYPE sub_cst = gadget->sub_cst; \
|
||||
ITYPE* rr = res; \
|
||||
const double* aa = a; \
|
||||
const double* aaend = a + a_size; \
|
||||
while (aa < aaend) { \
|
||||
du_t t = {.dv = *aa + add_cst}; \
|
||||
for (uint64_t i = 0; i < ell; ++i) { \
|
||||
ITYPE v = (ITYPE)(t.uv >> rshifts[i]); \
|
||||
*rr = (v & and_mask) - sub_cst; \
|
||||
++rr; \
|
||||
} \
|
||||
++aa; \
|
||||
}
|
||||
|
||||
/** @brief sets res = gadget_decompose(a) (int8_t* output) */
|
||||
EXPORT void default_i8_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N
|
||||
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||
int8_t* res, uint64_t res_size, // res (in general, size ell.a_size)
|
||||
const double* a, uint64_t a_size //
|
||||
){IMPL_ixx_approxdecomp_from_tndbl_ref(int8_t)}
|
||||
|
||||
/** @brief sets res = gadget_decompose(a) (int16_t* output) */
|
||||
EXPORT void default_i16_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N
|
||||
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||
int16_t* res, uint64_t res_size, // res
|
||||
const double* a, uint64_t a_size // a
|
||||
){IMPL_ixx_approxdecomp_from_tndbl_ref(int16_t)}
|
||||
|
||||
/** @brief sets res = gadget_decompose(a) (int32_t* output) */
|
||||
EXPORT void default_i32_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N
|
||||
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||
int32_t* res, uint64_t res_size, // res
|
||||
const double* a, uint64_t a_size // a
|
||||
) {
|
||||
IMPL_ixx_approxdecomp_from_tndbl_ref(int32_t)
|
||||
}
|
||||
@@ -0,0 +1,147 @@
|
||||
#ifndef SPQLIOS_ZN_ARITHMETIC_H
|
||||
#define SPQLIOS_ZN_ARITHMETIC_H
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include "../commons.h"
|
||||
|
||||
typedef enum z_module_type_t { DEFAULT } Z_MODULE_TYPE;
|
||||
|
||||
/** @brief opaque structure that describes the module and the hardware */
|
||||
typedef struct z_module_info_t MOD_Z;
|
||||
|
||||
/**
|
||||
* @brief obtain a module info for ring dimension N
|
||||
* the module-info knows about:
|
||||
* - the dimension N (or the complex dimension m=N/2)
|
||||
* - any moduleuted fft or ntt items
|
||||
* - the hardware (avx, arm64, x86, ...)
|
||||
*/
|
||||
EXPORT MOD_Z* new_z_module_info(Z_MODULE_TYPE mode);
|
||||
EXPORT void delete_z_module_info(MOD_Z* module_info);
|
||||
|
||||
typedef struct tndbl_approxdecomp_gadget_t TNDBL_APPROXDECOMP_GADGET;
|
||||
|
||||
EXPORT TNDBL_APPROXDECOMP_GADGET* new_tndbl_approxdecomp_gadget(const MOD_Z* module, //
|
||||
uint64_t k,
|
||||
uint64_t ell); // base 2^k, and size
|
||||
|
||||
EXPORT void delete_tndbl_approxdecomp_gadget(TNDBL_APPROXDECOMP_GADGET* ptr);
|
||||
|
||||
/** @brief sets res = gadget_decompose(a) (int8_t* output) */
|
||||
EXPORT void i8_approxdecomp_from_tndbl(const MOD_Z* module, // N
|
||||
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||
int8_t* res, uint64_t res_size, // res (in general, size ell.a_size)
|
||||
const double* a, uint64_t a_size); // a
|
||||
|
||||
/** @brief sets res = gadget_decompose(a) (int16_t* output) */
|
||||
EXPORT void i16_approxdecomp_from_tndbl(const MOD_Z* module, // N
|
||||
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||
int16_t* res, uint64_t res_size, // res (in general, size ell.a_size)
|
||||
const double* a, uint64_t a_size); // a
|
||||
/** @brief sets res = gadget_decompose(a) (int32_t* output) */
|
||||
EXPORT void i32_approxdecomp_from_tndbl(const MOD_Z* module, // N
|
||||
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||
int32_t* res, uint64_t res_size, // res (in general, size ell.a_size)
|
||||
const double* a, uint64_t a_size); // a
|
||||
|
||||
/** @brief opaque type that represents a prepared matrix */
|
||||
typedef struct zn32_vmp_pmat_t ZN32_VMP_PMAT;
|
||||
|
||||
/** @brief size in bytes of a prepared matrix (for custom allocation) */
|
||||
EXPORT uint64_t bytes_of_zn32_vmp_pmat(const MOD_Z* module, // N
|
||||
uint64_t nrows, uint64_t ncols); // dimensions
|
||||
|
||||
/** @brief allocates a prepared matrix (release with delete_zn32_vmp_pmat) */
|
||||
EXPORT ZN32_VMP_PMAT* new_zn32_vmp_pmat(const MOD_Z* module, // N
|
||||
uint64_t nrows, uint64_t ncols); // dimensions
|
||||
|
||||
/** @brief deletes a prepared matrix (release with free) */
|
||||
EXPORT void delete_zn32_vmp_pmat(ZN32_VMP_PMAT* ptr); // dimensions
|
||||
|
||||
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||
EXPORT void zn32_vmp_prepare_contiguous( //
|
||||
const MOD_Z* module,
|
||||
ZN32_VMP_PMAT* pmat, // output
|
||||
const int32_t* mat, uint64_t nrows, uint64_t ncols); // a
|
||||
|
||||
/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */
|
||||
EXPORT void zn32_vmp_prepare_dblptr( //
|
||||
const MOD_Z* module,
|
||||
ZN32_VMP_PMAT* pmat, // output
|
||||
const int32_t** mat, uint64_t nrows, uint64_t ncols); // a
|
||||
|
||||
/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */
|
||||
EXPORT void zn32_vmp_prepare_row( //
|
||||
const MOD_Z* module,
|
||||
ZN32_VMP_PMAT* pmat, // output
|
||||
const int32_t* row, uint64_t row_i, uint64_t nrows, uint64_t ncols); // a
|
||||
|
||||
/** @brief applies a vmp product (int32_t* input) */
|
||||
EXPORT void zn32_vmp_apply_i32( //
|
||||
const MOD_Z* module, //
|
||||
int32_t* res, uint64_t res_size, // res
|
||||
const int32_t* a, uint64_t a_size, // a
|
||||
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix
|
||||
|
||||
/** @brief applies a vmp product (int16_t* input) */
|
||||
EXPORT void zn32_vmp_apply_i16( //
|
||||
const MOD_Z* module, //
|
||||
int32_t* res, uint64_t res_size, // res
|
||||
const int16_t* a, uint64_t a_size, // a
|
||||
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix
|
||||
|
||||
/** @brief applies a vmp product (int8_t* input) */
|
||||
EXPORT void zn32_vmp_apply_i8( //
|
||||
const MOD_Z* module, //
|
||||
int32_t* res, uint64_t res_size, // res
|
||||
const int8_t* a, uint64_t a_size, // a
|
||||
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix
|
||||
|
||||
// explicit conversions
|
||||
|
||||
/** reduction mod 1, output in torus32 space */
|
||||
EXPORT void dbl_to_tn32(const MOD_Z* module, //
|
||||
int32_t* res, uint64_t res_size, // res
|
||||
const double* a, uint64_t a_size // a
|
||||
);
|
||||
|
||||
/** real centerlift mod 1, output in double space */
|
||||
EXPORT void tn32_to_dbl(const MOD_Z* module, //
|
||||
double* res, uint64_t res_size, // res
|
||||
const int32_t* a, uint64_t a_size // a
|
||||
);
|
||||
|
||||
/** round to the nearest int, output in i32 space.
|
||||
* WARNING: ||a||_inf must be <= 2^18 in this function
|
||||
*/
|
||||
EXPORT void dbl_round_to_i32(const MOD_Z* module, //
|
||||
int32_t* res, uint64_t res_size, // res
|
||||
const double* a, uint64_t a_size // a
|
||||
);
|
||||
|
||||
/** small int (int32 space) to double
|
||||
* WARNING: ||a||_inf must be <= 2^18 in this function
|
||||
*/
|
||||
EXPORT void i32_to_dbl(const MOD_Z* module, //
|
||||
double* res, uint64_t res_size, // res
|
||||
const int32_t* a, uint64_t a_size // a
|
||||
);
|
||||
|
||||
/** round to the nearest int, output in int64 space
|
||||
* WARNING: ||a||_inf must be <= 2^50 in this function
|
||||
*/
|
||||
EXPORT void dbl_round_to_i64(const MOD_Z* module, //
|
||||
int64_t* res, uint64_t res_size, // res
|
||||
const double* a, uint64_t a_size // a
|
||||
);
|
||||
|
||||
/** small int (int64 space, <= 2^50) to double
|
||||
* WARNING: ||a||_inf must be <= 2^50 in this function
|
||||
*/
|
||||
EXPORT void i64_to_dbl(const MOD_Z* module, //
|
||||
double* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size // a
|
||||
);
|
||||
|
||||
#endif // SPQLIOS_ZN_ARITHMETIC_H
|
||||
@@ -0,0 +1,43 @@
|
||||
#ifndef SPQLIOS_ZN_ARITHMETIC_PLUGIN_H
|
||||
#define SPQLIOS_ZN_ARITHMETIC_PLUGIN_H
|
||||
|
||||
#include "zn_arithmetic.h"
|
||||
|
||||
typedef typeof(i8_approxdecomp_from_tndbl) I8_APPROXDECOMP_FROM_TNDBL_F;
|
||||
typedef typeof(i16_approxdecomp_from_tndbl) I16_APPROXDECOMP_FROM_TNDBL_F;
|
||||
typedef typeof(i32_approxdecomp_from_tndbl) I32_APPROXDECOMP_FROM_TNDBL_F;
|
||||
typedef typeof(bytes_of_zn32_vmp_pmat) BYTES_OF_ZN32_VMP_PMAT_F;
|
||||
typedef typeof(zn32_vmp_prepare_contiguous) ZN32_VMP_PREPARE_CONTIGUOUS_F;
|
||||
typedef typeof(zn32_vmp_prepare_dblptr) ZN32_VMP_PREPARE_DBLPTR_F;
|
||||
typedef typeof(zn32_vmp_prepare_row) ZN32_VMP_PREPARE_ROW_F;
|
||||
typedef typeof(zn32_vmp_apply_i32) ZN32_VMP_APPLY_I32_F;
|
||||
typedef typeof(zn32_vmp_apply_i16) ZN32_VMP_APPLY_I16_F;
|
||||
typedef typeof(zn32_vmp_apply_i8) ZN32_VMP_APPLY_I8_F;
|
||||
typedef typeof(dbl_to_tn32) DBL_TO_TN32_F;
|
||||
typedef typeof(tn32_to_dbl) TN32_TO_DBL_F;
|
||||
typedef typeof(dbl_round_to_i32) DBL_ROUND_TO_I32_F;
|
||||
typedef typeof(i32_to_dbl) I32_TO_DBL_F;
|
||||
typedef typeof(dbl_round_to_i64) DBL_ROUND_TO_I64_F;
|
||||
typedef typeof(i64_to_dbl) I64_TO_DBL_F;
|
||||
|
||||
typedef struct z_module_vtable_t Z_MODULE_VTABLE;
|
||||
struct z_module_vtable_t {
|
||||
I8_APPROXDECOMP_FROM_TNDBL_F* i8_approxdecomp_from_tndbl;
|
||||
I16_APPROXDECOMP_FROM_TNDBL_F* i16_approxdecomp_from_tndbl;
|
||||
I32_APPROXDECOMP_FROM_TNDBL_F* i32_approxdecomp_from_tndbl;
|
||||
BYTES_OF_ZN32_VMP_PMAT_F* bytes_of_zn32_vmp_pmat;
|
||||
ZN32_VMP_PREPARE_CONTIGUOUS_F* zn32_vmp_prepare_contiguous;
|
||||
ZN32_VMP_PREPARE_DBLPTR_F* zn32_vmp_prepare_dblptr;
|
||||
ZN32_VMP_PREPARE_ROW_F* zn32_vmp_prepare_row;
|
||||
ZN32_VMP_APPLY_I32_F* zn32_vmp_apply_i32;
|
||||
ZN32_VMP_APPLY_I16_F* zn32_vmp_apply_i16;
|
||||
ZN32_VMP_APPLY_I8_F* zn32_vmp_apply_i8;
|
||||
DBL_TO_TN32_F* dbl_to_tn32;
|
||||
TN32_TO_DBL_F* tn32_to_dbl;
|
||||
DBL_ROUND_TO_I32_F* dbl_round_to_i32;
|
||||
I32_TO_DBL_F* i32_to_dbl;
|
||||
DBL_ROUND_TO_I64_F* dbl_round_to_i64;
|
||||
I64_TO_DBL_F* i64_to_dbl;
|
||||
};
|
||||
|
||||
#endif // SPQLIOS_ZN_ARITHMETIC_PLUGIN_H
|
||||
@@ -0,0 +1,164 @@
|
||||
#ifndef SPQLIOS_ZN_ARITHMETIC_PRIVATE_H
|
||||
#define SPQLIOS_ZN_ARITHMETIC_PRIVATE_H
|
||||
|
||||
#include "../commons_private.h"
|
||||
#include "zn_arithmetic.h"
|
||||
#include "zn_arithmetic_plugin.h"
|
||||
|
||||
typedef struct main_z_module_precomp_t MAIN_Z_MODULE_PRECOMP;
|
||||
struct main_z_module_precomp_t {
|
||||
// TODO
|
||||
};
|
||||
|
||||
typedef union z_module_precomp_t Z_MODULE_PRECOMP;
|
||||
union z_module_precomp_t {
|
||||
MAIN_Z_MODULE_PRECOMP main;
|
||||
};
|
||||
|
||||
void main_init_z_module_precomp(MOD_Z* module);
|
||||
|
||||
void main_finalize_z_module_precomp(MOD_Z* module);
|
||||
|
||||
/** @brief opaque structure that describes the modules (RnX,ZnX,TnX) and the hardware */
|
||||
struct z_module_info_t {
|
||||
Z_MODULE_TYPE mtype;
|
||||
Z_MODULE_VTABLE vtable;
|
||||
Z_MODULE_PRECOMP precomp;
|
||||
void* custom;
|
||||
void (*custom_deleter)(void*);
|
||||
};
|
||||
|
||||
void init_z_module_info(MOD_Z* module, Z_MODULE_TYPE mtype);
|
||||
|
||||
void main_init_z_module_vtable(MOD_Z* module);
|
||||
|
||||
struct tndbl_approxdecomp_gadget_t {
|
||||
uint64_t k;
|
||||
uint64_t ell;
|
||||
double add_cst; // 3.2^51-(K.ell) + 1/2.(sum 2^-(i+1)K)
|
||||
int64_t and_mask; // (2^K)-1
|
||||
int64_t sub_cst; // 2^(K-1)
|
||||
uint8_t rshifts[64]; // 2^(ell-1-i).K for i in [0:ell-1]
|
||||
};
|
||||
|
||||
/** @brief sets res = gadget_decompose(a) (int8_t* output) */
|
||||
EXPORT void default_i8_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N
|
||||
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||
int8_t* res, uint64_t res_size, // res (in general, size ell.a_size)
|
||||
const double* a, uint64_t a_size); // a
|
||||
|
||||
/** @brief sets res = gadget_decompose(a) (int16_t* output) */
|
||||
EXPORT void default_i16_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N
|
||||
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||
int16_t* res,
|
||||
uint64_t res_size, // res (in general, size ell.a_size)
|
||||
const double* a, uint64_t a_size); // a
|
||||
/** @brief sets res = gadget_decompose(a) (int32_t* output) */
|
||||
EXPORT void default_i32_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N
|
||||
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||
int32_t* res,
|
||||
uint64_t res_size, // res (in general, size ell.a_size)
|
||||
const double* a, uint64_t a_size); // a
|
||||
|
||||
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||
EXPORT void default_zn32_vmp_prepare_contiguous_ref( //
|
||||
const MOD_Z* module,
|
||||
ZN32_VMP_PMAT* pmat, // output
|
||||
const int32_t* mat, uint64_t nrows, uint64_t ncols // a
|
||||
);
|
||||
|
||||
/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */
|
||||
EXPORT void default_zn32_vmp_prepare_dblptr_ref( //
|
||||
const MOD_Z* module,
|
||||
ZN32_VMP_PMAT* pmat, // output
|
||||
const int32_t** mat, uint64_t nrows, uint64_t ncols // a
|
||||
);
|
||||
|
||||
/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */
|
||||
EXPORT void default_zn32_vmp_prepare_row_ref( //
|
||||
const MOD_Z* module,
|
||||
ZN32_VMP_PMAT* pmat, // output
|
||||
const int32_t* row, uint64_t row_i, uint64_t nrows, uint64_t ncols // a
|
||||
);
|
||||
|
||||
/** @brief applies a vmp product (int32_t* input) */
|
||||
EXPORT void default_zn32_vmp_apply_i32_ref( //
|
||||
const MOD_Z* module, //
|
||||
int32_t* res, uint64_t res_size, // res
|
||||
const int32_t* a, uint64_t a_size, // a
|
||||
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix
|
||||
|
||||
/** @brief applies a vmp product (int16_t* input) */
|
||||
EXPORT void default_zn32_vmp_apply_i16_ref( //
|
||||
const MOD_Z* module, // N
|
||||
int32_t* res, uint64_t res_size, // res
|
||||
const int16_t* a, uint64_t a_size, // a
|
||||
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix
|
||||
|
||||
/** @brief applies a vmp product (int8_t* input) */
|
||||
EXPORT void default_zn32_vmp_apply_i8_ref( //
|
||||
const MOD_Z* module, // N
|
||||
int32_t* res, uint64_t res_size, // res
|
||||
const int8_t* a, uint64_t a_size, // a
|
||||
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix
|
||||
|
||||
/** @brief applies a vmp product (int32_t* input) */
|
||||
EXPORT void default_zn32_vmp_apply_i32_avx( //
|
||||
const MOD_Z* module, //
|
||||
int32_t* res, uint64_t res_size, // res
|
||||
const int32_t* a, uint64_t a_size, // a
|
||||
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix
|
||||
|
||||
/** @brief applies a vmp product (int16_t* input) */
|
||||
EXPORT void default_zn32_vmp_apply_i16_avx( //
|
||||
const MOD_Z* module, // N
|
||||
int32_t* res, uint64_t res_size, // res
|
||||
const int16_t* a, uint64_t a_size, // a
|
||||
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix
|
||||
|
||||
/** @brief applies a vmp product (int8_t* input) */
|
||||
EXPORT void default_zn32_vmp_apply_i8_avx( //
|
||||
const MOD_Z* module, // N
|
||||
int32_t* res, uint64_t res_size, // res
|
||||
const int8_t* a, uint64_t a_size, // a
|
||||
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix
|
||||
|
||||
// explicit conversions
|
||||
|
||||
/** reduction mod 1, output in torus32 space */
|
||||
EXPORT void dbl_to_tn32_ref(const MOD_Z* module, //
|
||||
int32_t* res, uint64_t res_size, // res
|
||||
const double* a, uint64_t a_size // a
|
||||
);
|
||||
|
||||
/** real centerlift mod 1, output in double space */
|
||||
EXPORT void tn32_to_dbl_ref(const MOD_Z* module, //
|
||||
double* res, uint64_t res_size, // res
|
||||
const int32_t* a, uint64_t a_size // a
|
||||
);
|
||||
|
||||
/** round to the nearest int, output in i32 space */
|
||||
EXPORT void dbl_round_to_i32_ref(const MOD_Z* module, //
|
||||
int32_t* res, uint64_t res_size, // res
|
||||
const double* a, uint64_t a_size // a
|
||||
);
|
||||
|
||||
/** small int (int32 space) to double */
|
||||
EXPORT void i32_to_dbl_ref(const MOD_Z* module, //
|
||||
double* res, uint64_t res_size, // res
|
||||
const int32_t* a, uint64_t a_size // a
|
||||
);
|
||||
|
||||
/** round to the nearest int, output in int64 space */
|
||||
EXPORT void dbl_round_to_i64_ref(const MOD_Z* module, //
|
||||
int64_t* res, uint64_t res_size, // res
|
||||
const double* a, uint64_t a_size // a
|
||||
);
|
||||
|
||||
/** small int (int64 space) to double */
|
||||
EXPORT void i64_to_dbl_ref(const MOD_Z* module, //
|
||||
double* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size // a
|
||||
);
|
||||
|
||||
#endif // SPQLIOS_ZN_ARITHMETIC_PRIVATE_H
|
||||
@@ -0,0 +1,108 @@
|
||||
#include <memory.h>
|
||||
|
||||
#include "zn_arithmetic_private.h"
|
||||
|
||||
typedef union {
|
||||
double dv;
|
||||
int64_t s64v;
|
||||
int32_t s32v;
|
||||
uint64_t u64v;
|
||||
uint32_t u32v;
|
||||
} di_t;
|
||||
|
||||
/** reduction mod 1, output in torus32 space */
|
||||
EXPORT void dbl_to_tn32_ref(const MOD_Z* module, //
|
||||
int32_t* res, uint64_t res_size, // res
|
||||
const double* a, uint64_t a_size // a
|
||||
) {
|
||||
static const double ADD_CST = 0.5 + (double)(INT64_C(3) << (51 - 32));
|
||||
static const int32_t XOR_CST = (INT32_C(1) << 31);
|
||||
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < msize; ++i) {
|
||||
di_t t = {.dv = a[i] + ADD_CST};
|
||||
res[i] = t.s32v ^ XOR_CST;
|
||||
}
|
||||
memset(res + msize, 0, (res_size - msize) * sizeof(int32_t));
|
||||
}
|
||||
|
||||
/** real centerlift mod 1, output in double space */
|
||||
EXPORT void tn32_to_dbl_ref(const MOD_Z* module, //
|
||||
double* res, uint64_t res_size, // res
|
||||
const int32_t* a, uint64_t a_size // a
|
||||
) {
|
||||
static const uint32_t XOR_CST = (UINT32_C(1) << 31);
|
||||
static const di_t OR_CST = {.dv = (double)(INT64_C(1) << (52 - 32))};
|
||||
static const double SUB_CST = 0.5 + (double)(INT64_C(1) << (52 - 32));
|
||||
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < msize; ++i) {
|
||||
uint32_t ai = a[i] ^ XOR_CST;
|
||||
di_t t = {.u64v = OR_CST.u64v | (uint64_t)ai};
|
||||
res[i] = t.dv - SUB_CST;
|
||||
}
|
||||
memset(res + msize, 0, (res_size - msize) * sizeof(double));
|
||||
}
|
||||
|
||||
/** round to the nearest int, output in i32 space */
|
||||
EXPORT void dbl_round_to_i32_ref(const MOD_Z* module, //
|
||||
int32_t* res, uint64_t res_size, // res
|
||||
const double* a, uint64_t a_size // a
|
||||
) {
|
||||
static const double ADD_CST = (double)((INT64_C(3) << (51)) + (INT64_C(1) << (31)));
|
||||
static const int32_t XOR_CST = INT32_C(1) << 31;
|
||||
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < msize; ++i) {
|
||||
di_t t = {.dv = a[i] + ADD_CST};
|
||||
res[i] = t.s32v ^ XOR_CST;
|
||||
}
|
||||
memset(res + msize, 0, (res_size - msize) * sizeof(int32_t));
|
||||
}
|
||||
|
||||
/** small int (int32 space) to double */
|
||||
EXPORT void i32_to_dbl_ref(const MOD_Z* module, //
|
||||
double* res, uint64_t res_size, // res
|
||||
const int32_t* a, uint64_t a_size // a
|
||||
) {
|
||||
static const uint32_t XOR_CST = (UINT32_C(1) << 31);
|
||||
static const di_t OR_CST = {.dv = (double)(INT64_C(1) << 52)};
|
||||
static const double SUB_CST = (double)((INT64_C(1) << 52) + (INT64_C(1) << 31));
|
||||
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < msize; ++i) {
|
||||
uint32_t ai = a[i] ^ XOR_CST;
|
||||
di_t t = {.u64v = OR_CST.u64v | (uint64_t)ai};
|
||||
res[i] = t.dv - SUB_CST;
|
||||
}
|
||||
memset(res + msize, 0, (res_size - msize) * sizeof(double));
|
||||
}
|
||||
|
||||
/** round to the nearest int, output in int64 space */
|
||||
EXPORT void dbl_round_to_i64_ref(const MOD_Z* module, //
|
||||
int64_t* res, uint64_t res_size, // res
|
||||
const double* a, uint64_t a_size // a
|
||||
) {
|
||||
static const double ADD_CST = (double)(INT64_C(3) << (51));
|
||||
static const int64_t AND_CST = (INT64_C(1) << 52) - 1;
|
||||
static const int64_t SUB_CST = INT64_C(1) << 51;
|
||||
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < msize; ++i) {
|
||||
di_t t = {.dv = a[i] + ADD_CST};
|
||||
res[i] = (t.s64v & AND_CST) - SUB_CST;
|
||||
}
|
||||
memset(res + msize, 0, (res_size - msize) * sizeof(int64_t));
|
||||
}
|
||||
|
||||
/** small int (int64 space) to double */
|
||||
EXPORT void i64_to_dbl_ref(const MOD_Z* module, //
|
||||
double* res, uint64_t res_size, // res
|
||||
const int64_t* a, uint64_t a_size // a
|
||||
) {
|
||||
static const uint64_t ADD_CST = UINT64_C(1) << 51;
|
||||
static const uint64_t AND_CST = (UINT64_C(1) << 52) - 1;
|
||||
static const di_t OR_CST = {.dv = (INT64_C(1) << 52)};
|
||||
static const double SUB_CST = INT64_C(3) << 51;
|
||||
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||
for (uint64_t i = 0; i < msize; ++i) {
|
||||
di_t t = {.u64v = ((a[i] + ADD_CST) & AND_CST) | OR_CST.u64v};
|
||||
res[i] = t.dv - SUB_CST;
|
||||
}
|
||||
memset(res + msize, 0, (res_size - msize) * sizeof(double));
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
#define INTTYPE int16_t
|
||||
#define INTSN i16
|
||||
|
||||
#include "zn_vmp_int32_avx.c"
|
||||
@@ -0,0 +1,4 @@
|
||||
#define INTTYPE int16_t
|
||||
#define INTSN i16
|
||||
|
||||
#include "zn_vmp_int32_ref.c"
|
||||
@@ -0,0 +1,223 @@
|
||||
// This file is actually a template: it will be compiled multiple times with
|
||||
// different INTTYPES
|
||||
#ifndef INTTYPE
|
||||
#define INTTYPE int32_t
|
||||
#define INTSN i32
|
||||
#endif
|
||||
|
||||
#include <immintrin.h>
|
||||
#include <memory.h>
|
||||
|
||||
#include "zn_arithmetic_private.h"
|
||||
|
||||
#define concat_inner(aa, bb, cc) aa##_##bb##_##cc
|
||||
#define concat(aa, bb, cc) concat_inner(aa, bb, cc)
|
||||
#define zn32_vec_fn(cc) concat(zn32_vec, INTSN, cc)
|
||||
|
||||
static void zn32_vec_mat32cols_avx_prefetch(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b) {
|
||||
if (nrows == 0) {
|
||||
memset(res, 0, 32 * sizeof(int32_t));
|
||||
return;
|
||||
}
|
||||
const int32_t* bb = b;
|
||||
const int32_t* pref_bb = b;
|
||||
const uint64_t pref_iters = 128;
|
||||
const uint64_t pref_start = pref_iters < nrows ? pref_iters : nrows;
|
||||
const uint64_t pref_last = pref_iters > nrows ? 0 : nrows - pref_iters;
|
||||
// let's do some prefetching of the GSW key, since on some cpus,
|
||||
// it helps
|
||||
for (uint64_t i = 0; i < pref_start; ++i) {
|
||||
__builtin_prefetch(pref_bb, 0, _MM_HINT_T0);
|
||||
__builtin_prefetch(pref_bb + 16, 0, _MM_HINT_T0);
|
||||
pref_bb += 32;
|
||||
}
|
||||
// we do the first iteration
|
||||
__m256i x = _mm256_set1_epi32(a[0]);
|
||||
__m256i r0 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb)));
|
||||
__m256i r1 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8)));
|
||||
__m256i r2 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16)));
|
||||
__m256i r3 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 24)));
|
||||
bb += 32;
|
||||
uint64_t row = 1;
|
||||
for (; //
|
||||
row < pref_last; //
|
||||
++row, bb += 32) {
|
||||
// prefetch the next iteration
|
||||
__builtin_prefetch(pref_bb, 0, _MM_HINT_T0);
|
||||
__builtin_prefetch(pref_bb + 16, 0, _MM_HINT_T0);
|
||||
pref_bb += 32;
|
||||
INTTYPE ai = a[row];
|
||||
if (ai == 0) continue;
|
||||
x = _mm256_set1_epi32(ai);
|
||||
r0 = _mm256_add_epi32(r0, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb))));
|
||||
r1 = _mm256_add_epi32(r1, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8))));
|
||||
r2 = _mm256_add_epi32(r2, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16))));
|
||||
r3 = _mm256_add_epi32(r3, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 24))));
|
||||
}
|
||||
for (; //
|
||||
row < nrows; //
|
||||
++row, bb += 32) {
|
||||
INTTYPE ai = a[row];
|
||||
if (ai == 0) continue;
|
||||
x = _mm256_set1_epi32(ai);
|
||||
r0 = _mm256_add_epi32(r0, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb))));
|
||||
r1 = _mm256_add_epi32(r1, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8))));
|
||||
r2 = _mm256_add_epi32(r2, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16))));
|
||||
r3 = _mm256_add_epi32(r3, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 24))));
|
||||
}
|
||||
_mm256_storeu_si256((__m256i*)(res), r0);
|
||||
_mm256_storeu_si256((__m256i*)(res + 8), r1);
|
||||
_mm256_storeu_si256((__m256i*)(res + 16), r2);
|
||||
_mm256_storeu_si256((__m256i*)(res + 24), r3);
|
||||
}
|
||||
|
||||
void zn32_vec_fn(mat32cols_avx)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) {
|
||||
if (nrows == 0) {
|
||||
memset(res, 0, 32 * sizeof(int32_t));
|
||||
return;
|
||||
}
|
||||
const INTTYPE* aa = a;
|
||||
const INTTYPE* const aaend = a + nrows;
|
||||
const int32_t* bb = b;
|
||||
__m256i x = _mm256_set1_epi32(*aa);
|
||||
__m256i r0 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb)));
|
||||
__m256i r1 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8)));
|
||||
__m256i r2 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16)));
|
||||
__m256i r3 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 24)));
|
||||
bb += b_sl;
|
||||
++aa;
|
||||
for (; //
|
||||
aa < aaend; //
|
||||
bb += b_sl, ++aa) {
|
||||
INTTYPE ai = *aa;
|
||||
if (ai == 0) continue;
|
||||
x = _mm256_set1_epi32(ai);
|
||||
r0 = _mm256_add_epi32(r0, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb))));
|
||||
r1 = _mm256_add_epi32(r1, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8))));
|
||||
r2 = _mm256_add_epi32(r2, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16))));
|
||||
r3 = _mm256_add_epi32(r3, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 24))));
|
||||
}
|
||||
_mm256_storeu_si256((__m256i*)(res), r0);
|
||||
_mm256_storeu_si256((__m256i*)(res + 8), r1);
|
||||
_mm256_storeu_si256((__m256i*)(res + 16), r2);
|
||||
_mm256_storeu_si256((__m256i*)(res + 24), r3);
|
||||
}
|
||||
|
||||
void zn32_vec_fn(mat24cols_avx)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) {
|
||||
if (nrows == 0) {
|
||||
memset(res, 0, 24 * sizeof(int32_t));
|
||||
return;
|
||||
}
|
||||
const INTTYPE* aa = a;
|
||||
const INTTYPE* const aaend = a + nrows;
|
||||
const int32_t* bb = b;
|
||||
__m256i x = _mm256_set1_epi32(*aa);
|
||||
__m256i r0 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb)));
|
||||
__m256i r1 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8)));
|
||||
__m256i r2 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16)));
|
||||
bb += b_sl;
|
||||
++aa;
|
||||
for (; //
|
||||
aa < aaend; //
|
||||
bb += b_sl, ++aa) {
|
||||
INTTYPE ai = *aa;
|
||||
if (ai == 0) continue;
|
||||
x = _mm256_set1_epi32(ai);
|
||||
r0 = _mm256_add_epi32(r0, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb))));
|
||||
r1 = _mm256_add_epi32(r1, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8))));
|
||||
r2 = _mm256_add_epi32(r2, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16))));
|
||||
}
|
||||
_mm256_storeu_si256((__m256i*)(res), r0);
|
||||
_mm256_storeu_si256((__m256i*)(res + 8), r1);
|
||||
_mm256_storeu_si256((__m256i*)(res + 16), r2);
|
||||
}
|
||||
void zn32_vec_fn(mat16cols_avx)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) {
|
||||
if (nrows == 0) {
|
||||
memset(res, 0, 16 * sizeof(int32_t));
|
||||
return;
|
||||
}
|
||||
const INTTYPE* aa = a;
|
||||
const INTTYPE* const aaend = a + nrows;
|
||||
const int32_t* bb = b;
|
||||
__m256i x = _mm256_set1_epi32(*aa);
|
||||
__m256i r0 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb)));
|
||||
__m256i r1 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8)));
|
||||
bb += b_sl;
|
||||
++aa;
|
||||
for (; //
|
||||
aa < aaend; //
|
||||
bb += b_sl, ++aa) {
|
||||
INTTYPE ai = *aa;
|
||||
if (ai == 0) continue;
|
||||
x = _mm256_set1_epi32(ai);
|
||||
r0 = _mm256_add_epi32(r0, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb))));
|
||||
r1 = _mm256_add_epi32(r1, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8))));
|
||||
}
|
||||
_mm256_storeu_si256((__m256i*)(res), r0);
|
||||
_mm256_storeu_si256((__m256i*)(res + 8), r1);
|
||||
}
|
||||
|
||||
void zn32_vec_fn(mat8cols_avx)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) {
|
||||
if (nrows == 0) {
|
||||
memset(res, 0, 8 * sizeof(int32_t));
|
||||
return;
|
||||
}
|
||||
const INTTYPE* aa = a;
|
||||
const INTTYPE* const aaend = a + nrows;
|
||||
const int32_t* bb = b;
|
||||
__m256i x = _mm256_set1_epi32(*aa);
|
||||
__m256i r0 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb)));
|
||||
bb += b_sl;
|
||||
++aa;
|
||||
for (; //
|
||||
aa < aaend; //
|
||||
bb += b_sl, ++aa) {
|
||||
INTTYPE ai = *aa;
|
||||
if (ai == 0) continue;
|
||||
x = _mm256_set1_epi32(ai);
|
||||
r0 = _mm256_add_epi32(r0, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb))));
|
||||
}
|
||||
_mm256_storeu_si256((__m256i*)(res), r0);
|
||||
}
|
||||
|
||||
typedef void (*vm_f)(uint64_t nrows, //
|
||||
int32_t* res, //
|
||||
const INTTYPE* a, //
|
||||
const int32_t* b, uint64_t b_sl //
|
||||
);
|
||||
static const vm_f zn32_vec_mat8kcols_avx[4] = { //
|
||||
zn32_vec_fn(mat8cols_avx), //
|
||||
zn32_vec_fn(mat16cols_avx), //
|
||||
zn32_vec_fn(mat24cols_avx), //
|
||||
zn32_vec_fn(mat32cols_avx)};
|
||||
|
||||
/** @brief applies a vmp product (int32_t* input) */
|
||||
EXPORT void concat(default_zn32_vmp_apply, INTSN, avx)( //
|
||||
const MOD_Z* module, //
|
||||
int32_t* res, uint64_t res_size, //
|
||||
const INTTYPE* a, uint64_t a_size, //
|
||||
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) {
|
||||
const uint64_t rows = a_size < nrows ? a_size : nrows;
|
||||
const uint64_t cols = res_size < ncols ? res_size : ncols;
|
||||
const uint64_t ncolblk = cols >> 5;
|
||||
const uint64_t ncolrem = cols & 31;
|
||||
// copy the first full blocks
|
||||
const uint64_t full_blk_size = nrows * 32;
|
||||
const int32_t* mat = (int32_t*)pmat;
|
||||
int32_t* rr = res;
|
||||
for (uint64_t blk = 0; //
|
||||
blk < ncolblk; //
|
||||
++blk, mat += full_blk_size, rr += 32) {
|
||||
zn32_vec_mat32cols_avx_prefetch(rows, rr, a, mat);
|
||||
}
|
||||
// last block
|
||||
if (ncolrem) {
|
||||
uint64_t orig_rem = ncols - (ncolblk << 5);
|
||||
uint64_t b_sl = orig_rem >= 32 ? 32 : orig_rem;
|
||||
int32_t tmp[32];
|
||||
zn32_vec_mat8kcols_avx[(ncolrem - 1) >> 3](rows, tmp, a, mat, b_sl);
|
||||
memcpy(rr, tmp, ncolrem * sizeof(int32_t));
|
||||
}
|
||||
// trailing bytes
|
||||
memset(res + cols, 0, (res_size - cols) * sizeof(int32_t));
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
// This file is actually a template: it will be compiled multiple times with
|
||||
// different INTTYPES
|
||||
#ifndef INTTYPE
|
||||
#define INTTYPE int32_t
|
||||
#define INTSN i32
|
||||
#endif
|
||||
|
||||
#include <memory.h>
|
||||
|
||||
#include "zn_arithmetic_private.h"
|
||||
|
||||
#define concat_inner(aa, bb, cc) aa##_##bb##_##cc
|
||||
#define concat(aa, bb, cc) concat_inner(aa, bb, cc)
|
||||
#define zn32_vec_fn(cc) concat(zn32_vec, INTSN, cc)
|
||||
|
||||
// the ref version shares the same implementation for each fixed column size
|
||||
// optimized implementations may do something different.
|
||||
static __always_inline void IMPL_zn32_vec_matcols_ref(
|
||||
const uint64_t NCOLS, // fixed number of columns
|
||||
uint64_t nrows, // nrows of b
|
||||
int32_t* res, // result: size NCOLS, only the first min(b_sl, NCOLS) are relevant
|
||||
const INTTYPE* a, // a: nrows-sized vector
|
||||
const int32_t* b, uint64_t b_sl // b: nrows * min(b_sl, NCOLS) matrix
|
||||
) {
|
||||
memset(res, 0, NCOLS * sizeof(int32_t));
|
||||
for (uint64_t row = 0; row < nrows; ++row) {
|
||||
int32_t ai = a[row];
|
||||
const int32_t* bb = b + row * b_sl;
|
||||
for (uint64_t i = 0; i < NCOLS; ++i) {
|
||||
res[i] += ai * bb[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void zn32_vec_fn(mat32cols_ref)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) {
|
||||
IMPL_zn32_vec_matcols_ref(32, nrows, res, a, b, b_sl);
|
||||
}
|
||||
void zn32_vec_fn(mat24cols_ref)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) {
|
||||
IMPL_zn32_vec_matcols_ref(24, nrows, res, a, b, b_sl);
|
||||
}
|
||||
void zn32_vec_fn(mat16cols_ref)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) {
|
||||
IMPL_zn32_vec_matcols_ref(16, nrows, res, a, b, b_sl);
|
||||
}
|
||||
void zn32_vec_fn(mat8cols_ref)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) {
|
||||
IMPL_zn32_vec_matcols_ref(8, nrows, res, a, b, b_sl);
|
||||
}
|
||||
|
||||
typedef void (*vm_f)(uint64_t nrows, //
|
||||
int32_t* res, //
|
||||
const INTTYPE* a, //
|
||||
const int32_t* b, uint64_t b_sl //
|
||||
);
|
||||
static const vm_f zn32_vec_mat8kcols_ref[4] = { //
|
||||
zn32_vec_fn(mat8cols_ref), //
|
||||
zn32_vec_fn(mat16cols_ref), //
|
||||
zn32_vec_fn(mat24cols_ref), //
|
||||
zn32_vec_fn(mat32cols_ref)};
|
||||
|
||||
/** @brief applies a vmp product (int32_t* input) */
|
||||
EXPORT void concat(default_zn32_vmp_apply, INTSN, ref)( //
|
||||
const MOD_Z* module, //
|
||||
int32_t* res, uint64_t res_size, //
|
||||
const INTTYPE* a, uint64_t a_size, //
|
||||
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) {
|
||||
const uint64_t rows = a_size < nrows ? a_size : nrows;
|
||||
const uint64_t cols = res_size < ncols ? res_size : ncols;
|
||||
const uint64_t ncolblk = cols >> 5;
|
||||
const uint64_t ncolrem = cols & 31;
|
||||
// copy the first full blocks
|
||||
const uint32_t full_blk_size = nrows * 32;
|
||||
const int32_t* mat = (int32_t*)pmat;
|
||||
int32_t* rr = res;
|
||||
for (uint64_t blk = 0; //
|
||||
blk < ncolblk; //
|
||||
++blk, mat += full_blk_size, rr += 32) {
|
||||
zn32_vec_fn(mat32cols_ref)(rows, rr, a, mat, 32);
|
||||
}
|
||||
// last block
|
||||
if (ncolrem) {
|
||||
uint64_t orig_rem = ncols - (ncolblk << 5);
|
||||
uint64_t b_sl = orig_rem >= 32 ? 32 : orig_rem;
|
||||
int32_t tmp[32];
|
||||
zn32_vec_mat8kcols_ref[(ncolrem - 1) >> 3](rows, tmp, a, mat, b_sl);
|
||||
memcpy(rr, tmp, ncolrem * sizeof(int32_t));
|
||||
}
|
||||
// trailing bytes
|
||||
memset(res + cols, 0, (res_size - cols) * sizeof(int32_t));
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
#define INTTYPE int8_t
|
||||
#define INTSN i8
|
||||
|
||||
#include "zn_vmp_int32_avx.c"
|
||||
@@ -0,0 +1,4 @@
|
||||
#define INTTYPE int8_t
|
||||
#define INTSN i8
|
||||
|
||||
#include "zn_vmp_int32_ref.c"
|
||||
@@ -0,0 +1,185 @@
|
||||
#include <memory.h>
|
||||
|
||||
#include "zn_arithmetic_private.h"
|
||||
|
||||
/** @brief size in bytes of a prepared matrix (for custom allocation) */
|
||||
EXPORT uint64_t bytes_of_zn32_vmp_pmat(const MOD_Z* module, // N
|
||||
uint64_t nrows, uint64_t ncols // dimensions
|
||||
) {
|
||||
return (nrows * ncols + 7) * sizeof(int32_t);
|
||||
}
|
||||
|
||||
/** @brief allocates a prepared matrix (release with delete_zn32_vmp_pmat) */
|
||||
EXPORT ZN32_VMP_PMAT* new_zn32_vmp_pmat(const MOD_Z* module, // N
|
||||
uint64_t nrows, uint64_t ncols) {
|
||||
return (ZN32_VMP_PMAT*)spqlios_alloc(bytes_of_zn32_vmp_pmat(module, nrows, ncols));
|
||||
}
|
||||
|
||||
/** @brief deletes a prepared matrix (release with free) */
|
||||
EXPORT void delete_zn32_vmp_pmat(ZN32_VMP_PMAT* ptr) { spqlios_free(ptr); }
|
||||
|
||||
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||
EXPORT void default_zn32_vmp_prepare_contiguous_ref( //
|
||||
const MOD_Z* module,
|
||||
ZN32_VMP_PMAT* pmat, // output
|
||||
const int32_t* mat, uint64_t nrows, uint64_t ncols // a
|
||||
) {
|
||||
int32_t* const out = (int32_t*)pmat;
|
||||
const uint64_t nblk = ncols >> 5;
|
||||
const uint64_t ncols_rem = ncols & 31;
|
||||
const uint64_t final_elems = (8 - nrows * ncols) & 7;
|
||||
for (uint64_t blk = 0; blk < nblk; ++blk) {
|
||||
int32_t* outblk = out + blk * nrows * 32;
|
||||
const int32_t* srcblk = mat + blk * 32;
|
||||
for (uint64_t row = 0; row < nrows; ++row) {
|
||||
int32_t* dest = outblk + row * 32;
|
||||
const int32_t* src = srcblk + row * ncols;
|
||||
for (uint64_t i = 0; i < 32; ++i) {
|
||||
dest[i] = src[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
// copy the last block if any
|
||||
if (ncols_rem) {
|
||||
int32_t* outblk = out + nblk * nrows * 32;
|
||||
const int32_t* srcblk = mat + nblk * 32;
|
||||
for (uint64_t row = 0; row < nrows; ++row) {
|
||||
int32_t* dest = outblk + row * ncols_rem;
|
||||
const int32_t* src = srcblk + row * ncols;
|
||||
for (uint64_t i = 0; i < ncols_rem; ++i) {
|
||||
dest[i] = src[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
// zero-out the final elements that may be accessed
|
||||
if (final_elems) {
|
||||
int32_t* f = out + nrows * ncols;
|
||||
for (uint64_t i = 0; i < final_elems; ++i) {
|
||||
f[i] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */
|
||||
EXPORT void default_zn32_vmp_prepare_dblptr_ref( //
|
||||
const MOD_Z* module,
|
||||
ZN32_VMP_PMAT* pmat, // output
|
||||
const int32_t** mat, uint64_t nrows, uint64_t ncols // a
|
||||
) {
|
||||
for (uint64_t row_i = 0; row_i < nrows; ++row_i) {
|
||||
default_zn32_vmp_prepare_row_ref(module, pmat, mat[row_i], row_i, nrows, ncols);
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */
|
||||
EXPORT void default_zn32_vmp_prepare_row_ref( //
|
||||
const MOD_Z* module,
|
||||
ZN32_VMP_PMAT* pmat, // output
|
||||
const int32_t* row, uint64_t row_i, uint64_t nrows, uint64_t ncols // a
|
||||
) {
|
||||
int32_t* const out = (int32_t*)pmat;
|
||||
const uint64_t nblk = ncols >> 5;
|
||||
const uint64_t ncols_rem = ncols & 31;
|
||||
const uint64_t final_elems = (row_i == nrows - 1) && (8 - nrows * ncols) & 7;
|
||||
for (uint64_t blk = 0; blk < nblk; ++blk) {
|
||||
int32_t* outblk = out + blk * nrows * 32;
|
||||
int32_t* dest = outblk + row_i * 32;
|
||||
const int32_t* src = row + blk * 32;
|
||||
for (uint64_t i = 0; i < 32; ++i) {
|
||||
dest[i] = src[i];
|
||||
}
|
||||
}
|
||||
// copy the last block if any
|
||||
if (ncols_rem) {
|
||||
int32_t* outblk = out + nblk * nrows * 32;
|
||||
int32_t* dest = outblk + row_i * ncols_rem;
|
||||
const int32_t* src = row + nblk * 32;
|
||||
for (uint64_t i = 0; i < ncols_rem; ++i) {
|
||||
dest[i] = src[i];
|
||||
}
|
||||
}
|
||||
// zero-out the final elements that may be accessed
|
||||
if (final_elems) {
|
||||
int32_t* f = out + nrows * ncols;
|
||||
for (uint64_t i = 0; i < final_elems; ++i) {
|
||||
f[i] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if 0
|
||||
|
||||
#define IMPL_zn32_vec_ixxx_matyyycols_ref(NCOLS) \
|
||||
memset(res, 0, NCOLS * sizeof(int32_t)); \
|
||||
for (uint64_t row = 0; row < nrows; ++row) { \
|
||||
int32_t ai = a[row]; \
|
||||
const int32_t* bb = b + row * b_sl; \
|
||||
for (uint64_t i = 0; i < NCOLS; ++i) { \
|
||||
res[i] += ai * bb[i]; \
|
||||
} \
|
||||
}
|
||||
|
||||
#define IMPL_zn32_vec_ixxx_mat8cols_ref() IMPL_zn32_vec_ixxx_matyyycols_ref(8)
|
||||
#define IMPL_zn32_vec_ixxx_mat16cols_ref() IMPL_zn32_vec_ixxx_matyyycols_ref(16)
|
||||
#define IMPL_zn32_vec_ixxx_mat24cols_ref() IMPL_zn32_vec_ixxx_matyyycols_ref(24)
|
||||
#define IMPL_zn32_vec_ixxx_mat32cols_ref() IMPL_zn32_vec_ixxx_matyyycols_ref(32)
|
||||
|
||||
void zn32_vec_i8_mat32cols_ref(uint64_t nrows, int32_t* res, const int8_t* a, const int32_t* b, uint64_t b_sl) {
|
||||
IMPL_zn32_vec_ixxx_mat32cols_ref()
|
||||
}
|
||||
void zn32_vec_i16_mat32cols_ref(uint64_t nrows, int32_t* res, const int16_t* a, const int32_t* b, uint64_t b_sl) {
|
||||
IMPL_zn32_vec_ixxx_mat32cols_ref()
|
||||
}
|
||||
|
||||
void zn32_vec_i32_mat32cols_ref(uint64_t nrows, int32_t* res, const int32_t* a, const int32_t* b, uint64_t b_sl) {
|
||||
IMPL_zn32_vec_ixxx_mat32cols_ref()
|
||||
}
|
||||
void zn32_vec_i32_mat24cols_ref(uint64_t nrows, int32_t* res, const int32_t* a, const int32_t* b, uint64_t b_sl) {
|
||||
IMPL_zn32_vec_ixxx_mat24cols_ref()
|
||||
}
|
||||
void zn32_vec_i32_mat16cols_ref(uint64_t nrows, int32_t* res, const int32_t* a, const int32_t* b, uint64_t b_sl) {
|
||||
IMPL_zn32_vec_ixxx_mat16cols_ref()
|
||||
}
|
||||
void zn32_vec_i32_mat8cols_ref(uint64_t nrows, int32_t* res, const int32_t* a, const int32_t* b, uint64_t b_sl) {
|
||||
IMPL_zn32_vec_ixxx_mat8cols_ref()
|
||||
}
|
||||
typedef void (*zn32_vec_i32_mat8kcols_ref_f)(uint64_t nrows, //
|
||||
int32_t* res, //
|
||||
const int32_t* a, //
|
||||
const int32_t* b, uint64_t b_sl //
|
||||
);
|
||||
zn32_vec_i32_mat8kcols_ref_f zn32_vec_i32_mat8kcols_ref[4] = { //
|
||||
zn32_vec_i32_mat8cols_ref, zn32_vec_i32_mat16cols_ref, //
|
||||
zn32_vec_i32_mat24cols_ref, zn32_vec_i32_mat32cols_ref};
|
||||
|
||||
/** @brief applies a vmp product (int32_t* input) */
|
||||
EXPORT void default_zn32_vmp_apply_i32_ref(const MOD_Z* module, //
|
||||
int32_t* res, uint64_t res_size, //
|
||||
const int32_t* a, uint64_t a_size, //
|
||||
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) {
|
||||
const uint64_t rows = a_size < nrows ? a_size : nrows;
|
||||
const uint64_t cols = res_size < ncols ? res_size : ncols;
|
||||
const uint64_t ncolblk = cols >> 5;
|
||||
const uint64_t ncolrem = cols & 31;
|
||||
// copy the first full blocks
|
||||
const uint32_t full_blk_size = nrows * 32;
|
||||
const int32_t* mat = (int32_t*)pmat;
|
||||
int32_t* rr = res;
|
||||
for (uint64_t blk = 0; //
|
||||
blk < ncolblk; //
|
||||
++blk, mat += full_blk_size, rr += 32) {
|
||||
zn32_vec_i32_mat32cols_ref(rows, rr, a, mat, 32);
|
||||
}
|
||||
// last block
|
||||
if (ncolrem) {
|
||||
uint64_t orig_rem = ncols - (ncolblk << 5);
|
||||
uint64_t b_sl = orig_rem >= 32 ? 32 : orig_rem;
|
||||
int32_t tmp[32];
|
||||
zn32_vec_i32_mat8kcols_ref[(ncolrem - 1) >> 3](rows, tmp, a, mat, b_sl);
|
||||
memcpy(rr, tmp, ncolrem * sizeof(int32_t));
|
||||
}
|
||||
// trailing bytes
|
||||
memset(res + cols, 0, (res_size - cols) * sizeof(int32_t));
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,38 @@
|
||||
#include "vec_znx_arithmetic_private.h"
|
||||
|
||||
/** @brief res = a * b : small integer polynomial product */
|
||||
EXPORT void fft64_znx_small_single_product(const MODULE* module, // N
|
||||
int64_t* res, // output
|
||||
const int64_t* a, // a
|
||||
const int64_t* b, // b
|
||||
uint8_t* tmp) {
|
||||
const uint64_t nn = module->nn;
|
||||
double* const ffta = (double*)tmp;
|
||||
double* const fftb = ((double*)tmp) + nn;
|
||||
reim_from_znx64(module->mod.fft64.p_conv, ffta, a);
|
||||
reim_from_znx64(module->mod.fft64.p_conv, fftb, b);
|
||||
reim_fft(module->mod.fft64.p_fft, ffta);
|
||||
reim_fft(module->mod.fft64.p_fft, fftb);
|
||||
reim_fftvec_mul_simple(module->m, ffta, ffta, fftb);
|
||||
reim_ifft(module->mod.fft64.p_ifft, ffta);
|
||||
reim_to_znx64(module->mod.fft64.p_reim_to_znx, res, ffta);
|
||||
}
|
||||
|
||||
/** @brief tmp bytes required for znx_small_single_product */
|
||||
EXPORT uint64_t fft64_znx_small_single_product_tmp_bytes(const MODULE* module, uint64_t nn) {
|
||||
return 2 * nn * sizeof(double);
|
||||
}
|
||||
|
||||
/** @brief res = a * b : small integer polynomial product */
|
||||
EXPORT void znx_small_single_product(const MODULE* module, // N
|
||||
int64_t* res, // output
|
||||
const int64_t* a, // a
|
||||
const int64_t* b, // b
|
||||
uint8_t* tmp) {
|
||||
module->func.znx_small_single_product(module, res, a, b, tmp);
|
||||
}
|
||||
|
||||
/** @brief tmp bytes required for znx_small_single_product */
|
||||
EXPORT uint64_t znx_small_single_product_tmp_bytes(const MODULE* module, uint64_t nn) {
|
||||
return module->func.znx_small_single_product_tmp_bytes(module, nn);
|
||||
}
|
||||
@@ -0,0 +1,524 @@
|
||||
#include "coeffs_arithmetic.h"
|
||||
|
||||
#include <assert.h>
|
||||
#include <memory.h>
|
||||
|
||||
/** res = a + b */
|
||||
EXPORT void znx_add_i64_ref(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b) {
|
||||
for (uint64_t i = 0; i < nn; ++i) {
|
||||
res[i] = a[i] + b[i];
|
||||
}
|
||||
}
|
||||
/** res = a - b */
|
||||
EXPORT void znx_sub_i64_ref(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b) {
|
||||
for (uint64_t i = 0; i < nn; ++i) {
|
||||
res[i] = a[i] - b[i];
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void znx_negate_i64_ref(uint64_t nn, int64_t* res, const int64_t* a) {
|
||||
for (uint64_t i = 0; i < nn; ++i) {
|
||||
res[i] = -a[i];
|
||||
}
|
||||
}
|
||||
EXPORT void znx_copy_i64_ref(uint64_t nn, int64_t* res, const int64_t* a) { memcpy(res, a, nn * sizeof(int64_t)); }
|
||||
|
||||
EXPORT void znx_zero_i64_ref(uint64_t nn, int64_t* res) { memset(res, 0, nn * sizeof(int64_t)); }
|
||||
|
||||
EXPORT void rnx_divide_by_m_ref(uint64_t n, double m, double* res, const double* a) {
|
||||
const double invm = 1. / m;
|
||||
for (uint64_t i = 0; i < n; ++i) {
|
||||
res[i] = a[i] * invm;
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void rnx_rotate_f64(uint64_t nn, int64_t p, double* res, const double* in) {
|
||||
uint64_t a = (-p) & (2 * nn - 1); // a= (-p) (pos)mod (2*nn)
|
||||
|
||||
if (a < nn) { // rotate to the left
|
||||
uint64_t nma = nn - a;
|
||||
// rotate first half
|
||||
for (uint64_t j = 0; j < nma; j++) {
|
||||
res[j] = in[j + a];
|
||||
}
|
||||
for (uint64_t j = nma; j < nn; j++) {
|
||||
res[j] = -in[j - nma];
|
||||
}
|
||||
} else {
|
||||
a -= nn;
|
||||
uint64_t nma = nn - a;
|
||||
for (uint64_t j = 0; j < nma; j++) {
|
||||
res[j] = -in[j + a];
|
||||
}
|
||||
for (uint64_t j = nma; j < nn; j++) {
|
||||
// rotate first half
|
||||
res[j] = in[j - nma];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void znx_rotate_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in) {
|
||||
uint64_t a = (-p) & (2 * nn - 1); // a= (-p) (pos)mod (2*nn)
|
||||
|
||||
if (a < nn) { // rotate to the left
|
||||
uint64_t nma = nn - a;
|
||||
// rotate first half
|
||||
for (uint64_t j = 0; j < nma; j++) {
|
||||
res[j] = in[j + a];
|
||||
}
|
||||
for (uint64_t j = nma; j < nn; j++) {
|
||||
res[j] = -in[j - nma];
|
||||
}
|
||||
} else {
|
||||
a -= nn;
|
||||
uint64_t nma = nn - a;
|
||||
for (uint64_t j = 0; j < nma; j++) {
|
||||
res[j] = -in[j + a];
|
||||
}
|
||||
for (uint64_t j = nma; j < nn; j++) {
|
||||
// rotate first half
|
||||
res[j] = in[j - nma];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void rnx_mul_xp_minus_one_f64(uint64_t nn, int64_t p, double* res, const double* in) {
|
||||
uint64_t a = (-p) & (2 * nn - 1); // a= (-p) (pos)mod (2*nn)
|
||||
if (a < nn) { // rotate to the left
|
||||
uint64_t nma = nn - a;
|
||||
// rotate first half
|
||||
for (uint64_t j = 0; j < nma; j++) {
|
||||
res[j] = in[j + a] - in[j];
|
||||
}
|
||||
for (uint64_t j = nma; j < nn; j++) {
|
||||
res[j] = -in[j - nma] - in[j];
|
||||
}
|
||||
} else {
|
||||
a -= nn;
|
||||
uint64_t nma = nn - a;
|
||||
for (uint64_t j = 0; j < nma; j++) {
|
||||
res[j] = -in[j + a] - in[j];
|
||||
}
|
||||
for (uint64_t j = nma; j < nn; j++) {
|
||||
// rotate first half
|
||||
res[j] = in[j - nma] - in[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void znx_mul_xp_minus_one_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in) {
|
||||
uint64_t a = (-p) & (2 * nn - 1); // a= (-p) (pos)mod (2*nn)
|
||||
if (a < nn) { // rotate to the left
|
||||
uint64_t nma = nn - a;
|
||||
// rotate first half
|
||||
for (uint64_t j = 0; j < nma; j++) {
|
||||
res[j] = in[j + a] - in[j];
|
||||
}
|
||||
for (uint64_t j = nma; j < nn; j++) {
|
||||
res[j] = -in[j - nma] - in[j];
|
||||
}
|
||||
} else {
|
||||
a -= nn;
|
||||
uint64_t nma = nn - a;
|
||||
for (uint64_t j = 0; j < nma; j++) {
|
||||
res[j] = -in[j + a] - in[j];
|
||||
}
|
||||
for (uint64_t j = nma; j < nn; j++) {
|
||||
// rotate first half
|
||||
res[j] = in[j - nma] - in[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void znx_mul_xp_minus_one_inplace_i64(uint64_t nn, int64_t p, int64_t* res) {
|
||||
const uint64_t _2mn = 2 * nn - 1;
|
||||
const uint64_t _mn = nn - 1;
|
||||
uint64_t nb_modif = 0;
|
||||
uint64_t j_start = 0;
|
||||
while (nb_modif < nn) {
|
||||
// follow the cycle that start with j_start
|
||||
uint64_t j = j_start;
|
||||
int64_t tmp1 = res[j];
|
||||
do {
|
||||
// find where the value should go, and with which sign
|
||||
uint64_t new_j = (j + p) & _2mn; // mod 2n to get the position and sign
|
||||
uint64_t new_j_n = new_j & _mn; // mod n to get just the position
|
||||
// exchange this position with tmp1 (and take care of the sign)
|
||||
int64_t tmp2 = res[new_j_n];
|
||||
res[new_j_n] = ((new_j < nn) ? tmp1 : -tmp1) - res[new_j_n];
|
||||
tmp1 = tmp2;
|
||||
// move to the new location, and store the number of items modified
|
||||
++nb_modif;
|
||||
j = new_j_n;
|
||||
} while (j != j_start);
|
||||
// move to the start of the next cycle:
|
||||
// we need to find an index that has not been touched yet, and pick it as next j_start.
|
||||
// in practice, it is enough to do +1, because the group of rotations is cyclic and 1 is a generator.
|
||||
++j_start;
|
||||
}
|
||||
}
|
||||
|
||||
// 0 < p < 2nn
|
||||
EXPORT void rnx_automorphism_f64(uint64_t nn, int64_t p, double* res, const double* in) {
|
||||
res[0] = in[0];
|
||||
uint64_t a = 0;
|
||||
uint64_t _2mn = 2 * nn - 1;
|
||||
for (uint64_t i = 1; i < nn; i++) {
|
||||
a = (a + p) & _2mn; // i*p mod 2n
|
||||
if (a < nn) {
|
||||
res[a] = in[i]; // res[ip mod 2n] = res[i]
|
||||
} else {
|
||||
res[a - nn] = -in[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void znx_automorphism_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in) {
|
||||
res[0] = in[0];
|
||||
uint64_t a = 0;
|
||||
uint64_t _2mn = 2 * nn - 1;
|
||||
for (uint64_t i = 1; i < nn; i++) {
|
||||
a = (a + p) & _2mn;
|
||||
if (a < nn) {
|
||||
res[a] = in[i]; // res[ip mod 2n] = res[i]
|
||||
} else {
|
||||
res[a - nn] = -in[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void rnx_rotate_inplace_f64(uint64_t nn, int64_t p, double* res) {
|
||||
const uint64_t _2mn = 2 * nn - 1;
|
||||
const uint64_t _mn = nn - 1;
|
||||
uint64_t nb_modif = 0;
|
||||
uint64_t j_start = 0;
|
||||
while (nb_modif < nn) {
|
||||
// follow the cycle that start with j_start
|
||||
uint64_t j = j_start;
|
||||
double tmp1 = res[j];
|
||||
do {
|
||||
// find where the value should go, and with which sign
|
||||
uint64_t new_j = (j + p) & _2mn; // mod 2n to get the position and sign
|
||||
uint64_t new_j_n = new_j & _mn; // mod n to get just the position
|
||||
// exchange this position with tmp1 (and take care of the sign)
|
||||
double tmp2 = res[new_j_n];
|
||||
res[new_j_n] = (new_j < nn) ? tmp1 : -tmp1;
|
||||
tmp1 = tmp2;
|
||||
// move to the new location, and store the number of items modified
|
||||
++nb_modif;
|
||||
j = new_j_n;
|
||||
} while (j != j_start);
|
||||
// move to the start of the next cycle:
|
||||
// we need to find an index that has not been touched yet, and pick it as next j_start.
|
||||
// in practice, it is enough to do +1, because the group of rotations is cyclic and 1 is a generator.
|
||||
++j_start;
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void znx_rotate_inplace_i64(uint64_t nn, int64_t p, int64_t* res) {
|
||||
const uint64_t _2mn = 2 * nn - 1;
|
||||
const uint64_t _mn = nn - 1;
|
||||
uint64_t nb_modif = 0;
|
||||
uint64_t j_start = 0;
|
||||
while (nb_modif < nn) {
|
||||
// follow the cycle that start with j_start
|
||||
uint64_t j = j_start;
|
||||
int64_t tmp1 = res[j];
|
||||
do {
|
||||
// find where the value should go, and with which sign
|
||||
uint64_t new_j = (j + p) & _2mn; // mod 2n to get the position and sign
|
||||
uint64_t new_j_n = new_j & _mn; // mod n to get just the position
|
||||
// exchange this position with tmp1 (and take care of the sign)
|
||||
int64_t tmp2 = res[new_j_n];
|
||||
res[new_j_n] = (new_j < nn) ? tmp1 : -tmp1;
|
||||
tmp1 = tmp2;
|
||||
// move to the new location, and store the number of items modified
|
||||
++nb_modif;
|
||||
j = new_j_n;
|
||||
} while (j != j_start);
|
||||
// move to the start of the next cycle:
|
||||
// we need to find an index that has not been touched yet, and pick it as next j_start.
|
||||
// in practice, it is enough to do +1, because the group of rotations is cyclic and 1 is a generator.
|
||||
++j_start;
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void rnx_mul_xp_minus_one_inplace_f64(uint64_t nn, int64_t p, double* res) {
|
||||
const uint64_t _2mn = 2 * nn - 1;
|
||||
const uint64_t _mn = nn - 1;
|
||||
uint64_t nb_modif = 0;
|
||||
uint64_t j_start = 0;
|
||||
while (nb_modif < nn) {
|
||||
// follow the cycle that start with j_start
|
||||
uint64_t j = j_start;
|
||||
double tmp1 = res[j];
|
||||
do {
|
||||
// find where the value should go, and with which sign
|
||||
uint64_t new_j = (j + p) & _2mn; // mod 2n to get the position and sign
|
||||
uint64_t new_j_n = new_j & _mn; // mod n to get just the position
|
||||
// exchange this position with tmp1 (and take care of the sign)
|
||||
double tmp2 = res[new_j_n];
|
||||
res[new_j_n] = ((new_j < nn) ? tmp1 : -tmp1) - res[new_j_n];
|
||||
tmp1 = tmp2;
|
||||
// move to the new location, and store the number of items modified
|
||||
++nb_modif;
|
||||
j = new_j_n;
|
||||
} while (j != j_start);
|
||||
// move to the start of the next cycle:
|
||||
// we need to find an index that has not been touched yet, and pick it as next j_start.
|
||||
// in practice, it is enough to do +1, because the group of rotations is cyclic and 1 is a generator.
|
||||
++j_start;
|
||||
}
|
||||
}
|
||||
|
||||
__always_inline int64_t get_base_k_digit(const int64_t x, const uint64_t base_k) {
|
||||
return (x << (64 - base_k)) >> (64 - base_k);
|
||||
}
|
||||
|
||||
__always_inline int64_t get_base_k_carry(const int64_t x, const int64_t digit, const uint64_t base_k) {
|
||||
return (x - digit) >> base_k;
|
||||
}
|
||||
|
||||
EXPORT void znx_normalize(uint64_t nn, uint64_t base_k, int64_t* out, int64_t* carry_out, const int64_t* in,
|
||||
const int64_t* carry_in) {
|
||||
assert(in);
|
||||
if (out != 0) {
|
||||
if (carry_in != 0x0 && carry_out != 0x0) {
|
||||
// with carry in and carry out is computed
|
||||
for (uint64_t i = 0; i < nn; ++i) {
|
||||
const int64_t x = in[i];
|
||||
const int64_t cin = carry_in[i];
|
||||
|
||||
int64_t digit = get_base_k_digit(x, base_k);
|
||||
int64_t carry = get_base_k_carry(x, digit, base_k);
|
||||
int64_t digit_plus_cin = digit + cin;
|
||||
int64_t y = get_base_k_digit(digit_plus_cin, base_k);
|
||||
int64_t cout = carry + get_base_k_carry(digit_plus_cin, y, base_k);
|
||||
|
||||
out[i] = y;
|
||||
carry_out[i] = cout;
|
||||
}
|
||||
} else if (carry_in != 0) {
|
||||
// with carry in and carry out is dropped
|
||||
for (uint64_t i = 0; i < nn; ++i) {
|
||||
const int64_t x = in[i];
|
||||
const int64_t cin = carry_in[i];
|
||||
|
||||
int64_t digit = get_base_k_digit(x, base_k);
|
||||
int64_t digit_plus_cin = digit + cin;
|
||||
int64_t y = get_base_k_digit(digit_plus_cin, base_k);
|
||||
|
||||
out[i] = y;
|
||||
}
|
||||
|
||||
} else if (carry_out != 0) {
|
||||
// no carry in and carry out is computed
|
||||
for (uint64_t i = 0; i < nn; ++i) {
|
||||
const int64_t x = in[i];
|
||||
|
||||
int64_t y = get_base_k_digit(x, base_k);
|
||||
int64_t cout = get_base_k_carry(x, y, base_k);
|
||||
|
||||
out[i] = y;
|
||||
carry_out[i] = cout;
|
||||
}
|
||||
|
||||
} else {
|
||||
// no carry in and carry out is dropped
|
||||
for (uint64_t i = 0; i < nn; ++i) {
|
||||
out[i] = get_base_k_digit(in[i], base_k);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
assert(carry_out);
|
||||
if (carry_in != 0x0) {
|
||||
// with carry in and carry out is computed
|
||||
for (uint64_t i = 0; i < nn; ++i) {
|
||||
const int64_t x = in[i];
|
||||
const int64_t cin = carry_in[i];
|
||||
|
||||
int64_t digit = get_base_k_digit(x, base_k);
|
||||
int64_t carry = get_base_k_carry(x, digit, base_k);
|
||||
int64_t digit_plus_cin = digit + cin;
|
||||
int64_t y = get_base_k_digit(digit_plus_cin, base_k);
|
||||
int64_t cout = carry + get_base_k_carry(digit_plus_cin, y, base_k);
|
||||
|
||||
carry_out[i] = cout;
|
||||
}
|
||||
} else {
|
||||
// no carry in and carry out is computed
|
||||
for (uint64_t i = 0; i < nn; ++i) {
|
||||
const int64_t x = in[i];
|
||||
|
||||
int64_t y = get_base_k_digit(x, base_k);
|
||||
int64_t cout = get_base_k_carry(x, y, base_k);
|
||||
|
||||
carry_out[i] = cout;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void znx_automorphism_inplace_i64(uint64_t nn, int64_t p, int64_t* res) {
|
||||
const uint64_t _2mn = 2 * nn - 1;
|
||||
const uint64_t _mn = nn - 1;
|
||||
const uint64_t m = nn >> 1;
|
||||
// reduce p mod 2n
|
||||
p &= _2mn;
|
||||
// uint64_t vp = p & _2mn;
|
||||
/// uint64_t target_modifs = m >> 1;
|
||||
// we proceed by increasing binary valuation
|
||||
for (uint64_t binval = 1, vp = p & _2mn, orb_size = m; binval < nn;
|
||||
binval <<= 1, vp = (vp << 1) & _2mn, orb_size >>= 1) {
|
||||
// In this loop, we are going to treat the orbit of indexes = binval mod 2.binval.
|
||||
// At the beginning of this loop we have:
|
||||
// vp = binval * p mod 2n
|
||||
// target_modif = m / binval (i.e. order of the orbit binval % 2.binval)
|
||||
|
||||
// first, handle the orders 1 and 2.
|
||||
// if p*binval == binval % 2n: we're done!
|
||||
if (vp == binval) return;
|
||||
// if p*binval == -binval % 2n: nega-mirror the orbit and all the sub-orbits and exit!
|
||||
if (((vp + binval) & _2mn) == 0) {
|
||||
for (uint64_t j = binval; j < m; j += binval) {
|
||||
int64_t tmp = res[j];
|
||||
res[j] = -res[nn - j];
|
||||
res[nn - j] = -tmp;
|
||||
}
|
||||
res[m] = -res[m];
|
||||
return;
|
||||
}
|
||||
// if p*binval == binval + n % 2n: negate the orbit and exit
|
||||
if (((vp - binval) & _mn) == 0) {
|
||||
for (uint64_t j = binval; j < nn; j += 2 * binval) {
|
||||
res[j] = -res[j];
|
||||
}
|
||||
return;
|
||||
}
|
||||
// if p*binval == n - binval % 2n: mirror the orbit and continue!
|
||||
if (((vp + binval) & _mn) == 0) {
|
||||
for (uint64_t j = binval; j < m; j += 2 * binval) {
|
||||
int64_t tmp = res[j];
|
||||
res[j] = res[nn - j];
|
||||
res[nn - j] = tmp;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
// otherwise we will follow the orbit cycles,
|
||||
// starting from binval and -binval in parallel
|
||||
uint64_t j_start = binval;
|
||||
uint64_t nb_modif = 0;
|
||||
while (nb_modif < orb_size) {
|
||||
// follow the cycle that start with j_start
|
||||
uint64_t j = j_start;
|
||||
int64_t tmp1 = res[j];
|
||||
int64_t tmp2 = res[nn - j];
|
||||
do {
|
||||
// find where the value should go, and with which sign
|
||||
uint64_t new_j = (j * p) & _2mn; // mod 2n to get the position and sign
|
||||
uint64_t new_j_n = new_j & _mn; // mod n to get just the position
|
||||
// exchange this position with tmp1 (and take care of the sign)
|
||||
int64_t tmp1a = res[new_j_n];
|
||||
int64_t tmp2a = res[nn - new_j_n];
|
||||
if (new_j < nn) {
|
||||
res[new_j_n] = tmp1;
|
||||
res[nn - new_j_n] = tmp2;
|
||||
} else {
|
||||
res[new_j_n] = -tmp1;
|
||||
res[nn - new_j_n] = -tmp2;
|
||||
}
|
||||
tmp1 = tmp1a;
|
||||
tmp2 = tmp2a;
|
||||
// move to the new location, and store the number of items modified
|
||||
nb_modif += 2;
|
||||
j = new_j_n;
|
||||
} while (j != j_start);
|
||||
// move to the start of the next cycle:
|
||||
// we need to find an index that has not been touched yet, and pick it as next j_start.
|
||||
// in practice, it is enough to do *5, because 5 is a generator.
|
||||
j_start = (5 * j_start) & _mn;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void rnx_automorphism_inplace_f64(uint64_t nn, int64_t p, double* res) {
|
||||
const uint64_t _2mn = 2 * nn - 1;
|
||||
const uint64_t _mn = nn - 1;
|
||||
const uint64_t m = nn >> 1;
|
||||
// reduce p mod 2n
|
||||
p &= _2mn;
|
||||
// uint64_t vp = p & _2mn;
|
||||
/// uint64_t target_modifs = m >> 1;
|
||||
// we proceed by increasing binary valuation
|
||||
for (uint64_t binval = 1, vp = p & _2mn, orb_size = m; binval < nn;
|
||||
binval <<= 1, vp = (vp << 1) & _2mn, orb_size >>= 1) {
|
||||
// In this loop, we are going to treat the orbit of indexes = binval mod 2.binval.
|
||||
// At the beginning of this loop we have:
|
||||
// vp = binval * p mod 2n
|
||||
// target_modif = m / binval (i.e. order of the orbit binval % 2.binval)
|
||||
|
||||
// first, handle the orders 1 and 2.
|
||||
// if p*binval == binval % 2n: we're done!
|
||||
if (vp == binval) return;
|
||||
// if p*binval == -binval % 2n: nega-mirror the orbit and all the sub-orbits and exit!
|
||||
if (((vp + binval) & _2mn) == 0) {
|
||||
for (uint64_t j = binval; j < m; j += binval) {
|
||||
double tmp = res[j];
|
||||
res[j] = -res[nn - j];
|
||||
res[nn - j] = -tmp;
|
||||
}
|
||||
res[m] = -res[m];
|
||||
return;
|
||||
}
|
||||
// if p*binval == binval + n % 2n: negate the orbit and exit
|
||||
if (((vp - binval) & _mn) == 0) {
|
||||
for (uint64_t j = binval; j < nn; j += 2 * binval) {
|
||||
res[j] = -res[j];
|
||||
}
|
||||
return;
|
||||
}
|
||||
// if p*binval == n - binval % 2n: mirror the orbit and continue!
|
||||
if (((vp + binval) & _mn) == 0) {
|
||||
for (uint64_t j = binval; j < m; j += 2 * binval) {
|
||||
double tmp = res[j];
|
||||
res[j] = res[nn - j];
|
||||
res[nn - j] = tmp;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
// otherwise we will follow the orbit cycles,
|
||||
// starting from binval and -binval in parallel
|
||||
uint64_t j_start = binval;
|
||||
uint64_t nb_modif = 0;
|
||||
while (nb_modif < orb_size) {
|
||||
// follow the cycle that start with j_start
|
||||
uint64_t j = j_start;
|
||||
double tmp1 = res[j];
|
||||
double tmp2 = res[nn - j];
|
||||
do {
|
||||
// find where the value should go, and with which sign
|
||||
uint64_t new_j = (j * p) & _2mn; // mod 2n to get the position and sign
|
||||
uint64_t new_j_n = new_j & _mn; // mod n to get just the position
|
||||
// exchange this position with tmp1 (and take care of the sign)
|
||||
double tmp1a = res[new_j_n];
|
||||
double tmp2a = res[nn - new_j_n];
|
||||
if (new_j < nn) {
|
||||
res[new_j_n] = tmp1;
|
||||
res[nn - new_j_n] = tmp2;
|
||||
} else {
|
||||
res[new_j_n] = -tmp1;
|
||||
res[nn - new_j_n] = -tmp2;
|
||||
}
|
||||
tmp1 = tmp1a;
|
||||
tmp2 = tmp2a;
|
||||
// move to the new location, and store the number of items modified
|
||||
nb_modif += 2;
|
||||
j = new_j_n;
|
||||
} while (j != j_start);
|
||||
// move to the start of the next cycle:
|
||||
// we need to find an index that has not been touched yet, and pick it as next j_start.
|
||||
// in practice, it is enough to do *5, because 5 is a generator.
|
||||
j_start = (5 * j_start) & _mn;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,79 @@
|
||||
#ifndef SPQLIOS_COEFFS_ARITHMETIC_H
|
||||
#define SPQLIOS_COEFFS_ARITHMETIC_H
|
||||
|
||||
#include "../commons.h"
|
||||
|
||||
/** res = a + b */
|
||||
EXPORT void znx_add_i64_ref(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b);
|
||||
EXPORT void znx_add_i64_avx(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b);
|
||||
/** res = a - b */
|
||||
EXPORT void znx_sub_i64_ref(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b);
|
||||
EXPORT void znx_sub_i64_avx(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b);
|
||||
/** res = -a */
|
||||
EXPORT void znx_negate_i64_ref(uint64_t nn, int64_t* res, const int64_t* a);
|
||||
EXPORT void znx_negate_i64_avx(uint64_t nn, int64_t* res, const int64_t* a);
|
||||
/** res = a */
|
||||
EXPORT void znx_copy_i64_ref(uint64_t nn, int64_t* res, const int64_t* a);
|
||||
/** res = 0 */
|
||||
EXPORT void znx_zero_i64_ref(uint64_t nn, int64_t* res);
|
||||
|
||||
/** res = a / m where m is a power of 2 */
|
||||
EXPORT void rnx_divide_by_m_ref(uint64_t nn, double m, double* res, const double* a);
|
||||
EXPORT void rnx_divide_by_m_avx(uint64_t nn, double m, double* res, const double* a);
|
||||
|
||||
/**
|
||||
* @param res = X^p *in mod X^nn +1
|
||||
* @param nn the ring dimension
|
||||
* @param p a power for the rotation -2nn <= p <= 2nn
|
||||
* @param in is a rnx/znx vector of dimension nn
|
||||
*/
|
||||
EXPORT void rnx_rotate_f64(uint64_t nn, int64_t p, double* res, const double* in);
|
||||
EXPORT void znx_rotate_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in);
|
||||
EXPORT void rnx_rotate_inplace_f64(uint64_t nn, int64_t p, double* res);
|
||||
EXPORT void znx_rotate_inplace_i64(uint64_t nn, int64_t p, int64_t* res);
|
||||
|
||||
/**
|
||||
* @brief res(X) = in(X^p)
|
||||
* @param nn the ring dimension
|
||||
* @param p is odd integer and must be between 0 < p < 2nn
|
||||
* @param in is a rnx/znx vector of dimension nn
|
||||
*/
|
||||
EXPORT void rnx_automorphism_f64(uint64_t nn, int64_t p, double* res, const double* in);
|
||||
EXPORT void znx_automorphism_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in);
|
||||
EXPORT void rnx_automorphism_inplace_f64(uint64_t nn, int64_t p, double* res);
|
||||
EXPORT void znx_automorphism_inplace_i64(uint64_t nn, int64_t p, int64_t* res);
|
||||
|
||||
/**
|
||||
* @brief res = (X^p-1).in
|
||||
* @param nn the ring dimension
|
||||
* @param p must be between -2nn <= p <= 2nn
|
||||
* @param in is a rnx/znx vector of dimension nn
|
||||
*/
|
||||
EXPORT void rnx_mul_xp_minus_one_f64(uint64_t nn, int64_t p, double* res, const double* in);
|
||||
EXPORT void znx_mul_xp_minus_one_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in);
|
||||
EXPORT void rnx_mul_xp_minus_one_inplace_f64(uint64_t nn, int64_t p, double* res);
|
||||
EXPORT void znx_mul_xp_minus_one_inplace_i64(uint64_t nn, int64_t p, int64_t* res);
|
||||
|
||||
/**
|
||||
* @brief Normalize input plus carry mod-2^k. The following
|
||||
* equality holds @c {in + carry_in == out + carry_out . 2^k}.
|
||||
*
|
||||
* @c in must be in [-2^62 .. 2^62]
|
||||
*
|
||||
* @c out is in [ -2^(base_k-1), 2^(base_k-1) [.
|
||||
*
|
||||
* @c carry_in and @carry_out have at most 64+1-k bits.
|
||||
*
|
||||
* Null @c carry_in or @c carry_out are ignored.
|
||||
*
|
||||
* @param[in] nn the ring dimension
|
||||
* @param[in] base_k the base k
|
||||
* @param out output normalized znx
|
||||
* @param carry_out output carry znx
|
||||
* @param[in] in input znx
|
||||
* @param[in] carry_in input carry znx
|
||||
*/
|
||||
EXPORT void znx_normalize(uint64_t nn, uint64_t base_k, int64_t* out, int64_t* carry_out, const int64_t* in,
|
||||
const int64_t* carry_in);
|
||||
|
||||
#endif // SPQLIOS_COEFFS_ARITHMETIC_H
|
||||
@@ -0,0 +1,124 @@
|
||||
#include <immintrin.h>
|
||||
|
||||
#include "../commons_private.h"
|
||||
#include "coeffs_arithmetic.h"
|
||||
|
||||
// res = a + b. dimension n must be a power of 2
|
||||
EXPORT void znx_add_i64_avx(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b) {
|
||||
if (nn <= 2) {
|
||||
if (nn == 1) {
|
||||
res[0] = a[0] + b[0];
|
||||
} else {
|
||||
_mm_storeu_si128((__m128i*)res, //
|
||||
_mm_add_epi64( //
|
||||
_mm_loadu_si128((__m128i*)a), //
|
||||
_mm_loadu_si128((__m128i*)b)));
|
||||
}
|
||||
} else {
|
||||
const __m256i* aa = (__m256i*)a;
|
||||
const __m256i* bb = (__m256i*)b;
|
||||
__m256i* rr = (__m256i*)res;
|
||||
__m256i* const rrend = (__m256i*)(res + nn);
|
||||
do {
|
||||
_mm256_storeu_si256(rr, //
|
||||
_mm256_add_epi64( //
|
||||
_mm256_loadu_si256(aa), //
|
||||
_mm256_loadu_si256(bb)));
|
||||
++rr;
|
||||
++aa;
|
||||
++bb;
|
||||
} while (rr < rrend);
|
||||
}
|
||||
}
|
||||
|
||||
// res = a - b. dimension n must be a power of 2
|
||||
EXPORT void znx_sub_i64_avx(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b) {
|
||||
if (nn <= 2) {
|
||||
if (nn == 1) {
|
||||
res[0] = a[0] - b[0];
|
||||
} else {
|
||||
_mm_storeu_si128((__m128i*)res, //
|
||||
_mm_sub_epi64( //
|
||||
_mm_loadu_si128((__m128i*)a), //
|
||||
_mm_loadu_si128((__m128i*)b)));
|
||||
}
|
||||
} else {
|
||||
const __m256i* aa = (__m256i*)a;
|
||||
const __m256i* bb = (__m256i*)b;
|
||||
__m256i* rr = (__m256i*)res;
|
||||
__m256i* const rrend = (__m256i*)(res + nn);
|
||||
do {
|
||||
_mm256_storeu_si256(rr, //
|
||||
_mm256_sub_epi64( //
|
||||
_mm256_loadu_si256(aa), //
|
||||
_mm256_loadu_si256(bb)));
|
||||
++rr;
|
||||
++aa;
|
||||
++bb;
|
||||
} while (rr < rrend);
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void znx_negate_i64_avx(uint64_t nn, int64_t* res, const int64_t* a) {
|
||||
if (nn <= 2) {
|
||||
if (nn == 1) {
|
||||
res[0] = -a[0];
|
||||
} else {
|
||||
_mm_storeu_si128((__m128i*)res, //
|
||||
_mm_sub_epi64( //
|
||||
_mm_set1_epi64x(0), //
|
||||
_mm_loadu_si128((__m128i*)a)));
|
||||
}
|
||||
} else {
|
||||
const __m256i* aa = (__m256i*)a;
|
||||
__m256i* rr = (__m256i*)res;
|
||||
__m256i* const rrend = (__m256i*)(res + nn);
|
||||
do {
|
||||
_mm256_storeu_si256(rr, //
|
||||
_mm256_sub_epi64( //
|
||||
_mm256_set1_epi64x(0), //
|
||||
_mm256_loadu_si256(aa)));
|
||||
++rr;
|
||||
++aa;
|
||||
} while (rr < rrend);
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void rnx_divide_by_m_avx(uint64_t n, double m, double* res, const double* a) {
|
||||
// TODO: see if there is a faster way of dividing by a power of 2?
|
||||
const double invm = 1. / m;
|
||||
if (n < 8) {
|
||||
switch (n) {
|
||||
case 1:
|
||||
*res = *a * invm;
|
||||
break;
|
||||
case 2:
|
||||
_mm_storeu_pd(res, //
|
||||
_mm_mul_pd(_mm_loadu_pd(a), //
|
||||
_mm_set1_pd(invm)));
|
||||
break;
|
||||
case 4:
|
||||
_mm256_storeu_pd(res, //
|
||||
_mm256_mul_pd(_mm256_loadu_pd(a), //
|
||||
_mm256_set1_pd(invm)));
|
||||
break;
|
||||
default:
|
||||
NOT_SUPPORTED(); // non-power of 2
|
||||
}
|
||||
return;
|
||||
}
|
||||
const __m256d invm256 = _mm256_set1_pd(invm);
|
||||
double* rr = res;
|
||||
const double* aa = a;
|
||||
const double* const aaend = a + n;
|
||||
do {
|
||||
_mm256_storeu_pd(rr, //
|
||||
_mm256_mul_pd(_mm256_loadu_pd(aa), //
|
||||
invm256));
|
||||
_mm256_storeu_pd(rr + 4, //
|
||||
_mm256_mul_pd(_mm256_loadu_pd(aa + 4), //
|
||||
invm256));
|
||||
rr += 8;
|
||||
aa += 8;
|
||||
} while (aa < aaend);
|
||||
}
|
||||
@@ -0,0 +1,165 @@
|
||||
#include "commons.h"
|
||||
|
||||
#include <math.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
EXPORT void* UNDEFINED_p_ii(int32_t n, int32_t m) { UNDEFINED(); }
|
||||
EXPORT void* UNDEFINED_p_uu(uint32_t n, uint32_t m) { UNDEFINED(); }
|
||||
EXPORT double* UNDEFINED_dp_pi(const void* p, int32_t n) { UNDEFINED(); }
|
||||
EXPORT void* UNDEFINED_vp_pi(const void* p, int32_t n) { UNDEFINED(); }
|
||||
EXPORT void* UNDEFINED_vp_pu(const void* p, uint32_t n) { UNDEFINED(); }
|
||||
EXPORT void UNDEFINED_v_vpdp(const void* p, double* a) { UNDEFINED(); }
|
||||
EXPORT void UNDEFINED_v_vpvp(const void* p, void* a) { UNDEFINED(); }
|
||||
EXPORT double* NOT_IMPLEMENTED_dp_i(int32_t n) { NOT_IMPLEMENTED(); }
|
||||
EXPORT void* NOT_IMPLEMENTED_vp_i(int32_t n) { NOT_IMPLEMENTED(); }
|
||||
EXPORT void* NOT_IMPLEMENTED_vp_u(uint32_t n) { NOT_IMPLEMENTED(); }
|
||||
EXPORT void NOT_IMPLEMENTED_v_dp(double* a) { NOT_IMPLEMENTED(); }
|
||||
EXPORT void NOT_IMPLEMENTED_v_vp(void* p) { NOT_IMPLEMENTED(); }
|
||||
EXPORT void NOT_IMPLEMENTED_v_idpdpdp(int32_t n, double* a, const double* b, const double* c) { NOT_IMPLEMENTED(); }
|
||||
EXPORT void NOT_IMPLEMENTED_v_uvpcvpcvp(uint32_t n, void* r, const void* a, const void* b) { NOT_IMPLEMENTED(); }
|
||||
EXPORT void NOT_IMPLEMENTED_v_uvpvpcvp(uint32_t n, void* a, void* b, const void* o) { NOT_IMPLEMENTED(); }
|
||||
|
||||
#ifdef _WIN32
|
||||
#define __always_inline inline __attribute((always_inline))
|
||||
#endif
|
||||
|
||||
void internal_accurate_sincos(double* rcos, double* rsin, double x) {
|
||||
double _4_x_over_pi = 4 * x / M_PI;
|
||||
int64_t int_part = ((int64_t)rint(_4_x_over_pi)) & 7;
|
||||
double frac_part = _4_x_over_pi - (double)(int_part);
|
||||
double frac_x = M_PI * frac_part / 4.;
|
||||
// compute the taylor series
|
||||
double cosp = 1.;
|
||||
double sinp = 0.;
|
||||
double powx = 1.;
|
||||
int64_t nn = 0;
|
||||
while (fabs(powx) > 1e-20) {
|
||||
++nn;
|
||||
powx = powx * frac_x / (double)(nn); // x^n/n!
|
||||
switch (nn & 3) {
|
||||
case 0:
|
||||
cosp += powx;
|
||||
break;
|
||||
case 1:
|
||||
sinp += powx;
|
||||
break;
|
||||
case 2:
|
||||
cosp -= powx;
|
||||
break;
|
||||
case 3:
|
||||
sinp -= powx;
|
||||
break;
|
||||
default:
|
||||
abort(); // impossible
|
||||
}
|
||||
}
|
||||
// final multiplication
|
||||
switch (int_part) {
|
||||
case 0:
|
||||
*rcos = cosp;
|
||||
*rsin = sinp;
|
||||
break;
|
||||
case 1:
|
||||
*rcos = M_SQRT1_2 * (cosp - sinp);
|
||||
*rsin = M_SQRT1_2 * (cosp + sinp);
|
||||
break;
|
||||
case 2:
|
||||
*rcos = -sinp;
|
||||
*rsin = cosp;
|
||||
break;
|
||||
case 3:
|
||||
*rcos = -M_SQRT1_2 * (cosp + sinp);
|
||||
*rsin = M_SQRT1_2 * (cosp - sinp);
|
||||
break;
|
||||
case 4:
|
||||
*rcos = -cosp;
|
||||
*rsin = -sinp;
|
||||
break;
|
||||
case 5:
|
||||
*rcos = -M_SQRT1_2 * (cosp - sinp);
|
||||
*rsin = -M_SQRT1_2 * (cosp + sinp);
|
||||
break;
|
||||
case 6:
|
||||
*rcos = sinp;
|
||||
*rsin = -cosp;
|
||||
break;
|
||||
case 7:
|
||||
*rcos = M_SQRT1_2 * (cosp + sinp);
|
||||
*rsin = -M_SQRT1_2 * (cosp - sinp);
|
||||
break;
|
||||
default:
|
||||
abort(); // impossible
|
||||
}
|
||||
if (fabs(cos(x) - *rcos) > 1e-10 || fabs(sin(x) - *rsin) > 1e-10) {
|
||||
printf("cos(%.17lf) =? %.17lf instead of %.17lf\n", x, *rcos, cos(x));
|
||||
printf("sin(%.17lf) =? %.17lf instead of %.17lf\n", x, *rsin, sin(x));
|
||||
printf("fracx = %.17lf\n", frac_x);
|
||||
printf("cosp = %.17lf\n", cosp);
|
||||
printf("sinp = %.17lf\n", sinp);
|
||||
printf("nn = %d\n", (int)(nn));
|
||||
}
|
||||
}
|
||||
|
||||
double internal_accurate_cos(double x) {
|
||||
double rcos, rsin;
|
||||
internal_accurate_sincos(&rcos, &rsin, x);
|
||||
return rcos;
|
||||
}
|
||||
double internal_accurate_sin(double x) {
|
||||
double rcos, rsin;
|
||||
internal_accurate_sincos(&rcos, &rsin, x);
|
||||
return rsin;
|
||||
}
|
||||
|
||||
EXPORT void spqlios_debug_free(void* addr) { free((uint8_t*)addr - 64); }
|
||||
|
||||
EXPORT void* spqlios_debug_alloc(uint64_t size) { return (uint8_t*)malloc(size + 64) + 64; }
|
||||
|
||||
EXPORT void spqlios_free(void* addr) {
|
||||
#ifndef NDEBUG
|
||||
// in debug mode, we deallocated with spqlios_debug_free()
|
||||
spqlios_debug_free(addr);
|
||||
#else
|
||||
// in release mode, the function will free aligned memory
|
||||
#ifdef _WIN32
|
||||
_aligned_free(addr);
|
||||
#else
|
||||
free(addr);
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
EXPORT void* spqlios_alloc(uint64_t size) {
|
||||
#ifndef NDEBUG
|
||||
// in debug mode, the function will not necessarily have any particular alignment
|
||||
// it will also ensure that memory can only be deallocated with spqlios_free()
|
||||
return spqlios_debug_alloc(size);
|
||||
#else
|
||||
// in release mode, the function will return 64-bytes aligned memory
|
||||
#ifdef _WIN32
|
||||
void* reps = _aligned_malloc((size + 63) & (UINT64_C(-64)), 64);
|
||||
#else
|
||||
void* reps = aligned_alloc(64, (size + 63) & (UINT64_C(-64)));
|
||||
#endif
|
||||
if (reps == 0) FATAL_ERROR("Out of memory");
|
||||
return reps;
|
||||
#endif
|
||||
}
|
||||
|
||||
EXPORT void* spqlios_alloc_custom_align(uint64_t align, uint64_t size) {
|
||||
#ifndef NDEBUG
|
||||
// in debug mode, the function will not necessarily have any particular alignment
|
||||
// it will also ensure that memory can only be deallocated with spqlios_free()
|
||||
return spqlios_debug_alloc(size);
|
||||
#else
|
||||
// in release mode, the function will return aligned memory
|
||||
#ifdef _WIN32
|
||||
void* reps = _aligned_malloc(size, align);
|
||||
#else
|
||||
void* reps = aligned_alloc(align, size);
|
||||
#endif
|
||||
if (reps == 0) FATAL_ERROR("Out of memory");
|
||||
return reps;
|
||||
#endif
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
#ifndef SPQLIOS_COMMONS_H
|
||||
#define SPQLIOS_COMMONS_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#define EXPORT extern "C"
|
||||
#define EXPORT_DECL extern "C"
|
||||
#else
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#define EXPORT
|
||||
#define EXPORT_DECL extern
|
||||
#define nullptr 0x0;
|
||||
#endif
|
||||
|
||||
#define UNDEFINED() \
|
||||
{ \
|
||||
fprintf(stderr, "UNDEFINED!!!\n"); \
|
||||
abort(); \
|
||||
}
|
||||
#define NOT_IMPLEMENTED() \
|
||||
{ \
|
||||
fprintf(stderr, "NOT IMPLEMENTED!!!\n"); \
|
||||
abort(); \
|
||||
}
|
||||
#define FATAL_ERROR(MESSAGE) \
|
||||
{ \
|
||||
fprintf(stderr, "ERROR: %s\n", (MESSAGE)); \
|
||||
abort(); \
|
||||
}
|
||||
|
||||
EXPORT void* UNDEFINED_p_ii(int32_t n, int32_t m);
|
||||
EXPORT void* UNDEFINED_p_uu(uint32_t n, uint32_t m);
|
||||
EXPORT double* UNDEFINED_dp_pi(const void* p, int32_t n);
|
||||
EXPORT void* UNDEFINED_vp_pi(const void* p, int32_t n);
|
||||
EXPORT void* UNDEFINED_vp_pu(const void* p, uint32_t n);
|
||||
EXPORT void UNDEFINED_v_vpdp(const void* p, double* a);
|
||||
EXPORT void UNDEFINED_v_vpvp(const void* p, void* a);
|
||||
EXPORT double* NOT_IMPLEMENTED_dp_i(int32_t n);
|
||||
EXPORT void* NOT_IMPLEMENTED_vp_i(int32_t n);
|
||||
EXPORT void* NOT_IMPLEMENTED_vp_u(uint32_t n);
|
||||
EXPORT void NOT_IMPLEMENTED_v_dp(double* a);
|
||||
EXPORT void NOT_IMPLEMENTED_v_vp(void* p);
|
||||
EXPORT void NOT_IMPLEMENTED_v_idpdpdp(int32_t n, double* a, const double* b, const double* c);
|
||||
EXPORT void NOT_IMPLEMENTED_v_uvpcvpcvp(uint32_t n, void* r, const void* a, const void* b);
|
||||
EXPORT void NOT_IMPLEMENTED_v_uvpvpcvp(uint32_t n, void* a, void* b, const void* o);
|
||||
|
||||
// windows
|
||||
|
||||
#if defined(_WIN32) || defined(__APPLE__)
|
||||
#define __always_inline inline __attribute((always_inline))
|
||||
#endif
|
||||
|
||||
EXPORT void spqlios_free(void* address);
|
||||
|
||||
EXPORT void* spqlios_alloc(uint64_t size);
|
||||
EXPORT void* spqlios_alloc_custom_align(uint64_t align, uint64_t size);
|
||||
|
||||
#define USE_LIBM_SIN_COS
|
||||
#ifndef USE_LIBM_SIN_COS
|
||||
// if at some point, we want to remove the libm dependency, we can
|
||||
// consider this:
|
||||
EXPORT double internal_accurate_cos(double x);
|
||||
EXPORT double internal_accurate_sin(double x);
|
||||
EXPORT void internal_accurate_sincos(double* rcos, double* rsin, double x);
|
||||
#define m_accurate_cos internal_accurate_cos
|
||||
#define m_accurate_sin internal_accurate_sin
|
||||
#else
|
||||
// let's use libm sin and cos
|
||||
#define m_accurate_cos cos
|
||||
#define m_accurate_sin sin
|
||||
#endif
|
||||
|
||||
#endif // SPQLIOS_COMMONS_H
|
||||
@@ -0,0 +1,55 @@
|
||||
#include "commons_private.h"
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "commons.h"
|
||||
|
||||
EXPORT void* spqlios_error(const char* error) {
|
||||
fputs(error, stderr);
|
||||
abort();
|
||||
return nullptr;
|
||||
}
|
||||
EXPORT void* spqlios_keep_or_free(void* ptr, void* ptr2) {
|
||||
if (!ptr2) {
|
||||
free(ptr);
|
||||
}
|
||||
return ptr2;
|
||||
}
|
||||
|
||||
EXPORT uint32_t log2m(uint32_t m) {
|
||||
uint32_t a = m - 1;
|
||||
if (m & a) FATAL_ERROR("m must be a power of two");
|
||||
a = (a & 0x55555555u) + ((a >> 1) & 0x55555555u);
|
||||
a = (a & 0x33333333u) + ((a >> 2) & 0x33333333u);
|
||||
a = (a & 0x0F0F0F0Fu) + ((a >> 4) & 0x0F0F0F0Fu);
|
||||
a = (a & 0x00FF00FFu) + ((a >> 8) & 0x00FF00FFu);
|
||||
return (a & 0x0000FFFFu) + ((a >> 16) & 0x0000FFFFu);
|
||||
}
|
||||
|
||||
EXPORT uint64_t is_not_pow2_double(void* doublevalue) { return (*(uint64_t*)doublevalue) & 0x7FFFFFFFFFFFFUL; }
|
||||
|
||||
uint32_t revbits(uint32_t nbits, uint32_t value) {
|
||||
uint32_t res = 0;
|
||||
for (uint32_t i = 0; i < nbits; ++i) {
|
||||
res = (res << 1) + (value & 1);
|
||||
value >>= 1;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief this computes the sequence: 0,1/2,1/4,3/4,1/8,5/8,3/8,7/8,...
|
||||
* essentially: the bits of (i+1) in lsb order on the basis (1/2^k) mod 1*/
|
||||
double fracrevbits(uint32_t i) {
|
||||
if (i == 0) return 0;
|
||||
if (i == 1) return 0.5;
|
||||
if (i % 2 == 0)
|
||||
return fracrevbits(i / 2) / 2.;
|
||||
else
|
||||
return fracrevbits((i - 1) / 2) / 2. + 0.5;
|
||||
}
|
||||
|
||||
uint64_t ceilto64b(uint64_t size) { return (size + UINT64_C(63)) & (UINT64_C(-64)); }
|
||||
|
||||
uint64_t ceilto32b(uint64_t size) { return (size + UINT64_C(31)) & (UINT64_C(-32)); }
|
||||
@@ -0,0 +1,72 @@
|
||||
#ifndef SPQLIOS_COMMONS_PRIVATE_H
|
||||
#define SPQLIOS_COMMONS_PRIVATE_H
|
||||
|
||||
#include "commons.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#else
|
||||
#include <math.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#define nullptr 0x0;
|
||||
#endif
|
||||
|
||||
/** @brief log2 of a power of two (UB if m is not a power of two) */
|
||||
EXPORT uint32_t log2m(uint32_t m);
|
||||
|
||||
/** @brief checks if the doublevalue is a power of two */
|
||||
EXPORT uint64_t is_not_pow2_double(void* doublevalue);
|
||||
|
||||
#define UNDEFINED() \
|
||||
{ \
|
||||
fprintf(stderr, "UNDEFINED!!!\n"); \
|
||||
abort(); \
|
||||
}
|
||||
#define NOT_IMPLEMENTED() \
|
||||
{ \
|
||||
fprintf(stderr, "NOT IMPLEMENTED!!!\n"); \
|
||||
abort(); \
|
||||
}
|
||||
#define NOT_SUPPORTED() \
|
||||
{ \
|
||||
fprintf(stderr, "NOT SUPPORTED!!!\n"); \
|
||||
abort(); \
|
||||
}
|
||||
#define FATAL_ERROR(MESSAGE) \
|
||||
{ \
|
||||
fprintf(stderr, "ERROR: %s\n", (MESSAGE)); \
|
||||
abort(); \
|
||||
}
|
||||
|
||||
#define STATIC_ASSERT(condition) (void)sizeof(char[-1 + 2 * !!(condition)])
|
||||
|
||||
/** @brief reports the error and returns nullptr */
|
||||
EXPORT void* spqlios_error(const char* error);
|
||||
/** @brief if ptr2 is not null, returns ptr, otherwise free ptr and return null */
|
||||
EXPORT void* spqlios_keep_or_free(void* ptr, void* ptr2);
|
||||
|
||||
#ifdef __x86_64__
|
||||
#define CPU_SUPPORTS __builtin_cpu_supports
|
||||
#else
|
||||
// TODO for now, we do not have any optimization for non x86 targets
|
||||
#define CPU_SUPPORTS(xxxx) 0
|
||||
#endif
|
||||
|
||||
/** @brief returns the n bits of value in reversed order */
|
||||
EXPORT uint32_t revbits(uint32_t nbits, uint32_t value);
|
||||
|
||||
/**
|
||||
* @brief this computes the sequence: 0,1/2,1/4,3/4,1/8,5/8,3/8,7/8,...
|
||||
* essentially: the bits of (i+1) in lsb order on the basis (1/2^k) mod 1*/
|
||||
EXPORT double fracrevbits(uint32_t i);
|
||||
|
||||
/** @brief smallest multiple of 64 higher or equal to size */
|
||||
EXPORT uint64_t ceilto64b(uint64_t size);
|
||||
|
||||
/** @brief smallest multiple of 32 higher or equal to size */
|
||||
EXPORT uint64_t ceilto32b(uint64_t size);
|
||||
|
||||
#endif // SPQLIOS_COMMONS_PRIVATE_H
|
||||
@@ -0,0 +1,22 @@
|
||||
In this folder, we deal with the full complex FFT in `C[X] mod X^M-i`.
|
||||
One complex is represented by two consecutive doubles `(real,imag)`
|
||||
Note that a real polynomial sum_{j=0}^{N-1} p_j.X^j mod X^N+1
|
||||
corresponds to the complex polynomial of half degree `M=N/2`:
|
||||
`sum_{j=0}^{M-1} (p_{j} + i.p_{j+M}) X^j mod X^M-i`
|
||||
|
||||
For a complex polynomial A(X) sum c_i X^i of degree M-1
|
||||
or a real polynomial sum a_i X^i of degree N
|
||||
|
||||
coefficient space:
|
||||
a_0,a_M,a_1,a_{M+1},...,a_{M-1},a_{2M-1}
|
||||
or equivalently
|
||||
Re(c_0),Im(c_0),Re(c_1),Im(c_1),...Re(c_{M-1}),Im(c_{M-1})
|
||||
|
||||
eval space:
|
||||
c(omega_{0}),...,c(omega_{M-1})
|
||||
|
||||
where
|
||||
omega_j = omega^{1+rev_{2N}(j)}
|
||||
and omega = exp(i.pi/N)
|
||||
|
||||
rev_{2N}(j) is the number that has the log2(2N) bits of j in reverse order.
|
||||
@@ -0,0 +1,80 @@
|
||||
#include "cplx_fft_internal.h"
|
||||
|
||||
void cplx_set(CPLX r, const CPLX a) {
|
||||
r[0] = a[0];
|
||||
r[1] = a[1];
|
||||
}
|
||||
void cplx_neg(CPLX r, const CPLX a) {
|
||||
r[0] = -a[0];
|
||||
r[1] = -a[1];
|
||||
}
|
||||
void cplx_add(CPLX r, const CPLX a, const CPLX b) {
|
||||
r[0] = a[0] + b[0];
|
||||
r[1] = a[1] + b[1];
|
||||
}
|
||||
void cplx_sub(CPLX r, const CPLX a, const CPLX b) {
|
||||
r[0] = a[0] - b[0];
|
||||
r[1] = a[1] - b[1];
|
||||
}
|
||||
void cplx_mul(CPLX r, const CPLX a, const CPLX b) {
|
||||
double re = a[0] * b[0] - a[1] * b[1];
|
||||
r[1] = a[0] * b[1] + a[1] * b[0];
|
||||
r[0] = re;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief splits 2h evaluations of one polynomials into 2 times h evaluations of even/odd polynomial
|
||||
* Input: Q_0(y),...,Q_{h-1}(y),Q_0(-y),...,Q_{h-1}(-y)
|
||||
* Output: P_0(z),...,P_{h-1}(z),P_h(z),...,P_{2h-1}(z)
|
||||
* where Q_i(X)=P_i(X^2)+X.P_{h+i}(X^2) and y^2 = z
|
||||
* @param h number of "coefficients" h >= 1
|
||||
* @param data 2h complex coefficients interleaved and 256b aligned
|
||||
* @param powom y represented as (yre,yim)
|
||||
*/
|
||||
EXPORT void cplx_split_fft_ref(int32_t h, CPLX* data, const CPLX powom) {
|
||||
CPLX* d0 = data;
|
||||
CPLX* d1 = data + h;
|
||||
for (uint64_t i = 0; i < h; ++i) {
|
||||
CPLX diff;
|
||||
cplx_sub(diff, d0[i], d1[i]);
|
||||
cplx_add(d0[i], d0[i], d1[i]);
|
||||
cplx_mul(d1[i], diff, powom);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Do two layers of itwiddle (i.e. split).
|
||||
* Input/output: d0,d1,d2,d3 of length h
|
||||
* Algo:
|
||||
* itwiddle(d0,d1,om[0]),itwiddle(d2,d3,i.om[0])
|
||||
* itwiddle(d0,d2,om[1]),itwiddle(d1,d3,om[1])
|
||||
* @param h number of "coefficients" h >= 1
|
||||
* @param data 4h complex coefficients interleaved and 256b aligned
|
||||
* @param powom om[0] (re,im) and om[1] where om[1]=om[0]^2
|
||||
*/
|
||||
EXPORT void cplx_bisplit_fft_ref(int32_t h, CPLX* data, const CPLX powom[2]) {
|
||||
CPLX* d0 = data;
|
||||
CPLX* d2 = data + 2 * h;
|
||||
const CPLX* om0 = powom;
|
||||
CPLX iom0;
|
||||
iom0[0] = powom[0][1];
|
||||
iom0[1] = -powom[0][0];
|
||||
const CPLX* om1 = powom + 1;
|
||||
cplx_split_fft_ref(h, d0, *om0);
|
||||
cplx_split_fft_ref(h, d2, iom0);
|
||||
cplx_split_fft_ref(2 * h, d0, *om1);
|
||||
}
|
||||
|
||||
/**
|
||||
* Input: Q(y),Q(-y)
|
||||
* Output: P_0(z),P_1(z)
|
||||
* where Q(X)=P_0(X^2)+X.P_1(X^2) and y^2 = z
|
||||
* @param data 2 complexes coefficients interleaved and 256b aligned
|
||||
* @param powom (z,-z) interleaved: (zre,zim,-zre,-zim)
|
||||
*/
|
||||
void split_fft_last_ref(CPLX* data, const CPLX powom) {
|
||||
CPLX diff;
|
||||
cplx_sub(diff, data[0], data[1]);
|
||||
cplx_add(data[0], data[0], data[1]);
|
||||
cplx_mul(data[1], diff, powom);
|
||||
}
|
||||
@@ -0,0 +1,158 @@
|
||||
#include <errno.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "../commons_private.h"
|
||||
#include "cplx_fft_internal.h"
|
||||
#include "cplx_fft_private.h"
|
||||
|
||||
EXPORT void cplx_from_znx32_ref(const CPLX_FROM_ZNX32_PRECOMP* precomp, void* r, const int32_t* x) {
|
||||
const uint32_t m = precomp->m;
|
||||
const int32_t* inre = x;
|
||||
const int32_t* inim = x + m;
|
||||
CPLX* out = r;
|
||||
for (uint32_t i = 0; i < m; ++i) {
|
||||
out[i][0] = (double)inre[i];
|
||||
out[i][1] = (double)inim[i];
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void cplx_from_tnx32_ref(const CPLX_FROM_TNX32_PRECOMP* precomp, void* r, const int32_t* x) {
|
||||
static const double _2p32 = 1. / (INT64_C(1) << 32);
|
||||
const uint32_t m = precomp->m;
|
||||
const int32_t* inre = x;
|
||||
const int32_t* inim = x + m;
|
||||
CPLX* out = r;
|
||||
for (uint32_t i = 0; i < m; ++i) {
|
||||
out[i][0] = ((double)inre[i]) * _2p32;
|
||||
out[i][1] = ((double)inim[i]) * _2p32;
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void cplx_to_tnx32_ref(const CPLX_TO_TNX32_PRECOMP* precomp, int32_t* r, const void* x) {
|
||||
static const double _2p32 = (INT64_C(1) << 32);
|
||||
const uint32_t m = precomp->m;
|
||||
double factor = _2p32 / precomp->divisor;
|
||||
int32_t* outre = r;
|
||||
int32_t* outim = r + m;
|
||||
const CPLX* in = x;
|
||||
// Note: this formula will only work if abs(in) < 2^32
|
||||
for (uint32_t i = 0; i < m; ++i) {
|
||||
outre[i] = (int32_t)(int64_t)(rint(in[i][0] * factor));
|
||||
outim[i] = (int32_t)(int64_t)(rint(in[i][1] * factor));
|
||||
}
|
||||
}
|
||||
|
||||
void* init_cplx_from_znx32_precomp(CPLX_FROM_ZNX32_PRECOMP* res, uint32_t m) {
|
||||
res->m = m;
|
||||
if (CPU_SUPPORTS("avx2")) {
|
||||
if (m >= 8) {
|
||||
res->function = cplx_from_znx32_avx2_fma;
|
||||
} else {
|
||||
res->function = cplx_from_znx32_ref;
|
||||
}
|
||||
} else {
|
||||
res->function = cplx_from_znx32_ref;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
CPLX_FROM_ZNX32_PRECOMP* new_cplx_from_znx32_precomp(uint32_t m) {
|
||||
CPLX_FROM_ZNX32_PRECOMP* res = malloc(sizeof(CPLX_FROM_ZNX32_PRECOMP));
|
||||
if (!res) return spqlios_error(strerror(errno));
|
||||
return spqlios_keep_or_free(res, init_cplx_from_znx32_precomp(res, m));
|
||||
}
|
||||
|
||||
void* init_cplx_from_tnx32_precomp(CPLX_FROM_TNX32_PRECOMP* res, uint32_t m) {
|
||||
res->m = m;
|
||||
if (CPU_SUPPORTS("avx2")) {
|
||||
if (m >= 8) {
|
||||
res->function = cplx_from_tnx32_avx2_fma;
|
||||
} else {
|
||||
res->function = cplx_from_tnx32_ref;
|
||||
}
|
||||
} else {
|
||||
res->function = cplx_from_tnx32_ref;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
CPLX_FROM_TNX32_PRECOMP* new_cplx_from_tnx32_precomp(uint32_t m) {
|
||||
CPLX_FROM_TNX32_PRECOMP* res = malloc(sizeof(CPLX_FROM_TNX32_PRECOMP));
|
||||
if (!res) return spqlios_error(strerror(errno));
|
||||
return spqlios_keep_or_free(res, init_cplx_from_tnx32_precomp(res, m));
|
||||
}
|
||||
|
||||
void* init_cplx_to_tnx32_precomp(CPLX_TO_TNX32_PRECOMP* res, uint32_t m, double divisor, uint32_t log2overhead) {
|
||||
if (is_not_pow2_double(&divisor)) return spqlios_error("divisor must be a power of 2");
|
||||
if (m & (m - 1)) return spqlios_error("m must be a power of 2");
|
||||
if (log2overhead > 52) return spqlios_error("log2overhead is too large");
|
||||
res->m = m;
|
||||
res->divisor = divisor;
|
||||
if (CPU_SUPPORTS("avx2")) {
|
||||
if (log2overhead <= 18) {
|
||||
if (m >= 8) {
|
||||
res->function = cplx_to_tnx32_avx2_fma;
|
||||
} else {
|
||||
res->function = cplx_to_tnx32_ref;
|
||||
}
|
||||
} else {
|
||||
res->function = cplx_to_tnx32_ref;
|
||||
}
|
||||
} else {
|
||||
res->function = cplx_to_tnx32_ref;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
EXPORT CPLX_TO_TNX32_PRECOMP* new_cplx_to_tnx32_precomp(uint32_t m, double divisor, uint32_t log2overhead) {
|
||||
CPLX_TO_TNX32_PRECOMP* res = malloc(sizeof(CPLX_TO_TNX32_PRECOMP));
|
||||
if (!res) return spqlios_error(strerror(errno));
|
||||
return spqlios_keep_or_free(res, init_cplx_to_tnx32_precomp(res, m, divisor, log2overhead));
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Simpler API for the znx32 to cplx conversion.
|
||||
* For each dimension, the precomputed tables for this dimension are generated automatically the first time.
|
||||
* It is advised to do one dry-run call per desired dimension before using in a multithread environment */
|
||||
EXPORT void cplx_from_znx32_simple(uint32_t m, void* r, const int32_t* x) {
|
||||
// not checking for log2bound which is not relevant here
|
||||
static CPLX_FROM_ZNX32_PRECOMP precomp[32];
|
||||
CPLX_FROM_ZNX32_PRECOMP* p = precomp + log2m(m);
|
||||
if (!p->function) {
|
||||
if (!init_cplx_from_znx32_precomp(p, m)) abort();
|
||||
}
|
||||
p->function(p, r, x);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Simpler API for the tnx32 to cplx conversion.
|
||||
* For each dimension, the precomputed tables for this dimension are generated automatically the first time.
|
||||
* It is advised to do one dry-run call per desired dimension before using in a multithread environment */
|
||||
EXPORT void cplx_from_tnx32_simple(uint32_t m, void* r, const int32_t* x) {
|
||||
static CPLX_FROM_TNX32_PRECOMP precomp[32];
|
||||
CPLX_FROM_TNX32_PRECOMP* p = precomp + log2m(m);
|
||||
if (!p->function) {
|
||||
if (!init_cplx_from_tnx32_precomp(p, m)) abort();
|
||||
}
|
||||
p->function(p, r, x);
|
||||
}
|
||||
/**
|
||||
* @brief Simpler API for the cplx to tnx32 conversion.
|
||||
* For each dimension, the precomputed tables for this dimension are generated automatically the first time.
|
||||
* It is advised to do one dry-run call per desired dimension before using in a multithread environment */
|
||||
EXPORT void cplx_to_tnx32_simple(uint32_t m, double divisor, uint32_t log2overhead, int32_t* r, const void* x) {
|
||||
struct LAST_CPLX_TO_TNX32_PRECOMP {
|
||||
CPLX_TO_TNX32_PRECOMP p;
|
||||
double last_divisor;
|
||||
double last_log2over;
|
||||
};
|
||||
static __thread struct LAST_CPLX_TO_TNX32_PRECOMP precomp[32];
|
||||
struct LAST_CPLX_TO_TNX32_PRECOMP* p = precomp + log2m(m);
|
||||
if (!p->p.function || divisor != p->last_divisor || log2overhead != p->last_log2over) {
|
||||
memset(p, 0, sizeof(*p));
|
||||
if (!init_cplx_to_tnx32_precomp(&p->p, m, divisor, log2overhead)) abort();
|
||||
p->last_divisor = divisor;
|
||||
p->last_log2over = log2overhead;
|
||||
}
|
||||
p->p.function(&p->p, r, x);
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
#include <immintrin.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "cplx_fft_internal.h"
|
||||
#include "cplx_fft_private.h"
|
||||
|
||||
typedef int32_t I8MEM[8];
|
||||
typedef double D4MEM[4];
|
||||
|
||||
__always_inline void cplx_from_any_fma(uint64_t m, void* r, const int32_t* x, const __m256i C, const __m256d R) {
|
||||
const __m256i S = _mm256_set1_epi32(0x80000000);
|
||||
const I8MEM* inre = (I8MEM*)(x);
|
||||
const I8MEM* inim = (I8MEM*)(x + m);
|
||||
D4MEM* out = (D4MEM*)r;
|
||||
const uint64_t ms8 = m / 8;
|
||||
for (uint32_t i = 0; i < ms8; ++i) {
|
||||
__m256i rea = _mm256_loadu_si256((__m256i*)inre[0]);
|
||||
__m256i ima = _mm256_loadu_si256((__m256i*)inim[0]);
|
||||
rea = _mm256_add_epi32(rea, S);
|
||||
ima = _mm256_add_epi32(ima, S);
|
||||
__m256i tmpa = _mm256_unpacklo_epi32(rea, ima);
|
||||
__m256i tmpc = _mm256_unpackhi_epi32(rea, ima);
|
||||
__m256i cpla = _mm256_permute2x128_si256(tmpa, tmpc, 0x20);
|
||||
__m256i cplc = _mm256_permute2x128_si256(tmpa, tmpc, 0x31);
|
||||
tmpa = _mm256_unpacklo_epi32(cpla, C);
|
||||
__m256i tmpb = _mm256_unpackhi_epi32(cpla, C);
|
||||
tmpc = _mm256_unpacklo_epi32(cplc, C);
|
||||
__m256i tmpd = _mm256_unpackhi_epi32(cplc, C);
|
||||
cpla = _mm256_permute2x128_si256(tmpa, tmpb, 0x20);
|
||||
__m256i cplb = _mm256_permute2x128_si256(tmpa, tmpb, 0x31);
|
||||
cplc = _mm256_permute2x128_si256(tmpc, tmpd, 0x20);
|
||||
__m256i cpld = _mm256_permute2x128_si256(tmpc, tmpd, 0x31);
|
||||
__m256d dcpla = _mm256_sub_pd(_mm256_castsi256_pd(cpla), R);
|
||||
__m256d dcplb = _mm256_sub_pd(_mm256_castsi256_pd(cplb), R);
|
||||
__m256d dcplc = _mm256_sub_pd(_mm256_castsi256_pd(cplc), R);
|
||||
__m256d dcpld = _mm256_sub_pd(_mm256_castsi256_pd(cpld), R);
|
||||
_mm256_storeu_pd(out[0], dcpla);
|
||||
_mm256_storeu_pd(out[1], dcplb);
|
||||
_mm256_storeu_pd(out[2], dcplc);
|
||||
_mm256_storeu_pd(out[3], dcpld);
|
||||
inre += 1;
|
||||
inim += 1;
|
||||
out += 4;
|
||||
}
|
||||
}
|
||||
|
||||
EXPORT void cplx_from_znx32_avx2_fma(const CPLX_FROM_ZNX32_PRECOMP* precomp, void* r, const int32_t* x) {
|
||||
// note: the hex code of 2^31 + 2^52 is 0x4330000080000000
|
||||
const __m256i C = _mm256_set1_epi32(0x43300000);
|
||||
const __m256d R = _mm256_set1_pd((INT64_C(1) << 31) + (INT64_C(1) << 52));
|
||||
// double XX = INT64_C(1) + (INT64_C(1)<<31) + (INT64_C(1)<<52);
|
||||
// printf("\n\n%016lx\n", *(uint64_t*)&XX);
|
||||
// abort();
|
||||
const uint64_t m = precomp->m;
|
||||
cplx_from_any_fma(m, r, x, C, R);
|
||||
}
|
||||
|
||||
EXPORT void cplx_from_tnx32_avx2_fma(const CPLX_FROM_TNX32_PRECOMP* precomp, void* r, const int32_t* x) {
|
||||
// note: the hex code of 2^-1 + 2^30 is 0x4130000080000000
|
||||
const __m256i C = _mm256_set1_epi32(0x41300000);
|
||||
const __m256d R = _mm256_set1_pd(0.5 + (INT64_C(1) << 20));
|
||||
// double XX = (double)(INT64_C(1) + (INT64_C(1)<<31) + (INT64_C(1)<<52))/(INT64_C(1)<<32);
|
||||
// printf("\n\n%016lx\n", *(uint64_t*)&XX);
|
||||
// abort();
|
||||
const uint64_t m = precomp->m;
|
||||
cplx_from_any_fma(m, r, x, C, R);
|
||||
}
|
||||
|
||||
EXPORT void cplx_to_tnx32_avx2_fma(const CPLX_TO_TNX32_PRECOMP* precomp, int32_t* r, const void* x) {
|
||||
const __m256d R = _mm256_set1_pd((0.5 + (INT64_C(3) << 19)) * precomp->divisor);
|
||||
const __m256i MASK = _mm256_set1_epi64x(0xFFFFFFFFUL);
|
||||
const __m256i S = _mm256_set1_epi32(0x80000000);
|
||||
// const __m256i IDX = _mm256_set_epi32(0,4,1,5,2,6,3,7);
|
||||
const __m256i IDX = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
|
||||
const uint64_t m = precomp->m;
|
||||
const uint64_t ms8 = m / 8;
|
||||
I8MEM* outre = (I8MEM*)r;
|
||||
I8MEM* outim = (I8MEM*)(r + m);
|
||||
const D4MEM* in = x;
|
||||
// Note: this formula will only work if abs(in) < 2^32
|
||||
for (uint32_t i = 0; i < ms8; ++i) {
|
||||
__m256d cpla = _mm256_loadu_pd(in[0]);
|
||||
__m256d cplb = _mm256_loadu_pd(in[1]);
|
||||
__m256d cplc = _mm256_loadu_pd(in[2]);
|
||||
__m256d cpld = _mm256_loadu_pd(in[3]);
|
||||
__m256i icpla = _mm256_castpd_si256(_mm256_add_pd(cpla, R));
|
||||
__m256i icplb = _mm256_castpd_si256(_mm256_add_pd(cplb, R));
|
||||
__m256i icplc = _mm256_castpd_si256(_mm256_add_pd(cplc, R));
|
||||
__m256i icpld = _mm256_castpd_si256(_mm256_add_pd(cpld, R));
|
||||
icpla = _mm256_or_si256(_mm256_and_si256(icpla, MASK), _mm256_slli_epi64(icplb, 32));
|
||||
icplc = _mm256_or_si256(_mm256_and_si256(icplc, MASK), _mm256_slli_epi64(icpld, 32));
|
||||
icpla = _mm256_xor_si256(icpla, S);
|
||||
icplc = _mm256_xor_si256(icplc, S);
|
||||
__m256i re = _mm256_unpacklo_epi64(icpla, icplc);
|
||||
__m256i im = _mm256_unpackhi_epi64(icpla, icplc);
|
||||
re = _mm256_permutevar8x32_epi32(re, IDX);
|
||||
im = _mm256_permutevar8x32_epi32(im, IDX);
|
||||
_mm256_storeu_si256((__m256i*)outre[0], re);
|
||||
_mm256_storeu_si256((__m256i*)outim[0], im);
|
||||
outre += 1;
|
||||
outim += 1;
|
||||
in += 4;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
#include "cplx_fft_internal.h"
|
||||
#include "cplx_fft_private.h"
|
||||
|
||||
EXPORT void cplx_from_znx32(const CPLX_FROM_ZNX32_PRECOMP* tables, void* r, const int32_t* a) {
|
||||
tables->function(tables, r, a);
|
||||
}
|
||||
EXPORT void cplx_from_tnx32(const CPLX_FROM_TNX32_PRECOMP* tables, void* r, const int32_t* a) {
|
||||
tables->function(tables, r, a);
|
||||
}
|
||||
EXPORT void cplx_to_tnx32(const CPLX_TO_TNX32_PRECOMP* tables, int32_t* r, const void* a) {
|
||||
tables->function(tables, r, a);
|
||||
}
|
||||
EXPORT void cplx_fftvec_mul(const CPLX_FFTVEC_MUL_PRECOMP* tables, void* r, const void* a, const void* b) {
|
||||
tables->function(tables, r, a, b);
|
||||
}
|
||||
EXPORT void cplx_fftvec_addmul(const CPLX_FFTVEC_ADDMUL_PRECOMP* tables, void* r, const void* a, const void* b) {
|
||||
tables->function(tables, r, a, b);
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
#include "cplx_fft_internal.h"
|
||||
#include "cplx_fft_private.h"
|
||||
|
||||
EXPORT void cplx_fftvec_addmul_fma(const CPLX_FFTVEC_ADDMUL_PRECOMP* tables, void* r, const void* a, const void* b) {
|
||||
UNDEFINED(); // not defined for non x86 targets
|
||||
}
|
||||
EXPORT void cplx_fftvec_mul_fma(const CPLX_FFTVEC_MUL_PRECOMP* tables, void* r, const void* a, const void* b) {
|
||||
UNDEFINED();
|
||||
}
|
||||
EXPORT void cplx_fftvec_addmul_sse(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a, const void* b) {
|
||||
UNDEFINED();
|
||||
}
|
||||
EXPORT void cplx_fftvec_addmul_avx512(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a,
|
||||
const void* b) {
|
||||
UNDEFINED();
|
||||
}
|
||||
EXPORT void cplx_fft16_avx_fma(void* data, const void* omega) { UNDEFINED(); }
|
||||
EXPORT void cplx_ifft16_avx_fma(void* data, const void* omega) { UNDEFINED(); }
|
||||
EXPORT void cplx_from_znx32_avx2_fma(const CPLX_FROM_ZNX32_PRECOMP* precomp, void* r, const int32_t* x) { UNDEFINED(); }
|
||||
EXPORT void cplx_from_tnx32_avx2_fma(const CPLX_FROM_TNX32_PRECOMP* precomp, void* r, const int32_t* x) { UNDEFINED(); }
|
||||
EXPORT void cplx_to_tnx32_avx2_fma(const CPLX_TO_TNX32_PRECOMP* precomp, int32_t* x, const void* c) { UNDEFINED(); }
|
||||
EXPORT void cplx_fft_avx2_fma(const CPLX_FFT_PRECOMP* tables, void* data){UNDEFINED()} EXPORT
|
||||
void cplx_ifft_avx2_fma(const CPLX_IFFT_PRECOMP* itables, void* data){UNDEFINED()} EXPORT
|
||||
void cplx_fftvec_twiddle_fma(const CPLX_FFTVEC_TWIDDLE_PRECOMP* tables, void* a, void* b, const void* om){
|
||||
UNDEFINED()} EXPORT void cplx_fftvec_twiddle_avx512(const CPLX_FFTVEC_TWIDDLE_PRECOMP* tables, void* a, void* b,
|
||||
const void* om){UNDEFINED()} EXPORT
|
||||
void cplx_fftvec_bitwiddle_fma(const CPLX_FFTVEC_BITWIDDLE_PRECOMP* tables, void* a, uint64_t slice,
|
||||
const void* om){UNDEFINED()} EXPORT
|
||||
void cplx_fftvec_bitwiddle_avx512(const CPLX_FFTVEC_BITWIDDLE_PRECOMP* tables, void* a, uint64_t slice,
|
||||
const void* om){UNDEFINED()}
|
||||
|
||||
// DEPRECATED?
|
||||
EXPORT void cplx_fftvec_add_fma(uint32_t m, void* r, const void* a, const void* b){UNDEFINED()} EXPORT
|
||||
void cplx_fftvec_sub2_to_fma(uint32_t m, void* r, const void* a, const void* b){UNDEFINED()} EXPORT
|
||||
void cplx_fftvec_copy_fma(uint32_t m, void* r, const void* a) {
|
||||
UNDEFINED()
|
||||
}
|
||||
|
||||
// executors
|
||||
// EXPORT void cplx_ifft(const CPLX_IFFT_PRECOMP* itables, void* data) {
|
||||
// itables->function(itables, data);
|
||||
//}
|
||||
// EXPORT void cplx_fft(const CPLX_FFT_PRECOMP* tables, void* data) { tables->function(tables, data); }
|
||||
@@ -0,0 +1,221 @@
|
||||
#ifndef SPQLIOS_CPLX_FFT_H
|
||||
#define SPQLIOS_CPLX_FFT_H
|
||||
|
||||
#include "../commons.h"
|
||||
|
||||
typedef struct cplx_fft_precomp CPLX_FFT_PRECOMP;
|
||||
typedef struct cplx_ifft_precomp CPLX_IFFT_PRECOMP;
|
||||
typedef struct cplx_mul_precomp CPLX_FFTVEC_MUL_PRECOMP;
|
||||
typedef struct cplx_addmul_precomp CPLX_FFTVEC_ADDMUL_PRECOMP;
|
||||
typedef struct cplx_from_znx32_precomp CPLX_FROM_ZNX32_PRECOMP;
|
||||
typedef struct cplx_from_tnx32_precomp CPLX_FROM_TNX32_PRECOMP;
|
||||
typedef struct cplx_to_tnx32_precomp CPLX_TO_TNX32_PRECOMP;
|
||||
typedef struct cplx_to_znx32_precomp CPLX_TO_ZNX32_PRECOMP;
|
||||
typedef struct cplx_from_rnx64_precomp CPLX_FROM_RNX64_PRECOMP;
|
||||
typedef struct cplx_to_rnx64_precomp CPLX_TO_RNX64_PRECOMP;
|
||||
typedef struct cplx_round_to_rnx64_precomp CPLX_ROUND_TO_RNX64_PRECOMP;
|
||||
|
||||
/**
|
||||
* @brief precomputes fft tables.
|
||||
* The FFT tables contains a constant section that is required for efficient FFT operations in dimension nn.
|
||||
* The resulting pointer is to be passed as "tables" argument to any call to the fft function.
|
||||
* The user can optionnally allocate zero or more computation buffers, which are scratch spaces that are contiguous to
|
||||
* the constant tables in memory, and allow for more efficient operations. It is the user's responsibility to ensure
|
||||
* that each of those buffers are never used simultaneously by two ffts on different threads at the same time. The fft
|
||||
* table must be deleted by delete_fft_precomp after its last usage.
|
||||
*/
|
||||
EXPORT CPLX_FFT_PRECOMP* new_cplx_fft_precomp(uint32_t m, uint32_t num_buffers);
|
||||
|
||||
/**
|
||||
* @brief gets the address of a fft buffer allocated during new_fft_precomp.
|
||||
* This buffer can be used as data pointer in subsequent calls to fft,
|
||||
* and does not need to be released afterwards.
|
||||
*/
|
||||
EXPORT void* cplx_fft_precomp_get_buffer(const CPLX_FFT_PRECOMP* tables, uint32_t buffer_index);
|
||||
|
||||
/**
|
||||
* @brief allocates a new fft buffer.
|
||||
* This buffer can be used as data pointer in subsequent calls to fft,
|
||||
* and must be deleted afterwards by calling delete_fft_buffer.
|
||||
*/
|
||||
EXPORT void* new_cplx_fft_buffer(uint32_t m);
|
||||
|
||||
/**
|
||||
* @brief allocates a new fft buffer.
|
||||
* This buffer can be used as data pointer in subsequent calls to fft,
|
||||
* and must be deleted afterwards by calling delete_fft_buffer.
|
||||
*/
|
||||
EXPORT void delete_cplx_fft_buffer(void* buffer);
|
||||
|
||||
/**
|
||||
* @brief deallocates a fft table and all its built-in buffers.
|
||||
*/
|
||||
#define delete_cplx_fft_precomp free
|
||||
|
||||
/**
|
||||
* @brief computes a direct fft in-place over data.
|
||||
*/
|
||||
EXPORT void cplx_fft(const CPLX_FFT_PRECOMP* tables, void* data);
|
||||
|
||||
EXPORT CPLX_IFFT_PRECOMP* new_cplx_ifft_precomp(uint32_t m, uint32_t num_buffers);
|
||||
EXPORT void* cplx_ifft_precomp_get_buffer(const CPLX_IFFT_PRECOMP* tables, uint32_t buffer_index);
|
||||
EXPORT void cplx_ifft(const CPLX_IFFT_PRECOMP* tables, void* data);
|
||||
#define delete_cplx_ifft_precomp free
|
||||
|
||||
EXPORT CPLX_FFTVEC_MUL_PRECOMP* new_cplx_fftvec_mul_precomp(uint32_t m);
|
||||
EXPORT void cplx_fftvec_mul(const CPLX_FFTVEC_MUL_PRECOMP* tables, void* r, const void* a, const void* b);
|
||||
#define delete_cplx_fftvec_mul_precomp free
|
||||
|
||||
EXPORT CPLX_FFTVEC_ADDMUL_PRECOMP* new_cplx_fftvec_addmul_precomp(uint32_t m);
|
||||
EXPORT void cplx_fftvec_addmul(const CPLX_FFTVEC_ADDMUL_PRECOMP* tables, void* r, const void* a, const void* b);
|
||||
#define delete_cplx_fftvec_addmul_precomp free
|
||||
|
||||
/**
|
||||
* @brief prepares a conversion from ZnX to the cplx layout.
|
||||
* All the coefficients must be strictly lower than 2^log2bound in absolute value. Any attempt to use
|
||||
* this function on a larger coefficient is undefined behaviour. The resulting precomputed data must
|
||||
* be freed with `new_cplx_from_znx32_precomp`
|
||||
* @param m the target complex dimension m from C[X] mod X^m-i. Note that the inputs have n=2m
|
||||
* int32 coefficients in natural order modulo X^n+1
|
||||
* @param log2bound bound on the input coefficients. Must be between 0 and 32
|
||||
*/
|
||||
EXPORT CPLX_FROM_ZNX32_PRECOMP* new_cplx_from_znx32_precomp(uint32_t m);
|
||||
/**
|
||||
* @brief converts from ZnX to the cplx layout.
|
||||
* @param tables precomputed data obtained by new_cplx_from_znx32_precomp.
|
||||
* @param r resulting array of m complexes coefficients mod X^m-i
|
||||
* @param x input array of n bounded integer coefficients mod X^n+1
|
||||
*/
|
||||
EXPORT void cplx_from_znx32(const CPLX_FROM_ZNX32_PRECOMP* tables, void* r, const int32_t* a);
|
||||
/** @brief frees a precomputed conversion data initialized with new_cplx_from_znx32_precomp. */
|
||||
#define delete_cplx_from_znx32_precomp free
|
||||
|
||||
/**
|
||||
* @brief prepares a conversion from TnX to the cplx layout.
|
||||
* @param m the target complex dimension m from C[X] mod X^m-i. Note that the inputs have n=2m
|
||||
* torus32 coefficients. The resulting precomputed data must
|
||||
* be freed with `delete_cplx_from_tnx32_precomp`
|
||||
*/
|
||||
EXPORT CPLX_FROM_TNX32_PRECOMP* new_cplx_from_tnx32_precomp(uint32_t m);
|
||||
/**
|
||||
* @brief converts from TnX to the cplx layout.
|
||||
* @param tables precomputed data obtained by new_cplx_from_tnx32_precomp.
|
||||
* @param r resulting array of m complexes coefficients mod X^m-i
|
||||
* @param x input array of n torus32 coefficients mod X^n+1
|
||||
*/
|
||||
EXPORT void cplx_from_tnx32(const CPLX_FROM_TNX32_PRECOMP* tables, void* r, const int32_t* a);
|
||||
/** @brief frees a precomputed conversion data initialized with new_cplx_from_tnx32_precomp. */
|
||||
#define delete_cplx_from_tnx32_precomp free
|
||||
|
||||
/**
|
||||
* @brief prepares a rescale and conversion from the cplx layout to TnX.
|
||||
* @param m the target complex dimension m from C[X] mod X^m-i. Note that the outputs have n=2m
|
||||
* torus32 coefficients.
|
||||
* @param divisor must be a power of two. The inputs are rescaled by divisor before being reduced modulo 1.
|
||||
* Remember that the output of an iFFT must be divided by m.
|
||||
* @param log2overhead all inputs absolute values must be within divisor.2^log2overhead.
|
||||
* For any inputs outside of these bounds, the conversion is undefined behaviour.
|
||||
* The maximum supported log2overhead is 52, and the algorithm is faster for log2overhead=18.
|
||||
*/
|
||||
EXPORT CPLX_TO_TNX32_PRECOMP* new_cplx_to_tnx32_precomp(uint32_t m, double divisor, uint32_t log2overhead);
|
||||
/**
|
||||
* @brief rescale, converts and reduce mod 1 from cplx layout to torus32.
|
||||
* @param tables precomputed data obtained by new_cplx_from_tnx32_precomp.
|
||||
* @param r resulting array of n torus32 coefficients mod X^n+1
|
||||
* @param x input array of m cplx coefficients mod X^m-i
|
||||
*/
|
||||
EXPORT void cplx_to_tnx32(const CPLX_TO_TNX32_PRECOMP* tables, int32_t* r, const void* a);
|
||||
#define delete_cplx_to_tnx32_precomp free
|
||||
|
||||
EXPORT CPLX_TO_ZNX32_PRECOMP* new_cplx_to_znx32_precomp(uint32_t m, double divisor);
|
||||
EXPORT void cplx_to_znx32(const CPLX_TO_ZNX32_PRECOMP* precomp, int32_t* r, const void* x);
|
||||
#define delete_cplx_to_znx32_simple free
|
||||
|
||||
EXPORT CPLX_FROM_RNX64_PRECOMP* new_cplx_from_rnx64_simple(uint32_t m);
|
||||
EXPORT void cplx_from_rnx64(const CPLX_FROM_RNX64_PRECOMP* precomp, void* r, const double* x);
|
||||
#define delete_cplx_from_rnx64_simple free
|
||||
|
||||
EXPORT CPLX_TO_RNX64_PRECOMP* new_cplx_to_rnx64(uint32_t m, double divisor);
|
||||
EXPORT void cplx_to_rnx64(const CPLX_TO_RNX64_PRECOMP* precomp, double* r, const void* x);
|
||||
#define delete_cplx_round_to_rnx64_simple free
|
||||
|
||||
EXPORT CPLX_ROUND_TO_RNX64_PRECOMP* new_cplx_round_to_rnx64(uint32_t m, double divisor, uint32_t log2bound);
|
||||
EXPORT void cplx_round_to_rnx64(const CPLX_ROUND_TO_RNX64_PRECOMP* precomp, double* r, const void* x);
|
||||
#define delete_cplx_round_to_rnx64_simple free
|
||||
|
||||
/**
|
||||
* @brief Simpler API for the fft function.
|
||||
* For each dimension, the precomputed tables for this dimension are generated automatically.
|
||||
* It is advised to do one dry-run per desired dimension before using in a multithread environment */
|
||||
EXPORT void cplx_fft_simple(uint32_t m, void* data);
|
||||
/**
|
||||
* @brief Simpler API for the ifft function.
|
||||
* For each dimension, the precomputed tables for this dimension are generated automatically the first time.
|
||||
* It is advised to do one dry-run call per desired dimension in the main thread before using in a multithread
|
||||
* environment */
|
||||
EXPORT void cplx_ifft_simple(uint32_t m, void* data);
|
||||
/**
|
||||
* @brief Simpler API for the fftvec multiplication function.
|
||||
* For each dimension, the precomputed tables for this dimension are generated automatically the first time.
|
||||
* It is advised to do one dry-run call per desired dimension before using in a multithread environment */
|
||||
EXPORT void cplx_fftvec_mul_simple(uint32_t m, void* r, const void* a, const void* b);
|
||||
/**
|
||||
* @brief Simpler API for the fftvec addmul function.
|
||||
* For each dimension, the precomputed tables for this dimension are generated automatically the first time.
|
||||
* It is advised to do one dry-run call per desired dimension before using in a multithread environment */
|
||||
EXPORT void cplx_fftvec_addmul_simple(uint32_t m, void* r, const void* a, const void* b);
|
||||
/**
|
||||
* @brief Simpler API for the znx32 to cplx conversion.
|
||||
* For each dimension, the precomputed tables for this dimension are generated automatically the first time.
|
||||
* It is advised to do one dry-run call per desired dimension before using in a multithread environment */
|
||||
EXPORT void cplx_from_znx32_simple(uint32_t m, void* r, const int32_t* x);
|
||||
/**
|
||||
* @brief Simpler API for the tnx32 to cplx conversion.
|
||||
* For each dimension, the precomputed tables for this dimension are generated automatically the first time.
|
||||
* It is advised to do one dry-run call per desired dimension before using in a multithread environment */
|
||||
EXPORT void cplx_from_tnx32_simple(uint32_t m, void* r, const int32_t* x);
|
||||
/**
|
||||
* @brief Simpler API for the cplx to tnx32 conversion.
|
||||
* For each dimension, the precomputed tables for this dimension are generated automatically the first time.
|
||||
* It is advised to do one dry-run call per desired dimension before using in a multithread environment */
|
||||
EXPORT void cplx_to_tnx32_simple(uint32_t m, double divisor, uint32_t log2overhead, int32_t* r, const void* x);
|
||||
|
||||
/**
|
||||
* @brief converts, divides and round from cplx to znx32 (simple API)
|
||||
* @param m the complex dimension
|
||||
* @param divisor the divisor: a power of two, often m after an ifft
|
||||
* @param r the result: must be a double array of size 2m. r must be distinct from x
|
||||
* @param x the input: must hold m complex numbers.
|
||||
*/
|
||||
EXPORT void cplx_to_znx32_simple(uint32_t m, double divisor, int32_t* r, const void* x);
|
||||
|
||||
/**
|
||||
* @brief converts from rnx64 to cplx (simple API)
|
||||
* The bound on the output is assumed to be within ]2^-31,2^31[.
|
||||
* Any coefficient that would fall outside this range is undefined behaviour.
|
||||
* @param m the complex dimension
|
||||
* @param r the result: must be an array of m complex numbers. r must be distinct from x
|
||||
* @param x the input: must be an array of 2m doubles.
|
||||
*/
|
||||
EXPORT void cplx_from_rnx64_simple(uint32_t m, void* r, const double* x);
|
||||
|
||||
/**
|
||||
* @brief converts, divides from cplx to rnx64 (simple API)
|
||||
* @param m the complex dimension
|
||||
* @param divisor the divisor: a power of two, often m after an ifft
|
||||
* @param r the result: must be a double array of size 2m. r must be distinct from x
|
||||
* @param x the input: must hold m complex numbers.
|
||||
*/
|
||||
EXPORT void cplx_to_rnx64_simple(uint32_t m, double divisor, double* r, const void* x);
|
||||
|
||||
/**
|
||||
* @brief converts, divides and round to integer from cplx to rnx32 (simple API)
|
||||
* @param m the complex dimension
|
||||
* @param divisor the divisor: a power of two, often m after an ifft
|
||||
* @param log2bound a guarantee on the log2bound of the output. log2bound<=48 will use a more efficient algorithm.
|
||||
* @param r the result: must be a double array of size 2m. r must be distinct from x
|
||||
* @param x the input: must hold m complex numbers.
|
||||
*/
|
||||
EXPORT void cplx_round_to_rnx64_simple(uint32_t m, double divisor, uint32_t log2bound, double* r, const void* x);
|
||||
|
||||
#endif // SPQLIOS_CPLX_FFT_H
|
||||
@@ -0,0 +1,156 @@
|
||||
# shifted FFT over X^16-i
|
||||
# 1st argument (rdi) contains 16 complexes
|
||||
# 2nd argument (rsi) contains: 8 complexes
|
||||
# omega,alpha,beta,j.beta,gamma,j.gamma,k.gamma,kj.gamma
|
||||
# alpha = sqrt(omega), beta = sqrt(alpha), gamma = sqrt(beta)
|
||||
# j = sqrt(i), k=sqrt(j)
|
||||
.globl cplx_fft16_avx_fma
|
||||
cplx_fft16_avx_fma:
|
||||
vmovupd (%rdi),%ymm8
|
||||
vmovupd 0x20(%rdi),%ymm9
|
||||
vmovupd 0x40(%rdi),%ymm10
|
||||
vmovupd 0x60(%rdi),%ymm11
|
||||
vmovupd 0x80(%rdi),%ymm12
|
||||
vmovupd 0xa0(%rdi),%ymm13
|
||||
vmovupd 0xc0(%rdi),%ymm14
|
||||
vmovupd 0xe0(%rdi),%ymm15
|
||||
|
||||
.first_pass:
|
||||
vmovupd (%rsi),%xmm0 /* omri */
|
||||
vinsertf128 $1, %xmm0, %ymm0, %ymm0 /* omriri */
|
||||
vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: omiiii */
|
||||
vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: omrrrr */
|
||||
vshufpd $5, %ymm12, %ymm12, %ymm4
|
||||
vshufpd $5, %ymm13, %ymm13, %ymm5
|
||||
vshufpd $5, %ymm14, %ymm14, %ymm6
|
||||
vshufpd $5, %ymm15, %ymm15, %ymm7
|
||||
vmulpd %ymm4,%ymm1,%ymm4
|
||||
vmulpd %ymm5,%ymm1,%ymm5
|
||||
vmulpd %ymm6,%ymm1,%ymm6
|
||||
vmulpd %ymm7,%ymm1,%ymm7
|
||||
vfmaddsub231pd %ymm12, %ymm0, %ymm4 # ymm4 = (ymm0 * ymm12) +/- ymm4
|
||||
vfmaddsub231pd %ymm13, %ymm0, %ymm5
|
||||
vfmaddsub231pd %ymm14, %ymm0, %ymm6
|
||||
vfmaddsub231pd %ymm15, %ymm0, %ymm7
|
||||
vsubpd %ymm4,%ymm8,%ymm12
|
||||
vsubpd %ymm5,%ymm9,%ymm13
|
||||
vsubpd %ymm6,%ymm10,%ymm14
|
||||
vsubpd %ymm7,%ymm11,%ymm15
|
||||
vaddpd %ymm4,%ymm8,%ymm8
|
||||
vaddpd %ymm5,%ymm9,%ymm9
|
||||
vaddpd %ymm6,%ymm10,%ymm10
|
||||
vaddpd %ymm7,%ymm11,%ymm11
|
||||
|
||||
.second_pass:
|
||||
vmovupd 16(%rsi),%xmm0 /* omri */
|
||||
vinsertf128 $1, %xmm0, %ymm0, %ymm0 /* omriri */
|
||||
vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: omiiii */
|
||||
vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: omrrrr */
|
||||
vshufpd $5, %ymm10, %ymm10, %ymm4
|
||||
vshufpd $5, %ymm11, %ymm11, %ymm5
|
||||
vshufpd $5, %ymm14, %ymm14, %ymm6
|
||||
vshufpd $5, %ymm15, %ymm15, %ymm7
|
||||
vmulpd %ymm4,%ymm1,%ymm4
|
||||
vmulpd %ymm5,%ymm1,%ymm5
|
||||
vmulpd %ymm6,%ymm0,%ymm6
|
||||
vmulpd %ymm7,%ymm0,%ymm7
|
||||
vfmaddsub231pd %ymm10, %ymm0, %ymm4 # ymm4 = (ymm0 * ymm10) +/- ymm4
|
||||
vfmaddsub231pd %ymm11, %ymm0, %ymm5
|
||||
vfmsubadd231pd %ymm14, %ymm1, %ymm6
|
||||
vfmsubadd231pd %ymm15, %ymm1, %ymm7
|
||||
vsubpd %ymm4,%ymm8,%ymm10
|
||||
vsubpd %ymm5,%ymm9,%ymm11
|
||||
vaddpd %ymm6,%ymm12,%ymm14
|
||||
vaddpd %ymm7,%ymm13,%ymm15
|
||||
vaddpd %ymm4,%ymm8,%ymm8
|
||||
vaddpd %ymm5,%ymm9,%ymm9
|
||||
vsubpd %ymm6,%ymm12,%ymm12
|
||||
vsubpd %ymm7,%ymm13,%ymm13
|
||||
|
||||
.third_pass:
|
||||
vmovupd 32(%rsi),%xmm0 /* gamma */
|
||||
vmovupd 48(%rsi),%xmm2 /* delta */
|
||||
vinsertf128 $1, %xmm0, %ymm0, %ymm0
|
||||
vinsertf128 $1, %xmm2, %ymm2, %ymm2
|
||||
vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: gama.iiii */
|
||||
vshufpd $15, %ymm2, %ymm2, %ymm3 /* ymm3: delta.iiii */
|
||||
vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: gama.rrrr */
|
||||
vshufpd $0, %ymm2, %ymm2, %ymm2 /* ymm2: delta.rrrr */
|
||||
vshufpd $5, %ymm9, %ymm9, %ymm4
|
||||
vshufpd $5, %ymm11, %ymm11, %ymm5
|
||||
vshufpd $5, %ymm13, %ymm13, %ymm6
|
||||
vshufpd $5, %ymm15, %ymm15, %ymm7
|
||||
vmulpd %ymm4,%ymm1,%ymm4
|
||||
vmulpd %ymm5,%ymm0,%ymm5
|
||||
vmulpd %ymm6,%ymm3,%ymm6
|
||||
vmulpd %ymm7,%ymm2,%ymm7
|
||||
vfmaddsub231pd %ymm9, %ymm0, %ymm4 # ymm4 = (ymm0 * ymm10) +/- ymm4
|
||||
vfmsubadd231pd %ymm11, %ymm1, %ymm5
|
||||
vfmaddsub231pd %ymm13, %ymm2, %ymm6
|
||||
vfmsubadd231pd %ymm15, %ymm3, %ymm7
|
||||
vsubpd %ymm4,%ymm8,%ymm9
|
||||
vaddpd %ymm5,%ymm10,%ymm11
|
||||
vsubpd %ymm6,%ymm12,%ymm13
|
||||
vaddpd %ymm7,%ymm14,%ymm15
|
||||
vaddpd %ymm4,%ymm8,%ymm8
|
||||
vsubpd %ymm5,%ymm10,%ymm10
|
||||
vaddpd %ymm6,%ymm12,%ymm12
|
||||
vsubpd %ymm7,%ymm14,%ymm14
|
||||
|
||||
.fourth_pass:
|
||||
vmovupd 64(%rsi),%ymm0 /* gamma */
|
||||
vmovupd 96(%rsi),%ymm2 /* delta */
|
||||
vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: gama.iiii */
|
||||
vshufpd $15, %ymm2, %ymm2, %ymm3 /* ymm3: delta.iiii */
|
||||
vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: gama.rrrr */
|
||||
vshufpd $0, %ymm2, %ymm2, %ymm2 /* ymm2: delta.rrrr */
|
||||
vperm2f128 $0x31,%ymm10,%ymm8,%ymm4 # ymm4 contains c1,c5 -- x gamma
|
||||
vperm2f128 $0x31,%ymm11,%ymm9,%ymm5 # ymm5 contains c3,c7 -- x igamma
|
||||
vperm2f128 $0x31,%ymm14,%ymm12,%ymm6 # ymm6 contains c9,c13 -- x delta
|
||||
vperm2f128 $0x31,%ymm15,%ymm13,%ymm7 # ymm7 contains c11,c15 -- x idelta
|
||||
vperm2f128 $0x20,%ymm10,%ymm8,%ymm8 # ymm8 contains c0,c4
|
||||
vperm2f128 $0x20,%ymm11,%ymm9,%ymm9 # ymm9 contains c2,c6
|
||||
vperm2f128 $0x20,%ymm14,%ymm12,%ymm10 # ymm10 contains c8,c12
|
||||
vperm2f128 $0x20,%ymm15,%ymm13,%ymm11 # ymm11 contains c10,c14
|
||||
vshufpd $5, %ymm4, %ymm4, %ymm12
|
||||
vshufpd $5, %ymm5, %ymm5, %ymm13
|
||||
vshufpd $5, %ymm6, %ymm6, %ymm14
|
||||
vshufpd $5, %ymm7, %ymm7, %ymm15
|
||||
vmulpd %ymm12,%ymm1,%ymm12
|
||||
vmulpd %ymm13,%ymm0,%ymm13
|
||||
vmulpd %ymm14,%ymm3,%ymm14
|
||||
vmulpd %ymm15,%ymm2,%ymm15
|
||||
vfmaddsub231pd %ymm4, %ymm0, %ymm12 # ymm12 = (ymm0 * ymm4) +/- ymm12
|
||||
vfmsubadd231pd %ymm5, %ymm1, %ymm13
|
||||
vfmaddsub231pd %ymm6, %ymm2, %ymm14
|
||||
vfmsubadd231pd %ymm7, %ymm3, %ymm15
|
||||
vsubpd %ymm12,%ymm8,%ymm4
|
||||
vaddpd %ymm13,%ymm9,%ymm5
|
||||
vsubpd %ymm14,%ymm10,%ymm6
|
||||
vaddpd %ymm15,%ymm11,%ymm7
|
||||
vaddpd %ymm12,%ymm8,%ymm8
|
||||
vsubpd %ymm13,%ymm9,%ymm9
|
||||
vaddpd %ymm14,%ymm10,%ymm10
|
||||
vsubpd %ymm15,%ymm11,%ymm11
|
||||
|
||||
vperm2f128 $0x20,%ymm6,%ymm10,%ymm12 # ymm4 contains c1,c5 -- x gamma
|
||||
vperm2f128 $0x20,%ymm7,%ymm11,%ymm13 # ymm5 contains c3,c7 -- x igamma
|
||||
vperm2f128 $0x31,%ymm6,%ymm10,%ymm14 # ymm6 contains c9,c13 -- x delta
|
||||
vperm2f128 $0x31,%ymm7,%ymm11,%ymm15 # ymm7 contains c11,c15 -- x idelta
|
||||
vperm2f128 $0x31,%ymm4,%ymm8,%ymm10 # ymm10 contains c8,c12
|
||||
vperm2f128 $0x31,%ymm5,%ymm9,%ymm11 # ymm11 contains c10,c14
|
||||
vperm2f128 $0x20,%ymm4,%ymm8,%ymm8 # ymm8 contains c0,c4
|
||||
vperm2f128 $0x20,%ymm5,%ymm9,%ymm9 # ymm9 contains c2,c6
|
||||
|
||||
.save_and_return:
|
||||
vmovupd %ymm8,(%rdi)
|
||||
vmovupd %ymm9,0x20(%rdi)
|
||||
vmovupd %ymm10,0x40(%rdi)
|
||||
vmovupd %ymm11,0x60(%rdi)
|
||||
vmovupd %ymm12,0x80(%rdi)
|
||||
vmovupd %ymm13,0xa0(%rdi)
|
||||
vmovupd %ymm14,0xc0(%rdi)
|
||||
vmovupd %ymm15,0xe0(%rdi)
|
||||
ret
|
||||
.size cplx_fft16_avx_fma, .-cplx_fft16_avx_fma
|
||||
.section .note.GNU-stack,"",@progbits
|
||||
@@ -0,0 +1,190 @@
|
||||
.text
|
||||
.p2align 4
|
||||
.globl cplx_fft16_avx_fma
|
||||
.def cplx_fft16_avx_fma; .scl 2; .type 32; .endef
|
||||
cplx_fft16_avx_fma:
|
||||
|
||||
pushq %rdi
|
||||
pushq %rsi
|
||||
movq %rcx,%rdi
|
||||
movq %rdx,%rsi
|
||||
subq $0x100,%rsp
|
||||
movdqu %xmm6,(%rsp)
|
||||
movdqu %xmm7,0x10(%rsp)
|
||||
movdqu %xmm8,0x20(%rsp)
|
||||
movdqu %xmm9,0x30(%rsp)
|
||||
movdqu %xmm10,0x40(%rsp)
|
||||
movdqu %xmm11,0x50(%rsp)
|
||||
movdqu %xmm12,0x60(%rsp)
|
||||
movdqu %xmm13,0x70(%rsp)
|
||||
movdqu %xmm14,0x80(%rsp)
|
||||
movdqu %xmm15,0x90(%rsp)
|
||||
callq cplx_fft16_avx_fma_amd64
|
||||
movdqu (%rsp),%xmm6
|
||||
movdqu 0x10(%rsp),%xmm7
|
||||
movdqu 0x20(%rsp),%xmm8
|
||||
movdqu 0x30(%rsp),%xmm9
|
||||
movdqu 0x40(%rsp),%xmm10
|
||||
movdqu 0x50(%rsp),%xmm11
|
||||
movdqu 0x60(%rsp),%xmm12
|
||||
movdqu 0x70(%rsp),%xmm13
|
||||
movdqu 0x80(%rsp),%xmm14
|
||||
movdqu 0x90(%rsp),%xmm15
|
||||
addq $0x100,%rsp
|
||||
popq %rsi
|
||||
popq %rdi
|
||||
retq
|
||||
|
||||
# shifted FFT over X^16-i
|
||||
# 1st argument (rdi) contains 16 complexes
|
||||
# 2nd argument (rsi) contains: 8 complexes
|
||||
# omega,alpha,beta,j.beta,gamma,j.gamma,k.gamma,kj.gamma
|
||||
# alpha = sqrt(omega), beta = sqrt(alpha), gamma = sqrt(beta)
|
||||
# j = sqrt(i), k=sqrt(j)
|
||||
cplx_fft16_avx_fma_amd64:
|
||||
vmovupd (%rdi),%ymm8
|
||||
vmovupd 0x20(%rdi),%ymm9
|
||||
vmovupd 0x40(%rdi),%ymm10
|
||||
vmovupd 0x60(%rdi),%ymm11
|
||||
vmovupd 0x80(%rdi),%ymm12
|
||||
vmovupd 0xa0(%rdi),%ymm13
|
||||
vmovupd 0xc0(%rdi),%ymm14
|
||||
vmovupd 0xe0(%rdi),%ymm15
|
||||
|
||||
.first_pass:
|
||||
vmovupd (%rsi),%xmm0 /* omri */
|
||||
vinsertf128 $1, %xmm0, %ymm0, %ymm0 /* omriri */
|
||||
vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: omiiii */
|
||||
vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: omrrrr */
|
||||
vshufpd $5, %ymm12, %ymm12, %ymm4
|
||||
vshufpd $5, %ymm13, %ymm13, %ymm5
|
||||
vshufpd $5, %ymm14, %ymm14, %ymm6
|
||||
vshufpd $5, %ymm15, %ymm15, %ymm7
|
||||
vmulpd %ymm4,%ymm1,%ymm4
|
||||
vmulpd %ymm5,%ymm1,%ymm5
|
||||
vmulpd %ymm6,%ymm1,%ymm6
|
||||
vmulpd %ymm7,%ymm1,%ymm7
|
||||
vfmaddsub231pd %ymm12, %ymm0, %ymm4 # ymm4 = (ymm0 * ymm12) +/- ymm4
|
||||
vfmaddsub231pd %ymm13, %ymm0, %ymm5
|
||||
vfmaddsub231pd %ymm14, %ymm0, %ymm6
|
||||
vfmaddsub231pd %ymm15, %ymm0, %ymm7
|
||||
vsubpd %ymm4,%ymm8,%ymm12
|
||||
vsubpd %ymm5,%ymm9,%ymm13
|
||||
vsubpd %ymm6,%ymm10,%ymm14
|
||||
vsubpd %ymm7,%ymm11,%ymm15
|
||||
vaddpd %ymm4,%ymm8,%ymm8
|
||||
vaddpd %ymm5,%ymm9,%ymm9
|
||||
vaddpd %ymm6,%ymm10,%ymm10
|
||||
vaddpd %ymm7,%ymm11,%ymm11
|
||||
|
||||
.second_pass:
|
||||
vmovupd 16(%rsi),%xmm0 /* omri */
|
||||
vinsertf128 $1, %xmm0, %ymm0, %ymm0 /* omriri */
|
||||
vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: omiiii */
|
||||
vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: omrrrr */
|
||||
vshufpd $5, %ymm10, %ymm10, %ymm4
|
||||
vshufpd $5, %ymm11, %ymm11, %ymm5
|
||||
vshufpd $5, %ymm14, %ymm14, %ymm6
|
||||
vshufpd $5, %ymm15, %ymm15, %ymm7
|
||||
vmulpd %ymm4,%ymm1,%ymm4
|
||||
vmulpd %ymm5,%ymm1,%ymm5
|
||||
vmulpd %ymm6,%ymm0,%ymm6
|
||||
vmulpd %ymm7,%ymm0,%ymm7
|
||||
vfmaddsub231pd %ymm10, %ymm0, %ymm4 # ymm4 = (ymm0 * ymm10) +/- ymm4
|
||||
vfmaddsub231pd %ymm11, %ymm0, %ymm5
|
||||
vfmsubadd231pd %ymm14, %ymm1, %ymm6
|
||||
vfmsubadd231pd %ymm15, %ymm1, %ymm7
|
||||
vsubpd %ymm4,%ymm8,%ymm10
|
||||
vsubpd %ymm5,%ymm9,%ymm11
|
||||
vaddpd %ymm6,%ymm12,%ymm14
|
||||
vaddpd %ymm7,%ymm13,%ymm15
|
||||
vaddpd %ymm4,%ymm8,%ymm8
|
||||
vaddpd %ymm5,%ymm9,%ymm9
|
||||
vsubpd %ymm6,%ymm12,%ymm12
|
||||
vsubpd %ymm7,%ymm13,%ymm13
|
||||
|
||||
.third_pass:
|
||||
vmovupd 32(%rsi),%xmm0 /* gamma */
|
||||
vmovupd 48(%rsi),%xmm2 /* delta */
|
||||
vinsertf128 $1, %xmm0, %ymm0, %ymm0
|
||||
vinsertf128 $1, %xmm2, %ymm2, %ymm2
|
||||
vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: gama.iiii */
|
||||
vshufpd $15, %ymm2, %ymm2, %ymm3 /* ymm3: delta.iiii */
|
||||
vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: gama.rrrr */
|
||||
vshufpd $0, %ymm2, %ymm2, %ymm2 /* ymm2: delta.rrrr */
|
||||
vshufpd $5, %ymm9, %ymm9, %ymm4
|
||||
vshufpd $5, %ymm11, %ymm11, %ymm5
|
||||
vshufpd $5, %ymm13, %ymm13, %ymm6
|
||||
vshufpd $5, %ymm15, %ymm15, %ymm7
|
||||
vmulpd %ymm4,%ymm1,%ymm4
|
||||
vmulpd %ymm5,%ymm0,%ymm5
|
||||
vmulpd %ymm6,%ymm3,%ymm6
|
||||
vmulpd %ymm7,%ymm2,%ymm7
|
||||
vfmaddsub231pd %ymm9, %ymm0, %ymm4 # ymm4 = (ymm0 * ymm10) +/- ymm4
|
||||
vfmsubadd231pd %ymm11, %ymm1, %ymm5
|
||||
vfmaddsub231pd %ymm13, %ymm2, %ymm6
|
||||
vfmsubadd231pd %ymm15, %ymm3, %ymm7
|
||||
vsubpd %ymm4,%ymm8,%ymm9
|
||||
vaddpd %ymm5,%ymm10,%ymm11
|
||||
vsubpd %ymm6,%ymm12,%ymm13
|
||||
vaddpd %ymm7,%ymm14,%ymm15
|
||||
vaddpd %ymm4,%ymm8,%ymm8
|
||||
vsubpd %ymm5,%ymm10,%ymm10
|
||||
vaddpd %ymm6,%ymm12,%ymm12
|
||||
vsubpd %ymm7,%ymm14,%ymm14
|
||||
|
||||
.fourth_pass:
|
||||
vmovupd 64(%rsi),%ymm0 /* gamma */
|
||||
vmovupd 96(%rsi),%ymm2 /* delta */
|
||||
vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: gama.iiii */
|
||||
vshufpd $15, %ymm2, %ymm2, %ymm3 /* ymm3: delta.iiii */
|
||||
vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: gama.rrrr */
|
||||
vshufpd $0, %ymm2, %ymm2, %ymm2 /* ymm2: delta.rrrr */
|
||||
vperm2f128 $0x31,%ymm10,%ymm8,%ymm4 # ymm4 contains c1,c5 -- x gamma
|
||||
vperm2f128 $0x31,%ymm11,%ymm9,%ymm5 # ymm5 contains c3,c7 -- x igamma
|
||||
vperm2f128 $0x31,%ymm14,%ymm12,%ymm6 # ymm6 contains c9,c13 -- x delta
|
||||
vperm2f128 $0x31,%ymm15,%ymm13,%ymm7 # ymm7 contains c11,c15 -- x idelta
|
||||
vperm2f128 $0x20,%ymm10,%ymm8,%ymm8 # ymm8 contains c0,c4
|
||||
vperm2f128 $0x20,%ymm11,%ymm9,%ymm9 # ymm9 contains c2,c6
|
||||
vperm2f128 $0x20,%ymm14,%ymm12,%ymm10 # ymm10 contains c8,c12
|
||||
vperm2f128 $0x20,%ymm15,%ymm13,%ymm11 # ymm11 contains c10,c14
|
||||
vshufpd $5, %ymm4, %ymm4, %ymm12
|
||||
vshufpd $5, %ymm5, %ymm5, %ymm13
|
||||
vshufpd $5, %ymm6, %ymm6, %ymm14
|
||||
vshufpd $5, %ymm7, %ymm7, %ymm15
|
||||
vmulpd %ymm12,%ymm1,%ymm12
|
||||
vmulpd %ymm13,%ymm0,%ymm13
|
||||
vmulpd %ymm14,%ymm3,%ymm14
|
||||
vmulpd %ymm15,%ymm2,%ymm15
|
||||
vfmaddsub231pd %ymm4, %ymm0, %ymm12 # ymm12 = (ymm0 * ymm4) +/- ymm12
|
||||
vfmsubadd231pd %ymm5, %ymm1, %ymm13
|
||||
vfmaddsub231pd %ymm6, %ymm2, %ymm14
|
||||
vfmsubadd231pd %ymm7, %ymm3, %ymm15
|
||||
vsubpd %ymm12,%ymm8,%ymm4
|
||||
vaddpd %ymm13,%ymm9,%ymm5
|
||||
vsubpd %ymm14,%ymm10,%ymm6
|
||||
vaddpd %ymm15,%ymm11,%ymm7
|
||||
vaddpd %ymm12,%ymm8,%ymm8
|
||||
vsubpd %ymm13,%ymm9,%ymm9
|
||||
vaddpd %ymm14,%ymm10,%ymm10
|
||||
vsubpd %ymm15,%ymm11,%ymm11
|
||||
|
||||
vperm2f128 $0x20,%ymm6,%ymm10,%ymm12 # ymm4 contains c1,c5 -- x gamma
|
||||
vperm2f128 $0x20,%ymm7,%ymm11,%ymm13 # ymm5 contains c3,c7 -- x igamma
|
||||
vperm2f128 $0x31,%ymm6,%ymm10,%ymm14 # ymm6 contains c9,c13 -- x delta
|
||||
vperm2f128 $0x31,%ymm7,%ymm11,%ymm15 # ymm7 contains c11,c15 -- x idelta
|
||||
vperm2f128 $0x31,%ymm4,%ymm8,%ymm10 # ymm10 contains c8,c12
|
||||
vperm2f128 $0x31,%ymm5,%ymm9,%ymm11 # ymm11 contains c10,c14
|
||||
vperm2f128 $0x20,%ymm4,%ymm8,%ymm8 # ymm8 contains c0,c4
|
||||
vperm2f128 $0x20,%ymm5,%ymm9,%ymm9 # ymm9 contains c2,c6
|
||||
|
||||
.save_and_return:
|
||||
vmovupd %ymm8,(%rdi)
|
||||
vmovupd %ymm9,0x20(%rdi)
|
||||
vmovupd %ymm10,0x40(%rdi)
|
||||
vmovupd %ymm11,0x60(%rdi)
|
||||
vmovupd %ymm12,0x80(%rdi)
|
||||
vmovupd %ymm13,0xa0(%rdi)
|
||||
vmovupd %ymm14,0xc0(%rdi)
|
||||
vmovupd %ymm15,0xe0(%rdi)
|
||||
ret
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user