diff --git a/Cargo.lock b/Cargo.lock index e37f7d3..2185f0c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -347,12 +347,13 @@ dependencies = [ [[package]] name = "poulpy-backend" -version = "0.1.0" +version = "0.1.2" dependencies = [ "byteorder", "cmake", "criterion", "itertools 0.14.0", + "poulpy-hal 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)", "rand", "rand_chacha", "rand_core", @@ -362,9 +363,51 @@ dependencies = [ [[package]] name = "poulpy-backend" -version = "0.1.0" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d47fbc27d0c03c2bfffd972795c62a243e4a3a3068acdb95ef55fb335a58d00f" +checksum = "e0c6c0ad35bd5399e72a7d51b8bad5aa03e54bfd63bf1a09c4a595bd51145ca6" +dependencies = [ + "byteorder", + "cmake", + "criterion", + "itertools 0.14.0", + "poulpy-hal 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)", + "rand", + "rand_chacha", + "rand_core", + "rand_distr", + "rug", +] + +[[package]] +name = "poulpy-core" +version = "0.1.1" +dependencies = [ + "byteorder", + "criterion", + "itertools 0.14.0", + "poulpy-backend 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)", + "poulpy-hal 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)", + "rug", +] + +[[package]] +name = "poulpy-core" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34afc307c185e288395d9f298a3261177dc850229e2bd6d53aa4059ae7e98cab" +dependencies = [ + "byteorder", + "criterion", + "itertools 0.14.0", + "poulpy-backend 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)", + "poulpy-hal 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)", + "rug", +] + +[[package]] +name = "poulpy-hal" +version = "0.1.2" dependencies = [ "byteorder", "cmake", @@ -378,26 +421,17 @@ dependencies = [ ] [[package]] -name = "poulpy-core" -version = "0.1.0" -dependencies = [ - "byteorder", - "criterion", - "itertools 0.14.0", - "poulpy-backend 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)", - "rug", -] - -[[package]] -name = "poulpy-core" -version = "0.1.0" +name = "poulpy-hal" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ff4e1acd3f4a84e861b07184fd28fe3143a57360bd51e923aeadbc94b8b38d0" +checksum = "63312a7be7c5fd91e1f5151735d646294a4592d80027d8e90778076b2070a0ec" dependencies = [ "byteorder", + "cmake", "criterion", "itertools 0.14.0", - "poulpy-backend 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)", + "rand", + "rand_chacha", "rand_core", "rand_distr", "rug", @@ -405,12 +439,13 @@ dependencies = [ [[package]] name = "poulpy-schemes" -version = "0.1.0" +version = "0.1.1" dependencies = [ "byteorder", "itertools 0.14.0", - "poulpy-backend 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)", - "poulpy-core 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)", + "poulpy-backend 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)", + "poulpy-core 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)", + "poulpy-hal 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 8774029..2b08242 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["poulpy-backend", "poulpy-core", "poulpy-schemes"] +members = ["poulpy-hal", "poulpy-core", "poulpy-backend", "poulpy-schemes"] resolver = "3" [workspace.dependencies] diff --git a/README.md b/README.md index 8d19123..efc2df0 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,4 @@ + # 🐙 Poulpy

@@ -8,6 +9,19 @@ **Poulpy** is a fast & modular FHE library that implements Ring-Learning-With-Errors based homomorphic encryption. It adopts the bivariate polynomial representation proposed in [Revisiting Key Decomposition Techniques for FHE: Simpler, Faster and More Generic](https://eprint.iacr.org/2023/771). In addition to simpler and more efficient arithmetic than the residue number system (RNS), this representation provides a common plaintext space for all schemes and allows easy switching between any two schemes. Poulpy also decouples the schemes implementations from the polynomial arithmetic backend by being built around a hardware abstraction layer (HAL). This enables user to easily provide or use a custom backend. +## Library Overview + +- **`poulpy-hal`**: a crate providing layouts and a trait-based hardware acceleration layer with open extension points, matching the API and types of spqlios-arithmetic. + - **`api`**: fixed public low-level polynomial level arithmetic API closely matching spqlios-arithmetic. + - **`delegates`**: link between the user facing API and implementation OEP. Each trait of `api` is implemented by calling its corresponding trait on the `oep`. + - **`layouts`**: layouts of the front-end algebraic structs matching spqlios-arithmetic types, such as `ScalarZnx`, `VecZnx` or opaque backend prepared struct such as `SvpPPol` and `VmpPMat`. + - **`oep`**: open extension points, which can be (re-)implemented by the user to provide a concrete backend. + - **`tests`**: backend agnostic & generic tests for the OEP/layouts. +- **`poulpy-backend`**: a crate providing concrete implementations of **`poulpy-hal`**. + - **`cpu_spqlios`**: cpu implementation of **`poulpy-hal`** through the `oep` using bindings on spqlios-arithmetic. This implementation currently supports the `FFT64` backend and will be extended to support the `NTT120` backend once it is available in spqlios-arithmetic. +- **`poulpy-core`**: a backend agnostic crate implementing scheme agnostic RLWE arithmetic for LWE, GLWE, GGLWE and GGSW ciphertexts using **`poulpy-hal`**. +- **`poulpy-schemes`**: a backend agnostic crate implementing mainstream FHE schemes using **`poulpy-core`** and **`poulpy-hal`**. + ### Bivariate Polynomial Representation Existing FHE implementations (such as [Lattigo](https://github.com/tuneinsight/lattigo) or [OpenFHE](https://github.com/openfheorg/openfhe-development)) use the [residue-number-system](https://en.wikipedia.org/wiki/Residue_number_system) (RNS) to represent large integers. Although the parallelism and carry-less arithmetic provided by the RNS representation provides a very efficient modular arithmetic over large-integers, it suffers from various drawbacks when used in the context of FHE. The main idea behind the bivariate representation is to decouple the cyclotomic arithmetic from the large number arithmetic. Instead of using the RNS representation for large integer, integers are decomposed in base $2^{-K}$ over the Torus $\mathbb{T}_{N}[X]$. @@ -30,114 +44,13 @@ This provides the following benefits: In addition to providing a general purpose FHE library over a unified plaintext space, Poulpy is also designed from the ground up around a **hardware abstraction layer** that closely matches the API of [spqlios-arithmetic](https://github.com/tfhe/spqlios-arithmetic). The bivariate representation is by itself hardware friendly as it uses flat, aligned & vectorized memory layout. Finally, generic opaque write only structs (prepared versions) are provided, making it easy for developers to provide hardware focused/optimized operations. This makes possible for anyone to provide or use a custom backend. -## Library Overview - -- **`backend/hal`**: hardware abstraction layer. This layer targets users that want to provide their own backend or use a third party backend. - - - **`api`**: fixed public low-level polynomial level arithmetic API closely matching spqlios-arithmetic. The goal is to eventually freeze this API, in order to decouple it from the OEP traits, ensuring that changes to implementations do not affect the front end API. - - ```rust - pub trait SvpPrepare { - fn svp_prepare(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: SvpPPolToMut, - A: ScalarZnxToRef; - } - ```` - - - **`delegates`**: link between the user facing API and implementation OEP. Each trait of `api` is implemented by calling its corresponding trait on the `oep`. - - ```rust - impl SvpPrepare for Module - where - B: Backend + SvpPrepareImpl, - { - fn svp_prepare(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: SvpPPolToMut, - A: ScalarZnxToRef, - { - B::svp_prepare_impl(self, res, res_col, a, a_col); - } - } - ``` - - - **`layouts`**: defines the layouts of the front-end algebraic structs matching spqlios-arithmetic definitions, such as `ScalarZnx`, `VecZnx` or opaque backend prepared struct such as `SvpPPol` and `VmpPMat`. - - ```rust - pub struct SvpPPol { - data: D, - n: usize, - cols: usize, - _phantom: PhantomData, - } - ``` - - - **`oep`**: open extension points, which can be implemented by the user to provide a custom backend. - - ```rust - pub unsafe trait SvpPrepareImpl { - fn svp_prepare_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: SvpPPolToMut, - A: ScalarZnxToRef; - } - ``` - - - **`tests`**: exported generic tests for the OEP/structs. Their goal is to enable a user to automatically be able to test its backend implementation, without having to re-implement any tests. - -- **`backend/implementation`**: - - **`cpu_spqlios`**: concrete cpu implementation of the hal through the oep using bindings on spqlios-arithmetic. This implementation currently supports the `FFT64` backend and will be extended to support the `NTT120` backend once it is available in spqlios-arithmetic. - - ```rust - unsafe impl SvpPrepareImpl for FFT64 { - fn svp_prepare_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: SvpPPolToMut, - A: ScalarZnxToRef, - { - unsafe { - svp::svp_prepare( - module.ptr(), - res.to_mut().at_mut_ptr(res_col, 0) as *mut svp::svp_ppol_t, - a.to_ref().at_ptr(a_col, 0), - ) - } - } - } - ``` - -- **`core`**: core of the FHE library, implementing scheme agnostic RLWE arithmetic for LWE, GLWE, GGLWE and GGSW ciphertexts. It notably includes all possible cross-ciphertext operations, for example applying an external product on a GGLWE or an automorphism on a GGSW, as well as blind rotation. This crate is entirely implemented using the hardware abstraction layer API, and is thus solely defined over generic and traits (including tests). As such it will work over any backend, as long as it implements the necessary traits defined in the OEP. - - ```rust - pub struct GLWESecret { - pub(crate) data: ScalarZnx, - pub(crate) dist: Distribution, - } - - pub struct GLWESecrecPrepared { - pub(crate) data: SvpPPol, - pub(crate) dist: Distribution, - } - - impl GLWESecretPrepared { - pub fn prepare(&mut self, module: &Module, sk: &GLWESecret) - where - O: DataRef, - Module: SvpPrepare, - { - (0..self.rank()).for_each(|i| { - module.svp_prepare(&mut self.data, i, &sk.data, i); - }); - self.dist = sk.dist - } - } - ``` - ## Installation -TBD — currently not published on crates.io. Clone the repository and use via path-based dependencies. - +- **`poulpy-hal`**: https://crates.io/crates/poulpy-hal/0.1.0 +- **`poulpy-backend`**: https://crates.io/crates/poulpy-backend/0.1.0 +- **`poulpy-core`**: https://crates.io/crates/poulpy-core/0.1.0 +- **`poulpy-schemes`**: https://crates.io/crates/poulpy-schemes/0.1.0 +- ## Documentation * Full `cargo doc` documentation is coming soon. diff --git a/poulpy-backend/Cargo.toml b/poulpy-backend/Cargo.toml index e23ee54..62c5906 100644 --- a/poulpy-backend/Cargo.toml +++ b/poulpy-backend/Cargo.toml @@ -1,27 +1,28 @@ -[package] -name = "poulpy-backend" -version = "0.1.0" -edition = "2024" -license = "Apache-2.0" -readme = "README.md" -description = "A crate implementing bivariate polynomial arithmetic" -repository = "https://github.com/phantomzone-org/poulpy" -homepage = "https://github.com/phantomzone-org/poulpy" -documentation = "https://docs.rs/poulpy" - -[dependencies] -rug = {workspace = true} -criterion = {workspace = true} -itertools = {workspace = true} -rand = {workspace = true} -rand_distr = {workspace = true} -rand_core = {workspace = true} -byteorder = {workspace = true} -rand_chacha = "0.9.0" - -[build-dependencies] -cmake = "0.1.54" - -[package.metadata.docs.rs] -all-features = true +[package] +name = "poulpy-backend" +version = "0.1.2" +edition = "2024" +license = "Apache-2.0" +readme = "README.md" +description = "A crate providing concrete implementations of poulpy-hal through its open extension points" +repository = "https://github.com/phantomzone-org/poulpy" +homepage = "https://github.com/phantomzone-org/poulpy" +documentation = "https://docs.rs/poulpy" + +[dependencies] +poulpy-hal = "0.1.2" +rug = {workspace = true} +criterion = {workspace = true} +itertools = {workspace = true} +rand = {workspace = true} +rand_distr = {workspace = true} +rand_core = {workspace = true} +byteorder = {workspace = true} +rand_chacha = "0.9.0" + +[build-dependencies] +cmake = "0.1.54" + +[package.metadata.docs.rs] +all-features = true rustdoc-args = ["--cfg", "docsrs"] \ No newline at end of file diff --git a/poulpy-backend/README.md b/poulpy-backend/README.md index 86c8c65..23c72b0 100644 --- a/poulpy-backend/README.md +++ b/poulpy-backend/README.md @@ -1,12 +1,15 @@ - -## WSL/Ubuntu -To use this crate you need to build spqlios-arithmetic, which is provided a as a git submodule: -1) Initialize the sub-module -2) $ cd backend/spqlios-arithmetic -3) mdkir build -4) cd build -5) cmake .. -6) make - -## Others + + +## spqlios-arithmetic + +### WSL/Ubuntu +To use this crate you need to build spqlios-arithmetic, which is provided a as a git submodule: +1) Initialize the sub-module +2) $ cd backend/spqlios-arithmetic +3) mdkir build +4) cd build +5) cmake .. +6) make + +### Others Steps 3 to 6 might change depending of your platform. See [spqlios-arithmetic/wiki/build](https://github.com/tfhe/spqlios-arithmetic/wiki/build) for additional information and build options. \ No newline at end of file diff --git a/poulpy-backend/builds/cpu_spqlios.rs b/poulpy-backend/builds/cpu_spqlios.rs index 0f6c07c..8abb07a 100644 --- a/poulpy-backend/builds/cpu_spqlios.rs +++ b/poulpy-backend/builds/cpu_spqlios.rs @@ -1,7 +1,7 @@ use std::path::PathBuf; pub fn build() { - let dst: PathBuf = cmake::Config::new("src/implementation/cpu_spqlios/spqlios-arithmetic") + let dst: PathBuf = cmake::Config::new("src/cpu_spqlios/spqlios-arithmetic") .define("ENABLE_TESTING", "FALSE") .build(); diff --git a/poulpy-backend/examples/rlwe_encrypt.rs b/poulpy-backend/examples/rlwe_encrypt.rs index 392b673..2e970dc 100644 --- a/poulpy-backend/examples/rlwe_encrypt.rs +++ b/poulpy-backend/examples/rlwe_encrypt.rs @@ -1,15 +1,13 @@ use itertools::izip; -use poulpy_backend::{ - hal::{ - api::{ - ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPrepare, VecZnxAddNormal, - VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace, - VecZnxDftAlloc, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigTmpA, VecZnxFillUniform, VecZnxNormalizeInplace, ZnxInfos, - }, - layouts::{Module, ScalarZnx, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft}, - source::Source, +use poulpy_backend::cpu_spqlios::FFT64; +use poulpy_hal::{ + api::{ + ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPrepare, VecZnxAddNormal, + VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace, + VecZnxDftAlloc, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigTmpA, VecZnxFillUniform, VecZnxNormalizeInplace, ZnxInfos, }, - implementation::cpu_spqlios::FFT64, + layouts::{Module, ScalarZnx, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft}, + source::Source, }; fn main() { diff --git a/poulpy-backend/src/implementation/cpu_spqlios/ffi/cnv.rs b/poulpy-backend/src/cpu_spqlios/ffi/cnv.rs similarity index 100% rename from poulpy-backend/src/implementation/cpu_spqlios/ffi/cnv.rs rename to poulpy-backend/src/cpu_spqlios/ffi/cnv.rs diff --git a/poulpy-backend/src/cpu_spqlios/ffi/mod.rs b/poulpy-backend/src/cpu_spqlios/ffi/mod.rs new file mode 100644 index 0000000..6d40a1e --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/ffi/mod.rs @@ -0,0 +1,15 @@ +#[allow(non_camel_case_types)] +pub mod module; +#[allow(non_camel_case_types)] +pub mod svp; +#[allow(non_camel_case_types)] +pub mod vec_znx; +#[allow(dead_code)] +#[allow(non_camel_case_types)] +pub mod vec_znx_big; +#[allow(non_camel_case_types)] +pub mod vec_znx_dft; +#[allow(non_camel_case_types)] +pub mod vmp; +#[allow(non_camel_case_types)] +pub mod znx; diff --git a/poulpy-backend/src/implementation/cpu_spqlios/ffi/module.rs b/poulpy-backend/src/cpu_spqlios/ffi/module.rs similarity index 79% rename from poulpy-backend/src/implementation/cpu_spqlios/ffi/module.rs rename to poulpy-backend/src/cpu_spqlios/ffi/module.rs index b593448..9e67c30 100644 --- a/poulpy-backend/src/implementation/cpu_spqlios/ffi/module.rs +++ b/poulpy-backend/src/cpu_spqlios/ffi/module.rs @@ -1,19 +1,17 @@ -pub struct module_info_t { - _unused: [u8; 0], -} - -pub type module_type_t = ::std::os::raw::c_uint; -pub use self::module_type_t as MODULE_TYPE; - -#[allow(clippy::upper_case_acronyms)] -pub type MODULE = module_info_t; - -unsafe extern "C" { - pub unsafe fn new_module_info(N: u64, mode: MODULE_TYPE) -> *mut MODULE; -} -unsafe extern "C" { - pub unsafe fn delete_module_info(module_info: *mut MODULE); -} -unsafe extern "C" { - pub unsafe fn module_get_n(module: *const MODULE) -> u64; -} +#[repr(C)] +pub struct module_info_t { + _unused: [u8; 0], +} + +pub type module_type_t = ::std::os::raw::c_uint; +pub use self::module_type_t as MODULE_TYPE; + +#[allow(clippy::upper_case_acronyms)] +pub type MODULE = module_info_t; + +unsafe extern "C" { + pub unsafe fn new_module_info(N: u64, mode: MODULE_TYPE) -> *mut MODULE; +} +unsafe extern "C" { + pub unsafe fn delete_module_info(module_info: *mut MODULE); +} diff --git a/poulpy-backend/src/implementation/cpu_spqlios/ffi/svp.rs b/poulpy-backend/src/cpu_spqlios/ffi/svp.rs similarity index 68% rename from poulpy-backend/src/implementation/cpu_spqlios/ffi/svp.rs rename to poulpy-backend/src/cpu_spqlios/ffi/svp.rs index f9db97f..1b72fb3 100644 --- a/poulpy-backend/src/implementation/cpu_spqlios/ffi/svp.rs +++ b/poulpy-backend/src/cpu_spqlios/ffi/svp.rs @@ -1,4 +1,4 @@ -use crate::implementation::cpu_spqlios::ffi::{module::MODULE, vec_znx_dft::VEC_ZNX_DFT}; +use crate::cpu_spqlios::ffi::{module::MODULE, vec_znx_dft::VEC_ZNX_DFT}; #[repr(C)] #[derive(Debug, Copy, Clone)] @@ -7,20 +7,11 @@ pub struct svp_ppol_t { } pub type SVP_PPOL = svp_ppol_t; -unsafe extern "C" { - pub unsafe fn bytes_of_svp_ppol(module: *const MODULE) -> u64; -} -unsafe extern "C" { - pub unsafe fn new_svp_ppol(module: *const MODULE) -> *mut SVP_PPOL; -} -unsafe extern "C" { - pub unsafe fn delete_svp_ppol(res: *mut SVP_PPOL); -} - unsafe extern "C" { pub unsafe fn svp_prepare(module: *const MODULE, ppol: *mut SVP_PPOL, pol: *const i64); } +#[allow(dead_code)] unsafe extern "C" { pub unsafe fn svp_apply_dft( module: *const MODULE, diff --git a/poulpy-backend/src/implementation/cpu_spqlios/ffi/vec_znx.rs b/poulpy-backend/src/cpu_spqlios/ffi/vec_znx.rs similarity index 91% rename from poulpy-backend/src/implementation/cpu_spqlios/ffi/vec_znx.rs rename to poulpy-backend/src/cpu_spqlios/ffi/vec_znx.rs index f4ea531..020fb9e 100644 --- a/poulpy-backend/src/implementation/cpu_spqlios/ffi/vec_znx.rs +++ b/poulpy-backend/src/cpu_spqlios/ffi/vec_znx.rs @@ -1,4 +1,4 @@ -use crate::implementation::cpu_spqlios::ffi::module::MODULE; +use crate::cpu_spqlios::ffi::module::MODULE; unsafe extern "C" { pub unsafe fn vec_znx_add( @@ -53,6 +53,7 @@ unsafe extern "C" { ); } +#[allow(dead_code)] unsafe extern "C" { pub unsafe fn vec_znx_rotate( module: *const MODULE, @@ -81,9 +82,12 @@ unsafe extern "C" { ); } +#[allow(dead_code)] unsafe extern "C" { pub unsafe fn vec_znx_zero(module: *const MODULE, res: *mut i64, res_size: u64, res_sl: u64); } + +#[allow(dead_code)] unsafe extern "C" { pub unsafe fn vec_znx_copy( module: *const MODULE, diff --git a/poulpy-backend/src/implementation/cpu_spqlios/ffi/vec_znx_big.rs b/poulpy-backend/src/cpu_spqlios/ffi/vec_znx_big.rs similarity index 86% rename from poulpy-backend/src/implementation/cpu_spqlios/ffi/vec_znx_big.rs rename to poulpy-backend/src/cpu_spqlios/ffi/vec_znx_big.rs index 16d5647..55b9ea7 100644 --- a/poulpy-backend/src/implementation/cpu_spqlios/ffi/vec_znx_big.rs +++ b/poulpy-backend/src/cpu_spqlios/ffi/vec_znx_big.rs @@ -1,163 +1,153 @@ -use crate::implementation::cpu_spqlios::ffi::module::MODULE; - -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct vec_znx_big_t { - _unused: [u8; 0], -} -pub type VEC_ZNX_BIG = vec_znx_big_t; - -unsafe extern "C" { - pub unsafe fn bytes_of_vec_znx_big(module: *const MODULE, size: u64) -> u64; -} -unsafe extern "C" { - pub unsafe fn new_vec_znx_big(module: *const MODULE, size: u64) -> *mut VEC_ZNX_BIG; -} -unsafe extern "C" { - pub unsafe fn delete_vec_znx_big(res: *mut VEC_ZNX_BIG); -} - -unsafe extern "C" { - pub unsafe fn vec_znx_big_add( - module: *const MODULE, - res: *mut VEC_ZNX_BIG, - res_size: u64, - a: *const VEC_ZNX_BIG, - a_size: u64, - b: *const VEC_ZNX_BIG, - b_size: u64, - ); -} -unsafe extern "C" { - pub unsafe fn vec_znx_big_add_small( - module: *const MODULE, - res: *mut VEC_ZNX_BIG, - res_size: u64, - a: *const VEC_ZNX_BIG, - a_size: u64, - b: *const i64, - b_size: u64, - b_sl: u64, - ); -} -unsafe extern "C" { - pub unsafe fn vec_znx_big_add_small2( - module: *const MODULE, - res: *mut VEC_ZNX_BIG, - res_size: u64, - a: *const i64, - a_size: u64, - a_sl: u64, - b: *const i64, - b_size: u64, - b_sl: u64, - ); -} -unsafe extern "C" { - pub unsafe fn vec_znx_big_sub( - module: *const MODULE, - res: *mut VEC_ZNX_BIG, - res_size: u64, - a: *const VEC_ZNX_BIG, - a_size: u64, - b: *const VEC_ZNX_BIG, - b_size: u64, - ); -} -unsafe extern "C" { - pub unsafe fn vec_znx_big_sub_small_b( - module: *const MODULE, - res: *mut VEC_ZNX_BIG, - res_size: u64, - a: *const VEC_ZNX_BIG, - a_size: u64, - b: *const i64, - b_size: u64, - b_sl: u64, - ); -} -unsafe extern "C" { - pub unsafe fn vec_znx_big_sub_small_a( - module: *const MODULE, - res: *mut VEC_ZNX_BIG, - res_size: u64, - a: *const i64, - a_size: u64, - a_sl: u64, - b: *const VEC_ZNX_BIG, - b_size: u64, - ); -} -unsafe extern "C" { - pub unsafe fn vec_znx_big_sub_small2( - module: *const MODULE, - res: *mut VEC_ZNX_BIG, - res_size: u64, - a: *const i64, - a_size: u64, - a_sl: u64, - b: *const i64, - b_size: u64, - b_sl: u64, - ); -} - -unsafe extern "C" { - pub unsafe fn vec_znx_big_normalize_base2k_tmp_bytes(module: *const MODULE, n: u64) -> u64; -} - -unsafe extern "C" { - pub unsafe fn vec_znx_big_normalize_base2k( - module: *const MODULE, - n: u64, - log2_base2k: u64, - res: *mut i64, - res_size: u64, - res_sl: u64, - a: *const VEC_ZNX_BIG, - a_size: u64, - tmp_space: *mut u8, - ); -} - -unsafe extern "C" { - pub unsafe fn vec_znx_big_range_normalize_base2k( - module: *const MODULE, - n: u64, - log2_base2k: u64, - res: *mut i64, - res_size: u64, - res_sl: u64, - a: *const VEC_ZNX_BIG, - a_range_begin: u64, - a_range_xend: u64, - a_range_step: u64, - tmp_space: *mut u8, - ); -} - -unsafe extern "C" { - pub unsafe fn vec_znx_big_range_normalize_base2k_tmp_bytes(module: *const MODULE, n: u64) -> u64; -} - -unsafe extern "C" { - pub unsafe fn vec_znx_big_automorphism( - module: *const MODULE, - p: i64, - res: *mut VEC_ZNX_BIG, - res_size: u64, - a: *const VEC_ZNX_BIG, - a_size: u64, - ); -} - -unsafe extern "C" { - pub unsafe fn vec_znx_big_rotate( - module: *const MODULE, - p: i64, - res: *mut VEC_ZNX_BIG, - res_size: u64, - a: *const VEC_ZNX_BIG, - a_size: u64, - ); -} +use crate::cpu_spqlios::ffi::module::MODULE; + +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct vec_znx_big_t { + _unused: [u8; 0], +} +pub type VEC_ZNX_BIG = vec_znx_big_t; + +unsafe extern "C" { + pub unsafe fn vec_znx_big_add( + module: *const MODULE, + res: *mut VEC_ZNX_BIG, + res_size: u64, + a: *const VEC_ZNX_BIG, + a_size: u64, + b: *const VEC_ZNX_BIG, + b_size: u64, + ); +} +unsafe extern "C" { + pub unsafe fn vec_znx_big_add_small( + module: *const MODULE, + res: *mut VEC_ZNX_BIG, + res_size: u64, + a: *const VEC_ZNX_BIG, + a_size: u64, + b: *const i64, + b_size: u64, + b_sl: u64, + ); +} +unsafe extern "C" { + pub unsafe fn vec_znx_big_add_small2( + module: *const MODULE, + res: *mut VEC_ZNX_BIG, + res_size: u64, + a: *const i64, + a_size: u64, + a_sl: u64, + b: *const i64, + b_size: u64, + b_sl: u64, + ); +} +unsafe extern "C" { + pub unsafe fn vec_znx_big_sub( + module: *const MODULE, + res: *mut VEC_ZNX_BIG, + res_size: u64, + a: *const VEC_ZNX_BIG, + a_size: u64, + b: *const VEC_ZNX_BIG, + b_size: u64, + ); +} +unsafe extern "C" { + pub unsafe fn vec_znx_big_sub_small_b( + module: *const MODULE, + res: *mut VEC_ZNX_BIG, + res_size: u64, + a: *const VEC_ZNX_BIG, + a_size: u64, + b: *const i64, + b_size: u64, + b_sl: u64, + ); +} +unsafe extern "C" { + pub unsafe fn vec_znx_big_sub_small_a( + module: *const MODULE, + res: *mut VEC_ZNX_BIG, + res_size: u64, + a: *const i64, + a_size: u64, + a_sl: u64, + b: *const VEC_ZNX_BIG, + b_size: u64, + ); +} +unsafe extern "C" { + pub unsafe fn vec_znx_big_sub_small2( + module: *const MODULE, + res: *mut VEC_ZNX_BIG, + res_size: u64, + a: *const i64, + a_size: u64, + a_sl: u64, + b: *const i64, + b_size: u64, + b_sl: u64, + ); +} + +unsafe extern "C" { + pub unsafe fn vec_znx_big_normalize_base2k_tmp_bytes(module: *const MODULE, n: u64) -> u64; +} + +unsafe extern "C" { + pub unsafe fn vec_znx_big_normalize_base2k( + module: *const MODULE, + n: u64, + log2_base2k: u64, + res: *mut i64, + res_size: u64, + res_sl: u64, + a: *const VEC_ZNX_BIG, + a_size: u64, + tmp_space: *mut u8, + ); +} + +unsafe extern "C" { + pub unsafe fn vec_znx_big_range_normalize_base2k( + module: *const MODULE, + n: u64, + log2_base2k: u64, + res: *mut i64, + res_size: u64, + res_sl: u64, + a: *const VEC_ZNX_BIG, + a_range_begin: u64, + a_range_xend: u64, + a_range_step: u64, + tmp_space: *mut u8, + ); +} + +unsafe extern "C" { + pub unsafe fn vec_znx_big_range_normalize_base2k_tmp_bytes(module: *const MODULE, n: u64) -> u64; +} + +unsafe extern "C" { + pub unsafe fn vec_znx_big_automorphism( + module: *const MODULE, + p: i64, + res: *mut VEC_ZNX_BIG, + res_size: u64, + a: *const VEC_ZNX_BIG, + a_size: u64, + ); +} + +unsafe extern "C" { + pub unsafe fn vec_znx_big_rotate( + module: *const MODULE, + p: i64, + res: *mut VEC_ZNX_BIG, + res_size: u64, + a: *const VEC_ZNX_BIG, + a_size: u64, + ); +} diff --git a/poulpy-backend/src/implementation/cpu_spqlios/ffi/vec_znx_dft.rs b/poulpy-backend/src/cpu_spqlios/ffi/vec_znx_dft.rs similarity index 72% rename from poulpy-backend/src/implementation/cpu_spqlios/ffi/vec_znx_dft.rs rename to poulpy-backend/src/cpu_spqlios/ffi/vec_znx_dft.rs index fbf1e49..9612f37 100644 --- a/poulpy-backend/src/implementation/cpu_spqlios/ffi/vec_znx_dft.rs +++ b/poulpy-backend/src/cpu_spqlios/ffi/vec_znx_dft.rs @@ -1,4 +1,4 @@ -use crate::implementation::cpu_spqlios::ffi::{module::MODULE, vec_znx_big::VEC_ZNX_BIG}; +use crate::cpu_spqlios::ffi::{module::MODULE, vec_znx_big::VEC_ZNX_BIG}; #[repr(C)] #[derive(Debug, Copy, Clone)] @@ -7,19 +7,6 @@ pub struct vec_znx_dft_t { } pub type VEC_ZNX_DFT = vec_znx_dft_t; -unsafe extern "C" { - pub unsafe fn bytes_of_vec_znx_dft(module: *const MODULE, size: u64) -> u64; -} -unsafe extern "C" { - pub unsafe fn new_vec_znx_dft(module: *const MODULE, size: u64) -> *mut VEC_ZNX_DFT; -} -unsafe extern "C" { - pub unsafe fn delete_vec_znx_dft(res: *mut VEC_ZNX_DFT); -} - -unsafe extern "C" { - pub unsafe fn vec_dft_zero(module: *const MODULE, res: *mut VEC_ZNX_DFT, res_size: u64); -} unsafe extern "C" { pub unsafe fn vec_dft_add( module: *const MODULE, diff --git a/poulpy-backend/src/implementation/cpu_spqlios/ffi/vmp.rs b/poulpy-backend/src/cpu_spqlios/ffi/vmp.rs similarity index 79% rename from poulpy-backend/src/implementation/cpu_spqlios/ffi/vmp.rs rename to poulpy-backend/src/cpu_spqlios/ffi/vmp.rs index b9ae29a..48c3a84 100644 --- a/poulpy-backend/src/implementation/cpu_spqlios/ffi/vmp.rs +++ b/poulpy-backend/src/cpu_spqlios/ffi/vmp.rs @@ -1,4 +1,4 @@ -use crate::implementation::cpu_spqlios::ffi::{module::MODULE, vec_znx_dft::VEC_ZNX_DFT}; +use crate::cpu_spqlios::ffi::{module::MODULE, vec_znx_dft::VEC_ZNX_DFT}; #[repr(C)] #[derive(Debug, Copy, Clone)] @@ -9,16 +9,7 @@ pub struct vmp_pmat_t { // [rows][cols] = [#Decomposition][#Limbs] pub type VMP_PMAT = vmp_pmat_t; -unsafe extern "C" { - pub unsafe fn bytes_of_vmp_pmat(module: *const MODULE, nrows: u64, ncols: u64) -> u64; -} -unsafe extern "C" { - pub unsafe fn new_vmp_pmat(module: *const MODULE, nrows: u64, ncols: u64) -> *mut VMP_PMAT; -} -unsafe extern "C" { - pub unsafe fn delete_vmp_pmat(res: *mut VMP_PMAT); -} - +#[allow(dead_code)] unsafe extern "C" { pub unsafe fn vmp_apply_dft( module: *const MODULE, @@ -34,6 +25,7 @@ unsafe extern "C" { ); } +#[allow(dead_code)] unsafe extern "C" { pub unsafe fn vmp_apply_dft_add( module: *const MODULE, @@ -50,6 +42,7 @@ unsafe extern "C" { ); } +#[allow(dead_code)] unsafe extern "C" { pub unsafe fn vmp_apply_dft_tmp_bytes(module: *const MODULE, res_size: u64, a_size: u64, nrows: u64, ncols: u64) -> u64; } @@ -105,10 +98,6 @@ unsafe extern "C" { ); } -unsafe extern "C" { - pub unsafe fn vmp_prepare_contiguous_dft(module: *const MODULE, pmat: *mut VMP_PMAT, mat: *const f64, nrows: u64, ncols: u64); -} - unsafe extern "C" { pub unsafe fn vmp_prepare_tmp_bytes(module: *const MODULE, nn: u64, nrows: u64, ncols: u64) -> u64; } diff --git a/poulpy-backend/src/cpu_spqlios/ffi/znx.rs b/poulpy-backend/src/cpu_spqlios/ffi/znx.rs new file mode 100644 index 0000000..e669c35 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/ffi/znx.rs @@ -0,0 +1,7 @@ +unsafe extern "C" { + pub unsafe fn znx_rotate_i64(nn: u64, p: i64, res: *mut i64, in_: *const i64); +} + +unsafe extern "C" { + pub unsafe fn znx_rotate_inplace_i64(nn: u64, p: i64, res: *mut i64); +} diff --git a/poulpy-backend/src/cpu_spqlios/fft64/mod.rs b/poulpy-backend/src/cpu_spqlios/fft64/mod.rs new file mode 100644 index 0000000..b81e73d --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/fft64/mod.rs @@ -0,0 +1,15 @@ +mod module; +mod scratch; +mod svp_ppol; +mod vec_znx; +mod vec_znx_big; +mod vec_znx_dft; +mod vmp_pmat; + +pub use module::FFT64; + +/// For external documentation +pub use vec_znx::{ + vec_znx_copy_ref, vec_znx_lsh_inplace_ref, vec_znx_merge_ref, vec_znx_rsh_inplace_ref, vec_znx_split_ref, + vec_znx_switch_degree_ref, +}; diff --git a/poulpy-backend/src/implementation/cpu_spqlios/module_fft64.rs b/poulpy-backend/src/cpu_spqlios/fft64/module.rs similarity index 53% rename from poulpy-backend/src/implementation/cpu_spqlios/module_fft64.rs rename to poulpy-backend/src/cpu_spqlios/fft64/module.rs index 3f86c6c..7bd4ff6 100644 --- a/poulpy-backend/src/implementation/cpu_spqlios/module_fft64.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/module.rs @@ -1,25 +1,29 @@ use std::ptr::NonNull; -use crate::{ - hal::{ - layouts::{Backend, Module}, - oep::ModuleNewImpl, - }, - implementation::cpu_spqlios::{ - CPUAVX, - ffi::module::{MODULE, delete_module_info, new_module_info}, - }, +use poulpy_hal::{ + layouts::{Backend, Module}, + oep::ModuleNewImpl, }; +use crate::cpu_spqlios::ffi::module::{MODULE, delete_module_info, new_module_info}; + pub struct FFT64; -impl CPUAVX for FFT64 {} - impl Backend for FFT64 { + type ScalarPrep = f64; + type ScalarBig = i64; type Handle = MODULE; unsafe fn destroy(handle: NonNull) { unsafe { delete_module_info(handle.as_ptr()) } } + + fn layout_big_word_count() -> usize { + 1 + } + + fn layout_prep_word_count() -> usize { + 1 + } } unsafe impl ModuleNewImpl for FFT64 { diff --git a/poulpy-backend/src/implementation/cpu_spqlios/scratch.rs b/poulpy-backend/src/cpu_spqlios/fft64/scratch.rs similarity index 77% rename from poulpy-backend/src/implementation/cpu_spqlios/scratch.rs rename to poulpy-backend/src/cpu_spqlios/fft64/scratch.rs index 3ebb6d2..43ff74e 100644 --- a/poulpy-backend/src/implementation/cpu_spqlios/scratch.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/scratch.rs @@ -1,24 +1,20 @@ use std::marker::PhantomData; -use crate::{ +use poulpy_hal::{ DEFAULTALIGN, alloc_aligned, - hal::{ - api::ScratchFromBytes, - layouts::{Backend, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat}, - oep::{ - ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, SvpPPolAllocBytesImpl, - TakeMatZnxImpl, TakeScalarZnxImpl, TakeSliceImpl, TakeSvpPPolImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl, - TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, TakeVmpPMatImpl, VecZnxBigAllocBytesImpl, - VecZnxDftAllocBytesImpl, VmpPMatAllocBytesImpl, - }, + api::ScratchFromBytes, + layouts::{Backend, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat}, + oep::{ + ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, SvpPPolAllocBytesImpl, + TakeMatZnxImpl, TakeScalarZnxImpl, TakeSliceImpl, TakeSvpPPolImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl, + TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, TakeVmpPMatImpl, VecZnxBigAllocBytesImpl, + VecZnxDftAllocBytesImpl, VmpPMatAllocBytesImpl, }, - implementation::cpu_spqlios::CPUAVX, }; -unsafe impl ScratchOwnedAllocImpl for B -where - B: CPUAVX, -{ +use crate::cpu_spqlios::FFT64; + +unsafe impl ScratchOwnedAllocImpl for FFT64 { fn scratch_owned_alloc_impl(size: usize) -> ScratchOwned { let data: Vec = alloc_aligned(size); ScratchOwned { @@ -28,28 +24,22 @@ where } } -unsafe impl ScratchOwnedBorrowImpl for B +unsafe impl ScratchOwnedBorrowImpl for FFT64 where - B: CPUAVX, + B: ScratchFromBytesImpl, { fn scratch_owned_borrow_impl(scratch: &mut ScratchOwned) -> &mut Scratch { Scratch::from_bytes(&mut scratch.data) } } -unsafe impl ScratchFromBytesImpl for B -where - B: CPUAVX, -{ +unsafe impl ScratchFromBytesImpl for FFT64 { fn scratch_from_bytes_impl(data: &mut [u8]) -> &mut Scratch { unsafe { &mut *(data as *mut [u8] as *mut Scratch) } } } -unsafe impl ScratchAvailableImpl for B -where - B: CPUAVX, -{ +unsafe impl ScratchAvailableImpl for FFT64 { fn scratch_available_impl(scratch: &Scratch) -> usize { let ptr: *const u8 = scratch.data.as_ptr(); let self_len: usize = scratch.data.len(); @@ -58,9 +48,9 @@ where } } -unsafe impl TakeSliceImpl for B +unsafe impl TakeSliceImpl for FFT64 where - B: CPUAVX, + B: ScratchFromBytesImpl, { fn take_slice_impl(scratch: &mut Scratch, len: usize) -> (&mut [T], &mut Scratch) { let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, len * std::mem::size_of::()); @@ -74,9 +64,9 @@ where } } -unsafe impl TakeScalarZnxImpl for B +unsafe impl TakeScalarZnxImpl for FFT64 where - B: CPUAVX, + B: ScratchFromBytesImpl, { fn take_scalar_znx_impl(scratch: &mut Scratch, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Scratch) { let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, ScalarZnx::alloc_bytes(n, cols)); @@ -87,9 +77,9 @@ where } } -unsafe impl TakeSvpPPolImpl for B +unsafe impl TakeSvpPPolImpl for FFT64 where - B: CPUAVX + SvpPPolAllocBytesImpl, + B: SvpPPolAllocBytesImpl + ScratchFromBytesImpl, { fn take_svp_ppol_impl(scratch: &mut Scratch, n: usize, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Scratch) { let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, B::svp_ppol_alloc_bytes_impl(n, cols)); @@ -100,9 +90,9 @@ where } } -unsafe impl TakeVecZnxImpl for B +unsafe impl TakeVecZnxImpl for FFT64 where - B: CPUAVX, + B: ScratchFromBytesImpl, { fn take_vec_znx_impl(scratch: &mut Scratch, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Scratch) { let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, VecZnx::alloc_bytes(n, cols, size)); @@ -113,9 +103,9 @@ where } } -unsafe impl TakeVecZnxBigImpl for B +unsafe impl TakeVecZnxBigImpl for FFT64 where - B: CPUAVX + VecZnxBigAllocBytesImpl, + B: VecZnxBigAllocBytesImpl + ScratchFromBytesImpl, { fn take_vec_znx_big_impl( scratch: &mut Scratch, @@ -134,9 +124,9 @@ where } } -unsafe impl TakeVecZnxDftImpl for B +unsafe impl TakeVecZnxDftImpl for FFT64 where - B: CPUAVX + VecZnxDftAllocBytesImpl, + B: VecZnxDftAllocBytesImpl + ScratchFromBytesImpl, { fn take_vec_znx_dft_impl( scratch: &mut Scratch, @@ -156,9 +146,9 @@ where } } -unsafe impl TakeVecZnxDftSliceImpl for B +unsafe impl TakeVecZnxDftSliceImpl for FFT64 where - B: CPUAVX + VecZnxDftAllocBytesImpl, + B: VecZnxDftAllocBytesImpl + ScratchFromBytesImpl + TakeVecZnxDftImpl, { fn take_vec_znx_dft_slice_impl( scratch: &mut Scratch, @@ -178,9 +168,9 @@ where } } -unsafe impl TakeVecZnxSliceImpl for B +unsafe impl TakeVecZnxSliceImpl for FFT64 where - B: CPUAVX, + B: ScratchFromBytesImpl + TakeVecZnxImpl, { fn take_vec_znx_slice_impl( scratch: &mut Scratch, @@ -200,9 +190,9 @@ where } } -unsafe impl TakeVmpPMatImpl for B +unsafe impl TakeVmpPMatImpl for FFT64 where - B: CPUAVX + VmpPMatAllocBytesImpl, + B: VmpPMatAllocBytesImpl + ScratchFromBytesImpl, { fn take_vmp_pmat_impl( scratch: &mut Scratch, @@ -223,9 +213,9 @@ where } } -unsafe impl TakeMatZnxImpl for B +unsafe impl TakeMatZnxImpl for FFT64 where - B: CPUAVX, + B: ScratchFromBytesImpl, { fn take_mat_znx_impl( scratch: &mut Scratch, diff --git a/poulpy-backend/src/implementation/cpu_spqlios/svp_ppol_fft64.rs b/poulpy-backend/src/cpu_spqlios/fft64/svp_ppol.rs similarity index 72% rename from poulpy-backend/src/implementation/cpu_spqlios/svp_ppol_fft64.rs rename to poulpy-backend/src/cpu_spqlios/fft64/svp_ppol.rs index 265840a..6a82dc6 100644 --- a/poulpy-backend/src/implementation/cpu_spqlios/svp_ppol_fft64.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/svp_ppol.rs @@ -1,35 +1,16 @@ -use crate::{ - hal::{ - api::{ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut}, - layouts::{ - Data, DataRef, Module, ScalarZnxToRef, SvpPPol, SvpPPolBytesOf, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDft, - VecZnxDftToMut, VecZnxDftToRef, - }, - oep::{SvpApplyImpl, SvpApplyInplaceImpl, SvpPPolAllocBytesImpl, SvpPPolAllocImpl, SvpPPolFromBytesImpl, SvpPrepareImpl}, - }, - implementation::cpu_spqlios::{ - ffi::{svp, vec_znx_dft::vec_znx_dft_t}, - module_fft64::FFT64, +use poulpy_hal::{ + api::{ZnxInfos, ZnxView, ZnxViewMut}, + layouts::{ + Backend, Module, ScalarZnxToRef, SvpPPol, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDft, VecZnxDftToMut, + VecZnxDftToRef, }, + oep::{SvpApplyImpl, SvpApplyInplaceImpl, SvpPPolAllocBytesImpl, SvpPPolAllocImpl, SvpPPolFromBytesImpl, SvpPrepareImpl}, }; -const SVP_PPOL_FFT64_WORD_SIZE: usize = 1; - -impl SvpPPolBytesOf for SvpPPol { - fn bytes_of(n: usize, cols: usize) -> usize { - SVP_PPOL_FFT64_WORD_SIZE * n * cols * size_of::() - } -} - -impl ZnxSliceSize for SvpPPol { - fn sl(&self) -> usize { - SVP_PPOL_FFT64_WORD_SIZE * self.n() - } -} - -impl ZnxView for SvpPPol { - type Scalar = f64; -} +use crate::cpu_spqlios::{ + FFT64, + ffi::{svp, vec_znx_dft::vec_znx_dft_t}, +}; unsafe impl SvpPPolFromBytesImpl for FFT64 { fn svp_ppol_from_bytes_impl(n: usize, cols: usize, bytes: Vec) -> SvpPPolOwned { @@ -45,7 +26,7 @@ unsafe impl SvpPPolAllocImpl for FFT64 { unsafe impl SvpPPolAllocBytesImpl for FFT64 { fn svp_ppol_alloc_bytes_impl(n: usize, cols: usize) -> usize { - SvpPPol::, Self>::bytes_of(n, cols) + FFT64::layout_prep_word_count() * n * cols * size_of::() } } diff --git a/poulpy-backend/src/implementation/cpu_spqlios/vec_znx.rs b/poulpy-backend/src/cpu_spqlios/fft64/vec_znx.rs similarity index 77% rename from poulpy-backend/src/implementation/cpu_spqlios/vec_znx.rs rename to poulpy-backend/src/cpu_spqlios/fft64/vec_znx.rs index b928b9c..c5271cc 100644 --- a/poulpy-backend/src/implementation/cpu_spqlios/vec_znx.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/vec_znx.rs @@ -1,48 +1,47 @@ use itertools::izip; use rand_distr::Normal; -use crate::{ - hal::{ - api::{ - TakeSlice, TakeVecZnx, VecZnxAddDistF64, VecZnxCopy, VecZnxFillDistF64, VecZnxNormalizeTmpBytes, VecZnxRotate, - VecZnxRotateInplace, VecZnxSwithcDegree, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero, - }, - layouts::{Backend, Module, ScalarZnx, ScalarZnxToRef, Scratch, VecZnx, VecZnxToMut, VecZnxToRef}, - oep::{ - VecZnxAddDistF64Impl, VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl, VecZnxAddScalarInplaceImpl, - VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl, VecZnxFillDistF64Impl, VecZnxFillNormalImpl, - VecZnxFillUniformImpl, VecZnxLshInplaceImpl, VecZnxMergeImpl, VecZnxMulXpMinusOneImpl, - VecZnxMulXpMinusOneInplaceImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl, - VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl, - VecZnxRshInplaceImpl, VecZnxSplitImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl, - VecZnxSubScalarInplaceImpl, VecZnxSwithcDegreeImpl, - }, - source::Source, +use poulpy_hal::{ + api::{ + TakeSlice, TakeVecZnx, VecZnxAddDistF64, VecZnxCopy, VecZnxFillDistF64, VecZnxNormalizeTmpBytes, VecZnxRotate, + VecZnxRotateInplace, VecZnxSwithcDegree, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero, }, - implementation::cpu_spqlios::{ - CPUAVX, - ffi::{module::module_info_t, vec_znx, znx}, + layouts::{Backend, Module, ScalarZnx, ScalarZnxToRef, Scratch, VecZnx, VecZnxToMut, VecZnxToRef}, + oep::{ + TakeSliceImpl, TakeVecZnxImpl, VecZnxAddDistF64Impl, VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl, + VecZnxAddScalarInplaceImpl, VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl, VecZnxFillDistF64Impl, + VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxLshInplaceImpl, VecZnxMergeImpl, VecZnxMulXpMinusOneImpl, + VecZnxMulXpMinusOneInplaceImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl, + VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl, VecZnxRshInplaceImpl, + VecZnxSplitImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarInplaceImpl, + VecZnxSwithcDegreeImpl, }, + source::Source, }; -unsafe impl VecZnxNormalizeTmpBytesImpl for B -where - B: CPUAVX, -{ - fn vec_znx_normalize_tmp_bytes_impl(module: &Module, n: usize) -> usize { +use crate::cpu_spqlios::{ + FFT64, + ffi::{module::module_info_t, vec_znx, znx}, +}; + +unsafe impl VecZnxNormalizeTmpBytesImpl for FFT64 { + fn vec_znx_normalize_tmp_bytes_impl(module: &Module, n: usize) -> usize { unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(module.ptr() as *const module_info_t, n as u64) as usize } } } -unsafe impl VecZnxNormalizeImpl for B { +unsafe impl VecZnxNormalizeImpl for FFT64 +where + Self: TakeSliceImpl + VecZnxNormalizeTmpBytesImpl, +{ fn vec_znx_normalize_impl( - module: &Module, + module: &Module, basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, - scratch: &mut Scratch, + scratch: &mut Scratch, ) where R: VecZnxToMut, A: VecZnxToRef, @@ -74,9 +73,17 @@ unsafe impl VecZnxNormalizeImpl for B { } } -unsafe impl VecZnxNormalizeInplaceImpl for B { - fn vec_znx_normalize_inplace_impl(module: &Module, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) - where +unsafe impl VecZnxNormalizeInplaceImpl for FFT64 +where + Self: TakeSliceImpl + VecZnxNormalizeTmpBytesImpl, +{ + fn vec_znx_normalize_inplace_impl( + module: &Module, + basek: usize, + a: &mut A, + a_col: usize, + scratch: &mut Scratch, + ) where A: VecZnxToMut, { let mut a: VecZnx<&mut [u8]> = a.to_mut(); @@ -100,8 +107,8 @@ unsafe impl VecZnxNormalizeInplaceImpl for B { } } -unsafe impl VecZnxAddImpl for B { - fn vec_znx_add_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize) +unsafe impl VecZnxAddImpl for FFT64 { + fn vec_znx_add_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize) where R: VecZnxToMut, A: VecZnxToRef, @@ -134,8 +141,8 @@ unsafe impl VecZnxAddImpl for B { } } -unsafe impl VecZnxAddInplaceImpl for B { - fn vec_znx_add_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) +unsafe impl VecZnxAddInplaceImpl for FFT64 { + fn vec_znx_add_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxToMut, A: VecZnxToRef, @@ -164,9 +171,9 @@ unsafe impl VecZnxAddInplaceImpl for B { } } -unsafe impl VecZnxAddScalarInplaceImpl for B { +unsafe impl VecZnxAddScalarInplaceImpl for FFT64 { fn vec_znx_add_scalar_inplace_impl( - module: &Module, + module: &Module, res: &mut R, res_col: usize, res_limb: usize, @@ -201,8 +208,8 @@ unsafe impl VecZnxAddScalarInplaceImpl for B { } } -unsafe impl VecZnxSubImpl for B { - fn vec_znx_sub_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize) +unsafe impl VecZnxSubImpl for FFT64 { + fn vec_znx_sub_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize) where R: VecZnxToMut, A: VecZnxToRef, @@ -235,8 +242,8 @@ unsafe impl VecZnxSubImpl for B { } } -unsafe impl VecZnxSubABInplaceImpl for B { - fn vec_znx_sub_ab_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) +unsafe impl VecZnxSubABInplaceImpl for FFT64 { + fn vec_znx_sub_ab_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxToMut, A: VecZnxToRef, @@ -264,8 +271,8 @@ unsafe impl VecZnxSubABInplaceImpl for B { } } -unsafe impl VecZnxSubBAInplaceImpl for B { - fn vec_znx_sub_ba_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) +unsafe impl VecZnxSubBAInplaceImpl for FFT64 { + fn vec_znx_sub_ba_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxToMut, A: VecZnxToRef, @@ -293,9 +300,9 @@ unsafe impl VecZnxSubBAInplaceImpl for B { } } -unsafe impl VecZnxSubScalarInplaceImpl for B { +unsafe impl VecZnxSubScalarInplaceImpl for FFT64 { fn vec_znx_sub_scalar_inplace_impl( - module: &Module, + module: &Module, res: &mut R, res_col: usize, res_limb: usize, @@ -330,8 +337,8 @@ unsafe impl VecZnxSubScalarInplaceImpl for B { } } -unsafe impl VecZnxNegateImpl for B { - fn vec_znx_negate_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) +unsafe impl VecZnxNegateImpl for FFT64 { + fn vec_znx_negate_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxToMut, A: VecZnxToRef, @@ -356,8 +363,8 @@ unsafe impl VecZnxNegateImpl for B { } } -unsafe impl VecZnxNegateInplaceImpl for B { - fn vec_znx_negate_inplace_impl(module: &Module, a: &mut A, a_col: usize) +unsafe impl VecZnxNegateInplaceImpl for FFT64 { + fn vec_znx_negate_inplace_impl(module: &Module, a: &mut A, a_col: usize) where A: VecZnxToMut, { @@ -376,8 +383,8 @@ unsafe impl VecZnxNegateInplaceImpl for B { } } -unsafe impl VecZnxLshInplaceImpl for B { - fn vec_znx_lsh_inplace_impl(_module: &Module, basek: usize, k: usize, a: &mut A) +unsafe impl VecZnxLshInplaceImpl for FFT64 { + fn vec_znx_lsh_inplace_impl(_module: &Module, basek: usize, k: usize, a: &mut A) where A: VecZnxToMut, { @@ -417,8 +424,8 @@ where } } -unsafe impl VecZnxRshInplaceImpl for B { - fn vec_znx_rsh_inplace_impl(_module: &Module, basek: usize, k: usize, a: &mut A) +unsafe impl VecZnxRshInplaceImpl for FFT64 { + fn vec_znx_rsh_inplace_impl(_module: &Module, basek: usize, k: usize, a: &mut A) where A: VecZnxToMut, { @@ -461,8 +468,8 @@ where } } -unsafe impl VecZnxRotateImpl for B { - fn vec_znx_rotate_impl(_module: &Module, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) +unsafe impl VecZnxRotateImpl for FFT64 { + fn vec_znx_rotate_impl(_module: &Module, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxToMut, A: VecZnxToRef, @@ -486,8 +493,8 @@ unsafe impl VecZnxRotateImpl for B { } } -unsafe impl VecZnxRotateInplaceImpl for B { - fn vec_znx_rotate_inplace_impl(_module: &Module, k: i64, a: &mut A, a_col: usize) +unsafe impl VecZnxRotateInplaceImpl for FFT64 { + fn vec_znx_rotate_inplace_impl(_module: &Module, k: i64, a: &mut A, a_col: usize) where A: VecZnxToMut, { @@ -500,8 +507,8 @@ unsafe impl VecZnxRotateInplaceImpl for B { } } -unsafe impl VecZnxAutomorphismImpl for B { - fn vec_znx_automorphism_impl(module: &Module, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) +unsafe impl VecZnxAutomorphismImpl for FFT64 { + fn vec_znx_automorphism_impl(module: &Module, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxToMut, A: VecZnxToRef, @@ -527,8 +534,8 @@ unsafe impl VecZnxAutomorphismImpl for B { } } -unsafe impl VecZnxAutomorphismInplaceImpl for B { - fn vec_znx_automorphism_inplace_impl(module: &Module, k: i64, a: &mut A, a_col: usize) +unsafe impl VecZnxAutomorphismInplaceImpl for FFT64 { + fn vec_znx_automorphism_inplace_impl(module: &Module, k: i64, a: &mut A, a_col: usize) where A: VecZnxToMut, { @@ -556,8 +563,8 @@ unsafe impl VecZnxAutomorphismInplaceImpl for B { } } -unsafe impl VecZnxMulXpMinusOneImpl for B { - fn vec_znx_mul_xp_minus_one_impl(module: &Module, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) +unsafe impl VecZnxMulXpMinusOneImpl for FFT64 { + fn vec_znx_mul_xp_minus_one_impl(module: &Module, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxToMut, A: VecZnxToRef, @@ -584,8 +591,8 @@ unsafe impl VecZnxMulXpMinusOneImpl for B { } } -unsafe impl VecZnxMulXpMinusOneInplaceImpl for B { - fn vec_znx_mul_xp_minus_one_inplace_impl(module: &Module, p: i64, res: &mut R, res_col: usize) +unsafe impl VecZnxMulXpMinusOneInplaceImpl for FFT64 { + fn vec_znx_mul_xp_minus_one_inplace_impl(module: &Module, p: i64, res: &mut R, res_col: usize) where R: VecZnxToMut, { @@ -609,9 +616,22 @@ unsafe impl VecZnxMulXpMinusOneInplaceImpl for B { } } -unsafe impl VecZnxSplitImpl for B { - fn vec_znx_split_impl(module: &Module, res: &mut [R], res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) - where +unsafe impl VecZnxSplitImpl for FFT64 +where + Self: TakeVecZnxImpl + + TakeVecZnxImpl + + VecZnxSwithcDegreeImpl + + VecZnxRotateImpl + + VecZnxRotateInplaceImpl, +{ + fn vec_znx_split_impl( + module: &Module, + res: &mut [R], + res_col: usize, + a: &A, + a_col: usize, + scratch: &mut Scratch, + ) where R: VecZnxToMut, A: VecZnxToRef, { @@ -627,7 +647,7 @@ pub fn vec_znx_split_ref( a_col: usize, scratch: &mut Scratch, ) where - B: Backend + CPUAVX, + B: Backend + TakeVecZnxImpl + VecZnxSwithcDegreeImpl + VecZnxRotateImpl + VecZnxRotateInplaceImpl, R: VecZnxToMut, A: VecZnxToRef, { @@ -660,8 +680,11 @@ pub fn vec_znx_split_ref( }) } -unsafe impl VecZnxMergeImpl for B { - fn vec_znx_merge_impl(module: &Module, res: &mut R, res_col: usize, a: &[A], a_col: usize) +unsafe impl VecZnxMergeImpl for FFT64 +where + Self: VecZnxSwithcDegreeImpl + VecZnxRotateInplaceImpl, +{ + fn vec_znx_merge_impl(module: &Module, res: &mut R, res_col: usize, a: &[A], a_col: usize) where R: VecZnxToMut, A: VecZnxToRef, @@ -672,7 +695,7 @@ unsafe impl VecZnxMergeImpl for B { pub fn vec_znx_merge_ref(module: &Module, res: &mut R, res_col: usize, a: &[A], a_col: usize) where - B: Backend + CPUAVX, + B: Backend + VecZnxSwithcDegreeImpl + VecZnxRotateInplaceImpl, R: VecZnxToMut, A: VecZnxToRef, { @@ -700,8 +723,11 @@ where module.vec_znx_rotate_inplace(a.len() as i64, &mut res, res_col); } -unsafe impl VecZnxSwithcDegreeImpl for B { - fn vec_znx_switch_degree_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) +unsafe impl VecZnxSwithcDegreeImpl for FFT64 +where + Self: VecZnxCopyImpl, +{ + fn vec_znx_switch_degree_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxToMut, A: VecZnxToRef, @@ -712,7 +738,7 @@ unsafe impl VecZnxSwithcDegreeImpl for B { pub fn vec_znx_switch_degree_ref(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where - B: Backend + CPUAVX, + B: Backend + VecZnxCopyImpl, R: VecZnxToMut, A: VecZnxToRef, { @@ -745,8 +771,8 @@ where }); } -unsafe impl VecZnxCopyImpl for B { - fn vec_znx_copy_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) +unsafe impl VecZnxCopyImpl for FFT64 { + fn vec_znx_copy_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxToMut, A: VecZnxToRef, @@ -775,9 +801,15 @@ where }) } -unsafe impl VecZnxFillUniformImpl for B { - fn vec_znx_fill_uniform_impl(_module: &Module, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source) - where +unsafe impl VecZnxFillUniformImpl for FFT64 { + fn vec_znx_fill_uniform_impl( + _module: &Module, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + ) where R: VecZnxToMut, { let mut a: VecZnx<&mut [u8]> = res.to_mut(); @@ -792,9 +824,9 @@ unsafe impl VecZnxFillUniformImpl for B { } } -unsafe impl VecZnxFillDistF64Impl for B { +unsafe impl VecZnxFillDistF64Impl for FFT64 { fn vec_znx_fill_dist_f64_impl>( - _module: &Module, + _module: &Module, basek: usize, res: &mut R, res_col: usize, @@ -835,9 +867,9 @@ unsafe impl VecZnxFillDistF64Impl for B { } } -unsafe impl VecZnxAddDistF64Impl for B { +unsafe impl VecZnxAddDistF64Impl for FFT64 { fn vec_znx_add_dist_f64_impl>( - _module: &Module, + _module: &Module, basek: usize, res: &mut R, res_col: usize, @@ -878,9 +910,12 @@ unsafe impl VecZnxAddDistF64Impl for B { } } -unsafe impl VecZnxFillNormalImpl for B { +unsafe impl VecZnxFillNormalImpl for FFT64 +where + Self: VecZnxFillDistF64Impl, +{ fn vec_znx_fill_normal_impl( - module: &Module, + module: &Module, basek: usize, res: &mut R, res_col: usize, @@ -903,9 +938,12 @@ unsafe impl VecZnxFillNormalImpl for B { } } -unsafe impl VecZnxAddNormalImpl for B { +unsafe impl VecZnxAddNormalImpl for FFT64 +where + Self: VecZnxAddDistF64Impl, +{ fn vec_znx_add_normal_impl( - module: &Module, + module: &Module, basek: usize, res: &mut R, res_col: usize, diff --git a/poulpy-backend/src/implementation/cpu_spqlios/vec_znx_big_fft64.rs b/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_big.rs similarity index 61% rename from poulpy-backend/src/implementation/cpu_spqlios/vec_znx_big_fft64.rs rename to poulpy-backend/src/cpu_spqlios/fft64/vec_znx_big.rs index 7c499d1..c571371 100644 --- a/poulpy-backend/src/implementation/cpu_spqlios/vec_znx_big_fft64.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_big.rs @@ -1,69 +1,46 @@ -use std::fmt; - use rand_distr::{Distribution, Normal}; -use crate::{ - hal::{ - api::{ - TakeSlice, VecZnxBigAddDistF64, VecZnxBigFillDistF64, VecZnxBigNormalizeTmpBytes, ZnxInfos, ZnxSliceSize, ZnxView, - ZnxViewMut, - }, - layouts::{ - Data, DataRef, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigBytesOf, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, - VecZnxToMut, VecZnxToRef, - }, - oep::{ - VecZnxBigAddDistF64Impl, VecZnxBigAddImpl, VecZnxBigAddInplaceImpl, VecZnxBigAddNormalImpl, VecZnxBigAddSmallImpl, - VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, VecZnxBigAutomorphismImpl, - VecZnxBigAutomorphismInplaceImpl, VecZnxBigFillDistF64Impl, VecZnxBigFillNormalImpl, VecZnxBigFromBytesImpl, - VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl, VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubABInplaceImpl, - VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl, VecZnxBigSubSmallAImpl, VecZnxBigSubSmallAInplaceImpl, - VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl, - }, - source::Source, +use crate::cpu_spqlios::{FFT64, ffi::vec_znx}; +use poulpy_hal::{ + api::{ + TakeSlice, VecZnxBigAddDistF64, VecZnxBigFillDistF64, VecZnxBigNormalizeTmpBytes, ZnxInfos, ZnxSliceSize, ZnxView, + ZnxViewMut, }, - implementation::cpu_spqlios::{ffi::vec_znx, module_fft64::FFT64}, + layouts::{ + Backend, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef, + }, + oep::{ + TakeSliceImpl, VecZnxBigAddDistF64Impl, VecZnxBigAddImpl, VecZnxBigAddInplaceImpl, VecZnxBigAddNormalImpl, + VecZnxBigAddSmallImpl, VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, + VecZnxBigAutomorphismImpl, VecZnxBigAutomorphismInplaceImpl, VecZnxBigFillDistF64Impl, VecZnxBigFillNormalImpl, + VecZnxBigFromBytesImpl, VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl, VecZnxBigNormalizeTmpBytesImpl, + VecZnxBigSubABInplaceImpl, VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl, VecZnxBigSubSmallAImpl, + VecZnxBigSubSmallAInplaceImpl, VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl, + }, + source::Source, }; -const VEC_ZNX_BIG_FFT64_WORDSIZE: usize = 1; - -impl ZnxView for VecZnxBig { - type Scalar = i64; -} - -impl VecZnxBigBytesOf for VecZnxBig { - fn bytes_of(n: usize, cols: usize, size: usize) -> usize { - VEC_ZNX_BIG_FFT64_WORDSIZE * n * cols * size * size_of::() - } -} - -impl ZnxSliceSize for VecZnxBig { - fn sl(&self) -> usize { - VEC_ZNX_BIG_FFT64_WORDSIZE * self.n() * self.cols() - } -} - -unsafe impl VecZnxBigAllocImpl for FFT64 { - fn vec_znx_big_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxBigOwned { - VecZnxBig::, FFT64>::new(n, cols, size) - } -} - -unsafe impl VecZnxBigFromBytesImpl for FFT64 { - fn vec_znx_big_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec) -> VecZnxBigOwned { - VecZnxBig::, FFT64>::new_from_bytes(n, cols, size, bytes) - } -} - -unsafe impl VecZnxBigAllocBytesImpl for FFT64 { +unsafe impl VecZnxBigAllocBytesImpl for FFT64 { fn vec_znx_big_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize { - VecZnxBig::, FFT64>::bytes_of(n, cols, size) + Self::layout_big_word_count() * n * cols * size * size_of::() } } -unsafe impl VecZnxBigAddDistF64Impl for FFT64 { - fn add_dist_f64_impl, D: Distribution>( - _module: &Module, +unsafe impl VecZnxBigAllocImpl for FFT64 { + fn vec_znx_big_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxBigOwned { + VecZnxBig::alloc(n, cols, size) + } +} + +unsafe impl VecZnxBigFromBytesImpl for FFT64 { + fn vec_znx_big_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec) -> VecZnxBigOwned { + VecZnxBig::from_bytes(n, cols, size, bytes) + } +} + +unsafe impl VecZnxBigAddDistF64Impl for FFT64 { + fn add_dist_f64_impl, D: Distribution>( + _module: &Module, basek: usize, res: &mut R, res_col: usize, @@ -72,7 +49,7 @@ unsafe impl VecZnxBigAddDistF64Impl for FFT64 { dist: D, bound: f64, ) { - let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut(); assert!( (bound.log2().ceil() as i64) < 64, "invalid bound: ceil(log2(bound))={} > 63", @@ -102,9 +79,9 @@ unsafe impl VecZnxBigAddDistF64Impl for FFT64 { } } -unsafe impl VecZnxBigAddNormalImpl for FFT64 { - fn add_normal_impl>( - module: &Module, +unsafe impl VecZnxBigAddNormalImpl for FFT64 { + fn add_normal_impl>( + module: &Module, basek: usize, res: &mut R, res_col: usize, @@ -125,9 +102,9 @@ unsafe impl VecZnxBigAddNormalImpl for FFT64 { } } -unsafe impl VecZnxBigFillDistF64Impl for FFT64 { - fn fill_dist_f64_impl, D: Distribution>( - _module: &Module, +unsafe impl VecZnxBigFillDistF64Impl for FFT64 { + fn fill_dist_f64_impl, D: Distribution>( + _module: &Module, basek: usize, res: &mut R, res_col: usize, @@ -136,7 +113,7 @@ unsafe impl VecZnxBigFillDistF64Impl for FFT64 { dist: D, bound: f64, ) { - let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut(); assert!( (bound.log2().ceil() as i64) < 64, "invalid bound: ceil(log2(bound))={} > 63", @@ -166,9 +143,9 @@ unsafe impl VecZnxBigFillDistF64Impl for FFT64 { } } -unsafe impl VecZnxBigFillNormalImpl for FFT64 { - fn fill_normal_impl>( - module: &Module, +unsafe impl VecZnxBigFillNormalImpl for FFT64 { + fn fill_normal_impl>( + module: &Module, basek: usize, res: &mut R, res_col: usize, @@ -189,24 +166,17 @@ unsafe impl VecZnxBigFillNormalImpl for FFT64 { } } -unsafe impl VecZnxBigAddImpl for FFT64 { +unsafe impl VecZnxBigAddImpl for FFT64 { /// Adds `a` to `b` and stores the result on `c`. - fn vec_znx_big_add_impl( - module: &Module, - res: &mut R, - res_col: usize, - a: &A, - a_col: usize, - b: &B, - b_col: usize, - ) where - R: VecZnxBigToMut, - A: VecZnxBigToRef, - B: VecZnxBigToRef, + fn vec_znx_big_add_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + B: VecZnxBigToRef, { - let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); - let b: VecZnxBig<&[u8], FFT64> = b.to_ref(); - let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + let a: VecZnxBig<&[u8], Self> = a.to_ref(); + let b: VecZnxBig<&[u8], Self> = b.to_ref(); + let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut(); #[cfg(debug_assertions)] { @@ -231,15 +201,15 @@ unsafe impl VecZnxBigAddImpl for FFT64 { } } -unsafe impl VecZnxBigAddInplaceImpl for FFT64 { +unsafe impl VecZnxBigAddInplaceImpl for FFT64 { /// Adds `a` to `b` and stores the result on `b`. - fn vec_znx_big_add_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn vec_znx_big_add_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where - R: VecZnxBigToMut, - A: VecZnxBigToRef, + R: VecZnxBigToMut, + A: VecZnxBigToRef, { - let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); - let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + let a: VecZnxBig<&[u8], Self> = a.to_ref(); + let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut(); #[cfg(debug_assertions)] { @@ -262,10 +232,10 @@ unsafe impl VecZnxBigAddInplaceImpl for FFT64 { } } -unsafe impl VecZnxBigAddSmallImpl for FFT64 { +unsafe impl VecZnxBigAddSmallImpl for FFT64 { /// Adds `a` to `b` and stores the result on `c`. fn vec_znx_big_add_small_impl( - module: &Module, + module: &Module, res: &mut R, res_col: usize, a: &A, @@ -273,13 +243,13 @@ unsafe impl VecZnxBigAddSmallImpl for FFT64 { b: &B, b_col: usize, ) where - R: VecZnxBigToMut, - A: VecZnxBigToRef, + R: VecZnxBigToMut, + A: VecZnxBigToRef, B: VecZnxToRef, { - let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); + let a: VecZnxBig<&[u8], Self> = a.to_ref(); let b: VecZnx<&[u8]> = b.to_ref(); - let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut(); #[cfg(debug_assertions)] { @@ -304,15 +274,15 @@ unsafe impl VecZnxBigAddSmallImpl for FFT64 { } } -unsafe impl VecZnxBigAddSmallInplaceImpl for FFT64 { +unsafe impl VecZnxBigAddSmallInplaceImpl for FFT64 { /// Adds `a` to `b` and stores the result on `b`. - fn vec_znx_big_add_small_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn vec_znx_big_add_small_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where - R: VecZnxBigToMut, + R: VecZnxBigToMut, A: VecZnxToRef, { let a: VecZnx<&[u8]> = a.to_ref(); - let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut(); #[cfg(debug_assertions)] { @@ -335,24 +305,17 @@ unsafe impl VecZnxBigAddSmallInplaceImpl for FFT64 { } } -unsafe impl VecZnxBigSubImpl for FFT64 { +unsafe impl VecZnxBigSubImpl for FFT64 { /// Subtracts `a` to `b` and stores the result on `c`. - fn vec_znx_big_sub_impl( - module: &Module, - res: &mut R, - res_col: usize, - a: &A, - a_col: usize, - b: &B, - b_col: usize, - ) where - R: VecZnxBigToMut, - A: VecZnxBigToRef, - B: VecZnxBigToRef, + fn vec_znx_big_sub_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + B: VecZnxBigToRef, { - let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); - let b: VecZnxBig<&[u8], FFT64> = b.to_ref(); - let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + let a: VecZnxBig<&[u8], Self> = a.to_ref(); + let b: VecZnxBig<&[u8], Self> = b.to_ref(); + let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut(); #[cfg(debug_assertions)] { @@ -377,15 +340,15 @@ unsafe impl VecZnxBigSubImpl for FFT64 { } } -unsafe impl VecZnxBigSubABInplaceImpl for FFT64 { +unsafe impl VecZnxBigSubABInplaceImpl for FFT64 { /// Subtracts `a` from `b` and stores the result on `b`. - fn vec_znx_big_sub_ab_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn vec_znx_big_sub_ab_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where - R: VecZnxBigToMut, - A: VecZnxBigToRef, + R: VecZnxBigToMut, + A: VecZnxBigToRef, { - let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); - let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + let a: VecZnxBig<&[u8], Self> = a.to_ref(); + let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut(); #[cfg(debug_assertions)] { @@ -408,15 +371,15 @@ unsafe impl VecZnxBigSubABInplaceImpl for FFT64 { } } -unsafe impl VecZnxBigSubBAInplaceImpl for FFT64 { +unsafe impl VecZnxBigSubBAInplaceImpl for FFT64 { /// Subtracts `b` from `a` and stores the result on `b`. - fn vec_znx_big_sub_ba_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn vec_znx_big_sub_ba_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where - R: VecZnxBigToMut, - A: VecZnxBigToRef, + R: VecZnxBigToMut, + A: VecZnxBigToRef, { - let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); - let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + let a: VecZnxBig<&[u8], Self> = a.to_ref(); + let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut(); #[cfg(debug_assertions)] { @@ -439,10 +402,10 @@ unsafe impl VecZnxBigSubBAInplaceImpl for FFT64 { } } -unsafe impl VecZnxBigSubSmallAImpl for FFT64 { +unsafe impl VecZnxBigSubSmallAImpl for FFT64 { /// Subtracts `b` from `a` and stores the result on `c`. fn vec_znx_big_sub_small_a_impl( - module: &Module, + module: &Module, res: &mut R, res_col: usize, a: &A, @@ -450,13 +413,13 @@ unsafe impl VecZnxBigSubSmallAImpl for FFT64 { b: &B, b_col: usize, ) where - R: VecZnxBigToMut, + R: VecZnxBigToMut, A: VecZnxToRef, - B: VecZnxBigToRef, + B: VecZnxBigToRef, { let a: VecZnx<&[u8]> = a.to_ref(); - let b: VecZnxBig<&[u8], FFT64> = b.to_ref(); - let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + let b: VecZnxBig<&[u8], Self> = b.to_ref(); + let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut(); #[cfg(debug_assertions)] { @@ -481,15 +444,15 @@ unsafe impl VecZnxBigSubSmallAImpl for FFT64 { } } -unsafe impl VecZnxBigSubSmallAInplaceImpl for FFT64 { +unsafe impl VecZnxBigSubSmallAInplaceImpl for FFT64 { /// Subtracts `a` from `res` and stores the result on `res`. - fn vec_znx_big_sub_small_a_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn vec_znx_big_sub_small_a_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where - R: VecZnxBigToMut, + R: VecZnxBigToMut, A: VecZnxToRef, { let a: VecZnx<&[u8]> = a.to_ref(); - let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut(); #[cfg(debug_assertions)] { @@ -512,10 +475,10 @@ unsafe impl VecZnxBigSubSmallAInplaceImpl for FFT64 { } } -unsafe impl VecZnxBigSubSmallBImpl for FFT64 { +unsafe impl VecZnxBigSubSmallBImpl for FFT64 { /// Subtracts `b` from `a` and stores the result on `c`. fn vec_znx_big_sub_small_b_impl( - module: &Module, + module: &Module, res: &mut R, res_col: usize, a: &A, @@ -523,13 +486,13 @@ unsafe impl VecZnxBigSubSmallBImpl for FFT64 { b: &B, b_col: usize, ) where - R: VecZnxBigToMut, - A: VecZnxBigToRef, + R: VecZnxBigToMut, + A: VecZnxBigToRef, B: VecZnxToRef, { - let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); + let a: VecZnxBig<&[u8], Self> = a.to_ref(); let b: VecZnx<&[u8]> = b.to_ref(); - let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut(); #[cfg(debug_assertions)] { @@ -554,15 +517,15 @@ unsafe impl VecZnxBigSubSmallBImpl for FFT64 { } } -unsafe impl VecZnxBigSubSmallBInplaceImpl for FFT64 { +unsafe impl VecZnxBigSubSmallBInplaceImpl for FFT64 { /// Subtracts `res` from `a` and stores the result on `res`. - fn vec_znx_big_sub_small_b_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn vec_znx_big_sub_small_b_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where - R: VecZnxBigToMut, + R: VecZnxBigToMut, A: VecZnxToRef, { let a: VecZnx<&[u8]> = a.to_ref(); - let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut(); #[cfg(debug_assertions)] { @@ -585,12 +548,12 @@ unsafe impl VecZnxBigSubSmallBInplaceImpl for FFT64 { } } -unsafe impl VecZnxBigNegateInplaceImpl for FFT64 { - fn vec_znx_big_negate_inplace_impl(module: &Module, a: &mut A, a_col: usize) +unsafe impl VecZnxBigNegateInplaceImpl for FFT64 { + fn vec_znx_big_negate_inplace_impl(module: &Module, a: &mut A, a_col: usize) where - A: VecZnxBigToMut, + A: VecZnxBigToMut, { - let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut(); + let mut a: VecZnxBig<&mut [u8], Self> = a.to_mut(); unsafe { vec_znx::vec_znx_negate( module.ptr(), @@ -605,26 +568,29 @@ unsafe impl VecZnxBigNegateInplaceImpl for FFT64 { } } -unsafe impl VecZnxBigNormalizeTmpBytesImpl for FFT64 { - fn vec_znx_big_normalize_tmp_bytes_impl(module: &Module, n: usize) -> usize { +unsafe impl VecZnxBigNormalizeTmpBytesImpl for FFT64 { + fn vec_znx_big_normalize_tmp_bytes_impl(module: &Module, n: usize) -> usize { unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(module.ptr(), n as u64) as usize } } } -unsafe impl VecZnxBigNormalizeImpl for FFT64 { +unsafe impl VecZnxBigNormalizeImpl for FFT64 +where + Self: TakeSliceImpl, +{ fn vec_znx_big_normalize_impl( - module: &Module, + module: &Module, basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, - scratch: &mut Scratch, + scratch: &mut Scratch, ) where R: VecZnxToMut, - A: VecZnxBigToRef, + A: VecZnxBigToRef, { - let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); + let a: VecZnxBig<&[u8], Self> = a.to_ref(); let mut res: VecZnx<&mut [u8]> = res.to_mut(); #[cfg(debug_assertions)] @@ -650,15 +616,15 @@ unsafe impl VecZnxBigNormalizeImpl for FFT64 { } } -unsafe impl VecZnxBigAutomorphismImpl for FFT64 { +unsafe impl VecZnxBigAutomorphismImpl for FFT64 { /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`. - fn vec_znx_big_automorphism_impl(module: &Module, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn vec_znx_big_automorphism_impl(module: &Module, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) where - R: VecZnxBigToMut, - A: VecZnxBigToRef, + R: VecZnxBigToMut, + A: VecZnxBigToRef, { - let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); - let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + let a: VecZnxBig<&[u8], Self> = a.to_ref(); + let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut(); #[cfg(debug_assertions)] { @@ -679,13 +645,13 @@ unsafe impl VecZnxBigAutomorphismImpl for FFT64 { } } -unsafe impl VecZnxBigAutomorphismInplaceImpl for FFT64 { +unsafe impl VecZnxBigAutomorphismInplaceImpl for FFT64 { /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`. - fn vec_znx_big_automorphism_inplace_impl(module: &Module, k: i64, a: &mut A, a_col: usize) + fn vec_znx_big_automorphism_inplace_impl(module: &Module, k: i64, a: &mut A, a_col: usize) where - A: VecZnxBigToMut, + A: VecZnxBigToMut, { - let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut(); + let mut a: VecZnxBig<&mut [u8], Self> = a.to_mut(); unsafe { vec_znx::vec_znx_automorphism( module.ptr(), @@ -700,38 +666,3 @@ unsafe impl VecZnxBigAutomorphismInplaceImpl for FFT64 { } } } - -impl fmt::Display for VecZnxBig { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - writeln!( - f, - "VecZnxBig(n={}, cols={}, size={})", - self.n, self.cols, self.size - )?; - - for col in 0..self.cols { - writeln!(f, "Column {}:", col)?; - for size in 0..self.size { - let coeffs = self.at(col, size); - write!(f, " Size {}: [", size)?; - - let max_show = 100; - let show_count = coeffs.len().min(max_show); - - for (i, &coeff) in coeffs.iter().take(show_count).enumerate() { - if i > 0 { - write!(f, ", ")?; - } - write!(f, "{}", coeff)?; - } - - if coeffs.len() > max_show { - write!(f, ", ... ({} more)", coeffs.len() - max_show)?; - } - - writeln!(f, "]")?; - } - } - Ok(()) - } -} diff --git a/poulpy-backend/src/implementation/cpu_spqlios/vec_znx_dft_fft64.rs b/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_dft.rs similarity index 63% rename from poulpy-backend/src/implementation/cpu_spqlios/vec_znx_dft_fft64.rs rename to poulpy-backend/src/cpu_spqlios/fft64/vec_znx_dft.rs index 182f89a..669485b 100644 --- a/poulpy-backend/src/implementation/cpu_spqlios/vec_znx_dft_fft64.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_dft.rs @@ -1,78 +1,57 @@ -use std::fmt; - -use crate::{ - hal::{ - api::{TakeSlice, VecZnxDftToVecZnxBigTmpBytes, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero}, - layouts::{ - Data, DataRef, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftBytesOf, VecZnxDftOwned, - VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, - }, - oep::{ - VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftCopyImpl, - VecZnxDftFromBytesImpl, VecZnxDftFromVecZnxImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl, - VecZnxDftSubImpl, VecZnxDftToVecZnxBigConsumeImpl, VecZnxDftToVecZnxBigImpl, VecZnxDftToVecZnxBigTmpAImpl, - VecZnxDftToVecZnxBigTmpBytesImpl, VecZnxDftZeroImpl, - }, +use poulpy_hal::{ + api::{TakeSlice, VecZnxDftToVecZnxBigTmpBytes, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero}, + layouts::{ + Backend, Data, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut, + VecZnxDftToRef, VecZnxToRef, }, - implementation::cpu_spqlios::{ - ffi::{vec_znx_big, vec_znx_dft}, - module_fft64::FFT64, + oep::{ + VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftCopyImpl, + VecZnxDftFromBytesImpl, VecZnxDftFromVecZnxImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl, VecZnxDftSubImpl, + VecZnxDftToVecZnxBigConsumeImpl, VecZnxDftToVecZnxBigImpl, VecZnxDftToVecZnxBigTmpAImpl, + VecZnxDftToVecZnxBigTmpBytesImpl, VecZnxDftZeroImpl, }, }; -const VEC_ZNX_DFT_FFT64_WORDSIZE: usize = 1; +use crate::cpu_spqlios::{ + FFT64, + ffi::{vec_znx_big, vec_znx_dft}, +}; -impl ZnxSliceSize for VecZnxDft { - fn sl(&self) -> usize { - VEC_ZNX_DFT_FFT64_WORDSIZE * self.n() * self.cols() - } -} - -impl VecZnxDftBytesOf for VecZnxDft { - fn bytes_of(n: usize, cols: usize, size: usize) -> usize { - VEC_ZNX_DFT_FFT64_WORDSIZE * n * cols * size * size_of::() - } -} - -impl ZnxView for VecZnxDft { - type Scalar = f64; -} - -unsafe impl VecZnxDftFromBytesImpl for FFT64 { - fn vec_znx_dft_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec) -> VecZnxDftOwned { +unsafe impl VecZnxDftFromBytesImpl for FFT64 { + fn vec_znx_dft_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec) -> VecZnxDftOwned { VecZnxDft::, FFT64>::from_bytes(n, cols, size, bytes) } } -unsafe impl VecZnxDftAllocBytesImpl for FFT64 { +unsafe impl VecZnxDftAllocBytesImpl for FFT64 { fn vec_znx_dft_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize { - VecZnxDft::, FFT64>::bytes_of(n, cols, size) + FFT64::layout_prep_word_count() * n * cols * size * size_of::<::ScalarPrep>() } } -unsafe impl VecZnxDftAllocImpl for FFT64 { - fn vec_znx_dft_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxDftOwned { +unsafe impl VecZnxDftAllocImpl for FFT64 { + fn vec_znx_dft_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxDftOwned { VecZnxDftOwned::alloc(n, cols, size) } } -unsafe impl VecZnxDftToVecZnxBigTmpBytesImpl for FFT64 { - fn vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(module: &Module, n: usize) -> usize { +unsafe impl VecZnxDftToVecZnxBigTmpBytesImpl for FFT64 { + fn vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(module: &Module, n: usize) -> usize { unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(module.ptr(), n as u64) as usize } } } -unsafe impl VecZnxDftToVecZnxBigImpl for FFT64 { +unsafe impl VecZnxDftToVecZnxBigImpl for FFT64 { fn vec_znx_dft_to_vec_znx_big_impl( - module: &Module, + module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize, - scratch: &mut Scratch, + scratch: &mut Scratch, ) where - R: VecZnxBigToMut, - A: VecZnxDftToRef, + R: VecZnxBigToMut, + A: VecZnxDftToRef, { let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); let a: VecZnxDft<&[u8], FFT64> = a.to_ref(); @@ -104,11 +83,11 @@ unsafe impl VecZnxDftToVecZnxBigImpl for FFT64 { } } -unsafe impl VecZnxDftToVecZnxBigTmpAImpl for FFT64 { - fn vec_znx_dft_to_vec_znx_big_tmp_a_impl(module: &Module, res: &mut R, res_col: usize, a: &mut A, a_col: usize) +unsafe impl VecZnxDftToVecZnxBigTmpAImpl for FFT64 { + fn vec_znx_dft_to_vec_znx_big_tmp_a_impl(module: &Module, res: &mut R, res_col: usize, a: &mut A, a_col: usize) where - R: VecZnxBigToMut, - A: VecZnxDftToMut, + R: VecZnxBigToMut, + A: VecZnxDftToMut, { let mut res_mut: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut(); @@ -132,10 +111,10 @@ unsafe impl VecZnxDftToVecZnxBigTmpAImpl for FFT64 { } } -unsafe impl VecZnxDftToVecZnxBigConsumeImpl for FFT64 { - fn vec_znx_dft_to_vec_znx_big_consume_impl(module: &Module, mut a: VecZnxDft) -> VecZnxBig +unsafe impl VecZnxDftToVecZnxBigConsumeImpl for FFT64 { + fn vec_znx_dft_to_vec_znx_big_consume_impl(module: &Module, mut a: VecZnxDft) -> VecZnxBig where - VecZnxDft: VecZnxDftToMut, + VecZnxDft: VecZnxDftToMut, { let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut(); @@ -158,9 +137,9 @@ unsafe impl VecZnxDftToVecZnxBigConsumeImpl for FFT64 { } } -unsafe impl VecZnxDftFromVecZnxImpl for FFT64 { +unsafe impl VecZnxDftFromVecZnxImpl for FFT64 { fn vec_znx_dft_from_vec_znx_impl( - module: &Module, + module: &Module, step: usize, offset: usize, res: &mut R, @@ -168,7 +147,7 @@ unsafe impl VecZnxDftFromVecZnxImpl for FFT64 { a: &A, a_col: usize, ) where - R: VecZnxDftToMut, + R: VecZnxDftToMut, A: VecZnxToRef, { let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); @@ -196,19 +175,12 @@ unsafe impl VecZnxDftFromVecZnxImpl for FFT64 { } } -unsafe impl VecZnxDftAddImpl for FFT64 { - fn vec_znx_dft_add_impl( - module: &Module, - res: &mut R, - res_col: usize, - a: &A, - a_col: usize, - b: &D, - b_col: usize, - ) where - R: VecZnxDftToMut, - A: VecZnxDftToRef, - D: VecZnxDftToRef, +unsafe impl VecZnxDftAddImpl for FFT64 { + fn vec_znx_dft_add_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + D: VecZnxDftToRef, { let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); @@ -235,11 +207,11 @@ unsafe impl VecZnxDftAddImpl for FFT64 { } } -unsafe impl VecZnxDftAddInplaceImpl for FFT64 { - fn vec_znx_dft_add_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) +unsafe impl VecZnxDftAddInplaceImpl for FFT64 { + fn vec_znx_dft_add_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where - R: VecZnxDftToMut, - A: VecZnxDftToRef, + R: VecZnxDftToMut, + A: VecZnxDftToRef, { let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); @@ -262,19 +234,12 @@ unsafe impl VecZnxDftAddInplaceImpl for FFT64 { } } -unsafe impl VecZnxDftSubImpl for FFT64 { - fn vec_znx_dft_sub_impl( - module: &Module, - res: &mut R, - res_col: usize, - a: &A, - a_col: usize, - b: &D, - b_col: usize, - ) where - R: VecZnxDftToMut, - A: VecZnxDftToRef, - D: VecZnxDftToRef, +unsafe impl VecZnxDftSubImpl for FFT64 { + fn vec_znx_dft_sub_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + D: VecZnxDftToRef, { let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); @@ -301,11 +266,11 @@ unsafe impl VecZnxDftSubImpl for FFT64 { } } -unsafe impl VecZnxDftSubABInplaceImpl for FFT64 { - fn vec_znx_dft_sub_ab_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) +unsafe impl VecZnxDftSubABInplaceImpl for FFT64 { + fn vec_znx_dft_sub_ab_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where - R: VecZnxDftToMut, - A: VecZnxDftToRef, + R: VecZnxDftToMut, + A: VecZnxDftToRef, { let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); @@ -328,11 +293,11 @@ unsafe impl VecZnxDftSubABInplaceImpl for FFT64 { } } -unsafe impl VecZnxDftSubBAInplaceImpl for FFT64 { - fn vec_znx_dft_sub_ba_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) +unsafe impl VecZnxDftSubBAInplaceImpl for FFT64 { + fn vec_znx_dft_sub_ba_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where - R: VecZnxDftToMut, - A: VecZnxDftToRef, + R: VecZnxDftToMut, + A: VecZnxDftToRef, { let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); @@ -355,9 +320,9 @@ unsafe impl VecZnxDftSubBAInplaceImpl for FFT64 { } } -unsafe impl VecZnxDftCopyImpl for FFT64 { +unsafe impl VecZnxDftCopyImpl for FFT64 { fn vec_znx_dft_copy_impl( - _module: &Module, + _module: &Module, step: usize, offset: usize, res: &mut R, @@ -365,8 +330,8 @@ unsafe impl VecZnxDftCopyImpl for FFT64 { a: &A, a_col: usize, ) where - R: VecZnxDftToMut, - A: VecZnxDftToRef, + R: VecZnxDftToMut, + A: VecZnxDftToRef, { let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); @@ -388,46 +353,11 @@ unsafe impl VecZnxDftCopyImpl for FFT64 { } } -unsafe impl VecZnxDftZeroImpl for FFT64 { - fn vec_znx_dft_zero_impl(_module: &Module, res: &mut R) +unsafe impl VecZnxDftZeroImpl for FFT64 { + fn vec_znx_dft_zero_impl(_module: &Module, res: &mut R) where - R: VecZnxDftToMut, + R: VecZnxDftToMut, { res.to_mut().data.fill(0); } } - -impl fmt::Display for VecZnxDft { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - writeln!( - f, - "VecZnxDft(n={}, cols={}, size={})", - self.n, self.cols, self.size - )?; - - for col in 0..self.cols { - writeln!(f, "Column {}:", col)?; - for size in 0..self.size { - let coeffs = self.at(col, size); - write!(f, " Size {}: [", size)?; - - let max_show = 100; - let show_count = coeffs.len().min(max_show); - - for (i, &coeff) in coeffs.iter().take(show_count).enumerate() { - if i > 0 { - write!(f, ", ")?; - } - write!(f, "{}", coeff)?; - } - - if coeffs.len() > max_show { - write!(f, ", ... ({} more)", coeffs.len() - max_show)?; - } - - writeln!(f, "]")?; - } - } - Ok(()) - } -} diff --git a/poulpy-backend/src/implementation/cpu_spqlios/vmp_pmat_fft64.rs b/poulpy-backend/src/cpu_spqlios/fft64/vmp_pmat.rs similarity index 86% rename from poulpy-backend/src/implementation/cpu_spqlios/vmp_pmat_fft64.rs rename to poulpy-backend/src/cpu_spqlios/fft64/vmp_pmat.rs index 8730880..68c7ccd 100644 --- a/poulpy-backend/src/implementation/cpu_spqlios/vmp_pmat_fft64.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/vmp_pmat.rs @@ -1,39 +1,23 @@ -use crate::{ - hal::{ - api::{TakeSlice, VmpApplyTmpBytes, VmpPrepareTmpBytes, ZnxInfos, ZnxView, ZnxViewMut}, - layouts::{ - DataRef, MatZnx, MatZnxToRef, Module, Scratch, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, VmpPMat, VmpPMatBytesOf, - VmpPMatOwned, VmpPMatToMut, VmpPMatToRef, - }, - oep::{ - VmpApplyAddImpl, VmpApplyAddTmpBytesImpl, VmpApplyImpl, VmpApplyTmpBytesImpl, VmpPMatAllocBytesImpl, - VmpPMatAllocImpl, VmpPMatFromBytesImpl, VmpPMatPrepareImpl, VmpPrepareTmpBytesImpl, - }, +use poulpy_hal::{ + api::{TakeSlice, VmpApplyTmpBytes, VmpPrepareTmpBytes, ZnxInfos, ZnxView, ZnxViewMut}, + layouts::{ + Backend, MatZnx, MatZnxToRef, Module, Scratch, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, VmpPMat, VmpPMatOwned, + VmpPMatToMut, VmpPMatToRef, }, - implementation::cpu_spqlios::{ - ffi::{vec_znx_dft::vec_znx_dft_t, vmp}, - module_fft64::FFT64, + oep::{ + VmpApplyAddImpl, VmpApplyAddTmpBytesImpl, VmpApplyImpl, VmpApplyTmpBytesImpl, VmpPMatAllocBytesImpl, VmpPMatAllocImpl, + VmpPMatFromBytesImpl, VmpPMatPrepareImpl, VmpPrepareTmpBytesImpl, }, }; -const VMP_PMAT_FFT64_WORDSIZE: usize = 1; +use crate::cpu_spqlios::{ + FFT64, + ffi::{vec_znx_dft::vec_znx_dft_t, vmp}, +}; -impl ZnxView for VmpPMat { - type Scalar = f64; -} - -impl VmpPMatBytesOf for FFT64 { - fn vmp_pmat_bytes_of(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { - VMP_PMAT_FFT64_WORDSIZE * n * rows * cols_in * cols_out * size * size_of::() - } -} - -unsafe impl VmpPMatAllocBytesImpl for FFT64 -where - FFT64: VmpPMatBytesOf, -{ +unsafe impl VmpPMatAllocBytesImpl for FFT64 { fn vmp_pmat_alloc_bytes_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { - FFT64::vmp_pmat_bytes_of(n, rows, cols_in, cols_out, size) + FFT64::layout_prep_word_count() * n * rows * cols_in * cols_out * size * size_of::() } } @@ -251,8 +235,6 @@ unsafe impl VmpApplyAddImpl for FFT64 { #[cfg(debug_assertions)] { - use crate::hal::api::ZnxInfos; - assert_eq!(b.n(), res.n()); assert_eq!(a.n(), res.n()); assert_eq!( diff --git a/poulpy-backend/src/cpu_spqlios/mod.rs b/poulpy-backend/src/cpu_spqlios/mod.rs new file mode 100644 index 0000000..40baf00 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/mod.rs @@ -0,0 +1,9 @@ +mod ffi; +mod fft64; +mod ntt120; + +#[cfg(test)] +mod test; + +pub use fft64::*; +pub use ntt120::*; diff --git a/poulpy-backend/src/cpu_spqlios/ntt120/mod.rs b/poulpy-backend/src/cpu_spqlios/ntt120/mod.rs new file mode 100644 index 0000000..cd4cb13 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/ntt120/mod.rs @@ -0,0 +1,7 @@ +mod module; +mod svp_ppol; +mod vec_znx_big; +mod vec_znx_dft; +mod vmp_pmat; + +pub use module::NTT120; diff --git a/poulpy-backend/src/implementation/cpu_spqlios/module_ntt120.rs b/poulpy-backend/src/cpu_spqlios/ntt120/module.rs similarity index 53% rename from poulpy-backend/src/implementation/cpu_spqlios/module_ntt120.rs rename to poulpy-backend/src/cpu_spqlios/ntt120/module.rs index 94e0bdb..65aa143 100644 --- a/poulpy-backend/src/implementation/cpu_spqlios/module_ntt120.rs +++ b/poulpy-backend/src/cpu_spqlios/ntt120/module.rs @@ -1,25 +1,29 @@ use std::ptr::NonNull; -use crate::{ - hal::{ - layouts::{Backend, Module}, - oep::ModuleNewImpl, - }, - implementation::cpu_spqlios::{ - CPUAVX, - ffi::module::{MODULE, delete_module_info, new_module_info}, - }, +use poulpy_hal::{ + layouts::{Backend, Module}, + oep::ModuleNewImpl, }; +use crate::cpu_spqlios::ffi::module::{MODULE, delete_module_info, new_module_info}; + pub struct NTT120; -impl CPUAVX for NTT120 {} - impl Backend for NTT120 { + type ScalarPrep = i64; + type ScalarBig = i128; type Handle = MODULE; unsafe fn destroy(handle: NonNull) { unsafe { delete_module_info(handle.as_ptr()) } } + + fn layout_big_word_count() -> usize { + 4 + } + + fn layout_prep_word_count() -> usize { + 1 + } } unsafe impl ModuleNewImpl for NTT120 { diff --git a/poulpy-backend/src/cpu_spqlios/ntt120/svp_ppol.rs b/poulpy-backend/src/cpu_spqlios/ntt120/svp_ppol.rs new file mode 100644 index 0000000..c98237a --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/ntt120/svp_ppol.rs @@ -0,0 +1,24 @@ +use poulpy_hal::{ + layouts::{Backend, SvpPPolOwned}, + oep::{SvpPPolAllocBytesImpl, SvpPPolAllocImpl, SvpPPolFromBytesImpl}, +}; + +use crate::cpu_spqlios::NTT120; + +unsafe impl SvpPPolFromBytesImpl for NTT120 { + fn svp_ppol_from_bytes_impl(n: usize, cols: usize, bytes: Vec) -> SvpPPolOwned { + SvpPPolOwned::from_bytes(n, cols, bytes) + } +} + +unsafe impl SvpPPolAllocImpl for NTT120 { + fn svp_ppol_alloc_impl(n: usize, cols: usize) -> SvpPPolOwned { + SvpPPolOwned::alloc(n, cols) + } +} + +unsafe impl SvpPPolAllocBytesImpl for NTT120 { + fn svp_ppol_alloc_bytes_impl(n: usize, cols: usize) -> usize { + NTT120::layout_prep_word_count() * n * cols * size_of::() + } +} diff --git a/poulpy-backend/src/cpu_spqlios/ntt120/vec_znx_big.rs b/poulpy-backend/src/cpu_spqlios/ntt120/vec_znx_big.rs new file mode 100644 index 0000000..715b432 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/ntt120/vec_znx_big.rs @@ -0,0 +1,9 @@ +use poulpy_hal::{layouts::Backend, oep::VecZnxBigAllocBytesImpl}; + +use crate::cpu_spqlios::NTT120; + +unsafe impl VecZnxBigAllocBytesImpl for NTT120 { + fn vec_znx_big_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize { + NTT120::layout_big_word_count() * n * cols * size * size_of::() + } +} diff --git a/poulpy-backend/src/cpu_spqlios/ntt120/vec_znx_dft.rs b/poulpy-backend/src/cpu_spqlios/ntt120/vec_znx_dft.rs new file mode 100644 index 0000000..53dd24f --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/ntt120/vec_znx_dft.rs @@ -0,0 +1,18 @@ +use poulpy_hal::{ + layouts::{Backend, VecZnxDftOwned}, + oep::{VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl}, +}; + +use crate::cpu_spqlios::NTT120; + +unsafe impl VecZnxDftAllocBytesImpl for NTT120 { + fn vec_znx_dft_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize { + NTT120::layout_prep_word_count() * n * cols * size * size_of::() + } +} + +unsafe impl VecZnxDftAllocImpl for NTT120 { + fn vec_znx_dft_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxDftOwned { + VecZnxDftOwned::alloc(n, cols, size) + } +} diff --git a/poulpy-backend/src/cpu_spqlios/ntt120/vmp_pmat.rs b/poulpy-backend/src/cpu_spqlios/ntt120/vmp_pmat.rs new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/ntt120/vmp_pmat.rs @@ -0,0 +1 @@ + diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/.clang-format b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/.clang-format new file mode 100644 index 0000000..120c0ac --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/.clang-format @@ -0,0 +1,14 @@ +# Use the Google style in this project. +BasedOnStyle: Google + +# Some folks prefer to write "int& foo" while others prefer "int &foo". The +# Google Style Guide only asks for consistency within a project, we chose +# "int& foo" for this project: +DerivePointerAlignment: false +PointerAlignment: Left + +# The Google Style Guide only asks for consistency w.r.t. "east const" vs. +# "const west" alignment of cv-qualifiers. In this project we use "east const". +QualifierAlignment: Left + +ColumnLimit: 120 diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/.github/workflows/auto-release.yml b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/.github/workflows/auto-release.yml new file mode 100644 index 0000000..9f9203a --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/.github/workflows/auto-release.yml @@ -0,0 +1,20 @@ +name: Auto-Release + +on: + workflow_dispatch: + push: + branches: [ "main" ] + +jobs: + build: + name: Auto-Release + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 3 +# sparse-checkout: manifest.yaml scripts/auto-release.sh + + - run: + ${{github.workspace}}/scripts/auto-release.sh diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/.gitignore b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/.gitignore new file mode 100644 index 0000000..6adb7f2 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/.gitignore @@ -0,0 +1,6 @@ +cmake-build-* +.idea + +build +.vscode +.*.sh \ No newline at end of file diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/CMakeLists.txt b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/CMakeLists.txt new file mode 100644 index 0000000..5711d8e --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/CMakeLists.txt @@ -0,0 +1,69 @@ +cmake_minimum_required(VERSION 3.8) +project(spqlios) + +# read the current version from the manifest file +file(READ "manifest.yaml" manifest) +string(REGEX MATCH "version: +(([0-9]+)\\.([0-9]+)\\.([0-9]+))" SPQLIOS_VERSION_BLAH ${manifest}) +#message(STATUS "Version: ${SPQLIOS_VERSION_BLAH}") +set(SPQLIOS_VERSION ${CMAKE_MATCH_1}) +set(SPQLIOS_VERSION_MAJOR ${CMAKE_MATCH_2}) +set(SPQLIOS_VERSION_MINOR ${CMAKE_MATCH_3}) +set(SPQLIOS_VERSION_PATCH ${CMAKE_MATCH_4}) +message(STATUS "Compiling spqlios-fft version: ${SPQLIOS_VERSION_MAJOR}.${SPQLIOS_VERSION_MINOR}.${SPQLIOS_VERSION_PATCH}") + +#set(ENABLE_SPQLIOS_F128 ON CACHE BOOL "Enable float128 via libquadmath") +set(WARNING_PARANOID ON CACHE BOOL "Treat all warnings as errors") +set(ENABLE_TESTING ON CACHE BOOL "Compiles unittests and integration tests") +set(DEVMODE_INSTALL OFF CACHE BOOL "Install private headers and testlib (mainly for CI)") + +if (NOT CMAKE_BUILD_TYPE OR CMAKE_BUILD_TYPE STREQUAL "") + set(CMAKE_BUILD_TYPE "Release" CACHE STRING "Build type: Release or Debug" FORCE) +endif() +message(STATUS "Build type: ${CMAKE_BUILD_TYPE}") + +if (WARNING_PARANOID) + add_compile_options(-Wall -Werror -Wno-unused-command-line-argument) +endif() + +message(STATUS "CMAKE_HOST_SYSTEM_NAME: ${CMAKE_HOST_SYSTEM_NAME}") +message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}") +message(STATUS "CMAKE_SYSTEM_NAME: ${CMAKE_SYSTEM_NAME}") + +if (CMAKE_SYSTEM_PROCESSOR MATCHES "(x86)|(X86)|(amd64)|(AMD64)") + set(X86 ON) + set(AARCH64 OFF) +else () + set(X86 OFF) + # set(ENABLE_SPQLIOS_F128 OFF) # float128 are only supported for x86 targets +endif () +if (CMAKE_SYSTEM_PROCESSOR MATCHES "(aarch64)|(arm64)") + set(AARCH64 ON) +endif () + +if (CMAKE_SYSTEM_NAME MATCHES "(Windows)|(MSYS)") + set(WIN32 ON) +endif () +if (WIN32) + #overrides for win32 + set(X86 OFF) + set(AARCH64 OFF) + set(X86_WIN32 ON) +else() + set(X86_WIN32 OFF) + set(WIN32 OFF) +endif (WIN32) + +message(STATUS "--> WIN32: ${WIN32}") +message(STATUS "--> X86_WIN32: ${X86_WIN32}") +message(STATUS "--> X86_LINUX: ${X86}") +message(STATUS "--> AARCH64: ${AARCH64}") + +# compiles the main library in spqlios +add_subdirectory(spqlios) + +# compiles and activates unittests and itests +if (${ENABLE_TESTING}) + enable_testing() + add_subdirectory(test) +endif() + diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/CONTRIBUTING.md b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/CONTRIBUTING.md new file mode 100644 index 0000000..a30c304 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/CONTRIBUTING.md @@ -0,0 +1,77 @@ +# Contributing to SPQlios-fft + +The spqlios-fft team encourages contributions. +We encourage users to fix bugs, improve the documentation, write tests and to enhance the code, or ask for new features. +We encourage researchers to contribute with implementations of their FFT or NTT algorithms. +In the following we are trying to give some guidance on how to contribute effectively. + +## Communication ## + +Communication in the spqlios-fft project happens mainly on [GitHub](https://github.com/tfhe/spqlios-fft/issues). + +All communications are public, so please make sure to maintain professional behaviour in +all published comments. See [Code of Conduct](https://www.contributor-covenant.org/version/2/1/code_of_conduct/) for +guidelines. + +## Reporting Bugs or Requesting features ## + +Bug should be filed at [https://github.com/tfhe/spqlios-fft/issues](https://github.com/tfhe/spqlios-fft/issues). + +Features can also be requested there, in this case, please ensure that the features you request are self-contained, +easy to define, and generic enough to be used in different use-cases. Please provide an example of use-cases if +possible. + +## Setting up topic branches and generating pull requests + +This section applies to people that already have write access to the repository. Specific instructions for pull-requests +from public forks will be given later. + +To implement some changes, please follow these steps: + +- Create a "topic branch". Usually, the branch name should be `username/small-title` + or better `username/issuenumber-small-title` where `issuenumber` is the number of + the github issue number that is tackled. +- Push any needed commits to your branch. Make sure it compiles in `CMAKE_BUILD_TYPE=Debug` and `=Release`, with `-DWARNING_PARANOID=ON`. +- When the branch is nearly ready for review, please open a pull request, and add the label `check-on-arm` +- Do as many commits as necessary until all CI checks pass and all PR comments have been resolved. + + > _During the process, you may optionnally use `git rebase -i` to clean up your commit history. If you elect to do so, + please at the very least make sure that nobody else is working or has forked from your branch: the conflicts it would generate + and the human hours to fix them are not worth it. `Git merge` remains the preferred option._ + +- Finally, when all reviews are positive and all CI checks pass, you may merge your branch via the github webpage. + +### Keep your pull requests limited to a single issue + +Pull requests should be as small/atomic as possible. + +### Coding Conventions + +* Please make sure that your code is formatted according to the `.clang-format` file and + that all files end with a newline character. +* Please make sure that all the functions declared in the public api have relevant doxygen comments. + Preferably, functions in the private apis should also contain a brief doxygen description. + +### Versions and History + +* **Stable API** The project uses semantic versioning on the functions that are listed as `stable` in the documentation. A version has + the form `x.y.z` + * a patch release that increments `z` does not modify the stable API. + * a minor release that increments `y` adds a new feature to the stable API. + * In the unlikely case where we need to change or remove a feature, we will trigger a major release that + increments `x`. + + > _If any, we will mark those features as deprecated at least six months before the major release._ + +* **Experimental API** Features that are not part of the stable section in the documentation are experimental features: you may test them at + your own risk, + but keep in mind that semantic versioning does not apply to them. + +> _If you have a use-case that uses an experimental feature, we encourage +> you to tell us about it, so that this feature reaches to the stable section faster!_ + +* **Version history** The current version is reported in `manifest.yaml`, any change of version comes up with a tag on the main branch, and the history between releases is summarized in `Changelog.md`. It is the main source of truth for anyone who wishes to + get insight about + the history of the repository (not the commit graph). + +> Note: _The commit graph of git is for git's internal use only. Its main purpose is to reduce potential merge conflicts to a minimum, even in scenario where multiple features are developped in parallel: it may therefore be non-linear. If, as humans, we like to see a linear history, please read `Changelog.md` instead!_ diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/Changelog.md b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/Changelog.md new file mode 100644 index 0000000..5c0d2ea --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/Changelog.md @@ -0,0 +1,18 @@ +# Changelog + +All notable changes to this project will be documented in this file. +this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [2.0.0] - 2024-08-21 + +- Initial release of the `vec_znx` (except convolution products), `vec_rnx` and `zn` apis. +- Hardware acceleration available: AVX2 (most parts) +- APIs are documented in the wiki and are in "beta mode": during the 2.x -> 3.x transition, functions whose API is satisfactory in test projects will pass in "stable mode". + +## [1.0.0] - 2023-07-18 + +- Initial release of the double precision fft on the reim and cplx backends +- Coeffs-space conversions cplx <-> znx32 and tnx32 +- FFT-space conversions cplx <-> reim4 layouts +- FFT-space multiplications on the cplx, reim and reim4 layouts. +- In this first release, the only platform supported is linux x86_64 (generic C code, and avx2/fma). It compiles on arm64, but without any acceleration. diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/LICENSE b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/README.md b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/README.md new file mode 100644 index 0000000..9edc19d --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/README.md @@ -0,0 +1,65 @@ +# SPQlios library + + + +The SPQlios library provides fast arithmetic for Fully Homomorphic Encryption, and other lattice constructions that arise in post quantum cryptography. + + + +Namely, it is divided into 4 sections: + +* The low-level DFT section support FFT over 64-bit floats, as well as NTT modulo one fixed 120-bit modulus. It is an upgrade of the original spqlios-fft module embedded in the TFHE library since 2016. The DFT section exposes the traditional DFT, inverse-DFT, and coefficient-wise multiplications in DFT space. +* The VEC_ZNX section exposes fast algebra over vectors of small integer polynomial modulo $X^N+1$. It proposed in particular efficient (prepared) vector-matrix products, scalar-vector products, convolution products, and element-wise products, operations that naturally occurs on gadget-decomposed Ring-LWE coordinates. +* The RNX section is a simpler variant of VEC_ZNX, to represent single polynomials modulo $X^N+1$ (over the reals or over the torus) when the coefficient precision fits on 64-bit doubles. The small vector-matrix API of the RNX section is particularly adapted to reproducing the fastest CGGI-based bootstrappings. +* The ZN section focuses over vector and matrix algebra over scalars (used by scalar LWE, or scalar key-switches, but also on non-ring schemes like Frodo, FrodoPIR, and SimplePIR). + +### A high value target for hardware accelerations + +SPQlios is more than a library, it is also a good target for hardware developers. +On one hand, the arithmetic operations that are defined in the library have a clear standalone mathematical definition. And at the same time, the amount of work in each operations is sufficiently large so that meaningful functions only require a few of these. + +This makes the SPQlios API a high value target for hardware acceleration, that targets FHE. + +### SPQLios is not an FHE library, but a huge enabler + +SPQlios itself is not an FHE library: there is no ciphertext, plaintext or key. It is a mathematical library that exposes efficient algebra over polynomials. Using the functions exposed, it is possible to quickly build efficient FHE libraries, with support for the main schemes based on Ring-LWE: BFV, BGV, CGGI, DM, CKKS. + + +## Dependencies + +The SPQLIOS-FFT library is a C library that can be compiled with a standard C compiler, and depends only on libc and libm. The API +interface can be used in a regular C code, and any other language via classical foreign APIs. + +The unittests and integration tests are in an optional part of the code, and are written in C++. These tests rely on +[```benchmark```](https://github.com/google/benchmark), and [```gtest```](https://github.com/google/googletest) libraries, and therefore require a C++17 compiler. + +Currently, the project has been tested with the gcc,g++ >= 11.3.0 compiler under Linux (x86_64). In the future, we plan to +extend the compatibility to other compilers, platforms and operating systems. + + +## Installation + +The library uses a classical ```cmake``` build mechanism: use ```cmake``` to create a ```build``` folder in the top level directory and run ```make``` from inside it. This assumes that the standard tool ```cmake``` is already installed on the system, and an up-to-date c++ compiler (i.e. g++ >=11.3.0) as well. + +It will compile the shared library in optimized mode, and ```make install``` install it to the desired prefix folder (by default ```/usr/local/lib```). + +If you want to choose additional compile options (i.e. other installation folder, debug mode, tests), you need to run cmake manually and pass the desired options: +``` +mkdir build +cd build +cmake ../src -CMAKE_INSTALL_PREFIX=/usr/ +make +``` +The available options are the following: + +| Variable Name | values | +| -------------------- | ------------------------------------------------------------ | +| CMAKE_INSTALL_PREFIX | */usr/local* installation folder (libs go in lib/ and headers in include/) | +| WARNING_PARANOID | All warnings are shown and treated as errors. Off by default | +| ENABLE_TESTING | Compiles unit tests and integration tests | + +------ + + + + diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/docs/api-full.svg b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/docs/api-full.svg new file mode 100644 index 0000000..8cc9743 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/docs/api-full.svg @@ -0,0 +1,416 @@ + + + +VEC_ZNX arithmetic APIFFT/NTT APIRNX APIZN APIFHE librariesBGV, BFV, CKKSFHE librariesTFHE/CGGI, DMOther PQC:NTRU,FalconOnion-PIRsOnion, spiralSPQlios APIChimeraSchemeSwitchLWE/SISFrodo, FrodoPIRPotential use-casesthat would benefit from the API diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/docs/logo-inpher1.png b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/docs/logo-inpher1.png new file mode 100644 index 0000000..ce8de01 Binary files /dev/null and b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/docs/logo-inpher1.png differ diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/docs/logo-inpher2.png b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/docs/logo-inpher2.png new file mode 100644 index 0000000..a25c87f Binary files /dev/null and b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/docs/logo-inpher2.png differ diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/docs/logo-sandboxaq-black.svg b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/docs/logo-sandboxaq-black.svg new file mode 100644 index 0000000..bb31eae --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/docs/logo-sandboxaq-black.svg @@ -0,0 +1,139 @@ + + + + + + + + + + + + +SANDBOX +AQ + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/docs/logo-sandboxaq-white.svg b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/docs/logo-sandboxaq-white.svg new file mode 100644 index 0000000..036ce5a --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/docs/logo-sandboxaq-white.svg @@ -0,0 +1,133 @@ + + + + + +SANDBOX +AQ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/manifest.yaml b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/manifest.yaml new file mode 100644 index 0000000..02235cb --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/manifest.yaml @@ -0,0 +1,2 @@ +library: spqlios-fft +version: 2.0.0 diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/scripts/auto-release.sh b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/scripts/auto-release.sh new file mode 100755 index 0000000..c31efbe --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/scripts/auto-release.sh @@ -0,0 +1,27 @@ +#!/bin/sh + +# this script generates one tag if there is a version change in manifest.yaml +cd `dirname $0`/.. +if [ "v$1" = "v-y" ]; then + echo "production mode!"; +fi +changes=`git diff HEAD~1..HEAD -- manifest.yaml | grep 'version:'` +oldversion=$(echo "$changes" | grep '^-version:' | cut '-d ' -f2) +version=$(echo "$changes" | grep '^+version:' | cut '-d ' -f2) +echo "Versions: $oldversion --> $version" +if [ "v$oldversion" = "v$version" ]; then + echo "Same version - nothing to do"; exit 0; +fi +if [ "v$1" = "v-y" ]; then + git config user.name github-actions + git config user.email github-actions@github.com + git tag -a "v$version" -m "Version $version" + git push origin "v$version" +else +cat </dev/null + rm -f "$DIR/$FNAME" 2>/dev/null + DESTDIR="$DIR/dist" cmake --install build || exit 1 + if [ -d "$DIR/dist$CI_INSTALL_PREFIX" ]; then + tar -C "$DIR/dist" -cvzf "$DIR/$FNAME" . + else + # fix since msys can mess up the paths + REAL_DEST=`find "$DIR/dist" -type d -exec test -d "{}$CI_INSTALL_PREFIX" \; -print` + echo "REAL_DEST: $REAL_DEST" + [ -d "$REAL_DEST$CI_INSTALL_PREFIX" ] && tar -C "$REAL_DEST" -cvzf "$DIR/$FNAME" . + fi + [ -f "$DIR/$FNAME" ] || { echo "failed to create $DIR/$FNAME"; exit 1; } + [ "x$CI_CREDS" = "x" ] && { echo "CI_CREDS is not set: not uploading"; exit 1; } + curl -u "$CI_CREDS" -T "$DIR/$FNAME" "$CI_REPO_URL/$FNAME" +fi + +if [ "x$1" = "xinstall" ]; then + [ "x$CI_CREDS" = "x" ] && { echo "CI_CREDS is not set: not downloading"; exit 1; } + # cleaning + rm -rf "$DESTDIR$CI_INSTALL_PREFIX"/* 2>/dev/null + rm -f "$DIR/$FNAME" 2>/dev/null + # downloading + curl -u "$CI_CREDS" -o "$DIR/$FNAME" "$CI_REPO_URL/$FNAME" + [ -f "$DIR/$FNAME" ] || { echo "failed to download $DIR/$FNAME"; exit 0; } + # installing + mkdir -p $DESTDIR + tar -C "$DESTDIR" -xvzf "$DIR/$FNAME" + exit 0 +fi diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/scripts/prepare-release b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/scripts/prepare-release new file mode 100755 index 0000000..4e4843a --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/scripts/prepare-release @@ -0,0 +1,181 @@ +#!/usr/bin/perl +## +## This script will help update manifest.yaml and Changelog.md before a release +## Any merge to master that changes the version line in manifest.yaml +## is considered as a new release. +## +## When ready to make a release, please run ./scripts/prepare-release +## and commit push the final result! +use File::Basename; +use Cwd 'abs_path'; + +# find its way to the root of git's repository +my $scriptsdirname = dirname(abs_path(__FILE__)); +chdir "$scriptsdirname/.."; +print "✓ Entering directory:".`pwd`; + +# ensures that the current branch is ahead of origin/main +my $diff= `git diff`; +chop $diff; +if ($diff =~ /./) { + die("ERROR: Please commit all the changes before calling the prepare-release script."); +} else { + print("✓ All changes are comitted.\n"); +} +system("git fetch origin"); +my $vcount = `git rev-list --left-right --count origin/main...HEAD`; +$vcount =~ /^([0-9]+)[ \t]*([0-9]+)$/; +if ($2>0) { + die("ERROR: the current HEAD is not ahead of origin/main\n. Please use git merge origin/main."); +} else { + print("✓ Current HEAD is up to date with origin/main.\n"); +} + +mkdir ".changes"; +my $currentbranch = `git rev-parse --abbrev-ref HEAD`; +chop $currentbranch; +$currentbranch =~ s/[^a-zA-Z._-]+/-/g; +my $changefile=".changes/$currentbranch.md"; +my $origmanifestfile=".changes/$currentbranch--manifest.yaml"; +my $origchangelogfile=".changes/$currentbranch--Changelog.md"; + +my $exit_code=system("wget -O $origmanifestfile https://raw.githubusercontent.com/tfhe/spqlios-fft/main/manifest.yaml"); +if ($exit_code!=0 or ! -f $origmanifestfile) { + die("ERROR: failed to download manifest.yaml"); +} +$exit_code=system("wget -O $origchangelogfile https://raw.githubusercontent.com/tfhe/spqlios-fft/main/Changelog.md"); +if ($exit_code!=0 or ! -f $origchangelogfile) { + die("ERROR: failed to download Changelog.md"); +} + +# read the current version (from origin/main manifest) +my $vmajor = 0; +my $vminor = 0; +my $vpatch = 0; +my $versionline = `grep '^version: ' $origmanifestfile | cut -d" " -f2`; +chop $versionline; +if (not $versionline =~ /^([0-9]+)\.([0-9]+)\.([0-9]+)$/) { + die("ERROR: invalid version in manifest file: $versionline\n"); +} else { + $vmajor = int($1); + $vminor = int($2); + $vpatch = int($3); +} +print "Version in manifest file: $vmajor.$vminor.$vpatch\n"; + +if (not -f $changefile) { + ## create a changes file + open F,">$changefile"; + print F "# Changefile for branch $currentbranch\n\n"; + print F "## Type of release (major,minor,patch)?\n\n"; + print F "releasetype: patch\n\n"; + print F "## What has changed (please edit)?\n\n"; + print F "- This has changed.\n"; + close F; +} + +system("editor $changefile"); + +# compute the new version +my $nvmajor; +my $nvminor; +my $nvpatch; +my $changelog; +my $recordchangelog=0; +open F,"$changefile"; +while ($line=) { + chop $line; + if ($recordchangelog) { + ($line =~ /^$/) and next; + $changelog .= "$line\n"; + next; + } + if ($line =~ /^releasetype *: *patch *$/) { + $nvmajor=$vmajor; + $nvminor=$vminor; + $nvpatch=$vpatch+1; + } + if ($line =~ /^releasetype *: *minor *$/) { + $nvmajor=$vmajor; + $nvminor=$vminor+1; + $nvpatch=0; + } + if ($line =~ /^releasetype *: *major *$/) { + $nvmajor=$vmajor+1; + $nvminor=0; + $nvpatch=0; + } + if ($line =~ /^## What has changed/) { + $recordchangelog=1; + } +} +close F; +print "New version: $nvmajor.$nvminor.$nvpatch\n"; +print "Changes:\n$changelog"; + +# updating manifest.yaml +open F,"manifest.yaml"; +open G,">.changes/manifest.yaml"; +while ($line=) { + if ($line =~ /^version *: */) { + print G "version: $nvmajor.$nvminor.$nvpatch\n"; + next; + } + print G $line; +} +close F; +close G; +# updating Changelog.md +open F,"$origchangelogfile"; +open G,">.changes/Changelog.md"; +print G <) { + if ($line =~ /^## +\[([0-9]+)\.([0-9]+)\.([0-9]+)\] +/) { + if ($1>$nvmajor) { + die("ERROR: found larger version $1.$2.$3 in the Changelog.md\n"); + } elsif ($1<$nvmajor) { + $skip_section=0; + } elsif ($2>$nvminor) { + die("ERROR: found larger version $1.$2.$3 in the Changelog.md\n"); + } elsif ($2<$nvminor) { + $skip_section=0; + } elsif ($3>$nvpatch) { + die("ERROR: found larger version $1.$2.$3 in the Changelog.md\n"); + } elsif ($2<$nvpatch) { + $skip_section=0; + } else { + $skip_section=1; + } + } + ($skip_section) and next; + print G $line; +} +close F; +close G; + +print "-------------------------------------\n"; +print "THIS WILL BE UPDATED:\n"; +print "-------------------------------------\n"; +system("diff -u manifest.yaml .changes/manifest.yaml"); +system("diff -u Changelog.md .changes/Changelog.md"); +print "-------------------------------------\n"; +print "To proceed: press otherwise \n"; +my $bla; +$bla=; +system("cp -vf .changes/manifest.yaml manifest.yaml"); +system("cp -vf .changes/Changelog.md Changelog.md"); +system("git commit -a -m \"Update version and changelog.\""); +system("git push"); +print("✓ Changes have been committed and pushed!\n"); +print("✓ A new release will be created when this branch is merged to main.\n"); + diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/CMakeLists.txt b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/CMakeLists.txt new file mode 100644 index 0000000..75198b7 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/CMakeLists.txt @@ -0,0 +1,223 @@ +enable_language(ASM) + +# C source files that are compiled for all targets (i.e. reference code) +set(SRCS_GENERIC + commons.c + commons_private.c + coeffs/coeffs_arithmetic.c + arithmetic/vec_znx.c + arithmetic/vec_znx_dft.c + arithmetic/vector_matrix_product.c + cplx/cplx_common.c + cplx/cplx_conversions.c + cplx/cplx_fft_asserts.c + cplx/cplx_fft_ref.c + cplx/cplx_fftvec_ref.c + cplx/cplx_ifft_ref.c + cplx/spqlios_cplx_fft.c + reim4/reim4_arithmetic_ref.c + reim4/reim4_fftvec_addmul_ref.c + reim4/reim4_fftvec_conv_ref.c + reim/reim_conversions.c + reim/reim_fft_ifft.c + reim/reim_fft_ref.c + reim/reim_fftvec_ref.c + reim/reim_ifft_ref.c + reim/reim_ifft_ref.c + reim/reim_to_tnx_ref.c + q120/q120_ntt.c + q120/q120_arithmetic_ref.c + q120/q120_arithmetic_simple.c + arithmetic/scalar_vector_product.c + arithmetic/vec_znx_big.c + arithmetic/znx_small.c + arithmetic/module_api.c + arithmetic/zn_vmp_int8_ref.c + arithmetic/zn_vmp_int16_ref.c + arithmetic/zn_vmp_int32_ref.c + arithmetic/zn_vmp_ref.c + arithmetic/zn_api.c + arithmetic/zn_conversions_ref.c + arithmetic/zn_approxdecomp_ref.c + arithmetic/vec_rnx_api.c + arithmetic/vec_rnx_conversions_ref.c + arithmetic/vec_rnx_svp_ref.c + reim/reim_execute.c + cplx/cplx_execute.c + reim4/reim4_execute.c + arithmetic/vec_rnx_arithmetic.c + arithmetic/vec_rnx_approxdecomp_ref.c + arithmetic/vec_rnx_vmp_ref.c +) +# C or assembly source files compiled only on x86 targets +set(SRCS_X86 + ) +# C or assembly source files compiled only on aarch64 targets +set(SRCS_AARCH64 + cplx/cplx_fallbacks_aarch64.c + reim/reim_fallbacks_aarch64.c + reim4/reim4_fallbacks_aarch64.c + q120/q120_fallbacks_aarch64.c + reim/reim_fft_neon.c +) + +# C or assembly source files compiled only on x86: avx, avx2, fma targets +set(SRCS_FMA_C + arithmetic/vector_matrix_product_avx.c + cplx/cplx_conversions_avx2_fma.c + cplx/cplx_fft_avx2_fma.c + cplx/cplx_fft_sse.c + cplx/cplx_fftvec_avx2_fma.c + cplx/cplx_ifft_avx2_fma.c + reim4/reim4_arithmetic_avx2.c + reim4/reim4_fftvec_conv_fma.c + reim4/reim4_fftvec_addmul_fma.c + reim/reim_conversions_avx.c + reim/reim_fft4_avx_fma.c + reim/reim_fft8_avx_fma.c + reim/reim_ifft4_avx_fma.c + reim/reim_ifft8_avx_fma.c + reim/reim_fft_avx2.c + reim/reim_ifft_avx2.c + reim/reim_to_tnx_avx.c + reim/reim_fftvec_fma.c +) +set(SRCS_FMA_ASM + cplx/cplx_fft16_avx_fma.s + cplx/cplx_ifft16_avx_fma.s + reim/reim_fft16_avx_fma.s + reim/reim_ifft16_avx_fma.s +) +set(SRCS_FMA_WIN32_ASM + cplx/cplx_fft16_avx_fma_win32.s + cplx/cplx_ifft16_avx_fma_win32.s + reim/reim_fft16_avx_fma_win32.s + reim/reim_ifft16_avx_fma_win32.s +) +set_source_files_properties(${SRCS_FMA_C} PROPERTIES COMPILE_OPTIONS "-mfma;-mavx;-mavx2") +set_source_files_properties(${SRCS_FMA_ASM} PROPERTIES COMPILE_OPTIONS "-mfma;-mavx;-mavx2") + +# C or assembly source files compiled only on x86: avx512f/vl/dq + fma targets +set(SRCS_AVX512 + cplx/cplx_fft_avx512.c + ) +set_source_files_properties(${SRCS_AVX512} PROPERTIES COMPILE_OPTIONS "-mfma;-mavx512f;-mavx512vl;-mavx512dq") + +# C or assembly source files compiled only on x86: avx2 + bmi targets +set(SRCS_AVX2 + arithmetic/vec_znx_avx.c + coeffs/coeffs_arithmetic_avx.c + arithmetic/vec_znx_dft_avx2.c + arithmetic/zn_vmp_int8_avx.c + arithmetic/zn_vmp_int16_avx.c + arithmetic/zn_vmp_int32_avx.c + q120/q120_arithmetic_avx2.c + q120/q120_ntt_avx2.c + arithmetic/vec_rnx_arithmetic_avx.c + arithmetic/vec_rnx_approxdecomp_avx.c + arithmetic/vec_rnx_vmp_avx.c + +) +set_source_files_properties(${SRCS_AVX2} PROPERTIES COMPILE_OPTIONS "-mbmi2;-mavx2") + +# C source files on float128 via libquadmath on x86 targets targets +set(SRCS_F128 + cplx_f128/cplx_fft_f128.c + cplx_f128/cplx_fft_f128.h + ) + +# H header files containing the public API (these headers are installed) +set(HEADERSPUBLIC + commons.h + arithmetic/vec_znx_arithmetic.h + arithmetic/vec_rnx_arithmetic.h + arithmetic/zn_arithmetic.h + cplx/cplx_fft.h + reim/reim_fft.h + q120/q120_common.h + q120/q120_arithmetic.h + q120/q120_ntt.h + ) + +# H header files containing the private API (these headers are used internally) +set(HEADERSPRIVATE + commons_private.h + cplx/cplx_fft_internal.h + cplx/cplx_fft_private.h + reim4/reim4_arithmetic.h + reim4/reim4_fftvec_internal.h + reim4/reim4_fftvec_private.h + reim4/reim4_fftvec_public.h + reim/reim_fft_internal.h + reim/reim_fft_private.h + q120/q120_arithmetic_private.h + q120/q120_ntt_private.h + arithmetic/vec_znx_arithmetic.h + arithmetic/vec_rnx_arithmetic_private.h + arithmetic/vec_rnx_arithmetic_plugin.h + arithmetic/zn_arithmetic_private.h + arithmetic/zn_arithmetic_plugin.h + coeffs/coeffs_arithmetic.h + reim/reim_fft_core_template.h +) + +set(SPQLIOSSOURCES + ${SRCS_GENERIC} + ${HEADERSPUBLIC} + ${HEADERSPRIVATE} + ) +if (${X86}) + set(SPQLIOSSOURCES ${SPQLIOSSOURCES} + ${SRCS_X86} + ${SRCS_FMA_C} + ${SRCS_FMA_ASM} + ${SRCS_AVX2} + ${SRCS_AVX512} + ) +elseif (${X86_WIN32}) + set(SPQLIOSSOURCES ${SPQLIOSSOURCES} + #${SRCS_X86} + ${SRCS_FMA_C} + ${SRCS_FMA_WIN32_ASM} + ${SRCS_AVX2} + ${SRCS_AVX512} + ) +elseif (${AARCH64}) + set(SPQLIOSSOURCES ${SPQLIOSSOURCES} + ${SRCS_AARCH64} + ) +endif () + + +set(SPQLIOSLIBDEP + m # libmath depencency for cosinus/sinus functions + ) + +if (ENABLE_SPQLIOS_F128) + find_library(quadmath REQUIRED NAMES quadmath) + set(SPQLIOSSOURCES ${SPQLIOSSOURCES} ${SRCS_F128}) + set(SPQLIOSLIBDEP ${SPQLIOSLIBDEP} quadmath) +endif (ENABLE_SPQLIOS_F128) + +add_library(libspqlios-static STATIC ${SPQLIOSSOURCES}) +add_library(libspqlios SHARED ${SPQLIOSSOURCES}) +set_property(TARGET libspqlios-static PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET libspqlios PROPERTY OUTPUT_NAME spqlios) +set_property(TARGET libspqlios-static PROPERTY OUTPUT_NAME spqlios) +set_property(TARGET libspqlios PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET libspqlios PROPERTY SOVERSION ${SPQLIOS_VERSION_MAJOR}) +set_property(TARGET libspqlios PROPERTY VERSION ${SPQLIOS_VERSION}) +if (NOT APPLE) +target_link_options(libspqlios-static PUBLIC -Wl,--no-undefined) +target_link_options(libspqlios PUBLIC -Wl,--no-undefined) +endif() +target_link_libraries(libspqlios ${SPQLIOSLIBDEP}) +target_link_libraries(libspqlios-static ${SPQLIOSLIBDEP}) +install(TARGETS libspqlios-static) +install(TARGETS libspqlios) + +# install the public headers only +foreach (file ${HEADERSPUBLIC}) + get_filename_component(dir ${file} DIRECTORY) + install(FILES ${file} DESTINATION include/spqlios/${dir}) +endforeach () diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/module_api.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/module_api.c new file mode 100644 index 0000000..15f46c1 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/module_api.c @@ -0,0 +1,172 @@ +#include + +#include "vec_znx_arithmetic_private.h" + +static void fill_generic_virtual_table(MODULE* module) { + // TODO add default ref handler here + module->func.vec_znx_zero = vec_znx_zero_ref; + module->func.vec_znx_copy = vec_znx_copy_ref; + module->func.vec_znx_negate = vec_znx_negate_ref; + module->func.vec_znx_add = vec_znx_add_ref; + module->func.vec_znx_sub = vec_znx_sub_ref; + module->func.vec_znx_rotate = vec_znx_rotate_ref; + module->func.vec_znx_mul_xp_minus_one = vec_znx_mul_xp_minus_one_ref; + module->func.vec_znx_automorphism = vec_znx_automorphism_ref; + module->func.vec_znx_normalize_base2k = vec_znx_normalize_base2k_ref; + module->func.vec_znx_normalize_base2k_tmp_bytes = vec_znx_normalize_base2k_tmp_bytes_ref; + if (CPU_SUPPORTS("avx2")) { + // TODO add avx handlers here + module->func.vec_znx_negate = vec_znx_negate_avx; + module->func.vec_znx_add = vec_znx_add_avx; + module->func.vec_znx_sub = vec_znx_sub_avx; + } +} + +static void fill_fft64_virtual_table(MODULE* module) { + // TODO add default ref handler here + // module->func.vec_znx_dft = ...; + module->func.vec_znx_big_normalize_base2k = fft64_vec_znx_big_normalize_base2k; + module->func.vec_znx_big_normalize_base2k_tmp_bytes = fft64_vec_znx_big_normalize_base2k_tmp_bytes; + module->func.vec_znx_big_range_normalize_base2k = fft64_vec_znx_big_range_normalize_base2k; + module->func.vec_znx_big_range_normalize_base2k_tmp_bytes = fft64_vec_znx_big_range_normalize_base2k_tmp_bytes; + module->func.vec_znx_dft = fft64_vec_znx_dft; + module->func.vec_znx_idft = fft64_vec_znx_idft; + module->func.vec_dft_add = fft64_vec_dft_add; + module->func.vec_dft_sub = fft64_vec_dft_sub; + module->func.vec_znx_idft_tmp_bytes = fft64_vec_znx_idft_tmp_bytes; + module->func.vec_znx_idft_tmp_a = fft64_vec_znx_idft_tmp_a; + module->func.vec_znx_big_add = fft64_vec_znx_big_add; + module->func.vec_znx_big_add_small = fft64_vec_znx_big_add_small; + module->func.vec_znx_big_add_small2 = fft64_vec_znx_big_add_small2; + module->func.vec_znx_big_sub = fft64_vec_znx_big_sub; + module->func.vec_znx_big_sub_small_a = fft64_vec_znx_big_sub_small_a; + module->func.vec_znx_big_sub_small_b = fft64_vec_znx_big_sub_small_b; + module->func.vec_znx_big_sub_small2 = fft64_vec_znx_big_sub_small2; + module->func.vec_znx_big_rotate = fft64_vec_znx_big_rotate; + module->func.vec_znx_big_automorphism = fft64_vec_znx_big_automorphism; + module->func.svp_prepare = fft64_svp_prepare_ref; + module->func.svp_apply_dft = fft64_svp_apply_dft_ref; + module->func.svp_apply_dft_to_dft = fft64_svp_apply_dft_to_dft_ref; + module->func.znx_small_single_product = fft64_znx_small_single_product; + module->func.znx_small_single_product_tmp_bytes = fft64_znx_small_single_product_tmp_bytes; + module->func.vmp_prepare_contiguous = fft64_vmp_prepare_contiguous_ref; + module->func.vmp_prepare_tmp_bytes = fft64_vmp_prepare_tmp_bytes; + module->func.vmp_apply_dft = fft64_vmp_apply_dft_ref; + module->func.vmp_apply_dft_add = fft64_vmp_apply_dft_add_ref; + module->func.vmp_apply_dft_tmp_bytes = fft64_vmp_apply_dft_tmp_bytes; + module->func.vmp_apply_dft_to_dft = fft64_vmp_apply_dft_to_dft_ref; + module->func.vmp_apply_dft_to_dft_add = fft64_vmp_apply_dft_to_dft_add_ref; + module->func.vmp_apply_dft_to_dft_tmp_bytes = fft64_vmp_apply_dft_to_dft_tmp_bytes; + module->func.bytes_of_vec_znx_dft = fft64_bytes_of_vec_znx_dft; + module->func.bytes_of_vec_znx_big = fft64_bytes_of_vec_znx_big; + module->func.bytes_of_svp_ppol = fft64_bytes_of_svp_ppol; + module->func.bytes_of_vmp_pmat = fft64_bytes_of_vmp_pmat; + if (CPU_SUPPORTS("avx2")) { + // TODO add avx handlers here + // TODO: enable when avx implementation is done + module->func.vmp_prepare_contiguous = fft64_vmp_prepare_contiguous_avx; + module->func.vmp_apply_dft = fft64_vmp_apply_dft_avx; + module->func.vmp_apply_dft_add = fft64_vmp_apply_dft_add_avx; + module->func.vmp_apply_dft_to_dft = fft64_vmp_apply_dft_to_dft_avx; + module->func.vmp_apply_dft_to_dft_add = fft64_vmp_apply_dft_to_dft_add_avx; + } +} + +static void fill_ntt120_virtual_table(MODULE* module) { + // TODO add default ref handler here + // module->func.vec_znx_dft = ...; + if (CPU_SUPPORTS("avx2")) { + // TODO add avx handlers here + module->func.vec_znx_dft = ntt120_vec_znx_dft_avx; + module->func.vec_znx_idft = ntt120_vec_znx_idft_avx; + module->func.vec_znx_idft_tmp_bytes = ntt120_vec_znx_idft_tmp_bytes_avx; + module->func.vec_znx_idft_tmp_a = ntt120_vec_znx_idft_tmp_a_avx; + } +} + +static void fill_virtual_table(MODULE* module) { + fill_generic_virtual_table(module); + switch (module->module_type) { + case FFT64: + fill_fft64_virtual_table(module); + break; + case NTT120: + fill_ntt120_virtual_table(module); + break; + default: + NOT_SUPPORTED(); // invalid type + } +} + +static void fill_fft64_precomp(MODULE* module) { + // fill any necessary precomp stuff + module->mod.fft64.p_conv = new_reim_from_znx64_precomp(module->m, 50); + module->mod.fft64.p_fft = new_reim_fft_precomp(module->m, 0); + module->mod.fft64.p_reim_to_znx = new_reim_to_znx64_precomp(module->m, module->m, 63); + module->mod.fft64.p_ifft = new_reim_ifft_precomp(module->m, 0); + module->mod.fft64.p_addmul = new_reim_fftvec_addmul_precomp(module->m); + module->mod.fft64.mul_fft = new_reim_fftvec_mul_precomp(module->m); + module->mod.fft64.add_fft = new_reim_fftvec_add_precomp(module->m); + module->mod.fft64.sub_fft = new_reim_fftvec_sub_precomp(module->m); +} +static void fill_ntt120_precomp(MODULE* module) { + // fill any necessary precomp stuff + if (CPU_SUPPORTS("avx2")) { + module->mod.q120.p_ntt = q120_new_ntt_bb_precomp(module->nn); + module->mod.q120.p_intt = q120_new_intt_bb_precomp(module->nn); + } +} + +static void fill_module_precomp(MODULE* module) { + switch (module->module_type) { + case FFT64: + fill_fft64_precomp(module); + break; + case NTT120: + fill_ntt120_precomp(module); + break; + default: + NOT_SUPPORTED(); // invalid type + } +} + +static void fill_module(MODULE* module, uint64_t nn, MODULE_TYPE mtype) { + // init to zero to ensure that any non-initialized field bug is detected + // by at least a "proper" segfault + memset(module, 0, sizeof(MODULE)); + module->module_type = mtype; + module->nn = nn; + module->m = nn >> 1; + fill_module_precomp(module); + fill_virtual_table(module); +} + +EXPORT MODULE* new_module_info(uint64_t N, MODULE_TYPE mtype) { + MODULE* m = (MODULE*)malloc(sizeof(MODULE)); + fill_module(m, N, mtype); + return m; +} + +EXPORT void delete_module_info(MODULE* mod) { + switch (mod->module_type) { + case FFT64: + free(mod->mod.fft64.p_conv); + free(mod->mod.fft64.p_fft); + free(mod->mod.fft64.p_ifft); + free(mod->mod.fft64.p_reim_to_znx); + free(mod->mod.fft64.mul_fft); + free(mod->mod.fft64.p_addmul); + break; + case NTT120: + if (CPU_SUPPORTS("avx2")) { + q120_del_ntt_bb_precomp(mod->mod.q120.p_ntt); + q120_del_intt_bb_precomp(mod->mod.q120.p_intt); + } + break; + default: + break; + } + free(mod); +} + +EXPORT uint64_t module_get_n(const MODULE* module) { return module->nn; } diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/scalar_vector_product.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/scalar_vector_product.c new file mode 100644 index 0000000..b865862 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/scalar_vector_product.c @@ -0,0 +1,102 @@ +#include + +#include "vec_znx_arithmetic_private.h" + +EXPORT uint64_t bytes_of_svp_ppol(const MODULE* module) { return module->func.bytes_of_svp_ppol(module); } + +EXPORT uint64_t fft64_bytes_of_svp_ppol(const MODULE* module) { return module->nn * sizeof(double); } + +EXPORT SVP_PPOL* new_svp_ppol(const MODULE* module) { return spqlios_alloc(bytes_of_svp_ppol(module)); } + +EXPORT void delete_svp_ppol(SVP_PPOL* ppol) { spqlios_free(ppol); } + +// public wrappers +EXPORT void svp_prepare(const MODULE* module, // N + SVP_PPOL* ppol, // output + const int64_t* pol // a +) { + module->func.svp_prepare(module, ppol, pol); +} + +/** @brief prepares a svp polynomial */ +EXPORT void fft64_svp_prepare_ref(const MODULE* module, // N + SVP_PPOL* ppol, // output + const int64_t* pol // a +) { + reim_from_znx64(module->mod.fft64.p_conv, ppol, pol); + reim_fft(module->mod.fft64.p_fft, (double*)ppol); +} + +EXPORT void svp_apply_dft(const MODULE* module, // N + const VEC_ZNX_DFT* res, uint64_t res_size, // output + const SVP_PPOL* ppol, // prepared pol + const int64_t* a, uint64_t a_size, uint64_t a_sl) { + module->func.svp_apply_dft(module, // N + res, + res_size, // output + ppol, // prepared pol + a, a_size, a_sl); +} + +EXPORT void svp_apply_dft_to_dft(const MODULE* module, // N + const VEC_ZNX_DFT* res, uint64_t res_size, + uint64_t res_cols, // output + const SVP_PPOL* ppol, // prepared pol + const VEC_ZNX_DFT* a, uint64_t a_size, uint64_t a_cols) { + module->func.svp_apply_dft_to_dft(module, // N + res, res_size, res_cols, // output + ppol, a, a_size, a_cols // prepared pol + ); +} + +// result = ppol * a +EXPORT void fft64_svp_apply_dft_ref(const MODULE* module, // N + const VEC_ZNX_DFT* res, uint64_t res_size, // output + const SVP_PPOL* ppol, // prepared pol + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->nn; + double* const dres = (double*)res; + double* const dppol = (double*)ppol; + + const uint64_t auto_end_idx = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < auto_end_idx; ++i) { + const int64_t* a_ptr = a + i * a_sl; + double* const res_ptr = dres + i * nn; + // copy the polynomial to res, apply fft in place, call fftvec_mul in place. + reim_from_znx64(module->mod.fft64.p_conv, res_ptr, a_ptr); + reim_fft(module->mod.fft64.p_fft, res_ptr); + reim_fftvec_mul(module->mod.fft64.mul_fft, res_ptr, res_ptr, dppol); + } + + // then extend with zeros + memset(dres + auto_end_idx * nn, 0, (res_size - auto_end_idx) * nn * sizeof(double)); +} + +// result = ppol * a +EXPORT void fft64_svp_apply_dft_to_dft_ref(const MODULE* module, // N + const VEC_ZNX_DFT* res, uint64_t res_size, + uint64_t res_cols, // output + const SVP_PPOL* ppol, // prepared pol + const VEC_ZNX_DFT* a, uint64_t a_size, + uint64_t a_cols // a +) { + const uint64_t nn = module->nn; + const uint64_t res_sl = nn * res_cols; + const uint64_t a_sl = nn * a_cols; + double* const dres = (double*)res; + double* const da = (double*)a; + double* const dppol = (double*)ppol; + + const uint64_t auto_end_idx = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < auto_end_idx; ++i) { + const double* a_ptr = da + i * a_sl; + double* const res_ptr = dres + i * res_sl; + reim_fftvec_mul(module->mod.fft64.mul_fft, res_ptr, a_ptr, dppol); + } + + // then extend with zeros + for (uint64_t i = auto_end_idx; i < res_size; i++) { + memset(dres + i * res_sl, 0, nn * sizeof(double)); + } +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_rnx_api.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_rnx_api.c new file mode 100644 index 0000000..0e116fc --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_rnx_api.c @@ -0,0 +1,344 @@ +#include + +#include "vec_rnx_arithmetic_private.h" + +void fft64_init_rnx_module_precomp(MOD_RNX* module) { + // Add here initialization of items that are in the precomp + const uint64_t m = module->m; + module->precomp.fft64.p_fft = new_reim_fft_precomp(m, 0); + module->precomp.fft64.p_ifft = new_reim_ifft_precomp(m, 0); + module->precomp.fft64.p_fftvec_add = new_reim_fftvec_add_precomp(m); + module->precomp.fft64.p_fftvec_mul = new_reim_fftvec_mul_precomp(m); + module->precomp.fft64.p_fftvec_addmul = new_reim_fftvec_addmul_precomp(m); +} + +void fft64_finalize_rnx_module_precomp(MOD_RNX* module) { + // Add here deleters for items that are in the precomp + delete_reim_fft_precomp(module->precomp.fft64.p_fft); + delete_reim_ifft_precomp(module->precomp.fft64.p_ifft); + delete_reim_fftvec_add_precomp(module->precomp.fft64.p_fftvec_add); + delete_reim_fftvec_mul_precomp(module->precomp.fft64.p_fftvec_mul); + delete_reim_fftvec_addmul_precomp(module->precomp.fft64.p_fftvec_addmul); +} + +void fft64_init_rnx_module_vtable(MOD_RNX* module) { + // Add function pointers here + module->vtable.vec_rnx_add = vec_rnx_add_ref; + module->vtable.vec_rnx_zero = vec_rnx_zero_ref; + module->vtable.vec_rnx_copy = vec_rnx_copy_ref; + module->vtable.vec_rnx_negate = vec_rnx_negate_ref; + module->vtable.vec_rnx_sub = vec_rnx_sub_ref; + module->vtable.vec_rnx_rotate = vec_rnx_rotate_ref; + module->vtable.vec_rnx_automorphism = vec_rnx_automorphism_ref; + module->vtable.vec_rnx_mul_xp_minus_one = vec_rnx_mul_xp_minus_one_ref; + module->vtable.rnx_vmp_apply_dft_to_dft_tmp_bytes = fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref; + module->vtable.rnx_vmp_apply_dft_to_dft = fft64_rnx_vmp_apply_dft_to_dft_ref; + module->vtable.rnx_vmp_apply_tmp_a_tmp_bytes = fft64_rnx_vmp_apply_tmp_a_tmp_bytes_ref; + module->vtable.rnx_vmp_apply_tmp_a = fft64_rnx_vmp_apply_tmp_a_ref; + module->vtable.rnx_vmp_prepare_tmp_bytes = fft64_rnx_vmp_prepare_tmp_bytes_ref; + module->vtable.rnx_vmp_prepare_contiguous = fft64_rnx_vmp_prepare_contiguous_ref; + module->vtable.rnx_vmp_prepare_dblptr = fft64_rnx_vmp_prepare_dblptr_ref; + module->vtable.rnx_vmp_prepare_row = fft64_rnx_vmp_prepare_row_ref; + module->vtable.bytes_of_rnx_vmp_pmat = fft64_bytes_of_rnx_vmp_pmat; + module->vtable.rnx_approxdecomp_from_tnxdbl = rnx_approxdecomp_from_tnxdbl_ref; + module->vtable.vec_rnx_to_znx32 = vec_rnx_to_znx32_ref; + module->vtable.vec_rnx_from_znx32 = vec_rnx_from_znx32_ref; + module->vtable.vec_rnx_to_tnx32 = vec_rnx_to_tnx32_ref; + module->vtable.vec_rnx_from_tnx32 = vec_rnx_from_tnx32_ref; + module->vtable.vec_rnx_to_tnxdbl = vec_rnx_to_tnxdbl_ref; + module->vtable.bytes_of_rnx_svp_ppol = fft64_bytes_of_rnx_svp_ppol; + module->vtable.rnx_svp_prepare = fft64_rnx_svp_prepare_ref; + module->vtable.rnx_svp_apply = fft64_rnx_svp_apply_ref; + + // Add optimized function pointers here + if (CPU_SUPPORTS("avx")) { + module->vtable.vec_rnx_add = vec_rnx_add_avx; + module->vtable.vec_rnx_sub = vec_rnx_sub_avx; + module->vtable.vec_rnx_negate = vec_rnx_negate_avx; + module->vtable.rnx_vmp_apply_dft_to_dft_tmp_bytes = fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_avx; + module->vtable.rnx_vmp_apply_dft_to_dft = fft64_rnx_vmp_apply_dft_to_dft_avx; + module->vtable.rnx_vmp_apply_tmp_a_tmp_bytes = fft64_rnx_vmp_apply_tmp_a_tmp_bytes_avx; + module->vtable.rnx_vmp_apply_tmp_a = fft64_rnx_vmp_apply_tmp_a_avx; + module->vtable.rnx_vmp_prepare_tmp_bytes = fft64_rnx_vmp_prepare_tmp_bytes_avx; + module->vtable.rnx_vmp_prepare_contiguous = fft64_rnx_vmp_prepare_contiguous_avx; + module->vtable.rnx_vmp_prepare_dblptr = fft64_rnx_vmp_prepare_dblptr_avx; + module->vtable.rnx_vmp_prepare_row = fft64_rnx_vmp_prepare_row_avx; + module->vtable.rnx_approxdecomp_from_tnxdbl = rnx_approxdecomp_from_tnxdbl_avx; + } +} + +void init_rnx_module_info(MOD_RNX* module, // + uint64_t n, RNX_MODULE_TYPE mtype) { + memset(module, 0, sizeof(MOD_RNX)); + module->n = n; + module->m = n >> 1; + module->mtype = mtype; + switch (mtype) { + case FFT64: + fft64_init_rnx_module_precomp(module); + fft64_init_rnx_module_vtable(module); + break; + default: + NOT_SUPPORTED(); // unknown mtype + } +} + +void finalize_rnx_module_info(MOD_RNX* module) { + if (module->custom) module->custom_deleter(module->custom); + switch (module->mtype) { + case FFT64: + fft64_finalize_rnx_module_precomp(module); + // fft64_finalize_rnx_module_vtable(module); // nothing to finalize + break; + default: + NOT_SUPPORTED(); // unknown mtype + } +} + +EXPORT MOD_RNX* new_rnx_module_info(uint64_t nn, RNX_MODULE_TYPE mtype) { + MOD_RNX* res = (MOD_RNX*)malloc(sizeof(MOD_RNX)); + init_rnx_module_info(res, nn, mtype); + return res; +} + +EXPORT void delete_rnx_module_info(MOD_RNX* module_info) { + finalize_rnx_module_info(module_info); + free(module_info); +} + +EXPORT uint64_t rnx_module_get_n(const MOD_RNX* module) { return module->n; } + +/** @brief allocates a prepared matrix (release with delete_rnx_vmp_pmat) */ +EXPORT RNX_VMP_PMAT* new_rnx_vmp_pmat(const MOD_RNX* module, // N + uint64_t nrows, uint64_t ncols) { // dimensions + return (RNX_VMP_PMAT*)spqlios_alloc(bytes_of_rnx_vmp_pmat(module, nrows, ncols)); +} +EXPORT void delete_rnx_vmp_pmat(RNX_VMP_PMAT* ptr) { spqlios_free(ptr); } + +//////////////// wrappers ////////////////// + +/** @brief sets res = a + b */ +EXPORT void vec_rnx_add( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl, // a + const double* b, uint64_t b_size, uint64_t b_sl // b +) { + module->vtable.vec_rnx_add(module, res, res_size, res_sl, a, a_size, a_sl, b, b_size, b_sl); +} + +/** @brief sets res = 0 */ +EXPORT void vec_rnx_zero( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl // res +) { + module->vtable.vec_rnx_zero(module, res, res_size, res_sl); +} + +/** @brief sets res = a */ +EXPORT void vec_rnx_copy( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + module->vtable.vec_rnx_copy(module, res, res_size, res_sl, a, a_size, a_sl); +} + +/** @brief sets res = -a */ +EXPORT void vec_rnx_negate( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + module->vtable.vec_rnx_negate(module, res, res_size, res_sl, a, a_size, a_sl); +} + +/** @brief sets res = a - b */ +EXPORT void vec_rnx_sub( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl, // a + const double* b, uint64_t b_size, uint64_t b_sl // b +) { + module->vtable.vec_rnx_sub(module, res, res_size, res_sl, a, a_size, a_sl, b, b_size, b_sl); +} + +/** @brief sets res = a . X^p */ +EXPORT void vec_rnx_rotate( // + const MOD_RNX* module, // N + const int64_t p, // rotation value + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + module->vtable.vec_rnx_rotate(module, p, res, res_size, res_sl, a, a_size, a_sl); +} + +/** @brief sets res = a(X^p) */ +EXPORT void vec_rnx_automorphism( // + const MOD_RNX* module, // N + int64_t p, // X -> X^p + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + module->vtable.vec_rnx_automorphism(module, p, res, res_size, res_sl, a, a_size, a_sl); +} + +EXPORT void vec_rnx_mul_xp_minus_one( // + const MOD_RNX* module, // N + const int64_t p, // rotation value + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + module->vtable.vec_rnx_mul_xp_minus_one(module, p, res, res_size, res_sl, a, a_size, a_sl); +} +/** @brief number of bytes in a RNX_VMP_PMAT (for manual allocation) */ +EXPORT uint64_t bytes_of_rnx_vmp_pmat(const MOD_RNX* module, // N + uint64_t nrows, uint64_t ncols) { // dimensions + return module->vtable.bytes_of_rnx_vmp_pmat(module, nrows, ncols); +} + +/** @brief prepares a vmp matrix (contiguous row-major version) */ +EXPORT void rnx_vmp_prepare_contiguous( // + const MOD_RNX* module, // N + RNX_VMP_PMAT* pmat, // output + const double* a, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +) { + module->vtable.rnx_vmp_prepare_contiguous(module, pmat, a, nrows, ncols, tmp_space); +} + +/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */ +EXPORT void rnx_vmp_prepare_dblptr( // + const MOD_RNX* module, // N + RNX_VMP_PMAT* pmat, // output + const double** a, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +) { + module->vtable.rnx_vmp_prepare_dblptr(module, pmat, a, nrows, ncols, tmp_space); +} + +/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */ +EXPORT void rnx_vmp_prepare_row( // + const MOD_RNX* module, // N + RNX_VMP_PMAT* pmat, // output + const double* a, uint64_t row_i, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +) { + module->vtable.rnx_vmp_prepare_row(module, pmat, a, row_i, nrows, ncols, tmp_space); +} + +/** @brief number of scratch bytes necessary to prepare a matrix */ +EXPORT uint64_t rnx_vmp_prepare_tmp_bytes(const MOD_RNX* module) { + return module->vtable.rnx_vmp_prepare_tmp_bytes(module); +} + +/** @brief applies a vmp product res = a x pmat */ +EXPORT void rnx_vmp_apply_tmp_a( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten) + const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space +) { + module->vtable.rnx_vmp_apply_tmp_a(module, res, res_size, res_sl, tmpa, a_size, a_sl, pmat, nrows, ncols, tmp_space); +} + +EXPORT uint64_t rnx_vmp_apply_tmp_a_tmp_bytes( // + const MOD_RNX* module, // N + uint64_t res_size, // res size + uint64_t a_size, // a size + uint64_t nrows, uint64_t ncols // prep matrix dims +) { + return module->vtable.rnx_vmp_apply_tmp_a_tmp_bytes(module, res_size, a_size, nrows, ncols); +} + +/** @brief minimal size of the tmp_space */ +EXPORT void rnx_vmp_apply_dft_to_dft( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a_dft, uint64_t a_size, uint64_t a_sl, // a + const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes) +) { + module->vtable.rnx_vmp_apply_dft_to_dft(module, res, res_size, res_sl, a_dft, a_size, a_sl, pmat, nrows, ncols, + tmp_space); +} + +/** @brief minimal size of the tmp_space */ +EXPORT uint64_t rnx_vmp_apply_dft_to_dft_tmp_bytes( // + const MOD_RNX* module, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix +) { + return module->vtable.rnx_vmp_apply_dft_to_dft_tmp_bytes(module, res_size, a_size, nrows, ncols); +} + +EXPORT uint64_t bytes_of_rnx_svp_ppol(const MOD_RNX* module) { return module->vtable.bytes_of_rnx_svp_ppol(module); } + +EXPORT void rnx_svp_prepare(const MOD_RNX* module, // N + RNX_SVP_PPOL* ppol, // output + const double* pol // a +) { + module->vtable.rnx_svp_prepare(module, ppol, pol); +} + +EXPORT void rnx_svp_apply( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // output + const RNX_SVP_PPOL* ppol, // prepared pol + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + module->vtable.rnx_svp_apply(module, // N + res, res_size, res_sl, // output + ppol, // prepared pol + a, a_size, a_sl); +} + +EXPORT void rnx_approxdecomp_from_tnxdbl( // + const MOD_RNX* module, // N + const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a) { // a + module->vtable.rnx_approxdecomp_from_tnxdbl(module, gadget, res, res_size, res_sl, a); +} + +EXPORT void vec_rnx_to_znx32( // + const MOD_RNX* module, // N + int32_t* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + module->vtable.vec_rnx_to_znx32(module, res, res_size, res_sl, a, a_size, a_sl); +} + +EXPORT void vec_rnx_from_znx32( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const int32_t* a, uint64_t a_size, uint64_t a_sl // a +) { + module->vtable.vec_rnx_from_znx32(module, res, res_size, res_sl, a, a_size, a_sl); +} + +EXPORT void vec_rnx_to_tnx32( // + const MOD_RNX* module, // N + int32_t* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + module->vtable.vec_rnx_to_tnx32(module, res, res_size, res_sl, a, a_size, a_sl); +} + +EXPORT void vec_rnx_from_tnx32( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const int32_t* a, uint64_t a_size, uint64_t a_sl // a +) { + module->vtable.vec_rnx_from_tnx32(module, res, res_size, res_sl, a, a_size, a_sl); +} + +EXPORT void vec_rnx_to_tnxdbl( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + module->vtable.vec_rnx_to_tnxdbl(module, res, res_size, res_sl, a, a_size, a_sl); +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_rnx_approxdecomp_avx.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_rnx_approxdecomp_avx.c new file mode 100644 index 0000000..2acda14 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_rnx_approxdecomp_avx.c @@ -0,0 +1,59 @@ +#include + +#include "immintrin.h" +#include "vec_rnx_arithmetic_private.h" + +/** @brief sets res = gadget_decompose(a) */ +EXPORT void rnx_approxdecomp_from_tnxdbl_avx( // + const MOD_RNX* module, // N + const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a // a +) { + const uint64_t nn = module->n; + if (nn < 4) return rnx_approxdecomp_from_tnxdbl_ref(module, gadget, res, res_size, res_sl, a); + const uint64_t ell = gadget->ell; + const __m256i k = _mm256_set1_epi64x(gadget->k); + const __m256d add_cst = _mm256_set1_pd(gadget->add_cst); + const __m256i and_mask = _mm256_set1_epi64x(gadget->and_mask); + const __m256i or_mask = _mm256_set1_epi64x(gadget->or_mask); + const __m256d sub_cst = _mm256_set1_pd(gadget->sub_cst); + const uint64_t msize = res_size <= ell ? res_size : ell; + // gadget decompose column by column + if (msize == ell) { + // this is the main scenario when msize == ell + double* const last_r = res + (msize - 1) * res_sl; + for (uint64_t j = 0; j < nn; j += 4) { + double* rr = last_r + j; + const double* aa = a + j; + __m256d t_dbl = _mm256_add_pd(_mm256_loadu_pd(aa), add_cst); + __m256i t_int = _mm256_castpd_si256(t_dbl); + do { + __m256i u_int = _mm256_or_si256(_mm256_and_si256(t_int, and_mask), or_mask); + _mm256_storeu_pd(rr, _mm256_sub_pd(_mm256_castsi256_pd(u_int), sub_cst)); + t_int = _mm256_srlv_epi64(t_int, k); + rr -= res_sl; + } while (rr >= res); + } + } else if (msize > 0) { + // otherwise, if msize < ell: there is one additional rshift + const __m256i first_rsh = _mm256_set1_epi64x((ell - msize) * gadget->k); + double* const last_r = res + (msize - 1) * res_sl; + for (uint64_t j = 0; j < nn; j += 4) { + double* rr = last_r + j; + const double* aa = a + j; + __m256d t_dbl = _mm256_add_pd(_mm256_loadu_pd(aa), add_cst); + __m256i t_int = _mm256_srlv_epi64(_mm256_castpd_si256(t_dbl), first_rsh); + do { + __m256i u_int = _mm256_or_si256(_mm256_and_si256(t_int, and_mask), or_mask); + _mm256_storeu_pd(rr, _mm256_sub_pd(_mm256_castsi256_pd(u_int), sub_cst)); + t_int = _mm256_srlv_epi64(t_int, k); + rr -= res_sl; + } while (rr >= res); + } + } + // zero-out the last slices (if any) + for (uint64_t i = msize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_rnx_approxdecomp_ref.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_rnx_approxdecomp_ref.c new file mode 100644 index 0000000..eab2d12 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_rnx_approxdecomp_ref.c @@ -0,0 +1,75 @@ +#include + +#include "vec_rnx_arithmetic_private.h" + +typedef union di { + double dv; + uint64_t uv; +} di_t; + +/** @brief new gadget: delete with delete_tnxdbl_approxdecomp_gadget */ +EXPORT TNXDBL_APPROXDECOMP_GADGET* new_tnxdbl_approxdecomp_gadget( // + const MOD_RNX* module, // N + uint64_t k, uint64_t ell // base 2^K and size +) { + if (k * ell > 50) return spqlios_error("gadget requires a too large fp precision"); + TNXDBL_APPROXDECOMP_GADGET* res = spqlios_alloc(sizeof(TNXDBL_APPROXDECOMP_GADGET)); + res->k = k; + res->ell = ell; + // double add_cst; // double(3.2^(51-ell.K) + 1/2.(sum 2^(-iK)) for i=[0,ell[) + union di add_cst; + add_cst.dv = UINT64_C(3) << (51 - ell * k); + for (uint64_t i = 0; i < ell; ++i) { + add_cst.uv |= UINT64_C(1) << ((i + 1) * k - 1); + } + res->add_cst = add_cst.dv; + // uint64_t and_mask; // uint64(2^(K)-1) + res->and_mask = (UINT64_C(1) << k) - 1; + // uint64_t or_mask; // double(2^52) + union di or_mask; + or_mask.dv = (UINT64_C(1) << 52); + res->or_mask = or_mask.uv; + // double sub_cst; // double(2^52 + 2^(K-1)) + res->sub_cst = ((UINT64_C(1) << 52) + (UINT64_C(1) << (k - 1))); + return res; +} + +EXPORT void delete_tnxdbl_approxdecomp_gadget(TNXDBL_APPROXDECOMP_GADGET* gadget) { spqlios_free(gadget); } + +/** @brief sets res = gadget_decompose(a) */ +EXPORT void rnx_approxdecomp_from_tnxdbl_ref( // + const MOD_RNX* module, // N + const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a // a +) { + const uint64_t nn = module->n; + const uint64_t k = gadget->k; + const uint64_t ell = gadget->ell; + const double add_cst = gadget->add_cst; + const uint64_t and_mask = gadget->and_mask; + const uint64_t or_mask = gadget->or_mask; + const double sub_cst = gadget->sub_cst; + const uint64_t msize = res_size <= ell ? res_size : ell; + const uint64_t first_rsh = (ell - msize) * k; + // gadget decompose column by column + if (msize > 0) { + double* const last_r = res + (msize - 1) * res_sl; + for (uint64_t j = 0; j < nn; ++j) { + double* rr = last_r + j; + di_t t = {.dv = a[j] + add_cst}; + if (msize < ell) t.uv >>= first_rsh; + do { + di_t u; + u.uv = (t.uv & and_mask) | or_mask; + *rr = u.dv - sub_cst; + t.uv >>= k; + rr -= res_sl; + } while (rr >= res); + } + } + // zero-out the last slices (if any) + for (uint64_t i = msize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_rnx_arithmetic.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_rnx_arithmetic.c new file mode 100644 index 0000000..ccdfa85 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_rnx_arithmetic.c @@ -0,0 +1,223 @@ +#include + +#include "../coeffs/coeffs_arithmetic.h" +#include "vec_rnx_arithmetic_private.h" + +void rnx_add_ref(uint64_t nn, double* res, const double* a, const double* b) { + for (uint64_t i = 0; i < nn; ++i) { + res[i] = a[i] + b[i]; + } +} + +void rnx_sub_ref(uint64_t nn, double* res, const double* a, const double* b) { + for (uint64_t i = 0; i < nn; ++i) { + res[i] = a[i] - b[i]; + } +} + +void rnx_negate_ref(uint64_t nn, double* res, const double* a) { + for (uint64_t i = 0; i < nn; ++i) { + res[i] = -a[i]; + } +} + +/** @brief sets res = a + b */ +EXPORT void vec_rnx_add_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl, // a + const double* b, uint64_t b_size, uint64_t b_sl // b +) { + const uint64_t nn = module->n; + if (a_size < b_size) { + const uint64_t msize = res_size < a_size ? res_size : a_size; + const uint64_t nsize = res_size < b_size ? res_size : b_size; + for (uint64_t i = 0; i < msize; ++i) { + rnx_add_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + for (uint64_t i = msize; i < nsize; ++i) { + memcpy(res + i * res_sl, b + i * b_sl, nn * sizeof(double)); + } + for (uint64_t i = nsize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } + } else { + const uint64_t msize = res_size < b_size ? res_size : b_size; + const uint64_t nsize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + rnx_add_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + for (uint64_t i = msize; i < nsize; ++i) { + memcpy(res + i * res_sl, a + i * a_sl, nn * sizeof(double)); + } + for (uint64_t i = nsize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } + } +} + +/** @brief sets res = 0 */ +EXPORT void vec_rnx_zero_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl // res +) { + const uint64_t nn = module->n; + for (uint64_t i = 0; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } +} + +/** @brief sets res = a */ +EXPORT void vec_rnx_copy_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->n; + + const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size; + // rotate up to the smallest dimension + for (uint64_t i = 0; i < rot_end_idx; ++i) { + double* res_ptr = res + i * res_sl; + const double* a_ptr = a + i * a_sl; + memcpy(res_ptr, a_ptr, nn * sizeof(double)); + } + // then extend with zeros + for (uint64_t i = rot_end_idx; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } +} + +/** @brief sets res = -a */ +EXPORT void vec_rnx_negate_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->n; + + const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size; + // rotate up to the smallest dimension + for (uint64_t i = 0; i < rot_end_idx; ++i) { + double* res_ptr = res + i * res_sl; + const double* a_ptr = a + i * a_sl; + rnx_negate_ref(nn, res_ptr, a_ptr); + } + // then extend with zeros + for (uint64_t i = rot_end_idx; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } +} + +/** @brief sets res = a - b */ +EXPORT void vec_rnx_sub_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl, // a + const double* b, uint64_t b_size, uint64_t b_sl // b +) { + const uint64_t nn = module->n; + if (a_size < b_size) { + const uint64_t msize = res_size < a_size ? res_size : a_size; + const uint64_t nsize = res_size < b_size ? res_size : b_size; + for (uint64_t i = 0; i < msize; ++i) { + rnx_sub_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + for (uint64_t i = msize; i < nsize; ++i) { + rnx_negate_ref(nn, res + i * res_sl, b + i * b_sl); + } + for (uint64_t i = nsize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } + } else { + const uint64_t msize = res_size < b_size ? res_size : b_size; + const uint64_t nsize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + rnx_sub_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + for (uint64_t i = msize; i < nsize; ++i) { + memcpy(res + i * res_sl, a + i * a_sl, nn * sizeof(double)); + } + for (uint64_t i = nsize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } + } +} + +/** @brief sets res = a . X^p */ +EXPORT void vec_rnx_rotate_ref( // + const MOD_RNX* module, // N + const int64_t p, // rotation value + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->n; + + const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size; + // rotate up to the smallest dimension + for (uint64_t i = 0; i < rot_end_idx; ++i) { + double* res_ptr = res + i * res_sl; + const double* a_ptr = a + i * a_sl; + if (res_ptr == a_ptr) { + rnx_rotate_inplace_f64(nn, p, res_ptr); + } else { + rnx_rotate_f64(nn, p, res_ptr, a_ptr); + } + } + // then extend with zeros + for (uint64_t i = rot_end_idx; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } +} + +/** @brief sets res = a(X^p) */ +EXPORT void vec_rnx_automorphism_ref( // + const MOD_RNX* module, // N + int64_t p, // X -> X^p + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->n; + + const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size; + // rotate up to the smallest dimension + for (uint64_t i = 0; i < rot_end_idx; ++i) { + double* res_ptr = res + i * res_sl; + const double* a_ptr = a + i * a_sl; + if (res_ptr == a_ptr) { + rnx_automorphism_inplace_f64(nn, p, res_ptr); + } else { + rnx_automorphism_f64(nn, p, res_ptr, a_ptr); + } + } + // then extend with zeros + for (uint64_t i = rot_end_idx; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } +} + +/** @brief sets res = a . (X^p - 1) */ +EXPORT void vec_rnx_mul_xp_minus_one_ref( // + const MOD_RNX* module, // N + const int64_t p, // rotation value + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->n; + + const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size; + // rotate up to the smallest dimension + for (uint64_t i = 0; i < rot_end_idx; ++i) { + double* res_ptr = res + i * res_sl; + const double* a_ptr = a + i * a_sl; + if (res_ptr == a_ptr) { + rnx_mul_xp_minus_one_inplace_f64(nn, p, res_ptr); + } else { + rnx_mul_xp_minus_one_f64(nn, p, res_ptr, a_ptr); + } + } + // then extend with zeros + for (uint64_t i = rot_end_idx; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_rnx_arithmetic.h b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_rnx_arithmetic.h new file mode 100644 index 0000000..c88bddb --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_rnx_arithmetic.h @@ -0,0 +1,356 @@ +#ifndef SPQLIOS_VEC_RNX_ARITHMETIC_H +#define SPQLIOS_VEC_RNX_ARITHMETIC_H + +#include + +#include "../commons.h" + +/** + * We support the following module families: + * - FFT64: + * the overall precision should fit at all times over 52 bits. + */ +typedef enum rnx_module_type_t { FFT64 } RNX_MODULE_TYPE; + +/** @brief opaque structure that describes the modules (RnX,ZnX,TnX) and the hardware */ +typedef struct rnx_module_info_t MOD_RNX; + +/** + * @brief obtain a module info for ring dimension N + * the module-info knows about: + * - the dimension N (or the complex dimension m=N/2) + * - any moduleuted fft or ntt items + * - the hardware (avx, arm64, x86, ...) + */ +EXPORT MOD_RNX* new_rnx_module_info(uint64_t N, RNX_MODULE_TYPE mode); +EXPORT void delete_rnx_module_info(MOD_RNX* module_info); +EXPORT uint64_t rnx_module_get_n(const MOD_RNX* module); + +// basic arithmetic + +/** @brief sets res = 0 */ +EXPORT void vec_rnx_zero( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl // res +); + +/** @brief sets res = a */ +EXPORT void vec_rnx_copy( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief sets res = -a */ +EXPORT void vec_rnx_negate( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief sets res = a + b */ +EXPORT void vec_rnx_add( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl, // a + const double* b, uint64_t b_size, uint64_t b_sl // b +); + +/** @brief sets res = a - b */ +EXPORT void vec_rnx_sub( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl, // a + const double* b, uint64_t b_size, uint64_t b_sl // b +); + +/** @brief sets res = a . X^p */ +EXPORT void vec_rnx_rotate( // + const MOD_RNX* module, // N + const int64_t p, // rotation value + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief sets res = a . (X^p - 1) */ +EXPORT void vec_rnx_mul_xp_minus_one( // + const MOD_RNX* module, // N + const int64_t p, // rotation value + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief sets res = a(X^p) */ +EXPORT void vec_rnx_automorphism( // + const MOD_RNX* module, // N + int64_t p, // X -> X^p + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +/////////////////////////////////////////////////////////////////// +// conversions // +/////////////////////////////////////////////////////////////////// + +EXPORT void vec_rnx_to_znx32( // + const MOD_RNX* module, // N + int32_t* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_rnx_from_znx32( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const int32_t* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_rnx_to_tnx32( // + const MOD_RNX* module, // N + int32_t* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_rnx_from_tnx32( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const int32_t* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_rnx_to_tnx32x2( // + const MOD_RNX* module, // N + int32_t* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_rnx_from_tnx32x2( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const int32_t* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_rnx_to_tnxdbl( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +/////////////////////////////////////////////////////////////////// +// isolated products (n.log(n), but not particularly optimized // +/////////////////////////////////////////////////////////////////// + +/** @brief res = a * b : small polynomial product */ +EXPORT void rnx_small_single_product( // + const MOD_RNX* module, // N + double* res, // output + const double* a, // a + const double* b, // b + uint8_t* tmp); // scratch space + +EXPORT uint64_t rnx_small_single_product_tmp_bytes(const MOD_RNX* module); + +/** @brief res = a * b centermod 1: small polynomial product */ +EXPORT void tnxdbl_small_single_product( // + const MOD_RNX* module, // N + double* torus_res, // output + const double* int_a, // a + const double* torus_b, // b + uint8_t* tmp); // scratch space + +EXPORT uint64_t tnxdbl_small_single_product_tmp_bytes(const MOD_RNX* module); + +/** @brief res = a * b: small polynomial product */ +EXPORT void znx32_small_single_product( // + const MOD_RNX* module, // N + int32_t* int_res, // output + const int32_t* int_a, // a + const int32_t* int_b, // b + uint8_t* tmp); // scratch space + +EXPORT uint64_t znx32_small_single_product_tmp_bytes(const MOD_RNX* module); + +/** @brief res = a * b centermod 1: small polynomial product */ +EXPORT void tnx32_small_single_product( // + const MOD_RNX* module, // N + int32_t* torus_res, // output + const int32_t* int_a, // a + const int32_t* torus_b, // b + uint8_t* tmp); // scratch space + +EXPORT uint64_t tnx32_small_single_product_tmp_bytes(const MOD_RNX* module); + +/////////////////////////////////////////////////////////////////// +// prepared gadget decompositions (optimized) // +/////////////////////////////////////////////////////////////////// + +// decompose from tnx32 + +typedef struct tnx32_approxdecomp_gadget_t TNX32_APPROXDECOMP_GADGET; + +/** @brief new gadget: delete with delete_tnx32_approxdecomp_gadget */ +EXPORT TNX32_APPROXDECOMP_GADGET* new_tnx32_approxdecomp_gadget( // + const MOD_RNX* module, // N + uint64_t k, uint64_t ell // base 2^K and size +); +EXPORT void delete_tnx32_approxdecomp_gadget(const MOD_RNX* module); + +/** @brief sets res = gadget_decompose(a) */ +EXPORT void rnx_approxdecomp_from_tnx32( // + const MOD_RNX* module, // N + const TNX32_APPROXDECOMP_GADGET* gadget, // output base 2^K + double* res, uint64_t res_size, uint64_t res_sl, // res + const int32_t* a // a +); + +// decompose from tnx32x2 + +typedef struct tnx32x2_approxdecomp_gadget_t TNX32X2_APPROXDECOMP_GADGET; + +/** @brief new gadget: delete with delete_tnx32x2_approxdecomp_gadget */ +EXPORT TNX32X2_APPROXDECOMP_GADGET* new_tnx32x2_approxdecomp_gadget(const MOD_RNX* module, uint64_t ka, uint64_t ella, + uint64_t kb, uint64_t ellb); +EXPORT void delete_tnx32x2_approxdecomp_gadget(const MOD_RNX* module); + +/** @brief sets res = gadget_decompose(a) */ +EXPORT void rnx_approxdecomp_from_tnx32x2( // + const MOD_RNX* module, // N + const TNX32X2_APPROXDECOMP_GADGET* gadget, // output base 2^K + double* res, uint64_t res_size, uint64_t res_sl, // res + const int32_t* a // a +); + +// decompose from tnxdbl + +typedef struct tnxdbl_approxdecomp_gadget_t TNXDBL_APPROXDECOMP_GADGET; + +/** @brief new gadget: delete with delete_tnxdbl_approxdecomp_gadget */ +EXPORT TNXDBL_APPROXDECOMP_GADGET* new_tnxdbl_approxdecomp_gadget( // + const MOD_RNX* module, // N + uint64_t k, uint64_t ell // base 2^K and size +); +EXPORT void delete_tnxdbl_approxdecomp_gadget(TNXDBL_APPROXDECOMP_GADGET* gadget); + +/** @brief sets res = gadget_decompose(a) */ +EXPORT void rnx_approxdecomp_from_tnxdbl( // + const MOD_RNX* module, // N + const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a); // a + +/////////////////////////////////////////////////////////////////// +// prepared scalar-vector product (optimized) // +/////////////////////////////////////////////////////////////////// + +/** @brief opaque type that represents a polynomial of RnX prepared for a scalar-vector product */ +typedef struct rnx_svp_ppol_t RNX_SVP_PPOL; + +/** @brief number of bytes in a RNX_VMP_PMAT (for manual allocation) */ +EXPORT uint64_t bytes_of_rnx_svp_ppol(const MOD_RNX* module); // N + +/** @brief allocates a prepared vector (release with delete_rnx_svp_ppol) */ +EXPORT RNX_SVP_PPOL* new_rnx_svp_ppol(const MOD_RNX* module); // N + +/** @brief frees memory for a prepared vector */ +EXPORT void delete_rnx_svp_ppol(RNX_SVP_PPOL* res); + +/** @brief prepares a svp polynomial */ +EXPORT void rnx_svp_prepare(const MOD_RNX* module, // N + RNX_SVP_PPOL* ppol, // output + const double* pol // a +); + +/** @brief apply a svp product, result = ppol * a, presented in DFT space */ +EXPORT void rnx_svp_apply( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // output + const RNX_SVP_PPOL* ppol, // prepared pol + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +/////////////////////////////////////////////////////////////////// +// prepared vector-matrix product (optimized) // +/////////////////////////////////////////////////////////////////// + +typedef struct rnx_vmp_pmat_t RNX_VMP_PMAT; + +/** @brief number of bytes in a RNX_VMP_PMAT (for manual allocation) */ +EXPORT uint64_t bytes_of_rnx_vmp_pmat(const MOD_RNX* module, // N + uint64_t nrows, uint64_t ncols); // dimensions + +/** @brief allocates a prepared matrix (release with delete_rnx_vmp_pmat) */ +EXPORT RNX_VMP_PMAT* new_rnx_vmp_pmat(const MOD_RNX* module, // N + uint64_t nrows, uint64_t ncols); // dimensions +EXPORT void delete_rnx_vmp_pmat(RNX_VMP_PMAT* ptr); + +/** @brief prepares a vmp matrix (contiguous row-major version) */ +EXPORT void rnx_vmp_prepare_contiguous( // + const MOD_RNX* module, // N + RNX_VMP_PMAT* pmat, // output + const double* a, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +); + +/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */ +EXPORT void rnx_vmp_prepare_dblptr( // + const MOD_RNX* module, // N + RNX_VMP_PMAT* pmat, // output + const double** a, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +); + +/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */ +EXPORT void rnx_vmp_prepare_row( // + const MOD_RNX* module, // N + RNX_VMP_PMAT* pmat, // output + const double* a, uint64_t row_i, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +); + +/** @brief number of scratch bytes necessary to prepare a matrix */ +EXPORT uint64_t rnx_vmp_prepare_tmp_bytes(const MOD_RNX* module); + +/** @brief applies a vmp product res = a x pmat */ +EXPORT void rnx_vmp_apply_tmp_a( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten) + const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space +); + +EXPORT uint64_t rnx_vmp_apply_tmp_a_tmp_bytes( // + const MOD_RNX* module, // N + uint64_t res_size, // res size + uint64_t a_size, // a size + uint64_t nrows, uint64_t ncols // prep matrix dims +); + +/** @brief minimal size of the tmp_space */ +EXPORT void rnx_vmp_apply_dft_to_dft( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a_dft, uint64_t a_size, uint64_t a_sl, // a + const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes) +); + +/** @brief minimal size of the tmp_space */ +EXPORT uint64_t rnx_vmp_apply_dft_to_dft_tmp_bytes( // + const MOD_RNX* module, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix +); + +/** @brief sets res = DFT(a) */ +EXPORT void vec_rnx_dft(const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief sets res = iDFT(a_dft) -- idft is not normalized */ +EXPORT void vec_rnx_idft(const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a_dft, uint64_t a_size, uint64_t a_sl // a +); + +#endif // SPQLIOS_VEC_RNX_ARITHMETIC_H diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_rnx_arithmetic_avx.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_rnx_arithmetic_avx.c new file mode 100644 index 0000000..04b3ec0 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_rnx_arithmetic_avx.c @@ -0,0 +1,189 @@ +#include +#include + +#include "vec_rnx_arithmetic_private.h" + +void rnx_add_avx(uint64_t nn, double* res, const double* a, const double* b) { + if (nn < 8) { + if (nn == 4) { + _mm256_storeu_pd(res, _mm256_add_pd(_mm256_loadu_pd(a), _mm256_loadu_pd(b))); + } else if (nn == 2) { + _mm_storeu_pd(res, _mm_add_pd(_mm_loadu_pd(a), _mm_loadu_pd(b))); + } else if (nn == 1) { + *res = *a + *b; + } else { + NOT_SUPPORTED(); // not a power of 2 + } + return; + } + // general case: nn >= 8 + __m256d x0, x1, x2, x3, x4, x5; + const double* aa = a; + const double* bb = b; + double* rr = res; + double* const rrend = res + nn; + do { + x0 = _mm256_loadu_pd(aa); + x1 = _mm256_loadu_pd(aa + 4); + x2 = _mm256_loadu_pd(bb); + x3 = _mm256_loadu_pd(bb + 4); + x4 = _mm256_add_pd(x0, x2); + x5 = _mm256_add_pd(x1, x3); + _mm256_storeu_pd(rr, x4); + _mm256_storeu_pd(rr + 4, x5); + aa += 8; + bb += 8; + rr += 8; + } while (rr < rrend); +} + +void rnx_sub_avx(uint64_t nn, double* res, const double* a, const double* b) { + if (nn < 8) { + if (nn == 4) { + _mm256_storeu_pd(res, _mm256_sub_pd(_mm256_loadu_pd(a), _mm256_loadu_pd(b))); + } else if (nn == 2) { + _mm_storeu_pd(res, _mm_sub_pd(_mm_loadu_pd(a), _mm_loadu_pd(b))); + } else if (nn == 1) { + *res = *a - *b; + } else { + NOT_SUPPORTED(); // not a power of 2 + } + return; + } + // general case: nn >= 8 + __m256d x0, x1, x2, x3, x4, x5; + const double* aa = a; + const double* bb = b; + double* rr = res; + double* const rrend = res + nn; + do { + x0 = _mm256_loadu_pd(aa); + x1 = _mm256_loadu_pd(aa + 4); + x2 = _mm256_loadu_pd(bb); + x3 = _mm256_loadu_pd(bb + 4); + x4 = _mm256_sub_pd(x0, x2); + x5 = _mm256_sub_pd(x1, x3); + _mm256_storeu_pd(rr, x4); + _mm256_storeu_pd(rr + 4, x5); + aa += 8; + bb += 8; + rr += 8; + } while (rr < rrend); +} + +void rnx_negate_avx(uint64_t nn, double* res, const double* b) { + if (nn < 8) { + if (nn == 4) { + _mm256_storeu_pd(res, _mm256_sub_pd(_mm256_set1_pd(0), _mm256_loadu_pd(b))); + } else if (nn == 2) { + _mm_storeu_pd(res, _mm_sub_pd(_mm_set1_pd(0), _mm_loadu_pd(b))); + } else if (nn == 1) { + *res = -*b; + } else { + NOT_SUPPORTED(); // not a power of 2 + } + return; + } + // general case: nn >= 8 + __m256d x2, x3, x4, x5; + const __m256d ZERO = _mm256_set1_pd(0); + const double* bb = b; + double* rr = res; + double* const rrend = res + nn; + do { + x2 = _mm256_loadu_pd(bb); + x3 = _mm256_loadu_pd(bb + 4); + x4 = _mm256_sub_pd(ZERO, x2); + x5 = _mm256_sub_pd(ZERO, x3); + _mm256_storeu_pd(rr, x4); + _mm256_storeu_pd(rr + 4, x5); + bb += 8; + rr += 8; + } while (rr < rrend); +} + +/** @brief sets res = a + b */ +EXPORT void vec_rnx_add_avx( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl, // a + const double* b, uint64_t b_size, uint64_t b_sl // b +) { + const uint64_t nn = module->n; + if (a_size < b_size) { + const uint64_t msize = res_size < a_size ? res_size : a_size; + const uint64_t nsize = res_size < b_size ? res_size : b_size; + for (uint64_t i = 0; i < msize; ++i) { + rnx_add_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + for (uint64_t i = msize; i < nsize; ++i) { + memcpy(res + i * res_sl, b + i * b_sl, nn * sizeof(double)); + } + for (uint64_t i = nsize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } + } else { + const uint64_t msize = res_size < b_size ? res_size : b_size; + const uint64_t nsize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + rnx_add_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + for (uint64_t i = msize; i < nsize; ++i) { + memcpy(res + i * res_sl, a + i * a_sl, nn * sizeof(double)); + } + for (uint64_t i = nsize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } + } +} + +/** @brief sets res = -a */ +EXPORT void vec_rnx_negate_avx( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->n; + const uint64_t msize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + rnx_negate_avx(nn, res + i * res_sl, a + i * a_sl); + } + for (uint64_t i = msize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } +} + +/** @brief sets res = a - b */ +EXPORT void vec_rnx_sub_avx( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl, // a + const double* b, uint64_t b_size, uint64_t b_sl // b +) { + const uint64_t nn = module->n; + if (a_size < b_size) { + const uint64_t msize = res_size < a_size ? res_size : a_size; + const uint64_t nsize = res_size < b_size ? res_size : b_size; + for (uint64_t i = 0; i < msize; ++i) { + rnx_sub_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + for (uint64_t i = msize; i < nsize; ++i) { + rnx_negate_avx(nn, res + i * res_sl, b + i * b_sl); + } + for (uint64_t i = nsize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } + } else { + const uint64_t msize = res_size < b_size ? res_size : b_size; + const uint64_t nsize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + rnx_sub_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + for (uint64_t i = msize; i < nsize; ++i) { + memcpy(res + i * res_sl, a + i * a_sl, nn * sizeof(double)); + } + for (uint64_t i = nsize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } + } +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_rnx_arithmetic_plugin.h b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_rnx_arithmetic_plugin.h new file mode 100644 index 0000000..99277b4 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_rnx_arithmetic_plugin.h @@ -0,0 +1,92 @@ +#ifndef SPQLIOS_VEC_RNX_ARITHMETIC_PLUGIN_H +#define SPQLIOS_VEC_RNX_ARITHMETIC_PLUGIN_H + +#include "vec_rnx_arithmetic.h" + +typedef typeof(vec_rnx_zero) VEC_RNX_ZERO_F; +typedef typeof(vec_rnx_copy) VEC_RNX_COPY_F; +typedef typeof(vec_rnx_negate) VEC_RNX_NEGATE_F; +typedef typeof(vec_rnx_add) VEC_RNX_ADD_F; +typedef typeof(vec_rnx_sub) VEC_RNX_SUB_F; +typedef typeof(vec_rnx_rotate) VEC_RNX_ROTATE_F; +typedef typeof(vec_rnx_mul_xp_minus_one) VEC_RNX_MUL_XP_MINUS_ONE_F; +typedef typeof(vec_rnx_automorphism) VEC_RNX_AUTOMORPHISM_F; +typedef typeof(vec_rnx_to_znx32) VEC_RNX_TO_ZNX32_F; +typedef typeof(vec_rnx_from_znx32) VEC_RNX_FROM_ZNX32_F; +typedef typeof(vec_rnx_to_tnx32) VEC_RNX_TO_TNX32_F; +typedef typeof(vec_rnx_from_tnx32) VEC_RNX_FROM_TNX32_F; +typedef typeof(vec_rnx_to_tnx32x2) VEC_RNX_TO_TNX32X2_F; +typedef typeof(vec_rnx_from_tnx32x2) VEC_RNX_FROM_TNX32X2_F; +typedef typeof(vec_rnx_to_tnxdbl) VEC_RNX_TO_TNXDBL_F; +// typedef typeof(vec_rnx_from_tnxdbl) VEC_RNX_FROM_TNXDBL_F; +typedef typeof(rnx_small_single_product) RNX_SMALL_SINGLE_PRODUCT_F; +typedef typeof(rnx_small_single_product_tmp_bytes) RNX_SMALL_SINGLE_PRODUCT_TMP_BYTES_F; +typedef typeof(tnxdbl_small_single_product) TNXDBL_SMALL_SINGLE_PRODUCT_F; +typedef typeof(tnxdbl_small_single_product_tmp_bytes) TNXDBL_SMALL_SINGLE_PRODUCT_TMP_BYTES_F; +typedef typeof(znx32_small_single_product) ZNX32_SMALL_SINGLE_PRODUCT_F; +typedef typeof(znx32_small_single_product_tmp_bytes) ZNX32_SMALL_SINGLE_PRODUCT_TMP_BYTES_F; +typedef typeof(tnx32_small_single_product) TNX32_SMALL_SINGLE_PRODUCT_F; +typedef typeof(tnx32_small_single_product_tmp_bytes) TNX32_SMALL_SINGLE_PRODUCT_TMP_BYTES_F; +typedef typeof(rnx_approxdecomp_from_tnx32) RNX_APPROXDECOMP_FROM_TNX32_F; +typedef typeof(rnx_approxdecomp_from_tnx32x2) RNX_APPROXDECOMP_FROM_TNX32X2_F; +typedef typeof(rnx_approxdecomp_from_tnxdbl) RNX_APPROXDECOMP_FROM_TNXDBL_F; +typedef typeof(bytes_of_rnx_svp_ppol) BYTES_OF_RNX_SVP_PPOL_F; +typedef typeof(rnx_svp_prepare) RNX_SVP_PREPARE_F; +typedef typeof(rnx_svp_apply) RNX_SVP_APPLY_F; +typedef typeof(bytes_of_rnx_vmp_pmat) BYTES_OF_RNX_VMP_PMAT_F; +typedef typeof(rnx_vmp_prepare_contiguous) RNX_VMP_PREPARE_CONTIGUOUS_F; +typedef typeof(rnx_vmp_prepare_dblptr) RNX_VMP_PREPARE_DBLPTR_F; +typedef typeof(rnx_vmp_prepare_row) RNX_VMP_PREPARE_ROW_F; +typedef typeof(rnx_vmp_prepare_tmp_bytes) RNX_VMP_PREPARE_TMP_BYTES_F; +typedef typeof(rnx_vmp_apply_tmp_a) RNX_VMP_APPLY_TMP_A_F; +typedef typeof(rnx_vmp_apply_tmp_a_tmp_bytes) RNX_VMP_APPLY_TMP_A_TMP_BYTES_F; +typedef typeof(rnx_vmp_apply_dft_to_dft) RNX_VMP_APPLY_DFT_TO_DFT_F; +typedef typeof(rnx_vmp_apply_dft_to_dft_tmp_bytes) RNX_VMP_APPLY_DFT_TO_DFT_TMP_BYTES_F; +typedef typeof(vec_rnx_dft) VEC_RNX_DFT_F; +typedef typeof(vec_rnx_idft) VEC_RNX_IDFT_F; + +typedef struct rnx_module_vtable_t RNX_MODULE_VTABLE; +struct rnx_module_vtable_t { + VEC_RNX_ZERO_F* vec_rnx_zero; + VEC_RNX_COPY_F* vec_rnx_copy; + VEC_RNX_NEGATE_F* vec_rnx_negate; + VEC_RNX_ADD_F* vec_rnx_add; + VEC_RNX_SUB_F* vec_rnx_sub; + VEC_RNX_ROTATE_F* vec_rnx_rotate; + VEC_RNX_MUL_XP_MINUS_ONE_F* vec_rnx_mul_xp_minus_one; + VEC_RNX_AUTOMORPHISM_F* vec_rnx_automorphism; + VEC_RNX_TO_ZNX32_F* vec_rnx_to_znx32; + VEC_RNX_FROM_ZNX32_F* vec_rnx_from_znx32; + VEC_RNX_TO_TNX32_F* vec_rnx_to_tnx32; + VEC_RNX_FROM_TNX32_F* vec_rnx_from_tnx32; + VEC_RNX_TO_TNX32X2_F* vec_rnx_to_tnx32x2; + VEC_RNX_FROM_TNX32X2_F* vec_rnx_from_tnx32x2; + VEC_RNX_TO_TNXDBL_F* vec_rnx_to_tnxdbl; + RNX_SMALL_SINGLE_PRODUCT_F* rnx_small_single_product; + RNX_SMALL_SINGLE_PRODUCT_TMP_BYTES_F* rnx_small_single_product_tmp_bytes; + TNXDBL_SMALL_SINGLE_PRODUCT_F* tnxdbl_small_single_product; + TNXDBL_SMALL_SINGLE_PRODUCT_TMP_BYTES_F* tnxdbl_small_single_product_tmp_bytes; + ZNX32_SMALL_SINGLE_PRODUCT_F* znx32_small_single_product; + ZNX32_SMALL_SINGLE_PRODUCT_TMP_BYTES_F* znx32_small_single_product_tmp_bytes; + TNX32_SMALL_SINGLE_PRODUCT_F* tnx32_small_single_product; + TNX32_SMALL_SINGLE_PRODUCT_TMP_BYTES_F* tnx32_small_single_product_tmp_bytes; + RNX_APPROXDECOMP_FROM_TNX32_F* rnx_approxdecomp_from_tnx32; + RNX_APPROXDECOMP_FROM_TNX32X2_F* rnx_approxdecomp_from_tnx32x2; + RNX_APPROXDECOMP_FROM_TNXDBL_F* rnx_approxdecomp_from_tnxdbl; + BYTES_OF_RNX_SVP_PPOL_F* bytes_of_rnx_svp_ppol; + RNX_SVP_PREPARE_F* rnx_svp_prepare; + RNX_SVP_APPLY_F* rnx_svp_apply; + BYTES_OF_RNX_VMP_PMAT_F* bytes_of_rnx_vmp_pmat; + RNX_VMP_PREPARE_CONTIGUOUS_F* rnx_vmp_prepare_contiguous; + RNX_VMP_PREPARE_DBLPTR_F* rnx_vmp_prepare_dblptr; + RNX_VMP_PREPARE_ROW_F* rnx_vmp_prepare_row; + RNX_VMP_PREPARE_TMP_BYTES_F* rnx_vmp_prepare_tmp_bytes; + RNX_VMP_APPLY_TMP_A_F* rnx_vmp_apply_tmp_a; + RNX_VMP_APPLY_TMP_A_TMP_BYTES_F* rnx_vmp_apply_tmp_a_tmp_bytes; + RNX_VMP_APPLY_DFT_TO_DFT_F* rnx_vmp_apply_dft_to_dft; + RNX_VMP_APPLY_DFT_TO_DFT_TMP_BYTES_F* rnx_vmp_apply_dft_to_dft_tmp_bytes; + VEC_RNX_DFT_F* vec_rnx_dft; + VEC_RNX_IDFT_F* vec_rnx_idft; +}; + +#endif // SPQLIOS_VEC_RNX_ARITHMETIC_PLUGIN_H diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_rnx_arithmetic_private.h b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_rnx_arithmetic_private.h new file mode 100644 index 0000000..7f1ae31 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_rnx_arithmetic_private.h @@ -0,0 +1,309 @@ +#ifndef SPQLIOS_VEC_RNX_ARITHMETIC_PRIVATE_H +#define SPQLIOS_VEC_RNX_ARITHMETIC_PRIVATE_H + +#include "../commons_private.h" +#include "../reim/reim_fft.h" +#include "vec_rnx_arithmetic.h" +#include "vec_rnx_arithmetic_plugin.h" + +typedef struct fft64_rnx_module_precomp_t FFT64_RNX_MODULE_PRECOMP; +struct fft64_rnx_module_precomp_t { + REIM_FFT_PRECOMP* p_fft; + REIM_IFFT_PRECOMP* p_ifft; + REIM_FFTVEC_ADD_PRECOMP* p_fftvec_add; + REIM_FFTVEC_MUL_PRECOMP* p_fftvec_mul; + REIM_FFTVEC_ADDMUL_PRECOMP* p_fftvec_addmul; +}; + +typedef union rnx_module_precomp_t RNX_MODULE_PRECOMP; +union rnx_module_precomp_t { + FFT64_RNX_MODULE_PRECOMP fft64; +}; + +void fft64_init_rnx_module_precomp(MOD_RNX* module); + +void fft64_finalize_rnx_module_precomp(MOD_RNX* module); + +/** @brief opaque structure that describes the modules (RnX,ZnX,TnX) and the hardware */ +struct rnx_module_info_t { + uint64_t n; + uint64_t m; + RNX_MODULE_TYPE mtype; + RNX_MODULE_VTABLE vtable; + RNX_MODULE_PRECOMP precomp; + void* custom; + void (*custom_deleter)(void*); +}; + +void init_rnx_module_info(MOD_RNX* module, // + uint64_t, RNX_MODULE_TYPE mtype); + +void finalize_rnx_module_info(MOD_RNX* module); + +void fft64_init_rnx_module_vtable(MOD_RNX* module); + +/////////////////////////////////////////////////////////////////// +// prepared gadget decompositions (optimized) // +/////////////////////////////////////////////////////////////////// + +struct tnx32_approxdec_gadget_t { + uint64_t k; + uint64_t ell; + int32_t add_cst; // 1/2.(sum 2^-(i+1)K) + int32_t rshift_base; // 32 - K + int64_t and_mask; // 2^K-1 + int64_t or_mask; // double(2^52) + double sub_cst; // double(2^52 + 2^(K-1)) + uint8_t rshifts[8]; // 32 - (i+1).K +}; + +struct tnx32x2_approxdec_gadget_t { + // TODO +}; + +struct tnxdbl_approxdecomp_gadget_t { + uint64_t k; + uint64_t ell; + double add_cst; // double(3.2^(51-ell.K) + 1/2.(sum 2^(-iK)) for i=[0,ell[) + uint64_t and_mask; // uint64(2^(K)-1) + uint64_t or_mask; // double(2^52) + double sub_cst; // double(2^52 + 2^(K-1)) +}; + +EXPORT void vec_rnx_add_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl, // a + const double* b, uint64_t b_size, uint64_t b_sl // b +); +EXPORT void vec_rnx_add_avx( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl, // a + const double* b, uint64_t b_size, uint64_t b_sl // b +); + +/** @brief sets res = 0 */ +EXPORT void vec_rnx_zero_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl // res +); + +/** @brief sets res = a */ +EXPORT void vec_rnx_copy_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief sets res = -a */ +EXPORT void vec_rnx_negate_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief sets res = -a */ +EXPORT void vec_rnx_negate_avx( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief sets res = a - b */ +EXPORT void vec_rnx_sub_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl, // a + const double* b, uint64_t b_size, uint64_t b_sl // b +); + +/** @brief sets res = a - b */ +EXPORT void vec_rnx_sub_avx( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl, // a + const double* b, uint64_t b_size, uint64_t b_sl // b +); + +/** @brief sets res = a . X^p */ +EXPORT void vec_rnx_rotate_ref( // + const MOD_RNX* module, // N + const int64_t p, // rotation value + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief sets res = a(X^p) */ +EXPORT void vec_rnx_automorphism_ref( // + const MOD_RNX* module, // N + int64_t p, // X -> X^p + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief number of bytes in a RNX_VMP_PMAT (for manual allocation) */ +EXPORT uint64_t fft64_bytes_of_rnx_vmp_pmat(const MOD_RNX* module, // N + uint64_t nrows, uint64_t ncols); + +EXPORT void fft64_rnx_vmp_apply_dft_to_dft_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a_dft, uint64_t a_size, uint64_t a_sl, // a + const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes) +); +EXPORT void fft64_rnx_vmp_apply_dft_to_dft_avx( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a_dft, uint64_t a_size, uint64_t a_sl, // a + const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes) +); +EXPORT uint64_t fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref( // + const MOD_RNX* module, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix +); +EXPORT uint64_t fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_avx( // + const MOD_RNX* module, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix +); +EXPORT void fft64_rnx_vmp_prepare_contiguous_ref( // + const MOD_RNX* module, // N + RNX_VMP_PMAT* pmat, // output + const double* mat, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +); +EXPORT void fft64_rnx_vmp_prepare_contiguous_avx( // + const MOD_RNX* module, // N + RNX_VMP_PMAT* pmat, // output + const double* mat, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +); +EXPORT void fft64_rnx_vmp_prepare_dblptr_ref( // + const MOD_RNX* module, // N + RNX_VMP_PMAT* pmat, // output + const double** mat, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +); +EXPORT void fft64_rnx_vmp_prepare_dblptr_avx( // + const MOD_RNX* module, // N + RNX_VMP_PMAT* pmat, // output + const double** mat, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +); +EXPORT void fft64_rnx_vmp_prepare_row_ref( // + const MOD_RNX* module, // N + RNX_VMP_PMAT* pmat, // output + const double* mat, uint64_t row_i, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +); +EXPORT void fft64_rnx_vmp_prepare_row_avx( // + const MOD_RNX* module, // N + RNX_VMP_PMAT* pmat, // output + const double* mat, uint64_t row_i, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +); +EXPORT uint64_t fft64_rnx_vmp_prepare_tmp_bytes_ref(const MOD_RNX* module); +EXPORT uint64_t fft64_rnx_vmp_prepare_tmp_bytes_avx(const MOD_RNX* module); + +EXPORT void fft64_rnx_vmp_apply_tmp_a_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res (addr must be != a) + double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten) + const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space +); +EXPORT void fft64_rnx_vmp_apply_tmp_a_avx( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res (addr must be != a) + double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten) + const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space +); + +EXPORT uint64_t fft64_rnx_vmp_apply_tmp_a_tmp_bytes_ref( // + const MOD_RNX* module, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix +); +EXPORT uint64_t fft64_rnx_vmp_apply_tmp_a_tmp_bytes_avx( // + const MOD_RNX* module, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix +); + +/// gadget decompositions + +/** @brief sets res = gadget_decompose(a) */ +EXPORT void rnx_approxdecomp_from_tnxdbl_ref( // + const MOD_RNX* module, // N + const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a); // a +EXPORT void rnx_approxdecomp_from_tnxdbl_avx( // + const MOD_RNX* module, // N + const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a); // a + +EXPORT void vec_rnx_mul_xp_minus_one_ref( // + const MOD_RNX* module, // N + const int64_t p, // rotation value + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_rnx_to_znx32_ref( // + const MOD_RNX* module, // N + int32_t* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_rnx_from_znx32_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const int32_t* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_rnx_to_tnx32_ref( // + const MOD_RNX* module, // N + int32_t* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_rnx_from_tnx32_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const int32_t* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_rnx_to_tnxdbl_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT uint64_t fft64_bytes_of_rnx_svp_ppol(const MOD_RNX* module); // N + +/** @brief prepares a svp polynomial */ +EXPORT void fft64_rnx_svp_prepare_ref(const MOD_RNX* module, // N + RNX_SVP_PPOL* ppol, // output + const double* pol // a +); + +/** @brief apply a svp product, result = ppol * a, presented in DFT space */ +EXPORT void fft64_rnx_svp_apply_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // output + const RNX_SVP_PPOL* ppol, // prepared pol + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +#endif // SPQLIOS_VEC_RNX_ARITHMETIC_PRIVATE_H diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_rnx_conversions_ref.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_rnx_conversions_ref.c new file mode 100644 index 0000000..2a1b296 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_rnx_conversions_ref.c @@ -0,0 +1,91 @@ +#include + +#include "vec_rnx_arithmetic_private.h" +#include "zn_arithmetic_private.h" + +EXPORT void vec_rnx_to_znx32_ref( // + const MOD_RNX* module, // N + int32_t* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->n; + const uint64_t msize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + dbl_round_to_i32_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn); + } + for (uint64_t i = msize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(int32_t)); + } +} + +EXPORT void vec_rnx_from_znx32_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const int32_t* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->n; + const uint64_t msize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + i32_to_dbl_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn); + } + for (uint64_t i = msize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(int32_t)); + } +} +EXPORT void vec_rnx_to_tnx32_ref( // + const MOD_RNX* module, // N + int32_t* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->n; + const uint64_t msize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + dbl_to_tn32_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn); + } + for (uint64_t i = msize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(int32_t)); + } +} +EXPORT void vec_rnx_from_tnx32_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const int32_t* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->n; + const uint64_t msize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + tn32_to_dbl_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn); + } + for (uint64_t i = msize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(int32_t)); + } +} + +static void dbl_to_tndbl_ref( // + const void* UNUSED, // N + double* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +) { + static const double OFF_CST = INT64_C(3) << 51; + const uint64_t msize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + double ai = a[i] + OFF_CST; + res[i] = a[i] - (ai - OFF_CST); + } + memset(res + msize, 0, (res_size - msize) * sizeof(double)); +} + +EXPORT void vec_rnx_to_tnxdbl_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->n; + const uint64_t msize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + dbl_to_tndbl_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn); + } + for (uint64_t i = msize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(int32_t)); + } +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_rnx_svp_ref.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_rnx_svp_ref.c new file mode 100644 index 0000000..f811148 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_rnx_svp_ref.c @@ -0,0 +1,47 @@ +#include + +#include "../coeffs/coeffs_arithmetic.h" +#include "vec_rnx_arithmetic_private.h" + +EXPORT uint64_t fft64_bytes_of_rnx_svp_ppol(const MOD_RNX* module) { return module->n * sizeof(double); } + +EXPORT RNX_SVP_PPOL* new_rnx_svp_ppol(const MOD_RNX* module) { return spqlios_alloc(bytes_of_rnx_svp_ppol(module)); } + +EXPORT void delete_rnx_svp_ppol(RNX_SVP_PPOL* ppol) { spqlios_free(ppol); } + +/** @brief prepares a svp polynomial */ +EXPORT void fft64_rnx_svp_prepare_ref(const MOD_RNX* module, // N + RNX_SVP_PPOL* ppol, // output + const double* pol // a +) { + double* const dppol = (double*)ppol; + rnx_divide_by_m_ref(module->n, module->m, dppol, pol); + reim_fft(module->precomp.fft64.p_fft, dppol); +} + +EXPORT void fft64_rnx_svp_apply_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // output + const RNX_SVP_PPOL* ppol, // prepared pol + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->n; + double* const dppol = (double*)ppol; + + const uint64_t auto_end_idx = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < auto_end_idx; ++i) { + const double* a_ptr = a + i * a_sl; + double* const res_ptr = res + i * res_sl; + // copy the polynomial to res, apply fft in place, call fftvec + // _mul, apply ifft in place. + memcpy(res_ptr, a_ptr, nn * sizeof(double)); + reim_fft(module->precomp.fft64.p_fft, (double*)res_ptr); + reim_fftvec_mul(module->precomp.fft64.p_fftvec_mul, res_ptr, res_ptr, dppol); + reim_ifft(module->precomp.fft64.p_ifft, res_ptr); + } + + // then extend with zeros + for (uint64_t i = auto_end_idx; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_rnx_vmp_avx.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_rnx_vmp_avx.c new file mode 100644 index 0000000..7a492bc --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_rnx_vmp_avx.c @@ -0,0 +1,254 @@ +#include +#include +#include + +#include "../coeffs/coeffs_arithmetic.h" +#include "../reim/reim_fft.h" +#include "../reim4/reim4_arithmetic.h" +#include "vec_rnx_arithmetic_private.h" + +/** @brief prepares a vmp matrix (contiguous row-major version) */ +EXPORT void fft64_rnx_vmp_prepare_contiguous_avx( // + const MOD_RNX* module, // N + RNX_VMP_PMAT* pmat, // output + const double* mat, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +) { + // there is an edge case if nn < 8 + const uint64_t nn = module->n; + const uint64_t m = module->m; + + double* const dtmp = (double*)tmp_space; + double* const output_mat = (double*)pmat; + double* start_addr = (double*)pmat; + uint64_t offset = nrows * ncols * 8; + + if (nn >= 8) { + for (uint64_t row_i = 0; row_i < nrows; row_i++) { + for (uint64_t col_i = 0; col_i < ncols; col_i++) { + rnx_divide_by_m_avx(nn, m, dtmp, mat + (row_i * ncols + col_i) * nn); + reim_fft(module->precomp.fft64.p_fft, dtmp); + + if (col_i == (ncols - 1) && (ncols % 2 == 1)) { + // special case: last column out of an odd column number + start_addr = output_mat + col_i * nrows * 8 // col == ncols-1 + + row_i * 8; + } else { + // general case: columns go by pair + start_addr = output_mat + (col_i / 2) * (2 * nrows) * 8 // second: col pair index + + row_i * 2 * 8 // third: row index + + (col_i % 2) * 8; + } + + for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) { + // extract blk from tmp and save it + reim4_extract_1blk_from_reim_avx(m, blk_i, start_addr + blk_i * offset, dtmp); + } + } + } + } else { + for (uint64_t row_i = 0; row_i < nrows; row_i++) { + for (uint64_t col_i = 0; col_i < ncols; col_i++) { + double* res = output_mat + (col_i * nrows + row_i) * nn; + rnx_divide_by_m_avx(nn, m, res, mat + (row_i * ncols + col_i) * nn); + reim_fft(module->precomp.fft64.p_fft, res); + } + } + } +} + +/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */ +EXPORT void fft64_rnx_vmp_prepare_dblptr_avx( // + const MOD_RNX* module, // N + RNX_VMP_PMAT* pmat, // output + const double** mat, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +) { + for (uint64_t row_i = 0; row_i < nrows; row_i++) { + fft64_rnx_vmp_prepare_row_avx(module, pmat, mat[row_i], row_i, nrows, ncols, tmp_space); + } +} + +/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */ +EXPORT void fft64_rnx_vmp_prepare_row_avx( // + const MOD_RNX* module, // N + RNX_VMP_PMAT* pmat, // output + const double* row, uint64_t row_i, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +) { + // there is an edge case if nn < 8 + const uint64_t nn = module->n; + const uint64_t m = module->m; + + double* const dtmp = (double*)tmp_space; + double* const output_mat = (double*)pmat; + double* start_addr = (double*)pmat; + uint64_t offset = nrows * ncols * 8; + + if (nn >= 8) { + for (uint64_t col_i = 0; col_i < ncols; col_i++) { + rnx_divide_by_m_avx(nn, m, dtmp, row + col_i * nn); + reim_fft(module->precomp.fft64.p_fft, dtmp); + + if (col_i == (ncols - 1) && (ncols % 2 == 1)) { + // special case: last column out of an odd column number + start_addr = output_mat + col_i * nrows * 8 // col == ncols-1 + + row_i * 8; + } else { + // general case: columns go by pair + start_addr = output_mat + (col_i / 2) * (2 * nrows) * 8 // second: col pair index + + row_i * 2 * 8 // third: row index + + (col_i % 2) * 8; + } + + for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) { + // extract blk from tmp and save it + reim4_extract_1blk_from_reim_avx(m, blk_i, start_addr + blk_i * offset, dtmp); + } + } + } else { + for (uint64_t col_i = 0; col_i < ncols; col_i++) { + double* res = output_mat + (col_i * nrows + row_i) * nn; + rnx_divide_by_m_avx(nn, m, res, row + col_i * nn); + reim_fft(module->precomp.fft64.p_fft, res); + } + } +} + +/** @brief minimal size of the tmp_space */ +EXPORT void fft64_rnx_vmp_apply_dft_to_dft_avx( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a_dft, uint64_t a_size, uint64_t a_sl, // a + const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes) +) { + const uint64_t m = module->m; + const uint64_t nn = module->n; + + double* mat2cols_output = (double*)tmp_space; // 128 bytes + double* extracted_blk = (double*)tmp_space + 16; // 64*min(nrows,a_size) bytes + + double* mat_input = (double*)pmat; + + const uint64_t row_max = nrows < a_size ? nrows : a_size; + const uint64_t col_max = ncols < res_size ? ncols : res_size; + + if (row_max > 0 && col_max > 0) { + if (nn >= 8) { + // let's do some prefetching of the GSW key, since on some cpus, + // it helps + const uint64_t ms4 = m >> 2; // m/4 + const uint64_t gsw_iter_doubles = 8 * nrows * ncols; + const uint64_t pref_doubles = 1200; + const double* gsw_pref_ptr = mat_input; + const double* const gsw_ptr_end = mat_input + ms4 * gsw_iter_doubles; + const double* gsw_pref_ptr_target = mat_input + pref_doubles; + for (; gsw_pref_ptr < gsw_pref_ptr_target; gsw_pref_ptr += 8) { + __builtin_prefetch(gsw_pref_ptr, 0, _MM_HINT_T0); + } + const double* mat_blk_start; + uint64_t blk_i; + for (blk_i = 0, mat_blk_start = mat_input; blk_i < ms4; blk_i++, mat_blk_start += gsw_iter_doubles) { + // prefetch the next iteration + if (gsw_pref_ptr_target < gsw_ptr_end) { + gsw_pref_ptr_target += gsw_iter_doubles; + if (gsw_pref_ptr_target > gsw_ptr_end) gsw_pref_ptr_target = gsw_ptr_end; + for (; gsw_pref_ptr < gsw_pref_ptr_target; gsw_pref_ptr += 8) { + __builtin_prefetch(gsw_pref_ptr, 0, _MM_HINT_T0); + } + } + reim4_extract_1blk_from_contiguous_reim_sl_avx(m, a_sl, row_max, blk_i, extracted_blk, a_dft); + // apply mat2cols + for (uint64_t col_i = 0; col_i < col_max - 1; col_i += 2) { + uint64_t col_offset = col_i * (8 * nrows); + reim4_vec_mat2cols_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset); + + reim4_save_1blk_to_reim_avx(m, blk_i, res + col_i * res_sl, mat2cols_output); + reim4_save_1blk_to_reim_avx(m, blk_i, res + (col_i + 1) * res_sl, mat2cols_output + 8); + } + + // check if col_max is odd, then special case + if (col_max % 2 == 1) { + uint64_t last_col = col_max - 1; + uint64_t col_offset = last_col * (8 * nrows); + + // the last column is alone in the pmat: vec_mat1col + if (ncols == col_max) { + reim4_vec_mat1col_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset); + } else { + // the last column is part of a colpair in the pmat: vec_mat2cols and ignore the second position + reim4_vec_mat2cols_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset); + } + reim4_save_1blk_to_reim_avx(m, blk_i, res + last_col * res_sl, mat2cols_output); + } + } + } else { + const double* in; + uint64_t in_sl; + if (res == a_dft) { + // it is in place: copy the input vector + in = (double*)tmp_space; + in_sl = nn; + // vec_rnx_copy(module, (double*)tmp_space, row_max, nn, a_dft, row_max, a_sl); + for (uint64_t row_i = 0; row_i < row_max; row_i++) { + memcpy((double*)tmp_space + row_i * nn, a_dft + row_i * a_sl, nn * sizeof(double)); + } + } else { + // it is out of place: do the product directly + in = a_dft; + in_sl = a_sl; + } + for (uint64_t col_i = 0; col_i < col_max; col_i++) { + double* pmat_col = mat_input + col_i * nrows * nn; + { + reim_fftvec_mul(module->precomp.fft64.p_fftvec_mul, // + res + col_i * res_sl, // + in, // + pmat_col); + } + for (uint64_t row_i = 1; row_i < row_max; row_i++) { + reim_fftvec_addmul(module->precomp.fft64.p_fftvec_addmul, // + res + col_i * res_sl, // + in + row_i * in_sl, // + pmat_col + row_i * nn); + } + } + } + } + // zero out remaining bytes (if any) + for (uint64_t i = col_max; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } +} + +/** @brief applies a vmp product res = a x pmat */ +EXPORT void fft64_rnx_vmp_apply_tmp_a_avx( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res (addr must be != a) + double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten) + const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space +) { + const uint64_t nn = module->n; + const uint64_t rows = nrows < a_size ? nrows : a_size; + const uint64_t cols = ncols < res_size ? ncols : res_size; + + // fft is done in place on the input (tmpa is destroyed) + for (uint64_t i = 0; i < rows; ++i) { + reim_fft(module->precomp.fft64.p_fft, tmpa + i * a_sl); + } + fft64_rnx_vmp_apply_dft_to_dft_avx(module, // + res, cols, res_sl, // + tmpa, rows, a_sl, // + pmat, nrows, ncols, // + tmp_space); + // ifft is done in place on the output + for (uint64_t i = 0; i < cols; ++i) { + reim_ifft(module->precomp.fft64.p_ifft, res + i * res_sl); + } + // zero out the remaining positions + for (uint64_t i = cols; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_rnx_vmp_ref.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_rnx_vmp_ref.c new file mode 100644 index 0000000..1a91f3c --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_rnx_vmp_ref.c @@ -0,0 +1,309 @@ +#include +#include + +#include "../coeffs/coeffs_arithmetic.h" +#include "../reim/reim_fft.h" +#include "../reim4/reim4_arithmetic.h" +#include "vec_rnx_arithmetic_private.h" + +/** @brief number of bytes in a RNX_VMP_PMAT (for manual allocation) */ +EXPORT uint64_t fft64_bytes_of_rnx_vmp_pmat(const MOD_RNX* module, // N + uint64_t nrows, uint64_t ncols) { // dimensions + return nrows * ncols * module->n * sizeof(double); +} + +/** @brief prepares a vmp matrix (contiguous row-major version) */ +EXPORT void fft64_rnx_vmp_prepare_contiguous_ref( // + const MOD_RNX* module, // N + RNX_VMP_PMAT* pmat, // output + const double* mat, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +) { + // there is an edge case if nn < 8 + const uint64_t nn = module->n; + const uint64_t m = module->m; + + double* const dtmp = (double*)tmp_space; + double* const output_mat = (double*)pmat; + double* start_addr = (double*)pmat; + uint64_t offset = nrows * ncols * 8; + + if (nn >= 8) { + for (uint64_t row_i = 0; row_i < nrows; row_i++) { + for (uint64_t col_i = 0; col_i < ncols; col_i++) { + rnx_divide_by_m_ref(nn, m, dtmp, mat + (row_i * ncols + col_i) * nn); + reim_fft(module->precomp.fft64.p_fft, dtmp); + + if (col_i == (ncols - 1) && (ncols % 2 == 1)) { + // special case: last column out of an odd column number + start_addr = output_mat + col_i * nrows * 8 // col == ncols-1 + + row_i * 8; + } else { + // general case: columns go by pair + start_addr = output_mat + (col_i / 2) * (2 * nrows) * 8 // second: col pair index + + row_i * 2 * 8 // third: row index + + (col_i % 2) * 8; + } + + for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) { + // extract blk from tmp and save it + reim4_extract_1blk_from_reim_ref(m, blk_i, start_addr + blk_i * offset, dtmp); + } + } + } + } else { + for (uint64_t row_i = 0; row_i < nrows; row_i++) { + for (uint64_t col_i = 0; col_i < ncols; col_i++) { + double* res = output_mat + (col_i * nrows + row_i) * nn; + rnx_divide_by_m_ref(nn, m, res, mat + (row_i * ncols + col_i) * nn); + reim_fft(module->precomp.fft64.p_fft, res); + } + } + } +} + +/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */ +EXPORT void fft64_rnx_vmp_prepare_dblptr_ref( // + const MOD_RNX* module, // N + RNX_VMP_PMAT* pmat, // output + const double** mat, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +) { + for (uint64_t row_i = 0; row_i < nrows; row_i++) { + fft64_rnx_vmp_prepare_row_ref(module, pmat, mat[row_i], row_i, nrows, ncols, tmp_space); + } +} + +/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */ +EXPORT void fft64_rnx_vmp_prepare_row_ref( // + const MOD_RNX* module, // N + RNX_VMP_PMAT* pmat, // output + const double* row, uint64_t row_i, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +) { + // there is an edge case if nn < 8 + const uint64_t nn = module->n; + const uint64_t m = module->m; + + double* const dtmp = (double*)tmp_space; + double* const output_mat = (double*)pmat; + double* start_addr = (double*)pmat; + uint64_t offset = nrows * ncols * 8; + + if (nn >= 8) { + for (uint64_t col_i = 0; col_i < ncols; col_i++) { + rnx_divide_by_m_ref(nn, m, dtmp, row + col_i * nn); + reim_fft(module->precomp.fft64.p_fft, dtmp); + + if (col_i == (ncols - 1) && (ncols % 2 == 1)) { + // special case: last column out of an odd column number + start_addr = output_mat + col_i * nrows * 8 // col == ncols-1 + + row_i * 8; + } else { + // general case: columns go by pair + start_addr = output_mat + (col_i / 2) * (2 * nrows) * 8 // second: col pair index + + row_i * 2 * 8 // third: row index + + (col_i % 2) * 8; + } + + for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) { + // extract blk from tmp and save it + reim4_extract_1blk_from_reim_ref(m, blk_i, start_addr + blk_i * offset, dtmp); + } + } + } else { + for (uint64_t col_i = 0; col_i < ncols; col_i++) { + double* res = output_mat + (col_i * nrows + row_i) * nn; + rnx_divide_by_m_ref(nn, m, res, row + col_i * nn); + reim_fft(module->precomp.fft64.p_fft, res); + } + } +} + +/** @brief number of scratch bytes necessary to prepare a matrix */ +EXPORT uint64_t fft64_rnx_vmp_prepare_tmp_bytes_ref(const MOD_RNX* module) { + const uint64_t nn = module->n; + return nn * sizeof(int64_t); +} + +/** @brief minimal size of the tmp_space */ +EXPORT void fft64_rnx_vmp_apply_dft_to_dft_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a_dft, uint64_t a_size, uint64_t a_sl, // a + const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes) +) { + const uint64_t m = module->m; + const uint64_t nn = module->n; + + double* mat2cols_output = (double*)tmp_space; // 128 bytes + double* extracted_blk = (double*)tmp_space + 16; // 64*min(nrows,a_size) bytes + + double* mat_input = (double*)pmat; + + const uint64_t row_max = nrows < a_size ? nrows : a_size; + const uint64_t col_max = ncols < res_size ? ncols : res_size; + + if (row_max > 0 && col_max > 0) { + if (nn >= 8) { + for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) { + double* mat_blk_start = mat_input + blk_i * (8 * nrows * ncols); + + reim4_extract_1blk_from_contiguous_reim_sl_ref(m, a_sl, row_max, blk_i, extracted_blk, a_dft); + // apply mat2cols + for (uint64_t col_i = 0; col_i < col_max - 1; col_i += 2) { + uint64_t col_offset = col_i * (8 * nrows); + reim4_vec_mat2cols_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset); + + reim4_save_1blk_to_reim_ref(m, blk_i, res + col_i * res_sl, mat2cols_output); + reim4_save_1blk_to_reim_ref(m, blk_i, res + (col_i + 1) * res_sl, mat2cols_output + 8); + } + + // check if col_max is odd, then special case + if (col_max % 2 == 1) { + uint64_t last_col = col_max - 1; + uint64_t col_offset = last_col * (8 * nrows); + + // the last column is alone in the pmat: vec_mat1col + if (ncols == col_max) { + reim4_vec_mat1col_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset); + } else { + // the last column is part of a colpair in the pmat: vec_mat2cols and ignore the second position + reim4_vec_mat2cols_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset); + } + reim4_save_1blk_to_reim_ref(m, blk_i, res + last_col * res_sl, mat2cols_output); + } + } + } else { + const double* in; + uint64_t in_sl; + if (res == a_dft) { + // it is in place: copy the input vector + in = (double*)tmp_space; + in_sl = nn; + // vec_rnx_copy(module, (double*)tmp_space, row_max, nn, a_dft, row_max, a_sl); + for (uint64_t row_i = 0; row_i < row_max; row_i++) { + memcpy((double*)tmp_space + row_i * nn, a_dft + row_i * a_sl, nn * sizeof(double)); + } + } else { + // it is out of place: do the product directly + in = a_dft; + in_sl = a_sl; + } + for (uint64_t col_i = 0; col_i < col_max; col_i++) { + double* pmat_col = mat_input + col_i * nrows * nn; + { + reim_fftvec_mul(module->precomp.fft64.p_fftvec_mul, // + res + col_i * res_sl, // + in, // + pmat_col); + } + for (uint64_t row_i = 1; row_i < row_max; row_i++) { + reim_fftvec_addmul(module->precomp.fft64.p_fftvec_addmul, // + res + col_i * res_sl, // + in + row_i * in_sl, // + pmat_col + row_i * nn); + } + } + } + } + // zero out remaining bytes (if any) + for (uint64_t i = col_max; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } +} + +/** @brief applies a vmp product res = a x pmat */ +EXPORT void fft64_rnx_vmp_apply_tmp_a_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res (addr must be != a) + double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten) + const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space +) { + const uint64_t nn = module->n; + const uint64_t rows = nrows < a_size ? nrows : a_size; + const uint64_t cols = ncols < res_size ? ncols : res_size; + + // fft is done in place on the input (tmpa is destroyed) + for (uint64_t i = 0; i < rows; ++i) { + reim_fft(module->precomp.fft64.p_fft, tmpa + i * a_sl); + } + fft64_rnx_vmp_apply_dft_to_dft_ref(module, // + res, cols, res_sl, // + tmpa, rows, a_sl, // + pmat, nrows, ncols, // + tmp_space); + // ifft is done in place on the output + for (uint64_t i = 0; i < cols; ++i) { + reim_ifft(module->precomp.fft64.p_ifft, res + i * res_sl); + } + // zero out the remaining positions + for (uint64_t i = cols; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } +} + +/** @brief minimal size of the tmp_space */ +EXPORT uint64_t fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref( // + const MOD_RNX* module, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix +) { + const uint64_t row_max = nrows < a_size ? nrows : a_size; + + return (128) + (64 * row_max); +} + +#ifdef __APPLE__ +EXPORT uint64_t fft64_rnx_vmp_apply_tmp_a_tmp_bytes_ref( // + const MOD_RNX* module, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix +) { + return fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref(module, res_size, a_size, nrows, ncols); +} +#else +EXPORT uint64_t fft64_rnx_vmp_apply_tmp_a_tmp_bytes_ref( // + const MOD_RNX* module, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix + ) __attribute((alias("fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref"))); +#endif +// avx aliases that need to be defined in the same .c file + +/** @brief number of scratch bytes necessary to prepare a matrix */ +#ifdef __APPLE__ +#pragma weak fft64_rnx_vmp_prepare_tmp_bytes_avx = fft64_rnx_vmp_prepare_tmp_bytes_ref +#else +EXPORT uint64_t fft64_rnx_vmp_prepare_tmp_bytes_avx(const MOD_RNX* module) + __attribute((alias("fft64_rnx_vmp_prepare_tmp_bytes_ref"))); +#endif + +/** @brief minimal size of the tmp_space */ +#ifdef __APPLE__ +#pragma weak fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_avx = fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref +#else +EXPORT uint64_t fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_avx( // + const MOD_RNX* module, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix + ) __attribute((alias("fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref"))); +#endif + +#ifdef __APPLE__ +#pragma weak fft64_rnx_vmp_apply_tmp_a_tmp_bytes_avx = fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref +#else +EXPORT uint64_t fft64_rnx_vmp_apply_tmp_a_tmp_bytes_avx( // + const MOD_RNX* module, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix + ) __attribute((alias("fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref"))); +#endif +// wrappers diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_znx.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_znx.c new file mode 100644 index 0000000..75771c0 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_znx.c @@ -0,0 +1,369 @@ +#include +#include +#include +#include +#include + +#include "../coeffs/coeffs_arithmetic.h" +#include "../q120/q120_arithmetic.h" +#include "../q120/q120_ntt.h" +#include "../reim/reim_fft_internal.h" +#include "../reim4/reim4_arithmetic.h" +#include "vec_znx_arithmetic.h" +#include "vec_znx_arithmetic_private.h" + +// general function (virtual dispatch) + +EXPORT void vec_znx_add(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +) { + module->func.vec_znx_add(module, // N + res, res_size, res_sl, // res + a, a_size, a_sl, // a + b, b_size, b_sl // b + ); +} + +EXPORT void vec_znx_sub(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +) { + module->func.vec_znx_sub(module, // N + res, res_size, res_sl, // res + a, a_size, a_sl, // a + b, b_size, b_sl // b + ); +} + +EXPORT void vec_znx_rotate(const MODULE* module, // N + const int64_t p, // rotation value + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +) { + module->func.vec_znx_rotate(module, // N + p, // p + res, res_size, res_sl, // res + a, a_size, a_sl // a + ); +} + +EXPORT void vec_znx_mul_xp_minus_one(const MODULE* module, // N + const int64_t p, // p + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +) { + module->func.vec_znx_mul_xp_minus_one(module, // N + p, // p + res, res_size, res_sl, // res + a, a_size, a_sl // a + ); +} + +EXPORT void vec_znx_automorphism(const MODULE* module, // N + const int64_t p, // X->X^p + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +) { + module->func.vec_znx_automorphism(module, // N + p, // p + res, res_size, res_sl, // res + a, a_size, a_sl // a + ); +} + +EXPORT void vec_znx_normalize_base2k(const MODULE* module, uint64_t nn, // N + uint64_t log2_base2k, // output base 2^K + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + uint8_t* tmp_space // scratch space of size >= N +) { + module->func.vec_znx_normalize_base2k(module, nn, // N + log2_base2k, // log2_base2k + res, res_size, res_sl, // res + a, a_size, a_sl, // a + tmp_space); +} + +EXPORT uint64_t vec_znx_normalize_base2k_tmp_bytes(const MODULE* module, uint64_t nn // N +) { + return module->func.vec_znx_normalize_base2k_tmp_bytes(module, nn // N + ); +} + +// specialized function (ref) + +EXPORT void vec_znx_add_ref(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +) { + const uint64_t nn = module->nn; + if (a_size <= b_size) { + const uint64_t sum_idx = res_size < a_size ? res_size : a_size; + const uint64_t copy_idx = res_size < b_size ? res_size : b_size; + // add up to the smallest dimension + for (uint64_t i = 0; i < sum_idx; ++i) { + znx_add_i64_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + // then copy to the largest dimension + for (uint64_t i = sum_idx; i < copy_idx; ++i) { + znx_copy_i64_ref(nn, res + i * res_sl, b + i * b_sl); + } + // then extend with zeros + for (uint64_t i = copy_idx; i < res_size; ++i) { + znx_zero_i64_ref(nn, res + i * res_sl); + } + } else { + const uint64_t sum_idx = res_size < b_size ? res_size : b_size; + const uint64_t copy_idx = res_size < a_size ? res_size : a_size; + // add up to the smallest dimension + for (uint64_t i = 0; i < sum_idx; ++i) { + znx_add_i64_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + // then copy to the largest dimension + for (uint64_t i = sum_idx; i < copy_idx; ++i) { + znx_copy_i64_ref(nn, res + i * res_sl, a + i * a_sl); + } + // then extend with zeros + for (uint64_t i = copy_idx; i < res_size; ++i) { + znx_zero_i64_ref(nn, res + i * res_sl); + } + } +} + +EXPORT void vec_znx_sub_ref(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +) { + const uint64_t nn = module->nn; + if (a_size <= b_size) { + const uint64_t sub_idx = res_size < a_size ? res_size : a_size; + const uint64_t copy_idx = res_size < b_size ? res_size : b_size; + // subtract up to the smallest dimension + for (uint64_t i = 0; i < sub_idx; ++i) { + znx_sub_i64_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + // then negate to the largest dimension + for (uint64_t i = sub_idx; i < copy_idx; ++i) { + znx_negate_i64_ref(nn, res + i * res_sl, b + i * b_sl); + } + // then extend with zeros + for (uint64_t i = copy_idx; i < res_size; ++i) { + znx_zero_i64_ref(nn, res + i * res_sl); + } + } else { + const uint64_t sub_idx = res_size < b_size ? res_size : b_size; + const uint64_t copy_idx = res_size < a_size ? res_size : a_size; + // subtract up to the smallest dimension + for (uint64_t i = 0; i < sub_idx; ++i) { + znx_sub_i64_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + // then copy to the largest dimension + for (uint64_t i = sub_idx; i < copy_idx; ++i) { + znx_copy_i64_ref(nn, res + i * res_sl, a + i * a_sl); + } + // then extend with zeros + for (uint64_t i = copy_idx; i < res_size; ++i) { + znx_zero_i64_ref(nn, res + i * res_sl); + } + } +} + +EXPORT void vec_znx_rotate_ref(const MODULE* module, // N + const int64_t p, // rotation value + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->nn; + + const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size; + // rotate up to the smallest dimension + for (uint64_t i = 0; i < rot_end_idx; ++i) { + int64_t* res_ptr = res + i * res_sl; + const int64_t* a_ptr = a + i * a_sl; + if (res_ptr == a_ptr) { + znx_rotate_inplace_i64(nn, p, res_ptr); + } else { + znx_rotate_i64(nn, p, res_ptr, a_ptr); + } + } + // then extend with zeros + for (uint64_t i = rot_end_idx; i < res_size; ++i) { + znx_zero_i64_ref(nn, res + i * res_sl); + } +} + +EXPORT void vec_znx_mul_xp_minus_one_ref(const MODULE* module, // N + const int64_t p, // p + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->nn; + + const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < rot_end_idx; ++i) { + int64_t* res_ptr = res + i * res_sl; + const int64_t* a_ptr = a + i * a_sl; + if (res_ptr == a_ptr) { + znx_mul_xp_minus_one_inplace_i64(nn, p, res_ptr); + } else { + znx_mul_xp_minus_one_i64(nn, p, res_ptr, a_ptr); + } + } + // then extend with zeros + for (uint64_t i = rot_end_idx; i < res_size; ++i) { + znx_zero_i64_ref(nn, res + i * res_sl); + } +} + +EXPORT void vec_znx_automorphism_ref(const MODULE* module, // N + const int64_t p, // X->X^p + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->nn; + + const uint64_t auto_end_idx = res_size < a_size ? res_size : a_size; + + for (uint64_t i = 0; i < auto_end_idx; ++i) { + int64_t* res_ptr = res + i * res_sl; + const int64_t* a_ptr = a + i * a_sl; + if (res_ptr == a_ptr) { + znx_automorphism_inplace_i64(nn, p, res_ptr); + } else { + znx_automorphism_i64(nn, p, res_ptr, a_ptr); + } + } + // then extend with zeros + for (uint64_t i = auto_end_idx; i < res_size; ++i) { + znx_zero_i64_ref(nn, res + i * res_sl); + } +} + +EXPORT void vec_znx_normalize_base2k_ref(const MODULE* module, uint64_t nn, // N + uint64_t log2_base2k, // output base 2^K + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + uint8_t* tmp_space // scratch space of size >= N +) { + + // use MSB limb of res for carry propagation + int64_t* cout = (int64_t*)tmp_space; + int64_t* cin = 0x0; + + // propagate carry until first limb of res + int64_t i = a_size - 1; + for (; i >= res_size; --i) { + znx_normalize(nn, log2_base2k, 0x0, cout, a + i * a_sl, cin); + cin = cout; + } + + // propagate carry and normalize + for (; i >= 1; --i) { + znx_normalize(nn, log2_base2k, res + i * res_sl, cout, a + i * a_sl, cin); + cin = cout; + } + + // normalize last limb + znx_normalize(nn, log2_base2k, res, 0x0, a, cin); + + // extend result with zeros + for (uint64_t i = a_size; i < res_size; ++i) { + znx_zero_i64_ref(nn, res + i * res_sl); + } +} + +EXPORT uint64_t vec_znx_normalize_base2k_tmp_bytes_ref(const MODULE* module, uint64_t nn // N +) { + return nn * sizeof(int64_t); +} + +// alias have to be defined in this unit: do not move +#ifdef __APPLE__ +EXPORT uint64_t fft64_vec_znx_big_range_normalize_base2k_tmp_bytes( // + const MODULE* module, // N + uint64_t nn +) { + return vec_znx_normalize_base2k_tmp_bytes_ref(module, nn); +} +EXPORT uint64_t fft64_vec_znx_big_normalize_base2k_tmp_bytes( // + const MODULE* module, // N + uint64_t nn +) { + return vec_znx_normalize_base2k_tmp_bytes_ref(module, nn); +} +#else +EXPORT uint64_t fft64_vec_znx_big_normalize_base2k_tmp_bytes( // + const MODULE* module, // N + uint64_t nn + ) __attribute((alias("vec_znx_normalize_base2k_tmp_bytes_ref"))); + +EXPORT uint64_t fft64_vec_znx_big_range_normalize_base2k_tmp_bytes( // + const MODULE* module, // N + uint64_t nn + ) __attribute((alias("vec_znx_normalize_base2k_tmp_bytes_ref"))); +#endif + +/** @brief sets res = 0 */ +EXPORT void vec_znx_zero(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl // res +) { + module->func.vec_znx_zero(module, res, res_size, res_sl); +} + +/** @brief sets res = a */ +EXPORT void vec_znx_copy(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +) { + module->func.vec_znx_copy(module, res, res_size, res_sl, a, a_size, a_sl); +} + +/** @brief sets res = a */ +EXPORT void vec_znx_negate(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +) { + module->func.vec_znx_negate(module, res, res_size, res_sl, a, a_size, a_sl); +} + +EXPORT void vec_znx_zero_ref(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl // res +) { + uint64_t nn = module->nn; + for (uint64_t i = 0; i < res_size; ++i) { + znx_zero_i64_ref(nn, res + i * res_sl); + } +} + +EXPORT void vec_znx_copy_ref(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +) { + uint64_t nn = module->nn; + uint64_t smin = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < smin; ++i) { + znx_copy_i64_ref(nn, res + i * res_sl, a + i * a_sl); + } + for (uint64_t i = smin; i < res_size; ++i) { + znx_zero_i64_ref(nn, res + i * res_sl); + } +} + +EXPORT void vec_znx_negate_ref(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +) { + uint64_t nn = module->nn; + uint64_t smin = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < smin; ++i) { + znx_negate_i64_ref(nn, res + i * res_sl, a + i * a_sl); + } + for (uint64_t i = smin; i < res_size; ++i) { + znx_zero_i64_ref(nn, res + i * res_sl); + } +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_znx_arithmetic.h b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_znx_arithmetic.h new file mode 100644 index 0000000..1dea5c6 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_znx_arithmetic.h @@ -0,0 +1,370 @@ +#ifndef SPQLIOS_VEC_ZNX_ARITHMETIC_H +#define SPQLIOS_VEC_ZNX_ARITHMETIC_H + +#include + +#include "../commons.h" +#include "../reim/reim_fft.h" + +/** + * We support the following module families: + * - FFT64: + * all the polynomials should fit at all times over 52 bits. + * for FHE implementations, the recommended limb-sizes are + * between K=10 and 20, which is good for low multiplicative depths. + * - NTT120: + * all the polynomials should fit at all times over 119 bits. + * for FHE implementations, the recommended limb-sizes are + * between K=20 and 40, which is good for large multiplicative depths. + */ +typedef enum module_type_t { FFT64, NTT120 } MODULE_TYPE; + +/** @brief opaque structure that describr the modules (ZnX,TnX) and the hardware */ +typedef struct module_info_t MODULE; +/** @brief opaque type that represents a prepared matrix */ +typedef struct vmp_pmat_t VMP_PMAT; +/** @brief opaque type that represents a vector of znx in DFT space */ +typedef struct vec_znx_dft_t VEC_ZNX_DFT; +/** @brief opaque type that represents a vector of znx in large coeffs space */ +typedef struct vec_znx_bigcoeff_t VEC_ZNX_BIG; +/** @brief opaque type that represents a prepared scalar vector product */ +typedef struct svp_ppol_t SVP_PPOL; +/** @brief opaque type that represents a prepared left convolution vector product */ +typedef struct cnv_pvec_l_t CNV_PVEC_L; +/** @brief opaque type that represents a prepared right convolution vector product */ +typedef struct cnv_pvec_r_t CNV_PVEC_R; + +/** @brief bytes needed for a vec_znx in DFT space */ +EXPORT uint64_t bytes_of_vec_znx_dft(const MODULE* module, // N + uint64_t size); + +/** @brief allocates a vec_znx in DFT space */ +EXPORT VEC_ZNX_DFT* new_vec_znx_dft(const MODULE* module, // N + uint64_t size); + +/** @brief frees memory from a vec_znx in DFT space */ +EXPORT void delete_vec_znx_dft(VEC_ZNX_DFT* res); + +/** @brief bytes needed for a vec_znx_big */ +EXPORT uint64_t bytes_of_vec_znx_big(const MODULE* module, // N + uint64_t size); + +/** @brief allocates a vec_znx_big */ +EXPORT VEC_ZNX_BIG* new_vec_znx_big(const MODULE* module, // N + uint64_t size); +/** @brief frees memory from a vec_znx_big */ +EXPORT void delete_vec_znx_big(VEC_ZNX_BIG* res); + +/** @brief bytes needed for a prepared vector */ +EXPORT uint64_t bytes_of_svp_ppol(const MODULE* module); // N + +/** @brief allocates a prepared vector */ +EXPORT SVP_PPOL* new_svp_ppol(const MODULE* module); // N + +/** @brief frees memory for a prepared vector */ +EXPORT void delete_svp_ppol(SVP_PPOL* res); + +/** @brief bytes needed for a prepared matrix */ +EXPORT uint64_t bytes_of_vmp_pmat(const MODULE* module, // N + uint64_t nrows, uint64_t ncols); + +/** @brief allocates a prepared matrix */ +EXPORT VMP_PMAT* new_vmp_pmat(const MODULE* module, // N + uint64_t nrows, uint64_t ncols); + +/** @brief frees memory for a prepared matrix */ +EXPORT void delete_vmp_pmat(VMP_PMAT* res); + +/** + * @brief obtain a module info for ring dimension N + * the module-info knows about: + * - the dimension N (or the complex dimension m=N/2) + * - any moduleuted fft or ntt items + * - the hardware (avx, arm64, x86, ...) + */ +EXPORT MODULE* new_module_info(uint64_t N, MODULE_TYPE mode); +EXPORT void delete_module_info(MODULE* module_info); +EXPORT uint64_t module_get_n(const MODULE* module); + +/** @brief sets res = 0 */ +EXPORT void vec_znx_zero(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl // res +); + +/** @brief sets res = a */ +EXPORT void vec_znx_copy(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief sets res = a */ +EXPORT void vec_znx_negate(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief sets res = a + b */ +EXPORT void vec_znx_add(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +); + +/** @brief sets res = a - b */ +EXPORT void vec_znx_sub(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +); + +/** @brief sets res = k-normalize-reduce(a) */ +EXPORT void vec_znx_normalize_base2k(const MODULE* module, // MODULE + uint64_t nn, // N + uint64_t log2_base2k, // output base 2^K + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + uint8_t* tmp_space // scratch space (size >= N) +); + +/** @brief returns the minimal byte length of scratch space for vec_znx_normalize_base2k */ +EXPORT uint64_t vec_znx_normalize_base2k_tmp_bytes(const MODULE* module, uint64_t nn // N +); + +/** @brief sets res = a . X^p */ +EXPORT void vec_znx_rotate(const MODULE* module, // N + const int64_t p, // rotation value + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief sets res = a * (X^{p} - 1) */ +EXPORT void vec_znx_mul_xp_minus_one(const MODULE* module, // N + const int64_t p, // rotation value + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief sets res = a(X^p) */ +EXPORT void vec_znx_automorphism(const MODULE* module, // N + const int64_t p, // X-X^p + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief sets res = 0 */ +EXPORT void vec_dft_zero(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size // res +); + +/** @brief sets res = a+b */ +EXPORT void vec_dft_add(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const VEC_ZNX_DFT* a, uint64_t a_size, // a + const VEC_ZNX_DFT* b, uint64_t b_size // b +); + +/** @brief sets res = a-b */ +EXPORT void vec_dft_sub(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const VEC_ZNX_DFT* a, uint64_t a_size, // a + const VEC_ZNX_DFT* b, uint64_t b_size // b +); + +/** @brief sets res = DFT(a) */ +EXPORT void vec_znx_dft(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief sets res = iDFT(a_dft) -- output in big coeffs space */ +EXPORT void vec_znx_idft(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a + uint8_t* tmp // scratch space +); + +/** @brief tmp bytes required for vec_znx_idft */ +EXPORT uint64_t vec_znx_idft_tmp_bytes(const MODULE* module, uint64_t nn); + +/** + * @brief sets res = iDFT(a_dft) -- output in big coeffs space + * + * @note a_dft is overwritten + */ +EXPORT void vec_znx_idft_tmp_a(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + VEC_ZNX_DFT* a_dft, uint64_t a_size // a is overwritten +); + +/** @brief sets res = a+b */ +EXPORT void vec_znx_big_add(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size, // a + const VEC_ZNX_BIG* b, uint64_t b_size // b +); +/** @brief sets res = a+b */ +EXPORT void vec_znx_big_add_small(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +); +EXPORT void vec_znx_big_add_small2(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +); + +/** @brief sets res = a-b */ +EXPORT void vec_znx_big_sub(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size, // a + const VEC_ZNX_BIG* b, uint64_t b_size // b +); +EXPORT void vec_znx_big_sub_small_b(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +); +EXPORT void vec_znx_big_sub_small_a(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const VEC_ZNX_BIG* b, uint64_t b_size // b +); +EXPORT void vec_znx_big_sub_small2(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +); + +/** @brief sets res = k-normalize(a) -- output in int64 coeffs space */ +EXPORT void vec_znx_big_normalize_base2k(const MODULE* module, // MODULE + uint64_t nn, // N + uint64_t log2_base2k, // base-2^k + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const VEC_ZNX_BIG* a, uint64_t a_size, // a + uint8_t* tmp_space // temp space +); + +/** @brief returns the minimal byte length of scratch space for vec_znx_big_normalize_base2k */ +EXPORT uint64_t vec_znx_big_normalize_base2k_tmp_bytes(const MODULE* module, uint64_t nn // N +); + +/** @brief sets res = k-normalize(a.subrange) -- output in int64 coeffs space */ +EXPORT void vec_znx_big_range_normalize_base2k( // + const MODULE* module, // MODULE + uint64_t nn, + uint64_t log2_base2k, // base-2^k + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const VEC_ZNX_BIG* a, uint64_t a_range_begin, uint64_t a_range_xend, uint64_t a_range_step, // range + uint8_t* tmp_space // temp space +); + +/** @brief returns the minimal byte length of scratch space for vec_znx_big_range_normalize_base2k */ +EXPORT uint64_t vec_znx_big_range_normalize_base2k_tmp_bytes( // + const MODULE* module, uint64_t nn // N +); + +/** @brief sets res = a . X^p */ +EXPORT void vec_znx_big_rotate(const MODULE* module, // N + int64_t p, // rotation value + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size // a +); + +/** @brief sets res = a(X^p) */ +EXPORT void vec_znx_big_automorphism(const MODULE* module, // N + int64_t p, // X-X^p + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size // a +); + +/** @brief apply a svp product, result = ppol * a, presented in DFT space */ +EXPORT void svp_apply_dft(const MODULE* module, // N + const VEC_ZNX_DFT* res, uint64_t res_size, // output + const SVP_PPOL* ppol, // prepared pol + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief apply a svp product, result = ppol * a, presented in DFT space */ +EXPORT void svp_apply_dft_to_dft(const MODULE* module, // N + const VEC_ZNX_DFT* res, uint64_t res_size, + uint64_t res_cols, // output + const SVP_PPOL* ppol, // prepared pol + const VEC_ZNX_DFT* a, uint64_t a_size, uint64_t a_cols // a +); + +/** @brief prepares a svp polynomial */ +EXPORT void svp_prepare(const MODULE* module, // N + SVP_PPOL* ppol, // output + const int64_t* pol // a +); + +/** @brief res = a * b : small integer polynomial product */ +EXPORT void znx_small_single_product(const MODULE* module, // N + int64_t* res, // output + const int64_t* a, // a + const int64_t* b, // b + uint8_t* tmp); + +/** @brief tmp bytes required for znx_small_single_product */ +EXPORT uint64_t znx_small_single_product_tmp_bytes(const MODULE* module, uint64_t nn); + +/** @brief minimal scratch space byte-size required for the vmp_prepare function */ +EXPORT uint64_t vmp_prepare_tmp_bytes(const MODULE* module, uint64_t nn, // N + uint64_t nrows, uint64_t ncols); + +/** @brief prepares a vmp matrix (contiguous row-major version) */ +EXPORT void vmp_prepare_contiguous(const MODULE* module, // N + VMP_PMAT* pmat, // output + const int64_t* mat, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +); + +/** @brief applies a vmp product (result in DFT space) adds to res inplace */ +EXPORT void vmp_apply_dft_add(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, uint64_t pmat_scale, // prep matrix + uint8_t* tmp_space // scratch space +); + +/** @brief applies a vmp product (result in DFT space) */ +EXPORT void vmp_apply_dft(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space +); + +/** @brief minimal size of the tmp_space */ +EXPORT uint64_t vmp_apply_dft_tmp_bytes(const MODULE* module, uint64_t nn, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix +); + +/** @brief applies vmp product */ +EXPORT void vmp_apply_dft_to_dft(const MODULE* module, // N + VEC_ZNX_DFT* res, const uint64_t res_size, // res + const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a + const VMP_PMAT* pmat, const uint64_t nrows, + const uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes) +); + +/** @brief applies vmp product and adds to res inplace */ +EXPORT void vmp_apply_dft_to_dft_add(const MODULE* module, // N + VEC_ZNX_DFT* res, const uint64_t res_size, // res + const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a + const VMP_PMAT* pmat, const uint64_t nrows, const uint64_t ncols, + const uint64_t pmat_scale, // prep matrix + uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes) +); + +/** @brief minimal size of the tmp_space */ +EXPORT uint64_t vmp_apply_dft_to_dft_tmp_bytes(const MODULE* module, uint64_t nn, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix +); +#endif // SPQLIOS_VEC_ZNX_ARITHMETIC_H diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_znx_arithmetic_private.h b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_znx_arithmetic_private.h new file mode 100644 index 0000000..73c1729 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_znx_arithmetic_private.h @@ -0,0 +1,563 @@ +#ifndef SPQLIOS_VEC_ZNX_ARITHMETIC_PRIVATE_H +#define SPQLIOS_VEC_ZNX_ARITHMETIC_PRIVATE_H + +#include "../commons_private.h" +#include "../q120/q120_ntt.h" +#include "vec_znx_arithmetic.h" + +/** + * Layouts families: + * + * fft64: + * K: <= 20, N: <= 65536, ell: <= 200 + * vec normalized: represented by int64 + * vec large: represented by int64 (expect <=52 bits) + * vec DFT: represented by double (reim_fft space) + * On AVX2 inftastructure, PMAT, LCNV, RCNV use a special reim4_fft space + * + * ntt120: + * K: <= 50, N: <= 65536, ell: <= 80 + * vec normalized: represented by int64 + * vec large: represented by int128 (expect <=120 bits) + * vec DFT: represented by int64x4 (ntt120 space) + * On AVX2 inftastructure, PMAT, LCNV, RCNV use a special ntt120 space + * + * ntt104: + * K: <= 40, N: <= 65536, ell: <= 80 + * vec normalized: represented by int64 + * vec large: represented by int128 (expect <=120 bits) + * vec DFT: represented by int64x4 (ntt120 space) + * On AVX512 inftastructure, PMAT, LCNV, RCNV use a special ntt104 space + */ + +struct fft64_module_info_t { + // pre-computation for reim_fft + REIM_FFT_PRECOMP* p_fft; + // pre-computation for add_fft + REIM_FFTVEC_ADD_PRECOMP* add_fft; + // pre-computation for add_fft + REIM_FFTVEC_SUB_PRECOMP* sub_fft; + // pre-computation for mul_fft + REIM_FFTVEC_MUL_PRECOMP* mul_fft; + // pre-computation for reim_from_znx6 + REIM_FROM_ZNX64_PRECOMP* p_conv; + // pre-computation for reim_tp_znx6 + REIM_TO_ZNX64_PRECOMP* p_reim_to_znx; + // pre-computation for reim_fft + REIM_IFFT_PRECOMP* p_ifft; + // pre-computation for reim_fftvec_addmul + REIM_FFTVEC_ADDMUL_PRECOMP* p_addmul; +}; + +struct q120_module_info_t { + // pre-computation for q120b to q120b ntt + q120_ntt_precomp* p_ntt; + // pre-computation for q120b to q120b intt + q120_ntt_precomp* p_intt; +}; + +// TODO add function types here +typedef typeof(vec_znx_zero) VEC_ZNX_ZERO_F; +typedef typeof(vec_znx_copy) VEC_ZNX_COPY_F; +typedef typeof(vec_znx_negate) VEC_ZNX_NEGATE_F; +typedef typeof(vec_znx_add) VEC_ZNX_ADD_F; +typedef typeof(vec_znx_dft) VEC_ZNX_DFT_F; +typedef typeof(vec_dft_add) VEC_DFT_ADD_F; +typedef typeof(vec_dft_sub) VEC_DFT_SUB_F; +typedef typeof(vec_znx_idft) VEC_ZNX_IDFT_F; +typedef typeof(vec_znx_idft_tmp_bytes) VEC_ZNX_IDFT_TMP_BYTES_F; +typedef typeof(vec_znx_idft_tmp_a) VEC_ZNX_IDFT_TMP_A_F; +typedef typeof(vec_znx_sub) VEC_ZNX_SUB_F; +typedef typeof(vec_znx_rotate) VEC_ZNX_ROTATE_F; +typedef typeof(vec_znx_mul_xp_minus_one) VEC_ZNX_MUL_XP_MINUS_ONE_F; +typedef typeof(vec_znx_automorphism) VEC_ZNX_AUTOMORPHISM_F; +typedef typeof(vec_znx_normalize_base2k) VEC_ZNX_NORMALIZE_BASE2K_F; +typedef typeof(vec_znx_normalize_base2k_tmp_bytes) VEC_ZNX_NORMALIZE_BASE2K_TMP_BYTES_F; +typedef typeof(vec_znx_big_normalize_base2k) VEC_ZNX_BIG_NORMALIZE_BASE2K_F; +typedef typeof(vec_znx_big_normalize_base2k_tmp_bytes) VEC_ZNX_BIG_NORMALIZE_BASE2K_TMP_BYTES_F; +typedef typeof(vec_znx_big_range_normalize_base2k) VEC_ZNX_BIG_RANGE_NORMALIZE_BASE2K_F; +typedef typeof(vec_znx_big_range_normalize_base2k_tmp_bytes) VEC_ZNX_BIG_RANGE_NORMALIZE_BASE2K_TMP_BYTES_F; +typedef typeof(vec_znx_big_add) VEC_ZNX_BIG_ADD_F; +typedef typeof(vec_znx_big_add_small) VEC_ZNX_BIG_ADD_SMALL_F; +typedef typeof(vec_znx_big_add_small2) VEC_ZNX_BIG_ADD_SMALL2_F; +typedef typeof(vec_znx_big_sub) VEC_ZNX_BIG_SUB_F; +typedef typeof(vec_znx_big_sub_small_a) VEC_ZNX_BIG_SUB_SMALL_A_F; +typedef typeof(vec_znx_big_sub_small_b) VEC_ZNX_BIG_SUB_SMALL_B_F; +typedef typeof(vec_znx_big_sub_small2) VEC_ZNX_BIG_SUB_SMALL2_F; +typedef typeof(vec_znx_big_rotate) VEC_ZNX_BIG_ROTATE_F; +typedef typeof(vec_znx_big_automorphism) VEC_ZNX_BIG_AUTOMORPHISM_F; +typedef typeof(svp_prepare) SVP_PREPARE; +typedef typeof(svp_apply_dft) SVP_APPLY_DFT_F; +typedef typeof(svp_apply_dft_to_dft) SVP_APPLY_DFT_TO_DFT_F; +typedef typeof(znx_small_single_product) ZNX_SMALL_SINGLE_PRODUCT_F; +typedef typeof(znx_small_single_product_tmp_bytes) ZNX_SMALL_SINGLE_PRODUCT_TMP_BYTES_F; +typedef typeof(vmp_prepare_contiguous) VMP_PREPARE_CONTIGUOUS_F; +typedef typeof(vmp_prepare_tmp_bytes) VMP_PREPARE_TMP_BYTES_F; +typedef typeof(vmp_apply_dft) VMP_APPLY_DFT_F; +typedef typeof(vmp_apply_dft_add) VMP_APPLY_DFT_ADD_F; +typedef typeof(vmp_apply_dft_tmp_bytes) VMP_APPLY_DFT_TMP_BYTES_F; +typedef typeof(vmp_apply_dft_to_dft) VMP_APPLY_DFT_TO_DFT_F; +typedef typeof(vmp_apply_dft_to_dft_add) VMP_APPLY_DFT_TO_DFT_ADD_F; +typedef typeof(vmp_apply_dft_to_dft_tmp_bytes) VMP_APPLY_DFT_TO_DFT_TMP_BYTES_F; +typedef typeof(bytes_of_vec_znx_dft) BYTES_OF_VEC_ZNX_DFT_F; +typedef typeof(bytes_of_vec_znx_big) BYTES_OF_VEC_ZNX_BIG_F; +typedef typeof(bytes_of_svp_ppol) BYTES_OF_SVP_PPOL_F; +typedef typeof(bytes_of_vmp_pmat) BYTES_OF_VMP_PMAT_F; + +struct module_virtual_functions_t { + // TODO add functions here + VEC_ZNX_ZERO_F* vec_znx_zero; + VEC_ZNX_COPY_F* vec_znx_copy; + VEC_ZNX_NEGATE_F* vec_znx_negate; + VEC_ZNX_ADD_F* vec_znx_add; + VEC_ZNX_DFT_F* vec_znx_dft; + VEC_DFT_ADD_F* vec_dft_add; + VEC_DFT_SUB_F* vec_dft_sub; + VEC_ZNX_IDFT_F* vec_znx_idft; + VEC_ZNX_IDFT_TMP_BYTES_F* vec_znx_idft_tmp_bytes; + VEC_ZNX_IDFT_TMP_A_F* vec_znx_idft_tmp_a; + VEC_ZNX_SUB_F* vec_znx_sub; + VEC_ZNX_ROTATE_F* vec_znx_rotate; + VEC_ZNX_MUL_XP_MINUS_ONE_F* vec_znx_mul_xp_minus_one; + VEC_ZNX_AUTOMORPHISM_F* vec_znx_automorphism; + VEC_ZNX_NORMALIZE_BASE2K_F* vec_znx_normalize_base2k; + VEC_ZNX_NORMALIZE_BASE2K_TMP_BYTES_F* vec_znx_normalize_base2k_tmp_bytes; + VEC_ZNX_BIG_NORMALIZE_BASE2K_F* vec_znx_big_normalize_base2k; + VEC_ZNX_BIG_NORMALIZE_BASE2K_TMP_BYTES_F* vec_znx_big_normalize_base2k_tmp_bytes; + VEC_ZNX_BIG_RANGE_NORMALIZE_BASE2K_F* vec_znx_big_range_normalize_base2k; + VEC_ZNX_BIG_RANGE_NORMALIZE_BASE2K_TMP_BYTES_F* vec_znx_big_range_normalize_base2k_tmp_bytes; + VEC_ZNX_BIG_ADD_F* vec_znx_big_add; + VEC_ZNX_BIG_ADD_SMALL_F* vec_znx_big_add_small; + VEC_ZNX_BIG_ADD_SMALL2_F* vec_znx_big_add_small2; + VEC_ZNX_BIG_SUB_F* vec_znx_big_sub; + VEC_ZNX_BIG_SUB_SMALL_A_F* vec_znx_big_sub_small_a; + VEC_ZNX_BIG_SUB_SMALL_B_F* vec_znx_big_sub_small_b; + VEC_ZNX_BIG_SUB_SMALL2_F* vec_znx_big_sub_small2; + VEC_ZNX_BIG_ROTATE_F* vec_znx_big_rotate; + VEC_ZNX_BIG_AUTOMORPHISM_F* vec_znx_big_automorphism; + SVP_PREPARE* svp_prepare; + SVP_APPLY_DFT_F* svp_apply_dft; + SVP_APPLY_DFT_TO_DFT_F* svp_apply_dft_to_dft; + ZNX_SMALL_SINGLE_PRODUCT_F* znx_small_single_product; + ZNX_SMALL_SINGLE_PRODUCT_TMP_BYTES_F* znx_small_single_product_tmp_bytes; + VMP_PREPARE_CONTIGUOUS_F* vmp_prepare_contiguous; + VMP_PREPARE_TMP_BYTES_F* vmp_prepare_tmp_bytes; + VMP_APPLY_DFT_F* vmp_apply_dft; + VMP_APPLY_DFT_ADD_F* vmp_apply_dft_add; + VMP_APPLY_DFT_TMP_BYTES_F* vmp_apply_dft_tmp_bytes; + VMP_APPLY_DFT_TO_DFT_F* vmp_apply_dft_to_dft; + VMP_APPLY_DFT_TO_DFT_ADD_F* vmp_apply_dft_to_dft_add; + VMP_APPLY_DFT_TO_DFT_TMP_BYTES_F* vmp_apply_dft_to_dft_tmp_bytes; + BYTES_OF_VEC_ZNX_DFT_F* bytes_of_vec_znx_dft; + BYTES_OF_VEC_ZNX_BIG_F* bytes_of_vec_znx_big; + BYTES_OF_SVP_PPOL_F* bytes_of_svp_ppol; + BYTES_OF_VMP_PMAT_F* bytes_of_vmp_pmat; +}; + +union backend_module_info_t { + struct fft64_module_info_t fft64; + struct q120_module_info_t q120; +}; + +struct module_info_t { + // generic parameters + MODULE_TYPE module_type; + uint64_t nn; + uint64_t m; + // backend_dependent functions + union backend_module_info_t mod; + // virtual functions + struct module_virtual_functions_t func; +}; + +EXPORT uint64_t fft64_bytes_of_vec_znx_dft(const MODULE* module, // N + uint64_t size); + +EXPORT uint64_t fft64_bytes_of_vec_znx_big(const MODULE* module, // N + uint64_t size); + +EXPORT uint64_t fft64_bytes_of_svp_ppol(const MODULE* module); // N + +EXPORT uint64_t fft64_bytes_of_vmp_pmat(const MODULE* module, // N + uint64_t nrows, uint64_t ncols); + +EXPORT void vec_znx_zero_ref(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl // res +); + +EXPORT void vec_znx_copy_ref(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_znx_negate_ref(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_znx_negate_avx(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_znx_add_ref(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +); +EXPORT void vec_znx_add_avx(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +); + +EXPORT void vec_znx_sub_ref(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +); + +EXPORT void vec_znx_sub_avx(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +); + +EXPORT void vec_znx_normalize_base2k_ref(const MODULE* module, uint64_t nn, // N + uint64_t log2_base2k, // output base 2^K + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // inp + uint8_t* tmp_space // scratch space +); + +EXPORT uint64_t vec_znx_normalize_base2k_tmp_bytes_ref(const MODULE* module, uint64_t nn // N +); + +EXPORT void vec_znx_rotate_ref(const MODULE* module, // N + const int64_t p, // rotation value + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_znx_mul_xp_minus_one_ref(const MODULE* module, // N + const int64_t p, // rotation value + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_znx_automorphism_ref(const MODULE* module, // N + const int64_t p, // X->X^p + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vmp_prepare_ref(const MODULE* module, // N + VMP_PMAT* pmat, // output + const int64_t* mat, uint64_t nrows, uint64_t ncols // a +); + +EXPORT void vmp_apply_dft_ref(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols // prep matrix +); + +EXPORT void vec_dft_zero_ref(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size // res +); + +EXPORT void vec_dft_add_ref(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const VEC_ZNX_DFT* a, uint64_t a_size, // a + const VEC_ZNX_DFT* b, uint64_t b_size // b +); + +EXPORT void vec_dft_sub_ref(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const VEC_ZNX_DFT* a, uint64_t a_size, // a + const VEC_ZNX_DFT* b, uint64_t b_size // b +); + +EXPORT void vec_dft_ref(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_idft_ref(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_DFT* a_dft, uint64_t a_size); + +EXPORT void vec_znx_big_normalize_ref(const MODULE* module, // MODULE + uint64_t nn, // N + uint64_t k, // base-2^k + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const VEC_ZNX_BIG* a, uint64_t a_size // a +); + +/** @brief apply a svp product, result = ppol * a, presented in DFT space */ +EXPORT void fft64_svp_apply_dft_ref(const MODULE* module, // N + const VEC_ZNX_DFT* res, uint64_t res_size, // output + const SVP_PPOL* ppol, // prepared pol + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief apply a svp product, result = ppol * a, presented in DFT space */ +EXPORT void fft64_svp_apply_dft_to_dft_ref(const MODULE* module, // N + const VEC_ZNX_DFT* res, uint64_t res_size, + uint64_t res_cols, // output + const SVP_PPOL* ppol, // prepared pol + const VEC_ZNX_DFT* a, uint64_t a_size, + uint64_t a_cols // a +); + +/** @brief sets res = k-normalize(a) -- output in int64 coeffs space */ +EXPORT void fft64_vec_znx_big_normalize_base2k(const MODULE* module, // MODULE + uint64_t nn, // N + uint64_t k, // base-2^k + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const VEC_ZNX_BIG* a, uint64_t a_size, // a + uint8_t* tmp_space // temp space +); + +/** @brief returns the minimal byte length of scratch space for vec_znx_big_normalize_base2k */ +EXPORT uint64_t fft64_vec_znx_big_normalize_base2k_tmp_bytes(const MODULE* module, uint64_t nn // N + +); + +/** @brief sets res = k-normalize(a.subrange) -- output in int64 coeffs space */ +EXPORT void fft64_vec_znx_big_range_normalize_base2k(const MODULE* module, // MODULE + uint64_t nn, + uint64_t log2_base2k, // base-2^k + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const VEC_ZNX_BIG* a, uint64_t a_range_begin, // a + uint64_t a_range_xend, uint64_t a_range_step, // range + uint8_t* tmp_space // temp space +); + +/** @brief returns the minimal byte length of scratch space for vec_znx_big_range_normalize_base2k */ +EXPORT uint64_t fft64_vec_znx_big_range_normalize_base2k_tmp_bytes(const MODULE* module, uint64_t nn // N +); + +EXPORT void fft64_vec_znx_dft(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void fft64_vec_dft_add(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const VEC_ZNX_DFT* a, uint64_t a_size, // a + const VEC_ZNX_DFT* b, uint64_t b_size // b +); + +EXPORT void fft64_vec_dft_sub(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const VEC_ZNX_DFT* a, uint64_t a_size, // a + const VEC_ZNX_DFT* b, uint64_t b_size // b +); + +EXPORT void fft64_vec_znx_idft(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a + uint8_t* tmp // scratch space +); + +EXPORT uint64_t fft64_vec_znx_idft_tmp_bytes(const MODULE* module, uint64_t nn); + +EXPORT void fft64_vec_znx_idft_tmp_a(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + VEC_ZNX_DFT* a_dft, uint64_t a_size // a is overwritten +); + +EXPORT void ntt120_vec_znx_dft_avx(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +); + +/** */ +EXPORT void ntt120_vec_znx_idft_avx(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a + uint8_t* tmp // scratch space +); + +EXPORT uint64_t ntt120_vec_znx_idft_tmp_bytes_avx(const MODULE* module, uint64_t nn); + +EXPORT void ntt120_vec_znx_idft_tmp_a_avx(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + VEC_ZNX_DFT* a_dft, uint64_t a_size // a is overwritten +); + +// big additions/subtractions + +/** @brief sets res = a+b */ +EXPORT void fft64_vec_znx_big_add(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size, // a + const VEC_ZNX_BIG* b, uint64_t b_size // b +); +/** @brief sets res = a+b */ +EXPORT void fft64_vec_znx_big_add_small(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +); +EXPORT void fft64_vec_znx_big_add_small2(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +); + +/** @brief sets res = a-b */ +EXPORT void fft64_vec_znx_big_sub(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size, // a + const VEC_ZNX_BIG* b, uint64_t b_size // b +); +EXPORT void fft64_vec_znx_big_sub_small_b(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +); +EXPORT void fft64_vec_znx_big_sub_small_a(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const VEC_ZNX_BIG* b, uint64_t b_size // b +); +EXPORT void fft64_vec_znx_big_sub_small2(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +); + +/** @brief sets res = a . X^p */ +EXPORT void fft64_vec_znx_big_rotate(const MODULE* module, // N + int64_t p, // rotation value + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size // a +); + +/** @brief sets res = a(X^p) */ +EXPORT void fft64_vec_znx_big_automorphism(const MODULE* module, // N + int64_t p, // X-X^p + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size // a +); + +/** @brief prepares a svp polynomial */ +EXPORT void fft64_svp_prepare_ref(const MODULE* module, // N + SVP_PPOL* ppol, // output + const int64_t* pol // a +); + +/** @brief res = a * b : small integer polynomial product */ +EXPORT void fft64_znx_small_single_product(const MODULE* module, // N + int64_t* res, // output + const int64_t* a, // a + const int64_t* b, // b + uint8_t* tmp); + +/** @brief tmp bytes required for znx_small_single_product */ +EXPORT uint64_t fft64_znx_small_single_product_tmp_bytes(const MODULE* module, uint64_t nn); + +/** @brief prepares a vmp matrix (contiguous row-major version) */ +EXPORT void fft64_vmp_prepare_contiguous_ref(const MODULE* module, // N + VMP_PMAT* pmat, // output + const int64_t* mat, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +); + +/** @brief prepares a vmp matrix (contiguous row-major version) */ +EXPORT void fft64_vmp_prepare_contiguous_avx(const MODULE* module, // N + VMP_PMAT* pmat, // output + const int64_t* mat, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +); + +/** @brief minimal scratch space byte-size required for the vmp_prepare function */ +EXPORT uint64_t fft64_vmp_prepare_tmp_bytes(const MODULE* module, uint64_t nn, // N + uint64_t nrows, uint64_t ncols); + +/** @brief applies a vmp product (result in DFT space) and adds to res inplace */ +EXPORT void fft64_vmp_apply_dft_add_ref(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, + uint64_t pmat_scale, // prep matrix + uint8_t* tmp_space // scratch space +); + +/** @brief applies a vmp product (result in DFT space) */ +EXPORT void fft64_vmp_apply_dft_ref(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space +); + +/** @brief applies a vmp product (result in DFT space) */ +EXPORT void fft64_vmp_apply_dft_avx(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space +); + +/** @brief applies a vmp product (result in DFT space) and adds to res inplace*/ +EXPORT void fft64_vmp_apply_dft_add_avx(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, + uint64_t pmat_scale, // prep matrix + uint8_t* tmp_space // scratch space +); + +/** @brief this inner function could be very handy */ +EXPORT void fft64_vmp_apply_dft_to_dft_ref(const MODULE* module, // N + VEC_ZNX_DFT* res, const uint64_t res_size, // res + const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a + const VMP_PMAT* pmat, const uint64_t nrows, + const uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes) +); + +/** @brief applies rmp product and adds to res inplace */ +EXPORT void fft64_vmp_apply_dft_to_dft_add_ref(const MODULE* module, // N + VEC_ZNX_DFT* res, const uint64_t res_size, // res + const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a + const VMP_PMAT* pmat, const uint64_t nrows, const uint64_t ncols, + uint64_t pmat_scale, // prep matrix + uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes) +); + +/** @brief this inner function could be very handy */ +EXPORT void fft64_vmp_apply_dft_to_dft_avx(const MODULE* module, // N + VEC_ZNX_DFT* res, const uint64_t res_size, // res + const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a + const VMP_PMAT* pmat, const uint64_t nrows, + const uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes) +); + +/** @brief applies rmp product and adds to res inplace */ +EXPORT void fft64_vmp_apply_dft_to_dft_add_avx(const MODULE* module, // N + VEC_ZNX_DFT* res, const uint64_t res_size, // res + const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a + const VMP_PMAT* pmat, const uint64_t nrows, const uint64_t ncols, + uint64_t pmat_scale, // prep matrix + uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes) +); + +/** @brief minimal size of the tmp_space */ +EXPORT uint64_t fft64_vmp_apply_dft_tmp_bytes(const MODULE* module, uint64_t nn, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix +); + +/** @brief minimal size of the tmp_space */ +EXPORT uint64_t fft64_vmp_apply_dft_to_dft_tmp_bytes(const MODULE* module, uint64_t nn, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix +); +#endif // SPQLIOS_VEC_ZNX_ARITHMETIC_PRIVATE_H diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_znx_avx.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_znx_avx.c new file mode 100644 index 0000000..100902d --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_znx_avx.c @@ -0,0 +1,103 @@ +#include + +#include "../coeffs/coeffs_arithmetic.h" +#include "../reim4/reim4_arithmetic.h" +#include "vec_znx_arithmetic_private.h" + +// specialized function (ref) + +// Note: these functions do not have an avx variant. +#define znx_copy_i64_avx znx_copy_i64_ref +#define znx_zero_i64_avx znx_zero_i64_ref + +EXPORT void vec_znx_add_avx(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +) { + const uint64_t nn = module->nn; + if (a_size <= b_size) { + const uint64_t sum_idx = res_size < a_size ? res_size : a_size; + const uint64_t copy_idx = res_size < b_size ? res_size : b_size; + // add up to the smallest dimension + for (uint64_t i = 0; i < sum_idx; ++i) { + znx_add_i64_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + // then copy to the largest dimension + for (uint64_t i = sum_idx; i < copy_idx; ++i) { + znx_copy_i64_avx(nn, res + i * res_sl, b + i * b_sl); + } + // then extend with zeros + for (uint64_t i = copy_idx; i < res_size; ++i) { + znx_zero_i64_avx(nn, res + i * res_sl); + } + } else { + const uint64_t sum_idx = res_size < b_size ? res_size : b_size; + const uint64_t copy_idx = res_size < a_size ? res_size : a_size; + // add up to the smallest dimension + for (uint64_t i = 0; i < sum_idx; ++i) { + znx_add_i64_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + // then copy to the largest dimension + for (uint64_t i = sum_idx; i < copy_idx; ++i) { + znx_copy_i64_avx(nn, res + i * res_sl, a + i * a_sl); + } + // then extend with zeros + for (uint64_t i = copy_idx; i < res_size; ++i) { + znx_zero_i64_avx(nn, res + i * res_sl); + } + } +} + +EXPORT void vec_znx_sub_avx(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +) { + const uint64_t nn = module->nn; + if (a_size <= b_size) { + const uint64_t sub_idx = res_size < a_size ? res_size : a_size; + const uint64_t copy_idx = res_size < b_size ? res_size : b_size; + // subtract up to the smallest dimension + for (uint64_t i = 0; i < sub_idx; ++i) { + znx_sub_i64_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + // then negate to the largest dimension + for (uint64_t i = sub_idx; i < copy_idx; ++i) { + znx_negate_i64_avx(nn, res + i * res_sl, b + i * b_sl); + } + // then extend with zeros + for (uint64_t i = copy_idx; i < res_size; ++i) { + znx_zero_i64_avx(nn, res + i * res_sl); + } + } else { + const uint64_t sub_idx = res_size < b_size ? res_size : b_size; + const uint64_t copy_idx = res_size < a_size ? res_size : a_size; + // subtract up to the smallest dimension + for (uint64_t i = 0; i < sub_idx; ++i) { + znx_sub_i64_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + // then copy to the largest dimension + for (uint64_t i = sub_idx; i < copy_idx; ++i) { + znx_copy_i64_avx(nn, res + i * res_sl, a + i * a_sl); + } + // then extend with zeros + for (uint64_t i = copy_idx; i < res_size; ++i) { + znx_zero_i64_avx(nn, res + i * res_sl); + } + } +} + +EXPORT void vec_znx_negate_avx(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +) { + uint64_t nn = module->nn; + uint64_t smin = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < smin; ++i) { + znx_negate_i64_avx(nn, res + i * res_sl, a + i * a_sl); + } + for (uint64_t i = smin; i < res_size; ++i) { + znx_zero_i64_ref(nn, res + i * res_sl); + } +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_znx_big.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_znx_big.c new file mode 100644 index 0000000..9e29560 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_znx_big.c @@ -0,0 +1,278 @@ +#include "vec_znx_arithmetic_private.h" + +EXPORT uint64_t bytes_of_vec_znx_big(const MODULE* module, // N + uint64_t size) { + return module->func.bytes_of_vec_znx_big(module, size); +} + +// public wrappers + +/** @brief sets res = a+b */ +EXPORT void vec_znx_big_add(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size, // a + const VEC_ZNX_BIG* b, uint64_t b_size // b +) { + module->func.vec_znx_big_add(module, res, res_size, a, a_size, b, b_size); +} + +/** @brief sets res = a+b */ +EXPORT void vec_znx_big_add_small(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +) { + module->func.vec_znx_big_add_small(module, res, res_size, a, a_size, b, b_size, b_sl); +} + +EXPORT void vec_znx_big_add_small2(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +) { + module->func.vec_znx_big_add_small2(module, res, res_size, a, a_size, a_sl, b, b_size, b_sl); +} + +/** @brief sets res = a-b */ +EXPORT void vec_znx_big_sub(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size, // a + const VEC_ZNX_BIG* b, uint64_t b_size // b +) { + module->func.vec_znx_big_sub(module, res, res_size, a, a_size, b, b_size); +} + +EXPORT void vec_znx_big_sub_small_b(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +) { + module->func.vec_znx_big_sub_small_b(module, res, res_size, a, a_size, b, b_size, b_sl); +} + +EXPORT void vec_znx_big_sub_small_a(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const VEC_ZNX_BIG* b, uint64_t b_size // b +) { + module->func.vec_znx_big_sub_small_a(module, res, res_size, a, a_size, a_sl, b, b_size); +} +EXPORT void vec_znx_big_sub_small2(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +) { + module->func.vec_znx_big_sub_small2(module, res, res_size, a, a_size, a_sl, b, b_size, b_sl); +} + +/** @brief sets res = a . X^p */ +EXPORT void vec_znx_big_rotate(const MODULE* module, // N + int64_t p, // rotation value + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size // a +) { + module->func.vec_znx_big_rotate(module, p, res, res_size, a, a_size); +} + +/** @brief sets res = a(X^p) */ +EXPORT void vec_znx_big_automorphism(const MODULE* module, // N + int64_t p, // X-X^p + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size // a +) { + module->func.vec_znx_big_automorphism(module, p, res, res_size, a, a_size); +} + +// private wrappers + +EXPORT uint64_t fft64_bytes_of_vec_znx_big(const MODULE* module, // N + uint64_t size) { + return module->nn * size * sizeof(double); +} + +EXPORT VEC_ZNX_BIG* new_vec_znx_big(const MODULE* module, // N + uint64_t size) { + return spqlios_alloc(bytes_of_vec_znx_big(module, size)); +} + +EXPORT void delete_vec_znx_big(VEC_ZNX_BIG* res) { spqlios_free(res); } + +/** @brief sets res = a+b */ +EXPORT void fft64_vec_znx_big_add(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size, // a + const VEC_ZNX_BIG* b, uint64_t b_size // b +) { + const uint64_t n = module->nn; + vec_znx_add(module, // + (int64_t*)res, res_size, n, // + (int64_t*)a, a_size, n, // + (int64_t*)b, b_size, n); +} +/** @brief sets res = a+b */ +EXPORT void fft64_vec_znx_big_add_small(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +) { + const uint64_t n = module->nn; + vec_znx_add(module, // + (int64_t*)res, res_size, n, // + (int64_t*)a, a_size, n, // + b, b_size, b_sl); +} +EXPORT void fft64_vec_znx_big_add_small2(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +) { + const uint64_t n = module->nn; + vec_znx_add(module, // + (int64_t*)res, res_size, n, // + a, a_size, a_sl, // + b, b_size, b_sl); +} + +/** @brief sets res = a-b */ +EXPORT void fft64_vec_znx_big_sub(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size, // a + const VEC_ZNX_BIG* b, uint64_t b_size // b +) { + const uint64_t n = module->nn; + vec_znx_sub(module, // + (int64_t*)res, res_size, n, // + (int64_t*)a, a_size, n, // + (int64_t*)b, b_size, n); +} + +EXPORT void fft64_vec_znx_big_sub_small_b(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +) { + const uint64_t n = module->nn; + vec_znx_sub(module, // + (int64_t*)res, res_size, n, // + (int64_t*)a, a_size, // + n, b, b_size, b_sl); +} +EXPORT void fft64_vec_znx_big_sub_small_a(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const VEC_ZNX_BIG* b, uint64_t b_size // b +) { + const uint64_t n = module->nn; + vec_znx_sub(module, // + (int64_t*)res, res_size, n, // + a, a_size, a_sl, // + (int64_t*)b, b_size, n); +} +EXPORT void fft64_vec_znx_big_sub_small2(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +) { + const uint64_t n = module->nn; + vec_znx_sub(module, // + (int64_t*)res, res_size, // + n, a, a_size, // + a_sl, b, b_size, b_sl); +} + +/** @brief sets res = a . X^p */ +EXPORT void fft64_vec_znx_big_rotate(const MODULE* module, // N + int64_t p, // rotation value + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size // a +) { + uint64_t nn = module->nn; + vec_znx_rotate(module, p, (int64_t*)res, res_size, nn, (int64_t*)a, a_size, nn); +} + +/** @brief sets res = a(X^p) */ +EXPORT void fft64_vec_znx_big_automorphism(const MODULE* module, // N + int64_t p, // X-X^p + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_BIG* a, uint64_t a_size // a +) { + uint64_t nn = module->nn; + vec_znx_automorphism(module, p, (int64_t*)res, res_size, nn, (int64_t*)a, a_size, nn); +} + +EXPORT void vec_znx_big_normalize_base2k(const MODULE* module, // MODULE + uint64_t nn, // N + uint64_t k, // base-2^k + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const VEC_ZNX_BIG* a, uint64_t a_size, // a + uint8_t* tmp_space // temp space +) { + module->func.vec_znx_big_normalize_base2k(module, // MODULE + nn, // N + k, // base-2^k + res, res_size, res_sl, // res + a, a_size, // a + tmp_space); +} + +EXPORT uint64_t vec_znx_big_normalize_base2k_tmp_bytes(const MODULE* module, uint64_t nn // N +) { + return module->func.vec_znx_big_normalize_base2k_tmp_bytes(module, nn // N + ); +} + +/** @brief sets res = k-normalize(a.subrange) -- output in int64 coeffs space */ +EXPORT void vec_znx_big_range_normalize_base2k( // + const MODULE* module, // MODULE + uint64_t nn, // N + uint64_t log2_base2k, // base-2^k + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const VEC_ZNX_BIG* a, uint64_t a_range_begin, uint64_t a_range_xend, uint64_t a_range_step, // range + uint8_t* tmp_space // temp space +) { + module->func.vec_znx_big_range_normalize_base2k(module, nn, log2_base2k, res, res_size, res_sl, a, a_range_begin, + a_range_xend, a_range_step, tmp_space); +} + +/** @brief returns the minimal byte length of scratch space for vec_znx_big_range_normalize_base2k */ +EXPORT uint64_t vec_znx_big_range_normalize_base2k_tmp_bytes( // + const MODULE* module, // MODULE + uint64_t nn // N +) { + return module->func.vec_znx_big_range_normalize_base2k_tmp_bytes(module, nn); +} + +EXPORT void fft64_vec_znx_big_normalize_base2k(const MODULE* module, // MODULE + uint64_t nn, // N + uint64_t k, // base-2^k + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const VEC_ZNX_BIG* a, uint64_t a_size, // a + uint8_t* tmp_space) { + uint64_t a_sl = nn; + module->func.vec_znx_normalize_base2k(module, // N + nn, + k, // log2_base2k + res, res_size, res_sl, // res + (int64_t*)a, a_size, a_sl, // a + tmp_space); +} + +EXPORT void fft64_vec_znx_big_range_normalize_base2k( // + const MODULE* module, // MODULE + uint64_t nn, // N + uint64_t k, // base-2^k + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const VEC_ZNX_BIG* a, uint64_t a_begin, uint64_t a_end, uint64_t a_step, // a + uint8_t* tmp_space) { + // convert the range indexes to int64[] slices + const int64_t* a_st = ((int64_t*)a) + nn * a_begin; + const uint64_t a_size = (a_end + a_step - 1 - a_begin) / a_step; + const uint64_t a_sl = nn * a_step; + // forward the call + module->func.vec_znx_normalize_base2k(module, // MODULE + nn, // N + k, // log2_base2k + res, res_size, res_sl, // res + a_st, a_size, a_sl, // a + tmp_space); +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_znx_dft.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_znx_dft.c new file mode 100644 index 0000000..bb6271f --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_znx_dft.c @@ -0,0 +1,214 @@ +#include + +#include "../q120/q120_arithmetic.h" +#include "vec_znx_arithmetic_private.h" + +EXPORT void vec_znx_dft(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +) { + return module->func.vec_znx_dft(module, res, res_size, a, a_size, a_sl); +} + +EXPORT void vec_dft_add(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const VEC_ZNX_DFT* a, uint64_t a_size, // a + const VEC_ZNX_DFT* b, uint64_t b_size // b +) { + return module->func.vec_dft_add(module, res, res_size, a, a_size, b, b_size); +} + +EXPORT void vec_dft_sub(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const VEC_ZNX_DFT* a, uint64_t a_size, // a + const VEC_ZNX_DFT* b, uint64_t b_size // b +) { + return module->func.vec_dft_sub(module, res, res_size, a, a_size, b, b_size); +} + +EXPORT void vec_znx_idft(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a + uint8_t* tmp // scratch space +) { + return module->func.vec_znx_idft(module, res, res_size, a_dft, a_size, tmp); +} + +EXPORT uint64_t vec_znx_idft_tmp_bytes(const MODULE* module, uint64_t nn) { return module->func.vec_znx_idft_tmp_bytes(module, nn); } + +EXPORT void vec_znx_idft_tmp_a(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + VEC_ZNX_DFT* a_dft, uint64_t a_size // a is overwritten +) { + return module->func.vec_znx_idft_tmp_a(module, res, res_size, a_dft, a_size); +} + +EXPORT uint64_t bytes_of_vec_znx_dft(const MODULE* module, // N + uint64_t size) { + return module->func.bytes_of_vec_znx_dft(module, size); +} + +// fft64 backend +EXPORT uint64_t fft64_bytes_of_vec_znx_dft(const MODULE* module, // N + uint64_t size) { + return module->nn * size * sizeof(double); +} + +EXPORT VEC_ZNX_DFT* new_vec_znx_dft(const MODULE* module, // N + uint64_t size) { + return spqlios_alloc(bytes_of_vec_znx_dft(module, size)); +} + +EXPORT void delete_vec_znx_dft(VEC_ZNX_DFT* res) { spqlios_free(res); } + +EXPORT void fft64_vec_znx_dft(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t smin = res_size < a_size ? res_size : a_size; + const uint64_t nn = module->nn; + + for (uint64_t i = 0; i < smin; i++) { + reim_from_znx64(module->mod.fft64.p_conv, ((double*)res) + i * nn, a + i * a_sl); + reim_fft(module->mod.fft64.p_fft, ((double*)res) + i * nn); + } + + // fill up remaining part with 0's + double* const dres = (double*)res; + memset(dres + smin * nn, 0, (res_size - smin) * nn * sizeof(double)); +} + +EXPORT void fft64_vec_dft_add(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const VEC_ZNX_DFT* a, uint64_t a_size, // a + const VEC_ZNX_DFT* b, uint64_t b_size // b +) { + const uint64_t smin0 = a_size < b_size ? a_size : b_size; + const uint64_t smin = res_size < smin0 ? res_size : smin0; + const uint64_t nn = module->nn; + + for (uint64_t i = 0; i < smin; i++) { + reim_fftvec_add(module->mod.fft64.add_fft, ((double*)res) + i * nn, ((double*)a) + i * nn, ((double*)b) + i * nn); + } + + // fill remain `res` part with 0's + double* const dres = (double*)res; + memset(dres + smin * nn, 0, (res_size - smin) * nn * sizeof(double)); +} + +EXPORT void fft64_vec_dft_sub(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const VEC_ZNX_DFT* a, uint64_t a_size, // a + const VEC_ZNX_DFT* b, uint64_t b_size // b +) { + const uint64_t smin0 = a_size < b_size ? a_size : b_size; + const uint64_t smin = res_size < smin0 ? res_size : smin0; + const uint64_t nn = module->nn; + + for (uint64_t i = 0; i < smin; i++) { + reim_fftvec_sub(module->mod.fft64.sub_fft, ((double*)res) + i * nn, ((double*)a) + i * nn, ((double*)b) + i * nn); + } + + // fill remain `res` part with 0's + double* const dres = (double*)res; + memset(dres + smin * nn, 0, (res_size - smin) * nn * sizeof(double)); +} + +EXPORT void fft64_vec_znx_idft(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a + uint8_t* tmp // unused +) { + const uint64_t nn = module->nn; + const uint64_t smin = res_size < a_size ? res_size : a_size; + if ((double*)res != (double*)a_dft) { + memcpy(res, a_dft, smin * nn * sizeof(double)); + } + + for (uint64_t i = 0; i < smin; i++) { + reim_ifft(module->mod.fft64.p_ifft, ((double*)res) + i * nn); + reim_to_znx64(module->mod.fft64.p_reim_to_znx, ((int64_t*)res) + i * nn, ((int64_t*)res) + i * nn); + } + + // fill up remaining part with 0's + int64_t* const dres = (int64_t*)res; + memset(dres + smin * nn, 0, (res_size - smin) * nn * sizeof(double)); +} + +EXPORT uint64_t fft64_vec_znx_idft_tmp_bytes(const MODULE* module, uint64_t nn) { return 0; } + +EXPORT void fft64_vec_znx_idft_tmp_a(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + VEC_ZNX_DFT* a_dft, uint64_t a_size // a is overwritten +) { + const uint64_t nn = module->nn; + const uint64_t smin = res_size < a_size ? res_size : a_size; + + int64_t* const tres = (int64_t*)res; + double* const ta = (double*)a_dft; + for (uint64_t i = 0; i < smin; i++) { + reim_ifft(module->mod.fft64.p_ifft, ta + i * nn); + reim_to_znx64(module->mod.fft64.p_reim_to_znx, tres + i * nn, ta + i * nn); + } + + // fill up remaining part with 0's + memset(tres + smin * nn, 0, (res_size - smin) * nn * sizeof(double)); +} + +// ntt120 backend + +EXPORT void ntt120_vec_znx_dft_avx(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->nn; + const uint64_t smin = res_size < a_size ? res_size : a_size; + + int64_t* tres = (int64_t*)res; + for (uint64_t i = 0; i < smin; i++) { + q120_b_from_znx64_simple(nn, (q120b*)(tres + i * nn * 4), a + i * a_sl); + q120_ntt_bb_avx2(module->mod.q120.p_ntt, (q120b*)(tres + i * nn * 4)); + } + + // fill up remaining part with 0's + memset(tres + smin * nn * 4, 0, (res_size - smin) * nn * 4 * sizeof(int64_t)); +} + +EXPORT void ntt120_vec_znx_idft_avx(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a + uint8_t* tmp) { + const uint64_t nn = module->nn; + const uint64_t smin = res_size < a_size ? res_size : a_size; + + __int128_t* const tres = (__int128_t*)res; + const int64_t* const ta = (int64_t*)a_dft; + for (uint64_t i = 0; i < smin; i++) { + memcpy(tmp, ta + i * nn * 4, nn * 4 * sizeof(uint64_t)); + q120_intt_bb_avx2(module->mod.q120.p_intt, (q120b*)tmp); + q120_b_to_znx128_simple(nn, tres + i * nn, (q120b*)tmp); + } + + // fill up remaining part with 0's + memset(tres + smin * nn, 0, (res_size - smin) * nn * sizeof(*tres)); +} + +EXPORT uint64_t ntt120_vec_znx_idft_tmp_bytes_avx(const MODULE* module, uint64_t nn) { return nn * 4 * sizeof(uint64_t); } + +EXPORT void ntt120_vec_znx_idft_tmp_a_avx(const MODULE* module, // N + VEC_ZNX_BIG* res, uint64_t res_size, // res + VEC_ZNX_DFT* a_dft, uint64_t a_size // a is overwritten +) { + const uint64_t nn = module->nn; + const uint64_t smin = res_size < a_size ? res_size : a_size; + + __int128_t* const tres = (__int128_t*)res; + int64_t* const ta = (int64_t*)a_dft; + for (uint64_t i = 0; i < smin; i++) { + q120_intt_bb_avx2(module->mod.q120.p_intt, (q120b*)(ta + i * nn * 4)); + q120_b_to_znx128_simple(nn, tres + i * nn, (q120b*)(ta + i * nn * 4)); + } + + // fill up remaining part with 0's + memset(tres + smin * nn, 0, (res_size - smin) * nn * sizeof(*tres)); +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_znx_dft_avx2.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_znx_dft_avx2.c new file mode 100644 index 0000000..dbca7cc --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vec_znx_dft_avx2.c @@ -0,0 +1 @@ +#include "vec_znx_arithmetic_private.h" diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vector_matrix_product.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vector_matrix_product.c new file mode 100644 index 0000000..fd59996 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vector_matrix_product.c @@ -0,0 +1,369 @@ +#include +#include + +#include "../reim4/reim4_arithmetic.h" +#include "vec_znx_arithmetic_private.h" + +EXPORT uint64_t bytes_of_vmp_pmat(const MODULE* module, // N + uint64_t nrows, uint64_t ncols // dimensions +) { + return module->func.bytes_of_vmp_pmat(module, nrows, ncols); +} + +// fft64 +EXPORT uint64_t fft64_bytes_of_vmp_pmat(const MODULE* module, // N + uint64_t nrows, uint64_t ncols // dimensions +) { + return module->nn * nrows * ncols * sizeof(double); +} + +EXPORT VMP_PMAT* new_vmp_pmat(const MODULE* module, // N + uint64_t nrows, uint64_t ncols // dimensions +) { + return spqlios_alloc(bytes_of_vmp_pmat(module, nrows, ncols)); +} + +EXPORT void delete_vmp_pmat(VMP_PMAT* res) { spqlios_free(res); } + +/** @brief prepares a vmp matrix (contiguous row-major version) */ +EXPORT void vmp_prepare_contiguous(const MODULE* module, // N + VMP_PMAT* pmat, // output + const int64_t* mat, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +) { + module->func.vmp_prepare_contiguous(module, pmat, mat, nrows, ncols, tmp_space); +} + +/** @brief minimal scratch space byte-size required for the vmp_prepare function */ +EXPORT uint64_t vmp_prepare_tmp_bytes(const MODULE* module, uint64_t nn, // N + uint64_t nrows, uint64_t ncols) { + return module->func.vmp_prepare_tmp_bytes(module, nn, nrows, ncols); +} + +EXPORT double* get_blk_addr(uint64_t row_i, uint64_t col_i, uint64_t nrows, uint64_t ncols, const VMP_PMAT* pmat) { + double* output_mat = (double*)pmat; + + if (col_i == (ncols - 1) && (ncols % 2 == 1)) { + // special case: last column out of an odd column number + return output_mat + col_i * nrows * 8 // col == ncols-1 + + row_i * 8; + } else { + // general case: columns go by pair + return output_mat + (col_i / 2) * (2 * nrows) * 8 // second: col pair index + + row_i * 2 * 8 // third: row index + + (col_i % 2) * 8; + } +} + +void fft64_store_svp_ppol_into_vmp_pmat_row_blk_ref(uint64_t nn, uint64_t m, const SVP_PPOL* svp_ppol, uint64_t row_i, + uint64_t col_i, uint64_t nrows, uint64_t ncols, VMP_PMAT* pmat) { + double* start_addr = get_blk_addr(row_i, col_i, nrows, ncols, pmat); + uint64_t offset = nrows * ncols * 8; + for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) { + reim4_extract_1blk_from_reim_ref(m, blk_i, start_addr + blk_i * offset, (double*)svp_ppol); + } +} + +/** @brief prepares a vmp matrix (contiguous row-major version) */ +EXPORT void fft64_vmp_prepare_contiguous_ref(const MODULE* module, // N + VMP_PMAT* pmat, // output + const int64_t* mat, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +) { + // there is an edge case if nn < 8 + const uint64_t nn = module->nn; + const uint64_t m = module->m; + + if (nn >= 8) { + for (uint64_t row_i = 0; row_i < nrows; row_i++) { + for (uint64_t col_i = 0; col_i < ncols; col_i++) { + reim_from_znx64(module->mod.fft64.p_conv, (SVP_PPOL*)tmp_space, mat + (row_i * ncols + col_i) * nn); + reim_fft(module->mod.fft64.p_fft, (double*)tmp_space); + fft64_store_svp_ppol_into_vmp_pmat_row_blk_ref(nn, m, (SVP_PPOL*)tmp_space, row_i, col_i, nrows, ncols, pmat); + } + } + } else { + for (uint64_t row_i = 0; row_i < nrows; row_i++) { + for (uint64_t col_i = 0; col_i < ncols; col_i++) { + double* res = (double*)pmat + (col_i * nrows + row_i) * nn; + reim_from_znx64(module->mod.fft64.p_conv, (SVP_PPOL*)res, mat + (row_i * ncols + col_i) * nn); + reim_fft(module->mod.fft64.p_fft, res); + } + } + } +} + +/** @brief minimal scratch space byte-size required for the vmp_prepare function */ +EXPORT uint64_t fft64_vmp_prepare_tmp_bytes(const MODULE* module, uint64_t nn, // N + uint64_t nrows, uint64_t ncols) { + return nn * sizeof(int64_t); +} + +/** @brief applies a vmp product (result in DFT space) and adds to res inplace */ +EXPORT void fft64_vmp_apply_dft_add_ref(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, + uint64_t pmat_scale, // prep matrix + uint8_t* tmp_space // scratch space +) { + const uint64_t nn = module->nn; + const uint64_t rows = nrows < a_size ? nrows : a_size; + + VEC_ZNX_DFT* a_dft = (VEC_ZNX_DFT*)tmp_space; + uint8_t* new_tmp_space = (uint8_t*)tmp_space + rows * nn * sizeof(double); + + fft64_vec_znx_dft(module, a_dft, rows, a, a_size, a_sl); + fft64_vmp_apply_dft_to_dft_add_ref(module, res, res_size, a_dft, a_size, pmat, nrows, ncols, pmat_scale, + new_tmp_space); +} + +/** @brief applies a vmp product (result in DFT space) */ +EXPORT void fft64_vmp_apply_dft_ref(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space +) { + const uint64_t nn = module->nn; + const uint64_t rows = nrows < a_size ? nrows : a_size; + + VEC_ZNX_DFT* a_dft = (VEC_ZNX_DFT*)tmp_space; + uint8_t* new_tmp_space = (uint8_t*)tmp_space + rows * nn * sizeof(double); + + fft64_vec_znx_dft(module, a_dft, rows, a, a_size, a_sl); + fft64_vmp_apply_dft_to_dft_ref(module, res, res_size, a_dft, a_size, pmat, nrows, ncols, new_tmp_space); +} + +/** @brief like fft64_vmp_apply_dft_to_dft_ref but adds in place */ +EXPORT void fft64_vmp_apply_dft_to_dft_add_ref(const MODULE* module, // N + VEC_ZNX_DFT* res, const uint64_t res_size, // res + const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a + const VMP_PMAT* pmat, const uint64_t nrows, const uint64_t ncols, + uint64_t pmat_scale, // prep matrix + uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes) +) { + const uint64_t m = module->m; + const uint64_t nn = module->nn; + assert(nn >= 8); + + double* mat2cols_output = (double*)tmp_space; // 128 bytes + double* extracted_blk = (double*)tmp_space + 16; // 64*min(nrows,a_size) bytes + + double* mat_input = (double*)pmat; + double* vec_input = (double*)a_dft; + double* vec_output = (double*)res; + + // const uint64_t row_max0 = res_size < a_size ? res_size: a_size; + const uint64_t row_max = nrows < a_size ? nrows : a_size; + const uint64_t col_max = ncols < res_size ? ncols : res_size; + + if (nn >= 8) { + for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) { + double* mat_blk_start = mat_input + blk_i * (8 * nrows * ncols); + + reim4_extract_1blk_from_contiguous_reim_ref(m, row_max, blk_i, (double*)extracted_blk, (double*)a_dft); + + if (pmat_scale % 2 == 0) { + // apply mat2cols + for (uint64_t col_res = 0, col_pmat = pmat_scale; col_pmat < col_max - 1; col_res += 2, col_pmat += 2) { + uint64_t col_offset = col_pmat * (8 * nrows); + reim4_vec_mat2cols_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset); + reim4_add_1blk_to_reim_ref(m, blk_i, vec_output + col_res * nn, mat2cols_output); + reim4_add_1blk_to_reim_ref(m, blk_i, vec_output + (col_res + 1) * nn, mat2cols_output + 8); + } + } else { + uint64_t col_offset = (pmat_scale - 1) * (8 * nrows); + reim4_vec_mat2cols_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset); + reim4_add_1blk_to_reim_ref(m, blk_i, vec_output, mat2cols_output + 8); + + // apply mat2cols + for (uint64_t col_res = 1, col_pmat = pmat_scale + 1; col_pmat < col_max - 1; col_res += 2, col_pmat += 2) { + uint64_t col_offset = col_pmat * (8 * nrows); + reim4_vec_mat2cols_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset); + + reim4_add_1blk_to_reim_ref(m, blk_i, vec_output + col_res * nn, mat2cols_output); + reim4_add_1blk_to_reim_ref(m, blk_i, vec_output + (col_res + 1) * nn, mat2cols_output + 8); + } + } + + // check if col_max is odd, then special case + if (col_max % 2 == 1) { + uint64_t last_col = col_max - 1; + uint64_t col_offset = last_col * (8 * nrows); + + if (last_col >= pmat_scale) { + // the last column is alone in the pmat: vec_mat1col + if (ncols == col_max) { + reim4_vec_mat1col_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset); + } else { + // the last column is part of a colpair in the pmat: vec_mat2cols and ignore the second position + reim4_vec_mat2cols_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset); + } + + reim4_add_1blk_to_reim_ref(m, blk_i, vec_output + (last_col - pmat_scale) * nn, mat2cols_output); + } + } + } + } else { + for (uint64_t col_res = 0, col_pmat = pmat_scale; col_pmat < col_max; col_res += 1, col_pmat += 1) { + double* pmat_col = mat_input + col_pmat * nrows * nn; + for (uint64_t row_i = 0; row_i < row_max; row_i++) { + reim_fftvec_addmul(module->mod.fft64.p_addmul, vec_output + col_res * nn, vec_input + row_i * nn, + pmat_col + row_i * nn); + } + } + } + + // zero out remaining bytes + memset(vec_output + col_max * nn, 0, (res_size - col_max) * nn * sizeof(double)); +} + +/** @brief this inner function could be very handy */ +EXPORT void fft64_vmp_apply_dft_to_dft_ref(const MODULE* module, // N + VEC_ZNX_DFT* res, const uint64_t res_size, // res + const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a + const VMP_PMAT* pmat, const uint64_t nrows, + const uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes) +) { + const uint64_t m = module->m; + const uint64_t nn = module->nn; + assert(nn >= 8); + + double* mat2cols_output = (double*)tmp_space; // 128 bytes + double* extracted_blk = (double*)tmp_space + 16; // 64*min(nrows,a_size) bytes + + double* mat_input = (double*)pmat; + double* vec_input = (double*)a_dft; + double* vec_output = (double*)res; + + // const uint64_t row_max0 = res_size < a_size ? res_size: a_size; + const uint64_t row_max = nrows < a_size ? nrows : a_size; + const uint64_t col_max = ncols < res_size ? ncols : res_size; + + if (nn >= 8) { + for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) { + double* mat_blk_start = mat_input + blk_i * (8 * nrows * ncols); + + reim4_extract_1blk_from_contiguous_reim_ref(m, row_max, blk_i, (double*)extracted_blk, (double*)a_dft); + // apply mat2cols + for (uint64_t col_i = 0; col_i < col_max - 1; col_i += 2) { + uint64_t col_offset = col_i * (8 * nrows); + reim4_vec_mat2cols_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset); + + reim4_save_1blk_to_reim_ref(m, blk_i, vec_output + col_i * nn, mat2cols_output); + reim4_save_1blk_to_reim_ref(m, blk_i, vec_output + (col_i + 1) * nn, mat2cols_output + 8); + } + + // check if col_max is odd, then special case + if (col_max % 2 == 1) { + uint64_t last_col = col_max - 1; + uint64_t col_offset = last_col * (8 * nrows); + + // the last column is alone in the pmat: vec_mat1col + if (ncols == col_max) { + reim4_vec_mat1col_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset); + } else { + // the last column is part of a colpair in the pmat: vec_mat2cols and ignore the second position + reim4_vec_mat2cols_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset); + } + reim4_save_1blk_to_reim_ref(m, blk_i, vec_output + last_col * nn, mat2cols_output); + } + } + } else { + for (uint64_t col_i = 0; col_i < col_max; col_i++) { + double* pmat_col = mat_input + col_i * nrows * nn; + for (uint64_t row_i = 0; row_i < 1; row_i++) { + reim_fftvec_mul(module->mod.fft64.mul_fft, vec_output + col_i * nn, vec_input + row_i * nn, + pmat_col + row_i * nn); + } + for (uint64_t row_i = 1; row_i < row_max; row_i++) { + reim_fftvec_addmul(module->mod.fft64.p_addmul, vec_output + col_i * nn, vec_input + row_i * nn, + pmat_col + row_i * nn); + } + } + } + + // zero out remaining bytes + memset(vec_output + col_max * nn, 0, (res_size - col_max) * nn * sizeof(double)); +} + +/** @brief minimal size of the tmp_space */ +EXPORT uint64_t fft64_vmp_apply_dft_tmp_bytes(const MODULE* module, uint64_t nn, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix +) { + const uint64_t row_max = nrows < a_size ? nrows : a_size; + return (row_max * nn * sizeof(double)) + (128) + (64 * row_max); +} + +/** @brief minimal size of the tmp_space */ +EXPORT uint64_t fft64_vmp_apply_dft_to_dft_tmp_bytes(const MODULE* module, uint64_t nn, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix +) { + const uint64_t row_max = nrows < a_size ? nrows : a_size; + + return (128) + (64 * row_max); +} + +EXPORT void vmp_apply_dft_to_dft(const MODULE* module, // N + VEC_ZNX_DFT* res, const uint64_t res_size, // res + const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a + const VMP_PMAT* pmat, const uint64_t nrows, + const uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes) +) { + module->func.vmp_apply_dft_to_dft(module, res, res_size, a_dft, a_size, pmat, nrows, ncols, tmp_space); +} + +EXPORT void vmp_apply_dft_to_dft_add(const MODULE* module, // N + VEC_ZNX_DFT* res, const uint64_t res_size, // res + const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a + const VMP_PMAT* pmat, const uint64_t nrows, const uint64_t ncols, + uint64_t pmat_scale, // prep matrix + uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes) +) { + module->func.vmp_apply_dft_to_dft_add(module, res, res_size, a_dft, a_size, pmat, nrows, ncols, pmat_scale, + tmp_space); +} + +EXPORT uint64_t vmp_apply_dft_to_dft_tmp_bytes(const MODULE* module, uint64_t nn, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix +) { + return module->func.vmp_apply_dft_to_dft_tmp_bytes(module, nn, res_size, a_size, nrows, ncols); +} + +/** @brief applies a vmp product (result in DFT space) adds to res inplace */ +EXPORT void vmp_apply_dft_add(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, uint64_t pmat_scale, // prep matrix + uint8_t* tmp_space // scratch space +) { + module->func.vmp_apply_dft_add(module, res, res_size, a, a_size, a_sl, pmat, nrows, ncols, pmat_scale, tmp_space); +} + +/** @brief applies a vmp product (result in DFT space) */ +EXPORT void vmp_apply_dft(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space +) { + module->func.vmp_apply_dft(module, res, res_size, a, a_size, a_sl, pmat, nrows, ncols, tmp_space); +} + +/** @brief minimal size of the tmp_space */ +EXPORT uint64_t vmp_apply_dft_tmp_bytes(const MODULE* module, uint64_t nn, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix +) { + return module->func.vmp_apply_dft_tmp_bytes(module, nn, res_size, a_size, nrows, ncols); +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vector_matrix_product_avx.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vector_matrix_product_avx.c new file mode 100644 index 0000000..7088935 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/vector_matrix_product_avx.c @@ -0,0 +1,244 @@ +#include + +#include "../reim4/reim4_arithmetic.h" +#include "vec_znx_arithmetic_private.h" + +/** @brief prepares a vmp matrix (contiguous row-major version) */ +EXPORT void fft64_vmp_prepare_contiguous_avx(const MODULE* module, // N + VMP_PMAT* pmat, // output + const int64_t* mat, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +) { + // there is an edge case if nn < 8 + const uint64_t nn = module->nn; + const uint64_t m = module->m; + + double* output_mat = (double*)pmat; + double* start_addr = (double*)pmat; + uint64_t offset = nrows * ncols * 8; + + if (nn >= 8) { + for (uint64_t row_i = 0; row_i < nrows; row_i++) { + for (uint64_t col_i = 0; col_i < ncols; col_i++) { + reim_from_znx64(module->mod.fft64.p_conv, (SVP_PPOL*)tmp_space, mat + (row_i * ncols + col_i) * nn); + reim_fft(module->mod.fft64.p_fft, (double*)tmp_space); + + if (col_i == (ncols - 1) && (ncols % 2 == 1)) { + // special case: last column out of an odd column number + start_addr = output_mat + col_i * nrows * 8 // col == ncols-1 + + row_i * 8; + } else { + // general case: columns go by pair + start_addr = output_mat + (col_i / 2) * (2 * nrows) * 8 // second: col pair index + + row_i * 2 * 8 // third: row index + + (col_i % 2) * 8; + } + + for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) { + // extract blk from tmp and save it + reim4_extract_1blk_from_reim_avx(m, blk_i, start_addr + blk_i * offset, (double*)tmp_space); + } + } + } + } else { + for (uint64_t row_i = 0; row_i < nrows; row_i++) { + for (uint64_t col_i = 0; col_i < ncols; col_i++) { + double* res = (double*)pmat + (col_i * nrows + row_i) * nn; + reim_from_znx64(module->mod.fft64.p_conv, (SVP_PPOL*)res, mat + (row_i * ncols + col_i) * nn); + reim_fft(module->mod.fft64.p_fft, res); + } + } + } +} + +double* get_blk_addr(int row, int col, int nrows, int ncols, VMP_PMAT* pmat); + +void fft64_store_svp_ppol_into_vmp_pmat_row_blk_avx(uint64_t nn, uint64_t m, const SVP_PPOL* svp_ppol, uint64_t row_i, + uint64_t col_i, uint64_t nrows, uint64_t ncols, VMP_PMAT* pmat) { + double* start_addr = get_blk_addr(row_i, col_i, nrows, ncols, pmat); + uint64_t offset = nrows * ncols * 8; + for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) { + reim4_extract_1blk_from_reim_avx(m, blk_i, start_addr + blk_i * offset, (double*)svp_ppol); + } +} + +/** @brief applies a vmp product (result in DFT space) abd adds to res inplace */ +EXPORT void fft64_vmp_apply_dft_add_avx(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, + uint64_t pmat_scale, // prep matrix + uint8_t* tmp_space // scratch space +) { + const uint64_t nn = module->nn; + const uint64_t rows = nrows < a_size ? nrows : a_size; + + VEC_ZNX_DFT* a_dft = (VEC_ZNX_DFT*)tmp_space; + uint8_t* new_tmp_space = (uint8_t*)tmp_space + rows * nn * sizeof(double); + + fft64_vec_znx_dft(module, a_dft, rows, a, a_size, a_sl); + fft64_vmp_apply_dft_to_dft_add_avx(module, res, res_size, a_dft, a_size, pmat, nrows, ncols, pmat_scale, + new_tmp_space); +} + +/** @brief applies a vmp product (result in DFT space) */ +EXPORT void fft64_vmp_apply_dft_avx(const MODULE* module, // N + VEC_ZNX_DFT* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space +) { + const uint64_t nn = module->nn; + const uint64_t rows = nrows < a_size ? nrows : a_size; + + VEC_ZNX_DFT* a_dft = (VEC_ZNX_DFT*)tmp_space; + uint8_t* new_tmp_space = (uint8_t*)tmp_space + rows * nn * sizeof(double); + + fft64_vec_znx_dft(module, a_dft, rows, a, a_size, a_sl); + fft64_vmp_apply_dft_to_dft_avx(module, res, res_size, a_dft, a_size, pmat, nrows, ncols, new_tmp_space); +} + +EXPORT void fft64_vmp_apply_dft_to_dft_add_avx(const MODULE* module, // N + VEC_ZNX_DFT* res, const uint64_t res_size, // res + const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a + const VMP_PMAT* pmat, const uint64_t nrows, const uint64_t ncols, + uint64_t pmat_scale, // prep matrix + uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes) +) { + const uint64_t m = module->m; + const uint64_t nn = module->nn; + + double* mat2cols_output = (double*)tmp_space; // 128 bytes + double* extracted_blk = (double*)tmp_space + 16; // 64*min(nrows,a_size) bytes + + double* mat_input = (double*)pmat; + double* vec_input = (double*)a_dft; + double* vec_output = (double*)res; + + const uint64_t row_max = nrows < a_size ? nrows : a_size; + const uint64_t col_max = ncols < res_size ? ncols : res_size; + + if (nn >= 8) { + for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) { + double* mat_blk_start = mat_input + blk_i * (8 * nrows * ncols); + + reim4_extract_1blk_from_contiguous_reim_avx(m, row_max, blk_i, (double*)extracted_blk, (double*)a_dft); + + if (pmat_scale % 2 == 0) { + for (uint64_t col_res = 0, col_pmat = pmat_scale; col_pmat < col_max - 1; col_res += 2, col_pmat += 2) { + uint64_t col_offset = col_pmat * (8 * nrows); + reim4_vec_mat2cols_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset); + reim4_add_1blk_to_reim_avx(m, blk_i, vec_output + col_res * nn, mat2cols_output); + reim4_add_1blk_to_reim_avx(m, blk_i, vec_output + (col_res + 1) * nn, mat2cols_output + 8); + } + } else { + uint64_t col_offset = (pmat_scale - 1) * (8 * nrows); + reim4_vec_mat2cols_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset); + reim4_add_1blk_to_reim_avx(m, blk_i, vec_output, mat2cols_output + 8); + + for (uint64_t col_res = 1, col_pmat = pmat_scale + 1; col_pmat < col_max - 1; col_res += 2, col_pmat += 2) { + uint64_t col_offset = col_pmat * (8 * nrows); + reim4_vec_mat2cols_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset); + reim4_add_1blk_to_reim_avx(m, blk_i, vec_output + col_res * nn, mat2cols_output); + reim4_add_1blk_to_reim_avx(m, blk_i, vec_output + (col_res + 1) * nn, mat2cols_output + 8); + } + } + + // check if col_max is odd, then special case + if (col_max % 2 == 1) { + uint64_t last_col = col_max - 1; + uint64_t col_offset = last_col * (8 * nrows); + + if (last_col >= pmat_scale) { + // the last column is alone in the pmat: vec_mat1col + if (ncols == col_max) + reim4_vec_mat1col_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset); + else { + // the last column is part of a colpair in the pmat: vec_mat2cols and ignore the second position + reim4_vec_mat2cols_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset); + } + reim4_add_1blk_to_reim_avx(m, blk_i, vec_output + (last_col - pmat_scale) * nn, mat2cols_output); + } + } + } + } else { + for (uint64_t col_res = 0, col_pmat = pmat_scale; col_pmat < col_max; col_res += 1, col_pmat += 1) { + double* pmat_col = mat_input + col_pmat * nrows * nn; + for (uint64_t row_i = 0; row_i < row_max; row_i++) { + reim_fftvec_addmul(module->mod.fft64.p_addmul, vec_output + col_res * nn, vec_input + row_i * nn, + pmat_col + row_i * nn); + } + } + } + + // zero out remaining bytes + memset(vec_output + col_max * nn, 0, (res_size - col_max) * nn * sizeof(double)); +} + +/** @brief this inner function could be very handy */ +EXPORT void fft64_vmp_apply_dft_to_dft_avx(const MODULE* module, // N + VEC_ZNX_DFT* res, const uint64_t res_size, // res + const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a + const VMP_PMAT* pmat, const uint64_t nrows, + const uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes) +) { + const uint64_t m = module->m; + const uint64_t nn = module->nn; + + double* mat2cols_output = (double*)tmp_space; // 128 bytes + double* extracted_blk = (double*)tmp_space + 16; // 64*min(nrows,a_size) bytes + + double* mat_input = (double*)pmat; + double* vec_input = (double*)a_dft; + double* vec_output = (double*)res; + + const uint64_t row_max = nrows < a_size ? nrows : a_size; + const uint64_t col_max = ncols < res_size ? ncols : res_size; + + if (nn >= 8) { + for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) { + double* mat_blk_start = mat_input + blk_i * (8 * nrows * ncols); + + reim4_extract_1blk_from_contiguous_reim_avx(m, row_max, blk_i, (double*)extracted_blk, (double*)a_dft); + // apply mat2cols + for (uint64_t col_i = 0; col_i < col_max - 1; col_i += 2) { + uint64_t col_offset = col_i * (8 * nrows); + reim4_vec_mat2cols_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset); + + reim4_save_1blk_to_reim_avx(m, blk_i, vec_output + col_i * nn, mat2cols_output); + reim4_save_1blk_to_reim_avx(m, blk_i, vec_output + (col_i + 1) * nn, mat2cols_output + 8); + } + + // check if col_max is odd, then special case + if (col_max % 2 == 1) { + uint64_t last_col = col_max - 1; + uint64_t col_offset = last_col * (8 * nrows); + + // the last column is alone in the pmat: vec_mat1col + if (ncols == col_max) + reim4_vec_mat1col_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset); + else { + // the last column is part of a colpair in the pmat: vec_mat2cols and ignore the second position + reim4_vec_mat2cols_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset); + } + reim4_save_1blk_to_reim_avx(m, blk_i, vec_output + last_col * nn, mat2cols_output); + } + } + } else { + for (uint64_t col_i = 0; col_i < col_max; col_i++) { + double* pmat_col = mat_input + col_i * nrows * nn; + for (uint64_t row_i = 0; row_i < 1; row_i++) { + reim_fftvec_mul(module->mod.fft64.mul_fft, vec_output + col_i * nn, vec_input + row_i * nn, + pmat_col + row_i * nn); + } + for (uint64_t row_i = 1; row_i < row_max; row_i++) { + reim_fftvec_addmul(module->mod.fft64.p_addmul, vec_output + col_i * nn, vec_input + row_i * nn, + pmat_col + row_i * nn); + } + } + } + + // zero out remaining bytes + memset(vec_output + col_max * nn, 0, (res_size - col_max) * nn * sizeof(double)); +} \ No newline at end of file diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/zn_api.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/zn_api.c new file mode 100644 index 0000000..4b81750 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/zn_api.c @@ -0,0 +1,185 @@ +#include + +#include "zn_arithmetic_private.h" + +void default_init_z_module_precomp(MOD_Z* module) { + // Add here initialization of items that are in the precomp +} + +void default_finalize_z_module_precomp(MOD_Z* module) { + // Add here deleters for items that are in the precomp +} + +void default_init_z_module_vtable(MOD_Z* module) { + // Add function pointers here + module->vtable.i8_approxdecomp_from_tndbl = default_i8_approxdecomp_from_tndbl_ref; + module->vtable.i16_approxdecomp_from_tndbl = default_i16_approxdecomp_from_tndbl_ref; + module->vtable.i32_approxdecomp_from_tndbl = default_i32_approxdecomp_from_tndbl_ref; + module->vtable.zn32_vmp_prepare_contiguous = default_zn32_vmp_prepare_contiguous_ref; + module->vtable.zn32_vmp_prepare_dblptr = default_zn32_vmp_prepare_dblptr_ref; + module->vtable.zn32_vmp_prepare_row = default_zn32_vmp_prepare_row_ref; + module->vtable.zn32_vmp_apply_i8 = default_zn32_vmp_apply_i8_ref; + module->vtable.zn32_vmp_apply_i16 = default_zn32_vmp_apply_i16_ref; + module->vtable.zn32_vmp_apply_i32 = default_zn32_vmp_apply_i32_ref; + module->vtable.dbl_to_tn32 = dbl_to_tn32_ref; + module->vtable.tn32_to_dbl = tn32_to_dbl_ref; + module->vtable.dbl_round_to_i32 = dbl_round_to_i32_ref; + module->vtable.i32_to_dbl = i32_to_dbl_ref; + module->vtable.dbl_round_to_i64 = dbl_round_to_i64_ref; + module->vtable.i64_to_dbl = i64_to_dbl_ref; + + // Add optimized function pointers here + if (CPU_SUPPORTS("avx")) { + module->vtable.zn32_vmp_apply_i8 = default_zn32_vmp_apply_i8_avx; + module->vtable.zn32_vmp_apply_i16 = default_zn32_vmp_apply_i16_avx; + module->vtable.zn32_vmp_apply_i32 = default_zn32_vmp_apply_i32_avx; + } +} + +void init_z_module_info(MOD_Z* module, // + Z_MODULE_TYPE mtype) { + memset(module, 0, sizeof(MOD_Z)); + module->mtype = mtype; + switch (mtype) { + case DEFAULT: + default_init_z_module_precomp(module); + default_init_z_module_vtable(module); + break; + default: + NOT_SUPPORTED(); // unknown mtype + } +} + +void finalize_z_module_info(MOD_Z* module) { + if (module->custom) module->custom_deleter(module->custom); + switch (module->mtype) { + case DEFAULT: + default_finalize_z_module_precomp(module); + // fft64_finalize_rnx_module_vtable(module); // nothing to finalize + break; + default: + NOT_SUPPORTED(); // unknown mtype + } +} + +EXPORT MOD_Z* new_z_module_info(Z_MODULE_TYPE mtype) { + MOD_Z* res = (MOD_Z*)malloc(sizeof(MOD_Z)); + init_z_module_info(res, mtype); + return res; +} + +EXPORT void delete_z_module_info(MOD_Z* module_info) { + finalize_z_module_info(module_info); + free(module_info); +} + +//////////////// wrappers ////////////////// + +/** @brief sets res = gadget_decompose(a) (int8_t* output) */ +EXPORT void i8_approxdecomp_from_tndbl(const MOD_Z* module, // N + const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget + int8_t* res, uint64_t res_size, // res (in general, size ell.a_size) + const double* a, uint64_t a_size) { // a + module->vtable.i8_approxdecomp_from_tndbl(module, gadget, res, res_size, a, a_size); +} + +/** @brief sets res = gadget_decompose(a) (int16_t* output) */ +EXPORT void i16_approxdecomp_from_tndbl(const MOD_Z* module, // N + const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget + int16_t* res, uint64_t res_size, // res (in general, size ell.a_size) + const double* a, uint64_t a_size) { // a + module->vtable.i16_approxdecomp_from_tndbl(module, gadget, res, res_size, a, a_size); +} +/** @brief sets res = gadget_decompose(a) (int32_t* output) */ +EXPORT void i32_approxdecomp_from_tndbl(const MOD_Z* module, // N + const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget + int32_t* res, uint64_t res_size, // res (in general, size ell.a_size) + const double* a, uint64_t a_size) { // a + module->vtable.i32_approxdecomp_from_tndbl(module, gadget, res, res_size, a, a_size); +} + +/** @brief prepares a vmp matrix (contiguous row-major version) */ +EXPORT void zn32_vmp_prepare_contiguous(const MOD_Z* module, + ZN32_VMP_PMAT* pmat, // output + const int32_t* mat, uint64_t nrows, uint64_t ncols) { // a + module->vtable.zn32_vmp_prepare_contiguous(module, pmat, mat, nrows, ncols); +} + +/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */ +EXPORT void zn32_vmp_prepare_dblptr(const MOD_Z* module, + ZN32_VMP_PMAT* pmat, // output + const int32_t** mat, uint64_t nrows, uint64_t ncols) { // a + module->vtable.zn32_vmp_prepare_dblptr(module, pmat, mat, nrows, ncols); +} + +/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */ +EXPORT void zn32_vmp_prepare_row(const MOD_Z* module, + ZN32_VMP_PMAT* pmat, // output + const int32_t* row, uint64_t row_i, uint64_t nrows, uint64_t ncols) { // a + module->vtable.zn32_vmp_prepare_row(module, pmat, row, row_i, nrows, ncols); +} + +/** @brief applies a vmp product (int32_t* input) */ +EXPORT void zn32_vmp_apply_i32(const MOD_Z* module, int32_t* res, uint64_t res_size, const int32_t* a, uint64_t a_size, + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) { + module->vtable.zn32_vmp_apply_i32(module, res, res_size, a, a_size, pmat, nrows, ncols); +} +/** @brief applies a vmp product (int16_t* input) */ +EXPORT void zn32_vmp_apply_i16(const MOD_Z* module, int32_t* res, uint64_t res_size, const int16_t* a, uint64_t a_size, + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) { + module->vtable.zn32_vmp_apply_i16(module, res, res_size, a, a_size, pmat, nrows, ncols); +} + +/** @brief applies a vmp product (int8_t* input) */ +EXPORT void zn32_vmp_apply_i8(const MOD_Z* module, int32_t* res, uint64_t res_size, const int8_t* a, uint64_t a_size, + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) { + module->vtable.zn32_vmp_apply_i8(module, res, res_size, a, a_size, pmat, nrows, ncols); +} + +/** reduction mod 1, output in torus32 space */ +EXPORT void dbl_to_tn32(const MOD_Z* module, // + int32_t* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +) { + module->vtable.dbl_to_tn32(module, res, res_size, a, a_size); +} + +/** real centerlift mod 1, output in double space */ +EXPORT void tn32_to_dbl(const MOD_Z* module, // + double* res, uint64_t res_size, // res + const int32_t* a, uint64_t a_size // a +) { + module->vtable.tn32_to_dbl(module, res, res_size, a, a_size); +} + +/** round to the nearest int, output in i32 space */ +EXPORT void dbl_round_to_i32(const MOD_Z* module, // + int32_t* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +) { + module->vtable.dbl_round_to_i32(module, res, res_size, a, a_size); +} + +/** small int (int32 space) to double */ +EXPORT void i32_to_dbl(const MOD_Z* module, // + double* res, uint64_t res_size, // res + const int32_t* a, uint64_t a_size // a +) { + module->vtable.i32_to_dbl(module, res, res_size, a, a_size); +} + +/** round to the nearest int, output in int64 space */ +EXPORT void dbl_round_to_i64(const MOD_Z* module, // + int64_t* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +) { + module->vtable.dbl_round_to_i64(module, res, res_size, a, a_size); +} + +/** small int (int64 space, <= 2^50) to double */ +EXPORT void i64_to_dbl(const MOD_Z* module, // + double* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size // a +) { + module->vtable.i64_to_dbl(module, res, res_size, a, a_size); +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/zn_approxdecomp_ref.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/zn_approxdecomp_ref.c new file mode 100644 index 0000000..616b9a3 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/zn_approxdecomp_ref.c @@ -0,0 +1,81 @@ +#include + +#include "zn_arithmetic_private.h" + +EXPORT TNDBL_APPROXDECOMP_GADGET* new_tndbl_approxdecomp_gadget(const MOD_Z* module, // + uint64_t k, uint64_t ell) { + if (k * ell > 50) { + return spqlios_error("approx decomposition requested is too precise for doubles"); + } + if (k < 1) { + return spqlios_error("approx decomposition supports k>=1"); + } + TNDBL_APPROXDECOMP_GADGET* res = malloc(sizeof(TNDBL_APPROXDECOMP_GADGET)); + memset(res, 0, sizeof(TNDBL_APPROXDECOMP_GADGET)); + res->k = k; + res->ell = ell; + double add_cst = INT64_C(3) << (51 - k * ell); + for (uint64_t i = 0; i < ell; ++i) { + add_cst += pow(2., -(double)(i * k + 1)); + } + res->add_cst = add_cst; + res->and_mask = (UINT64_C(1) << k) - 1; + res->sub_cst = UINT64_C(1) << (k - 1); + for (uint64_t i = 0; i < ell; ++i) res->rshifts[i] = (ell - 1 - i) * k; + return res; +} +EXPORT void delete_tndbl_approxdecomp_gadget(TNDBL_APPROXDECOMP_GADGET* ptr) { free(ptr); } + +EXPORT int default_init_tndbl_approxdecomp_gadget(const MOD_Z* module, // + TNDBL_APPROXDECOMP_GADGET* res, // + uint64_t k, uint64_t ell) { + return 0; +} + +typedef union { + double dv; + uint64_t uv; +} du_t; + +#define IMPL_ixx_approxdecomp_from_tndbl_ref(ITYPE) \ + if (res_size != a_size * gadget->ell) NOT_IMPLEMENTED(); \ + const uint64_t ell = gadget->ell; \ + const double add_cst = gadget->add_cst; \ + const uint8_t* const rshifts = gadget->rshifts; \ + const ITYPE and_mask = gadget->and_mask; \ + const ITYPE sub_cst = gadget->sub_cst; \ + ITYPE* rr = res; \ + const double* aa = a; \ + const double* aaend = a + a_size; \ + while (aa < aaend) { \ + du_t t = {.dv = *aa + add_cst}; \ + for (uint64_t i = 0; i < ell; ++i) { \ + ITYPE v = (ITYPE)(t.uv >> rshifts[i]); \ + *rr = (v & and_mask) - sub_cst; \ + ++rr; \ + } \ + ++aa; \ + } + +/** @brief sets res = gadget_decompose(a) (int8_t* output) */ +EXPORT void default_i8_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N + const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget + int8_t* res, uint64_t res_size, // res (in general, size ell.a_size) + const double* a, uint64_t a_size // +){IMPL_ixx_approxdecomp_from_tndbl_ref(int8_t)} + +/** @brief sets res = gadget_decompose(a) (int16_t* output) */ +EXPORT void default_i16_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N + const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget + int16_t* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +){IMPL_ixx_approxdecomp_from_tndbl_ref(int16_t)} + +/** @brief sets res = gadget_decompose(a) (int32_t* output) */ +EXPORT void default_i32_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N + const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget + int32_t* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +) { + IMPL_ixx_approxdecomp_from_tndbl_ref(int32_t) +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/zn_arithmetic.h b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/zn_arithmetic.h new file mode 100644 index 0000000..7aec10a --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/zn_arithmetic.h @@ -0,0 +1,147 @@ +#ifndef SPQLIOS_ZN_ARITHMETIC_H +#define SPQLIOS_ZN_ARITHMETIC_H + +#include + +#include "../commons.h" + +typedef enum z_module_type_t { DEFAULT } Z_MODULE_TYPE; + +/** @brief opaque structure that describes the module and the hardware */ +typedef struct z_module_info_t MOD_Z; + +/** + * @brief obtain a module info for ring dimension N + * the module-info knows about: + * - the dimension N (or the complex dimension m=N/2) + * - any moduleuted fft or ntt items + * - the hardware (avx, arm64, x86, ...) + */ +EXPORT MOD_Z* new_z_module_info(Z_MODULE_TYPE mode); +EXPORT void delete_z_module_info(MOD_Z* module_info); + +typedef struct tndbl_approxdecomp_gadget_t TNDBL_APPROXDECOMP_GADGET; + +EXPORT TNDBL_APPROXDECOMP_GADGET* new_tndbl_approxdecomp_gadget(const MOD_Z* module, // + uint64_t k, + uint64_t ell); // base 2^k, and size + +EXPORT void delete_tndbl_approxdecomp_gadget(TNDBL_APPROXDECOMP_GADGET* ptr); + +/** @brief sets res = gadget_decompose(a) (int8_t* output) */ +EXPORT void i8_approxdecomp_from_tndbl(const MOD_Z* module, // N + const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget + int8_t* res, uint64_t res_size, // res (in general, size ell.a_size) + const double* a, uint64_t a_size); // a + +/** @brief sets res = gadget_decompose(a) (int16_t* output) */ +EXPORT void i16_approxdecomp_from_tndbl(const MOD_Z* module, // N + const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget + int16_t* res, uint64_t res_size, // res (in general, size ell.a_size) + const double* a, uint64_t a_size); // a +/** @brief sets res = gadget_decompose(a) (int32_t* output) */ +EXPORT void i32_approxdecomp_from_tndbl(const MOD_Z* module, // N + const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget + int32_t* res, uint64_t res_size, // res (in general, size ell.a_size) + const double* a, uint64_t a_size); // a + +/** @brief opaque type that represents a prepared matrix */ +typedef struct zn32_vmp_pmat_t ZN32_VMP_PMAT; + +/** @brief size in bytes of a prepared matrix (for custom allocation) */ +EXPORT uint64_t bytes_of_zn32_vmp_pmat(const MOD_Z* module, // N + uint64_t nrows, uint64_t ncols); // dimensions + +/** @brief allocates a prepared matrix (release with delete_zn32_vmp_pmat) */ +EXPORT ZN32_VMP_PMAT* new_zn32_vmp_pmat(const MOD_Z* module, // N + uint64_t nrows, uint64_t ncols); // dimensions + +/** @brief deletes a prepared matrix (release with free) */ +EXPORT void delete_zn32_vmp_pmat(ZN32_VMP_PMAT* ptr); // dimensions + +/** @brief prepares a vmp matrix (contiguous row-major version) */ +EXPORT void zn32_vmp_prepare_contiguous( // + const MOD_Z* module, + ZN32_VMP_PMAT* pmat, // output + const int32_t* mat, uint64_t nrows, uint64_t ncols); // a + +/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */ +EXPORT void zn32_vmp_prepare_dblptr( // + const MOD_Z* module, + ZN32_VMP_PMAT* pmat, // output + const int32_t** mat, uint64_t nrows, uint64_t ncols); // a + +/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */ +EXPORT void zn32_vmp_prepare_row( // + const MOD_Z* module, + ZN32_VMP_PMAT* pmat, // output + const int32_t* row, uint64_t row_i, uint64_t nrows, uint64_t ncols); // a + +/** @brief applies a vmp product (int32_t* input) */ +EXPORT void zn32_vmp_apply_i32( // + const MOD_Z* module, // + int32_t* res, uint64_t res_size, // res + const int32_t* a, uint64_t a_size, // a + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix + +/** @brief applies a vmp product (int16_t* input) */ +EXPORT void zn32_vmp_apply_i16( // + const MOD_Z* module, // + int32_t* res, uint64_t res_size, // res + const int16_t* a, uint64_t a_size, // a + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix + +/** @brief applies a vmp product (int8_t* input) */ +EXPORT void zn32_vmp_apply_i8( // + const MOD_Z* module, // + int32_t* res, uint64_t res_size, // res + const int8_t* a, uint64_t a_size, // a + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix + +// explicit conversions + +/** reduction mod 1, output in torus32 space */ +EXPORT void dbl_to_tn32(const MOD_Z* module, // + int32_t* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +); + +/** real centerlift mod 1, output in double space */ +EXPORT void tn32_to_dbl(const MOD_Z* module, // + double* res, uint64_t res_size, // res + const int32_t* a, uint64_t a_size // a +); + +/** round to the nearest int, output in i32 space. + * WARNING: ||a||_inf must be <= 2^18 in this function + */ +EXPORT void dbl_round_to_i32(const MOD_Z* module, // + int32_t* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +); + +/** small int (int32 space) to double + * WARNING: ||a||_inf must be <= 2^18 in this function + */ +EXPORT void i32_to_dbl(const MOD_Z* module, // + double* res, uint64_t res_size, // res + const int32_t* a, uint64_t a_size // a +); + +/** round to the nearest int, output in int64 space + * WARNING: ||a||_inf must be <= 2^50 in this function + */ +EXPORT void dbl_round_to_i64(const MOD_Z* module, // + int64_t* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +); + +/** small int (int64 space, <= 2^50) to double + * WARNING: ||a||_inf must be <= 2^50 in this function + */ +EXPORT void i64_to_dbl(const MOD_Z* module, // + double* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size // a +); + +#endif // SPQLIOS_ZN_ARITHMETIC_H diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/zn_arithmetic_plugin.h b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/zn_arithmetic_plugin.h new file mode 100644 index 0000000..eb573cc --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/zn_arithmetic_plugin.h @@ -0,0 +1,43 @@ +#ifndef SPQLIOS_ZN_ARITHMETIC_PLUGIN_H +#define SPQLIOS_ZN_ARITHMETIC_PLUGIN_H + +#include "zn_arithmetic.h" + +typedef typeof(i8_approxdecomp_from_tndbl) I8_APPROXDECOMP_FROM_TNDBL_F; +typedef typeof(i16_approxdecomp_from_tndbl) I16_APPROXDECOMP_FROM_TNDBL_F; +typedef typeof(i32_approxdecomp_from_tndbl) I32_APPROXDECOMP_FROM_TNDBL_F; +typedef typeof(bytes_of_zn32_vmp_pmat) BYTES_OF_ZN32_VMP_PMAT_F; +typedef typeof(zn32_vmp_prepare_contiguous) ZN32_VMP_PREPARE_CONTIGUOUS_F; +typedef typeof(zn32_vmp_prepare_dblptr) ZN32_VMP_PREPARE_DBLPTR_F; +typedef typeof(zn32_vmp_prepare_row) ZN32_VMP_PREPARE_ROW_F; +typedef typeof(zn32_vmp_apply_i32) ZN32_VMP_APPLY_I32_F; +typedef typeof(zn32_vmp_apply_i16) ZN32_VMP_APPLY_I16_F; +typedef typeof(zn32_vmp_apply_i8) ZN32_VMP_APPLY_I8_F; +typedef typeof(dbl_to_tn32) DBL_TO_TN32_F; +typedef typeof(tn32_to_dbl) TN32_TO_DBL_F; +typedef typeof(dbl_round_to_i32) DBL_ROUND_TO_I32_F; +typedef typeof(i32_to_dbl) I32_TO_DBL_F; +typedef typeof(dbl_round_to_i64) DBL_ROUND_TO_I64_F; +typedef typeof(i64_to_dbl) I64_TO_DBL_F; + +typedef struct z_module_vtable_t Z_MODULE_VTABLE; +struct z_module_vtable_t { + I8_APPROXDECOMP_FROM_TNDBL_F* i8_approxdecomp_from_tndbl; + I16_APPROXDECOMP_FROM_TNDBL_F* i16_approxdecomp_from_tndbl; + I32_APPROXDECOMP_FROM_TNDBL_F* i32_approxdecomp_from_tndbl; + BYTES_OF_ZN32_VMP_PMAT_F* bytes_of_zn32_vmp_pmat; + ZN32_VMP_PREPARE_CONTIGUOUS_F* zn32_vmp_prepare_contiguous; + ZN32_VMP_PREPARE_DBLPTR_F* zn32_vmp_prepare_dblptr; + ZN32_VMP_PREPARE_ROW_F* zn32_vmp_prepare_row; + ZN32_VMP_APPLY_I32_F* zn32_vmp_apply_i32; + ZN32_VMP_APPLY_I16_F* zn32_vmp_apply_i16; + ZN32_VMP_APPLY_I8_F* zn32_vmp_apply_i8; + DBL_TO_TN32_F* dbl_to_tn32; + TN32_TO_DBL_F* tn32_to_dbl; + DBL_ROUND_TO_I32_F* dbl_round_to_i32; + I32_TO_DBL_F* i32_to_dbl; + DBL_ROUND_TO_I64_F* dbl_round_to_i64; + I64_TO_DBL_F* i64_to_dbl; +}; + +#endif // SPQLIOS_ZN_ARITHMETIC_PLUGIN_H diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/zn_arithmetic_private.h b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/zn_arithmetic_private.h new file mode 100644 index 0000000..2de8a84 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/zn_arithmetic_private.h @@ -0,0 +1,164 @@ +#ifndef SPQLIOS_ZN_ARITHMETIC_PRIVATE_H +#define SPQLIOS_ZN_ARITHMETIC_PRIVATE_H + +#include "../commons_private.h" +#include "zn_arithmetic.h" +#include "zn_arithmetic_plugin.h" + +typedef struct main_z_module_precomp_t MAIN_Z_MODULE_PRECOMP; +struct main_z_module_precomp_t { + // TODO +}; + +typedef union z_module_precomp_t Z_MODULE_PRECOMP; +union z_module_precomp_t { + MAIN_Z_MODULE_PRECOMP main; +}; + +void main_init_z_module_precomp(MOD_Z* module); + +void main_finalize_z_module_precomp(MOD_Z* module); + +/** @brief opaque structure that describes the modules (RnX,ZnX,TnX) and the hardware */ +struct z_module_info_t { + Z_MODULE_TYPE mtype; + Z_MODULE_VTABLE vtable; + Z_MODULE_PRECOMP precomp; + void* custom; + void (*custom_deleter)(void*); +}; + +void init_z_module_info(MOD_Z* module, Z_MODULE_TYPE mtype); + +void main_init_z_module_vtable(MOD_Z* module); + +struct tndbl_approxdecomp_gadget_t { + uint64_t k; + uint64_t ell; + double add_cst; // 3.2^51-(K.ell) + 1/2.(sum 2^-(i+1)K) + int64_t and_mask; // (2^K)-1 + int64_t sub_cst; // 2^(K-1) + uint8_t rshifts[64]; // 2^(ell-1-i).K for i in [0:ell-1] +}; + +/** @brief sets res = gadget_decompose(a) (int8_t* output) */ +EXPORT void default_i8_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N + const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget + int8_t* res, uint64_t res_size, // res (in general, size ell.a_size) + const double* a, uint64_t a_size); // a + +/** @brief sets res = gadget_decompose(a) (int16_t* output) */ +EXPORT void default_i16_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N + const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget + int16_t* res, + uint64_t res_size, // res (in general, size ell.a_size) + const double* a, uint64_t a_size); // a +/** @brief sets res = gadget_decompose(a) (int32_t* output) */ +EXPORT void default_i32_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N + const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget + int32_t* res, + uint64_t res_size, // res (in general, size ell.a_size) + const double* a, uint64_t a_size); // a + +/** @brief prepares a vmp matrix (contiguous row-major version) */ +EXPORT void default_zn32_vmp_prepare_contiguous_ref( // + const MOD_Z* module, + ZN32_VMP_PMAT* pmat, // output + const int32_t* mat, uint64_t nrows, uint64_t ncols // a +); + +/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */ +EXPORT void default_zn32_vmp_prepare_dblptr_ref( // + const MOD_Z* module, + ZN32_VMP_PMAT* pmat, // output + const int32_t** mat, uint64_t nrows, uint64_t ncols // a +); + +/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */ +EXPORT void default_zn32_vmp_prepare_row_ref( // + const MOD_Z* module, + ZN32_VMP_PMAT* pmat, // output + const int32_t* row, uint64_t row_i, uint64_t nrows, uint64_t ncols // a +); + +/** @brief applies a vmp product (int32_t* input) */ +EXPORT void default_zn32_vmp_apply_i32_ref( // + const MOD_Z* module, // + int32_t* res, uint64_t res_size, // res + const int32_t* a, uint64_t a_size, // a + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix + +/** @brief applies a vmp product (int16_t* input) */ +EXPORT void default_zn32_vmp_apply_i16_ref( // + const MOD_Z* module, // N + int32_t* res, uint64_t res_size, // res + const int16_t* a, uint64_t a_size, // a + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix + +/** @brief applies a vmp product (int8_t* input) */ +EXPORT void default_zn32_vmp_apply_i8_ref( // + const MOD_Z* module, // N + int32_t* res, uint64_t res_size, // res + const int8_t* a, uint64_t a_size, // a + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix + +/** @brief applies a vmp product (int32_t* input) */ +EXPORT void default_zn32_vmp_apply_i32_avx( // + const MOD_Z* module, // + int32_t* res, uint64_t res_size, // res + const int32_t* a, uint64_t a_size, // a + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix + +/** @brief applies a vmp product (int16_t* input) */ +EXPORT void default_zn32_vmp_apply_i16_avx( // + const MOD_Z* module, // N + int32_t* res, uint64_t res_size, // res + const int16_t* a, uint64_t a_size, // a + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix + +/** @brief applies a vmp product (int8_t* input) */ +EXPORT void default_zn32_vmp_apply_i8_avx( // + const MOD_Z* module, // N + int32_t* res, uint64_t res_size, // res + const int8_t* a, uint64_t a_size, // a + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix + +// explicit conversions + +/** reduction mod 1, output in torus32 space */ +EXPORT void dbl_to_tn32_ref(const MOD_Z* module, // + int32_t* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +); + +/** real centerlift mod 1, output in double space */ +EXPORT void tn32_to_dbl_ref(const MOD_Z* module, // + double* res, uint64_t res_size, // res + const int32_t* a, uint64_t a_size // a +); + +/** round to the nearest int, output in i32 space */ +EXPORT void dbl_round_to_i32_ref(const MOD_Z* module, // + int32_t* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +); + +/** small int (int32 space) to double */ +EXPORT void i32_to_dbl_ref(const MOD_Z* module, // + double* res, uint64_t res_size, // res + const int32_t* a, uint64_t a_size // a +); + +/** round to the nearest int, output in int64 space */ +EXPORT void dbl_round_to_i64_ref(const MOD_Z* module, // + int64_t* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +); + +/** small int (int64 space) to double */ +EXPORT void i64_to_dbl_ref(const MOD_Z* module, // + double* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size // a +); + +#endif // SPQLIOS_ZN_ARITHMETIC_PRIVATE_H diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/zn_conversions_ref.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/zn_conversions_ref.c new file mode 100644 index 0000000..f016a71 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/zn_conversions_ref.c @@ -0,0 +1,108 @@ +#include + +#include "zn_arithmetic_private.h" + +typedef union { + double dv; + int64_t s64v; + int32_t s32v; + uint64_t u64v; + uint32_t u32v; +} di_t; + +/** reduction mod 1, output in torus32 space */ +EXPORT void dbl_to_tn32_ref(const MOD_Z* module, // + int32_t* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +) { + static const double ADD_CST = 0.5 + (double)(INT64_C(3) << (51 - 32)); + static const int32_t XOR_CST = (INT32_C(1) << 31); + const uint64_t msize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + di_t t = {.dv = a[i] + ADD_CST}; + res[i] = t.s32v ^ XOR_CST; + } + memset(res + msize, 0, (res_size - msize) * sizeof(int32_t)); +} + +/** real centerlift mod 1, output in double space */ +EXPORT void tn32_to_dbl_ref(const MOD_Z* module, // + double* res, uint64_t res_size, // res + const int32_t* a, uint64_t a_size // a +) { + static const uint32_t XOR_CST = (UINT32_C(1) << 31); + static const di_t OR_CST = {.dv = (double)(INT64_C(1) << (52 - 32))}; + static const double SUB_CST = 0.5 + (double)(INT64_C(1) << (52 - 32)); + const uint64_t msize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + uint32_t ai = a[i] ^ XOR_CST; + di_t t = {.u64v = OR_CST.u64v | (uint64_t)ai}; + res[i] = t.dv - SUB_CST; + } + memset(res + msize, 0, (res_size - msize) * sizeof(double)); +} + +/** round to the nearest int, output in i32 space */ +EXPORT void dbl_round_to_i32_ref(const MOD_Z* module, // + int32_t* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +) { + static const double ADD_CST = (double)((INT64_C(3) << (51)) + (INT64_C(1) << (31))); + static const int32_t XOR_CST = INT32_C(1) << 31; + const uint64_t msize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + di_t t = {.dv = a[i] + ADD_CST}; + res[i] = t.s32v ^ XOR_CST; + } + memset(res + msize, 0, (res_size - msize) * sizeof(int32_t)); +} + +/** small int (int32 space) to double */ +EXPORT void i32_to_dbl_ref(const MOD_Z* module, // + double* res, uint64_t res_size, // res + const int32_t* a, uint64_t a_size // a +) { + static const uint32_t XOR_CST = (UINT32_C(1) << 31); + static const di_t OR_CST = {.dv = (double)(INT64_C(1) << 52)}; + static const double SUB_CST = (double)((INT64_C(1) << 52) + (INT64_C(1) << 31)); + const uint64_t msize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + uint32_t ai = a[i] ^ XOR_CST; + di_t t = {.u64v = OR_CST.u64v | (uint64_t)ai}; + res[i] = t.dv - SUB_CST; + } + memset(res + msize, 0, (res_size - msize) * sizeof(double)); +} + +/** round to the nearest int, output in int64 space */ +EXPORT void dbl_round_to_i64_ref(const MOD_Z* module, // + int64_t* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +) { + static const double ADD_CST = (double)(INT64_C(3) << (51)); + static const int64_t AND_CST = (INT64_C(1) << 52) - 1; + static const int64_t SUB_CST = INT64_C(1) << 51; + const uint64_t msize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + di_t t = {.dv = a[i] + ADD_CST}; + res[i] = (t.s64v & AND_CST) - SUB_CST; + } + memset(res + msize, 0, (res_size - msize) * sizeof(int64_t)); +} + +/** small int (int64 space) to double */ +EXPORT void i64_to_dbl_ref(const MOD_Z* module, // + double* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size // a +) { + static const uint64_t ADD_CST = UINT64_C(1) << 51; + static const uint64_t AND_CST = (UINT64_C(1) << 52) - 1; + static const di_t OR_CST = {.dv = (INT64_C(1) << 52)}; + static const double SUB_CST = INT64_C(3) << 51; + const uint64_t msize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + di_t t = {.u64v = ((a[i] + ADD_CST) & AND_CST) | OR_CST.u64v}; + res[i] = t.dv - SUB_CST; + } + memset(res + msize, 0, (res_size - msize) * sizeof(double)); +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/zn_vmp_int16_avx.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/zn_vmp_int16_avx.c new file mode 100644 index 0000000..563f199 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/zn_vmp_int16_avx.c @@ -0,0 +1,4 @@ +#define INTTYPE int16_t +#define INTSN i16 + +#include "zn_vmp_int32_avx.c" diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/zn_vmp_int16_ref.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/zn_vmp_int16_ref.c new file mode 100644 index 0000000..0626c9b --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/zn_vmp_int16_ref.c @@ -0,0 +1,4 @@ +#define INTTYPE int16_t +#define INTSN i16 + +#include "zn_vmp_int32_ref.c" diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/zn_vmp_int32_avx.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/zn_vmp_int32_avx.c new file mode 100644 index 0000000..3fbc8fb --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/zn_vmp_int32_avx.c @@ -0,0 +1,223 @@ +// This file is actually a template: it will be compiled multiple times with +// different INTTYPES +#ifndef INTTYPE +#define INTTYPE int32_t +#define INTSN i32 +#endif + +#include +#include + +#include "zn_arithmetic_private.h" + +#define concat_inner(aa, bb, cc) aa##_##bb##_##cc +#define concat(aa, bb, cc) concat_inner(aa, bb, cc) +#define zn32_vec_fn(cc) concat(zn32_vec, INTSN, cc) + +static void zn32_vec_mat32cols_avx_prefetch(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b) { + if (nrows == 0) { + memset(res, 0, 32 * sizeof(int32_t)); + return; + } + const int32_t* bb = b; + const int32_t* pref_bb = b; + const uint64_t pref_iters = 128; + const uint64_t pref_start = pref_iters < nrows ? pref_iters : nrows; + const uint64_t pref_last = pref_iters > nrows ? 0 : nrows - pref_iters; + // let's do some prefetching of the GSW key, since on some cpus, + // it helps + for (uint64_t i = 0; i < pref_start; ++i) { + __builtin_prefetch(pref_bb, 0, _MM_HINT_T0); + __builtin_prefetch(pref_bb + 16, 0, _MM_HINT_T0); + pref_bb += 32; + } + // we do the first iteration + __m256i x = _mm256_set1_epi32(a[0]); + __m256i r0 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb))); + __m256i r1 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8))); + __m256i r2 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16))); + __m256i r3 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 24))); + bb += 32; + uint64_t row = 1; + for (; // + row < pref_last; // + ++row, bb += 32) { + // prefetch the next iteration + __builtin_prefetch(pref_bb, 0, _MM_HINT_T0); + __builtin_prefetch(pref_bb + 16, 0, _MM_HINT_T0); + pref_bb += 32; + INTTYPE ai = a[row]; + if (ai == 0) continue; + x = _mm256_set1_epi32(ai); + r0 = _mm256_add_epi32(r0, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb)))); + r1 = _mm256_add_epi32(r1, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8)))); + r2 = _mm256_add_epi32(r2, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16)))); + r3 = _mm256_add_epi32(r3, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 24)))); + } + for (; // + row < nrows; // + ++row, bb += 32) { + INTTYPE ai = a[row]; + if (ai == 0) continue; + x = _mm256_set1_epi32(ai); + r0 = _mm256_add_epi32(r0, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb)))); + r1 = _mm256_add_epi32(r1, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8)))); + r2 = _mm256_add_epi32(r2, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16)))); + r3 = _mm256_add_epi32(r3, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 24)))); + } + _mm256_storeu_si256((__m256i*)(res), r0); + _mm256_storeu_si256((__m256i*)(res + 8), r1); + _mm256_storeu_si256((__m256i*)(res + 16), r2); + _mm256_storeu_si256((__m256i*)(res + 24), r3); +} + +void zn32_vec_fn(mat32cols_avx)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) { + if (nrows == 0) { + memset(res, 0, 32 * sizeof(int32_t)); + return; + } + const INTTYPE* aa = a; + const INTTYPE* const aaend = a + nrows; + const int32_t* bb = b; + __m256i x = _mm256_set1_epi32(*aa); + __m256i r0 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb))); + __m256i r1 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8))); + __m256i r2 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16))); + __m256i r3 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 24))); + bb += b_sl; + ++aa; + for (; // + aa < aaend; // + bb += b_sl, ++aa) { + INTTYPE ai = *aa; + if (ai == 0) continue; + x = _mm256_set1_epi32(ai); + r0 = _mm256_add_epi32(r0, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb)))); + r1 = _mm256_add_epi32(r1, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8)))); + r2 = _mm256_add_epi32(r2, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16)))); + r3 = _mm256_add_epi32(r3, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 24)))); + } + _mm256_storeu_si256((__m256i*)(res), r0); + _mm256_storeu_si256((__m256i*)(res + 8), r1); + _mm256_storeu_si256((__m256i*)(res + 16), r2); + _mm256_storeu_si256((__m256i*)(res + 24), r3); +} + +void zn32_vec_fn(mat24cols_avx)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) { + if (nrows == 0) { + memset(res, 0, 24 * sizeof(int32_t)); + return; + } + const INTTYPE* aa = a; + const INTTYPE* const aaend = a + nrows; + const int32_t* bb = b; + __m256i x = _mm256_set1_epi32(*aa); + __m256i r0 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb))); + __m256i r1 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8))); + __m256i r2 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16))); + bb += b_sl; + ++aa; + for (; // + aa < aaend; // + bb += b_sl, ++aa) { + INTTYPE ai = *aa; + if (ai == 0) continue; + x = _mm256_set1_epi32(ai); + r0 = _mm256_add_epi32(r0, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb)))); + r1 = _mm256_add_epi32(r1, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8)))); + r2 = _mm256_add_epi32(r2, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16)))); + } + _mm256_storeu_si256((__m256i*)(res), r0); + _mm256_storeu_si256((__m256i*)(res + 8), r1); + _mm256_storeu_si256((__m256i*)(res + 16), r2); +} +void zn32_vec_fn(mat16cols_avx)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) { + if (nrows == 0) { + memset(res, 0, 16 * sizeof(int32_t)); + return; + } + const INTTYPE* aa = a; + const INTTYPE* const aaend = a + nrows; + const int32_t* bb = b; + __m256i x = _mm256_set1_epi32(*aa); + __m256i r0 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb))); + __m256i r1 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8))); + bb += b_sl; + ++aa; + for (; // + aa < aaend; // + bb += b_sl, ++aa) { + INTTYPE ai = *aa; + if (ai == 0) continue; + x = _mm256_set1_epi32(ai); + r0 = _mm256_add_epi32(r0, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb)))); + r1 = _mm256_add_epi32(r1, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8)))); + } + _mm256_storeu_si256((__m256i*)(res), r0); + _mm256_storeu_si256((__m256i*)(res + 8), r1); +} + +void zn32_vec_fn(mat8cols_avx)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) { + if (nrows == 0) { + memset(res, 0, 8 * sizeof(int32_t)); + return; + } + const INTTYPE* aa = a; + const INTTYPE* const aaend = a + nrows; + const int32_t* bb = b; + __m256i x = _mm256_set1_epi32(*aa); + __m256i r0 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb))); + bb += b_sl; + ++aa; + for (; // + aa < aaend; // + bb += b_sl, ++aa) { + INTTYPE ai = *aa; + if (ai == 0) continue; + x = _mm256_set1_epi32(ai); + r0 = _mm256_add_epi32(r0, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb)))); + } + _mm256_storeu_si256((__m256i*)(res), r0); +} + +typedef void (*vm_f)(uint64_t nrows, // + int32_t* res, // + const INTTYPE* a, // + const int32_t* b, uint64_t b_sl // +); +static const vm_f zn32_vec_mat8kcols_avx[4] = { // + zn32_vec_fn(mat8cols_avx), // + zn32_vec_fn(mat16cols_avx), // + zn32_vec_fn(mat24cols_avx), // + zn32_vec_fn(mat32cols_avx)}; + +/** @brief applies a vmp product (int32_t* input) */ +EXPORT void concat(default_zn32_vmp_apply, INTSN, avx)( // + const MOD_Z* module, // + int32_t* res, uint64_t res_size, // + const INTTYPE* a, uint64_t a_size, // + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) { + const uint64_t rows = a_size < nrows ? a_size : nrows; + const uint64_t cols = res_size < ncols ? res_size : ncols; + const uint64_t ncolblk = cols >> 5; + const uint64_t ncolrem = cols & 31; + // copy the first full blocks + const uint64_t full_blk_size = nrows * 32; + const int32_t* mat = (int32_t*)pmat; + int32_t* rr = res; + for (uint64_t blk = 0; // + blk < ncolblk; // + ++blk, mat += full_blk_size, rr += 32) { + zn32_vec_mat32cols_avx_prefetch(rows, rr, a, mat); + } + // last block + if (ncolrem) { + uint64_t orig_rem = ncols - (ncolblk << 5); + uint64_t b_sl = orig_rem >= 32 ? 32 : orig_rem; + int32_t tmp[32]; + zn32_vec_mat8kcols_avx[(ncolrem - 1) >> 3](rows, tmp, a, mat, b_sl); + memcpy(rr, tmp, ncolrem * sizeof(int32_t)); + } + // trailing bytes + memset(res + cols, 0, (res_size - cols) * sizeof(int32_t)); +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/zn_vmp_int32_ref.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/zn_vmp_int32_ref.c new file mode 100644 index 0000000..c3d0bc9 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/zn_vmp_int32_ref.c @@ -0,0 +1,88 @@ +// This file is actually a template: it will be compiled multiple times with +// different INTTYPES +#ifndef INTTYPE +#define INTTYPE int32_t +#define INTSN i32 +#endif + +#include + +#include "zn_arithmetic_private.h" + +#define concat_inner(aa, bb, cc) aa##_##bb##_##cc +#define concat(aa, bb, cc) concat_inner(aa, bb, cc) +#define zn32_vec_fn(cc) concat(zn32_vec, INTSN, cc) + +// the ref version shares the same implementation for each fixed column size +// optimized implementations may do something different. +static __always_inline void IMPL_zn32_vec_matcols_ref( + const uint64_t NCOLS, // fixed number of columns + uint64_t nrows, // nrows of b + int32_t* res, // result: size NCOLS, only the first min(b_sl, NCOLS) are relevant + const INTTYPE* a, // a: nrows-sized vector + const int32_t* b, uint64_t b_sl // b: nrows * min(b_sl, NCOLS) matrix +) { + memset(res, 0, NCOLS * sizeof(int32_t)); + for (uint64_t row = 0; row < nrows; ++row) { + int32_t ai = a[row]; + const int32_t* bb = b + row * b_sl; + for (uint64_t i = 0; i < NCOLS; ++i) { + res[i] += ai * bb[i]; + } + } +} + +void zn32_vec_fn(mat32cols_ref)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) { + IMPL_zn32_vec_matcols_ref(32, nrows, res, a, b, b_sl); +} +void zn32_vec_fn(mat24cols_ref)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) { + IMPL_zn32_vec_matcols_ref(24, nrows, res, a, b, b_sl); +} +void zn32_vec_fn(mat16cols_ref)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) { + IMPL_zn32_vec_matcols_ref(16, nrows, res, a, b, b_sl); +} +void zn32_vec_fn(mat8cols_ref)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) { + IMPL_zn32_vec_matcols_ref(8, nrows, res, a, b, b_sl); +} + +typedef void (*vm_f)(uint64_t nrows, // + int32_t* res, // + const INTTYPE* a, // + const int32_t* b, uint64_t b_sl // +); +static const vm_f zn32_vec_mat8kcols_ref[4] = { // + zn32_vec_fn(mat8cols_ref), // + zn32_vec_fn(mat16cols_ref), // + zn32_vec_fn(mat24cols_ref), // + zn32_vec_fn(mat32cols_ref)}; + +/** @brief applies a vmp product (int32_t* input) */ +EXPORT void concat(default_zn32_vmp_apply, INTSN, ref)( // + const MOD_Z* module, // + int32_t* res, uint64_t res_size, // + const INTTYPE* a, uint64_t a_size, // + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) { + const uint64_t rows = a_size < nrows ? a_size : nrows; + const uint64_t cols = res_size < ncols ? res_size : ncols; + const uint64_t ncolblk = cols >> 5; + const uint64_t ncolrem = cols & 31; + // copy the first full blocks + const uint32_t full_blk_size = nrows * 32; + const int32_t* mat = (int32_t*)pmat; + int32_t* rr = res; + for (uint64_t blk = 0; // + blk < ncolblk; // + ++blk, mat += full_blk_size, rr += 32) { + zn32_vec_fn(mat32cols_ref)(rows, rr, a, mat, 32); + } + // last block + if (ncolrem) { + uint64_t orig_rem = ncols - (ncolblk << 5); + uint64_t b_sl = orig_rem >= 32 ? 32 : orig_rem; + int32_t tmp[32]; + zn32_vec_mat8kcols_ref[(ncolrem - 1) >> 3](rows, tmp, a, mat, b_sl); + memcpy(rr, tmp, ncolrem * sizeof(int32_t)); + } + // trailing bytes + memset(res + cols, 0, (res_size - cols) * sizeof(int32_t)); +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/zn_vmp_int8_avx.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/zn_vmp_int8_avx.c new file mode 100644 index 0000000..74480aa --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/zn_vmp_int8_avx.c @@ -0,0 +1,4 @@ +#define INTTYPE int8_t +#define INTSN i8 + +#include "zn_vmp_int32_avx.c" diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/zn_vmp_int8_ref.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/zn_vmp_int8_ref.c new file mode 100644 index 0000000..d1de571 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/zn_vmp_int8_ref.c @@ -0,0 +1,4 @@ +#define INTTYPE int8_t +#define INTSN i8 + +#include "zn_vmp_int32_ref.c" diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/zn_vmp_ref.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/zn_vmp_ref.c new file mode 100644 index 0000000..dd0b527 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/zn_vmp_ref.c @@ -0,0 +1,185 @@ +#include + +#include "zn_arithmetic_private.h" + +/** @brief size in bytes of a prepared matrix (for custom allocation) */ +EXPORT uint64_t bytes_of_zn32_vmp_pmat(const MOD_Z* module, // N + uint64_t nrows, uint64_t ncols // dimensions +) { + return (nrows * ncols + 7) * sizeof(int32_t); +} + +/** @brief allocates a prepared matrix (release with delete_zn32_vmp_pmat) */ +EXPORT ZN32_VMP_PMAT* new_zn32_vmp_pmat(const MOD_Z* module, // N + uint64_t nrows, uint64_t ncols) { + return (ZN32_VMP_PMAT*)spqlios_alloc(bytes_of_zn32_vmp_pmat(module, nrows, ncols)); +} + +/** @brief deletes a prepared matrix (release with free) */ +EXPORT void delete_zn32_vmp_pmat(ZN32_VMP_PMAT* ptr) { spqlios_free(ptr); } + +/** @brief prepares a vmp matrix (contiguous row-major version) */ +EXPORT void default_zn32_vmp_prepare_contiguous_ref( // + const MOD_Z* module, + ZN32_VMP_PMAT* pmat, // output + const int32_t* mat, uint64_t nrows, uint64_t ncols // a +) { + int32_t* const out = (int32_t*)pmat; + const uint64_t nblk = ncols >> 5; + const uint64_t ncols_rem = ncols & 31; + const uint64_t final_elems = (8 - nrows * ncols) & 7; + for (uint64_t blk = 0; blk < nblk; ++blk) { + int32_t* outblk = out + blk * nrows * 32; + const int32_t* srcblk = mat + blk * 32; + for (uint64_t row = 0; row < nrows; ++row) { + int32_t* dest = outblk + row * 32; + const int32_t* src = srcblk + row * ncols; + for (uint64_t i = 0; i < 32; ++i) { + dest[i] = src[i]; + } + } + } + // copy the last block if any + if (ncols_rem) { + int32_t* outblk = out + nblk * nrows * 32; + const int32_t* srcblk = mat + nblk * 32; + for (uint64_t row = 0; row < nrows; ++row) { + int32_t* dest = outblk + row * ncols_rem; + const int32_t* src = srcblk + row * ncols; + for (uint64_t i = 0; i < ncols_rem; ++i) { + dest[i] = src[i]; + } + } + } + // zero-out the final elements that may be accessed + if (final_elems) { + int32_t* f = out + nrows * ncols; + for (uint64_t i = 0; i < final_elems; ++i) { + f[i] = 0; + } + } +} + +/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */ +EXPORT void default_zn32_vmp_prepare_dblptr_ref( // + const MOD_Z* module, + ZN32_VMP_PMAT* pmat, // output + const int32_t** mat, uint64_t nrows, uint64_t ncols // a +) { + for (uint64_t row_i = 0; row_i < nrows; ++row_i) { + default_zn32_vmp_prepare_row_ref(module, pmat, mat[row_i], row_i, nrows, ncols); + } +} + +/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */ +EXPORT void default_zn32_vmp_prepare_row_ref( // + const MOD_Z* module, + ZN32_VMP_PMAT* pmat, // output + const int32_t* row, uint64_t row_i, uint64_t nrows, uint64_t ncols // a +) { + int32_t* const out = (int32_t*)pmat; + const uint64_t nblk = ncols >> 5; + const uint64_t ncols_rem = ncols & 31; + const uint64_t final_elems = (row_i == nrows - 1) && (8 - nrows * ncols) & 7; + for (uint64_t blk = 0; blk < nblk; ++blk) { + int32_t* outblk = out + blk * nrows * 32; + int32_t* dest = outblk + row_i * 32; + const int32_t* src = row + blk * 32; + for (uint64_t i = 0; i < 32; ++i) { + dest[i] = src[i]; + } + } + // copy the last block if any + if (ncols_rem) { + int32_t* outblk = out + nblk * nrows * 32; + int32_t* dest = outblk + row_i * ncols_rem; + const int32_t* src = row + nblk * 32; + for (uint64_t i = 0; i < ncols_rem; ++i) { + dest[i] = src[i]; + } + } + // zero-out the final elements that may be accessed + if (final_elems) { + int32_t* f = out + nrows * ncols; + for (uint64_t i = 0; i < final_elems; ++i) { + f[i] = 0; + } + } +} + +#if 0 + +#define IMPL_zn32_vec_ixxx_matyyycols_ref(NCOLS) \ + memset(res, 0, NCOLS * sizeof(int32_t)); \ + for (uint64_t row = 0; row < nrows; ++row) { \ + int32_t ai = a[row]; \ + const int32_t* bb = b + row * b_sl; \ + for (uint64_t i = 0; i < NCOLS; ++i) { \ + res[i] += ai * bb[i]; \ + } \ + } + +#define IMPL_zn32_vec_ixxx_mat8cols_ref() IMPL_zn32_vec_ixxx_matyyycols_ref(8) +#define IMPL_zn32_vec_ixxx_mat16cols_ref() IMPL_zn32_vec_ixxx_matyyycols_ref(16) +#define IMPL_zn32_vec_ixxx_mat24cols_ref() IMPL_zn32_vec_ixxx_matyyycols_ref(24) +#define IMPL_zn32_vec_ixxx_mat32cols_ref() IMPL_zn32_vec_ixxx_matyyycols_ref(32) + +void zn32_vec_i8_mat32cols_ref(uint64_t nrows, int32_t* res, const int8_t* a, const int32_t* b, uint64_t b_sl) { + IMPL_zn32_vec_ixxx_mat32cols_ref() +} +void zn32_vec_i16_mat32cols_ref(uint64_t nrows, int32_t* res, const int16_t* a, const int32_t* b, uint64_t b_sl) { + IMPL_zn32_vec_ixxx_mat32cols_ref() +} + +void zn32_vec_i32_mat32cols_ref(uint64_t nrows, int32_t* res, const int32_t* a, const int32_t* b, uint64_t b_sl) { + IMPL_zn32_vec_ixxx_mat32cols_ref() +} +void zn32_vec_i32_mat24cols_ref(uint64_t nrows, int32_t* res, const int32_t* a, const int32_t* b, uint64_t b_sl) { + IMPL_zn32_vec_ixxx_mat24cols_ref() +} +void zn32_vec_i32_mat16cols_ref(uint64_t nrows, int32_t* res, const int32_t* a, const int32_t* b, uint64_t b_sl) { + IMPL_zn32_vec_ixxx_mat16cols_ref() +} +void zn32_vec_i32_mat8cols_ref(uint64_t nrows, int32_t* res, const int32_t* a, const int32_t* b, uint64_t b_sl) { + IMPL_zn32_vec_ixxx_mat8cols_ref() +} +typedef void (*zn32_vec_i32_mat8kcols_ref_f)(uint64_t nrows, // + int32_t* res, // + const int32_t* a, // + const int32_t* b, uint64_t b_sl // +); +zn32_vec_i32_mat8kcols_ref_f zn32_vec_i32_mat8kcols_ref[4] = { // + zn32_vec_i32_mat8cols_ref, zn32_vec_i32_mat16cols_ref, // + zn32_vec_i32_mat24cols_ref, zn32_vec_i32_mat32cols_ref}; + +/** @brief applies a vmp product (int32_t* input) */ +EXPORT void default_zn32_vmp_apply_i32_ref(const MOD_Z* module, // + int32_t* res, uint64_t res_size, // + const int32_t* a, uint64_t a_size, // + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) { + const uint64_t rows = a_size < nrows ? a_size : nrows; + const uint64_t cols = res_size < ncols ? res_size : ncols; + const uint64_t ncolblk = cols >> 5; + const uint64_t ncolrem = cols & 31; + // copy the first full blocks + const uint32_t full_blk_size = nrows * 32; + const int32_t* mat = (int32_t*)pmat; + int32_t* rr = res; + for (uint64_t blk = 0; // + blk < ncolblk; // + ++blk, mat += full_blk_size, rr += 32) { + zn32_vec_i32_mat32cols_ref(rows, rr, a, mat, 32); + } + // last block + if (ncolrem) { + uint64_t orig_rem = ncols - (ncolblk << 5); + uint64_t b_sl = orig_rem >= 32 ? 32 : orig_rem; + int32_t tmp[32]; + zn32_vec_i32_mat8kcols_ref[(ncolrem - 1) >> 3](rows, tmp, a, mat, b_sl); + memcpy(rr, tmp, ncolrem * sizeof(int32_t)); + } + // trailing bytes + memset(res + cols, 0, (res_size - cols) * sizeof(int32_t)); +} + +#endif diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/znx_small.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/znx_small.c new file mode 100644 index 0000000..24d3ef6 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/arithmetic/znx_small.c @@ -0,0 +1,38 @@ +#include "vec_znx_arithmetic_private.h" + +/** @brief res = a * b : small integer polynomial product */ +EXPORT void fft64_znx_small_single_product(const MODULE* module, // N + int64_t* res, // output + const int64_t* a, // a + const int64_t* b, // b + uint8_t* tmp) { + const uint64_t nn = module->nn; + double* const ffta = (double*)tmp; + double* const fftb = ((double*)tmp) + nn; + reim_from_znx64(module->mod.fft64.p_conv, ffta, a); + reim_from_znx64(module->mod.fft64.p_conv, fftb, b); + reim_fft(module->mod.fft64.p_fft, ffta); + reim_fft(module->mod.fft64.p_fft, fftb); + reim_fftvec_mul_simple(module->m, ffta, ffta, fftb); + reim_ifft(module->mod.fft64.p_ifft, ffta); + reim_to_znx64(module->mod.fft64.p_reim_to_znx, res, ffta); +} + +/** @brief tmp bytes required for znx_small_single_product */ +EXPORT uint64_t fft64_znx_small_single_product_tmp_bytes(const MODULE* module, uint64_t nn) { + return 2 * nn * sizeof(double); +} + +/** @brief res = a * b : small integer polynomial product */ +EXPORT void znx_small_single_product(const MODULE* module, // N + int64_t* res, // output + const int64_t* a, // a + const int64_t* b, // b + uint8_t* tmp) { + module->func.znx_small_single_product(module, res, a, b, tmp); +} + +/** @brief tmp bytes required for znx_small_single_product */ +EXPORT uint64_t znx_small_single_product_tmp_bytes(const MODULE* module, uint64_t nn) { + return module->func.znx_small_single_product_tmp_bytes(module, nn); +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/coeffs/coeffs_arithmetic.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/coeffs/coeffs_arithmetic.c new file mode 100644 index 0000000..52d0bca --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/coeffs/coeffs_arithmetic.c @@ -0,0 +1,524 @@ +#include "coeffs_arithmetic.h" + +#include +#include + +/** res = a + b */ +EXPORT void znx_add_i64_ref(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b) { + for (uint64_t i = 0; i < nn; ++i) { + res[i] = a[i] + b[i]; + } +} +/** res = a - b */ +EXPORT void znx_sub_i64_ref(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b) { + for (uint64_t i = 0; i < nn; ++i) { + res[i] = a[i] - b[i]; + } +} + +EXPORT void znx_negate_i64_ref(uint64_t nn, int64_t* res, const int64_t* a) { + for (uint64_t i = 0; i < nn; ++i) { + res[i] = -a[i]; + } +} +EXPORT void znx_copy_i64_ref(uint64_t nn, int64_t* res, const int64_t* a) { memcpy(res, a, nn * sizeof(int64_t)); } + +EXPORT void znx_zero_i64_ref(uint64_t nn, int64_t* res) { memset(res, 0, nn * sizeof(int64_t)); } + +EXPORT void rnx_divide_by_m_ref(uint64_t n, double m, double* res, const double* a) { + const double invm = 1. / m; + for (uint64_t i = 0; i < n; ++i) { + res[i] = a[i] * invm; + } +} + +EXPORT void rnx_rotate_f64(uint64_t nn, int64_t p, double* res, const double* in) { + uint64_t a = (-p) & (2 * nn - 1); // a= (-p) (pos)mod (2*nn) + + if (a < nn) { // rotate to the left + uint64_t nma = nn - a; + // rotate first half + for (uint64_t j = 0; j < nma; j++) { + res[j] = in[j + a]; + } + for (uint64_t j = nma; j < nn; j++) { + res[j] = -in[j - nma]; + } + } else { + a -= nn; + uint64_t nma = nn - a; + for (uint64_t j = 0; j < nma; j++) { + res[j] = -in[j + a]; + } + for (uint64_t j = nma; j < nn; j++) { + // rotate first half + res[j] = in[j - nma]; + } + } +} + +EXPORT void znx_rotate_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in) { + uint64_t a = (-p) & (2 * nn - 1); // a= (-p) (pos)mod (2*nn) + + if (a < nn) { // rotate to the left + uint64_t nma = nn - a; + // rotate first half + for (uint64_t j = 0; j < nma; j++) { + res[j] = in[j + a]; + } + for (uint64_t j = nma; j < nn; j++) { + res[j] = -in[j - nma]; + } + } else { + a -= nn; + uint64_t nma = nn - a; + for (uint64_t j = 0; j < nma; j++) { + res[j] = -in[j + a]; + } + for (uint64_t j = nma; j < nn; j++) { + // rotate first half + res[j] = in[j - nma]; + } + } +} + +EXPORT void rnx_mul_xp_minus_one_f64(uint64_t nn, int64_t p, double* res, const double* in) { + uint64_t a = (-p) & (2 * nn - 1); // a= (-p) (pos)mod (2*nn) + if (a < nn) { // rotate to the left + uint64_t nma = nn - a; + // rotate first half + for (uint64_t j = 0; j < nma; j++) { + res[j] = in[j + a] - in[j]; + } + for (uint64_t j = nma; j < nn; j++) { + res[j] = -in[j - nma] - in[j]; + } + } else { + a -= nn; + uint64_t nma = nn - a; + for (uint64_t j = 0; j < nma; j++) { + res[j] = -in[j + a] - in[j]; + } + for (uint64_t j = nma; j < nn; j++) { + // rotate first half + res[j] = in[j - nma] - in[j]; + } + } +} + +EXPORT void znx_mul_xp_minus_one_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in) { + uint64_t a = (-p) & (2 * nn - 1); // a= (-p) (pos)mod (2*nn) + if (a < nn) { // rotate to the left + uint64_t nma = nn - a; + // rotate first half + for (uint64_t j = 0; j < nma; j++) { + res[j] = in[j + a] - in[j]; + } + for (uint64_t j = nma; j < nn; j++) { + res[j] = -in[j - nma] - in[j]; + } + } else { + a -= nn; + uint64_t nma = nn - a; + for (uint64_t j = 0; j < nma; j++) { + res[j] = -in[j + a] - in[j]; + } + for (uint64_t j = nma; j < nn; j++) { + // rotate first half + res[j] = in[j - nma] - in[j]; + } + } +} + +EXPORT void znx_mul_xp_minus_one_inplace_i64(uint64_t nn, int64_t p, int64_t* res) { + const uint64_t _2mn = 2 * nn - 1; + const uint64_t _mn = nn - 1; + uint64_t nb_modif = 0; + uint64_t j_start = 0; + while (nb_modif < nn) { + // follow the cycle that start with j_start + uint64_t j = j_start; + int64_t tmp1 = res[j]; + do { + // find where the value should go, and with which sign + uint64_t new_j = (j + p) & _2mn; // mod 2n to get the position and sign + uint64_t new_j_n = new_j & _mn; // mod n to get just the position + // exchange this position with tmp1 (and take care of the sign) + int64_t tmp2 = res[new_j_n]; + res[new_j_n] = ((new_j < nn) ? tmp1 : -tmp1) - res[new_j_n]; + tmp1 = tmp2; + // move to the new location, and store the number of items modified + ++nb_modif; + j = new_j_n; + } while (j != j_start); + // move to the start of the next cycle: + // we need to find an index that has not been touched yet, and pick it as next j_start. + // in practice, it is enough to do +1, because the group of rotations is cyclic and 1 is a generator. + ++j_start; + } +} + +// 0 < p < 2nn +EXPORT void rnx_automorphism_f64(uint64_t nn, int64_t p, double* res, const double* in) { + res[0] = in[0]; + uint64_t a = 0; + uint64_t _2mn = 2 * nn - 1; + for (uint64_t i = 1; i < nn; i++) { + a = (a + p) & _2mn; // i*p mod 2n + if (a < nn) { + res[a] = in[i]; // res[ip mod 2n] = res[i] + } else { + res[a - nn] = -in[i]; + } + } +} + +EXPORT void znx_automorphism_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in) { + res[0] = in[0]; + uint64_t a = 0; + uint64_t _2mn = 2 * nn - 1; + for (uint64_t i = 1; i < nn; i++) { + a = (a + p) & _2mn; + if (a < nn) { + res[a] = in[i]; // res[ip mod 2n] = res[i] + } else { + res[a - nn] = -in[i]; + } + } +} + +EXPORT void rnx_rotate_inplace_f64(uint64_t nn, int64_t p, double* res) { + const uint64_t _2mn = 2 * nn - 1; + const uint64_t _mn = nn - 1; + uint64_t nb_modif = 0; + uint64_t j_start = 0; + while (nb_modif < nn) { + // follow the cycle that start with j_start + uint64_t j = j_start; + double tmp1 = res[j]; + do { + // find where the value should go, and with which sign + uint64_t new_j = (j + p) & _2mn; // mod 2n to get the position and sign + uint64_t new_j_n = new_j & _mn; // mod n to get just the position + // exchange this position with tmp1 (and take care of the sign) + double tmp2 = res[new_j_n]; + res[new_j_n] = (new_j < nn) ? tmp1 : -tmp1; + tmp1 = tmp2; + // move to the new location, and store the number of items modified + ++nb_modif; + j = new_j_n; + } while (j != j_start); + // move to the start of the next cycle: + // we need to find an index that has not been touched yet, and pick it as next j_start. + // in practice, it is enough to do +1, because the group of rotations is cyclic and 1 is a generator. + ++j_start; + } +} + +EXPORT void znx_rotate_inplace_i64(uint64_t nn, int64_t p, int64_t* res) { + const uint64_t _2mn = 2 * nn - 1; + const uint64_t _mn = nn - 1; + uint64_t nb_modif = 0; + uint64_t j_start = 0; + while (nb_modif < nn) { + // follow the cycle that start with j_start + uint64_t j = j_start; + int64_t tmp1 = res[j]; + do { + // find where the value should go, and with which sign + uint64_t new_j = (j + p) & _2mn; // mod 2n to get the position and sign + uint64_t new_j_n = new_j & _mn; // mod n to get just the position + // exchange this position with tmp1 (and take care of the sign) + int64_t tmp2 = res[new_j_n]; + res[new_j_n] = (new_j < nn) ? tmp1 : -tmp1; + tmp1 = tmp2; + // move to the new location, and store the number of items modified + ++nb_modif; + j = new_j_n; + } while (j != j_start); + // move to the start of the next cycle: + // we need to find an index that has not been touched yet, and pick it as next j_start. + // in practice, it is enough to do +1, because the group of rotations is cyclic and 1 is a generator. + ++j_start; + } +} + +EXPORT void rnx_mul_xp_minus_one_inplace_f64(uint64_t nn, int64_t p, double* res) { + const uint64_t _2mn = 2 * nn - 1; + const uint64_t _mn = nn - 1; + uint64_t nb_modif = 0; + uint64_t j_start = 0; + while (nb_modif < nn) { + // follow the cycle that start with j_start + uint64_t j = j_start; + double tmp1 = res[j]; + do { + // find where the value should go, and with which sign + uint64_t new_j = (j + p) & _2mn; // mod 2n to get the position and sign + uint64_t new_j_n = new_j & _mn; // mod n to get just the position + // exchange this position with tmp1 (and take care of the sign) + double tmp2 = res[new_j_n]; + res[new_j_n] = ((new_j < nn) ? tmp1 : -tmp1) - res[new_j_n]; + tmp1 = tmp2; + // move to the new location, and store the number of items modified + ++nb_modif; + j = new_j_n; + } while (j != j_start); + // move to the start of the next cycle: + // we need to find an index that has not been touched yet, and pick it as next j_start. + // in practice, it is enough to do +1, because the group of rotations is cyclic and 1 is a generator. + ++j_start; + } +} + +__always_inline int64_t get_base_k_digit(const int64_t x, const uint64_t base_k) { + return (x << (64 - base_k)) >> (64 - base_k); +} + +__always_inline int64_t get_base_k_carry(const int64_t x, const int64_t digit, const uint64_t base_k) { + return (x - digit) >> base_k; +} + +EXPORT void znx_normalize(uint64_t nn, uint64_t base_k, int64_t* out, int64_t* carry_out, const int64_t* in, + const int64_t* carry_in) { + assert(in); + if (out != 0) { + if (carry_in != 0x0 && carry_out != 0x0) { + // with carry in and carry out is computed + for (uint64_t i = 0; i < nn; ++i) { + const int64_t x = in[i]; + const int64_t cin = carry_in[i]; + + int64_t digit = get_base_k_digit(x, base_k); + int64_t carry = get_base_k_carry(x, digit, base_k); + int64_t digit_plus_cin = digit + cin; + int64_t y = get_base_k_digit(digit_plus_cin, base_k); + int64_t cout = carry + get_base_k_carry(digit_plus_cin, y, base_k); + + out[i] = y; + carry_out[i] = cout; + } + } else if (carry_in != 0) { + // with carry in and carry out is dropped + for (uint64_t i = 0; i < nn; ++i) { + const int64_t x = in[i]; + const int64_t cin = carry_in[i]; + + int64_t digit = get_base_k_digit(x, base_k); + int64_t digit_plus_cin = digit + cin; + int64_t y = get_base_k_digit(digit_plus_cin, base_k); + + out[i] = y; + } + + } else if (carry_out != 0) { + // no carry in and carry out is computed + for (uint64_t i = 0; i < nn; ++i) { + const int64_t x = in[i]; + + int64_t y = get_base_k_digit(x, base_k); + int64_t cout = get_base_k_carry(x, y, base_k); + + out[i] = y; + carry_out[i] = cout; + } + + } else { + // no carry in and carry out is dropped + for (uint64_t i = 0; i < nn; ++i) { + out[i] = get_base_k_digit(in[i], base_k); + } + } + } else { + assert(carry_out); + if (carry_in != 0x0) { + // with carry in and carry out is computed + for (uint64_t i = 0; i < nn; ++i) { + const int64_t x = in[i]; + const int64_t cin = carry_in[i]; + + int64_t digit = get_base_k_digit(x, base_k); + int64_t carry = get_base_k_carry(x, digit, base_k); + int64_t digit_plus_cin = digit + cin; + int64_t y = get_base_k_digit(digit_plus_cin, base_k); + int64_t cout = carry + get_base_k_carry(digit_plus_cin, y, base_k); + + carry_out[i] = cout; + } + } else { + // no carry in and carry out is computed + for (uint64_t i = 0; i < nn; ++i) { + const int64_t x = in[i]; + + int64_t y = get_base_k_digit(x, base_k); + int64_t cout = get_base_k_carry(x, y, base_k); + + carry_out[i] = cout; + } + } + } +} + +void znx_automorphism_inplace_i64(uint64_t nn, int64_t p, int64_t* res) { + const uint64_t _2mn = 2 * nn - 1; + const uint64_t _mn = nn - 1; + const uint64_t m = nn >> 1; + // reduce p mod 2n + p &= _2mn; + // uint64_t vp = p & _2mn; + /// uint64_t target_modifs = m >> 1; + // we proceed by increasing binary valuation + for (uint64_t binval = 1, vp = p & _2mn, orb_size = m; binval < nn; + binval <<= 1, vp = (vp << 1) & _2mn, orb_size >>= 1) { + // In this loop, we are going to treat the orbit of indexes = binval mod 2.binval. + // At the beginning of this loop we have: + // vp = binval * p mod 2n + // target_modif = m / binval (i.e. order of the orbit binval % 2.binval) + + // first, handle the orders 1 and 2. + // if p*binval == binval % 2n: we're done! + if (vp == binval) return; + // if p*binval == -binval % 2n: nega-mirror the orbit and all the sub-orbits and exit! + if (((vp + binval) & _2mn) == 0) { + for (uint64_t j = binval; j < m; j += binval) { + int64_t tmp = res[j]; + res[j] = -res[nn - j]; + res[nn - j] = -tmp; + } + res[m] = -res[m]; + return; + } + // if p*binval == binval + n % 2n: negate the orbit and exit + if (((vp - binval) & _mn) == 0) { + for (uint64_t j = binval; j < nn; j += 2 * binval) { + res[j] = -res[j]; + } + return; + } + // if p*binval == n - binval % 2n: mirror the orbit and continue! + if (((vp + binval) & _mn) == 0) { + for (uint64_t j = binval; j < m; j += 2 * binval) { + int64_t tmp = res[j]; + res[j] = res[nn - j]; + res[nn - j] = tmp; + } + continue; + } + // otherwise we will follow the orbit cycles, + // starting from binval and -binval in parallel + uint64_t j_start = binval; + uint64_t nb_modif = 0; + while (nb_modif < orb_size) { + // follow the cycle that start with j_start + uint64_t j = j_start; + int64_t tmp1 = res[j]; + int64_t tmp2 = res[nn - j]; + do { + // find where the value should go, and with which sign + uint64_t new_j = (j * p) & _2mn; // mod 2n to get the position and sign + uint64_t new_j_n = new_j & _mn; // mod n to get just the position + // exchange this position with tmp1 (and take care of the sign) + int64_t tmp1a = res[new_j_n]; + int64_t tmp2a = res[nn - new_j_n]; + if (new_j < nn) { + res[new_j_n] = tmp1; + res[nn - new_j_n] = tmp2; + } else { + res[new_j_n] = -tmp1; + res[nn - new_j_n] = -tmp2; + } + tmp1 = tmp1a; + tmp2 = tmp2a; + // move to the new location, and store the number of items modified + nb_modif += 2; + j = new_j_n; + } while (j != j_start); + // move to the start of the next cycle: + // we need to find an index that has not been touched yet, and pick it as next j_start. + // in practice, it is enough to do *5, because 5 is a generator. + j_start = (5 * j_start) & _mn; + } + } +} + +void rnx_automorphism_inplace_f64(uint64_t nn, int64_t p, double* res) { + const uint64_t _2mn = 2 * nn - 1; + const uint64_t _mn = nn - 1; + const uint64_t m = nn >> 1; + // reduce p mod 2n + p &= _2mn; + // uint64_t vp = p & _2mn; + /// uint64_t target_modifs = m >> 1; + // we proceed by increasing binary valuation + for (uint64_t binval = 1, vp = p & _2mn, orb_size = m; binval < nn; + binval <<= 1, vp = (vp << 1) & _2mn, orb_size >>= 1) { + // In this loop, we are going to treat the orbit of indexes = binval mod 2.binval. + // At the beginning of this loop we have: + // vp = binval * p mod 2n + // target_modif = m / binval (i.e. order of the orbit binval % 2.binval) + + // first, handle the orders 1 and 2. + // if p*binval == binval % 2n: we're done! + if (vp == binval) return; + // if p*binval == -binval % 2n: nega-mirror the orbit and all the sub-orbits and exit! + if (((vp + binval) & _2mn) == 0) { + for (uint64_t j = binval; j < m; j += binval) { + double tmp = res[j]; + res[j] = -res[nn - j]; + res[nn - j] = -tmp; + } + res[m] = -res[m]; + return; + } + // if p*binval == binval + n % 2n: negate the orbit and exit + if (((vp - binval) & _mn) == 0) { + for (uint64_t j = binval; j < nn; j += 2 * binval) { + res[j] = -res[j]; + } + return; + } + // if p*binval == n - binval % 2n: mirror the orbit and continue! + if (((vp + binval) & _mn) == 0) { + for (uint64_t j = binval; j < m; j += 2 * binval) { + double tmp = res[j]; + res[j] = res[nn - j]; + res[nn - j] = tmp; + } + continue; + } + // otherwise we will follow the orbit cycles, + // starting from binval and -binval in parallel + uint64_t j_start = binval; + uint64_t nb_modif = 0; + while (nb_modif < orb_size) { + // follow the cycle that start with j_start + uint64_t j = j_start; + double tmp1 = res[j]; + double tmp2 = res[nn - j]; + do { + // find where the value should go, and with which sign + uint64_t new_j = (j * p) & _2mn; // mod 2n to get the position and sign + uint64_t new_j_n = new_j & _mn; // mod n to get just the position + // exchange this position with tmp1 (and take care of the sign) + double tmp1a = res[new_j_n]; + double tmp2a = res[nn - new_j_n]; + if (new_j < nn) { + res[new_j_n] = tmp1; + res[nn - new_j_n] = tmp2; + } else { + res[new_j_n] = -tmp1; + res[nn - new_j_n] = -tmp2; + } + tmp1 = tmp1a; + tmp2 = tmp2a; + // move to the new location, and store the number of items modified + nb_modif += 2; + j = new_j_n; + } while (j != j_start); + // move to the start of the next cycle: + // we need to find an index that has not been touched yet, and pick it as next j_start. + // in practice, it is enough to do *5, because 5 is a generator. + j_start = (5 * j_start) & _mn; + } + } +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/coeffs/coeffs_arithmetic.h b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/coeffs/coeffs_arithmetic.h new file mode 100644 index 0000000..d6b9721 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/coeffs/coeffs_arithmetic.h @@ -0,0 +1,79 @@ +#ifndef SPQLIOS_COEFFS_ARITHMETIC_H +#define SPQLIOS_COEFFS_ARITHMETIC_H + +#include "../commons.h" + +/** res = a + b */ +EXPORT void znx_add_i64_ref(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b); +EXPORT void znx_add_i64_avx(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b); +/** res = a - b */ +EXPORT void znx_sub_i64_ref(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b); +EXPORT void znx_sub_i64_avx(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b); +/** res = -a */ +EXPORT void znx_negate_i64_ref(uint64_t nn, int64_t* res, const int64_t* a); +EXPORT void znx_negate_i64_avx(uint64_t nn, int64_t* res, const int64_t* a); +/** res = a */ +EXPORT void znx_copy_i64_ref(uint64_t nn, int64_t* res, const int64_t* a); +/** res = 0 */ +EXPORT void znx_zero_i64_ref(uint64_t nn, int64_t* res); + +/** res = a / m where m is a power of 2 */ +EXPORT void rnx_divide_by_m_ref(uint64_t nn, double m, double* res, const double* a); +EXPORT void rnx_divide_by_m_avx(uint64_t nn, double m, double* res, const double* a); + +/** + * @param res = X^p *in mod X^nn +1 + * @param nn the ring dimension + * @param p a power for the rotation -2nn <= p <= 2nn + * @param in is a rnx/znx vector of dimension nn + */ +EXPORT void rnx_rotate_f64(uint64_t nn, int64_t p, double* res, const double* in); +EXPORT void znx_rotate_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in); +EXPORT void rnx_rotate_inplace_f64(uint64_t nn, int64_t p, double* res); +EXPORT void znx_rotate_inplace_i64(uint64_t nn, int64_t p, int64_t* res); + +/** + * @brief res(X) = in(X^p) + * @param nn the ring dimension + * @param p is odd integer and must be between 0 < p < 2nn + * @param in is a rnx/znx vector of dimension nn + */ +EXPORT void rnx_automorphism_f64(uint64_t nn, int64_t p, double* res, const double* in); +EXPORT void znx_automorphism_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in); +EXPORT void rnx_automorphism_inplace_f64(uint64_t nn, int64_t p, double* res); +EXPORT void znx_automorphism_inplace_i64(uint64_t nn, int64_t p, int64_t* res); + +/** + * @brief res = (X^p-1).in + * @param nn the ring dimension + * @param p must be between -2nn <= p <= 2nn + * @param in is a rnx/znx vector of dimension nn + */ +EXPORT void rnx_mul_xp_minus_one_f64(uint64_t nn, int64_t p, double* res, const double* in); +EXPORT void znx_mul_xp_minus_one_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in); +EXPORT void rnx_mul_xp_minus_one_inplace_f64(uint64_t nn, int64_t p, double* res); +EXPORT void znx_mul_xp_minus_one_inplace_i64(uint64_t nn, int64_t p, int64_t* res); + +/** + * @brief Normalize input plus carry mod-2^k. The following + * equality holds @c {in + carry_in == out + carry_out . 2^k}. + * + * @c in must be in [-2^62 .. 2^62] + * + * @c out is in [ -2^(base_k-1), 2^(base_k-1) [. + * + * @c carry_in and @carry_out have at most 64+1-k bits. + * + * Null @c carry_in or @c carry_out are ignored. + * + * @param[in] nn the ring dimension + * @param[in] base_k the base k + * @param out output normalized znx + * @param carry_out output carry znx + * @param[in] in input znx + * @param[in] carry_in input carry znx + */ +EXPORT void znx_normalize(uint64_t nn, uint64_t base_k, int64_t* out, int64_t* carry_out, const int64_t* in, + const int64_t* carry_in); + +#endif // SPQLIOS_COEFFS_ARITHMETIC_H diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/coeffs/coeffs_arithmetic_avx.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/coeffs/coeffs_arithmetic_avx.c new file mode 100644 index 0000000..8fcd608 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/coeffs/coeffs_arithmetic_avx.c @@ -0,0 +1,124 @@ +#include + +#include "../commons_private.h" +#include "coeffs_arithmetic.h" + +// res = a + b. dimension n must be a power of 2 +EXPORT void znx_add_i64_avx(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b) { + if (nn <= 2) { + if (nn == 1) { + res[0] = a[0] + b[0]; + } else { + _mm_storeu_si128((__m128i*)res, // + _mm_add_epi64( // + _mm_loadu_si128((__m128i*)a), // + _mm_loadu_si128((__m128i*)b))); + } + } else { + const __m256i* aa = (__m256i*)a; + const __m256i* bb = (__m256i*)b; + __m256i* rr = (__m256i*)res; + __m256i* const rrend = (__m256i*)(res + nn); + do { + _mm256_storeu_si256(rr, // + _mm256_add_epi64( // + _mm256_loadu_si256(aa), // + _mm256_loadu_si256(bb))); + ++rr; + ++aa; + ++bb; + } while (rr < rrend); + } +} + +// res = a - b. dimension n must be a power of 2 +EXPORT void znx_sub_i64_avx(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b) { + if (nn <= 2) { + if (nn == 1) { + res[0] = a[0] - b[0]; + } else { + _mm_storeu_si128((__m128i*)res, // + _mm_sub_epi64( // + _mm_loadu_si128((__m128i*)a), // + _mm_loadu_si128((__m128i*)b))); + } + } else { + const __m256i* aa = (__m256i*)a; + const __m256i* bb = (__m256i*)b; + __m256i* rr = (__m256i*)res; + __m256i* const rrend = (__m256i*)(res + nn); + do { + _mm256_storeu_si256(rr, // + _mm256_sub_epi64( // + _mm256_loadu_si256(aa), // + _mm256_loadu_si256(bb))); + ++rr; + ++aa; + ++bb; + } while (rr < rrend); + } +} + +EXPORT void znx_negate_i64_avx(uint64_t nn, int64_t* res, const int64_t* a) { + if (nn <= 2) { + if (nn == 1) { + res[0] = -a[0]; + } else { + _mm_storeu_si128((__m128i*)res, // + _mm_sub_epi64( // + _mm_set1_epi64x(0), // + _mm_loadu_si128((__m128i*)a))); + } + } else { + const __m256i* aa = (__m256i*)a; + __m256i* rr = (__m256i*)res; + __m256i* const rrend = (__m256i*)(res + nn); + do { + _mm256_storeu_si256(rr, // + _mm256_sub_epi64( // + _mm256_set1_epi64x(0), // + _mm256_loadu_si256(aa))); + ++rr; + ++aa; + } while (rr < rrend); + } +} + +EXPORT void rnx_divide_by_m_avx(uint64_t n, double m, double* res, const double* a) { + // TODO: see if there is a faster way of dividing by a power of 2? + const double invm = 1. / m; + if (n < 8) { + switch (n) { + case 1: + *res = *a * invm; + break; + case 2: + _mm_storeu_pd(res, // + _mm_mul_pd(_mm_loadu_pd(a), // + _mm_set1_pd(invm))); + break; + case 4: + _mm256_storeu_pd(res, // + _mm256_mul_pd(_mm256_loadu_pd(a), // + _mm256_set1_pd(invm))); + break; + default: + NOT_SUPPORTED(); // non-power of 2 + } + return; + } + const __m256d invm256 = _mm256_set1_pd(invm); + double* rr = res; + const double* aa = a; + const double* const aaend = a + n; + do { + _mm256_storeu_pd(rr, // + _mm256_mul_pd(_mm256_loadu_pd(aa), // + invm256)); + _mm256_storeu_pd(rr + 4, // + _mm256_mul_pd(_mm256_loadu_pd(aa + 4), // + invm256)); + rr += 8; + aa += 8; + } while (aa < aaend); +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/commons.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/commons.c new file mode 100644 index 0000000..adcff79 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/commons.c @@ -0,0 +1,165 @@ +#include "commons.h" + +#include +#include +#include + +EXPORT void* UNDEFINED_p_ii(int32_t n, int32_t m) { UNDEFINED(); } +EXPORT void* UNDEFINED_p_uu(uint32_t n, uint32_t m) { UNDEFINED(); } +EXPORT double* UNDEFINED_dp_pi(const void* p, int32_t n) { UNDEFINED(); } +EXPORT void* UNDEFINED_vp_pi(const void* p, int32_t n) { UNDEFINED(); } +EXPORT void* UNDEFINED_vp_pu(const void* p, uint32_t n) { UNDEFINED(); } +EXPORT void UNDEFINED_v_vpdp(const void* p, double* a) { UNDEFINED(); } +EXPORT void UNDEFINED_v_vpvp(const void* p, void* a) { UNDEFINED(); } +EXPORT double* NOT_IMPLEMENTED_dp_i(int32_t n) { NOT_IMPLEMENTED(); } +EXPORT void* NOT_IMPLEMENTED_vp_i(int32_t n) { NOT_IMPLEMENTED(); } +EXPORT void* NOT_IMPLEMENTED_vp_u(uint32_t n) { NOT_IMPLEMENTED(); } +EXPORT void NOT_IMPLEMENTED_v_dp(double* a) { NOT_IMPLEMENTED(); } +EXPORT void NOT_IMPLEMENTED_v_vp(void* p) { NOT_IMPLEMENTED(); } +EXPORT void NOT_IMPLEMENTED_v_idpdpdp(int32_t n, double* a, const double* b, const double* c) { NOT_IMPLEMENTED(); } +EXPORT void NOT_IMPLEMENTED_v_uvpcvpcvp(uint32_t n, void* r, const void* a, const void* b) { NOT_IMPLEMENTED(); } +EXPORT void NOT_IMPLEMENTED_v_uvpvpcvp(uint32_t n, void* a, void* b, const void* o) { NOT_IMPLEMENTED(); } + +#ifdef _WIN32 +#define __always_inline inline __attribute((always_inline)) +#endif + +void internal_accurate_sincos(double* rcos, double* rsin, double x) { + double _4_x_over_pi = 4 * x / M_PI; + int64_t int_part = ((int64_t)rint(_4_x_over_pi)) & 7; + double frac_part = _4_x_over_pi - (double)(int_part); + double frac_x = M_PI * frac_part / 4.; + // compute the taylor series + double cosp = 1.; + double sinp = 0.; + double powx = 1.; + int64_t nn = 0; + while (fabs(powx) > 1e-20) { + ++nn; + powx = powx * frac_x / (double)(nn); // x^n/n! + switch (nn & 3) { + case 0: + cosp += powx; + break; + case 1: + sinp += powx; + break; + case 2: + cosp -= powx; + break; + case 3: + sinp -= powx; + break; + default: + abort(); // impossible + } + } + // final multiplication + switch (int_part) { + case 0: + *rcos = cosp; + *rsin = sinp; + break; + case 1: + *rcos = M_SQRT1_2 * (cosp - sinp); + *rsin = M_SQRT1_2 * (cosp + sinp); + break; + case 2: + *rcos = -sinp; + *rsin = cosp; + break; + case 3: + *rcos = -M_SQRT1_2 * (cosp + sinp); + *rsin = M_SQRT1_2 * (cosp - sinp); + break; + case 4: + *rcos = -cosp; + *rsin = -sinp; + break; + case 5: + *rcos = -M_SQRT1_2 * (cosp - sinp); + *rsin = -M_SQRT1_2 * (cosp + sinp); + break; + case 6: + *rcos = sinp; + *rsin = -cosp; + break; + case 7: + *rcos = M_SQRT1_2 * (cosp + sinp); + *rsin = -M_SQRT1_2 * (cosp - sinp); + break; + default: + abort(); // impossible + } + if (fabs(cos(x) - *rcos) > 1e-10 || fabs(sin(x) - *rsin) > 1e-10) { + printf("cos(%.17lf) =? %.17lf instead of %.17lf\n", x, *rcos, cos(x)); + printf("sin(%.17lf) =? %.17lf instead of %.17lf\n", x, *rsin, sin(x)); + printf("fracx = %.17lf\n", frac_x); + printf("cosp = %.17lf\n", cosp); + printf("sinp = %.17lf\n", sinp); + printf("nn = %d\n", (int)(nn)); + } +} + +double internal_accurate_cos(double x) { + double rcos, rsin; + internal_accurate_sincos(&rcos, &rsin, x); + return rcos; +} +double internal_accurate_sin(double x) { + double rcos, rsin; + internal_accurate_sincos(&rcos, &rsin, x); + return rsin; +} + +EXPORT void spqlios_debug_free(void* addr) { free((uint8_t*)addr - 64); } + +EXPORT void* spqlios_debug_alloc(uint64_t size) { return (uint8_t*)malloc(size + 64) + 64; } + +EXPORT void spqlios_free(void* addr) { +#ifndef NDEBUG + // in debug mode, we deallocated with spqlios_debug_free() + spqlios_debug_free(addr); +#else + // in release mode, the function will free aligned memory +#ifdef _WIN32 + _aligned_free(addr); +#else + free(addr); +#endif +#endif +} + +EXPORT void* spqlios_alloc(uint64_t size) { +#ifndef NDEBUG + // in debug mode, the function will not necessarily have any particular alignment + // it will also ensure that memory can only be deallocated with spqlios_free() + return spqlios_debug_alloc(size); +#else + // in release mode, the function will return 64-bytes aligned memory +#ifdef _WIN32 + void* reps = _aligned_malloc((size + 63) & (UINT64_C(-64)), 64); +#else + void* reps = aligned_alloc(64, (size + 63) & (UINT64_C(-64))); +#endif + if (reps == 0) FATAL_ERROR("Out of memory"); + return reps; +#endif +} + +EXPORT void* spqlios_alloc_custom_align(uint64_t align, uint64_t size) { +#ifndef NDEBUG + // in debug mode, the function will not necessarily have any particular alignment + // it will also ensure that memory can only be deallocated with spqlios_free() + return spqlios_debug_alloc(size); +#else + // in release mode, the function will return aligned memory +#ifdef _WIN32 + void* reps = _aligned_malloc(size, align); +#else + void* reps = aligned_alloc(align, size); +#endif + if (reps == 0) FATAL_ERROR("Out of memory"); + return reps; +#endif +} \ No newline at end of file diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/commons.h b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/commons.h new file mode 100644 index 0000000..653d083 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/commons.h @@ -0,0 +1,77 @@ +#ifndef SPQLIOS_COMMONS_H +#define SPQLIOS_COMMONS_H + +#ifdef __cplusplus +#include +#include +#include +#define EXPORT extern "C" +#define EXPORT_DECL extern "C" +#else +#include +#include +#include +#define EXPORT +#define EXPORT_DECL extern +#define nullptr 0x0; +#endif + +#define UNDEFINED() \ + { \ + fprintf(stderr, "UNDEFINED!!!\n"); \ + abort(); \ + } +#define NOT_IMPLEMENTED() \ + { \ + fprintf(stderr, "NOT IMPLEMENTED!!!\n"); \ + abort(); \ + } +#define FATAL_ERROR(MESSAGE) \ + { \ + fprintf(stderr, "ERROR: %s\n", (MESSAGE)); \ + abort(); \ + } + +EXPORT void* UNDEFINED_p_ii(int32_t n, int32_t m); +EXPORT void* UNDEFINED_p_uu(uint32_t n, uint32_t m); +EXPORT double* UNDEFINED_dp_pi(const void* p, int32_t n); +EXPORT void* UNDEFINED_vp_pi(const void* p, int32_t n); +EXPORT void* UNDEFINED_vp_pu(const void* p, uint32_t n); +EXPORT void UNDEFINED_v_vpdp(const void* p, double* a); +EXPORT void UNDEFINED_v_vpvp(const void* p, void* a); +EXPORT double* NOT_IMPLEMENTED_dp_i(int32_t n); +EXPORT void* NOT_IMPLEMENTED_vp_i(int32_t n); +EXPORT void* NOT_IMPLEMENTED_vp_u(uint32_t n); +EXPORT void NOT_IMPLEMENTED_v_dp(double* a); +EXPORT void NOT_IMPLEMENTED_v_vp(void* p); +EXPORT void NOT_IMPLEMENTED_v_idpdpdp(int32_t n, double* a, const double* b, const double* c); +EXPORT void NOT_IMPLEMENTED_v_uvpcvpcvp(uint32_t n, void* r, const void* a, const void* b); +EXPORT void NOT_IMPLEMENTED_v_uvpvpcvp(uint32_t n, void* a, void* b, const void* o); + +// windows + +#if defined(_WIN32) || defined(__APPLE__) +#define __always_inline inline __attribute((always_inline)) +#endif + +EXPORT void spqlios_free(void* address); + +EXPORT void* spqlios_alloc(uint64_t size); +EXPORT void* spqlios_alloc_custom_align(uint64_t align, uint64_t size); + +#define USE_LIBM_SIN_COS +#ifndef USE_LIBM_SIN_COS +// if at some point, we want to remove the libm dependency, we can +// consider this: +EXPORT double internal_accurate_cos(double x); +EXPORT double internal_accurate_sin(double x); +EXPORT void internal_accurate_sincos(double* rcos, double* rsin, double x); +#define m_accurate_cos internal_accurate_cos +#define m_accurate_sin internal_accurate_sin +#else +// let's use libm sin and cos +#define m_accurate_cos cos +#define m_accurate_sin sin +#endif + +#endif // SPQLIOS_COMMONS_H diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/commons_private.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/commons_private.c new file mode 100644 index 0000000..fddc190 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/commons_private.c @@ -0,0 +1,55 @@ +#include "commons_private.h" + +#include +#include + +#include "commons.h" + +EXPORT void* spqlios_error(const char* error) { + fputs(error, stderr); + abort(); + return nullptr; +} +EXPORT void* spqlios_keep_or_free(void* ptr, void* ptr2) { + if (!ptr2) { + free(ptr); + } + return ptr2; +} + +EXPORT uint32_t log2m(uint32_t m) { + uint32_t a = m - 1; + if (m & a) FATAL_ERROR("m must be a power of two"); + a = (a & 0x55555555u) + ((a >> 1) & 0x55555555u); + a = (a & 0x33333333u) + ((a >> 2) & 0x33333333u); + a = (a & 0x0F0F0F0Fu) + ((a >> 4) & 0x0F0F0F0Fu); + a = (a & 0x00FF00FFu) + ((a >> 8) & 0x00FF00FFu); + return (a & 0x0000FFFFu) + ((a >> 16) & 0x0000FFFFu); +} + +EXPORT uint64_t is_not_pow2_double(void* doublevalue) { return (*(uint64_t*)doublevalue) & 0x7FFFFFFFFFFFFUL; } + +uint32_t revbits(uint32_t nbits, uint32_t value) { + uint32_t res = 0; + for (uint32_t i = 0; i < nbits; ++i) { + res = (res << 1) + (value & 1); + value >>= 1; + } + return res; +} + +/** + * @brief this computes the sequence: 0,1/2,1/4,3/4,1/8,5/8,3/8,7/8,... + * essentially: the bits of (i+1) in lsb order on the basis (1/2^k) mod 1*/ +double fracrevbits(uint32_t i) { + if (i == 0) return 0; + if (i == 1) return 0.5; + if (i % 2 == 0) + return fracrevbits(i / 2) / 2.; + else + return fracrevbits((i - 1) / 2) / 2. + 0.5; +} + +uint64_t ceilto64b(uint64_t size) { return (size + UINT64_C(63)) & (UINT64_C(-64)); } + +uint64_t ceilto32b(uint64_t size) { return (size + UINT64_C(31)) & (UINT64_C(-32)); } diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/commons_private.h b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/commons_private.h new file mode 100644 index 0000000..e2b0514 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/commons_private.h @@ -0,0 +1,72 @@ +#ifndef SPQLIOS_COMMONS_PRIVATE_H +#define SPQLIOS_COMMONS_PRIVATE_H + +#include "commons.h" + +#ifdef __cplusplus +#include +#include +#include +#else +#include +#include +#include +#define nullptr 0x0; +#endif + +/** @brief log2 of a power of two (UB if m is not a power of two) */ +EXPORT uint32_t log2m(uint32_t m); + +/** @brief checks if the doublevalue is a power of two */ +EXPORT uint64_t is_not_pow2_double(void* doublevalue); + +#define UNDEFINED() \ + { \ + fprintf(stderr, "UNDEFINED!!!\n"); \ + abort(); \ + } +#define NOT_IMPLEMENTED() \ + { \ + fprintf(stderr, "NOT IMPLEMENTED!!!\n"); \ + abort(); \ + } +#define NOT_SUPPORTED() \ + { \ + fprintf(stderr, "NOT SUPPORTED!!!\n"); \ + abort(); \ + } +#define FATAL_ERROR(MESSAGE) \ + { \ + fprintf(stderr, "ERROR: %s\n", (MESSAGE)); \ + abort(); \ + } + +#define STATIC_ASSERT(condition) (void)sizeof(char[-1 + 2 * !!(condition)]) + +/** @brief reports the error and returns nullptr */ +EXPORT void* spqlios_error(const char* error); +/** @brief if ptr2 is not null, returns ptr, otherwise free ptr and return null */ +EXPORT void* spqlios_keep_or_free(void* ptr, void* ptr2); + +#ifdef __x86_64__ +#define CPU_SUPPORTS __builtin_cpu_supports +#else +// TODO for now, we do not have any optimization for non x86 targets +#define CPU_SUPPORTS(xxxx) 0 +#endif + +/** @brief returns the n bits of value in reversed order */ +EXPORT uint32_t revbits(uint32_t nbits, uint32_t value); + +/** + * @brief this computes the sequence: 0,1/2,1/4,3/4,1/8,5/8,3/8,7/8,... + * essentially: the bits of (i+1) in lsb order on the basis (1/2^k) mod 1*/ +EXPORT double fracrevbits(uint32_t i); + +/** @brief smallest multiple of 64 higher or equal to size */ +EXPORT uint64_t ceilto64b(uint64_t size); + +/** @brief smallest multiple of 32 higher or equal to size */ +EXPORT uint64_t ceilto32b(uint64_t size); + +#endif // SPQLIOS_COMMONS_PRIVATE_H diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/README.md b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/README.md new file mode 100644 index 0000000..443cfa4 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/README.md @@ -0,0 +1,22 @@ +In this folder, we deal with the full complex FFT in `C[X] mod X^M-i`. +One complex is represented by two consecutive doubles `(real,imag)` +Note that a real polynomial sum_{j=0}^{N-1} p_j.X^j mod X^N+1 +corresponds to the complex polynomial of half degree `M=N/2`: +`sum_{j=0}^{M-1} (p_{j} + i.p_{j+M}) X^j mod X^M-i` + +For a complex polynomial A(X) sum c_i X^i of degree M-1 +or a real polynomial sum a_i X^i of degree N + +coefficient space: +a_0,a_M,a_1,a_{M+1},...,a_{M-1},a_{2M-1} +or equivalently +Re(c_0),Im(c_0),Re(c_1),Im(c_1),...Re(c_{M-1}),Im(c_{M-1}) + +eval space: +c(omega_{0}),...,c(omega_{M-1}) + +where +omega_j = omega^{1+rev_{2N}(j)} +and omega = exp(i.pi/N) + +rev_{2N}(j) is the number that has the log2(2N) bits of j in reverse order. \ No newline at end of file diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_common.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_common.c new file mode 100644 index 0000000..5306068 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_common.c @@ -0,0 +1,80 @@ +#include "cplx_fft_internal.h" + +void cplx_set(CPLX r, const CPLX a) { + r[0] = a[0]; + r[1] = a[1]; +} +void cplx_neg(CPLX r, const CPLX a) { + r[0] = -a[0]; + r[1] = -a[1]; +} +void cplx_add(CPLX r, const CPLX a, const CPLX b) { + r[0] = a[0] + b[0]; + r[1] = a[1] + b[1]; +} +void cplx_sub(CPLX r, const CPLX a, const CPLX b) { + r[0] = a[0] - b[0]; + r[1] = a[1] - b[1]; +} +void cplx_mul(CPLX r, const CPLX a, const CPLX b) { + double re = a[0] * b[0] - a[1] * b[1]; + r[1] = a[0] * b[1] + a[1] * b[0]; + r[0] = re; +} + +/** + * @brief splits 2h evaluations of one polynomials into 2 times h evaluations of even/odd polynomial + * Input: Q_0(y),...,Q_{h-1}(y),Q_0(-y),...,Q_{h-1}(-y) + * Output: P_0(z),...,P_{h-1}(z),P_h(z),...,P_{2h-1}(z) + * where Q_i(X)=P_i(X^2)+X.P_{h+i}(X^2) and y^2 = z + * @param h number of "coefficients" h >= 1 + * @param data 2h complex coefficients interleaved and 256b aligned + * @param powom y represented as (yre,yim) + */ +EXPORT void cplx_split_fft_ref(int32_t h, CPLX* data, const CPLX powom) { + CPLX* d0 = data; + CPLX* d1 = data + h; + for (uint64_t i = 0; i < h; ++i) { + CPLX diff; + cplx_sub(diff, d0[i], d1[i]); + cplx_add(d0[i], d0[i], d1[i]); + cplx_mul(d1[i], diff, powom); + } +} + +/** + * @brief Do two layers of itwiddle (i.e. split). + * Input/output: d0,d1,d2,d3 of length h + * Algo: + * itwiddle(d0,d1,om[0]),itwiddle(d2,d3,i.om[0]) + * itwiddle(d0,d2,om[1]),itwiddle(d1,d3,om[1]) + * @param h number of "coefficients" h >= 1 + * @param data 4h complex coefficients interleaved and 256b aligned + * @param powom om[0] (re,im) and om[1] where om[1]=om[0]^2 + */ +EXPORT void cplx_bisplit_fft_ref(int32_t h, CPLX* data, const CPLX powom[2]) { + CPLX* d0 = data; + CPLX* d2 = data + 2 * h; + const CPLX* om0 = powom; + CPLX iom0; + iom0[0] = powom[0][1]; + iom0[1] = -powom[0][0]; + const CPLX* om1 = powom + 1; + cplx_split_fft_ref(h, d0, *om0); + cplx_split_fft_ref(h, d2, iom0); + cplx_split_fft_ref(2 * h, d0, *om1); +} + +/** + * Input: Q(y),Q(-y) + * Output: P_0(z),P_1(z) + * where Q(X)=P_0(X^2)+X.P_1(X^2) and y^2 = z + * @param data 2 complexes coefficients interleaved and 256b aligned + * @param powom (z,-z) interleaved: (zre,zim,-zre,-zim) + */ +void split_fft_last_ref(CPLX* data, const CPLX powom) { + CPLX diff; + cplx_sub(diff, data[0], data[1]); + cplx_add(data[0], data[0], data[1]); + cplx_mul(data[1], diff, powom); +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_conversions.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_conversions.c new file mode 100644 index 0000000..d912ccf --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_conversions.c @@ -0,0 +1,158 @@ +#include +#include + +#include "../commons_private.h" +#include "cplx_fft_internal.h" +#include "cplx_fft_private.h" + +EXPORT void cplx_from_znx32_ref(const CPLX_FROM_ZNX32_PRECOMP* precomp, void* r, const int32_t* x) { + const uint32_t m = precomp->m; + const int32_t* inre = x; + const int32_t* inim = x + m; + CPLX* out = r; + for (uint32_t i = 0; i < m; ++i) { + out[i][0] = (double)inre[i]; + out[i][1] = (double)inim[i]; + } +} + +EXPORT void cplx_from_tnx32_ref(const CPLX_FROM_TNX32_PRECOMP* precomp, void* r, const int32_t* x) { + static const double _2p32 = 1. / (INT64_C(1) << 32); + const uint32_t m = precomp->m; + const int32_t* inre = x; + const int32_t* inim = x + m; + CPLX* out = r; + for (uint32_t i = 0; i < m; ++i) { + out[i][0] = ((double)inre[i]) * _2p32; + out[i][1] = ((double)inim[i]) * _2p32; + } +} + +EXPORT void cplx_to_tnx32_ref(const CPLX_TO_TNX32_PRECOMP* precomp, int32_t* r, const void* x) { + static const double _2p32 = (INT64_C(1) << 32); + const uint32_t m = precomp->m; + double factor = _2p32 / precomp->divisor; + int32_t* outre = r; + int32_t* outim = r + m; + const CPLX* in = x; + // Note: this formula will only work if abs(in) < 2^32 + for (uint32_t i = 0; i < m; ++i) { + outre[i] = (int32_t)(int64_t)(rint(in[i][0] * factor)); + outim[i] = (int32_t)(int64_t)(rint(in[i][1] * factor)); + } +} + +void* init_cplx_from_znx32_precomp(CPLX_FROM_ZNX32_PRECOMP* res, uint32_t m) { + res->m = m; + if (CPU_SUPPORTS("avx2")) { + if (m >= 8) { + res->function = cplx_from_znx32_avx2_fma; + } else { + res->function = cplx_from_znx32_ref; + } + } else { + res->function = cplx_from_znx32_ref; + } + return res; +} + +CPLX_FROM_ZNX32_PRECOMP* new_cplx_from_znx32_precomp(uint32_t m) { + CPLX_FROM_ZNX32_PRECOMP* res = malloc(sizeof(CPLX_FROM_ZNX32_PRECOMP)); + if (!res) return spqlios_error(strerror(errno)); + return spqlios_keep_or_free(res, init_cplx_from_znx32_precomp(res, m)); +} + +void* init_cplx_from_tnx32_precomp(CPLX_FROM_TNX32_PRECOMP* res, uint32_t m) { + res->m = m; + if (CPU_SUPPORTS("avx2")) { + if (m >= 8) { + res->function = cplx_from_tnx32_avx2_fma; + } else { + res->function = cplx_from_tnx32_ref; + } + } else { + res->function = cplx_from_tnx32_ref; + } + return res; +} + +CPLX_FROM_TNX32_PRECOMP* new_cplx_from_tnx32_precomp(uint32_t m) { + CPLX_FROM_TNX32_PRECOMP* res = malloc(sizeof(CPLX_FROM_TNX32_PRECOMP)); + if (!res) return spqlios_error(strerror(errno)); + return spqlios_keep_or_free(res, init_cplx_from_tnx32_precomp(res, m)); +} + +void* init_cplx_to_tnx32_precomp(CPLX_TO_TNX32_PRECOMP* res, uint32_t m, double divisor, uint32_t log2overhead) { + if (is_not_pow2_double(&divisor)) return spqlios_error("divisor must be a power of 2"); + if (m & (m - 1)) return spqlios_error("m must be a power of 2"); + if (log2overhead > 52) return spqlios_error("log2overhead is too large"); + res->m = m; + res->divisor = divisor; + if (CPU_SUPPORTS("avx2")) { + if (log2overhead <= 18) { + if (m >= 8) { + res->function = cplx_to_tnx32_avx2_fma; + } else { + res->function = cplx_to_tnx32_ref; + } + } else { + res->function = cplx_to_tnx32_ref; + } + } else { + res->function = cplx_to_tnx32_ref; + } + return res; +} + +EXPORT CPLX_TO_TNX32_PRECOMP* new_cplx_to_tnx32_precomp(uint32_t m, double divisor, uint32_t log2overhead) { + CPLX_TO_TNX32_PRECOMP* res = malloc(sizeof(CPLX_TO_TNX32_PRECOMP)); + if (!res) return spqlios_error(strerror(errno)); + return spqlios_keep_or_free(res, init_cplx_to_tnx32_precomp(res, m, divisor, log2overhead)); +} + +/** + * @brief Simpler API for the znx32 to cplx conversion. + * For each dimension, the precomputed tables for this dimension are generated automatically the first time. + * It is advised to do one dry-run call per desired dimension before using in a multithread environment */ +EXPORT void cplx_from_znx32_simple(uint32_t m, void* r, const int32_t* x) { + // not checking for log2bound which is not relevant here + static CPLX_FROM_ZNX32_PRECOMP precomp[32]; + CPLX_FROM_ZNX32_PRECOMP* p = precomp + log2m(m); + if (!p->function) { + if (!init_cplx_from_znx32_precomp(p, m)) abort(); + } + p->function(p, r, x); +} + +/** + * @brief Simpler API for the tnx32 to cplx conversion. + * For each dimension, the precomputed tables for this dimension are generated automatically the first time. + * It is advised to do one dry-run call per desired dimension before using in a multithread environment */ +EXPORT void cplx_from_tnx32_simple(uint32_t m, void* r, const int32_t* x) { + static CPLX_FROM_TNX32_PRECOMP precomp[32]; + CPLX_FROM_TNX32_PRECOMP* p = precomp + log2m(m); + if (!p->function) { + if (!init_cplx_from_tnx32_precomp(p, m)) abort(); + } + p->function(p, r, x); +} +/** + * @brief Simpler API for the cplx to tnx32 conversion. + * For each dimension, the precomputed tables for this dimension are generated automatically the first time. + * It is advised to do one dry-run call per desired dimension before using in a multithread environment */ +EXPORT void cplx_to_tnx32_simple(uint32_t m, double divisor, uint32_t log2overhead, int32_t* r, const void* x) { + struct LAST_CPLX_TO_TNX32_PRECOMP { + CPLX_TO_TNX32_PRECOMP p; + double last_divisor; + double last_log2over; + }; + static __thread struct LAST_CPLX_TO_TNX32_PRECOMP precomp[32]; + struct LAST_CPLX_TO_TNX32_PRECOMP* p = precomp + log2m(m); + if (!p->p.function || divisor != p->last_divisor || log2overhead != p->last_log2over) { + memset(p, 0, sizeof(*p)); + if (!init_cplx_to_tnx32_precomp(&p->p, m, divisor, log2overhead)) abort(); + p->last_divisor = divisor; + p->last_log2over = log2overhead; + } + p->p.function(&p->p, r, x); +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_conversions_avx2_fma.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_conversions_avx2_fma.c new file mode 100644 index 0000000..23fe004 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_conversions_avx2_fma.c @@ -0,0 +1,104 @@ +#include +#include + +#include "cplx_fft_internal.h" +#include "cplx_fft_private.h" + +typedef int32_t I8MEM[8]; +typedef double D4MEM[4]; + +__always_inline void cplx_from_any_fma(uint64_t m, void* r, const int32_t* x, const __m256i C, const __m256d R) { + const __m256i S = _mm256_set1_epi32(0x80000000); + const I8MEM* inre = (I8MEM*)(x); + const I8MEM* inim = (I8MEM*)(x + m); + D4MEM* out = (D4MEM*)r; + const uint64_t ms8 = m / 8; + for (uint32_t i = 0; i < ms8; ++i) { + __m256i rea = _mm256_loadu_si256((__m256i*)inre[0]); + __m256i ima = _mm256_loadu_si256((__m256i*)inim[0]); + rea = _mm256_add_epi32(rea, S); + ima = _mm256_add_epi32(ima, S); + __m256i tmpa = _mm256_unpacklo_epi32(rea, ima); + __m256i tmpc = _mm256_unpackhi_epi32(rea, ima); + __m256i cpla = _mm256_permute2x128_si256(tmpa, tmpc, 0x20); + __m256i cplc = _mm256_permute2x128_si256(tmpa, tmpc, 0x31); + tmpa = _mm256_unpacklo_epi32(cpla, C); + __m256i tmpb = _mm256_unpackhi_epi32(cpla, C); + tmpc = _mm256_unpacklo_epi32(cplc, C); + __m256i tmpd = _mm256_unpackhi_epi32(cplc, C); + cpla = _mm256_permute2x128_si256(tmpa, tmpb, 0x20); + __m256i cplb = _mm256_permute2x128_si256(tmpa, tmpb, 0x31); + cplc = _mm256_permute2x128_si256(tmpc, tmpd, 0x20); + __m256i cpld = _mm256_permute2x128_si256(tmpc, tmpd, 0x31); + __m256d dcpla = _mm256_sub_pd(_mm256_castsi256_pd(cpla), R); + __m256d dcplb = _mm256_sub_pd(_mm256_castsi256_pd(cplb), R); + __m256d dcplc = _mm256_sub_pd(_mm256_castsi256_pd(cplc), R); + __m256d dcpld = _mm256_sub_pd(_mm256_castsi256_pd(cpld), R); + _mm256_storeu_pd(out[0], dcpla); + _mm256_storeu_pd(out[1], dcplb); + _mm256_storeu_pd(out[2], dcplc); + _mm256_storeu_pd(out[3], dcpld); + inre += 1; + inim += 1; + out += 4; + } +} + +EXPORT void cplx_from_znx32_avx2_fma(const CPLX_FROM_ZNX32_PRECOMP* precomp, void* r, const int32_t* x) { + // note: the hex code of 2^31 + 2^52 is 0x4330000080000000 + const __m256i C = _mm256_set1_epi32(0x43300000); + const __m256d R = _mm256_set1_pd((INT64_C(1) << 31) + (INT64_C(1) << 52)); + // double XX = INT64_C(1) + (INT64_C(1)<<31) + (INT64_C(1)<<52); + // printf("\n\n%016lx\n", *(uint64_t*)&XX); + // abort(); + const uint64_t m = precomp->m; + cplx_from_any_fma(m, r, x, C, R); +} + +EXPORT void cplx_from_tnx32_avx2_fma(const CPLX_FROM_TNX32_PRECOMP* precomp, void* r, const int32_t* x) { + // note: the hex code of 2^-1 + 2^30 is 0x4130000080000000 + const __m256i C = _mm256_set1_epi32(0x41300000); + const __m256d R = _mm256_set1_pd(0.5 + (INT64_C(1) << 20)); + // double XX = (double)(INT64_C(1) + (INT64_C(1)<<31) + (INT64_C(1)<<52))/(INT64_C(1)<<32); + // printf("\n\n%016lx\n", *(uint64_t*)&XX); + // abort(); + const uint64_t m = precomp->m; + cplx_from_any_fma(m, r, x, C, R); +} + +EXPORT void cplx_to_tnx32_avx2_fma(const CPLX_TO_TNX32_PRECOMP* precomp, int32_t* r, const void* x) { + const __m256d R = _mm256_set1_pd((0.5 + (INT64_C(3) << 19)) * precomp->divisor); + const __m256i MASK = _mm256_set1_epi64x(0xFFFFFFFFUL); + const __m256i S = _mm256_set1_epi32(0x80000000); + // const __m256i IDX = _mm256_set_epi32(0,4,1,5,2,6,3,7); + const __m256i IDX = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0); + const uint64_t m = precomp->m; + const uint64_t ms8 = m / 8; + I8MEM* outre = (I8MEM*)r; + I8MEM* outim = (I8MEM*)(r + m); + const D4MEM* in = x; + // Note: this formula will only work if abs(in) < 2^32 + for (uint32_t i = 0; i < ms8; ++i) { + __m256d cpla = _mm256_loadu_pd(in[0]); + __m256d cplb = _mm256_loadu_pd(in[1]); + __m256d cplc = _mm256_loadu_pd(in[2]); + __m256d cpld = _mm256_loadu_pd(in[3]); + __m256i icpla = _mm256_castpd_si256(_mm256_add_pd(cpla, R)); + __m256i icplb = _mm256_castpd_si256(_mm256_add_pd(cplb, R)); + __m256i icplc = _mm256_castpd_si256(_mm256_add_pd(cplc, R)); + __m256i icpld = _mm256_castpd_si256(_mm256_add_pd(cpld, R)); + icpla = _mm256_or_si256(_mm256_and_si256(icpla, MASK), _mm256_slli_epi64(icplb, 32)); + icplc = _mm256_or_si256(_mm256_and_si256(icplc, MASK), _mm256_slli_epi64(icpld, 32)); + icpla = _mm256_xor_si256(icpla, S); + icplc = _mm256_xor_si256(icplc, S); + __m256i re = _mm256_unpacklo_epi64(icpla, icplc); + __m256i im = _mm256_unpackhi_epi64(icpla, icplc); + re = _mm256_permutevar8x32_epi32(re, IDX); + im = _mm256_permutevar8x32_epi32(im, IDX); + _mm256_storeu_si256((__m256i*)outre[0], re); + _mm256_storeu_si256((__m256i*)outim[0], im); + outre += 1; + outim += 1; + in += 4; + } +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_execute.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_execute.c new file mode 100644 index 0000000..323f7b1 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_execute.c @@ -0,0 +1,18 @@ +#include "cplx_fft_internal.h" +#include "cplx_fft_private.h" + +EXPORT void cplx_from_znx32(const CPLX_FROM_ZNX32_PRECOMP* tables, void* r, const int32_t* a) { + tables->function(tables, r, a); +} +EXPORT void cplx_from_tnx32(const CPLX_FROM_TNX32_PRECOMP* tables, void* r, const int32_t* a) { + tables->function(tables, r, a); +} +EXPORT void cplx_to_tnx32(const CPLX_TO_TNX32_PRECOMP* tables, int32_t* r, const void* a) { + tables->function(tables, r, a); +} +EXPORT void cplx_fftvec_mul(const CPLX_FFTVEC_MUL_PRECOMP* tables, void* r, const void* a, const void* b) { + tables->function(tables, r, a, b); +} +EXPORT void cplx_fftvec_addmul(const CPLX_FFTVEC_ADDMUL_PRECOMP* tables, void* r, const void* a, const void* b) { + tables->function(tables, r, a, b); +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_fallbacks_aarch64.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_fallbacks_aarch64.c new file mode 100644 index 0000000..87bd6f6 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_fallbacks_aarch64.c @@ -0,0 +1,43 @@ +#include "cplx_fft_internal.h" +#include "cplx_fft_private.h" + +EXPORT void cplx_fftvec_addmul_fma(const CPLX_FFTVEC_ADDMUL_PRECOMP* tables, void* r, const void* a, const void* b) { + UNDEFINED(); // not defined for non x86 targets +} +EXPORT void cplx_fftvec_mul_fma(const CPLX_FFTVEC_MUL_PRECOMP* tables, void* r, const void* a, const void* b) { + UNDEFINED(); +} +EXPORT void cplx_fftvec_addmul_sse(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a, const void* b) { + UNDEFINED(); +} +EXPORT void cplx_fftvec_addmul_avx512(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a, + const void* b) { + UNDEFINED(); +} +EXPORT void cplx_fft16_avx_fma(void* data, const void* omega) { UNDEFINED(); } +EXPORT void cplx_ifft16_avx_fma(void* data, const void* omega) { UNDEFINED(); } +EXPORT void cplx_from_znx32_avx2_fma(const CPLX_FROM_ZNX32_PRECOMP* precomp, void* r, const int32_t* x) { UNDEFINED(); } +EXPORT void cplx_from_tnx32_avx2_fma(const CPLX_FROM_TNX32_PRECOMP* precomp, void* r, const int32_t* x) { UNDEFINED(); } +EXPORT void cplx_to_tnx32_avx2_fma(const CPLX_TO_TNX32_PRECOMP* precomp, int32_t* x, const void* c) { UNDEFINED(); } +EXPORT void cplx_fft_avx2_fma(const CPLX_FFT_PRECOMP* tables, void* data){UNDEFINED()} EXPORT + void cplx_ifft_avx2_fma(const CPLX_IFFT_PRECOMP* itables, void* data){UNDEFINED()} EXPORT + void cplx_fftvec_twiddle_fma(const CPLX_FFTVEC_TWIDDLE_PRECOMP* tables, void* a, void* b, const void* om){ + UNDEFINED()} EXPORT void cplx_fftvec_twiddle_avx512(const CPLX_FFTVEC_TWIDDLE_PRECOMP* tables, void* a, void* b, + const void* om){UNDEFINED()} EXPORT + void cplx_fftvec_bitwiddle_fma(const CPLX_FFTVEC_BITWIDDLE_PRECOMP* tables, void* a, uint64_t slice, + const void* om){UNDEFINED()} EXPORT + void cplx_fftvec_bitwiddle_avx512(const CPLX_FFTVEC_BITWIDDLE_PRECOMP* tables, void* a, uint64_t slice, + const void* om){UNDEFINED()} + +// DEPRECATED? +EXPORT void cplx_fftvec_add_fma(uint32_t m, void* r, const void* a, const void* b){UNDEFINED()} EXPORT + void cplx_fftvec_sub2_to_fma(uint32_t m, void* r, const void* a, const void* b){UNDEFINED()} EXPORT + void cplx_fftvec_copy_fma(uint32_t m, void* r, const void* a) { + UNDEFINED() +} + +// executors +// EXPORT void cplx_ifft(const CPLX_IFFT_PRECOMP* itables, void* data) { +// itables->function(itables, data); +//} +// EXPORT void cplx_fft(const CPLX_FFT_PRECOMP* tables, void* data) { tables->function(tables, data); } diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_fft.h b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_fft.h new file mode 100644 index 0000000..01699bd --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_fft.h @@ -0,0 +1,221 @@ +#ifndef SPQLIOS_CPLX_FFT_H +#define SPQLIOS_CPLX_FFT_H + +#include "../commons.h" + +typedef struct cplx_fft_precomp CPLX_FFT_PRECOMP; +typedef struct cplx_ifft_precomp CPLX_IFFT_PRECOMP; +typedef struct cplx_mul_precomp CPLX_FFTVEC_MUL_PRECOMP; +typedef struct cplx_addmul_precomp CPLX_FFTVEC_ADDMUL_PRECOMP; +typedef struct cplx_from_znx32_precomp CPLX_FROM_ZNX32_PRECOMP; +typedef struct cplx_from_tnx32_precomp CPLX_FROM_TNX32_PRECOMP; +typedef struct cplx_to_tnx32_precomp CPLX_TO_TNX32_PRECOMP; +typedef struct cplx_to_znx32_precomp CPLX_TO_ZNX32_PRECOMP; +typedef struct cplx_from_rnx64_precomp CPLX_FROM_RNX64_PRECOMP; +typedef struct cplx_to_rnx64_precomp CPLX_TO_RNX64_PRECOMP; +typedef struct cplx_round_to_rnx64_precomp CPLX_ROUND_TO_RNX64_PRECOMP; + +/** + * @brief precomputes fft tables. + * The FFT tables contains a constant section that is required for efficient FFT operations in dimension nn. + * The resulting pointer is to be passed as "tables" argument to any call to the fft function. + * The user can optionnally allocate zero or more computation buffers, which are scratch spaces that are contiguous to + * the constant tables in memory, and allow for more efficient operations. It is the user's responsibility to ensure + * that each of those buffers are never used simultaneously by two ffts on different threads at the same time. The fft + * table must be deleted by delete_fft_precomp after its last usage. + */ +EXPORT CPLX_FFT_PRECOMP* new_cplx_fft_precomp(uint32_t m, uint32_t num_buffers); + +/** + * @brief gets the address of a fft buffer allocated during new_fft_precomp. + * This buffer can be used as data pointer in subsequent calls to fft, + * and does not need to be released afterwards. + */ +EXPORT void* cplx_fft_precomp_get_buffer(const CPLX_FFT_PRECOMP* tables, uint32_t buffer_index); + +/** + * @brief allocates a new fft buffer. + * This buffer can be used as data pointer in subsequent calls to fft, + * and must be deleted afterwards by calling delete_fft_buffer. + */ +EXPORT void* new_cplx_fft_buffer(uint32_t m); + +/** + * @brief allocates a new fft buffer. + * This buffer can be used as data pointer in subsequent calls to fft, + * and must be deleted afterwards by calling delete_fft_buffer. + */ +EXPORT void delete_cplx_fft_buffer(void* buffer); + +/** + * @brief deallocates a fft table and all its built-in buffers. + */ +#define delete_cplx_fft_precomp free + +/** + * @brief computes a direct fft in-place over data. + */ +EXPORT void cplx_fft(const CPLX_FFT_PRECOMP* tables, void* data); + +EXPORT CPLX_IFFT_PRECOMP* new_cplx_ifft_precomp(uint32_t m, uint32_t num_buffers); +EXPORT void* cplx_ifft_precomp_get_buffer(const CPLX_IFFT_PRECOMP* tables, uint32_t buffer_index); +EXPORT void cplx_ifft(const CPLX_IFFT_PRECOMP* tables, void* data); +#define delete_cplx_ifft_precomp free + +EXPORT CPLX_FFTVEC_MUL_PRECOMP* new_cplx_fftvec_mul_precomp(uint32_t m); +EXPORT void cplx_fftvec_mul(const CPLX_FFTVEC_MUL_PRECOMP* tables, void* r, const void* a, const void* b); +#define delete_cplx_fftvec_mul_precomp free + +EXPORT CPLX_FFTVEC_ADDMUL_PRECOMP* new_cplx_fftvec_addmul_precomp(uint32_t m); +EXPORT void cplx_fftvec_addmul(const CPLX_FFTVEC_ADDMUL_PRECOMP* tables, void* r, const void* a, const void* b); +#define delete_cplx_fftvec_addmul_precomp free + +/** + * @brief prepares a conversion from ZnX to the cplx layout. + * All the coefficients must be strictly lower than 2^log2bound in absolute value. Any attempt to use + * this function on a larger coefficient is undefined behaviour. The resulting precomputed data must + * be freed with `new_cplx_from_znx32_precomp` + * @param m the target complex dimension m from C[X] mod X^m-i. Note that the inputs have n=2m + * int32 coefficients in natural order modulo X^n+1 + * @param log2bound bound on the input coefficients. Must be between 0 and 32 + */ +EXPORT CPLX_FROM_ZNX32_PRECOMP* new_cplx_from_znx32_precomp(uint32_t m); +/** + * @brief converts from ZnX to the cplx layout. + * @param tables precomputed data obtained by new_cplx_from_znx32_precomp. + * @param r resulting array of m complexes coefficients mod X^m-i + * @param x input array of n bounded integer coefficients mod X^n+1 + */ +EXPORT void cplx_from_znx32(const CPLX_FROM_ZNX32_PRECOMP* tables, void* r, const int32_t* a); +/** @brief frees a precomputed conversion data initialized with new_cplx_from_znx32_precomp. */ +#define delete_cplx_from_znx32_precomp free + +/** + * @brief prepares a conversion from TnX to the cplx layout. + * @param m the target complex dimension m from C[X] mod X^m-i. Note that the inputs have n=2m + * torus32 coefficients. The resulting precomputed data must + * be freed with `delete_cplx_from_tnx32_precomp` + */ +EXPORT CPLX_FROM_TNX32_PRECOMP* new_cplx_from_tnx32_precomp(uint32_t m); +/** + * @brief converts from TnX to the cplx layout. + * @param tables precomputed data obtained by new_cplx_from_tnx32_precomp. + * @param r resulting array of m complexes coefficients mod X^m-i + * @param x input array of n torus32 coefficients mod X^n+1 + */ +EXPORT void cplx_from_tnx32(const CPLX_FROM_TNX32_PRECOMP* tables, void* r, const int32_t* a); +/** @brief frees a precomputed conversion data initialized with new_cplx_from_tnx32_precomp. */ +#define delete_cplx_from_tnx32_precomp free + +/** + * @brief prepares a rescale and conversion from the cplx layout to TnX. + * @param m the target complex dimension m from C[X] mod X^m-i. Note that the outputs have n=2m + * torus32 coefficients. + * @param divisor must be a power of two. The inputs are rescaled by divisor before being reduced modulo 1. + * Remember that the output of an iFFT must be divided by m. + * @param log2overhead all inputs absolute values must be within divisor.2^log2overhead. + * For any inputs outside of these bounds, the conversion is undefined behaviour. + * The maximum supported log2overhead is 52, and the algorithm is faster for log2overhead=18. + */ +EXPORT CPLX_TO_TNX32_PRECOMP* new_cplx_to_tnx32_precomp(uint32_t m, double divisor, uint32_t log2overhead); +/** + * @brief rescale, converts and reduce mod 1 from cplx layout to torus32. + * @param tables precomputed data obtained by new_cplx_from_tnx32_precomp. + * @param r resulting array of n torus32 coefficients mod X^n+1 + * @param x input array of m cplx coefficients mod X^m-i + */ +EXPORT void cplx_to_tnx32(const CPLX_TO_TNX32_PRECOMP* tables, int32_t* r, const void* a); +#define delete_cplx_to_tnx32_precomp free + +EXPORT CPLX_TO_ZNX32_PRECOMP* new_cplx_to_znx32_precomp(uint32_t m, double divisor); +EXPORT void cplx_to_znx32(const CPLX_TO_ZNX32_PRECOMP* precomp, int32_t* r, const void* x); +#define delete_cplx_to_znx32_simple free + +EXPORT CPLX_FROM_RNX64_PRECOMP* new_cplx_from_rnx64_simple(uint32_t m); +EXPORT void cplx_from_rnx64(const CPLX_FROM_RNX64_PRECOMP* precomp, void* r, const double* x); +#define delete_cplx_from_rnx64_simple free + +EXPORT CPLX_TO_RNX64_PRECOMP* new_cplx_to_rnx64(uint32_t m, double divisor); +EXPORT void cplx_to_rnx64(const CPLX_TO_RNX64_PRECOMP* precomp, double* r, const void* x); +#define delete_cplx_round_to_rnx64_simple free + +EXPORT CPLX_ROUND_TO_RNX64_PRECOMP* new_cplx_round_to_rnx64(uint32_t m, double divisor, uint32_t log2bound); +EXPORT void cplx_round_to_rnx64(const CPLX_ROUND_TO_RNX64_PRECOMP* precomp, double* r, const void* x); +#define delete_cplx_round_to_rnx64_simple free + +/** + * @brief Simpler API for the fft function. + * For each dimension, the precomputed tables for this dimension are generated automatically. + * It is advised to do one dry-run per desired dimension before using in a multithread environment */ +EXPORT void cplx_fft_simple(uint32_t m, void* data); +/** + * @brief Simpler API for the ifft function. + * For each dimension, the precomputed tables for this dimension are generated automatically the first time. + * It is advised to do one dry-run call per desired dimension in the main thread before using in a multithread + * environment */ +EXPORT void cplx_ifft_simple(uint32_t m, void* data); +/** + * @brief Simpler API for the fftvec multiplication function. + * For each dimension, the precomputed tables for this dimension are generated automatically the first time. + * It is advised to do one dry-run call per desired dimension before using in a multithread environment */ +EXPORT void cplx_fftvec_mul_simple(uint32_t m, void* r, const void* a, const void* b); +/** + * @brief Simpler API for the fftvec addmul function. + * For each dimension, the precomputed tables for this dimension are generated automatically the first time. + * It is advised to do one dry-run call per desired dimension before using in a multithread environment */ +EXPORT void cplx_fftvec_addmul_simple(uint32_t m, void* r, const void* a, const void* b); +/** + * @brief Simpler API for the znx32 to cplx conversion. + * For each dimension, the precomputed tables for this dimension are generated automatically the first time. + * It is advised to do one dry-run call per desired dimension before using in a multithread environment */ +EXPORT void cplx_from_znx32_simple(uint32_t m, void* r, const int32_t* x); +/** + * @brief Simpler API for the tnx32 to cplx conversion. + * For each dimension, the precomputed tables for this dimension are generated automatically the first time. + * It is advised to do one dry-run call per desired dimension before using in a multithread environment */ +EXPORT void cplx_from_tnx32_simple(uint32_t m, void* r, const int32_t* x); +/** + * @brief Simpler API for the cplx to tnx32 conversion. + * For each dimension, the precomputed tables for this dimension are generated automatically the first time. + * It is advised to do one dry-run call per desired dimension before using in a multithread environment */ +EXPORT void cplx_to_tnx32_simple(uint32_t m, double divisor, uint32_t log2overhead, int32_t* r, const void* x); + +/** + * @brief converts, divides and round from cplx to znx32 (simple API) + * @param m the complex dimension + * @param divisor the divisor: a power of two, often m after an ifft + * @param r the result: must be a double array of size 2m. r must be distinct from x + * @param x the input: must hold m complex numbers. + */ +EXPORT void cplx_to_znx32_simple(uint32_t m, double divisor, int32_t* r, const void* x); + +/** + * @brief converts from rnx64 to cplx (simple API) + * The bound on the output is assumed to be within ]2^-31,2^31[. + * Any coefficient that would fall outside this range is undefined behaviour. + * @param m the complex dimension + * @param r the result: must be an array of m complex numbers. r must be distinct from x + * @param x the input: must be an array of 2m doubles. + */ +EXPORT void cplx_from_rnx64_simple(uint32_t m, void* r, const double* x); + +/** + * @brief converts, divides from cplx to rnx64 (simple API) + * @param m the complex dimension + * @param divisor the divisor: a power of two, often m after an ifft + * @param r the result: must be a double array of size 2m. r must be distinct from x + * @param x the input: must hold m complex numbers. + */ +EXPORT void cplx_to_rnx64_simple(uint32_t m, double divisor, double* r, const void* x); + +/** + * @brief converts, divides and round to integer from cplx to rnx32 (simple API) + * @param m the complex dimension + * @param divisor the divisor: a power of two, often m after an ifft + * @param log2bound a guarantee on the log2bound of the output. log2bound<=48 will use a more efficient algorithm. + * @param r the result: must be a double array of size 2m. r must be distinct from x + * @param x the input: must hold m complex numbers. + */ +EXPORT void cplx_round_to_rnx64_simple(uint32_t m, double divisor, uint32_t log2bound, double* r, const void* x); + +#endif // SPQLIOS_CPLX_FFT_H diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_fft16_avx_fma.s b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_fft16_avx_fma.s new file mode 100644 index 0000000..40e3985 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_fft16_avx_fma.s @@ -0,0 +1,156 @@ +# shifted FFT over X^16-i +# 1st argument (rdi) contains 16 complexes +# 2nd argument (rsi) contains: 8 complexes +# omega,alpha,beta,j.beta,gamma,j.gamma,k.gamma,kj.gamma +# alpha = sqrt(omega), beta = sqrt(alpha), gamma = sqrt(beta) +# j = sqrt(i), k=sqrt(j) +.globl cplx_fft16_avx_fma +cplx_fft16_avx_fma: +vmovupd (%rdi),%ymm8 +vmovupd 0x20(%rdi),%ymm9 +vmovupd 0x40(%rdi),%ymm10 +vmovupd 0x60(%rdi),%ymm11 +vmovupd 0x80(%rdi),%ymm12 +vmovupd 0xa0(%rdi),%ymm13 +vmovupd 0xc0(%rdi),%ymm14 +vmovupd 0xe0(%rdi),%ymm15 + +.first_pass: +vmovupd (%rsi),%xmm0 /* omri */ +vinsertf128 $1, %xmm0, %ymm0, %ymm0 /* omriri */ +vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: omiiii */ +vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: omrrrr */ +vshufpd $5, %ymm12, %ymm12, %ymm4 +vshufpd $5, %ymm13, %ymm13, %ymm5 +vshufpd $5, %ymm14, %ymm14, %ymm6 +vshufpd $5, %ymm15, %ymm15, %ymm7 +vmulpd %ymm4,%ymm1,%ymm4 +vmulpd %ymm5,%ymm1,%ymm5 +vmulpd %ymm6,%ymm1,%ymm6 +vmulpd %ymm7,%ymm1,%ymm7 +vfmaddsub231pd %ymm12, %ymm0, %ymm4 # ymm4 = (ymm0 * ymm12) +/- ymm4 +vfmaddsub231pd %ymm13, %ymm0, %ymm5 +vfmaddsub231pd %ymm14, %ymm0, %ymm6 +vfmaddsub231pd %ymm15, %ymm0, %ymm7 +vsubpd %ymm4,%ymm8,%ymm12 +vsubpd %ymm5,%ymm9,%ymm13 +vsubpd %ymm6,%ymm10,%ymm14 +vsubpd %ymm7,%ymm11,%ymm15 +vaddpd %ymm4,%ymm8,%ymm8 +vaddpd %ymm5,%ymm9,%ymm9 +vaddpd %ymm6,%ymm10,%ymm10 +vaddpd %ymm7,%ymm11,%ymm11 + +.second_pass: +vmovupd 16(%rsi),%xmm0 /* omri */ +vinsertf128 $1, %xmm0, %ymm0, %ymm0 /* omriri */ +vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: omiiii */ +vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: omrrrr */ +vshufpd $5, %ymm10, %ymm10, %ymm4 +vshufpd $5, %ymm11, %ymm11, %ymm5 +vshufpd $5, %ymm14, %ymm14, %ymm6 +vshufpd $5, %ymm15, %ymm15, %ymm7 +vmulpd %ymm4,%ymm1,%ymm4 +vmulpd %ymm5,%ymm1,%ymm5 +vmulpd %ymm6,%ymm0,%ymm6 +vmulpd %ymm7,%ymm0,%ymm7 +vfmaddsub231pd %ymm10, %ymm0, %ymm4 # ymm4 = (ymm0 * ymm10) +/- ymm4 +vfmaddsub231pd %ymm11, %ymm0, %ymm5 +vfmsubadd231pd %ymm14, %ymm1, %ymm6 +vfmsubadd231pd %ymm15, %ymm1, %ymm7 +vsubpd %ymm4,%ymm8,%ymm10 +vsubpd %ymm5,%ymm9,%ymm11 +vaddpd %ymm6,%ymm12,%ymm14 +vaddpd %ymm7,%ymm13,%ymm15 +vaddpd %ymm4,%ymm8,%ymm8 +vaddpd %ymm5,%ymm9,%ymm9 +vsubpd %ymm6,%ymm12,%ymm12 +vsubpd %ymm7,%ymm13,%ymm13 + +.third_pass: +vmovupd 32(%rsi),%xmm0 /* gamma */ +vmovupd 48(%rsi),%xmm2 /* delta */ +vinsertf128 $1, %xmm0, %ymm0, %ymm0 +vinsertf128 $1, %xmm2, %ymm2, %ymm2 +vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: gama.iiii */ +vshufpd $15, %ymm2, %ymm2, %ymm3 /* ymm3: delta.iiii */ +vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: gama.rrrr */ +vshufpd $0, %ymm2, %ymm2, %ymm2 /* ymm2: delta.rrrr */ +vshufpd $5, %ymm9, %ymm9, %ymm4 +vshufpd $5, %ymm11, %ymm11, %ymm5 +vshufpd $5, %ymm13, %ymm13, %ymm6 +vshufpd $5, %ymm15, %ymm15, %ymm7 +vmulpd %ymm4,%ymm1,%ymm4 +vmulpd %ymm5,%ymm0,%ymm5 +vmulpd %ymm6,%ymm3,%ymm6 +vmulpd %ymm7,%ymm2,%ymm7 +vfmaddsub231pd %ymm9, %ymm0, %ymm4 # ymm4 = (ymm0 * ymm10) +/- ymm4 +vfmsubadd231pd %ymm11, %ymm1, %ymm5 +vfmaddsub231pd %ymm13, %ymm2, %ymm6 +vfmsubadd231pd %ymm15, %ymm3, %ymm7 +vsubpd %ymm4,%ymm8,%ymm9 +vaddpd %ymm5,%ymm10,%ymm11 +vsubpd %ymm6,%ymm12,%ymm13 +vaddpd %ymm7,%ymm14,%ymm15 +vaddpd %ymm4,%ymm8,%ymm8 +vsubpd %ymm5,%ymm10,%ymm10 +vaddpd %ymm6,%ymm12,%ymm12 +vsubpd %ymm7,%ymm14,%ymm14 + +.fourth_pass: +vmovupd 64(%rsi),%ymm0 /* gamma */ +vmovupd 96(%rsi),%ymm2 /* delta */ +vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: gama.iiii */ +vshufpd $15, %ymm2, %ymm2, %ymm3 /* ymm3: delta.iiii */ +vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: gama.rrrr */ +vshufpd $0, %ymm2, %ymm2, %ymm2 /* ymm2: delta.rrrr */ +vperm2f128 $0x31,%ymm10,%ymm8,%ymm4 # ymm4 contains c1,c5 -- x gamma +vperm2f128 $0x31,%ymm11,%ymm9,%ymm5 # ymm5 contains c3,c7 -- x igamma +vperm2f128 $0x31,%ymm14,%ymm12,%ymm6 # ymm6 contains c9,c13 -- x delta +vperm2f128 $0x31,%ymm15,%ymm13,%ymm7 # ymm7 contains c11,c15 -- x idelta +vperm2f128 $0x20,%ymm10,%ymm8,%ymm8 # ymm8 contains c0,c4 +vperm2f128 $0x20,%ymm11,%ymm9,%ymm9 # ymm9 contains c2,c6 +vperm2f128 $0x20,%ymm14,%ymm12,%ymm10 # ymm10 contains c8,c12 +vperm2f128 $0x20,%ymm15,%ymm13,%ymm11 # ymm11 contains c10,c14 +vshufpd $5, %ymm4, %ymm4, %ymm12 +vshufpd $5, %ymm5, %ymm5, %ymm13 +vshufpd $5, %ymm6, %ymm6, %ymm14 +vshufpd $5, %ymm7, %ymm7, %ymm15 +vmulpd %ymm12,%ymm1,%ymm12 +vmulpd %ymm13,%ymm0,%ymm13 +vmulpd %ymm14,%ymm3,%ymm14 +vmulpd %ymm15,%ymm2,%ymm15 +vfmaddsub231pd %ymm4, %ymm0, %ymm12 # ymm12 = (ymm0 * ymm4) +/- ymm12 +vfmsubadd231pd %ymm5, %ymm1, %ymm13 +vfmaddsub231pd %ymm6, %ymm2, %ymm14 +vfmsubadd231pd %ymm7, %ymm3, %ymm15 +vsubpd %ymm12,%ymm8,%ymm4 +vaddpd %ymm13,%ymm9,%ymm5 +vsubpd %ymm14,%ymm10,%ymm6 +vaddpd %ymm15,%ymm11,%ymm7 +vaddpd %ymm12,%ymm8,%ymm8 +vsubpd %ymm13,%ymm9,%ymm9 +vaddpd %ymm14,%ymm10,%ymm10 +vsubpd %ymm15,%ymm11,%ymm11 + +vperm2f128 $0x20,%ymm6,%ymm10,%ymm12 # ymm4 contains c1,c5 -- x gamma +vperm2f128 $0x20,%ymm7,%ymm11,%ymm13 # ymm5 contains c3,c7 -- x igamma +vperm2f128 $0x31,%ymm6,%ymm10,%ymm14 # ymm6 contains c9,c13 -- x delta +vperm2f128 $0x31,%ymm7,%ymm11,%ymm15 # ymm7 contains c11,c15 -- x idelta +vperm2f128 $0x31,%ymm4,%ymm8,%ymm10 # ymm10 contains c8,c12 +vperm2f128 $0x31,%ymm5,%ymm9,%ymm11 # ymm11 contains c10,c14 +vperm2f128 $0x20,%ymm4,%ymm8,%ymm8 # ymm8 contains c0,c4 +vperm2f128 $0x20,%ymm5,%ymm9,%ymm9 # ymm9 contains c2,c6 + +.save_and_return: +vmovupd %ymm8,(%rdi) +vmovupd %ymm9,0x20(%rdi) +vmovupd %ymm10,0x40(%rdi) +vmovupd %ymm11,0x60(%rdi) +vmovupd %ymm12,0x80(%rdi) +vmovupd %ymm13,0xa0(%rdi) +vmovupd %ymm14,0xc0(%rdi) +vmovupd %ymm15,0xe0(%rdi) +ret +.size cplx_fft16_avx_fma, .-cplx_fft16_avx_fma +.section .note.GNU-stack,"",@progbits diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_fft16_avx_fma_win32.s b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_fft16_avx_fma_win32.s new file mode 100644 index 0000000..d7d4bf1 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_fft16_avx_fma_win32.s @@ -0,0 +1,190 @@ + .text + .p2align 4 + .globl cplx_fft16_avx_fma + .def cplx_fft16_avx_fma; .scl 2; .type 32; .endef +cplx_fft16_avx_fma: + + pushq %rdi + pushq %rsi + movq %rcx,%rdi + movq %rdx,%rsi + subq $0x100,%rsp + movdqu %xmm6,(%rsp) + movdqu %xmm7,0x10(%rsp) + movdqu %xmm8,0x20(%rsp) + movdqu %xmm9,0x30(%rsp) + movdqu %xmm10,0x40(%rsp) + movdqu %xmm11,0x50(%rsp) + movdqu %xmm12,0x60(%rsp) + movdqu %xmm13,0x70(%rsp) + movdqu %xmm14,0x80(%rsp) + movdqu %xmm15,0x90(%rsp) + callq cplx_fft16_avx_fma_amd64 + movdqu (%rsp),%xmm6 + movdqu 0x10(%rsp),%xmm7 + movdqu 0x20(%rsp),%xmm8 + movdqu 0x30(%rsp),%xmm9 + movdqu 0x40(%rsp),%xmm10 + movdqu 0x50(%rsp),%xmm11 + movdqu 0x60(%rsp),%xmm12 + movdqu 0x70(%rsp),%xmm13 + movdqu 0x80(%rsp),%xmm14 + movdqu 0x90(%rsp),%xmm15 + addq $0x100,%rsp + popq %rsi + popq %rdi + retq + +# shifted FFT over X^16-i +# 1st argument (rdi) contains 16 complexes +# 2nd argument (rsi) contains: 8 complexes +# omega,alpha,beta,j.beta,gamma,j.gamma,k.gamma,kj.gamma +# alpha = sqrt(omega), beta = sqrt(alpha), gamma = sqrt(beta) +# j = sqrt(i), k=sqrt(j) +cplx_fft16_avx_fma_amd64: +vmovupd (%rdi),%ymm8 +vmovupd 0x20(%rdi),%ymm9 +vmovupd 0x40(%rdi),%ymm10 +vmovupd 0x60(%rdi),%ymm11 +vmovupd 0x80(%rdi),%ymm12 +vmovupd 0xa0(%rdi),%ymm13 +vmovupd 0xc0(%rdi),%ymm14 +vmovupd 0xe0(%rdi),%ymm15 + +.first_pass: +vmovupd (%rsi),%xmm0 /* omri */ +vinsertf128 $1, %xmm0, %ymm0, %ymm0 /* omriri */ +vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: omiiii */ +vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: omrrrr */ +vshufpd $5, %ymm12, %ymm12, %ymm4 +vshufpd $5, %ymm13, %ymm13, %ymm5 +vshufpd $5, %ymm14, %ymm14, %ymm6 +vshufpd $5, %ymm15, %ymm15, %ymm7 +vmulpd %ymm4,%ymm1,%ymm4 +vmulpd %ymm5,%ymm1,%ymm5 +vmulpd %ymm6,%ymm1,%ymm6 +vmulpd %ymm7,%ymm1,%ymm7 +vfmaddsub231pd %ymm12, %ymm0, %ymm4 # ymm4 = (ymm0 * ymm12) +/- ymm4 +vfmaddsub231pd %ymm13, %ymm0, %ymm5 +vfmaddsub231pd %ymm14, %ymm0, %ymm6 +vfmaddsub231pd %ymm15, %ymm0, %ymm7 +vsubpd %ymm4,%ymm8,%ymm12 +vsubpd %ymm5,%ymm9,%ymm13 +vsubpd %ymm6,%ymm10,%ymm14 +vsubpd %ymm7,%ymm11,%ymm15 +vaddpd %ymm4,%ymm8,%ymm8 +vaddpd %ymm5,%ymm9,%ymm9 +vaddpd %ymm6,%ymm10,%ymm10 +vaddpd %ymm7,%ymm11,%ymm11 + +.second_pass: +vmovupd 16(%rsi),%xmm0 /* omri */ +vinsertf128 $1, %xmm0, %ymm0, %ymm0 /* omriri */ +vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: omiiii */ +vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: omrrrr */ +vshufpd $5, %ymm10, %ymm10, %ymm4 +vshufpd $5, %ymm11, %ymm11, %ymm5 +vshufpd $5, %ymm14, %ymm14, %ymm6 +vshufpd $5, %ymm15, %ymm15, %ymm7 +vmulpd %ymm4,%ymm1,%ymm4 +vmulpd %ymm5,%ymm1,%ymm5 +vmulpd %ymm6,%ymm0,%ymm6 +vmulpd %ymm7,%ymm0,%ymm7 +vfmaddsub231pd %ymm10, %ymm0, %ymm4 # ymm4 = (ymm0 * ymm10) +/- ymm4 +vfmaddsub231pd %ymm11, %ymm0, %ymm5 +vfmsubadd231pd %ymm14, %ymm1, %ymm6 +vfmsubadd231pd %ymm15, %ymm1, %ymm7 +vsubpd %ymm4,%ymm8,%ymm10 +vsubpd %ymm5,%ymm9,%ymm11 +vaddpd %ymm6,%ymm12,%ymm14 +vaddpd %ymm7,%ymm13,%ymm15 +vaddpd %ymm4,%ymm8,%ymm8 +vaddpd %ymm5,%ymm9,%ymm9 +vsubpd %ymm6,%ymm12,%ymm12 +vsubpd %ymm7,%ymm13,%ymm13 + +.third_pass: +vmovupd 32(%rsi),%xmm0 /* gamma */ +vmovupd 48(%rsi),%xmm2 /* delta */ +vinsertf128 $1, %xmm0, %ymm0, %ymm0 +vinsertf128 $1, %xmm2, %ymm2, %ymm2 +vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: gama.iiii */ +vshufpd $15, %ymm2, %ymm2, %ymm3 /* ymm3: delta.iiii */ +vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: gama.rrrr */ +vshufpd $0, %ymm2, %ymm2, %ymm2 /* ymm2: delta.rrrr */ +vshufpd $5, %ymm9, %ymm9, %ymm4 +vshufpd $5, %ymm11, %ymm11, %ymm5 +vshufpd $5, %ymm13, %ymm13, %ymm6 +vshufpd $5, %ymm15, %ymm15, %ymm7 +vmulpd %ymm4,%ymm1,%ymm4 +vmulpd %ymm5,%ymm0,%ymm5 +vmulpd %ymm6,%ymm3,%ymm6 +vmulpd %ymm7,%ymm2,%ymm7 +vfmaddsub231pd %ymm9, %ymm0, %ymm4 # ymm4 = (ymm0 * ymm10) +/- ymm4 +vfmsubadd231pd %ymm11, %ymm1, %ymm5 +vfmaddsub231pd %ymm13, %ymm2, %ymm6 +vfmsubadd231pd %ymm15, %ymm3, %ymm7 +vsubpd %ymm4,%ymm8,%ymm9 +vaddpd %ymm5,%ymm10,%ymm11 +vsubpd %ymm6,%ymm12,%ymm13 +vaddpd %ymm7,%ymm14,%ymm15 +vaddpd %ymm4,%ymm8,%ymm8 +vsubpd %ymm5,%ymm10,%ymm10 +vaddpd %ymm6,%ymm12,%ymm12 +vsubpd %ymm7,%ymm14,%ymm14 + +.fourth_pass: +vmovupd 64(%rsi),%ymm0 /* gamma */ +vmovupd 96(%rsi),%ymm2 /* delta */ +vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: gama.iiii */ +vshufpd $15, %ymm2, %ymm2, %ymm3 /* ymm3: delta.iiii */ +vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: gama.rrrr */ +vshufpd $0, %ymm2, %ymm2, %ymm2 /* ymm2: delta.rrrr */ +vperm2f128 $0x31,%ymm10,%ymm8,%ymm4 # ymm4 contains c1,c5 -- x gamma +vperm2f128 $0x31,%ymm11,%ymm9,%ymm5 # ymm5 contains c3,c7 -- x igamma +vperm2f128 $0x31,%ymm14,%ymm12,%ymm6 # ymm6 contains c9,c13 -- x delta +vperm2f128 $0x31,%ymm15,%ymm13,%ymm7 # ymm7 contains c11,c15 -- x idelta +vperm2f128 $0x20,%ymm10,%ymm8,%ymm8 # ymm8 contains c0,c4 +vperm2f128 $0x20,%ymm11,%ymm9,%ymm9 # ymm9 contains c2,c6 +vperm2f128 $0x20,%ymm14,%ymm12,%ymm10 # ymm10 contains c8,c12 +vperm2f128 $0x20,%ymm15,%ymm13,%ymm11 # ymm11 contains c10,c14 +vshufpd $5, %ymm4, %ymm4, %ymm12 +vshufpd $5, %ymm5, %ymm5, %ymm13 +vshufpd $5, %ymm6, %ymm6, %ymm14 +vshufpd $5, %ymm7, %ymm7, %ymm15 +vmulpd %ymm12,%ymm1,%ymm12 +vmulpd %ymm13,%ymm0,%ymm13 +vmulpd %ymm14,%ymm3,%ymm14 +vmulpd %ymm15,%ymm2,%ymm15 +vfmaddsub231pd %ymm4, %ymm0, %ymm12 # ymm12 = (ymm0 * ymm4) +/- ymm12 +vfmsubadd231pd %ymm5, %ymm1, %ymm13 +vfmaddsub231pd %ymm6, %ymm2, %ymm14 +vfmsubadd231pd %ymm7, %ymm3, %ymm15 +vsubpd %ymm12,%ymm8,%ymm4 +vaddpd %ymm13,%ymm9,%ymm5 +vsubpd %ymm14,%ymm10,%ymm6 +vaddpd %ymm15,%ymm11,%ymm7 +vaddpd %ymm12,%ymm8,%ymm8 +vsubpd %ymm13,%ymm9,%ymm9 +vaddpd %ymm14,%ymm10,%ymm10 +vsubpd %ymm15,%ymm11,%ymm11 + +vperm2f128 $0x20,%ymm6,%ymm10,%ymm12 # ymm4 contains c1,c5 -- x gamma +vperm2f128 $0x20,%ymm7,%ymm11,%ymm13 # ymm5 contains c3,c7 -- x igamma +vperm2f128 $0x31,%ymm6,%ymm10,%ymm14 # ymm6 contains c9,c13 -- x delta +vperm2f128 $0x31,%ymm7,%ymm11,%ymm15 # ymm7 contains c11,c15 -- x idelta +vperm2f128 $0x31,%ymm4,%ymm8,%ymm10 # ymm10 contains c8,c12 +vperm2f128 $0x31,%ymm5,%ymm9,%ymm11 # ymm11 contains c10,c14 +vperm2f128 $0x20,%ymm4,%ymm8,%ymm8 # ymm8 contains c0,c4 +vperm2f128 $0x20,%ymm5,%ymm9,%ymm9 # ymm9 contains c2,c6 + +.save_and_return: +vmovupd %ymm8,(%rdi) +vmovupd %ymm9,0x20(%rdi) +vmovupd %ymm10,0x40(%rdi) +vmovupd %ymm11,0x60(%rdi) +vmovupd %ymm12,0x80(%rdi) +vmovupd %ymm13,0xa0(%rdi) +vmovupd %ymm14,0xc0(%rdi) +vmovupd %ymm15,0xe0(%rdi) +ret diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_fft_asserts.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_fft_asserts.c new file mode 100644 index 0000000..924aed0 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_fft_asserts.c @@ -0,0 +1,8 @@ +#include "../commons_private.h" +#include "cplx_fft_private.h" + +__always_inline void my_asserts() { + STATIC_ASSERT(sizeof(FFT_FUNCTION) == 8); + STATIC_ASSERT(sizeof(CPLX_FFT_PRECOMP) == 40); + STATIC_ASSERT(sizeof(CPLX_IFFT_PRECOMP) == 40); +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_fft_avx2_fma.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_fft_avx2_fma.c new file mode 100644 index 0000000..bdd5dad --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_fft_avx2_fma.c @@ -0,0 +1,266 @@ +#include +#include + +#include "cplx_fft_internal.h" +#include "cplx_fft_private.h" + +typedef double D4MEM[4]; + +/** + * @brief complex fft via bfs strategy (for m between 2 and 8) + * @param dat the data to run the algorithm on + * @param omg precomputed tables (must have been filled with fill_omega) + * @param m ring dimension of the FFT (modulo X^m-i) + */ +void cplx_fft_avx2_fma_bfs_2(D4MEM* dat, const D4MEM** omg, uint32_t m) { + double* data = (double*)dat; + int32_t _2nblock = m >> 1; // = h in ref code + D4MEM* const finaldd = (D4MEM*)(data + 2 * m); + while (_2nblock >= 2) { + int32_t nblock = _2nblock >> 1; // =h/2 in ref code + D4MEM* dd = (D4MEM*)data; + do { + const __m256d om = _mm256_load_pd(*omg[0]); + const __m256d omim = _mm256_addsub_pd(_mm256_setzero_pd(), _mm256_unpackhi_pd(om, om)); + const __m256d omre = _mm256_unpacklo_pd(om, om); + D4MEM* const ddend = (dd + nblock); + D4MEM* ddmid = ddend; + do { + const __m256d b = _mm256_loadu_pd(ddmid[0]); + const __m256d t1 = _mm256_mul_pd(b, omre); + const __m256d barb = _mm256_shuffle_pd(b, b, 5); + const __m256d t2 = _mm256_fmadd_pd(barb, omim, t1); + const __m256d a = _mm256_loadu_pd(dd[0]); + const __m256d newa = _mm256_add_pd(a, t2); + const __m256d newb = _mm256_sub_pd(a, t2); + _mm256_storeu_pd(dd[0], newa); + _mm256_storeu_pd(ddmid[0], newb); + dd += 1; + ddmid += 1; + } while (dd < ddend); + dd += nblock; + *omg += 1; + } while (dd < finaldd); + _2nblock >>= 1; + } + // last iteration when _2nblock == 1 + { + D4MEM* dd = (D4MEM*)data; + do { + const __m256d om = _mm256_load_pd(*omg[0]); + const __m256d omre = _mm256_unpacklo_pd(om, om); + const __m256d omim = _mm256_unpackhi_pd(om, om); + const __m256d ab = _mm256_loadu_pd(dd[0]); + const __m256d bb = _mm256_permute4x64_pd(ab, 0b11101110); + const __m256d bbbar = _mm256_permute4x64_pd(ab, 0b10111011); + const __m256d t1 = _mm256_mul_pd(bbbar, omim); + const __m256d t2 = _mm256_fmaddsub_pd(bb, omre, t1); + const __m256d aa = _mm256_permute4x64_pd(ab, 0b01000100); + const __m256d newab = _mm256_add_pd(aa, t2); + _mm256_storeu_pd(dd[0], newab); + dd += 1; + *omg += 1; + } while (dd < finaldd); + } +} + +__always_inline void cplx_twiddle_fft_avx2(int32_t h, D4MEM* data, const void* omg) { + const __m256d om = _mm256_loadu_pd(omg); + const __m256d omim = _mm256_unpackhi_pd(om, om); + const __m256d omre = _mm256_unpacklo_pd(om, om); + D4MEM* d0 = data; + D4MEM* const ddend = d0 + (h >> 1); + D4MEM* d1 = ddend; + do { + const __m256d b = _mm256_loadu_pd(d1[0]); + const __m256d barb = _mm256_shuffle_pd(b, b, 5); + const __m256d t1 = _mm256_mul_pd(barb, omim); + const __m256d t2 = _mm256_fmaddsub_pd(b, omre, t1); + const __m256d a = _mm256_loadu_pd(d0[0]); + const __m256d newa = _mm256_add_pd(a, t2); + const __m256d newb = _mm256_sub_pd(a, t2); + _mm256_storeu_pd(d0[0], newa); + _mm256_storeu_pd(d1[0], newb); + d0 += 1; + d1 += 1; + } while (d0 < ddend); +} + +__always_inline void cplx_bitwiddle_fft_avx2(int32_t h, void* data, const void* powom) { + const __m256d omx = _mm256_loadu_pd(powom); + const __m256d oma = _mm256_permute2f128_pd(omx, omx, 0x00); + const __m256d omb = _mm256_permute2f128_pd(omx, omx, 0x11); + const __m256d omaim = _mm256_unpackhi_pd(oma, oma); + const __m256d omare = _mm256_unpacklo_pd(oma, oma); + const __m256d ombim = _mm256_unpackhi_pd(omb, omb); + const __m256d ombre = _mm256_unpacklo_pd(omb, omb); + D4MEM* d0 = (D4MEM*)data; + D4MEM* const ddend = d0 + (h >> 1); + D4MEM* d1 = ddend; + D4MEM* d2 = d0 + h; + D4MEM* d3 = d1 + h; + __m256d reg0, reg1, reg2, reg3, tmp0, tmp1; + do { + reg0 = _mm256_loadu_pd(d0[0]); + reg1 = _mm256_loadu_pd(d1[0]); + reg2 = _mm256_loadu_pd(d2[0]); + reg3 = _mm256_loadu_pd(d3[0]); + tmp0 = _mm256_shuffle_pd(reg2, reg2, 5); + tmp1 = _mm256_shuffle_pd(reg3, reg3, 5); + tmp0 = _mm256_mul_pd(tmp0, omaim); + tmp1 = _mm256_mul_pd(tmp1, omaim); + tmp0 = _mm256_fmaddsub_pd(reg2, omare, tmp0); + tmp1 = _mm256_fmaddsub_pd(reg3, omare, tmp1); + reg2 = _mm256_sub_pd(reg0, tmp0); + reg3 = _mm256_sub_pd(reg1, tmp1); + reg0 = _mm256_add_pd(reg0, tmp0); + reg1 = _mm256_add_pd(reg1, tmp1); + //-------------------------------------- + tmp0 = _mm256_shuffle_pd(reg1, reg1, 5); + tmp1 = _mm256_shuffle_pd(reg3, reg3, 5); + tmp0 = _mm256_mul_pd(tmp0, ombim); //(r,i) + tmp1 = _mm256_mul_pd(tmp1, ombre); //(-i,r) + tmp0 = _mm256_fmaddsub_pd(reg1, ombre, tmp0); + tmp1 = _mm256_fmsubadd_pd(reg3, ombim, tmp1); + reg1 = _mm256_sub_pd(reg0, tmp0); + reg3 = _mm256_add_pd(reg2, tmp1); + reg0 = _mm256_add_pd(reg0, tmp0); + reg2 = _mm256_sub_pd(reg2, tmp1); + ///// + _mm256_storeu_pd(d0[0], reg0); + _mm256_storeu_pd(d1[0], reg1); + _mm256_storeu_pd(d2[0], reg2); + _mm256_storeu_pd(d3[0], reg3); + d0 += 1; + d1 += 1; + d2 += 1; + d3 += 1; + } while (d0 < ddend); +} + +/** + * @brief complex fft via bfs strategy (for m >= 16) + * @param dat the data to run the algorithm on + * @param omg precomputed tables (must have been filled with fill_omega) + * @param m ring dimension of the FFT (modulo X^m-i) + */ +void cplx_fft_avx2_fma_bfs_16(D4MEM* dat, const D4MEM** omg, uint32_t m) { + double* data = (double*)dat; + D4MEM* const finaldd = (D4MEM*)(data + 2 * m); + uint32_t mm = m; + uint32_t log2m = _mm_popcnt_u32(m - 1); // log2(m) + if (log2m % 2 == 1) { + uint32_t h = mm >> 1; + cplx_twiddle_fft_avx2(h, dat, **omg); + *omg += 1; + mm >>= 1; + } + while (mm > 16) { + uint32_t h = mm / 4; + for (CPLX* d = (CPLX*)data; d < (CPLX*)finaldd; d += mm) { + cplx_bitwiddle_fft_avx2(h, d, (CPLX*)*omg); + *omg += 1; + } + mm = h; + } + { + D4MEM* dd = (D4MEM*)data; + do { + cplx_fft16_avx_fma(dd, *omg); + dd += 8; + *omg += 4; + } while (dd < finaldd); + _mm256_zeroupper(); + } + /* + int32_t _2nblock = m >> 1; // = h in ref code + while (_2nblock >= 16) { + int32_t nblock = _2nblock >> 1; // =h/2 in ref code + D4MEM* dd = (D4MEM*)data; + do { + const __m256d om = _mm256_load_pd(*omg[0]); + const __m256d omim = _mm256_addsub_pd(_mm256_setzero_pd(), _mm256_unpackhi_pd(om, om)); + const __m256d omre = _mm256_unpacklo_pd(om, om); + D4MEM* const ddend = (dd + nblock); + D4MEM* ddmid = ddend; + do { + const __m256d b = _mm256_loadu_pd(ddmid[0]); + const __m256d t1 = _mm256_mul_pd(b, omre); + const __m256d barb = _mm256_shuffle_pd(b, b, 5); + const __m256d t2 = _mm256_fmadd_pd(barb, omim, t1); + const __m256d a = _mm256_loadu_pd(dd[0]); + const __m256d newa = _mm256_add_pd(a, t2); + const __m256d newb = _mm256_sub_pd(a, t2); + _mm256_storeu_pd(dd[0], newa); + _mm256_storeu_pd(ddmid[0], newb); + dd += 1; + ddmid += 1; + } while (dd < ddend); + dd += nblock; + *omg += 1; + } while (dd < finaldd); + _2nblock >>= 1; + } + // last iteration when _2nblock == 8 + { + D4MEM* dd = (D4MEM*)data; + do { + cplx_fft16_avx_fma(dd, *omg); + dd += 8; + *omg += 4; + } while (dd < finaldd); + _mm256_zeroupper(); + } + */ +} + +/** + * @brief complex fft via dfs recursion (for m >= 16) + * @param dat the data to run the algorithm on + * @param omg precomputed tables (must have been filled with fill_omega) + * @param m ring dimension of the FFT (modulo X^m-i) + */ +void cplx_fft_avx2_fma_rec_16(D4MEM* dat, const D4MEM** omg, uint32_t m) { + if (m <= 8) return cplx_fft_avx2_fma_bfs_2(dat, omg, m); + if (m <= 2048) return cplx_fft_avx2_fma_bfs_16(dat, omg, m); + double* data = (double*)dat; + int32_t _2nblock = m >> 1; // = h in ref code + int32_t nblock = _2nblock >> 1; // =h/2 in ref code + D4MEM* dd = (D4MEM*)data; + const __m256d om = _mm256_load_pd(*omg[0]); + const __m256d omim = _mm256_addsub_pd(_mm256_setzero_pd(), _mm256_unpackhi_pd(om, om)); + const __m256d omre = _mm256_unpacklo_pd(om, om); + D4MEM* const ddend = (dd + nblock); + D4MEM* ddmid = ddend; + do { + const __m256d b = _mm256_loadu_pd(ddmid[0]); + const __m256d t1 = _mm256_mul_pd(b, omre); + const __m256d barb = _mm256_shuffle_pd(b, b, 5); + const __m256d t2 = _mm256_fmadd_pd(barb, omim, t1); + const __m256d a = _mm256_loadu_pd(dd[0]); + const __m256d newa = _mm256_add_pd(a, t2); + const __m256d newb = _mm256_sub_pd(a, t2); + _mm256_storeu_pd(dd[0], newa); + _mm256_storeu_pd(ddmid[0], newb); + dd += 1; + ddmid += 1; + } while (dd < ddend); + *omg += 1; + cplx_fft_avx2_fma_rec_16(dat, omg, _2nblock); + cplx_fft_avx2_fma_rec_16(ddend, omg, _2nblock); +} + +/** + * @brief complex fft via best strategy (for m>=1) + * @param dat the data to run the algorithm on: m complex numbers + * @param omg precomputed tables (must have been filled with fill_omega) + * @param m ring dimension of the FFT (modulo X^m-i) + */ +EXPORT void cplx_fft_avx2_fma(const CPLX_FFT_PRECOMP* precomp, void* d) { + const uint32_t m = precomp->m; + const D4MEM* omg = (D4MEM*)precomp->powomegas; + if (m <= 1) return; + if (m <= 8) return cplx_fft_avx2_fma_bfs_2(d, &omg, m); + if (m <= 2048) return cplx_fft_avx2_fma_bfs_16(d, &omg, m); + cplx_fft_avx2_fma_rec_16(d, &omg, m); +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_fft_avx512.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_fft_avx512.c new file mode 100644 index 0000000..f0cf764 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_fft_avx512.c @@ -0,0 +1,451 @@ +#include + +#include "cplx_fft_internal.h" +#include "cplx_fft_private.h" + +typedef double D2MEM[2]; +typedef double D4MEM[4]; +typedef double D8MEM[8]; + +EXPORT void cplx_fftvec_addmul_avx512(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a, + const void* b) { + const uint32_t m = precomp->m; + const D8MEM* aa = (D8MEM*)a; + const D8MEM* bb = (D8MEM*)b; + D8MEM* rr = (D8MEM*)r; + const D8MEM* const aend = aa + (m >> 2); + do { + /* +BEGIN_TEMPLATE +const __m512d ari% = _mm512_loadu_pd(aa[%]); +const __m512d bri% = _mm512_loadu_pd(bb[%]); +const __m512d rri% = _mm512_loadu_pd(rr[%]); +const __m512d bir% = _mm512_shuffle_pd(bri%,bri%, 0b01010101); +const __m512d aii% = _mm512_shuffle_pd(ari%,ari%, 0b11111111); +const __m512d pro% = _mm512_fmaddsub_pd(aii%,bir%,rri%); +const __m512d arr% = _mm512_shuffle_pd(ari%,ari%, 0b00000000); +const __m512d res% = _mm512_fmaddsub_pd(arr%,bri%,pro%); +_mm512_storeu_pd(rr[%],res%); +rr += @; // ONCE +aa += @; // ONCE +bb += @; // ONCE +END_TEMPLATE + */ + // BEGIN_INTERLEAVE 2 + const __m512d ari0 = _mm512_loadu_pd(aa[0]); + const __m512d ari1 = _mm512_loadu_pd(aa[1]); + const __m512d bri0 = _mm512_loadu_pd(bb[0]); + const __m512d bri1 = _mm512_loadu_pd(bb[1]); + const __m512d rri0 = _mm512_loadu_pd(rr[0]); + const __m512d rri1 = _mm512_loadu_pd(rr[1]); + const __m512d bir0 = _mm512_shuffle_pd(bri0, bri0, 0b01010101); + const __m512d bir1 = _mm512_shuffle_pd(bri1, bri1, 0b01010101); + const __m512d aii0 = _mm512_shuffle_pd(ari0, ari0, 0b11111111); + const __m512d aii1 = _mm512_shuffle_pd(ari1, ari1, 0b11111111); + const __m512d pro0 = _mm512_fmaddsub_pd(aii0, bir0, rri0); + const __m512d pro1 = _mm512_fmaddsub_pd(aii1, bir1, rri1); + const __m512d arr0 = _mm512_shuffle_pd(ari0, ari0, 0b00000000); + const __m512d arr1 = _mm512_shuffle_pd(ari1, ari1, 0b00000000); + const __m512d res0 = _mm512_fmaddsub_pd(arr0, bri0, pro0); + const __m512d res1 = _mm512_fmaddsub_pd(arr1, bri1, pro1); + _mm512_storeu_pd(rr[0], res0); + _mm512_storeu_pd(rr[1], res1); + rr += 2; // ONCE + aa += 2; // ONCE + bb += 2; // ONCE + // END_INTERLEAVE + } while (aa < aend); +} + +#if 0 +EXPORT void cplx_fftvec_mul_fma(uint32_t m, void* r, const void* a, const void* b) { + const double(*aa)[4] = (double(*)[4])a; + const double(*bb)[4] = (double(*)[4])b; + double(*rr)[4] = (double(*)[4])r; + const double(*const aend)[4] = aa + (m >> 1); + do { + /* +BEGIN_TEMPLATE +const __m256d ari% = _mm256_loadu_pd(aa[%]); +const __m256d bri% = _mm256_loadu_pd(bb[%]); +const __m256d bir% = _mm256_shuffle_pd(bri%,bri%, 5); // conj of b +const __m256d aii% = _mm256_shuffle_pd(ari%,ari%, 15); // im of a +const __m256d pro% = _mm256_mul_pd(aii%,bir%); +const __m256d arr% = _mm256_shuffle_pd(ari%,ari%, 0); // rr of a +const __m256d res% = _mm256_fmaddsub_pd(arr%,bri%,pro%); +_mm256_storeu_pd(rr[%],res%); +END_TEMPLATE + */ + // BEGIN_INTERLEAVE 4 +const __m256d ari0 = _mm256_loadu_pd(aa[0]); +const __m256d ari1 = _mm256_loadu_pd(aa[1]); +const __m256d ari2 = _mm256_loadu_pd(aa[2]); +const __m256d ari3 = _mm256_loadu_pd(aa[3]); +const __m256d bri0 = _mm256_loadu_pd(bb[0]); +const __m256d bri1 = _mm256_loadu_pd(bb[1]); +const __m256d bri2 = _mm256_loadu_pd(bb[2]); +const __m256d bri3 = _mm256_loadu_pd(bb[3]); +const __m256d bir0 = _mm256_shuffle_pd(bri0,bri0, 5); // conj of b +const __m256d bir1 = _mm256_shuffle_pd(bri1,bri1, 5); // conj of b +const __m256d bir2 = _mm256_shuffle_pd(bri2,bri2, 5); // conj of b +const __m256d bir3 = _mm256_shuffle_pd(bri3,bri3, 5); // conj of b +const __m256d aii0 = _mm256_shuffle_pd(ari0,ari0, 15); // im of a +const __m256d aii1 = _mm256_shuffle_pd(ari1,ari1, 15); // im of a +const __m256d aii2 = _mm256_shuffle_pd(ari2,ari2, 15); // im of a +const __m256d aii3 = _mm256_shuffle_pd(ari3,ari3, 15); // im of a +const __m256d pro0 = _mm256_mul_pd(aii0,bir0); +const __m256d pro1 = _mm256_mul_pd(aii1,bir1); +const __m256d pro2 = _mm256_mul_pd(aii2,bir2); +const __m256d pro3 = _mm256_mul_pd(aii3,bir3); +const __m256d arr0 = _mm256_shuffle_pd(ari0,ari0, 0); // rr of a +const __m256d arr1 = _mm256_shuffle_pd(ari1,ari1, 0); // rr of a +const __m256d arr2 = _mm256_shuffle_pd(ari2,ari2, 0); // rr of a +const __m256d arr3 = _mm256_shuffle_pd(ari3,ari3, 0); // rr of a +const __m256d res0 = _mm256_fmaddsub_pd(arr0,bri0,pro0); +const __m256d res1 = _mm256_fmaddsub_pd(arr1,bri1,pro1); +const __m256d res2 = _mm256_fmaddsub_pd(arr2,bri2,pro2); +const __m256d res3 = _mm256_fmaddsub_pd(arr3,bri3,pro3); +_mm256_storeu_pd(rr[0],res0); +_mm256_storeu_pd(rr[1],res1); +_mm256_storeu_pd(rr[2],res2); +_mm256_storeu_pd(rr[3],res3); + // END_INTERLEAVE + rr += 4; + aa += 4; + bb += 4; + } while (aa < aend); +} + +EXPORT void cplx_fftvec_add_fma(uint32_t m, void* r, const void* a, const void* b) { + const double(*aa)[4] = (double(*)[4])a; + const double(*bb)[4] = (double(*)[4])b; + double(*rr)[4] = (double(*)[4])r; + const double(*const aend)[4] = aa + (m >> 1); + do { + /* +BEGIN_TEMPLATE +const __m256d ari% = _mm256_loadu_pd(aa[%]); +const __m256d bri% = _mm256_loadu_pd(bb[%]); +const __m256d res% = _mm256_add_pd(ari%,bri%); +_mm256_storeu_pd(rr[%],res%); +END_TEMPLATE + */ + // BEGIN_INTERLEAVE 4 +const __m256d ari0 = _mm256_loadu_pd(aa[0]); +const __m256d ari1 = _mm256_loadu_pd(aa[1]); +const __m256d ari2 = _mm256_loadu_pd(aa[2]); +const __m256d ari3 = _mm256_loadu_pd(aa[3]); +const __m256d bri0 = _mm256_loadu_pd(bb[0]); +const __m256d bri1 = _mm256_loadu_pd(bb[1]); +const __m256d bri2 = _mm256_loadu_pd(bb[2]); +const __m256d bri3 = _mm256_loadu_pd(bb[3]); +const __m256d res0 = _mm256_add_pd(ari0,bri0); +const __m256d res1 = _mm256_add_pd(ari1,bri1); +const __m256d res2 = _mm256_add_pd(ari2,bri2); +const __m256d res3 = _mm256_add_pd(ari3,bri3); +_mm256_storeu_pd(rr[0],res0); +_mm256_storeu_pd(rr[1],res1); +_mm256_storeu_pd(rr[2],res2); +_mm256_storeu_pd(rr[3],res3); + // END_INTERLEAVE + rr += 4; + aa += 4; + bb += 4; + } while (aa < aend); +} + +EXPORT void cplx_fftvec_sub2_to_fma(uint32_t m, void* r, const void* a, const void* b) { + const double(*aa)[4] = (double(*)[4])a; + const double(*bb)[4] = (double(*)[4])b; + double(*rr)[4] = (double(*)[4])r; + const double(*const aend)[4] = aa + (m >> 1); + do { + /* +BEGIN_TEMPLATE +const __m256d ari% = _mm256_loadu_pd(aa[%]); +const __m256d bri% = _mm256_loadu_pd(bb[%]); +const __m256d sum% = _mm256_add_pd(ari%,bri%); +const __m256d rri% = _mm256_loadu_pd(rr[%]); +const __m256d res% = _mm256_sub_pd(rri%,sum%); +_mm256_storeu_pd(rr[%],res%); +END_TEMPLATE + */ + // BEGIN_INTERLEAVE 4 +const __m256d ari0 = _mm256_loadu_pd(aa[0]); +const __m256d ari1 = _mm256_loadu_pd(aa[1]); +const __m256d ari2 = _mm256_loadu_pd(aa[2]); +const __m256d ari3 = _mm256_loadu_pd(aa[3]); +const __m256d bri0 = _mm256_loadu_pd(bb[0]); +const __m256d bri1 = _mm256_loadu_pd(bb[1]); +const __m256d bri2 = _mm256_loadu_pd(bb[2]); +const __m256d bri3 = _mm256_loadu_pd(bb[3]); +const __m256d sum0 = _mm256_add_pd(ari0,bri0); +const __m256d sum1 = _mm256_add_pd(ari1,bri1); +const __m256d sum2 = _mm256_add_pd(ari2,bri2); +const __m256d sum3 = _mm256_add_pd(ari3,bri3); +const __m256d rri0 = _mm256_loadu_pd(rr[0]); +const __m256d rri1 = _mm256_loadu_pd(rr[1]); +const __m256d rri2 = _mm256_loadu_pd(rr[2]); +const __m256d rri3 = _mm256_loadu_pd(rr[3]); +const __m256d res0 = _mm256_sub_pd(rri0,sum0); +const __m256d res1 = _mm256_sub_pd(rri1,sum1); +const __m256d res2 = _mm256_sub_pd(rri2,sum2); +const __m256d res3 = _mm256_sub_pd(rri3,sum3); +_mm256_storeu_pd(rr[0],res0); +_mm256_storeu_pd(rr[1],res1); +_mm256_storeu_pd(rr[2],res2); +_mm256_storeu_pd(rr[3],res3); + // END_INTERLEAVE + rr += 4; + aa += 4; + bb += 4; + } while (aa < aend); +} + +EXPORT void cplx_fftvec_copy_fma(uint32_t m, void* r, const void* a) { + const double(*aa)[4] = (double(*)[4])a; + double(*rr)[4] = (double(*)[4])r; + const double(*const aend)[4] = aa + (m >> 1); + do { + /* +BEGIN_TEMPLATE +const __m256d ari% = _mm256_loadu_pd(aa[%]); +_mm256_storeu_pd(rr[%],ari%); +END_TEMPLATE + */ + // BEGIN_INTERLEAVE 4 +const __m256d ari0 = _mm256_loadu_pd(aa[0]); +const __m256d ari1 = _mm256_loadu_pd(aa[1]); +const __m256d ari2 = _mm256_loadu_pd(aa[2]); +const __m256d ari3 = _mm256_loadu_pd(aa[3]); +_mm256_storeu_pd(rr[0],ari0); +_mm256_storeu_pd(rr[1],ari1); +_mm256_storeu_pd(rr[2],ari2); +_mm256_storeu_pd(rr[3],ari3); + // END_INTERLEAVE + rr += 4; + aa += 4; + } while (aa < aend); +} + +EXPORT void cplx_fftvec_twiddle_fma(uint32_t m, void* a, void* b, const void* omg) { + double(*aa)[4] = (double(*)[4])a; + double(*bb)[4] = (double(*)[4])b; + const double(*const aend)[4] = aa + (m >> 1); + const __m256d om = _mm256_loadu_pd(omg); + const __m256d omrr = _mm256_shuffle_pd(om, om, 0); + const __m256d omii = _mm256_shuffle_pd(om, om, 15); + do { + /* +BEGIN_TEMPLATE +const __m256d bri% = _mm256_loadu_pd(bb[%]); +const __m256d bir% = _mm256_shuffle_pd(bri%,bri%,5); +__m256d p% = _mm256_mul_pd(bir%,omii); +p% = _mm256_fmaddsub_pd(bri%,omrr,p%); +const __m256d ari% = _mm256_loadu_pd(aa[%]); +_mm256_storeu_pd(aa[%],_mm256_add_pd(ari%,p%)); +_mm256_storeu_pd(bb[%],_mm256_sub_pd(ari%,p%)); +END_TEMPLATE + */ + // BEGIN_INTERLEAVE 4 +const __m256d bri0 = _mm256_loadu_pd(bb[0]); +const __m256d bri1 = _mm256_loadu_pd(bb[1]); +const __m256d bri2 = _mm256_loadu_pd(bb[2]); +const __m256d bri3 = _mm256_loadu_pd(bb[3]); +const __m256d bir0 = _mm256_shuffle_pd(bri0,bri0,5); +const __m256d bir1 = _mm256_shuffle_pd(bri1,bri1,5); +const __m256d bir2 = _mm256_shuffle_pd(bri2,bri2,5); +const __m256d bir3 = _mm256_shuffle_pd(bri3,bri3,5); +__m256d p0 = _mm256_mul_pd(bir0,omii); +__m256d p1 = _mm256_mul_pd(bir1,omii); +__m256d p2 = _mm256_mul_pd(bir2,omii); +__m256d p3 = _mm256_mul_pd(bir3,omii); +p0 = _mm256_fmaddsub_pd(bri0,omrr,p0); +p1 = _mm256_fmaddsub_pd(bri1,omrr,p1); +p2 = _mm256_fmaddsub_pd(bri2,omrr,p2); +p3 = _mm256_fmaddsub_pd(bri3,omrr,p3); +const __m256d ari0 = _mm256_loadu_pd(aa[0]); +const __m256d ari1 = _mm256_loadu_pd(aa[1]); +const __m256d ari2 = _mm256_loadu_pd(aa[2]); +const __m256d ari3 = _mm256_loadu_pd(aa[3]); +_mm256_storeu_pd(aa[0],_mm256_add_pd(ari0,p0)); +_mm256_storeu_pd(aa[1],_mm256_add_pd(ari1,p1)); +_mm256_storeu_pd(aa[2],_mm256_add_pd(ari2,p2)); +_mm256_storeu_pd(aa[3],_mm256_add_pd(ari3,p3)); +_mm256_storeu_pd(bb[0],_mm256_sub_pd(ari0,p0)); +_mm256_storeu_pd(bb[1],_mm256_sub_pd(ari1,p1)); +_mm256_storeu_pd(bb[2],_mm256_sub_pd(ari2,p2)); +_mm256_storeu_pd(bb[3],_mm256_sub_pd(ari3,p3)); + // END_INTERLEAVE + bb += 4; + aa += 4; + } while (aa < aend); +} +#endif + +EXPORT void cplx_fftvec_twiddle_avx512(const CPLX_FFTVEC_TWIDDLE_PRECOMP* precomp, void* a, void* b, const void* omg) { + const uint32_t m = precomp->m; + D8MEM* aa = (D8MEM*)a; + D8MEM* bb = (D8MEM*)b; + D8MEM* const aend = aa + (m >> 2); + const __m512d om = _mm512_broadcast_f64x4(_mm256_loadu_pd(omg)); + const __m512d omrr = _mm512_shuffle_pd(om, om, 0b00000000); + const __m512d omii = _mm512_shuffle_pd(om, om, 0b11111111); + do { + /* + BEGIN_TEMPLATE + const __m512d bri% = _mm512_loadu_pd(bb[%]); + const __m512d bir% = _mm512_shuffle_pd(bri%,bri%,0b10011001); + __m512d p% = _mm512_mul_pd(bir%,omii); + p% = _mm512_fmaddsub_pd(bri%,omrr,p%); + const __m512d ari% = _mm512_loadu_pd(aa[%]); + _mm512_storeu_pd(aa[%],_mm512_add_pd(ari%,p%)); + _mm512_storeu_pd(bb[%],_mm512_sub_pd(ari%,p%)); + bb += @; // ONCE + aa += @; // ONCE + END_TEMPLATE + */ + // BEGIN_INTERLEAVE 4 + const __m512d bri0 = _mm512_loadu_pd(bb[0]); + const __m512d bri1 = _mm512_loadu_pd(bb[1]); + const __m512d bri2 = _mm512_loadu_pd(bb[2]); + const __m512d bri3 = _mm512_loadu_pd(bb[3]); + const __m512d bir0 = _mm512_shuffle_pd(bri0, bri0, 0b10011001); + const __m512d bir1 = _mm512_shuffle_pd(bri1, bri1, 0b10011001); + const __m512d bir2 = _mm512_shuffle_pd(bri2, bri2, 0b10011001); + const __m512d bir3 = _mm512_shuffle_pd(bri3, bri3, 0b10011001); + __m512d p0 = _mm512_mul_pd(bir0, omii); + __m512d p1 = _mm512_mul_pd(bir1, omii); + __m512d p2 = _mm512_mul_pd(bir2, omii); + __m512d p3 = _mm512_mul_pd(bir3, omii); + p0 = _mm512_fmaddsub_pd(bri0, omrr, p0); + p1 = _mm512_fmaddsub_pd(bri1, omrr, p1); + p2 = _mm512_fmaddsub_pd(bri2, omrr, p2); + p3 = _mm512_fmaddsub_pd(bri3, omrr, p3); + const __m512d ari0 = _mm512_loadu_pd(aa[0]); + const __m512d ari1 = _mm512_loadu_pd(aa[1]); + const __m512d ari2 = _mm512_loadu_pd(aa[2]); + const __m512d ari3 = _mm512_loadu_pd(aa[3]); + _mm512_storeu_pd(aa[0], _mm512_add_pd(ari0, p0)); + _mm512_storeu_pd(aa[1], _mm512_add_pd(ari1, p1)); + _mm512_storeu_pd(aa[2], _mm512_add_pd(ari2, p2)); + _mm512_storeu_pd(aa[3], _mm512_add_pd(ari3, p3)); + _mm512_storeu_pd(bb[0], _mm512_sub_pd(ari0, p0)); + _mm512_storeu_pd(bb[1], _mm512_sub_pd(ari1, p1)); + _mm512_storeu_pd(bb[2], _mm512_sub_pd(ari2, p2)); + _mm512_storeu_pd(bb[3], _mm512_sub_pd(ari3, p3)); + bb += 4; // ONCE + aa += 4; // ONCE + // END_INTERLEAVE + } while (aa < aend); +} + +EXPORT void cplx_fftvec_bitwiddle_avx512(const CPLX_FFTVEC_BITWIDDLE_PRECOMP* precomp, void* a, uint64_t slicea, + const void* omg) { + const uint32_t m = precomp->m; + const uint64_t OFFSET = slicea / sizeof(D8MEM); + D8MEM* aa = (D8MEM*)a; + const D8MEM* aend = aa + (m >> 2); + const __m512d om = _mm512_broadcast_f64x4(_mm256_loadu_pd(omg)); + const __m512d om1rr = _mm512_shuffle_pd(om, om, 0); + const __m512d om1ii = _mm512_shuffle_pd(om, om, 15); + const __m512d om2rr = _mm512_shuffle_pd(om, om, 0); + const __m512d om2ii = _mm512_shuffle_pd(om, om, 0); + const __m512d om3rr = _mm512_shuffle_pd(om, om, 15); + const __m512d om3ii = _mm512_shuffle_pd(om, om, 15); + do { + /* + BEGIN_TEMPLATE + __m512d ari% = _mm512_loadu_pd(aa[%]); + __m512d bri% = _mm512_loadu_pd((aa+OFFSET)[%]); + __m512d cri% = _mm512_loadu_pd((aa+2*OFFSET)[%]); + __m512d dri% = _mm512_loadu_pd((aa+3*OFFSET)[%]); + __m512d pa% = _mm512_shuffle_pd(cri%,cri%,5); + __m512d pb% = _mm512_shuffle_pd(dri%,dri%,5); + pa% = _mm512_mul_pd(pa%,om1ii); + pb% = _mm512_mul_pd(pb%,om1ii); + pa% = _mm512_fmaddsub_pd(cri%,om1rr,pa%); + pb% = _mm512_fmaddsub_pd(dri%,om1rr,pb%); + cri% = _mm512_sub_pd(ari%,pa%); + dri% = _mm512_sub_pd(bri%,pb%); + ari% = _mm512_add_pd(ari%,pa%); + bri% = _mm512_add_pd(bri%,pb%); + pa% = _mm512_shuffle_pd(bri%,bri%,5); + pb% = _mm512_shuffle_pd(dri%,dri%,5); + pa% = _mm512_mul_pd(pa%,om2ii); + pb% = _mm512_mul_pd(pb%,om3ii); + pa% = _mm512_fmaddsub_pd(bri%,om2rr,pa%); + pb% = _mm512_fmaddsub_pd(dri%,om3rr,pb%); + bri% = _mm512_sub_pd(ari%,pa%); + dri% = _mm512_sub_pd(cri%,pb%); + ari% = _mm512_add_pd(ari%,pa%); + cri% = _mm512_add_pd(cri%,pb%); + _mm512_storeu_pd(aa[%], ari%); + _mm512_storeu_pd((aa+OFFSET)[%],bri%); + _mm512_storeu_pd((aa+2*OFFSET)[%],cri%); + _mm512_storeu_pd((aa+3*OFFSET)[%],dri%); + aa += @; // ONCE + END_TEMPLATE + */ + // BEGIN_INTERLEAVE 2 + __m512d ari0 = _mm512_loadu_pd(aa[0]); + __m512d ari1 = _mm512_loadu_pd(aa[1]); + __m512d bri0 = _mm512_loadu_pd((aa + OFFSET)[0]); + __m512d bri1 = _mm512_loadu_pd((aa + OFFSET)[1]); + __m512d cri0 = _mm512_loadu_pd((aa + 2 * OFFSET)[0]); + __m512d cri1 = _mm512_loadu_pd((aa + 2 * OFFSET)[1]); + __m512d dri0 = _mm512_loadu_pd((aa + 3 * OFFSET)[0]); + __m512d dri1 = _mm512_loadu_pd((aa + 3 * OFFSET)[1]); + __m512d pa0 = _mm512_shuffle_pd(cri0, cri0, 5); + __m512d pa1 = _mm512_shuffle_pd(cri1, cri1, 5); + __m512d pb0 = _mm512_shuffle_pd(dri0, dri0, 5); + __m512d pb1 = _mm512_shuffle_pd(dri1, dri1, 5); + pa0 = _mm512_mul_pd(pa0, om1ii); + pa1 = _mm512_mul_pd(pa1, om1ii); + pb0 = _mm512_mul_pd(pb0, om1ii); + pb1 = _mm512_mul_pd(pb1, om1ii); + pa0 = _mm512_fmaddsub_pd(cri0, om1rr, pa0); + pa1 = _mm512_fmaddsub_pd(cri1, om1rr, pa1); + pb0 = _mm512_fmaddsub_pd(dri0, om1rr, pb0); + pb1 = _mm512_fmaddsub_pd(dri1, om1rr, pb1); + cri0 = _mm512_sub_pd(ari0, pa0); + cri1 = _mm512_sub_pd(ari1, pa1); + dri0 = _mm512_sub_pd(bri0, pb0); + dri1 = _mm512_sub_pd(bri1, pb1); + ari0 = _mm512_add_pd(ari0, pa0); + ari1 = _mm512_add_pd(ari1, pa1); + bri0 = _mm512_add_pd(bri0, pb0); + bri1 = _mm512_add_pd(bri1, pb1); + pa0 = _mm512_shuffle_pd(bri0, bri0, 5); + pa1 = _mm512_shuffle_pd(bri1, bri1, 5); + pb0 = _mm512_shuffle_pd(dri0, dri0, 5); + pb1 = _mm512_shuffle_pd(dri1, dri1, 5); + pa0 = _mm512_mul_pd(pa0, om2ii); + pa1 = _mm512_mul_pd(pa1, om2ii); + pb0 = _mm512_mul_pd(pb0, om3ii); + pb1 = _mm512_mul_pd(pb1, om3ii); + pa0 = _mm512_fmaddsub_pd(bri0, om2rr, pa0); + pa1 = _mm512_fmaddsub_pd(bri1, om2rr, pa1); + pb0 = _mm512_fmaddsub_pd(dri0, om3rr, pb0); + pb1 = _mm512_fmaddsub_pd(dri1, om3rr, pb1); + bri0 = _mm512_sub_pd(ari0, pa0); + bri1 = _mm512_sub_pd(ari1, pa1); + dri0 = _mm512_sub_pd(cri0, pb0); + dri1 = _mm512_sub_pd(cri1, pb1); + ari0 = _mm512_add_pd(ari0, pa0); + ari1 = _mm512_add_pd(ari1, pa1); + cri0 = _mm512_add_pd(cri0, pb0); + cri1 = _mm512_add_pd(cri1, pb1); + _mm512_storeu_pd(aa[0], ari0); + _mm512_storeu_pd(aa[1], ari1); + _mm512_storeu_pd((aa + OFFSET)[0], bri0); + _mm512_storeu_pd((aa + OFFSET)[1], bri1); + _mm512_storeu_pd((aa + 2 * OFFSET)[0], cri0); + _mm512_storeu_pd((aa + 2 * OFFSET)[1], cri1); + _mm512_storeu_pd((aa + 3 * OFFSET)[0], dri0); + _mm512_storeu_pd((aa + 3 * OFFSET)[1], dri1); + aa += 2; // ONCE + // END_INTERLEAVE + } while (aa < aend); + _mm256_zeroupper(); +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_fft_internal.h b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_fft_internal.h new file mode 100644 index 0000000..7aa17fd --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_fft_internal.h @@ -0,0 +1,123 @@ +#ifndef SPQLIOS_CPLX_FFT_INTERNAL_H +#define SPQLIOS_CPLX_FFT_INTERNAL_H + +#include "cplx_fft.h" + +/** @brief a complex number contains two doubles real,imag */ +typedef double CPLX[2]; + +EXPORT void cplx_set(CPLX r, const CPLX a); +EXPORT void cplx_neg(CPLX r, const CPLX a); +EXPORT void cplx_add(CPLX r, const CPLX a, const CPLX b); +EXPORT void cplx_sub(CPLX r, const CPLX a, const CPLX b); +EXPORT void cplx_mul(CPLX r, const CPLX a, const CPLX b); + +/** + * @brief splits 2h evaluations of one polynomials into 2 times h evaluations of even/odd polynomial + * Input: Q_0(y),...,Q_{h-1}(y),Q_0(-y),...,Q_{h-1}(-y) + * Output: P_0(z),...,P_{h-1}(z),P_h(z),...,P_{2h-1}(z) + * where Q_i(X)=P_i(X^2)+X.P_{h+i}(X^2) and y^2 = z + * @param h number of "coefficients" h >= 1 + * @param data 2h complex coefficients interleaved and 256b aligned + * @param powom y represented as (yre,yim) + */ +EXPORT void cplx_split_fft_ref(int32_t h, CPLX* data, const CPLX powom); +EXPORT void cplx_bisplit_fft_ref(int32_t h, CPLX* data, const CPLX powom[2]); + +/** + * Input: Q(y),Q(-y) + * Output: P_0(z),P_1(z) + * where Q(X)=P_0(X^2)+X.P_1(X^2) and y^2 = z + * @param data 2 complexes coefficients interleaved and 256b aligned + * @param powom (z,-z) interleaved: (zre,zim,-zre,-zim) + */ +EXPORT void split_fft_last_ref(CPLX* data, const CPLX powom); + +EXPORT void cplx_ifft_naive(const uint32_t m, const double entry_pwr, CPLX* data); +EXPORT void cplx_ifft16_avx_fma(void* data, const void* omega); +EXPORT void cplx_ifft16_ref(void* data, const void* omega); + +/** + * @brief compute the ifft evaluations of P in place + * ifft(data) = ifft_rec(data, i); + * function ifft_rec(data, omega) { + * if #data = 1: return data + * let s = sqrt(omega) w. re(s)>0 + * let (u,v) = data + * return split_fft([ifft_rec(u, s), ifft_rec(v, -s)],s) + * } + * @param itables precomputed tables (contains all the powers of omega in the order they are used) + * @param data vector of m complexes (coeffs as input, evals as output) + */ +EXPORT void cplx_ifft_ref(const CPLX_IFFT_PRECOMP* itables, void* data); +EXPORT void cplx_ifft_avx2_fma(const CPLX_IFFT_PRECOMP* itables, void* data); +EXPORT void cplx_fft_naive(const uint32_t m, const double entry_pwr, CPLX* data); +EXPORT void cplx_fft16_avx_fma(void* data, const void* omega); +EXPORT void cplx_fft16_ref(void* data, const void* omega); + +/** + * @brief compute the fft evaluations of P in place + * fft(data) = fft_rec(data, i); + * function fft_rec(data, omega) { + * if #data = 1: return data + * let s = sqrt(omega) w. re(s)>0 + * let (u,v) = merge_fft(data, s) + * return [fft_rec(u, s), fft_rec(v, -s)] + * } + * @param tables precomputed tables (contains all the powers of omega in the order they are used) + * @param data vector of m complexes (coeffs as input, evals as output) + */ +EXPORT void cplx_fft_ref(const CPLX_FFT_PRECOMP* tables, void* data); +EXPORT void cplx_fft_avx2_fma(const CPLX_FFT_PRECOMP* tables, void* data); + +/** + * @brief merges 2 times h evaluations of even/odd polynomials into 2h evaluations of a sigle polynomial + * Input: P_0(z),...,P_{h-1}(z),P_h(z),...,P_{2h-1}(z) + * Output: Q_0(y),...,Q_{h-1}(y),Q_0(-y),...,Q_{h-1}(-y) + * where Q_i(X)=P_i(X^2)+X.P_{h+i}(X^2) and y^2 = z + * @param h number of "coefficients" h >= 1 + * @param data 2h complex coefficients interleaved and 256b aligned + * @param powom y represented as (yre,yim) + */ +EXPORT void cplx_twiddle_fft_ref(int32_t h, CPLX* data, const CPLX powom); + +EXPORT void citwiddle(CPLX a, CPLX b, const CPLX om); +EXPORT void ctwiddle(CPLX a, CPLX b, const CPLX om); +EXPORT void invctwiddle(CPLX a, CPLX b, const CPLX ombar); +EXPORT void invcitwiddle(CPLX a, CPLX b, const CPLX ombar); + +// CONVERSIONS + +/** @brief r = x from ZnX (coeffs as signed int32_t's ) to double */ +EXPORT void cplx_from_znx32_ref(const CPLX_FROM_ZNX32_PRECOMP* precomp, void* r, const int32_t* x); +EXPORT void cplx_from_znx32_avx2_fma(const CPLX_FROM_ZNX32_PRECOMP* precomp, void* r, const int32_t* x); +/** @brief r = x to ZnX (coeffs as signed int32_t's ) to double */ +EXPORT void cplx_to_znx32_ref(const CPLX_TO_ZNX32_PRECOMP* precomp, int32_t* r, const void* x); +EXPORT void cplx_to_znx32_avx2_fma(const CPLX_TO_ZNX32_PRECOMP* precomp, int32_t* r, const void* x); +/** @brief r = x mod 1 from TnX (coeffs as signed int32_t's) to double */ +EXPORT void cplx_from_tnx32_ref(const CPLX_FROM_TNX32_PRECOMP* precomp, void* r, const int32_t* x); +EXPORT void cplx_from_tnx32_avx2_fma(const CPLX_FROM_TNX32_PRECOMP* precomp, void* r, const int32_t* x); +/** @brief r = x mod 1 from TnX (coeffs as signed int32_t's) */ +EXPORT void cplx_to_tnx32_ref(const CPLX_TO_TNX32_PRECOMP* precomp, int32_t* x, const void* c); +EXPORT void cplx_to_tnx32_avx2_fma(const CPLX_TO_TNX32_PRECOMP* precomp, int32_t* x, const void* c); +/** @brief r = x from RnX (coeffs as doubles ) to double */ +EXPORT void cplx_from_rnx64_ref(const CPLX_FROM_RNX64_PRECOMP* precomp, void* r, const double* x); +EXPORT void cplx_from_rnx64_avx2_fma(const CPLX_FROM_RNX64_PRECOMP* precomp, void* r, const double* x); +/** @brief r = x to RnX (coeffs as doubles ) to double */ +EXPORT void cplx_to_rnx64_ref(const CPLX_TO_RNX64_PRECOMP* precomp, double* r, const void* x); +EXPORT void cplx_to_rnx64_avx2_fma(const CPLX_TO_RNX64_PRECOMP* precomp, double* r, const void* x); +/** @brief r = x to integers in RnX (coeffs as doubles ) to double */ +EXPORT void cplx_round_to_rnx64_ref(const CPLX_ROUND_TO_RNX64_PRECOMP* precomp, double* r, const void* x); +EXPORT void cplx_round_to_rnx64_avx2_fma(const CPLX_ROUND_TO_RNX64_PRECOMP* precomp, double* r, const void* x); + +// fftvec operations +/** @brief element-wise addmul r += ab */ +EXPORT void cplx_fftvec_addmul_ref(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a, const void* b); +EXPORT void cplx_fftvec_addmul_fma(const CPLX_FFTVEC_ADDMUL_PRECOMP* tables, void* r, const void* a, const void* b); +EXPORT void cplx_fftvec_addmul_sse(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a, const void* b); +EXPORT void cplx_fftvec_addmul_avx512(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a, const void* b); +/** @brief element-wise mul r = ab */ +EXPORT void cplx_fftvec_mul_ref(const CPLX_FFTVEC_MUL_PRECOMP* tables, void* r, const void* a, const void* b); +EXPORT void cplx_fftvec_mul_fma(const CPLX_FFTVEC_MUL_PRECOMP* tables, void* r, const void* a, const void* b); + +#endif // SPQLIOS_CPLX_FFT_INTERNAL_H diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_fft_private.h b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_fft_private.h new file mode 100644 index 0000000..8791dbe --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_fft_private.h @@ -0,0 +1,109 @@ +#ifndef SPQLIOS_CPLX_FFT_PRIVATE_H +#define SPQLIOS_CPLX_FFT_PRIVATE_H + +#include "cplx_fft.h" + +typedef struct cplx_twiddle_precomp CPLX_FFTVEC_TWIDDLE_PRECOMP; +typedef struct cplx_bitwiddle_precomp CPLX_FFTVEC_BITWIDDLE_PRECOMP; + +typedef void (*IFFT_FUNCTION)(const CPLX_IFFT_PRECOMP*, void*); +typedef void (*FFT_FUNCTION)(const CPLX_FFT_PRECOMP*, void*); +// conversions +typedef void (*FROM_ZNX32_FUNCTION)(const CPLX_FROM_ZNX32_PRECOMP*, void*, const int32_t*); +typedef void (*TO_ZNX32_FUNCTION)(const CPLX_FROM_ZNX32_PRECOMP*, int32_t*, const void*); +typedef void (*FROM_TNX32_FUNCTION)(const CPLX_FROM_TNX32_PRECOMP*, void*, const int32_t*); +typedef void (*TO_TNX32_FUNCTION)(const CPLX_TO_TNX32_PRECOMP*, int32_t*, const void*); +typedef void (*FROM_RNX64_FUNCTION)(const CPLX_FROM_RNX64_PRECOMP* precomp, void* r, const double* x); +typedef void (*TO_RNX64_FUNCTION)(const CPLX_TO_RNX64_PRECOMP* precomp, double* r, const void* x); +typedef void (*ROUND_TO_RNX64_FUNCTION)(const CPLX_ROUND_TO_RNX64_PRECOMP* precomp, double* r, const void* x); +// fftvec operations +typedef void (*FFTVEC_MUL_FUNCTION)(const CPLX_FFTVEC_MUL_PRECOMP*, void*, const void*, const void*); +typedef void (*FFTVEC_ADDMUL_FUNCTION)(const CPLX_FFTVEC_ADDMUL_PRECOMP*, void*, const void*, const void*); + +typedef void (*FFTVEC_TWIDDLE_FUNCTION)(const CPLX_FFTVEC_TWIDDLE_PRECOMP*, void*, const void*, const void*); +typedef void (*FFTVEC_BITWIDDLE_FUNCTION)(const CPLX_FFTVEC_BITWIDDLE_PRECOMP*, void*, uint64_t, const void*); + +struct cplx_ifft_precomp { + IFFT_FUNCTION function; + int64_t m; + uint64_t buf_size; + double* powomegas; + void* aligned_buffers; +}; + +struct cplx_fft_precomp { + FFT_FUNCTION function; + int64_t m; + uint64_t buf_size; + double* powomegas; + void* aligned_buffers; +}; + +struct cplx_from_znx32_precomp { + FROM_ZNX32_FUNCTION function; + int64_t m; +}; + +struct cplx_to_znx32_precomp { + TO_ZNX32_FUNCTION function; + int64_t m; + double divisor; +}; + +struct cplx_from_tnx32_precomp { + FROM_TNX32_FUNCTION function; + int64_t m; +}; + +struct cplx_to_tnx32_precomp { + TO_TNX32_FUNCTION function; + int64_t m; + double divisor; +}; + +struct cplx_from_rnx64_precomp { + FROM_RNX64_FUNCTION function; + int64_t m; +}; + +struct cplx_to_rnx64_precomp { + TO_RNX64_FUNCTION function; + int64_t m; + double divisor; +}; + +struct cplx_round_to_rnx64_precomp { + ROUND_TO_RNX64_FUNCTION function; + int64_t m; + double divisor; + uint32_t log2bound; +}; + +typedef struct cplx_mul_precomp { + FFTVEC_MUL_FUNCTION function; + int64_t m; +} CPLX_FFTVEC_MUL_PRECOMP; + +typedef struct cplx_addmul_precomp { + FFTVEC_ADDMUL_FUNCTION function; + int64_t m; +} CPLX_FFTVEC_ADDMUL_PRECOMP; + +struct cplx_twiddle_precomp { + FFTVEC_TWIDDLE_FUNCTION function; + int64_t m; +}; + +struct cplx_bitwiddle_precomp { + FFTVEC_BITWIDDLE_FUNCTION function; + int64_t m; +}; + +EXPORT void cplx_fftvec_twiddle_fma(const CPLX_FFTVEC_TWIDDLE_PRECOMP* tables, void* a, void* b, const void* om); +EXPORT void cplx_fftvec_twiddle_avx512(const CPLX_FFTVEC_TWIDDLE_PRECOMP* tables, void* a, void* b, const void* om); +EXPORT void cplx_fftvec_bitwiddle_fma(const CPLX_FFTVEC_BITWIDDLE_PRECOMP* tables, void* a, uint64_t slice, + const void* om); +EXPORT void cplx_fftvec_bitwiddle_avx512(const CPLX_FFTVEC_BITWIDDLE_PRECOMP* tables, void* a, uint64_t slice, + const void* om); + +#endif // SPQLIOS_CPLX_FFT_PRIVATE_H diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_fft_ref.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_fft_ref.c new file mode 100644 index 0000000..19cc777 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_fft_ref.c @@ -0,0 +1,367 @@ +#include +#include + +#include "../commons_private.h" +#include "cplx_fft_internal.h" +#include "cplx_fft_private.h" + +/** @brief (a,b) <- (a+omega.b,a-omega.b) */ +void ctwiddle(CPLX a, CPLX b, const CPLX om) { + double re = om[0] * b[0] - om[1] * b[1]; + double im = om[0] * b[1] + om[1] * b[0]; + b[0] = a[0] - re; + b[1] = a[1] - im; + a[0] += re; + a[1] += im; +} + +/** @brief (a,b) <- (a+i.omega.b,a-i.omega.b) */ +void citwiddle(CPLX a, CPLX b, const CPLX om) { + double re = -om[1] * b[0] - om[0] * b[1]; + double im = -om[1] * b[1] + om[0] * b[0]; + b[0] = a[0] - re; + b[1] = a[1] - im; + a[0] += re; + a[1] += im; +} + +/** + * @brief FFT modulo X^16-omega^2 (in registers) + * @param data contains 16 complexes + * @param omega 8 complexes in this order: + * omega,alpha,beta,j.beta,gamma,j.gamma,k.gamma,kj.gamma + * alpha = sqrt(omega), beta = sqrt(alpha), gamma = sqrt(beta) + * j = sqrt(i), k=sqrt(j) + */ +void cplx_fft16_ref(void* data, const void* omega) { + CPLX* d = data; + const CPLX* om = omega; + // first pass + for (uint64_t i = 0; i < 8; ++i) { + ctwiddle(d[0 + i], d[8 + i], om[0]); + } + // + ctwiddle(d[0], d[4], om[1]); + ctwiddle(d[1], d[5], om[1]); + ctwiddle(d[2], d[6], om[1]); + ctwiddle(d[3], d[7], om[1]); + citwiddle(d[8], d[12], om[1]); + citwiddle(d[9], d[13], om[1]); + citwiddle(d[10], d[14], om[1]); + citwiddle(d[11], d[15], om[1]); + // + ctwiddle(d[0], d[2], om[2]); + ctwiddle(d[1], d[3], om[2]); + citwiddle(d[4], d[6], om[2]); + citwiddle(d[5], d[7], om[2]); + ctwiddle(d[8], d[10], om[3]); + ctwiddle(d[9], d[11], om[3]); + citwiddle(d[12], d[14], om[3]); + citwiddle(d[13], d[15], om[3]); + // + ctwiddle(d[0], d[1], om[4]); + citwiddle(d[2], d[3], om[4]); + ctwiddle(d[4], d[5], om[5]); + citwiddle(d[6], d[7], om[5]); + ctwiddle(d[8], d[9], om[6]); + citwiddle(d[10], d[11], om[6]); + ctwiddle(d[12], d[13], om[7]); + citwiddle(d[14], d[15], om[7]); +} + +double cos_2pix(double x) { return m_accurate_cos(2 * M_PI * x); } +double sin_2pix(double x) { return m_accurate_sin(2 * M_PI * x); } +void cplx_set_e2pix(CPLX res, double x) { + res[0] = cos_2pix(x); + res[1] = sin_2pix(x); +} + +void cplx_fft16_precomp(const double entry_pwr, CPLX** omg) { + static const double j_pow = 1. / 8.; + static const double k_pow = 1. / 16.; + const double pom = entry_pwr / 2.; + const double pom_2 = entry_pwr / 4.; + const double pom_4 = entry_pwr / 8.; + const double pom_8 = entry_pwr / 16.; + cplx_set_e2pix((*omg)[0], pom); + cplx_set_e2pix((*omg)[1], pom_2); + cplx_set_e2pix((*omg)[2], pom_4); + cplx_set_e2pix((*omg)[3], pom_4 + j_pow); + cplx_set_e2pix((*omg)[4], pom_8); + cplx_set_e2pix((*omg)[5], pom_8 + j_pow); + cplx_set_e2pix((*omg)[6], pom_8 + k_pow); + cplx_set_e2pix((*omg)[7], pom_8 + j_pow + k_pow); + *omg += 8; +} + +/** + * @brief h twiddles-fft on the same omega + * (also called merge-fft)merges 2 times h evaluations of even/odd polynomials into 2h evaluations of a sigle polynomial + * Input: P_0(z),...,P_{h-1}(z),P_h(z),...,P_{2h-1}(z) + * Output: Q_0(y),...,Q_{h-1}(y),Q_0(-y),...,Q_{h-1}(-y) + * where Q_i(X)=P_i(X^2)+X.P_{h+i}(X^2) and y^2 = z + * @param h number of "coefficients" h >= 1 + * @param data 2h complex coefficients interleaved and 256b aligned + * @param powom y represented as (yre,yim) + */ +void cplx_twiddle_fft_ref(int32_t h, CPLX* data, const CPLX powom) { + CPLX* d0 = data; + CPLX* d1 = data + h; + for (uint64_t i = 0; i < h; ++i) { + ctwiddle(d0[i], d1[i], powom); + } +} + +void cplx_bitwiddle_fft_ref(int32_t h, CPLX* data, const CPLX powom[2]) { + CPLX* d0 = data; + CPLX* d1 = data + h; + CPLX* d2 = data + 2 * h; + CPLX* d3 = data + 3 * h; + for (uint64_t i = 0; i < h; ++i) { + ctwiddle(d0[i], d2[i], powom[0]); + ctwiddle(d1[i], d3[i], powom[0]); + } + for (uint64_t i = 0; i < h; ++i) { + ctwiddle(d0[i], d1[i], powom[1]); + citwiddle(d2[i], d3[i], powom[1]); + } +} + +/** + * Input: P_0(z),P_1(z) + * Output: Q(y),Q(-y) + * where Q(X)=P_0(X^2)+X.P_1(X^2) and y^2 = z + * @param data 2 complexes coefficients interleaved and 256b aligned + * @param powom (z,-z) interleaved: (zre,zim,-zre,-zim) + */ +void merge_fft_last_ref(CPLX* data, const CPLX powom) { + CPLX prod; + cplx_mul(prod, data[1], powom); + cplx_sub(data[1], data[0], prod); + cplx_add(data[0], data[0], prod); +} + +void cplx_fft_ref_bfs_2(CPLX* dat, const CPLX** omg, uint32_t m) { + CPLX* data = (CPLX*)dat; + CPLX* const dend = data + m; + for (int32_t h = m / 2; h >= 2; h >>= 1) { + for (CPLX* d = data; d < dend; d += 2 * h) { + if (memcmp((*omg)[0], (*omg)[1], 8) != 0) abort(); + cplx_twiddle_fft_ref(h, d, **omg); + *omg += 2; + } +#if 0 + printf("after merge %d: ", h); + for (uint64_t ii=0; ii>= 1; + } + while (mm > 16) { + uint32_t h = mm / 4; + for (CPLX* d = data; d < dend; d += mm) { + cplx_bitwiddle_fft_ref(h, d, *omg); + *omg += 2; + } + mm = h; + } + for (CPLX* d = data; d < dend; d += 16) { + cplx_fft16_ref(d, *omg); + *omg += 8; + } +#if 0 + printf("after last: "); + for (uint64_t ii=0; ii 16) { + double pom = ss / 4.; + uint32_t h = mm / 4; + for (uint32_t i = 0; i < m / mm; i++) { + double om = pom + fracrevbits(i) / 4.; + cplx_set_e2pix(omg[0][0], 2. * om); + cplx_set_e2pix(omg[0][1], om); + *omg += 2; + } + mm = h; + ss = pom; + } + { + // mm=16 + for (uint32_t i = 0; i < m / 16; i++) { + cplx_fft16_precomp(ss + fracrevbits(i), omg); + } + } +} + +/** @brief fills omega for cplx_fft_bfs_2 modulo X^m-exp(i.2.pi.entry_pwr) */ +void fill_cplx_fft_omegas_bfs_2(const double entry_pwr, CPLX** omg, uint32_t m) { + double pom = entry_pwr / 2.; + for (int32_t h = m / 2; h >= 2; h >>= 1) { + for (uint32_t i = 0; i < m / (2 * h); i++) { + cplx_set_e2pix(omg[0][0], pom + fracrevbits(i) / 2.); + cplx_set(omg[0][1], omg[0][0]); + *omg += 2; + } + pom /= 2; + } + { + // h=1 + for (uint32_t i = 0; i < m / 2; i++) { + cplx_set_e2pix((*omg)[0], pom + fracrevbits(i) / 2.); + cplx_neg((*omg)[1], (*omg)[0]); + *omg += 2; + } + } +} + +/** @brief fills omega for cplx_fft_rec modulo X^m-exp(i.2.pi.entry_pwr) */ +void fill_cplx_fft_omegas_rec_16(const double entry_pwr, CPLX** omg, uint32_t m) { + // note that the cases below are for recursive calls only! + // externally, this function shall only be called with m>=4096 + if (m == 1) return; + if (m <= 8) return fill_cplx_fft_omegas_bfs_2(entry_pwr, omg, m); + if (m <= 2048) return fill_cplx_fft_omegas_bfs_16(entry_pwr, omg, m); + double pom = entry_pwr / 2.; + cplx_set_e2pix((*omg)[0], pom); + cplx_set_e2pix((*omg)[1], pom); + *omg += 2; + fill_cplx_fft_omegas_rec_16(pom, omg, m / 2); + fill_cplx_fft_omegas_rec_16(pom + 0.5, omg, m / 2); +} + +void cplx_fft_ref_rec_16(CPLX* dat, const CPLX** omg, uint32_t m) { + if (m == 1) return; + if (m <= 8) return cplx_fft_ref_bfs_2(dat, omg, m); + if (m <= 2048) return cplx_fft_ref_bfs_16(dat, omg, m); + const uint32_t h = m / 2; + if (memcmp((*omg)[0], (*omg)[1], 8) != 0) abort(); + cplx_twiddle_fft_ref(h, dat, **omg); + *omg += 2; + cplx_fft_ref_rec_16(dat, omg, h); + cplx_fft_ref_rec_16(dat + h, omg, h); +} + +void cplx_fft_ref(const CPLX_FFT_PRECOMP* precomp, void* d) { + CPLX* data = (CPLX*)d; + const int32_t m = precomp->m; + const CPLX* omg = (CPLX*)precomp->powomegas; + if (m == 1) return; + if (m <= 8) return cplx_fft_ref_bfs_2(data, &omg, m); + if (m <= 2048) return cplx_fft_ref_bfs_16(data, &omg, m); + cplx_fft_ref_rec_16(data, &omg, m); +} + +EXPORT CPLX_FFT_PRECOMP* new_cplx_fft_precomp(uint32_t m, uint32_t num_buffers) { + const uint64_t OMG_SPACE = ceilto64b((2 * m) * sizeof(CPLX)); + const uint64_t BUF_SIZE = ceilto64b(m * sizeof(CPLX)); + void* reps = malloc(sizeof(CPLX_FFT_PRECOMP) + 63 // padding + + OMG_SPACE // tables //TODO 16? + + num_buffers * BUF_SIZE // buffers + ); + uint64_t aligned_addr = ceilto64b((uint64_t)(reps) + sizeof(CPLX_FFT_PRECOMP)); + CPLX_FFT_PRECOMP* r = (CPLX_FFT_PRECOMP*)reps; + r->m = m; + r->buf_size = BUF_SIZE; + r->powomegas = (double*)aligned_addr; + r->aligned_buffers = (void*)(aligned_addr + OMG_SPACE); + // fill in powomegas + CPLX* omg = (CPLX*)r->powomegas; + if (m <= 8) { + fill_cplx_fft_omegas_bfs_2(0.25, &omg, m); + } else if (m <= 2048) { + fill_cplx_fft_omegas_bfs_16(0.25, &omg, m); + } else { + fill_cplx_fft_omegas_rec_16(0.25, &omg, m); + } + if (((uint64_t)omg) - aligned_addr > OMG_SPACE) abort(); + // dispatch the right implementation + { + if (m <= 4) { + // currently, we do not have any acceletated + // implementation for m<=4 + r->function = cplx_fft_ref; + } else if (CPU_SUPPORTS("fma")) { + r->function = cplx_fft_avx2_fma; + } else { + r->function = cplx_fft_ref; + } + } + return reps; +} + +EXPORT void* cplx_fft_precomp_get_buffer(const CPLX_FFT_PRECOMP* tables, uint32_t buffer_index) { + return (uint8_t*)tables->aligned_buffers + buffer_index * tables->buf_size; +} + +EXPORT void cplx_fft_simple(uint32_t m, void* data) { + static CPLX_FFT_PRECOMP* p[31] = {0}; + CPLX_FFT_PRECOMP** f = p + log2m(m); + if (!*f) *f = new_cplx_fft_precomp(m, 0); + (*f)->function(*f, data); +} + +EXPORT void cplx_fft(const CPLX_FFT_PRECOMP* tables, void* data) { tables->function(tables, data); } diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_fft_sse.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_fft_sse.c new file mode 100644 index 0000000..90ab8d2 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_fft_sse.c @@ -0,0 +1,309 @@ +#include + +#include "cplx_fft_internal.h" +#include "cplx_fft_private.h" + +typedef double D2MEM[2]; + +EXPORT void cplx_fftvec_addmul_sse(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a, const void* b) { + const uint32_t m = precomp->m; + const D2MEM* aa = (D2MEM*)a; + const D2MEM* bb = (D2MEM*)b; + D2MEM* rr = (D2MEM*)r; + const D2MEM* const aend = aa + m; + do { + /* +BEGIN_TEMPLATE +const __m128d ari% = _mm_loadu_pd(aa[%]); +const __m128d bri% = _mm_loadu_pd(bb[%]); +const __m128d rri% = _mm_loadu_pd(rr[%]); +const __m128d bir% = _mm_shuffle_pd(bri%,bri%, 5); +const __m128d aii% = _mm_shuffle_pd(ari%,ari%, 15); +const __m128d pro% = _mm_fmaddsub_pd(aii%,bir%,rri%); +const __m128d arr% = _mm_shuffle_pd(ari%,ari%, 0); +const __m128d res% = _mm_fmaddsub_pd(arr%,bri%,pro%); +_mm_storeu_pd(rr[%],res%); +rr += @; // ONCE +aa += @; // ONCE +bb += @; // ONCE +END_TEMPLATE + */ + // BEGIN_INTERLEAVE 2 + const __m128d ari0 = _mm_loadu_pd(aa[0]); + const __m128d ari1 = _mm_loadu_pd(aa[1]); + const __m128d bri0 = _mm_loadu_pd(bb[0]); + const __m128d bri1 = _mm_loadu_pd(bb[1]); + const __m128d rri0 = _mm_loadu_pd(rr[0]); + const __m128d rri1 = _mm_loadu_pd(rr[1]); + const __m128d bir0 = _mm_shuffle_pd(bri0, bri0, 0b01); + const __m128d bir1 = _mm_shuffle_pd(bri1, bri1, 0b01); + const __m128d aii0 = _mm_shuffle_pd(ari0, ari0, 0b11); + const __m128d aii1 = _mm_shuffle_pd(ari1, ari1, 0b11); + const __m128d pro0 = _mm_fmaddsub_pd(aii0, bir0, rri0); + const __m128d pro1 = _mm_fmaddsub_pd(aii1, bir1, rri1); + const __m128d arr0 = _mm_shuffle_pd(ari0, ari0, 0b00); + const __m128d arr1 = _mm_shuffle_pd(ari1, ari1, 0b00); + const __m128d res0 = _mm_fmaddsub_pd(arr0, bri0, pro0); + const __m128d res1 = _mm_fmaddsub_pd(arr1, bri1, pro1); + _mm_storeu_pd(rr[0], res0); + _mm_storeu_pd(rr[1], res1); + rr += 2; // ONCE + aa += 2; // ONCE + bb += 2; // ONCE + // END_INTERLEAVE + } while (aa < aend); +} + +#if 0 +EXPORT void cplx_fftvec_mul_fma(const CPLX_FFTVEC_MUL_PRECOMP* precomp, void* r, const void* a, const void* b) { + const uint32_t m = precomp->m; + const double(*aa)[4] = (double(*)[4])a; + const double(*bb)[4] = (double(*)[4])b; + double(*rr)[4] = (double(*)[4])r; + const double(*const aend)[4] = aa + (m >> 1); + do { + /* +BEGIN_TEMPLATE +const __m256d ari% = _mm256_loadu_pd(aa[%]); +const __m256d bri% = _mm256_loadu_pd(bb[%]); +const __m256d bir% = _mm256_shuffle_pd(bri%,bri%, 5); // conj of b +const __m256d aii% = _mm256_shuffle_pd(ari%,ari%, 15); // im of a +const __m256d pro% = _mm256_mul_pd(aii%,bir%); +const __m256d arr% = _mm256_shuffle_pd(ari%,ari%, 0); // rr of a +const __m256d res% = _mm256_fmaddsub_pd(arr%,bri%,pro%); +_mm256_storeu_pd(rr[%],res%); +END_TEMPLATE + */ + // BEGIN_INTERLEAVE 4 +const __m256d ari0 = _mm256_loadu_pd(aa[0]); +const __m256d ari1 = _mm256_loadu_pd(aa[1]); +const __m256d ari2 = _mm256_loadu_pd(aa[2]); +const __m256d ari3 = _mm256_loadu_pd(aa[3]); +const __m256d bri0 = _mm256_loadu_pd(bb[0]); +const __m256d bri1 = _mm256_loadu_pd(bb[1]); +const __m256d bri2 = _mm256_loadu_pd(bb[2]); +const __m256d bri3 = _mm256_loadu_pd(bb[3]); +const __m256d bir0 = _mm256_shuffle_pd(bri0,bri0, 5); // conj of b +const __m256d bir1 = _mm256_shuffle_pd(bri1,bri1, 5); // conj of b +const __m256d bir2 = _mm256_shuffle_pd(bri2,bri2, 5); // conj of b +const __m256d bir3 = _mm256_shuffle_pd(bri3,bri3, 5); // conj of b +const __m256d aii0 = _mm256_shuffle_pd(ari0,ari0, 15); // im of a +const __m256d aii1 = _mm256_shuffle_pd(ari1,ari1, 15); // im of a +const __m256d aii2 = _mm256_shuffle_pd(ari2,ari2, 15); // im of a +const __m256d aii3 = _mm256_shuffle_pd(ari3,ari3, 15); // im of a +const __m256d pro0 = _mm256_mul_pd(aii0,bir0); +const __m256d pro1 = _mm256_mul_pd(aii1,bir1); +const __m256d pro2 = _mm256_mul_pd(aii2,bir2); +const __m256d pro3 = _mm256_mul_pd(aii3,bir3); +const __m256d arr0 = _mm256_shuffle_pd(ari0,ari0, 0); // rr of a +const __m256d arr1 = _mm256_shuffle_pd(ari1,ari1, 0); // rr of a +const __m256d arr2 = _mm256_shuffle_pd(ari2,ari2, 0); // rr of a +const __m256d arr3 = _mm256_shuffle_pd(ari3,ari3, 0); // rr of a +const __m256d res0 = _mm256_fmaddsub_pd(arr0,bri0,pro0); +const __m256d res1 = _mm256_fmaddsub_pd(arr1,bri1,pro1); +const __m256d res2 = _mm256_fmaddsub_pd(arr2,bri2,pro2); +const __m256d res3 = _mm256_fmaddsub_pd(arr3,bri3,pro3); +_mm256_storeu_pd(rr[0],res0); +_mm256_storeu_pd(rr[1],res1); +_mm256_storeu_pd(rr[2],res2); +_mm256_storeu_pd(rr[3],res3); + // END_INTERLEAVE + rr += 4; + aa += 4; + bb += 4; + } while (aa < aend); +} + +EXPORT void cplx_fftvec_add_fma(uint32_t m, void* r, const void* a, const void* b) { + const double(*aa)[4] = (double(*)[4])a; + const double(*bb)[4] = (double(*)[4])b; + double(*rr)[4] = (double(*)[4])r; + const double(*const aend)[4] = aa + (m >> 1); + do { + /* +BEGIN_TEMPLATE +const __m256d ari% = _mm256_loadu_pd(aa[%]); +const __m256d bri% = _mm256_loadu_pd(bb[%]); +const __m256d res% = _mm256_add_pd(ari%,bri%); +_mm256_storeu_pd(rr[%],res%); +END_TEMPLATE + */ + // BEGIN_INTERLEAVE 4 +const __m256d ari0 = _mm256_loadu_pd(aa[0]); +const __m256d ari1 = _mm256_loadu_pd(aa[1]); +const __m256d ari2 = _mm256_loadu_pd(aa[2]); +const __m256d ari3 = _mm256_loadu_pd(aa[3]); +const __m256d bri0 = _mm256_loadu_pd(bb[0]); +const __m256d bri1 = _mm256_loadu_pd(bb[1]); +const __m256d bri2 = _mm256_loadu_pd(bb[2]); +const __m256d bri3 = _mm256_loadu_pd(bb[3]); +const __m256d res0 = _mm256_add_pd(ari0,bri0); +const __m256d res1 = _mm256_add_pd(ari1,bri1); +const __m256d res2 = _mm256_add_pd(ari2,bri2); +const __m256d res3 = _mm256_add_pd(ari3,bri3); +_mm256_storeu_pd(rr[0],res0); +_mm256_storeu_pd(rr[1],res1); +_mm256_storeu_pd(rr[2],res2); +_mm256_storeu_pd(rr[3],res3); + // END_INTERLEAVE + rr += 4; + aa += 4; + bb += 4; + } while (aa < aend); +} + +EXPORT void cplx_fftvec_sub2_to_fma(uint32_t m, void* r, const void* a, const void* b) { + const double(*aa)[4] = (double(*)[4])a; + const double(*bb)[4] = (double(*)[4])b; + double(*rr)[4] = (double(*)[4])r; + const double(*const aend)[4] = aa + (m >> 1); + do { + /* +BEGIN_TEMPLATE +const __m256d ari% = _mm256_loadu_pd(aa[%]); +const __m256d bri% = _mm256_loadu_pd(bb[%]); +const __m256d sum% = _mm256_add_pd(ari%,bri%); +const __m256d rri% = _mm256_loadu_pd(rr[%]); +const __m256d res% = _mm256_sub_pd(rri%,sum%); +_mm256_storeu_pd(rr[%],res%); +END_TEMPLATE + */ + // BEGIN_INTERLEAVE 4 +const __m256d ari0 = _mm256_loadu_pd(aa[0]); +const __m256d ari1 = _mm256_loadu_pd(aa[1]); +const __m256d ari2 = _mm256_loadu_pd(aa[2]); +const __m256d ari3 = _mm256_loadu_pd(aa[3]); +const __m256d bri0 = _mm256_loadu_pd(bb[0]); +const __m256d bri1 = _mm256_loadu_pd(bb[1]); +const __m256d bri2 = _mm256_loadu_pd(bb[2]); +const __m256d bri3 = _mm256_loadu_pd(bb[3]); +const __m256d sum0 = _mm256_add_pd(ari0,bri0); +const __m256d sum1 = _mm256_add_pd(ari1,bri1); +const __m256d sum2 = _mm256_add_pd(ari2,bri2); +const __m256d sum3 = _mm256_add_pd(ari3,bri3); +const __m256d rri0 = _mm256_loadu_pd(rr[0]); +const __m256d rri1 = _mm256_loadu_pd(rr[1]); +const __m256d rri2 = _mm256_loadu_pd(rr[2]); +const __m256d rri3 = _mm256_loadu_pd(rr[3]); +const __m256d res0 = _mm256_sub_pd(rri0,sum0); +const __m256d res1 = _mm256_sub_pd(rri1,sum1); +const __m256d res2 = _mm256_sub_pd(rri2,sum2); +const __m256d res3 = _mm256_sub_pd(rri3,sum3); +_mm256_storeu_pd(rr[0],res0); +_mm256_storeu_pd(rr[1],res1); +_mm256_storeu_pd(rr[2],res2); +_mm256_storeu_pd(rr[3],res3); + // END_INTERLEAVE + rr += 4; + aa += 4; + bb += 4; + } while (aa < aend); +} + +EXPORT void cplx_fftvec_copy_fma(uint32_t m, void* r, const void* a) { + const double(*aa)[4] = (double(*)[4])a; + double(*rr)[4] = (double(*)[4])r; + const double(*const aend)[4] = aa + (m >> 1); + do { + /* +BEGIN_TEMPLATE +const __m256d ari% = _mm256_loadu_pd(aa[%]); +_mm256_storeu_pd(rr[%],ari%); +END_TEMPLATE + */ + // BEGIN_INTERLEAVE 4 +const __m256d ari0 = _mm256_loadu_pd(aa[0]); +const __m256d ari1 = _mm256_loadu_pd(aa[1]); +const __m256d ari2 = _mm256_loadu_pd(aa[2]); +const __m256d ari3 = _mm256_loadu_pd(aa[3]); +_mm256_storeu_pd(rr[0],ari0); +_mm256_storeu_pd(rr[1],ari1); +_mm256_storeu_pd(rr[2],ari2); +_mm256_storeu_pd(rr[3],ari3); + // END_INTERLEAVE + rr += 4; + aa += 4; + } while (aa < aend); +} + +EXPORT void cplx_fftvec_twiddle_fma(uint32_t m, void* a, void* b, const void* omg) { + double(*aa)[4] = (double(*)[4])a; + double(*bb)[4] = (double(*)[4])b; + const double(*const aend)[4] = aa + (m >> 1); + const __m256d om = _mm256_loadu_pd(omg); + const __m256d omrr = _mm256_shuffle_pd(om, om, 0); + const __m256d omii = _mm256_shuffle_pd(om, om, 15); + do { + /* +BEGIN_TEMPLATE +const __m256d bri% = _mm256_loadu_pd(bb[%]); +const __m256d bir% = _mm256_shuffle_pd(bri%,bri%,5); +__m256d p% = _mm256_mul_pd(bir%,omii); +p% = _mm256_fmaddsub_pd(bri%,omrr,p%); +const __m256d ari% = _mm256_loadu_pd(aa[%]); +_mm256_storeu_pd(aa[%],_mm256_add_pd(ari%,p%)); +_mm256_storeu_pd(bb[%],_mm256_sub_pd(ari%,p%)); +END_TEMPLATE + */ + // BEGIN_INTERLEAVE 4 +const __m256d bri0 = _mm256_loadu_pd(bb[0]); +const __m256d bri1 = _mm256_loadu_pd(bb[1]); +const __m256d bri2 = _mm256_loadu_pd(bb[2]); +const __m256d bri3 = _mm256_loadu_pd(bb[3]); +const __m256d bir0 = _mm256_shuffle_pd(bri0,bri0,5); +const __m256d bir1 = _mm256_shuffle_pd(bri1,bri1,5); +const __m256d bir2 = _mm256_shuffle_pd(bri2,bri2,5); +const __m256d bir3 = _mm256_shuffle_pd(bri3,bri3,5); +__m256d p0 = _mm256_mul_pd(bir0,omii); +__m256d p1 = _mm256_mul_pd(bir1,omii); +__m256d p2 = _mm256_mul_pd(bir2,omii); +__m256d p3 = _mm256_mul_pd(bir3,omii); +p0 = _mm256_fmaddsub_pd(bri0,omrr,p0); +p1 = _mm256_fmaddsub_pd(bri1,omrr,p1); +p2 = _mm256_fmaddsub_pd(bri2,omrr,p2); +p3 = _mm256_fmaddsub_pd(bri3,omrr,p3); +const __m256d ari0 = _mm256_loadu_pd(aa[0]); +const __m256d ari1 = _mm256_loadu_pd(aa[1]); +const __m256d ari2 = _mm256_loadu_pd(aa[2]); +const __m256d ari3 = _mm256_loadu_pd(aa[3]); +_mm256_storeu_pd(aa[0],_mm256_add_pd(ari0,p0)); +_mm256_storeu_pd(aa[1],_mm256_add_pd(ari1,p1)); +_mm256_storeu_pd(aa[2],_mm256_add_pd(ari2,p2)); +_mm256_storeu_pd(aa[3],_mm256_add_pd(ari3,p3)); +_mm256_storeu_pd(bb[0],_mm256_sub_pd(ari0,p0)); +_mm256_storeu_pd(bb[1],_mm256_sub_pd(ari1,p1)); +_mm256_storeu_pd(bb[2],_mm256_sub_pd(ari2,p2)); +_mm256_storeu_pd(bb[3],_mm256_sub_pd(ari3,p3)); + // END_INTERLEAVE + bb += 4; + aa += 4; + } while (aa < aend); +} + +EXPORT void cplx_fftvec_innerprod_avx2_fma(const CPLX_FFTVEC_INNERPROD_PRECOMP* precomp, const int32_t ellbar, + const uint64_t lda, const uint64_t ldb, + void* r, const void* a, const void* b) { + const uint32_t m = precomp->m; + const uint32_t blk = precomp->blk; + const uint32_t nblocks = precomp->nblocks; + const CPLX* aa = (CPLX*)a; + const CPLX* bb = (CPLX*)b; + CPLX* rr = (CPLX*)r; + const uint64_t ldda = lda >> 4; // in CPLX + const uint64_t lddb = ldb >> 4; + if (m==0) { + memset(r, 0, m*sizeof(CPLX)); + return; + } + for (uint32_t k=0; kmul_func, rrr, aaa, bbb); + for (int32_t i=1; iaddmul_func, rrr, aaa + i * ldda, bbb + i * lddb); + } + } +} +#endif diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_fftvec_avx2_fma.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_fftvec_avx2_fma.c new file mode 100644 index 0000000..3418a0f --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_fftvec_avx2_fma.c @@ -0,0 +1,387 @@ +#include +#include + +#include "cplx_fft_internal.h" +#include "cplx_fft_private.h" + +typedef double D4MEM[4]; + +EXPORT void cplx_fftvec_mul_fma(const CPLX_FFTVEC_MUL_PRECOMP* precomp, void* r, const void* a, const void* b) { + const uint32_t m = precomp->m; + const double(*aa)[4] = (double(*)[4])a; + const double(*bb)[4] = (double(*)[4])b; + double(*rr)[4] = (double(*)[4])r; + const double(*const aend)[4] = aa + (m >> 1); + do { + /* +BEGIN_TEMPLATE +const __m256d ari% = _mm256_loadu_pd(aa[%]); +const __m256d bri% = _mm256_loadu_pd(bb[%]); +const __m256d bir% = _mm256_shuffle_pd(bri%,bri%, 5); // conj of b +const __m256d aii% = _mm256_shuffle_pd(ari%,ari%, 15); // im of a +const __m256d pro% = _mm256_mul_pd(aii%,bir%); +const __m256d arr% = _mm256_shuffle_pd(ari%,ari%, 0); // rr of a +const __m256d res% = _mm256_fmaddsub_pd(arr%,bri%,pro%); +_mm256_storeu_pd(rr[%],res%); +rr += @; // ONCE +aa += @; // ONCE +bb += @; // ONCE +END_TEMPLATE + */ + // BEGIN_INTERLEAVE 4 + // This block is automatically generated from the template above + // by the interleave.pl script. Please do not edit by hand + const __m256d ari0 = _mm256_loadu_pd(aa[0]); + const __m256d ari1 = _mm256_loadu_pd(aa[1]); + const __m256d ari2 = _mm256_loadu_pd(aa[2]); + const __m256d ari3 = _mm256_loadu_pd(aa[3]); + const __m256d bri0 = _mm256_loadu_pd(bb[0]); + const __m256d bri1 = _mm256_loadu_pd(bb[1]); + const __m256d bri2 = _mm256_loadu_pd(bb[2]); + const __m256d bri3 = _mm256_loadu_pd(bb[3]); + const __m256d bir0 = _mm256_shuffle_pd(bri0, bri0, 5); // conj of b + const __m256d bir1 = _mm256_shuffle_pd(bri1, bri1, 5); // conj of b + const __m256d bir2 = _mm256_shuffle_pd(bri2, bri2, 5); // conj of b + const __m256d bir3 = _mm256_shuffle_pd(bri3, bri3, 5); // conj of b + const __m256d aii0 = _mm256_shuffle_pd(ari0, ari0, 15); // im of a + const __m256d aii1 = _mm256_shuffle_pd(ari1, ari1, 15); // im of a + const __m256d aii2 = _mm256_shuffle_pd(ari2, ari2, 15); // im of a + const __m256d aii3 = _mm256_shuffle_pd(ari3, ari3, 15); // im of a + const __m256d pro0 = _mm256_mul_pd(aii0, bir0); + const __m256d pro1 = _mm256_mul_pd(aii1, bir1); + const __m256d pro2 = _mm256_mul_pd(aii2, bir2); + const __m256d pro3 = _mm256_mul_pd(aii3, bir3); + const __m256d arr0 = _mm256_shuffle_pd(ari0, ari0, 0); // rr of a + const __m256d arr1 = _mm256_shuffle_pd(ari1, ari1, 0); // rr of a + const __m256d arr2 = _mm256_shuffle_pd(ari2, ari2, 0); // rr of a + const __m256d arr3 = _mm256_shuffle_pd(ari3, ari3, 0); // rr of a + const __m256d res0 = _mm256_fmaddsub_pd(arr0, bri0, pro0); + const __m256d res1 = _mm256_fmaddsub_pd(arr1, bri1, pro1); + const __m256d res2 = _mm256_fmaddsub_pd(arr2, bri2, pro2); + const __m256d res3 = _mm256_fmaddsub_pd(arr3, bri3, pro3); + _mm256_storeu_pd(rr[0], res0); + _mm256_storeu_pd(rr[1], res1); + _mm256_storeu_pd(rr[2], res2); + _mm256_storeu_pd(rr[3], res3); + rr += 4; // ONCE + aa += 4; // ONCE + bb += 4; // ONCE + // END_INTERLEAVE + } while (aa < aend); +} + +EXPORT void cplx_fftvec_addmul_fma(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a, const void* b) { + const uint32_t m = precomp->m; + const double(*aa)[4] = (double(*)[4])a; + const double(*bb)[4] = (double(*)[4])b; + double(*rr)[4] = (double(*)[4])r; + const double(*const aend)[4] = aa + (m >> 1); + do { + /* +BEGIN_TEMPLATE +const __m256d ari% = _mm256_loadu_pd(aa[%]); +const __m256d bri% = _mm256_loadu_pd(bb[%]); +const __m256d rri% = _mm256_loadu_pd(rr[%]); +const __m256d bir% = _mm256_shuffle_pd(bri%,bri%, 5); +const __m256d aii% = _mm256_shuffle_pd(ari%,ari%, 15); +const __m256d pro% = _mm256_fmaddsub_pd(aii%,bir%,rri%); +const __m256d arr% = _mm256_shuffle_pd(ari%,ari%, 0); +const __m256d res% = _mm256_fmaddsub_pd(arr%,bri%,pro%); +_mm256_storeu_pd(rr[%],res%); +rr += @; // ONCE +aa += @; // ONCE +bb += @; // ONCE +END_TEMPLATE + */ + // BEGIN_INTERLEAVE 2 + // This block is automatically generated from the template above + // by the interleave.pl script. Please do not edit by hand + const __m256d ari0 = _mm256_loadu_pd(aa[0]); + const __m256d ari1 = _mm256_loadu_pd(aa[1]); + const __m256d bri0 = _mm256_loadu_pd(bb[0]); + const __m256d bri1 = _mm256_loadu_pd(bb[1]); + const __m256d rri0 = _mm256_loadu_pd(rr[0]); + const __m256d rri1 = _mm256_loadu_pd(rr[1]); + const __m256d bir0 = _mm256_shuffle_pd(bri0, bri0, 5); + const __m256d bir1 = _mm256_shuffle_pd(bri1, bri1, 5); + const __m256d aii0 = _mm256_shuffle_pd(ari0, ari0, 15); + const __m256d aii1 = _mm256_shuffle_pd(ari1, ari1, 15); + const __m256d pro0 = _mm256_fmaddsub_pd(aii0, bir0, rri0); + const __m256d pro1 = _mm256_fmaddsub_pd(aii1, bir1, rri1); + const __m256d arr0 = _mm256_shuffle_pd(ari0, ari0, 0); + const __m256d arr1 = _mm256_shuffle_pd(ari1, ari1, 0); + const __m256d res0 = _mm256_fmaddsub_pd(arr0, bri0, pro0); + const __m256d res1 = _mm256_fmaddsub_pd(arr1, bri1, pro1); + _mm256_storeu_pd(rr[0], res0); + _mm256_storeu_pd(rr[1], res1); + rr += 2; // ONCE + aa += 2; // ONCE + bb += 2; // ONCE + // END_INTERLEAVE + } while (aa < aend); +} + +EXPORT void cplx_fftvec_bitwiddle_fma(const CPLX_FFTVEC_BITWIDDLE_PRECOMP* precomp, void* a, uint64_t slicea, + const void* omg) { + const uint32_t m = precomp->m; + const uint64_t OFFSET = slicea / sizeof(D4MEM); + D4MEM* aa = (D4MEM*)a; + const double(*const aend)[4] = aa + (m >> 1); + const __m256d om = _mm256_loadu_pd(omg); + const __m256d om1rr = _mm256_shuffle_pd(om, om, 0); + const __m256d om1ii = _mm256_shuffle_pd(om, om, 15); + const __m256d om2rr = _mm256_shuffle_pd(om, om, 0); + const __m256d om2ii = _mm256_shuffle_pd(om, om, 0); + const __m256d om3rr = _mm256_shuffle_pd(om, om, 15); + const __m256d om3ii = _mm256_shuffle_pd(om, om, 15); + do { + /* +BEGIN_TEMPLATE +__m256d ari% = _mm256_loadu_pd(aa[%]); +__m256d bri% = _mm256_loadu_pd((aa+OFFSET)[%]); +__m256d cri% = _mm256_loadu_pd((aa+2*OFFSET)[%]); +__m256d dri% = _mm256_loadu_pd((aa+3*OFFSET)[%]); +__m256d pa% = _mm256_shuffle_pd(cri%,cri%,5); +__m256d pb% = _mm256_shuffle_pd(dri%,dri%,5); +pa% = _mm256_mul_pd(pa%,om1ii); +pb% = _mm256_mul_pd(pb%,om1ii); +pa% = _mm256_fmaddsub_pd(cri%,om1rr,pa%); +pb% = _mm256_fmaddsub_pd(dri%,om1rr,pb%); +cri% = _mm256_sub_pd(ari%,pa%); +dri% = _mm256_sub_pd(bri%,pb%); +ari% = _mm256_add_pd(ari%,pa%); +bri% = _mm256_add_pd(bri%,pb%); +pa% = _mm256_shuffle_pd(bri%,bri%,5); +pb% = _mm256_shuffle_pd(dri%,dri%,5); +pa% = _mm256_mul_pd(pa%,om2ii); +pb% = _mm256_mul_pd(pb%,om3ii); +pa% = _mm256_fmaddsub_pd(bri%,om2rr,pa%); +pb% = _mm256_fmaddsub_pd(dri%,om3rr,pb%); +bri% = _mm256_sub_pd(ari%,pa%); +dri% = _mm256_sub_pd(cri%,pb%); +ari% = _mm256_add_pd(ari%,pa%); +cri% = _mm256_add_pd(cri%,pb%); +_mm256_storeu_pd(aa[%], ari%); +_mm256_storeu_pd((aa+OFFSET)[%],bri%); +_mm256_storeu_pd((aa+2*OFFSET)[%],cri%); +_mm256_storeu_pd((aa+3*OFFSET)[%],dri%); +aa += @; // ONCE +END_TEMPLATE + */ + // BEGIN_INTERLEAVE 1 + // This block is automatically generated from the template above + // by the interleave.pl script. Please do not edit by hand + __m256d ari0 = _mm256_loadu_pd(aa[0]); + __m256d bri0 = _mm256_loadu_pd((aa + OFFSET)[0]); + __m256d cri0 = _mm256_loadu_pd((aa + 2 * OFFSET)[0]); + __m256d dri0 = _mm256_loadu_pd((aa + 3 * OFFSET)[0]); + __m256d pa0 = _mm256_shuffle_pd(cri0, cri0, 5); + __m256d pb0 = _mm256_shuffle_pd(dri0, dri0, 5); + pa0 = _mm256_mul_pd(pa0, om1ii); + pb0 = _mm256_mul_pd(pb0, om1ii); + pa0 = _mm256_fmaddsub_pd(cri0, om1rr, pa0); + pb0 = _mm256_fmaddsub_pd(dri0, om1rr, pb0); + cri0 = _mm256_sub_pd(ari0, pa0); + dri0 = _mm256_sub_pd(bri0, pb0); + ari0 = _mm256_add_pd(ari0, pa0); + bri0 = _mm256_add_pd(bri0, pb0); + pa0 = _mm256_shuffle_pd(bri0, bri0, 5); + pb0 = _mm256_shuffle_pd(dri0, dri0, 5); + pa0 = _mm256_mul_pd(pa0, om2ii); + pb0 = _mm256_mul_pd(pb0, om3ii); + pa0 = _mm256_fmaddsub_pd(bri0, om2rr, pa0); + pb0 = _mm256_fmaddsub_pd(dri0, om3rr, pb0); + bri0 = _mm256_sub_pd(ari0, pa0); + dri0 = _mm256_sub_pd(cri0, pb0); + ari0 = _mm256_add_pd(ari0, pa0); + cri0 = _mm256_add_pd(cri0, pb0); + _mm256_storeu_pd(aa[0], ari0); + _mm256_storeu_pd((aa + OFFSET)[0], bri0); + _mm256_storeu_pd((aa + 2 * OFFSET)[0], cri0); + _mm256_storeu_pd((aa + 3 * OFFSET)[0], dri0); + aa += 1; // ONCE + // END_INTERLEAVE + } while (aa < aend); +} + +EXPORT void cplx_fftvec_sub2_to_fma(uint32_t m, void* r, const void* a, const void* b) { + const double(*aa)[4] = (double(*)[4])a; + const double(*bb)[4] = (double(*)[4])b; + double(*rr)[4] = (double(*)[4])r; + const double(*const aend)[4] = aa + (m >> 1); + do { + /* +BEGIN_TEMPLATE +const __m256d ari% = _mm256_loadu_pd(aa[%]); +const __m256d bri% = _mm256_loadu_pd(bb[%]); +const __m256d sum% = _mm256_add_pd(ari%,bri%); +const __m256d rri% = _mm256_loadu_pd(rr[%]); +const __m256d res% = _mm256_sub_pd(rri%,sum%); +_mm256_storeu_pd(rr[%],res%); +END_TEMPLATE + */ + // BEGIN_INTERLEAVE 4 + // This block is automatically generated from the template above + // by the interleave.pl script. Please do not edit by hand + const __m256d ari0 = _mm256_loadu_pd(aa[0]); + const __m256d ari1 = _mm256_loadu_pd(aa[1]); + const __m256d ari2 = _mm256_loadu_pd(aa[2]); + const __m256d ari3 = _mm256_loadu_pd(aa[3]); + const __m256d bri0 = _mm256_loadu_pd(bb[0]); + const __m256d bri1 = _mm256_loadu_pd(bb[1]); + const __m256d bri2 = _mm256_loadu_pd(bb[2]); + const __m256d bri3 = _mm256_loadu_pd(bb[3]); + const __m256d sum0 = _mm256_add_pd(ari0, bri0); + const __m256d sum1 = _mm256_add_pd(ari1, bri1); + const __m256d sum2 = _mm256_add_pd(ari2, bri2); + const __m256d sum3 = _mm256_add_pd(ari3, bri3); + const __m256d rri0 = _mm256_loadu_pd(rr[0]); + const __m256d rri1 = _mm256_loadu_pd(rr[1]); + const __m256d rri2 = _mm256_loadu_pd(rr[2]); + const __m256d rri3 = _mm256_loadu_pd(rr[3]); + const __m256d res0 = _mm256_sub_pd(rri0, sum0); + const __m256d res1 = _mm256_sub_pd(rri1, sum1); + const __m256d res2 = _mm256_sub_pd(rri2, sum2); + const __m256d res3 = _mm256_sub_pd(rri3, sum3); + _mm256_storeu_pd(rr[0], res0); + _mm256_storeu_pd(rr[1], res1); + _mm256_storeu_pd(rr[2], res2); + _mm256_storeu_pd(rr[3], res3); + // END_INTERLEAVE + rr += 4; + aa += 4; + bb += 4; + } while (aa < aend); +} + +EXPORT void cplx_fftvec_add_fma(uint32_t m, void* r, const void* a, const void* b) { + const double(*aa)[4] = (double(*)[4])a; + const double(*bb)[4] = (double(*)[4])b; + double(*rr)[4] = (double(*)[4])r; + const double(*const aend)[4] = aa + (m >> 1); + do { + /* + BEGIN_TEMPLATE + const __m256d ari% = _mm256_loadu_pd(aa[%]); + const __m256d bri% = _mm256_loadu_pd(bb[%]); + const __m256d res% = _mm256_add_pd(ari%,bri%); + _mm256_storeu_pd(rr[%],res%); + rr += @; // ONCE + aa += @; // ONCE + bb += @; // ONCE + END_TEMPLATE + */ + // BEGIN_INTERLEAVE 4 + // This block is automatically generated from the template above + // by the interleave.pl script. Please do not edit by hand + const __m256d ari0 = _mm256_loadu_pd(aa[0]); + const __m256d ari1 = _mm256_loadu_pd(aa[1]); + const __m256d ari2 = _mm256_loadu_pd(aa[2]); + const __m256d ari3 = _mm256_loadu_pd(aa[3]); + const __m256d bri0 = _mm256_loadu_pd(bb[0]); + const __m256d bri1 = _mm256_loadu_pd(bb[1]); + const __m256d bri2 = _mm256_loadu_pd(bb[2]); + const __m256d bri3 = _mm256_loadu_pd(bb[3]); + const __m256d res0 = _mm256_add_pd(ari0, bri0); + const __m256d res1 = _mm256_add_pd(ari1, bri1); + const __m256d res2 = _mm256_add_pd(ari2, bri2); + const __m256d res3 = _mm256_add_pd(ari3, bri3); + _mm256_storeu_pd(rr[0], res0); + _mm256_storeu_pd(rr[1], res1); + _mm256_storeu_pd(rr[2], res2); + _mm256_storeu_pd(rr[3], res3); + rr += 4; // ONCE + aa += 4; // ONCE + bb += 4; // ONCE + // END_INTERLEAVE + } while (aa < aend); +} + +EXPORT void cplx_fftvec_twiddle_fma(const CPLX_FFTVEC_TWIDDLE_PRECOMP* precomp, void* a, void* b, const void* omg) { + const uint32_t m = precomp->m; + double(*aa)[4] = (double(*)[4])a; + double(*bb)[4] = (double(*)[4])b; + const double(*const aend)[4] = aa + (m >> 1); + const __m256d om = _mm256_loadu_pd(omg); + const __m256d omrr = _mm256_shuffle_pd(om, om, 0); + const __m256d omii = _mm256_shuffle_pd(om, om, 15); + do { + /* +BEGIN_TEMPLATE +const __m256d bri% = _mm256_loadu_pd(bb[%]); +const __m256d bir% = _mm256_shuffle_pd(bri%,bri%,5); +__m256d p% = _mm256_mul_pd(bir%,omii); +p% = _mm256_fmaddsub_pd(bri%,omrr,p%); +const __m256d ari% = _mm256_loadu_pd(aa[%]); +_mm256_storeu_pd(aa[%],_mm256_add_pd(ari%,p%)); +_mm256_storeu_pd(bb[%],_mm256_sub_pd(ari%,p%)); +bb += @; // ONCE +aa += @; // ONCE +END_TEMPLATE + */ + // BEGIN_INTERLEAVE 4 + // This block is automatically generated from the template above + // by the interleave.pl script. Please do not edit by hand + const __m256d bri0 = _mm256_loadu_pd(bb[0]); + const __m256d bri1 = _mm256_loadu_pd(bb[1]); + const __m256d bri2 = _mm256_loadu_pd(bb[2]); + const __m256d bri3 = _mm256_loadu_pd(bb[3]); + const __m256d bir0 = _mm256_shuffle_pd(bri0, bri0, 5); + const __m256d bir1 = _mm256_shuffle_pd(bri1, bri1, 5); + const __m256d bir2 = _mm256_shuffle_pd(bri2, bri2, 5); + const __m256d bir3 = _mm256_shuffle_pd(bri3, bri3, 5); + __m256d p0 = _mm256_mul_pd(bir0, omii); + __m256d p1 = _mm256_mul_pd(bir1, omii); + __m256d p2 = _mm256_mul_pd(bir2, omii); + __m256d p3 = _mm256_mul_pd(bir3, omii); + p0 = _mm256_fmaddsub_pd(bri0, omrr, p0); + p1 = _mm256_fmaddsub_pd(bri1, omrr, p1); + p2 = _mm256_fmaddsub_pd(bri2, omrr, p2); + p3 = _mm256_fmaddsub_pd(bri3, omrr, p3); + const __m256d ari0 = _mm256_loadu_pd(aa[0]); + const __m256d ari1 = _mm256_loadu_pd(aa[1]); + const __m256d ari2 = _mm256_loadu_pd(aa[2]); + const __m256d ari3 = _mm256_loadu_pd(aa[3]); + _mm256_storeu_pd(aa[0], _mm256_add_pd(ari0, p0)); + _mm256_storeu_pd(aa[1], _mm256_add_pd(ari1, p1)); + _mm256_storeu_pd(aa[2], _mm256_add_pd(ari2, p2)); + _mm256_storeu_pd(aa[3], _mm256_add_pd(ari3, p3)); + _mm256_storeu_pd(bb[0], _mm256_sub_pd(ari0, p0)); + _mm256_storeu_pd(bb[1], _mm256_sub_pd(ari1, p1)); + _mm256_storeu_pd(bb[2], _mm256_sub_pd(ari2, p2)); + _mm256_storeu_pd(bb[3], _mm256_sub_pd(ari3, p3)); + bb += 4; // ONCE + aa += 4; // ONCE + // END_INTERLEAVE + } while (aa < aend); +} + +EXPORT void cplx_fftvec_copy_fma(uint32_t m, void* r, const void* a) { + const double(*aa)[4] = (double(*)[4])a; + double(*rr)[4] = (double(*)[4])r; + const double(*const aend)[4] = aa + (m >> 1); + do { + /* +BEGIN_TEMPLATE +const __m256d ari% = _mm256_loadu_pd(aa[%]); +_mm256_storeu_pd(rr[%],ari%); +rr += @; // ONCE +aa += @; // ONCE +END_TEMPLATE + */ + // BEGIN_INTERLEAVE 4 + // This block is automatically generated from the template above + // by the interleave.pl script. Please do not edit by hand + const __m256d ari0 = _mm256_loadu_pd(aa[0]); + const __m256d ari1 = _mm256_loadu_pd(aa[1]); + const __m256d ari2 = _mm256_loadu_pd(aa[2]); + const __m256d ari3 = _mm256_loadu_pd(aa[3]); + _mm256_storeu_pd(rr[0], ari0); + _mm256_storeu_pd(rr[1], ari1); + _mm256_storeu_pd(rr[2], ari2); + _mm256_storeu_pd(rr[3], ari3); + rr += 4; // ONCE + aa += 4; // ONCE + // END_INTERLEAVE + } while (aa < aend); +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_fftvec_ref.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_fftvec_ref.c new file mode 100644 index 0000000..f7d4629 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_fftvec_ref.c @@ -0,0 +1,85 @@ +#include + +#include "../commons_private.h" +#include "cplx_fft_internal.h" +#include "cplx_fft_private.h" + +EXPORT void cplx_fftvec_addmul_ref(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a, const void* b) { + const uint32_t m = precomp->m; + const CPLX* aa = (CPLX*)a; + const CPLX* bb = (CPLX*)b; + CPLX* rr = (CPLX*)r; + for (uint32_t i = 0; i < m; ++i) { + const double re = aa[i][0] * bb[i][0] - aa[i][1] * bb[i][1]; + const double im = aa[i][0] * bb[i][1] + aa[i][1] * bb[i][0]; + rr[i][0] += re; + rr[i][1] += im; + } +} + +EXPORT void cplx_fftvec_mul_ref(const CPLX_FFTVEC_MUL_PRECOMP* precomp, void* r, const void* a, const void* b) { + const uint32_t m = precomp->m; + const CPLX* aa = (CPLX*)a; + const CPLX* bb = (CPLX*)b; + CPLX* rr = (CPLX*)r; + for (uint32_t i = 0; i < m; ++i) { + const double re = aa[i][0] * bb[i][0] - aa[i][1] * bb[i][1]; + const double im = aa[i][0] * bb[i][1] + aa[i][1] * bb[i][0]; + rr[i][0] = re; + rr[i][1] = im; + } +} + +EXPORT void* init_cplx_fftvec_addmul_precomp(CPLX_FFTVEC_ADDMUL_PRECOMP* r, uint32_t m) { + if (m & (m - 1)) return spqlios_error("m must be a power of two"); + r->m = m; + if (m <= 4) { + r->function = cplx_fftvec_addmul_ref; + } else if (CPU_SUPPORTS("fma")) { + r->function = cplx_fftvec_addmul_fma; + } else { + r->function = cplx_fftvec_addmul_ref; + } + return r; +} + +EXPORT void* init_cplx_fftvec_mul_precomp(CPLX_FFTVEC_MUL_PRECOMP* r, uint32_t m) { + if (m & (m - 1)) return spqlios_error("m must be a power of two"); + r->m = m; + if (m <= 4) { + r->function = cplx_fftvec_mul_ref; + } else if (CPU_SUPPORTS("fma")) { + r->function = cplx_fftvec_mul_fma; + } else { + r->function = cplx_fftvec_mul_ref; + } + return r; +} + +EXPORT CPLX_FFTVEC_ADDMUL_PRECOMP* new_cplx_fftvec_addmul_precomp(uint32_t m) { + CPLX_FFTVEC_ADDMUL_PRECOMP* r = malloc(sizeof(CPLX_FFTVEC_MUL_PRECOMP)); + return spqlios_keep_or_free(r, init_cplx_fftvec_addmul_precomp(r, m)); +} + +EXPORT CPLX_FFTVEC_MUL_PRECOMP* new_cplx_fftvec_mul_precomp(uint32_t m) { + CPLX_FFTVEC_MUL_PRECOMP* r = malloc(sizeof(CPLX_FFTVEC_MUL_PRECOMP)); + return spqlios_keep_or_free(r, init_cplx_fftvec_mul_precomp(r, m)); +} + +EXPORT void cplx_fftvec_mul_simple(uint32_t m, void* r, const void* a, const void* b) { + static CPLX_FFTVEC_MUL_PRECOMP p[31] = {0}; + CPLX_FFTVEC_MUL_PRECOMP* f = p + log2m(m); + if (!f->function) { + if (!init_cplx_fftvec_mul_precomp(f, m)) abort(); + } + f->function(f, r, a, b); +} + +EXPORT void cplx_fftvec_addmul_simple(uint32_t m, void* r, const void* a, const void* b) { + static CPLX_FFTVEC_ADDMUL_PRECOMP p[31] = {0}; + CPLX_FFTVEC_ADDMUL_PRECOMP* f = p + log2m(m); + if (!f->function) { + if (!init_cplx_fftvec_addmul_precomp(f, m)) abort(); + } + f->function(f, r, a, b); +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_ifft16_avx_fma.s b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_ifft16_avx_fma.s new file mode 100644 index 0000000..bc9ea10 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_ifft16_avx_fma.s @@ -0,0 +1,157 @@ +# shifted FFT over X^16-i +# 1st argument (rdi) contains 16 complexes +# 2nd argument (rsi) contains: 8 complexes +# omega,alpha,beta,j.beta,gamma,j.gamma,k.gamma,kj.gamma +# alpha = sqrt(omega), beta = sqrt(alpha), gamma = sqrt(beta) +# j = sqrt(i), k=sqrt(j) +.globl cplx_ifft16_avx_fma +cplx_ifft16_avx_fma: +vmovupd (%rdi),%ymm8 # load data into registers %ymm8 -> %ymm15 +vmovupd 0x20(%rdi),%ymm9 +vmovupd 0x40(%rdi),%ymm10 +vmovupd 0x60(%rdi),%ymm11 +vmovupd 0x80(%rdi),%ymm12 +vmovupd 0xa0(%rdi),%ymm13 +vmovupd 0xc0(%rdi),%ymm14 +vmovupd 0xe0(%rdi),%ymm15 + +.fourth_pass: +vmovupd 0(%rsi),%ymm0 /* gamma */ +vmovupd 32(%rsi),%ymm2 /* delta */ +vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: gama.iiii */ +vshufpd $15, %ymm2, %ymm2, %ymm3 /* ymm3: delta.iiii */ +vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: gama.rrrr */ +vshufpd $0, %ymm2, %ymm2, %ymm2 /* ymm2: delta.rrrr */ +vperm2f128 $0x31,%ymm10,%ymm8,%ymm4 # ymm4 contains c1,c5 +vperm2f128 $0x31,%ymm11,%ymm9,%ymm5 # ymm5 contains c3,c7 +vperm2f128 $0x31,%ymm14,%ymm12,%ymm6 # ymm6 contains c9,c13 +vperm2f128 $0x31,%ymm15,%ymm13,%ymm7 # ymm7 contains c11,c15 +vperm2f128 $0x20,%ymm10,%ymm8,%ymm8 # ymm8 contains c0,c4 +vperm2f128 $0x20,%ymm11,%ymm9,%ymm9 # ymm9 contains c2,c6 +vperm2f128 $0x20,%ymm14,%ymm12,%ymm10 # ymm10 contains c8,c12 +vperm2f128 $0x20,%ymm15,%ymm13,%ymm11 # ymm11 contains c10,c14 +vsubpd %ymm4,%ymm8,%ymm12 # tw: to mul by gamma +vsubpd %ymm5,%ymm9,%ymm13 # itw: to mul by i.gamma +vsubpd %ymm6,%ymm10,%ymm14 # tw: to mul by delta +vsubpd %ymm7,%ymm11,%ymm15 # itw: to mul by i.delta +vaddpd %ymm4,%ymm8,%ymm8 +vaddpd %ymm5,%ymm9,%ymm9 +vaddpd %ymm6,%ymm10,%ymm10 +vaddpd %ymm7,%ymm11,%ymm11 +vshufpd $5, %ymm12, %ymm12, %ymm4 +vshufpd $5, %ymm13, %ymm13, %ymm5 +vshufpd $5, %ymm14, %ymm14, %ymm6 +vshufpd $5, %ymm15, %ymm15, %ymm7 +vmulpd %ymm4,%ymm1,%ymm4 +vmulpd %ymm5,%ymm0,%ymm5 +vmulpd %ymm6,%ymm3,%ymm6 +vmulpd %ymm7,%ymm2,%ymm7 +vfmaddsub231pd %ymm12, %ymm0, %ymm4 # ymm4 = (ymm0 * ymm12) +/- ymm4 +vfmsubadd231pd %ymm13, %ymm1, %ymm5 +vfmaddsub231pd %ymm14, %ymm2, %ymm6 +vfmsubadd231pd %ymm15, %ymm3, %ymm7 + +vperm2f128 $0x20,%ymm6,%ymm10,%ymm12 # ymm4 contains c1,c5 -- x gamma +vperm2f128 $0x20,%ymm7,%ymm11,%ymm13 # ymm5 contains c3,c7 -- x igamma +vperm2f128 $0x31,%ymm6,%ymm10,%ymm14 # ymm6 contains c9,c13 -- x delta +vperm2f128 $0x31,%ymm7,%ymm11,%ymm15 # ymm7 contains c11,c15 -- x idelta +vperm2f128 $0x31,%ymm4,%ymm8,%ymm10 # ymm10 contains c8,c12 +vperm2f128 $0x31,%ymm5,%ymm9,%ymm11 # ymm11 contains c10,c14 +vperm2f128 $0x20,%ymm4,%ymm8,%ymm8 # ymm8 contains c0,c4 +vperm2f128 $0x20,%ymm5,%ymm9,%ymm9 # ymm9 contains c2,c6 + + +.third_pass: +vmovupd 64(%rsi),%xmm0 /* gamma */ +vmovupd 80(%rsi),%xmm2 /* delta */ +vinsertf128 $1, %xmm0, %ymm0, %ymm0 +vinsertf128 $1, %xmm2, %ymm2, %ymm2 +vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: gama.iiii */ +vshufpd $15, %ymm2, %ymm2, %ymm3 /* ymm3: delta.iiii */ +vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: gama.rrrr */ +vshufpd $0, %ymm2, %ymm2, %ymm2 /* ymm2: delta.rrrr */ +vsubpd %ymm9,%ymm8,%ymm4 +vsubpd %ymm11,%ymm10,%ymm5 +vsubpd %ymm13,%ymm12,%ymm6 +vsubpd %ymm15,%ymm14,%ymm7 +vaddpd %ymm9,%ymm8,%ymm8 +vaddpd %ymm11,%ymm10,%ymm10 +vaddpd %ymm13,%ymm12,%ymm12 +vaddpd %ymm15,%ymm14,%ymm14 +vshufpd $5, %ymm4, %ymm4, %ymm9 +vshufpd $5, %ymm5, %ymm5, %ymm11 +vshufpd $5, %ymm6, %ymm6, %ymm13 +vshufpd $5, %ymm7, %ymm7, %ymm15 +vmulpd %ymm9,%ymm1,%ymm9 +vmulpd %ymm11,%ymm0,%ymm11 +vmulpd %ymm13,%ymm3,%ymm13 +vmulpd %ymm15,%ymm2,%ymm15 +vfmaddsub231pd %ymm4, %ymm0, %ymm9 # ymm9 = (ymm0 * ymm4) +/- ymm9 +vfmsubadd231pd %ymm5, %ymm1, %ymm11 +vfmaddsub231pd %ymm6, %ymm2, %ymm13 +vfmsubadd231pd %ymm7, %ymm3, %ymm15 + +.second_pass: +vmovupd 96(%rsi),%xmm0 /* omri */ +vinsertf128 $1, %xmm0, %ymm0, %ymm0 /* omriri */ +vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: omiiii */ +vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: omrrrr */ +vsubpd %ymm10,%ymm8,%ymm4 +vsubpd %ymm11,%ymm9,%ymm5 +vsubpd %ymm14,%ymm12,%ymm6 +vsubpd %ymm15,%ymm13,%ymm7 +vaddpd %ymm10,%ymm8,%ymm8 +vaddpd %ymm11,%ymm9,%ymm9 +vaddpd %ymm14,%ymm12,%ymm12 +vaddpd %ymm15,%ymm13,%ymm13 +vshufpd $5, %ymm4, %ymm4, %ymm10 +vshufpd $5, %ymm5, %ymm5, %ymm11 +vshufpd $5, %ymm6, %ymm6, %ymm14 +vshufpd $5, %ymm7, %ymm7, %ymm15 +vmulpd %ymm10,%ymm1,%ymm10 +vmulpd %ymm11,%ymm1,%ymm11 +vmulpd %ymm14,%ymm0,%ymm14 +vmulpd %ymm15,%ymm0,%ymm15 +vfmaddsub231pd %ymm4, %ymm0, %ymm10 # ymm10 = (ymm0 * ymm4) +/- ymm10 +vfmaddsub231pd %ymm5, %ymm0, %ymm11 +vfmsubadd231pd %ymm6, %ymm1, %ymm14 +vfmsubadd231pd %ymm7, %ymm1, %ymm15 + +.first_pass: +vmovupd 112(%rsi),%xmm0 /* omri */ +vinsertf128 $1, %xmm0, %ymm0, %ymm0 /* omriri */ +vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: omiiii */ +vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: omrrrr */ +vsubpd %ymm12,%ymm8,%ymm4 +vsubpd %ymm13,%ymm9,%ymm5 +vsubpd %ymm14,%ymm10,%ymm6 +vsubpd %ymm15,%ymm11,%ymm7 +vaddpd %ymm12,%ymm8,%ymm8 +vaddpd %ymm13,%ymm9,%ymm9 +vaddpd %ymm14,%ymm10,%ymm10 +vaddpd %ymm15,%ymm11,%ymm11 +vshufpd $5, %ymm4, %ymm4, %ymm12 +vshufpd $5, %ymm5, %ymm5, %ymm13 +vshufpd $5, %ymm6, %ymm6, %ymm14 +vshufpd $5, %ymm7, %ymm7, %ymm15 +vmulpd %ymm12,%ymm1,%ymm12 +vmulpd %ymm13,%ymm1,%ymm13 +vmulpd %ymm14,%ymm1,%ymm14 +vmulpd %ymm15,%ymm1,%ymm15 +vfmaddsub231pd %ymm4, %ymm0, %ymm12 # ymm12 = (ymm0 * ymm4) +/- ymm12 +vfmaddsub231pd %ymm5, %ymm0, %ymm13 +vfmaddsub231pd %ymm6, %ymm0, %ymm14 +vfmaddsub231pd %ymm7, %ymm0, %ymm15 + +.save_and_return: +vmovupd %ymm8,(%rdi) +vmovupd %ymm9,0x20(%rdi) +vmovupd %ymm10,0x40(%rdi) +vmovupd %ymm11,0x60(%rdi) +vmovupd %ymm12,0x80(%rdi) +vmovupd %ymm13,0xa0(%rdi) +vmovupd %ymm14,0xc0(%rdi) +vmovupd %ymm15,0xe0(%rdi) +ret +.size cplx_ifft16_avx_fma, .-cplx_ifft16_avx_fma +.section .note.GNU-stack,"",@progbits diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_ifft16_avx_fma_win32.s b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_ifft16_avx_fma_win32.s new file mode 100644 index 0000000..1803882 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_ifft16_avx_fma_win32.s @@ -0,0 +1,192 @@ + .text + .p2align 4 + .globl cplx_ifft16_avx_fma + .def cplx_ifft16_avx_fma; .scl 2; .type 32; .endef +cplx_ifft16_avx_fma: + + pushq %rdi + pushq %rsi + movq %rcx,%rdi + movq %rdx,%rsi + subq $0x100,%rsp + movdqu %xmm6,(%rsp) + movdqu %xmm7,0x10(%rsp) + movdqu %xmm8,0x20(%rsp) + movdqu %xmm9,0x30(%rsp) + movdqu %xmm10,0x40(%rsp) + movdqu %xmm11,0x50(%rsp) + movdqu %xmm12,0x60(%rsp) + movdqu %xmm13,0x70(%rsp) + movdqu %xmm14,0x80(%rsp) + movdqu %xmm15,0x90(%rsp) + callq cplx_ifft16_avx_fma_amd64 + movdqu (%rsp),%xmm6 + movdqu 0x10(%rsp),%xmm7 + movdqu 0x20(%rsp),%xmm8 + movdqu 0x30(%rsp),%xmm9 + movdqu 0x40(%rsp),%xmm10 + movdqu 0x50(%rsp),%xmm11 + movdqu 0x60(%rsp),%xmm12 + movdqu 0x70(%rsp),%xmm13 + movdqu 0x80(%rsp),%xmm14 + movdqu 0x90(%rsp),%xmm15 + addq $0x100,%rsp + popq %rsi + popq %rdi + retq + +# shifted FFT over X^16-i +# 1st argument (rdi) contains 16 complexes +# 2nd argument (rsi) contains: 8 complexes +# omega,alpha,beta,j.beta,gamma,j.gamma,k.gamma,kj.gamma +# alpha = sqrt(omega), beta = sqrt(alpha), gamma = sqrt(beta) +# j = sqrt(i), k=sqrt(j) + +cplx_ifft16_avx_fma_amd64: +vmovupd (%rdi),%ymm8 # load data into registers %ymm8 -> %ymm15 +vmovupd 0x20(%rdi),%ymm9 +vmovupd 0x40(%rdi),%ymm10 +vmovupd 0x60(%rdi),%ymm11 +vmovupd 0x80(%rdi),%ymm12 +vmovupd 0xa0(%rdi),%ymm13 +vmovupd 0xc0(%rdi),%ymm14 +vmovupd 0xe0(%rdi),%ymm15 + +.fourth_pass: +vmovupd 0(%rsi),%ymm0 /* gamma */ +vmovupd 32(%rsi),%ymm2 /* delta */ +vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: gama.iiii */ +vshufpd $15, %ymm2, %ymm2, %ymm3 /* ymm3: delta.iiii */ +vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: gama.rrrr */ +vshufpd $0, %ymm2, %ymm2, %ymm2 /* ymm2: delta.rrrr */ +vperm2f128 $0x31,%ymm10,%ymm8,%ymm4 # ymm4 contains c1,c5 +vperm2f128 $0x31,%ymm11,%ymm9,%ymm5 # ymm5 contains c3,c7 +vperm2f128 $0x31,%ymm14,%ymm12,%ymm6 # ymm6 contains c9,c13 +vperm2f128 $0x31,%ymm15,%ymm13,%ymm7 # ymm7 contains c11,c15 +vperm2f128 $0x20,%ymm10,%ymm8,%ymm8 # ymm8 contains c0,c4 +vperm2f128 $0x20,%ymm11,%ymm9,%ymm9 # ymm9 contains c2,c6 +vperm2f128 $0x20,%ymm14,%ymm12,%ymm10 # ymm10 contains c8,c12 +vperm2f128 $0x20,%ymm15,%ymm13,%ymm11 # ymm11 contains c10,c14 +vsubpd %ymm4,%ymm8,%ymm12 # tw: to mul by gamma +vsubpd %ymm5,%ymm9,%ymm13 # itw: to mul by i.gamma +vsubpd %ymm6,%ymm10,%ymm14 # tw: to mul by delta +vsubpd %ymm7,%ymm11,%ymm15 # itw: to mul by i.delta +vaddpd %ymm4,%ymm8,%ymm8 +vaddpd %ymm5,%ymm9,%ymm9 +vaddpd %ymm6,%ymm10,%ymm10 +vaddpd %ymm7,%ymm11,%ymm11 +vshufpd $5, %ymm12, %ymm12, %ymm4 +vshufpd $5, %ymm13, %ymm13, %ymm5 +vshufpd $5, %ymm14, %ymm14, %ymm6 +vshufpd $5, %ymm15, %ymm15, %ymm7 +vmulpd %ymm4,%ymm1,%ymm4 +vmulpd %ymm5,%ymm0,%ymm5 +vmulpd %ymm6,%ymm3,%ymm6 +vmulpd %ymm7,%ymm2,%ymm7 +vfmaddsub231pd %ymm12, %ymm0, %ymm4 # ymm4 = (ymm0 * ymm12) +/- ymm4 +vfmsubadd231pd %ymm13, %ymm1, %ymm5 +vfmaddsub231pd %ymm14, %ymm2, %ymm6 +vfmsubadd231pd %ymm15, %ymm3, %ymm7 + +vperm2f128 $0x20,%ymm6,%ymm10,%ymm12 # ymm4 contains c1,c5 -- x gamma +vperm2f128 $0x20,%ymm7,%ymm11,%ymm13 # ymm5 contains c3,c7 -- x igamma +vperm2f128 $0x31,%ymm6,%ymm10,%ymm14 # ymm6 contains c9,c13 -- x delta +vperm2f128 $0x31,%ymm7,%ymm11,%ymm15 # ymm7 contains c11,c15 -- x idelta +vperm2f128 $0x31,%ymm4,%ymm8,%ymm10 # ymm10 contains c8,c12 +vperm2f128 $0x31,%ymm5,%ymm9,%ymm11 # ymm11 contains c10,c14 +vperm2f128 $0x20,%ymm4,%ymm8,%ymm8 # ymm8 contains c0,c4 +vperm2f128 $0x20,%ymm5,%ymm9,%ymm9 # ymm9 contains c2,c6 + + +.third_pass: +vmovupd 64(%rsi),%xmm0 /* gamma */ +vmovupd 80(%rsi),%xmm2 /* delta */ +vinsertf128 $1, %xmm0, %ymm0, %ymm0 +vinsertf128 $1, %xmm2, %ymm2, %ymm2 +vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: gama.iiii */ +vshufpd $15, %ymm2, %ymm2, %ymm3 /* ymm3: delta.iiii */ +vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: gama.rrrr */ +vshufpd $0, %ymm2, %ymm2, %ymm2 /* ymm2: delta.rrrr */ +vsubpd %ymm9,%ymm8,%ymm4 +vsubpd %ymm11,%ymm10,%ymm5 +vsubpd %ymm13,%ymm12,%ymm6 +vsubpd %ymm15,%ymm14,%ymm7 +vaddpd %ymm9,%ymm8,%ymm8 +vaddpd %ymm11,%ymm10,%ymm10 +vaddpd %ymm13,%ymm12,%ymm12 +vaddpd %ymm15,%ymm14,%ymm14 +vshufpd $5, %ymm4, %ymm4, %ymm9 +vshufpd $5, %ymm5, %ymm5, %ymm11 +vshufpd $5, %ymm6, %ymm6, %ymm13 +vshufpd $5, %ymm7, %ymm7, %ymm15 +vmulpd %ymm9,%ymm1,%ymm9 +vmulpd %ymm11,%ymm0,%ymm11 +vmulpd %ymm13,%ymm3,%ymm13 +vmulpd %ymm15,%ymm2,%ymm15 +vfmaddsub231pd %ymm4, %ymm0, %ymm9 # ymm9 = (ymm0 * ymm4) +/- ymm9 +vfmsubadd231pd %ymm5, %ymm1, %ymm11 +vfmaddsub231pd %ymm6, %ymm2, %ymm13 +vfmsubadd231pd %ymm7, %ymm3, %ymm15 + +.second_pass: +vmovupd 96(%rsi),%xmm0 /* omri */ +vinsertf128 $1, %xmm0, %ymm0, %ymm0 /* omriri */ +vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: omiiii */ +vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: omrrrr */ +vsubpd %ymm10,%ymm8,%ymm4 +vsubpd %ymm11,%ymm9,%ymm5 +vsubpd %ymm14,%ymm12,%ymm6 +vsubpd %ymm15,%ymm13,%ymm7 +vaddpd %ymm10,%ymm8,%ymm8 +vaddpd %ymm11,%ymm9,%ymm9 +vaddpd %ymm14,%ymm12,%ymm12 +vaddpd %ymm15,%ymm13,%ymm13 +vshufpd $5, %ymm4, %ymm4, %ymm10 +vshufpd $5, %ymm5, %ymm5, %ymm11 +vshufpd $5, %ymm6, %ymm6, %ymm14 +vshufpd $5, %ymm7, %ymm7, %ymm15 +vmulpd %ymm10,%ymm1,%ymm10 +vmulpd %ymm11,%ymm1,%ymm11 +vmulpd %ymm14,%ymm0,%ymm14 +vmulpd %ymm15,%ymm0,%ymm15 +vfmaddsub231pd %ymm4, %ymm0, %ymm10 # ymm10 = (ymm0 * ymm4) +/- ymm10 +vfmaddsub231pd %ymm5, %ymm0, %ymm11 +vfmsubadd231pd %ymm6, %ymm1, %ymm14 +vfmsubadd231pd %ymm7, %ymm1, %ymm15 + +.first_pass: +vmovupd 112(%rsi),%xmm0 /* omri */ +vinsertf128 $1, %xmm0, %ymm0, %ymm0 /* omriri */ +vshufpd $15, %ymm0, %ymm0, %ymm1 /* ymm1: omiiii */ +vshufpd $0, %ymm0, %ymm0, %ymm0 /* ymm0: omrrrr */ +vsubpd %ymm12,%ymm8,%ymm4 +vsubpd %ymm13,%ymm9,%ymm5 +vsubpd %ymm14,%ymm10,%ymm6 +vsubpd %ymm15,%ymm11,%ymm7 +vaddpd %ymm12,%ymm8,%ymm8 +vaddpd %ymm13,%ymm9,%ymm9 +vaddpd %ymm14,%ymm10,%ymm10 +vaddpd %ymm15,%ymm11,%ymm11 +vshufpd $5, %ymm4, %ymm4, %ymm12 +vshufpd $5, %ymm5, %ymm5, %ymm13 +vshufpd $5, %ymm6, %ymm6, %ymm14 +vshufpd $5, %ymm7, %ymm7, %ymm15 +vmulpd %ymm12,%ymm1,%ymm12 +vmulpd %ymm13,%ymm1,%ymm13 +vmulpd %ymm14,%ymm1,%ymm14 +vmulpd %ymm15,%ymm1,%ymm15 +vfmaddsub231pd %ymm4, %ymm0, %ymm12 # ymm12 = (ymm0 * ymm4) +/- ymm12 +vfmaddsub231pd %ymm5, %ymm0, %ymm13 +vfmaddsub231pd %ymm6, %ymm0, %ymm14 +vfmaddsub231pd %ymm7, %ymm0, %ymm15 + +.save_and_return: +vmovupd %ymm8,(%rdi) +vmovupd %ymm9,0x20(%rdi) +vmovupd %ymm10,0x40(%rdi) +vmovupd %ymm11,0x60(%rdi) +vmovupd %ymm12,0x80(%rdi) +vmovupd %ymm13,0xa0(%rdi) +vmovupd %ymm14,0xc0(%rdi) +vmovupd %ymm15,0xe0(%rdi) +ret diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_ifft_avx2_fma.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_ifft_avx2_fma.c new file mode 100644 index 0000000..c902f5e --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_ifft_avx2_fma.c @@ -0,0 +1,267 @@ +#include +#include +#include + +#include "cplx_fft_internal.h" +#include "cplx_fft_private.h" + +typedef double D4MEM[4]; +typedef double D2MEM[2]; + +/** + * @brief complex ifft via bfs strategy (for m between 2 and 8) + * @param dat the data to run the algorithm on + * @param omg precomputed tables (must have been filled with fill_omega) + * @param m ring dimension of the FFT (modulo X^m-i) + */ +void cplx_ifft_avx2_fma_bfs_2(D4MEM* dat, const D2MEM** omga, uint32_t m) { + double* data = (double*)dat; + D4MEM* const finaldd = (D4MEM*)(data + 2 * m); + { + // loop with h = 1 + // we do not do any particular optimization in this loop, + // since this function is only called for small dimensions + D4MEM* dd = (D4MEM*)data; + do { + /* + BEGIN_TEMPLATE + const __m256d ab% = _mm256_loadu_pd(dd[0+2*%]); + const __m256d cd% = _mm256_loadu_pd(dd[1+2*%]); + const __m256d ac% = _mm256_permute2f128_pd(ab%, cd%, 0b100000); + const __m256d bd% = _mm256_permute2f128_pd(ab%, cd%, 0b110001); + const __m256d sum% = _mm256_add_pd(ac%, bd%); + const __m256d diff% = _mm256_sub_pd(ac%, bd%); + const __m256d diffbar% = _mm256_shuffle_pd(diff%, diff%, 5); + const __m256d om% = _mm256_load_pd((*omg)[0+%]); + const __m256d omre% = _mm256_unpacklo_pd(om%, om%); + const __m256d omim% = _mm256_unpackhi_pd(om%, om%); + const __m256d t1% = _mm256_mul_pd(diffbar%, omim%); + const __m256d t2% = _mm256_fmaddsub_pd(diff%, omre%, t1%); + const __m256d newab% = _mm256_permute2f128_pd(sum%, t2%, 0b100000); + const __m256d newcd% = _mm256_permute2f128_pd(sum%, t2%, 0b110001); + _mm256_storeu_pd(dd[0+2*%], newab%); + _mm256_storeu_pd(dd[1+2*%], newcd%); + dd += 2*@; + *omg += 2*@; + END_TEMPLATE + */ + // BEGIN_INTERLEAVE 1 + const __m256d ab0 = _mm256_loadu_pd(dd[0 + 2 * 0]); + const __m256d cd0 = _mm256_loadu_pd(dd[1 + 2 * 0]); + const __m256d ac0 = _mm256_permute2f128_pd(ab0, cd0, 0b100000); + const __m256d bd0 = _mm256_permute2f128_pd(ab0, cd0, 0b110001); + const __m256d sum0 = _mm256_add_pd(ac0, bd0); + const __m256d diff0 = _mm256_sub_pd(ac0, bd0); + const __m256d diffbar0 = _mm256_shuffle_pd(diff0, diff0, 5); + const __m256d om0 = _mm256_load_pd((*omga)[0 + 0]); + const __m256d omre0 = _mm256_unpacklo_pd(om0, om0); + const __m256d omim0 = _mm256_unpackhi_pd(om0, om0); + const __m256d t10 = _mm256_mul_pd(diffbar0, omim0); + const __m256d t20 = _mm256_fmaddsub_pd(diff0, omre0, t10); + const __m256d newab0 = _mm256_permute2f128_pd(sum0, t20, 0b100000); + const __m256d newcd0 = _mm256_permute2f128_pd(sum0, t20, 0b110001); + _mm256_storeu_pd(dd[0 + 2 * 0], newab0); + _mm256_storeu_pd(dd[1 + 2 * 0], newcd0); + dd += 2 * 1; + *omga += 2 * 1; + // END_INTERLEAVE + } while (dd < finaldd); +#if 0 + printf("c after first: "); + for (uint64_t ii=0; ii> 1; + for (uint32_t _2nblock = 2; _2nblock <= ms2; _2nblock <<= 1) { + // _2nblock = h in ref code + uint32_t nblock = _2nblock >> 1; // =h/2 in ref code + D4MEM* dd = (D4MEM*)data; + do { + const __m256d om = _mm256_load_pd((*omga)[0]); + const __m256d omre = _mm256_unpacklo_pd(om, om); + const __m256d omim = _mm256_addsub_pd(_mm256_setzero_pd(), _mm256_unpackhi_pd(om, om)); + D4MEM* const ddend = (dd + nblock); + D4MEM* ddmid = ddend; + do { + const __m256d a = _mm256_loadu_pd(dd[0]); + const __m256d b = _mm256_loadu_pd(ddmid[0]); + const __m256d newa = _mm256_add_pd(a, b); + _mm256_storeu_pd(dd[0], newa); + const __m256d diff = _mm256_sub_pd(a, b); + const __m256d t1 = _mm256_mul_pd(diff, omre); + const __m256d bardiff = _mm256_shuffle_pd(diff, diff, 5); + const __m256d t2 = _mm256_fmadd_pd(bardiff, omim, t1); + _mm256_storeu_pd(ddmid[0], t2); + dd += 1; + ddmid += 1; + } while (dd < ddend); + dd += nblock; + *omga += 2; + } while (dd < finaldd); + } +} + +/** + * @brief complex fft via bfs strategy (for m >= 16) + * @param dat the data to run the algorithm on + * @param omg precomputed tables (must have been filled with fill_omega) + * @param m ring dimension of the FFT (modulo X^m-i) + */ +void cplx_ifft_avx2_fma_bfs_16(D4MEM* dat, const D2MEM** omga, uint32_t m) { + double* data = (double*)dat; + D4MEM* const finaldd = (D4MEM*)(data + 2 * m); + // base iteration when h = _2nblock == 8 + { + D4MEM* dd = (D4MEM*)data; + do { + cplx_ifft16_avx_fma(dd, *omga); + dd += 8; + *omga += 8; + } while (dd < finaldd); + } + // general case + const uint32_t log2m = _mm_popcnt_u32(m - 1); //_popcnt32(m-1); //log2(m); + uint32_t h = 16; + if (log2m % 2 == 1) { + uint32_t nblock = h >> 1; // =h/2 in ref code + D4MEM* dd = (D4MEM*)data; + do { + const __m128d om1 = _mm_loadu_pd((*omga)[0]); + const __m256d om = _mm256_set_m128d(om1, om1); + const __m256d omre = _mm256_unpacklo_pd(om, om); + const __m256d omim = _mm256_addsub_pd(_mm256_setzero_pd(), _mm256_unpackhi_pd(om, om)); + D4MEM* const ddend = (dd + nblock); + D4MEM* ddmid = ddend; + do { + const __m256d a = _mm256_loadu_pd(dd[0]); + const __m256d b = _mm256_loadu_pd(ddmid[0]); + const __m256d newa = _mm256_add_pd(a, b); + _mm256_storeu_pd(dd[0], newa); + const __m256d diff = _mm256_sub_pd(a, b); + const __m256d t1 = _mm256_mul_pd(diff, omre); + const __m256d bardiff = _mm256_shuffle_pd(diff, diff, 5); + const __m256d t2 = _mm256_fmadd_pd(bardiff, omim, t1); + _mm256_storeu_pd(ddmid[0], t2); + dd += 1; + ddmid += 1; + } while (dd < ddend); + dd += nblock; + *omga += 1; + } while (dd < finaldd); + h = 32; + } + for (; h < m; h <<= 2) { + // _2nblock = h in ref code + uint32_t nblock = h >> 1; // =h/2 in ref code + D4MEM* dd0 = (D4MEM*)data; + do { + const __m128d om1 = _mm_loadu_pd((*omga)[0]); + const __m128d al1 = _mm_loadu_pd((*omga)[1]); + const __m256d om = _mm256_set_m128d(om1, om1); + const __m256d al = _mm256_set_m128d(al1, al1); + const __m256d omre = _mm256_unpacklo_pd(om, om); + const __m256d omim = _mm256_unpackhi_pd(om, om); + const __m256d alre = _mm256_unpacklo_pd(al, al); + const __m256d alim = _mm256_unpackhi_pd(al, al); + D4MEM* const ddend = (dd0 + nblock); + D4MEM* dd1 = ddend; + D4MEM* dd2 = dd1 + nblock; + D4MEM* dd3 = dd2 + nblock; + do { + __m256d u0 = _mm256_loadu_pd(dd0[0]); + __m256d u1 = _mm256_loadu_pd(dd1[0]); + __m256d u2 = _mm256_loadu_pd(dd2[0]); + __m256d u3 = _mm256_loadu_pd(dd3[0]); + __m256d u4 = _mm256_add_pd(u0, u1); + __m256d u5 = _mm256_sub_pd(u0, u1); + __m256d u6 = _mm256_add_pd(u2, u3); + __m256d u7 = _mm256_sub_pd(u2, u3); + u0 = _mm256_shuffle_pd(u5, u5, 5); + u2 = _mm256_shuffle_pd(u7, u7, 5); + u1 = _mm256_mul_pd(u0, omim); + u3 = _mm256_mul_pd(u2, omre); + u5 = _mm256_fmaddsub_pd(u5, omre, u1); + u7 = _mm256_fmsubadd_pd(u7, omim, u3); + ////// + u0 = _mm256_add_pd(u4, u6); + u1 = _mm256_add_pd(u5, u7); + u2 = _mm256_sub_pd(u4, u6); + u3 = _mm256_sub_pd(u5, u7); + u4 = _mm256_shuffle_pd(u2, u2, 5); + u5 = _mm256_shuffle_pd(u3, u3, 5); + u6 = _mm256_mul_pd(u4, alim); + u7 = _mm256_mul_pd(u5, alim); + u2 = _mm256_fmaddsub_pd(u2, alre, u6); + u3 = _mm256_fmaddsub_pd(u3, alre, u7); + /////// + _mm256_storeu_pd(dd0[0], u0); + _mm256_storeu_pd(dd1[0], u1); + _mm256_storeu_pd(dd2[0], u2); + _mm256_storeu_pd(dd3[0], u3); + dd0 += 1; + dd1 += 1; + dd2 += 1; + dd3 += 1; + } while (dd0 < ddend); + dd0 += 3 * nblock; + *omga += 2; + } while (dd0 < finaldd); + } +} + +/** + * @brief complex ifft via dfs recursion (for m >= 16) + * @param dat the data to run the algorithm on + * @param omg precomputed tables (must have been filled with fill_omega) + * @param m ring dimension of the FFT (modulo X^m-i) + */ +void cplx_ifft_avx2_fma_rec_16(D4MEM* dat, const D2MEM** omga, uint32_t m) { + if (m <= 8) return cplx_ifft_avx2_fma_bfs_2(dat, omga, m); + if (m <= 2048) return cplx_ifft_avx2_fma_bfs_16(dat, omga, m); + const uint32_t _2nblock = m >> 1; // = h in ref code + const uint32_t nblock = _2nblock >> 1; // =h/2 in ref code + cplx_ifft_avx2_fma_rec_16(dat, omga, _2nblock); + cplx_ifft_avx2_fma_rec_16(dat + nblock, omga, _2nblock); + { + // final iteration + D4MEM* dd = dat; + const __m256d om = _mm256_load_pd((*omga)[0]); + const __m256d omre = _mm256_unpacklo_pd(om, om); + const __m256d omim = _mm256_addsub_pd(_mm256_setzero_pd(), _mm256_unpackhi_pd(om, om)); + D4MEM* const ddend = (dd + nblock); + D4MEM* ddmid = ddend; + do { + const __m256d a = _mm256_loadu_pd(dd[0]); + const __m256d b = _mm256_loadu_pd(ddmid[0]); + const __m256d newa = _mm256_add_pd(a, b); + _mm256_storeu_pd(dd[0], newa); + const __m256d diff = _mm256_sub_pd(a, b); + const __m256d t1 = _mm256_mul_pd(diff, omre); + const __m256d bardiff = _mm256_shuffle_pd(diff, diff, 5); + const __m256d t2 = _mm256_fmadd_pd(bardiff, omim, t1); + _mm256_storeu_pd(ddmid[0], t2); + dd += 1; + ddmid += 1; + } while (dd < ddend); + *omga += 2; + } +} + +/** + * @brief complex ifft via best strategy (for m>=1) + * @param dat the data to run the algorithm on: m complex numbers + * @param omg precomputed tables (must have been filled with fill_omega) + * @param m ring dimension of the FFT (modulo X^m-i) + */ +EXPORT void cplx_ifft_avx2_fma(const CPLX_IFFT_PRECOMP* precomp, void* d) { + const uint32_t m = precomp->m; + const D2MEM* omg = (D2MEM*)precomp->powomegas; + if (m <= 1) return; + if (m <= 8) return cplx_ifft_avx2_fma_bfs_2(d, &omg, m); + if (m <= 2048) return cplx_ifft_avx2_fma_bfs_16(d, &omg, m); + cplx_ifft_avx2_fma_rec_16(d, &omg, m); +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_ifft_ref.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_ifft_ref.c new file mode 100644 index 0000000..8380e12 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/cplx_ifft_ref.c @@ -0,0 +1,312 @@ +#include + +#include "../commons_private.h" +#include "cplx_fft.h" +#include "cplx_fft_internal.h" +#include "cplx_fft_private.h" + +/** @brief (a,b) <- (a+b,omegabar.(a-b)) */ +void invctwiddle(CPLX a, CPLX b, const CPLX ombar) { + double diffre = a[0] - b[0]; + double diffim = a[1] - b[1]; + a[0] = a[0] + b[0]; + a[1] = a[1] + b[1]; + b[0] = diffre * ombar[0] - diffim * ombar[1]; + b[1] = diffre * ombar[1] + diffim * ombar[0]; +} + +/** @brief (a,b) <- (a+b,-i.omegabar(a-b)) */ +void invcitwiddle(CPLX a, CPLX b, const CPLX ombar) { + double diffre = a[0] - b[0]; + double diffim = a[1] - b[1]; + a[0] = a[0] + b[0]; + a[1] = a[1] + b[1]; + //-i(x+iy)=-ix+y + b[0] = diffre * ombar[1] + diffim * ombar[0]; + b[1] = -diffre * ombar[0] + diffim * ombar[1]; +} + +/** @brief exp(-i.2pi.x) */ +void cplx_set_e2pimx(CPLX res, double x) { + res[0] = m_accurate_cos(2 * M_PI * x); + res[1] = -m_accurate_sin(2 * M_PI * x); +} +/** + * @brief this computes the sequence: 0,1/2,1/4,3/4,1/8,5/8,3/8,7/8,... + * essentially: the bits of (i+1) in lsb order on the basis (1/2^k) mod 1*/ +double fracrevbits(uint32_t i); +/** @brief fft modulo X^m-exp(i.2pi.entry+pwr) -- reference code */ +void cplx_ifft_naive(const uint32_t m, const double entry_pwr, CPLX* data) { + if (m == 1) return; + const double pom = entry_pwr / 2.; + const uint32_t h = m / 2; + CPLX cpom; + cplx_set_e2pimx(cpom, pom); + // do the recursive calls + cplx_ifft_naive(h, pom, data); + cplx_ifft_naive(h, pom + 0.5, data + h); + // apply the inverse twiddle factors + for (uint64_t i = 0; i < h; ++i) { + invctwiddle(data[i], data[i + h], cpom); + } +} + +void cplx_ifft16_precomp(const double entry_pwr, CPLX** omg) { + static const double j_pow = 1. / 8.; + static const double k_pow = 1. / 16.; + const double pom = entry_pwr / 2.; + const double pom_2 = entry_pwr / 4.; + const double pom_4 = entry_pwr / 8.; + const double pom_8 = entry_pwr / 16.; + cplx_set_e2pimx((*omg)[0], pom_8); + cplx_set_e2pimx((*omg)[1], pom_8 + j_pow); + cplx_set_e2pimx((*omg)[2], pom_8 + k_pow); + cplx_set_e2pimx((*omg)[3], pom_8 + j_pow + k_pow); + cplx_set_e2pimx((*omg)[4], pom_4); + cplx_set_e2pimx((*omg)[5], pom_4 + j_pow); + cplx_set_e2pimx((*omg)[6], pom_2); + cplx_set_e2pimx((*omg)[7], pom); + *omg += 8; +} + +/** + * @brief iFFT modulo X^16-omega^2 (in registers) + * @param data contains 16 complexes + * @param omegabar 8 complexes in this order: + * gammabar,jb.gammabar,kb.gammabar,kbjb.gammabar, + * betabar,jb.betabar,alphabar,omegabar + * alpha = sqrt(omega), beta = sqrt(alpha), gamma = sqrt(beta) + * jb = sqrt(ib), kb=sqrt(jb) + */ +void cplx_ifft16_ref(void* data, const void* omegabar) { + CPLX* d = data; + const CPLX* om = omegabar; + // fourth pass inverse + invctwiddle(d[0], d[1], om[0]); + invcitwiddle(d[2], d[3], om[0]); + invctwiddle(d[4], d[5], om[1]); + invcitwiddle(d[6], d[7], om[1]); + invctwiddle(d[8], d[9], om[2]); + invcitwiddle(d[10], d[11], om[2]); + invctwiddle(d[12], d[13], om[3]); + invcitwiddle(d[14], d[15], om[3]); + // third pass inverse + invctwiddle(d[0], d[2], om[4]); + invctwiddle(d[1], d[3], om[4]); + invcitwiddle(d[4], d[6], om[4]); + invcitwiddle(d[5], d[7], om[4]); + invctwiddle(d[8], d[10], om[5]); + invctwiddle(d[9], d[11], om[5]); + invcitwiddle(d[12], d[14], om[5]); + invcitwiddle(d[13], d[15], om[5]); + // second pass inverse + invctwiddle(d[0], d[4], om[6]); + invctwiddle(d[1], d[5], om[6]); + invctwiddle(d[2], d[6], om[6]); + invctwiddle(d[3], d[7], om[6]); + invcitwiddle(d[8], d[12], om[6]); + invcitwiddle(d[9], d[13], om[6]); + invcitwiddle(d[10], d[14], om[6]); + invcitwiddle(d[11], d[15], om[6]); + // first pass + for (uint64_t i = 0; i < 8; ++i) { + invctwiddle(d[0 + i], d[8 + i], om[7]); + } +} + +void cplx_ifft_ref_bfs_2(CPLX* dat, const CPLX** omg, uint32_t m) { + CPLX* const dend = dat + m; + for (CPLX* d = dat; d < dend; d += 2) { + split_fft_last_ref(d, (*omg)[0]); + *omg += 1; + } +#if 0 + printf("after first: "); + for (uint64_t ii=0; iim; + const CPLX* omg = (CPLX*)precomp->powomegas; + if (m == 1) return; + if (m <= 8) return cplx_ifft_ref_bfs_2(data, &omg, m); + if (m <= 2048) return cplx_ifft_ref_bfs_16(data, &omg, m); + cplx_ifft_ref_rec_16(data, &omg, m); +} + +EXPORT CPLX_IFFT_PRECOMP* new_cplx_ifft_precomp(uint32_t m, uint32_t num_buffers) { + const uint64_t OMG_SPACE = ceilto64b(2 * m * sizeof(CPLX)); + const uint64_t BUF_SIZE = ceilto64b(m * sizeof(CPLX)); + void* reps = malloc(sizeof(CPLX_IFFT_PRECOMP) + 63 // padding + + OMG_SPACE // tables + + num_buffers * BUF_SIZE // buffers + ); + uint64_t aligned_addr = ceilto64b((uint64_t)reps + sizeof(CPLX_IFFT_PRECOMP)); + CPLX_IFFT_PRECOMP* r = (CPLX_IFFT_PRECOMP*)reps; + r->m = m; + r->buf_size = BUF_SIZE; + r->powomegas = (double*)aligned_addr; + r->aligned_buffers = (void*)(aligned_addr + OMG_SPACE); + // fill in powomegas + CPLX* omg = (CPLX*)r->powomegas; + fill_cplx_ifft_omegas_rec_16(0.25, &omg, m); + if (((uint64_t)omg) - aligned_addr > OMG_SPACE) abort(); + { + if (m <= 4) { + // currently, we do not have any acceletated + // implementation for m<=4 + r->function = cplx_ifft_ref; + } else if (CPU_SUPPORTS("fma")) { + r->function = cplx_ifft_avx2_fma; + } else { + r->function = cplx_ifft_ref; + } + } + return reps; +} + +EXPORT void* cplx_ifft_precomp_get_buffer(const CPLX_IFFT_PRECOMP* itables, uint32_t buffer_index) { + return (uint8_t*)itables->aligned_buffers + buffer_index * itables->buf_size; +} + +EXPORT void cplx_ifft(const CPLX_IFFT_PRECOMP* itables, void* data) { itables->function(itables, data); } + +EXPORT void cplx_ifft_simple(uint32_t m, void* data) { + static CPLX_IFFT_PRECOMP* p[31] = {0}; + CPLX_IFFT_PRECOMP** f = p + log2m(m); + if (!*f) *f = new_cplx_ifft_precomp(m, 0); + (*f)->function(*f, data); +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/spqlios_cplx_fft.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/cplx/spqlios_cplx_fft.c new file mode 100644 index 0000000..e69de29 diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/ext/neon_accel/macrof.h b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/ext/neon_accel/macrof.h new file mode 100644 index 0000000..1db1895 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/ext/neon_accel/macrof.h @@ -0,0 +1,136 @@ +/* + * This file is extracted from the implementation of the FFT on Arm64/Neon + * available in https://github.com/cothan/Falcon-Arm (neon/macrof.h). + * ============================================================================= + * Copyright (c) 2022 by Cryptographic Engineering Research Group (CERG) + * ECE Department, George Mason University + * Fairfax, VA, U.S.A. + * @author: Duc Tri Nguyen dnguye69@gmu.edu, cothannguyen@gmail.com + * Licensed under the Apache License, Version 2.0 (the "License"); + * ============================================================================= + * + * This 64-bit Floating point NEON macro x1 has not been modified and is provided as is. + */ + +#ifndef MACROF_H +#define MACROF_H + +#include + +// c <= addr x1 +#define vload(c, addr) c = vld1q_f64(addr); +// c <= addr interleave 2 +#define vload2(c, addr) c = vld2q_f64(addr); +// c <= addr interleave 4 +#define vload4(c, addr) c = vld4q_f64(addr); + +#define vstore(addr, c) vst1q_f64(addr, c); +// addr <= c +#define vstore2(addr, c) vst2q_f64(addr, c); +// addr <= c +#define vstore4(addr, c) vst4q_f64(addr, c); + +// c <= addr x2 +#define vloadx2(c, addr) c = vld1q_f64_x2(addr); +// c <= addr x3 +#define vloadx3(c, addr) c = vld1q_f64_x3(addr); + +// addr <= c +#define vstorex2(addr, c) vst1q_f64_x2(addr, c); + +// c = a - b +#define vfsub(c, a, b) c = vsubq_f64(a, b); + +// c = a + b +#define vfadd(c, a, b) c = vaddq_f64(a, b); + +// c = a * b +#define vfmul(c, a, b) c = vmulq_f64(a, b); + +// c = a * n (n is constant) +#define vfmuln(c, a, n) c = vmulq_n_f64(a, n); + +// Swap from a|b to b|a +#define vswap(c, a) c = vextq_f64(a, a, 1); + +// c = a * b[i] +#define vfmul_lane(c, a, b, i) c = vmulq_laneq_f64(a, b, i); + +// c = 1/a +#define vfinv(c, a) c = vdivq_f64(vdupq_n_f64(1.0), a); + +// c = -a +#define vfneg(c, a) c = vnegq_f64(a); + +#define transpose_f64(a, b, t, ia, ib, it) \ + t.val[it] = a.val[ia]; \ + a.val[ia] = vzip1q_f64(a.val[ia], b.val[ib]); \ + b.val[ib] = vzip2q_f64(t.val[it], b.val[ib]); + +/* + * c = a + jb + * c[0] = a[0] - b[1] + * c[1] = a[1] + b[0] + */ +#define vfcaddj(c, a, b) c = vcaddq_rot90_f64(a, b); + +/* + * c = a - jb + * c[0] = a[0] + b[1] + * c[1] = a[1] - b[0] + */ +#define vfcsubj(c, a, b) c = vcaddq_rot270_f64(a, b); + +// c[0] = c[0] + b[0]*a[0], c[1] = c[1] + b[1]*a[0] +#define vfcmla(c, a, b) c = vcmlaq_f64(c, a, b); + +// c[0] = c[0] - b[1]*a[1], c[1] = c[1] + b[0]*a[1] +#define vfcmla_90(c, a, b) c = vcmlaq_rot90_f64(c, a, b); + +// c[0] = c[0] - b[0]*a[0], c[1] = c[1] - b[1]*a[0] +#define vfcmla_180(c, a, b) c = vcmlaq_rot180_f64(c, a, b); + +// c[0] = c[0] + b[1]*a[1], c[1] = c[1] - b[0]*a[1] +#define vfcmla_270(c, a, b) c = vcmlaq_rot270_f64(c, a, b); + +/* + * Complex MUL: c = a*b + * c[0] = a[0]*b[0] - a[1]*b[1] + * c[1] = a[0]*b[1] + a[1]*b[0] + */ +#define FPC_CMUL(c, a, b) \ + c = vmulq_laneq_f64(b, a, 0); \ + c = vcmlaq_rot90_f64(c, a, b); + +/* + * Complex MUL: c = a * conjugate(b) = a * (b[0], -b[1]) + * c[0] = b[0]*a[0] + b[1]*a[1] + * c[1] = + b[0]*a[1] - b[1]*a[0] + */ +#define FPC_CMUL_CONJ(c, a, b) \ + c = vmulq_laneq_f64(a, b, 0); \ + c = vcmlaq_rot270_f64(c, b, a); + +#if FMA == 1 +// d = c + a *b +#define vfmla(d, c, a, b) d = vfmaq_f64(c, a, b); +// d = c - a * b +#define vfmls(d, c, a, b) d = vfmsq_f64(c, a, b); +// d = c + a * b[i] +#define vfmla_lane(d, c, a, b, i) d = vfmaq_laneq_f64(c, a, b, i); +// d = c - a * b[i] +#define vfmls_lane(d, c, a, b, i) d = vfmsq_laneq_f64(c, a, b, i); + +#else +// d = c + a *b +#define vfmla(d, c, a, b) d = vaddq_f64(c, vmulq_f64(a, b)); +// d = c - a *b +#define vfmls(d, c, a, b) d = vsubq_f64(c, vmulq_f64(a, b)); +// d = c + a * b[i] +#define vfmla_lane(d, c, a, b, i) d = vaddq_f64(c, vmulq_laneq_f64(a, b, i)); + +#define vfmls_lane(d, c, a, b, i) d = vsubq_f64(c, vmulq_laneq_f64(a, b, i)); + +#endif + +#endif diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/ext/neon_accel/macrofx4.h b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/ext/neon_accel/macrofx4.h new file mode 100644 index 0000000..a5124dd --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/ext/neon_accel/macrofx4.h @@ -0,0 +1,420 @@ +/* + * This file is extracted from the implementation of the FFT on Arm64/Neon + * available in https://github.com/cothan/Falcon-Arm (neon/macrof.h). + * ============================================================================= + * Copyright (c) 2022 by Cryptographic Engineering Research Group (CERG) + * ECE Department, George Mason University + * Fairfax, VA, U.S.A. + * @author: Duc Tri Nguyen dnguye69@gmu.edu, cothannguyen@gmail.com + * Licensed under the Apache License, Version 2.0 (the "License"); + * ============================================================================= + * + * This 64-bit Floating point NEON macro x4 has not been modified and is provided as is. + */ + +#ifndef MACROFX4_H +#define MACROFX4_H + +#include + +#include "macrof.h" + +#define vloadx4(c, addr) c = vld1q_f64_x4(addr); + +#define vstorex4(addr, c) vst1q_f64_x4(addr, c); + +#define vfdupx4(c, constant) \ + c.val[0] = vdupq_n_f64(constant); \ + c.val[1] = vdupq_n_f64(constant); \ + c.val[2] = vdupq_n_f64(constant); \ + c.val[3] = vdupq_n_f64(constant); + +#define vfnegx4(c, a) \ + c.val[0] = vnegq_f64(a.val[0]); \ + c.val[1] = vnegq_f64(a.val[1]); \ + c.val[2] = vnegq_f64(a.val[2]); \ + c.val[3] = vnegq_f64(a.val[3]); + +#define vfmulnx4(c, a, n) \ + c.val[0] = vmulq_n_f64(a.val[0], n); \ + c.val[1] = vmulq_n_f64(a.val[1], n); \ + c.val[2] = vmulq_n_f64(a.val[2], n); \ + c.val[3] = vmulq_n_f64(a.val[3], n); + +// c = a - b +#define vfsubx4(c, a, b) \ + c.val[0] = vsubq_f64(a.val[0], b.val[0]); \ + c.val[1] = vsubq_f64(a.val[1], b.val[1]); \ + c.val[2] = vsubq_f64(a.val[2], b.val[2]); \ + c.val[3] = vsubq_f64(a.val[3], b.val[3]); + +// c = a + b +#define vfaddx4(c, a, b) \ + c.val[0] = vaddq_f64(a.val[0], b.val[0]); \ + c.val[1] = vaddq_f64(a.val[1], b.val[1]); \ + c.val[2] = vaddq_f64(a.val[2], b.val[2]); \ + c.val[3] = vaddq_f64(a.val[3], b.val[3]); + +#define vfmulx4(c, a, b) \ + c.val[0] = vmulq_f64(a.val[0], b.val[0]); \ + c.val[1] = vmulq_f64(a.val[1], b.val[1]); \ + c.val[2] = vmulq_f64(a.val[2], b.val[2]); \ + c.val[3] = vmulq_f64(a.val[3], b.val[3]); + +#define vfmulx4_i(c, a, b) \ + c.val[0] = vmulq_f64(a.val[0], b); \ + c.val[1] = vmulq_f64(a.val[1], b); \ + c.val[2] = vmulq_f64(a.val[2], b); \ + c.val[3] = vmulq_f64(a.val[3], b); + +#define vfinvx4(c, a) \ + c.val[0] = vdivq_f64(vdupq_n_f64(1.0), a.val[0]); \ + c.val[1] = vdivq_f64(vdupq_n_f64(1.0), a.val[1]); \ + c.val[2] = vdivq_f64(vdupq_n_f64(1.0), a.val[2]); \ + c.val[3] = vdivq_f64(vdupq_n_f64(1.0), a.val[3]); + +#define vfcvtx4(c, a) \ + c.val[0] = vcvtq_f64_s64(a.val[0]); \ + c.val[1] = vcvtq_f64_s64(a.val[1]); \ + c.val[2] = vcvtq_f64_s64(a.val[2]); \ + c.val[3] = vcvtq_f64_s64(a.val[3]); + +#define vfmlax4(d, c, a, b) \ + vfmla(d.val[0], c.val[0], a.val[0], b.val[0]); \ + vfmla(d.val[1], c.val[1], a.val[1], b.val[1]); \ + vfmla(d.val[2], c.val[2], a.val[2], b.val[2]); \ + vfmla(d.val[3], c.val[3], a.val[3], b.val[3]); + +#define vfmlsx4(d, c, a, b) \ + vfmls(d.val[0], c.val[0], a.val[0], b.val[0]); \ + vfmls(d.val[1], c.val[1], a.val[1], b.val[1]); \ + vfmls(d.val[2], c.val[2], a.val[2], b.val[2]); \ + vfmls(d.val[3], c.val[3], a.val[3], b.val[3]); + +#define vfrintx4(c, a) \ + c.val[0] = vcvtnq_s64_f64(a.val[0]); \ + c.val[1] = vcvtnq_s64_f64(a.val[1]); \ + c.val[2] = vcvtnq_s64_f64(a.val[2]); \ + c.val[3] = vcvtnq_s64_f64(a.val[3]); + +/* + * Wrapper for FFT, split/merge and poly_float.c + */ + +#define FPC_MUL(d_re, d_im, a_re, a_im, b_re, b_im) \ + vfmul(d_re, a_re, b_re); \ + vfmls(d_re, d_re, a_im, b_im); \ + vfmul(d_im, a_re, b_im); \ + vfmla(d_im, d_im, a_im, b_re); + +#define FPC_MULx2(d_re, d_im, a_re, a_im, b_re, b_im) \ + vfmul(d_re.val[0], a_re.val[0], b_re.val[0]); \ + vfmls(d_re.val[0], d_re.val[0], a_im.val[0], b_im.val[0]); \ + vfmul(d_re.val[1], a_re.val[1], b_re.val[1]); \ + vfmls(d_re.val[1], d_re.val[1], a_im.val[1], b_im.val[1]); \ + vfmul(d_im.val[0], a_re.val[0], b_im.val[0]); \ + vfmla(d_im.val[0], d_im.val[0], a_im.val[0], b_re.val[0]); \ + vfmul(d_im.val[1], a_re.val[1], b_im.val[1]); \ + vfmla(d_im.val[1], d_im.val[1], a_im.val[1], b_re.val[1]); + +#define FPC_MULx4(d_re, d_im, a_re, a_im, b_re, b_im) \ + vfmul(d_re.val[0], a_re.val[0], b_re.val[0]); \ + vfmls(d_re.val[0], d_re.val[0], a_im.val[0], b_im.val[0]); \ + vfmul(d_re.val[1], a_re.val[1], b_re.val[1]); \ + vfmls(d_re.val[1], d_re.val[1], a_im.val[1], b_im.val[1]); \ + vfmul(d_re.val[2], a_re.val[2], b_re.val[2]); \ + vfmls(d_re.val[2], d_re.val[2], a_im.val[2], b_im.val[2]); \ + vfmul(d_re.val[3], a_re.val[3], b_re.val[3]); \ + vfmls(d_re.val[3], d_re.val[3], a_im.val[3], b_im.val[3]); \ + vfmul(d_im.val[0], a_re.val[0], b_im.val[0]); \ + vfmla(d_im.val[0], d_im.val[0], a_im.val[0], b_re.val[0]); \ + vfmul(d_im.val[1], a_re.val[1], b_im.val[1]); \ + vfmla(d_im.val[1], d_im.val[1], a_im.val[1], b_re.val[1]); \ + vfmul(d_im.val[2], a_re.val[2], b_im.val[2]); \ + vfmla(d_im.val[2], d_im.val[2], a_im.val[2], b_re.val[2]); \ + vfmul(d_im.val[3], a_re.val[3], b_im.val[3]); \ + vfmla(d_im.val[3], d_im.val[3], a_im.val[3], b_re.val[3]); + +#define FPC_MLA(d_re, d_im, a_re, a_im, b_re, b_im) \ + vfmla(d_re, d_re, a_re, b_re); \ + vfmls(d_re, d_re, a_im, b_im); \ + vfmla(d_im, d_im, a_re, b_im); \ + vfmla(d_im, d_im, a_im, b_re); + +#define FPC_MLAx2(d_re, d_im, a_re, a_im, b_re, b_im) \ + vfmla(d_re.val[0], d_re.val[0], a_re.val[0], b_re.val[0]); \ + vfmls(d_re.val[0], d_re.val[0], a_im.val[0], b_im.val[0]); \ + vfmla(d_re.val[1], d_re.val[1], a_re.val[1], b_re.val[1]); \ + vfmls(d_re.val[1], d_re.val[1], a_im.val[1], b_im.val[1]); \ + vfmla(d_im.val[0], d_im.val[0], a_re.val[0], b_im.val[0]); \ + vfmla(d_im.val[0], d_im.val[0], a_im.val[0], b_re.val[0]); \ + vfmla(d_im.val[1], d_im.val[1], a_re.val[1], b_im.val[1]); \ + vfmla(d_im.val[1], d_im.val[1], a_im.val[1], b_re.val[1]); + +#define FPC_MLAx4(d_re, d_im, a_re, a_im, b_re, b_im) \ + vfmla(d_re.val[0], d_re.val[0], a_re.val[0], b_re.val[0]); \ + vfmls(d_re.val[0], d_re.val[0], a_im.val[0], b_im.val[0]); \ + vfmla(d_re.val[1], d_re.val[1], a_re.val[1], b_re.val[1]); \ + vfmls(d_re.val[1], d_re.val[1], a_im.val[1], b_im.val[1]); \ + vfmla(d_re.val[2], d_re.val[2], a_re.val[2], b_re.val[2]); \ + vfmls(d_re.val[2], d_re.val[2], a_im.val[2], b_im.val[2]); \ + vfmla(d_re.val[3], d_re.val[3], a_re.val[3], b_re.val[3]); \ + vfmls(d_re.val[3], d_re.val[3], a_im.val[3], b_im.val[3]); \ + vfmla(d_im.val[0], d_im.val[0], a_re.val[0], b_im.val[0]); \ + vfmla(d_im.val[0], d_im.val[0], a_im.val[0], b_re.val[0]); \ + vfmla(d_im.val[1], d_im.val[1], a_re.val[1], b_im.val[1]); \ + vfmla(d_im.val[1], d_im.val[1], a_im.val[1], b_re.val[1]); \ + vfmla(d_im.val[2], d_im.val[2], a_re.val[2], b_im.val[2]); \ + vfmla(d_im.val[2], d_im.val[2], a_im.val[2], b_re.val[2]); \ + vfmla(d_im.val[3], d_im.val[3], a_re.val[3], b_im.val[3]); \ + vfmla(d_im.val[3], d_im.val[3], a_im.val[3], b_re.val[3]); + +#define FPC_MUL_CONJx4(d_re, d_im, a_re, a_im, b_re, b_im) \ + vfmul(d_re.val[0], b_im.val[0], a_im.val[0]); \ + vfmla(d_re.val[0], d_re.val[0], a_re.val[0], b_re.val[0]); \ + vfmul(d_re.val[1], b_im.val[1], a_im.val[1]); \ + vfmla(d_re.val[1], d_re.val[1], a_re.val[1], b_re.val[1]); \ + vfmul(d_re.val[2], b_im.val[2], a_im.val[2]); \ + vfmla(d_re.val[2], d_re.val[2], a_re.val[2], b_re.val[2]); \ + vfmul(d_re.val[3], b_im.val[3], a_im.val[3]); \ + vfmla(d_re.val[3], d_re.val[3], a_re.val[3], b_re.val[3]); \ + vfmul(d_im.val[0], b_re.val[0], a_im.val[0]); \ + vfmls(d_im.val[0], d_im.val[0], a_re.val[0], b_im.val[0]); \ + vfmul(d_im.val[1], b_re.val[1], a_im.val[1]); \ + vfmls(d_im.val[1], d_im.val[1], a_re.val[1], b_im.val[1]); \ + vfmul(d_im.val[2], b_re.val[2], a_im.val[2]); \ + vfmls(d_im.val[2], d_im.val[2], a_re.val[2], b_im.val[2]); \ + vfmul(d_im.val[3], b_re.val[3], a_im.val[3]); \ + vfmls(d_im.val[3], d_im.val[3], a_re.val[3], b_im.val[3]); + +#define FPC_MLA_CONJx4(d_re, d_im, a_re, a_im, b_re, b_im) \ + vfmla(d_re.val[0], d_re.val[0], b_im.val[0], a_im.val[0]); \ + vfmla(d_re.val[0], d_re.val[0], a_re.val[0], b_re.val[0]); \ + vfmla(d_re.val[1], d_re.val[1], b_im.val[1], a_im.val[1]); \ + vfmla(d_re.val[1], d_re.val[1], a_re.val[1], b_re.val[1]); \ + vfmla(d_re.val[2], d_re.val[2], b_im.val[2], a_im.val[2]); \ + vfmla(d_re.val[2], d_re.val[2], a_re.val[2], b_re.val[2]); \ + vfmla(d_re.val[3], d_re.val[3], b_im.val[3], a_im.val[3]); \ + vfmla(d_re.val[3], d_re.val[3], a_re.val[3], b_re.val[3]); \ + vfmla(d_im.val[0], d_im.val[0], b_re.val[0], a_im.val[0]); \ + vfmls(d_im.val[0], d_im.val[0], a_re.val[0], b_im.val[0]); \ + vfmla(d_im.val[1], d_im.val[1], b_re.val[1], a_im.val[1]); \ + vfmls(d_im.val[1], d_im.val[1], a_re.val[1], b_im.val[1]); \ + vfmla(d_im.val[2], d_im.val[2], b_re.val[2], a_im.val[2]); \ + vfmls(d_im.val[2], d_im.val[2], a_re.val[2], b_im.val[2]); \ + vfmla(d_im.val[3], d_im.val[3], b_re.val[3], a_im.val[3]); \ + vfmls(d_im.val[3], d_im.val[3], a_re.val[3], b_im.val[3]); + +#define FPC_MUL_LANE(d_re, d_im, a_re, a_im, b_re_im) \ + vfmul_lane(d_re, a_re, b_re_im, 0); \ + vfmls_lane(d_re, d_re, a_im, b_re_im, 1); \ + vfmul_lane(d_im, a_re, b_re_im, 1); \ + vfmla_lane(d_im, d_im, a_im, b_re_im, 0); + +#define FPC_MUL_LANEx4(d_re, d_im, a_re, a_im, b_re_im) \ + vfmul_lane(d_re.val[0], a_re.val[0], b_re_im, 0); \ + vfmls_lane(d_re.val[0], d_re.val[0], a_im.val[0], b_re_im, 1); \ + vfmul_lane(d_re.val[1], a_re.val[1], b_re_im, 0); \ + vfmls_lane(d_re.val[1], d_re.val[1], a_im.val[1], b_re_im, 1); \ + vfmul_lane(d_re.val[2], a_re.val[2], b_re_im, 0); \ + vfmls_lane(d_re.val[2], d_re.val[2], a_im.val[2], b_re_im, 1); \ + vfmul_lane(d_re.val[3], a_re.val[3], b_re_im, 0); \ + vfmls_lane(d_re.val[3], d_re.val[3], a_im.val[3], b_re_im, 1); \ + vfmul_lane(d_im.val[0], a_re.val[0], b_re_im, 1); \ + vfmla_lane(d_im.val[0], d_im.val[0], a_im.val[0], b_re_im, 0); \ + vfmul_lane(d_im.val[1], a_re.val[1], b_re_im, 1); \ + vfmla_lane(d_im.val[1], d_im.val[1], a_im.val[1], b_re_im, 0); \ + vfmul_lane(d_im.val[2], a_re.val[2], b_re_im, 1); \ + vfmla_lane(d_im.val[2], d_im.val[2], a_im.val[2], b_re_im, 0); \ + vfmul_lane(d_im.val[3], a_re.val[3], b_re_im, 1); \ + vfmla_lane(d_im.val[3], d_im.val[3], a_im.val[3], b_re_im, 0); + +#define FWD_TOP(t_re, t_im, b_re, b_im, zeta_re, zeta_im) FPC_MUL(t_re, t_im, b_re, b_im, zeta_re, zeta_im); + +#define FWD_TOP_LANE(t_re, t_im, b_re, b_im, zeta) FPC_MUL_LANE(t_re, t_im, b_re, b_im, zeta); + +#define FWD_TOP_LANEx4(t_re, t_im, b_re, b_im, zeta) FPC_MUL_LANEx4(t_re, t_im, b_re, b_im, zeta); + +/* + * FPC + */ + +#define FPC_SUB(d_re, d_im, a_re, a_im, b_re, b_im) \ + d_re = vsubq_f64(a_re, b_re); \ + d_im = vsubq_f64(a_im, b_im); + +#define FPC_SUBx4(d_re, d_im, a_re, a_im, b_re, b_im) \ + d_re.val[0] = vsubq_f64(a_re.val[0], b_re.val[0]); \ + d_im.val[0] = vsubq_f64(a_im.val[0], b_im.val[0]); \ + d_re.val[1] = vsubq_f64(a_re.val[1], b_re.val[1]); \ + d_im.val[1] = vsubq_f64(a_im.val[1], b_im.val[1]); \ + d_re.val[2] = vsubq_f64(a_re.val[2], b_re.val[2]); \ + d_im.val[2] = vsubq_f64(a_im.val[2], b_im.val[2]); \ + d_re.val[3] = vsubq_f64(a_re.val[3], b_re.val[3]); \ + d_im.val[3] = vsubq_f64(a_im.val[3], b_im.val[3]); + +#define FPC_ADD(d_re, d_im, a_re, a_im, b_re, b_im) \ + d_re = vaddq_f64(a_re, b_re); \ + d_im = vaddq_f64(a_im, b_im); + +#define FPC_ADDx4(d_re, d_im, a_re, a_im, b_re, b_im) \ + d_re.val[0] = vaddq_f64(a_re.val[0], b_re.val[0]); \ + d_im.val[0] = vaddq_f64(a_im.val[0], b_im.val[0]); \ + d_re.val[1] = vaddq_f64(a_re.val[1], b_re.val[1]); \ + d_im.val[1] = vaddq_f64(a_im.val[1], b_im.val[1]); \ + d_re.val[2] = vaddq_f64(a_re.val[2], b_re.val[2]); \ + d_im.val[2] = vaddq_f64(a_im.val[2], b_im.val[2]); \ + d_re.val[3] = vaddq_f64(a_re.val[3], b_re.val[3]); \ + d_im.val[3] = vaddq_f64(a_im.val[3], b_im.val[3]); + +#define FWD_BOT(a_re, a_im, b_re, b_im, t_re, t_im) \ + FPC_SUB(b_re, b_im, a_re, a_im, t_re, t_im); \ + FPC_ADD(a_re, a_im, a_re, a_im, t_re, t_im); + +#define FWD_BOTx4(a_re, a_im, b_re, b_im, t_re, t_im) \ + FPC_SUBx4(b_re, b_im, a_re, a_im, t_re, t_im); \ + FPC_ADDx4(a_re, a_im, a_re, a_im, t_re, t_im); + +/* + * FPC_J + */ + +#define FPC_ADDJ(d_re, d_im, a_re, a_im, b_re, b_im) \ + d_re = vsubq_f64(a_re, b_im); \ + d_im = vaddq_f64(a_im, b_re); + +#define FPC_ADDJx4(d_re, d_im, a_re, a_im, b_re, b_im) \ + d_re.val[0] = vsubq_f64(a_re.val[0], b_im.val[0]); \ + d_im.val[0] = vaddq_f64(a_im.val[0], b_re.val[0]); \ + d_re.val[1] = vsubq_f64(a_re.val[1], b_im.val[1]); \ + d_im.val[1] = vaddq_f64(a_im.val[1], b_re.val[1]); \ + d_re.val[2] = vsubq_f64(a_re.val[2], b_im.val[2]); \ + d_im.val[2] = vaddq_f64(a_im.val[2], b_re.val[2]); \ + d_re.val[3] = vsubq_f64(a_re.val[3], b_im.val[3]); \ + d_im.val[3] = vaddq_f64(a_im.val[3], b_re.val[3]); + +#define FPC_SUBJ(d_re, d_im, a_re, a_im, b_re, b_im) \ + d_re = vaddq_f64(a_re, b_im); \ + d_im = vsubq_f64(a_im, b_re); + +#define FPC_SUBJx4(d_re, d_im, a_re, a_im, b_re, b_im) \ + d_re.val[0] = vaddq_f64(a_re.val[0], b_im.val[0]); \ + d_im.val[0] = vsubq_f64(a_im.val[0], b_re.val[0]); \ + d_re.val[1] = vaddq_f64(a_re.val[1], b_im.val[1]); \ + d_im.val[1] = vsubq_f64(a_im.val[1], b_re.val[1]); \ + d_re.val[2] = vaddq_f64(a_re.val[2], b_im.val[2]); \ + d_im.val[2] = vsubq_f64(a_im.val[2], b_re.val[2]); \ + d_re.val[3] = vaddq_f64(a_re.val[3], b_im.val[3]); \ + d_im.val[3] = vsubq_f64(a_im.val[3], b_re.val[3]); + +#define FWD_BOTJ(a_re, a_im, b_re, b_im, t_re, t_im) \ + FPC_SUBJ(b_re, b_im, a_re, a_im, t_re, t_im); \ + FPC_ADDJ(a_re, a_im, a_re, a_im, t_re, t_im); + +#define FWD_BOTJx4(a_re, a_im, b_re, b_im, t_re, t_im) \ + FPC_SUBJx4(b_re, b_im, a_re, a_im, t_re, t_im); \ + FPC_ADDJx4(a_re, a_im, a_re, a_im, t_re, t_im); + +//============== Inverse FFT +/* + * FPC_J + * a * conj(b) + * Original (without swap): + * d_re = b_im * a_im + a_re * b_re; + * d_im = b_re * a_im - a_re * b_im; + */ +#define FPC_MUL_BOTJ_LANE(d_re, d_im, a_re, a_im, b_re_im) \ + vfmul_lane(d_re, a_re, b_re_im, 0); \ + vfmla_lane(d_re, d_re, a_im, b_re_im, 1); \ + vfmul_lane(d_im, a_im, b_re_im, 0); \ + vfmls_lane(d_im, d_im, a_re, b_re_im, 1); + +#define FPC_MUL_BOTJ_LANEx4(d_re, d_im, a_re, a_im, b_re_im) \ + vfmul_lane(d_re.val[0], a_re.val[0], b_re_im, 0); \ + vfmla_lane(d_re.val[0], d_re.val[0], a_im.val[0], b_re_im, 1); \ + vfmul_lane(d_im.val[0], a_im.val[0], b_re_im, 0); \ + vfmls_lane(d_im.val[0], d_im.val[0], a_re.val[0], b_re_im, 1); \ + vfmul_lane(d_re.val[1], a_re.val[1], b_re_im, 0); \ + vfmla_lane(d_re.val[1], d_re.val[1], a_im.val[1], b_re_im, 1); \ + vfmul_lane(d_im.val[1], a_im.val[1], b_re_im, 0); \ + vfmls_lane(d_im.val[1], d_im.val[1], a_re.val[1], b_re_im, 1); \ + vfmul_lane(d_re.val[2], a_re.val[2], b_re_im, 0); \ + vfmla_lane(d_re.val[2], d_re.val[2], a_im.val[2], b_re_im, 1); \ + vfmul_lane(d_im.val[2], a_im.val[2], b_re_im, 0); \ + vfmls_lane(d_im.val[2], d_im.val[2], a_re.val[2], b_re_im, 1); \ + vfmul_lane(d_re.val[3], a_re.val[3], b_re_im, 0); \ + vfmla_lane(d_re.val[3], d_re.val[3], a_im.val[3], b_re_im, 1); \ + vfmul_lane(d_im.val[3], a_im.val[3], b_re_im, 0); \ + vfmls_lane(d_im.val[3], d_im.val[3], a_re.val[3], b_re_im, 1); + +#define FPC_MUL_BOTJ(d_re, d_im, a_re, a_im, b_re, b_im) \ + vfmul(d_re, b_im, a_im); \ + vfmla(d_re, d_re, a_re, b_re); \ + vfmul(d_im, b_re, a_im); \ + vfmls(d_im, d_im, a_re, b_im); + +#define INV_TOPJ(t_re, t_im, a_re, a_im, b_re, b_im) \ + FPC_SUB(t_re, t_im, a_re, a_im, b_re, b_im); \ + FPC_ADD(a_re, a_im, a_re, a_im, b_re, b_im); + +#define INV_TOPJx4(t_re, t_im, a_re, a_im, b_re, b_im) \ + FPC_SUBx4(t_re, t_im, a_re, a_im, b_re, b_im); \ + FPC_ADDx4(a_re, a_im, a_re, a_im, b_re, b_im); + +#define INV_BOTJ(b_re, b_im, t_re, t_im, zeta_re, zeta_im) FPC_MUL_BOTJ(b_re, b_im, t_re, t_im, zeta_re, zeta_im); + +#define INV_BOTJ_LANE(b_re, b_im, t_re, t_im, zeta) FPC_MUL_BOTJ_LANE(b_re, b_im, t_re, t_im, zeta); + +#define INV_BOTJ_LANEx4(b_re, b_im, t_re, t_im, zeta) FPC_MUL_BOTJ_LANEx4(b_re, b_im, t_re, t_im, zeta); + +/* + * FPC_Jm + * a * -conj(b) + * d_re = a_re * b_im - a_im * b_re; + * d_im = a_im * b_im + a_re * b_re; + */ +#define FPC_MUL_BOTJm_LANE(d_re, d_im, a_re, a_im, b_re_im) \ + vfmul_lane(d_re, a_re, b_re_im, 1); \ + vfmls_lane(d_re, d_re, a_im, b_re_im, 0); \ + vfmul_lane(d_im, a_re, b_re_im, 0); \ + vfmla_lane(d_im, d_im, a_im, b_re_im, 1); + +#define FPC_MUL_BOTJm_LANEx4(d_re, d_im, a_re, a_im, b_re_im) \ + vfmul_lane(d_re.val[0], a_re.val[0], b_re_im, 1); \ + vfmls_lane(d_re.val[0], d_re.val[0], a_im.val[0], b_re_im, 0); \ + vfmul_lane(d_im.val[0], a_re.val[0], b_re_im, 0); \ + vfmla_lane(d_im.val[0], d_im.val[0], a_im.val[0], b_re_im, 1); \ + vfmul_lane(d_re.val[1], a_re.val[1], b_re_im, 1); \ + vfmls_lane(d_re.val[1], d_re.val[1], a_im.val[1], b_re_im, 0); \ + vfmul_lane(d_im.val[1], a_re.val[1], b_re_im, 0); \ + vfmla_lane(d_im.val[1], d_im.val[1], a_im.val[1], b_re_im, 1); \ + vfmul_lane(d_re.val[2], a_re.val[2], b_re_im, 1); \ + vfmls_lane(d_re.val[2], d_re.val[2], a_im.val[2], b_re_im, 0); \ + vfmul_lane(d_im.val[2], a_re.val[2], b_re_im, 0); \ + vfmla_lane(d_im.val[2], d_im.val[2], a_im.val[2], b_re_im, 1); \ + vfmul_lane(d_re.val[3], a_re.val[3], b_re_im, 1); \ + vfmls_lane(d_re.val[3], d_re.val[3], a_im.val[3], b_re_im, 0); \ + vfmul_lane(d_im.val[3], a_re.val[3], b_re_im, 0); \ + vfmla_lane(d_im.val[3], d_im.val[3], a_im.val[3], b_re_im, 1); + +#define FPC_MUL_BOTJm(d_re, d_im, a_re, a_im, b_re, b_im) \ + vfmul(d_re, a_re, b_im); \ + vfmls(d_re, d_re, a_im, b_re); \ + vfmul(d_im, a_im, b_im); \ + vfmla(d_im, d_im, a_re, b_re); + +#define INV_TOPJm(t_re, t_im, a_re, a_im, b_re, b_im) \ + FPC_SUB(t_re, t_im, b_re, b_im, a_re, a_im); \ + FPC_ADD(a_re, a_im, a_re, a_im, b_re, b_im); + +#define INV_TOPJmx4(t_re, t_im, a_re, a_im, b_re, b_im) \ + FPC_SUBx4(t_re, t_im, b_re, b_im, a_re, a_im); \ + FPC_ADDx4(a_re, a_im, a_re, a_im, b_re, b_im); + +#define INV_BOTJm(b_re, b_im, t_re, t_im, zeta_re, zeta_im) FPC_MUL_BOTJm(b_re, b_im, t_re, t_im, zeta_re, zeta_im); + +#define INV_BOTJm_LANE(b_re, b_im, t_re, t_im, zeta) FPC_MUL_BOTJm_LANE(b_re, b_im, t_re, t_im, zeta); + +#define INV_BOTJm_LANEx4(b_re, b_im, t_re, t_im, zeta) FPC_MUL_BOTJm_LANEx4(b_re, b_im, t_re, t_im, zeta); + +#endif diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/q120/q120_arithmetic.h b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/q120/q120_arithmetic.h new file mode 100644 index 0000000..31745d0 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/q120/q120_arithmetic.h @@ -0,0 +1,115 @@ +#ifndef SPQLIOS_Q120_ARITHMETIC_H +#define SPQLIOS_Q120_ARITHMETIC_H + +#include + +#include "../commons.h" +#include "q120_common.h" + +typedef struct _q120_mat1col_product_baa_precomp q120_mat1col_product_baa_precomp; +typedef struct _q120_mat1col_product_bbb_precomp q120_mat1col_product_bbb_precomp; +typedef struct _q120_mat1col_product_bbc_precomp q120_mat1col_product_bbc_precomp; + +EXPORT q120_mat1col_product_baa_precomp* q120_new_vec_mat1col_product_baa_precomp(); +EXPORT void q120_delete_vec_mat1col_product_baa_precomp(q120_mat1col_product_baa_precomp*); +EXPORT q120_mat1col_product_bbb_precomp* q120_new_vec_mat1col_product_bbb_precomp(); +EXPORT void q120_delete_vec_mat1col_product_bbb_precomp(q120_mat1col_product_bbb_precomp*); +EXPORT q120_mat1col_product_bbc_precomp* q120_new_vec_mat1col_product_bbc_precomp(); +EXPORT void q120_delete_vec_mat1col_product_bbc_precomp(q120_mat1col_product_bbc_precomp*); + +// ell < 10000 +EXPORT void q120_vec_mat1col_product_baa_ref(q120_mat1col_product_baa_precomp*, const uint64_t ell, q120b* const res, + const q120a* const x, const q120a* const y); +EXPORT void q120_vec_mat1col_product_bbb_ref(q120_mat1col_product_bbb_precomp*, const uint64_t ell, q120b* const res, + const q120b* const x, const q120b* const y); +EXPORT void q120_vec_mat1col_product_bbc_ref(q120_mat1col_product_bbc_precomp*, const uint64_t ell, q120b* const res, + const q120b* const x, const q120c* const y); + +EXPORT void q120_vec_mat1col_product_baa_avx2(q120_mat1col_product_baa_precomp*, const uint64_t ell, q120b* const res, + const q120a* const x, const q120a* const y); +EXPORT void q120_vec_mat1col_product_bbb_avx2(q120_mat1col_product_bbb_precomp*, const uint64_t ell, q120b* const res, + const q120b* const x, const q120b* const y); +EXPORT void q120_vec_mat1col_product_bbc_avx2(q120_mat1col_product_bbc_precomp*, const uint64_t ell, q120b* const res, + const q120b* const x, const q120c* const y); + +EXPORT void q120x2_vec_mat1col_product_bbc_ref(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell, + q120b* const res, const q120b* const x, const q120c* const y); +EXPORT void q120x2_vec_mat1col_product_bbc_avx2(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell, + q120b* const res, const q120b* const x, const q120c* const y); +EXPORT void q120x2_vec_mat2cols_product_bbc_ref(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell, + q120b* const res, const q120b* const x, const q120c* const y); +EXPORT void q120x2_vec_mat2cols_product_bbc_avx2(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell, + q120b* const res, const q120b* const x, const q120c* const y); + +/** + * @brief extract 1 q120x2 block from one q120 ntt vectors + * @param nn the size of each vector + * @param blk the block id to extract ( + +#include "q120_arithmetic.h" +#include "q120_arithmetic_private.h" + +EXPORT void q120_vec_mat1col_product_baa_avx2(q120_mat1col_product_baa_precomp* precomp, const uint64_t ell, + q120b* const res, const q120a* const x, const q120a* const y) { + /** + * Algorithm: + * - res = acc1 + acc2 . ((2^H) % Q) + * - acc1 is the sum of H LSB of products x[i].y[i] + * - acc2 is the sum of 64-H MSB of products x[i]].y[i] + * - for l < 10k acc1 will have H + log2(10000) and acc2 64 - H + log2(10000) bits + * - final sum has max(H, 64 - H + bit_size((2^H) % Q)) + log2(10000) + 1 bits + */ + + const uint64_t H = precomp->h; + const __m256i MASK = _mm256_set1_epi64x((UINT64_C(1) << H) - 1); + + __m256i acc1 = _mm256_setzero_si256(); + __m256i acc2 = _mm256_setzero_si256(); + + const __m256i* x_ptr = (__m256i*)x; + const __m256i* y_ptr = (__m256i*)y; + + for (uint64_t i = 0; i < ell; ++i) { + __m256i a = _mm256_loadu_si256(x_ptr); + __m256i b = _mm256_loadu_si256(y_ptr); + __m256i t = _mm256_mul_epu32(a, b); + + acc1 = _mm256_add_epi64(acc1, _mm256_and_si256(t, MASK)); + acc2 = _mm256_add_epi64(acc2, _mm256_srli_epi64(t, H)); + + x_ptr++; + y_ptr++; + } + + const __m256i H_POW_RED = _mm256_loadu_si256((__m256i*)precomp->h_pow_red); + + __m256i t = _mm256_add_epi64(acc1, _mm256_mul_epu32(acc2, H_POW_RED)); + _mm256_storeu_si256((__m256i*)res, t); +} + +EXPORT void q120_vec_mat1col_product_bbb_avx2(q120_mat1col_product_bbb_precomp* precomp, const uint64_t ell, + q120b* const res, const q120b* const x, const q120b* const y) { + /** + * Algorithm: + * 1. Split x_i and y_i in 2 32-bit parts and compute the cross-products: + * - x_i = xl_i + xh_i . 2^32 + * - y_i = yl_i + yh_i . 2^32 + * - A_i = xl_i . yl_i + * - B_i = xl_i . yh_i + * - C_i = xh_i . yl_i + * - D_i = xh_i . yh_i + * - we have x_i . y_i == A_i + (B_i + C_i) . 2^32 + D_i . 2^64 + * 2. Split A_i, B_i, C_i and D_i into 2 32-bit parts + * - A_i = Al_i + Ah_i . 2^32 + * - B_i = Bl_i + Bh_i . 2^32 + * - C_i = Cl_i + Ch_i . 2^32 + * - D_i = Dl_i + Dh_i . 2^32 + * 3. Compute the sums: + * - S1 = \sum Al_i + * - S2 = \sum (Ah_i + Bl_i + Cl_i) + * - S3 = \sum (Bh_i + Ch_i + Dl_i) + * - S4 = \sum Dh_i + * - here S1, S4 have 32 + log2(ell) bits and S2, S3 have 32 + log2(ell) + + * log2(3) bits + * - for ell == 10000 S2, S3 have < 47 bits + * 4. Split S1, S2, S3 and S4 in 2 24-bit parts (24 = ceil(47/2)) + * - S1 = S1l + S1h . 2^24 + * - S2 = S2l + S2h . 2^24 + * - S3 = S3l + S3h . 2^24 + * - S4 = S4l + S4h . 2^24 + * 5. Compute final result as: + * - \sum x_i . y_i = S1l + S1h . 2^24 + * + S2l . 2^32 + S2h . 2^(32+24) + * + S3l . 2^64 + S3h . 2^(64 + 24) + * + S4l . 2^96 + S4l . 2^(96+24) + * - here the powers of 2 are reduced modulo the primes Q before + * multiplications + * - the result will be on 24 + 3 + bit size of primes Q + */ + const uint64_t H1 = 32; + const __m256i MASK1 = _mm256_set1_epi64x((UINT64_C(1) << H1) - 1); + + __m256i s1 = _mm256_setzero_si256(); + __m256i s2 = _mm256_setzero_si256(); + __m256i s3 = _mm256_setzero_si256(); + __m256i s4 = _mm256_setzero_si256(); + + const __m256i* x_ptr = (__m256i*)x; + const __m256i* y_ptr = (__m256i*)y; + + for (uint64_t i = 0; i < ell; ++i) { + __m256i x = _mm256_loadu_si256(x_ptr); + __m256i xl = _mm256_and_si256(x, MASK1); + __m256i xh = _mm256_srli_epi64(x, H1); + + __m256i y = _mm256_loadu_si256(y_ptr); + __m256i yl = _mm256_and_si256(y, MASK1); + __m256i yh = _mm256_srli_epi64(y, H1); + + __m256i a = _mm256_mul_epu32(xl, yl); + __m256i b = _mm256_mul_epu32(xl, yh); + __m256i c = _mm256_mul_epu32(xh, yl); + __m256i d = _mm256_mul_epu32(xh, yh); + + s1 = _mm256_add_epi64(s1, _mm256_and_si256(a, MASK1)); + + s2 = _mm256_add_epi64(s2, _mm256_srli_epi64(a, H1)); + s2 = _mm256_add_epi64(s2, _mm256_and_si256(b, MASK1)); + s2 = _mm256_add_epi64(s2, _mm256_and_si256(c, MASK1)); + + s3 = _mm256_add_epi64(s3, _mm256_srli_epi64(b, H1)); + s3 = _mm256_add_epi64(s3, _mm256_srli_epi64(c, H1)); + s3 = _mm256_add_epi64(s3, _mm256_and_si256(d, MASK1)); + + s4 = _mm256_add_epi64(s4, _mm256_srli_epi64(d, H1)); + + x_ptr++; + y_ptr++; + } + + const uint64_t H2 = precomp->h; + const __m256i MASK2 = _mm256_set1_epi64x((UINT64_C(1) << H2) - 1); + + const __m256i S1H_POW_RED = _mm256_loadu_si256((__m256i*)precomp->s1h_pow_red); + __m256i s1l = _mm256_and_si256(s1, MASK2); + __m256i s1h = _mm256_srli_epi64(s1, H2); + __m256i t = _mm256_add_epi64(s1l, _mm256_mul_epu32(s1h, S1H_POW_RED)); + + const __m256i S2L_POW_RED = _mm256_loadu_si256((__m256i*)precomp->s2l_pow_red); + const __m256i S2H_POW_RED = _mm256_loadu_si256((__m256i*)precomp->s2h_pow_red); + __m256i s2l = _mm256_and_si256(s2, MASK2); + __m256i s2h = _mm256_srli_epi64(s2, H2); + t = _mm256_add_epi64(t, _mm256_mul_epu32(s2l, S2L_POW_RED)); + t = _mm256_add_epi64(t, _mm256_mul_epu32(s2h, S2H_POW_RED)); + + const __m256i S3L_POW_RED = _mm256_loadu_si256((__m256i*)precomp->s3l_pow_red); + const __m256i S3H_POW_RED = _mm256_loadu_si256((__m256i*)precomp->s3h_pow_red); + __m256i s3l = _mm256_and_si256(s3, MASK2); + __m256i s3h = _mm256_srli_epi64(s3, H2); + t = _mm256_add_epi64(t, _mm256_mul_epu32(s3l, S3L_POW_RED)); + t = _mm256_add_epi64(t, _mm256_mul_epu32(s3h, S3H_POW_RED)); + + const __m256i S4L_POW_RED = _mm256_loadu_si256((__m256i*)precomp->s4l_pow_red); + const __m256i S4H_POW_RED = _mm256_loadu_si256((__m256i*)precomp->s4h_pow_red); + __m256i s4l = _mm256_and_si256(s4, MASK2); + __m256i s4h = _mm256_srli_epi64(s4, H2); + t = _mm256_add_epi64(t, _mm256_mul_epu32(s4l, S4L_POW_RED)); + t = _mm256_add_epi64(t, _mm256_mul_epu32(s4h, S4H_POW_RED)); + + _mm256_storeu_si256((__m256i*)res, t); +} + +EXPORT void q120_vec_mat1col_product_bbc_avx2(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell, + q120b* const res, const q120b* const x, const q120c* const y) { + /** + * Algorithm: + * 0. We have + * - y0_i == y_i % Q and y1_i == (y_i . 2^32) % Q + * 1. Split x_i in 2 32-bit parts and compute the cross-products: + * - x_i = xl_i + xh_i . 2^32 + * - A_i = xl_i . y1_i + * - B_i = xh_i . y2_i + * - we have x_i . y_i == A_i + B_i + * 2. Split A_i and B_i into 2 32-bit parts + * - A_i = Al_i + Ah_i . 2^32 + * - B_i = Bl_i + Bh_i . 2^32 + * 3. Compute the sums: + * - S1 = \sum Al_i + Bl_i + * - S2 = \sum Ah_i + Bh_i + * - here S1 and S2 have 32 + log2(ell) bits + * - for ell == 10000 S1, S2 have < 46 bits + * 4. Split S2 in 27-bit and 19-bit parts (27+19 == 46) + * - S2 = S2l + S2h . 2^27 + * 5. Compute final result as: + * - \sum x_i . y_i = S1 + S2l . 2^32 + S2h . 2^(32+27) + * - here the powers of 2 are reduced modulo the primes Q before + * multiplications + * - the result will be on < 52 bits + */ + + const uint64_t H1 = 32; + const __m256i MASK1 = _mm256_set1_epi64x((UINT64_C(1) << H1) - 1); + + __m256i s1 = _mm256_setzero_si256(); + __m256i s2 = _mm256_setzero_si256(); + + const __m256i* x_ptr = (__m256i*)x; + const __m256i* y_ptr = (__m256i*)y; + + for (uint64_t i = 0; i < ell; ++i) { + __m256i x = _mm256_loadu_si256(x_ptr); + __m256i xl = _mm256_and_si256(x, MASK1); + __m256i xh = _mm256_srli_epi64(x, H1); + + __m256i y = _mm256_loadu_si256(y_ptr); + __m256i y0 = _mm256_and_si256(y, MASK1); + __m256i y1 = _mm256_srli_epi64(y, H1); + + __m256i a = _mm256_mul_epu32(xl, y0); + __m256i b = _mm256_mul_epu32(xh, y1); + + s1 = _mm256_add_epi64(s1, _mm256_and_si256(a, MASK1)); + s1 = _mm256_add_epi64(s1, _mm256_and_si256(b, MASK1)); + + s2 = _mm256_add_epi64(s2, _mm256_srli_epi64(a, H1)); + s2 = _mm256_add_epi64(s2, _mm256_srli_epi64(b, H1)); + + x_ptr++; + y_ptr++; + } + + const uint64_t H2 = precomp->h; + const __m256i MASK2 = _mm256_set1_epi64x((UINT64_C(1) << H2) - 1); + + __m256i t = s1; + + const __m256i S2L_POW_RED = _mm256_loadu_si256((__m256i*)precomp->s2l_pow_red); + const __m256i S2H_POW_RED = _mm256_loadu_si256((__m256i*)precomp->s2h_pow_red); + __m256i s2l = _mm256_and_si256(s2, MASK2); + __m256i s2h = _mm256_srli_epi64(s2, H2); + t = _mm256_add_epi64(t, _mm256_mul_epu32(s2l, S2L_POW_RED)); + t = _mm256_add_epi64(t, _mm256_mul_epu32(s2h, S2H_POW_RED)); + + _mm256_storeu_si256((__m256i*)res, t); +} + +/** + * @deprecated keeping this one for history only. + * There is a slight register starvation condition on the q120x2_vec_mat2cols + * strategy below sounds better. + */ +EXPORT void q120x2_vec_mat2cols_product_bbc_avx2_old(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell, + q120b* const res, const q120b* const x, const q120c* const y) { + __m256i s0 = _mm256_setzero_si256(); // col 1a + __m256i s1 = _mm256_setzero_si256(); + __m256i s2 = _mm256_setzero_si256(); // col 1b + __m256i s3 = _mm256_setzero_si256(); + __m256i s4 = _mm256_setzero_si256(); // col 2a + __m256i s5 = _mm256_setzero_si256(); + __m256i s6 = _mm256_setzero_si256(); // col 2b + __m256i s7 = _mm256_setzero_si256(); + __m256i s8, s9, s10, s11; + __m256i s12, s13, s14, s15; + + const __m256i* x_ptr = (__m256i*)x; + const __m256i* y_ptr = (__m256i*)y; + __m256i* res_ptr = (__m256i*)res; + for (uint64_t i = 0; i < ell; ++i) { + s8 = _mm256_loadu_si256(x_ptr); + s9 = _mm256_loadu_si256(x_ptr + 1); + s10 = _mm256_srli_epi64(s8, 32); + s11 = _mm256_srli_epi64(s9, 32); + + s12 = _mm256_loadu_si256(y_ptr); + s13 = _mm256_loadu_si256(y_ptr + 1); + s14 = _mm256_srli_epi64(s12, 32); + s15 = _mm256_srli_epi64(s13, 32); + + s12 = _mm256_mul_epu32(s8, s12); // -> s0,s1 + s13 = _mm256_mul_epu32(s9, s13); // -> s2,s3 + s14 = _mm256_mul_epu32(s10, s14); // -> s0,s1 + s15 = _mm256_mul_epu32(s11, s15); // -> s2,s3 + + s10 = _mm256_slli_epi64(s12, 32); // -> s0 + s11 = _mm256_slli_epi64(s13, 32); // -> s2 + s12 = _mm256_srli_epi64(s12, 32); // -> s1 + s13 = _mm256_srli_epi64(s13, 32); // -> s3 + s10 = _mm256_srli_epi64(s10, 32); // -> s0 + s11 = _mm256_srli_epi64(s11, 32); // -> s2 + + s0 = _mm256_add_epi64(s0, s10); + s1 = _mm256_add_epi64(s1, s12); + s2 = _mm256_add_epi64(s2, s11); + s3 = _mm256_add_epi64(s3, s13); + + s10 = _mm256_slli_epi64(s14, 32); // -> s0 + s11 = _mm256_slli_epi64(s15, 32); // -> s2 + s14 = _mm256_srli_epi64(s14, 32); // -> s1 + s15 = _mm256_srli_epi64(s15, 32); // -> s3 + s10 = _mm256_srli_epi64(s10, 32); // -> s0 + s11 = _mm256_srli_epi64(s11, 32); // -> s2 + + s0 = _mm256_add_epi64(s0, s10); + s1 = _mm256_add_epi64(s1, s14); + s2 = _mm256_add_epi64(s2, s11); + s3 = _mm256_add_epi64(s3, s15); + + // deal with the second column + // s8,s9 are still in place! + s10 = _mm256_srli_epi64(s8, 32); + s11 = _mm256_srli_epi64(s9, 32); + + s12 = _mm256_loadu_si256(y_ptr + 2); + s13 = _mm256_loadu_si256(y_ptr + 3); + s14 = _mm256_srli_epi64(s12, 32); + s15 = _mm256_srli_epi64(s13, 32); + + s12 = _mm256_mul_epu32(s8, s12); // -> s4,s5 + s13 = _mm256_mul_epu32(s9, s13); // -> s6,s7 + s14 = _mm256_mul_epu32(s10, s14); // -> s4,s5 + s15 = _mm256_mul_epu32(s11, s15); // -> s6,s7 + + s10 = _mm256_slli_epi64(s12, 32); // -> s4 + s11 = _mm256_slli_epi64(s13, 32); // -> s6 + s12 = _mm256_srli_epi64(s12, 32); // -> s5 + s13 = _mm256_srli_epi64(s13, 32); // -> s7 + s10 = _mm256_srli_epi64(s10, 32); // -> s4 + s11 = _mm256_srli_epi64(s11, 32); // -> s6 + + s4 = _mm256_add_epi64(s4, s10); + s5 = _mm256_add_epi64(s5, s12); + s6 = _mm256_add_epi64(s6, s11); + s7 = _mm256_add_epi64(s7, s13); + + s10 = _mm256_slli_epi64(s14, 32); // -> s4 + s11 = _mm256_slli_epi64(s15, 32); // -> s6 + s14 = _mm256_srli_epi64(s14, 32); // -> s5 + s15 = _mm256_srli_epi64(s15, 32); // -> s7 + s10 = _mm256_srli_epi64(s10, 32); // -> s4 + s11 = _mm256_srli_epi64(s11, 32); // -> s6 + + s4 = _mm256_add_epi64(s4, s10); + s5 = _mm256_add_epi64(s5, s14); + s6 = _mm256_add_epi64(s6, s11); + s7 = _mm256_add_epi64(s7, s15); + + x_ptr += 2; + y_ptr += 4; + } + // final reduction + const uint64_t H2 = precomp->h; + s8 = _mm256_set1_epi64x((UINT64_C(1) << H2) - 1); // MASK2 + s9 = _mm256_loadu_si256((__m256i*)precomp->s2l_pow_red); // S2L_POW_RED + s10 = _mm256_loadu_si256((__m256i*)precomp->s2h_pow_red); // S2H_POW_RED + //--- s0,s1 + s11 = _mm256_and_si256(s1, s8); + s12 = _mm256_srli_epi64(s1, H2); + s13 = _mm256_mul_epu32(s11, s9); + s14 = _mm256_mul_epu32(s12, s10); + s0 = _mm256_add_epi64(s0, s13); + s0 = _mm256_add_epi64(s0, s14); + _mm256_storeu_si256(res_ptr + 0, s0); + //--- s2,s3 + s11 = _mm256_and_si256(s3, s8); + s12 = _mm256_srli_epi64(s3, H2); + s13 = _mm256_mul_epu32(s11, s9); + s14 = _mm256_mul_epu32(s12, s10); + s2 = _mm256_add_epi64(s2, s13); + s2 = _mm256_add_epi64(s2, s14); + _mm256_storeu_si256(res_ptr + 1, s2); + //--- s4,s5 + s11 = _mm256_and_si256(s5, s8); + s12 = _mm256_srli_epi64(s5, H2); + s13 = _mm256_mul_epu32(s11, s9); + s14 = _mm256_mul_epu32(s12, s10); + s4 = _mm256_add_epi64(s4, s13); + s4 = _mm256_add_epi64(s4, s14); + _mm256_storeu_si256(res_ptr + 2, s4); + //--- s6,s7 + s11 = _mm256_and_si256(s7, s8); + s12 = _mm256_srli_epi64(s7, H2); + s13 = _mm256_mul_epu32(s11, s9); + s14 = _mm256_mul_epu32(s12, s10); + s6 = _mm256_add_epi64(s6, s13); + s6 = _mm256_add_epi64(s6, s14); + _mm256_storeu_si256(res_ptr + 3, s6); +} + +EXPORT void q120x2_vec_mat2cols_product_bbc_avx2(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell, + q120b* const res, const q120b* const x, const q120c* const y) { + __m256i s0 = _mm256_setzero_si256(); // col 1a + __m256i s1 = _mm256_setzero_si256(); + __m256i s2 = _mm256_setzero_si256(); // col 1b + __m256i s3 = _mm256_setzero_si256(); + __m256i s4 = _mm256_setzero_si256(); // col 2a + __m256i s5 = _mm256_setzero_si256(); + __m256i s6 = _mm256_setzero_si256(); // col 2b + __m256i s7 = _mm256_setzero_si256(); + __m256i s8, s9, s10, s11; + __m256i s12, s13, s14, s15; + + s11 = _mm256_set1_epi64x(0xFFFFFFFFUL); + const __m256i* x_ptr = (__m256i*)x; + const __m256i* y_ptr = (__m256i*)y; + __m256i* res_ptr = (__m256i*)res; + for (uint64_t i = 0; i < ell; ++i) { + // treat item a + s8 = _mm256_loadu_si256(x_ptr); + s9 = _mm256_srli_epi64(s8, 32); + + s12 = _mm256_loadu_si256(y_ptr); + s13 = _mm256_loadu_si256(y_ptr + 2); + s14 = _mm256_srli_epi64(s12, 32); + s15 = _mm256_srli_epi64(s13, 32); + + s12 = _mm256_mul_epu32(s8, s12); // c1a -> s0,s1 + s13 = _mm256_mul_epu32(s8, s13); // c2a -> s4,s5 + s14 = _mm256_mul_epu32(s9, s14); // c1a -> s0,s1 + s15 = _mm256_mul_epu32(s9, s15); // c2a -> s4,s5 + + s8 = _mm256_and_si256(s12, s11); // -> s0 + s9 = _mm256_and_si256(s13, s11); // -> s4 + s12 = _mm256_srli_epi64(s12, 32); // -> s1 + s13 = _mm256_srli_epi64(s13, 32); // -> s5 + s0 = _mm256_add_epi64(s0, s8); + s1 = _mm256_add_epi64(s1, s12); + s4 = _mm256_add_epi64(s4, s9); + s5 = _mm256_add_epi64(s5, s13); + + s8 = _mm256_and_si256(s14, s11); // -> s0 + s9 = _mm256_and_si256(s15, s11); // -> s4 + s14 = _mm256_srli_epi64(s14, 32); // -> s1 + s15 = _mm256_srli_epi64(s15, 32); // -> s5 + s0 = _mm256_add_epi64(s0, s8); + s1 = _mm256_add_epi64(s1, s14); + s4 = _mm256_add_epi64(s4, s9); + s5 = _mm256_add_epi64(s5, s15); + + // treat item b + s8 = _mm256_loadu_si256(x_ptr + 1); + s9 = _mm256_srli_epi64(s8, 32); + + s12 = _mm256_loadu_si256(y_ptr + 1); + s13 = _mm256_loadu_si256(y_ptr + 3); + s14 = _mm256_srli_epi64(s12, 32); + s15 = _mm256_srli_epi64(s13, 32); + + s12 = _mm256_mul_epu32(s8, s12); // c1b -> s2,s3 + s13 = _mm256_mul_epu32(s8, s13); // c2b -> s6,s7 + s14 = _mm256_mul_epu32(s9, s14); // c1b -> s2,s3 + s15 = _mm256_mul_epu32(s9, s15); // c2b -> s6,s7 + + s8 = _mm256_and_si256(s12, s11); // -> s2 + s9 = _mm256_and_si256(s13, s11); // -> s6 + s12 = _mm256_srli_epi64(s12, 32); // -> s3 + s13 = _mm256_srli_epi64(s13, 32); // -> s7 + s2 = _mm256_add_epi64(s2, s8); + s3 = _mm256_add_epi64(s3, s12); + s6 = _mm256_add_epi64(s6, s9); + s7 = _mm256_add_epi64(s7, s13); + + s8 = _mm256_and_si256(s14, s11); // -> s2 + s9 = _mm256_and_si256(s15, s11); // -> s6 + s14 = _mm256_srli_epi64(s14, 32); // -> s3 + s15 = _mm256_srli_epi64(s15, 32); // -> s7 + s2 = _mm256_add_epi64(s2, s8); + s3 = _mm256_add_epi64(s3, s14); + s6 = _mm256_add_epi64(s6, s9); + s7 = _mm256_add_epi64(s7, s15); + + x_ptr += 2; + y_ptr += 4; + } + // final reduction + const uint64_t H2 = precomp->h; + s8 = _mm256_set1_epi64x((UINT64_C(1) << H2) - 1); // MASK2 + s9 = _mm256_loadu_si256((__m256i*)precomp->s2l_pow_red); // S2L_POW_RED + s10 = _mm256_loadu_si256((__m256i*)precomp->s2h_pow_red); // S2H_POW_RED + //--- s0,s1 + s11 = _mm256_and_si256(s1, s8); + s12 = _mm256_srli_epi64(s1, H2); + s13 = _mm256_mul_epu32(s11, s9); + s14 = _mm256_mul_epu32(s12, s10); + s0 = _mm256_add_epi64(s0, s13); + s0 = _mm256_add_epi64(s0, s14); + _mm256_storeu_si256(res_ptr + 0, s0); + //--- s2,s3 + s11 = _mm256_and_si256(s3, s8); + s12 = _mm256_srli_epi64(s3, H2); + s13 = _mm256_mul_epu32(s11, s9); + s14 = _mm256_mul_epu32(s12, s10); + s2 = _mm256_add_epi64(s2, s13); + s2 = _mm256_add_epi64(s2, s14); + _mm256_storeu_si256(res_ptr + 1, s2); + //--- s4,s5 + s11 = _mm256_and_si256(s5, s8); + s12 = _mm256_srli_epi64(s5, H2); + s13 = _mm256_mul_epu32(s11, s9); + s14 = _mm256_mul_epu32(s12, s10); + s4 = _mm256_add_epi64(s4, s13); + s4 = _mm256_add_epi64(s4, s14); + _mm256_storeu_si256(res_ptr + 2, s4); + //--- s6,s7 + s11 = _mm256_and_si256(s7, s8); + s12 = _mm256_srli_epi64(s7, H2); + s13 = _mm256_mul_epu32(s11, s9); + s14 = _mm256_mul_epu32(s12, s10); + s6 = _mm256_add_epi64(s6, s13); + s6 = _mm256_add_epi64(s6, s14); + _mm256_storeu_si256(res_ptr + 3, s6); +} + +EXPORT void q120x2_vec_mat1col_product_bbc_avx2(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell, + q120b* const res, const q120b* const x, const q120c* const y) { + __m256i s0 = _mm256_setzero_si256(); // col 1a + __m256i s1 = _mm256_setzero_si256(); + __m256i s2 = _mm256_setzero_si256(); // col 1b + __m256i s3 = _mm256_setzero_si256(); + __m256i s4 = _mm256_set1_epi64x(0xFFFFFFFFUL); + __m256i s8, s9, s10, s11; + __m256i s12, s13, s14, s15; + + const __m256i* x_ptr = (__m256i*)x; + const __m256i* y_ptr = (__m256i*)y; + __m256i* res_ptr = (__m256i*)res; + for (uint64_t i = 0; i < ell; ++i) { + s8 = _mm256_loadu_si256(x_ptr); + s9 = _mm256_loadu_si256(x_ptr + 1); + s10 = _mm256_srli_epi64(s8, 32); + s11 = _mm256_srli_epi64(s9, 32); + + s12 = _mm256_loadu_si256(y_ptr); + s13 = _mm256_loadu_si256(y_ptr + 1); + s14 = _mm256_srli_epi64(s12, 32); + s15 = _mm256_srli_epi64(s13, 32); + + s12 = _mm256_mul_epu32(s8, s12); // -> s0,s1 + s13 = _mm256_mul_epu32(s9, s13); // -> s2,s3 + s14 = _mm256_mul_epu32(s10, s14); // -> s0,s1 + s15 = _mm256_mul_epu32(s11, s15); // -> s2,s3 + + s8 = _mm256_and_si256(s12, s4); // -> s0 + s9 = _mm256_and_si256(s13, s4); // -> s2 + s10 = _mm256_and_si256(s14, s4); // -> s0 + s11 = _mm256_and_si256(s15, s4); // -> s2 + s12 = _mm256_srli_epi64(s12, 32); // -> s1 + s13 = _mm256_srli_epi64(s13, 32); // -> s3 + s14 = _mm256_srli_epi64(s14, 32); // -> s1 + s15 = _mm256_srli_epi64(s15, 32); // -> s3 + + s0 = _mm256_add_epi64(s0, s8); + s1 = _mm256_add_epi64(s1, s12); + s2 = _mm256_add_epi64(s2, s9); + s3 = _mm256_add_epi64(s3, s13); + s0 = _mm256_add_epi64(s0, s10); + s1 = _mm256_add_epi64(s1, s14); + s2 = _mm256_add_epi64(s2, s11); + s3 = _mm256_add_epi64(s3, s15); + + x_ptr += 2; + y_ptr += 2; + } + // final reduction + const uint64_t H2 = precomp->h; + s8 = _mm256_set1_epi64x((UINT64_C(1) << H2) - 1); // MASK2 + s9 = _mm256_loadu_si256((__m256i*)precomp->s2l_pow_red); // S2L_POW_RED + s10 = _mm256_loadu_si256((__m256i*)precomp->s2h_pow_red); // S2H_POW_RED + //--- s0,s1 + s11 = _mm256_and_si256(s1, s8); + s12 = _mm256_srli_epi64(s1, H2); + s13 = _mm256_mul_epu32(s11, s9); + s14 = _mm256_mul_epu32(s12, s10); + s0 = _mm256_add_epi64(s0, s13); + s0 = _mm256_add_epi64(s0, s14); + _mm256_storeu_si256(res_ptr + 0, s0); + //--- s2,s3 + s11 = _mm256_and_si256(s3, s8); + s12 = _mm256_srli_epi64(s3, H2); + s13 = _mm256_mul_epu32(s11, s9); + s14 = _mm256_mul_epu32(s12, s10); + s2 = _mm256_add_epi64(s2, s13); + s2 = _mm256_add_epi64(s2, s14); + _mm256_storeu_si256(res_ptr + 1, s2); +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/q120/q120_arithmetic_private.h b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/q120/q120_arithmetic_private.h new file mode 100644 index 0000000..399f989 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/q120/q120_arithmetic_private.h @@ -0,0 +1,37 @@ +#ifndef SPQLIOS_Q120_ARITHMETIC_DEF_H +#define SPQLIOS_Q120_ARITHMETIC_DEF_H + +#include + +typedef struct _q120_mat1col_product_baa_precomp { + uint64_t h; + uint64_t h_pow_red[4]; +#ifndef NDEBUG + double res_bit_size; +#endif +} q120_mat1col_product_baa_precomp; + +typedef struct _q120_mat1col_product_bbb_precomp { + uint64_t h; + uint64_t s1h_pow_red[4]; + uint64_t s2l_pow_red[4]; + uint64_t s2h_pow_red[4]; + uint64_t s3l_pow_red[4]; + uint64_t s3h_pow_red[4]; + uint64_t s4l_pow_red[4]; + uint64_t s4h_pow_red[4]; +#ifndef NDEBUG + double res_bit_size; +#endif +} q120_mat1col_product_bbb_precomp; + +typedef struct _q120_mat1col_product_bbc_precomp { + uint64_t h; + uint64_t s2l_pow_red[4]; + uint64_t s2h_pow_red[4]; +#ifndef NDEBUG + double res_bit_size; +#endif +} q120_mat1col_product_bbc_precomp; + +#endif // SPQLIOS_Q120_ARITHMETIC_DEF_H diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/q120/q120_arithmetic_ref.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/q120/q120_arithmetic_ref.c new file mode 100644 index 0000000..cd20f10 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/q120/q120_arithmetic_ref.c @@ -0,0 +1,506 @@ +#include +#include +#include + +#include "q120_arithmetic.h" +#include "q120_arithmetic_private.h" +#include "q120_common.h" + +#define MODQ(val, q) ((val) % (q)) + +double comp_bit_size_red(const uint64_t h, const uint64_t qs[4]) { + assert(h < 128); + double h_pow2_bs = 0; + for (uint64_t k = 0; k < 4; ++k) { + double t = log2((double)MODQ((__uint128_t)1 << h, qs[k])); + if (t > h_pow2_bs) h_pow2_bs = t; + } + return h_pow2_bs; +} + +double comp_bit_size_sum(const uint64_t n, const double* const bs) { + double s = 0; + for (uint64_t i = 0; i < n; ++i) { + s += pow(2, bs[i]); + } + return log2(s); +} + +void vec_mat1col_product_baa_precomp(q120_mat1col_product_baa_precomp* precomp) { + uint64_t qs[4] = {Q1, Q2, Q3, Q4}; + + double min_res_bs = 1000; + uint64_t min_h = -1; + + double ell_bs = log2((double)MAX_ELL); + for (uint64_t h = 1; h < 64; ++h) { + double h_pow2_bs = comp_bit_size_red(h, qs); + + const double bs[] = {h + ell_bs, 64 - h + ell_bs + h_pow2_bs}; + const double res_bs = comp_bit_size_sum(2, bs); + + if (min_res_bs > res_bs) { + min_res_bs = res_bs; + min_h = h; + } + } + + assert(min_res_bs < 64); + precomp->h = min_h; + for (uint64_t k = 0; k < 4; ++k) { + precomp->h_pow_red[k] = MODQ(UINT64_C(1) << precomp->h, qs[k]); + } +#ifndef NDEBUG + precomp->res_bit_size = min_res_bs; +#endif + // printf("AA %lu %lf\n", min_h, min_res_bs); +} + +EXPORT q120_mat1col_product_baa_precomp* q120_new_vec_mat1col_product_baa_precomp() { + q120_mat1col_product_baa_precomp* res = malloc(sizeof(q120_mat1col_product_baa_precomp)); + vec_mat1col_product_baa_precomp(res); + return res; +} + +EXPORT void q120_delete_vec_mat1col_product_baa_precomp(q120_mat1col_product_baa_precomp* addr) { free(addr); } + +EXPORT void q120_vec_mat1col_product_baa_ref(q120_mat1col_product_baa_precomp* precomp, const uint64_t ell, + q120b* const res, const q120a* const x, const q120a* const y) { + /** + * Algorithm: + * - res = acc1 + acc2 . ((2^H) % Q) + * - acc1 is the sum of H LSB of products x[i].y[i] + * - acc2 is the sum of 64-H MSB of products x[i]].y[i] + * - for l < 10k acc1 will have H + log2(10000) and acc2 64 - H + log2(10000) bits + * - final sum has max(H, 64 - H + bit_size((2^H) % Q)) + log2(10000) + 1 bits + */ + const uint64_t H = precomp->h; + const uint64_t MASK = (UINT64_C(1) << H) - 1; + + uint64_t acc1[4] = {0, 0, 0, 0}; // accumulate H least significant bits of product + uint64_t acc2[4] = {0, 0, 0, 0}; // accumulate 64 - H most significan bits of product + + const uint64_t* const x_ptr = (uint64_t*)x; + const uint64_t* const y_ptr = (uint64_t*)y; + + for (uint64_t i = 0; i < 4 * ell; i += 4) { + for (uint64_t j = 0; j < 4; ++j) { + uint64_t t = x_ptr[i + j] * y_ptr[i + j]; + acc1[j] += t & MASK; + acc2[j] += t >> H; + } + } + + uint64_t* const res_ptr = (uint64_t*)res; + for (uint64_t j = 0; j < 4; ++j) { + res_ptr[j] = acc1[j] + acc2[j] * precomp->h_pow_red[j]; + assert(log2(res_ptr[j]) < precomp->res_bit_size); + } +} + +void vec_mat1col_product_bbb_precomp(q120_mat1col_product_bbb_precomp* precomp) { + uint64_t qs[4] = {Q1, Q2, Q3, Q4}; + + double ell_bs = log2((double)MAX_ELL); + double min_res_bs = 1000; + uint64_t min_h = -1; + + const double s1_bs = 32 + ell_bs; + const double s2_bs = 32 + ell_bs + log2(3); + const double s3_bs = 32 + ell_bs + log2(3); + const double s4_bs = 32 + ell_bs; + for (uint64_t h = 16; h < 32; ++h) { + const double s1l_bs = h; + const double s1h_bs = (s1_bs - h) + comp_bit_size_red(h, qs); + const double s2l_bs = h + comp_bit_size_red(32, qs); + const double s2h_bs = (s2_bs - h) + comp_bit_size_red(32 + h, qs); + const double s3l_bs = h + comp_bit_size_red(64, qs); + const double s3h_bs = (s3_bs - h) + comp_bit_size_red(64 + h, qs); + const double s4l_bs = h + comp_bit_size_red(96, qs); + const double s4h_bs = (s4_bs - h) + comp_bit_size_red(96 + h, qs); + + const double bs[] = {s1l_bs, s1h_bs, s2l_bs, s2h_bs, s3l_bs, s3h_bs, s4l_bs, s4h_bs}; + const double res_bs = comp_bit_size_sum(8, bs); + + if (min_res_bs > res_bs) { + min_res_bs = res_bs; + min_h = h; + } + } + + assert(min_res_bs < 64); + precomp->h = min_h; + for (uint64_t k = 0; k < 4; ++k) { + precomp->s1h_pow_red[k] = UINT64_C(1) << precomp->h; // 2^24 + precomp->s2l_pow_red[k] = MODQ(UINT64_C(1) << 32, qs[k]); // 2^32 + precomp->s2h_pow_red[k] = MODQ(precomp->s2l_pow_red[k] * precomp->s1h_pow_red[k], qs[k]); // 2^(32+24) + precomp->s3l_pow_red[k] = MODQ(precomp->s2l_pow_red[k] * precomp->s2l_pow_red[k], qs[k]); // 2^64 = 2^(32+32) + precomp->s3h_pow_red[k] = MODQ(precomp->s3l_pow_red[k] * precomp->s1h_pow_red[k], qs[k]); // 2^(64+24) + precomp->s4l_pow_red[k] = MODQ(precomp->s3l_pow_red[k] * precomp->s2l_pow_red[k], qs[k]); // 2^96 = 2^(64+32) + precomp->s4h_pow_red[k] = MODQ(precomp->s4l_pow_red[k] * precomp->s1h_pow_red[k], qs[k]); // 2^(96+24) + } +// printf("AA %lu %lf\n", min_h, min_res_bs); +#ifndef NDEBUG + precomp->res_bit_size = min_res_bs; +#endif +} + +EXPORT q120_mat1col_product_bbb_precomp* q120_new_vec_mat1col_product_bbb_precomp() { + q120_mat1col_product_bbb_precomp* res = malloc(sizeof(q120_mat1col_product_bbb_precomp)); + vec_mat1col_product_bbb_precomp(res); + return res; +} + +EXPORT void q120_delete_vec_mat1col_product_bbb_precomp(q120_mat1col_product_bbb_precomp* addr) { free(addr); } + +EXPORT void q120_vec_mat1col_product_bbb_ref(q120_mat1col_product_bbb_precomp* precomp, const uint64_t ell, + q120b* const res, const q120b* const x, const q120b* const y) { + /** + * Algorithm: + * 1. Split x_i and y_i in 2 32-bit parts and compute the cross-products: + * - x_i = xl_i + xh_i . 2^32 + * - y_i = yl_i + yh_i . 2^32 + * - A_i = xl_i . yl_i + * - B_i = xl_i . yh_i + * - C_i = xh_i . yl_i + * - D_i = xh_i . yh_i + * - we have x_i . y_i == A_i + (B_i + C_i) . 2^32 + D_i . 2^64 + * 2. Split A_i, B_i, C_i and D_i into 2 32-bit parts + * - A_i = Al_i + Ah_i . 2^32 + * - B_i = Bl_i + Bh_i . 2^32 + * - C_i = Cl_i + Ch_i . 2^32 + * - D_i = Dl_i + Dh_i . 2^32 + * 3. Compute the sums: + * - S1 = \sum Al_i + * - S2 = \sum (Ah_i + Bl_i + Cl_i) + * - S3 = \sum (Bh_i + Ch_i + Dl_i) + * - S4 = \sum Dh_i + * - here S1, S4 have 32 + log2(ell) bits and S2, S3 have 32 + log2(ell) + + * log2(3) bits + * - for ell == 10000 S2, S3 have < 47 bits + * 4. Split S1, S2, S3 and S4 in 2 24-bit parts (24 = ceil(47/2)) + * - S1 = S1l + S1h . 2^24 + * - S2 = S2l + S2h . 2^24 + * - S3 = S3l + S3h . 2^24 + * - S4 = S4l + S4h . 2^24 + * 5. Compute final result as: + * - \sum x_i . y_i = S1l + S1h . 2^24 + * + S2l . 2^32 + S2h . 2^(32+24) + * + S3l . 2^64 + S3h . 2^(64 + 24) + * + S4l . 2^96 + S4l . 2^(96+24) + * - here the powers of 2 are reduced modulo the primes Q before + * multiplications + * - the result will be on 24 + 3 + bit size of primes Q + */ + const uint64_t H1 = 32; + const uint64_t MASK1 = (UINT64_C(1) << H1) - 1; + + uint64_t s1[4] = {0, 0, 0, 0}; + uint64_t s2[4] = {0, 0, 0, 0}; + uint64_t s3[4] = {0, 0, 0, 0}; + uint64_t s4[4] = {0, 0, 0, 0}; + + const uint64_t* const x_ptr = (uint64_t*)x; + const uint64_t* const y_ptr = (uint64_t*)y; + + for (uint64_t i = 0; i < 4 * ell; i += 4) { + for (uint64_t j = 0; j < 4; ++j) { + const uint64_t xl = x_ptr[i + j] & MASK1; + const uint64_t xh = x_ptr[i + j] >> H1; + const uint64_t yl = y_ptr[i + j] & MASK1; + const uint64_t yh = y_ptr[i + j] >> H1; + + const uint64_t a = xl * yl; + const uint64_t al = a & MASK1; + const uint64_t ah = a >> H1; + + const uint64_t b = xl * yh; + const uint64_t bl = b & MASK1; + const uint64_t bh = b >> H1; + + const uint64_t c = xh * yl; + const uint64_t cl = c & MASK1; + const uint64_t ch = c >> H1; + + const uint64_t d = xh * yh; + const uint64_t dl = d & MASK1; + const uint64_t dh = d >> H1; + + s1[j] += al; + s2[j] += ah + bl + cl; + s3[j] += bh + ch + dl; + s4[j] += dh; + } + } + + const uint64_t H2 = precomp->h; + const uint64_t MASK2 = (UINT64_C(1) << H2) - 1; + + uint64_t* const res_ptr = (uint64_t*)res; + for (uint64_t j = 0; j < 4; ++j) { + const uint64_t s1l = s1[j] & MASK2; + const uint64_t s1h = s1[j] >> H2; + const uint64_t s2l = s2[j] & MASK2; + const uint64_t s2h = s2[j] >> H2; + const uint64_t s3l = s3[j] & MASK2; + const uint64_t s3h = s3[j] >> H2; + const uint64_t s4l = s4[j] & MASK2; + const uint64_t s4h = s4[j] >> H2; + + uint64_t t = s1l; + t += s1h * precomp->s1h_pow_red[j]; + t += s2l * precomp->s2l_pow_red[j]; + t += s2h * precomp->s2h_pow_red[j]; + t += s3l * precomp->s3l_pow_red[j]; + t += s3h * precomp->s3h_pow_red[j]; + t += s4l * precomp->s4l_pow_red[j]; + t += s4h * precomp->s4h_pow_red[j]; + + res_ptr[j] = t; + assert(log2(res_ptr[j]) < precomp->res_bit_size); + } +} + +void vec_mat1col_product_bbc_precomp(q120_mat1col_product_bbc_precomp* precomp) { + uint64_t qs[4] = {Q1, Q2, Q3, Q4}; + + double min_res_bs = 1000; + uint64_t min_h = -1; + + double pow2_32_bs = comp_bit_size_red(32, qs); + + double ell_bs = log2((double)MAX_ELL); + double s1_bs = 32 + ell_bs; + for (uint64_t h = 16; h < 32; ++h) { + double s2l_bs = pow2_32_bs + h; + double s2h_bs = s1_bs - h + comp_bit_size_red(32 + h, qs); + + const double bs[] = {s1_bs, s2l_bs, s2h_bs}; + const double res_bs = comp_bit_size_sum(3, bs); + + if (min_res_bs > res_bs) { + min_res_bs = res_bs; + min_h = h; + } + } + + assert(min_res_bs < 64); + precomp->h = min_h; + for (uint64_t k = 0; k < 4; ++k) { + precomp->s2l_pow_red[k] = MODQ(UINT64_C(1) << 32, qs[k]); + precomp->s2h_pow_red[k] = MODQ(UINT64_C(1) << (32 + precomp->h), qs[k]); + } +#ifndef NDEBUG + precomp->res_bit_size = min_res_bs; +#endif + // printf("AA %lu %lf\n", min_h, min_res_bs); +} + +EXPORT q120_mat1col_product_bbc_precomp* q120_new_vec_mat1col_product_bbc_precomp() { + q120_mat1col_product_bbc_precomp* res = malloc(sizeof(q120_mat1col_product_bbc_precomp)); + vec_mat1col_product_bbc_precomp(res); + return res; +} + +EXPORT void q120_delete_vec_mat1col_product_bbc_precomp(q120_mat1col_product_bbc_precomp* addr) { free(addr); } + +EXPORT void q120_vec_mat1col_product_bbc_ref_old(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell, + q120b* const res, const q120b* const x, const q120c* const y) { + /** + * Algorithm: + * 0. We have + * - y0_i == y_i % Q and y1_i == (y_i . 2^32) % Q + * 1. Split x_i in 2 32-bit parts and compute the cross-products: + * - x_i = xl_i + xh_i . 2^32 + * - A_i = xl_i . y1_i + * - B_i = xh_i . y2_i + * - we have x_i . y_i == A_i + B_i + * 2. Split A_i and B_i into 2 32-bit parts + * - A_i = Al_i + Ah_i . 2^32 + * - B_i = Bl_i + Bh_i . 2^32 + * 3. Compute the sums: + * - S1 = \sum Al_i + Bl_i + * - S2 = \sum Ah_i + Bh_i + * - here S1 and S2 have 32 + log2(ell) bits + * - for ell == 10000 S1, S2 have < 46 bits + * 4. Split S2 in 27-bit and 19-bit parts (27+19 == 46) + * - S2 = S2l + S2h . 2^27 + * 5. Compute final result as: + * - \sum x_i . y_i = S1 + S2l . 2^32 + S2h . 2^(32+27) + * - here the powers of 2 are reduced modulo the primes Q before + * multiplications + * - the result will be on < 52 bits + */ + + const uint64_t H1 = 32; + const uint64_t MASK1 = (UINT64_C(1) << H1) - 1; + + uint64_t s1[4] = {0, 0, 0, 0}; + uint64_t s2[4] = {0, 0, 0, 0}; + + const uint64_t* const x_ptr = (uint64_t*)x; + const uint32_t* const y_ptr = (uint32_t*)y; + + for (uint64_t i = 0; i < 4 * ell; i += 4) { + for (uint64_t j = 0; j < 4; ++j) { + const uint64_t xl = x_ptr[i + j] & MASK1; + const uint64_t xh = x_ptr[i + j] >> H1; + const uint64_t y0 = y_ptr[2 * (i + j)]; + const uint64_t y1 = y_ptr[2 * (i + j) + 1]; + + const uint64_t a = xl * y0; + const uint64_t al = a & MASK1; + const uint64_t ah = a >> H1; + + const uint64_t b = xh * y1; + const uint64_t bl = b & MASK1; + const uint64_t bh = b >> H1; + + s1[j] += al + bl; + s2[j] += ah + bh; + } + } + + const uint64_t H2 = precomp->h; + const uint64_t MASK2 = (UINT64_C(1) << H2) - 1; + + uint64_t* const res_ptr = (uint64_t*)res; + for (uint64_t k = 0; k < 4; ++k) { + const uint64_t s2l = s2[k] & MASK2; + const uint64_t s2h = s2[k] >> H2; + + uint64_t t = s1[k]; + t += s2l * precomp->s2l_pow_red[k]; + t += s2h * precomp->s2h_pow_red[k]; + + res_ptr[k] = t; + assert(log2(res_ptr[k]) < precomp->res_bit_size); + } +} + +static __always_inline void accum_mul_q120_bc(uint64_t res[8], // + const uint32_t x_layb[8], const uint32_t y_layc[8]) { + for (uint64_t i = 0; i < 4; ++i) { + static const uint64_t MASK32 = 0xFFFFFFFFUL; + uint64_t x_lo = x_layb[2 * i]; + uint64_t x_hi = x_layb[2 * i + 1]; + uint64_t y_lo = y_layc[2 * i]; + uint64_t y_hi = y_layc[2 * i + 1]; + uint64_t xy_lo = x_lo * y_lo; + uint64_t xy_hi = x_hi * y_hi; + res[2 * i] += (xy_lo & MASK32) + (xy_hi & MASK32); + res[2 * i + 1] += (xy_lo >> 32) + (xy_hi >> 32); + } +} + +static __always_inline void accum_to_q120b(uint64_t res[4], // + const uint64_t s[8], const q120_mat1col_product_bbc_precomp* precomp) { + const uint64_t H2 = precomp->h; + const uint64_t MASK2 = (UINT64_C(1) << H2) - 1; + for (uint64_t k = 0; k < 4; ++k) { + const uint64_t s2l = s[2 * k + 1] & MASK2; + const uint64_t s2h = s[2 * k + 1] >> H2; + uint64_t t = s[2 * k]; + t += s2l * precomp->s2l_pow_red[k]; + t += s2h * precomp->s2h_pow_red[k]; + res[k] = t; + assert(log2(res[k]) < precomp->res_bit_size); + } +} + +EXPORT void q120_vec_mat1col_product_bbc_ref(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell, + q120b* const res, const q120b* const x, const q120c* const y) { + uint64_t s[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + + const uint32_t(*const x_ptr)[8] = (const uint32_t(*const)[8])x; + const uint32_t(*const y_ptr)[8] = (const uint32_t(*const)[8])y; + + for (uint64_t i = 0; i < ell; i++) { + accum_mul_q120_bc(s, x_ptr[i], y_ptr[i]); + } + accum_to_q120b((uint64_t*)res, s, precomp); +} + +EXPORT void q120x2_vec_mat1col_product_bbc_ref(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell, + q120b* const res, const q120b* const x, const q120c* const y) { + uint64_t s[2][16] = {0}; + + const uint32_t(*const x_ptr)[2][8] = (const uint32_t(*const)[2][8])x; + const uint32_t(*const y_ptr)[2][8] = (const uint32_t(*const)[2][8])y; + uint64_t(*re)[4] = (uint64_t(*)[4])res; + + for (uint64_t i = 0; i < ell; i++) { + accum_mul_q120_bc(s[0], x_ptr[i][0], y_ptr[i][0]); + accum_mul_q120_bc(s[1], x_ptr[i][1], y_ptr[i][1]); + } + accum_to_q120b(re[0], s[0], precomp); + accum_to_q120b(re[1], s[1], precomp); +} + +EXPORT void q120x2_vec_mat2cols_product_bbc_ref(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell, + q120b* const res, const q120b* const x, const q120c* const y) { + uint64_t s[4][16] = {0}; + + const uint32_t(*const x_ptr)[2][8] = (const uint32_t(*const)[2][8])x; + const uint32_t(*const y_ptr)[4][8] = (const uint32_t(*const)[4][8])y; + uint64_t(*re)[4] = (uint64_t(*)[4])res; + + for (uint64_t i = 0; i < ell; i++) { + accum_mul_q120_bc(s[0], x_ptr[i][0], y_ptr[i][0]); + accum_mul_q120_bc(s[1], x_ptr[i][1], y_ptr[i][1]); + accum_mul_q120_bc(s[2], x_ptr[i][0], y_ptr[i][2]); + accum_mul_q120_bc(s[3], x_ptr[i][1], y_ptr[i][3]); + } + accum_to_q120b(re[0], s[0], precomp); + accum_to_q120b(re[1], s[1], precomp); + accum_to_q120b(re[2], s[2], precomp); + accum_to_q120b(re[3], s[3], precomp); +} + +EXPORT void q120x2_extract_1blk_from_q120b_ref(uint64_t nn, uint64_t blk, + q120x2b* const dst, // 8 doubles + const q120b* const src // a q120b vector +) { + const uint64_t* in = (uint64_t*)src; + uint64_t* out = (uint64_t*)dst; + for (uint64_t i = 0; i < 8; ++i) { + out[i] = in[8 * blk + i]; + } +} + +// function on layout c is the exact same as on layout b +#ifdef __APPLE__ +#pragma weak q120x2_extract_1blk_from_q120c_ref = q120x2_extract_1blk_from_q120b_ref +#else +EXPORT void q120x2_extract_1blk_from_q120c_ref(uint64_t nn, uint64_t blk, + q120x2c* const dst, // 8 doubles + const q120c* const src // a q120c vector + ) __attribute__((alias("q120x2_extract_1blk_from_q120b_ref"))); +#endif + +EXPORT void q120x2_extract_1blk_from_contiguous_q120b_ref( + uint64_t nn, uint64_t nrows, uint64_t blk, + q120x2b* const dst, // nrows * 2 q120 + const q120b* const src // a contiguous array of nrows q120b vectors +) { + const uint64_t* in = (uint64_t*)src; + uint64_t* out = (uint64_t*)dst; + for (uint64_t row = 0; row < nrows; ++row) { + for (uint64_t i = 0; i < 8; ++i) { + out[i] = in[8 * blk + i]; + } + in += 4 * nn; + out += 8; + } +} + +EXPORT void q120x2b_save_1blk_to_q120b_ref(uint64_t nn, uint64_t blk, + q120b* dest, // 1 reim vector of length m + const q120x2b* src // 8 doubles +) { + const uint64_t* in = (uint64_t*)src; + uint64_t* out = (uint64_t*)dest; + for (uint64_t i = 0; i < 8; ++i) { + out[8 * blk + i] = in[i]; + } +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/q120/q120_arithmetic_simple.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/q120/q120_arithmetic_simple.c new file mode 100644 index 0000000..0753ad5 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/q120/q120_arithmetic_simple.c @@ -0,0 +1,111 @@ +#include +#include +#include + +#include "q120_arithmetic.h" +#include "q120_common.h" + +EXPORT void q120_add_bbb_simple(uint64_t nn, q120b* const res, const q120b* const x, const q120b* const y) { + const uint64_t* x_u64 = (uint64_t*)x; + const uint64_t* y_u64 = (uint64_t*)y; + uint64_t* res_u64 = (uint64_t*)res; + for (uint64_t i = 0; i < 4 * nn; i += 4) { + res_u64[i + 0] = x_u64[i + 0] % ((uint64_t)Q1 << 33) + y_u64[i + 0] % ((uint64_t)Q1 << 33); + res_u64[i + 1] = x_u64[i + 1] % ((uint64_t)Q2 << 33) + y_u64[i + 1] % ((uint64_t)Q2 << 33); + res_u64[i + 2] = x_u64[i + 2] % ((uint64_t)Q3 << 33) + y_u64[i + 2] % ((uint64_t)Q3 << 33); + res_u64[i + 3] = x_u64[i + 3] % ((uint64_t)Q4 << 33) + y_u64[i + 3] % ((uint64_t)Q4 << 33); + } +} + +EXPORT void q120_add_ccc_simple(uint64_t nn, q120c* const res, const q120c* const x, const q120c* const y) { + const uint32_t* x_u32 = (uint32_t*)x; + const uint32_t* y_u32 = (uint32_t*)y; + uint32_t* res_u32 = (uint32_t*)res; + for (uint64_t i = 0; i < 8 * nn; i += 8) { + res_u32[i + 0] = (uint32_t)(((uint64_t)x_u32[i + 0] + (uint64_t)y_u32[i + 0]) % Q1); + res_u32[i + 1] = (uint32_t)(((uint64_t)x_u32[i + 1] + (uint64_t)y_u32[i + 1]) % Q1); + res_u32[i + 2] = (uint32_t)(((uint64_t)x_u32[i + 2] + (uint64_t)y_u32[i + 2]) % Q2); + res_u32[i + 3] = (uint32_t)(((uint64_t)x_u32[i + 3] + (uint64_t)y_u32[i + 3]) % Q2); + res_u32[i + 4] = (uint32_t)(((uint64_t)x_u32[i + 4] + (uint64_t)y_u32[i + 4]) % Q3); + res_u32[i + 5] = (uint32_t)(((uint64_t)x_u32[i + 5] + (uint64_t)y_u32[i + 5]) % Q3); + res_u32[i + 6] = (uint32_t)(((uint64_t)x_u32[i + 6] + (uint64_t)y_u32[i + 6]) % Q4); + res_u32[i + 7] = (uint32_t)(((uint64_t)x_u32[i + 7] + (uint64_t)y_u32[i + 7]) % Q4); + } +} + +EXPORT void q120_c_from_b_simple(uint64_t nn, q120c* const res, const q120b* const x) { + const uint64_t* x_u64 = (uint64_t*)x; + uint32_t* res_u32 = (uint32_t*)res; + for (uint64_t i = 0, j = 0; i < 4 * nn; i += 4, j += 8) { + res_u32[j + 0] = x_u64[i + 0] % Q1; + res_u32[j + 1] = ((uint64_t)res_u32[j + 0] << 32) % Q1; + res_u32[j + 2] = x_u64[i + 1] % Q2; + res_u32[j + 3] = ((uint64_t)res_u32[j + 2] << 32) % Q2; + res_u32[j + 4] = x_u64[i + 2] % Q3; + res_u32[j + 5] = ((uint64_t)res_u32[j + 4] << 32) % Q3; + res_u32[j + 6] = x_u64[i + 3] % Q4; + res_u32[j + 7] = ((uint64_t)res_u32[j + 6] << 32) % Q4; + } +} + +EXPORT void q120_b_from_znx64_simple(uint64_t nn, q120b* const res, const int64_t* const x) { + static const int64_t MASK_HI = INT64_C(0x8000000000000000); + static const int64_t MASK_LO = ~MASK_HI; + static const uint64_t OQ[4] = { + (Q1 - (UINT64_C(0x8000000000000000) % Q1)), + (Q2 - (UINT64_C(0x8000000000000000) % Q2)), + (Q3 - (UINT64_C(0x8000000000000000) % Q3)), + (Q4 - (UINT64_C(0x8000000000000000) % Q4)), + }; + uint64_t* res_u64 = (uint64_t*)res; + for (uint64_t i = 0, j = 0; j < nn; i += 4, ++j) { + uint64_t xj_lo = x[j] & MASK_LO; + uint64_t xj_hi = x[j] & MASK_HI; + res_u64[i + 0] = xj_lo + (xj_hi ? OQ[0] : 0); + res_u64[i + 1] = xj_lo + (xj_hi ? OQ[1] : 0); + res_u64[i + 2] = xj_lo + (xj_hi ? OQ[2] : 0); + res_u64[i + 3] = xj_lo + (xj_hi ? OQ[3] : 0); + } +} + +static int64_t posmod(int64_t x, int64_t q) { + int64_t t = x % q; + if (t < 0) + return t + q; + else + return t; +} + +EXPORT void q120_c_from_znx64_simple(uint64_t nn, q120c* const res, const int64_t* const x) { + uint32_t* res_u32 = (uint32_t*)res; + for (uint64_t i = 0, j = 0; j < nn; i += 8, ++j) { + res_u32[i + 0] = posmod(x[j], Q1); + res_u32[i + 1] = ((uint64_t)res_u32[i + 0] << 32) % Q1; + res_u32[i + 2] = posmod(x[j], Q2); + res_u32[i + 3] = ((uint64_t)res_u32[i + 2] << 32) % Q2; + res_u32[i + 4] = posmod(x[j], Q3); + res_u32[i + 5] = ((uint64_t)res_u32[i + 4] << 32) % Q3; + res_u32[i + 6] = posmod(x[j], Q4); + res_u32[i + 7] = ((uint64_t)res_u32[i + 6] << 32) % Q4; + ; + } +} + +EXPORT void q120_b_to_znx128_simple(uint64_t nn, __int128_t* const res, const q120b* const x) { + static const __int128_t Q = (__int128_t)Q1 * Q2 * Q3 * Q4; + static const __int128_t Qm1 = (__int128_t)Q2 * Q3 * Q4; + static const __int128_t Qm2 = (__int128_t)Q1 * Q3 * Q4; + static const __int128_t Qm3 = (__int128_t)Q1 * Q2 * Q4; + static const __int128_t Qm4 = (__int128_t)Q1 * Q2 * Q3; + + const uint64_t* x_u64 = (uint64_t*)x; + for (uint64_t i = 0, j = 0; j < nn; i += 4, ++j) { + __int128_t tmp = 0; + tmp += (((x_u64[i + 0] % Q1) * Q1_CRT_CST) % Q1) * Qm1; + tmp += (((x_u64[i + 1] % Q2) * Q2_CRT_CST) % Q2) * Qm2; + tmp += (((x_u64[i + 2] % Q3) * Q3_CRT_CST) % Q3) * Qm3; + tmp += (((x_u64[i + 3] % Q4) * Q4_CRT_CST) % Q4) * Qm4; + tmp %= Q; + res[j] = (tmp >= (Q + 1) / 2) ? tmp - Q : tmp; + } +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/q120/q120_common.h b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/q120/q120_common.h new file mode 100644 index 0000000..9acef5e --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/q120/q120_common.h @@ -0,0 +1,94 @@ +#ifndef SPQLIOS_Q120_COMMON_H +#define SPQLIOS_Q120_COMMON_H + +#include + +#if !defined(SPQLIOS_Q120_USE_29_BIT_PRIMES) && !defined(SPQLIOS_Q120_USE_30_BIT_PRIMES) && \ + !defined(SPQLIOS_Q120_USE_31_BIT_PRIMES) +#define SPQLIOS_Q120_USE_30_BIT_PRIMES +#endif + +/** + * 29-bit primes and 2*2^16 roots of unity + */ +#ifdef SPQLIOS_Q120_USE_29_BIT_PRIMES +#define Q1 ((1u << 29) - 2 * (1u << 17) + 1) +#define OMEGA1 78289835 +#define Q1_CRT_CST 301701286 // (Q2*Q3*Q4)^-1 mod Q1 + +#define Q2 ((1u << 29) - 5 * (1u << 17) + 1) +#define OMEGA2 178519192 +#define Q2_CRT_CST 536020447 // (Q1*Q3*Q4)^-1 mod Q2 + +#define Q3 ((1u << 29) - 26 * (1u << 17) + 1) +#define OMEGA3 483889678 +#define Q3_CRT_CST 86367873 // (Q1*Q2*Q4)^-1 mod Q3 + +#define Q4 ((1u << 29) - 35 * (1u << 17) + 1) +#define OMEGA4 239808033 +#define Q4_CRT_CST 147030781 // (Q1*Q2*Q3)^-1 mod Q4 +#endif + +/** + * 30-bit primes and 2*2^16 roots of unity + */ +#ifdef SPQLIOS_Q120_USE_30_BIT_PRIMES +#define Q1 ((1u << 30) - 2 * (1u << 17) + 1) +#define OMEGA1 1070907127 +#define Q1_CRT_CST 43599465 // (Q2*Q3*Q4)^-1 mod Q1 + +#define Q2 ((1u << 30) - 17 * (1u << 17) + 1) +#define OMEGA2 315046632 +#define Q2_CRT_CST 292938863 // (Q1*Q3*Q4)^-1 mod Q2 + +#define Q3 ((1u << 30) - 23 * (1u << 17) + 1) +#define OMEGA3 309185662 +#define Q3_CRT_CST 594011630 // (Q1*Q2*Q4)^-1 mod Q3 + +#define Q4 ((1u << 30) - 42 * (1u << 17) + 1) +#define OMEGA4 846468380 +#define Q4_CRT_CST 140177212 // (Q1*Q2*Q3)^-1 mod Q4 +#endif + +/** + * 31-bit primes and 2*2^16 roots of unity + */ +#ifdef SPQLIOS_Q120_USE_31_BIT_PRIMES +#define Q1 ((1u << 31) - 1 * (1u << 17) + 1) +#define OMEGA1 1615402923 +#define Q1_CRT_CST 1811422063 // (Q2*Q3*Q4)^-1 mod Q1 + +#define Q2 ((1u << 31) - 4 * (1u << 17) + 1) +#define OMEGA2 1137738560 +#define Q2_CRT_CST 2093150204 // (Q1*Q3*Q4)^-1 mod Q2 + +#define Q3 ((1u << 31) - 11 * (1u << 17) + 1) +#define OMEGA3 154880552 +#define Q3_CRT_CST 164149010 // (Q1*Q2*Q4)^-1 mod Q3 + +#define Q4 ((1u << 31) - 23 * (1u << 17) + 1) +#define OMEGA4 558784885 +#define Q4_CRT_CST 225197446 // (Q1*Q2*Q3)^-1 mod Q4 +#endif + +static const uint32_t PRIMES_VEC[4] = {Q1, Q2, Q3, Q4}; +static const uint32_t OMEGAS_VEC[4] = {OMEGA1, OMEGA2, OMEGA3, OMEGA4}; + +#define MAX_ELL 10000 + +// each number x mod Q120 is represented by uint64_t[4] with (non-unique) values (x mod q1, x mod q2,x mod q3,x mod q4), +// each between [0 and 2^32-1] +typedef struct _q120a q120a; + +// each number x mod Q120 is represented by uint64_t[4] with (non-unique) values (x mod q1, x mod q2,x mod q3,x mod q4), +// each between [0 and 2^64-1] +typedef struct _q120b q120b; + +// each number x mod Q120 is represented by uint32_t[8] with values (x mod q1, 2^32x mod q1, x mod q2, 2^32.x mod q2, x +// mod q3, 2^32.x mod q3, x mod q4, 2^32.x mod q4) each between [0 and 2^32-1] +typedef struct _q120c q120c; + +typedef struct _q120x2b q120x2b; +typedef struct _q120x2c q120x2c; + +#endif // SPQLIOS_Q120_COMMON_H diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/q120/q120_fallbacks_aarch64.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/q120/q120_fallbacks_aarch64.c new file mode 100644 index 0000000..18db6cb --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/q120/q120_fallbacks_aarch64.c @@ -0,0 +1,5 @@ +#include "q120_ntt_private.h" + +EXPORT void q120_ntt_bb_avx2(const q120_ntt_precomp* const precomp, q120b* const data) { UNDEFINED(); } + +EXPORT void q120_intt_bb_avx2(const q120_ntt_precomp* const precomp, q120b* const data) { UNDEFINED(); } diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/q120/q120_ntt.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/q120/q120_ntt.c new file mode 100644 index 0000000..f58b98e --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/q120/q120_ntt.c @@ -0,0 +1,340 @@ +#include +#include +#include +#include + +#include "q120_ntt_private.h" + +q120_ntt_precomp* new_precomp(const uint64_t n) { + q120_ntt_precomp* precomp = malloc(sizeof(*precomp)); + precomp->n = n; + + assert(n && !(n & (n - 1)) && n <= (1 << 16)); // n is a power of 2 smaller than 2^16 + const uint64_t logN = ceil(log2(n)); + precomp->level_metadata = malloc((logN + 2) * sizeof(*precomp->level_metadata)); + + precomp->powomega = spqlios_alloc_custom_align(32, 4 * 2 * n * sizeof(*(precomp->powomega))); + + return precomp; +} + +uint32_t modq_pow(const uint32_t x, const int64_t n, const uint32_t q) { + uint64_t np = (n % (q - 1) + q - 1) % (q - 1); + + uint64_t val_pow = x; + uint64_t res = 1; + while (np != 0) { + if (np & 1) res = (res * val_pow) % q; + val_pow = (val_pow * val_pow) % q; + np >>= 1; + } + return res; +} + +void fill_omegas(const uint64_t n, uint32_t omegas[4]) { + for (uint64_t k = 0; k < 4; ++k) { + omegas[k] = modq_pow(OMEGAS_VEC[k], (1 << 16) / n, PRIMES_VEC[k]); + } + +#ifndef NDEBUG + + const uint64_t logQ = ceil(log2(PRIMES_VEC[0])); + for (int k = 1; k < 4; ++k) { + if (logQ != ceil(log2(PRIMES_VEC[k]))) { + fprintf(stderr, "The 4 primes must have the same bit-size\n"); + exit(-1); + } + } + + // check if each omega is a 2.n primitive root of unity + for (uint64_t k = 0; k < 4; ++k) { + assert(modq_pow(omegas[k], 2 * n, PRIMES_VEC[k]) == 1); + for (uint64_t i = 1; i < 2 * n; ++i) { + assert(modq_pow(omegas[k], i, PRIMES_VEC[k]) != 1); + } + } + + if (logQ > 31) { + fprintf(stderr, "Modulus q bit-size is larger than 30 bit\n"); + exit(-1); + } +#endif +} + +uint64_t fill_reduction_meta(const uint64_t bs_start, q120_ntt_reduc_step_precomp* reduc_metadata) { + // fill reduction metadata + uint64_t bs_after_reduc = -1; + { + uint64_t min_h = -1; + + for (uint64_t h = bs_start / 2; h < bs_start; ++h) { + uint64_t t = 0; + for (uint64_t k = 0; k < 4; ++k) { + const uint64_t t1 = bs_start - h + (uint64_t)ceil(log2((UINT64_C(1) << h) % PRIMES_VEC[k])); + const uint64_t t2 = UINT64_C(1) + ((t1 > h) ? t1 : h); + if (t < t2) t = t2; + } + if (t < bs_after_reduc) { + min_h = h; + bs_after_reduc = t; + } + } + + reduc_metadata->h = min_h; + reduc_metadata->mask = (UINT64_C(1) << min_h) - 1; + for (uint64_t k = 0; k < 4; ++k) { + reduc_metadata->modulo_red_cst[k] = (UINT64_C(1) << min_h) % PRIMES_VEC[k]; + } + + assert(bs_after_reduc < 64); + } + + return bs_after_reduc; +} + +uint64_t round_up_half_n(const uint64_t n) { return (n + 1) / 2; } + +EXPORT q120_ntt_precomp* q120_new_ntt_bb_precomp(const uint64_t n) { + uint32_t omega_vec[4]; + fill_omegas(n, omega_vec); + + const uint64_t logQ = ceil(log2(PRIMES_VEC[0])); + + q120_ntt_precomp* precomp = new_precomp(n); + + uint64_t bs = precomp->input_bit_size = 64; + + LOG("NTT parameters:\n"); + LOG("\tsize = %" PRIu64 "\n", n) + LOG("\tlogQ = %" PRIu64 "\n", logQ); + LOG("\tinput bit-size = %" PRIu64 "\n", bs); + + if (n == 1) return precomp; + + // fill reduction metadata + uint64_t bs_after_reduc = fill_reduction_meta(bs, &(precomp->reduc_metadata)); + + // forward metadata + q120_ntt_step_precomp* level_metadata_ptr = precomp->level_metadata; + + // first level a_k.omega^k + { + const uint64_t half_bs = (bs + 1) / 2; + level_metadata_ptr->half_bs = half_bs; + level_metadata_ptr->mask = (UINT64_C(1) << half_bs) - UINT64_C(1); + level_metadata_ptr->bs = bs = half_bs + logQ + 1; + LOG("\tlevel %6" PRIu64 " output bit-size = %" PRIu64 " (a_k.omega^k) \n", n, bs); + level_metadata_ptr++; + } + + for (uint64_t nn = n; nn >= 4; nn /= 2) { + level_metadata_ptr->reduce = (bs == 64); + if (level_metadata_ptr->reduce) { + bs = bs_after_reduc; + LOG("\treduce output bit-size = %" PRIu64 "\n", bs); + } + + for (int k = 0; k < 4; ++k) level_metadata_ptr->q2bs[k] = (uint64_t)PRIMES_VEC[k] << (bs - logQ); + + double bs_1 = bs + 1; // bit-size of term a+b or a-b + + const uint64_t half_bs = round_up_half_n(bs_1); + uint64_t bs_2 = half_bs + logQ + 1; // bit-size of term (a-b).omega^k + bs = (bs_1 > bs_2) ? bs_1 : bs_2; + assert(bs <= 64); + + level_metadata_ptr->bs = bs; + level_metadata_ptr->half_bs = half_bs; + level_metadata_ptr->mask = (UINT64_C(1) << half_bs) - UINT64_C(1); + level_metadata_ptr++; + + LOG("\tlevel %6" PRIu64 " output bit-size = %" PRIu64 "\n", nn / 2, bs); + } + + // last level (a-b, a+b) + { + level_metadata_ptr->reduce = (bs == 64); + if (level_metadata_ptr->reduce) { + bs = bs_after_reduc; + LOG("\treduce output bit-size = %" PRIu64 "\n", bs); + } + for (int k = 0; k < 4; ++k) level_metadata_ptr->q2bs[k] = ((uint64_t)PRIMES_VEC[k] << (bs - logQ)); + level_metadata_ptr->bs = ++bs; + level_metadata_ptr->half_bs = level_metadata_ptr->mask = UINT64_C(0); // not used + + LOG("\tlevel %6" PRIu64 " output bit-size = %" PRIu64 "\n", UINT64_C(1), bs); + } + precomp->output_bit_size = bs; + + // omega powers + uint64_t* powomega = malloc(sizeof(*powomega) * 2 * n); + for (uint64_t k = 0; k < 4; ++k) { + const uint64_t q = PRIMES_VEC[k]; + + for (uint64_t i = 0; i < 2 * n; ++i) { + powomega[i] = modq_pow(omega_vec[k], i, q); + } + + uint64_t* powomega_ptr = precomp->powomega + k; + level_metadata_ptr = precomp->level_metadata; + + { + // const uint64_t hpow = UINT64_C(1) << level_metadata_ptr->half_bs; + for (uint64_t i = 0; i < n; ++i) { + uint64_t t = powomega[i]; + uint64_t t1 = (t << level_metadata_ptr->half_bs) % q; + powomega_ptr[4 * i] = (t1 << 32) + t; + } + powomega_ptr += 4 * n; + level_metadata_ptr++; + } + + for (uint64_t nn = n; nn >= 4; nn /= 2) { + const uint64_t halfnn = nn / 2; + const uint64_t m = n / halfnn; + + // const uint64_t hpow = UINT64_C(1) << level_metadata_ptr->half_bs; + for (uint64_t i = 1; i < halfnn; ++i) { + uint64_t t = powomega[i * m]; + uint64_t t1 = (t << level_metadata_ptr->half_bs) % q; + powomega_ptr[4 * (i - 1)] = (t1 << 32) + t; + } + powomega_ptr += 4 * (halfnn - 1); + level_metadata_ptr++; + } + } + free(powomega); + + return precomp; +} + +EXPORT q120_ntt_precomp* q120_new_intt_bb_precomp(const uint64_t n) { + uint32_t omega_vec[4]; + fill_omegas(n, omega_vec); + + const uint64_t logQ = ceil(log2(PRIMES_VEC[0])); + + q120_ntt_precomp* precomp = new_precomp(n); + + uint64_t bs = precomp->input_bit_size = 64; + + LOG("iNTT parameters:\n"); + LOG("\tsize = %" PRIu64 "\n", n) + LOG("\tlogQ = %" PRIu64 "\n", logQ); + LOG("\tinput bit-size = %" PRIu64 "\n", bs); + + if (n == 1) return precomp; + + // fill reduction metadata + uint64_t bs_after_reduc = fill_reduction_meta(bs, &(precomp->reduc_metadata)); + + // backward metadata + q120_ntt_step_precomp* level_metadata_ptr = precomp->level_metadata; + + // first level (a+b, a-b) adds 1-bit + { + level_metadata_ptr->reduce = (bs == 64); + if (level_metadata_ptr->reduce) { + bs = bs_after_reduc; + LOG("\treduce output bit-size = %" PRIu64 "\n", bs); + } + + for (int k = 0; k < 4; ++k) level_metadata_ptr->q2bs[k] = (uint64_t)PRIMES_VEC[k] << (bs - logQ); + + level_metadata_ptr->bs = ++bs; + level_metadata_ptr->half_bs = level_metadata_ptr->mask = UINT64_C(0); // not used + LOG("\tlevel %6" PRIu64 " output bit-size = %" PRIu64 "\n", UINT64_C(1), bs); + level_metadata_ptr++; + } + + for (uint64_t nn = 4; nn <= n; nn *= 2) { + level_metadata_ptr->reduce = (bs == 64); + if (level_metadata_ptr->reduce) { + bs = bs_after_reduc; + LOG("\treduce output bit-size = %" PRIu64 "\n", bs); + } + + const uint64_t half_bs = round_up_half_n(bs); + const uint64_t bs_mult = half_bs + logQ + 1; // bit-size of term b.omega^k + bs = 1 + ((bs > bs_mult) ? bs : bs_mult); // bit-size of a+b.omega^k or a-b.omega^k + assert(bs <= 64); + + level_metadata_ptr->bs = bs; + level_metadata_ptr->half_bs = half_bs; + level_metadata_ptr->mask = (UINT64_C(1) << half_bs) - UINT64_C(1); + for (int k = 0; k < 4; ++k) level_metadata_ptr->q2bs[k] = (uint64_t)PRIMES_VEC[k] << (bs_mult - logQ); + level_metadata_ptr++; + + LOG("\tlevel %6" PRIu64 " output bit-size = %" PRIu64 "\n", nn / 2, bs); + } + + // last level a_k.omega^k + { + level_metadata_ptr->reduce = (bs == 64); + if (level_metadata_ptr->reduce) { + bs = bs_after_reduc; + LOG("\treduce output bit-size = %" PRIu64 "\n", bs); + } + + const uint64_t half_bs = round_up_half_n(bs); + + bs = half_bs + logQ + 1; // bit-size of term a.omega^k + assert(bs <= 64); + + level_metadata_ptr->bs = bs; + level_metadata_ptr->half_bs = half_bs; + level_metadata_ptr->mask = (UINT64_C(1) << half_bs) - UINT64_C(1); + for (int k = 0; k < 4; ++k) level_metadata_ptr->q2bs[k] = (uint64_t)PRIMES_VEC[k] << (bs - logQ); + + LOG("\tlevel %6" PRIu64 " output bit-size = %" PRIu64 "\n", n, bs); + } + + // omega powers + uint32_t* powomegabar = malloc(sizeof(*powomegabar) * 2 * n); + for (int k = 0; k < 4; ++k) { + const uint64_t q = PRIMES_VEC[k]; + + for (uint64_t i = 0; i < 2 * n; ++i) { + powomegabar[i] = modq_pow(omega_vec[k], -i, q); + } + + uint64_t* powomega_ptr = precomp->powomega + k; + level_metadata_ptr = precomp->level_metadata + 1; + + for (uint64_t nn = 4; nn <= n; nn *= 2) { + const uint64_t halfnn = nn / 2; + const uint64_t m = n / halfnn; + + for (uint64_t i = 1; i < halfnn; ++i) { + uint64_t t = powomegabar[i * m]; + uint64_t t1 = (t << level_metadata_ptr->half_bs) % q; + powomega_ptr[4 * (i - 1)] = (t1 << 32) + t; + } + powomega_ptr += 4 * (halfnn - 1); + level_metadata_ptr++; + } + + { + const uint64_t invNmod = modq_pow(n, -1, q); + for (uint64_t i = 0; i < n; ++i) { + uint64_t t = (powomegabar[i] * invNmod) % q; + uint64_t t1 = (t << level_metadata_ptr->half_bs) % q; + powomega_ptr[4 * i] = (t1 << 32) + t; + } + } + } + + free(powomegabar); + + return precomp; +} + +void del_precomp(q120_ntt_precomp* precomp) { + spqlios_free(precomp->powomega); + free(precomp->level_metadata); + free(precomp); +} + +EXPORT void q120_del_ntt_bb_precomp(q120_ntt_precomp* precomp) { del_precomp(precomp); } + +EXPORT void q120_del_intt_bb_precomp(q120_ntt_precomp* precomp) { del_precomp(precomp); } diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/q120/q120_ntt.h b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/q120/q120_ntt.h new file mode 100644 index 0000000..329b54d --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/q120/q120_ntt.h @@ -0,0 +1,25 @@ +#ifndef SPQLIOS_Q120_NTT_H +#define SPQLIOS_Q120_NTT_H + +#include "../commons.h" +#include "q120_common.h" + +typedef struct _q120_ntt_precomp q120_ntt_precomp; + +EXPORT q120_ntt_precomp* q120_new_ntt_bb_precomp(const uint64_t n); +EXPORT void q120_del_ntt_bb_precomp(q120_ntt_precomp* precomp); + +EXPORT q120_ntt_precomp* q120_new_intt_bb_precomp(const uint64_t n); +EXPORT void q120_del_intt_bb_precomp(q120_ntt_precomp* precomp); + +/** + * @brief computes a direct ntt in-place over data. + */ +EXPORT void q120_ntt_bb_avx2(const q120_ntt_precomp* const precomp, q120b* const data); + +/** + * @brief computes an inverse ntt in-place over data. + */ +EXPORT void q120_intt_bb_avx2(const q120_ntt_precomp* const precomp, q120b* const data); + +#endif // SPQLIOS_Q120_NTT_H diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/q120/q120_ntt_avx2.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/q120/q120_ntt_avx2.c new file mode 100644 index 0000000..9d0b547 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/q120/q120_ntt_avx2.c @@ -0,0 +1,479 @@ +#include +#include +#include +#include +#include +#include + +#include "q120_common.h" +#include "q120_ntt_private.h" + +// at which level to switch from computations by level to computations by block +#define CHANGE_MODE_N 1024 + +__always_inline __m256i split_precompmul_si256(__m256i inp, __m256i powomega, const uint64_t h, const __m256i mask) { + const __m256i inp_low = _mm256_and_si256(inp, mask); + const __m256i t1 = _mm256_mul_epu32(inp_low, powomega); + + const __m256i inp_high = _mm256_srli_epi64(inp, h); + const __m256i powomega_high = _mm256_srli_epi64(powomega, 32); + const __m256i t2 = _mm256_mul_epu32(inp_high, powomega_high); + + return _mm256_add_epi64(t1, t2); +} + +__always_inline __m256i modq_red(const __m256i x, const uint64_t h, const __m256i mask, const __m256i _2_pow_h_modq) { + __m256i xh = _mm256_srli_epi64(x, h); + __m256i xl = _mm256_and_si256(x, mask); + __m256i xh_1 = _mm256_mul_epu32(xh, _2_pow_h_modq); + return _mm256_add_epi64(xl, xh_1); +} + +void print_data(const uint64_t n, const uint64_t* const data, const uint64_t q) { + for (uint64_t i = 0; i < n; i++) { + printf("%" PRIu64 " ", *(data + i) % q); + } + printf("\n"); +} + +double max_bit_size(const void* const begin, const void* const end) { + double bs = 0; + const uint64_t* data = (uint64_t*)begin; + while (data != end) { + double t = log2(*(data++)); + if (bs < t) { + bs = t; + } + } + return bs; +} + +void ntt_iter_first(__m256i* const begin, const __m256i* const end, const q120_ntt_step_precomp* const itData, + const __m256i* powomega) { + const uint64_t h = itData->half_bs; + const __m256i vmask = _mm256_set1_epi64x(itData->mask); + + __m256i* data = begin; + while (data < end) { + __m256i x = _mm256_loadu_si256(data); + __m256i po = _mm256_loadu_si256(powomega); + __m256i r = split_precompmul_si256(x, po, h, vmask); + _mm256_storeu_si256(data, r); + + data++; + powomega++; + } +} + +void ntt_iter(const uint64_t nn, __m256i* const begin, const __m256i* const end, + const q120_ntt_step_precomp* const itData, const __m256i* const powomega) { + assert(nn % 2 == 0); + const uint64_t halfnn = nn / 2; + + const __m256i vq2bs = _mm256_loadu_si256((__m256i*)itData->q2bs); + const __m256i vmask = _mm256_set1_epi64x(itData->mask); + + __m256i* data = begin; + while (data < end) { + __m256i* ptr1 = data; + __m256i* ptr2 = data + halfnn; + + const __m256i a = _mm256_loadu_si256(ptr1); + const __m256i b = _mm256_loadu_si256(ptr2); + + const __m256i ap = _mm256_add_epi64(a, b); + _mm256_storeu_si256(ptr1, ap); + + const __m256i bp = _mm256_sub_epi64(_mm256_add_epi64(a, vq2bs), b); + _mm256_storeu_si256(ptr2, bp); + + ptr1++; + ptr2++; + + const __m256i* po_ptr = powomega; + for (uint64_t i = 1; i < halfnn; ++i) { + __m256i a = _mm256_loadu_si256(ptr1); + __m256i b = _mm256_loadu_si256(ptr2); + + __m256i ap = _mm256_add_epi64(a, b); + + _mm256_storeu_si256(ptr1, ap); + + __m256i b1 = _mm256_sub_epi64(_mm256_add_epi64(a, vq2bs), b); + __m256i po = _mm256_loadu_si256(po_ptr); + + __m256i bp = split_precompmul_si256(b1, po, itData->half_bs, vmask); + + _mm256_storeu_si256(ptr2, bp); + + ptr1++; + ptr2++; + po_ptr++; + } + data += nn; + } +} + +void ntt_iter_red(const uint64_t nn, __m256i* const begin, const __m256i* const end, + const q120_ntt_step_precomp* const itData, const __m256i* const powomega, + const q120_ntt_reduc_step_precomp* const reduc_precomp) { + assert(nn % 2 == 0); + const uint64_t halfnn = nn / 2; + + const __m256i vq2bs = _mm256_loadu_si256((__m256i*)itData->q2bs); + const __m256i vmask = _mm256_set1_epi64x(itData->mask); + + const __m256i reduc_mask = _mm256_set1_epi64x(reduc_precomp->mask); + const __m256i reduc_cst = _mm256_loadu_si256((__m256i*)reduc_precomp->modulo_red_cst); + + __m256i* data = begin; + while (data < end) { + __m256i* ptr1 = data; + __m256i* ptr2 = data + halfnn; + + __m256i a = _mm256_loadu_si256(ptr1); + __m256i b = _mm256_loadu_si256(ptr2); + + a = modq_red(a, reduc_precomp->h, reduc_mask, reduc_cst); + b = modq_red(b, reduc_precomp->h, reduc_mask, reduc_cst); + + const __m256i ap = _mm256_add_epi64(a, b); + _mm256_storeu_si256(ptr1, ap); + + const __m256i bp = _mm256_sub_epi64(_mm256_add_epi64(a, vq2bs), b); + _mm256_storeu_si256(ptr2, bp); + + ptr1++; + ptr2++; + + const __m256i* po_ptr = powomega; + for (uint64_t i = 1; i < halfnn; ++i) { + __m256i a = _mm256_loadu_si256(ptr1); + __m256i b = _mm256_loadu_si256(ptr2); + + a = modq_red(a, reduc_precomp->h, reduc_mask, reduc_cst); + b = modq_red(b, reduc_precomp->h, reduc_mask, reduc_cst); + + __m256i ap = _mm256_add_epi64(a, b); + + _mm256_storeu_si256(ptr1, ap); + + __m256i bp = _mm256_sub_epi64(_mm256_add_epi64(a, vq2bs), b); + __m256i po = _mm256_loadu_si256(po_ptr); + bp = split_precompmul_si256(bp, po, itData->half_bs, vmask); + + _mm256_storeu_si256(ptr2, bp); + + ptr1++; + ptr2++; + po_ptr++; + } + data += nn; + } +} + +EXPORT void q120_ntt_bb_avx2(const q120_ntt_precomp* const precomp, q120b* const data_ptr) { + // assert((size_t)data_ptr % 32 == 0); // alignment check + + const uint64_t n = precomp->n; + if (n == 1) return; + + const q120_ntt_step_precomp* itData = precomp->level_metadata; + const __m256i* powomega = (__m256i*)precomp->powomega; + + __m256i* const begin = (__m256i*)data_ptr; + const __m256i* const end = ((__m256i*)data_ptr) + n; + + if (CHECK_BOUNDS) { + double bs __attribute__((unused)) = max_bit_size((void*)begin, (void*)end); + LOG("Input %lf %" PRIu64 "\n", bs, precomp->input_bit_size); + assert(bs <= precomp->input_bit_size); + } + + // first iteration a_k.omega^k + ntt_iter_first(begin, end, itData, powomega); + + if (CHECK_BOUNDS) { + double bs __attribute__((unused)) = max_bit_size((void*)begin, (void*)end); + LOG("Iter %3" PRIu64 " - %lf %" PRIu64 "\n", n, bs, itData->bs); + assert(bs < itData->bs); + } + + powomega += n; + itData++; + + const uint64_t split_nn = (CHANGE_MODE_N > n) ? n : CHANGE_MODE_N; + // const uint64_t split_nn = 2; + + // computations by level + for (uint64_t nn = n; nn > split_nn; nn /= 2) { + const uint64_t halfnn = nn / 2; + + if (itData->reduce) { + ntt_iter_red(nn, begin, end, itData, powomega, &precomp->reduc_metadata); + } else { + ntt_iter(nn, begin, end, itData, powomega); + } + + if (CHECK_BOUNDS) { + double bs __attribute__((unused)) = max_bit_size((void*)begin, (void*)end); + LOG("Iter %3" PRIu64 " - %lf %" PRIu64 " %c\n", nn / 2, bs, itData->bs, itData->reduce ? '*' : ' '); + assert(bs < itData->bs); + } + + powomega += halfnn - 1; + itData++; + } + + // computations by memory block + if (split_nn >= 2) { + const q120_ntt_step_precomp* itData1 = itData; + const __m256i* powomega1 = powomega; + for (__m256i* it = begin; it < end; it += split_nn) { + __m256i* const begin1 = it; + const __m256i* const end1 = it + split_nn; + + itData = itData1; + powomega = powomega1; + for (uint64_t nn = split_nn; nn >= 2; nn /= 2) { + const uint64_t halfnn = nn / 2; + + if (itData->reduce) { + ntt_iter_red(nn, begin1, end1, itData, powomega, &precomp->reduc_metadata); + } else { + ntt_iter(nn, begin1, end1, itData, powomega); + } + + if (CHECK_BOUNDS) { + double bs __attribute__((unused)) = max_bit_size((uint64_t*)begin1, (uint64_t*)end1); + // LOG("Iter %3lu - %lf %lu\n", nn / 2, bs, itData->bs); + assert(bs < itData->bs); + } + + powomega += halfnn - 1; + itData++; + } + } + } + + if (CHECK_BOUNDS) { + double bs __attribute__((unused)) = max_bit_size((void*)begin, (void*)end); + LOG("Iter %3" PRIu64 " - %lf %" PRIu64 "\n", UINT64_C(1), bs, precomp->output_bit_size); + assert(bs < precomp->output_bit_size); + } +} + +void intt_iter(const uint64_t nn, __m256i* const begin, const __m256i* const end, + const q120_ntt_step_precomp* const itData, const __m256i* const powomega) { + assert(nn % 2 == 0); + const uint64_t halfnn = nn / 2; + + const __m256i vq2bs = _mm256_loadu_si256((__m256i*)itData->q2bs); + const __m256i vmask = _mm256_set1_epi64x(itData->mask); + + __m256i* data = begin; + while (data < end) { + __m256i* ptr1 = data; + __m256i* ptr2 = data + halfnn; + + const __m256i a = _mm256_loadu_si256(ptr1); + const __m256i b = _mm256_loadu_si256(ptr2); + + const __m256i ap = _mm256_add_epi64(a, b); + _mm256_storeu_si256(ptr1, ap); + + const __m256i bp = _mm256_sub_epi64(_mm256_add_epi64(a, vq2bs), b); + _mm256_storeu_si256(ptr2, bp); + + ptr1++; + ptr2++; + + const __m256i* po_ptr = powomega; + for (uint64_t i = 1; i < halfnn; ++i) { + __m256i a = _mm256_loadu_si256(ptr1); + __m256i b = _mm256_loadu_si256(ptr2); + + __m256i po = _mm256_loadu_si256(po_ptr); + __m256i bo = split_precompmul_si256(b, po, itData->half_bs, vmask); + + __m256i ap = _mm256_add_epi64(a, bo); + + _mm256_storeu_si256(ptr1, ap); + + __m256i bp = _mm256_sub_epi64(_mm256_add_epi64(a, vq2bs), bo); + + _mm256_storeu_si256(ptr2, bp); + + ptr1++; + ptr2++; + po_ptr++; + } + data += nn; + } +} + +void intt_iter_red(const uint64_t nn, __m256i* const begin, const __m256i* const end, + const q120_ntt_step_precomp* const itData, const __m256i* const powomega, + const q120_ntt_reduc_step_precomp* const reduc_precomp) { + assert(nn % 2 == 0); + const uint64_t halfnn = nn / 2; + + const __m256i vq2bs = _mm256_loadu_si256((__m256i*)itData->q2bs); + const __m256i vmask = _mm256_set1_epi64x(itData->mask); + + const __m256i reduc_mask = _mm256_set1_epi64x(reduc_precomp->mask); + const __m256i reduc_cst = _mm256_loadu_si256((__m256i*)reduc_precomp->modulo_red_cst); + + __m256i* data = begin; + while (data < end) { + __m256i* ptr1 = data; + __m256i* ptr2 = data + halfnn; + + __m256i a = _mm256_loadu_si256(ptr1); + __m256i b = _mm256_loadu_si256(ptr2); + + a = modq_red(a, reduc_precomp->h, reduc_mask, reduc_cst); + b = modq_red(b, reduc_precomp->h, reduc_mask, reduc_cst); + + const __m256i ap = _mm256_add_epi64(a, b); + _mm256_storeu_si256(ptr1, ap); + + const __m256i bp = _mm256_sub_epi64(_mm256_add_epi64(a, vq2bs), b); + _mm256_storeu_si256(ptr2, bp); + + ptr1++; + ptr2++; + + const __m256i* po_ptr = powomega; + for (uint64_t i = 1; i < halfnn; ++i) { + __m256i a = _mm256_loadu_si256(ptr1); + __m256i b = _mm256_loadu_si256(ptr2); + + a = modq_red(a, reduc_precomp->h, reduc_mask, reduc_cst); + b = modq_red(b, reduc_precomp->h, reduc_mask, reduc_cst); + + __m256i po = _mm256_loadu_si256(po_ptr); + __m256i bo = split_precompmul_si256(b, po, itData->half_bs, vmask); + + __m256i ap = _mm256_add_epi64(a, bo); + + _mm256_storeu_si256(ptr1, ap); + + __m256i bp = _mm256_sub_epi64(_mm256_add_epi64(a, vq2bs), bo); + + _mm256_storeu_si256(ptr2, bp); + + ptr1++; + ptr2++; + po_ptr++; + } + data += nn; + } +} + +void ntt_iter_first_red(__m256i* const begin, const __m256i* const end, const q120_ntt_step_precomp* const itData, + const __m256i* powomega, const q120_ntt_reduc_step_precomp* const reduc_precomp) { + const uint64_t h = itData->half_bs; + const __m256i vmask = _mm256_set1_epi64x(itData->mask); + + const __m256i reduc_mask = _mm256_set1_epi64x(reduc_precomp->mask); + const __m256i reduc_cst = _mm256_loadu_si256((__m256i*)reduc_precomp->modulo_red_cst); + + __m256i* data = begin; + while (data < end) { + __m256i x = _mm256_loadu_si256(data); + x = modq_red(x, reduc_precomp->h, reduc_mask, reduc_cst); + __m256i po = _mm256_loadu_si256(powomega); + __m256i r = split_precompmul_si256(x, po, h, vmask); + _mm256_storeu_si256(data, r); + + data++; + powomega++; + } +} + +EXPORT void q120_intt_bb_avx2(const q120_ntt_precomp* const precomp, q120b* const data_ptr) { + // assert((size_t)data_ptr % 32 == 0); // alignment check + + const uint64_t n = precomp->n; + if (n == 1) return; + + const q120_ntt_step_precomp* itData = precomp->level_metadata; + const __m256i* powomega = (__m256i*)precomp->powomega; + + __m256i* const begin = (__m256i*)data_ptr; + const __m256i* const end = ((__m256i*)data_ptr) + n; + + if (CHECK_BOUNDS) { + double bs __attribute__((unused)) = max_bit_size((void*)begin, (void*)end); + LOG("Input %lf %" PRIu64 "\n", bs, precomp->input_bit_size); + assert(bs <= precomp->input_bit_size); + } + + const uint64_t split_nn = (CHANGE_MODE_N > n) ? n : CHANGE_MODE_N; + + // computations by memory block + if (split_nn >= 2) { + const q120_ntt_step_precomp* itData1 = itData; + const __m256i* powomega1 = powomega; + for (__m256i* it = begin; it < end; it += split_nn) { + __m256i* const begin1 = it; + const __m256i* const end1 = it + split_nn; + + itData = itData1; + powomega = powomega1; + for (uint64_t nn = 2; nn <= split_nn; nn *= 2) { + const uint64_t halfnn = nn / 2; + + if (itData->reduce) { + intt_iter_red(nn, begin1, end1, itData, powomega, &precomp->reduc_metadata); + } else { + intt_iter(nn, begin1, end1, itData, powomega); + } + + if (CHECK_BOUNDS) { + double bs __attribute__((unused)) = max_bit_size((uint64_t*)begin1, (uint64_t*)end1); + // LOG("Iter %3lu - %lf %lu\n", nn / 2, bs, itData->bs); + assert(bs < itData->bs); + } + + powomega += halfnn - 1; + itData++; + } + } + } + + // computations by level + // for (uint64_t nn = 2; nn <= n; nn *= 2) { + for (uint64_t nn = 2 * split_nn; nn <= n; nn *= 2) { + const uint64_t halfnn = nn / 2; + + if (itData->reduce) { + intt_iter_red(nn, begin, end, itData, powomega, &precomp->reduc_metadata); + } else { + intt_iter(nn, begin, end, itData, powomega); + } + + if (CHECK_BOUNDS) { + double bs __attribute__((unused)) = max_bit_size((void*)begin, (void*)end); + LOG("Iter %3" PRIu64 " - %lf %" PRIu64 " %c\n", nn / 2, bs, itData->bs, itData->reduce ? '*' : ' '); + assert(bs < itData->bs); + } + + powomega += halfnn - 1; + itData++; + } + + // last iteration a_k . omega^k . n^-1 + if (itData->reduce) { + ntt_iter_first_red(begin, end, itData, powomega, &precomp->reduc_metadata); + } else { + ntt_iter_first(begin, end, itData, powomega); + } + + if (CHECK_BOUNDS) { + double bs __attribute__((unused)) = max_bit_size((void*)begin, (void*)end); + LOG("Iter %3" PRIu64 " - %lf %" PRIu64 "\n", n, bs, itData->bs); + assert(bs < itData->bs); + } +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/q120/q120_ntt_private.h b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/q120/q120_ntt_private.h new file mode 100644 index 0000000..c727ecd --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/q120/q120_ntt_private.h @@ -0,0 +1,39 @@ +#include "q120_ntt.h" + +#ifndef NDEBUG +#define CHECK_BOUNDS 1 +#define VERBOSE +#else +#define CHECK_BOUNDS 0 +#endif + +#ifndef VERBOSE +#define LOG(...) ; +#else +#define LOG(...) printf(__VA_ARGS__); +#endif + +typedef struct _q120_ntt_step_precomp { + uint64_t q2bs[4]; // q2bs = 2^{bs-31}.q[k] + uint64_t bs; // inputs at this iterations must be in Q_n + uint64_t half_bs; // == ceil(bs/2) + uint64_t mask; // (1< +#include + +#include "../commons_private.h" +#include "reim_fft_internal.h" +#include "reim_fft_private.h" + +void reim_from_znx64_ref(const REIM_FROM_ZNX64_PRECOMP* precomp, void* r, const int64_t* x) { + // naive version of the function (just cast) + const uint64_t nn = precomp->m << 1; + double* res = (double*)r; + for (uint64_t i = 0; i < nn; ++i) { + res[i] = (double)x[i]; + } +} + +void* init_reim_from_znx64_precomp(REIM_FROM_ZNX64_PRECOMP* const res, uint32_t m, uint32_t log2bound) { + if (m & (m - 1)) return spqlios_error("m must be a power of 2"); + // currently, we are going to use the trick add 3.2^51, mask the exponent, reinterpret bits as double. + // therefore we need the input values to be < 2^50. + if (log2bound > 50) return spqlios_error("Invalid log2bound error: must be in [0,50]"); + res->m = m; + FROM_ZNX64_FUNC resf = reim_from_znx64_ref; + if (m >= 8) { + if (CPU_SUPPORTS("avx2")) { + resf = reim_from_znx64_bnd50_fma; + } + } + res->function = resf; + return res; +} + +EXPORT REIM_FROM_ZNX64_PRECOMP* new_reim_from_znx64_precomp(uint32_t m, uint32_t log2bound) { + REIM_FROM_ZNX64_PRECOMP* res = malloc(sizeof(*res)); + if (!res) return spqlios_error(strerror(errno)); + return spqlios_keep_or_free(res, init_reim_from_znx64_precomp(res, m, log2bound)); +} + +EXPORT void reim_from_znx64(const REIM_FROM_ZNX64_PRECOMP* tables, void* r, const int64_t* a) { + tables->function(tables, r, a); +} + +/** + * @brief Simpler API for the znx64 to reim conversion. + * For each dimension, the precomputed tables for this dimension are generated automatically the first time. + * It is advised to do one dry-run call per desired dimension before using in a multithread environment + */ +EXPORT void reim_from_znx64_simple(uint32_t m, uint32_t log2bound, void* r, const int64_t* a) { + // not checking for log2bound which is not relevant here + static REIM_FROM_ZNX64_PRECOMP precomp[32]; + REIM_FROM_ZNX64_PRECOMP* p = precomp + log2m(m); + if (!p->function) { + if (!init_reim_from_znx64_precomp(p, m, log2bound)) abort(); + } + p->function(p, r, a); +} + +void reim_from_znx32_ref(const REIM_FROM_ZNX32_PRECOMP* precomp, void* r, const int32_t* x) { NOT_IMPLEMENTED(); } +void reim_from_znx32_avx2_fma(const REIM_FROM_ZNX32_PRECOMP* precomp, void* r, const int32_t* x) { NOT_IMPLEMENTED(); } + +void* init_reim_from_znx32_precomp(REIM_FROM_ZNX32_PRECOMP* const res, uint32_t m, uint32_t log2bound) { + if (m & (m - 1)) return spqlios_error("m must be a power of 2"); + if (log2bound > 32) return spqlios_error("Invalid log2bound error: must be in [0,32]"); + res->m = m; + // TODO: check selection logic + if (CPU_SUPPORTS("avx2")) { + if (m >= 8) { + res->function = reim_from_znx32_avx2_fma; + } else { + res->function = reim_from_znx32_ref; + } + } else { + res->function = reim_from_znx32_ref; + } + return res; +} + +EXPORT REIM_FROM_ZNX32_PRECOMP* new_reim_from_znx32_precomp(uint32_t m, uint32_t log2bound) { + REIM_FROM_ZNX32_PRECOMP* res = malloc(sizeof(*res)); + if (!res) return spqlios_error(strerror(errno)); + return spqlios_keep_or_free(res, init_reim_from_znx32_precomp(res, m, log2bound)); +} + +void reim_from_tnx32_ref(const REIM_FROM_TNX32_PRECOMP* precomp, void* r, const int32_t* x) { NOT_IMPLEMENTED(); } +void reim_from_tnx32_avx2_fma(const REIM_FROM_TNX32_PRECOMP* precomp, void* r, const int32_t* x) { NOT_IMPLEMENTED(); } + +void* init_reim_from_tnx32_precomp(REIM_FROM_TNX32_PRECOMP* const res, uint32_t m) { + if (m & (m - 1)) return spqlios_error("m must be a power of 2"); + res->m = m; + // TODO: check selection logic + if (CPU_SUPPORTS("avx2")) { + if (m >= 8) { + res->function = reim_from_tnx32_avx2_fma; + } else { + res->function = reim_from_tnx32_ref; + } + } else { + res->function = reim_from_tnx32_ref; + } + return res; +} + +EXPORT REIM_FROM_TNX32_PRECOMP* new_reim_from_tnx32_precomp(uint32_t m) { + REIM_FROM_TNX32_PRECOMP* res = malloc(sizeof(*res)); + if (!res) return spqlios_error(strerror(errno)); + return spqlios_keep_or_free(res, init_reim_from_tnx32_precomp(res, m)); +} + +void reim_to_tnx32_ref(const REIM_TO_TNX32_PRECOMP* precomp, int32_t* r, const void* x) { NOT_IMPLEMENTED(); } +void reim_to_tnx32_avx2_fma(const REIM_TO_TNX32_PRECOMP* precomp, int32_t* r, const void* x) { NOT_IMPLEMENTED(); } + +void* init_reim_to_tnx32_precomp(REIM_TO_TNX32_PRECOMP* const res, uint32_t m, double divisor, uint32_t log2overhead) { + if (m & (m - 1)) return spqlios_error("m must be a power of 2"); + if (is_not_pow2_double(&divisor)) return spqlios_error("divisor must be a power of 2"); + if (log2overhead > 52) return spqlios_error("log2overhead is too large"); + res->m = m; + res->divisor = divisor; + // TODO: check selection logic + if (CPU_SUPPORTS("avx2")) { + if (log2overhead <= 18) { + if (m >= 8) { + res->function = reim_to_tnx32_avx2_fma; + } else { + res->function = reim_to_tnx32_ref; + } + } else { + res->function = reim_to_tnx32_ref; + } + } else { + res->function = reim_to_tnx32_ref; + } + return res; +} + +EXPORT REIM_TO_TNX32_PRECOMP* new_reim_to_tnx32_precomp(uint32_t m, double divisor, uint32_t log2overhead) { + REIM_TO_TNX32_PRECOMP* res = malloc(sizeof(*res)); + if (!res) return spqlios_error(strerror(errno)); + return spqlios_keep_or_free(res, init_reim_to_tnx32_precomp(res, m, divisor, log2overhead)); +} + +EXPORT void reim_from_znx32_simple(uint32_t m, uint32_t log2bound, void* r, const int32_t* x) { + static REIM_FROM_ZNX32_PRECOMP* p[31] = {0}; + REIM_FROM_ZNX32_PRECOMP** f = p + log2m(m); + if (!*f) *f = new_reim_from_znx32_precomp(m, log2bound); + (*f)->function(*f, r, x); +} + +EXPORT void reim_from_tnx32_simple(uint32_t m, void* r, const int32_t* x) { + static REIM_FROM_TNX32_PRECOMP* p[31] = {0}; + REIM_FROM_TNX32_PRECOMP** f = p + log2m(m); + if (!*f) *f = new_reim_from_tnx32_precomp(m); + (*f)->function(*f, r, x); +} + +EXPORT void reim_to_tnx32_simple(uint32_t m, double divisor, uint32_t log2overhead, int32_t* r, const void* x) { + static REIM_TO_TNX32_PRECOMP* p[31] = {0}; + REIM_TO_TNX32_PRECOMP** f = p + log2m(m); + if (!*f) *f = new_reim_to_tnx32_precomp(m, divisor, log2overhead); + (*f)->function(*f, r, x); +} + +void reim_to_znx64_ref(const REIM_TO_ZNX64_PRECOMP* precomp, int64_t* r, const void* x) { + // for now, we stick to a slow implem + uint64_t nn = precomp->m << 1; + const double* v = (double*)x; + double invdiv = 1. / precomp->divisor; + for (uint64_t i = 0; i < nn; ++i) { + r[i] = (int64_t)rint(v[i] * invdiv); + } +} + +void* init_reim_to_znx64_precomp(REIM_TO_ZNX64_PRECOMP* const res, uint32_t m, double divisor, uint32_t log2bound) { + if (m & (m - 1)) return spqlios_error("m must be a power of 2"); + if (is_not_pow2_double(&divisor)) return spqlios_error("divisor must be a power of 2"); + if (log2bound > 64) return spqlios_error("log2bound is too large"); + res->m = m; + res->divisor = divisor; + TO_ZNX64_FUNC resf = reim_to_znx64_ref; + if (CPU_SUPPORTS("avx2") && m >= 8) { + if (log2bound <= 50) { + resf = reim_to_znx64_avx2_bnd50_fma; + } else { + resf = reim_to_znx64_avx2_bnd63_fma; + } + } + res->function = resf; // must be the last one set + return res; +} + +EXPORT REIM_TO_ZNX64_PRECOMP* new_reim_to_znx64_precomp(uint32_t m, double divisor, uint32_t log2bound) { + REIM_TO_ZNX64_PRECOMP* res = malloc(sizeof(*res)); + if (!res) return spqlios_error(strerror(errno)); + return spqlios_keep_or_free(res, init_reim_to_znx64_precomp(res, m, divisor, log2bound)); +} + +EXPORT void reim_to_znx64(const REIM_TO_ZNX64_PRECOMP* precomp, int64_t* r, const void* a) { + precomp->function(precomp, r, a); +} + +/** + * @brief Simpler API for the znx64 to reim conversion. + */ +EXPORT void reim_to_znx64_simple(uint32_t m, double divisor, uint32_t log2bound, int64_t* r, const void* a) { + // not checking distinguishing <=50 or not + static __thread REIM_TO_ZNX64_PRECOMP p; + static __thread uint32_t prev_log2bound; + if (!p.function || p.m != m || p.divisor != divisor || prev_log2bound != log2bound) { + if (!init_reim_to_znx64_precomp(&p, m, divisor, log2bound)) abort(); + prev_log2bound = log2bound; + } + p.function(&p, r, a); +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_conversions_avx.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_conversions_avx.c new file mode 100644 index 0000000..30f98c9 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_conversions_avx.c @@ -0,0 +1,106 @@ +#include + +#include "reim_fft_private.h" + +void reim_from_znx64_bnd50_fma(const REIM_FROM_ZNX64_PRECOMP* precomp, void* r, const int64_t* x) { + static const double EXPO = INT64_C(1) << 52; + const int64_t ADD_CST = INT64_C(1) << 51; + const double SUB_CST = INT64_C(3) << 51; + + const __m256d SUB_CST_4 = _mm256_set1_pd(SUB_CST); + const __m256i ADD_CST_4 = _mm256_set1_epi64x(ADD_CST); + const __m256d EXPO_4 = _mm256_set1_pd(EXPO); + + double(*out)[4] = (double(*)[4])r; + __m256i* in = (__m256i*)x; + __m256i* inend = (__m256i*)(x + (precomp->m << 1)); + do { + // read the next value + __m256i a = _mm256_loadu_si256(in); + a = _mm256_add_epi64(a, ADD_CST_4); + __m256d ad = _mm256_castsi256_pd(a); + ad = _mm256_or_pd(ad, EXPO_4); + ad = _mm256_sub_pd(ad, SUB_CST_4); + // store the next value + _mm256_storeu_pd(out[0], ad); + ++out; + ++in; + } while (in < inend); +} + +// version where the output norm can be as big as 2^63 +void reim_to_znx64_avx2_bnd63_fma(const REIM_TO_ZNX64_PRECOMP* precomp, int64_t* r, const void* x) { + static const uint64_t SIGN_MASK = 0x8000000000000000UL; + static const uint64_t EXPO_MASK = 0x7FF0000000000000UL; + static const uint64_t MANTISSA_MASK = 0x000FFFFFFFFFFFFFUL; + static const uint64_t MANTISSA_MSB = 0x0010000000000000UL; + const double divisor_bits = precomp->divisor * ((double)(INT64_C(1) << 52)); + const double offset = precomp->divisor / 2.; + + const __m256d SIGN_MASK_4 = _mm256_castsi256_pd(_mm256_set1_epi64x(SIGN_MASK)); + const __m256i EXPO_MASK_4 = _mm256_set1_epi64x(EXPO_MASK); + const __m256i MANTISSA_MASK_4 = _mm256_set1_epi64x(MANTISSA_MASK); + const __m256i MANTISSA_MSB_4 = _mm256_set1_epi64x(MANTISSA_MSB); + const __m256d offset_4 = _mm256_set1_pd(offset); + const __m256i divi_bits_4 = _mm256_castpd_si256(_mm256_set1_pd(divisor_bits)); + + double(*in)[4] = (double(*)[4])x; + __m256i* out = (__m256i*)r; + __m256i* outend = (__m256i*)(r + (precomp->m << 1)); + do { + // read the next value + __m256d a = _mm256_loadu_pd(in[0]); + // a += sign(a) * m/2 + __m256d asign = _mm256_and_pd(a, SIGN_MASK_4); + a = _mm256_add_pd(a, _mm256_or_pd(asign, offset_4)); + // sign: either 0 or -1 + __m256i sign_mask = _mm256_castpd_si256(asign); + sign_mask = _mm256_sub_epi64(_mm256_set1_epi64x(0), _mm256_srli_epi64(sign_mask, 63)); + // compute the exponents + __m256i a0exp = _mm256_and_si256(_mm256_castpd_si256(a), EXPO_MASK_4); + __m256i a0lsh = _mm256_sub_epi64(a0exp, divi_bits_4); + __m256i a0rsh = _mm256_sub_epi64(divi_bits_4, a0exp); + a0lsh = _mm256_srli_epi64(a0lsh, 52); + a0rsh = _mm256_srli_epi64(a0rsh, 52); + // compute the new mantissa + __m256i a0pos = _mm256_and_si256(_mm256_castpd_si256(a), MANTISSA_MASK_4); + a0pos = _mm256_or_si256(a0pos, MANTISSA_MSB_4); + a0lsh = _mm256_sllv_epi64(a0pos, a0lsh); + a0rsh = _mm256_srlv_epi64(a0pos, a0rsh); + __m256i final = _mm256_or_si256(a0lsh, a0rsh); + // negate if the sign was negative + final = _mm256_xor_si256(final, sign_mask); + final = _mm256_sub_epi64(final, sign_mask); + // read the next value + _mm256_storeu_si256(out, final); + ++out; + ++in; + } while (out < outend); +} + +// version where the output norm can be as big as 2^50 +void reim_to_znx64_avx2_bnd50_fma(const REIM_TO_ZNX64_PRECOMP* precomp, int64_t* r, const void* x) { + static const uint64_t MANTISSA_MASK = 0x000FFFFFFFFFFFFFUL; + const int64_t SUB_CST = INT64_C(1) << 51; + const double add_cst = precomp->divisor * ((double)(INT64_C(3) << 51)); + + const __m256i SUB_CST_4 = _mm256_set1_epi64x(SUB_CST); + const __m256d add_cst_4 = _mm256_set1_pd(add_cst); + const __m256i MANTISSA_MASK_4 = _mm256_set1_epi64x(MANTISSA_MASK); + + double(*in)[4] = (double(*)[4])x; + __m256i* out = (__m256i*)r; + __m256i* outend = (__m256i*)(r + (precomp->m << 1)); + do { + // read the next value + __m256d a = _mm256_loadu_pd(in[0]); + a = _mm256_add_pd(a, add_cst_4); + __m256i ai = _mm256_castpd_si256(a); + ai = _mm256_and_si256(ai, MANTISSA_MASK_4); + ai = _mm256_sub_epi64(ai, SUB_CST_4); + // store the next value + _mm256_storeu_si256(out, ai); + ++out; + ++in; + } while (out < outend); +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_execute.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_execute.c new file mode 100644 index 0000000..ace7721 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_execute.c @@ -0,0 +1,44 @@ +#include "reim_fft_internal.h" +#include "reim_fft_private.h" + +EXPORT void reim_from_znx32(const REIM_FROM_ZNX32_PRECOMP* tables, void* r, const int32_t* a) { + tables->function(tables, r, a); +} + +EXPORT void reim_from_tnx32(const REIM_FROM_TNX32_PRECOMP* tables, void* r, const int32_t* a) { + tables->function(tables, r, a); +} + +EXPORT void reim_to_tnx32(const REIM_TO_TNX32_PRECOMP* tables, int32_t* r, const void* a) { + tables->function(tables, r, a); +} + +EXPORT void reim_fftvec_add(const REIM_FFTVEC_ADD_PRECOMP* tables, double* r, const double* a, const double* b) { + tables->function(tables, r, a, b); +} + +EXPORT void reim_fftvec_sub(const REIM_FFTVEC_SUB_PRECOMP* tables, double* r, const double* a, const double* b) { + tables->function(tables, r, a, b); +} + +EXPORT void reim_fftvec_mul(const REIM_FFTVEC_MUL_PRECOMP* tables, double* r, const double* a, const double* b) { + tables->function(tables, r, a, b); +} + +EXPORT void reim_fftvec_addmul(const REIM_FFTVEC_ADDMUL_PRECOMP* tables, double* r, const double* a, const double* b) { + tables->function(tables, r, a, b); +} + +EXPORT void reim_fftvec_automorphism(const REIM_FFTVEC_AUTOMORPHISM_PRECOMP* tables, int64_t p, double* r, + const double* a, uint64_t a_size) { + tables->function.apply(tables, p, r, a, a_size); +} + +EXPORT void reim_fftvec_automorphism_inplace(const REIM_FFTVEC_AUTOMORPHISM_PRECOMP* tables, int64_t p, double* a, + uint64_t a_size, uint8_t* tmp_bytes) { + tables->function.apply_inplace(tables, p, a, a_size, tmp_bytes); +} + +EXPORT uint64_t reim_fftvec_automorphism_inplace_tmp_bytes(const REIM_FFTVEC_AUTOMORPHISM_PRECOMP* tables) { + return tables->function.apply_inplace_tmp_bytes(tables); +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fallbacks_aarch64.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fallbacks_aarch64.c new file mode 100644 index 0000000..f8c364e --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fallbacks_aarch64.c @@ -0,0 +1,21 @@ +#include "reim_fft_private.h" + +EXPORT void reim_fftvec_addmul_fma(const REIM_FFTVEC_ADDMUL_PRECOMP* precomp, double* r, const double* a, + const double* b) { + UNDEFINED(); +} +EXPORT void reim_fftvec_mul_fma(const REIM_FFTVEC_MUL_PRECOMP* precomp, double* r, const double* a, const double* b) { + UNDEFINED(); +} +EXPORT void reim_fftvec_add_fma(const REIM_FFTVEC_ADD_PRECOMP* precomp, double* r, const double* a, const double* b) { + UNDEFINED(); +} +EXPORT void reim_fftvec_sub_fma(const REIM_FFTVEC_SUB_PRECOMP* precomp, double* r, const double* a, const double* b) { + UNDEFINED(); +} + +EXPORT void reim_fft_avx2_fma(const REIM_FFT_PRECOMP* tables, double* data) { UNDEFINED(); } +EXPORT void reim_ifft_avx2_fma(const REIM_IFFT_PRECOMP* tables, double* data) { UNDEFINED(); } + +// EXPORT void reim_fft(const REIM_FFT_PRECOMP* tables, double* data) { tables->function(tables, data); } +// EXPORT void reim_ifft(const REIM_IFFT_PRECOMP* tables, double* data) { tables->function(tables, data); } diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fft.h b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fft.h new file mode 100644 index 0000000..582413b --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fft.h @@ -0,0 +1,234 @@ +#ifndef SPQLIOS_REIM_FFT_H +#define SPQLIOS_REIM_FFT_H + +#include "../commons.h" + +typedef struct reim_fft_precomp REIM_FFT_PRECOMP; +typedef struct reim_ifft_precomp REIM_IFFT_PRECOMP; +typedef struct reim_add_precomp REIM_FFTVEC_ADD_PRECOMP; +typedef struct reim_sub_precomp REIM_FFTVEC_SUB_PRECOMP; +typedef struct reim_mul_precomp REIM_FFTVEC_MUL_PRECOMP; +typedef struct reim_addmul_precomp REIM_FFTVEC_ADDMUL_PRECOMP; +typedef struct reim_fftvec_automorphism_precomp REIM_FFTVEC_AUTOMORPHISM_PRECOMP; +typedef struct reim_from_znx32_precomp REIM_FROM_ZNX32_PRECOMP; +typedef struct reim_from_znx64_precomp REIM_FROM_ZNX64_PRECOMP; +typedef struct reim_from_tnx32_precomp REIM_FROM_TNX32_PRECOMP; +typedef struct reim_to_tnx32_precomp REIM_TO_TNX32_PRECOMP; +typedef struct reim_to_tnx_precomp REIM_TO_TNX_PRECOMP; +typedef struct reim_to_znx64_precomp REIM_TO_ZNX64_PRECOMP; + +/** + * @brief precomputes fft tables. + * The FFT tables contains a constant section that is required for efficient FFT operations in dimension nn. + * The resulting pointer is to be passed as "tables" argument to any call to the fft function. + * The user can optionnally allocate zero or more computation buffers, which are scratch spaces that are contiguous to + * the constant tables in memory, and allow for more efficient operations. It is the user's responsibility to ensure + * that each of those buffers are never used simultaneously by two ffts on different threads at the same time. The fft + * table must be deleted by delete_fft_precomp after its last usage. + */ +EXPORT REIM_FFT_PRECOMP* new_reim_fft_precomp(uint32_t m, uint32_t num_buffers); + +/** + * @brief gets the address of a fft buffer allocated during new_fft_precomp. + * This buffer can be used as data pointer in subsequent calls to fft, + * and does not need to be released afterwards. + */ +EXPORT double* reim_fft_precomp_get_buffer(const REIM_FFT_PRECOMP* tables, uint32_t buffer_index); + +/** + * @brief allocates a new fft buffer. + * This buffer can be used as data pointer in subsequent calls to fft, + * and must be deleted afterwards by calling delete_fft_buffer. + */ +EXPORT double* new_reim_fft_buffer(uint32_t m); + +/** + * @brief allocates a new fft buffer. + * This buffer can be used as data pointer in subsequent calls to fft, + * and must be deleted afterwards by calling delete_fft_buffer. + */ +EXPORT void delete_reim_fft_buffer(double* buffer); + +/** + * @brief deallocates a fft table and all its built-in buffers. + */ +#define delete_reim_fft_precomp free + +/** + * @brief computes a direct fft in-place over data. + */ +EXPORT void reim_fft(const REIM_FFT_PRECOMP* tables, double* data); + +EXPORT REIM_IFFT_PRECOMP* new_reim_ifft_precomp(uint32_t m, uint32_t num_buffers); +EXPORT double* reim_ifft_precomp_get_buffer(const REIM_IFFT_PRECOMP* tables, uint32_t buffer_index); +EXPORT void reim_ifft(const REIM_IFFT_PRECOMP* tables, double* data); +#define delete_reim_ifft_precomp free + +EXPORT REIM_FFTVEC_ADD_PRECOMP* new_reim_fftvec_add_precomp(uint32_t m); +EXPORT void reim_fftvec_add(const REIM_FFTVEC_ADD_PRECOMP* tables, double* r, const double* a, const double* b); +#define delete_reim_fftvec_add_precomp free + +EXPORT REIM_FFTVEC_SUB_PRECOMP* new_reim_fftvec_sub_precomp(uint32_t m); +EXPORT void reim_fftvec_sub(const REIM_FFTVEC_SUB_PRECOMP* tables, double* r, const double* a, const double* b); +#define delete_reim_fftvec_sub_precomp free + +EXPORT REIM_FFTVEC_MUL_PRECOMP* new_reim_fftvec_mul_precomp(uint32_t m); +EXPORT void reim_fftvec_mul(const REIM_FFTVEC_MUL_PRECOMP* tables, double* r, const double* a, const double* b); +#define delete_reim_fftvec_mul_precomp free + +EXPORT REIM_FFTVEC_ADDMUL_PRECOMP* new_reim_fftvec_addmul_precomp(uint32_t m); +EXPORT void reim_fftvec_addmul(const REIM_FFTVEC_ADDMUL_PRECOMP* tables, double* r, const double* a, const double* b); +#define delete_reim_fftvec_addmul_precomp free + +EXPORT REIM_FFTVEC_AUTOMORPHISM_PRECOMP* new_reim_fftvec_automorphism_precomp(uint32_t m); +EXPORT void reim_fftvec_automorphism(const REIM_FFTVEC_AUTOMORPHISM_PRECOMP* tables, int64_t p, double* r, + const double* a, uint64_t a_size); + +EXPORT void reim_fftvec_automorphism_inplace(const REIM_FFTVEC_AUTOMORPHISM_PRECOMP* tables, int64_t p, double* a, + uint64_t a_size, uint8_t* tmp_bytes); + +EXPORT uint64_t reim_fftvec_automorphism_inplace_tmp_bytes(const REIM_FFTVEC_AUTOMORPHISM_PRECOMP* tables); + +#define delete_reim_fftvec_automorphism_precomp free + +/** + * @brief prepares a conversion from ZnX to the cplx layout. + * All the coefficients must be strictly lower than 2^log2bound in absolute value. Any attempt to use + * this function on a larger coefficient is undefined behaviour. The resulting precomputed data must + * be freed with `new_reim_from_znx32_precomp` + * @param m the target complex dimension m from C[X] mod X^m-i. Note that the inputs have n=2m + * int32 coefficients in natural order modulo X^n+1 + * @param log2bound bound on the input coefficients. Must be between 0 and 32 + */ +EXPORT REIM_FROM_ZNX32_PRECOMP* new_reim_from_znx32_precomp(uint32_t m, uint32_t log2bound); + +/** + * @brief converts from ZnX to the cplx layout. + * @param tables precomputed data obtained by new_reim_from_znx32_precomp. + * @param r resulting array of m complexes coefficients mod X^m-i + * @param x input array of n bounded integer coefficients mod X^n+1 + */ +EXPORT void reim_from_znx32(const REIM_FROM_ZNX32_PRECOMP* tables, void* r, const int32_t* a); +/** @brief frees a precomputed conversion data initialized with new_reim_from_znx32_precomp. */ +#define delete_reim_from_znx32_precomp free + +/** + * @brief converts from ZnX to the cplx layout. + * @param tables precomputed data obtained by new_reim_from_znx64_precomp. + * @param r resulting array of m complexes coefficients mod X^m-i + * @param x input array of n bounded integer coefficients mod X^n+1 + */ +EXPORT void reim_from_znx64(const REIM_FROM_ZNX64_PRECOMP* tables, void* r, const int64_t* a); +/** @brief frees a precomputed conversion data initialized with new_reim_from_znx32_precomp. */ +EXPORT REIM_FROM_ZNX64_PRECOMP* new_reim_from_znx64_precomp(uint32_t m, uint32_t maxbnd); +#define delete_reim_from_znx64_precomp free +EXPORT void reim_from_znx64_simple(uint32_t m, uint32_t log2bound, void* r, const int64_t* a); + +/** + * @brief prepares a conversion from TnX to the cplx layout. + * @param m the target complex dimension m from C[X] mod X^m-i. Note that the inputs have n=2m + * torus32 coefficients. The resulting precomputed data must + * be freed with `delete_reim_from_tnx32_precomp` + */ + +EXPORT REIM_FROM_TNX32_PRECOMP* new_reim_from_tnx32_precomp(uint32_t m); +/** + * @brief converts from TnX to the cplx layout. + * @param tables precomputed data obtained by new_reim_from_tnx32_precomp. + * @param r resulting array of m complexes coefficients mod X^m-i + * @param x input array of n torus32 coefficients mod X^n+1 + */ +EXPORT void reim_from_tnx32(const REIM_FROM_TNX32_PRECOMP* tables, void* r, const int32_t* a); +/** @brief frees a precomputed conversion data initialized with new_reim_from_tnx32_precomp. */ +#define delete_reim_from_tnx32_precomp free + +/** + * @brief prepares a rescale and conversion from the cplx layout to TnX. + * @param m the target complex dimension m from C[X] mod X^m-i. Note that the outputs have n=2m + * torus32 coefficients. + * @param divisor must be a power of two. The inputs are rescaled by divisor before being reduced modulo 1. + * Remember that the output of an iFFT must be divided by m. + * @param log2overhead all inputs absolute values must be within divisor.2^log2overhead. + * For any inputs outside of these bounds, the conversion is undefined behaviour. + * The maximum supported log2overhead is 52, and the algorithm is faster for log2overhead=18. + */ +EXPORT REIM_TO_TNX32_PRECOMP* new_reim_to_tnx32_precomp(uint32_t m, double divisor, uint32_t log2overhead); + +/** + * @brief rescale, converts and reduce mod 1 from cplx layout to torus32. + * @param tables precomputed data obtained by new_reim_from_tnx32_precomp. + * @param r resulting array of n torus32 coefficients mod X^n+1 + * @param x input array of m cplx coefficients mod X^m-i + */ +EXPORT void reim_to_tnx32(const REIM_TO_TNX32_PRECOMP* tables, int32_t* r, const void* a); +#define delete_reim_to_tnx32_precomp free + +/** + * @brief prepares a rescale and conversion from the cplx layout to TnX (doubles). + * @param m the target complex dimension m from C[X] mod X^m-i. Note that the outputs have n=2m + * torus32 coefficients. + * @param divisor must be a power of two. The inputs are rescaled by divisor before being reduced modulo 1. + * Remember that the output of an iFFT must be divided by m. + * @param log2overhead all inputs absolute values must be within divisor.2^log2overhead. + * For any inputs outside of these bounds, the conversion is undefined behaviour. + * The maximum supported log2overhead is 52, and the algorithm is faster for log2overhead=18. + */ +EXPORT REIM_TO_TNX_PRECOMP* new_reim_to_tnx_precomp(uint32_t m, double divisor, uint32_t log2overhead); +/** + * @brief rescale, converts and reduce mod 1 from cplx layout to torus32. + * @param tables precomputed data obtained by new_reim_from_tnx32_precomp. + * @param r resulting array of n torus32 coefficients mod X^n+1 + * @param x input array of m cplx coefficients mod X^m-i + */ +EXPORT void reim_to_tnx(const REIM_TO_TNX_PRECOMP* tables, double* r, const double* a); +#define delete_reim_to_tnx_precomp free +EXPORT void reim_to_tnx_simple(uint32_t m, double divisor, uint32_t log2overhead, double* r, const double* a); + +EXPORT REIM_TO_ZNX64_PRECOMP* new_reim_to_znx64_precomp(uint32_t m, double divisor, uint32_t log2bound); +#define delete_reim_to_znx64_precomp free +EXPORT void reim_to_znx64(const REIM_TO_ZNX64_PRECOMP* precomp, int64_t* r, const void* a); +EXPORT void reim_to_znx64_simple(uint32_t m, double divisor, uint32_t log2bound, int64_t* r, const void* a); + +/** + * @brief Simpler API for the fft function. + * For each dimension, the precomputed tables for this dimension are generated automatically. + * It is advised to do one dry-run per desired dimension before using in a multithread environment */ +EXPORT void reim_fft_simple(uint32_t m, void* data); +/** + * @brief Simpler API for the ifft function. + * For each dimension, the precomputed tables for this dimension are generated automatically the first time. + * It is advised to do one dry-run call per desired dimension in the main thread before using in a multithread + * environment */ +EXPORT void reim_ifft_simple(uint32_t m, void* data); +/** + * @brief Simpler API for the fftvec addition function. + * For each dimension, the precomputed tables for this dimension are generated automatically the first time. + * It is advised to do one dry-run call per desired dimension before using in a multithread environment */ +EXPORT void reim_fftvec_add_simple(uint32_t m, void* r, const void* a, const void* b); +/** + * @brief Simpler API for the fftvec multiplication function. + * For each dimension, the precomputed tables for this dimension are generated automatically the first time. + * It is advised to do one dry-run call per desired dimension before using in a multithread environment */ +EXPORT void reim_fftvec_mul_simple(uint32_t m, void* r, const void* a, const void* b); +/** + * @brief Simpler API for the fftvec addmul function. + * For each dimension, the precomputed tables for this dimension are generated automatically the first time. + * It is advised to do one dry-run call per desired dimension before using in a multithread environment */ +EXPORT void reim_fftvec_addmul_simple(uint32_t m, void* r, const void* a, const void* b); +/** + * @brief Simpler API for the znx32 to cplx conversion. + * For each dimension, the precomputed tables for this dimension are generated automatically the first time. + * It is advised to do one dry-run call per desired dimension before using in a multithread environment */ +EXPORT void reim_from_znx32_simple(uint32_t m, uint32_t log2bound, void* r, const int32_t* x); +/** + * @brief Simpler API for the tnx32 to cplx conversion. + * For each dimension, the precomputed tables for this dimension are generated automatically the first time. + * It is advised to do one dry-run call per desired dimension before using in a multithread environment */ +EXPORT void reim_from_tnx32_simple(uint32_t m, void* r, const int32_t* x); +/** + * @brief Simpler API for the cplx to tnx32 conversion. + * For each dimension, the precomputed tables for this dimension are generated automatically the first time. + * It is advised to do one dry-run call per desired dimension before using in a multithread environment */ +EXPORT void reim_to_tnx32_simple(uint32_t m, double divisor, uint32_t log2overhead, int32_t* r, const void* x); + +#endif // SPQLIOS_REIM_FFT_H diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fft16_avx_fma.s b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fft16_avx_fma.s new file mode 100644 index 0000000..e68012c --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fft16_avx_fma.s @@ -0,0 +1,167 @@ +#rdi datare ptr +#rsi dataim ptr +#rdx om ptr +.globl reim_fft16_avx_fma +reim_fft16_avx_fma: +vmovupd (%rdi),%ymm0 # ra0 +vmovupd 0x20(%rdi),%ymm1 # ra4 +vmovupd 0x40(%rdi),%ymm2 # ra8 +vmovupd 0x60(%rdi),%ymm3 # ra12 +vmovupd (%rsi),%ymm4 # ia0 +vmovupd 0x20(%rsi),%ymm5 # ia4 +vmovupd 0x40(%rsi),%ymm6 # ia8 +vmovupd 0x60(%rsi),%ymm7 # ia12 + +vmovupd (%rdx),%xmm12 +vinsertf128 $1, %xmm12, %ymm12, %ymm12 # omriri +vshufpd $15, %ymm12, %ymm12, %ymm13 # ymm13: omai +vshufpd $0, %ymm12, %ymm12, %ymm12 # ymm12: omar +vmulpd %ymm6,%ymm13,%ymm8 # ia0.omai +vmulpd %ymm7,%ymm13,%ymm9 # ia4.omai +vmulpd %ymm2,%ymm13,%ymm10 # ra0.omai +vmulpd %ymm3,%ymm13,%ymm11 # ra4.omai +vfmsub231pd %ymm2,%ymm12,%ymm8 # rprod0 +vfmsub231pd %ymm3,%ymm12,%ymm9 # rprod4 +vfmadd231pd %ymm6,%ymm12,%ymm10 # iprod0 +vfmadd231pd %ymm7,%ymm12,%ymm11 # iprod4 +vsubpd %ymm8,%ymm0,%ymm2 +vsubpd %ymm9,%ymm1,%ymm3 +vsubpd %ymm10,%ymm4,%ymm6 +vsubpd %ymm11,%ymm5,%ymm7 +vaddpd %ymm8,%ymm0,%ymm0 +vaddpd %ymm9,%ymm1,%ymm1 +vaddpd %ymm10,%ymm4,%ymm4 +vaddpd %ymm11,%ymm5,%ymm5 + +1: +vmovupd 16(%rdx),%xmm12 +vinsertf128 $1, %xmm12, %ymm12, %ymm12 # omriri +vshufpd $15, %ymm12, %ymm12, %ymm13 # ymm13: omai +vshufpd $0, %ymm12, %ymm12, %ymm12 # ymm12: omar +vmulpd %ymm5,%ymm13,%ymm8 # ia0.omai (tw) +vmulpd %ymm7,%ymm12,%ymm9 # ia4.omar (itw) +vmulpd %ymm1,%ymm13,%ymm10 # ra0.omai (tw) +vmulpd %ymm3,%ymm12,%ymm11 # ra4.omar (itw) +vfmsub231pd %ymm1,%ymm12,%ymm8 # rprod0 (tw) +vfmadd231pd %ymm3,%ymm13,%ymm9 # rprod4 (itw) +vfmadd231pd %ymm5,%ymm12,%ymm10 # iprod0 (tw) +vfmsub231pd %ymm7,%ymm13,%ymm11 # iprod4 (itw) +vsubpd %ymm8,%ymm0,%ymm1 +vaddpd %ymm9,%ymm2,%ymm3 +vsubpd %ymm10,%ymm4,%ymm5 +vaddpd %ymm11,%ymm6,%ymm7 +vaddpd %ymm8,%ymm0,%ymm0 +vsubpd %ymm9,%ymm2,%ymm2 +vaddpd %ymm10,%ymm4,%ymm4 +vsubpd %ymm11,%ymm6,%ymm6 + +2: +vmovupd 0x20(%rdx),%ymm12 +vshufpd $15, %ymm12, %ymm12, %ymm13 # ymm13: omaiii'i' +vshufpd $0, %ymm12, %ymm12, %ymm12 # ymm12: omarrr'r' + +vperm2f128 $0x31,%ymm2,%ymm0,%ymm8 # ymm8 contains re to mul (tw) +vperm2f128 $0x31,%ymm3,%ymm1,%ymm9 # ymm9 contains re to mul (itw) +vperm2f128 $0x31,%ymm6,%ymm4,%ymm10 # ymm10 contains im to mul (tw) +vperm2f128 $0x31,%ymm7,%ymm5,%ymm11 # ymm11 contains im to mul (itw) +vperm2f128 $0x20,%ymm2,%ymm0,%ymm0 # ymm0 contains re to add (tw) +vperm2f128 $0x20,%ymm3,%ymm1,%ymm1 # ymm1 contains re to add (itw) +vperm2f128 $0x20,%ymm6,%ymm4,%ymm2 # ymm2 contains im to add (tw) +vperm2f128 $0x20,%ymm7,%ymm5,%ymm3 # ymm3 contains im to add (itw) + +vmulpd %ymm10,%ymm13,%ymm4 # ia0.omai (tw) +vmulpd %ymm11,%ymm12,%ymm5 # ia4.omar (itw) +vmulpd %ymm8,%ymm13,%ymm6 # ra0.omai (tw) +vmulpd %ymm9,%ymm12,%ymm7 # ra4.omar (itw) +vfmsub231pd %ymm8,%ymm12,%ymm4 # rprod0 (tw) +vfmadd231pd %ymm9,%ymm13,%ymm5 # rprod4 (itw) +vfmadd231pd %ymm10,%ymm12,%ymm6 # iprod0 (tw) +vfmsub231pd %ymm11,%ymm13,%ymm7 # iprod4 (itw) +vsubpd %ymm4,%ymm0,%ymm8 +vaddpd %ymm5,%ymm1,%ymm9 +vsubpd %ymm6,%ymm2,%ymm10 +vaddpd %ymm7,%ymm3,%ymm11 +vaddpd %ymm4,%ymm0,%ymm0 +vsubpd %ymm5,%ymm1,%ymm1 +vaddpd %ymm6,%ymm2,%ymm2 +vsubpd %ymm7,%ymm3,%ymm3 + +#vperm2f128 $0x31,%ymm10,%ymm2,%ymm6 +#vperm2f128 $0x31,%ymm11,%ymm3,%ymm7 +#vperm2f128 $0x20,%ymm10,%ymm2,%ymm4 +#vperm2f128 $0x20,%ymm11,%ymm3,%ymm5 +#vperm2f128 $0x31,%ymm8,%ymm0,%ymm2 +#vperm2f128 $0x31,%ymm9,%ymm1,%ymm3 +#vperm2f128 $0x20,%ymm8,%ymm0,%ymm0 +#vperm2f128 $0x20,%ymm9,%ymm1,%ymm1 + +3: +vmovupd 0x40(%rdx),%ymm12 +vmovupd 0x60(%rdx),%ymm13 + +#vperm2f128 $0x31,%ymm2,%ymm0,%ymm8 # ymm8 contains re to mul (tw) +#vperm2f128 $0x31,%ymm3,%ymm1,%ymm9 # ymm9 contains re to mul (itw) +#vperm2f128 $0x31,%ymm6,%ymm4,%ymm10 # ymm10 contains im to mul (tw) +#vperm2f128 $0x31,%ymm7,%ymm5,%ymm11 # ymm11 contains im to mul (itw) +#vperm2f128 $0x20,%ymm2,%ymm0,%ymm0 # ymm0 contains re to add (tw) +#vperm2f128 $0x20,%ymm3,%ymm1,%ymm1 # ymm1 contains re to add (itw) +#vperm2f128 $0x20,%ymm6,%ymm4,%ymm2 # ymm2 contains im to add (tw) +#vperm2f128 $0x20,%ymm7,%ymm5,%ymm3 # ymm3 contains im to add (itw) + +vunpckhpd %ymm1,%ymm0,%ymm4 # (0,1) -> (0,4) +vunpckhpd %ymm3,%ymm2,%ymm6 # (2,3) -> (2,6) +vunpckhpd %ymm9,%ymm8,%ymm5 # (8,9) -> (1,5) +vunpckhpd %ymm11,%ymm10,%ymm7 # (10,11) -> (3,7) +vunpcklpd %ymm1,%ymm0,%ymm0 +vunpcklpd %ymm3,%ymm2,%ymm2 +vunpcklpd %ymm9,%ymm8,%ymm1 +vunpcklpd %ymm11,%ymm10,%ymm3 + +vmulpd %ymm6,%ymm13,%ymm8 # ia0.omai (tw) +vmulpd %ymm7,%ymm12,%ymm9 # ia4.omar (itw) +vmulpd %ymm4,%ymm13,%ymm10 # ra0.omai (tw) +vmulpd %ymm5,%ymm12,%ymm11 # ra4.omar (itw) +vfmsub231pd %ymm4,%ymm12,%ymm8 # rprod0 (tw) +vfmadd231pd %ymm5,%ymm13,%ymm9 # rprod4 (itw) +vfmadd231pd %ymm6,%ymm12,%ymm10 # iprod0 (tw) +vfmsub231pd %ymm7,%ymm13,%ymm11 # iprod4 (itw) +vsubpd %ymm8,%ymm0,%ymm4 +vaddpd %ymm9,%ymm1,%ymm5 +vsubpd %ymm10,%ymm2,%ymm6 +vaddpd %ymm11,%ymm3,%ymm7 +vaddpd %ymm8,%ymm0,%ymm0 +vsubpd %ymm9,%ymm1,%ymm1 +vaddpd %ymm10,%ymm2,%ymm2 +vsubpd %ymm11,%ymm3,%ymm3 + +vunpckhpd %ymm7,%ymm3,%ymm11 # (0,4) -> (0,1) +vunpckhpd %ymm5,%ymm1,%ymm9 # (2,6) -> (2,3) +vunpcklpd %ymm7,%ymm3,%ymm10 +vunpcklpd %ymm5,%ymm1,%ymm8 +vunpckhpd %ymm6,%ymm2,%ymm3 # (1,5) -> (8,9) +vunpckhpd %ymm4,%ymm0,%ymm1 # (3,7) -> (10,11) +vunpcklpd %ymm6,%ymm2,%ymm2 +vunpcklpd %ymm4,%ymm0,%ymm0 + +vperm2f128 $0x31,%ymm10,%ymm2,%ymm6 +vperm2f128 $0x31,%ymm11,%ymm3,%ymm7 +vperm2f128 $0x20,%ymm10,%ymm2,%ymm4 +vperm2f128 $0x20,%ymm11,%ymm3,%ymm5 +vperm2f128 $0x31,%ymm8,%ymm0,%ymm2 +vperm2f128 $0x31,%ymm9,%ymm1,%ymm3 +vperm2f128 $0x20,%ymm8,%ymm0,%ymm0 +vperm2f128 $0x20,%ymm9,%ymm1,%ymm1 + +4: +vmovupd %ymm0,(%rdi) # ra0 +vmovupd %ymm1,0x20(%rdi) # ra4 +vmovupd %ymm2,0x40(%rdi) # ra8 +vmovupd %ymm3,0x60(%rdi) # ra12 +vmovupd %ymm4,(%rsi) # ia0 +vmovupd %ymm5,0x20(%rsi) # ia4 +vmovupd %ymm6,0x40(%rsi) # ia8 +vmovupd %ymm7,0x60(%rsi) # ia12 +vzeroupper +ret +.size reim_fft16_avx_fma, .-reim_fft16_avx_fma +.section .note.GNU-stack,"",@progbits diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fft16_avx_fma_win32.s b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fft16_avx_fma_win32.s new file mode 100644 index 0000000..add742e --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fft16_avx_fma_win32.s @@ -0,0 +1,203 @@ + .text + .p2align 4 + .globl reim_fft16_avx_fma + .def reim_fft16_avx_fma; .scl 2; .type 32; .endef +reim_fft16_avx_fma: + + pushq %rdi + pushq %rsi + movq %rcx,%rdi + movq %rdx,%rsi + movq %r8,%rdx + subq $0x100,%rsp + movdqu %xmm6,(%rsp) + movdqu %xmm7,0x10(%rsp) + movdqu %xmm8,0x20(%rsp) + movdqu %xmm9,0x30(%rsp) + movdqu %xmm10,0x40(%rsp) + movdqu %xmm11,0x50(%rsp) + movdqu %xmm12,0x60(%rsp) + movdqu %xmm13,0x70(%rsp) + movdqu %xmm14,0x80(%rsp) + movdqu %xmm15,0x90(%rsp) + callq reim_fft16_avx_fma_amd64 + movdqu (%rsp),%xmm6 + movdqu 0x10(%rsp),%xmm7 + movdqu 0x20(%rsp),%xmm8 + movdqu 0x30(%rsp),%xmm9 + movdqu 0x40(%rsp),%xmm10 + movdqu 0x50(%rsp),%xmm11 + movdqu 0x60(%rsp),%xmm12 + movdqu 0x70(%rsp),%xmm13 + movdqu 0x80(%rsp),%xmm14 + movdqu 0x90(%rsp),%xmm15 + addq $0x100,%rsp + popq %rsi + popq %rdi + retq + +#rdi datare ptr +#rsi dataim ptr +#rdx om ptr +#.globl reim_fft16_avx_fma_amd64 +reim_fft16_avx_fma_amd64: +vmovupd (%rdi),%ymm0 # ra0 +vmovupd 0x20(%rdi),%ymm1 # ra4 +vmovupd 0x40(%rdi),%ymm2 # ra8 +vmovupd 0x60(%rdi),%ymm3 # ra12 +vmovupd (%rsi),%ymm4 # ia0 +vmovupd 0x20(%rsi),%ymm5 # ia4 +vmovupd 0x40(%rsi),%ymm6 # ia8 +vmovupd 0x60(%rsi),%ymm7 # ia12 + +vmovupd (%rdx),%xmm12 +vinsertf128 $1, %xmm12, %ymm12, %ymm12 # omriri +vshufpd $15, %ymm12, %ymm12, %ymm13 # ymm13: omai +vshufpd $0, %ymm12, %ymm12, %ymm12 # ymm12: omar +vmulpd %ymm6,%ymm13,%ymm8 # ia0.omai +vmulpd %ymm7,%ymm13,%ymm9 # ia4.omai +vmulpd %ymm2,%ymm13,%ymm10 # ra0.omai +vmulpd %ymm3,%ymm13,%ymm11 # ra4.omai +vfmsub231pd %ymm2,%ymm12,%ymm8 # rprod0 +vfmsub231pd %ymm3,%ymm12,%ymm9 # rprod4 +vfmadd231pd %ymm6,%ymm12,%ymm10 # iprod0 +vfmadd231pd %ymm7,%ymm12,%ymm11 # iprod4 +vsubpd %ymm8,%ymm0,%ymm2 +vsubpd %ymm9,%ymm1,%ymm3 +vsubpd %ymm10,%ymm4,%ymm6 +vsubpd %ymm11,%ymm5,%ymm7 +vaddpd %ymm8,%ymm0,%ymm0 +vaddpd %ymm9,%ymm1,%ymm1 +vaddpd %ymm10,%ymm4,%ymm4 +vaddpd %ymm11,%ymm5,%ymm5 + +1: +vmovupd 16(%rdx),%xmm12 +vinsertf128 $1, %xmm12, %ymm12, %ymm12 # omriri +vshufpd $15, %ymm12, %ymm12, %ymm13 # ymm13: omai +vshufpd $0, %ymm12, %ymm12, %ymm12 # ymm12: omar +vmulpd %ymm5,%ymm13,%ymm8 # ia0.omai (tw) +vmulpd %ymm7,%ymm12,%ymm9 # ia4.omar (itw) +vmulpd %ymm1,%ymm13,%ymm10 # ra0.omai (tw) +vmulpd %ymm3,%ymm12,%ymm11 # ra4.omar (itw) +vfmsub231pd %ymm1,%ymm12,%ymm8 # rprod0 (tw) +vfmadd231pd %ymm3,%ymm13,%ymm9 # rprod4 (itw) +vfmadd231pd %ymm5,%ymm12,%ymm10 # iprod0 (tw) +vfmsub231pd %ymm7,%ymm13,%ymm11 # iprod4 (itw) +vsubpd %ymm8,%ymm0,%ymm1 +vaddpd %ymm9,%ymm2,%ymm3 +vsubpd %ymm10,%ymm4,%ymm5 +vaddpd %ymm11,%ymm6,%ymm7 +vaddpd %ymm8,%ymm0,%ymm0 +vsubpd %ymm9,%ymm2,%ymm2 +vaddpd %ymm10,%ymm4,%ymm4 +vsubpd %ymm11,%ymm6,%ymm6 + +2: +vmovupd 0x20(%rdx),%ymm12 +vshufpd $15, %ymm12, %ymm12, %ymm13 # ymm13: omaiii'i' +vshufpd $0, %ymm12, %ymm12, %ymm12 # ymm12: omarrr'r' + +vperm2f128 $0x31,%ymm2,%ymm0,%ymm8 # ymm8 contains re to mul (tw) +vperm2f128 $0x31,%ymm3,%ymm1,%ymm9 # ymm9 contains re to mul (itw) +vperm2f128 $0x31,%ymm6,%ymm4,%ymm10 # ymm10 contains im to mul (tw) +vperm2f128 $0x31,%ymm7,%ymm5,%ymm11 # ymm11 contains im to mul (itw) +vperm2f128 $0x20,%ymm2,%ymm0,%ymm0 # ymm0 contains re to add (tw) +vperm2f128 $0x20,%ymm3,%ymm1,%ymm1 # ymm1 contains re to add (itw) +vperm2f128 $0x20,%ymm6,%ymm4,%ymm2 # ymm2 contains im to add (tw) +vperm2f128 $0x20,%ymm7,%ymm5,%ymm3 # ymm3 contains im to add (itw) + +vmulpd %ymm10,%ymm13,%ymm4 # ia0.omai (tw) +vmulpd %ymm11,%ymm12,%ymm5 # ia4.omar (itw) +vmulpd %ymm8,%ymm13,%ymm6 # ra0.omai (tw) +vmulpd %ymm9,%ymm12,%ymm7 # ra4.omar (itw) +vfmsub231pd %ymm8,%ymm12,%ymm4 # rprod0 (tw) +vfmadd231pd %ymm9,%ymm13,%ymm5 # rprod4 (itw) +vfmadd231pd %ymm10,%ymm12,%ymm6 # iprod0 (tw) +vfmsub231pd %ymm11,%ymm13,%ymm7 # iprod4 (itw) +vsubpd %ymm4,%ymm0,%ymm8 +vaddpd %ymm5,%ymm1,%ymm9 +vsubpd %ymm6,%ymm2,%ymm10 +vaddpd %ymm7,%ymm3,%ymm11 +vaddpd %ymm4,%ymm0,%ymm0 +vsubpd %ymm5,%ymm1,%ymm1 +vaddpd %ymm6,%ymm2,%ymm2 +vsubpd %ymm7,%ymm3,%ymm3 + +#vperm2f128 $0x31,%ymm10,%ymm2,%ymm6 +#vperm2f128 $0x31,%ymm11,%ymm3,%ymm7 +#vperm2f128 $0x20,%ymm10,%ymm2,%ymm4 +#vperm2f128 $0x20,%ymm11,%ymm3,%ymm5 +#vperm2f128 $0x31,%ymm8,%ymm0,%ymm2 +#vperm2f128 $0x31,%ymm9,%ymm1,%ymm3 +#vperm2f128 $0x20,%ymm8,%ymm0,%ymm0 +#vperm2f128 $0x20,%ymm9,%ymm1,%ymm1 + +3: +vmovupd 0x40(%rdx),%ymm12 +vmovupd 0x60(%rdx),%ymm13 + +#vperm2f128 $0x31,%ymm2,%ymm0,%ymm8 # ymm8 contains re to mul (tw) +#vperm2f128 $0x31,%ymm3,%ymm1,%ymm9 # ymm9 contains re to mul (itw) +#vperm2f128 $0x31,%ymm6,%ymm4,%ymm10 # ymm10 contains im to mul (tw) +#vperm2f128 $0x31,%ymm7,%ymm5,%ymm11 # ymm11 contains im to mul (itw) +#vperm2f128 $0x20,%ymm2,%ymm0,%ymm0 # ymm0 contains re to add (tw) +#vperm2f128 $0x20,%ymm3,%ymm1,%ymm1 # ymm1 contains re to add (itw) +#vperm2f128 $0x20,%ymm6,%ymm4,%ymm2 # ymm2 contains im to add (tw) +#vperm2f128 $0x20,%ymm7,%ymm5,%ymm3 # ymm3 contains im to add (itw) + +vunpckhpd %ymm1,%ymm0,%ymm4 # (0,1) -> (0,4) +vunpckhpd %ymm3,%ymm2,%ymm6 # (2,3) -> (2,6) +vunpckhpd %ymm9,%ymm8,%ymm5 # (8,9) -> (1,5) +vunpckhpd %ymm11,%ymm10,%ymm7 # (10,11) -> (3,7) +vunpcklpd %ymm1,%ymm0,%ymm0 +vunpcklpd %ymm3,%ymm2,%ymm2 +vunpcklpd %ymm9,%ymm8,%ymm1 +vunpcklpd %ymm11,%ymm10,%ymm3 + +vmulpd %ymm6,%ymm13,%ymm8 # ia0.omai (tw) +vmulpd %ymm7,%ymm12,%ymm9 # ia4.omar (itw) +vmulpd %ymm4,%ymm13,%ymm10 # ra0.omai (tw) +vmulpd %ymm5,%ymm12,%ymm11 # ra4.omar (itw) +vfmsub231pd %ymm4,%ymm12,%ymm8 # rprod0 (tw) +vfmadd231pd %ymm5,%ymm13,%ymm9 # rprod4 (itw) +vfmadd231pd %ymm6,%ymm12,%ymm10 # iprod0 (tw) +vfmsub231pd %ymm7,%ymm13,%ymm11 # iprod4 (itw) +vsubpd %ymm8,%ymm0,%ymm4 +vaddpd %ymm9,%ymm1,%ymm5 +vsubpd %ymm10,%ymm2,%ymm6 +vaddpd %ymm11,%ymm3,%ymm7 +vaddpd %ymm8,%ymm0,%ymm0 +vsubpd %ymm9,%ymm1,%ymm1 +vaddpd %ymm10,%ymm2,%ymm2 +vsubpd %ymm11,%ymm3,%ymm3 + +vunpckhpd %ymm7,%ymm3,%ymm11 # (0,4) -> (0,1) +vunpckhpd %ymm5,%ymm1,%ymm9 # (2,6) -> (2,3) +vunpcklpd %ymm7,%ymm3,%ymm10 +vunpcklpd %ymm5,%ymm1,%ymm8 +vunpckhpd %ymm6,%ymm2,%ymm3 # (1,5) -> (8,9) +vunpckhpd %ymm4,%ymm0,%ymm1 # (3,7) -> (10,11) +vunpcklpd %ymm6,%ymm2,%ymm2 +vunpcklpd %ymm4,%ymm0,%ymm0 + +vperm2f128 $0x31,%ymm10,%ymm2,%ymm6 +vperm2f128 $0x31,%ymm11,%ymm3,%ymm7 +vperm2f128 $0x20,%ymm10,%ymm2,%ymm4 +vperm2f128 $0x20,%ymm11,%ymm3,%ymm5 +vperm2f128 $0x31,%ymm8,%ymm0,%ymm2 +vperm2f128 $0x31,%ymm9,%ymm1,%ymm3 +vperm2f128 $0x20,%ymm8,%ymm0,%ymm0 +vperm2f128 $0x20,%ymm9,%ymm1,%ymm1 + +4: +vmovupd %ymm0,(%rdi) # ra0 +vmovupd %ymm1,0x20(%rdi) # ra4 +vmovupd %ymm2,0x40(%rdi) # ra8 +vmovupd %ymm3,0x60(%rdi) # ra12 +vmovupd %ymm4,(%rsi) # ia0 +vmovupd %ymm5,0x20(%rsi) # ia4 +vmovupd %ymm6,0x40(%rsi) # ia8 +vmovupd %ymm7,0x60(%rsi) # ia12 +vzeroupper +ret diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fft4_avx_fma.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fft4_avx_fma.c new file mode 100644 index 0000000..f7d24f3 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fft4_avx_fma.c @@ -0,0 +1,66 @@ +#include +#include +#include + +#include "reim_fft_private.h" + +__always_inline void reim_ctwiddle_avx_fma(__m128d* ra, __m128d* rb, __m128d* ia, __m128d* ib, const __m128d omre, + const __m128d omim) { + // rb * omre - ib * omim; + __m128d rprod0 = _mm_mul_pd(*ib, omim); + rprod0 = _mm_fmsub_pd(*rb, omre, rprod0); + + // rb * omim + ib * omre; + __m128d iprod0 = _mm_mul_pd(*rb, omim); + iprod0 = _mm_fmadd_pd(*ib, omre, iprod0); + + *rb = _mm_sub_pd(*ra, rprod0); + *ib = _mm_sub_pd(*ia, iprod0); + *ra = _mm_add_pd(*ra, rprod0); + *ia = _mm_add_pd(*ia, iprod0); +} + +EXPORT void reim_fft4_avx_fma(double* dre, double* dim, const void* ompv) { + const double* omp = (const double*)ompv; + + __m128d ra0 = _mm_loadu_pd(dre); + __m128d ra2 = _mm_loadu_pd(dre + 2); + __m128d ia0 = _mm_loadu_pd(dim); + __m128d ia2 = _mm_loadu_pd(dim + 2); + + // 1 + { + // duplicate omegas in precomp? + __m128d om = _mm_loadu_pd(omp); + __m128d omre = _mm_permute_pd(om, 0); + __m128d omim = _mm_permute_pd(om, 3); + + reim_ctwiddle_avx_fma(&ra0, &ra2, &ia0, &ia2, omre, omim); + } + + // 2 + { + const __m128d fft4neg = _mm_castsi128_pd(_mm_set_epi64x(UINT64_C(1) << 63, 0)); + __m128d om = _mm_loadu_pd(omp + 2); // om: r,i + __m128d omim = _mm_permute_pd(om, 1); // omim: i,r + __m128d omre = _mm_xor_pd(om, fft4neg); // omre: r,-i + + __m128d rb = _mm_unpackhi_pd(ra0, ra2); // (r0, r1), (r2, r3) -> (r1, r3) + __m128d ib = _mm_unpackhi_pd(ia0, ia2); // (i0, i1), (i2, i3) -> (i1, i3) + __m128d ra = _mm_unpacklo_pd(ra0, ra2); // (r0, r1), (r2, r3) -> (r0, r2) + __m128d ia = _mm_unpacklo_pd(ia0, ia2); // (i0, i1), (i2, i3) -> (i0, i2) + + reim_ctwiddle_avx_fma(&ra, &rb, &ia, &ib, omre, omim); + + ra0 = _mm_unpacklo_pd(ra, rb); + ia0 = _mm_unpacklo_pd(ia, ib); + ra2 = _mm_unpackhi_pd(ra, rb); + ia2 = _mm_unpackhi_pd(ia, ib); + } + + // 4 + _mm_storeu_pd(dre, ra0); + _mm_storeu_pd(dre + 2, ra2); + _mm_storeu_pd(dim, ia0); + _mm_storeu_pd(dim + 2, ia2); +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fft8_avx_fma.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fft8_avx_fma.c new file mode 100644 index 0000000..a39597f --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fft8_avx_fma.c @@ -0,0 +1,89 @@ +#include +#include +#include + +#include "reim_fft_private.h" + +__always_inline void reim_ctwiddle_avx_fma(__m256d* ra, __m256d* rb, __m256d* ia, __m256d* ib, const __m256d omre, + const __m256d omim) { + // rb * omre - ib * omim; + __m256d rprod0 = _mm256_mul_pd(*ib, omim); + rprod0 = _mm256_fmsub_pd(*rb, omre, rprod0); + + // rb * omim + ib * omre; + __m256d iprod0 = _mm256_mul_pd(*rb, omim); + iprod0 = _mm256_fmadd_pd(*ib, omre, iprod0); + + *rb = _mm256_sub_pd(*ra, rprod0); + *ib = _mm256_sub_pd(*ia, iprod0); + *ra = _mm256_add_pd(*ra, rprod0); + *ia = _mm256_add_pd(*ia, iprod0); +} + +EXPORT void reim_fft8_avx_fma(double* dre, double* dim, const void* ompv) { + const double* omp = (const double*)ompv; + + __m256d ra0 = _mm256_loadu_pd(dre); + __m256d ra4 = _mm256_loadu_pd(dre + 4); + __m256d ia0 = _mm256_loadu_pd(dim); + __m256d ia4 = _mm256_loadu_pd(dim + 4); + + // 1 + { + // duplicate omegas in precomp? + __m128d omri = _mm_loadu_pd(omp); + __m256d omriri = _mm256_set_m128d(omri, omri); + __m256d omi = _mm256_permute_pd(omriri, 15); + __m256d omr = _mm256_permute_pd(omriri, 0); + + reim_ctwiddle_avx_fma(&ra0, &ra4, &ia0, &ia4, omr, omi); + } + + // 2 + { + const __m128d fft8neg = _mm_castsi128_pd(_mm_set_epi64x(UINT64_C(1) << 63, 0)); + __m128d omri = _mm_loadu_pd(omp + 2); // r,i + __m128d omrmi = _mm_xor_pd(omri, fft8neg); // r,-i + __m256d omrirmi = _mm256_set_m128d(omrmi, omri); // r,i,r,-i + __m256d omi = _mm256_permute_pd(omrirmi, 3); // i,i,r,r + __m256d omr = _mm256_permute_pd(omrirmi, 12); // r,r,-i,-i + + __m256d rb = _mm256_permute2f128_pd(ra0, ra4, 0x31); + __m256d ib = _mm256_permute2f128_pd(ia0, ia4, 0x31); + __m256d ra = _mm256_permute2f128_pd(ra0, ra4, 0x20); + __m256d ia = _mm256_permute2f128_pd(ia0, ia4, 0x20); + + reim_ctwiddle_avx_fma(&ra, &rb, &ia, &ib, omr, omi); + + ra0 = _mm256_permute2f128_pd(ra, rb, 0x20); + ra4 = _mm256_permute2f128_pd(ra, rb, 0x31); + ia0 = _mm256_permute2f128_pd(ia, ib, 0x20); + ia4 = _mm256_permute2f128_pd(ia, ib, 0x31); + } + + // 3 + { + const __m256d fft8neg2 = _mm256_castsi256_pd(_mm256_set_epi64x(UINT64_C(1) << 63, UINT64_C(1) << 63, 0, 0)); + __m256d om = _mm256_loadu_pd(omp + 4); // r0,r1,i0,i1 + __m256d omi = _mm256_permute2f128_pd(om, om, 1); // i0,i1,r0,r1 + __m256d omr = _mm256_xor_pd(om, fft8neg2); // r0,r1,-i0,-i1 + + __m256d rb = _mm256_unpackhi_pd(ra0, ra4); + __m256d ib = _mm256_unpackhi_pd(ia0, ia4); + __m256d ra = _mm256_unpacklo_pd(ra0, ra4); + __m256d ia = _mm256_unpacklo_pd(ia0, ia4); + + reim_ctwiddle_avx_fma(&ra, &rb, &ia, &ib, omr, omi); + + ra4 = _mm256_unpackhi_pd(ra, rb); + ia4 = _mm256_unpackhi_pd(ia, ib); + ra0 = _mm256_unpacklo_pd(ra, rb); + ia0 = _mm256_unpacklo_pd(ia, ib); + } + + // 4 + _mm256_storeu_pd(dre, ra0); + _mm256_storeu_pd(dre + 4, ra4); + _mm256_storeu_pd(dim, ia0); + _mm256_storeu_pd(dim + 4, ia4); +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fft_avx2.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fft_avx2.c new file mode 100644 index 0000000..73e4e16 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fft_avx2.c @@ -0,0 +1,162 @@ +#include "immintrin.h" +#include "reim_fft_internal.h" +#include "reim_fft_private.h" + +__always_inline void reim_twiddle_fft_avx2_fma(uint32_t h, double* re, double* im, double om[2]) { + const __m128d omx = _mm_load_pd(om); + const __m256d omra = _mm256_set_m128d(omx, omx); + const __m256d omi = _mm256_unpackhi_pd(omra, omra); + const __m256d omr = _mm256_unpacklo_pd(omra, omra); + double* r0 = re; + double* r1 = re + h; + double* i0 = im; + double* i1 = im + h; + for (uint32_t i = 0; i < h; i += 4) { + __m256d ur0 = _mm256_loadu_pd(r0 + i); + __m256d ur1 = _mm256_loadu_pd(r1 + i); + __m256d ui0 = _mm256_loadu_pd(i0 + i); + __m256d ui1 = _mm256_loadu_pd(i1 + i); + __m256d tra = _mm256_mul_pd(omi, ui1); + __m256d tia = _mm256_mul_pd(omi, ur1); + tra = _mm256_fmsub_pd(omr, ur1, tra); + tia = _mm256_fmadd_pd(omr, ui1, tia); + ur1 = _mm256_sub_pd(ur0, tra); + ui1 = _mm256_sub_pd(ui0, tia); + ur0 = _mm256_add_pd(ur0, tra); + ui0 = _mm256_add_pd(ui0, tia); + _mm256_storeu_pd(r0 + i, ur0); + _mm256_storeu_pd(r1 + i, ur1); + _mm256_storeu_pd(i0 + i, ui0); + _mm256_storeu_pd(i1 + i, ui1); + } +} + +__always_inline void reim_bitwiddle_fft_avx2_fma(uint32_t h, double* re, double* im, double om[4]) { + double* const r0 = re; + double* const r1 = re + h; + double* const r2 = re + 2 * h; + double* const r3 = re + 3 * h; + double* const i0 = im; + double* const i1 = im + h; + double* const i2 = im + 2 * h; + double* const i3 = im + 3 * h; + const __m256d om0 = _mm256_loadu_pd(om); + const __m256d omb = _mm256_permute2f128_pd(om0, om0, 0x11); + const __m256d oma = _mm256_permute2f128_pd(om0, om0, 0x00); + const __m256d omai = _mm256_unpackhi_pd(oma, oma); + const __m256d omar = _mm256_unpacklo_pd(oma, oma); + const __m256d ombi = _mm256_unpackhi_pd(omb, omb); + const __m256d ombr = _mm256_unpacklo_pd(omb, omb); + for (uint32_t i = 0; i < h; i += 4) { + __m256d ur0 = _mm256_loadu_pd(r0 + i); + __m256d ur1 = _mm256_loadu_pd(r1 + i); + __m256d ur2 = _mm256_loadu_pd(r2 + i); + __m256d ur3 = _mm256_loadu_pd(r3 + i); + __m256d ui0 = _mm256_loadu_pd(i0 + i); + __m256d ui1 = _mm256_loadu_pd(i1 + i); + __m256d ui2 = _mm256_loadu_pd(i2 + i); + __m256d ui3 = _mm256_loadu_pd(i3 + i); + //------ twiddles 1 + __m256d tra = _mm256_mul_pd(omai, ui2); + __m256d trb = _mm256_mul_pd(omai, ui3); + __m256d tia = _mm256_mul_pd(omai, ur2); + __m256d tib = _mm256_mul_pd(omai, ur3); + tra = _mm256_fmsub_pd(omar, ur2, tra); + trb = _mm256_fmsub_pd(omar, ur3, trb); + tia = _mm256_fmadd_pd(omar, ui2, tia); + tib = _mm256_fmadd_pd(omar, ui3, tib); + ur2 = _mm256_sub_pd(ur0, tra); + ur3 = _mm256_sub_pd(ur1, trb); + ui2 = _mm256_sub_pd(ui0, tia); + ui3 = _mm256_sub_pd(ui1, tib); + ur0 = _mm256_add_pd(ur0, tra); + ur1 = _mm256_add_pd(ur1, trb); + ui0 = _mm256_add_pd(ui0, tia); + ui1 = _mm256_add_pd(ui1, tib); + //------ twiddles 1 + tra = _mm256_mul_pd(ombi, ui1); + trb = _mm256_mul_pd(ombr, ui3); // ii + tia = _mm256_mul_pd(ombi, ur1); + tib = _mm256_mul_pd(ombr, ur3); // ri + tra = _mm256_fmsub_pd(ombr, ur1, tra); + trb = _mm256_fmadd_pd(ombi, ur3, trb); //-rr+ii + tia = _mm256_fmadd_pd(ombr, ui1, tia); + tib = _mm256_fmsub_pd(ombi, ui3, tib); //-ir-ri + ur1 = _mm256_sub_pd(ur0, tra); + ur3 = _mm256_add_pd(ur2, trb); + ui1 = _mm256_sub_pd(ui0, tia); + ui3 = _mm256_add_pd(ui2, tib); + ur0 = _mm256_add_pd(ur0, tra); + ur2 = _mm256_sub_pd(ur2, trb); + ui0 = _mm256_add_pd(ui0, tia); + ui2 = _mm256_sub_pd(ui2, tib); + ///--- + _mm256_storeu_pd(r0 + i, ur0); + _mm256_storeu_pd(r1 + i, ur1); + _mm256_storeu_pd(r2 + i, ur2); + _mm256_storeu_pd(r3 + i, ur3); + _mm256_storeu_pd(i0 + i, ui0); + _mm256_storeu_pd(i1 + i, ui1); + _mm256_storeu_pd(i2 + i, ui2); + _mm256_storeu_pd(i3 + i, ui3); + } +} + +void reim_fft_bfs_16_avx2_fma(uint32_t m, double* re, double* im, double** omg) { + uint32_t log2m = _mm_popcnt_u32(m - 1); // log2(m); + uint32_t mm = m; + if ((log2m & 1) != 0) { + uint32_t h = mm >> 1; + // do the first twiddle iteration normally + reim_twiddle_fft_avx2_fma(h, re, im, *omg); + *omg += 2; + mm = h; + } + while (mm > 16) { + uint32_t h = mm >> 2; + for (uint32_t off = 0; off < m; off += mm) { + reim_bitwiddle_fft_avx2_fma(h, re + off, im + off, *omg); + *omg += 4; + } + mm = h; + } + if (mm != 16) abort(); // bug! + for (uint32_t off = 0; off < m; off += 16) { + reim_fft16_avx_fma(re + off, im + off, *omg); + *omg += 16; + } +} + +void reim_fft_rec_16_avx2_fma(uint32_t m, double* re, double* im, double** omg) { + if (m <= 2048) return reim_fft_bfs_16_avx2_fma(m, re, im, omg); + const uint32_t h = m / 2; + reim_twiddle_fft_avx2_fma(h, re, im, *omg); + *omg += 2; + reim_fft_rec_16_avx2_fma(h, re, im, omg); + reim_fft_rec_16_avx2_fma(h, re + h, im + h, omg); +} + +void reim_fft_avx2_fma(const REIM_FFT_PRECOMP* precomp, double* dat) { + const int32_t m = precomp->m; + double* omg = precomp->powomegas; + double* re = dat; + double* im = dat + m; + if (m <= 16) { + switch (m) { + case 1: + return; + case 2: + return reim_fft2_ref(re, im, omg); + case 4: + return reim_fft4_avx_fma(re, im, omg); + case 8: + return reim_fft8_avx_fma(re, im, omg); + case 16: + return reim_fft16_avx_fma(re, im, omg); + default: + abort(); // m is not a power of 2 + } + } + if (m <= 2048) return reim_fft_bfs_16_avx2_fma(m, re, im, &omg); + return reim_fft_rec_16_avx2_fma(m, re, im, &omg); +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fft_core_template.h b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fft_core_template.h new file mode 100644 index 0000000..b82a05b --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fft_core_template.h @@ -0,0 +1,162 @@ +#ifndef SPQLIOS_REIM_FFT_CORE_TEMPLATE_H +#define SPQLIOS_REIM_FFT_CORE_TEMPLATE_H + +// this file contains the main template for the fft strategy. +// it is meant to be included once for each specialization (ref, avx, neon) +// all the leaf functions it uses shall be defined in the following macros, +// before including this header. + +#if !defined(reim_fft16_f) || !defined(reim_fft16_pom_offset) || !defined(fill_reim_fft16_omegas_f) +#error "missing reim16 definitions" +#endif +#if !defined(reim_fft8_f) || !defined(reim_fft8_pom_offset) || !defined(fill_reim_fft8_omegas_f) +#error "missing reim8 definitions" +#endif +#if !defined(reim_fft4_f) || !defined(reim_fft4_pom_offset) || !defined(fill_reim_fft4_omegas_f) +#error "missing reim4 definitions" +#endif +#if !defined(reim_fft2_f) || !defined(reim_fft2_pom_offset) || !defined(fill_reim_fft2_omegas_f) +#error "missing reim2 definitions" +#endif +#if !defined(reim_twiddle_fft_f) || !defined(reim_twiddle_fft_pom_offset) || !defined(fill_reim_twiddle_fft_omegas_f) +#error "missing twiddle definitions" +#endif +#if !defined(reim_bitwiddle_fft_f) || !defined(reim_bitwiddle_fft_pom_offset) || \ + !defined(fill_reim_bitwiddle_fft_omegas_f) +#error "missing bitwiddle definitions" +#endif +#if !defined(reim_fft_bfs_16_f) || !defined(fill_reim_fft_bfs_16_omegas_f) +#error "missing bfs_16 definitions" +#endif +#if !defined(reim_fft_rec_16_f) || !defined(fill_reim_fft_rec_16_omegas_f) +#error "missing rec_16 definitions" +#endif +#if !defined(reim_fft_f) || !defined(fill_reim_fft_omegas_f) +#error "missing main definitions" +#endif + +void reim_fft_bfs_16_f(uint64_t m, double* re, double* im, double** omg) { + uint64_t log2m = log2(m); + uint64_t mm = m; + if (log2m & 1) { + uint64_t h = mm >> 1; + // do the first twiddle iteration normally + reim_twiddle_fft_f(h, re, im, *omg); + *omg += reim_twiddle_fft_pom_offset; + mm = h; + } + while (mm > 16) { + uint64_t h = mm >> 2; + for (uint64_t off = 0; off < m; off += mm) { + reim_bitwiddle_fft_f(h, re + off, im + off, *omg); + *omg += reim_bitwiddle_fft_pom_offset; + } + mm = h; + } + if (mm != 16) abort(); // bug! + for (uint64_t off = 0; off < m; off += 16) { + reim_fft16_f(re + off, im + off, *omg); + *omg += reim_fft16_pom_offset; + } +} + +void fill_reim_fft_bfs_16_omegas_f(uint64_t m, double entry_pwr, double** omg) { + uint64_t log2m = log2(m); + uint64_t mm = m; + double ss = entry_pwr; + if (log2m % 2 != 0) { + uint64_t h = mm >> 1; + double s = ss / 2.; + // do the first twiddle iteration normally + fill_reim_twiddle_fft_omegas_f(s, omg); + mm = h; + ss = s; + } + while (mm > 16) { + uint64_t h = mm >> 2; + double s = ss / 4.; + for (uint64_t off = 0; off < m; off += mm) { + double rs0 = s + fracrevbits(off / mm) / 4.; + fill_reim_bitwiddle_fft_omegas_f(rs0, omg); + } + mm = h; + ss = s; + } + if (mm != 16) abort(); // bug! + for (uint64_t off = 0; off < m; off += 16) { + double s = ss + fracrevbits(off / 16); + fill_reim_fft16_omegas_f(s, omg); + } +} + +void reim_fft_rec_16_f(uint64_t m, double* re, double* im, double** omg) { + if (m <= 2048) return reim_fft_bfs_16_f(m, re, im, omg); + const uint32_t h = m >> 1; + reim_twiddle_fft_f(h, re, im, *omg); + *omg += reim_twiddle_fft_pom_offset; + reim_fft_rec_16_f(h, re, im, omg); + reim_fft_rec_16_f(h, re + h, im + h, omg); +} + +void fill_reim_fft_rec_16_omegas_f(uint64_t m, double entry_pwr, double** omg) { + if (m <= 2048) return fill_reim_fft_bfs_16_omegas_f(m, entry_pwr, omg); + const uint64_t h = m / 2; + const double s = entry_pwr / 2; + fill_reim_twiddle_fft_omegas_f(s, omg); + fill_reim_fft_rec_16_omegas_f(h, s, omg); + fill_reim_fft_rec_16_omegas_f(h, s + 0.5, omg); +} + +void reim_fft_f(const REIM_FFT_PRECOMP* precomp, double* dat) { + const int32_t m = precomp->m; + double* omg = precomp->powomegas; + double* re = dat; + double* im = dat + m; + if (m <= 16) { + switch (m) { + case 1: + return; + case 2: + return reim_fft2_f(re, im, omg); + case 4: + return reim_fft4_f(re, im, omg); + case 8: + return reim_fft8_f(re, im, omg); + case 16: + return reim_fft16_f(re, im, omg); + default: + abort(); // m is not a power of 2 + } + } + if (m <= 2048) return reim_fft_bfs_16_f(m, re, im, &omg); + return reim_fft_rec_16_f(m, re, im, &omg); +} + +void fill_reim_fft_omegas_f(uint64_t m, double entry_pwr, double** omg) { + if (m <= 16) { + switch (m) { + case 1: + break; + case 2: + fill_reim_fft2_omegas_f(entry_pwr, omg); + break; + case 4: + fill_reim_fft4_omegas_f(entry_pwr, omg); + break; + case 8: + fill_reim_fft8_omegas_f(entry_pwr, omg); + break; + case 16: + fill_reim_fft16_omegas_f(entry_pwr, omg); + break; + default: + abort(); // m is not a power of 2 + } + } else if (m <= 2048) { + fill_reim_fft_bfs_16_omegas_f(m, entry_pwr, omg); + } else { + fill_reim_fft_rec_16_omegas_f(m, entry_pwr, omg); + } +} + +#endif // SPQLIOS_REIM_FFT_CORE_TEMPLATE_H diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fft_ifft.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fft_ifft.c new file mode 100644 index 0000000..82bc693 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fft_ifft.c @@ -0,0 +1,35 @@ +#include +#include +#include +#include + +#include "../commons_private.h" +#include "reim_fft_internal.h" +#include "reim_fft_private.h" + +double accurate_cos(int32_t i, int32_t n) { // cos(2pi*i/n) + i = ((i % n) + n) % n; + if (i >= 3 * n / 4) return cos(2. * M_PI * (n - i) / (double)(n)); + if (i >= 2 * n / 4) return -cos(2. * M_PI * (i - n / 2) / (double)(n)); + if (i >= 1 * n / 4) return -cos(2. * M_PI * (n / 2 - i) / (double)(n)); + return cos(2. * M_PI * (i) / (double)(n)); +} + +double accurate_sin(int32_t i, int32_t n) { // sin(2pi*i/n) + i = ((i % n) + n) % n; + if (i >= 3 * n / 4) return -sin(2. * M_PI * (n - i) / (double)(n)); + if (i >= 2 * n / 4) return -sin(2. * M_PI * (i - n / 2) / (double)(n)); + if (i >= 1 * n / 4) return sin(2. * M_PI * (n / 2 - i) / (double)(n)); + return sin(2. * M_PI * (i) / (double)(n)); +} + +EXPORT double* reim_ifft_precomp_get_buffer(const REIM_IFFT_PRECOMP* tables, uint32_t buffer_index) { + return (double*)((uint8_t*)tables->aligned_buffers + buffer_index * tables->buf_size); +} + +EXPORT double* reim_fft_precomp_get_buffer(const REIM_FFT_PRECOMP* tables, uint32_t buffer_index) { + return (double*)((uint8_t*)tables->aligned_buffers + buffer_index * tables->buf_size); +} + +EXPORT void reim_fft(const REIM_FFT_PRECOMP* tables, double* data) { tables->function(tables, data); } +EXPORT void reim_ifft(const REIM_IFFT_PRECOMP* tables, double* data) { tables->function(tables, data); } diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fft_internal.h b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fft_internal.h new file mode 100644 index 0000000..1e9f1aa --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fft_internal.h @@ -0,0 +1,147 @@ +#ifndef SPQLIOS_REIM_FFT_INTERNAL_H +#define SPQLIOS_REIM_FFT_INTERNAL_H + +#include +#include +#include + +#include "reim_fft.h" + +EXPORT void reim_fftvec_addmul_fma(const REIM_FFTVEC_ADDMUL_PRECOMP* precomp, double* r, const double* a, + const double* b); +EXPORT void reim_fftvec_addmul_ref(const REIM_FFTVEC_ADDMUL_PRECOMP* precomp, double* r, const double* a, + const double* b); + +EXPORT void reim_fftvec_mul_fma(const REIM_FFTVEC_MUL_PRECOMP* tables, double* r, const double* a, const double* b); +EXPORT void reim_fftvec_mul_ref(const REIM_FFTVEC_MUL_PRECOMP* tables, double* r, const double* a, const double* b); + +EXPORT void reim_fftvec_add_ref(const REIM_FFTVEC_ADD_PRECOMP* tables, double* r, const double* a, const double* b); +EXPORT void reim_fftvec_add_fma(const REIM_FFTVEC_ADD_PRECOMP* tables, double* r, const double* a, const double* b); + +EXPORT void reim_fftvec_sub_ref(const REIM_FFTVEC_SUB_PRECOMP* tables, double* r, const double* a, const double* b); +EXPORT void reim_fftvec_sub_fma(const REIM_FFTVEC_SUB_PRECOMP* tables, double* r, const double* a, const double* b); + +/** @brief r = x from ZnX (coeffs as signed int32_t's ) to double */ +EXPORT void reim_from_znx32_ref(const REIM_FROM_ZNX32_PRECOMP* precomp, void* r, const int32_t* x); +EXPORT void reim_from_znx32_avx2_fma(const REIM_FROM_ZNX32_PRECOMP* precomp, void* r, const int32_t* x); + +/** @brief r = x mod 1 (coeffs as double) */ +EXPORT void reim_to_tnx_ref(const REIM_TO_TNX_PRECOMP* tables, double* r, const double* x); +EXPORT void reim_to_tnx_avx(const REIM_TO_TNX_PRECOMP* tables, double* r, const double* x); + +EXPORT void reim_from_znx64_ref(const REIM_FROM_ZNX64_PRECOMP* precomp, void* r, const int64_t* x); +EXPORT void reim_from_znx64_bnd50_fma(const REIM_FROM_ZNX64_PRECOMP* precomp, void* r, const int64_t* x); + +EXPORT void reim_to_znx64_ref(const REIM_TO_ZNX64_PRECOMP* precomp, int64_t* r, const void* x); +EXPORT void reim_to_znx64_avx2_bnd63_fma(const REIM_TO_ZNX64_PRECOMP* precomp, int64_t* r, const void* x); +EXPORT void reim_to_znx64_avx2_bnd50_fma(const REIM_TO_ZNX64_PRECOMP* precomp, int64_t* r, const void* x); + +/** + * @brief compute the fft evaluations of P in place + * fft(data) = fft_rec(data, i); + * function fft_rec(data, omega) { + * if #data = 1: return data + * let s = sqrt(omega) w. re(s)>0 + * let (u,v) = merge_fft(data, s) + * return [fft_rec(u, s), fft_rec(v, -s)] + * } + * @param tables precomputed tables (contains all the powers of omega in the order they are used) + * @param data vector of m complexes (coeffs as input, evals as output) + */ +EXPORT void reim_fft_ref(const REIM_FFT_PRECOMP* tables, double* data); +EXPORT void reim_fft_avx2_fma(const REIM_FFT_PRECOMP* tables, double* data); + +/** + * @brief compute the ifft evaluations of P in place + * ifft(data) = ifft_rec(data, i); + * function ifft_rec(data, omega) { + * if #data = 1: return data + * let s = sqrt(omega) w. re(s)>0 + * let (u,v) = data + * return split_fft([ifft_rec(u, s), ifft_rec(v, -s)],s) + * } + * @param itables precomputed tables (contains all the powers of omega in the order they are used) + * @param data vector of m complexes (coeffs as input, evals as output) + */ +EXPORT void reim_ifft_ref(const REIM_IFFT_PRECOMP* itables, double* data); +EXPORT void reim_ifft_avx2_fma(const REIM_IFFT_PRECOMP* itables, double* data); + +/// new compressed implementation + +/** @brief naive FFT code mod X^m-exp(2i.pi.entry_pwr) */ +EXPORT void reim_naive_fft(uint64_t m, double entry_pwr, double* re, double* im); + +/** @brief 16-dimensional FFT with precomputed omegas */ +EXPORT void reim_fft16_neon(double* dre, double* dim, const void* omega); +EXPORT void reim_fft16_avx_fma(double* dre, double* dim, const void* omega); +EXPORT void reim_fft16_ref(double* dre, double* dim, const void* omega); + +/** @brief precompute omegas so that reim_fft16 functions */ +EXPORT void fill_reim_fft16_omegas(const double entry_pwr, double** omg); +EXPORT void fill_reim_fft16_omegas_neon(const double entry_pwr, double** omg); + +/** @brief 8-dimensional FFT with precomputed omegas */ +EXPORT void reim_fft8_avx_fma(double* dre, double* dim, const void* omega); +EXPORT void reim_fft8_ref(double* dre, double* dim, const void* omega); + +/** @brief precompute omegas so that reim_fft8 functions */ +EXPORT void fill_reim_fft8_omegas(const double entry_pwr, double** omg); + +/** @brief 4-dimensional FFT with precomputed omegas */ +EXPORT void reim_fft4_avx_fma(double* dre, double* dim, const void* omega); +EXPORT void reim_fft4_ref(double* dre, double* dim, const void* omega); + +/** @brief precompute omegas so that reim_fft8 functions */ +EXPORT void fill_reim_fft4_omegas(const double entry_pwr, double** omg); + +/** @brief 2-dimensional FFT with precomputed omegas */ +// EXPORT void reim_fft4_avx_fma(double* dre, double* dim, const void* omega); +EXPORT void reim_fft2_ref(double* dre, double* dim, const void* omega); + +/** @brief precompute omegas so that reim_fft8 functions */ +EXPORT void fill_reim_fft2_omegas(const double entry_pwr, double** omg); + +EXPORT void reim_fft_bfs_16_ref(uint64_t m, double* re, double* im, double** omg); +EXPORT void fill_reim_fft_bfs_16_omegas(uint64_t m, double entry_pwr, double** omg); + +EXPORT void reim_fft_rec_16_ref(uint64_t m, double* re, double* im, double** omg); +EXPORT void fill_reim_fft_rec_16_omegas(uint64_t m, double entry_pwr, double** omg); + +/** @brief naive FFT code mod X^m-exp(2i.pi.entry_pwr) */ +EXPORT void reim_naive_ifft(uint64_t m, double entry_pwr, double* re, double* im); + +/** @brief 16-dimensional FFT with precomputed omegas */ +EXPORT void reim_ifft16_avx_fma(double* dre, double* dim, const void* omega); +EXPORT void reim_ifft16_ref(double* dre, double* dim, const void* omega); + +/** @brief precompute omegas so that reim_fft16 functions */ +EXPORT void fill_reim_ifft16_omegas(const double entry_pwr, double** omg); + +/** @brief 8-dimensional FFT with precomputed omegas */ +EXPORT void reim_ifft8_avx_fma(double* dre, double* dim, const void* omega); +EXPORT void reim_ifft8_ref(double* dre, double* dim, const void* omega); + +/** @brief precompute omegas so that reim_fft8 functions */ +EXPORT void fill_reim_ifft8_omegas(const double entry_pwr, double** omg); + +/** @brief 4-dimensional FFT with precomputed omegas */ +EXPORT void reim_ifft4_avx_fma(double* dre, double* dim, const void* omega); +EXPORT void reim_ifft4_ref(double* dre, double* dim, const void* omega); + +/** @brief precompute omegas so that reim_fft8 functions */ +EXPORT void fill_reim_ifft4_omegas(const double entry_pwr, double** omg); + +/** @brief 2-dimensional FFT with precomputed omegas */ +// EXPORT void reim_ifft2_avx_fma(double* dre, double* dim, const void* omega); +EXPORT void reim_ifft2_ref(double* dre, double* dim, const void* omega); + +/** @brief precompute omegas so that reim_fft8 functions */ +EXPORT void fill_reim_ifft2_omegas(const double entry_pwr, double** omg); + +EXPORT void reim_ifft_bfs_16_ref(uint64_t m, double* re, double* im, double** omg); +EXPORT void fill_reim_ifft_bfs_16_omegas(uint64_t m, double entry_pwr, double** omg); + +EXPORT void reim_ifft_rec_16_ref(uint64_t m, double* re, double* im, double** omg); +EXPORT void fill_reim_ifft_rec_16_omegas(uint64_t m, double entry_pwr, double** omg); + +#endif // SPQLIOS_REIM_FFT_INTERNAL_H diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fft_neon.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fft_neon.c new file mode 100644 index 0000000..43bd6d1 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fft_neon.c @@ -0,0 +1,1624 @@ +/* + * This file is adapted from the implementation of the FFT on Arm64/Neon + * available in https://github.com/cothan/Falcon-Arm (neon/fft.c). + * ============================================================================= + * Copyright (c) 2022 by Cryptographic Engineering Research Group (CERG) + * ECE Department, George Mason University + * Fairfax, VA, U.S.A. + * @author: Duc Tri Nguyen dnguye69@gmu.edu, cothannguyen@gmail.com + * Licensed under the Apache License, Version 2.0 (the "License"); + * ============================================================================= + * + * The original source file has been modified by the authors of spqlios-arithmetic + * to be interfaced with the dynamic twiddle factors generator of the spqlios-arithmetic + * library, as well as the recursive dfs strategy for large complex dimensions m>=2048. + */ + +#include + +#include "../commons.h" +#include "../ext/neon_accel/macrof.h" +#include "../ext/neon_accel/macrofx4.h" + +void fill_reim_fft16_omegas_neon(const double entry_pwr, double** omg) { + const double j_pow = 1. / 8.; + const double k_pow = 1. / 16.; + const double pin = entry_pwr / 2.; + const double pin_2 = entry_pwr / 4.; + const double pin_4 = entry_pwr / 8.; + const double pin_8 = entry_pwr / 16.; + // 0 and 1 are real and imag of om + (*omg)[0] = cos(2. * M_PI * pin); + (*omg)[1] = sin(2. * M_PI * pin); + // 2 and 3 are real and imag of om^1/2 + (*omg)[2] = cos(2. * M_PI * (pin_2)); + (*omg)[3] = sin(2. * M_PI * (pin_2)); + // (4,5) and (6,7) are real and imag of om^1/4 and j.om^1/4 + (*omg)[4] = cos(2. * M_PI * (pin_4)); + (*omg)[5] = sin(2. * M_PI * (pin_4)); + (*omg)[6] = cos(2. * M_PI * (pin_4 + j_pow)); + (*omg)[7] = sin(2. * M_PI * (pin_4 + j_pow)); + // ((8,9,10,11),(12,13,14,15)) are 4 reals then 4 imag of om^1/8*(1,k,j,kj) + (*omg)[8] = cos(2. * M_PI * (pin_8)); + (*omg)[10] = cos(2. * M_PI * (pin_8 + j_pow)); + (*omg)[12] = cos(2. * M_PI * (pin_8 + k_pow)); + (*omg)[14] = cos(2. * M_PI * (pin_8 + j_pow + k_pow)); + (*omg)[9] = sin(2. * M_PI * (pin_8)); + (*omg)[11] = sin(2. * M_PI * (pin_8 + j_pow)); + (*omg)[13] = sin(2. * M_PI * (pin_8 + k_pow)); + (*omg)[15] = sin(2. * M_PI * (pin_8 + j_pow + k_pow)); + *omg += 16; +} + +EXPORT void reim_fft16_neon(double* dre, double* dim, const void* omega) { + const double* pom = omega; + // Total SIMD register: 28 = 24 + 4 + float64x2x2_t s_re_im; // 2 + float64x2x4_t x_re, x_im, y_re, y_im, t_re, t_im, v_re, v_im; // 32 + + { + /* + Level 2 + ( 8, 24) * ( 0, 1) + ( 9, 25) * ( 0, 1) + ( 10, 26) * ( 0, 1) + ( 11, 27) * ( 0, 1) + ( 12, 28) * ( 0, 1) + ( 13, 29) * ( 0, 1) + ( 14, 30) * ( 0, 1) + ( 15, 31) * ( 0, 1) + + ( 8, 24) = ( 0, 16) - @ + ( 9, 25) = ( 1, 17) - @ + ( 10, 26) = ( 2, 18) - @ + ( 11, 27) = ( 3, 19) - @ + ( 12, 28) = ( 4, 20) - @ + ( 13, 29) = ( 5, 21) - @ + ( 14, 30) = ( 6, 22) - @ + ( 15, 31) = ( 7, 23) - @ + + ( 0, 16) = ( 0, 16) + @ + ( 1, 17) = ( 1, 17) + @ + ( 2, 18) = ( 2, 18) + @ + ( 3, 19) = ( 3, 19) + @ + ( 4, 20) = ( 4, 20) + @ + ( 5, 21) = ( 5, 21) + @ + ( 6, 22) = ( 6, 22) + @ + ( 7, 23) = ( 7, 23) + @ + */ + vload(s_re_im.val[0], pom); + + vloadx4(y_re, dre + 8); + vloadx4(y_im, dim + 8); + + FWD_TOP_LANEx4(v_re, v_im, y_re, y_im, s_re_im.val[0]); + + vloadx4(x_re, dre); + vloadx4(x_im, dim); + + FWD_BOTx4(x_re, x_im, y_re, y_im, v_re, v_im); + + // vstorex4(dre, x_re); + // vstorex4(dim, x_im); + // vstorex4(dre + 8, y_re); + // vstorex4(dim + 8, y_im); + // return; + } + { + /* + Level 3 + + ( 4, 20) * ( 0, 1) + ( 5, 21) * ( 0, 1) + ( 6, 22) * ( 0, 1) + ( 7, 23) * ( 0, 1) + + ( 4, 20) = ( 0, 16) - @ + ( 5, 21) = ( 1, 17) - @ + ( 6, 22) = ( 2, 18) - @ + ( 7, 23) = ( 3, 19) - @ + + ( 0, 16) = ( 0, 16) + @ + ( 1, 17) = ( 1, 17) + @ + ( 2, 18) = ( 2, 18) + @ + ( 3, 19) = ( 3, 19) + @ + + ( 12, 28) * ( 0, 1) + ( 13, 29) * ( 0, 1) + ( 14, 30) * ( 0, 1) + ( 15, 31) * ( 0, 1) + + ( 12, 28) = ( 8, 24) - j@ + ( 13, 29) = ( 9, 25) - j@ + ( 14, 30) = ( 10, 26) - j@ + ( 15, 31) = ( 11, 27) - j@ + + ( 8, 24) = ( 8, 24) + j@ + ( 9, 25) = ( 9, 25) + j@ + ( 10, 26) = ( 10, 26) + j@ + ( 11, 27) = ( 11, 27) + j@ + */ + + // vloadx4(y_re, dre + 8); + // vloadx4(y_im, dim + 8); + // vloadx4(x_re, dre); + // vloadx4(x_im, dim); + + vload(s_re_im.val[0], pom + 2); + + FWD_TOP_LANE(t_re.val[0], t_im.val[0], x_re.val[2], x_im.val[2], s_re_im.val[0]); + FWD_TOP_LANE(t_re.val[1], t_im.val[1], x_re.val[3], x_im.val[3], s_re_im.val[0]); + FWD_TOP_LANE(t_re.val[2], t_im.val[2], y_re.val[2], y_im.val[2], s_re_im.val[0]); + FWD_TOP_LANE(t_re.val[3], t_im.val[3], y_re.val[3], y_im.val[3], s_re_im.val[0]); + + FWD_BOT(x_re.val[0], x_im.val[0], x_re.val[2], x_im.val[2], t_re.val[0], t_im.val[0]); + FWD_BOT(x_re.val[1], x_im.val[1], x_re.val[3], x_im.val[3], t_re.val[1], t_im.val[1]); + FWD_BOTJ(y_re.val[0], y_im.val[0], y_re.val[2], y_im.val[2], t_re.val[2], t_im.val[2]); + FWD_BOTJ(y_re.val[1], y_im.val[1], y_re.val[3], y_im.val[3], t_re.val[3], t_im.val[3]); + + // vstorex4(dre, x_re); + // vstorex4(dim, x_im); + // vstorex4(dre + 8, y_re); + // vstorex4(dim + 8, y_im); + // return; + } + { + /* + Level 4 + + ( 2, 18) * ( 0, 1) + ( 3, 19) * ( 0, 1) + ( 6, 22) * ( 0, 1) + ( 7, 23) * ( 0, 1) + + ( 2, 18) = ( 0, 16) - @ + ( 3, 19) = ( 1, 17) - @ + ( 0, 16) = ( 0, 16) + @ + ( 1, 17) = ( 1, 17) + @ + + ( 6, 22) = ( 4, 20) - j@ + ( 7, 23) = ( 5, 21) - j@ + ( 4, 20) = ( 4, 20) + j@ + ( 5, 21) = ( 5, 21) + j@ + + ( 10, 26) * ( 2, 3) + ( 11, 27) * ( 2, 3) + ( 14, 30) * ( 2, 3) + ( 15, 31) * ( 2, 3) + + ( 10, 26) = ( 8, 24) - @ + ( 11, 27) = ( 9, 25) - @ + ( 8, 24) = ( 8, 24) + @ + ( 9, 25) = ( 9, 25) + @ + + ( 14, 30) = ( 12, 28) - j@ + ( 15, 31) = ( 13, 29) - j@ + ( 12, 28) = ( 12, 28) + j@ + ( 13, 29) = ( 13, 29) + j@ + */ + // vloadx4(y_re, dre + 8); + // vloadx4(y_im, dim + 8); + // vloadx4(x_re, dre); + // vloadx4(x_im, dim); + + vloadx2(s_re_im, pom + 4); + + FWD_TOP_LANE(t_re.val[0], t_im.val[0], x_re.val[1], x_im.val[1], s_re_im.val[0]); + FWD_TOP_LANE(t_re.val[1], t_im.val[1], x_re.val[3], x_im.val[3], s_re_im.val[0]); + FWD_TOP_LANE(t_re.val[2], t_im.val[2], y_re.val[1], y_im.val[1], s_re_im.val[1]); + FWD_TOP_LANE(t_re.val[3], t_im.val[3], y_re.val[3], y_im.val[3], s_re_im.val[1]); + + FWD_BOT(x_re.val[0], x_im.val[0], x_re.val[1], x_im.val[1], t_re.val[0], t_im.val[0]); + FWD_BOTJ(x_re.val[2], x_im.val[2], x_re.val[3], x_im.val[3], t_re.val[1], t_im.val[1]); + FWD_BOT(y_re.val[0], y_im.val[0], y_re.val[1], y_im.val[1], t_re.val[2], t_im.val[2]); + FWD_BOTJ(y_re.val[2], y_im.val[2], y_re.val[3], y_im.val[3], t_re.val[3], t_im.val[3]); + + // vstorex4(dre, x_re); + // vstorex4(dim, x_im); + // vstorex4(dre + 8, y_re); + // vstorex4(dim + 8, y_im); + // return; + } + { + /* + Level 5 + + ( 1, 17) * ( 0, 1) + ( 5, 21) * ( 2, 3) + ------ + ( 1, 17) = ( 0, 16) - @ + ( 5, 21) = ( 4, 20) - @ + ( 0, 16) = ( 0, 16) + @ + ( 4, 20) = ( 4, 20) + @ + + ( 3, 19) * ( 0, 1) + ( 7, 23) * ( 2, 3) + ------ + ( 3, 19) = ( 2, 18) - j@ + ( 7, 23) = ( 6, 22) - j@ + ( 2, 18) = ( 2, 18) + j@ + ( 6, 22) = ( 6, 22) + j@ + + ( 9, 25) * ( 4, 5) + ( 13, 29) * ( 6, 7) + ------ + ( 9, 25) = ( 8, 24) - @ + ( 13, 29) = ( 12, 28) - @ + ( 8, 24) = ( 8, 24) + @ + ( 12, 28) = ( 12, 28) + @ + + ( 11, 27) * ( 4, 5) + ( 15, 31) * ( 6, 7) + ------ + ( 11, 27) = ( 10, 26) - j@ + ( 15, 31) = ( 14, 30) - j@ + ( 10, 26) = ( 10, 26) + j@ + ( 14, 30) = ( 14, 30) + j@ + + before transpose_f64 + x_re: 0, 1 | 2, 3 | 4, 5 | 6, 7 + y_re: 8, 9 | 10, 11 | 12, 13 | 14, 15 + after transpose_f64 + x_re: 0, 4 | 2, 6 | 1, 5 | 3, 7 + y_re: 8, 12| 9, 13| 10, 14 | 11, 15 + after swap + x_re: 0, 4 | 1, 5 | 2, 6 | 3, 7 + y_re: 8, 12| 10, 14 | 9, 13| 11, 15 + */ + + // vloadx4(y_re, dre + 8); + // vloadx4(y_im, dim + 8); + // vloadx4(x_re, dre); + // vloadx4(x_im, dim); + + transpose_f64(x_re, x_re, v_re, 0, 2, 0); + transpose_f64(x_re, x_re, v_re, 1, 3, 1); + transpose_f64(x_im, x_im, v_im, 0, 2, 0); + transpose_f64(x_im, x_im, v_im, 1, 3, 1); + + v_re.val[0] = x_re.val[2]; + x_re.val[2] = x_re.val[1]; + x_re.val[1] = v_re.val[0]; + + v_im.val[0] = x_im.val[2]; + x_im.val[2] = x_im.val[1]; + x_im.val[1] = v_im.val[0]; + + transpose_f64(y_re, y_re, v_re, 0, 2, 2); + transpose_f64(y_re, y_re, v_re, 1, 3, 3); + transpose_f64(y_im, y_im, v_im, 0, 2, 2); + transpose_f64(y_im, y_im, v_im, 1, 3, 3); + + v_re.val[0] = y_re.val[2]; + y_re.val[2] = y_re.val[1]; + y_re.val[1] = v_re.val[0]; + + v_im.val[0] = y_im.val[2]; + y_im.val[2] = y_im.val[1]; + y_im.val[1] = v_im.val[0]; + + // double pom8[] = {pom[8], pom[12], pom[9], pom[13]}; + vload2(s_re_im, pom + 8); + // vload2(s_re_im, pom8); + + FWD_TOP(t_re.val[0], t_im.val[0], x_re.val[1], x_im.val[1], s_re_im.val[0], s_re_im.val[1]); + FWD_TOP(t_re.val[1], t_im.val[1], x_re.val[3], x_im.val[3], s_re_im.val[0], s_re_im.val[1]); + + // double pom12[] = {pom[10], pom[14], pom[11], pom[15]}; + vload2(s_re_im, pom + 12); + // vload2(s_re_im, pom12); + + FWD_TOP(t_re.val[2], t_im.val[2], y_re.val[1], y_im.val[1], s_re_im.val[0], s_re_im.val[1]); + FWD_TOP(t_re.val[3], t_im.val[3], y_re.val[3], y_im.val[3], s_re_im.val[0], s_re_im.val[1]); + + FWD_BOT(x_re.val[0], x_im.val[0], x_re.val[1], x_im.val[1], t_re.val[0], t_im.val[0]); + FWD_BOTJ(x_re.val[2], x_im.val[2], x_re.val[3], x_im.val[3], t_re.val[1], t_im.val[1]); + + vstore4(dre, x_re); + vstore4(dim, x_im); + + FWD_BOT(y_re.val[0], y_im.val[0], y_re.val[1], y_im.val[1], t_re.val[2], t_im.val[2]); + FWD_BOTJ(y_re.val[2], y_im.val[2], y_re.val[3], y_im.val[3], t_re.val[3], t_im.val[3]); + + vstore4(dre + 8, y_re); + vstore4(dim + 8, y_im); + } +} + +void reim_twiddle_fft_neon(uint64_t h, double* re, double* im, double om[2]) { + // Total SIMD register: 28 = 24 + 4 + if (h < 8) abort(); // bug + float64x2_t s_re_im; // 2 + float64x2x4_t x_re, x_im, y_re, y_im, v_re, v_im; // 32 + vload(s_re_im, om); + for (uint64_t blk = 0; blk < h; blk += 8) { + double* dre = re + blk; + double* dim = im + blk; + vloadx4(y_re, dre + h); + vloadx4(y_im, dim + h); + FWD_TOP_LANEx4(v_re, v_im, y_re, y_im, s_re_im); + vloadx4(x_re, dre); + vloadx4(x_im, dim); + FWD_BOTx4(x_re, x_im, y_re, y_im, v_re, v_im); + vstorex4(dre, x_re); + vstorex4(dim, x_im); + vstorex4(dre + h, y_re); + vstorex4(dim + h, y_im); + } +} + +void reim_ctwiddle(double* ra, double* ia, double* rb, double* ib, double omre, double omim); +// i (omre + i omim) = -omim + i omre +void reim_citwiddle(double* ra, double* ia, double* rb, double* ib, double omre, double omim); + +void reim_bitwiddle_fft_neon(uint64_t h, double* re, double* im, double om[4]) { + // Total SIMD register: 28 = 24 + 4 + if (h < 4) abort(); // bug + double* r0 = re; + double* r1 = re + h; + double* r2 = re + 2 * h; + double* r3 = re + 3 * h; + double* i0 = im; + double* i1 = im + h; + double* i2 = im + 2 * h; + double* i3 = im + 3 * h; + float64x2x2_t s_re_im; // 2 + float64x2x4_t v_re, v_im; // 2 + float64x2x2_t x0_re, x0_im, x1_re, x1_im; // 32 + float64x2x2_t x2_re, x2_im, x3_re, x3_im; // 32 + vloadx2(s_re_im, om); + for (uint64_t blk = 0; blk < h; blk += 4) { + { + vloadx2(x2_re, r2 + blk); + vloadx2(x3_re, r3 + blk); + vloadx2(x2_im, i2 + blk); + vloadx2(x3_im, i3 + blk); + FWD_TOP_LANE(v_re.val[0], v_im.val[0], x2_re.val[0], x2_im.val[0], s_re_im.val[0]); + FWD_TOP_LANE(v_re.val[1], v_im.val[1], x2_re.val[1], x2_im.val[1], s_re_im.val[0]); + FWD_TOP_LANE(v_re.val[2], v_im.val[2], x3_re.val[0], x3_im.val[0], s_re_im.val[0]); + FWD_TOP_LANE(v_re.val[3], v_im.val[3], x3_re.val[1], x3_im.val[1], s_re_im.val[0]); + vloadx2(x0_re, r0 + blk); + vloadx2(x1_re, r1 + blk); + vloadx2(x0_im, i0 + blk); + vloadx2(x1_im, i1 + blk); + FWD_BOT(x0_re.val[0], x0_im.val[0], x2_re.val[0], x2_im.val[0], v_re.val[0], v_im.val[0]); + FWD_BOT(x0_re.val[1], x0_im.val[1], x2_re.val[1], x2_im.val[1], v_re.val[1], v_im.val[1]); + FWD_BOT(x1_re.val[0], x1_im.val[0], x3_re.val[0], x3_im.val[0], v_re.val[2], v_im.val[2]); + FWD_BOT(x1_re.val[1], x1_im.val[1], x3_re.val[1], x3_im.val[1], v_re.val[3], v_im.val[3]); + } + { + FWD_TOP_LANE(v_re.val[0], v_im.val[0], x1_re.val[0], x1_im.val[0], s_re_im.val[1]); + FWD_TOP_LANE(v_re.val[1], v_im.val[1], x1_re.val[1], x1_im.val[1], s_re_im.val[1]); + FWD_TOP_LANE(v_re.val[2], v_im.val[2], x3_re.val[0], x3_im.val[0], s_re_im.val[1]); + FWD_TOP_LANE(v_re.val[3], v_im.val[3], x3_re.val[1], x3_im.val[1], s_re_im.val[1]); + FWD_BOT(x0_re.val[0], x0_im.val[0], x1_re.val[0], x1_im.val[0], v_re.val[0], v_im.val[0]); + FWD_BOT(x0_re.val[1], x0_im.val[1], x1_re.val[1], x1_im.val[1], v_re.val[1], v_im.val[1]); + FWD_BOTJ(x2_re.val[0], x2_im.val[0], x3_re.val[0], x3_im.val[0], v_re.val[2], v_im.val[2]); + FWD_BOTJ(x2_re.val[1], x2_im.val[1], x3_re.val[1], x3_im.val[1], v_re.val[3], v_im.val[3]); + vstorex2(r0 + blk, x0_re); + vstorex2(r1 + blk, x1_re); + vstorex2(r2 + blk, x2_re); + vstorex2(r3 + blk, x3_re); + vstorex2(i0 + blk, x0_im); + vstorex2(i1 + blk, x1_im); + vstorex2(i2 + blk, x2_im); + vstorex2(i3 + blk, x3_im); + } + } +} + +#if 0 +static void ZfN(iFFT_log2)(fpr *f) +{ + /* + y_re: 1 = (2 - 3) * 5 + (0 - 1) * 4 + y_im: 3 = (2 - 3) * 4 - (0 - 1) * 5 + x_re: 0 = 0 + 1 + x_im: 2 = 2 + 3 + + Turn out this vectorize code is too short to be executed, + the scalar version is consistently faster + + float64x2x2_t tmp; + float64x2_t v, s, t; + + // 0: 0, 2 + // 1: 1, 3 + + vload2(tmp, &f[0]); + vload(s, &fpr_gm_tab[4]); + + vfsub(v, tmp.val[0], tmp.val[1]); + vfadd(tmp.val[0], tmp.val[0], tmp.val[1]); + + // y_im: 3 = (2 - 3) * 4 - (0 - 1) * 5 + // y_re: 1 = (2 - 3) * 5 + (0 - 1) * 4 + vswap(t, v); + + vfmul_lane(tmp.val[1], s, t, 0); + vfcmla_90(tmp.val[1], t, s); + + vfmuln(tmp.val[0], tmp.val[0], 0.5); + vfmuln(tmp.val[1], tmp.val[1], 0.5); + + vswap(tmp.val[1], tmp.val[1]); + + vstore2(&f[0], tmp); + */ + + fpr x_re, x_im, y_re, y_im, s; + x_re = f[0]; + y_re = f[1]; + x_im = f[2]; + y_im = f[3]; + s = fpr_tab_log2[0] * 0.5; + + f[0] = (x_re + y_re) * 0.5; + f[2] = (x_im + y_im) * 0.5; + + x_re = (x_re - y_re) * s; + x_im = (x_im - y_im) * s; + + f[1] = x_im + x_re; + f[3] = x_im - x_re; +} + +static void ZfN(iFFT_log3)(fpr *f) +{ + /* + * Total instructions: 27 + y_re: 1 = (4 - 5) * 9 + (0 - 1) * 8 + y_re: 3 = (6 - 7) * 11 + (2 - 3) * 10 + y_im: 5 = (4 - 5) * 8 - (0 - 1) * 9 + y_im: 7 = (6 - 7) * 10 - (2 - 3) * 11 + x_re: 0 = 0 + 1 + x_re: 2 = 2 + 3 + x_im: 4 = 4 + 5 + x_im: 6 = 6 + 7 + */ + // 0: 0, 2 - 0: 0, 4 + // 1: 1, 3 - 1: 1, 5 + // 2: 4, 6 - 2: 2, 6 + // 3: 5, 7 - 3: 3, 7 + float64x2x4_t tmp; + float64x2x2_t x_re_im, y_re_im, v, s_re_im; + + vload2(x_re_im, &f[0]); + vload2(y_re_im, &f[4]); + + vfsub(v.val[0], x_re_im.val[0], x_re_im.val[1]); + vfsub(v.val[1], y_re_im.val[0], y_re_im.val[1]); + vfadd(x_re_im.val[0], x_re_im.val[0], x_re_im.val[1]); + vfadd(x_re_im.val[1], y_re_im.val[0], y_re_im.val[1]); + + // 0: 8, 10 + // 1: 9, 11 + vload2(s_re_im, &fpr_tab_log3[0]); + + vfmul(y_re_im.val[0], v.val[1], s_re_im.val[1]); + vfmla(y_re_im.val[0], y_re_im.val[0], v.val[0], s_re_im.val[0]); + vfmul(y_re_im.val[1], v.val[1], s_re_im.val[0]); + vfmls(y_re_im.val[1], y_re_im.val[1], v.val[0], s_re_im.val[1]); + + // x: 0,2 | 4,6 + // y: 1,3 | 5,7 + tmp.val[0] = vtrn1q_f64(x_re_im.val[0], y_re_im.val[0]); + tmp.val[1] = vtrn2q_f64(x_re_im.val[0], y_re_im.val[0]); + tmp.val[2] = vtrn1q_f64(x_re_im.val[1], y_re_im.val[1]); + tmp.val[3] = vtrn2q_f64(x_re_im.val[1], y_re_im.val[1]); + // tmp: 0,1 | 2,3 | 4,5 | 6,7 + /* + y_re: 2 = (4 - 6) * 4 + (0 - 2) * 4 + y_re: 3 = (5 - 7) * 4 + (1 - 3) * 4 + y_im: 6 = (4 - 6) * 4 - (0 - 2) * 4 + y_im: 7 = (5 - 7) * 4 - (1 - 3) * 4 + x_re: 0 = 0 + 2 + x_re: 1 = 1 + 3 + x_im: 4 = 4 + 6 + x_im: 5 = 5 + 7 + */ + s_re_im.val[0] = vld1q_dup_f64(&fpr_tab_log2[0]); + + vfadd(x_re_im.val[0], tmp.val[0], tmp.val[1]); + vfadd(x_re_im.val[1], tmp.val[2], tmp.val[3]); + vfsub(v.val[0], tmp.val[0], tmp.val[1]); + vfsub(v.val[1], tmp.val[2], tmp.val[3]); + + vfmuln(tmp.val[0], x_re_im.val[0], 0.25); + vfmuln(tmp.val[2], x_re_im.val[1], 0.25); + + vfmuln(s_re_im.val[0], s_re_im.val[0], 0.25); + + vfmul(y_re_im.val[0], v.val[0], s_re_im.val[0]); + vfmul(y_re_im.val[1], v.val[1], s_re_im.val[0]); + + vfadd(tmp.val[1], y_re_im.val[1], y_re_im.val[0]); + vfsub(tmp.val[3], y_re_im.val[1], y_re_im.val[0]); + + vstorex4(&f[0], tmp); +} + +static void ZfN(iFFT_log4)(fpr *f) +{ + /* + * ( 0, 8) - ( 1, 9) + * ( 4, 12) - ( 5, 13) + * + * ( 0, 8) + ( 1, 9) + * ( 4, 12) + ( 5, 13) + * + * ( 3, 11) - ( 2, 10) + * ( 7, 15) - ( 6, 14) + * + * ( 2, 10) + ( 3, 11) + * ( 6, 14) + ( 7, 15) + * + * ( 1, 9) = @ * ( 0, 1) + * ( 5, 13) = @ * ( 2, 3) + * + * ( 3, 11) = j@ * ( 0, 1) + * ( 7, 15) = j@ * ( 2, 3) + */ + + float64x2x4_t re, im, t; + float64x2x2_t t_re, t_im, s_re_im; + + vload4(re, &f[0]); + vload4(im, &f[8]); + + INV_TOPJ (t_re.val[0], t_im.val[0], re.val[0], im.val[0], re.val[1], im.val[1]); + INV_TOPJm(t_re.val[1], t_im.val[1], re.val[2], im.val[2], re.val[3], im.val[3]); + + vload2(s_re_im, &fpr_tab_log4[0]); + + INV_BOTJ (re.val[1], im.val[1], t_re.val[0], t_im.val[0], s_re_im.val[0], s_re_im.val[1]); + INV_BOTJm(re.val[3], im.val[3], t_re.val[1], t_im.val[1], s_re_im.val[0], s_re_im.val[1]); + + /* + * ( 0, 8) - ( 2, 10) + * ( 1, 9) - ( 3, 11) + * + * ( 0, 8) + ( 2, 10) + * ( 1, 9) + ( 3, 11) + * + * ( 2, 10) = @ * ( 0, 1) + * ( 3, 11) = @ * ( 0, 1) + * + * ( 6, 14) - ( 4, 12) + * ( 7, 15) - ( 5, 13) + * + * ( 4, 12) + ( 6, 14) + * ( 5, 13) + ( 7, 15) + * + * ( 6, 14) = j@ * ( 0, 1) + * ( 7, 15) = j@ * ( 0, 1) + */ + + // re: 0, 4 | 1, 5 | 2, 6 | 3, 7 + // im: 8, 12| 9, 13|10, 14|11, 15 + + transpose_f64(re, re, t, 0, 1, 0); + transpose_f64(re, re, t, 2, 3, 1); + transpose_f64(im, im, t, 0, 1, 2); + transpose_f64(im, im, t, 2, 3, 3); + + // re: 0, 1 | 4, 5 | 2, 3 | 6, 7 + // im: 8, 9 | 12, 13|10, 11| 14, 15 + t.val[0] = re.val[1]; + re.val[1] = re.val[2]; + re.val[2] = t.val[0]; + + t.val[1] = im.val[1]; + im.val[1] = im.val[2]; + im.val[2] = t.val[1]; + + // re: 0, 1 | 2, 3| 4, 5 | 6, 7 + // im: 8, 9 | 10, 11| 12, 13| 14, 15 + + INV_TOPJ (t_re.val[0], t_im.val[0], re.val[0], im.val[0], re.val[1], im.val[1]); + INV_TOPJm(t_re.val[1], t_im.val[1], re.val[2], im.val[2], re.val[3], im.val[3]); + + vload(s_re_im.val[0], &fpr_tab_log3[0]); + + INV_BOTJ_LANE (re.val[1], im.val[1], t_re.val[0], t_im.val[0], s_re_im.val[0]); + INV_BOTJm_LANE(re.val[3], im.val[3], t_re.val[1], t_im.val[1], s_re_im.val[0]); + + /* + * ( 0, 8) - ( 4, 12) + * ( 1, 9) - ( 5, 13) + * ( 0, 8) + ( 4, 12) + * ( 1, 9) + ( 5, 13) + * + * ( 2, 10) - ( 6, 14) + * ( 3, 11) - ( 7, 15) + * ( 2, 10) + ( 6, 14) + * ( 3, 11) + ( 7, 15) + * + * ( 4, 12) = @ * ( 0, 1) + * ( 5, 13) = @ * ( 0, 1) + * + * ( 6, 14) = @ * ( 0, 1) + * ( 7, 15) = @ * ( 0, 1) + */ + + INV_TOPJ(t_re.val[0], t_im.val[0], re.val[0], im.val[0], re.val[2], im.val[2]); + INV_TOPJ(t_re.val[1], t_im.val[1], re.val[1], im.val[1], re.val[3], im.val[3]); + + vfmuln(re.val[0], re.val[0], 0.12500000000); + vfmuln(re.val[1], re.val[1], 0.12500000000); + vfmuln(im.val[0], im.val[0], 0.12500000000); + vfmuln(im.val[1], im.val[1], 0.12500000000); + + s_re_im.val[0] = vld1q_dup_f64(&fpr_tab_log2[0]); + + vfmuln(s_re_im.val[0], s_re_im.val[0], 0.12500000000); + + vfmul(t_re.val[0], t_re.val[0], s_re_im.val[0]); + vfmul(t_re.val[1], t_re.val[1], s_re_im.val[0]); + vfmul(t_im.val[0], t_im.val[0], s_re_im.val[0]); + vfmul(t_im.val[1], t_im.val[1], s_re_im.val[0]); + + vfsub(im.val[2], t_im.val[0], t_re.val[0]); + vfsub(im.val[3], t_im.val[1], t_re.val[1]); + vfadd(re.val[2], t_im.val[0], t_re.val[0]); + vfadd(re.val[3], t_im.val[1], t_re.val[1]); + + vstorex4(&f[0], re); + vstorex4(&f[8], im); +} + +static + void ZfN(iFFT_log5)(fpr *f, const unsigned logn, const unsigned last) +{ + // Total SIMD register: 26 = 24 + 2 + float64x2x4_t x_re, x_im, y_re, y_im, t_re, t_im; // 24 + float64x2x2_t s_re_im; // 2 + const unsigned n = 1 << logn; + const unsigned hn = n >> 1; + + int level = logn; + const fpr *fpr_tab5 = fpr_table[level--], + *fpr_tab4 = fpr_table[level--], + *fpr_tab3 = fpr_table[level--], + *fpr_tab2 = fpr_table[level]; + int k2 = 0, k3 = 0, k4 = 0, k5 = 0; + + for (unsigned j = 0; j < hn; j += 16) + { + /* + * ( 0, 16) - ( 1, 17) + * ( 4, 20) - ( 5, 21) + * ( 0, 16) + ( 1, 17) + * ( 4, 20) + ( 5, 21) + * ( 1, 17) = @ * ( 0, 1) + * ( 5, 21) = @ * ( 2, 3) + * + * ( 2, 18) - ( 3, 19) + * ( 6, 22) - ( 7, 23) + * ( 2, 18) + ( 3, 19) + * ( 6, 22) + ( 7, 23) + * ( 3, 19) = j@ * ( 0, 1) + * ( 7, 23) = j@ * ( 2, 3) + * + * ( 8, 24) - ( 9, 25) + * ( 12, 28) - ( 13, 29) + * ( 8, 24) + ( 9, 25) + * ( 12, 28) + ( 13, 29) + * ( 9, 25) = @ * ( 4, 5) + * ( 13, 29) = @ * ( 6, 7) + * + * ( 10, 26) - ( 11, 27) + * ( 14, 30) - ( 15, 31) + * ( 10, 26) + ( 11, 27) + * ( 14, 30) + ( 15, 31) + * ( 11, 27) = j@ * ( 4, 5) + * ( 15, 31) = j@ * ( 6, 7) + */ + + vload4(x_re, &f[j]); + vload4(x_im, &f[j + hn]); + + INV_TOPJ(t_re.val[0], t_im.val[0], x_re.val[0], x_im.val[0], x_re.val[1], x_im.val[1]); + INV_TOPJm(t_re.val[2], t_im.val[2], x_re.val[2], x_im.val[2], x_re.val[3], x_im.val[3]); + + vload4(y_re, &f[j + 8]); + vload4(y_im, &f[j + 8 + hn]) + + INV_TOPJ(t_re.val[1], t_im.val[1], y_re.val[0], y_im.val[0], y_re.val[1], y_im.val[1]); + INV_TOPJm(t_re.val[3], t_im.val[3], y_re.val[2], y_im.val[2], y_re.val[3], y_im.val[3]); + + vload2(s_re_im, &fpr_tab5[k5]); + k5 += 4; + + INV_BOTJ (x_re.val[1], x_im.val[1], t_re.val[0], t_im.val[0], s_re_im.val[0], s_re_im.val[1]); + INV_BOTJm(x_re.val[3], x_im.val[3], t_re.val[2], t_im.val[2], s_re_im.val[0], s_re_im.val[1]); + + vload2(s_re_im, &fpr_tab5[k5]); + k5 += 4; + + INV_BOTJ (y_re.val[1], y_im.val[1], t_re.val[1], t_im.val[1], s_re_im.val[0], s_re_im.val[1]); + INV_BOTJm(y_re.val[3], y_im.val[3], t_re.val[3], t_im.val[3], s_re_im.val[0], s_re_im.val[1]); + + + // x_re: 0, 4 | 1, 5 | 2, 6 | 3, 7 + // y_re: 8, 12| 9, 13|10, 14|11, 15 + + transpose_f64(x_re, x_re, t_re, 0, 1, 0); + transpose_f64(x_re, x_re, t_re, 2, 3, 1); + transpose_f64(y_re, y_re, t_re, 0, 1, 2); + transpose_f64(y_re, y_re, t_re, 2, 3, 3); + + transpose_f64(x_im, x_im, t_im, 0, 1, 0); + transpose_f64(x_im, x_im, t_im, 2, 3, 1); + transpose_f64(y_im, y_im, t_im, 0, 1, 2); + transpose_f64(y_im, y_im, t_im, 2, 3, 3); + + // x_re: 0, 1 | 4, 5 | 2, 3 | 6, 7 + // y_re: 8, 9 | 12,13|10,11 |14, 15 + + t_re.val[0] = x_re.val[1]; + x_re.val[1] = x_re.val[2]; + x_re.val[2] = t_re.val[0]; + + t_re.val[1] = y_re.val[1]; + y_re.val[1] = y_re.val[2]; + y_re.val[2] = t_re.val[1]; + + + t_im.val[0] = x_im.val[1]; + x_im.val[1] = x_im.val[2]; + x_im.val[2] = t_im.val[0]; + + t_im.val[1] = y_im.val[1]; + y_im.val[1] = y_im.val[2]; + y_im.val[2] = t_im.val[1]; + // x_re: 0, 1 | 2, 3| 4, 5 | 6, 7 + // y_re: 8, 9 | 10, 11| 12, 13| 14, 15 + + /* + * ( 0, 16) - ( 2, 18) + * ( 1, 17) - ( 3, 19) + * ( 0, 16) + ( 2, 18) + * ( 1, 17) + ( 3, 19) + * ( 2, 18) = @ * ( 0, 1) + * ( 3, 19) = @ * ( 0, 1) + * + * ( 4, 20) - ( 6, 22) + * ( 5, 21) - ( 7, 23) + * ( 4, 20) + ( 6, 22) + * ( 5, 21) + ( 7, 23) + * ( 6, 22) = j@ * ( 0, 1) + * ( 7, 23) = j@ * ( 0, 1) + * + * ( 8, 24) - ( 10, 26) + * ( 9, 25) - ( 11, 27) + * ( 8, 24) + ( 10, 26) + * ( 9, 25) + ( 11, 27) + * ( 10, 26) = @ * ( 2, 3) + * ( 11, 27) = @ * ( 2, 3) + * + * ( 12, 28) - ( 14, 30) + * ( 13, 29) - ( 15, 31) + * ( 12, 28) + ( 14, 30) + * ( 13, 29) + ( 15, 31) + * ( 14, 30) = j@ * ( 2, 3) + * ( 15, 31) = j@ * ( 2, 3) + */ + + INV_TOPJ (t_re.val[0], t_im.val[0], x_re.val[0], x_im.val[0], x_re.val[1], x_im.val[1]); + INV_TOPJm(t_re.val[1], t_im.val[1], x_re.val[2], x_im.val[2], x_re.val[3], x_im.val[3]); + + INV_TOPJ (t_re.val[2], t_im.val[2], y_re.val[0], y_im.val[0], y_re.val[1], y_im.val[1]); + INV_TOPJm(t_re.val[3], t_im.val[3], y_re.val[2], y_im.val[2], y_re.val[3], y_im.val[3]); + + vloadx2(s_re_im, &fpr_tab4[k4]); + k4 += 4; + + INV_BOTJ_LANE (x_re.val[1], x_im.val[1], t_re.val[0], t_im.val[0], s_re_im.val[0]); + INV_BOTJm_LANE(x_re.val[3], x_im.val[3], t_re.val[1], t_im.val[1], s_re_im.val[0]); + + INV_BOTJ_LANE (y_re.val[1], y_im.val[1], t_re.val[2], t_im.val[2], s_re_im.val[1]); + INV_BOTJm_LANE(y_re.val[3], y_im.val[3], t_re.val[3], t_im.val[3], s_re_im.val[1]); + + /* + * ( 0, 16) - ( 4, 20) + * ( 1, 17) - ( 5, 21) + * ( 0, 16) + ( 4, 20) + * ( 1, 17) + ( 5, 21) + * ( 4, 20) = @ * ( 0, 1) + * ( 5, 21) = @ * ( 0, 1) + * + * ( 2, 18) - ( 6, 22) + * ( 3, 19) - ( 7, 23) + * ( 2, 18) + ( 6, 22) + * ( 3, 19) + ( 7, 23) + * ( 6, 22) = @ * ( 0, 1) + * ( 7, 23) = @ * ( 0, 1) + * + * ( 8, 24) - ( 12, 28) + * ( 9, 25) - ( 13, 29) + * ( 8, 24) + ( 12, 28) + * ( 9, 25) + ( 13, 29) + * ( 12, 28) = j@ * ( 0, 1) + * ( 13, 29) = j@ * ( 0, 1) + * + * ( 10, 26) - ( 14, 30) + * ( 11, 27) - ( 15, 31) + * ( 10, 26) + ( 14, 30) + * ( 11, 27) + ( 15, 31) + * ( 14, 30) = j@ * ( 0, 1) + * ( 15, 31) = j@ * ( 0, 1) + */ + + INV_TOPJ (t_re.val[0], t_im.val[0], x_re.val[0], x_im.val[0], x_re.val[2], x_im.val[2]); + INV_TOPJ (t_re.val[1], t_im.val[1], x_re.val[1], x_im.val[1], x_re.val[3], x_im.val[3]); + + INV_TOPJm(t_re.val[2], t_im.val[2], y_re.val[0], y_im.val[0], y_re.val[2], y_im.val[2]); + INV_TOPJm(t_re.val[3], t_im.val[3], y_re.val[1], y_im.val[1], y_re.val[3], y_im.val[3]); + + vload(s_re_im.val[0], &fpr_tab3[k3]); + k3 += 2; + + INV_BOTJ_LANE(x_re.val[2], x_im.val[2], t_re.val[0], t_im.val[0], s_re_im.val[0]); + INV_BOTJ_LANE(x_re.val[3], x_im.val[3], t_re.val[1], t_im.val[1], s_re_im.val[0]); + + INV_BOTJm_LANE(y_re.val[2], y_im.val[2], t_re.val[2], t_im.val[2], s_re_im.val[0]); + INV_BOTJm_LANE(y_re.val[3], y_im.val[3], t_re.val[3], t_im.val[3], s_re_im.val[0]); + + /* + * ( 0, 16) - ( 8, 24) + * ( 1, 17) - ( 9, 25) + * ( 0, 16) + ( 8, 24) + * ( 1, 17) + ( 9, 25) + * ( 8, 24) = @ * ( 0, 1) + * ( 9, 25) = @ * ( 0, 1) + * + * ( 2, 18) - ( 10, 26) + * ( 3, 19) - ( 11, 27) + * ( 2, 18) + ( 10, 26) + * ( 3, 19) + ( 11, 27) + * ( 10, 26) = @ * ( 0, 1) + * ( 11, 27) = @ * ( 0, 1) + * + * ( 4, 20) - ( 12, 28) + * ( 5, 21) - ( 13, 29) + * ( 4, 20) + ( 12, 28) + * ( 5, 21) + ( 13, 29) + * ( 12, 28) = @ * ( 0, 1) + * ( 13, 29) = @ * ( 0, 1) + * + * ( 6, 22) - ( 14, 30) + * ( 7, 23) - ( 15, 31) + * ( 6, 22) + ( 14, 30) + * ( 7, 23) + ( 15, 31) + * ( 14, 30) = @ * ( 0, 1) + * ( 15, 31) = @ * ( 0, 1) + */ + + + if ( (j >> 4) & 1) + { + INV_TOPJmx4(t_re, t_im, x_re, x_im, y_re, y_im); + } + else + { + INV_TOPJx4(t_re, t_im, x_re, x_im, y_re, y_im); + } + + vload(s_re_im.val[0], &fpr_tab2[k2]); + k2 += 2 * ((j & 31) == 16); + + if (last) + { + vfmuln(s_re_im.val[0], s_re_im.val[0], fpr_p2_tab[logn]); + vfmulnx4(x_re, x_re, fpr_p2_tab[logn]); + vfmulnx4(x_im, x_im, fpr_p2_tab[logn]); + } + vstorex4(&f[j], x_re); + vstorex4(&f[j + hn], x_im); + + if (logn == 5) + { + // Special case in fpr_tab_log2 where re == im + vfmulx4_i(t_re, t_re, s_re_im.val[0]); + vfmulx4_i(t_im, t_im, s_re_im.val[0]); + + vfaddx4(y_re, t_im, t_re); + vfsubx4(y_im, t_im, t_re); + } + else + { + if ((j >> 4) & 1) + { + INV_BOTJm_LANEx4(y_re, y_im, t_re, t_im, s_re_im.val[0]); + } + else + { + INV_BOTJ_LANEx4(y_re, y_im, t_re, t_im, s_re_im.val[0]); + } + } + + vstorex4(&f[j + 8], y_re); + vstorex4(&f[j + 8 + hn], y_im); + } +} + +static + void ZfN(iFFT_logn1)(fpr *f, const unsigned logn, const unsigned last) +{ + // Total SIMD register 26 = 24 + 2 + float64x2x4_t a_re, a_im, b_re, b_im, t_re, t_im; // 24 + float64x2_t s_re_im; // 2 + + const unsigned n = 1 << logn; + const unsigned hn = n >> 1; + const unsigned ht = n >> 2; + + for (unsigned j = 0; j < ht; j+= 8) + { + vloadx4(a_re, &f[j]); + vloadx4(a_im, &f[j + hn]); + vloadx4(b_re, &f[j + ht]); + vloadx4(b_im, &f[j + ht + hn]); + + INV_TOPJx4(t_re, t_im, a_re, a_im, b_re, b_im); + + s_re_im = vld1q_dup_f64(&fpr_tab_log2[0]); + + if (last) + { + vfmuln(s_re_im, s_re_im, fpr_p2_tab[logn]); + vfmulnx4(a_re, a_re, fpr_p2_tab[logn]); + vfmulnx4(a_im, a_im, fpr_p2_tab[logn]); + } + + vstorex4(&f[j], a_re); + vstorex4(&f[j + hn], a_im); + + vfmulx4_i(t_re, t_re, s_re_im); + vfmulx4_i(t_im, t_im, s_re_im); + + vfaddx4(b_re, t_im, t_re); + vfsubx4(b_im, t_im, t_re); + + vstorex4(&f[j + ht], b_re); + vstorex4(&f[j + ht + hn], b_im); + } +} + +// static +// void ZfN(iFFT_logn2)(fpr *f, const unsigned logn, const unsigned level, unsigned last) +// { +// const unsigned int falcon_n = 1 << logn; +// const unsigned int hn = falcon_n >> 1; + +// // Total SIMD register: 26 = 16 + 8 + 2 +// float64x2x4_t t_re, t_im; // 8 +// float64x2x2_t x1_re, x2_re, x1_im, x2_im, +// y1_re, y2_re, y1_im, y2_im; // 16 +// float64x2_t s1_re_im, s2_re_im; // 2 + +// const fpr *fpr_inv_tab1 = NULL, *fpr_inv_tab2 = NULL; +// unsigned l, len, start, j, k1, k2; +// unsigned bar = logn - 4 - 2; +// unsigned Jm; + +// for (l = 4; l < logn - level - 1; l += 2) +// { +// len = 1 << l; +// last -= 1; +// fpr_inv_tab1 = fpr_table[bar--]; +// fpr_inv_tab2 = fpr_table[bar--]; +// k1 = 0; k2 = 0; + +// for (start = 0; start < hn; start += 1 << (l + 2)) +// { +// vload(s1_re_im, &fpr_inv_tab1[k1]); +// vload(s2_re_im, &fpr_inv_tab2[k2]); +// k1 += 2; +// k2 += 2 * ((start & 127) == 64); +// if (!last) +// { +// vfmuln(s2_re_im, s2_re_im, fpr_p2_tab[logn]); +// } +// Jm = (start >> (l+ 2)) & 1; +// for (j = start; j < start + len; j += 4) +// { +// /* +// Level 6 +// * ( 0, 64) - ( 16, 80) +// * ( 1, 65) - ( 17, 81) +// * ( 0, 64) + ( 16, 80) +// * ( 1, 65) + ( 17, 81) +// * ( 16, 80) = @ * ( 0, 1) +// * ( 17, 81) = @ * ( 0, 1) +// * +// * ( 2, 66) - ( 18, 82) +// * ( 3, 67) - ( 19, 83) +// * ( 2, 66) + ( 18, 82) +// * ( 3, 67) + ( 19, 83) +// * ( 18, 82) = @ * ( 0, 1) +// * ( 19, 83) = @ * ( 0, 1) +// * +// * ( 32, 96) - ( 48, 112) +// * ( 33, 97) - ( 49, 113) +// * ( 32, 96) + ( 48, 112) +// * ( 33, 97) + ( 49, 113) +// * ( 48, 112) = j@ * ( 0, 1) +// * ( 49, 113) = j@ * ( 0, 1) +// * +// * ( 34, 98) - ( 50, 114) +// * ( 35, 99) - ( 51, 115) +// * ( 34, 98) + ( 50, 114) +// * ( 35, 99) + ( 51, 115) +// * ( 50, 114) = j@ * ( 0, 1) +// * ( 51, 115) = j@ * ( 0, 1) +// */ +// // x1: 0 -> 4 | 64 -> 67 +// // y1: 16 -> 19 | 80 -> 81 +// // x2: 32 -> 35 | 96 -> 99 +// // y2: 48 -> 51 | 112 -> 115 +// vloadx2(x1_re, &f[j]); +// vloadx2(x1_im, &f[j + hn]); +// vloadx2(y1_re, &f[j + len]); +// vloadx2(y1_im, &f[j + len + hn]); + +// INV_TOPJ (t_re.val[0], t_im.val[0], x1_re.val[0], x1_im.val[0], y1_re.val[0], y1_im.val[0]); +// INV_TOPJ (t_re.val[1], t_im.val[1], x1_re.val[1], x1_im.val[1], y1_re.val[1], y1_im.val[1]); + +// vloadx2(x2_re, &f[j + 2*len]); +// vloadx2(x2_im, &f[j + 2*len + hn]); +// vloadx2(y2_re, &f[j + 3*len]); +// vloadx2(y2_im, &f[j + 3*len + hn]); + +// INV_TOPJm(t_re.val[2], t_im.val[2], x2_re.val[0], x2_im.val[0], y2_re.val[0], y2_im.val[0]); +// INV_TOPJm(t_re.val[3], t_im.val[3], x2_re.val[1], x2_im.val[1], y2_re.val[1], y2_im.val[1]); + +// INV_BOTJ_LANE (y1_re.val[0], y1_im.val[0], t_re.val[0], t_im.val[0], s1_re_im); +// INV_BOTJ_LANE (y1_re.val[1], y1_im.val[1], t_re.val[1], t_im.val[1], s1_re_im); + +// INV_BOTJm_LANE(y2_re.val[0], y2_im.val[0], t_re.val[2], t_im.val[2], s1_re_im); +// INV_BOTJm_LANE(y2_re.val[1], y2_im.val[1], t_re.val[3], t_im.val[3], s1_re_im); +// /* +// * Level 7 +// * ( 0, 64) - ( 32, 96) +// * ( 1, 65) - ( 33, 97) +// * ( 0, 64) + ( 32, 96) +// * ( 1, 65) + ( 33, 97) +// * ( 32, 96) = @ * ( 0, 1) +// * ( 33, 97) = @ * ( 0, 1) +// * +// * ( 2, 66) - ( 34, 98) +// * ( 3, 67) - ( 35, 99) +// * ( 2, 66) + ( 34, 98) +// * ( 3, 67) + ( 35, 99) +// * ( 34, 98) = @ * ( 0, 1) +// * ( 35, 99) = @ * ( 0, 1) +// * ---- +// * ( 16, 80) - ( 48, 112) +// * ( 17, 81) - ( 49, 113) +// * ( 16, 80) + ( 48, 112) +// * ( 17, 81) + ( 49, 113) +// * ( 48, 112) = @ * ( 0, 1) +// * ( 49, 113) = @ * ( 0, 1) +// * +// * ( 18, 82) - ( 50, 114) +// * ( 19, 83) - ( 51, 115) +// * ( 18, 82) + ( 50, 114) +// * ( 19, 83) + ( 51, 115) +// * ( 50, 114) = @ * ( 0, 1) +// * ( 51, 115) = @ * ( 0, 1) +// */ + +// if (Jm) +// { +// INV_TOPJm(t_re.val[0], t_im.val[0], x1_re.val[0], x1_im.val[0], x2_re.val[0], x2_im.val[0]); +// INV_TOPJm(t_re.val[1], t_im.val[1], x1_re.val[1], x1_im.val[1], x2_re.val[1], x2_im.val[1]); + +// INV_TOPJm(t_re.val[2], t_im.val[2], y1_re.val[0], y1_im.val[0], y2_re.val[0], y2_im.val[0]); +// INV_TOPJm(t_re.val[3], t_im.val[3], y1_re.val[1], y1_im.val[1], y2_re.val[1], y2_im.val[1]); + +// INV_BOTJm_LANE(x2_re.val[0], x2_im.val[0], t_re.val[0], t_im.val[0], s2_re_im); +// INV_BOTJm_LANE(x2_re.val[1], x2_im.val[1], t_re.val[1], t_im.val[1], s2_re_im); +// INV_BOTJm_LANE(y2_re.val[0], y2_im.val[0], t_re.val[2], t_im.val[2], s2_re_im); +// INV_BOTJm_LANE(y2_re.val[1], y2_im.val[1], t_re.val[3], t_im.val[3], s2_re_im); +// } +// else +// { +// INV_TOPJ(t_re.val[0], t_im.val[0], x1_re.val[0], x1_im.val[0], x2_re.val[0], x2_im.val[0]); +// INV_TOPJ(t_re.val[1], t_im.val[1], x1_re.val[1], x1_im.val[1], x2_re.val[1], x2_im.val[1]); + +// INV_TOPJ(t_re.val[2], t_im.val[2], y1_re.val[0], y1_im.val[0], y2_re.val[0], y2_im.val[0]); +// INV_TOPJ(t_re.val[3], t_im.val[3], y1_re.val[1], y1_im.val[1], y2_re.val[1], y2_im.val[1]); + +// INV_BOTJ_LANE(x2_re.val[0], x2_im.val[0], t_re.val[0], t_im.val[0], s2_re_im); +// INV_BOTJ_LANE(x2_re.val[1], x2_im.val[1], t_re.val[1], t_im.val[1], s2_re_im); +// INV_BOTJ_LANE(y2_re.val[0], y2_im.val[0], t_re.val[2], t_im.val[2], s2_re_im); +// INV_BOTJ_LANE(y2_re.val[1], y2_im.val[1], t_re.val[3], t_im.val[3], s2_re_im); +// } + +// vstorex2(&f[j + 2*len], x2_re); +// vstorex2(&f[j + 2*len + hn], x2_im); + +// vstorex2(&f[j + 3*len], y2_re); +// vstorex2(&f[j + 3*len + hn], y2_im); + +// if (!last) +// { +// vfmuln(x1_re.val[0], x1_re.val[0], fpr_p2_tab[logn]); +// vfmuln(x1_re.val[1], x1_re.val[1], fpr_p2_tab[logn]); +// vfmuln(x1_im.val[0], x1_im.val[0], fpr_p2_tab[logn]); +// vfmuln(x1_im.val[1], x1_im.val[1], fpr_p2_tab[logn]); + +// vfmuln(y1_re.val[0], y1_re.val[0], fpr_p2_tab[logn]); +// vfmuln(y1_re.val[1], y1_re.val[1], fpr_p2_tab[logn]); +// vfmuln(y1_im.val[0], y1_im.val[0], fpr_p2_tab[logn]); +// vfmuln(y1_im.val[1], y1_im.val[1], fpr_p2_tab[logn]); +// } + +// vstorex2(&f[j], x1_re); +// vstorex2(&f[j + hn], x1_im); + +// vstorex2(&f[j + len], y1_re); +// vstorex2(&f[j + len + hn], y1_im); + +// } +// } +// } +// } + + + +static + void ZfN(iFFT_logn2)(fpr *f, const unsigned logn, const unsigned level, unsigned last) +{ + const unsigned int falcon_n = 1 << logn; + const unsigned int hn = falcon_n >> 1; + + // Total SIMD register: 26 = 16 + 8 + 2 + float64x2x4_t t_re, t_im; // 8 + float64x2x2_t x1_re, x2_re, x1_im, x2_im, + y1_re, y2_re, y1_im, y2_im; // 16 + float64x2_t s1_re_im, s2_re_im; // 2 + + const fpr *fpr_inv_tab1 = NULL, *fpr_inv_tab2 = NULL; + unsigned l, len, start, j, k1, k2; + unsigned bar = logn - 4; + + for (l = 4; l < logn - level - 1; l += 2) + { + len = 1 << l; + last -= 1; + fpr_inv_tab1 = fpr_table[bar--]; + fpr_inv_tab2 = fpr_table[bar--]; + k1 = 0; k2 = 0; + + for (start = 0; start < hn; start += 1 << (l + 2)) + { + vload(s1_re_im, &fpr_inv_tab1[k1]); + vload(s2_re_im, &fpr_inv_tab2[k2]); + k1 += 2; + k2 += 2 * ((start & 127) == 64); + if (!last) + { + vfmuln(s2_re_im, s2_re_im, fpr_p2_tab[logn]); + } + for (j = start; j < start + len; j += 4) + { + /* + Level 6 + * ( 0, 64) - ( 16, 80) + * ( 1, 65) - ( 17, 81) + * ( 0, 64) + ( 16, 80) + * ( 1, 65) + ( 17, 81) + * ( 16, 80) = @ * ( 0, 1) + * ( 17, 81) = @ * ( 0, 1) + * + * ( 2, 66) - ( 18, 82) + * ( 3, 67) - ( 19, 83) + * ( 2, 66) + ( 18, 82) + * ( 3, 67) + ( 19, 83) + * ( 18, 82) = @ * ( 0, 1) + * ( 19, 83) = @ * ( 0, 1) + * + * ( 32, 96) - ( 48, 112) + * ( 33, 97) - ( 49, 113) + * ( 32, 96) + ( 48, 112) + * ( 33, 97) + ( 49, 113) + * ( 48, 112) = j@ * ( 0, 1) + * ( 49, 113) = j@ * ( 0, 1) + * + * ( 34, 98) - ( 50, 114) + * ( 35, 99) - ( 51, 115) + * ( 34, 98) + ( 50, 114) + * ( 35, 99) + ( 51, 115) + * ( 50, 114) = j@ * ( 0, 1) + * ( 51, 115) = j@ * ( 0, 1) + */ + // x1: 0 -> 4 | 64 -> 67 + // y1: 16 -> 19 | 80 -> 81 + // x2: 32 -> 35 | 96 -> 99 + // y2: 48 -> 51 | 112 -> 115 + vloadx2(x1_re, &f[j]); + vloadx2(x1_im, &f[j + hn]); + vloadx2(y1_re, &f[j + len]); + vloadx2(y1_im, &f[j + len + hn]); + + INV_TOPJ (t_re.val[0], t_im.val[0], x1_re.val[0], x1_im.val[0], y1_re.val[0], y1_im.val[0]); + INV_TOPJ (t_re.val[1], t_im.val[1], x1_re.val[1], x1_im.val[1], y1_re.val[1], y1_im.val[1]); + + vloadx2(x2_re, &f[j + 2*len]); + vloadx2(x2_im, &f[j + 2*len + hn]); + vloadx2(y2_re, &f[j + 3*len]); + vloadx2(y2_im, &f[j + 3*len + hn]); + + INV_TOPJm(t_re.val[2], t_im.val[2], x2_re.val[0], x2_im.val[0], y2_re.val[0], y2_im.val[0]); + INV_TOPJm(t_re.val[3], t_im.val[3], x2_re.val[1], x2_im.val[1], y2_re.val[1], y2_im.val[1]); + + INV_BOTJ_LANE (y1_re.val[0], y1_im.val[0], t_re.val[0], t_im.val[0], s1_re_im); + INV_BOTJ_LANE (y1_re.val[1], y1_im.val[1], t_re.val[1], t_im.val[1], s1_re_im); + + INV_BOTJm_LANE(y2_re.val[0], y2_im.val[0], t_re.val[2], t_im.val[2], s1_re_im); + INV_BOTJm_LANE(y2_re.val[1], y2_im.val[1], t_re.val[3], t_im.val[3], s1_re_im); + /* + * Level 7 + * ( 0, 64) - ( 32, 96) + * ( 1, 65) - ( 33, 97) + * ( 0, 64) + ( 32, 96) + * ( 1, 65) + ( 33, 97) + * ( 32, 96) = @ * ( 0, 1) + * ( 33, 97) = @ * ( 0, 1) + * + * ( 2, 66) - ( 34, 98) + * ( 3, 67) - ( 35, 99) + * ( 2, 66) + ( 34, 98) + * ( 3, 67) + ( 35, 99) + * ( 34, 98) = @ * ( 0, 1) + * ( 35, 99) = @ * ( 0, 1) + * ---- + * ( 16, 80) - ( 48, 112) + * ( 17, 81) - ( 49, 113) + * ( 16, 80) + ( 48, 112) + * ( 17, 81) + ( 49, 113) + * ( 48, 112) = @ * ( 0, 1) + * ( 49, 113) = @ * ( 0, 1) + * + * ( 18, 82) - ( 50, 114) + * ( 19, 83) - ( 51, 115) + * ( 18, 82) + ( 50, 114) + * ( 19, 83) + ( 51, 115) + * ( 50, 114) = @ * ( 0, 1) + * ( 51, 115) = @ * ( 0, 1) + */ + + + INV_TOPJ(t_re.val[0], t_im.val[0], x1_re.val[0], x1_im.val[0], x2_re.val[0], x2_im.val[0]); + INV_TOPJ(t_re.val[1], t_im.val[1], x1_re.val[1], x1_im.val[1], x2_re.val[1], x2_im.val[1]); + + INV_TOPJ(t_re.val[2], t_im.val[2], y1_re.val[0], y1_im.val[0], y2_re.val[0], y2_im.val[0]); + INV_TOPJ(t_re.val[3], t_im.val[3], y1_re.val[1], y1_im.val[1], y2_re.val[1], y2_im.val[1]); + + INV_BOTJ_LANE(x2_re.val[0], x2_im.val[0], t_re.val[0], t_im.val[0], s2_re_im); + INV_BOTJ_LANE(x2_re.val[1], x2_im.val[1], t_re.val[1], t_im.val[1], s2_re_im); + INV_BOTJ_LANE(y2_re.val[0], y2_im.val[0], t_re.val[2], t_im.val[2], s2_re_im); + INV_BOTJ_LANE(y2_re.val[1], y2_im.val[1], t_re.val[3], t_im.val[3], s2_re_im); + + vstorex2(&f[j + 2*len], x2_re); + vstorex2(&f[j + 2*len + hn], x2_im); + + vstorex2(&f[j + 3*len], y2_re); + vstorex2(&f[j + 3*len + hn], y2_im); + + if (!last) + { + vfmuln(x1_re.val[0], x1_re.val[0], fpr_p2_tab[logn]); + vfmuln(x1_re.val[1], x1_re.val[1], fpr_p2_tab[logn]); + vfmuln(x1_im.val[0], x1_im.val[0], fpr_p2_tab[logn]); + vfmuln(x1_im.val[1], x1_im.val[1], fpr_p2_tab[logn]); + + vfmuln(y1_re.val[0], y1_re.val[0], fpr_p2_tab[logn]); + vfmuln(y1_re.val[1], y1_re.val[1], fpr_p2_tab[logn]); + vfmuln(y1_im.val[0], y1_im.val[0], fpr_p2_tab[logn]); + vfmuln(y1_im.val[1], y1_im.val[1], fpr_p2_tab[logn]); + } + + vstorex2(&f[j], x1_re); + vstorex2(&f[j + hn], x1_im); + + vstorex2(&f[j + len], y1_re); + vstorex2(&f[j + len + hn], y1_im); + + } + // + start += 1 << (l + 2); + if (start >= hn) break; + + vload(s1_re_im, &fpr_inv_tab1[k1]); + vload(s2_re_im, &fpr_inv_tab2[k2]); + k1 += 2; + k2 += 2 * ((start & 127) == 64); + if (!last) + { + vfmuln(s2_re_im, s2_re_im, fpr_p2_tab[logn]); + } + + for (j = start; j < start + len; j += 4) + { + /* + Level 6 + * ( 0, 64) - ( 16, 80) + * ( 1, 65) - ( 17, 81) + * ( 0, 64) + ( 16, 80) + * ( 1, 65) + ( 17, 81) + * ( 16, 80) = @ * ( 0, 1) + * ( 17, 81) = @ * ( 0, 1) + * + * ( 2, 66) - ( 18, 82) + * ( 3, 67) - ( 19, 83) + * ( 2, 66) + ( 18, 82) + * ( 3, 67) + ( 19, 83) + * ( 18, 82) = @ * ( 0, 1) + * ( 19, 83) = @ * ( 0, 1) + * + * ( 32, 96) - ( 48, 112) + * ( 33, 97) - ( 49, 113) + * ( 32, 96) + ( 48, 112) + * ( 33, 97) + ( 49, 113) + * ( 48, 112) = j@ * ( 0, 1) + * ( 49, 113) = j@ * ( 0, 1) + * + * ( 34, 98) - ( 50, 114) + * ( 35, 99) - ( 51, 115) + * ( 34, 98) + ( 50, 114) + * ( 35, 99) + ( 51, 115) + * ( 50, 114) = j@ * ( 0, 1) + * ( 51, 115) = j@ * ( 0, 1) + */ + // x1: 0 -> 4 | 64 -> 67 + // y1: 16 -> 19 | 80 -> 81 + // x2: 32 -> 35 | 96 -> 99 + // y2: 48 -> 51 | 112 -> 115 + vloadx2(x1_re, &f[j]); + vloadx2(x1_im, &f[j + hn]); + vloadx2(y1_re, &f[j + len]); + vloadx2(y1_im, &f[j + len + hn]); + + INV_TOPJ (t_re.val[0], t_im.val[0], x1_re.val[0], x1_im.val[0], y1_re.val[0], y1_im.val[0]); + INV_TOPJ (t_re.val[1], t_im.val[1], x1_re.val[1], x1_im.val[1], y1_re.val[1], y1_im.val[1]); + + vloadx2(x2_re, &f[j + 2*len]); + vloadx2(x2_im, &f[j + 2*len + hn]); + vloadx2(y2_re, &f[j + 3*len]); + vloadx2(y2_im, &f[j + 3*len + hn]); + + INV_TOPJm(t_re.val[2], t_im.val[2], x2_re.val[0], x2_im.val[0], y2_re.val[0], y2_im.val[0]); + INV_TOPJm(t_re.val[3], t_im.val[3], x2_re.val[1], x2_im.val[1], y2_re.val[1], y2_im.val[1]); + + INV_BOTJ_LANE (y1_re.val[0], y1_im.val[0], t_re.val[0], t_im.val[0], s1_re_im); + INV_BOTJ_LANE (y1_re.val[1], y1_im.val[1], t_re.val[1], t_im.val[1], s1_re_im); + + INV_BOTJm_LANE(y2_re.val[0], y2_im.val[0], t_re.val[2], t_im.val[2], s1_re_im); + INV_BOTJm_LANE(y2_re.val[1], y2_im.val[1], t_re.val[3], t_im.val[3], s1_re_im); + /* + * Level 7 + * ( 0, 64) - ( 32, 96) + * ( 1, 65) - ( 33, 97) + * ( 0, 64) + ( 32, 96) + * ( 1, 65) + ( 33, 97) + * ( 32, 96) = @ * ( 0, 1) + * ( 33, 97) = @ * ( 0, 1) + * + * ( 2, 66) - ( 34, 98) + * ( 3, 67) - ( 35, 99) + * ( 2, 66) + ( 34, 98) + * ( 3, 67) + ( 35, 99) + * ( 34, 98) = @ * ( 0, 1) + * ( 35, 99) = @ * ( 0, 1) + * ---- + * ( 16, 80) - ( 48, 112) + * ( 17, 81) - ( 49, 113) + * ( 16, 80) + ( 48, 112) + * ( 17, 81) + ( 49, 113) + * ( 48, 112) = @ * ( 0, 1) + * ( 49, 113) = @ * ( 0, 1) + * + * ( 18, 82) - ( 50, 114) + * ( 19, 83) - ( 51, 115) + * ( 18, 82) + ( 50, 114) + * ( 19, 83) + ( 51, 115) + * ( 50, 114) = @ * ( 0, 1) + * ( 51, 115) = @ * ( 0, 1) + */ + + INV_TOPJm(t_re.val[0], t_im.val[0], x1_re.val[0], x1_im.val[0], x2_re.val[0], x2_im.val[0]); + INV_TOPJm(t_re.val[1], t_im.val[1], x1_re.val[1], x1_im.val[1], x2_re.val[1], x2_im.val[1]); + + INV_TOPJm(t_re.val[2], t_im.val[2], y1_re.val[0], y1_im.val[0], y2_re.val[0], y2_im.val[0]); + INV_TOPJm(t_re.val[3], t_im.val[3], y1_re.val[1], y1_im.val[1], y2_re.val[1], y2_im.val[1]); + + INV_BOTJm_LANE(x2_re.val[0], x2_im.val[0], t_re.val[0], t_im.val[0], s2_re_im); + INV_BOTJm_LANE(x2_re.val[1], x2_im.val[1], t_re.val[1], t_im.val[1], s2_re_im); + INV_BOTJm_LANE(y2_re.val[0], y2_im.val[0], t_re.val[2], t_im.val[2], s2_re_im); + INV_BOTJm_LANE(y2_re.val[1], y2_im.val[1], t_re.val[3], t_im.val[3], s2_re_im); + + vstorex2(&f[j + 2*len], x2_re); + vstorex2(&f[j + 2*len + hn], x2_im); + + vstorex2(&f[j + 3*len], y2_re); + vstorex2(&f[j + 3*len + hn], y2_im); + + if (!last) + { + vfmuln(x1_re.val[0], x1_re.val[0], fpr_p2_tab[logn]); + vfmuln(x1_re.val[1], x1_re.val[1], fpr_p2_tab[logn]); + vfmuln(x1_im.val[0], x1_im.val[0], fpr_p2_tab[logn]); + vfmuln(x1_im.val[1], x1_im.val[1], fpr_p2_tab[logn]); + + vfmuln(y1_re.val[0], y1_re.val[0], fpr_p2_tab[logn]); + vfmuln(y1_re.val[1], y1_re.val[1], fpr_p2_tab[logn]); + vfmuln(y1_im.val[0], y1_im.val[0], fpr_p2_tab[logn]); + vfmuln(y1_im.val[1], y1_im.val[1], fpr_p2_tab[logn]); + } + + vstorex2(&f[j], x1_re); + vstorex2(&f[j + hn], x1_im); + + vstorex2(&f[j + len], y1_re); + vstorex2(&f[j + len + hn], y1_im); + + } + // + } + } +} + + +/* +* Support logn from [1, 10] +* Can be easily extended to logn > 10 +*/ +void ZfN(iFFT)(fpr *f, const unsigned logn) +{ + const unsigned level = (logn - 5) & 1; + + switch (logn) + { + case 2: + ZfN(iFFT_log2)(f); + break; + + case 3: + ZfN(iFFT_log3)(f); + break; + + case 4: + ZfN(iFFT_log4)(f); + break; + + case 5: + ZfN(iFFT_log5)(f, 5, 1); + break; + + case 6: + ZfN(iFFT_log5)(f, logn, 0); + ZfN(iFFT_logn1)(f, logn, 1); + break; + + case 7: + case 9: + ZfN(iFFT_log5)(f, logn, 0); + ZfN(iFFT_logn2)(f, logn, level, 1); + break; + + case 8: + case 10: + ZfN(iFFT_log5)(f, logn, 0); + ZfN(iFFT_logn2)(f, logn, level, 0); + ZfN(iFFT_logn1)(f, logn, 1); + break; + + default: + break; + } +} +#endif + +// generic fft stuff + +#include "../commons_private.h" +#include "reim_fft_internal.h" +#include "reim_fft_private.h" + +void reim_ctwiddle(double* ra, double* ia, double* rb, double* ib, double omre, double omim); +void reim_citwiddle(double* ra, double* ia, double* rb, double* ib, double omre, double omim); + +void reim_fft16_ref(double* dre, double* dim, const void* pom); +void fill_reim_fft16_omegas(const double entry_pwr, double** omg); +void reim_fft8_ref(double* dre, double* dim, const void* pom); +void reim_twiddle_fft_ref(uint64_t h, double* re, double* im, double om[2]); +void fill_reim_twiddle_fft_ref(const double s, double** omg); +void reim_bitwiddle_fft_ref(uint64_t h, double* re, double* im, double om[4]); +void fill_reim_bitwiddle_fft_ref(const double s, double** omg); +void reim_fft_rec_16_ref(uint64_t m, double* re, double* im, double** omg); +void fill_reim_fft_rec_16_omegas(uint64_t m, double entry_pwr, double** omg); + +void fill_reim_twiddle_fft_omegas_ref(const double rs0, double** omg) { + (*omg)[0] = cos(2 * M_PI * rs0); + (*omg)[1] = sin(2 * M_PI * rs0); + *omg += 2; +} + +void fill_reim_bitwiddle_fft_omegas_ref(const double rs0, double** omg) { + double rs1 = 2. * rs0; + (*omg)[0] = cos(2 * M_PI * rs1); + (*omg)[1] = sin(2 * M_PI * rs1); + (*omg)[2] = cos(2 * M_PI * rs0); + (*omg)[3] = sin(2 * M_PI * rs0); + *omg += 4; +} + +#define reim_fft16_f reim_fft16_neon +#define reim_fft16_pom_offset 16 +#define fill_reim_fft16_omegas_f fill_reim_fft16_omegas_neon +// currently, m=4 uses the ref implem (corner case with low impact) TODO!! +#define reim_fft8_f reim_fft8_ref +#define reim_fft8_pom_offset 8 +#define fill_reim_fft8_omegas_f fill_reim_fft8_omegas +// currently, m=4 uses the ref implem (corner case with low impact) TODO!! +#define reim_fft4_f reim_fft4_ref +#define reim_fft4_pom_offset 4 +#define fill_reim_fft4_omegas_f fill_reim_fft4_omegas +// m = 2 will use the ref implem, since intrinsics don't provide any speed-up +#define reim_fft2_f reim_fft2_ref +#define reim_fft2_pom_offset 2 +#define fill_reim_fft2_omegas_f fill_reim_fft2_omegas + +// neon twiddle use the same omegas layout as the ref implem +#define reim_twiddle_fft_f reim_twiddle_fft_neon +#define reim_twiddle_fft_pom_offset 2 +#define fill_reim_twiddle_fft_omegas_f fill_reim_twiddle_fft_omegas_ref + +// neon bi-twiddle use the same omegas layout as the ref implem +#define reim_bitwiddle_fft_f reim_bitwiddle_fft_neon +#define reim_bitwiddle_fft_pom_offset 4 +#define fill_reim_bitwiddle_fft_omegas_f fill_reim_bitwiddle_fft_omegas_ref + +// template functions to produce +#define reim_fft_bfs_16_f reim_fft_bfs_16_neon +#define fill_reim_fft_bfs_16_omegas_f fill_reim_fft_bfs_16_omegas_neon +#define reim_fft_rec_16_f reim_fft_rec_16_neon +#define fill_reim_fft_rec_16_omegas_f fill_reim_fft_rec_16_omegas_neon +#define reim_fft_f reim_fft_neon +#define fill_reim_fft_omegas_f fill_reim_fft_omegas_neon + +#include "reim_fft_core_template.h" + +EXPORT REIM_FFT_PRECOMP* new_reim_fft_precomp_neon(uint32_t m, uint32_t num_buffers) { + const uint64_t OMG_SPACE = ceilto64b(2 * m * sizeof(double)); + const uint64_t BUF_SIZE = ceilto64b(2 * m * sizeof(double)); + void* reps = malloc(sizeof(REIM_FFT_PRECOMP) // base + + 63 // padding + + OMG_SPACE // tables //TODO 16? + + num_buffers * BUF_SIZE // buffers + ); + uint64_t aligned_addr = ceilto64b((uint64_t)(reps) + sizeof(REIM_FFT_PRECOMP)); + REIM_FFT_PRECOMP* r = (REIM_FFT_PRECOMP*)reps; + r->m = m; + r->buf_size = BUF_SIZE; + r->powomegas = (double*)aligned_addr; + r->aligned_buffers = (void*)(aligned_addr + OMG_SPACE); + // fill in powomegas + double* omg = (double*)r->powomegas; + fill_reim_fft_omegas_f(m, 0.25, &omg); + if (((uint64_t)omg) - aligned_addr > OMG_SPACE) abort(); + // dispatch the right implementation + //{ + // if (CPU_SUPPORTS("fma")) { + // r->function = reim_fft_avx2_fma; + // } else { + // r->function = reim_fft_ref; + // } + //} + r->function = reim_fft_f; + return reps; +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fft_private.h b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fft_private.h new file mode 100644 index 0000000..a38acf2 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fft_private.h @@ -0,0 +1,128 @@ +#ifndef SPQLIOS_REIM_FFT_PRIVATE_H +#define SPQLIOS_REIM_FFT_PRIVATE_H + +#include "../commons_private.h" +#include "reim_fft.h" + +#define STATIC_ASSERT(condition) (void)sizeof(char[-1 + 2 * !!(condition)]) + +typedef struct reim_twiddle_precomp REIM_FFTVEC_TWIDDLE_PRECOMP; +typedef struct reim_bitwiddle_precomp REIM_FFTVEC_BITWIDDLE_PRECOMP; + +typedef void (*FFT_FUNC)(const REIM_FFT_PRECOMP*, double*); +typedef void (*IFFT_FUNC)(const REIM_IFFT_PRECOMP*, double*); +typedef void (*FFTVEC_ADD_FUNC)(const REIM_FFTVEC_ADD_PRECOMP*, double*, const double*, const double*); +typedef void (*FFTVEC_SUB_FUNC)(const REIM_FFTVEC_SUB_PRECOMP*, double*, const double*, const double*); +typedef void (*FFTVEC_MUL_FUNC)(const REIM_FFTVEC_MUL_PRECOMP*, double*, const double*, const double*); +typedef void (*FFTVEC_ADDMUL_FUNC)(const REIM_FFTVEC_ADDMUL_PRECOMP*, double*, const double*, const double*); +typedef void (*FROM_ZNX32_FUNC)(const REIM_FROM_ZNX32_PRECOMP*, void*, const int32_t*); +typedef void (*FROM_ZNX64_FUNC)(const REIM_FROM_ZNX64_PRECOMP*, void*, const int64_t*); +typedef void (*FROM_TNX32_FUNC)(const REIM_FROM_TNX32_PRECOMP*, void*, const int32_t*); +typedef void (*TO_TNX32_FUNC)(const REIM_TO_TNX32_PRECOMP*, int32_t*, const void*); +typedef void (*TO_TNX_FUNC)(const REIM_TO_TNX_PRECOMP*, double*, const double*); +typedef void (*TO_ZNX64_FUNC)(const REIM_TO_ZNX64_PRECOMP*, int64_t*, const void*); +typedef void (*FFTVEC_TWIDDLE_FUNC)(const REIM_FFTVEC_TWIDDLE_PRECOMP*, void*, const void*, const void*); +typedef void (*FFTVEC_BITWIDDLE_FUNC)(const REIM_FFTVEC_BITWIDDLE_PRECOMP*, void*, uint64_t, const void*); + +typedef void (*FFTVEC_AUTOMORPHISM_APPLY)(const REIM_FFTVEC_AUTOMORPHISM_PRECOMP* precomp, int64_t p, double* r, + const double* a, uint64_t a_size); +typedef uint64_t (*FFTVEC_AUTOMORPHISM_APPLY_INPLACE_TMP_BYTES)(const REIM_FFTVEC_AUTOMORPHISM_PRECOMP* precomp); +typedef void (*FFTVEC_AUTOMORPHISM_APPLY_INPLACE)(const REIM_FFTVEC_AUTOMORPHISM_PRECOMP* precomp, int64_t p, double* a, + uint64_t a_size, uint8_t* tmp_bytes); +typedef struct reim_fftvec_automorphism_funcs { + FFTVEC_AUTOMORPHISM_APPLY apply; + FFTVEC_AUTOMORPHISM_APPLY_INPLACE apply_inplace; + FFTVEC_AUTOMORPHISM_APPLY_INPLACE_TMP_BYTES apply_inplace_tmp_bytes; +} FFTVEC_AUTOMORPHISM_FUNCS; + +typedef struct reim_fft_precomp { + FFT_FUNC function; + int64_t m; ///< complex dimension warning: reim uses n=2N=4m + uint64_t buf_size; ///< size of aligned_buffers (multiple of 64B) + double* powomegas; ///< 64B aligned + void* aligned_buffers; ///< 64B aligned +} REIM_FFT_PRECOMP; + +typedef struct reim_ifft_precomp { + IFFT_FUNC function; + int64_t m; // warning: reim uses n=2N=4m + uint64_t buf_size; ///< size of aligned_buffers (multiple of 64B) + double* powomegas; + void* aligned_buffers; +} REIM_IFFT_PRECOMP; + +typedef struct reim_add_precomp { + FFTVEC_ADD_FUNC function; + int64_t m; +} REIM_FFTVEC_ADD_PRECOMP; + +typedef struct reim_sub_precomp { + FFTVEC_SUB_FUNC function; + int64_t m; +} REIM_FFTVEC_SUB_PRECOMP; + +typedef struct reim_mul_precomp { + FFTVEC_MUL_FUNC function; + int64_t m; +} REIM_FFTVEC_MUL_PRECOMP; + +typedef struct reim_addmul_precomp { + FFTVEC_ADDMUL_FUNC function; + int64_t m; +} REIM_FFTVEC_ADDMUL_PRECOMP; + +typedef struct reim_fftvec_automorphism_precomp { + FFTVEC_AUTOMORPHISM_FUNCS function; + int64_t m; + uint64_t* irev; +} REIM_FFTVEC_AUTOMORPHISM_PRECOMP; + +struct reim_from_znx32_precomp { + FROM_ZNX32_FUNC function; + int64_t m; +}; + +struct reim_from_znx64_precomp { + FROM_ZNX64_FUNC function; + int64_t m; +}; + +struct reim_from_tnx32_precomp { + FROM_TNX32_FUNC function; + int64_t m; +}; + +struct reim_to_tnx32_precomp { + TO_TNX32_FUNC function; + int64_t m; + double divisor; +}; + +struct reim_to_tnx_precomp { + TO_TNX_FUNC function; + int64_t m; + double divisor; + uint32_t log2overhead; + double add_cst; + uint64_t mask_and; + uint64_t mask_or; + double sub_cst; +}; + +struct reim_to_znx64_precomp { + TO_ZNX64_FUNC function; + int64_t m; + double divisor; +}; + +struct reim_twiddle_precomp { + FFTVEC_TWIDDLE_FUNC function; + int64_t m; +}; + +struct reim_bitwiddle_precomp { + FFTVEC_BITWIDDLE_FUNC function; + int64_t m; +}; + +#endif // SPQLIOS_REIM_FFT_PRIVATE_H diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fft_ref.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fft_ref.c new file mode 100644 index 0000000..3033fa1 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fft_ref.c @@ -0,0 +1,450 @@ +#include "../commons_private.h" +#include "reim_fft_internal.h" +#include "reim_fft_private.h" + +EXPORT void reim_fft_simple(uint32_t m, void* data) { + static REIM_FFT_PRECOMP* p[31] = {0}; + REIM_FFT_PRECOMP** f = p + log2m(m); + if (!*f) *f = new_reim_fft_precomp(m, 0); + (*f)->function(*f, data); +} + +EXPORT void reim_ifft_simple(uint32_t m, void* data) { + static REIM_IFFT_PRECOMP* p[31] = {0}; + REIM_IFFT_PRECOMP** f = p + log2m(m); + if (!*f) *f = new_reim_ifft_precomp(m, 0); + (*f)->function(*f, data); +} + +EXPORT void reim_fftvec_add_simple(uint32_t m, void* r, const void* a, const void* b) { + static REIM_FFTVEC_ADD_PRECOMP* p[31] = {0}; + REIM_FFTVEC_ADD_PRECOMP** f = p + log2m(m); + if (!*f) *f = new_reim_fftvec_add_precomp(m); + (*f)->function(*f, r, a, b); +} + +EXPORT void reim_fftvec_sub_simple(uint32_t m, void* r, const void* a, const void* b) { + static REIM_FFTVEC_SUB_PRECOMP* p[31] = {0}; + REIM_FFTVEC_SUB_PRECOMP** f = p + log2m(m); + if (!*f) *f = new_reim_fftvec_sub_precomp(m); + (*f)->function(*f, r, a, b); +} + +EXPORT void reim_fftvec_mul_simple(uint32_t m, void* r, const void* a, const void* b) { + static REIM_FFTVEC_MUL_PRECOMP* p[31] = {0}; + REIM_FFTVEC_MUL_PRECOMP** f = p + log2m(m); + if (!*f) *f = new_reim_fftvec_mul_precomp(m); + (*f)->function(*f, r, a, b); +} + +EXPORT void reim_fftvec_addmul_simple(uint32_t m, void* r, const void* a, const void* b) { + static REIM_FFTVEC_ADDMUL_PRECOMP* p[31] = {0}; + REIM_FFTVEC_ADDMUL_PRECOMP** f = p + log2m(m); + if (!*f) *f = new_reim_fftvec_addmul_precomp(m); + (*f)->function(*f, r, a, b); +} + +void reim_ctwiddle(double* ra, double* ia, double* rb, double* ib, double omre, double omim) { + double newrt = *rb * omre - *ib * omim; + double newit = *rb * omim + *ib * omre; + *rb = *ra - newrt; + *ib = *ia - newit; + *ra = *ra + newrt; + *ia = *ia + newit; +} + +// i (omre + i omim) = -omim + i omre +void reim_citwiddle(double* ra, double* ia, double* rb, double* ib, double omre, double omim) { + double newrt = -*rb * omim - *ib * omre; + double newit = *rb * omre - *ib * omim; + *rb = *ra - newrt; + *ib = *ia - newit; + *ra = *ra + newrt; + *ia = *ia + newit; +} + +void reim_fft16_ref(double* dre, double* dim, const void* pom) { + const double* om = (const double*)pom; + { + double omre = om[0]; + double omim = om[1]; + reim_ctwiddle(&dre[0], &dim[0], &dre[8], &dim[8], omre, omim); + reim_ctwiddle(&dre[1], &dim[1], &dre[9], &dim[9], omre, omim); + reim_ctwiddle(&dre[2], &dim[2], &dre[10], &dim[10], omre, omim); + reim_ctwiddle(&dre[3], &dim[3], &dre[11], &dim[11], omre, omim); + reim_ctwiddle(&dre[4], &dim[4], &dre[12], &dim[12], omre, omim); + reim_ctwiddle(&dre[5], &dim[5], &dre[13], &dim[13], omre, omim); + reim_ctwiddle(&dre[6], &dim[6], &dre[14], &dim[14], omre, omim); + reim_ctwiddle(&dre[7], &dim[7], &dre[15], &dim[15], omre, omim); + } + { + double omre = om[2]; + double omim = om[3]; + reim_ctwiddle(&dre[0], &dim[0], &dre[4], &dim[4], omre, omim); + reim_ctwiddle(&dre[1], &dim[1], &dre[5], &dim[5], omre, omim); + reim_ctwiddle(&dre[2], &dim[2], &dre[6], &dim[6], omre, omim); + reim_ctwiddle(&dre[3], &dim[3], &dre[7], &dim[7], omre, omim); + reim_citwiddle(&dre[8], &dim[8], &dre[12], &dim[12], omre, omim); + reim_citwiddle(&dre[9], &dim[9], &dre[13], &dim[13], omre, omim); + reim_citwiddle(&dre[10], &dim[10], &dre[14], &dim[14], omre, omim); + reim_citwiddle(&dre[11], &dim[11], &dre[15], &dim[15], omre, omim); + } + { + double omare = om[4]; + double omaim = om[5]; + double ombre = om[6]; + double ombim = om[7]; + reim_ctwiddle(&dre[0], &dim[0], &dre[2], &dim[2], omare, omaim); + reim_ctwiddle(&dre[1], &dim[1], &dre[3], &dim[3], omare, omaim); + reim_citwiddle(&dre[4], &dim[4], &dre[6], &dim[6], omare, omaim); + reim_citwiddle(&dre[5], &dim[5], &dre[7], &dim[7], omare, omaim); + reim_ctwiddle(&dre[8], &dim[8], &dre[10], &dim[10], ombre, ombim); + reim_ctwiddle(&dre[9], &dim[9], &dre[11], &dim[11], ombre, ombim); + reim_citwiddle(&dre[12], &dim[12], &dre[14], &dim[14], ombre, ombim); + reim_citwiddle(&dre[13], &dim[13], &dre[15], &dim[15], ombre, ombim); + } + { + double omare = om[8]; + double ombre = om[9]; + double omcre = om[10]; + double omdre = om[11]; + double omaim = om[12]; + double ombim = om[13]; + double omcim = om[14]; + double omdim = om[15]; + reim_ctwiddle(&dre[0], &dim[0], &dre[1], &dim[1], omare, omaim); + reim_citwiddle(&dre[2], &dim[2], &dre[3], &dim[3], omare, omaim); + reim_ctwiddle(&dre[4], &dim[4], &dre[5], &dim[5], ombre, ombim); + reim_citwiddle(&dre[6], &dim[6], &dre[7], &dim[7], ombre, ombim); + reim_ctwiddle(&dre[8], &dim[8], &dre[9], &dim[9], omcre, omcim); + reim_citwiddle(&dre[10], &dim[10], &dre[11], &dim[11], omcre, omcim); + reim_ctwiddle(&dre[12], &dim[12], &dre[13], &dim[13], omdre, omdim); + reim_citwiddle(&dre[14], &dim[14], &dre[15], &dim[15], omdre, omdim); + } +} + +void fill_reim_fft16_omegas(const double entry_pwr, double** omg) { + const double j_pow = 1. / 8.; + const double k_pow = 1. / 16.; + const double pin = entry_pwr / 2.; + const double pin_2 = entry_pwr / 4.; + const double pin_4 = entry_pwr / 8.; + const double pin_8 = entry_pwr / 16.; + // 0 and 1 are real and imag of om + (*omg)[0] = cos(2. * M_PI * pin); + (*omg)[1] = sin(2. * M_PI * pin); + // 2 and 3 are real and imag of om^1/2 + (*omg)[2] = cos(2. * M_PI * (pin_2)); + (*omg)[3] = sin(2. * M_PI * (pin_2)); + // (4,5) and (6,7) are real and imag of om^1/4 and j.om^1/4 + (*omg)[4] = cos(2. * M_PI * (pin_4)); + (*omg)[5] = sin(2. * M_PI * (pin_4)); + (*omg)[6] = cos(2. * M_PI * (pin_4 + j_pow)); + (*omg)[7] = sin(2. * M_PI * (pin_4 + j_pow)); + // ((8,9,10,11),(12,13,14,15)) are 4 reals then 4 imag of om^1/8*(1,k,j,kj) + (*omg)[8] = cos(2. * M_PI * (pin_8)); + (*omg)[9] = cos(2. * M_PI * (pin_8 + j_pow)); + (*omg)[10] = cos(2. * M_PI * (pin_8 + k_pow)); + (*omg)[11] = cos(2. * M_PI * (pin_8 + j_pow + k_pow)); + (*omg)[12] = sin(2. * M_PI * (pin_8)); + (*omg)[13] = sin(2. * M_PI * (pin_8 + j_pow)); + (*omg)[14] = sin(2. * M_PI * (pin_8 + k_pow)); + (*omg)[15] = sin(2. * M_PI * (pin_8 + j_pow + k_pow)); + *omg += 16; +} + +void reim_fft8_ref(double* dre, double* dim, const void* pom) { + const double* om = (const double*)pom; + { + double omre = om[0]; + double omim = om[1]; + reim_ctwiddle(&dre[0], &dim[0], &dre[4], &dim[4], omre, omim); + reim_ctwiddle(&dre[1], &dim[1], &dre[5], &dim[5], omre, omim); + reim_ctwiddle(&dre[2], &dim[2], &dre[6], &dim[6], omre, omim); + reim_ctwiddle(&dre[3], &dim[3], &dre[7], &dim[7], omre, omim); + } + { + double omare = om[2]; + double omaim = om[3]; + reim_ctwiddle(&dre[0], &dim[0], &dre[2], &dim[2], omare, omaim); + reim_ctwiddle(&dre[1], &dim[1], &dre[3], &dim[3], omare, omaim); + reim_citwiddle(&dre[4], &dim[4], &dre[6], &dim[6], omare, omaim); + reim_citwiddle(&dre[5], &dim[5], &dre[7], &dim[7], omare, omaim); + } + { + double omare = om[4]; + double ombre = om[5]; + double omaim = om[6]; + double ombim = om[7]; + reim_ctwiddle(&dre[0], &dim[0], &dre[1], &dim[1], omare, omaim); + reim_citwiddle(&dre[2], &dim[2], &dre[3], &dim[3], omare, omaim); + reim_ctwiddle(&dre[4], &dim[4], &dre[5], &dim[5], ombre, ombim); + reim_citwiddle(&dre[6], &dim[6], &dre[7], &dim[7], ombre, ombim); + } +} + +void fill_reim_fft8_omegas(const double entry_pwr, double** omg) { + const double j_pow = 1. / 8.; + const double pin = entry_pwr / 2.; + const double pin_2 = entry_pwr / 4.; + const double pin_4 = entry_pwr / 8.; + // 0 and 1 are real and imag of om + (*omg)[0] = cos(2. * M_PI * pin); + (*omg)[1] = sin(2. * M_PI * pin); + // 2 and 3 are real and imag of om^1/2 + (*omg)[2] = cos(2. * M_PI * (pin_2)); + (*omg)[3] = sin(2. * M_PI * (pin_2)); + // (4,5) and (6,7) are real and imag of om^1/4 and j.om^1/4 + (*omg)[4] = cos(2. * M_PI * (pin_4)); + (*omg)[5] = cos(2. * M_PI * (pin_4 + j_pow)); + (*omg)[6] = sin(2. * M_PI * (pin_4)); + (*omg)[7] = sin(2. * M_PI * (pin_4 + j_pow)); + *omg += 8; +} + +void reim_fft4_ref(double* dre, double* dim, const void* pom) { + const double* om = (const double*)pom; + { + double omare = om[0]; + double omaim = om[1]; + reim_ctwiddle(&dre[0], &dim[0], &dre[2], &dim[2], omare, omaim); + reim_ctwiddle(&dre[1], &dim[1], &dre[3], &dim[3], omare, omaim); + } + { + double omare = om[2]; + double omaim = om[3]; + reim_ctwiddle(&dre[0], &dim[0], &dre[1], &dim[1], omare, omaim); + reim_citwiddle(&dre[2], &dim[2], &dre[3], &dim[3], omare, omaim); + } +} + +void fill_reim_fft4_omegas(const double entry_pwr, double** omg) { + const double pin = entry_pwr / 2.; + const double pin_2 = entry_pwr / 4.; + // 0 and 1 are real and imag of om + (*omg)[0] = cos(2. * M_PI * pin); + (*omg)[1] = sin(2. * M_PI * pin); + // 2 and 3 are real and imag of om^1/2 + (*omg)[2] = cos(2. * M_PI * (pin_2)); + (*omg)[3] = sin(2. * M_PI * (pin_2)); + *omg += 4; +} + +void reim_fft2_ref(double* dre, double* dim, const void* pom) { + const double* om = (const double*)pom; + { + double omare = om[0]; + double omaim = om[1]; + reim_ctwiddle(&dre[0], &dim[0], &dre[1], &dim[1], omare, omaim); + } +} + +void fill_reim_fft2_omegas(const double entry_pwr, double** omg) { + const double pin = entry_pwr / 2.; + // 0 and 1 are real and imag of om + (*omg)[0] = cos(2. * M_PI * pin); + (*omg)[1] = sin(2. * M_PI * pin); + *omg += 2; +} + +void reim_twiddle_fft_ref(uint64_t h, double* re, double* im, double om[2]) { + for (uint64_t i = 0; i < h; ++i) { + reim_ctwiddle(&re[i], &im[i], &re[h + i], &im[h + i], om[0], om[1]); + } +} + +void reim_bitwiddle_fft_ref(uint64_t h, double* re, double* im, double om[4]) { + double* r0 = re; + double* r1 = re + h; + double* r2 = re + 2 * h; + double* r3 = re + 3 * h; + double* i0 = im; + double* i1 = im + h; + double* i2 = im + 2 * h; + double* i3 = im + 3 * h; + for (uint64_t i = 0; i < h; ++i) { + reim_ctwiddle(&r0[i], &i0[i], &r2[i], &i2[i], om[0], om[1]); + reim_ctwiddle(&r1[i], &i1[i], &r3[i], &i3[i], om[0], om[1]); + } + for (uint64_t i = 0; i < h; ++i) { + reim_ctwiddle(&r0[i], &i0[i], &r1[i], &i1[i], om[2], om[3]); + reim_citwiddle(&r2[i], &i2[i], &r3[i], &i3[i], om[2], om[3]); + } +} + +void reim_fft_bfs_16_ref(uint64_t m, double* re, double* im, double** omg) { + uint64_t log2m = log2(m); + uint64_t mm = m; + if (log2m % 2 != 0) { + uint64_t h = mm >> 1; + // do the first twiddle iteration normally + reim_twiddle_fft_ref(h, re, im, *omg); + *omg += 2; + mm = h; + } + while (mm > 16) { + uint64_t h = mm >> 2; + for (uint64_t off = 0; off < m; off += mm) { + reim_bitwiddle_fft_ref(h, re + off, im + off, *omg); + *omg += 4; + } + mm = h; + } + if (mm != 16) abort(); // bug! + for (uint64_t off = 0; off < m; off += 16) { + reim_fft16_ref(re + off, im + off, *omg); + *omg += 16; + } +} + +void fill_reim_fft_bfs_16_omegas(uint64_t m, double entry_pwr, double** omg) { + uint64_t log2m = log2(m); + uint64_t mm = m; + double ss = entry_pwr; + if (log2m % 2 != 0) { + uint64_t h = mm >> 1; + double s = ss / 2.; + // do the first twiddle iteration normally + (*omg)[0] = cos(2 * M_PI * s); + (*omg)[1] = sin(2 * M_PI * s); + *omg += 2; + mm = h; + ss = s; + } + while (mm > 16) { + uint64_t h = mm >> 2; + double s = ss / 4.; + for (uint64_t off = 0; off < m; off += mm) { + double rs0 = s + fracrevbits(off / mm) / 4.; + double rs1 = 2. * rs0; + (*omg)[0] = cos(2 * M_PI * rs1); + (*omg)[1] = sin(2 * M_PI * rs1); + (*omg)[2] = cos(2 * M_PI * rs0); + (*omg)[3] = sin(2 * M_PI * rs0); + *omg += 4; + } + mm = h; + ss = s; + } + if (mm != 16) abort(); // bug! + for (uint64_t off = 0; off < m; off += 16) { + double s = ss + fracrevbits(off / 16); + fill_reim_fft16_omegas(s, omg); + } +} + +void reim_fft_rec_16_ref(uint64_t m, double* re, double* im, double** omg) { + if (m <= 2048) return reim_fft_bfs_16_ref(m, re, im, omg); + const uint32_t h = m / 2; + reim_twiddle_fft_ref(h, re, im, *omg); + *omg += 2; + reim_fft_rec_16_ref(h, re, im, omg); + reim_fft_rec_16_ref(h, re + h, im + h, omg); +} + +void fill_reim_fft_rec_16_omegas(uint64_t m, double entry_pwr, double** omg) { + if (m <= 2048) return fill_reim_fft_bfs_16_omegas(m, entry_pwr, omg); + const uint64_t h = m / 2; + const double s = entry_pwr / 2; + (*omg)[0] = cos(2 * M_PI * s); + (*omg)[1] = sin(2 * M_PI * s); + *omg += 2; + fill_reim_fft_rec_16_omegas(h, s, omg); + fill_reim_fft_rec_16_omegas(h, s + 0.5, omg); +} + +void reim_fft_ref(const REIM_FFT_PRECOMP* precomp, double* dat) { + const int32_t m = precomp->m; + double* omg = precomp->powomegas; + double* re = dat; + double* im = dat + m; + if (m <= 16) { + switch (m) { + case 1: + return; + case 2: + return reim_fft2_ref(re, im, omg); + case 4: + return reim_fft4_ref(re, im, omg); + case 8: + return reim_fft8_ref(re, im, omg); + case 16: + return reim_fft16_ref(re, im, omg); + default: + abort(); // m is not a power of 2 + } + } + if (m <= 2048) return reim_fft_bfs_16_ref(m, re, im, &omg); + return reim_fft_rec_16_ref(m, re, im, &omg); +} + +EXPORT REIM_FFT_PRECOMP* new_reim_fft_precomp(uint32_t m, uint32_t num_buffers) { + const uint64_t OMG_SPACE = ceilto64b(2 * m * sizeof(double)); + const uint64_t BUF_SIZE = ceilto64b(2 * m * sizeof(double)); + void* reps = malloc(sizeof(REIM_FFT_PRECOMP) // base + + 63 // padding + + OMG_SPACE // tables //TODO 16? + + num_buffers * BUF_SIZE // buffers + ); + uint64_t aligned_addr = ceilto64b((uint64_t)(reps) + sizeof(REIM_FFT_PRECOMP)); + REIM_FFT_PRECOMP* r = (REIM_FFT_PRECOMP*)reps; + r->m = m; + r->buf_size = BUF_SIZE; + r->powomegas = (double*)aligned_addr; + r->aligned_buffers = (void*)(aligned_addr + OMG_SPACE); + // fill in powomegas + double* omg = (double*)r->powomegas; + if (m <= 16) { + switch (m) { + case 1: + break; + case 2: + fill_reim_fft2_omegas(0.25, &omg); + break; + case 4: + fill_reim_fft4_omegas(0.25, &omg); + break; + case 8: + fill_reim_fft8_omegas(0.25, &omg); + break; + case 16: + fill_reim_fft16_omegas(0.25, &omg); + break; + default: + abort(); // m is not a power of 2 + } + } else if (m <= 2048) { + fill_reim_fft_bfs_16_omegas(m, 0.25, &omg); + } else { + fill_reim_fft_rec_16_omegas(m, 0.25, &omg); + } + if (((uint64_t)omg) - aligned_addr > OMG_SPACE) abort(); + // dispatch the right implementation + { + if (CPU_SUPPORTS("fma")) { + r->function = reim_fft_avx2_fma; + } else { + r->function = reim_fft_ref; + } + } + return reps; +} + +void reim_naive_fft(uint64_t m, double entry_pwr, double* re, double* im) { + if (m == 1) return; + // twiddle + const uint64_t h = m / 2; + const double s = entry_pwr / 2.; + const double sre = cos(2 * M_PI * s); + const double sim = sin(2 * M_PI * s); + for (uint64_t j = 0; j < h; ++j) { + double pre = re[h + j] * sre - im[h + j] * sim; + double pim = im[h + j] * sre + re[h + j] * sim; + re[h + j] = re[j] - pre; + im[h + j] = im[j] - pim; + re[j] += pre; + im[j] += pim; + } + reim_naive_fft(h, s, re, im); + reim_naive_fft(h, s + 0.5, re + h, im + h); +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fftvec_fma.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fftvec_fma.c new file mode 100644 index 0000000..393d883 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fftvec_fma.c @@ -0,0 +1,137 @@ +#include +#include +#include + +#include "reim_fft_private.h" + +EXPORT void reim_fftvec_addmul_fma(const REIM_FFTVEC_ADDMUL_PRECOMP* precomp, double* r, const double* a, + const double* b) { + const uint64_t m = precomp->m; + double* rr_ptr = r; + double* ri_ptr = r + m; + const double* ar_ptr = a; + const double* ai_ptr = a + m; + const double* br_ptr = b; + const double* bi_ptr = b + m; + + const double* const rend_ptr = ri_ptr; + while (rr_ptr != rend_ptr) { + __m256d rr = _mm256_loadu_pd(rr_ptr); + __m256d ri = _mm256_loadu_pd(ri_ptr); + const __m256d ar = _mm256_loadu_pd(ar_ptr); + const __m256d ai = _mm256_loadu_pd(ai_ptr); + const __m256d br = _mm256_loadu_pd(br_ptr); + const __m256d bi = _mm256_loadu_pd(bi_ptr); + + rr = _mm256_fmsub_pd(ai, bi, rr); + rr = _mm256_fmsub_pd(ar, br, rr); + ri = _mm256_fmadd_pd(ar, bi, ri); + ri = _mm256_fmadd_pd(ai, br, ri); + + _mm256_storeu_pd(rr_ptr, rr); + _mm256_storeu_pd(ri_ptr, ri); + + rr_ptr += 4; + ri_ptr += 4; + ar_ptr += 4; + ai_ptr += 4; + br_ptr += 4; + bi_ptr += 4; + } +} + +EXPORT void reim_fftvec_mul_fma(const REIM_FFTVEC_MUL_PRECOMP* precomp, double* r, const double* a, const double* b) { + const uint64_t m = precomp->m; + double* rr_ptr = r; + double* ri_ptr = r + m; + const double* ar_ptr = a; + const double* ai_ptr = a + m; + const double* br_ptr = b; + const double* bi_ptr = b + m; + + const double* const rend_ptr = ri_ptr; + while (rr_ptr != rend_ptr) { + const __m256d ar = _mm256_loadu_pd(ar_ptr); + const __m256d ai = _mm256_loadu_pd(ai_ptr); + const __m256d br = _mm256_loadu_pd(br_ptr); + const __m256d bi = _mm256_loadu_pd(bi_ptr); + + const __m256d t1 = _mm256_mul_pd(ai, bi); + const __m256d t2 = _mm256_mul_pd(ar, bi); + + __m256d rr = _mm256_fmsub_pd(ar, br, t1); + __m256d ri = _mm256_fmadd_pd(ai, br, t2); + + _mm256_storeu_pd(rr_ptr, rr); + _mm256_storeu_pd(ri_ptr, ri); + + rr_ptr += 4; + ri_ptr += 4; + ar_ptr += 4; + ai_ptr += 4; + br_ptr += 4; + bi_ptr += 4; + } +} + +EXPORT void reim_fftvec_add_fma(const REIM_FFTVEC_ADD_PRECOMP* precomp, double* r, const double* a, const double* b) { + const uint64_t m = precomp->m; + double* rr_ptr = r; + double* ri_ptr = r + m; + const double* ar_ptr = a; + const double* ai_ptr = a + m; + const double* br_ptr = b; + const double* bi_ptr = b + m; + + const double* const rend_ptr = ri_ptr; + while (rr_ptr != rend_ptr) { + const __m256d ar = _mm256_loadu_pd(ar_ptr); + const __m256d ai = _mm256_loadu_pd(ai_ptr); + const __m256d br = _mm256_loadu_pd(br_ptr); + const __m256d bi = _mm256_loadu_pd(bi_ptr); + + __m256d rr = _mm256_add_pd(ar, br); + __m256d ri = _mm256_add_pd(ai, bi); + + _mm256_storeu_pd(rr_ptr, rr); + _mm256_storeu_pd(ri_ptr, ri); + + rr_ptr += 4; + ri_ptr += 4; + ar_ptr += 4; + ai_ptr += 4; + br_ptr += 4; + bi_ptr += 4; + } +} + +EXPORT void reim_fftvec_sub_fma(const REIM_FFTVEC_ADD_PRECOMP* precomp, double* r, const double* a, const double* b) { + const uint64_t m = precomp->m; + double* rr_ptr = r; + double* ri_ptr = r + m; + const double* ar_ptr = a; + const double* ai_ptr = a + m; + const double* br_ptr = b; + const double* bi_ptr = b + m; + + const double* const rend_ptr = ri_ptr; + while (rr_ptr != rend_ptr) { + const __m256d ar = _mm256_loadu_pd(ar_ptr); + const __m256d ai = _mm256_loadu_pd(ai_ptr); + const __m256d br = _mm256_loadu_pd(br_ptr); + const __m256d bi = _mm256_loadu_pd(bi_ptr); + + __m256d rr = _mm256_sub_pd(ar, br); + __m256d ri = _mm256_sub_pd(ai, bi); + + _mm256_storeu_pd(rr_ptr, rr); + _mm256_storeu_pd(ri_ptr, ri); + + rr_ptr += 4; + ri_ptr += 4; + ar_ptr += 4; + ai_ptr += 4; + br_ptr += 4; + bi_ptr += 4; + } +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fftvec_ref.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fftvec_ref.c new file mode 100644 index 0000000..9adbe48 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_fftvec_ref.c @@ -0,0 +1,242 @@ +#include + +#include "../commons_private.h" +#include "reim_fft_internal.h" +#include "reim_fft_private.h" + +// Computes X^i -> X^(p*i) in the fourier domain for a reim vector is size 2 * m * size +// This function cannot be evaluated in place. +EXPORT void reim_fftvec_automorphism_ref(const REIM_FFTVEC_AUTOMORPHISM_PRECOMP* tables, int64_t p, double* r, + const double* a, uint64_t a_size) { + const uint64_t m = tables->m; + const uint64_t nn = 2 * m; + const uint64_t* irev = tables->irev; + const uint64_t mask = (4 * m - 1); + const uint64_t conj = !((p & 3) == 1); + p = p & mask; + p *= 1 - (conj << 1); + if (a_size == 1) { + for (uint64_t i = 0; i < m; ++i) { + uint64_t i_rev = 2 * irev[i] + 1; + i_rev = (((p * i_rev) & mask) - 1) >> 1; + uint64_t j = irev[i_rev]; + r[i] = a[j]; + double x = a[j + m]; + r[i + m] = conj ? -x : x; + } + } else { + for (uint64_t i = 0; i < m; ++i) { + uint64_t i_rev = 2 * irev[i] + 1; + i_rev = (((p * i_rev) & mask) - 1) >> 1; + uint64_t j = irev[i_rev]; + for (uint64_t k = 0; k < a_size; ++k) { + uint64_t offset_re = k * nn; + uint64_t offset_im = offset_re + m; + r[i + offset_re] = a[j + offset_re]; + double x = a[j + offset_im]; + r[i + offset_im] = conj ? -x : x; + } + } + } +} + +// Computes the permutation index for an automorphism X^{i} -> X^{i*p} in the fourier domain. +EXPORT void reim_fftvec_automorphism_lut_precomp_ref(const REIM_FFTVEC_AUTOMORPHISM_PRECOMP* tables, int64_t p, + uint64_t* precomp // size m +) { + const uint64_t m = tables->m; + const uint64_t* irev = tables->irev; + const uint64_t mask = (4 * m - 1); + const uint64_t conj = !((p & 3) == 1); + p = p & mask; + p *= 1 - (conj << 1); + for (uint64_t i = 0; i < m; ++i) { + uint64_t i_rev = 2 * irev[i] + 1; + i_rev = (((p * i_rev) & mask) - 1) >> 1; + uint64_t j = irev[i_rev]; + precomp[i] = j; + } +} + +// Computes X^{i} -> X^{i*p} in the fourier domain for a reim vector of size m using a precomputed lut permutation. +void reim_fftvec_automorphism_with_lut_ref(const REIM_FFTVEC_AUTOMORPHISM_PRECOMP* tables, uint64_t* precomp, double* r, + const double* a) { + const uint64_t m = tables->m; + for (uint64_t i = 0; i < m; ++i) { + uint64_t j = precomp[i]; + r[i] = a[j]; + r[i + m] = a[j + m]; + } +} + +// Computes X^{i} -> X^{i*-p} in the fourier domain for a reim vector of size m using a precomputed lut permutation. +void reim_fftvec_automorphism_conj_with_lut_ref(const REIM_FFTVEC_AUTOMORPHISM_PRECOMP* tables, uint64_t* precomp, + double* r, const double* a) { + const uint64_t m = tables->m; + for (uint64_t i = 0; i < m; ++i) { + uint64_t j = precomp[i]; + r[i] = a[j]; + r[i + m] = -a[j + m]; + } +} + +// Returns the minimum number of temporary bytes used by reim_fftvec_automorphism_inplace_ref. +EXPORT uint64_t reim_fftvec_automorphism_inplace_tmp_bytes_ref(const REIM_FFTVEC_AUTOMORPHISM_PRECOMP* tables) { + const uint64_t m = tables->m; + return m * (2 * sizeof(double) + sizeof(uint64_t)); +} + +// Computes X^i -> X^(p*i) in the fourier domain for a reim vector is size 2 * m * a_size +// This function cannot be evaluated in place. +EXPORT void reim_fftvec_automorphism_inplace_ref(const REIM_FFTVEC_AUTOMORPHISM_PRECOMP* tables, int64_t p, double* a, + uint64_t a_size, + uint8_t* tmp_bytes // m * (2*sizeof(double) + sizeof(uint64_t)) +) { + const uint64_t m = tables->m; + const uint64_t nn = 2 * m; + const uint64_t* irev = tables->irev; + const uint64_t mask = (4 * m - 1); + const uint64_t conj = !((p & 3) == 1); + + double* tmp = (double*)tmp_bytes; + p = p & mask; + if (a_size == 1) { + p *= 1 - (conj << 1); + for (uint64_t i = 0; i < m; ++i) { + uint64_t i_rev = 2 * irev[i] + 1; + i_rev = (((p * i_rev) & mask) - 1) >> 1; + uint64_t j = irev[i_rev]; + tmp[i] = a[j]; + double x = a[j + m]; + tmp[i + m] = conj ? -x : x; + } + memcpy(a, tmp, nn * sizeof(double)); + } else { + uint64_t* lut = (uint64_t*)(tmp_bytes + nn * sizeof(double)); + reim_fftvec_automorphism_lut_precomp_ref(tables, p, lut); + for (uint64_t i = 0; i < a_size; ++i) { + if (conj == 1) { + reim_fftvec_automorphism_conj_with_lut_ref(tables, lut, tmp, a + i * nn); + } else { + reim_fftvec_automorphism_with_lut_ref(tables, lut, tmp, a + i * nn); + } + memcpy(a + i * nn, tmp, nn * sizeof(double)); + } + } +} + +EXPORT void reim_fftvec_addmul_ref(const REIM_FFTVEC_ADDMUL_PRECOMP* tables, double* r, const double* a, + const double* b) { + const uint64_t m = tables->m; + for (uint64_t i = 0; i < m; ++i) { + double re = a[i] * b[i] - a[i + m] * b[i + m]; + double im = a[i] * b[i + m] + a[i + m] * b[i]; + r[i] += re; + r[i + m] += im; + } +} + +EXPORT void reim_fftvec_add_ref(const REIM_FFTVEC_ADD_PRECOMP* precomp, double* r, const double* a, const double* b) { + const uint64_t m = precomp->m; + for (uint64_t i = 0; i < m; ++i) { + double re = a[i] + b[i]; + double im = a[i + m] + b[i + m]; + r[i] = re; + r[i + m] = im; + } +} + +EXPORT void reim_fftvec_sub_ref(const REIM_FFTVEC_SUB_PRECOMP* precomp, double* r, const double* a, const double* b) { + const uint64_t m = precomp->m; + for (uint64_t i = 0; i < m; ++i) { + double re = a[i] - b[i]; + double im = a[i + m] - b[i + m]; + r[i] = re; + r[i + m] = im; + } +} + +EXPORT void reim_fftvec_mul_ref(const REIM_FFTVEC_MUL_PRECOMP* tables, double* r, const double* a, const double* b) { + const uint64_t m = tables->m; + for (uint64_t i = 0; i < m; ++i) { + double re = a[i] * b[i] - a[i + m] * b[i + m]; + double im = a[i] * b[i + m] + a[i + m] * b[i]; + r[i] = re; + r[i + m] = im; + } +} + +EXPORT REIM_FFTVEC_ADDMUL_PRECOMP* new_reim_fftvec_addmul_precomp(uint32_t m) { + REIM_FFTVEC_ADDMUL_PRECOMP* reps = malloc(sizeof(REIM_FFTVEC_ADDMUL_PRECOMP)); + reps->m = m; + if (CPU_SUPPORTS("fma")) { + if (m >= 4) { + reps->function = reim_fftvec_addmul_fma; + } else { + reps->function = reim_fftvec_addmul_ref; + } + } else { + reps->function = reim_fftvec_addmul_ref; + } + return reps; +} + +EXPORT REIM_FFTVEC_ADD_PRECOMP* new_reim_fftvec_add_precomp(uint32_t m) { + REIM_FFTVEC_ADD_PRECOMP* reps = malloc(sizeof(REIM_FFTVEC_ADD_PRECOMP)); + reps->m = m; + if (CPU_SUPPORTS("fma")) { + if (m >= 4) { + reps->function = reim_fftvec_add_fma; + } else { + reps->function = reim_fftvec_add_ref; + } + } else { + reps->function = reim_fftvec_add_ref; + } + return reps; +} + +EXPORT REIM_FFTVEC_SUB_PRECOMP* new_reim_fftvec_sub_precomp(uint32_t m) { + REIM_FFTVEC_SUB_PRECOMP* reps = malloc(sizeof(REIM_FFTVEC_SUB_PRECOMP)); + reps->m = m; + if (CPU_SUPPORTS("fma")) { + if (m >= 4) { + reps->function = reim_fftvec_sub_fma; + } else { + reps->function = reim_fftvec_sub_ref; + } + } else { + reps->function = reim_fftvec_sub_ref; + } + return reps; +} + +EXPORT REIM_FFTVEC_MUL_PRECOMP* new_reim_fftvec_mul_precomp(uint32_t m) { + REIM_FFTVEC_MUL_PRECOMP* reps = malloc(sizeof(REIM_FFTVEC_MUL_PRECOMP)); + reps->m = m; + if (CPU_SUPPORTS("fma")) { + if (m >= 4) { + reps->function = reim_fftvec_mul_fma; + } else { + reps->function = reim_fftvec_mul_ref; + } + } else { + reps->function = reim_fftvec_mul_ref; + } + return reps; +} + +EXPORT REIM_FFTVEC_AUTOMORPHISM_PRECOMP* new_reim_fftvec_automorphism_precomp(uint32_t m) { + REIM_FFTVEC_AUTOMORPHISM_PRECOMP* reps = malloc(sizeof(REIM_FFTVEC_AUTOMORPHISM_PRECOMP)); + reps->m = m; + reps->function.apply = reim_fftvec_automorphism_ref; + reps->function.apply_inplace = reim_fftvec_automorphism_inplace_ref; + reps->function.apply_inplace_tmp_bytes = reim_fftvec_automorphism_inplace_tmp_bytes_ref; + const uint64_t nn = 2 * m; + reps->irev = malloc(sizeof(uint64_t) * nn); + uint32_t lognn = log2m(nn); + for (uint32_t i = 0; i < nn; i++) { + reps->irev[i] = (uint64_t)revbits(lognn, i); + } + return reps; +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_ifft16_avx_fma.s b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_ifft16_avx_fma.s new file mode 100644 index 0000000..4657c5d --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_ifft16_avx_fma.s @@ -0,0 +1,192 @@ +#rdi datare ptr +#rsi dataim ptr +#rdx om ptr +.globl reim_ifft16_avx_fma +reim_ifft16_avx_fma: +vmovupd (%rdi),%ymm0 # ra0 +vmovupd 0x20(%rdi),%ymm1 # ra4 +vmovupd 0x40(%rdi),%ymm2 # ra8 +vmovupd 0x60(%rdi),%ymm3 # ra12 +vmovupd (%rsi),%ymm4 # ia0 +vmovupd 0x20(%rsi),%ymm5 # ia4 +vmovupd 0x40(%rsi),%ymm6 # ia8 +vmovupd 0x60(%rsi),%ymm7 # ia12 + +1: +vmovupd 0x00(%rdx),%ymm12 +vmovupd 0x20(%rdx),%ymm13 + +vperm2f128 $0x31,%ymm2,%ymm0,%ymm8 # ymm8 contains re to mul (tw) +vperm2f128 $0x31,%ymm3,%ymm1,%ymm9 # ymm9 contains re to mul (itw) +vperm2f128 $0x31,%ymm6,%ymm4,%ymm10 # ymm10 contains im to mul (tw) +vperm2f128 $0x31,%ymm7,%ymm5,%ymm11 # ymm11 contains im to mul (itw) +vperm2f128 $0x20,%ymm2,%ymm0,%ymm0 # ymm0 contains re to add (tw) +vperm2f128 $0x20,%ymm3,%ymm1,%ymm1 # ymm1 contains re to add (itw) +vperm2f128 $0x20,%ymm6,%ymm4,%ymm2 # ymm2 contains im to add (tw) +vperm2f128 $0x20,%ymm7,%ymm5,%ymm3 # ymm3 contains im to add (itw) + +vunpckhpd %ymm1,%ymm0,%ymm4 # (0,1) -> (0,4) +vunpckhpd %ymm3,%ymm2,%ymm6 # (2,3) -> (2,6) +vunpckhpd %ymm9,%ymm8,%ymm5 # (8,9) -> (1,5) +vunpckhpd %ymm11,%ymm10,%ymm7 # (10,11) -> (3,7) +vunpcklpd %ymm1,%ymm0,%ymm0 +vunpcklpd %ymm3,%ymm2,%ymm2 +vunpcklpd %ymm9,%ymm8,%ymm1 +vunpcklpd %ymm11,%ymm10,%ymm3 + +# invctwiddle Re:(ymm0,ymm4) and Im:(ymm2,ymm6) with omega=(ymm12,ymm13) +# invcitwiddle Re:(ymm1,ymm5) and Im:(ymm3,ymm7) with omega=(ymm12,ymm13) +vsubpd %ymm4,%ymm0,%ymm8 # retw +vsubpd %ymm5,%ymm1,%ymm9 # reitw +vsubpd %ymm6,%ymm2,%ymm10 # imtw +vsubpd %ymm7,%ymm3,%ymm11 # imitw +vaddpd %ymm4,%ymm0,%ymm0 +vaddpd %ymm5,%ymm1,%ymm1 +vaddpd %ymm6,%ymm2,%ymm2 +vaddpd %ymm7,%ymm3,%ymm3 +# multiply 8,9,10,11 by 12,13, result to: 4,5,6,7 +# twiddles use reom=ymm12, imom=ymm13 +# invtwiddles use reom=ymm13, imom=-ymm12 +vmulpd %ymm10,%ymm13,%ymm4 # imtw.omai (tw) +vmulpd %ymm11,%ymm12,%ymm5 # imitw.omar (itw) +vmulpd %ymm8,%ymm13,%ymm6 # retw.omai (tw) +vmulpd %ymm9,%ymm12,%ymm7 # reitw.omar (itw) +vfmsub231pd %ymm8,%ymm12,%ymm4 # rprod0 (tw) +vfmadd231pd %ymm9,%ymm13,%ymm5 # rprod4 (itw) +vfmadd231pd %ymm10,%ymm12,%ymm6 # iprod0 (tw) +vfmsub231pd %ymm11,%ymm13,%ymm7 # iprod4 (itw) + +vunpckhpd %ymm7,%ymm3,%ymm11 # (0,4) -> (0,1) +vunpckhpd %ymm5,%ymm1,%ymm9 # (2,6) -> (2,3) +vunpcklpd %ymm7,%ymm3,%ymm10 +vunpcklpd %ymm5,%ymm1,%ymm8 +vunpckhpd %ymm6,%ymm2,%ymm3 # (1,5) -> (8,9) +vunpckhpd %ymm4,%ymm0,%ymm1 # (3,7) -> (10,11) +vunpcklpd %ymm6,%ymm2,%ymm2 +vunpcklpd %ymm4,%ymm0,%ymm0 + +/* +vperm2f128 $0x31,%ymm10,%ymm2,%ymm6 +vperm2f128 $0x31,%ymm11,%ymm3,%ymm7 +vperm2f128 $0x20,%ymm10,%ymm2,%ymm4 +vperm2f128 $0x20,%ymm11,%ymm3,%ymm5 +vperm2f128 $0x31,%ymm8,%ymm0,%ymm2 +vperm2f128 $0x31,%ymm9,%ymm1,%ymm3 +vperm2f128 $0x20,%ymm8,%ymm0,%ymm0 +vperm2f128 $0x20,%ymm9,%ymm1,%ymm1 +*/ +2: +vmovupd 0x40(%rdx),%ymm12 +vshufpd $15, %ymm12, %ymm12, %ymm13 # ymm13: omaiii'i' +vshufpd $0, %ymm12, %ymm12, %ymm12 # ymm12: omarrr'r' + +/* +vperm2f128 $0x31,%ymm2,%ymm0,%ymm8 # ymm8 contains re to mul (tw) +vperm2f128 $0x31,%ymm3,%ymm1,%ymm9 # ymm9 contains re to mul (itw) +vperm2f128 $0x31,%ymm6,%ymm4,%ymm10 # ymm10 contains im to mul (tw) +vperm2f128 $0x31,%ymm7,%ymm5,%ymm11 # ymm11 contains im to mul (itw) +vperm2f128 $0x20,%ymm2,%ymm0,%ymm0 # ymm0 contains re to add (tw) +vperm2f128 $0x20,%ymm3,%ymm1,%ymm1 # ymm1 contains re to add (itw) +vperm2f128 $0x20,%ymm6,%ymm4,%ymm2 # ymm2 contains im to add (tw) +vperm2f128 $0x20,%ymm7,%ymm5,%ymm3 # ymm3 contains im to add (itw) +*/ + +# invctwiddle Re:(ymm0,ymm8) and Im:(ymm2,ymm10) with omega=(ymm12,ymm13) +# invcitwiddle Re:(ymm1,ymm9) and Im:(ymm3,ymm11) with omega=(ymm12,ymm13) +vsubpd %ymm8,%ymm0,%ymm4 # retw +vsubpd %ymm9,%ymm1,%ymm5 # reitw +vsubpd %ymm10,%ymm2,%ymm6 # imtw +vsubpd %ymm11,%ymm3,%ymm7 # imitw +vaddpd %ymm8,%ymm0,%ymm0 +vaddpd %ymm9,%ymm1,%ymm1 +vaddpd %ymm10,%ymm2,%ymm2 +vaddpd %ymm11,%ymm3,%ymm3 +# multiply 4,5,6,7 by 12,13, result to 8,9,10,11 +# twiddles use reom=ymm12, imom=ymm13 +# invtwiddles use reom=ymm13, imom=-ymm12 +vmulpd %ymm6,%ymm13,%ymm8 # imtw.omai (tw) +vmulpd %ymm7,%ymm12,%ymm9 # imitw.omar (itw) +vmulpd %ymm4,%ymm13,%ymm10 # retw.omai (tw) +vmulpd %ymm5,%ymm12,%ymm11 # reitw.omar (itw) +vfmsub231pd %ymm4,%ymm12,%ymm8 # rprod0 (tw) +vfmadd231pd %ymm5,%ymm13,%ymm9 # rprod4 (itw) +vfmadd231pd %ymm6,%ymm12,%ymm10 # iprod0 (tw) +vfmsub231pd %ymm7,%ymm13,%ymm11 # iprod4 (itw) + +vperm2f128 $0x31,%ymm10,%ymm2,%ymm6 +vperm2f128 $0x31,%ymm11,%ymm3,%ymm7 +vperm2f128 $0x20,%ymm10,%ymm2,%ymm4 +vperm2f128 $0x20,%ymm11,%ymm3,%ymm5 +vperm2f128 $0x31,%ymm8,%ymm0,%ymm2 +vperm2f128 $0x31,%ymm9,%ymm1,%ymm3 +vperm2f128 $0x20,%ymm8,%ymm0,%ymm0 +vperm2f128 $0x20,%ymm9,%ymm1,%ymm1 + +3: +vmovupd 0x60(%rdx),%xmm12 +vinsertf128 $1, %xmm12, %ymm12, %ymm12 # omriri +vshufpd $15, %ymm12, %ymm12, %ymm13 # ymm13: omai +vshufpd $0, %ymm12, %ymm12, %ymm12 # ymm12: omar + +# invctwiddle Re:(ymm0,ymm1) and Im:(ymm4,ymm5) with omega=(ymm12,ymm13) +# invcitwiddle Re:(ymm2,ymm3) and Im:(ymm6,ymm7) with omega=(ymm12,ymm13) +vsubpd %ymm1,%ymm0,%ymm8 # retw +vsubpd %ymm3,%ymm2,%ymm9 # reitw +vsubpd %ymm5,%ymm4,%ymm10 # imtw +vsubpd %ymm7,%ymm6,%ymm11 # imitw +vaddpd %ymm1,%ymm0,%ymm0 +vaddpd %ymm3,%ymm2,%ymm2 +vaddpd %ymm5,%ymm4,%ymm4 +vaddpd %ymm7,%ymm6,%ymm6 +# multiply 8,9,10,11 by 12,13, result to 1,3,5,7 +# twiddles use reom=ymm12, imom=ymm13 +# invtwiddles use reom=ymm13, imom=-ymm12 +vmulpd %ymm10,%ymm13,%ymm1 # imtw.omai (tw) +vmulpd %ymm11,%ymm12,%ymm3 # imitw.omar (itw) +vmulpd %ymm8,%ymm13,%ymm5 # retw.omai (tw) +vmulpd %ymm9,%ymm12,%ymm7 # reitw.omar (itw) +vfmsub231pd %ymm8,%ymm12,%ymm1 # rprod0 (tw) +vfmadd231pd %ymm9,%ymm13,%ymm3 # rprod4 (itw) +vfmadd231pd %ymm10,%ymm12,%ymm5 # iprod0 (tw) +vfmsub231pd %ymm11,%ymm13,%ymm7 # iprod4 (itw) + +4: +vmovupd 0x70(%rdx),%xmm12 +vinsertf128 $1, %xmm12, %ymm12, %ymm12 # omriri +vshufpd $15, %ymm12, %ymm12, %ymm13 # ymm13: omai +vshufpd $0, %ymm12, %ymm12, %ymm12 # ymm12: omar + +# invctwiddle Re:(ymm0,ymm2) and Im:(ymm4,ymm6) with omega=(ymm12,ymm13) +# invctwiddle Re:(ymm1,ymm3) and Im:(ymm5,ymm7) with omega=(ymm12,ymm13) +vsubpd %ymm2,%ymm0,%ymm8 # retw1 +vsubpd %ymm3,%ymm1,%ymm9 # retw2 +vsubpd %ymm6,%ymm4,%ymm10 # imtw1 +vsubpd %ymm7,%ymm5,%ymm11 # imtw2 +vaddpd %ymm2,%ymm0,%ymm0 +vaddpd %ymm3,%ymm1,%ymm1 +vaddpd %ymm6,%ymm4,%ymm4 +vaddpd %ymm7,%ymm5,%ymm5 +# multiply 8,9,10,11 by 12,13, result to 2,3,6,7 +# twiddles use reom=ymm12, imom=ymm13 +vmulpd %ymm10,%ymm13,%ymm2 # imtw1.omai +vmulpd %ymm11,%ymm13,%ymm3 # imtw2.omai +vmulpd %ymm8,%ymm13,%ymm6 # retw1.omai +vmulpd %ymm9,%ymm13,%ymm7 # retw2.omai +vfmsub231pd %ymm8,%ymm12,%ymm2 # rprod0 +vfmsub231pd %ymm9,%ymm12,%ymm3 # rprod4 +vfmadd231pd %ymm10,%ymm12,%ymm6 # iprod0 +vfmadd231pd %ymm11,%ymm12,%ymm7 # iprod4 + +5: +vmovupd %ymm0,(%rdi) # ra0 +vmovupd %ymm1,0x20(%rdi) # ra4 +vmovupd %ymm2,0x40(%rdi) # ra8 +vmovupd %ymm3,0x60(%rdi) # ra12 +vmovupd %ymm4,(%rsi) # ia0 +vmovupd %ymm5,0x20(%rsi) # ia4 +vmovupd %ymm6,0x40(%rsi) # ia8 +vmovupd %ymm7,0x60(%rsi) # ia12 +vzeroupper +ret +.size reim_ifft16_avx_fma, .-reim_ifft16_avx_fma +.section .note.GNU-stack,"",@progbits diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_ifft16_avx_fma_win32.s b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_ifft16_avx_fma_win32.s new file mode 100644 index 0000000..cf554c5 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_ifft16_avx_fma_win32.s @@ -0,0 +1,228 @@ + .text + .p2align 4 + .globl reim_ifft16_avx_fma + .def reim_ifft16_avx_fma; .scl 2; .type 32; .endef +reim_ifft16_avx_fma: + + pushq %rdi + pushq %rsi + movq %rcx,%rdi + movq %rdx,%rsi + movq %r8,%rdx + subq $0x100,%rsp + movdqu %xmm6,(%rsp) + movdqu %xmm7,0x10(%rsp) + movdqu %xmm8,0x20(%rsp) + movdqu %xmm9,0x30(%rsp) + movdqu %xmm10,0x40(%rsp) + movdqu %xmm11,0x50(%rsp) + movdqu %xmm12,0x60(%rsp) + movdqu %xmm13,0x70(%rsp) + movdqu %xmm14,0x80(%rsp) + movdqu %xmm15,0x90(%rsp) + callq reim_ifft16_avx_fma_amd64 + movdqu (%rsp),%xmm6 + movdqu 0x10(%rsp),%xmm7 + movdqu 0x20(%rsp),%xmm8 + movdqu 0x30(%rsp),%xmm9 + movdqu 0x40(%rsp),%xmm10 + movdqu 0x50(%rsp),%xmm11 + movdqu 0x60(%rsp),%xmm12 + movdqu 0x70(%rsp),%xmm13 + movdqu 0x80(%rsp),%xmm14 + movdqu 0x90(%rsp),%xmm15 + addq $0x100,%rsp + popq %rsi + popq %rdi + retq + +#rdi datare ptr +#rsi dataim ptr +#rdx om ptr +#.globl reim_ifft16_avx_fma_amd64 +reim_ifft16_avx_fma_amd64: +vmovupd (%rdi),%ymm0 # ra0 +vmovupd 0x20(%rdi),%ymm1 # ra4 +vmovupd 0x40(%rdi),%ymm2 # ra8 +vmovupd 0x60(%rdi),%ymm3 # ra12 +vmovupd (%rsi),%ymm4 # ia0 +vmovupd 0x20(%rsi),%ymm5 # ia4 +vmovupd 0x40(%rsi),%ymm6 # ia8 +vmovupd 0x60(%rsi),%ymm7 # ia12 + +1: +vmovupd 0x00(%rdx),%ymm12 +vmovupd 0x20(%rdx),%ymm13 + +vperm2f128 $0x31,%ymm2,%ymm0,%ymm8 # ymm8 contains re to mul (tw) +vperm2f128 $0x31,%ymm3,%ymm1,%ymm9 # ymm9 contains re to mul (itw) +vperm2f128 $0x31,%ymm6,%ymm4,%ymm10 # ymm10 contains im to mul (tw) +vperm2f128 $0x31,%ymm7,%ymm5,%ymm11 # ymm11 contains im to mul (itw) +vperm2f128 $0x20,%ymm2,%ymm0,%ymm0 # ymm0 contains re to add (tw) +vperm2f128 $0x20,%ymm3,%ymm1,%ymm1 # ymm1 contains re to add (itw) +vperm2f128 $0x20,%ymm6,%ymm4,%ymm2 # ymm2 contains im to add (tw) +vperm2f128 $0x20,%ymm7,%ymm5,%ymm3 # ymm3 contains im to add (itw) + +vunpckhpd %ymm1,%ymm0,%ymm4 # (0,1) -> (0,4) +vunpckhpd %ymm3,%ymm2,%ymm6 # (2,3) -> (2,6) +vunpckhpd %ymm9,%ymm8,%ymm5 # (8,9) -> (1,5) +vunpckhpd %ymm11,%ymm10,%ymm7 # (10,11) -> (3,7) +vunpcklpd %ymm1,%ymm0,%ymm0 +vunpcklpd %ymm3,%ymm2,%ymm2 +vunpcklpd %ymm9,%ymm8,%ymm1 +vunpcklpd %ymm11,%ymm10,%ymm3 + +# invctwiddle Re:(ymm0,ymm4) and Im:(ymm2,ymm6) with omega=(ymm12,ymm13) +# invcitwiddle Re:(ymm1,ymm5) and Im:(ymm3,ymm7) with omega=(ymm12,ymm13) +vsubpd %ymm4,%ymm0,%ymm8 # retw +vsubpd %ymm5,%ymm1,%ymm9 # reitw +vsubpd %ymm6,%ymm2,%ymm10 # imtw +vsubpd %ymm7,%ymm3,%ymm11 # imitw +vaddpd %ymm4,%ymm0,%ymm0 +vaddpd %ymm5,%ymm1,%ymm1 +vaddpd %ymm6,%ymm2,%ymm2 +vaddpd %ymm7,%ymm3,%ymm3 +# multiply 8,9,10,11 by 12,13, result to: 4,5,6,7 +# twiddles use reom=ymm12, imom=ymm13 +# invtwiddles use reom=ymm13, imom=-ymm12 +vmulpd %ymm10,%ymm13,%ymm4 # imtw.omai (tw) +vmulpd %ymm11,%ymm12,%ymm5 # imitw.omar (itw) +vmulpd %ymm8,%ymm13,%ymm6 # retw.omai (tw) +vmulpd %ymm9,%ymm12,%ymm7 # reitw.omar (itw) +vfmsub231pd %ymm8,%ymm12,%ymm4 # rprod0 (tw) +vfmadd231pd %ymm9,%ymm13,%ymm5 # rprod4 (itw) +vfmadd231pd %ymm10,%ymm12,%ymm6 # iprod0 (tw) +vfmsub231pd %ymm11,%ymm13,%ymm7 # iprod4 (itw) + +vunpckhpd %ymm7,%ymm3,%ymm11 # (0,4) -> (0,1) +vunpckhpd %ymm5,%ymm1,%ymm9 # (2,6) -> (2,3) +vunpcklpd %ymm7,%ymm3,%ymm10 +vunpcklpd %ymm5,%ymm1,%ymm8 +vunpckhpd %ymm6,%ymm2,%ymm3 # (1,5) -> (8,9) +vunpckhpd %ymm4,%ymm0,%ymm1 # (3,7) -> (10,11) +vunpcklpd %ymm6,%ymm2,%ymm2 +vunpcklpd %ymm4,%ymm0,%ymm0 + +/* +vperm2f128 $0x31,%ymm10,%ymm2,%ymm6 +vperm2f128 $0x31,%ymm11,%ymm3,%ymm7 +vperm2f128 $0x20,%ymm10,%ymm2,%ymm4 +vperm2f128 $0x20,%ymm11,%ymm3,%ymm5 +vperm2f128 $0x31,%ymm8,%ymm0,%ymm2 +vperm2f128 $0x31,%ymm9,%ymm1,%ymm3 +vperm2f128 $0x20,%ymm8,%ymm0,%ymm0 +vperm2f128 $0x20,%ymm9,%ymm1,%ymm1 +*/ +2: +vmovupd 0x40(%rdx),%ymm12 +vshufpd $15, %ymm12, %ymm12, %ymm13 # ymm13: omaiii'i' +vshufpd $0, %ymm12, %ymm12, %ymm12 # ymm12: omarrr'r' + +/* +vperm2f128 $0x31,%ymm2,%ymm0,%ymm8 # ymm8 contains re to mul (tw) +vperm2f128 $0x31,%ymm3,%ymm1,%ymm9 # ymm9 contains re to mul (itw) +vperm2f128 $0x31,%ymm6,%ymm4,%ymm10 # ymm10 contains im to mul (tw) +vperm2f128 $0x31,%ymm7,%ymm5,%ymm11 # ymm11 contains im to mul (itw) +vperm2f128 $0x20,%ymm2,%ymm0,%ymm0 # ymm0 contains re to add (tw) +vperm2f128 $0x20,%ymm3,%ymm1,%ymm1 # ymm1 contains re to add (itw) +vperm2f128 $0x20,%ymm6,%ymm4,%ymm2 # ymm2 contains im to add (tw) +vperm2f128 $0x20,%ymm7,%ymm5,%ymm3 # ymm3 contains im to add (itw) +*/ + +# invctwiddle Re:(ymm0,ymm8) and Im:(ymm2,ymm10) with omega=(ymm12,ymm13) +# invcitwiddle Re:(ymm1,ymm9) and Im:(ymm3,ymm11) with omega=(ymm12,ymm13) +vsubpd %ymm8,%ymm0,%ymm4 # retw +vsubpd %ymm9,%ymm1,%ymm5 # reitw +vsubpd %ymm10,%ymm2,%ymm6 # imtw +vsubpd %ymm11,%ymm3,%ymm7 # imitw +vaddpd %ymm8,%ymm0,%ymm0 +vaddpd %ymm9,%ymm1,%ymm1 +vaddpd %ymm10,%ymm2,%ymm2 +vaddpd %ymm11,%ymm3,%ymm3 +# multiply 4,5,6,7 by 12,13, result to 8,9,10,11 +# twiddles use reom=ymm12, imom=ymm13 +# invtwiddles use reom=ymm13, imom=-ymm12 +vmulpd %ymm6,%ymm13,%ymm8 # imtw.omai (tw) +vmulpd %ymm7,%ymm12,%ymm9 # imitw.omar (itw) +vmulpd %ymm4,%ymm13,%ymm10 # retw.omai (tw) +vmulpd %ymm5,%ymm12,%ymm11 # reitw.omar (itw) +vfmsub231pd %ymm4,%ymm12,%ymm8 # rprod0 (tw) +vfmadd231pd %ymm5,%ymm13,%ymm9 # rprod4 (itw) +vfmadd231pd %ymm6,%ymm12,%ymm10 # iprod0 (tw) +vfmsub231pd %ymm7,%ymm13,%ymm11 # iprod4 (itw) + +vperm2f128 $0x31,%ymm10,%ymm2,%ymm6 +vperm2f128 $0x31,%ymm11,%ymm3,%ymm7 +vperm2f128 $0x20,%ymm10,%ymm2,%ymm4 +vperm2f128 $0x20,%ymm11,%ymm3,%ymm5 +vperm2f128 $0x31,%ymm8,%ymm0,%ymm2 +vperm2f128 $0x31,%ymm9,%ymm1,%ymm3 +vperm2f128 $0x20,%ymm8,%ymm0,%ymm0 +vperm2f128 $0x20,%ymm9,%ymm1,%ymm1 + +3: +vmovupd 0x60(%rdx),%xmm12 +vinsertf128 $1, %xmm12, %ymm12, %ymm12 # omriri +vshufpd $15, %ymm12, %ymm12, %ymm13 # ymm13: omai +vshufpd $0, %ymm12, %ymm12, %ymm12 # ymm12: omar + +# invctwiddle Re:(ymm0,ymm1) and Im:(ymm4,ymm5) with omega=(ymm12,ymm13) +# invcitwiddle Re:(ymm2,ymm3) and Im:(ymm6,ymm7) with omega=(ymm12,ymm13) +vsubpd %ymm1,%ymm0,%ymm8 # retw +vsubpd %ymm3,%ymm2,%ymm9 # reitw +vsubpd %ymm5,%ymm4,%ymm10 # imtw +vsubpd %ymm7,%ymm6,%ymm11 # imitw +vaddpd %ymm1,%ymm0,%ymm0 +vaddpd %ymm3,%ymm2,%ymm2 +vaddpd %ymm5,%ymm4,%ymm4 +vaddpd %ymm7,%ymm6,%ymm6 +# multiply 8,9,10,11 by 12,13, result to 1,3,5,7 +# twiddles use reom=ymm12, imom=ymm13 +# invtwiddles use reom=ymm13, imom=-ymm12 +vmulpd %ymm10,%ymm13,%ymm1 # imtw.omai (tw) +vmulpd %ymm11,%ymm12,%ymm3 # imitw.omar (itw) +vmulpd %ymm8,%ymm13,%ymm5 # retw.omai (tw) +vmulpd %ymm9,%ymm12,%ymm7 # reitw.omar (itw) +vfmsub231pd %ymm8,%ymm12,%ymm1 # rprod0 (tw) +vfmadd231pd %ymm9,%ymm13,%ymm3 # rprod4 (itw) +vfmadd231pd %ymm10,%ymm12,%ymm5 # iprod0 (tw) +vfmsub231pd %ymm11,%ymm13,%ymm7 # iprod4 (itw) + +4: +vmovupd 0x70(%rdx),%xmm12 +vinsertf128 $1, %xmm12, %ymm12, %ymm12 # omriri +vshufpd $15, %ymm12, %ymm12, %ymm13 # ymm13: omai +vshufpd $0, %ymm12, %ymm12, %ymm12 # ymm12: omar + +# invctwiddle Re:(ymm0,ymm2) and Im:(ymm4,ymm6) with omega=(ymm12,ymm13) +# invctwiddle Re:(ymm1,ymm3) and Im:(ymm5,ymm7) with omega=(ymm12,ymm13) +vsubpd %ymm2,%ymm0,%ymm8 # retw1 +vsubpd %ymm3,%ymm1,%ymm9 # retw2 +vsubpd %ymm6,%ymm4,%ymm10 # imtw1 +vsubpd %ymm7,%ymm5,%ymm11 # imtw2 +vaddpd %ymm2,%ymm0,%ymm0 +vaddpd %ymm3,%ymm1,%ymm1 +vaddpd %ymm6,%ymm4,%ymm4 +vaddpd %ymm7,%ymm5,%ymm5 +# multiply 8,9,10,11 by 12,13, result to 2,3,6,7 +# twiddles use reom=ymm12, imom=ymm13 +vmulpd %ymm10,%ymm13,%ymm2 # imtw1.omai +vmulpd %ymm11,%ymm13,%ymm3 # imtw2.omai +vmulpd %ymm8,%ymm13,%ymm6 # retw1.omai +vmulpd %ymm9,%ymm13,%ymm7 # retw2.omai +vfmsub231pd %ymm8,%ymm12,%ymm2 # rprod0 +vfmsub231pd %ymm9,%ymm12,%ymm3 # rprod4 +vfmadd231pd %ymm10,%ymm12,%ymm6 # iprod0 +vfmadd231pd %ymm11,%ymm12,%ymm7 # iprod4 + +5: +vmovupd %ymm0,(%rdi) # ra0 +vmovupd %ymm1,0x20(%rdi) # ra4 +vmovupd %ymm2,0x40(%rdi) # ra8 +vmovupd %ymm3,0x60(%rdi) # ra12 +vmovupd %ymm4,(%rsi) # ia0 +vmovupd %ymm5,0x20(%rsi) # ia4 +vmovupd %ymm6,0x40(%rsi) # ia8 +vmovupd %ymm7,0x60(%rsi) # ia12 +vzeroupper +ret diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_ifft4_avx_fma.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_ifft4_avx_fma.c new file mode 100644 index 0000000..5266c95 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_ifft4_avx_fma.c @@ -0,0 +1,62 @@ +#include +#include +#include + +#include "reim_fft_private.h" + +__always_inline void reim_invctwiddle_avx_fma(__m128d* ra, __m128d* rb, __m128d* ia, __m128d* ib, const __m128d omre, + const __m128d omim) { + __m128d rdiff = _mm_sub_pd(*ra, *rb); + __m128d idiff = _mm_sub_pd(*ia, *ib); + *ra = _mm_add_pd(*ra, *rb); + *ia = _mm_add_pd(*ia, *ib); + + *rb = _mm_mul_pd(idiff, omim); + *rb = _mm_fmsub_pd(rdiff, omre, *rb); + + *ib = _mm_mul_pd(rdiff, omim); + *ib = _mm_fmadd_pd(idiff, omre, *ib); +} + +EXPORT void reim_ifft4_avx_fma(double* dre, double* dim, const void* ompv) { + const double* omp = (const double*)ompv; + + __m128d ra0 = _mm_loadu_pd(dre); + __m128d ra2 = _mm_loadu_pd(dre + 2); + __m128d ia0 = _mm_loadu_pd(dim); + __m128d ia2 = _mm_loadu_pd(dim + 2); + + // 1 + { + const __m128d ifft4neg = _mm_castsi128_pd(_mm_set_epi64x(UINT64_C(1) << 63, 0)); + __m128d omre = _mm_loadu_pd(omp); // omre: r,i + __m128d omim = _mm_xor_pd(_mm_permute_pd(omre, 1), ifft4neg); // omim: i,-r + + __m128d ra = _mm_unpacklo_pd(ra0, ra2); // (r0, r1), (r2, r3) -> (r0, r2) + __m128d ia = _mm_unpacklo_pd(ia0, ia2); // (i0, i1), (i2, i3) -> (i0, i2) + __m128d rb = _mm_unpackhi_pd(ra0, ra2); // (r0, r1), (r2, r3) -> (r1, r3) + __m128d ib = _mm_unpackhi_pd(ia0, ia2); // (i0, i1), (i2, i3) -> (i1, i3) + + reim_invctwiddle_avx_fma(&ra, &rb, &ia, &ib, omre, omim); + + ra0 = _mm_unpacklo_pd(ra, rb); + ia0 = _mm_unpacklo_pd(ia, ib); + ra2 = _mm_unpackhi_pd(ra, rb); + ia2 = _mm_unpackhi_pd(ia, ib); + } + + // 2 + { + __m128d om = _mm_loadu_pd(omp + 2); + __m128d omre = _mm_permute_pd(om, 0); + __m128d omim = _mm_permute_pd(om, 3); + + reim_invctwiddle_avx_fma(&ra0, &ra2, &ia0, &ia2, omre, omim); + } + + // 4 + _mm_storeu_pd(dre, ra0); + _mm_storeu_pd(dre + 2, ra2); + _mm_storeu_pd(dim, ia0); + _mm_storeu_pd(dim + 2, ia2); +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_ifft8_avx_fma.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_ifft8_avx_fma.c new file mode 100644 index 0000000..b85fdda --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_ifft8_avx_fma.c @@ -0,0 +1,86 @@ +#include +#include +#include + +#include "reim_fft_private.h" + +__always_inline void reim_invctwiddle_avx_fma(__m256d* ra, __m256d* rb, __m256d* ia, __m256d* ib, const __m256d omre, + const __m256d omim) { + __m256d rdiff = _mm256_sub_pd(*ra, *rb); + __m256d idiff = _mm256_sub_pd(*ia, *ib); + *ra = _mm256_add_pd(*ra, *rb); + *ia = _mm256_add_pd(*ia, *ib); + + *rb = _mm256_mul_pd(idiff, omim); + *rb = _mm256_fmsub_pd(rdiff, omre, *rb); + + *ib = _mm256_mul_pd(rdiff, omim); + *ib = _mm256_fmadd_pd(idiff, omre, *ib); +} + +EXPORT void reim_ifft8_avx_fma(double* dre, double* dim, const void* ompv) { + const double* omp = (const double*)ompv; + + __m256d ra0 = _mm256_loadu_pd(dre); + __m256d ra4 = _mm256_loadu_pd(dre + 4); + __m256d ia0 = _mm256_loadu_pd(dim); + __m256d ia4 = _mm256_loadu_pd(dim + 4); + + // 1 + { + const __m256d fft8neg2 = _mm256_castsi256_pd(_mm256_set_epi64x(UINT64_C(1) << 63, UINT64_C(1) << 63, 0, 0)); + __m256d omr = _mm256_loadu_pd(omp); // r0,r1,i0,i1 + __m256d omiirr = _mm256_permute2f128_pd(omr, omr, 1); // i0,i1,r0,r1 + __m256d omi = _mm256_xor_pd(omiirr, fft8neg2); // i0,i1,-r0,-r1 + + __m256d rb = _mm256_unpackhi_pd(ra0, ra4); + __m256d ib = _mm256_unpackhi_pd(ia0, ia4); + __m256d ra = _mm256_unpacklo_pd(ra0, ra4); + __m256d ia = _mm256_unpacklo_pd(ia0, ia4); + + reim_invctwiddle_avx_fma(&ra, &rb, &ia, &ib, omr, omi); + + ra4 = _mm256_unpackhi_pd(ra, rb); + ia4 = _mm256_unpackhi_pd(ia, ib); + ra0 = _mm256_unpacklo_pd(ra, rb); + ia0 = _mm256_unpacklo_pd(ia, ib); + } + + // 2 + { + const __m128d ifft8neg = _mm_castsi128_pd(_mm_set_epi64x(0, UINT64_C(1) << 63)); + __m128d omri = _mm_loadu_pd(omp + 4); // r,i + __m128d ommri = _mm_xor_pd(omri, ifft8neg); // -r,i + __m256d omrimri = _mm256_set_m128d(ommri, omri); // r,i,-r,i + __m256d omi = _mm256_permute_pd(omrimri, 3); // i,i,-r,-r + __m256d omr = _mm256_permute_pd(omrimri, 12); // r,r,i,i + + __m256d rb = _mm256_permute2f128_pd(ra0, ra4, 0x31); + __m256d ib = _mm256_permute2f128_pd(ia0, ia4, 0x31); + __m256d ra = _mm256_permute2f128_pd(ra0, ra4, 0x20); + __m256d ia = _mm256_permute2f128_pd(ia0, ia4, 0x20); + + reim_invctwiddle_avx_fma(&ra, &rb, &ia, &ib, omr, omi); + + ra0 = _mm256_permute2f128_pd(ra, rb, 0x20); + ra4 = _mm256_permute2f128_pd(ra, rb, 0x31); + ia0 = _mm256_permute2f128_pd(ia, ib, 0x20); + ia4 = _mm256_permute2f128_pd(ia, ib, 0x31); + } + + // 3 + { + __m128d omri = _mm_loadu_pd(omp + 6); // r,i + __m256d omriri = _mm256_set_m128d(omri, omri); // r,i,r,i + __m256d omi = _mm256_permute_pd(omriri, 15); // i,i,i,i + __m256d omr = _mm256_permute_pd(omriri, 0); // r,r,r,r + + reim_invctwiddle_avx_fma(&ra0, &ra4, &ia0, &ia4, omr, omi); + } + + // 4 + _mm256_storeu_pd(dre, ra0); + _mm256_storeu_pd(dre + 4, ra4); + _mm256_storeu_pd(dim, ia0); + _mm256_storeu_pd(dim + 4, ia4); +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_ifft_avx2.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_ifft_avx2.c new file mode 100644 index 0000000..3a0b0f9 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_ifft_avx2.c @@ -0,0 +1,167 @@ +#include "immintrin.h" +#include "reim_fft_internal.h" +#include "reim_fft_private.h" + +void reim_invtwiddle_ifft_ref(uint64_t h, double* re, double* im, double om[2]); +void reim_invbitwiddle_ifft_ref(uint64_t h, double* re, double* im, double om[4]); + +__always_inline void reim_invtwiddle_ifft_avx2_fma(uint32_t h, double* re, double* im, double om[2]) { + const __m128d omx = _mm_load_pd(om); + const __m256d omra = _mm256_set_m128d(omx, omx); + const __m256d omi = _mm256_unpackhi_pd(omra, omra); + const __m256d omr = _mm256_unpacklo_pd(omra, omra); + double* r0 = re; + double* r1 = re + h; + double* i0 = im; + double* i1 = im + h; + for (uint32_t i = 0; i < h; i += 4) { + __m256d ur0 = _mm256_loadu_pd(r0 + i); + __m256d ur1 = _mm256_loadu_pd(r1 + i); + __m256d ui0 = _mm256_loadu_pd(i0 + i); + __m256d ui1 = _mm256_loadu_pd(i1 + i); + __m256d tra = _mm256_sub_pd(ur0, ur1); + __m256d tia = _mm256_sub_pd(ui0, ui1); + ur0 = _mm256_add_pd(ur0, ur1); + ui0 = _mm256_add_pd(ui0, ui1); + ur1 = _mm256_mul_pd(omi, tia); + ui1 = _mm256_mul_pd(omi, tra); + ur1 = _mm256_fmsub_pd(omr, tra, ur1); + ui1 = _mm256_fmadd_pd(omr, tia, ui1); + _mm256_storeu_pd(r0 + i, ur0); + _mm256_storeu_pd(r1 + i, ur1); + _mm256_storeu_pd(i0 + i, ui0); + _mm256_storeu_pd(i1 + i, ui1); + } +} + +__always_inline void reim_invbitwiddle_ifft_avx2_fma(uint32_t h, double* re, double* im, double om[4]) { + // reim_invbitwiddle_ifft_ref(h,re,im, om); + double* const r0 = re; + double* const r1 = re + h; + double* const r2 = re + 2 * h; + double* const r3 = re + 3 * h; + double* const i0 = im; + double* const i1 = im + h; + double* const i2 = im + 2 * h; + double* const i3 = im + 3 * h; + const __m256d om0 = _mm256_loadu_pd(om); + const __m256d omb = _mm256_permute2f128_pd(om0, om0, 0x11); + const __m256d oma = _mm256_permute2f128_pd(om0, om0, 0x00); + const __m256d omai = _mm256_unpackhi_pd(oma, oma); + const __m256d omar = _mm256_unpacklo_pd(oma, oma); + const __m256d ombi = _mm256_unpackhi_pd(omb, omb); + const __m256d ombr = _mm256_unpacklo_pd(omb, omb); + for (uint32_t i = 0; i < h; i += 4) { + __m256d ur0 = _mm256_loadu_pd(r0 + i); + __m256d ur1 = _mm256_loadu_pd(r1 + i); + __m256d ur2 = _mm256_loadu_pd(r2 + i); + __m256d ur3 = _mm256_loadu_pd(r3 + i); + __m256d ui0 = _mm256_loadu_pd(i0 + i); + __m256d ui1 = _mm256_loadu_pd(i1 + i); + __m256d ui2 = _mm256_loadu_pd(i2 + i); + __m256d ui3 = _mm256_loadu_pd(i3 + i); + __m256d tra, trb, tia, tib; + //------ twiddles 2 + tra = _mm256_sub_pd(ur0, ur1); + trb = _mm256_sub_pd(ur2, ur3); + tia = _mm256_sub_pd(ui0, ui1); + tib = _mm256_sub_pd(ui2, ui3); + ur0 = _mm256_add_pd(ur0, ur1); + ur2 = _mm256_add_pd(ur2, ur3); + ui0 = _mm256_add_pd(ui0, ui1); + ui2 = _mm256_add_pd(ui2, ui3); + ur1 = _mm256_mul_pd(omai, tia); + ur3 = _mm256_mul_pd(omar, tib); + ui1 = _mm256_mul_pd(omai, tra); + ui3 = _mm256_mul_pd(omar, trb); + ur1 = _mm256_fmsub_pd(omar, tra, ur1); + ur3 = _mm256_fmadd_pd(omai, trb, ur3); + ui1 = _mm256_fmadd_pd(omar, tia, ui1); + ui3 = _mm256_fmsub_pd(omai, tib, ui3); + //------ twiddles 1 + tra = _mm256_sub_pd(ur0, ur2); + trb = _mm256_sub_pd(ur1, ur3); + tia = _mm256_sub_pd(ui0, ui2); + tib = _mm256_sub_pd(ui1, ui3); + ur0 = _mm256_add_pd(ur0, ur2); + ur1 = _mm256_add_pd(ur1, ur3); + ui0 = _mm256_add_pd(ui0, ui2); + ui1 = _mm256_add_pd(ui1, ui3); + ur2 = _mm256_mul_pd(ombi, tia); + ur3 = _mm256_mul_pd(ombi, tib); + ui2 = _mm256_mul_pd(ombi, tra); + ui3 = _mm256_mul_pd(ombi, trb); + ur2 = _mm256_fmsub_pd(ombr, tra, ur2); + ur3 = _mm256_fmsub_pd(ombr, trb, ur3); + ui2 = _mm256_fmadd_pd(ombr, tia, ui2); + ui3 = _mm256_fmadd_pd(ombr, tib, ui3); + ///--- + _mm256_storeu_pd(r0 + i, ur0); + _mm256_storeu_pd(r1 + i, ur1); + _mm256_storeu_pd(r2 + i, ur2); + _mm256_storeu_pd(r3 + i, ur3); + _mm256_storeu_pd(i0 + i, ui0); + _mm256_storeu_pd(i1 + i, ui1); + _mm256_storeu_pd(i2 + i, ui2); + _mm256_storeu_pd(i3 + i, ui3); + } +} + +void reim_ifft_bfs_16_avx2_fma(uint32_t m, double* re, double* im, double** omg) { + uint32_t log2m = _mm_popcnt_u32(m - 1); // log2(m); + for (uint32_t off = 0; off < m; off += 16) { + reim_ifft16_avx_fma(re + off, im + off, *omg); + *omg += 16; + } + uint32_t h = 16; + uint32_t ms2 = m >> 1; // m/2 + while (h < ms2) { + uint32_t mm = h << 2; + for (uint32_t off = 0; off < m; off += mm) { + reim_invbitwiddle_ifft_avx2_fma(h, re + off, im + off, *omg); + *omg += 4; + } + h = mm; + } + if (log2m & 1) { + // if (h!=ms2) abort(); // bug + // do the first twiddle iteration normally + reim_invtwiddle_ifft_avx2_fma(h, re, im, *omg); + *omg += 2; + // h = m; + } +} + +void reim_ifft_rec_16_avx2_fma(uint32_t m, double* re, double* im, double** omg) { + if (m <= 2048) return reim_ifft_bfs_16_avx2_fma(m, re, im, omg); + const uint32_t h = m >> 1; // m / 2; + reim_ifft_rec_16_avx2_fma(h, re, im, omg); + reim_ifft_rec_16_avx2_fma(h, re + h, im + h, omg); + reim_invtwiddle_ifft_avx2_fma(h, re, im, *omg); + *omg += 2; +} + +void reim_ifft_avx2_fma(const REIM_IFFT_PRECOMP* precomp, double* dat) { + const int32_t m = precomp->m; + double* omg = precomp->powomegas; + double* re = dat; + double* im = dat + m; + if (m <= 16) { + switch (m) { + case 1: + return; + case 2: + return reim_ifft2_ref(re, im, omg); + case 4: + return reim_ifft4_avx_fma(re, im, omg); + case 8: + return reim_ifft8_avx_fma(re, im, omg); + case 16: + return reim_ifft16_avx_fma(re, im, omg); + default: + abort(); // m is not a power of 2 + } + } + if (m <= 2048) return reim_ifft_bfs_16_avx2_fma(m, re, im, &omg); + return reim_ifft_rec_16_avx2_fma(m, re, im, &omg); +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_ifft_ref.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_ifft_ref.c new file mode 100644 index 0000000..eea071b --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_ifft_ref.c @@ -0,0 +1,408 @@ +#include "../commons_private.h" +#include "reim_fft_internal.h" +#include "reim_fft_private.h" + +void reim_invctwiddle(double* ra, double* ia, double* rb, double* ib, double omre, double omim) { + double rdiff = *ra - *rb; + double idiff = *ia - *ib; + *ra = *ra + *rb; + *ia = *ia + *ib; + *rb = rdiff * omre - idiff * omim; + *ib = rdiff * omim + idiff * omre; +} + +// i (omre + i omim) = -omim + i omre +void reim_invcitwiddle(double* ra, double* ia, double* rb, double* ib, double omre, double omim) { + double rdiff = *ra - *rb; + double idiff = *ia - *ib; + *ra = *ra + *rb; + *ia = *ia + *ib; + *rb = rdiff * omim + idiff * omre; + *ib = -rdiff * omre + idiff * omim; +} + +void reim_ifft16_ref(double* dre, double* dim, const void* pom) { + const double* om = (const double*)pom; + { + double omare = om[0]; + double ombre = om[1]; + double omcre = om[2]; + double omdre = om[3]; + double omaim = om[4]; + double ombim = om[5]; + double omcim = om[6]; + double omdim = om[7]; + reim_invctwiddle(&dre[0], &dim[0], &dre[1], &dim[1], omare, omaim); + reim_invcitwiddle(&dre[2], &dim[2], &dre[3], &dim[3], omare, omaim); + reim_invctwiddle(&dre[4], &dim[4], &dre[5], &dim[5], ombre, ombim); + reim_invcitwiddle(&dre[6], &dim[6], &dre[7], &dim[7], ombre, ombim); + reim_invctwiddle(&dre[8], &dim[8], &dre[9], &dim[9], omcre, omcim); + reim_invcitwiddle(&dre[10], &dim[10], &dre[11], &dim[11], omcre, omcim); + reim_invctwiddle(&dre[12], &dim[12], &dre[13], &dim[13], omdre, omdim); + reim_invcitwiddle(&dre[14], &dim[14], &dre[15], &dim[15], omdre, omdim); + } + { + double omare = om[8]; + double omaim = om[9]; + double ombre = om[10]; + double ombim = om[11]; + reim_invctwiddle(&dre[0], &dim[0], &dre[2], &dim[2], omare, omaim); + reim_invctwiddle(&dre[1], &dim[1], &dre[3], &dim[3], omare, omaim); + reim_invcitwiddle(&dre[4], &dim[4], &dre[6], &dim[6], omare, omaim); + reim_invcitwiddle(&dre[5], &dim[5], &dre[7], &dim[7], omare, omaim); + reim_invctwiddle(&dre[8], &dim[8], &dre[10], &dim[10], ombre, ombim); + reim_invctwiddle(&dre[9], &dim[9], &dre[11], &dim[11], ombre, ombim); + reim_invcitwiddle(&dre[12], &dim[12], &dre[14], &dim[14], ombre, ombim); + reim_invcitwiddle(&dre[13], &dim[13], &dre[15], &dim[15], ombre, ombim); + } + { + double omre = om[12]; + double omim = om[13]; + reim_invctwiddle(&dre[0], &dim[0], &dre[4], &dim[4], omre, omim); + reim_invctwiddle(&dre[1], &dim[1], &dre[5], &dim[5], omre, omim); + reim_invctwiddle(&dre[2], &dim[2], &dre[6], &dim[6], omre, omim); + reim_invctwiddle(&dre[3], &dim[3], &dre[7], &dim[7], omre, omim); + reim_invcitwiddle(&dre[8], &dim[8], &dre[12], &dim[12], omre, omim); + reim_invcitwiddle(&dre[9], &dim[9], &dre[13], &dim[13], omre, omim); + reim_invcitwiddle(&dre[10], &dim[10], &dre[14], &dim[14], omre, omim); + reim_invcitwiddle(&dre[11], &dim[11], &dre[15], &dim[15], omre, omim); + } + { + double omre = om[14]; + double omim = om[15]; + reim_invctwiddle(&dre[0], &dim[0], &dre[8], &dim[8], omre, omim); + reim_invctwiddle(&dre[1], &dim[1], &dre[9], &dim[9], omre, omim); + reim_invctwiddle(&dre[2], &dim[2], &dre[10], &dim[10], omre, omim); + reim_invctwiddle(&dre[3], &dim[3], &dre[11], &dim[11], omre, omim); + reim_invctwiddle(&dre[4], &dim[4], &dre[12], &dim[12], omre, omim); + reim_invctwiddle(&dre[5], &dim[5], &dre[13], &dim[13], omre, omim); + reim_invctwiddle(&dre[6], &dim[6], &dre[14], &dim[14], omre, omim); + reim_invctwiddle(&dre[7], &dim[7], &dre[15], &dim[15], omre, omim); + } +} + +void fill_reim_ifft16_omegas(const double entry_pwr, double** omg) { + const double j_pow = 1. / 8.; + const double k_pow = 1. / 16.; + const double pin = entry_pwr / 2.; + const double pin_2 = entry_pwr / 4.; + const double pin_4 = entry_pwr / 8.; + const double pin_8 = entry_pwr / 16.; + // ((8,9,10,11),(12,13,14,15)) are 4 reals then 4 imag of om^1/8*(1,k,j,kj) + (*omg)[0] = cos(2. * M_PI * (pin_8)); + (*omg)[1] = cos(2. * M_PI * (pin_8 + j_pow)); + (*omg)[2] = cos(2. * M_PI * (pin_8 + k_pow)); + (*omg)[3] = cos(2. * M_PI * (pin_8 + j_pow + k_pow)); + (*omg)[4] = -sin(2. * M_PI * (pin_8)); + (*omg)[5] = -sin(2. * M_PI * (pin_8 + j_pow)); + (*omg)[6] = -sin(2. * M_PI * (pin_8 + k_pow)); + (*omg)[7] = -sin(2. * M_PI * (pin_8 + j_pow + k_pow)); + // (4,5) and (6,7) are real and imag of om^1/4 and j.om^1/4 + (*omg)[8] = cos(2. * M_PI * (pin_4)); + (*omg)[9] = -sin(2. * M_PI * (pin_4)); + (*omg)[10] = cos(2. * M_PI * (pin_4 + j_pow)); + (*omg)[11] = -sin(2. * M_PI * (pin_4 + j_pow)); + // 2 and 3 are real and imag of om^1/2 + (*omg)[12] = cos(2. * M_PI * (pin_2)); + (*omg)[13] = -sin(2. * M_PI * (pin_2)); + // 0 and 1 are real and imag of om + (*omg)[14] = cos(2. * M_PI * pin); + (*omg)[15] = -sin(2. * M_PI * pin); + *omg += 16; +} + +void reim_ifft8_ref(double* dre, double* dim, const void* pom) { + const double* om = (const double*)pom; + { + double omare = om[0]; + double ombre = om[1]; + double omaim = om[2]; + double ombim = om[3]; + reim_invctwiddle(&dre[0], &dim[0], &dre[1], &dim[1], omare, omaim); + reim_invcitwiddle(&dre[2], &dim[2], &dre[3], &dim[3], omare, omaim); + reim_invctwiddle(&dre[4], &dim[4], &dre[5], &dim[5], ombre, ombim); + reim_invcitwiddle(&dre[6], &dim[6], &dre[7], &dim[7], ombre, ombim); + } + { + double omare = om[4]; + double omaim = om[5]; + reim_invctwiddle(&dre[0], &dim[0], &dre[2], &dim[2], omare, omaim); + reim_invctwiddle(&dre[1], &dim[1], &dre[3], &dim[3], omare, omaim); + reim_invcitwiddle(&dre[4], &dim[4], &dre[6], &dim[6], omare, omaim); + reim_invcitwiddle(&dre[5], &dim[5], &dre[7], &dim[7], omare, omaim); + } + { + double omre = om[6]; + double omim = om[7]; + reim_invctwiddle(&dre[0], &dim[0], &dre[4], &dim[4], omre, omim); + reim_invctwiddle(&dre[1], &dim[1], &dre[5], &dim[5], omre, omim); + reim_invctwiddle(&dre[2], &dim[2], &dre[6], &dim[6], omre, omim); + reim_invctwiddle(&dre[3], &dim[3], &dre[7], &dim[7], omre, omim); + } +} + +void fill_reim_ifft8_omegas(const double entry_pwr, double** omg) { + const double j_pow = 1. / 8.; + const double pin = entry_pwr / 2.; + const double pin_2 = entry_pwr / 4.; + const double pin_4 = entry_pwr / 8.; + // (4,5) and (6,7) are real and imag of om^1/4 and j.om^1/4 + (*omg)[0] = cos(2. * M_PI * (pin_4)); + (*omg)[1] = cos(2. * M_PI * (pin_4 + j_pow)); + (*omg)[2] = -sin(2. * M_PI * (pin_4)); + (*omg)[3] = -sin(2. * M_PI * (pin_4 + j_pow)); + // 2 and 3 are real and imag of om^1/2 + (*omg)[4] = cos(2. * M_PI * (pin_2)); + (*omg)[5] = -sin(2. * M_PI * (pin_2)); + // 0 and 1 are real and imag of om + (*omg)[6] = cos(2. * M_PI * pin); + (*omg)[7] = -sin(2. * M_PI * pin); + *omg += 8; +} + +void reim_ifft4_ref(double* dre, double* dim, const void* pom) { + const double* om = (const double*)pom; + { + double omare = om[0]; + double omaim = om[1]; + reim_invctwiddle(&dre[0], &dim[0], &dre[1], &dim[1], omare, omaim); + reim_invcitwiddle(&dre[2], &dim[2], &dre[3], &dim[3], omare, omaim); + } + { + double omare = om[2]; + double omaim = om[3]; + reim_invctwiddle(&dre[0], &dim[0], &dre[2], &dim[2], omare, omaim); + reim_invctwiddle(&dre[1], &dim[1], &dre[3], &dim[3], omare, omaim); + } +} + +void fill_reim_ifft4_omegas(const double entry_pwr, double** omg) { + const double pin = entry_pwr / 2.; + const double pin_2 = entry_pwr / 4.; + // 2 and 3 are real and imag of om^1/2 + (*omg)[0] = cos(2. * M_PI * (pin_2)); + (*omg)[1] = -sin(2. * M_PI * (pin_2)); + // 0 and 1 are real and imag of om + (*omg)[2] = cos(2. * M_PI * pin); + (*omg)[3] = -sin(2. * M_PI * pin); + *omg += 4; +} + +void reim_ifft2_ref(double* dre, double* dim, const void* pom) { + const double* om = (const double*)pom; + { + double omare = om[0]; + double omaim = om[1]; + reim_invctwiddle(&dre[0], &dim[0], &dre[1], &dim[1], omare, omaim); + } +} + +void fill_reim_ifft2_omegas(const double entry_pwr, double** omg) { + const double pin = entry_pwr / 2.; + // 0 and 1 are real and imag of om + (*omg)[0] = cos(2. * M_PI * pin); + (*omg)[1] = -sin(2. * M_PI * pin); + *omg += 2; +} + +void reim_invtwiddle_ifft_ref(uint64_t h, double* re, double* im, double om[2]) { + for (uint64_t i = 0; i < h; ++i) { + reim_invctwiddle(&re[i], &im[i], &re[h + i], &im[h + i], om[0], om[1]); + } +} + +void reim_invbitwiddle_ifft_ref(uint64_t h, double* re, double* im, double om[4]) { + double* r0 = re; + double* r1 = re + h; + double* r2 = re + 2 * h; + double* r3 = re + 3 * h; + double* i0 = im; + double* i1 = im + h; + double* i2 = im + 2 * h; + double* i3 = im + 3 * h; + for (uint64_t i = 0; i < h; ++i) { + reim_invctwiddle(&r0[i], &i0[i], &r1[i], &i1[i], om[0], om[1]); + reim_invcitwiddle(&r2[i], &i2[i], &r3[i], &i3[i], om[0], om[1]); + } + for (uint64_t i = 0; i < h; ++i) { + reim_invctwiddle(&r0[i], &i0[i], &r2[i], &i2[i], om[2], om[3]); + reim_invctwiddle(&r1[i], &i1[i], &r3[i], &i3[i], om[2], om[3]); + } +} + +void reim_ifft_bfs_16_ref(uint64_t m, double* re, double* im, double** omg) { + uint64_t log2m = log2(m); + for (uint64_t off = 0; off < m; off += 16) { + reim_ifft16_ref(re + off, im + off, *omg); + *omg += 16; + } + uint64_t h = 16; + uint64_t ms2 = m / 2; + while (h < ms2) { + uint64_t mm = h << 2; + for (uint64_t off = 0; off < m; off += mm) { + reim_invbitwiddle_ifft_ref(h, re + off, im + off, *omg); + *omg += 4; + } + h = mm; + } + if (log2m % 2 != 0) { + if (h != ms2) abort(); // bug + // do the first twiddle iteration normally + reim_invtwiddle_ifft_ref(h, re, im, *omg); + *omg += 2; + h = m; + } +} + +void fill_reim_ifft_bfs_16_omegas(uint64_t m, double entry_pwr, double** omg) { + uint64_t log2m = log2(m); + // uint64_t mm = 16; + double ss = entry_pwr * 16. / m; + for (uint64_t off = 0; off < m; off += 16) { + double s = ss + fracrevbits(off / 16); + fill_reim_ifft16_omegas(s, omg); + } + uint64_t h = 16; + uint64_t ms2 = m / 2; + while (h < ms2) { + uint64_t mm = h << 2; + for (uint64_t off = 0; off < m; off += mm) { + double rs0 = ss + fracrevbits(off / mm) / 4.; + double rs1 = 2. * rs0; + (*omg)[0] = cos(2 * M_PI * rs0); + (*omg)[1] = -sin(2 * M_PI * rs0); + (*omg)[2] = cos(2 * M_PI * rs1); + (*omg)[3] = -sin(2 * M_PI * rs1); + *omg += 4; + } + ss *= 4.; + h = mm; + } + if (log2m % 2 != 0) { + if (h != ms2) abort(); // bug + // do the first twiddle iteration normally + (*omg)[0] = cos(2 * M_PI * ss); + (*omg)[1] = -sin(2 * M_PI * ss); + *omg += 2; + h = m; + ss *= 2.; + } + if (ss != entry_pwr) abort(); +} + +void reim_ifft_rec_16_ref(uint64_t m, double* re, double* im, double** omg) { + if (m <= 2048) return reim_ifft_bfs_16_ref(m, re, im, omg); + const uint32_t h = m / 2; + reim_ifft_rec_16_ref(h, re, im, omg); + reim_ifft_rec_16_ref(h, re + h, im + h, omg); + reim_invtwiddle_ifft_ref(h, re, im, *omg); + *omg += 2; +} + +void fill_reim_ifft_rec_16_omegas(uint64_t m, double entry_pwr, double** omg) { + if (m <= 2048) return fill_reim_ifft_bfs_16_omegas(m, entry_pwr, omg); + const uint64_t h = m / 2; + const double s = entry_pwr / 2; + fill_reim_ifft_rec_16_omegas(h, s, omg); + fill_reim_ifft_rec_16_omegas(h, s + 0.5, omg); + (*omg)[0] = cos(2 * M_PI * s); + (*omg)[1] = -sin(2 * M_PI * s); + *omg += 2; +} + +void reim_ifft_ref(const REIM_IFFT_PRECOMP* precomp, double* dat) { + const int32_t m = precomp->m; + double* omg = precomp->powomegas; + double* re = dat; + double* im = dat + m; + if (m <= 16) { + switch (m) { + case 1: + return; + case 2: + return reim_ifft2_ref(re, im, omg); + case 4: + return reim_ifft4_ref(re, im, omg); + case 8: + return reim_ifft8_ref(re, im, omg); + case 16: + return reim_ifft16_ref(re, im, omg); + default: + abort(); // m is not a power of 2 + } + } + if (m <= 2048) return reim_ifft_bfs_16_ref(m, re, im, &omg); + return reim_ifft_rec_16_ref(m, re, im, &omg); +} + +EXPORT REIM_IFFT_PRECOMP* new_reim_ifft_precomp(uint32_t m, uint32_t num_buffers) { + const uint64_t OMG_SPACE = ceilto64b(2 * m * sizeof(double)); + const uint64_t BUF_SIZE = ceilto64b(2 * m * sizeof(double)); + void* reps = malloc(sizeof(REIM_IFFT_PRECOMP) // base + + 63 // padding + + OMG_SPACE // tables //TODO 16? + + num_buffers * BUF_SIZE // buffers + ); + uint64_t aligned_addr = ceilto64b((uint64_t)(reps) + sizeof(REIM_IFFT_PRECOMP)); + REIM_IFFT_PRECOMP* r = (REIM_IFFT_PRECOMP*)reps; + r->m = m; + r->buf_size = BUF_SIZE; + r->powomegas = (double*)aligned_addr; + r->aligned_buffers = (void*)(aligned_addr + OMG_SPACE); + // fill in powomegas + double* omg = (double*)r->powomegas; + if (m <= 16) { + switch (m) { + case 1: + break; + case 2: + fill_reim_ifft2_omegas(0.25, &omg); + break; + case 4: + fill_reim_ifft4_omegas(0.25, &omg); + break; + case 8: + fill_reim_ifft8_omegas(0.25, &omg); + break; + case 16: + fill_reim_ifft16_omegas(0.25, &omg); + break; + default: + abort(); // m is not a power of 2 + } + } else if (m <= 2048) { + fill_reim_ifft_bfs_16_omegas(m, 0.25, &omg); + } else { + fill_reim_ifft_rec_16_omegas(m, 0.25, &omg); + } + if (((uint64_t)omg) - aligned_addr > OMG_SPACE) abort(); + // dispatch the right implementation + { + if (CPU_SUPPORTS("fma")) { + r->function = reim_ifft_avx2_fma; + } else { + r->function = reim_ifft_ref; + } + } + return reps; +} + +void reim_naive_ifft(uint64_t m, double entry_pwr, double* re, double* im) { + if (m == 1) return; + // twiddle + const uint64_t h = m / 2; + const double s = entry_pwr / 2.; + reim_naive_ifft(h, s, re, im); + reim_naive_ifft(h, s + 0.5, re + h, im + h); + const double sre = cos(2 * M_PI * s); + const double sim = -sin(2 * M_PI * s); + for (uint64_t j = 0; j < h; ++j) { + double rdiff = re[j] - re[h + j]; + double idiff = im[j] - im[h + j]; + re[j] = re[j] + re[h + j]; + im[j] = im[j] + im[h + j]; + re[h + j] = rdiff * sre - idiff * sim; + im[h + j] = idiff * sre + rdiff * sim; + } +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_to_tnx_avx.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_to_tnx_avx.c new file mode 100644 index 0000000..a77137a --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_to_tnx_avx.c @@ -0,0 +1,35 @@ +#include +#include +#include + +#include "../commons_private.h" +#include "reim_fft_internal.h" +#include "reim_fft_private.h" + +typedef union { + double d; + uint64_t u; +} dblui64_t; + +EXPORT void reim_to_tnx_avx(const REIM_TO_TNX_PRECOMP* tables, double* r, const double* x) { + const uint64_t n = tables->m << 1; + const __m256d add_cst = _mm256_set1_pd(tables->add_cst); + const __m256d mask_and = _mm256_castsi256_pd(_mm256_set1_epi64x(tables->mask_and)); + const __m256d mask_or = _mm256_castsi256_pd(_mm256_set1_epi64x(tables->mask_or)); + const __m256d sub_cst = _mm256_set1_pd(tables->sub_cst); + __m256d reg0, reg1; + for (uint64_t i = 0; i < n; i += 8) { + reg0 = _mm256_loadu_pd(x + i); + reg1 = _mm256_loadu_pd(x + i + 4); + reg0 = _mm256_add_pd(reg0, add_cst); + reg1 = _mm256_add_pd(reg1, add_cst); + reg0 = _mm256_and_pd(reg0, mask_and); + reg1 = _mm256_and_pd(reg1, mask_and); + reg0 = _mm256_or_pd(reg0, mask_or); + reg1 = _mm256_or_pd(reg1, mask_or); + reg0 = _mm256_sub_pd(reg0, sub_cst); + reg1 = _mm256_sub_pd(reg1, sub_cst); + _mm256_storeu_pd(r + i, reg0); + _mm256_storeu_pd(r + i + 4, reg1); + } +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_to_tnx_ref.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_to_tnx_ref.c new file mode 100644 index 0000000..3ceeff9 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim/reim_to_tnx_ref.c @@ -0,0 +1,75 @@ +#include +#include + +#include "../commons_private.h" +#include "reim_fft_internal.h" +#include "reim_fft_private.h" + +EXPORT void reim_to_tnx_basic_ref(const REIM_TO_TNX_PRECOMP* tables, double* r, const double* x) { + const uint64_t n = tables->m << 1; + double divisor = tables->divisor; + for (uint64_t i = 0; i < n; ++i) { + double ri = x[i] / divisor; + r[i] = ri - rint(ri); + } +} + +typedef union { + double d; + uint64_t u; +} dblui64_t; + +EXPORT void reim_to_tnx_ref(const REIM_TO_TNX_PRECOMP* tables, double* r, const double* x) { + const uint64_t n = tables->m << 1; + double add_cst = tables->add_cst; + uint64_t mask_and = tables->mask_and; + uint64_t mask_or = tables->mask_or; + double sub_cst = tables->sub_cst; + dblui64_t cur; + for (uint64_t i = 0; i < n; ++i) { + cur.d = x[i] + add_cst; + cur.u &= mask_and; + cur.u |= mask_or; + r[i] = cur.d - sub_cst; + } +} + +void* init_reim_to_tnx_precomp(REIM_TO_TNX_PRECOMP* const res, uint32_t m, double divisor, uint32_t log2overhead) { + if (m & (m - 1)) return spqlios_error("m must be a power of 2"); + if (is_not_pow2_double(&divisor)) return spqlios_error("divisor must be a power of 2"); + if (log2overhead > 52) return spqlios_error("log2overhead is too large"); + res->m = m; + res->divisor = divisor; + res->log2overhead = log2overhead; + // 52 + 11 + 1 + // ......1.......01(1)|expo|sign + // .......=========(1)|expo|sign msbbits = log2ovh + 2 + 11 + 1 + uint64_t nbits = 50 - log2overhead; + dblui64_t ovh_cst; + ovh_cst.d = 0.5 + (6 << log2overhead); + res->add_cst = ovh_cst.d * divisor; + res->mask_and = ((UINT64_C(1) << nbits) - 1); + res->mask_or = ovh_cst.u & ((UINT64_C(-1)) << nbits); + res->sub_cst = ovh_cst.d; + // TODO: check selection logic + if (CPU_SUPPORTS("avx2")) { + if (m >= 8) { + res->function = reim_to_tnx_avx; + } else { + res->function = reim_to_tnx_ref; + } + } else { + res->function = reim_to_tnx_ref; + } + return res; +} + +EXPORT REIM_TO_TNX_PRECOMP* new_reim_to_tnx_precomp(uint32_t m, double divisor, uint32_t log2overhead) { + REIM_TO_TNX_PRECOMP* res = malloc(sizeof(*res)); + if (!res) return spqlios_error(strerror(errno)); + return spqlios_keep_or_free(res, init_reim_to_tnx_precomp(res, m, divisor, log2overhead)); +} + +EXPORT void reim_to_tnx(const REIM_TO_TNX_PRECOMP* tables, double* r, const double* x) { + tables->function(tables, r, x); +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim4/reim4_arithmetic.h b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim4/reim4_arithmetic.h new file mode 100644 index 0000000..3bfb5c8 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim4/reim4_arithmetic.h @@ -0,0 +1,177 @@ +#ifndef SPQLIOS_REIM4_ARITHMETIC_H +#define SPQLIOS_REIM4_ARITHMETIC_H + +#include "../commons.h" + +// the reim4 structure represent 4 complex numbers, +// represented as re0,re1,re2,re3,im0,im1,im2,im3 + +// TODO implement these basic arithmetic functions. (only ref is needed) + +/** @brief dest = 0 */ +EXPORT void reim4_zero(double* dest); +/** @brief dest = a + b */ +EXPORT void reim4_add(double* dest, const double* u, const double* v); +/** @brief dest = a * b */ +EXPORT void reim4_mul(double* dest, const double* u, const double* v); +/** @brief dest += a * b */ +EXPORT void reim4_add_mul(double* dest, const double* a, const double* b); + +// TODO add the dot products: vec x mat1col, vec x mat2cols functions here (ref and avx) + +// TODO implement the convolution functions here (ref and avx) +/** + * @brief k-th coefficient of the convolution product a * b. + * + * The k-th coefficient is defined as sum_{i+j=k} a[i].b[j]. + * (in this sum, i and j must remain within the bounds [0,sizea[ and [0,sizeb[) + * + * In practice, accounting for these bounds, the convolution function boils down to + * ``` + * res := 0 + * if (k +#include + +#include "reim4_arithmetic.h" + +void reim4_extract_1blk_from_reim_avx(uint64_t m, uint64_t blk, + double* const dst, // nrows * 8 doubles + const double* const src // a contiguous array of nrows reim vectors +) { + assert(blk < (m >> 2)); + const double* src_ptr = src + (blk << 2); + double* dst_ptr = dst; + _mm256_storeu_pd(dst_ptr, _mm256_loadu_pd(src_ptr)); + _mm256_storeu_pd(dst_ptr + 4, _mm256_loadu_pd(src_ptr + m)); +} + +void reim4_extract_1blk_from_contiguous_reim_avx(uint64_t m, uint64_t nrows, uint64_t blk, double* const dst, + const double* const src) { + assert(blk < (m >> 2)); + const double* src_ptr = src + (blk << 2); + double* dst_ptr = dst; + for (uint64_t i = 0; i < nrows * 2; ++i) { + _mm256_storeu_pd(dst_ptr, _mm256_loadu_pd(src_ptr)); + dst_ptr += 4; + src_ptr += m; + } +} + +EXPORT void reim4_extract_1blk_from_contiguous_reim_sl_avx(uint64_t m, uint64_t sl, uint64_t nrows, uint64_t blk, + double* const dst, const double* const src) { + assert(blk < (m >> 2)); + const double* src_ptr = src + (blk << 2); + double* dst_ptr = dst; + for (uint64_t i = 0; i < nrows; ++i) { + _mm256_storeu_pd(dst_ptr, _mm256_loadu_pd(src_ptr)); + _mm256_storeu_pd(dst_ptr + 4, _mm256_loadu_pd(src_ptr + m)); + dst_ptr += 8; + src_ptr += sl; + } +} + +void reim4_save_1blk_to_reim_avx(uint64_t m, uint64_t blk, + double* dst, // 1 reim vector of length m + const double* src // 8 doubles +) { + assert(blk < (m >> 2)); + const double* src_ptr = src; + double* dst_ptr = dst + (blk << 2); + _mm256_storeu_pd(dst_ptr, _mm256_loadu_pd(src_ptr)); + _mm256_storeu_pd(dst_ptr + m, _mm256_loadu_pd(src_ptr + 4)); +} + +void reim4_add_1blk_to_reim_avx(uint64_t m, uint64_t blk, + double* dst, // 1 reim vector of length m + const double* src // 8 doubles +) { + assert(blk < (m >> 2)); + const double* src_ptr = src; + double* dst_ptr = dst + (blk << 2); + + __m256d s0 = _mm256_loadu_pd(src_ptr); + __m256d d0 = _mm256_loadu_pd(dst_ptr); + _mm256_storeu_pd(dst_ptr, _mm256_add_pd(s0, d0)); + + __m256d s1 = _mm256_loadu_pd(src_ptr + 4); + __m256d d1 = _mm256_loadu_pd(dst_ptr + m); + _mm256_storeu_pd(dst_ptr + m, _mm256_add_pd(s1, d1)); +} + +__always_inline void cplx_prod(__m256d* re1, __m256d* re2, __m256d* im, const double* const u_ptr, + const double* const v_ptr) { + const __m256d a = _mm256_loadu_pd(u_ptr); + const __m256d c = _mm256_loadu_pd(v_ptr); + *re1 = _mm256_fmadd_pd(a, c, *re1); + + const __m256d b = _mm256_loadu_pd(u_ptr + 4); + *im = _mm256_fmadd_pd(b, c, *im); + + const __m256d d = _mm256_loadu_pd(v_ptr + 4); + *re2 = _mm256_fmadd_pd(b, d, *re2); + *im = _mm256_fmadd_pd(a, d, *im); +} + +void reim4_vec_mat1col_product_avx2(const uint64_t nrows, + double* const dst, // 8 doubles + const double* const u, // nrows * 8 doubles + const double* const v // nrows * 8 doubles +) { + __m256d re1 = _mm256_setzero_pd(); + __m256d re2 = _mm256_setzero_pd(); + __m256d im1 = _mm256_setzero_pd(); + __m256d im2 = _mm256_setzero_pd(); + + const double* u_ptr = u; + const double* v_ptr = v; + for (uint64_t i = 0; i < nrows; ++i) { + const __m256d a = _mm256_loadu_pd(u_ptr); + const __m256d c = _mm256_loadu_pd(v_ptr); + re1 = _mm256_fmadd_pd(a, c, re1); + + const __m256d b = _mm256_loadu_pd(u_ptr + 4); + im2 = _mm256_fmadd_pd(b, c, im2); + + const __m256d d = _mm256_loadu_pd(v_ptr + 4); + re2 = _mm256_fmadd_pd(b, d, re2); + im1 = _mm256_fmadd_pd(a, d, im1); + + u_ptr += 8; + v_ptr += 8; + } + + _mm256_storeu_pd(dst, _mm256_sub_pd(re1, re2)); + _mm256_storeu_pd(dst + 4, _mm256_add_pd(im1, im2)); +} + +EXPORT void reim4_vec_mat2cols_product_avx2(const uint64_t nrows, + double* const dst, // 16 doubles + const double* const u, // nrows * 16 doubles + const double* const v // nrows * 16 doubles +) { + __m256d re1 = _mm256_setzero_pd(); + __m256d im1 = _mm256_setzero_pd(); + __m256d re2 = _mm256_setzero_pd(); + __m256d im2 = _mm256_setzero_pd(); + + __m256d ur, ui, ar, ai, br, bi; + for (uint64_t i = 0; i < nrows; ++i) { + ur = _mm256_loadu_pd(u + 8 * i); + ui = _mm256_loadu_pd(u + 8 * i + 4); + ar = _mm256_loadu_pd(v + 16 * i); + ai = _mm256_loadu_pd(v + 16 * i + 4); + br = _mm256_loadu_pd(v + 16 * i + 8); + bi = _mm256_loadu_pd(v + 16 * i + 12); + re1 = _mm256_fmsub_pd(ui, ai, re1); + re2 = _mm256_fmsub_pd(ui, bi, re2); + im1 = _mm256_fmadd_pd(ur, ai, im1); + im2 = _mm256_fmadd_pd(ur, bi, im2); + re1 = _mm256_fmsub_pd(ur, ar, re1); + re2 = _mm256_fmsub_pd(ur, br, re2); + im1 = _mm256_fmadd_pd(ui, ar, im1); + im2 = _mm256_fmadd_pd(ui, br, im2); + } + _mm256_storeu_pd(dst, re1); + _mm256_storeu_pd(dst + 4, im1); + _mm256_storeu_pd(dst + 8, re2); + _mm256_storeu_pd(dst + 12, im2); +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim4/reim4_arithmetic_ref.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim4/reim4_arithmetic_ref.c new file mode 100644 index 0000000..78b65eb --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim4/reim4_arithmetic_ref.c @@ -0,0 +1,254 @@ +#include +#include + +#include "reim4_arithmetic.h" + +// Stores the first 4 values (RE) of src + blk*4 and he first 4 values (IM) of src + blk*4 + m +// contiguously into dst +void reim4_extract_1blk_from_reim_ref(uint64_t m, uint64_t blk, + double* const dst, // 8 doubles + const double* const src // one reim4 vector +) { + assert(blk < (m >> 2)); + const double* src_ptr = src + (blk << 2); + // copy the real parts + dst[0] = src_ptr[0]; + dst[1] = src_ptr[1]; + dst[2] = src_ptr[2]; + dst[3] = src_ptr[3]; + src_ptr += m; + // copy the imaginary parts + dst[4] = src_ptr[0]; + dst[5] = src_ptr[1]; + dst[6] = src_ptr[2]; + dst[7] = src_ptr[3]; +} + +void reim4_extract_reim_from_1blk_ref(uint64_t m, uint64_t blk, + double* const dst, // 8 doubles + const double* const src // one reim4 vector +) { + assert(blk < (m >> 2)); + double* dst_ptr = dst + (blk << 2); + // copy the real parts + dst_ptr[0] = src[0]; + dst_ptr[1] = src[1]; + dst_ptr[2] = src[2]; + dst_ptr[3] = src[3]; + dst_ptr += m; + // copy the imaginary parts + dst_ptr[0] = src[4]; + dst_ptr[1] = src[5]; + dst_ptr[2] = src[6]; + dst_ptr[3] = src[7]; +} + +EXPORT void reim4_extract_1blk_from_contiguous_reim_ref(uint64_t m, uint64_t nrows, uint64_t blk, double* const dst, + const double* const src) { + assert(blk < (m >> 2)); + + const double* src_ptr = src + (blk << 2); + double* dst_ptr = dst; + for (uint64_t i = 0; i < nrows * 2; ++i) { + dst_ptr[0] = src_ptr[0]; + dst_ptr[1] = src_ptr[1]; + dst_ptr[2] = src_ptr[2]; + dst_ptr[3] = src_ptr[3]; + dst_ptr += 4; + src_ptr += m; + } +} + +EXPORT void reim4_extract_1blk_from_contiguous_reim_sl_ref(uint64_t m, uint64_t sl, uint64_t nrows, uint64_t blk, + double* const dst, const double* const src) { + assert(blk < (m >> 2)); + + const double* src_ptr = src + (blk << 2); + double* dst_ptr = dst; + const uint64_t sl_minus_m = sl - m; + for (uint64_t i = 0; i < nrows; ++i) { + dst_ptr[0] = src_ptr[0]; + dst_ptr[1] = src_ptr[1]; + dst_ptr[2] = src_ptr[2]; + dst_ptr[3] = src_ptr[3]; + dst_ptr += 4; + src_ptr += m; + dst_ptr[0] = src_ptr[0]; + dst_ptr[1] = src_ptr[1]; + dst_ptr[2] = src_ptr[2]; + dst_ptr[3] = src_ptr[3]; + dst_ptr += 4; + src_ptr += sl_minus_m; + } +} + +// dest(i)=src +// use scalar or sse or avx? Code +// should be the inverse of reim4_extract_1col_from_reim +void reim4_save_1blk_to_reim_ref(uint64_t m, uint64_t blk, + double* dst, // 1 reim vector of length m + const double* src // 8 doubles +) { + assert(blk < (m >> 2)); + double* dst_ptr = dst + (blk << 2); + // save the real part + dst_ptr[0] = src[0]; + dst_ptr[1] = src[1]; + dst_ptr[2] = src[2]; + dst_ptr[3] = src[3]; + dst_ptr += m; + // save the imag part + dst_ptr[0] = src[4]; + dst_ptr[1] = src[5]; + dst_ptr[2] = src[6]; + dst_ptr[3] = src[7]; +} + +void reim4_add_1blk_to_reim_ref(uint64_t m, uint64_t blk, + double* dst, // 1 reim vector of length m + const double* src // 8 doubles +) { + assert(blk < (m >> 2)); + double* dst_ptr = dst + (blk << 2); + // add the real part + dst_ptr[0] += src[0]; + dst_ptr[1] += src[1]; + dst_ptr[2] += src[2]; + dst_ptr[3] += src[3]; + dst_ptr += m; + // add the imag part + dst_ptr[0] += src[4]; + dst_ptr[1] += src[5]; + dst_ptr[2] += src[6]; + dst_ptr[3] += src[7]; +} + +// dest = 0 +void reim4_zero(double* const dst // 8 doubles +) { + for (uint64_t i = 0; i < 8; ++i) dst[i] = 0; +} + +/** @brief dest = a + b */ +void reim4_add(double* const dst, // 8 doubles + const double* const u, // nrows * 8 doubles + const double* const v // nrows * 8 doubles +) { + for (uint64_t k = 0; k < 4; ++k) { + const double a = u[k]; + const double c = v[k]; + const double b = u[k + 4]; + const double d = v[k + 4]; + + dst[k] = a + c; + dst[k + 4] = b + d; + } +} + +/** @brief dest = a * b */ +void reim4_mul(double* const dst, // 8 doubles + const double* const u, // 8 doubles + const double* const v // 8 doubles +) { + for (uint64_t k = 0; k < 4; ++k) { + const double a = u[k]; + const double c = v[k]; + const double b = u[k + 4]; + const double d = v[k + 4]; + + dst[k] = a * c - b * d; + dst[k + 4] = a * d + b * c; + } +} + +/** @brief dest += a * b */ +void reim4_add_mul(double* const dst, // 8 doubles + const double* const u, // 8 doubles + const double* const v // 8 doubles +) { + for (uint64_t k = 0; k < 4; ++k) { + const double a = u[k]; + const double c = v[k]; + const double b = u[k + 4]; + const double d = v[k + 4]; + + dst[k] += a * c - b * d; + dst[k + 4] += a * d + b * c; + } +} + +/** dest = uT * v where u is a vector of size nrows, and v is a nrows x 1 matrix */ +void reim4_vec_mat1col_product_ref(const uint64_t nrows, + double* const dst, // 8 doubles + const double* const u, // nrows * 8 doubles + const double* const v // nrows * 8 doubles +) { + reim4_zero(dst); + + for (uint64_t i = 0, j = 0; i < nrows; ++i, j += 8) { + reim4_add_mul(dst, u + j, v + j); + } +} + +/** dest = uT * v where u is a vector of size nrows, and v is a nrows x 2 matrix */ +void reim4_vec_mat2cols_product_ref(const uint64_t nrows, + double* const dst, // 16 doubles + const double* const u, // nrows * 8 doubles + const double* const v // nrows * 16 doubles +) { + reim4_zero(dst); + reim4_zero(dst + 8); + + double* dst1 = dst; + double* dst2 = dst + 8; + + for (uint64_t i = 0, j = 0; i < nrows; ++i, j += 8) { + uint64_t double_j = j << 1; + reim4_add_mul(dst1, u + j, v + double_j); + reim4_add_mul(dst2, u + j, v + double_j + 8); + } +} + +/** + * @brief k-th coefficient of the convolution product a * b. + * + * The k-th coefficient is defined as sum_{i+j=k} a[i].b[j]. + * (in this sum, i and j must remain within the bounds [0,sizea[ and [0,sizeb[) + * + * In practice, accounting for these bounds, the convolution function boils down to + * ``` + * res := 0 + * if (k= sizea + sizeb) return; + uint64_t jmin = k >= sizea ? k + 1 - sizea : 0; + uint64_t jmax = k < sizeb ? k + 1 : sizeb; + for (uint64_t j = jmin; j < jmax; ++j) { + reim4_add_mul(dest, a + 8 * (k - j), b + 8 * j); + } +} + +/** @brief returns two consecutive convolution coefficients: k and k+1 */ +EXPORT void reim4_convolution_2coeff_ref(uint64_t k, double* dest, const double* a, uint64_t sizea, const double* b, + uint64_t sizeb) { + reim4_convolution_1coeff_ref(k, dest, a, sizea, b, sizeb); + reim4_convolution_1coeff_ref(k + 1, dest + 8, a, sizea, b, sizeb); +} + +/** + * @brief From the convolution a * b, return the coefficients between offset and offset + size + * For the full convolution, use offset=0 and size=sizea+sizeb-1. + */ +EXPORT void reim4_convolution_ref(double* dest, uint64_t dest_size, uint64_t dest_offset, const double* a, + uint64_t sizea, const double* b, uint64_t sizeb) { + for (uint64_t k = 0; k < dest_size; ++k) { + reim4_convolution_1coeff_ref(k + dest_offset, dest + 8 * k, a, sizea, b, sizeb); + } +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim4/reim4_execute.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim4/reim4_execute.c new file mode 100644 index 0000000..9618393 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim4/reim4_execute.c @@ -0,0 +1,19 @@ +#include "reim4_fftvec_internal.h" +#include "reim4_fftvec_private.h" + +EXPORT void reim4_fftvec_addmul(const REIM4_FFTVEC_ADDMUL_PRECOMP* tables, double* r, const double* a, + const double* b) { + tables->function(tables, r, a, b); +} + +EXPORT void reim4_fftvec_mul(const REIM4_FFTVEC_MUL_PRECOMP* tables, double* r, const double* a, const double* b) { + tables->function(tables, r, a, b); +} + +EXPORT void reim4_from_cplx(const REIM4_FROM_CPLX_PRECOMP* tables, double* r, const void* a) { + tables->function(tables, r, a); +} + +EXPORT void reim4_to_cplx(const REIM4_TO_CPLX_PRECOMP* tables, void* r, const double* a) { + tables->function(tables, r, a); +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim4/reim4_fallbacks_aarch64.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim4/reim4_fallbacks_aarch64.c new file mode 100644 index 0000000..595c031 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim4/reim4_fallbacks_aarch64.c @@ -0,0 +1,13 @@ +#include "reim4_fftvec_private.h" + +EXPORT void reim4_fftvec_mul_fma(const REIM4_FFTVEC_MUL_PRECOMP* tables, double* r, const double* a, + const double* b){UNDEFINED()} + +EXPORT void reim4_fftvec_addmul_fma(const REIM4_FFTVEC_ADDMUL_PRECOMP* tables, double* r, const double* a, + const double* b){UNDEFINED()} + +EXPORT void reim4_from_cplx_fma(const REIM4_FROM_CPLX_PRECOMP* tables, double* r, const void* a){UNDEFINED()} + +EXPORT void reim4_to_cplx_fma(const REIM4_TO_CPLX_PRECOMP* tables, void* r, const double* a) { + UNDEFINED() +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim4/reim4_fftvec_addmul_fma.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim4/reim4_fftvec_addmul_fma.c new file mode 100644 index 0000000..833a95b --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim4/reim4_fftvec_addmul_fma.c @@ -0,0 +1,54 @@ +#include +#include +#include + +#include "reim4_fftvec_private.h" + +EXPORT void reim4_fftvec_addmul_fma(const REIM4_FFTVEC_ADDMUL_PRECOMP* tables, double* r_ptr, const double* a_ptr, + const double* b_ptr) { + const double* const rend_ptr = r_ptr + (tables->m << 1); + while (r_ptr != rend_ptr) { + __m256d rr = _mm256_loadu_pd(r_ptr); + __m256d ri = _mm256_loadu_pd(r_ptr + 4); + const __m256d ar = _mm256_loadu_pd(a_ptr); + const __m256d ai = _mm256_loadu_pd(a_ptr + 4); + const __m256d br = _mm256_loadu_pd(b_ptr); + const __m256d bi = _mm256_loadu_pd(b_ptr + 4); + + rr = _mm256_fmsub_pd(ai, bi, rr); + rr = _mm256_fmsub_pd(ar, br, rr); + ri = _mm256_fmadd_pd(ar, bi, ri); + ri = _mm256_fmadd_pd(ai, br, ri); + + _mm256_storeu_pd(r_ptr, rr); + _mm256_storeu_pd(r_ptr + 4, ri); + + r_ptr += 8; + a_ptr += 8; + b_ptr += 8; + } +} + +EXPORT void reim4_fftvec_mul_fma(const REIM4_FFTVEC_MUL_PRECOMP* tables, double* r_ptr, const double* a_ptr, + const double* b_ptr) { + const double* const rend_ptr = r_ptr + (tables->m << 1); + while (r_ptr != rend_ptr) { + const __m256d ar = _mm256_loadu_pd(a_ptr); + const __m256d ai = _mm256_loadu_pd(a_ptr + 4); + const __m256d br = _mm256_loadu_pd(b_ptr); + const __m256d bi = _mm256_loadu_pd(b_ptr + 4); + + const __m256d t1 = _mm256_mul_pd(ai, bi); + const __m256d t2 = _mm256_mul_pd(ar, bi); + + __m256d rr = _mm256_fmsub_pd(ar, br, t1); + __m256d ri = _mm256_fmadd_pd(ai, br, t2); + + _mm256_storeu_pd(r_ptr, rr); + _mm256_storeu_pd(r_ptr + 4, ri); + + r_ptr += 8; + a_ptr += 8; + b_ptr += 8; + } +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim4/reim4_fftvec_addmul_ref.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim4/reim4_fftvec_addmul_ref.c new file mode 100644 index 0000000..5254fd6 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim4/reim4_fftvec_addmul_ref.c @@ -0,0 +1,97 @@ +#include +#include +#include +#include + +#include "../commons_private.h" +#include "reim4_fftvec_internal.h" +#include "reim4_fftvec_private.h" + +void* init_reim4_fftvec_addmul_precomp(REIM4_FFTVEC_ADDMUL_PRECOMP* res, uint32_t m) { + res->m = m; + if (CPU_SUPPORTS("fma")) { + if (m >= 2) { + res->function = reim4_fftvec_addmul_fma; + } else { + res->function = reim4_fftvec_addmul_ref; + } + } else { + res->function = reim4_fftvec_addmul_ref; + } + return res; +} + +EXPORT REIM4_FFTVEC_ADDMUL_PRECOMP* new_reim4_fftvec_addmul_precomp(uint32_t m) { + REIM4_FFTVEC_ADDMUL_PRECOMP* res = malloc(sizeof(*res)); + if (!res) return spqlios_error(strerror(errno)); + return spqlios_keep_or_free(res, init_reim4_fftvec_addmul_precomp(res, m)); +} + +EXPORT void reim4_fftvec_addmul_ref(const REIM4_FFTVEC_ADDMUL_PRECOMP* precomp, double* r, const double* a, + const double* b) { + const uint64_t m = precomp->m; + for (uint64_t j = 0; j < m / 4; ++j) { + for (uint64_t i = 0; i < 4; ++i) { + double re = a[i] * b[i] - a[i + 4] * b[i + 4]; + double im = a[i] * b[i + 4] + a[i + 4] * b[i]; + r[i] += re; + r[i + 4] += im; + } + a += 8; + b += 8; + r += 8; + } +} + +EXPORT void reim4_fftvec_addmul_simple(uint32_t m, double* r, const double* a, const double* b) { + static REIM4_FFTVEC_ADDMUL_PRECOMP precomp[32]; + REIM4_FFTVEC_ADDMUL_PRECOMP* p = precomp + log2m(m); + if (!p->function) { + if (!init_reim4_fftvec_addmul_precomp(p, m)) abort(); + } + p->function(p, r, a, b); +} + +void* init_reim4_fftvec_mul_precomp(REIM4_FFTVEC_MUL_PRECOMP* res, uint32_t m) { + res->m = m; + if (CPU_SUPPORTS("fma")) { + if (m >= 4) { + res->function = reim4_fftvec_mul_fma; + } else { + res->function = reim4_fftvec_mul_ref; + } + } else { + res->function = reim4_fftvec_mul_ref; + } + return res; +} + +EXPORT REIM4_FFTVEC_MUL_PRECOMP* new_reim4_fftvec_mul_precomp(uint32_t m) { + REIM4_FFTVEC_MUL_PRECOMP* res = malloc(sizeof(*res)); + if (!res) return spqlios_error(strerror(errno)); + return spqlios_keep_or_free(res, init_reim4_fftvec_mul_precomp(res, m)); +} + +EXPORT void reim4_fftvec_mul_ref(const REIM4_FFTVEC_MUL_PRECOMP* precomp, double* r, const double* a, const double* b) { + const uint64_t m = precomp->m; + for (uint64_t j = 0; j < m / 4; ++j) { + for (uint64_t i = 0; i < 4; ++i) { + double re = a[i] * b[i] - a[i + 4] * b[i + 4]; + double im = a[i] * b[i + 4] + a[i + 4] * b[i]; + r[i] = re; + r[i + 4] = im; + } + a += 8; + b += 8; + r += 8; + } +} + +EXPORT void reim4_fftvec_mul_simple(uint32_t m, double* r, const double* a, const double* b) { + static REIM4_FFTVEC_MUL_PRECOMP precomp[32]; + REIM4_FFTVEC_MUL_PRECOMP* p = precomp + log2m(m); + if (!p->function) { + if (!init_reim4_fftvec_mul_precomp(p, m)) abort(); + } + p->function(p, r, a, b); +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim4/reim4_fftvec_conv_fma.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim4/reim4_fftvec_conv_fma.c new file mode 100644 index 0000000..c175d0e --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim4/reim4_fftvec_conv_fma.c @@ -0,0 +1,37 @@ +#include +#include +#include + +#include "reim4_fftvec_private.h" + +EXPORT void reim4_from_cplx_fma(const REIM4_FROM_CPLX_PRECOMP* tables, double* r_ptr, const void* a) { + const double* const rend_ptr = r_ptr + (tables->m << 1); + + const double* a_ptr = (double*)a; + while (r_ptr != rend_ptr) { + __m256d t1 = _mm256_loadu_pd(a_ptr); + __m256d t2 = _mm256_loadu_pd(a_ptr + 4); + + _mm256_storeu_pd(r_ptr, _mm256_unpacklo_pd(t1, t2)); + _mm256_storeu_pd(r_ptr + 4, _mm256_unpackhi_pd(t1, t2)); + + r_ptr += 8; + a_ptr += 8; + } +} + +EXPORT void reim4_to_cplx_fma(const REIM4_TO_CPLX_PRECOMP* tables, void* r, const double* a_ptr) { + const double* const aend_ptr = a_ptr + (tables->m << 1); + double* r_ptr = (double*)r; + + while (a_ptr != aend_ptr) { + __m256d t1 = _mm256_loadu_pd(a_ptr); + __m256d t2 = _mm256_loadu_pd(a_ptr + 4); + + _mm256_storeu_pd(r_ptr, _mm256_unpacklo_pd(t1, t2)); + _mm256_storeu_pd(r_ptr + 4, _mm256_unpackhi_pd(t1, t2)); + + r_ptr += 8; + a_ptr += 8; + } +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim4/reim4_fftvec_conv_ref.c b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim4/reim4_fftvec_conv_ref.c new file mode 100644 index 0000000..5bee8ff --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim4/reim4_fftvec_conv_ref.c @@ -0,0 +1,116 @@ +#include +#include +#include +#include + +#include "../commons_private.h" +#include "reim4_fftvec_internal.h" +#include "reim4_fftvec_private.h" + +EXPORT void reim4_from_cplx_ref(const REIM4_FROM_CPLX_PRECOMP* tables, double* r, const void* a) { + const double* x = (double*)a; + const uint64_t m = tables->m; + for (uint64_t i = 0; i < m / 4; ++i) { + double r0 = x[0]; + double i0 = x[1]; + double r1 = x[2]; + double i1 = x[3]; + double r2 = x[4]; + double i2 = x[5]; + double r3 = x[6]; + double i3 = x[7]; + r[0] = r0; + r[1] = r2; + r[2] = r1; + r[3] = r3; + r[4] = i0; + r[5] = i2; + r[6] = i1; + r[7] = i3; + x += 8; + r += 8; + } +} + +void* init_reim4_from_cplx_precomp(REIM4_FROM_CPLX_PRECOMP* res, uint32_t nn) { + res->m = nn / 2; + if (CPU_SUPPORTS("fma")) { + if (nn >= 4) { + res->function = reim4_from_cplx_fma; + } else { + res->function = reim4_from_cplx_ref; + } + } else { + res->function = reim4_from_cplx_ref; + } + return res; +} + +EXPORT REIM4_FROM_CPLX_PRECOMP* new_reim4_from_cplx_precomp(uint32_t m) { + REIM4_FROM_CPLX_PRECOMP* res = malloc(sizeof(*res)); + if (!res) return spqlios_error(strerror(errno)); + return spqlios_keep_or_free(res, init_reim4_from_cplx_precomp(res, m)); +} + +EXPORT void reim4_from_cplx_simple(uint32_t m, double* r, const void* a) { + static REIM4_FROM_CPLX_PRECOMP precomp[32]; + REIM4_FROM_CPLX_PRECOMP* p = precomp + log2m(m); + if (!p->function) { + if (!init_reim4_from_cplx_precomp(p, m)) abort(); + } + p->function(p, r, a); +} + +EXPORT void reim4_to_cplx_ref(const REIM4_TO_CPLX_PRECOMP* tables, void* r, const double* a) { + double* y = (double*)r; + const uint64_t m = tables->m; + for (uint64_t i = 0; i < m / 4; ++i) { + double r0 = a[0]; + double r2 = a[1]; + double r1 = a[2]; + double r3 = a[3]; + double i0 = a[4]; + double i2 = a[5]; + double i1 = a[6]; + double i3 = a[7]; + y[0] = r0; + y[1] = i0; + y[2] = r1; + y[3] = i1; + y[4] = r2; + y[5] = i2; + y[6] = r3; + y[7] = i3; + a += 8; + y += 8; + } +} + +void* init_reim4_to_cplx_precomp(REIM4_TO_CPLX_PRECOMP* res, uint32_t m) { + res->m = m; + if (CPU_SUPPORTS("fma")) { + if (m >= 2) { + res->function = reim4_to_cplx_fma; + } else { + res->function = reim4_to_cplx_ref; + } + } else { + res->function = reim4_to_cplx_ref; + } + return res; +} + +EXPORT REIM4_TO_CPLX_PRECOMP* new_reim4_to_cplx_precomp(uint32_t m) { + REIM4_TO_CPLX_PRECOMP* res = malloc(sizeof(*res)); + if (!res) return spqlios_error(strerror(errno)); + return spqlios_keep_or_free(res, init_reim4_to_cplx_precomp(res, m)); +} + +EXPORT void reim4_to_cplx_simple(uint32_t m, void* r, const double* a) { + static REIM4_TO_CPLX_PRECOMP precomp[32]; + REIM4_TO_CPLX_PRECOMP* p = precomp + log2m(m); + if (!p->function) { + if (!init_reim4_to_cplx_precomp(p, m)) abort(); + } + p->function(p, r, a); +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim4/reim4_fftvec_internal.h b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim4/reim4_fftvec_internal.h new file mode 100644 index 0000000..4b076f0 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim4/reim4_fftvec_internal.h @@ -0,0 +1,20 @@ +#ifndef SPQLIOS_REIM4_FFTVEC_INTERNAL_H +#define SPQLIOS_REIM4_FFTVEC_INTERNAL_H + +#include "reim4_fftvec_public.h" + +EXPORT void reim4_fftvec_mul_ref(const REIM4_FFTVEC_MUL_PRECOMP* tables, double* r, const double* a, const double* b); +EXPORT void reim4_fftvec_mul_fma(const REIM4_FFTVEC_MUL_PRECOMP* tables, double* r, const double* a, const double* b); + +EXPORT void reim4_fftvec_addmul_ref(const REIM4_FFTVEC_ADDMUL_PRECOMP* tables, double* r, const double* a, + const double* b); +EXPORT void reim4_fftvec_addmul_fma(const REIM4_FFTVEC_ADDMUL_PRECOMP* tables, double* r, const double* a, + const double* b); + +EXPORT void reim4_from_cplx_ref(const REIM4_FROM_CPLX_PRECOMP* tables, double* r, const void* a); +EXPORT void reim4_from_cplx_fma(const REIM4_FROM_CPLX_PRECOMP* tables, double* r, const void* a); + +EXPORT void reim4_to_cplx_ref(const REIM4_TO_CPLX_PRECOMP* tables, void* r, const double* a); +EXPORT void reim4_to_cplx_fma(const REIM4_TO_CPLX_PRECOMP* tables, void* r, const double* a); + +#endif // SPQLIOS_REIM4_FFTVEC_INTERNAL_H diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim4/reim4_fftvec_private.h b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim4/reim4_fftvec_private.h new file mode 100644 index 0000000..98a5286 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim4/reim4_fftvec_private.h @@ -0,0 +1,33 @@ +#ifndef SPQLIOS_REIM4_FFTVEC_PRIVATE_H +#define SPQLIOS_REIM4_FFTVEC_PRIVATE_H + +#include "reim4_fftvec_public.h" + +#define STATIC_ASSERT(condition) (void)sizeof(char[-1 + 2 * !!(condition)]) + +typedef void (*R4_FFTVEC_MUL_FUNC)(const REIM4_FFTVEC_MUL_PRECOMP*, double*, const double*, const double*); +typedef void (*R4_FFTVEC_ADDMUL_FUNC)(const REIM4_FFTVEC_ADDMUL_PRECOMP*, double*, const double*, const double*); +typedef void (*R4_FROM_CPLX_FUNC)(const REIM4_FROM_CPLX_PRECOMP*, double*, const void*); +typedef void (*R4_TO_CPLX_FUNC)(const REIM4_TO_CPLX_PRECOMP*, void*, const double*); + +struct reim4_mul_precomp { + R4_FFTVEC_MUL_FUNC function; + int64_t m; +}; + +struct reim4_addmul_precomp { + R4_FFTVEC_ADDMUL_FUNC function; + int64_t m; +}; + +struct reim4_from_cplx_precomp { + R4_FROM_CPLX_FUNC function; + int64_t m; +}; + +struct reim4_to_cplx_precomp { + R4_TO_CPLX_FUNC function; + int64_t m; +}; + +#endif // SPQLIOS_REIM4_FFTVEC_PRIVATE_H diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim4/reim4_fftvec_public.h b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim4/reim4_fftvec_public.h new file mode 100644 index 0000000..c833dcf --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/spqlios/reim4/reim4_fftvec_public.h @@ -0,0 +1,59 @@ +#ifndef SPQLIOS_REIM4_FFTVEC_PUBLIC_H +#define SPQLIOS_REIM4_FFTVEC_PUBLIC_H + +#include "../commons.h" + +typedef struct reim4_addmul_precomp REIM4_FFTVEC_ADDMUL_PRECOMP; +typedef struct reim4_mul_precomp REIM4_FFTVEC_MUL_PRECOMP; +typedef struct reim4_from_cplx_precomp REIM4_FROM_CPLX_PRECOMP; +typedef struct reim4_to_cplx_precomp REIM4_TO_CPLX_PRECOMP; + +EXPORT REIM4_FFTVEC_MUL_PRECOMP* new_reim4_fftvec_mul_precomp(uint32_t m); +EXPORT void reim4_fftvec_mul(const REIM4_FFTVEC_MUL_PRECOMP* tables, double* r, const double* a, const double* b); +#define delete_reim4_fftvec_mul_precomp free + +EXPORT REIM4_FFTVEC_ADDMUL_PRECOMP* new_reim4_fftvec_addmul_precomp(uint32_t nn); +EXPORT void reim4_fftvec_addmul(const REIM4_FFTVEC_ADDMUL_PRECOMP* tables, double* r, const double* a, const double* b); +#define delete_reim4_fftvec_addmul_precomp free + +/** + * @brief prepares a conversion from the cplx fftvec layout to the reim4 layout. + * @param m complex dimension m from C[X] mod X^m-i. + */ +EXPORT REIM4_FROM_CPLX_PRECOMP* new_reim4_from_cplx_precomp(uint32_t m); +EXPORT void reim4_from_cplx(const REIM4_FROM_CPLX_PRECOMP* tables, double* r, const void* a); +#define delete_reim4_from_cplx_precomp free + +/** + * @brief prepares a conversion from the reim4 fftvec layout to the cplx layout + * @param m the complex dimension m from C[X] mod X^m-i. + */ +EXPORT REIM4_TO_CPLX_PRECOMP* new_reim4_to_cplx_precomp(uint32_t m); +EXPORT void reim4_to_cplx(const REIM4_TO_CPLX_PRECOMP* tables, void* r, const double* a); +#define delete_reim4_to_cplx_precomp free + +/** + * @brief Simpler API for the fftvec multiplication function. + * For each dimension, the precomputed tables for this dimension are generated automatically the first time. + * It is advised to do one dry-run call per desired dimension before using in a multithread environment */ +EXPORT void reim4_fftvec_mul_simple(uint32_t m, double* r, const double* a, const double* b); + +/** + * @brief Simpler API for the fftvec addmul function. + * For each dimension, the precomputed tables for this dimension are generated automatically the first time. + * It is advised to do one dry-run call per desired dimension before using in a multithread environment */ +EXPORT void reim4_fftvec_addmul_simple(uint32_t m, double* r, const double* a, const double* b); + +/** + * @brief Simpler API for cplx conversion. + * For each dimension, the precomputed tables for this dimension are generated automatically the first time. + * It is advised to do one dry-run call per desired dimension before using in a multithread environment */ +EXPORT void reim4_from_cplx_simple(uint32_t m, double* r, const void* a); + +/** + * @brief Simpler API for to cplx conversion. + * For each dimension, the precomputed tables for this dimension are generated automatically the first time. + * It is advised to do one dry-run call per desired dimension before using in a multithread environment */ +EXPORT void reim4_to_cplx_simple(uint32_t m, void* r, const double* a); + +#endif // SPQLIOS_REIM4_FFTVEC_PUBLIC_H \ No newline at end of file diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/CMakeLists.txt b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/CMakeLists.txt new file mode 100644 index 0000000..09f0b7a --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/CMakeLists.txt @@ -0,0 +1,142 @@ +set(CMAKE_CXX_STANDARD 17) + +set(test_incs ..) +set(gtest_libs) +set(benchmark_libs) +# searching for libgtest +find_path(gtest_inc NAMES gtest/gtest.h) +find_library(gtest NAMES gtest) +find_library(gtest_main REQUIRED NAMES gtest_main) +if (gtest_inc AND gtest AND gtest_main) + message(STATUS "Found gtest: I=${gtest_inc} L=${gtest},${gtest_main}") + set(test_incs ${test_incs} ${gtest_inc}) + set(gtest_libs ${gtest_libs} ${gtest} ${gtest_main} pthread) +else() + message(FATAL_ERROR "Libgtest not found (required if ENABLE_TESTING is on): I=${gtest_inc} L=${gtest},${gtest_main}") +endif() +# searching for libbenchmark +find_path(benchmark_inc NAMES benchmark/benchmark.h) +find_library(benchmark NAMES benchmark) +if (benchmark_inc AND benchmark) + message(STATUS "Found benchmark: I=${benchmark_inc} L=${benchmark}") + set(test_incs ${test_incs} ${benchmark_inc}) + set(benchmark_libs ${benchmark_libs} ${benchmark}) +else() + message(FATAL_ERROR "Libbenchmark not found (required if ENABLE_TESTING is on): I=${benchmark_inc} L=${benchmark}") +endif() +find_path(VALGRIND_DIR NAMES valgrind/valgrind.h) +if (VALGRIND_DIR) + message(STATUS "Found valgrind header ${VALGRIND_DIR}") +else () + # for now, we will fail if we don't find valgrind for tests + message(STATUS "CANNOT FIND valgrind header: ${VALGRIND_DIR}") +endif () + +add_library(spqlios-testlib SHARED + testlib/random.cpp + testlib/test_commons.h + testlib/test_commons.cpp + testlib/mod_q120.h + testlib/mod_q120.cpp + testlib/negacyclic_polynomial.cpp + testlib/negacyclic_polynomial.h + testlib/negacyclic_polynomial_impl.h + testlib/reim4_elem.cpp + testlib/reim4_elem.h + testlib/fft64_dft.cpp + testlib/fft64_dft.h + testlib/fft64_layouts.h + testlib/fft64_layouts.cpp + testlib/ntt120_layouts.cpp + testlib/ntt120_layouts.h + testlib/ntt120_dft.cpp + testlib/ntt120_dft.h + testlib/test_hash.cpp + testlib/sha3.h + testlib/sha3.c + testlib/polynomial_vector.h + testlib/polynomial_vector.cpp + testlib/vec_rnx_layout.h + testlib/vec_rnx_layout.cpp + testlib/zn_layouts.h + testlib/zn_layouts.cpp +) +if (VALGRIND_DIR) + target_include_directories(spqlios-testlib PRIVATE ${VALGRIND_DIR}) + target_compile_definitions(spqlios-testlib PRIVATE VALGRIND_MEM_TESTS) +endif () +target_link_libraries(spqlios-testlib libspqlios) + + + +# main unittest file +message(STATUS "${gtest_libs}") +set(UNITTEST_FILES + spqlios_test.cpp + spqlios_reim_conversions_test.cpp + spqlios_reim_test.cpp + spqlios_reim4_arithmetic_test.cpp + spqlios_cplx_test.cpp + spqlios_cplx_conversions_test.cpp + spqlios_q120_ntt_test.cpp + spqlios_q120_arithmetic_test.cpp + spqlios_coeffs_arithmetic_test.cpp + spqlios_vec_znx_big_test.cpp + spqlios_znx_small_test.cpp + spqlios_vmp_product_test.cpp + spqlios_vec_znx_dft_test.cpp + spqlios_svp_apply_test.cpp + spqlios_svp_prepare_test.cpp + spqlios_vec_znx_test.cpp + spqlios_vec_rnx_test.cpp + spqlios_vec_rnx_vmp_test.cpp + spqlios_vec_rnx_conversions_test.cpp + spqlios_vec_rnx_ppol_test.cpp + spqlios_vec_rnx_approxdecomp_tnxdbl_test.cpp + spqlios_zn_approxdecomp_test.cpp + spqlios_zn_conversions_test.cpp + spqlios_zn_vmp_test.cpp + + +) + +add_executable(spqlios-test ${UNITTEST_FILES}) +target_link_libraries(spqlios-test spqlios-testlib libspqlios ${gtest_libs}) +target_include_directories(spqlios-test PRIVATE ${test_incs}) +add_test(NAME spqlios-test COMMAND spqlios-test) +if (WIN32) + # copy the dlls to the test directory + cmake_minimum_required(VERSION 3.26) + add_custom_command( + POST_BUILD + TARGET spqlios-test + COMMAND ${CMAKE_COMMAND} -E copy + -t $ $ $ + COMMAND_EXPAND_LISTS + ) +endif() + +# benchmarks +add_executable(spqlios-cplx-fft-bench spqlios_cplx_fft_bench.cpp) +target_link_libraries(spqlios-cplx-fft-bench libspqlios ${benchmark_libs} pthread) +target_include_directories(spqlios-cplx-fft-bench PRIVATE ${test_incs}) + +if (X86 OR X86_WIN32) + add_executable(spqlios-q120-ntt-bench spqlios_q120_ntt_bench.cpp) + target_link_libraries(spqlios-q120-ntt-bench libspqlios ${benchmark_libs} pthread) + target_include_directories(spqlios-q120-ntt-bench PRIVATE ${test_incs}) + + add_executable(spqlios-q120-arithmetic-bench spqlios_q120_arithmetic_bench.cpp) + target_link_libraries(spqlios-q120-arithmetic-bench libspqlios ${benchmark_libs} pthread) + target_include_directories(spqlios-q120-arithmetic-bench PRIVATE ${test_incs}) +endif () + +if (X86 OR X86_WIN32) + add_executable(spqlios_reim4_arithmetic_bench spqlios_reim4_arithmetic_bench.cpp) + target_link_libraries(spqlios_reim4_arithmetic_bench ${benchmark_libs} libspqlios pthread) + target_include_directories(spqlios_reim4_arithmetic_bench PRIVATE ${test_incs}) +endif () + +if (DEVMODE_INSTALL) + install(TARGETS spqlios-testlib) +endif() diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_coeffs_arithmetic_test.cpp b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_coeffs_arithmetic_test.cpp new file mode 100644 index 0000000..c0cf0e9 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_coeffs_arithmetic_test.cpp @@ -0,0 +1,488 @@ +#include +#include + +#include +#include +#include +#include + +#include "../spqlios/coeffs/coeffs_arithmetic.h" +#include "test/testlib/mod_q120.h" +#include "testlib/negacyclic_polynomial.h" +#include "testlib/test_commons.h" + +/// tests of element-wise operations +template +void test_elemw_op(F elemw_op, G poly_elemw_op) { + for (uint64_t n : {1, 2, 4, 8, 16, 64, 256, 4096}) { + polynomial a = polynomial::random(n); + polynomial b = polynomial::random(n); + polynomial expect(n); + polynomial actual(n); + // out of place + expect = poly_elemw_op(a, b); + elemw_op(n, actual.data(), a.data(), b.data()); + ASSERT_EQ(actual, expect); + // in place 1 + actual = polynomial::random(n); + expect = poly_elemw_op(actual, b); + elemw_op(n, actual.data(), actual.data(), b.data()); + ASSERT_EQ(actual, expect); + // in place 2 + actual = polynomial::random(n); + expect = poly_elemw_op(a, actual); + elemw_op(n, actual.data(), a.data(), actual.data()); + ASSERT_EQ(actual, expect); + // in place 3 + actual = polynomial::random(n); + expect = poly_elemw_op(actual, actual); + elemw_op(n, actual.data(), actual.data(), actual.data()); + ASSERT_EQ(actual, expect); + } +} + +static polynomial poly_i64_add(const polynomial& u, polynomial& v) { return u + v; } +static polynomial poly_i64_sub(const polynomial& u, polynomial& v) { return u - v; } +TEST(coeffs_arithmetic, znx_add_i64_ref) { test_elemw_op(znx_add_i64_ref, poly_i64_add); } +TEST(coeffs_arithmetic, znx_sub_i64_ref) { test_elemw_op(znx_sub_i64_ref, poly_i64_sub); } +#ifdef __x86_64__ +TEST(coeffs_arithmetic, znx_add_i64_avx) { test_elemw_op(znx_add_i64_avx, poly_i64_add); } +TEST(coeffs_arithmetic, znx_sub_i64_avx) { test_elemw_op(znx_sub_i64_avx, poly_i64_sub); } +#endif + +/// tests of element-wise operations +template +void test_elemw_unary_op(F elemw_op, G poly_elemw_op) { + for (uint64_t n : {1, 2, 4, 8, 16, 64, 256, 4096}) { + polynomial a = polynomial::random(n); + polynomial expect(n); + polynomial actual(n); + // out of place + expect = poly_elemw_op(a); + elemw_op(n, actual.data(), a.data()); + ASSERT_EQ(actual, expect); + // in place + actual = polynomial::random(n); + expect = poly_elemw_op(actual); + elemw_op(n, actual.data(), actual.data()); + ASSERT_EQ(actual, expect); + } +} + +static polynomial poly_i64_neg(const polynomial& u) { return -u; } +static polynomial poly_i64_copy(const polynomial& u) { return u; } +TEST(coeffs_arithmetic, znx_neg_i64_ref) { test_elemw_unary_op(znx_negate_i64_ref, poly_i64_neg); } +TEST(coeffs_arithmetic, znx_copy_i64_ref) { test_elemw_unary_op(znx_copy_i64_ref, poly_i64_copy); } +#ifdef __x86_64__ +TEST(coeffs_arithmetic, znx_neg_i64_avx) { test_elemw_unary_op(znx_negate_i64_avx, poly_i64_neg); } +#endif + +/// tests of the rotations out of place +template +void test_rotation_outplace(F rotate) { + for (uint64_t n : {1, 2, 4, 8, 16, 64, 256, 4096}) { + polynomial poly = polynomial::random(n); + polynomial expect(n); + polynomial actual(n); + for (uint64_t trial = 0; trial < 10; ++trial) { + int64_t p = uniform_i64_bits(32); + // rotate by p + for (uint64_t i = 0; i < n; ++i) { + expect.set_coeff(i, poly.get_coeff(i - p)); + } + // rotate using the function + rotate(n, p, actual.data(), poly.data()); + ASSERT_EQ(actual, expect); + } + } +} + +TEST(coeffs_arithmetic, rnx_rotate_f64) { test_rotation_outplace(rnx_rotate_f64); } +TEST(coeffs_arithmetic, znx_rotate_i64) { test_rotation_outplace(znx_rotate_i64); } + +/// tests of the rotations out of place +template +void test_rotation_inplace(F rotate) { + for (uint64_t n : {1, 2, 4, 8, 16, 64, 256, 4096}) { + polynomial poly = polynomial::random(n); + polynomial expect(n); + for (uint64_t trial = 0; trial < 10; ++trial) { + polynomial actual = poly; + int64_t p = uniform_i64_bits(32); + // rotate by p + for (uint64_t i = 0; i < n; ++i) { + expect.set_coeff(i, poly.get_coeff(i - p)); + } + // rotate using the function + rotate(n, p, actual.data()); + ASSERT_EQ(actual, expect); + } + } +} + +TEST(coeffs_arithmetic, rnx_rotate_inplace_f64) { test_rotation_inplace(rnx_rotate_inplace_f64); } + +TEST(coeffs_arithmetic, znx_rotate_inplace_i64) { test_rotation_inplace(znx_rotate_inplace_i64); } + +/// tests of the rotations out of place +template +void test_mul_xp_minus_one_outplace(F rotate) { + for (uint64_t n : {1, 2, 4, 8, 16, 64, 256, 4096}) { + polynomial poly = polynomial::random(n); + polynomial expect(n); + polynomial actual(n); + for (uint64_t trial = 0; trial < 10; ++trial) { + int64_t p = uniform_i64_bits(32); + // rotate by p + for (uint64_t i = 0; i < n; ++i) { + expect.set_coeff(i, poly.get_coeff(i - p) - poly.get_coeff(i)); + } + // rotate using the function + rotate(n, p, actual.data(), poly.data()); + ASSERT_EQ(actual, expect); + } + } +} + +TEST(coeffs_arithmetic, rnx_mul_xp_minus_one_f64) { test_mul_xp_minus_one_outplace(rnx_mul_xp_minus_one_f64); } +TEST(coeffs_arithmetic, znx_mul_xp_minus_one_i64) { test_mul_xp_minus_one_outplace(znx_mul_xp_minus_one_i64); } + +/// tests of the rotations out of place +template +void test_mul_xp_minus_one_inplace(F rotate) { + for (uint64_t n : {1, 2, 4, 8, 16, 64, 256, 4096}) { + polynomial poly = polynomial::random(n); + polynomial expect(n); + for (uint64_t trial = 0; trial < 10; ++trial) { + polynomial actual = poly; + int64_t p = uniform_i64_bits(32); + // rotate by p + for (uint64_t i = 0; i < n; ++i) { + expect.set_coeff(i, poly.get_coeff(i - p) - poly.get_coeff(i)); + } + // rotate using the function + rotate(n, p, actual.data()); + ASSERT_EQ(actual, expect); + } + } +} + +TEST(coeffs_arithmetic, rnx_mul_xp_minus_one_inplace_f64) { + test_mul_xp_minus_one_inplace(rnx_mul_xp_minus_one_inplace_f64); +} + +// TEST(coeffs_arithmetic, znx_mul_xp_minus_one_inplace_i64) { +// test_mul_xp_minus_one_inplace(znx_rotate_inplace_i64); } +/// tests of the automorphisms out of place +template +void test_automorphism_outplace(F automorphism) { + for (uint64_t n : {1, 2, 4, 8, 16, 64, 256, 4096}) { + polynomial poly = polynomial::random(n); + polynomial expect(n); + polynomial actual(n); + for (uint64_t trial = 0; trial < 10; ++trial) { + int64_t p = uniform_i64_bits(32) | int64_t(1); // make it odd + // automorphism p + for (uint64_t i = 0; i < n; ++i) { + expect.set_coeff(i * p, poly.get_coeff(i)); + } + // rotate using the function + automorphism(n, p, actual.data(), poly.data()); + ASSERT_EQ(actual, expect); + } + } +} + +TEST(coeffs_arithmetic, rnx_automorphism_f64) { test_automorphism_outplace(rnx_automorphism_f64); } +TEST(coeffs_arithmetic, znx_automorphism_i64) { test_automorphism_outplace(znx_automorphism_i64); } + +/// tests of the automorphisms out of place +template +void test_automorphism_inplace(F automorphism) { + for (uint64_t n : {1, 2, 4, 8, 16, 64, 256, 4096, 16384}) { + polynomial poly = polynomial::random(n); + polynomial expect(n); + for (uint64_t trial = 0; trial < 20; ++trial) { + polynomial actual = poly; + int64_t p = uniform_i64_bits(32) | int64_t(1); // make it odd + // automorphism p + for (uint64_t i = 0; i < n; ++i) { + expect.set_coeff(i * p, poly.get_coeff(i)); + } + automorphism(n, p, actual.data()); + if (!(actual == expect)) { + std::cerr << "automorphism p: " << p << std::endl; + for (uint64_t i = 0; i < n; ++i) { + std::cerr << i << " " << actual.get_coeff(i) << " vs " << expect.get_coeff(i) << " " + << (actual.get_coeff(i) == expect.get_coeff(i)) << std::endl; + } + } + ASSERT_EQ(actual, expect); + } + } +} +TEST(coeffs_arithmetic, rnx_automorphism_inplace_f64) { + test_automorphism_inplace(rnx_automorphism_inplace_f64); +} +TEST(coeffs_arithmetic, znx_automorphism_inplace_i64) { + test_automorphism_inplace(znx_automorphism_inplace_i64); +} + +// TODO: write a test later! + +/** + * @brief res = (X^p-1).in + * @param nn the ring dimension + * @param p must be between -2nn <= p <= 2nn + * @param in is a rnx/znx vector of dimension nn + */ +EXPORT void rnx_mul_xp_minus_one(uint64_t nn, int64_t p, double* res, const double* in); +EXPORT void znx_mul_xp_minus_one(uint64_t nn, int64_t p, int64_t* res, const int64_t* in); + +// normalize with no carry in nor carry out +template +void test_znx_normalize(F normalize) { + for (const uint64_t n : {1, 2, 4, 8, 16, 64, 256, 4096}) { + polynomial inp = znx_i64::random_log2bound(n, 62); + if (n >= 2) { + inp.set_coeff(0, -(INT64_C(1) << 62)); + inp.set_coeff(1, (INT64_C(1) << 62)); + } + for (const uint64_t base_k : {2, 3, 19, 35, 62}) { + polynomial out; + int64_t* inp_ptr; + if (inplace_flag == 1) { + out = polynomial(inp); + inp_ptr = out.data(); + } else { + out = polynomial(n); + inp_ptr = inp.data(); + } + + znx_normalize(n, base_k, out.data(), nullptr, inp_ptr, nullptr); + for (uint64_t i = 0; i < n; ++i) { + const int64_t x = inp.get_coeff(i); + const int64_t y = out.get_coeff(i); + const int64_t y_exp = centermod(x, INT64_C(1) << base_k); + ASSERT_EQ(y, y_exp) << n << " " << base_k << " " << i << " " << x << " " << y; + } + } + } +} + +TEST(coeffs_arithmetic, znx_normalize_outplace) { test_znx_normalize<0>(znx_normalize); } +TEST(coeffs_arithmetic, znx_normalize_inplace) { test_znx_normalize<1>(znx_normalize); } + +// normalize with no carry in nor carry out +template +void test_znx_normalize_cout(F normalize) { + static_assert(inplace_flag < 3, "either out or cout can be inplace with inp"); + for (const uint64_t n : {1, 2, 4, 8, 16, 64, 256, 4096}) { + polynomial inp = znx_i64::random_log2bound(n, 62); + if (n >= 2) { + inp.set_coeff(0, -(INT64_C(1) << 62)); + inp.set_coeff(1, (INT64_C(1) << 62)); + } + for (const uint64_t base_k : {2, 3, 19, 35, 62}) { + polynomial out, cout; + int64_t* inp_ptr; + if (inplace_flag == 1) { + // out and inp are the same + out = polynomial(inp); + inp_ptr = out.data(); + cout = polynomial(n); + } else if (inplace_flag == 2) { + // carry out and inp are the same + cout = polynomial(inp); + inp_ptr = cout.data(); + out = polynomial(n); + } else { + // inp, out and carry out are distinct + out = polynomial(n); + cout = polynomial(n); + inp_ptr = inp.data(); + } + + znx_normalize(n, base_k, has_output ? out.data() : nullptr, cout.data(), inp_ptr, nullptr); + for (uint64_t i = 0; i < n; ++i) { + const int64_t x = inp.get_coeff(i); + const int64_t co = cout.get_coeff(i); + const int64_t y_exp = centermod((int64_t)x, INT64_C(1) << base_k); + const int64_t co_exp = (x - y_exp) >> base_k; + ASSERT_EQ(co, co_exp); + + if (has_output) { + const int64_t y = out.get_coeff(i); + ASSERT_EQ(y, y_exp); + } + } + } + } +} + +TEST(coeffs_arithmetic, znx_normalize_cout_outplace) { test_znx_normalize_cout<0, false>(znx_normalize); } +TEST(coeffs_arithmetic, znx_normalize_out_cout_outplace) { test_znx_normalize_cout<0, true>(znx_normalize); } +TEST(coeffs_arithmetic, znx_normalize_cout_inplace1) { test_znx_normalize_cout<1, false>(znx_normalize); } +TEST(coeffs_arithmetic, znx_normalize_out_cout_inplace1) { test_znx_normalize_cout<1, true>(znx_normalize); } +TEST(coeffs_arithmetic, znx_normalize_cout_inplace2) { test_znx_normalize_cout<2, false>(znx_normalize); } +TEST(coeffs_arithmetic, znx_normalize_out_cout_inplace2) { test_znx_normalize_cout<2, true>(znx_normalize); } + +// normalize with no carry in nor carry out +template +void test_znx_normalize_cin(F normalize) { + static_assert(inplace_flag < 3, "either inp or cin can be inplace with out"); + for (const uint64_t n : {1, 2, 4, 8, 16, 64, 256, 4096}) { + polynomial inp = znx_i64::random_log2bound(n, 62); + if (n >= 4) { + inp.set_coeff(0, -(INT64_C(1) << 62)); + inp.set_coeff(1, -(INT64_C(1) << 62)); + inp.set_coeff(2, (INT64_C(1) << 62)); + inp.set_coeff(3, (INT64_C(1) << 62)); + } + for (const uint64_t base_k : {2, 3, 19, 35, 62}) { + polynomial cin = znx_i64::random_log2bound(n, 62); + if (n >= 4) { + inp.set_coeff(0, -(INT64_C(1) << 62)); + inp.set_coeff(1, (INT64_C(1) << 62)); + inp.set_coeff(0, -(INT64_C(1) << 62)); + inp.set_coeff(1, (INT64_C(1) << 62)); + } + + polynomial out; + int64_t *inp_ptr, *cin_ptr; + if (inplace_flag == 1) { + // out and inp are the same + out = polynomial(inp); + inp_ptr = out.data(); + cin_ptr = cin.data(); + } else if (inplace_flag == 2) { + // out and carry in are the same + out = polynomial(cin); + inp_ptr = inp.data(); + cin_ptr = out.data(); + } else { + // inp, carry in and out are distinct + out = polynomial(n); + inp_ptr = inp.data(); + cin_ptr = cin.data(); + } + + znx_normalize(n, base_k, out.data(), nullptr, inp_ptr, cin_ptr); + for (uint64_t i = 0; i < n; ++i) { + const int64_t x = inp.get_coeff(i); + const int64_t ci = cin.get_coeff(i); + const int64_t y = out.get_coeff(i); + + const __int128_t xp = (__int128_t)x + ci; + const int64_t y_exp = centermod((int64_t)xp, INT64_C(1) << base_k); + + ASSERT_EQ(y, y_exp) << n << " " << base_k << " " << i << " " << x << " " << y << " " << ci; + } + } + } +} + +TEST(coeffs_arithmetic, znx_normalize_cin_outplace) { test_znx_normalize_cin<0>(znx_normalize); } +TEST(coeffs_arithmetic, znx_normalize_cin_inplace1) { test_znx_normalize_cin<1>(znx_normalize); } +TEST(coeffs_arithmetic, znx_normalize_cin_inplace2) { test_znx_normalize_cin<2>(znx_normalize); } + +// normalize with no carry in nor carry out +template +void test_znx_normalize_cin_cout(F normalize) { + static_assert(inplace_flag < 7, "either inp or cin can be inplace with out"); + for (const uint64_t n : {1, 2, 4, 8, 16, 64, 256, 4096}) { + polynomial inp = znx_i64::random_log2bound(n, 62); + if (n >= 4) { + inp.set_coeff(0, -(INT64_C(1) << 62)); + inp.set_coeff(1, -(INT64_C(1) << 62)); + inp.set_coeff(2, (INT64_C(1) << 62)); + inp.set_coeff(3, (INT64_C(1) << 62)); + } + for (const uint64_t base_k : {2, 3, 19, 35, 62}) { + polynomial cin = znx_i64::random_log2bound(n, 62); + if (n >= 4) { + inp.set_coeff(0, -(INT64_C(1) << 62)); + inp.set_coeff(1, (INT64_C(1) << 62)); + inp.set_coeff(0, -(INT64_C(1) << 62)); + inp.set_coeff(1, (INT64_C(1) << 62)); + } + + polynomial out, cout; + int64_t *inp_ptr, *cin_ptr; + if (inplace_flag == 1) { + // out == inp + out = polynomial(inp); + cout = polynomial(n); + inp_ptr = out.data(); + cin_ptr = cin.data(); + } else if (inplace_flag == 2) { + // cout == inp + out = polynomial(n); + cout = polynomial(inp); + inp_ptr = cout.data(); + cin_ptr = cin.data(); + } else if (inplace_flag == 3) { + // out == cin + out = polynomial(cin); + cout = polynomial(n); + inp_ptr = inp.data(); + cin_ptr = out.data(); + } else if (inplace_flag == 4) { + // cout == cin + out = polynomial(n); + cout = polynomial(cin); + inp_ptr = inp.data(); + cin_ptr = cout.data(); + } else if (inplace_flag == 5) { + // out == inp, cout == cin + out = polynomial(inp); + cout = polynomial(cin); + inp_ptr = out.data(); + cin_ptr = cout.data(); + } else if (inplace_flag == 6) { + // out == cin, cout == inp + out = polynomial(cin); + cout = polynomial(inp); + inp_ptr = cout.data(); + cin_ptr = out.data(); + } else { + out = polynomial(n); + cout = polynomial(n); + inp_ptr = inp.data(); + cin_ptr = cin.data(); + } + + znx_normalize(n, base_k, has_output ? out.data() : nullptr, cout.data(), inp_ptr, cin_ptr); + for (uint64_t i = 0; i < n; ++i) { + const int64_t x = inp.get_coeff(i); + const int64_t ci = cin.get_coeff(i); + const int64_t co = cout.get_coeff(i); + + const __int128_t xp = (__int128_t)x + ci; + const int64_t y_exp = centermod((int64_t)xp, INT64_C(1) << base_k); + const int64_t co_exp = (xp - y_exp) >> base_k; + ASSERT_EQ(co, co_exp); + + if (has_output) { + const int64_t y = out.get_coeff(i); + ASSERT_EQ(y, y_exp); + } + } + } + } +} + +TEST(coeffs_arithmetic, znx_normalize_cin_cout_outplace) { test_znx_normalize_cin_cout<0, false>(znx_normalize); } +TEST(coeffs_arithmetic, znx_normalize_out_cin_cout_outplace) { test_znx_normalize_cin_cout<0, true>(znx_normalize); } +TEST(coeffs_arithmetic, znx_normalize_cin_cout_inplace1) { test_znx_normalize_cin_cout<1, false>(znx_normalize); } +TEST(coeffs_arithmetic, znx_normalize_out_cin_cout_inplace1) { test_znx_normalize_cin_cout<1, true>(znx_normalize); } +TEST(coeffs_arithmetic, znx_normalize_cin_cout_inplace2) { test_znx_normalize_cin_cout<2, false>(znx_normalize); } +TEST(coeffs_arithmetic, znx_normalize_out_cin_cout_inplace2) { test_znx_normalize_cin_cout<2, true>(znx_normalize); } +TEST(coeffs_arithmetic, znx_normalize_cin_cout_inplace3) { test_znx_normalize_cin_cout<3, false>(znx_normalize); } +TEST(coeffs_arithmetic, znx_normalize_out_cin_cout_inplace3) { test_znx_normalize_cin_cout<3, true>(znx_normalize); } +TEST(coeffs_arithmetic, znx_normalize_cin_cout_inplace4) { test_znx_normalize_cin_cout<4, false>(znx_normalize); } +TEST(coeffs_arithmetic, znx_normalize_out_cin_cout_inplace4) { test_znx_normalize_cin_cout<4, true>(znx_normalize); } +TEST(coeffs_arithmetic, znx_normalize_cin_cout_inplace5) { test_znx_normalize_cin_cout<5, false>(znx_normalize); } +TEST(coeffs_arithmetic, znx_normalize_out_cin_cout_inplace5) { test_znx_normalize_cin_cout<5, true>(znx_normalize); } +TEST(coeffs_arithmetic, znx_normalize_cin_cout_inplace6) { test_znx_normalize_cin_cout<6, false>(znx_normalize); } +TEST(coeffs_arithmetic, znx_normalize_out_cin_cout_inplace6) { test_znx_normalize_cin_cout<6, true>(znx_normalize); } diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_cplx_conversions_test.cpp b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_cplx_conversions_test.cpp new file mode 100644 index 0000000..32c9158 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_cplx_conversions_test.cpp @@ -0,0 +1,86 @@ +#include + +#include + +#include "spqlios/cplx/cplx_fft_internal.h" +#include "spqlios/cplx/cplx_fft_private.h" + +#ifdef __x86_64__ +TEST(fft, cplx_from_znx32_ref_vs_fma) { + const uint32_t m = 128; + int32_t* src = (int32_t*)spqlios_alloc_custom_align(32, 10 * m * sizeof(int32_t)); + CPLX* dst1 = (CPLX*)(src + 2 * m); + CPLX* dst2 = (CPLX*)(src + 6 * m); + for (uint64_t i = 0; i < 2 * m; ++i) { + src[i] = rand() - RAND_MAX / 2; + } + CPLX_FROM_ZNX32_PRECOMP precomp; + precomp.m = m; + cplx_from_znx32_ref(&precomp, dst1, src); + // cplx_from_znx32_simple(m, 32, dst1, src); + cplx_from_znx32_avx2_fma(&precomp, dst2, src); + for (uint64_t i = 0; i < m; ++i) { + ASSERT_EQ(dst1[i][0], dst2[i][0]); + ASSERT_EQ(dst1[i][1], dst2[i][1]); + } + spqlios_free(src); +} +#endif + +#ifdef __x86_64__ +TEST(fft, cplx_from_tnx32_ref_vs_fma) { + const uint32_t m = 128; + int32_t* src = (int32_t*)spqlios_alloc_custom_align(32, 10 * m * sizeof(int32_t)); + CPLX* dst1 = (CPLX*)(src + 2 * m); + CPLX* dst2 = (CPLX*)(src + 6 * m); + for (uint64_t i = 0; i < 2 * m; ++i) { + src[i] = rand() + (rand() << 20); + } + CPLX_FROM_TNX32_PRECOMP precomp; + precomp.m = m; + cplx_from_tnx32_ref(&precomp, dst1, src); + // cplx_from_tnx32_simple(m, dst1, src); + cplx_from_tnx32_avx2_fma(&precomp, dst2, src); + for (uint64_t i = 0; i < m; ++i) { + ASSERT_EQ(dst1[i][0], dst2[i][0]); + ASSERT_EQ(dst1[i][1], dst2[i][1]); + } + spqlios_free(src); +} +#endif + +#ifdef __x86_64__ +TEST(fft, cplx_to_tnx32_ref_vs_fma) { + for (const uint32_t m : {8, 128, 1024, 65536}) { + for (const double divisor : {double(1), double(m), double(0.5)}) { + CPLX* src = (CPLX*)spqlios_alloc_custom_align(32, 10 * m * sizeof(int32_t)); + int32_t* dst1 = (int32_t*)(src + m); + int32_t* dst2 = (int32_t*)(src + 2 * m); + for (uint64_t i = 0; i < 2 * m; ++i) { + src[i][0] = (rand() / double(RAND_MAX) - 0.5) * pow(2., 19 - (rand() % 60)) * divisor; + src[i][1] = (rand() / double(RAND_MAX) - 0.5) * pow(2., 19 - (rand() % 60)) * divisor; + } + CPLX_TO_TNX32_PRECOMP precomp; + precomp.m = m; + precomp.divisor = divisor; + cplx_to_tnx32_ref(&precomp, dst1, src); + cplx_to_tnx32_avx2_fma(&precomp, dst2, src); + // cplx_to_tnx32_simple(m, divisor, 18, dst2, src); + for (uint64_t i = 0; i < 2 * m; ++i) { + double truevalue = + (src[i % m][i / m] / divisor - floor(src[i % m][i / m] / divisor + 0.5)) * (INT64_C(1) << 32); + if (fabs(truevalue - floor(truevalue)) == 0.5) { + // ties can differ by 0, 1 or -1 + ASSERT_LE(abs(dst1[i] - dst2[i]), 0) + << i << " " << dst1[i] << " " << dst2[i] << " " << truevalue << std::endl; + } else { + // otherwise, we should have equality + ASSERT_LE(abs(dst1[i] - dst2[i]), 0) + << i << " " << dst1[i] << " " << dst2[i] << " " << truevalue << std::endl; + } + } + spqlios_free(src); + } + } +} +#endif diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_cplx_fft_bench.cpp b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_cplx_fft_bench.cpp new file mode 100644 index 0000000..19f4905 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_cplx_fft_bench.cpp @@ -0,0 +1,112 @@ +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../spqlios/cplx/cplx_fft_internal.h" +#include "spqlios/reim/reim_fft.h" + +using namespace std; + +void init_random_values(uint64_t n, double* v) { + for (uint64_t i = 0; i < n; ++i) v[i] = rand() - (RAND_MAX >> 1); +} + +void benchmark_cplx_fft(benchmark::State& state) { + const int32_t nn = state.range(0); + CPLX_FFT_PRECOMP* a = new_cplx_fft_precomp(nn / 2, 1); + double* c = (double*)cplx_fft_precomp_get_buffer(a, 0); + init_random_values(nn, c); + for (auto _ : state) { + // cplx_fft_simple(nn/2, c); + cplx_fft(a, c); + } + delete_cplx_fft_precomp(a); +} + +void benchmark_cplx_ifft(benchmark::State& state) { + const int32_t nn = state.range(0); + CPLX_IFFT_PRECOMP* a = new_cplx_ifft_precomp(nn / 2, 1); + double* c = (double*)cplx_ifft_precomp_get_buffer(a, 0); + init_random_values(nn, c); + for (auto _ : state) { + // cplx_ifft_simple(nn/2, c); + cplx_ifft(a, c); + } + delete_cplx_ifft_precomp(a); +} + +void benchmark_reim_fft(benchmark::State& state) { + const int32_t nn = state.range(0); + const uint32_t m = nn / 2; + REIM_FFT_PRECOMP* a = new_reim_fft_precomp(m, 1); + double* c = reim_fft_precomp_get_buffer(a, 0); + init_random_values(nn, c); + for (auto _ : state) { + // cplx_fft_simple(nn/2, c); + reim_fft(a, c); + } + delete_reim_fft_precomp(a); +} + +#ifdef __aarch64__ +EXPORT REIM_FFT_PRECOMP* new_reim_fft_precomp_neon(uint32_t m, uint32_t num_buffers); +EXPORT void reim_fft_neon(const REIM_FFT_PRECOMP* precomp, double* d); + +void benchmark_reim_fft_neon(benchmark::State& state) { + const int32_t nn = state.range(0); + const uint32_t m = nn / 2; + REIM_FFT_PRECOMP* a = new_reim_fft_precomp_neon(m, 1); + double* c = reim_fft_precomp_get_buffer(a, 0); + init_random_values(nn, c); + for (auto _ : state) { + // cplx_fft_simple(nn/2, c); + reim_fft_neon(a, c); + } + delete_reim_fft_precomp(a); +} +#endif + +void benchmark_reim_ifft(benchmark::State& state) { + const int32_t nn = state.range(0); + const uint32_t m = nn / 2; + REIM_IFFT_PRECOMP* a = new_reim_ifft_precomp(m, 1); + double* c = reim_ifft_precomp_get_buffer(a, 0); + init_random_values(nn, c); + for (auto _ : state) { + // cplx_ifft_simple(nn/2, c); + reim_ifft(a, c); + } + delete_reim_ifft_precomp(a); +} + +// #define ARGS Arg(1024)->Arg(8192)->Arg(32768)->Arg(65536) +#define ARGS Arg(64)->Arg(256)->Arg(1024)->Arg(2048)->Arg(4096)->Arg(8192)->Arg(16384)->Arg(32768)->Arg(65536) + +int main(int argc, char** argv) { + ::benchmark::Initialize(&argc, argv); + if (::benchmark::ReportUnrecognizedArguments(argc, argv)) return 1; + std::cout << "Dimensions n in the benchmark below are in \"real FFT\" modulo X^n+1" << std::endl; + std::cout << "The complex dimension m (modulo X^m-i) is half of it" << std::endl; + BENCHMARK(benchmark_cplx_fft)->ARGS; + BENCHMARK(benchmark_cplx_ifft)->ARGS; + BENCHMARK(benchmark_reim_fft)->ARGS; +#ifdef __aarch64__ + BENCHMARK(benchmark_reim_fft_neon)->ARGS; +#endif + BENCHMARK(benchmark_reim_ifft)->ARGS; + // if (CPU_SUPPORTS("avx512f")) { + // BENCHMARK(bench_cplx_fftvec_twiddle_avx512)->ARGS; + // BENCHMARK(bench_cplx_fftvec_bitwiddle_avx512)->ARGS; + //} + ::benchmark::RunSpecifiedBenchmarks(); + ::benchmark::Shutdown(); + return 0; +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_cplx_test.cpp b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_cplx_test.cpp new file mode 100644 index 0000000..13c5513 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_cplx_test.cpp @@ -0,0 +1,496 @@ +#include + +#include "gtest/gtest.h" +#include "spqlios/commons_private.h" +#include "spqlios/cplx/cplx_fft.h" +#include "spqlios/cplx/cplx_fft_internal.h" +#include "spqlios/cplx/cplx_fft_private.h" + +#ifdef __x86_64__ +TEST(fft, ifft16_fma_vs_ref) { + CPLX data[16]; + CPLX omega[8]; + for (uint64_t i = 0; i < 32; ++i) ((double*)data)[i] = 2 * i + 1; //(rand()%100)-50; + for (uint64_t i = 0; i < 16; ++i) ((double*)omega)[i] = i + 1; //(rand()%100)-50; + CPLX copydata[16]; + CPLX copyomega[8]; + memcpy(copydata, data, sizeof(copydata)); + memcpy(copyomega, omega, sizeof(copyomega)); + cplx_ifft16_avx_fma(data, omega); + cplx_ifft16_ref(copydata, copyomega); + double distance = 0; + for (uint64_t i = 0; i < 16; ++i) { + double d1 = fabs(data[i][0] - copydata[i][0]); + double d2 = fabs(data[i][0] - copydata[i][0]); + if (d1 > distance) distance = d1; + if (d2 > distance) distance = d2; + } + /* + printf("data:\n"); + for (uint64_t i=0; i<4; ++i) { + for (uint64_t j=0; j<8; ++j) { + printf("%.5lf ", data[4 * i + j / 2][j % 2]); + } + printf("\n"); + } + printf("copydata:\n"); + for (uint64_t i=0; i<4; ++i) { + for (uint64_t j=0; j<8; ++j) { + printf("%5.5lf ", copydata[4 * i + j / 2][j % 2]); + } + printf("\n"); + } + */ + ASSERT_EQ(distance, 0); +} + +#endif + +void cplx_zero(CPLX r) { r[0] = r[1] = 0; } +void cplx_addmul(CPLX r, const CPLX a, const CPLX b) { + double re = r[0] + a[0] * b[0] - a[1] * b[1]; + double im = r[1] + a[0] * b[1] + a[1] * b[0]; + r[0] = re; + r[1] = im; +} + +void halfcfft_eval(CPLX res, uint32_t nn, uint32_t k, const CPLX* coeffs, const CPLX* powomegas) { + const uint32_t N = nn / 2; + cplx_zero(res); + for (uint64_t i = 0; i < N; ++i) { + cplx_addmul(res, coeffs[i], powomegas[(k * i) % (2 * nn)]); + } +} +void halfcfft_naive(uint32_t nn, CPLX* data) { + const uint32_t N = nn / 2; + CPLX* in = (CPLX*)malloc(N * sizeof(CPLX)); + CPLX* powomega = (CPLX*)malloc(2 * nn * sizeof(CPLX)); + for (uint64_t i = 0; i < (2 * nn); ++i) { + powomega[i][0] = m_accurate_cos((M_PI * i) / nn); + powomega[i][1] = m_accurate_sin((M_PI * i) / nn); + } + memcpy(in, data, N * sizeof(CPLX)); + for (uint64_t j = 0; j < N; ++j) { + uint64_t p = rint(log2(N)) + 2; + uint64_t k = revbits(p, j) + 1; + halfcfft_eval(data[j], nn, k, in, powomega); + } + free(powomega); + free(in); +} + +#ifdef __x86_64__ +TEST(fft, fft16_fma_vs_ref) { + CPLX data[16]; + CPLX omega[8]; + for (uint64_t i = 0; i < 32; ++i) ((double*)data)[i] = rand() % 1000; + for (uint64_t i = 0; i < 16; ++i) ((double*)omega)[i] = rand() % 1000; + CPLX copydata[16]; + CPLX copyomega[8]; + memcpy(copydata, data, sizeof(copydata)); + memcpy(copyomega, omega, sizeof(copyomega)); + cplx_fft16_avx_fma(data, omega); + cplx_fft16_ref(copydata, omega); + double distance = 0; + for (uint64_t i = 0; i < 16; ++i) { + double d1 = fabs(data[i][0] - copydata[i][0]); + double d2 = fabs(data[i][0] - copydata[i][0]); + if (d1 > distance) distance = d1; + if (d2 > distance) distance = d2; + } + ASSERT_EQ(distance, 0); +} +#endif + +TEST(fft, citwiddle_then_invcitwiddle) { + CPLX om; + CPLX ombar; + CPLX data[2]; + CPLX copydata[2]; + om[0] = cos(3); + om[1] = sin(3); + ombar[0] = om[0]; + ombar[1] = -om[1]; + data[0][0] = 47; + data[0][1] = 23; + data[1][0] = -12; + data[1][1] = -9; + memcpy(copydata, data, sizeof(copydata)); + citwiddle(data[0], data[1], om); + invcitwiddle(data[0], data[1], ombar); + double distance = 0; + for (uint64_t i = 0; i < 2; ++i) { + double d1 = fabs(data[i][0] - 2 * copydata[i][0]); + double d2 = fabs(data[i][1] - 2 * copydata[i][1]); + if (d1 > distance) distance = d1; + if (d2 > distance) distance = d2; + } + ASSERT_LE(distance, 1e-9); +} + +TEST(fft, ctwiddle_then_invctwiddle) { + CPLX om; + CPLX ombar; + CPLX data[2]; + CPLX copydata[2]; + om[0] = cos(3); + om[1] = sin(3); + ombar[0] = om[0]; + ombar[1] = -om[1]; + data[0][0] = 47; + data[0][1] = 23; + data[1][0] = -12; + data[1][1] = -9; + memcpy(copydata, data, sizeof(copydata)); + ctwiddle(data[0], data[1], om); + invctwiddle(data[0], data[1], ombar); + double distance = 0; + for (uint64_t i = 0; i < 2; ++i) { + double d1 = fabs(data[i][0] - 2 * copydata[i][0]); + double d2 = fabs(data[i][1] - 2 * copydata[i][1]); + if (d1 > distance) distance = d1; + if (d2 > distance) distance = d2; + } + ASSERT_LE(distance, 1e-9); +} + +TEST(fft, fft16_then_ifft16_ref) { + CPLX full_omegas[64]; + CPLX full_omegabars[64]; + for (uint64_t i = 0; i < 64; ++i) { + full_omegas[i][0] = cos(M_PI * i / 32.); + full_omegas[i][1] = sin(M_PI * i / 32.); + full_omegabars[i][0] = full_omegas[i][0]; + full_omegabars[i][1] = -full_omegas[i][1]; + } + CPLX omega[8]; + CPLX omegabar[8]; + cplx_set(omega[0], full_omegas[8]); // j + cplx_set(omega[1], full_omegas[4]); // k + cplx_set(omega[2], full_omegas[2]); // l + cplx_set(omega[3], full_omegas[10]); // lj + cplx_set(omega[4], full_omegas[1]); // n + cplx_set(omega[5], full_omegas[9]); // nj + cplx_set(omega[6], full_omegas[5]); // nk + cplx_set(omega[7], full_omegas[13]); // njk + cplx_set(omegabar[0], full_omegabars[1]); // n + cplx_set(omegabar[1], full_omegabars[9]); // nj + cplx_set(omegabar[2], full_omegabars[5]); // nk + cplx_set(omegabar[3], full_omegabars[13]); // njk + cplx_set(omegabar[4], full_omegabars[2]); // l + cplx_set(omegabar[5], full_omegabars[10]); // lj + cplx_set(omegabar[6], full_omegabars[4]); // k + cplx_set(omegabar[7], full_omegabars[8]); // j + CPLX data[16]; + CPLX copydata[16]; + for (uint64_t i = 0; i < 32; ++i) ((double*)data)[i] = rand() % 1000; + memcpy(copydata, data, sizeof(copydata)); + cplx_fft16_ref(data, omega); + cplx_ifft16_ref(data, omegabar); + double distance = 0; + for (uint64_t i = 0; i < 16; ++i) { + double d1 = fabs(data[i][0] - 16 * copydata[i][0]); + double d2 = fabs(data[i][0] - 16 * copydata[i][0]); + if (d1 > distance) distance = d1; + if (d2 > distance) distance = d2; + } + ASSERT_LE(distance, 1e-9); +} + +TEST(fft, halfcfft_ref_vs_naive) { + for (uint64_t nn : {4, 8, 16, 64, 256, 8192}) { + uint64_t m = nn / 2; + CPLX_FFT_PRECOMP* tables = new_cplx_fft_precomp(m, 0); + CPLX* a = (CPLX*)spqlios_alloc_custom_align(32, m * sizeof(CPLX)); + CPLX* a1 = (CPLX*)spqlios_alloc_custom_align(32, m * sizeof(CPLX)); + CPLX* a2 = (CPLX*)spqlios_alloc_custom_align(32, m * sizeof(CPLX)); + int64_t p = 1 << 16; + for (uint32_t i = 0; i < m; i++) { + a[i][0] = (rand() % p) - p / 2; // between -p/2 and p/2 + a[i][1] = (rand() % p) - p / 2; + } + memcpy(a1, a, m * sizeof(CPLX)); + memcpy(a2, a, m * sizeof(CPLX)); + + halfcfft_naive(nn, a1); + cplx_fft_naive(m, 0.25, a2); + + double d = 0; + for (uint32_t i = 0; i < nn / 2; i++) { + double dre = fabs(a1[i][0] - a2[i][0]); + double dim = fabs(a1[i][1] - a2[i][1]); + if (dre > d) d = dre; + if (dim > d) d = dim; + } + ASSERT_LE(d, nn * 1e-10) << nn; + spqlios_free(a); + spqlios_free(a1); + spqlios_free(a2); + delete_cplx_fft_precomp(tables); + } +} + +#ifdef __x86_64__ +TEST(fft, halfcfft_fma_vs_ref) { + typedef void (*FFTF)(const CPLX_FFT_PRECOMP*, void* data); + for (FFTF fft : {cplx_fft_ref, cplx_fft_avx2_fma}) { + for (uint64_t nn : {8, 16, 32, 64, 1024, 8192, 65536}) { + uint64_t m = nn / 2; + CPLX_FFT_PRECOMP* tables = new_cplx_fft_precomp(m, 0); + CPLX* a = (CPLX*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + CPLX* a1 = (CPLX*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + CPLX* a2 = (CPLX*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + int64_t p = 1 << 16; + for (uint32_t i = 0; i < nn / 2; i++) { + a[i][0] = (rand() % p) - p / 2; // between -p/2 and p/2 + a[i][1] = (rand() % p) - p / 2; + } + memcpy(a1, a, nn / 2 * sizeof(CPLX)); + memcpy(a2, a, nn / 2 * sizeof(CPLX)); + cplx_fft_naive(m, 0.25, a2); + fft(tables, a1); + double d = 0; + for (uint32_t i = 0; i < nn / 2; i++) { + double dre = fabs(a1[i][0] - a2[i][0]); + double dim = fabs(a1[i][1] - a2[i][1]); + if (dre > d) d = dre; + if (dim > d) d = dim; + } + ASSERT_LE(d, nn * 1e-10) << nn; + spqlios_free(a); + spqlios_free(a1); + spqlios_free(a2); + delete_cplx_fft_precomp(tables); + } + } +} +#endif + +TEST(fft, halfcfft_then_ifft_ref) { + for (uint64_t nn : {4, 8, 16, 32, 64, 1024, 8192, 65536}) { + uint64_t m = nn / 2; + CPLX_FFT_PRECOMP* tables = new_cplx_fft_precomp(m, 0); + CPLX_IFFT_PRECOMP* itables = new_cplx_ifft_precomp(m, 0); + CPLX* a = (CPLX*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + CPLX* a1 = (CPLX*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + int64_t p = 1 << 16; + for (uint32_t i = 0; i < nn / 2; i++) { + a[i][0] = (rand() % p) - p / 2; // between -p/2 and p/2 + a[i][1] = (rand() % p) - p / 2; + } + memcpy(a1, a, nn / 2 * sizeof(CPLX)); + cplx_fft_ref(tables, a1); + cplx_ifft_ref(itables, a1); + double d = 0; + for (uint32_t i = 0; i < nn / 2; i++) { + double dre = fabs(a[i][0] - a1[i][0] / (nn / 2)); + double dim = fabs(a[i][1] - a1[i][1] / (nn / 2)); + if (dre > d) d = dre; + if (dim > d) d = dim; + } + ASSERT_LE(d, 1e-8); + spqlios_free(a); + spqlios_free(a1); + delete_cplx_fft_precomp(tables); + delete_cplx_ifft_precomp(itables); + } +} + +#ifdef __x86_64__ +TEST(fft, halfcfft_ifft_fma_vs_ref) { + for (IFFT_FUNCTION ifft : {cplx_ifft_ref, cplx_ifft_avx2_fma}) { + for (uint64_t nn : {8, 16, 32, 1024, 4096, 8192, 65536}) { + uint64_t m = nn / 2; + CPLX_IFFT_PRECOMP* itables = new_cplx_ifft_precomp(m, 0); + CPLX* a = (CPLX*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + CPLX* a1 = (CPLX*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + CPLX* a2 = (CPLX*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + int64_t p = 1 << 16; + for (uint32_t i = 0; i < nn / 2; i++) { + a[i][0] = (rand() % p) - p / 2; // between -p/2 and p/2 + a[i][1] = (rand() % p) - p / 2; + } + memcpy(a1, a, nn / 2 * sizeof(CPLX)); + memcpy(a2, a, nn / 2 * sizeof(CPLX)); + cplx_ifft_naive(m, 0.25, a2); + ifft(itables, a1); + double d = 0; + for (uint32_t i = 0; i < nn / 2; i++) { + double dre = fabs(a1[i][0] - a2[i][0]); + double dim = fabs(a1[i][1] - a2[i][1]); + if (dre > d) d = dre; + if (dim > d) d = dim; + } + ASSERT_LE(d, 1e-8); + spqlios_free(a); + spqlios_free(a1); + spqlios_free(a2); + delete_cplx_ifft_precomp(itables); + } + } +} +#endif + +// test the reference and simple implementations of mul on all dimensions +TEST(fftvec, cplx_fftvec_mul_ref) { + for (uint64_t nn : {2, 4, 8, 16, 32, 1024, 4096, 8192, 65536}) { + uint64_t m = nn / 2; + CPLX_FFTVEC_MUL_PRECOMP* precomp = new_cplx_fftvec_mul_precomp(m); + CPLX* a = new CPLX[m]; + CPLX* b = new CPLX[m]; + CPLX* r0 = new CPLX[m]; + CPLX* r1 = new CPLX[m]; + CPLX* r2 = new CPLX[m]; + int64_t p = 1 << 16; + for (uint32_t i = 0; i < m; i++) { + a[i][0] = (rand() % p) - p / 2; // between -p/2 and p/2 + a[i][1] = (rand() % p) - p / 2; + b[i][0] = (rand() % p) - p / 2; // between -p/2 and p/2 + b[i][1] = (rand() % p) - p / 2; + r2[i][0] = r1[i][0] = r0[i][0] = (rand() % p) - p / 2; // between -p/2 and p/2 + r2[i][1] = r1[i][1] = r0[i][1] = (rand() % p) - p / 2; + } + cplx_fftvec_mul_simple(m, r0, a, b); + cplx_fftvec_mul_ref(precomp, r1, a, b); + for (uint32_t i = 0; i < m; i++) { + r2[i][0] = a[i][0] * b[i][0] - a[i][1] * b[i][1]; + r2[i][1] = a[i][0] * b[i][1] + a[i][1] * b[i][0]; + ASSERT_LE(fabs(r1[i][0] - r2[i][0]) + fabs(r1[i][1] - r2[i][1]), 1e-8); + ASSERT_LE(fabs(r0[i][0] - r2[i][0]) + fabs(r0[i][1] - r2[i][1]), 1e-8); + } + delete[] a; + delete[] b; + delete[] r0; + delete[] r1; + delete[] r2; + delete_cplx_fftvec_mul_precomp(precomp); + } +} + +// test the reference and simple implementations of addmul on all dimensions +TEST(fftvec, cplx_fftvec_addmul_ref) { + for (uint64_t nn : {2, 4, 8, 16, 32, 1024, 4096, 8192, 65536}) { + uint64_t m = nn / 2; + CPLX_FFTVEC_ADDMUL_PRECOMP* precomp = new_cplx_fftvec_addmul_precomp(m); + CPLX* a = new CPLX[m]; + CPLX* b = new CPLX[m]; + CPLX* r0 = new CPLX[m]; + CPLX* r1 = new CPLX[m]; + CPLX* r2 = new CPLX[m]; + int64_t p = 1 << 16; + for (uint32_t i = 0; i < m; i++) { + a[i][0] = (rand() % p) - p / 2; // between -p/2 and p/2 + a[i][1] = (rand() % p) - p / 2; + b[i][0] = (rand() % p) - p / 2; // between -p/2 and p/2 + b[i][1] = (rand() % p) - p / 2; + r2[i][0] = r1[i][0] = r0[i][0] = (rand() % p) - p / 2; // between -p/2 and p/2 + r2[i][1] = r1[i][1] = r0[i][1] = (rand() % p) - p / 2; + } + cplx_fftvec_addmul_simple(m, r0, a, b); + cplx_fftvec_addmul_ref(precomp, r1, a, b); + for (uint32_t i = 0; i < m; i++) { + r2[i][0] += a[i][0] * b[i][0] - a[i][1] * b[i][1]; + r2[i][1] += a[i][0] * b[i][1] + a[i][1] * b[i][0]; + ASSERT_LE(fabs(r1[i][0] - r2[i][0]) + fabs(r1[i][1] - r2[i][1]), 1e-8); + ASSERT_LE(fabs(r0[i][0] - r2[i][0]) + fabs(r0[i][1] - r2[i][1]), 1e-8); + } + delete[] a; + delete[] b; + delete[] r0; + delete[] r1; + delete[] r2; + delete_cplx_fftvec_addmul_precomp(precomp); + } +} + +// comparative tests between mul ref vs. optimized (only relevant dimensions) +TEST(fftvec, cplx_fftvec_mul_ref_vs_optim) { + struct totest { + FFTVEC_MUL_FUNCTION f; + uint64_t min_m; + totest(FFTVEC_MUL_FUNCTION f, uint64_t min_m) : f(f), min_m(min_m) {} + }; + std::vector totestset; + totestset.emplace_back(cplx_fftvec_mul, 1); +#ifdef __x86_64__ + totestset.emplace_back(cplx_fftvec_mul_fma, 8); +#endif + for (uint64_t m : {1, 2, 4, 8, 16, 1024, 4096, 8192, 65536}) { + CPLX_FFTVEC_MUL_PRECOMP* precomp = new_cplx_fftvec_mul_precomp(m); + for (const totest& t : totestset) { + if (t.min_m > m) continue; + CPLX* a = new CPLX[m]; + CPLX* b = new CPLX[m]; + CPLX* r1 = new CPLX[m]; + CPLX* r2 = new CPLX[m]; + int64_t p = 1 << 16; + for (uint32_t i = 0; i < m; i++) { + a[i][0] = (rand() % p) - p / 2; // between -p/2 and p/2 + a[i][1] = (rand() % p) - p / 2; + b[i][0] = (rand() % p) - p / 2; // between -p/2 and p/2 + b[i][1] = (rand() % p) - p / 2; + r2[i][0] = r1[i][0] = (rand() % p) - p / 2; // between -p/2 and p/2 + r2[i][1] = r1[i][1] = (rand() % p) - p / 2; + } + t.f(precomp, r1, a, b); + cplx_fftvec_mul_ref(precomp, r2, a, b); + for (uint32_t i = 0; i < m; i++) { + double dre = fabs(r1[i][0] - r2[i][0]); + double dim = fabs(r1[i][1] - r2[i][1]); + ASSERT_LE(dre, 1e-8); + ASSERT_LE(dim, 1e-8); + } + delete[] a; + delete[] b; + delete[] r1; + delete[] r2; + } + delete_cplx_fftvec_mul_precomp(precomp); + } +} + +// comparative tests between addmul ref vs. optimized (only relevant dimensions) +TEST(fftvec, cplx_fftvec_addmul_ref_vs_optim) { + struct totest { + FFTVEC_ADDMUL_FUNCTION f; + uint64_t min_m; + totest(FFTVEC_ADDMUL_FUNCTION f, uint64_t min_m) : f(f), min_m(min_m) {} + }; + std::vector totestset; + totestset.emplace_back(cplx_fftvec_addmul, 1); +#ifdef __x86_64__ + totestset.emplace_back(cplx_fftvec_addmul_fma, 8); +#endif + for (uint64_t m : {1, 2, 4, 8, 16, 1024, 4096, 8192, 65536}) { + CPLX_FFTVEC_ADDMUL_PRECOMP* precomp = new_cplx_fftvec_addmul_precomp(m); + for (const totest& t : totestset) { + if (t.min_m > m) continue; + CPLX* a = new CPLX[m]; + CPLX* b = new CPLX[m]; + CPLX* r1 = new CPLX[m]; + CPLX* r2 = new CPLX[m]; + int64_t p = 1 << 16; + for (uint32_t i = 0; i < m; i++) { + a[i][0] = (rand() % p) - p / 2; // between -p/2 and p/2 + a[i][1] = (rand() % p) - p / 2; + b[i][0] = (rand() % p) - p / 2; // between -p/2 and p/2 + b[i][1] = (rand() % p) - p / 2; + r2[i][0] = r1[i][0] = (rand() % p) - p / 2; // between -p/2 and p/2 + r2[i][1] = r1[i][1] = (rand() % p) - p / 2; + } + t.f(precomp, r1, a, b); + cplx_fftvec_addmul_ref(precomp, r2, a, b); + for (uint32_t i = 0; i < m; i++) { + double dre = fabs(r1[i][0] - r2[i][0]); + double dim = fabs(r1[i][1] - r2[i][1]); + ASSERT_LE(dre, 1e-8); + ASSERT_LE(dim, 1e-8); + } + delete[] a; + delete[] b; + delete[] r1; + delete[] r2; + } + delete_cplx_fftvec_addmul_precomp(precomp); + } +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_q120_arithmetic_bench.cpp b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_q120_arithmetic_bench.cpp new file mode 100644 index 0000000..b344503 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_q120_arithmetic_bench.cpp @@ -0,0 +1,136 @@ +#include + +#include + +#include "spqlios/q120/q120_arithmetic.h" + +#define ARGS Arg(128)->Arg(4096)->Arg(10000) + +template +void benchmark_baa(benchmark::State& state) { + const uint64_t ell = state.range(0); + q120_mat1col_product_baa_precomp* precomp = q120_new_vec_mat1col_product_baa_precomp(); + + uint64_t* a = new uint64_t[ell * 4]; + uint64_t* b = new uint64_t[ell * 4]; + uint64_t* c = new uint64_t[4]; + for (uint64_t i = 0; i < 4 * ell; i++) { + a[i] = rand(); + b[i] = rand(); + } + for (auto _ : state) { + f(precomp, ell, (q120b*)c, (q120a*)a, (q120a*)b); + } + delete[] c; + delete[] b; + delete[] a; + q120_delete_vec_mat1col_product_baa_precomp(precomp); +} + +BENCHMARK(benchmark_baa)->Name("q120_vec_mat1col_product_baa_ref")->ARGS; +BENCHMARK(benchmark_baa)->Name("q120_vec_mat1col_product_baa_avx2")->ARGS; + +template +void benchmark_bbb(benchmark::State& state) { + const uint64_t ell = state.range(0); + q120_mat1col_product_bbb_precomp* precomp = q120_new_vec_mat1col_product_bbb_precomp(); + + uint64_t* a = new uint64_t[ell * 4]; + uint64_t* b = new uint64_t[ell * 4]; + uint64_t* c = new uint64_t[4]; + for (uint64_t i = 0; i < 4 * ell; i++) { + a[i] = rand(); + b[i] = rand(); + } + for (auto _ : state) { + f(precomp, ell, (q120b*)c, (q120b*)a, (q120b*)b); + } + delete[] c; + delete[] b; + delete[] a; + q120_delete_vec_mat1col_product_bbb_precomp(precomp); +} + +BENCHMARK(benchmark_bbb)->Name("q120_vec_mat1col_product_bbb_ref")->ARGS; +BENCHMARK(benchmark_bbb)->Name("q120_vec_mat1col_product_bbb_avx2")->ARGS; + +template +void benchmark_bbc(benchmark::State& state) { + const uint64_t ell = state.range(0); + q120_mat1col_product_bbc_precomp* precomp = q120_new_vec_mat1col_product_bbc_precomp(); + + uint64_t* a = new uint64_t[ell * 4]; + uint64_t* b = new uint64_t[ell * 4]; + uint64_t* c = new uint64_t[4]; + for (uint64_t i = 0; i < 4 * ell; i++) { + a[i] = rand(); + b[i] = rand(); + } + for (auto _ : state) { + f(precomp, ell, (q120b*)c, (q120b*)a, (q120c*)b); + } + delete[] c; + delete[] b; + delete[] a; + q120_delete_vec_mat1col_product_bbc_precomp(precomp); +} + +BENCHMARK(benchmark_bbc)->Name("q120_vec_mat1col_product_bbc_ref")->ARGS; +BENCHMARK(benchmark_bbc)->Name("q120_vec_mat1col_product_bbc_avx2")->ARGS; + +EXPORT void q120x2_vec_mat2cols_product_bbc_avx2(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell, + q120b* const res, const q120b* const x, const q120c* const y); +EXPORT void q120x2_vec_mat1col_product_bbc_avx2(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell, + q120b* const res, const q120b* const x, const q120c* const y); + +template +void benchmark_x2c2_bbc(benchmark::State& state) { + const uint64_t ell = state.range(0); + q120_mat1col_product_bbc_precomp* precomp = q120_new_vec_mat1col_product_bbc_precomp(); + + uint64_t* a = new uint64_t[ell * 8]; + uint64_t* b = new uint64_t[ell * 16]; + uint64_t* c = new uint64_t[16]; + for (uint64_t i = 0; i < 8 * ell; i++) { + a[i] = rand(); + } + for (uint64_t i = 0; i < 16 * ell; i++) { + b[i] = rand(); + } + for (auto _ : state) { + f(precomp, ell, (q120b*)c, (q120b*)a, (q120c*)b); + } + delete[] c; + delete[] b; + delete[] a; + q120_delete_vec_mat1col_product_bbc_precomp(precomp); +} + +BENCHMARK(benchmark_x2c2_bbc)->Name("q120x2_vec_mat2col_product_bbc_avx2")->ARGS; + +template +void benchmark_x2c1_bbc(benchmark::State& state) { + const uint64_t ell = state.range(0); + q120_mat1col_product_bbc_precomp* precomp = q120_new_vec_mat1col_product_bbc_precomp(); + + uint64_t* a = new uint64_t[ell * 8]; + uint64_t* b = new uint64_t[ell * 8]; + uint64_t* c = new uint64_t[8]; + for (uint64_t i = 0; i < 8 * ell; i++) { + a[i] = rand(); + } + for (uint64_t i = 0; i < 8 * ell; i++) { + b[i] = rand(); + } + for (auto _ : state) { + f(precomp, ell, (q120b*)c, (q120b*)a, (q120c*)b); + } + delete[] c; + delete[] b; + delete[] a; + q120_delete_vec_mat1col_product_bbc_precomp(precomp); +} + +BENCHMARK(benchmark_x2c1_bbc)->Name("q120x2_vec_mat1col_product_bbc_avx2")->ARGS; + +BENCHMARK_MAIN(); diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_q120_arithmetic_test.cpp b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_q120_arithmetic_test.cpp new file mode 100644 index 0000000..ca261ff --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_q120_arithmetic_test.cpp @@ -0,0 +1,437 @@ +#include + +#include +#include + +#include "spqlios/q120/q120_arithmetic.h" +#include "test/testlib/negacyclic_polynomial.h" +#include "test/testlib/ntt120_layouts.h" +#include "testlib/mod_q120.h" + +typedef typeof(q120_vec_mat1col_product_baa_ref) vec_mat1col_product_baa_f; + +void test_vec_mat1col_product_baa(vec_mat1col_product_baa_f vec_mat1col_product_baa) { + q120_mat1col_product_baa_precomp* precomp = q120_new_vec_mat1col_product_baa_precomp(); + for (uint64_t ell : {1, 2, 100, 10000}) { + std::vector a(ell * 4); + std::vector b(ell * 4); + std::vector res(4); + uint64_t* pa = a.data(); + uint64_t* pb = b.data(); + uint64_t* pr = res.data(); + // generate some random data + uniform_q120b(pr); + for (uint64_t i = 0; i < ell; ++i) { + uniform_q120a(pa + 4 * i); + uniform_q120a(pb + 4 * i); + } + // compute the expected result + mod_q120 expect_r; + for (uint64_t i = 0; i < ell; ++i) { + expect_r += mod_q120::from_q120a(pa + 4 * i) * mod_q120::from_q120a(pb + 4 * i); + } + // compute the function + vec_mat1col_product_baa(precomp, ell, (q120b*)pr, (q120a*)pa, (q120a*)pb); + mod_q120 comp_r = mod_q120::from_q120b(pr); + // check for equality + ASSERT_EQ(comp_r, expect_r) << ell; + } + q120_delete_vec_mat1col_product_baa_precomp(precomp); +} + +TEST(q120_arithmetic, q120_vec_mat1col_product_baa_ref) { + test_vec_mat1col_product_baa(q120_vec_mat1col_product_baa_ref); +} +#ifdef __x86_64__ +TEST(q120_arithmetic, q120_vec_mat1col_product_baa_avx2) { + test_vec_mat1col_product_baa(q120_vec_mat1col_product_baa_avx2); +} +#endif + +typedef typeof(q120_vec_mat1col_product_bbb_ref) vec_mat1col_product_bbb_f; + +void test_vec_mat1col_product_bbb(vec_mat1col_product_bbb_f vec_mat1col_product_bbb) { + q120_mat1col_product_bbb_precomp* precomp = q120_new_vec_mat1col_product_bbb_precomp(); + for (uint64_t ell : {1, 2, 100, 10000}) { + std::vector a(ell * 4); + std::vector b(ell * 4); + std::vector res(4); + uint64_t* pa = a.data(); + uint64_t* pb = b.data(); + uint64_t* pr = res.data(); + // generate some random data + uniform_q120b(pr); + for (uint64_t i = 0; i < ell; ++i) { + uniform_q120b(pa + 4 * i); + uniform_q120b(pb + 4 * i); + } + // compute the expected result + mod_q120 expect_r; + for (uint64_t i = 0; i < ell; ++i) { + expect_r += mod_q120::from_q120b(pa + 4 * i) * mod_q120::from_q120b(pb + 4 * i); + } + // compute the function + vec_mat1col_product_bbb(precomp, ell, (q120b*)pr, (q120b*)pa, (q120b*)pb); + mod_q120 comp_r = mod_q120::from_q120b(pr); + // check for equality + ASSERT_EQ(comp_r, expect_r); + } + q120_delete_vec_mat1col_product_bbb_precomp(precomp); +} + +TEST(q120_arithmetic, q120_vec_mat1col_product_bbb_ref) { + test_vec_mat1col_product_bbb(q120_vec_mat1col_product_bbb_ref); +} +#ifdef __x86_64__ +TEST(q120_arithmetic, q120_vec_mat1col_product_bbb_avx2) { + test_vec_mat1col_product_bbb(q120_vec_mat1col_product_bbb_avx2); +} +#endif + +typedef typeof(q120_vec_mat1col_product_bbc_ref) vec_mat1col_product_bbc_f; + +void test_vec_mat1col_product_bbc(vec_mat1col_product_bbc_f vec_mat1col_product_bbc) { + q120_mat1col_product_bbc_precomp* precomp = q120_new_vec_mat1col_product_bbc_precomp(); + for (uint64_t ell : {1, 2, 100, 10000}) { + std::vector a(ell * 4); + std::vector b(ell * 4); + std::vector res(4); + uint64_t* pa = a.data(); + uint64_t* pb = b.data(); + uint64_t* pr = res.data(); + // generate some random data + uniform_q120b(pr); + for (uint64_t i = 0; i < ell; ++i) { + uniform_q120b(pa + 4 * i); + uniform_q120c(pb + 4 * i); + } + // compute the expected result + mod_q120 expect_r; + for (uint64_t i = 0; i < ell; ++i) { + expect_r += mod_q120::from_q120b(pa + 4 * i) * mod_q120::from_q120c(pb + 4 * i); + } + // compute the function + vec_mat1col_product_bbc(precomp, ell, (q120b*)pr, (q120b*)pa, (q120c*)pb); + mod_q120 comp_r = mod_q120::from_q120b(pr); + // check for equality + ASSERT_EQ(comp_r, expect_r); + } + q120_delete_vec_mat1col_product_bbc_precomp(precomp); +} + +TEST(q120_arithmetic, q120_vec_mat1col_product_bbc_ref) { + test_vec_mat1col_product_bbc(q120_vec_mat1col_product_bbc_ref); +} +#ifdef __x86_64__ +TEST(q120_arithmetic, q120_vec_mat1col_product_bbc_avx2) { + test_vec_mat1col_product_bbc(q120_vec_mat1col_product_bbc_avx2); +} +#endif + +EXPORT void q120x2_vec_mat2cols_product_bbc_avx2(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell, + q120b* const res, const q120b* const x, const q120c* const y); +EXPORT void q120x2_vec_mat1col_product_bbc_avx2(q120_mat1col_product_bbc_precomp* precomp, const uint64_t ell, + q120b* const res, const q120b* const x, const q120c* const y); + +typedef typeof(q120x2_vec_mat2cols_product_bbc_avx2) q120x2_prod_bbc_f; + +void test_q120x2_vec_mat2cols_product_bbc(q120x2_prod_bbc_f q120x2_prod_bbc) { + q120_mat1col_product_bbc_precomp* precomp = q120_new_vec_mat1col_product_bbc_precomp(); + for (uint64_t ell : {1, 2, 100, 10000}) { + std::vector a(ell * 8); + std::vector b(ell * 16); + std::vector res(16); + uint64_t* pa = a.data(); + uint64_t* pb = b.data(); + uint64_t* pr = res.data(); + // generate some random data + uniform_q120b(pr); + for (uint64_t i = 0; i < 2 * ell; ++i) { + uniform_q120b(pa + 4 * i); + } + for (uint64_t i = 0; i < 4 * ell; ++i) { + uniform_q120c(pb + 4 * i); + } + // compute the expected result + mod_q120 expect_r[4]; + for (uint64_t i = 0; i < ell; ++i) { + mod_q120 va = mod_q120::from_q120b(pa + 8 * i); + mod_q120 vb = mod_q120::from_q120b(pa + 8 * i + 4); + mod_q120 m1a = mod_q120::from_q120c(pb + 16 * i); + mod_q120 m1b = mod_q120::from_q120c(pb + 16 * i + 4); + mod_q120 m2a = mod_q120::from_q120c(pb + 16 * i + 8); + mod_q120 m2b = mod_q120::from_q120c(pb + 16 * i + 12); + expect_r[0] += va * m1a; + expect_r[1] += vb * m1b; + expect_r[2] += va * m2a; + expect_r[3] += vb * m2b; + } + // compute the function + q120x2_prod_bbc(precomp, ell, (q120b*)pr, (q120b*)pa, (q120c*)pb); + // check for equality + ASSERT_EQ(mod_q120::from_q120b(pr), expect_r[0]); + ASSERT_EQ(mod_q120::from_q120b(pr + 4), expect_r[1]); + ASSERT_EQ(mod_q120::from_q120b(pr + 8), expect_r[2]); + ASSERT_EQ(mod_q120::from_q120b(pr + 12), expect_r[3]); + } + q120_delete_vec_mat1col_product_bbc_precomp(precomp); +} + +TEST(q120_arithmetic, q120x2_vec_mat2cols_product_bbc_ref) { + test_q120x2_vec_mat2cols_product_bbc(q120x2_vec_mat2cols_product_bbc_ref); +} +#ifdef __x86_64__ +TEST(q120_arithmetic, q120x2_vec_mat2cols_product_bbc_avx2) { + test_q120x2_vec_mat2cols_product_bbc(q120x2_vec_mat2cols_product_bbc_avx2); +} +#endif + +typedef typeof(q120x2_vec_mat1col_product_bbc_avx2) q120x2_c1_prod_bbc_f; + +void test_q120x2_vec_mat1col_product_bbc(q120x2_c1_prod_bbc_f q120x2_c1_prod_bbc) { + q120_mat1col_product_bbc_precomp* precomp = q120_new_vec_mat1col_product_bbc_precomp(); + for (uint64_t ell : {1, 2, 100, 10000}) { + std::vector a(ell * 8); + std::vector b(ell * 8); + std::vector res(8); + uint64_t* pa = a.data(); + uint64_t* pb = b.data(); + uint64_t* pr = res.data(); + // generate some random data + uniform_q120b(pr); + for (uint64_t i = 0; i < 2 * ell; ++i) { + uniform_q120b(pa + 4 * i); + } + for (uint64_t i = 0; i < 2 * ell; ++i) { + uniform_q120c(pb + 4 * i); + } + // compute the expected result + mod_q120 expect_r[2]; + for (uint64_t i = 0; i < ell; ++i) { + mod_q120 va = mod_q120::from_q120b(pa + 8 * i); + mod_q120 vb = mod_q120::from_q120b(pa + 8 * i + 4); + mod_q120 m1a = mod_q120::from_q120c(pb + 8 * i); + mod_q120 m1b = mod_q120::from_q120c(pb + 8 * i + 4); + expect_r[0] += va * m1a; + expect_r[1] += vb * m1b; + } + // compute the function + q120x2_c1_prod_bbc(precomp, ell, (q120b*)pr, (q120b*)pa, (q120c*)pb); + // check for equality + ASSERT_EQ(mod_q120::from_q120b(pr), expect_r[0]); + ASSERT_EQ(mod_q120::from_q120b(pr + 4), expect_r[1]); + } + q120_delete_vec_mat1col_product_bbc_precomp(precomp); +} + +TEST(q120_arithmetic, q120x2_vec_mat1col_product_bbc_ref) { + test_q120x2_vec_mat1col_product_bbc(q120x2_vec_mat1col_product_bbc_ref); +} +#ifdef __x86_64__ +TEST(q120_arithmetic, q120x2_vec_mat1col_product_bbc_avx2) { + test_q120x2_vec_mat1col_product_bbc(q120x2_vec_mat1col_product_bbc_avx2); +} +#endif + +typedef typeof(q120x2_extract_1blk_from_q120b_ref) q120x2_extract_f; +void test_q120x2_extract_1blk(q120x2_extract_f q120x2_extract) { + for (uint64_t n : {2, 4, 64}) { + ntt120_vec_znx_dft_layout v(n, 1); + std::vector r(8); + std::vector expect(8); + for (uint64_t blk = 0; blk < n / 2; ++blk) { + for (uint64_t i = 0; i < 8; ++i) { + expect[i] = uniform_u64(); + } + memcpy(v.get_blk(0, blk), expect.data(), 8 * sizeof(uint64_t)); + q120x2_extract_1blk_from_q120b_ref(n, blk, (q120x2b*)r.data(), (q120b*)v.data); + ASSERT_EQ(r, expect); + } + } +} + +TEST(q120_arithmetic, q120x2_extract_1blk_from_q120b_ref) { + test_q120x2_extract_1blk(q120x2_extract_1blk_from_q120b_ref); +} + +typedef typeof(q120x2_extract_1blk_from_contiguous_q120b_ref) q120x2_extract_vec_f; +void test_q120x2_extract_1blk_vec(q120x2_extract_vec_f q120x2_extract) { + for (uint64_t n : {2, 4, 32}) { + for (uint64_t size : {1, 2, 7}) { + ntt120_vec_znx_dft_layout v(n, size); + std::vector r(8 * size); + std::vector expect(8 * size); + for (uint64_t blk = 0; blk < n / 2; ++blk) { + for (uint64_t i = 0; i < 8 * size; ++i) { + expect[i] = uniform_u64(); + } + for (uint64_t i = 0; i < size; ++i) { + memcpy(v.get_blk(i, blk), expect.data() + 8 * i, 8 * sizeof(uint64_t)); + } + q120x2_extract(n, size, blk, (q120x2b*)r.data(), (q120b*)v.data); + ASSERT_EQ(r, expect); + } + } + } +} + +TEST(q120_arithmetic, q120x2_extract_1blk_from_contiguous_q120b_ref) { + test_q120x2_extract_1blk_vec(q120x2_extract_1blk_from_contiguous_q120b_ref); +} + +typedef typeof(q120x2b_save_1blk_to_q120b_ref) q120x2_save_f; +void test_q120x2_save_1blk(q120x2_save_f q120x2_save) { + for (uint64_t n : {2, 4, 64}) { + ntt120_vec_znx_dft_layout v(n, 1); + std::vector r(8); + std::vector expect(8); + for (uint64_t blk = 0; blk < n / 2; ++blk) { + for (uint64_t i = 0; i < 8; ++i) { + expect[i] = uniform_u64(); + } + q120x2_save(n, blk, (q120b*)v.data, (q120x2b*)expect.data()); + memcpy(r.data(), v.get_blk(0, blk), 8 * sizeof(uint64_t)); + ASSERT_EQ(r, expect); + } + } +} + +TEST(q120_arithmetic, q120x2b_save_1blk_to_q120b_ref) { test_q120x2_save_1blk(q120x2b_save_1blk_to_q120b_ref); } + +TEST(q120_arithmetic, q120_add_bbb_simple) { + for (const uint64_t n : {2, 4, 1024}) { + std::vector a(n * 4); + std::vector b(n * 4); + std::vector r(n * 4); + uint64_t* pa = a.data(); + uint64_t* pb = b.data(); + uint64_t* pr = r.data(); + + // generate some random data + for (uint64_t i = 0; i < n; ++i) { + uniform_q120b(pa + 4 * i); + uniform_q120b(pb + 4 * i); + } + + // compute the function + q120_add_bbb_simple(n, (q120b*)pr, (q120b*)pa, (q120b*)pb); + + for (uint64_t i = 0; i < n; ++i) { + mod_q120 ae = mod_q120::from_q120b(pa + 4 * i); + mod_q120 be = mod_q120::from_q120b(pb + 4 * i); + mod_q120 re = mod_q120::from_q120b(pr + 4 * i); + + ASSERT_EQ(ae + be, re); + } + } +} + +TEST(q120_arithmetic, q120_add_ccc_simple) { + for (const uint64_t n : {2, 4, 1024}) { + std::vector a(n * 4); + std::vector b(n * 4); + std::vector r(n * 4); + uint64_t* pa = a.data(); + uint64_t* pb = b.data(); + uint64_t* pr = r.data(); + + // generate some random data + for (uint64_t i = 0; i < n; ++i) { + uniform_q120c(pa + 4 * i); + uniform_q120c(pb + 4 * i); + } + + // compute the function + q120_add_ccc_simple(n, (q120c*)pr, (q120c*)pa, (q120c*)pb); + + for (uint64_t i = 0; i < n; ++i) { + mod_q120 ae = mod_q120::from_q120c(pa + 4 * i); + mod_q120 be = mod_q120::from_q120c(pb + 4 * i); + mod_q120 re = mod_q120::from_q120c(pr + 4 * i); + + ASSERT_EQ(ae + be, re); + } + } +} + +TEST(q120_arithmetic, q120_c_from_b_simple) { + for (const uint64_t n : {2, 4, 1024}) { + std::vector a(n * 4); + std::vector r(n * 4); + uint64_t* pa = a.data(); + uint64_t* pr = r.data(); + + // generate some random data + for (uint64_t i = 0; i < n; ++i) { + uniform_q120b(pa + 4 * i); + } + + // compute the function + q120_c_from_b_simple(n, (q120c*)pr, (q120b*)pa); + + for (uint64_t i = 0; i < n; ++i) { + mod_q120 ae = mod_q120::from_q120b(pa + 4 * i); + mod_q120 re = mod_q120::from_q120c(pr + 4 * i); + + ASSERT_EQ(ae, re); + } + } +} + +TEST(q120_arithmetic, q120_b_from_znx64_simple) { + for (const uint64_t n : {2, 4, 1024}) { + znx_i64 x = znx_i64::random_log2bound(n, 62); + std::vector r(n * 4); + uint64_t* pr = r.data(); + + q120_b_from_znx64_simple(n, (q120b*)pr, x.data()); + + for (uint64_t i = 0; i < n; ++i) { + mod_q120 re = mod_q120::from_q120b(pr + 4 * i); + + for (uint64_t k = 0; k < 4; ++k) { + ASSERT_EQ(centermod(x.get_coeff(i), mod_q120::Qi[k]), re.a[k]); + } + } + } +} + +TEST(q120_arithmetic, q120_c_from_znx64_simple) { + for (const uint64_t n : {2, 4, 1024}) { + znx_i64 x = znx_i64::random(n); + std::vector r(n * 4); + uint64_t* pr = r.data(); + + q120_c_from_znx64_simple(n, (q120c*)pr, x.data()); + + for (uint64_t i = 0; i < n; ++i) { + mod_q120 re = mod_q120::from_q120c(pr + 4 * i); + + for (uint64_t k = 0; k < 4; ++k) { + ASSERT_EQ(centermod(x.get_coeff(i), mod_q120::Qi[k]), re.a[k]); + } + } + } +} + +TEST(q120_arithmetic, q120_b_to_znx128_simple) { + for (const uint64_t n : {2, 4, 1024}) { + std::vector x(n * 4); + uint64_t* px = x.data(); + + // generate some random data + for (uint64_t i = 0; i < n; ++i) { + uniform_q120b(px + 4 * i); + } + + znx_i128 r(n); + q120_b_to_znx128_simple(n, r.data(), (q120b*)px); + + for (uint64_t i = 0; i < n; ++i) { + mod_q120 xe = mod_q120::from_q120b(px + 4 * i); + for (uint64_t k = 0; k < 4; ++k) { + ASSERT_EQ(centermod((int64_t)(r.get_coeff(i) % mod_q120::Qi[k]), mod_q120::Qi[k]), xe.a[k]); + } + } + } +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_q120_ntt_bench.cpp b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_q120_ntt_bench.cpp new file mode 100644 index 0000000..50e1f15 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_q120_ntt_bench.cpp @@ -0,0 +1,44 @@ +#include + +#include + +#include "spqlios/q120/q120_ntt.h" + +#define ARGS Arg(1 << 10)->Arg(1 << 11)->Arg(1 << 12)->Arg(1 << 13)->Arg(1 << 14)->Arg(1 << 15)->Arg(1 << 16) + +template +void benchmark_ntt(benchmark::State& state) { + const uint64_t n = state.range(0); + q120_ntt_precomp* precomp = q120_new_ntt_bb_precomp(n); + + uint64_t* px = new uint64_t[n * 4]; + for (uint64_t i = 0; i < 4 * n; i++) { + px[i] = (rand() << 31) + rand(); + } + for (auto _ : state) { + f(precomp, (q120b*)px); + } + delete[] px; + q120_del_ntt_bb_precomp(precomp); +} + +template +void benchmark_intt(benchmark::State& state) { + const uint64_t n = state.range(0); + q120_ntt_precomp* precomp = q120_new_intt_bb_precomp(n); + + uint64_t* px = new uint64_t[n * 4]; + for (uint64_t i = 0; i < 4 * n; i++) { + px[i] = (rand() << 31) + rand(); + } + for (auto _ : state) { + f(precomp, (q120b*)px); + } + delete[] px; + q120_del_intt_bb_precomp(precomp); +} + +BENCHMARK(benchmark_ntt)->Name("q120_ntt_bb_avx2")->ARGS; +BENCHMARK(benchmark_intt)->Name("q120_intt_bb_avx2")->ARGS; + +BENCHMARK_MAIN(); diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_q120_ntt_test.cpp b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_q120_ntt_test.cpp new file mode 100644 index 0000000..61c9dd6 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_q120_ntt_test.cpp @@ -0,0 +1,174 @@ +#include + +#include +#include +#include +#include +#include + +#include "spqlios/q120/q120_common.h" +#include "spqlios/q120/q120_ntt.h" +#include "testlib/mod_q120.h" + +std::vector q120_ntt(const std::vector& x) { + const uint64_t n = x.size(); + + mod_q120 omega_2pow17{OMEGA1, OMEGA2, OMEGA3, OMEGA4}; + mod_q120 omega = pow(omega_2pow17, (1 << 16) / n); + + std::vector res(n); + for (uint64_t i = 0; i < n; ++i) { + res[i] = x[i]; + } + + for (uint64_t i = 0; i < n; ++i) { + res[i] = res[i] * pow(omega, i); + } + + for (uint64_t nn = n; nn > 1; nn /= 2) { + const uint64_t halfnn = nn / 2; + const uint64_t m = n / halfnn; + + for (uint64_t j = 0; j < n; j += nn) { + for (uint64_t k = 0; k < halfnn; ++k) { + mod_q120 a = res[j + k]; + mod_q120 b = res[j + halfnn + k]; + + res[j + k] = a + b; + res[j + halfnn + k] = (a - b) * pow(omega, k * m); + } + } + } + + return res; +} + +std::vector q120_intt(const std::vector& x) { + const uint64_t n = x.size(); + + mod_q120 omega_2pow17{OMEGA1, OMEGA2, OMEGA3, OMEGA4}; + mod_q120 omega = pow(omega_2pow17, (1 << 16) / n); + + std::vector res(n); + for (uint64_t i = 0; i < n; ++i) { + res[i] = x[i]; + } + + for (uint64_t nn = 2; nn <= n; nn *= 2) { + const uint64_t halfnn = nn / 2; + const uint64_t m = n / halfnn; + + for (uint64_t j = 0; j < n; j += nn) { + for (uint64_t k = 0; k < halfnn; ++k) { + mod_q120 a = res[j + k]; + mod_q120 b = res[j + halfnn + k]; + + mod_q120 bo = b * pow(omega, -k * m); + res[j + k] = a + bo; + res[j + halfnn + k] = a - bo; + } + } + } + + mod_q120 n_q120{(int64_t)n, (int64_t)n, (int64_t)n, (int64_t)n}; + mod_q120 n_inv_q120 = pow(n_q120, -1); + + for (uint64_t i = 0; i < n; ++i) { + mod_q120 po = pow(omega, -i) * n_inv_q120; + res[i] = res[i] * po; + } + + return res; +} + +class ntt : public testing::TestWithParam {}; + +#ifdef __x86_64__ + +TEST_P(ntt, q120_ntt_bb_avx2) { + const uint64_t n = GetParam(); + q120_ntt_precomp* precomp = q120_new_ntt_bb_precomp(n); + + std::vector x(n * 4); + uint64_t* px = x.data(); + for (uint64_t i = 0; i < 4 * n; i += 4) { + uniform_q120b(px + i); + } + + std::vector x_modq(n); + for (uint64_t i = 0; i < n; ++i) { + x_modq[i] = mod_q120::from_q120b(px + 4 * i); + } + + std::vector y_exp = q120_ntt(x_modq); + + q120_ntt_bb_avx2(precomp, (q120b*)px); + + for (uint64_t i = 0; i < n; ++i) { + mod_q120 comp_r = mod_q120::from_q120b(px + 4 * i); + ASSERT_EQ(comp_r, y_exp[i]) << i; + } + + q120_del_ntt_bb_precomp(precomp); +} + +TEST_P(ntt, q120_intt_bb_avx2) { + const uint64_t n = GetParam(); + q120_ntt_precomp* precomp = q120_new_intt_bb_precomp(n); + + std::vector x(n * 4); + uint64_t* px = x.data(); + for (uint64_t i = 0; i < 4 * n; i += 4) { + uniform_q120b(px + i); + } + + std::vector x_modq(n); + for (uint64_t i = 0; i < n; ++i) { + x_modq[i] = mod_q120::from_q120b(px + 4 * i); + } + + q120_intt_bb_avx2(precomp, (q120b*)px); + + std::vector y_exp = q120_intt(x_modq); + for (uint64_t i = 0; i < n; ++i) { + mod_q120 comp_r = mod_q120::from_q120b(px + 4 * i); + ASSERT_EQ(comp_r, y_exp[i]) << i; + } + + q120_del_intt_bb_precomp(precomp); +} + +TEST_P(ntt, q120_ntt_intt_bb_avx2) { + const uint64_t n = GetParam(); + q120_ntt_precomp* precomp_ntt = q120_new_ntt_bb_precomp(n); + q120_ntt_precomp* precomp_intt = q120_new_intt_bb_precomp(n); + + std::vector x(n * 4); + uint64_t* px = x.data(); + for (uint64_t i = 0; i < 4 * n; i += 4) { + uniform_q120b(px + i); + } + + std::vector x_modq(n); + for (uint64_t i = 0; i < n; ++i) { + x_modq[i] = mod_q120::from_q120b(px + 4 * i); + } + + q120_ntt_bb_avx2(precomp_ntt, (q120b*)px); + q120_intt_bb_avx2(precomp_intt, (q120b*)px); + + for (uint64_t i = 0; i < n; ++i) { + mod_q120 comp_r = mod_q120::from_q120b(px + 4 * i); + ASSERT_EQ(comp_r, x_modq[i]) << i; + } + + q120_del_intt_bb_precomp(precomp_intt); + q120_del_ntt_bb_precomp(precomp_ntt); +} + +INSTANTIATE_TEST_SUITE_P(q120, ntt, + testing::Values(1, 2, 4, 16, 256, UINT64_C(1) << 10, UINT64_C(1) << 11, UINT64_C(1) << 12, + UINT64_C(1) << 13, UINT64_C(1) << 14, UINT64_C(1) << 15, UINT64_C(1) << 16), + testing::PrintToStringParamName()); + +#endif diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_reim4_arithmetic_bench.cpp b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_reim4_arithmetic_bench.cpp new file mode 100644 index 0000000..2acd7df --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_reim4_arithmetic_bench.cpp @@ -0,0 +1,52 @@ +#include + +#include "spqlios/reim4/reim4_arithmetic.h" + +void init_random_values(uint64_t n, double* v) { + for (uint64_t i = 0; i < n; ++i) + v[i] = (double(rand() % (UINT64_C(1) << 14)) - (UINT64_C(1) << 13)) / (UINT64_C(1) << 12); +} + +// Run the benchmark +BENCHMARK_MAIN(); + +#undef ARGS +#define ARGS Args({47, 16384})->Args({93, 32768}) + +/* + * reim4_vec_mat1col_product + * reim4_vec_mat2col_product + * reim4_vec_mat3col_product + * reim4_vec_mat4col_product + */ + +template +void benchmark_reim4_vec_matXcols_product(benchmark::State& state) { + const uint64_t nrows = state.range(0); + + double* u = new double[nrows * 8]; + init_random_values(8 * nrows, u); + double* v = new double[nrows * X * 8]; + init_random_values(X * 8 * nrows, v); + double* dst = new double[X * 8]; + + for (auto _ : state) { + fnc(nrows, dst, u, v); + } + + delete[] dst; + delete[] v; + delete[] u; +} + +#undef ARGS +#define ARGS Arg(128)->Arg(1024)->Arg(4096) + +#ifdef __x86_64__ +BENCHMARK(benchmark_reim4_vec_matXcols_product<1, reim4_vec_mat1col_product_avx2>)->ARGS; +// TODO: please remove when fixed: +BENCHMARK(benchmark_reim4_vec_matXcols_product<2, reim4_vec_mat2cols_product_avx2>)->ARGS; +#endif +BENCHMARK(benchmark_reim4_vec_matXcols_product<1, reim4_vec_mat1col_product_ref>)->ARGS; +BENCHMARK(benchmark_reim4_vec_matXcols_product<2, reim4_vec_mat2cols_product_ref>)->ARGS; diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_reim4_arithmetic_test.cpp b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_reim4_arithmetic_test.cpp new file mode 100644 index 0000000..60787f7 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_reim4_arithmetic_test.cpp @@ -0,0 +1,291 @@ +#include + +#include +#include + +#include "../spqlios/reim4/reim4_arithmetic.h" +#include "test/testlib/reim4_elem.h" + +/// Actual tests + +typedef typeof(reim4_extract_1blk_from_reim_ref) reim4_extract_1blk_from_reim_f; +void test_reim4_extract_1blk_from_reim(reim4_extract_1blk_from_reim_f reim4_extract_1blk_from_reim) { + static const uint64_t numtrials = 100; + for (uint64_t m : {4, 8, 16, 1024, 4096, 32768}) { + double* v = (double*)malloc(2 * m * sizeof(double)); + double* w = (double*)malloc(8 * sizeof(double)); + reim_view vv(m, v); + for (uint64_t i = 0; i < numtrials; ++i) { + reim4_elem el = gaussian_reim4(); + uint64_t blk = rand() % (m / 4); + vv.set_blk(blk, el); + reim4_extract_1blk_from_reim(m, blk, w, v); + reim4_elem actual(w); + ASSERT_EQ(el, actual); + } + free(v); + free(w); + } +} + +TEST(reim4_arithmetic, reim4_extract_1blk_from_reim_ref) { + test_reim4_extract_1blk_from_reim(reim4_extract_1blk_from_reim_ref); +} +#ifdef __x86_64__ +TEST(reim4_arithmetic, reim4_extract_1blk_from_reim_avx) { + test_reim4_extract_1blk_from_reim(reim4_extract_1blk_from_reim_avx); +} +#endif + +typedef typeof(reim4_save_1blk_to_reim_ref) reim4_save_1blk_to_reim_f; +void test_reim4_save_1blk_to_reim(reim4_save_1blk_to_reim_f reim4_save_1blk_to_reim) { + static const uint64_t numtrials = 100; + for (uint64_t m : {4, 8, 16, 1024, 4096, 32768}) { + double* v = (double*)malloc(2 * m * sizeof(double)); + double* w = (double*)malloc(8 * sizeof(double)); + reim_view vv(m, v); + for (uint64_t i = 0; i < numtrials; ++i) { + reim4_elem el = gaussian_reim4(); + el.save_as(w); + uint64_t blk = rand() % (m / 4); + reim4_save_1blk_to_reim_ref(m, blk, v, w); + reim4_elem actual = vv.get_blk(blk); + ASSERT_EQ(el, actual); + } + free(v); + free(w); + } +} + +TEST(reim4_arithmetic, reim4_save_1blk_to_reim_ref) { test_reim4_save_1blk_to_reim(reim4_save_1blk_to_reim_ref); } +#ifdef __x86_64__ +TEST(reim4_arithmetic, reim4_save_1blk_to_reim_avx) { test_reim4_save_1blk_to_reim(reim4_save_1blk_to_reim_avx); } +#endif + +typedef typeof(reim4_add_1blk_to_reim_ref) reim4_add_1blk_to_reim_f; +void test_reim4_add_1blk_to_reim(reim4_add_1blk_to_reim_f reim4_add_1blk_to_reim) { + static const uint64_t numtrials = 100; + for (uint64_t m : {4, 8, 16, 1024, 4096, 32768}) { + double* v = (double*)malloc(2 * m * sizeof(double)); + double* w1 = (double*)malloc(8 * sizeof(double)); + double* w2 = (double*)malloc(8 * sizeof(double)); + // double* tmp = (double*)malloc(8 * sizeof(double)); + reim_view vv(m, v); + for (uint64_t i = 0; i < numtrials; ++i) { + reim4_elem el1 = gaussian_reim4(); + reim4_elem el2 = gaussian_reim4(); + + // v[blk] = w1 + uint64_t blk = rand() % (m / 4); + el1.save_as(w1); + reim4_save_1blk_to_reim_ref(m, blk, v, w1); + el2.save_as(w2); + + // v[blk] += w2 + reim4_add_1blk_to_reim_ref(m, blk, v, w2); + + reim4_elem w1w2sum = vv.get_blk(blk); + reim4_elem expected_sum = reim4_elem::zero(); + reim4_add(expected_sum.value, w1, w2); + ASSERT_EQ(expected_sum, w1w2sum); + } + free(v); + free(w1); + free(w2); + } +} + +TEST(reim4_arithmetic, reim4_add_1blk_to_reim_ref) { test_reim4_add_1blk_to_reim(reim4_add_1blk_to_reim_ref); } +#ifdef __x86_64__ +TEST(reim4_arithmetic, reim4_add_1blk_to_reim_avx) { test_reim4_add_1blk_to_reim(reim4_add_1blk_to_reim_ref); } +#endif + +typedef typeof(reim4_extract_1blk_from_contiguous_reim_ref) reim4_extract_1blk_from_contiguous_reim_f; +void test_reim4_extract_1blk_from_contiguous_reim( + reim4_extract_1blk_from_contiguous_reim_f reim4_extract_1blk_from_contiguous_reim) { + static const uint64_t numtrials = 20; + for (uint64_t m : {4, 8, 16, 1024, 4096, 32768}) { + for (uint64_t nrows : {1, 2, 5, 128}) { + double* v = (double*)malloc(2 * m * nrows * sizeof(double)); + double* w = (double*)malloc(8 * nrows * sizeof(double)); + reim_vector_view vv(m, nrows, v); + reim4_array_view ww(nrows, w); + for (uint64_t i = 0; i < numtrials; ++i) { + uint64_t blk = rand() % (m / 4); + for (uint64_t j = 0; j < nrows; ++j) { + reim4_elem el = gaussian_reim4(); + vv.row(j).set_blk(blk, el); + } + reim4_extract_1blk_from_contiguous_reim_ref(m, nrows, blk, w, v); + for (uint64_t j = 0; j < nrows; ++j) { + reim4_elem el = vv.row(j).get_blk(blk); + reim4_elem actual = ww.get(j); + ASSERT_EQ(el, actual); + } + } + free(v); + free(w); + } + } +} + +TEST(reim4_arithmetic, reim4_extract_1blk_from_contiguous_reim_ref) { + test_reim4_extract_1blk_from_contiguous_reim(reim4_extract_1blk_from_contiguous_reim_ref); +} +#ifdef __x86_64__ +TEST(reim4_arithmetic, reim4_extract_1blk_from_contiguous_reim_avx) { + test_reim4_extract_1blk_from_contiguous_reim(reim4_extract_1blk_from_contiguous_reim_avx); +} +#endif + +// test of basic arithmetic functions + +TEST(reim4_arithmetic, add) { + reim4_elem x = gaussian_reim4(); + reim4_elem y = gaussian_reim4(); + reim4_elem expect = x + y; + reim4_elem actual; + reim4_add(actual.value, x.value, y.value); + ASSERT_EQ(actual, expect); +} + +TEST(reim4_arithmetic, mul) { + reim4_elem x = gaussian_reim4(); + reim4_elem y = gaussian_reim4(); + reim4_elem expect = x * y; + reim4_elem actual; + reim4_mul(actual.value, x.value, y.value); + ASSERT_EQ(actual, expect); +} + +TEST(reim4_arithmetic, add_mul) { + reim4_elem x = gaussian_reim4(); + reim4_elem y = gaussian_reim4(); + reim4_elem z = gaussian_reim4(); + reim4_elem expect = z; + reim4_elem actual = z; + expect += x * y; + reim4_add_mul(actual.value, x.value, y.value); + ASSERT_EQ(actual, expect) << infty_dist(expect, actual); +} + +// test of dot products + +typedef typeof(reim4_vec_mat1col_product_ref) reim4_vec_mat1col_product_f; +void test_reim4_vec_mat1col_product(reim4_vec_mat1col_product_f product) { + for (uint64_t ell : {1, 2, 5, 13, 69, 129}) { + std::vector actual(8); + std::vector a(ell * 8); + std::vector b(ell * 8); + reim4_array_view va(ell, a.data()); + reim4_array_view vb(ell, b.data()); + reim4_array_view vactual(1, actual.data()); + // initialize random values + for (uint64_t i = 0; i < ell; ++i) { + va.set(i, gaussian_reim4()); + vb.set(i, gaussian_reim4()); + } + // compute the mat1col product + reim4_elem expect; + for (uint64_t i = 0; i < ell; ++i) { + expect += va.get(i) * vb.get(i); + } + // compute the actual product + product(ell, actual.data(), a.data(), b.data()); + // compare + ASSERT_LE(infty_dist(vactual.get(0), expect), 1e-10); + } +} + +TEST(reim4_arithmetic, reim4_vec_mat1col_product_ref) { test_reim4_vec_mat1col_product(reim4_vec_mat1col_product_ref); } +#ifdef __x86_64__ +TEST(reim4_arena, reim4_vec_mat1col_product_avx2) { test_reim4_vec_mat1col_product(reim4_vec_mat1col_product_avx2); } +#endif + +typedef typeof(reim4_vec_mat2cols_product_ref) reim4_vec_mat2col_product_f; +void test_reim4_vec_mat2cols_product(reim4_vec_mat2col_product_f product) { + for (uint64_t ell : {1, 2, 5, 13, 69, 129}) { + std::vector actual(16); + std::vector a(ell * 8); + std::vector b(ell * 16); + reim4_array_view va(ell, a.data()); + reim4_matrix_view vb(ell, 2, b.data()); + reim4_array_view vactual(2, actual.data()); + // initialize random values + for (uint64_t i = 0; i < ell; ++i) { + va.set(i, gaussian_reim4()); + vb.set(i, 0, gaussian_reim4()); + vb.set(i, 1, gaussian_reim4()); + } + // compute the mat1col product + reim4_elem expect[2]; + for (uint64_t i = 0; i < ell; ++i) { + expect[0] += va.get(i) * vb.get(i, 0); + expect[1] += va.get(i) * vb.get(i, 1); + } + // compute the actual product + product(ell, actual.data(), a.data(), b.data()); + // compare + ASSERT_LE(infty_dist(vactual.get(0), expect[0]), 1e-10); + ASSERT_LE(infty_dist(vactual.get(1), expect[1]), 1e-10); + } +} + +TEST(reim4_arithmetic, reim4_vec_mat2cols_product_ref) { + test_reim4_vec_mat2cols_product(reim4_vec_mat2cols_product_ref); +} +#ifdef __x86_64__ +TEST(reim4_arithmetic, reim4_vec_mat2cols_product_avx2) { + test_reim4_vec_mat2cols_product(reim4_vec_mat2cols_product_avx2); +} +#endif + +// for now, we do not need avx implementations, +// so we will keep a single test function +TEST(reim4_arithmetic, reim4_vec_convolution_ref) { + for (uint64_t sizea : {1, 2, 3, 5, 8}) { + for (uint64_t sizeb : {1, 3, 6, 9, 13}) { + std::vector a(8 * sizea); + std::vector b(8 * sizeb); + std::vector expect(8 * (sizea + sizeb - 1)); + std::vector actual(8 * (sizea + sizeb - 1)); + reim4_array_view va(sizea, a.data()); + reim4_array_view vb(sizeb, b.data()); + std::vector vexpect(sizea + sizeb + 3); + reim4_array_view vactual(sizea + sizeb - 1, actual.data()); + for (uint64_t i = 0; i < sizea; ++i) { + va.set(i, gaussian_reim4()); + } + for (uint64_t j = 0; j < sizeb; ++j) { + vb.set(j, gaussian_reim4()); + } + // manual convolution + for (uint64_t i = 0; i < sizea; ++i) { + for (uint64_t j = 0; j < sizeb; ++j) { + vexpect[i + j] += va.get(i) * vb.get(j); + } + } + // partial convolution single coeff + for (uint64_t k = 0; k < sizea + sizeb + 3; ++k) { + double dest[8] = {0}; + reim4_convolution_1coeff_ref(k, dest, a.data(), sizea, b.data(), sizeb); + ASSERT_LE(infty_dist(reim4_elem(dest), vexpect[k]), 1e-10); + } + // partial convolution dual coeff + for (uint64_t k = 0; k < sizea + sizeb + 2; ++k) { + double dest[16] = {0}; + reim4_convolution_2coeff_ref(k, dest, a.data(), sizea, b.data(), sizeb); + ASSERT_LE(infty_dist(reim4_elem(dest), vexpect[k]), 1e-10); + ASSERT_LE(infty_dist(reim4_elem(dest + 8), vexpect[k + 1]), 1e-10); + } + // actual convolution + reim4_convolution_ref(actual.data(), sizea + sizeb - 1, 0, a.data(), sizea, b.data(), sizeb); + for (uint64_t k = 0; k < sizea + sizeb - 1; ++k) { + ASSERT_LE(infty_dist(vactual.get(k), vexpect[k]), 1e-10) << k; + } + } + } +} + +EXPORT void reim4_convolution_ref(double* dest, uint64_t dest_size, uint64_t dest_offset, const double* a, + uint64_t sizea, const double* b, uint64_t sizeb); diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_reim_conversions_test.cpp b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_reim_conversions_test.cpp new file mode 100644 index 0000000..d266f89 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_reim_conversions_test.cpp @@ -0,0 +1,115 @@ +#include +#include + +#include "testlib/test_commons.h" + +TEST(reim_conversions, reim_to_tnx) { + for (uint32_t m : {1, 2, 64, 128, 512}) { + for (double divisor : {1, 2, int(m)}) { + for (uint32_t log2overhead : {1, 2, 10, 18, 35, 42}) { + double maxdiff = pow(2., log2overhead - 50); + std::vector data(2 * m); + std::vector dout(2 * m); + for (uint64_t i = 0; i < 2 * m; ++i) { + data[i] = (uniform_f64_01() - 0.5) * pow(2., log2overhead + 1) * divisor; + } + REIM_TO_TNX_PRECOMP* p = new_reim_to_tnx_precomp(m, divisor, 18); + reim_to_tnx_ref(p, dout.data(), data.data()); + for (uint64_t i = 0; i < 2 * m; ++i) { + ASSERT_LE(fabs(dout[i]), 0.5); + double diff = dout[i] - data[i] / divisor; + double fracdiff = diff - rint(diff); + ASSERT_LE(fabs(fracdiff), maxdiff); + } + delete_reim_to_tnx_precomp(p); + } + } + } +} + +#ifdef __x86_64__ +TEST(reim_conversions, reim_to_tnx_ref_vs_avx) { + for (uint32_t m : {8, 16, 64, 128, 512}) { + for (double divisor : {1, 2, int(m)}) { + for (uint32_t log2overhead : {1, 2, 10, 18, 35, 42}) { + // double maxdiff = pow(2., log2overhead - 50); + std::vector data(2 * m); + std::vector dout1(2 * m); + std::vector dout2(2 * m); + for (uint64_t i = 0; i < 2 * m; ++i) { + data[i] = (uniform_f64_01() - 0.5) * pow(2., log2overhead + 1) * divisor; + } + REIM_TO_TNX_PRECOMP* p = new_reim_to_tnx_precomp(m, divisor, 18); + reim_to_tnx_ref(p, dout1.data(), data.data()); + reim_to_tnx_avx(p, dout2.data(), data.data()); + for (uint64_t i = 0; i < 2 * m; ++i) { + ASSERT_LE(fabs(dout1[i] - dout2[i]), 0.); + } + delete_reim_to_tnx_precomp(p); + } + } + } +} +#endif + +typedef typeof(reim_from_znx64_ref) reim_from_znx64_f; + +void test_reim_from_znx64(reim_from_znx64_f reim_from_znx64, uint64_t maxbnd) { + for (uint32_t m : {4, 8, 16, 64, 16384}) { + REIM_FROM_ZNX64_PRECOMP* p = new_reim_from_znx64_precomp(m, maxbnd); + std::vector data(2 * m); + std::vector dout(2 * m); + for (uint64_t i = 0; i < 2 * m; ++i) { + int64_t magnitude = int64_t(uniform_u64() % (maxbnd + 1)); + data[i] = uniform_i64() >> (63 - magnitude); + REQUIRE_DRAMATICALLY(abs(data[i]) <= (INT64_C(1) << magnitude), "pb"); + } + reim_from_znx64(p, dout.data(), data.data()); + for (uint64_t i = 0; i < 2 * m; ++i) { + ASSERT_EQ(dout[i], double(data[i])) << dout[i] << " " << data[i]; + } + delete_reim_from_znx64_precomp(p); + } +} + +TEST(reim_conversions, reim_from_znx64) { + for (uint64_t maxbnd : {50}) { + test_reim_from_znx64(reim_from_znx64, maxbnd); + } +} +TEST(reim_conversions, reim_from_znx64_ref) { test_reim_from_znx64(reim_from_znx64_ref, 50); } +#ifdef __x86_64__ +TEST(reim_conversions, reim_from_znx64_avx2_bnd50_fma) { test_reim_from_znx64(reim_from_znx64_bnd50_fma, 50); } +#endif + +typedef typeof(reim_to_znx64_ref) reim_to_znx64_f; + +void test_reim_to_znx64(reim_to_znx64_f reim_to_znx64_fcn, int64_t maxbnd) { + for (uint32_t m : {4, 8, 16, 64, 16384}) { + for (double divisor : {1, 2, int(m)}) { + REIM_TO_ZNX64_PRECOMP* p = new_reim_to_znx64_precomp(m, divisor, maxbnd); + std::vector data(2 * m); + std::vector dout(2 * m); + for (uint64_t i = 0; i < 2 * m; ++i) { + int64_t magnitude = int64_t(uniform_u64() % (maxbnd + 11)) - 10; + data[i] = (uniform_f64_01() - 0.5) * pow(2., magnitude + 1) * divisor; + } + reim_to_znx64_fcn(p, dout.data(), data.data()); + for (uint64_t i = 0; i < 2 * m; ++i) { + ASSERT_LE(dout[i] - data[i] / divisor, 0.5) << dout[i] << " " << data[i]; + } + delete_reim_to_znx64_precomp(p); + } + } +} + +TEST(reim_conversions, reim_to_znx64) { + for (uint64_t maxbnd : {63, 50}) { + test_reim_to_znx64(reim_to_znx64, maxbnd); + } +} +TEST(reim_conversions, reim_to_znx64_ref) { test_reim_to_znx64(reim_to_znx64_ref, 63); } +#ifdef __x86_64__ +TEST(reim_conversions, reim_to_znx64_avx2_bnd63_fma) { test_reim_to_znx64(reim_to_znx64_avx2_bnd63_fma, 63); } +TEST(reim_conversions, reim_to_znx64_avx2_bnd50_fma) { test_reim_to_znx64(reim_to_znx64_avx2_bnd50_fma, 50); } +#endif diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_reim_test.cpp b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_reim_test.cpp new file mode 100644 index 0000000..5c32d59 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_reim_test.cpp @@ -0,0 +1,513 @@ +#include + +#include + +#include "gtest/gtest.h" +#include "spqlios/commons_private.h" +#include "spqlios/cplx/cplx_fft_internal.h" +#include "spqlios/reim/reim_fft_internal.h" +#include "spqlios/reim/reim_fft_private.h" + +#ifdef __x86_64__ +TEST(fft, reim_fft_avx2_vs_fft_reim_ref) { + for (uint64_t nn : {16, 32, 64, 1024, 8192, 65536}) { + uint64_t m = nn / 2; + // CPLX_FFT_PRECOMP* tables = new_cplx_fft_precomp(m, 0); + REIM_FFT_PRECOMP* reimtables = new_reim_fft_precomp(m, 0); + CPLX* a = (CPLX*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* a1 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* a2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + int64_t p = 1 << 16; + for (uint32_t i = 0; i < nn / 2; i++) { + a[i][0] = (rand() % p) - p / 2; // between -p/2 and p/2 + a[i][1] = (rand() % p) - p / 2; + } + memcpy(a1, a, nn / 2 * sizeof(CPLX)); + memcpy(a2, a, nn / 2 * sizeof(CPLX)); + reim_fft_ref(reimtables, a2); + reim_fft_avx2_fma(reimtables, a1); + double d = 0; + for (uint32_t i = 0; i < nn / 2; i++) { + double dre = fabs(a1[i] - a2[i]); + double dim = fabs(a1[nn / 2 + i] - a2[nn / 2 + i]); + if (dre > d) d = dre; + if (dim > d) d = dim; + ASSERT_LE(d, nn * 1e-10) << nn; + } + ASSERT_LE(d, nn * 1e-10) << nn; + spqlios_free(a); + spqlios_free(a1); + spqlios_free(a2); + // delete_cplx_fft_precomp(tables); + delete_reim_fft_precomp(reimtables); + } +} +#endif + +#ifdef __x86_64__ +TEST(fft, reim_ifft_avx2_vs_reim_ifft_ref) { + for (uint64_t nn : {16, 32, 64, 1024, 8192, 65536}) { + uint64_t m = nn / 2; + // CPLX_FFT_PRECOMP* tables = new_cplx_fft_precomp(m, 0); + REIM_IFFT_PRECOMP* reimtables = new_reim_ifft_precomp(m, 0); + CPLX* a = (CPLX*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* a1 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* a2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + int64_t p = 1 << 16; + for (uint32_t i = 0; i < nn / 2; i++) { + a[i][0] = (rand() % p) - p / 2; // between -p/2 and p/2 + a[i][1] = (rand() % p) - p / 2; + } + memcpy(a1, a, nn / 2 * sizeof(CPLX)); + memcpy(a2, a, nn / 2 * sizeof(CPLX)); + reim_ifft_ref(reimtables, a2); + reim_ifft_avx2_fma(reimtables, a1); + double d = 0; + for (uint32_t i = 0; i < nn / 2; i++) { + double dre = fabs(a1[i] - a2[i]); + double dim = fabs(a1[nn / 2 + i] - a2[nn / 2 + i]); + if (dre > d) d = dre; + if (dim > d) d = dim; + ASSERT_LE(d, 1e-8); + } + ASSERT_LE(d, 1e-8); + spqlios_free(a); + spqlios_free(a1); + spqlios_free(a2); + // delete_cplx_fft_precomp(tables); + delete_reim_fft_precomp(reimtables); + } +} +#endif + +#ifdef __x86_64__ +TEST(fft, reim_vecfft_addmul_fma_vs_ref) { + for (uint64_t nn : {16, 32, 64, 1024, 8192, 65536}) { + uint64_t m = nn / 2; + REIM_FFTVEC_ADDMUL_PRECOMP* tbl = new_reim_fftvec_addmul_precomp(m); + ASSERT_TRUE(tbl != nullptr); + double* a1 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* a2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* b1 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* b2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* r1 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* r2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + int64_t p = 1 << 16; + for (uint32_t i = 0; i < nn; i++) { + a1[i] = (rand() % p) - p / 2; // between -p/2 and p/2 + b1[i] = (rand() % p) - p / 2; + r1[i] = (rand() % p) - p / 2; + } + memcpy(a2, a1, nn / 2 * sizeof(CPLX)); + memcpy(b2, b1, nn / 2 * sizeof(CPLX)); + memcpy(r2, r1, nn / 2 * sizeof(CPLX)); + reim_fftvec_addmul_ref(tbl, r1, a1, b1); + reim_fftvec_addmul_fma(tbl, r2, a2, b2); + double d = 0; + for (uint32_t i = 0; i < nn; i++) { + double di = fabs(r1[i] - r2[i]); + if (di > d) d = di; + ASSERT_LE(d, 1e-8); + } + ASSERT_LE(d, 1e-8); + spqlios_free(a1); + spqlios_free(a2); + spqlios_free(b1); + spqlios_free(b2); + spqlios_free(r1); + spqlios_free(r2); + delete_reim_fftvec_addmul_precomp(tbl); + } +} +#endif + +#ifdef __x86_64__ +TEST(fft, reim_vecfft_mul_fma_vs_ref) { + for (uint64_t nn : {16, 32, 64, 1024, 8192, 65536}) { + uint64_t m = nn / 2; + REIM_FFTVEC_MUL_PRECOMP* tbl = new_reim_fftvec_mul_precomp(m); + double* a1 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* a2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* b1 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* b2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* r1 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* r2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + int64_t p = 1 << 16; + for (uint32_t i = 0; i < nn; i++) { + a1[i] = (rand() % p) - p / 2; // between -p/2 and p/2 + b1[i] = (rand() % p) - p / 2; + r1[i] = (rand() % p) - p / 2; + } + memcpy(a2, a1, nn / 2 * sizeof(CPLX)); + memcpy(b2, b1, nn / 2 * sizeof(CPLX)); + memcpy(r2, r1, nn / 2 * sizeof(CPLX)); + reim_fftvec_mul_ref(tbl, r1, a1, b1); + reim_fftvec_mul_fma(tbl, r2, a2, b2); + double d = 0; + for (uint32_t i = 0; i < nn; i++) { + double di = fabs(r1[i] - r2[i]); + if (di > d) d = di; + ASSERT_LE(d, 1e-8); + } + ASSERT_LE(d, 1e-8); + spqlios_free(a1); + spqlios_free(a2); + spqlios_free(b1); + spqlios_free(b2); + spqlios_free(r1); + spqlios_free(r2); + delete_reim_fftvec_mul_precomp(tbl); + } +} +#endif + +#ifdef __x86_64__ +TEST(fft, reim_vecfft_add_fma_vs_ref) { + for (uint64_t nn : {16, 32, 64, 1024, 8192, 65536}) { + uint64_t m = nn / 2; + REIM_FFTVEC_ADD_PRECOMP* tbl = new_reim_fftvec_add_precomp(m); + double* a1 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* a2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* b1 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* b2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* r1 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* r2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + int64_t p = 1 << 16; + for (uint32_t i = 0; i < nn; i++) { + a1[i] = (rand() % p) - p / 2; // between -p/2 and p/2 + b1[i] = (rand() % p) - p / 2; + r1[i] = (rand() % p) - p / 2; + } + memcpy(a2, a1, nn / 2 * sizeof(CPLX)); + memcpy(b2, b1, nn / 2 * sizeof(CPLX)); + memcpy(r2, r1, nn / 2 * sizeof(CPLX)); + reim_fftvec_add_ref(tbl, r1, a1, b1); + reim_fftvec_add_fma(tbl, r2, a2, b2); + double d = 0; + for (uint32_t i = 0; i < nn; i++) { + double di = fabs(r1[i] - r2[i]); + if (di > d) d = di; + ASSERT_LE(d, 1e-8); + } + ASSERT_LE(d, 1e-8); + spqlios_free(a1); + spqlios_free(a2); + spqlios_free(b1); + spqlios_free(b2); + spqlios_free(r1); + spqlios_free(r2); + delete_reim_fftvec_add_precomp(tbl); + } +} +#endif + +typedef void (*FILL_REIM_FFT_OMG_F)(const double entry_pwr, double** omg); +typedef void (*REIM_FFT_F)(double* dre, double* dim, const void* omega); + +// template to test a fixed-dimension fft vs. naive +template +void test_reim_fft_ref_vs_naive(FILL_REIM_FFT_OMG_F fill_omega_f, REIM_FFT_F reim_fft_f) { + double om[N]; + double data[2 * N]; + double datacopy[2 * N]; + double* omg = om; + fill_omega_f(0.25, &omg); + ASSERT_EQ(omg - om, ptrdiff_t(N)); // it may depend on N + for (uint64_t i = 0; i < N; ++i) { + datacopy[i] = data[i] = (rand() % 100) - 50; + datacopy[N + i] = data[N + i] = (rand() % 100) - 50; + } + reim_fft_f(datacopy, datacopy + N, om); + reim_naive_fft(N, 0.25, data, data + N); + double d = 0; + for (uint64_t i = 0; i < 2 * N; ++i) { + d += fabs(datacopy[i] - data[i]); + } + ASSERT_LE(d, 1e-7); +} + +template +void test_reim_fft_ref_vs_accel(REIM_FFT_F reim_fft_ref_f, REIM_FFT_F reim_fft_accel_f) { + double om[N]; + double data[2 * N]; + double datacopy[2 * N]; + for (uint64_t i = 0; i < N; ++i) { + om[i] = (rand() % 100) - 50; + datacopy[i] = data[i] = (rand() % 100) - 50; + datacopy[N + i] = data[N + i] = (rand() % 100) - 50; + } + reim_fft_ref_f(datacopy, datacopy + N, om); + reim_fft_accel_f(data, data + N, om); + double d = 0; + for (uint64_t i = 0; i < 2 * N; ++i) { + d += fabs(datacopy[i] - data[i]); + } + if (d > 1e-15) { + for (uint64_t i = 0; i < N; ++i) { + printf("%" PRId64 " %lf %lf %lf %lf\n", i, data[i], data[N + i], datacopy[i], datacopy[N + i]); + } + ASSERT_LE(d, 0); + } +} + +TEST(fft, reim_fft16_ref_vs_naive) { test_reim_fft_ref_vs_naive<16>(fill_reim_fft16_omegas, reim_fft16_ref); } +#ifdef __aarch64__ +TEST(fft, reim_fft16_neon_vs_naive) { test_reim_fft_ref_vs_naive<16>(fill_reim_fft16_omegas_neon, reim_fft16_neon); } +#endif + +#ifdef __x86_64__ +TEST(fft, reim_fft16_ref_vs_fma) { test_reim_fft_ref_vs_accel<16>(reim_fft16_ref, reim_fft16_avx_fma); } +#endif + +#ifdef __aarch64__ +static void reim_fft16_ref_neon_pom(double* dre, double* dim, const void* omega) { + const double* pom = (double*)omega; + // put the omegas in neon order + double x_pom[] = {pom[0], pom[1], pom[2], pom[3], pom[4], pom[5], pom[6], pom[7], + pom[8], pom[10], pom[12], pom[14], pom[9], pom[11], pom[13], pom[15]}; + reim_fft16_ref(dre, dim, x_pom); +} +TEST(fft, reim_fft16_ref_vs_neon) { test_reim_fft_ref_vs_accel<16>(reim_fft16_ref_neon_pom, reim_fft16_neon); } +#endif + +TEST(fft, reim_fft8_ref_vs_naive) { test_reim_fft_ref_vs_naive<8>(fill_reim_fft8_omegas, reim_fft8_ref); } + +#ifdef __x86_64__ +TEST(fft, reim_fft8_ref_vs_fma) { test_reim_fft_ref_vs_accel<8>(reim_fft8_ref, reim_fft8_avx_fma); } +#endif + +TEST(fft, reim_fft4_ref_vs_naive) { test_reim_fft_ref_vs_naive<4>(fill_reim_fft4_omegas, reim_fft4_ref); } + +#ifdef __x86_64__ +TEST(fft, reim_fft4_ref_vs_fma) { test_reim_fft_ref_vs_accel<4>(reim_fft4_ref, reim_fft4_avx_fma); } +#endif + +TEST(fft, reim_fft2_ref_vs_naive) { test_reim_fft_ref_vs_naive<2>(fill_reim_fft2_omegas, reim_fft2_ref); } + +TEST(fft, reim_fft_bfs_16_ref_vs_naive) { + for (const uint64_t m : {16, 32, 64, 128, 256, 512, 1024, 2048}) { + std::vector om(2 * m); + std::vector data(2 * m); + std::vector datacopy(2 * m); + double* omg = om.data(); + fill_reim_fft_bfs_16_omegas(m, 0.25, &omg); + ASSERT_LE(omg - om.data(), ptrdiff_t(2 * m)); // it may depend on m + for (uint64_t i = 0; i < m; ++i) { + datacopy[i] = data[i] = (rand() % 100) - 50; + datacopy[m + i] = data[m + i] = (rand() % 100) - 50; + } + omg = om.data(); + reim_fft_bfs_16_ref(m, datacopy.data(), datacopy.data() + m, &omg); + reim_naive_fft(m, 0.25, data.data(), data.data() + m); + double d = 0; + for (uint64_t i = 0; i < 2 * m; ++i) { + d += fabs(datacopy[i] - data[i]); + } + ASSERT_LE(d, 1e-7); + } +} + +TEST(fft, reim_fft_rec_16_ref_vs_naive) { + for (const uint64_t m : {2048, 4096, 8192, 32768, 65536}) { + std::vector om(2 * m); + std::vector data(2 * m); + std::vector datacopy(2 * m); + double* omg = om.data(); + fill_reim_fft_rec_16_omegas(m, 0.25, &omg); + ASSERT_LE(omg - om.data(), ptrdiff_t(2 * m)); // it may depend on m + for (uint64_t i = 0; i < m; ++i) { + datacopy[i] = data[i] = (rand() % 100) - 50; + datacopy[m + i] = data[m + i] = (rand() % 100) - 50; + } + omg = om.data(); + reim_fft_rec_16_ref(m, datacopy.data(), datacopy.data() + m, &omg); + reim_naive_fft(m, 0.25, data.data(), data.data() + m); + double d = 0; + for (uint64_t i = 0; i < 2 * m; ++i) { + d += fabs(datacopy[i] - data[i]); + } + ASSERT_LE(d, 1e-5); + } +} + +TEST(fft, reim_fft_ref_vs_naive) { + for (const uint64_t m : {1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 32768, 65536}) { + std::vector om(2 * m); + std::vector data(2 * m); + std::vector datacopy(2 * m); + REIM_FFT_PRECOMP* precomp = new_reim_fft_precomp(m, 0); + for (uint64_t i = 0; i < m; ++i) { + datacopy[i] = data[i] = (rand() % 100) - 50; + datacopy[m + i] = data[m + i] = (rand() % 100) - 50; + } + reim_fft_ref(precomp, datacopy.data()); + reim_naive_fft(m, 0.25, data.data(), data.data() + m); + double d = 0; + for (uint64_t i = 0; i < 2 * m; ++i) { + d += fabs(datacopy[i] - data[i]); + } + ASSERT_LE(d, 1e-5) << m; + delete_reim_fft_precomp(precomp); + } +} + +#ifdef __aarch64__ +EXPORT REIM_FFT_PRECOMP* new_reim_fft_precomp_neon(uint32_t m, uint32_t num_buffers); +EXPORT void reim_fft_neon(const REIM_FFT_PRECOMP* precomp, double* d); +TEST(fft, reim_fft_neon_vs_naive) { + for (const uint64_t m : {1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 32768, 65536}) { + std::vector om(2 * m); + std::vector data(2 * m); + std::vector datacopy(2 * m); + REIM_FFT_PRECOMP* precomp = new_reim_fft_precomp_neon(m, 0); + for (uint64_t i = 0; i < m; ++i) { + datacopy[i] = data[i] = (rand() % 100) - 50; + datacopy[m + i] = data[m + i] = (rand() % 100) - 50; + } + reim_fft_neon(precomp, datacopy.data()); + reim_naive_fft(m, 0.25, data.data(), data.data() + m); + double d = 0; + for (uint64_t i = 0; i < 2 * m; ++i) { + d += fabs(datacopy[i] - data[i]); + } + ASSERT_LE(d, 1e-5) << m; + delete_reim_fft_precomp(precomp); + } +} +#endif + +typedef void (*FILL_REIM_IFFT_OMG_F)(const double entry_pwr, double** omg); +typedef void (*REIM_IFFT_F)(double* dre, double* dim, const void* omega); + +// template to test a fixed-dimension fft vs. naive +template +void test_reim_ifft_ref_vs_naive(FILL_REIM_IFFT_OMG_F fill_omega_f, REIM_IFFT_F reim_ifft_f) { + double om[N]; + double data[2 * N]; + double datacopy[2 * N]; + double* omg = om; + fill_omega_f(0.25, &omg); + ASSERT_EQ(omg - om, ptrdiff_t(N)); // it may depend on N + for (uint64_t i = 0; i < N; ++i) { + datacopy[i] = data[i] = (rand() % 100) - 50; + datacopy[N + i] = data[N + i] = (rand() % 100) - 50; + } + reim_ifft_f(datacopy, datacopy + N, om); + reim_naive_ifft(N, 0.25, data, data + N); + double d = 0; + for (uint64_t i = 0; i < 2 * N; ++i) { + d += fabs(datacopy[i] - data[i]); + } + ASSERT_LE(d, 1e-7); +} + +template +void test_reim_ifft_ref_vs_accel(REIM_IFFT_F reim_ifft_ref_f, REIM_IFFT_F reim_ifft_accel_f) { + double om[N]; + double data[2 * N]; + double datacopy[2 * N]; + for (uint64_t i = 0; i < N; ++i) { + om[i] = (rand() % 100) - 50; + datacopy[i] = data[i] = (rand() % 100) - 50; + datacopy[N + i] = data[N + i] = (rand() % 100) - 50; + } + reim_ifft_ref_f(datacopy, datacopy + N, om); + reim_ifft_accel_f(data, data + N, om); + double d = 0; + for (uint64_t i = 0; i < 2 * N; ++i) { + d += fabs(datacopy[i] - data[i]); + } + if (d > 1e-15) { + for (uint64_t i = 0; i < N; ++i) { + printf("%" PRId64 " %lf %lf %lf %lf\n", i, data[i], data[N + i], datacopy[i], datacopy[N + i]); + } + ASSERT_LE(d, 0); + } +} + +TEST(fft, reim_ifft16_ref_vs_naive) { test_reim_ifft_ref_vs_naive<16>(fill_reim_ifft16_omegas, reim_ifft16_ref); } + +#ifdef __x86_64__ +TEST(fft, reim_ifft16_ref_vs_fma) { test_reim_ifft_ref_vs_accel<16>(reim_ifft16_ref, reim_ifft16_avx_fma); } +#endif + +TEST(fft, reim_ifft8_ref_vs_naive) { test_reim_ifft_ref_vs_naive<8>(fill_reim_ifft8_omegas, reim_ifft8_ref); } + +#ifdef __x86_64__ +TEST(fft, reim_ifft8_ref_vs_fma) { test_reim_ifft_ref_vs_accel<8>(reim_ifft8_ref, reim_ifft8_avx_fma); } +#endif + +TEST(fft, reim_ifft4_ref_vs_naive) { test_reim_ifft_ref_vs_naive<4>(fill_reim_ifft4_omegas, reim_ifft4_ref); } + +#ifdef __x86_64__ +TEST(fft, reim_ifft4_ref_vs_fma) { test_reim_ifft_ref_vs_accel<4>(reim_ifft4_ref, reim_ifft4_avx_fma); } +#endif + +TEST(fft, reim_ifft2_ref_vs_naive) { test_reim_ifft_ref_vs_naive<2>(fill_reim_ifft2_omegas, reim_ifft2_ref); } + +TEST(fft, reim_ifft_bfs_16_ref_vs_naive) { + for (const uint64_t m : {16, 32, 64, 128, 256, 512, 1024, 2048}) { + std::vector om(2 * m); + std::vector data(2 * m); + std::vector datacopy(2 * m); + double* omg = om.data(); + fill_reim_ifft_bfs_16_omegas(m, 0.25, &omg); + ASSERT_LE(omg - om.data(), ptrdiff_t(2 * m)); // it may depend on m + for (uint64_t i = 0; i < m; ++i) { + datacopy[i] = data[i] = (rand() % 100) - 50; + datacopy[m + i] = data[m + i] = (rand() % 100) - 50; + } + omg = om.data(); + reim_ifft_bfs_16_ref(m, datacopy.data(), datacopy.data() + m, &omg); + reim_naive_ifft(m, 0.25, data.data(), data.data() + m); + double d = 0; + for (uint64_t i = 0; i < 2 * m; ++i) { + d += fabs(datacopy[i] - data[i]); + } + ASSERT_LE(d, 1e-7); + } +} + +TEST(fft, reim_ifft_rec_16_ref_vs_naive) { + for (const uint64_t m : {2048, 4096, 8192, 32768, 65536}) { + std::vector om(2 * m); + std::vector data(2 * m); + std::vector datacopy(2 * m); + double* omg = om.data(); + fill_reim_ifft_rec_16_omegas(m, 0.25, &omg); + ASSERT_LE(omg - om.data(), ptrdiff_t(2 * m)); // it may depend on m + for (uint64_t i = 0; i < m; ++i) { + datacopy[i] = data[i] = (rand() % 100) - 50; + datacopy[m + i] = data[m + i] = (rand() % 100) - 50; + } + omg = om.data(); + reim_ifft_rec_16_ref(m, datacopy.data(), datacopy.data() + m, &omg); + reim_naive_ifft(m, 0.25, data.data(), data.data() + m); + double d = 0; + for (uint64_t i = 0; i < 2 * m; ++i) { + d += fabs(datacopy[i] - data[i]); + } + ASSERT_LE(d, 1e-5); + } +} + +TEST(fft, reim_ifft_ref_vs_naive) { + for (const uint64_t m : {1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 32768, 65536}) { + std::vector om(2 * m); + std::vector data(2 * m); + std::vector datacopy(2 * m); + REIM_IFFT_PRECOMP* precomp = new_reim_ifft_precomp(m, 0); + for (uint64_t i = 0; i < m; ++i) { + datacopy[i] = data[i] = (rand() % 100) - 50; + datacopy[m + i] = data[m + i] = (rand() % 100) - 50; + } + reim_ifft_ref(precomp, datacopy.data()); + reim_naive_ifft(m, 0.25, data.data(), data.data() + m); + double d = 0; + for (uint64_t i = 0; i < 2 * m; ++i) { + d += fabs(datacopy[i] - data[i]); + } + ASSERT_LE(d, 1e-5) << m; + delete_reim_ifft_precomp(precomp); + } +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_svp_apply_test.cpp b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_svp_apply_test.cpp new file mode 100644 index 0000000..8a6d234 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_svp_apply_test.cpp @@ -0,0 +1,47 @@ +#include + +#include "../spqlios/arithmetic/vec_znx_arithmetic_private.h" +#include "testlib/fft64_dft.h" +#include "testlib/fft64_layouts.h" +#include "testlib/polynomial_vector.h" + +void test_fft64_svp_apply_dft(SVP_APPLY_DFT_F svp) { + for (uint64_t n : {2, 4, 8, 64, 128}) { + MODULE* module = new_module_info(n, FFT64); + // poly 1 to multiply - create and prepare + fft64_svp_ppol_layout ppol(n); + ppol.fill_random(1.); + for (uint64_t sa : {3, 5, 8}) { + for (uint64_t sr : {3, 5, 8}) { + uint64_t a_sl = n + uniform_u64_bits(2); + // poly 2 to multiply + znx_vec_i64_layout a(n, sa, a_sl); + a.fill_random(19); + // original operation result + fft64_vec_znx_dft_layout res(n, sr); + thash hash_a_before = a.content_hash(); + thash hash_ppol_before = ppol.content_hash(); + svp(module, res.data, sr, ppol.data, a.data(), sa, a_sl); + ASSERT_EQ(a.content_hash(), hash_a_before); + ASSERT_EQ(ppol.content_hash(), hash_ppol_before); + // create expected value + reim_fft64vec ppo = ppol.get_copy(); + std::vector expect(sr); + for (uint64_t i = 0; i < sr; ++i) { + expect[i] = ppo * simple_fft64(a.get_copy_zext(i)); + } + // this is the largest precision we can safely expect + double prec_expect = n * pow(2., 19 - 52); + for (uint64_t i = 0; i < sr; ++i) { + reim_fft64vec actual = res.get_copy_zext(i); + ASSERT_LE(infty_dist(actual, expect[i]), prec_expect); + } + } + } + + delete_module_info(module); + } +} + +TEST(fft64_svp_apply_dft, svp_apply_dft) { test_fft64_svp_apply_dft(svp_apply_dft); } +TEST(fft64_svp_apply_dft, fft64_svp_apply_dft_ref) { test_fft64_svp_apply_dft(fft64_svp_apply_dft_ref); } \ No newline at end of file diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_svp_prepare_test.cpp b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_svp_prepare_test.cpp new file mode 100644 index 0000000..db7cb23 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_svp_prepare_test.cpp @@ -0,0 +1,28 @@ +#include + +#include "../spqlios/arithmetic/vec_znx_arithmetic_private.h" +#include "testlib/fft64_dft.h" +#include "testlib/fft64_layouts.h" +#include "testlib/polynomial_vector.h" + +// todo: remove when registered +typedef typeof(fft64_svp_prepare_ref) SVP_PREPARE_F; + +void test_fft64_svp_prepare(SVP_PREPARE_F svp_prepare) { + for (uint64_t n : {2, 4, 8, 64, 128}) { + MODULE* module = new_module_info(n, FFT64); + znx_i64 in = znx_i64::random_log2bound(n, 40); + fft64_svp_ppol_layout out(n); + reim_fft64vec expect = simple_fft64(in); + svp_prepare(module, out.data, in.data()); + const double* ed = (double*)expect.data(); + const double* ac = (double*)out.data; + for (uint64_t i = 0; i < n; ++i) { + ASSERT_LE(abs(ed[i] - ac[i]), 1e-10) << i << n; + } + delete_module_info(module); + } +} + +TEST(svp_prepare, fft64_svp_prepare_ref) { test_fft64_svp_prepare(fft64_svp_prepare_ref); } +TEST(svp_prepare, svp_prepare) { test_fft64_svp_prepare(svp_prepare); } diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_test.cpp b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_test.cpp new file mode 100644 index 0000000..f6048c1 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_test.cpp @@ -0,0 +1,493 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "spqlios/cplx/cplx_fft_internal.h" + +using namespace std; + +/*namespace { +bool very_close(const double& a, const double& b) { + bool reps = (abs(a - b) < 1e-5); + if (!reps) { + cerr << "not close: " << a << " vs. " << b << endl; + } + return reps; +} + +}*/ // namespace + +TEST(fft, fftvec_convolution) { + uint64_t nn = 65536; // vary accross (8192, 16384), 32768, 65536 + static const uint64_t k = 18; // vary from 10 to 20 + // double* buf_fft = fft_precomp_get_buffer(tables, 0); + // double* buf_ifft = ifft_precomp_get_buffer(itables, 0); + double* a = (double*)spqlios_alloc_custom_align(32, nn * 8); + double* a2 = (double*)spqlios_alloc_custom_align(32, nn * 8); + double* b = (double*)spqlios_alloc_custom_align(32, nn * 8); + double* dist_vector = (double*)spqlios_alloc_custom_align(32, nn * 8); + int64_t p = UINT64_C(1) << k; + printf("p size: %" PRId64 "\n", p); + for (uint32_t i = 0; i < nn; i++) { + a[i] = (rand() % p) - p / 2; // between -p/2 and p/2 + b[i] = (rand() % p) - p / 2; + a2[i] = 0; + } + cplx_fft_simple(nn / 2, a); + cplx_fft_simple(nn / 2, b); + cplx_fftvec_addmul_simple(nn / 2, a2, a, b); + cplx_ifft_simple(nn / 2, a2); // normalization is missing + double distance = 0; + // for (int32_t i = 0; i < 10; i++) { + // printf("%lf %lf\n", a2[i], a2[i] / (nn / 2.)); + //} + for (uint32_t i = 0; i < nn; i++) { + double curdist = fabs(a2[i] / (nn / 2.) - rint(a2[i] / (nn / 2.))); + if (distance < curdist) distance = curdist; + dist_vector[i] = a2[i] / (nn / 2.) - rint(a2[i] / (nn / 2.)); + } + printf("distance: %lf\n", distance); + ASSERT_LE(distance, 0.5); // switch from previous 0.1 to 0.5 per experiment 1 reqs + // double a3[] = {2,4,4,4,5,5,7,9}; //instead of dist_vector, for test + // nn = 8; + double mean = 0; + for (uint32_t i = 0; i < nn; i++) { + mean = mean + dist_vector[i]; + } + mean = mean / nn; + double variance = 0; + for (uint32_t i = 0; i < nn; i++) { + variance = variance + pow((mean - dist_vector[i]), 2); + } + double stdev = sqrt(variance / nn); + printf("stdev: %lf\n", stdev); + + spqlios_free(a); + spqlios_free(b); + spqlios_free(a2); + spqlios_free(dist_vector); +} + +typedef double CPLX[2]; +EXPORT uint32_t revbits(uint32_t i, uint32_t v); + +void cplx_zero(CPLX r); +void cplx_addmul(CPLX r, const CPLX a, const CPLX b); + +void halfcfft_eval(CPLX res, uint32_t nn, uint32_t k, const CPLX* coeffs, const CPLX* powomegas); +void halfcfft_naive(uint32_t nn, CPLX* data); + +EXPORT void cplx_set(CPLX, const CPLX); +EXPORT void citwiddle(CPLX, CPLX, const CPLX); +EXPORT void invcitwiddle(CPLX, CPLX, const CPLX); +EXPORT void ctwiddle(CPLX, CPLX, const CPLX); +EXPORT void invctwiddle(CPLX, CPLX, const CPLX); + +#include "../spqlios/cplx/cplx_fft_private.h" +#include "../spqlios/reim/reim_fft_internal.h" +#include "../spqlios/reim/reim_fft_private.h" +#include "../spqlios/reim4/reim4_fftvec_internal.h" +#include "../spqlios/reim4/reim4_fftvec_private.h" + +TEST(fft, simple_fft_test) { // test for checking the simple_fft api + uint64_t nn = 8; // vary accross (8192, 16384), 32768, 65536 + // double* buf_fft = fft_precomp_get_buffer(tables, 0); + // double* buf_ifft = ifft_precomp_get_buffer(itables, 0); + + // define the complex coefficients of two polynomials mod X^4-i + double a[4][2] = {{1.1, 2.2}, {3.3, 4.4}, {5.5, 6.6}, {7.7, 8.8}}; + double b[4][2] = {{9., 10.}, {11., 12.}, {13., 14.}, {15., 16.}}; + double c[4][2]; // for the result + double a2[4][2]; // for testing inverse fft + memcpy(a2, a, 8 * nn); + cplx_fft_simple(4, a); + cplx_fft_simple(4, b); + cplx_fftvec_mul_simple(4, c, a, b); + cplx_ifft_simple(4, c); + // c contains the complex coefficients 4.a*b mod X^4-i + cplx_ifft_simple(4, a); + + double distance = 0; + for (uint32_t i = 0; i < nn / 2; i++) { + double dist = fabs(a[i][0] / 4. - a2[i][0]); + if (distance < dist) distance = dist; + dist = fabs(a[i][1] / 4. - a2[i][1]); + if (distance < dist) distance = dist; + } + printf("distance: %lf\n", distance); + ASSERT_LE(distance, 0.1); // switch from previous 0.1 to 0.5 per experiment 1 reqs + + for (uint32_t i = 0; i < nn / 4; i++) { + printf("%lf %lf\n", a2[i][0], a[i][0] / (nn / 2.)); + printf("%lf %lf\n", a2[i][1], a[i][1] / (nn / 2.)); + } +} + +TEST(fft, reim_test) { + // double a[16] __attribute__ ((aligned(32)))= {1.1,2.2,3.3,4.4,5.5,6.6,7.7,8.8,9.9,10.,11.,12.,13.,14.,15.,16.}; + // double b[16] __attribute__ ((aligned(32)))= {17.,18.,19.,20.,21.,22.,23.,24.,25.,26.,27.,28.,29.,30., 31.,32.}; + // double c[16] __attribute__ ((aligned(32))); // for the result in reference layout + double a[16] = {1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10., 11., 12., 13., 14., 15., 16.}; + double b[16] = {17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32.}; + double c[16]; // for the result in reference layout + reim_fft_simple(8, a); + reim_fft_simple(8, b); + reim_fftvec_mul_simple(8, c, a, b); + reim_ifft_simple(8, c); +} + +TEST(fft, reim_vs_regular_layout_mul_test) { + uint64_t nn = 16; + + // define the complex coefficients of two polynomials mod X^8-i + + double a1[8][2] __attribute__((aligned(32))) = {{1.1, 2.2}, {3.3, 4.4}, {5.5, 6.6}, {7.7, 8.8}, + {9.9, 10.}, {11., 12.}, {13., 14.}, {15., 16.}}; + double b1[8][2] __attribute__((aligned(32))) = {{17., 18.}, {19., 20.}, {21., 22.}, {23., 24.}, + {25., 26.}, {27., 28.}, {29., 30.}, {31., 32.}}; + double c1[8][2] __attribute__((aligned(32))); // for the result + double c2[16] __attribute__((aligned(32))); // for the result + double c3[8][2] __attribute__((aligned(32))); // for the result + + double* a2 = + (double*)spqlios_alloc_custom_align(32, nn / 2 * 2 * sizeof(double)); // for storing the coefs in reim layout + double* b2 = + (double*)spqlios_alloc_custom_align(32, nn / 2 * 2 * sizeof(double)); // for storing the coefs in reim layout + // double* c2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); // for storing the coefs in reim + // layout + + // organise the coefficients in the reim layout + for (uint32_t i = 0; i < nn / 2; i++) { + a2[i] = a1[i][0]; // a1 = a2, b1 = b2 + a2[nn / 2 + i] = a1[i][1]; + b2[i] = b1[i][0]; + b2[nn / 2 + i] = b1[i][1]; + } + + // fft + cplx_fft_simple(8, a1); + reim_fft_simple(8, a2); + + cplx_fft_simple(8, b1); + reim_fft_simple(8, b2); + + cplx_fftvec_mul_simple(8, c1, a1, b1); + reim_fftvec_mul_simple(8, c2, a2, b2); + + cplx_ifft_simple(8, c1); + reim_ifft_simple(8, c2); + + // check base layout and reim layout result in the same values + double d = 0; + for (uint32_t i = 0; i < nn / 2; i++) { + // printf("RE: cplx_result %lf and reim_result %lf \n", c1[i][0], c2[i]); + // printf("IM: cplx_result %lf and reim_result %lf \n", c1[i][1], c2[nn / 2 + i]); + double dre = fabs(c1[i][0] - c2[i]); + double dim = fabs(c1[i][1] - c2[nn / 2 + i]); + if (dre > d) d = dre; + if (dim > d) d = dim; + ASSERT_LE(d, 1e-7); + } + ASSERT_LE(d, 1e-7); + + // check converting back to base layout: + + for (uint32_t i = 0; i < nn / 2; i++) { + c3[i][0] = c2[i]; + c3[i][1] = c2[8 + i]; + } + + d = 0; + for (uint32_t i = 0; i < nn / 2; i++) { + double dre = fabs(c1[i][0] - c3[i][0]); + double dim = fabs(c1[i][1] - c3[i][1]); + if (dre > d) d = dre; + if (dim > d) d = dim; + ASSERT_LE(d, 1e-7); + } + ASSERT_LE(d, 1e-7); + + spqlios_free(a2); + spqlios_free(b2); + // spqlios_free(c2); +} + +TEST(fft, fftvec_convolution_recursiveoverk) { + static const uint64_t nn = 32768; // vary accross (8192, 16384), 32768, 65536 + double* a = (double*)spqlios_alloc_custom_align(32, nn * 8); + double* a2 = (double*)spqlios_alloc_custom_align(32, nn * 8); + double* b = (double*)spqlios_alloc_custom_align(32, nn * 8); + double* dist_vector = (double*)spqlios_alloc_custom_align(32, nn * 8); + + printf("N size: %" PRId64 "\n", nn); + + for (uint32_t k = 14; k <= 24; k++) { // vary k + printf("k size: %" PRId32 "\n", k); + int64_t p = UINT64_C(1) << k; + for (uint32_t i = 0; i < nn; i++) { + a[i] = (rand() % p) - p / 2; + b[i] = (rand() % p) - p / 2; + a2[i] = 0; + } + cplx_fft_simple(nn / 2, a); + cplx_fft_simple(nn / 2, b); + cplx_fftvec_addmul_simple(nn / 2, a2, a, b); + cplx_ifft_simple(nn / 2, a2); + double distance = 0; + for (uint32_t i = 0; i < nn; i++) { + double curdist = fabs(a2[i] / (nn / 2.) - rint(a2[i] / (nn / 2.))); + if (distance < curdist) distance = curdist; + dist_vector[i] = a2[i] / (nn / 2.) - rint(a2[i] / (nn / 2.)); + } + printf("distance: %lf\n", distance); + ASSERT_LE(distance, 0.5); // switch from previous 0.1 to 0.5 per experiment 1 reqs + double mean = 0; + for (uint32_t i = 0; i < nn; i++) { + mean = mean + dist_vector[i]; + } + mean = mean / nn; + double variance = 0; + for (uint32_t i = 0; i < nn; i++) { + variance = variance + pow((mean - dist_vector[i]), 2); + } + double stdev = sqrt(variance / nn); + printf("stdev: %lf\n", stdev); + } + + spqlios_free(a); + spqlios_free(b); + spqlios_free(a2); + spqlios_free(dist_vector); +} + +#ifdef __x86_64__ +TEST(fft, cplx_fft_ref_vs_fft_reim_ref) { + for (uint64_t nn : {16, 32, 64, 1024, 8192, 65536}) { + uint64_t m = nn / 2; + CPLX_FFT_PRECOMP* tables = new_cplx_fft_precomp(m, 0); + REIM_FFT_PRECOMP* reimtables = new_reim_fft_precomp(m, 0); + CPLX* a = (CPLX*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + CPLX* a1 = (CPLX*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* a2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + int64_t p = 1 << 16; + for (uint32_t i = 0; i < nn / 2; i++) { + a[i][0] = (rand() % p) - p / 2; // between -p/2 and p/2 + a[i][1] = (rand() % p) - p / 2; + } + memcpy(a1, a, nn / 2 * sizeof(CPLX)); + for (uint32_t i = 0; i < nn / 2; i++) { + a2[i] = a[i][0]; + a2[nn / 2 + i] = a[i][1]; + } + cplx_fft_ref(tables, a1); + reim_fft_ref(reimtables, a2); + double d = 0; + for (uint32_t i = 0; i < nn / 2; i++) { + double dre = fabs(a1[i][0] - a2[i]); + double dim = fabs(a1[i][1] - a2[nn / 2 + i]); + if (dre > d) d = dre; + if (dim > d) d = dim; + ASSERT_LE(d, 1e-7); + } + ASSERT_LE(d, 1e-7); + spqlios_free(a); + spqlios_free(a1); + spqlios_free(a2); + delete_cplx_fft_precomp(tables); + delete_reim_fft_precomp(reimtables); + } +} +#endif + +TEST(fft, cplx_ifft_ref_vs_reim_ifft_ref) { + for (uint64_t nn : {16, 32, 64, 1024, 8192, 65536}) { + uint64_t m = nn / 2; + CPLX_IFFT_PRECOMP* tables = new_cplx_ifft_precomp(m, 0); + REIM_IFFT_PRECOMP* reimtables = new_reim_ifft_precomp(m, 0); + CPLX* a = (CPLX*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + CPLX* a1 = (CPLX*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* a2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + int64_t p = 1 << 16; + for (uint32_t i = 0; i < nn / 2; i++) { + a[i][0] = (rand() % p) - p / 2; // between -p/2 and p/2 + a[i][1] = (rand() % p) - p / 2; + } + memcpy(a1, a, nn / 2 * sizeof(CPLX)); + for (uint32_t i = 0; i < nn / 2; i++) { + a2[i] = a[i][0]; + a2[nn / 2 + i] = a[i][1]; + } + cplx_ifft_ref(tables, a1); + reim_ifft_ref(reimtables, a2); + double d = 0; + for (uint32_t i = 0; i < nn / 2; i++) { + double dre = fabs(a1[i][0] - a2[i]); + double dim = fabs(a1[i][1] - a2[nn / 2 + i]); + if (dre > d) d = dre; + if (dim > d) d = dim; + ASSERT_LE(d, 1e-7); + } + ASSERT_LE(d, 1e-7); + spqlios_free(a); + spqlios_free(a1); + spqlios_free(a2); + delete_cplx_fft_precomp(tables); + delete_reim_fft_precomp(reimtables); + } +} + +#ifdef __x86_64__ +TEST(fft, reim4_vecfft_addmul_fma_vs_ref) { + for (uint64_t nn : {16, 32, 64, 1024, 8192, 65536}) { + uint64_t m = nn / 2; + REIM4_FFTVEC_ADDMUL_PRECOMP* tbl = new_reim4_fftvec_addmul_precomp(m); + ASSERT_TRUE(tbl != nullptr); + double* a1 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* a2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* b1 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* b2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* r1 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* r2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + int64_t p = 1 << 16; + for (uint32_t i = 0; i < nn; i++) { + a1[i] = (rand() % p) - p / 2; // between -p/2 and p/2 + b1[i] = (rand() % p) - p / 2; + r1[i] = (rand() % p) - p / 2; + } + memcpy(a2, a1, nn / 2 * sizeof(CPLX)); + memcpy(b2, b1, nn / 2 * sizeof(CPLX)); + memcpy(r2, r1, nn / 2 * sizeof(CPLX)); + reim4_fftvec_addmul_ref(tbl, r1, a1, b1); + reim4_fftvec_addmul_fma(tbl, r2, a2, b2); + double d = 0; + for (uint32_t i = 0; i < nn; i++) { + double di = fabs(r1[i] - r2[i]); + if (di > d) d = di; + ASSERT_LE(d, 1e-8); + } + ASSERT_LE(d, 1e-8); + spqlios_free(a1); + spqlios_free(a2); + spqlios_free(b1); + spqlios_free(b2); + spqlios_free(r1); + spqlios_free(r2); + delete_reim4_fftvec_addmul_precomp(tbl); + } +} +#endif + +#ifdef __x86_64__ +TEST(fft, reim4_vecfft_mul_fma_vs_ref) { + for (uint64_t nn : {16, 32, 64, 1024, 8192, 65536}) { + uint64_t m = nn / 2; + REIM4_FFTVEC_MUL_PRECOMP* tbl = new_reim4_fftvec_mul_precomp(m); + double* a1 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* a2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* b1 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* b2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* r1 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* r2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + int64_t p = 1 << 16; + for (uint32_t i = 0; i < nn; i++) { + a1[i] = (rand() % p) - p / 2; // between -p/2 and p/2 + b1[i] = (rand() % p) - p / 2; + r1[i] = (rand() % p) - p / 2; + } + memcpy(a2, a1, nn / 2 * sizeof(CPLX)); + memcpy(b2, b1, nn / 2 * sizeof(CPLX)); + memcpy(r2, r1, nn / 2 * sizeof(CPLX)); + reim4_fftvec_mul_ref(tbl, r1, a1, b1); + reim4_fftvec_mul_fma(tbl, r2, a2, b2); + double d = 0; + for (uint32_t i = 0; i < nn; i++) { + double di = fabs(r1[i] - r2[i]); + if (di > d) d = di; + ASSERT_LE(d, 1e-8); + } + ASSERT_LE(d, 1e-8); + spqlios_free(a1); + spqlios_free(a2); + spqlios_free(b1); + spqlios_free(b2); + spqlios_free(r1); + spqlios_free(r2); + delete_reim_fftvec_mul_precomp(tbl); + } +} +#endif + +#ifdef __x86_64__ +TEST(fft, reim4_from_cplx_fma_vs_ref) { + for (uint64_t nn : {16, 32, 64, 1024, 8192, 65536}) { + uint64_t m = nn / 2; + REIM4_FROM_CPLX_PRECOMP* tbl = new_reim4_from_cplx_precomp(m); + double* a1 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* a2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* r1 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* r2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + int64_t p = 1 << 16; + for (uint32_t i = 0; i < nn; i++) { + a1[i] = (rand() % p) - p / 2; // between -p/2 and p/2 + r1[i] = (rand() % p) - p / 2; + } + memcpy(a2, a1, nn / 2 * sizeof(CPLX)); + memcpy(r2, r1, nn / 2 * sizeof(CPLX)); + reim4_from_cplx_ref(tbl, r1, a1); + reim4_from_cplx_fma(tbl, r2, a2); + double d = 0; + for (uint32_t i = 0; i < nn; i++) { + double di = fabs(r1[i] - r2[i]); + if (di > d) d = di; + ASSERT_LE(d, 1e-8); + } + ASSERT_LE(d, 1e-8); + spqlios_free(a1); + spqlios_free(a2); + spqlios_free(r1); + spqlios_free(r2); + delete_reim4_from_cplx_precomp(tbl); + } +} +#endif + +#ifdef __x86_64__ +TEST(fft, reim4_to_cplx_fma_vs_ref) { + for (uint64_t nn : {16, 32, 64, 1024, 8192, 65536}) { + uint64_t m = nn / 2; + REIM4_TO_CPLX_PRECOMP* tbl = new_reim4_to_cplx_precomp(m); + double* a1 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* a2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* r1 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + double* r2 = (double*)spqlios_alloc_custom_align(32, nn / 2 * sizeof(CPLX)); + int64_t p = 1 << 16; + for (uint32_t i = 0; i < nn; i++) { + a1[i] = (rand() % p) - p / 2; // between -p/2 and p/2 + r1[i] = (rand() % p) - p / 2; + } + memcpy(a2, a1, nn / 2 * sizeof(CPLX)); + memcpy(r2, r1, nn / 2 * sizeof(CPLX)); + reim4_to_cplx_ref(tbl, r1, a1); + reim4_to_cplx_fma(tbl, r2, a2); + double d = 0; + for (uint32_t i = 0; i < nn; i++) { + double di = fabs(r1[i] - r2[i]); + if (di > d) d = di; + ASSERT_LE(d, 1e-8); + } + ASSERT_LE(d, 1e-8); + spqlios_free(a1); + spqlios_free(a2); + spqlios_free(r1); + spqlios_free(r2); + delete_reim4_from_cplx_precomp(tbl); + } +} +#endif diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_vec_rnx_approxdecomp_tnxdbl_test.cpp b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_vec_rnx_approxdecomp_tnxdbl_test.cpp new file mode 100644 index 0000000..5b36ed0 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_vec_rnx_approxdecomp_tnxdbl_test.cpp @@ -0,0 +1,42 @@ +#include "gtest/gtest.h" +#include "spqlios/arithmetic/vec_rnx_arithmetic_private.h" +#include "testlib/vec_rnx_layout.h" + +static void test_rnx_approxdecomp(RNX_APPROXDECOMP_FROM_TNXDBL_F approxdec) { + for (const uint64_t nn : {2, 4, 8, 32}) { + MOD_RNX* module = new_rnx_module_info(nn, FFT64); + for (const uint64_t ell : {1, 2, 7}) { + for (const uint64_t k : {2, 5}) { + TNXDBL_APPROXDECOMP_GADGET* gadget = new_tnxdbl_approxdecomp_gadget(module, k, ell); + for (const uint64_t res_size : {ell, ell - 1, ell + 1}) { + const uint64_t res_sl = nn + uniform_u64_bits(2); + rnx_vec_f64_layout in(nn, 1, nn); + in.fill_random(3); + rnx_vec_f64_layout out(nn, res_size, res_sl); + approxdec(module, gadget, out.data(), res_size, res_sl, in.data()); + // reconstruct the output + uint64_t msize = std::min(res_size, ell); + double err_bnd = msize == ell ? pow(2., -double(msize * k) - 1) : pow(2., -double(msize * k)); + for (uint64_t j = 0; j < nn; ++j) { + double in_j = in.data()[j]; + double out_j = 0; + for (uint64_t i = 0; i < res_size; ++i) { + out_j += out.get_copy(i).get_coeff(j) * pow(2., -double((i + 1) * k)); + } + double err = out_j - in_j; + double err_abs = fabs(err - rint(err)); + ASSERT_LE(err_abs, err_bnd); + } + } + delete_tnxdbl_approxdecomp_gadget(gadget); + } + } + delete_rnx_module_info(module); + } +} + +TEST(vec_rnx, rnx_approxdecomp) { test_rnx_approxdecomp(rnx_approxdecomp_from_tnxdbl); } +TEST(vec_rnx, rnx_approxdecomp_ref) { test_rnx_approxdecomp(rnx_approxdecomp_from_tnxdbl_ref); } +#ifdef __x86_64__ +TEST(vec_rnx, rnx_approxdecomp_avx) { test_rnx_approxdecomp(rnx_approxdecomp_from_tnxdbl_avx); } +#endif diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_vec_rnx_conversions_test.cpp b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_vec_rnx_conversions_test.cpp new file mode 100644 index 0000000..3b629e6 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_vec_rnx_conversions_test.cpp @@ -0,0 +1,134 @@ +#include +#include + +#include "testlib/test_commons.h" + +template +static void test_conv(void (*conv_f)(const MOD_RNX*, // + DST_T* res, uint64_t res_size, uint64_t res_sl, // + const SRC_T* a, uint64_t a_size, uint64_t a_sl), // + DST_T (*ideal_conv_f)(SRC_T x), // + SRC_T (*random_f)() // +) { + for (uint64_t nn : {2, 4, 16, 64}) { + MOD_RNX* module = new_rnx_module_info(nn, FFT64); + for (uint64_t a_size : {0, 1, 2, 5}) { + for (uint64_t res_size : {0, 1, 3, 5}) { + for (uint64_t trials = 0; trials < 20; ++trials) { + uint64_t a_sl = nn + uniform_u64_bits(2); + uint64_t res_sl = nn + uniform_u64_bits(2); + std::vector a(a_sl * a_size); + std::vector res(res_sl * res_size); + uint64_t msize = std::min(a_size, res_size); + for (uint64_t i = 0; i < a_size; ++i) { + for (uint64_t j = 0; j < nn; ++j) { + a[i * a_sl + j] = random_f(); + } + } + conv_f(module, res.data(), res_size, res_sl, a.data(), a_size, a_sl); + for (uint64_t i = 0; i < msize; ++i) { + for (uint64_t j = 0; j < nn; ++j) { + SRC_T aij = a[i * a_sl + j]; + DST_T expect = ideal_conv_f(aij); + DST_T actual = res[i * res_sl + j]; + ASSERT_EQ(expect, actual); + } + } + for (uint64_t i = msize; i < res_size; ++i) { + DST_T expect = 0; + for (uint64_t j = 0; j < nn; ++j) { + SRC_T actual = res[i * res_sl + j]; + ASSERT_EQ(expect, actual); + } + } + } + } + } + delete_rnx_module_info(module); + } +} + +static int32_t ideal_dbl_to_tn32(double a) { + double _2p32 = INT64_C(1) << 32; + double a_mod_1 = a - rint(a); + int64_t t = rint(a_mod_1 * _2p32); + return int32_t(t); +} + +static double random_f64_10() { return uniform_f64_bounds(-10, 10); } + +static void test_vec_rnx_to_tnx32(VEC_RNX_TO_TNX32_F vec_rnx_to_tnx32_f) { + test_conv(vec_rnx_to_tnx32_f, ideal_dbl_to_tn32, random_f64_10); +} + +TEST(vec_rnx_arithmetic, vec_rnx_to_tnx32) { test_vec_rnx_to_tnx32(vec_rnx_to_tnx32); } +TEST(vec_rnx_arithmetic, vec_rnx_to_tnx32_ref) { test_vec_rnx_to_tnx32(vec_rnx_to_tnx32_ref); } + +static double ideal_tn32_to_dbl(int32_t a) { + const double _2p32 = INT64_C(1) << 32; + return double(a) / _2p32; +} + +static int32_t random_t32() { return uniform_i64_bits(32); } + +static void test_vec_rnx_from_tnx32(VEC_RNX_FROM_TNX32_F vec_rnx_from_tnx32_f) { + test_conv(vec_rnx_from_tnx32_f, ideal_tn32_to_dbl, random_t32); +} + +TEST(vec_rnx_arithmetic, vec_rnx_from_tnx32) { test_vec_rnx_from_tnx32(vec_rnx_from_tnx32); } +TEST(vec_rnx_arithmetic, vec_rnx_from_tnx32_ref) { test_vec_rnx_from_tnx32(vec_rnx_from_tnx32_ref); } + +static int32_t ideal_dbl_round_to_i32(double a) { return int32_t(rint(a)); } + +static double random_dbl_explaw_18() { return uniform_f64_bounds(-1., 1.) * pow(2., uniform_u64_bits(6) % 19); } + +static void test_vec_rnx_to_znx32(VEC_RNX_TO_ZNX32_F vec_rnx_to_znx32_f) { + test_conv(vec_rnx_to_znx32_f, ideal_dbl_round_to_i32, random_dbl_explaw_18); +} + +TEST(zn_arithmetic, vec_rnx_to_znx32) { test_vec_rnx_to_znx32(vec_rnx_to_znx32); } +TEST(zn_arithmetic, vec_rnx_to_znx32_ref) { test_vec_rnx_to_znx32(vec_rnx_to_znx32_ref); } + +static double ideal_i32_to_dbl(int32_t a) { return double(a); } + +static int32_t random_i32_explaw_18() { return uniform_i64_bits(uniform_u64_bits(6) % 19); } + +static void test_vec_rnx_from_znx32(VEC_RNX_FROM_ZNX32_F vec_rnx_from_znx32_f) { + test_conv(vec_rnx_from_znx32_f, ideal_i32_to_dbl, random_i32_explaw_18); +} + +TEST(zn_arithmetic, vec_rnx_from_znx32) { test_vec_rnx_from_znx32(vec_rnx_from_znx32); } +TEST(zn_arithmetic, vec_rnx_from_znx32_ref) { test_vec_rnx_from_znx32(vec_rnx_from_znx32_ref); } + +static double ideal_dbl_to_tndbl(double a) { return a - rint(a); } + +static void test_vec_rnx_to_tnxdbl(VEC_RNX_TO_TNXDBL_F vec_rnx_to_tnxdbl_f) { + test_conv(vec_rnx_to_tnxdbl_f, ideal_dbl_to_tndbl, random_f64_10); +} + +TEST(zn_arithmetic, vec_rnx_to_tnxdbl) { test_vec_rnx_to_tnxdbl(vec_rnx_to_tnxdbl); } +TEST(zn_arithmetic, vec_rnx_to_tnxdbl_ref) { test_vec_rnx_to_tnxdbl(vec_rnx_to_tnxdbl_ref); } + +#if 0 +static int64_t ideal_dbl_round_to_i64(double a) { return rint(a); } + +static double random_dbl_explaw_50() { return uniform_f64_bounds(-1., 1.) * pow(2., uniform_u64_bits(7) % 51); } + +static void test_dbl_round_to_i64(DBL_ROUND_TO_I64_F dbl_round_to_i64_f) { + test_conv(dbl_round_to_i64_f, ideal_dbl_round_to_i64, random_dbl_explaw_50); +} + +TEST(zn_arithmetic, dbl_round_to_i64) { test_dbl_round_to_i64(dbl_round_to_i64); } +TEST(zn_arithmetic, dbl_round_to_i64_ref) { test_dbl_round_to_i64(dbl_round_to_i64_ref); } + +static double ideal_i64_to_dbl(int64_t a) { return double(a); } + +static int64_t random_i64_explaw_50() { return uniform_i64_bits(uniform_u64_bits(7) % 51); } + +static void test_i64_to_dbl(I64_TO_DBL_F i64_to_dbl_f) { + test_conv(i64_to_dbl_f, ideal_i64_to_dbl, random_i64_explaw_50); +} + +TEST(zn_arithmetic, i64_to_dbl) { test_i64_to_dbl(i64_to_dbl); } +TEST(zn_arithmetic, i64_to_dbl_ref) { test_i64_to_dbl(i64_to_dbl_ref); } +#endif diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_vec_rnx_ppol_test.cpp b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_vec_rnx_ppol_test.cpp new file mode 100644 index 0000000..e5e0cbd --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_vec_rnx_ppol_test.cpp @@ -0,0 +1,73 @@ +#include + +#include "spqlios/arithmetic/vec_rnx_arithmetic_private.h" +#include "spqlios/reim/reim_fft.h" +#include "test/testlib/vec_rnx_layout.h" + +static void test_vec_rnx_svp_prepare(RNX_SVP_PREPARE_F* rnx_svp_prepare, BYTES_OF_RNX_SVP_PPOL_F* tmp_bytes) { + for (uint64_t n : {2, 4, 8, 64}) { + MOD_RNX* mod = new_rnx_module_info(n, FFT64); + const double invm = 1. / mod->m; + + rnx_f64 in = rnx_f64::random_log2bound(n, 40); + rnx_f64 in_divide_by_m = rnx_f64::zero(n); + for (uint64_t i = 0; i < n; ++i) { + in_divide_by_m.set_coeff(i, in.get_coeff(i) * invm); + } + fft64_rnx_svp_ppol_layout out(n); + reim_fft64vec expect = simple_fft64(in_divide_by_m); + rnx_svp_prepare(mod, out.data, in.data()); + const double* ed = (double*)expect.data(); + const double* ac = (double*)out.data; + for (uint64_t i = 0; i < n; ++i) { + ASSERT_LE(abs(ed[i] - ac[i]), 1e-10) << i << n; + } + delete_rnx_module_info(mod); + } +} +TEST(vec_rnx, vec_rnx_svp_prepare) { test_vec_rnx_svp_prepare(rnx_svp_prepare, bytes_of_rnx_svp_ppol); } +TEST(vec_rnx, vec_rnx_svp_prepare_ref) { + test_vec_rnx_svp_prepare(fft64_rnx_svp_prepare_ref, fft64_bytes_of_rnx_svp_ppol); +} + +static void test_vec_rnx_svp_apply(RNX_SVP_APPLY_F* apply) { + for (uint64_t n : {2, 4, 8, 64, 128}) { + MOD_RNX* mod = new_rnx_module_info(n, FFT64); + + // poly 1 to multiply - create and prepare + fft64_rnx_svp_ppol_layout ppol(n); + ppol.fill_random(1.); + for (uint64_t sa : {3, 5, 8}) { + for (uint64_t sr : {3, 5, 8}) { + uint64_t a_sl = n + uniform_u64_bits(2); + uint64_t r_sl = n + uniform_u64_bits(2); + // poly 2 to multiply + rnx_vec_f64_layout a(n, sa, a_sl); + a.fill_random(19); + + // original operation result + rnx_vec_f64_layout res(n, sr, r_sl); + thash hash_a_before = a.content_hash(); + thash hash_ppol_before = ppol.content_hash(); + apply(mod, res.data(), sr, r_sl, ppol.data, a.data(), sa, a_sl); + ASSERT_EQ(a.content_hash(), hash_a_before); + ASSERT_EQ(ppol.content_hash(), hash_ppol_before); + // create expected value + reim_fft64vec ppo = ppol.get_copy(); + std::vector expect(sr); + for (uint64_t i = 0; i < sr; ++i) { + expect[i] = simple_ifft64(ppo * simple_fft64(a.get_copy_zext(i))); + } + // this is the largest precision we can safely expect + double prec_expect = n * pow(2., 19 - 50); + for (uint64_t i = 0; i < sr; ++i) { + rnx_f64 actual = res.get_copy_zext(i); + ASSERT_LE(infty_dist(actual, expect[i]), prec_expect); + } + } + } + delete_rnx_module_info(mod); + } +} +TEST(vec_rnx, vec_rnx_svp_apply) { test_vec_rnx_svp_apply(rnx_svp_apply); } +TEST(vec_rnx, vec_rnx_svp_apply_ref) { test_vec_rnx_svp_apply(fft64_rnx_svp_apply_ref); } diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_vec_rnx_test.cpp b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_vec_rnx_test.cpp new file mode 100644 index 0000000..1ff3389 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_vec_rnx_test.cpp @@ -0,0 +1,417 @@ +#include + +#include "spqlios/arithmetic/vec_rnx_arithmetic_private.h" +#include "spqlios/reim/reim_fft.h" +#include "testlib/vec_rnx_layout.h" + +// disabling this test by default, since it depicts on purpose wrong accesses +#if 0 +TEST(rnx_layout, valgrind_antipattern_test) { + uint64_t n = 4; + rnx_vec_f64_layout v(n, 7, 13); + // this should be ok + v.set(0, rnx_f64::zero(n)); + // this should abort (wrong ring dimension) + ASSERT_DEATH(v.set(3, rnx_f64::zero(2 * n)), ""); + // this should abort (out of bounds) + ASSERT_DEATH(v.set(8, rnx_f64::zero(n)), ""); + // this should be ok + ASSERT_EQ(v.get_copy_zext(0), rnx_f64::zero(n)); + // should be an uninit read + ASSERT_TRUE(!(v.get_copy_zext(2) == rnx_f64::zero(n))); // should be uninit + // should be an invalid read (inter-slice) + ASSERT_NE(v.data()[4], 0); + ASSERT_EQ(v.data()[2], 0); // should be ok + // should be an uninit read + ASSERT_NE(v.data()[13], 0); // should be uninit +} +#endif + +// test of binary operations + +// test for out of place calls +template +void test_vec_rnx_elemw_binop_outplace(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { + for (uint64_t n : {2, 4, 8, 128}) { + RNX_MODULE_TYPE mtype = FFT64; + MOD_RNX* mod = new_rnx_module_info(n, mtype); + for (uint64_t sa : {7, 13, 15}) { + for (uint64_t sb : {7, 13, 15}) { + for (uint64_t sc : {7, 13, 15}) { + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + uint64_t b_sl = uniform_u64_bits(3) * 5 + n; + uint64_t c_sl = uniform_u64_bits(3) * 5 + n; + rnx_vec_f64_layout la(n, sa, a_sl); + rnx_vec_f64_layout lb(n, sb, b_sl); + rnx_vec_f64_layout lc(n, sc, c_sl); + std::vector expect(sc); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, rnx_f64::random_log2bound(n, 1.)); + } + for (uint64_t i = 0; i < sb; ++i) { + lb.set(i, rnx_f64::random_log2bound(n, 1.)); + } + for (uint64_t i = 0; i < sc; ++i) { + expect[i] = ref_binop(la.get_copy_zext(i), lb.get_copy_zext(i)); + } + binop(mod, // N + lc.data(), sc, c_sl, // res + la.data(), sa, a_sl, // a + lb.data(), sb, b_sl); + for (uint64_t i = 0; i < sc; ++i) { + ASSERT_EQ(lc.get_copy_zext(i), expect[i]); + } + } + } + } + delete_rnx_module_info(mod); + } +} +// test for inplace1 calls +template +void test_vec_rnx_elemw_binop_inplace1(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { + for (uint64_t n : {2, 4, 64}) { + RNX_MODULE_TYPE mtype = FFT64; + MOD_RNX* mod = new_rnx_module_info(n, mtype); + for (uint64_t sa : {3, 9, 12}) { + for (uint64_t sb : {3, 9, 12}) { + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + uint64_t b_sl = uniform_u64_bits(3) * 5 + n; + rnx_vec_f64_layout la(n, sa, a_sl); + rnx_vec_f64_layout lb(n, sb, b_sl); + std::vector expect(sa); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, rnx_f64::random_log2bound(n, 1.)); + } + for (uint64_t i = 0; i < sb; ++i) { + lb.set(i, rnx_f64::random_log2bound(n, 1.)); + } + for (uint64_t i = 0; i < sa; ++i) { + expect[i] = ref_binop(la.get_copy_zext(i), lb.get_copy_zext(i)); + } + binop(mod, // N + la.data(), sa, a_sl, // res + la.data(), sa, a_sl, // a + lb.data(), sb, b_sl); + for (uint64_t i = 0; i < sa; ++i) { + ASSERT_EQ(la.get_copy_zext(i), expect[i]); + } + } + } + delete_rnx_module_info(mod); + } +} +// test for inplace2 calls +template +void test_vec_rnx_elemw_binop_inplace2(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { + for (uint64_t n : {4, 32, 64}) { + RNX_MODULE_TYPE mtype = FFT64; + MOD_RNX* mod = new_rnx_module_info(n, mtype); + for (uint64_t sa : {3, 9, 12}) { + for (uint64_t sb : {3, 9, 12}) { + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + uint64_t b_sl = uniform_u64_bits(3) * 5 + n; + rnx_vec_f64_layout la(n, sa, a_sl); + rnx_vec_f64_layout lb(n, sb, b_sl); + std::vector expect(sb); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, rnx_f64::random_log2bound(n, 1.)); + } + for (uint64_t i = 0; i < sb; ++i) { + lb.set(i, rnx_f64::random_log2bound(n, 1.)); + } + for (uint64_t i = 0; i < sb; ++i) { + expect[i] = ref_binop(la.get_copy_zext(i), lb.get_copy_zext(i)); + } + binop(mod, // N + lb.data(), sb, b_sl, // res + la.data(), sa, a_sl, // a + lb.data(), sb, b_sl); + for (uint64_t i = 0; i < sb; ++i) { + ASSERT_EQ(lb.get_copy_zext(i), expect[i]); + } + } + } + delete_rnx_module_info(mod); + } +} +// test for inplace3 calls +template +void test_vec_rnx_elemw_binop_inplace3(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { + for (uint64_t n : {2, 16, 1024}) { + RNX_MODULE_TYPE mtype = FFT64; + MOD_RNX* mod = new_rnx_module_info(n, mtype); + for (uint64_t sa : {2, 6, 11}) { + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + rnx_vec_f64_layout la(n, sa, a_sl); + std::vector expect(sa); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, rnx_f64::random_log2bound(n, 1.)); + } + for (uint64_t i = 0; i < sa; ++i) { + expect[i] = ref_binop(la.get_copy_zext(i), la.get_copy_zext(i)); + } + binop(mod, // N + la.data(), sa, a_sl, // res + la.data(), sa, a_sl, // a + la.data(), sa, a_sl); + for (uint64_t i = 0; i < sa; ++i) { + ASSERT_EQ(la.get_copy_zext(i), expect[i]); + } + } + delete_rnx_module_info(mod); + } +} +template +void test_vec_rnx_elemw_binop(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { + test_vec_rnx_elemw_binop_outplace(binop, ref_binop); + test_vec_rnx_elemw_binop_inplace1(binop, ref_binop); + test_vec_rnx_elemw_binop_inplace2(binop, ref_binop); + test_vec_rnx_elemw_binop_inplace3(binop, ref_binop); +} + +static rnx_f64 poly_add(const rnx_f64& a, const rnx_f64& b) { return a + b; } +static rnx_f64 poly_sub(const rnx_f64& a, const rnx_f64& b) { return a - b; } +TEST(vec_rnx, vec_rnx_add) { test_vec_rnx_elemw_binop(vec_rnx_add, poly_add); } +TEST(vec_rnx, vec_rnx_add_ref) { test_vec_rnx_elemw_binop(vec_rnx_add_ref, poly_add); } +#ifdef __x86_64__ +TEST(vec_rnx, vec_rnx_add_avx) { test_vec_rnx_elemw_binop(vec_rnx_add_avx, poly_add); } +#endif +TEST(vec_rnx, vec_rnx_sub) { test_vec_rnx_elemw_binop(vec_rnx_sub, poly_sub); } +TEST(vec_rnx, vec_rnx_sub_ref) { test_vec_rnx_elemw_binop(vec_rnx_sub_ref, poly_sub); } +#ifdef __x86_64__ +TEST(vec_rnx, vec_rnx_sub_avx) { test_vec_rnx_elemw_binop(vec_rnx_sub_avx, poly_sub); } +#endif + +// test for out of place calls +template +void test_vec_rnx_elemw_unop_param_outplace(ACTUAL_FCN test_mul_xp_minus_one, EXPECT_FCN ref_mul_xp_minus_one, + int64_t (*param_gen)()) { + for (uint64_t n : {2, 4, 8, 128}) { + MOD_RNX* mod = new_rnx_module_info(n, FFT64); + for (uint64_t sa : {7, 13, 15}) { + for (uint64_t sb : {7, 13, 15}) { + { + int64_t p = param_gen(); + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + uint64_t b_sl = uniform_u64_bits(3) * 4 + n; + rnx_vec_f64_layout la(n, sa, a_sl); + rnx_vec_f64_layout lb(n, sb, b_sl); + std::vector expect(sb); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, rnx_f64::random_log2bound(n, 0)); + } + for (uint64_t i = 0; i < sb; ++i) { + expect[i] = ref_mul_xp_minus_one(p, la.get_copy_zext(i)); + } + test_mul_xp_minus_one(mod, // + p, // + lb.data(), sb, b_sl, // + la.data(), sa, a_sl // + ); + for (uint64_t i = 0; i < sb; ++i) { + ASSERT_EQ(lb.get_copy_zext(i), expect[i]) << n << " " << sa << " " << sb << " " << i; + } + } + } + } + delete_rnx_module_info(mod); + } +} + +// test for inplace calls +template +void test_vec_rnx_elemw_unop_param_inplace(ACTUAL_FCN actual_function, EXPECT_FCN ref_function, + int64_t (*param_gen)()) { + for (uint64_t n : {2, 16, 1024}) { + MOD_RNX* mod = new_rnx_module_info(n, FFT64); + for (uint64_t sa : {2, 6, 11}) { + { + int64_t p = param_gen(); + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + rnx_vec_f64_layout la(n, sa, a_sl); + std::vector expect(sa); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, rnx_f64::random_log2bound(n, 0)); + } + for (uint64_t i = 0; i < sa; ++i) { + expect[i] = ref_function(p, la.get_copy_zext(i)); + } + actual_function(mod, // N + p, //; + la.data(), sa, a_sl, // res + la.data(), sa, a_sl // a + ); + for (uint64_t i = 0; i < sa; ++i) { + ASSERT_EQ(la.get_copy_zext(i), expect[i]) << n << " " << sa << " " << i; + } + } + } + delete_rnx_module_info(mod); + } +} + +static int64_t random_mul_xp_minus_one_param() { return uniform_i64(); } +static int64_t random_automorphism_param() { return 2 * uniform_i64() + 1; } +static int64_t random_rotation_param() { return uniform_i64(); } + +template +void test_vec_rnx_elemw_mul_xp_minus_one(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { + test_vec_rnx_elemw_unop_param_outplace(binop, ref_binop, random_mul_xp_minus_one_param); + test_vec_rnx_elemw_unop_param_inplace(binop, ref_binop, random_mul_xp_minus_one_param); +} +template +void test_vec_rnx_elemw_rotate(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { + test_vec_rnx_elemw_unop_param_outplace(binop, ref_binop, random_rotation_param); + test_vec_rnx_elemw_unop_param_inplace(binop, ref_binop, random_rotation_param); +} +template +void test_vec_rnx_elemw_automorphism(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { + test_vec_rnx_elemw_unop_param_outplace(binop, ref_binop, random_automorphism_param); + test_vec_rnx_elemw_unop_param_inplace(binop, ref_binop, random_automorphism_param); +} + +static rnx_f64 poly_mul_xp_minus_one(const int64_t p, const rnx_f64& a) { + uint64_t n = a.nn(); + rnx_f64 res(n); + for (uint64_t i = 0; i < n; ++i) { + res.set_coeff(i, a.get_coeff(i - p) - a.get_coeff(i)); + } + return res; +} +static rnx_f64 poly_rotate(const int64_t p, const rnx_f64& a) { + uint64_t n = a.nn(); + rnx_f64 res(n); + for (uint64_t i = 0; i < n; ++i) { + res.set_coeff(i, a.get_coeff(i - p)); + } + return res; +} +static rnx_f64 poly_automorphism(const int64_t p, const rnx_f64& a) { + uint64_t n = a.nn(); + rnx_f64 res(n); + for (uint64_t i = 0; i < n; ++i) { + res.set_coeff(i * p, a.get_coeff(i)); + } + return res; +} + +TEST(vec_rnx, vec_rnx_mul_xp_minus_one) { + test_vec_rnx_elemw_mul_xp_minus_one(vec_rnx_mul_xp_minus_one, poly_mul_xp_minus_one); +} +TEST(vec_rnx, vec_rnx_mul_xp_minus_one_ref) { + test_vec_rnx_elemw_mul_xp_minus_one(vec_rnx_mul_xp_minus_one_ref, poly_mul_xp_minus_one); +} + +TEST(vec_rnx, vec_rnx_rotate) { test_vec_rnx_elemw_rotate(vec_rnx_rotate, poly_rotate); } +TEST(vec_rnx, vec_rnx_rotate_ref) { test_vec_rnx_elemw_rotate(vec_rnx_rotate_ref, poly_rotate); } +TEST(vec_rnx, vec_rnx_automorphism) { test_vec_rnx_elemw_automorphism(vec_rnx_automorphism, poly_automorphism); } +TEST(vec_rnx, vec_rnx_automorphism_ref) { + test_vec_rnx_elemw_automorphism(vec_rnx_automorphism_ref, poly_automorphism); +} + +// test for out of place calls +template +void test_vec_rnx_elemw_unop_outplace(ACTUAL_FCN actual_function, EXPECT_FCN ref_function) { + for (uint64_t n : {2, 4, 8, 128}) { + MOD_RNX* mod = new_rnx_module_info(n, FFT64); + for (uint64_t sa : {7, 13, 15}) { + for (uint64_t sb : {7, 13, 15}) { + { + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + uint64_t b_sl = uniform_u64_bits(3) * 4 + n; + rnx_vec_f64_layout la(n, sa, a_sl); + rnx_vec_f64_layout lb(n, sb, b_sl); + std::vector expect(sb); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, rnx_f64::random_log2bound(n, 0)); + } + for (uint64_t i = 0; i < sb; ++i) { + expect[i] = ref_function(la.get_copy_zext(i)); + } + actual_function(mod, // + lb.data(), sb, b_sl, // + la.data(), sa, a_sl // + ); + for (uint64_t i = 0; i < sb; ++i) { + ASSERT_EQ(lb.get_copy_zext(i), expect[i]) << n << " " << sa << " " << sb << " " << i; + } + } + } + } + delete_rnx_module_info(mod); + } +} + +// test for inplace calls +template +void test_vec_rnx_elemw_unop_inplace(ACTUAL_FCN actual_function, EXPECT_FCN ref_function) { + for (uint64_t n : {2, 16, 1024}) { + MOD_RNX* mod = new_rnx_module_info(n, FFT64); + for (uint64_t sa : {2, 6, 11}) { + { + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + rnx_vec_f64_layout la(n, sa, a_sl); + std::vector expect(sa); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, rnx_f64::random_log2bound(n, 0)); + } + for (uint64_t i = 0; i < sa; ++i) { + expect[i] = ref_function(la.get_copy_zext(i)); + } + actual_function(mod, // N + la.data(), sa, a_sl, // res + la.data(), sa, a_sl // a + ); + for (uint64_t i = 0; i < sa; ++i) { + ASSERT_EQ(la.get_copy_zext(i), expect[i]) << n << " " << sa << " " << i; + } + } + } + delete_rnx_module_info(mod); + } +} +template +void test_vec_rnx_elemw_unop(ACTUAL_FCN unnop, EXPECT_FCN ref_unnop) { + test_vec_rnx_elemw_unop_outplace(unnop, ref_unnop); + test_vec_rnx_elemw_unop_inplace(unnop, ref_unnop); +} + +static rnx_f64 poly_copy(const rnx_f64& a) { return a; } +static rnx_f64 poly_negate(const rnx_f64& a) { return -a; } + +TEST(vec_rnx, vec_rnx_copy) { test_vec_rnx_elemw_unop(vec_rnx_copy, poly_copy); } +TEST(vec_rnx, vec_rnx_copy_ref) { test_vec_rnx_elemw_unop(vec_rnx_copy_ref, poly_copy); } +TEST(vec_rnx, vec_rnx_negate) { test_vec_rnx_elemw_unop(vec_rnx_negate, poly_negate); } +TEST(vec_rnx, vec_rnx_negate_ref) { test_vec_rnx_elemw_unop(vec_rnx_negate_ref, poly_negate); } +#ifdef __x86_64__ +TEST(vec_rnx, vec_rnx_negate_avx) { test_vec_rnx_elemw_unop(vec_rnx_negate_avx, poly_negate); } +#endif + +// test for inplace calls +void test_vec_rnx_zero(VEC_RNX_ZERO_F actual_function) { + for (uint64_t n : {2, 16, 1024}) { + MOD_RNX* mod = new_rnx_module_info(n, FFT64); + for (uint64_t sa : {2, 6, 11}) { + { + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + rnx_vec_f64_layout la(n, sa, a_sl); + const rnx_f64 ZERO = rnx_f64::zero(n); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, rnx_f64::random_log2bound(n, 0)); + } + actual_function(mod, // N + la.data(), sa, a_sl // res + ); + for (uint64_t i = 0; i < sa; ++i) { + ASSERT_EQ(la.get_copy_zext(i), ZERO) << n << " " << sa << " " << i; + } + } + } + delete_rnx_module_info(mod); + } +} + +TEST(vec_rnx, vec_rnx_zero) { test_vec_rnx_zero(vec_rnx_zero); } + +TEST(vec_rnx, vec_rnx_zero_ref) { test_vec_rnx_zero(vec_rnx_zero_ref); } diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_vec_rnx_vmp_test.cpp b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_vec_rnx_vmp_test.cpp new file mode 100644 index 0000000..edaa12c --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_vec_rnx_vmp_test.cpp @@ -0,0 +1,369 @@ +#include "../spqlios/arithmetic/vec_rnx_arithmetic_private.h" +#include "../spqlios/reim/reim_fft.h" +#include "gtest/gtest.h" +#include "testlib/vec_rnx_layout.h" + +static void test_vmp_apply_dft_to_dft_outplace( // + RNX_VMP_APPLY_DFT_TO_DFT_F* apply, // + RNX_VMP_APPLY_DFT_TO_DFT_TMP_BYTES_F* tmp_bytes) { + for (uint64_t nn : {2, 4, 8, 64}) { + MOD_RNX* module = new_rnx_module_info(nn, FFT64); + for (uint64_t mat_nrows : {1, 4, 7}) { + for (uint64_t mat_ncols : {1, 2, 5}) { + for (uint64_t in_size : {1, 4, 7}) { + for (uint64_t out_size : {1, 2, 5}) { + const uint64_t in_sl = nn + uniform_u64_bits(2); + const uint64_t out_sl = nn + uniform_u64_bits(2); + rnx_vec_f64_layout in(nn, in_size, in_sl); + fft64_rnx_vmp_pmat_layout pmat(nn, mat_nrows, mat_ncols); + rnx_vec_f64_layout out(nn, out_size, out_sl); + in.fill_random(0); + pmat.fill_random(0); + // naive computation of the product + std::vector expect(out_size, reim_fft64vec(nn)); + for (uint64_t col = 0; col < out_size; ++col) { + reim_fft64vec ex = reim_fft64vec::zero(nn); + for (uint64_t row = 0; row < std::min(mat_nrows, in_size); ++row) { + ex += pmat.get_zext(row, col) * in.get_dft_copy(row); + } + expect[col] = ex; + } + // apply the product + std::vector tmp(tmp_bytes(module, out_size, in_size, mat_nrows, mat_ncols)); + apply(module, // + out.data(), out_size, out_sl, // + in.data(), in_size, in_sl, // + pmat.data, mat_nrows, mat_ncols, // + tmp.data()); + // check that the output is close from the expectation + for (uint64_t col = 0; col < out_size; ++col) { + reim_fft64vec actual = out.get_dft_copy_zext(col); + ASSERT_LE(infty_dist(actual, expect[col]), 1e-10); + } + } + } + } + } + delete_rnx_module_info(module); + } +} + +static void test_vmp_apply_dft_to_dft_inplace( // + RNX_VMP_APPLY_DFT_TO_DFT_F* apply, // + RNX_VMP_APPLY_DFT_TO_DFT_TMP_BYTES_F* tmp_bytes) { + for (uint64_t nn : {2, 4, 8, 64}) { + MOD_RNX* module = new_rnx_module_info(nn, FFT64); + for (uint64_t mat_nrows : {1, 2, 6}) { + for (uint64_t mat_ncols : {1, 2, 7, 8}) { + for (uint64_t in_size : {1, 3, 6}) { + for (uint64_t out_size : {1, 3, 6}) { + const uint64_t in_out_sl = nn + uniform_u64_bits(2); + rnx_vec_f64_layout in_out(nn, std::max(in_size, out_size), in_out_sl); + fft64_rnx_vmp_pmat_layout pmat(nn, mat_nrows, mat_ncols); + in_out.fill_random(0); + pmat.fill_random(0); + // naive computation of the product + std::vector expect(out_size, reim_fft64vec(nn)); + for (uint64_t col = 0; col < out_size; ++col) { + reim_fft64vec ex = reim_fft64vec::zero(nn); + for (uint64_t row = 0; row < std::min(mat_nrows, in_size); ++row) { + ex += pmat.get_zext(row, col) * in_out.get_dft_copy(row); + } + expect[col] = ex; + } + // apply the product + std::vector tmp(tmp_bytes(module, out_size, in_size, mat_nrows, mat_ncols)); + apply(module, // + in_out.data(), out_size, in_out_sl, // + in_out.data(), in_size, in_out_sl, // + pmat.data, mat_nrows, mat_ncols, // + tmp.data()); + // check that the output is close from the expectation + for (uint64_t col = 0; col < out_size; ++col) { + reim_fft64vec actual = in_out.get_dft_copy_zext(col); + ASSERT_LE(infty_dist(actual, expect[col]), 1e-10); + } + } + } + } + } + delete_rnx_module_info(module); + } +} + +static void test_vmp_apply_dft_to_dft( // + RNX_VMP_APPLY_DFT_TO_DFT_F* apply, // + RNX_VMP_APPLY_DFT_TO_DFT_TMP_BYTES_F* tmp_bytes) { + test_vmp_apply_dft_to_dft_outplace(apply, tmp_bytes); + test_vmp_apply_dft_to_dft_inplace(apply, tmp_bytes); +} + +TEST(vec_rnx, vmp_apply_to_dft) { + test_vmp_apply_dft_to_dft(rnx_vmp_apply_dft_to_dft, rnx_vmp_apply_dft_to_dft_tmp_bytes); +} +TEST(vec_rnx, fft64_vmp_apply_dft_to_dft_ref) { + test_vmp_apply_dft_to_dft(fft64_rnx_vmp_apply_dft_to_dft_ref, fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref); +} +#ifdef __x86_64__ +TEST(vec_rnx, fft64_vmp_apply_dft_to_dft_avx) { + test_vmp_apply_dft_to_dft(fft64_rnx_vmp_apply_dft_to_dft_avx, fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_avx); +} +#endif + +/// rnx_vmp_prepare + +static void test_vmp_prepare_contiguous(RNX_VMP_PREPARE_CONTIGUOUS_F* prepare_contiguous, + RNX_VMP_PREPARE_TMP_BYTES_F* tmp_bytes) { + // tests when n < 8 + for (uint64_t nn : {2, 4}) { + const double one_over_m = 2. / nn; + MOD_RNX* module = new_rnx_module_info(nn, FFT64); + for (uint64_t nrows : {1, 2, 5}) { + for (uint64_t ncols : {2, 6, 7}) { + rnx_vec_f64_layout mat(nn, nrows * ncols, nn); + fft64_rnx_vmp_pmat_layout pmat(nn, nrows, ncols); + mat.fill_random(0); + std::vector tmp_space(tmp_bytes(module)); + thash hash_before = mat.content_hash(); + prepare_contiguous(module, pmat.data, mat.data(), nrows, ncols, tmp_space.data()); + ASSERT_EQ(mat.content_hash(), hash_before); + for (uint64_t row = 0; row < nrows; ++row) { + for (uint64_t col = 0; col < ncols; ++col) { + const double* pmatv = (double*)pmat.data + (col * nrows + row) * nn; + reim_fft64vec tmp = one_over_m * simple_fft64(mat.get_copy(row * ncols + col)); + const double* tmpv = tmp.data(); + for (uint64_t i = 0; i < nn; ++i) { + ASSERT_LE(abs(pmatv[i] - tmpv[i]), 1e-10); + } + } + } + } + } + delete_rnx_module_info(module); + } + // tests when n >= 8 + for (uint64_t nn : {8, 32}) { + const double one_over_m = 2. / nn; + MOD_RNX* module = new_rnx_module_info(nn, FFT64); + uint64_t nblk = nn / 8; + for (uint64_t nrows : {1, 2, 5}) { + for (uint64_t ncols : {2, 6, 7}) { + rnx_vec_f64_layout mat(nn, nrows * ncols, nn); + fft64_rnx_vmp_pmat_layout pmat(nn, nrows, ncols); + mat.fill_random(0); + std::vector tmp_space(tmp_bytes(module)); + thash hash_before = mat.content_hash(); + prepare_contiguous(module, pmat.data, mat.data(), nrows, ncols, tmp_space.data()); + ASSERT_EQ(mat.content_hash(), hash_before); + for (uint64_t row = 0; row < nrows; ++row) { + for (uint64_t col = 0; col < ncols; ++col) { + reim_fft64vec tmp = one_over_m * simple_fft64(mat.get_copy(row * ncols + col)); + for (uint64_t blk = 0; blk < nblk; ++blk) { + reim4_elem expect = tmp.get_blk(blk); + reim4_elem actual = pmat.get(row, col, blk); + ASSERT_LE(infty_dist(actual, expect), 1e-10); + } + } + } + } + } + delete_rnx_module_info(module); + } +} + +TEST(vec_rnx, vmp_prepare_contiguous) { + test_vmp_prepare_contiguous(rnx_vmp_prepare_contiguous, rnx_vmp_prepare_tmp_bytes); +} +TEST(vec_rnx, fft64_vmp_prepare_contiguous_ref) { + test_vmp_prepare_contiguous(fft64_rnx_vmp_prepare_contiguous_ref, fft64_rnx_vmp_prepare_tmp_bytes_ref); +} +#ifdef __x86_64__ +TEST(vec_rnx, fft64_vmp_prepare_contiguous_avx) { + test_vmp_prepare_contiguous(fft64_rnx_vmp_prepare_contiguous_avx, fft64_rnx_vmp_prepare_tmp_bytes_avx); +} +#endif + +/// rnx_vmp_prepare_dblptr + +static void test_vmp_prepare_dblptr(RNX_VMP_PREPARE_DBLPTR_F* prepare_dblptr, RNX_VMP_PREPARE_TMP_BYTES_F* tmp_bytes) { + // tests when n < 8 + for (uint64_t nn : {2, 4}) { + const double one_over_m = 2. / nn; + MOD_RNX* module = new_rnx_module_info(nn, FFT64); + for (uint64_t nrows : {1, 2, 5}) { + for (uint64_t ncols : {2, 6, 7}) { + rnx_vec_f64_layout mat(nn, nrows * ncols, nn); + fft64_rnx_vmp_pmat_layout pmat(nn, nrows, ncols); + mat.fill_random(0); + std::vector tmp_space(tmp_bytes(module)); + thash hash_before = mat.content_hash(); + const double** mat_dblptr = (const double**)malloc(nrows * sizeof(double*)); + for (size_t row_i = 0; row_i < nrows; row_i++) { + mat_dblptr[row_i] = &mat.data()[row_i * ncols * nn]; + }; + prepare_dblptr(module, pmat.data, mat_dblptr, nrows, ncols, tmp_space.data()); + ASSERT_EQ(mat.content_hash(), hash_before); + for (uint64_t row = 0; row < nrows; ++row) { + for (uint64_t col = 0; col < ncols; ++col) { + const double* pmatv = (double*)pmat.data + (col * nrows + row) * nn; + reim_fft64vec tmp = one_over_m * simple_fft64(mat.get_copy(row * ncols + col)); + const double* tmpv = tmp.data(); + for (uint64_t i = 0; i < nn; ++i) { + ASSERT_LE(abs(pmatv[i] - tmpv[i]), 1e-10); + } + } + } + } + } + delete_rnx_module_info(module); + } + // tests when n >= 8 + for (uint64_t nn : {8, 32}) { + const double one_over_m = 2. / nn; + MOD_RNX* module = new_rnx_module_info(nn, FFT64); + uint64_t nblk = nn / 8; + for (uint64_t nrows : {1, 2, 5}) { + for (uint64_t ncols : {2, 6, 7}) { + rnx_vec_f64_layout mat(nn, nrows * ncols, nn); + fft64_rnx_vmp_pmat_layout pmat(nn, nrows, ncols); + mat.fill_random(0); + std::vector tmp_space(tmp_bytes(module)); + thash hash_before = mat.content_hash(); + const double** mat_dblptr = (const double**)malloc(nrows * sizeof(double*)); + for (size_t row_i = 0; row_i < nrows; row_i++) { + mat_dblptr[row_i] = &mat.data()[row_i * ncols * nn]; + }; + prepare_dblptr(module, pmat.data, mat_dblptr, nrows, ncols, tmp_space.data()); + ASSERT_EQ(mat.content_hash(), hash_before); + for (uint64_t row = 0; row < nrows; ++row) { + for (uint64_t col = 0; col < ncols; ++col) { + reim_fft64vec tmp = one_over_m * simple_fft64(mat.get_copy(row * ncols + col)); + for (uint64_t blk = 0; blk < nblk; ++blk) { + reim4_elem expect = tmp.get_blk(blk); + reim4_elem actual = pmat.get(row, col, blk); + ASSERT_LE(infty_dist(actual, expect), 1e-10); + } + } + } + } + } + delete_rnx_module_info(module); + } +} + +TEST(vec_rnx, vmp_prepare_dblptr) { test_vmp_prepare_dblptr(rnx_vmp_prepare_dblptr, rnx_vmp_prepare_tmp_bytes); } +TEST(vec_rnx, fft64_vmp_prepare_dblptr_ref) { + test_vmp_prepare_dblptr(fft64_rnx_vmp_prepare_dblptr_ref, fft64_rnx_vmp_prepare_tmp_bytes_ref); +} +#ifdef __x86_64__ +TEST(vec_rnx, fft64_vmp_prepare_dblptr_avx) { + test_vmp_prepare_dblptr(fft64_rnx_vmp_prepare_dblptr_avx, fft64_rnx_vmp_prepare_tmp_bytes_avx); +} +#endif + +/// rnx_vmp_apply_dft_to_dft + +static void test_vmp_apply_tmp_a_outplace( // + RNX_VMP_APPLY_TMP_A_F* apply, // + RNX_VMP_APPLY_TMP_A_TMP_BYTES_F* tmp_bytes) { + for (uint64_t nn : {2, 4, 8, 64}) { + MOD_RNX* module = new_rnx_module_info(nn, FFT64); + for (uint64_t mat_nrows : {1, 4, 7}) { + for (uint64_t mat_ncols : {1, 2, 5}) { + for (uint64_t in_size : {1, 4, 7}) { + for (uint64_t out_size : {1, 2, 5}) { + const uint64_t in_sl = nn + uniform_u64_bits(2); + const uint64_t out_sl = nn + uniform_u64_bits(2); + rnx_vec_f64_layout in(nn, in_size, in_sl); + fft64_rnx_vmp_pmat_layout pmat(nn, mat_nrows, mat_ncols); + rnx_vec_f64_layout out(nn, out_size, out_sl); + in.fill_random(0); + pmat.fill_random(0); + // naive computation of the product + std::vector expect(out_size); + for (uint64_t col = 0; col < out_size; ++col) { + reim_fft64vec ex = reim_fft64vec::zero(nn); + for (uint64_t row = 0; row < std::min(mat_nrows, in_size); ++row) { + ex += pmat.get_zext(row, col) * simple_fft64(in.get_copy(row)); + } + expect[col] = simple_ifft64(ex); + } + // apply the product + std::vector tmp(tmp_bytes(module, out_size, in_size, mat_nrows, mat_ncols)); + apply(module, // + out.data(), out_size, out_sl, // + in.data(), in_size, in_sl, // + pmat.data, mat_nrows, mat_ncols, // + tmp.data()); + // check that the output is close from the expectation + for (uint64_t col = 0; col < out_size; ++col) { + rnx_f64 actual = out.get_copy_zext(col); + ASSERT_LE(infty_dist(actual, expect[col]), 1e-10); + } + } + } + } + } + delete_rnx_module_info(module); + } +} + +static void test_vmp_apply_tmp_a_inplace( // + RNX_VMP_APPLY_TMP_A_F* apply, // + RNX_VMP_APPLY_TMP_A_TMP_BYTES_F* tmp_bytes) { + for (uint64_t nn : {2, 4, 8, 64}) { + MOD_RNX* module = new_rnx_module_info(nn, FFT64); + for (uint64_t mat_nrows : {1, 4, 7}) { + for (uint64_t mat_ncols : {1, 2, 5}) { + for (uint64_t in_size : {1, 4, 7}) { + for (uint64_t out_size : {1, 2, 5}) { + const uint64_t in_out_sl = nn + uniform_u64_bits(2); + rnx_vec_f64_layout in_out(nn, std::max(in_size, out_size), in_out_sl); + fft64_rnx_vmp_pmat_layout pmat(nn, mat_nrows, mat_ncols); + in_out.fill_random(0); + pmat.fill_random(0); + // naive computation of the product + std::vector expect(out_size); + for (uint64_t col = 0; col < out_size; ++col) { + reim_fft64vec ex = reim_fft64vec::zero(nn); + for (uint64_t row = 0; row < std::min(mat_nrows, in_size); ++row) { + ex += pmat.get_zext(row, col) * simple_fft64(in_out.get_copy(row)); + } + expect[col] = simple_ifft64(ex); + } + // apply the product + std::vector tmp(tmp_bytes(module, out_size, in_size, mat_nrows, mat_ncols)); + apply(module, // + in_out.data(), out_size, in_out_sl, // + in_out.data(), in_size, in_out_sl, // + pmat.data, mat_nrows, mat_ncols, // + tmp.data()); + // check that the output is close from the expectation + for (uint64_t col = 0; col < out_size; ++col) { + rnx_f64 actual = in_out.get_copy_zext(col); + ASSERT_LE(infty_dist(actual, expect[col]), 1e-10); + } + } + } + } + } + delete_rnx_module_info(module); + } +} + +static void test_vmp_apply_tmp_a( // + RNX_VMP_APPLY_TMP_A_F* apply, // + RNX_VMP_APPLY_TMP_A_TMP_BYTES_F* tmp_bytes) { + test_vmp_apply_tmp_a_outplace(apply, tmp_bytes); + test_vmp_apply_tmp_a_inplace(apply, tmp_bytes); +} + +TEST(vec_znx, fft64_vmp_apply_tmp_a) { test_vmp_apply_tmp_a(rnx_vmp_apply_tmp_a, rnx_vmp_apply_tmp_a_tmp_bytes); } +TEST(vec_znx, fft64_vmp_apply_tmp_a_ref) { + test_vmp_apply_tmp_a(fft64_rnx_vmp_apply_tmp_a_ref, fft64_rnx_vmp_apply_tmp_a_tmp_bytes_ref); +} +#ifdef __x86_64__ +TEST(vec_znx, fft64_vmp_apply_tmp_a_avx) { + test_vmp_apply_tmp_a(fft64_rnx_vmp_apply_tmp_a_avx, fft64_rnx_vmp_apply_tmp_a_tmp_bytes_avx); +} +#endif diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_vec_znx_big_test.cpp b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_vec_znx_big_test.cpp new file mode 100644 index 0000000..8c51ed8 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_vec_znx_big_test.cpp @@ -0,0 +1,265 @@ +#include + +#include "spqlios/arithmetic/vec_znx_arithmetic_private.h" +#include "test/testlib/polynomial_vector.h" +#include "testlib/fft64_layouts.h" +#include "testlib/test_commons.h" + +#define def_rand_big(varname, ringdim, varsize) \ + fft64_vec_znx_big_layout varname(ringdim, varsize); \ + varname.fill_random() + +#define def_rand_small(varname, ringdim, varsize) \ + znx_vec_i64_layout varname(ringdim, varsize, 2 * ringdim); \ + varname.fill_random() + +#define test_prelude(ringdim, moduletype, dim1, dim2, dim3) \ + uint64_t n = ringdim; \ + MODULE* module = new_module_info(ringdim, moduletype); \ + for (uint64_t sa : {dim1, dim2, dim3}) { \ + for (uint64_t sb : {dim1, dim2, dim3}) { \ + for (uint64_t sr : {dim1, dim2, dim3}) + +#define test_end() \ + } \ + } \ + free(module) + +void test_fft64_vec_znx_big_add(VEC_ZNX_BIG_ADD_F vec_znx_big_add_fcn) { + test_prelude(8, FFT64, 3, 5, 7) { + fft64_vec_znx_big_layout r(n, sr); + def_rand_big(a, n, sa); + def_rand_big(b, n, sb); + vec_znx_big_add_fcn(module, r.data, sr, a.data, sa, b.data, sb); + for (uint64_t i = 0; i < sr; ++i) { + ASSERT_EQ(r.get_copy(i), a.get_copy_zext(i) + b.get_copy_zext(i)); + } + } + test_end(); +} +void test_fft64_vec_znx_big_add_small(VEC_ZNX_BIG_ADD_SMALL_F vec_znx_big_add_fcn) { + test_prelude(16, FFT64, 2, 4, 5) { + fft64_vec_znx_big_layout r(n, sr); + def_rand_big(a, n, sa); + def_rand_small(b, n, sb); + vec_znx_big_add_fcn(module, r.data, sr, a.data, sa, b.data(), sb, 2 * n); + for (uint64_t i = 0; i < sr; ++i) { + ASSERT_EQ(r.get_copy(i), a.get_copy_zext(i) + b.get_copy_zext(i)); + } + } + test_end(); +} +void test_fft64_vec_znx_big_add_small2(VEC_ZNX_BIG_ADD_SMALL2_F vec_znx_big_add_fcn) { + test_prelude(64, FFT64, 3, 6, 7) { + fft64_vec_znx_big_layout r(n, sr); + def_rand_small(a, n, sa); + def_rand_small(b, n, sb); + vec_znx_big_add_fcn(module, r.data, sr, a.data(), sa, 2 * n, b.data(), sb, 2 * n); + for (uint64_t i = 0; i < sr; ++i) { + ASSERT_EQ(r.get_copy(i), a.get_copy_zext(i) + b.get_copy_zext(i)); + } + } + test_end(); +} + +TEST(fft64_vec_znx_big, fft64_vec_znx_big_add) { test_fft64_vec_znx_big_add(fft64_vec_znx_big_add); } +TEST(vec_znx_big, vec_znx_big_add) { test_fft64_vec_znx_big_add(vec_znx_big_add); } + +TEST(fft64_vec_znx_big, fft64_vec_znx_big_add_small) { test_fft64_vec_znx_big_add_small(fft64_vec_znx_big_add_small); } +TEST(vec_znx_big, vec_znx_big_add_small) { test_fft64_vec_znx_big_add_small(vec_znx_big_add_small); } + +TEST(fft64_vec_znx_big, fft64_vec_znx_big_add_small2) { + test_fft64_vec_znx_big_add_small2(fft64_vec_znx_big_add_small2); +} +TEST(vec_znx_big, vec_znx_big_add_small2) { test_fft64_vec_znx_big_add_small2(vec_znx_big_add_small2); } + +void test_fft64_vec_znx_big_sub(VEC_ZNX_BIG_SUB_F vec_znx_big_sub_fcn) { + test_prelude(16, FFT64, 3, 5, 7) { + fft64_vec_znx_big_layout r(n, sr); + def_rand_big(a, n, sa); + def_rand_big(b, n, sb); + vec_znx_big_sub_fcn(module, r.data, sr, a.data, sa, b.data, sb); + for (uint64_t i = 0; i < sr; ++i) { + ASSERT_EQ(r.get_copy(i), a.get_copy_zext(i) - b.get_copy_zext(i)); + } + } + test_end(); +} +void test_fft64_vec_znx_big_sub_small_a(VEC_ZNX_BIG_SUB_SMALL_A_F vec_znx_big_sub_fcn) { + test_prelude(32, FFT64, 2, 4, 5) { + fft64_vec_znx_big_layout r(n, sr); + def_rand_small(a, n, sa); + def_rand_big(b, n, sb); + vec_znx_big_sub_fcn(module, r.data, sr, a.data(), sa, 2 * n, b.data, sb); + for (uint64_t i = 0; i < sr; ++i) { + ASSERT_EQ(r.get_copy(i), a.get_copy_zext(i) - b.get_copy_zext(i)); + } + } + test_end(); +} +void test_fft64_vec_znx_big_sub_small_b(VEC_ZNX_BIG_SUB_SMALL_B_F vec_znx_big_sub_fcn) { + test_prelude(16, FFT64, 2, 4, 5) { + fft64_vec_znx_big_layout r(n, sr); + def_rand_big(a, n, sa); + def_rand_small(b, n, sb); + vec_znx_big_sub_fcn(module, r.data, sr, a.data, sa, b.data(), sb, 2 * n); + for (uint64_t i = 0; i < sr; ++i) { + ASSERT_EQ(r.get_copy(i), a.get_copy_zext(i) - b.get_copy_zext(i)); + } + } + test_end(); +} +void test_fft64_vec_znx_big_sub_small2(VEC_ZNX_BIG_SUB_SMALL2_F vec_znx_big_sub_fcn) { + test_prelude(8, FFT64, 3, 6, 7) { + fft64_vec_znx_big_layout r(n, sr); + def_rand_small(a, n, sa); + def_rand_small(b, n, sb); + vec_znx_big_sub_fcn(module, r.data, sr, a.data(), sa, 2 * n, b.data(), sb, 2 * n); + for (uint64_t i = 0; i < sr; ++i) { + ASSERT_EQ(r.get_copy(i), a.get_copy_zext(i) - b.get_copy_zext(i)); + } + } + test_end(); +} + +TEST(fft64_vec_znx_big, fft64_vec_znx_big_sub) { test_fft64_vec_znx_big_sub(fft64_vec_znx_big_sub); } +TEST(vec_znx_big, vec_znx_big_sub) { test_fft64_vec_znx_big_sub(vec_znx_big_sub); } + +TEST(fft64_vec_znx_big, fft64_vec_znx_big_sub_small_a) { + test_fft64_vec_znx_big_sub_small_a(fft64_vec_znx_big_sub_small_a); +} +TEST(vec_znx_big, vec_znx_big_sub_small_a) { test_fft64_vec_znx_big_sub_small_a(vec_znx_big_sub_small_a); } + +TEST(fft64_vec_znx_big, fft64_vec_znx_big_sub_small_b) { + test_fft64_vec_znx_big_sub_small_b(fft64_vec_znx_big_sub_small_b); +} +TEST(vec_znx_big, vec_znx_big_sub_small_b) { test_fft64_vec_znx_big_sub_small_b(vec_znx_big_sub_small_b); } + +TEST(fft64_vec_znx_big, fft64_vec_znx_big_sub_small2) { + test_fft64_vec_znx_big_sub_small2(fft64_vec_znx_big_sub_small2); +} +TEST(vec_znx_big, vec_znx_big_sub_small2) { test_fft64_vec_znx_big_sub_small2(vec_znx_big_sub_small2); } + +static void test_vec_znx_big_normalize(VEC_ZNX_BIG_NORMALIZE_BASE2K_F normalize, + VEC_ZNX_BIG_NORMALIZE_BASE2K_TMP_BYTES_F normalize_tmp_bytes) { + // in the FFT64 case, big_normalize is just a forward. + // we will just test that the functions are callable + uint64_t n = 16; + uint64_t k = 12; + MODULE* module = new_module_info(n, FFT64); + for (uint64_t sa : {3, 5, 7}) { + for (uint64_t sr : {3, 5, 7}) { + uint64_t r_sl = n + 3; + def_rand_big(a, n, sa); + znx_vec_i64_layout r(n, sr, r_sl); + std::vector tmp_space(normalize_tmp_bytes(module, n)); + normalize(module, n, k, r.data(), sr, r_sl, a.data, sa, tmp_space.data()); + } + } + delete_module_info(module); +} + +TEST(vec_znx_big, fft64_vec_znx_big_normalize_base2k) { + test_vec_znx_big_normalize(fft64_vec_znx_big_normalize_base2k, fft64_vec_znx_big_normalize_base2k_tmp_bytes); +} +TEST(vec_znx_big, vec_znx_big_normalize_base2k) { + test_vec_znx_big_normalize(vec_znx_big_normalize_base2k, vec_znx_big_normalize_base2k_tmp_bytes); +} + +static void test_vec_znx_big_range_normalize( // + VEC_ZNX_BIG_RANGE_NORMALIZE_BASE2K_F normalize, + VEC_ZNX_BIG_RANGE_NORMALIZE_BASE2K_TMP_BYTES_F normalize_tmp_bytes) { + // in the FFT64 case, big_normalize is just a forward. + // we will just test that the functions are callable + uint64_t n = 16; + uint64_t k = 11; + MODULE* module = new_module_info(n, FFT64); + for (uint64_t sa : {6, 15, 21}) { + for (uint64_t sr : {3, 5, 7}) { + uint64_t r_sl = n + 3; + def_rand_big(a, n, sa); + uint64_t a_start = uniform_u64_bits(30) % (sa / 2); + uint64_t a_end = sa - (uniform_u64_bits(30) % (sa / 2)); + uint64_t a_step = (uniform_u64_bits(30) % 3) + 1; + uint64_t range_size = (a_end + a_step - 1 - a_start) / a_step; + fft64_vec_znx_big_layout aextr(n, range_size); + for (uint64_t i = 0, idx = a_start; idx < a_end; ++i, idx += a_step) { + aextr.set(i, a.get_copy(idx)); + } + znx_vec_i64_layout r(n, sr, r_sl); + znx_vec_i64_layout r2(n, sr, r_sl); + // tmp_space is large-enough for both + std::vector tmp_space(normalize_tmp_bytes(module, n)); + normalize(module, n, k, r.data(), sr, r_sl, a.data, a_start, a_end, a_step, tmp_space.data()); + fft64_vec_znx_big_normalize_base2k(module, n, k, r2.data(), sr, r_sl, aextr.data, range_size, tmp_space.data()); + for (uint64_t i = 0; i < sr; ++i) { + ASSERT_EQ(r.get_copy(i), r2.get_copy(i)); + } + } + } + delete_module_info(module); +} + +TEST(vec_znx_big, fft64_vec_znx_big_range_normalize_base2k) { + test_vec_znx_big_range_normalize(fft64_vec_znx_big_range_normalize_base2k, + fft64_vec_znx_big_range_normalize_base2k_tmp_bytes); +} +TEST(vec_znx_big, vec_znx_big_range_normalize_base2k) { + test_vec_znx_big_range_normalize(vec_znx_big_range_normalize_base2k, vec_znx_big_range_normalize_base2k_tmp_bytes); +} + +static void test_vec_znx_big_rotate(VEC_ZNX_BIG_ROTATE_F rotate) { + // in the FFT64 case, big_normalize is just a forward. + // we will just test that the functions are callable + uint64_t n = 16; + int64_t p = 12; + MODULE* module = new_module_info(n, FFT64); + for (uint64_t sa : {3, 5, 7}) { + for (uint64_t sr : {3, 5, 7}) { + def_rand_big(a, n, sa); + fft64_vec_znx_big_layout r(n, sr); + rotate(module, p, r.data, sr, a.data, sa); + for (uint64_t i = 0; i < sr; ++i) { + znx_i64 aa = a.get_copy_zext(i); + znx_i64 expect(n); + for (uint64_t j = 0; j < n; ++j) { + expect.set_coeff(j, aa.get_coeff(int64_t(j) - p)); + } + znx_i64 actual = r.get_copy(i); + ASSERT_EQ(expect, actual); + } + } + } + delete_module_info(module); +} + +TEST(vec_znx_big, fft64_vec_znx_big_rotate) { test_vec_znx_big_rotate(fft64_vec_znx_big_rotate); } +TEST(vec_znx_big, vec_znx_big_rotate) { test_vec_znx_big_rotate(vec_znx_big_rotate); } + +static void test_vec_znx_big_automorphism(VEC_ZNX_BIG_AUTOMORPHISM_F automorphism) { + // in the FFT64 case, big_normalize is just a forward. + // we will just test that the functions are callable + uint64_t n = 16; + int64_t p = 11; + MODULE* module = new_module_info(n, FFT64); + for (uint64_t sa : {3, 5, 7}) { + for (uint64_t sr : {3, 5, 7}) { + def_rand_big(a, n, sa); + fft64_vec_znx_big_layout r(n, sr); + automorphism(module, p, r.data, sr, a.data, sa); + for (uint64_t i = 0; i < sr; ++i) { + znx_i64 aa = a.get_copy_zext(i); + znx_i64 expect(n); + for (uint64_t j = 0; j < n; ++j) { + expect.set_coeff(p * j, aa.get_coeff(j)); + } + znx_i64 actual = r.get_copy(i); + ASSERT_EQ(expect, actual); + } + } + } + delete_module_info(module); +} + +TEST(vec_znx_big, fft64_vec_znx_big_automorphism) { test_vec_znx_big_automorphism(fft64_vec_znx_big_automorphism); } +TEST(vec_znx_big, vec_znx_big_automorphism) { test_vec_znx_big_automorphism(vec_znx_big_automorphism); } diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_vec_znx_dft_test.cpp b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_vec_znx_dft_test.cpp new file mode 100644 index 0000000..598921d --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_vec_znx_dft_test.cpp @@ -0,0 +1,231 @@ +#include + +#include + +#include "../spqlios/arithmetic/vec_znx_arithmetic_private.h" +#include "spqlios/arithmetic/vec_znx_arithmetic.h" +#include "test/testlib/ntt120_dft.h" +#include "test/testlib/ntt120_layouts.h" +#include "testlib/fft64_dft.h" +#include "testlib/fft64_layouts.h" +#include "testlib/polynomial_vector.h" + +static void test_fft64_vec_znx_dft(VEC_ZNX_DFT_F dft) { + for (uint64_t n : {2, 4, 128}) { + MODULE* module = new_module_info(n, FFT64); + for (uint64_t sa : {3, 5, 8}) { + for (uint64_t sr : {3, 5, 8}) { + uint64_t a_sl = n + uniform_u64_bits(2); + znx_vec_i64_layout a(n, sa, a_sl); + fft64_vec_znx_dft_layout res(n, sr); + a.fill_random(42); + std::vector expect(sr); + for (uint64_t i = 0; i < sr; ++i) { + expect[i] = simple_fft64(a.get_copy_zext(i)); + } + // test the function + thash hash_before = a.content_hash(); + dft(module, res.data, sr, a.data(), sa, a_sl); + ASSERT_EQ(a.content_hash(), hash_before); + for (uint64_t i = 0; i < sr; ++i) { + reim_fft64vec actual = res.get_copy_zext(i); + ASSERT_LE(infty_dist(actual, expect[i]), 1e-10); + } + } + } + delete_module_info(module); + } +} + +static void test_fft64_vec_dft_add(VEC_DFT_ADD_F dft_add) { + for (uint64_t n : {2, 4, 128}) { + MODULE* module = new_module_info(n, FFT64); + for (uint64_t sa : {3, 5, 8}) { + for (uint64_t sr : {3, 5, 8}) { + fft64_vec_znx_dft_layout a(n, sa); + fft64_vec_znx_dft_layout b(n, sa); + fft64_vec_znx_dft_layout res(n, sr); + a.fill_dft_random_log2bound(42); + b.fill_dft_random_log2bound(42); + std::vector expect(sr); + for (uint64_t i = 0; i < sr; ++i) { + std::vector v(n); + if (i < sa) { + reim_fftvec_add(module->mod.fft64.add_fft, v.data(), ((double*)a.data) + i * n, ((double*)b.data) + i * n); + } else { + std::fill(v.begin(), v.end(), 0.0); + } + expect[i] = reim_fft64vec(n, v.data()); + } + // test the function + thash a_hash_before = a.content_hash(); + thash b_hash_before = b.content_hash(); + dft_add(module, res.data, sr, a.data, sa, b.data, sa); + ASSERT_EQ(a.content_hash(), a_hash_before); + ASSERT_EQ(b.content_hash(), b_hash_before); + for (uint64_t i = 0; i < sr; ++i) { + reim_fft64vec actual = res.get_copy_zext(i); + ASSERT_LE(infty_dist(actual, expect[i]), 1e-10); + } + } + } + delete_module_info(module); + } +} + +#ifdef __x86_64__ +// FIXME: currently, it only works on avx +static void test_ntt120_vec_znx_dft(VEC_ZNX_DFT_F dft) { + for (uint64_t n : {2, 4, 128}) { + MODULE* module = new_module_info(n, NTT120); + for (uint64_t sa : {3, 5, 8}) { + for (uint64_t sr : {3, 5, 8}) { + uint64_t a_sl = n + uniform_u64_bits(2); + znx_vec_i64_layout a(n, sa, a_sl); + ntt120_vec_znx_dft_layout res(n, sr); + a.fill_random(42); + std::vector expect(sr); + for (uint64_t i = 0; i < sr; ++i) { + expect[i] = simple_ntt120(a.get_copy_zext(i)); + } + // test the function + thash hash_before = a.content_hash(); + dft(module, res.data, sr, a.data(), sa, a_sl); + ASSERT_EQ(a.content_hash(), hash_before); + for (uint64_t i = 0; i < sr; ++i) { + q120_nttvec actual = res.get_copy_zext(i); + if (!(actual == expect[i])) { + for (uint64_t j = 0; j < n; ++j) { + std::cerr << actual.v[j] << " vs " << expect[i].v[j] << std::endl; + } + } + ASSERT_EQ(actual, expect[i]); + } + } + } + delete_module_info(module); + } +} +#endif + +TEST(vec_znx_dft, fft64_vec_znx_dft) { test_fft64_vec_znx_dft(fft64_vec_znx_dft); } +#ifdef __x86_64__ +// FIXME: currently, it only works on avx +TEST(vec_znx_dft, ntt120_vec_znx_dft) { test_ntt120_vec_znx_dft(ntt120_vec_znx_dft_avx); } +#endif +TEST(vec_znx_dft, vec_znx_dft) { + test_fft64_vec_znx_dft(vec_znx_dft); +#ifdef __x86_64__ + // FIXME: currently, it only works on avx + test_ntt120_vec_znx_dft(ntt120_vec_znx_dft_avx); +#endif +} + +TEST(vec_dft_add, fft64_vec_dft_add) { test_fft64_vec_dft_add(fft64_vec_dft_add); } + +static void test_fft64_vec_znx_idft(VEC_ZNX_IDFT_F idft, VEC_ZNX_IDFT_TMP_A_F idft_tmp_a, + VEC_ZNX_IDFT_TMP_BYTES_F idft_tmp_bytes) { + for (uint64_t n : {2, 4, 64, 128}) { + MODULE* module = new_module_info(n, FFT64); + uint64_t tmp_size = idft_tmp_bytes ? idft_tmp_bytes(module, n) : 0; + std::vector tmp(tmp_size); + for (uint64_t sa : {3, 5, 8}) { + for (uint64_t sr : {3, 5, 8}) { + fft64_vec_znx_dft_layout a(n, sa); + fft64_vec_znx_big_layout res(n, sr); + a.fill_dft_random_log2bound(22); + std::vector expect(sr); + for (uint64_t i = 0; i < sr; ++i) { + expect[i] = simple_rint_ifft64(a.get_copy_zext(i)); + } + // test the function + if (idft_tmp_bytes) { + thash hash_before = a.content_hash(); + idft(module, res.data, sr, a.data, sa, tmp.data()); + ASSERT_EQ(a.content_hash(), hash_before); + } else { + idft_tmp_a(module, res.data, sr, a.data, sa); + } + for (uint64_t i = 0; i < sr; ++i) { + znx_i64 actual = res.get_copy_zext(i); + // ASSERT_EQ(res.get_copy_zext(i), expect[i]); + if (!(actual == expect[i])) { + for (uint64_t j = 0; j < n; ++j) { + std::cerr << actual.get_coeff(j) << " dft vs. " << expect[i].get_coeff(j) << std::endl; + } + FAIL(); + } + } + } + } + delete_module_info(module); + } +} + +TEST(vec_znx_dft, fft64_vec_znx_idft) { + test_fft64_vec_znx_idft(fft64_vec_znx_idft, nullptr, fft64_vec_znx_idft_tmp_bytes); +} +TEST(vec_znx_dft, fft64_vec_znx_idft_tmp_a) { test_fft64_vec_znx_idft(nullptr, fft64_vec_znx_idft_tmp_a, nullptr); } + +#ifdef __x86_64__ +// FIXME: currently, it only works on avx +static void test_ntt120_vec_znx_idft(VEC_ZNX_IDFT_F idft, VEC_ZNX_IDFT_TMP_A_F idft_tmp_a, + VEC_ZNX_IDFT_TMP_BYTES_F idft_tmp_bytes) { + for (uint64_t n : {2, 4, 64, 128}) { + MODULE* module = new_module_info(n, NTT120); + uint64_t tmp_size = idft_tmp_bytes ? idft_tmp_bytes(module, n) : 0; + std::vector tmp(tmp_size); + for (uint64_t sa : {3, 5, 8}) { + for (uint64_t sr : {3, 5, 8}) { + ntt120_vec_znx_dft_layout a(n, sa); + ntt120_vec_znx_big_layout res(n, sr); + a.fill_random(); + std::vector expect(sr); + for (uint64_t i = 0; i < sr; ++i) { + expect[i] = simple_intt120(a.get_copy_zext(i)); + } + // test the function + if (idft_tmp_bytes) { + thash hash_before = a.content_hash(); + idft(module, res.data, sr, a.data, sa, tmp.data()); + ASSERT_EQ(a.content_hash(), hash_before); + } else { + idft_tmp_a(module, res.data, sr, a.data, sa); + } + for (uint64_t i = 0; i < sr; ++i) { + znx_i128 actual = res.get_copy_zext(i); + ASSERT_EQ(res.get_copy_zext(i), expect[i]); + // if (!(actual == expect[i])) { + // for (uint64_t j = 0; j < n; ++j) { + // std::cerr << actual.get_coeff(j) << " dft vs. " << expect[i].get_coeff(j) << std::endl; + // } + // FAIL(); + // } + } + } + } + delete_module_info(module); + } +} + +TEST(vec_znx_dft, ntt120_vec_znx_idft) { + test_ntt120_vec_znx_idft(ntt120_vec_znx_idft_avx, nullptr, ntt120_vec_znx_idft_tmp_bytes_avx); +} +TEST(vec_znx_dft, ntt120_vec_znx_idft_tmp_a) { + test_ntt120_vec_znx_idft(nullptr, ntt120_vec_znx_idft_tmp_a_avx, nullptr); +} +#endif +TEST(vec_znx_dft, vec_znx_idft) { + test_fft64_vec_znx_idft(vec_znx_idft, nullptr, vec_znx_idft_tmp_bytes); +#ifdef __x86_64__ + // FIXME: currently, only supported on avx + test_ntt120_vec_znx_idft(vec_znx_idft, nullptr, vec_znx_idft_tmp_bytes); +#endif +} +TEST(vec_znx_dft, vec_znx_idft_tmp_a) { + test_fft64_vec_znx_idft(nullptr, vec_znx_idft_tmp_a, nullptr); +#ifdef __x86_64__ + // FIXME: currently, only supported on avx + test_ntt120_vec_znx_idft(nullptr, vec_znx_idft_tmp_a, nullptr); +#endif +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_vec_znx_test.cpp b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_vec_znx_test.cpp new file mode 100644 index 0000000..4c084dd --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_vec_znx_test.cpp @@ -0,0 +1,548 @@ +#include +#include + +#include "../spqlios/arithmetic/vec_znx_arithmetic.h" +#include "gtest/gtest.h" +#include "spqlios/arithmetic/vec_znx_arithmetic_private.h" +#include "spqlios/coeffs/coeffs_arithmetic.h" +#include "test/testlib/mod_q120.h" +#include "test/testlib/negacyclic_polynomial.h" +#include "testlib/fft64_dft.h" +#include "testlib/polynomial_vector.h" + +TEST(fft64_layouts, dft_idft_fft64) { + uint64_t n = 128; + // create a random polynomial + znx_i64 p(n); + for (uint64_t i = 0; i < n; ++i) { + p.set_coeff(i, uniform_i64_bits(36)); + } + // call fft + reim_fft64vec q = simple_fft64(p); + // call ifft and round + znx_i64 r = simple_rint_ifft64(q); + ASSERT_EQ(p, r); +} + +TEST(znx_layout, valid_test) { + uint64_t n = 4; + znx_vec_i64_layout v(n, 7, 13); + // this should be ok + v.set(0, znx_i64::zero(n)); + // this should be ok + ASSERT_EQ(v.get_copy_zext(0), znx_i64::zero(n)); + ASSERT_EQ(v.data()[2], 0); // should be ok + // this is also ok (zero extended vector) + ASSERT_EQ(v.get_copy_zext(1000), znx_i64::zero(n)); +} + +// disabling this test by default, since it depicts on purpose wrong accesses +#if 0 +TEST(znx_layout, valgrind_antipattern_test) { + uint64_t n = 4; + znx_vec_i64_layout v(n, 7, 13); + // this should be ok + v.set(0, znx_i64::zero(n)); + // this should abort (wrong ring dimension) + ASSERT_DEATH(v.set(3, znx_i64::zero(2 * n)), ""); + // this should abort (out of bounds) + ASSERT_DEATH(v.set(8, znx_i64::zero(n)), ""); + // this should be ok + ASSERT_EQ(v.get_copy_zext(0), znx_i64::zero(n)); + // should be an uninit read + ASSERT_TRUE(!(v.get_copy_zext(2) == znx_i64::zero(n))); // should be uninit + // should be an invalid read (inter-slice) + ASSERT_NE(v.data()[4], 0); + ASSERT_EQ(v.data()[2], 0); // should be ok + // should be an uninit read + ASSERT_NE(v.data()[13], 0); // should be uninit +} +#endif + +// test of binary operations + +// test for out of place calls +template +void test_vec_znx_elemw_binop_outplace(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { + for (uint64_t n : {2, 4, 8, 128}) { + MODULE_TYPE mtype = uniform_u64() % 2 == 0 ? FFT64 : NTT120; + MODULE* mod = new_module_info(n, mtype); + for (uint64_t sa : {7, 13, 15}) { + for (uint64_t sb : {7, 13, 15}) { + for (uint64_t sc : {7, 13, 15}) { + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + uint64_t b_sl = uniform_u64_bits(3) * 5 + n; + uint64_t c_sl = uniform_u64_bits(3) * 5 + n; + znx_vec_i64_layout la(n, sa, a_sl); + znx_vec_i64_layout lb(n, sb, b_sl); + znx_vec_i64_layout lc(n, sc, c_sl); + std::vector expect(sc); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, znx_i64::random_log2bound(n, 62)); + } + for (uint64_t i = 0; i < sb; ++i) { + lb.set(i, znx_i64::random_log2bound(n, 62)); + } + for (uint64_t i = 0; i < sc; ++i) { + expect[i] = ref_binop(la.get_copy_zext(i), lb.get_copy_zext(i)); + } + binop(mod, // N + lc.data(), sc, c_sl, // res + la.data(), sa, a_sl, // a + lb.data(), sb, b_sl); + for (uint64_t i = 0; i < sc; ++i) { + ASSERT_EQ(lc.get_copy_zext(i), expect[i]); + } + } + } + } + delete_module_info(mod); + } +} +// test for inplace1 calls +template +void test_vec_znx_elemw_binop_inplace1(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { + for (uint64_t n : {2, 4, 64}) { + MODULE_TYPE mtype = uniform_u64() % 2 == 0 ? FFT64 : NTT120; + MODULE* mod = new_module_info(n, mtype); + for (uint64_t sa : {3, 9, 12}) { + for (uint64_t sb : {3, 9, 12}) { + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + uint64_t b_sl = uniform_u64_bits(3) * 5 + n; + znx_vec_i64_layout la(n, sa, a_sl); + znx_vec_i64_layout lb(n, sb, b_sl); + std::vector expect(sa); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, znx_i64::random_log2bound(n, 62)); + } + for (uint64_t i = 0; i < sb; ++i) { + lb.set(i, znx_i64::random_log2bound(n, 62)); + } + for (uint64_t i = 0; i < sa; ++i) { + expect[i] = ref_binop(la.get_copy_zext(i), lb.get_copy_zext(i)); + } + binop(mod, // N + la.data(), sa, a_sl, // res + la.data(), sa, a_sl, // a + lb.data(), sb, b_sl); + for (uint64_t i = 0; i < sa; ++i) { + ASSERT_EQ(la.get_copy_zext(i), expect[i]); + } + } + } + delete_module_info(mod); + } +} +// test for inplace2 calls +template +void test_vec_znx_elemw_binop_inplace2(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { + for (uint64_t n : {4, 32, 64}) { + MODULE_TYPE mtype = uniform_u64() % 2 == 0 ? FFT64 : NTT120; + MODULE* mod = new_module_info(n, mtype); + for (uint64_t sa : {3, 9, 12}) { + for (uint64_t sb : {3, 9, 12}) { + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + uint64_t b_sl = uniform_u64_bits(3) * 5 + n; + znx_vec_i64_layout la(n, sa, a_sl); + znx_vec_i64_layout lb(n, sb, b_sl); + std::vector expect(sb); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, znx_i64::random_log2bound(n, 62)); + } + for (uint64_t i = 0; i < sb; ++i) { + lb.set(i, znx_i64::random_log2bound(n, 62)); + } + for (uint64_t i = 0; i < sb; ++i) { + expect[i] = ref_binop(la.get_copy_zext(i), lb.get_copy_zext(i)); + } + binop(mod, // N + lb.data(), sb, b_sl, // res + la.data(), sa, a_sl, // a + lb.data(), sb, b_sl); + for (uint64_t i = 0; i < sb; ++i) { + ASSERT_EQ(lb.get_copy_zext(i), expect[i]); + } + } + } + delete_module_info(mod); + } +} +// test for inplace3 calls +template +void test_vec_znx_elemw_binop_inplace3(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { + for (uint64_t n : {2, 16, 1024}) { + MODULE_TYPE mtype = uniform_u64() % 2 == 0 ? FFT64 : NTT120; + MODULE* mod = new_module_info(n, mtype); + for (uint64_t sa : {2, 6, 11}) { + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + znx_vec_i64_layout la(n, sa, a_sl); + std::vector expect(sa); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, znx_i64::random_log2bound(n, 62)); + } + for (uint64_t i = 0; i < sa; ++i) { + expect[i] = ref_binop(la.get_copy_zext(i), la.get_copy_zext(i)); + } + binop(mod, // N + la.data(), sa, a_sl, // res + la.data(), sa, a_sl, // a + la.data(), sa, a_sl); + for (uint64_t i = 0; i < sa; ++i) { + ASSERT_EQ(la.get_copy_zext(i), expect[i]); + } + } + delete_module_info(mod); + } +} +template +void test_vec_znx_elemw_binop(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { + test_vec_znx_elemw_binop_outplace(binop, ref_binop); + test_vec_znx_elemw_binop_inplace1(binop, ref_binop); + test_vec_znx_elemw_binop_inplace2(binop, ref_binop); + test_vec_znx_elemw_binop_inplace3(binop, ref_binop); +} + +static znx_i64 poly_add(const znx_i64& a, const znx_i64& b) { return a + b; } +TEST(vec_znx, vec_znx_add) { test_vec_znx_elemw_binop(vec_znx_add, poly_add); } +TEST(vec_znx, vec_znx_add_ref) { test_vec_znx_elemw_binop(vec_znx_add_ref, poly_add); } +#ifdef __x86_64__ +TEST(vec_znx, vec_znx_add_avx) { test_vec_znx_elemw_binop(vec_znx_add_avx, poly_add); } +#endif + +static znx_i64 poly_sub(const znx_i64& a, const znx_i64& b) { return a - b; } +TEST(vec_znx, vec_znx_sub) { test_vec_znx_elemw_binop(vec_znx_sub, poly_sub); } +TEST(vec_znx, vec_znx_sub_ref) { test_vec_znx_elemw_binop(vec_znx_sub_ref, poly_sub); } +#ifdef __x86_64__ +TEST(vec_znx, vec_znx_sub_avx) { test_vec_znx_elemw_binop(vec_znx_sub_avx, poly_sub); } +#endif + +// test of rotation operations + +// test for out of place calls +template +void test_vec_znx_elemw_unop_param_outplace(ACTUAL_FCN test_rotate, EXPECT_FCN ref_rotate, int64_t (*param_gen)()) { + for (uint64_t n : {2, 4, 8, 128}) { + MODULE_TYPE mtype = uniform_u64() % 2 == 0 ? FFT64 : NTT120; + MODULE* mod = new_module_info(n, mtype); + for (uint64_t sa : {7, 13, 15}) { + for (uint64_t sb : {7, 13, 15}) { + { + int64_t p = param_gen(); + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + uint64_t b_sl = uniform_u64_bits(3) * 4 + n; + znx_vec_i64_layout la(n, sa, a_sl); + znx_vec_i64_layout lb(n, sb, b_sl); + std::vector expect(sb); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, znx_i64::random_log2bound(n, 62)); + } + for (uint64_t i = 0; i < sb; ++i) { + expect[i] = ref_rotate(p, la.get_copy_zext(i)); + } + test_rotate(mod, // + p, // + lb.data(), sb, b_sl, // + la.data(), sa, a_sl // + ); + for (uint64_t i = 0; i < sb; ++i) { + ASSERT_EQ(lb.get_copy_zext(i), expect[i]) << n << " " << sa << " " << sb << " " << i; + } + } + } + } + delete_module_info(mod); + } +} + +// test for inplace calls +template +void test_vec_znx_elemw_unop_param_inplace(ACTUAL_FCN test_rotate, EXPECT_FCN ref_rotate, int64_t (*param_gen)()) { + for (uint64_t n : {2, 16, 1024}) { + MODULE_TYPE mtype = uniform_u64() % 2 == 0 ? FFT64 : NTT120; + MODULE* mod = new_module_info(n, mtype); + for (uint64_t sa : {2, 6, 11}) { + { + int64_t p = param_gen(); + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + znx_vec_i64_layout la(n, sa, a_sl); + std::vector expect(sa); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, znx_i64::random_log2bound(n, 62)); + } + for (uint64_t i = 0; i < sa; ++i) { + expect[i] = ref_rotate(p, la.get_copy_zext(i)); + } + test_rotate(mod, // N + p, //; + la.data(), sa, a_sl, // res + la.data(), sa, a_sl // a + ); + for (uint64_t i = 0; i < sa; ++i) { + ASSERT_EQ(la.get_copy_zext(i), expect[i]) << n << " " << sa << " " << i; + } + } + } + delete_module_info(mod); + } +} + +static int64_t random_rotate_param() { return uniform_i64(); } + +template +void test_vec_znx_elemw_rotate(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { + test_vec_znx_elemw_unop_param_outplace(binop, ref_binop, random_rotate_param); + test_vec_znx_elemw_unop_param_inplace(binop, ref_binop, random_rotate_param); +} + +static znx_i64 poly_rotate(const int64_t p, const znx_i64& a) { + uint64_t n = a.nn(); + znx_i64 res(n); + for (uint64_t i = 0; i < n; ++i) { + res.set_coeff(i, a.get_coeff(i - p)); + } + return res; +} +TEST(vec_znx, vec_znx_rotate) { test_vec_znx_elemw_rotate(vec_znx_rotate, poly_rotate); } +TEST(vec_znx, vec_znx_rotate_ref) { test_vec_znx_elemw_rotate(vec_znx_rotate_ref, poly_rotate); } + +static int64_t random_automorphism_param() { return uniform_i64() | 1; } + +template +void test_vec_znx_elemw_automorphism(ACTUAL_FCN unop, EXPECT_FCN ref_unop) { + test_vec_znx_elemw_unop_param_outplace(unop, ref_unop, random_automorphism_param); + test_vec_znx_elemw_unop_param_inplace(unop, ref_unop, random_automorphism_param); +} + +static znx_i64 poly_automorphism(const int64_t p, const znx_i64& a) { + uint64_t n = a.nn(); + znx_i64 res(n); + for (uint64_t i = 0; i < n; ++i) { + res.set_coeff(i * p, a.get_coeff(i)); + } + return res; +} + +TEST(vec_znx, vec_znx_automorphism) { test_vec_znx_elemw_automorphism(vec_znx_automorphism, poly_automorphism); } +TEST(vec_znx, vec_znx_automorphism_ref) { + test_vec_znx_elemw_automorphism(vec_znx_automorphism_ref, poly_automorphism); +} + +// test for out of place calls +template +void test_vec_znx_elemw_unop_outplace(ACTUAL_FCN test_unop, EXPECT_FCN ref_unop) { + for (uint64_t n : {2, 4, 8, 128}) { + MODULE_TYPE mtype = uniform_u64() % 2 == 0 ? FFT64 : NTT120; + MODULE* mod = new_module_info(n, mtype); + for (uint64_t sa : {7, 13, 15}) { + for (uint64_t sb : {7, 13, 15}) { + { + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + uint64_t b_sl = uniform_u64_bits(3) * 4 + n; + znx_vec_i64_layout la(n, sa, a_sl); + znx_vec_i64_layout lb(n, sb, b_sl); + std::vector expect(sb); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, znx_i64::random_log2bound(n, 62)); + } + for (uint64_t i = 0; i < sb; ++i) { + expect[i] = ref_unop(la.get_copy_zext(i)); + } + test_unop(mod, // + lb.data(), sb, b_sl, // + la.data(), sa, a_sl // + ); + for (uint64_t i = 0; i < sb; ++i) { + ASSERT_EQ(lb.get_copy_zext(i), expect[i]) << n << " " << sa << " " << sb << " " << i; + } + } + } + } + delete_module_info(mod); + } +} + +// test for inplace calls +template +void test_vec_znx_elemw_unop_inplace(ACTUAL_FCN test_unop, EXPECT_FCN ref_unop) { + for (uint64_t n : {2, 16, 1024}) { + MODULE_TYPE mtype = uniform_u64() % 2 == 0 ? FFT64 : NTT120; + MODULE* mod = new_module_info(n, mtype); + for (uint64_t sa : {2, 6, 11}) { + { + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + znx_vec_i64_layout la(n, sa, a_sl); + std::vector expect(sa); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, znx_i64::random_log2bound(n, 62)); + } + for (uint64_t i = 0; i < sa; ++i) { + expect[i] = ref_unop(la.get_copy_zext(i)); + } + test_unop(mod, // N + la.data(), sa, a_sl, // res + la.data(), sa, a_sl // a + ); + for (uint64_t i = 0; i < sa; ++i) { + ASSERT_EQ(la.get_copy_zext(i), expect[i]) << n << " " << sa << " " << i; + } + } + } + delete_module_info(mod); + } +} + +template +void test_vec_znx_elemw_unop(ACTUAL_FCN unop, EXPECT_FCN ref_unop) { + test_vec_znx_elemw_unop_outplace(unop, ref_unop); + test_vec_znx_elemw_unop_inplace(unop, ref_unop); +} + +static znx_i64 poly_copy(const znx_i64& a) { return a; } + +TEST(vec_znx, vec_znx_copy) { test_vec_znx_elemw_unop(vec_znx_copy, poly_copy); } +TEST(vec_znx, vec_znx_copy_ref) { test_vec_znx_elemw_unop(vec_znx_copy_ref, poly_copy); } + +static znx_i64 poly_negate(const znx_i64& a) { return -a; } + +TEST(vec_znx, vec_znx_negate) { test_vec_znx_elemw_unop(vec_znx_negate, poly_negate); } +TEST(vec_znx, vec_znx_negate_ref) { test_vec_znx_elemw_unop(vec_znx_negate_ref, poly_negate); } +#ifdef __x86_64__ +TEST(vec_znx, vec_znx_negate_avx) { test_vec_znx_elemw_unop(vec_znx_negate_avx, poly_negate); } +#endif + +static void test_vec_znx_zero(VEC_ZNX_ZERO_F zero) { + for (uint64_t n : {2, 16, 1024}) { + MODULE_TYPE mtype = uniform_u64() % 2 == 0 ? FFT64 : NTT120; + MODULE* mod = new_module_info(n, mtype); + for (uint64_t sa : {2, 6, 11}) { + { + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + znx_vec_i64_layout la(n, sa, a_sl); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, znx_i64::random_log2bound(n, 62)); + } + zero(mod, // N + la.data(), sa, a_sl // res + ); + znx_i64 ZERO = znx_i64::zero(n); + for (uint64_t i = 0; i < sa; ++i) { + ASSERT_EQ(la.get_copy_zext(i), ZERO) << n << " " << sa << " " << i; + } + } + } + delete_module_info(mod); + } +} + +TEST(vec_znx, vec_znx_zero) { test_vec_znx_zero(vec_znx_zero); } +TEST(vec_znx, vec_znx_zero_ref) { test_vec_znx_zero(vec_znx_zero_ref); } + +static void vec_poly_normalize(const uint64_t base_k, std::vector& in) { + if (in.size() > 0) { + uint64_t n = in.front().nn(); + + znx_i64 out = znx_i64::random_log2bound(n, 62); + znx_i64 cinout(n); + for (int64_t i = in.size() - 1; i >= 0; --i) { + znx_normalize(n, base_k, in[i].data(), cinout.data(), in[i].data(), cinout.data()); + } + } +} + +template +void test_vec_znx_normalize_outplace(ACTUAL_FCN test_normalize, TMP_BYTES_FNC tmp_bytes) { + for (uint64_t n : {2, 16, 1024}) { + MODULE_TYPE mtype = uniform_u64() % 2 == 0 ? FFT64 : NTT120; + MODULE* mod = new_module_info(n, mtype); + for (uint64_t sa : {1, 2, 6, 11}) { + for (uint64_t sb : {1, 2, 6, 11}) { + for (uint64_t base_k : {19}) { + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + znx_vec_i64_layout la(n, sa, a_sl); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, znx_i64::random_log2bound(n, 62)); + } + + std::vector la_norm; + for (uint64_t i = 0; i < sa; ++i) { + la_norm.push_back(la.get_copy_zext(i)); + } + vec_poly_normalize(base_k, la_norm); + + uint64_t b_sl = uniform_u64_bits(3) * 5 + n; + znx_vec_i64_layout lb(n, sb, b_sl); + + const uint64_t tmp_size = tmp_bytes(mod, n); + uint8_t* tmp = new uint8_t[tmp_size]; + test_normalize(mod, + n, // N + base_k, // base_k + lb.data(), sb, b_sl, // res + la.data(), sa, a_sl, // a + tmp); + delete[] tmp; + + for (uint64_t i = 0; i < std::min(sa, sb); ++i) { + ASSERT_EQ(lb.get_copy_zext(i), la_norm[i]) << n << " " << sa << " " << sb << " " << i; + } + znx_i64 zero(n); + for (uint64_t i = std::min(sa, sb); i < sb; ++i) { + ASSERT_EQ(lb.get_copy_zext(i), zero) << n << " " << sa << " " << sb << " " << i; + } + } + } + } + delete_module_info(mod); + } +} + +TEST(vec_znx, vec_znx_normalize_outplace) { + test_vec_znx_normalize_outplace(vec_znx_normalize_base2k, vec_znx_normalize_base2k_tmp_bytes); +} +TEST(vec_znx, vec_znx_normalize_outplace_ref) { + test_vec_znx_normalize_outplace(vec_znx_normalize_base2k_ref, vec_znx_normalize_base2k_tmp_bytes_ref); +} + +template +void test_vec_znx_normalize_inplace(ACTUAL_FCN test_normalize, TMP_BYTES_FNC tmp_bytes) { + for (uint64_t n : {2, 16, 1024}) { + MODULE_TYPE mtype = uniform_u64() % 2 == 0 ? FFT64 : NTT120; + MODULE* mod = new_module_info(n, mtype); + for (uint64_t sa : {2, 6, 11}) { + for (uint64_t base_k : {19}) { + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + znx_vec_i64_layout la(n, sa, a_sl); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, znx_i64::random_log2bound(n, 62)); + } + + std::vector la_norm; + for (uint64_t i = 0; i < sa; ++i) { + la_norm.push_back(la.get_copy_zext(i)); + } + vec_poly_normalize(base_k, la_norm); + + const uint64_t tmp_size = tmp_bytes(mod, n); + uint8_t* tmp = new uint8_t[tmp_size]; + test_normalize(mod, + n, // N + base_k, // base_k + la.data(), sa, a_sl, // res + la.data(), sa, a_sl, // a + tmp); + delete[] tmp; + for (uint64_t i = 0; i < sa; ++i) { + ASSERT_EQ(la.get_copy_zext(i), la_norm[i]) << n << " " << sa << " " << i; + } + } + } + delete_module_info(mod); + } +} + +TEST(vec_znx, vec_znx_normalize_inplace) { + test_vec_znx_normalize_inplace(vec_znx_normalize_base2k, vec_znx_normalize_base2k_tmp_bytes); +} +TEST(vec_znx, vec_znx_normalize_inplace_ref) { + test_vec_znx_normalize_inplace(vec_znx_normalize_base2k_ref, vec_znx_normalize_base2k_tmp_bytes_ref); +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_vmp_product_test.cpp b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_vmp_product_test.cpp new file mode 100644 index 0000000..0eeed1f --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_vmp_product_test.cpp @@ -0,0 +1,180 @@ +#include + +#include "../spqlios/arithmetic/vec_znx_arithmetic_private.h" +#include "../spqlios/reim4/reim4_arithmetic.h" +#include "testlib/fft64_layouts.h" +#include "testlib/polynomial_vector.h" + +static void test_vmp_prepare_contiguous(VMP_PREPARE_CONTIGUOUS_F* prepare_contiguous, + VMP_PREPARE_TMP_BYTES_F* tmp_bytes) { + // tests when n < 8 + for (uint64_t nn : {2, 4}) { + MODULE* module = new_module_info(nn, FFT64); + for (uint64_t nrows : {1, 2, 5}) { + for (uint64_t ncols : {2, 6, 7}) { + znx_vec_i64_layout mat(nn, nrows * ncols, nn); + fft64_vmp_pmat_layout pmat(nn, nrows, ncols); + mat.fill_random(30); + std::vector tmp_space(fft64_vmp_prepare_tmp_bytes(module, nn, nrows, ncols)); + thash hash_before = mat.content_hash(); + prepare_contiguous(module, pmat.data, mat.data(), nrows, ncols, tmp_space.data()); + ASSERT_EQ(mat.content_hash(), hash_before); + for (uint64_t row = 0; row < nrows; ++row) { + for (uint64_t col = 0; col < ncols; ++col) { + const double* pmatv = (double*)pmat.data + (col * nrows + row) * nn; + reim_fft64vec tmp = simple_fft64(mat.get_copy(row * ncols + col)); + const double* tmpv = tmp.data(); + for (uint64_t i = 0; i < nn; ++i) { + ASSERT_LE(abs(pmatv[i] - tmpv[i]), 1e-10); + } + } + } + } + } + delete_module_info(module); + } + // tests when n >= 8 + for (uint64_t nn : {8, 32}) { + MODULE* module = new_module_info(nn, FFT64); + uint64_t nblk = nn / 8; + for (uint64_t nrows : {1, 2, 5}) { + for (uint64_t ncols : {2, 6, 7}) { + znx_vec_i64_layout mat(nn, nrows * ncols, nn); + fft64_vmp_pmat_layout pmat(nn, nrows, ncols); + mat.fill_random(30); + std::vector tmp_space(tmp_bytes(module, nn, nrows, ncols)); + thash hash_before = mat.content_hash(); + prepare_contiguous(module, pmat.data, mat.data(), nrows, ncols, tmp_space.data()); + ASSERT_EQ(mat.content_hash(), hash_before); + for (uint64_t row = 0; row < nrows; ++row) { + for (uint64_t col = 0; col < ncols; ++col) { + reim_fft64vec tmp = simple_fft64(mat.get_copy(row * ncols + col)); + for (uint64_t blk = 0; blk < nblk; ++blk) { + reim4_elem expect = tmp.get_blk(blk); + reim4_elem actual = pmat.get(row, col, blk); + ASSERT_LE(infty_dist(actual, expect), 1e-10); + } + } + } + } + } + delete_module_info(module); + } +} + +TEST(vec_znx, vmp_prepare_contiguous) { test_vmp_prepare_contiguous(vmp_prepare_contiguous, vmp_prepare_tmp_bytes); } +TEST(vec_znx, fft64_vmp_prepare_contiguous_ref) { + test_vmp_prepare_contiguous(fft64_vmp_prepare_contiguous_ref, fft64_vmp_prepare_tmp_bytes); +} +#ifdef __x86_64__ +TEST(vec_znx, fft64_vmp_prepare_contiguous_avx) { + test_vmp_prepare_contiguous(fft64_vmp_prepare_contiguous_avx, fft64_vmp_prepare_tmp_bytes); +} +#endif + +static void test_vmp_apply_add(VMP_APPLY_DFT_TO_DFT_ADD_F* apply, VMP_APPLY_DFT_TO_DFT_TMP_BYTES_F* tmp_bytes) { + for (uint64_t nn : {32}) { + MODULE* module = new_module_info(nn, FFT64); + for (uint64_t mat_nrows : {1, 4, 7}) { + for (uint64_t mat_ncols : {1, 2, 5}) { + for (uint64_t in_size : {1, 4, 7}) { + for (uint64_t out_size : {1, 2, 5}) { + for (uint64_t mat_scale : {1, 2}) { + fft64_vec_znx_dft_layout in(nn, in_size); + fft64_vmp_pmat_layout pmat(nn, mat_nrows, mat_ncols); + fft64_vec_znx_dft_layout out(nn, out_size); + in.fill_dft_random(0); + pmat.fill_dft_random(0); + out.fill_dft_random(0); + + // naive computation of the product + std::vector expect(out_size, reim_fft64vec(nn)); + + for (uint64_t col = 0; col < std::min(uint64_t(0), out_size - mat_scale); ++col) { + reim_fft64vec ex = reim_fft64vec::zero(nn); + for (uint64_t row = 0; row < std::min(mat_nrows, in_size); ++row) { + ex += pmat.get_zext(row, col + mat_scale) * in.get_copy_zext(row); + } + expect[col] = ex; + } + + for (uint64_t col = 0; col < out_size; ++col) { + expect[col] += out.get_copy_zext(col); + } + + // apply the product + std::vector tmp(tmp_bytes(module, nn, out_size, in_size, mat_nrows, mat_ncols)); + apply(module, out.data, out_size, in.data, in_size, pmat.data, mat_nrows, mat_ncols, mat_scale, + tmp.data()); + // check that the output is close from the expectation + for (uint64_t col = 0; col < std::min(uint64_t(0), std::min(mat_ncols, out_size) - mat_scale); ++col) { + reim_fft64vec actual = out.get_copy_zext(col); + ASSERT_LE(infty_dist(actual, expect[col]), 1e-10); + } + } + } + } + } + } + delete_module_info(module); + } +} + +TEST(vec_znx, vmp_apply_to_dft_add) { test_vmp_apply_add(vmp_apply_dft_to_dft_add, vmp_apply_dft_to_dft_tmp_bytes); } +TEST(vec_znx, fft64_vmp_apply_dft_to_dft_add_ref) { + test_vmp_apply_add(fft64_vmp_apply_dft_to_dft_add_ref, vmp_apply_dft_to_dft_tmp_bytes); +} +#ifdef __x86_64__ +TEST(vec_znx, fft64_vmp_apply_dft_to_dft_add_avx) { + test_vmp_apply_add(fft64_vmp_apply_dft_to_dft_add_avx, vmp_apply_dft_to_dft_tmp_bytes); +} +#endif + +static void test_vmp_apply(VMP_APPLY_DFT_TO_DFT_F* apply, VMP_APPLY_DFT_TO_DFT_TMP_BYTES_F* tmp_bytes) { + for (uint64_t nn : {2, 4, 8, 64}) { + MODULE* module = new_module_info(nn, FFT64); + for (uint64_t mat_nrows : {1, 4, 7}) { + for (uint64_t mat_ncols : {1, 2, 5}) { + for (uint64_t in_size : {1, 4, 7}) { + for (uint64_t out_size : {1, 2, 5}) { + fft64_vec_znx_dft_layout in(nn, in_size); + fft64_vmp_pmat_layout pmat(nn, mat_nrows, mat_ncols); + fft64_vec_znx_dft_layout out(nn, out_size); + in.fill_dft_random(0); + pmat.fill_dft_random(0); + + // naive computation of the product + std::vector expect(out_size, reim_fft64vec(nn)); + for (uint64_t col = 0; col < out_size; ++col) { + reim_fft64vec ex = reim_fft64vec::zero(nn); + for (uint64_t row = 0; row < std::min(mat_nrows, in_size); ++row) { + ex += pmat.get_zext(row, col) * in.get_copy_zext(row); + } + expect[col] = ex; + } + + // apply the product + std::vector tmp(tmp_bytes(module, nn, out_size, in_size, mat_nrows, mat_ncols)); + apply(module, out.data, out_size, in.data, in_size, pmat.data, mat_nrows, mat_ncols, tmp.data()); + // check that the output is close from the expectation + for (uint64_t col = 0; col < std::min(mat_ncols, out_size); ++col) { + reim_fft64vec actual = out.get_copy_zext(col); + ASSERT_LE(infty_dist(actual, expect[col]), 1e-10); + } + } + } + } + } + delete_module_info(module); + } +} + +TEST(vec_znx, vmp_apply_to_dft) { test_vmp_apply(vmp_apply_dft_to_dft, vmp_apply_dft_to_dft_tmp_bytes); } +TEST(vec_znx, fft64_vmp_apply_dft_to_dft_ref) { + test_vmp_apply(fft64_vmp_apply_dft_to_dft_ref, vmp_apply_dft_to_dft_tmp_bytes); +} +#ifdef __x86_64__ +TEST(vec_znx, fft64_vmp_apply_dft_to_dft_avx) { + test_vmp_apply(fft64_vmp_apply_dft_to_dft_avx, fft64_vmp_apply_dft_to_dft_tmp_bytes); +} +#endif \ No newline at end of file diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_zn_approxdecomp_test.cpp b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_zn_approxdecomp_test.cpp new file mode 100644 index 0000000..d21f420 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_zn_approxdecomp_test.cpp @@ -0,0 +1,46 @@ +#include "gtest/gtest.h" +#include "spqlios/arithmetic/zn_arithmetic_private.h" +#include "testlib/test_commons.h" + +template +static void test_tndbl_approxdecomp( // + void (*approxdec)(const MOD_Z*, const TNDBL_APPROXDECOMP_GADGET*, INTTYPE*, uint64_t, const double*, uint64_t) // +) { + for (const uint64_t nn : {1, 3, 8, 51}) { + MOD_Z* module = new_z_module_info(DEFAULT); + for (const uint64_t ell : {1, 2, 7}) { + for (const uint64_t k : {2, 5}) { + TNDBL_APPROXDECOMP_GADGET* gadget = new_tndbl_approxdecomp_gadget(module, k, ell); + for (const uint64_t res_size : {ell * nn}) { + std::vector in(nn); + std::vector out(res_size); + for (double& x : in) x = uniform_f64_bounds(-10, 10); + approxdec(module, gadget, out.data(), res_size, in.data(), nn); + // reconstruct the output + double err_bnd = pow(2., -double(ell * k) - 1); + for (uint64_t j = 0; j < nn; ++j) { + double in_j = in[j]; + double out_j = 0; + for (uint64_t i = 0; i < ell; ++i) { + out_j += out[ell * j + i] * pow(2., -double((i + 1) * k)); + } + double err = out_j - in_j; + double err_abs = fabs(err - rint(err)); + ASSERT_LE(err_abs, err_bnd); + } + } + delete_tndbl_approxdecomp_gadget(gadget); + } + } + delete_z_module_info(module); + } +} + +TEST(vec_rnx, i8_tndbl_rnx_approxdecomp) { test_tndbl_approxdecomp(i8_approxdecomp_from_tndbl); } +TEST(vec_rnx, default_i8_tndbl_rnx_approxdecomp) { test_tndbl_approxdecomp(default_i8_approxdecomp_from_tndbl_ref); } + +TEST(vec_rnx, i16_tndbl_rnx_approxdecomp) { test_tndbl_approxdecomp(i16_approxdecomp_from_tndbl); } +TEST(vec_rnx, default_i16_tndbl_rnx_approxdecomp) { test_tndbl_approxdecomp(default_i16_approxdecomp_from_tndbl_ref); } + +TEST(vec_rnx, i32_tndbl_rnx_approxdecomp) { test_tndbl_approxdecomp(i32_approxdecomp_from_tndbl); } +TEST(vec_rnx, default_i32_tndbl_rnx_approxdecomp) { test_tndbl_approxdecomp(default_i32_approxdecomp_from_tndbl_ref); } diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_zn_conversions_test.cpp b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_zn_conversions_test.cpp new file mode 100644 index 0000000..da2b94b --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_zn_conversions_test.cpp @@ -0,0 +1,104 @@ +#include +#include + +#include "testlib/test_commons.h" + +template +static void test_conv(void (*conv_f)(const MOD_Z*, DST_T* res, uint64_t res_size, const SRC_T* a, uint64_t a_size), + DST_T (*ideal_conv_f)(SRC_T x), SRC_T (*random_f)()) { + MOD_Z* module = new_z_module_info(DEFAULT); + for (uint64_t a_size : {0, 1, 2, 42}) { + for (uint64_t res_size : {0, 1, 2, 42}) { + for (uint64_t trials = 0; trials < 100; ++trials) { + std::vector a(a_size); + std::vector res(res_size); + uint64_t msize = std::min(a_size, res_size); + for (SRC_T& x : a) x = random_f(); + conv_f(module, res.data(), res_size, a.data(), a_size); + for (uint64_t i = 0; i < msize; ++i) { + DST_T expect = ideal_conv_f(a[i]); + DST_T actual = res[i]; + ASSERT_EQ(expect, actual); + } + for (uint64_t i = msize; i < res_size; ++i) { + DST_T expect = 0; + SRC_T actual = res[i]; + ASSERT_EQ(expect, actual); + } + } + } + } + delete_z_module_info(module); +} + +static int32_t ideal_dbl_to_tn32(double a) { + double _2p32 = INT64_C(1) << 32; + double a_mod_1 = a - rint(a); + int64_t t = rint(a_mod_1 * _2p32); + return int32_t(t); +} + +static double random_f64_10() { return uniform_f64_bounds(-10, 10); } + +static void test_dbl_to_tn32(DBL_TO_TN32_F dbl_to_tn32_f) { + test_conv(dbl_to_tn32_f, ideal_dbl_to_tn32, random_f64_10); +} + +TEST(zn_arithmetic, dbl_to_tn32) { test_dbl_to_tn32(dbl_to_tn32); } +TEST(zn_arithmetic, dbl_to_tn32_ref) { test_dbl_to_tn32(dbl_to_tn32_ref); } + +static double ideal_tn32_to_dbl(int32_t a) { + const double _2p32 = INT64_C(1) << 32; + return double(a) / _2p32; +} + +static int32_t random_t32() { return uniform_i64_bits(32); } + +static void test_tn32_to_dbl(TN32_TO_DBL_F tn32_to_dbl_f) { test_conv(tn32_to_dbl_f, ideal_tn32_to_dbl, random_t32); } + +TEST(zn_arithmetic, tn32_to_dbl) { test_tn32_to_dbl(tn32_to_dbl); } +TEST(zn_arithmetic, tn32_to_dbl_ref) { test_tn32_to_dbl(tn32_to_dbl_ref); } + +static int32_t ideal_dbl_round_to_i32(double a) { return int32_t(rint(a)); } + +static double random_dbl_explaw_18() { return uniform_f64_bounds(-1., 1.) * pow(2., uniform_u64_bits(6) % 19); } + +static void test_dbl_round_to_i32(DBL_ROUND_TO_I32_F dbl_round_to_i32_f) { + test_conv(dbl_round_to_i32_f, ideal_dbl_round_to_i32, random_dbl_explaw_18); +} + +TEST(zn_arithmetic, dbl_round_to_i32) { test_dbl_round_to_i32(dbl_round_to_i32); } +TEST(zn_arithmetic, dbl_round_to_i32_ref) { test_dbl_round_to_i32(dbl_round_to_i32_ref); } + +static double ideal_i32_to_dbl(int32_t a) { return double(a); } + +static int32_t random_i32_explaw_18() { return uniform_i64_bits(uniform_u64_bits(6) % 19); } + +static void test_i32_to_dbl(I32_TO_DBL_F i32_to_dbl_f) { + test_conv(i32_to_dbl_f, ideal_i32_to_dbl, random_i32_explaw_18); +} + +TEST(zn_arithmetic, i32_to_dbl) { test_i32_to_dbl(i32_to_dbl); } +TEST(zn_arithmetic, i32_to_dbl_ref) { test_i32_to_dbl(i32_to_dbl_ref); } + +static int64_t ideal_dbl_round_to_i64(double a) { return rint(a); } + +static double random_dbl_explaw_50() { return uniform_f64_bounds(-1., 1.) * pow(2., uniform_u64_bits(7) % 51); } + +static void test_dbl_round_to_i64(DBL_ROUND_TO_I64_F dbl_round_to_i64_f) { + test_conv(dbl_round_to_i64_f, ideal_dbl_round_to_i64, random_dbl_explaw_50); +} + +TEST(zn_arithmetic, dbl_round_to_i64) { test_dbl_round_to_i64(dbl_round_to_i64); } +TEST(zn_arithmetic, dbl_round_to_i64_ref) { test_dbl_round_to_i64(dbl_round_to_i64_ref); } + +static double ideal_i64_to_dbl(int64_t a) { return double(a); } + +static int64_t random_i64_explaw_50() { return uniform_i64_bits(uniform_u64_bits(7) % 51); } + +static void test_i64_to_dbl(I64_TO_DBL_F i64_to_dbl_f) { + test_conv(i64_to_dbl_f, ideal_i64_to_dbl, random_i64_explaw_50); +} + +TEST(zn_arithmetic, i64_to_dbl) { test_i64_to_dbl(i64_to_dbl); } +TEST(zn_arithmetic, i64_to_dbl_ref) { test_i64_to_dbl(i64_to_dbl_ref); } diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_zn_vmp_test.cpp b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_zn_vmp_test.cpp new file mode 100644 index 0000000..57b0ad0 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_zn_vmp_test.cpp @@ -0,0 +1,94 @@ +#include "gtest/gtest.h" +#include "spqlios/arithmetic/zn_arithmetic_private.h" +#include "testlib/zn_layouts.h" + +static void test_zn_vmp_prepare(ZN32_VMP_PREPARE_CONTIGUOUS_F prepare_contiguous) { + MOD_Z* module = new_z_module_info(DEFAULT); + for (uint64_t nrows : {1, 2, 5, 15}) { + for (uint64_t ncols : {1, 2, 32, 42, 67}) { + std::vector src(nrows * ncols); + zn32_pmat_layout out(nrows, ncols); + for (int32_t& x : src) x = uniform_i64_bits(32); + prepare_contiguous(module, out.data, src.data(), nrows, ncols); + for (uint64_t i = 0; i < nrows; ++i) { + for (uint64_t j = 0; j < ncols; ++j) { + int32_t in = src[i * ncols + j]; + int32_t actual = out.get(i, j); + ASSERT_EQ(actual, in); + } + } + } + } + delete_z_module_info(module); +} + +TEST(zn, zn32_vmp_prepare_contiguous) { test_zn_vmp_prepare(zn32_vmp_prepare_contiguous); } +TEST(zn, default_zn32_vmp_prepare_contiguous_ref) { test_zn_vmp_prepare(default_zn32_vmp_prepare_contiguous_ref); } + +static void test_zn_vmp_prepare(ZN32_VMP_PREPARE_DBLPTR_F prepare_dblptr) { + MOD_Z* module = new_z_module_info(DEFAULT); + for (uint64_t nrows : {1, 2, 5, 15}) { + for (uint64_t ncols : {1, 2, 32, 42, 67}) { + std::vector src(nrows * ncols); + zn32_pmat_layout out(nrows, ncols); + for (int32_t& x : src) x = uniform_i64_bits(32); + const int32_t** mat_dblptr = (const int32_t**)malloc(nrows * sizeof(int32_t*)); + for (size_t row_i = 0; row_i < nrows; row_i++) { + mat_dblptr[row_i] = &src.data()[row_i * ncols]; + }; + prepare_dblptr(module, out.data, mat_dblptr, nrows, ncols); + for (uint64_t i = 0; i < nrows; ++i) { + for (uint64_t j = 0; j < ncols; ++j) { + int32_t in = src[i * ncols + j]; + int32_t actual = out.get(i, j); + ASSERT_EQ(actual, in); + } + } + } + } + delete_z_module_info(module); +} + +TEST(zn, zn32_vmp_prepare_dblptr) { test_zn_vmp_prepare(zn32_vmp_prepare_dblptr); } +TEST(zn, default_zn32_vmp_prepare_dblptr_ref) { test_zn_vmp_prepare(default_zn32_vmp_prepare_dblptr_ref); } + +template +static void test_zn_vmp_apply(void (*apply)(const MOD_Z*, int32_t*, uint64_t, const INTTYPE*, uint64_t, + const ZN32_VMP_PMAT*, uint64_t, uint64_t)) { + MOD_Z* module = new_z_module_info(DEFAULT); + for (uint64_t nrows : {1, 2, 5, 15}) { + for (uint64_t ncols : {1, 2, 32, 42, 67}) { + for (uint64_t a_size : {1, 2, 5, 15}) { + for (uint64_t res_size : {1, 2, 32, 42, 67}) { + std::vector a(a_size); + zn32_pmat_layout out(nrows, ncols); + std::vector res(res_size); + for (INTTYPE& x : a) x = uniform_i64_bits(32); + out.fill_random(); + std::vector expect = vmp_product(a.data(), a_size, res_size, out); + apply(module, res.data(), res_size, a.data(), a_size, out.data, nrows, ncols); + for (uint64_t i = 0; i < res_size; ++i) { + int32_t exp = expect[i]; + int32_t actual = res[i]; + ASSERT_EQ(actual, exp); + } + } + } + } + } + delete_z_module_info(module); +} + +TEST(zn, zn32_vmp_apply_i32) { test_zn_vmp_apply(zn32_vmp_apply_i32); } +TEST(zn, zn32_vmp_apply_i16) { test_zn_vmp_apply(zn32_vmp_apply_i16); } +TEST(zn, zn32_vmp_apply_i8) { test_zn_vmp_apply(zn32_vmp_apply_i8); } + +TEST(zn, default_zn32_vmp_apply_i32_ref) { test_zn_vmp_apply(default_zn32_vmp_apply_i32_ref); } +TEST(zn, default_zn32_vmp_apply_i16_ref) { test_zn_vmp_apply(default_zn32_vmp_apply_i16_ref); } +TEST(zn, default_zn32_vmp_apply_i8_ref) { test_zn_vmp_apply(default_zn32_vmp_apply_i8_ref); } + +#ifdef __x86_64__ +TEST(zn, default_zn32_vmp_apply_i32_avx) { test_zn_vmp_apply(default_zn32_vmp_apply_i32_avx); } +TEST(zn, default_zn32_vmp_apply_i16_avx) { test_zn_vmp_apply(default_zn32_vmp_apply_i16_avx); } +TEST(zn, default_zn32_vmp_apply_i8_avx) { test_zn_vmp_apply(default_zn32_vmp_apply_i8_avx); } +#endif diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_znx_small_test.cpp b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_znx_small_test.cpp new file mode 100644 index 0000000..2f5e9bf --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/spqlios_znx_small_test.cpp @@ -0,0 +1,26 @@ +#include + +#include "../spqlios/arithmetic/vec_znx_arithmetic_private.h" +#include "testlib/negacyclic_polynomial.h" + +static void test_znx_small_single_product(ZNX_SMALL_SINGLE_PRODUCT_F product, + ZNX_SMALL_SINGLE_PRODUCT_TMP_BYTES_F product_tmp_bytes) { + for (const uint64_t nn : {2, 4, 8, 64}) { + MODULE* module = new_module_info(nn, FFT64); + znx_i64 a = znx_i64::random_log2bound(nn, 20); + znx_i64 b = znx_i64::random_log2bound(nn, 20); + znx_i64 expect = naive_product(a, b); + znx_i64 actual(nn); + std::vector tmp(znx_small_single_product_tmp_bytes(module, nn)); + fft64_znx_small_single_product(module, actual.data(), a.data(), b.data(), tmp.data()); + ASSERT_EQ(actual, expect) << actual.get_coeff(0) << " vs. " << expect.get_coeff(0); + delete_module_info(module); + } +} + +TEST(znx_small, fft64_znx_small_single_product) { + test_znx_small_single_product(fft64_znx_small_single_product, fft64_znx_small_single_product_tmp_bytes); +} +TEST(znx_small, znx_small_single_product) { + test_znx_small_single_product(znx_small_single_product, znx_small_single_product_tmp_bytes); +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/fft64_dft.cpp b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/fft64_dft.cpp new file mode 100644 index 0000000..e66b680 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/fft64_dft.cpp @@ -0,0 +1,168 @@ +#include "fft64_dft.h" + +#include + +#include "../../spqlios/reim/reim_fft.h" +#include "../../spqlios/reim/reim_fft_internal.h" + +reim_fft64vec::reim_fft64vec(uint64_t n) : v(n, 0) {} +reim4_elem reim_fft64vec::get_blk(uint64_t blk) const { + return reim_view(v.size() / 2, (double*)v.data()).get_blk(blk); +} +double* reim_fft64vec::data() { return v.data(); } +const double* reim_fft64vec::data() const { return v.data(); } +uint64_t reim_fft64vec::nn() const { return v.size(); } +reim_fft64vec::reim_fft64vec(uint64_t n, const double* data) : v(data, data + n) {} +void reim_fft64vec::save_as(double* dest) const { memcpy(dest, v.data(), nn() * sizeof(double)); } +reim_fft64vec reim_fft64vec::zero(uint64_t n) { return reim_fft64vec(n); } +void reim_fft64vec::set_blk(uint64_t blk, const reim4_elem& value) { + reim_view(v.size() / 2, (double*)v.data()).set_blk(blk, value); +} +reim_fft64vec reim_fft64vec::dft_random(uint64_t n, uint64_t log2bound) { + return simple_fft64(znx_i64::random_log2bound(n, log2bound)); +} +reim_fft64vec reim_fft64vec::random(uint64_t n, double log2bound) { + double bound = pow(2., log2bound); + reim_fft64vec res(n); + for (uint64_t i = 0; i < n; ++i) { + res.v[i] = uniform_f64_bounds(-bound, bound); + } + return res; +} + +reim_fft64vec operator+(const reim_fft64vec& a, const reim_fft64vec& b) { + uint64_t nn = a.nn(); + REQUIRE_DRAMATICALLY(b.nn() == a.nn(), "ring dimension mismatch"); + reim_fft64vec res(nn); + double* rv = res.data(); + const double* av = a.data(); + const double* bv = b.data(); + for (uint64_t i = 0; i < nn; ++i) { + rv[i] = av[i] + bv[i]; + } + return res; +} +reim_fft64vec operator-(const reim_fft64vec& a, const reim_fft64vec& b) { + uint64_t nn = a.nn(); + REQUIRE_DRAMATICALLY(b.nn() == a.nn(), "ring dimension mismatch"); + reim_fft64vec res(nn); + double* rv = res.data(); + const double* av = a.data(); + const double* bv = b.data(); + for (uint64_t i = 0; i < nn; ++i) { + rv[i] = av[i] - bv[i]; + } + return res; +} +reim_fft64vec operator*(const reim_fft64vec& a, const reim_fft64vec& b) { + uint64_t nn = a.nn(); + REQUIRE_DRAMATICALLY(b.nn() == a.nn(), "ring dimension mismatch"); + REQUIRE_DRAMATICALLY(nn >= 2, "test not defined for nn=1"); + uint64_t m = nn / 2; + reim_fft64vec res(nn); + double* rv = res.data(); + const double* av = a.data(); + const double* bv = b.data(); + for (uint64_t i = 0; i < m; ++i) { + rv[i] = av[i] * bv[i] - av[m + i] * bv[m + i]; + rv[m + i] = av[i] * bv[m + i] + av[m + i] * bv[i]; + } + return res; +} +reim_fft64vec& operator+=(reim_fft64vec& a, const reim_fft64vec& b) { + uint64_t nn = a.nn(); + REQUIRE_DRAMATICALLY(b.nn() == a.nn(), "ring dimension mismatch"); + double* av = a.data(); + const double* bv = b.data(); + for (uint64_t i = 0; i < nn; ++i) { + av[i] = av[i] + bv[i]; + } + return a; +} +reim_fft64vec& operator-=(reim_fft64vec& a, const reim_fft64vec& b) { + uint64_t nn = a.nn(); + REQUIRE_DRAMATICALLY(b.nn() == a.nn(), "ring dimension mismatch"); + double* av = a.data(); + const double* bv = b.data(); + for (uint64_t i = 0; i < nn; ++i) { + av[i] = av[i] - bv[i]; + } + return a; +} + +reim_fft64vec simple_fft64(const znx_i64& polynomial) { + const uint64_t nn = polynomial.nn(); + const uint64_t m = nn / 2; + reim_fft64vec res(nn); + double* dat = res.data(); + for (uint64_t i = 0; i < nn; ++i) dat[i] = polynomial.get_coeff(i); + reim_fft_simple(m, dat); + return res; +} + +znx_i64 simple_rint_ifft64(const reim_fft64vec& fftvec) { + const uint64_t nn = fftvec.nn(); + const uint64_t m = nn / 2; + std::vector vv(fftvec.data(), fftvec.data() + nn); + double* v = vv.data(); + reim_ifft_simple(m, v); + znx_i64 res(nn); + for (uint64_t i = 0; i < nn; ++i) { + res.set_coeff(i, rint(v[i] / m)); + } + return res; +} + +rnx_f64 naive_ifft64(const reim_fft64vec& fftvec) { + const uint64_t nn = fftvec.nn(); + const uint64_t m = nn / 2; + std::vector vv(fftvec.data(), fftvec.data() + nn); + double* v = vv.data(); + reim_ifft_simple(m, v); + rnx_f64 res(nn); + for (uint64_t i = 0; i < nn; ++i) { + res.set_coeff(i, v[i] / m); + } + return res; +} +double infty_dist(const reim_fft64vec& a, const reim_fft64vec& b) { + const uint64_t n = a.nn(); + REQUIRE_DRAMATICALLY(b.nn() == a.nn(), "dimensions mismatch"); + const double* da = a.data(); + const double* db = b.data(); + double d = 0; + for (uint64_t i = 0; i < n; ++i) { + double di = abs(da[i] - db[i]); + if (di > d) d = di; + } + return d; +} + +reim_fft64vec simple_fft64(const rnx_f64& polynomial) { + const uint64_t nn = polynomial.nn(); + const uint64_t m = nn / 2; + reim_fft64vec res(nn); + double* dat = res.data(); + for (uint64_t i = 0; i < nn; ++i) dat[i] = polynomial.get_coeff(i); + reim_fft_simple(m, dat); + return res; +} + +reim_fft64vec operator*(double coeff, const reim_fft64vec& v) { + const uint64_t nn = v.nn(); + reim_fft64vec res(nn); + double* rr = res.data(); + const double* vv = v.data(); + for (uint64_t i = 0; i < nn; ++i) rr[i] = coeff * vv[i]; + return res; +} + +rnx_f64 simple_ifft64(const reim_fft64vec& v) { + const uint64_t nn = v.nn(); + const uint64_t m = nn / 2; + rnx_f64 res(nn); + double* dat = res.data(); + memcpy(dat, v.data(), nn * sizeof(double)); + reim_ifft_simple(m, dat); + return res; +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/fft64_dft.h b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/fft64_dft.h new file mode 100644 index 0000000..32ee437 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/fft64_dft.h @@ -0,0 +1,43 @@ +#ifndef SPQLIOS_FFT64_DFT_H +#define SPQLIOS_FFT64_DFT_H + +#include "negacyclic_polynomial.h" +#include "reim4_elem.h" + +class reim_fft64vec { + std::vector v; + + public: + reim_fft64vec() = default; + explicit reim_fft64vec(uint64_t n); + reim_fft64vec(uint64_t n, const double* data); + uint64_t nn() const; + static reim_fft64vec zero(uint64_t n); + /** random complex coefficients (unstructured) */ + static reim_fft64vec random(uint64_t n, double log2bound); + /** random fft of a small int polynomial */ + static reim_fft64vec dft_random(uint64_t n, uint64_t log2bound); + double* data(); + const double* data() const; + void save_as(double* dest) const; + reim4_elem get_blk(uint64_t blk) const; + void set_blk(uint64_t blk, const reim4_elem& value); +}; + +reim_fft64vec operator+(const reim_fft64vec& a, const reim_fft64vec& b); +reim_fft64vec operator-(const reim_fft64vec& a, const reim_fft64vec& b); +reim_fft64vec operator*(const reim_fft64vec& a, const reim_fft64vec& b); +reim_fft64vec operator*(double coeff, const reim_fft64vec& v); +reim_fft64vec& operator+=(reim_fft64vec& a, const reim_fft64vec& b); +reim_fft64vec& operator-=(reim_fft64vec& a, const reim_fft64vec& b); + +/** infty distance */ +double infty_dist(const reim_fft64vec& a, const reim_fft64vec& b); + +reim_fft64vec simple_fft64(const znx_i64& polynomial); +znx_i64 simple_rint_ifft64(const reim_fft64vec& fftvec); +rnx_f64 naive_ifft64(const reim_fft64vec& fftvec); +reim_fft64vec simple_fft64(const rnx_f64& polynomial); +rnx_f64 simple_ifft64(const reim_fft64vec& v); + +#endif // SPQLIOS_FFT64_DFT_H diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/fft64_layouts.cpp b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/fft64_layouts.cpp new file mode 100644 index 0000000..a8976b6 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/fft64_layouts.cpp @@ -0,0 +1,238 @@ +#include "fft64_layouts.h" +#ifdef VALGRIND_MEM_TESTS +#include "valgrind/memcheck.h" +#endif + +void* alloc64(uint64_t size) { + static uint64_t _msk64 = -64; + if (size == 0) return nullptr; + uint64_t rsize = (size + 63) & _msk64; + uint8_t* reps = (uint8_t*)spqlios_alloc(rsize); + REQUIRE_DRAMATICALLY(reps != 0, "Out of memory"); +#ifdef VALGRIND_MEM_TESTS + VALGRIND_MAKE_MEM_NOACCESS(reps + size, rsize - size); +#endif + return reps; +} + +fft64_vec_znx_dft_layout::fft64_vec_znx_dft_layout(uint64_t n, uint64_t size) + : nn(n), // + size(size), // + data((VEC_ZNX_DFT*)alloc64(n * size * 8)), // + view(n / 2, size, (double*)data) {} + +fft64_vec_znx_dft_layout::~fft64_vec_znx_dft_layout() { spqlios_free(data); } + +double* fft64_vec_znx_dft_layout::get_addr(uint64_t idx) { + REQUIRE_DRAMATICALLY(idx < size, "index overflow " << idx << " / " << size); + return ((double*)data) + idx * nn; +} +const double* fft64_vec_znx_dft_layout::get_addr(uint64_t idx) const { + REQUIRE_DRAMATICALLY(idx < size, "index overflow " << idx << " / " << size); + return ((double*)data) + idx * nn; +} +reim_fft64vec fft64_vec_znx_dft_layout::get_copy_zext(uint64_t idx) const { + if (idx < size) { + return reim_fft64vec(nn, get_addr(idx)); + } else { + return reim_fft64vec::zero(nn); + } +} +void fft64_vec_znx_dft_layout::fill_dft_random_log2bound(uint64_t bits) { + for (uint64_t i = 0; i < size; ++i) { + set(i, simple_fft64(znx_i64::random_log2bound(nn, bits))); + } +} +void fft64_vec_znx_dft_layout::set(uint64_t idx, const reim_fft64vec& value) { + REQUIRE_DRAMATICALLY(value.nn() == nn, "ring dimension mismatch"); + value.save_as(get_addr(idx)); +} +thash fft64_vec_znx_dft_layout::content_hash() const { return test_hash(data, size * nn * sizeof(double)); } + +reim4_elem fft64_vec_znx_dft_layout::get(uint64_t idx, uint64_t blk) const { + REQUIRE_DRAMATICALLY(idx < size, "index overflow: " << idx << " / " << size); + REQUIRE_DRAMATICALLY(blk < nn / 8, "blk overflow: " << blk << " / " << nn / 8); + double* reim = ((double*)data) + idx * nn; + return reim4_elem(reim + blk * 4, reim + nn / 2 + blk * 4); +} +reim4_elem fft64_vec_znx_dft_layout::get_zext(uint64_t idx, uint64_t blk) const { + REQUIRE_DRAMATICALLY(blk < nn / 8, "blk overflow: " << blk << " / " << nn / 8); + if (idx < size) { + return get(idx, blk); + } else { + return reim4_elem::zero(); + } +} +void fft64_vec_znx_dft_layout::set(uint64_t idx, uint64_t blk, const reim4_elem& value) { + REQUIRE_DRAMATICALLY(idx < size, "index overflow: " << idx << " / " << size); + REQUIRE_DRAMATICALLY(blk < nn / 8, "blk overflow: " << blk << " / " << nn / 8); + double* reim = ((double*)data) + idx * nn; + value.save_re_im(reim + blk * 4, reim + nn / 2 + blk * 4); +} +void fft64_vec_znx_dft_layout::fill_random(double log2bound) { + for (uint64_t i = 0; i < size; ++i) { + set(i, reim_fft64vec::random(nn, log2bound)); + } +} +void fft64_vec_znx_dft_layout::fill_dft_random(uint64_t log2bound) { + for (uint64_t i = 0; i < size; ++i) { + set(i, reim_fft64vec::dft_random(nn, log2bound)); + } +} + +fft64_vec_znx_big_layout::fft64_vec_znx_big_layout(uint64_t n, uint64_t size) + : nn(n), // + size(size), // + data((VEC_ZNX_BIG*)alloc64(n * size * 8)) {} + +znx_i64 fft64_vec_znx_big_layout::get_copy(uint64_t index) const { + REQUIRE_DRAMATICALLY(index < size, "index overflow: " << index << " / " << size); + return znx_i64(nn, ((int64_t*)data) + index * nn); +} +znx_i64 fft64_vec_znx_big_layout::get_copy_zext(uint64_t index) const { + if (index < size) { + return znx_i64(nn, ((int64_t*)data) + index * nn); + } else { + return znx_i64::zero(nn); + } +} +void fft64_vec_znx_big_layout::set(uint64_t index, const znx_i64& value) { + REQUIRE_DRAMATICALLY(index < size, "index overflow: " << index << " / " << size); + value.save_as(((int64_t*)data) + index * nn); +} +void fft64_vec_znx_big_layout::fill_random() { + for (uint64_t i = 0; i < size; ++i) { + set(i, znx_i64::random_log2bound(nn, 1)); + } +} +fft64_vec_znx_big_layout::~fft64_vec_znx_big_layout() { spqlios_free(data); } + +fft64_vmp_pmat_layout::fft64_vmp_pmat_layout(uint64_t n, uint64_t nrows, uint64_t ncols) + : nn(n), + nrows(nrows), + ncols(ncols), // + data((VMP_PMAT*)alloc64(nrows * ncols * nn * 8)) {} + +double* fft64_vmp_pmat_layout::get_addr(uint64_t row, uint64_t col, uint64_t blk) const { + REQUIRE_DRAMATICALLY(row < nrows, "row overflow: " << row << " / " << nrows); + REQUIRE_DRAMATICALLY(col < ncols, "col overflow: " << col << " / " << ncols); + REQUIRE_DRAMATICALLY(blk < nn / 8, "block overflow: " << blk << " / " << (nn / 8)); + double* d = (double*)data; + if (col == (ncols - 1) && (ncols % 2 == 1)) { + // special case: last column out of an odd column number + return d + blk * nrows * ncols * 8 // major: blk + + col * nrows * 8 // col == ncols-1 + + row * 8; + } else { + // general case: columns go by pair + return d + blk * nrows * ncols * 8 // major: blk + + (col / 2) * (2 * nrows) * 8 // second: col pair index + + row * 2 * 8 // third: row index + + (col % 2) * 8; // minor: col in colpair + } +} + +reim4_elem fft64_vmp_pmat_layout::get(uint64_t row, uint64_t col, uint64_t blk) const { + return reim4_elem(get_addr(row, col, blk)); +} +reim4_elem fft64_vmp_pmat_layout::get_zext(uint64_t row, uint64_t col, uint64_t blk) const { + REQUIRE_DRAMATICALLY(blk < nn / 8, "block overflow: " << blk << " / " << (nn / 8)); + if (row < nrows && col < ncols) { + return reim4_elem(get_addr(row, col, blk)); + } else { + return reim4_elem::zero(); + } +} +void fft64_vmp_pmat_layout::set(uint64_t row, uint64_t col, uint64_t blk, const reim4_elem& value) const { + value.save_as(get_addr(row, col, blk)); +} + +fft64_vmp_pmat_layout::~fft64_vmp_pmat_layout() { spqlios_free(data); } + +reim_fft64vec fft64_vmp_pmat_layout::get_zext(uint64_t row, uint64_t col) const { + if (row >= nrows || col >= ncols) { + return reim_fft64vec::zero(nn); + } + if (nn < 8) { + // the pmat is just col major + double* addr = (double*)data + (row + col * nrows) * nn; + return reim_fft64vec(nn, addr); + } + // otherwise, reconstruct it block by block + reim_fft64vec res(nn); + for (uint64_t blk = 0; blk < nn / 8; ++blk) { + reim4_elem v = get(row, col, blk); + res.set_blk(blk, v); + } + return res; +} +void fft64_vmp_pmat_layout::set(uint64_t row, uint64_t col, const reim_fft64vec& value) { + REQUIRE_DRAMATICALLY(row < nrows, "row overflow: " << row << " / " << nrows); + REQUIRE_DRAMATICALLY(col < ncols, "row overflow: " << col << " / " << ncols); + if (nn < 8) { + // the pmat is just col major + double* addr = (double*)data + (row + col * nrows) * nn; + value.save_as(addr); + return; + } + // otherwise, reconstruct it block by block + for (uint64_t blk = 0; blk < nn / 8; ++blk) { + reim4_elem v = value.get_blk(blk); + set(row, col, blk, v); + } +} +void fft64_vmp_pmat_layout::fill_random(double log2bound) { + for (uint64_t row = 0; row < nrows; ++row) { + for (uint64_t col = 0; col < ncols; ++col) { + set(row, col, reim_fft64vec::random(nn, log2bound)); + } + } +} +void fft64_vmp_pmat_layout::fill_dft_random(uint64_t log2bound) { + for (uint64_t row = 0; row < nrows; ++row) { + for (uint64_t col = 0; col < ncols; ++col) { + set(row, col, reim_fft64vec::dft_random(nn, log2bound)); + } + } +} + +fft64_svp_ppol_layout::fft64_svp_ppol_layout(uint64_t n) + : nn(n), // + data((SVP_PPOL*)alloc64(nn * 8)) {} + +reim_fft64vec fft64_svp_ppol_layout::get_copy() const { return reim_fft64vec(nn, (double*)data); } + +void fft64_svp_ppol_layout::set(const reim_fft64vec& value) { value.save_as((double*)data); } + +void fft64_svp_ppol_layout::fill_dft_random(uint64_t log2bound) { set(reim_fft64vec::dft_random(nn, log2bound)); } + +void fft64_svp_ppol_layout::fill_random(double log2bound) { set(reim_fft64vec::random(nn, log2bound)); } + +fft64_svp_ppol_layout::~fft64_svp_ppol_layout() { spqlios_free(data); } +thash fft64_svp_ppol_layout::content_hash() const { return test_hash(data, nn * sizeof(double)); } + +fft64_cnv_left_layout::fft64_cnv_left_layout(uint64_t n, uint64_t size) + : nn(n), // + size(size), + data((CNV_PVEC_L*)alloc64(size * nn * 8)) {} + +reim4_elem fft64_cnv_left_layout::get(uint64_t idx, uint64_t blk) { + REQUIRE_DRAMATICALLY(idx < size, "idx overflow: " << idx << " / " << size); + REQUIRE_DRAMATICALLY(blk < nn / 8, "block overflow: " << blk << " / " << (nn / 8)); + return reim4_elem(((double*)data) + blk * size + idx); +} + +fft64_cnv_left_layout::~fft64_cnv_left_layout() { spqlios_free(data); } + +fft64_cnv_right_layout::fft64_cnv_right_layout(uint64_t n, uint64_t size) + : nn(n), // + size(size), + data((CNV_PVEC_R*)alloc64(size * nn * 8)) {} + +reim4_elem fft64_cnv_right_layout::get(uint64_t idx, uint64_t blk) { + REQUIRE_DRAMATICALLY(idx < size, "idx overflow: " << idx << " / " << size); + REQUIRE_DRAMATICALLY(blk < nn / 8, "block overflow: " << blk << " / " << (nn / 8)); + return reim4_elem(((double*)data) + blk * size + idx); +} + +fft64_cnv_right_layout::~fft64_cnv_right_layout() { spqlios_free(data); } diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/fft64_layouts.h b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/fft64_layouts.h new file mode 100644 index 0000000..ba71448 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/fft64_layouts.h @@ -0,0 +1,109 @@ +#ifndef SPQLIOS_FFT64_LAYOUTS_H +#define SPQLIOS_FFT64_LAYOUTS_H + +#include "../../spqlios/arithmetic/vec_znx_arithmetic.h" +#include "fft64_dft.h" +#include "negacyclic_polynomial.h" +#include "reim4_elem.h" + +/** @brief test layout for the VEC_ZNX_DFT */ +struct fft64_vec_znx_dft_layout { + public: + const uint64_t nn; + const uint64_t size; + VEC_ZNX_DFT* const data; + reim_vector_view view; + /** @brief fill with random double values (unstructured) */ + void fill_random(double log2bound); + /** @brief fill with random ffts of small int polynomials */ + void fill_dft_random(uint64_t log2bound); + reim4_elem get(uint64_t idx, uint64_t blk) const; + reim4_elem get_zext(uint64_t idx, uint64_t blk) const; + void set(uint64_t idx, uint64_t blk, const reim4_elem& value); + fft64_vec_znx_dft_layout(uint64_t n, uint64_t size); + void fill_random_log2bound(uint64_t bits); + void fill_dft_random_log2bound(uint64_t bits); + double* get_addr(uint64_t idx); + const double* get_addr(uint64_t idx) const; + reim_fft64vec get_copy_zext(uint64_t idx) const; + void set(uint64_t idx, const reim_fft64vec& value); + thash content_hash() const; + ~fft64_vec_znx_dft_layout(); +}; + +/** @brief test layout for the VEC_ZNX_BIG */ +class fft64_vec_znx_big_layout { + public: + const uint64_t nn; + const uint64_t size; + VEC_ZNX_BIG* const data; + fft64_vec_znx_big_layout(uint64_t n, uint64_t size); + void fill_random(); + znx_i64 get_copy(uint64_t index) const; + znx_i64 get_copy_zext(uint64_t index) const; + void set(uint64_t index, const znx_i64& value); + thash content_hash() const; + ~fft64_vec_znx_big_layout(); +}; + +/** @brief test layout for the VMP_PMAT */ +class fft64_vmp_pmat_layout { + public: + const uint64_t nn; + const uint64_t nrows; + const uint64_t ncols; + VMP_PMAT* const data; + fft64_vmp_pmat_layout(uint64_t n, uint64_t nrows, uint64_t ncols); + double* get_addr(uint64_t row, uint64_t col, uint64_t blk) const; + reim4_elem get(uint64_t row, uint64_t col, uint64_t blk) const; + thash content_hash() const; + reim4_elem get_zext(uint64_t row, uint64_t col, uint64_t blk) const; + reim_fft64vec get_zext(uint64_t row, uint64_t col) const; + void set(uint64_t row, uint64_t col, uint64_t blk, const reim4_elem& v) const; + void set(uint64_t row, uint64_t col, const reim_fft64vec& value); + /** @brief fill with random double values (unstructured) */ + void fill_random(double log2bound); + /** @brief fill with random ffts of small int polynomials */ + void fill_dft_random(uint64_t log2bound); + ~fft64_vmp_pmat_layout(); +}; + +/** @brief test layout for the SVP_PPOL */ +class fft64_svp_ppol_layout { + public: + const uint64_t nn; + SVP_PPOL* const data; + fft64_svp_ppol_layout(uint64_t n); + thash content_hash() const; + reim_fft64vec get_copy() const; + void set(const reim_fft64vec&); + /** @brief fill with random double values (unstructured) */ + void fill_random(double log2bound); + /** @brief fill with random ffts of small int polynomials */ + void fill_dft_random(uint64_t log2bound); + ~fft64_svp_ppol_layout(); +}; + +/** @brief test layout for the CNV_PVEC_L */ +class fft64_cnv_left_layout { + const uint64_t nn; + const uint64_t size; + CNV_PVEC_L* const data; + fft64_cnv_left_layout(uint64_t n, uint64_t size); + reim4_elem get(uint64_t idx, uint64_t blk); + thash content_hash() const; + ~fft64_cnv_left_layout(); +}; + +/** @brief test layout for the CNV_PVEC_R */ +class fft64_cnv_right_layout { + const uint64_t nn; + const uint64_t size; + CNV_PVEC_R* const data; + fft64_cnv_right_layout(uint64_t n, uint64_t size); + reim4_elem get(uint64_t idx, uint64_t blk); + thash content_hash() const; + ~fft64_cnv_right_layout(); +}; + +#endif // SPQLIOS_FFT64_LAYOUTS_H diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/mod_q120.cpp b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/mod_q120.cpp new file mode 100644 index 0000000..4fc4bb3 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/mod_q120.cpp @@ -0,0 +1,229 @@ +#include "mod_q120.h" + +#include +#include + +int64_t centermod(int64_t v, int64_t q) { + int64_t t = v % q; + if (t >= (q + 1) / 2) return t - q; + if (t < -q / 2) return t + q; + return t; +} + +int64_t centermod(uint64_t v, int64_t q) { + int64_t t = int64_t(v % uint64_t(q)); + if (t >= q / 2) return t - q; + return t; +} + +mod_q120::mod_q120() { + for (uint64_t i = 0; i < 4; ++i) { + a[i] = 0; + } +} + +mod_q120::mod_q120(int64_t a0, int64_t a1, int64_t a2, int64_t a3) { + a[0] = centermod(a0, Qi[0]); + a[1] = centermod(a1, Qi[1]); + a[2] = centermod(a2, Qi[2]); + a[3] = centermod(a3, Qi[3]); +} + +mod_q120 operator+(const mod_q120& x, const mod_q120& y) { + mod_q120 r; + for (uint64_t i = 0; i < 4; ++i) { + r.a[i] = centermod(x.a[i] + y.a[i], mod_q120::Qi[i]); + } + return r; +} + +mod_q120 operator-(const mod_q120& x, const mod_q120& y) { + mod_q120 r; + for (uint64_t i = 0; i < 4; ++i) { + r.a[i] = centermod(x.a[i] - y.a[i], mod_q120::Qi[i]); + } + return r; +} + +mod_q120 operator*(const mod_q120& x, const mod_q120& y) { + mod_q120 r; + for (uint64_t i = 0; i < 4; ++i) { + r.a[i] = centermod(x.a[i] * y.a[i], mod_q120::Qi[i]); + } + return r; +} + +mod_q120& operator+=(mod_q120& x, const mod_q120& y) { + for (uint64_t i = 0; i < 4; ++i) { + x.a[i] = centermod(x.a[i] + y.a[i], mod_q120::Qi[i]); + } + return x; +} + +mod_q120& operator-=(mod_q120& x, const mod_q120& y) { + for (uint64_t i = 0; i < 4; ++i) { + x.a[i] = centermod(x.a[i] - y.a[i], mod_q120::Qi[i]); + } + return x; +} + +mod_q120& operator*=(mod_q120& x, const mod_q120& y) { + for (uint64_t i = 0; i < 4; ++i) { + x.a[i] = centermod(x.a[i] * y.a[i], mod_q120::Qi[i]); + } + return x; +} + +int64_t modq_pow(int64_t x, int32_t k, int64_t q) { + k = (k % (q - 1) + q - 1) % (q - 1); + + int64_t res = 1; + int64_t x_pow = centermod(x, q); + while (k != 0) { + if (k & 1) res = centermod(res * x_pow, q); + x_pow = centermod(x_pow * x_pow, q); + k >>= 1; + } + return res; +} + +mod_q120 pow(const mod_q120& x, int32_t k) { + const int64_t r0 = modq_pow(x.a[0], k, x.Qi[0]); + const int64_t r1 = modq_pow(x.a[1], k, x.Qi[1]); + const int64_t r2 = modq_pow(x.a[2], k, x.Qi[2]); + const int64_t r3 = modq_pow(x.a[3], k, x.Qi[3]); + return mod_q120{r0, r1, r2, r3}; +} + +static int64_t half_modq(int64_t x, int64_t q) { + // q must be odd in this function + if (x % 2 == 0) return x / 2; + return centermod((x + q) / 2, q); +} + +mod_q120 half(const mod_q120& x) { + const int64_t r0 = half_modq(x.a[0], x.Qi[0]); + const int64_t r1 = half_modq(x.a[1], x.Qi[1]); + const int64_t r2 = half_modq(x.a[2], x.Qi[2]); + const int64_t r3 = half_modq(x.a[3], x.Qi[3]); + return mod_q120{r0, r1, r2, r3}; +} + +bool operator==(const mod_q120& x, const mod_q120& y) { + for (uint64_t i = 0; i < 4; ++i) { + if (x.a[i] != y.a[i]) return false; + } + return true; +} + +std::ostream& operator<<(std::ostream& out, const mod_q120& x) { + return out << "q120{" << x.a[0] << "," << x.a[1] << "," << x.a[2] << "," << x.a[3] << "}"; +} + +mod_q120 mod_q120::from_q120a(const void* addr) { + static const uint64_t _2p32 = UINT64_C(1) << 32; + const uint64_t* in = (const uint64_t*)addr; + mod_q120 r; + for (uint64_t i = 0; i < 4; ++i) { + REQUIRE_DRAMATICALLY(in[i] < _2p32, "invalid layout a q120"); + r.a[i] = centermod(in[i], mod_q120::Qi[i]); + } + return r; +} + +mod_q120 mod_q120::from_q120b(const void* addr) { + const uint64_t* in = (const uint64_t*)addr; + mod_q120 r; + for (uint64_t i = 0; i < 4; ++i) { + r.a[i] = centermod(in[i], mod_q120::Qi[i]); + } + return r; +} + +mod_q120 mod_q120::from_q120c(const void* addr) { + // static const uint64_t _mask_2p32 = (uint64_t(1) << 32) - 1; + const uint32_t* in = (const uint32_t*)addr; + mod_q120 r; + for (uint64_t i = 0, k = 0; i < 8; i += 2, ++k) { + const uint64_t q = mod_q120::Qi[k]; + uint64_t u = in[i]; + uint64_t w = in[i + 1]; + REQUIRE_DRAMATICALLY(((u << 32) % q) == (w % q), + "invalid layout q120c: " << u << ".2^32 != " << (w >> 32) << " mod " << q); + r.a[k] = centermod(u, q); + } + return r; +} +__int128_t mod_q120::to_int128() const { + static const __int128_t qm[] = {(__int128_t(Qi[1]) * Qi[2]) * Qi[3], (__int128_t(Qi[0]) * Qi[2]) * Qi[3], + (__int128_t(Qi[0]) * Qi[1]) * Qi[3], (__int128_t(Qi[0]) * Qi[1]) * Qi[2]}; + static const int64_t CRTi[] = {Q1_CRT_CST, Q2_CRT_CST, Q3_CRT_CST, Q4_CRT_CST}; + static const __int128_t q = qm[0] * Qi[0]; + static const __int128_t qs2 = q / 2; + __int128_t res = 0; + for (uint64_t i = 0; i < 4; ++i) { + res += (a[i] * CRTi[i] % Qi[i]) * qm[i]; + } + res = (((res % q) + q + qs2) % q) - qs2; // centermod + return res; +} +void mod_q120::save_as_q120a(void* dest) const { + int64_t* d = (int64_t*)dest; + for (uint64_t i = 0; i < 4; ++i) { + d[i] = a[i] + Qi[i]; + } +} +void mod_q120::save_as_q120b(void* dest) const { + int64_t* d = (int64_t*)dest; + for (uint64_t i = 0; i < 4; ++i) { + d[i] = a[i] + (Qi[i] * (1 + uniform_u64_bits(32))); + } +} +void mod_q120::save_as_q120c(void* dest) const { + int32_t* d = (int32_t*)dest; + for (uint64_t i = 0; i < 4; ++i) { + d[2 * i] = a[i] + 3 * Qi[i]; + d[2 * i + 1] = (uint64_t(d[2 * i]) << 32) % uint64_t(Qi[i]); + } +} + +mod_q120 uniform_q120() { + test_rng& gen = randgen(); + std::uniform_int_distribution dista(0, mod_q120::Qi[0]); + std::uniform_int_distribution distb(0, mod_q120::Qi[1]); + std::uniform_int_distribution distc(0, mod_q120::Qi[2]); + std::uniform_int_distribution distd(0, mod_q120::Qi[3]); + return mod_q120(dista(gen), distb(gen), distc(gen), distd(gen)); +} + +void uniform_q120a(void* dest) { + uint64_t* res = (uint64_t*)dest; + for (uint64_t i = 0; i < 4; ++i) { + res[i] = uniform_u64_bits(32); + } +} + +void uniform_q120b(void* dest) { + uint64_t* res = (uint64_t*)dest; + for (uint64_t i = 0; i < 4; ++i) { + res[i] = uniform_u64(); + } +} + +void uniform_q120c(void* dest) { + uint32_t* res = (uint32_t*)dest; + static const uint64_t _2p32 = uint64_t(1) << 32; + for (uint64_t i = 0, k = 0; i < 8; i += 2, ++k) { + const uint64_t q = mod_q120::Qi[k]; + const uint64_t z = uniform_u64_bits(32); + const uint64_t z_pow_red = (z << 32) % q; + const uint64_t room = (_2p32 - z_pow_red) / q; + const uint64_t z_pow = z_pow_red + (uniform_u64() % room) * q; + REQUIRE_DRAMATICALLY(z < _2p32, "bug!"); + REQUIRE_DRAMATICALLY(z_pow < _2p32, "bug!"); + REQUIRE_DRAMATICALLY(z_pow % q == (z << 32) % q, "bug!"); + + res[i] = (uint32_t)z; + res[i + 1] = (uint32_t)z_pow; + } +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/mod_q120.h b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/mod_q120.h new file mode 100644 index 0000000..45c7cc7 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/mod_q120.h @@ -0,0 +1,49 @@ +#ifndef SPQLIOS_MOD_Q120_H +#define SPQLIOS_MOD_Q120_H + +#include + +#include "../../spqlios/q120/q120_common.h" +#include "test_commons.h" + +/** @brief centered modulo q */ +int64_t centermod(int64_t v, int64_t q); +int64_t centermod(uint64_t v, int64_t q); + +/** @brief this class represents an integer mod Q120 */ +class mod_q120 { + public: + static constexpr int64_t Qi[] = {Q1, Q2, Q3, Q4}; + int64_t a[4]; + mod_q120(int64_t a1, int64_t a2, int64_t a3, int64_t a4); + mod_q120(); + __int128_t to_int128() const; + static mod_q120 from_q120a(const void* addr); + static mod_q120 from_q120b(const void* addr); + static mod_q120 from_q120c(const void* addr); + void save_as_q120a(void* dest) const; + void save_as_q120b(void* dest) const; + void save_as_q120c(void* dest) const; +}; + +mod_q120 operator+(const mod_q120& x, const mod_q120& y); +mod_q120 operator-(const mod_q120& x, const mod_q120& y); +mod_q120 operator*(const mod_q120& x, const mod_q120& y); +mod_q120& operator+=(mod_q120& x, const mod_q120& y); +mod_q120& operator-=(mod_q120& x, const mod_q120& y); +mod_q120& operator*=(mod_q120& x, const mod_q120& y); +std::ostream& operator<<(std::ostream& out, const mod_q120& x); +bool operator==(const mod_q120& x, const mod_q120& y); +mod_q120 pow(const mod_q120& x, int32_t k); +mod_q120 half(const mod_q120& x); + +/** @brief a uniformly drawn number mod Q120 */ +mod_q120 uniform_q120(); +/** @brief a uniformly random mod Q120 layout A (4 integers < 2^32) */ +void uniform_q120a(void* dest); +/** @brief a uniformly random mod Q120 layout B (4 integers < 2^64) */ +void uniform_q120b(void* dest); +/** @brief a uniformly random mod Q120 layout C (4 integers repr. x,2^32x) */ +void uniform_q120c(void* dest); + +#endif // SPQLIOS_MOD_Q120_H diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/negacyclic_polynomial.cpp b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/negacyclic_polynomial.cpp new file mode 100644 index 0000000..ee516c6 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/negacyclic_polynomial.cpp @@ -0,0 +1,18 @@ +#include "negacyclic_polynomial_impl.h" + +// explicit instantiation +EXPLICIT_INSTANTIATE_POLYNOMIAL(__int128_t); +EXPLICIT_INSTANTIATE_POLYNOMIAL(int64_t); +EXPLICIT_INSTANTIATE_POLYNOMIAL(double); + +double infty_dist(const rnx_f64& a, const rnx_f64& b) { + const uint64_t nn = a.nn(); + const double* aa = a.data(); + const double* bb = b.data(); + double res = 0.; + for (uint64_t i = 0; i < nn; ++i) { + double d = fabs(aa[i] - bb[i]); + if (d > res) res = d; + } + return res; +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/negacyclic_polynomial.h b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/negacyclic_polynomial.h new file mode 100644 index 0000000..8f5b17f --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/negacyclic_polynomial.h @@ -0,0 +1,69 @@ +#ifndef SPQLIOS_NEGACYCLIC_POLYNOMIAL_H +#define SPQLIOS_NEGACYCLIC_POLYNOMIAL_H + +#include + +#include "test_commons.h" + +template +class polynomial; +typedef polynomial<__int128_t> znx_i128; +typedef polynomial znx_i64; +typedef polynomial rnx_f64; + +template +class polynomial { + public: + std::vector coeffs; + /** @brief create a polynomial out of existing coeffs */ + polynomial(uint64_t N, const T* c); + /** @brief zero polynomial of dimension N */ + explicit polynomial(uint64_t N); + /** @brief empty polynomial (dim 0) */ + polynomial(); + + /** @brief ring dimension */ + uint64_t nn() const; + /** @brief special setter (accept any indexes, and does the negacyclic translation) */ + void set_coeff(int64_t i, T v); + /** @brief special getter (accept any indexes, and does the negacyclic translation) */ + T get_coeff(int64_t i) const; + /** @brief returns the coefficient layout */ + T* data(); + /** @brief returns the coefficient layout (const version) */ + const T* data() const; + /** @brief saves to the layout */ + void save_as(T* dest) const; + /** @brief zero */ + static polynomial zero(uint64_t n); + /** @brief random polynomial with coefficients in [-2^log2bounds, 2^log2bounds]*/ + static polynomial random_log2bound(uint64_t n, uint64_t log2bound); + /** @brief random polynomial with coefficients in [-2^log2bounds, 2^log2bounds]*/ + static polynomial random(uint64_t n); + /** @brief random polynomial with coefficient in [lb;ub] */ + static polynomial random_bound(uint64_t n, const T lb, const T ub); +}; + +/** @brief equality operator (used during tests) */ +template +bool operator==(const polynomial& a, const polynomial& b); + +/** @brief addition operator (used during tests) */ +template +polynomial operator+(const polynomial& a, const polynomial& b); + +/** @brief subtraction operator (used during tests) */ +template +polynomial operator-(const polynomial& a, const polynomial& b); + +/** @brief negation operator (used during tests) */ +template +polynomial operator-(const polynomial& a); + +template +polynomial naive_product(const polynomial& a, const polynomial& b); + +/** @brief distance between two real polynomials (used during tests) */ +double infty_dist(const rnx_f64& a, const rnx_f64& b); + +#endif // SPQLIOS_NEGACYCLIC_POLYNOMIAL_H diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/negacyclic_polynomial_impl.h b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/negacyclic_polynomial_impl.h new file mode 100644 index 0000000..e31c5a5 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/negacyclic_polynomial_impl.h @@ -0,0 +1,247 @@ +#ifndef SPQLIOS_NEGACYCLIC_POLYNOMIAL_IMPL_H +#define SPQLIOS_NEGACYCLIC_POLYNOMIAL_IMPL_H + +#include "negacyclic_polynomial.h" + +template +polynomial::polynomial(uint64_t N, const T* c) : coeffs(N) { + for (uint64_t i = 0; i < N; ++i) coeffs[i] = c[i]; +} +/** @brief zero polynomial of dimension N */ +template +polynomial::polynomial(uint64_t N) : coeffs(N, 0) {} +/** @brief empty polynomial (dim 0) */ +template +polynomial::polynomial() {} + +/** @brief ring dimension */ +template +uint64_t polynomial::nn() const { + uint64_t n = coeffs.size(); + REQUIRE_DRAMATICALLY(is_pow2(n), "polynomial dim is not a pow of 2"); + return n; +} + +/** @brief special setter (accept any indexes, and does the negacyclic translation) */ +template +void polynomial::set_coeff(int64_t i, T v) { + const uint64_t n = nn(); + const uint64_t _2nm = 2 * n - 1; + uint64_t pos = uint64_t(i) & _2nm; + if (pos < n) { + coeffs[pos] = v; + } else { + coeffs[pos - n] = -v; + } +} +/** @brief special getter (accept any indexes, and does the negacyclic translation) */ +template +T polynomial::get_coeff(int64_t i) const { + const uint64_t n = nn(); + const uint64_t _2nm = 2 * n - 1; + uint64_t pos = uint64_t(i) & _2nm; + if (pos < n) { + return coeffs[pos]; + } else { + return -coeffs[pos - n]; + } +} +/** @brief returns the coefficient layout */ +template +T* polynomial::data() { + return coeffs.data(); +} + +template +void polynomial::save_as(T* dest) const { + const uint64_t n = nn(); + for (uint64_t i = 0; i < n; ++i) { + dest[i] = coeffs[i]; + } +} + +/** @brief returns the coefficient layout (const version) */ +template +const T* polynomial::data() const { + return coeffs.data(); +} + +/** @brief returns the coefficient layout (const version) */ +template +polynomial polynomial::zero(uint64_t n) { + return polynomial(n); +} + +/** @brief equality operator (used during tests) */ +template +bool operator==(const polynomial& a, const polynomial& b) { + uint64_t n = a.nn(); + REQUIRE_DRAMATICALLY(b.nn() == n, "wrong dimensions"); + for (uint64_t i = 0; i < n; ++i) { + if (a.get_coeff(i) != b.get_coeff(i)) return false; + } + return true; +} + +/** @brief addition operator (used during tests) */ +template +polynomial operator+(const polynomial& a, const polynomial& b) { + uint64_t n = a.nn(); + REQUIRE_DRAMATICALLY(b.nn() == n, "wrong dimensions"); + polynomial res(n); + for (uint64_t i = 0; i < n; ++i) { + res.set_coeff(i, a.get_coeff(i) + b.get_coeff(i)); + } + return res; +} + +/** @brief subtraction operator (used during tests) */ +template +polynomial operator-(const polynomial& a, const polynomial& b) { + uint64_t n = a.nn(); + REQUIRE_DRAMATICALLY(b.nn() == n, "wrong dimensions"); + polynomial res(n); + for (uint64_t i = 0; i < n; ++i) { + res.set_coeff(i, a.get_coeff(i) - b.get_coeff(i)); + } + return res; +} + +/** @brief subtraction operator (used during tests) */ +template +polynomial operator-(const polynomial& a) { + uint64_t n = a.nn(); + polynomial res(n); + for (uint64_t i = 0; i < n; ++i) { + res.set_coeff(i, -a.get_coeff(i)); + } + return res; +} + +/** @brief random polynomial */ +template +polynomial random_polynomial(uint64_t n); + +/** @brief random int64 polynomial */ +template <> +polynomial random_polynomial(uint64_t n) { + polynomial res(n); + for (uint64_t i = 0; i < n; ++i) { + res.set_coeff(i, uniform_i64()); + } + return res; +} + +/** @brief random float64 gaussian polynomial */ +template <> +polynomial random_polynomial(uint64_t n) { + polynomial res(n); + for (uint64_t i = 0; i < n; ++i) { + res.set_coeff(i, random_f64_gaussian()); + } + return res; +} + +template +polynomial random_polynomial_bounds(uint64_t n, const T lb, const T ub); + +/** @brief random int64 polynomial */ +template <> +polynomial random_polynomial_bounds(uint64_t n, const int64_t lb, const int64_t ub) { + polynomial res(n); + for (uint64_t i = 0; i < n; ++i) { + res.set_coeff(i, uniform_i64_bounds(lb, ub)); + } + return res; +} + +/** @brief random float64 gaussian polynomial */ +template <> +polynomial random_polynomial_bounds(uint64_t n, const double lb, const double ub) { + polynomial res(n); + for (uint64_t i = 0; i < n; ++i) { + res.set_coeff(i, uniform_f64_bounds(lb, ub)); + } + return res; +} + +/** @brief random int64 polynomial */ +template <> +polynomial<__int128_t> random_polynomial_bounds(uint64_t n, const __int128_t lb, const __int128_t ub) { + polynomial<__int128_t> res(n); + for (uint64_t i = 0; i < n; ++i) { + res.set_coeff(i, uniform_i128_bounds(lb, ub)); + } + return res; +} + +template +polynomial random_polynomial_bits(uint64_t n, const uint64_t bits) { + T b = UINT64_C(1) << bits; + return random_polynomial_bounds(n, -b, b); +} + +template <> +polynomial polynomial::random_log2bound(uint64_t n, uint64_t log2bound) { + polynomial res(n); + for (uint64_t i = 0; i < n; ++i) { + res.set_coeff(i, uniform_i64_bits(log2bound)); + } + return res; +} + +template <> +polynomial polynomial::random(uint64_t n) { + polynomial res(n); + for (uint64_t i = 0; i < n; ++i) { + res.set_coeff(i, uniform_u64()); + } + return res; +} + +template <> +polynomial polynomial::random_log2bound(uint64_t n, uint64_t log2bound) { + polynomial res(n); + double bound = pow(2., log2bound); + for (uint64_t i = 0; i < n; ++i) { + res.set_coeff(i, uniform_f64_bounds(-bound, bound)); + } + return res; +} + +template <> +polynomial polynomial::random(uint64_t n) { + polynomial res(n); + double bound = 2.; + for (uint64_t i = 0; i < n; ++i) { + res.set_coeff(i, uniform_f64_bounds(-bound, bound)); + } + return res; +} + +template +polynomial naive_product(const polynomial& a, const polynomial& b) { + const int64_t nn = a.nn(); + REQUIRE_DRAMATICALLY(b.nn() == uint64_t(nn), "dimension mismatch!"); + polynomial res(nn); + for (int64_t i = 0; i < nn; ++i) { + T ri = 0; + for (int64_t j = 0; j < nn; ++j) { + ri += a.get_coeff(j) * b.get_coeff(i - j); + } + res.set_coeff(i, ri); + } + return res; +} + +#define EXPLICIT_INSTANTIATE_POLYNOMIAL(TYPE) \ + template class polynomial; \ + template bool operator==(const polynomial& a, const polynomial& b); \ + template polynomial operator+(const polynomial& a, const polynomial& b); \ + template polynomial operator-(const polynomial& a, const polynomial& b); \ + template polynomial operator-(const polynomial& a); \ + template polynomial random_polynomial_bits(uint64_t n, const uint64_t bits); \ + template polynomial naive_product(const polynomial& a, const polynomial& b); \ + // template polynomial random_polynomial(uint64_t n); + +#endif // SPQLIOS_NEGACYCLIC_POLYNOMIAL_IMPL_H diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/ntt120_dft.cpp b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/ntt120_dft.cpp new file mode 100644 index 0000000..5d4b6f5 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/ntt120_dft.cpp @@ -0,0 +1,122 @@ +#include "ntt120_dft.h" + +#include "mod_q120.h" + +// @brief alternative version of the NTT + +/** for all s=k/2^17, root_of_unity(s) = omega_0^k */ +static mod_q120 root_of_unity(double s) { + static mod_q120 omega_2pow17{OMEGA1, OMEGA2, OMEGA3, OMEGA4}; + static double _2pow17 = 1 << 17; + return pow(omega_2pow17, s * _2pow17); +} +static mod_q120 root_of_unity_inv(double s) { + static mod_q120 omega_2pow17{OMEGA1, OMEGA2, OMEGA3, OMEGA4}; + static double _2pow17 = 1 << 17; + return pow(omega_2pow17, -s * _2pow17); +} + +/** recursive naive ntt */ +static void q120_ntt_naive_rec(uint64_t n, double entry_pwr, mod_q120* data) { + if (n == 1) return; + const uint64_t h = n / 2; + const double s = entry_pwr / 2.; + mod_q120 om = root_of_unity(s); + for (uint64_t j = 0; j < h; ++j) { + mod_q120 om_right = data[h + j] * om; + data[h + j] = data[j] - om_right; + data[j] = data[j] + om_right; + } + q120_ntt_naive_rec(h, s, data); + q120_ntt_naive_rec(h, s + 0.5, data + h); +} +static void q120_intt_naive_rec(uint64_t n, double entry_pwr, mod_q120* data) { + if (n == 1) return; + const uint64_t h = n / 2; + const double s = entry_pwr / 2.; + q120_intt_naive_rec(h, s, data); + q120_intt_naive_rec(h, s + 0.5, data + h); + mod_q120 om = root_of_unity_inv(s); + for (uint64_t j = 0; j < h; ++j) { + mod_q120 dat_diff = half(data[j] - data[h + j]); + data[j] = half(data[j] + data[h + j]); + data[h + j] = dat_diff * om; + } +} + +/** user friendly version */ +q120_nttvec simple_ntt120(const znx_i64& polynomial) { + const uint64_t n = polynomial.nn(); + q120_nttvec res(n); + for (uint64_t i = 0; i < n; ++i) { + int64_t xi = polynomial.get_coeff(i); + res.v[i] = mod_q120(xi, xi, xi, xi); + } + q120_ntt_naive_rec(n, 0.5, res.v.data()); + return res; +} + +znx_i128 simple_intt120(const q120_nttvec& fftvec) { + const uint64_t n = fftvec.nn(); + q120_nttvec copy = fftvec; + znx_i128 res(n); + q120_intt_naive_rec(n, 0.5, copy.v.data()); + for (uint64_t i = 0; i < n; ++i) { + res.set_coeff(i, copy.v[i].to_int128()); + } + return res; +} +bool operator==(const q120_nttvec& a, const q120_nttvec& b) { return a.v == b.v; } + +std::vector q120_ntt_naive(const std::vector& x) { + std::vector res = x; + q120_ntt_naive_rec(res.size(), 0.5, res.data()); + return res; +} +q120_nttvec::q120_nttvec(uint64_t n) : v(n) {} +q120_nttvec::q120_nttvec(uint64_t n, const q120b* data) : v(n) { + int64_t* d = (int64_t*)data; + for (uint64_t i = 0; i < n; ++i) { + v[i] = mod_q120::from_q120b(d + 4 * i); + } +} +q120_nttvec::q120_nttvec(uint64_t n, const q120c* data) : v(n) { + int64_t* d = (int64_t*)data; + for (uint64_t i = 0; i < n; ++i) { + v[i] = mod_q120::from_q120c(d + 4 * i); + } +} +uint64_t q120_nttvec::nn() const { return v.size(); } +q120_nttvec q120_nttvec::zero(uint64_t n) { return q120_nttvec(n); } +void q120_nttvec::save_as(q120a* dest) const { + int64_t* const d = (int64_t*)dest; + const uint64_t n = nn(); + for (uint64_t i = 0; i < n; ++i) { + v[i].save_as_q120a(d + 4 * i); + } +} +void q120_nttvec::save_as(q120b* dest) const { + int64_t* const d = (int64_t*)dest; + const uint64_t n = nn(); + for (uint64_t i = 0; i < n; ++i) { + v[i].save_as_q120b(d + 4 * i); + } +} +void q120_nttvec::save_as(q120c* dest) const { + int64_t* const d = (int64_t*)dest; + const uint64_t n = nn(); + for (uint64_t i = 0; i < n; ++i) { + v[i].save_as_q120c(d + 4 * i); + } +} +mod_q120 q120_nttvec::get_blk(uint64_t blk) const { + REQUIRE_DRAMATICALLY(blk < nn(), "blk overflow"); + return v[blk]; +} +q120_nttvec q120_nttvec::random(uint64_t n) { + q120_nttvec res(n); + for (uint64_t i = 0; i < n; ++i) { + res.v[i] = uniform_q120(); + } + return res; +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/ntt120_dft.h b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/ntt120_dft.h new file mode 100644 index 0000000..80f5679 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/ntt120_dft.h @@ -0,0 +1,31 @@ +#ifndef SPQLIOS_NTT120_DFT_H +#define SPQLIOS_NTT120_DFT_H + +#include + +#include "../../spqlios/q120/q120_arithmetic.h" +#include "mod_q120.h" +#include "negacyclic_polynomial.h" +#include "test_commons.h" + +class q120_nttvec { + public: + std::vector v; + q120_nttvec() = default; + explicit q120_nttvec(uint64_t n); + q120_nttvec(uint64_t n, const q120b* data); + q120_nttvec(uint64_t n, const q120c* data); + uint64_t nn() const; + static q120_nttvec zero(uint64_t n); + static q120_nttvec random(uint64_t n); + void save_as(q120a* dest) const; + void save_as(q120b* dest) const; + void save_as(q120c* dest) const; + mod_q120 get_blk(uint64_t blk) const; +}; + +q120_nttvec simple_ntt120(const znx_i64& polynomial); +znx_i128 simple_intt120(const q120_nttvec& fftvec); +bool operator==(const q120_nttvec& a, const q120_nttvec& b); + +#endif // SPQLIOS_NTT120_DFT_H diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/ntt120_layouts.cpp b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/ntt120_layouts.cpp new file mode 100644 index 0000000..d1e582f --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/ntt120_layouts.cpp @@ -0,0 +1,66 @@ +#include "ntt120_layouts.h" + +mod_q120x2::mod_q120x2() {} +mod_q120x2::mod_q120x2(const mod_q120& a, const mod_q120& b) { + value[0] = a; + value[1] = b; +} +mod_q120x2::mod_q120x2(q120x2b* addr) { + uint64_t* p = (uint64_t*)addr; + value[0] = mod_q120::from_q120b(p); + value[1] = mod_q120::from_q120b(p + 4); +} + +ntt120_vec_znx_dft_layout::ntt120_vec_znx_dft_layout(uint64_t n, uint64_t size) + : nn(n), // + size(size), // + data((VEC_ZNX_DFT*)alloc64(n * size * 4 * sizeof(uint64_t))) {} + +mod_q120x2 ntt120_vec_znx_dft_layout::get_copy_zext(uint64_t idx, uint64_t blk) { + return mod_q120x2(get_blk(idx, blk)); +} +q120x2b* ntt120_vec_znx_dft_layout::get_blk(uint64_t idx, uint64_t blk) { + REQUIRE_DRAMATICALLY(idx < size, "idx overflow"); + REQUIRE_DRAMATICALLY(blk < nn / 2, "blk overflow"); + uint64_t* d = (uint64_t*)data; + return (q120x2b*)(d + 4 * nn * idx + 8 * blk); +} +ntt120_vec_znx_dft_layout::~ntt120_vec_znx_dft_layout() { spqlios_free(data); } +q120_nttvec ntt120_vec_znx_dft_layout::get_copy_zext(uint64_t idx) { + int64_t* d = (int64_t*)data; + if (idx < size) { + return q120_nttvec(nn, (q120b*)(d + idx * nn * 4)); + } else { + return q120_nttvec::zero(nn); + } +} +void ntt120_vec_znx_dft_layout::set(uint64_t idx, const q120_nttvec& value) { + REQUIRE_DRAMATICALLY(idx < size, "index overflow: " << idx << " / " << size); + q120b* dest_addr = (q120b*)((int64_t*)data + idx * nn * 4); + value.save_as(dest_addr); +} +void ntt120_vec_znx_dft_layout::fill_random() { + for (uint64_t i = 0; i < size; ++i) { + set(i, q120_nttvec::random(nn)); + } +} +thash ntt120_vec_znx_dft_layout::content_hash() const { return test_hash(data, nn * size * 4 * sizeof(int64_t)); } +ntt120_vec_znx_big_layout::ntt120_vec_znx_big_layout(uint64_t n, uint64_t size) + : nn(n), // + size(size), + data((VEC_ZNX_BIG*)alloc64(n * size * sizeof(__int128_t))) {} + +znx_i128 ntt120_vec_znx_big_layout::get_copy(uint64_t index) const { return znx_i128(nn, get_addr(index)); } +znx_i128 ntt120_vec_znx_big_layout::get_copy_zext(uint64_t index) const { + if (index < size) { + return znx_i128(nn, get_addr(index)); + } else { + return znx_i128::zero(nn); + } +} +__int128* ntt120_vec_znx_big_layout::get_addr(uint64_t index) const { + REQUIRE_DRAMATICALLY(index < size, "index overflow: " << index << " / " << size); + return (__int128_t*)data + index * nn; +} +void ntt120_vec_znx_big_layout::set(uint64_t index, const znx_i128& value) { value.save_as(get_addr(index)); } +ntt120_vec_znx_big_layout::~ntt120_vec_znx_big_layout() { spqlios_free(data); } diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/ntt120_layouts.h b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/ntt120_layouts.h new file mode 100644 index 0000000..d8fcc08 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/ntt120_layouts.h @@ -0,0 +1,103 @@ +#ifndef SPQLIOS_NTT120_LAYOUTS_H +#define SPQLIOS_NTT120_LAYOUTS_H + +#include "../../spqlios/arithmetic/vec_znx_arithmetic.h" +#include "mod_q120.h" +#include "negacyclic_polynomial.h" +#include "ntt120_dft.h" +#include "test_commons.h" + +struct q120b_vector_view {}; + +struct mod_q120x2 { + mod_q120 value[2]; + mod_q120x2(); + mod_q120x2(const mod_q120& a, const mod_q120& b); + mod_q120x2(__int128_t value); + explicit mod_q120x2(q120x2b* addr); + explicit mod_q120x2(q120x2c* addr); + void save_as(q120x2b* addr) const; + void save_as(q120x2c* addr) const; + static mod_q120x2 random(); +}; +mod_q120x2 operator+(const mod_q120x2& a, const mod_q120x2& b); +mod_q120x2 operator-(const mod_q120x2& a, const mod_q120x2& b); +mod_q120x2 operator*(const mod_q120x2& a, const mod_q120x2& b); +bool operator==(const mod_q120x2& a, const mod_q120x2& b); +bool operator!=(const mod_q120x2& a, const mod_q120x2& b); +mod_q120x2& operator+=(mod_q120x2& a, const mod_q120x2& b); +mod_q120x2& operator-=(mod_q120x2& a, const mod_q120x2& b); + +/** @brief test layout for the VEC_ZNX_DFT */ +struct ntt120_vec_znx_dft_layout { + const uint64_t nn; + const uint64_t size; + VEC_ZNX_DFT* const data; + ntt120_vec_znx_dft_layout(uint64_t n, uint64_t size); + mod_q120x2 get_copy_zext(uint64_t idx, uint64_t blk); + q120_nttvec get_copy_zext(uint64_t idx); + void set(uint64_t idx, const q120_nttvec& v); + q120x2b* get_blk(uint64_t idx, uint64_t blk); + thash content_hash() const; + void fill_random(); + ~ntt120_vec_znx_dft_layout(); +}; + +/** @brief test layout for the VEC_ZNX_BIG */ +class ntt120_vec_znx_big_layout { + public: + const uint64_t nn; + const uint64_t size; + VEC_ZNX_BIG* const data; + ntt120_vec_znx_big_layout(uint64_t n, uint64_t size); + + private: + __int128* get_addr(uint64_t index) const; + + public: + znx_i128 get_copy(uint64_t index) const; + znx_i128 get_copy_zext(uint64_t index) const; + void set(uint64_t index, const znx_i128& value); + ~ntt120_vec_znx_big_layout(); +}; + +/** @brief test layout for the VMP_PMAT */ +class ntt120_vmp_pmat_layout { + const uint64_t nn; + const uint64_t nrows; + const uint64_t ncols; + VMP_PMAT* const data; + ntt120_vmp_pmat_layout(uint64_t n, uint64_t nrows, uint64_t ncols); + mod_q120x2 get(uint64_t row, uint64_t col, uint64_t blk) const; + ~ntt120_vmp_pmat_layout(); +}; + +/** @brief test layout for the SVP_PPOL */ +class ntt120_svp_ppol_layout { + const uint64_t nn; + SVP_PPOL* const data; + ntt120_svp_ppol_layout(uint64_t n); + ~ntt120_svp_ppol_layout(); +}; + +/** @brief test layout for the CNV_PVEC_L */ +class ntt120_cnv_left_layout { + const uint64_t nn; + const uint64_t size; + CNV_PVEC_L* const data; + ntt120_cnv_left_layout(uint64_t n, uint64_t size); + mod_q120x2 get(uint64_t idx, uint64_t blk); + ~ntt120_cnv_left_layout(); +}; + +/** @brief test layout for the CNV_PVEC_R */ +class ntt120_cnv_right_layout { + const uint64_t nn; + const uint64_t size; + CNV_PVEC_R* const data; + ntt120_cnv_right_layout(uint64_t n, uint64_t size); + mod_q120x2 get(uint64_t idx, uint64_t blk); + ~ntt120_cnv_right_layout(); +}; + +#endif // SPQLIOS_NTT120_LAYOUTS_H diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/polynomial_vector.cpp b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/polynomial_vector.cpp new file mode 100644 index 0000000..95d4e78 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/polynomial_vector.cpp @@ -0,0 +1,69 @@ +#include "polynomial_vector.h" + +#include + +#ifdef VALGRIND_MEM_TESTS +#include "valgrind/memcheck.h" +#endif + +#define CANARY_PADDING (1024) +#define GARBAGE_VALUE (242) + +znx_vec_i64_layout::znx_vec_i64_layout(uint64_t n, uint64_t size, uint64_t slice) : n(n), size(size), slice(slice) { + REQUIRE_DRAMATICALLY(is_pow2(n), "not a power of 2" << n); + REQUIRE_DRAMATICALLY(slice >= n, "slice too small" << slice << " < " << n); + this->region = (uint8_t*)malloc(size * slice * sizeof(int64_t) + 2 * CANARY_PADDING); + this->data_start = (int64_t*)(region + CANARY_PADDING); + // ensure that any invalid value is kind-of garbage + memset(region, GARBAGE_VALUE, size * slice * sizeof(int64_t) + 2 * CANARY_PADDING); + // mark inter-slice memory as non accessible +#ifdef VALGRIND_MEM_TESTS + VALGRIND_MAKE_MEM_NOACCESS(region, CANARY_PADDING); + VALGRIND_MAKE_MEM_NOACCESS(region + size * slice * sizeof(int64_t) + CANARY_PADDING, CANARY_PADDING); + for (uint64_t i = 0; i < size; ++i) { + VALGRIND_MAKE_MEM_UNDEFINED(data_start + i * slice, n * sizeof(int64_t)); + } + if (size != slice) { + for (uint64_t i = 0; i < size; ++i) { + VALGRIND_MAKE_MEM_NOACCESS(data_start + i * slice + n, (slice - n) * sizeof(int64_t)); + } + } +#endif +} + +znx_vec_i64_layout::~znx_vec_i64_layout() { free(region); } + +znx_i64 znx_vec_i64_layout::get_copy_zext(uint64_t index) const { + if (index < size) { + return znx_i64(n, data_start + index * slice); + } else { + return znx_i64::zero(n); + } +} + +znx_i64 znx_vec_i64_layout::get_copy(uint64_t index) const { + REQUIRE_DRAMATICALLY(index < size, "index overflow: " << index << " / " << size); + return znx_i64(n, data_start + index * slice); +} + +void znx_vec_i64_layout::set(uint64_t index, const znx_i64& elem) { + REQUIRE_DRAMATICALLY(index < size, "index overflow: " << index << " / " << size); + REQUIRE_DRAMATICALLY(elem.nn() == n, "incompatible ring dimensions: " << elem.nn() << " / " << n); + elem.save_as(data_start + index * slice); +} + +int64_t* znx_vec_i64_layout::data() { return data_start; } +const int64_t* znx_vec_i64_layout::data() const { return data_start; } + +void znx_vec_i64_layout::fill_random(uint64_t bits) { + for (uint64_t i = 0; i < size; ++i) { + set(i, znx_i64::random_log2bound(n, bits)); + } +} +__uint128_t znx_vec_i64_layout::content_hash() const { + test_hasher hasher; + for (uint64_t i = 0; i < size; ++i) { + hasher.update(data() + i * slice, n * sizeof(int64_t)); + } + return hasher.hash(); +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/polynomial_vector.h b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/polynomial_vector.h new file mode 100644 index 0000000..d821193 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/polynomial_vector.h @@ -0,0 +1,42 @@ +#ifndef SPQLIOS_POLYNOMIAL_VECTOR_H +#define SPQLIOS_POLYNOMIAL_VECTOR_H + +#include "negacyclic_polynomial.h" +#include "test_commons.h" + +/** @brief a test memory layout for znx i64 polynomials vectors */ +class znx_vec_i64_layout { + uint64_t n; + uint64_t size; + uint64_t slice; + int64_t* data_start; + uint8_t* region; + + public: + // NO-COPY structure + znx_vec_i64_layout(const znx_vec_i64_layout&) = delete; + void operator=(const znx_vec_i64_layout&) = delete; + znx_vec_i64_layout(znx_vec_i64_layout&&) = delete; + void operator=(znx_vec_i64_layout&&) = delete; + /** @brief initialises a memory layout */ + znx_vec_i64_layout(uint64_t n, uint64_t size, uint64_t slice); + /** @brief destructor */ + ~znx_vec_i64_layout(); + + /** @brief get a copy of item index index (extended with zeros) */ + znx_i64 get_copy_zext(uint64_t index) const; + /** @brief get a copy of item index index (extended with zeros) */ + znx_i64 get_copy(uint64_t index) const; + /** @brief get a copy of item index index (index +#include + +#include "test_commons.h" + +bool is_pow2(uint64_t n) { return !(n & (n - 1)); } + +test_rng& randgen() { + static test_rng gen; + return gen; +} +uint64_t uniform_u64() { + static std::uniform_int_distribution dist64(0, UINT64_MAX); + return dist64(randgen()); +} + +uint64_t uniform_u64_bits(uint64_t nbits) { + if (nbits >= 64) return uniform_u64(); + return uniform_u64() >> (64 - nbits); +} + +int64_t uniform_i64() { + std::uniform_int_distribution dist; + return dist(randgen()); +} + +int64_t uniform_i64_bits(uint64_t nbits) { + int64_t bound = int64_t(1) << nbits; + std::uniform_int_distribution dist(-bound, bound); + return dist(randgen()); +} + +int64_t uniform_i64_bounds(const int64_t lb, const int64_t ub) { + std::uniform_int_distribution dist(lb, ub); + return dist(randgen()); +} + +__int128_t uniform_i128_bounds(const __int128_t lb, const __int128_t ub) { + std::uniform_int_distribution<__int128_t> dist(lb, ub); + return dist(randgen()); +} + +double random_f64_gaussian(double stdev) { + std::normal_distribution dist(0, stdev); + return dist(randgen()); +} + +double uniform_f64_bounds(const double lb, const double ub) { + std::uniform_real_distribution dist(lb, ub); + return dist(randgen()); +} + +double uniform_f64_01() { return uniform_f64_bounds(0, 1); } diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/reim4_elem.cpp b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/reim4_elem.cpp new file mode 100644 index 0000000..2028a31 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/reim4_elem.cpp @@ -0,0 +1,145 @@ +#include "reim4_elem.h" + +reim4_elem::reim4_elem(const double* re, const double* im) { + for (uint64_t i = 0; i < 4; ++i) { + value[i] = re[i]; + value[4 + i] = im[i]; + } +} +reim4_elem::reim4_elem(const double* layout) { + for (uint64_t i = 0; i < 8; ++i) { + value[i] = layout[i]; + } +} +reim4_elem::reim4_elem() { + for (uint64_t i = 0; i < 8; ++i) { + value[i] = 0.; + } +} +void reim4_elem::save_re_im(double* re, double* im) const { + for (uint64_t i = 0; i < 4; ++i) { + re[i] = value[i]; + im[i] = value[4 + i]; + } +} +void reim4_elem::save_as(double* reim4) const { + for (uint64_t i = 0; i < 8; ++i) { + reim4[i] = value[i]; + } +} +reim4_elem reim4_elem::zero() { return reim4_elem(); } + +bool operator==(const reim4_elem& x, const reim4_elem& y) { + for (uint64_t i = 0; i < 8; ++i) { + if (x.value[i] != y.value[i]) return false; + } + return true; +} + +reim4_elem gaussian_reim4() { + test_rng& gen = randgen(); + std::normal_distribution dist(0, 1); + reim4_elem res; + for (uint64_t i = 0; i < 8; ++i) { + res.value[i] = dist(gen); + } + return res; +} + +reim4_array_view::reim4_array_view(uint64_t size, double* data) : size(size), data(data) {} +reim4_elem reim4_array_view::get(uint64_t i) const { + REQUIRE_DRAMATICALLY(i < size, "reim4 array overflow"); + return reim4_elem(data + 8 * i); +} +void reim4_array_view::set(uint64_t i, const reim4_elem& value) { + REQUIRE_DRAMATICALLY(i < size, "reim4 array overflow"); + value.save_as(data + 8 * i); +} + +reim_view::reim_view(uint64_t m, double* data) : m(m), data(data) {} +reim4_elem reim_view::get_blk(uint64_t i) { + REQUIRE_DRAMATICALLY(i < m / 4, "block overflow"); + return reim4_elem(data + 4 * i, data + m + 4 * i); +} +void reim_view::set_blk(uint64_t i, const reim4_elem& value) { + REQUIRE_DRAMATICALLY(i < m / 4, "block overflow"); + value.save_re_im(data + 4 * i, data + m + 4 * i); +} + +reim_vector_view::reim_vector_view(uint64_t m, uint64_t nrows, double* data) : m(m), nrows(nrows), data(data) {} +reim_view reim_vector_view::row(uint64_t row) { + REQUIRE_DRAMATICALLY(row < nrows, "row overflow"); + return reim_view(m, data + 2 * m * row); +} + +/** @brief addition */ +reim4_elem operator+(const reim4_elem& x, const reim4_elem& y) { + reim4_elem reps; + for (uint64_t i = 0; i < 8; ++i) { + reps.value[i] = x.value[i] + y.value[i]; + } + return reps; +} +reim4_elem& operator+=(reim4_elem& x, const reim4_elem& y) { + for (uint64_t i = 0; i < 8; ++i) { + x.value[i] += y.value[i]; + } + return x; +} +/** @brief subtraction */ +reim4_elem operator-(const reim4_elem& x, const reim4_elem& y) { + reim4_elem reps; + for (uint64_t i = 0; i < 8; ++i) { + reps.value[i] = x.value[i] + y.value[i]; + } + return reps; +} +reim4_elem& operator-=(reim4_elem& x, const reim4_elem& y) { + for (uint64_t i = 0; i < 8; ++i) { + x.value[i] -= y.value[i]; + } + return x; +} +/** @brief product */ +reim4_elem operator*(const reim4_elem& x, const reim4_elem& y) { + reim4_elem reps; + for (uint64_t i = 0; i < 4; ++i) { + double xre = x.value[i]; + double yre = y.value[i]; + double xim = x.value[i + 4]; + double yim = y.value[i + 4]; + reps.value[i] = xre * yre - xim * yim; + reps.value[i + 4] = xre * yim + xim * yre; + } + return reps; +} +/** @brief distance in infty norm */ +double infty_dist(const reim4_elem& x, const reim4_elem& y) { + double dist = 0; + for (uint64_t i = 0; i < 8; ++i) { + double d = fabs(x.value[i] - y.value[i]); + if (d > dist) dist = d; + } + return dist; +} + +std::ostream& operator<<(std::ostream& out, const reim4_elem& x) { + out << "[\n"; + for (uint64_t i = 0; i < 4; ++i) { + out << " re=" << x.value[i] << ", im=" << x.value[i + 4] << "\n"; + } + return out << "]"; +} + +reim4_matrix_view::reim4_matrix_view(uint64_t nrows, uint64_t ncols, double* data) + : nrows(nrows), ncols(ncols), data(data) {} +reim4_elem reim4_matrix_view::get(uint64_t row, uint64_t col) const { + REQUIRE_DRAMATICALLY(row < nrows, "rows out of bounds" << row << " / " << nrows); + REQUIRE_DRAMATICALLY(col < ncols, "cols out of bounds" << col << " / " << ncols); + return reim4_elem(data + 8 * (row * ncols + col)); +} +void reim4_matrix_view::set(uint64_t row, uint64_t col, const reim4_elem& value) { + REQUIRE_DRAMATICALLY(row < nrows, "rows out of bounds" << row << " / " << nrows); + REQUIRE_DRAMATICALLY(col < ncols, "cols out of bounds" << col << " / " << ncols); + value.save_as(data + 8 * (row * ncols + col)); +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/reim4_elem.h b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/reim4_elem.h new file mode 100644 index 0000000..68d9430 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/reim4_elem.h @@ -0,0 +1,95 @@ +#ifndef SPQLIOS_REIM4_ELEM_H +#define SPQLIOS_REIM4_ELEM_H + +#include "test_commons.h" + +/** @brief test class representing one single reim4 element */ +class reim4_elem { + public: + /** @brief 8 components (4 real parts followed by 4 imag parts) */ + double value[8]; + /** @brief constructs from 4 real parts and 4 imaginary parts */ + reim4_elem(const double* re, const double* im); + /** @brief constructs from 8 components */ + explicit reim4_elem(const double* layout); + /** @brief zero */ + reim4_elem(); + /** @brief saves the real parts to re and the 4 imag to im */ + void save_re_im(double* re, double* im) const; + /** @brief saves the 8 components to reim4 */ + void save_as(double* reim4) const; + static reim4_elem zero(); +}; + +/** @brief checks for equality */ +bool operator==(const reim4_elem& x, const reim4_elem& y); +/** @brief random gaussian reim4 of stdev 1 and mean 0 */ +reim4_elem gaussian_reim4(); +/** @brief addition */ +reim4_elem operator+(const reim4_elem& x, const reim4_elem& y); +reim4_elem& operator+=(reim4_elem& x, const reim4_elem& y); +/** @brief subtraction */ +reim4_elem operator-(const reim4_elem& x, const reim4_elem& y); +reim4_elem& operator-=(reim4_elem& x, const reim4_elem& y); +/** @brief product */ +reim4_elem operator*(const reim4_elem& x, const reim4_elem& y); +std::ostream& operator<<(std::ostream& out, const reim4_elem& x); +/** @brief distance in infty norm */ +double infty_dist(const reim4_elem& x, const reim4_elem& y); + +/** @brief test class representing the view of one reim of m complexes */ +class reim4_array_view { + uint64_t size; ///< size of the reim array + double* data; ///< pointer to the start of the array + public: + /** @brief ininitializes a view at an existing given address */ + reim4_array_view(uint64_t size, double* data); + ; + /** @brief gets the i-th element */ + reim4_elem get(uint64_t i) const; + /** @brief sets the i-th element */ + void set(uint64_t i, const reim4_elem& value); +}; + +/** @brief test class representing the view of one matrix of nrowsxncols reim4's */ +class reim4_matrix_view { + uint64_t nrows; ///< number of rows + uint64_t ncols; ///< number of columns + double* data; ///< pointer to the start of the matrix + public: + /** @brief ininitializes a view at an existing given address */ + reim4_matrix_view(uint64_t nrows, uint64_t ncols, double* data); + /** @brief gets the i-th element */ + reim4_elem get(uint64_t row, uint64_t col) const; + /** @brief sets the i-th element */ + void set(uint64_t row, uint64_t col, const reim4_elem& value); +}; + +/** @brief test class representing the view of one reim of m complexes */ +class reim_view { + uint64_t m; ///< (complex) dimension of the reim polynomial + double* data; ///< address of the start of the reim polynomial + public: + /** @brief ininitializes a view at an existing given address */ + reim_view(uint64_t m, double* data); + ; + /** @brief extracts the i-th reim4 block (i +// https://github.com/mjosaarinen/tiny_sha3 +// LICENSE: MIT + +// Revised 07-Aug-15 to match with official release of FIPS PUB 202 "SHA3" +// Revised 03-Sep-15 for portability + OpenSSL - style API + +#include "sha3.h" + +// update the state with given number of rounds + +void sha3_keccakf(uint64_t st[25]) { + // constants + const uint64_t keccakf_rndc[24] = {0x0000000000000001, 0x0000000000008082, 0x800000000000808a, 0x8000000080008000, + 0x000000000000808b, 0x0000000080000001, 0x8000000080008081, 0x8000000000008009, + 0x000000000000008a, 0x0000000000000088, 0x0000000080008009, 0x000000008000000a, + 0x000000008000808b, 0x800000000000008b, 0x8000000000008089, 0x8000000000008003, + 0x8000000000008002, 0x8000000000000080, 0x000000000000800a, 0x800000008000000a, + 0x8000000080008081, 0x8000000000008080, 0x0000000080000001, 0x8000000080008008}; + const int keccakf_rotc[24] = {1, 3, 6, 10, 15, 21, 28, 36, 45, 55, 2, 14, + 27, 41, 56, 8, 25, 43, 62, 18, 39, 61, 20, 44}; + const int keccakf_piln[24] = {10, 7, 11, 17, 18, 3, 5, 16, 8, 21, 24, 4, 15, 23, 19, 13, 12, 2, 20, 14, 22, 9, 6, 1}; + + // variables + int i, j, r; + uint64_t t, bc[5]; + +#if __BYTE_ORDER__ != __ORDER_LITTLE_ENDIAN__ + uint8_t* v; + + // endianess conversion. this is redundant on little-endian targets + for (i = 0; i < 25; i++) { + v = (uint8_t*)&st[i]; + st[i] = ((uint64_t)v[0]) | (((uint64_t)v[1]) << 8) | (((uint64_t)v[2]) << 16) | (((uint64_t)v[3]) << 24) | + (((uint64_t)v[4]) << 32) | (((uint64_t)v[5]) << 40) | (((uint64_t)v[6]) << 48) | (((uint64_t)v[7]) << 56); + } +#endif + + // actual iteration + for (r = 0; r < KECCAKF_ROUNDS; r++) { + // Theta + for (i = 0; i < 5; i++) bc[i] = st[i] ^ st[i + 5] ^ st[i + 10] ^ st[i + 15] ^ st[i + 20]; + + for (i = 0; i < 5; i++) { + t = bc[(i + 4) % 5] ^ ROTL64(bc[(i + 1) % 5], 1); + for (j = 0; j < 25; j += 5) st[j + i] ^= t; + } + + // Rho Pi + t = st[1]; + for (i = 0; i < 24; i++) { + j = keccakf_piln[i]; + bc[0] = st[j]; + st[j] = ROTL64(t, keccakf_rotc[i]); + t = bc[0]; + } + + // Chi + for (j = 0; j < 25; j += 5) { + for (i = 0; i < 5; i++) bc[i] = st[j + i]; + for (i = 0; i < 5; i++) st[j + i] ^= (~bc[(i + 1) % 5]) & bc[(i + 2) % 5]; + } + + // Iota + st[0] ^= keccakf_rndc[r]; + } + +#if __BYTE_ORDER__ != __ORDER_LITTLE_ENDIAN__ + // endianess conversion. this is redundant on little-endian targets + for (i = 0; i < 25; i++) { + v = (uint8_t*)&st[i]; + t = st[i]; + v[0] = t & 0xFF; + v[1] = (t >> 8) & 0xFF; + v[2] = (t >> 16) & 0xFF; + v[3] = (t >> 24) & 0xFF; + v[4] = (t >> 32) & 0xFF; + v[5] = (t >> 40) & 0xFF; + v[6] = (t >> 48) & 0xFF; + v[7] = (t >> 56) & 0xFF; + } +#endif +} + +// Initialize the context for SHA3 + +int sha3_init(sha3_ctx_t* c, int mdlen) { + int i; + + for (i = 0; i < 25; i++) c->st.q[i] = 0; + c->mdlen = mdlen; + c->rsiz = 200 - 2 * mdlen; + c->pt = 0; + + return 1; +} + +// update state with more data + +int sha3_update(sha3_ctx_t* c, const void* data, size_t len) { + size_t i; + int j; + + j = c->pt; + for (i = 0; i < len; i++) { + c->st.b[j++] ^= ((const uint8_t*)data)[i]; + if (j >= c->rsiz) { + sha3_keccakf(c->st.q); + j = 0; + } + } + c->pt = j; + + return 1; +} + +// finalize and output a hash + +int sha3_final(void* md, sha3_ctx_t* c) { + int i; + + c->st.b[c->pt] ^= 0x06; + c->st.b[c->rsiz - 1] ^= 0x80; + sha3_keccakf(c->st.q); + + for (i = 0; i < c->mdlen; i++) { + ((uint8_t*)md)[i] = c->st.b[i]; + } + + return 1; +} + +// compute a SHA-3 hash (md) of given byte length from "in" + +void* sha3(const void* in, size_t inlen, void* md, int mdlen) { + sha3_ctx_t sha3; + + sha3_init(&sha3, mdlen); + sha3_update(&sha3, in, inlen); + sha3_final(md, &sha3); + + return md; +} + +// SHAKE128 and SHAKE256 extensible-output functionality + +void shake_xof(sha3_ctx_t* c) { + c->st.b[c->pt] ^= 0x1F; + c->st.b[c->rsiz - 1] ^= 0x80; + sha3_keccakf(c->st.q); + c->pt = 0; +} + +void shake_out(sha3_ctx_t* c, void* out, size_t len) { + size_t i; + int j; + + j = c->pt; + for (i = 0; i < len; i++) { + if (j >= c->rsiz) { + sha3_keccakf(c->st.q); + j = 0; + } + ((uint8_t*)out)[i] = c->st.b[j++]; + } + c->pt = j; +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/sha3.h b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/sha3.h new file mode 100644 index 0000000..08a7c86 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/sha3.h @@ -0,0 +1,56 @@ +// sha3.h +// 19-Nov-11 Markku-Juhani O. Saarinen +// https://github.com/mjosaarinen/tiny_sha3 +// License: MIT + +#ifndef SHA3_H +#define SHA3_H + +#ifdef __cplusplus +extern "C" { +#endif + +#include +#include + +#ifndef KECCAKF_ROUNDS +#define KECCAKF_ROUNDS 24 +#endif + +#ifndef ROTL64 +#define ROTL64(x, y) (((x) << (y)) | ((x) >> (64 - (y)))) +#endif + +// state context +typedef struct { + union { // state: + uint8_t b[200]; // 8-bit bytes + uint64_t q[25]; // 64-bit words + } st; + int pt, rsiz, mdlen; // these don't overflow +} sha3_ctx_t; + +// Compression function. +void sha3_keccakf(uint64_t st[25]); + +// OpenSSL - like interfece +int sha3_init(sha3_ctx_t* c, int mdlen); // mdlen = hash output in bytes +int sha3_update(sha3_ctx_t* c, const void* data, size_t len); +int sha3_final(void* md, sha3_ctx_t* c); // digest goes to md + +// compute a sha3 hash (md) of given byte length from "in" +void* sha3(const void* in, size_t inlen, void* md, int mdlen); + +// SHAKE128 and SHAKE256 extensible-output functions +#define shake128_init(c) sha3_init(c, 16) +#define shake256_init(c) sha3_init(c, 32) +#define shake_update sha3_update + +void shake_xof(sha3_ctx_t* c); +void shake_out(sha3_ctx_t* c, void* out, size_t len); + +#ifdef __cplusplus +} +#endif + +#endif // SHA3_H diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/test_commons.cpp b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/test_commons.cpp new file mode 100644 index 0000000..90f3606 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/test_commons.cpp @@ -0,0 +1,10 @@ +#include "test_commons.h" + +#include + +std::ostream& operator<<(std::ostream& out, __int128_t x) { + char c[35] = {0}; + snprintf(c, 35, "0x%016" PRIx64 "%016" PRIx64, uint64_t(x >> 64), uint64_t(x)); + return out << c; +} +std::ostream& operator<<(std::ostream& out, __uint128_t x) { return out << __int128_t(x); } diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/test_commons.h b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/test_commons.h new file mode 100644 index 0000000..d94f672 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/test_commons.h @@ -0,0 +1,75 @@ +#ifndef SPQLIOS_TEST_COMMONS_H +#define SPQLIOS_TEST_COMMONS_H + +#include +#include + +#include "../../spqlios/commons.h" + +/** @brief macro that crashes if the condition are not met */ +#define REQUIRE_DRAMATICALLY(req_contition, error_msg) \ + do { \ + if (!(req_contition)) { \ + std::cerr << "REQUIREMENT FAILED at " << __FILE__ << ":" << __LINE__ << ": " << error_msg << std::endl; \ + abort(); \ + } \ + } while (0) + +typedef std::default_random_engine test_rng; +/** @brief reference to the default test rng */ +test_rng& randgen(); +/** @brief uniformly random 64-bit uint */ +uint64_t uniform_u64(); +/** @brief uniformly random number <= 2^nbits-1 */ +uint64_t uniform_u64_bits(uint64_t nbits); +/** @brief uniformly random signed 64-bit number */ +int64_t uniform_i64(); +/** @brief uniformly random signed |number| <= 2^nbits */ +int64_t uniform_i64_bits(uint64_t nbits); +/** @brief uniformly random signed lb <= number <= ub */ +int64_t uniform_i64_bounds(const int64_t lb, const int64_t ub); +/** @brief uniformly random signed lb <= number <= ub */ +__int128_t uniform_i128_bounds(const __int128_t lb, const __int128_t ub); +/** @brief uniformly random gaussian float64 */ +double random_f64_gaussian(double stdev = 1); +/** @brief uniformly random signed lb <= number <= ub */ +double uniform_f64_bounds(const double lb, const double ub); +/** @brief uniformly random float64 in [0,1] */ +double uniform_f64_01(); +/** @brief random gaussian float64 */ +double random_f64_gaussian(double stdev); + +bool is_pow2(uint64_t n); + +void* alloc64(uint64_t size); + +typedef __uint128_t thash; +/** @brief returns some pseudorandom hash of a contiguous content */ +thash test_hash(const void* data, uint64_t size); +/** @brief class to return a pseudorandom hash of a piecewise-defined content */ +class test_hasher { + void* md; + + public: + test_hasher(); + test_hasher(const test_hasher&) = delete; + void operator=(const test_hasher&) = delete; + /** + * @brief append input bytes. + * The final hash only depends on the concatenation of bytes, not on the + * way the content was split into multiple calls to update. + */ + void update(const void* data, uint64_t size); + /** + * @brief returns the final hash. + * no more calls to update(...) shall be issued after this call. + */ + thash hash(); + ~test_hasher(); +}; + +// not included by default, since it makes some versions of gtest not compile +// std::ostream& operator<<(std::ostream& out, __int128_t x); +// std::ostream& operator<<(std::ostream& out, __uint128_t x); + +#endif // SPQLIOS_TEST_COMMONS_H diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/test_hash.cpp b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/test_hash.cpp new file mode 100644 index 0000000..2e065af --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/test_hash.cpp @@ -0,0 +1,24 @@ +#include "sha3.h" +#include "test_commons.h" + +/** @brief returns some pseudorandom hash of the content */ +thash test_hash(const void* data, uint64_t size) { + thash res; + sha3(data, size, &res, sizeof(res)); + return res; +} +/** @brief class to return a pseudorandom hash of the content */ +test_hasher::test_hasher() { + md = malloc(sizeof(sha3_ctx_t)); + sha3_init((sha3_ctx_t*)md, 16); +} + +void test_hasher::update(const void* data, uint64_t size) { sha3_update((sha3_ctx_t*)md, data, size); } + +thash test_hasher::hash() { + thash res; + sha3_final(&res, (sha3_ctx_t*)md); + return res; +} + +test_hasher::~test_hasher() { free(md); } diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/vec_rnx_layout.cpp b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/vec_rnx_layout.cpp new file mode 100644 index 0000000..2a61e81 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/vec_rnx_layout.cpp @@ -0,0 +1,182 @@ +#include "vec_rnx_layout.h" + +#include + +#include "../../spqlios/arithmetic/vec_rnx_arithmetic.h" + +#ifdef VALGRIND_MEM_TESTS +#include "valgrind/memcheck.h" +#endif + +#define CANARY_PADDING (1024) +#define GARBAGE_VALUE (242) + +rnx_vec_f64_layout::rnx_vec_f64_layout(uint64_t n, uint64_t size, uint64_t slice) : n(n), size(size), slice(slice) { + REQUIRE_DRAMATICALLY(is_pow2(n), "not a power of 2" << n); + REQUIRE_DRAMATICALLY(slice >= n, "slice too small" << slice << " < " << n); + this->region = (uint8_t*)malloc(size * slice * sizeof(int64_t) + 2 * CANARY_PADDING); + this->data_start = (double*)(region + CANARY_PADDING); + // ensure that any invalid value is kind-of garbage + memset(region, GARBAGE_VALUE, size * slice * sizeof(int64_t) + 2 * CANARY_PADDING); + // mark inter-slice memory as not accessible +#ifdef VALGRIND_MEM_TESTS + VALGRIND_MAKE_MEM_NOACCESS(region, CANARY_PADDING); + VALGRIND_MAKE_MEM_NOACCESS(region + size * slice * sizeof(int64_t) + CANARY_PADDING, CANARY_PADDING); + for (uint64_t i = 0; i < size; ++i) { + VALGRIND_MAKE_MEM_UNDEFINED(data_start + i * slice, n * sizeof(int64_t)); + } + if (size != slice) { + for (uint64_t i = 0; i < size; ++i) { + VALGRIND_MAKE_MEM_NOACCESS(data_start + i * slice + n, (slice - n) * sizeof(int64_t)); + } + } +#endif +} + +rnx_vec_f64_layout::~rnx_vec_f64_layout() { free(region); } + +rnx_f64 rnx_vec_f64_layout::get_copy_zext(uint64_t index) const { + if (index < size) { + return rnx_f64(n, data_start + index * slice); + } else { + return rnx_f64::zero(n); + } +} + +rnx_f64 rnx_vec_f64_layout::get_copy(uint64_t index) const { + REQUIRE_DRAMATICALLY(index < size, "index overflow: " << index << " / " << size); + return rnx_f64(n, data_start + index * slice); +} + +reim_fft64vec rnx_vec_f64_layout::get_dft_copy_zext(uint64_t index) const { + if (index < size) { + return reim_fft64vec(n, data_start + index * slice); + } else { + return reim_fft64vec::zero(n); + } +} + +reim_fft64vec rnx_vec_f64_layout::get_dft_copy(uint64_t index) const { + REQUIRE_DRAMATICALLY(index < size, "index overflow: " << index << " / " << size); + return reim_fft64vec(n, data_start + index * slice); +} + +void rnx_vec_f64_layout::set(uint64_t index, const rnx_f64& elem) { + REQUIRE_DRAMATICALLY(index < size, "index overflow: " << index << " / " << size); + REQUIRE_DRAMATICALLY(elem.nn() == n, "incompatible ring dimensions: " << elem.nn() << " / " << n); + elem.save_as(data_start + index * slice); +} + +double* rnx_vec_f64_layout::data() { return data_start; } +const double* rnx_vec_f64_layout::data() const { return data_start; } + +void rnx_vec_f64_layout::fill_random(double log2bound) { + for (uint64_t i = 0; i < size; ++i) { + set(i, rnx_f64::random_log2bound(n, log2bound)); + } +} + +thash rnx_vec_f64_layout::content_hash() const { + test_hasher hasher; + for (uint64_t i = 0; i < size; ++i) { + hasher.update(data() + i * slice, n * sizeof(int64_t)); + } + return hasher.hash(); +} + +fft64_rnx_vmp_pmat_layout::fft64_rnx_vmp_pmat_layout(uint64_t n, uint64_t nrows, uint64_t ncols) + : nn(n), + nrows(nrows), + ncols(ncols), // + data((RNX_VMP_PMAT*)alloc64(nrows * ncols * nn * 8)) {} + +double* fft64_rnx_vmp_pmat_layout::get_addr(uint64_t row, uint64_t col, uint64_t blk) const { + REQUIRE_DRAMATICALLY(row < nrows, "row overflow: " << row << " / " << nrows); + REQUIRE_DRAMATICALLY(col < ncols, "col overflow: " << col << " / " << ncols); + REQUIRE_DRAMATICALLY(blk < nn / 8, "block overflow: " << blk << " / " << (nn / 8)); + double* d = (double*)data; + if (col == (ncols - 1) && (ncols % 2 == 1)) { + // special case: last column out of an odd column number + return d + blk * nrows * ncols * 8 // major: blk + + col * nrows * 8 // col == ncols-1 + + row * 8; + } else { + // general case: columns go by pair + return d + blk * nrows * ncols * 8 // major: blk + + (col / 2) * (2 * nrows) * 8 // second: col pair index + + row * 2 * 8 // third: row index + + (col % 2) * 8; // minor: col in colpair + } +} + +reim4_elem fft64_rnx_vmp_pmat_layout::get(uint64_t row, uint64_t col, uint64_t blk) const { + return reim4_elem(get_addr(row, col, blk)); +} +reim4_elem fft64_rnx_vmp_pmat_layout::get_zext(uint64_t row, uint64_t col, uint64_t blk) const { + REQUIRE_DRAMATICALLY(blk < nn / 8, "block overflow: " << blk << " / " << (nn / 8)); + if (row < nrows && col < ncols) { + return reim4_elem(get_addr(row, col, blk)); + } else { + return reim4_elem::zero(); + } +} +void fft64_rnx_vmp_pmat_layout::set(uint64_t row, uint64_t col, uint64_t blk, const reim4_elem& value) const { + value.save_as(get_addr(row, col, blk)); +} + +fft64_rnx_vmp_pmat_layout::~fft64_rnx_vmp_pmat_layout() { spqlios_free(data); } + +reim_fft64vec fft64_rnx_vmp_pmat_layout::get_zext(uint64_t row, uint64_t col) const { + if (row >= nrows || col >= ncols) { + return reim_fft64vec::zero(nn); + } + if (nn < 8) { + // the pmat is just col major + double* addr = (double*)data + (row + col * nrows) * nn; + return reim_fft64vec(nn, addr); + } + // otherwise, reconstruct it block by block + reim_fft64vec res(nn); + for (uint64_t blk = 0; blk < nn / 8; ++blk) { + reim4_elem v = get(row, col, blk); + res.set_blk(blk, v); + } + return res; +} +void fft64_rnx_vmp_pmat_layout::set(uint64_t row, uint64_t col, const reim_fft64vec& value) { + REQUIRE_DRAMATICALLY(row < nrows, "row overflow: " << row << " / " << nrows); + REQUIRE_DRAMATICALLY(col < ncols, "row overflow: " << col << " / " << ncols); + if (nn < 8) { + // the pmat is just col major + double* addr = (double*)data + (row + col * nrows) * nn; + value.save_as(addr); + return; + } + // otherwise, reconstruct it block by block + for (uint64_t blk = 0; blk < nn / 8; ++blk) { + reim4_elem v = value.get_blk(blk); + set(row, col, blk, v); + } +} +void fft64_rnx_vmp_pmat_layout::fill_random(double log2bound) { + for (uint64_t row = 0; row < nrows; ++row) { + for (uint64_t col = 0; col < ncols; ++col) { + set(row, col, reim_fft64vec::random(nn, log2bound)); + } + } +} + +fft64_rnx_svp_ppol_layout::fft64_rnx_svp_ppol_layout(uint64_t n) + : nn(n), // + data((RNX_SVP_PPOL*)alloc64(nn * 8)) {} + +reim_fft64vec fft64_rnx_svp_ppol_layout::get_copy() const { return reim_fft64vec(nn, (double*)data); } + +void fft64_rnx_svp_ppol_layout::set(const reim_fft64vec& value) { value.save_as((double*)data); } + +void fft64_rnx_svp_ppol_layout::fill_dft_random(uint64_t log2bound) { set(reim_fft64vec::dft_random(nn, log2bound)); } + +void fft64_rnx_svp_ppol_layout::fill_random(double log2bound) { set(reim_fft64vec::random(nn, log2bound)); } + +fft64_rnx_svp_ppol_layout::~fft64_rnx_svp_ppol_layout() { spqlios_free(data); } +thash fft64_rnx_svp_ppol_layout::content_hash() const { return test_hash(data, nn * sizeof(double)); } \ No newline at end of file diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/vec_rnx_layout.h b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/vec_rnx_layout.h new file mode 100644 index 0000000..a92bc04 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/vec_rnx_layout.h @@ -0,0 +1,85 @@ +#ifndef SPQLIOS_EXT_VEC_RNX_LAYOUT_H +#define SPQLIOS_EXT_VEC_RNX_LAYOUT_H + +#include "../../spqlios/arithmetic/vec_rnx_arithmetic.h" +#include "fft64_dft.h" +#include "negacyclic_polynomial.h" +#include "reim4_elem.h" +#include "test_commons.h" + +/** @brief a test memory layout for rnx i64 polynomials vectors */ +class rnx_vec_f64_layout { + uint64_t n; + uint64_t size; + uint64_t slice; + double* data_start; + uint8_t* region; + + public: + // NO-COPY structure + rnx_vec_f64_layout(const rnx_vec_f64_layout&) = delete; + void operator=(const rnx_vec_f64_layout&) = delete; + rnx_vec_f64_layout(rnx_vec_f64_layout&&) = delete; + void operator=(rnx_vec_f64_layout&&) = delete; + /** @brief initialises a memory layout */ + rnx_vec_f64_layout(uint64_t n, uint64_t size, uint64_t slice); + /** @brief destructor */ + ~rnx_vec_f64_layout(); + + /** @brief get a copy of item index index (extended with zeros) */ + rnx_f64 get_copy_zext(uint64_t index) const; + /** @brief get a copy of item index index (extended with zeros) */ + rnx_f64 get_copy(uint64_t index) const; + /** @brief get a copy of item index index (extended with zeros) */ + reim_fft64vec get_dft_copy_zext(uint64_t index) const; + /** @brief get a copy of item index index (extended with zeros) */ + reim_fft64vec get_dft_copy(uint64_t index) const; + + /** @brief get a copy of item index index (index> 5; + const uint64_t rem_ncols = ncols & 31; + uint64_t blk = col >> 5; + uint64_t col_rem = col & 31; + if (blk < nblk) { + // column is part of a full block + return (int32_t*)data + blk * nrows * 32 + row * 32 + col_rem; + } else { + // column is part of the last block + return (int32_t*)data + blk * nrows * 32 + row * rem_ncols + col_rem; + } +} +int32_t zn32_pmat_layout::get(uint64_t row, uint64_t col) const { return *get_addr(row, col); } +int32_t zn32_pmat_layout::get_zext(uint64_t row, uint64_t col) const { + if (row >= nrows || col >= ncols) return 0; + return *get_addr(row, col); +} +void zn32_pmat_layout::set(uint64_t row, uint64_t col, int32_t value) { *get_addr(row, col) = value; } +void zn32_pmat_layout::fill_random() { + int32_t* d = (int32_t*)data; + for (uint64_t i = 0; i < nrows * ncols; ++i) d[i] = uniform_i64_bits(32); +} +thash zn32_pmat_layout::content_hash() const { return test_hash(data, nrows * ncols * sizeof(int32_t)); } + +template +std::vector vmp_product(const T* vec, uint64_t vec_size, uint64_t out_size, const zn32_pmat_layout& mat) { + uint64_t rows = std::min(vec_size, mat.nrows); + uint64_t cols = std::min(out_size, mat.ncols); + std::vector res(out_size, 0); + for (uint64_t j = 0; j < cols; ++j) { + for (uint64_t i = 0; i < rows; ++i) { + res[j] += vec[i] * mat.get(i, j); + } + } + return res; +} + +template std::vector vmp_product(const int8_t* vec, uint64_t vec_size, uint64_t out_size, + const zn32_pmat_layout& mat); +template std::vector vmp_product(const int16_t* vec, uint64_t vec_size, uint64_t out_size, + const zn32_pmat_layout& mat); +template std::vector vmp_product(const int32_t* vec, uint64_t vec_size, uint64_t out_size, + const zn32_pmat_layout& mat); diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/zn_layouts.h b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/zn_layouts.h new file mode 100644 index 0000000..b36ce3e --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/test/testlib/zn_layouts.h @@ -0,0 +1,29 @@ +#ifndef SPQLIOS_EXT_ZN_LAYOUTS_H +#define SPQLIOS_EXT_ZN_LAYOUTS_H + +#include "../../spqlios/arithmetic/zn_arithmetic.h" +#include "test_commons.h" + +class zn32_pmat_layout { + public: + const uint64_t nrows; + const uint64_t ncols; + ZN32_VMP_PMAT* const data; + zn32_pmat_layout(uint64_t nrows, uint64_t ncols); + + private: + int32_t* get_addr(uint64_t row, uint64_t col) const; + + public: + int32_t get(uint64_t row, uint64_t col) const; + int32_t get_zext(uint64_t row, uint64_t col) const; + void set(uint64_t row, uint64_t col, int32_t value); + void fill_random(); + thash content_hash() const; + ~zn32_pmat_layout(); +}; + +template +std::vector vmp_product(const T* vec, uint64_t vec_size, uint64_t out_size, const zn32_pmat_layout& mat); + +#endif // SPQLIOS_EXT_ZN_LAYOUTS_H diff --git a/poulpy-backend/src/cpu_spqlios/test/mod.rs b/poulpy-backend/src/cpu_spqlios/test/mod.rs new file mode 100644 index 0000000..3146d6e --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/test/mod.rs @@ -0,0 +1,2 @@ +mod vec_znx_fft64; +mod vmp_pmat_fft64; diff --git a/poulpy-backend/src/implementation/cpu_spqlios/test/vec_znx_fft64.rs b/poulpy-backend/src/cpu_spqlios/test/vec_znx_fft64.rs similarity index 60% rename from poulpy-backend/src/implementation/cpu_spqlios/test/vec_znx_fft64.rs rename to poulpy-backend/src/cpu_spqlios/test/vec_znx_fft64.rs index 35043f1..9e378dc 100644 --- a/poulpy-backend/src/implementation/cpu_spqlios/test/vec_znx_fft64.rs +++ b/poulpy-backend/src/cpu_spqlios/test/vec_znx_fft64.rs @@ -1,12 +1,11 @@ -use crate::{ - hal::{ - api::ModuleNew, - layouts::Module, - tests::vec_znx::{test_vec_znx_add_normal, test_vec_znx_fill_uniform}, - }, - implementation::cpu_spqlios::FFT64, +use poulpy_hal::{ + api::ModuleNew, + layouts::Module, + tests::vec_znx::{test_vec_znx_add_normal, test_vec_znx_fill_uniform}, }; +use crate::cpu_spqlios::FFT64; + #[test] fn test_vec_znx_fill_uniform_fft64() { let module: Module = Module::::new(1 << 12); diff --git a/poulpy-backend/src/cpu_spqlios/test/vmp_pmat_fft64.rs b/poulpy-backend/src/cpu_spqlios/test/vmp_pmat_fft64.rs new file mode 100644 index 0000000..7354d73 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/test/vmp_pmat_fft64.rs @@ -0,0 +1,8 @@ +use poulpy_hal::tests::vmp_pmat::test_vmp_apply; + +use crate::cpu_spqlios::FFT64; + +#[test] +fn vmp_apply() { + test_vmp_apply::(); +} diff --git a/poulpy-backend/src/hal/api/module.rs b/poulpy-backend/src/hal/api/module.rs deleted file mode 100644 index 7ab8672..0000000 --- a/poulpy-backend/src/hal/api/module.rs +++ /dev/null @@ -1,6 +0,0 @@ -use crate::hal::layouts::Backend; - -/// Instantiate a new [crate::hal::layouts::Module]. -pub trait ModuleNew { - fn new(n: u64) -> Self; -} diff --git a/poulpy-backend/src/hal/layouts/scratch.rs b/poulpy-backend/src/hal/layouts/scratch.rs deleted file mode 100644 index 0562939..0000000 --- a/poulpy-backend/src/hal/layouts/scratch.rs +++ /dev/null @@ -1,13 +0,0 @@ -use std::marker::PhantomData; - -use crate::hal::layouts::Backend; - -pub struct ScratchOwned { - pub(crate) data: Vec, - pub(crate) _phantom: PhantomData, -} - -pub struct Scratch { - pub(crate) _phantom: PhantomData, - pub(crate) data: [u8], -} diff --git a/poulpy-backend/src/hal/mod.rs b/poulpy-backend/src/hal/mod.rs deleted file mode 100644 index e94b4e2..0000000 --- a/poulpy-backend/src/hal/mod.rs +++ /dev/null @@ -1,6 +0,0 @@ -pub mod api; -pub mod delegates; -pub mod layouts; -pub mod oep; -pub mod source; -pub mod tests; diff --git a/poulpy-backend/src/implementation/cpu_spqlios/ffi/mod.rs b/poulpy-backend/src/implementation/cpu_spqlios/ffi/mod.rs deleted file mode 100644 index 4213310..0000000 --- a/poulpy-backend/src/implementation/cpu_spqlios/ffi/mod.rs +++ /dev/null @@ -1,8 +0,0 @@ -pub mod module; -pub mod reim; -pub mod svp; -pub mod vec_znx; -pub mod vec_znx_big; -pub mod vec_znx_dft; -pub mod vmp; -pub mod znx; diff --git a/poulpy-backend/src/implementation/cpu_spqlios/ffi/reim.rs b/poulpy-backend/src/implementation/cpu_spqlios/ffi/reim.rs deleted file mode 100644 index a3ce548..0000000 --- a/poulpy-backend/src/implementation/cpu_spqlios/ffi/reim.rs +++ /dev/null @@ -1,172 +0,0 @@ -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct reim_fft_precomp { - _unused: [u8; 0], -} -pub type REIM_FFT_PRECOMP = reim_fft_precomp; -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct reim_ifft_precomp { - _unused: [u8; 0], -} -pub type REIM_IFFT_PRECOMP = reim_ifft_precomp; -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct reim_mul_precomp { - _unused: [u8; 0], -} -pub type REIM_FFTVEC_MUL_PRECOMP = reim_mul_precomp; -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct reim_addmul_precomp { - _unused: [u8; 0], -} -pub type REIM_FFTVEC_ADDMUL_PRECOMP = reim_addmul_precomp; -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct reim_from_znx32_precomp { - _unused: [u8; 0], -} -pub type REIM_FROM_ZNX32_PRECOMP = reim_from_znx32_precomp; -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct reim_from_znx64_precomp { - _unused: [u8; 0], -} -pub type REIM_FROM_ZNX64_PRECOMP = reim_from_znx64_precomp; -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct reim_from_tnx32_precomp { - _unused: [u8; 0], -} -pub type REIM_FROM_TNX32_PRECOMP = reim_from_tnx32_precomp; -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct reim_to_tnx32_precomp { - _unused: [u8; 0], -} -pub type REIM_TO_TNX32_PRECOMP = reim_to_tnx32_precomp; -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct reim_to_tnx_precomp { - _unused: [u8; 0], -} -pub type REIM_TO_TNX_PRECOMP = reim_to_tnx_precomp; -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct reim_to_znx64_precomp { - _unused: [u8; 0], -} -pub type REIM_TO_ZNX64_PRECOMP = reim_to_znx64_precomp; -unsafe extern "C" { - pub unsafe fn new_reim_fft_precomp(m: u32, num_buffers: u32) -> *mut REIM_FFT_PRECOMP; -} -unsafe extern "C" { - pub unsafe fn reim_fft_precomp_get_buffer(tables: *const REIM_FFT_PRECOMP, buffer_index: u32) -> *mut f64; -} -unsafe extern "C" { - pub unsafe fn new_reim_fft_buffer(m: u32) -> *mut f64; -} -unsafe extern "C" { - pub unsafe fn delete_reim_fft_buffer(buffer: *mut f64); -} -unsafe extern "C" { - pub unsafe fn reim_fft(tables: *const REIM_FFT_PRECOMP, data: *mut f64); -} -unsafe extern "C" { - pub unsafe fn new_reim_ifft_precomp(m: u32, num_buffers: u32) -> *mut REIM_IFFT_PRECOMP; -} -unsafe extern "C" { - pub unsafe fn reim_ifft_precomp_get_buffer(tables: *const REIM_IFFT_PRECOMP, buffer_index: u32) -> *mut f64; -} -unsafe extern "C" { - pub unsafe fn reim_ifft(tables: *const REIM_IFFT_PRECOMP, data: *mut f64); -} -unsafe extern "C" { - pub unsafe fn new_reim_fftvec_mul_precomp(m: u32) -> *mut REIM_FFTVEC_MUL_PRECOMP; -} -unsafe extern "C" { - pub unsafe fn reim_fftvec_mul(tables: *const REIM_FFTVEC_MUL_PRECOMP, r: *mut f64, a: *const f64, b: *const f64); -} -unsafe extern "C" { - pub unsafe fn new_reim_fftvec_addmul_precomp(m: u32) -> *mut REIM_FFTVEC_ADDMUL_PRECOMP; -} -unsafe extern "C" { - pub unsafe fn reim_fftvec_addmul(tables: *const REIM_FFTVEC_ADDMUL_PRECOMP, r: *mut f64, a: *const f64, b: *const f64); -} -unsafe extern "C" { - pub unsafe fn new_reim_from_znx32_precomp(m: u32, log2bound: u32) -> *mut REIM_FROM_ZNX32_PRECOMP; -} -unsafe extern "C" { - pub unsafe fn reim_from_znx32(tables: *const REIM_FROM_ZNX32_PRECOMP, r: *mut ::std::os::raw::c_void, a: *const i32); -} -unsafe extern "C" { - pub unsafe fn reim_from_znx64(tables: *const REIM_FROM_ZNX64_PRECOMP, r: *mut ::std::os::raw::c_void, a: *const i64); -} -unsafe extern "C" { - pub unsafe fn new_reim_from_znx64_precomp(m: u32, maxbnd: u32) -> *mut REIM_FROM_ZNX64_PRECOMP; -} -unsafe extern "C" { - pub unsafe fn reim_from_znx64_simple(m: u32, log2bound: u32, r: *mut ::std::os::raw::c_void, a: *const i64); -} -unsafe extern "C" { - pub unsafe fn new_reim_from_tnx32_precomp(m: u32) -> *mut REIM_FROM_TNX32_PRECOMP; -} -unsafe extern "C" { - pub unsafe fn reim_from_tnx32(tables: *const REIM_FROM_TNX32_PRECOMP, r: *mut ::std::os::raw::c_void, a: *const i32); -} -unsafe extern "C" { - pub unsafe fn new_reim_to_tnx32_precomp(m: u32, divisor: f64, log2overhead: u32) -> *mut REIM_TO_TNX32_PRECOMP; -} -unsafe extern "C" { - pub unsafe fn reim_to_tnx32(tables: *const REIM_TO_TNX32_PRECOMP, r: *mut i32, a: *const ::std::os::raw::c_void); -} -unsafe extern "C" { - pub unsafe fn new_reim_to_tnx_precomp(m: u32, divisor: f64, log2overhead: u32) -> *mut REIM_TO_TNX_PRECOMP; -} -unsafe extern "C" { - pub unsafe fn reim_to_tnx(tables: *const REIM_TO_TNX_PRECOMP, r: *mut f64, a: *const f64); -} -unsafe extern "C" { - pub unsafe fn reim_to_tnx_simple(m: u32, divisor: f64, log2overhead: u32, r: *mut f64, a: *const f64); -} -unsafe extern "C" { - pub unsafe fn new_reim_to_znx64_precomp(m: u32, divisor: f64, log2bound: u32) -> *mut REIM_TO_ZNX64_PRECOMP; -} -unsafe extern "C" { - pub unsafe fn reim_to_znx64(precomp: *const REIM_TO_ZNX64_PRECOMP, r: *mut i64, a: *const ::std::os::raw::c_void); -} -unsafe extern "C" { - pub unsafe fn reim_to_znx64_simple(m: u32, divisor: f64, log2bound: u32, r: *mut i64, a: *const ::std::os::raw::c_void); -} -unsafe extern "C" { - pub unsafe fn reim_fft_simple(m: u32, data: *mut ::std::os::raw::c_void); -} -unsafe extern "C" { - pub unsafe fn reim_ifft_simple(m: u32, data: *mut ::std::os::raw::c_void); -} -unsafe extern "C" { - pub unsafe fn reim_fftvec_mul_simple( - m: u32, - r: *mut ::std::os::raw::c_void, - a: *const ::std::os::raw::c_void, - b: *const ::std::os::raw::c_void, - ); -} -unsafe extern "C" { - pub unsafe fn reim_fftvec_addmul_simple( - m: u32, - r: *mut ::std::os::raw::c_void, - a: *const ::std::os::raw::c_void, - b: *const ::std::os::raw::c_void, - ); -} -unsafe extern "C" { - pub unsafe fn reim_from_znx32_simple(m: u32, log2bound: u32, r: *mut ::std::os::raw::c_void, x: *const i32); -} -unsafe extern "C" { - pub unsafe fn reim_from_tnx32_simple(m: u32, r: *mut ::std::os::raw::c_void, x: *const i32); -} -unsafe extern "C" { - pub unsafe fn reim_to_tnx32_simple(m: u32, divisor: f64, log2overhead: u32, r: *mut i32, x: *const ::std::os::raw::c_void); -} diff --git a/poulpy-backend/src/implementation/cpu_spqlios/ffi/znx.rs b/poulpy-backend/src/implementation/cpu_spqlios/ffi/znx.rs deleted file mode 100644 index f03da0a..0000000 --- a/poulpy-backend/src/implementation/cpu_spqlios/ffi/znx.rs +++ /dev/null @@ -1,79 +0,0 @@ -use crate::implementation::cpu_spqlios::ffi::module::MODULE; - -unsafe extern "C" { - pub unsafe fn znx_add_i64_ref(nn: u64, res: *mut i64, a: *const i64, b: *const i64); -} -unsafe extern "C" { - pub unsafe fn znx_add_i64_avx(nn: u64, res: *mut i64, a: *const i64, b: *const i64); -} -unsafe extern "C" { - pub unsafe fn znx_sub_i64_ref(nn: u64, res: *mut i64, a: *const i64, b: *const i64); -} -unsafe extern "C" { - pub unsafe fn znx_sub_i64_avx(nn: u64, res: *mut i64, a: *const i64, b: *const i64); -} -unsafe extern "C" { - pub unsafe fn znx_negate_i64_ref(nn: u64, res: *mut i64, a: *const i64); -} -unsafe extern "C" { - pub unsafe fn znx_negate_i64_avx(nn: u64, res: *mut i64, a: *const i64); -} -unsafe extern "C" { - pub unsafe fn znx_copy_i64_ref(nn: u64, res: *mut i64, a: *const i64); -} -unsafe extern "C" { - pub unsafe fn znx_zero_i64_ref(nn: u64, res: *mut i64); -} -unsafe extern "C" { - pub unsafe fn rnx_divide_by_m_ref(nn: u64, m: f64, res: *mut f64, a: *const f64); -} -unsafe extern "C" { - pub unsafe fn rnx_divide_by_m_avx(nn: u64, m: f64, res: *mut f64, a: *const f64); -} -unsafe extern "C" { - pub unsafe fn rnx_rotate_f64(nn: u64, p: i64, res: *mut f64, in_: *const f64); -} -unsafe extern "C" { - pub unsafe fn znx_rotate_i64(nn: u64, p: i64, res: *mut i64, in_: *const i64); -} -unsafe extern "C" { - pub unsafe fn rnx_rotate_inplace_f64(nn: u64, p: i64, res: *mut f64); -} -unsafe extern "C" { - pub unsafe fn znx_rotate_inplace_i64(nn: u64, p: i64, res: *mut i64); -} -unsafe extern "C" { - pub unsafe fn rnx_automorphism_f64(nn: u64, p: i64, res: *mut f64, in_: *const f64); -} -unsafe extern "C" { - pub unsafe fn znx_automorphism_i64(nn: u64, p: i64, res: *mut i64, in_: *const i64); -} -unsafe extern "C" { - pub unsafe fn rnx_automorphism_inplace_f64(nn: u64, p: i64, res: *mut f64); -} -unsafe extern "C" { - pub unsafe fn znx_automorphism_inplace_i64(nn: u64, p: i64, res: *mut i64); -} -unsafe extern "C" { - pub unsafe fn rnx_mul_xp_minus_one_f64(nn: u64, p: i64, res: *mut f64, in_: *const f64); -} -unsafe extern "C" { - pub unsafe fn znx_mul_xp_minus_one_i64(nn: u64, p: i64, res: *mut i64, in_: *const i64); -} -unsafe extern "C" { - pub unsafe fn rnx_mul_xp_minus_one_inplace_f64(nn: u64, p: i64, res: *mut f64); -} -unsafe extern "C" { - pub unsafe fn znx_mul_xp_minus_one_inplace_i64(nn: u64, p: i64, res: *mut i64); -} -unsafe extern "C" { - pub unsafe fn znx_normalize(nn: u64, base_k: u64, out: *mut i64, carry_out: *mut i64, in_: *const i64, carry_in: *const i64); -} - -unsafe extern "C" { - pub unsafe fn znx_small_single_product(module: *const MODULE, res: *mut i64, a: *const i64, b: *const i64, tmp: *mut u8); -} - -unsafe extern "C" { - pub unsafe fn znx_small_single_product_tmp_bytes(module: *const MODULE) -> u64; -} diff --git a/poulpy-backend/src/implementation/cpu_spqlios/mod.rs b/poulpy-backend/src/implementation/cpu_spqlios/mod.rs deleted file mode 100644 index 570a23b..0000000 --- a/poulpy-backend/src/implementation/cpu_spqlios/mod.rs +++ /dev/null @@ -1,27 +0,0 @@ -mod ffi; -mod module_fft64; -mod module_ntt120; -mod scratch; -mod svp_ppol_fft64; -mod svp_ppol_ntt120; -mod vec_znx; -mod vec_znx_big_fft64; -mod vec_znx_big_ntt120; -mod vec_znx_dft_fft64; -mod vec_znx_dft_ntt120; -mod vmp_pmat_fft64; -mod vmp_pmat_ntt120; - -#[cfg(test)] -mod test; - -pub use module_fft64::*; -pub use module_ntt120::*; - -/// For external documentation -pub use vec_znx::{ - vec_znx_copy_ref, vec_znx_lsh_inplace_ref, vec_znx_merge_ref, vec_znx_rsh_inplace_ref, vec_znx_split_ref, - vec_znx_switch_degree_ref, -}; - -pub trait CPUAVX {} diff --git a/poulpy-backend/src/implementation/cpu_spqlios/spqlios-arithmetic b/poulpy-backend/src/implementation/cpu_spqlios/spqlios-arithmetic deleted file mode 160000 index de62af3..0000000 --- a/poulpy-backend/src/implementation/cpu_spqlios/spqlios-arithmetic +++ /dev/null @@ -1 +0,0 @@ -Subproject commit de62af3507776597231e0c0d2b26495a0c92d207 diff --git a/poulpy-backend/src/implementation/cpu_spqlios/svp_ppol_ntt120.rs b/poulpy-backend/src/implementation/cpu_spqlios/svp_ppol_ntt120.rs deleted file mode 100644 index 39a84f9..0000000 --- a/poulpy-backend/src/implementation/cpu_spqlios/svp_ppol_ntt120.rs +++ /dev/null @@ -1,44 +0,0 @@ -use crate::{ - hal::{ - api::{ZnxInfos, ZnxSliceSize, ZnxView}, - layouts::{Data, DataRef, SvpPPol, SvpPPolBytesOf, SvpPPolOwned}, - oep::{SvpPPolAllocBytesImpl, SvpPPolAllocImpl, SvpPPolFromBytesImpl}, - }, - implementation::cpu_spqlios::module_ntt120::NTT120, -}; - -const SVP_PPOL_NTT120_WORD_SIZE: usize = 4; - -impl SvpPPolBytesOf for SvpPPol { - fn bytes_of(n: usize, cols: usize) -> usize { - SVP_PPOL_NTT120_WORD_SIZE * n * cols * size_of::() - } -} - -impl ZnxSliceSize for SvpPPol { - fn sl(&self) -> usize { - SVP_PPOL_NTT120_WORD_SIZE * self.n() - } -} - -impl ZnxView for SvpPPol { - type Scalar = i64; -} - -unsafe impl SvpPPolFromBytesImpl for NTT120 { - fn svp_ppol_from_bytes_impl(n: usize, cols: usize, bytes: Vec) -> SvpPPolOwned { - SvpPPolOwned::from_bytes(n, cols, bytes) - } -} - -unsafe impl SvpPPolAllocImpl for NTT120 { - fn svp_ppol_alloc_impl(n: usize, cols: usize) -> SvpPPolOwned { - SvpPPolOwned::alloc(n, cols) - } -} - -unsafe impl SvpPPolAllocBytesImpl for NTT120 { - fn svp_ppol_alloc_bytes_impl(n: usize, cols: usize) -> usize { - SvpPPol::, Self>::bytes_of(n, cols) - } -} diff --git a/poulpy-backend/src/implementation/cpu_spqlios/test/mod.rs b/poulpy-backend/src/implementation/cpu_spqlios/test/mod.rs deleted file mode 100644 index 636f5a4..0000000 --- a/poulpy-backend/src/implementation/cpu_spqlios/test/mod.rs +++ /dev/null @@ -1 +0,0 @@ -mod vec_znx_fft64; diff --git a/poulpy-backend/src/implementation/cpu_spqlios/vec_znx_big_ntt120.rs b/poulpy-backend/src/implementation/cpu_spqlios/vec_znx_big_ntt120.rs deleted file mode 100644 index 42e632a..0000000 --- a/poulpy-backend/src/implementation/cpu_spqlios/vec_znx_big_ntt120.rs +++ /dev/null @@ -1,32 +0,0 @@ -use crate::{ - hal::{ - api::{ZnxInfos, ZnxSliceSize, ZnxView}, - layouts::{Data, DataRef, VecZnxBig, VecZnxBigBytesOf}, - oep::VecZnxBigAllocBytesImpl, - }, - implementation::cpu_spqlios::module_ntt120::NTT120, -}; - -const VEC_ZNX_BIG_NTT120_WORDSIZE: usize = 4; - -impl ZnxView for VecZnxBig { - type Scalar = i128; -} - -impl VecZnxBigBytesOf for VecZnxBig { - fn bytes_of(n: usize, cols: usize, size: usize) -> usize { - VEC_ZNX_BIG_NTT120_WORDSIZE * n * cols * size * size_of::() - } -} - -impl ZnxSliceSize for VecZnxBig { - fn sl(&self) -> usize { - VEC_ZNX_BIG_NTT120_WORDSIZE * self.n() * self.cols() - } -} - -unsafe impl VecZnxBigAllocBytesImpl for NTT120 { - fn vec_znx_big_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize { - VecZnxBig::, NTT120>::bytes_of(n, cols, size) - } -} diff --git a/poulpy-backend/src/implementation/cpu_spqlios/vec_znx_dft_ntt120.rs b/poulpy-backend/src/implementation/cpu_spqlios/vec_znx_dft_ntt120.rs deleted file mode 100644 index 3378fb0..0000000 --- a/poulpy-backend/src/implementation/cpu_spqlios/vec_znx_dft_ntt120.rs +++ /dev/null @@ -1,38 +0,0 @@ -use crate::{ - hal::{ - api::{ZnxInfos, ZnxSliceSize, ZnxView}, - layouts::{Data, DataRef, VecZnxDft, VecZnxDftBytesOf, VecZnxDftOwned}, - oep::{VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl}, - }, - implementation::cpu_spqlios::module_ntt120::NTT120, -}; - -const VEC_ZNX_DFT_NTT120_WORDSIZE: usize = 4; - -impl ZnxSliceSize for VecZnxDft { - fn sl(&self) -> usize { - VEC_ZNX_DFT_NTT120_WORDSIZE * self.n() * self.cols() - } -} - -impl VecZnxDftBytesOf for VecZnxDft { - fn bytes_of(n: usize, cols: usize, size: usize) -> usize { - VEC_ZNX_DFT_NTT120_WORDSIZE * n * cols * size * size_of::() - } -} - -impl ZnxView for VecZnxDft { - type Scalar = i64; -} - -unsafe impl VecZnxDftAllocBytesImpl for NTT120 { - fn vec_znx_dft_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize { - VecZnxDft::, NTT120>::bytes_of(n, cols, size) - } -} - -unsafe impl VecZnxDftAllocImpl for NTT120 { - fn vec_znx_dft_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxDftOwned { - VecZnxDftOwned::alloc(n, cols, size) - } -} diff --git a/poulpy-backend/src/implementation/cpu_spqlios/vmp_pmat_ntt120.rs b/poulpy-backend/src/implementation/cpu_spqlios/vmp_pmat_ntt120.rs deleted file mode 100644 index af135bf..0000000 --- a/poulpy-backend/src/implementation/cpu_spqlios/vmp_pmat_ntt120.rs +++ /dev/null @@ -1,11 +0,0 @@ -use crate::{ - hal::{ - api::ZnxView, - layouts::{DataRef, VmpPMat}, - }, - implementation::cpu_spqlios::module_ntt120::NTT120, -}; - -impl ZnxView for VmpPMat { - type Scalar = i64; -} diff --git a/poulpy-backend/src/implementation/mod.rs b/poulpy-backend/src/implementation/mod.rs deleted file mode 100644 index 15632e0..0000000 --- a/poulpy-backend/src/implementation/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod cpu_spqlios; diff --git a/poulpy-backend/src/lib.rs b/poulpy-backend/src/lib.rs index 981c679..15632e0 100644 --- a/poulpy-backend/src/lib.rs +++ b/poulpy-backend/src/lib.rs @@ -1,106 +1 @@ -#![allow(non_camel_case_types, non_snake_case, non_upper_case_globals, dead_code, improper_ctypes)] -#![deny(rustdoc::broken_intra_doc_links)] -#![cfg_attr(docsrs, feature(doc_cfg))] -#![feature(trait_alias)] - -pub mod hal; -pub mod implementation; - -pub mod doc { - #[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/docs/backend_safety_contract.md"))] - pub mod backend_safety { - pub const _PLACEHOLDER: () = (); - } -} - -pub const GALOISGENERATOR: u64 = 5; -pub const DEFAULTALIGN: usize = 64; - -fn is_aligned_custom(ptr: *const T, align: usize) -> bool { - (ptr as usize).is_multiple_of(align) -} - -pub fn is_aligned(ptr: *const T) -> bool { - is_aligned_custom(ptr, DEFAULTALIGN) -} - -pub fn assert_alignement(ptr: *const T) { - assert!( - is_aligned(ptr), - "invalid alignement: ensure passed bytes have been allocated with [alloc_aligned_u8] or [alloc_aligned]" - ) -} - -pub fn cast(data: &[T]) -> &[V] { - let ptr: *const V = data.as_ptr() as *const V; - let len: usize = data.len() / size_of::(); - unsafe { std::slice::from_raw_parts(ptr, len) } -} - -#[allow(clippy::mut_from_ref)] -pub fn cast_mut(data: &[T]) -> &mut [V] { - let ptr: *mut V = data.as_ptr() as *mut V; - let len: usize = data.len() / size_of::(); - unsafe { std::slice::from_raw_parts_mut(ptr, len) } -} - -/// Allocates a block of bytes with a custom alignement. -/// Alignement must be a power of two and size a multiple of the alignement. -/// Allocated memory is initialized to zero. -fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec { - assert!( - align.is_power_of_two(), - "Alignment must be a power of two but is {}", - align - ); - assert_eq!( - (size * size_of::()) % align, - 0, - "size={} must be a multiple of align={}", - size, - align - ); - unsafe { - let layout: std::alloc::Layout = std::alloc::Layout::from_size_align(size, align).expect("Invalid alignment"); - let ptr: *mut u8 = std::alloc::alloc(layout); - if ptr.is_null() { - panic!("Memory allocation failed"); - } - assert!( - is_aligned_custom(ptr, align), - "Memory allocation at {:p} is not aligned to {} bytes", - ptr, - align - ); - // Init allocated memory to zero - std::ptr::write_bytes(ptr, 0, size); - Vec::from_raw_parts(ptr, size, size) - } -} - -/// Allocates a block of T aligned with [DEFAULTALIGN]. -/// Size of T * size msut be a multiple of [DEFAULTALIGN]. -pub fn alloc_aligned_custom(size: usize, align: usize) -> Vec { - assert_eq!( - (size * size_of::()) % (align / size_of::()), - 0, - "size={} must be a multiple of align={}", - size, - align - ); - let mut vec_u8: Vec = alloc_aligned_custom_u8(size_of::() * size, align); - let ptr: *mut T = vec_u8.as_mut_ptr() as *mut T; - let len: usize = vec_u8.len() / size_of::(); - let cap: usize = vec_u8.capacity() / size_of::(); - std::mem::forget(vec_u8); - unsafe { Vec::from_raw_parts(ptr, len, cap) } -} - -/// Allocates an aligned vector of size equal to the smallest multiple -/// of [DEFAULTALIGN]/`size_of::`() that is equal or greater to `size`. -pub fn alloc_aligned(size: usize) -> Vec { - alloc_aligned_custom::( - size + (DEFAULTALIGN - (size % (DEFAULTALIGN / size_of::()))) % DEFAULTALIGN, - DEFAULTALIGN, - ) -} +pub mod cpu_spqlios; diff --git a/poulpy-core/Cargo.toml b/poulpy-core/Cargo.toml index c0297d1..bf871ad 100644 --- a/poulpy-core/Cargo.toml +++ b/poulpy-core/Cargo.toml @@ -1,9 +1,9 @@ [package] name = "poulpy-core" -version = "0.1.0" +version = "0.1.1" edition = "2024" license = "Apache-2.0" -description = "A crate implementing RLWE-based encrypted arithmetic" +description = "A backend agnostic crate implementing RLWE-based encryption & arithmetic." repository = "https://github.com/phantomzone-org/poulpy" homepage = "https://github.com/phantomzone-org/poulpy" documentation = "https://docs.rs/poulpy" @@ -11,7 +11,8 @@ documentation = "https://docs.rs/poulpy" [dependencies] rug = {workspace = true} criterion = {workspace = true} -poulpy-backend = "0.1.0" +poulpy-hal = "0.1.2" +poulpy-backend = "0.1.2" itertools = {workspace = true} byteorder = {workspace = true} diff --git a/poulpy-core/benches/external_product_glwe_fft64.rs b/poulpy-core/benches/external_product_glwe_fft64.rs index 121e748..2adc110 100644 --- a/poulpy-core/benches/external_product_glwe_fft64.rs +++ b/poulpy-core/benches/external_product_glwe_fft64.rs @@ -5,13 +5,12 @@ use poulpy_core::layouts::{ use std::hint::black_box; use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; -use poulpy_backend::{ - hal::{ - api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow}, - layouts::{Module, ScalarZnx, ScratchOwned}, - source::Source, - }, - implementation::cpu_spqlios::FFT64, + +use poulpy_backend::cpu_spqlios::FFT64; +use poulpy_hal::{ + api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow}, + layouts::{Module, ScalarZnx, ScratchOwned}, + source::Source, }; fn bench_external_product_glwe_fft64(c: &mut Criterion) { diff --git a/poulpy-core/benches/keyswitch_glwe_fft64.rs b/poulpy-core/benches/keyswitch_glwe_fft64.rs index e731c33..97168d8 100644 --- a/poulpy-core/benches/keyswitch_glwe_fft64.rs +++ b/poulpy-core/benches/keyswitch_glwe_fft64.rs @@ -5,13 +5,11 @@ use poulpy_core::layouts::{ use std::{hint::black_box, time::Duration}; use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; -use poulpy_backend::{ - hal::{ - api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow}, - layouts::{Module, ScratchOwned}, - source::Source, - }, - implementation::cpu_spqlios::FFT64, +use poulpy_backend::cpu_spqlios::FFT64; +use poulpy_hal::{ + api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow}, + layouts::{Module, ScratchOwned}, + source::Source, }; fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { diff --git a/poulpy-core/src/automorphism/gglwe_atk.rs b/poulpy-core/src/automorphism/gglwe_atk.rs index 1243738..21934c1 100644 --- a/poulpy-core/src/automorphism/gglwe_atk.rs +++ b/poulpy-core/src/automorphism/gglwe_atk.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchAvailable, TakeVecZnxDft, VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, diff --git a/poulpy-core/src/automorphism/ggsw_ct.rs b/poulpy-core/src/automorphism/ggsw_ct.rs index 355b5c0..7cd1a63 100644 --- a/poulpy-core/src/automorphism/ggsw_ct.rs +++ b/poulpy-core/src/automorphism/ggsw_ct.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchAvailable, TakeVecZnxBig, TakeVecZnxDft, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAddInplace, VecZnxDftAllocBytes, VecZnxDftCopy, diff --git a/poulpy-core/src/automorphism/glwe_ct.rs b/poulpy-core/src/automorphism/glwe_ct.rs index f4781d5..1b4c344 100644 --- a/poulpy-core/src/automorphism/glwe_ct.rs +++ b/poulpy-core/src/automorphism/glwe_ct.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchAvailable, TakeVecZnxDft, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallAInplace, VecZnxBigSubSmallBInplace, diff --git a/poulpy-core/src/conversion/glwe_to_lwe.rs b/poulpy-core/src/conversion/glwe_to_lwe.rs index 557698c..c99afda 100644 --- a/poulpy-core/src/conversion/glwe_to_lwe.rs +++ b/poulpy-core/src/conversion/glwe_to_lwe.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, VmpApply, VmpApplyAdd, VmpApplyTmpBytes, ZnxView, diff --git a/poulpy-core/src/conversion/lwe_to_glwe.rs b/poulpy-core/src/conversion/lwe_to_glwe.rs index 7e4cc7b..83d8b99 100644 --- a/poulpy-core/src/conversion/lwe_to_glwe.rs +++ b/poulpy-core/src/conversion/lwe_to_glwe.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, VmpApply, VmpApplyAdd, VmpApplyTmpBytes, ZnxView, diff --git a/poulpy-core/src/decryption/glwe_ct.rs b/poulpy-core/src/decryption/glwe_ct.rs index 89eb894..2d9ee17 100644 --- a/poulpy-core/src/decryption/glwe_ct.rs +++ b/poulpy-core/src/decryption/glwe_ct.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ DataViewMut, SvpApplyInplace, TakeVecZnxBig, TakeVecZnxDft, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, VecZnxNormalizeTmpBytes, diff --git a/poulpy-core/src/decryption/lwe_ct.rs b/poulpy-core/src/decryption/lwe_ct.rs index d257726..50fb4d7 100644 --- a/poulpy-core/src/decryption/lwe_ct.rs +++ b/poulpy-core/src/decryption/lwe_ct.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxNormalizeInplace, ZnxView, ZnxViewMut}, layouts::{Backend, DataMut, DataRef, Module, ScratchOwned}, oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl}, diff --git a/poulpy-core/src/encryption/compressed/gglwe_atk.rs b/poulpy-core/src/encryption/compressed/gglwe_atk.rs index 14f6dc5..41031de 100644 --- a/poulpy-core/src/encryption/compressed/gglwe_atk.rs +++ b/poulpy-core/src/encryption/compressed/gglwe_atk.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchAvailable, SvpApplyInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxBigNormalize, VecZnxDftAllocBytes, diff --git a/poulpy-core/src/encryption/compressed/gglwe_ct.rs b/poulpy-core/src/encryption/compressed/gglwe_ct.rs index 4166497..20a2440 100644 --- a/poulpy-core/src/encryption/compressed/gglwe_ct.rs +++ b/poulpy-core/src/encryption/compressed/gglwe_ct.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchAvailable, SvpApplyInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, VecZnxFillUniform, @@ -53,7 +53,7 @@ impl GGLWECiphertextCompressed { { #[cfg(debug_assertions)] { - use poulpy_backend::hal::api::ZnxInfos; + use poulpy_hal::api::ZnxInfos; assert_eq!( self.rank_in(), diff --git a/poulpy-core/src/encryption/compressed/gglwe_ksk.rs b/poulpy-core/src/encryption/compressed/gglwe_ksk.rs index 971e9e4..403ad89 100644 --- a/poulpy-core/src/encryption/compressed/gglwe_ksk.rs +++ b/poulpy-core/src/encryption/compressed/gglwe_ksk.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchAvailable, SvpApplyInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftFromVecZnx, diff --git a/poulpy-core/src/encryption/compressed/gglwe_tsk.rs b/poulpy-core/src/encryption/compressed/gglwe_tsk.rs index f2373aa..9ae1621 100644 --- a/poulpy-core/src/encryption/compressed/gglwe_tsk.rs +++ b/poulpy-core/src/encryption/compressed/gglwe_tsk.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchAvailable, SvpApply, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAllocBytes, diff --git a/poulpy-core/src/encryption/compressed/ggsw_ct.rs b/poulpy-core/src/encryption/compressed/ggsw_ct.rs index bec3abf..01ec988 100644 --- a/poulpy-core/src/encryption/compressed/ggsw_ct.rs +++ b/poulpy-core/src/encryption/compressed/ggsw_ct.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchAvailable, SvpApplyInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, VecZnxFillUniform, @@ -53,7 +53,7 @@ impl GGSWCiphertextCompressed { { #[cfg(debug_assertions)] { - use poulpy_backend::hal::api::ZnxInfos; + use poulpy_hal::api::ZnxInfos; assert_eq!(self.rank(), sk.rank()); assert_eq!(self.n(), sk.n()); diff --git a/poulpy-core/src/encryption/compressed/glwe_ct.rs b/poulpy-core/src/encryption/compressed/glwe_ct.rs index bfaa081..9c34692 100644 --- a/poulpy-core/src/encryption/compressed/glwe_ct.rs +++ b/poulpy-core/src/encryption/compressed/glwe_ct.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchAvailable, SvpApplyInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, VecZnxFillUniform, VecZnxNormalize, diff --git a/poulpy-core/src/encryption/gglwe_atk.rs b/poulpy-core/src/encryption/gglwe_atk.rs index bfb9da2..a3ffa72 100644 --- a/poulpy-core/src/encryption/gglwe_atk.rs +++ b/poulpy-core/src/encryption/gglwe_atk.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchAvailable, SvpApplyInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxBigNormalize, VecZnxDftAllocBytes, diff --git a/poulpy-core/src/encryption/gglwe_ct.rs b/poulpy-core/src/encryption/gglwe_ct.rs index 386b938..066bf75 100644 --- a/poulpy-core/src/encryption/gglwe_ct.rs +++ b/poulpy-core/src/encryption/gglwe_ct.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchAvailable, SvpApplyInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, VecZnxFillUniform, @@ -57,7 +57,7 @@ impl GGLWECiphertext { { #[cfg(debug_assertions)] { - use poulpy_backend::hal::api::ZnxInfos; + use poulpy_hal::api::ZnxInfos; assert_eq!( self.rank_in(), diff --git a/poulpy-core/src/encryption/gglwe_ksk.rs b/poulpy-core/src/encryption/gglwe_ksk.rs index 5230d99..158c947 100644 --- a/poulpy-core/src/encryption/gglwe_ksk.rs +++ b/poulpy-core/src/encryption/gglwe_ksk.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchAvailable, SvpApplyInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftFromVecZnx, diff --git a/poulpy-core/src/encryption/gglwe_tsk.rs b/poulpy-core/src/encryption/gglwe_tsk.rs index cb794dc..8190ccb 100644 --- a/poulpy-core/src/encryption/gglwe_tsk.rs +++ b/poulpy-core/src/encryption/gglwe_tsk.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchAvailable, SvpApply, SvpApplyInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, diff --git a/poulpy-core/src/encryption/ggsw_ct.rs b/poulpy-core/src/encryption/ggsw_ct.rs index 9d2b4dd..ee228cf 100644 --- a/poulpy-core/src/encryption/ggsw_ct.rs +++ b/poulpy-core/src/encryption/ggsw_ct.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchAvailable, SvpApplyInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, VecZnxFillUniform, @@ -56,7 +56,7 @@ impl GGSWCiphertext { { #[cfg(debug_assertions)] { - use poulpy_backend::hal::api::ZnxInfos; + use poulpy_hal::api::ZnxInfos; assert_eq!(self.rank(), sk.rank()); assert_eq!(self.n(), sk.n()); diff --git a/poulpy-core/src/encryption/glwe_ct.rs b/poulpy-core/src/encryption/glwe_ct.rs index 3b5cf12..f5faf57 100644 --- a/poulpy-core/src/encryption/glwe_ct.rs +++ b/poulpy-core/src/encryption/glwe_ct.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchAvailable, SvpApply, SvpApplyInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeSvpPPol, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxBigAddNormal, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, @@ -356,7 +356,7 @@ pub(crate) fn glwe_encrypt_sk_internal GLWECiphertext { #[cfg(debug_assertions)] { - use poulpy_backend::hal::api::ScratchAvailable; + use poulpy_hal::api::ScratchAvailable; assert_eq!(rhs.rank(), lhs.rank()); assert_eq!(rhs.rank(), self.rank()); diff --git a/poulpy-core/src/glwe_packing.rs b/poulpy-core/src/glwe_packing.rs index 488f54b..e93ce11 100644 --- a/poulpy-core/src/glwe_packing.rs +++ b/poulpy-core/src/glwe_packing.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace, VecZnxCopy, diff --git a/poulpy-core/src/glwe_trace.rs b/poulpy-core/src/glwe_trace.rs index 6a34045..c06f53c 100644 --- a/poulpy-core/src/glwe_trace.rs +++ b/poulpy-core/src/glwe_trace.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, diff --git a/poulpy-core/src/keyswitching/gglwe_ct.rs b/poulpy-core/src/keyswitching/gglwe_ct.rs index 092967f..bf9bc70 100644 --- a/poulpy-core/src/keyswitching/gglwe_ct.rs +++ b/poulpy-core/src/keyswitching/gglwe_ct.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, VmpApply, VmpApplyAdd, VmpApplyTmpBytes, ZnxZero, diff --git a/poulpy-core/src/keyswitching/ggsw_ct.rs b/poulpy-core/src/keyswitching/ggsw_ct.rs index 7fd40a7..1d8a4c0 100644 --- a/poulpy-core/src/keyswitching/ggsw_ct.rs +++ b/poulpy-core/src/keyswitching/ggsw_ct.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchAvailable, TakeVecZnxBig, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAddInplace, VecZnxDftAllocBytes, VecZnxDftCopy, VecZnxDftFromVecZnx, diff --git a/poulpy-core/src/keyswitching/glwe_ct.rs b/poulpy-core/src/keyswitching/glwe_ct.rs index 4395128..5dd402b 100644 --- a/poulpy-core/src/keyswitching/glwe_ct.rs +++ b/poulpy-core/src/keyswitching/glwe_ct.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ DataViewMut, ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, VmpApply, VmpApplyAdd, VmpApplyTmpBytes, ZnxInfos, diff --git a/poulpy-core/src/keyswitching/lwe_ct.rs b/poulpy-core/src/keyswitching/lwe_ct.rs index 5daa9ce..d0aca35 100644 --- a/poulpy-core/src/keyswitching/lwe_ct.rs +++ b/poulpy-core/src/keyswitching/lwe_ct.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, VmpApply, VmpApplyAdd, VmpApplyTmpBytes, ZnxView, diff --git a/poulpy-core/src/layouts/compressed/gglwe_atk.rs b/poulpy-core/src/layouts/compressed/gglwe_atk.rs index 819c150..27c32cc 100644 --- a/poulpy-core/src/layouts/compressed/gglwe_atk.rs +++ b/poulpy-core/src/layouts/compressed/gglwe_atk.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{FillUniform, Reset, VecZnxCopy, VecZnxFillUniform}, layouts::{Backend, Data, DataMut, DataRef, MatZnx, Module, ReaderFrom, WriterTo}, source::Source, diff --git a/poulpy-core/src/layouts/compressed/gglwe_ct.rs b/poulpy-core/src/layouts/compressed/gglwe_ct.rs index 8265350..2216937 100644 --- a/poulpy-core/src/layouts/compressed/gglwe_ct.rs +++ b/poulpy-core/src/layouts/compressed/gglwe_ct.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{FillUniform, Reset, VecZnxCopy, VecZnxFillUniform}, layouts::{Backend, Data, DataMut, DataRef, MatZnx, Module, ReaderFrom, WriterTo}, source::Source, @@ -201,7 +201,7 @@ impl Decompress Decompress Decompress PrepareAlloc, B>> for where Module: SvpPrepare + SvpPPolAlloc, { - fn prepare_alloc( - &self, - module: &Module, - scratch: &mut poulpy_backend::hal::layouts::Scratch, - ) -> GLWESecretPrepared, B> { + fn prepare_alloc(&self, module: &Module, scratch: &mut poulpy_hal::layouts::Scratch) -> GLWESecretPrepared, B> { let mut sk_dft: GLWESecretPrepared, B> = GLWESecretPrepared::alloc(module, self.n(), self.rank()); sk_dft.prepare(module, self, scratch); sk_dft @@ -68,7 +64,7 @@ impl Prepare> for GLWESe where Module: SvpPrepare, { - fn prepare(&mut self, module: &Module, other: &GLWESecret, _scratch: &mut poulpy_backend::hal::layouts::Scratch) { + fn prepare(&mut self, module: &Module, other: &GLWESecret, _scratch: &mut poulpy_hal::layouts::Scratch) { (0..self.rank()).for_each(|i| { module.svp_prepare(&mut self.data, i, &other.data, i); }); diff --git a/poulpy-core/src/layouts/prepared/glwe_to_lwe_ksk.rs b/poulpy-core/src/layouts/prepared/glwe_to_lwe_ksk.rs index e3fc88d..bce3954 100644 --- a/poulpy-core/src/layouts/prepared/glwe_to_lwe_ksk.rs +++ b/poulpy-core/src/layouts/prepared/glwe_to_lwe_ksk.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{VmpPMatAlloc, VmpPMatAllocBytes, VmpPrepare}, layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VmpPMat}, }; diff --git a/poulpy-core/src/layouts/prepared/lwe_ksk.rs b/poulpy-core/src/layouts/prepared/lwe_ksk.rs index 522c713..3177671 100644 --- a/poulpy-core/src/layouts/prepared/lwe_ksk.rs +++ b/poulpy-core/src/layouts/prepared/lwe_ksk.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{VmpPMatAlloc, VmpPMatAllocBytes, VmpPrepare}, layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VmpPMat}, }; diff --git a/poulpy-core/src/layouts/prepared/lwe_to_glwe_ksk.rs b/poulpy-core/src/layouts/prepared/lwe_to_glwe_ksk.rs index e7692a1..806fa7f 100644 --- a/poulpy-core/src/layouts/prepared/lwe_to_glwe_ksk.rs +++ b/poulpy-core/src/layouts/prepared/lwe_to_glwe_ksk.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{VmpPMatAlloc, VmpPMatAllocBytes, VmpPrepare}, layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VmpPMat}, }; diff --git a/poulpy-core/src/layouts/prepared/mod.rs b/poulpy-core/src/layouts/prepared/mod.rs index 240c11f..9da8f27 100644 --- a/poulpy-core/src/layouts/prepared/mod.rs +++ b/poulpy-core/src/layouts/prepared/mod.rs @@ -19,7 +19,7 @@ pub use glwe_sk::*; pub use glwe_to_lwe_ksk::*; pub use lwe_ksk::*; pub use lwe_to_glwe_ksk::*; -use poulpy_backend::hal::layouts::{Backend, Module, Scratch}; +use poulpy_hal::layouts::{Backend, Module, Scratch}; pub trait PrepareAlloc { fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> T; diff --git a/poulpy-core/src/noise/gglwe_ct.rs b/poulpy-core/src/noise/gglwe_ct.rs index 572c409..6fe8af2 100644 --- a/poulpy-core/src/noise/gglwe_ct.rs +++ b/poulpy-core/src/noise/gglwe_ct.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, diff --git a/poulpy-core/src/noise/ggsw_ct.rs b/poulpy-core/src/noise/ggsw_ct.rs index 0405dc1..4b5bfb8 100644 --- a/poulpy-core/src/noise/ggsw_ct.rs +++ b/poulpy-core/src/noise/ggsw_ct.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, diff --git a/poulpy-core/src/noise/glwe_ct.rs b/poulpy-core/src/noise/glwe_ct.rs index d675e22..6f9b0e3 100644 --- a/poulpy-core/src/noise/glwe_ct.rs +++ b/poulpy-core/src/noise/glwe_ct.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, diff --git a/poulpy-core/src/operations/glwe.rs b/poulpy-core/src/operations/glwe.rs index 8df1d19..3b45636 100644 --- a/poulpy-core/src/operations/glwe.rs +++ b/poulpy-core/src/operations/glwe.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ VecZnxAdd, VecZnxAddInplace, VecZnxCopy, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, diff --git a/poulpy-core/src/scratch.rs b/poulpy-core/src/scratch.rs index f21c080..cc1dda8 100644 --- a/poulpy-core/src/scratch.rs +++ b/poulpy-core/src/scratch.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{TakeMatZnx, TakeScalarZnx, TakeSvpPPol, TakeVecZnx, TakeVecZnxDft, TakeVmpPMat}, layouts::{Backend, DataRef, Scratch}, oep::{TakeMatZnxImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, TakeVmpPMatImpl}, diff --git a/poulpy-core/src/tests/generics/automorphism/gglwe_atk.rs b/poulpy-core/src/tests/generics/automorphism/gglwe_atk.rs index 43022c4..bcfd261 100644 --- a/poulpy-core/src/tests/generics/automorphism/gglwe_atk.rs +++ b/poulpy-core/src/tests/generics/automorphism/gglwe_atk.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxBigAddInplace, diff --git a/poulpy-core/src/tests/generics/automorphism/ggsw_ct.rs b/poulpy-core/src/tests/generics/automorphism/ggsw_ct.rs index e8aeca5..3498e81 100644 --- a/poulpy-core/src/tests/generics/automorphism/ggsw_ct.rs +++ b/poulpy-core/src/tests/generics/automorphism/ggsw_ct.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApply, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, diff --git a/poulpy-core/src/tests/generics/automorphism/glwe_ct.rs b/poulpy-core/src/tests/generics/automorphism/glwe_ct.rs index 30367c3..000ac2b 100644 --- a/poulpy-core/src/tests/generics/automorphism/glwe_ct.rs +++ b/poulpy-core/src/tests/generics/automorphism/glwe_ct.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxBigAddInplace, diff --git a/poulpy-core/src/tests/generics/conversion.rs b/poulpy-core/src/tests/generics/conversion.rs index beca378..90b1cfc 100644 --- a/poulpy-core/src/tests/generics/conversion.rs +++ b/poulpy-core/src/tests/generics/conversion.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphismInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, diff --git a/poulpy-core/src/tests/generics/encryption/gglwe_atk.rs b/poulpy-core/src/tests/generics/encryption/gglwe_atk.rs index 8e494cf..a28d083 100644 --- a/poulpy-core/src/tests/generics/encryption/gglwe_atk.rs +++ b/poulpy-core/src/tests/generics/encryption/gglwe_atk.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxBigAddInplace, diff --git a/poulpy-core/src/tests/generics/encryption/gglwe_ct.rs b/poulpy-core/src/tests/generics/encryption/gglwe_ct.rs index 56c3e67..0ce46f9 100644 --- a/poulpy-core/src/tests/generics/encryption/gglwe_ct.rs +++ b/poulpy-core/src/tests/generics/encryption/gglwe_ct.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, diff --git a/poulpy-core/src/tests/generics/encryption/ggsw_ct.rs b/poulpy-core/src/tests/generics/encryption/ggsw_ct.rs index 66b4abc..60271b1 100644 --- a/poulpy-core/src/tests/generics/encryption/ggsw_ct.rs +++ b/poulpy-core/src/tests/generics/encryption/ggsw_ct.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, diff --git a/poulpy-core/src/tests/generics/encryption/glwe_ct.rs b/poulpy-core/src/tests/generics/encryption/glwe_ct.rs index 71de183..8d679e9 100644 --- a/poulpy-core/src/tests/generics/encryption/glwe_ct.rs +++ b/poulpy-core/src/tests/generics/encryption/glwe_ct.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApply, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxBigAddInplace, VecZnxBigAddNormal, VecZnxBigAddSmallInplace, diff --git a/poulpy-core/src/tests/generics/encryption/glwe_tsk.rs b/poulpy-core/src/tests/generics/encryption/glwe_tsk.rs index 9cf4592..dc83ca8 100644 --- a/poulpy-core/src/tests/generics/encryption/glwe_tsk.rs +++ b/poulpy-core/src/tests/generics/encryption/glwe_tsk.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApply, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, diff --git a/poulpy-core/src/tests/generics/external_product/gglwe_ksk.rs b/poulpy-core/src/tests/generics/external_product/gglwe_ksk.rs index 3c07b27..2641893 100644 --- a/poulpy-core/src/tests/generics/external_product/gglwe_ksk.rs +++ b/poulpy-core/src/tests/generics/external_product/gglwe_ksk.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, diff --git a/poulpy-core/src/tests/generics/external_product/ggsw_ct.rs b/poulpy-core/src/tests/generics/external_product/ggsw_ct.rs index 8ae50e7..30ba710 100644 --- a/poulpy-core/src/tests/generics/external_product/ggsw_ct.rs +++ b/poulpy-core/src/tests/generics/external_product/ggsw_ct.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, diff --git a/poulpy-core/src/tests/generics/external_product/glwe_ct.rs b/poulpy-core/src/tests/generics/external_product/glwe_ct.rs index d99a2d9..c83bc17 100644 --- a/poulpy-core/src/tests/generics/external_product/glwe_ct.rs +++ b/poulpy-core/src/tests/generics/external_product/glwe_ct.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, diff --git a/poulpy-core/src/tests/generics/keyswitch/gglwe_ct.rs b/poulpy-core/src/tests/generics/keyswitch/gglwe_ct.rs index f854d0d..fbb57bb 100644 --- a/poulpy-core/src/tests/generics/keyswitch/gglwe_ct.rs +++ b/poulpy-core/src/tests/generics/keyswitch/gglwe_ct.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, diff --git a/poulpy-core/src/tests/generics/keyswitch/ggsw_ct.rs b/poulpy-core/src/tests/generics/keyswitch/ggsw_ct.rs index 92b1eb9..9315e06 100644 --- a/poulpy-core/src/tests/generics/keyswitch/ggsw_ct.rs +++ b/poulpy-core/src/tests/generics/keyswitch/ggsw_ct.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApply, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, diff --git a/poulpy-core/src/tests/generics/keyswitch/glwe_ct.rs b/poulpy-core/src/tests/generics/keyswitch/glwe_ct.rs index 1daf445..5f648fd 100644 --- a/poulpy-core/src/tests/generics/keyswitch/glwe_ct.rs +++ b/poulpy-core/src/tests/generics/keyswitch/glwe_ct.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, diff --git a/poulpy-core/src/tests/generics/keyswitch/lwe_ct.rs b/poulpy-core/src/tests/generics/keyswitch/lwe_ct.rs index fd028a8..da469a9 100644 --- a/poulpy-core/src/tests/generics/keyswitch/lwe_ct.rs +++ b/poulpy-core/src/tests/generics/keyswitch/lwe_ct.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphismInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, diff --git a/poulpy-core/src/tests/generics/packing.rs b/poulpy-core/src/tests/generics/packing.rs index fb9f834..27921a9 100644 --- a/poulpy-core/src/tests/generics/packing.rs +++ b/poulpy-core/src/tests/generics/packing.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxBigAddInplace, diff --git a/poulpy-core/src/tests/generics/trace.rs b/poulpy-core/src/tests/generics/trace.rs index 84f369f..aba261e 100644 --- a/poulpy-core/src/tests/generics/trace.rs +++ b/poulpy-core/src/tests/generics/trace.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, diff --git a/poulpy-core/src/tests/implementation/cpu_spqlios/fft64/gglwe.rs b/poulpy-core/src/tests/implementation/cpu_spqlios/fft64/gglwe.rs index 2397b0b..ec250e6 100644 --- a/poulpy-core/src/tests/implementation/cpu_spqlios/fft64/gglwe.rs +++ b/poulpy-core/src/tests/implementation/cpu_spqlios/fft64/gglwe.rs @@ -1,7 +1,5 @@ -use poulpy_backend::{ - hal::{api::ModuleNew, layouts::Module}, - implementation::cpu_spqlios::FFT64, -}; +use poulpy_backend::cpu_spqlios::FFT64; +use poulpy_hal::{api::ModuleNew, layouts::Module}; use crate::tests::generics::{ automorphism::{test_gglwe_automorphism_key_automorphism, test_gglwe_automorphism_key_automorphism_inplace}, @@ -17,7 +15,7 @@ use crate::tests::generics::{ #[test] fn gglwe_switching_key_encrypt_sk() { let log_n: usize = 8; - let module: Module = Module::::new(1 << log_n); + let module = Module::::new(1 << log_n); let basek: usize = 12; let k_ksk: usize = 54; let digits: usize = k_ksk / basek; diff --git a/poulpy-core/src/tests/implementation/cpu_spqlios/fft64/ggws.rs b/poulpy-core/src/tests/implementation/cpu_spqlios/fft64/ggws.rs index fa62380..ab95ea3 100644 --- a/poulpy-core/src/tests/implementation/cpu_spqlios/fft64/ggws.rs +++ b/poulpy-core/src/tests/implementation/cpu_spqlios/fft64/ggws.rs @@ -1,7 +1,5 @@ -use poulpy_backend::{ - hal::{api::ModuleNew, layouts::Module}, - implementation::cpu_spqlios::FFT64, -}; +use poulpy_backend::cpu_spqlios::FFT64; +use poulpy_hal::{api::ModuleNew, layouts::Module}; use crate::tests::generics::{ automorphism::{test_ggsw_automorphism, test_ggsw_automorphism_inplace}, diff --git a/poulpy-core/src/tests/implementation/cpu_spqlios/fft64/glwe.rs b/poulpy-core/src/tests/implementation/cpu_spqlios/fft64/glwe.rs index 95b7717..f0f3b6c 100644 --- a/poulpy-core/src/tests/implementation/cpu_spqlios/fft64/glwe.rs +++ b/poulpy-core/src/tests/implementation/cpu_spqlios/fft64/glwe.rs @@ -1,7 +1,5 @@ -use poulpy_backend::{ - hal::{api::ModuleNew, layouts::Module}, - implementation::cpu_spqlios::FFT64, -}; +use poulpy_backend::cpu_spqlios::FFT64; +use poulpy_hal::{api::ModuleNew, layouts::Module}; use crate::tests::generics::{ automorphism::{test_glwe_automorphism, test_glwe_automorphism_inplace}, diff --git a/poulpy-core/src/tests/implementation/cpu_spqlios/fft64/lwe.rs b/poulpy-core/src/tests/implementation/cpu_spqlios/fft64/lwe.rs index c8e0a7f..5213d0c 100644 --- a/poulpy-core/src/tests/implementation/cpu_spqlios/fft64/lwe.rs +++ b/poulpy-core/src/tests/implementation/cpu_spqlios/fft64/lwe.rs @@ -1,9 +1,6 @@ -use poulpy_backend::{ - hal::{api::ModuleNew, layouts::Module}, - implementation::cpu_spqlios::FFT64, -}; - use crate::tests::generics::{keyswitch::test_lwe_keyswitch, test_glwe_to_lwe, test_lwe_to_glwe}; +use poulpy_backend::cpu_spqlios::FFT64; +use poulpy_hal::{api::ModuleNew, layouts::Module}; #[test] fn lwe_to_glwe() { diff --git a/poulpy-core/src/tests/serialization.rs b/poulpy-core/src/tests/serialization.rs index 8c4a359..14f8177 100644 --- a/poulpy-core/src/tests/serialization.rs +++ b/poulpy-core/src/tests/serialization.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::tests::serialization::test_reader_writer_interface; +use poulpy_hal::tests::serialization::test_reader_writer_interface; use crate::layouts::{ GGLWEAutomorphismKey, GGLWECiphertext, GGLWESwitchingKey, GGLWETensorKey, GGSWCiphertext, GLWECiphertext, @@ -21,7 +21,7 @@ const DIGITS: usize = 1; #[test] fn glwe_serialization() { let original: GLWECiphertext> = GLWECiphertext::alloc(N_GLWE, BASEK, K, RANK); - poulpy_backend::hal::tests::serialization::test_reader_writer_interface(original); + poulpy_hal::tests::serialization::test_reader_writer_interface(original); } #[test] diff --git a/poulpy-core/src/utils.rs b/poulpy-core/src/utils.rs index 883b57d..3d54753 100644 --- a/poulpy-core/src/utils.rs +++ b/poulpy-core/src/utils.rs @@ -1,5 +1,5 @@ use crate::layouts::{GLWEPlaintext, Infos, LWEPlaintext}; -use poulpy_backend::hal::layouts::{DataMut, DataRef}; +use poulpy_hal::layouts::{DataMut, DataRef}; use rug::Float; impl GLWEPlaintext { diff --git a/poulpy-hal/Cargo.toml b/poulpy-hal/Cargo.toml new file mode 100644 index 0000000..07d01f3 --- /dev/null +++ b/poulpy-hal/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "poulpy-hal" +version = "0.1.2" +edition = "2024" +license = "Apache-2.0" +readme = "README.md" +description = "A crate providing layouts and a trait-based hardware acceleration layer with open extension points, matching the API and types of spqlios-arithmetic." +repository = "https://github.com/phantomzone-org/poulpy" +homepage = "https://github.com/phantomzone-org/poulpy" +documentation = "https://docs.rs/poulpy" + +[dependencies] +rug = {workspace = true} +criterion = {workspace = true} +itertools = {workspace = true} +rand = {workspace = true} +rand_distr = {workspace = true} +rand_core = {workspace = true} +byteorder = {workspace = true} +rand_chacha = "0.9.0" + +[build-dependencies] +cmake = "0.1.54" + +[package.metadata.docs.rs] +all-features = true +rustdoc-args = ["--cfg", "docsrs"] \ No newline at end of file diff --git a/poulpy-hal/README.md b/poulpy-hal/README.md new file mode 100644 index 0000000..e69de29 diff --git a/poulpy-backend/docs/backend_safety_contract.md b/poulpy-hal/docs/backend_safety_contract.md similarity index 100% rename from poulpy-backend/docs/backend_safety_contract.md rename to poulpy-hal/docs/backend_safety_contract.md diff --git a/poulpy-backend/src/hal/api/mod.rs b/poulpy-hal/src/api/mod.rs similarity index 100% rename from poulpy-backend/src/hal/api/mod.rs rename to poulpy-hal/src/api/mod.rs diff --git a/poulpy-hal/src/api/module.rs b/poulpy-hal/src/api/module.rs new file mode 100644 index 0000000..6e5faed --- /dev/null +++ b/poulpy-hal/src/api/module.rs @@ -0,0 +1,6 @@ +use crate::layouts::Backend; + +/// Instantiate a new [crate::layouts::Module]. +pub trait ModuleNew { + fn new(n: u64) -> Self; +} diff --git a/poulpy-backend/src/hal/api/scratch.rs b/poulpy-hal/src/api/scratch.rs similarity index 95% rename from poulpy-backend/src/hal/api/scratch.rs rename to poulpy-hal/src/api/scratch.rs index 812ed30..8ef66eb 100644 --- a/poulpy-backend/src/hal/api/scratch.rs +++ b/poulpy-hal/src/api/scratch.rs @@ -1,6 +1,6 @@ -use crate::hal::layouts::{Backend, MatZnx, ScalarZnx, Scratch, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat}; +use crate::layouts::{Backend, MatZnx, ScalarZnx, Scratch, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat}; -/// Allocates a new [crate::hal::layouts::ScratchOwned] of `size` aligned bytes. +/// Allocates a new [crate::layouts::ScratchOwned] of `size` aligned bytes. pub trait ScratchOwnedAlloc { fn alloc(size: usize) -> Self; } diff --git a/poulpy-backend/src/hal/api/svp_ppol.rs b/poulpy-hal/src/api/svp_ppol.rs similarity index 76% rename from poulpy-backend/src/hal/api/svp_ppol.rs rename to poulpy-hal/src/api/svp_ppol.rs index 05aa189..f8a1cfc 100644 --- a/poulpy-backend/src/hal/api/svp_ppol.rs +++ b/poulpy-hal/src/api/svp_ppol.rs @@ -1,22 +1,22 @@ -use crate::hal::layouts::{Backend, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDftToMut, VecZnxDftToRef}; +use crate::layouts::{Backend, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDftToMut, VecZnxDftToRef}; -/// Allocates as [crate::hal::layouts::SvpPPol]. +/// Allocates as [crate::layouts::SvpPPol]. pub trait SvpPPolAlloc { fn svp_ppol_alloc(&self, n: usize, cols: usize) -> SvpPPolOwned; } -/// Returns the size in bytes to allocate a [crate::hal::layouts::SvpPPol]. +/// Returns the size in bytes to allocate a [crate::layouts::SvpPPol]. pub trait SvpPPolAllocBytes { fn svp_ppol_alloc_bytes(&self, n: usize, cols: usize) -> usize; } -/// Consume a vector of bytes into a [crate::hal::layouts::MatZnx]. +/// Consume a vector of bytes into a [crate::layouts::MatZnx]. /// User must ensure that bytes is memory aligned and that it length is equal to [SvpPPolAllocBytes]. pub trait SvpPPolFromBytes { fn svp_ppol_from_bytes(&self, n: usize, cols: usize, bytes: Vec) -> SvpPPolOwned; } -/// Prepare a [crate::hal::layouts::ScalarZnx] into an [crate::hal::layouts::SvpPPol]. +/// Prepare a [crate::layouts::ScalarZnx] into an [crate::layouts::SvpPPol]. pub trait SvpPrepare { fn svp_prepare(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) where diff --git a/poulpy-backend/src/hal/api/vec_znx.rs b/poulpy-hal/src/api/vec_znx.rs similarity index 97% rename from poulpy-backend/src/hal/api/vec_znx.rs rename to poulpy-hal/src/api/vec_znx.rs index cc1525f..d5ab4bc 100644 --- a/poulpy-backend/src/hal/api/vec_znx.rs +++ b/poulpy-hal/src/api/vec_znx.rs @@ -1,6 +1,6 @@ use rand_distr::Distribution; -use crate::hal::{ +use crate::{ layouts::{Backend, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef}, source::Source, }; @@ -164,7 +164,7 @@ pub trait VecZnxSplit { /// /// # Panics /// - /// This method requires that all [crate::hal::layouts::VecZnx] of b have the same ring degree + /// This method requires that all [crate::layouts::VecZnx] of b have the same ring degree /// and that b.n() * b.len() <= a.n() fn vec_znx_split(&self, res: &mut [R], res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) where @@ -177,7 +177,7 @@ pub trait VecZnxMerge { /// /// # Panics /// - /// This method requires that all [crate::hal::layouts::VecZnx] of a have the same ring degree + /// This method requires that all [crate::layouts::VecZnx] of a have the same ring degree /// and that a.n() * a.len() <= b.n() fn vec_znx_merge(&self, res: &mut R, res_col: usize, a: &[A], a_col: usize) where diff --git a/poulpy-backend/src/hal/api/vec_znx_big.rs b/poulpy-hal/src/api/vec_znx_big.rs similarity index 96% rename from poulpy-backend/src/hal/api/vec_znx_big.rs rename to poulpy-hal/src/api/vec_znx_big.rs index 2f6ee4c..11b52b2 100644 --- a/poulpy-backend/src/hal/api/vec_znx_big.rs +++ b/poulpy-hal/src/api/vec_znx_big.rs @@ -1,21 +1,21 @@ use rand_distr::Distribution; -use crate::hal::{ +use crate::{ layouts::{Backend, Scratch, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef}, source::Source, }; -/// Allocates as [crate::hal::layouts::VecZnxBig]. +/// Allocates as [crate::layouts::VecZnxBig]. pub trait VecZnxBigAlloc { fn vec_znx_big_alloc(&self, n: usize, cols: usize, size: usize) -> VecZnxBigOwned; } -/// Returns the size in bytes to allocate a [crate::hal::layouts::VecZnxBig]. +/// Returns the size in bytes to allocate a [crate::layouts::VecZnxBig]. pub trait VecZnxBigAllocBytes { fn vec_znx_big_alloc_bytes(&self, n: usize, cols: usize, size: usize) -> usize; } -/// Consume a vector of bytes into a [crate::hal::layouts::VecZnxBig]. +/// Consume a vector of bytes into a [crate::layouts::VecZnxBig]. /// User must ensure that bytes is memory aligned and that it length is equal to [VecZnxBigAllocBytes]. pub trait VecZnxBigFromBytes { fn vec_znx_big_from_bytes(&self, n: usize, cols: usize, size: usize, bytes: Vec) -> VecZnxBigOwned; diff --git a/poulpy-backend/src/hal/api/vec_znx_dft.rs b/poulpy-hal/src/api/vec_znx_dft.rs similarity index 99% rename from poulpy-backend/src/hal/api/vec_znx_dft.rs rename to poulpy-hal/src/api/vec_znx_dft.rs index 8efe892..8fef8d7 100644 --- a/poulpy-backend/src/hal/api/vec_znx_dft.rs +++ b/poulpy-hal/src/api/vec_znx_dft.rs @@ -1,4 +1,4 @@ -use crate::hal::layouts::{ +use crate::layouts::{ Backend, Data, Scratch, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, }; diff --git a/poulpy-backend/src/hal/api/vmp_pmat.rs b/poulpy-hal/src/api/vmp_pmat.rs similarity index 71% rename from poulpy-backend/src/hal/api/vmp_pmat.rs rename to poulpy-hal/src/api/vmp_pmat.rs index 9f06ea1..7b3b732 100644 --- a/poulpy-backend/src/hal/api/vmp_pmat.rs +++ b/poulpy-hal/src/api/vmp_pmat.rs @@ -1,6 +1,4 @@ -use crate::hal::layouts::{ - Backend, MatZnxToRef, Scratch, VecZnxDftToMut, VecZnxDftToRef, VmpPMatOwned, VmpPMatToMut, VmpPMatToRef, -}; +use crate::layouts::{Backend, MatZnxToRef, Scratch, VecZnxDftToMut, VecZnxDftToRef, VmpPMatOwned, VmpPMatToMut, VmpPMatToRef}; pub trait VmpPMatAlloc { fn vmp_pmat_alloc(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned; @@ -48,14 +46,14 @@ pub trait VmpApplyTmpBytes { } pub trait VmpApply { - /// Applies the vector matrix product [crate::hal::layouts::VecZnxDft] x [crate::hal::layouts::VmpPMat]. + /// Applies the vector matrix product [crate::layouts::VecZnxDft] x [crate::layouts::VmpPMat]. /// - /// A vector matrix product numerically equivalent to a sum of [crate::hal::api::SvpApply], - /// where each [crate::hal::layouts::SvpPPol] is a limb of the input [crate::hal::layouts::VecZnx] in DFT, - /// and each vector a [crate::hal::layouts::VecZnxDft] (row) of the [crate::hal::layouts::VmpPMat]. + /// A vector matrix product numerically equivalent to a sum of [crate::api::SvpApply], + /// where each [crate::layouts::SvpPPol] is a limb of the input [crate::layouts::VecZnx] in DFT, + /// and each vector a [crate::layouts::VecZnxDft] (row) of the [crate::layouts::VmpPMat]. /// - /// As such, given an input [crate::hal::layouts::VecZnx] of `i` size and a [crate::hal::layouts::VmpPMat] of `i` rows and - /// `j` size, the output is a [crate::hal::layouts::VecZnx] of `j` size. + /// As such, given an input [crate::layouts::VecZnx] of `i` size and a [crate::layouts::VmpPMat] of `i` rows and + /// `j` size, the output is a [crate::layouts::VecZnx] of `j` size. /// /// If there is a mismatch between the dimensions the largest valid ones are used. /// @@ -64,13 +62,13 @@ pub trait VmpApply { /// |h i j| /// |k l m| /// ``` - /// where each element is a [crate::hal::layouts::VecZnxDft]. + /// where each element is a [crate::layouts::VecZnxDft]. /// /// # Arguments /// - /// * `c`: the output of the vector matrix product, as a [crate::hal::layouts::VecZnxDft]. - /// * `a`: the left operand [crate::hal::layouts::VecZnxDft] of the vector matrix product. - /// * `b`: the right operand [crate::hal::layouts::VmpPMat] of the vector matrix product. + /// * `c`: the output of the vector matrix product, as a [crate::layouts::VecZnxDft]. + /// * `a`: the left operand [crate::layouts::VecZnxDft] of the vector matrix product. + /// * `b`: the right operand [crate::layouts::VmpPMat] of the vector matrix product. /// * `buf`: scratch space, the size can be obtained with [VmpApplyTmpBytes::vmp_apply_tmp_bytes]. fn vmp_apply(&self, res: &mut R, a: &A, b: &C, scratch: &mut Scratch) where diff --git a/poulpy-backend/src/hal/api/znx_base.rs b/poulpy-hal/src/api/znx_base.rs similarity index 96% rename from poulpy-backend/src/hal/api/znx_base.rs rename to poulpy-hal/src/api/znx_base.rs index bc15acc..deab5ea 100644 --- a/poulpy-backend/src/hal/api/znx_base.rs +++ b/poulpy-hal/src/api/znx_base.rs @@ -1,5 +1,5 @@ -use crate::hal::{ - layouts::{Data, DataMut, DataRef}, +use crate::{ + layouts::{Backend, Data, DataMut, DataRef}, source::Source, }; use rand_distr::num_traits::Zero; @@ -28,6 +28,10 @@ pub trait ZnxInfos { } } +pub trait ZnxSliceSizeImpl { + fn slice_size(&self) -> usize; +} + pub trait ZnxSliceSize { /// Returns the slice size, which is the offset between /// two size of the same column. diff --git a/poulpy-backend/src/hal/delegates/mod.rs b/poulpy-hal/src/delegates/mod.rs similarity index 100% rename from poulpy-backend/src/hal/delegates/mod.rs rename to poulpy-hal/src/delegates/mod.rs diff --git a/poulpy-backend/src/hal/delegates/module.rs b/poulpy-hal/src/delegates/module.rs similarity index 92% rename from poulpy-backend/src/hal/delegates/module.rs rename to poulpy-hal/src/delegates/module.rs index a9a8d24..0e3a455 100644 --- a/poulpy-backend/src/hal/delegates/module.rs +++ b/poulpy-hal/src/delegates/module.rs @@ -1,4 +1,4 @@ -use crate::hal::{ +use crate::{ api::ModuleNew, layouts::{Backend, Module}, oep::ModuleNewImpl, diff --git a/poulpy-backend/src/hal/delegates/scratch.rs b/poulpy-hal/src/delegates/scratch.rs similarity index 99% rename from poulpy-backend/src/hal/delegates/scratch.rs rename to poulpy-hal/src/delegates/scratch.rs index a5f6b58..95c5b92 100644 --- a/poulpy-backend/src/hal/delegates/scratch.rs +++ b/poulpy-hal/src/delegates/scratch.rs @@ -1,4 +1,4 @@ -use crate::hal::{ +use crate::{ api::{ ScratchAvailable, ScratchFromBytes, ScratchOwnedAlloc, ScratchOwnedBorrow, TakeLike, TakeMatZnx, TakeScalarZnx, TakeSlice, TakeSvpPPol, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, TakeVecZnxDftSlice, TakeVecZnxSlice, TakeVmpPMat, diff --git a/poulpy-backend/src/hal/delegates/svp_ppol.rs b/poulpy-hal/src/delegates/svp_ppol.rs similarity index 99% rename from poulpy-backend/src/hal/delegates/svp_ppol.rs rename to poulpy-hal/src/delegates/svp_ppol.rs index e47e474..af76dd7 100644 --- a/poulpy-backend/src/hal/delegates/svp_ppol.rs +++ b/poulpy-hal/src/delegates/svp_ppol.rs @@ -1,4 +1,4 @@ -use crate::hal::{ +use crate::{ api::{SvpApply, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPPolFromBytes, SvpPrepare}, layouts::{Backend, Module, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDftToMut, VecZnxDftToRef}, oep::{SvpApplyImpl, SvpApplyInplaceImpl, SvpPPolAllocBytesImpl, SvpPPolAllocImpl, SvpPPolFromBytesImpl, SvpPrepareImpl}, diff --git a/poulpy-backend/src/hal/delegates/vec_znx.rs b/poulpy-hal/src/delegates/vec_znx.rs similarity index 99% rename from poulpy-backend/src/hal/delegates/vec_znx.rs rename to poulpy-hal/src/delegates/vec_znx.rs index db53467..4ff3964 100644 --- a/poulpy-backend/src/hal/delegates/vec_znx.rs +++ b/poulpy-hal/src/delegates/vec_znx.rs @@ -1,4 +1,4 @@ -use crate::hal::{ +use crate::{ api::{ VecZnxAdd, VecZnxAddDistF64, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxCopy, VecZnxFillDistF64, VecZnxFillNormal, VecZnxFillUniform, VecZnxLshInplace, diff --git a/poulpy-backend/src/hal/delegates/vec_znx_big.rs b/poulpy-hal/src/delegates/vec_znx_big.rs similarity index 99% rename from poulpy-backend/src/hal/delegates/vec_znx_big.rs rename to poulpy-hal/src/delegates/vec_znx_big.rs index 1d00c72..e78092a 100644 --- a/poulpy-backend/src/hal/delegates/vec_znx_big.rs +++ b/poulpy-hal/src/delegates/vec_znx_big.rs @@ -1,6 +1,6 @@ use rand_distr::Distribution; -use crate::hal::{ +use crate::{ api::{ VecZnxBigAdd, VecZnxBigAddDistF64, VecZnxBigAddInplace, VecZnxBigAddNormal, VecZnxBigAddSmall, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigAutomorphism, VecZnxBigAutomorphismInplace, VecZnxBigFillDistF64, diff --git a/poulpy-backend/src/hal/delegates/vec_znx_dft.rs b/poulpy-hal/src/delegates/vec_znx_dft.rs similarity index 99% rename from poulpy-backend/src/hal/delegates/vec_znx_dft.rs rename to poulpy-hal/src/delegates/vec_znx_dft.rs index 75744f4..7cf602b 100644 --- a/poulpy-backend/src/hal/delegates/vec_znx_dft.rs +++ b/poulpy-hal/src/delegates/vec_znx_dft.rs @@ -1,4 +1,4 @@ -use crate::hal::{ +use crate::{ api::{ VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftCopy, VecZnxDftFromBytes, VecZnxDftFromVecZnx, VecZnxDftSub, VecZnxDftSubABInplace, VecZnxDftSubBAInplace, VecZnxDftToVecZnxBig, diff --git a/poulpy-backend/src/hal/delegates/vmp_pmat.rs b/poulpy-hal/src/delegates/vmp_pmat.rs similarity index 99% rename from poulpy-backend/src/hal/delegates/vmp_pmat.rs rename to poulpy-hal/src/delegates/vmp_pmat.rs index 10343eb..9465d9e 100644 --- a/poulpy-backend/src/hal/delegates/vmp_pmat.rs +++ b/poulpy-hal/src/delegates/vmp_pmat.rs @@ -1,4 +1,4 @@ -use crate::hal::{ +use crate::{ api::{ VmpApply, VmpApplyAdd, VmpApplyAddTmpBytes, VmpApplyTmpBytes, VmpPMatAlloc, VmpPMatAllocBytes, VmpPMatFromBytes, VmpPrepare, VmpPrepareTmpBytes, diff --git a/poulpy-backend/src/hal/layouts/encoding.rs b/poulpy-hal/src/layouts/encoding.rs similarity index 99% rename from poulpy-backend/src/hal/layouts/encoding.rs rename to poulpy-hal/src/layouts/encoding.rs index abf27a6..717d90a 100644 --- a/poulpy-backend/src/hal/layouts/encoding.rs +++ b/poulpy-hal/src/layouts/encoding.rs @@ -1,7 +1,7 @@ use itertools::izip; use rug::{Assign, Float}; -use crate::hal::{ +use crate::{ api::{ZnxInfos, ZnxView, ZnxViewMut, ZnxZero}, layouts::{DataMut, DataRef, VecZnx, VecZnxToMut, VecZnxToRef}, }; diff --git a/poulpy-backend/src/hal/layouts/mat_znx.rs b/poulpy-hal/src/layouts/mat_znx.rs similarity index 95% rename from poulpy-backend/src/hal/layouts/mat_znx.rs rename to poulpy-hal/src/layouts/mat_znx.rs index acd4ec3..db2400b 100644 --- a/poulpy-backend/src/hal/layouts/mat_znx.rs +++ b/poulpy-hal/src/layouts/mat_znx.rs @@ -1,10 +1,8 @@ use crate::{ alloc_aligned, - hal::{ - api::{DataView, DataViewMut, FillUniform, Reset, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero}, - layouts::{Data, DataMut, DataRef, ReaderFrom, ToOwnedDeep, VecZnx, WriterTo}, - source::Source, - }, + api::{DataView, DataViewMut, FillUniform, Reset, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero}, + layouts::{Data, DataMut, DataRef, ReaderFrom, ToOwnedDeep, VecZnx, WriterTo}, + source::Source, }; use std::fmt; @@ -230,7 +228,7 @@ impl MatZnxToMut for MatZnx { } impl MatZnx { - pub(crate) fn from_data(data: D, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self { + pub fn from_data(data: D, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self { Self { data, n, diff --git a/poulpy-backend/src/hal/layouts/mod.rs b/poulpy-hal/src/layouts/mod.rs similarity index 100% rename from poulpy-backend/src/hal/layouts/mod.rs rename to poulpy-hal/src/layouts/mod.rs diff --git a/poulpy-backend/src/hal/layouts/module.rs b/poulpy-hal/src/layouts/module.rs similarity index 89% rename from poulpy-backend/src/hal/layouts/module.rs rename to poulpy-hal/src/layouts/module.rs index be55cb8..e885b13 100644 --- a/poulpy-backend/src/hal/layouts/module.rs +++ b/poulpy-hal/src/layouts/module.rs @@ -1,10 +1,16 @@ -use std::{marker::PhantomData, ptr::NonNull}; +use std::{fmt::Display, marker::PhantomData, ptr::NonNull}; + +use rand_distr::num_traits::Zero; use crate::GALOISGENERATOR; #[allow(clippy::missing_safety_doc)] pub trait Backend: Sized { + type ScalarBig: Copy + Zero + Display; + type ScalarPrep: Copy + Zero + Display; type Handle: 'static; + fn layout_prep_word_count() -> usize; + fn layout_big_word_count() -> usize; unsafe fn destroy(handle: NonNull); } diff --git a/poulpy-backend/src/hal/layouts/scalar_znx.rs b/poulpy-hal/src/layouts/scalar_znx.rs similarity index 94% rename from poulpy-backend/src/hal/layouts/scalar_znx.rs rename to poulpy-hal/src/layouts/scalar_znx.rs index 137eaae..679acb3 100644 --- a/poulpy-backend/src/hal/layouts/scalar_znx.rs +++ b/poulpy-hal/src/layouts/scalar_znx.rs @@ -4,18 +4,16 @@ use rand_distr::{Distribution, weighted::WeightedIndex}; use crate::{ alloc_aligned, - hal::{ - api::{DataView, DataViewMut, FillUniform, Reset, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero}, - layouts::{Data, DataMut, DataRef, ReaderFrom, ToOwnedDeep, VecZnx, WriterTo}, - source::Source, - }, + api::{DataView, DataViewMut, FillUniform, Reset, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero}, + layouts::{Data, DataMut, DataRef, ReaderFrom, ToOwnedDeep, VecZnx, WriterTo}, + source::Source, }; #[derive(PartialEq, Eq, Debug, Clone)] pub struct ScalarZnx { - pub(crate) data: D, - pub(crate) n: usize, - pub(crate) cols: usize, + pub data: D, + pub n: usize, + pub cols: usize, } impl ToOwnedDeep for ScalarZnx { @@ -161,7 +159,7 @@ impl Reset for ScalarZnx { pub type ScalarZnxOwned = ScalarZnx>; impl ScalarZnx { - pub(crate) fn from_data(data: D, n: usize, cols: usize) -> Self { + pub fn from_data(data: D, n: usize, cols: usize) -> Self { Self { data, n, cols } } } diff --git a/poulpy-hal/src/layouts/scratch.rs b/poulpy-hal/src/layouts/scratch.rs new file mode 100644 index 0000000..695883a --- /dev/null +++ b/poulpy-hal/src/layouts/scratch.rs @@ -0,0 +1,13 @@ +use std::marker::PhantomData; + +use crate::layouts::Backend; + +pub struct ScratchOwned { + pub data: Vec, + pub _phantom: PhantomData, +} + +pub struct Scratch { + pub _phantom: PhantomData, + pub data: [u8], +} diff --git a/poulpy-backend/src/hal/layouts/serialization.rs b/poulpy-hal/src/layouts/serialization.rs similarity index 100% rename from poulpy-backend/src/hal/layouts/serialization.rs rename to poulpy-hal/src/layouts/serialization.rs diff --git a/poulpy-backend/src/hal/layouts/stats.rs b/poulpy-hal/src/layouts/stats.rs similarity index 98% rename from poulpy-backend/src/hal/layouts/stats.rs rename to poulpy-hal/src/layouts/stats.rs index b67b20d..5930c57 100644 --- a/poulpy-backend/src/hal/layouts/stats.rs +++ b/poulpy-hal/src/layouts/stats.rs @@ -4,7 +4,7 @@ use rug::{ ops::{AddAssignRound, DivAssignRound, SubAssignRound}, }; -use crate::hal::{ +use crate::{ api::ZnxInfos, layouts::{DataRef, VecZnx}, }; diff --git a/poulpy-backend/src/hal/layouts/svp_ppol.rs b/poulpy-hal/src/layouts/svp_ppol.rs similarity index 78% rename from poulpy-backend/src/hal/layouts/svp_ppol.rs rename to poulpy-hal/src/layouts/svp_ppol.rs index 877c470..7565ae8 100644 --- a/poulpy-backend/src/hal/layouts/svp_ppol.rs +++ b/poulpy-hal/src/layouts/svp_ppol.rs @@ -2,18 +2,27 @@ use std::marker::PhantomData; use crate::{ alloc_aligned, - hal::{ - api::{DataView, DataViewMut, ZnxInfos}, - layouts::{Backend, Data, DataMut, DataRef, ReaderFrom, WriterTo}, - }, + api::{DataView, DataViewMut, ZnxInfos, ZnxSliceSize, ZnxView}, + layouts::{Backend, Data, DataMut, DataRef, ReaderFrom, WriterTo}, + oep::SvpPPolAllocBytesImpl, }; #[derive(PartialEq, Eq)] pub struct SvpPPol { - data: D, - n: usize, - cols: usize, - _phantom: PhantomData, + pub data: D, + pub n: usize, + pub cols: usize, + pub _phantom: PhantomData, +} + +impl ZnxSliceSize for SvpPPol { + fn sl(&self) -> usize { + B::layout_prep_word_count() * self.n() + } +} + +impl ZnxView for SvpPPol { + type Scalar = B::ScalarPrep; } impl ZnxInfos for SvpPPol { @@ -47,16 +56,12 @@ impl DataViewMut for SvpPPol { } } -pub trait SvpPPolBytesOf { - fn bytes_of(n: usize, cols: usize) -> usize; -} - impl>, B: Backend> SvpPPol where - SvpPPol: SvpPPolBytesOf, + B: SvpPPolAllocBytesImpl, { - pub(crate) fn alloc(n: usize, cols: usize) -> Self { - let data: Vec = alloc_aligned::(Self::bytes_of(n, cols)); + pub fn alloc(n: usize, cols: usize) -> Self { + let data: Vec = alloc_aligned::(B::svp_ppol_alloc_bytes_impl(n, cols)); Self { data: data.into(), n, @@ -65,9 +70,9 @@ where } } - pub(crate) fn from_bytes(n: usize, cols: usize, bytes: impl Into>) -> Self { + pub fn from_bytes(n: usize, cols: usize, bytes: impl Into>) -> Self { let data: Vec = bytes.into(); - assert!(data.len() == Self::bytes_of(n, cols)); + assert!(data.len() == B::svp_ppol_alloc_bytes_impl(n, cols)); Self { data: data.into(), n, @@ -110,7 +115,7 @@ impl SvpPPolToMut for SvpPPol { } impl SvpPPol { - pub(crate) fn from_data(data: D, n: usize, cols: usize) -> Self { + pub fn from_data(data: D, n: usize, cols: usize) -> Self { Self { data, n, diff --git a/poulpy-backend/src/hal/layouts/vec_znx.rs b/poulpy-hal/src/layouts/vec_znx.rs similarity index 94% rename from poulpy-backend/src/hal/layouts/vec_znx.rs rename to poulpy-hal/src/layouts/vec_znx.rs index 765f095..1bd9129 100644 --- a/poulpy-backend/src/hal/layouts/vec_znx.rs +++ b/poulpy-hal/src/layouts/vec_znx.rs @@ -2,11 +2,9 @@ use std::fmt; use crate::{ alloc_aligned, - hal::{ - api::{DataView, DataViewMut, FillUniform, Reset, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero}, - layouts::{Data, DataMut, DataRef, ReaderFrom, ToOwnedDeep, WriterTo}, - source::Source, - }, + api::{DataView, DataViewMut, FillUniform, Reset, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero}, + layouts::{Data, DataMut, DataRef, ReaderFrom, ToOwnedDeep, WriterTo}, + source::Source, }; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; @@ -14,11 +12,11 @@ use rand::RngCore; #[derive(PartialEq, Eq, Clone, Copy)] pub struct VecZnx { - pub(crate) data: D, - pub(crate) n: usize, - pub(crate) cols: usize, - pub(crate) size: usize, - pub(crate) max_size: usize, + pub data: D, + pub n: usize, + pub cols: usize, + pub size: usize, + pub max_size: usize, } impl ToOwnedDeep for VecZnx { diff --git a/poulpy-backend/src/hal/layouts/vec_znx_big.rs b/poulpy-hal/src/layouts/vec_znx_big.rs similarity index 56% rename from poulpy-backend/src/hal/layouts/vec_znx_big.rs rename to poulpy-hal/src/layouts/vec_znx_big.rs index 0a48727..1a4e063 100644 --- a/poulpy-backend/src/hal/layouts/vec_znx_big.rs +++ b/poulpy-hal/src/layouts/vec_znx_big.rs @@ -1,23 +1,33 @@ use std::marker::PhantomData; use rand_distr::num_traits::Zero; +use std::fmt; use crate::{ alloc_aligned, - hal::{ - api::{DataView, DataViewMut, ZnxInfos, ZnxView, ZnxViewMut, ZnxZero}, - layouts::{Backend, Data, DataMut, DataRef}, - }, + api::{DataView, DataViewMut, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero}, + layouts::{Backend, Data, DataMut, DataRef}, + oep::VecZnxBigAllocBytesImpl, }; #[derive(PartialEq, Eq)] pub struct VecZnxBig { - pub(crate) data: D, - pub(crate) n: usize, - pub(crate) cols: usize, - pub(crate) size: usize, - pub(crate) max_size: usize, - pub(crate) _phantom: PhantomData, + pub data: D, + pub n: usize, + pub cols: usize, + pub size: usize, + pub max_size: usize, + pub _phantom: PhantomData, +} + +impl ZnxSliceSize for VecZnxBig { + fn sl(&self) -> usize { + B::layout_big_word_count() * self.n() * self.cols() + } +} + +impl ZnxView for VecZnxBig { + type Scalar = B::ScalarBig; } impl ZnxInfos for VecZnxBig { @@ -51,10 +61,6 @@ impl DataViewMut for VecZnxBig { } } -pub trait VecZnxBigBytesOf { - fn bytes_of(n: usize, cols: usize, size: usize) -> usize; -} - impl ZnxZero for VecZnxBig where Self: ZnxViewMut, @@ -70,10 +76,10 @@ where impl>, B: Backend> VecZnxBig where - VecZnxBig: VecZnxBigBytesOf, + B: VecZnxBigAllocBytesImpl, { - pub(crate) fn new(n: usize, cols: usize, size: usize) -> Self { - let data = alloc_aligned::(Self::bytes_of(n, cols, size)); + pub fn alloc(n: usize, cols: usize, size: usize) -> Self { + let data = alloc_aligned::(B::vec_znx_big_alloc_bytes_impl(n, cols, size)); Self { data: data.into(), n, @@ -84,9 +90,9 @@ where } } - pub(crate) fn new_from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into>) -> Self { + pub fn from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into>) -> Self { let data: Vec = bytes.into(); - assert!(data.len() == Self::bytes_of(n, cols, size)); + assert!(data.len() == B::vec_znx_big_alloc_bytes_impl(n, cols, size)); Self { data: data.into(), n, @@ -99,7 +105,7 @@ where } impl VecZnxBig { - pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self { + pub fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self { Self { data, n, @@ -146,3 +152,38 @@ impl VecZnxBigToMut for VecZnxBig { } } } + +impl fmt::Display for VecZnxBig { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!( + f, + "VecZnxBig(n={}, cols={}, size={})", + self.n, self.cols, self.size + )?; + + for col in 0..self.cols { + writeln!(f, "Column {}:", col)?; + for size in 0..self.size { + let coeffs = self.at(col, size); + write!(f, " Size {}: [", size)?; + + let max_show = 100; + let show_count = coeffs.len().min(max_show); + + for (i, &coeff) in coeffs.iter().take(show_count).enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", coeff)?; + } + + if coeffs.len() > max_show { + write!(f, ", ... ({} more)", coeffs.len() - max_show)?; + } + + writeln!(f, "]")?; + } + } + Ok(()) + } +} diff --git a/poulpy-backend/src/hal/layouts/vec_znx_dft.rs b/poulpy-hal/src/layouts/vec_znx_dft.rs similarity index 59% rename from poulpy-backend/src/hal/layouts/vec_znx_dft.rs rename to poulpy-hal/src/layouts/vec_znx_dft.rs index 709a52a..bc9ce34 100644 --- a/poulpy-backend/src/hal/layouts/vec_znx_dft.rs +++ b/poulpy-hal/src/layouts/vec_znx_dft.rs @@ -1,22 +1,31 @@ -use std::marker::PhantomData; +use std::{fmt, marker::PhantomData}; use rand_distr::num_traits::Zero; use crate::{ alloc_aligned, - hal::{ - api::{DataView, DataViewMut, ZnxInfos, ZnxView, ZnxViewMut, ZnxZero}, - layouts::{Backend, Data, DataMut, DataRef, VecZnxBig}, - }, + api::{DataView, DataViewMut, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero}, + layouts::{Backend, Data, DataMut, DataRef, VecZnxBig}, + oep::VecZnxBigAllocBytesImpl, }; #[derive(PartialEq, Eq)] pub struct VecZnxDft { - pub(crate) data: D, - pub(crate) n: usize, - pub(crate) cols: usize, - pub(crate) size: usize, - pub(crate) max_size: usize, - pub(crate) _phantom: PhantomData, + pub data: D, + pub n: usize, + pub cols: usize, + pub size: usize, + pub max_size: usize, + pub _phantom: PhantomData, +} + +impl ZnxSliceSize for VecZnxDft { + fn sl(&self) -> usize { + B::layout_prep_word_count() * self.n() * self.cols() + } +} + +impl ZnxView for VecZnxDft { + type Scalar = B::ScalarPrep; } impl VecZnxDft { @@ -82,16 +91,12 @@ where } } -pub trait VecZnxDftBytesOf { - fn bytes_of(n: usize, cols: usize, size: usize) -> usize; -} - impl>, B: Backend> VecZnxDft where - VecZnxDft: VecZnxDftBytesOf, + B: VecZnxBigAllocBytesImpl, { - pub(crate) fn alloc(n: usize, cols: usize, size: usize) -> Self { - let data: Vec = alloc_aligned::(Self::bytes_of(n, cols, size)); + pub fn alloc(n: usize, cols: usize, size: usize) -> Self { + let data: Vec = alloc_aligned::(B::vec_znx_big_alloc_bytes_impl(n, cols, size)); Self { data: data.into(), n, @@ -102,9 +107,9 @@ where } } - pub(crate) fn from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into>) -> Self { + pub fn from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into>) -> Self { let data: Vec = bytes.into(); - assert!(data.len() == Self::bytes_of(n, cols, size)); + assert!(data.len() == B::vec_znx_big_alloc_bytes_impl(n, cols, size)); Self { data: data.into(), n, @@ -119,7 +124,7 @@ where pub type VecZnxDftOwned = VecZnxDft, B>; impl VecZnxDft { - pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self { + pub fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self { Self { data, n, @@ -164,3 +169,38 @@ impl VecZnxDftToMut for VecZnxDft { } } } + +impl fmt::Display for VecZnxDft { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!( + f, + "VecZnxDft(n={}, cols={}, size={})", + self.n, self.cols, self.size + )?; + + for col in 0..self.cols { + writeln!(f, "Column {}:", col)?; + for size in 0..self.size { + let coeffs = self.at(col, size); + write!(f, " Size {}: [", size)?; + + let max_show = 100; + let show_count = coeffs.len().min(max_show); + + for (i, &coeff) in coeffs.iter().take(show_count).enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", coeff)?; + } + + if coeffs.len() > max_show { + write!(f, ", ... ({} more)", coeffs.len() - max_show)?; + } + + writeln!(f, "]")?; + } + } + Ok(()) + } +} diff --git a/poulpy-backend/src/hal/layouts/vmp_pmat.rs b/poulpy-hal/src/layouts/vmp_pmat.rs similarity index 76% rename from poulpy-backend/src/hal/layouts/vmp_pmat.rs rename to poulpy-hal/src/layouts/vmp_pmat.rs index 4a0c387..2cfd0e2 100644 --- a/poulpy-backend/src/hal/layouts/vmp_pmat.rs +++ b/poulpy-hal/src/layouts/vmp_pmat.rs @@ -2,10 +2,9 @@ use std::marker::PhantomData; use crate::{ alloc_aligned, - hal::{ - api::{DataView, DataViewMut, ZnxInfos}, - layouts::{Backend, Data, DataMut, DataRef}, - }, + api::{DataView, DataViewMut, ZnxInfos, ZnxView}, + layouts::{Backend, Data, DataMut, DataRef}, + oep::VmpPMatAllocBytesImpl, }; #[derive(PartialEq, Eq)] @@ -19,6 +18,10 @@ pub struct VmpPMat { _phantom: PhantomData, } +impl ZnxView for VmpPMat { + type Scalar = B::ScalarPrep; +} + impl ZnxInfos for VmpPMat { fn cols(&self) -> usize { self.cols_in @@ -60,16 +63,14 @@ impl VmpPMat { } } -pub trait VmpPMatBytesOf { - fn vmp_pmat_bytes_of(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize; -} - impl>, B: Backend> VmpPMat where - B: VmpPMatBytesOf, + B: VmpPMatAllocBytesImpl, { - pub(crate) fn alloc(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self { - let data: Vec = alloc_aligned(B::vmp_pmat_bytes_of(n, rows, cols_in, cols_out, size)); + pub fn alloc(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self { + let data: Vec = alloc_aligned(B::vmp_pmat_alloc_bytes_impl( + n, rows, cols_in, cols_out, size, + )); Self { data: data.into(), n, @@ -81,16 +82,9 @@ where } } - pub(crate) fn from_bytes( - n: usize, - rows: usize, - cols_in: usize, - cols_out: usize, - size: usize, - bytes: impl Into>, - ) -> Self { + pub fn from_bytes(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: impl Into>) -> Self { let data: Vec = bytes.into(); - assert!(data.len() == B::vmp_pmat_bytes_of(n, rows, cols_in, cols_out, size)); + assert!(data.len() == B::vmp_pmat_alloc_bytes_impl(n, rows, cols_in, cols_out, size)); Self { data: data.into(), n, @@ -143,7 +137,7 @@ impl VmpPMatToMut for VmpPMat { } impl VmpPMat { - pub(crate) fn from_data(data: D, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self { + pub fn from_data(data: D, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self { Self { data, n, diff --git a/poulpy-hal/src/lib.rs b/poulpy-hal/src/lib.rs new file mode 100644 index 0000000..2199c9f --- /dev/null +++ b/poulpy-hal/src/lib.rs @@ -0,0 +1,110 @@ +#![allow(non_camel_case_types, non_snake_case, non_upper_case_globals, dead_code, improper_ctypes)] +#![deny(rustdoc::broken_intra_doc_links)] +#![cfg_attr(docsrs, feature(doc_cfg))] +#![feature(trait_alias)] + +pub mod api; +pub mod delegates; +pub mod layouts; +pub mod oep; +pub mod source; +pub mod tests; + +pub mod doc { + #[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/docs/backend_safety_contract.md"))] + pub mod backend_safety { + pub const _PLACEHOLDER: () = (); + } +} + +pub const GALOISGENERATOR: u64 = 5; +pub const DEFAULTALIGN: usize = 64; + +fn is_aligned_custom(ptr: *const T, align: usize) -> bool { + (ptr as usize).is_multiple_of(align) +} + +pub fn is_aligned(ptr: *const T) -> bool { + is_aligned_custom(ptr, DEFAULTALIGN) +} + +pub fn assert_alignement(ptr: *const T) { + assert!( + is_aligned(ptr), + "invalid alignement: ensure passed bytes have been allocated with [alloc_aligned_u8] or [alloc_aligned]" + ) +} + +pub fn cast(data: &[T]) -> &[V] { + let ptr: *const V = data.as_ptr() as *const V; + let len: usize = data.len() / size_of::(); + unsafe { std::slice::from_raw_parts(ptr, len) } +} + +#[allow(clippy::mut_from_ref)] +pub fn cast_mut(data: &[T]) -> &mut [V] { + let ptr: *mut V = data.as_ptr() as *mut V; + let len: usize = data.len() / size_of::(); + unsafe { std::slice::from_raw_parts_mut(ptr, len) } +} + +/// Allocates a block of bytes with a custom alignement. +/// Alignement must be a power of two and size a multiple of the alignement. +/// Allocated memory is initialized to zero. +fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec { + assert!( + align.is_power_of_two(), + "Alignment must be a power of two but is {}", + align + ); + assert_eq!( + (size * size_of::()) % align, + 0, + "size={} must be a multiple of align={}", + size, + align + ); + unsafe { + let layout: std::alloc::Layout = std::alloc::Layout::from_size_align(size, align).expect("Invalid alignment"); + let ptr: *mut u8 = std::alloc::alloc(layout); + if ptr.is_null() { + panic!("Memory allocation failed"); + } + assert!( + is_aligned_custom(ptr, align), + "Memory allocation at {:p} is not aligned to {} bytes", + ptr, + align + ); + // Init allocated memory to zero + std::ptr::write_bytes(ptr, 0, size); + Vec::from_raw_parts(ptr, size, size) + } +} + +/// Allocates a block of T aligned with [DEFAULTALIGN]. +/// Size of T * size msut be a multiple of [DEFAULTALIGN]. +pub fn alloc_aligned_custom(size: usize, align: usize) -> Vec { + assert_eq!( + (size * size_of::()) % (align / size_of::()), + 0, + "size={} must be a multiple of align={}", + size, + align + ); + let mut vec_u8: Vec = alloc_aligned_custom_u8(size_of::() * size, align); + let ptr: *mut T = vec_u8.as_mut_ptr() as *mut T; + let len: usize = vec_u8.len() / size_of::(); + let cap: usize = vec_u8.capacity() / size_of::(); + std::mem::forget(vec_u8); + unsafe { Vec::from_raw_parts(ptr, len, cap) } +} + +/// Allocates an aligned vector of size equal to the smallest multiple +/// of [DEFAULTALIGN]/`size_of::`() that is equal or greater to `size`. +pub fn alloc_aligned(size: usize) -> Vec { + alloc_aligned_custom::( + size + (DEFAULTALIGN - (size % (DEFAULTALIGN / size_of::()))) % DEFAULTALIGN, + DEFAULTALIGN, + ) +} diff --git a/poulpy-backend/src/hal/oep/mod.rs b/poulpy-hal/src/oep/mod.rs similarity index 100% rename from poulpy-backend/src/hal/oep/mod.rs rename to poulpy-hal/src/oep/mod.rs diff --git a/poulpy-backend/src/hal/oep/module.rs b/poulpy-hal/src/oep/module.rs similarity index 86% rename from poulpy-backend/src/hal/oep/module.rs rename to poulpy-hal/src/oep/module.rs index df01329..760a872 100644 --- a/poulpy-backend/src/hal/oep/module.rs +++ b/poulpy-hal/src/oep/module.rs @@ -1,4 +1,4 @@ -use crate::hal::layouts::{Backend, Module}; +use crate::layouts::{Backend, Module}; /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See TODO for reference code. diff --git a/poulpy-backend/src/hal/oep/scratch.rs b/poulpy-hal/src/oep/scratch.rs similarity index 99% rename from poulpy-backend/src/hal/oep/scratch.rs rename to poulpy-hal/src/oep/scratch.rs index cd7f92c..da8d28e 100644 --- a/poulpy-backend/src/hal/oep/scratch.rs +++ b/poulpy-hal/src/oep/scratch.rs @@ -1,4 +1,4 @@ -use crate::hal::{ +use crate::{ api::ZnxInfos, layouts::{Backend, DataRef, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat}, }; diff --git a/poulpy-backend/src/hal/oep/svp_ppol.rs b/poulpy-hal/src/oep/svp_ppol.rs similarity index 94% rename from poulpy-backend/src/hal/oep/svp_ppol.rs rename to poulpy-hal/src/oep/svp_ppol.rs index 7509532..81668ca 100644 --- a/poulpy-backend/src/hal/oep/svp_ppol.rs +++ b/poulpy-hal/src/oep/svp_ppol.rs @@ -1,6 +1,4 @@ -use crate::hal::layouts::{ - Backend, Module, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDftToMut, VecZnxDftToRef, -}; +use crate::layouts::{Backend, Module, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDftToMut, VecZnxDftToRef}; /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See TODO for reference code. diff --git a/poulpy-backend/src/hal/oep/vec_znx.rs b/poulpy-hal/src/oep/vec_znx.rs similarity index 83% rename from poulpy-backend/src/hal/oep/vec_znx.rs rename to poulpy-hal/src/oep/vec_znx.rs index 975a8b8..0aac0dc 100644 --- a/poulpy-backend/src/hal/oep/vec_znx.rs +++ b/poulpy-hal/src/oep/vec_znx.rs @@ -1,13 +1,13 @@ use rand_distr::Distribution; -use crate::hal::{ +use crate::{ layouts::{Backend, Module, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef}, source::Source, }; /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See [vec_znx_normalize_base2k_tmp_bytes_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L245C17-L245C55) for reference code. -/// * See [crate::hal::api::VecZnxNormalizeTmpBytes] for corresponding public API. +/// * See [crate::api::VecZnxNormalizeTmpBytes] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxNormalizeTmpBytesImpl { fn vec_znx_normalize_tmp_bytes_impl(module: &Module, n: usize) -> usize; @@ -15,7 +15,7 @@ pub unsafe trait VecZnxNormalizeTmpBytesImpl { /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See [vec_znx_normalize_base2k_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L212) for reference code. -/// * See [crate::hal::api::VecZnxNormalize] for corresponding public API. +/// * See [crate::api::VecZnxNormalize] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxNormalizeImpl { fn vec_znx_normalize_impl( @@ -33,7 +33,7 @@ pub unsafe trait VecZnxNormalizeImpl { /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See [vec_znx_normalize_base2k_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L212) for reference code. -/// * See [crate::hal::api::VecZnxNormalizeInplace] for corresponding public API. +/// * See [crate::api::VecZnxNormalizeInplace] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxNormalizeInplaceImpl { fn vec_znx_normalize_inplace_impl(module: &Module, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) @@ -43,7 +43,7 @@ pub unsafe trait VecZnxNormalizeInplaceImpl { /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See [vec_znx_add_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L86) for reference code. -/// * See [crate::hal::api::VecZnxAdd] for corresponding public API. +/// * See [crate::api::VecZnxAdd] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxAddImpl { fn vec_znx_add_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize) @@ -55,7 +55,7 @@ pub unsafe trait VecZnxAddImpl { /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See [vec_znx_add_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L86) for reference code. -/// * See [crate::hal::api::VecZnxAddInplace] for corresponding public API. +/// * See [crate::api::VecZnxAddInplace] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxAddInplaceImpl { fn vec_znx_add_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) @@ -66,7 +66,7 @@ pub unsafe trait VecZnxAddInplaceImpl { /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See [vec_znx_add_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L86) for reference code. -/// * See [crate::hal::api::VecZnxAddScalarInplace] for corresponding public API. +/// * See [crate::api::VecZnxAddScalarInplace] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxAddScalarInplaceImpl { fn vec_znx_add_scalar_inplace_impl( @@ -83,7 +83,7 @@ pub unsafe trait VecZnxAddScalarInplaceImpl { /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See [vec_znx_sub_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L125) for reference code. -/// * See [crate::hal::api::VecZnxSub] for corresponding public API. +/// * See [crate::api::VecZnxSub] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxSubImpl { fn vec_znx_sub_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize) @@ -95,7 +95,7 @@ pub unsafe trait VecZnxSubImpl { /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See [vec_znx_sub_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L125) for reference code. -/// * See [crate::hal::api::VecZnxSubABInplace] for corresponding public API. +/// * See [crate::api::VecZnxSubABInplace] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxSubABInplaceImpl { fn vec_znx_sub_ab_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) @@ -106,7 +106,7 @@ pub unsafe trait VecZnxSubABInplaceImpl { /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See [vec_znx_sub_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L125) for reference code. -/// * See [crate::hal::api::VecZnxSubBAInplace] for corresponding public API. +/// * See [crate::api::VecZnxSubBAInplace] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxSubBAInplaceImpl { fn vec_znx_sub_ba_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) @@ -117,7 +117,7 @@ pub unsafe trait VecZnxSubBAInplaceImpl { /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See [vec_znx_sub_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L125) for reference code. -/// * See [crate::hal::api::VecZnxSubScalarInplace] for corresponding public API. +/// * See [crate::api::VecZnxSubScalarInplace] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxSubScalarInplaceImpl { fn vec_znx_sub_scalar_inplace_impl( @@ -134,7 +134,7 @@ pub unsafe trait VecZnxSubScalarInplaceImpl { /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See [vec_znx_negate_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L322C13-L322C31) for reference code. -/// * See [crate::hal::api::VecZnxNegate] for corresponding public API. +/// * See [crate::api::VecZnxNegate] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxNegateImpl { fn vec_znx_negate_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) @@ -145,7 +145,7 @@ pub unsafe trait VecZnxNegateImpl { /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See [vec_znx_negate_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L322C13-L322C31) for reference code. -/// * See [crate::hal::api::VecZnxNegateInplace] for corresponding public API. +/// * See [crate::api::VecZnxNegateInplace] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxNegateInplaceImpl { fn vec_znx_negate_inplace_impl(module: &Module, a: &mut A, a_col: usize) @@ -154,8 +154,8 @@ pub unsafe trait VecZnxNegateInplaceImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [crate::implementation::cpu_spqlios::vec_znx::vec_znx_rsh_inplace_ref] for reference code. -/// * See [crate::hal::api::VecZnxRshInplace] for corresponding public API. +/// * See [crate::cpu_spqlios::vec_znx::vec_znx_rsh_inplace_ref] for reference code. +/// * See [crate::api::VecZnxRshInplace] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxRshInplaceImpl { fn vec_znx_rsh_inplace_impl(module: &Module, basek: usize, k: usize, a: &mut A) @@ -164,8 +164,8 @@ pub unsafe trait VecZnxRshInplaceImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [crate::implementation::cpu_spqlios::vec_znx::vec_znx_lsh_inplace_ref] for reference code. -/// * See [crate::hal::api::VecZnxLshInplace] for corresponding public API. +/// * See [crate::cpu_spqlios::vec_znx::vec_znx_lsh_inplace_ref] for reference code. +/// * See [crate::api::VecZnxLshInplace] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxLshInplaceImpl { fn vec_znx_lsh_inplace_impl(module: &Module, basek: usize, k: usize, a: &mut A) @@ -175,7 +175,7 @@ pub unsafe trait VecZnxLshInplaceImpl { /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See [vec_znx_rotate_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L164) for reference code. -/// * See [crate::hal::api::VecZnxRotate] for corresponding public API. +/// * See [crate::api::VecZnxRotate] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxRotateImpl { fn vec_znx_rotate_impl(module: &Module, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) @@ -186,7 +186,7 @@ pub unsafe trait VecZnxRotateImpl { /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See [vec_znx_rotate_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L164) for reference code. -/// * See [crate::hal::api::VecZnxRotateInplace] for corresponding public API. +/// * See [crate::api::VecZnxRotateInplace] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxRotateInplaceImpl { fn vec_znx_rotate_inplace_impl(module: &Module, k: i64, a: &mut A, a_col: usize) @@ -196,7 +196,7 @@ pub unsafe trait VecZnxRotateInplaceImpl { /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See [vec_znx_automorphism_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L188) for reference code. -/// * See [crate::hal::api::VecZnxAutomorphism] for corresponding public API. +/// * See [crate::api::VecZnxAutomorphism] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxAutomorphismImpl { fn vec_znx_automorphism_impl(module: &Module, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) @@ -207,7 +207,7 @@ pub unsafe trait VecZnxAutomorphismImpl { /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See [vec_znx_automorphism_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L188) for reference code. -/// * See [crate::hal::api::VecZnxAutomorphismInplace] for corresponding public API. +/// * See [crate::api::VecZnxAutomorphismInplace] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxAutomorphismInplaceImpl { fn vec_znx_automorphism_inplace_impl(module: &Module, k: i64, a: &mut A, a_col: usize) @@ -217,7 +217,7 @@ pub unsafe trait VecZnxAutomorphismInplaceImpl { /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See [vec_znx_mul_xp_minus_one_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/7160f588da49712a042931ea247b4259b95cefcc/spqlios/arithmetic/vec_znx.c#L200C13-L200C41) for reference code. -/// * See [crate::hal::api::VecZnxMulXpMinusOne] for corresponding public API. +/// * See [crate::api::VecZnxMulXpMinusOne] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxMulXpMinusOneImpl { fn vec_znx_mul_xp_minus_one_impl(module: &Module, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) @@ -228,7 +228,7 @@ pub unsafe trait VecZnxMulXpMinusOneImpl { /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See [vec_znx_mul_xp_minus_one_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/7160f588da49712a042931ea247b4259b95cefcc/spqlios/arithmetic/vec_znx.c#L200C13-L200C41) for reference code. -/// * See [crate::hal::api::VecZnxMulXpMinusOneInplace] for corresponding public API. +/// * See [crate::api::VecZnxMulXpMinusOneInplace] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxMulXpMinusOneInplaceImpl { fn vec_znx_mul_xp_minus_one_inplace_impl(module: &Module, p: i64, res: &mut R, res_col: usize) @@ -237,8 +237,8 @@ pub unsafe trait VecZnxMulXpMinusOneInplaceImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [crate::implementation::cpu_spqlios::vec_znx::vec_znx_split_ref] for reference code. -/// * See [crate::hal::api::VecZnxSplit] for corresponding public API. +/// * See [crate::cpu_spqlios::vec_znx::vec_znx_split_ref] for reference code. +/// * See [crate::api::VecZnxSplit] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxSplitImpl { fn vec_znx_split_impl(module: &Module, res: &mut [R], res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) @@ -248,8 +248,8 @@ pub unsafe trait VecZnxSplitImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [crate::implementation::cpu_spqlios::vec_znx::vec_znx_merge_ref] for reference code. -/// * See [crate::hal::api::VecZnxMerge] for corresponding public API. +/// * See [crate::cpu_spqlios::vec_znx::vec_znx_merge_ref] for reference code. +/// * See [crate::api::VecZnxMerge] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxMergeImpl { fn vec_znx_merge_impl(module: &Module, res: &mut R, res_col: usize, a: &[A], a_col: usize) @@ -259,8 +259,8 @@ pub unsafe trait VecZnxMergeImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [crate::implementation::cpu_spqlios::vec_znx::vec_znx_switch_degree_ref] for reference code. -/// * See [crate::hal::api::VecZnxSwithcDegree] for corresponding public API. +/// * See [crate::cpu_spqlios::vec_znx::vec_znx_switch_degree_ref] for reference code. +/// * See [crate::api::VecZnxSwithcDegree] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxSwithcDegreeImpl { fn vec_znx_switch_degree_impl( @@ -273,8 +273,8 @@ pub unsafe trait VecZnxSwithcDegreeImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [crate::implementation::cpu_spqlios::vec_znx::vec_znx_copy_ref] for reference code. -/// * See [crate::hal::api::VecZnxCopy] for corresponding public API. +/// * See [crate::cpu_spqlios::vec_znx::vec_znx_copy_ref] for reference code. +/// * See [crate::api::VecZnxCopy] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxCopyImpl { fn vec_znx_copy_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) @@ -284,7 +284,7 @@ pub unsafe trait VecZnxCopyImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [crate::hal::api::VecZnxFillUniform] for corresponding public API. +/// * See [crate::api::VecZnxFillUniform] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxFillUniformImpl { fn vec_znx_fill_uniform_impl(module: &Module, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source) @@ -294,7 +294,7 @@ pub unsafe trait VecZnxFillUniformImpl { #[allow(clippy::too_many_arguments)] /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [crate::hal::api::VecZnxFillDistF64] for corresponding public API. +/// * See [crate::api::VecZnxFillDistF64] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxFillDistF64Impl { fn vec_znx_fill_dist_f64_impl>( @@ -312,7 +312,7 @@ pub unsafe trait VecZnxFillDistF64Impl { #[allow(clippy::too_many_arguments)] /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [crate::hal::api::VecZnxAddDistF64] for corresponding public API. +/// * See [crate::api::VecZnxAddDistF64] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxAddDistF64Impl { fn vec_znx_add_dist_f64_impl>( @@ -330,7 +330,7 @@ pub unsafe trait VecZnxAddDistF64Impl { #[allow(clippy::too_many_arguments)] /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [crate::hal::api::VecZnxFillNormal] for corresponding public API. +/// * See [crate::api::VecZnxFillNormal] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxFillNormalImpl { fn vec_znx_fill_normal_impl( @@ -348,7 +348,7 @@ pub unsafe trait VecZnxFillNormalImpl { #[allow(clippy::too_many_arguments)] /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [crate::hal::api::VecZnxAddNormal] for corresponding public API. +/// * See [crate::api::VecZnxAddNormal] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxAddNormalImpl { fn vec_znx_add_normal_impl( diff --git a/poulpy-backend/src/hal/oep/vec_znx_big.rs b/poulpy-hal/src/oep/vec_znx_big.rs similarity index 99% rename from poulpy-backend/src/hal/oep/vec_znx_big.rs rename to poulpy-hal/src/oep/vec_znx_big.rs index 7420be8..3c13393 100644 --- a/poulpy-backend/src/hal/oep/vec_znx_big.rs +++ b/poulpy-hal/src/oep/vec_znx_big.rs @@ -1,6 +1,6 @@ use rand_distr::Distribution; -use crate::hal::{ +use crate::{ layouts::{Backend, Module, Scratch, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef}, source::Source, }; diff --git a/poulpy-backend/src/hal/oep/vec_znx_dft.rs b/poulpy-hal/src/oep/vec_znx_dft.rs similarity index 99% rename from poulpy-backend/src/hal/oep/vec_znx_dft.rs rename to poulpy-hal/src/oep/vec_znx_dft.rs index 3ab7a7e..4a962ea 100644 --- a/poulpy-backend/src/hal/oep/vec_znx_dft.rs +++ b/poulpy-hal/src/oep/vec_znx_dft.rs @@ -1,4 +1,4 @@ -use crate::hal::layouts::{ +use crate::layouts::{ Backend, Data, Module, Scratch, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, }; diff --git a/poulpy-backend/src/hal/oep/vmp_pmat.rs b/poulpy-hal/src/oep/vmp_pmat.rs similarity index 99% rename from poulpy-backend/src/hal/oep/vmp_pmat.rs rename to poulpy-hal/src/oep/vmp_pmat.rs index 1f7aeb6..b7971cf 100644 --- a/poulpy-backend/src/hal/oep/vmp_pmat.rs +++ b/poulpy-hal/src/oep/vmp_pmat.rs @@ -1,4 +1,4 @@ -use crate::hal::layouts::{ +use crate::layouts::{ Backend, MatZnxToRef, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VmpPMatOwned, VmpPMatToMut, VmpPMatToRef, }; diff --git a/poulpy-backend/src/hal/source.rs b/poulpy-hal/src/source.rs similarity index 100% rename from poulpy-backend/src/hal/source.rs rename to poulpy-hal/src/source.rs diff --git a/poulpy-backend/src/hal/tests/mod.rs b/poulpy-hal/src/tests/mod.rs similarity index 68% rename from poulpy-backend/src/hal/tests/mod.rs rename to poulpy-hal/src/tests/mod.rs index e48d734..a805d82 100644 --- a/poulpy-backend/src/hal/tests/mod.rs +++ b/poulpy-hal/src/tests/mod.rs @@ -1,2 +1,3 @@ pub mod serialization; pub mod vec_znx; +pub mod vmp_pmat; diff --git a/poulpy-backend/src/hal/tests/serialization.rs b/poulpy-hal/src/tests/serialization.rs similarity index 79% rename from poulpy-backend/src/hal/tests/serialization.rs rename to poulpy-hal/src/tests/serialization.rs index 2943726..ba54677 100644 --- a/poulpy-backend/src/hal/tests/serialization.rs +++ b/poulpy-hal/src/tests/serialization.rs @@ -1,6 +1,6 @@ use std::fmt::Debug; -use crate::hal::{ +use crate::{ api::{FillUniform, Reset}, layouts::{ReaderFrom, WriterTo}, source::Source, @@ -38,18 +38,18 @@ where #[test] fn scalar_znx_serialize() { - let original: crate::hal::layouts::ScalarZnx> = crate::hal::layouts::ScalarZnx::alloc(1024, 3); + let original: crate::layouts::ScalarZnx> = crate::layouts::ScalarZnx::alloc(1024, 3); test_reader_writer_interface(original); } #[test] fn vec_znx_serialize() { - let original: crate::hal::layouts::VecZnx> = crate::hal::layouts::VecZnx::alloc(1024, 3, 4); + let original: crate::layouts::VecZnx> = crate::layouts::VecZnx::alloc(1024, 3, 4); test_reader_writer_interface(original); } #[test] fn mat_znx_serialize() { - let original: crate::hal::layouts::MatZnx> = crate::hal::layouts::MatZnx::alloc(1024, 3, 2, 2, 4); + let original: crate::layouts::MatZnx> = crate::layouts::MatZnx::alloc(1024, 3, 2, 2, 4); test_reader_writer_interface(original); } diff --git a/poulpy-backend/src/hal/tests/vec_znx/encoding.rs b/poulpy-hal/src/tests/vec_znx/encoding.rs similarity index 99% rename from poulpy-backend/src/hal/tests/vec_znx/encoding.rs rename to poulpy-hal/src/tests/vec_znx/encoding.rs index 1885a82..eac449c 100644 --- a/poulpy-backend/src/hal/tests/vec_znx/encoding.rs +++ b/poulpy-hal/src/tests/vec_znx/encoding.rs @@ -1,4 +1,4 @@ -use crate::hal::{ +use crate::{ api::{ZnxInfos, ZnxViewMut}, layouts::VecZnx, source::Source, diff --git a/poulpy-backend/src/hal/tests/vec_znx/generics.rs b/poulpy-hal/src/tests/vec_znx/generics.rs similarity index 99% rename from poulpy-backend/src/hal/tests/vec_znx/generics.rs rename to poulpy-hal/src/tests/vec_znx/generics.rs index 1d2f714..498abfe 100644 --- a/poulpy-backend/src/hal/tests/vec_znx/generics.rs +++ b/poulpy-hal/src/tests/vec_znx/generics.rs @@ -1,4 +1,4 @@ -use crate::hal::{ +use crate::{ api::{VecZnxAddNormal, VecZnxFillUniform, ZnxView}, layouts::{Backend, Module, VecZnx}, source::Source, diff --git a/poulpy-backend/src/hal/tests/vec_znx/mod.rs b/poulpy-hal/src/tests/vec_znx/mod.rs similarity index 100% rename from poulpy-backend/src/hal/tests/vec_znx/mod.rs rename to poulpy-hal/src/tests/vec_znx/mod.rs diff --git a/poulpy-hal/src/tests/vmp_pmat/mod.rs b/poulpy-hal/src/tests/vmp_pmat/mod.rs new file mode 100644 index 0000000..2847cc3 --- /dev/null +++ b/poulpy-hal/src/tests/vmp_pmat/mod.rs @@ -0,0 +1,3 @@ +mod vmp_apply; + +pub use vmp_apply::*; diff --git a/poulpy-hal/src/tests/vmp_pmat/vmp_apply.rs b/poulpy-hal/src/tests/vmp_pmat/vmp_apply.rs new file mode 100644 index 0000000..a81d7be --- /dev/null +++ b/poulpy-hal/src/tests/vmp_pmat/vmp_apply.rs @@ -0,0 +1,114 @@ +use crate::{ + api::{ + ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxBigAlloc, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, + VecZnxDftAlloc, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigTmpA, VmpApply, VmpApplyTmpBytes, VmpPMatAlloc, VmpPrepare, + ZnxInfos, ZnxViewMut, + }, + layouts::{MatZnx, Module, ScratchOwned, VecZnx, VecZnxBig, VecZnxDft, VmpPMat}, + oep::{ + ModuleNewImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, VecZnxBigAllocImpl, VecZnxBigNormalizeImpl, + VecZnxBigNormalizeTmpBytesImpl, VecZnxDftAllocImpl, VecZnxDftFromVecZnxImpl, VecZnxDftToVecZnxBigTmpAImpl, VmpApplyImpl, + VmpApplyTmpBytesImpl, VmpPMatAllocImpl, VmpPMatPrepareImpl, + }, +}; + +use crate::layouts::Backend; + +pub fn test_vmp_apply() +where + B: Backend + + ModuleNewImpl + + VmpApplyTmpBytesImpl + + VecZnxBigNormalizeTmpBytesImpl + + VmpPMatAllocImpl + + VecZnxDftAllocImpl + + VecZnxBigAllocImpl + + VmpPMatPrepareImpl + + VecZnxDftFromVecZnxImpl + + VmpApplyImpl + + VecZnxDftToVecZnxBigTmpAImpl + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + VecZnxBigNormalizeImpl, +{ + let log_n: i32 = 5; + let n: usize = 1 << log_n; + + let module: Module = Module::::new(n as u64); + let basek: usize = 15; + let a_size: usize = 5; + let mat_size: usize = 6; + let res_size: usize = a_size; + + [1, 2].iter().for_each(|cols_in| { + [1, 2].iter().for_each(|cols_out| { + let a_cols: usize = *cols_in; + let res_cols: usize = *cols_out; + + let mat_rows: usize = a_size; + let mat_cols_in: usize = a_cols; + let mat_cols_out: usize = res_cols; + + let mut scratch = ScratchOwned::alloc( + module.vmp_apply_tmp_bytes( + n, + res_size, + a_size, + mat_rows, + mat_cols_in, + mat_cols_out, + mat_size, + ) | module.vec_znx_big_normalize_tmp_bytes(n), + ); + + let mut a: VecZnx> = VecZnx::alloc(n, a_cols, a_size); + + (0..a_cols).for_each(|i| { + a.at_mut(i, a_size - 1)[i + 1] = 1; + }); + + let mut vmp: VmpPMat, B> = module.vmp_pmat_alloc(n, mat_rows, mat_cols_in, mat_cols_out, mat_size); + + let mut c_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(n, mat_cols_out, mat_size); + let mut c_big: VecZnxBig, B> = module.vec_znx_big_alloc(n, mat_cols_out, mat_size); + + let mut mat: MatZnx> = MatZnx::alloc(n, mat_rows, mat_cols_in, mat_cols_out, mat_size); + + // Construts a [VecZnxMatDft] that performs cyclic rotations on each submatrix. + (0..a.size()).for_each(|row_i| { + (0..mat_cols_in).for_each(|col_in_i| { + (0..mat_cols_out).for_each(|col_out_i| { + let idx = 1 + col_in_i * mat_cols_out + col_out_i; + mat.at_mut(row_i, col_in_i).at_mut(col_out_i, row_i)[idx] = 1_i64; // X^{idx} + }); + }); + }); + + module.vmp_prepare(&mut vmp, &mat, scratch.borrow()); + + let mut a_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(n, a_cols, a_size); + (0..a_cols).for_each(|i| { + module.vec_znx_dft_from_vec_znx(1, 0, &mut a_dft, i, &a, i); + }); + + module.vmp_apply(&mut c_dft, &a_dft, &vmp, scratch.borrow()); + + let mut res_have_vi64: Vec = vec![i64::default(); n]; + + let mut res_have: VecZnx> = VecZnx::alloc(n, res_cols, res_size); + (0..mat_cols_out).for_each(|i| { + module.vec_znx_dft_to_vec_znx_big_tmp_a(&mut c_big, i, &mut c_dft, i); + module.vec_znx_big_normalize(basek, &mut res_have, i, &c_big, i, scratch.borrow()); + }); + + (0..mat_cols_out).for_each(|col_i| { + let mut res_want_vi64: Vec = vec![i64::default(); n]; + (0..a_cols).for_each(|i| { + res_want_vi64[(i + 1) + (1 + i * mat_cols_out + col_i)] = 1; + }); + res_have.decode_vec_i64(basek, col_i, basek * a_size, &mut res_have_vi64); + assert_eq!(res_have_vi64, res_want_vi64); + }); + }); + }); +} diff --git a/poulpy-schemes/Cargo.toml b/poulpy-schemes/Cargo.toml index cb6a11a..b446c07 100644 --- a/poulpy-schemes/Cargo.toml +++ b/poulpy-schemes/Cargo.toml @@ -1,16 +1,17 @@ [package] name = "poulpy-schemes" -version = "0.1.0" +version = "0.1.1" edition = "2024" license = "Apache-2.0" readme = "README.md" -description = "A crate implementing FHE schemes" +description = "A backend agnostic crate implementing mainsteam RLWE-based FHE schemes" repository = "https://github.com/phantomzone-org/poulpy" homepage = "https://github.com/phantomzone-org/poulpy" documentation = "https://docs.rs/poulpy" [dependencies] -poulpy-backend = "0.1.0" -poulpy-core = "0.1.0" +poulpy-backend = "0.1.2" +poulpy-hal = "0.1.2" +poulpy-core = "0.1.1" itertools = "0.14.0" byteorder = "1.5.0" \ No newline at end of file diff --git a/poulpy-schemes/examples/circuit_bootstrapping.rs b/poulpy-schemes/examples/circuit_bootstrapping.rs index 7812d95..b05ec56 100644 --- a/poulpy-schemes/examples/circuit_bootstrapping.rs +++ b/poulpy-schemes/examples/circuit_bootstrapping.rs @@ -7,15 +7,14 @@ use poulpy_core::{ }; use std::time::Instant; -use poulpy_backend::{ - hal::{ - api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxNormalizeInplace, ZnxView, ZnxViewMut}, - layouts::{Module, ScalarZnx, ScratchOwned}, - source::Source, - }, - implementation::cpu_spqlios::FFT64, +use poulpy_hal::{ + api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxNormalizeInplace, ZnxView, ZnxViewMut}, + layouts::{Module, ScalarZnx, ScratchOwned}, + source::Source, }; +use poulpy_backend::cpu_spqlios::FFT64; + use poulpy_schemes::tfhe::{ blind_rotation::CGGI, circuit_bootstrapping::{ diff --git a/poulpy-schemes/src/tfhe/blind_rotation/cggi_algo.rs b/poulpy-schemes/src/tfhe/blind_rotation/cggi_algo.rs index 7ef921a..aabc58f 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/cggi_algo.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/cggi_algo.rs @@ -1,5 +1,5 @@ use itertools::izip; -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchAvailable, SvpApply, SvpPPolAllocBytes, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, TakeVecZnxDftSlice, TakeVecZnxSlice, VecZnxAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, @@ -513,7 +513,7 @@ fn execute_standard( out_mut.normalize_inplace(module, scratch1); } -pub(crate) fn mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWECiphertext<&[u8]>, rot_dir: LookUpTableRotationDirection) { +pub fn mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWECiphertext<&[u8]>, rot_dir: LookUpTableRotationDirection) { let basek: usize = lwe.basek(); let log2n: usize = usize::BITS as usize - (n - 1).leading_zeros() as usize + 1; diff --git a/poulpy-schemes/src/tfhe/blind_rotation/cggi_key.rs b/poulpy-schemes/src/tfhe/blind_rotation/cggi_key.rs index db30ad0..cf144cb 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/cggi_key.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/cggi_key.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchAvailable, SvpApplyInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, VecZnxFillUniform, diff --git a/poulpy-schemes/src/tfhe/blind_rotation/key.rs b/poulpy-schemes/src/tfhe/blind_rotation/key.rs index 9fc0fa5..b530893 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/key.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/key.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{FillUniform, Reset}, layouts::{Backend, Data, DataMut, DataRef, Module, ReaderFrom, Scratch, WriterTo}, source::Source, diff --git a/poulpy-schemes/src/tfhe/blind_rotation/key_compressed.rs b/poulpy-schemes/src/tfhe/blind_rotation/key_compressed.rs index fe8fdf2..41a6f1e 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/key_compressed.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/key_compressed.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{FillUniform, Reset}, layouts::{Data, DataMut, DataRef, ReaderFrom, WriterTo}, source::Source, diff --git a/poulpy-schemes/src/tfhe/blind_rotation/key_prepared.rs b/poulpy-schemes/src/tfhe/blind_rotation/key_prepared.rs index d90a137..5c61867 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/key_prepared.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/key_prepared.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{SvpPPolAlloc, SvpPrepare, VmpPMatAlloc, VmpPrepare}, layouts::{Backend, Data, DataMut, DataRef, Module, ScalarZnx, Scratch, SvpPPol}, }; diff --git a/poulpy-schemes/src/tfhe/blind_rotation/lut.rs b/poulpy-schemes/src/tfhe/blind_rotation/lut.rs index c955ed6..52f590c 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/lut.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/lut.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxCopy, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotateInplace, VecZnxSwithcDegree, ZnxInfos, ZnxViewMut, diff --git a/poulpy-schemes/src/tfhe/blind_rotation/mod.rs b/poulpy-schemes/src/tfhe/blind_rotation/mod.rs index 25c31bd..bd83a08 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/mod.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/mod.rs @@ -14,8 +14,8 @@ pub use lut::*; pub mod tests; -use poulpy_backend::hal::layouts::{Backend, DataMut, DataRef, Module, Scratch}; use poulpy_core::layouts::{GLWECiphertext, LWECiphertext}; +use poulpy_hal::layouts::{Backend, DataMut, DataRef, Module, Scratch}; pub trait BlindRotationAlgo {} diff --git a/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_blind_rotation.rs b/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_blind_rotation.rs index c344a7d..e05cf1d 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_blind_rotation.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_blind_rotation.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApply, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, diff --git a/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_lut.rs b/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_lut.rs index 9516e84..d4ef98e 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_lut.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_lut.rs @@ -1,6 +1,6 @@ use std::vec; -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{VecZnxCopy, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotateInplace, VecZnxSwithcDegree}, layouts::{Backend, Module}, oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl}, diff --git a/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_serialization.rs b/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_serialization.rs index 16f5b66..7310750 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_serialization.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_serialization.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::tests::serialization::test_reader_writer_interface; +use poulpy_hal::tests::serialization::test_reader_writer_interface; use crate::tfhe::blind_rotation::{BlindRotationKey, BlindRotationKeyAlloc, BlindRotationKeyCompressed, CGGI}; diff --git a/poulpy-schemes/src/tfhe/blind_rotation/tests/implementation/cpu_spqlios/fft64.rs b/poulpy-schemes/src/tfhe/blind_rotation/tests/implementation/cpu_spqlios/fft64.rs index 9003e89..feeee06 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/tests/implementation/cpu_spqlios/fft64.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/tests/implementation/cpu_spqlios/fft64.rs @@ -1,7 +1,5 @@ -use poulpy_backend::{ - hal::{api::ModuleNew, layouts::Module}, - implementation::cpu_spqlios::FFT64, -}; +use poulpy_backend::cpu_spqlios::FFT64; +use poulpy_hal::{api::ModuleNew, layouts::Module}; use crate::tfhe::blind_rotation::tests::{ generic_blind_rotation::test_blind_rotation, diff --git a/poulpy-schemes/src/tfhe/blind_rotation/utils.rs b/poulpy-schemes/src/tfhe/blind_rotation/utils.rs index cfdd309..993f665 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/utils.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/utils.rs @@ -1,4 +1,4 @@ -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{SvpPrepare, ZnxInfos, ZnxViewMut}, layouts::{Backend, DataMut, Module, ScalarZnx, SvpPPol}, }; diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs index afda315..331ab98 100644 --- a/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs +++ b/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchAvailable, TakeMatZnx, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, TakeVecZnxDftSlice, TakeVecZnxSlice, VecZnxAddInplace, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigAutomorphismInplace, diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/key.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/key.rs index 0ec79ec..2de4a33 100644 --- a/poulpy-schemes/src/tfhe/circuit_bootstrapping/key.rs +++ b/poulpy-schemes/src/tfhe/circuit_bootstrapping/key.rs @@ -4,7 +4,7 @@ use poulpy_core::layouts::{ }; use std::collections::HashMap; -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchAvailable, SvpApply, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeSvpPPol, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/mod.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/mod.rs index e3a79f6..86dcf5e 100644 --- a/poulpy-schemes/src/tfhe/circuit_bootstrapping/mod.rs +++ b/poulpy-schemes/src/tfhe/circuit_bootstrapping/mod.rs @@ -7,7 +7,7 @@ pub use key::*; use poulpy_core::layouts::{GGSWCiphertext, LWECiphertext}; -use poulpy_backend::hal::layouts::{Backend, DataMut, DataRef, Module, Scratch}; +use poulpy_hal::layouts::{Backend, DataMut, DataRef, Module, Scratch}; pub trait CirtuitBootstrappingExecute { fn execute_to_constant( diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/circuit_bootstrapping.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/circuit_bootstrapping.rs index 2763e5c..48bf585 100644 --- a/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/circuit_bootstrapping.rs +++ b/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/circuit_bootstrapping.rs @@ -1,6 +1,6 @@ use std::time::Instant; -use poulpy_backend::hal::{ +use poulpy_hal::{ api::{ ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApply, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/implementation/cpu_spqlios/fft64.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/implementation/cpu_spqlios/fft64.rs index 8319d92..688e28e 100644 --- a/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/implementation/cpu_spqlios/fft64.rs +++ b/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/implementation/cpu_spqlios/fft64.rs @@ -1,7 +1,5 @@ -use poulpy_backend::{ - hal::{api::ModuleNew, layouts::Module}, - implementation::cpu_spqlios::FFT64, -}; +use poulpy_backend::cpu_spqlios::FFT64; +use poulpy_hal::{api::ModuleNew, layouts::Module}; use crate::tfhe::{ blind_rotation::CGGI,