mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
Crates io (#76)
* crates re-organisation * fixed typo in layout & added test for vmp_apply * updated dependencies
This commit is contained in:
committed by
GitHub
parent
dce4d82706
commit
a1de248567
77
Cargo.lock
generated
77
Cargo.lock
generated
@@ -347,12 +347,13 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "poulpy-backend"
|
name = "poulpy-backend"
|
||||||
version = "0.1.0"
|
version = "0.1.2"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"byteorder",
|
"byteorder",
|
||||||
"cmake",
|
"cmake",
|
||||||
"criterion",
|
"criterion",
|
||||||
"itertools 0.14.0",
|
"itertools 0.14.0",
|
||||||
|
"poulpy-hal 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"rand",
|
"rand",
|
||||||
"rand_chacha",
|
"rand_chacha",
|
||||||
"rand_core",
|
"rand_core",
|
||||||
@@ -362,9 +363,51 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "poulpy-backend"
|
name = "poulpy-backend"
|
||||||
version = "0.1.0"
|
version = "0.1.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "d47fbc27d0c03c2bfffd972795c62a243e4a3a3068acdb95ef55fb335a58d00f"
|
checksum = "e0c6c0ad35bd5399e72a7d51b8bad5aa03e54bfd63bf1a09c4a595bd51145ca6"
|
||||||
|
dependencies = [
|
||||||
|
"byteorder",
|
||||||
|
"cmake",
|
||||||
|
"criterion",
|
||||||
|
"itertools 0.14.0",
|
||||||
|
"poulpy-hal 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"rand",
|
||||||
|
"rand_chacha",
|
||||||
|
"rand_core",
|
||||||
|
"rand_distr",
|
||||||
|
"rug",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "poulpy-core"
|
||||||
|
version = "0.1.1"
|
||||||
|
dependencies = [
|
||||||
|
"byteorder",
|
||||||
|
"criterion",
|
||||||
|
"itertools 0.14.0",
|
||||||
|
"poulpy-backend 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"poulpy-hal 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"rug",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "poulpy-core"
|
||||||
|
version = "0.1.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "34afc307c185e288395d9f298a3261177dc850229e2bd6d53aa4059ae7e98cab"
|
||||||
|
dependencies = [
|
||||||
|
"byteorder",
|
||||||
|
"criterion",
|
||||||
|
"itertools 0.14.0",
|
||||||
|
"poulpy-backend 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"poulpy-hal 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"rug",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "poulpy-hal"
|
||||||
|
version = "0.1.2"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"byteorder",
|
"byteorder",
|
||||||
"cmake",
|
"cmake",
|
||||||
@@ -378,26 +421,17 @@ dependencies = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "poulpy-core"
|
name = "poulpy-hal"
|
||||||
version = "0.1.0"
|
version = "0.1.2"
|
||||||
dependencies = [
|
|
||||||
"byteorder",
|
|
||||||
"criterion",
|
|
||||||
"itertools 0.14.0",
|
|
||||||
"poulpy-backend 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
|
||||||
"rug",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "poulpy-core"
|
|
||||||
version = "0.1.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "4ff4e1acd3f4a84e861b07184fd28fe3143a57360bd51e923aeadbc94b8b38d0"
|
checksum = "63312a7be7c5fd91e1f5151735d646294a4592d80027d8e90778076b2070a0ec"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"byteorder",
|
"byteorder",
|
||||||
|
"cmake",
|
||||||
"criterion",
|
"criterion",
|
||||||
"itertools 0.14.0",
|
"itertools 0.14.0",
|
||||||
"poulpy-backend 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
"rand",
|
||||||
|
"rand_chacha",
|
||||||
"rand_core",
|
"rand_core",
|
||||||
"rand_distr",
|
"rand_distr",
|
||||||
"rug",
|
"rug",
|
||||||
@@ -405,12 +439,13 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "poulpy-schemes"
|
name = "poulpy-schemes"
|
||||||
version = "0.1.0"
|
version = "0.1.1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"byteorder",
|
"byteorder",
|
||||||
"itertools 0.14.0",
|
"itertools 0.14.0",
|
||||||
"poulpy-backend 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
"poulpy-backend 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"poulpy-core 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
"poulpy-core 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"poulpy-hal 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
[workspace]
|
[workspace]
|
||||||
members = ["poulpy-backend", "poulpy-core", "poulpy-schemes"]
|
members = ["poulpy-hal", "poulpy-core", "poulpy-backend", "poulpy-schemes"]
|
||||||
resolver = "3"
|
resolver = "3"
|
||||||
|
|
||||||
[workspace.dependencies]
|
[workspace.dependencies]
|
||||||
|
|||||||
125
README.md
125
README.md
@@ -1,3 +1,4 @@
|
|||||||
|
|
||||||
# 🐙 Poulpy
|
# 🐙 Poulpy
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
@@ -8,6 +9,19 @@
|
|||||||
|
|
||||||
**Poulpy** is a fast & modular FHE library that implements Ring-Learning-With-Errors based homomorphic encryption. It adopts the bivariate polynomial representation proposed in [Revisiting Key Decomposition Techniques for FHE: Simpler, Faster and More Generic](https://eprint.iacr.org/2023/771). In addition to simpler and more efficient arithmetic than the residue number system (RNS), this representation provides a common plaintext space for all schemes and allows easy switching between any two schemes. Poulpy also decouples the schemes implementations from the polynomial arithmetic backend by being built around a hardware abstraction layer (HAL). This enables user to easily provide or use a custom backend.
|
**Poulpy** is a fast & modular FHE library that implements Ring-Learning-With-Errors based homomorphic encryption. It adopts the bivariate polynomial representation proposed in [Revisiting Key Decomposition Techniques for FHE: Simpler, Faster and More Generic](https://eprint.iacr.org/2023/771). In addition to simpler and more efficient arithmetic than the residue number system (RNS), this representation provides a common plaintext space for all schemes and allows easy switching between any two schemes. Poulpy also decouples the schemes implementations from the polynomial arithmetic backend by being built around a hardware abstraction layer (HAL). This enables user to easily provide or use a custom backend.
|
||||||
|
|
||||||
|
## Library Overview
|
||||||
|
|
||||||
|
- **`poulpy-hal`**: a crate providing layouts and a trait-based hardware acceleration layer with open extension points, matching the API and types of spqlios-arithmetic.
|
||||||
|
- **`api`**: fixed public low-level polynomial level arithmetic API closely matching spqlios-arithmetic.
|
||||||
|
- **`delegates`**: link between the user facing API and implementation OEP. Each trait of `api` is implemented by calling its corresponding trait on the `oep`.
|
||||||
|
- **`layouts`**: layouts of the front-end algebraic structs matching spqlios-arithmetic types, such as `ScalarZnx`, `VecZnx` or opaque backend prepared struct such as `SvpPPol` and `VmpPMat`.
|
||||||
|
- **`oep`**: open extension points, which can be (re-)implemented by the user to provide a concrete backend.
|
||||||
|
- **`tests`**: backend agnostic & generic tests for the OEP/layouts.
|
||||||
|
- **`poulpy-backend`**: a crate providing concrete implementations of **`poulpy-hal`**.
|
||||||
|
- **`cpu_spqlios`**: cpu implementation of **`poulpy-hal`** through the `oep` using bindings on spqlios-arithmetic. This implementation currently supports the `FFT64` backend and will be extended to support the `NTT120` backend once it is available in spqlios-arithmetic.
|
||||||
|
- **`poulpy-core`**: a backend agnostic crate implementing scheme agnostic RLWE arithmetic for LWE, GLWE, GGLWE and GGSW ciphertexts using **`poulpy-hal`**.
|
||||||
|
- **`poulpy-schemes`**: a backend agnostic crate implementing mainstream FHE schemes using **`poulpy-core`** and **`poulpy-hal`**.
|
||||||
|
|
||||||
### Bivariate Polynomial Representation
|
### Bivariate Polynomial Representation
|
||||||
|
|
||||||
Existing FHE implementations (such as [Lattigo](https://github.com/tuneinsight/lattigo) or [OpenFHE](https://github.com/openfheorg/openfhe-development)) use the [residue-number-system](https://en.wikipedia.org/wiki/Residue_number_system) (RNS) to represent large integers. Although the parallelism and carry-less arithmetic provided by the RNS representation provides a very efficient modular arithmetic over large-integers, it suffers from various drawbacks when used in the context of FHE. The main idea behind the bivariate representation is to decouple the cyclotomic arithmetic from the large number arithmetic. Instead of using the RNS representation for large integer, integers are decomposed in base $2^{-K}$ over the Torus $\mathbb{T}_{N}[X]$.
|
Existing FHE implementations (such as [Lattigo](https://github.com/tuneinsight/lattigo) or [OpenFHE](https://github.com/openfheorg/openfhe-development)) use the [residue-number-system](https://en.wikipedia.org/wiki/Residue_number_system) (RNS) to represent large integers. Although the parallelism and carry-less arithmetic provided by the RNS representation provides a very efficient modular arithmetic over large-integers, it suffers from various drawbacks when used in the context of FHE. The main idea behind the bivariate representation is to decouple the cyclotomic arithmetic from the large number arithmetic. Instead of using the RNS representation for large integer, integers are decomposed in base $2^{-K}$ over the Torus $\mathbb{T}_{N}[X]$.
|
||||||
@@ -30,114 +44,13 @@ This provides the following benefits:
|
|||||||
|
|
||||||
In addition to providing a general purpose FHE library over a unified plaintext space, Poulpy is also designed from the ground up around a **hardware abstraction layer** that closely matches the API of [spqlios-arithmetic](https://github.com/tfhe/spqlios-arithmetic). The bivariate representation is by itself hardware friendly as it uses flat, aligned & vectorized memory layout. Finally, generic opaque write only structs (prepared versions) are provided, making it easy for developers to provide hardware focused/optimized operations. This makes possible for anyone to provide or use a custom backend.
|
In addition to providing a general purpose FHE library over a unified plaintext space, Poulpy is also designed from the ground up around a **hardware abstraction layer** that closely matches the API of [spqlios-arithmetic](https://github.com/tfhe/spqlios-arithmetic). The bivariate representation is by itself hardware friendly as it uses flat, aligned & vectorized memory layout. Finally, generic opaque write only structs (prepared versions) are provided, making it easy for developers to provide hardware focused/optimized operations. This makes possible for anyone to provide or use a custom backend.
|
||||||
|
|
||||||
## Library Overview
|
|
||||||
|
|
||||||
- **`backend/hal`**: hardware abstraction layer. This layer targets users that want to provide their own backend or use a third party backend.
|
|
||||||
|
|
||||||
- **`api`**: fixed public low-level polynomial level arithmetic API closely matching spqlios-arithmetic. The goal is to eventually freeze this API, in order to decouple it from the OEP traits, ensuring that changes to implementations do not affect the front end API.
|
|
||||||
|
|
||||||
```rust
|
|
||||||
pub trait SvpPrepare<B: Backend> {
|
|
||||||
fn svp_prepare<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
|
||||||
where
|
|
||||||
R: SvpPPolToMut<B>,
|
|
||||||
A: ScalarZnxToRef;
|
|
||||||
}
|
|
||||||
````
|
|
||||||
|
|
||||||
- **`delegates`**: link between the user facing API and implementation OEP. Each trait of `api` is implemented by calling its corresponding trait on the `oep`.
|
|
||||||
|
|
||||||
```rust
|
|
||||||
impl<B> SvpPrepare<B> for Module<B>
|
|
||||||
where
|
|
||||||
B: Backend + SvpPrepareImpl<B>,
|
|
||||||
{
|
|
||||||
fn svp_prepare<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
|
||||||
where
|
|
||||||
R: SvpPPolToMut<B>,
|
|
||||||
A: ScalarZnxToRef,
|
|
||||||
{
|
|
||||||
B::svp_prepare_impl(self, res, res_col, a, a_col);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
- **`layouts`**: defines the layouts of the front-end algebraic structs matching spqlios-arithmetic definitions, such as `ScalarZnx`, `VecZnx` or opaque backend prepared struct such as `SvpPPol` and `VmpPMat`.
|
|
||||||
|
|
||||||
```rust
|
|
||||||
pub struct SvpPPol<D: Data, B: Backend> {
|
|
||||||
data: D,
|
|
||||||
n: usize,
|
|
||||||
cols: usize,
|
|
||||||
_phantom: PhantomData<B>,
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
- **`oep`**: open extension points, which can be implemented by the user to provide a custom backend.
|
|
||||||
|
|
||||||
```rust
|
|
||||||
pub unsafe trait SvpPrepareImpl<B: Backend> {
|
|
||||||
fn svp_prepare_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
|
||||||
where
|
|
||||||
R: SvpPPolToMut<B>,
|
|
||||||
A: ScalarZnxToRef;
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
- **`tests`**: exported generic tests for the OEP/structs. Their goal is to enable a user to automatically be able to test its backend implementation, without having to re-implement any tests.
|
|
||||||
|
|
||||||
- **`backend/implementation`**:
|
|
||||||
- **`cpu_spqlios`**: concrete cpu implementation of the hal through the oep using bindings on spqlios-arithmetic. This implementation currently supports the `FFT64` backend and will be extended to support the `NTT120` backend once it is available in spqlios-arithmetic.
|
|
||||||
|
|
||||||
```rust
|
|
||||||
unsafe impl SvpPrepareImpl<Self> for FFT64 {
|
|
||||||
fn svp_prepare_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
|
||||||
where
|
|
||||||
R: SvpPPolToMut<Self>,
|
|
||||||
A: ScalarZnxToRef,
|
|
||||||
{
|
|
||||||
unsafe {
|
|
||||||
svp::svp_prepare(
|
|
||||||
module.ptr(),
|
|
||||||
res.to_mut().at_mut_ptr(res_col, 0) as *mut svp::svp_ppol_t,
|
|
||||||
a.to_ref().at_ptr(a_col, 0),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
- **`core`**: core of the FHE library, implementing scheme agnostic RLWE arithmetic for LWE, GLWE, GGLWE and GGSW ciphertexts. It notably includes all possible cross-ciphertext operations, for example applying an external product on a GGLWE or an automorphism on a GGSW, as well as blind rotation. This crate is entirely implemented using the hardware abstraction layer API, and is thus solely defined over generic and traits (including tests). As such it will work over any backend, as long as it implements the necessary traits defined in the OEP.
|
|
||||||
|
|
||||||
```rust
|
|
||||||
pub struct GLWESecret<D: Data> {
|
|
||||||
pub(crate) data: ScalarZnx<D>,
|
|
||||||
pub(crate) dist: Distribution,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct GLWESecrecPrepared<D: Data, B: Backend> {
|
|
||||||
pub(crate) data: SvpPPol<D, B>,
|
|
||||||
pub(crate) dist: Distribution,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<D: DataMut, B: Backend> GLWESecretPrepared<D, B> {
|
|
||||||
pub fn prepare<O>(&mut self, module: &Module<B>, sk: &GLWESecret<O>)
|
|
||||||
where
|
|
||||||
O: DataRef,
|
|
||||||
Module<B>: SvpPrepare<B>,
|
|
||||||
{
|
|
||||||
(0..self.rank()).for_each(|i| {
|
|
||||||
module.svp_prepare(&mut self.data, i, &sk.data, i);
|
|
||||||
});
|
|
||||||
self.dist = sk.dist
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
TBD — currently not published on crates.io. Clone the repository and use via path-based dependencies.
|
- **`poulpy-hal`**: https://crates.io/crates/poulpy-hal/0.1.0
|
||||||
|
- **`poulpy-backend`**: https://crates.io/crates/poulpy-backend/0.1.0
|
||||||
|
- **`poulpy-core`**: https://crates.io/crates/poulpy-core/0.1.0
|
||||||
|
- **`poulpy-schemes`**: https://crates.io/crates/poulpy-schemes/0.1.0
|
||||||
|
-
|
||||||
## Documentation
|
## Documentation
|
||||||
|
|
||||||
* Full `cargo doc` documentation is coming soon.
|
* Full `cargo doc` documentation is coming soon.
|
||||||
|
|||||||
@@ -1,27 +1,28 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "poulpy-backend"
|
name = "poulpy-backend"
|
||||||
version = "0.1.0"
|
version = "0.1.2"
|
||||||
edition = "2024"
|
edition = "2024"
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
description = "A crate implementing bivariate polynomial arithmetic"
|
description = "A crate providing concrete implementations of poulpy-hal through its open extension points"
|
||||||
repository = "https://github.com/phantomzone-org/poulpy"
|
repository = "https://github.com/phantomzone-org/poulpy"
|
||||||
homepage = "https://github.com/phantomzone-org/poulpy"
|
homepage = "https://github.com/phantomzone-org/poulpy"
|
||||||
documentation = "https://docs.rs/poulpy"
|
documentation = "https://docs.rs/poulpy"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
rug = {workspace = true}
|
poulpy-hal = "0.1.2"
|
||||||
criterion = {workspace = true}
|
rug = {workspace = true}
|
||||||
itertools = {workspace = true}
|
criterion = {workspace = true}
|
||||||
rand = {workspace = true}
|
itertools = {workspace = true}
|
||||||
rand_distr = {workspace = true}
|
rand = {workspace = true}
|
||||||
rand_core = {workspace = true}
|
rand_distr = {workspace = true}
|
||||||
byteorder = {workspace = true}
|
rand_core = {workspace = true}
|
||||||
rand_chacha = "0.9.0"
|
byteorder = {workspace = true}
|
||||||
|
rand_chacha = "0.9.0"
|
||||||
[build-dependencies]
|
|
||||||
cmake = "0.1.54"
|
[build-dependencies]
|
||||||
|
cmake = "0.1.54"
|
||||||
[package.metadata.docs.rs]
|
|
||||||
all-features = true
|
[package.metadata.docs.rs]
|
||||||
|
all-features = true
|
||||||
rustdoc-args = ["--cfg", "docsrs"]
|
rustdoc-args = ["--cfg", "docsrs"]
|
||||||
@@ -1,12 +1,15 @@
|
|||||||
|
|
||||||
## WSL/Ubuntu
|
|
||||||
To use this crate you need to build spqlios-arithmetic, which is provided a as a git submodule:
|
## spqlios-arithmetic
|
||||||
1) Initialize the sub-module
|
|
||||||
2) $ cd backend/spqlios-arithmetic
|
### WSL/Ubuntu
|
||||||
3) mdkir build
|
To use this crate you need to build spqlios-arithmetic, which is provided a as a git submodule:
|
||||||
4) cd build
|
1) Initialize the sub-module
|
||||||
5) cmake ..
|
2) $ cd backend/spqlios-arithmetic
|
||||||
6) make
|
3) mdkir build
|
||||||
|
4) cd build
|
||||||
## Others
|
5) cmake ..
|
||||||
|
6) make
|
||||||
|
|
||||||
|
### Others
|
||||||
Steps 3 to 6 might change depending of your platform. See [spqlios-arithmetic/wiki/build](https://github.com/tfhe/spqlios-arithmetic/wiki/build) for additional information and build options.
|
Steps 3 to 6 might change depending of your platform. See [spqlios-arithmetic/wiki/build](https://github.com/tfhe/spqlios-arithmetic/wiki/build) for additional information and build options.
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
|
||||||
pub fn build() {
|
pub fn build() {
|
||||||
let dst: PathBuf = cmake::Config::new("src/implementation/cpu_spqlios/spqlios-arithmetic")
|
let dst: PathBuf = cmake::Config::new("src/cpu_spqlios/spqlios-arithmetic")
|
||||||
.define("ENABLE_TESTING", "FALSE")
|
.define("ENABLE_TESTING", "FALSE")
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,13 @@
|
|||||||
use itertools::izip;
|
use itertools::izip;
|
||||||
use poulpy_backend::{
|
use poulpy_backend::cpu_spqlios::FFT64;
|
||||||
hal::{
|
use poulpy_hal::{
|
||||||
api::{
|
api::{
|
||||||
ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPrepare, VecZnxAddNormal,
|
ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPrepare, VecZnxAddNormal,
|
||||||
VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace,
|
VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace,
|
||||||
VecZnxDftAlloc, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigTmpA, VecZnxFillUniform, VecZnxNormalizeInplace, ZnxInfos,
|
VecZnxDftAlloc, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigTmpA, VecZnxFillUniform, VecZnxNormalizeInplace, ZnxInfos,
|
||||||
},
|
|
||||||
layouts::{Module, ScalarZnx, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft},
|
|
||||||
source::Source,
|
|
||||||
},
|
},
|
||||||
implementation::cpu_spqlios::FFT64,
|
layouts::{Module, ScalarZnx, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft},
|
||||||
|
source::Source,
|
||||||
};
|
};
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
|
|||||||
15
poulpy-backend/src/cpu_spqlios/ffi/mod.rs
Normal file
15
poulpy-backend/src/cpu_spqlios/ffi/mod.rs
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
#[allow(non_camel_case_types)]
|
||||||
|
pub mod module;
|
||||||
|
#[allow(non_camel_case_types)]
|
||||||
|
pub mod svp;
|
||||||
|
#[allow(non_camel_case_types)]
|
||||||
|
pub mod vec_znx;
|
||||||
|
#[allow(dead_code)]
|
||||||
|
#[allow(non_camel_case_types)]
|
||||||
|
pub mod vec_znx_big;
|
||||||
|
#[allow(non_camel_case_types)]
|
||||||
|
pub mod vec_znx_dft;
|
||||||
|
#[allow(non_camel_case_types)]
|
||||||
|
pub mod vmp;
|
||||||
|
#[allow(non_camel_case_types)]
|
||||||
|
pub mod znx;
|
||||||
@@ -1,19 +1,17 @@
|
|||||||
pub struct module_info_t {
|
#[repr(C)]
|
||||||
_unused: [u8; 0],
|
pub struct module_info_t {
|
||||||
}
|
_unused: [u8; 0],
|
||||||
|
}
|
||||||
pub type module_type_t = ::std::os::raw::c_uint;
|
|
||||||
pub use self::module_type_t as MODULE_TYPE;
|
pub type module_type_t = ::std::os::raw::c_uint;
|
||||||
|
pub use self::module_type_t as MODULE_TYPE;
|
||||||
#[allow(clippy::upper_case_acronyms)]
|
|
||||||
pub type MODULE = module_info_t;
|
#[allow(clippy::upper_case_acronyms)]
|
||||||
|
pub type MODULE = module_info_t;
|
||||||
unsafe extern "C" {
|
|
||||||
pub unsafe fn new_module_info(N: u64, mode: MODULE_TYPE) -> *mut MODULE;
|
unsafe extern "C" {
|
||||||
}
|
pub unsafe fn new_module_info(N: u64, mode: MODULE_TYPE) -> *mut MODULE;
|
||||||
unsafe extern "C" {
|
}
|
||||||
pub unsafe fn delete_module_info(module_info: *mut MODULE);
|
unsafe extern "C" {
|
||||||
}
|
pub unsafe fn delete_module_info(module_info: *mut MODULE);
|
||||||
unsafe extern "C" {
|
}
|
||||||
pub unsafe fn module_get_n(module: *const MODULE) -> u64;
|
|
||||||
}
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
use crate::implementation::cpu_spqlios::ffi::{module::MODULE, vec_znx_dft::VEC_ZNX_DFT};
|
use crate::cpu_spqlios::ffi::{module::MODULE, vec_znx_dft::VEC_ZNX_DFT};
|
||||||
|
|
||||||
#[repr(C)]
|
#[repr(C)]
|
||||||
#[derive(Debug, Copy, Clone)]
|
#[derive(Debug, Copy, Clone)]
|
||||||
@@ -7,20 +7,11 @@ pub struct svp_ppol_t {
|
|||||||
}
|
}
|
||||||
pub type SVP_PPOL = svp_ppol_t;
|
pub type SVP_PPOL = svp_ppol_t;
|
||||||
|
|
||||||
unsafe extern "C" {
|
|
||||||
pub unsafe fn bytes_of_svp_ppol(module: *const MODULE) -> u64;
|
|
||||||
}
|
|
||||||
unsafe extern "C" {
|
|
||||||
pub unsafe fn new_svp_ppol(module: *const MODULE) -> *mut SVP_PPOL;
|
|
||||||
}
|
|
||||||
unsafe extern "C" {
|
|
||||||
pub unsafe fn delete_svp_ppol(res: *mut SVP_PPOL);
|
|
||||||
}
|
|
||||||
|
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn svp_prepare(module: *const MODULE, ppol: *mut SVP_PPOL, pol: *const i64);
|
pub unsafe fn svp_prepare(module: *const MODULE, ppol: *mut SVP_PPOL, pol: *const i64);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn svp_apply_dft(
|
pub unsafe fn svp_apply_dft(
|
||||||
module: *const MODULE,
|
module: *const MODULE,
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
use crate::implementation::cpu_spqlios::ffi::module::MODULE;
|
use crate::cpu_spqlios::ffi::module::MODULE;
|
||||||
|
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn vec_znx_add(
|
pub unsafe fn vec_znx_add(
|
||||||
@@ -53,6 +53,7 @@ unsafe extern "C" {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn vec_znx_rotate(
|
pub unsafe fn vec_znx_rotate(
|
||||||
module: *const MODULE,
|
module: *const MODULE,
|
||||||
@@ -81,9 +82,12 @@ unsafe extern "C" {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn vec_znx_zero(module: *const MODULE, res: *mut i64, res_size: u64, res_sl: u64);
|
pub unsafe fn vec_znx_zero(module: *const MODULE, res: *mut i64, res_size: u64, res_sl: u64);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn vec_znx_copy(
|
pub unsafe fn vec_znx_copy(
|
||||||
module: *const MODULE,
|
module: *const MODULE,
|
||||||
@@ -1,163 +1,153 @@
|
|||||||
use crate::implementation::cpu_spqlios::ffi::module::MODULE;
|
use crate::cpu_spqlios::ffi::module::MODULE;
|
||||||
|
|
||||||
#[repr(C)]
|
#[repr(C)]
|
||||||
#[derive(Debug, Copy, Clone)]
|
#[derive(Debug, Copy, Clone)]
|
||||||
pub struct vec_znx_big_t {
|
pub struct vec_znx_big_t {
|
||||||
_unused: [u8; 0],
|
_unused: [u8; 0],
|
||||||
}
|
}
|
||||||
pub type VEC_ZNX_BIG = vec_znx_big_t;
|
pub type VEC_ZNX_BIG = vec_znx_big_t;
|
||||||
|
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn bytes_of_vec_znx_big(module: *const MODULE, size: u64) -> u64;
|
pub unsafe fn vec_znx_big_add(
|
||||||
}
|
module: *const MODULE,
|
||||||
unsafe extern "C" {
|
res: *mut VEC_ZNX_BIG,
|
||||||
pub unsafe fn new_vec_znx_big(module: *const MODULE, size: u64) -> *mut VEC_ZNX_BIG;
|
res_size: u64,
|
||||||
}
|
a: *const VEC_ZNX_BIG,
|
||||||
unsafe extern "C" {
|
a_size: u64,
|
||||||
pub unsafe fn delete_vec_znx_big(res: *mut VEC_ZNX_BIG);
|
b: *const VEC_ZNX_BIG,
|
||||||
}
|
b_size: u64,
|
||||||
|
);
|
||||||
unsafe extern "C" {
|
}
|
||||||
pub unsafe fn vec_znx_big_add(
|
unsafe extern "C" {
|
||||||
module: *const MODULE,
|
pub unsafe fn vec_znx_big_add_small(
|
||||||
res: *mut VEC_ZNX_BIG,
|
module: *const MODULE,
|
||||||
res_size: u64,
|
res: *mut VEC_ZNX_BIG,
|
||||||
a: *const VEC_ZNX_BIG,
|
res_size: u64,
|
||||||
a_size: u64,
|
a: *const VEC_ZNX_BIG,
|
||||||
b: *const VEC_ZNX_BIG,
|
a_size: u64,
|
||||||
b_size: u64,
|
b: *const i64,
|
||||||
);
|
b_size: u64,
|
||||||
}
|
b_sl: u64,
|
||||||
unsafe extern "C" {
|
);
|
||||||
pub unsafe fn vec_znx_big_add_small(
|
}
|
||||||
module: *const MODULE,
|
unsafe extern "C" {
|
||||||
res: *mut VEC_ZNX_BIG,
|
pub unsafe fn vec_znx_big_add_small2(
|
||||||
res_size: u64,
|
module: *const MODULE,
|
||||||
a: *const VEC_ZNX_BIG,
|
res: *mut VEC_ZNX_BIG,
|
||||||
a_size: u64,
|
res_size: u64,
|
||||||
b: *const i64,
|
a: *const i64,
|
||||||
b_size: u64,
|
a_size: u64,
|
||||||
b_sl: u64,
|
a_sl: u64,
|
||||||
);
|
b: *const i64,
|
||||||
}
|
b_size: u64,
|
||||||
unsafe extern "C" {
|
b_sl: u64,
|
||||||
pub unsafe fn vec_znx_big_add_small2(
|
);
|
||||||
module: *const MODULE,
|
}
|
||||||
res: *mut VEC_ZNX_BIG,
|
unsafe extern "C" {
|
||||||
res_size: u64,
|
pub unsafe fn vec_znx_big_sub(
|
||||||
a: *const i64,
|
module: *const MODULE,
|
||||||
a_size: u64,
|
res: *mut VEC_ZNX_BIG,
|
||||||
a_sl: u64,
|
res_size: u64,
|
||||||
b: *const i64,
|
a: *const VEC_ZNX_BIG,
|
||||||
b_size: u64,
|
a_size: u64,
|
||||||
b_sl: u64,
|
b: *const VEC_ZNX_BIG,
|
||||||
);
|
b_size: u64,
|
||||||
}
|
);
|
||||||
unsafe extern "C" {
|
}
|
||||||
pub unsafe fn vec_znx_big_sub(
|
unsafe extern "C" {
|
||||||
module: *const MODULE,
|
pub unsafe fn vec_znx_big_sub_small_b(
|
||||||
res: *mut VEC_ZNX_BIG,
|
module: *const MODULE,
|
||||||
res_size: u64,
|
res: *mut VEC_ZNX_BIG,
|
||||||
a: *const VEC_ZNX_BIG,
|
res_size: u64,
|
||||||
a_size: u64,
|
a: *const VEC_ZNX_BIG,
|
||||||
b: *const VEC_ZNX_BIG,
|
a_size: u64,
|
||||||
b_size: u64,
|
b: *const i64,
|
||||||
);
|
b_size: u64,
|
||||||
}
|
b_sl: u64,
|
||||||
unsafe extern "C" {
|
);
|
||||||
pub unsafe fn vec_znx_big_sub_small_b(
|
}
|
||||||
module: *const MODULE,
|
unsafe extern "C" {
|
||||||
res: *mut VEC_ZNX_BIG,
|
pub unsafe fn vec_znx_big_sub_small_a(
|
||||||
res_size: u64,
|
module: *const MODULE,
|
||||||
a: *const VEC_ZNX_BIG,
|
res: *mut VEC_ZNX_BIG,
|
||||||
a_size: u64,
|
res_size: u64,
|
||||||
b: *const i64,
|
a: *const i64,
|
||||||
b_size: u64,
|
a_size: u64,
|
||||||
b_sl: u64,
|
a_sl: u64,
|
||||||
);
|
b: *const VEC_ZNX_BIG,
|
||||||
}
|
b_size: u64,
|
||||||
unsafe extern "C" {
|
);
|
||||||
pub unsafe fn vec_znx_big_sub_small_a(
|
}
|
||||||
module: *const MODULE,
|
unsafe extern "C" {
|
||||||
res: *mut VEC_ZNX_BIG,
|
pub unsafe fn vec_znx_big_sub_small2(
|
||||||
res_size: u64,
|
module: *const MODULE,
|
||||||
a: *const i64,
|
res: *mut VEC_ZNX_BIG,
|
||||||
a_size: u64,
|
res_size: u64,
|
||||||
a_sl: u64,
|
a: *const i64,
|
||||||
b: *const VEC_ZNX_BIG,
|
a_size: u64,
|
||||||
b_size: u64,
|
a_sl: u64,
|
||||||
);
|
b: *const i64,
|
||||||
}
|
b_size: u64,
|
||||||
unsafe extern "C" {
|
b_sl: u64,
|
||||||
pub unsafe fn vec_znx_big_sub_small2(
|
);
|
||||||
module: *const MODULE,
|
}
|
||||||
res: *mut VEC_ZNX_BIG,
|
|
||||||
res_size: u64,
|
unsafe extern "C" {
|
||||||
a: *const i64,
|
pub unsafe fn vec_znx_big_normalize_base2k_tmp_bytes(module: *const MODULE, n: u64) -> u64;
|
||||||
a_size: u64,
|
}
|
||||||
a_sl: u64,
|
|
||||||
b: *const i64,
|
unsafe extern "C" {
|
||||||
b_size: u64,
|
pub unsafe fn vec_znx_big_normalize_base2k(
|
||||||
b_sl: u64,
|
module: *const MODULE,
|
||||||
);
|
n: u64,
|
||||||
}
|
log2_base2k: u64,
|
||||||
|
res: *mut i64,
|
||||||
unsafe extern "C" {
|
res_size: u64,
|
||||||
pub unsafe fn vec_znx_big_normalize_base2k_tmp_bytes(module: *const MODULE, n: u64) -> u64;
|
res_sl: u64,
|
||||||
}
|
a: *const VEC_ZNX_BIG,
|
||||||
|
a_size: u64,
|
||||||
unsafe extern "C" {
|
tmp_space: *mut u8,
|
||||||
pub unsafe fn vec_znx_big_normalize_base2k(
|
);
|
||||||
module: *const MODULE,
|
}
|
||||||
n: u64,
|
|
||||||
log2_base2k: u64,
|
unsafe extern "C" {
|
||||||
res: *mut i64,
|
pub unsafe fn vec_znx_big_range_normalize_base2k(
|
||||||
res_size: u64,
|
module: *const MODULE,
|
||||||
res_sl: u64,
|
n: u64,
|
||||||
a: *const VEC_ZNX_BIG,
|
log2_base2k: u64,
|
||||||
a_size: u64,
|
res: *mut i64,
|
||||||
tmp_space: *mut u8,
|
res_size: u64,
|
||||||
);
|
res_sl: u64,
|
||||||
}
|
a: *const VEC_ZNX_BIG,
|
||||||
|
a_range_begin: u64,
|
||||||
unsafe extern "C" {
|
a_range_xend: u64,
|
||||||
pub unsafe fn vec_znx_big_range_normalize_base2k(
|
a_range_step: u64,
|
||||||
module: *const MODULE,
|
tmp_space: *mut u8,
|
||||||
n: u64,
|
);
|
||||||
log2_base2k: u64,
|
}
|
||||||
res: *mut i64,
|
|
||||||
res_size: u64,
|
unsafe extern "C" {
|
||||||
res_sl: u64,
|
pub unsafe fn vec_znx_big_range_normalize_base2k_tmp_bytes(module: *const MODULE, n: u64) -> u64;
|
||||||
a: *const VEC_ZNX_BIG,
|
}
|
||||||
a_range_begin: u64,
|
|
||||||
a_range_xend: u64,
|
unsafe extern "C" {
|
||||||
a_range_step: u64,
|
pub unsafe fn vec_znx_big_automorphism(
|
||||||
tmp_space: *mut u8,
|
module: *const MODULE,
|
||||||
);
|
p: i64,
|
||||||
}
|
res: *mut VEC_ZNX_BIG,
|
||||||
|
res_size: u64,
|
||||||
unsafe extern "C" {
|
a: *const VEC_ZNX_BIG,
|
||||||
pub unsafe fn vec_znx_big_range_normalize_base2k_tmp_bytes(module: *const MODULE, n: u64) -> u64;
|
a_size: u64,
|
||||||
}
|
);
|
||||||
|
}
|
||||||
unsafe extern "C" {
|
|
||||||
pub unsafe fn vec_znx_big_automorphism(
|
unsafe extern "C" {
|
||||||
module: *const MODULE,
|
pub unsafe fn vec_znx_big_rotate(
|
||||||
p: i64,
|
module: *const MODULE,
|
||||||
res: *mut VEC_ZNX_BIG,
|
p: i64,
|
||||||
res_size: u64,
|
res: *mut VEC_ZNX_BIG,
|
||||||
a: *const VEC_ZNX_BIG,
|
res_size: u64,
|
||||||
a_size: u64,
|
a: *const VEC_ZNX_BIG,
|
||||||
);
|
a_size: u64,
|
||||||
}
|
);
|
||||||
|
}
|
||||||
unsafe extern "C" {
|
|
||||||
pub unsafe fn vec_znx_big_rotate(
|
|
||||||
module: *const MODULE,
|
|
||||||
p: i64,
|
|
||||||
res: *mut VEC_ZNX_BIG,
|
|
||||||
res_size: u64,
|
|
||||||
a: *const VEC_ZNX_BIG,
|
|
||||||
a_size: u64,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
use crate::implementation::cpu_spqlios::ffi::{module::MODULE, vec_znx_big::VEC_ZNX_BIG};
|
use crate::cpu_spqlios::ffi::{module::MODULE, vec_znx_big::VEC_ZNX_BIG};
|
||||||
|
|
||||||
#[repr(C)]
|
#[repr(C)]
|
||||||
#[derive(Debug, Copy, Clone)]
|
#[derive(Debug, Copy, Clone)]
|
||||||
@@ -7,19 +7,6 @@ pub struct vec_znx_dft_t {
|
|||||||
}
|
}
|
||||||
pub type VEC_ZNX_DFT = vec_znx_dft_t;
|
pub type VEC_ZNX_DFT = vec_znx_dft_t;
|
||||||
|
|
||||||
unsafe extern "C" {
|
|
||||||
pub unsafe fn bytes_of_vec_znx_dft(module: *const MODULE, size: u64) -> u64;
|
|
||||||
}
|
|
||||||
unsafe extern "C" {
|
|
||||||
pub unsafe fn new_vec_znx_dft(module: *const MODULE, size: u64) -> *mut VEC_ZNX_DFT;
|
|
||||||
}
|
|
||||||
unsafe extern "C" {
|
|
||||||
pub unsafe fn delete_vec_znx_dft(res: *mut VEC_ZNX_DFT);
|
|
||||||
}
|
|
||||||
|
|
||||||
unsafe extern "C" {
|
|
||||||
pub unsafe fn vec_dft_zero(module: *const MODULE, res: *mut VEC_ZNX_DFT, res_size: u64);
|
|
||||||
}
|
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn vec_dft_add(
|
pub unsafe fn vec_dft_add(
|
||||||
module: *const MODULE,
|
module: *const MODULE,
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
use crate::implementation::cpu_spqlios::ffi::{module::MODULE, vec_znx_dft::VEC_ZNX_DFT};
|
use crate::cpu_spqlios::ffi::{module::MODULE, vec_znx_dft::VEC_ZNX_DFT};
|
||||||
|
|
||||||
#[repr(C)]
|
#[repr(C)]
|
||||||
#[derive(Debug, Copy, Clone)]
|
#[derive(Debug, Copy, Clone)]
|
||||||
@@ -9,16 +9,7 @@ pub struct vmp_pmat_t {
|
|||||||
// [rows][cols] = [#Decomposition][#Limbs]
|
// [rows][cols] = [#Decomposition][#Limbs]
|
||||||
pub type VMP_PMAT = vmp_pmat_t;
|
pub type VMP_PMAT = vmp_pmat_t;
|
||||||
|
|
||||||
unsafe extern "C" {
|
#[allow(dead_code)]
|
||||||
pub unsafe fn bytes_of_vmp_pmat(module: *const MODULE, nrows: u64, ncols: u64) -> u64;
|
|
||||||
}
|
|
||||||
unsafe extern "C" {
|
|
||||||
pub unsafe fn new_vmp_pmat(module: *const MODULE, nrows: u64, ncols: u64) -> *mut VMP_PMAT;
|
|
||||||
}
|
|
||||||
unsafe extern "C" {
|
|
||||||
pub unsafe fn delete_vmp_pmat(res: *mut VMP_PMAT);
|
|
||||||
}
|
|
||||||
|
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn vmp_apply_dft(
|
pub unsafe fn vmp_apply_dft(
|
||||||
module: *const MODULE,
|
module: *const MODULE,
|
||||||
@@ -34,6 +25,7 @@ unsafe extern "C" {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn vmp_apply_dft_add(
|
pub unsafe fn vmp_apply_dft_add(
|
||||||
module: *const MODULE,
|
module: *const MODULE,
|
||||||
@@ -50,6 +42,7 @@ unsafe extern "C" {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn vmp_apply_dft_tmp_bytes(module: *const MODULE, res_size: u64, a_size: u64, nrows: u64, ncols: u64) -> u64;
|
pub unsafe fn vmp_apply_dft_tmp_bytes(module: *const MODULE, res_size: u64, a_size: u64, nrows: u64, ncols: u64) -> u64;
|
||||||
}
|
}
|
||||||
@@ -105,10 +98,6 @@ unsafe extern "C" {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe extern "C" {
|
|
||||||
pub unsafe fn vmp_prepare_contiguous_dft(module: *const MODULE, pmat: *mut VMP_PMAT, mat: *const f64, nrows: u64, ncols: u64);
|
|
||||||
}
|
|
||||||
|
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
pub unsafe fn vmp_prepare_tmp_bytes(module: *const MODULE, nn: u64, nrows: u64, ncols: u64) -> u64;
|
pub unsafe fn vmp_prepare_tmp_bytes(module: *const MODULE, nn: u64, nrows: u64, ncols: u64) -> u64;
|
||||||
}
|
}
|
||||||
7
poulpy-backend/src/cpu_spqlios/ffi/znx.rs
Normal file
7
poulpy-backend/src/cpu_spqlios/ffi/znx.rs
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
unsafe extern "C" {
|
||||||
|
pub unsafe fn znx_rotate_i64(nn: u64, p: i64, res: *mut i64, in_: *const i64);
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe extern "C" {
|
||||||
|
pub unsafe fn znx_rotate_inplace_i64(nn: u64, p: i64, res: *mut i64);
|
||||||
|
}
|
||||||
15
poulpy-backend/src/cpu_spqlios/fft64/mod.rs
Normal file
15
poulpy-backend/src/cpu_spqlios/fft64/mod.rs
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
mod module;
|
||||||
|
mod scratch;
|
||||||
|
mod svp_ppol;
|
||||||
|
mod vec_znx;
|
||||||
|
mod vec_znx_big;
|
||||||
|
mod vec_znx_dft;
|
||||||
|
mod vmp_pmat;
|
||||||
|
|
||||||
|
pub use module::FFT64;
|
||||||
|
|
||||||
|
/// For external documentation
|
||||||
|
pub use vec_znx::{
|
||||||
|
vec_znx_copy_ref, vec_znx_lsh_inplace_ref, vec_znx_merge_ref, vec_znx_rsh_inplace_ref, vec_znx_split_ref,
|
||||||
|
vec_znx_switch_degree_ref,
|
||||||
|
};
|
||||||
@@ -1,25 +1,29 @@
|
|||||||
use std::ptr::NonNull;
|
use std::ptr::NonNull;
|
||||||
|
|
||||||
use crate::{
|
use poulpy_hal::{
|
||||||
hal::{
|
layouts::{Backend, Module},
|
||||||
layouts::{Backend, Module},
|
oep::ModuleNewImpl,
|
||||||
oep::ModuleNewImpl,
|
|
||||||
},
|
|
||||||
implementation::cpu_spqlios::{
|
|
||||||
CPUAVX,
|
|
||||||
ffi::module::{MODULE, delete_module_info, new_module_info},
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use crate::cpu_spqlios::ffi::module::{MODULE, delete_module_info, new_module_info};
|
||||||
|
|
||||||
pub struct FFT64;
|
pub struct FFT64;
|
||||||
|
|
||||||
impl CPUAVX for FFT64 {}
|
|
||||||
|
|
||||||
impl Backend for FFT64 {
|
impl Backend for FFT64 {
|
||||||
|
type ScalarPrep = f64;
|
||||||
|
type ScalarBig = i64;
|
||||||
type Handle = MODULE;
|
type Handle = MODULE;
|
||||||
unsafe fn destroy(handle: NonNull<Self::Handle>) {
|
unsafe fn destroy(handle: NonNull<Self::Handle>) {
|
||||||
unsafe { delete_module_info(handle.as_ptr()) }
|
unsafe { delete_module_info(handle.as_ptr()) }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn layout_big_word_count() -> usize {
|
||||||
|
1
|
||||||
|
}
|
||||||
|
|
||||||
|
fn layout_prep_word_count() -> usize {
|
||||||
|
1
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl ModuleNewImpl<Self> for FFT64 {
|
unsafe impl ModuleNewImpl<Self> for FFT64 {
|
||||||
@@ -1,24 +1,20 @@
|
|||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
use crate::{
|
use poulpy_hal::{
|
||||||
DEFAULTALIGN, alloc_aligned,
|
DEFAULTALIGN, alloc_aligned,
|
||||||
hal::{
|
api::ScratchFromBytes,
|
||||||
api::ScratchFromBytes,
|
layouts::{Backend, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat},
|
||||||
layouts::{Backend, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat},
|
oep::{
|
||||||
oep::{
|
ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, SvpPPolAllocBytesImpl,
|
||||||
ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, SvpPPolAllocBytesImpl,
|
TakeMatZnxImpl, TakeScalarZnxImpl, TakeSliceImpl, TakeSvpPPolImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl,
|
||||||
TakeMatZnxImpl, TakeScalarZnxImpl, TakeSliceImpl, TakeSvpPPolImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl,
|
TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, TakeVmpPMatImpl, VecZnxBigAllocBytesImpl,
|
||||||
TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, TakeVmpPMatImpl, VecZnxBigAllocBytesImpl,
|
VecZnxDftAllocBytesImpl, VmpPMatAllocBytesImpl,
|
||||||
VecZnxDftAllocBytesImpl, VmpPMatAllocBytesImpl,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
implementation::cpu_spqlios::CPUAVX,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
unsafe impl<B: Backend> ScratchOwnedAllocImpl<B> for B
|
use crate::cpu_spqlios::FFT64;
|
||||||
where
|
|
||||||
B: CPUAVX,
|
unsafe impl<B: Backend> ScratchOwnedAllocImpl<B> for FFT64 {
|
||||||
{
|
|
||||||
fn scratch_owned_alloc_impl(size: usize) -> ScratchOwned<B> {
|
fn scratch_owned_alloc_impl(size: usize) -> ScratchOwned<B> {
|
||||||
let data: Vec<u8> = alloc_aligned(size);
|
let data: Vec<u8> = alloc_aligned(size);
|
||||||
ScratchOwned {
|
ScratchOwned {
|
||||||
@@ -28,28 +24,22 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<B: Backend> ScratchOwnedBorrowImpl<B> for B
|
unsafe impl<B: Backend> ScratchOwnedBorrowImpl<B> for FFT64
|
||||||
where
|
where
|
||||||
B: CPUAVX,
|
B: ScratchFromBytesImpl<B>,
|
||||||
{
|
{
|
||||||
fn scratch_owned_borrow_impl(scratch: &mut ScratchOwned<B>) -> &mut Scratch<B> {
|
fn scratch_owned_borrow_impl(scratch: &mut ScratchOwned<B>) -> &mut Scratch<B> {
|
||||||
Scratch::from_bytes(&mut scratch.data)
|
Scratch::from_bytes(&mut scratch.data)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<B: Backend> ScratchFromBytesImpl<B> for B
|
unsafe impl<B: Backend> ScratchFromBytesImpl<B> for FFT64 {
|
||||||
where
|
|
||||||
B: CPUAVX,
|
|
||||||
{
|
|
||||||
fn scratch_from_bytes_impl(data: &mut [u8]) -> &mut Scratch<B> {
|
fn scratch_from_bytes_impl(data: &mut [u8]) -> &mut Scratch<B> {
|
||||||
unsafe { &mut *(data as *mut [u8] as *mut Scratch<B>) }
|
unsafe { &mut *(data as *mut [u8] as *mut Scratch<B>) }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<B: Backend> ScratchAvailableImpl<B> for B
|
unsafe impl<B: Backend> ScratchAvailableImpl<B> for FFT64 {
|
||||||
where
|
|
||||||
B: CPUAVX,
|
|
||||||
{
|
|
||||||
fn scratch_available_impl(scratch: &Scratch<B>) -> usize {
|
fn scratch_available_impl(scratch: &Scratch<B>) -> usize {
|
||||||
let ptr: *const u8 = scratch.data.as_ptr();
|
let ptr: *const u8 = scratch.data.as_ptr();
|
||||||
let self_len: usize = scratch.data.len();
|
let self_len: usize = scratch.data.len();
|
||||||
@@ -58,9 +48,9 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<B: Backend> TakeSliceImpl<B> for B
|
unsafe impl<B: Backend> TakeSliceImpl<B> for FFT64
|
||||||
where
|
where
|
||||||
B: CPUAVX,
|
B: ScratchFromBytesImpl<B>,
|
||||||
{
|
{
|
||||||
fn take_slice_impl<T>(scratch: &mut Scratch<B>, len: usize) -> (&mut [T], &mut Scratch<B>) {
|
fn take_slice_impl<T>(scratch: &mut Scratch<B>, len: usize) -> (&mut [T], &mut Scratch<B>) {
|
||||||
let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, len * std::mem::size_of::<T>());
|
let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, len * std::mem::size_of::<T>());
|
||||||
@@ -74,9 +64,9 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<B: Backend> TakeScalarZnxImpl<B> for B
|
unsafe impl<B: Backend> TakeScalarZnxImpl<B> for FFT64
|
||||||
where
|
where
|
||||||
B: CPUAVX,
|
B: ScratchFromBytesImpl<B>,
|
||||||
{
|
{
|
||||||
fn take_scalar_znx_impl(scratch: &mut Scratch<B>, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Scratch<B>) {
|
fn take_scalar_znx_impl(scratch: &mut Scratch<B>, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Scratch<B>) {
|
||||||
let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, ScalarZnx::alloc_bytes(n, cols));
|
let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, ScalarZnx::alloc_bytes(n, cols));
|
||||||
@@ -87,9 +77,9 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<B: Backend> TakeSvpPPolImpl<B> for B
|
unsafe impl<B: Backend> TakeSvpPPolImpl<B> for FFT64
|
||||||
where
|
where
|
||||||
B: CPUAVX + SvpPPolAllocBytesImpl<B>,
|
B: SvpPPolAllocBytesImpl<B> + ScratchFromBytesImpl<B>,
|
||||||
{
|
{
|
||||||
fn take_svp_ppol_impl(scratch: &mut Scratch<B>, n: usize, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Scratch<B>) {
|
fn take_svp_ppol_impl(scratch: &mut Scratch<B>, n: usize, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Scratch<B>) {
|
||||||
let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, B::svp_ppol_alloc_bytes_impl(n, cols));
|
let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, B::svp_ppol_alloc_bytes_impl(n, cols));
|
||||||
@@ -100,9 +90,9 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<B: Backend> TakeVecZnxImpl<B> for B
|
unsafe impl<B: Backend> TakeVecZnxImpl<B> for FFT64
|
||||||
where
|
where
|
||||||
B: CPUAVX,
|
B: ScratchFromBytesImpl<B>,
|
||||||
{
|
{
|
||||||
fn take_vec_znx_impl(scratch: &mut Scratch<B>, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Scratch<B>) {
|
fn take_vec_znx_impl(scratch: &mut Scratch<B>, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Scratch<B>) {
|
||||||
let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, VecZnx::alloc_bytes(n, cols, size));
|
let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, VecZnx::alloc_bytes(n, cols, size));
|
||||||
@@ -113,9 +103,9 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<B: Backend> TakeVecZnxBigImpl<B> for B
|
unsafe impl<B: Backend> TakeVecZnxBigImpl<B> for FFT64
|
||||||
where
|
where
|
||||||
B: CPUAVX + VecZnxBigAllocBytesImpl<B>,
|
B: VecZnxBigAllocBytesImpl<B> + ScratchFromBytesImpl<B>,
|
||||||
{
|
{
|
||||||
fn take_vec_znx_big_impl(
|
fn take_vec_znx_big_impl(
|
||||||
scratch: &mut Scratch<B>,
|
scratch: &mut Scratch<B>,
|
||||||
@@ -134,9 +124,9 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<B: Backend> TakeVecZnxDftImpl<B> for B
|
unsafe impl<B: Backend> TakeVecZnxDftImpl<B> for FFT64
|
||||||
where
|
where
|
||||||
B: CPUAVX + VecZnxDftAllocBytesImpl<B>,
|
B: VecZnxDftAllocBytesImpl<B> + ScratchFromBytesImpl<B>,
|
||||||
{
|
{
|
||||||
fn take_vec_znx_dft_impl(
|
fn take_vec_znx_dft_impl(
|
||||||
scratch: &mut Scratch<B>,
|
scratch: &mut Scratch<B>,
|
||||||
@@ -156,9 +146,9 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<B: Backend> TakeVecZnxDftSliceImpl<B> for B
|
unsafe impl<B: Backend> TakeVecZnxDftSliceImpl<B> for FFT64
|
||||||
where
|
where
|
||||||
B: CPUAVX + VecZnxDftAllocBytesImpl<B>,
|
B: VecZnxDftAllocBytesImpl<B> + ScratchFromBytesImpl<B> + TakeVecZnxDftImpl<B>,
|
||||||
{
|
{
|
||||||
fn take_vec_znx_dft_slice_impl(
|
fn take_vec_znx_dft_slice_impl(
|
||||||
scratch: &mut Scratch<B>,
|
scratch: &mut Scratch<B>,
|
||||||
@@ -178,9 +168,9 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<B: Backend> TakeVecZnxSliceImpl<B> for B
|
unsafe impl<B: Backend> TakeVecZnxSliceImpl<B> for FFT64
|
||||||
where
|
where
|
||||||
B: CPUAVX,
|
B: ScratchFromBytesImpl<B> + TakeVecZnxImpl<B>,
|
||||||
{
|
{
|
||||||
fn take_vec_znx_slice_impl(
|
fn take_vec_znx_slice_impl(
|
||||||
scratch: &mut Scratch<B>,
|
scratch: &mut Scratch<B>,
|
||||||
@@ -200,9 +190,9 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<B: Backend> TakeVmpPMatImpl<B> for B
|
unsafe impl<B: Backend> TakeVmpPMatImpl<B> for FFT64
|
||||||
where
|
where
|
||||||
B: CPUAVX + VmpPMatAllocBytesImpl<B>,
|
B: VmpPMatAllocBytesImpl<B> + ScratchFromBytesImpl<B>,
|
||||||
{
|
{
|
||||||
fn take_vmp_pmat_impl(
|
fn take_vmp_pmat_impl(
|
||||||
scratch: &mut Scratch<B>,
|
scratch: &mut Scratch<B>,
|
||||||
@@ -223,9 +213,9 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<B: Backend> TakeMatZnxImpl<B> for B
|
unsafe impl<B: Backend> TakeMatZnxImpl<B> for FFT64
|
||||||
where
|
where
|
||||||
B: CPUAVX,
|
B: ScratchFromBytesImpl<B>,
|
||||||
{
|
{
|
||||||
fn take_mat_znx_impl(
|
fn take_mat_znx_impl(
|
||||||
scratch: &mut Scratch<B>,
|
scratch: &mut Scratch<B>,
|
||||||
@@ -1,35 +1,16 @@
|
|||||||
use crate::{
|
use poulpy_hal::{
|
||||||
hal::{
|
api::{ZnxInfos, ZnxView, ZnxViewMut},
|
||||||
api::{ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut},
|
layouts::{
|
||||||
layouts::{
|
Backend, Module, ScalarZnxToRef, SvpPPol, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDft, VecZnxDftToMut,
|
||||||
Data, DataRef, Module, ScalarZnxToRef, SvpPPol, SvpPPolBytesOf, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDft,
|
VecZnxDftToRef,
|
||||||
VecZnxDftToMut, VecZnxDftToRef,
|
|
||||||
},
|
|
||||||
oep::{SvpApplyImpl, SvpApplyInplaceImpl, SvpPPolAllocBytesImpl, SvpPPolAllocImpl, SvpPPolFromBytesImpl, SvpPrepareImpl},
|
|
||||||
},
|
|
||||||
implementation::cpu_spqlios::{
|
|
||||||
ffi::{svp, vec_znx_dft::vec_znx_dft_t},
|
|
||||||
module_fft64::FFT64,
|
|
||||||
},
|
},
|
||||||
|
oep::{SvpApplyImpl, SvpApplyInplaceImpl, SvpPPolAllocBytesImpl, SvpPPolAllocImpl, SvpPPolFromBytesImpl, SvpPrepareImpl},
|
||||||
};
|
};
|
||||||
|
|
||||||
const SVP_PPOL_FFT64_WORD_SIZE: usize = 1;
|
use crate::cpu_spqlios::{
|
||||||
|
FFT64,
|
||||||
impl<D: Data> SvpPPolBytesOf for SvpPPol<D, FFT64> {
|
ffi::{svp, vec_znx_dft::vec_znx_dft_t},
|
||||||
fn bytes_of(n: usize, cols: usize) -> usize {
|
};
|
||||||
SVP_PPOL_FFT64_WORD_SIZE * n * cols * size_of::<f64>()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<D: Data> ZnxSliceSize for SvpPPol<D, FFT64> {
|
|
||||||
fn sl(&self) -> usize {
|
|
||||||
SVP_PPOL_FFT64_WORD_SIZE * self.n()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<D: DataRef> ZnxView for SvpPPol<D, FFT64> {
|
|
||||||
type Scalar = f64;
|
|
||||||
}
|
|
||||||
|
|
||||||
unsafe impl SvpPPolFromBytesImpl<Self> for FFT64 {
|
unsafe impl SvpPPolFromBytesImpl<Self> for FFT64 {
|
||||||
fn svp_ppol_from_bytes_impl(n: usize, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<Self> {
|
fn svp_ppol_from_bytes_impl(n: usize, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<Self> {
|
||||||
@@ -45,7 +26,7 @@ unsafe impl SvpPPolAllocImpl<Self> for FFT64 {
|
|||||||
|
|
||||||
unsafe impl SvpPPolAllocBytesImpl<Self> for FFT64 {
|
unsafe impl SvpPPolAllocBytesImpl<Self> for FFT64 {
|
||||||
fn svp_ppol_alloc_bytes_impl(n: usize, cols: usize) -> usize {
|
fn svp_ppol_alloc_bytes_impl(n: usize, cols: usize) -> usize {
|
||||||
SvpPPol::<Vec<u8>, Self>::bytes_of(n, cols)
|
FFT64::layout_prep_word_count() * n * cols * size_of::<f64>()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1,48 +1,47 @@
|
|||||||
use itertools::izip;
|
use itertools::izip;
|
||||||
use rand_distr::Normal;
|
use rand_distr::Normal;
|
||||||
|
|
||||||
use crate::{
|
use poulpy_hal::{
|
||||||
hal::{
|
api::{
|
||||||
api::{
|
TakeSlice, TakeVecZnx, VecZnxAddDistF64, VecZnxCopy, VecZnxFillDistF64, VecZnxNormalizeTmpBytes, VecZnxRotate,
|
||||||
TakeSlice, TakeVecZnx, VecZnxAddDistF64, VecZnxCopy, VecZnxFillDistF64, VecZnxNormalizeTmpBytes, VecZnxRotate,
|
VecZnxRotateInplace, VecZnxSwithcDegree, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
|
||||||
VecZnxRotateInplace, VecZnxSwithcDegree, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
|
|
||||||
},
|
|
||||||
layouts::{Backend, Module, ScalarZnx, ScalarZnxToRef, Scratch, VecZnx, VecZnxToMut, VecZnxToRef},
|
|
||||||
oep::{
|
|
||||||
VecZnxAddDistF64Impl, VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl, VecZnxAddScalarInplaceImpl,
|
|
||||||
VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl, VecZnxFillDistF64Impl, VecZnxFillNormalImpl,
|
|
||||||
VecZnxFillUniformImpl, VecZnxLshInplaceImpl, VecZnxMergeImpl, VecZnxMulXpMinusOneImpl,
|
|
||||||
VecZnxMulXpMinusOneInplaceImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl,
|
|
||||||
VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl,
|
|
||||||
VecZnxRshInplaceImpl, VecZnxSplitImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl,
|
|
||||||
VecZnxSubScalarInplaceImpl, VecZnxSwithcDegreeImpl,
|
|
||||||
},
|
|
||||||
source::Source,
|
|
||||||
},
|
},
|
||||||
implementation::cpu_spqlios::{
|
layouts::{Backend, Module, ScalarZnx, ScalarZnxToRef, Scratch, VecZnx, VecZnxToMut, VecZnxToRef},
|
||||||
CPUAVX,
|
oep::{
|
||||||
ffi::{module::module_info_t, vec_znx, znx},
|
TakeSliceImpl, TakeVecZnxImpl, VecZnxAddDistF64Impl, VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl,
|
||||||
|
VecZnxAddScalarInplaceImpl, VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl, VecZnxFillDistF64Impl,
|
||||||
|
VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxLshInplaceImpl, VecZnxMergeImpl, VecZnxMulXpMinusOneImpl,
|
||||||
|
VecZnxMulXpMinusOneInplaceImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl,
|
||||||
|
VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl, VecZnxRshInplaceImpl,
|
||||||
|
VecZnxSplitImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarInplaceImpl,
|
||||||
|
VecZnxSwithcDegreeImpl,
|
||||||
},
|
},
|
||||||
|
source::Source,
|
||||||
};
|
};
|
||||||
|
|
||||||
unsafe impl<B: Backend> VecZnxNormalizeTmpBytesImpl<B> for B
|
use crate::cpu_spqlios::{
|
||||||
where
|
FFT64,
|
||||||
B: CPUAVX,
|
ffi::{module::module_info_t, vec_znx, znx},
|
||||||
{
|
};
|
||||||
fn vec_znx_normalize_tmp_bytes_impl(module: &Module<B>, n: usize) -> usize {
|
|
||||||
|
unsafe impl VecZnxNormalizeTmpBytesImpl<Self> for FFT64 {
|
||||||
|
fn vec_znx_normalize_tmp_bytes_impl(module: &Module<Self>, n: usize) -> usize {
|
||||||
unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(module.ptr() as *const module_info_t, n as u64) as usize }
|
unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(module.ptr() as *const module_info_t, n as u64) as usize }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<B: Backend + CPUAVX> VecZnxNormalizeImpl<B> for B {
|
unsafe impl VecZnxNormalizeImpl<Self> for FFT64
|
||||||
|
where
|
||||||
|
Self: TakeSliceImpl<Self> + VecZnxNormalizeTmpBytesImpl<Self>,
|
||||||
|
{
|
||||||
fn vec_znx_normalize_impl<R, A>(
|
fn vec_znx_normalize_impl<R, A>(
|
||||||
module: &Module<B>,
|
module: &Module<Self>,
|
||||||
basek: usize,
|
basek: usize,
|
||||||
res: &mut R,
|
res: &mut R,
|
||||||
res_col: usize,
|
res_col: usize,
|
||||||
a: &A,
|
a: &A,
|
||||||
a_col: usize,
|
a_col: usize,
|
||||||
scratch: &mut Scratch<B>,
|
scratch: &mut Scratch<Self>,
|
||||||
) where
|
) where
|
||||||
R: VecZnxToMut,
|
R: VecZnxToMut,
|
||||||
A: VecZnxToRef,
|
A: VecZnxToRef,
|
||||||
@@ -74,9 +73,17 @@ unsafe impl<B: Backend + CPUAVX> VecZnxNormalizeImpl<B> for B {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<B: Backend + CPUAVX> VecZnxNormalizeInplaceImpl<B> for B {
|
unsafe impl VecZnxNormalizeInplaceImpl<Self> for FFT64
|
||||||
fn vec_znx_normalize_inplace_impl<A>(module: &Module<B>, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
|
where
|
||||||
where
|
Self: TakeSliceImpl<Self> + VecZnxNormalizeTmpBytesImpl<Self>,
|
||||||
|
{
|
||||||
|
fn vec_znx_normalize_inplace_impl<A>(
|
||||||
|
module: &Module<Self>,
|
||||||
|
basek: usize,
|
||||||
|
a: &mut A,
|
||||||
|
a_col: usize,
|
||||||
|
scratch: &mut Scratch<Self>,
|
||||||
|
) where
|
||||||
A: VecZnxToMut,
|
A: VecZnxToMut,
|
||||||
{
|
{
|
||||||
let mut a: VecZnx<&mut [u8]> = a.to_mut();
|
let mut a: VecZnx<&mut [u8]> = a.to_mut();
|
||||||
@@ -100,8 +107,8 @@ unsafe impl<B: Backend + CPUAVX> VecZnxNormalizeInplaceImpl<B> for B {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<B: Backend + CPUAVX> VecZnxAddImpl<B> for B {
|
unsafe impl VecZnxAddImpl<Self> for FFT64 {
|
||||||
fn vec_znx_add_impl<R, A, C>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
fn vec_znx_add_impl<R, A, C>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||||
where
|
where
|
||||||
R: VecZnxToMut,
|
R: VecZnxToMut,
|
||||||
A: VecZnxToRef,
|
A: VecZnxToRef,
|
||||||
@@ -134,8 +141,8 @@ unsafe impl<B: Backend + CPUAVX> VecZnxAddImpl<B> for B {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<B: Backend + CPUAVX> VecZnxAddInplaceImpl<B> for B {
|
unsafe impl VecZnxAddInplaceImpl<Self> for FFT64 {
|
||||||
fn vec_znx_add_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
fn vec_znx_add_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
where
|
where
|
||||||
R: VecZnxToMut,
|
R: VecZnxToMut,
|
||||||
A: VecZnxToRef,
|
A: VecZnxToRef,
|
||||||
@@ -164,9 +171,9 @@ unsafe impl<B: Backend + CPUAVX> VecZnxAddInplaceImpl<B> for B {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<B: Backend + CPUAVX> VecZnxAddScalarInplaceImpl<B> for B {
|
unsafe impl VecZnxAddScalarInplaceImpl<Self> for FFT64 {
|
||||||
fn vec_znx_add_scalar_inplace_impl<R, A>(
|
fn vec_znx_add_scalar_inplace_impl<R, A>(
|
||||||
module: &Module<B>,
|
module: &Module<Self>,
|
||||||
res: &mut R,
|
res: &mut R,
|
||||||
res_col: usize,
|
res_col: usize,
|
||||||
res_limb: usize,
|
res_limb: usize,
|
||||||
@@ -201,8 +208,8 @@ unsafe impl<B: Backend + CPUAVX> VecZnxAddScalarInplaceImpl<B> for B {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<B: Backend + CPUAVX> VecZnxSubImpl<B> for B {
|
unsafe impl VecZnxSubImpl<Self> for FFT64 {
|
||||||
fn vec_znx_sub_impl<R, A, C>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
fn vec_znx_sub_impl<R, A, C>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||||
where
|
where
|
||||||
R: VecZnxToMut,
|
R: VecZnxToMut,
|
||||||
A: VecZnxToRef,
|
A: VecZnxToRef,
|
||||||
@@ -235,8 +242,8 @@ unsafe impl<B: Backend + CPUAVX> VecZnxSubImpl<B> for B {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<B: Backend + CPUAVX> VecZnxSubABInplaceImpl<B> for B {
|
unsafe impl VecZnxSubABInplaceImpl<Self> for FFT64 {
|
||||||
fn vec_znx_sub_ab_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
fn vec_znx_sub_ab_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
where
|
where
|
||||||
R: VecZnxToMut,
|
R: VecZnxToMut,
|
||||||
A: VecZnxToRef,
|
A: VecZnxToRef,
|
||||||
@@ -264,8 +271,8 @@ unsafe impl<B: Backend + CPUAVX> VecZnxSubABInplaceImpl<B> for B {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<B: Backend + CPUAVX> VecZnxSubBAInplaceImpl<B> for B {
|
unsafe impl VecZnxSubBAInplaceImpl<Self> for FFT64 {
|
||||||
fn vec_znx_sub_ba_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
fn vec_znx_sub_ba_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
where
|
where
|
||||||
R: VecZnxToMut,
|
R: VecZnxToMut,
|
||||||
A: VecZnxToRef,
|
A: VecZnxToRef,
|
||||||
@@ -293,9 +300,9 @@ unsafe impl<B: Backend + CPUAVX> VecZnxSubBAInplaceImpl<B> for B {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<B: Backend + CPUAVX> VecZnxSubScalarInplaceImpl<B> for B {
|
unsafe impl VecZnxSubScalarInplaceImpl<Self> for FFT64 {
|
||||||
fn vec_znx_sub_scalar_inplace_impl<R, A>(
|
fn vec_znx_sub_scalar_inplace_impl<R, A>(
|
||||||
module: &Module<B>,
|
module: &Module<Self>,
|
||||||
res: &mut R,
|
res: &mut R,
|
||||||
res_col: usize,
|
res_col: usize,
|
||||||
res_limb: usize,
|
res_limb: usize,
|
||||||
@@ -330,8 +337,8 @@ unsafe impl<B: Backend + CPUAVX> VecZnxSubScalarInplaceImpl<B> for B {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<B: Backend + CPUAVX> VecZnxNegateImpl<B> for B {
|
unsafe impl VecZnxNegateImpl<Self> for FFT64 {
|
||||||
fn vec_znx_negate_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
fn vec_znx_negate_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
where
|
where
|
||||||
R: VecZnxToMut,
|
R: VecZnxToMut,
|
||||||
A: VecZnxToRef,
|
A: VecZnxToRef,
|
||||||
@@ -356,8 +363,8 @@ unsafe impl<B: Backend + CPUAVX> VecZnxNegateImpl<B> for B {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<B: Backend + CPUAVX> VecZnxNegateInplaceImpl<B> for B {
|
unsafe impl VecZnxNegateInplaceImpl<Self> for FFT64 {
|
||||||
fn vec_znx_negate_inplace_impl<A>(module: &Module<B>, a: &mut A, a_col: usize)
|
fn vec_znx_negate_inplace_impl<A>(module: &Module<Self>, a: &mut A, a_col: usize)
|
||||||
where
|
where
|
||||||
A: VecZnxToMut,
|
A: VecZnxToMut,
|
||||||
{
|
{
|
||||||
@@ -376,8 +383,8 @@ unsafe impl<B: Backend + CPUAVX> VecZnxNegateInplaceImpl<B> for B {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<B: Backend + CPUAVX> VecZnxLshInplaceImpl<B> for B {
|
unsafe impl VecZnxLshInplaceImpl<Self> for FFT64 {
|
||||||
fn vec_znx_lsh_inplace_impl<A>(_module: &Module<B>, basek: usize, k: usize, a: &mut A)
|
fn vec_znx_lsh_inplace_impl<A>(_module: &Module<Self>, basek: usize, k: usize, a: &mut A)
|
||||||
where
|
where
|
||||||
A: VecZnxToMut,
|
A: VecZnxToMut,
|
||||||
{
|
{
|
||||||
@@ -417,8 +424,8 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<B: Backend + CPUAVX> VecZnxRshInplaceImpl<B> for B {
|
unsafe impl VecZnxRshInplaceImpl<Self> for FFT64 {
|
||||||
fn vec_znx_rsh_inplace_impl<A>(_module: &Module<B>, basek: usize, k: usize, a: &mut A)
|
fn vec_znx_rsh_inplace_impl<A>(_module: &Module<Self>, basek: usize, k: usize, a: &mut A)
|
||||||
where
|
where
|
||||||
A: VecZnxToMut,
|
A: VecZnxToMut,
|
||||||
{
|
{
|
||||||
@@ -461,8 +468,8 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<B: Backend + CPUAVX> VecZnxRotateImpl<B> for B {
|
unsafe impl VecZnxRotateImpl<Self> for FFT64 {
|
||||||
fn vec_znx_rotate_impl<R, A>(_module: &Module<B>, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
fn vec_znx_rotate_impl<R, A>(_module: &Module<Self>, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
where
|
where
|
||||||
R: VecZnxToMut,
|
R: VecZnxToMut,
|
||||||
A: VecZnxToRef,
|
A: VecZnxToRef,
|
||||||
@@ -486,8 +493,8 @@ unsafe impl<B: Backend + CPUAVX> VecZnxRotateImpl<B> for B {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<B: Backend + CPUAVX> VecZnxRotateInplaceImpl<B> for B {
|
unsafe impl VecZnxRotateInplaceImpl<Self> for FFT64 {
|
||||||
fn vec_znx_rotate_inplace_impl<A>(_module: &Module<B>, k: i64, a: &mut A, a_col: usize)
|
fn vec_znx_rotate_inplace_impl<A>(_module: &Module<Self>, k: i64, a: &mut A, a_col: usize)
|
||||||
where
|
where
|
||||||
A: VecZnxToMut,
|
A: VecZnxToMut,
|
||||||
{
|
{
|
||||||
@@ -500,8 +507,8 @@ unsafe impl<B: Backend + CPUAVX> VecZnxRotateInplaceImpl<B> for B {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<B: Backend + CPUAVX> VecZnxAutomorphismImpl<B> for B {
|
unsafe impl VecZnxAutomorphismImpl<Self> for FFT64 {
|
||||||
fn vec_znx_automorphism_impl<R, A>(module: &Module<B>, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
fn vec_znx_automorphism_impl<R, A>(module: &Module<Self>, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
where
|
where
|
||||||
R: VecZnxToMut,
|
R: VecZnxToMut,
|
||||||
A: VecZnxToRef,
|
A: VecZnxToRef,
|
||||||
@@ -527,8 +534,8 @@ unsafe impl<B: Backend + CPUAVX> VecZnxAutomorphismImpl<B> for B {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<B: Backend + CPUAVX> VecZnxAutomorphismInplaceImpl<B> for B {
|
unsafe impl VecZnxAutomorphismInplaceImpl<Self> for FFT64 {
|
||||||
fn vec_znx_automorphism_inplace_impl<A>(module: &Module<B>, k: i64, a: &mut A, a_col: usize)
|
fn vec_znx_automorphism_inplace_impl<A>(module: &Module<Self>, k: i64, a: &mut A, a_col: usize)
|
||||||
where
|
where
|
||||||
A: VecZnxToMut,
|
A: VecZnxToMut,
|
||||||
{
|
{
|
||||||
@@ -556,8 +563,8 @@ unsafe impl<B: Backend + CPUAVX> VecZnxAutomorphismInplaceImpl<B> for B {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<B: Backend + CPUAVX> VecZnxMulXpMinusOneImpl<B> for B {
|
unsafe impl VecZnxMulXpMinusOneImpl<Self> for FFT64 {
|
||||||
fn vec_znx_mul_xp_minus_one_impl<R, A>(module: &Module<B>, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
fn vec_znx_mul_xp_minus_one_impl<R, A>(module: &Module<Self>, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
where
|
where
|
||||||
R: VecZnxToMut,
|
R: VecZnxToMut,
|
||||||
A: VecZnxToRef,
|
A: VecZnxToRef,
|
||||||
@@ -584,8 +591,8 @@ unsafe impl<B: Backend + CPUAVX> VecZnxMulXpMinusOneImpl<B> for B {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<B: Backend + CPUAVX> VecZnxMulXpMinusOneInplaceImpl<B> for B {
|
unsafe impl VecZnxMulXpMinusOneInplaceImpl<Self> for FFT64 {
|
||||||
fn vec_znx_mul_xp_minus_one_inplace_impl<R>(module: &Module<B>, p: i64, res: &mut R, res_col: usize)
|
fn vec_znx_mul_xp_minus_one_inplace_impl<R>(module: &Module<Self>, p: i64, res: &mut R, res_col: usize)
|
||||||
where
|
where
|
||||||
R: VecZnxToMut,
|
R: VecZnxToMut,
|
||||||
{
|
{
|
||||||
@@ -609,9 +616,22 @@ unsafe impl<B: Backend + CPUAVX> VecZnxMulXpMinusOneInplaceImpl<B> for B {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<B: Backend + CPUAVX> VecZnxSplitImpl<B> for B {
|
unsafe impl VecZnxSplitImpl<Self> for FFT64
|
||||||
fn vec_znx_split_impl<R, A>(module: &Module<B>, res: &mut [R], res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
|
where
|
||||||
where
|
Self: TakeVecZnxImpl<Self>
|
||||||
|
+ TakeVecZnxImpl<Self>
|
||||||
|
+ VecZnxSwithcDegreeImpl<Self>
|
||||||
|
+ VecZnxRotateImpl<Self>
|
||||||
|
+ VecZnxRotateInplaceImpl<Self>,
|
||||||
|
{
|
||||||
|
fn vec_znx_split_impl<R, A>(
|
||||||
|
module: &Module<Self>,
|
||||||
|
res: &mut [R],
|
||||||
|
res_col: usize,
|
||||||
|
a: &A,
|
||||||
|
a_col: usize,
|
||||||
|
scratch: &mut Scratch<Self>,
|
||||||
|
) where
|
||||||
R: VecZnxToMut,
|
R: VecZnxToMut,
|
||||||
A: VecZnxToRef,
|
A: VecZnxToRef,
|
||||||
{
|
{
|
||||||
@@ -627,7 +647,7 @@ pub fn vec_znx_split_ref<R, A, B>(
|
|||||||
a_col: usize,
|
a_col: usize,
|
||||||
scratch: &mut Scratch<B>,
|
scratch: &mut Scratch<B>,
|
||||||
) where
|
) where
|
||||||
B: Backend + CPUAVX,
|
B: Backend + TakeVecZnxImpl<B> + VecZnxSwithcDegreeImpl<B> + VecZnxRotateImpl<B> + VecZnxRotateInplaceImpl<B>,
|
||||||
R: VecZnxToMut,
|
R: VecZnxToMut,
|
||||||
A: VecZnxToRef,
|
A: VecZnxToRef,
|
||||||
{
|
{
|
||||||
@@ -660,8 +680,11 @@ pub fn vec_znx_split_ref<R, A, B>(
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<B: Backend + CPUAVX> VecZnxMergeImpl<B> for B {
|
unsafe impl VecZnxMergeImpl<Self> for FFT64
|
||||||
fn vec_znx_merge_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &[A], a_col: usize)
|
where
|
||||||
|
Self: VecZnxSwithcDegreeImpl<Self> + VecZnxRotateInplaceImpl<Self>,
|
||||||
|
{
|
||||||
|
fn vec_znx_merge_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &[A], a_col: usize)
|
||||||
where
|
where
|
||||||
R: VecZnxToMut,
|
R: VecZnxToMut,
|
||||||
A: VecZnxToRef,
|
A: VecZnxToRef,
|
||||||
@@ -672,7 +695,7 @@ unsafe impl<B: Backend + CPUAVX> VecZnxMergeImpl<B> for B {
|
|||||||
|
|
||||||
pub fn vec_znx_merge_ref<R, A, B>(module: &Module<B>, res: &mut R, res_col: usize, a: &[A], a_col: usize)
|
pub fn vec_znx_merge_ref<R, A, B>(module: &Module<B>, res: &mut R, res_col: usize, a: &[A], a_col: usize)
|
||||||
where
|
where
|
||||||
B: Backend + CPUAVX,
|
B: Backend + VecZnxSwithcDegreeImpl<B> + VecZnxRotateInplaceImpl<B>,
|
||||||
R: VecZnxToMut,
|
R: VecZnxToMut,
|
||||||
A: VecZnxToRef,
|
A: VecZnxToRef,
|
||||||
{
|
{
|
||||||
@@ -700,8 +723,11 @@ where
|
|||||||
module.vec_znx_rotate_inplace(a.len() as i64, &mut res, res_col);
|
module.vec_znx_rotate_inplace(a.len() as i64, &mut res, res_col);
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<B: Backend + CPUAVX> VecZnxSwithcDegreeImpl<B> for B {
|
unsafe impl VecZnxSwithcDegreeImpl<Self> for FFT64
|
||||||
fn vec_znx_switch_degree_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
where
|
||||||
|
Self: VecZnxCopyImpl<Self>,
|
||||||
|
{
|
||||||
|
fn vec_znx_switch_degree_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
where
|
where
|
||||||
R: VecZnxToMut,
|
R: VecZnxToMut,
|
||||||
A: VecZnxToRef,
|
A: VecZnxToRef,
|
||||||
@@ -712,7 +738,7 @@ unsafe impl<B: Backend + CPUAVX> VecZnxSwithcDegreeImpl<B> for B {
|
|||||||
|
|
||||||
pub fn vec_znx_switch_degree_ref<R, A, B>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
pub fn vec_znx_switch_degree_ref<R, A, B>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
where
|
where
|
||||||
B: Backend + CPUAVX,
|
B: Backend + VecZnxCopyImpl<B>,
|
||||||
R: VecZnxToMut,
|
R: VecZnxToMut,
|
||||||
A: VecZnxToRef,
|
A: VecZnxToRef,
|
||||||
{
|
{
|
||||||
@@ -745,8 +771,8 @@ where
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<B: Backend + CPUAVX> VecZnxCopyImpl<B> for B {
|
unsafe impl VecZnxCopyImpl<Self> for FFT64 {
|
||||||
fn vec_znx_copy_impl<R, A>(_module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
fn vec_znx_copy_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
where
|
where
|
||||||
R: VecZnxToMut,
|
R: VecZnxToMut,
|
||||||
A: VecZnxToRef,
|
A: VecZnxToRef,
|
||||||
@@ -775,9 +801,15 @@ where
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<B: Backend + CPUAVX> VecZnxFillUniformImpl<B> for B {
|
unsafe impl VecZnxFillUniformImpl<Self> for FFT64 {
|
||||||
fn vec_znx_fill_uniform_impl<R>(_module: &Module<B>, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source)
|
fn vec_znx_fill_uniform_impl<R>(
|
||||||
where
|
_module: &Module<Self>,
|
||||||
|
basek: usize,
|
||||||
|
res: &mut R,
|
||||||
|
res_col: usize,
|
||||||
|
k: usize,
|
||||||
|
source: &mut Source,
|
||||||
|
) where
|
||||||
R: VecZnxToMut,
|
R: VecZnxToMut,
|
||||||
{
|
{
|
||||||
let mut a: VecZnx<&mut [u8]> = res.to_mut();
|
let mut a: VecZnx<&mut [u8]> = res.to_mut();
|
||||||
@@ -792,9 +824,9 @@ unsafe impl<B: Backend + CPUAVX> VecZnxFillUniformImpl<B> for B {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<B: Backend + CPUAVX> VecZnxFillDistF64Impl<B> for B {
|
unsafe impl VecZnxFillDistF64Impl<Self> for FFT64 {
|
||||||
fn vec_znx_fill_dist_f64_impl<R, D: rand::prelude::Distribution<f64>>(
|
fn vec_znx_fill_dist_f64_impl<R, D: rand::prelude::Distribution<f64>>(
|
||||||
_module: &Module<B>,
|
_module: &Module<Self>,
|
||||||
basek: usize,
|
basek: usize,
|
||||||
res: &mut R,
|
res: &mut R,
|
||||||
res_col: usize,
|
res_col: usize,
|
||||||
@@ -835,9 +867,9 @@ unsafe impl<B: Backend + CPUAVX> VecZnxFillDistF64Impl<B> for B {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<B: Backend + CPUAVX> VecZnxAddDistF64Impl<B> for B {
|
unsafe impl VecZnxAddDistF64Impl<Self> for FFT64 {
|
||||||
fn vec_znx_add_dist_f64_impl<R, D: rand::prelude::Distribution<f64>>(
|
fn vec_znx_add_dist_f64_impl<R, D: rand::prelude::Distribution<f64>>(
|
||||||
_module: &Module<B>,
|
_module: &Module<Self>,
|
||||||
basek: usize,
|
basek: usize,
|
||||||
res: &mut R,
|
res: &mut R,
|
||||||
res_col: usize,
|
res_col: usize,
|
||||||
@@ -878,9 +910,12 @@ unsafe impl<B: Backend + CPUAVX> VecZnxAddDistF64Impl<B> for B {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<B: Backend + CPUAVX> VecZnxFillNormalImpl<B> for B {
|
unsafe impl VecZnxFillNormalImpl<Self> for FFT64
|
||||||
|
where
|
||||||
|
Self: VecZnxFillDistF64Impl<Self>,
|
||||||
|
{
|
||||||
fn vec_znx_fill_normal_impl<R>(
|
fn vec_znx_fill_normal_impl<R>(
|
||||||
module: &Module<B>,
|
module: &Module<Self>,
|
||||||
basek: usize,
|
basek: usize,
|
||||||
res: &mut R,
|
res: &mut R,
|
||||||
res_col: usize,
|
res_col: usize,
|
||||||
@@ -903,9 +938,12 @@ unsafe impl<B: Backend + CPUAVX> VecZnxFillNormalImpl<B> for B {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<B: Backend + CPUAVX> VecZnxAddNormalImpl<B> for B {
|
unsafe impl VecZnxAddNormalImpl<Self> for FFT64
|
||||||
|
where
|
||||||
|
Self: VecZnxAddDistF64Impl<Self>,
|
||||||
|
{
|
||||||
fn vec_znx_add_normal_impl<R>(
|
fn vec_znx_add_normal_impl<R>(
|
||||||
module: &Module<B>,
|
module: &Module<Self>,
|
||||||
basek: usize,
|
basek: usize,
|
||||||
res: &mut R,
|
res: &mut R,
|
||||||
res_col: usize,
|
res_col: usize,
|
||||||
@@ -1,69 +1,46 @@
|
|||||||
use std::fmt;
|
|
||||||
|
|
||||||
use rand_distr::{Distribution, Normal};
|
use rand_distr::{Distribution, Normal};
|
||||||
|
|
||||||
use crate::{
|
use crate::cpu_spqlios::{FFT64, ffi::vec_znx};
|
||||||
hal::{
|
use poulpy_hal::{
|
||||||
api::{
|
api::{
|
||||||
TakeSlice, VecZnxBigAddDistF64, VecZnxBigFillDistF64, VecZnxBigNormalizeTmpBytes, ZnxInfos, ZnxSliceSize, ZnxView,
|
TakeSlice, VecZnxBigAddDistF64, VecZnxBigFillDistF64, VecZnxBigNormalizeTmpBytes, ZnxInfos, ZnxSliceSize, ZnxView,
|
||||||
ZnxViewMut,
|
ZnxViewMut,
|
||||||
},
|
|
||||||
layouts::{
|
|
||||||
Data, DataRef, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigBytesOf, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef,
|
|
||||||
VecZnxToMut, VecZnxToRef,
|
|
||||||
},
|
|
||||||
oep::{
|
|
||||||
VecZnxBigAddDistF64Impl, VecZnxBigAddImpl, VecZnxBigAddInplaceImpl, VecZnxBigAddNormalImpl, VecZnxBigAddSmallImpl,
|
|
||||||
VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, VecZnxBigAutomorphismImpl,
|
|
||||||
VecZnxBigAutomorphismInplaceImpl, VecZnxBigFillDistF64Impl, VecZnxBigFillNormalImpl, VecZnxBigFromBytesImpl,
|
|
||||||
VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl, VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubABInplaceImpl,
|
|
||||||
VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl, VecZnxBigSubSmallAImpl, VecZnxBigSubSmallAInplaceImpl,
|
|
||||||
VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl,
|
|
||||||
},
|
|
||||||
source::Source,
|
|
||||||
},
|
},
|
||||||
implementation::cpu_spqlios::{ffi::vec_znx, module_fft64::FFT64},
|
layouts::{
|
||||||
|
Backend, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef,
|
||||||
|
},
|
||||||
|
oep::{
|
||||||
|
TakeSliceImpl, VecZnxBigAddDistF64Impl, VecZnxBigAddImpl, VecZnxBigAddInplaceImpl, VecZnxBigAddNormalImpl,
|
||||||
|
VecZnxBigAddSmallImpl, VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl,
|
||||||
|
VecZnxBigAutomorphismImpl, VecZnxBigAutomorphismInplaceImpl, VecZnxBigFillDistF64Impl, VecZnxBigFillNormalImpl,
|
||||||
|
VecZnxBigFromBytesImpl, VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl, VecZnxBigNormalizeTmpBytesImpl,
|
||||||
|
VecZnxBigSubABInplaceImpl, VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl, VecZnxBigSubSmallAImpl,
|
||||||
|
VecZnxBigSubSmallAInplaceImpl, VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl,
|
||||||
|
},
|
||||||
|
source::Source,
|
||||||
};
|
};
|
||||||
|
|
||||||
const VEC_ZNX_BIG_FFT64_WORDSIZE: usize = 1;
|
unsafe impl VecZnxBigAllocBytesImpl<Self> for FFT64 {
|
||||||
|
|
||||||
impl<D: DataRef> ZnxView for VecZnxBig<D, FFT64> {
|
|
||||||
type Scalar = i64;
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<D: Data> VecZnxBigBytesOf for VecZnxBig<D, FFT64> {
|
|
||||||
fn bytes_of(n: usize, cols: usize, size: usize) -> usize {
|
|
||||||
VEC_ZNX_BIG_FFT64_WORDSIZE * n * cols * size * size_of::<f64>()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<D: Data> ZnxSliceSize for VecZnxBig<D, FFT64> {
|
|
||||||
fn sl(&self) -> usize {
|
|
||||||
VEC_ZNX_BIG_FFT64_WORDSIZE * self.n() * self.cols()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
unsafe impl VecZnxBigAllocImpl<FFT64> for FFT64 {
|
|
||||||
fn vec_znx_big_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxBigOwned<FFT64> {
|
|
||||||
VecZnxBig::<Vec<u8>, FFT64>::new(n, cols, size)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
unsafe impl VecZnxBigFromBytesImpl<FFT64> for FFT64 {
|
|
||||||
fn vec_znx_big_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<FFT64> {
|
|
||||||
VecZnxBig::<Vec<u8>, FFT64>::new_from_bytes(n, cols, size, bytes)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
unsafe impl VecZnxBigAllocBytesImpl<FFT64> for FFT64 {
|
|
||||||
fn vec_znx_big_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize {
|
fn vec_znx_big_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize {
|
||||||
VecZnxBig::<Vec<u8>, FFT64>::bytes_of(n, cols, size)
|
Self::layout_big_word_count() * n * cols * size * size_of::<f64>()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl VecZnxBigAddDistF64Impl<FFT64> for FFT64 {
|
unsafe impl VecZnxBigAllocImpl<Self> for FFT64 {
|
||||||
fn add_dist_f64_impl<R: VecZnxBigToMut<FFT64>, D: Distribution<f64>>(
|
fn vec_znx_big_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxBigOwned<Self> {
|
||||||
_module: &Module<FFT64>,
|
VecZnxBig::alloc(n, cols, size)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe impl VecZnxBigFromBytesImpl<Self> for FFT64 {
|
||||||
|
fn vec_znx_big_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<Self> {
|
||||||
|
VecZnxBig::from_bytes(n, cols, size, bytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe impl VecZnxBigAddDistF64Impl<Self> for FFT64 {
|
||||||
|
fn add_dist_f64_impl<R: VecZnxBigToMut<Self>, D: Distribution<f64>>(
|
||||||
|
_module: &Module<Self>,
|
||||||
basek: usize,
|
basek: usize,
|
||||||
res: &mut R,
|
res: &mut R,
|
||||||
res_col: usize,
|
res_col: usize,
|
||||||
@@ -72,7 +49,7 @@ unsafe impl VecZnxBigAddDistF64Impl<FFT64> for FFT64 {
|
|||||||
dist: D,
|
dist: D,
|
||||||
bound: f64,
|
bound: f64,
|
||||||
) {
|
) {
|
||||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut();
|
||||||
assert!(
|
assert!(
|
||||||
(bound.log2().ceil() as i64) < 64,
|
(bound.log2().ceil() as i64) < 64,
|
||||||
"invalid bound: ceil(log2(bound))={} > 63",
|
"invalid bound: ceil(log2(bound))={} > 63",
|
||||||
@@ -102,9 +79,9 @@ unsafe impl VecZnxBigAddDistF64Impl<FFT64> for FFT64 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl VecZnxBigAddNormalImpl<FFT64> for FFT64 {
|
unsafe impl VecZnxBigAddNormalImpl<Self> for FFT64 {
|
||||||
fn add_normal_impl<R: VecZnxBigToMut<FFT64>>(
|
fn add_normal_impl<R: VecZnxBigToMut<Self>>(
|
||||||
module: &Module<FFT64>,
|
module: &Module<Self>,
|
||||||
basek: usize,
|
basek: usize,
|
||||||
res: &mut R,
|
res: &mut R,
|
||||||
res_col: usize,
|
res_col: usize,
|
||||||
@@ -125,9 +102,9 @@ unsafe impl VecZnxBigAddNormalImpl<FFT64> for FFT64 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl VecZnxBigFillDistF64Impl<FFT64> for FFT64 {
|
unsafe impl VecZnxBigFillDistF64Impl<Self> for FFT64 {
|
||||||
fn fill_dist_f64_impl<R: VecZnxBigToMut<FFT64>, D: Distribution<f64>>(
|
fn fill_dist_f64_impl<R: VecZnxBigToMut<Self>, D: Distribution<f64>>(
|
||||||
_module: &Module<FFT64>,
|
_module: &Module<Self>,
|
||||||
basek: usize,
|
basek: usize,
|
||||||
res: &mut R,
|
res: &mut R,
|
||||||
res_col: usize,
|
res_col: usize,
|
||||||
@@ -136,7 +113,7 @@ unsafe impl VecZnxBigFillDistF64Impl<FFT64> for FFT64 {
|
|||||||
dist: D,
|
dist: D,
|
||||||
bound: f64,
|
bound: f64,
|
||||||
) {
|
) {
|
||||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut();
|
||||||
assert!(
|
assert!(
|
||||||
(bound.log2().ceil() as i64) < 64,
|
(bound.log2().ceil() as i64) < 64,
|
||||||
"invalid bound: ceil(log2(bound))={} > 63",
|
"invalid bound: ceil(log2(bound))={} > 63",
|
||||||
@@ -166,9 +143,9 @@ unsafe impl VecZnxBigFillDistF64Impl<FFT64> for FFT64 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl VecZnxBigFillNormalImpl<FFT64> for FFT64 {
|
unsafe impl VecZnxBigFillNormalImpl<Self> for FFT64 {
|
||||||
fn fill_normal_impl<R: VecZnxBigToMut<FFT64>>(
|
fn fill_normal_impl<R: VecZnxBigToMut<Self>>(
|
||||||
module: &Module<FFT64>,
|
module: &Module<Self>,
|
||||||
basek: usize,
|
basek: usize,
|
||||||
res: &mut R,
|
res: &mut R,
|
||||||
res_col: usize,
|
res_col: usize,
|
||||||
@@ -189,24 +166,17 @@ unsafe impl VecZnxBigFillNormalImpl<FFT64> for FFT64 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl VecZnxBigAddImpl<FFT64> for FFT64 {
|
unsafe impl VecZnxBigAddImpl<Self> for FFT64 {
|
||||||
/// Adds `a` to `b` and stores the result on `c`.
|
/// Adds `a` to `b` and stores the result on `c`.
|
||||||
fn vec_znx_big_add_impl<R, A, B>(
|
fn vec_znx_big_add_impl<R, A, B>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||||
module: &Module<FFT64>,
|
where
|
||||||
res: &mut R,
|
R: VecZnxBigToMut<Self>,
|
||||||
res_col: usize,
|
A: VecZnxBigToRef<Self>,
|
||||||
a: &A,
|
B: VecZnxBigToRef<Self>,
|
||||||
a_col: usize,
|
|
||||||
b: &B,
|
|
||||||
b_col: usize,
|
|
||||||
) where
|
|
||||||
R: VecZnxBigToMut<FFT64>,
|
|
||||||
A: VecZnxBigToRef<FFT64>,
|
|
||||||
B: VecZnxBigToRef<FFT64>,
|
|
||||||
{
|
{
|
||||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
let a: VecZnxBig<&[u8], Self> = a.to_ref();
|
||||||
let b: VecZnxBig<&[u8], FFT64> = b.to_ref();
|
let b: VecZnxBig<&[u8], Self> = b.to_ref();
|
||||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut();
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
@@ -231,15 +201,15 @@ unsafe impl VecZnxBigAddImpl<FFT64> for FFT64 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl VecZnxBigAddInplaceImpl<FFT64> for FFT64 {
|
unsafe impl VecZnxBigAddInplaceImpl<Self> for FFT64 {
|
||||||
/// Adds `a` to `b` and stores the result on `b`.
|
/// Adds `a` to `b` and stores the result on `b`.
|
||||||
fn vec_znx_big_add_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
fn vec_znx_big_add_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
where
|
where
|
||||||
R: VecZnxBigToMut<FFT64>,
|
R: VecZnxBigToMut<Self>,
|
||||||
A: VecZnxBigToRef<FFT64>,
|
A: VecZnxBigToRef<Self>,
|
||||||
{
|
{
|
||||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
let a: VecZnxBig<&[u8], Self> = a.to_ref();
|
||||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut();
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
@@ -262,10 +232,10 @@ unsafe impl VecZnxBigAddInplaceImpl<FFT64> for FFT64 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl VecZnxBigAddSmallImpl<FFT64> for FFT64 {
|
unsafe impl VecZnxBigAddSmallImpl<Self> for FFT64 {
|
||||||
/// Adds `a` to `b` and stores the result on `c`.
|
/// Adds `a` to `b` and stores the result on `c`.
|
||||||
fn vec_znx_big_add_small_impl<R, A, B>(
|
fn vec_znx_big_add_small_impl<R, A, B>(
|
||||||
module: &Module<FFT64>,
|
module: &Module<Self>,
|
||||||
res: &mut R,
|
res: &mut R,
|
||||||
res_col: usize,
|
res_col: usize,
|
||||||
a: &A,
|
a: &A,
|
||||||
@@ -273,13 +243,13 @@ unsafe impl VecZnxBigAddSmallImpl<FFT64> for FFT64 {
|
|||||||
b: &B,
|
b: &B,
|
||||||
b_col: usize,
|
b_col: usize,
|
||||||
) where
|
) where
|
||||||
R: VecZnxBigToMut<FFT64>,
|
R: VecZnxBigToMut<Self>,
|
||||||
A: VecZnxBigToRef<FFT64>,
|
A: VecZnxBigToRef<Self>,
|
||||||
B: VecZnxToRef,
|
B: VecZnxToRef,
|
||||||
{
|
{
|
||||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
let a: VecZnxBig<&[u8], Self> = a.to_ref();
|
||||||
let b: VecZnx<&[u8]> = b.to_ref();
|
let b: VecZnx<&[u8]> = b.to_ref();
|
||||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut();
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
@@ -304,15 +274,15 @@ unsafe impl VecZnxBigAddSmallImpl<FFT64> for FFT64 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl VecZnxBigAddSmallInplaceImpl<FFT64> for FFT64 {
|
unsafe impl VecZnxBigAddSmallInplaceImpl<Self> for FFT64 {
|
||||||
/// Adds `a` to `b` and stores the result on `b`.
|
/// Adds `a` to `b` and stores the result on `b`.
|
||||||
fn vec_znx_big_add_small_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
fn vec_znx_big_add_small_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
where
|
where
|
||||||
R: VecZnxBigToMut<FFT64>,
|
R: VecZnxBigToMut<Self>,
|
||||||
A: VecZnxToRef,
|
A: VecZnxToRef,
|
||||||
{
|
{
|
||||||
let a: VecZnx<&[u8]> = a.to_ref();
|
let a: VecZnx<&[u8]> = a.to_ref();
|
||||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut();
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
@@ -335,24 +305,17 @@ unsafe impl VecZnxBigAddSmallInplaceImpl<FFT64> for FFT64 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl VecZnxBigSubImpl<FFT64> for FFT64 {
|
unsafe impl VecZnxBigSubImpl<Self> for FFT64 {
|
||||||
/// Subtracts `a` to `b` and stores the result on `c`.
|
/// Subtracts `a` to `b` and stores the result on `c`.
|
||||||
fn vec_znx_big_sub_impl<R, A, B>(
|
fn vec_znx_big_sub_impl<R, A, B>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||||
module: &Module<FFT64>,
|
where
|
||||||
res: &mut R,
|
R: VecZnxBigToMut<Self>,
|
||||||
res_col: usize,
|
A: VecZnxBigToRef<Self>,
|
||||||
a: &A,
|
B: VecZnxBigToRef<Self>,
|
||||||
a_col: usize,
|
|
||||||
b: &B,
|
|
||||||
b_col: usize,
|
|
||||||
) where
|
|
||||||
R: VecZnxBigToMut<FFT64>,
|
|
||||||
A: VecZnxBigToRef<FFT64>,
|
|
||||||
B: VecZnxBigToRef<FFT64>,
|
|
||||||
{
|
{
|
||||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
let a: VecZnxBig<&[u8], Self> = a.to_ref();
|
||||||
let b: VecZnxBig<&[u8], FFT64> = b.to_ref();
|
let b: VecZnxBig<&[u8], Self> = b.to_ref();
|
||||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut();
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
@@ -377,15 +340,15 @@ unsafe impl VecZnxBigSubImpl<FFT64> for FFT64 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl VecZnxBigSubABInplaceImpl<FFT64> for FFT64 {
|
unsafe impl VecZnxBigSubABInplaceImpl<Self> for FFT64 {
|
||||||
/// Subtracts `a` from `b` and stores the result on `b`.
|
/// Subtracts `a` from `b` and stores the result on `b`.
|
||||||
fn vec_znx_big_sub_ab_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
fn vec_znx_big_sub_ab_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
where
|
where
|
||||||
R: VecZnxBigToMut<FFT64>,
|
R: VecZnxBigToMut<Self>,
|
||||||
A: VecZnxBigToRef<FFT64>,
|
A: VecZnxBigToRef<Self>,
|
||||||
{
|
{
|
||||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
let a: VecZnxBig<&[u8], Self> = a.to_ref();
|
||||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut();
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
@@ -408,15 +371,15 @@ unsafe impl VecZnxBigSubABInplaceImpl<FFT64> for FFT64 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl VecZnxBigSubBAInplaceImpl<FFT64> for FFT64 {
|
unsafe impl VecZnxBigSubBAInplaceImpl<Self> for FFT64 {
|
||||||
/// Subtracts `b` from `a` and stores the result on `b`.
|
/// Subtracts `b` from `a` and stores the result on `b`.
|
||||||
fn vec_znx_big_sub_ba_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
fn vec_znx_big_sub_ba_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
where
|
where
|
||||||
R: VecZnxBigToMut<FFT64>,
|
R: VecZnxBigToMut<Self>,
|
||||||
A: VecZnxBigToRef<FFT64>,
|
A: VecZnxBigToRef<Self>,
|
||||||
{
|
{
|
||||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
let a: VecZnxBig<&[u8], Self> = a.to_ref();
|
||||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut();
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
@@ -439,10 +402,10 @@ unsafe impl VecZnxBigSubBAInplaceImpl<FFT64> for FFT64 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl VecZnxBigSubSmallAImpl<FFT64> for FFT64 {
|
unsafe impl VecZnxBigSubSmallAImpl<Self> for FFT64 {
|
||||||
/// Subtracts `b` from `a` and stores the result on `c`.
|
/// Subtracts `b` from `a` and stores the result on `c`.
|
||||||
fn vec_znx_big_sub_small_a_impl<R, A, B>(
|
fn vec_znx_big_sub_small_a_impl<R, A, B>(
|
||||||
module: &Module<FFT64>,
|
module: &Module<Self>,
|
||||||
res: &mut R,
|
res: &mut R,
|
||||||
res_col: usize,
|
res_col: usize,
|
||||||
a: &A,
|
a: &A,
|
||||||
@@ -450,13 +413,13 @@ unsafe impl VecZnxBigSubSmallAImpl<FFT64> for FFT64 {
|
|||||||
b: &B,
|
b: &B,
|
||||||
b_col: usize,
|
b_col: usize,
|
||||||
) where
|
) where
|
||||||
R: VecZnxBigToMut<FFT64>,
|
R: VecZnxBigToMut<Self>,
|
||||||
A: VecZnxToRef,
|
A: VecZnxToRef,
|
||||||
B: VecZnxBigToRef<FFT64>,
|
B: VecZnxBigToRef<Self>,
|
||||||
{
|
{
|
||||||
let a: VecZnx<&[u8]> = a.to_ref();
|
let a: VecZnx<&[u8]> = a.to_ref();
|
||||||
let b: VecZnxBig<&[u8], FFT64> = b.to_ref();
|
let b: VecZnxBig<&[u8], Self> = b.to_ref();
|
||||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut();
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
@@ -481,15 +444,15 @@ unsafe impl VecZnxBigSubSmallAImpl<FFT64> for FFT64 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl VecZnxBigSubSmallAInplaceImpl<FFT64> for FFT64 {
|
unsafe impl VecZnxBigSubSmallAInplaceImpl<Self> for FFT64 {
|
||||||
/// Subtracts `a` from `res` and stores the result on `res`.
|
/// Subtracts `a` from `res` and stores the result on `res`.
|
||||||
fn vec_znx_big_sub_small_a_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
fn vec_znx_big_sub_small_a_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
where
|
where
|
||||||
R: VecZnxBigToMut<FFT64>,
|
R: VecZnxBigToMut<Self>,
|
||||||
A: VecZnxToRef,
|
A: VecZnxToRef,
|
||||||
{
|
{
|
||||||
let a: VecZnx<&[u8]> = a.to_ref();
|
let a: VecZnx<&[u8]> = a.to_ref();
|
||||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut();
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
@@ -512,10 +475,10 @@ unsafe impl VecZnxBigSubSmallAInplaceImpl<FFT64> for FFT64 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl VecZnxBigSubSmallBImpl<FFT64> for FFT64 {
|
unsafe impl VecZnxBigSubSmallBImpl<Self> for FFT64 {
|
||||||
/// Subtracts `b` from `a` and stores the result on `c`.
|
/// Subtracts `b` from `a` and stores the result on `c`.
|
||||||
fn vec_znx_big_sub_small_b_impl<R, A, B>(
|
fn vec_znx_big_sub_small_b_impl<R, A, B>(
|
||||||
module: &Module<FFT64>,
|
module: &Module<Self>,
|
||||||
res: &mut R,
|
res: &mut R,
|
||||||
res_col: usize,
|
res_col: usize,
|
||||||
a: &A,
|
a: &A,
|
||||||
@@ -523,13 +486,13 @@ unsafe impl VecZnxBigSubSmallBImpl<FFT64> for FFT64 {
|
|||||||
b: &B,
|
b: &B,
|
||||||
b_col: usize,
|
b_col: usize,
|
||||||
) where
|
) where
|
||||||
R: VecZnxBigToMut<FFT64>,
|
R: VecZnxBigToMut<Self>,
|
||||||
A: VecZnxBigToRef<FFT64>,
|
A: VecZnxBigToRef<Self>,
|
||||||
B: VecZnxToRef,
|
B: VecZnxToRef,
|
||||||
{
|
{
|
||||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
let a: VecZnxBig<&[u8], Self> = a.to_ref();
|
||||||
let b: VecZnx<&[u8]> = b.to_ref();
|
let b: VecZnx<&[u8]> = b.to_ref();
|
||||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut();
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
@@ -554,15 +517,15 @@ unsafe impl VecZnxBigSubSmallBImpl<FFT64> for FFT64 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl VecZnxBigSubSmallBInplaceImpl<FFT64> for FFT64 {
|
unsafe impl VecZnxBigSubSmallBInplaceImpl<Self> for FFT64 {
|
||||||
/// Subtracts `res` from `a` and stores the result on `res`.
|
/// Subtracts `res` from `a` and stores the result on `res`.
|
||||||
fn vec_znx_big_sub_small_b_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
fn vec_znx_big_sub_small_b_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
where
|
where
|
||||||
R: VecZnxBigToMut<FFT64>,
|
R: VecZnxBigToMut<Self>,
|
||||||
A: VecZnxToRef,
|
A: VecZnxToRef,
|
||||||
{
|
{
|
||||||
let a: VecZnx<&[u8]> = a.to_ref();
|
let a: VecZnx<&[u8]> = a.to_ref();
|
||||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut();
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
@@ -585,12 +548,12 @@ unsafe impl VecZnxBigSubSmallBInplaceImpl<FFT64> for FFT64 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl VecZnxBigNegateInplaceImpl<FFT64> for FFT64 {
|
unsafe impl VecZnxBigNegateInplaceImpl<Self> for FFT64 {
|
||||||
fn vec_znx_big_negate_inplace_impl<A>(module: &Module<FFT64>, a: &mut A, a_col: usize)
|
fn vec_znx_big_negate_inplace_impl<A>(module: &Module<Self>, a: &mut A, a_col: usize)
|
||||||
where
|
where
|
||||||
A: VecZnxBigToMut<FFT64>,
|
A: VecZnxBigToMut<Self>,
|
||||||
{
|
{
|
||||||
let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut();
|
let mut a: VecZnxBig<&mut [u8], Self> = a.to_mut();
|
||||||
unsafe {
|
unsafe {
|
||||||
vec_znx::vec_znx_negate(
|
vec_znx::vec_znx_negate(
|
||||||
module.ptr(),
|
module.ptr(),
|
||||||
@@ -605,26 +568,29 @@ unsafe impl VecZnxBigNegateInplaceImpl<FFT64> for FFT64 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl VecZnxBigNormalizeTmpBytesImpl<FFT64> for FFT64 {
|
unsafe impl VecZnxBigNormalizeTmpBytesImpl<Self> for FFT64 {
|
||||||
fn vec_znx_big_normalize_tmp_bytes_impl(module: &Module<FFT64>, n: usize) -> usize {
|
fn vec_znx_big_normalize_tmp_bytes_impl(module: &Module<Self>, n: usize) -> usize {
|
||||||
unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(module.ptr(), n as u64) as usize }
|
unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(module.ptr(), n as u64) as usize }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl VecZnxBigNormalizeImpl<FFT64> for FFT64 {
|
unsafe impl VecZnxBigNormalizeImpl<Self> for FFT64
|
||||||
|
where
|
||||||
|
Self: TakeSliceImpl<Self>,
|
||||||
|
{
|
||||||
fn vec_znx_big_normalize_impl<R, A>(
|
fn vec_znx_big_normalize_impl<R, A>(
|
||||||
module: &Module<FFT64>,
|
module: &Module<Self>,
|
||||||
basek: usize,
|
basek: usize,
|
||||||
res: &mut R,
|
res: &mut R,
|
||||||
res_col: usize,
|
res_col: usize,
|
||||||
a: &A,
|
a: &A,
|
||||||
a_col: usize,
|
a_col: usize,
|
||||||
scratch: &mut Scratch<FFT64>,
|
scratch: &mut Scratch<Self>,
|
||||||
) where
|
) where
|
||||||
R: VecZnxToMut,
|
R: VecZnxToMut,
|
||||||
A: VecZnxBigToRef<FFT64>,
|
A: VecZnxBigToRef<Self>,
|
||||||
{
|
{
|
||||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
let a: VecZnxBig<&[u8], Self> = a.to_ref();
|
||||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
@@ -650,15 +616,15 @@ unsafe impl VecZnxBigNormalizeImpl<FFT64> for FFT64 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl VecZnxBigAutomorphismImpl<FFT64> for FFT64 {
|
unsafe impl VecZnxBigAutomorphismImpl<Self> for FFT64 {
|
||||||
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`.
|
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`.
|
||||||
fn vec_znx_big_automorphism_impl<R, A>(module: &Module<FFT64>, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
fn vec_znx_big_automorphism_impl<R, A>(module: &Module<Self>, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
where
|
where
|
||||||
R: VecZnxBigToMut<FFT64>,
|
R: VecZnxBigToMut<Self>,
|
||||||
A: VecZnxBigToRef<FFT64>,
|
A: VecZnxBigToRef<Self>,
|
||||||
{
|
{
|
||||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
let a: VecZnxBig<&[u8], Self> = a.to_ref();
|
||||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut();
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
@@ -679,13 +645,13 @@ unsafe impl VecZnxBigAutomorphismImpl<FFT64> for FFT64 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl VecZnxBigAutomorphismInplaceImpl<FFT64> for FFT64 {
|
unsafe impl VecZnxBigAutomorphismInplaceImpl<Self> for FFT64 {
|
||||||
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`.
|
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`.
|
||||||
fn vec_znx_big_automorphism_inplace_impl<A>(module: &Module<FFT64>, k: i64, a: &mut A, a_col: usize)
|
fn vec_znx_big_automorphism_inplace_impl<A>(module: &Module<Self>, k: i64, a: &mut A, a_col: usize)
|
||||||
where
|
where
|
||||||
A: VecZnxBigToMut<FFT64>,
|
A: VecZnxBigToMut<Self>,
|
||||||
{
|
{
|
||||||
let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut();
|
let mut a: VecZnxBig<&mut [u8], Self> = a.to_mut();
|
||||||
unsafe {
|
unsafe {
|
||||||
vec_znx::vec_znx_automorphism(
|
vec_znx::vec_znx_automorphism(
|
||||||
module.ptr(),
|
module.ptr(),
|
||||||
@@ -700,38 +666,3 @@ unsafe impl VecZnxBigAutomorphismInplaceImpl<FFT64> for FFT64 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<D: DataRef> fmt::Display for VecZnxBig<D, FFT64> {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
||||||
writeln!(
|
|
||||||
f,
|
|
||||||
"VecZnxBig(n={}, cols={}, size={})",
|
|
||||||
self.n, self.cols, self.size
|
|
||||||
)?;
|
|
||||||
|
|
||||||
for col in 0..self.cols {
|
|
||||||
writeln!(f, "Column {}:", col)?;
|
|
||||||
for size in 0..self.size {
|
|
||||||
let coeffs = self.at(col, size);
|
|
||||||
write!(f, " Size {}: [", size)?;
|
|
||||||
|
|
||||||
let max_show = 100;
|
|
||||||
let show_count = coeffs.len().min(max_show);
|
|
||||||
|
|
||||||
for (i, &coeff) in coeffs.iter().take(show_count).enumerate() {
|
|
||||||
if i > 0 {
|
|
||||||
write!(f, ", ")?;
|
|
||||||
}
|
|
||||||
write!(f, "{}", coeff)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
if coeffs.len() > max_show {
|
|
||||||
write!(f, ", ... ({} more)", coeffs.len() - max_show)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
writeln!(f, "]")?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,78 +1,57 @@
|
|||||||
use std::fmt;
|
use poulpy_hal::{
|
||||||
|
api::{TakeSlice, VecZnxDftToVecZnxBigTmpBytes, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero},
|
||||||
use crate::{
|
layouts::{
|
||||||
hal::{
|
Backend, Data, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut,
|
||||||
api::{TakeSlice, VecZnxDftToVecZnxBigTmpBytes, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero},
|
VecZnxDftToRef, VecZnxToRef,
|
||||||
layouts::{
|
|
||||||
Data, DataRef, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftBytesOf, VecZnxDftOwned,
|
|
||||||
VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef,
|
|
||||||
},
|
|
||||||
oep::{
|
|
||||||
VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftCopyImpl,
|
|
||||||
VecZnxDftFromBytesImpl, VecZnxDftFromVecZnxImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl,
|
|
||||||
VecZnxDftSubImpl, VecZnxDftToVecZnxBigConsumeImpl, VecZnxDftToVecZnxBigImpl, VecZnxDftToVecZnxBigTmpAImpl,
|
|
||||||
VecZnxDftToVecZnxBigTmpBytesImpl, VecZnxDftZeroImpl,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
implementation::cpu_spqlios::{
|
oep::{
|
||||||
ffi::{vec_znx_big, vec_znx_dft},
|
VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftCopyImpl,
|
||||||
module_fft64::FFT64,
|
VecZnxDftFromBytesImpl, VecZnxDftFromVecZnxImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl, VecZnxDftSubImpl,
|
||||||
|
VecZnxDftToVecZnxBigConsumeImpl, VecZnxDftToVecZnxBigImpl, VecZnxDftToVecZnxBigTmpAImpl,
|
||||||
|
VecZnxDftToVecZnxBigTmpBytesImpl, VecZnxDftZeroImpl,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
const VEC_ZNX_DFT_FFT64_WORDSIZE: usize = 1;
|
use crate::cpu_spqlios::{
|
||||||
|
FFT64,
|
||||||
|
ffi::{vec_znx_big, vec_znx_dft},
|
||||||
|
};
|
||||||
|
|
||||||
impl<D: Data> ZnxSliceSize for VecZnxDft<D, FFT64> {
|
unsafe impl VecZnxDftFromBytesImpl<Self> for FFT64 {
|
||||||
fn sl(&self) -> usize {
|
fn vec_znx_dft_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<Self> {
|
||||||
VEC_ZNX_DFT_FFT64_WORDSIZE * self.n() * self.cols()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<D: Data> VecZnxDftBytesOf for VecZnxDft<D, FFT64> {
|
|
||||||
fn bytes_of(n: usize, cols: usize, size: usize) -> usize {
|
|
||||||
VEC_ZNX_DFT_FFT64_WORDSIZE * n * cols * size * size_of::<f64>()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<D: DataRef> ZnxView for VecZnxDft<D, FFT64> {
|
|
||||||
type Scalar = f64;
|
|
||||||
}
|
|
||||||
|
|
||||||
unsafe impl VecZnxDftFromBytesImpl<FFT64> for FFT64 {
|
|
||||||
fn vec_znx_dft_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<FFT64> {
|
|
||||||
VecZnxDft::<Vec<u8>, FFT64>::from_bytes(n, cols, size, bytes)
|
VecZnxDft::<Vec<u8>, FFT64>::from_bytes(n, cols, size, bytes)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl VecZnxDftAllocBytesImpl<FFT64> for FFT64 {
|
unsafe impl VecZnxDftAllocBytesImpl<Self> for FFT64 {
|
||||||
fn vec_znx_dft_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize {
|
fn vec_znx_dft_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize {
|
||||||
VecZnxDft::<Vec<u8>, FFT64>::bytes_of(n, cols, size)
|
FFT64::layout_prep_word_count() * n * cols * size * size_of::<<FFT64 as Backend>::ScalarPrep>()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl VecZnxDftAllocImpl<FFT64> for FFT64 {
|
unsafe impl VecZnxDftAllocImpl<Self> for FFT64 {
|
||||||
fn vec_znx_dft_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxDftOwned<FFT64> {
|
fn vec_znx_dft_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxDftOwned<Self> {
|
||||||
VecZnxDftOwned::alloc(n, cols, size)
|
VecZnxDftOwned::alloc(n, cols, size)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl VecZnxDftToVecZnxBigTmpBytesImpl<FFT64> for FFT64 {
|
unsafe impl VecZnxDftToVecZnxBigTmpBytesImpl<Self> for FFT64 {
|
||||||
fn vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(module: &Module<FFT64>, n: usize) -> usize {
|
fn vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(module: &Module<Self>, n: usize) -> usize {
|
||||||
unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(module.ptr(), n as u64) as usize }
|
unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(module.ptr(), n as u64) as usize }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl VecZnxDftToVecZnxBigImpl<FFT64> for FFT64 {
|
unsafe impl VecZnxDftToVecZnxBigImpl<Self> for FFT64 {
|
||||||
fn vec_znx_dft_to_vec_znx_big_impl<R, A>(
|
fn vec_znx_dft_to_vec_znx_big_impl<R, A>(
|
||||||
module: &Module<FFT64>,
|
module: &Module<Self>,
|
||||||
res: &mut R,
|
res: &mut R,
|
||||||
res_col: usize,
|
res_col: usize,
|
||||||
a: &A,
|
a: &A,
|
||||||
a_col: usize,
|
a_col: usize,
|
||||||
scratch: &mut Scratch<FFT64>,
|
scratch: &mut Scratch<Self>,
|
||||||
) where
|
) where
|
||||||
R: VecZnxBigToMut<FFT64>,
|
R: VecZnxBigToMut<Self>,
|
||||||
A: VecZnxDftToRef<FFT64>,
|
A: VecZnxDftToRef<Self>,
|
||||||
{
|
{
|
||||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||||
let a: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
let a: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||||
@@ -104,11 +83,11 @@ unsafe impl VecZnxDftToVecZnxBigImpl<FFT64> for FFT64 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl VecZnxDftToVecZnxBigTmpAImpl<FFT64> for FFT64 {
|
unsafe impl VecZnxDftToVecZnxBigTmpAImpl<Self> for FFT64 {
|
||||||
fn vec_znx_dft_to_vec_znx_big_tmp_a_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
|
fn vec_znx_dft_to_vec_znx_big_tmp_a_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
|
||||||
where
|
where
|
||||||
R: VecZnxBigToMut<FFT64>,
|
R: VecZnxBigToMut<Self>,
|
||||||
A: VecZnxDftToMut<FFT64>,
|
A: VecZnxDftToMut<Self>,
|
||||||
{
|
{
|
||||||
let mut res_mut: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
let mut res_mut: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||||
let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut();
|
let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut();
|
||||||
@@ -132,10 +111,10 @@ unsafe impl VecZnxDftToVecZnxBigTmpAImpl<FFT64> for FFT64 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl VecZnxDftToVecZnxBigConsumeImpl<FFT64> for FFT64 {
|
unsafe impl VecZnxDftToVecZnxBigConsumeImpl<Self> for FFT64 {
|
||||||
fn vec_znx_dft_to_vec_znx_big_consume_impl<D: Data>(module: &Module<FFT64>, mut a: VecZnxDft<D, FFT64>) -> VecZnxBig<D, FFT64>
|
fn vec_znx_dft_to_vec_znx_big_consume_impl<D: Data>(module: &Module<Self>, mut a: VecZnxDft<D, FFT64>) -> VecZnxBig<D, FFT64>
|
||||||
where
|
where
|
||||||
VecZnxDft<D, FFT64>: VecZnxDftToMut<FFT64>,
|
VecZnxDft<D, FFT64>: VecZnxDftToMut<Self>,
|
||||||
{
|
{
|
||||||
let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut();
|
let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut();
|
||||||
|
|
||||||
@@ -158,9 +137,9 @@ unsafe impl VecZnxDftToVecZnxBigConsumeImpl<FFT64> for FFT64 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl VecZnxDftFromVecZnxImpl<FFT64> for FFT64 {
|
unsafe impl VecZnxDftFromVecZnxImpl<Self> for FFT64 {
|
||||||
fn vec_znx_dft_from_vec_znx_impl<R, A>(
|
fn vec_znx_dft_from_vec_znx_impl<R, A>(
|
||||||
module: &Module<FFT64>,
|
module: &Module<Self>,
|
||||||
step: usize,
|
step: usize,
|
||||||
offset: usize,
|
offset: usize,
|
||||||
res: &mut R,
|
res: &mut R,
|
||||||
@@ -168,7 +147,7 @@ unsafe impl VecZnxDftFromVecZnxImpl<FFT64> for FFT64 {
|
|||||||
a: &A,
|
a: &A,
|
||||||
a_col: usize,
|
a_col: usize,
|
||||||
) where
|
) where
|
||||||
R: VecZnxDftToMut<FFT64>,
|
R: VecZnxDftToMut<Self>,
|
||||||
A: VecZnxToRef,
|
A: VecZnxToRef,
|
||||||
{
|
{
|
||||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||||
@@ -196,19 +175,12 @@ unsafe impl VecZnxDftFromVecZnxImpl<FFT64> for FFT64 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl VecZnxDftAddImpl<FFT64> for FFT64 {
|
unsafe impl VecZnxDftAddImpl<Self> for FFT64 {
|
||||||
fn vec_znx_dft_add_impl<R, A, D>(
|
fn vec_znx_dft_add_impl<R, A, D>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
|
||||||
module: &Module<FFT64>,
|
where
|
||||||
res: &mut R,
|
R: VecZnxDftToMut<Self>,
|
||||||
res_col: usize,
|
A: VecZnxDftToRef<Self>,
|
||||||
a: &A,
|
D: VecZnxDftToRef<Self>,
|
||||||
a_col: usize,
|
|
||||||
b: &D,
|
|
||||||
b_col: usize,
|
|
||||||
) where
|
|
||||||
R: VecZnxDftToMut<FFT64>,
|
|
||||||
A: VecZnxDftToRef<FFT64>,
|
|
||||||
D: VecZnxDftToRef<FFT64>,
|
|
||||||
{
|
{
|
||||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||||
@@ -235,11 +207,11 @@ unsafe impl VecZnxDftAddImpl<FFT64> for FFT64 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl VecZnxDftAddInplaceImpl<FFT64> for FFT64 {
|
unsafe impl VecZnxDftAddInplaceImpl<Self> for FFT64 {
|
||||||
fn vec_znx_dft_add_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
fn vec_znx_dft_add_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
where
|
where
|
||||||
R: VecZnxDftToMut<FFT64>,
|
R: VecZnxDftToMut<Self>,
|
||||||
A: VecZnxDftToRef<FFT64>,
|
A: VecZnxDftToRef<Self>,
|
||||||
{
|
{
|
||||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||||
@@ -262,19 +234,12 @@ unsafe impl VecZnxDftAddInplaceImpl<FFT64> for FFT64 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl VecZnxDftSubImpl<FFT64> for FFT64 {
|
unsafe impl VecZnxDftSubImpl<Self> for FFT64 {
|
||||||
fn vec_znx_dft_sub_impl<R, A, D>(
|
fn vec_znx_dft_sub_impl<R, A, D>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
|
||||||
module: &Module<FFT64>,
|
where
|
||||||
res: &mut R,
|
R: VecZnxDftToMut<Self>,
|
||||||
res_col: usize,
|
A: VecZnxDftToRef<Self>,
|
||||||
a: &A,
|
D: VecZnxDftToRef<Self>,
|
||||||
a_col: usize,
|
|
||||||
b: &D,
|
|
||||||
b_col: usize,
|
|
||||||
) where
|
|
||||||
R: VecZnxDftToMut<FFT64>,
|
|
||||||
A: VecZnxDftToRef<FFT64>,
|
|
||||||
D: VecZnxDftToRef<FFT64>,
|
|
||||||
{
|
{
|
||||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||||
@@ -301,11 +266,11 @@ unsafe impl VecZnxDftSubImpl<FFT64> for FFT64 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl VecZnxDftSubABInplaceImpl<FFT64> for FFT64 {
|
unsafe impl VecZnxDftSubABInplaceImpl<Self> for FFT64 {
|
||||||
fn vec_znx_dft_sub_ab_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
fn vec_znx_dft_sub_ab_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
where
|
where
|
||||||
R: VecZnxDftToMut<FFT64>,
|
R: VecZnxDftToMut<Self>,
|
||||||
A: VecZnxDftToRef<FFT64>,
|
A: VecZnxDftToRef<Self>,
|
||||||
{
|
{
|
||||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||||
@@ -328,11 +293,11 @@ unsafe impl VecZnxDftSubABInplaceImpl<FFT64> for FFT64 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl VecZnxDftSubBAInplaceImpl<FFT64> for FFT64 {
|
unsafe impl VecZnxDftSubBAInplaceImpl<Self> for FFT64 {
|
||||||
fn vec_znx_dft_sub_ba_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
fn vec_znx_dft_sub_ba_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
where
|
where
|
||||||
R: VecZnxDftToMut<FFT64>,
|
R: VecZnxDftToMut<Self>,
|
||||||
A: VecZnxDftToRef<FFT64>,
|
A: VecZnxDftToRef<Self>,
|
||||||
{
|
{
|
||||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||||
@@ -355,9 +320,9 @@ unsafe impl VecZnxDftSubBAInplaceImpl<FFT64> for FFT64 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl VecZnxDftCopyImpl<FFT64> for FFT64 {
|
unsafe impl VecZnxDftCopyImpl<Self> for FFT64 {
|
||||||
fn vec_znx_dft_copy_impl<R, A>(
|
fn vec_znx_dft_copy_impl<R, A>(
|
||||||
_module: &Module<FFT64>,
|
_module: &Module<Self>,
|
||||||
step: usize,
|
step: usize,
|
||||||
offset: usize,
|
offset: usize,
|
||||||
res: &mut R,
|
res: &mut R,
|
||||||
@@ -365,8 +330,8 @@ unsafe impl VecZnxDftCopyImpl<FFT64> for FFT64 {
|
|||||||
a: &A,
|
a: &A,
|
||||||
a_col: usize,
|
a_col: usize,
|
||||||
) where
|
) where
|
||||||
R: VecZnxDftToMut<FFT64>,
|
R: VecZnxDftToMut<Self>,
|
||||||
A: VecZnxDftToRef<FFT64>,
|
A: VecZnxDftToRef<Self>,
|
||||||
{
|
{
|
||||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||||
@@ -388,46 +353,11 @@ unsafe impl VecZnxDftCopyImpl<FFT64> for FFT64 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl VecZnxDftZeroImpl<FFT64> for FFT64 {
|
unsafe impl VecZnxDftZeroImpl<Self> for FFT64 {
|
||||||
fn vec_znx_dft_zero_impl<R>(_module: &Module<FFT64>, res: &mut R)
|
fn vec_znx_dft_zero_impl<R>(_module: &Module<Self>, res: &mut R)
|
||||||
where
|
where
|
||||||
R: VecZnxDftToMut<FFT64>,
|
R: VecZnxDftToMut<Self>,
|
||||||
{
|
{
|
||||||
res.to_mut().data.fill(0);
|
res.to_mut().data.fill(0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<D: DataRef> fmt::Display for VecZnxDft<D, FFT64> {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
||||||
writeln!(
|
|
||||||
f,
|
|
||||||
"VecZnxDft(n={}, cols={}, size={})",
|
|
||||||
self.n, self.cols, self.size
|
|
||||||
)?;
|
|
||||||
|
|
||||||
for col in 0..self.cols {
|
|
||||||
writeln!(f, "Column {}:", col)?;
|
|
||||||
for size in 0..self.size {
|
|
||||||
let coeffs = self.at(col, size);
|
|
||||||
write!(f, " Size {}: [", size)?;
|
|
||||||
|
|
||||||
let max_show = 100;
|
|
||||||
let show_count = coeffs.len().min(max_show);
|
|
||||||
|
|
||||||
for (i, &coeff) in coeffs.iter().take(show_count).enumerate() {
|
|
||||||
if i > 0 {
|
|
||||||
write!(f, ", ")?;
|
|
||||||
}
|
|
||||||
write!(f, "{}", coeff)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
if coeffs.len() > max_show {
|
|
||||||
write!(f, ", ... ({} more)", coeffs.len() - max_show)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
writeln!(f, "]")?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,39 +1,23 @@
|
|||||||
use crate::{
|
use poulpy_hal::{
|
||||||
hal::{
|
api::{TakeSlice, VmpApplyTmpBytes, VmpPrepareTmpBytes, ZnxInfos, ZnxView, ZnxViewMut},
|
||||||
api::{TakeSlice, VmpApplyTmpBytes, VmpPrepareTmpBytes, ZnxInfos, ZnxView, ZnxViewMut},
|
layouts::{
|
||||||
layouts::{
|
Backend, MatZnx, MatZnxToRef, Module, Scratch, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, VmpPMat, VmpPMatOwned,
|
||||||
DataRef, MatZnx, MatZnxToRef, Module, Scratch, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, VmpPMat, VmpPMatBytesOf,
|
VmpPMatToMut, VmpPMatToRef,
|
||||||
VmpPMatOwned, VmpPMatToMut, VmpPMatToRef,
|
|
||||||
},
|
|
||||||
oep::{
|
|
||||||
VmpApplyAddImpl, VmpApplyAddTmpBytesImpl, VmpApplyImpl, VmpApplyTmpBytesImpl, VmpPMatAllocBytesImpl,
|
|
||||||
VmpPMatAllocImpl, VmpPMatFromBytesImpl, VmpPMatPrepareImpl, VmpPrepareTmpBytesImpl,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
implementation::cpu_spqlios::{
|
oep::{
|
||||||
ffi::{vec_znx_dft::vec_znx_dft_t, vmp},
|
VmpApplyAddImpl, VmpApplyAddTmpBytesImpl, VmpApplyImpl, VmpApplyTmpBytesImpl, VmpPMatAllocBytesImpl, VmpPMatAllocImpl,
|
||||||
module_fft64::FFT64,
|
VmpPMatFromBytesImpl, VmpPMatPrepareImpl, VmpPrepareTmpBytesImpl,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
const VMP_PMAT_FFT64_WORDSIZE: usize = 1;
|
use crate::cpu_spqlios::{
|
||||||
|
FFT64,
|
||||||
|
ffi::{vec_znx_dft::vec_znx_dft_t, vmp},
|
||||||
|
};
|
||||||
|
|
||||||
impl<D: DataRef> ZnxView for VmpPMat<D, FFT64> {
|
unsafe impl VmpPMatAllocBytesImpl<FFT64> for FFT64 {
|
||||||
type Scalar = f64;
|
|
||||||
}
|
|
||||||
|
|
||||||
impl VmpPMatBytesOf for FFT64 {
|
|
||||||
fn vmp_pmat_bytes_of(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
|
||||||
VMP_PMAT_FFT64_WORDSIZE * n * rows * cols_in * cols_out * size * size_of::<f64>()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
unsafe impl VmpPMatAllocBytesImpl<FFT64> for FFT64
|
|
||||||
where
|
|
||||||
FFT64: VmpPMatBytesOf,
|
|
||||||
{
|
|
||||||
fn vmp_pmat_alloc_bytes_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
fn vmp_pmat_alloc_bytes_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||||
FFT64::vmp_pmat_bytes_of(n, rows, cols_in, cols_out, size)
|
FFT64::layout_prep_word_count() * n * rows * cols_in * cols_out * size * size_of::<f64>()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -251,8 +235,6 @@ unsafe impl VmpApplyAddImpl<FFT64> for FFT64 {
|
|||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
use crate::hal::api::ZnxInfos;
|
|
||||||
|
|
||||||
assert_eq!(b.n(), res.n());
|
assert_eq!(b.n(), res.n());
|
||||||
assert_eq!(a.n(), res.n());
|
assert_eq!(a.n(), res.n());
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
9
poulpy-backend/src/cpu_spqlios/mod.rs
Normal file
9
poulpy-backend/src/cpu_spqlios/mod.rs
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
mod ffi;
|
||||||
|
mod fft64;
|
||||||
|
mod ntt120;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test;
|
||||||
|
|
||||||
|
pub use fft64::*;
|
||||||
|
pub use ntt120::*;
|
||||||
7
poulpy-backend/src/cpu_spqlios/ntt120/mod.rs
Normal file
7
poulpy-backend/src/cpu_spqlios/ntt120/mod.rs
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
mod module;
|
||||||
|
mod svp_ppol;
|
||||||
|
mod vec_znx_big;
|
||||||
|
mod vec_znx_dft;
|
||||||
|
mod vmp_pmat;
|
||||||
|
|
||||||
|
pub use module::NTT120;
|
||||||
@@ -1,25 +1,29 @@
|
|||||||
use std::ptr::NonNull;
|
use std::ptr::NonNull;
|
||||||
|
|
||||||
use crate::{
|
use poulpy_hal::{
|
||||||
hal::{
|
layouts::{Backend, Module},
|
||||||
layouts::{Backend, Module},
|
oep::ModuleNewImpl,
|
||||||
oep::ModuleNewImpl,
|
|
||||||
},
|
|
||||||
implementation::cpu_spqlios::{
|
|
||||||
CPUAVX,
|
|
||||||
ffi::module::{MODULE, delete_module_info, new_module_info},
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use crate::cpu_spqlios::ffi::module::{MODULE, delete_module_info, new_module_info};
|
||||||
|
|
||||||
pub struct NTT120;
|
pub struct NTT120;
|
||||||
|
|
||||||
impl CPUAVX for NTT120 {}
|
|
||||||
|
|
||||||
impl Backend for NTT120 {
|
impl Backend for NTT120 {
|
||||||
|
type ScalarPrep = i64;
|
||||||
|
type ScalarBig = i128;
|
||||||
type Handle = MODULE;
|
type Handle = MODULE;
|
||||||
unsafe fn destroy(handle: NonNull<Self::Handle>) {
|
unsafe fn destroy(handle: NonNull<Self::Handle>) {
|
||||||
unsafe { delete_module_info(handle.as_ptr()) }
|
unsafe { delete_module_info(handle.as_ptr()) }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn layout_big_word_count() -> usize {
|
||||||
|
4
|
||||||
|
}
|
||||||
|
|
||||||
|
fn layout_prep_word_count() -> usize {
|
||||||
|
1
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl ModuleNewImpl<Self> for NTT120 {
|
unsafe impl ModuleNewImpl<Self> for NTT120 {
|
||||||
24
poulpy-backend/src/cpu_spqlios/ntt120/svp_ppol.rs
Normal file
24
poulpy-backend/src/cpu_spqlios/ntt120/svp_ppol.rs
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
use poulpy_hal::{
|
||||||
|
layouts::{Backend, SvpPPolOwned},
|
||||||
|
oep::{SvpPPolAllocBytesImpl, SvpPPolAllocImpl, SvpPPolFromBytesImpl},
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::cpu_spqlios::NTT120;
|
||||||
|
|
||||||
|
unsafe impl SvpPPolFromBytesImpl<Self> for NTT120 {
|
||||||
|
fn svp_ppol_from_bytes_impl(n: usize, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<NTT120> {
|
||||||
|
SvpPPolOwned::from_bytes(n, cols, bytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe impl SvpPPolAllocImpl<Self> for NTT120 {
|
||||||
|
fn svp_ppol_alloc_impl(n: usize, cols: usize) -> SvpPPolOwned<NTT120> {
|
||||||
|
SvpPPolOwned::alloc(n, cols)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe impl SvpPPolAllocBytesImpl<Self> for NTT120 {
|
||||||
|
fn svp_ppol_alloc_bytes_impl(n: usize, cols: usize) -> usize {
|
||||||
|
NTT120::layout_prep_word_count() * n * cols * size_of::<i64>()
|
||||||
|
}
|
||||||
|
}
|
||||||
9
poulpy-backend/src/cpu_spqlios/ntt120/vec_znx_big.rs
Normal file
9
poulpy-backend/src/cpu_spqlios/ntt120/vec_znx_big.rs
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
use poulpy_hal::{layouts::Backend, oep::VecZnxBigAllocBytesImpl};
|
||||||
|
|
||||||
|
use crate::cpu_spqlios::NTT120;
|
||||||
|
|
||||||
|
unsafe impl VecZnxBigAllocBytesImpl<NTT120> for NTT120 {
|
||||||
|
fn vec_znx_big_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize {
|
||||||
|
NTT120::layout_big_word_count() * n * cols * size * size_of::<i128>()
|
||||||
|
}
|
||||||
|
}
|
||||||
18
poulpy-backend/src/cpu_spqlios/ntt120/vec_znx_dft.rs
Normal file
18
poulpy-backend/src/cpu_spqlios/ntt120/vec_znx_dft.rs
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
use poulpy_hal::{
|
||||||
|
layouts::{Backend, VecZnxDftOwned},
|
||||||
|
oep::{VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl},
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::cpu_spqlios::NTT120;
|
||||||
|
|
||||||
|
unsafe impl VecZnxDftAllocBytesImpl<NTT120> for NTT120 {
|
||||||
|
fn vec_znx_dft_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize {
|
||||||
|
NTT120::layout_prep_word_count() * n * cols * size * size_of::<i64>()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe impl VecZnxDftAllocImpl<NTT120> for NTT120 {
|
||||||
|
fn vec_znx_dft_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxDftOwned<NTT120> {
|
||||||
|
VecZnxDftOwned::alloc(n, cols, size)
|
||||||
|
}
|
||||||
|
}
|
||||||
1
poulpy-backend/src/cpu_spqlios/ntt120/vmp_pmat.rs
Normal file
1
poulpy-backend/src/cpu_spqlios/ntt120/vmp_pmat.rs
Normal file
@@ -0,0 +1 @@
|
|||||||
|
|
||||||
@@ -0,0 +1,14 @@
|
|||||||
|
# Use the Google style in this project.
|
||||||
|
BasedOnStyle: Google
|
||||||
|
|
||||||
|
# Some folks prefer to write "int& foo" while others prefer "int &foo". The
|
||||||
|
# Google Style Guide only asks for consistency within a project, we chose
|
||||||
|
# "int& foo" for this project:
|
||||||
|
DerivePointerAlignment: false
|
||||||
|
PointerAlignment: Left
|
||||||
|
|
||||||
|
# The Google Style Guide only asks for consistency w.r.t. "east const" vs.
|
||||||
|
# "const west" alignment of cv-qualifiers. In this project we use "east const".
|
||||||
|
QualifierAlignment: Left
|
||||||
|
|
||||||
|
ColumnLimit: 120
|
||||||
20
poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/.github/workflows/auto-release.yml
vendored
Normal file
20
poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/.github/workflows/auto-release.yml
vendored
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
name: Auto-Release
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
push:
|
||||||
|
branches: [ "main" ]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
name: Auto-Release
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
with:
|
||||||
|
fetch-depth: 3
|
||||||
|
# sparse-checkout: manifest.yaml scripts/auto-release.sh
|
||||||
|
|
||||||
|
- run:
|
||||||
|
${{github.workspace}}/scripts/auto-release.sh
|
||||||
6
poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/.gitignore
vendored
Normal file
6
poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/.gitignore
vendored
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
cmake-build-*
|
||||||
|
.idea
|
||||||
|
|
||||||
|
build
|
||||||
|
.vscode
|
||||||
|
.*.sh
|
||||||
@@ -0,0 +1,69 @@
|
|||||||
|
cmake_minimum_required(VERSION 3.8)
|
||||||
|
project(spqlios)
|
||||||
|
|
||||||
|
# read the current version from the manifest file
|
||||||
|
file(READ "manifest.yaml" manifest)
|
||||||
|
string(REGEX MATCH "version: +(([0-9]+)\\.([0-9]+)\\.([0-9]+))" SPQLIOS_VERSION_BLAH ${manifest})
|
||||||
|
#message(STATUS "Version: ${SPQLIOS_VERSION_BLAH}")
|
||||||
|
set(SPQLIOS_VERSION ${CMAKE_MATCH_1})
|
||||||
|
set(SPQLIOS_VERSION_MAJOR ${CMAKE_MATCH_2})
|
||||||
|
set(SPQLIOS_VERSION_MINOR ${CMAKE_MATCH_3})
|
||||||
|
set(SPQLIOS_VERSION_PATCH ${CMAKE_MATCH_4})
|
||||||
|
message(STATUS "Compiling spqlios-fft version: ${SPQLIOS_VERSION_MAJOR}.${SPQLIOS_VERSION_MINOR}.${SPQLIOS_VERSION_PATCH}")
|
||||||
|
|
||||||
|
#set(ENABLE_SPQLIOS_F128 ON CACHE BOOL "Enable float128 via libquadmath")
|
||||||
|
set(WARNING_PARANOID ON CACHE BOOL "Treat all warnings as errors")
|
||||||
|
set(ENABLE_TESTING ON CACHE BOOL "Compiles unittests and integration tests")
|
||||||
|
set(DEVMODE_INSTALL OFF CACHE BOOL "Install private headers and testlib (mainly for CI)")
|
||||||
|
|
||||||
|
if (NOT CMAKE_BUILD_TYPE OR CMAKE_BUILD_TYPE STREQUAL "")
|
||||||
|
set(CMAKE_BUILD_TYPE "Release" CACHE STRING "Build type: Release or Debug" FORCE)
|
||||||
|
endif()
|
||||||
|
message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
|
||||||
|
|
||||||
|
if (WARNING_PARANOID)
|
||||||
|
add_compile_options(-Wall -Werror -Wno-unused-command-line-argument)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
message(STATUS "CMAKE_HOST_SYSTEM_NAME: ${CMAKE_HOST_SYSTEM_NAME}")
|
||||||
|
message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
|
||||||
|
message(STATUS "CMAKE_SYSTEM_NAME: ${CMAKE_SYSTEM_NAME}")
|
||||||
|
|
||||||
|
if (CMAKE_SYSTEM_PROCESSOR MATCHES "(x86)|(X86)|(amd64)|(AMD64)")
|
||||||
|
set(X86 ON)
|
||||||
|
set(AARCH64 OFF)
|
||||||
|
else ()
|
||||||
|
set(X86 OFF)
|
||||||
|
# set(ENABLE_SPQLIOS_F128 OFF) # float128 are only supported for x86 targets
|
||||||
|
endif ()
|
||||||
|
if (CMAKE_SYSTEM_PROCESSOR MATCHES "(aarch64)|(arm64)")
|
||||||
|
set(AARCH64 ON)
|
||||||
|
endif ()
|
||||||
|
|
||||||
|
if (CMAKE_SYSTEM_NAME MATCHES "(Windows)|(MSYS)")
|
||||||
|
set(WIN32 ON)
|
||||||
|
endif ()
|
||||||
|
if (WIN32)
|
||||||
|
#overrides for win32
|
||||||
|
set(X86 OFF)
|
||||||
|
set(AARCH64 OFF)
|
||||||
|
set(X86_WIN32 ON)
|
||||||
|
else()
|
||||||
|
set(X86_WIN32 OFF)
|
||||||
|
set(WIN32 OFF)
|
||||||
|
endif (WIN32)
|
||||||
|
|
||||||
|
message(STATUS "--> WIN32: ${WIN32}")
|
||||||
|
message(STATUS "--> X86_WIN32: ${X86_WIN32}")
|
||||||
|
message(STATUS "--> X86_LINUX: ${X86}")
|
||||||
|
message(STATUS "--> AARCH64: ${AARCH64}")
|
||||||
|
|
||||||
|
# compiles the main library in spqlios
|
||||||
|
add_subdirectory(spqlios)
|
||||||
|
|
||||||
|
# compiles and activates unittests and itests
|
||||||
|
if (${ENABLE_TESTING})
|
||||||
|
enable_testing()
|
||||||
|
add_subdirectory(test)
|
||||||
|
endif()
|
||||||
|
|
||||||
@@ -0,0 +1,77 @@
|
|||||||
|
# Contributing to SPQlios-fft
|
||||||
|
|
||||||
|
The spqlios-fft team encourages contributions.
|
||||||
|
We encourage users to fix bugs, improve the documentation, write tests and to enhance the code, or ask for new features.
|
||||||
|
We encourage researchers to contribute with implementations of their FFT or NTT algorithms.
|
||||||
|
In the following we are trying to give some guidance on how to contribute effectively.
|
||||||
|
|
||||||
|
## Communication ##
|
||||||
|
|
||||||
|
Communication in the spqlios-fft project happens mainly on [GitHub](https://github.com/tfhe/spqlios-fft/issues).
|
||||||
|
|
||||||
|
All communications are public, so please make sure to maintain professional behaviour in
|
||||||
|
all published comments. See [Code of Conduct](https://www.contributor-covenant.org/version/2/1/code_of_conduct/) for
|
||||||
|
guidelines.
|
||||||
|
|
||||||
|
## Reporting Bugs or Requesting features ##
|
||||||
|
|
||||||
|
Bug should be filed at [https://github.com/tfhe/spqlios-fft/issues](https://github.com/tfhe/spqlios-fft/issues).
|
||||||
|
|
||||||
|
Features can also be requested there, in this case, please ensure that the features you request are self-contained,
|
||||||
|
easy to define, and generic enough to be used in different use-cases. Please provide an example of use-cases if
|
||||||
|
possible.
|
||||||
|
|
||||||
|
## Setting up topic branches and generating pull requests
|
||||||
|
|
||||||
|
This section applies to people that already have write access to the repository. Specific instructions for pull-requests
|
||||||
|
from public forks will be given later.
|
||||||
|
|
||||||
|
To implement some changes, please follow these steps:
|
||||||
|
|
||||||
|
- Create a "topic branch". Usually, the branch name should be `username/small-title`
|
||||||
|
or better `username/issuenumber-small-title` where `issuenumber` is the number of
|
||||||
|
the github issue number that is tackled.
|
||||||
|
- Push any needed commits to your branch. Make sure it compiles in `CMAKE_BUILD_TYPE=Debug` and `=Release`, with `-DWARNING_PARANOID=ON`.
|
||||||
|
- When the branch is nearly ready for review, please open a pull request, and add the label `check-on-arm`
|
||||||
|
- Do as many commits as necessary until all CI checks pass and all PR comments have been resolved.
|
||||||
|
|
||||||
|
> _During the process, you may optionnally use `git rebase -i` to clean up your commit history. If you elect to do so,
|
||||||
|
please at the very least make sure that nobody else is working or has forked from your branch: the conflicts it would generate
|
||||||
|
and the human hours to fix them are not worth it. `Git merge` remains the preferred option._
|
||||||
|
|
||||||
|
- Finally, when all reviews are positive and all CI checks pass, you may merge your branch via the github webpage.
|
||||||
|
|
||||||
|
### Keep your pull requests limited to a single issue
|
||||||
|
|
||||||
|
Pull requests should be as small/atomic as possible.
|
||||||
|
|
||||||
|
### Coding Conventions
|
||||||
|
|
||||||
|
* Please make sure that your code is formatted according to the `.clang-format` file and
|
||||||
|
that all files end with a newline character.
|
||||||
|
* Please make sure that all the functions declared in the public api have relevant doxygen comments.
|
||||||
|
Preferably, functions in the private apis should also contain a brief doxygen description.
|
||||||
|
|
||||||
|
### Versions and History
|
||||||
|
|
||||||
|
* **Stable API** The project uses semantic versioning on the functions that are listed as `stable` in the documentation. A version has
|
||||||
|
the form `x.y.z`
|
||||||
|
* a patch release that increments `z` does not modify the stable API.
|
||||||
|
* a minor release that increments `y` adds a new feature to the stable API.
|
||||||
|
* In the unlikely case where we need to change or remove a feature, we will trigger a major release that
|
||||||
|
increments `x`.
|
||||||
|
|
||||||
|
> _If any, we will mark those features as deprecated at least six months before the major release._
|
||||||
|
|
||||||
|
* **Experimental API** Features that are not part of the stable section in the documentation are experimental features: you may test them at
|
||||||
|
your own risk,
|
||||||
|
but keep in mind that semantic versioning does not apply to them.
|
||||||
|
|
||||||
|
> _If you have a use-case that uses an experimental feature, we encourage
|
||||||
|
> you to tell us about it, so that this feature reaches to the stable section faster!_
|
||||||
|
|
||||||
|
* **Version history** The current version is reported in `manifest.yaml`, any change of version comes up with a tag on the main branch, and the history between releases is summarized in `Changelog.md`. It is the main source of truth for anyone who wishes to
|
||||||
|
get insight about
|
||||||
|
the history of the repository (not the commit graph).
|
||||||
|
|
||||||
|
> Note: _The commit graph of git is for git's internal use only. Its main purpose is to reduce potential merge conflicts to a minimum, even in scenario where multiple features are developped in parallel: it may therefore be non-linear. If, as humans, we like to see a linear history, please read `Changelog.md` instead!_
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
# Changelog
|
||||||
|
|
||||||
|
All notable changes to this project will be documented in this file.
|
||||||
|
this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||||
|
|
||||||
|
## [2.0.0] - 2024-08-21
|
||||||
|
|
||||||
|
- Initial release of the `vec_znx` (except convolution products), `vec_rnx` and `zn` apis.
|
||||||
|
- Hardware acceleration available: AVX2 (most parts)
|
||||||
|
- APIs are documented in the wiki and are in "beta mode": during the 2.x -> 3.x transition, functions whose API is satisfactory in test projects will pass in "stable mode".
|
||||||
|
|
||||||
|
## [1.0.0] - 2023-07-18
|
||||||
|
|
||||||
|
- Initial release of the double precision fft on the reim and cplx backends
|
||||||
|
- Coeffs-space conversions cplx <-> znx32 and tnx32
|
||||||
|
- FFT-space conversions cplx <-> reim4 layouts
|
||||||
|
- FFT-space multiplications on the cplx, reim and reim4 layouts.
|
||||||
|
- In this first release, the only platform supported is linux x86_64 (generic C code, and avx2/fma). It compiles on arm64, but without any acceleration.
|
||||||
201
poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/LICENSE
Normal file
201
poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/LICENSE
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright [yyyy] [name of copyright owner]
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
65
poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/README.md
Normal file
65
poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/README.md
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
# SPQlios library
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
The SPQlios library provides fast arithmetic for Fully Homomorphic Encryption, and other lattice constructions that arise in post quantum cryptography.
|
||||||
|
|
||||||
|
<img src="docs/api-full.svg">
|
||||||
|
|
||||||
|
Namely, it is divided into 4 sections:
|
||||||
|
|
||||||
|
* The low-level DFT section support FFT over 64-bit floats, as well as NTT modulo one fixed 120-bit modulus. It is an upgrade of the original spqlios-fft module embedded in the TFHE library since 2016. The DFT section exposes the traditional DFT, inverse-DFT, and coefficient-wise multiplications in DFT space.
|
||||||
|
* The VEC_ZNX section exposes fast algebra over vectors of small integer polynomial modulo $X^N+1$. It proposed in particular efficient (prepared) vector-matrix products, scalar-vector products, convolution products, and element-wise products, operations that naturally occurs on gadget-decomposed Ring-LWE coordinates.
|
||||||
|
* The RNX section is a simpler variant of VEC_ZNX, to represent single polynomials modulo $X^N+1$ (over the reals or over the torus) when the coefficient precision fits on 64-bit doubles. The small vector-matrix API of the RNX section is particularly adapted to reproducing the fastest CGGI-based bootstrappings.
|
||||||
|
* The ZN section focuses over vector and matrix algebra over scalars (used by scalar LWE, or scalar key-switches, but also on non-ring schemes like Frodo, FrodoPIR, and SimplePIR).
|
||||||
|
|
||||||
|
### A high value target for hardware accelerations
|
||||||
|
|
||||||
|
SPQlios is more than a library, it is also a good target for hardware developers.
|
||||||
|
On one hand, the arithmetic operations that are defined in the library have a clear standalone mathematical definition. And at the same time, the amount of work in each operations is sufficiently large so that meaningful functions only require a few of these.
|
||||||
|
|
||||||
|
This makes the SPQlios API a high value target for hardware acceleration, that targets FHE.
|
||||||
|
|
||||||
|
### SPQLios is not an FHE library, but a huge enabler
|
||||||
|
|
||||||
|
SPQlios itself is not an FHE library: there is no ciphertext, plaintext or key. It is a mathematical library that exposes efficient algebra over polynomials. Using the functions exposed, it is possible to quickly build efficient FHE libraries, with support for the main schemes based on Ring-LWE: BFV, BGV, CGGI, DM, CKKS.
|
||||||
|
|
||||||
|
|
||||||
|
## Dependencies
|
||||||
|
|
||||||
|
The SPQLIOS-FFT library is a C library that can be compiled with a standard C compiler, and depends only on libc and libm. The API
|
||||||
|
interface can be used in a regular C code, and any other language via classical foreign APIs.
|
||||||
|
|
||||||
|
The unittests and integration tests are in an optional part of the code, and are written in C++. These tests rely on
|
||||||
|
[```benchmark```](https://github.com/google/benchmark), and [```gtest```](https://github.com/google/googletest) libraries, and therefore require a C++17 compiler.
|
||||||
|
|
||||||
|
Currently, the project has been tested with the gcc,g++ >= 11.3.0 compiler under Linux (x86_64). In the future, we plan to
|
||||||
|
extend the compatibility to other compilers, platforms and operating systems.
|
||||||
|
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
The library uses a classical ```cmake``` build mechanism: use ```cmake``` to create a ```build``` folder in the top level directory and run ```make``` from inside it. This assumes that the standard tool ```cmake``` is already installed on the system, and an up-to-date c++ compiler (i.e. g++ >=11.3.0) as well.
|
||||||
|
|
||||||
|
It will compile the shared library in optimized mode, and ```make install``` install it to the desired prefix folder (by default ```/usr/local/lib```).
|
||||||
|
|
||||||
|
If you want to choose additional compile options (i.e. other installation folder, debug mode, tests), you need to run cmake manually and pass the desired options:
|
||||||
|
```
|
||||||
|
mkdir build
|
||||||
|
cd build
|
||||||
|
cmake ../src -CMAKE_INSTALL_PREFIX=/usr/
|
||||||
|
make
|
||||||
|
```
|
||||||
|
The available options are the following:
|
||||||
|
|
||||||
|
| Variable Name | values |
|
||||||
|
| -------------------- | ------------------------------------------------------------ |
|
||||||
|
| CMAKE_INSTALL_PREFIX | */usr/local* installation folder (libs go in lib/ and headers in include/) |
|
||||||
|
| WARNING_PARANOID | All warnings are shown and treated as errors. Off by default |
|
||||||
|
| ENABLE_TESTING | Compiles unit tests and integration tests |
|
||||||
|
|
||||||
|
------
|
||||||
|
|
||||||
|
<img src="docs/logo-sandboxaq-black.svg">
|
||||||
|
|
||||||
|
<img src="docs/logo-inpher1.png">
|
||||||
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 550 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 24 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 24 KiB |
@@ -0,0 +1,139 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
|
||||||
|
<!-- Generator: Adobe Illustrator 24.2.1, SVG Export Plug-In . SVG Version: 6.00 Build 0) -->
|
||||||
|
|
||||||
|
<svg
|
||||||
|
version="1.1"
|
||||||
|
id="Layer_1"
|
||||||
|
x="0px"
|
||||||
|
y="0px"
|
||||||
|
viewBox="0 0 270 49.4"
|
||||||
|
style="enable-background:new 0 0 270 49.4;"
|
||||||
|
xml:space="preserve"
|
||||||
|
sodipodi:docname="logo-sandboxaq-black.svg"
|
||||||
|
inkscape:version="1.3.2 (1:1.3.2+202311252150+091e20ef0f)"
|
||||||
|
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
|
||||||
|
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
|
||||||
|
xmlns="http://www.w3.org/2000/svg"
|
||||||
|
xmlns:svg="http://www.w3.org/2000/svg"><defs
|
||||||
|
id="defs9839">
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
</defs><sodipodi:namedview
|
||||||
|
id="namedview9837"
|
||||||
|
pagecolor="#ffffff"
|
||||||
|
bordercolor="#000000"
|
||||||
|
borderopacity="0.25"
|
||||||
|
inkscape:showpageshadow="2"
|
||||||
|
inkscape:pageopacity="0.0"
|
||||||
|
inkscape:pagecheckerboard="0"
|
||||||
|
inkscape:deskcolor="#d1d1d1"
|
||||||
|
showgrid="false"
|
||||||
|
inkscape:zoom="1.194332"
|
||||||
|
inkscape:cx="135.64068"
|
||||||
|
inkscape:cy="25.118645"
|
||||||
|
inkscape:window-width="804"
|
||||||
|
inkscape:window-height="436"
|
||||||
|
inkscape:window-x="190"
|
||||||
|
inkscape:window-y="27"
|
||||||
|
inkscape:window-maximized="0"
|
||||||
|
inkscape:current-layer="Layer_1" />
|
||||||
|
<style
|
||||||
|
type="text/css"
|
||||||
|
id="style9786">
|
||||||
|
.st0{fill:#EBB028;}
|
||||||
|
.st1{fill:#FFFFFF;}
|
||||||
|
</style>
|
||||||
|
<text
|
||||||
|
transform="matrix(1 0 0 1 393.832 -491.944)"
|
||||||
|
class="st1"
|
||||||
|
style="font-family:'Satoshi-Medium'; font-size:86.2078px;"
|
||||||
|
id="text9788">SANDBOX </text>
|
||||||
|
<text
|
||||||
|
transform="matrix(1 0 0 1 896.332 -491.944)"
|
||||||
|
class="st1"
|
||||||
|
style="font-family:'Satoshi-Black'; font-size:86.2078px;"
|
||||||
|
id="text9790">AQ</text>
|
||||||
|
<g
|
||||||
|
id="g9808">
|
||||||
|
<g
|
||||||
|
id="g9800">
|
||||||
|
<g
|
||||||
|
id="g9798">
|
||||||
|
<path
|
||||||
|
class="st0"
|
||||||
|
d="m 8.9,9.7 v 3.9 l 29.6,17.1 v 2.7 c 0,1.2 -0.6,2.3 -1.6,2.9 L 31,39.8 v -4 L 1.4,18.6 V 15.9 C 1.4,14.7 2,13.6 3.1,13 Z"
|
||||||
|
id="path9792" />
|
||||||
|
<path
|
||||||
|
class="st0"
|
||||||
|
d="M 18.3,45.1 3.1,36.3 C 2.1,35.7 1.4,34.6 1.4,33.4 V 26 L 28,41.4 21.5,45.1 c -0.9,0.6 -2.2,0.6 -3.2,0 z"
|
||||||
|
id="path9794" />
|
||||||
|
<path
|
||||||
|
class="st0"
|
||||||
|
d="m 21.6,4.3 15.2,8.8 c 1,0.6 1.7,1.7 1.7,2.9 v 7.5 L 11.8,8 18.3,4.3 c 1,-0.6 2.3,-0.6 3.3,0 z"
|
||||||
|
id="path9796" />
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
<g
|
||||||
|
id="g9806">
|
||||||
|
<polygon
|
||||||
|
class="st0"
|
||||||
|
points="248.1,23.2 248.1,30 251.4,33.8 257.3,33.8 "
|
||||||
|
id="polygon9802" />
|
||||||
|
<path
|
||||||
|
class="st0"
|
||||||
|
d="m 246.9,31 -0.1,-0.1 h -0.1 c -0.2,0 -0.4,0 -0.6,0 -3.5,0 -5.7,-2.6 -5.7,-6.7 0,-4.1 2.2,-6.7 5.7,-6.7 3.5,0 5.7,2.6 5.7,6.7 0,0.3 0,0.6 0,0.9 l 3.6,4.2 c 0.7,-1.5 1,-3.2 1,-5.1 0,-6.5 -4.2,-11 -10.3,-11 -6.1,0 -10.3,4.5 -10.3,11 0,6.5 4.2,11 10.3,11 1.2,0 2.3,-0.2 3.4,-0.5 l 0.5,-0.2 z"
|
||||||
|
id="path9804" />
|
||||||
|
</g>
|
||||||
|
</g><g
|
||||||
|
id="g9824"
|
||||||
|
style="fill:#1a1a1a">
|
||||||
|
<path
|
||||||
|
class="st1"
|
||||||
|
d="m 58.7,13.2 c 4.6,0 7.4,2.5 7.4,6.5 h -4.6 c 0,-1.5 -1.1,-2.4 -2.9,-2.4 -1.9,0 -3.1,0.9 -3.1,2.3 0,1.3 0.7,1.9 2.2,2.2 l 3.2,0.7 c 3.8,0.8 5.6,2.6 5.6,5.9 0,4.1 -3.2,6.8 -8.1,6.8 -4.7,0 -7.8,-2.6 -7.8,-6.5 h 4.6 c 0,1.6 1.1,2.4 3.2,2.4 2.1,0 3.4,-0.8 3.4,-2.2 0,-1.2 -0.5,-1.8 -2,-2.1 l -3.2,-0.7 c -3.8,-0.8 -5.7,-2.9 -5.7,-6.4 0,-3.7 3.2,-6.5 7.8,-6.5 z"
|
||||||
|
id="path9810"
|
||||||
|
style="fill:#1a1a1a" />
|
||||||
|
<path
|
||||||
|
class="st1"
|
||||||
|
d="M 70.4,34.9 78,13.6 h 4.5 l 7.6,21.3 h -4.9 l -1.5,-4.5 h -6.9 l -1.5,4.5 z m 7.7,-8.4 h 4.2 L 80.8,22 c -0.2,-0.7 -0.5,-1.6 -0.6,-2.1 -0.1,0.5 -0.3,1.3 -0.6,2.1 z"
|
||||||
|
id="path9812"
|
||||||
|
style="fill:#1a1a1a" />
|
||||||
|
<path
|
||||||
|
class="st1"
|
||||||
|
d="M 95.3,34.9 V 13.6 h 4.6 l 9,13.5 V 13.6 h 4.6 v 21.3 h -4.6 l -9,-13.5 v 13.5 z"
|
||||||
|
id="path9814"
|
||||||
|
style="fill:#1a1a1a" />
|
||||||
|
<path
|
||||||
|
class="st1"
|
||||||
|
d="M 120.7,34.9 V 13.6 h 8 c 6.2,0 10.6,4.4 10.6,10.7 0,6.2 -4.2,10.6 -10.3,10.6 z m 4.7,-17 v 12.6 h 3.2 c 3.7,0 5.8,-2.3 5.8,-6.3 0,-4 -2.3,-6.4 -6.1,-6.4 h -2.9 z"
|
||||||
|
id="path9816"
|
||||||
|
style="fill:#1a1a1a" />
|
||||||
|
<path
|
||||||
|
class="st1"
|
||||||
|
d="m 145.4,13.6 h 8.8 c 4.3,0 6.9,2.2 6.9,5.9 0,2.3 -1,3.9 -3,4.8 2.1,0.7 3.2,2.3 3.2,4.7 0,3.8 -2.5,5.9 -7.1,5.9 h -8.8 z m 4.7,4.1 v 4.6 h 3.7 c 1.7,0 2.6,-0.8 2.6,-2.4 0,-1.5 -0.9,-2.3 -2.6,-2.3 h -3.7 z m 0,8.5 v 4.6 h 3.9 c 1.7,0 2.6,-0.8 2.6,-2.4 0,-1.4 -0.9,-2.2 -2.6,-2.2 z"
|
||||||
|
id="path9818"
|
||||||
|
style="fill:#1a1a1a" />
|
||||||
|
<path
|
||||||
|
class="st1"
|
||||||
|
d="m 176.5,35.2 c -6.1,0 -10.4,-4.5 -10.4,-11 0,-6.5 4.3,-11 10.4,-11 6.2,0 10.4,4.5 10.4,11 0,6.5 -4.2,11 -10.4,11 z m 0.1,-17.5 c -3.4,0 -5.5,2.4 -5.5,6.5 0,4.1 2.1,6.5 5.5,6.5 3.4,0 5.5,-2.5 5.5,-6.5 0,-4 -2.1,-6.5 -5.5,-6.5 z"
|
||||||
|
id="path9820"
|
||||||
|
style="fill:#1a1a1a" />
|
||||||
|
<path
|
||||||
|
class="st1"
|
||||||
|
d="m 190.4,13.6 h 5.5 l 1.8,2.8 c 0.8,1.2 1.5,2.5 2.5,4.3 l 4.3,-7 h 5.4 l -6.7,10.6 6.7,10.6 h -5.5 L 203,32.7 c -1.1,-1.7 -1.8,-3 -2.8,-4.9 l -4.6,7.1 h -5.5 l 7.1,-10.6 z"
|
||||||
|
id="path9822"
|
||||||
|
style="fill:#1a1a1a" />
|
||||||
|
</g><path
|
||||||
|
class="st0"
|
||||||
|
d="m 229,34.9 h 4.7 L 226,13.6 h -4.3 L 214,34.8 h 4.6 l 1.6,-4.5 h 7.1 z m -5.1,-14.6 c 0,0 0,0 0,0 0,-0.1 0,-0.1 0,0 l 2.2,6.2 h -4.4 z"
|
||||||
|
id="path9826" /><g
|
||||||
|
id="g9832">
|
||||||
|
<path
|
||||||
|
class="st1"
|
||||||
|
d="m 259.5,11.2 h 3.9 v 1 h -1.3 v 3.1 h -1.3 v -3.1 h -1.3 z m 4.5,0 h 1.7 l 0.6,2.5 0.6,-2.5 h 1.7 v 4.1 h -1 v -3.1 l -0.8,3.1 h -0.9 l -0.8,-3.1 v 3.1 h -1 v -4.1 z"
|
||||||
|
id="path9830" />
|
||||||
|
</g>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 5.0 KiB |
@@ -0,0 +1,133 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
|
||||||
|
<!-- Generator: Adobe Illustrator 24.2.1, SVG Export Plug-In . SVG Version: 6.00 Build 0) -->
|
||||||
|
|
||||||
|
<svg
|
||||||
|
version="1.1"
|
||||||
|
id="Layer_1"
|
||||||
|
x="0px"
|
||||||
|
y="0px"
|
||||||
|
viewBox="0 0 270 49.4"
|
||||||
|
style="enable-background:new 0 0 270 49.4;"
|
||||||
|
xml:space="preserve"
|
||||||
|
sodipodi:docname="logo-sandboxaq-white.svg"
|
||||||
|
inkscape:version="1.2.2 (1:1.2.2+202212051551+b0a8486541)"
|
||||||
|
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
|
||||||
|
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
|
||||||
|
xmlns="http://www.w3.org/2000/svg"
|
||||||
|
xmlns:svg="http://www.w3.org/2000/svg"><defs
|
||||||
|
id="defs9839" /><sodipodi:namedview
|
||||||
|
id="namedview9837"
|
||||||
|
pagecolor="#ffffff"
|
||||||
|
bordercolor="#000000"
|
||||||
|
borderopacity="0.25"
|
||||||
|
inkscape:showpageshadow="2"
|
||||||
|
inkscape:pageopacity="0.0"
|
||||||
|
inkscape:pagecheckerboard="0"
|
||||||
|
inkscape:deskcolor="#d1d1d1"
|
||||||
|
showgrid="false"
|
||||||
|
inkscape:zoom="2.3886639"
|
||||||
|
inkscape:cx="135.22204"
|
||||||
|
inkscape:cy="25.327967"
|
||||||
|
inkscape:window-width="1072"
|
||||||
|
inkscape:window-height="688"
|
||||||
|
inkscape:window-x="0"
|
||||||
|
inkscape:window-y="0"
|
||||||
|
inkscape:window-maximized="1"
|
||||||
|
inkscape:current-layer="Layer_1" />
|
||||||
|
<style
|
||||||
|
type="text/css"
|
||||||
|
id="style9786">
|
||||||
|
.st0{fill:#EBB028;}
|
||||||
|
.st1{fill:#FFFFFF;}
|
||||||
|
</style>
|
||||||
|
<text
|
||||||
|
transform="matrix(1 0 0 1 393.832 -491.944)"
|
||||||
|
class="st1"
|
||||||
|
style="font-family:'Satoshi-Medium'; font-size:86.2078px;"
|
||||||
|
id="text9788">SANDBOX </text>
|
||||||
|
<text
|
||||||
|
transform="matrix(1 0 0 1 896.332 -491.944)"
|
||||||
|
class="st1"
|
||||||
|
style="font-family:'Satoshi-Black'; font-size:86.2078px;"
|
||||||
|
id="text9790">AQ</text>
|
||||||
|
<g
|
||||||
|
id="g9834">
|
||||||
|
<g
|
||||||
|
id="g9828">
|
||||||
|
<g
|
||||||
|
id="g9808">
|
||||||
|
<g
|
||||||
|
id="g9800">
|
||||||
|
<g
|
||||||
|
id="g9798">
|
||||||
|
<path
|
||||||
|
class="st0"
|
||||||
|
d="M8.9,9.7v3.9l29.6,17.1v2.7c0,1.2-0.6,2.3-1.6,2.9L31,39.8v-4L1.4,18.6v-2.7c0-1.2,0.6-2.3,1.7-2.9 L8.9,9.7z"
|
||||||
|
id="path9792" />
|
||||||
|
<path
|
||||||
|
class="st0"
|
||||||
|
d="M18.3,45.1L3.1,36.3c-1-0.6-1.7-1.7-1.7-2.9V26L28,41.4l-6.5,3.7C20.6,45.7,19.3,45.7,18.3,45.1z"
|
||||||
|
id="path9794" />
|
||||||
|
<path
|
||||||
|
class="st0"
|
||||||
|
d="M21.6,4.3l15.2,8.8c1,0.6,1.7,1.7,1.7,2.9v7.5L11.8,8l6.5-3.7C19.3,3.7,20.6,3.7,21.6,4.3z"
|
||||||
|
id="path9796" />
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
<g
|
||||||
|
id="g9806">
|
||||||
|
<polygon
|
||||||
|
class="st0"
|
||||||
|
points="248.1,23.2 248.1,30 251.4,33.8 257.3,33.8 "
|
||||||
|
id="polygon9802" />
|
||||||
|
<path
|
||||||
|
class="st0"
|
||||||
|
d="M246.9,31l-0.1-0.1l-0.1,0c-0.2,0-0.4,0-0.6,0c-3.5,0-5.7-2.6-5.7-6.7c0-4.1,2.2-6.7,5.7-6.7 s5.7,2.6,5.7,6.7c0,0.3,0,0.6,0,0.9l3.6,4.2c0.7-1.5,1-3.2,1-5.1c0-6.5-4.2-11-10.3-11c-6.1,0-10.3,4.5-10.3,11s4.2,11,10.3,11 c1.2,0,2.3-0.2,3.4-0.5l0.5-0.2L246.9,31z"
|
||||||
|
id="path9804" />
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
<g
|
||||||
|
id="g9824">
|
||||||
|
<path
|
||||||
|
class="st1"
|
||||||
|
d="M58.7,13.2c4.6,0,7.4,2.5,7.4,6.5h-4.6c0-1.5-1.1-2.4-2.9-2.4c-1.9,0-3.1,0.9-3.1,2.3c0,1.3,0.7,1.9,2.2,2.2 l3.2,0.7c3.8,0.8,5.6,2.6,5.6,5.9c0,4.1-3.2,6.8-8.1,6.8c-4.7,0-7.8-2.6-7.8-6.5h4.6c0,1.6,1.1,2.4,3.2,2.4 c2.1,0,3.4-0.8,3.4-2.2c0-1.2-0.5-1.8-2-2.1l-3.2-0.7c-3.8-0.8-5.7-2.9-5.7-6.4C50.9,16,54.1,13.2,58.7,13.2z"
|
||||||
|
id="path9810" />
|
||||||
|
<path
|
||||||
|
class="st1"
|
||||||
|
d="M70.4,34.9L78,13.6h4.5l7.6,21.3h-4.9l-1.5-4.5h-6.9l-1.5,4.5H70.4z M78.1,26.5h4.2L80.8,22 c-0.2-0.7-0.5-1.6-0.6-2.1c-0.1,0.5-0.3,1.3-0.6,2.1L78.1,26.5z"
|
||||||
|
id="path9812" />
|
||||||
|
<path
|
||||||
|
class="st1"
|
||||||
|
d="M95.3,34.9V13.6h4.6l9,13.5V13.6h4.6v21.3h-4.6l-9-13.5v13.5H95.3z"
|
||||||
|
id="path9814" />
|
||||||
|
<path
|
||||||
|
class="st1"
|
||||||
|
d="M120.7,34.9V13.6h8c6.2,0,10.6,4.4,10.6,10.7c0,6.2-4.2,10.6-10.3,10.6H120.7z M125.4,17.9v12.6h3.2 c3.7,0,5.8-2.3,5.8-6.3c0-4-2.3-6.4-6.1-6.4H125.4z"
|
||||||
|
id="path9816" />
|
||||||
|
<path
|
||||||
|
class="st1"
|
||||||
|
d="M145.4,13.6h8.8c4.3,0,6.9,2.2,6.9,5.9c0,2.3-1,3.9-3,4.8c2.1,0.7,3.2,2.3,3.2,4.7c0,3.8-2.5,5.9-7.1,5.9 h-8.8V13.6z M150.1,17.7v4.6h3.7c1.7,0,2.6-0.8,2.6-2.4c0-1.5-0.9-2.3-2.6-2.3H150.1z M150.1,26.2v4.6h3.9c1.7,0,2.6-0.8,2.6-2.4 c0-1.4-0.9-2.2-2.6-2.2H150.1z"
|
||||||
|
id="path9818" />
|
||||||
|
<path
|
||||||
|
class="st1"
|
||||||
|
d="M176.5,35.2c-6.1,0-10.4-4.5-10.4-11s4.3-11,10.4-11c6.2,0,10.4,4.5,10.4,11S182.7,35.2,176.5,35.2z M176.6,17.7c-3.4,0-5.5,2.4-5.5,6.5c0,4.1,2.1,6.5,5.5,6.5c3.4,0,5.5-2.5,5.5-6.5C182.1,20.2,180,17.7,176.6,17.7z"
|
||||||
|
id="path9820" />
|
||||||
|
<path
|
||||||
|
class="st1"
|
||||||
|
d="M190.4,13.6h5.5l1.8,2.8c0.8,1.2,1.5,2.5,2.5,4.3l4.3-7h5.4l-6.7,10.6l6.7,10.6h-5.5l-1.4-2.2 c-1.1-1.7-1.8-3-2.8-4.9l-4.6,7.1h-5.5l7.1-10.6L190.4,13.6z"
|
||||||
|
id="path9822" />
|
||||||
|
</g>
|
||||||
|
<path
|
||||||
|
class="st0"
|
||||||
|
d="M229,34.9h4.7L226,13.6h-4.3l-7.7,21.2h4.6l1.6-4.5h7.1L229,34.9z M223.9,20.3 C223.9,20.3,223.9,20.3,223.9,20.3C223.9,20.2,223.9,20.2,223.9,20.3l2.2,6.2h-4.4L223.9,20.3z"
|
||||||
|
id="path9826" />
|
||||||
|
</g>
|
||||||
|
<g
|
||||||
|
id="g9832">
|
||||||
|
<path
|
||||||
|
class="st1"
|
||||||
|
d="M259.5,11.2h3.9v1h-1.3v3.1h-1.3v-3.1h-1.3V11.2L259.5,11.2z M264,11.2h1.7l0.6,2.5l0.6-2.5h1.7v4.1h-1v-3.1 l-0.8,3.1h-0.9l-0.8-3.1v3.1h-1V11.2L264,11.2z"
|
||||||
|
id="path9830" />
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 4.7 KiB |
@@ -0,0 +1,2 @@
|
|||||||
|
library: spqlios-fft
|
||||||
|
version: 2.0.0
|
||||||
27
poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/scripts/auto-release.sh
Executable file
27
poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/scripts/auto-release.sh
Executable file
@@ -0,0 +1,27 @@
|
|||||||
|
#!/bin/sh
|
||||||
|
|
||||||
|
# this script generates one tag if there is a version change in manifest.yaml
|
||||||
|
cd `dirname $0`/..
|
||||||
|
if [ "v$1" = "v-y" ]; then
|
||||||
|
echo "production mode!";
|
||||||
|
fi
|
||||||
|
changes=`git diff HEAD~1..HEAD -- manifest.yaml | grep 'version:'`
|
||||||
|
oldversion=$(echo "$changes" | grep '^-version:' | cut '-d ' -f2)
|
||||||
|
version=$(echo "$changes" | grep '^+version:' | cut '-d ' -f2)
|
||||||
|
echo "Versions: $oldversion --> $version"
|
||||||
|
if [ "v$oldversion" = "v$version" ]; then
|
||||||
|
echo "Same version - nothing to do"; exit 0;
|
||||||
|
fi
|
||||||
|
if [ "v$1" = "v-y" ]; then
|
||||||
|
git config user.name github-actions
|
||||||
|
git config user.email github-actions@github.com
|
||||||
|
git tag -a "v$version" -m "Version $version"
|
||||||
|
git push origin "v$version"
|
||||||
|
else
|
||||||
|
cat <<EOF
|
||||||
|
# the script would do:
|
||||||
|
git tag -a "v$version" -m "Version $version"
|
||||||
|
git push origin "v$version"
|
||||||
|
EOF
|
||||||
|
fi
|
||||||
|
|
||||||
102
poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/scripts/ci-pkg
Executable file
102
poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/scripts/ci-pkg
Executable file
@@ -0,0 +1,102 @@
|
|||||||
|
#!/bin/sh
|
||||||
|
|
||||||
|
# ONLY USE A PREFIX YOU ARE CONFIDENT YOU CAN WIPE OUT ENTIRELY
|
||||||
|
CI_INSTALL_PREFIX=/opt/spqlios
|
||||||
|
CI_REPO_URL=https://spq-dav.algonics.net/ci
|
||||||
|
WORKDIR=`pwd`
|
||||||
|
if [ "x$DESTDIR" = "x" ]; then
|
||||||
|
DESTDIR=/
|
||||||
|
else
|
||||||
|
mkdir -p $DESTDIR
|
||||||
|
DESTDIR=`realpath $DESTDIR`
|
||||||
|
fi
|
||||||
|
DIR=`dirname "$0"`
|
||||||
|
cd $DIR/..
|
||||||
|
DIR=`pwd`
|
||||||
|
|
||||||
|
FULL_UNAME=`uname -a | tr '[A-Z]' '[a-z]'`
|
||||||
|
HOST=`echo $FULL_UNAME | sed 's/ .*//'`
|
||||||
|
ARCH=none
|
||||||
|
case "$HOST" in
|
||||||
|
*linux*)
|
||||||
|
DISTRIB=`lsb_release -c | awk '{print $2}' | tr '[A-Z]' '[a-z]'`
|
||||||
|
HOST=linux-$DISTRIB
|
||||||
|
;;
|
||||||
|
*darwin*)
|
||||||
|
HOST=darwin
|
||||||
|
;;
|
||||||
|
*mingw*|*msys*)
|
||||||
|
DISTRIB=`echo $MSYSTEM | tr '[A-Z]' '[a-z]'`
|
||||||
|
HOST=msys64-$DISTRIB
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "Host unknown: $HOST";
|
||||||
|
exit 1
|
||||||
|
esac
|
||||||
|
case "$FULL_UNAME" in
|
||||||
|
*x86_64*)
|
||||||
|
ARCH=x86_64
|
||||||
|
;;
|
||||||
|
*aarch64*)
|
||||||
|
ARCH=aarch64
|
||||||
|
;;
|
||||||
|
*arm64*)
|
||||||
|
ARCH=arm64
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "Architecture unknown: $FULL_UNAME";
|
||||||
|
exit 1
|
||||||
|
esac
|
||||||
|
UNAME="$HOST-$ARCH"
|
||||||
|
CMH=
|
||||||
|
if [ -d lib/spqlios/.git ]; then
|
||||||
|
CMH=`git submodule status | sed 's/\(..........\).*/\1/'`
|
||||||
|
else
|
||||||
|
CMH=`git rev-parse HEAD | sed 's/\(..........\).*/\1/'`
|
||||||
|
fi
|
||||||
|
FNAME=spqlios-arithmetic-$CMH-$UNAME.tar.gz
|
||||||
|
|
||||||
|
cat <<EOF
|
||||||
|
================= CI MINI-PACKAGER ==================
|
||||||
|
Work Dir: WORKDIR=$WORKDIR
|
||||||
|
Spq Dir: DIR=$DIR
|
||||||
|
Install Root: DESTDIR=$DESTDIR
|
||||||
|
Install Prefix: CI_INSTALL_PREFIX=$CI_INSTALL_PREFIX
|
||||||
|
Archive Name: FNAME=$FNAME
|
||||||
|
CI WebDav: CI_REPO_URL=$CI_REPO_URL
|
||||||
|
=====================================================
|
||||||
|
EOF
|
||||||
|
|
||||||
|
if [ "x$1" = "xcreate" ]; then
|
||||||
|
rm -rf dist
|
||||||
|
cmake -B build -S . -DCMAKE_INSTALL_PREFIX="$CI_INSTALL_PREFIX" -DCMAKE_BUILD_TYPE=Release -DENABLE_TESTING=ON -DWARNING_PARANOID=ON -DDEVMODE_INSTALL=ON || exit 1
|
||||||
|
cmake --build build || exit 1
|
||||||
|
rm -rf "$DIR/dist" 2>/dev/null
|
||||||
|
rm -f "$DIR/$FNAME" 2>/dev/null
|
||||||
|
DESTDIR="$DIR/dist" cmake --install build || exit 1
|
||||||
|
if [ -d "$DIR/dist$CI_INSTALL_PREFIX" ]; then
|
||||||
|
tar -C "$DIR/dist" -cvzf "$DIR/$FNAME" .
|
||||||
|
else
|
||||||
|
# fix since msys can mess up the paths
|
||||||
|
REAL_DEST=`find "$DIR/dist" -type d -exec test -d "{}$CI_INSTALL_PREFIX" \; -print`
|
||||||
|
echo "REAL_DEST: $REAL_DEST"
|
||||||
|
[ -d "$REAL_DEST$CI_INSTALL_PREFIX" ] && tar -C "$REAL_DEST" -cvzf "$DIR/$FNAME" .
|
||||||
|
fi
|
||||||
|
[ -f "$DIR/$FNAME" ] || { echo "failed to create $DIR/$FNAME"; exit 1; }
|
||||||
|
[ "x$CI_CREDS" = "x" ] && { echo "CI_CREDS is not set: not uploading"; exit 1; }
|
||||||
|
curl -u "$CI_CREDS" -T "$DIR/$FNAME" "$CI_REPO_URL/$FNAME"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ "x$1" = "xinstall" ]; then
|
||||||
|
[ "x$CI_CREDS" = "x" ] && { echo "CI_CREDS is not set: not downloading"; exit 1; }
|
||||||
|
# cleaning
|
||||||
|
rm -rf "$DESTDIR$CI_INSTALL_PREFIX"/* 2>/dev/null
|
||||||
|
rm -f "$DIR/$FNAME" 2>/dev/null
|
||||||
|
# downloading
|
||||||
|
curl -u "$CI_CREDS" -o "$DIR/$FNAME" "$CI_REPO_URL/$FNAME"
|
||||||
|
[ -f "$DIR/$FNAME" ] || { echo "failed to download $DIR/$FNAME"; exit 0; }
|
||||||
|
# installing
|
||||||
|
mkdir -p $DESTDIR
|
||||||
|
tar -C "$DESTDIR" -xvzf "$DIR/$FNAME"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
181
poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/scripts/prepare-release
Executable file
181
poulpy-backend/src/cpu_spqlios/spqlios-arithmetic/scripts/prepare-release
Executable file
@@ -0,0 +1,181 @@
|
|||||||
|
#!/usr/bin/perl
|
||||||
|
##
|
||||||
|
## This script will help update manifest.yaml and Changelog.md before a release
|
||||||
|
## Any merge to master that changes the version line in manifest.yaml
|
||||||
|
## is considered as a new release.
|
||||||
|
##
|
||||||
|
## When ready to make a release, please run ./scripts/prepare-release
|
||||||
|
## and commit push the final result!
|
||||||
|
use File::Basename;
|
||||||
|
use Cwd 'abs_path';
|
||||||
|
|
||||||
|
# find its way to the root of git's repository
|
||||||
|
my $scriptsdirname = dirname(abs_path(__FILE__));
|
||||||
|
chdir "$scriptsdirname/..";
|
||||||
|
print "✓ Entering directory:".`pwd`;
|
||||||
|
|
||||||
|
# ensures that the current branch is ahead of origin/main
|
||||||
|
my $diff= `git diff`;
|
||||||
|
chop $diff;
|
||||||
|
if ($diff =~ /./) {
|
||||||
|
die("ERROR: Please commit all the changes before calling the prepare-release script.");
|
||||||
|
} else {
|
||||||
|
print("✓ All changes are comitted.\n");
|
||||||
|
}
|
||||||
|
system("git fetch origin");
|
||||||
|
my $vcount = `git rev-list --left-right --count origin/main...HEAD`;
|
||||||
|
$vcount =~ /^([0-9]+)[ \t]*([0-9]+)$/;
|
||||||
|
if ($2>0) {
|
||||||
|
die("ERROR: the current HEAD is not ahead of origin/main\n. Please use git merge origin/main.");
|
||||||
|
} else {
|
||||||
|
print("✓ Current HEAD is up to date with origin/main.\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
mkdir ".changes";
|
||||||
|
my $currentbranch = `git rev-parse --abbrev-ref HEAD`;
|
||||||
|
chop $currentbranch;
|
||||||
|
$currentbranch =~ s/[^a-zA-Z._-]+/-/g;
|
||||||
|
my $changefile=".changes/$currentbranch.md";
|
||||||
|
my $origmanifestfile=".changes/$currentbranch--manifest.yaml";
|
||||||
|
my $origchangelogfile=".changes/$currentbranch--Changelog.md";
|
||||||
|
|
||||||
|
my $exit_code=system("wget -O $origmanifestfile https://raw.githubusercontent.com/tfhe/spqlios-fft/main/manifest.yaml");
|
||||||
|
if ($exit_code!=0 or ! -f $origmanifestfile) {
|
||||||
|
die("ERROR: failed to download manifest.yaml");
|
||||||
|
}
|
||||||
|
$exit_code=system("wget -O $origchangelogfile https://raw.githubusercontent.com/tfhe/spqlios-fft/main/Changelog.md");
|
||||||
|
if ($exit_code!=0 or ! -f $origchangelogfile) {
|
||||||
|
die("ERROR: failed to download Changelog.md");
|
||||||
|
}
|
||||||
|
|
||||||
|
# read the current version (from origin/main manifest)
|
||||||
|
my $vmajor = 0;
|
||||||
|
my $vminor = 0;
|
||||||
|
my $vpatch = 0;
|
||||||
|
my $versionline = `grep '^version: ' $origmanifestfile | cut -d" " -f2`;
|
||||||
|
chop $versionline;
|
||||||
|
if (not $versionline =~ /^([0-9]+)\.([0-9]+)\.([0-9]+)$/) {
|
||||||
|
die("ERROR: invalid version in manifest file: $versionline\n");
|
||||||
|
} else {
|
||||||
|
$vmajor = int($1);
|
||||||
|
$vminor = int($2);
|
||||||
|
$vpatch = int($3);
|
||||||
|
}
|
||||||
|
print "Version in manifest file: $vmajor.$vminor.$vpatch\n";
|
||||||
|
|
||||||
|
if (not -f $changefile) {
|
||||||
|
## create a changes file
|
||||||
|
open F,">$changefile";
|
||||||
|
print F "# Changefile for branch $currentbranch\n\n";
|
||||||
|
print F "## Type of release (major,minor,patch)?\n\n";
|
||||||
|
print F "releasetype: patch\n\n";
|
||||||
|
print F "## What has changed (please edit)?\n\n";
|
||||||
|
print F "- This has changed.\n";
|
||||||
|
close F;
|
||||||
|
}
|
||||||
|
|
||||||
|
system("editor $changefile");
|
||||||
|
|
||||||
|
# compute the new version
|
||||||
|
my $nvmajor;
|
||||||
|
my $nvminor;
|
||||||
|
my $nvpatch;
|
||||||
|
my $changelog;
|
||||||
|
my $recordchangelog=0;
|
||||||
|
open F,"$changefile";
|
||||||
|
while ($line=<F>) {
|
||||||
|
chop $line;
|
||||||
|
if ($recordchangelog) {
|
||||||
|
($line =~ /^$/) and next;
|
||||||
|
$changelog .= "$line\n";
|
||||||
|
next;
|
||||||
|
}
|
||||||
|
if ($line =~ /^releasetype *: *patch *$/) {
|
||||||
|
$nvmajor=$vmajor;
|
||||||
|
$nvminor=$vminor;
|
||||||
|
$nvpatch=$vpatch+1;
|
||||||
|
}
|
||||||
|
if ($line =~ /^releasetype *: *minor *$/) {
|
||||||
|
$nvmajor=$vmajor;
|
||||||
|
$nvminor=$vminor+1;
|
||||||
|
$nvpatch=0;
|
||||||
|
}
|
||||||
|
if ($line =~ /^releasetype *: *major *$/) {
|
||||||
|
$nvmajor=$vmajor+1;
|
||||||
|
$nvminor=0;
|
||||||
|
$nvpatch=0;
|
||||||
|
}
|
||||||
|
if ($line =~ /^## What has changed/) {
|
||||||
|
$recordchangelog=1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
close F;
|
||||||
|
print "New version: $nvmajor.$nvminor.$nvpatch\n";
|
||||||
|
print "Changes:\n$changelog";
|
||||||
|
|
||||||
|
# updating manifest.yaml
|
||||||
|
open F,"manifest.yaml";
|
||||||
|
open G,">.changes/manifest.yaml";
|
||||||
|
while ($line=<F>) {
|
||||||
|
if ($line =~ /^version *: */) {
|
||||||
|
print G "version: $nvmajor.$nvminor.$nvpatch\n";
|
||||||
|
next;
|
||||||
|
}
|
||||||
|
print G $line;
|
||||||
|
}
|
||||||
|
close F;
|
||||||
|
close G;
|
||||||
|
# updating Changelog.md
|
||||||
|
open F,"$origchangelogfile";
|
||||||
|
open G,">.changes/Changelog.md";
|
||||||
|
print G <<EOF
|
||||||
|
# Changelog
|
||||||
|
|
||||||
|
All notable changes to this project will be documented in this file.
|
||||||
|
this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||||
|
|
||||||
|
EOF
|
||||||
|
;
|
||||||
|
print G "## [$nvmajor.$nvminor.$nvpatch] - ".`date '+%Y-%m-%d'`."\n";
|
||||||
|
print G "$changelog\n";
|
||||||
|
my $skip_section=1;
|
||||||
|
while ($line=<F>) {
|
||||||
|
if ($line =~ /^## +\[([0-9]+)\.([0-9]+)\.([0-9]+)\] +/) {
|
||||||
|
if ($1>$nvmajor) {
|
||||||
|
die("ERROR: found larger version $1.$2.$3 in the Changelog.md\n");
|
||||||
|
} elsif ($1<$nvmajor) {
|
||||||
|
$skip_section=0;
|
||||||
|
} elsif ($2>$nvminor) {
|
||||||
|
die("ERROR: found larger version $1.$2.$3 in the Changelog.md\n");
|
||||||
|
} elsif ($2<$nvminor) {
|
||||||
|
$skip_section=0;
|
||||||
|
} elsif ($3>$nvpatch) {
|
||||||
|
die("ERROR: found larger version $1.$2.$3 in the Changelog.md\n");
|
||||||
|
} elsif ($2<$nvpatch) {
|
||||||
|
$skip_section=0;
|
||||||
|
} else {
|
||||||
|
$skip_section=1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
($skip_section) and next;
|
||||||
|
print G $line;
|
||||||
|
}
|
||||||
|
close F;
|
||||||
|
close G;
|
||||||
|
|
||||||
|
print "-------------------------------------\n";
|
||||||
|
print "THIS WILL BE UPDATED:\n";
|
||||||
|
print "-------------------------------------\n";
|
||||||
|
system("diff -u manifest.yaml .changes/manifest.yaml");
|
||||||
|
system("diff -u Changelog.md .changes/Changelog.md");
|
||||||
|
print "-------------------------------------\n";
|
||||||
|
print "To proceed: press <enter> otherwise <CTRL+C>\n";
|
||||||
|
my $bla;
|
||||||
|
$bla=<STDIN>;
|
||||||
|
system("cp -vf .changes/manifest.yaml manifest.yaml");
|
||||||
|
system("cp -vf .changes/Changelog.md Changelog.md");
|
||||||
|
system("git commit -a -m \"Update version and changelog.\"");
|
||||||
|
system("git push");
|
||||||
|
print("✓ Changes have been committed and pushed!\n");
|
||||||
|
print("✓ A new release will be created when this branch is merged to main.\n");
|
||||||
|
|
||||||
@@ -0,0 +1,223 @@
|
|||||||
|
enable_language(ASM)
|
||||||
|
|
||||||
|
# C source files that are compiled for all targets (i.e. reference code)
|
||||||
|
set(SRCS_GENERIC
|
||||||
|
commons.c
|
||||||
|
commons_private.c
|
||||||
|
coeffs/coeffs_arithmetic.c
|
||||||
|
arithmetic/vec_znx.c
|
||||||
|
arithmetic/vec_znx_dft.c
|
||||||
|
arithmetic/vector_matrix_product.c
|
||||||
|
cplx/cplx_common.c
|
||||||
|
cplx/cplx_conversions.c
|
||||||
|
cplx/cplx_fft_asserts.c
|
||||||
|
cplx/cplx_fft_ref.c
|
||||||
|
cplx/cplx_fftvec_ref.c
|
||||||
|
cplx/cplx_ifft_ref.c
|
||||||
|
cplx/spqlios_cplx_fft.c
|
||||||
|
reim4/reim4_arithmetic_ref.c
|
||||||
|
reim4/reim4_fftvec_addmul_ref.c
|
||||||
|
reim4/reim4_fftvec_conv_ref.c
|
||||||
|
reim/reim_conversions.c
|
||||||
|
reim/reim_fft_ifft.c
|
||||||
|
reim/reim_fft_ref.c
|
||||||
|
reim/reim_fftvec_ref.c
|
||||||
|
reim/reim_ifft_ref.c
|
||||||
|
reim/reim_ifft_ref.c
|
||||||
|
reim/reim_to_tnx_ref.c
|
||||||
|
q120/q120_ntt.c
|
||||||
|
q120/q120_arithmetic_ref.c
|
||||||
|
q120/q120_arithmetic_simple.c
|
||||||
|
arithmetic/scalar_vector_product.c
|
||||||
|
arithmetic/vec_znx_big.c
|
||||||
|
arithmetic/znx_small.c
|
||||||
|
arithmetic/module_api.c
|
||||||
|
arithmetic/zn_vmp_int8_ref.c
|
||||||
|
arithmetic/zn_vmp_int16_ref.c
|
||||||
|
arithmetic/zn_vmp_int32_ref.c
|
||||||
|
arithmetic/zn_vmp_ref.c
|
||||||
|
arithmetic/zn_api.c
|
||||||
|
arithmetic/zn_conversions_ref.c
|
||||||
|
arithmetic/zn_approxdecomp_ref.c
|
||||||
|
arithmetic/vec_rnx_api.c
|
||||||
|
arithmetic/vec_rnx_conversions_ref.c
|
||||||
|
arithmetic/vec_rnx_svp_ref.c
|
||||||
|
reim/reim_execute.c
|
||||||
|
cplx/cplx_execute.c
|
||||||
|
reim4/reim4_execute.c
|
||||||
|
arithmetic/vec_rnx_arithmetic.c
|
||||||
|
arithmetic/vec_rnx_approxdecomp_ref.c
|
||||||
|
arithmetic/vec_rnx_vmp_ref.c
|
||||||
|
)
|
||||||
|
# C or assembly source files compiled only on x86 targets
|
||||||
|
set(SRCS_X86
|
||||||
|
)
|
||||||
|
# C or assembly source files compiled only on aarch64 targets
|
||||||
|
set(SRCS_AARCH64
|
||||||
|
cplx/cplx_fallbacks_aarch64.c
|
||||||
|
reim/reim_fallbacks_aarch64.c
|
||||||
|
reim4/reim4_fallbacks_aarch64.c
|
||||||
|
q120/q120_fallbacks_aarch64.c
|
||||||
|
reim/reim_fft_neon.c
|
||||||
|
)
|
||||||
|
|
||||||
|
# C or assembly source files compiled only on x86: avx, avx2, fma targets
|
||||||
|
set(SRCS_FMA_C
|
||||||
|
arithmetic/vector_matrix_product_avx.c
|
||||||
|
cplx/cplx_conversions_avx2_fma.c
|
||||||
|
cplx/cplx_fft_avx2_fma.c
|
||||||
|
cplx/cplx_fft_sse.c
|
||||||
|
cplx/cplx_fftvec_avx2_fma.c
|
||||||
|
cplx/cplx_ifft_avx2_fma.c
|
||||||
|
reim4/reim4_arithmetic_avx2.c
|
||||||
|
reim4/reim4_fftvec_conv_fma.c
|
||||||
|
reim4/reim4_fftvec_addmul_fma.c
|
||||||
|
reim/reim_conversions_avx.c
|
||||||
|
reim/reim_fft4_avx_fma.c
|
||||||
|
reim/reim_fft8_avx_fma.c
|
||||||
|
reim/reim_ifft4_avx_fma.c
|
||||||
|
reim/reim_ifft8_avx_fma.c
|
||||||
|
reim/reim_fft_avx2.c
|
||||||
|
reim/reim_ifft_avx2.c
|
||||||
|
reim/reim_to_tnx_avx.c
|
||||||
|
reim/reim_fftvec_fma.c
|
||||||
|
)
|
||||||
|
set(SRCS_FMA_ASM
|
||||||
|
cplx/cplx_fft16_avx_fma.s
|
||||||
|
cplx/cplx_ifft16_avx_fma.s
|
||||||
|
reim/reim_fft16_avx_fma.s
|
||||||
|
reim/reim_ifft16_avx_fma.s
|
||||||
|
)
|
||||||
|
set(SRCS_FMA_WIN32_ASM
|
||||||
|
cplx/cplx_fft16_avx_fma_win32.s
|
||||||
|
cplx/cplx_ifft16_avx_fma_win32.s
|
||||||
|
reim/reim_fft16_avx_fma_win32.s
|
||||||
|
reim/reim_ifft16_avx_fma_win32.s
|
||||||
|
)
|
||||||
|
set_source_files_properties(${SRCS_FMA_C} PROPERTIES COMPILE_OPTIONS "-mfma;-mavx;-mavx2")
|
||||||
|
set_source_files_properties(${SRCS_FMA_ASM} PROPERTIES COMPILE_OPTIONS "-mfma;-mavx;-mavx2")
|
||||||
|
|
||||||
|
# C or assembly source files compiled only on x86: avx512f/vl/dq + fma targets
|
||||||
|
set(SRCS_AVX512
|
||||||
|
cplx/cplx_fft_avx512.c
|
||||||
|
)
|
||||||
|
set_source_files_properties(${SRCS_AVX512} PROPERTIES COMPILE_OPTIONS "-mfma;-mavx512f;-mavx512vl;-mavx512dq")
|
||||||
|
|
||||||
|
# C or assembly source files compiled only on x86: avx2 + bmi targets
|
||||||
|
set(SRCS_AVX2
|
||||||
|
arithmetic/vec_znx_avx.c
|
||||||
|
coeffs/coeffs_arithmetic_avx.c
|
||||||
|
arithmetic/vec_znx_dft_avx2.c
|
||||||
|
arithmetic/zn_vmp_int8_avx.c
|
||||||
|
arithmetic/zn_vmp_int16_avx.c
|
||||||
|
arithmetic/zn_vmp_int32_avx.c
|
||||||
|
q120/q120_arithmetic_avx2.c
|
||||||
|
q120/q120_ntt_avx2.c
|
||||||
|
arithmetic/vec_rnx_arithmetic_avx.c
|
||||||
|
arithmetic/vec_rnx_approxdecomp_avx.c
|
||||||
|
arithmetic/vec_rnx_vmp_avx.c
|
||||||
|
|
||||||
|
)
|
||||||
|
set_source_files_properties(${SRCS_AVX2} PROPERTIES COMPILE_OPTIONS "-mbmi2;-mavx2")
|
||||||
|
|
||||||
|
# C source files on float128 via libquadmath on x86 targets targets
|
||||||
|
set(SRCS_F128
|
||||||
|
cplx_f128/cplx_fft_f128.c
|
||||||
|
cplx_f128/cplx_fft_f128.h
|
||||||
|
)
|
||||||
|
|
||||||
|
# H header files containing the public API (these headers are installed)
|
||||||
|
set(HEADERSPUBLIC
|
||||||
|
commons.h
|
||||||
|
arithmetic/vec_znx_arithmetic.h
|
||||||
|
arithmetic/vec_rnx_arithmetic.h
|
||||||
|
arithmetic/zn_arithmetic.h
|
||||||
|
cplx/cplx_fft.h
|
||||||
|
reim/reim_fft.h
|
||||||
|
q120/q120_common.h
|
||||||
|
q120/q120_arithmetic.h
|
||||||
|
q120/q120_ntt.h
|
||||||
|
)
|
||||||
|
|
||||||
|
# H header files containing the private API (these headers are used internally)
|
||||||
|
set(HEADERSPRIVATE
|
||||||
|
commons_private.h
|
||||||
|
cplx/cplx_fft_internal.h
|
||||||
|
cplx/cplx_fft_private.h
|
||||||
|
reim4/reim4_arithmetic.h
|
||||||
|
reim4/reim4_fftvec_internal.h
|
||||||
|
reim4/reim4_fftvec_private.h
|
||||||
|
reim4/reim4_fftvec_public.h
|
||||||
|
reim/reim_fft_internal.h
|
||||||
|
reim/reim_fft_private.h
|
||||||
|
q120/q120_arithmetic_private.h
|
||||||
|
q120/q120_ntt_private.h
|
||||||
|
arithmetic/vec_znx_arithmetic.h
|
||||||
|
arithmetic/vec_rnx_arithmetic_private.h
|
||||||
|
arithmetic/vec_rnx_arithmetic_plugin.h
|
||||||
|
arithmetic/zn_arithmetic_private.h
|
||||||
|
arithmetic/zn_arithmetic_plugin.h
|
||||||
|
coeffs/coeffs_arithmetic.h
|
||||||
|
reim/reim_fft_core_template.h
|
||||||
|
)
|
||||||
|
|
||||||
|
set(SPQLIOSSOURCES
|
||||||
|
${SRCS_GENERIC}
|
||||||
|
${HEADERSPUBLIC}
|
||||||
|
${HEADERSPRIVATE}
|
||||||
|
)
|
||||||
|
if (${X86})
|
||||||
|
set(SPQLIOSSOURCES ${SPQLIOSSOURCES}
|
||||||
|
${SRCS_X86}
|
||||||
|
${SRCS_FMA_C}
|
||||||
|
${SRCS_FMA_ASM}
|
||||||
|
${SRCS_AVX2}
|
||||||
|
${SRCS_AVX512}
|
||||||
|
)
|
||||||
|
elseif (${X86_WIN32})
|
||||||
|
set(SPQLIOSSOURCES ${SPQLIOSSOURCES}
|
||||||
|
#${SRCS_X86}
|
||||||
|
${SRCS_FMA_C}
|
||||||
|
${SRCS_FMA_WIN32_ASM}
|
||||||
|
${SRCS_AVX2}
|
||||||
|
${SRCS_AVX512}
|
||||||
|
)
|
||||||
|
elseif (${AARCH64})
|
||||||
|
set(SPQLIOSSOURCES ${SPQLIOSSOURCES}
|
||||||
|
${SRCS_AARCH64}
|
||||||
|
)
|
||||||
|
endif ()
|
||||||
|
|
||||||
|
|
||||||
|
set(SPQLIOSLIBDEP
|
||||||
|
m # libmath depencency for cosinus/sinus functions
|
||||||
|
)
|
||||||
|
|
||||||
|
if (ENABLE_SPQLIOS_F128)
|
||||||
|
find_library(quadmath REQUIRED NAMES quadmath)
|
||||||
|
set(SPQLIOSSOURCES ${SPQLIOSSOURCES} ${SRCS_F128})
|
||||||
|
set(SPQLIOSLIBDEP ${SPQLIOSLIBDEP} quadmath)
|
||||||
|
endif (ENABLE_SPQLIOS_F128)
|
||||||
|
|
||||||
|
add_library(libspqlios-static STATIC ${SPQLIOSSOURCES})
|
||||||
|
add_library(libspqlios SHARED ${SPQLIOSSOURCES})
|
||||||
|
set_property(TARGET libspqlios-static PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||||
|
set_property(TARGET libspqlios PROPERTY OUTPUT_NAME spqlios)
|
||||||
|
set_property(TARGET libspqlios-static PROPERTY OUTPUT_NAME spqlios)
|
||||||
|
set_property(TARGET libspqlios PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||||
|
set_property(TARGET libspqlios PROPERTY SOVERSION ${SPQLIOS_VERSION_MAJOR})
|
||||||
|
set_property(TARGET libspqlios PROPERTY VERSION ${SPQLIOS_VERSION})
|
||||||
|
if (NOT APPLE)
|
||||||
|
target_link_options(libspqlios-static PUBLIC -Wl,--no-undefined)
|
||||||
|
target_link_options(libspqlios PUBLIC -Wl,--no-undefined)
|
||||||
|
endif()
|
||||||
|
target_link_libraries(libspqlios ${SPQLIOSLIBDEP})
|
||||||
|
target_link_libraries(libspqlios-static ${SPQLIOSLIBDEP})
|
||||||
|
install(TARGETS libspqlios-static)
|
||||||
|
install(TARGETS libspqlios)
|
||||||
|
|
||||||
|
# install the public headers only
|
||||||
|
foreach (file ${HEADERSPUBLIC})
|
||||||
|
get_filename_component(dir ${file} DIRECTORY)
|
||||||
|
install(FILES ${file} DESTINATION include/spqlios/${dir})
|
||||||
|
endforeach ()
|
||||||
@@ -0,0 +1,172 @@
|
|||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include "vec_znx_arithmetic_private.h"
|
||||||
|
|
||||||
|
static void fill_generic_virtual_table(MODULE* module) {
|
||||||
|
// TODO add default ref handler here
|
||||||
|
module->func.vec_znx_zero = vec_znx_zero_ref;
|
||||||
|
module->func.vec_znx_copy = vec_znx_copy_ref;
|
||||||
|
module->func.vec_znx_negate = vec_znx_negate_ref;
|
||||||
|
module->func.vec_znx_add = vec_znx_add_ref;
|
||||||
|
module->func.vec_znx_sub = vec_znx_sub_ref;
|
||||||
|
module->func.vec_znx_rotate = vec_znx_rotate_ref;
|
||||||
|
module->func.vec_znx_mul_xp_minus_one = vec_znx_mul_xp_minus_one_ref;
|
||||||
|
module->func.vec_znx_automorphism = vec_znx_automorphism_ref;
|
||||||
|
module->func.vec_znx_normalize_base2k = vec_znx_normalize_base2k_ref;
|
||||||
|
module->func.vec_znx_normalize_base2k_tmp_bytes = vec_znx_normalize_base2k_tmp_bytes_ref;
|
||||||
|
if (CPU_SUPPORTS("avx2")) {
|
||||||
|
// TODO add avx handlers here
|
||||||
|
module->func.vec_znx_negate = vec_znx_negate_avx;
|
||||||
|
module->func.vec_znx_add = vec_znx_add_avx;
|
||||||
|
module->func.vec_znx_sub = vec_znx_sub_avx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void fill_fft64_virtual_table(MODULE* module) {
|
||||||
|
// TODO add default ref handler here
|
||||||
|
// module->func.vec_znx_dft = ...;
|
||||||
|
module->func.vec_znx_big_normalize_base2k = fft64_vec_znx_big_normalize_base2k;
|
||||||
|
module->func.vec_znx_big_normalize_base2k_tmp_bytes = fft64_vec_znx_big_normalize_base2k_tmp_bytes;
|
||||||
|
module->func.vec_znx_big_range_normalize_base2k = fft64_vec_znx_big_range_normalize_base2k;
|
||||||
|
module->func.vec_znx_big_range_normalize_base2k_tmp_bytes = fft64_vec_znx_big_range_normalize_base2k_tmp_bytes;
|
||||||
|
module->func.vec_znx_dft = fft64_vec_znx_dft;
|
||||||
|
module->func.vec_znx_idft = fft64_vec_znx_idft;
|
||||||
|
module->func.vec_dft_add = fft64_vec_dft_add;
|
||||||
|
module->func.vec_dft_sub = fft64_vec_dft_sub;
|
||||||
|
module->func.vec_znx_idft_tmp_bytes = fft64_vec_znx_idft_tmp_bytes;
|
||||||
|
module->func.vec_znx_idft_tmp_a = fft64_vec_znx_idft_tmp_a;
|
||||||
|
module->func.vec_znx_big_add = fft64_vec_znx_big_add;
|
||||||
|
module->func.vec_znx_big_add_small = fft64_vec_znx_big_add_small;
|
||||||
|
module->func.vec_znx_big_add_small2 = fft64_vec_znx_big_add_small2;
|
||||||
|
module->func.vec_znx_big_sub = fft64_vec_znx_big_sub;
|
||||||
|
module->func.vec_znx_big_sub_small_a = fft64_vec_znx_big_sub_small_a;
|
||||||
|
module->func.vec_znx_big_sub_small_b = fft64_vec_znx_big_sub_small_b;
|
||||||
|
module->func.vec_znx_big_sub_small2 = fft64_vec_znx_big_sub_small2;
|
||||||
|
module->func.vec_znx_big_rotate = fft64_vec_znx_big_rotate;
|
||||||
|
module->func.vec_znx_big_automorphism = fft64_vec_znx_big_automorphism;
|
||||||
|
module->func.svp_prepare = fft64_svp_prepare_ref;
|
||||||
|
module->func.svp_apply_dft = fft64_svp_apply_dft_ref;
|
||||||
|
module->func.svp_apply_dft_to_dft = fft64_svp_apply_dft_to_dft_ref;
|
||||||
|
module->func.znx_small_single_product = fft64_znx_small_single_product;
|
||||||
|
module->func.znx_small_single_product_tmp_bytes = fft64_znx_small_single_product_tmp_bytes;
|
||||||
|
module->func.vmp_prepare_contiguous = fft64_vmp_prepare_contiguous_ref;
|
||||||
|
module->func.vmp_prepare_tmp_bytes = fft64_vmp_prepare_tmp_bytes;
|
||||||
|
module->func.vmp_apply_dft = fft64_vmp_apply_dft_ref;
|
||||||
|
module->func.vmp_apply_dft_add = fft64_vmp_apply_dft_add_ref;
|
||||||
|
module->func.vmp_apply_dft_tmp_bytes = fft64_vmp_apply_dft_tmp_bytes;
|
||||||
|
module->func.vmp_apply_dft_to_dft = fft64_vmp_apply_dft_to_dft_ref;
|
||||||
|
module->func.vmp_apply_dft_to_dft_add = fft64_vmp_apply_dft_to_dft_add_ref;
|
||||||
|
module->func.vmp_apply_dft_to_dft_tmp_bytes = fft64_vmp_apply_dft_to_dft_tmp_bytes;
|
||||||
|
module->func.bytes_of_vec_znx_dft = fft64_bytes_of_vec_znx_dft;
|
||||||
|
module->func.bytes_of_vec_znx_big = fft64_bytes_of_vec_znx_big;
|
||||||
|
module->func.bytes_of_svp_ppol = fft64_bytes_of_svp_ppol;
|
||||||
|
module->func.bytes_of_vmp_pmat = fft64_bytes_of_vmp_pmat;
|
||||||
|
if (CPU_SUPPORTS("avx2")) {
|
||||||
|
// TODO add avx handlers here
|
||||||
|
// TODO: enable when avx implementation is done
|
||||||
|
module->func.vmp_prepare_contiguous = fft64_vmp_prepare_contiguous_avx;
|
||||||
|
module->func.vmp_apply_dft = fft64_vmp_apply_dft_avx;
|
||||||
|
module->func.vmp_apply_dft_add = fft64_vmp_apply_dft_add_avx;
|
||||||
|
module->func.vmp_apply_dft_to_dft = fft64_vmp_apply_dft_to_dft_avx;
|
||||||
|
module->func.vmp_apply_dft_to_dft_add = fft64_vmp_apply_dft_to_dft_add_avx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void fill_ntt120_virtual_table(MODULE* module) {
|
||||||
|
// TODO add default ref handler here
|
||||||
|
// module->func.vec_znx_dft = ...;
|
||||||
|
if (CPU_SUPPORTS("avx2")) {
|
||||||
|
// TODO add avx handlers here
|
||||||
|
module->func.vec_znx_dft = ntt120_vec_znx_dft_avx;
|
||||||
|
module->func.vec_znx_idft = ntt120_vec_znx_idft_avx;
|
||||||
|
module->func.vec_znx_idft_tmp_bytes = ntt120_vec_znx_idft_tmp_bytes_avx;
|
||||||
|
module->func.vec_znx_idft_tmp_a = ntt120_vec_znx_idft_tmp_a_avx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void fill_virtual_table(MODULE* module) {
|
||||||
|
fill_generic_virtual_table(module);
|
||||||
|
switch (module->module_type) {
|
||||||
|
case FFT64:
|
||||||
|
fill_fft64_virtual_table(module);
|
||||||
|
break;
|
||||||
|
case NTT120:
|
||||||
|
fill_ntt120_virtual_table(module);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
NOT_SUPPORTED(); // invalid type
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void fill_fft64_precomp(MODULE* module) {
|
||||||
|
// fill any necessary precomp stuff
|
||||||
|
module->mod.fft64.p_conv = new_reim_from_znx64_precomp(module->m, 50);
|
||||||
|
module->mod.fft64.p_fft = new_reim_fft_precomp(module->m, 0);
|
||||||
|
module->mod.fft64.p_reim_to_znx = new_reim_to_znx64_precomp(module->m, module->m, 63);
|
||||||
|
module->mod.fft64.p_ifft = new_reim_ifft_precomp(module->m, 0);
|
||||||
|
module->mod.fft64.p_addmul = new_reim_fftvec_addmul_precomp(module->m);
|
||||||
|
module->mod.fft64.mul_fft = new_reim_fftvec_mul_precomp(module->m);
|
||||||
|
module->mod.fft64.add_fft = new_reim_fftvec_add_precomp(module->m);
|
||||||
|
module->mod.fft64.sub_fft = new_reim_fftvec_sub_precomp(module->m);
|
||||||
|
}
|
||||||
|
static void fill_ntt120_precomp(MODULE* module) {
|
||||||
|
// fill any necessary precomp stuff
|
||||||
|
if (CPU_SUPPORTS("avx2")) {
|
||||||
|
module->mod.q120.p_ntt = q120_new_ntt_bb_precomp(module->nn);
|
||||||
|
module->mod.q120.p_intt = q120_new_intt_bb_precomp(module->nn);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void fill_module_precomp(MODULE* module) {
|
||||||
|
switch (module->module_type) {
|
||||||
|
case FFT64:
|
||||||
|
fill_fft64_precomp(module);
|
||||||
|
break;
|
||||||
|
case NTT120:
|
||||||
|
fill_ntt120_precomp(module);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
NOT_SUPPORTED(); // invalid type
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void fill_module(MODULE* module, uint64_t nn, MODULE_TYPE mtype) {
|
||||||
|
// init to zero to ensure that any non-initialized field bug is detected
|
||||||
|
// by at least a "proper" segfault
|
||||||
|
memset(module, 0, sizeof(MODULE));
|
||||||
|
module->module_type = mtype;
|
||||||
|
module->nn = nn;
|
||||||
|
module->m = nn >> 1;
|
||||||
|
fill_module_precomp(module);
|
||||||
|
fill_virtual_table(module);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT MODULE* new_module_info(uint64_t N, MODULE_TYPE mtype) {
|
||||||
|
MODULE* m = (MODULE*)malloc(sizeof(MODULE));
|
||||||
|
fill_module(m, N, mtype);
|
||||||
|
return m;
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void delete_module_info(MODULE* mod) {
|
||||||
|
switch (mod->module_type) {
|
||||||
|
case FFT64:
|
||||||
|
free(mod->mod.fft64.p_conv);
|
||||||
|
free(mod->mod.fft64.p_fft);
|
||||||
|
free(mod->mod.fft64.p_ifft);
|
||||||
|
free(mod->mod.fft64.p_reim_to_znx);
|
||||||
|
free(mod->mod.fft64.mul_fft);
|
||||||
|
free(mod->mod.fft64.p_addmul);
|
||||||
|
break;
|
||||||
|
case NTT120:
|
||||||
|
if (CPU_SUPPORTS("avx2")) {
|
||||||
|
q120_del_ntt_bb_precomp(mod->mod.q120.p_ntt);
|
||||||
|
q120_del_intt_bb_precomp(mod->mod.q120.p_intt);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
free(mod);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT uint64_t module_get_n(const MODULE* module) { return module->nn; }
|
||||||
@@ -0,0 +1,102 @@
|
|||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include "vec_znx_arithmetic_private.h"
|
||||||
|
|
||||||
|
EXPORT uint64_t bytes_of_svp_ppol(const MODULE* module) { return module->func.bytes_of_svp_ppol(module); }
|
||||||
|
|
||||||
|
EXPORT uint64_t fft64_bytes_of_svp_ppol(const MODULE* module) { return module->nn * sizeof(double); }
|
||||||
|
|
||||||
|
EXPORT SVP_PPOL* new_svp_ppol(const MODULE* module) { return spqlios_alloc(bytes_of_svp_ppol(module)); }
|
||||||
|
|
||||||
|
EXPORT void delete_svp_ppol(SVP_PPOL* ppol) { spqlios_free(ppol); }
|
||||||
|
|
||||||
|
// public wrappers
|
||||||
|
EXPORT void svp_prepare(const MODULE* module, // N
|
||||||
|
SVP_PPOL* ppol, // output
|
||||||
|
const int64_t* pol // a
|
||||||
|
) {
|
||||||
|
module->func.svp_prepare(module, ppol, pol);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief prepares a svp polynomial */
|
||||||
|
EXPORT void fft64_svp_prepare_ref(const MODULE* module, // N
|
||||||
|
SVP_PPOL* ppol, // output
|
||||||
|
const int64_t* pol // a
|
||||||
|
) {
|
||||||
|
reim_from_znx64(module->mod.fft64.p_conv, ppol, pol);
|
||||||
|
reim_fft(module->mod.fft64.p_fft, (double*)ppol);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void svp_apply_dft(const MODULE* module, // N
|
||||||
|
const VEC_ZNX_DFT* res, uint64_t res_size, // output
|
||||||
|
const SVP_PPOL* ppol, // prepared pol
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl) {
|
||||||
|
module->func.svp_apply_dft(module, // N
|
||||||
|
res,
|
||||||
|
res_size, // output
|
||||||
|
ppol, // prepared pol
|
||||||
|
a, a_size, a_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void svp_apply_dft_to_dft(const MODULE* module, // N
|
||||||
|
const VEC_ZNX_DFT* res, uint64_t res_size,
|
||||||
|
uint64_t res_cols, // output
|
||||||
|
const SVP_PPOL* ppol, // prepared pol
|
||||||
|
const VEC_ZNX_DFT* a, uint64_t a_size, uint64_t a_cols) {
|
||||||
|
module->func.svp_apply_dft_to_dft(module, // N
|
||||||
|
res, res_size, res_cols, // output
|
||||||
|
ppol, a, a_size, a_cols // prepared pol
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// result = ppol * a
|
||||||
|
EXPORT void fft64_svp_apply_dft_ref(const MODULE* module, // N
|
||||||
|
const VEC_ZNX_DFT* res, uint64_t res_size, // output
|
||||||
|
const SVP_PPOL* ppol, // prepared pol
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
double* const dres = (double*)res;
|
||||||
|
double* const dppol = (double*)ppol;
|
||||||
|
|
||||||
|
const uint64_t auto_end_idx = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < auto_end_idx; ++i) {
|
||||||
|
const int64_t* a_ptr = a + i * a_sl;
|
||||||
|
double* const res_ptr = dres + i * nn;
|
||||||
|
// copy the polynomial to res, apply fft in place, call fftvec_mul in place.
|
||||||
|
reim_from_znx64(module->mod.fft64.p_conv, res_ptr, a_ptr);
|
||||||
|
reim_fft(module->mod.fft64.p_fft, res_ptr);
|
||||||
|
reim_fftvec_mul(module->mod.fft64.mul_fft, res_ptr, res_ptr, dppol);
|
||||||
|
}
|
||||||
|
|
||||||
|
// then extend with zeros
|
||||||
|
memset(dres + auto_end_idx * nn, 0, (res_size - auto_end_idx) * nn * sizeof(double));
|
||||||
|
}
|
||||||
|
|
||||||
|
// result = ppol * a
|
||||||
|
EXPORT void fft64_svp_apply_dft_to_dft_ref(const MODULE* module, // N
|
||||||
|
const VEC_ZNX_DFT* res, uint64_t res_size,
|
||||||
|
uint64_t res_cols, // output
|
||||||
|
const SVP_PPOL* ppol, // prepared pol
|
||||||
|
const VEC_ZNX_DFT* a, uint64_t a_size,
|
||||||
|
uint64_t a_cols // a
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
const uint64_t res_sl = nn * res_cols;
|
||||||
|
const uint64_t a_sl = nn * a_cols;
|
||||||
|
double* const dres = (double*)res;
|
||||||
|
double* const da = (double*)a;
|
||||||
|
double* const dppol = (double*)ppol;
|
||||||
|
|
||||||
|
const uint64_t auto_end_idx = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < auto_end_idx; ++i) {
|
||||||
|
const double* a_ptr = da + i * a_sl;
|
||||||
|
double* const res_ptr = dres + i * res_sl;
|
||||||
|
reim_fftvec_mul(module->mod.fft64.mul_fft, res_ptr, a_ptr, dppol);
|
||||||
|
}
|
||||||
|
|
||||||
|
// then extend with zeros
|
||||||
|
for (uint64_t i = auto_end_idx; i < res_size; i++) {
|
||||||
|
memset(dres + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,344 @@
|
|||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include "vec_rnx_arithmetic_private.h"
|
||||||
|
|
||||||
|
void fft64_init_rnx_module_precomp(MOD_RNX* module) {
|
||||||
|
// Add here initialization of items that are in the precomp
|
||||||
|
const uint64_t m = module->m;
|
||||||
|
module->precomp.fft64.p_fft = new_reim_fft_precomp(m, 0);
|
||||||
|
module->precomp.fft64.p_ifft = new_reim_ifft_precomp(m, 0);
|
||||||
|
module->precomp.fft64.p_fftvec_add = new_reim_fftvec_add_precomp(m);
|
||||||
|
module->precomp.fft64.p_fftvec_mul = new_reim_fftvec_mul_precomp(m);
|
||||||
|
module->precomp.fft64.p_fftvec_addmul = new_reim_fftvec_addmul_precomp(m);
|
||||||
|
}
|
||||||
|
|
||||||
|
void fft64_finalize_rnx_module_precomp(MOD_RNX* module) {
|
||||||
|
// Add here deleters for items that are in the precomp
|
||||||
|
delete_reim_fft_precomp(module->precomp.fft64.p_fft);
|
||||||
|
delete_reim_ifft_precomp(module->precomp.fft64.p_ifft);
|
||||||
|
delete_reim_fftvec_add_precomp(module->precomp.fft64.p_fftvec_add);
|
||||||
|
delete_reim_fftvec_mul_precomp(module->precomp.fft64.p_fftvec_mul);
|
||||||
|
delete_reim_fftvec_addmul_precomp(module->precomp.fft64.p_fftvec_addmul);
|
||||||
|
}
|
||||||
|
|
||||||
|
void fft64_init_rnx_module_vtable(MOD_RNX* module) {
|
||||||
|
// Add function pointers here
|
||||||
|
module->vtable.vec_rnx_add = vec_rnx_add_ref;
|
||||||
|
module->vtable.vec_rnx_zero = vec_rnx_zero_ref;
|
||||||
|
module->vtable.vec_rnx_copy = vec_rnx_copy_ref;
|
||||||
|
module->vtable.vec_rnx_negate = vec_rnx_negate_ref;
|
||||||
|
module->vtable.vec_rnx_sub = vec_rnx_sub_ref;
|
||||||
|
module->vtable.vec_rnx_rotate = vec_rnx_rotate_ref;
|
||||||
|
module->vtable.vec_rnx_automorphism = vec_rnx_automorphism_ref;
|
||||||
|
module->vtable.vec_rnx_mul_xp_minus_one = vec_rnx_mul_xp_minus_one_ref;
|
||||||
|
module->vtable.rnx_vmp_apply_dft_to_dft_tmp_bytes = fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref;
|
||||||
|
module->vtable.rnx_vmp_apply_dft_to_dft = fft64_rnx_vmp_apply_dft_to_dft_ref;
|
||||||
|
module->vtable.rnx_vmp_apply_tmp_a_tmp_bytes = fft64_rnx_vmp_apply_tmp_a_tmp_bytes_ref;
|
||||||
|
module->vtable.rnx_vmp_apply_tmp_a = fft64_rnx_vmp_apply_tmp_a_ref;
|
||||||
|
module->vtable.rnx_vmp_prepare_tmp_bytes = fft64_rnx_vmp_prepare_tmp_bytes_ref;
|
||||||
|
module->vtable.rnx_vmp_prepare_contiguous = fft64_rnx_vmp_prepare_contiguous_ref;
|
||||||
|
module->vtable.rnx_vmp_prepare_dblptr = fft64_rnx_vmp_prepare_dblptr_ref;
|
||||||
|
module->vtable.rnx_vmp_prepare_row = fft64_rnx_vmp_prepare_row_ref;
|
||||||
|
module->vtable.bytes_of_rnx_vmp_pmat = fft64_bytes_of_rnx_vmp_pmat;
|
||||||
|
module->vtable.rnx_approxdecomp_from_tnxdbl = rnx_approxdecomp_from_tnxdbl_ref;
|
||||||
|
module->vtable.vec_rnx_to_znx32 = vec_rnx_to_znx32_ref;
|
||||||
|
module->vtable.vec_rnx_from_znx32 = vec_rnx_from_znx32_ref;
|
||||||
|
module->vtable.vec_rnx_to_tnx32 = vec_rnx_to_tnx32_ref;
|
||||||
|
module->vtable.vec_rnx_from_tnx32 = vec_rnx_from_tnx32_ref;
|
||||||
|
module->vtable.vec_rnx_to_tnxdbl = vec_rnx_to_tnxdbl_ref;
|
||||||
|
module->vtable.bytes_of_rnx_svp_ppol = fft64_bytes_of_rnx_svp_ppol;
|
||||||
|
module->vtable.rnx_svp_prepare = fft64_rnx_svp_prepare_ref;
|
||||||
|
module->vtable.rnx_svp_apply = fft64_rnx_svp_apply_ref;
|
||||||
|
|
||||||
|
// Add optimized function pointers here
|
||||||
|
if (CPU_SUPPORTS("avx")) {
|
||||||
|
module->vtable.vec_rnx_add = vec_rnx_add_avx;
|
||||||
|
module->vtable.vec_rnx_sub = vec_rnx_sub_avx;
|
||||||
|
module->vtable.vec_rnx_negate = vec_rnx_negate_avx;
|
||||||
|
module->vtable.rnx_vmp_apply_dft_to_dft_tmp_bytes = fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_avx;
|
||||||
|
module->vtable.rnx_vmp_apply_dft_to_dft = fft64_rnx_vmp_apply_dft_to_dft_avx;
|
||||||
|
module->vtable.rnx_vmp_apply_tmp_a_tmp_bytes = fft64_rnx_vmp_apply_tmp_a_tmp_bytes_avx;
|
||||||
|
module->vtable.rnx_vmp_apply_tmp_a = fft64_rnx_vmp_apply_tmp_a_avx;
|
||||||
|
module->vtable.rnx_vmp_prepare_tmp_bytes = fft64_rnx_vmp_prepare_tmp_bytes_avx;
|
||||||
|
module->vtable.rnx_vmp_prepare_contiguous = fft64_rnx_vmp_prepare_contiguous_avx;
|
||||||
|
module->vtable.rnx_vmp_prepare_dblptr = fft64_rnx_vmp_prepare_dblptr_avx;
|
||||||
|
module->vtable.rnx_vmp_prepare_row = fft64_rnx_vmp_prepare_row_avx;
|
||||||
|
module->vtable.rnx_approxdecomp_from_tnxdbl = rnx_approxdecomp_from_tnxdbl_avx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void init_rnx_module_info(MOD_RNX* module, //
|
||||||
|
uint64_t n, RNX_MODULE_TYPE mtype) {
|
||||||
|
memset(module, 0, sizeof(MOD_RNX));
|
||||||
|
module->n = n;
|
||||||
|
module->m = n >> 1;
|
||||||
|
module->mtype = mtype;
|
||||||
|
switch (mtype) {
|
||||||
|
case FFT64:
|
||||||
|
fft64_init_rnx_module_precomp(module);
|
||||||
|
fft64_init_rnx_module_vtable(module);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
NOT_SUPPORTED(); // unknown mtype
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void finalize_rnx_module_info(MOD_RNX* module) {
|
||||||
|
if (module->custom) module->custom_deleter(module->custom);
|
||||||
|
switch (module->mtype) {
|
||||||
|
case FFT64:
|
||||||
|
fft64_finalize_rnx_module_precomp(module);
|
||||||
|
// fft64_finalize_rnx_module_vtable(module); // nothing to finalize
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
NOT_SUPPORTED(); // unknown mtype
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT MOD_RNX* new_rnx_module_info(uint64_t nn, RNX_MODULE_TYPE mtype) {
|
||||||
|
MOD_RNX* res = (MOD_RNX*)malloc(sizeof(MOD_RNX));
|
||||||
|
init_rnx_module_info(res, nn, mtype);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void delete_rnx_module_info(MOD_RNX* module_info) {
|
||||||
|
finalize_rnx_module_info(module_info);
|
||||||
|
free(module_info);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT uint64_t rnx_module_get_n(const MOD_RNX* module) { return module->n; }
|
||||||
|
|
||||||
|
/** @brief allocates a prepared matrix (release with delete_rnx_vmp_pmat) */
|
||||||
|
EXPORT RNX_VMP_PMAT* new_rnx_vmp_pmat(const MOD_RNX* module, // N
|
||||||
|
uint64_t nrows, uint64_t ncols) { // dimensions
|
||||||
|
return (RNX_VMP_PMAT*)spqlios_alloc(bytes_of_rnx_vmp_pmat(module, nrows, ncols));
|
||||||
|
}
|
||||||
|
EXPORT void delete_rnx_vmp_pmat(RNX_VMP_PMAT* ptr) { spqlios_free(ptr); }
|
||||||
|
|
||||||
|
//////////////// wrappers //////////////////
|
||||||
|
|
||||||
|
/** @brief sets res = a + b */
|
||||||
|
EXPORT void vec_rnx_add( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
) {
|
||||||
|
module->vtable.vec_rnx_add(module, res, res_size, res_sl, a, a_size, a_sl, b, b_size, b_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = 0 */
|
||||||
|
EXPORT void vec_rnx_zero( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl // res
|
||||||
|
) {
|
||||||
|
module->vtable.vec_rnx_zero(module, res, res_size, res_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = a */
|
||||||
|
EXPORT void vec_rnx_copy( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
module->vtable.vec_rnx_copy(module, res, res_size, res_sl, a, a_size, a_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = -a */
|
||||||
|
EXPORT void vec_rnx_negate( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
module->vtable.vec_rnx_negate(module, res, res_size, res_sl, a, a_size, a_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = a - b */
|
||||||
|
EXPORT void vec_rnx_sub( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
) {
|
||||||
|
module->vtable.vec_rnx_sub(module, res, res_size, res_sl, a, a_size, a_sl, b, b_size, b_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = a . X^p */
|
||||||
|
EXPORT void vec_rnx_rotate( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
const int64_t p, // rotation value
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
module->vtable.vec_rnx_rotate(module, p, res, res_size, res_sl, a, a_size, a_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = a(X^p) */
|
||||||
|
EXPORT void vec_rnx_automorphism( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
int64_t p, // X -> X^p
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
module->vtable.vec_rnx_automorphism(module, p, res, res_size, res_sl, a, a_size, a_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_mul_xp_minus_one( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
const int64_t p, // rotation value
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
module->vtable.vec_rnx_mul_xp_minus_one(module, p, res, res_size, res_sl, a, a_size, a_sl);
|
||||||
|
}
|
||||||
|
/** @brief number of bytes in a RNX_VMP_PMAT (for manual allocation) */
|
||||||
|
EXPORT uint64_t bytes_of_rnx_vmp_pmat(const MOD_RNX* module, // N
|
||||||
|
uint64_t nrows, uint64_t ncols) { // dimensions
|
||||||
|
return module->vtable.bytes_of_rnx_vmp_pmat(module, nrows, ncols);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||||
|
EXPORT void rnx_vmp_prepare_contiguous( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
RNX_VMP_PMAT* pmat, // output
|
||||||
|
const double* a, uint64_t nrows, uint64_t ncols, // a
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
) {
|
||||||
|
module->vtable.rnx_vmp_prepare_contiguous(module, pmat, a, nrows, ncols, tmp_space);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */
|
||||||
|
EXPORT void rnx_vmp_prepare_dblptr( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
RNX_VMP_PMAT* pmat, // output
|
||||||
|
const double** a, uint64_t nrows, uint64_t ncols, // a
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
) {
|
||||||
|
module->vtable.rnx_vmp_prepare_dblptr(module, pmat, a, nrows, ncols, tmp_space);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */
|
||||||
|
EXPORT void rnx_vmp_prepare_row( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
RNX_VMP_PMAT* pmat, // output
|
||||||
|
const double* a, uint64_t row_i, uint64_t nrows, uint64_t ncols, // a
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
) {
|
||||||
|
module->vtable.rnx_vmp_prepare_row(module, pmat, a, row_i, nrows, ncols, tmp_space);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief number of scratch bytes necessary to prepare a matrix */
|
||||||
|
EXPORT uint64_t rnx_vmp_prepare_tmp_bytes(const MOD_RNX* module) {
|
||||||
|
return module->vtable.rnx_vmp_prepare_tmp_bytes(module);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief applies a vmp product res = a x pmat */
|
||||||
|
EXPORT void rnx_vmp_apply_tmp_a( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten)
|
||||||
|
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
) {
|
||||||
|
module->vtable.rnx_vmp_apply_tmp_a(module, res, res_size, res_sl, tmpa, a_size, a_sl, pmat, nrows, ncols, tmp_space);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT uint64_t rnx_vmp_apply_tmp_a_tmp_bytes( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
uint64_t res_size, // res size
|
||||||
|
uint64_t a_size, // a size
|
||||||
|
uint64_t nrows, uint64_t ncols // prep matrix dims
|
||||||
|
) {
|
||||||
|
return module->vtable.rnx_vmp_apply_tmp_a_tmp_bytes(module, res_size, a_size, nrows, ncols);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief minimal size of the tmp_space */
|
||||||
|
EXPORT void rnx_vmp_apply_dft_to_dft( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a_dft, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||||
|
) {
|
||||||
|
module->vtable.rnx_vmp_apply_dft_to_dft(module, res, res_size, res_sl, a_dft, a_size, a_sl, pmat, nrows, ncols,
|
||||||
|
tmp_space);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief minimal size of the tmp_space */
|
||||||
|
EXPORT uint64_t rnx_vmp_apply_dft_to_dft_tmp_bytes( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
uint64_t res_size, // res
|
||||||
|
uint64_t a_size, // a
|
||||||
|
uint64_t nrows, uint64_t ncols // prep matrix
|
||||||
|
) {
|
||||||
|
return module->vtable.rnx_vmp_apply_dft_to_dft_tmp_bytes(module, res_size, a_size, nrows, ncols);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT uint64_t bytes_of_rnx_svp_ppol(const MOD_RNX* module) { return module->vtable.bytes_of_rnx_svp_ppol(module); }
|
||||||
|
|
||||||
|
EXPORT void rnx_svp_prepare(const MOD_RNX* module, // N
|
||||||
|
RNX_SVP_PPOL* ppol, // output
|
||||||
|
const double* pol // a
|
||||||
|
) {
|
||||||
|
module->vtable.rnx_svp_prepare(module, ppol, pol);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void rnx_svp_apply( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // output
|
||||||
|
const RNX_SVP_PPOL* ppol, // prepared pol
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
module->vtable.rnx_svp_apply(module, // N
|
||||||
|
res, res_size, res_sl, // output
|
||||||
|
ppol, // prepared pol
|
||||||
|
a, a_size, a_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void rnx_approxdecomp_from_tnxdbl( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a) { // a
|
||||||
|
module->vtable.rnx_approxdecomp_from_tnxdbl(module, gadget, res, res_size, res_sl, a);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_to_znx32( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
int32_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
module->vtable.vec_rnx_to_znx32(module, res, res_size, res_sl, a, a_size, a_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_from_znx32( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int32_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
module->vtable.vec_rnx_from_znx32(module, res, res_size, res_sl, a, a_size, a_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_to_tnx32( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
int32_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
module->vtable.vec_rnx_to_tnx32(module, res, res_size, res_sl, a, a_size, a_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_from_tnx32( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int32_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
module->vtable.vec_rnx_from_tnx32(module, res, res_size, res_sl, a, a_size, a_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_to_tnxdbl( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
module->vtable.vec_rnx_to_tnxdbl(module, res, res_size, res_sl, a, a_size, a_sl);
|
||||||
|
}
|
||||||
@@ -0,0 +1,59 @@
|
|||||||
|
#include <memory.h>
|
||||||
|
|
||||||
|
#include "immintrin.h"
|
||||||
|
#include "vec_rnx_arithmetic_private.h"
|
||||||
|
|
||||||
|
/** @brief sets res = gadget_decompose(a) */
|
||||||
|
EXPORT void rnx_approxdecomp_from_tnxdbl_avx( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a // a
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
if (nn < 4) return rnx_approxdecomp_from_tnxdbl_ref(module, gadget, res, res_size, res_sl, a);
|
||||||
|
const uint64_t ell = gadget->ell;
|
||||||
|
const __m256i k = _mm256_set1_epi64x(gadget->k);
|
||||||
|
const __m256d add_cst = _mm256_set1_pd(gadget->add_cst);
|
||||||
|
const __m256i and_mask = _mm256_set1_epi64x(gadget->and_mask);
|
||||||
|
const __m256i or_mask = _mm256_set1_epi64x(gadget->or_mask);
|
||||||
|
const __m256d sub_cst = _mm256_set1_pd(gadget->sub_cst);
|
||||||
|
const uint64_t msize = res_size <= ell ? res_size : ell;
|
||||||
|
// gadget decompose column by column
|
||||||
|
if (msize == ell) {
|
||||||
|
// this is the main scenario when msize == ell
|
||||||
|
double* const last_r = res + (msize - 1) * res_sl;
|
||||||
|
for (uint64_t j = 0; j < nn; j += 4) {
|
||||||
|
double* rr = last_r + j;
|
||||||
|
const double* aa = a + j;
|
||||||
|
__m256d t_dbl = _mm256_add_pd(_mm256_loadu_pd(aa), add_cst);
|
||||||
|
__m256i t_int = _mm256_castpd_si256(t_dbl);
|
||||||
|
do {
|
||||||
|
__m256i u_int = _mm256_or_si256(_mm256_and_si256(t_int, and_mask), or_mask);
|
||||||
|
_mm256_storeu_pd(rr, _mm256_sub_pd(_mm256_castsi256_pd(u_int), sub_cst));
|
||||||
|
t_int = _mm256_srlv_epi64(t_int, k);
|
||||||
|
rr -= res_sl;
|
||||||
|
} while (rr >= res);
|
||||||
|
}
|
||||||
|
} else if (msize > 0) {
|
||||||
|
// otherwise, if msize < ell: there is one additional rshift
|
||||||
|
const __m256i first_rsh = _mm256_set1_epi64x((ell - msize) * gadget->k);
|
||||||
|
double* const last_r = res + (msize - 1) * res_sl;
|
||||||
|
for (uint64_t j = 0; j < nn; j += 4) {
|
||||||
|
double* rr = last_r + j;
|
||||||
|
const double* aa = a + j;
|
||||||
|
__m256d t_dbl = _mm256_add_pd(_mm256_loadu_pd(aa), add_cst);
|
||||||
|
__m256i t_int = _mm256_srlv_epi64(_mm256_castpd_si256(t_dbl), first_rsh);
|
||||||
|
do {
|
||||||
|
__m256i u_int = _mm256_or_si256(_mm256_and_si256(t_int, and_mask), or_mask);
|
||||||
|
_mm256_storeu_pd(rr, _mm256_sub_pd(_mm256_castsi256_pd(u_int), sub_cst));
|
||||||
|
t_int = _mm256_srlv_epi64(t_int, k);
|
||||||
|
rr -= res_sl;
|
||||||
|
} while (rr >= res);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// zero-out the last slices (if any)
|
||||||
|
for (uint64_t i = msize; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,75 @@
|
|||||||
|
#include <memory.h>
|
||||||
|
|
||||||
|
#include "vec_rnx_arithmetic_private.h"
|
||||||
|
|
||||||
|
typedef union di {
|
||||||
|
double dv;
|
||||||
|
uint64_t uv;
|
||||||
|
} di_t;
|
||||||
|
|
||||||
|
/** @brief new gadget: delete with delete_tnxdbl_approxdecomp_gadget */
|
||||||
|
EXPORT TNXDBL_APPROXDECOMP_GADGET* new_tnxdbl_approxdecomp_gadget( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
uint64_t k, uint64_t ell // base 2^K and size
|
||||||
|
) {
|
||||||
|
if (k * ell > 50) return spqlios_error("gadget requires a too large fp precision");
|
||||||
|
TNXDBL_APPROXDECOMP_GADGET* res = spqlios_alloc(sizeof(TNXDBL_APPROXDECOMP_GADGET));
|
||||||
|
res->k = k;
|
||||||
|
res->ell = ell;
|
||||||
|
// double add_cst; // double(3.2^(51-ell.K) + 1/2.(sum 2^(-iK)) for i=[0,ell[)
|
||||||
|
union di add_cst;
|
||||||
|
add_cst.dv = UINT64_C(3) << (51 - ell * k);
|
||||||
|
for (uint64_t i = 0; i < ell; ++i) {
|
||||||
|
add_cst.uv |= UINT64_C(1) << ((i + 1) * k - 1);
|
||||||
|
}
|
||||||
|
res->add_cst = add_cst.dv;
|
||||||
|
// uint64_t and_mask; // uint64(2^(K)-1)
|
||||||
|
res->and_mask = (UINT64_C(1) << k) - 1;
|
||||||
|
// uint64_t or_mask; // double(2^52)
|
||||||
|
union di or_mask;
|
||||||
|
or_mask.dv = (UINT64_C(1) << 52);
|
||||||
|
res->or_mask = or_mask.uv;
|
||||||
|
// double sub_cst; // double(2^52 + 2^(K-1))
|
||||||
|
res->sub_cst = ((UINT64_C(1) << 52) + (UINT64_C(1) << (k - 1)));
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void delete_tnxdbl_approxdecomp_gadget(TNXDBL_APPROXDECOMP_GADGET* gadget) { spqlios_free(gadget); }
|
||||||
|
|
||||||
|
/** @brief sets res = gadget_decompose(a) */
|
||||||
|
EXPORT void rnx_approxdecomp_from_tnxdbl_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a // a
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
const uint64_t k = gadget->k;
|
||||||
|
const uint64_t ell = gadget->ell;
|
||||||
|
const double add_cst = gadget->add_cst;
|
||||||
|
const uint64_t and_mask = gadget->and_mask;
|
||||||
|
const uint64_t or_mask = gadget->or_mask;
|
||||||
|
const double sub_cst = gadget->sub_cst;
|
||||||
|
const uint64_t msize = res_size <= ell ? res_size : ell;
|
||||||
|
const uint64_t first_rsh = (ell - msize) * k;
|
||||||
|
// gadget decompose column by column
|
||||||
|
if (msize > 0) {
|
||||||
|
double* const last_r = res + (msize - 1) * res_sl;
|
||||||
|
for (uint64_t j = 0; j < nn; ++j) {
|
||||||
|
double* rr = last_r + j;
|
||||||
|
di_t t = {.dv = a[j] + add_cst};
|
||||||
|
if (msize < ell) t.uv >>= first_rsh;
|
||||||
|
do {
|
||||||
|
di_t u;
|
||||||
|
u.uv = (t.uv & and_mask) | or_mask;
|
||||||
|
*rr = u.dv - sub_cst;
|
||||||
|
t.uv >>= k;
|
||||||
|
rr -= res_sl;
|
||||||
|
} while (rr >= res);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// zero-out the last slices (if any)
|
||||||
|
for (uint64_t i = msize; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,223 @@
|
|||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include "../coeffs/coeffs_arithmetic.h"
|
||||||
|
#include "vec_rnx_arithmetic_private.h"
|
||||||
|
|
||||||
|
void rnx_add_ref(uint64_t nn, double* res, const double* a, const double* b) {
|
||||||
|
for (uint64_t i = 0; i < nn; ++i) {
|
||||||
|
res[i] = a[i] + b[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void rnx_sub_ref(uint64_t nn, double* res, const double* a, const double* b) {
|
||||||
|
for (uint64_t i = 0; i < nn; ++i) {
|
||||||
|
res[i] = a[i] - b[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void rnx_negate_ref(uint64_t nn, double* res, const double* a) {
|
||||||
|
for (uint64_t i = 0; i < nn; ++i) {
|
||||||
|
res[i] = -a[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = a + b */
|
||||||
|
EXPORT void vec_rnx_add_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
if (a_size < b_size) {
|
||||||
|
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||||
|
const uint64_t nsize = res_size < b_size ? res_size : b_size;
|
||||||
|
for (uint64_t i = 0; i < msize; ++i) {
|
||||||
|
rnx_add_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
for (uint64_t i = msize; i < nsize; ++i) {
|
||||||
|
memcpy(res + i * res_sl, b + i * b_sl, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
for (uint64_t i = nsize; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const uint64_t msize = res_size < b_size ? res_size : b_size;
|
||||||
|
const uint64_t nsize = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < msize; ++i) {
|
||||||
|
rnx_add_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
for (uint64_t i = msize; i < nsize; ++i) {
|
||||||
|
memcpy(res + i * res_sl, a + i * a_sl, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
for (uint64_t i = nsize; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = 0 */
|
||||||
|
EXPORT void vec_rnx_zero_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl // res
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
for (uint64_t i = 0; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = a */
|
||||||
|
EXPORT void vec_rnx_copy_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
|
||||||
|
const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size;
|
||||||
|
// rotate up to the smallest dimension
|
||||||
|
for (uint64_t i = 0; i < rot_end_idx; ++i) {
|
||||||
|
double* res_ptr = res + i * res_sl;
|
||||||
|
const double* a_ptr = a + i * a_sl;
|
||||||
|
memcpy(res_ptr, a_ptr, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
// then extend with zeros
|
||||||
|
for (uint64_t i = rot_end_idx; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = -a */
|
||||||
|
EXPORT void vec_rnx_negate_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
|
||||||
|
const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size;
|
||||||
|
// rotate up to the smallest dimension
|
||||||
|
for (uint64_t i = 0; i < rot_end_idx; ++i) {
|
||||||
|
double* res_ptr = res + i * res_sl;
|
||||||
|
const double* a_ptr = a + i * a_sl;
|
||||||
|
rnx_negate_ref(nn, res_ptr, a_ptr);
|
||||||
|
}
|
||||||
|
// then extend with zeros
|
||||||
|
for (uint64_t i = rot_end_idx; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = a - b */
|
||||||
|
EXPORT void vec_rnx_sub_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
if (a_size < b_size) {
|
||||||
|
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||||
|
const uint64_t nsize = res_size < b_size ? res_size : b_size;
|
||||||
|
for (uint64_t i = 0; i < msize; ++i) {
|
||||||
|
rnx_sub_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
for (uint64_t i = msize; i < nsize; ++i) {
|
||||||
|
rnx_negate_ref(nn, res + i * res_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
for (uint64_t i = nsize; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const uint64_t msize = res_size < b_size ? res_size : b_size;
|
||||||
|
const uint64_t nsize = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < msize; ++i) {
|
||||||
|
rnx_sub_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
for (uint64_t i = msize; i < nsize; ++i) {
|
||||||
|
memcpy(res + i * res_sl, a + i * a_sl, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
for (uint64_t i = nsize; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = a . X^p */
|
||||||
|
EXPORT void vec_rnx_rotate_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
const int64_t p, // rotation value
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
|
||||||
|
const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size;
|
||||||
|
// rotate up to the smallest dimension
|
||||||
|
for (uint64_t i = 0; i < rot_end_idx; ++i) {
|
||||||
|
double* res_ptr = res + i * res_sl;
|
||||||
|
const double* a_ptr = a + i * a_sl;
|
||||||
|
if (res_ptr == a_ptr) {
|
||||||
|
rnx_rotate_inplace_f64(nn, p, res_ptr);
|
||||||
|
} else {
|
||||||
|
rnx_rotate_f64(nn, p, res_ptr, a_ptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// then extend with zeros
|
||||||
|
for (uint64_t i = rot_end_idx; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = a(X^p) */
|
||||||
|
EXPORT void vec_rnx_automorphism_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
int64_t p, // X -> X^p
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
|
||||||
|
const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size;
|
||||||
|
// rotate up to the smallest dimension
|
||||||
|
for (uint64_t i = 0; i < rot_end_idx; ++i) {
|
||||||
|
double* res_ptr = res + i * res_sl;
|
||||||
|
const double* a_ptr = a + i * a_sl;
|
||||||
|
if (res_ptr == a_ptr) {
|
||||||
|
rnx_automorphism_inplace_f64(nn, p, res_ptr);
|
||||||
|
} else {
|
||||||
|
rnx_automorphism_f64(nn, p, res_ptr, a_ptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// then extend with zeros
|
||||||
|
for (uint64_t i = rot_end_idx; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = a . (X^p - 1) */
|
||||||
|
EXPORT void vec_rnx_mul_xp_minus_one_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
const int64_t p, // rotation value
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
|
||||||
|
const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size;
|
||||||
|
// rotate up to the smallest dimension
|
||||||
|
for (uint64_t i = 0; i < rot_end_idx; ++i) {
|
||||||
|
double* res_ptr = res + i * res_sl;
|
||||||
|
const double* a_ptr = a + i * a_sl;
|
||||||
|
if (res_ptr == a_ptr) {
|
||||||
|
rnx_mul_xp_minus_one_inplace_f64(nn, p, res_ptr);
|
||||||
|
} else {
|
||||||
|
rnx_mul_xp_minus_one_f64(nn, p, res_ptr, a_ptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// then extend with zeros
|
||||||
|
for (uint64_t i = rot_end_idx; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,356 @@
|
|||||||
|
#ifndef SPQLIOS_VEC_RNX_ARITHMETIC_H
|
||||||
|
#define SPQLIOS_VEC_RNX_ARITHMETIC_H
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include "../commons.h"
|
||||||
|
|
||||||
|
/**
|
||||||
|
* We support the following module families:
|
||||||
|
* - FFT64:
|
||||||
|
* the overall precision should fit at all times over 52 bits.
|
||||||
|
*/
|
||||||
|
typedef enum rnx_module_type_t { FFT64 } RNX_MODULE_TYPE;
|
||||||
|
|
||||||
|
/** @brief opaque structure that describes the modules (RnX,ZnX,TnX) and the hardware */
|
||||||
|
typedef struct rnx_module_info_t MOD_RNX;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief obtain a module info for ring dimension N
|
||||||
|
* the module-info knows about:
|
||||||
|
* - the dimension N (or the complex dimension m=N/2)
|
||||||
|
* - any moduleuted fft or ntt items
|
||||||
|
* - the hardware (avx, arm64, x86, ...)
|
||||||
|
*/
|
||||||
|
EXPORT MOD_RNX* new_rnx_module_info(uint64_t N, RNX_MODULE_TYPE mode);
|
||||||
|
EXPORT void delete_rnx_module_info(MOD_RNX* module_info);
|
||||||
|
EXPORT uint64_t rnx_module_get_n(const MOD_RNX* module);
|
||||||
|
|
||||||
|
// basic arithmetic
|
||||||
|
|
||||||
|
/** @brief sets res = 0 */
|
||||||
|
EXPORT void vec_rnx_zero( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl // res
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a */
|
||||||
|
EXPORT void vec_rnx_copy( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = -a */
|
||||||
|
EXPORT void vec_rnx_negate( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a + b */
|
||||||
|
EXPORT void vec_rnx_add( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a - b */
|
||||||
|
EXPORT void vec_rnx_sub( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a . X^p */
|
||||||
|
EXPORT void vec_rnx_rotate( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
const int64_t p, // rotation value
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a . (X^p - 1) */
|
||||||
|
EXPORT void vec_rnx_mul_xp_minus_one( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
const int64_t p, // rotation value
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a(X^p) */
|
||||||
|
EXPORT void vec_rnx_automorphism( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
int64_t p, // X -> X^p
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
// conversions //
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_to_znx32( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
int32_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_from_znx32( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int32_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_to_tnx32( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
int32_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_from_tnx32( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int32_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_to_tnx32x2( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
int32_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_from_tnx32x2( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int32_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_to_tnxdbl( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
// isolated products (n.log(n), but not particularly optimized //
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
/** @brief res = a * b : small polynomial product */
|
||||||
|
EXPORT void rnx_small_single_product( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, // output
|
||||||
|
const double* a, // a
|
||||||
|
const double* b, // b
|
||||||
|
uint8_t* tmp); // scratch space
|
||||||
|
|
||||||
|
EXPORT uint64_t rnx_small_single_product_tmp_bytes(const MOD_RNX* module);
|
||||||
|
|
||||||
|
/** @brief res = a * b centermod 1: small polynomial product */
|
||||||
|
EXPORT void tnxdbl_small_single_product( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* torus_res, // output
|
||||||
|
const double* int_a, // a
|
||||||
|
const double* torus_b, // b
|
||||||
|
uint8_t* tmp); // scratch space
|
||||||
|
|
||||||
|
EXPORT uint64_t tnxdbl_small_single_product_tmp_bytes(const MOD_RNX* module);
|
||||||
|
|
||||||
|
/** @brief res = a * b: small polynomial product */
|
||||||
|
EXPORT void znx32_small_single_product( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
int32_t* int_res, // output
|
||||||
|
const int32_t* int_a, // a
|
||||||
|
const int32_t* int_b, // b
|
||||||
|
uint8_t* tmp); // scratch space
|
||||||
|
|
||||||
|
EXPORT uint64_t znx32_small_single_product_tmp_bytes(const MOD_RNX* module);
|
||||||
|
|
||||||
|
/** @brief res = a * b centermod 1: small polynomial product */
|
||||||
|
EXPORT void tnx32_small_single_product( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
int32_t* torus_res, // output
|
||||||
|
const int32_t* int_a, // a
|
||||||
|
const int32_t* torus_b, // b
|
||||||
|
uint8_t* tmp); // scratch space
|
||||||
|
|
||||||
|
EXPORT uint64_t tnx32_small_single_product_tmp_bytes(const MOD_RNX* module);
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
// prepared gadget decompositions (optimized) //
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// decompose from tnx32
|
||||||
|
|
||||||
|
typedef struct tnx32_approxdecomp_gadget_t TNX32_APPROXDECOMP_GADGET;
|
||||||
|
|
||||||
|
/** @brief new gadget: delete with delete_tnx32_approxdecomp_gadget */
|
||||||
|
EXPORT TNX32_APPROXDECOMP_GADGET* new_tnx32_approxdecomp_gadget( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
uint64_t k, uint64_t ell // base 2^K and size
|
||||||
|
);
|
||||||
|
EXPORT void delete_tnx32_approxdecomp_gadget(const MOD_RNX* module);
|
||||||
|
|
||||||
|
/** @brief sets res = gadget_decompose(a) */
|
||||||
|
EXPORT void rnx_approxdecomp_from_tnx32( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
const TNX32_APPROXDECOMP_GADGET* gadget, // output base 2^K
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int32_t* a // a
|
||||||
|
);
|
||||||
|
|
||||||
|
// decompose from tnx32x2
|
||||||
|
|
||||||
|
typedef struct tnx32x2_approxdecomp_gadget_t TNX32X2_APPROXDECOMP_GADGET;
|
||||||
|
|
||||||
|
/** @brief new gadget: delete with delete_tnx32x2_approxdecomp_gadget */
|
||||||
|
EXPORT TNX32X2_APPROXDECOMP_GADGET* new_tnx32x2_approxdecomp_gadget(const MOD_RNX* module, uint64_t ka, uint64_t ella,
|
||||||
|
uint64_t kb, uint64_t ellb);
|
||||||
|
EXPORT void delete_tnx32x2_approxdecomp_gadget(const MOD_RNX* module);
|
||||||
|
|
||||||
|
/** @brief sets res = gadget_decompose(a) */
|
||||||
|
EXPORT void rnx_approxdecomp_from_tnx32x2( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
const TNX32X2_APPROXDECOMP_GADGET* gadget, // output base 2^K
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int32_t* a // a
|
||||||
|
);
|
||||||
|
|
||||||
|
// decompose from tnxdbl
|
||||||
|
|
||||||
|
typedef struct tnxdbl_approxdecomp_gadget_t TNXDBL_APPROXDECOMP_GADGET;
|
||||||
|
|
||||||
|
/** @brief new gadget: delete with delete_tnxdbl_approxdecomp_gadget */
|
||||||
|
EXPORT TNXDBL_APPROXDECOMP_GADGET* new_tnxdbl_approxdecomp_gadget( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
uint64_t k, uint64_t ell // base 2^K and size
|
||||||
|
);
|
||||||
|
EXPORT void delete_tnxdbl_approxdecomp_gadget(TNXDBL_APPROXDECOMP_GADGET* gadget);
|
||||||
|
|
||||||
|
/** @brief sets res = gadget_decompose(a) */
|
||||||
|
EXPORT void rnx_approxdecomp_from_tnxdbl( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a); // a
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
// prepared scalar-vector product (optimized) //
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
/** @brief opaque type that represents a polynomial of RnX prepared for a scalar-vector product */
|
||||||
|
typedef struct rnx_svp_ppol_t RNX_SVP_PPOL;
|
||||||
|
|
||||||
|
/** @brief number of bytes in a RNX_VMP_PMAT (for manual allocation) */
|
||||||
|
EXPORT uint64_t bytes_of_rnx_svp_ppol(const MOD_RNX* module); // N
|
||||||
|
|
||||||
|
/** @brief allocates a prepared vector (release with delete_rnx_svp_ppol) */
|
||||||
|
EXPORT RNX_SVP_PPOL* new_rnx_svp_ppol(const MOD_RNX* module); // N
|
||||||
|
|
||||||
|
/** @brief frees memory for a prepared vector */
|
||||||
|
EXPORT void delete_rnx_svp_ppol(RNX_SVP_PPOL* res);
|
||||||
|
|
||||||
|
/** @brief prepares a svp polynomial */
|
||||||
|
EXPORT void rnx_svp_prepare(const MOD_RNX* module, // N
|
||||||
|
RNX_SVP_PPOL* ppol, // output
|
||||||
|
const double* pol // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief apply a svp product, result = ppol * a, presented in DFT space */
|
||||||
|
EXPORT void rnx_svp_apply( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // output
|
||||||
|
const RNX_SVP_PPOL* ppol, // prepared pol
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
// prepared vector-matrix product (optimized) //
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
typedef struct rnx_vmp_pmat_t RNX_VMP_PMAT;
|
||||||
|
|
||||||
|
/** @brief number of bytes in a RNX_VMP_PMAT (for manual allocation) */
|
||||||
|
EXPORT uint64_t bytes_of_rnx_vmp_pmat(const MOD_RNX* module, // N
|
||||||
|
uint64_t nrows, uint64_t ncols); // dimensions
|
||||||
|
|
||||||
|
/** @brief allocates a prepared matrix (release with delete_rnx_vmp_pmat) */
|
||||||
|
EXPORT RNX_VMP_PMAT* new_rnx_vmp_pmat(const MOD_RNX* module, // N
|
||||||
|
uint64_t nrows, uint64_t ncols); // dimensions
|
||||||
|
EXPORT void delete_rnx_vmp_pmat(RNX_VMP_PMAT* ptr);
|
||||||
|
|
||||||
|
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||||
|
EXPORT void rnx_vmp_prepare_contiguous( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
RNX_VMP_PMAT* pmat, // output
|
||||||
|
const double* a, uint64_t nrows, uint64_t ncols, // a
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */
|
||||||
|
EXPORT void rnx_vmp_prepare_dblptr( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
RNX_VMP_PMAT* pmat, // output
|
||||||
|
const double** a, uint64_t nrows, uint64_t ncols, // a
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */
|
||||||
|
EXPORT void rnx_vmp_prepare_row( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
RNX_VMP_PMAT* pmat, // output
|
||||||
|
const double* a, uint64_t row_i, uint64_t nrows, uint64_t ncols, // a
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief number of scratch bytes necessary to prepare a matrix */
|
||||||
|
EXPORT uint64_t rnx_vmp_prepare_tmp_bytes(const MOD_RNX* module);
|
||||||
|
|
||||||
|
/** @brief applies a vmp product res = a x pmat */
|
||||||
|
EXPORT void rnx_vmp_apply_tmp_a( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten)
|
||||||
|
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT uint64_t rnx_vmp_apply_tmp_a_tmp_bytes( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
uint64_t res_size, // res size
|
||||||
|
uint64_t a_size, // a size
|
||||||
|
uint64_t nrows, uint64_t ncols // prep matrix dims
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief minimal size of the tmp_space */
|
||||||
|
EXPORT void rnx_vmp_apply_dft_to_dft( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a_dft, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief minimal size of the tmp_space */
|
||||||
|
EXPORT uint64_t rnx_vmp_apply_dft_to_dft_tmp_bytes( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
uint64_t res_size, // res
|
||||||
|
uint64_t a_size, // a
|
||||||
|
uint64_t nrows, uint64_t ncols // prep matrix
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = DFT(a) */
|
||||||
|
EXPORT void vec_rnx_dft(const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = iDFT(a_dft) -- idft is not normalized */
|
||||||
|
EXPORT void vec_rnx_idft(const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a_dft, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
#endif // SPQLIOS_VEC_RNX_ARITHMETIC_H
|
||||||
@@ -0,0 +1,189 @@
|
|||||||
|
#include <immintrin.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include "vec_rnx_arithmetic_private.h"
|
||||||
|
|
||||||
|
void rnx_add_avx(uint64_t nn, double* res, const double* a, const double* b) {
|
||||||
|
if (nn < 8) {
|
||||||
|
if (nn == 4) {
|
||||||
|
_mm256_storeu_pd(res, _mm256_add_pd(_mm256_loadu_pd(a), _mm256_loadu_pd(b)));
|
||||||
|
} else if (nn == 2) {
|
||||||
|
_mm_storeu_pd(res, _mm_add_pd(_mm_loadu_pd(a), _mm_loadu_pd(b)));
|
||||||
|
} else if (nn == 1) {
|
||||||
|
*res = *a + *b;
|
||||||
|
} else {
|
||||||
|
NOT_SUPPORTED(); // not a power of 2
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// general case: nn >= 8
|
||||||
|
__m256d x0, x1, x2, x3, x4, x5;
|
||||||
|
const double* aa = a;
|
||||||
|
const double* bb = b;
|
||||||
|
double* rr = res;
|
||||||
|
double* const rrend = res + nn;
|
||||||
|
do {
|
||||||
|
x0 = _mm256_loadu_pd(aa);
|
||||||
|
x1 = _mm256_loadu_pd(aa + 4);
|
||||||
|
x2 = _mm256_loadu_pd(bb);
|
||||||
|
x3 = _mm256_loadu_pd(bb + 4);
|
||||||
|
x4 = _mm256_add_pd(x0, x2);
|
||||||
|
x5 = _mm256_add_pd(x1, x3);
|
||||||
|
_mm256_storeu_pd(rr, x4);
|
||||||
|
_mm256_storeu_pd(rr + 4, x5);
|
||||||
|
aa += 8;
|
||||||
|
bb += 8;
|
||||||
|
rr += 8;
|
||||||
|
} while (rr < rrend);
|
||||||
|
}
|
||||||
|
|
||||||
|
void rnx_sub_avx(uint64_t nn, double* res, const double* a, const double* b) {
|
||||||
|
if (nn < 8) {
|
||||||
|
if (nn == 4) {
|
||||||
|
_mm256_storeu_pd(res, _mm256_sub_pd(_mm256_loadu_pd(a), _mm256_loadu_pd(b)));
|
||||||
|
} else if (nn == 2) {
|
||||||
|
_mm_storeu_pd(res, _mm_sub_pd(_mm_loadu_pd(a), _mm_loadu_pd(b)));
|
||||||
|
} else if (nn == 1) {
|
||||||
|
*res = *a - *b;
|
||||||
|
} else {
|
||||||
|
NOT_SUPPORTED(); // not a power of 2
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// general case: nn >= 8
|
||||||
|
__m256d x0, x1, x2, x3, x4, x5;
|
||||||
|
const double* aa = a;
|
||||||
|
const double* bb = b;
|
||||||
|
double* rr = res;
|
||||||
|
double* const rrend = res + nn;
|
||||||
|
do {
|
||||||
|
x0 = _mm256_loadu_pd(aa);
|
||||||
|
x1 = _mm256_loadu_pd(aa + 4);
|
||||||
|
x2 = _mm256_loadu_pd(bb);
|
||||||
|
x3 = _mm256_loadu_pd(bb + 4);
|
||||||
|
x4 = _mm256_sub_pd(x0, x2);
|
||||||
|
x5 = _mm256_sub_pd(x1, x3);
|
||||||
|
_mm256_storeu_pd(rr, x4);
|
||||||
|
_mm256_storeu_pd(rr + 4, x5);
|
||||||
|
aa += 8;
|
||||||
|
bb += 8;
|
||||||
|
rr += 8;
|
||||||
|
} while (rr < rrend);
|
||||||
|
}
|
||||||
|
|
||||||
|
void rnx_negate_avx(uint64_t nn, double* res, const double* b) {
|
||||||
|
if (nn < 8) {
|
||||||
|
if (nn == 4) {
|
||||||
|
_mm256_storeu_pd(res, _mm256_sub_pd(_mm256_set1_pd(0), _mm256_loadu_pd(b)));
|
||||||
|
} else if (nn == 2) {
|
||||||
|
_mm_storeu_pd(res, _mm_sub_pd(_mm_set1_pd(0), _mm_loadu_pd(b)));
|
||||||
|
} else if (nn == 1) {
|
||||||
|
*res = -*b;
|
||||||
|
} else {
|
||||||
|
NOT_SUPPORTED(); // not a power of 2
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// general case: nn >= 8
|
||||||
|
__m256d x2, x3, x4, x5;
|
||||||
|
const __m256d ZERO = _mm256_set1_pd(0);
|
||||||
|
const double* bb = b;
|
||||||
|
double* rr = res;
|
||||||
|
double* const rrend = res + nn;
|
||||||
|
do {
|
||||||
|
x2 = _mm256_loadu_pd(bb);
|
||||||
|
x3 = _mm256_loadu_pd(bb + 4);
|
||||||
|
x4 = _mm256_sub_pd(ZERO, x2);
|
||||||
|
x5 = _mm256_sub_pd(ZERO, x3);
|
||||||
|
_mm256_storeu_pd(rr, x4);
|
||||||
|
_mm256_storeu_pd(rr + 4, x5);
|
||||||
|
bb += 8;
|
||||||
|
rr += 8;
|
||||||
|
} while (rr < rrend);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = a + b */
|
||||||
|
EXPORT void vec_rnx_add_avx( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
if (a_size < b_size) {
|
||||||
|
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||||
|
const uint64_t nsize = res_size < b_size ? res_size : b_size;
|
||||||
|
for (uint64_t i = 0; i < msize; ++i) {
|
||||||
|
rnx_add_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
for (uint64_t i = msize; i < nsize; ++i) {
|
||||||
|
memcpy(res + i * res_sl, b + i * b_sl, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
for (uint64_t i = nsize; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const uint64_t msize = res_size < b_size ? res_size : b_size;
|
||||||
|
const uint64_t nsize = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < msize; ++i) {
|
||||||
|
rnx_add_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
for (uint64_t i = msize; i < nsize; ++i) {
|
||||||
|
memcpy(res + i * res_sl, a + i * a_sl, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
for (uint64_t i = nsize; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = -a */
|
||||||
|
EXPORT void vec_rnx_negate_avx( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < msize; ++i) {
|
||||||
|
rnx_negate_avx(nn, res + i * res_sl, a + i * a_sl);
|
||||||
|
}
|
||||||
|
for (uint64_t i = msize; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = a - b */
|
||||||
|
EXPORT void vec_rnx_sub_avx( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
if (a_size < b_size) {
|
||||||
|
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||||
|
const uint64_t nsize = res_size < b_size ? res_size : b_size;
|
||||||
|
for (uint64_t i = 0; i < msize; ++i) {
|
||||||
|
rnx_sub_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
for (uint64_t i = msize; i < nsize; ++i) {
|
||||||
|
rnx_negate_avx(nn, res + i * res_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
for (uint64_t i = nsize; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const uint64_t msize = res_size < b_size ? res_size : b_size;
|
||||||
|
const uint64_t nsize = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < msize; ++i) {
|
||||||
|
rnx_sub_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
for (uint64_t i = msize; i < nsize; ++i) {
|
||||||
|
memcpy(res + i * res_sl, a + i * a_sl, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
for (uint64_t i = nsize; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,92 @@
|
|||||||
|
#ifndef SPQLIOS_VEC_RNX_ARITHMETIC_PLUGIN_H
|
||||||
|
#define SPQLIOS_VEC_RNX_ARITHMETIC_PLUGIN_H
|
||||||
|
|
||||||
|
#include "vec_rnx_arithmetic.h"
|
||||||
|
|
||||||
|
typedef typeof(vec_rnx_zero) VEC_RNX_ZERO_F;
|
||||||
|
typedef typeof(vec_rnx_copy) VEC_RNX_COPY_F;
|
||||||
|
typedef typeof(vec_rnx_negate) VEC_RNX_NEGATE_F;
|
||||||
|
typedef typeof(vec_rnx_add) VEC_RNX_ADD_F;
|
||||||
|
typedef typeof(vec_rnx_sub) VEC_RNX_SUB_F;
|
||||||
|
typedef typeof(vec_rnx_rotate) VEC_RNX_ROTATE_F;
|
||||||
|
typedef typeof(vec_rnx_mul_xp_minus_one) VEC_RNX_MUL_XP_MINUS_ONE_F;
|
||||||
|
typedef typeof(vec_rnx_automorphism) VEC_RNX_AUTOMORPHISM_F;
|
||||||
|
typedef typeof(vec_rnx_to_znx32) VEC_RNX_TO_ZNX32_F;
|
||||||
|
typedef typeof(vec_rnx_from_znx32) VEC_RNX_FROM_ZNX32_F;
|
||||||
|
typedef typeof(vec_rnx_to_tnx32) VEC_RNX_TO_TNX32_F;
|
||||||
|
typedef typeof(vec_rnx_from_tnx32) VEC_RNX_FROM_TNX32_F;
|
||||||
|
typedef typeof(vec_rnx_to_tnx32x2) VEC_RNX_TO_TNX32X2_F;
|
||||||
|
typedef typeof(vec_rnx_from_tnx32x2) VEC_RNX_FROM_TNX32X2_F;
|
||||||
|
typedef typeof(vec_rnx_to_tnxdbl) VEC_RNX_TO_TNXDBL_F;
|
||||||
|
// typedef typeof(vec_rnx_from_tnxdbl) VEC_RNX_FROM_TNXDBL_F;
|
||||||
|
typedef typeof(rnx_small_single_product) RNX_SMALL_SINGLE_PRODUCT_F;
|
||||||
|
typedef typeof(rnx_small_single_product_tmp_bytes) RNX_SMALL_SINGLE_PRODUCT_TMP_BYTES_F;
|
||||||
|
typedef typeof(tnxdbl_small_single_product) TNXDBL_SMALL_SINGLE_PRODUCT_F;
|
||||||
|
typedef typeof(tnxdbl_small_single_product_tmp_bytes) TNXDBL_SMALL_SINGLE_PRODUCT_TMP_BYTES_F;
|
||||||
|
typedef typeof(znx32_small_single_product) ZNX32_SMALL_SINGLE_PRODUCT_F;
|
||||||
|
typedef typeof(znx32_small_single_product_tmp_bytes) ZNX32_SMALL_SINGLE_PRODUCT_TMP_BYTES_F;
|
||||||
|
typedef typeof(tnx32_small_single_product) TNX32_SMALL_SINGLE_PRODUCT_F;
|
||||||
|
typedef typeof(tnx32_small_single_product_tmp_bytes) TNX32_SMALL_SINGLE_PRODUCT_TMP_BYTES_F;
|
||||||
|
typedef typeof(rnx_approxdecomp_from_tnx32) RNX_APPROXDECOMP_FROM_TNX32_F;
|
||||||
|
typedef typeof(rnx_approxdecomp_from_tnx32x2) RNX_APPROXDECOMP_FROM_TNX32X2_F;
|
||||||
|
typedef typeof(rnx_approxdecomp_from_tnxdbl) RNX_APPROXDECOMP_FROM_TNXDBL_F;
|
||||||
|
typedef typeof(bytes_of_rnx_svp_ppol) BYTES_OF_RNX_SVP_PPOL_F;
|
||||||
|
typedef typeof(rnx_svp_prepare) RNX_SVP_PREPARE_F;
|
||||||
|
typedef typeof(rnx_svp_apply) RNX_SVP_APPLY_F;
|
||||||
|
typedef typeof(bytes_of_rnx_vmp_pmat) BYTES_OF_RNX_VMP_PMAT_F;
|
||||||
|
typedef typeof(rnx_vmp_prepare_contiguous) RNX_VMP_PREPARE_CONTIGUOUS_F;
|
||||||
|
typedef typeof(rnx_vmp_prepare_dblptr) RNX_VMP_PREPARE_DBLPTR_F;
|
||||||
|
typedef typeof(rnx_vmp_prepare_row) RNX_VMP_PREPARE_ROW_F;
|
||||||
|
typedef typeof(rnx_vmp_prepare_tmp_bytes) RNX_VMP_PREPARE_TMP_BYTES_F;
|
||||||
|
typedef typeof(rnx_vmp_apply_tmp_a) RNX_VMP_APPLY_TMP_A_F;
|
||||||
|
typedef typeof(rnx_vmp_apply_tmp_a_tmp_bytes) RNX_VMP_APPLY_TMP_A_TMP_BYTES_F;
|
||||||
|
typedef typeof(rnx_vmp_apply_dft_to_dft) RNX_VMP_APPLY_DFT_TO_DFT_F;
|
||||||
|
typedef typeof(rnx_vmp_apply_dft_to_dft_tmp_bytes) RNX_VMP_APPLY_DFT_TO_DFT_TMP_BYTES_F;
|
||||||
|
typedef typeof(vec_rnx_dft) VEC_RNX_DFT_F;
|
||||||
|
typedef typeof(vec_rnx_idft) VEC_RNX_IDFT_F;
|
||||||
|
|
||||||
|
typedef struct rnx_module_vtable_t RNX_MODULE_VTABLE;
|
||||||
|
struct rnx_module_vtable_t {
|
||||||
|
VEC_RNX_ZERO_F* vec_rnx_zero;
|
||||||
|
VEC_RNX_COPY_F* vec_rnx_copy;
|
||||||
|
VEC_RNX_NEGATE_F* vec_rnx_negate;
|
||||||
|
VEC_RNX_ADD_F* vec_rnx_add;
|
||||||
|
VEC_RNX_SUB_F* vec_rnx_sub;
|
||||||
|
VEC_RNX_ROTATE_F* vec_rnx_rotate;
|
||||||
|
VEC_RNX_MUL_XP_MINUS_ONE_F* vec_rnx_mul_xp_minus_one;
|
||||||
|
VEC_RNX_AUTOMORPHISM_F* vec_rnx_automorphism;
|
||||||
|
VEC_RNX_TO_ZNX32_F* vec_rnx_to_znx32;
|
||||||
|
VEC_RNX_FROM_ZNX32_F* vec_rnx_from_znx32;
|
||||||
|
VEC_RNX_TO_TNX32_F* vec_rnx_to_tnx32;
|
||||||
|
VEC_RNX_FROM_TNX32_F* vec_rnx_from_tnx32;
|
||||||
|
VEC_RNX_TO_TNX32X2_F* vec_rnx_to_tnx32x2;
|
||||||
|
VEC_RNX_FROM_TNX32X2_F* vec_rnx_from_tnx32x2;
|
||||||
|
VEC_RNX_TO_TNXDBL_F* vec_rnx_to_tnxdbl;
|
||||||
|
RNX_SMALL_SINGLE_PRODUCT_F* rnx_small_single_product;
|
||||||
|
RNX_SMALL_SINGLE_PRODUCT_TMP_BYTES_F* rnx_small_single_product_tmp_bytes;
|
||||||
|
TNXDBL_SMALL_SINGLE_PRODUCT_F* tnxdbl_small_single_product;
|
||||||
|
TNXDBL_SMALL_SINGLE_PRODUCT_TMP_BYTES_F* tnxdbl_small_single_product_tmp_bytes;
|
||||||
|
ZNX32_SMALL_SINGLE_PRODUCT_F* znx32_small_single_product;
|
||||||
|
ZNX32_SMALL_SINGLE_PRODUCT_TMP_BYTES_F* znx32_small_single_product_tmp_bytes;
|
||||||
|
TNX32_SMALL_SINGLE_PRODUCT_F* tnx32_small_single_product;
|
||||||
|
TNX32_SMALL_SINGLE_PRODUCT_TMP_BYTES_F* tnx32_small_single_product_tmp_bytes;
|
||||||
|
RNX_APPROXDECOMP_FROM_TNX32_F* rnx_approxdecomp_from_tnx32;
|
||||||
|
RNX_APPROXDECOMP_FROM_TNX32X2_F* rnx_approxdecomp_from_tnx32x2;
|
||||||
|
RNX_APPROXDECOMP_FROM_TNXDBL_F* rnx_approxdecomp_from_tnxdbl;
|
||||||
|
BYTES_OF_RNX_SVP_PPOL_F* bytes_of_rnx_svp_ppol;
|
||||||
|
RNX_SVP_PREPARE_F* rnx_svp_prepare;
|
||||||
|
RNX_SVP_APPLY_F* rnx_svp_apply;
|
||||||
|
BYTES_OF_RNX_VMP_PMAT_F* bytes_of_rnx_vmp_pmat;
|
||||||
|
RNX_VMP_PREPARE_CONTIGUOUS_F* rnx_vmp_prepare_contiguous;
|
||||||
|
RNX_VMP_PREPARE_DBLPTR_F* rnx_vmp_prepare_dblptr;
|
||||||
|
RNX_VMP_PREPARE_ROW_F* rnx_vmp_prepare_row;
|
||||||
|
RNX_VMP_PREPARE_TMP_BYTES_F* rnx_vmp_prepare_tmp_bytes;
|
||||||
|
RNX_VMP_APPLY_TMP_A_F* rnx_vmp_apply_tmp_a;
|
||||||
|
RNX_VMP_APPLY_TMP_A_TMP_BYTES_F* rnx_vmp_apply_tmp_a_tmp_bytes;
|
||||||
|
RNX_VMP_APPLY_DFT_TO_DFT_F* rnx_vmp_apply_dft_to_dft;
|
||||||
|
RNX_VMP_APPLY_DFT_TO_DFT_TMP_BYTES_F* rnx_vmp_apply_dft_to_dft_tmp_bytes;
|
||||||
|
VEC_RNX_DFT_F* vec_rnx_dft;
|
||||||
|
VEC_RNX_IDFT_F* vec_rnx_idft;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // SPQLIOS_VEC_RNX_ARITHMETIC_PLUGIN_H
|
||||||
@@ -0,0 +1,309 @@
|
|||||||
|
#ifndef SPQLIOS_VEC_RNX_ARITHMETIC_PRIVATE_H
|
||||||
|
#define SPQLIOS_VEC_RNX_ARITHMETIC_PRIVATE_H
|
||||||
|
|
||||||
|
#include "../commons_private.h"
|
||||||
|
#include "../reim/reim_fft.h"
|
||||||
|
#include "vec_rnx_arithmetic.h"
|
||||||
|
#include "vec_rnx_arithmetic_plugin.h"
|
||||||
|
|
||||||
|
typedef struct fft64_rnx_module_precomp_t FFT64_RNX_MODULE_PRECOMP;
|
||||||
|
struct fft64_rnx_module_precomp_t {
|
||||||
|
REIM_FFT_PRECOMP* p_fft;
|
||||||
|
REIM_IFFT_PRECOMP* p_ifft;
|
||||||
|
REIM_FFTVEC_ADD_PRECOMP* p_fftvec_add;
|
||||||
|
REIM_FFTVEC_MUL_PRECOMP* p_fftvec_mul;
|
||||||
|
REIM_FFTVEC_ADDMUL_PRECOMP* p_fftvec_addmul;
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef union rnx_module_precomp_t RNX_MODULE_PRECOMP;
|
||||||
|
union rnx_module_precomp_t {
|
||||||
|
FFT64_RNX_MODULE_PRECOMP fft64;
|
||||||
|
};
|
||||||
|
|
||||||
|
void fft64_init_rnx_module_precomp(MOD_RNX* module);
|
||||||
|
|
||||||
|
void fft64_finalize_rnx_module_precomp(MOD_RNX* module);
|
||||||
|
|
||||||
|
/** @brief opaque structure that describes the modules (RnX,ZnX,TnX) and the hardware */
|
||||||
|
struct rnx_module_info_t {
|
||||||
|
uint64_t n;
|
||||||
|
uint64_t m;
|
||||||
|
RNX_MODULE_TYPE mtype;
|
||||||
|
RNX_MODULE_VTABLE vtable;
|
||||||
|
RNX_MODULE_PRECOMP precomp;
|
||||||
|
void* custom;
|
||||||
|
void (*custom_deleter)(void*);
|
||||||
|
};
|
||||||
|
|
||||||
|
void init_rnx_module_info(MOD_RNX* module, //
|
||||||
|
uint64_t, RNX_MODULE_TYPE mtype);
|
||||||
|
|
||||||
|
void finalize_rnx_module_info(MOD_RNX* module);
|
||||||
|
|
||||||
|
void fft64_init_rnx_module_vtable(MOD_RNX* module);
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
// prepared gadget decompositions (optimized) //
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
struct tnx32_approxdec_gadget_t {
|
||||||
|
uint64_t k;
|
||||||
|
uint64_t ell;
|
||||||
|
int32_t add_cst; // 1/2.(sum 2^-(i+1)K)
|
||||||
|
int32_t rshift_base; // 32 - K
|
||||||
|
int64_t and_mask; // 2^K-1
|
||||||
|
int64_t or_mask; // double(2^52)
|
||||||
|
double sub_cst; // double(2^52 + 2^(K-1))
|
||||||
|
uint8_t rshifts[8]; // 32 - (i+1).K
|
||||||
|
};
|
||||||
|
|
||||||
|
struct tnx32x2_approxdec_gadget_t {
|
||||||
|
// TODO
|
||||||
|
};
|
||||||
|
|
||||||
|
struct tnxdbl_approxdecomp_gadget_t {
|
||||||
|
uint64_t k;
|
||||||
|
uint64_t ell;
|
||||||
|
double add_cst; // double(3.2^(51-ell.K) + 1/2.(sum 2^(-iK)) for i=[0,ell[)
|
||||||
|
uint64_t and_mask; // uint64(2^(K)-1)
|
||||||
|
uint64_t or_mask; // double(2^52)
|
||||||
|
double sub_cst; // double(2^52 + 2^(K-1))
|
||||||
|
};
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_add_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
);
|
||||||
|
EXPORT void vec_rnx_add_avx( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = 0 */
|
||||||
|
EXPORT void vec_rnx_zero_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl // res
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a */
|
||||||
|
EXPORT void vec_rnx_copy_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = -a */
|
||||||
|
EXPORT void vec_rnx_negate_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = -a */
|
||||||
|
EXPORT void vec_rnx_negate_avx( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a - b */
|
||||||
|
EXPORT void vec_rnx_sub_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a - b */
|
||||||
|
EXPORT void vec_rnx_sub_avx( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const double* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a . X^p */
|
||||||
|
EXPORT void vec_rnx_rotate_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
const int64_t p, // rotation value
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a(X^p) */
|
||||||
|
EXPORT void vec_rnx_automorphism_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
int64_t p, // X -> X^p
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief number of bytes in a RNX_VMP_PMAT (for manual allocation) */
|
||||||
|
EXPORT uint64_t fft64_bytes_of_rnx_vmp_pmat(const MOD_RNX* module, // N
|
||||||
|
uint64_t nrows, uint64_t ncols);
|
||||||
|
|
||||||
|
EXPORT void fft64_rnx_vmp_apply_dft_to_dft_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a_dft, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||||
|
);
|
||||||
|
EXPORT void fft64_rnx_vmp_apply_dft_to_dft_avx( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a_dft, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||||
|
);
|
||||||
|
EXPORT uint64_t fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
uint64_t res_size, // res
|
||||||
|
uint64_t a_size, // a
|
||||||
|
uint64_t nrows, uint64_t ncols // prep matrix
|
||||||
|
);
|
||||||
|
EXPORT uint64_t fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_avx( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
uint64_t res_size, // res
|
||||||
|
uint64_t a_size, // a
|
||||||
|
uint64_t nrows, uint64_t ncols // prep matrix
|
||||||
|
);
|
||||||
|
EXPORT void fft64_rnx_vmp_prepare_contiguous_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
RNX_VMP_PMAT* pmat, // output
|
||||||
|
const double* mat, uint64_t nrows, uint64_t ncols, // a
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
);
|
||||||
|
EXPORT void fft64_rnx_vmp_prepare_contiguous_avx( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
RNX_VMP_PMAT* pmat, // output
|
||||||
|
const double* mat, uint64_t nrows, uint64_t ncols, // a
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
);
|
||||||
|
EXPORT void fft64_rnx_vmp_prepare_dblptr_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
RNX_VMP_PMAT* pmat, // output
|
||||||
|
const double** mat, uint64_t nrows, uint64_t ncols, // a
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
);
|
||||||
|
EXPORT void fft64_rnx_vmp_prepare_dblptr_avx( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
RNX_VMP_PMAT* pmat, // output
|
||||||
|
const double** mat, uint64_t nrows, uint64_t ncols, // a
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
);
|
||||||
|
EXPORT void fft64_rnx_vmp_prepare_row_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
RNX_VMP_PMAT* pmat, // output
|
||||||
|
const double* mat, uint64_t row_i, uint64_t nrows, uint64_t ncols, // a
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
);
|
||||||
|
EXPORT void fft64_rnx_vmp_prepare_row_avx( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
RNX_VMP_PMAT* pmat, // output
|
||||||
|
const double* mat, uint64_t row_i, uint64_t nrows, uint64_t ncols, // a
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
);
|
||||||
|
EXPORT uint64_t fft64_rnx_vmp_prepare_tmp_bytes_ref(const MOD_RNX* module);
|
||||||
|
EXPORT uint64_t fft64_rnx_vmp_prepare_tmp_bytes_avx(const MOD_RNX* module);
|
||||||
|
|
||||||
|
EXPORT void fft64_rnx_vmp_apply_tmp_a_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res (addr must be != a)
|
||||||
|
double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten)
|
||||||
|
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
);
|
||||||
|
EXPORT void fft64_rnx_vmp_apply_tmp_a_avx( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res (addr must be != a)
|
||||||
|
double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten)
|
||||||
|
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT uint64_t fft64_rnx_vmp_apply_tmp_a_tmp_bytes_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
uint64_t res_size, // res
|
||||||
|
uint64_t a_size, // a
|
||||||
|
uint64_t nrows, uint64_t ncols // prep matrix
|
||||||
|
);
|
||||||
|
EXPORT uint64_t fft64_rnx_vmp_apply_tmp_a_tmp_bytes_avx( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
uint64_t res_size, // res
|
||||||
|
uint64_t a_size, // a
|
||||||
|
uint64_t nrows, uint64_t ncols // prep matrix
|
||||||
|
);
|
||||||
|
|
||||||
|
/// gadget decompositions
|
||||||
|
|
||||||
|
/** @brief sets res = gadget_decompose(a) */
|
||||||
|
EXPORT void rnx_approxdecomp_from_tnxdbl_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a); // a
|
||||||
|
EXPORT void rnx_approxdecomp_from_tnxdbl_avx( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a); // a
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_mul_xp_minus_one_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
const int64_t p, // rotation value
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_to_znx32_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
int32_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_from_znx32_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int32_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_to_tnx32_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
int32_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_from_tnx32_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int32_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_to_tnxdbl_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT uint64_t fft64_bytes_of_rnx_svp_ppol(const MOD_RNX* module); // N
|
||||||
|
|
||||||
|
/** @brief prepares a svp polynomial */
|
||||||
|
EXPORT void fft64_rnx_svp_prepare_ref(const MOD_RNX* module, // N
|
||||||
|
RNX_SVP_PPOL* ppol, // output
|
||||||
|
const double* pol // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief apply a svp product, result = ppol * a, presented in DFT space */
|
||||||
|
EXPORT void fft64_rnx_svp_apply_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // output
|
||||||
|
const RNX_SVP_PPOL* ppol, // prepared pol
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
#endif // SPQLIOS_VEC_RNX_ARITHMETIC_PRIVATE_H
|
||||||
@@ -0,0 +1,91 @@
|
|||||||
|
#include <memory.h>
|
||||||
|
|
||||||
|
#include "vec_rnx_arithmetic_private.h"
|
||||||
|
#include "zn_arithmetic_private.h"
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_to_znx32_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
int32_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < msize; ++i) {
|
||||||
|
dbl_round_to_i32_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn);
|
||||||
|
}
|
||||||
|
for (uint64_t i = msize; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(int32_t));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_from_znx32_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int32_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < msize; ++i) {
|
||||||
|
i32_to_dbl_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn);
|
||||||
|
}
|
||||||
|
for (uint64_t i = msize; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(int32_t));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EXPORT void vec_rnx_to_tnx32_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
int32_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < msize; ++i) {
|
||||||
|
dbl_to_tn32_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn);
|
||||||
|
}
|
||||||
|
for (uint64_t i = msize; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(int32_t));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EXPORT void vec_rnx_from_tnx32_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int32_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < msize; ++i) {
|
||||||
|
tn32_to_dbl_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn);
|
||||||
|
}
|
||||||
|
for (uint64_t i = msize; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(int32_t));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void dbl_to_tndbl_ref( //
|
||||||
|
const void* UNUSED, // N
|
||||||
|
double* res, uint64_t res_size, // res
|
||||||
|
const double* a, uint64_t a_size // a
|
||||||
|
) {
|
||||||
|
static const double OFF_CST = INT64_C(3) << 51;
|
||||||
|
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < msize; ++i) {
|
||||||
|
double ai = a[i] + OFF_CST;
|
||||||
|
res[i] = a[i] - (ai - OFF_CST);
|
||||||
|
}
|
||||||
|
memset(res + msize, 0, (res_size - msize) * sizeof(double));
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_rnx_to_tnxdbl_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < msize; ++i) {
|
||||||
|
dbl_to_tndbl_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn);
|
||||||
|
}
|
||||||
|
for (uint64_t i = msize; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(int32_t));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,47 @@
|
|||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include "../coeffs/coeffs_arithmetic.h"
|
||||||
|
#include "vec_rnx_arithmetic_private.h"
|
||||||
|
|
||||||
|
EXPORT uint64_t fft64_bytes_of_rnx_svp_ppol(const MOD_RNX* module) { return module->n * sizeof(double); }
|
||||||
|
|
||||||
|
EXPORT RNX_SVP_PPOL* new_rnx_svp_ppol(const MOD_RNX* module) { return spqlios_alloc(bytes_of_rnx_svp_ppol(module)); }
|
||||||
|
|
||||||
|
EXPORT void delete_rnx_svp_ppol(RNX_SVP_PPOL* ppol) { spqlios_free(ppol); }
|
||||||
|
|
||||||
|
/** @brief prepares a svp polynomial */
|
||||||
|
EXPORT void fft64_rnx_svp_prepare_ref(const MOD_RNX* module, // N
|
||||||
|
RNX_SVP_PPOL* ppol, // output
|
||||||
|
const double* pol // a
|
||||||
|
) {
|
||||||
|
double* const dppol = (double*)ppol;
|
||||||
|
rnx_divide_by_m_ref(module->n, module->m, dppol, pol);
|
||||||
|
reim_fft(module->precomp.fft64.p_fft, dppol);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void fft64_rnx_svp_apply_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // output
|
||||||
|
const RNX_SVP_PPOL* ppol, // prepared pol
|
||||||
|
const double* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
double* const dppol = (double*)ppol;
|
||||||
|
|
||||||
|
const uint64_t auto_end_idx = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < auto_end_idx; ++i) {
|
||||||
|
const double* a_ptr = a + i * a_sl;
|
||||||
|
double* const res_ptr = res + i * res_sl;
|
||||||
|
// copy the polynomial to res, apply fft in place, call fftvec
|
||||||
|
// _mul, apply ifft in place.
|
||||||
|
memcpy(res_ptr, a_ptr, nn * sizeof(double));
|
||||||
|
reim_fft(module->precomp.fft64.p_fft, (double*)res_ptr);
|
||||||
|
reim_fftvec_mul(module->precomp.fft64.p_fftvec_mul, res_ptr, res_ptr, dppol);
|
||||||
|
reim_ifft(module->precomp.fft64.p_ifft, res_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
// then extend with zeros
|
||||||
|
for (uint64_t i = auto_end_idx; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,254 @@
|
|||||||
|
#include <assert.h>
|
||||||
|
#include <immintrin.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include "../coeffs/coeffs_arithmetic.h"
|
||||||
|
#include "../reim/reim_fft.h"
|
||||||
|
#include "../reim4/reim4_arithmetic.h"
|
||||||
|
#include "vec_rnx_arithmetic_private.h"
|
||||||
|
|
||||||
|
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||||
|
EXPORT void fft64_rnx_vmp_prepare_contiguous_avx( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
RNX_VMP_PMAT* pmat, // output
|
||||||
|
const double* mat, uint64_t nrows, uint64_t ncols, // a
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
) {
|
||||||
|
// there is an edge case if nn < 8
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
const uint64_t m = module->m;
|
||||||
|
|
||||||
|
double* const dtmp = (double*)tmp_space;
|
||||||
|
double* const output_mat = (double*)pmat;
|
||||||
|
double* start_addr = (double*)pmat;
|
||||||
|
uint64_t offset = nrows * ncols * 8;
|
||||||
|
|
||||||
|
if (nn >= 8) {
|
||||||
|
for (uint64_t row_i = 0; row_i < nrows; row_i++) {
|
||||||
|
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||||
|
rnx_divide_by_m_avx(nn, m, dtmp, mat + (row_i * ncols + col_i) * nn);
|
||||||
|
reim_fft(module->precomp.fft64.p_fft, dtmp);
|
||||||
|
|
||||||
|
if (col_i == (ncols - 1) && (ncols % 2 == 1)) {
|
||||||
|
// special case: last column out of an odd column number
|
||||||
|
start_addr = output_mat + col_i * nrows * 8 // col == ncols-1
|
||||||
|
+ row_i * 8;
|
||||||
|
} else {
|
||||||
|
// general case: columns go by pair
|
||||||
|
start_addr = output_mat + (col_i / 2) * (2 * nrows) * 8 // second: col pair index
|
||||||
|
+ row_i * 2 * 8 // third: row index
|
||||||
|
+ (col_i % 2) * 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
|
||||||
|
// extract blk from tmp and save it
|
||||||
|
reim4_extract_1blk_from_reim_avx(m, blk_i, start_addr + blk_i * offset, dtmp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (uint64_t row_i = 0; row_i < nrows; row_i++) {
|
||||||
|
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||||
|
double* res = output_mat + (col_i * nrows + row_i) * nn;
|
||||||
|
rnx_divide_by_m_avx(nn, m, res, mat + (row_i * ncols + col_i) * nn);
|
||||||
|
reim_fft(module->precomp.fft64.p_fft, res);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */
|
||||||
|
EXPORT void fft64_rnx_vmp_prepare_dblptr_avx( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
RNX_VMP_PMAT* pmat, // output
|
||||||
|
const double** mat, uint64_t nrows, uint64_t ncols, // a
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
) {
|
||||||
|
for (uint64_t row_i = 0; row_i < nrows; row_i++) {
|
||||||
|
fft64_rnx_vmp_prepare_row_avx(module, pmat, mat[row_i], row_i, nrows, ncols, tmp_space);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */
|
||||||
|
EXPORT void fft64_rnx_vmp_prepare_row_avx( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
RNX_VMP_PMAT* pmat, // output
|
||||||
|
const double* row, uint64_t row_i, uint64_t nrows, uint64_t ncols, // a
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
) {
|
||||||
|
// there is an edge case if nn < 8
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
const uint64_t m = module->m;
|
||||||
|
|
||||||
|
double* const dtmp = (double*)tmp_space;
|
||||||
|
double* const output_mat = (double*)pmat;
|
||||||
|
double* start_addr = (double*)pmat;
|
||||||
|
uint64_t offset = nrows * ncols * 8;
|
||||||
|
|
||||||
|
if (nn >= 8) {
|
||||||
|
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||||
|
rnx_divide_by_m_avx(nn, m, dtmp, row + col_i * nn);
|
||||||
|
reim_fft(module->precomp.fft64.p_fft, dtmp);
|
||||||
|
|
||||||
|
if (col_i == (ncols - 1) && (ncols % 2 == 1)) {
|
||||||
|
// special case: last column out of an odd column number
|
||||||
|
start_addr = output_mat + col_i * nrows * 8 // col == ncols-1
|
||||||
|
+ row_i * 8;
|
||||||
|
} else {
|
||||||
|
// general case: columns go by pair
|
||||||
|
start_addr = output_mat + (col_i / 2) * (2 * nrows) * 8 // second: col pair index
|
||||||
|
+ row_i * 2 * 8 // third: row index
|
||||||
|
+ (col_i % 2) * 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
|
||||||
|
// extract blk from tmp and save it
|
||||||
|
reim4_extract_1blk_from_reim_avx(m, blk_i, start_addr + blk_i * offset, dtmp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||||
|
double* res = output_mat + (col_i * nrows + row_i) * nn;
|
||||||
|
rnx_divide_by_m_avx(nn, m, res, row + col_i * nn);
|
||||||
|
reim_fft(module->precomp.fft64.p_fft, res);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief minimal size of the tmp_space */
|
||||||
|
EXPORT void fft64_rnx_vmp_apply_dft_to_dft_avx( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a_dft, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||||
|
) {
|
||||||
|
const uint64_t m = module->m;
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
|
||||||
|
double* mat2cols_output = (double*)tmp_space; // 128 bytes
|
||||||
|
double* extracted_blk = (double*)tmp_space + 16; // 64*min(nrows,a_size) bytes
|
||||||
|
|
||||||
|
double* mat_input = (double*)pmat;
|
||||||
|
|
||||||
|
const uint64_t row_max = nrows < a_size ? nrows : a_size;
|
||||||
|
const uint64_t col_max = ncols < res_size ? ncols : res_size;
|
||||||
|
|
||||||
|
if (row_max > 0 && col_max > 0) {
|
||||||
|
if (nn >= 8) {
|
||||||
|
// let's do some prefetching of the GSW key, since on some cpus,
|
||||||
|
// it helps
|
||||||
|
const uint64_t ms4 = m >> 2; // m/4
|
||||||
|
const uint64_t gsw_iter_doubles = 8 * nrows * ncols;
|
||||||
|
const uint64_t pref_doubles = 1200;
|
||||||
|
const double* gsw_pref_ptr = mat_input;
|
||||||
|
const double* const gsw_ptr_end = mat_input + ms4 * gsw_iter_doubles;
|
||||||
|
const double* gsw_pref_ptr_target = mat_input + pref_doubles;
|
||||||
|
for (; gsw_pref_ptr < gsw_pref_ptr_target; gsw_pref_ptr += 8) {
|
||||||
|
__builtin_prefetch(gsw_pref_ptr, 0, _MM_HINT_T0);
|
||||||
|
}
|
||||||
|
const double* mat_blk_start;
|
||||||
|
uint64_t blk_i;
|
||||||
|
for (blk_i = 0, mat_blk_start = mat_input; blk_i < ms4; blk_i++, mat_blk_start += gsw_iter_doubles) {
|
||||||
|
// prefetch the next iteration
|
||||||
|
if (gsw_pref_ptr_target < gsw_ptr_end) {
|
||||||
|
gsw_pref_ptr_target += gsw_iter_doubles;
|
||||||
|
if (gsw_pref_ptr_target > gsw_ptr_end) gsw_pref_ptr_target = gsw_ptr_end;
|
||||||
|
for (; gsw_pref_ptr < gsw_pref_ptr_target; gsw_pref_ptr += 8) {
|
||||||
|
__builtin_prefetch(gsw_pref_ptr, 0, _MM_HINT_T0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
reim4_extract_1blk_from_contiguous_reim_sl_avx(m, a_sl, row_max, blk_i, extracted_blk, a_dft);
|
||||||
|
// apply mat2cols
|
||||||
|
for (uint64_t col_i = 0; col_i < col_max - 1; col_i += 2) {
|
||||||
|
uint64_t col_offset = col_i * (8 * nrows);
|
||||||
|
reim4_vec_mat2cols_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||||
|
|
||||||
|
reim4_save_1blk_to_reim_avx(m, blk_i, res + col_i * res_sl, mat2cols_output);
|
||||||
|
reim4_save_1blk_to_reim_avx(m, blk_i, res + (col_i + 1) * res_sl, mat2cols_output + 8);
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if col_max is odd, then special case
|
||||||
|
if (col_max % 2 == 1) {
|
||||||
|
uint64_t last_col = col_max - 1;
|
||||||
|
uint64_t col_offset = last_col * (8 * nrows);
|
||||||
|
|
||||||
|
// the last column is alone in the pmat: vec_mat1col
|
||||||
|
if (ncols == col_max) {
|
||||||
|
reim4_vec_mat1col_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||||
|
} else {
|
||||||
|
// the last column is part of a colpair in the pmat: vec_mat2cols and ignore the second position
|
||||||
|
reim4_vec_mat2cols_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||||
|
}
|
||||||
|
reim4_save_1blk_to_reim_avx(m, blk_i, res + last_col * res_sl, mat2cols_output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const double* in;
|
||||||
|
uint64_t in_sl;
|
||||||
|
if (res == a_dft) {
|
||||||
|
// it is in place: copy the input vector
|
||||||
|
in = (double*)tmp_space;
|
||||||
|
in_sl = nn;
|
||||||
|
// vec_rnx_copy(module, (double*)tmp_space, row_max, nn, a_dft, row_max, a_sl);
|
||||||
|
for (uint64_t row_i = 0; row_i < row_max; row_i++) {
|
||||||
|
memcpy((double*)tmp_space + row_i * nn, a_dft + row_i * a_sl, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// it is out of place: do the product directly
|
||||||
|
in = a_dft;
|
||||||
|
in_sl = a_sl;
|
||||||
|
}
|
||||||
|
for (uint64_t col_i = 0; col_i < col_max; col_i++) {
|
||||||
|
double* pmat_col = mat_input + col_i * nrows * nn;
|
||||||
|
{
|
||||||
|
reim_fftvec_mul(module->precomp.fft64.p_fftvec_mul, //
|
||||||
|
res + col_i * res_sl, //
|
||||||
|
in, //
|
||||||
|
pmat_col);
|
||||||
|
}
|
||||||
|
for (uint64_t row_i = 1; row_i < row_max; row_i++) {
|
||||||
|
reim_fftvec_addmul(module->precomp.fft64.p_fftvec_addmul, //
|
||||||
|
res + col_i * res_sl, //
|
||||||
|
in + row_i * in_sl, //
|
||||||
|
pmat_col + row_i * nn);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// zero out remaining bytes (if any)
|
||||||
|
for (uint64_t i = col_max; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief applies a vmp product res = a x pmat */
|
||||||
|
EXPORT void fft64_rnx_vmp_apply_tmp_a_avx( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res (addr must be != a)
|
||||||
|
double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten)
|
||||||
|
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
const uint64_t rows = nrows < a_size ? nrows : a_size;
|
||||||
|
const uint64_t cols = ncols < res_size ? ncols : res_size;
|
||||||
|
|
||||||
|
// fft is done in place on the input (tmpa is destroyed)
|
||||||
|
for (uint64_t i = 0; i < rows; ++i) {
|
||||||
|
reim_fft(module->precomp.fft64.p_fft, tmpa + i * a_sl);
|
||||||
|
}
|
||||||
|
fft64_rnx_vmp_apply_dft_to_dft_avx(module, //
|
||||||
|
res, cols, res_sl, //
|
||||||
|
tmpa, rows, a_sl, //
|
||||||
|
pmat, nrows, ncols, //
|
||||||
|
tmp_space);
|
||||||
|
// ifft is done in place on the output
|
||||||
|
for (uint64_t i = 0; i < cols; ++i) {
|
||||||
|
reim_ifft(module->precomp.fft64.p_ifft, res + i * res_sl);
|
||||||
|
}
|
||||||
|
// zero out the remaining positions
|
||||||
|
for (uint64_t i = cols; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,309 @@
|
|||||||
|
#include <assert.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include "../coeffs/coeffs_arithmetic.h"
|
||||||
|
#include "../reim/reim_fft.h"
|
||||||
|
#include "../reim4/reim4_arithmetic.h"
|
||||||
|
#include "vec_rnx_arithmetic_private.h"
|
||||||
|
|
||||||
|
/** @brief number of bytes in a RNX_VMP_PMAT (for manual allocation) */
|
||||||
|
EXPORT uint64_t fft64_bytes_of_rnx_vmp_pmat(const MOD_RNX* module, // N
|
||||||
|
uint64_t nrows, uint64_t ncols) { // dimensions
|
||||||
|
return nrows * ncols * module->n * sizeof(double);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||||
|
EXPORT void fft64_rnx_vmp_prepare_contiguous_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
RNX_VMP_PMAT* pmat, // output
|
||||||
|
const double* mat, uint64_t nrows, uint64_t ncols, // a
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
) {
|
||||||
|
// there is an edge case if nn < 8
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
const uint64_t m = module->m;
|
||||||
|
|
||||||
|
double* const dtmp = (double*)tmp_space;
|
||||||
|
double* const output_mat = (double*)pmat;
|
||||||
|
double* start_addr = (double*)pmat;
|
||||||
|
uint64_t offset = nrows * ncols * 8;
|
||||||
|
|
||||||
|
if (nn >= 8) {
|
||||||
|
for (uint64_t row_i = 0; row_i < nrows; row_i++) {
|
||||||
|
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||||
|
rnx_divide_by_m_ref(nn, m, dtmp, mat + (row_i * ncols + col_i) * nn);
|
||||||
|
reim_fft(module->precomp.fft64.p_fft, dtmp);
|
||||||
|
|
||||||
|
if (col_i == (ncols - 1) && (ncols % 2 == 1)) {
|
||||||
|
// special case: last column out of an odd column number
|
||||||
|
start_addr = output_mat + col_i * nrows * 8 // col == ncols-1
|
||||||
|
+ row_i * 8;
|
||||||
|
} else {
|
||||||
|
// general case: columns go by pair
|
||||||
|
start_addr = output_mat + (col_i / 2) * (2 * nrows) * 8 // second: col pair index
|
||||||
|
+ row_i * 2 * 8 // third: row index
|
||||||
|
+ (col_i % 2) * 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
|
||||||
|
// extract blk from tmp and save it
|
||||||
|
reim4_extract_1blk_from_reim_ref(m, blk_i, start_addr + blk_i * offset, dtmp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (uint64_t row_i = 0; row_i < nrows; row_i++) {
|
||||||
|
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||||
|
double* res = output_mat + (col_i * nrows + row_i) * nn;
|
||||||
|
rnx_divide_by_m_ref(nn, m, res, mat + (row_i * ncols + col_i) * nn);
|
||||||
|
reim_fft(module->precomp.fft64.p_fft, res);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */
|
||||||
|
EXPORT void fft64_rnx_vmp_prepare_dblptr_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
RNX_VMP_PMAT* pmat, // output
|
||||||
|
const double** mat, uint64_t nrows, uint64_t ncols, // a
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
) {
|
||||||
|
for (uint64_t row_i = 0; row_i < nrows; row_i++) {
|
||||||
|
fft64_rnx_vmp_prepare_row_ref(module, pmat, mat[row_i], row_i, nrows, ncols, tmp_space);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */
|
||||||
|
EXPORT void fft64_rnx_vmp_prepare_row_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
RNX_VMP_PMAT* pmat, // output
|
||||||
|
const double* row, uint64_t row_i, uint64_t nrows, uint64_t ncols, // a
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
) {
|
||||||
|
// there is an edge case if nn < 8
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
const uint64_t m = module->m;
|
||||||
|
|
||||||
|
double* const dtmp = (double*)tmp_space;
|
||||||
|
double* const output_mat = (double*)pmat;
|
||||||
|
double* start_addr = (double*)pmat;
|
||||||
|
uint64_t offset = nrows * ncols * 8;
|
||||||
|
|
||||||
|
if (nn >= 8) {
|
||||||
|
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||||
|
rnx_divide_by_m_ref(nn, m, dtmp, row + col_i * nn);
|
||||||
|
reim_fft(module->precomp.fft64.p_fft, dtmp);
|
||||||
|
|
||||||
|
if (col_i == (ncols - 1) && (ncols % 2 == 1)) {
|
||||||
|
// special case: last column out of an odd column number
|
||||||
|
start_addr = output_mat + col_i * nrows * 8 // col == ncols-1
|
||||||
|
+ row_i * 8;
|
||||||
|
} else {
|
||||||
|
// general case: columns go by pair
|
||||||
|
start_addr = output_mat + (col_i / 2) * (2 * nrows) * 8 // second: col pair index
|
||||||
|
+ row_i * 2 * 8 // third: row index
|
||||||
|
+ (col_i % 2) * 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
|
||||||
|
// extract blk from tmp and save it
|
||||||
|
reim4_extract_1blk_from_reim_ref(m, blk_i, start_addr + blk_i * offset, dtmp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||||
|
double* res = output_mat + (col_i * nrows + row_i) * nn;
|
||||||
|
rnx_divide_by_m_ref(nn, m, res, row + col_i * nn);
|
||||||
|
reim_fft(module->precomp.fft64.p_fft, res);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief number of scratch bytes necessary to prepare a matrix */
|
||||||
|
EXPORT uint64_t fft64_rnx_vmp_prepare_tmp_bytes_ref(const MOD_RNX* module) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
return nn * sizeof(int64_t);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief minimal size of the tmp_space */
|
||||||
|
EXPORT void fft64_rnx_vmp_apply_dft_to_dft_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const double* a_dft, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||||
|
) {
|
||||||
|
const uint64_t m = module->m;
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
|
||||||
|
double* mat2cols_output = (double*)tmp_space; // 128 bytes
|
||||||
|
double* extracted_blk = (double*)tmp_space + 16; // 64*min(nrows,a_size) bytes
|
||||||
|
|
||||||
|
double* mat_input = (double*)pmat;
|
||||||
|
|
||||||
|
const uint64_t row_max = nrows < a_size ? nrows : a_size;
|
||||||
|
const uint64_t col_max = ncols < res_size ? ncols : res_size;
|
||||||
|
|
||||||
|
if (row_max > 0 && col_max > 0) {
|
||||||
|
if (nn >= 8) {
|
||||||
|
for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
|
||||||
|
double* mat_blk_start = mat_input + blk_i * (8 * nrows * ncols);
|
||||||
|
|
||||||
|
reim4_extract_1blk_from_contiguous_reim_sl_ref(m, a_sl, row_max, blk_i, extracted_blk, a_dft);
|
||||||
|
// apply mat2cols
|
||||||
|
for (uint64_t col_i = 0; col_i < col_max - 1; col_i += 2) {
|
||||||
|
uint64_t col_offset = col_i * (8 * nrows);
|
||||||
|
reim4_vec_mat2cols_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||||
|
|
||||||
|
reim4_save_1blk_to_reim_ref(m, blk_i, res + col_i * res_sl, mat2cols_output);
|
||||||
|
reim4_save_1blk_to_reim_ref(m, blk_i, res + (col_i + 1) * res_sl, mat2cols_output + 8);
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if col_max is odd, then special case
|
||||||
|
if (col_max % 2 == 1) {
|
||||||
|
uint64_t last_col = col_max - 1;
|
||||||
|
uint64_t col_offset = last_col * (8 * nrows);
|
||||||
|
|
||||||
|
// the last column is alone in the pmat: vec_mat1col
|
||||||
|
if (ncols == col_max) {
|
||||||
|
reim4_vec_mat1col_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||||
|
} else {
|
||||||
|
// the last column is part of a colpair in the pmat: vec_mat2cols and ignore the second position
|
||||||
|
reim4_vec_mat2cols_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||||
|
}
|
||||||
|
reim4_save_1blk_to_reim_ref(m, blk_i, res + last_col * res_sl, mat2cols_output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const double* in;
|
||||||
|
uint64_t in_sl;
|
||||||
|
if (res == a_dft) {
|
||||||
|
// it is in place: copy the input vector
|
||||||
|
in = (double*)tmp_space;
|
||||||
|
in_sl = nn;
|
||||||
|
// vec_rnx_copy(module, (double*)tmp_space, row_max, nn, a_dft, row_max, a_sl);
|
||||||
|
for (uint64_t row_i = 0; row_i < row_max; row_i++) {
|
||||||
|
memcpy((double*)tmp_space + row_i * nn, a_dft + row_i * a_sl, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// it is out of place: do the product directly
|
||||||
|
in = a_dft;
|
||||||
|
in_sl = a_sl;
|
||||||
|
}
|
||||||
|
for (uint64_t col_i = 0; col_i < col_max; col_i++) {
|
||||||
|
double* pmat_col = mat_input + col_i * nrows * nn;
|
||||||
|
{
|
||||||
|
reim_fftvec_mul(module->precomp.fft64.p_fftvec_mul, //
|
||||||
|
res + col_i * res_sl, //
|
||||||
|
in, //
|
||||||
|
pmat_col);
|
||||||
|
}
|
||||||
|
for (uint64_t row_i = 1; row_i < row_max; row_i++) {
|
||||||
|
reim_fftvec_addmul(module->precomp.fft64.p_fftvec_addmul, //
|
||||||
|
res + col_i * res_sl, //
|
||||||
|
in + row_i * in_sl, //
|
||||||
|
pmat_col + row_i * nn);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// zero out remaining bytes (if any)
|
||||||
|
for (uint64_t i = col_max; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief applies a vmp product res = a x pmat */
|
||||||
|
EXPORT void fft64_rnx_vmp_apply_tmp_a_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
double* res, uint64_t res_size, uint64_t res_sl, // res (addr must be != a)
|
||||||
|
double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten)
|
||||||
|
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->n;
|
||||||
|
const uint64_t rows = nrows < a_size ? nrows : a_size;
|
||||||
|
const uint64_t cols = ncols < res_size ? ncols : res_size;
|
||||||
|
|
||||||
|
// fft is done in place on the input (tmpa is destroyed)
|
||||||
|
for (uint64_t i = 0; i < rows; ++i) {
|
||||||
|
reim_fft(module->precomp.fft64.p_fft, tmpa + i * a_sl);
|
||||||
|
}
|
||||||
|
fft64_rnx_vmp_apply_dft_to_dft_ref(module, //
|
||||||
|
res, cols, res_sl, //
|
||||||
|
tmpa, rows, a_sl, //
|
||||||
|
pmat, nrows, ncols, //
|
||||||
|
tmp_space);
|
||||||
|
// ifft is done in place on the output
|
||||||
|
for (uint64_t i = 0; i < cols; ++i) {
|
||||||
|
reim_ifft(module->precomp.fft64.p_ifft, res + i * res_sl);
|
||||||
|
}
|
||||||
|
// zero out the remaining positions
|
||||||
|
for (uint64_t i = cols; i < res_size; ++i) {
|
||||||
|
memset(res + i * res_sl, 0, nn * sizeof(double));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief minimal size of the tmp_space */
|
||||||
|
EXPORT uint64_t fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
uint64_t res_size, // res
|
||||||
|
uint64_t a_size, // a
|
||||||
|
uint64_t nrows, uint64_t ncols // prep matrix
|
||||||
|
) {
|
||||||
|
const uint64_t row_max = nrows < a_size ? nrows : a_size;
|
||||||
|
|
||||||
|
return (128) + (64 * row_max);
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef __APPLE__
|
||||||
|
EXPORT uint64_t fft64_rnx_vmp_apply_tmp_a_tmp_bytes_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
uint64_t res_size, // res
|
||||||
|
uint64_t a_size, // a
|
||||||
|
uint64_t nrows, uint64_t ncols // prep matrix
|
||||||
|
) {
|
||||||
|
return fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref(module, res_size, a_size, nrows, ncols);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
EXPORT uint64_t fft64_rnx_vmp_apply_tmp_a_tmp_bytes_ref( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
uint64_t res_size, // res
|
||||||
|
uint64_t a_size, // a
|
||||||
|
uint64_t nrows, uint64_t ncols // prep matrix
|
||||||
|
) __attribute((alias("fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref")));
|
||||||
|
#endif
|
||||||
|
// avx aliases that need to be defined in the same .c file
|
||||||
|
|
||||||
|
/** @brief number of scratch bytes necessary to prepare a matrix */
|
||||||
|
#ifdef __APPLE__
|
||||||
|
#pragma weak fft64_rnx_vmp_prepare_tmp_bytes_avx = fft64_rnx_vmp_prepare_tmp_bytes_ref
|
||||||
|
#else
|
||||||
|
EXPORT uint64_t fft64_rnx_vmp_prepare_tmp_bytes_avx(const MOD_RNX* module)
|
||||||
|
__attribute((alias("fft64_rnx_vmp_prepare_tmp_bytes_ref")));
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/** @brief minimal size of the tmp_space */
|
||||||
|
#ifdef __APPLE__
|
||||||
|
#pragma weak fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_avx = fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref
|
||||||
|
#else
|
||||||
|
EXPORT uint64_t fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_avx( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
uint64_t res_size, // res
|
||||||
|
uint64_t a_size, // a
|
||||||
|
uint64_t nrows, uint64_t ncols // prep matrix
|
||||||
|
) __attribute((alias("fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref")));
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef __APPLE__
|
||||||
|
#pragma weak fft64_rnx_vmp_apply_tmp_a_tmp_bytes_avx = fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref
|
||||||
|
#else
|
||||||
|
EXPORT uint64_t fft64_rnx_vmp_apply_tmp_a_tmp_bytes_avx( //
|
||||||
|
const MOD_RNX* module, // N
|
||||||
|
uint64_t res_size, // res
|
||||||
|
uint64_t a_size, // a
|
||||||
|
uint64_t nrows, uint64_t ncols // prep matrix
|
||||||
|
) __attribute((alias("fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref")));
|
||||||
|
#endif
|
||||||
|
// wrappers
|
||||||
@@ -0,0 +1,369 @@
|
|||||||
|
#include <assert.h>
|
||||||
|
#include <math.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include "../coeffs/coeffs_arithmetic.h"
|
||||||
|
#include "../q120/q120_arithmetic.h"
|
||||||
|
#include "../q120/q120_ntt.h"
|
||||||
|
#include "../reim/reim_fft_internal.h"
|
||||||
|
#include "../reim4/reim4_arithmetic.h"
|
||||||
|
#include "vec_znx_arithmetic.h"
|
||||||
|
#include "vec_znx_arithmetic_private.h"
|
||||||
|
|
||||||
|
// general function (virtual dispatch)
|
||||||
|
|
||||||
|
EXPORT void vec_znx_add(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
) {
|
||||||
|
module->func.vec_znx_add(module, // N
|
||||||
|
res, res_size, res_sl, // res
|
||||||
|
a, a_size, a_sl, // a
|
||||||
|
b, b_size, b_sl // b
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_znx_sub(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
) {
|
||||||
|
module->func.vec_znx_sub(module, // N
|
||||||
|
res, res_size, res_sl, // res
|
||||||
|
a, a_size, a_sl, // a
|
||||||
|
b, b_size, b_sl // b
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_znx_rotate(const MODULE* module, // N
|
||||||
|
const int64_t p, // rotation value
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
module->func.vec_znx_rotate(module, // N
|
||||||
|
p, // p
|
||||||
|
res, res_size, res_sl, // res
|
||||||
|
a, a_size, a_sl // a
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_znx_mul_xp_minus_one(const MODULE* module, // N
|
||||||
|
const int64_t p, // p
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
module->func.vec_znx_mul_xp_minus_one(module, // N
|
||||||
|
p, // p
|
||||||
|
res, res_size, res_sl, // res
|
||||||
|
a, a_size, a_sl // a
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_znx_automorphism(const MODULE* module, // N
|
||||||
|
const int64_t p, // X->X^p
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
module->func.vec_znx_automorphism(module, // N
|
||||||
|
p, // p
|
||||||
|
res, res_size, res_sl, // res
|
||||||
|
a, a_size, a_sl // a
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_znx_normalize_base2k(const MODULE* module, uint64_t nn, // N
|
||||||
|
uint64_t log2_base2k, // output base 2^K
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
uint8_t* tmp_space // scratch space of size >= N
|
||||||
|
) {
|
||||||
|
module->func.vec_znx_normalize_base2k(module, nn, // N
|
||||||
|
log2_base2k, // log2_base2k
|
||||||
|
res, res_size, res_sl, // res
|
||||||
|
a, a_size, a_sl, // a
|
||||||
|
tmp_space);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT uint64_t vec_znx_normalize_base2k_tmp_bytes(const MODULE* module, uint64_t nn // N
|
||||||
|
) {
|
||||||
|
return module->func.vec_znx_normalize_base2k_tmp_bytes(module, nn // N
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// specialized function (ref)
|
||||||
|
|
||||||
|
EXPORT void vec_znx_add_ref(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
if (a_size <= b_size) {
|
||||||
|
const uint64_t sum_idx = res_size < a_size ? res_size : a_size;
|
||||||
|
const uint64_t copy_idx = res_size < b_size ? res_size : b_size;
|
||||||
|
// add up to the smallest dimension
|
||||||
|
for (uint64_t i = 0; i < sum_idx; ++i) {
|
||||||
|
znx_add_i64_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
// then copy to the largest dimension
|
||||||
|
for (uint64_t i = sum_idx; i < copy_idx; ++i) {
|
||||||
|
znx_copy_i64_ref(nn, res + i * res_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
// then extend with zeros
|
||||||
|
for (uint64_t i = copy_idx; i < res_size; ++i) {
|
||||||
|
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const uint64_t sum_idx = res_size < b_size ? res_size : b_size;
|
||||||
|
const uint64_t copy_idx = res_size < a_size ? res_size : a_size;
|
||||||
|
// add up to the smallest dimension
|
||||||
|
for (uint64_t i = 0; i < sum_idx; ++i) {
|
||||||
|
znx_add_i64_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
// then copy to the largest dimension
|
||||||
|
for (uint64_t i = sum_idx; i < copy_idx; ++i) {
|
||||||
|
znx_copy_i64_ref(nn, res + i * res_sl, a + i * a_sl);
|
||||||
|
}
|
||||||
|
// then extend with zeros
|
||||||
|
for (uint64_t i = copy_idx; i < res_size; ++i) {
|
||||||
|
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_znx_sub_ref(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
if (a_size <= b_size) {
|
||||||
|
const uint64_t sub_idx = res_size < a_size ? res_size : a_size;
|
||||||
|
const uint64_t copy_idx = res_size < b_size ? res_size : b_size;
|
||||||
|
// subtract up to the smallest dimension
|
||||||
|
for (uint64_t i = 0; i < sub_idx; ++i) {
|
||||||
|
znx_sub_i64_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
// then negate to the largest dimension
|
||||||
|
for (uint64_t i = sub_idx; i < copy_idx; ++i) {
|
||||||
|
znx_negate_i64_ref(nn, res + i * res_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
// then extend with zeros
|
||||||
|
for (uint64_t i = copy_idx; i < res_size; ++i) {
|
||||||
|
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const uint64_t sub_idx = res_size < b_size ? res_size : b_size;
|
||||||
|
const uint64_t copy_idx = res_size < a_size ? res_size : a_size;
|
||||||
|
// subtract up to the smallest dimension
|
||||||
|
for (uint64_t i = 0; i < sub_idx; ++i) {
|
||||||
|
znx_sub_i64_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
// then copy to the largest dimension
|
||||||
|
for (uint64_t i = sub_idx; i < copy_idx; ++i) {
|
||||||
|
znx_copy_i64_ref(nn, res + i * res_sl, a + i * a_sl);
|
||||||
|
}
|
||||||
|
// then extend with zeros
|
||||||
|
for (uint64_t i = copy_idx; i < res_size; ++i) {
|
||||||
|
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_znx_rotate_ref(const MODULE* module, // N
|
||||||
|
const int64_t p, // rotation value
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
|
||||||
|
const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size;
|
||||||
|
// rotate up to the smallest dimension
|
||||||
|
for (uint64_t i = 0; i < rot_end_idx; ++i) {
|
||||||
|
int64_t* res_ptr = res + i * res_sl;
|
||||||
|
const int64_t* a_ptr = a + i * a_sl;
|
||||||
|
if (res_ptr == a_ptr) {
|
||||||
|
znx_rotate_inplace_i64(nn, p, res_ptr);
|
||||||
|
} else {
|
||||||
|
znx_rotate_i64(nn, p, res_ptr, a_ptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// then extend with zeros
|
||||||
|
for (uint64_t i = rot_end_idx; i < res_size; ++i) {
|
||||||
|
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_znx_mul_xp_minus_one_ref(const MODULE* module, // N
|
||||||
|
const int64_t p, // p
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
|
||||||
|
const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < rot_end_idx; ++i) {
|
||||||
|
int64_t* res_ptr = res + i * res_sl;
|
||||||
|
const int64_t* a_ptr = a + i * a_sl;
|
||||||
|
if (res_ptr == a_ptr) {
|
||||||
|
znx_mul_xp_minus_one_inplace_i64(nn, p, res_ptr);
|
||||||
|
} else {
|
||||||
|
znx_mul_xp_minus_one_i64(nn, p, res_ptr, a_ptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// then extend with zeros
|
||||||
|
for (uint64_t i = rot_end_idx; i < res_size; ++i) {
|
||||||
|
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_znx_automorphism_ref(const MODULE* module, // N
|
||||||
|
const int64_t p, // X->X^p
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
|
||||||
|
const uint64_t auto_end_idx = res_size < a_size ? res_size : a_size;
|
||||||
|
|
||||||
|
for (uint64_t i = 0; i < auto_end_idx; ++i) {
|
||||||
|
int64_t* res_ptr = res + i * res_sl;
|
||||||
|
const int64_t* a_ptr = a + i * a_sl;
|
||||||
|
if (res_ptr == a_ptr) {
|
||||||
|
znx_automorphism_inplace_i64(nn, p, res_ptr);
|
||||||
|
} else {
|
||||||
|
znx_automorphism_i64(nn, p, res_ptr, a_ptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// then extend with zeros
|
||||||
|
for (uint64_t i = auto_end_idx; i < res_size; ++i) {
|
||||||
|
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_znx_normalize_base2k_ref(const MODULE* module, uint64_t nn, // N
|
||||||
|
uint64_t log2_base2k, // output base 2^K
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
uint8_t* tmp_space // scratch space of size >= N
|
||||||
|
) {
|
||||||
|
|
||||||
|
// use MSB limb of res for carry propagation
|
||||||
|
int64_t* cout = (int64_t*)tmp_space;
|
||||||
|
int64_t* cin = 0x0;
|
||||||
|
|
||||||
|
// propagate carry until first limb of res
|
||||||
|
int64_t i = a_size - 1;
|
||||||
|
for (; i >= res_size; --i) {
|
||||||
|
znx_normalize(nn, log2_base2k, 0x0, cout, a + i * a_sl, cin);
|
||||||
|
cin = cout;
|
||||||
|
}
|
||||||
|
|
||||||
|
// propagate carry and normalize
|
||||||
|
for (; i >= 1; --i) {
|
||||||
|
znx_normalize(nn, log2_base2k, res + i * res_sl, cout, a + i * a_sl, cin);
|
||||||
|
cin = cout;
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalize last limb
|
||||||
|
znx_normalize(nn, log2_base2k, res, 0x0, a, cin);
|
||||||
|
|
||||||
|
// extend result with zeros
|
||||||
|
for (uint64_t i = a_size; i < res_size; ++i) {
|
||||||
|
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT uint64_t vec_znx_normalize_base2k_tmp_bytes_ref(const MODULE* module, uint64_t nn // N
|
||||||
|
) {
|
||||||
|
return nn * sizeof(int64_t);
|
||||||
|
}
|
||||||
|
|
||||||
|
// alias have to be defined in this unit: do not move
|
||||||
|
#ifdef __APPLE__
|
||||||
|
EXPORT uint64_t fft64_vec_znx_big_range_normalize_base2k_tmp_bytes( //
|
||||||
|
const MODULE* module, // N
|
||||||
|
uint64_t nn
|
||||||
|
) {
|
||||||
|
return vec_znx_normalize_base2k_tmp_bytes_ref(module, nn);
|
||||||
|
}
|
||||||
|
EXPORT uint64_t fft64_vec_znx_big_normalize_base2k_tmp_bytes( //
|
||||||
|
const MODULE* module, // N
|
||||||
|
uint64_t nn
|
||||||
|
) {
|
||||||
|
return vec_znx_normalize_base2k_tmp_bytes_ref(module, nn);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
EXPORT uint64_t fft64_vec_znx_big_normalize_base2k_tmp_bytes( //
|
||||||
|
const MODULE* module, // N
|
||||||
|
uint64_t nn
|
||||||
|
) __attribute((alias("vec_znx_normalize_base2k_tmp_bytes_ref")));
|
||||||
|
|
||||||
|
EXPORT uint64_t fft64_vec_znx_big_range_normalize_base2k_tmp_bytes( //
|
||||||
|
const MODULE* module, // N
|
||||||
|
uint64_t nn
|
||||||
|
) __attribute((alias("vec_znx_normalize_base2k_tmp_bytes_ref")));
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/** @brief sets res = 0 */
|
||||||
|
EXPORT void vec_znx_zero(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl // res
|
||||||
|
) {
|
||||||
|
module->func.vec_znx_zero(module, res, res_size, res_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = a */
|
||||||
|
EXPORT void vec_znx_copy(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
module->func.vec_znx_copy(module, res, res_size, res_sl, a, a_size, a_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = a */
|
||||||
|
EXPORT void vec_znx_negate(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
module->func.vec_znx_negate(module, res, res_size, res_sl, a, a_size, a_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_znx_zero_ref(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl // res
|
||||||
|
) {
|
||||||
|
uint64_t nn = module->nn;
|
||||||
|
for (uint64_t i = 0; i < res_size; ++i) {
|
||||||
|
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_znx_copy_ref(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
uint64_t nn = module->nn;
|
||||||
|
uint64_t smin = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < smin; ++i) {
|
||||||
|
znx_copy_i64_ref(nn, res + i * res_sl, a + i * a_sl);
|
||||||
|
}
|
||||||
|
for (uint64_t i = smin; i < res_size; ++i) {
|
||||||
|
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_znx_negate_ref(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
uint64_t nn = module->nn;
|
||||||
|
uint64_t smin = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < smin; ++i) {
|
||||||
|
znx_negate_i64_ref(nn, res + i * res_sl, a + i * a_sl);
|
||||||
|
}
|
||||||
|
for (uint64_t i = smin; i < res_size; ++i) {
|
||||||
|
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,370 @@
|
|||||||
|
#ifndef SPQLIOS_VEC_ZNX_ARITHMETIC_H
|
||||||
|
#define SPQLIOS_VEC_ZNX_ARITHMETIC_H
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include "../commons.h"
|
||||||
|
#include "../reim/reim_fft.h"
|
||||||
|
|
||||||
|
/**
|
||||||
|
* We support the following module families:
|
||||||
|
* - FFT64:
|
||||||
|
* all the polynomials should fit at all times over 52 bits.
|
||||||
|
* for FHE implementations, the recommended limb-sizes are
|
||||||
|
* between K=10 and 20, which is good for low multiplicative depths.
|
||||||
|
* - NTT120:
|
||||||
|
* all the polynomials should fit at all times over 119 bits.
|
||||||
|
* for FHE implementations, the recommended limb-sizes are
|
||||||
|
* between K=20 and 40, which is good for large multiplicative depths.
|
||||||
|
*/
|
||||||
|
typedef enum module_type_t { FFT64, NTT120 } MODULE_TYPE;
|
||||||
|
|
||||||
|
/** @brief opaque structure that describr the modules (ZnX,TnX) and the hardware */
|
||||||
|
typedef struct module_info_t MODULE;
|
||||||
|
/** @brief opaque type that represents a prepared matrix */
|
||||||
|
typedef struct vmp_pmat_t VMP_PMAT;
|
||||||
|
/** @brief opaque type that represents a vector of znx in DFT space */
|
||||||
|
typedef struct vec_znx_dft_t VEC_ZNX_DFT;
|
||||||
|
/** @brief opaque type that represents a vector of znx in large coeffs space */
|
||||||
|
typedef struct vec_znx_bigcoeff_t VEC_ZNX_BIG;
|
||||||
|
/** @brief opaque type that represents a prepared scalar vector product */
|
||||||
|
typedef struct svp_ppol_t SVP_PPOL;
|
||||||
|
/** @brief opaque type that represents a prepared left convolution vector product */
|
||||||
|
typedef struct cnv_pvec_l_t CNV_PVEC_L;
|
||||||
|
/** @brief opaque type that represents a prepared right convolution vector product */
|
||||||
|
typedef struct cnv_pvec_r_t CNV_PVEC_R;
|
||||||
|
|
||||||
|
/** @brief bytes needed for a vec_znx in DFT space */
|
||||||
|
EXPORT uint64_t bytes_of_vec_znx_dft(const MODULE* module, // N
|
||||||
|
uint64_t size);
|
||||||
|
|
||||||
|
/** @brief allocates a vec_znx in DFT space */
|
||||||
|
EXPORT VEC_ZNX_DFT* new_vec_znx_dft(const MODULE* module, // N
|
||||||
|
uint64_t size);
|
||||||
|
|
||||||
|
/** @brief frees memory from a vec_znx in DFT space */
|
||||||
|
EXPORT void delete_vec_znx_dft(VEC_ZNX_DFT* res);
|
||||||
|
|
||||||
|
/** @brief bytes needed for a vec_znx_big */
|
||||||
|
EXPORT uint64_t bytes_of_vec_znx_big(const MODULE* module, // N
|
||||||
|
uint64_t size);
|
||||||
|
|
||||||
|
/** @brief allocates a vec_znx_big */
|
||||||
|
EXPORT VEC_ZNX_BIG* new_vec_znx_big(const MODULE* module, // N
|
||||||
|
uint64_t size);
|
||||||
|
/** @brief frees memory from a vec_znx_big */
|
||||||
|
EXPORT void delete_vec_znx_big(VEC_ZNX_BIG* res);
|
||||||
|
|
||||||
|
/** @brief bytes needed for a prepared vector */
|
||||||
|
EXPORT uint64_t bytes_of_svp_ppol(const MODULE* module); // N
|
||||||
|
|
||||||
|
/** @brief allocates a prepared vector */
|
||||||
|
EXPORT SVP_PPOL* new_svp_ppol(const MODULE* module); // N
|
||||||
|
|
||||||
|
/** @brief frees memory for a prepared vector */
|
||||||
|
EXPORT void delete_svp_ppol(SVP_PPOL* res);
|
||||||
|
|
||||||
|
/** @brief bytes needed for a prepared matrix */
|
||||||
|
EXPORT uint64_t bytes_of_vmp_pmat(const MODULE* module, // N
|
||||||
|
uint64_t nrows, uint64_t ncols);
|
||||||
|
|
||||||
|
/** @brief allocates a prepared matrix */
|
||||||
|
EXPORT VMP_PMAT* new_vmp_pmat(const MODULE* module, // N
|
||||||
|
uint64_t nrows, uint64_t ncols);
|
||||||
|
|
||||||
|
/** @brief frees memory for a prepared matrix */
|
||||||
|
EXPORT void delete_vmp_pmat(VMP_PMAT* res);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief obtain a module info for ring dimension N
|
||||||
|
* the module-info knows about:
|
||||||
|
* - the dimension N (or the complex dimension m=N/2)
|
||||||
|
* - any moduleuted fft or ntt items
|
||||||
|
* - the hardware (avx, arm64, x86, ...)
|
||||||
|
*/
|
||||||
|
EXPORT MODULE* new_module_info(uint64_t N, MODULE_TYPE mode);
|
||||||
|
EXPORT void delete_module_info(MODULE* module_info);
|
||||||
|
EXPORT uint64_t module_get_n(const MODULE* module);
|
||||||
|
|
||||||
|
/** @brief sets res = 0 */
|
||||||
|
EXPORT void vec_znx_zero(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl // res
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a */
|
||||||
|
EXPORT void vec_znx_copy(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a */
|
||||||
|
EXPORT void vec_znx_negate(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a + b */
|
||||||
|
EXPORT void vec_znx_add(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a - b */
|
||||||
|
EXPORT void vec_znx_sub(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = k-normalize-reduce(a) */
|
||||||
|
EXPORT void vec_znx_normalize_base2k(const MODULE* module, // MODULE
|
||||||
|
uint64_t nn, // N
|
||||||
|
uint64_t log2_base2k, // output base 2^K
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
uint8_t* tmp_space // scratch space (size >= N)
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief returns the minimal byte length of scratch space for vec_znx_normalize_base2k */
|
||||||
|
EXPORT uint64_t vec_znx_normalize_base2k_tmp_bytes(const MODULE* module, uint64_t nn // N
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a . X^p */
|
||||||
|
EXPORT void vec_znx_rotate(const MODULE* module, // N
|
||||||
|
const int64_t p, // rotation value
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a * (X^{p} - 1) */
|
||||||
|
EXPORT void vec_znx_mul_xp_minus_one(const MODULE* module, // N
|
||||||
|
const int64_t p, // rotation value
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a(X^p) */
|
||||||
|
EXPORT void vec_znx_automorphism(const MODULE* module, // N
|
||||||
|
const int64_t p, // X-X^p
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = 0 */
|
||||||
|
EXPORT void vec_dft_zero(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size // res
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a+b */
|
||||||
|
EXPORT void vec_dft_add(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a, uint64_t a_size, // a
|
||||||
|
const VEC_ZNX_DFT* b, uint64_t b_size // b
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a-b */
|
||||||
|
EXPORT void vec_dft_sub(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a, uint64_t a_size, // a
|
||||||
|
const VEC_ZNX_DFT* b, uint64_t b_size // b
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = DFT(a) */
|
||||||
|
EXPORT void vec_znx_dft(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = iDFT(a_dft) -- output in big coeffs space */
|
||||||
|
EXPORT void vec_znx_idft(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||||
|
uint8_t* tmp // scratch space
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief tmp bytes required for vec_znx_idft */
|
||||||
|
EXPORT uint64_t vec_znx_idft_tmp_bytes(const MODULE* module, uint64_t nn);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief sets res = iDFT(a_dft) -- output in big coeffs space
|
||||||
|
*
|
||||||
|
* @note a_dft is overwritten
|
||||||
|
*/
|
||||||
|
EXPORT void vec_znx_idft_tmp_a(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
VEC_ZNX_DFT* a_dft, uint64_t a_size // a is overwritten
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a+b */
|
||||||
|
EXPORT void vec_znx_big_add(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||||
|
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||||
|
);
|
||||||
|
/** @brief sets res = a+b */
|
||||||
|
EXPORT void vec_znx_big_add_small(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
);
|
||||||
|
EXPORT void vec_znx_big_add_small2(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a-b */
|
||||||
|
EXPORT void vec_znx_big_sub(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||||
|
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||||
|
);
|
||||||
|
EXPORT void vec_znx_big_sub_small_b(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
);
|
||||||
|
EXPORT void vec_znx_big_sub_small_a(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||||
|
);
|
||||||
|
EXPORT void vec_znx_big_sub_small2(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = k-normalize(a) -- output in int64 coeffs space */
|
||||||
|
EXPORT void vec_znx_big_normalize_base2k(const MODULE* module, // MODULE
|
||||||
|
uint64_t nn, // N
|
||||||
|
uint64_t log2_base2k, // base-2^k
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||||
|
uint8_t* tmp_space // temp space
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief returns the minimal byte length of scratch space for vec_znx_big_normalize_base2k */
|
||||||
|
EXPORT uint64_t vec_znx_big_normalize_base2k_tmp_bytes(const MODULE* module, uint64_t nn // N
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = k-normalize(a.subrange) -- output in int64 coeffs space */
|
||||||
|
EXPORT void vec_znx_big_range_normalize_base2k( //
|
||||||
|
const MODULE* module, // MODULE
|
||||||
|
uint64_t nn,
|
||||||
|
uint64_t log2_base2k, // base-2^k
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_range_begin, uint64_t a_range_xend, uint64_t a_range_step, // range
|
||||||
|
uint8_t* tmp_space // temp space
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief returns the minimal byte length of scratch space for vec_znx_big_range_normalize_base2k */
|
||||||
|
EXPORT uint64_t vec_znx_big_range_normalize_base2k_tmp_bytes( //
|
||||||
|
const MODULE* module, uint64_t nn // N
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a . X^p */
|
||||||
|
EXPORT void vec_znx_big_rotate(const MODULE* module, // N
|
||||||
|
int64_t p, // rotation value
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a(X^p) */
|
||||||
|
EXPORT void vec_znx_big_automorphism(const MODULE* module, // N
|
||||||
|
int64_t p, // X-X^p
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief apply a svp product, result = ppol * a, presented in DFT space */
|
||||||
|
EXPORT void svp_apply_dft(const MODULE* module, // N
|
||||||
|
const VEC_ZNX_DFT* res, uint64_t res_size, // output
|
||||||
|
const SVP_PPOL* ppol, // prepared pol
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief apply a svp product, result = ppol * a, presented in DFT space */
|
||||||
|
EXPORT void svp_apply_dft_to_dft(const MODULE* module, // N
|
||||||
|
const VEC_ZNX_DFT* res, uint64_t res_size,
|
||||||
|
uint64_t res_cols, // output
|
||||||
|
const SVP_PPOL* ppol, // prepared pol
|
||||||
|
const VEC_ZNX_DFT* a, uint64_t a_size, uint64_t a_cols // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief prepares a svp polynomial */
|
||||||
|
EXPORT void svp_prepare(const MODULE* module, // N
|
||||||
|
SVP_PPOL* ppol, // output
|
||||||
|
const int64_t* pol // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief res = a * b : small integer polynomial product */
|
||||||
|
EXPORT void znx_small_single_product(const MODULE* module, // N
|
||||||
|
int64_t* res, // output
|
||||||
|
const int64_t* a, // a
|
||||||
|
const int64_t* b, // b
|
||||||
|
uint8_t* tmp);
|
||||||
|
|
||||||
|
/** @brief tmp bytes required for znx_small_single_product */
|
||||||
|
EXPORT uint64_t znx_small_single_product_tmp_bytes(const MODULE* module, uint64_t nn);
|
||||||
|
|
||||||
|
/** @brief minimal scratch space byte-size required for the vmp_prepare function */
|
||||||
|
EXPORT uint64_t vmp_prepare_tmp_bytes(const MODULE* module, uint64_t nn, // N
|
||||||
|
uint64_t nrows, uint64_t ncols);
|
||||||
|
|
||||||
|
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||||
|
EXPORT void vmp_prepare_contiguous(const MODULE* module, // N
|
||||||
|
VMP_PMAT* pmat, // output
|
||||||
|
const int64_t* mat, uint64_t nrows, uint64_t ncols, // a
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (result in DFT space) adds to res inplace */
|
||||||
|
EXPORT void vmp_apply_dft_add(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, uint64_t pmat_scale, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (result in DFT space) */
|
||||||
|
EXPORT void vmp_apply_dft(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief minimal size of the tmp_space */
|
||||||
|
EXPORT uint64_t vmp_apply_dft_tmp_bytes(const MODULE* module, uint64_t nn, // N
|
||||||
|
uint64_t res_size, // res
|
||||||
|
uint64_t a_size, // a
|
||||||
|
uint64_t nrows, uint64_t ncols // prep matrix
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief applies vmp product */
|
||||||
|
EXPORT void vmp_apply_dft_to_dft(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, const uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||||
|
const VMP_PMAT* pmat, const uint64_t nrows,
|
||||||
|
const uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief applies vmp product and adds to res inplace */
|
||||||
|
EXPORT void vmp_apply_dft_to_dft_add(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, const uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||||
|
const VMP_PMAT* pmat, const uint64_t nrows, const uint64_t ncols,
|
||||||
|
const uint64_t pmat_scale, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief minimal size of the tmp_space */
|
||||||
|
EXPORT uint64_t vmp_apply_dft_to_dft_tmp_bytes(const MODULE* module, uint64_t nn, // N
|
||||||
|
uint64_t res_size, // res
|
||||||
|
uint64_t a_size, // a
|
||||||
|
uint64_t nrows, uint64_t ncols // prep matrix
|
||||||
|
);
|
||||||
|
#endif // SPQLIOS_VEC_ZNX_ARITHMETIC_H
|
||||||
@@ -0,0 +1,563 @@
|
|||||||
|
#ifndef SPQLIOS_VEC_ZNX_ARITHMETIC_PRIVATE_H
|
||||||
|
#define SPQLIOS_VEC_ZNX_ARITHMETIC_PRIVATE_H
|
||||||
|
|
||||||
|
#include "../commons_private.h"
|
||||||
|
#include "../q120/q120_ntt.h"
|
||||||
|
#include "vec_znx_arithmetic.h"
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Layouts families:
|
||||||
|
*
|
||||||
|
* fft64:
|
||||||
|
* K: <= 20, N: <= 65536, ell: <= 200
|
||||||
|
* vec<ZnX> normalized: represented by int64
|
||||||
|
* vec<ZnX> large: represented by int64 (expect <=52 bits)
|
||||||
|
* vec<ZnX> DFT: represented by double (reim_fft space)
|
||||||
|
* On AVX2 inftastructure, PMAT, LCNV, RCNV use a special reim4_fft space
|
||||||
|
*
|
||||||
|
* ntt120:
|
||||||
|
* K: <= 50, N: <= 65536, ell: <= 80
|
||||||
|
* vec<ZnX> normalized: represented by int64
|
||||||
|
* vec<ZnX> large: represented by int128 (expect <=120 bits)
|
||||||
|
* vec<ZnX> DFT: represented by int64x4 (ntt120 space)
|
||||||
|
* On AVX2 inftastructure, PMAT, LCNV, RCNV use a special ntt120 space
|
||||||
|
*
|
||||||
|
* ntt104:
|
||||||
|
* K: <= 40, N: <= 65536, ell: <= 80
|
||||||
|
* vec<ZnX> normalized: represented by int64
|
||||||
|
* vec<ZnX> large: represented by int128 (expect <=120 bits)
|
||||||
|
* vec<ZnX> DFT: represented by int64x4 (ntt120 space)
|
||||||
|
* On AVX512 inftastructure, PMAT, LCNV, RCNV use a special ntt104 space
|
||||||
|
*/
|
||||||
|
|
||||||
|
struct fft64_module_info_t {
|
||||||
|
// pre-computation for reim_fft
|
||||||
|
REIM_FFT_PRECOMP* p_fft;
|
||||||
|
// pre-computation for add_fft
|
||||||
|
REIM_FFTVEC_ADD_PRECOMP* add_fft;
|
||||||
|
// pre-computation for add_fft
|
||||||
|
REIM_FFTVEC_SUB_PRECOMP* sub_fft;
|
||||||
|
// pre-computation for mul_fft
|
||||||
|
REIM_FFTVEC_MUL_PRECOMP* mul_fft;
|
||||||
|
// pre-computation for reim_from_znx6
|
||||||
|
REIM_FROM_ZNX64_PRECOMP* p_conv;
|
||||||
|
// pre-computation for reim_tp_znx6
|
||||||
|
REIM_TO_ZNX64_PRECOMP* p_reim_to_znx;
|
||||||
|
// pre-computation for reim_fft
|
||||||
|
REIM_IFFT_PRECOMP* p_ifft;
|
||||||
|
// pre-computation for reim_fftvec_addmul
|
||||||
|
REIM_FFTVEC_ADDMUL_PRECOMP* p_addmul;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct q120_module_info_t {
|
||||||
|
// pre-computation for q120b to q120b ntt
|
||||||
|
q120_ntt_precomp* p_ntt;
|
||||||
|
// pre-computation for q120b to q120b intt
|
||||||
|
q120_ntt_precomp* p_intt;
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO add function types here
|
||||||
|
typedef typeof(vec_znx_zero) VEC_ZNX_ZERO_F;
|
||||||
|
typedef typeof(vec_znx_copy) VEC_ZNX_COPY_F;
|
||||||
|
typedef typeof(vec_znx_negate) VEC_ZNX_NEGATE_F;
|
||||||
|
typedef typeof(vec_znx_add) VEC_ZNX_ADD_F;
|
||||||
|
typedef typeof(vec_znx_dft) VEC_ZNX_DFT_F;
|
||||||
|
typedef typeof(vec_dft_add) VEC_DFT_ADD_F;
|
||||||
|
typedef typeof(vec_dft_sub) VEC_DFT_SUB_F;
|
||||||
|
typedef typeof(vec_znx_idft) VEC_ZNX_IDFT_F;
|
||||||
|
typedef typeof(vec_znx_idft_tmp_bytes) VEC_ZNX_IDFT_TMP_BYTES_F;
|
||||||
|
typedef typeof(vec_znx_idft_tmp_a) VEC_ZNX_IDFT_TMP_A_F;
|
||||||
|
typedef typeof(vec_znx_sub) VEC_ZNX_SUB_F;
|
||||||
|
typedef typeof(vec_znx_rotate) VEC_ZNX_ROTATE_F;
|
||||||
|
typedef typeof(vec_znx_mul_xp_minus_one) VEC_ZNX_MUL_XP_MINUS_ONE_F;
|
||||||
|
typedef typeof(vec_znx_automorphism) VEC_ZNX_AUTOMORPHISM_F;
|
||||||
|
typedef typeof(vec_znx_normalize_base2k) VEC_ZNX_NORMALIZE_BASE2K_F;
|
||||||
|
typedef typeof(vec_znx_normalize_base2k_tmp_bytes) VEC_ZNX_NORMALIZE_BASE2K_TMP_BYTES_F;
|
||||||
|
typedef typeof(vec_znx_big_normalize_base2k) VEC_ZNX_BIG_NORMALIZE_BASE2K_F;
|
||||||
|
typedef typeof(vec_znx_big_normalize_base2k_tmp_bytes) VEC_ZNX_BIG_NORMALIZE_BASE2K_TMP_BYTES_F;
|
||||||
|
typedef typeof(vec_znx_big_range_normalize_base2k) VEC_ZNX_BIG_RANGE_NORMALIZE_BASE2K_F;
|
||||||
|
typedef typeof(vec_znx_big_range_normalize_base2k_tmp_bytes) VEC_ZNX_BIG_RANGE_NORMALIZE_BASE2K_TMP_BYTES_F;
|
||||||
|
typedef typeof(vec_znx_big_add) VEC_ZNX_BIG_ADD_F;
|
||||||
|
typedef typeof(vec_znx_big_add_small) VEC_ZNX_BIG_ADD_SMALL_F;
|
||||||
|
typedef typeof(vec_znx_big_add_small2) VEC_ZNX_BIG_ADD_SMALL2_F;
|
||||||
|
typedef typeof(vec_znx_big_sub) VEC_ZNX_BIG_SUB_F;
|
||||||
|
typedef typeof(vec_znx_big_sub_small_a) VEC_ZNX_BIG_SUB_SMALL_A_F;
|
||||||
|
typedef typeof(vec_znx_big_sub_small_b) VEC_ZNX_BIG_SUB_SMALL_B_F;
|
||||||
|
typedef typeof(vec_znx_big_sub_small2) VEC_ZNX_BIG_SUB_SMALL2_F;
|
||||||
|
typedef typeof(vec_znx_big_rotate) VEC_ZNX_BIG_ROTATE_F;
|
||||||
|
typedef typeof(vec_znx_big_automorphism) VEC_ZNX_BIG_AUTOMORPHISM_F;
|
||||||
|
typedef typeof(svp_prepare) SVP_PREPARE;
|
||||||
|
typedef typeof(svp_apply_dft) SVP_APPLY_DFT_F;
|
||||||
|
typedef typeof(svp_apply_dft_to_dft) SVP_APPLY_DFT_TO_DFT_F;
|
||||||
|
typedef typeof(znx_small_single_product) ZNX_SMALL_SINGLE_PRODUCT_F;
|
||||||
|
typedef typeof(znx_small_single_product_tmp_bytes) ZNX_SMALL_SINGLE_PRODUCT_TMP_BYTES_F;
|
||||||
|
typedef typeof(vmp_prepare_contiguous) VMP_PREPARE_CONTIGUOUS_F;
|
||||||
|
typedef typeof(vmp_prepare_tmp_bytes) VMP_PREPARE_TMP_BYTES_F;
|
||||||
|
typedef typeof(vmp_apply_dft) VMP_APPLY_DFT_F;
|
||||||
|
typedef typeof(vmp_apply_dft_add) VMP_APPLY_DFT_ADD_F;
|
||||||
|
typedef typeof(vmp_apply_dft_tmp_bytes) VMP_APPLY_DFT_TMP_BYTES_F;
|
||||||
|
typedef typeof(vmp_apply_dft_to_dft) VMP_APPLY_DFT_TO_DFT_F;
|
||||||
|
typedef typeof(vmp_apply_dft_to_dft_add) VMP_APPLY_DFT_TO_DFT_ADD_F;
|
||||||
|
typedef typeof(vmp_apply_dft_to_dft_tmp_bytes) VMP_APPLY_DFT_TO_DFT_TMP_BYTES_F;
|
||||||
|
typedef typeof(bytes_of_vec_znx_dft) BYTES_OF_VEC_ZNX_DFT_F;
|
||||||
|
typedef typeof(bytes_of_vec_znx_big) BYTES_OF_VEC_ZNX_BIG_F;
|
||||||
|
typedef typeof(bytes_of_svp_ppol) BYTES_OF_SVP_PPOL_F;
|
||||||
|
typedef typeof(bytes_of_vmp_pmat) BYTES_OF_VMP_PMAT_F;
|
||||||
|
|
||||||
|
struct module_virtual_functions_t {
|
||||||
|
// TODO add functions here
|
||||||
|
VEC_ZNX_ZERO_F* vec_znx_zero;
|
||||||
|
VEC_ZNX_COPY_F* vec_znx_copy;
|
||||||
|
VEC_ZNX_NEGATE_F* vec_znx_negate;
|
||||||
|
VEC_ZNX_ADD_F* vec_znx_add;
|
||||||
|
VEC_ZNX_DFT_F* vec_znx_dft;
|
||||||
|
VEC_DFT_ADD_F* vec_dft_add;
|
||||||
|
VEC_DFT_SUB_F* vec_dft_sub;
|
||||||
|
VEC_ZNX_IDFT_F* vec_znx_idft;
|
||||||
|
VEC_ZNX_IDFT_TMP_BYTES_F* vec_znx_idft_tmp_bytes;
|
||||||
|
VEC_ZNX_IDFT_TMP_A_F* vec_znx_idft_tmp_a;
|
||||||
|
VEC_ZNX_SUB_F* vec_znx_sub;
|
||||||
|
VEC_ZNX_ROTATE_F* vec_znx_rotate;
|
||||||
|
VEC_ZNX_MUL_XP_MINUS_ONE_F* vec_znx_mul_xp_minus_one;
|
||||||
|
VEC_ZNX_AUTOMORPHISM_F* vec_znx_automorphism;
|
||||||
|
VEC_ZNX_NORMALIZE_BASE2K_F* vec_znx_normalize_base2k;
|
||||||
|
VEC_ZNX_NORMALIZE_BASE2K_TMP_BYTES_F* vec_znx_normalize_base2k_tmp_bytes;
|
||||||
|
VEC_ZNX_BIG_NORMALIZE_BASE2K_F* vec_znx_big_normalize_base2k;
|
||||||
|
VEC_ZNX_BIG_NORMALIZE_BASE2K_TMP_BYTES_F* vec_znx_big_normalize_base2k_tmp_bytes;
|
||||||
|
VEC_ZNX_BIG_RANGE_NORMALIZE_BASE2K_F* vec_znx_big_range_normalize_base2k;
|
||||||
|
VEC_ZNX_BIG_RANGE_NORMALIZE_BASE2K_TMP_BYTES_F* vec_znx_big_range_normalize_base2k_tmp_bytes;
|
||||||
|
VEC_ZNX_BIG_ADD_F* vec_znx_big_add;
|
||||||
|
VEC_ZNX_BIG_ADD_SMALL_F* vec_znx_big_add_small;
|
||||||
|
VEC_ZNX_BIG_ADD_SMALL2_F* vec_znx_big_add_small2;
|
||||||
|
VEC_ZNX_BIG_SUB_F* vec_znx_big_sub;
|
||||||
|
VEC_ZNX_BIG_SUB_SMALL_A_F* vec_znx_big_sub_small_a;
|
||||||
|
VEC_ZNX_BIG_SUB_SMALL_B_F* vec_znx_big_sub_small_b;
|
||||||
|
VEC_ZNX_BIG_SUB_SMALL2_F* vec_znx_big_sub_small2;
|
||||||
|
VEC_ZNX_BIG_ROTATE_F* vec_znx_big_rotate;
|
||||||
|
VEC_ZNX_BIG_AUTOMORPHISM_F* vec_znx_big_automorphism;
|
||||||
|
SVP_PREPARE* svp_prepare;
|
||||||
|
SVP_APPLY_DFT_F* svp_apply_dft;
|
||||||
|
SVP_APPLY_DFT_TO_DFT_F* svp_apply_dft_to_dft;
|
||||||
|
ZNX_SMALL_SINGLE_PRODUCT_F* znx_small_single_product;
|
||||||
|
ZNX_SMALL_SINGLE_PRODUCT_TMP_BYTES_F* znx_small_single_product_tmp_bytes;
|
||||||
|
VMP_PREPARE_CONTIGUOUS_F* vmp_prepare_contiguous;
|
||||||
|
VMP_PREPARE_TMP_BYTES_F* vmp_prepare_tmp_bytes;
|
||||||
|
VMP_APPLY_DFT_F* vmp_apply_dft;
|
||||||
|
VMP_APPLY_DFT_ADD_F* vmp_apply_dft_add;
|
||||||
|
VMP_APPLY_DFT_TMP_BYTES_F* vmp_apply_dft_tmp_bytes;
|
||||||
|
VMP_APPLY_DFT_TO_DFT_F* vmp_apply_dft_to_dft;
|
||||||
|
VMP_APPLY_DFT_TO_DFT_ADD_F* vmp_apply_dft_to_dft_add;
|
||||||
|
VMP_APPLY_DFT_TO_DFT_TMP_BYTES_F* vmp_apply_dft_to_dft_tmp_bytes;
|
||||||
|
BYTES_OF_VEC_ZNX_DFT_F* bytes_of_vec_znx_dft;
|
||||||
|
BYTES_OF_VEC_ZNX_BIG_F* bytes_of_vec_znx_big;
|
||||||
|
BYTES_OF_SVP_PPOL_F* bytes_of_svp_ppol;
|
||||||
|
BYTES_OF_VMP_PMAT_F* bytes_of_vmp_pmat;
|
||||||
|
};
|
||||||
|
|
||||||
|
union backend_module_info_t {
|
||||||
|
struct fft64_module_info_t fft64;
|
||||||
|
struct q120_module_info_t q120;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct module_info_t {
|
||||||
|
// generic parameters
|
||||||
|
MODULE_TYPE module_type;
|
||||||
|
uint64_t nn;
|
||||||
|
uint64_t m;
|
||||||
|
// backend_dependent functions
|
||||||
|
union backend_module_info_t mod;
|
||||||
|
// virtual functions
|
||||||
|
struct module_virtual_functions_t func;
|
||||||
|
};
|
||||||
|
|
||||||
|
EXPORT uint64_t fft64_bytes_of_vec_znx_dft(const MODULE* module, // N
|
||||||
|
uint64_t size);
|
||||||
|
|
||||||
|
EXPORT uint64_t fft64_bytes_of_vec_znx_big(const MODULE* module, // N
|
||||||
|
uint64_t size);
|
||||||
|
|
||||||
|
EXPORT uint64_t fft64_bytes_of_svp_ppol(const MODULE* module); // N
|
||||||
|
|
||||||
|
EXPORT uint64_t fft64_bytes_of_vmp_pmat(const MODULE* module, // N
|
||||||
|
uint64_t nrows, uint64_t ncols);
|
||||||
|
|
||||||
|
EXPORT void vec_znx_zero_ref(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl // res
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_znx_copy_ref(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_znx_negate_ref(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_znx_negate_avx(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_znx_add_ref(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
);
|
||||||
|
EXPORT void vec_znx_add_avx(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_znx_sub_ref(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_znx_sub_avx(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_znx_normalize_base2k_ref(const MODULE* module, uint64_t nn, // N
|
||||||
|
uint64_t log2_base2k, // output base 2^K
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // inp
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT uint64_t vec_znx_normalize_base2k_tmp_bytes_ref(const MODULE* module, uint64_t nn // N
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_znx_rotate_ref(const MODULE* module, // N
|
||||||
|
const int64_t p, // rotation value
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_znx_mul_xp_minus_one_ref(const MODULE* module, // N
|
||||||
|
const int64_t p, // rotation value
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_znx_automorphism_ref(const MODULE* module, // N
|
||||||
|
const int64_t p, // X->X^p
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vmp_prepare_ref(const MODULE* module, // N
|
||||||
|
VMP_PMAT* pmat, // output
|
||||||
|
const int64_t* mat, uint64_t nrows, uint64_t ncols // a
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vmp_apply_dft_ref(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols // prep matrix
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_dft_zero_ref(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size // res
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_dft_add_ref(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a, uint64_t a_size, // a
|
||||||
|
const VEC_ZNX_DFT* b, uint64_t b_size // b
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_dft_sub_ref(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a, uint64_t a_size, // a
|
||||||
|
const VEC_ZNX_DFT* b, uint64_t b_size // b
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_dft_ref(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void vec_idft_ref(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a_dft, uint64_t a_size);
|
||||||
|
|
||||||
|
EXPORT void vec_znx_big_normalize_ref(const MODULE* module, // MODULE
|
||||||
|
uint64_t nn, // N
|
||||||
|
uint64_t k, // base-2^k
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief apply a svp product, result = ppol * a, presented in DFT space */
|
||||||
|
EXPORT void fft64_svp_apply_dft_ref(const MODULE* module, // N
|
||||||
|
const VEC_ZNX_DFT* res, uint64_t res_size, // output
|
||||||
|
const SVP_PPOL* ppol, // prepared pol
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief apply a svp product, result = ppol * a, presented in DFT space */
|
||||||
|
EXPORT void fft64_svp_apply_dft_to_dft_ref(const MODULE* module, // N
|
||||||
|
const VEC_ZNX_DFT* res, uint64_t res_size,
|
||||||
|
uint64_t res_cols, // output
|
||||||
|
const SVP_PPOL* ppol, // prepared pol
|
||||||
|
const VEC_ZNX_DFT* a, uint64_t a_size,
|
||||||
|
uint64_t a_cols // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = k-normalize(a) -- output in int64 coeffs space */
|
||||||
|
EXPORT void fft64_vec_znx_big_normalize_base2k(const MODULE* module, // MODULE
|
||||||
|
uint64_t nn, // N
|
||||||
|
uint64_t k, // base-2^k
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||||
|
uint8_t* tmp_space // temp space
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief returns the minimal byte length of scratch space for vec_znx_big_normalize_base2k */
|
||||||
|
EXPORT uint64_t fft64_vec_znx_big_normalize_base2k_tmp_bytes(const MODULE* module, uint64_t nn // N
|
||||||
|
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = k-normalize(a.subrange) -- output in int64 coeffs space */
|
||||||
|
EXPORT void fft64_vec_znx_big_range_normalize_base2k(const MODULE* module, // MODULE
|
||||||
|
uint64_t nn,
|
||||||
|
uint64_t log2_base2k, // base-2^k
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_range_begin, // a
|
||||||
|
uint64_t a_range_xend, uint64_t a_range_step, // range
|
||||||
|
uint8_t* tmp_space // temp space
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief returns the minimal byte length of scratch space for vec_znx_big_range_normalize_base2k */
|
||||||
|
EXPORT uint64_t fft64_vec_znx_big_range_normalize_base2k_tmp_bytes(const MODULE* module, uint64_t nn // N
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void fft64_vec_znx_dft(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void fft64_vec_dft_add(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a, uint64_t a_size, // a
|
||||||
|
const VEC_ZNX_DFT* b, uint64_t b_size // b
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void fft64_vec_dft_sub(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a, uint64_t a_size, // a
|
||||||
|
const VEC_ZNX_DFT* b, uint64_t b_size // b
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void fft64_vec_znx_idft(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||||
|
uint8_t* tmp // scratch space
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT uint64_t fft64_vec_znx_idft_tmp_bytes(const MODULE* module, uint64_t nn);
|
||||||
|
|
||||||
|
EXPORT void fft64_vec_znx_idft_tmp_a(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
VEC_ZNX_DFT* a_dft, uint64_t a_size // a is overwritten
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT void ntt120_vec_znx_dft_avx(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** */
|
||||||
|
EXPORT void ntt120_vec_znx_idft_avx(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||||
|
uint8_t* tmp // scratch space
|
||||||
|
);
|
||||||
|
|
||||||
|
EXPORT uint64_t ntt120_vec_znx_idft_tmp_bytes_avx(const MODULE* module, uint64_t nn);
|
||||||
|
|
||||||
|
EXPORT void ntt120_vec_znx_idft_tmp_a_avx(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
VEC_ZNX_DFT* a_dft, uint64_t a_size // a is overwritten
|
||||||
|
);
|
||||||
|
|
||||||
|
// big additions/subtractions
|
||||||
|
|
||||||
|
/** @brief sets res = a+b */
|
||||||
|
EXPORT void fft64_vec_znx_big_add(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||||
|
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||||
|
);
|
||||||
|
/** @brief sets res = a+b */
|
||||||
|
EXPORT void fft64_vec_znx_big_add_small(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
);
|
||||||
|
EXPORT void fft64_vec_znx_big_add_small2(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a-b */
|
||||||
|
EXPORT void fft64_vec_znx_big_sub(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||||
|
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||||
|
);
|
||||||
|
EXPORT void fft64_vec_znx_big_sub_small_b(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
);
|
||||||
|
EXPORT void fft64_vec_znx_big_sub_small_a(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||||
|
);
|
||||||
|
EXPORT void fft64_vec_znx_big_sub_small2(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a . X^p */
|
||||||
|
EXPORT void fft64_vec_znx_big_rotate(const MODULE* module, // N
|
||||||
|
int64_t p, // rotation value
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief sets res = a(X^p) */
|
||||||
|
EXPORT void fft64_vec_znx_big_automorphism(const MODULE* module, // N
|
||||||
|
int64_t p, // X-X^p
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief prepares a svp polynomial */
|
||||||
|
EXPORT void fft64_svp_prepare_ref(const MODULE* module, // N
|
||||||
|
SVP_PPOL* ppol, // output
|
||||||
|
const int64_t* pol // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief res = a * b : small integer polynomial product */
|
||||||
|
EXPORT void fft64_znx_small_single_product(const MODULE* module, // N
|
||||||
|
int64_t* res, // output
|
||||||
|
const int64_t* a, // a
|
||||||
|
const int64_t* b, // b
|
||||||
|
uint8_t* tmp);
|
||||||
|
|
||||||
|
/** @brief tmp bytes required for znx_small_single_product */
|
||||||
|
EXPORT uint64_t fft64_znx_small_single_product_tmp_bytes(const MODULE* module, uint64_t nn);
|
||||||
|
|
||||||
|
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||||
|
EXPORT void fft64_vmp_prepare_contiguous_ref(const MODULE* module, // N
|
||||||
|
VMP_PMAT* pmat, // output
|
||||||
|
const int64_t* mat, uint64_t nrows, uint64_t ncols, // a
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||||
|
EXPORT void fft64_vmp_prepare_contiguous_avx(const MODULE* module, // N
|
||||||
|
VMP_PMAT* pmat, // output
|
||||||
|
const int64_t* mat, uint64_t nrows, uint64_t ncols, // a
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief minimal scratch space byte-size required for the vmp_prepare function */
|
||||||
|
EXPORT uint64_t fft64_vmp_prepare_tmp_bytes(const MODULE* module, uint64_t nn, // N
|
||||||
|
uint64_t nrows, uint64_t ncols);
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (result in DFT space) and adds to res inplace */
|
||||||
|
EXPORT void fft64_vmp_apply_dft_add_ref(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols,
|
||||||
|
uint64_t pmat_scale, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (result in DFT space) */
|
||||||
|
EXPORT void fft64_vmp_apply_dft_ref(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (result in DFT space) */
|
||||||
|
EXPORT void fft64_vmp_apply_dft_avx(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (result in DFT space) and adds to res inplace*/
|
||||||
|
EXPORT void fft64_vmp_apply_dft_add_avx(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols,
|
||||||
|
uint64_t pmat_scale, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief this inner function could be very handy */
|
||||||
|
EXPORT void fft64_vmp_apply_dft_to_dft_ref(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, const uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||||
|
const VMP_PMAT* pmat, const uint64_t nrows,
|
||||||
|
const uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief applies rmp product and adds to res inplace */
|
||||||
|
EXPORT void fft64_vmp_apply_dft_to_dft_add_ref(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, const uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||||
|
const VMP_PMAT* pmat, const uint64_t nrows, const uint64_t ncols,
|
||||||
|
uint64_t pmat_scale, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief this inner function could be very handy */
|
||||||
|
EXPORT void fft64_vmp_apply_dft_to_dft_avx(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, const uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||||
|
const VMP_PMAT* pmat, const uint64_t nrows,
|
||||||
|
const uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief applies rmp product and adds to res inplace */
|
||||||
|
EXPORT void fft64_vmp_apply_dft_to_dft_add_avx(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, const uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||||
|
const VMP_PMAT* pmat, const uint64_t nrows, const uint64_t ncols,
|
||||||
|
uint64_t pmat_scale, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief minimal size of the tmp_space */
|
||||||
|
EXPORT uint64_t fft64_vmp_apply_dft_tmp_bytes(const MODULE* module, uint64_t nn, // N
|
||||||
|
uint64_t res_size, // res
|
||||||
|
uint64_t a_size, // a
|
||||||
|
uint64_t nrows, uint64_t ncols // prep matrix
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief minimal size of the tmp_space */
|
||||||
|
EXPORT uint64_t fft64_vmp_apply_dft_to_dft_tmp_bytes(const MODULE* module, uint64_t nn, // N
|
||||||
|
uint64_t res_size, // res
|
||||||
|
uint64_t a_size, // a
|
||||||
|
uint64_t nrows, uint64_t ncols // prep matrix
|
||||||
|
);
|
||||||
|
#endif // SPQLIOS_VEC_ZNX_ARITHMETIC_PRIVATE_H
|
||||||
@@ -0,0 +1,103 @@
|
|||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include "../coeffs/coeffs_arithmetic.h"
|
||||||
|
#include "../reim4/reim4_arithmetic.h"
|
||||||
|
#include "vec_znx_arithmetic_private.h"
|
||||||
|
|
||||||
|
// specialized function (ref)
|
||||||
|
|
||||||
|
// Note: these functions do not have an avx variant.
|
||||||
|
#define znx_copy_i64_avx znx_copy_i64_ref
|
||||||
|
#define znx_zero_i64_avx znx_zero_i64_ref
|
||||||
|
|
||||||
|
EXPORT void vec_znx_add_avx(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
if (a_size <= b_size) {
|
||||||
|
const uint64_t sum_idx = res_size < a_size ? res_size : a_size;
|
||||||
|
const uint64_t copy_idx = res_size < b_size ? res_size : b_size;
|
||||||
|
// add up to the smallest dimension
|
||||||
|
for (uint64_t i = 0; i < sum_idx; ++i) {
|
||||||
|
znx_add_i64_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
// then copy to the largest dimension
|
||||||
|
for (uint64_t i = sum_idx; i < copy_idx; ++i) {
|
||||||
|
znx_copy_i64_avx(nn, res + i * res_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
// then extend with zeros
|
||||||
|
for (uint64_t i = copy_idx; i < res_size; ++i) {
|
||||||
|
znx_zero_i64_avx(nn, res + i * res_sl);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const uint64_t sum_idx = res_size < b_size ? res_size : b_size;
|
||||||
|
const uint64_t copy_idx = res_size < a_size ? res_size : a_size;
|
||||||
|
// add up to the smallest dimension
|
||||||
|
for (uint64_t i = 0; i < sum_idx; ++i) {
|
||||||
|
znx_add_i64_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
// then copy to the largest dimension
|
||||||
|
for (uint64_t i = sum_idx; i < copy_idx; ++i) {
|
||||||
|
znx_copy_i64_avx(nn, res + i * res_sl, a + i * a_sl);
|
||||||
|
}
|
||||||
|
// then extend with zeros
|
||||||
|
for (uint64_t i = copy_idx; i < res_size; ++i) {
|
||||||
|
znx_zero_i64_avx(nn, res + i * res_sl);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_znx_sub_avx(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
if (a_size <= b_size) {
|
||||||
|
const uint64_t sub_idx = res_size < a_size ? res_size : a_size;
|
||||||
|
const uint64_t copy_idx = res_size < b_size ? res_size : b_size;
|
||||||
|
// subtract up to the smallest dimension
|
||||||
|
for (uint64_t i = 0; i < sub_idx; ++i) {
|
||||||
|
znx_sub_i64_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
// then negate to the largest dimension
|
||||||
|
for (uint64_t i = sub_idx; i < copy_idx; ++i) {
|
||||||
|
znx_negate_i64_avx(nn, res + i * res_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
// then extend with zeros
|
||||||
|
for (uint64_t i = copy_idx; i < res_size; ++i) {
|
||||||
|
znx_zero_i64_avx(nn, res + i * res_sl);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const uint64_t sub_idx = res_size < b_size ? res_size : b_size;
|
||||||
|
const uint64_t copy_idx = res_size < a_size ? res_size : a_size;
|
||||||
|
// subtract up to the smallest dimension
|
||||||
|
for (uint64_t i = 0; i < sub_idx; ++i) {
|
||||||
|
znx_sub_i64_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
|
||||||
|
}
|
||||||
|
// then copy to the largest dimension
|
||||||
|
for (uint64_t i = sub_idx; i < copy_idx; ++i) {
|
||||||
|
znx_copy_i64_avx(nn, res + i * res_sl, a + i * a_sl);
|
||||||
|
}
|
||||||
|
// then extend with zeros
|
||||||
|
for (uint64_t i = copy_idx; i < res_size; ++i) {
|
||||||
|
znx_zero_i64_avx(nn, res + i * res_sl);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_znx_negate_avx(const MODULE* module, // N
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
uint64_t nn = module->nn;
|
||||||
|
uint64_t smin = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < smin; ++i) {
|
||||||
|
znx_negate_i64_avx(nn, res + i * res_sl, a + i * a_sl);
|
||||||
|
}
|
||||||
|
for (uint64_t i = smin; i < res_size; ++i) {
|
||||||
|
znx_zero_i64_ref(nn, res + i * res_sl);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,278 @@
|
|||||||
|
#include "vec_znx_arithmetic_private.h"
|
||||||
|
|
||||||
|
EXPORT uint64_t bytes_of_vec_znx_big(const MODULE* module, // N
|
||||||
|
uint64_t size) {
|
||||||
|
return module->func.bytes_of_vec_znx_big(module, size);
|
||||||
|
}
|
||||||
|
|
||||||
|
// public wrappers
|
||||||
|
|
||||||
|
/** @brief sets res = a+b */
|
||||||
|
EXPORT void vec_znx_big_add(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||||
|
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||||
|
) {
|
||||||
|
module->func.vec_znx_big_add(module, res, res_size, a, a_size, b, b_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = a+b */
|
||||||
|
EXPORT void vec_znx_big_add_small(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
) {
|
||||||
|
module->func.vec_znx_big_add_small(module, res, res_size, a, a_size, b, b_size, b_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_znx_big_add_small2(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
) {
|
||||||
|
module->func.vec_znx_big_add_small2(module, res, res_size, a, a_size, a_sl, b, b_size, b_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = a-b */
|
||||||
|
EXPORT void vec_znx_big_sub(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||||
|
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||||
|
) {
|
||||||
|
module->func.vec_znx_big_sub(module, res, res_size, a, a_size, b, b_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_znx_big_sub_small_b(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
) {
|
||||||
|
module->func.vec_znx_big_sub_small_b(module, res, res_size, a, a_size, b, b_size, b_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_znx_big_sub_small_a(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||||
|
) {
|
||||||
|
module->func.vec_znx_big_sub_small_a(module, res, res_size, a, a_size, a_sl, b, b_size);
|
||||||
|
}
|
||||||
|
EXPORT void vec_znx_big_sub_small2(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
) {
|
||||||
|
module->func.vec_znx_big_sub_small2(module, res, res_size, a, a_size, a_sl, b, b_size, b_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = a . X^p */
|
||||||
|
EXPORT void vec_znx_big_rotate(const MODULE* module, // N
|
||||||
|
int64_t p, // rotation value
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size // a
|
||||||
|
) {
|
||||||
|
module->func.vec_znx_big_rotate(module, p, res, res_size, a, a_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = a(X^p) */
|
||||||
|
EXPORT void vec_znx_big_automorphism(const MODULE* module, // N
|
||||||
|
int64_t p, // X-X^p
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size // a
|
||||||
|
) {
|
||||||
|
module->func.vec_znx_big_automorphism(module, p, res, res_size, a, a_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
// private wrappers
|
||||||
|
|
||||||
|
EXPORT uint64_t fft64_bytes_of_vec_znx_big(const MODULE* module, // N
|
||||||
|
uint64_t size) {
|
||||||
|
return module->nn * size * sizeof(double);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT VEC_ZNX_BIG* new_vec_znx_big(const MODULE* module, // N
|
||||||
|
uint64_t size) {
|
||||||
|
return spqlios_alloc(bytes_of_vec_znx_big(module, size));
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void delete_vec_znx_big(VEC_ZNX_BIG* res) { spqlios_free(res); }
|
||||||
|
|
||||||
|
/** @brief sets res = a+b */
|
||||||
|
EXPORT void fft64_vec_znx_big_add(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||||
|
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||||
|
) {
|
||||||
|
const uint64_t n = module->nn;
|
||||||
|
vec_znx_add(module, //
|
||||||
|
(int64_t*)res, res_size, n, //
|
||||||
|
(int64_t*)a, a_size, n, //
|
||||||
|
(int64_t*)b, b_size, n);
|
||||||
|
}
|
||||||
|
/** @brief sets res = a+b */
|
||||||
|
EXPORT void fft64_vec_znx_big_add_small(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
) {
|
||||||
|
const uint64_t n = module->nn;
|
||||||
|
vec_znx_add(module, //
|
||||||
|
(int64_t*)res, res_size, n, //
|
||||||
|
(int64_t*)a, a_size, n, //
|
||||||
|
b, b_size, b_sl);
|
||||||
|
}
|
||||||
|
EXPORT void fft64_vec_znx_big_add_small2(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
) {
|
||||||
|
const uint64_t n = module->nn;
|
||||||
|
vec_znx_add(module, //
|
||||||
|
(int64_t*)res, res_size, n, //
|
||||||
|
a, a_size, a_sl, //
|
||||||
|
b, b_size, b_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = a-b */
|
||||||
|
EXPORT void fft64_vec_znx_big_sub(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||||
|
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||||
|
) {
|
||||||
|
const uint64_t n = module->nn;
|
||||||
|
vec_znx_sub(module, //
|
||||||
|
(int64_t*)res, res_size, n, //
|
||||||
|
(int64_t*)a, a_size, n, //
|
||||||
|
(int64_t*)b, b_size, n);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void fft64_vec_znx_big_sub_small_b(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
) {
|
||||||
|
const uint64_t n = module->nn;
|
||||||
|
vec_znx_sub(module, //
|
||||||
|
(int64_t*)res, res_size, n, //
|
||||||
|
(int64_t*)a, a_size, //
|
||||||
|
n, b, b_size, b_sl);
|
||||||
|
}
|
||||||
|
EXPORT void fft64_vec_znx_big_sub_small_a(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const VEC_ZNX_BIG* b, uint64_t b_size // b
|
||||||
|
) {
|
||||||
|
const uint64_t n = module->nn;
|
||||||
|
vec_znx_sub(module, //
|
||||||
|
(int64_t*)res, res_size, n, //
|
||||||
|
a, a_size, a_sl, //
|
||||||
|
(int64_t*)b, b_size, n);
|
||||||
|
}
|
||||||
|
EXPORT void fft64_vec_znx_big_sub_small2(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
|
||||||
|
) {
|
||||||
|
const uint64_t n = module->nn;
|
||||||
|
vec_znx_sub(module, //
|
||||||
|
(int64_t*)res, res_size, //
|
||||||
|
n, a, a_size, //
|
||||||
|
a_sl, b, b_size, b_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = a . X^p */
|
||||||
|
EXPORT void fft64_vec_znx_big_rotate(const MODULE* module, // N
|
||||||
|
int64_t p, // rotation value
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size // a
|
||||||
|
) {
|
||||||
|
uint64_t nn = module->nn;
|
||||||
|
vec_znx_rotate(module, p, (int64_t*)res, res_size, nn, (int64_t*)a, a_size, nn);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = a(X^p) */
|
||||||
|
EXPORT void fft64_vec_znx_big_automorphism(const MODULE* module, // N
|
||||||
|
int64_t p, // X-X^p
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size // a
|
||||||
|
) {
|
||||||
|
uint64_t nn = module->nn;
|
||||||
|
vec_znx_automorphism(module, p, (int64_t*)res, res_size, nn, (int64_t*)a, a_size, nn);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_znx_big_normalize_base2k(const MODULE* module, // MODULE
|
||||||
|
uint64_t nn, // N
|
||||||
|
uint64_t k, // base-2^k
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||||
|
uint8_t* tmp_space // temp space
|
||||||
|
) {
|
||||||
|
module->func.vec_znx_big_normalize_base2k(module, // MODULE
|
||||||
|
nn, // N
|
||||||
|
k, // base-2^k
|
||||||
|
res, res_size, res_sl, // res
|
||||||
|
a, a_size, // a
|
||||||
|
tmp_space);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT uint64_t vec_znx_big_normalize_base2k_tmp_bytes(const MODULE* module, uint64_t nn // N
|
||||||
|
) {
|
||||||
|
return module->func.vec_znx_big_normalize_base2k_tmp_bytes(module, nn // N
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = k-normalize(a.subrange) -- output in int64 coeffs space */
|
||||||
|
EXPORT void vec_znx_big_range_normalize_base2k( //
|
||||||
|
const MODULE* module, // MODULE
|
||||||
|
uint64_t nn, // N
|
||||||
|
uint64_t log2_base2k, // base-2^k
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_range_begin, uint64_t a_range_xend, uint64_t a_range_step, // range
|
||||||
|
uint8_t* tmp_space // temp space
|
||||||
|
) {
|
||||||
|
module->func.vec_znx_big_range_normalize_base2k(module, nn, log2_base2k, res, res_size, res_sl, a, a_range_begin,
|
||||||
|
a_range_xend, a_range_step, tmp_space);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief returns the minimal byte length of scratch space for vec_znx_big_range_normalize_base2k */
|
||||||
|
EXPORT uint64_t vec_znx_big_range_normalize_base2k_tmp_bytes( //
|
||||||
|
const MODULE* module, // MODULE
|
||||||
|
uint64_t nn // N
|
||||||
|
) {
|
||||||
|
return module->func.vec_znx_big_range_normalize_base2k_tmp_bytes(module, nn);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void fft64_vec_znx_big_normalize_base2k(const MODULE* module, // MODULE
|
||||||
|
uint64_t nn, // N
|
||||||
|
uint64_t k, // base-2^k
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_size, // a
|
||||||
|
uint8_t* tmp_space) {
|
||||||
|
uint64_t a_sl = nn;
|
||||||
|
module->func.vec_znx_normalize_base2k(module, // N
|
||||||
|
nn,
|
||||||
|
k, // log2_base2k
|
||||||
|
res, res_size, res_sl, // res
|
||||||
|
(int64_t*)a, a_size, a_sl, // a
|
||||||
|
tmp_space);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void fft64_vec_znx_big_range_normalize_base2k( //
|
||||||
|
const MODULE* module, // MODULE
|
||||||
|
uint64_t nn, // N
|
||||||
|
uint64_t k, // base-2^k
|
||||||
|
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
|
||||||
|
const VEC_ZNX_BIG* a, uint64_t a_begin, uint64_t a_end, uint64_t a_step, // a
|
||||||
|
uint8_t* tmp_space) {
|
||||||
|
// convert the range indexes to int64[] slices
|
||||||
|
const int64_t* a_st = ((int64_t*)a) + nn * a_begin;
|
||||||
|
const uint64_t a_size = (a_end + a_step - 1 - a_begin) / a_step;
|
||||||
|
const uint64_t a_sl = nn * a_step;
|
||||||
|
// forward the call
|
||||||
|
module->func.vec_znx_normalize_base2k(module, // MODULE
|
||||||
|
nn, // N
|
||||||
|
k, // log2_base2k
|
||||||
|
res, res_size, res_sl, // res
|
||||||
|
a_st, a_size, a_sl, // a
|
||||||
|
tmp_space);
|
||||||
|
}
|
||||||
@@ -0,0 +1,214 @@
|
|||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include "../q120/q120_arithmetic.h"
|
||||||
|
#include "vec_znx_arithmetic_private.h"
|
||||||
|
|
||||||
|
EXPORT void vec_znx_dft(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
return module->func.vec_znx_dft(module, res, res_size, a, a_size, a_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_dft_add(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a, uint64_t a_size, // a
|
||||||
|
const VEC_ZNX_DFT* b, uint64_t b_size // b
|
||||||
|
) {
|
||||||
|
return module->func.vec_dft_add(module, res, res_size, a, a_size, b, b_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_dft_sub(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a, uint64_t a_size, // a
|
||||||
|
const VEC_ZNX_DFT* b, uint64_t b_size // b
|
||||||
|
) {
|
||||||
|
return module->func.vec_dft_sub(module, res, res_size, a, a_size, b, b_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vec_znx_idft(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||||
|
uint8_t* tmp // scratch space
|
||||||
|
) {
|
||||||
|
return module->func.vec_znx_idft(module, res, res_size, a_dft, a_size, tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT uint64_t vec_znx_idft_tmp_bytes(const MODULE* module, uint64_t nn) { return module->func.vec_znx_idft_tmp_bytes(module, nn); }
|
||||||
|
|
||||||
|
EXPORT void vec_znx_idft_tmp_a(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
VEC_ZNX_DFT* a_dft, uint64_t a_size // a is overwritten
|
||||||
|
) {
|
||||||
|
return module->func.vec_znx_idft_tmp_a(module, res, res_size, a_dft, a_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT uint64_t bytes_of_vec_znx_dft(const MODULE* module, // N
|
||||||
|
uint64_t size) {
|
||||||
|
return module->func.bytes_of_vec_znx_dft(module, size);
|
||||||
|
}
|
||||||
|
|
||||||
|
// fft64 backend
|
||||||
|
EXPORT uint64_t fft64_bytes_of_vec_znx_dft(const MODULE* module, // N
|
||||||
|
uint64_t size) {
|
||||||
|
return module->nn * size * sizeof(double);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT VEC_ZNX_DFT* new_vec_znx_dft(const MODULE* module, // N
|
||||||
|
uint64_t size) {
|
||||||
|
return spqlios_alloc(bytes_of_vec_znx_dft(module, size));
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void delete_vec_znx_dft(VEC_ZNX_DFT* res) { spqlios_free(res); }
|
||||||
|
|
||||||
|
EXPORT void fft64_vec_znx_dft(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
const uint64_t smin = res_size < a_size ? res_size : a_size;
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
|
||||||
|
for (uint64_t i = 0; i < smin; i++) {
|
||||||
|
reim_from_znx64(module->mod.fft64.p_conv, ((double*)res) + i * nn, a + i * a_sl);
|
||||||
|
reim_fft(module->mod.fft64.p_fft, ((double*)res) + i * nn);
|
||||||
|
}
|
||||||
|
|
||||||
|
// fill up remaining part with 0's
|
||||||
|
double* const dres = (double*)res;
|
||||||
|
memset(dres + smin * nn, 0, (res_size - smin) * nn * sizeof(double));
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void fft64_vec_dft_add(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a, uint64_t a_size, // a
|
||||||
|
const VEC_ZNX_DFT* b, uint64_t b_size // b
|
||||||
|
) {
|
||||||
|
const uint64_t smin0 = a_size < b_size ? a_size : b_size;
|
||||||
|
const uint64_t smin = res_size < smin0 ? res_size : smin0;
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
|
||||||
|
for (uint64_t i = 0; i < smin; i++) {
|
||||||
|
reim_fftvec_add(module->mod.fft64.add_fft, ((double*)res) + i * nn, ((double*)a) + i * nn, ((double*)b) + i * nn);
|
||||||
|
}
|
||||||
|
|
||||||
|
// fill remain `res` part with 0's
|
||||||
|
double* const dres = (double*)res;
|
||||||
|
memset(dres + smin * nn, 0, (res_size - smin) * nn * sizeof(double));
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void fft64_vec_dft_sub(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a, uint64_t a_size, // a
|
||||||
|
const VEC_ZNX_DFT* b, uint64_t b_size // b
|
||||||
|
) {
|
||||||
|
const uint64_t smin0 = a_size < b_size ? a_size : b_size;
|
||||||
|
const uint64_t smin = res_size < smin0 ? res_size : smin0;
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
|
||||||
|
for (uint64_t i = 0; i < smin; i++) {
|
||||||
|
reim_fftvec_sub(module->mod.fft64.sub_fft, ((double*)res) + i * nn, ((double*)a) + i * nn, ((double*)b) + i * nn);
|
||||||
|
}
|
||||||
|
|
||||||
|
// fill remain `res` part with 0's
|
||||||
|
double* const dres = (double*)res;
|
||||||
|
memset(dres + smin * nn, 0, (res_size - smin) * nn * sizeof(double));
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void fft64_vec_znx_idft(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||||
|
uint8_t* tmp // unused
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
const uint64_t smin = res_size < a_size ? res_size : a_size;
|
||||||
|
if ((double*)res != (double*)a_dft) {
|
||||||
|
memcpy(res, a_dft, smin * nn * sizeof(double));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint64_t i = 0; i < smin; i++) {
|
||||||
|
reim_ifft(module->mod.fft64.p_ifft, ((double*)res) + i * nn);
|
||||||
|
reim_to_znx64(module->mod.fft64.p_reim_to_znx, ((int64_t*)res) + i * nn, ((int64_t*)res) + i * nn);
|
||||||
|
}
|
||||||
|
|
||||||
|
// fill up remaining part with 0's
|
||||||
|
int64_t* const dres = (int64_t*)res;
|
||||||
|
memset(dres + smin * nn, 0, (res_size - smin) * nn * sizeof(double));
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT uint64_t fft64_vec_znx_idft_tmp_bytes(const MODULE* module, uint64_t nn) { return 0; }
|
||||||
|
|
||||||
|
EXPORT void fft64_vec_znx_idft_tmp_a(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
VEC_ZNX_DFT* a_dft, uint64_t a_size // a is overwritten
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
const uint64_t smin = res_size < a_size ? res_size : a_size;
|
||||||
|
|
||||||
|
int64_t* const tres = (int64_t*)res;
|
||||||
|
double* const ta = (double*)a_dft;
|
||||||
|
for (uint64_t i = 0; i < smin; i++) {
|
||||||
|
reim_ifft(module->mod.fft64.p_ifft, ta + i * nn);
|
||||||
|
reim_to_znx64(module->mod.fft64.p_reim_to_znx, tres + i * nn, ta + i * nn);
|
||||||
|
}
|
||||||
|
|
||||||
|
// fill up remaining part with 0's
|
||||||
|
memset(tres + smin * nn, 0, (res_size - smin) * nn * sizeof(double));
|
||||||
|
}
|
||||||
|
|
||||||
|
// ntt120 backend
|
||||||
|
|
||||||
|
EXPORT void ntt120_vec_znx_dft_avx(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
const uint64_t smin = res_size < a_size ? res_size : a_size;
|
||||||
|
|
||||||
|
int64_t* tres = (int64_t*)res;
|
||||||
|
for (uint64_t i = 0; i < smin; i++) {
|
||||||
|
q120_b_from_znx64_simple(nn, (q120b*)(tres + i * nn * 4), a + i * a_sl);
|
||||||
|
q120_ntt_bb_avx2(module->mod.q120.p_ntt, (q120b*)(tres + i * nn * 4));
|
||||||
|
}
|
||||||
|
|
||||||
|
// fill up remaining part with 0's
|
||||||
|
memset(tres + smin * nn * 4, 0, (res_size - smin) * nn * 4 * sizeof(int64_t));
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void ntt120_vec_znx_idft_avx(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||||
|
uint8_t* tmp) {
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
const uint64_t smin = res_size < a_size ? res_size : a_size;
|
||||||
|
|
||||||
|
__int128_t* const tres = (__int128_t*)res;
|
||||||
|
const int64_t* const ta = (int64_t*)a_dft;
|
||||||
|
for (uint64_t i = 0; i < smin; i++) {
|
||||||
|
memcpy(tmp, ta + i * nn * 4, nn * 4 * sizeof(uint64_t));
|
||||||
|
q120_intt_bb_avx2(module->mod.q120.p_intt, (q120b*)tmp);
|
||||||
|
q120_b_to_znx128_simple(nn, tres + i * nn, (q120b*)tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
// fill up remaining part with 0's
|
||||||
|
memset(tres + smin * nn, 0, (res_size - smin) * nn * sizeof(*tres));
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT uint64_t ntt120_vec_znx_idft_tmp_bytes_avx(const MODULE* module, uint64_t nn) { return nn * 4 * sizeof(uint64_t); }
|
||||||
|
|
||||||
|
EXPORT void ntt120_vec_znx_idft_tmp_a_avx(const MODULE* module, // N
|
||||||
|
VEC_ZNX_BIG* res, uint64_t res_size, // res
|
||||||
|
VEC_ZNX_DFT* a_dft, uint64_t a_size // a is overwritten
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
const uint64_t smin = res_size < a_size ? res_size : a_size;
|
||||||
|
|
||||||
|
__int128_t* const tres = (__int128_t*)res;
|
||||||
|
int64_t* const ta = (int64_t*)a_dft;
|
||||||
|
for (uint64_t i = 0; i < smin; i++) {
|
||||||
|
q120_intt_bb_avx2(module->mod.q120.p_intt, (q120b*)(ta + i * nn * 4));
|
||||||
|
q120_b_to_znx128_simple(nn, tres + i * nn, (q120b*)(ta + i * nn * 4));
|
||||||
|
}
|
||||||
|
|
||||||
|
// fill up remaining part with 0's
|
||||||
|
memset(tres + smin * nn, 0, (res_size - smin) * nn * sizeof(*tres));
|
||||||
|
}
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
#include "vec_znx_arithmetic_private.h"
|
||||||
@@ -0,0 +1,369 @@
|
|||||||
|
#include <assert.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include "../reim4/reim4_arithmetic.h"
|
||||||
|
#include "vec_znx_arithmetic_private.h"
|
||||||
|
|
||||||
|
EXPORT uint64_t bytes_of_vmp_pmat(const MODULE* module, // N
|
||||||
|
uint64_t nrows, uint64_t ncols // dimensions
|
||||||
|
) {
|
||||||
|
return module->func.bytes_of_vmp_pmat(module, nrows, ncols);
|
||||||
|
}
|
||||||
|
|
||||||
|
// fft64
|
||||||
|
EXPORT uint64_t fft64_bytes_of_vmp_pmat(const MODULE* module, // N
|
||||||
|
uint64_t nrows, uint64_t ncols // dimensions
|
||||||
|
) {
|
||||||
|
return module->nn * nrows * ncols * sizeof(double);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT VMP_PMAT* new_vmp_pmat(const MODULE* module, // N
|
||||||
|
uint64_t nrows, uint64_t ncols // dimensions
|
||||||
|
) {
|
||||||
|
return spqlios_alloc(bytes_of_vmp_pmat(module, nrows, ncols));
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void delete_vmp_pmat(VMP_PMAT* res) { spqlios_free(res); }
|
||||||
|
|
||||||
|
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||||
|
EXPORT void vmp_prepare_contiguous(const MODULE* module, // N
|
||||||
|
VMP_PMAT* pmat, // output
|
||||||
|
const int64_t* mat, uint64_t nrows, uint64_t ncols, // a
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
) {
|
||||||
|
module->func.vmp_prepare_contiguous(module, pmat, mat, nrows, ncols, tmp_space);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief minimal scratch space byte-size required for the vmp_prepare function */
|
||||||
|
EXPORT uint64_t vmp_prepare_tmp_bytes(const MODULE* module, uint64_t nn, // N
|
||||||
|
uint64_t nrows, uint64_t ncols) {
|
||||||
|
return module->func.vmp_prepare_tmp_bytes(module, nn, nrows, ncols);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT double* get_blk_addr(uint64_t row_i, uint64_t col_i, uint64_t nrows, uint64_t ncols, const VMP_PMAT* pmat) {
|
||||||
|
double* output_mat = (double*)pmat;
|
||||||
|
|
||||||
|
if (col_i == (ncols - 1) && (ncols % 2 == 1)) {
|
||||||
|
// special case: last column out of an odd column number
|
||||||
|
return output_mat + col_i * nrows * 8 // col == ncols-1
|
||||||
|
+ row_i * 8;
|
||||||
|
} else {
|
||||||
|
// general case: columns go by pair
|
||||||
|
return output_mat + (col_i / 2) * (2 * nrows) * 8 // second: col pair index
|
||||||
|
+ row_i * 2 * 8 // third: row index
|
||||||
|
+ (col_i % 2) * 8;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void fft64_store_svp_ppol_into_vmp_pmat_row_blk_ref(uint64_t nn, uint64_t m, const SVP_PPOL* svp_ppol, uint64_t row_i,
|
||||||
|
uint64_t col_i, uint64_t nrows, uint64_t ncols, VMP_PMAT* pmat) {
|
||||||
|
double* start_addr = get_blk_addr(row_i, col_i, nrows, ncols, pmat);
|
||||||
|
uint64_t offset = nrows * ncols * 8;
|
||||||
|
for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
|
||||||
|
reim4_extract_1blk_from_reim_ref(m, blk_i, start_addr + blk_i * offset, (double*)svp_ppol);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||||
|
EXPORT void fft64_vmp_prepare_contiguous_ref(const MODULE* module, // N
|
||||||
|
VMP_PMAT* pmat, // output
|
||||||
|
const int64_t* mat, uint64_t nrows, uint64_t ncols, // a
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
) {
|
||||||
|
// there is an edge case if nn < 8
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
const uint64_t m = module->m;
|
||||||
|
|
||||||
|
if (nn >= 8) {
|
||||||
|
for (uint64_t row_i = 0; row_i < nrows; row_i++) {
|
||||||
|
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||||
|
reim_from_znx64(module->mod.fft64.p_conv, (SVP_PPOL*)tmp_space, mat + (row_i * ncols + col_i) * nn);
|
||||||
|
reim_fft(module->mod.fft64.p_fft, (double*)tmp_space);
|
||||||
|
fft64_store_svp_ppol_into_vmp_pmat_row_blk_ref(nn, m, (SVP_PPOL*)tmp_space, row_i, col_i, nrows, ncols, pmat);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (uint64_t row_i = 0; row_i < nrows; row_i++) {
|
||||||
|
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||||
|
double* res = (double*)pmat + (col_i * nrows + row_i) * nn;
|
||||||
|
reim_from_znx64(module->mod.fft64.p_conv, (SVP_PPOL*)res, mat + (row_i * ncols + col_i) * nn);
|
||||||
|
reim_fft(module->mod.fft64.p_fft, res);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief minimal scratch space byte-size required for the vmp_prepare function */
|
||||||
|
EXPORT uint64_t fft64_vmp_prepare_tmp_bytes(const MODULE* module, uint64_t nn, // N
|
||||||
|
uint64_t nrows, uint64_t ncols) {
|
||||||
|
return nn * sizeof(int64_t);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (result in DFT space) and adds to res inplace */
|
||||||
|
EXPORT void fft64_vmp_apply_dft_add_ref(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols,
|
||||||
|
uint64_t pmat_scale, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
const uint64_t rows = nrows < a_size ? nrows : a_size;
|
||||||
|
|
||||||
|
VEC_ZNX_DFT* a_dft = (VEC_ZNX_DFT*)tmp_space;
|
||||||
|
uint8_t* new_tmp_space = (uint8_t*)tmp_space + rows * nn * sizeof(double);
|
||||||
|
|
||||||
|
fft64_vec_znx_dft(module, a_dft, rows, a, a_size, a_sl);
|
||||||
|
fft64_vmp_apply_dft_to_dft_add_ref(module, res, res_size, a_dft, a_size, pmat, nrows, ncols, pmat_scale,
|
||||||
|
new_tmp_space);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (result in DFT space) */
|
||||||
|
EXPORT void fft64_vmp_apply_dft_ref(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
const uint64_t rows = nrows < a_size ? nrows : a_size;
|
||||||
|
|
||||||
|
VEC_ZNX_DFT* a_dft = (VEC_ZNX_DFT*)tmp_space;
|
||||||
|
uint8_t* new_tmp_space = (uint8_t*)tmp_space + rows * nn * sizeof(double);
|
||||||
|
|
||||||
|
fft64_vec_znx_dft(module, a_dft, rows, a, a_size, a_sl);
|
||||||
|
fft64_vmp_apply_dft_to_dft_ref(module, res, res_size, a_dft, a_size, pmat, nrows, ncols, new_tmp_space);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief like fft64_vmp_apply_dft_to_dft_ref but adds in place */
|
||||||
|
EXPORT void fft64_vmp_apply_dft_to_dft_add_ref(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, const uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||||
|
const VMP_PMAT* pmat, const uint64_t nrows, const uint64_t ncols,
|
||||||
|
uint64_t pmat_scale, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||||
|
) {
|
||||||
|
const uint64_t m = module->m;
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
assert(nn >= 8);
|
||||||
|
|
||||||
|
double* mat2cols_output = (double*)tmp_space; // 128 bytes
|
||||||
|
double* extracted_blk = (double*)tmp_space + 16; // 64*min(nrows,a_size) bytes
|
||||||
|
|
||||||
|
double* mat_input = (double*)pmat;
|
||||||
|
double* vec_input = (double*)a_dft;
|
||||||
|
double* vec_output = (double*)res;
|
||||||
|
|
||||||
|
// const uint64_t row_max0 = res_size < a_size ? res_size: a_size;
|
||||||
|
const uint64_t row_max = nrows < a_size ? nrows : a_size;
|
||||||
|
const uint64_t col_max = ncols < res_size ? ncols : res_size;
|
||||||
|
|
||||||
|
if (nn >= 8) {
|
||||||
|
for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
|
||||||
|
double* mat_blk_start = mat_input + blk_i * (8 * nrows * ncols);
|
||||||
|
|
||||||
|
reim4_extract_1blk_from_contiguous_reim_ref(m, row_max, blk_i, (double*)extracted_blk, (double*)a_dft);
|
||||||
|
|
||||||
|
if (pmat_scale % 2 == 0) {
|
||||||
|
// apply mat2cols
|
||||||
|
for (uint64_t col_res = 0, col_pmat = pmat_scale; col_pmat < col_max - 1; col_res += 2, col_pmat += 2) {
|
||||||
|
uint64_t col_offset = col_pmat * (8 * nrows);
|
||||||
|
reim4_vec_mat2cols_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||||
|
reim4_add_1blk_to_reim_ref(m, blk_i, vec_output + col_res * nn, mat2cols_output);
|
||||||
|
reim4_add_1blk_to_reim_ref(m, blk_i, vec_output + (col_res + 1) * nn, mat2cols_output + 8);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
uint64_t col_offset = (pmat_scale - 1) * (8 * nrows);
|
||||||
|
reim4_vec_mat2cols_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||||
|
reim4_add_1blk_to_reim_ref(m, blk_i, vec_output, mat2cols_output + 8);
|
||||||
|
|
||||||
|
// apply mat2cols
|
||||||
|
for (uint64_t col_res = 1, col_pmat = pmat_scale + 1; col_pmat < col_max - 1; col_res += 2, col_pmat += 2) {
|
||||||
|
uint64_t col_offset = col_pmat * (8 * nrows);
|
||||||
|
reim4_vec_mat2cols_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||||
|
|
||||||
|
reim4_add_1blk_to_reim_ref(m, blk_i, vec_output + col_res * nn, mat2cols_output);
|
||||||
|
reim4_add_1blk_to_reim_ref(m, blk_i, vec_output + (col_res + 1) * nn, mat2cols_output + 8);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if col_max is odd, then special case
|
||||||
|
if (col_max % 2 == 1) {
|
||||||
|
uint64_t last_col = col_max - 1;
|
||||||
|
uint64_t col_offset = last_col * (8 * nrows);
|
||||||
|
|
||||||
|
if (last_col >= pmat_scale) {
|
||||||
|
// the last column is alone in the pmat: vec_mat1col
|
||||||
|
if (ncols == col_max) {
|
||||||
|
reim4_vec_mat1col_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||||
|
} else {
|
||||||
|
// the last column is part of a colpair in the pmat: vec_mat2cols and ignore the second position
|
||||||
|
reim4_vec_mat2cols_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
reim4_add_1blk_to_reim_ref(m, blk_i, vec_output + (last_col - pmat_scale) * nn, mat2cols_output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (uint64_t col_res = 0, col_pmat = pmat_scale; col_pmat < col_max; col_res += 1, col_pmat += 1) {
|
||||||
|
double* pmat_col = mat_input + col_pmat * nrows * nn;
|
||||||
|
for (uint64_t row_i = 0; row_i < row_max; row_i++) {
|
||||||
|
reim_fftvec_addmul(module->mod.fft64.p_addmul, vec_output + col_res * nn, vec_input + row_i * nn,
|
||||||
|
pmat_col + row_i * nn);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// zero out remaining bytes
|
||||||
|
memset(vec_output + col_max * nn, 0, (res_size - col_max) * nn * sizeof(double));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief this inner function could be very handy */
|
||||||
|
EXPORT void fft64_vmp_apply_dft_to_dft_ref(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, const uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||||
|
const VMP_PMAT* pmat, const uint64_t nrows,
|
||||||
|
const uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||||
|
) {
|
||||||
|
const uint64_t m = module->m;
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
assert(nn >= 8);
|
||||||
|
|
||||||
|
double* mat2cols_output = (double*)tmp_space; // 128 bytes
|
||||||
|
double* extracted_blk = (double*)tmp_space + 16; // 64*min(nrows,a_size) bytes
|
||||||
|
|
||||||
|
double* mat_input = (double*)pmat;
|
||||||
|
double* vec_input = (double*)a_dft;
|
||||||
|
double* vec_output = (double*)res;
|
||||||
|
|
||||||
|
// const uint64_t row_max0 = res_size < a_size ? res_size: a_size;
|
||||||
|
const uint64_t row_max = nrows < a_size ? nrows : a_size;
|
||||||
|
const uint64_t col_max = ncols < res_size ? ncols : res_size;
|
||||||
|
|
||||||
|
if (nn >= 8) {
|
||||||
|
for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
|
||||||
|
double* mat_blk_start = mat_input + blk_i * (8 * nrows * ncols);
|
||||||
|
|
||||||
|
reim4_extract_1blk_from_contiguous_reim_ref(m, row_max, blk_i, (double*)extracted_blk, (double*)a_dft);
|
||||||
|
// apply mat2cols
|
||||||
|
for (uint64_t col_i = 0; col_i < col_max - 1; col_i += 2) {
|
||||||
|
uint64_t col_offset = col_i * (8 * nrows);
|
||||||
|
reim4_vec_mat2cols_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||||
|
|
||||||
|
reim4_save_1blk_to_reim_ref(m, blk_i, vec_output + col_i * nn, mat2cols_output);
|
||||||
|
reim4_save_1blk_to_reim_ref(m, blk_i, vec_output + (col_i + 1) * nn, mat2cols_output + 8);
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if col_max is odd, then special case
|
||||||
|
if (col_max % 2 == 1) {
|
||||||
|
uint64_t last_col = col_max - 1;
|
||||||
|
uint64_t col_offset = last_col * (8 * nrows);
|
||||||
|
|
||||||
|
// the last column is alone in the pmat: vec_mat1col
|
||||||
|
if (ncols == col_max) {
|
||||||
|
reim4_vec_mat1col_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||||
|
} else {
|
||||||
|
// the last column is part of a colpair in the pmat: vec_mat2cols and ignore the second position
|
||||||
|
reim4_vec_mat2cols_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||||
|
}
|
||||||
|
reim4_save_1blk_to_reim_ref(m, blk_i, vec_output + last_col * nn, mat2cols_output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (uint64_t col_i = 0; col_i < col_max; col_i++) {
|
||||||
|
double* pmat_col = mat_input + col_i * nrows * nn;
|
||||||
|
for (uint64_t row_i = 0; row_i < 1; row_i++) {
|
||||||
|
reim_fftvec_mul(module->mod.fft64.mul_fft, vec_output + col_i * nn, vec_input + row_i * nn,
|
||||||
|
pmat_col + row_i * nn);
|
||||||
|
}
|
||||||
|
for (uint64_t row_i = 1; row_i < row_max; row_i++) {
|
||||||
|
reim_fftvec_addmul(module->mod.fft64.p_addmul, vec_output + col_i * nn, vec_input + row_i * nn,
|
||||||
|
pmat_col + row_i * nn);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// zero out remaining bytes
|
||||||
|
memset(vec_output + col_max * nn, 0, (res_size - col_max) * nn * sizeof(double));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief minimal size of the tmp_space */
|
||||||
|
EXPORT uint64_t fft64_vmp_apply_dft_tmp_bytes(const MODULE* module, uint64_t nn, // N
|
||||||
|
uint64_t res_size, // res
|
||||||
|
uint64_t a_size, // a
|
||||||
|
uint64_t nrows, uint64_t ncols // prep matrix
|
||||||
|
) {
|
||||||
|
const uint64_t row_max = nrows < a_size ? nrows : a_size;
|
||||||
|
return (row_max * nn * sizeof(double)) + (128) + (64 * row_max);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief minimal size of the tmp_space */
|
||||||
|
EXPORT uint64_t fft64_vmp_apply_dft_to_dft_tmp_bytes(const MODULE* module, uint64_t nn, // N
|
||||||
|
uint64_t res_size, // res
|
||||||
|
uint64_t a_size, // a
|
||||||
|
uint64_t nrows, uint64_t ncols // prep matrix
|
||||||
|
) {
|
||||||
|
const uint64_t row_max = nrows < a_size ? nrows : a_size;
|
||||||
|
|
||||||
|
return (128) + (64 * row_max);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vmp_apply_dft_to_dft(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, const uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||||
|
const VMP_PMAT* pmat, const uint64_t nrows,
|
||||||
|
const uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||||
|
) {
|
||||||
|
module->func.vmp_apply_dft_to_dft(module, res, res_size, a_dft, a_size, pmat, nrows, ncols, tmp_space);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void vmp_apply_dft_to_dft_add(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, const uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||||
|
const VMP_PMAT* pmat, const uint64_t nrows, const uint64_t ncols,
|
||||||
|
uint64_t pmat_scale, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||||
|
) {
|
||||||
|
module->func.vmp_apply_dft_to_dft_add(module, res, res_size, a_dft, a_size, pmat, nrows, ncols, pmat_scale,
|
||||||
|
tmp_space);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT uint64_t vmp_apply_dft_to_dft_tmp_bytes(const MODULE* module, uint64_t nn, // N
|
||||||
|
uint64_t res_size, // res
|
||||||
|
uint64_t a_size, // a
|
||||||
|
uint64_t nrows, uint64_t ncols // prep matrix
|
||||||
|
) {
|
||||||
|
return module->func.vmp_apply_dft_to_dft_tmp_bytes(module, nn, res_size, a_size, nrows, ncols);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (result in DFT space) adds to res inplace */
|
||||||
|
EXPORT void vmp_apply_dft_add(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, uint64_t pmat_scale, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
) {
|
||||||
|
module->func.vmp_apply_dft_add(module, res, res_size, a, a_size, a_sl, pmat, nrows, ncols, pmat_scale, tmp_space);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (result in DFT space) */
|
||||||
|
EXPORT void vmp_apply_dft(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
) {
|
||||||
|
module->func.vmp_apply_dft(module, res, res_size, a, a_size, a_sl, pmat, nrows, ncols, tmp_space);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief minimal size of the tmp_space */
|
||||||
|
EXPORT uint64_t vmp_apply_dft_tmp_bytes(const MODULE* module, uint64_t nn, // N
|
||||||
|
uint64_t res_size, // res
|
||||||
|
uint64_t a_size, // a
|
||||||
|
uint64_t nrows, uint64_t ncols // prep matrix
|
||||||
|
) {
|
||||||
|
return module->func.vmp_apply_dft_tmp_bytes(module, nn, res_size, a_size, nrows, ncols);
|
||||||
|
}
|
||||||
@@ -0,0 +1,244 @@
|
|||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include "../reim4/reim4_arithmetic.h"
|
||||||
|
#include "vec_znx_arithmetic_private.h"
|
||||||
|
|
||||||
|
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||||
|
EXPORT void fft64_vmp_prepare_contiguous_avx(const MODULE* module, // N
|
||||||
|
VMP_PMAT* pmat, // output
|
||||||
|
const int64_t* mat, uint64_t nrows, uint64_t ncols, // a
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
) {
|
||||||
|
// there is an edge case if nn < 8
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
const uint64_t m = module->m;
|
||||||
|
|
||||||
|
double* output_mat = (double*)pmat;
|
||||||
|
double* start_addr = (double*)pmat;
|
||||||
|
uint64_t offset = nrows * ncols * 8;
|
||||||
|
|
||||||
|
if (nn >= 8) {
|
||||||
|
for (uint64_t row_i = 0; row_i < nrows; row_i++) {
|
||||||
|
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||||
|
reim_from_znx64(module->mod.fft64.p_conv, (SVP_PPOL*)tmp_space, mat + (row_i * ncols + col_i) * nn);
|
||||||
|
reim_fft(module->mod.fft64.p_fft, (double*)tmp_space);
|
||||||
|
|
||||||
|
if (col_i == (ncols - 1) && (ncols % 2 == 1)) {
|
||||||
|
// special case: last column out of an odd column number
|
||||||
|
start_addr = output_mat + col_i * nrows * 8 // col == ncols-1
|
||||||
|
+ row_i * 8;
|
||||||
|
} else {
|
||||||
|
// general case: columns go by pair
|
||||||
|
start_addr = output_mat + (col_i / 2) * (2 * nrows) * 8 // second: col pair index
|
||||||
|
+ row_i * 2 * 8 // third: row index
|
||||||
|
+ (col_i % 2) * 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
|
||||||
|
// extract blk from tmp and save it
|
||||||
|
reim4_extract_1blk_from_reim_avx(m, blk_i, start_addr + blk_i * offset, (double*)tmp_space);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (uint64_t row_i = 0; row_i < nrows; row_i++) {
|
||||||
|
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
|
||||||
|
double* res = (double*)pmat + (col_i * nrows + row_i) * nn;
|
||||||
|
reim_from_znx64(module->mod.fft64.p_conv, (SVP_PPOL*)res, mat + (row_i * ncols + col_i) * nn);
|
||||||
|
reim_fft(module->mod.fft64.p_fft, res);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
double* get_blk_addr(int row, int col, int nrows, int ncols, VMP_PMAT* pmat);
|
||||||
|
|
||||||
|
void fft64_store_svp_ppol_into_vmp_pmat_row_blk_avx(uint64_t nn, uint64_t m, const SVP_PPOL* svp_ppol, uint64_t row_i,
|
||||||
|
uint64_t col_i, uint64_t nrows, uint64_t ncols, VMP_PMAT* pmat) {
|
||||||
|
double* start_addr = get_blk_addr(row_i, col_i, nrows, ncols, pmat);
|
||||||
|
uint64_t offset = nrows * ncols * 8;
|
||||||
|
for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
|
||||||
|
reim4_extract_1blk_from_reim_avx(m, blk_i, start_addr + blk_i * offset, (double*)svp_ppol);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (result in DFT space) abd adds to res inplace */
|
||||||
|
EXPORT void fft64_vmp_apply_dft_add_avx(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols,
|
||||||
|
uint64_t pmat_scale, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
const uint64_t rows = nrows < a_size ? nrows : a_size;
|
||||||
|
|
||||||
|
VEC_ZNX_DFT* a_dft = (VEC_ZNX_DFT*)tmp_space;
|
||||||
|
uint8_t* new_tmp_space = (uint8_t*)tmp_space + rows * nn * sizeof(double);
|
||||||
|
|
||||||
|
fft64_vec_znx_dft(module, a_dft, rows, a, a_size, a_sl);
|
||||||
|
fft64_vmp_apply_dft_to_dft_add_avx(module, res, res_size, a_dft, a_size, pmat, nrows, ncols, pmat_scale,
|
||||||
|
new_tmp_space);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (result in DFT space) */
|
||||||
|
EXPORT void fft64_vmp_apply_dft_avx(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
|
||||||
|
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space
|
||||||
|
) {
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
const uint64_t rows = nrows < a_size ? nrows : a_size;
|
||||||
|
|
||||||
|
VEC_ZNX_DFT* a_dft = (VEC_ZNX_DFT*)tmp_space;
|
||||||
|
uint8_t* new_tmp_space = (uint8_t*)tmp_space + rows * nn * sizeof(double);
|
||||||
|
|
||||||
|
fft64_vec_znx_dft(module, a_dft, rows, a, a_size, a_sl);
|
||||||
|
fft64_vmp_apply_dft_to_dft_avx(module, res, res_size, a_dft, a_size, pmat, nrows, ncols, new_tmp_space);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void fft64_vmp_apply_dft_to_dft_add_avx(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, const uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||||
|
const VMP_PMAT* pmat, const uint64_t nrows, const uint64_t ncols,
|
||||||
|
uint64_t pmat_scale, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||||
|
) {
|
||||||
|
const uint64_t m = module->m;
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
|
||||||
|
double* mat2cols_output = (double*)tmp_space; // 128 bytes
|
||||||
|
double* extracted_blk = (double*)tmp_space + 16; // 64*min(nrows,a_size) bytes
|
||||||
|
|
||||||
|
double* mat_input = (double*)pmat;
|
||||||
|
double* vec_input = (double*)a_dft;
|
||||||
|
double* vec_output = (double*)res;
|
||||||
|
|
||||||
|
const uint64_t row_max = nrows < a_size ? nrows : a_size;
|
||||||
|
const uint64_t col_max = ncols < res_size ? ncols : res_size;
|
||||||
|
|
||||||
|
if (nn >= 8) {
|
||||||
|
for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
|
||||||
|
double* mat_blk_start = mat_input + blk_i * (8 * nrows * ncols);
|
||||||
|
|
||||||
|
reim4_extract_1blk_from_contiguous_reim_avx(m, row_max, blk_i, (double*)extracted_blk, (double*)a_dft);
|
||||||
|
|
||||||
|
if (pmat_scale % 2 == 0) {
|
||||||
|
for (uint64_t col_res = 0, col_pmat = pmat_scale; col_pmat < col_max - 1; col_res += 2, col_pmat += 2) {
|
||||||
|
uint64_t col_offset = col_pmat * (8 * nrows);
|
||||||
|
reim4_vec_mat2cols_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||||
|
reim4_add_1blk_to_reim_avx(m, blk_i, vec_output + col_res * nn, mat2cols_output);
|
||||||
|
reim4_add_1blk_to_reim_avx(m, blk_i, vec_output + (col_res + 1) * nn, mat2cols_output + 8);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
uint64_t col_offset = (pmat_scale - 1) * (8 * nrows);
|
||||||
|
reim4_vec_mat2cols_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||||
|
reim4_add_1blk_to_reim_avx(m, blk_i, vec_output, mat2cols_output + 8);
|
||||||
|
|
||||||
|
for (uint64_t col_res = 1, col_pmat = pmat_scale + 1; col_pmat < col_max - 1; col_res += 2, col_pmat += 2) {
|
||||||
|
uint64_t col_offset = col_pmat * (8 * nrows);
|
||||||
|
reim4_vec_mat2cols_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||||
|
reim4_add_1blk_to_reim_avx(m, blk_i, vec_output + col_res * nn, mat2cols_output);
|
||||||
|
reim4_add_1blk_to_reim_avx(m, blk_i, vec_output + (col_res + 1) * nn, mat2cols_output + 8);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if col_max is odd, then special case
|
||||||
|
if (col_max % 2 == 1) {
|
||||||
|
uint64_t last_col = col_max - 1;
|
||||||
|
uint64_t col_offset = last_col * (8 * nrows);
|
||||||
|
|
||||||
|
if (last_col >= pmat_scale) {
|
||||||
|
// the last column is alone in the pmat: vec_mat1col
|
||||||
|
if (ncols == col_max)
|
||||||
|
reim4_vec_mat1col_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||||
|
else {
|
||||||
|
// the last column is part of a colpair in the pmat: vec_mat2cols and ignore the second position
|
||||||
|
reim4_vec_mat2cols_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||||
|
}
|
||||||
|
reim4_add_1blk_to_reim_avx(m, blk_i, vec_output + (last_col - pmat_scale) * nn, mat2cols_output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (uint64_t col_res = 0, col_pmat = pmat_scale; col_pmat < col_max; col_res += 1, col_pmat += 1) {
|
||||||
|
double* pmat_col = mat_input + col_pmat * nrows * nn;
|
||||||
|
for (uint64_t row_i = 0; row_i < row_max; row_i++) {
|
||||||
|
reim_fftvec_addmul(module->mod.fft64.p_addmul, vec_output + col_res * nn, vec_input + row_i * nn,
|
||||||
|
pmat_col + row_i * nn);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// zero out remaining bytes
|
||||||
|
memset(vec_output + col_max * nn, 0, (res_size - col_max) * nn * sizeof(double));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief this inner function could be very handy */
|
||||||
|
EXPORT void fft64_vmp_apply_dft_to_dft_avx(const MODULE* module, // N
|
||||||
|
VEC_ZNX_DFT* res, const uint64_t res_size, // res
|
||||||
|
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
|
||||||
|
const VMP_PMAT* pmat, const uint64_t nrows,
|
||||||
|
const uint64_t ncols, // prep matrix
|
||||||
|
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
|
||||||
|
) {
|
||||||
|
const uint64_t m = module->m;
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
|
||||||
|
double* mat2cols_output = (double*)tmp_space; // 128 bytes
|
||||||
|
double* extracted_blk = (double*)tmp_space + 16; // 64*min(nrows,a_size) bytes
|
||||||
|
|
||||||
|
double* mat_input = (double*)pmat;
|
||||||
|
double* vec_input = (double*)a_dft;
|
||||||
|
double* vec_output = (double*)res;
|
||||||
|
|
||||||
|
const uint64_t row_max = nrows < a_size ? nrows : a_size;
|
||||||
|
const uint64_t col_max = ncols < res_size ? ncols : res_size;
|
||||||
|
|
||||||
|
if (nn >= 8) {
|
||||||
|
for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
|
||||||
|
double* mat_blk_start = mat_input + blk_i * (8 * nrows * ncols);
|
||||||
|
|
||||||
|
reim4_extract_1blk_from_contiguous_reim_avx(m, row_max, blk_i, (double*)extracted_blk, (double*)a_dft);
|
||||||
|
// apply mat2cols
|
||||||
|
for (uint64_t col_i = 0; col_i < col_max - 1; col_i += 2) {
|
||||||
|
uint64_t col_offset = col_i * (8 * nrows);
|
||||||
|
reim4_vec_mat2cols_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||||
|
|
||||||
|
reim4_save_1blk_to_reim_avx(m, blk_i, vec_output + col_i * nn, mat2cols_output);
|
||||||
|
reim4_save_1blk_to_reim_avx(m, blk_i, vec_output + (col_i + 1) * nn, mat2cols_output + 8);
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if col_max is odd, then special case
|
||||||
|
if (col_max % 2 == 1) {
|
||||||
|
uint64_t last_col = col_max - 1;
|
||||||
|
uint64_t col_offset = last_col * (8 * nrows);
|
||||||
|
|
||||||
|
// the last column is alone in the pmat: vec_mat1col
|
||||||
|
if (ncols == col_max)
|
||||||
|
reim4_vec_mat1col_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||||
|
else {
|
||||||
|
// the last column is part of a colpair in the pmat: vec_mat2cols and ignore the second position
|
||||||
|
reim4_vec_mat2cols_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset);
|
||||||
|
}
|
||||||
|
reim4_save_1blk_to_reim_avx(m, blk_i, vec_output + last_col * nn, mat2cols_output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (uint64_t col_i = 0; col_i < col_max; col_i++) {
|
||||||
|
double* pmat_col = mat_input + col_i * nrows * nn;
|
||||||
|
for (uint64_t row_i = 0; row_i < 1; row_i++) {
|
||||||
|
reim_fftvec_mul(module->mod.fft64.mul_fft, vec_output + col_i * nn, vec_input + row_i * nn,
|
||||||
|
pmat_col + row_i * nn);
|
||||||
|
}
|
||||||
|
for (uint64_t row_i = 1; row_i < row_max; row_i++) {
|
||||||
|
reim_fftvec_addmul(module->mod.fft64.p_addmul, vec_output + col_i * nn, vec_input + row_i * nn,
|
||||||
|
pmat_col + row_i * nn);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// zero out remaining bytes
|
||||||
|
memset(vec_output + col_max * nn, 0, (res_size - col_max) * nn * sizeof(double));
|
||||||
|
}
|
||||||
@@ -0,0 +1,185 @@
|
|||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include "zn_arithmetic_private.h"
|
||||||
|
|
||||||
|
void default_init_z_module_precomp(MOD_Z* module) {
|
||||||
|
// Add here initialization of items that are in the precomp
|
||||||
|
}
|
||||||
|
|
||||||
|
void default_finalize_z_module_precomp(MOD_Z* module) {
|
||||||
|
// Add here deleters for items that are in the precomp
|
||||||
|
}
|
||||||
|
|
||||||
|
void default_init_z_module_vtable(MOD_Z* module) {
|
||||||
|
// Add function pointers here
|
||||||
|
module->vtable.i8_approxdecomp_from_tndbl = default_i8_approxdecomp_from_tndbl_ref;
|
||||||
|
module->vtable.i16_approxdecomp_from_tndbl = default_i16_approxdecomp_from_tndbl_ref;
|
||||||
|
module->vtable.i32_approxdecomp_from_tndbl = default_i32_approxdecomp_from_tndbl_ref;
|
||||||
|
module->vtable.zn32_vmp_prepare_contiguous = default_zn32_vmp_prepare_contiguous_ref;
|
||||||
|
module->vtable.zn32_vmp_prepare_dblptr = default_zn32_vmp_prepare_dblptr_ref;
|
||||||
|
module->vtable.zn32_vmp_prepare_row = default_zn32_vmp_prepare_row_ref;
|
||||||
|
module->vtable.zn32_vmp_apply_i8 = default_zn32_vmp_apply_i8_ref;
|
||||||
|
module->vtable.zn32_vmp_apply_i16 = default_zn32_vmp_apply_i16_ref;
|
||||||
|
module->vtable.zn32_vmp_apply_i32 = default_zn32_vmp_apply_i32_ref;
|
||||||
|
module->vtable.dbl_to_tn32 = dbl_to_tn32_ref;
|
||||||
|
module->vtable.tn32_to_dbl = tn32_to_dbl_ref;
|
||||||
|
module->vtable.dbl_round_to_i32 = dbl_round_to_i32_ref;
|
||||||
|
module->vtable.i32_to_dbl = i32_to_dbl_ref;
|
||||||
|
module->vtable.dbl_round_to_i64 = dbl_round_to_i64_ref;
|
||||||
|
module->vtable.i64_to_dbl = i64_to_dbl_ref;
|
||||||
|
|
||||||
|
// Add optimized function pointers here
|
||||||
|
if (CPU_SUPPORTS("avx")) {
|
||||||
|
module->vtable.zn32_vmp_apply_i8 = default_zn32_vmp_apply_i8_avx;
|
||||||
|
module->vtable.zn32_vmp_apply_i16 = default_zn32_vmp_apply_i16_avx;
|
||||||
|
module->vtable.zn32_vmp_apply_i32 = default_zn32_vmp_apply_i32_avx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void init_z_module_info(MOD_Z* module, //
|
||||||
|
Z_MODULE_TYPE mtype) {
|
||||||
|
memset(module, 0, sizeof(MOD_Z));
|
||||||
|
module->mtype = mtype;
|
||||||
|
switch (mtype) {
|
||||||
|
case DEFAULT:
|
||||||
|
default_init_z_module_precomp(module);
|
||||||
|
default_init_z_module_vtable(module);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
NOT_SUPPORTED(); // unknown mtype
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void finalize_z_module_info(MOD_Z* module) {
|
||||||
|
if (module->custom) module->custom_deleter(module->custom);
|
||||||
|
switch (module->mtype) {
|
||||||
|
case DEFAULT:
|
||||||
|
default_finalize_z_module_precomp(module);
|
||||||
|
// fft64_finalize_rnx_module_vtable(module); // nothing to finalize
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
NOT_SUPPORTED(); // unknown mtype
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT MOD_Z* new_z_module_info(Z_MODULE_TYPE mtype) {
|
||||||
|
MOD_Z* res = (MOD_Z*)malloc(sizeof(MOD_Z));
|
||||||
|
init_z_module_info(res, mtype);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void delete_z_module_info(MOD_Z* module_info) {
|
||||||
|
finalize_z_module_info(module_info);
|
||||||
|
free(module_info);
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////// wrappers //////////////////
|
||||||
|
|
||||||
|
/** @brief sets res = gadget_decompose(a) (int8_t* output) */
|
||||||
|
EXPORT void i8_approxdecomp_from_tndbl(const MOD_Z* module, // N
|
||||||
|
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||||
|
int8_t* res, uint64_t res_size, // res (in general, size ell.a_size)
|
||||||
|
const double* a, uint64_t a_size) { // a
|
||||||
|
module->vtable.i8_approxdecomp_from_tndbl(module, gadget, res, res_size, a, a_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = gadget_decompose(a) (int16_t* output) */
|
||||||
|
EXPORT void i16_approxdecomp_from_tndbl(const MOD_Z* module, // N
|
||||||
|
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||||
|
int16_t* res, uint64_t res_size, // res (in general, size ell.a_size)
|
||||||
|
const double* a, uint64_t a_size) { // a
|
||||||
|
module->vtable.i16_approxdecomp_from_tndbl(module, gadget, res, res_size, a, a_size);
|
||||||
|
}
|
||||||
|
/** @brief sets res = gadget_decompose(a) (int32_t* output) */
|
||||||
|
EXPORT void i32_approxdecomp_from_tndbl(const MOD_Z* module, // N
|
||||||
|
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||||
|
int32_t* res, uint64_t res_size, // res (in general, size ell.a_size)
|
||||||
|
const double* a, uint64_t a_size) { // a
|
||||||
|
module->vtable.i32_approxdecomp_from_tndbl(module, gadget, res, res_size, a, a_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||||
|
EXPORT void zn32_vmp_prepare_contiguous(const MOD_Z* module,
|
||||||
|
ZN32_VMP_PMAT* pmat, // output
|
||||||
|
const int32_t* mat, uint64_t nrows, uint64_t ncols) { // a
|
||||||
|
module->vtable.zn32_vmp_prepare_contiguous(module, pmat, mat, nrows, ncols);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */
|
||||||
|
EXPORT void zn32_vmp_prepare_dblptr(const MOD_Z* module,
|
||||||
|
ZN32_VMP_PMAT* pmat, // output
|
||||||
|
const int32_t** mat, uint64_t nrows, uint64_t ncols) { // a
|
||||||
|
module->vtable.zn32_vmp_prepare_dblptr(module, pmat, mat, nrows, ncols);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */
|
||||||
|
EXPORT void zn32_vmp_prepare_row(const MOD_Z* module,
|
||||||
|
ZN32_VMP_PMAT* pmat, // output
|
||||||
|
const int32_t* row, uint64_t row_i, uint64_t nrows, uint64_t ncols) { // a
|
||||||
|
module->vtable.zn32_vmp_prepare_row(module, pmat, row, row_i, nrows, ncols);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (int32_t* input) */
|
||||||
|
EXPORT void zn32_vmp_apply_i32(const MOD_Z* module, int32_t* res, uint64_t res_size, const int32_t* a, uint64_t a_size,
|
||||||
|
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) {
|
||||||
|
module->vtable.zn32_vmp_apply_i32(module, res, res_size, a, a_size, pmat, nrows, ncols);
|
||||||
|
}
|
||||||
|
/** @brief applies a vmp product (int16_t* input) */
|
||||||
|
EXPORT void zn32_vmp_apply_i16(const MOD_Z* module, int32_t* res, uint64_t res_size, const int16_t* a, uint64_t a_size,
|
||||||
|
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) {
|
||||||
|
module->vtable.zn32_vmp_apply_i16(module, res, res_size, a, a_size, pmat, nrows, ncols);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (int8_t* input) */
|
||||||
|
EXPORT void zn32_vmp_apply_i8(const MOD_Z* module, int32_t* res, uint64_t res_size, const int8_t* a, uint64_t a_size,
|
||||||
|
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) {
|
||||||
|
module->vtable.zn32_vmp_apply_i8(module, res, res_size, a, a_size, pmat, nrows, ncols);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** reduction mod 1, output in torus32 space */
|
||||||
|
EXPORT void dbl_to_tn32(const MOD_Z* module, //
|
||||||
|
int32_t* res, uint64_t res_size, // res
|
||||||
|
const double* a, uint64_t a_size // a
|
||||||
|
) {
|
||||||
|
module->vtable.dbl_to_tn32(module, res, res_size, a, a_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** real centerlift mod 1, output in double space */
|
||||||
|
EXPORT void tn32_to_dbl(const MOD_Z* module, //
|
||||||
|
double* res, uint64_t res_size, // res
|
||||||
|
const int32_t* a, uint64_t a_size // a
|
||||||
|
) {
|
||||||
|
module->vtable.tn32_to_dbl(module, res, res_size, a, a_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** round to the nearest int, output in i32 space */
|
||||||
|
EXPORT void dbl_round_to_i32(const MOD_Z* module, //
|
||||||
|
int32_t* res, uint64_t res_size, // res
|
||||||
|
const double* a, uint64_t a_size // a
|
||||||
|
) {
|
||||||
|
module->vtable.dbl_round_to_i32(module, res, res_size, a, a_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** small int (int32 space) to double */
|
||||||
|
EXPORT void i32_to_dbl(const MOD_Z* module, //
|
||||||
|
double* res, uint64_t res_size, // res
|
||||||
|
const int32_t* a, uint64_t a_size // a
|
||||||
|
) {
|
||||||
|
module->vtable.i32_to_dbl(module, res, res_size, a, a_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** round to the nearest int, output in int64 space */
|
||||||
|
EXPORT void dbl_round_to_i64(const MOD_Z* module, //
|
||||||
|
int64_t* res, uint64_t res_size, // res
|
||||||
|
const double* a, uint64_t a_size // a
|
||||||
|
) {
|
||||||
|
module->vtable.dbl_round_to_i64(module, res, res_size, a, a_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** small int (int64 space, <= 2^50) to double */
|
||||||
|
EXPORT void i64_to_dbl(const MOD_Z* module, //
|
||||||
|
double* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size // a
|
||||||
|
) {
|
||||||
|
module->vtable.i64_to_dbl(module, res, res_size, a, a_size);
|
||||||
|
}
|
||||||
@@ -0,0 +1,81 @@
|
|||||||
|
#include <memory.h>
|
||||||
|
|
||||||
|
#include "zn_arithmetic_private.h"
|
||||||
|
|
||||||
|
EXPORT TNDBL_APPROXDECOMP_GADGET* new_tndbl_approxdecomp_gadget(const MOD_Z* module, //
|
||||||
|
uint64_t k, uint64_t ell) {
|
||||||
|
if (k * ell > 50) {
|
||||||
|
return spqlios_error("approx decomposition requested is too precise for doubles");
|
||||||
|
}
|
||||||
|
if (k < 1) {
|
||||||
|
return spqlios_error("approx decomposition supports k>=1");
|
||||||
|
}
|
||||||
|
TNDBL_APPROXDECOMP_GADGET* res = malloc(sizeof(TNDBL_APPROXDECOMP_GADGET));
|
||||||
|
memset(res, 0, sizeof(TNDBL_APPROXDECOMP_GADGET));
|
||||||
|
res->k = k;
|
||||||
|
res->ell = ell;
|
||||||
|
double add_cst = INT64_C(3) << (51 - k * ell);
|
||||||
|
for (uint64_t i = 0; i < ell; ++i) {
|
||||||
|
add_cst += pow(2., -(double)(i * k + 1));
|
||||||
|
}
|
||||||
|
res->add_cst = add_cst;
|
||||||
|
res->and_mask = (UINT64_C(1) << k) - 1;
|
||||||
|
res->sub_cst = UINT64_C(1) << (k - 1);
|
||||||
|
for (uint64_t i = 0; i < ell; ++i) res->rshifts[i] = (ell - 1 - i) * k;
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
EXPORT void delete_tndbl_approxdecomp_gadget(TNDBL_APPROXDECOMP_GADGET* ptr) { free(ptr); }
|
||||||
|
|
||||||
|
EXPORT int default_init_tndbl_approxdecomp_gadget(const MOD_Z* module, //
|
||||||
|
TNDBL_APPROXDECOMP_GADGET* res, //
|
||||||
|
uint64_t k, uint64_t ell) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
typedef union {
|
||||||
|
double dv;
|
||||||
|
uint64_t uv;
|
||||||
|
} du_t;
|
||||||
|
|
||||||
|
#define IMPL_ixx_approxdecomp_from_tndbl_ref(ITYPE) \
|
||||||
|
if (res_size != a_size * gadget->ell) NOT_IMPLEMENTED(); \
|
||||||
|
const uint64_t ell = gadget->ell; \
|
||||||
|
const double add_cst = gadget->add_cst; \
|
||||||
|
const uint8_t* const rshifts = gadget->rshifts; \
|
||||||
|
const ITYPE and_mask = gadget->and_mask; \
|
||||||
|
const ITYPE sub_cst = gadget->sub_cst; \
|
||||||
|
ITYPE* rr = res; \
|
||||||
|
const double* aa = a; \
|
||||||
|
const double* aaend = a + a_size; \
|
||||||
|
while (aa < aaend) { \
|
||||||
|
du_t t = {.dv = *aa + add_cst}; \
|
||||||
|
for (uint64_t i = 0; i < ell; ++i) { \
|
||||||
|
ITYPE v = (ITYPE)(t.uv >> rshifts[i]); \
|
||||||
|
*rr = (v & and_mask) - sub_cst; \
|
||||||
|
++rr; \
|
||||||
|
} \
|
||||||
|
++aa; \
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief sets res = gadget_decompose(a) (int8_t* output) */
|
||||||
|
EXPORT void default_i8_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N
|
||||||
|
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||||
|
int8_t* res, uint64_t res_size, // res (in general, size ell.a_size)
|
||||||
|
const double* a, uint64_t a_size //
|
||||||
|
){IMPL_ixx_approxdecomp_from_tndbl_ref(int8_t)}
|
||||||
|
|
||||||
|
/** @brief sets res = gadget_decompose(a) (int16_t* output) */
|
||||||
|
EXPORT void default_i16_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N
|
||||||
|
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||||
|
int16_t* res, uint64_t res_size, // res
|
||||||
|
const double* a, uint64_t a_size // a
|
||||||
|
){IMPL_ixx_approxdecomp_from_tndbl_ref(int16_t)}
|
||||||
|
|
||||||
|
/** @brief sets res = gadget_decompose(a) (int32_t* output) */
|
||||||
|
EXPORT void default_i32_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N
|
||||||
|
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||||
|
int32_t* res, uint64_t res_size, // res
|
||||||
|
const double* a, uint64_t a_size // a
|
||||||
|
) {
|
||||||
|
IMPL_ixx_approxdecomp_from_tndbl_ref(int32_t)
|
||||||
|
}
|
||||||
@@ -0,0 +1,147 @@
|
|||||||
|
#ifndef SPQLIOS_ZN_ARITHMETIC_H
|
||||||
|
#define SPQLIOS_ZN_ARITHMETIC_H
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include "../commons.h"
|
||||||
|
|
||||||
|
typedef enum z_module_type_t { DEFAULT } Z_MODULE_TYPE;
|
||||||
|
|
||||||
|
/** @brief opaque structure that describes the module and the hardware */
|
||||||
|
typedef struct z_module_info_t MOD_Z;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief obtain a module info for ring dimension N
|
||||||
|
* the module-info knows about:
|
||||||
|
* - the dimension N (or the complex dimension m=N/2)
|
||||||
|
* - any moduleuted fft or ntt items
|
||||||
|
* - the hardware (avx, arm64, x86, ...)
|
||||||
|
*/
|
||||||
|
EXPORT MOD_Z* new_z_module_info(Z_MODULE_TYPE mode);
|
||||||
|
EXPORT void delete_z_module_info(MOD_Z* module_info);
|
||||||
|
|
||||||
|
typedef struct tndbl_approxdecomp_gadget_t TNDBL_APPROXDECOMP_GADGET;
|
||||||
|
|
||||||
|
EXPORT TNDBL_APPROXDECOMP_GADGET* new_tndbl_approxdecomp_gadget(const MOD_Z* module, //
|
||||||
|
uint64_t k,
|
||||||
|
uint64_t ell); // base 2^k, and size
|
||||||
|
|
||||||
|
EXPORT void delete_tndbl_approxdecomp_gadget(TNDBL_APPROXDECOMP_GADGET* ptr);
|
||||||
|
|
||||||
|
/** @brief sets res = gadget_decompose(a) (int8_t* output) */
|
||||||
|
EXPORT void i8_approxdecomp_from_tndbl(const MOD_Z* module, // N
|
||||||
|
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||||
|
int8_t* res, uint64_t res_size, // res (in general, size ell.a_size)
|
||||||
|
const double* a, uint64_t a_size); // a
|
||||||
|
|
||||||
|
/** @brief sets res = gadget_decompose(a) (int16_t* output) */
|
||||||
|
EXPORT void i16_approxdecomp_from_tndbl(const MOD_Z* module, // N
|
||||||
|
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||||
|
int16_t* res, uint64_t res_size, // res (in general, size ell.a_size)
|
||||||
|
const double* a, uint64_t a_size); // a
|
||||||
|
/** @brief sets res = gadget_decompose(a) (int32_t* output) */
|
||||||
|
EXPORT void i32_approxdecomp_from_tndbl(const MOD_Z* module, // N
|
||||||
|
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||||
|
int32_t* res, uint64_t res_size, // res (in general, size ell.a_size)
|
||||||
|
const double* a, uint64_t a_size); // a
|
||||||
|
|
||||||
|
/** @brief opaque type that represents a prepared matrix */
|
||||||
|
typedef struct zn32_vmp_pmat_t ZN32_VMP_PMAT;
|
||||||
|
|
||||||
|
/** @brief size in bytes of a prepared matrix (for custom allocation) */
|
||||||
|
EXPORT uint64_t bytes_of_zn32_vmp_pmat(const MOD_Z* module, // N
|
||||||
|
uint64_t nrows, uint64_t ncols); // dimensions
|
||||||
|
|
||||||
|
/** @brief allocates a prepared matrix (release with delete_zn32_vmp_pmat) */
|
||||||
|
EXPORT ZN32_VMP_PMAT* new_zn32_vmp_pmat(const MOD_Z* module, // N
|
||||||
|
uint64_t nrows, uint64_t ncols); // dimensions
|
||||||
|
|
||||||
|
/** @brief deletes a prepared matrix (release with free) */
|
||||||
|
EXPORT void delete_zn32_vmp_pmat(ZN32_VMP_PMAT* ptr); // dimensions
|
||||||
|
|
||||||
|
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||||
|
EXPORT void zn32_vmp_prepare_contiguous( //
|
||||||
|
const MOD_Z* module,
|
||||||
|
ZN32_VMP_PMAT* pmat, // output
|
||||||
|
const int32_t* mat, uint64_t nrows, uint64_t ncols); // a
|
||||||
|
|
||||||
|
/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */
|
||||||
|
EXPORT void zn32_vmp_prepare_dblptr( //
|
||||||
|
const MOD_Z* module,
|
||||||
|
ZN32_VMP_PMAT* pmat, // output
|
||||||
|
const int32_t** mat, uint64_t nrows, uint64_t ncols); // a
|
||||||
|
|
||||||
|
/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */
|
||||||
|
EXPORT void zn32_vmp_prepare_row( //
|
||||||
|
const MOD_Z* module,
|
||||||
|
ZN32_VMP_PMAT* pmat, // output
|
||||||
|
const int32_t* row, uint64_t row_i, uint64_t nrows, uint64_t ncols); // a
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (int32_t* input) */
|
||||||
|
EXPORT void zn32_vmp_apply_i32( //
|
||||||
|
const MOD_Z* module, //
|
||||||
|
int32_t* res, uint64_t res_size, // res
|
||||||
|
const int32_t* a, uint64_t a_size, // a
|
||||||
|
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (int16_t* input) */
|
||||||
|
EXPORT void zn32_vmp_apply_i16( //
|
||||||
|
const MOD_Z* module, //
|
||||||
|
int32_t* res, uint64_t res_size, // res
|
||||||
|
const int16_t* a, uint64_t a_size, // a
|
||||||
|
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (int8_t* input) */
|
||||||
|
EXPORT void zn32_vmp_apply_i8( //
|
||||||
|
const MOD_Z* module, //
|
||||||
|
int32_t* res, uint64_t res_size, // res
|
||||||
|
const int8_t* a, uint64_t a_size, // a
|
||||||
|
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix
|
||||||
|
|
||||||
|
// explicit conversions
|
||||||
|
|
||||||
|
/** reduction mod 1, output in torus32 space */
|
||||||
|
EXPORT void dbl_to_tn32(const MOD_Z* module, //
|
||||||
|
int32_t* res, uint64_t res_size, // res
|
||||||
|
const double* a, uint64_t a_size // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** real centerlift mod 1, output in double space */
|
||||||
|
EXPORT void tn32_to_dbl(const MOD_Z* module, //
|
||||||
|
double* res, uint64_t res_size, // res
|
||||||
|
const int32_t* a, uint64_t a_size // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** round to the nearest int, output in i32 space.
|
||||||
|
* WARNING: ||a||_inf must be <= 2^18 in this function
|
||||||
|
*/
|
||||||
|
EXPORT void dbl_round_to_i32(const MOD_Z* module, //
|
||||||
|
int32_t* res, uint64_t res_size, // res
|
||||||
|
const double* a, uint64_t a_size // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** small int (int32 space) to double
|
||||||
|
* WARNING: ||a||_inf must be <= 2^18 in this function
|
||||||
|
*/
|
||||||
|
EXPORT void i32_to_dbl(const MOD_Z* module, //
|
||||||
|
double* res, uint64_t res_size, // res
|
||||||
|
const int32_t* a, uint64_t a_size // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** round to the nearest int, output in int64 space
|
||||||
|
* WARNING: ||a||_inf must be <= 2^50 in this function
|
||||||
|
*/
|
||||||
|
EXPORT void dbl_round_to_i64(const MOD_Z* module, //
|
||||||
|
int64_t* res, uint64_t res_size, // res
|
||||||
|
const double* a, uint64_t a_size // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** small int (int64 space, <= 2^50) to double
|
||||||
|
* WARNING: ||a||_inf must be <= 2^50 in this function
|
||||||
|
*/
|
||||||
|
EXPORT void i64_to_dbl(const MOD_Z* module, //
|
||||||
|
double* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size // a
|
||||||
|
);
|
||||||
|
|
||||||
|
#endif // SPQLIOS_ZN_ARITHMETIC_H
|
||||||
@@ -0,0 +1,43 @@
|
|||||||
|
#ifndef SPQLIOS_ZN_ARITHMETIC_PLUGIN_H
|
||||||
|
#define SPQLIOS_ZN_ARITHMETIC_PLUGIN_H
|
||||||
|
|
||||||
|
#include "zn_arithmetic.h"
|
||||||
|
|
||||||
|
typedef typeof(i8_approxdecomp_from_tndbl) I8_APPROXDECOMP_FROM_TNDBL_F;
|
||||||
|
typedef typeof(i16_approxdecomp_from_tndbl) I16_APPROXDECOMP_FROM_TNDBL_F;
|
||||||
|
typedef typeof(i32_approxdecomp_from_tndbl) I32_APPROXDECOMP_FROM_TNDBL_F;
|
||||||
|
typedef typeof(bytes_of_zn32_vmp_pmat) BYTES_OF_ZN32_VMP_PMAT_F;
|
||||||
|
typedef typeof(zn32_vmp_prepare_contiguous) ZN32_VMP_PREPARE_CONTIGUOUS_F;
|
||||||
|
typedef typeof(zn32_vmp_prepare_dblptr) ZN32_VMP_PREPARE_DBLPTR_F;
|
||||||
|
typedef typeof(zn32_vmp_prepare_row) ZN32_VMP_PREPARE_ROW_F;
|
||||||
|
typedef typeof(zn32_vmp_apply_i32) ZN32_VMP_APPLY_I32_F;
|
||||||
|
typedef typeof(zn32_vmp_apply_i16) ZN32_VMP_APPLY_I16_F;
|
||||||
|
typedef typeof(zn32_vmp_apply_i8) ZN32_VMP_APPLY_I8_F;
|
||||||
|
typedef typeof(dbl_to_tn32) DBL_TO_TN32_F;
|
||||||
|
typedef typeof(tn32_to_dbl) TN32_TO_DBL_F;
|
||||||
|
typedef typeof(dbl_round_to_i32) DBL_ROUND_TO_I32_F;
|
||||||
|
typedef typeof(i32_to_dbl) I32_TO_DBL_F;
|
||||||
|
typedef typeof(dbl_round_to_i64) DBL_ROUND_TO_I64_F;
|
||||||
|
typedef typeof(i64_to_dbl) I64_TO_DBL_F;
|
||||||
|
|
||||||
|
typedef struct z_module_vtable_t Z_MODULE_VTABLE;
|
||||||
|
struct z_module_vtable_t {
|
||||||
|
I8_APPROXDECOMP_FROM_TNDBL_F* i8_approxdecomp_from_tndbl;
|
||||||
|
I16_APPROXDECOMP_FROM_TNDBL_F* i16_approxdecomp_from_tndbl;
|
||||||
|
I32_APPROXDECOMP_FROM_TNDBL_F* i32_approxdecomp_from_tndbl;
|
||||||
|
BYTES_OF_ZN32_VMP_PMAT_F* bytes_of_zn32_vmp_pmat;
|
||||||
|
ZN32_VMP_PREPARE_CONTIGUOUS_F* zn32_vmp_prepare_contiguous;
|
||||||
|
ZN32_VMP_PREPARE_DBLPTR_F* zn32_vmp_prepare_dblptr;
|
||||||
|
ZN32_VMP_PREPARE_ROW_F* zn32_vmp_prepare_row;
|
||||||
|
ZN32_VMP_APPLY_I32_F* zn32_vmp_apply_i32;
|
||||||
|
ZN32_VMP_APPLY_I16_F* zn32_vmp_apply_i16;
|
||||||
|
ZN32_VMP_APPLY_I8_F* zn32_vmp_apply_i8;
|
||||||
|
DBL_TO_TN32_F* dbl_to_tn32;
|
||||||
|
TN32_TO_DBL_F* tn32_to_dbl;
|
||||||
|
DBL_ROUND_TO_I32_F* dbl_round_to_i32;
|
||||||
|
I32_TO_DBL_F* i32_to_dbl;
|
||||||
|
DBL_ROUND_TO_I64_F* dbl_round_to_i64;
|
||||||
|
I64_TO_DBL_F* i64_to_dbl;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // SPQLIOS_ZN_ARITHMETIC_PLUGIN_H
|
||||||
@@ -0,0 +1,164 @@
|
|||||||
|
#ifndef SPQLIOS_ZN_ARITHMETIC_PRIVATE_H
|
||||||
|
#define SPQLIOS_ZN_ARITHMETIC_PRIVATE_H
|
||||||
|
|
||||||
|
#include "../commons_private.h"
|
||||||
|
#include "zn_arithmetic.h"
|
||||||
|
#include "zn_arithmetic_plugin.h"
|
||||||
|
|
||||||
|
typedef struct main_z_module_precomp_t MAIN_Z_MODULE_PRECOMP;
|
||||||
|
struct main_z_module_precomp_t {
|
||||||
|
// TODO
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef union z_module_precomp_t Z_MODULE_PRECOMP;
|
||||||
|
union z_module_precomp_t {
|
||||||
|
MAIN_Z_MODULE_PRECOMP main;
|
||||||
|
};
|
||||||
|
|
||||||
|
void main_init_z_module_precomp(MOD_Z* module);
|
||||||
|
|
||||||
|
void main_finalize_z_module_precomp(MOD_Z* module);
|
||||||
|
|
||||||
|
/** @brief opaque structure that describes the modules (RnX,ZnX,TnX) and the hardware */
|
||||||
|
struct z_module_info_t {
|
||||||
|
Z_MODULE_TYPE mtype;
|
||||||
|
Z_MODULE_VTABLE vtable;
|
||||||
|
Z_MODULE_PRECOMP precomp;
|
||||||
|
void* custom;
|
||||||
|
void (*custom_deleter)(void*);
|
||||||
|
};
|
||||||
|
|
||||||
|
void init_z_module_info(MOD_Z* module, Z_MODULE_TYPE mtype);
|
||||||
|
|
||||||
|
void main_init_z_module_vtable(MOD_Z* module);
|
||||||
|
|
||||||
|
struct tndbl_approxdecomp_gadget_t {
|
||||||
|
uint64_t k;
|
||||||
|
uint64_t ell;
|
||||||
|
double add_cst; // 3.2^51-(K.ell) + 1/2.(sum 2^-(i+1)K)
|
||||||
|
int64_t and_mask; // (2^K)-1
|
||||||
|
int64_t sub_cst; // 2^(K-1)
|
||||||
|
uint8_t rshifts[64]; // 2^(ell-1-i).K for i in [0:ell-1]
|
||||||
|
};
|
||||||
|
|
||||||
|
/** @brief sets res = gadget_decompose(a) (int8_t* output) */
|
||||||
|
EXPORT void default_i8_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N
|
||||||
|
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||||
|
int8_t* res, uint64_t res_size, // res (in general, size ell.a_size)
|
||||||
|
const double* a, uint64_t a_size); // a
|
||||||
|
|
||||||
|
/** @brief sets res = gadget_decompose(a) (int16_t* output) */
|
||||||
|
EXPORT void default_i16_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N
|
||||||
|
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||||
|
int16_t* res,
|
||||||
|
uint64_t res_size, // res (in general, size ell.a_size)
|
||||||
|
const double* a, uint64_t a_size); // a
|
||||||
|
/** @brief sets res = gadget_decompose(a) (int32_t* output) */
|
||||||
|
EXPORT void default_i32_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N
|
||||||
|
const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
|
||||||
|
int32_t* res,
|
||||||
|
uint64_t res_size, // res (in general, size ell.a_size)
|
||||||
|
const double* a, uint64_t a_size); // a
|
||||||
|
|
||||||
|
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||||
|
EXPORT void default_zn32_vmp_prepare_contiguous_ref( //
|
||||||
|
const MOD_Z* module,
|
||||||
|
ZN32_VMP_PMAT* pmat, // output
|
||||||
|
const int32_t* mat, uint64_t nrows, uint64_t ncols // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */
|
||||||
|
EXPORT void default_zn32_vmp_prepare_dblptr_ref( //
|
||||||
|
const MOD_Z* module,
|
||||||
|
ZN32_VMP_PMAT* pmat, // output
|
||||||
|
const int32_t** mat, uint64_t nrows, uint64_t ncols // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */
|
||||||
|
EXPORT void default_zn32_vmp_prepare_row_ref( //
|
||||||
|
const MOD_Z* module,
|
||||||
|
ZN32_VMP_PMAT* pmat, // output
|
||||||
|
const int32_t* row, uint64_t row_i, uint64_t nrows, uint64_t ncols // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (int32_t* input) */
|
||||||
|
EXPORT void default_zn32_vmp_apply_i32_ref( //
|
||||||
|
const MOD_Z* module, //
|
||||||
|
int32_t* res, uint64_t res_size, // res
|
||||||
|
const int32_t* a, uint64_t a_size, // a
|
||||||
|
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (int16_t* input) */
|
||||||
|
EXPORT void default_zn32_vmp_apply_i16_ref( //
|
||||||
|
const MOD_Z* module, // N
|
||||||
|
int32_t* res, uint64_t res_size, // res
|
||||||
|
const int16_t* a, uint64_t a_size, // a
|
||||||
|
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (int8_t* input) */
|
||||||
|
EXPORT void default_zn32_vmp_apply_i8_ref( //
|
||||||
|
const MOD_Z* module, // N
|
||||||
|
int32_t* res, uint64_t res_size, // res
|
||||||
|
const int8_t* a, uint64_t a_size, // a
|
||||||
|
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (int32_t* input) */
|
||||||
|
EXPORT void default_zn32_vmp_apply_i32_avx( //
|
||||||
|
const MOD_Z* module, //
|
||||||
|
int32_t* res, uint64_t res_size, // res
|
||||||
|
const int32_t* a, uint64_t a_size, // a
|
||||||
|
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (int16_t* input) */
|
||||||
|
EXPORT void default_zn32_vmp_apply_i16_avx( //
|
||||||
|
const MOD_Z* module, // N
|
||||||
|
int32_t* res, uint64_t res_size, // res
|
||||||
|
const int16_t* a, uint64_t a_size, // a
|
||||||
|
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (int8_t* input) */
|
||||||
|
EXPORT void default_zn32_vmp_apply_i8_avx( //
|
||||||
|
const MOD_Z* module, // N
|
||||||
|
int32_t* res, uint64_t res_size, // res
|
||||||
|
const int8_t* a, uint64_t a_size, // a
|
||||||
|
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix
|
||||||
|
|
||||||
|
// explicit conversions
|
||||||
|
|
||||||
|
/** reduction mod 1, output in torus32 space */
|
||||||
|
EXPORT void dbl_to_tn32_ref(const MOD_Z* module, //
|
||||||
|
int32_t* res, uint64_t res_size, // res
|
||||||
|
const double* a, uint64_t a_size // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** real centerlift mod 1, output in double space */
|
||||||
|
EXPORT void tn32_to_dbl_ref(const MOD_Z* module, //
|
||||||
|
double* res, uint64_t res_size, // res
|
||||||
|
const int32_t* a, uint64_t a_size // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** round to the nearest int, output in i32 space */
|
||||||
|
EXPORT void dbl_round_to_i32_ref(const MOD_Z* module, //
|
||||||
|
int32_t* res, uint64_t res_size, // res
|
||||||
|
const double* a, uint64_t a_size // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** small int (int32 space) to double */
|
||||||
|
EXPORT void i32_to_dbl_ref(const MOD_Z* module, //
|
||||||
|
double* res, uint64_t res_size, // res
|
||||||
|
const int32_t* a, uint64_t a_size // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** round to the nearest int, output in int64 space */
|
||||||
|
EXPORT void dbl_round_to_i64_ref(const MOD_Z* module, //
|
||||||
|
int64_t* res, uint64_t res_size, // res
|
||||||
|
const double* a, uint64_t a_size // a
|
||||||
|
);
|
||||||
|
|
||||||
|
/** small int (int64 space) to double */
|
||||||
|
EXPORT void i64_to_dbl_ref(const MOD_Z* module, //
|
||||||
|
double* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size // a
|
||||||
|
);
|
||||||
|
|
||||||
|
#endif // SPQLIOS_ZN_ARITHMETIC_PRIVATE_H
|
||||||
@@ -0,0 +1,108 @@
|
|||||||
|
#include <memory.h>
|
||||||
|
|
||||||
|
#include "zn_arithmetic_private.h"
|
||||||
|
|
||||||
|
typedef union {
|
||||||
|
double dv;
|
||||||
|
int64_t s64v;
|
||||||
|
int32_t s32v;
|
||||||
|
uint64_t u64v;
|
||||||
|
uint32_t u32v;
|
||||||
|
} di_t;
|
||||||
|
|
||||||
|
/** reduction mod 1, output in torus32 space */
|
||||||
|
EXPORT void dbl_to_tn32_ref(const MOD_Z* module, //
|
||||||
|
int32_t* res, uint64_t res_size, // res
|
||||||
|
const double* a, uint64_t a_size // a
|
||||||
|
) {
|
||||||
|
static const double ADD_CST = 0.5 + (double)(INT64_C(3) << (51 - 32));
|
||||||
|
static const int32_t XOR_CST = (INT32_C(1) << 31);
|
||||||
|
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < msize; ++i) {
|
||||||
|
di_t t = {.dv = a[i] + ADD_CST};
|
||||||
|
res[i] = t.s32v ^ XOR_CST;
|
||||||
|
}
|
||||||
|
memset(res + msize, 0, (res_size - msize) * sizeof(int32_t));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** real centerlift mod 1, output in double space */
|
||||||
|
EXPORT void tn32_to_dbl_ref(const MOD_Z* module, //
|
||||||
|
double* res, uint64_t res_size, // res
|
||||||
|
const int32_t* a, uint64_t a_size // a
|
||||||
|
) {
|
||||||
|
static const uint32_t XOR_CST = (UINT32_C(1) << 31);
|
||||||
|
static const di_t OR_CST = {.dv = (double)(INT64_C(1) << (52 - 32))};
|
||||||
|
static const double SUB_CST = 0.5 + (double)(INT64_C(1) << (52 - 32));
|
||||||
|
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < msize; ++i) {
|
||||||
|
uint32_t ai = a[i] ^ XOR_CST;
|
||||||
|
di_t t = {.u64v = OR_CST.u64v | (uint64_t)ai};
|
||||||
|
res[i] = t.dv - SUB_CST;
|
||||||
|
}
|
||||||
|
memset(res + msize, 0, (res_size - msize) * sizeof(double));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** round to the nearest int, output in i32 space */
|
||||||
|
EXPORT void dbl_round_to_i32_ref(const MOD_Z* module, //
|
||||||
|
int32_t* res, uint64_t res_size, // res
|
||||||
|
const double* a, uint64_t a_size // a
|
||||||
|
) {
|
||||||
|
static const double ADD_CST = (double)((INT64_C(3) << (51)) + (INT64_C(1) << (31)));
|
||||||
|
static const int32_t XOR_CST = INT32_C(1) << 31;
|
||||||
|
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < msize; ++i) {
|
||||||
|
di_t t = {.dv = a[i] + ADD_CST};
|
||||||
|
res[i] = t.s32v ^ XOR_CST;
|
||||||
|
}
|
||||||
|
memset(res + msize, 0, (res_size - msize) * sizeof(int32_t));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** small int (int32 space) to double */
|
||||||
|
EXPORT void i32_to_dbl_ref(const MOD_Z* module, //
|
||||||
|
double* res, uint64_t res_size, // res
|
||||||
|
const int32_t* a, uint64_t a_size // a
|
||||||
|
) {
|
||||||
|
static const uint32_t XOR_CST = (UINT32_C(1) << 31);
|
||||||
|
static const di_t OR_CST = {.dv = (double)(INT64_C(1) << 52)};
|
||||||
|
static const double SUB_CST = (double)((INT64_C(1) << 52) + (INT64_C(1) << 31));
|
||||||
|
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < msize; ++i) {
|
||||||
|
uint32_t ai = a[i] ^ XOR_CST;
|
||||||
|
di_t t = {.u64v = OR_CST.u64v | (uint64_t)ai};
|
||||||
|
res[i] = t.dv - SUB_CST;
|
||||||
|
}
|
||||||
|
memset(res + msize, 0, (res_size - msize) * sizeof(double));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** round to the nearest int, output in int64 space */
|
||||||
|
EXPORT void dbl_round_to_i64_ref(const MOD_Z* module, //
|
||||||
|
int64_t* res, uint64_t res_size, // res
|
||||||
|
const double* a, uint64_t a_size // a
|
||||||
|
) {
|
||||||
|
static const double ADD_CST = (double)(INT64_C(3) << (51));
|
||||||
|
static const int64_t AND_CST = (INT64_C(1) << 52) - 1;
|
||||||
|
static const int64_t SUB_CST = INT64_C(1) << 51;
|
||||||
|
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < msize; ++i) {
|
||||||
|
di_t t = {.dv = a[i] + ADD_CST};
|
||||||
|
res[i] = (t.s64v & AND_CST) - SUB_CST;
|
||||||
|
}
|
||||||
|
memset(res + msize, 0, (res_size - msize) * sizeof(int64_t));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** small int (int64 space) to double */
|
||||||
|
EXPORT void i64_to_dbl_ref(const MOD_Z* module, //
|
||||||
|
double* res, uint64_t res_size, // res
|
||||||
|
const int64_t* a, uint64_t a_size // a
|
||||||
|
) {
|
||||||
|
static const uint64_t ADD_CST = UINT64_C(1) << 51;
|
||||||
|
static const uint64_t AND_CST = (UINT64_C(1) << 52) - 1;
|
||||||
|
static const di_t OR_CST = {.dv = (INT64_C(1) << 52)};
|
||||||
|
static const double SUB_CST = INT64_C(3) << 51;
|
||||||
|
const uint64_t msize = res_size < a_size ? res_size : a_size;
|
||||||
|
for (uint64_t i = 0; i < msize; ++i) {
|
||||||
|
di_t t = {.u64v = ((a[i] + ADD_CST) & AND_CST) | OR_CST.u64v};
|
||||||
|
res[i] = t.dv - SUB_CST;
|
||||||
|
}
|
||||||
|
memset(res + msize, 0, (res_size - msize) * sizeof(double));
|
||||||
|
}
|
||||||
@@ -0,0 +1,4 @@
|
|||||||
|
#define INTTYPE int16_t
|
||||||
|
#define INTSN i16
|
||||||
|
|
||||||
|
#include "zn_vmp_int32_avx.c"
|
||||||
@@ -0,0 +1,4 @@
|
|||||||
|
#define INTTYPE int16_t
|
||||||
|
#define INTSN i16
|
||||||
|
|
||||||
|
#include "zn_vmp_int32_ref.c"
|
||||||
@@ -0,0 +1,223 @@
|
|||||||
|
// This file is actually a template: it will be compiled multiple times with
|
||||||
|
// different INTTYPES
|
||||||
|
#ifndef INTTYPE
|
||||||
|
#define INTTYPE int32_t
|
||||||
|
#define INTSN i32
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include <immintrin.h>
|
||||||
|
#include <memory.h>
|
||||||
|
|
||||||
|
#include "zn_arithmetic_private.h"
|
||||||
|
|
||||||
|
#define concat_inner(aa, bb, cc) aa##_##bb##_##cc
|
||||||
|
#define concat(aa, bb, cc) concat_inner(aa, bb, cc)
|
||||||
|
#define zn32_vec_fn(cc) concat(zn32_vec, INTSN, cc)
|
||||||
|
|
||||||
|
static void zn32_vec_mat32cols_avx_prefetch(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b) {
|
||||||
|
if (nrows == 0) {
|
||||||
|
memset(res, 0, 32 * sizeof(int32_t));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const int32_t* bb = b;
|
||||||
|
const int32_t* pref_bb = b;
|
||||||
|
const uint64_t pref_iters = 128;
|
||||||
|
const uint64_t pref_start = pref_iters < nrows ? pref_iters : nrows;
|
||||||
|
const uint64_t pref_last = pref_iters > nrows ? 0 : nrows - pref_iters;
|
||||||
|
// let's do some prefetching of the GSW key, since on some cpus,
|
||||||
|
// it helps
|
||||||
|
for (uint64_t i = 0; i < pref_start; ++i) {
|
||||||
|
__builtin_prefetch(pref_bb, 0, _MM_HINT_T0);
|
||||||
|
__builtin_prefetch(pref_bb + 16, 0, _MM_HINT_T0);
|
||||||
|
pref_bb += 32;
|
||||||
|
}
|
||||||
|
// we do the first iteration
|
||||||
|
__m256i x = _mm256_set1_epi32(a[0]);
|
||||||
|
__m256i r0 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb)));
|
||||||
|
__m256i r1 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8)));
|
||||||
|
__m256i r2 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16)));
|
||||||
|
__m256i r3 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 24)));
|
||||||
|
bb += 32;
|
||||||
|
uint64_t row = 1;
|
||||||
|
for (; //
|
||||||
|
row < pref_last; //
|
||||||
|
++row, bb += 32) {
|
||||||
|
// prefetch the next iteration
|
||||||
|
__builtin_prefetch(pref_bb, 0, _MM_HINT_T0);
|
||||||
|
__builtin_prefetch(pref_bb + 16, 0, _MM_HINT_T0);
|
||||||
|
pref_bb += 32;
|
||||||
|
INTTYPE ai = a[row];
|
||||||
|
if (ai == 0) continue;
|
||||||
|
x = _mm256_set1_epi32(ai);
|
||||||
|
r0 = _mm256_add_epi32(r0, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb))));
|
||||||
|
r1 = _mm256_add_epi32(r1, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8))));
|
||||||
|
r2 = _mm256_add_epi32(r2, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16))));
|
||||||
|
r3 = _mm256_add_epi32(r3, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 24))));
|
||||||
|
}
|
||||||
|
for (; //
|
||||||
|
row < nrows; //
|
||||||
|
++row, bb += 32) {
|
||||||
|
INTTYPE ai = a[row];
|
||||||
|
if (ai == 0) continue;
|
||||||
|
x = _mm256_set1_epi32(ai);
|
||||||
|
r0 = _mm256_add_epi32(r0, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb))));
|
||||||
|
r1 = _mm256_add_epi32(r1, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8))));
|
||||||
|
r2 = _mm256_add_epi32(r2, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16))));
|
||||||
|
r3 = _mm256_add_epi32(r3, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 24))));
|
||||||
|
}
|
||||||
|
_mm256_storeu_si256((__m256i*)(res), r0);
|
||||||
|
_mm256_storeu_si256((__m256i*)(res + 8), r1);
|
||||||
|
_mm256_storeu_si256((__m256i*)(res + 16), r2);
|
||||||
|
_mm256_storeu_si256((__m256i*)(res + 24), r3);
|
||||||
|
}
|
||||||
|
|
||||||
|
void zn32_vec_fn(mat32cols_avx)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) {
|
||||||
|
if (nrows == 0) {
|
||||||
|
memset(res, 0, 32 * sizeof(int32_t));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const INTTYPE* aa = a;
|
||||||
|
const INTTYPE* const aaend = a + nrows;
|
||||||
|
const int32_t* bb = b;
|
||||||
|
__m256i x = _mm256_set1_epi32(*aa);
|
||||||
|
__m256i r0 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb)));
|
||||||
|
__m256i r1 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8)));
|
||||||
|
__m256i r2 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16)));
|
||||||
|
__m256i r3 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 24)));
|
||||||
|
bb += b_sl;
|
||||||
|
++aa;
|
||||||
|
for (; //
|
||||||
|
aa < aaend; //
|
||||||
|
bb += b_sl, ++aa) {
|
||||||
|
INTTYPE ai = *aa;
|
||||||
|
if (ai == 0) continue;
|
||||||
|
x = _mm256_set1_epi32(ai);
|
||||||
|
r0 = _mm256_add_epi32(r0, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb))));
|
||||||
|
r1 = _mm256_add_epi32(r1, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8))));
|
||||||
|
r2 = _mm256_add_epi32(r2, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16))));
|
||||||
|
r3 = _mm256_add_epi32(r3, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 24))));
|
||||||
|
}
|
||||||
|
_mm256_storeu_si256((__m256i*)(res), r0);
|
||||||
|
_mm256_storeu_si256((__m256i*)(res + 8), r1);
|
||||||
|
_mm256_storeu_si256((__m256i*)(res + 16), r2);
|
||||||
|
_mm256_storeu_si256((__m256i*)(res + 24), r3);
|
||||||
|
}
|
||||||
|
|
||||||
|
void zn32_vec_fn(mat24cols_avx)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) {
|
||||||
|
if (nrows == 0) {
|
||||||
|
memset(res, 0, 24 * sizeof(int32_t));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const INTTYPE* aa = a;
|
||||||
|
const INTTYPE* const aaend = a + nrows;
|
||||||
|
const int32_t* bb = b;
|
||||||
|
__m256i x = _mm256_set1_epi32(*aa);
|
||||||
|
__m256i r0 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb)));
|
||||||
|
__m256i r1 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8)));
|
||||||
|
__m256i r2 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16)));
|
||||||
|
bb += b_sl;
|
||||||
|
++aa;
|
||||||
|
for (; //
|
||||||
|
aa < aaend; //
|
||||||
|
bb += b_sl, ++aa) {
|
||||||
|
INTTYPE ai = *aa;
|
||||||
|
if (ai == 0) continue;
|
||||||
|
x = _mm256_set1_epi32(ai);
|
||||||
|
r0 = _mm256_add_epi32(r0, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb))));
|
||||||
|
r1 = _mm256_add_epi32(r1, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8))));
|
||||||
|
r2 = _mm256_add_epi32(r2, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16))));
|
||||||
|
}
|
||||||
|
_mm256_storeu_si256((__m256i*)(res), r0);
|
||||||
|
_mm256_storeu_si256((__m256i*)(res + 8), r1);
|
||||||
|
_mm256_storeu_si256((__m256i*)(res + 16), r2);
|
||||||
|
}
|
||||||
|
void zn32_vec_fn(mat16cols_avx)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) {
|
||||||
|
if (nrows == 0) {
|
||||||
|
memset(res, 0, 16 * sizeof(int32_t));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const INTTYPE* aa = a;
|
||||||
|
const INTTYPE* const aaend = a + nrows;
|
||||||
|
const int32_t* bb = b;
|
||||||
|
__m256i x = _mm256_set1_epi32(*aa);
|
||||||
|
__m256i r0 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb)));
|
||||||
|
__m256i r1 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8)));
|
||||||
|
bb += b_sl;
|
||||||
|
++aa;
|
||||||
|
for (; //
|
||||||
|
aa < aaend; //
|
||||||
|
bb += b_sl, ++aa) {
|
||||||
|
INTTYPE ai = *aa;
|
||||||
|
if (ai == 0) continue;
|
||||||
|
x = _mm256_set1_epi32(ai);
|
||||||
|
r0 = _mm256_add_epi32(r0, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb))));
|
||||||
|
r1 = _mm256_add_epi32(r1, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8))));
|
||||||
|
}
|
||||||
|
_mm256_storeu_si256((__m256i*)(res), r0);
|
||||||
|
_mm256_storeu_si256((__m256i*)(res + 8), r1);
|
||||||
|
}
|
||||||
|
|
||||||
|
void zn32_vec_fn(mat8cols_avx)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) {
|
||||||
|
if (nrows == 0) {
|
||||||
|
memset(res, 0, 8 * sizeof(int32_t));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const INTTYPE* aa = a;
|
||||||
|
const INTTYPE* const aaend = a + nrows;
|
||||||
|
const int32_t* bb = b;
|
||||||
|
__m256i x = _mm256_set1_epi32(*aa);
|
||||||
|
__m256i r0 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb)));
|
||||||
|
bb += b_sl;
|
||||||
|
++aa;
|
||||||
|
for (; //
|
||||||
|
aa < aaend; //
|
||||||
|
bb += b_sl, ++aa) {
|
||||||
|
INTTYPE ai = *aa;
|
||||||
|
if (ai == 0) continue;
|
||||||
|
x = _mm256_set1_epi32(ai);
|
||||||
|
r0 = _mm256_add_epi32(r0, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb))));
|
||||||
|
}
|
||||||
|
_mm256_storeu_si256((__m256i*)(res), r0);
|
||||||
|
}
|
||||||
|
|
||||||
|
typedef void (*vm_f)(uint64_t nrows, //
|
||||||
|
int32_t* res, //
|
||||||
|
const INTTYPE* a, //
|
||||||
|
const int32_t* b, uint64_t b_sl //
|
||||||
|
);
|
||||||
|
static const vm_f zn32_vec_mat8kcols_avx[4] = { //
|
||||||
|
zn32_vec_fn(mat8cols_avx), //
|
||||||
|
zn32_vec_fn(mat16cols_avx), //
|
||||||
|
zn32_vec_fn(mat24cols_avx), //
|
||||||
|
zn32_vec_fn(mat32cols_avx)};
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (int32_t* input) */
|
||||||
|
EXPORT void concat(default_zn32_vmp_apply, INTSN, avx)( //
|
||||||
|
const MOD_Z* module, //
|
||||||
|
int32_t* res, uint64_t res_size, //
|
||||||
|
const INTTYPE* a, uint64_t a_size, //
|
||||||
|
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) {
|
||||||
|
const uint64_t rows = a_size < nrows ? a_size : nrows;
|
||||||
|
const uint64_t cols = res_size < ncols ? res_size : ncols;
|
||||||
|
const uint64_t ncolblk = cols >> 5;
|
||||||
|
const uint64_t ncolrem = cols & 31;
|
||||||
|
// copy the first full blocks
|
||||||
|
const uint64_t full_blk_size = nrows * 32;
|
||||||
|
const int32_t* mat = (int32_t*)pmat;
|
||||||
|
int32_t* rr = res;
|
||||||
|
for (uint64_t blk = 0; //
|
||||||
|
blk < ncolblk; //
|
||||||
|
++blk, mat += full_blk_size, rr += 32) {
|
||||||
|
zn32_vec_mat32cols_avx_prefetch(rows, rr, a, mat);
|
||||||
|
}
|
||||||
|
// last block
|
||||||
|
if (ncolrem) {
|
||||||
|
uint64_t orig_rem = ncols - (ncolblk << 5);
|
||||||
|
uint64_t b_sl = orig_rem >= 32 ? 32 : orig_rem;
|
||||||
|
int32_t tmp[32];
|
||||||
|
zn32_vec_mat8kcols_avx[(ncolrem - 1) >> 3](rows, tmp, a, mat, b_sl);
|
||||||
|
memcpy(rr, tmp, ncolrem * sizeof(int32_t));
|
||||||
|
}
|
||||||
|
// trailing bytes
|
||||||
|
memset(res + cols, 0, (res_size - cols) * sizeof(int32_t));
|
||||||
|
}
|
||||||
@@ -0,0 +1,88 @@
|
|||||||
|
// This file is actually a template: it will be compiled multiple times with
|
||||||
|
// different INTTYPES
|
||||||
|
#ifndef INTTYPE
|
||||||
|
#define INTTYPE int32_t
|
||||||
|
#define INTSN i32
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include <memory.h>
|
||||||
|
|
||||||
|
#include "zn_arithmetic_private.h"
|
||||||
|
|
||||||
|
#define concat_inner(aa, bb, cc) aa##_##bb##_##cc
|
||||||
|
#define concat(aa, bb, cc) concat_inner(aa, bb, cc)
|
||||||
|
#define zn32_vec_fn(cc) concat(zn32_vec, INTSN, cc)
|
||||||
|
|
||||||
|
// the ref version shares the same implementation for each fixed column size
|
||||||
|
// optimized implementations may do something different.
|
||||||
|
static __always_inline void IMPL_zn32_vec_matcols_ref(
|
||||||
|
const uint64_t NCOLS, // fixed number of columns
|
||||||
|
uint64_t nrows, // nrows of b
|
||||||
|
int32_t* res, // result: size NCOLS, only the first min(b_sl, NCOLS) are relevant
|
||||||
|
const INTTYPE* a, // a: nrows-sized vector
|
||||||
|
const int32_t* b, uint64_t b_sl // b: nrows * min(b_sl, NCOLS) matrix
|
||||||
|
) {
|
||||||
|
memset(res, 0, NCOLS * sizeof(int32_t));
|
||||||
|
for (uint64_t row = 0; row < nrows; ++row) {
|
||||||
|
int32_t ai = a[row];
|
||||||
|
const int32_t* bb = b + row * b_sl;
|
||||||
|
for (uint64_t i = 0; i < NCOLS; ++i) {
|
||||||
|
res[i] += ai * bb[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void zn32_vec_fn(mat32cols_ref)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) {
|
||||||
|
IMPL_zn32_vec_matcols_ref(32, nrows, res, a, b, b_sl);
|
||||||
|
}
|
||||||
|
void zn32_vec_fn(mat24cols_ref)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) {
|
||||||
|
IMPL_zn32_vec_matcols_ref(24, nrows, res, a, b, b_sl);
|
||||||
|
}
|
||||||
|
void zn32_vec_fn(mat16cols_ref)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) {
|
||||||
|
IMPL_zn32_vec_matcols_ref(16, nrows, res, a, b, b_sl);
|
||||||
|
}
|
||||||
|
void zn32_vec_fn(mat8cols_ref)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) {
|
||||||
|
IMPL_zn32_vec_matcols_ref(8, nrows, res, a, b, b_sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
typedef void (*vm_f)(uint64_t nrows, //
|
||||||
|
int32_t* res, //
|
||||||
|
const INTTYPE* a, //
|
||||||
|
const int32_t* b, uint64_t b_sl //
|
||||||
|
);
|
||||||
|
static const vm_f zn32_vec_mat8kcols_ref[4] = { //
|
||||||
|
zn32_vec_fn(mat8cols_ref), //
|
||||||
|
zn32_vec_fn(mat16cols_ref), //
|
||||||
|
zn32_vec_fn(mat24cols_ref), //
|
||||||
|
zn32_vec_fn(mat32cols_ref)};
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (int32_t* input) */
|
||||||
|
EXPORT void concat(default_zn32_vmp_apply, INTSN, ref)( //
|
||||||
|
const MOD_Z* module, //
|
||||||
|
int32_t* res, uint64_t res_size, //
|
||||||
|
const INTTYPE* a, uint64_t a_size, //
|
||||||
|
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) {
|
||||||
|
const uint64_t rows = a_size < nrows ? a_size : nrows;
|
||||||
|
const uint64_t cols = res_size < ncols ? res_size : ncols;
|
||||||
|
const uint64_t ncolblk = cols >> 5;
|
||||||
|
const uint64_t ncolrem = cols & 31;
|
||||||
|
// copy the first full blocks
|
||||||
|
const uint32_t full_blk_size = nrows * 32;
|
||||||
|
const int32_t* mat = (int32_t*)pmat;
|
||||||
|
int32_t* rr = res;
|
||||||
|
for (uint64_t blk = 0; //
|
||||||
|
blk < ncolblk; //
|
||||||
|
++blk, mat += full_blk_size, rr += 32) {
|
||||||
|
zn32_vec_fn(mat32cols_ref)(rows, rr, a, mat, 32);
|
||||||
|
}
|
||||||
|
// last block
|
||||||
|
if (ncolrem) {
|
||||||
|
uint64_t orig_rem = ncols - (ncolblk << 5);
|
||||||
|
uint64_t b_sl = orig_rem >= 32 ? 32 : orig_rem;
|
||||||
|
int32_t tmp[32];
|
||||||
|
zn32_vec_mat8kcols_ref[(ncolrem - 1) >> 3](rows, tmp, a, mat, b_sl);
|
||||||
|
memcpy(rr, tmp, ncolrem * sizeof(int32_t));
|
||||||
|
}
|
||||||
|
// trailing bytes
|
||||||
|
memset(res + cols, 0, (res_size - cols) * sizeof(int32_t));
|
||||||
|
}
|
||||||
@@ -0,0 +1,4 @@
|
|||||||
|
#define INTTYPE int8_t
|
||||||
|
#define INTSN i8
|
||||||
|
|
||||||
|
#include "zn_vmp_int32_avx.c"
|
||||||
@@ -0,0 +1,4 @@
|
|||||||
|
#define INTTYPE int8_t
|
||||||
|
#define INTSN i8
|
||||||
|
|
||||||
|
#include "zn_vmp_int32_ref.c"
|
||||||
@@ -0,0 +1,185 @@
|
|||||||
|
#include <memory.h>
|
||||||
|
|
||||||
|
#include "zn_arithmetic_private.h"
|
||||||
|
|
||||||
|
/** @brief size in bytes of a prepared matrix (for custom allocation) */
|
||||||
|
EXPORT uint64_t bytes_of_zn32_vmp_pmat(const MOD_Z* module, // N
|
||||||
|
uint64_t nrows, uint64_t ncols // dimensions
|
||||||
|
) {
|
||||||
|
return (nrows * ncols + 7) * sizeof(int32_t);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief allocates a prepared matrix (release with delete_zn32_vmp_pmat) */
|
||||||
|
EXPORT ZN32_VMP_PMAT* new_zn32_vmp_pmat(const MOD_Z* module, // N
|
||||||
|
uint64_t nrows, uint64_t ncols) {
|
||||||
|
return (ZN32_VMP_PMAT*)spqlios_alloc(bytes_of_zn32_vmp_pmat(module, nrows, ncols));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief deletes a prepared matrix (release with free) */
|
||||||
|
EXPORT void delete_zn32_vmp_pmat(ZN32_VMP_PMAT* ptr) { spqlios_free(ptr); }
|
||||||
|
|
||||||
|
/** @brief prepares a vmp matrix (contiguous row-major version) */
|
||||||
|
EXPORT void default_zn32_vmp_prepare_contiguous_ref( //
|
||||||
|
const MOD_Z* module,
|
||||||
|
ZN32_VMP_PMAT* pmat, // output
|
||||||
|
const int32_t* mat, uint64_t nrows, uint64_t ncols // a
|
||||||
|
) {
|
||||||
|
int32_t* const out = (int32_t*)pmat;
|
||||||
|
const uint64_t nblk = ncols >> 5;
|
||||||
|
const uint64_t ncols_rem = ncols & 31;
|
||||||
|
const uint64_t final_elems = (8 - nrows * ncols) & 7;
|
||||||
|
for (uint64_t blk = 0; blk < nblk; ++blk) {
|
||||||
|
int32_t* outblk = out + blk * nrows * 32;
|
||||||
|
const int32_t* srcblk = mat + blk * 32;
|
||||||
|
for (uint64_t row = 0; row < nrows; ++row) {
|
||||||
|
int32_t* dest = outblk + row * 32;
|
||||||
|
const int32_t* src = srcblk + row * ncols;
|
||||||
|
for (uint64_t i = 0; i < 32; ++i) {
|
||||||
|
dest[i] = src[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// copy the last block if any
|
||||||
|
if (ncols_rem) {
|
||||||
|
int32_t* outblk = out + nblk * nrows * 32;
|
||||||
|
const int32_t* srcblk = mat + nblk * 32;
|
||||||
|
for (uint64_t row = 0; row < nrows; ++row) {
|
||||||
|
int32_t* dest = outblk + row * ncols_rem;
|
||||||
|
const int32_t* src = srcblk + row * ncols;
|
||||||
|
for (uint64_t i = 0; i < ncols_rem; ++i) {
|
||||||
|
dest[i] = src[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// zero-out the final elements that may be accessed
|
||||||
|
if (final_elems) {
|
||||||
|
int32_t* f = out + nrows * ncols;
|
||||||
|
for (uint64_t i = 0; i < final_elems; ++i) {
|
||||||
|
f[i] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */
|
||||||
|
EXPORT void default_zn32_vmp_prepare_dblptr_ref( //
|
||||||
|
const MOD_Z* module,
|
||||||
|
ZN32_VMP_PMAT* pmat, // output
|
||||||
|
const int32_t** mat, uint64_t nrows, uint64_t ncols // a
|
||||||
|
) {
|
||||||
|
for (uint64_t row_i = 0; row_i < nrows; ++row_i) {
|
||||||
|
default_zn32_vmp_prepare_row_ref(module, pmat, mat[row_i], row_i, nrows, ncols);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */
|
||||||
|
EXPORT void default_zn32_vmp_prepare_row_ref( //
|
||||||
|
const MOD_Z* module,
|
||||||
|
ZN32_VMP_PMAT* pmat, // output
|
||||||
|
const int32_t* row, uint64_t row_i, uint64_t nrows, uint64_t ncols // a
|
||||||
|
) {
|
||||||
|
int32_t* const out = (int32_t*)pmat;
|
||||||
|
const uint64_t nblk = ncols >> 5;
|
||||||
|
const uint64_t ncols_rem = ncols & 31;
|
||||||
|
const uint64_t final_elems = (row_i == nrows - 1) && (8 - nrows * ncols) & 7;
|
||||||
|
for (uint64_t blk = 0; blk < nblk; ++blk) {
|
||||||
|
int32_t* outblk = out + blk * nrows * 32;
|
||||||
|
int32_t* dest = outblk + row_i * 32;
|
||||||
|
const int32_t* src = row + blk * 32;
|
||||||
|
for (uint64_t i = 0; i < 32; ++i) {
|
||||||
|
dest[i] = src[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// copy the last block if any
|
||||||
|
if (ncols_rem) {
|
||||||
|
int32_t* outblk = out + nblk * nrows * 32;
|
||||||
|
int32_t* dest = outblk + row_i * ncols_rem;
|
||||||
|
const int32_t* src = row + nblk * 32;
|
||||||
|
for (uint64_t i = 0; i < ncols_rem; ++i) {
|
||||||
|
dest[i] = src[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// zero-out the final elements that may be accessed
|
||||||
|
if (final_elems) {
|
||||||
|
int32_t* f = out + nrows * ncols;
|
||||||
|
for (uint64_t i = 0; i < final_elems; ++i) {
|
||||||
|
f[i] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#if 0
|
||||||
|
|
||||||
|
#define IMPL_zn32_vec_ixxx_matyyycols_ref(NCOLS) \
|
||||||
|
memset(res, 0, NCOLS * sizeof(int32_t)); \
|
||||||
|
for (uint64_t row = 0; row < nrows; ++row) { \
|
||||||
|
int32_t ai = a[row]; \
|
||||||
|
const int32_t* bb = b + row * b_sl; \
|
||||||
|
for (uint64_t i = 0; i < NCOLS; ++i) { \
|
||||||
|
res[i] += ai * bb[i]; \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define IMPL_zn32_vec_ixxx_mat8cols_ref() IMPL_zn32_vec_ixxx_matyyycols_ref(8)
|
||||||
|
#define IMPL_zn32_vec_ixxx_mat16cols_ref() IMPL_zn32_vec_ixxx_matyyycols_ref(16)
|
||||||
|
#define IMPL_zn32_vec_ixxx_mat24cols_ref() IMPL_zn32_vec_ixxx_matyyycols_ref(24)
|
||||||
|
#define IMPL_zn32_vec_ixxx_mat32cols_ref() IMPL_zn32_vec_ixxx_matyyycols_ref(32)
|
||||||
|
|
||||||
|
void zn32_vec_i8_mat32cols_ref(uint64_t nrows, int32_t* res, const int8_t* a, const int32_t* b, uint64_t b_sl) {
|
||||||
|
IMPL_zn32_vec_ixxx_mat32cols_ref()
|
||||||
|
}
|
||||||
|
void zn32_vec_i16_mat32cols_ref(uint64_t nrows, int32_t* res, const int16_t* a, const int32_t* b, uint64_t b_sl) {
|
||||||
|
IMPL_zn32_vec_ixxx_mat32cols_ref()
|
||||||
|
}
|
||||||
|
|
||||||
|
void zn32_vec_i32_mat32cols_ref(uint64_t nrows, int32_t* res, const int32_t* a, const int32_t* b, uint64_t b_sl) {
|
||||||
|
IMPL_zn32_vec_ixxx_mat32cols_ref()
|
||||||
|
}
|
||||||
|
void zn32_vec_i32_mat24cols_ref(uint64_t nrows, int32_t* res, const int32_t* a, const int32_t* b, uint64_t b_sl) {
|
||||||
|
IMPL_zn32_vec_ixxx_mat24cols_ref()
|
||||||
|
}
|
||||||
|
void zn32_vec_i32_mat16cols_ref(uint64_t nrows, int32_t* res, const int32_t* a, const int32_t* b, uint64_t b_sl) {
|
||||||
|
IMPL_zn32_vec_ixxx_mat16cols_ref()
|
||||||
|
}
|
||||||
|
void zn32_vec_i32_mat8cols_ref(uint64_t nrows, int32_t* res, const int32_t* a, const int32_t* b, uint64_t b_sl) {
|
||||||
|
IMPL_zn32_vec_ixxx_mat8cols_ref()
|
||||||
|
}
|
||||||
|
typedef void (*zn32_vec_i32_mat8kcols_ref_f)(uint64_t nrows, //
|
||||||
|
int32_t* res, //
|
||||||
|
const int32_t* a, //
|
||||||
|
const int32_t* b, uint64_t b_sl //
|
||||||
|
);
|
||||||
|
zn32_vec_i32_mat8kcols_ref_f zn32_vec_i32_mat8kcols_ref[4] = { //
|
||||||
|
zn32_vec_i32_mat8cols_ref, zn32_vec_i32_mat16cols_ref, //
|
||||||
|
zn32_vec_i32_mat24cols_ref, zn32_vec_i32_mat32cols_ref};
|
||||||
|
|
||||||
|
/** @brief applies a vmp product (int32_t* input) */
|
||||||
|
EXPORT void default_zn32_vmp_apply_i32_ref(const MOD_Z* module, //
|
||||||
|
int32_t* res, uint64_t res_size, //
|
||||||
|
const int32_t* a, uint64_t a_size, //
|
||||||
|
const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) {
|
||||||
|
const uint64_t rows = a_size < nrows ? a_size : nrows;
|
||||||
|
const uint64_t cols = res_size < ncols ? res_size : ncols;
|
||||||
|
const uint64_t ncolblk = cols >> 5;
|
||||||
|
const uint64_t ncolrem = cols & 31;
|
||||||
|
// copy the first full blocks
|
||||||
|
const uint32_t full_blk_size = nrows * 32;
|
||||||
|
const int32_t* mat = (int32_t*)pmat;
|
||||||
|
int32_t* rr = res;
|
||||||
|
for (uint64_t blk = 0; //
|
||||||
|
blk < ncolblk; //
|
||||||
|
++blk, mat += full_blk_size, rr += 32) {
|
||||||
|
zn32_vec_i32_mat32cols_ref(rows, rr, a, mat, 32);
|
||||||
|
}
|
||||||
|
// last block
|
||||||
|
if (ncolrem) {
|
||||||
|
uint64_t orig_rem = ncols - (ncolblk << 5);
|
||||||
|
uint64_t b_sl = orig_rem >= 32 ? 32 : orig_rem;
|
||||||
|
int32_t tmp[32];
|
||||||
|
zn32_vec_i32_mat8kcols_ref[(ncolrem - 1) >> 3](rows, tmp, a, mat, b_sl);
|
||||||
|
memcpy(rr, tmp, ncolrem * sizeof(int32_t));
|
||||||
|
}
|
||||||
|
// trailing bytes
|
||||||
|
memset(res + cols, 0, (res_size - cols) * sizeof(int32_t));
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
@@ -0,0 +1,38 @@
|
|||||||
|
#include "vec_znx_arithmetic_private.h"
|
||||||
|
|
||||||
|
/** @brief res = a * b : small integer polynomial product */
|
||||||
|
EXPORT void fft64_znx_small_single_product(const MODULE* module, // N
|
||||||
|
int64_t* res, // output
|
||||||
|
const int64_t* a, // a
|
||||||
|
const int64_t* b, // b
|
||||||
|
uint8_t* tmp) {
|
||||||
|
const uint64_t nn = module->nn;
|
||||||
|
double* const ffta = (double*)tmp;
|
||||||
|
double* const fftb = ((double*)tmp) + nn;
|
||||||
|
reim_from_znx64(module->mod.fft64.p_conv, ffta, a);
|
||||||
|
reim_from_znx64(module->mod.fft64.p_conv, fftb, b);
|
||||||
|
reim_fft(module->mod.fft64.p_fft, ffta);
|
||||||
|
reim_fft(module->mod.fft64.p_fft, fftb);
|
||||||
|
reim_fftvec_mul_simple(module->m, ffta, ffta, fftb);
|
||||||
|
reim_ifft(module->mod.fft64.p_ifft, ffta);
|
||||||
|
reim_to_znx64(module->mod.fft64.p_reim_to_znx, res, ffta);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief tmp bytes required for znx_small_single_product */
|
||||||
|
EXPORT uint64_t fft64_znx_small_single_product_tmp_bytes(const MODULE* module, uint64_t nn) {
|
||||||
|
return 2 * nn * sizeof(double);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief res = a * b : small integer polynomial product */
|
||||||
|
EXPORT void znx_small_single_product(const MODULE* module, // N
|
||||||
|
int64_t* res, // output
|
||||||
|
const int64_t* a, // a
|
||||||
|
const int64_t* b, // b
|
||||||
|
uint8_t* tmp) {
|
||||||
|
module->func.znx_small_single_product(module, res, a, b, tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief tmp bytes required for znx_small_single_product */
|
||||||
|
EXPORT uint64_t znx_small_single_product_tmp_bytes(const MODULE* module, uint64_t nn) {
|
||||||
|
return module->func.znx_small_single_product_tmp_bytes(module, nn);
|
||||||
|
}
|
||||||
@@ -0,0 +1,524 @@
|
|||||||
|
#include "coeffs_arithmetic.h"
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
#include <memory.h>
|
||||||
|
|
||||||
|
/** res = a + b */
|
||||||
|
EXPORT void znx_add_i64_ref(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b) {
|
||||||
|
for (uint64_t i = 0; i < nn; ++i) {
|
||||||
|
res[i] = a[i] + b[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/** res = a - b */
|
||||||
|
EXPORT void znx_sub_i64_ref(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b) {
|
||||||
|
for (uint64_t i = 0; i < nn; ++i) {
|
||||||
|
res[i] = a[i] - b[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void znx_negate_i64_ref(uint64_t nn, int64_t* res, const int64_t* a) {
|
||||||
|
for (uint64_t i = 0; i < nn; ++i) {
|
||||||
|
res[i] = -a[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EXPORT void znx_copy_i64_ref(uint64_t nn, int64_t* res, const int64_t* a) { memcpy(res, a, nn * sizeof(int64_t)); }
|
||||||
|
|
||||||
|
EXPORT void znx_zero_i64_ref(uint64_t nn, int64_t* res) { memset(res, 0, nn * sizeof(int64_t)); }
|
||||||
|
|
||||||
|
EXPORT void rnx_divide_by_m_ref(uint64_t n, double m, double* res, const double* a) {
|
||||||
|
const double invm = 1. / m;
|
||||||
|
for (uint64_t i = 0; i < n; ++i) {
|
||||||
|
res[i] = a[i] * invm;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void rnx_rotate_f64(uint64_t nn, int64_t p, double* res, const double* in) {
|
||||||
|
uint64_t a = (-p) & (2 * nn - 1); // a= (-p) (pos)mod (2*nn)
|
||||||
|
|
||||||
|
if (a < nn) { // rotate to the left
|
||||||
|
uint64_t nma = nn - a;
|
||||||
|
// rotate first half
|
||||||
|
for (uint64_t j = 0; j < nma; j++) {
|
||||||
|
res[j] = in[j + a];
|
||||||
|
}
|
||||||
|
for (uint64_t j = nma; j < nn; j++) {
|
||||||
|
res[j] = -in[j - nma];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
a -= nn;
|
||||||
|
uint64_t nma = nn - a;
|
||||||
|
for (uint64_t j = 0; j < nma; j++) {
|
||||||
|
res[j] = -in[j + a];
|
||||||
|
}
|
||||||
|
for (uint64_t j = nma; j < nn; j++) {
|
||||||
|
// rotate first half
|
||||||
|
res[j] = in[j - nma];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void znx_rotate_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in) {
|
||||||
|
uint64_t a = (-p) & (2 * nn - 1); // a= (-p) (pos)mod (2*nn)
|
||||||
|
|
||||||
|
if (a < nn) { // rotate to the left
|
||||||
|
uint64_t nma = nn - a;
|
||||||
|
// rotate first half
|
||||||
|
for (uint64_t j = 0; j < nma; j++) {
|
||||||
|
res[j] = in[j + a];
|
||||||
|
}
|
||||||
|
for (uint64_t j = nma; j < nn; j++) {
|
||||||
|
res[j] = -in[j - nma];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
a -= nn;
|
||||||
|
uint64_t nma = nn - a;
|
||||||
|
for (uint64_t j = 0; j < nma; j++) {
|
||||||
|
res[j] = -in[j + a];
|
||||||
|
}
|
||||||
|
for (uint64_t j = nma; j < nn; j++) {
|
||||||
|
// rotate first half
|
||||||
|
res[j] = in[j - nma];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void rnx_mul_xp_minus_one_f64(uint64_t nn, int64_t p, double* res, const double* in) {
|
||||||
|
uint64_t a = (-p) & (2 * nn - 1); // a= (-p) (pos)mod (2*nn)
|
||||||
|
if (a < nn) { // rotate to the left
|
||||||
|
uint64_t nma = nn - a;
|
||||||
|
// rotate first half
|
||||||
|
for (uint64_t j = 0; j < nma; j++) {
|
||||||
|
res[j] = in[j + a] - in[j];
|
||||||
|
}
|
||||||
|
for (uint64_t j = nma; j < nn; j++) {
|
||||||
|
res[j] = -in[j - nma] - in[j];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
a -= nn;
|
||||||
|
uint64_t nma = nn - a;
|
||||||
|
for (uint64_t j = 0; j < nma; j++) {
|
||||||
|
res[j] = -in[j + a] - in[j];
|
||||||
|
}
|
||||||
|
for (uint64_t j = nma; j < nn; j++) {
|
||||||
|
// rotate first half
|
||||||
|
res[j] = in[j - nma] - in[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void znx_mul_xp_minus_one_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in) {
|
||||||
|
uint64_t a = (-p) & (2 * nn - 1); // a= (-p) (pos)mod (2*nn)
|
||||||
|
if (a < nn) { // rotate to the left
|
||||||
|
uint64_t nma = nn - a;
|
||||||
|
// rotate first half
|
||||||
|
for (uint64_t j = 0; j < nma; j++) {
|
||||||
|
res[j] = in[j + a] - in[j];
|
||||||
|
}
|
||||||
|
for (uint64_t j = nma; j < nn; j++) {
|
||||||
|
res[j] = -in[j - nma] - in[j];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
a -= nn;
|
||||||
|
uint64_t nma = nn - a;
|
||||||
|
for (uint64_t j = 0; j < nma; j++) {
|
||||||
|
res[j] = -in[j + a] - in[j];
|
||||||
|
}
|
||||||
|
for (uint64_t j = nma; j < nn; j++) {
|
||||||
|
// rotate first half
|
||||||
|
res[j] = in[j - nma] - in[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void znx_mul_xp_minus_one_inplace_i64(uint64_t nn, int64_t p, int64_t* res) {
|
||||||
|
const uint64_t _2mn = 2 * nn - 1;
|
||||||
|
const uint64_t _mn = nn - 1;
|
||||||
|
uint64_t nb_modif = 0;
|
||||||
|
uint64_t j_start = 0;
|
||||||
|
while (nb_modif < nn) {
|
||||||
|
// follow the cycle that start with j_start
|
||||||
|
uint64_t j = j_start;
|
||||||
|
int64_t tmp1 = res[j];
|
||||||
|
do {
|
||||||
|
// find where the value should go, and with which sign
|
||||||
|
uint64_t new_j = (j + p) & _2mn; // mod 2n to get the position and sign
|
||||||
|
uint64_t new_j_n = new_j & _mn; // mod n to get just the position
|
||||||
|
// exchange this position with tmp1 (and take care of the sign)
|
||||||
|
int64_t tmp2 = res[new_j_n];
|
||||||
|
res[new_j_n] = ((new_j < nn) ? tmp1 : -tmp1) - res[new_j_n];
|
||||||
|
tmp1 = tmp2;
|
||||||
|
// move to the new location, and store the number of items modified
|
||||||
|
++nb_modif;
|
||||||
|
j = new_j_n;
|
||||||
|
} while (j != j_start);
|
||||||
|
// move to the start of the next cycle:
|
||||||
|
// we need to find an index that has not been touched yet, and pick it as next j_start.
|
||||||
|
// in practice, it is enough to do +1, because the group of rotations is cyclic and 1 is a generator.
|
||||||
|
++j_start;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 0 < p < 2nn
|
||||||
|
EXPORT void rnx_automorphism_f64(uint64_t nn, int64_t p, double* res, const double* in) {
|
||||||
|
res[0] = in[0];
|
||||||
|
uint64_t a = 0;
|
||||||
|
uint64_t _2mn = 2 * nn - 1;
|
||||||
|
for (uint64_t i = 1; i < nn; i++) {
|
||||||
|
a = (a + p) & _2mn; // i*p mod 2n
|
||||||
|
if (a < nn) {
|
||||||
|
res[a] = in[i]; // res[ip mod 2n] = res[i]
|
||||||
|
} else {
|
||||||
|
res[a - nn] = -in[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void znx_automorphism_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in) {
|
||||||
|
res[0] = in[0];
|
||||||
|
uint64_t a = 0;
|
||||||
|
uint64_t _2mn = 2 * nn - 1;
|
||||||
|
for (uint64_t i = 1; i < nn; i++) {
|
||||||
|
a = (a + p) & _2mn;
|
||||||
|
if (a < nn) {
|
||||||
|
res[a] = in[i]; // res[ip mod 2n] = res[i]
|
||||||
|
} else {
|
||||||
|
res[a - nn] = -in[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void rnx_rotate_inplace_f64(uint64_t nn, int64_t p, double* res) {
|
||||||
|
const uint64_t _2mn = 2 * nn - 1;
|
||||||
|
const uint64_t _mn = nn - 1;
|
||||||
|
uint64_t nb_modif = 0;
|
||||||
|
uint64_t j_start = 0;
|
||||||
|
while (nb_modif < nn) {
|
||||||
|
// follow the cycle that start with j_start
|
||||||
|
uint64_t j = j_start;
|
||||||
|
double tmp1 = res[j];
|
||||||
|
do {
|
||||||
|
// find where the value should go, and with which sign
|
||||||
|
uint64_t new_j = (j + p) & _2mn; // mod 2n to get the position and sign
|
||||||
|
uint64_t new_j_n = new_j & _mn; // mod n to get just the position
|
||||||
|
// exchange this position with tmp1 (and take care of the sign)
|
||||||
|
double tmp2 = res[new_j_n];
|
||||||
|
res[new_j_n] = (new_j < nn) ? tmp1 : -tmp1;
|
||||||
|
tmp1 = tmp2;
|
||||||
|
// move to the new location, and store the number of items modified
|
||||||
|
++nb_modif;
|
||||||
|
j = new_j_n;
|
||||||
|
} while (j != j_start);
|
||||||
|
// move to the start of the next cycle:
|
||||||
|
// we need to find an index that has not been touched yet, and pick it as next j_start.
|
||||||
|
// in practice, it is enough to do +1, because the group of rotations is cyclic and 1 is a generator.
|
||||||
|
++j_start;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void znx_rotate_inplace_i64(uint64_t nn, int64_t p, int64_t* res) {
|
||||||
|
const uint64_t _2mn = 2 * nn - 1;
|
||||||
|
const uint64_t _mn = nn - 1;
|
||||||
|
uint64_t nb_modif = 0;
|
||||||
|
uint64_t j_start = 0;
|
||||||
|
while (nb_modif < nn) {
|
||||||
|
// follow the cycle that start with j_start
|
||||||
|
uint64_t j = j_start;
|
||||||
|
int64_t tmp1 = res[j];
|
||||||
|
do {
|
||||||
|
// find where the value should go, and with which sign
|
||||||
|
uint64_t new_j = (j + p) & _2mn; // mod 2n to get the position and sign
|
||||||
|
uint64_t new_j_n = new_j & _mn; // mod n to get just the position
|
||||||
|
// exchange this position with tmp1 (and take care of the sign)
|
||||||
|
int64_t tmp2 = res[new_j_n];
|
||||||
|
res[new_j_n] = (new_j < nn) ? tmp1 : -tmp1;
|
||||||
|
tmp1 = tmp2;
|
||||||
|
// move to the new location, and store the number of items modified
|
||||||
|
++nb_modif;
|
||||||
|
j = new_j_n;
|
||||||
|
} while (j != j_start);
|
||||||
|
// move to the start of the next cycle:
|
||||||
|
// we need to find an index that has not been touched yet, and pick it as next j_start.
|
||||||
|
// in practice, it is enough to do +1, because the group of rotations is cyclic and 1 is a generator.
|
||||||
|
++j_start;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void rnx_mul_xp_minus_one_inplace_f64(uint64_t nn, int64_t p, double* res) {
|
||||||
|
const uint64_t _2mn = 2 * nn - 1;
|
||||||
|
const uint64_t _mn = nn - 1;
|
||||||
|
uint64_t nb_modif = 0;
|
||||||
|
uint64_t j_start = 0;
|
||||||
|
while (nb_modif < nn) {
|
||||||
|
// follow the cycle that start with j_start
|
||||||
|
uint64_t j = j_start;
|
||||||
|
double tmp1 = res[j];
|
||||||
|
do {
|
||||||
|
// find where the value should go, and with which sign
|
||||||
|
uint64_t new_j = (j + p) & _2mn; // mod 2n to get the position and sign
|
||||||
|
uint64_t new_j_n = new_j & _mn; // mod n to get just the position
|
||||||
|
// exchange this position with tmp1 (and take care of the sign)
|
||||||
|
double tmp2 = res[new_j_n];
|
||||||
|
res[new_j_n] = ((new_j < nn) ? tmp1 : -tmp1) - res[new_j_n];
|
||||||
|
tmp1 = tmp2;
|
||||||
|
// move to the new location, and store the number of items modified
|
||||||
|
++nb_modif;
|
||||||
|
j = new_j_n;
|
||||||
|
} while (j != j_start);
|
||||||
|
// move to the start of the next cycle:
|
||||||
|
// we need to find an index that has not been touched yet, and pick it as next j_start.
|
||||||
|
// in practice, it is enough to do +1, because the group of rotations is cyclic and 1 is a generator.
|
||||||
|
++j_start;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__always_inline int64_t get_base_k_digit(const int64_t x, const uint64_t base_k) {
|
||||||
|
return (x << (64 - base_k)) >> (64 - base_k);
|
||||||
|
}
|
||||||
|
|
||||||
|
__always_inline int64_t get_base_k_carry(const int64_t x, const int64_t digit, const uint64_t base_k) {
|
||||||
|
return (x - digit) >> base_k;
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void znx_normalize(uint64_t nn, uint64_t base_k, int64_t* out, int64_t* carry_out, const int64_t* in,
|
||||||
|
const int64_t* carry_in) {
|
||||||
|
assert(in);
|
||||||
|
if (out != 0) {
|
||||||
|
if (carry_in != 0x0 && carry_out != 0x0) {
|
||||||
|
// with carry in and carry out is computed
|
||||||
|
for (uint64_t i = 0; i < nn; ++i) {
|
||||||
|
const int64_t x = in[i];
|
||||||
|
const int64_t cin = carry_in[i];
|
||||||
|
|
||||||
|
int64_t digit = get_base_k_digit(x, base_k);
|
||||||
|
int64_t carry = get_base_k_carry(x, digit, base_k);
|
||||||
|
int64_t digit_plus_cin = digit + cin;
|
||||||
|
int64_t y = get_base_k_digit(digit_plus_cin, base_k);
|
||||||
|
int64_t cout = carry + get_base_k_carry(digit_plus_cin, y, base_k);
|
||||||
|
|
||||||
|
out[i] = y;
|
||||||
|
carry_out[i] = cout;
|
||||||
|
}
|
||||||
|
} else if (carry_in != 0) {
|
||||||
|
// with carry in and carry out is dropped
|
||||||
|
for (uint64_t i = 0; i < nn; ++i) {
|
||||||
|
const int64_t x = in[i];
|
||||||
|
const int64_t cin = carry_in[i];
|
||||||
|
|
||||||
|
int64_t digit = get_base_k_digit(x, base_k);
|
||||||
|
int64_t digit_plus_cin = digit + cin;
|
||||||
|
int64_t y = get_base_k_digit(digit_plus_cin, base_k);
|
||||||
|
|
||||||
|
out[i] = y;
|
||||||
|
}
|
||||||
|
|
||||||
|
} else if (carry_out != 0) {
|
||||||
|
// no carry in and carry out is computed
|
||||||
|
for (uint64_t i = 0; i < nn; ++i) {
|
||||||
|
const int64_t x = in[i];
|
||||||
|
|
||||||
|
int64_t y = get_base_k_digit(x, base_k);
|
||||||
|
int64_t cout = get_base_k_carry(x, y, base_k);
|
||||||
|
|
||||||
|
out[i] = y;
|
||||||
|
carry_out[i] = cout;
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// no carry in and carry out is dropped
|
||||||
|
for (uint64_t i = 0; i < nn; ++i) {
|
||||||
|
out[i] = get_base_k_digit(in[i], base_k);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
assert(carry_out);
|
||||||
|
if (carry_in != 0x0) {
|
||||||
|
// with carry in and carry out is computed
|
||||||
|
for (uint64_t i = 0; i < nn; ++i) {
|
||||||
|
const int64_t x = in[i];
|
||||||
|
const int64_t cin = carry_in[i];
|
||||||
|
|
||||||
|
int64_t digit = get_base_k_digit(x, base_k);
|
||||||
|
int64_t carry = get_base_k_carry(x, digit, base_k);
|
||||||
|
int64_t digit_plus_cin = digit + cin;
|
||||||
|
int64_t y = get_base_k_digit(digit_plus_cin, base_k);
|
||||||
|
int64_t cout = carry + get_base_k_carry(digit_plus_cin, y, base_k);
|
||||||
|
|
||||||
|
carry_out[i] = cout;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// no carry in and carry out is computed
|
||||||
|
for (uint64_t i = 0; i < nn; ++i) {
|
||||||
|
const int64_t x = in[i];
|
||||||
|
|
||||||
|
int64_t y = get_base_k_digit(x, base_k);
|
||||||
|
int64_t cout = get_base_k_carry(x, y, base_k);
|
||||||
|
|
||||||
|
carry_out[i] = cout;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void znx_automorphism_inplace_i64(uint64_t nn, int64_t p, int64_t* res) {
|
||||||
|
const uint64_t _2mn = 2 * nn - 1;
|
||||||
|
const uint64_t _mn = nn - 1;
|
||||||
|
const uint64_t m = nn >> 1;
|
||||||
|
// reduce p mod 2n
|
||||||
|
p &= _2mn;
|
||||||
|
// uint64_t vp = p & _2mn;
|
||||||
|
/// uint64_t target_modifs = m >> 1;
|
||||||
|
// we proceed by increasing binary valuation
|
||||||
|
for (uint64_t binval = 1, vp = p & _2mn, orb_size = m; binval < nn;
|
||||||
|
binval <<= 1, vp = (vp << 1) & _2mn, orb_size >>= 1) {
|
||||||
|
// In this loop, we are going to treat the orbit of indexes = binval mod 2.binval.
|
||||||
|
// At the beginning of this loop we have:
|
||||||
|
// vp = binval * p mod 2n
|
||||||
|
// target_modif = m / binval (i.e. order of the orbit binval % 2.binval)
|
||||||
|
|
||||||
|
// first, handle the orders 1 and 2.
|
||||||
|
// if p*binval == binval % 2n: we're done!
|
||||||
|
if (vp == binval) return;
|
||||||
|
// if p*binval == -binval % 2n: nega-mirror the orbit and all the sub-orbits and exit!
|
||||||
|
if (((vp + binval) & _2mn) == 0) {
|
||||||
|
for (uint64_t j = binval; j < m; j += binval) {
|
||||||
|
int64_t tmp = res[j];
|
||||||
|
res[j] = -res[nn - j];
|
||||||
|
res[nn - j] = -tmp;
|
||||||
|
}
|
||||||
|
res[m] = -res[m];
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// if p*binval == binval + n % 2n: negate the orbit and exit
|
||||||
|
if (((vp - binval) & _mn) == 0) {
|
||||||
|
for (uint64_t j = binval; j < nn; j += 2 * binval) {
|
||||||
|
res[j] = -res[j];
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// if p*binval == n - binval % 2n: mirror the orbit and continue!
|
||||||
|
if (((vp + binval) & _mn) == 0) {
|
||||||
|
for (uint64_t j = binval; j < m; j += 2 * binval) {
|
||||||
|
int64_t tmp = res[j];
|
||||||
|
res[j] = res[nn - j];
|
||||||
|
res[nn - j] = tmp;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// otherwise we will follow the orbit cycles,
|
||||||
|
// starting from binval and -binval in parallel
|
||||||
|
uint64_t j_start = binval;
|
||||||
|
uint64_t nb_modif = 0;
|
||||||
|
while (nb_modif < orb_size) {
|
||||||
|
// follow the cycle that start with j_start
|
||||||
|
uint64_t j = j_start;
|
||||||
|
int64_t tmp1 = res[j];
|
||||||
|
int64_t tmp2 = res[nn - j];
|
||||||
|
do {
|
||||||
|
// find where the value should go, and with which sign
|
||||||
|
uint64_t new_j = (j * p) & _2mn; // mod 2n to get the position and sign
|
||||||
|
uint64_t new_j_n = new_j & _mn; // mod n to get just the position
|
||||||
|
// exchange this position with tmp1 (and take care of the sign)
|
||||||
|
int64_t tmp1a = res[new_j_n];
|
||||||
|
int64_t tmp2a = res[nn - new_j_n];
|
||||||
|
if (new_j < nn) {
|
||||||
|
res[new_j_n] = tmp1;
|
||||||
|
res[nn - new_j_n] = tmp2;
|
||||||
|
} else {
|
||||||
|
res[new_j_n] = -tmp1;
|
||||||
|
res[nn - new_j_n] = -tmp2;
|
||||||
|
}
|
||||||
|
tmp1 = tmp1a;
|
||||||
|
tmp2 = tmp2a;
|
||||||
|
// move to the new location, and store the number of items modified
|
||||||
|
nb_modif += 2;
|
||||||
|
j = new_j_n;
|
||||||
|
} while (j != j_start);
|
||||||
|
// move to the start of the next cycle:
|
||||||
|
// we need to find an index that has not been touched yet, and pick it as next j_start.
|
||||||
|
// in practice, it is enough to do *5, because 5 is a generator.
|
||||||
|
j_start = (5 * j_start) & _mn;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void rnx_automorphism_inplace_f64(uint64_t nn, int64_t p, double* res) {
|
||||||
|
const uint64_t _2mn = 2 * nn - 1;
|
||||||
|
const uint64_t _mn = nn - 1;
|
||||||
|
const uint64_t m = nn >> 1;
|
||||||
|
// reduce p mod 2n
|
||||||
|
p &= _2mn;
|
||||||
|
// uint64_t vp = p & _2mn;
|
||||||
|
/// uint64_t target_modifs = m >> 1;
|
||||||
|
// we proceed by increasing binary valuation
|
||||||
|
for (uint64_t binval = 1, vp = p & _2mn, orb_size = m; binval < nn;
|
||||||
|
binval <<= 1, vp = (vp << 1) & _2mn, orb_size >>= 1) {
|
||||||
|
// In this loop, we are going to treat the orbit of indexes = binval mod 2.binval.
|
||||||
|
// At the beginning of this loop we have:
|
||||||
|
// vp = binval * p mod 2n
|
||||||
|
// target_modif = m / binval (i.e. order of the orbit binval % 2.binval)
|
||||||
|
|
||||||
|
// first, handle the orders 1 and 2.
|
||||||
|
// if p*binval == binval % 2n: we're done!
|
||||||
|
if (vp == binval) return;
|
||||||
|
// if p*binval == -binval % 2n: nega-mirror the orbit and all the sub-orbits and exit!
|
||||||
|
if (((vp + binval) & _2mn) == 0) {
|
||||||
|
for (uint64_t j = binval; j < m; j += binval) {
|
||||||
|
double tmp = res[j];
|
||||||
|
res[j] = -res[nn - j];
|
||||||
|
res[nn - j] = -tmp;
|
||||||
|
}
|
||||||
|
res[m] = -res[m];
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// if p*binval == binval + n % 2n: negate the orbit and exit
|
||||||
|
if (((vp - binval) & _mn) == 0) {
|
||||||
|
for (uint64_t j = binval; j < nn; j += 2 * binval) {
|
||||||
|
res[j] = -res[j];
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// if p*binval == n - binval % 2n: mirror the orbit and continue!
|
||||||
|
if (((vp + binval) & _mn) == 0) {
|
||||||
|
for (uint64_t j = binval; j < m; j += 2 * binval) {
|
||||||
|
double tmp = res[j];
|
||||||
|
res[j] = res[nn - j];
|
||||||
|
res[nn - j] = tmp;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// otherwise we will follow the orbit cycles,
|
||||||
|
// starting from binval and -binval in parallel
|
||||||
|
uint64_t j_start = binval;
|
||||||
|
uint64_t nb_modif = 0;
|
||||||
|
while (nb_modif < orb_size) {
|
||||||
|
// follow the cycle that start with j_start
|
||||||
|
uint64_t j = j_start;
|
||||||
|
double tmp1 = res[j];
|
||||||
|
double tmp2 = res[nn - j];
|
||||||
|
do {
|
||||||
|
// find where the value should go, and with which sign
|
||||||
|
uint64_t new_j = (j * p) & _2mn; // mod 2n to get the position and sign
|
||||||
|
uint64_t new_j_n = new_j & _mn; // mod n to get just the position
|
||||||
|
// exchange this position with tmp1 (and take care of the sign)
|
||||||
|
double tmp1a = res[new_j_n];
|
||||||
|
double tmp2a = res[nn - new_j_n];
|
||||||
|
if (new_j < nn) {
|
||||||
|
res[new_j_n] = tmp1;
|
||||||
|
res[nn - new_j_n] = tmp2;
|
||||||
|
} else {
|
||||||
|
res[new_j_n] = -tmp1;
|
||||||
|
res[nn - new_j_n] = -tmp2;
|
||||||
|
}
|
||||||
|
tmp1 = tmp1a;
|
||||||
|
tmp2 = tmp2a;
|
||||||
|
// move to the new location, and store the number of items modified
|
||||||
|
nb_modif += 2;
|
||||||
|
j = new_j_n;
|
||||||
|
} while (j != j_start);
|
||||||
|
// move to the start of the next cycle:
|
||||||
|
// we need to find an index that has not been touched yet, and pick it as next j_start.
|
||||||
|
// in practice, it is enough to do *5, because 5 is a generator.
|
||||||
|
j_start = (5 * j_start) & _mn;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,79 @@
|
|||||||
|
#ifndef SPQLIOS_COEFFS_ARITHMETIC_H
|
||||||
|
#define SPQLIOS_COEFFS_ARITHMETIC_H
|
||||||
|
|
||||||
|
#include "../commons.h"
|
||||||
|
|
||||||
|
/** res = a + b */
|
||||||
|
EXPORT void znx_add_i64_ref(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b);
|
||||||
|
EXPORT void znx_add_i64_avx(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b);
|
||||||
|
/** res = a - b */
|
||||||
|
EXPORT void znx_sub_i64_ref(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b);
|
||||||
|
EXPORT void znx_sub_i64_avx(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b);
|
||||||
|
/** res = -a */
|
||||||
|
EXPORT void znx_negate_i64_ref(uint64_t nn, int64_t* res, const int64_t* a);
|
||||||
|
EXPORT void znx_negate_i64_avx(uint64_t nn, int64_t* res, const int64_t* a);
|
||||||
|
/** res = a */
|
||||||
|
EXPORT void znx_copy_i64_ref(uint64_t nn, int64_t* res, const int64_t* a);
|
||||||
|
/** res = 0 */
|
||||||
|
EXPORT void znx_zero_i64_ref(uint64_t nn, int64_t* res);
|
||||||
|
|
||||||
|
/** res = a / m where m is a power of 2 */
|
||||||
|
EXPORT void rnx_divide_by_m_ref(uint64_t nn, double m, double* res, const double* a);
|
||||||
|
EXPORT void rnx_divide_by_m_avx(uint64_t nn, double m, double* res, const double* a);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param res = X^p *in mod X^nn +1
|
||||||
|
* @param nn the ring dimension
|
||||||
|
* @param p a power for the rotation -2nn <= p <= 2nn
|
||||||
|
* @param in is a rnx/znx vector of dimension nn
|
||||||
|
*/
|
||||||
|
EXPORT void rnx_rotate_f64(uint64_t nn, int64_t p, double* res, const double* in);
|
||||||
|
EXPORT void znx_rotate_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in);
|
||||||
|
EXPORT void rnx_rotate_inplace_f64(uint64_t nn, int64_t p, double* res);
|
||||||
|
EXPORT void znx_rotate_inplace_i64(uint64_t nn, int64_t p, int64_t* res);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief res(X) = in(X^p)
|
||||||
|
* @param nn the ring dimension
|
||||||
|
* @param p is odd integer and must be between 0 < p < 2nn
|
||||||
|
* @param in is a rnx/znx vector of dimension nn
|
||||||
|
*/
|
||||||
|
EXPORT void rnx_automorphism_f64(uint64_t nn, int64_t p, double* res, const double* in);
|
||||||
|
EXPORT void znx_automorphism_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in);
|
||||||
|
EXPORT void rnx_automorphism_inplace_f64(uint64_t nn, int64_t p, double* res);
|
||||||
|
EXPORT void znx_automorphism_inplace_i64(uint64_t nn, int64_t p, int64_t* res);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief res = (X^p-1).in
|
||||||
|
* @param nn the ring dimension
|
||||||
|
* @param p must be between -2nn <= p <= 2nn
|
||||||
|
* @param in is a rnx/znx vector of dimension nn
|
||||||
|
*/
|
||||||
|
EXPORT void rnx_mul_xp_minus_one_f64(uint64_t nn, int64_t p, double* res, const double* in);
|
||||||
|
EXPORT void znx_mul_xp_minus_one_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in);
|
||||||
|
EXPORT void rnx_mul_xp_minus_one_inplace_f64(uint64_t nn, int64_t p, double* res);
|
||||||
|
EXPORT void znx_mul_xp_minus_one_inplace_i64(uint64_t nn, int64_t p, int64_t* res);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Normalize input plus carry mod-2^k. The following
|
||||||
|
* equality holds @c {in + carry_in == out + carry_out . 2^k}.
|
||||||
|
*
|
||||||
|
* @c in must be in [-2^62 .. 2^62]
|
||||||
|
*
|
||||||
|
* @c out is in [ -2^(base_k-1), 2^(base_k-1) [.
|
||||||
|
*
|
||||||
|
* @c carry_in and @carry_out have at most 64+1-k bits.
|
||||||
|
*
|
||||||
|
* Null @c carry_in or @c carry_out are ignored.
|
||||||
|
*
|
||||||
|
* @param[in] nn the ring dimension
|
||||||
|
* @param[in] base_k the base k
|
||||||
|
* @param out output normalized znx
|
||||||
|
* @param carry_out output carry znx
|
||||||
|
* @param[in] in input znx
|
||||||
|
* @param[in] carry_in input carry znx
|
||||||
|
*/
|
||||||
|
EXPORT void znx_normalize(uint64_t nn, uint64_t base_k, int64_t* out, int64_t* carry_out, const int64_t* in,
|
||||||
|
const int64_t* carry_in);
|
||||||
|
|
||||||
|
#endif // SPQLIOS_COEFFS_ARITHMETIC_H
|
||||||
@@ -0,0 +1,124 @@
|
|||||||
|
#include <immintrin.h>
|
||||||
|
|
||||||
|
#include "../commons_private.h"
|
||||||
|
#include "coeffs_arithmetic.h"
|
||||||
|
|
||||||
|
// res = a + b. dimension n must be a power of 2
|
||||||
|
EXPORT void znx_add_i64_avx(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b) {
|
||||||
|
if (nn <= 2) {
|
||||||
|
if (nn == 1) {
|
||||||
|
res[0] = a[0] + b[0];
|
||||||
|
} else {
|
||||||
|
_mm_storeu_si128((__m128i*)res, //
|
||||||
|
_mm_add_epi64( //
|
||||||
|
_mm_loadu_si128((__m128i*)a), //
|
||||||
|
_mm_loadu_si128((__m128i*)b)));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const __m256i* aa = (__m256i*)a;
|
||||||
|
const __m256i* bb = (__m256i*)b;
|
||||||
|
__m256i* rr = (__m256i*)res;
|
||||||
|
__m256i* const rrend = (__m256i*)(res + nn);
|
||||||
|
do {
|
||||||
|
_mm256_storeu_si256(rr, //
|
||||||
|
_mm256_add_epi64( //
|
||||||
|
_mm256_loadu_si256(aa), //
|
||||||
|
_mm256_loadu_si256(bb)));
|
||||||
|
++rr;
|
||||||
|
++aa;
|
||||||
|
++bb;
|
||||||
|
} while (rr < rrend);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// res = a - b. dimension n must be a power of 2
|
||||||
|
EXPORT void znx_sub_i64_avx(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b) {
|
||||||
|
if (nn <= 2) {
|
||||||
|
if (nn == 1) {
|
||||||
|
res[0] = a[0] - b[0];
|
||||||
|
} else {
|
||||||
|
_mm_storeu_si128((__m128i*)res, //
|
||||||
|
_mm_sub_epi64( //
|
||||||
|
_mm_loadu_si128((__m128i*)a), //
|
||||||
|
_mm_loadu_si128((__m128i*)b)));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const __m256i* aa = (__m256i*)a;
|
||||||
|
const __m256i* bb = (__m256i*)b;
|
||||||
|
__m256i* rr = (__m256i*)res;
|
||||||
|
__m256i* const rrend = (__m256i*)(res + nn);
|
||||||
|
do {
|
||||||
|
_mm256_storeu_si256(rr, //
|
||||||
|
_mm256_sub_epi64( //
|
||||||
|
_mm256_loadu_si256(aa), //
|
||||||
|
_mm256_loadu_si256(bb)));
|
||||||
|
++rr;
|
||||||
|
++aa;
|
||||||
|
++bb;
|
||||||
|
} while (rr < rrend);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void znx_negate_i64_avx(uint64_t nn, int64_t* res, const int64_t* a) {
|
||||||
|
if (nn <= 2) {
|
||||||
|
if (nn == 1) {
|
||||||
|
res[0] = -a[0];
|
||||||
|
} else {
|
||||||
|
_mm_storeu_si128((__m128i*)res, //
|
||||||
|
_mm_sub_epi64( //
|
||||||
|
_mm_set1_epi64x(0), //
|
||||||
|
_mm_loadu_si128((__m128i*)a)));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const __m256i* aa = (__m256i*)a;
|
||||||
|
__m256i* rr = (__m256i*)res;
|
||||||
|
__m256i* const rrend = (__m256i*)(res + nn);
|
||||||
|
do {
|
||||||
|
_mm256_storeu_si256(rr, //
|
||||||
|
_mm256_sub_epi64( //
|
||||||
|
_mm256_set1_epi64x(0), //
|
||||||
|
_mm256_loadu_si256(aa)));
|
||||||
|
++rr;
|
||||||
|
++aa;
|
||||||
|
} while (rr < rrend);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void rnx_divide_by_m_avx(uint64_t n, double m, double* res, const double* a) {
|
||||||
|
// TODO: see if there is a faster way of dividing by a power of 2?
|
||||||
|
const double invm = 1. / m;
|
||||||
|
if (n < 8) {
|
||||||
|
switch (n) {
|
||||||
|
case 1:
|
||||||
|
*res = *a * invm;
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
_mm_storeu_pd(res, //
|
||||||
|
_mm_mul_pd(_mm_loadu_pd(a), //
|
||||||
|
_mm_set1_pd(invm)));
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
_mm256_storeu_pd(res, //
|
||||||
|
_mm256_mul_pd(_mm256_loadu_pd(a), //
|
||||||
|
_mm256_set1_pd(invm)));
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
NOT_SUPPORTED(); // non-power of 2
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const __m256d invm256 = _mm256_set1_pd(invm);
|
||||||
|
double* rr = res;
|
||||||
|
const double* aa = a;
|
||||||
|
const double* const aaend = a + n;
|
||||||
|
do {
|
||||||
|
_mm256_storeu_pd(rr, //
|
||||||
|
_mm256_mul_pd(_mm256_loadu_pd(aa), //
|
||||||
|
invm256));
|
||||||
|
_mm256_storeu_pd(rr + 4, //
|
||||||
|
_mm256_mul_pd(_mm256_loadu_pd(aa + 4), //
|
||||||
|
invm256));
|
||||||
|
rr += 8;
|
||||||
|
aa += 8;
|
||||||
|
} while (aa < aaend);
|
||||||
|
}
|
||||||
@@ -0,0 +1,165 @@
|
|||||||
|
#include "commons.h"
|
||||||
|
|
||||||
|
#include <math.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
|
||||||
|
EXPORT void* UNDEFINED_p_ii(int32_t n, int32_t m) { UNDEFINED(); }
|
||||||
|
EXPORT void* UNDEFINED_p_uu(uint32_t n, uint32_t m) { UNDEFINED(); }
|
||||||
|
EXPORT double* UNDEFINED_dp_pi(const void* p, int32_t n) { UNDEFINED(); }
|
||||||
|
EXPORT void* UNDEFINED_vp_pi(const void* p, int32_t n) { UNDEFINED(); }
|
||||||
|
EXPORT void* UNDEFINED_vp_pu(const void* p, uint32_t n) { UNDEFINED(); }
|
||||||
|
EXPORT void UNDEFINED_v_vpdp(const void* p, double* a) { UNDEFINED(); }
|
||||||
|
EXPORT void UNDEFINED_v_vpvp(const void* p, void* a) { UNDEFINED(); }
|
||||||
|
EXPORT double* NOT_IMPLEMENTED_dp_i(int32_t n) { NOT_IMPLEMENTED(); }
|
||||||
|
EXPORT void* NOT_IMPLEMENTED_vp_i(int32_t n) { NOT_IMPLEMENTED(); }
|
||||||
|
EXPORT void* NOT_IMPLEMENTED_vp_u(uint32_t n) { NOT_IMPLEMENTED(); }
|
||||||
|
EXPORT void NOT_IMPLEMENTED_v_dp(double* a) { NOT_IMPLEMENTED(); }
|
||||||
|
EXPORT void NOT_IMPLEMENTED_v_vp(void* p) { NOT_IMPLEMENTED(); }
|
||||||
|
EXPORT void NOT_IMPLEMENTED_v_idpdpdp(int32_t n, double* a, const double* b, const double* c) { NOT_IMPLEMENTED(); }
|
||||||
|
EXPORT void NOT_IMPLEMENTED_v_uvpcvpcvp(uint32_t n, void* r, const void* a, const void* b) { NOT_IMPLEMENTED(); }
|
||||||
|
EXPORT void NOT_IMPLEMENTED_v_uvpvpcvp(uint32_t n, void* a, void* b, const void* o) { NOT_IMPLEMENTED(); }
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
#define __always_inline inline __attribute((always_inline))
|
||||||
|
#endif
|
||||||
|
|
||||||
|
void internal_accurate_sincos(double* rcos, double* rsin, double x) {
|
||||||
|
double _4_x_over_pi = 4 * x / M_PI;
|
||||||
|
int64_t int_part = ((int64_t)rint(_4_x_over_pi)) & 7;
|
||||||
|
double frac_part = _4_x_over_pi - (double)(int_part);
|
||||||
|
double frac_x = M_PI * frac_part / 4.;
|
||||||
|
// compute the taylor series
|
||||||
|
double cosp = 1.;
|
||||||
|
double sinp = 0.;
|
||||||
|
double powx = 1.;
|
||||||
|
int64_t nn = 0;
|
||||||
|
while (fabs(powx) > 1e-20) {
|
||||||
|
++nn;
|
||||||
|
powx = powx * frac_x / (double)(nn); // x^n/n!
|
||||||
|
switch (nn & 3) {
|
||||||
|
case 0:
|
||||||
|
cosp += powx;
|
||||||
|
break;
|
||||||
|
case 1:
|
||||||
|
sinp += powx;
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
cosp -= powx;
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
sinp -= powx;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
abort(); // impossible
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// final multiplication
|
||||||
|
switch (int_part) {
|
||||||
|
case 0:
|
||||||
|
*rcos = cosp;
|
||||||
|
*rsin = sinp;
|
||||||
|
break;
|
||||||
|
case 1:
|
||||||
|
*rcos = M_SQRT1_2 * (cosp - sinp);
|
||||||
|
*rsin = M_SQRT1_2 * (cosp + sinp);
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
*rcos = -sinp;
|
||||||
|
*rsin = cosp;
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
*rcos = -M_SQRT1_2 * (cosp + sinp);
|
||||||
|
*rsin = M_SQRT1_2 * (cosp - sinp);
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
*rcos = -cosp;
|
||||||
|
*rsin = -sinp;
|
||||||
|
break;
|
||||||
|
case 5:
|
||||||
|
*rcos = -M_SQRT1_2 * (cosp - sinp);
|
||||||
|
*rsin = -M_SQRT1_2 * (cosp + sinp);
|
||||||
|
break;
|
||||||
|
case 6:
|
||||||
|
*rcos = sinp;
|
||||||
|
*rsin = -cosp;
|
||||||
|
break;
|
||||||
|
case 7:
|
||||||
|
*rcos = M_SQRT1_2 * (cosp + sinp);
|
||||||
|
*rsin = -M_SQRT1_2 * (cosp - sinp);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
abort(); // impossible
|
||||||
|
}
|
||||||
|
if (fabs(cos(x) - *rcos) > 1e-10 || fabs(sin(x) - *rsin) > 1e-10) {
|
||||||
|
printf("cos(%.17lf) =? %.17lf instead of %.17lf\n", x, *rcos, cos(x));
|
||||||
|
printf("sin(%.17lf) =? %.17lf instead of %.17lf\n", x, *rsin, sin(x));
|
||||||
|
printf("fracx = %.17lf\n", frac_x);
|
||||||
|
printf("cosp = %.17lf\n", cosp);
|
||||||
|
printf("sinp = %.17lf\n", sinp);
|
||||||
|
printf("nn = %d\n", (int)(nn));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
double internal_accurate_cos(double x) {
|
||||||
|
double rcos, rsin;
|
||||||
|
internal_accurate_sincos(&rcos, &rsin, x);
|
||||||
|
return rcos;
|
||||||
|
}
|
||||||
|
double internal_accurate_sin(double x) {
|
||||||
|
double rcos, rsin;
|
||||||
|
internal_accurate_sincos(&rcos, &rsin, x);
|
||||||
|
return rsin;
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void spqlios_debug_free(void* addr) { free((uint8_t*)addr - 64); }
|
||||||
|
|
||||||
|
EXPORT void* spqlios_debug_alloc(uint64_t size) { return (uint8_t*)malloc(size + 64) + 64; }
|
||||||
|
|
||||||
|
EXPORT void spqlios_free(void* addr) {
|
||||||
|
#ifndef NDEBUG
|
||||||
|
// in debug mode, we deallocated with spqlios_debug_free()
|
||||||
|
spqlios_debug_free(addr);
|
||||||
|
#else
|
||||||
|
// in release mode, the function will free aligned memory
|
||||||
|
#ifdef _WIN32
|
||||||
|
_aligned_free(addr);
|
||||||
|
#else
|
||||||
|
free(addr);
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void* spqlios_alloc(uint64_t size) {
|
||||||
|
#ifndef NDEBUG
|
||||||
|
// in debug mode, the function will not necessarily have any particular alignment
|
||||||
|
// it will also ensure that memory can only be deallocated with spqlios_free()
|
||||||
|
return spqlios_debug_alloc(size);
|
||||||
|
#else
|
||||||
|
// in release mode, the function will return 64-bytes aligned memory
|
||||||
|
#ifdef _WIN32
|
||||||
|
void* reps = _aligned_malloc((size + 63) & (UINT64_C(-64)), 64);
|
||||||
|
#else
|
||||||
|
void* reps = aligned_alloc(64, (size + 63) & (UINT64_C(-64)));
|
||||||
|
#endif
|
||||||
|
if (reps == 0) FATAL_ERROR("Out of memory");
|
||||||
|
return reps;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void* spqlios_alloc_custom_align(uint64_t align, uint64_t size) {
|
||||||
|
#ifndef NDEBUG
|
||||||
|
// in debug mode, the function will not necessarily have any particular alignment
|
||||||
|
// it will also ensure that memory can only be deallocated with spqlios_free()
|
||||||
|
return spqlios_debug_alloc(size);
|
||||||
|
#else
|
||||||
|
// in release mode, the function will return aligned memory
|
||||||
|
#ifdef _WIN32
|
||||||
|
void* reps = _aligned_malloc(size, align);
|
||||||
|
#else
|
||||||
|
void* reps = aligned_alloc(align, size);
|
||||||
|
#endif
|
||||||
|
if (reps == 0) FATAL_ERROR("Out of memory");
|
||||||
|
return reps;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
@@ -0,0 +1,77 @@
|
|||||||
|
#ifndef SPQLIOS_COMMONS_H
|
||||||
|
#define SPQLIOS_COMMONS_H
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <cstdlib>
|
||||||
|
#define EXPORT extern "C"
|
||||||
|
#define EXPORT_DECL extern "C"
|
||||||
|
#else
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#define EXPORT
|
||||||
|
#define EXPORT_DECL extern
|
||||||
|
#define nullptr 0x0;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#define UNDEFINED() \
|
||||||
|
{ \
|
||||||
|
fprintf(stderr, "UNDEFINED!!!\n"); \
|
||||||
|
abort(); \
|
||||||
|
}
|
||||||
|
#define NOT_IMPLEMENTED() \
|
||||||
|
{ \
|
||||||
|
fprintf(stderr, "NOT IMPLEMENTED!!!\n"); \
|
||||||
|
abort(); \
|
||||||
|
}
|
||||||
|
#define FATAL_ERROR(MESSAGE) \
|
||||||
|
{ \
|
||||||
|
fprintf(stderr, "ERROR: %s\n", (MESSAGE)); \
|
||||||
|
abort(); \
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void* UNDEFINED_p_ii(int32_t n, int32_t m);
|
||||||
|
EXPORT void* UNDEFINED_p_uu(uint32_t n, uint32_t m);
|
||||||
|
EXPORT double* UNDEFINED_dp_pi(const void* p, int32_t n);
|
||||||
|
EXPORT void* UNDEFINED_vp_pi(const void* p, int32_t n);
|
||||||
|
EXPORT void* UNDEFINED_vp_pu(const void* p, uint32_t n);
|
||||||
|
EXPORT void UNDEFINED_v_vpdp(const void* p, double* a);
|
||||||
|
EXPORT void UNDEFINED_v_vpvp(const void* p, void* a);
|
||||||
|
EXPORT double* NOT_IMPLEMENTED_dp_i(int32_t n);
|
||||||
|
EXPORT void* NOT_IMPLEMENTED_vp_i(int32_t n);
|
||||||
|
EXPORT void* NOT_IMPLEMENTED_vp_u(uint32_t n);
|
||||||
|
EXPORT void NOT_IMPLEMENTED_v_dp(double* a);
|
||||||
|
EXPORT void NOT_IMPLEMENTED_v_vp(void* p);
|
||||||
|
EXPORT void NOT_IMPLEMENTED_v_idpdpdp(int32_t n, double* a, const double* b, const double* c);
|
||||||
|
EXPORT void NOT_IMPLEMENTED_v_uvpcvpcvp(uint32_t n, void* r, const void* a, const void* b);
|
||||||
|
EXPORT void NOT_IMPLEMENTED_v_uvpvpcvp(uint32_t n, void* a, void* b, const void* o);
|
||||||
|
|
||||||
|
// windows
|
||||||
|
|
||||||
|
#if defined(_WIN32) || defined(__APPLE__)
|
||||||
|
#define __always_inline inline __attribute((always_inline))
|
||||||
|
#endif
|
||||||
|
|
||||||
|
EXPORT void spqlios_free(void* address);
|
||||||
|
|
||||||
|
EXPORT void* spqlios_alloc(uint64_t size);
|
||||||
|
EXPORT void* spqlios_alloc_custom_align(uint64_t align, uint64_t size);
|
||||||
|
|
||||||
|
#define USE_LIBM_SIN_COS
|
||||||
|
#ifndef USE_LIBM_SIN_COS
|
||||||
|
// if at some point, we want to remove the libm dependency, we can
|
||||||
|
// consider this:
|
||||||
|
EXPORT double internal_accurate_cos(double x);
|
||||||
|
EXPORT double internal_accurate_sin(double x);
|
||||||
|
EXPORT void internal_accurate_sincos(double* rcos, double* rsin, double x);
|
||||||
|
#define m_accurate_cos internal_accurate_cos
|
||||||
|
#define m_accurate_sin internal_accurate_sin
|
||||||
|
#else
|
||||||
|
// let's use libm sin and cos
|
||||||
|
#define m_accurate_cos cos
|
||||||
|
#define m_accurate_sin sin
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif // SPQLIOS_COMMONS_H
|
||||||
@@ -0,0 +1,55 @@
|
|||||||
|
#include "commons_private.h"
|
||||||
|
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
|
||||||
|
#include "commons.h"
|
||||||
|
|
||||||
|
EXPORT void* spqlios_error(const char* error) {
|
||||||
|
fputs(error, stderr);
|
||||||
|
abort();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
EXPORT void* spqlios_keep_or_free(void* ptr, void* ptr2) {
|
||||||
|
if (!ptr2) {
|
||||||
|
free(ptr);
|
||||||
|
}
|
||||||
|
return ptr2;
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT uint32_t log2m(uint32_t m) {
|
||||||
|
uint32_t a = m - 1;
|
||||||
|
if (m & a) FATAL_ERROR("m must be a power of two");
|
||||||
|
a = (a & 0x55555555u) + ((a >> 1) & 0x55555555u);
|
||||||
|
a = (a & 0x33333333u) + ((a >> 2) & 0x33333333u);
|
||||||
|
a = (a & 0x0F0F0F0Fu) + ((a >> 4) & 0x0F0F0F0Fu);
|
||||||
|
a = (a & 0x00FF00FFu) + ((a >> 8) & 0x00FF00FFu);
|
||||||
|
return (a & 0x0000FFFFu) + ((a >> 16) & 0x0000FFFFu);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT uint64_t is_not_pow2_double(void* doublevalue) { return (*(uint64_t*)doublevalue) & 0x7FFFFFFFFFFFFUL; }
|
||||||
|
|
||||||
|
uint32_t revbits(uint32_t nbits, uint32_t value) {
|
||||||
|
uint32_t res = 0;
|
||||||
|
for (uint32_t i = 0; i < nbits; ++i) {
|
||||||
|
res = (res << 1) + (value & 1);
|
||||||
|
value >>= 1;
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief this computes the sequence: 0,1/2,1/4,3/4,1/8,5/8,3/8,7/8,...
|
||||||
|
* essentially: the bits of (i+1) in lsb order on the basis (1/2^k) mod 1*/
|
||||||
|
double fracrevbits(uint32_t i) {
|
||||||
|
if (i == 0) return 0;
|
||||||
|
if (i == 1) return 0.5;
|
||||||
|
if (i % 2 == 0)
|
||||||
|
return fracrevbits(i / 2) / 2.;
|
||||||
|
else
|
||||||
|
return fracrevbits((i - 1) / 2) / 2. + 0.5;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t ceilto64b(uint64_t size) { return (size + UINT64_C(63)) & (UINT64_C(-64)); }
|
||||||
|
|
||||||
|
uint64_t ceilto32b(uint64_t size) { return (size + UINT64_C(31)) & (UINT64_C(-32)); }
|
||||||
@@ -0,0 +1,72 @@
|
|||||||
|
#ifndef SPQLIOS_COMMONS_PRIVATE_H
|
||||||
|
#define SPQLIOS_COMMONS_PRIVATE_H
|
||||||
|
|
||||||
|
#include "commons.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <cstdlib>
|
||||||
|
#else
|
||||||
|
#include <math.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#define nullptr 0x0;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/** @brief log2 of a power of two (UB if m is not a power of two) */
|
||||||
|
EXPORT uint32_t log2m(uint32_t m);
|
||||||
|
|
||||||
|
/** @brief checks if the doublevalue is a power of two */
|
||||||
|
EXPORT uint64_t is_not_pow2_double(void* doublevalue);
|
||||||
|
|
||||||
|
#define UNDEFINED() \
|
||||||
|
{ \
|
||||||
|
fprintf(stderr, "UNDEFINED!!!\n"); \
|
||||||
|
abort(); \
|
||||||
|
}
|
||||||
|
#define NOT_IMPLEMENTED() \
|
||||||
|
{ \
|
||||||
|
fprintf(stderr, "NOT IMPLEMENTED!!!\n"); \
|
||||||
|
abort(); \
|
||||||
|
}
|
||||||
|
#define NOT_SUPPORTED() \
|
||||||
|
{ \
|
||||||
|
fprintf(stderr, "NOT SUPPORTED!!!\n"); \
|
||||||
|
abort(); \
|
||||||
|
}
|
||||||
|
#define FATAL_ERROR(MESSAGE) \
|
||||||
|
{ \
|
||||||
|
fprintf(stderr, "ERROR: %s\n", (MESSAGE)); \
|
||||||
|
abort(); \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define STATIC_ASSERT(condition) (void)sizeof(char[-1 + 2 * !!(condition)])
|
||||||
|
|
||||||
|
/** @brief reports the error and returns nullptr */
|
||||||
|
EXPORT void* spqlios_error(const char* error);
|
||||||
|
/** @brief if ptr2 is not null, returns ptr, otherwise free ptr and return null */
|
||||||
|
EXPORT void* spqlios_keep_or_free(void* ptr, void* ptr2);
|
||||||
|
|
||||||
|
#ifdef __x86_64__
|
||||||
|
#define CPU_SUPPORTS __builtin_cpu_supports
|
||||||
|
#else
|
||||||
|
// TODO for now, we do not have any optimization for non x86 targets
|
||||||
|
#define CPU_SUPPORTS(xxxx) 0
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/** @brief returns the n bits of value in reversed order */
|
||||||
|
EXPORT uint32_t revbits(uint32_t nbits, uint32_t value);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief this computes the sequence: 0,1/2,1/4,3/4,1/8,5/8,3/8,7/8,...
|
||||||
|
* essentially: the bits of (i+1) in lsb order on the basis (1/2^k) mod 1*/
|
||||||
|
EXPORT double fracrevbits(uint32_t i);
|
||||||
|
|
||||||
|
/** @brief smallest multiple of 64 higher or equal to size */
|
||||||
|
EXPORT uint64_t ceilto64b(uint64_t size);
|
||||||
|
|
||||||
|
/** @brief smallest multiple of 32 higher or equal to size */
|
||||||
|
EXPORT uint64_t ceilto32b(uint64_t size);
|
||||||
|
|
||||||
|
#endif // SPQLIOS_COMMONS_PRIVATE_H
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
In this folder, we deal with the full complex FFT in `C[X] mod X^M-i`.
|
||||||
|
One complex is represented by two consecutive doubles `(real,imag)`
|
||||||
|
Note that a real polynomial sum_{j=0}^{N-1} p_j.X^j mod X^N+1
|
||||||
|
corresponds to the complex polynomial of half degree `M=N/2`:
|
||||||
|
`sum_{j=0}^{M-1} (p_{j} + i.p_{j+M}) X^j mod X^M-i`
|
||||||
|
|
||||||
|
For a complex polynomial A(X) sum c_i X^i of degree M-1
|
||||||
|
or a real polynomial sum a_i X^i of degree N
|
||||||
|
|
||||||
|
coefficient space:
|
||||||
|
a_0,a_M,a_1,a_{M+1},...,a_{M-1},a_{2M-1}
|
||||||
|
or equivalently
|
||||||
|
Re(c_0),Im(c_0),Re(c_1),Im(c_1),...Re(c_{M-1}),Im(c_{M-1})
|
||||||
|
|
||||||
|
eval space:
|
||||||
|
c(omega_{0}),...,c(omega_{M-1})
|
||||||
|
|
||||||
|
where
|
||||||
|
omega_j = omega^{1+rev_{2N}(j)}
|
||||||
|
and omega = exp(i.pi/N)
|
||||||
|
|
||||||
|
rev_{2N}(j) is the number that has the log2(2N) bits of j in reverse order.
|
||||||
@@ -0,0 +1,80 @@
|
|||||||
|
#include "cplx_fft_internal.h"
|
||||||
|
|
||||||
|
void cplx_set(CPLX r, const CPLX a) {
|
||||||
|
r[0] = a[0];
|
||||||
|
r[1] = a[1];
|
||||||
|
}
|
||||||
|
void cplx_neg(CPLX r, const CPLX a) {
|
||||||
|
r[0] = -a[0];
|
||||||
|
r[1] = -a[1];
|
||||||
|
}
|
||||||
|
void cplx_add(CPLX r, const CPLX a, const CPLX b) {
|
||||||
|
r[0] = a[0] + b[0];
|
||||||
|
r[1] = a[1] + b[1];
|
||||||
|
}
|
||||||
|
void cplx_sub(CPLX r, const CPLX a, const CPLX b) {
|
||||||
|
r[0] = a[0] - b[0];
|
||||||
|
r[1] = a[1] - b[1];
|
||||||
|
}
|
||||||
|
void cplx_mul(CPLX r, const CPLX a, const CPLX b) {
|
||||||
|
double re = a[0] * b[0] - a[1] * b[1];
|
||||||
|
r[1] = a[0] * b[1] + a[1] * b[0];
|
||||||
|
r[0] = re;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief splits 2h evaluations of one polynomials into 2 times h evaluations of even/odd polynomial
|
||||||
|
* Input: Q_0(y),...,Q_{h-1}(y),Q_0(-y),...,Q_{h-1}(-y)
|
||||||
|
* Output: P_0(z),...,P_{h-1}(z),P_h(z),...,P_{2h-1}(z)
|
||||||
|
* where Q_i(X)=P_i(X^2)+X.P_{h+i}(X^2) and y^2 = z
|
||||||
|
* @param h number of "coefficients" h >= 1
|
||||||
|
* @param data 2h complex coefficients interleaved and 256b aligned
|
||||||
|
* @param powom y represented as (yre,yim)
|
||||||
|
*/
|
||||||
|
EXPORT void cplx_split_fft_ref(int32_t h, CPLX* data, const CPLX powom) {
|
||||||
|
CPLX* d0 = data;
|
||||||
|
CPLX* d1 = data + h;
|
||||||
|
for (uint64_t i = 0; i < h; ++i) {
|
||||||
|
CPLX diff;
|
||||||
|
cplx_sub(diff, d0[i], d1[i]);
|
||||||
|
cplx_add(d0[i], d0[i], d1[i]);
|
||||||
|
cplx_mul(d1[i], diff, powom);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Do two layers of itwiddle (i.e. split).
|
||||||
|
* Input/output: d0,d1,d2,d3 of length h
|
||||||
|
* Algo:
|
||||||
|
* itwiddle(d0,d1,om[0]),itwiddle(d2,d3,i.om[0])
|
||||||
|
* itwiddle(d0,d2,om[1]),itwiddle(d1,d3,om[1])
|
||||||
|
* @param h number of "coefficients" h >= 1
|
||||||
|
* @param data 4h complex coefficients interleaved and 256b aligned
|
||||||
|
* @param powom om[0] (re,im) and om[1] where om[1]=om[0]^2
|
||||||
|
*/
|
||||||
|
EXPORT void cplx_bisplit_fft_ref(int32_t h, CPLX* data, const CPLX powom[2]) {
|
||||||
|
CPLX* d0 = data;
|
||||||
|
CPLX* d2 = data + 2 * h;
|
||||||
|
const CPLX* om0 = powom;
|
||||||
|
CPLX iom0;
|
||||||
|
iom0[0] = powom[0][1];
|
||||||
|
iom0[1] = -powom[0][0];
|
||||||
|
const CPLX* om1 = powom + 1;
|
||||||
|
cplx_split_fft_ref(h, d0, *om0);
|
||||||
|
cplx_split_fft_ref(h, d2, iom0);
|
||||||
|
cplx_split_fft_ref(2 * h, d0, *om1);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Input: Q(y),Q(-y)
|
||||||
|
* Output: P_0(z),P_1(z)
|
||||||
|
* where Q(X)=P_0(X^2)+X.P_1(X^2) and y^2 = z
|
||||||
|
* @param data 2 complexes coefficients interleaved and 256b aligned
|
||||||
|
* @param powom (z,-z) interleaved: (zre,zim,-zre,-zim)
|
||||||
|
*/
|
||||||
|
void split_fft_last_ref(CPLX* data, const CPLX powom) {
|
||||||
|
CPLX diff;
|
||||||
|
cplx_sub(diff, data[0], data[1]);
|
||||||
|
cplx_add(data[0], data[0], data[1]);
|
||||||
|
cplx_mul(data[1], diff, powom);
|
||||||
|
}
|
||||||
@@ -0,0 +1,158 @@
|
|||||||
|
#include <errno.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include "../commons_private.h"
|
||||||
|
#include "cplx_fft_internal.h"
|
||||||
|
#include "cplx_fft_private.h"
|
||||||
|
|
||||||
|
EXPORT void cplx_from_znx32_ref(const CPLX_FROM_ZNX32_PRECOMP* precomp, void* r, const int32_t* x) {
|
||||||
|
const uint32_t m = precomp->m;
|
||||||
|
const int32_t* inre = x;
|
||||||
|
const int32_t* inim = x + m;
|
||||||
|
CPLX* out = r;
|
||||||
|
for (uint32_t i = 0; i < m; ++i) {
|
||||||
|
out[i][0] = (double)inre[i];
|
||||||
|
out[i][1] = (double)inim[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void cplx_from_tnx32_ref(const CPLX_FROM_TNX32_PRECOMP* precomp, void* r, const int32_t* x) {
|
||||||
|
static const double _2p32 = 1. / (INT64_C(1) << 32);
|
||||||
|
const uint32_t m = precomp->m;
|
||||||
|
const int32_t* inre = x;
|
||||||
|
const int32_t* inim = x + m;
|
||||||
|
CPLX* out = r;
|
||||||
|
for (uint32_t i = 0; i < m; ++i) {
|
||||||
|
out[i][0] = ((double)inre[i]) * _2p32;
|
||||||
|
out[i][1] = ((double)inim[i]) * _2p32;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void cplx_to_tnx32_ref(const CPLX_TO_TNX32_PRECOMP* precomp, int32_t* r, const void* x) {
|
||||||
|
static const double _2p32 = (INT64_C(1) << 32);
|
||||||
|
const uint32_t m = precomp->m;
|
||||||
|
double factor = _2p32 / precomp->divisor;
|
||||||
|
int32_t* outre = r;
|
||||||
|
int32_t* outim = r + m;
|
||||||
|
const CPLX* in = x;
|
||||||
|
// Note: this formula will only work if abs(in) < 2^32
|
||||||
|
for (uint32_t i = 0; i < m; ++i) {
|
||||||
|
outre[i] = (int32_t)(int64_t)(rint(in[i][0] * factor));
|
||||||
|
outim[i] = (int32_t)(int64_t)(rint(in[i][1] * factor));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void* init_cplx_from_znx32_precomp(CPLX_FROM_ZNX32_PRECOMP* res, uint32_t m) {
|
||||||
|
res->m = m;
|
||||||
|
if (CPU_SUPPORTS("avx2")) {
|
||||||
|
if (m >= 8) {
|
||||||
|
res->function = cplx_from_znx32_avx2_fma;
|
||||||
|
} else {
|
||||||
|
res->function = cplx_from_znx32_ref;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
res->function = cplx_from_znx32_ref;
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
CPLX_FROM_ZNX32_PRECOMP* new_cplx_from_znx32_precomp(uint32_t m) {
|
||||||
|
CPLX_FROM_ZNX32_PRECOMP* res = malloc(sizeof(CPLX_FROM_ZNX32_PRECOMP));
|
||||||
|
if (!res) return spqlios_error(strerror(errno));
|
||||||
|
return spqlios_keep_or_free(res, init_cplx_from_znx32_precomp(res, m));
|
||||||
|
}
|
||||||
|
|
||||||
|
void* init_cplx_from_tnx32_precomp(CPLX_FROM_TNX32_PRECOMP* res, uint32_t m) {
|
||||||
|
res->m = m;
|
||||||
|
if (CPU_SUPPORTS("avx2")) {
|
||||||
|
if (m >= 8) {
|
||||||
|
res->function = cplx_from_tnx32_avx2_fma;
|
||||||
|
} else {
|
||||||
|
res->function = cplx_from_tnx32_ref;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
res->function = cplx_from_tnx32_ref;
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
CPLX_FROM_TNX32_PRECOMP* new_cplx_from_tnx32_precomp(uint32_t m) {
|
||||||
|
CPLX_FROM_TNX32_PRECOMP* res = malloc(sizeof(CPLX_FROM_TNX32_PRECOMP));
|
||||||
|
if (!res) return spqlios_error(strerror(errno));
|
||||||
|
return spqlios_keep_or_free(res, init_cplx_from_tnx32_precomp(res, m));
|
||||||
|
}
|
||||||
|
|
||||||
|
void* init_cplx_to_tnx32_precomp(CPLX_TO_TNX32_PRECOMP* res, uint32_t m, double divisor, uint32_t log2overhead) {
|
||||||
|
if (is_not_pow2_double(&divisor)) return spqlios_error("divisor must be a power of 2");
|
||||||
|
if (m & (m - 1)) return spqlios_error("m must be a power of 2");
|
||||||
|
if (log2overhead > 52) return spqlios_error("log2overhead is too large");
|
||||||
|
res->m = m;
|
||||||
|
res->divisor = divisor;
|
||||||
|
if (CPU_SUPPORTS("avx2")) {
|
||||||
|
if (log2overhead <= 18) {
|
||||||
|
if (m >= 8) {
|
||||||
|
res->function = cplx_to_tnx32_avx2_fma;
|
||||||
|
} else {
|
||||||
|
res->function = cplx_to_tnx32_ref;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
res->function = cplx_to_tnx32_ref;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
res->function = cplx_to_tnx32_ref;
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT CPLX_TO_TNX32_PRECOMP* new_cplx_to_tnx32_precomp(uint32_t m, double divisor, uint32_t log2overhead) {
|
||||||
|
CPLX_TO_TNX32_PRECOMP* res = malloc(sizeof(CPLX_TO_TNX32_PRECOMP));
|
||||||
|
if (!res) return spqlios_error(strerror(errno));
|
||||||
|
return spqlios_keep_or_free(res, init_cplx_to_tnx32_precomp(res, m, divisor, log2overhead));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Simpler API for the znx32 to cplx conversion.
|
||||||
|
* For each dimension, the precomputed tables for this dimension are generated automatically the first time.
|
||||||
|
* It is advised to do one dry-run call per desired dimension before using in a multithread environment */
|
||||||
|
EXPORT void cplx_from_znx32_simple(uint32_t m, void* r, const int32_t* x) {
|
||||||
|
// not checking for log2bound which is not relevant here
|
||||||
|
static CPLX_FROM_ZNX32_PRECOMP precomp[32];
|
||||||
|
CPLX_FROM_ZNX32_PRECOMP* p = precomp + log2m(m);
|
||||||
|
if (!p->function) {
|
||||||
|
if (!init_cplx_from_znx32_precomp(p, m)) abort();
|
||||||
|
}
|
||||||
|
p->function(p, r, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Simpler API for the tnx32 to cplx conversion.
|
||||||
|
* For each dimension, the precomputed tables for this dimension are generated automatically the first time.
|
||||||
|
* It is advised to do one dry-run call per desired dimension before using in a multithread environment */
|
||||||
|
EXPORT void cplx_from_tnx32_simple(uint32_t m, void* r, const int32_t* x) {
|
||||||
|
static CPLX_FROM_TNX32_PRECOMP precomp[32];
|
||||||
|
CPLX_FROM_TNX32_PRECOMP* p = precomp + log2m(m);
|
||||||
|
if (!p->function) {
|
||||||
|
if (!init_cplx_from_tnx32_precomp(p, m)) abort();
|
||||||
|
}
|
||||||
|
p->function(p, r, x);
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* @brief Simpler API for the cplx to tnx32 conversion.
|
||||||
|
* For each dimension, the precomputed tables for this dimension are generated automatically the first time.
|
||||||
|
* It is advised to do one dry-run call per desired dimension before using in a multithread environment */
|
||||||
|
EXPORT void cplx_to_tnx32_simple(uint32_t m, double divisor, uint32_t log2overhead, int32_t* r, const void* x) {
|
||||||
|
struct LAST_CPLX_TO_TNX32_PRECOMP {
|
||||||
|
CPLX_TO_TNX32_PRECOMP p;
|
||||||
|
double last_divisor;
|
||||||
|
double last_log2over;
|
||||||
|
};
|
||||||
|
static __thread struct LAST_CPLX_TO_TNX32_PRECOMP precomp[32];
|
||||||
|
struct LAST_CPLX_TO_TNX32_PRECOMP* p = precomp + log2m(m);
|
||||||
|
if (!p->p.function || divisor != p->last_divisor || log2overhead != p->last_log2over) {
|
||||||
|
memset(p, 0, sizeof(*p));
|
||||||
|
if (!init_cplx_to_tnx32_precomp(&p->p, m, divisor, log2overhead)) abort();
|
||||||
|
p->last_divisor = divisor;
|
||||||
|
p->last_log2over = log2overhead;
|
||||||
|
}
|
||||||
|
p->p.function(&p->p, r, x);
|
||||||
|
}
|
||||||
@@ -0,0 +1,104 @@
|
|||||||
|
#include <immintrin.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
|
#include "cplx_fft_internal.h"
|
||||||
|
#include "cplx_fft_private.h"
|
||||||
|
|
||||||
|
typedef int32_t I8MEM[8];
|
||||||
|
typedef double D4MEM[4];
|
||||||
|
|
||||||
|
__always_inline void cplx_from_any_fma(uint64_t m, void* r, const int32_t* x, const __m256i C, const __m256d R) {
|
||||||
|
const __m256i S = _mm256_set1_epi32(0x80000000);
|
||||||
|
const I8MEM* inre = (I8MEM*)(x);
|
||||||
|
const I8MEM* inim = (I8MEM*)(x + m);
|
||||||
|
D4MEM* out = (D4MEM*)r;
|
||||||
|
const uint64_t ms8 = m / 8;
|
||||||
|
for (uint32_t i = 0; i < ms8; ++i) {
|
||||||
|
__m256i rea = _mm256_loadu_si256((__m256i*)inre[0]);
|
||||||
|
__m256i ima = _mm256_loadu_si256((__m256i*)inim[0]);
|
||||||
|
rea = _mm256_add_epi32(rea, S);
|
||||||
|
ima = _mm256_add_epi32(ima, S);
|
||||||
|
__m256i tmpa = _mm256_unpacklo_epi32(rea, ima);
|
||||||
|
__m256i tmpc = _mm256_unpackhi_epi32(rea, ima);
|
||||||
|
__m256i cpla = _mm256_permute2x128_si256(tmpa, tmpc, 0x20);
|
||||||
|
__m256i cplc = _mm256_permute2x128_si256(tmpa, tmpc, 0x31);
|
||||||
|
tmpa = _mm256_unpacklo_epi32(cpla, C);
|
||||||
|
__m256i tmpb = _mm256_unpackhi_epi32(cpla, C);
|
||||||
|
tmpc = _mm256_unpacklo_epi32(cplc, C);
|
||||||
|
__m256i tmpd = _mm256_unpackhi_epi32(cplc, C);
|
||||||
|
cpla = _mm256_permute2x128_si256(tmpa, tmpb, 0x20);
|
||||||
|
__m256i cplb = _mm256_permute2x128_si256(tmpa, tmpb, 0x31);
|
||||||
|
cplc = _mm256_permute2x128_si256(tmpc, tmpd, 0x20);
|
||||||
|
__m256i cpld = _mm256_permute2x128_si256(tmpc, tmpd, 0x31);
|
||||||
|
__m256d dcpla = _mm256_sub_pd(_mm256_castsi256_pd(cpla), R);
|
||||||
|
__m256d dcplb = _mm256_sub_pd(_mm256_castsi256_pd(cplb), R);
|
||||||
|
__m256d dcplc = _mm256_sub_pd(_mm256_castsi256_pd(cplc), R);
|
||||||
|
__m256d dcpld = _mm256_sub_pd(_mm256_castsi256_pd(cpld), R);
|
||||||
|
_mm256_storeu_pd(out[0], dcpla);
|
||||||
|
_mm256_storeu_pd(out[1], dcplb);
|
||||||
|
_mm256_storeu_pd(out[2], dcplc);
|
||||||
|
_mm256_storeu_pd(out[3], dcpld);
|
||||||
|
inre += 1;
|
||||||
|
inim += 1;
|
||||||
|
out += 4;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void cplx_from_znx32_avx2_fma(const CPLX_FROM_ZNX32_PRECOMP* precomp, void* r, const int32_t* x) {
|
||||||
|
// note: the hex code of 2^31 + 2^52 is 0x4330000080000000
|
||||||
|
const __m256i C = _mm256_set1_epi32(0x43300000);
|
||||||
|
const __m256d R = _mm256_set1_pd((INT64_C(1) << 31) + (INT64_C(1) << 52));
|
||||||
|
// double XX = INT64_C(1) + (INT64_C(1)<<31) + (INT64_C(1)<<52);
|
||||||
|
// printf("\n\n%016lx\n", *(uint64_t*)&XX);
|
||||||
|
// abort();
|
||||||
|
const uint64_t m = precomp->m;
|
||||||
|
cplx_from_any_fma(m, r, x, C, R);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void cplx_from_tnx32_avx2_fma(const CPLX_FROM_TNX32_PRECOMP* precomp, void* r, const int32_t* x) {
|
||||||
|
// note: the hex code of 2^-1 + 2^30 is 0x4130000080000000
|
||||||
|
const __m256i C = _mm256_set1_epi32(0x41300000);
|
||||||
|
const __m256d R = _mm256_set1_pd(0.5 + (INT64_C(1) << 20));
|
||||||
|
// double XX = (double)(INT64_C(1) + (INT64_C(1)<<31) + (INT64_C(1)<<52))/(INT64_C(1)<<32);
|
||||||
|
// printf("\n\n%016lx\n", *(uint64_t*)&XX);
|
||||||
|
// abort();
|
||||||
|
const uint64_t m = precomp->m;
|
||||||
|
cplx_from_any_fma(m, r, x, C, R);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT void cplx_to_tnx32_avx2_fma(const CPLX_TO_TNX32_PRECOMP* precomp, int32_t* r, const void* x) {
|
||||||
|
const __m256d R = _mm256_set1_pd((0.5 + (INT64_C(3) << 19)) * precomp->divisor);
|
||||||
|
const __m256i MASK = _mm256_set1_epi64x(0xFFFFFFFFUL);
|
||||||
|
const __m256i S = _mm256_set1_epi32(0x80000000);
|
||||||
|
// const __m256i IDX = _mm256_set_epi32(0,4,1,5,2,6,3,7);
|
||||||
|
const __m256i IDX = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
|
||||||
|
const uint64_t m = precomp->m;
|
||||||
|
const uint64_t ms8 = m / 8;
|
||||||
|
I8MEM* outre = (I8MEM*)r;
|
||||||
|
I8MEM* outim = (I8MEM*)(r + m);
|
||||||
|
const D4MEM* in = x;
|
||||||
|
// Note: this formula will only work if abs(in) < 2^32
|
||||||
|
for (uint32_t i = 0; i < ms8; ++i) {
|
||||||
|
__m256d cpla = _mm256_loadu_pd(in[0]);
|
||||||
|
__m256d cplb = _mm256_loadu_pd(in[1]);
|
||||||
|
__m256d cplc = _mm256_loadu_pd(in[2]);
|
||||||
|
__m256d cpld = _mm256_loadu_pd(in[3]);
|
||||||
|
__m256i icpla = _mm256_castpd_si256(_mm256_add_pd(cpla, R));
|
||||||
|
__m256i icplb = _mm256_castpd_si256(_mm256_add_pd(cplb, R));
|
||||||
|
__m256i icplc = _mm256_castpd_si256(_mm256_add_pd(cplc, R));
|
||||||
|
__m256i icpld = _mm256_castpd_si256(_mm256_add_pd(cpld, R));
|
||||||
|
icpla = _mm256_or_si256(_mm256_and_si256(icpla, MASK), _mm256_slli_epi64(icplb, 32));
|
||||||
|
icplc = _mm256_or_si256(_mm256_and_si256(icplc, MASK), _mm256_slli_epi64(icpld, 32));
|
||||||
|
icpla = _mm256_xor_si256(icpla, S);
|
||||||
|
icplc = _mm256_xor_si256(icplc, S);
|
||||||
|
__m256i re = _mm256_unpacklo_epi64(icpla, icplc);
|
||||||
|
__m256i im = _mm256_unpackhi_epi64(icpla, icplc);
|
||||||
|
re = _mm256_permutevar8x32_epi32(re, IDX);
|
||||||
|
im = _mm256_permutevar8x32_epi32(im, IDX);
|
||||||
|
_mm256_storeu_si256((__m256i*)outre[0], re);
|
||||||
|
_mm256_storeu_si256((__m256i*)outim[0], im);
|
||||||
|
outre += 1;
|
||||||
|
outim += 1;
|
||||||
|
in += 4;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
#include "cplx_fft_internal.h"
|
||||||
|
#include "cplx_fft_private.h"
|
||||||
|
|
||||||
|
EXPORT void cplx_from_znx32(const CPLX_FROM_ZNX32_PRECOMP* tables, void* r, const int32_t* a) {
|
||||||
|
tables->function(tables, r, a);
|
||||||
|
}
|
||||||
|
EXPORT void cplx_from_tnx32(const CPLX_FROM_TNX32_PRECOMP* tables, void* r, const int32_t* a) {
|
||||||
|
tables->function(tables, r, a);
|
||||||
|
}
|
||||||
|
EXPORT void cplx_to_tnx32(const CPLX_TO_TNX32_PRECOMP* tables, int32_t* r, const void* a) {
|
||||||
|
tables->function(tables, r, a);
|
||||||
|
}
|
||||||
|
EXPORT void cplx_fftvec_mul(const CPLX_FFTVEC_MUL_PRECOMP* tables, void* r, const void* a, const void* b) {
|
||||||
|
tables->function(tables, r, a, b);
|
||||||
|
}
|
||||||
|
EXPORT void cplx_fftvec_addmul(const CPLX_FFTVEC_ADDMUL_PRECOMP* tables, void* r, const void* a, const void* b) {
|
||||||
|
tables->function(tables, r, a, b);
|
||||||
|
}
|
||||||
@@ -0,0 +1,43 @@
|
|||||||
|
#include "cplx_fft_internal.h"
|
||||||
|
#include "cplx_fft_private.h"
|
||||||
|
|
||||||
|
EXPORT void cplx_fftvec_addmul_fma(const CPLX_FFTVEC_ADDMUL_PRECOMP* tables, void* r, const void* a, const void* b) {
|
||||||
|
UNDEFINED(); // not defined for non x86 targets
|
||||||
|
}
|
||||||
|
EXPORT void cplx_fftvec_mul_fma(const CPLX_FFTVEC_MUL_PRECOMP* tables, void* r, const void* a, const void* b) {
|
||||||
|
UNDEFINED();
|
||||||
|
}
|
||||||
|
EXPORT void cplx_fftvec_addmul_sse(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a, const void* b) {
|
||||||
|
UNDEFINED();
|
||||||
|
}
|
||||||
|
EXPORT void cplx_fftvec_addmul_avx512(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a,
|
||||||
|
const void* b) {
|
||||||
|
UNDEFINED();
|
||||||
|
}
|
||||||
|
EXPORT void cplx_fft16_avx_fma(void* data, const void* omega) { UNDEFINED(); }
|
||||||
|
EXPORT void cplx_ifft16_avx_fma(void* data, const void* omega) { UNDEFINED(); }
|
||||||
|
EXPORT void cplx_from_znx32_avx2_fma(const CPLX_FROM_ZNX32_PRECOMP* precomp, void* r, const int32_t* x) { UNDEFINED(); }
|
||||||
|
EXPORT void cplx_from_tnx32_avx2_fma(const CPLX_FROM_TNX32_PRECOMP* precomp, void* r, const int32_t* x) { UNDEFINED(); }
|
||||||
|
EXPORT void cplx_to_tnx32_avx2_fma(const CPLX_TO_TNX32_PRECOMP* precomp, int32_t* x, const void* c) { UNDEFINED(); }
|
||||||
|
EXPORT void cplx_fft_avx2_fma(const CPLX_FFT_PRECOMP* tables, void* data){UNDEFINED()} EXPORT
|
||||||
|
void cplx_ifft_avx2_fma(const CPLX_IFFT_PRECOMP* itables, void* data){UNDEFINED()} EXPORT
|
||||||
|
void cplx_fftvec_twiddle_fma(const CPLX_FFTVEC_TWIDDLE_PRECOMP* tables, void* a, void* b, const void* om){
|
||||||
|
UNDEFINED()} EXPORT void cplx_fftvec_twiddle_avx512(const CPLX_FFTVEC_TWIDDLE_PRECOMP* tables, void* a, void* b,
|
||||||
|
const void* om){UNDEFINED()} EXPORT
|
||||||
|
void cplx_fftvec_bitwiddle_fma(const CPLX_FFTVEC_BITWIDDLE_PRECOMP* tables, void* a, uint64_t slice,
|
||||||
|
const void* om){UNDEFINED()} EXPORT
|
||||||
|
void cplx_fftvec_bitwiddle_avx512(const CPLX_FFTVEC_BITWIDDLE_PRECOMP* tables, void* a, uint64_t slice,
|
||||||
|
const void* om){UNDEFINED()}
|
||||||
|
|
||||||
|
// DEPRECATED?
|
||||||
|
EXPORT void cplx_fftvec_add_fma(uint32_t m, void* r, const void* a, const void* b){UNDEFINED()} EXPORT
|
||||||
|
void cplx_fftvec_sub2_to_fma(uint32_t m, void* r, const void* a, const void* b){UNDEFINED()} EXPORT
|
||||||
|
void cplx_fftvec_copy_fma(uint32_t m, void* r, const void* a) {
|
||||||
|
UNDEFINED()
|
||||||
|
}
|
||||||
|
|
||||||
|
// executors
|
||||||
|
// EXPORT void cplx_ifft(const CPLX_IFFT_PRECOMP* itables, void* data) {
|
||||||
|
// itables->function(itables, data);
|
||||||
|
//}
|
||||||
|
// EXPORT void cplx_fft(const CPLX_FFT_PRECOMP* tables, void* data) { tables->function(tables, data); }
|
||||||
@@ -0,0 +1,221 @@
|
|||||||
|
#ifndef SPQLIOS_CPLX_FFT_H
|
||||||
|
#define SPQLIOS_CPLX_FFT_H
|
||||||
|
|
||||||
|
#include "../commons.h"
|
||||||
|
|
||||||
|
typedef struct cplx_fft_precomp CPLX_FFT_PRECOMP;
|
||||||
|
typedef struct cplx_ifft_precomp CPLX_IFFT_PRECOMP;
|
||||||
|
typedef struct cplx_mul_precomp CPLX_FFTVEC_MUL_PRECOMP;
|
||||||
|
typedef struct cplx_addmul_precomp CPLX_FFTVEC_ADDMUL_PRECOMP;
|
||||||
|
typedef struct cplx_from_znx32_precomp CPLX_FROM_ZNX32_PRECOMP;
|
||||||
|
typedef struct cplx_from_tnx32_precomp CPLX_FROM_TNX32_PRECOMP;
|
||||||
|
typedef struct cplx_to_tnx32_precomp CPLX_TO_TNX32_PRECOMP;
|
||||||
|
typedef struct cplx_to_znx32_precomp CPLX_TO_ZNX32_PRECOMP;
|
||||||
|
typedef struct cplx_from_rnx64_precomp CPLX_FROM_RNX64_PRECOMP;
|
||||||
|
typedef struct cplx_to_rnx64_precomp CPLX_TO_RNX64_PRECOMP;
|
||||||
|
typedef struct cplx_round_to_rnx64_precomp CPLX_ROUND_TO_RNX64_PRECOMP;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief precomputes fft tables.
|
||||||
|
* The FFT tables contains a constant section that is required for efficient FFT operations in dimension nn.
|
||||||
|
* The resulting pointer is to be passed as "tables" argument to any call to the fft function.
|
||||||
|
* The user can optionnally allocate zero or more computation buffers, which are scratch spaces that are contiguous to
|
||||||
|
* the constant tables in memory, and allow for more efficient operations. It is the user's responsibility to ensure
|
||||||
|
* that each of those buffers are never used simultaneously by two ffts on different threads at the same time. The fft
|
||||||
|
* table must be deleted by delete_fft_precomp after its last usage.
|
||||||
|
*/
|
||||||
|
EXPORT CPLX_FFT_PRECOMP* new_cplx_fft_precomp(uint32_t m, uint32_t num_buffers);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief gets the address of a fft buffer allocated during new_fft_precomp.
|
||||||
|
* This buffer can be used as data pointer in subsequent calls to fft,
|
||||||
|
* and does not need to be released afterwards.
|
||||||
|
*/
|
||||||
|
EXPORT void* cplx_fft_precomp_get_buffer(const CPLX_FFT_PRECOMP* tables, uint32_t buffer_index);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief allocates a new fft buffer.
|
||||||
|
* This buffer can be used as data pointer in subsequent calls to fft,
|
||||||
|
* and must be deleted afterwards by calling delete_fft_buffer.
|
||||||
|
*/
|
||||||
|
EXPORT void* new_cplx_fft_buffer(uint32_t m);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief allocates a new fft buffer.
|
||||||
|
* This buffer can be used as data pointer in subsequent calls to fft,
|
||||||
|
* and must be deleted afterwards by calling delete_fft_buffer.
|
||||||
|
*/
|
||||||
|
EXPORT void delete_cplx_fft_buffer(void* buffer);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief deallocates a fft table and all its built-in buffers.
|
||||||
|
*/
|
||||||
|
#define delete_cplx_fft_precomp free
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief computes a direct fft in-place over data.
|
||||||
|
*/
|
||||||
|
EXPORT void cplx_fft(const CPLX_FFT_PRECOMP* tables, void* data);
|
||||||
|
|
||||||
|
EXPORT CPLX_IFFT_PRECOMP* new_cplx_ifft_precomp(uint32_t m, uint32_t num_buffers);
|
||||||
|
EXPORT void* cplx_ifft_precomp_get_buffer(const CPLX_IFFT_PRECOMP* tables, uint32_t buffer_index);
|
||||||
|
EXPORT void cplx_ifft(const CPLX_IFFT_PRECOMP* tables, void* data);
|
||||||
|
#define delete_cplx_ifft_precomp free
|
||||||
|
|
||||||
|
EXPORT CPLX_FFTVEC_MUL_PRECOMP* new_cplx_fftvec_mul_precomp(uint32_t m);
|
||||||
|
EXPORT void cplx_fftvec_mul(const CPLX_FFTVEC_MUL_PRECOMP* tables, void* r, const void* a, const void* b);
|
||||||
|
#define delete_cplx_fftvec_mul_precomp free
|
||||||
|
|
||||||
|
EXPORT CPLX_FFTVEC_ADDMUL_PRECOMP* new_cplx_fftvec_addmul_precomp(uint32_t m);
|
||||||
|
EXPORT void cplx_fftvec_addmul(const CPLX_FFTVEC_ADDMUL_PRECOMP* tables, void* r, const void* a, const void* b);
|
||||||
|
#define delete_cplx_fftvec_addmul_precomp free
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief prepares a conversion from ZnX to the cplx layout.
|
||||||
|
* All the coefficients must be strictly lower than 2^log2bound in absolute value. Any attempt to use
|
||||||
|
* this function on a larger coefficient is undefined behaviour. The resulting precomputed data must
|
||||||
|
* be freed with `new_cplx_from_znx32_precomp`
|
||||||
|
* @param m the target complex dimension m from C[X] mod X^m-i. Note that the inputs have n=2m
|
||||||
|
* int32 coefficients in natural order modulo X^n+1
|
||||||
|
* @param log2bound bound on the input coefficients. Must be between 0 and 32
|
||||||
|
*/
|
||||||
|
EXPORT CPLX_FROM_ZNX32_PRECOMP* new_cplx_from_znx32_precomp(uint32_t m);
|
||||||
|
/**
|
||||||
|
* @brief converts from ZnX to the cplx layout.
|
||||||
|
* @param tables precomputed data obtained by new_cplx_from_znx32_precomp.
|
||||||
|
* @param r resulting array of m complexes coefficients mod X^m-i
|
||||||
|
* @param x input array of n bounded integer coefficients mod X^n+1
|
||||||
|
*/
|
||||||
|
EXPORT void cplx_from_znx32(const CPLX_FROM_ZNX32_PRECOMP* tables, void* r, const int32_t* a);
|
||||||
|
/** @brief frees a precomputed conversion data initialized with new_cplx_from_znx32_precomp. */
|
||||||
|
#define delete_cplx_from_znx32_precomp free
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief prepares a conversion from TnX to the cplx layout.
|
||||||
|
* @param m the target complex dimension m from C[X] mod X^m-i. Note that the inputs have n=2m
|
||||||
|
* torus32 coefficients. The resulting precomputed data must
|
||||||
|
* be freed with `delete_cplx_from_tnx32_precomp`
|
||||||
|
*/
|
||||||
|
EXPORT CPLX_FROM_TNX32_PRECOMP* new_cplx_from_tnx32_precomp(uint32_t m);
|
||||||
|
/**
|
||||||
|
* @brief converts from TnX to the cplx layout.
|
||||||
|
* @param tables precomputed data obtained by new_cplx_from_tnx32_precomp.
|
||||||
|
* @param r resulting array of m complexes coefficients mod X^m-i
|
||||||
|
* @param x input array of n torus32 coefficients mod X^n+1
|
||||||
|
*/
|
||||||
|
EXPORT void cplx_from_tnx32(const CPLX_FROM_TNX32_PRECOMP* tables, void* r, const int32_t* a);
|
||||||
|
/** @brief frees a precomputed conversion data initialized with new_cplx_from_tnx32_precomp. */
|
||||||
|
#define delete_cplx_from_tnx32_precomp free
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief prepares a rescale and conversion from the cplx layout to TnX.
|
||||||
|
* @param m the target complex dimension m from C[X] mod X^m-i. Note that the outputs have n=2m
|
||||||
|
* torus32 coefficients.
|
||||||
|
* @param divisor must be a power of two. The inputs are rescaled by divisor before being reduced modulo 1.
|
||||||
|
* Remember that the output of an iFFT must be divided by m.
|
||||||
|
* @param log2overhead all inputs absolute values must be within divisor.2^log2overhead.
|
||||||
|
* For any inputs outside of these bounds, the conversion is undefined behaviour.
|
||||||
|
* The maximum supported log2overhead is 52, and the algorithm is faster for log2overhead=18.
|
||||||
|
*/
|
||||||
|
EXPORT CPLX_TO_TNX32_PRECOMP* new_cplx_to_tnx32_precomp(uint32_t m, double divisor, uint32_t log2overhead);
|
||||||
|
/**
|
||||||
|
* @brief rescale, converts and reduce mod 1 from cplx layout to torus32.
|
||||||
|
* @param tables precomputed data obtained by new_cplx_from_tnx32_precomp.
|
||||||
|
* @param r resulting array of n torus32 coefficients mod X^n+1
|
||||||
|
* @param x input array of m cplx coefficients mod X^m-i
|
||||||
|
*/
|
||||||
|
EXPORT void cplx_to_tnx32(const CPLX_TO_TNX32_PRECOMP* tables, int32_t* r, const void* a);
|
||||||
|
#define delete_cplx_to_tnx32_precomp free
|
||||||
|
|
||||||
|
EXPORT CPLX_TO_ZNX32_PRECOMP* new_cplx_to_znx32_precomp(uint32_t m, double divisor);
|
||||||
|
EXPORT void cplx_to_znx32(const CPLX_TO_ZNX32_PRECOMP* precomp, int32_t* r, const void* x);
|
||||||
|
#define delete_cplx_to_znx32_simple free
|
||||||
|
|
||||||
|
EXPORT CPLX_FROM_RNX64_PRECOMP* new_cplx_from_rnx64_simple(uint32_t m);
|
||||||
|
EXPORT void cplx_from_rnx64(const CPLX_FROM_RNX64_PRECOMP* precomp, void* r, const double* x);
|
||||||
|
#define delete_cplx_from_rnx64_simple free
|
||||||
|
|
||||||
|
EXPORT CPLX_TO_RNX64_PRECOMP* new_cplx_to_rnx64(uint32_t m, double divisor);
|
||||||
|
EXPORT void cplx_to_rnx64(const CPLX_TO_RNX64_PRECOMP* precomp, double* r, const void* x);
|
||||||
|
#define delete_cplx_round_to_rnx64_simple free
|
||||||
|
|
||||||
|
EXPORT CPLX_ROUND_TO_RNX64_PRECOMP* new_cplx_round_to_rnx64(uint32_t m, double divisor, uint32_t log2bound);
|
||||||
|
EXPORT void cplx_round_to_rnx64(const CPLX_ROUND_TO_RNX64_PRECOMP* precomp, double* r, const void* x);
|
||||||
|
#define delete_cplx_round_to_rnx64_simple free
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Simpler API for the fft function.
|
||||||
|
* For each dimension, the precomputed tables for this dimension are generated automatically.
|
||||||
|
* It is advised to do one dry-run per desired dimension before using in a multithread environment */
|
||||||
|
EXPORT void cplx_fft_simple(uint32_t m, void* data);
|
||||||
|
/**
|
||||||
|
* @brief Simpler API for the ifft function.
|
||||||
|
* For each dimension, the precomputed tables for this dimension are generated automatically the first time.
|
||||||
|
* It is advised to do one dry-run call per desired dimension in the main thread before using in a multithread
|
||||||
|
* environment */
|
||||||
|
EXPORT void cplx_ifft_simple(uint32_t m, void* data);
|
||||||
|
/**
|
||||||
|
* @brief Simpler API for the fftvec multiplication function.
|
||||||
|
* For each dimension, the precomputed tables for this dimension are generated automatically the first time.
|
||||||
|
* It is advised to do one dry-run call per desired dimension before using in a multithread environment */
|
||||||
|
EXPORT void cplx_fftvec_mul_simple(uint32_t m, void* r, const void* a, const void* b);
|
||||||
|
/**
|
||||||
|
* @brief Simpler API for the fftvec addmul function.
|
||||||
|
* For each dimension, the precomputed tables for this dimension are generated automatically the first time.
|
||||||
|
* It is advised to do one dry-run call per desired dimension before using in a multithread environment */
|
||||||
|
EXPORT void cplx_fftvec_addmul_simple(uint32_t m, void* r, const void* a, const void* b);
|
||||||
|
/**
|
||||||
|
* @brief Simpler API for the znx32 to cplx conversion.
|
||||||
|
* For each dimension, the precomputed tables for this dimension are generated automatically the first time.
|
||||||
|
* It is advised to do one dry-run call per desired dimension before using in a multithread environment */
|
||||||
|
EXPORT void cplx_from_znx32_simple(uint32_t m, void* r, const int32_t* x);
|
||||||
|
/**
|
||||||
|
* @brief Simpler API for the tnx32 to cplx conversion.
|
||||||
|
* For each dimension, the precomputed tables for this dimension are generated automatically the first time.
|
||||||
|
* It is advised to do one dry-run call per desired dimension before using in a multithread environment */
|
||||||
|
EXPORT void cplx_from_tnx32_simple(uint32_t m, void* r, const int32_t* x);
|
||||||
|
/**
|
||||||
|
* @brief Simpler API for the cplx to tnx32 conversion.
|
||||||
|
* For each dimension, the precomputed tables for this dimension are generated automatically the first time.
|
||||||
|
* It is advised to do one dry-run call per desired dimension before using in a multithread environment */
|
||||||
|
EXPORT void cplx_to_tnx32_simple(uint32_t m, double divisor, uint32_t log2overhead, int32_t* r, const void* x);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief converts, divides and round from cplx to znx32 (simple API)
|
||||||
|
* @param m the complex dimension
|
||||||
|
* @param divisor the divisor: a power of two, often m after an ifft
|
||||||
|
* @param r the result: must be a double array of size 2m. r must be distinct from x
|
||||||
|
* @param x the input: must hold m complex numbers.
|
||||||
|
*/
|
||||||
|
EXPORT void cplx_to_znx32_simple(uint32_t m, double divisor, int32_t* r, const void* x);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief converts from rnx64 to cplx (simple API)
|
||||||
|
* The bound on the output is assumed to be within ]2^-31,2^31[.
|
||||||
|
* Any coefficient that would fall outside this range is undefined behaviour.
|
||||||
|
* @param m the complex dimension
|
||||||
|
* @param r the result: must be an array of m complex numbers. r must be distinct from x
|
||||||
|
* @param x the input: must be an array of 2m doubles.
|
||||||
|
*/
|
||||||
|
EXPORT void cplx_from_rnx64_simple(uint32_t m, void* r, const double* x);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief converts, divides from cplx to rnx64 (simple API)
|
||||||
|
* @param m the complex dimension
|
||||||
|
* @param divisor the divisor: a power of two, often m after an ifft
|
||||||
|
* @param r the result: must be a double array of size 2m. r must be distinct from x
|
||||||
|
* @param x the input: must hold m complex numbers.
|
||||||
|
*/
|
||||||
|
EXPORT void cplx_to_rnx64_simple(uint32_t m, double divisor, double* r, const void* x);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief converts, divides and round to integer from cplx to rnx32 (simple API)
|
||||||
|
* @param m the complex dimension
|
||||||
|
* @param divisor the divisor: a power of two, often m after an ifft
|
||||||
|
* @param log2bound a guarantee on the log2bound of the output. log2bound<=48 will use a more efficient algorithm.
|
||||||
|
* @param r the result: must be a double array of size 2m. r must be distinct from x
|
||||||
|
* @param x the input: must hold m complex numbers.
|
||||||
|
*/
|
||||||
|
EXPORT void cplx_round_to_rnx64_simple(uint32_t m, double divisor, uint32_t log2bound, double* r, const void* x);
|
||||||
|
|
||||||
|
#endif // SPQLIOS_CPLX_FFT_H
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user