mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-09 20:56:47 +01:00
Add Hardware Abstraction Layer (#56)
This commit is contained in:
committed by
GitHub
parent
833520b163
commit
0e0745065e
63
.github/workflows/ci.yml
vendored
Normal file
63
.github/workflows/ci.yml
vendored
Normal file
@@ -0,0 +1,63 @@
|
||||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
jobs:
|
||||
build-and-test:
|
||||
name: Build & Test (stable, beta, nightly)
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
toolchain: [stable, beta, nightly]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Install toolchain
|
||||
uses: actions/setup-rust@v1
|
||||
with:
|
||||
toolchain: ${{ matrix.toolchain }}
|
||||
components: clippy
|
||||
|
||||
- name: Cache cargo registry
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: |
|
||||
~/.cargo/registry
|
||||
~/.cargo/git
|
||||
key: ${{ runner.os }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-cargo-registry-
|
||||
|
||||
- name: Cargo fmt check
|
||||
run: cargo fmt --all -- --check
|
||||
|
||||
- name: Build
|
||||
run: cargo build --all-targets --verbose
|
||||
|
||||
- name: Run tests
|
||||
run: cargo test --all-targets --verbose
|
||||
|
||||
- name: Run Clippy
|
||||
run: cargo clippy --all-targets -- -D warnings
|
||||
|
||||
docs:
|
||||
name: Check Documentation
|
||||
runs-on: ubuntu-latest
|
||||
needs: build-and-test
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-rust@v1
|
||||
with:
|
||||
toolchain: stable
|
||||
|
||||
- name: Build docs
|
||||
run: cargo doc --all --no-deps --verbose
|
||||
4
.gitmodules
vendored
4
.gitmodules
vendored
@@ -1,3 +1,3 @@
|
||||
[submodule "backend/spqlios-arithmetic"]
|
||||
path = backend/spqlios-arithmetic
|
||||
[submodule "backend/src/implementation/cpu_spqlios/spqlios-arithmetic"]
|
||||
path = backend/src/implementation/cpu_spqlios/spqlios-arithmetic
|
||||
url = https://github.com/phantomzone-org/spqlios-arithmetic
|
||||
|
||||
61
Cargo.lock
generated
61
Cargo.lock
generated
@@ -39,8 +39,11 @@ checksum = "7b7e4c2464d97fe331d41de9d5db0def0a96f4d823b8b32a2efd503578988973"
|
||||
name = "backend"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"cmake",
|
||||
"criterion",
|
||||
"itertools 0.14.0",
|
||||
"paste",
|
||||
"rand",
|
||||
"rand_core",
|
||||
"rand_distr",
|
||||
@@ -73,6 +76,15 @@ version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
|
||||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.2.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c3a42d84bb6b69d3a8b3eaacf0d88f179e1929695e1ad012b6cf64d9caaa5fd2"
|
||||
dependencies = [
|
||||
"shlex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cfg-if"
|
||||
version = "1.0.0"
|
||||
@@ -131,11 +143,21 @@ version = "0.7.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6"
|
||||
|
||||
[[package]]
|
||||
name = "cmake"
|
||||
version = "0.1.54"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0"
|
||||
dependencies = [
|
||||
"cc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "core"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"backend",
|
||||
"byteorder",
|
||||
"criterion",
|
||||
"itertools 0.14.0",
|
||||
"rand_distr",
|
||||
@@ -145,9 +167,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "criterion"
|
||||
version = "0.6.0"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3bf7af66b0989381bd0be551bd7cc91912a655a58c6918420c9527b1fd8b4679"
|
||||
checksum = "e1c047a62b0cc3e145fa84415a3191f628e980b194c2755aa12300a4e6cbd928"
|
||||
dependencies = [
|
||||
"anes",
|
||||
"cast",
|
||||
@@ -168,12 +190,12 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "criterion-plot"
|
||||
version = "0.5.0"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1"
|
||||
checksum = "9b1bcc0dc7dfae599d84ad0b1a55f80cde8af3725da8313b528da95ef783e338"
|
||||
dependencies = [
|
||||
"cast",
|
||||
"itertools 0.10.5",
|
||||
"itertools 0.13.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -251,15 +273,6 @@ dependencies = [
|
||||
"crunchy",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itertools"
|
||||
version = "0.10.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473"
|
||||
dependencies = [
|
||||
"either",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itertools"
|
||||
version = "0.13.0"
|
||||
@@ -340,6 +353,12 @@ version = "11.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9"
|
||||
|
||||
[[package]]
|
||||
name = "paste"
|
||||
version = "1.0.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
|
||||
|
||||
[[package]]
|
||||
name = "plotters"
|
||||
version = "0.3.7"
|
||||
@@ -527,18 +546,18 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.216"
|
||||
version = "1.0.219"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0b9781016e935a97e8beecf0c933758c97a5520d32930e460142b4cd80c6338e"
|
||||
checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6"
|
||||
dependencies = [
|
||||
"serde_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_derive"
|
||||
version = "1.0.216"
|
||||
version = "1.0.219"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "46f859dbbf73865c6627ed570e78961cd3ac92407a2d117204c49232485da55e"
|
||||
checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@@ -557,6 +576,12 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "shlex"
|
||||
version = "1.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "2.0.96"
|
||||
|
||||
@@ -9,4 +9,5 @@ rand_chacha = "0.9.0"
|
||||
rand_core = "0.9.3"
|
||||
rand_distr = "0.5.1"
|
||||
itertools = "0.14.0"
|
||||
criterion = "0.6.0"
|
||||
criterion = "0.7.0"
|
||||
byteorder = "1.5.0"
|
||||
|
||||
12
NOTICE
Normal file
12
NOTICE
Normal file
@@ -0,0 +1,12 @@
|
||||
Copyright 2025 Phantom Zone, Jean-Philippe Bossuat
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use
|
||||
this file except in compliance with the License. You may obtain a copy of the
|
||||
License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed
|
||||
under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
168
README.md
168
README.md
@@ -1,3 +1,165 @@
|
||||
<p align="center">
|
||||
<img src="poulpy.png" />
|
||||
</p>
|
||||
# 🐙 Poulpy
|
||||
|
||||
**Poulpy** is a fast & modular FHE library that implements Ring-Learning-With-Errors based homomorphic encryption. It adopts the bivariate polynomial representation proposed in [Revisiting Key Decomposition Techniques for FHE: Simpler, Faster and More Generic](https://eprint.iacr.org/2023/771). In addition to simpler and more efficient arithmetic than the residue number system (RNS), this representation provides a common plaintext space for all schemes and allows easy switching between any two schemes. Poulpy also decouples the schemes implementations from the polynomial arithmetic backend by being built around a hardware abstraction layer (HAL). This enables user to easily provide or use a custom backend.
|
||||
|
||||
### Bivariate Polynomial Representation
|
||||
|
||||
Existing FHE implementations (such as [Lattigo](https://github.com/tuneinsight/lattigo) or [OpenFHE](https://github.com/openfheorg/openfhe-development)) use the [residue-number-system](https://en.wikipedia.org/wiki/Residue_number_system) (RNS) to represent large integers. Although the parallelism and carry-less arithmetic provided by the RNS representation provides a very efficient modular arithmetic over large-integers, it suffers from various drawbacks when used in the context of FHE. The main idea behind the bivariate representation is to decouple the cyclotomic arithmetic from the large number arithmetic. Instead of using the RNS representation for large integer, integers are decomposed in base $2^{-K}$ over the Torus $\mathbb{T}_{N}[X]$.
|
||||
|
||||
This provides the following benefits:
|
||||
|
||||
- **Intuitive, efficient and reusable parameterization & instances:** Only the bit-size of the modulus is required from the user (i.e. Torus precision). As such, parameterization is natural and generic, and instances can be reused for any circuit consuming the same homomorphic capacity, without loss of efficiency. With the RNS representation, individual NTT friendly primes needs to be specified for each level, making the parameterization not user friendly and circuit-specific.
|
||||
|
||||
- **Optimal and granular rescaling:** Ciphertext rescaling is carried out with bit-shifting, enabling a bit-level granular rescaling and optimal noise/homomorphic capacity management. In the RNS representation, ciphertext division can only be done by one of the primes composing the modulus, leading to difficult scaling management and frequent inefficient noise/homomorphic capacity management.
|
||||
|
||||
- **Linear number of DFT in the half external product:** The bivariate representation of the coefficients implicitly provides the digit decomposition, as such the number of DFT is linear in the number of limbs, contrary to the RNS representation where it is quadratic due to the RNS basis conversion. This enables a much more efficient key-switching, which is the **most used and expensive** FHE operation.
|
||||
|
||||
- **Unified plaintext space:** The bivariate polynomial representation is by essence a high precision discretized representation of the Torus $\mathbb{T}_{N}[X]$. Using the Torus as the common plaintext space for all schemes achieves the vision of [CHIMERA: Combining Ring-LWE-based Fully Homomorphic Encryption Schemes](https://eprint.iacr.org/2018/758) which is to unify all RLWE-based FHE schemes (TFHE, FHEW, BGV, BFV, CLPX, GBFV, CKKS, ...) under a single scheme with different encodings, enabling native and efficient scheme-switching functionalities.
|
||||
|
||||
- **Simpler implementation**: Since the cyclotomic arithmetic is decoupled from the coefficient representation, the same pipeline (including DFT) can be reused for all limbs (unlike in the RNS representation), making this representation a prime target for hardware acceleration.
|
||||
|
||||
- **Deterministic computation**: Although being defined on the Torus, bivariate arithmetic remains integer polynomial arithmetic, ensuring all computations are deterministic, the contract being that output should be reproducible and identical, regardless of the backend or hardware.
|
||||
|
||||
### Hardware Abstraction Layer
|
||||
|
||||
In addition to providing a general purpose FHE library over a unified plaintext space, Poulpy is also designed from the ground up around a **hardware abstraction layer** that closely matches the API of [spqlios-arithmetic](https://github.com/tfhe/spqlios-arithmetic). The bivariate representation is by itself hardware friendly as it uses flat, aligned & vectorized memory layout. Finally, generic opaque write only structs (prepared versions) are provided, making it easy for developers to provide hardware focused/optimized operations. This makes possible for anyone to provide or use a custom backend.
|
||||
|
||||
## Library Overview
|
||||
|
||||
- **`backend/hal`**: hardware abstraction layer. This layer targets users that want to provide their own backend or use a third party backend.
|
||||
|
||||
- **`api`**: fixed public low-level polynomial level arithmetic API closely matching spqlios-arithmetic. The goal is to eventually freeze this API, in order to decouple it from the OEP traits, ensuring that changes to implementations do not affect the front end API.
|
||||
|
||||
```rust
|
||||
pub trait SvpPrepare<B: Backend> {
|
||||
fn svp_prepare<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: SvpPPolToMut<B>,
|
||||
A: ScalarZnxToRef;
|
||||
}
|
||||
````
|
||||
|
||||
- **`delegates`**: link between the user facing API and implementation OEP. Each trait of `api` is implemented by calling its corresponding trait on the `oep`.
|
||||
|
||||
```rust
|
||||
impl<B> SvpPrepare<B> for Module<B>
|
||||
where
|
||||
B: Backend + SvpPrepareImpl<B>,
|
||||
{
|
||||
fn svp_prepare<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: SvpPPolToMut<B>,
|
||||
A: ScalarZnxToRef,
|
||||
{
|
||||
B::svp_prepare_impl(self, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
- **`layouts`**: defines the layouts of the front-end algebraic structs matching spqlios-arithmetic definitions, such as `ScalarZnx`, `VecZnx` or opaque backend prepared struct such as `SvpPPol` and `VmpPMat`.
|
||||
|
||||
```rust
|
||||
pub struct SvpPPol<D: Data, B: Backend> {
|
||||
data: D,
|
||||
n: usize,
|
||||
cols: usize,
|
||||
_phantom: PhantomData<B>,
|
||||
}
|
||||
```
|
||||
|
||||
- **`oep`**: open extension points, which can be implemented by the user to provide a custom backend.
|
||||
|
||||
```rust
|
||||
pub unsafe trait SvpPrepareImpl<B: Backend> {
|
||||
fn svp_prepare_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: SvpPPolToMut<B>,
|
||||
A: ScalarZnxToRef;
|
||||
}
|
||||
```
|
||||
|
||||
- **`tests`**: exported generic tests for the OEP/structs. Their goal is to enable a user to automatically be able to test its backend implementation, without having to re-implement any tests.
|
||||
|
||||
- **`backend/implementation`**:
|
||||
- **`cpu_spqlios`**: concrete cpu implementation of the hal through the oep using bindings on spqlios-arithmetic. This implementation currently supports the `FFT64` backend and will be extended to support the `NTT120` backend once it is available in spqlios-arithmetic.
|
||||
|
||||
```rust
|
||||
unsafe impl SvpPrepareImpl<Self> for FFT64 {
|
||||
fn svp_prepare_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: SvpPPolToMut<Self>,
|
||||
A: ScalarZnxToRef,
|
||||
{
|
||||
unsafe {
|
||||
svp::svp_prepare(
|
||||
module.ptr(),
|
||||
res.to_mut().at_mut_ptr(res_col, 0) as *mut svp::svp_ppol_t,
|
||||
a.to_ref().at_ptr(a_col, 0),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
- **`core`**: core of the FHE library, implementing scheme agnostic RLWE arithmetic for LWE, GLWE, GGLWE and GGSW ciphertexts. It notably includes all possible cross-ciphertext operations, for example applying an external product on a GGLWE or an automorphism on a GGSW, as well as blind rotation. This crate is entirely implemented using the hardware abstraction layer API, and is thus solely defined over generic and traits (including tests). As such it will work over any backend, as long as it implements the necessary traits defined in the OEP.
|
||||
|
||||
```rust
|
||||
pub struct GLWESecret<D: Data> {
|
||||
pub(crate) data: ScalarZnx<D>,
|
||||
pub(crate) dist: Distribution,
|
||||
}
|
||||
|
||||
pub struct GLWESecretExec<D: Data, B: Backend> {
|
||||
pub(crate) data: SvpPPol<D, B>,
|
||||
pub(crate) dist: Distribution,
|
||||
}
|
||||
|
||||
impl<D: DataMut, B: Backend> GLWESecretExec<D, B> {
|
||||
pub fn prepare<O>(&mut self, module: &Module<B>, sk: &GLWESecret<O>)
|
||||
where
|
||||
O: DataRef,
|
||||
Module<B>: SvpPrepare<B>,
|
||||
{
|
||||
(0..self.rank()).for_each(|i| {
|
||||
module.svp_prepare(&mut self.data, i, &sk.data, i);
|
||||
});
|
||||
self.dist = sk.dist
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Installation
|
||||
|
||||
TBD — currently not published on crates.io. Clone the repository and use via path-based dependencies.
|
||||
|
||||
## Documentation
|
||||
|
||||
* Full `cargo doc` documentation is coming soon.
|
||||
* Architecture diagrams and design notes will be added in the [`/doc`](./doc) folder.
|
||||
|
||||
## Contributing
|
||||
|
||||
We welcome external contributions, please see [CONTRIBUTING](./CONTRIBUTING.md).
|
||||
|
||||
## Security
|
||||
|
||||
Please see [SECURITY](./SECURITY.md).
|
||||
|
||||
## License
|
||||
|
||||
Poulpy is licensed under the Apache 2.0 License. See [NOTICE](./NOTICE) & [LICENSE](./LICENSE).
|
||||
|
||||
## Acknowledgement
|
||||
|
||||
**Poulpy** is inspired by the modular architecture of [Lattigo](https://github.com/tuneinsight/lattigo) and [TFHE-go](https://github.com/sp301415/tfhe-go), and its development is lead by Lattigo’s co-author and main contributor [@Pro7ech](https://github.com/Pro7ech). Poulpy reflects the experience gained from over five years of designing and maintaining Lattigo, and represents the next evolution in architecture, performance, and backend philosophy.
|
||||
|
||||
## Citing
|
||||
Please use the following BibTex entry for citing Lattigo
|
||||
|
||||
@misc{poulpy,
|
||||
title = {Poulpy v0.1.0},
|
||||
howpublished = {Online: \url{https://github.com/phantomzone-org/poulpy}},
|
||||
month = Aug,
|
||||
year = 2025,
|
||||
note = {Phantom Zone}
|
||||
}
|
||||
|
||||
@@ -13,7 +13,12 @@ rand_distr = {workspace = true}
|
||||
rand_core = {workspace = true}
|
||||
sampling = { path = "../sampling" }
|
||||
utils = { path = "../utils" }
|
||||
paste = "1.0.15"
|
||||
byteorder = {workspace = true}
|
||||
|
||||
[[bench]]
|
||||
name = "fft"
|
||||
harness = false
|
||||
[build-dependencies]
|
||||
cmake = "0.1.54"
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
all-features = true
|
||||
rustdoc-args = ["--cfg", "docsrs"]
|
||||
@@ -1,56 +0,0 @@
|
||||
use backend::ffi::reim::*;
|
||||
use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main};
|
||||
use std::ffi::c_void;
|
||||
|
||||
fn fft(c: &mut Criterion) {
|
||||
fn forward<'a>(m: u32, log_bound: u32, reim_fft_precomp: *mut reim_fft_precomp, a: &'a [i64]) -> Box<dyn FnMut() + 'a> {
|
||||
unsafe {
|
||||
let buf_a: *mut f64 = reim_fft_precomp_get_buffer(reim_fft_precomp, 0);
|
||||
reim_from_znx64_simple(m as u32, log_bound as u32, buf_a as *mut c_void, a.as_ptr());
|
||||
Box::new(move || reim_fft(reim_fft_precomp, buf_a))
|
||||
}
|
||||
}
|
||||
|
||||
fn backward<'a>(m: u32, log_bound: u32, reim_ifft_precomp: *mut reim_ifft_precomp, a: &'a [i64]) -> Box<dyn FnMut() + 'a> {
|
||||
Box::new(move || unsafe {
|
||||
let buf_a: *mut f64 = reim_ifft_precomp_get_buffer(reim_ifft_precomp, 0);
|
||||
reim_from_znx64_simple(m as u32, log_bound as u32, buf_a as *mut c_void, a.as_ptr());
|
||||
reim_ifft(reim_ifft_precomp, buf_a);
|
||||
})
|
||||
}
|
||||
|
||||
let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = c.benchmark_group("fft");
|
||||
|
||||
for log_n in 10..17 {
|
||||
let n: usize = 1 << log_n;
|
||||
let m: usize = n >> 1;
|
||||
let log_bound: u32 = 19;
|
||||
|
||||
let mut a: Vec<i64> = vec![i64::default(); n];
|
||||
a.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
|
||||
|
||||
unsafe {
|
||||
let reim_fft_precomp: *mut reim_fft_precomp = new_reim_fft_precomp(m as u32, 1);
|
||||
let reim_ifft_precomp: *mut reim_ifft_precomp = new_reim_ifft_precomp(m as u32, 1);
|
||||
|
||||
let runners: [(String, Box<dyn FnMut()>); 2] = [
|
||||
(format!("forward"), {
|
||||
forward(m as u32, log_bound, reim_fft_precomp, &a)
|
||||
}),
|
||||
(format!("backward"), {
|
||||
backward(m as u32, log_bound, reim_ifft_precomp, &a)
|
||||
}),
|
||||
];
|
||||
|
||||
for (name, mut runner) in runners {
|
||||
let id: BenchmarkId = BenchmarkId::new(name, format!("n={}", 1 << log_n));
|
||||
b.bench_with_input(id, &(), |b: &mut criterion::Bencher<'_>, _| {
|
||||
b.iter(&mut runner)
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, fft,);
|
||||
criterion_main!(benches);
|
||||
@@ -1,13 +1,7 @@
|
||||
use std::path::absolute;
|
||||
|
||||
fn main() {
|
||||
println!(
|
||||
"cargo:rustc-link-search=native={}",
|
||||
absolute("spqlios-arithmetic/build/spqlios")
|
||||
.unwrap()
|
||||
.to_str()
|
||||
.unwrap()
|
||||
);
|
||||
println!("cargo:rustc-link-lib=static=spqlios");
|
||||
// println!("cargo:rustc-link-lib=dylib=spqlios")
|
||||
}
|
||||
mod builds {
|
||||
pub mod cpu_spqlios;
|
||||
}
|
||||
|
||||
fn main() {
|
||||
builds::cpu_spqlios::build()
|
||||
}
|
||||
|
||||
10
backend/builds/cpu_spqlios.rs
Normal file
10
backend/builds/cpu_spqlios.rs
Normal file
@@ -0,0 +1,10 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
pub fn build() {
|
||||
let dst: PathBuf = cmake::Config::new("src/implementation/cpu_spqlios/spqlios-arithmetic").build();
|
||||
|
||||
let lib_dir: PathBuf = dst.join("lib");
|
||||
|
||||
println!("cargo:rustc-link-search=native={}", lib_dir.display());
|
||||
println!("cargo:rustc-link-lib=static=spqlios");
|
||||
}
|
||||
27
backend/docs/backend_safety_contract.md
Normal file
27
backend/docs/backend_safety_contract.md
Normal file
@@ -0,0 +1,27 @@
|
||||
Implementors must uphold all of the following for **every** call:
|
||||
|
||||
* **Memory domains**: Pointers produced by to_ref() / to_mut() must be valid
|
||||
in the target execution domain for Self (e.g., CPU host memory for CPU,
|
||||
device memory for a specific GPU). If host↔device transfers are required,
|
||||
perform them inside the implementation; do not assume the caller synchronized.
|
||||
|
||||
* **Alignment & layout**: All data must match the layout, stride, and element
|
||||
size expected by the kernel. size(), rows(), cols_in(), cols_out(),
|
||||
n(), etc... must be interpreted identically to the reference CPU implementation.
|
||||
|
||||
* **Scratch lifetime**: Any scratch obtained from scratch.tmp_slice(...) (or a
|
||||
backend-specific variant) must remain valid for the duration of the call; it
|
||||
may be reused by the caller afterwards. Do not retain pointers past return.
|
||||
|
||||
* **Synchronization**: The call must appear **logically synchronous** to the
|
||||
caller. If you enqueue asynchronous work (e.g., CUDA streams), you must
|
||||
ensure completion before returning or clearly document and implement a
|
||||
synchronization contract used by all backends consistently.
|
||||
|
||||
* **Aliasing & overlaps**: If res, a, b, etc... alias or overlap in ways
|
||||
that violate your kernel’s requirements, you must either handle safely or reject
|
||||
with a defined error path (e.g., debug assert). Never trigger UB.
|
||||
|
||||
* **Numerical contract**: For modular/integer arithmetic, results must be
|
||||
bit-exact to the specification. For floating-point, any permitted tolerance
|
||||
must be documented and consistent with the crate’s guarantees.
|
||||
@@ -1,56 +0,0 @@
|
||||
use backend::ffi::reim::*;
|
||||
use std::ffi::c_void;
|
||||
use std::time::Instant;
|
||||
|
||||
fn main() {
|
||||
let log_bound: usize = 19;
|
||||
|
||||
let n: usize = 2048;
|
||||
let m: usize = n >> 1;
|
||||
|
||||
let mut a: Vec<i64> = vec![i64::default(); n];
|
||||
let mut b: Vec<i64> = vec![i64::default(); n];
|
||||
let mut c: Vec<i64> = vec![i64::default(); n];
|
||||
|
||||
a.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
|
||||
b[1] = 1;
|
||||
|
||||
println!("{:?}", b);
|
||||
|
||||
unsafe {
|
||||
let reim_fft_precomp = new_reim_fft_precomp(m as u32, 2);
|
||||
let reim_ifft_precomp = new_reim_ifft_precomp(m as u32, 1);
|
||||
|
||||
let buf_a = reim_fft_precomp_get_buffer(reim_fft_precomp, 0);
|
||||
let buf_b = reim_fft_precomp_get_buffer(reim_fft_precomp, 1);
|
||||
let buf_c = reim_ifft_precomp_get_buffer(reim_ifft_precomp, 0);
|
||||
|
||||
let now = Instant::now();
|
||||
(0..1024).for_each(|_| {
|
||||
reim_from_znx64_simple(m as u32, log_bound as u32, buf_a as *mut c_void, a.as_ptr());
|
||||
reim_fft(reim_fft_precomp, buf_a);
|
||||
|
||||
reim_from_znx64_simple(m as u32, log_bound as u32, buf_b as *mut c_void, b.as_ptr());
|
||||
reim_fft(reim_fft_precomp, buf_b);
|
||||
|
||||
reim_fftvec_mul_simple(
|
||||
m as u32,
|
||||
buf_c as *mut c_void,
|
||||
buf_a as *mut c_void,
|
||||
buf_b as *mut c_void,
|
||||
);
|
||||
reim_ifft(reim_ifft_precomp, buf_c);
|
||||
|
||||
reim_to_znx64_simple(
|
||||
m as u32,
|
||||
m as f64,
|
||||
log_bound as u32,
|
||||
c.as_mut_ptr(),
|
||||
buf_c as *mut c_void,
|
||||
)
|
||||
});
|
||||
|
||||
println!("time: {}us", now.elapsed().as_micros());
|
||||
println!("{:?}", &c[..16]);
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,14 @@
|
||||
use backend::{
|
||||
AddNormal, Decoding, Encoding, FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc,
|
||||
ScalarZnxDftOps, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft,
|
||||
VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxInfos,
|
||||
hal::{
|
||||
api::{
|
||||
ModuleNew, ScalarZnxAlloc, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPrepare,
|
||||
VecZnxAddNormal, VecZnxAlloc, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigNormalize,
|
||||
VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace, VecZnxDecodeVeci64, VecZnxDftAlloc, VecZnxDftFromVecZnx,
|
||||
VecZnxDftToVecZnxBigTmpA, VecZnxEncodeVeci64, VecZnxFillUniform, VecZnxNormalizeInplace, ZnxInfos,
|
||||
},
|
||||
layouts::{Module, ScalarZnx, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft},
|
||||
},
|
||||
implementation::cpu_spqlios::FFT64,
|
||||
};
|
||||
use itertools::izip;
|
||||
use sampling::source::Source;
|
||||
@@ -12,35 +19,35 @@ fn main() {
|
||||
let ct_size: usize = 3;
|
||||
let msg_size: usize = 2;
|
||||
let log_scale: usize = msg_size * basek - 5;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n as u64);
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(module.vec_znx_big_normalize_tmp_bytes());
|
||||
let mut scratch: ScratchOwned<FFT64> = ScratchOwned::<FFT64>::alloc(module.vec_znx_big_normalize_tmp_bytes(n));
|
||||
|
||||
let seed: [u8; 32] = [0; 32];
|
||||
let mut source: Source = Source::new(seed);
|
||||
|
||||
// s <- Z_{-1, 0, 1}[X]/(X^{N}+1)
|
||||
let mut s: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
||||
let mut s: ScalarZnx<Vec<u8>> = module.scalar_znx_alloc(1);
|
||||
s.fill_ternary_prob(0, 0.5, &mut source);
|
||||
|
||||
// Buffer to store s in the DFT domain
|
||||
let mut s_dft: ScalarZnxDft<Vec<u8>, FFT64> = module.new_scalar_znx_dft(s.cols());
|
||||
let mut s_dft: SvpPPol<Vec<u8>, FFT64> = module.svp_ppol_alloc(s.cols());
|
||||
|
||||
// s_dft <- DFT(s)
|
||||
module.svp_prepare(&mut s_dft, 0, &s, 0);
|
||||
|
||||
// Allocates a VecZnx with two columns: ct=(0, 0)
|
||||
let mut ct: VecZnx<Vec<u8>> = module.new_vec_znx(
|
||||
let mut ct: VecZnx<Vec<u8>> = module.vec_znx_alloc(
|
||||
2, // Number of columns
|
||||
ct_size, // Number of small poly per column
|
||||
);
|
||||
|
||||
// Fill the second column with random values: ct = (0, a)
|
||||
ct.fill_uniform(basek, 1, ct_size, &mut source);
|
||||
module.vec_znx_fill_uniform(basek, &mut ct, 1, ct_size * basek, &mut source);
|
||||
|
||||
let mut buf_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct_size);
|
||||
let mut buf_dft: VecZnxDft<Vec<u8>, FFT64> = module.vec_znx_dft_alloc(1, ct_size);
|
||||
|
||||
module.vec_znx_dft(1, 0, &mut buf_dft, 0, &ct, 1);
|
||||
module.vec_znx_dft_from_vec_znx(1, 0, &mut buf_dft, 0, &ct, 1);
|
||||
|
||||
// Applies DFT(ct[1]) * DFT(s)
|
||||
module.svp_apply_inplace(
|
||||
@@ -53,18 +60,18 @@ fn main() {
|
||||
// Alias scratch space (VecZnxDft<B> is always at least as big as VecZnxBig<B>)
|
||||
|
||||
// BIG(ct[1] * s) <- IDFT(DFT(ct[1] * s)) (not normalized)
|
||||
let mut buf_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, ct_size);
|
||||
module.vec_znx_idft_tmp_a(&mut buf_big, 0, &mut buf_dft, 0);
|
||||
let mut buf_big: VecZnxBig<Vec<u8>, FFT64> = module.vec_znx_big_alloc(1, ct_size);
|
||||
module.vec_znx_dft_to_vec_znx_big_tmp_a(&mut buf_big, 0, &mut buf_dft, 0);
|
||||
|
||||
// Creates a plaintext: VecZnx with 1 column
|
||||
let mut m = module.new_vec_znx(
|
||||
let mut m = module.vec_znx_alloc(
|
||||
1, // Number of columns
|
||||
msg_size, // Number of small polynomials
|
||||
);
|
||||
let mut want: Vec<i64> = vec![0; n];
|
||||
want.iter_mut()
|
||||
.for_each(|x| *x = source.next_u64n(16, 15) as i64);
|
||||
m.encode_vec_i64(0, basek, log_scale, &want, 4);
|
||||
module.encode_vec_i64(basek, &mut m, 0, log_scale, &want, 4);
|
||||
module.vec_znx_normalize_inplace(basek, &mut m, 0, scratch.borrow());
|
||||
|
||||
// m - BIG(ct[1] * s)
|
||||
@@ -88,13 +95,14 @@ fn main() {
|
||||
|
||||
// Add noise to ct[0]
|
||||
// ct[0] <- ct[0] + e
|
||||
ct.add_normal(
|
||||
module.vec_znx_add_normal(
|
||||
basek,
|
||||
&mut ct,
|
||||
0, // Selects the first column of ct (ct[0])
|
||||
basek * ct_size, // Scaling of the noise: 2^{-basek * limbs}
|
||||
&mut source,
|
||||
3.2, // Standard deviation
|
||||
19.0, // Truncatation bound
|
||||
3.2, // Standard deviation
|
||||
3.2 * 6.0, // Truncatation bound
|
||||
);
|
||||
|
||||
// Final ciphertext: ct = (-a * s + m + e, a)
|
||||
@@ -102,7 +110,7 @@ fn main() {
|
||||
// Decryption
|
||||
|
||||
// DFT(ct[1] * s)
|
||||
module.vec_znx_dft(1, 0, &mut buf_dft, 0, &ct, 1);
|
||||
module.vec_znx_dft_from_vec_znx(1, 0, &mut buf_dft, 0, &ct, 1);
|
||||
module.svp_apply_inplace(
|
||||
&mut buf_dft,
|
||||
0, // Selects the first column of res.
|
||||
@@ -111,18 +119,18 @@ fn main() {
|
||||
);
|
||||
|
||||
// BIG(c1 * s) = IDFT(DFT(c1 * s))
|
||||
module.vec_znx_idft_tmp_a(&mut buf_big, 0, &mut buf_dft, 0);
|
||||
module.vec_znx_dft_to_vec_znx_big_tmp_a(&mut buf_big, 0, &mut buf_dft, 0);
|
||||
|
||||
// BIG(c1 * s) + ct[0]
|
||||
module.vec_znx_big_add_small_inplace(&mut buf_big, 0, &ct, 0);
|
||||
|
||||
// m + e <- BIG(ct[1] * s + ct[0])
|
||||
let mut res = module.new_vec_znx(1, ct_size);
|
||||
let mut res = module.vec_znx_alloc(1, ct_size);
|
||||
module.vec_znx_big_normalize(basek, &mut res, 0, &buf_big, 0, scratch.borrow());
|
||||
|
||||
// have = m * 2^{log_scale} + e
|
||||
let mut have: Vec<i64> = vec![i64::default(); n];
|
||||
res.decode_vec_i64(0, basek, res.size() * basek, &mut have);
|
||||
module.decode_vec_i64(basek, &mut res, 0, ct_size * basek, &mut have);
|
||||
|
||||
let scale: f64 = (1 << (res.size() * basek - log_scale)) as f64;
|
||||
izip!(want.iter(), have.iter())
|
||||
|
||||
Submodule backend/spqlios-arithmetic deleted from 0ae9a7b5ad
@@ -1,344 +0,0 @@
|
||||
use crate::ffi::znx::znx_zero_i64_ref;
|
||||
use crate::znx_base::{ZnxView, ZnxViewMut};
|
||||
use crate::{VecZnx, znx_base::ZnxInfos};
|
||||
use itertools::izip;
|
||||
use rug::{Assign, Float};
|
||||
use std::cmp::min;
|
||||
|
||||
pub trait Encoding {
|
||||
/// encode a vector of i64 on the receiver.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `col_i`: the index of the poly where to encode the data.
|
||||
/// * `basek`: base two negative logarithm decomposition of the receiver.
|
||||
/// * `k`: base two negative logarithm of the scaling of the data.
|
||||
/// * `data`: data to encode on the receiver.
|
||||
/// * `log_max`: base two logarithm of the infinity norm of the input data.
|
||||
fn encode_vec_i64(&mut self, col_i: usize, basek: usize, k: usize, data: &[i64], log_max: usize);
|
||||
|
||||
/// encodes a single i64 on the receiver at the given index.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `col_i`: the index of the poly where to encode the data.
|
||||
/// * `basek`: base two negative logarithm decomposition of the receiver.
|
||||
/// * `k`: base two negative logarithm of the scaling of the data.
|
||||
/// * `i`: index of the coefficient on which to encode the data.
|
||||
/// * `data`: data to encode on the receiver.
|
||||
/// * `log_max`: base two logarithm of the infinity norm of the input data.
|
||||
fn encode_coeff_i64(&mut self, col_i: usize, basek: usize, k: usize, i: usize, data: i64, log_max: usize);
|
||||
}
|
||||
|
||||
pub trait Decoding {
|
||||
/// decode a vector of i64 from the receiver.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `col_i`: the index of the poly where to encode the data.
|
||||
/// * `basek`: base two negative logarithm decomposition of the receiver.
|
||||
/// * `k`: base two logarithm of the scaling of the data.
|
||||
/// * `data`: data to decode from the receiver.
|
||||
fn decode_vec_i64(&self, col_i: usize, basek: usize, k: usize, data: &mut [i64]);
|
||||
|
||||
/// decode a vector of Float from the receiver.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `col_i`: the index of the poly where to encode the data.
|
||||
/// * `basek`: base two negative logarithm decomposition of the receiver.
|
||||
/// * `data`: data to decode from the receiver.
|
||||
fn decode_vec_float(&self, col_i: usize, basek: usize, data: &mut [Float]);
|
||||
|
||||
/// decode a single of i64 from the receiver at the given index.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `col_i`: the index of the poly where to encode the data.
|
||||
/// * `basek`: base two negative logarithm decomposition of the receiver.
|
||||
/// * `k`: base two negative logarithm of the scaling of the data.
|
||||
/// * `i`: index of the coefficient to decode.
|
||||
/// * `data`: data to decode from the receiver.
|
||||
fn decode_coeff_i64(&self, col_i: usize, basek: usize, k: usize, i: usize) -> i64;
|
||||
}
|
||||
|
||||
impl<D: AsMut<[u8]> + AsRef<[u8]>> Encoding for VecZnx<D> {
|
||||
fn encode_vec_i64(&mut self, col_i: usize, basek: usize, k: usize, data: &[i64], log_max: usize) {
|
||||
encode_vec_i64(self, col_i, basek, k, data, log_max)
|
||||
}
|
||||
|
||||
fn encode_coeff_i64(&mut self, col_i: usize, basek: usize, k: usize, i: usize, value: i64, log_max: usize) {
|
||||
encode_coeff_i64(self, col_i, basek, k, i, value, log_max)
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: AsRef<[u8]>> Decoding for VecZnx<D> {
|
||||
fn decode_vec_i64(&self, col_i: usize, basek: usize, k: usize, data: &mut [i64]) {
|
||||
decode_vec_i64(self, col_i, basek, k, data)
|
||||
}
|
||||
|
||||
fn decode_vec_float(&self, col_i: usize, basek: usize, data: &mut [Float]) {
|
||||
decode_vec_float(self, col_i, basek, data)
|
||||
}
|
||||
|
||||
fn decode_coeff_i64(&self, col_i: usize, basek: usize, k: usize, i: usize) -> i64 {
|
||||
decode_coeff_i64(self, col_i, basek, k, i)
|
||||
}
|
||||
}
|
||||
|
||||
fn encode_vec_i64<D: AsMut<[u8]> + AsRef<[u8]>>(
|
||||
a: &mut VecZnx<D>,
|
||||
col_i: usize,
|
||||
basek: usize,
|
||||
k: usize,
|
||||
data: &[i64],
|
||||
log_max: usize,
|
||||
) {
|
||||
let size: usize = (k + basek - 1) / basek;
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(
|
||||
size <= a.size(),
|
||||
"invalid argument k: (k + a.basek - 1)/a.basek={} > a.size()={}",
|
||||
size,
|
||||
a.size()
|
||||
);
|
||||
assert!(col_i < a.cols());
|
||||
assert!(data.len() <= a.n())
|
||||
}
|
||||
|
||||
let data_len: usize = data.len();
|
||||
let k_rem: usize = basek - (k % basek);
|
||||
|
||||
// Zeroes coefficients of the i-th column
|
||||
(0..a.size()).for_each(|i| unsafe {
|
||||
znx_zero_i64_ref(a.n() as u64, a.at_mut_ptr(col_i, i));
|
||||
});
|
||||
|
||||
// If 2^{basek} * 2^{k_rem} < 2^{63}-1, then we can simply copy
|
||||
// values on the last limb.
|
||||
// Else we decompose values base2k.
|
||||
if log_max + k_rem < 63 || k_rem == basek {
|
||||
a.at_mut(col_i, size - 1)[..data_len].copy_from_slice(&data[..data_len]);
|
||||
} else {
|
||||
let mask: i64 = (1 << basek) - 1;
|
||||
let steps: usize = min(size, (log_max + basek - 1) / basek);
|
||||
(size - steps..size)
|
||||
.rev()
|
||||
.enumerate()
|
||||
.for_each(|(i, i_rev)| {
|
||||
let shift: usize = i * basek;
|
||||
izip!(a.at_mut(col_i, i_rev).iter_mut(), data.iter()).for_each(|(y, x)| *y = (x >> shift) & mask);
|
||||
})
|
||||
}
|
||||
|
||||
// Case where self.prec % self.k != 0.
|
||||
if k_rem != basek {
|
||||
let steps: usize = min(size, (log_max + basek - 1) / basek);
|
||||
(size - steps..size).rev().for_each(|i| {
|
||||
a.at_mut(col_i, i)[..data_len]
|
||||
.iter_mut()
|
||||
.for_each(|x| *x <<= k_rem);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn decode_vec_i64<D: AsRef<[u8]>>(a: &VecZnx<D>, col_i: usize, basek: usize, k: usize, data: &mut [i64]) {
|
||||
let size: usize = (k + basek - 1) / basek;
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(
|
||||
data.len() >= a.n(),
|
||||
"invalid data: data.len()={} < a.n()={}",
|
||||
data.len(),
|
||||
a.n()
|
||||
);
|
||||
assert!(col_i < a.cols());
|
||||
}
|
||||
data.copy_from_slice(a.at(col_i, 0));
|
||||
let rem: usize = basek - (k % basek);
|
||||
if k < basek {
|
||||
data.iter_mut().for_each(|x| *x >>= rem);
|
||||
} else {
|
||||
(1..size).for_each(|i| {
|
||||
if i == size - 1 && rem != basek {
|
||||
let k_rem: usize = basek - rem;
|
||||
izip!(a.at(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| {
|
||||
*y = (*y << k_rem) + (x >> rem);
|
||||
});
|
||||
} else {
|
||||
izip!(a.at(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| {
|
||||
*y = (*y << basek) + x;
|
||||
});
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn decode_vec_float<D: AsRef<[u8]>>(a: &VecZnx<D>, col_i: usize, basek: usize, data: &mut [Float]) {
|
||||
let size: usize = a.size();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(
|
||||
data.len() >= a.n(),
|
||||
"invalid data: data.len()={} < a.n()={}",
|
||||
data.len(),
|
||||
a.n()
|
||||
);
|
||||
assert!(col_i < a.cols());
|
||||
}
|
||||
|
||||
let prec: u32 = (basek * size) as u32;
|
||||
|
||||
// 2^{basek}
|
||||
let base = Float::with_val(prec, (1 << basek) as f64);
|
||||
|
||||
// y[i] = sum x[j][i] * 2^{-basek*j}
|
||||
(0..size).for_each(|i| {
|
||||
if i == 0 {
|
||||
izip!(a.at(col_i, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
|
||||
y.assign(*x);
|
||||
*y /= &base;
|
||||
});
|
||||
} else {
|
||||
izip!(a.at(col_i, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
|
||||
*y += Float::with_val(prec, *x);
|
||||
*y /= &base;
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn encode_coeff_i64<D: AsMut<[u8]> + AsRef<[u8]>>(
|
||||
a: &mut VecZnx<D>,
|
||||
col_i: usize,
|
||||
basek: usize,
|
||||
k: usize,
|
||||
i: usize,
|
||||
value: i64,
|
||||
log_max: usize,
|
||||
) {
|
||||
let size: usize = (k + basek - 1) / basek;
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(i < a.n());
|
||||
assert!(
|
||||
size <= a.size(),
|
||||
"invalid argument k: (k + a.basek - 1)/a.basek={} > a.size()={}",
|
||||
size,
|
||||
a.size()
|
||||
);
|
||||
assert!(col_i < a.cols());
|
||||
}
|
||||
|
||||
let k_rem: usize = basek - (k % basek);
|
||||
(0..a.size()).for_each(|j| a.at_mut(col_i, j)[i] = 0);
|
||||
|
||||
// If 2^{basek} * 2^{k_rem} < 2^{63}-1, then we can simply copy
|
||||
// values on the last limb.
|
||||
// Else we decompose values base2k.
|
||||
if log_max + k_rem < 63 || k_rem == basek {
|
||||
a.at_mut(col_i, size - 1)[i] = value;
|
||||
} else {
|
||||
let mask: i64 = (1 << basek) - 1;
|
||||
let steps: usize = min(size, (log_max + basek - 1) / basek);
|
||||
(size - steps..size)
|
||||
.rev()
|
||||
.enumerate()
|
||||
.for_each(|(j, j_rev)| {
|
||||
a.at_mut(col_i, j_rev)[i] = (value >> (j * basek)) & mask;
|
||||
})
|
||||
}
|
||||
|
||||
// Case where prec % k != 0.
|
||||
if k_rem != basek {
|
||||
let steps: usize = min(size, (log_max + basek - 1) / basek);
|
||||
(size - steps..size).rev().for_each(|j| {
|
||||
a.at_mut(col_i, j)[i] <<= k_rem;
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn decode_coeff_i64<D: AsRef<[u8]>>(a: &VecZnx<D>, col_i: usize, basek: usize, k: usize, i: usize) -> i64 {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(i < a.n());
|
||||
assert!(col_i < a.cols())
|
||||
}
|
||||
|
||||
let size: usize = (k + basek - 1) / basek;
|
||||
let data: &[i64] = a.raw();
|
||||
let mut res: i64 = 0;
|
||||
let rem: usize = basek - (k % basek);
|
||||
let slice_size: usize = a.n() * a.cols();
|
||||
(0..size).for_each(|j| {
|
||||
let x: i64 = data[j * slice_size + i];
|
||||
if j == size - 1 && rem != basek {
|
||||
let k_rem: usize = basek - rem;
|
||||
res = (res << k_rem) + (x >> rem);
|
||||
} else {
|
||||
res = (res << basek) + x;
|
||||
}
|
||||
});
|
||||
res
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::vec_znx_ops::*;
|
||||
use crate::znx_base::*;
|
||||
use crate::{Decoding, Encoding, FFT64, Module, VecZnx, znx_base::ZnxInfos};
|
||||
use itertools::izip;
|
||||
use sampling::source::Source;
|
||||
|
||||
#[test]
|
||||
fn test_set_get_i64_lo_norm() {
|
||||
let n: usize = 8;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let basek: usize = 17;
|
||||
let size: usize = 5;
|
||||
let k: usize = size * basek - 5;
|
||||
let mut a: VecZnx<_> = module.new_vec_znx(2, size);
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let raw: &mut [i64] = a.raw_mut();
|
||||
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
|
||||
(0..a.cols()).for_each(|col_i| {
|
||||
let mut have: Vec<i64> = vec![i64::default(); n];
|
||||
have.iter_mut()
|
||||
.for_each(|x| *x = (source.next_i64() << 56) >> 56);
|
||||
a.encode_vec_i64(col_i, basek, k, &have, 10);
|
||||
let mut want: Vec<i64> = vec![i64::default(); n];
|
||||
a.decode_vec_i64(col_i, basek, k, &mut want);
|
||||
izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b));
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_set_get_i64_hi_norm() {
|
||||
let n: usize = 8;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let basek: usize = 17;
|
||||
let size: usize = 5;
|
||||
for k in [1, basek / 2, size * basek - 5] {
|
||||
let mut a: VecZnx<_> = module.new_vec_znx(2, size);
|
||||
let mut source = Source::new([0u8; 32]);
|
||||
let raw: &mut [i64] = a.raw_mut();
|
||||
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
|
||||
(0..a.cols()).for_each(|col_i| {
|
||||
let mut have: Vec<i64> = vec![i64::default(); n];
|
||||
have.iter_mut().for_each(|x| {
|
||||
if k < 64 {
|
||||
*x = source.next_u64n(1 << k, (1 << k) - 1) as i64;
|
||||
} else {
|
||||
*x = source.next_i64();
|
||||
}
|
||||
});
|
||||
a.encode_vec_i64(col_i, basek, k, &have, std::cmp::min(k, 64));
|
||||
let mut want = vec![i64::default(); n];
|
||||
a.decode_vec_i64(col_i, basek, k, &mut want);
|
||||
izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b));
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
17
backend/src/hal/api/mat_znx.rs
Normal file
17
backend/src/hal/api/mat_znx.rs
Normal file
@@ -0,0 +1,17 @@
|
||||
use crate::hal::layouts::MatZnxOwned;
|
||||
|
||||
/// Allocates as [crate::hal::layouts::MatZnx].
|
||||
pub trait MatZnxAlloc {
|
||||
fn mat_znx_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxOwned;
|
||||
}
|
||||
|
||||
/// Returns the size in bytes to allocate a [crate::hal::layouts::MatZnx].
|
||||
pub trait MatZnxAllocBytes {
|
||||
fn mat_znx_alloc_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
/// Consume a vector of bytes into a [crate::hal::layouts::MatZnx].
|
||||
/// User must ensure that bytes is memory aligned and that it length is equal to [MatZnxAllocBytes].
|
||||
pub trait MatZnxFromBytes {
|
||||
fn mat_znx_from_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: Vec<u8>) -> MatZnxOwned;
|
||||
}
|
||||
21
backend/src/hal/api/mod.rs
Normal file
21
backend/src/hal/api/mod.rs
Normal file
@@ -0,0 +1,21 @@
|
||||
mod mat_znx;
|
||||
mod module;
|
||||
mod scalar_znx;
|
||||
mod scratch;
|
||||
mod svp_ppol;
|
||||
mod vec_znx;
|
||||
mod vec_znx_big;
|
||||
mod vec_znx_dft;
|
||||
mod vmp_pmat;
|
||||
mod znx_base;
|
||||
|
||||
pub use mat_znx::*;
|
||||
pub use module::*;
|
||||
pub use scalar_znx::*;
|
||||
pub use scratch::*;
|
||||
pub use svp_ppol::*;
|
||||
pub use vec_znx::*;
|
||||
pub use vec_znx_big::*;
|
||||
pub use vec_znx_dft::*;
|
||||
pub use vmp_pmat::*;
|
||||
pub use znx_base::*;
|
||||
6
backend/src/hal/api/module.rs
Normal file
6
backend/src/hal/api/module.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
use crate::hal::layouts::{Backend, Module};
|
||||
|
||||
/// Instantiate a new [crate::hal::layouts::Module].
|
||||
pub trait ModuleNew<B: Backend> {
|
||||
fn new(n: u64) -> Module<B>;
|
||||
}
|
||||
47
backend/src/hal/api/scalar_znx.rs
Normal file
47
backend/src/hal/api/scalar_znx.rs
Normal file
@@ -0,0 +1,47 @@
|
||||
use crate::hal::layouts::{ScalarZnxOwned, ScalarZnxToMut, ScalarZnxToRef};
|
||||
|
||||
/// Allocates as [crate::hal::layouts::ScalarZnx].
|
||||
pub trait ScalarZnxAlloc {
|
||||
fn scalar_znx_alloc(&self, cols: usize) -> ScalarZnxOwned;
|
||||
}
|
||||
|
||||
/// Returns the size in bytes to allocate a [crate::hal::layouts::ScalarZnx].
|
||||
pub trait ScalarZnxAllocBytes {
|
||||
fn scalar_znx_alloc_bytes(&self, cols: usize) -> usize;
|
||||
}
|
||||
|
||||
/// Consume a vector of bytes into a [crate::hal::layouts::ScalarZnx].
|
||||
/// User must ensure that bytes is memory aligned and that it length is equal to [ScalarZnxAllocBytes].
|
||||
pub trait ScalarZnxFromBytes {
|
||||
fn scalar_znx_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxOwned;
|
||||
}
|
||||
|
||||
/// Applies the mapping X -> X^k to a\[a_col\] and write the result on res\[res_col\].
|
||||
pub trait ScalarZnxAutomorphism {
|
||||
fn scalar_znx_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: ScalarZnxToMut,
|
||||
A: ScalarZnxToRef;
|
||||
}
|
||||
|
||||
/// Applies the mapping X -> X^k on res\[res_col\].
|
||||
pub trait ScalarZnxAutomorphismInplace {
|
||||
fn scalar_znx_automorphism_inplace<R>(&self, k: i64, res: &mut R, res_col: usize)
|
||||
where
|
||||
R: ScalarZnxToMut;
|
||||
}
|
||||
|
||||
/// Multiply a\[a_col\] with (X^p - 1) and write the result on res\[res_col\].
|
||||
pub trait ScalarZnxMulXpMinusOne {
|
||||
fn scalar_znx_mul_xp_minus_one<R, A>(&self, p: i64, r: &mut R, r_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: ScalarZnxToMut,
|
||||
A: ScalarZnxToRef;
|
||||
}
|
||||
|
||||
/// Multiply res\[res_col\] with (X^p - 1).
|
||||
pub trait ScalarZnxMulXpMinusOneInplace {
|
||||
fn scalar_znx_mul_xp_minus_one_inplace<R>(&self, p: i64, res: &mut R, res_col: usize)
|
||||
where
|
||||
R: ScalarZnxToMut;
|
||||
}
|
||||
113
backend/src/hal/api/scratch.rs
Normal file
113
backend/src/hal/api/scratch.rs
Normal file
@@ -0,0 +1,113 @@
|
||||
use crate::hal::layouts::{Backend, MatZnx, Module, ScalarZnx, Scratch, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat};
|
||||
|
||||
/// Allocates a new [crate::hal::layouts::ScratchOwned] of `size` aligned bytes.
|
||||
pub trait ScratchOwnedAlloc<B: Backend> {
|
||||
fn alloc(size: usize) -> Self;
|
||||
}
|
||||
|
||||
/// Borrows a slice of bytes into a [Scratch].
|
||||
pub trait ScratchOwnedBorrow<B: Backend> {
|
||||
fn borrow(&mut self) -> &mut Scratch<B>;
|
||||
}
|
||||
|
||||
/// Wrap an array of mutable borrowed bytes into a [Scratch].
|
||||
pub trait ScratchFromBytes<B: Backend> {
|
||||
fn from_bytes(data: &mut [u8]) -> &mut Scratch<B>;
|
||||
}
|
||||
|
||||
/// Returns how many bytes left can be taken from the scratch.
|
||||
pub trait ScratchAvailable {
|
||||
fn available(&self) -> usize;
|
||||
}
|
||||
|
||||
/// Takes a slice of bytes from a [Scratch] and return a new [Scratch] minus the taken array of bytes.
|
||||
pub trait TakeSlice {
|
||||
fn take_slice<T>(&mut self, len: usize) -> (&mut [T], &mut Self);
|
||||
}
|
||||
|
||||
/// Take a slice of bytes from a [Scratch], wraps it into a [ScalarZnx] and returns it
|
||||
/// as well as a new [Scratch] minus the taken array of bytes.
|
||||
pub trait TakeScalarZnx<B: Backend> {
|
||||
fn take_scalar_znx(&mut self, module: &Module<B>, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self);
|
||||
}
|
||||
|
||||
/// Take a slice of bytes from a [Scratch], wraps it into a [SvpPPol] and returns it
|
||||
/// as well as a new [Scratch] minus the taken array of bytes.
|
||||
pub trait TakeSvpPPol<B: Backend> {
|
||||
fn take_svp_ppol(&mut self, module: &Module<B>, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Self);
|
||||
}
|
||||
|
||||
/// Take a slice of bytes from a [Scratch], wraps it into a [VecZnx] and returns it
|
||||
/// as well as a new [Scratch] minus the taken array of bytes.
|
||||
pub trait TakeVecZnx<B: Backend> {
|
||||
fn take_vec_znx(&mut self, module: &Module<B>, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self);
|
||||
}
|
||||
|
||||
/// Take a slice of bytes from a [Scratch], slices it into a vector of [VecZnx] aand returns it
|
||||
/// as well as a new [Scratch] minus the taken array of bytes.
|
||||
pub trait TakeVecZnxSlice<B: Backend> {
|
||||
fn take_vec_znx_slice(
|
||||
&mut self,
|
||||
len: usize,
|
||||
module: &Module<B>,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (Vec<VecZnx<&mut [u8]>>, &mut Self);
|
||||
}
|
||||
|
||||
/// Take a slice of bytes from a [Scratch], wraps it into a [VecZnxBig] and returns it
|
||||
/// as well as a new [Scratch] minus the taken array of bytes.
|
||||
pub trait TakeVecZnxBig<B: Backend> {
|
||||
fn take_vec_znx_big(&mut self, module: &Module<B>, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, &mut Self);
|
||||
}
|
||||
|
||||
/// Take a slice of bytes from a [Scratch], wraps it into a [VecZnxDft] and returns it
|
||||
/// as well as a new [Scratch] minus the taken array of bytes.
|
||||
pub trait TakeVecZnxDft<B: Backend> {
|
||||
fn take_vec_znx_dft(&mut self, module: &Module<B>, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, &mut Self);
|
||||
}
|
||||
|
||||
/// Take a slice of bytes from a [Scratch], slices it into a vector of [VecZnxDft] and returns it
|
||||
/// as well as a new [Scratch] minus the taken array of bytes.
|
||||
pub trait TakeVecZnxDftSlice<B: Backend> {
|
||||
fn take_vec_znx_dft_slice(
|
||||
&mut self,
|
||||
len: usize,
|
||||
module: &Module<B>,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (Vec<VecZnxDft<&mut [u8], B>>, &mut Self);
|
||||
}
|
||||
|
||||
/// Take a slice of bytes from a [Scratch], wraps it into a [VmpPMat] and returns it
|
||||
/// as well as a new [Scratch] minus the taken array of bytes.
|
||||
pub trait TakeVmpPMat<B: Backend> {
|
||||
fn take_vmp_pmat(
|
||||
&mut self,
|
||||
module: &Module<B>,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
) -> (VmpPMat<&mut [u8], B>, &mut Self);
|
||||
}
|
||||
|
||||
/// Take a slice of bytes from a [Scratch], wraps it into a [MatZnx] and returns it
|
||||
/// as well as a new [Scratch] minus the taken array of bytes.
|
||||
pub trait TakeMatZnx<B: Backend> {
|
||||
fn take_mat_znx(
|
||||
&mut self,
|
||||
module: &Module<B>,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
) -> (MatZnx<&mut [u8]>, &mut Self);
|
||||
}
|
||||
|
||||
/// Take a slice of bytes from a [Scratch], wraps it into the template's type and returns it
|
||||
/// as well as a new [Scratch] minus the taken array of bytes.
|
||||
pub trait TakeLike<'a, B: Backend, T> {
|
||||
type Output;
|
||||
fn take_like(&'a mut self, template: &T) -> (Self::Output, &'a mut Self);
|
||||
}
|
||||
42
backend/src/hal/api/svp_ppol.rs
Normal file
42
backend/src/hal/api/svp_ppol.rs
Normal file
@@ -0,0 +1,42 @@
|
||||
use crate::hal::layouts::{Backend, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDftToMut, VecZnxDftToRef};
|
||||
|
||||
/// Allocates as [crate::hal::layouts::SvpPPol].
|
||||
pub trait SvpPPolAlloc<B: Backend> {
|
||||
fn svp_ppol_alloc(&self, cols: usize) -> SvpPPolOwned<B>;
|
||||
}
|
||||
|
||||
/// Returns the size in bytes to allocate a [crate::hal::layouts::SvpPPol].
|
||||
pub trait SvpPPolAllocBytes {
|
||||
fn svp_ppol_alloc_bytes(&self, cols: usize) -> usize;
|
||||
}
|
||||
|
||||
/// Consume a vector of bytes into a [crate::hal::layouts::MatZnx].
|
||||
/// User must ensure that bytes is memory aligned and that it length is equal to [SvpPPolAllocBytes].
|
||||
pub trait SvpPPolFromBytes<B: Backend> {
|
||||
fn svp_ppol_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<B>;
|
||||
}
|
||||
|
||||
/// Prepare a [crate::hal::layouts::ScalarZnx] into an [crate::hal::layouts::SvpPPol].
|
||||
pub trait SvpPrepare<B: Backend> {
|
||||
fn svp_prepare<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: SvpPPolToMut<B>,
|
||||
A: ScalarZnxToRef;
|
||||
}
|
||||
|
||||
/// Apply a scalar-vector product between `a[a_col]` and `b[b_col]` and stores the result on `res[res_col]`.
|
||||
pub trait SvpApply<B: Backend> {
|
||||
fn svp_apply<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: SvpPPolToRef<B>,
|
||||
C: VecZnxDftToRef<B>;
|
||||
}
|
||||
|
||||
/// Apply a scalar-vector product between `res[res_col]` and `a[a_col]` and stores the result on `res[res_col]`.
|
||||
pub trait SvpApplyInplace<B: Backend> {
|
||||
fn svp_apply_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: SvpPPolToRef<B>;
|
||||
}
|
||||
369
backend/src/hal/api/vec_znx.rs
Normal file
369
backend/src/hal/api/vec_znx.rs
Normal file
@@ -0,0 +1,369 @@
|
||||
use rand_distr::Distribution;
|
||||
use rug::Float;
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::hal::layouts::{Backend, ScalarZnxToRef, Scratch, VecZnxOwned, VecZnxToMut, VecZnxToRef};
|
||||
|
||||
pub trait VecZnxAlloc {
|
||||
/// Allocates a new [crate::hal::layouts::VecZnx].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `cols`: the number of polynomials.
|
||||
/// * `size`: the number small polynomials per column.
|
||||
fn vec_znx_alloc(&self, cols: usize, size: usize) -> VecZnxOwned;
|
||||
}
|
||||
|
||||
pub trait VecZnxFromBytes {
|
||||
/// Instantiates a new [crate::hal::layouts::VecZnx] from a slice of bytes.
|
||||
/// The returned [crate::hal::layouts::VecZnx] takes ownership of the slice of bytes.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `cols`: the number of polynomials.
|
||||
/// * `size`: the number small polynomials per column.
|
||||
fn vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxOwned;
|
||||
}
|
||||
|
||||
pub trait VecZnxAllocBytes {
|
||||
/// Returns the number of bytes necessary to allocate a new [crate::hal::layouts::VecZnx].
|
||||
fn vec_znx_alloc_bytes(&self, cols: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
pub trait VecZnxNormalizeTmpBytes {
|
||||
/// Returns the minimum number of bytes necessary for normalization.
|
||||
fn vec_znx_normalize_tmp_bytes(&self, n: usize) -> usize;
|
||||
}
|
||||
|
||||
pub trait VecZnxNormalize<B: Backend> {
|
||||
/// Normalizes the selected column of `a` and stores the result into the selected column of `res`.
|
||||
fn vec_znx_normalize<R, A>(&self, basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxNormalizeInplace<B: Backend> {
|
||||
/// Normalizes the selected column of `a`.
|
||||
fn vec_znx_normalize_inplace<A>(&self, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
A: VecZnxToMut;
|
||||
}
|
||||
|
||||
pub trait VecZnxAdd {
|
||||
/// Adds the selected column of `a` to the selected column of `b` and writes the result on the selected column of `res`.
|
||||
fn vec_znx_add<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
B: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxAddInplace {
|
||||
/// Adds the selected column of `a` to the selected column of `res` and writes the result on the selected column of `res`.
|
||||
fn vec_znx_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxAddScalarInplace {
|
||||
/// Adds the selected column of `a` on the selected column and limb of `res`.
|
||||
fn vec_znx_add_scalar_inplace<R, A>(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxSub {
|
||||
/// Subtracts the selected column of `b` from the selected column of `a` and writes the result on the selected column of `res`.
|
||||
fn vec_znx_sub<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
B: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxSubABInplace {
|
||||
/// Subtracts the selected column of `a` from the selected column of `res` inplace.
|
||||
///
|
||||
/// res\[res_col\] -= a\[a_col\]
|
||||
fn vec_znx_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxSubBAInplace {
|
||||
/// Subtracts the selected column of `res` from the selected column of `a` and inplace mutates `res`
|
||||
///
|
||||
/// res\[res_col\] = a\[a_col\] - res\[res_col\]
|
||||
fn vec_znx_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxSubScalarInplace {
|
||||
/// Subtracts the selected column of `a` on the selected column and limb of `res`.
|
||||
fn vec_znx_sub_scalar_inplace<R, A>(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxNegate {
|
||||
// Negates the selected column of `a` and stores the result in `res_col` of `res`.
|
||||
fn vec_znx_negate<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxNegateInplace {
|
||||
/// Negates the selected column of `a`.
|
||||
fn vec_znx_negate_inplace<A>(&self, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxToMut;
|
||||
}
|
||||
|
||||
pub trait VecZnxLshInplace {
|
||||
/// Left shift by k bits all columns of `a`.
|
||||
fn vec_znx_lsh_inplace<A>(&self, basek: usize, k: usize, a: &mut A)
|
||||
where
|
||||
A: VecZnxToMut;
|
||||
}
|
||||
|
||||
pub trait VecZnxRshInplace {
|
||||
/// Right shift by k bits all columns of `a`.
|
||||
fn vec_znx_rsh_inplace<A>(&self, basek: usize, k: usize, a: &mut A)
|
||||
where
|
||||
A: VecZnxToMut;
|
||||
}
|
||||
|
||||
pub trait VecZnxRotate {
|
||||
/// Multiplies the selected column of `a` by X^k and stores the result in `res_col` of `res`.
|
||||
fn vec_znx_rotate<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxRotateInplace {
|
||||
/// Multiplies the selected column of `a` by X^k.
|
||||
fn vec_znx_rotate_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxToMut;
|
||||
}
|
||||
|
||||
pub trait VecZnxAutomorphism {
|
||||
/// Applies the automorphism X^i -> X^ik on the selected column of `a` and stores the result in `res_col` column of `res`.
|
||||
fn vec_znx_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxAutomorphismInplace {
|
||||
/// Applies the automorphism X^i -> X^ik on the selected column of `a`.
|
||||
fn vec_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxToMut;
|
||||
}
|
||||
|
||||
pub trait VecZnxMulXpMinusOne {
|
||||
fn vec_znx_mul_xp_minus_one<R, A>(&self, p: i64, r: &mut R, r_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxMulXpMinusOneInplace {
|
||||
fn vec_znx_mul_xp_minus_one_inplace<R>(&self, p: i64, r: &mut R, r_col: usize)
|
||||
where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
pub trait VecZnxSplit<B: Backend> {
|
||||
/// Splits the selected columns of `b` into subrings and copies them them into the selected column of `res`.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// This method requires that all [crate::hal::layouts::VecZnx] of b have the same ring degree
|
||||
/// and that b.n() * b.len() <= a.n()
|
||||
fn vec_znx_split<R, A>(&self, res: &mut Vec<R>, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxMerge {
|
||||
/// Merges the subrings of the selected column of `a` into the selected column of `res`.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// This method requires that all [crate::hal::layouts::VecZnx] of a have the same ring degree
|
||||
/// and that a.n() * a.len() <= b.n()
|
||||
fn vec_znx_merge<R, A>(&self, res: &mut R, res_col: usize, a: Vec<A>, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxSwithcDegree {
|
||||
fn vec_znx_switch_degree<R, A>(&self, res: &mut R, res_col: usize, a: &A, col_a: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxCopy {
|
||||
fn vec_znx_copy<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxStd {
|
||||
/// Returns the standard devaition of the i-th polynomial.
|
||||
fn vec_znx_std<A>(&self, basek: usize, a: &A, a_col: usize) -> f64
|
||||
where
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxFillUniform {
|
||||
/// Fills the first `size` size with uniform values in \[-2^{basek-1}, 2^{basek-1}\]
|
||||
fn vec_znx_fill_uniform<R>(&self, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source)
|
||||
where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
pub trait VecZnxFillDistF64 {
|
||||
fn vec_znx_fill_dist_f64<R, D: Distribution<f64>>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
pub trait VecZnxAddDistF64 {
|
||||
/// Adds vector sampled according to the provided distribution, scaled by 2^{-k} and bounded to \[-bound, bound\].
|
||||
fn vec_znx_add_dist_f64<R, D: Distribution<f64>>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
pub trait VecZnxFillNormal {
|
||||
fn vec_znx_fill_normal<R>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
pub trait VecZnxAddNormal {
|
||||
/// Adds a discrete normal vector scaled by 2^{-k} with the provided standard deviation and bounded to \[-bound, bound\].
|
||||
fn vec_znx_add_normal<R>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
pub trait VecZnxEncodeVeci64 {
|
||||
/// encode a vector of i64 on the receiver.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `col_i`: the index of the poly where to encode the data.
|
||||
/// * `basek`: base two negative logarithm decomposition of the receiver.
|
||||
/// * `k`: base two negative logarithm of the scaling of the data.
|
||||
/// * `data`: data to encode on the receiver.
|
||||
/// * `log_max`: base two logarithm of the infinity norm of the input data.
|
||||
fn encode_vec_i64<R>(&self, basek: usize, res: &mut R, res_col: usize, k: usize, data: &[i64], log_max: usize)
|
||||
where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
pub trait VecZnxEncodeCoeffsi64 {
|
||||
/// encodes a single i64 on the receiver at the given index.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `res_col`: the index of the poly where to encode the data.
|
||||
/// * `basek`: base two negative logarithm decomposition of the receiver.
|
||||
/// * `k`: base two negative logarithm of the scaling of the data.
|
||||
/// * `i`: index of the coefficient on which to encode the data.
|
||||
/// * `data`: data to encode on the receiver.
|
||||
/// * `log_max`: base two logarithm of the infinity norm of the input data.
|
||||
fn encode_coeff_i64<R>(&self, basek: usize, res: &mut R, res_col: usize, k: usize, i: usize, data: i64, log_max: usize)
|
||||
where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
pub trait VecZnxDecodeVeci64 {
|
||||
/// decode a vector of i64 from the receiver.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `res_col`: the index of the poly where to encode the data.
|
||||
/// * `basek`: base two negative logarithm decomposition of the receiver.
|
||||
/// * `k`: base two logarithm of the scaling of the data.
|
||||
/// * `data`: data to decode from the receiver.
|
||||
fn decode_vec_i64<R>(&self, basek: usize, res: &R, res_col: usize, k: usize, data: &mut [i64])
|
||||
where
|
||||
R: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxDecodeCoeffsi64 {
|
||||
/// decode a single of i64 from the receiver at the given index.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `res_col`: the index of the poly where to encode the data.
|
||||
/// * `basek`: base two negative logarithm decomposition of the receiver.
|
||||
/// * `k`: base two negative logarithm of the scaling of the data.
|
||||
/// * `i`: index of the coefficient to decode.
|
||||
/// * `data`: data to decode from the receiver.
|
||||
fn decode_coeff_i64<R>(&self, basek: usize, res: &R, res_col: usize, k: usize, i: usize) -> i64
|
||||
where
|
||||
R: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxDecodeVecFloat {
|
||||
/// decode a vector of Float from the receiver.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `col_i`: the index of the poly where to encode the data.
|
||||
/// * `basek`: base two negative logarithm decomposition of the receiver.
|
||||
/// * `data`: data to decode from the receiver.
|
||||
fn decode_vec_float<R>(&self, basek: usize, res: &R, col_i: usize, data: &mut [Float])
|
||||
where
|
||||
R: VecZnxToRef;
|
||||
}
|
||||
214
backend/src/hal/api/vec_znx_big.rs
Normal file
214
backend/src/hal/api/vec_znx_big.rs
Normal file
@@ -0,0 +1,214 @@
|
||||
use rand_distr::Distribution;
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::hal::layouts::{Backend, Scratch, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef};
|
||||
|
||||
/// Allocates as [crate::hal::layouts::VecZnxBig].
|
||||
pub trait VecZnxBigAlloc<B: Backend> {
|
||||
fn vec_znx_big_alloc(&self, cols: usize, size: usize) -> VecZnxBigOwned<B>;
|
||||
}
|
||||
|
||||
/// Returns the size in bytes to allocate a [crate::hal::layouts::VecZnxBig].
|
||||
pub trait VecZnxBigAllocBytes {
|
||||
fn vec_znx_big_alloc_bytes(&self, cols: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
/// Consume a vector of bytes into a [crate::hal::layouts::VecZnxBig].
|
||||
/// User must ensure that bytes is memory aligned and that it length is equal to [VecZnxBigAllocBytes].
|
||||
pub trait VecZnxBigFromBytes<B: Backend> {
|
||||
fn vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<B>;
|
||||
}
|
||||
|
||||
/// Add a discrete normal distribution on res.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `basek`: base two logarithm of the bivariate representation
|
||||
/// * `res`: receiver.
|
||||
/// * `res_col`: column of the receiver on which the operation is performed/stored.
|
||||
/// * `k`:
|
||||
/// * `source`: random coin source.
|
||||
/// * `sigma`: standard deviation of the discrete normal distribution.
|
||||
/// * `bound`: rejection sampling bound.
|
||||
pub trait VecZnxBigAddNormal<B: Backend> {
|
||||
fn vec_znx_big_add_normal<R: VecZnxBigToMut<B>>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
);
|
||||
}
|
||||
|
||||
pub trait VecZnxBigFillNormal<B: Backend> {
|
||||
fn vec_znx_big_fill_normal<R: VecZnxBigToMut<B>>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
);
|
||||
}
|
||||
|
||||
pub trait VecZnxBigFillDistF64<B: Backend> {
|
||||
fn vec_znx_big_fill_dist_f64<R: VecZnxBigToMut<B>, D: Distribution<f64>>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
);
|
||||
}
|
||||
|
||||
pub trait VecZnxBigAddDistF64<B: Backend> {
|
||||
fn vec_znx_big_add_dist_f64<R: VecZnxBigToMut<B>, D: Distribution<f64>>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
);
|
||||
}
|
||||
|
||||
pub trait VecZnxBigAdd<B: Backend> {
|
||||
/// Adds `a` to `b` and stores the result on `c`.
|
||||
fn vec_znx_big_add<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>,
|
||||
C: VecZnxBigToRef<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxBigAddInplace<B: Backend> {
|
||||
/// Adds `a` to `b` and stores the result on `b`.
|
||||
fn vec_znx_big_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxBigAddSmall<B: Backend> {
|
||||
/// Adds `a` to `b` and stores the result on `c`.
|
||||
fn vec_znx_big_add_small<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>,
|
||||
C: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxBigAddSmallInplace<B: Backend> {
|
||||
/// Adds `a` to `b` and stores the result on `b`.
|
||||
fn vec_znx_big_add_small_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxBigSub<B: Backend> {
|
||||
/// Subtracts `a` to `b` and stores the result on `c`.
|
||||
fn vec_znx_big_sub<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>,
|
||||
C: VecZnxBigToRef<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxBigSubABInplace<B: Backend> {
|
||||
/// Subtracts `a` from `b` and stores the result on `b`.
|
||||
fn vec_znx_big_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxBigSubBAInplace<B: Backend> {
|
||||
/// Subtracts `b` from `a` and stores the result on `b`.
|
||||
fn vec_znx_big_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxBigSubSmallA<B: Backend> {
|
||||
/// Subtracts `b` from `a` and stores the result on `c`.
|
||||
fn vec_znx_big_sub_small_a<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxToRef,
|
||||
C: VecZnxBigToRef<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxBigSubSmallAInplace<B: Backend> {
|
||||
/// Subtracts `a` from `res` and stores the result on `res`.
|
||||
fn vec_znx_big_sub_small_a_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxBigSubSmallB<B: Backend> {
|
||||
/// Subtracts `b` from `a` and stores the result on `c`.
|
||||
fn vec_znx_big_sub_small_b<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>,
|
||||
C: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxBigSubSmallBInplace<B: Backend> {
|
||||
/// Subtracts `res` from `a` and stores the result on `res`.
|
||||
fn vec_znx_big_sub_small_b_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxBigNegateInplace<B: Backend> {
|
||||
fn vec_znx_big_negate_inplace<A>(&self, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxBigToMut<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxBigNormalizeTmpBytes {
|
||||
fn vec_znx_big_normalize_tmp_bytes(&self, n: usize) -> usize;
|
||||
}
|
||||
|
||||
pub trait VecZnxBigNormalize<B: Backend> {
|
||||
fn vec_znx_big_normalize<R, A>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxBigToRef<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxBigAutomorphism<B: Backend> {
|
||||
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`.
|
||||
fn vec_znx_big_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxBigAutomorphismInplace<B: Backend> {
|
||||
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`.
|
||||
fn vec_znx_big_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxBigToMut<B>;
|
||||
}
|
||||
96
backend/src/hal/api/vec_znx_dft.rs
Normal file
96
backend/src/hal/api/vec_znx_dft.rs
Normal file
@@ -0,0 +1,96 @@
|
||||
use crate::hal::layouts::{
|
||||
Backend, Data, Scratch, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef,
|
||||
};
|
||||
|
||||
pub trait VecZnxDftAlloc<B: Backend> {
|
||||
fn vec_znx_dft_alloc(&self, cols: usize, size: usize) -> VecZnxDftOwned<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxDftFromBytes<B: Backend> {
|
||||
fn vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxDftAllocBytes {
|
||||
fn vec_znx_dft_alloc_bytes(&self, cols: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
pub trait VecZnxDftToVecZnxBigTmpBytes {
|
||||
fn vec_znx_dft_to_vec_znx_big_tmp_bytes(&self) -> usize;
|
||||
}
|
||||
|
||||
pub trait VecZnxDftToVecZnxBig<B: Backend> {
|
||||
fn vec_znx_dft_to_vec_znx_big<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxDftToRef<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxDftToVecZnxBigTmpA<B: Backend> {
|
||||
fn vec_znx_dft_to_vec_znx_big_tmp_a<R, A>(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxDftToMut<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxDftToVecZnxBigConsume<B: Backend> {
|
||||
fn vec_znx_dft_to_vec_znx_big_consume<D: Data>(&self, a: VecZnxDft<D, B>) -> VecZnxBig<D, B>
|
||||
where
|
||||
VecZnxDft<D, B>: VecZnxDftToMut<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxDftAdd<B: Backend> {
|
||||
fn vec_znx_dft_add<R, A, D>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
D: VecZnxDftToRef<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxDftAddInplace<B: Backend> {
|
||||
fn vec_znx_dft_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxDftSub<B: Backend> {
|
||||
fn vec_znx_dft_sub<R, A, D>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
D: VecZnxDftToRef<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxDftSubABInplace<B: Backend> {
|
||||
fn vec_znx_dft_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxDftSubBAInplace<B: Backend> {
|
||||
fn vec_znx_dft_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxDftCopy<B: Backend> {
|
||||
fn vec_znx_dft_copy<R, A>(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxDftFromVecZnx<B: Backend> {
|
||||
fn vec_znx_dft_from_vec_znx<R, A>(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxDftZero<B: Backend> {
|
||||
fn vec_znx_dft_zero<R>(&self, res: &mut R)
|
||||
where
|
||||
R: VecZnxDftToMut<B>;
|
||||
}
|
||||
90
backend/src/hal/api/vmp_pmat.rs
Normal file
90
backend/src/hal/api/vmp_pmat.rs
Normal file
@@ -0,0 +1,90 @@
|
||||
use crate::hal::layouts::{
|
||||
Backend, MatZnxToRef, Scratch, VecZnxDftToMut, VecZnxDftToRef, VmpPMatOwned, VmpPMatToMut, VmpPMatToRef,
|
||||
};
|
||||
|
||||
pub trait VmpPMatAlloc<B: Backend> {
|
||||
fn vmp_pmat_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<B>;
|
||||
}
|
||||
|
||||
pub trait VmpPMatAllocBytes {
|
||||
fn vmp_pmat_alloc_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
pub trait VmpPMatFromBytes<B: Backend> {
|
||||
fn vmp_pmat_from_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: Vec<u8>) -> VmpPMatOwned<B>;
|
||||
}
|
||||
|
||||
pub trait VmpPrepareTmpBytes {
|
||||
fn vmp_prepare_tmp_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
pub trait VmpPMatPrepare<B: Backend> {
|
||||
fn vmp_prepare<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VmpPMatToMut<B>,
|
||||
A: MatZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VmpApplyTmpBytes {
|
||||
fn vmp_apply_tmp_bytes(
|
||||
&self,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
b_cols_in: usize,
|
||||
b_cols_out: usize,
|
||||
b_size: usize,
|
||||
) -> usize;
|
||||
}
|
||||
|
||||
pub trait VmpApply<B: Backend> {
|
||||
/// Applies the vector matrix product [crate::hal::layouts::VecZnxDft] x [crate::hal::layouts::VmpPMat].
|
||||
///
|
||||
/// A vector matrix product numerically equivalent to a sum of [crate::hal::api::SvpApply],
|
||||
/// where each [crate::hal::layouts::SvpPPol] is a limb of the input [crate::hal::layouts::VecZnx] in DFT,
|
||||
/// and each vector a [crate::hal::layouts::VecZnxDft] (row) of the [crate::hal::layouts::VmpPMat].
|
||||
///
|
||||
/// As such, given an input [crate::hal::layouts::VecZnx] of `i` size and a [crate::hal::layouts::VmpPMat] of `i` rows and
|
||||
/// `j` size, the output is a [crate::hal::layouts::VecZnx] of `j` size.
|
||||
///
|
||||
/// If there is a mismatch between the dimensions the largest valid ones are used.
|
||||
///
|
||||
/// ```text
|
||||
/// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p|
|
||||
/// |h i j|
|
||||
/// |k l m|
|
||||
/// ```
|
||||
/// where each element is a [crate::hal::layouts::VecZnxDft].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `c`: the output of the vector matrix product, as a [crate::hal::layouts::VecZnxDft].
|
||||
/// * `a`: the left operand [crate::hal::layouts::VecZnxDft] of the vector matrix product.
|
||||
/// * `b`: the right operand [crate::hal::layouts::VmpPMat] of the vector matrix product.
|
||||
/// * `buf`: scratch space, the size can be obtained with [VmpApplyTmpBytes::vmp_apply_tmp_bytes].
|
||||
fn vmp_apply<R, A, C>(&self, res: &mut R, a: &A, b: &C, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
C: VmpPMatToRef<B>;
|
||||
}
|
||||
|
||||
pub trait VmpApplyAddTmpBytes {
|
||||
fn vmp_apply_add_tmp_bytes(
|
||||
&self,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
b_cols_in: usize,
|
||||
b_cols_out: usize,
|
||||
b_size: usize,
|
||||
) -> usize;
|
||||
}
|
||||
|
||||
pub trait VmpApplyAdd<B: Backend> {
|
||||
fn vmp_apply_add<R, A, C>(&self, res: &mut R, a: &A, b: &C, scale: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
C: VmpPMatToRef<B>;
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
use itertools::izip;
|
||||
use crate::hal::layouts::{Data, DataMut, DataRef};
|
||||
use rand_distr::num_traits::Zero;
|
||||
|
||||
pub trait ZnxInfos {
|
||||
@@ -32,7 +32,7 @@ pub trait ZnxSliceSize {
|
||||
}
|
||||
|
||||
pub trait DataView {
|
||||
type D;
|
||||
type D: Data;
|
||||
fn data(&self) -> &Self::D;
|
||||
}
|
||||
|
||||
@@ -40,8 +40,8 @@ pub trait DataViewMut: DataView {
|
||||
fn data_mut(&mut self) -> &mut Self::D;
|
||||
}
|
||||
|
||||
pub trait ZnxView: ZnxInfos + DataView<D: AsRef<[u8]>> {
|
||||
type Scalar: Copy;
|
||||
pub trait ZnxView: ZnxInfos + DataView<D: DataRef> {
|
||||
type Scalar: Copy + Zero;
|
||||
|
||||
/// Returns a non-mutable pointer to the underlying coefficients array.
|
||||
fn as_ptr(&self) -> *const Self::Scalar {
|
||||
@@ -57,8 +57,8 @@ pub trait ZnxView: ZnxInfos + DataView<D: AsRef<[u8]>> {
|
||||
fn at_ptr(&self, i: usize, j: usize) -> *const Self::Scalar {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(i < self.cols(), "{} >= {}", i, self.cols());
|
||||
assert!(j < self.size(), "{} >= {}", j, self.size());
|
||||
assert!(i < self.cols(), "cols: {} >= {}", i, self.cols());
|
||||
assert!(j < self.size(), "size: {} >= {}", j, self.size());
|
||||
}
|
||||
let offset: usize = self.n() * (j * self.cols() + i);
|
||||
unsafe { self.as_ptr().add(offset) }
|
||||
@@ -70,7 +70,7 @@ pub trait ZnxView: ZnxInfos + DataView<D: AsRef<[u8]>> {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ZnxViewMut: ZnxView + DataViewMut<D: AsMut<[u8]>> {
|
||||
pub trait ZnxViewMut: ZnxView + DataViewMut<D: DataMut> {
|
||||
/// Returns a mutable pointer to the underlying coefficients array.
|
||||
fn as_mut_ptr(&mut self) -> *mut Self::Scalar {
|
||||
self.data_mut().as_mut().as_mut_ptr() as *mut Self::Scalar
|
||||
@@ -85,8 +85,8 @@ pub trait ZnxViewMut: ZnxView + DataViewMut<D: AsMut<[u8]>> {
|
||||
fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut Self::Scalar {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(i < self.cols(), "{} >= {}", i, self.cols());
|
||||
assert!(j < self.size(), "{} >= {}", j, self.size());
|
||||
assert!(i < self.cols(), "cols: {} >= {}", i, self.cols());
|
||||
assert!(j < self.size(), "size: {} >= {}", j, self.size());
|
||||
}
|
||||
let offset: usize = self.n() * (j * self.cols() + i);
|
||||
unsafe { self.as_mut_ptr().add(offset) }
|
||||
@@ -99,101 +99,12 @@ pub trait ZnxViewMut: ZnxView + DataViewMut<D: AsMut<[u8]>> {
|
||||
}
|
||||
|
||||
//(Jay)Note: Can't provide blanket impl. of ZnxView because Scalar is not known
|
||||
impl<T> ZnxViewMut for T where T: ZnxView + DataViewMut<D: AsMut<[u8]>> {}
|
||||
impl<T> ZnxViewMut for T where T: ZnxView + DataViewMut<D: DataMut> {}
|
||||
|
||||
pub trait ZnxZero: ZnxViewMut + ZnxSliceSize
|
||||
pub trait ZnxZero
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
fn zero(&mut self) {
|
||||
unsafe {
|
||||
std::ptr::write_bytes(self.as_mut_ptr(), 0, self.n() * self.poly_count());
|
||||
}
|
||||
}
|
||||
|
||||
fn zero_at(&mut self, i: usize, j: usize) {
|
||||
unsafe {
|
||||
std::ptr::write_bytes(self.at_mut_ptr(i, j), 0, self.n());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Blanket implementations
|
||||
impl<T> ZnxZero for T where T: ZnxViewMut + ZnxSliceSize {} // WARNING should not work for mat_znx_dft but it does
|
||||
|
||||
use std::ops::{Add, AddAssign, Div, Mul, Neg, Shl, Shr, Sub};
|
||||
|
||||
use crate::Scratch;
|
||||
pub trait Integer:
|
||||
Copy
|
||||
+ Default
|
||||
+ PartialEq
|
||||
+ PartialOrd
|
||||
+ Add<Output = Self>
|
||||
+ Sub<Output = Self>
|
||||
+ Mul<Output = Self>
|
||||
+ Div<Output = Self>
|
||||
+ Neg<Output = Self>
|
||||
+ Shl<Output = Self>
|
||||
+ Shr<Output = Self>
|
||||
+ AddAssign
|
||||
{
|
||||
const BITS: u32;
|
||||
}
|
||||
|
||||
impl Integer for i64 {
|
||||
const BITS: u32 = 64;
|
||||
}
|
||||
|
||||
impl Integer for i128 {
|
||||
const BITS: u32 = 128;
|
||||
}
|
||||
|
||||
//(Jay)Note: `rsh` impl. ignores the column
|
||||
pub fn rsh<V: ZnxZero>(k: usize, basek: usize, a: &mut V, _a_col: usize, scratch: &mut Scratch)
|
||||
where
|
||||
V::Scalar: From<usize> + Integer + Zero,
|
||||
{
|
||||
let n: usize = a.n();
|
||||
let _size: usize = a.size();
|
||||
let cols: usize = a.cols();
|
||||
|
||||
let size: usize = a.size();
|
||||
let steps: usize = k / basek;
|
||||
|
||||
a.raw_mut().rotate_right(n * steps * cols);
|
||||
(0..cols).for_each(|i| {
|
||||
(0..steps).for_each(|j| {
|
||||
a.zero_at(i, j);
|
||||
})
|
||||
});
|
||||
|
||||
let k_rem: usize = k % basek;
|
||||
|
||||
if k_rem != 0 {
|
||||
let (carry, _) = scratch.tmp_slice::<V::Scalar>(rsh_tmp_bytes::<V::Scalar>(n));
|
||||
|
||||
unsafe {
|
||||
std::ptr::write_bytes(carry.as_mut_ptr(), 0, n * size_of::<V::Scalar>());
|
||||
}
|
||||
|
||||
let basek_t = V::Scalar::from(basek);
|
||||
let shift = V::Scalar::from(V::Scalar::BITS as usize - k_rem);
|
||||
let k_rem_t = V::Scalar::from(k_rem);
|
||||
|
||||
(0..cols).for_each(|i| {
|
||||
(steps..size).for_each(|j| {
|
||||
izip!(carry.iter_mut(), a.at_mut(i, j).iter_mut()).for_each(|(ci, xi)| {
|
||||
*xi += *ci << basek_t;
|
||||
*ci = (*xi << shift) >> shift;
|
||||
*xi = (*xi - *ci) >> k_rem_t;
|
||||
});
|
||||
});
|
||||
carry.iter_mut().for_each(|r| *r = V::Scalar::zero());
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn rsh_tmp_bytes<T>(n: usize) -> usize {
|
||||
n * std::mem::size_of::<T>()
|
||||
fn zero(&mut self);
|
||||
fn zero_at(&mut self, i: usize, j: usize);
|
||||
}
|
||||
32
backend/src/hal/delegates/mat_znx.rs
Normal file
32
backend/src/hal/delegates/mat_znx.rs
Normal file
@@ -0,0 +1,32 @@
|
||||
use crate::hal::{
|
||||
api::{MatZnxAlloc, MatZnxAllocBytes, MatZnxFromBytes},
|
||||
layouts::{Backend, MatZnxOwned, Module},
|
||||
oep::{MatZnxAllocBytesImpl, MatZnxAllocImpl, MatZnxFromBytesImpl},
|
||||
};
|
||||
|
||||
impl<B> MatZnxAlloc for Module<B>
|
||||
where
|
||||
B: Backend + MatZnxAllocImpl<B>,
|
||||
{
|
||||
fn mat_znx_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxOwned {
|
||||
B::mat_znx_alloc_impl(self, rows, cols_in, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> MatZnxAllocBytes for Module<B>
|
||||
where
|
||||
B: Backend + MatZnxAllocBytesImpl<B>,
|
||||
{
|
||||
fn mat_znx_alloc_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
B::mat_znx_alloc_bytes_impl(self, rows, cols_in, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> MatZnxFromBytes for Module<B>
|
||||
where
|
||||
B: Backend + MatZnxFromBytesImpl<B>,
|
||||
{
|
||||
fn mat_znx_from_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: Vec<u8>) -> MatZnxOwned {
|
||||
B::mat_znx_from_bytes_impl(self, rows, cols_in, cols_out, size, bytes)
|
||||
}
|
||||
}
|
||||
9
backend/src/hal/delegates/mod.rs
Normal file
9
backend/src/hal/delegates/mod.rs
Normal file
@@ -0,0 +1,9 @@
|
||||
mod mat_znx;
|
||||
mod module;
|
||||
mod scalar_znx;
|
||||
mod scratch;
|
||||
mod svp_ppol;
|
||||
mod vec_znx;
|
||||
mod vec_znx_big;
|
||||
mod vec_znx_dft;
|
||||
mod vmp_pmat;
|
||||
14
backend/src/hal/delegates/module.rs
Normal file
14
backend/src/hal/delegates/module.rs
Normal file
@@ -0,0 +1,14 @@
|
||||
use crate::hal::{
|
||||
api::ModuleNew,
|
||||
layouts::{Backend, Module},
|
||||
oep::ModuleNewImpl,
|
||||
};
|
||||
|
||||
impl<B> ModuleNew<B> for Module<B>
|
||||
where
|
||||
B: Backend + ModuleNewImpl<B>,
|
||||
{
|
||||
fn new(n: u64) -> Self {
|
||||
B::new_impl(n)
|
||||
}
|
||||
}
|
||||
88
backend/src/hal/delegates/scalar_znx.rs
Normal file
88
backend/src/hal/delegates/scalar_znx.rs
Normal file
@@ -0,0 +1,88 @@
|
||||
use crate::hal::{
|
||||
api::{
|
||||
ScalarZnxAlloc, ScalarZnxAllocBytes, ScalarZnxAutomorphism, ScalarZnxAutomorphismInplace, ScalarZnxFromBytes,
|
||||
ScalarZnxMulXpMinusOne, ScalarZnxMulXpMinusOneInplace,
|
||||
},
|
||||
layouts::{Backend, Module, ScalarZnxOwned, ScalarZnxToMut, ScalarZnxToRef},
|
||||
oep::{
|
||||
ScalarZnxAllocBytesImpl, ScalarZnxAllocImpl, ScalarZnxAutomorphismImpl, ScalarZnxAutomorphismInplaceIml,
|
||||
ScalarZnxFromBytesImpl, ScalarZnxMulXpMinusOneImpl, ScalarZnxMulXpMinusOneInplaceImpl,
|
||||
},
|
||||
};
|
||||
|
||||
impl<B> ScalarZnxAllocBytes for Module<B>
|
||||
where
|
||||
B: Backend + ScalarZnxAllocBytesImpl<B>,
|
||||
{
|
||||
fn scalar_znx_alloc_bytes(&self, cols: usize) -> usize {
|
||||
B::scalar_znx_alloc_bytes_impl(self.n(), cols)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> ScalarZnxAlloc for Module<B>
|
||||
where
|
||||
B: Backend + ScalarZnxAllocImpl<B>,
|
||||
{
|
||||
fn scalar_znx_alloc(&self, cols: usize) -> ScalarZnxOwned {
|
||||
B::scalar_znx_alloc_impl(self.n(), cols)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> ScalarZnxFromBytes for Module<B>
|
||||
where
|
||||
B: Backend + ScalarZnxFromBytesImpl<B>,
|
||||
{
|
||||
fn scalar_znx_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxOwned {
|
||||
B::scalar_znx_from_bytes_impl(self.n(), cols, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> ScalarZnxAutomorphism for Module<B>
|
||||
where
|
||||
B: Backend + ScalarZnxAutomorphismImpl<B>,
|
||||
{
|
||||
fn scalar_znx_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: ScalarZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
{
|
||||
B::scalar_znx_automorphism_impl(self, k, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> ScalarZnxAutomorphismInplace for Module<B>
|
||||
where
|
||||
B: Backend + ScalarZnxAutomorphismInplaceIml<B>,
|
||||
{
|
||||
fn scalar_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: ScalarZnxToMut,
|
||||
{
|
||||
B::scalar_znx_automorphism_inplace_impl(self, k, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> ScalarZnxMulXpMinusOne for Module<B>
|
||||
where
|
||||
B: Backend + ScalarZnxMulXpMinusOneImpl<B>,
|
||||
{
|
||||
fn scalar_znx_mul_xp_minus_one<R, A>(&self, p: i64, r: &mut R, r_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: ScalarZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
{
|
||||
B::scalar_znx_mul_xp_minus_one_impl(self, p, r, r_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> ScalarZnxMulXpMinusOneInplace for Module<B>
|
||||
where
|
||||
B: Backend + ScalarZnxMulXpMinusOneInplaceImpl<B>,
|
||||
{
|
||||
fn scalar_znx_mul_xp_minus_one_inplace<R>(&self, p: i64, r: &mut R, r_col: usize)
|
||||
where
|
||||
R: ScalarZnxToMut,
|
||||
{
|
||||
B::scalar_znx_mul_xp_minus_one_inplace_impl(self, p, r, r_col);
|
||||
}
|
||||
}
|
||||
243
backend/src/hal/delegates/scratch.rs
Normal file
243
backend/src/hal/delegates/scratch.rs
Normal file
@@ -0,0 +1,243 @@
|
||||
use crate::hal::{
|
||||
api::{
|
||||
ScratchAvailable, ScratchFromBytes, ScratchOwnedAlloc, ScratchOwnedBorrow, TakeLike, TakeMatZnx, TakeScalarZnx,
|
||||
TakeSlice, TakeSvpPPol, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, TakeVecZnxDftSlice, TakeVecZnxSlice, TakeVmpPMat,
|
||||
},
|
||||
layouts::{
|
||||
Backend, DataRef, MatZnx, Module, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat,
|
||||
},
|
||||
oep::{
|
||||
ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeLikeImpl, TakeMatZnxImpl,
|
||||
TakeScalarZnxImpl, TakeSliceImpl, TakeSvpPPolImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxDftSliceImpl,
|
||||
TakeVecZnxImpl, TakeVecZnxSliceImpl, TakeVmpPMatImpl,
|
||||
},
|
||||
};
|
||||
|
||||
impl<B> ScratchOwnedAlloc<B> for ScratchOwned<B>
|
||||
where
|
||||
B: Backend + ScratchOwnedAllocImpl<B>,
|
||||
{
|
||||
fn alloc(size: usize) -> Self {
|
||||
B::scratch_owned_alloc_impl(size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> ScratchOwnedBorrow<B> for ScratchOwned<B>
|
||||
where
|
||||
B: Backend + ScratchOwnedBorrowImpl<B>,
|
||||
{
|
||||
fn borrow(&mut self) -> &mut Scratch<B> {
|
||||
B::scratch_owned_borrow_impl(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> ScratchFromBytes<B> for Scratch<B>
|
||||
where
|
||||
B: Backend + ScratchFromBytesImpl<B>,
|
||||
{
|
||||
fn from_bytes(data: &mut [u8]) -> &mut Scratch<B> {
|
||||
B::scratch_from_bytes_impl(data)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> ScratchAvailable for Scratch<B>
|
||||
where
|
||||
B: Backend + ScratchAvailableImpl<B>,
|
||||
{
|
||||
fn available(&self) -> usize {
|
||||
B::scratch_available_impl(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> TakeSlice for Scratch<B>
|
||||
where
|
||||
B: Backend + TakeSliceImpl<B>,
|
||||
{
|
||||
fn take_slice<T>(&mut self, len: usize) -> (&mut [T], &mut Self) {
|
||||
B::take_slice_impl(self, len)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> TakeScalarZnx<B> for Scratch<B>
|
||||
where
|
||||
B: Backend + TakeScalarZnxImpl<B>,
|
||||
{
|
||||
fn take_scalar_znx(&mut self, module: &Module<B>, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) {
|
||||
B::take_scalar_znx_impl(self, module.n(), cols)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> TakeSvpPPol<B> for Scratch<B>
|
||||
where
|
||||
B: Backend + TakeSvpPPolImpl<B>,
|
||||
{
|
||||
fn take_svp_ppol(&mut self, module: &Module<B>, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Self) {
|
||||
B::take_svp_ppol_impl(self, module.n(), cols)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> TakeVecZnx<B> for Scratch<B>
|
||||
where
|
||||
B: Backend + TakeVecZnxImpl<B>,
|
||||
{
|
||||
fn take_vec_znx(&mut self, module: &Module<B>, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self) {
|
||||
B::take_vec_znx_impl(self, module.n(), cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> TakeVecZnxSlice<B> for Scratch<B>
|
||||
where
|
||||
B: Backend + TakeVecZnxSliceImpl<B>,
|
||||
{
|
||||
fn take_vec_znx_slice(
|
||||
&mut self,
|
||||
len: usize,
|
||||
module: &Module<B>,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (Vec<VecZnx<&mut [u8]>>, &mut Self) {
|
||||
B::take_vec_znx_slice_impl(self, len, module.n(), cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> TakeVecZnxBig<B> for Scratch<B>
|
||||
where
|
||||
B: Backend + TakeVecZnxBigImpl<B>,
|
||||
{
|
||||
fn take_vec_znx_big(&mut self, module: &Module<B>, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, &mut Self) {
|
||||
B::take_vec_znx_big_impl(self, module.n(), cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> TakeVecZnxDft<B> for Scratch<B>
|
||||
where
|
||||
B: Backend + TakeVecZnxDftImpl<B>,
|
||||
{
|
||||
fn take_vec_znx_dft(&mut self, module: &Module<B>, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, &mut Self) {
|
||||
B::take_vec_znx_dft_impl(self, module.n(), cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> TakeVecZnxDftSlice<B> for Scratch<B>
|
||||
where
|
||||
B: Backend + TakeVecZnxDftSliceImpl<B>,
|
||||
{
|
||||
fn take_vec_znx_dft_slice(
|
||||
&mut self,
|
||||
len: usize,
|
||||
module: &Module<B>,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (Vec<VecZnxDft<&mut [u8], B>>, &mut Self) {
|
||||
B::take_vec_znx_dft_slice_impl(self, len, module.n(), cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> TakeVmpPMat<B> for Scratch<B>
|
||||
where
|
||||
B: Backend + TakeVmpPMatImpl<B>,
|
||||
{
|
||||
fn take_vmp_pmat(
|
||||
&mut self,
|
||||
module: &Module<B>,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
) -> (VmpPMat<&mut [u8], B>, &mut Self) {
|
||||
B::take_vmp_pmat_impl(self, module.n(), rows, cols_in, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> TakeMatZnx<B> for Scratch<B>
|
||||
where
|
||||
B: Backend + TakeMatZnxImpl<B>,
|
||||
{
|
||||
fn take_mat_znx(
|
||||
&mut self,
|
||||
module: &Module<B>,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
) -> (MatZnx<&mut [u8]>, &mut Self) {
|
||||
B::take_mat_znx_impl(self, module.n(), rows, cols_in, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, B: Backend, D> TakeLike<'a, B, ScalarZnx<D>> for Scratch<B>
|
||||
where
|
||||
B: TakeLikeImpl<'a, B, ScalarZnx<D>, Output = ScalarZnx<&'a mut [u8]>>,
|
||||
D: DataRef,
|
||||
{
|
||||
type Output = ScalarZnx<&'a mut [u8]>;
|
||||
fn take_like(&'a mut self, template: &ScalarZnx<D>) -> (Self::Output, &'a mut Self) {
|
||||
B::take_like_impl(self, template)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, B: Backend, D> TakeLike<'a, B, SvpPPol<D, B>> for Scratch<B>
|
||||
where
|
||||
B: TakeLikeImpl<'a, B, SvpPPol<D, B>, Output = SvpPPol<&'a mut [u8], B>>,
|
||||
D: DataRef,
|
||||
{
|
||||
type Output = SvpPPol<&'a mut [u8], B>;
|
||||
fn take_like(&'a mut self, template: &SvpPPol<D, B>) -> (Self::Output, &'a mut Self) {
|
||||
B::take_like_impl(self, template)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, B: Backend, D> TakeLike<'a, B, VecZnx<D>> for Scratch<B>
|
||||
where
|
||||
B: TakeLikeImpl<'a, B, VecZnx<D>, Output = VecZnx<&'a mut [u8]>>,
|
||||
D: DataRef,
|
||||
{
|
||||
type Output = VecZnx<&'a mut [u8]>;
|
||||
fn take_like(&'a mut self, template: &VecZnx<D>) -> (Self::Output, &'a mut Self) {
|
||||
B::take_like_impl(self, template)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, B: Backend, D> TakeLike<'a, B, VecZnxBig<D, B>> for Scratch<B>
|
||||
where
|
||||
B: TakeLikeImpl<'a, B, VecZnxBig<D, B>, Output = VecZnxBig<&'a mut [u8], B>>,
|
||||
D: DataRef,
|
||||
{
|
||||
type Output = VecZnxBig<&'a mut [u8], B>;
|
||||
fn take_like(&'a mut self, template: &VecZnxBig<D, B>) -> (Self::Output, &'a mut Self) {
|
||||
B::take_like_impl(self, template)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, B: Backend, D> TakeLike<'a, B, VecZnxDft<D, B>> for Scratch<B>
|
||||
where
|
||||
B: TakeLikeImpl<'a, B, VecZnxDft<D, B>, Output = VecZnxDft<&'a mut [u8], B>>,
|
||||
D: DataRef,
|
||||
{
|
||||
type Output = VecZnxDft<&'a mut [u8], B>;
|
||||
fn take_like(&'a mut self, template: &VecZnxDft<D, B>) -> (Self::Output, &'a mut Self) {
|
||||
B::take_like_impl(self, template)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, B: Backend, D> TakeLike<'a, B, MatZnx<D>> for Scratch<B>
|
||||
where
|
||||
B: TakeLikeImpl<'a, B, MatZnx<D>, Output = MatZnx<&'a mut [u8]>>,
|
||||
D: DataRef,
|
||||
{
|
||||
type Output = MatZnx<&'a mut [u8]>;
|
||||
fn take_like(&'a mut self, template: &MatZnx<D>) -> (Self::Output, &'a mut Self) {
|
||||
B::take_like_impl(self, template)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, B: Backend, D> TakeLike<'a, B, VmpPMat<D, B>> for Scratch<B>
|
||||
where
|
||||
B: TakeLikeImpl<'a, B, VmpPMat<D, B>, Output = VmpPMat<&'a mut [u8], B>>,
|
||||
D: DataRef,
|
||||
{
|
||||
type Output = VmpPMat<&'a mut [u8], B>;
|
||||
fn take_like(&'a mut self, template: &VmpPMat<D, B>) -> (Self::Output, &'a mut Self) {
|
||||
B::take_like_impl(self, template)
|
||||
}
|
||||
}
|
||||
72
backend/src/hal/delegates/svp_ppol.rs
Normal file
72
backend/src/hal/delegates/svp_ppol.rs
Normal file
@@ -0,0 +1,72 @@
|
||||
use crate::hal::{
|
||||
api::{SvpApply, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPPolFromBytes, SvpPrepare},
|
||||
layouts::{Backend, Module, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDftToMut, VecZnxDftToRef},
|
||||
oep::{SvpApplyImpl, SvpApplyInplaceImpl, SvpPPolAllocBytesImpl, SvpPPolAllocImpl, SvpPPolFromBytesImpl, SvpPrepareImpl},
|
||||
};
|
||||
|
||||
impl<B> SvpPPolFromBytes<B> for Module<B>
|
||||
where
|
||||
B: Backend + SvpPPolFromBytesImpl<B>,
|
||||
{
|
||||
fn svp_ppol_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<B> {
|
||||
B::svp_ppol_from_bytes_impl(self.n(), cols, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> SvpPPolAlloc<B> for Module<B>
|
||||
where
|
||||
B: Backend + SvpPPolAllocImpl<B>,
|
||||
{
|
||||
fn svp_ppol_alloc(&self, cols: usize) -> SvpPPolOwned<B> {
|
||||
B::svp_ppol_alloc_impl(self.n(), cols)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> SvpPPolAllocBytes for Module<B>
|
||||
where
|
||||
B: Backend + SvpPPolAllocBytesImpl<B>,
|
||||
{
|
||||
fn svp_ppol_alloc_bytes(&self, cols: usize) -> usize {
|
||||
B::svp_ppol_alloc_bytes_impl(self.n(), cols)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> SvpPrepare<B> for Module<B>
|
||||
where
|
||||
B: Backend + SvpPrepareImpl<B>,
|
||||
{
|
||||
fn svp_prepare<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: SvpPPolToMut<B>,
|
||||
A: ScalarZnxToRef,
|
||||
{
|
||||
B::svp_prepare_impl(self, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> SvpApply<B> for Module<B>
|
||||
where
|
||||
B: Backend + SvpApplyImpl<B>,
|
||||
{
|
||||
fn svp_apply<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: SvpPPolToRef<B>,
|
||||
C: VecZnxDftToRef<B>,
|
||||
{
|
||||
B::svp_apply_impl(self, res, res_col, a, a_col, b, b_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> SvpApplyInplace<B> for Module<B>
|
||||
where
|
||||
B: Backend + SvpApplyInplaceImpl,
|
||||
{
|
||||
fn svp_apply_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: SvpPPolToRef<B>,
|
||||
{
|
||||
B::svp_apply_inplace_impl(self, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
518
backend/src/hal/delegates/vec_znx.rs
Normal file
518
backend/src/hal/delegates/vec_znx.rs
Normal file
@@ -0,0 +1,518 @@
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::hal::{
|
||||
api::{
|
||||
VecZnxAdd, VecZnxAddDistF64, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAlloc, VecZnxAllocBytes,
|
||||
VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxCopy, VecZnxDecodeCoeffsi64, VecZnxDecodeVecFloat,
|
||||
VecZnxDecodeVeci64, VecZnxEncodeCoeffsi64, VecZnxEncodeVeci64, VecZnxFillDistF64, VecZnxFillNormal, VecZnxFillUniform,
|
||||
VecZnxFromBytes, VecZnxLshInplace, VecZnxMerge, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegate,
|
||||
VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace,
|
||||
VecZnxRshInplace, VecZnxSplit, VecZnxStd, VecZnxSub, VecZnxSubABInplace, VecZnxSubBAInplace, VecZnxSubScalarInplace,
|
||||
VecZnxSwithcDegree,
|
||||
},
|
||||
layouts::{Backend, Module, ScalarZnxToRef, Scratch, VecZnxOwned, VecZnxToMut, VecZnxToRef},
|
||||
oep::{
|
||||
VecZnxAddDistF64Impl, VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl, VecZnxAddScalarInplaceImpl,
|
||||
VecZnxAllocBytesImpl, VecZnxAllocImpl, VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl,
|
||||
VecZnxDecodeCoeffsi64Impl, VecZnxDecodeVecFloatImpl, VecZnxDecodeVeci64Impl, VecZnxEncodeCoeffsi64Impl,
|
||||
VecZnxEncodeVeci64Impl, VecZnxFillDistF64Impl, VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxFromBytesImpl,
|
||||
VecZnxLshInplaceImpl, VecZnxMergeImpl, VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl, VecZnxNegateImpl,
|
||||
VecZnxNegateInplaceImpl, VecZnxNormalizeImpl, VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl,
|
||||
VecZnxRotateInplaceImpl, VecZnxRshInplaceImpl, VecZnxSplitImpl, VecZnxStdImpl, VecZnxSubABInplaceImpl,
|
||||
VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarInplaceImpl, VecZnxSwithcDegreeImpl,
|
||||
},
|
||||
};
|
||||
|
||||
impl<B> VecZnxAlloc for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxAllocImpl<B>,
|
||||
{
|
||||
fn vec_znx_alloc(&self, cols: usize, size: usize) -> VecZnxOwned {
|
||||
B::vec_znx_alloc_impl(self.n(), cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxFromBytes for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxFromBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxOwned {
|
||||
B::vec_znx_from_bytes_impl(self.n(), cols, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxAllocBytes for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxAllocBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_alloc_bytes(&self, cols: usize, size: usize) -> usize {
|
||||
B::vec_znx_alloc_bytes_impl(self.n(), cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxNormalizeTmpBytes for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxNormalizeTmpBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_normalize_tmp_bytes(&self, n: usize) -> usize {
|
||||
B::vec_znx_normalize_tmp_bytes_impl(self, n)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxNormalize<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxNormalizeImpl<B>,
|
||||
{
|
||||
fn vec_znx_normalize<R, A>(&self, basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_normalize_impl(self, basek, res, res_col, a, a_col, scratch)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxNormalizeInplace<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxNormalizeInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_normalize_inplace<A>(&self, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
B::vec_znx_normalize_inplace_impl(self, basek, a, a_col, scratch)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxAdd for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxAddImpl<B>,
|
||||
{
|
||||
fn vec_znx_add<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
C: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_add_impl(self, res, res_col, a, a_col, b, b_col)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxAddInplace for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxAddInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_add_inplace_impl(self, res, res_col, a, a_col)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxAddScalarInplace for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxAddScalarInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_add_scalar_inplace<R, A>(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
{
|
||||
B::vec_znx_add_scalar_inplace_impl(self, res, res_col, res_limb, a, a_col)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxSub for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxSubImpl<B>,
|
||||
{
|
||||
fn vec_znx_sub<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
C: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_sub_impl(self, res, res_col, a, a_col, b, b_col)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxSubABInplace for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxSubABInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_sub_ab_inplace_impl(self, res, res_col, a, a_col)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxSubBAInplace for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxSubBAInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_sub_ba_inplace_impl(self, res, res_col, a, a_col)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxSubScalarInplace for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxSubScalarInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_sub_scalar_inplace<R, A>(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
{
|
||||
B::vec_znx_sub_scalar_inplace_impl(self, res, res_col, res_limb, a, a_col)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxNegate for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxNegateImpl<B>,
|
||||
{
|
||||
fn vec_znx_negate<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_negate_impl(self, res, res_col, a, a_col)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxNegateInplace for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxNegateInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_negate_inplace<A>(&self, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
B::vec_znx_negate_inplace_impl(self, a, a_col)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxLshInplace for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxLshInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_lsh_inplace<A>(&self, basek: usize, k: usize, a: &mut A)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
B::vec_znx_lsh_inplace_impl(self, basek, k, a)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxRshInplace for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxRshInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_rsh_inplace<A>(&self, basek: usize, k: usize, a: &mut A)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
B::vec_znx_rsh_inplace_impl(self, basek, k, a)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxRotate for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxRotateImpl<B>,
|
||||
{
|
||||
fn vec_znx_rotate<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_rotate_impl(self, k, res, res_col, a, a_col)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxRotateInplace for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxRotateInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_rotate_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
B::vec_znx_rotate_inplace_impl(self, k, a, a_col)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxAutomorphism for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxAutomorphismImpl<B>,
|
||||
{
|
||||
fn vec_znx_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_automorphism_impl(self, k, res, res_col, a, a_col)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxAutomorphismInplace for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxAutomorphismInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
B::vec_znx_automorphism_inplace_impl(self, k, a, a_col)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxMulXpMinusOne for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxMulXpMinusOneImpl<B>,
|
||||
{
|
||||
fn vec_znx_mul_xp_minus_one<R, A>(&self, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_mul_xp_minus_one_impl(self, p, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxMulXpMinusOneInplace for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxMulXpMinusOneInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_mul_xp_minus_one_inplace<R>(&self, p: i64, res: &mut R, res_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
B::vec_znx_mul_xp_minus_one_inplace_impl(self, p, res, res_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxSplit<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxSplitImpl<B>,
|
||||
{
|
||||
fn vec_znx_split<R, A>(&self, res: &mut Vec<R>, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_split_impl(self, res, res_col, a, a_col, scratch)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxMerge for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxMergeImpl<B>,
|
||||
{
|
||||
fn vec_znx_merge<R, A>(&self, res: &mut R, res_col: usize, a: Vec<A>, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_merge_impl(self, res, res_col, a, a_col)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxSwithcDegree for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxSwithcDegreeImpl<B>,
|
||||
{
|
||||
fn vec_znx_switch_degree<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_switch_degree_impl(self, res, res_col, a, a_col)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxCopy for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxCopyImpl<B>,
|
||||
{
|
||||
fn vec_znx_copy<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_copy_impl(self, res, res_col, a, a_col)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxStd for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxStdImpl<B>,
|
||||
{
|
||||
fn vec_znx_std<A>(&self, basek: usize, a: &A, a_col: usize) -> f64
|
||||
where
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_std_impl(self, basek, a, a_col)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxFillUniform for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxFillUniformImpl<B>,
|
||||
{
|
||||
fn vec_znx_fill_uniform<R>(&self, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
B::vec_znx_fill_uniform_impl(self, basek, res, res_col, k, source);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxFillDistF64 for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxFillDistF64Impl<B>,
|
||||
{
|
||||
fn vec_znx_fill_dist_f64<R, D: rand::prelude::Distribution<f64>>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
B::vec_znx_fill_dist_f64_impl(self, basek, res, res_col, k, source, dist, bound);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxAddDistF64 for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxAddDistF64Impl<B>,
|
||||
{
|
||||
fn vec_znx_add_dist_f64<R, D: rand::prelude::Distribution<f64>>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
B::vec_znx_add_dist_f64_impl(self, basek, res, res_col, k, source, dist, bound);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxFillNormal for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxFillNormalImpl<B>,
|
||||
{
|
||||
fn vec_znx_fill_normal<R>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
B::vec_znx_fill_normal_impl(self, basek, res, res_col, k, source, sigma, bound);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxAddNormal for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxAddNormalImpl<B>,
|
||||
{
|
||||
fn vec_znx_add_normal<R>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
B::vec_znx_add_normal_impl(self, basek, res, res_col, k, source, sigma, bound);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxEncodeVeci64 for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxEncodeVeci64Impl<B>,
|
||||
{
|
||||
fn encode_vec_i64<R>(&self, basek: usize, res: &mut R, res_col: usize, k: usize, data: &[i64], log_max: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
B::encode_vec_i64_impl(self, basek, res, res_col, k, data, log_max);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxEncodeCoeffsi64 for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxEncodeCoeffsi64Impl<B>,
|
||||
{
|
||||
fn encode_coeff_i64<R>(&self, basek: usize, res: &mut R, res_col: usize, k: usize, i: usize, data: i64, log_max: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
B::encode_coeff_i64_impl(self, basek, res, res_col, k, i, data, log_max);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxDecodeVeci64 for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDecodeVeci64Impl<B>,
|
||||
{
|
||||
fn decode_vec_i64<R>(&self, basek: usize, res: &R, res_col: usize, k: usize, data: &mut [i64])
|
||||
where
|
||||
R: VecZnxToRef,
|
||||
{
|
||||
B::decode_vec_i64_impl(self, basek, res, res_col, k, data);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxDecodeCoeffsi64 for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDecodeCoeffsi64Impl<B>,
|
||||
{
|
||||
fn decode_coeff_i64<R>(&self, basek: usize, res: &R, res_col: usize, k: usize, i: usize) -> i64
|
||||
where
|
||||
R: VecZnxToRef,
|
||||
{
|
||||
B::decode_coeff_i64_impl(self, basek, res, res_col, k, i)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxDecodeVecFloat for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDecodeVecFloatImpl<B>,
|
||||
{
|
||||
fn decode_vec_float<R>(&self, basek: usize, res: &R, col_i: usize, data: &mut [rug::Float])
|
||||
where
|
||||
R: VecZnxToRef,
|
||||
{
|
||||
B::decode_vec_float_impl(self, basek, res, col_i, data);
|
||||
}
|
||||
}
|
||||
334
backend/src/hal/delegates/vec_znx_big.rs
Normal file
334
backend/src/hal/delegates/vec_znx_big.rs
Normal file
@@ -0,0 +1,334 @@
|
||||
use rand_distr::Distribution;
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::hal::{
|
||||
api::{
|
||||
VecZnxBigAdd, VecZnxBigAddDistF64, VecZnxBigAddInplace, VecZnxBigAddNormal, VecZnxBigAddSmall, VecZnxBigAddSmallInplace,
|
||||
VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigAutomorphism, VecZnxBigAutomorphismInplace, VecZnxBigFillDistF64,
|
||||
VecZnxBigFillNormal, VecZnxBigFromBytes, VecZnxBigNegateInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
|
||||
VecZnxBigSub, VecZnxBigSubABInplace, VecZnxBigSubBAInplace, VecZnxBigSubSmallA, VecZnxBigSubSmallAInplace,
|
||||
VecZnxBigSubSmallB, VecZnxBigSubSmallBInplace,
|
||||
},
|
||||
layouts::{Backend, Module, Scratch, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef},
|
||||
oep::{
|
||||
VecZnxBigAddDistF64Impl, VecZnxBigAddImpl, VecZnxBigAddInplaceImpl, VecZnxBigAddNormalImpl, VecZnxBigAddSmallImpl,
|
||||
VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, VecZnxBigAutomorphismImpl,
|
||||
VecZnxBigAutomorphismInplaceImpl, VecZnxBigFillDistF64Impl, VecZnxBigFillNormalImpl, VecZnxBigFromBytesImpl,
|
||||
VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl, VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubABInplaceImpl,
|
||||
VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl, VecZnxBigSubSmallAImpl, VecZnxBigSubSmallAInplaceImpl,
|
||||
VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl,
|
||||
},
|
||||
};
|
||||
|
||||
impl<B> VecZnxBigAlloc<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigAllocImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_alloc(&self, cols: usize, size: usize) -> VecZnxBigOwned<B> {
|
||||
B::vec_znx_big_alloc_impl(self.n(), cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigFromBytes<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigFromBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<B> {
|
||||
B::vec_znx_big_from_bytes_impl(self.n(), cols, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigAllocBytes for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigAllocBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_alloc_bytes(&self, cols: usize, size: usize) -> usize {
|
||||
B::vec_znx_big_alloc_bytes_impl(self.n(), cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigAddDistF64<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigAddDistF64Impl<B>,
|
||||
{
|
||||
fn vec_znx_big_add_dist_f64<R: VecZnxBigToMut<B>, D: Distribution<f64>>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) {
|
||||
B::add_dist_f64_impl(self, basek, res, res_col, k, source, dist, bound);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigAddNormal<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigAddNormalImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_add_normal<R: VecZnxBigToMut<B>>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) {
|
||||
B::add_normal_impl(self, basek, res, res_col, k, source, sigma, bound);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigFillDistF64<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigFillDistF64Impl<B>,
|
||||
{
|
||||
fn vec_znx_big_fill_dist_f64<R: VecZnxBigToMut<B>, D: Distribution<f64>>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) {
|
||||
B::fill_dist_f64_impl(self, basek, res, res_col, k, source, dist, bound);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigFillNormal<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigFillNormalImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_fill_normal<R: VecZnxBigToMut<B>>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) {
|
||||
B::fill_normal_impl(self, basek, res, res_col, k, source, sigma, bound);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigAdd<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigAddImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_add<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>,
|
||||
C: VecZnxBigToRef<B>,
|
||||
{
|
||||
B::vec_znx_big_add_impl(self, res, res_col, a, a_col, b, b_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigAddInplace<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigAddInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>,
|
||||
{
|
||||
B::vec_znx_big_add_inplace_impl(self, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigAddSmall<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigAddSmallImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_add_small<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>,
|
||||
C: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_big_add_small_impl(self, res, res_col, a, a_col, b, b_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigAddSmallInplace<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigAddSmallInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_add_small_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_big_add_small_inplace_impl(self, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigSub<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigSubImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_sub<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>,
|
||||
C: VecZnxBigToRef<B>,
|
||||
{
|
||||
B::vec_znx_big_sub_impl(self, res, res_col, a, a_col, b, b_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigSubABInplace<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigSubABInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>,
|
||||
{
|
||||
B::vec_znx_big_sub_ab_inplace_impl(self, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigSubBAInplace<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigSubBAInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>,
|
||||
{
|
||||
B::vec_znx_big_sub_ba_inplace_impl(self, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigSubSmallA<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigSubSmallAImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_sub_small_a<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxToRef,
|
||||
C: VecZnxBigToRef<B>,
|
||||
{
|
||||
B::vec_znx_big_sub_small_a_impl(self, res, res_col, a, a_col, b, b_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigSubSmallAInplace<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigSubSmallAInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_sub_small_a_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_big_sub_small_a_inplace_impl(self, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigSubSmallB<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigSubSmallBImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_sub_small_b<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>,
|
||||
C: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_big_sub_small_b_impl(self, res, res_col, a, a_col, b, b_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigSubSmallBInplace<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigSubSmallBInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_sub_small_b_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_big_sub_small_b_inplace_impl(self, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigNegateInplace<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigNegateInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_negate_inplace<A>(&self, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxBigToMut<B>,
|
||||
{
|
||||
B::vec_znx_big_negate_inplace_impl(self, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigNormalizeTmpBytes for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigNormalizeTmpBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_normalize_tmp_bytes(&self, n: usize) -> usize {
|
||||
B::vec_znx_big_normalize_tmp_bytes_impl(self, n)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigNormalize<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigNormalizeImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_normalize<R, A>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxBigToRef<B>,
|
||||
{
|
||||
B::vec_znx_big_normalize_impl(self, basek, res, res_col, a, a_col, scratch);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigAutomorphism<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigAutomorphismImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>,
|
||||
{
|
||||
B::vec_znx_big_automorphism_impl(self, k, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigAutomorphismInplace<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigAutomorphismInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxBigToMut<B>,
|
||||
{
|
||||
B::vec_znx_big_automorphism_inplace_impl(self, k, a, a_col);
|
||||
}
|
||||
}
|
||||
196
backend/src/hal/delegates/vec_znx_dft.rs
Normal file
196
backend/src/hal/delegates/vec_znx_dft.rs
Normal file
@@ -0,0 +1,196 @@
|
||||
use crate::hal::{
|
||||
api::{
|
||||
VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftCopy, VecZnxDftFromBytes,
|
||||
VecZnxDftFromVecZnx, VecZnxDftSub, VecZnxDftSubABInplace, VecZnxDftSubBAInplace, VecZnxDftToVecZnxBig,
|
||||
VecZnxDftToVecZnxBigConsume, VecZnxDftToVecZnxBigTmpA, VecZnxDftToVecZnxBigTmpBytes, VecZnxDftZero,
|
||||
},
|
||||
layouts::{
|
||||
Backend, Data, Module, Scratch, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef,
|
||||
VecZnxToRef,
|
||||
},
|
||||
oep::{
|
||||
VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftCopyImpl,
|
||||
VecZnxDftFromBytesImpl, VecZnxDftFromVecZnxImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl, VecZnxDftSubImpl,
|
||||
VecZnxDftToVecZnxBigConsumeImpl, VecZnxDftToVecZnxBigImpl, VecZnxDftToVecZnxBigTmpAImpl,
|
||||
VecZnxDftToVecZnxBigTmpBytesImpl, VecZnxDftZeroImpl,
|
||||
},
|
||||
};
|
||||
|
||||
impl<B> VecZnxDftFromBytes<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDftFromBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<B> {
|
||||
B::vec_znx_dft_from_bytes_impl(self.n(), cols, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxDftAllocBytes for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDftAllocBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_dft_alloc_bytes(&self, cols: usize, size: usize) -> usize {
|
||||
B::vec_znx_dft_alloc_bytes_impl(self.n(), cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxDftAlloc<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDftAllocImpl<B>,
|
||||
{
|
||||
fn vec_znx_dft_alloc(&self, cols: usize, size: usize) -> VecZnxDftOwned<B> {
|
||||
B::vec_znx_dft_alloc_impl(self.n(), cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxDftToVecZnxBigTmpBytes for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDftToVecZnxBigTmpBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_dft_to_vec_znx_big_tmp_bytes(&self) -> usize {
|
||||
B::vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxDftToVecZnxBig<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDftToVecZnxBigImpl<B>,
|
||||
{
|
||||
fn vec_znx_dft_to_vec_znx_big<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
{
|
||||
B::vec_znx_dft_to_vec_znx_big_impl(self, res, res_col, a, a_col, scratch);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxDftToVecZnxBigTmpA<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDftToVecZnxBigTmpAImpl<B>,
|
||||
{
|
||||
fn vec_znx_dft_to_vec_znx_big_tmp_a<R, A>(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxDftToMut<B>,
|
||||
{
|
||||
B::vec_znx_dft_to_vec_znx_big_tmp_a_impl(self, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxDftToVecZnxBigConsume<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDftToVecZnxBigConsumeImpl<B>,
|
||||
{
|
||||
fn vec_znx_dft_to_vec_znx_big_consume<D: Data>(&self, a: VecZnxDft<D, B>) -> VecZnxBig<D, B>
|
||||
where
|
||||
VecZnxDft<D, B>: VecZnxDftToMut<B>,
|
||||
{
|
||||
B::vec_znx_dft_to_vec_znx_big_consume_impl(self, a)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxDftFromVecZnx<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDftFromVecZnxImpl<B>,
|
||||
{
|
||||
fn vec_znx_dft_from_vec_znx<R, A>(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_dft_from_vec_znx_impl(self, step, offset, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxDftAdd<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDftAddImpl<B>,
|
||||
{
|
||||
fn vec_znx_dft_add<R, A, D>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
D: VecZnxDftToRef<B>,
|
||||
{
|
||||
B::vec_znx_dft_add_impl(self, res, res_col, a, a_col, b, b_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxDftAddInplace<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDftAddInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_dft_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
{
|
||||
B::vec_znx_dft_add_inplace_impl(self, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxDftSub<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDftSubImpl<B>,
|
||||
{
|
||||
fn vec_znx_dft_sub<R, A, D>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
D: VecZnxDftToRef<B>,
|
||||
{
|
||||
B::vec_znx_dft_sub_impl(self, res, res_col, a, a_col, b, b_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxDftSubABInplace<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDftSubABInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_dft_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
{
|
||||
B::vec_znx_dft_sub_ab_inplace_impl(self, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxDftSubBAInplace<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDftSubBAInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_dft_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
{
|
||||
B::vec_znx_dft_sub_ba_inplace_impl(self, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxDftCopy<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDftCopyImpl<B>,
|
||||
{
|
||||
fn vec_znx_dft_copy<R, A>(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
{
|
||||
B::vec_znx_dft_copy_impl(self, step, offset, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxDftZero<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDftZeroImpl<B>,
|
||||
{
|
||||
fn vec_znx_dft_zero<R>(&self, res: &mut R)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
{
|
||||
B::vec_znx_dft_zero_impl(self, res);
|
||||
}
|
||||
}
|
||||
126
backend/src/hal/delegates/vmp_pmat.rs
Normal file
126
backend/src/hal/delegates/vmp_pmat.rs
Normal file
@@ -0,0 +1,126 @@
|
||||
use crate::hal::{
|
||||
api::{
|
||||
VmpApply, VmpApplyAdd, VmpApplyAddTmpBytes, VmpApplyTmpBytes, VmpPMatAlloc, VmpPMatAllocBytes, VmpPMatFromBytes,
|
||||
VmpPMatPrepare, VmpPrepareTmpBytes,
|
||||
},
|
||||
layouts::{Backend, MatZnxToRef, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VmpPMatOwned, VmpPMatToMut, VmpPMatToRef},
|
||||
oep::{
|
||||
VmpApplyAddImpl, VmpApplyAddTmpBytesImpl, VmpApplyImpl, VmpApplyTmpBytesImpl, VmpPMatAllocBytesImpl, VmpPMatAllocImpl,
|
||||
VmpPMatFromBytesImpl, VmpPMatPrepareImpl, VmpPrepareTmpBytesImpl,
|
||||
},
|
||||
};
|
||||
|
||||
impl<B> VmpPMatAlloc<B> for Module<B>
|
||||
where
|
||||
B: Backend + VmpPMatAllocImpl<B>,
|
||||
{
|
||||
fn vmp_pmat_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<B> {
|
||||
B::vmp_pmat_alloc_impl(self.n(), rows, cols_in, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VmpPMatAllocBytes for Module<B>
|
||||
where
|
||||
B: Backend + VmpPMatAllocBytesImpl<B>,
|
||||
{
|
||||
fn vmp_pmat_alloc_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
B::vmp_pmat_alloc_bytes_impl(self.n(), rows, cols_in, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VmpPMatFromBytes<B> for Module<B>
|
||||
where
|
||||
B: Backend + VmpPMatFromBytesImpl<B>,
|
||||
{
|
||||
fn vmp_pmat_from_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: Vec<u8>) -> VmpPMatOwned<B> {
|
||||
B::vmp_pmat_from_bytes_impl(self.n(), rows, cols_in, cols_out, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VmpPrepareTmpBytes for Module<B>
|
||||
where
|
||||
B: Backend + VmpPrepareTmpBytesImpl<B>,
|
||||
{
|
||||
fn vmp_prepare_tmp_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
B::vmp_prepare_tmp_bytes_impl(self, rows, cols_in, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VmpPMatPrepare<B> for Module<B>
|
||||
where
|
||||
B: Backend + VmpPMatPrepareImpl<B>,
|
||||
{
|
||||
fn vmp_prepare<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VmpPMatToMut<B>,
|
||||
A: MatZnxToRef,
|
||||
{
|
||||
B::vmp_prepare_impl(self, res, a, scratch)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VmpApplyTmpBytes for Module<B>
|
||||
where
|
||||
B: Backend + VmpApplyTmpBytesImpl<B>,
|
||||
{
|
||||
fn vmp_apply_tmp_bytes(
|
||||
&self,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
b_cols_in: usize,
|
||||
b_cols_out: usize,
|
||||
b_size: usize,
|
||||
) -> usize {
|
||||
B::vmp_apply_tmp_bytes_impl(
|
||||
self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VmpApply<B> for Module<B>
|
||||
where
|
||||
B: Backend + VmpApplyImpl<B>,
|
||||
{
|
||||
fn vmp_apply<R, A, C>(&self, res: &mut R, a: &A, b: &C, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
C: VmpPMatToRef<B>,
|
||||
{
|
||||
B::vmp_apply_impl(self, res, a, b, scratch);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VmpApplyAddTmpBytes for Module<B>
|
||||
where
|
||||
B: Backend + VmpApplyAddTmpBytesImpl<B>,
|
||||
{
|
||||
fn vmp_apply_add_tmp_bytes(
|
||||
&self,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
b_cols_in: usize,
|
||||
b_cols_out: usize,
|
||||
b_size: usize,
|
||||
) -> usize {
|
||||
B::vmp_apply_add_tmp_bytes_impl(
|
||||
self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VmpApplyAdd<B> for Module<B>
|
||||
where
|
||||
B: Backend + VmpApplyAddImpl<B>,
|
||||
{
|
||||
fn vmp_apply_add<R, A, C>(&self, res: &mut R, a: &A, b: &C, scale: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
C: VmpPMatToRef<B>,
|
||||
{
|
||||
B::vmp_apply_add_impl(self, res, a, b, scale, scratch);
|
||||
}
|
||||
}
|
||||
246
backend/src/hal/layouts/mat_znx.rs
Normal file
246
backend/src/hal/layouts/mat_znx.rs
Normal file
@@ -0,0 +1,246 @@
|
||||
use crate::{
|
||||
alloc_aligned,
|
||||
hal::{
|
||||
api::{DataView, DataViewMut, ZnxInfos, ZnxSliceSize, ZnxView},
|
||||
layouts::{Data, DataMut, DataRef, ReaderFrom, VecZnx, WriterTo},
|
||||
},
|
||||
};
|
||||
|
||||
#[derive(PartialEq, Eq)]
|
||||
pub struct MatZnx<D: Data> {
|
||||
data: D,
|
||||
n: usize,
|
||||
size: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
}
|
||||
|
||||
impl<D: Data> ZnxInfos for MatZnx<D> {
|
||||
fn cols(&self) -> usize {
|
||||
self.cols_in
|
||||
}
|
||||
|
||||
fn rows(&self) -> usize {
|
||||
self.rows
|
||||
}
|
||||
|
||||
fn n(&self) -> usize {
|
||||
self.n
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> ZnxSliceSize for MatZnx<D> {
|
||||
fn sl(&self) -> usize {
|
||||
self.n() * self.cols_out()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> DataView for MatZnx<D> {
|
||||
type D = D;
|
||||
fn data(&self) -> &Self::D {
|
||||
&self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> DataViewMut for MatZnx<D> {
|
||||
fn data_mut(&mut self) -> &mut Self::D {
|
||||
&mut self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> ZnxView for MatZnx<D> {
|
||||
type Scalar = i64;
|
||||
}
|
||||
|
||||
impl<D: Data> MatZnx<D> {
|
||||
pub fn cols_in(&self) -> usize {
|
||||
self.cols_in
|
||||
}
|
||||
|
||||
pub fn cols_out(&self) -> usize {
|
||||
self.cols_out
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> MatZnx<D> {
|
||||
pub fn bytes_of(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
rows * cols_in * VecZnx::<Vec<u8>>::alloc_bytes::<i64>(n, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef + From<Vec<u8>>> MatZnx<D> {
|
||||
pub(crate) fn new(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
|
||||
let data: Vec<u8> = alloc_aligned(Self::bytes_of(n, rows, cols_in, cols_out, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n,
|
||||
size,
|
||||
rows,
|
||||
cols_in,
|
||||
cols_out,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn new_from_bytes(
|
||||
n: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
bytes: impl Into<Vec<u8>>,
|
||||
) -> Self {
|
||||
let data: Vec<u8> = bytes.into();
|
||||
assert!(data.len() == Self::bytes_of(n, rows, cols_in, cols_out, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n,
|
||||
size,
|
||||
rows,
|
||||
cols_in,
|
||||
cols_out,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> MatZnx<D> {
|
||||
pub fn at(&self, row: usize, col: usize) -> VecZnx<&[u8]> {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(row < self.rows(), "rows: {} >= {}", row, self.rows());
|
||||
assert!(col < self.cols_in(), "cols: {} >= {}", col, self.cols_in());
|
||||
}
|
||||
|
||||
let self_ref: MatZnx<&[u8]> = self.to_ref();
|
||||
let nb_bytes: usize = VecZnx::<Vec<u8>>::alloc_bytes::<i64>(self.n, self.cols_out, self.size);
|
||||
let start: usize = nb_bytes * self.cols() * row + col * nb_bytes;
|
||||
let end: usize = start + nb_bytes;
|
||||
|
||||
VecZnx {
|
||||
data: &self_ref.data[start..end],
|
||||
n: self.n,
|
||||
cols: self.cols_out,
|
||||
size: self.size,
|
||||
max_size: self.size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataMut> MatZnx<D> {
|
||||
pub fn at_mut(&mut self, row: usize, col: usize) -> VecZnx<&mut [u8]> {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(row < self.rows(), "rows: {} >= {}", row, self.rows());
|
||||
assert!(col < self.cols_in(), "cols: {} >= {}", col, self.cols_in());
|
||||
}
|
||||
|
||||
let n: usize = self.n();
|
||||
let cols_out: usize = self.cols_out();
|
||||
let cols_in: usize = self.cols_in();
|
||||
let size: usize = self.size();
|
||||
|
||||
let self_ref: MatZnx<&mut [u8]> = self.to_mut();
|
||||
let nb_bytes: usize = VecZnx::<Vec<u8>>::alloc_bytes::<i64>(n, cols_out, size);
|
||||
let start: usize = nb_bytes * cols_in * row + col * nb_bytes;
|
||||
let end: usize = start + nb_bytes;
|
||||
|
||||
VecZnx {
|
||||
data: &mut self_ref.data[start..end],
|
||||
n,
|
||||
cols: cols_out,
|
||||
size,
|
||||
max_size: size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type MatZnxOwned = MatZnx<Vec<u8>>;
|
||||
pub type MatZnxMut<'a> = MatZnx<&'a mut [u8]>;
|
||||
pub type MatZnxRef<'a> = MatZnx<&'a [u8]>;
|
||||
|
||||
pub trait MatZnxToRef {
|
||||
fn to_ref(&self) -> MatZnx<&[u8]>;
|
||||
}
|
||||
|
||||
impl<D: DataRef> MatZnxToRef for MatZnx<D> {
|
||||
fn to_ref(&self) -> MatZnx<&[u8]> {
|
||||
MatZnx {
|
||||
data: self.data.as_ref(),
|
||||
n: self.n,
|
||||
rows: self.rows,
|
||||
cols_in: self.cols_in,
|
||||
cols_out: self.cols_out,
|
||||
size: self.size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait MatZnxToMut {
|
||||
fn to_mut(&mut self) -> MatZnx<&mut [u8]>;
|
||||
}
|
||||
|
||||
impl<D: DataMut> MatZnxToMut for MatZnx<D> {
|
||||
fn to_mut(&mut self) -> MatZnx<&mut [u8]> {
|
||||
MatZnx {
|
||||
data: self.data.as_mut(),
|
||||
n: self.n,
|
||||
rows: self.rows,
|
||||
cols_in: self.cols_in,
|
||||
cols_out: self.cols_out,
|
||||
size: self.size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> MatZnx<D> {
|
||||
pub(crate) fn from_data(data: D, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
|
||||
Self {
|
||||
data,
|
||||
n,
|
||||
rows,
|
||||
cols_in,
|
||||
cols_out,
|
||||
size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
|
||||
impl<D: DataMut> ReaderFrom for MatZnx<D> {
|
||||
fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
|
||||
self.n = reader.read_u64::<LittleEndian>()? as usize;
|
||||
self.size = reader.read_u64::<LittleEndian>()? as usize;
|
||||
self.rows = reader.read_u64::<LittleEndian>()? as usize;
|
||||
self.cols_in = reader.read_u64::<LittleEndian>()? as usize;
|
||||
self.cols_out = reader.read_u64::<LittleEndian>()? as usize;
|
||||
let len: usize = reader.read_u64::<LittleEndian>()? as usize;
|
||||
let buf: &mut [u8] = self.data.as_mut();
|
||||
if buf.len() != len {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::UnexpectedEof,
|
||||
format!("self.data.len()={} != read len={}", buf.len(), len),
|
||||
));
|
||||
}
|
||||
reader.read_exact(&mut buf[..len])?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> WriterTo for MatZnx<D> {
|
||||
fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
|
||||
writer.write_u64::<LittleEndian>(self.n as u64)?;
|
||||
writer.write_u64::<LittleEndian>(self.size as u64)?;
|
||||
writer.write_u64::<LittleEndian>(self.rows as u64)?;
|
||||
writer.write_u64::<LittleEndian>(self.cols_in as u64)?;
|
||||
writer.write_u64::<LittleEndian>(self.cols_out as u64)?;
|
||||
let buf: &[u8] = self.data.as_ref();
|
||||
writer.write_u64::<LittleEndian>(buf.len() as u64)?;
|
||||
writer.write_all(buf)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
25
backend/src/hal/layouts/mod.rs
Normal file
25
backend/src/hal/layouts/mod.rs
Normal file
@@ -0,0 +1,25 @@
|
||||
mod mat_znx;
|
||||
mod module;
|
||||
mod scalar_znx;
|
||||
mod scratch;
|
||||
mod serialization;
|
||||
mod svp_ppol;
|
||||
mod vec_znx;
|
||||
mod vec_znx_big;
|
||||
mod vec_znx_dft;
|
||||
mod vmp_pmat;
|
||||
|
||||
pub use mat_znx::*;
|
||||
pub use module::*;
|
||||
pub use scalar_znx::*;
|
||||
pub use scratch::*;
|
||||
pub use serialization::*;
|
||||
pub use svp_ppol::*;
|
||||
pub use vec_znx::*;
|
||||
pub use vec_znx_big::*;
|
||||
pub use vec_znx_dft::*;
|
||||
pub use vmp_pmat::*;
|
||||
|
||||
pub trait Data = PartialEq + Eq + Sized;
|
||||
pub trait DataRef = Data + AsRef<[u8]>;
|
||||
pub trait DataMut = DataRef + AsMut<[u8]>;
|
||||
@@ -1,71 +1,56 @@
|
||||
use std::{marker::PhantomData, ptr::NonNull};
|
||||
|
||||
use crate::GALOISGENERATOR;
|
||||
use crate::ffi::module::{MODULE, delete_module_info, module_info_t, new_module_info};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
#[repr(u8)]
|
||||
pub enum BACKEND {
|
||||
FFT64,
|
||||
NTT120,
|
||||
}
|
||||
|
||||
pub trait Backend {
|
||||
const KIND: BACKEND;
|
||||
fn module_type() -> u32;
|
||||
}
|
||||
|
||||
pub struct FFT64;
|
||||
pub struct NTT120;
|
||||
|
||||
impl Backend for FFT64 {
|
||||
const KIND: BACKEND = BACKEND::FFT64;
|
||||
fn module_type() -> u32 {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
impl Backend for NTT120 {
|
||||
const KIND: BACKEND = BACKEND::NTT120;
|
||||
fn module_type() -> u32 {
|
||||
1
|
||||
}
|
||||
pub trait Backend: Sized {
|
||||
type Handle: 'static;
|
||||
unsafe fn destroy(handle: NonNull<Self::Handle>);
|
||||
}
|
||||
|
||||
pub struct Module<B: Backend> {
|
||||
pub ptr: *mut MODULE,
|
||||
n: usize,
|
||||
ptr: NonNull<B::Handle>,
|
||||
n: u64,
|
||||
_marker: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Module<B> {
|
||||
// Instantiates a new module.
|
||||
pub fn new(n: usize) -> Self {
|
||||
unsafe {
|
||||
let m: *mut module_info_t = new_module_info(n as u64, B::module_type());
|
||||
if m.is_null() {
|
||||
panic!("Failed to create module.");
|
||||
}
|
||||
Self {
|
||||
ptr: m,
|
||||
n: n,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
/// Construct from a raw pointer managed elsewhere.
|
||||
/// SAFETY: `ptr` must be non-null and remain valid for the lifetime of this Module.
|
||||
#[inline]
|
||||
pub unsafe fn from_raw_parts(ptr: *mut B::Handle, n: u64) -> Self {
|
||||
Self {
|
||||
ptr: NonNull::new(ptr).expect("null module ptr"),
|
||||
n,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn n(&self) -> usize {
|
||||
self.n
|
||||
#[inline]
|
||||
pub unsafe fn ptr(&self) -> *mut <B as Backend>::Handle {
|
||||
self.ptr.as_ptr()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn n(&self) -> usize {
|
||||
self.n as usize
|
||||
}
|
||||
#[inline]
|
||||
pub fn as_mut_ptr(&self) -> *mut B::Handle {
|
||||
self.ptr.as_ptr()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn log_n(&self) -> usize {
|
||||
(usize::BITS - (self.n() - 1).leading_zeros()) as _
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn cyclotomic_order(&self) -> u64 {
|
||||
(self.n() << 1) as _
|
||||
}
|
||||
|
||||
// Returns GALOISGENERATOR^|generator| * sign(generator)
|
||||
#[inline]
|
||||
pub fn galois_element(&self, generator: i64) -> i64 {
|
||||
if generator == 0 {
|
||||
return 1;
|
||||
@@ -74,6 +59,7 @@ impl<B: Backend> Module<B> {
|
||||
}
|
||||
|
||||
// Returns gen^-1
|
||||
#[inline]
|
||||
pub fn galois_element_inv(&self, gal_el: i64) -> i64 {
|
||||
if gal_el == 0 {
|
||||
panic!("cannot invert 0")
|
||||
@@ -85,11 +71,11 @@ impl<B: Backend> Module<B> {
|
||||
|
||||
impl<B: Backend> Drop for Module<B> {
|
||||
fn drop(&mut self) {
|
||||
unsafe { delete_module_info(self.ptr) }
|
||||
unsafe { B::destroy(self.ptr) }
|
||||
}
|
||||
}
|
||||
|
||||
fn mod_exp_u64(x: u64, e: usize) -> u64 {
|
||||
pub fn mod_exp_u64(x: u64, e: usize) -> u64 {
|
||||
let mut y: u64 = 1;
|
||||
let mut x_pow: u64 = x;
|
||||
let mut exp = e;
|
||||
@@ -1,20 +1,24 @@
|
||||
use crate::ffi::vec_znx;
|
||||
use crate::znx_base::ZnxInfos;
|
||||
use crate::{
|
||||
Backend, DataView, DataViewMut, Module, VecZnx, VecZnxToMut, VecZnxToRef, ZnxSliceSize, ZnxView, ZnxViewMut, alloc_aligned,
|
||||
};
|
||||
use rand::seq::SliceRandom;
|
||||
use rand_core::RngCore;
|
||||
use rand_distr::{Distribution, weighted::WeightedIndex};
|
||||
use sampling::source::Source;
|
||||
|
||||
pub struct ScalarZnx<D> {
|
||||
use crate::{
|
||||
alloc_aligned,
|
||||
hal::{
|
||||
api::{DataView, DataViewMut, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero},
|
||||
layouts::{Data, DataMut, DataRef, ReaderFrom, VecZnx, VecZnxToMut, VecZnxToRef, WriterTo},
|
||||
},
|
||||
};
|
||||
|
||||
#[derive(PartialEq, Eq)]
|
||||
pub struct ScalarZnx<D: Data> {
|
||||
pub(crate) data: D,
|
||||
pub(crate) n: usize,
|
||||
pub(crate) cols: usize,
|
||||
}
|
||||
|
||||
impl<D> ZnxInfos for ScalarZnx<D> {
|
||||
impl<D: Data> ZnxInfos for ScalarZnx<D> {
|
||||
fn cols(&self) -> usize {
|
||||
self.cols
|
||||
}
|
||||
@@ -32,30 +36,30 @@ impl<D> ZnxInfos for ScalarZnx<D> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<D> ZnxSliceSize for ScalarZnx<D> {
|
||||
impl<D: Data> ZnxSliceSize for ScalarZnx<D> {
|
||||
fn sl(&self) -> usize {
|
||||
self.n()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D> DataView for ScalarZnx<D> {
|
||||
impl<D: Data> DataView for ScalarZnx<D> {
|
||||
type D = D;
|
||||
fn data(&self) -> &Self::D {
|
||||
&self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D> DataViewMut for ScalarZnx<D> {
|
||||
impl<D: Data> DataViewMut for ScalarZnx<D> {
|
||||
fn data_mut(&mut self) -> &mut Self::D {
|
||||
&mut self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: AsRef<[u8]>> ZnxView for ScalarZnx<D> {
|
||||
impl<D: DataRef> ZnxView for ScalarZnx<D> {
|
||||
type Scalar = i64;
|
||||
}
|
||||
|
||||
impl<D: AsMut<[u8]> + AsRef<[u8]>> ScalarZnx<D> {
|
||||
impl<D: DataMut> ScalarZnx<D> {
|
||||
pub fn fill_ternary_prob(&mut self, col: usize, prob: f64, source: &mut Source) {
|
||||
let choices: [i64; 3] = [-1, 0, 1];
|
||||
let weights: [f64; 3] = [prob / 2.0, 1.0 - prob, prob / 2.0];
|
||||
@@ -103,11 +107,13 @@ impl<D: AsMut<[u8]> + AsRef<[u8]>> ScalarZnx<D> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: From<Vec<u8>>> ScalarZnx<D> {
|
||||
pub(crate) fn bytes_of(n: usize, cols: usize) -> usize {
|
||||
impl<D: DataRef> ScalarZnx<D> {
|
||||
pub fn bytes_of(n: usize, cols: usize) -> usize {
|
||||
n * cols * size_of::<i64>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef + From<Vec<u8>>> ScalarZnx<D> {
|
||||
pub fn new(n: usize, cols: usize) -> Self {
|
||||
let data: Vec<u8> = alloc_aligned::<u8>(Self::bytes_of(n, cols));
|
||||
Self {
|
||||
@@ -128,94 +134,18 @@ impl<D: From<Vec<u8>>> ScalarZnx<D> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataMut> ZnxZero for ScalarZnx<D> {
|
||||
fn zero(&mut self) {
|
||||
self.raw_mut().fill(0)
|
||||
}
|
||||
fn zero_at(&mut self, i: usize, j: usize) {
|
||||
self.at_mut(i, j).fill(0);
|
||||
}
|
||||
}
|
||||
|
||||
pub type ScalarZnxOwned = ScalarZnx<Vec<u8>>;
|
||||
|
||||
pub(crate) fn bytes_of_scalar_znx<B: Backend>(module: &Module<B>, cols: usize) -> usize {
|
||||
ScalarZnxOwned::bytes_of(module.n(), cols)
|
||||
}
|
||||
|
||||
pub trait ScalarZnxAlloc {
|
||||
fn bytes_of_scalar_znx(&self, cols: usize) -> usize;
|
||||
fn new_scalar_znx(&self, cols: usize) -> ScalarZnxOwned;
|
||||
fn new_scalar_znx_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxOwned;
|
||||
}
|
||||
|
||||
impl<B: Backend> ScalarZnxAlloc for Module<B> {
|
||||
fn bytes_of_scalar_znx(&self, cols: usize) -> usize {
|
||||
ScalarZnxOwned::bytes_of(self.n(), cols)
|
||||
}
|
||||
fn new_scalar_znx(&self, cols: usize) -> ScalarZnxOwned {
|
||||
ScalarZnxOwned::new(self.n(), cols)
|
||||
}
|
||||
fn new_scalar_znx_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxOwned {
|
||||
ScalarZnxOwned::new_from_bytes(self.n(), cols, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ScalarZnxOps {
|
||||
fn scalar_znx_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: ScalarZnxToMut,
|
||||
A: ScalarZnxToRef;
|
||||
|
||||
/// Applies the automorphism X^i -> X^ik on the selected column of `a`.
|
||||
fn scalar_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: ScalarZnxToMut;
|
||||
}
|
||||
|
||||
impl<B: Backend> ScalarZnxOps for Module<B> {
|
||||
fn scalar_znx_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: ScalarZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
{
|
||||
let a: ScalarZnx<&[u8]> = a.to_ref();
|
||||
let mut res: ScalarZnx<&mut [u8]> = res.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
self.ptr,
|
||||
k,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn scalar_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: ScalarZnxToMut,
|
||||
{
|
||||
let mut a: ScalarZnx<&mut [u8]> = a.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
self.ptr,
|
||||
k,
|
||||
a.at_mut_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D> ScalarZnx<D> {
|
||||
impl<D: Data> ScalarZnx<D> {
|
||||
pub(crate) fn from_data(data: D, n: usize, cols: usize) -> Self {
|
||||
Self { data, n, cols }
|
||||
}
|
||||
@@ -225,10 +155,7 @@ pub trait ScalarZnxToRef {
|
||||
fn to_ref(&self) -> ScalarZnx<&[u8]>;
|
||||
}
|
||||
|
||||
impl<D> ScalarZnxToRef for ScalarZnx<D>
|
||||
where
|
||||
D: AsRef<[u8]>,
|
||||
{
|
||||
impl<D: DataRef> ScalarZnxToRef for ScalarZnx<D> {
|
||||
fn to_ref(&self) -> ScalarZnx<&[u8]> {
|
||||
ScalarZnx {
|
||||
data: self.data.as_ref(),
|
||||
@@ -242,10 +169,7 @@ pub trait ScalarZnxToMut {
|
||||
fn to_mut(&mut self) -> ScalarZnx<&mut [u8]>;
|
||||
}
|
||||
|
||||
impl<D> ScalarZnxToMut for ScalarZnx<D>
|
||||
where
|
||||
D: AsRef<[u8]> + AsMut<[u8]>,
|
||||
{
|
||||
impl<D: DataMut> ScalarZnxToMut for ScalarZnx<D> {
|
||||
fn to_mut(&mut self) -> ScalarZnx<&mut [u8]> {
|
||||
ScalarZnx {
|
||||
data: self.data.as_mut(),
|
||||
@@ -255,30 +179,56 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<D> VecZnxToRef for ScalarZnx<D>
|
||||
where
|
||||
D: AsRef<[u8]>,
|
||||
{
|
||||
impl<D: DataRef> VecZnxToRef for ScalarZnx<D> {
|
||||
fn to_ref(&self) -> VecZnx<&[u8]> {
|
||||
VecZnx {
|
||||
data: self.data.as_ref(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: 1,
|
||||
max_size: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D> VecZnxToMut for ScalarZnx<D>
|
||||
where
|
||||
D: AsRef<[u8]> + AsMut<[u8]>,
|
||||
{
|
||||
impl<D: DataMut> VecZnxToMut for ScalarZnx<D> {
|
||||
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
|
||||
VecZnx {
|
||||
data: self.data.as_mut(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: 1,
|
||||
max_size: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
|
||||
impl<D: DataMut> ReaderFrom for ScalarZnx<D> {
|
||||
fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
|
||||
self.n = reader.read_u64::<LittleEndian>()? as usize;
|
||||
self.cols = reader.read_u64::<LittleEndian>()? as usize;
|
||||
let len: usize = reader.read_u64::<LittleEndian>()? as usize;
|
||||
let buf: &mut [u8] = self.data.as_mut();
|
||||
if buf.len() != len {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::UnexpectedEof,
|
||||
format!("self.data.len()={} != read len={}", buf.len(), len),
|
||||
));
|
||||
}
|
||||
reader.read_exact(&mut buf[..len])?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> WriterTo for ScalarZnx<D> {
|
||||
fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
|
||||
writer.write_u64::<LittleEndian>(self.n as u64)?;
|
||||
writer.write_u64::<LittleEndian>(self.cols as u64)?;
|
||||
let buf: &[u8] = self.data.as_ref();
|
||||
writer.write_u64::<LittleEndian>(buf.len() as u64)?;
|
||||
writer.write_all(buf)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
13
backend/src/hal/layouts/scratch.rs
Normal file
13
backend/src/hal/layouts/scratch.rs
Normal file
@@ -0,0 +1,13 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::hal::layouts::Backend;
|
||||
|
||||
pub struct ScratchOwned<B: Backend> {
|
||||
pub(crate) data: Vec<u8>,
|
||||
pub(crate) _phantom: PhantomData<B>,
|
||||
}
|
||||
|
||||
pub struct Scratch<B: Backend> {
|
||||
pub(crate) _phantom: PhantomData<B>,
|
||||
pub(crate) data: [u8],
|
||||
}
|
||||
9
backend/src/hal/layouts/serialization.rs
Normal file
9
backend/src/hal/layouts/serialization.rs
Normal file
@@ -0,0 +1,9 @@
|
||||
use std::io::{Read, Result, Write};
|
||||
|
||||
pub trait WriterTo {
|
||||
fn write_to<W: Write>(&self, writer: &mut W) -> Result<()>;
|
||||
}
|
||||
|
||||
pub trait ReaderFrom {
|
||||
fn read_from<R: Read>(&mut self, reader: &mut R) -> Result<()>;
|
||||
}
|
||||
151
backend/src/hal/layouts/svp_ppol.rs
Normal file
151
backend/src/hal/layouts/svp_ppol.rs
Normal file
@@ -0,0 +1,151 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::{
|
||||
alloc_aligned,
|
||||
hal::{
|
||||
api::{DataView, DataViewMut, ZnxInfos},
|
||||
layouts::{Backend, Data, DataMut, DataRef, ReaderFrom, WriterTo},
|
||||
},
|
||||
};
|
||||
|
||||
#[derive(PartialEq, Eq)]
|
||||
pub struct SvpPPol<D: Data, B: Backend> {
|
||||
data: D,
|
||||
n: usize,
|
||||
cols: usize,
|
||||
_phantom: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<D: Data, B: Backend> ZnxInfos for SvpPPol<D, B> {
|
||||
fn cols(&self) -> usize {
|
||||
self.cols
|
||||
}
|
||||
|
||||
fn rows(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn n(&self) -> usize {
|
||||
self.n
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
1
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data, B: Backend> DataView for SvpPPol<D, B> {
|
||||
type D = D;
|
||||
fn data(&self) -> &Self::D {
|
||||
&self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data, B: Backend> DataViewMut for SvpPPol<D, B> {
|
||||
fn data_mut(&mut self) -> &mut Self::D {
|
||||
&mut self.data
|
||||
}
|
||||
}
|
||||
|
||||
pub trait SvpPPolBytesOf {
|
||||
fn bytes_of(n: usize, cols: usize) -> usize;
|
||||
}
|
||||
|
||||
impl<D: Data + From<Vec<u8>>, B: Backend> SvpPPol<D, B>
|
||||
where
|
||||
SvpPPol<D, B>: SvpPPolBytesOf,
|
||||
{
|
||||
pub(crate) fn alloc(n: usize, cols: usize) -> Self {
|
||||
let data: Vec<u8> = alloc_aligned::<u8>(Self::bytes_of(n, cols));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n,
|
||||
cols,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn from_bytes(n: usize, cols: usize, bytes: impl Into<Vec<u8>>) -> Self {
|
||||
let data: Vec<u8> = bytes.into();
|
||||
assert!(data.len() == Self::bytes_of(n, cols));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n,
|
||||
cols,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type SvpPPolOwned<B> = SvpPPol<Vec<u8>, B>;
|
||||
|
||||
pub trait SvpPPolToRef<B: Backend> {
|
||||
fn to_ref(&self) -> SvpPPol<&[u8], B>;
|
||||
}
|
||||
|
||||
impl<D: DataRef, B: Backend> SvpPPolToRef<B> for SvpPPol<D, B> {
|
||||
fn to_ref(&self) -> SvpPPol<&[u8], B> {
|
||||
SvpPPol {
|
||||
data: self.data.as_ref(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait SvpPPolToMut<B: Backend> {
|
||||
fn to_mut(&mut self) -> SvpPPol<&mut [u8], B>;
|
||||
}
|
||||
|
||||
impl<D: DataMut, B: Backend> SvpPPolToMut<B> for SvpPPol<D, B> {
|
||||
fn to_mut(&mut self) -> SvpPPol<&mut [u8], B> {
|
||||
SvpPPol {
|
||||
data: self.data.as_mut(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data, B: Backend> SvpPPol<D, B> {
|
||||
pub(crate) fn from_data(data: D, n: usize, cols: usize) -> Self {
|
||||
Self {
|
||||
data,
|
||||
n,
|
||||
cols,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
|
||||
impl<D: DataMut, B: Backend> ReaderFrom for SvpPPol<D, B> {
|
||||
fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
|
||||
self.n = reader.read_u64::<LittleEndian>()? as usize;
|
||||
self.cols = reader.read_u64::<LittleEndian>()? as usize;
|
||||
let len: usize = reader.read_u64::<LittleEndian>()? as usize;
|
||||
let buf: &mut [u8] = self.data.as_mut();
|
||||
if buf.len() != len {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::UnexpectedEof,
|
||||
format!("self.data.len()={} != read len={}", buf.len(), len),
|
||||
));
|
||||
}
|
||||
reader.read_exact(&mut buf[..len])?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef, B: Backend> WriterTo for SvpPPol<D, B> {
|
||||
fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
|
||||
writer.write_u64::<LittleEndian>(self.n as u64)?;
|
||||
writer.write_u64::<LittleEndian>(self.cols as u64)?;
|
||||
let buf: &[u8] = self.data.as_ref();
|
||||
writer.write_u64::<LittleEndian>(buf.len() as u64)?;
|
||||
writer.write_all(buf)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
241
backend/src/hal/layouts/vec_znx.rs
Normal file
241
backend/src/hal/layouts/vec_znx.rs
Normal file
@@ -0,0 +1,241 @@
|
||||
use std::fmt;
|
||||
|
||||
use crate::{
|
||||
alloc_aligned,
|
||||
hal::{
|
||||
api::{DataView, DataViewMut, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero},
|
||||
layouts::{Data, DataMut, DataRef, ReaderFrom, WriterTo},
|
||||
},
|
||||
};
|
||||
|
||||
#[derive(PartialEq, Eq)]
|
||||
pub struct VecZnx<D: Data> {
|
||||
pub(crate) data: D,
|
||||
pub(crate) n: usize,
|
||||
pub(crate) cols: usize,
|
||||
pub(crate) size: usize,
|
||||
pub(crate) max_size: usize,
|
||||
}
|
||||
|
||||
impl<D: DataRef> fmt::Debug for VecZnx<D> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}", self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> ZnxInfos for VecZnx<D> {
|
||||
fn cols(&self) -> usize {
|
||||
self.cols
|
||||
}
|
||||
|
||||
fn rows(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn n(&self) -> usize {
|
||||
self.n
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> ZnxSliceSize for VecZnx<D> {
|
||||
fn sl(&self) -> usize {
|
||||
self.n() * self.cols()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> DataView for VecZnx<D> {
|
||||
type D = D;
|
||||
fn data(&self) -> &Self::D {
|
||||
&self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> DataViewMut for VecZnx<D> {
|
||||
fn data_mut(&mut self) -> &mut Self::D {
|
||||
&mut self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> ZnxView for VecZnx<D> {
|
||||
type Scalar = i64;
|
||||
}
|
||||
|
||||
impl VecZnx<Vec<u8>> {
|
||||
pub fn rsh_scratch_space(n: usize) -> usize {
|
||||
n * std::mem::size_of::<i64>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataMut> ZnxZero for VecZnx<D> {
|
||||
fn zero(&mut self) {
|
||||
self.raw_mut().fill(0)
|
||||
}
|
||||
fn zero_at(&mut self, i: usize, j: usize) {
|
||||
self.at_mut(i, j).fill(0);
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> VecZnx<D> {
|
||||
pub fn alloc_bytes<Scalar: Sized>(n: usize, cols: usize, size: usize) -> usize {
|
||||
n * cols * size * size_of::<Scalar>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef + From<Vec<u8>>> VecZnx<D> {
|
||||
pub fn new<Scalar: Sized>(n: usize, cols: usize, size: usize) -> Self {
|
||||
let data: Vec<u8> = alloc_aligned::<u8>(Self::alloc_bytes::<Scalar>(n, cols, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n,
|
||||
cols,
|
||||
size,
|
||||
max_size: size,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_bytes<Scalar: Sized>(n: usize, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
|
||||
let data: Vec<u8> = bytes.into();
|
||||
assert!(data.len() == Self::alloc_bytes::<Scalar>(n, cols, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n,
|
||||
cols,
|
||||
size,
|
||||
max_size: size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> VecZnx<D> {
|
||||
pub fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
|
||||
Self {
|
||||
data,
|
||||
n,
|
||||
cols,
|
||||
size,
|
||||
max_size: size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> fmt::Display for VecZnx<D> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
writeln!(
|
||||
f,
|
||||
"VecZnx(n={}, cols={}, size={})",
|
||||
self.n, self.cols, self.size
|
||||
)?;
|
||||
|
||||
for col in 0..self.cols {
|
||||
writeln!(f, "Column {}:", col)?;
|
||||
for size in 0..self.size {
|
||||
let coeffs = self.at(col, size);
|
||||
write!(f, " Size {}: [", size)?;
|
||||
|
||||
let max_show = 100;
|
||||
let show_count = coeffs.len().min(max_show);
|
||||
|
||||
for (i, &coeff) in coeffs.iter().take(show_count).enumerate() {
|
||||
if i > 0 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
write!(f, "{}", coeff)?;
|
||||
}
|
||||
|
||||
if coeffs.len() > max_show {
|
||||
write!(f, ", ... ({} more)", coeffs.len() - max_show)?;
|
||||
}
|
||||
|
||||
writeln!(f, "]")?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub type VecZnxOwned = VecZnx<Vec<u8>>;
|
||||
pub type VecZnxMut<'a> = VecZnx<&'a mut [u8]>;
|
||||
pub type VecZnxRef<'a> = VecZnx<&'a [u8]>;
|
||||
|
||||
pub trait VecZnxToRef {
|
||||
fn to_ref(&self) -> VecZnx<&[u8]>;
|
||||
}
|
||||
|
||||
impl<D: DataRef> VecZnxToRef for VecZnx<D> {
|
||||
fn to_ref(&self) -> VecZnx<&[u8]> {
|
||||
VecZnx {
|
||||
data: self.data.as_ref(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: self.size,
|
||||
max_size: self.max_size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait VecZnxToMut {
|
||||
fn to_mut(&mut self) -> VecZnx<&mut [u8]>;
|
||||
}
|
||||
|
||||
impl<D: DataMut> VecZnxToMut for VecZnx<D> {
|
||||
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
|
||||
VecZnx {
|
||||
data: self.data.as_mut(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: self.size,
|
||||
max_size: self.max_size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> VecZnx<D> {
|
||||
pub fn clone(&self) -> VecZnx<Vec<u8>> {
|
||||
let self_ref: VecZnx<&[u8]> = self.to_ref();
|
||||
VecZnx {
|
||||
data: self_ref.data.to_vec(),
|
||||
n: self_ref.n,
|
||||
cols: self_ref.cols,
|
||||
size: self_ref.size,
|
||||
max_size: self_ref.max_size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
|
||||
impl<D: DataMut> ReaderFrom for VecZnx<D> {
|
||||
fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
|
||||
self.n = reader.read_u64::<LittleEndian>()? as usize;
|
||||
self.cols = reader.read_u64::<LittleEndian>()? as usize;
|
||||
self.size = reader.read_u64::<LittleEndian>()? as usize;
|
||||
self.max_size = reader.read_u64::<LittleEndian>()? as usize;
|
||||
let len: usize = reader.read_u64::<LittleEndian>()? as usize;
|
||||
let buf: &mut [u8] = self.data.as_mut();
|
||||
if buf.len() != len {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::UnexpectedEof,
|
||||
format!("self.data.len()={} != read len={}", buf.len(), len),
|
||||
));
|
||||
}
|
||||
reader.read_exact(&mut buf[..len])?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> WriterTo for VecZnx<D> {
|
||||
fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
|
||||
writer.write_u64::<LittleEndian>(self.n as u64)?;
|
||||
writer.write_u64::<LittleEndian>(self.cols as u64)?;
|
||||
writer.write_u64::<LittleEndian>(self.size as u64)?;
|
||||
writer.write_u64::<LittleEndian>(self.max_size as u64)?;
|
||||
let buf: &[u8] = self.data.as_ref();
|
||||
writer.write_u64::<LittleEndian>(buf.len() as u64)?;
|
||||
writer.write_all(buf)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
148
backend/src/hal/layouts/vec_znx_big.rs
Normal file
148
backend/src/hal/layouts/vec_znx_big.rs
Normal file
@@ -0,0 +1,148 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use rand_distr::num_traits::Zero;
|
||||
|
||||
use crate::{
|
||||
alloc_aligned,
|
||||
hal::{
|
||||
api::{DataView, DataViewMut, ZnxInfos, ZnxView, ZnxViewMut, ZnxZero},
|
||||
layouts::{Backend, Data, DataMut, DataRef},
|
||||
},
|
||||
};
|
||||
|
||||
#[derive(PartialEq, Eq)]
|
||||
pub struct VecZnxBig<D: Data, B: Backend> {
|
||||
pub(crate) data: D,
|
||||
pub(crate) n: usize,
|
||||
pub(crate) cols: usize,
|
||||
pub(crate) size: usize,
|
||||
pub(crate) max_size: usize,
|
||||
pub(crate) _phantom: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<D: Data, B: Backend> ZnxInfos for VecZnxBig<D, B> {
|
||||
fn cols(&self) -> usize {
|
||||
self.cols
|
||||
}
|
||||
|
||||
fn rows(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn n(&self) -> usize {
|
||||
self.n
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data, B: Backend> DataView for VecZnxBig<D, B> {
|
||||
type D = D;
|
||||
fn data(&self) -> &Self::D {
|
||||
&self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data, B: Backend> DataViewMut for VecZnxBig<D, B> {
|
||||
fn data_mut(&mut self) -> &mut Self::D {
|
||||
&mut self.data
|
||||
}
|
||||
}
|
||||
|
||||
pub trait VecZnxBigBytesOf {
|
||||
fn bytes_of(n: usize, cols: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
impl<D: DataMut, B: Backend> ZnxZero for VecZnxBig<D, B>
|
||||
where
|
||||
Self: ZnxViewMut,
|
||||
<Self as ZnxView>::Scalar: Zero + Copy,
|
||||
{
|
||||
fn zero(&mut self) {
|
||||
self.raw_mut().fill(<Self as ZnxView>::Scalar::zero())
|
||||
}
|
||||
fn zero_at(&mut self, i: usize, j: usize) {
|
||||
self.at_mut(i, j).fill(<Self as ZnxView>::Scalar::zero());
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef + From<Vec<u8>>, B: Backend> VecZnxBig<D, B>
|
||||
where
|
||||
VecZnxBig<D, B>: VecZnxBigBytesOf,
|
||||
{
|
||||
pub(crate) fn new(n: usize, cols: usize, size: usize) -> Self {
|
||||
let data = alloc_aligned::<u8>(Self::bytes_of(n, cols, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n,
|
||||
cols,
|
||||
size,
|
||||
max_size: size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn new_from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
|
||||
let data: Vec<u8> = bytes.into();
|
||||
assert!(data.len() == Self::bytes_of(n, cols, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n,
|
||||
cols,
|
||||
size,
|
||||
max_size: size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data, B: Backend> VecZnxBig<D, B> {
|
||||
pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
|
||||
Self {
|
||||
data,
|
||||
n,
|
||||
cols,
|
||||
size,
|
||||
max_size: size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type VecZnxBigOwned<B> = VecZnxBig<Vec<u8>, B>;
|
||||
|
||||
pub trait VecZnxBigToRef<B: Backend> {
|
||||
fn to_ref(&self) -> VecZnxBig<&[u8], B>;
|
||||
}
|
||||
|
||||
impl<D: DataRef, B: Backend> VecZnxBigToRef<B> for VecZnxBig<D, B> {
|
||||
fn to_ref(&self) -> VecZnxBig<&[u8], B> {
|
||||
VecZnxBig {
|
||||
data: self.data.as_ref(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: self.size,
|
||||
max_size: self.max_size,
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait VecZnxBigToMut<B: Backend> {
|
||||
fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B>;
|
||||
}
|
||||
|
||||
impl<D: DataMut, B: Backend> VecZnxBigToMut<B> for VecZnxBig<D, B> {
|
||||
fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B> {
|
||||
VecZnxBig {
|
||||
data: self.data.as_mut(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: self.size,
|
||||
max_size: self.max_size,
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
166
backend/src/hal/layouts/vec_znx_dft.rs
Normal file
166
backend/src/hal/layouts/vec_znx_dft.rs
Normal file
@@ -0,0 +1,166 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use rand_distr::num_traits::Zero;
|
||||
|
||||
use crate::{
|
||||
alloc_aligned,
|
||||
hal::{
|
||||
api::{DataView, DataViewMut, ZnxInfos, ZnxView, ZnxViewMut, ZnxZero},
|
||||
layouts::{Backend, Data, DataMut, DataRef, VecZnxBig},
|
||||
},
|
||||
};
|
||||
#[derive(PartialEq, Eq)]
|
||||
pub struct VecZnxDft<D: Data, B: Backend> {
|
||||
pub(crate) data: D,
|
||||
pub(crate) n: usize,
|
||||
pub(crate) cols: usize,
|
||||
pub(crate) size: usize,
|
||||
pub(crate) max_size: usize,
|
||||
pub(crate) _phantom: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<D: Data, B: Backend> VecZnxDft<D, B> {
|
||||
pub fn into_big(self) -> VecZnxBig<D, B> {
|
||||
VecZnxBig::<D, B>::from_data(self.data, self.n, self.cols, self.size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data, B: Backend> ZnxInfos for VecZnxDft<D, B> {
|
||||
fn cols(&self) -> usize {
|
||||
self.cols
|
||||
}
|
||||
|
||||
fn rows(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn n(&self) -> usize {
|
||||
self.n
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data, B: Backend> DataView for VecZnxDft<D, B> {
|
||||
type D = D;
|
||||
fn data(&self) -> &Self::D {
|
||||
&self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data, B: Backend> DataViewMut for VecZnxDft<D, B> {
|
||||
fn data_mut(&mut self) -> &mut Self::D {
|
||||
&mut self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef, B: Backend> VecZnxDft<D, B> {
|
||||
pub fn max_size(&self) -> usize {
|
||||
self.max_size
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataMut, B: Backend> VecZnxDft<D, B> {
|
||||
pub fn set_size(&mut self, size: usize) {
|
||||
assert!(size <= self.max_size);
|
||||
self.size = size
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataMut, B: Backend> ZnxZero for VecZnxDft<D, B>
|
||||
where
|
||||
Self: ZnxViewMut,
|
||||
<Self as ZnxView>::Scalar: Zero + Copy,
|
||||
{
|
||||
fn zero(&mut self) {
|
||||
self.raw_mut().fill(<Self as ZnxView>::Scalar::zero())
|
||||
}
|
||||
fn zero_at(&mut self, i: usize, j: usize) {
|
||||
self.at_mut(i, j).fill(<Self as ZnxView>::Scalar::zero());
|
||||
}
|
||||
}
|
||||
|
||||
pub trait VecZnxDftBytesOf {
|
||||
fn bytes_of(n: usize, cols: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
impl<D: DataRef + From<Vec<u8>>, B: Backend> VecZnxDft<D, B>
|
||||
where
|
||||
VecZnxDft<D, B>: VecZnxDftBytesOf,
|
||||
{
|
||||
pub(crate) fn alloc(n: usize, cols: usize, size: usize) -> Self {
|
||||
let data: Vec<u8> = alloc_aligned::<u8>(Self::bytes_of(n, cols, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n: n,
|
||||
cols,
|
||||
size,
|
||||
max_size: size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
|
||||
let data: Vec<u8> = bytes.into();
|
||||
assert!(data.len() == Self::bytes_of(n, cols, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n: n,
|
||||
cols,
|
||||
size,
|
||||
max_size: size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type VecZnxDftOwned<B> = VecZnxDft<Vec<u8>, B>;
|
||||
|
||||
impl<D: Data, B: Backend> VecZnxDft<D, B> {
|
||||
pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
|
||||
Self {
|
||||
data,
|
||||
n,
|
||||
cols,
|
||||
size,
|
||||
max_size: size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait VecZnxDftToRef<B: Backend> {
|
||||
fn to_ref(&self) -> VecZnxDft<&[u8], B>;
|
||||
}
|
||||
|
||||
impl<D: DataRef, B: Backend> VecZnxDftToRef<B> for VecZnxDft<D, B> {
|
||||
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
|
||||
VecZnxDft {
|
||||
data: self.data.as_ref(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: self.size,
|
||||
max_size: self.max_size,
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait VecZnxDftToMut<B: Backend> {
|
||||
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B>;
|
||||
}
|
||||
|
||||
impl<D: DataMut, B: Backend> VecZnxDftToMut<B> for VecZnxDft<D, B> {
|
||||
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
|
||||
VecZnxDft {
|
||||
data: self.data.as_mut(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: self.size,
|
||||
max_size: self.max_size,
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
157
backend/src/hal/layouts/vmp_pmat.rs
Normal file
157
backend/src/hal/layouts/vmp_pmat.rs
Normal file
@@ -0,0 +1,157 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::{
|
||||
alloc_aligned,
|
||||
hal::{
|
||||
api::{DataView, DataViewMut, ZnxInfos},
|
||||
layouts::{Backend, Data, DataMut, DataRef},
|
||||
},
|
||||
};
|
||||
|
||||
#[derive(PartialEq, Eq)]
|
||||
pub struct VmpPMat<D: Data, B: Backend> {
|
||||
data: D,
|
||||
n: usize,
|
||||
size: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
_phantom: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<D: Data, B: Backend> ZnxInfos for VmpPMat<D, B> {
|
||||
fn cols(&self) -> usize {
|
||||
self.cols_in
|
||||
}
|
||||
|
||||
fn rows(&self) -> usize {
|
||||
self.rows
|
||||
}
|
||||
|
||||
fn n(&self) -> usize {
|
||||
self.n
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data, B: Backend> DataView for VmpPMat<D, B> {
|
||||
type D = D;
|
||||
fn data(&self) -> &Self::D {
|
||||
&self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data, B: Backend> DataViewMut for VmpPMat<D, B> {
|
||||
fn data_mut(&mut self) -> &mut Self::D {
|
||||
&mut self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data, B: Backend> VmpPMat<D, B> {
|
||||
pub fn cols_in(&self) -> usize {
|
||||
self.cols_in
|
||||
}
|
||||
|
||||
pub fn cols_out(&self) -> usize {
|
||||
self.cols_out
|
||||
}
|
||||
}
|
||||
|
||||
pub trait VmpPMatBytesOf {
|
||||
fn vmp_pmat_bytes_of(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
impl<D: DataRef + From<Vec<u8>>, B: Backend> VmpPMat<D, B>
|
||||
where
|
||||
B: VmpPMatBytesOf,
|
||||
{
|
||||
pub(crate) fn alloc(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
|
||||
let data: Vec<u8> = alloc_aligned(B::vmp_pmat_bytes_of(n, rows, cols_in, cols_out, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n,
|
||||
size,
|
||||
rows,
|
||||
cols_in,
|
||||
cols_out,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn from_bytes(
|
||||
n: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
bytes: impl Into<Vec<u8>>,
|
||||
) -> Self {
|
||||
let data: Vec<u8> = bytes.into();
|
||||
assert!(data.len() == B::vmp_pmat_bytes_of(n, rows, cols_in, cols_out, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n,
|
||||
size,
|
||||
rows,
|
||||
cols_in,
|
||||
cols_out,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type VmpPMatOwned<B> = VmpPMat<Vec<u8>, B>;
|
||||
pub type VmpPMatRef<'a, B> = VmpPMat<&'a [u8], B>;
|
||||
|
||||
pub trait VmpPMatToRef<B: Backend> {
|
||||
fn to_ref(&self) -> VmpPMat<&[u8], B>;
|
||||
}
|
||||
|
||||
impl<D: DataRef, B: Backend> VmpPMatToRef<B> for VmpPMat<D, B> {
|
||||
fn to_ref(&self) -> VmpPMat<&[u8], B> {
|
||||
VmpPMat {
|
||||
data: self.data.as_ref(),
|
||||
n: self.n,
|
||||
rows: self.rows,
|
||||
cols_in: self.cols_in,
|
||||
cols_out: self.cols_out,
|
||||
size: self.size,
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait VmpPMatToMut<B: Backend> {
|
||||
fn to_mut(&mut self) -> VmpPMat<&mut [u8], B>;
|
||||
}
|
||||
|
||||
impl<D: DataMut, B: Backend> VmpPMatToMut<B> for VmpPMat<D, B> {
|
||||
fn to_mut(&mut self) -> VmpPMat<&mut [u8], B> {
|
||||
VmpPMat {
|
||||
data: self.data.as_mut(),
|
||||
n: self.n,
|
||||
rows: self.rows,
|
||||
cols_in: self.cols_in,
|
||||
cols_out: self.cols_out,
|
||||
size: self.size,
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data, B: Backend> VmpPMat<D, B> {
|
||||
pub(crate) fn from_data(data: D, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
|
||||
Self {
|
||||
data,
|
||||
n,
|
||||
rows,
|
||||
cols_in,
|
||||
cols_out,
|
||||
size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
5
backend/src/hal/mod.rs
Normal file
5
backend/src/hal/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
pub mod api;
|
||||
pub mod delegates;
|
||||
pub mod layouts;
|
||||
pub mod oep;
|
||||
pub mod tests;
|
||||
20
backend/src/hal/oep/mat_znx.rs
Normal file
20
backend/src/hal/oep/mat_znx.rs
Normal file
@@ -0,0 +1,20 @@
|
||||
use crate::hal::layouts::{Backend, MatZnxOwned, Module};
|
||||
|
||||
pub unsafe trait MatZnxAllocImpl<B: Backend> {
|
||||
fn mat_znx_alloc_impl(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxOwned;
|
||||
}
|
||||
|
||||
pub unsafe trait MatZnxAllocBytesImpl<B: Backend> {
|
||||
fn mat_znx_alloc_bytes_impl(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
pub unsafe trait MatZnxFromBytesImpl<B: Backend> {
|
||||
fn mat_znx_from_bytes_impl(
|
||||
module: &Module<B>,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
bytes: Vec<u8>,
|
||||
) -> MatZnxOwned;
|
||||
}
|
||||
19
backend/src/hal/oep/mod.rs
Normal file
19
backend/src/hal/oep/mod.rs
Normal file
@@ -0,0 +1,19 @@
|
||||
mod mat_znx;
|
||||
mod module;
|
||||
mod scalar_znx;
|
||||
mod scratch;
|
||||
mod svp_ppol;
|
||||
mod vec_znx;
|
||||
mod vec_znx_big;
|
||||
mod vec_znx_dft;
|
||||
mod vmp_pmat;
|
||||
|
||||
pub use mat_znx::*;
|
||||
pub use module::*;
|
||||
pub use scalar_znx::*;
|
||||
pub use scratch::*;
|
||||
pub use svp_ppol::*;
|
||||
pub use vec_znx::*;
|
||||
pub use vec_znx_big::*;
|
||||
pub use vec_znx_dft::*;
|
||||
pub use vmp_pmat::*;
|
||||
5
backend/src/hal/oep/module.rs
Normal file
5
backend/src/hal/oep/module.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
use crate::hal::layouts::{Backend, Module};
|
||||
|
||||
pub unsafe trait ModuleNewImpl<B: Backend> {
|
||||
fn new_impl(n: u64) -> Module<B>;
|
||||
}
|
||||
39
backend/src/hal/oep/scalar_znx.rs
Normal file
39
backend/src/hal/oep/scalar_znx.rs
Normal file
@@ -0,0 +1,39 @@
|
||||
use crate::hal::layouts::{Backend, Module, ScalarZnxOwned, ScalarZnxToMut, ScalarZnxToRef};
|
||||
|
||||
pub unsafe trait ScalarZnxFromBytesImpl<B: Backend> {
|
||||
fn scalar_znx_from_bytes_impl(n: usize, cols: usize, bytes: Vec<u8>) -> ScalarZnxOwned;
|
||||
}
|
||||
|
||||
pub unsafe trait ScalarZnxAllocBytesImpl<B: Backend> {
|
||||
fn scalar_znx_alloc_bytes_impl(n: usize, cols: usize) -> usize;
|
||||
}
|
||||
|
||||
pub unsafe trait ScalarZnxAllocImpl<B: Backend> {
|
||||
fn scalar_znx_alloc_impl(n: usize, cols: usize) -> ScalarZnxOwned;
|
||||
}
|
||||
|
||||
pub unsafe trait ScalarZnxAutomorphismImpl<B: Backend> {
|
||||
fn scalar_znx_automorphism_impl<R, A>(module: &Module<B>, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: ScalarZnxToMut,
|
||||
A: ScalarZnxToRef;
|
||||
}
|
||||
|
||||
pub unsafe trait ScalarZnxAutomorphismInplaceIml<B: Backend> {
|
||||
fn scalar_znx_automorphism_inplace_impl<A>(module: &Module<B>, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: ScalarZnxToMut;
|
||||
}
|
||||
|
||||
pub unsafe trait ScalarZnxMulXpMinusOneImpl<B: Backend> {
|
||||
fn scalar_znx_mul_xp_minus_one_impl<R, A>(module: &Module<B>, p: i64, r: &mut R, r_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: ScalarZnxToMut,
|
||||
A: ScalarZnxToRef;
|
||||
}
|
||||
|
||||
pub unsafe trait ScalarZnxMulXpMinusOneInplaceImpl<B: Backend> {
|
||||
fn scalar_znx_mul_xp_minus_one_inplace_impl<R>(module: &Module<B>, p: i64, r: &mut R, r_col: usize)
|
||||
where
|
||||
R: ScalarZnxToMut;
|
||||
}
|
||||
199
backend/src/hal/oep/scratch.rs
Normal file
199
backend/src/hal/oep/scratch.rs
Normal file
@@ -0,0 +1,199 @@
|
||||
use crate::hal::{
|
||||
api::ZnxInfos,
|
||||
layouts::{Backend, DataRef, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat},
|
||||
};
|
||||
|
||||
pub unsafe trait ScratchOwnedAllocImpl<B: Backend> {
|
||||
fn scratch_owned_alloc_impl(size: usize) -> ScratchOwned<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait ScratchOwnedBorrowImpl<B: Backend> {
|
||||
fn scratch_owned_borrow_impl(scratch: &mut ScratchOwned<B>) -> &mut Scratch<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait ScratchFromBytesImpl<B: Backend> {
|
||||
fn scratch_from_bytes_impl(data: &mut [u8]) -> &mut Scratch<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait ScratchAvailableImpl<B: Backend> {
|
||||
fn scratch_available_impl(scratch: &Scratch<B>) -> usize;
|
||||
}
|
||||
|
||||
pub unsafe trait TakeSliceImpl<B: Backend> {
|
||||
fn take_slice_impl<T>(scratch: &mut Scratch<B>, len: usize) -> (&mut [T], &mut Scratch<B>);
|
||||
}
|
||||
|
||||
pub unsafe trait TakeScalarZnxImpl<B: Backend> {
|
||||
fn take_scalar_znx_impl(scratch: &mut Scratch<B>, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Scratch<B>);
|
||||
}
|
||||
|
||||
pub unsafe trait TakeSvpPPolImpl<B: Backend> {
|
||||
fn take_svp_ppol_impl(scratch: &mut Scratch<B>, n: usize, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Scratch<B>);
|
||||
}
|
||||
|
||||
pub unsafe trait TakeVecZnxImpl<B: Backend> {
|
||||
fn take_vec_znx_impl(scratch: &mut Scratch<B>, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Scratch<B>);
|
||||
}
|
||||
|
||||
pub unsafe trait TakeVecZnxSliceImpl<B: Backend> {
|
||||
fn take_vec_znx_slice_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
len: usize,
|
||||
n: usize,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (Vec<VecZnx<&mut [u8]>>, &mut Scratch<B>);
|
||||
}
|
||||
|
||||
pub unsafe trait TakeVecZnxBigImpl<B: Backend> {
|
||||
fn take_vec_znx_big_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
n: usize,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (VecZnxBig<&mut [u8], B>, &mut Scratch<B>);
|
||||
}
|
||||
|
||||
pub unsafe trait TakeVecZnxDftImpl<B: Backend> {
|
||||
fn take_vec_znx_dft_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
n: usize,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (VecZnxDft<&mut [u8], B>, &mut Scratch<B>);
|
||||
}
|
||||
|
||||
pub unsafe trait TakeVecZnxDftSliceImpl<B: Backend> {
|
||||
fn take_vec_znx_dft_slice_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
len: usize,
|
||||
n: usize,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (Vec<VecZnxDft<&mut [u8], B>>, &mut Scratch<B>);
|
||||
}
|
||||
|
||||
pub unsafe trait TakeVmpPMatImpl<B: Backend> {
|
||||
fn take_vmp_pmat_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
n: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
) -> (VmpPMat<&mut [u8], B>, &mut Scratch<B>);
|
||||
}
|
||||
|
||||
pub unsafe trait TakeMatZnxImpl<B: Backend> {
|
||||
fn take_mat_znx_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
n: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
) -> (MatZnx<&mut [u8]>, &mut Scratch<B>);
|
||||
}
|
||||
|
||||
pub trait TakeLikeImpl<'a, B: Backend, T> {
|
||||
type Output;
|
||||
fn take_like_impl(scratch: &'a mut Scratch<B>, template: &T) -> (Self::Output, &'a mut Scratch<B>);
|
||||
}
|
||||
|
||||
impl<'a, B: Backend, D> TakeLikeImpl<'a, B, VmpPMat<D, B>> for B
|
||||
where
|
||||
B: TakeVmpPMatImpl<B>,
|
||||
D: DataRef,
|
||||
{
|
||||
type Output = VmpPMat<&'a mut [u8], B>;
|
||||
|
||||
fn take_like_impl(scratch: &'a mut Scratch<B>, template: &VmpPMat<D, B>) -> (Self::Output, &'a mut Scratch<B>) {
|
||||
B::take_vmp_pmat_impl(
|
||||
scratch,
|
||||
template.n(),
|
||||
template.rows(),
|
||||
template.cols_in(),
|
||||
template.cols_out(),
|
||||
template.size(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, B: Backend, D> TakeLikeImpl<'a, B, MatZnx<D>> for B
|
||||
where
|
||||
B: TakeMatZnxImpl<B>,
|
||||
D: DataRef,
|
||||
{
|
||||
type Output = MatZnx<&'a mut [u8]>;
|
||||
|
||||
fn take_like_impl(scratch: &'a mut Scratch<B>, template: &MatZnx<D>) -> (Self::Output, &'a mut Scratch<B>) {
|
||||
B::take_mat_znx_impl(
|
||||
scratch,
|
||||
template.n(),
|
||||
template.rows(),
|
||||
template.cols_in(),
|
||||
template.cols_out(),
|
||||
template.size(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, B: Backend, D> TakeLikeImpl<'a, B, VecZnxDft<D, B>> for B
|
||||
where
|
||||
B: TakeVecZnxDftImpl<B>,
|
||||
D: DataRef,
|
||||
{
|
||||
type Output = VecZnxDft<&'a mut [u8], B>;
|
||||
|
||||
fn take_like_impl(scratch: &'a mut Scratch<B>, template: &VecZnxDft<D, B>) -> (Self::Output, &'a mut Scratch<B>) {
|
||||
B::take_vec_znx_dft_impl(scratch, template.n(), template.cols(), template.size())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, B: Backend, D> TakeLikeImpl<'a, B, VecZnxBig<D, B>> for B
|
||||
where
|
||||
B: TakeVecZnxBigImpl<B>,
|
||||
D: DataRef,
|
||||
{
|
||||
type Output = VecZnxBig<&'a mut [u8], B>;
|
||||
|
||||
fn take_like_impl(scratch: &'a mut Scratch<B>, template: &VecZnxBig<D, B>) -> (Self::Output, &'a mut Scratch<B>) {
|
||||
B::take_vec_znx_big_impl(scratch, template.n(), template.cols(), template.size())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, B: Backend, D> TakeLikeImpl<'a, B, SvpPPol<D, B>> for B
|
||||
where
|
||||
B: TakeSvpPPolImpl<B>,
|
||||
D: DataRef,
|
||||
{
|
||||
type Output = SvpPPol<&'a mut [u8], B>;
|
||||
|
||||
fn take_like_impl(scratch: &'a mut Scratch<B>, template: &SvpPPol<D, B>) -> (Self::Output, &'a mut Scratch<B>) {
|
||||
B::take_svp_ppol_impl(scratch, template.n(), template.cols())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, B: Backend, D> TakeLikeImpl<'a, B, VecZnx<D>> for B
|
||||
where
|
||||
B: TakeVecZnxImpl<B>,
|
||||
D: DataRef,
|
||||
{
|
||||
type Output = VecZnx<&'a mut [u8]>;
|
||||
|
||||
fn take_like_impl(scratch: &'a mut Scratch<B>, template: &VecZnx<D>) -> (Self::Output, &'a mut Scratch<B>) {
|
||||
B::take_vec_znx_impl(scratch, template.n(), template.cols(), template.size())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, B: Backend, D> TakeLikeImpl<'a, B, ScalarZnx<D>> for B
|
||||
where
|
||||
B: TakeScalarZnxImpl<B>,
|
||||
D: DataRef,
|
||||
{
|
||||
type Output = ScalarZnx<&'a mut [u8]>;
|
||||
|
||||
fn take_like_impl(scratch: &'a mut Scratch<B>, template: &ScalarZnx<D>) -> (Self::Output, &'a mut Scratch<B>) {
|
||||
B::take_scalar_znx_impl(scratch, template.n(), template.cols())
|
||||
}
|
||||
}
|
||||
37
backend/src/hal/oep/svp_ppol.rs
Normal file
37
backend/src/hal/oep/svp_ppol.rs
Normal file
@@ -0,0 +1,37 @@
|
||||
use crate::hal::layouts::{
|
||||
Backend, Module, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDftToMut, VecZnxDftToRef,
|
||||
};
|
||||
|
||||
pub unsafe trait SvpPPolFromBytesImpl<B: Backend> {
|
||||
fn svp_ppol_from_bytes_impl(n: usize, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait SvpPPolAllocImpl<B: Backend> {
|
||||
fn svp_ppol_alloc_impl(n: usize, cols: usize) -> SvpPPolOwned<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait SvpPPolAllocBytesImpl<B: Backend> {
|
||||
fn svp_ppol_alloc_bytes_impl(n: usize, cols: usize) -> usize;
|
||||
}
|
||||
|
||||
pub unsafe trait SvpPrepareImpl<B: Backend> {
|
||||
fn svp_prepare_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: SvpPPolToMut<B>,
|
||||
A: ScalarZnxToRef;
|
||||
}
|
||||
|
||||
pub unsafe trait SvpApplyImpl<B: Backend> {
|
||||
fn svp_apply_impl<R, A, C>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: SvpPPolToRef<B>,
|
||||
C: VecZnxDftToRef<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait SvpApplyInplaceImpl: Backend {
|
||||
fn svp_apply_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: SvpPPolToRef<Self>;
|
||||
}
|
||||
465
backend/src/hal/oep/vec_znx.rs
Normal file
465
backend/src/hal/oep/vec_znx.rs
Normal file
@@ -0,0 +1,465 @@
|
||||
use rand_distr::Distribution;
|
||||
use rug::Float;
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::hal::layouts::{Backend, Module, ScalarZnxToRef, Scratch, VecZnxOwned, VecZnxToMut, VecZnxToRef};
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::hal::layouts::VecZnx::new] for reference code.
|
||||
/// * See [crate::hal::api::VecZnxAlloc] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
/// * See test \[TODO\]
|
||||
pub unsafe trait VecZnxAllocImpl<B: Backend> {
|
||||
fn vec_znx_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxOwned;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::hal::layouts::VecZnx::from_bytes] for reference code.
|
||||
/// * See [crate::hal::api::VecZnxFromBytes] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxFromBytesImpl<B: Backend> {
|
||||
fn vec_znx_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxOwned;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::hal::layouts::VecZnx::alloc_bytes] for reference code.
|
||||
/// * See [crate::hal::api::VecZnxAllocBytes] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxAllocBytesImpl<B: Backend> {
|
||||
fn vec_znx_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [vec_znx_normalize_base2k_tmp_bytes_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L245C17-L245C55) for reference code.
|
||||
/// * See [crate::hal::api::VecZnxNormalizeTmpBytes] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxNormalizeTmpBytesImpl<B: Backend> {
|
||||
fn vec_znx_normalize_tmp_bytes_impl(module: &Module<B>, n: usize) -> usize;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [vec_znx_normalize_base2k_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L212) for reference code.
|
||||
/// * See [crate::hal::api::VecZnxNormalize] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxNormalizeImpl<B: Backend> {
|
||||
fn vec_znx_normalize_impl<R, A>(
|
||||
module: &Module<B>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [vec_znx_normalize_base2k_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L212) for reference code.
|
||||
/// * See [crate::hal::api::VecZnxNormalizeInplace] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxNormalizeInplaceImpl<B: Backend> {
|
||||
fn vec_znx_normalize_inplace_impl<A>(module: &Module<B>, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
A: VecZnxToMut;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [vec_znx_add_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L86) for reference code.
|
||||
/// * See [crate::hal::api::VecZnxAdd] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxAddImpl<B: Backend> {
|
||||
fn vec_znx_add_impl<R, A, C>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
C: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [vec_znx_add_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L86) for reference code.
|
||||
/// * See [crate::hal::api::VecZnxAddInplace] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxAddInplaceImpl<B: Backend> {
|
||||
fn vec_znx_add_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [vec_znx_add_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L86) for reference code.
|
||||
/// * See [crate::hal::api::VecZnxAddScalarInplace] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxAddScalarInplaceImpl<B: Backend> {
|
||||
fn vec_znx_add_scalar_inplace_impl<R, A>(
|
||||
module: &Module<B>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
res_limb: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [vec_znx_sub_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L125) for reference code.
|
||||
/// * See [crate::hal::api::VecZnxSub] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxSubImpl<B: Backend> {
|
||||
fn vec_znx_sub_impl<R, A, C>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
C: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [vec_znx_sub_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L125) for reference code.
|
||||
/// * See [crate::hal::api::VecZnxSubABInplace] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxSubABInplaceImpl<B: Backend> {
|
||||
fn vec_znx_sub_ab_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [vec_znx_sub_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L125) for reference code.
|
||||
/// * See [crate::hal::api::VecZnxSubBAInplace] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxSubBAInplaceImpl<B: Backend> {
|
||||
fn vec_znx_sub_ba_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [vec_znx_sub_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L125) for reference code.
|
||||
/// * See [crate::hal::api::VecZnxSubScalarInplace] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxSubScalarInplaceImpl<B: Backend> {
|
||||
fn vec_znx_sub_scalar_inplace_impl<R, A>(
|
||||
module: &Module<B>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
res_limb: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [vec_znx_negate_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L322C13-L322C31) for reference code.
|
||||
/// * See [crate::hal::api::VecZnxNegate] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxNegateImpl<B: Backend> {
|
||||
fn vec_znx_negate_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [vec_znx_negate_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L322C13-L322C31) for reference code.
|
||||
/// * See [crate::hal::api::VecZnxNegateInplace] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxNegateInplaceImpl<B: Backend> {
|
||||
fn vec_znx_negate_inplace_impl<A>(module: &Module<B>, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxToMut;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::implementation::cpu_spqlios::vec_znx::vec_znx_rsh_inplace_ref] for reference code.
|
||||
/// * See [crate::hal::api::VecZnxRshInplace] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxRshInplaceImpl<B: Backend> {
|
||||
fn vec_znx_rsh_inplace_impl<A>(module: &Module<B>, basek: usize, k: usize, a: &mut A)
|
||||
where
|
||||
A: VecZnxToMut;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::implementation::cpu_spqlios::vec_znx::vec_znx_lsh_inplace_ref] for reference code.
|
||||
/// * See [crate::hal::api::VecZnxLshInplace] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxLshInplaceImpl<B: Backend> {
|
||||
fn vec_znx_lsh_inplace_impl<A>(module: &Module<B>, basek: usize, k: usize, a: &mut A)
|
||||
where
|
||||
A: VecZnxToMut;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [vec_znx_rotate_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L164) for reference code.
|
||||
/// * See [crate::hal::api::VecZnxRotate] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxRotateImpl<B: Backend> {
|
||||
fn vec_znx_rotate_impl<R, A>(module: &Module<B>, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [vec_znx_rotate_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L164) for reference code.
|
||||
/// * See [crate::hal::api::VecZnxRotateInplace] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxRotateInplaceImpl<B: Backend> {
|
||||
fn vec_znx_rotate_inplace_impl<A>(module: &Module<B>, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxToMut;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [vec_znx_automorphism_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L188) for reference code.
|
||||
/// * See [crate::hal::api::VecZnxAutomorphism] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxAutomorphismImpl<B: Backend> {
|
||||
fn vec_znx_automorphism_impl<R, A>(module: &Module<B>, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [vec_znx_automorphism_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L188) for reference code.
|
||||
/// * See [crate::hal::api::VecZnxAutomorphismInplace] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxAutomorphismInplaceImpl<B: Backend> {
|
||||
fn vec_znx_automorphism_inplace_impl<A>(module: &Module<B>, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxToMut;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [vec_znx_mul_xp_minus_one_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/7160f588da49712a042931ea247b4259b95cefcc/spqlios/arithmetic/vec_znx.c#L200C13-L200C41) for reference code.
|
||||
/// * See [crate::hal::api::VecZnxMulXpMinusOne] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxMulXpMinusOneImpl<B: Backend> {
|
||||
fn vec_znx_mul_xp_minus_one_impl<R, A>(module: &Module<B>, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [vec_znx_mul_xp_minus_one_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/7160f588da49712a042931ea247b4259b95cefcc/spqlios/arithmetic/vec_znx.c#L200C13-L200C41) for reference code.
|
||||
/// * See [crate::hal::api::VecZnxMulXpMinusOneInplace] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxMulXpMinusOneInplaceImpl<B: Backend> {
|
||||
fn vec_znx_mul_xp_minus_one_inplace_impl<R>(module: &Module<B>, p: i64, res: &mut R, res_col: usize)
|
||||
where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::implementation::cpu_spqlios::vec_znx::vec_znx_split_ref] for reference code.
|
||||
/// * See [crate::hal::api::VecZnxSplit] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxSplitImpl<B: Backend> {
|
||||
fn vec_znx_split_impl<R, A>(
|
||||
module: &Module<B>,
|
||||
res: &mut Vec<R>,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::implementation::cpu_spqlios::vec_znx::vec_znx_merge_ref] for reference code.
|
||||
/// * See [crate::hal::api::VecZnxMerge] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxMergeImpl<B: Backend> {
|
||||
fn vec_znx_merge_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: Vec<A>, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::implementation::cpu_spqlios::vec_znx::vec_znx_switch_degree_ref] for reference code.
|
||||
/// * See [crate::hal::api::VecZnxSwithcDegree] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxSwithcDegreeImpl<B: Backend> {
|
||||
fn vec_znx_switch_degree_impl<R: VecZnxToMut, A: VecZnxToRef>(
|
||||
module: &Module<B>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
);
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::implementation::cpu_spqlios::vec_znx::vec_znx_copy_ref] for reference code.
|
||||
/// * See [crate::hal::api::VecZnxCopy] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxCopyImpl<B: Backend> {
|
||||
fn vec_znx_copy_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::hal::api::VecZnxStd] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxStdImpl<B: Backend> {
|
||||
fn vec_znx_std_impl<A>(module: &Module<B>, basek: usize, a: &A, a_col: usize) -> f64
|
||||
where
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::hal::api::VecZnxFillUniform] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxFillUniformImpl<B: Backend> {
|
||||
fn vec_znx_fill_uniform_impl<R>(module: &Module<B>, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source)
|
||||
where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::hal::api::VecZnxFillDistF64] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxFillDistF64Impl<B: Backend> {
|
||||
fn vec_znx_fill_dist_f64_impl<R, D: Distribution<f64>>(
|
||||
module: &Module<B>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::hal::api::VecZnxAddDistF64] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxAddDistF64Impl<B: Backend> {
|
||||
fn vec_znx_add_dist_f64_impl<R, D: Distribution<f64>>(
|
||||
module: &Module<B>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::hal::api::VecZnxFillNormal] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxFillNormalImpl<B: Backend> {
|
||||
fn vec_znx_fill_normal_impl<R>(
|
||||
module: &Module<B>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::hal::api::VecZnxAddNormal] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxAddNormalImpl<B: Backend> {
|
||||
fn vec_znx_add_normal_impl<R>(
|
||||
module: &Module<B>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See \[TODO\] for reference code.
|
||||
/// * See [crate::hal::api::VecZnxEncodeVeci64] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxEncodeVeci64Impl<B: Backend> {
|
||||
fn encode_vec_i64_impl<R>(
|
||||
module: &Module<B>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
data: &[i64],
|
||||
log_max: usize,
|
||||
) where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See \[TODO\] for reference code.
|
||||
/// * See [crate::hal::api::VecZnxEncodeCoeffsi64] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxEncodeCoeffsi64Impl<B: Backend> {
|
||||
fn encode_coeff_i64_impl<R>(
|
||||
module: &Module<B>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
i: usize,
|
||||
data: i64,
|
||||
log_max: usize,
|
||||
) where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See \[TODO\] for reference code.
|
||||
/// * See [crate::hal::api::VecZnxDecodeVeci64] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxDecodeVeci64Impl<B: Backend> {
|
||||
fn decode_vec_i64_impl<R>(module: &Module<B>, basek: usize, res: &R, res_col: usize, k: usize, data: &mut [i64])
|
||||
where
|
||||
R: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See \[TODO\] for reference code.
|
||||
/// * See [crate::hal::api::VecZnxDecodeCoeffsi64] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxDecodeCoeffsi64Impl<B: Backend> {
|
||||
fn decode_coeff_i64_impl<R>(module: &Module<B>, basek: usize, res: &R, res_col: usize, k: usize, i: usize) -> i64
|
||||
where
|
||||
R: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See \[TODO\] for reference code.
|
||||
/// * See [crate::hal::api::VecZnxDecodeVecFloat] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxDecodeVecFloatImpl<B: Backend> {
|
||||
fn decode_vec_float_impl<R>(module: &Module<B>, basek: usize, res: &R, res_col: usize, data: &mut [Float])
|
||||
where
|
||||
R: VecZnxToRef;
|
||||
}
|
||||
208
backend/src/hal/oep/vec_znx_big.rs
Normal file
208
backend/src/hal/oep/vec_znx_big.rs
Normal file
@@ -0,0 +1,208 @@
|
||||
use rand_distr::Distribution;
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::hal::layouts::{Backend, Module, Scratch, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef};
|
||||
|
||||
pub unsafe trait VecZnxBigAllocImpl<B: Backend> {
|
||||
fn vec_znx_big_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxBigOwned<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigFromBytesImpl<B: Backend> {
|
||||
fn vec_znx_big_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigAllocBytesImpl<B: Backend> {
|
||||
fn vec_znx_big_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigAddNormalImpl<B: Backend> {
|
||||
fn add_normal_impl<R: VecZnxBigToMut<B>>(
|
||||
module: &Module<B>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
);
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigFillNormalImpl<B: Backend> {
|
||||
fn fill_normal_impl<R: VecZnxBigToMut<B>>(
|
||||
module: &Module<B>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
);
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigFillDistF64Impl<B: Backend> {
|
||||
fn fill_dist_f64_impl<R: VecZnxBigToMut<B>, D: Distribution<f64>>(
|
||||
module: &Module<B>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
);
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigAddDistF64Impl<B: Backend> {
|
||||
fn add_dist_f64_impl<R: VecZnxBigToMut<B>, D: Distribution<f64>>(
|
||||
module: &Module<B>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
);
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigAddImpl<B: Backend> {
|
||||
fn vec_znx_big_add_impl<R, A, C>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>,
|
||||
C: VecZnxBigToRef<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigAddInplaceImpl<B: Backend> {
|
||||
fn vec_znx_big_add_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigAddSmallImpl<B: Backend> {
|
||||
fn vec_znx_big_add_small_impl<R, A, C>(
|
||||
module: &Module<B>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &C,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>,
|
||||
C: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigAddSmallInplaceImpl<B: Backend> {
|
||||
fn vec_znx_big_add_small_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigSubImpl<B: Backend> {
|
||||
fn vec_znx_big_sub_impl<R, A, C>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>,
|
||||
C: VecZnxBigToRef<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigSubABInplaceImpl<B: Backend> {
|
||||
fn vec_znx_big_sub_ab_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigSubBAInplaceImpl<B: Backend> {
|
||||
fn vec_znx_big_sub_ba_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigSubSmallAImpl<B: Backend> {
|
||||
fn vec_znx_big_sub_small_a_impl<R, A, C>(
|
||||
module: &Module<B>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &C,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxToRef,
|
||||
C: VecZnxBigToRef<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigSubSmallAInplaceImpl<B: Backend> {
|
||||
fn vec_znx_big_sub_small_a_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigSubSmallBImpl<B: Backend> {
|
||||
fn vec_znx_big_sub_small_b_impl<R, A, C>(
|
||||
module: &Module<B>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &C,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>,
|
||||
C: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigSubSmallBInplaceImpl<B: Backend> {
|
||||
fn vec_znx_big_sub_small_b_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigNegateInplaceImpl<B: Backend> {
|
||||
fn vec_znx_big_negate_inplace_impl<A>(module: &Module<B>, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxBigToMut<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigNormalizeTmpBytesImpl<B: Backend> {
|
||||
fn vec_znx_big_normalize_tmp_bytes_impl(module: &Module<B>, n: usize) -> usize;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigNormalizeImpl<B: Backend> {
|
||||
fn vec_znx_big_normalize_impl<R, A>(
|
||||
module: &Module<B>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxBigToRef<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigAutomorphismImpl<B: Backend> {
|
||||
fn vec_znx_big_automorphism_impl<R, A>(module: &Module<B>, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigAutomorphismInplaceImpl<B: Backend> {
|
||||
fn vec_znx_big_automorphism_inplace_impl<A>(module: &Module<B>, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxBigToMut<B>;
|
||||
}
|
||||
117
backend/src/hal/oep/vec_znx_dft.rs
Normal file
117
backend/src/hal/oep/vec_znx_dft.rs
Normal file
@@ -0,0 +1,117 @@
|
||||
use crate::hal::layouts::{
|
||||
Backend, Data, Module, Scratch, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef,
|
||||
VecZnxToRef,
|
||||
};
|
||||
|
||||
pub unsafe trait VecZnxDftAllocImpl<B: Backend> {
|
||||
fn vec_znx_dft_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxDftOwned<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxDftFromBytesImpl<B: Backend> {
|
||||
fn vec_znx_dft_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxDftAllocBytesImpl<B: Backend> {
|
||||
fn vec_znx_dft_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxDftToVecZnxBigTmpBytesImpl<B: Backend> {
|
||||
fn vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(module: &Module<B>) -> usize;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxDftToVecZnxBigImpl<B: Backend> {
|
||||
fn vec_znx_dft_to_vec_znx_big_impl<R, A>(
|
||||
module: &Module<B>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxDftToRef<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxDftToVecZnxBigTmpAImpl<B: Backend> {
|
||||
fn vec_znx_dft_to_vec_znx_big_tmp_a_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxDftToMut<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxDftToVecZnxBigConsumeImpl<B: Backend> {
|
||||
fn vec_znx_dft_to_vec_znx_big_consume_impl<D: Data>(module: &Module<B>, a: VecZnxDft<D, B>) -> VecZnxBig<D, B>
|
||||
where
|
||||
VecZnxDft<D, B>: VecZnxDftToMut<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxDftAddImpl<B: Backend> {
|
||||
fn vec_znx_dft_add_impl<R, A, D>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
D: VecZnxDftToRef<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxDftAddInplaceImpl<B: Backend> {
|
||||
fn vec_znx_dft_add_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxDftSubImpl<B: Backend> {
|
||||
fn vec_znx_dft_sub_impl<R, A, D>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
D: VecZnxDftToRef<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxDftSubABInplaceImpl<B: Backend> {
|
||||
fn vec_znx_dft_sub_ab_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxDftSubBAInplaceImpl<B: Backend> {
|
||||
fn vec_znx_dft_sub_ba_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxDftCopyImpl<B: Backend> {
|
||||
fn vec_znx_dft_copy_impl<R, A>(
|
||||
module: &Module<B>,
|
||||
step: usize,
|
||||
offset: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
) where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxDftFromVecZnxImpl<B: Backend> {
|
||||
fn vec_znx_dft_from_vec_znx_impl<R, A>(
|
||||
module: &Module<B>,
|
||||
step: usize,
|
||||
offset: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
) where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxDftZeroImpl<B: Backend> {
|
||||
fn vec_znx_dft_zero_impl<R>(module: &Module<B>, res: &mut R)
|
||||
where
|
||||
R: VecZnxDftToMut<B>;
|
||||
}
|
||||
74
backend/src/hal/oep/vmp_pmat.rs
Normal file
74
backend/src/hal/oep/vmp_pmat.rs
Normal file
@@ -0,0 +1,74 @@
|
||||
use crate::hal::layouts::{
|
||||
Backend, MatZnxToRef, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VmpPMatOwned, VmpPMatToMut, VmpPMatToRef,
|
||||
};
|
||||
|
||||
pub unsafe trait VmpPMatAllocImpl<B: Backend> {
|
||||
fn vmp_pmat_alloc_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VmpPMatAllocBytesImpl<B: Backend> {
|
||||
fn vmp_pmat_alloc_bytes_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
pub unsafe trait VmpPMatFromBytesImpl<B: Backend> {
|
||||
fn vmp_pmat_from_bytes_impl(
|
||||
n: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
bytes: Vec<u8>,
|
||||
) -> VmpPMatOwned<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VmpPrepareTmpBytesImpl<B: Backend> {
|
||||
fn vmp_prepare_tmp_bytes_impl(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
pub unsafe trait VmpPMatPrepareImpl<B: Backend> {
|
||||
fn vmp_prepare_impl<R, A>(module: &Module<B>, res: &mut R, a: &A, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VmpPMatToMut<B>,
|
||||
A: MatZnxToRef;
|
||||
}
|
||||
|
||||
pub unsafe trait VmpApplyTmpBytesImpl<B: Backend> {
|
||||
fn vmp_apply_tmp_bytes_impl(
|
||||
module: &Module<B>,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
b_cols_in: usize,
|
||||
b_cols_out: usize,
|
||||
b_size: usize,
|
||||
) -> usize;
|
||||
}
|
||||
|
||||
pub unsafe trait VmpApplyImpl<B: Backend> {
|
||||
fn vmp_apply_impl<R, A, C>(module: &Module<B>, res: &mut R, a: &A, b: &C, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
C: VmpPMatToRef<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VmpApplyAddTmpBytesImpl<B: Backend> {
|
||||
fn vmp_apply_add_tmp_bytes_impl(
|
||||
module: &Module<B>,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
b_cols_in: usize,
|
||||
b_cols_out: usize,
|
||||
b_size: usize,
|
||||
) -> usize;
|
||||
}
|
||||
|
||||
pub unsafe trait VmpApplyAddImpl<B: Backend> {
|
||||
// Same as [MatZnxDftOps::vmp_apply] except result is added on R instead of overwritting R.
|
||||
fn vmp_apply_add_impl<R, A, C>(module: &Module<B>, res: &mut R, a: &A, b: &C, scale: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
C: VmpPMatToRef<B>;
|
||||
}
|
||||
1
backend/src/hal/tests/mod.rs
Normal file
1
backend/src/hal/tests/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod vec_znx;
|
||||
120
backend/src/hal/tests/vec_znx/generics.rs
Normal file
120
backend/src/hal/tests/vec_znx/generics.rs
Normal file
@@ -0,0 +1,120 @@
|
||||
use itertools::izip;
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::hal::{
|
||||
api::{
|
||||
VecZnxAddNormal, VecZnxAlloc, VecZnxDecodeVeci64, VecZnxEncodeVeci64, VecZnxFillUniform, VecZnxStd, ZnxInfos, ZnxView,
|
||||
ZnxViewMut,
|
||||
},
|
||||
layouts::{Backend, Module, VecZnx},
|
||||
};
|
||||
|
||||
pub fn test_vec_znx_fill_uniform<B: Backend>(module: &Module<B>)
|
||||
where
|
||||
Module<B>: VecZnxFillUniform + VecZnxStd + VecZnxAlloc,
|
||||
{
|
||||
let basek: usize = 17;
|
||||
let size: usize = 5;
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let cols: usize = 2;
|
||||
let zero: Vec<i64> = vec![0; module.n()];
|
||||
let one_12_sqrt: f64 = 0.28867513459481287;
|
||||
(0..cols).for_each(|col_i| {
|
||||
let mut a: VecZnx<_> = module.vec_znx_alloc(cols, size);
|
||||
module.vec_znx_fill_uniform(basek, &mut a, col_i, size * basek, &mut source);
|
||||
(0..cols).for_each(|col_j| {
|
||||
if col_j != col_i {
|
||||
(0..size).for_each(|limb_i| {
|
||||
assert_eq!(a.at(col_j, limb_i), zero);
|
||||
})
|
||||
} else {
|
||||
let std: f64 = module.vec_znx_std(basek, &a, col_i);
|
||||
assert!(
|
||||
(std - one_12_sqrt).abs() < 0.01,
|
||||
"std={} ~!= {}",
|
||||
std,
|
||||
one_12_sqrt
|
||||
);
|
||||
}
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
pub fn test_vec_znx_add_normal<B: Backend>(module: &Module<B>)
|
||||
where
|
||||
Module<B>: VecZnxAddNormal + VecZnxStd + VecZnxAlloc,
|
||||
{
|
||||
let basek: usize = 17;
|
||||
let k: usize = 2 * 17;
|
||||
let size: usize = 5;
|
||||
let sigma: f64 = 3.2;
|
||||
let bound: f64 = 6.0 * sigma;
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let cols: usize = 2;
|
||||
let zero: Vec<i64> = vec![0; module.n()];
|
||||
let k_f64: f64 = (1u64 << k as u64) as f64;
|
||||
(0..cols).for_each(|col_i| {
|
||||
let mut a: VecZnx<_> = module.vec_znx_alloc(cols, size);
|
||||
module.vec_znx_add_normal(basek, &mut a, col_i, k, &mut source, sigma, bound);
|
||||
(0..cols).for_each(|col_j| {
|
||||
if col_j != col_i {
|
||||
(0..size).for_each(|limb_i| {
|
||||
assert_eq!(a.at(col_j, limb_i), zero);
|
||||
})
|
||||
} else {
|
||||
let std: f64 = module.vec_znx_std(basek, &a, col_i) * k_f64;
|
||||
assert!((std - sigma).abs() < 0.1, "std={} ~!= {}", std, sigma);
|
||||
}
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
pub fn test_vec_znx_encode_vec_i64_lo_norm<B: Backend>(module: &Module<B>)
|
||||
where
|
||||
Module<B>: VecZnxEncodeVeci64 + VecZnxDecodeVeci64 + VecZnxAlloc,
|
||||
{
|
||||
let basek: usize = 17;
|
||||
let size: usize = 5;
|
||||
let k: usize = size * basek - 5;
|
||||
let mut a: VecZnx<_> = module.vec_znx_alloc(2, size);
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let raw: &mut [i64] = a.raw_mut();
|
||||
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
|
||||
(0..a.cols()).for_each(|col_i| {
|
||||
let mut have: Vec<i64> = vec![i64::default(); module.n()];
|
||||
have.iter_mut()
|
||||
.for_each(|x| *x = (source.next_i64() << 56) >> 56);
|
||||
module.encode_vec_i64(basek, &mut a, col_i, k, &have, 10);
|
||||
let mut want: Vec<i64> = vec![i64::default(); module.n()];
|
||||
module.decode_vec_i64(basek, &a, col_i, k, &mut want);
|
||||
izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b));
|
||||
});
|
||||
}
|
||||
|
||||
pub fn test_vec_znx_encode_vec_i64_hi_norm<B: Backend>(module: &Module<B>)
|
||||
where
|
||||
Module<B>: VecZnxEncodeVeci64 + VecZnxDecodeVeci64 + VecZnxAlloc,
|
||||
{
|
||||
let basek: usize = 17;
|
||||
let size: usize = 5;
|
||||
for k in [1, basek / 2, size * basek - 5] {
|
||||
let mut a: VecZnx<_> = module.vec_znx_alloc(2, size);
|
||||
let mut source = Source::new([0u8; 32]);
|
||||
let raw: &mut [i64] = a.raw_mut();
|
||||
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
|
||||
(0..a.cols()).for_each(|col_i| {
|
||||
let mut have: Vec<i64> = vec![i64::default(); module.n()];
|
||||
have.iter_mut().for_each(|x| {
|
||||
if k < 64 {
|
||||
*x = source.next_u64n(1 << k, (1 << k) - 1) as i64;
|
||||
} else {
|
||||
*x = source.next_i64();
|
||||
}
|
||||
});
|
||||
module.encode_vec_i64(basek, &mut a, col_i, k, &have, 63);
|
||||
let mut want: Vec<i64> = vec![i64::default(); module.n()];
|
||||
module.decode_vec_i64(basek, &a, col_i, k, &mut want);
|
||||
izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b));
|
||||
})
|
||||
}
|
||||
}
|
||||
2
backend/src/hal/tests/vec_znx/mod.rs
Normal file
2
backend/src/hal/tests/vec_znx/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
mod generics;
|
||||
pub use generics::*;
|
||||
@@ -1,48 +1,47 @@
|
||||
use crate::ffi::module::MODULE;
|
||||
use crate::ffi::vec_znx_dft::VEC_ZNX_DFT;
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct svp_ppol_t {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type SVP_PPOL = svp_ppol_t;
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn bytes_of_svp_ppol(module: *const MODULE) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_svp_ppol(module: *const MODULE) -> *mut SVP_PPOL;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn delete_svp_ppol(res: *mut SVP_PPOL);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn svp_prepare(module: *const MODULE, ppol: *mut SVP_PPOL, pol: *const i64);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn svp_apply_dft(
|
||||
module: *const MODULE,
|
||||
res: *const VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
ppol: *const SVP_PPOL,
|
||||
a: *const i64,
|
||||
a_size: u64,
|
||||
a_sl: u64,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn svp_apply_dft_to_dft(
|
||||
module: *const MODULE,
|
||||
res: *const VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
res_cols: u64,
|
||||
ppol: *const SVP_PPOL,
|
||||
a: *const VEC_ZNX_DFT,
|
||||
a_size: u64,
|
||||
a_cols: u64,
|
||||
);
|
||||
}
|
||||
use crate::implementation::cpu_spqlios::ffi::{module::MODULE, vec_znx_dft::VEC_ZNX_DFT};
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct svp_ppol_t {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type SVP_PPOL = svp_ppol_t;
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn bytes_of_svp_ppol(module: *const MODULE) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_svp_ppol(module: *const MODULE) -> *mut SVP_PPOL;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn delete_svp_ppol(res: *mut SVP_PPOL);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn svp_prepare(module: *const MODULE, ppol: *mut SVP_PPOL, pol: *const i64);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn svp_apply_dft(
|
||||
module: *const MODULE,
|
||||
res: *const VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
ppol: *const SVP_PPOL,
|
||||
a: *const i64,
|
||||
a_size: u64,
|
||||
a_sl: u64,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn svp_apply_dft_to_dft(
|
||||
module: *const MODULE,
|
||||
res: *const VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
res_cols: u64,
|
||||
ppol: *const SVP_PPOL,
|
||||
a: *const VEC_ZNX_DFT,
|
||||
a_size: u64,
|
||||
a_cols: u64,
|
||||
);
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::ffi::module::MODULE;
|
||||
use crate::implementation::cpu_spqlios::ffi::module::MODULE;
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_add(
|
||||
@@ -28,6 +28,19 @@ unsafe extern "C" {
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_mul_xp_minus_one(
|
||||
module: *const MODULE,
|
||||
p: i64,
|
||||
res: *mut i64,
|
||||
res_size: u64,
|
||||
res_sl: u64,
|
||||
a: *const i64,
|
||||
a_size: u64,
|
||||
a_sl: u64,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_negate(
|
||||
module: *const MODULE,
|
||||
@@ -86,6 +99,7 @@ unsafe extern "C" {
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_normalize_base2k(
|
||||
module: *const MODULE,
|
||||
n: u64,
|
||||
base2k: u64,
|
||||
res: *mut i64,
|
||||
res_size: u64,
|
||||
@@ -97,5 +111,5 @@ unsafe extern "C" {
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64;
|
||||
pub unsafe fn vec_znx_normalize_base2k_tmp_bytes(module: *const MODULE, n: u64) -> u64;
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::ffi::module::MODULE;
|
||||
use crate::implementation::cpu_spqlios::ffi::module::MODULE;
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
@@ -103,12 +103,13 @@ unsafe extern "C" {
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64;
|
||||
pub unsafe fn vec_znx_big_normalize_base2k_tmp_bytes(module: *const MODULE, n: u64) -> u64;
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_normalize_base2k(
|
||||
module: *const MODULE,
|
||||
n: u64,
|
||||
log2_base2k: u64,
|
||||
res: *mut i64,
|
||||
res_size: u64,
|
||||
@@ -122,6 +123,7 @@ unsafe extern "C" {
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_range_normalize_base2k(
|
||||
module: *const MODULE,
|
||||
n: u64,
|
||||
log2_base2k: u64,
|
||||
res: *mut i64,
|
||||
res_size: u64,
|
||||
@@ -135,7 +137,7 @@ unsafe extern "C" {
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_range_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64;
|
||||
pub unsafe fn vec_znx_big_range_normalize_base2k_tmp_bytes(module: *const MODULE, n: u64) -> u64;
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
@@ -1,86 +1,85 @@
|
||||
use crate::ffi::module::MODULE;
|
||||
use crate::ffi::vec_znx_big::VEC_ZNX_BIG;
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct vec_znx_dft_t {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type VEC_ZNX_DFT = vec_znx_dft_t;
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn bytes_of_vec_znx_dft(module: *const MODULE, size: u64) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_vec_znx_dft(module: *const MODULE, size: u64) -> *mut VEC_ZNX_DFT;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn delete_vec_znx_dft(res: *mut VEC_ZNX_DFT);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_dft_zero(module: *const MODULE, res: *mut VEC_ZNX_DFT, res_size: u64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_dft_add(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
a: *const VEC_ZNX_DFT,
|
||||
a_size: u64,
|
||||
b: *const VEC_ZNX_DFT,
|
||||
b_size: u64,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_dft_sub(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
a: *const VEC_ZNX_DFT,
|
||||
a_size: u64,
|
||||
b: *const VEC_ZNX_DFT,
|
||||
b_size: u64,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_dft(module: *const MODULE, res: *mut VEC_ZNX_DFT, res_size: u64, a: *const i64, a_size: u64, a_sl: u64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_idft(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
a_dft: *const VEC_ZNX_DFT,
|
||||
a_size: u64,
|
||||
tmp: *mut u8,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_idft_tmp_bytes(module: *const MODULE) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_idft_tmp_a(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
a_dft: *mut VEC_ZNX_DFT,
|
||||
a_size: u64,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_dft_automorphism(
|
||||
module: *const MODULE,
|
||||
d: i64,
|
||||
res_dft: *mut VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
a_dft: *const VEC_ZNX_DFT,
|
||||
a_size: u64,
|
||||
tmp: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_dft_automorphism_tmp_bytes(module: *const MODULE) -> u64;
|
||||
}
|
||||
use crate::implementation::cpu_spqlios::ffi::{module::MODULE, vec_znx_big::VEC_ZNX_BIG};
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct vec_znx_dft_t {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type VEC_ZNX_DFT = vec_znx_dft_t;
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn bytes_of_vec_znx_dft(module: *const MODULE, size: u64) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_vec_znx_dft(module: *const MODULE, size: u64) -> *mut VEC_ZNX_DFT;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn delete_vec_znx_dft(res: *mut VEC_ZNX_DFT);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_dft_zero(module: *const MODULE, res: *mut VEC_ZNX_DFT, res_size: u64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_dft_add(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
a: *const VEC_ZNX_DFT,
|
||||
a_size: u64,
|
||||
b: *const VEC_ZNX_DFT,
|
||||
b_size: u64,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_dft_sub(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
a: *const VEC_ZNX_DFT,
|
||||
a_size: u64,
|
||||
b: *const VEC_ZNX_DFT,
|
||||
b_size: u64,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_dft(module: *const MODULE, res: *mut VEC_ZNX_DFT, res_size: u64, a: *const i64, a_size: u64, a_sl: u64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_idft(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
a_dft: *const VEC_ZNX_DFT,
|
||||
a_size: u64,
|
||||
tmp: *mut u8,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_idft_tmp_bytes(module: *const MODULE) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_idft_tmp_a(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
a_dft: *mut VEC_ZNX_DFT,
|
||||
a_size: u64,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_dft_automorphism(
|
||||
module: *const MODULE,
|
||||
d: i64,
|
||||
res_dft: *mut VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
a_dft: *const VEC_ZNX_DFT,
|
||||
a_size: u64,
|
||||
tmp: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_dft_automorphism_tmp_bytes(module: *const MODULE) -> u64;
|
||||
}
|
||||
@@ -1,167 +1,113 @@
|
||||
use crate::ffi::module::MODULE;
|
||||
use crate::ffi::vec_znx_big::VEC_ZNX_BIG;
|
||||
use crate::ffi::vec_znx_dft::VEC_ZNX_DFT;
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct vmp_pmat_t {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
|
||||
// [rows][cols] = [#Decomposition][#Limbs]
|
||||
pub type VMP_PMAT = vmp_pmat_t;
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn bytes_of_vmp_pmat(module: *const MODULE, nrows: u64, ncols: u64) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_vmp_pmat(module: *const MODULE, nrows: u64, ncols: u64) -> *mut VMP_PMAT;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn delete_vmp_pmat(res: *mut VMP_PMAT);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_apply_dft(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
a: *const i64,
|
||||
a_size: u64,
|
||||
a_sl: u64,
|
||||
pmat: *const VMP_PMAT,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
tmp_space: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_apply_dft_add(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
a: *const i64,
|
||||
a_size: u64,
|
||||
a_sl: u64,
|
||||
pmat: *const VMP_PMAT,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
pmat_scale: u64,
|
||||
tmp_space: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_apply_dft_tmp_bytes(module: *const MODULE, res_size: u64, a_size: u64, nrows: u64, ncols: u64) -> u64;
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_apply_dft_to_dft(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
a_dft: *const VEC_ZNX_DFT,
|
||||
a_size: u64,
|
||||
pmat: *const VMP_PMAT,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
tmp_space: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_apply_dft_to_dft_add(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
a_dft: *const VEC_ZNX_DFT,
|
||||
a_size: u64,
|
||||
pmat: *const VMP_PMAT,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
pmat_scale: u64,
|
||||
tmp_space: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_apply_dft_to_dft_tmp_bytes(
|
||||
module: *const MODULE,
|
||||
res_size: u64,
|
||||
a_size: u64,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
) -> u64;
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_prepare_contiguous(
|
||||
module: *const MODULE,
|
||||
pmat: *mut VMP_PMAT,
|
||||
mat: *const i64,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
tmp_space: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_prepare_dblptr(
|
||||
module: *const MODULE,
|
||||
pmat: *mut VMP_PMAT,
|
||||
mat: *const *const i64,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
tmp_space: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_prepare_row(
|
||||
module: *const MODULE,
|
||||
pmat: *mut VMP_PMAT,
|
||||
row: *const i64,
|
||||
row_i: u64,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
tmp_space: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_prepare_row_dft(
|
||||
module: *const MODULE,
|
||||
pmat: *mut VMP_PMAT,
|
||||
row: *const VEC_ZNX_DFT,
|
||||
row_i: u64,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_extract_row_dft(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_DFT,
|
||||
pmat: *const VMP_PMAT,
|
||||
row_i: u64,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_extract_row(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
pmat: *const VMP_PMAT,
|
||||
row_i: u64,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_prepare_tmp_bytes(module: *const MODULE, nrows: u64, ncols: u64) -> u64;
|
||||
}
|
||||
use crate::implementation::cpu_spqlios::ffi::{module::MODULE, vec_znx_dft::VEC_ZNX_DFT};
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct vmp_pmat_t {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
|
||||
// [rows][cols] = [#Decomposition][#Limbs]
|
||||
pub type VMP_PMAT = vmp_pmat_t;
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn bytes_of_vmp_pmat(module: *const MODULE, nrows: u64, ncols: u64) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_vmp_pmat(module: *const MODULE, nrows: u64, ncols: u64) -> *mut VMP_PMAT;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn delete_vmp_pmat(res: *mut VMP_PMAT);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_apply_dft(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
a: *const i64,
|
||||
a_size: u64,
|
||||
a_sl: u64,
|
||||
pmat: *const VMP_PMAT,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
tmp_space: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_apply_dft_add(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
a: *const i64,
|
||||
a_size: u64,
|
||||
a_sl: u64,
|
||||
pmat: *const VMP_PMAT,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
pmat_scale: u64,
|
||||
tmp_space: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_apply_dft_tmp_bytes(module: *const MODULE, res_size: u64, a_size: u64, nrows: u64, ncols: u64) -> u64;
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_apply_dft_to_dft(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
a_dft: *const VEC_ZNX_DFT,
|
||||
a_size: u64,
|
||||
pmat: *const VMP_PMAT,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
tmp_space: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_apply_dft_to_dft_add(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
a_dft: *const VEC_ZNX_DFT,
|
||||
a_size: u64,
|
||||
pmat: *const VMP_PMAT,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
pmat_scale: u64,
|
||||
tmp_space: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_apply_dft_to_dft_tmp_bytes(
|
||||
module: *const MODULE,
|
||||
res_size: u64,
|
||||
a_size: u64,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
) -> u64;
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_prepare_contiguous(
|
||||
module: *const MODULE,
|
||||
pmat: *mut VMP_PMAT,
|
||||
mat: *const i64,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
tmp_space: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_prepare_contiguous_dft(module: *const MODULE, pmat: *mut VMP_PMAT, mat: *const f64, nrows: u64, ncols: u64);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_prepare_tmp_bytes(module: *const MODULE, nrows: u64, ncols: u64) -> u64;
|
||||
}
|
||||
@@ -1,76 +1,79 @@
|
||||
use crate::ffi::module::MODULE;
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_add_i64_ref(nn: u64, res: *mut i64, a: *const i64, b: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_add_i64_avx(nn: u64, res: *mut i64, a: *const i64, b: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_sub_i64_ref(nn: u64, res: *mut i64, a: *const i64, b: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_sub_i64_avx(nn: u64, res: *mut i64, a: *const i64, b: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_negate_i64_ref(nn: u64, res: *mut i64, a: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_negate_i64_avx(nn: u64, res: *mut i64, a: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_copy_i64_ref(nn: u64, res: *mut i64, a: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_zero_i64_ref(nn: u64, res: *mut i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn rnx_divide_by_m_ref(nn: u64, m: f64, res: *mut f64, a: *const f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn rnx_divide_by_m_avx(nn: u64, m: f64, res: *mut f64, a: *const f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn rnx_rotate_f64(nn: u64, p: i64, res: *mut f64, in_: *const f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_rotate_i64(nn: u64, p: i64, res: *mut i64, in_: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn rnx_rotate_inplace_f64(nn: u64, p: i64, res: *mut f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_rotate_inplace_i64(nn: u64, p: i64, res: *mut i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn rnx_automorphism_f64(nn: u64, p: i64, res: *mut f64, in_: *const f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_automorphism_i64(nn: u64, p: i64, res: *mut i64, in_: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn rnx_automorphism_inplace_f64(nn: u64, p: i64, res: *mut f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_automorphism_inplace_i64(nn: u64, p: i64, res: *mut i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn rnx_mul_xp_minus_one(nn: u64, p: i64, res: *mut f64, in_: *const f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_mul_xp_minus_one(nn: u64, p: i64, res: *mut i64, in_: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn rnx_mul_xp_minus_one_inplace(nn: u64, p: i64, res: *mut f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_normalize(nn: u64, base_k: u64, out: *mut i64, carry_out: *mut i64, in_: *const i64, carry_in: *const i64);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_small_single_product(module: *const MODULE, res: *mut i64, a: *const i64, b: *const i64, tmp: *mut u8);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_small_single_product_tmp_bytes(module: *const MODULE) -> u64;
|
||||
}
|
||||
use crate::implementation::cpu_spqlios::ffi::module::MODULE;
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_add_i64_ref(nn: u64, res: *mut i64, a: *const i64, b: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_add_i64_avx(nn: u64, res: *mut i64, a: *const i64, b: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_sub_i64_ref(nn: u64, res: *mut i64, a: *const i64, b: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_sub_i64_avx(nn: u64, res: *mut i64, a: *const i64, b: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_negate_i64_ref(nn: u64, res: *mut i64, a: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_negate_i64_avx(nn: u64, res: *mut i64, a: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_copy_i64_ref(nn: u64, res: *mut i64, a: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_zero_i64_ref(nn: u64, res: *mut i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn rnx_divide_by_m_ref(nn: u64, m: f64, res: *mut f64, a: *const f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn rnx_divide_by_m_avx(nn: u64, m: f64, res: *mut f64, a: *const f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn rnx_rotate_f64(nn: u64, p: i64, res: *mut f64, in_: *const f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_rotate_i64(nn: u64, p: i64, res: *mut i64, in_: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn rnx_rotate_inplace_f64(nn: u64, p: i64, res: *mut f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_rotate_inplace_i64(nn: u64, p: i64, res: *mut i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn rnx_automorphism_f64(nn: u64, p: i64, res: *mut f64, in_: *const f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_automorphism_i64(nn: u64, p: i64, res: *mut i64, in_: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn rnx_automorphism_inplace_f64(nn: u64, p: i64, res: *mut f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_automorphism_inplace_i64(nn: u64, p: i64, res: *mut i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn rnx_mul_xp_minus_one_f64(nn: u64, p: i64, res: *mut f64, in_: *const f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_mul_xp_minus_one_i64(nn: u64, p: i64, res: *mut i64, in_: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn rnx_mul_xp_minus_one_inplace_f64(nn: u64, p: i64, res: *mut f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_mul_xp_minus_one_inplace_i64(nn: u64, p: i64, res: *mut i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_normalize(nn: u64, base_k: u64, out: *mut i64, carry_out: *mut i64, in_: *const i64, carry_in: *const i64);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_small_single_product(module: *const MODULE, res: *mut i64, a: *const i64, b: *const i64, tmp: *mut u8);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_small_single_product_tmp_bytes(module: *const MODULE) -> u64;
|
||||
}
|
||||
41
backend/src/implementation/cpu_spqlios/mat_znx.rs
Normal file
41
backend/src/implementation/cpu_spqlios/mat_znx.rs
Normal file
@@ -0,0 +1,41 @@
|
||||
use crate::{
|
||||
hal::{
|
||||
layouts::{Backend, MatZnxOwned, Module},
|
||||
oep::{MatZnxAllocBytesImpl, MatZnxAllocImpl, MatZnxFromBytesImpl},
|
||||
},
|
||||
implementation::cpu_spqlios::CPUAVX,
|
||||
};
|
||||
|
||||
unsafe impl<B: Backend> MatZnxAllocImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn mat_znx_alloc_impl(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxOwned {
|
||||
MatZnxOwned::new(module.n(), rows, cols_in, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> MatZnxAllocBytesImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn mat_znx_alloc_bytes_impl(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
MatZnxOwned::bytes_of(module.n(), rows, cols_in, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> MatZnxFromBytesImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn mat_znx_from_bytes_impl(
|
||||
module: &Module<B>,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
bytes: Vec<u8>,
|
||||
) -> MatZnxOwned {
|
||||
MatZnxOwned::new_from_bytes(module.n(), rows, cols_in, cols_out, size, bytes)
|
||||
}
|
||||
}
|
||||
26
backend/src/implementation/cpu_spqlios/mod.rs
Normal file
26
backend/src/implementation/cpu_spqlios/mod.rs
Normal file
@@ -0,0 +1,26 @@
|
||||
mod ffi;
|
||||
mod mat_znx;
|
||||
mod module_fft64;
|
||||
mod module_ntt120;
|
||||
mod scalar_znx;
|
||||
mod scratch;
|
||||
mod svp_ppol_fft64;
|
||||
mod svp_ppol_ntt120;
|
||||
mod vec_znx;
|
||||
mod vec_znx_big_fft64;
|
||||
mod vec_znx_big_ntt120;
|
||||
mod vec_znx_dft_fft64;
|
||||
mod vec_znx_dft_ntt120;
|
||||
mod vmp_pmat_fft64;
|
||||
mod vmp_pmat_ntt120;
|
||||
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
|
||||
pub use module_fft64::*;
|
||||
pub use module_ntt120::*;
|
||||
|
||||
/// For external documentation
|
||||
pub use vec_znx::{vec_znx_copy_ref, vec_znx_lsh_inplace_ref, vec_znx_merge_ref, vec_znx_rsh_inplace_ref, vec_znx_split_ref, vec_znx_switch_degree_ref};
|
||||
|
||||
pub trait CPUAVX {}
|
||||
29
backend/src/implementation/cpu_spqlios/module_fft64.rs
Normal file
29
backend/src/implementation/cpu_spqlios/module_fft64.rs
Normal file
@@ -0,0 +1,29 @@
|
||||
use std::ptr::NonNull;
|
||||
|
||||
use crate::{
|
||||
hal::{
|
||||
layouts::{Backend, Module},
|
||||
oep::ModuleNewImpl,
|
||||
},
|
||||
implementation::cpu_spqlios::{
|
||||
CPUAVX,
|
||||
ffi::module::{MODULE, delete_module_info, new_module_info},
|
||||
},
|
||||
};
|
||||
|
||||
pub struct FFT64;
|
||||
|
||||
impl CPUAVX for FFT64 {}
|
||||
|
||||
impl Backend for FFT64 {
|
||||
type Handle = MODULE;
|
||||
unsafe fn destroy(handle: NonNull<Self::Handle>) {
|
||||
unsafe { delete_module_info(handle.as_ptr()) }
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ModuleNewImpl<Self> for FFT64 {
|
||||
fn new_impl(n: u64) -> Module<Self> {
|
||||
unsafe { Module::from_raw_parts(new_module_info(n, 0), n) }
|
||||
}
|
||||
}
|
||||
29
backend/src/implementation/cpu_spqlios/module_ntt120.rs
Normal file
29
backend/src/implementation/cpu_spqlios/module_ntt120.rs
Normal file
@@ -0,0 +1,29 @@
|
||||
use std::ptr::NonNull;
|
||||
|
||||
use crate::{
|
||||
hal::{
|
||||
layouts::{Backend, Module},
|
||||
oep::ModuleNewImpl,
|
||||
},
|
||||
implementation::cpu_spqlios::{
|
||||
CPUAVX,
|
||||
ffi::module::{MODULE, delete_module_info, new_module_info},
|
||||
},
|
||||
};
|
||||
|
||||
pub struct NTT120;
|
||||
|
||||
impl CPUAVX for NTT120 {}
|
||||
|
||||
impl Backend for NTT120 {
|
||||
type Handle = MODULE;
|
||||
unsafe fn destroy(handle: NonNull<Self::Handle>) {
|
||||
unsafe { delete_module_info(handle.as_ptr()) }
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ModuleNewImpl<Self> for NTT120 {
|
||||
fn new_impl(n: u64) -> Module<Self> {
|
||||
unsafe { Module::from_raw_parts(new_module_info(n, 1), n) }
|
||||
}
|
||||
}
|
||||
100
backend/src/implementation/cpu_spqlios/scalar_znx.rs
Normal file
100
backend/src/implementation/cpu_spqlios/scalar_znx.rs
Normal file
@@ -0,0 +1,100 @@
|
||||
use crate::{
|
||||
hal::{
|
||||
api::{ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut},
|
||||
layouts::{Backend, Module, ScalarZnx, ScalarZnxOwned, ScalarZnxToMut, ScalarZnxToRef},
|
||||
oep::{
|
||||
ScalarZnxAllocBytesImpl, ScalarZnxAllocImpl, ScalarZnxAutomorphismImpl, ScalarZnxAutomorphismInplaceIml,
|
||||
ScalarZnxFromBytesImpl,
|
||||
},
|
||||
},
|
||||
implementation::cpu_spqlios::{
|
||||
CPUAVX,
|
||||
ffi::{module::module_info_t, vec_znx},
|
||||
},
|
||||
};
|
||||
|
||||
unsafe impl<B: Backend> ScalarZnxAllocBytesImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn scalar_znx_alloc_bytes_impl(n: usize, cols: usize) -> usize {
|
||||
ScalarZnxOwned::bytes_of(n, cols)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> ScalarZnxAllocImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn scalar_znx_alloc_impl(n: usize, cols: usize) -> ScalarZnxOwned {
|
||||
ScalarZnxOwned::new(n, cols)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> ScalarZnxFromBytesImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn scalar_znx_from_bytes_impl(n: usize, cols: usize, bytes: Vec<u8>) -> ScalarZnxOwned {
|
||||
ScalarZnxOwned::new_from_bytes(n, cols, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> ScalarZnxAutomorphismImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn scalar_znx_automorphism_impl<R, A>(module: &Module<B>, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: ScalarZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
{
|
||||
let a: ScalarZnx<&[u8]> = a.to_ref();
|
||||
let mut res: ScalarZnx<&mut [u8]> = res.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
module.ptr() as *const module_info_t,
|
||||
k,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> ScalarZnxAutomorphismInplaceIml<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn scalar_znx_automorphism_inplace_impl<A>(module: &Module<B>, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: ScalarZnxToMut,
|
||||
{
|
||||
let mut a: ScalarZnx<&mut [u8]> = a.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
module.ptr() as *const module_info_t,
|
||||
k,
|
||||
a.at_mut_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
274
backend/src/implementation/cpu_spqlios/scratch.rs
Normal file
274
backend/src/implementation/cpu_spqlios/scratch.rs
Normal file
@@ -0,0 +1,274 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::{
|
||||
DEFAULTALIGN, alloc_aligned,
|
||||
hal::{
|
||||
api::ScratchFromBytes,
|
||||
layouts::{Backend, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat},
|
||||
oep::{
|
||||
ScalarZnxAllocBytesImpl, ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl,
|
||||
SvpPPolAllocBytesImpl, TakeMatZnxImpl, TakeScalarZnxImpl, TakeSliceImpl, TakeSvpPPolImpl, TakeVecZnxBigImpl,
|
||||
TakeVecZnxDftImpl, TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, TakeVmpPMatImpl,
|
||||
VecZnxAllocBytesImpl, VecZnxBigAllocBytesImpl, VecZnxDftAllocBytesImpl, VmpPMatAllocBytesImpl,
|
||||
},
|
||||
},
|
||||
implementation::cpu_spqlios::CPUAVX,
|
||||
};
|
||||
|
||||
unsafe impl<B: Backend> ScratchOwnedAllocImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn scratch_owned_alloc_impl(size: usize) -> ScratchOwned<B> {
|
||||
let data: Vec<u8> = alloc_aligned(size);
|
||||
ScratchOwned {
|
||||
data,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> ScratchOwnedBorrowImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn scratch_owned_borrow_impl(scratch: &mut ScratchOwned<B>) -> &mut Scratch<B> {
|
||||
Scratch::from_bytes(&mut scratch.data)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> ScratchFromBytesImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn scratch_from_bytes_impl(data: &mut [u8]) -> &mut Scratch<B> {
|
||||
unsafe { &mut *(data as *mut [u8] as *mut Scratch<B>) }
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> ScratchAvailableImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn scratch_available_impl(scratch: &Scratch<B>) -> usize {
|
||||
let ptr: *const u8 = scratch.data.as_ptr();
|
||||
let self_len: usize = scratch.data.len();
|
||||
let aligned_offset: usize = ptr.align_offset(DEFAULTALIGN);
|
||||
self_len.saturating_sub(aligned_offset)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeSliceImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn take_slice_impl<T>(scratch: &mut Scratch<B>, len: usize) -> (&mut [T], &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, len * std::mem::size_of::<T>());
|
||||
|
||||
unsafe {
|
||||
(
|
||||
&mut *(std::ptr::slice_from_raw_parts_mut(take_slice.as_mut_ptr() as *mut T, len)),
|
||||
Scratch::from_bytes(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeScalarZnxImpl<B> for B
|
||||
where
|
||||
B: CPUAVX + ScalarZnxAllocBytesImpl<B>,
|
||||
{
|
||||
fn take_scalar_znx_impl(scratch: &mut Scratch<B>, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, B::scalar_znx_alloc_bytes_impl(n, cols));
|
||||
(
|
||||
ScalarZnx::from_data(take_slice, n, cols),
|
||||
Scratch::from_bytes(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeSvpPPolImpl<B> for B
|
||||
where
|
||||
B: CPUAVX + SvpPPolAllocBytesImpl<B>,
|
||||
{
|
||||
fn take_svp_ppol_impl(scratch: &mut Scratch<B>, n: usize, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, B::svp_ppol_alloc_bytes_impl(n, cols));
|
||||
(
|
||||
SvpPPol::from_data(take_slice, n, cols),
|
||||
Scratch::from_bytes(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVecZnxImpl<B> for B
|
||||
where
|
||||
B: CPUAVX + VecZnxAllocBytesImpl<B>,
|
||||
{
|
||||
fn take_vec_znx_impl(scratch: &mut Scratch<B>, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(
|
||||
&mut scratch.data,
|
||||
B::vec_znx_alloc_bytes_impl(n, cols, size),
|
||||
);
|
||||
(
|
||||
VecZnx::from_data(take_slice, n, cols, size),
|
||||
Scratch::from_bytes(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVecZnxBigImpl<B> for B
|
||||
where
|
||||
B: CPUAVX + VecZnxBigAllocBytesImpl<B>,
|
||||
{
|
||||
fn take_vec_znx_big_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
n: usize,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (VecZnxBig<&mut [u8], B>, &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(
|
||||
&mut scratch.data,
|
||||
B::vec_znx_big_alloc_bytes_impl(n, cols, size),
|
||||
);
|
||||
(
|
||||
VecZnxBig::from_data(take_slice, n, cols, size),
|
||||
Scratch::from_bytes(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVecZnxDftImpl<B> for B
|
||||
where
|
||||
B: CPUAVX + VecZnxDftAllocBytesImpl<B>,
|
||||
{
|
||||
fn take_vec_znx_dft_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
n: usize,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (VecZnxDft<&mut [u8], B>, &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(
|
||||
&mut scratch.data,
|
||||
B::vec_znx_dft_alloc_bytes_impl(n, cols, size),
|
||||
);
|
||||
|
||||
(
|
||||
VecZnxDft::from_data(take_slice, n, cols, size),
|
||||
Scratch::from_bytes(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVecZnxDftSliceImpl<B> for B
|
||||
where
|
||||
B: CPUAVX + VecZnxDftAllocBytesImpl<B>,
|
||||
{
|
||||
fn take_vec_znx_dft_slice_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
len: usize,
|
||||
n: usize,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (Vec<VecZnxDft<&mut [u8], B>>, &mut Scratch<B>) {
|
||||
let mut scratch: &mut Scratch<B> = scratch;
|
||||
let mut slice: Vec<VecZnxDft<&mut [u8], B>> = Vec::with_capacity(len);
|
||||
for _ in 0..len {
|
||||
let (znx, new_scratch) = B::take_vec_znx_dft_impl(scratch, n, cols, size);
|
||||
scratch = new_scratch;
|
||||
slice.push(znx);
|
||||
}
|
||||
(slice, scratch)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVecZnxSliceImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn take_vec_znx_slice_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
len: usize,
|
||||
n: usize,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (Vec<VecZnx<&mut [u8]>>, &mut Scratch<B>) {
|
||||
let mut scratch: &mut Scratch<B> = scratch;
|
||||
let mut slice: Vec<VecZnx<&mut [u8]>> = Vec::with_capacity(len);
|
||||
for _ in 0..len {
|
||||
let (znx, new_scratch) = B::take_vec_znx_impl(scratch, n, cols, size);
|
||||
scratch = new_scratch;
|
||||
slice.push(znx);
|
||||
}
|
||||
(slice, scratch)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVmpPMatImpl<B> for B
|
||||
where
|
||||
B: CPUAVX + VmpPMatAllocBytesImpl<B>,
|
||||
{
|
||||
fn take_vmp_pmat_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
n: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
) -> (VmpPMat<&mut [u8], B>, &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(
|
||||
&mut scratch.data,
|
||||
B::vmp_pmat_alloc_bytes_impl(n, rows, cols_in, cols_out, size),
|
||||
);
|
||||
(
|
||||
VmpPMat::from_data(take_slice, n, rows, cols_in, cols_out, size),
|
||||
Scratch::from_bytes(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeMatZnxImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn take_mat_znx_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
n: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
) -> (MatZnx<&mut [u8]>, &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(
|
||||
&mut scratch.data,
|
||||
MatZnx::<Vec<u8>>::bytes_of(n, rows, cols_in, cols_out, size),
|
||||
);
|
||||
(
|
||||
MatZnx::from_data(take_slice, n, rows, cols_in, cols_out, size),
|
||||
Scratch::from_bytes(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8]) {
|
||||
let ptr: *mut u8 = data.as_mut_ptr();
|
||||
let self_len: usize = data.len();
|
||||
|
||||
let aligned_offset: usize = ptr.align_offset(DEFAULTALIGN);
|
||||
let aligned_len: usize = self_len.saturating_sub(aligned_offset);
|
||||
|
||||
if let Some(rem_len) = aligned_len.checked_sub(take_len) {
|
||||
unsafe {
|
||||
let rem_ptr: *mut u8 = ptr.add(aligned_offset).add(take_len);
|
||||
let rem_slice: &mut [u8] = &mut *std::ptr::slice_from_raw_parts_mut(rem_ptr, rem_len);
|
||||
|
||||
let take_slice: &mut [u8] = &mut *std::ptr::slice_from_raw_parts_mut(ptr.add(aligned_offset), take_len);
|
||||
|
||||
return (take_slice, rem_slice);
|
||||
}
|
||||
} else {
|
||||
panic!(
|
||||
"Attempted to take {} from scratch with {} aligned bytes left",
|
||||
take_len, aligned_len,
|
||||
);
|
||||
}
|
||||
}
|
||||
Submodule backend/src/implementation/cpu_spqlios/spqlios-arithmetic added at 7160f588da
114
backend/src/implementation/cpu_spqlios/svp_ppol_fft64.rs
Normal file
114
backend/src/implementation/cpu_spqlios/svp_ppol_fft64.rs
Normal file
@@ -0,0 +1,114 @@
|
||||
use crate::{
|
||||
hal::{
|
||||
api::{ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut},
|
||||
layouts::{
|
||||
Data, DataRef, Module, ScalarZnxToRef, SvpPPol, SvpPPolBytesOf, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDft,
|
||||
VecZnxDftToMut, VecZnxDftToRef,
|
||||
},
|
||||
oep::{SvpApplyImpl, SvpApplyInplaceImpl, SvpPPolAllocBytesImpl, SvpPPolAllocImpl, SvpPPolFromBytesImpl, SvpPrepareImpl},
|
||||
},
|
||||
implementation::cpu_spqlios::{
|
||||
ffi::{svp, vec_znx_dft::vec_znx_dft_t},
|
||||
module_fft64::FFT64,
|
||||
},
|
||||
};
|
||||
|
||||
const SVP_PPOL_FFT64_WORD_SIZE: usize = 1;
|
||||
|
||||
impl<D: Data> SvpPPolBytesOf for SvpPPol<D, FFT64> {
|
||||
fn bytes_of(n: usize, cols: usize) -> usize {
|
||||
SVP_PPOL_FFT64_WORD_SIZE * n * cols * size_of::<f64>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> ZnxSliceSize for SvpPPol<D, FFT64> {
|
||||
fn sl(&self) -> usize {
|
||||
SVP_PPOL_FFT64_WORD_SIZE * self.n()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> ZnxView for SvpPPol<D, FFT64> {
|
||||
type Scalar = f64;
|
||||
}
|
||||
|
||||
unsafe impl SvpPPolFromBytesImpl<Self> for FFT64 {
|
||||
fn svp_ppol_from_bytes_impl(n: usize, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<Self> {
|
||||
SvpPPolOwned::from_bytes(n, cols, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SvpPPolAllocImpl<Self> for FFT64 {
|
||||
fn svp_ppol_alloc_impl(n: usize, cols: usize) -> SvpPPolOwned<Self> {
|
||||
SvpPPolOwned::alloc(n, cols)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SvpPPolAllocBytesImpl<Self> for FFT64 {
|
||||
fn svp_ppol_alloc_bytes_impl(n: usize, cols: usize) -> usize {
|
||||
SvpPPol::<Vec<u8>, Self>::bytes_of(n, cols)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SvpPrepareImpl<Self> for FFT64 {
|
||||
fn svp_prepare_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: SvpPPolToMut<Self>,
|
||||
A: ScalarZnxToRef,
|
||||
{
|
||||
unsafe {
|
||||
svp::svp_prepare(
|
||||
module.ptr(),
|
||||
res.to_mut().at_mut_ptr(res_col, 0) as *mut svp::svp_ppol_t,
|
||||
a.to_ref().at_ptr(a_col, 0),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SvpApplyImpl<Self> for FFT64 {
|
||||
fn svp_apply_impl<R, A, B>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: SvpPPolToRef<Self>,
|
||||
B: VecZnxDftToRef<Self>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], Self> = res.to_mut();
|
||||
let a: SvpPPol<&[u8], Self> = a.to_ref();
|
||||
let b: VecZnxDft<&[u8], Self> = b.to_ref();
|
||||
unsafe {
|
||||
svp::svp_apply_dft_to_dft(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
|
||||
res.size() as u64,
|
||||
res.cols() as u64,
|
||||
a.at_ptr(a_col, 0) as *const svp::svp_ppol_t,
|
||||
b.at_ptr(b_col, 0) as *const vec_znx_dft_t,
|
||||
b.size() as u64,
|
||||
b.cols() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SvpApplyInplaceImpl for FFT64 {
|
||||
fn svp_apply_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: SvpPPolToRef<Self>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], Self> = res.to_mut();
|
||||
let a: SvpPPol<&[u8], Self> = a.to_ref();
|
||||
unsafe {
|
||||
svp::svp_apply_dft_to_dft(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
|
||||
res.size() as u64,
|
||||
res.cols() as u64,
|
||||
a.at_ptr(a_col, 0) as *const svp::svp_ppol_t,
|
||||
res.at_ptr(res_col, 0) as *const vec_znx_dft_t,
|
||||
res.size() as u64,
|
||||
res.cols() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
44
backend/src/implementation/cpu_spqlios/svp_ppol_ntt120.rs
Normal file
44
backend/src/implementation/cpu_spqlios/svp_ppol_ntt120.rs
Normal file
@@ -0,0 +1,44 @@
|
||||
use crate::{
|
||||
hal::{
|
||||
api::{ZnxInfos, ZnxSliceSize, ZnxView},
|
||||
layouts::{Data, DataRef, SvpPPol, SvpPPolBytesOf, SvpPPolOwned},
|
||||
oep::{SvpPPolAllocBytesImpl, SvpPPolAllocImpl, SvpPPolFromBytesImpl},
|
||||
},
|
||||
implementation::cpu_spqlios::module_ntt120::NTT120,
|
||||
};
|
||||
|
||||
const SVP_PPOL_NTT120_WORD_SIZE: usize = 4;
|
||||
|
||||
impl<D: Data> SvpPPolBytesOf for SvpPPol<D, NTT120> {
|
||||
fn bytes_of(n: usize, cols: usize) -> usize {
|
||||
SVP_PPOL_NTT120_WORD_SIZE * n * cols * size_of::<i64>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> ZnxSliceSize for SvpPPol<D, NTT120> {
|
||||
fn sl(&self) -> usize {
|
||||
SVP_PPOL_NTT120_WORD_SIZE * self.n()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> ZnxView for SvpPPol<D, NTT120> {
|
||||
type Scalar = i64;
|
||||
}
|
||||
|
||||
unsafe impl SvpPPolFromBytesImpl<Self> for NTT120 {
|
||||
fn svp_ppol_from_bytes_impl(n: usize, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<NTT120> {
|
||||
SvpPPolOwned::from_bytes(n, cols, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SvpPPolAllocImpl<Self> for NTT120 {
|
||||
fn svp_ppol_alloc_impl(n: usize, cols: usize) -> SvpPPolOwned<NTT120> {
|
||||
SvpPPolOwned::alloc(n, cols)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SvpPPolAllocBytesImpl<Self> for NTT120 {
|
||||
fn svp_ppol_alloc_bytes_impl(n: usize, cols: usize) -> usize {
|
||||
SvpPPol::<Vec<u8>, Self>::bytes_of(n, cols)
|
||||
}
|
||||
}
|
||||
1
backend/src/implementation/cpu_spqlios/test/mod.rs
Normal file
1
backend/src/implementation/cpu_spqlios/test/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
mod vec_znx_fft64;
|
||||
35
backend/src/implementation/cpu_spqlios/test/vec_znx_fft64.rs
Normal file
35
backend/src/implementation/cpu_spqlios/test/vec_znx_fft64.rs
Normal file
@@ -0,0 +1,35 @@
|
||||
use crate::{
|
||||
hal::{
|
||||
api::ModuleNew,
|
||||
layouts::Module,
|
||||
tests::vec_znx::{
|
||||
test_vec_znx_add_normal, test_vec_znx_encode_vec_i64_hi_norm, test_vec_znx_encode_vec_i64_lo_norm,
|
||||
test_vec_znx_fill_uniform,
|
||||
},
|
||||
},
|
||||
implementation::cpu_spqlios::FFT64,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_vec_znx_fill_uniform_fft64() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << 12);
|
||||
test_vec_znx_fill_uniform(&module);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vec_znx_add_normal_fft64() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << 12);
|
||||
test_vec_znx_add_normal(&module);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vec_znx_encode_vec_lo_norm_fft64() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << 8);
|
||||
test_vec_znx_encode_vec_i64_lo_norm(&module);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vec_znx_encode_vec_hi_norm_fft64() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << 8);
|
||||
test_vec_znx_encode_vec_i64_hi_norm(&module);
|
||||
}
|
||||
1344
backend/src/implementation/cpu_spqlios/vec_znx.rs
Normal file
1344
backend/src/implementation/cpu_spqlios/vec_znx.rs
Normal file
File diff suppressed because it is too large
Load Diff
758
backend/src/implementation/cpu_spqlios/vec_znx_big_fft64.rs
Normal file
758
backend/src/implementation/cpu_spqlios/vec_znx_big_fft64.rs
Normal file
@@ -0,0 +1,758 @@
|
||||
use std::fmt;
|
||||
|
||||
use rand_distr::{Distribution, Normal};
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::{
|
||||
hal::{
|
||||
api::{
|
||||
TakeSlice, VecZnxBigAddDistF64, VecZnxBigFillDistF64, VecZnxBigNormalizeTmpBytes, ZnxInfos, ZnxSliceSize, ZnxView,
|
||||
ZnxViewMut,
|
||||
},
|
||||
layouts::{
|
||||
Data, DataRef, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigBytesOf, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef,
|
||||
VecZnxToMut, VecZnxToRef,
|
||||
},
|
||||
oep::{
|
||||
VecZnxBigAddDistF64Impl, VecZnxBigAddImpl, VecZnxBigAddInplaceImpl, VecZnxBigAddNormalImpl, VecZnxBigAddSmallImpl,
|
||||
VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, VecZnxBigAutomorphismImpl,
|
||||
VecZnxBigAutomorphismInplaceImpl, VecZnxBigFillDistF64Impl, VecZnxBigFillNormalImpl, VecZnxBigFromBytesImpl,
|
||||
VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl, VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubABInplaceImpl,
|
||||
VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl, VecZnxBigSubSmallAImpl, VecZnxBigSubSmallAInplaceImpl,
|
||||
VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl,
|
||||
},
|
||||
},
|
||||
implementation::cpu_spqlios::{ffi::vec_znx, module_fft64::FFT64},
|
||||
};
|
||||
|
||||
const VEC_ZNX_BIG_FFT64_WORDSIZE: usize = 1;
|
||||
|
||||
impl<D: DataRef> ZnxView for VecZnxBig<D, FFT64> {
|
||||
type Scalar = i64;
|
||||
}
|
||||
|
||||
impl<D: Data> VecZnxBigBytesOf for VecZnxBig<D, FFT64> {
|
||||
fn bytes_of(n: usize, cols: usize, size: usize) -> usize {
|
||||
VEC_ZNX_BIG_FFT64_WORDSIZE * n * cols * size * size_of::<f64>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> ZnxSliceSize for VecZnxBig<D, FFT64> {
|
||||
fn sl(&self) -> usize {
|
||||
VEC_ZNX_BIG_FFT64_WORDSIZE * self.n() * self.cols()
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAllocImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_big_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxBigOwned<FFT64> {
|
||||
VecZnxBig::<Vec<u8>, FFT64>::new(n, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigFromBytesImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_big_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<FFT64> {
|
||||
VecZnxBig::<Vec<u8>, FFT64>::new_from_bytes(n, cols, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAllocBytesImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_big_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize {
|
||||
VecZnxBig::<Vec<u8>, FFT64>::bytes_of(n, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddDistF64Impl<FFT64> for FFT64 {
|
||||
fn add_dist_f64_impl<R: VecZnxBigToMut<FFT64>, D: Distribution<f64>>(
|
||||
_module: &Module<FFT64>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) {
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = k.div_ceil(basek) - 1;
|
||||
let basek_rem: usize = (limb + 1) * basek - k;
|
||||
|
||||
if basek_rem != 0 {
|
||||
res.at_mut(res_col, limb).iter_mut().for_each(|x| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*x += (dist_f64.round() as i64) << basek_rem;
|
||||
});
|
||||
} else {
|
||||
res.at_mut(res_col, limb).iter_mut().for_each(|x| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*x += dist_f64.round() as i64
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddNormalImpl<FFT64> for FFT64 {
|
||||
fn add_normal_impl<R: VecZnxBigToMut<FFT64>>(
|
||||
module: &Module<FFT64>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) {
|
||||
module.vec_znx_big_add_dist_f64(
|
||||
basek,
|
||||
res,
|
||||
res_col,
|
||||
k,
|
||||
source,
|
||||
Normal::new(0.0, sigma).unwrap(),
|
||||
bound,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigFillDistF64Impl<FFT64> for FFT64 {
|
||||
fn fill_dist_f64_impl<R: VecZnxBigToMut<FFT64>, D: Distribution<f64>>(
|
||||
_module: &Module<FFT64>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) {
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = k.div_ceil(basek) - 1;
|
||||
let basek_rem: usize = (limb + 1) * basek - k;
|
||||
|
||||
if basek_rem != 0 {
|
||||
res.at_mut(res_col, limb).iter_mut().for_each(|x| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*x = (dist_f64.round() as i64) << basek_rem;
|
||||
});
|
||||
} else {
|
||||
res.at_mut(res_col, limb).iter_mut().for_each(|x| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*x = dist_f64.round() as i64
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigFillNormalImpl<FFT64> for FFT64 {
|
||||
fn fill_normal_impl<R: VecZnxBigToMut<FFT64>>(
|
||||
module: &Module<FFT64>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) {
|
||||
module.vec_znx_big_fill_dist_f64(
|
||||
basek,
|
||||
res,
|
||||
res_col,
|
||||
k,
|
||||
source,
|
||||
Normal::new(0.0, sigma).unwrap(),
|
||||
bound,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddImpl<FFT64> for FFT64 {
|
||||
/// Adds `a` to `b` and stores the result on `c`.
|
||||
fn vec_znx_big_add_impl<R, A, B>(
|
||||
module: &Module<FFT64>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
B: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let b: VecZnxBig<&[u8], FFT64> = b.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(b.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_add(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(b_col, 0),
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddInplaceImpl<FFT64> for FFT64 {
|
||||
/// Adds `a` to `b` and stores the result on `b`.
|
||||
fn vec_znx_big_add_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_add(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddSmallImpl<FFT64> for FFT64 {
|
||||
/// Adds `a` to `b` and stores the result on `c`.
|
||||
fn vec_znx_big_add_small_impl<R, A, B>(
|
||||
module: &Module<FFT64>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
B: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let b: VecZnx<&[u8]> = b.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(b.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_add(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(b_col, 0),
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddSmallInplaceImpl<FFT64> for FFT64 {
|
||||
/// Adds `a` to `b` and stores the result on `b`.
|
||||
fn vec_znx_big_add_small_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_add(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubImpl<FFT64> for FFT64 {
|
||||
/// Subtracts `a` to `b` and stores the result on `c`.
|
||||
fn vec_znx_big_sub_impl<R, A, B>(
|
||||
module: &Module<FFT64>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
B: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let b: VecZnxBig<&[u8], FFT64> = b.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(b.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(b_col, 0),
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubABInplaceImpl<FFT64> for FFT64 {
|
||||
/// Subtracts `a` from `b` and stores the result on `b`.
|
||||
fn vec_znx_big_sub_ab_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubBAInplaceImpl<FFT64> for FFT64 {
|
||||
/// Subtracts `b` from `a` and stores the result on `b`.
|
||||
fn vec_znx_big_sub_ba_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubSmallAImpl<FFT64> for FFT64 {
|
||||
/// Subtracts `b` from `a` and stores the result on `c`.
|
||||
fn vec_znx_big_sub_small_a_impl<R, A, B>(
|
||||
module: &Module<FFT64>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxToRef,
|
||||
B: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let b: VecZnxBig<&[u8], FFT64> = b.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(b.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(b_col, 0),
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubSmallAInplaceImpl<FFT64> for FFT64 {
|
||||
/// Subtracts `a` from `res` and stores the result on `res`.
|
||||
fn vec_znx_big_sub_small_a_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubSmallBImpl<FFT64> for FFT64 {
|
||||
/// Subtracts `b` from `a` and stores the result on `c`.
|
||||
fn vec_znx_big_sub_small_b_impl<R, A, B>(
|
||||
module: &Module<FFT64>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
B: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let b: VecZnx<&[u8]> = b.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(b.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(b_col, 0),
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubSmallBInplaceImpl<FFT64> for FFT64 {
|
||||
/// Subtracts `res` from `a` and stores the result on `res`.
|
||||
fn vec_znx_big_sub_small_b_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigNegateInplaceImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_big_negate_inplace_impl<A>(module: &Module<FFT64>, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxBigToMut<FFT64>,
|
||||
{
|
||||
let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_negate(
|
||||
module.ptr(),
|
||||
a.at_mut_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigNormalizeTmpBytesImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_big_normalize_tmp_bytes_impl(module: &Module<FFT64>, n: usize) -> usize {
|
||||
unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(module.ptr(), n as u64) as usize }
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigNormalizeImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_big_normalize_impl<R, A>(
|
||||
module: &Module<FFT64>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<FFT64>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), a.n());
|
||||
}
|
||||
|
||||
let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes(a.n()));
|
||||
unsafe {
|
||||
vec_znx::vec_znx_normalize_base2k(
|
||||
module.ptr(),
|
||||
a.n() as u64,
|
||||
basek as u64,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAutomorphismImpl<FFT64> for FFT64 {
|
||||
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`.
|
||||
fn vec_znx_big_automorphism_impl<R, A>(module: &Module<FFT64>, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
module.ptr(),
|
||||
k,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAutomorphismInplaceImpl<FFT64> for FFT64 {
|
||||
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`.
|
||||
fn vec_znx_big_automorphism_inplace_impl<A>(module: &Module<FFT64>, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxBigToMut<FFT64>,
|
||||
{
|
||||
let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
module.ptr(),
|
||||
k,
|
||||
a.at_mut_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> fmt::Display for VecZnxBig<D, FFT64> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
writeln!(
|
||||
f,
|
||||
"VecZnxBig(n={}, cols={}, size={})",
|
||||
self.n, self.cols, self.size
|
||||
)?;
|
||||
|
||||
for col in 0..self.cols {
|
||||
writeln!(f, "Column {}:", col)?;
|
||||
for size in 0..self.size {
|
||||
let coeffs = self.at(col, size);
|
||||
write!(f, " Size {}: [", size)?;
|
||||
|
||||
let max_show = 100;
|
||||
let show_count = coeffs.len().min(max_show);
|
||||
|
||||
for (i, &coeff) in coeffs.iter().take(show_count).enumerate() {
|
||||
if i > 0 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
write!(f, "{}", coeff)?;
|
||||
}
|
||||
|
||||
if coeffs.len() > max_show {
|
||||
write!(f, ", ... ({} more)", coeffs.len() - max_show)?;
|
||||
}
|
||||
|
||||
writeln!(f, "]")?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
32
backend/src/implementation/cpu_spqlios/vec_znx_big_ntt120.rs
Normal file
32
backend/src/implementation/cpu_spqlios/vec_znx_big_ntt120.rs
Normal file
@@ -0,0 +1,32 @@
|
||||
use crate::{
|
||||
hal::{
|
||||
api::{ZnxInfos, ZnxSliceSize, ZnxView},
|
||||
layouts::{Data, DataRef, VecZnxBig, VecZnxBigBytesOf},
|
||||
oep::VecZnxBigAllocBytesImpl,
|
||||
},
|
||||
implementation::cpu_spqlios::module_ntt120::NTT120,
|
||||
};
|
||||
|
||||
const VEC_ZNX_BIG_NTT120_WORDSIZE: usize = 4;
|
||||
|
||||
impl<D: DataRef> ZnxView for VecZnxBig<D, NTT120> {
|
||||
type Scalar = i128;
|
||||
}
|
||||
|
||||
impl<D: Data> VecZnxBigBytesOf for VecZnxBig<D, NTT120> {
|
||||
fn bytes_of(n: usize, cols: usize, size: usize) -> usize {
|
||||
VEC_ZNX_BIG_NTT120_WORDSIZE * n * cols * size * size_of::<f64>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> ZnxSliceSize for VecZnxBig<D, NTT120> {
|
||||
fn sl(&self) -> usize {
|
||||
VEC_ZNX_BIG_NTT120_WORDSIZE * self.n() * self.cols()
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAllocBytesImpl<NTT120> for NTT120 {
|
||||
fn vec_znx_big_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize {
|
||||
VecZnxBig::<Vec<u8>, NTT120>::bytes_of(n, cols, size)
|
||||
}
|
||||
}
|
||||
@@ -1,375 +1,90 @@
|
||||
use crate::ffi::{vec_znx_big, vec_znx_dft};
|
||||
use crate::vec_znx_dft::bytes_of_vec_znx_dft;
|
||||
use crate::znx_base::ZnxInfos;
|
||||
use std::fmt;
|
||||
|
||||
use crate::{
|
||||
Backend, Scratch, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef,
|
||||
ZnxSliceSize,
|
||||
hal::{
|
||||
api::{TakeSlice, VecZnxDftToVecZnxBigTmpBytes, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero},
|
||||
layouts::{
|
||||
Data, DataRef, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftBytesOf, VecZnxDftOwned,
|
||||
VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef,
|
||||
},
|
||||
oep::{
|
||||
VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftCopyImpl,
|
||||
VecZnxDftFromBytesImpl, VecZnxDftFromVecZnxImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl,
|
||||
VecZnxDftSubImpl, VecZnxDftToVecZnxBigConsumeImpl, VecZnxDftToVecZnxBigImpl, VecZnxDftToVecZnxBigTmpAImpl,
|
||||
VecZnxDftToVecZnxBigTmpBytesImpl, VecZnxDftZeroImpl,
|
||||
},
|
||||
},
|
||||
implementation::cpu_spqlios::{
|
||||
ffi::{vec_znx_big, vec_znx_dft},
|
||||
module_fft64::FFT64,
|
||||
},
|
||||
};
|
||||
use crate::{FFT64, Module, ZnxView, ZnxViewMut, ZnxZero};
|
||||
use std::cmp::min;
|
||||
|
||||
pub trait VecZnxDftAlloc<B: Backend> {
|
||||
/// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space.
|
||||
fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDftOwned<B>;
|
||||
const VEC_ZNX_DFT_FFT64_WORDSIZE: usize = 1;
|
||||
|
||||
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
|
||||
///
|
||||
/// Behavior: takes ownership of the backing array.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `cols`: the number of cols of the [VecZnxDft].
|
||||
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft].
|
||||
///
|
||||
/// # Panics
|
||||
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
|
||||
fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<B>;
|
||||
|
||||
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `cols`: the number of cols of the [VecZnxDft].
|
||||
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft].
|
||||
///
|
||||
/// # Panics
|
||||
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
|
||||
fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
pub trait VecZnxDftOps<B: Backend> {
|
||||
/// Returns the minimum number of bytes necessary to allocate
|
||||
/// a new [VecZnxDft] through [VecZnxDft::from_bytes].
|
||||
fn vec_znx_idft_tmp_bytes(&self) -> usize;
|
||||
|
||||
fn vec_znx_dft_add<R, A, D>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
D: VecZnxDftToRef<B>;
|
||||
|
||||
fn vec_znx_dft_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>;
|
||||
|
||||
fn vec_znx_dft_sub<R, A, D>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
D: VecZnxDftToRef<B>;
|
||||
|
||||
fn vec_znx_dft_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>;
|
||||
|
||||
fn vec_znx_dft_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>;
|
||||
|
||||
fn vec_znx_dft_copy<R, A>(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>;
|
||||
|
||||
/// b <- IDFT(a), uses a as scratch space.
|
||||
fn vec_znx_idft_tmp_a<R, A>(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxDftToMut<B>;
|
||||
|
||||
/// Consumes a to return IDFT(a) in big coeff space.
|
||||
fn vec_znx_idft_consume<D>(&self, a: VecZnxDft<D, B>) -> VecZnxBig<D, FFT64>
|
||||
where
|
||||
VecZnxDft<D, FFT64>: VecZnxDftToMut<FFT64>;
|
||||
|
||||
fn vec_znx_idft<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxDftToRef<B>;
|
||||
|
||||
fn vec_znx_dft<R, A>(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
impl<B: Backend> VecZnxDftAlloc<B> for Module<B> {
|
||||
fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDftOwned<B> {
|
||||
VecZnxDftOwned::new(&self, cols, size)
|
||||
}
|
||||
|
||||
fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<B> {
|
||||
VecZnxDftOwned::new_from_bytes(self, cols, size, bytes)
|
||||
}
|
||||
|
||||
fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize {
|
||||
bytes_of_vec_znx_dft(self, cols, size)
|
||||
impl<D: Data> ZnxSliceSize for VecZnxDft<D, FFT64> {
|
||||
fn sl(&self) -> usize {
|
||||
VEC_ZNX_DFT_FFT64_WORDSIZE * self.n() * self.cols()
|
||||
}
|
||||
}
|
||||
|
||||
impl VecZnxDftOps<FFT64> for Module<FFT64> {
|
||||
fn vec_znx_dft_add<R, A, D>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
D: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
let b_ref: VecZnxDft<&[u8], FFT64> = b.to_ref();
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_ref.size()).min(b_ref.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_add(
|
||||
self.ptr,
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
b_ref.at_ptr(b_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
}
|
||||
(min_size..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
})
|
||||
impl<D: Data> VecZnxDftBytesOf for VecZnxDft<D, FFT64> {
|
||||
fn bytes_of(n: usize, cols: usize, size: usize) -> usize {
|
||||
VEC_ZNX_DFT_FFT64_WORDSIZE * n * cols * size * size_of::<f64>()
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_dft_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
impl<D: DataRef> ZnxView for VecZnxDft<D, FFT64> {
|
||||
type Scalar = f64;
|
||||
}
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_ref.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_add(
|
||||
self.ptr,
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
}
|
||||
unsafe impl VecZnxDftFromBytesImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<FFT64> {
|
||||
VecZnxDft::<Vec<u8>, FFT64>::from_bytes(n, cols, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_dft_sub<R, A, D>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
D: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
let b_ref: VecZnxDft<&[u8], FFT64> = b.to_ref();
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_ref.size()).min(b_ref.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_sub(
|
||||
self.ptr,
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
b_ref.at_ptr(b_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
}
|
||||
(min_size..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
})
|
||||
unsafe impl VecZnxDftAllocBytesImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize {
|
||||
VecZnxDft::<Vec<u8>, FFT64>::bytes_of(n, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_dft_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_ref.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_sub(
|
||||
self.ptr,
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
}
|
||||
unsafe impl VecZnxDftAllocImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxDftOwned<FFT64> {
|
||||
VecZnxDftOwned::alloc(n, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_dft_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_ref.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_sub(
|
||||
self.ptr,
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
}
|
||||
unsafe impl VecZnxDftToVecZnxBigTmpBytesImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(module: &Module<FFT64>) -> usize {
|
||||
unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(module.ptr()) as usize }
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_dft_copy<R, A>(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
|
||||
let steps: usize = (a_ref.size() + step - 1) / step;
|
||||
let min_steps: usize = min(res_mut.size(), steps);
|
||||
|
||||
(0..min_steps).for_each(|j| {
|
||||
let limb: usize = offset + j * step;
|
||||
if limb < a_ref.size() {
|
||||
res_mut
|
||||
.at_mut(res_col, j)
|
||||
.copy_from_slice(a_ref.at(a_col, limb));
|
||||
}
|
||||
});
|
||||
(min_steps..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
})
|
||||
}
|
||||
|
||||
fn vec_znx_idft_tmp_a<R, A>(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxDftToMut<FFT64>,
|
||||
{
|
||||
let mut res_mut: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut();
|
||||
|
||||
let min_size: usize = min(res_mut.size(), a_mut.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_znx_idft_tmp_a(
|
||||
self.ptr,
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t,
|
||||
1 as u64,
|
||||
a_mut.at_mut_ptr(a_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1 as u64,
|
||||
)
|
||||
});
|
||||
(min_size..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_idft_consume<D>(&self, mut a: VecZnxDft<D, FFT64>) -> VecZnxBig<D, FFT64>
|
||||
where
|
||||
VecZnxDft<D, FFT64>: VecZnxDftToMut<FFT64>,
|
||||
{
|
||||
let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut();
|
||||
|
||||
unsafe {
|
||||
// Rev col and rows because ZnxDft.sl() >= ZnxBig.sl()
|
||||
(0..a_mut.size()).for_each(|j| {
|
||||
(0..a_mut.cols()).for_each(|i| {
|
||||
vec_znx_dft::vec_znx_idft_tmp_a(
|
||||
self.ptr,
|
||||
a_mut.at_mut_ptr(i, j) as *mut vec_znx_big::vec_znx_big_t,
|
||||
1 as u64,
|
||||
a_mut.at_mut_ptr(i, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1 as u64,
|
||||
)
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
a.into_big()
|
||||
}
|
||||
|
||||
fn vec_znx_idft_tmp_bytes(&self) -> usize {
|
||||
unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(self.ptr) as usize }
|
||||
}
|
||||
|
||||
fn vec_znx_dft<R, A>(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: crate::VecZnx<&[u8]> = a.to_ref();
|
||||
let steps: usize = (a_ref.size() + step - 1) / step;
|
||||
let min_steps: usize = min(res_mut.size(), steps);
|
||||
unsafe {
|
||||
(0..min_steps).for_each(|j| {
|
||||
let limb: usize = offset + j * step;
|
||||
if limb < a_ref.size() {
|
||||
vec_znx_dft::vec_znx_dft(
|
||||
self.ptr,
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1 as u64,
|
||||
a_ref.at_ptr(a_col, limb),
|
||||
1 as u64,
|
||||
a_ref.sl() as u64,
|
||||
)
|
||||
}
|
||||
});
|
||||
(min_steps..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes].
|
||||
fn vec_znx_idft<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
|
||||
where
|
||||
unsafe impl VecZnxDftToVecZnxBigImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_to_vec_znx_big_impl<R, A>(
|
||||
module: &Module<FFT64>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<FFT64>,
|
||||
) where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res_mut: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
|
||||
let (tmp_bytes, _) = scratch.tmp_slice(self.vec_znx_idft_tmp_bytes());
|
||||
let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_dft_to_vec_znx_big_tmp_bytes());
|
||||
|
||||
let min_size: usize = min(res_mut.size(), a_ref.size());
|
||||
let min_size: usize = res_mut.size().min(a_ref.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_znx_idft(
|
||||
self.ptr,
|
||||
module.ptr(),
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t,
|
||||
1 as u64,
|
||||
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
@@ -383,3 +98,331 @@ impl VecZnxDftOps<FFT64> for Module<FFT64> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftToVecZnxBigTmpAImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_to_vec_znx_big_tmp_a_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxDftToMut<FFT64>,
|
||||
{
|
||||
let mut res_mut: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut();
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_mut.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_znx_idft_tmp_a(
|
||||
module.ptr(),
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t,
|
||||
1 as u64,
|
||||
a_mut.at_mut_ptr(a_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1 as u64,
|
||||
)
|
||||
});
|
||||
(min_size..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftToVecZnxBigConsumeImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_to_vec_znx_big_consume_impl<D: Data>(module: &Module<FFT64>, mut a: VecZnxDft<D, FFT64>) -> VecZnxBig<D, FFT64>
|
||||
where
|
||||
VecZnxDft<D, FFT64>: VecZnxDftToMut<FFT64>,
|
||||
{
|
||||
let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut();
|
||||
|
||||
unsafe {
|
||||
// Rev col and rows because ZnxDft.sl() >= ZnxBig.sl()
|
||||
(0..a_mut.size()).for_each(|j| {
|
||||
(0..a_mut.cols()).for_each(|i| {
|
||||
vec_znx_dft::vec_znx_idft_tmp_a(
|
||||
module.ptr(),
|
||||
a_mut.at_mut_ptr(i, j) as *mut vec_znx_big::vec_znx_big_t,
|
||||
1 as u64,
|
||||
a_mut.at_mut_ptr(i, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1 as u64,
|
||||
)
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
a.into_big()
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftFromVecZnxImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_from_vec_znx_impl<R, A>(
|
||||
module: &Module<FFT64>,
|
||||
step: usize,
|
||||
offset: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
) where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnx<&[u8]> = a.to_ref();
|
||||
let steps: usize = a_ref.size().div_ceil(step);
|
||||
let min_steps: usize = res_mut.size().min(steps);
|
||||
unsafe {
|
||||
(0..min_steps).for_each(|j| {
|
||||
let limb: usize = offset + j * step;
|
||||
if limb < a_ref.size() {
|
||||
vec_znx_dft::vec_znx_dft(
|
||||
module.ptr(),
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1 as u64,
|
||||
a_ref.at_ptr(a_col, limb),
|
||||
1 as u64,
|
||||
a_ref.sl() as u64,
|
||||
)
|
||||
}
|
||||
});
|
||||
(min_steps..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftAddImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_add_impl<R, A, D>(
|
||||
module: &Module<FFT64>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &D,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
D: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
let b_ref: VecZnxDft<&[u8], FFT64> = b.to_ref();
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_ref.size()).min(b_ref.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_add(
|
||||
module.ptr(),
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
b_ref.at_ptr(b_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
}
|
||||
(min_size..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftAddInplaceImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_add_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_ref.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_add(
|
||||
module.ptr(),
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftSubImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_sub_impl<R, A, D>(
|
||||
module: &Module<FFT64>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &D,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
D: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
let b_ref: VecZnxDft<&[u8], FFT64> = b.to_ref();
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_ref.size()).min(b_ref.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_sub(
|
||||
module.ptr(),
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
b_ref.at_ptr(b_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
}
|
||||
(min_size..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftSubABInplaceImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_sub_ab_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_ref.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_sub(
|
||||
module.ptr(),
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftSubBAInplaceImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_sub_ba_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_ref.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_sub(
|
||||
module.ptr(),
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftCopyImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_copy_impl<R, A>(
|
||||
_module: &Module<FFT64>,
|
||||
step: usize,
|
||||
offset: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
) where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
|
||||
let steps: usize = a_ref.size().div_ceil(step);
|
||||
let min_steps: usize = res_mut.size().min(steps);
|
||||
|
||||
(0..min_steps).for_each(|j| {
|
||||
let limb: usize = offset + j * step;
|
||||
if limb < a_ref.size() {
|
||||
res_mut
|
||||
.at_mut(res_col, j)
|
||||
.copy_from_slice(a_ref.at(a_col, limb));
|
||||
}
|
||||
});
|
||||
(min_steps..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftZeroImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_zero_impl<R>(_module: &Module<FFT64>, res: &mut R)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
{
|
||||
res.to_mut().data.fill(0);
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> fmt::Display for VecZnxDft<D, FFT64> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
writeln!(
|
||||
f,
|
||||
"VecZnxDft(n={}, cols={}, size={})",
|
||||
self.n, self.cols, self.size
|
||||
)?;
|
||||
|
||||
for col in 0..self.cols {
|
||||
writeln!(f, "Column {}:", col)?;
|
||||
for size in 0..self.size {
|
||||
let coeffs = self.at(col, size);
|
||||
write!(f, " Size {}: [", size)?;
|
||||
|
||||
let max_show = 100;
|
||||
let show_count = coeffs.len().min(max_show);
|
||||
|
||||
for (i, &coeff) in coeffs.iter().take(show_count).enumerate() {
|
||||
if i > 0 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
write!(f, "{}", coeff)?;
|
||||
}
|
||||
|
||||
if coeffs.len() > max_show {
|
||||
write!(f, ", ... ({} more)", coeffs.len() - max_show)?;
|
||||
}
|
||||
|
||||
writeln!(f, "]")?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
38
backend/src/implementation/cpu_spqlios/vec_znx_dft_ntt120.rs
Normal file
38
backend/src/implementation/cpu_spqlios/vec_znx_dft_ntt120.rs
Normal file
@@ -0,0 +1,38 @@
|
||||
use crate::{
|
||||
hal::{
|
||||
api::{ZnxInfos, ZnxSliceSize, ZnxView},
|
||||
layouts::{Data, DataRef, VecZnxDft, VecZnxDftBytesOf, VecZnxDftOwned},
|
||||
oep::{VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl},
|
||||
},
|
||||
implementation::cpu_spqlios::module_ntt120::NTT120,
|
||||
};
|
||||
|
||||
const VEC_ZNX_DFT_NTT120_WORDSIZE: usize = 4;
|
||||
|
||||
impl<D: Data> ZnxSliceSize for VecZnxDft<D, NTT120> {
|
||||
fn sl(&self) -> usize {
|
||||
VEC_ZNX_DFT_NTT120_WORDSIZE * self.n() * self.cols()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> VecZnxDftBytesOf for VecZnxDft<D, NTT120> {
|
||||
fn bytes_of(n: usize, cols: usize, size: usize) -> usize {
|
||||
VEC_ZNX_DFT_NTT120_WORDSIZE * n * cols * size * size_of::<i64>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> ZnxView for VecZnxDft<D, NTT120> {
|
||||
type Scalar = i64;
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftAllocBytesImpl<NTT120> for NTT120 {
|
||||
fn vec_znx_dft_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize {
|
||||
VecZnxDft::<Vec<u8>, NTT120>::bytes_of(n, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftAllocImpl<NTT120> for NTT120 {
|
||||
fn vec_znx_dft_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxDftOwned<NTT120> {
|
||||
VecZnxDftOwned::alloc(n, cols, size)
|
||||
}
|
||||
}
|
||||
286
backend/src/implementation/cpu_spqlios/vmp_pmat_fft64.rs
Normal file
286
backend/src/implementation/cpu_spqlios/vmp_pmat_fft64.rs
Normal file
@@ -0,0 +1,286 @@
|
||||
use crate::{
|
||||
hal::{
|
||||
api::{TakeSlice, VmpApplyTmpBytes, VmpPrepareTmpBytes, ZnxInfos, ZnxView, ZnxViewMut},
|
||||
layouts::{
|
||||
DataRef, MatZnx, MatZnxToRef, Module, Scratch, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, VmpPMat, VmpPMatBytesOf,
|
||||
VmpPMatOwned, VmpPMatToMut, VmpPMatToRef,
|
||||
},
|
||||
oep::{
|
||||
VmpApplyAddImpl, VmpApplyAddTmpBytesImpl, VmpApplyImpl, VmpApplyTmpBytesImpl, VmpPMatAllocBytesImpl,
|
||||
VmpPMatAllocImpl, VmpPMatFromBytesImpl, VmpPMatPrepareImpl, VmpPrepareTmpBytesImpl,
|
||||
},
|
||||
},
|
||||
implementation::cpu_spqlios::{
|
||||
ffi::{vec_znx_dft::vec_znx_dft_t, vmp},
|
||||
module_fft64::FFT64,
|
||||
},
|
||||
};
|
||||
|
||||
const VMP_PMAT_FFT64_WORDSIZE: usize = 1;
|
||||
|
||||
impl<D: DataRef> ZnxView for VmpPMat<D, FFT64> {
|
||||
type Scalar = f64;
|
||||
}
|
||||
|
||||
impl VmpPMatBytesOf for FFT64 {
|
||||
fn vmp_pmat_bytes_of(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
VMP_PMAT_FFT64_WORDSIZE * n * rows * cols_in * cols_out * size * size_of::<f64>()
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpPMatAllocBytesImpl<FFT64> for FFT64
|
||||
where
|
||||
FFT64: VmpPMatBytesOf,
|
||||
{
|
||||
fn vmp_pmat_alloc_bytes_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
FFT64::vmp_pmat_bytes_of(n, rows, cols_in, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpPMatFromBytesImpl<FFT64> for FFT64 {
|
||||
fn vmp_pmat_from_bytes_impl(
|
||||
n: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
bytes: Vec<u8>,
|
||||
) -> VmpPMatOwned<FFT64> {
|
||||
VmpPMatOwned::from_bytes(n, rows, cols_in, cols_out, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpPMatAllocImpl<FFT64> for FFT64 {
|
||||
fn vmp_pmat_alloc_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<FFT64> {
|
||||
VmpPMatOwned::alloc(n, rows, cols_in, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpPrepareTmpBytesImpl<FFT64> for FFT64 {
|
||||
fn vmp_prepare_tmp_bytes_impl(module: &Module<FFT64>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
unsafe {
|
||||
vmp::vmp_prepare_tmp_bytes(
|
||||
module.ptr(),
|
||||
(rows * cols_in) as u64,
|
||||
(cols_out * size) as u64,
|
||||
) as usize
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpPMatPrepareImpl<FFT64> for FFT64 {
|
||||
fn vmp_prepare_impl<R, A>(module: &Module<FFT64>, res: &mut R, a: &A, scratch: &mut Scratch<FFT64>)
|
||||
where
|
||||
R: VmpPMatToMut<FFT64>,
|
||||
A: MatZnxToRef,
|
||||
{
|
||||
let mut res: VmpPMat<&mut [u8], FFT64> = res.to_mut();
|
||||
let a: MatZnx<&[u8]> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(
|
||||
res.cols_in(),
|
||||
a.cols_in(),
|
||||
"res.cols_in: {} != a.cols_in: {}",
|
||||
res.cols_in(),
|
||||
a.cols_in()
|
||||
);
|
||||
assert_eq!(
|
||||
res.rows(),
|
||||
a.rows(),
|
||||
"res.rows: {} != a.rows: {}",
|
||||
res.rows(),
|
||||
a.rows()
|
||||
);
|
||||
assert_eq!(
|
||||
res.cols_out(),
|
||||
a.cols_out(),
|
||||
"res.cols_out: {} != a.cols_out: {}",
|
||||
res.cols_out(),
|
||||
a.cols_out()
|
||||
);
|
||||
assert_eq!(
|
||||
res.size(),
|
||||
a.size(),
|
||||
"res.size: {} != a.size: {}",
|
||||
res.size(),
|
||||
a.size()
|
||||
);
|
||||
}
|
||||
|
||||
let (tmp_bytes, _) = scratch.take_slice(module.vmp_prepare_tmp_bytes(a.rows(), a.cols_in(), a.cols_out(), a.size()));
|
||||
|
||||
unsafe {
|
||||
vmp::vmp_prepare_contiguous(
|
||||
module.ptr(),
|
||||
res.as_mut_ptr() as *mut vmp::vmp_pmat_t,
|
||||
a.as_ptr(),
|
||||
(a.rows() * a.cols_in()) as u64,
|
||||
(a.size() * a.cols_out()) as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpApplyTmpBytesImpl<FFT64> for FFT64 {
|
||||
fn vmp_apply_tmp_bytes_impl(
|
||||
module: &Module<FFT64>,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
b_cols_in: usize,
|
||||
b_cols_out: usize,
|
||||
b_size: usize,
|
||||
) -> usize {
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_to_dft_tmp_bytes(
|
||||
module.ptr(),
|
||||
(res_size * b_cols_out) as u64,
|
||||
(a_size * b_cols_in) as u64,
|
||||
(b_rows * b_cols_in) as u64,
|
||||
(b_size * b_cols_out) as u64,
|
||||
) as usize
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpApplyImpl<FFT64> for FFT64 {
|
||||
fn vmp_apply_impl<R, A, C>(module: &Module<FFT64>, res: &mut R, a: &A, b: &C, scratch: &mut Scratch<FFT64>)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
C: VmpPMatToRef<FFT64>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], _> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], _> = a.to_ref();
|
||||
let b: VmpPMat<&[u8], _> = b.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(b.n(), module.n());
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(
|
||||
res.cols(),
|
||||
b.cols_out(),
|
||||
"res.cols(): {} != b.cols_out: {}",
|
||||
res.cols(),
|
||||
b.cols_out()
|
||||
);
|
||||
assert_eq!(
|
||||
a.cols(),
|
||||
b.cols_in(),
|
||||
"a.cols(): {} != b.cols_in: {}",
|
||||
a.cols(),
|
||||
b.cols_in()
|
||||
);
|
||||
}
|
||||
|
||||
let (tmp_bytes, _) = scratch.take_slice(module.vmp_apply_tmp_bytes(
|
||||
res.size(),
|
||||
a.size(),
|
||||
b.rows(),
|
||||
b.cols_in(),
|
||||
b.cols_out(),
|
||||
b.size(),
|
||||
));
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_to_dft(
|
||||
module.ptr(),
|
||||
res.as_mut_ptr() as *mut vec_znx_dft_t,
|
||||
(res.size() * res.cols()) as u64,
|
||||
a.as_ptr() as *const vec_znx_dft_t,
|
||||
(a.size() * a.cols()) as u64,
|
||||
b.as_ptr() as *const vmp::vmp_pmat_t,
|
||||
(b.rows() * b.cols_in()) as u64,
|
||||
(b.size() * b.cols_out()) as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpApplyAddTmpBytesImpl<FFT64> for FFT64 {
|
||||
fn vmp_apply_add_tmp_bytes_impl(
|
||||
module: &Module<FFT64>,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
b_cols_in: usize,
|
||||
b_cols_out: usize,
|
||||
b_size: usize,
|
||||
) -> usize {
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_to_dft_tmp_bytes(
|
||||
module.ptr(),
|
||||
(res_size * b_cols_out) as u64,
|
||||
(a_size * b_cols_in) as u64,
|
||||
(b_rows * b_cols_in) as u64,
|
||||
(b_size * b_cols_out) as u64,
|
||||
) as usize
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpApplyAddImpl<FFT64> for FFT64 {
|
||||
fn vmp_apply_add_impl<R, A, C>(module: &Module<FFT64>, res: &mut R, a: &A, b: &C, scale: usize, scratch: &mut Scratch<FFT64>)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
C: VmpPMatToRef<FFT64>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], _> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], _> = a.to_ref();
|
||||
let b: VmpPMat<&[u8], _> = b.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
use crate::hal::api::ZnxInfos;
|
||||
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(b.n(), module.n());
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(
|
||||
res.cols(),
|
||||
b.cols_out(),
|
||||
"res.cols(): {} != b.cols_out: {}",
|
||||
res.cols(),
|
||||
b.cols_out()
|
||||
);
|
||||
assert_eq!(
|
||||
a.cols(),
|
||||
b.cols_in(),
|
||||
"a.cols(): {} != b.cols_in: {}",
|
||||
a.cols(),
|
||||
b.cols_in()
|
||||
);
|
||||
}
|
||||
|
||||
let (tmp_bytes, _) = scratch.take_slice(module.vmp_apply_tmp_bytes(
|
||||
res.size(),
|
||||
a.size(),
|
||||
b.rows(),
|
||||
b.cols_in(),
|
||||
b.cols_out(),
|
||||
b.size(),
|
||||
));
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_to_dft_add(
|
||||
module.ptr(),
|
||||
res.as_mut_ptr() as *mut vec_znx_dft_t,
|
||||
(res.size() * res.cols()) as u64,
|
||||
a.as_ptr() as *const vec_znx_dft_t,
|
||||
(a.size() * a.cols()) as u64,
|
||||
b.as_ptr() as *const vmp::vmp_pmat_t,
|
||||
(b.rows() * b.cols_in()) as u64,
|
||||
(b.size() * b.cols_out()) as u64,
|
||||
(scale * b.cols_out()) as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
11
backend/src/implementation/cpu_spqlios/vmp_pmat_ntt120.rs
Normal file
11
backend/src/implementation/cpu_spqlios/vmp_pmat_ntt120.rs
Normal file
@@ -0,0 +1,11 @@
|
||||
use crate::{
|
||||
hal::{
|
||||
api::ZnxView,
|
||||
layouts::{DataRef, VmpPMat},
|
||||
},
|
||||
implementation::cpu_spqlios::module_ntt120::NTT120,
|
||||
};
|
||||
|
||||
impl<D: DataRef> ZnxView for VmpPMat<D, NTT120> {
|
||||
type Scalar = i64;
|
||||
}
|
||||
1
backend/src/implementation/mod.rs
Normal file
1
backend/src/implementation/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod cpu_spqlios;
|
||||
@@ -1,39 +1,17 @@
|
||||
pub mod encoding;
|
||||
#[allow(non_camel_case_types, non_snake_case, non_upper_case_globals, dead_code, improper_ctypes)]
|
||||
// Other modules and exports
|
||||
pub mod ffi;
|
||||
pub mod mat_znx_dft;
|
||||
pub mod mat_znx_dft_ops;
|
||||
pub mod module;
|
||||
pub mod sampling;
|
||||
pub mod scalar_znx;
|
||||
pub mod scalar_znx_dft;
|
||||
pub mod scalar_znx_dft_ops;
|
||||
pub mod stats;
|
||||
pub mod vec_znx;
|
||||
pub mod vec_znx_big;
|
||||
pub mod vec_znx_big_ops;
|
||||
pub mod vec_znx_dft;
|
||||
pub mod vec_znx_dft_ops;
|
||||
pub mod vec_znx_ops;
|
||||
pub mod znx_base;
|
||||
#![allow(non_camel_case_types, non_snake_case, non_upper_case_globals, dead_code, improper_ctypes)]
|
||||
#![deny(rustdoc::broken_intra_doc_links)]
|
||||
#![cfg_attr(docsrs, feature(doc_cfg))]
|
||||
#![feature(trait_alias)]
|
||||
|
||||
pub use encoding::*;
|
||||
pub use mat_znx_dft::*;
|
||||
pub use mat_znx_dft_ops::*;
|
||||
pub use module::*;
|
||||
pub use sampling::*;
|
||||
pub use scalar_znx::*;
|
||||
pub use scalar_znx_dft::*;
|
||||
pub use scalar_znx_dft_ops::*;
|
||||
pub use stats::*;
|
||||
pub use vec_znx::*;
|
||||
pub use vec_znx_big::*;
|
||||
pub use vec_znx_big_ops::*;
|
||||
pub use vec_znx_dft::*;
|
||||
pub use vec_znx_dft_ops::*;
|
||||
pub use vec_znx_ops::*;
|
||||
pub use znx_base::*;
|
||||
pub mod hal;
|
||||
pub mod implementation;
|
||||
|
||||
pub mod doc {
|
||||
#[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/docs/backend_safety_contract.md"))]
|
||||
pub mod backend_safety {
|
||||
pub const _PLACEHOLDER: () = ();
|
||||
}
|
||||
}
|
||||
|
||||
pub const GALOISGENERATOR: u64 = 5;
|
||||
pub const DEFAULTALIGN: usize = 64;
|
||||
@@ -118,190 +96,10 @@ pub fn alloc_aligned_custom<T>(size: usize, align: usize) -> Vec<T> {
|
||||
}
|
||||
|
||||
/// Allocates an aligned vector of size equal to the smallest multiple
|
||||
/// of [DEFAULTALIGN]/size_of::<T>() that is equal or greater to `size`.
|
||||
/// of [DEFAULTALIGN]/`size_of::<T>`() that is equal or greater to `size`.
|
||||
pub fn alloc_aligned<T>(size: usize) -> Vec<T> {
|
||||
alloc_aligned_custom::<T>(
|
||||
size + (DEFAULTALIGN - (size % (DEFAULTALIGN / size_of::<T>()))),
|
||||
size + (DEFAULTALIGN - (size % (DEFAULTALIGN / size_of::<T>()))) % DEFAULTALIGN,
|
||||
DEFAULTALIGN,
|
||||
)
|
||||
}
|
||||
|
||||
// Scratch implementation below
|
||||
|
||||
pub struct ScratchOwned(Vec<u8>);
|
||||
|
||||
impl ScratchOwned {
|
||||
pub fn new(byte_count: usize) -> Self {
|
||||
let data: Vec<u8> = alloc_aligned(byte_count);
|
||||
Self(data)
|
||||
}
|
||||
|
||||
pub fn borrow(&mut self) -> &mut Scratch {
|
||||
Scratch::new(&mut self.0)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Scratch {
|
||||
data: [u8],
|
||||
}
|
||||
|
||||
impl Scratch {
|
||||
fn new(data: &mut [u8]) -> &mut Self {
|
||||
unsafe { &mut *(data as *mut [u8] as *mut Self) }
|
||||
}
|
||||
|
||||
pub fn zero(&mut self) {
|
||||
self.data.fill(0);
|
||||
}
|
||||
|
||||
pub fn available(&self) -> usize {
|
||||
let ptr: *const u8 = self.data.as_ptr();
|
||||
let self_len: usize = self.data.len();
|
||||
let aligned_offset: usize = ptr.align_offset(DEFAULTALIGN);
|
||||
self_len.saturating_sub(aligned_offset)
|
||||
}
|
||||
|
||||
fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8]) {
|
||||
let ptr: *mut u8 = data.as_mut_ptr();
|
||||
let self_len: usize = data.len();
|
||||
|
||||
let aligned_offset: usize = ptr.align_offset(DEFAULTALIGN);
|
||||
let aligned_len: usize = self_len.saturating_sub(aligned_offset);
|
||||
|
||||
if let Some(rem_len) = aligned_len.checked_sub(take_len) {
|
||||
unsafe {
|
||||
let rem_ptr: *mut u8 = ptr.add(aligned_offset).add(take_len);
|
||||
let rem_slice: &mut [u8] = &mut *std::ptr::slice_from_raw_parts_mut(rem_ptr, rem_len);
|
||||
|
||||
let take_slice: &mut [u8] = &mut *std::ptr::slice_from_raw_parts_mut(ptr.add(aligned_offset), take_len);
|
||||
|
||||
return (take_slice, rem_slice);
|
||||
}
|
||||
} else {
|
||||
panic!(
|
||||
"Attempted to take {} from scratch with {} aligned bytes left",
|
||||
take_len,
|
||||
aligned_len,
|
||||
// type_name::<T>(),
|
||||
// aligned_len
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tmp_slice<T>(&mut self, len: usize) -> (&mut [T], &mut Self) {
|
||||
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, len * std::mem::size_of::<T>());
|
||||
|
||||
unsafe {
|
||||
(
|
||||
&mut *(std::ptr::slice_from_raw_parts_mut(take_slice.as_mut_ptr() as *mut T, len)),
|
||||
Self::new(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tmp_scalar_znx<B: Backend>(&mut self, module: &Module<B>, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) {
|
||||
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_scalar_znx(module, cols));
|
||||
|
||||
(
|
||||
ScalarZnx::from_data(take_slice, module.n(), cols),
|
||||
Self::new(rem_slice),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn tmp_scalar_znx_dft<B: Backend>(&mut self, module: &Module<B>, cols: usize) -> (ScalarZnxDft<&mut [u8], B>, &mut Self) {
|
||||
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_scalar_znx_dft(module, cols));
|
||||
|
||||
(
|
||||
ScalarZnxDft::from_data(take_slice, module.n(), cols),
|
||||
Self::new(rem_slice),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn tmp_vec_znx_dft<B: Backend>(
|
||||
&mut self,
|
||||
module: &Module<B>,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (VecZnxDft<&mut [u8], B>, &mut Self) {
|
||||
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_vec_znx_dft(module, cols, size));
|
||||
|
||||
(
|
||||
VecZnxDft::from_data(take_slice, module.n(), cols, size),
|
||||
Self::new(rem_slice),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn tmp_slice_vec_znx_dft<B: Backend>(
|
||||
&mut self,
|
||||
slice_size: usize,
|
||||
module: &Module<B>,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (Vec<VecZnxDft<&mut [u8], B>>, &mut Self) {
|
||||
let mut scratch: &mut Scratch = self;
|
||||
let mut slice: Vec<VecZnxDft<&mut [u8], B>> = Vec::with_capacity(slice_size);
|
||||
for _ in 0..slice_size {
|
||||
let (znx, new_scratch) = scratch.tmp_vec_znx_dft(module, cols, size);
|
||||
scratch = new_scratch;
|
||||
slice.push(znx);
|
||||
}
|
||||
(slice, scratch)
|
||||
}
|
||||
|
||||
pub fn tmp_vec_znx_big<B: Backend>(
|
||||
&mut self,
|
||||
module: &Module<B>,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (VecZnxBig<&mut [u8], B>, &mut Self) {
|
||||
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_vec_znx_big(module, cols, size));
|
||||
|
||||
(
|
||||
VecZnxBig::from_data(take_slice, module.n(), cols, size),
|
||||
Self::new(rem_slice),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn tmp_vec_znx<B: Backend>(&mut self, module: &Module<B>, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self) {
|
||||
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, module.bytes_of_vec_znx(cols, size));
|
||||
(
|
||||
VecZnx::from_data(take_slice, module.n(), cols, size),
|
||||
Self::new(rem_slice),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn tmp_slice_vec_znx<B: Backend>(
|
||||
&mut self,
|
||||
slice_size: usize,
|
||||
module: &Module<B>,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (Vec<VecZnx<&mut [u8]>>, &mut Self) {
|
||||
let mut scratch: &mut Scratch = self;
|
||||
let mut slice: Vec<VecZnx<&mut [u8]>> = Vec::with_capacity(slice_size);
|
||||
for _ in 0..slice_size {
|
||||
let (znx, new_scratch) = scratch.tmp_vec_znx(module, cols, size);
|
||||
scratch = new_scratch;
|
||||
slice.push(znx);
|
||||
}
|
||||
(slice, scratch)
|
||||
}
|
||||
|
||||
pub fn tmp_mat_znx_dft<B: Backend>(
|
||||
&mut self,
|
||||
module: &Module<B>,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
) -> (MatZnxDft<&mut [u8], B>, &mut Self) {
|
||||
let (take_slice, rem_slice) = Self::take_slice_aligned(
|
||||
&mut self.data,
|
||||
module.bytes_of_mat_znx_dft(rows, cols_in, cols_out, size),
|
||||
);
|
||||
(
|
||||
MatZnxDft::from_data(take_slice, module.n(), rows, cols_in, cols_out, size),
|
||||
Self::new(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,214 +0,0 @@
|
||||
use crate::znx_base::ZnxInfos;
|
||||
use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxView, alloc_aligned};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx],
|
||||
/// stored as a 3D matrix in the DFT domain in a single contiguous array.
|
||||
/// Each col of the [MatZnxDft] can be seen as a collection of [VecZnxDft].
|
||||
///
|
||||
/// [MatZnxDft] is used to permform a vector matrix product between a [VecZnx]/[VecZnxDft] and a [MatZnxDft].
|
||||
/// See the trait [MatZnxDftOps] for additional information.
|
||||
pub struct MatZnxDft<D, B: Backend> {
|
||||
data: D,
|
||||
n: usize,
|
||||
size: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
_phantom: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<D, B: Backend> ZnxInfos for MatZnxDft<D, B> {
|
||||
fn cols(&self) -> usize {
|
||||
self.cols_in
|
||||
}
|
||||
|
||||
fn rows(&self) -> usize {
|
||||
self.rows
|
||||
}
|
||||
|
||||
fn n(&self) -> usize {
|
||||
self.n
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
}
|
||||
|
||||
impl<D> ZnxSliceSize for MatZnxDft<D, FFT64> {
|
||||
fn sl(&self) -> usize {
|
||||
self.n() * self.cols_out()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, B: Backend> DataView for MatZnxDft<D, B> {
|
||||
type D = D;
|
||||
fn data(&self) -> &Self::D {
|
||||
&self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, B: Backend> DataViewMut for MatZnxDft<D, B> {
|
||||
fn data_mut(&mut self) -> &mut Self::D {
|
||||
&mut self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: AsRef<[u8]>> ZnxView for MatZnxDft<D, FFT64> {
|
||||
type Scalar = f64;
|
||||
}
|
||||
|
||||
impl<D, B: Backend> MatZnxDft<D, B> {
|
||||
pub fn cols_in(&self) -> usize {
|
||||
self.cols_in
|
||||
}
|
||||
|
||||
pub fn cols_out(&self) -> usize {
|
||||
self.cols_out
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: From<Vec<u8>>, B: Backend> MatZnxDft<D, B> {
|
||||
pub(crate) fn bytes_of(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
unsafe {
|
||||
crate::ffi::vmp::bytes_of_vmp_pmat(
|
||||
module.ptr,
|
||||
(rows * cols_in) as u64,
|
||||
(size * cols_out) as u64,
|
||||
) as usize
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn new(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
|
||||
let data: Vec<u8> = alloc_aligned(Self::bytes_of(module, rows, cols_in, cols_out, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n: module.n(),
|
||||
size,
|
||||
rows,
|
||||
cols_in,
|
||||
cols_out,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn new_from_bytes(
|
||||
module: &Module<B>,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
bytes: impl Into<Vec<u8>>,
|
||||
) -> Self {
|
||||
let data: Vec<u8> = bytes.into();
|
||||
assert!(data.len() == Self::bytes_of(module, rows, cols_in, cols_out, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n: module.n(),
|
||||
size,
|
||||
rows,
|
||||
cols_in,
|
||||
cols_out,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: AsRef<[u8]>> MatZnxDft<D, FFT64> {
|
||||
/// Returns a copy of the backend array at index (i, j) of the [MatZnxDft].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `row`: row index (i).
|
||||
/// * `col`: col index (j).
|
||||
#[allow(dead_code)]
|
||||
fn at(&self, row: usize, col: usize) -> Vec<f64> {
|
||||
let n: usize = self.n();
|
||||
|
||||
let mut res: Vec<f64> = alloc_aligned(n);
|
||||
|
||||
if n < 8 {
|
||||
res.copy_from_slice(&self.raw()[(row + col * self.rows()) * n..(row + col * self.rows()) * (n + 1)]);
|
||||
} else {
|
||||
(0..n >> 3).for_each(|blk| {
|
||||
res[blk * 8..(blk + 1) * 8].copy_from_slice(&self.at_block(row, col, blk)[..8]);
|
||||
});
|
||||
}
|
||||
|
||||
res
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn at_block(&self, row: usize, col: usize, blk: usize) -> &[f64] {
|
||||
let nrows: usize = self.rows();
|
||||
let nsize: usize = self.size();
|
||||
if col == (nsize - 1) && (nsize & 1 == 1) {
|
||||
&self.raw()[blk * nrows * nsize * 8 + col * nrows * 8 + row * 8..]
|
||||
} else {
|
||||
&self.raw()[blk * nrows * nsize * 8 + (col / 2) * (2 * nrows) * 8 + row * 2 * 8 + (col % 2) * 8..]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type MatZnxDftOwned<B> = MatZnxDft<Vec<u8>, B>;
|
||||
pub type MatZnxDftMut<'a, B> = MatZnxDft<&'a mut [u8], B>;
|
||||
pub type MatZnxDftRef<'a, B> = MatZnxDft<&'a [u8], B>;
|
||||
|
||||
pub trait MatZnxToRef<B: Backend> {
|
||||
fn to_ref(&self) -> MatZnxDft<&[u8], B>;
|
||||
}
|
||||
|
||||
impl<D, B: Backend> MatZnxToRef<B> for MatZnxDft<D, B>
|
||||
where
|
||||
D: AsRef<[u8]>,
|
||||
B: Backend,
|
||||
{
|
||||
fn to_ref(&self) -> MatZnxDft<&[u8], B> {
|
||||
MatZnxDft {
|
||||
data: self.data.as_ref(),
|
||||
n: self.n,
|
||||
rows: self.rows,
|
||||
cols_in: self.cols_in,
|
||||
cols_out: self.cols_out,
|
||||
size: self.size,
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait MatZnxToMut<B: Backend> {
|
||||
fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B>;
|
||||
}
|
||||
|
||||
impl<D, B: Backend> MatZnxToMut<B> for MatZnxDft<D, B>
|
||||
where
|
||||
D: AsRef<[u8]> + AsMut<[u8]>,
|
||||
B: Backend,
|
||||
{
|
||||
fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> {
|
||||
MatZnxDft {
|
||||
data: self.data.as_mut(),
|
||||
n: self.n,
|
||||
rows: self.rows,
|
||||
cols_in: self.cols_in,
|
||||
cols_out: self.cols_out,
|
||||
size: self.size,
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, B: Backend> MatZnxDft<D, B> {
|
||||
pub(crate) fn from_data(data: D, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
|
||||
Self {
|
||||
data,
|
||||
n,
|
||||
rows,
|
||||
cols_in,
|
||||
cols_out,
|
||||
size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,996 +0,0 @@
|
||||
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
|
||||
use crate::ffi::vmp;
|
||||
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
|
||||
use crate::{
|
||||
Backend, FFT64, MatZnxDft, MatZnxDftOwned, MatZnxToMut, MatZnxToRef, Module, ScalarZnxAlloc, ScalarZnxDftAlloc,
|
||||
ScalarZnxDftOps, Scratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps,
|
||||
};
|
||||
|
||||
pub trait MatZnxDftAlloc<B: Backend> {
|
||||
/// Allocates a new [MatZnxDft] with the given number of rows and columns.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `rows`: number of rows (number of [VecZnxDft]).
|
||||
/// * `size`: number of size (number of size of each [VecZnxDft]).
|
||||
fn new_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxDftOwned<B>;
|
||||
|
||||
fn bytes_of_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
|
||||
|
||||
fn new_mat_znx_dft_from_bytes(
|
||||
&self,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
bytes: Vec<u8>,
|
||||
) -> MatZnxDftOwned<B>;
|
||||
}
|
||||
|
||||
pub trait MatZnxDftScratch {
|
||||
/// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft_to_dft].
|
||||
fn vmp_apply_tmp_bytes(
|
||||
&self,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
b_cols_in: usize,
|
||||
b_cols_out: usize,
|
||||
b_size: usize,
|
||||
) -> usize;
|
||||
|
||||
fn mat_znx_dft_mul_x_pow_minus_one_scratch_space(&self, size: usize, cols_out: usize) -> usize;
|
||||
}
|
||||
|
||||
/// This trait implements methods for vector matrix product,
|
||||
/// that is, multiplying a [VecZnx] with a [MatZnxDft].
|
||||
pub trait MatZnxDftOps<BACKEND: Backend> {
|
||||
/// Prepares the ith-row of [MatZnxDft] from a [VecZnxDft].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `res`: [MatZnxDft] on which the values are encoded.
|
||||
/// * `a`: the [VecZnxDft] to encode on the [MatZnxDft].
|
||||
/// * `row_i`: the index of the row to prepare.
|
||||
///
|
||||
/// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes].
|
||||
fn mat_znx_dft_set_row<R, A>(&self, res: &mut R, res_row: usize, res_col_in: usize, a: &A)
|
||||
where
|
||||
R: MatZnxToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>;
|
||||
|
||||
/// Extracts the ith-row of [MatZnxDft] into a [VecZnxDft].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `res`: the [VecZnxDft] to on which to extract the row of the [MatZnxDft].
|
||||
/// * `a`: [MatZnxDft] on which the values are encoded.
|
||||
/// * `row_i`: the index of the row to extract.
|
||||
fn mat_znx_dft_get_row<R, A>(&self, res: &mut R, a: &A, a_row: usize, a_col_in: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: MatZnxToRef<FFT64>;
|
||||
|
||||
/// Multiplies A by (X^{k} - 1) and stores the result on R.
|
||||
fn mat_znx_dft_mul_x_pow_minus_one<R, A>(&self, k: i64, res: &mut R, a: &A, scratch: &mut Scratch)
|
||||
where
|
||||
R: MatZnxToMut<FFT64>,
|
||||
A: MatZnxToRef<FFT64>;
|
||||
|
||||
/// Multiplies A by (X^{k} - 1).
|
||||
fn mat_znx_dft_mul_x_pow_minus_one_inplace<A>(&self, k: i64, a: &mut A, scratch: &mut Scratch)
|
||||
where
|
||||
A: MatZnxToMut<FFT64>;
|
||||
|
||||
/// Multiplies A by (X^{k} - 1).
|
||||
fn mat_znx_dft_mul_x_pow_minus_one_add_inplace<R, A>(&self, k: i64, res: &mut R, a: &A, scratch: &mut Scratch)
|
||||
where
|
||||
R: MatZnxToMut<FFT64>,
|
||||
A: MatZnxToRef<FFT64>;
|
||||
|
||||
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft].
|
||||
/// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
|
||||
///
|
||||
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
|
||||
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
|
||||
/// and each vector a [VecZnxDft] (row) of the [MatZnxDft].
|
||||
///
|
||||
/// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and
|
||||
/// `j` size, the output is a [VecZnx] of `j` size.
|
||||
///
|
||||
/// If there is a mismatch between the dimensions the largest valid ones are used.
|
||||
///
|
||||
/// ```text
|
||||
/// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p|
|
||||
/// |h i j|
|
||||
/// |k l m|
|
||||
/// ```
|
||||
/// where each element is a [VecZnxDft].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `c`: the output of the vector matrix product, as a [VecZnxDft].
|
||||
/// * `a`: the left operand [VecZnxDft] of the vector matrix product.
|
||||
/// * `b`: the right operand [MatZnxDft] of the vector matrix product.
|
||||
/// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
|
||||
fn vmp_apply<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
B: MatZnxToRef<FFT64>;
|
||||
|
||||
// Same as [MatZnxDftOps::vmp_apply] except result is added on R instead of overwritting R.
|
||||
fn vmp_apply_add<R, A, B>(&self, res: &mut R, a: &A, b: &B, scale: usize, scratch: &mut Scratch)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
B: MatZnxToRef<FFT64>;
|
||||
}
|
||||
|
||||
impl<B: Backend> MatZnxDftAlloc<B> for Module<B> {
|
||||
fn bytes_of_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
MatZnxDftOwned::bytes_of(self, rows, cols_in, cols_out, size)
|
||||
}
|
||||
|
||||
fn new_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxDftOwned<B> {
|
||||
MatZnxDftOwned::new(self, rows, cols_in, cols_out, size)
|
||||
}
|
||||
|
||||
fn new_mat_znx_dft_from_bytes(
|
||||
&self,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
bytes: Vec<u8>,
|
||||
) -> MatZnxDftOwned<B> {
|
||||
MatZnxDftOwned::new_from_bytes(self, rows, cols_in, cols_out, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl<BACKEND: Backend> MatZnxDftScratch for Module<BACKEND> {
|
||||
fn vmp_apply_tmp_bytes(
|
||||
&self,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
b_cols_in: usize,
|
||||
b_cols_out: usize,
|
||||
b_size: usize,
|
||||
) -> usize {
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_to_dft_tmp_bytes(
|
||||
self.ptr,
|
||||
(res_size * b_cols_out) as u64,
|
||||
(a_size * b_cols_in) as u64,
|
||||
(b_rows * b_cols_in) as u64,
|
||||
(b_size * b_cols_out) as u64,
|
||||
) as usize
|
||||
}
|
||||
}
|
||||
|
||||
fn mat_znx_dft_mul_x_pow_minus_one_scratch_space(&self, size: usize, cols_out: usize) -> usize {
|
||||
let xpm1_dft: usize = self.bytes_of_scalar_znx(1);
|
||||
let xpm1: usize = self.bytes_of_scalar_znx_dft(1);
|
||||
let tmp: usize = self.bytes_of_vec_znx_dft(cols_out, size);
|
||||
xpm1_dft + (xpm1 | 2 * tmp)
|
||||
}
|
||||
}
|
||||
|
||||
impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
fn mat_znx_dft_mul_x_pow_minus_one<R, A>(&self, k: i64, res: &mut R, a: &A, scratch: &mut Scratch)
|
||||
where
|
||||
R: MatZnxToMut<FFT64>,
|
||||
A: MatZnxToRef<FFT64>,
|
||||
{
|
||||
let mut res: MatZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a: MatZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), self.n());
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.rows(), a.rows());
|
||||
assert_eq!(res.cols_in(), a.cols_in());
|
||||
assert_eq!(res.cols_out(), a.cols_out());
|
||||
}
|
||||
|
||||
let (mut xpm1_dft, scratch1) = scratch.tmp_scalar_znx_dft(self, 1);
|
||||
|
||||
{
|
||||
let (mut xpm1, _) = scratch1.tmp_scalar_znx(self, 1);
|
||||
xpm1.data[0] = 1;
|
||||
self.vec_znx_rotate_inplace(k, &mut xpm1, 0);
|
||||
self.svp_prepare(&mut xpm1_dft, 0, &xpm1, 0);
|
||||
}
|
||||
|
||||
let (mut tmp_0, scratch2) = scratch1.tmp_vec_znx_dft(self, res.cols_out(), res.size());
|
||||
let (mut tmp_1, _) = scratch2.tmp_vec_znx_dft(self, res.cols_out(), res.size());
|
||||
|
||||
(0..res.rows()).for_each(|row_i| {
|
||||
(0..res.cols_in()).for_each(|col_j| {
|
||||
self.mat_znx_dft_get_row(&mut tmp_0, &a, row_i, col_j);
|
||||
|
||||
(0..tmp_0.cols()).for_each(|i| {
|
||||
self.svp_apply(&mut tmp_1, i, &xpm1_dft, 0, &tmp_0, i);
|
||||
self.vec_znx_dft_sub_ab_inplace(&mut tmp_1, i, &tmp_0, i);
|
||||
});
|
||||
|
||||
self.mat_znx_dft_set_row(&mut res, row_i, col_j, &tmp_1);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
fn mat_znx_dft_mul_x_pow_minus_one_inplace<A>(&self, k: i64, a: &mut A, scratch: &mut Scratch)
|
||||
where
|
||||
A: MatZnxToMut<FFT64>,
|
||||
{
|
||||
let mut a: MatZnxDft<&mut [u8], FFT64> = a.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
}
|
||||
|
||||
let (mut xpm1_dft, scratch1) = scratch.tmp_scalar_znx_dft(self, 1);
|
||||
|
||||
{
|
||||
let (mut xpm1, _) = scratch1.tmp_scalar_znx(self, 1);
|
||||
xpm1.data[0] = 1;
|
||||
self.vec_znx_rotate_inplace(k, &mut xpm1, 0);
|
||||
self.svp_prepare(&mut xpm1_dft, 0, &xpm1, 0);
|
||||
}
|
||||
|
||||
let (mut tmp_0, scratch2) = scratch1.tmp_vec_znx_dft(self, a.cols_out(), a.size());
|
||||
let (mut tmp_1, _) = scratch2.tmp_vec_znx_dft(self, a.cols_out(), a.size());
|
||||
|
||||
(0..a.rows()).for_each(|row_i| {
|
||||
(0..a.cols_in()).for_each(|col_j| {
|
||||
self.mat_znx_dft_get_row(&mut tmp_0, &a, row_i, col_j);
|
||||
|
||||
(0..tmp_0.cols()).for_each(|i| {
|
||||
self.svp_apply(&mut tmp_1, i, &xpm1_dft, 0, &tmp_0, i);
|
||||
self.vec_znx_dft_sub_ab_inplace(&mut tmp_1, i, &tmp_0, i);
|
||||
});
|
||||
|
||||
self.mat_znx_dft_set_row(&mut a, row_i, col_j, &tmp_1);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
fn mat_znx_dft_mul_x_pow_minus_one_add_inplace<R, A>(&self, k: i64, res: &mut R, a: &A, scratch: &mut Scratch)
|
||||
where
|
||||
R: MatZnxToMut<FFT64>,
|
||||
A: MatZnxToRef<FFT64>,
|
||||
{
|
||||
let mut res: MatZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a: MatZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
}
|
||||
|
||||
let (mut xpm1_dft, scratch1) = scratch.tmp_scalar_znx_dft(self, 1);
|
||||
|
||||
{
|
||||
let (mut xpm1, _) = scratch1.tmp_scalar_znx(self, 1);
|
||||
xpm1.data[0] = 1;
|
||||
self.vec_znx_rotate_inplace(k, &mut xpm1, 0);
|
||||
self.svp_prepare(&mut xpm1_dft, 0, &xpm1, 0);
|
||||
}
|
||||
|
||||
let (mut tmp_0, scratch2) = scratch1.tmp_vec_znx_dft(self, a.cols_out(), a.size());
|
||||
let (mut tmp_1, _) = scratch2.tmp_vec_znx_dft(self, a.cols_out(), a.size());
|
||||
|
||||
(0..a.rows()).for_each(|row_i| {
|
||||
(0..a.cols_in()).for_each(|col_j| {
|
||||
self.mat_znx_dft_get_row(&mut tmp_0, &a, row_i, col_j);
|
||||
|
||||
(0..tmp_0.cols()).for_each(|i| {
|
||||
self.svp_apply(&mut tmp_1, i, &xpm1_dft, 0, &tmp_0, i);
|
||||
self.vec_znx_dft_sub_ab_inplace(&mut tmp_1, i, &tmp_0, i);
|
||||
});
|
||||
|
||||
self.mat_znx_dft_get_row(&mut tmp_0, &res, row_i, col_j);
|
||||
|
||||
(0..tmp_0.cols()).for_each(|i| {
|
||||
self.vec_znx_dft_add_inplace(&mut tmp_0, i, &tmp_1, i);
|
||||
});
|
||||
|
||||
self.mat_znx_dft_set_row(&mut res, row_i, col_j, &tmp_0);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
fn mat_znx_dft_set_row<R, A>(&self, res: &mut R, res_row: usize, res_col_in: usize, a: &A)
|
||||
where
|
||||
R: MatZnxToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res: MatZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], _> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), self.n());
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(
|
||||
a.cols(),
|
||||
res.cols_out(),
|
||||
"a.cols(): {} != res.cols_out(): {}",
|
||||
a.cols(),
|
||||
res.cols_out()
|
||||
);
|
||||
assert!(
|
||||
res_row < res.rows(),
|
||||
"res_row: {} >= res.rows(): {}",
|
||||
res_row,
|
||||
res.rows()
|
||||
);
|
||||
assert!(
|
||||
res_col_in < res.cols_in(),
|
||||
"res_col_in: {} >= res.cols_in(): {}",
|
||||
res_col_in,
|
||||
res.cols_in()
|
||||
);
|
||||
assert_eq!(
|
||||
res.size(),
|
||||
a.size(),
|
||||
"res.size(): {} != a.size(): {}",
|
||||
res.size(),
|
||||
a.size()
|
||||
);
|
||||
}
|
||||
|
||||
unsafe {
|
||||
vmp::vmp_prepare_row_dft(
|
||||
self.ptr,
|
||||
res.as_mut_ptr() as *mut vmp::vmp_pmat_t,
|
||||
a.as_ptr() as *const vec_znx_dft_t,
|
||||
(res_row * res.cols_in() + res_col_in) as u64,
|
||||
(res.rows() * res.cols_in()) as u64,
|
||||
(res.size() * res.cols_out()) as u64,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn mat_znx_dft_get_row<R, A>(&self, res: &mut R, a: &A, a_row: usize, a_col_in: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: MatZnxToRef<FFT64>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a: MatZnxDft<&[u8], _> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), self.n());
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(
|
||||
res.cols(),
|
||||
a.cols_out(),
|
||||
"res.cols(): {} != a.cols_out(): {}",
|
||||
res.cols(),
|
||||
a.cols_out()
|
||||
);
|
||||
assert!(
|
||||
a_row < a.rows(),
|
||||
"a_row: {} >= a.rows(): {}",
|
||||
a_row,
|
||||
a.rows()
|
||||
);
|
||||
assert!(
|
||||
a_col_in < a.cols_in(),
|
||||
"a_col_in: {} >= a.cols_in(): {}",
|
||||
a_col_in,
|
||||
a.cols_in()
|
||||
);
|
||||
assert_eq!(
|
||||
res.size(),
|
||||
a.size(),
|
||||
"res.size(): {} != a.size(): {}",
|
||||
res.size(),
|
||||
a.size()
|
||||
);
|
||||
}
|
||||
unsafe {
|
||||
vmp::vmp_extract_row_dft(
|
||||
self.ptr,
|
||||
res.as_mut_ptr() as *mut vec_znx_dft_t,
|
||||
a.as_ptr() as *const vmp::vmp_pmat_t,
|
||||
(a_row * a.cols_in() + a_col_in) as u64,
|
||||
(a.rows() * a.cols_in()) as u64,
|
||||
(a.size() * a.cols_out()) as u64,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_apply<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
B: MatZnxToRef<FFT64>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], _> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], _> = a.to_ref();
|
||||
let b: MatZnxDft<&[u8], _> = b.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), self.n());
|
||||
assert_eq!(b.n(), self.n());
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(
|
||||
res.cols(),
|
||||
b.cols_out(),
|
||||
"res.cols(): {} != b.cols_out: {}",
|
||||
res.cols(),
|
||||
b.cols_out()
|
||||
);
|
||||
assert_eq!(
|
||||
a.cols(),
|
||||
b.cols_in(),
|
||||
"a.cols(): {} != b.cols_in: {}",
|
||||
a.cols(),
|
||||
b.cols_in()
|
||||
);
|
||||
}
|
||||
|
||||
let (tmp_bytes, _) = scratch.tmp_slice(self.vmp_apply_tmp_bytes(
|
||||
res.size(),
|
||||
a.size(),
|
||||
b.rows(),
|
||||
b.cols_in(),
|
||||
b.cols_out(),
|
||||
b.size(),
|
||||
));
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_to_dft(
|
||||
self.ptr,
|
||||
res.as_mut_ptr() as *mut vec_znx_dft_t,
|
||||
(res.size() * res.cols()) as u64,
|
||||
a.as_ptr() as *const vec_znx_dft_t,
|
||||
(a.size() * a.cols()) as u64,
|
||||
b.as_ptr() as *const vmp::vmp_pmat_t,
|
||||
(b.rows() * b.cols_in()) as u64,
|
||||
(b.size() * b.cols_out()) as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_apply_add<R, A, B>(&self, res: &mut R, a: &A, b: &B, scale: usize, scratch: &mut Scratch)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
B: MatZnxToRef<FFT64>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], _> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], _> = a.to_ref();
|
||||
let b: MatZnxDft<&[u8], _> = b.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), self.n());
|
||||
assert_eq!(b.n(), self.n());
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(
|
||||
res.cols(),
|
||||
b.cols_out(),
|
||||
"res.cols(): {} != b.cols_out: {}",
|
||||
res.cols(),
|
||||
b.cols_out()
|
||||
);
|
||||
assert_eq!(
|
||||
a.cols(),
|
||||
b.cols_in(),
|
||||
"a.cols(): {} != b.cols_in: {}",
|
||||
a.cols(),
|
||||
b.cols_in()
|
||||
);
|
||||
}
|
||||
|
||||
let (tmp_bytes, _) = scratch.tmp_slice(self.vmp_apply_tmp_bytes(
|
||||
res.size(),
|
||||
a.size(),
|
||||
b.rows(),
|
||||
b.cols_in(),
|
||||
b.cols_out(),
|
||||
b.size(),
|
||||
));
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_to_dft_add(
|
||||
self.ptr,
|
||||
res.as_mut_ptr() as *mut vec_znx_dft_t,
|
||||
(res.size() * res.cols()) as u64,
|
||||
a.as_ptr() as *const vec_znx_dft_t,
|
||||
(a.size() * a.cols()) as u64,
|
||||
b.as_ptr() as *const vmp::vmp_pmat_t,
|
||||
(b.rows() * b.cols_in()) as u64,
|
||||
(b.size() * b.cols_out()) as u64,
|
||||
(scale * b.cols_out()) as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{
|
||||
Decoding, FFT64, FillUniform, MatZnxDft, MatZnxDftOps, Module, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig,
|
||||
VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxInfos, ZnxView,
|
||||
ZnxViewMut, ZnxZero,
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
use super::{MatZnxDftAlloc, MatZnxDftScratch};
|
||||
|
||||
#[test]
|
||||
fn vmp_set_row() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(16);
|
||||
let basek: usize = 8;
|
||||
let mat_rows: usize = 4;
|
||||
let mat_cols_in: usize = 2;
|
||||
let mat_cols_out: usize = 2;
|
||||
let mat_size: usize = 5;
|
||||
let mut a: VecZnx<Vec<u8>> = module.new_vec_znx(mat_cols_out, mat_size);
|
||||
let mut a_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size);
|
||||
let mut b_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size);
|
||||
let mut mat: MatZnxDft<Vec<u8>, FFT64> = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size);
|
||||
|
||||
for col_in in 0..mat_cols_in {
|
||||
for row_i in 0..mat_rows {
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
(0..mat_cols_out).for_each(|col_out| {
|
||||
a.fill_uniform(basek, col_out, mat_size, &mut source);
|
||||
module.vec_znx_dft(1, 0, &mut a_dft, col_out, &a, col_out);
|
||||
});
|
||||
module.mat_znx_dft_set_row(&mut mat, row_i, col_in, &a_dft);
|
||||
module.mat_znx_dft_get_row(&mut b_dft, &mat, row_i, col_in);
|
||||
assert_eq!(a_dft.raw(), b_dft.raw());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vmp_apply() {
|
||||
let log_n: i32 = 5;
|
||||
let n: usize = 1 << log_n;
|
||||
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let basek: usize = 15;
|
||||
let a_size: usize = 5;
|
||||
let mat_size: usize = 6;
|
||||
let res_size: usize = a_size;
|
||||
|
||||
[1, 2].iter().for_each(|cols_in| {
|
||||
[1, 2].iter().for_each(|cols_out| {
|
||||
let a_cols: usize = *cols_in;
|
||||
let res_cols: usize = *cols_out;
|
||||
|
||||
let mat_rows: usize = a_size;
|
||||
let mat_cols_in: usize = a_cols;
|
||||
let mat_cols_out: usize = res_cols;
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
module.vmp_apply_tmp_bytes(
|
||||
res_size,
|
||||
a_size,
|
||||
mat_rows,
|
||||
mat_cols_in,
|
||||
mat_cols_out,
|
||||
mat_size,
|
||||
) | module.vec_znx_big_normalize_tmp_bytes(),
|
||||
);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = module.new_vec_znx(a_cols, a_size);
|
||||
|
||||
(0..a_cols).for_each(|i| {
|
||||
a.at_mut(i, a_size - 1)[i + 1] = 1;
|
||||
});
|
||||
|
||||
let mut mat_znx_dft: MatZnxDft<Vec<u8>, FFT64> =
|
||||
module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size);
|
||||
|
||||
let mut c_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size);
|
||||
let mut c_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(mat_cols_out, mat_size);
|
||||
|
||||
let mut tmp: VecZnx<Vec<u8>> = module.new_vec_znx(mat_cols_out, mat_size);
|
||||
|
||||
// Construts a [VecZnxMatDft] that performs cyclic rotations on each submatrix.
|
||||
(0..a.size()).for_each(|row_i| {
|
||||
(0..mat_cols_in).for_each(|col_in_i| {
|
||||
(0..mat_cols_out).for_each(|col_out_i| {
|
||||
let idx = 1 + col_in_i * mat_cols_out + col_out_i;
|
||||
tmp.at_mut(col_out_i, row_i)[idx] = 1 as i64; // X^{idx}
|
||||
module.vec_znx_dft(1, 0, &mut c_dft, col_out_i, &tmp, col_out_i);
|
||||
tmp.at_mut(col_out_i, row_i)[idx] = 0 as i64;
|
||||
});
|
||||
module.mat_znx_dft_set_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft);
|
||||
});
|
||||
});
|
||||
|
||||
let mut a_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(a_cols, a_size);
|
||||
(0..a_cols).for_each(|i| {
|
||||
module.vec_znx_dft(1, 0, &mut a_dft, i, &a, i);
|
||||
});
|
||||
|
||||
module.vmp_apply(&mut c_dft, &a_dft, &mat_znx_dft, scratch.borrow());
|
||||
|
||||
let mut res_have_vi64: Vec<i64> = vec![i64::default(); n];
|
||||
|
||||
let mut res_have: VecZnx<Vec<u8>> = module.new_vec_znx(res_cols, res_size);
|
||||
(0..mat_cols_out).for_each(|i| {
|
||||
module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i);
|
||||
module.vec_znx_big_normalize(basek, &mut res_have, i, &c_big, i, scratch.borrow());
|
||||
});
|
||||
|
||||
(0..mat_cols_out).for_each(|col_i| {
|
||||
let mut res_want_vi64: Vec<i64> = vec![i64::default(); n];
|
||||
(0..a_cols).for_each(|i| {
|
||||
res_want_vi64[(i + 1) + (1 + i * mat_cols_out + col_i)] = 1;
|
||||
});
|
||||
res_have.decode_vec_i64(col_i, basek, basek * a_size, &mut res_have_vi64);
|
||||
assert_eq!(res_have_vi64, res_want_vi64);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vmp_apply_add() {
|
||||
let log_n: i32 = 4;
|
||||
let n: usize = 1 << log_n;
|
||||
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let basek: usize = 8;
|
||||
let a_size: usize = 5;
|
||||
let mat_size: usize = 5;
|
||||
let res_size: usize = a_size;
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
[1, 2].iter().for_each(|cols_in| {
|
||||
[1, 2].iter().for_each(|cols_out| {
|
||||
(0..res_size).for_each(|shift| {
|
||||
let a_cols: usize = *cols_in;
|
||||
let res_cols: usize = *cols_out;
|
||||
|
||||
let mat_rows: usize = a_size;
|
||||
let mat_cols_in: usize = a_cols;
|
||||
let mat_cols_out: usize = res_cols;
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
module.vmp_apply_tmp_bytes(
|
||||
res_size,
|
||||
a_size,
|
||||
mat_rows,
|
||||
mat_cols_in,
|
||||
mat_cols_out,
|
||||
mat_size,
|
||||
) | module.vec_znx_big_normalize_tmp_bytes(),
|
||||
);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = module.new_vec_znx(a_cols, a_size);
|
||||
|
||||
(0..a_cols).for_each(|col_i| {
|
||||
a.fill_uniform(basek, col_i, a.size(), &mut source);
|
||||
});
|
||||
|
||||
let mut mat_znx_dft: MatZnxDft<Vec<u8>, FFT64> =
|
||||
module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size);
|
||||
|
||||
let mut c_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size);
|
||||
let mut c_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(mat_cols_out, mat_size);
|
||||
|
||||
let mut tmp: VecZnx<Vec<u8>> = module.new_vec_znx(mat_cols_out, mat_size);
|
||||
|
||||
// Construts a [VecZnxMatDft] that performs cyclic rotations on each submatrix.
|
||||
(0..a.size()).for_each(|row_i| {
|
||||
(0..mat_cols_in).for_each(|col_in_i| {
|
||||
(0..mat_cols_out).for_each(|col_out_i| {
|
||||
let idx: usize = 1 + col_in_i * mat_cols_out + col_out_i;
|
||||
tmp.at_mut(col_out_i, row_i)[idx] = 1 as i64; // X^{idx}
|
||||
module.vec_znx_dft(1, 0, &mut c_dft, col_out_i, &tmp, col_out_i);
|
||||
tmp.at_mut(col_out_i, row_i)[idx] = 0 as i64;
|
||||
});
|
||||
module.mat_znx_dft_set_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft);
|
||||
});
|
||||
});
|
||||
|
||||
let mut a_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(a_cols, a_size);
|
||||
(0..a_cols).for_each(|i| {
|
||||
module.vec_znx_dft(1, 0, &mut a_dft, i, &a, i);
|
||||
});
|
||||
|
||||
c_dft.zero();
|
||||
(0..c_dft.cols()).for_each(|i| {
|
||||
module.vec_znx_dft(1, 0, &mut c_dft, i, &a, 0);
|
||||
});
|
||||
|
||||
module.vmp_apply_add(&mut c_dft, &a_dft, &mat_znx_dft, shift, scratch.borrow());
|
||||
|
||||
let mut res_have: VecZnx<Vec<u8>> = module.new_vec_znx(res_cols, mat_size);
|
||||
(0..mat_cols_out).for_each(|i| {
|
||||
module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i);
|
||||
module.vec_znx_big_normalize(basek, &mut res_have, i, &c_big, i, scratch.borrow());
|
||||
});
|
||||
|
||||
let mut res_want: VecZnx<Vec<u8>> = module.new_vec_znx(res_cols, mat_size);
|
||||
|
||||
// Equivalent to vmp_add & scale
|
||||
module.vmp_apply(&mut c_dft, &a_dft, &mat_znx_dft, scratch.borrow());
|
||||
(0..mat_cols_out).for_each(|i| {
|
||||
module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i);
|
||||
module.vec_znx_big_normalize(basek, &mut res_want, i, &c_big, i, scratch.borrow());
|
||||
});
|
||||
module.vec_znx_shift_inplace(
|
||||
basek,
|
||||
(shift * basek) as i64,
|
||||
&mut res_want,
|
||||
scratch.borrow(),
|
||||
);
|
||||
(0..res_cols).for_each(|i| {
|
||||
module.vec_znx_add_inplace(&mut res_want, i, &a, 0);
|
||||
module.vec_znx_normalize_inplace(basek, &mut res_want, i, scratch.borrow());
|
||||
});
|
||||
|
||||
assert_eq!(res_want, res_have);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vmp_apply_digits() {
|
||||
let log_n: i32 = 4;
|
||||
let n: usize = 1 << log_n;
|
||||
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let basek: usize = 8;
|
||||
let a_size: usize = 6;
|
||||
let mat_size: usize = 6;
|
||||
let res_size: usize = a_size;
|
||||
|
||||
[1, 2].iter().for_each(|cols_in| {
|
||||
[1, 2].iter().for_each(|cols_out| {
|
||||
[1, 3, 6].iter().for_each(|digits| {
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let a_cols: usize = *cols_in;
|
||||
let res_cols: usize = *cols_out;
|
||||
|
||||
let mat_rows: usize = a_size;
|
||||
let mat_cols_in: usize = a_cols;
|
||||
let mat_cols_out: usize = res_cols;
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
module.vmp_apply_tmp_bytes(
|
||||
res_size,
|
||||
a_size,
|
||||
mat_rows,
|
||||
mat_cols_in,
|
||||
mat_cols_out,
|
||||
mat_size,
|
||||
) | module.vec_znx_big_normalize_tmp_bytes(),
|
||||
);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = module.new_vec_znx(a_cols, a_size);
|
||||
|
||||
(0..a_cols).for_each(|col_i| {
|
||||
a.fill_uniform(basek, col_i, a.size(), &mut source);
|
||||
});
|
||||
|
||||
let mut mat_znx_dft: MatZnxDft<Vec<u8>, FFT64> =
|
||||
module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size);
|
||||
|
||||
let mut c_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size);
|
||||
let mut c_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(mat_cols_out, mat_size);
|
||||
|
||||
let mut tmp: VecZnx<Vec<u8>> = module.new_vec_znx(mat_cols_out, mat_size);
|
||||
|
||||
let rows: usize = a.size() / digits;
|
||||
|
||||
let shift: usize = 1;
|
||||
|
||||
// Construts a [VecZnxMatDft] that performs cyclic rotations on each submatrix.
|
||||
(0..rows).for_each(|row_i| {
|
||||
(0..mat_cols_in).for_each(|col_in_i| {
|
||||
(0..mat_cols_out).for_each(|col_out_i| {
|
||||
let idx: usize = shift + col_in_i * mat_cols_out + col_out_i;
|
||||
let limb: usize = (digits - 1) + row_i * digits;
|
||||
tmp.at_mut(col_out_i, limb)[idx] = 1 as i64; // X^{idx}
|
||||
module.vec_znx_dft(1, 0, &mut c_dft, col_out_i, &tmp, col_out_i);
|
||||
tmp.at_mut(col_out_i, limb)[idx] = 0 as i64;
|
||||
});
|
||||
module.mat_znx_dft_set_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft);
|
||||
});
|
||||
});
|
||||
|
||||
let mut a_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(a_cols, (a_size + digits - 1) / digits);
|
||||
|
||||
(0..*digits).for_each(|di| {
|
||||
(0..a_cols).for_each(|col_i| {
|
||||
module.vec_znx_dft(*digits, digits - 1 - di, &mut a_dft, col_i, &a, col_i);
|
||||
});
|
||||
|
||||
if di == 0 {
|
||||
module.vmp_apply(&mut c_dft, &a_dft, &mat_znx_dft, scratch.borrow());
|
||||
} else {
|
||||
module.vmp_apply_add(&mut c_dft, &a_dft, &mat_znx_dft, di, scratch.borrow());
|
||||
}
|
||||
});
|
||||
|
||||
let mut res_have: VecZnx<Vec<u8>> = module.new_vec_znx(res_cols, mat_size);
|
||||
(0..mat_cols_out).for_each(|i| {
|
||||
module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i);
|
||||
module.vec_znx_big_normalize(basek, &mut res_have, i, &c_big, i, scratch.borrow());
|
||||
});
|
||||
|
||||
let mut res_want: VecZnx<Vec<u8>> = module.new_vec_znx(res_cols, mat_size);
|
||||
let mut tmp: VecZnx<Vec<u8>> = module.new_vec_znx(res_cols, mat_size);
|
||||
(0..res_cols).for_each(|col_i| {
|
||||
(0..a_cols).for_each(|j| {
|
||||
module.vec_znx_rotate(
|
||||
(col_i + j * mat_cols_out + shift) as i64,
|
||||
&mut tmp,
|
||||
0,
|
||||
&a,
|
||||
j,
|
||||
);
|
||||
module.vec_znx_add_inplace(&mut res_want, col_i, &tmp, 0);
|
||||
});
|
||||
module.vec_znx_normalize_inplace(basek, &mut res_want, col_i, scratch.borrow());
|
||||
});
|
||||
|
||||
assert_eq!(res_have, res_want)
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mat_znx_dft_mul_x_pow_minus_one() {
|
||||
let log_n: i32 = 5;
|
||||
let n: usize = 1 << log_n;
|
||||
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let basek: usize = 8;
|
||||
let rows: usize = 2;
|
||||
let cols_in: usize = 2;
|
||||
let cols_out: usize = 2;
|
||||
let size: usize = 4;
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(module.mat_znx_dft_mul_x_pow_minus_one_scratch_space(size, cols_out));
|
||||
|
||||
let mut mat_want: MatZnxDft<Vec<u8>, FFT64> = module.new_mat_znx_dft(rows, cols_in, cols_out, size);
|
||||
let mut mat_have: MatZnxDft<Vec<u8>, FFT64> = module.new_mat_znx_dft(rows, cols_in, cols_out, size);
|
||||
|
||||
let mut tmp: VecZnx<Vec<u8>> = module.new_vec_znx(1, size);
|
||||
let mut tmp_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(cols_out, size);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
(0..mat_want.rows()).for_each(|row_i| {
|
||||
(0..mat_want.cols_in()).for_each(|col_i| {
|
||||
(0..cols_out).for_each(|j| {
|
||||
tmp.fill_uniform(basek, 0, size, &mut source);
|
||||
module.vec_znx_dft(1, 0, &mut tmp_dft, j, &tmp, 0);
|
||||
});
|
||||
|
||||
module.mat_znx_dft_set_row(&mut mat_want, row_i, col_i, &tmp_dft);
|
||||
});
|
||||
});
|
||||
|
||||
let k: i64 = 1;
|
||||
|
||||
module.mat_znx_dft_mul_x_pow_minus_one(k, &mut mat_have, &mat_want, scratch.borrow());
|
||||
|
||||
let mut have: VecZnx<Vec<u8>> = module.new_vec_znx(cols_out, size);
|
||||
let mut want: VecZnx<Vec<u8>> = module.new_vec_znx(cols_out, size);
|
||||
let mut tmp_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, size);
|
||||
|
||||
(0..mat_want.rows()).for_each(|row_i| {
|
||||
(0..mat_want.cols_in()).for_each(|col_i| {
|
||||
module.mat_znx_dft_get_row(&mut tmp_dft, &mat_want, row_i, col_i);
|
||||
|
||||
(0..cols_out).for_each(|j| {
|
||||
module.vec_znx_idft(&mut tmp_big, 0, &tmp_dft, j, scratch.borrow());
|
||||
module.vec_znx_big_normalize(basek, &mut tmp, 0, &tmp_big, 0, scratch.borrow());
|
||||
module.vec_znx_rotate(k, &mut want, j, &tmp, 0);
|
||||
module.vec_znx_sub_ab_inplace(&mut want, j, &tmp, 0);
|
||||
module.vec_znx_normalize_inplace(basek, &mut want, j, scratch.borrow());
|
||||
});
|
||||
|
||||
module.mat_znx_dft_get_row(&mut tmp_dft, &mat_have, row_i, col_i);
|
||||
|
||||
(0..cols_out).for_each(|j| {
|
||||
module.vec_znx_idft(&mut tmp_big, 0, &tmp_dft, j, scratch.borrow());
|
||||
module.vec_znx_big_normalize(basek, &mut have, j, &tmp_big, 0, scratch.borrow());
|
||||
});
|
||||
|
||||
assert_eq!(have, want)
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mat_znx_dft_mul_x_pow_minus_one_add_inplace() {
|
||||
let log_n: i32 = 5;
|
||||
let n: usize = 1 << log_n;
|
||||
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let basek: usize = 8;
|
||||
let rows: usize = 2;
|
||||
let cols_in: usize = 2;
|
||||
let cols_out: usize = 2;
|
||||
let size: usize = 4;
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(module.mat_znx_dft_mul_x_pow_minus_one_scratch_space(size, cols_out));
|
||||
|
||||
let mut mat_want: MatZnxDft<Vec<u8>, FFT64> = module.new_mat_znx_dft(rows, cols_in, cols_out, size);
|
||||
let mut mat_have: MatZnxDft<Vec<u8>, FFT64> = module.new_mat_znx_dft(rows, cols_in, cols_out, size);
|
||||
|
||||
let mut tmp: VecZnx<Vec<u8>> = module.new_vec_znx(1, size);
|
||||
let mut tmp_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(cols_out, size);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
(0..mat_have.rows()).for_each(|row_i| {
|
||||
(0..mat_have.cols_in()).for_each(|col_i| {
|
||||
(0..cols_out).for_each(|j| {
|
||||
tmp.fill_uniform(basek, 0, size, &mut source);
|
||||
module.vec_znx_dft(1, 0, &mut tmp_dft, j, &tmp, 0);
|
||||
});
|
||||
|
||||
module.mat_znx_dft_set_row(&mut mat_have, row_i, col_i, &tmp_dft);
|
||||
});
|
||||
});
|
||||
|
||||
(0..mat_want.rows()).for_each(|row_i| {
|
||||
(0..mat_want.cols_in()).for_each(|col_i| {
|
||||
(0..cols_out).for_each(|j| {
|
||||
tmp.fill_uniform(basek, 0, size, &mut source);
|
||||
module.vec_znx_dft(1, 0, &mut tmp_dft, j, &tmp, 0);
|
||||
});
|
||||
|
||||
module.mat_znx_dft_set_row(&mut mat_want, row_i, col_i, &tmp_dft);
|
||||
});
|
||||
});
|
||||
|
||||
let k: i64 = 1;
|
||||
|
||||
module.mat_znx_dft_mul_x_pow_minus_one_add_inplace(k, &mut mat_have, &mat_want, scratch.borrow());
|
||||
|
||||
let mut have: VecZnx<Vec<u8>> = module.new_vec_znx(cols_out, size);
|
||||
let mut want: VecZnx<Vec<u8>> = module.new_vec_znx(cols_out, size);
|
||||
let mut tmp_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, size);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
(0..mat_want.rows()).for_each(|row_i| {
|
||||
(0..mat_want.cols_in()).for_each(|col_i| {
|
||||
module.mat_znx_dft_get_row(&mut tmp_dft, &mat_want, row_i, col_i);
|
||||
|
||||
(0..cols_out).for_each(|j| {
|
||||
module.vec_znx_idft(&mut tmp_big, 0, &tmp_dft, j, scratch.borrow());
|
||||
module.vec_znx_big_normalize(basek, &mut tmp, 0, &tmp_big, 0, scratch.borrow());
|
||||
module.vec_znx_rotate(k, &mut want, j, &tmp, 0);
|
||||
module.vec_znx_sub_ab_inplace(&mut want, j, &tmp, 0);
|
||||
|
||||
tmp.fill_uniform(basek, 0, size, &mut source);
|
||||
module.vec_znx_add_inplace(&mut want, j, &tmp, 0);
|
||||
module.vec_znx_normalize_inplace(basek, &mut want, j, scratch.borrow());
|
||||
});
|
||||
|
||||
module.mat_znx_dft_get_row(&mut tmp_dft, &mat_have, row_i, col_i);
|
||||
|
||||
(0..cols_out).for_each(|j| {
|
||||
module.vec_znx_idft(&mut tmp_big, 0, &tmp_dft, j, scratch.borrow());
|
||||
module.vec_znx_big_normalize(basek, &mut have, j, &tmp_big, 0, scratch.borrow());
|
||||
});
|
||||
|
||||
assert_eq!(have, want)
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1,365 +0,0 @@
|
||||
use crate::znx_base::ZnxViewMut;
|
||||
use crate::{FFT64, VecZnx, VecZnxBig, VecZnxBigToMut, VecZnxToMut};
|
||||
use rand_distr::{Distribution, Normal};
|
||||
use sampling::source::Source;
|
||||
|
||||
pub trait FillUniform {
|
||||
/// Fills the first `size` size with uniform values in \[-2^{basek-1}, 2^{basek-1}\]
|
||||
fn fill_uniform(&mut self, basek: usize, col_i: usize, size: usize, source: &mut Source);
|
||||
}
|
||||
|
||||
pub trait FillDistF64 {
|
||||
fn fill_dist_f64<D: Distribution<f64>>(
|
||||
&mut self,
|
||||
basek: usize,
|
||||
col_i: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
);
|
||||
}
|
||||
|
||||
pub trait AddDistF64 {
|
||||
/// Adds vector sampled according to the provided distribution, scaled by 2^{-k} and bounded to \[-bound, bound\].
|
||||
fn add_dist_f64<D: Distribution<f64>>(
|
||||
&mut self,
|
||||
basek: usize,
|
||||
col_i: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
);
|
||||
}
|
||||
|
||||
pub trait FillNormal {
|
||||
fn fill_normal(&mut self, basek: usize, col_i: usize, k: usize, source: &mut Source, sigma: f64, bound: f64);
|
||||
}
|
||||
|
||||
pub trait AddNormal {
|
||||
/// Adds a discrete normal vector scaled by 2^{-k} with the provided standard deviation and bounded to \[-bound, bound\].
|
||||
fn add_normal(&mut self, basek: usize, col_i: usize, k: usize, source: &mut Source, sigma: f64, bound: f64);
|
||||
}
|
||||
|
||||
impl<T: AsMut<[u8]> + AsRef<[u8]>> FillUniform for VecZnx<T>
|
||||
where
|
||||
VecZnx<T>: VecZnxToMut,
|
||||
{
|
||||
fn fill_uniform(&mut self, basek: usize, col_i: usize, size: usize, source: &mut Source) {
|
||||
let mut a: VecZnx<&mut [u8]> = self.to_mut();
|
||||
let base2k: u64 = 1 << basek;
|
||||
let mask: u64 = base2k - 1;
|
||||
let base2k_half: i64 = (base2k >> 1) as i64;
|
||||
(0..size).for_each(|j| {
|
||||
a.at_mut(col_i, j)
|
||||
.iter_mut()
|
||||
.for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsMut<[u8]> + AsRef<[u8]>> FillDistF64 for VecZnx<T>
|
||||
where
|
||||
VecZnx<T>: VecZnxToMut,
|
||||
{
|
||||
fn fill_dist_f64<D: Distribution<f64>>(
|
||||
&mut self,
|
||||
basek: usize,
|
||||
col_i: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) {
|
||||
let mut a: VecZnx<&mut [u8]> = self.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = (k + basek - 1) / basek - 1;
|
||||
let basek_rem: usize = (limb + 1) * basek - k;
|
||||
|
||||
if basek_rem != 0 {
|
||||
a.at_mut(col_i, limb).iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a = (dist_f64.round() as i64) << basek_rem;
|
||||
});
|
||||
} else {
|
||||
a.at_mut(col_i, limb).iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a = dist_f64.round() as i64
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsMut<[u8]> + AsRef<[u8]>> AddDistF64 for VecZnx<T>
|
||||
where
|
||||
VecZnx<T>: VecZnxToMut,
|
||||
{
|
||||
fn add_dist_f64<D: Distribution<f64>>(
|
||||
&mut self,
|
||||
basek: usize,
|
||||
col_i: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) {
|
||||
let mut a: VecZnx<&mut [u8]> = self.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = (k + basek - 1) / basek - 1;
|
||||
let basek_rem: usize = (limb + 1) * basek - k;
|
||||
|
||||
if basek_rem != 0 {
|
||||
a.at_mut(col_i, limb).iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a += (dist_f64.round() as i64) << basek_rem;
|
||||
});
|
||||
} else {
|
||||
a.at_mut(col_i, limb).iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a += dist_f64.round() as i64
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsMut<[u8]> + AsRef<[u8]>> FillNormal for VecZnx<T>
|
||||
where
|
||||
VecZnx<T>: VecZnxToMut,
|
||||
{
|
||||
fn fill_normal(&mut self, basek: usize, col_i: usize, k: usize, source: &mut Source, sigma: f64, bound: f64) {
|
||||
self.fill_dist_f64(
|
||||
basek,
|
||||
col_i,
|
||||
k,
|
||||
source,
|
||||
Normal::new(0.0, sigma).unwrap(),
|
||||
bound,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsMut<[u8]> + AsRef<[u8]>> AddNormal for VecZnx<T>
|
||||
where
|
||||
VecZnx<T>: VecZnxToMut,
|
||||
{
|
||||
fn add_normal(&mut self, basek: usize, col_i: usize, k: usize, source: &mut Source, sigma: f64, bound: f64) {
|
||||
self.add_dist_f64(
|
||||
basek,
|
||||
col_i,
|
||||
k,
|
||||
source,
|
||||
Normal::new(0.0, sigma).unwrap(),
|
||||
bound,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsMut<[u8]> + AsRef<[u8]>> FillDistF64 for VecZnxBig<T, FFT64>
|
||||
where
|
||||
VecZnxBig<T, FFT64>: VecZnxBigToMut<FFT64>,
|
||||
{
|
||||
fn fill_dist_f64<D: Distribution<f64>>(
|
||||
&mut self,
|
||||
basek: usize,
|
||||
col_i: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) {
|
||||
let mut a: VecZnxBig<&mut [u8], FFT64> = self.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = (k + basek - 1) / basek - 1;
|
||||
let basek_rem: usize = (limb + 1) * basek - k;
|
||||
|
||||
if basek_rem != 0 {
|
||||
a.at_mut(col_i, limb).iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a = (dist_f64.round() as i64) << basek_rem;
|
||||
});
|
||||
} else {
|
||||
a.at_mut(col_i, limb).iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a = dist_f64.round() as i64
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsMut<[u8]> + AsRef<[u8]>> AddDistF64 for VecZnxBig<T, FFT64>
|
||||
where
|
||||
VecZnxBig<T, FFT64>: VecZnxBigToMut<FFT64>,
|
||||
{
|
||||
fn add_dist_f64<D: Distribution<f64>>(
|
||||
&mut self,
|
||||
basek: usize,
|
||||
col_i: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) {
|
||||
let mut a: VecZnxBig<&mut [u8], FFT64> = self.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = (k + basek - 1) / basek - 1;
|
||||
let basek_rem: usize = (limb + 1) * basek - k;
|
||||
|
||||
if basek_rem != 0 {
|
||||
a.at_mut(col_i, limb).iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a += (dist_f64.round() as i64) << basek_rem;
|
||||
});
|
||||
} else {
|
||||
a.at_mut(col_i, limb).iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a += dist_f64.round() as i64
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsMut<[u8]> + AsRef<[u8]>> FillNormal for VecZnxBig<T, FFT64>
|
||||
where
|
||||
VecZnxBig<T, FFT64>: VecZnxBigToMut<FFT64>,
|
||||
{
|
||||
fn fill_normal(&mut self, basek: usize, col_i: usize, k: usize, source: &mut Source, sigma: f64, bound: f64) {
|
||||
self.fill_dist_f64(
|
||||
basek,
|
||||
col_i,
|
||||
k,
|
||||
source,
|
||||
Normal::new(0.0, sigma).unwrap(),
|
||||
bound,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsMut<[u8]> + AsRef<[u8]>> AddNormal for VecZnxBig<T, FFT64>
|
||||
where
|
||||
VecZnxBig<T, FFT64>: VecZnxBigToMut<FFT64>,
|
||||
{
|
||||
fn add_normal(&mut self, basek: usize, col_i: usize, k: usize, source: &mut Source, sigma: f64, bound: f64) {
|
||||
self.add_dist_f64(
|
||||
basek,
|
||||
col_i,
|
||||
k,
|
||||
source,
|
||||
Normal::new(0.0, sigma).unwrap(),
|
||||
bound,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{AddNormal, FillUniform};
|
||||
use crate::vec_znx_ops::*;
|
||||
use crate::znx_base::*;
|
||||
use crate::{FFT64, Module, Stats, VecZnx};
|
||||
use sampling::source::Source;
|
||||
|
||||
#[test]
|
||||
fn vec_znx_fill_uniform() {
|
||||
let n: usize = 4096;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let basek: usize = 17;
|
||||
let size: usize = 5;
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let cols: usize = 2;
|
||||
let zero: Vec<i64> = vec![0; n];
|
||||
let one_12_sqrt: f64 = 0.28867513459481287;
|
||||
(0..cols).for_each(|col_i| {
|
||||
let mut a: VecZnx<_> = module.new_vec_znx(cols, size);
|
||||
a.fill_uniform(basek, col_i, size, &mut source);
|
||||
(0..cols).for_each(|col_j| {
|
||||
if col_j != col_i {
|
||||
(0..size).for_each(|limb_i| {
|
||||
assert_eq!(a.at(col_j, limb_i), zero);
|
||||
})
|
||||
} else {
|
||||
let std: f64 = a.std(col_i, basek);
|
||||
assert!(
|
||||
(std - one_12_sqrt).abs() < 0.01,
|
||||
"std={} ~!= {}",
|
||||
std,
|
||||
one_12_sqrt
|
||||
);
|
||||
}
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vec_znx_add_normal() {
|
||||
let n: usize = 4096;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let basek: usize = 17;
|
||||
let k: usize = 2 * 17;
|
||||
let size: usize = 5;
|
||||
let sigma: f64 = 3.2;
|
||||
let bound: f64 = 6.0 * sigma;
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let cols: usize = 2;
|
||||
let zero: Vec<i64> = vec![0; n];
|
||||
let k_f64: f64 = (1u64 << k as u64) as f64;
|
||||
(0..cols).for_each(|col_i| {
|
||||
let mut a: VecZnx<_> = module.new_vec_znx(cols, size);
|
||||
a.add_normal(basek, col_i, k, &mut source, sigma, bound);
|
||||
(0..cols).for_each(|col_j| {
|
||||
if col_j != col_i {
|
||||
(0..size).for_each(|limb_i| {
|
||||
assert_eq!(a.at(col_j, limb_i), zero);
|
||||
})
|
||||
} else {
|
||||
let std: f64 = a.std(col_i, basek) * k_f64;
|
||||
assert!((std - sigma).abs() < 0.1, "std={} ~!= {}", std, sigma);
|
||||
}
|
||||
})
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1,180 +0,0 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::ffi::svp;
|
||||
use crate::znx_base::ZnxInfos;
|
||||
use crate::{
|
||||
Backend, DataView, DataViewMut, FFT64, Module, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, ZnxSliceSize, ZnxView,
|
||||
alloc_aligned,
|
||||
};
|
||||
|
||||
pub struct ScalarZnxDft<D, B: Backend> {
|
||||
data: D,
|
||||
n: usize,
|
||||
cols: usize,
|
||||
_phantom: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<D, B: Backend> ZnxInfos for ScalarZnxDft<D, B> {
|
||||
fn cols(&self) -> usize {
|
||||
self.cols
|
||||
}
|
||||
|
||||
fn rows(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn n(&self) -> usize {
|
||||
self.n
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
1
|
||||
}
|
||||
}
|
||||
|
||||
impl<D> ZnxSliceSize for ScalarZnxDft<D, FFT64> {
|
||||
fn sl(&self) -> usize {
|
||||
self.n()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, B: Backend> DataView for ScalarZnxDft<D, B> {
|
||||
type D = D;
|
||||
fn data(&self) -> &Self::D {
|
||||
&self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, B: Backend> DataViewMut for ScalarZnxDft<D, B> {
|
||||
fn data_mut(&mut self) -> &mut Self::D {
|
||||
&mut self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: AsRef<[u8]>> ZnxView for ScalarZnxDft<D, FFT64> {
|
||||
type Scalar = f64;
|
||||
}
|
||||
|
||||
pub(crate) fn bytes_of_scalar_znx_dft<B: Backend>(module: &Module<B>, cols: usize) -> usize {
|
||||
ScalarZnxDftOwned::bytes_of(module, cols)
|
||||
}
|
||||
|
||||
impl<D: From<Vec<u8>>, B: Backend> ScalarZnxDft<D, B> {
|
||||
pub(crate) fn bytes_of(module: &Module<B>, cols: usize) -> usize {
|
||||
unsafe { svp::bytes_of_svp_ppol(module.ptr) as usize * cols }
|
||||
}
|
||||
|
||||
pub(crate) fn new(module: &Module<B>, cols: usize) -> Self {
|
||||
let data = alloc_aligned::<u8>(Self::bytes_of(module, cols));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n: module.n(),
|
||||
cols,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn new_from_bytes(module: &Module<B>, cols: usize, bytes: impl Into<Vec<u8>>) -> Self {
|
||||
let data: Vec<u8> = bytes.into();
|
||||
assert!(data.len() == Self::bytes_of(module, cols));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n: module.n(),
|
||||
cols,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, B: Backend> ScalarZnxDft<D, B> {
|
||||
pub(crate) fn from_data(data: D, n: usize, cols: usize) -> Self {
|
||||
Self {
|
||||
data,
|
||||
n,
|
||||
cols,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_vec_znx_dft(self) -> VecZnxDft<D, B> {
|
||||
VecZnxDft {
|
||||
data: self.data,
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: 1,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type ScalarZnxDftOwned<B> = ScalarZnxDft<Vec<u8>, B>;
|
||||
|
||||
pub trait ScalarZnxDftToRef<B: Backend> {
|
||||
fn to_ref(&self) -> ScalarZnxDft<&[u8], B>;
|
||||
}
|
||||
|
||||
impl<D, B: Backend> ScalarZnxDftToRef<B> for ScalarZnxDft<D, B>
|
||||
where
|
||||
D: AsRef<[u8]>,
|
||||
B: Backend,
|
||||
{
|
||||
fn to_ref(&self) -> ScalarZnxDft<&[u8], B> {
|
||||
ScalarZnxDft {
|
||||
data: self.data.as_ref(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ScalarZnxDftToMut<B: Backend> {
|
||||
fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B>;
|
||||
}
|
||||
|
||||
impl<D, B: Backend> ScalarZnxDftToMut<B> for ScalarZnxDft<D, B>
|
||||
where
|
||||
D: AsMut<[u8]> + AsRef<[u8]>,
|
||||
B: Backend,
|
||||
{
|
||||
fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B> {
|
||||
ScalarZnxDft {
|
||||
data: self.data.as_mut(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, B: Backend> VecZnxDftToRef<B> for ScalarZnxDft<D, B>
|
||||
where
|
||||
D: AsRef<[u8]>,
|
||||
B: Backend,
|
||||
{
|
||||
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
|
||||
VecZnxDft {
|
||||
data: self.data.as_ref(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: 1,
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, B: Backend> VecZnxDftToMut<B> for ScalarZnxDft<D, B>
|
||||
where
|
||||
D: AsRef<[u8]> + AsMut<[u8]>,
|
||||
B: Backend,
|
||||
{
|
||||
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
|
||||
VecZnxDft {
|
||||
data: self.data.as_mut(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: 1,
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,122 +0,0 @@
|
||||
use crate::ffi::svp;
|
||||
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
|
||||
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
|
||||
use crate::{
|
||||
Backend, FFT64, Module, ScalarZnx, ScalarZnxDft, ScalarZnxDftOwned, ScalarZnxDftToMut, ScalarZnxDftToRef, ScalarZnxToMut,
|
||||
ScalarZnxToRef, Scratch, VecZnxDft, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps,
|
||||
};
|
||||
|
||||
pub trait ScalarZnxDftAlloc<B: Backend> {
|
||||
fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned<B>;
|
||||
fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize;
|
||||
fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxDftOwned<B>;
|
||||
}
|
||||
|
||||
pub trait ScalarZnxDftOps<BACKEND: Backend> {
|
||||
fn svp_prepare<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: ScalarZnxDftToMut<BACKEND>,
|
||||
A: ScalarZnxToRef;
|
||||
|
||||
fn svp_apply<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<BACKEND>,
|
||||
A: ScalarZnxDftToRef<BACKEND>,
|
||||
B: VecZnxDftToRef<BACKEND>;
|
||||
|
||||
fn svp_apply_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<BACKEND>,
|
||||
A: ScalarZnxDftToRef<BACKEND>;
|
||||
|
||||
fn scalar_znx_idft<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
|
||||
where
|
||||
R: ScalarZnxToMut,
|
||||
A: ScalarZnxDftToRef<BACKEND>;
|
||||
}
|
||||
|
||||
impl<B: Backend> ScalarZnxDftAlloc<B> for Module<B> {
|
||||
fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned<B> {
|
||||
ScalarZnxDftOwned::new(self, cols)
|
||||
}
|
||||
|
||||
fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize {
|
||||
ScalarZnxDftOwned::bytes_of(self, cols)
|
||||
}
|
||||
|
||||
fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxDftOwned<B> {
|
||||
ScalarZnxDftOwned::new_from_bytes(self, cols, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl ScalarZnxDftOps<FFT64> for Module<FFT64> {
|
||||
fn scalar_znx_idft<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
|
||||
where
|
||||
R: ScalarZnxToMut,
|
||||
A: ScalarZnxDftToRef<FFT64>,
|
||||
{
|
||||
let res_mut: &mut ScalarZnx<&mut [u8]> = &mut res.to_mut();
|
||||
let a_ref: &ScalarZnxDft<&[u8], FFT64> = &a.to_ref();
|
||||
let (mut vec_znx_big, scratch1) = scratch.tmp_vec_znx_big(self, 1, 1);
|
||||
self.vec_znx_idft(&mut vec_znx_big, 0, a_ref, a_col, scratch1);
|
||||
self.vec_znx_copy(res_mut, res_col, &vec_znx_big.to_vec_znx_small(), 0);
|
||||
}
|
||||
|
||||
fn svp_prepare<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: ScalarZnxDftToMut<FFT64>,
|
||||
A: ScalarZnxToRef,
|
||||
{
|
||||
unsafe {
|
||||
svp::svp_prepare(
|
||||
self.ptr,
|
||||
res.to_mut().at_mut_ptr(res_col, 0) as *mut svp::svp_ppol_t,
|
||||
a.to_ref().at_ptr(a_col, 0),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn svp_apply<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: ScalarZnxDftToRef<FFT64>,
|
||||
B: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
let b: VecZnxDft<&[u8], FFT64> = b.to_ref();
|
||||
unsafe {
|
||||
svp::svp_apply_dft_to_dft(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
|
||||
res.size() as u64,
|
||||
res.cols() as u64,
|
||||
a.at_ptr(a_col, 0) as *const svp::svp_ppol_t,
|
||||
b.at_ptr(b_col, 0) as *const vec_znx_dft_t,
|
||||
b.size() as u64,
|
||||
b.cols() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn svp_apply_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: ScalarZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
unsafe {
|
||||
svp::svp_apply_dft_to_dft(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
|
||||
res.size() as u64,
|
||||
res.cols() as u64,
|
||||
a.at_ptr(a_col, 0) as *const svp::svp_ppol_t,
|
||||
res.at_ptr(res_col, 0) as *const vec_znx_dft_t,
|
||||
res.size() as u64,
|
||||
res.cols() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,32 +0,0 @@
|
||||
use crate::znx_base::ZnxInfos;
|
||||
use crate::{Decoding, VecZnx};
|
||||
use rug::Float;
|
||||
use rug::float::Round;
|
||||
use rug::ops::{AddAssignRound, DivAssignRound, SubAssignRound};
|
||||
|
||||
pub trait Stats {
|
||||
/// Returns the standard devaition of the i-th polynomial.
|
||||
fn std(&self, col_i: usize, basek: usize) -> f64;
|
||||
}
|
||||
|
||||
impl<D: AsRef<[u8]>> Stats for VecZnx<D> {
|
||||
fn std(&self, col_i: usize, basek: usize) -> f64 {
|
||||
let prec: u32 = (self.size() * basek) as u32;
|
||||
let mut data: Vec<Float> = (0..self.n()).map(|_| Float::with_val(prec, 0)).collect();
|
||||
self.decode_vec_float(col_i, basek, &mut data);
|
||||
// std = sqrt(sum((xi - avg)^2) / n)
|
||||
let mut avg: Float = Float::with_val(prec, 0);
|
||||
data.iter().for_each(|x| {
|
||||
avg.add_assign_round(x, Round::Nearest);
|
||||
});
|
||||
avg.div_assign_round(Float::with_val(prec, data.len()), Round::Nearest);
|
||||
data.iter_mut().for_each(|x| {
|
||||
x.sub_assign_round(&avg, Round::Nearest);
|
||||
});
|
||||
let mut std: Float = Float::with_val(prec, 0);
|
||||
data.iter().for_each(|x| std += x * x);
|
||||
std.div_assign_round(Float::with_val(prec, data.len()), Round::Nearest);
|
||||
std = std.sqrt();
|
||||
std.to_f64()
|
||||
}
|
||||
}
|
||||
@@ -1,413 +0,0 @@
|
||||
use itertools::izip;
|
||||
|
||||
use crate::DataView;
|
||||
use crate::DataViewMut;
|
||||
use crate::ScalarZnx;
|
||||
use crate::Scratch;
|
||||
use crate::ZnxSliceSize;
|
||||
use crate::ZnxZero;
|
||||
use crate::alloc_aligned;
|
||||
use crate::assert_alignement;
|
||||
use crate::cast_mut;
|
||||
use crate::ffi::znx;
|
||||
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
|
||||
use std::{cmp::min, fmt};
|
||||
|
||||
/// [VecZnx] represents collection of contiguously stacked vector of small norm polynomials of
|
||||
/// Zn\[X\] with [i64] coefficients.
|
||||
/// A [VecZnx] is composed of multiple Zn\[X\] polynomials stored in a single contiguous array
|
||||
/// in the memory.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// Given 3 polynomials (a, b, c) of Zn\[X\], each with 4 columns, then the memory
|
||||
/// layout is: `[a0, b0, c0, a1, b1, c1, a2, b2, c2, a3, b3, c3]`, where ai, bi, ci
|
||||
/// are small polynomials of Zn\[X\].
|
||||
#[derive(PartialEq, Eq)]
|
||||
pub struct VecZnx<D> {
|
||||
pub data: D,
|
||||
pub n: usize,
|
||||
pub cols: usize,
|
||||
pub size: usize,
|
||||
}
|
||||
|
||||
impl<D> fmt::Debug for VecZnx<D>
|
||||
where
|
||||
D: AsRef<[u8]>,
|
||||
{
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}", self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<D> ZnxInfos for VecZnx<D> {
|
||||
fn cols(&self) -> usize {
|
||||
self.cols
|
||||
}
|
||||
|
||||
fn rows(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn n(&self) -> usize {
|
||||
self.n
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
}
|
||||
|
||||
impl<D> ZnxSliceSize for VecZnx<D> {
|
||||
fn sl(&self) -> usize {
|
||||
self.n() * self.cols()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D> DataView for VecZnx<D> {
|
||||
type D = D;
|
||||
fn data(&self) -> &Self::D {
|
||||
&self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D> DataViewMut for VecZnx<D> {
|
||||
fn data_mut(&mut self) -> &mut Self::D {
|
||||
&mut self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: AsRef<[u8]>> ZnxView for VecZnx<D> {
|
||||
type Scalar = i64;
|
||||
}
|
||||
|
||||
impl VecZnx<Vec<u8>> {
|
||||
pub fn rsh_scratch_space(n: usize) -> usize {
|
||||
n * std::mem::size_of::<i64>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: AsMut<[u8]> + AsRef<[u8]>> VecZnx<D> {
|
||||
/// Truncates the precision of the [VecZnx] by k bits.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `basek`: the base two logarithm of the coefficients decomposition.
|
||||
/// * `k`: the number of bits of precision to drop.
|
||||
pub fn trunc_pow2(&mut self, basek: usize, k: usize, col: usize) {
|
||||
if k == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
self.size -= k / basek;
|
||||
|
||||
let k_rem: usize = k % basek;
|
||||
|
||||
if k_rem != 0 {
|
||||
let mask: i64 = ((1 << (basek - k_rem - 1)) - 1) << k_rem;
|
||||
self.at_mut(col, self.size() - 1)
|
||||
.iter_mut()
|
||||
.for_each(|x: &mut i64| *x &= mask)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn rotate(&mut self, k: i64) {
|
||||
unsafe {
|
||||
(0..self.cols()).for_each(|i| {
|
||||
(0..self.size()).for_each(|j| {
|
||||
znx::znx_rotate_inplace_i64(self.n() as u64, k, self.at_mut_ptr(i, j));
|
||||
});
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn rsh(&mut self, basek: usize, k: usize, scratch: &mut Scratch) {
|
||||
let n: usize = self.n();
|
||||
let cols: usize = self.cols();
|
||||
let size: usize = self.size();
|
||||
let steps: usize = k / basek;
|
||||
|
||||
self.raw_mut().rotate_right(n * steps * cols);
|
||||
(0..cols).for_each(|i| {
|
||||
(0..steps).for_each(|j| {
|
||||
self.zero_at(i, j);
|
||||
})
|
||||
});
|
||||
|
||||
let k_rem: usize = k % basek;
|
||||
|
||||
if k_rem != 0 {
|
||||
let (carry, _) = scratch.tmp_slice::<i64>(n);
|
||||
let shift = i64::BITS as usize - k_rem;
|
||||
(0..cols).for_each(|i| {
|
||||
carry.fill(0);
|
||||
(steps..size).for_each(|j| {
|
||||
izip!(carry.iter_mut(), self.at_mut(i, j).iter_mut()).for_each(|(ci, xi)| {
|
||||
*xi += *ci << basek;
|
||||
*ci = (*xi << shift) >> shift;
|
||||
*xi = (*xi - *ci) >> k_rem;
|
||||
});
|
||||
});
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn lsh(&mut self, basek: usize, k: usize, scratch: &mut Scratch) {
|
||||
let n: usize = self.n();
|
||||
let cols: usize = self.cols();
|
||||
let size: usize = self.size();
|
||||
let steps: usize = k / basek;
|
||||
|
||||
self.raw_mut().rotate_left(n * steps * cols);
|
||||
(0..cols).for_each(|i| {
|
||||
(size - steps..size).for_each(|j| {
|
||||
self.zero_at(i, j);
|
||||
})
|
||||
});
|
||||
|
||||
let k_rem: usize = k % basek;
|
||||
|
||||
if k_rem != 0 {
|
||||
let shift: usize = i64::BITS as usize - k_rem;
|
||||
let (tmp_bytes, _) = scratch.tmp_slice::<u8>(n * size_of::<i64>());
|
||||
(0..cols).for_each(|i| {
|
||||
(0..steps).for_each(|j| {
|
||||
self.at_mut(i, j).iter_mut().for_each(|xi| {
|
||||
*xi <<= shift;
|
||||
});
|
||||
});
|
||||
normalize(basek, self, i, tmp_bytes);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: From<Vec<u8>>> VecZnx<D> {
|
||||
pub(crate) fn bytes_of<Scalar: Sized>(n: usize, cols: usize, size: usize) -> usize {
|
||||
n * cols * size * size_of::<Scalar>()
|
||||
}
|
||||
|
||||
pub fn new<Scalar: Sized>(n: usize, cols: usize, size: usize) -> Self {
|
||||
let data = alloc_aligned::<u8>(Self::bytes_of::<Scalar>(n, cols, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n,
|
||||
cols,
|
||||
size,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn new_from_bytes<Scalar: Sized>(n: usize, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
|
||||
let data: Vec<u8> = bytes.into();
|
||||
assert!(data.len() == Self::bytes_of::<Scalar>(n, cols, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n,
|
||||
cols,
|
||||
size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D> VecZnx<D> {
|
||||
pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
|
||||
Self {
|
||||
data,
|
||||
n,
|
||||
cols,
|
||||
size,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_scalar_znx(self) -> ScalarZnx<D> {
|
||||
debug_assert_eq!(
|
||||
self.size, 1,
|
||||
"cannot convert VecZnx to ScalarZnx if cols: {} != 1",
|
||||
self.cols
|
||||
);
|
||||
ScalarZnx {
|
||||
data: self.data,
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Copies the coefficients of `a` on the receiver.
|
||||
/// Copy is done with the minimum size matching both backing arrays.
|
||||
/// Panics if the cols do not match.
|
||||
pub fn copy_vec_znx_from<DataMut, Data>(b: &mut VecZnx<DataMut>, a: &VecZnx<Data>)
|
||||
where
|
||||
DataMut: AsMut<[u8]> + AsRef<[u8]>,
|
||||
Data: AsRef<[u8]>,
|
||||
{
|
||||
assert_eq!(b.cols(), a.cols());
|
||||
let data_a: &[i64] = a.raw();
|
||||
let data_b: &mut [i64] = b.raw_mut();
|
||||
let size = min(data_b.len(), data_a.len());
|
||||
data_b[..size].copy_from_slice(&data_a[..size])
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn normalize_tmp_bytes(n: usize) -> usize {
|
||||
n * std::mem::size_of::<i64>()
|
||||
}
|
||||
|
||||
impl<D: AsRef<[u8]> + AsMut<[u8]>> VecZnx<D> {
|
||||
pub fn normalize(&mut self, basek: usize, a_col: usize, tmp_bytes: &mut [u8]) {
|
||||
normalize(basek, self, a_col, tmp_bytes);
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize<D: AsMut<[u8]> + AsRef<[u8]>>(basek: usize, a: &mut VecZnx<D>, a_col: usize, tmp_bytes: &mut [u8]) {
|
||||
let n: usize = a.n();
|
||||
|
||||
debug_assert!(
|
||||
tmp_bytes.len() >= normalize_tmp_bytes(n),
|
||||
"invalid tmp_bytes: tmp_bytes.len()={} < normalize_tmp_bytes({})",
|
||||
tmp_bytes.len(),
|
||||
n,
|
||||
);
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr())
|
||||
}
|
||||
|
||||
let carry_i64: &mut [i64] = cast_mut(tmp_bytes);
|
||||
|
||||
unsafe {
|
||||
znx::znx_zero_i64_ref(n as u64, carry_i64.as_mut_ptr());
|
||||
(0..a.size()).rev().for_each(|i| {
|
||||
znx::znx_normalize(
|
||||
n as u64,
|
||||
basek as u64,
|
||||
a.at_mut_ptr(a_col, i),
|
||||
carry_i64.as_mut_ptr(),
|
||||
a.at_mut_ptr(a_col, i),
|
||||
carry_i64.as_mut_ptr(),
|
||||
)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: AsMut<[u8]> + AsRef<[u8]>> VecZnx<D>
|
||||
where
|
||||
VecZnx<D>: VecZnxToMut + ZnxInfos,
|
||||
{
|
||||
/// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self].
|
||||
pub fn extract_column<R>(&mut self, self_col: usize, a: &VecZnx<R>, a_col: usize)
|
||||
where
|
||||
R: AsRef<[u8]>,
|
||||
VecZnx<R>: VecZnxToRef + ZnxInfos,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(self_col < self.cols());
|
||||
assert!(a_col < a.cols());
|
||||
}
|
||||
|
||||
let min_size: usize = self.size.min(a.size());
|
||||
let max_size: usize = self.size;
|
||||
|
||||
let mut self_mut: VecZnx<&mut [u8]> = self.to_mut();
|
||||
let a_ref: VecZnx<&[u8]> = a.to_ref();
|
||||
|
||||
(0..min_size).for_each(|i: usize| {
|
||||
self_mut
|
||||
.at_mut(self_col, i)
|
||||
.copy_from_slice(a_ref.at(a_col, i));
|
||||
});
|
||||
|
||||
(min_size..max_size).for_each(|i| {
|
||||
self_mut.zero_at(self_col, i);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: AsRef<[u8]>> fmt::Display for VecZnx<D> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
writeln!(
|
||||
f,
|
||||
"VecZnx(n={}, cols={}, size={})",
|
||||
self.n, self.cols, self.size
|
||||
)?;
|
||||
|
||||
for col in 0..self.cols {
|
||||
writeln!(f, "Column {}:", col)?;
|
||||
for size in 0..self.size {
|
||||
let coeffs = self.at(col, size);
|
||||
write!(f, " Size {}: [", size)?;
|
||||
|
||||
let max_show = 100;
|
||||
let show_count = coeffs.len().min(max_show);
|
||||
|
||||
for (i, &coeff) in coeffs.iter().take(show_count).enumerate() {
|
||||
if i > 0 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
write!(f, "{}", coeff)?;
|
||||
}
|
||||
|
||||
if coeffs.len() > max_show {
|
||||
write!(f, ", ... ({} more)", coeffs.len() - max_show)?;
|
||||
}
|
||||
|
||||
writeln!(f, "]")?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub type VecZnxOwned = VecZnx<Vec<u8>>;
|
||||
pub type VecZnxMut<'a> = VecZnx<&'a mut [u8]>;
|
||||
pub type VecZnxRef<'a> = VecZnx<&'a [u8]>;
|
||||
|
||||
pub trait VecZnxToRef {
|
||||
fn to_ref(&self) -> VecZnx<&[u8]>;
|
||||
}
|
||||
|
||||
impl<D> VecZnxToRef for VecZnx<D>
|
||||
where
|
||||
D: AsRef<[u8]>,
|
||||
{
|
||||
fn to_ref(&self) -> VecZnx<&[u8]> {
|
||||
VecZnx {
|
||||
data: self.data.as_ref(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: self.size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait VecZnxToMut {
|
||||
fn to_mut(&mut self) -> VecZnx<&mut [u8]>;
|
||||
}
|
||||
|
||||
impl<D> VecZnxToMut for VecZnx<D>
|
||||
where
|
||||
D: AsRef<[u8]> + AsMut<[u8]>,
|
||||
{
|
||||
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
|
||||
VecZnx {
|
||||
data: self.data.as_mut(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: self.size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<DataSelf: AsRef<[u8]>> VecZnx<DataSelf> {
|
||||
pub fn clone(&self) -> VecZnx<Vec<u8>> {
|
||||
let self_ref: VecZnx<&[u8]> = self.to_ref();
|
||||
VecZnx {
|
||||
data: self_ref.data.to_vec(),
|
||||
n: self_ref.n,
|
||||
cols: self_ref.cols,
|
||||
size: self_ref.size,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,216 +0,0 @@
|
||||
use crate::ffi::vec_znx_big;
|
||||
use crate::znx_base::{ZnxInfos, ZnxView};
|
||||
use crate::{Backend, DataView, DataViewMut, FFT64, Module, VecZnx, ZnxSliceSize, ZnxViewMut, ZnxZero, alloc_aligned};
|
||||
use std::fmt;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
pub struct VecZnxBig<D, B: Backend> {
|
||||
data: D,
|
||||
n: usize,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
_phantom: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<D, B: Backend> ZnxInfos for VecZnxBig<D, B> {
|
||||
fn cols(&self) -> usize {
|
||||
self.cols
|
||||
}
|
||||
|
||||
fn rows(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn n(&self) -> usize {
|
||||
self.n
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
}
|
||||
|
||||
impl<D> ZnxSliceSize for VecZnxBig<D, FFT64> {
|
||||
fn sl(&self) -> usize {
|
||||
self.n() * self.cols()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, B: Backend> DataView for VecZnxBig<D, B> {
|
||||
type D = D;
|
||||
fn data(&self) -> &Self::D {
|
||||
&self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, B: Backend> DataViewMut for VecZnxBig<D, B> {
|
||||
fn data_mut(&mut self) -> &mut Self::D {
|
||||
&mut self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: AsRef<[u8]>> ZnxView for VecZnxBig<D, FFT64> {
|
||||
type Scalar = i64;
|
||||
}
|
||||
|
||||
pub(crate) fn bytes_of_vec_znx_big<B: Backend>(module: &Module<B>, cols: usize, size: usize) -> usize {
|
||||
unsafe { vec_znx_big::bytes_of_vec_znx_big(module.ptr, size as u64) as usize * cols }
|
||||
}
|
||||
|
||||
impl<D: From<Vec<u8>>, B: Backend> VecZnxBig<D, B> {
|
||||
pub(crate) fn new(module: &Module<B>, cols: usize, size: usize) -> Self {
|
||||
let data = alloc_aligned::<u8>(bytes_of_vec_znx_big(module, cols, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n: module.n(),
|
||||
cols,
|
||||
size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn new_from_bytes(module: &Module<B>, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
|
||||
let data: Vec<u8> = bytes.into();
|
||||
assert!(data.len() == bytes_of_vec_znx_big(module, cols, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n: module.n(),
|
||||
cols,
|
||||
size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, B: Backend> VecZnxBig<D, B> {
|
||||
pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
|
||||
Self {
|
||||
data,
|
||||
n,
|
||||
cols,
|
||||
size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: AsMut<[u8]> + AsRef<[u8]>> VecZnxBig<D, FFT64>
|
||||
where
|
||||
VecZnxBig<D, FFT64>: VecZnxBigToMut<FFT64> + ZnxInfos,
|
||||
{
|
||||
// Consumes the VecZnxBig to return a VecZnx.
|
||||
// Useful when no normalization is needed.
|
||||
pub fn to_vec_znx_small(self) -> VecZnx<D> {
|
||||
VecZnx {
|
||||
data: self.data,
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: self.size,
|
||||
}
|
||||
}
|
||||
|
||||
/// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self].
|
||||
pub fn extract_column<C>(&mut self, self_col: usize, a: &C, a_col: usize)
|
||||
where
|
||||
C: VecZnxBigToRef<FFT64> + ZnxInfos,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(self_col < self.cols());
|
||||
assert!(a_col < a.cols());
|
||||
}
|
||||
|
||||
let min_size: usize = self.size.min(a.size());
|
||||
let max_size: usize = self.size;
|
||||
|
||||
let mut self_mut: VecZnxBig<&mut [u8], FFT64> = self.to_mut();
|
||||
let a_ref: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
|
||||
(0..min_size).for_each(|i: usize| {
|
||||
self_mut
|
||||
.at_mut(self_col, i)
|
||||
.copy_from_slice(a_ref.at(a_col, i));
|
||||
});
|
||||
|
||||
(min_size..max_size).for_each(|i| {
|
||||
self_mut.zero_at(self_col, i);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub type VecZnxBigOwned<B> = VecZnxBig<Vec<u8>, B>;
|
||||
|
||||
pub trait VecZnxBigToRef<B: Backend> {
|
||||
fn to_ref(&self) -> VecZnxBig<&[u8], B>;
|
||||
}
|
||||
|
||||
impl<D, B: Backend> VecZnxBigToRef<B> for VecZnxBig<D, B>
|
||||
where
|
||||
D: AsRef<[u8]>,
|
||||
B: Backend,
|
||||
{
|
||||
fn to_ref(&self) -> VecZnxBig<&[u8], B> {
|
||||
VecZnxBig {
|
||||
data: self.data.as_ref(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: self.size,
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait VecZnxBigToMut<B: Backend> {
|
||||
fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B>;
|
||||
}
|
||||
|
||||
impl<D, B: Backend> VecZnxBigToMut<B> for VecZnxBig<D, B>
|
||||
where
|
||||
D: AsRef<[u8]> + AsMut<[u8]>,
|
||||
B: Backend,
|
||||
{
|
||||
fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B> {
|
||||
VecZnxBig {
|
||||
data: self.data.as_mut(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: self.size,
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: AsRef<[u8]>> fmt::Display for VecZnxBig<D, FFT64> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
writeln!(
|
||||
f,
|
||||
"VecZnxBig(n={}, cols={}, size={})",
|
||||
self.n, self.cols, self.size
|
||||
)?;
|
||||
|
||||
for col in 0..self.cols {
|
||||
writeln!(f, "Column {}:", col)?;
|
||||
for size in 0..self.size {
|
||||
let coeffs = self.at(col, size);
|
||||
write!(f, " Size {}: [", size)?;
|
||||
|
||||
let max_show = 100;
|
||||
let show_count = coeffs.len().min(max_show);
|
||||
|
||||
for (i, &coeff) in coeffs.iter().take(show_count).enumerate() {
|
||||
if i > 0 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
write!(f, "{}", coeff)?;
|
||||
}
|
||||
|
||||
if coeffs.len() > max_show {
|
||||
write!(f, ", ... ({} more)", coeffs.len() - max_show)?;
|
||||
}
|
||||
|
||||
writeln!(f, "]")?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,618 +0,0 @@
|
||||
use crate::ffi::vec_znx;
|
||||
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
|
||||
use crate::{
|
||||
Backend, FFT64, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxScratch,
|
||||
VecZnxToMut, VecZnxToRef, ZnxSliceSize, bytes_of_vec_znx_big,
|
||||
};
|
||||
|
||||
pub trait VecZnxBigAlloc<B: Backend> {
|
||||
/// Allocates a vector Z[X]/(X^N+1) that stores not normalized values.
|
||||
fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBigOwned<B>;
|
||||
|
||||
/// Returns a new [VecZnxBig] with the provided bytes array as backing array.
|
||||
///
|
||||
/// Behavior: takes ownership of the backing array.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `cols`: the number of polynomials..
|
||||
/// * `size`: the number of polynomials per column.
|
||||
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big].
|
||||
///
|
||||
/// # Panics
|
||||
/// If `bytes.len()` < [Module::bytes_of_vec_znx_big].
|
||||
fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<B>;
|
||||
|
||||
// /// Returns a new [VecZnxBig] with the provided bytes array as backing array.
|
||||
// ///
|
||||
// /// Behavior: the backing array is only borrowed.
|
||||
// ///
|
||||
// /// # Arguments
|
||||
// ///
|
||||
// /// * `cols`: the number of polynomials..
|
||||
// /// * `size`: the number of polynomials per column.
|
||||
// /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big].
|
||||
// ///
|
||||
// /// # Panics
|
||||
// /// If `bytes.len()` < [Module::bytes_of_vec_znx_big].
|
||||
// fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig<B>;
|
||||
|
||||
/// Returns the minimum number of bytes necessary to allocate
|
||||
/// a new [VecZnxBig] through [VecZnxBig::from_bytes].
|
||||
fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
pub trait VecZnxBigOps<BACKEND: Backend> {
|
||||
/// Adds `a` to `b` and stores the result on `c`.
|
||||
fn vec_znx_big_add<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<BACKEND>,
|
||||
A: VecZnxBigToRef<BACKEND>,
|
||||
B: VecZnxBigToRef<BACKEND>;
|
||||
|
||||
/// Adds `a` to `b` and stores the result on `b`.
|
||||
fn vec_znx_big_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<BACKEND>,
|
||||
A: VecZnxBigToRef<BACKEND>;
|
||||
|
||||
/// Adds `a` to `b` and stores the result on `c`.
|
||||
fn vec_znx_big_add_small<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<BACKEND>,
|
||||
A: VecZnxBigToRef<BACKEND>,
|
||||
B: VecZnxToRef;
|
||||
|
||||
/// Adds `a` to `b` and stores the result on `b`.
|
||||
fn vec_znx_big_add_small_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<BACKEND>,
|
||||
A: VecZnxToRef;
|
||||
|
||||
/// Subtracts `a` to `b` and stores the result on `c`.
|
||||
fn vec_znx_big_sub<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<BACKEND>,
|
||||
A: VecZnxBigToRef<BACKEND>,
|
||||
B: VecZnxBigToRef<BACKEND>;
|
||||
|
||||
/// Subtracts `a` from `b` and stores the result on `b`.
|
||||
fn vec_znx_big_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<BACKEND>,
|
||||
A: VecZnxBigToRef<BACKEND>;
|
||||
|
||||
/// Subtracts `b` from `a` and stores the result on `b`.
|
||||
fn vec_znx_big_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<BACKEND>,
|
||||
A: VecZnxBigToRef<BACKEND>;
|
||||
|
||||
/// Subtracts `b` from `a` and stores the result on `c`.
|
||||
fn vec_znx_big_sub_small_a<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<BACKEND>,
|
||||
A: VecZnxToRef,
|
||||
B: VecZnxBigToRef<BACKEND>;
|
||||
|
||||
/// Subtracts `a` from `res` and stores the result on `res`.
|
||||
fn vec_znx_big_sub_small_a_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<BACKEND>,
|
||||
A: VecZnxToRef;
|
||||
|
||||
/// Subtracts `b` from `a` and stores the result on `c`.
|
||||
fn vec_znx_big_sub_small_b<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<BACKEND>,
|
||||
A: VecZnxBigToRef<BACKEND>,
|
||||
B: VecZnxToRef;
|
||||
|
||||
/// Subtracts `res` from `a` and stores the result on `res`.
|
||||
fn vec_znx_big_sub_small_b_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<BACKEND>,
|
||||
A: VecZnxToRef;
|
||||
|
||||
/// Negates `a` inplace.
|
||||
fn vec_znx_big_negate_inplace<A>(&self, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxBigToMut<BACKEND>;
|
||||
|
||||
/// Normalizes `a` and stores the result on `b`.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `basek`: normalization basis.
|
||||
/// * `tmp_bytes`: scratch space of size at least [VecZnxBigOps::vec_znx_big_normalize].
|
||||
fn vec_znx_big_normalize<R, A>(&self, basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxBigToRef<FFT64>;
|
||||
|
||||
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`.
|
||||
fn vec_znx_big_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<BACKEND>,
|
||||
A: VecZnxBigToRef<BACKEND>;
|
||||
|
||||
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`.
|
||||
fn vec_znx_big_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxBigToMut<BACKEND>;
|
||||
}
|
||||
|
||||
pub trait VecZnxBigScratch {
|
||||
/// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_normalize].
|
||||
fn vec_znx_big_normalize_tmp_bytes(&self) -> usize;
|
||||
}
|
||||
|
||||
impl<B: Backend> VecZnxBigAlloc<B> for Module<B> {
|
||||
fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBigOwned<B> {
|
||||
VecZnxBig::new(self, cols, size)
|
||||
}
|
||||
|
||||
fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<B> {
|
||||
VecZnxBig::new_from_bytes(self, cols, size, bytes)
|
||||
}
|
||||
|
||||
fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize {
|
||||
bytes_of_vec_znx_big(self, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl VecZnxBigOps<FFT64> for Module<FFT64> {
|
||||
fn vec_znx_big_add<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
B: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let b: VecZnxBig<&[u8], FFT64> = b.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(b.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_add(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(b_col, 0),
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_add(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_sub<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
B: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let b: VecZnxBig<&[u8], FFT64> = b.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(b.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(b_col, 0),
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_sub_small_b<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
B: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let b: VecZnx<&[u8]> = b.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(b.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(b_col, 0),
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_sub_small_b_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_sub_small_a<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxToRef,
|
||||
B: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let b: VecZnxBig<&[u8], FFT64> = b.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(b.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(b_col, 0),
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_sub_small_a_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_add_small<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
B: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let b: VecZnx<&[u8]> = b.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(b.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_add(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(b_col, 0),
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_add_small_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_add(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_negate_inplace<A>(&self, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxBigToMut<FFT64>,
|
||||
{
|
||||
let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_negate(
|
||||
self.ptr,
|
||||
a.at_mut_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_normalize<R, A>(&self, basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
//(Jay)Note: This is calling VezZnxOps::vec_znx_normalize_tmp_bytes and not VecZnxBigOps::vec_znx_big_normalize_tmp_bytes.
|
||||
// In the FFT backend the tmp sizes are same but will be different in the NTT backend
|
||||
// assert!(tmp_bytes.len() >= <Self as VecZnxOps<&mut [u8], & [u8]>>::vec_znx_normalize_tmp_bytes(&self));
|
||||
// assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
|
||||
let (tmp_bytes, _) = scratch.tmp_slice(<Self as VecZnxBigScratch>::vec_znx_big_normalize_tmp_bytes(
|
||||
&self,
|
||||
));
|
||||
unsafe {
|
||||
vec_znx::vec_znx_normalize_base2k(
|
||||
self.ptr,
|
||||
basek as u64,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
self.ptr,
|
||||
k,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxBigToMut<FFT64>,
|
||||
{
|
||||
let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
self.ptr,
|
||||
k,
|
||||
a.at_mut_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> VecZnxBigScratch for Module<B> {
|
||||
fn vec_znx_big_normalize_tmp_bytes(&self) -> usize {
|
||||
<Self as VecZnxScratch>::vec_znx_normalize_tmp_bytes(self)
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user