diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..9d9ef63 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,63 @@ +name: CI + +on: + push: + pull_request: + +env: + CARGO_TERM_COLOR: always + +jobs: + build-and-test: + name: Build & Test (stable, beta, nightly) + runs-on: ubuntu-latest + + strategy: + matrix: + toolchain: [stable, beta, nightly] + + steps: + - uses: actions/checkout@v4 + + - name: Install toolchain + uses: actions/setup-rust@v1 + with: + toolchain: ${{ matrix.toolchain }} + components: clippy + + - name: Cache cargo registry + uses: actions/cache@v3 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + key: ${{ runner.os }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo-registry- + + - name: Cargo fmt check + run: cargo fmt --all -- --check + + - name: Build + run: cargo build --all-targets --verbose + + - name: Run tests + run: cargo test --all-targets --verbose + + - name: Run Clippy + run: cargo clippy --all-targets -- -D warnings + + docs: + name: Check Documentation + runs-on: ubuntu-latest + needs: build-and-test + + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-rust@v1 + with: + toolchain: stable + + - name: Build docs + run: cargo doc --all --no-deps --verbose \ No newline at end of file diff --git a/.gitmodules b/.gitmodules index a7f9b93..58601f7 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,3 @@ -[submodule "backend/spqlios-arithmetic"] - path = backend/spqlios-arithmetic +[submodule "backend/src/implementation/cpu_spqlios/spqlios-arithmetic"] + path = backend/src/implementation/cpu_spqlios/spqlios-arithmetic url = https://github.com/phantomzone-org/spqlios-arithmetic diff --git a/Cargo.lock b/Cargo.lock index ab139e4..3da78e3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -39,8 +39,11 @@ checksum = "7b7e4c2464d97fe331d41de9d5db0def0a96f4d823b8b32a2efd503578988973" name = "backend" version = "0.1.0" dependencies = [ + "byteorder", + "cmake", "criterion", "itertools 0.14.0", + "paste", "rand", "rand_core", "rand_distr", @@ -73,6 +76,15 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" +[[package]] +name = "cc" +version = "1.2.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3a42d84bb6b69d3a8b3eaacf0d88f179e1929695e1ad012b6cf64d9caaa5fd2" +dependencies = [ + "shlex", +] + [[package]] name = "cfg-if" version = "1.0.0" @@ -131,11 +143,21 @@ version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" +[[package]] +name = "cmake" +version = "0.1.54" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0" +dependencies = [ + "cc", +] + [[package]] name = "core" version = "0.1.0" dependencies = [ "backend", + "byteorder", "criterion", "itertools 0.14.0", "rand_distr", @@ -145,9 +167,9 @@ dependencies = [ [[package]] name = "criterion" -version = "0.6.0" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3bf7af66b0989381bd0be551bd7cc91912a655a58c6918420c9527b1fd8b4679" +checksum = "e1c047a62b0cc3e145fa84415a3191f628e980b194c2755aa12300a4e6cbd928" dependencies = [ "anes", "cast", @@ -168,12 +190,12 @@ dependencies = [ [[package]] name = "criterion-plot" -version = "0.5.0" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +checksum = "9b1bcc0dc7dfae599d84ad0b1a55f80cde8af3725da8313b528da95ef783e338" dependencies = [ "cast", - "itertools 0.10.5", + "itertools 0.13.0", ] [[package]] @@ -251,15 +273,6 @@ dependencies = [ "crunchy", ] -[[package]] -name = "itertools" -version = "0.10.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" -dependencies = [ - "either", -] - [[package]] name = "itertools" version = "0.13.0" @@ -340,6 +353,12 @@ version = "11.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + [[package]] name = "plotters" version = "0.3.7" @@ -527,18 +546,18 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.216" +version = "1.0.219" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b9781016e935a97e8beecf0c933758c97a5520d32930e460142b4cd80c6338e" +checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.216" +version = "1.0.219" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46f859dbbf73865c6627ed570e78961cd3ac92407a2d117204c49232485da55e" +checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" dependencies = [ "proc-macro2", "quote", @@ -557,6 +576,12 @@ dependencies = [ "serde", ] +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + [[package]] name = "syn" version = "2.0.96" diff --git a/Cargo.toml b/Cargo.toml index 4852656..3e68e2d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,4 +9,5 @@ rand_chacha = "0.9.0" rand_core = "0.9.3" rand_distr = "0.5.1" itertools = "0.14.0" -criterion = "0.6.0" +criterion = "0.7.0" +byteorder = "1.5.0" diff --git a/NOTICE b/NOTICE new file mode 100644 index 0000000..44031cf --- /dev/null +++ b/NOTICE @@ -0,0 +1,12 @@ +Copyright 2025 Phantom Zone, Jean-Philippe Bossuat + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use +this file except in compliance with the License. You may obtain a copy of the +License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software distributed +under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +CONDITIONS OF ANY KIND, either express or implied. See the License for the +specific language governing permissions and limitations under the License. \ No newline at end of file diff --git a/README.md b/README.md index 52bc74e..f331783 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,165 @@ -

- -

\ No newline at end of file +# šŸ™ Poulpy + +**Poulpy** is a fast & modular FHE library that implements Ring-Learning-With-Errors based homomorphic encryption. It adopts the bivariate polynomial representation proposed in [Revisiting Key Decomposition Techniques for FHE: Simpler, Faster and More Generic](https://eprint.iacr.org/2023/771). In addition to simpler and more efficient arithmetic than the residue number system (RNS), this representation provides a common plaintext space for all schemes and allows easy switching between any two schemes. Poulpy also decouples the schemes implementations from the polynomial arithmetic backend by being built around a hardware abstraction layer (HAL). This enables user to easily provide or use a custom backend. + +### Bivariate Polynomial Representation + +Existing FHE implementations (such as [Lattigo](https://github.com/tuneinsight/lattigo) or [OpenFHE](https://github.com/openfheorg/openfhe-development)) use the [residue-number-system](https://en.wikipedia.org/wiki/Residue_number_system) (RNS) to represent large integers. Although the parallelism and carry-less arithmetic provided by the RNS representation provides a very efficient modular arithmetic over large-integers, it suffers from various drawbacks when used in the context of FHE. The main idea behind the bivariate representation is to decouple the cyclotomic arithmetic from the large number arithmetic. Instead of using the RNS representation for large integer, integers are decomposed in base $2^{-K}$ over the Torus $\mathbb{T}_{N}[X]$. + +This provides the following benefits: + +- **Intuitive, efficient and reusable parameterization & instances:** Only the bit-size of the modulus is required from the user (i.e. Torus precision). As such, parameterization is natural and generic, and instances can be reused for any circuit consuming the same homomorphic capacity, without loss of efficiency. With the RNS representation, individual NTT friendly primes needs to be specified for each level, making the parameterization not user friendly and circuit-specific. + +- **Optimal and granular rescaling:** Ciphertext rescaling is carried out with bit-shifting, enabling a bit-level granular rescaling and optimal noise/homomorphic capacity management. In the RNS representation, ciphertext division can only be done by one of the primes composing the modulus, leading to difficult scaling management and frequent inefficient noise/homomorphic capacity management. + +- **Linear number of DFT in the half external product:** The bivariate representation of the coefficients implicitly provides the digit decomposition, as such the number of DFT is linear in the number of limbs, contrary to the RNS representation where it is quadratic due to the RNS basis conversion. This enables a much more efficient key-switching, which is the **most used and expensive** FHE operation. + +- **Unified plaintext space:** The bivariate polynomial representation is by essence a high precision discretized representation of the Torus $\mathbb{T}_{N}[X]$. Using the Torus as the common plaintext space for all schemes achieves the vision of [CHIMERA: Combining Ring-LWE-based Fully Homomorphic Encryption Schemes](https://eprint.iacr.org/2018/758) which is to unify all RLWE-based FHE schemes (TFHE, FHEW, BGV, BFV, CLPX, GBFV, CKKS, ...) under a single scheme with different encodings, enabling native and efficient scheme-switching functionalities. + +- **Simpler implementation**: Since the cyclotomic arithmetic is decoupled from the coefficient representation, the same pipeline (including DFT) can be reused for all limbs (unlike in the RNS representation), making this representation a prime target for hardware acceleration. + +- **Deterministic computation**: Although being defined on the Torus, bivariate arithmetic remains integer polynomial arithmetic, ensuring all computations are deterministic, the contract being that output should be reproducible and identical, regardless of the backend or hardware. + +### Hardware Abstraction Layer + +In addition to providing a general purpose FHE library over a unified plaintext space, Poulpy is also designed from the ground up around a **hardware abstraction layer** that closely matches the API of [spqlios-arithmetic](https://github.com/tfhe/spqlios-arithmetic). The bivariate representation is by itself hardware friendly as it uses flat, aligned & vectorized memory layout. Finally, generic opaque write only structs (prepared versions) are provided, making it easy for developers to provide hardware focused/optimized operations. This makes possible for anyone to provide or use a custom backend. + +## Library Overview + +- **`backend/hal`**: hardware abstraction layer. This layer targets users that want to provide their own backend or use a third party backend. + + - **`api`**: fixed public low-level polynomial level arithmetic API closely matching spqlios-arithmetic. The goal is to eventually freeze this API, in order to decouple it from the OEP traits, ensuring that changes to implementations do not affect the front end API. + + ```rust + pub trait SvpPrepare { + fn svp_prepare(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: SvpPPolToMut, + A: ScalarZnxToRef; + } + ```` + + - **`delegates`**: link between the user facing API and implementation OEP. Each trait of `api` is implemented by calling its corresponding trait on the `oep`. + + ```rust + impl SvpPrepare for Module + where + B: Backend + SvpPrepareImpl, + { + fn svp_prepare(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: SvpPPolToMut, + A: ScalarZnxToRef, + { + B::svp_prepare_impl(self, res, res_col, a, a_col); + } + } + ``` + + - **`layouts`**: defines the layouts of the front-end algebraic structs matching spqlios-arithmetic definitions, such as `ScalarZnx`, `VecZnx` or opaque backend prepared struct such as `SvpPPol` and `VmpPMat`. + + ```rust + pub struct SvpPPol { + data: D, + n: usize, + cols: usize, + _phantom: PhantomData, + } + ``` + + - **`oep`**: open extension points, which can be implemented by the user to provide a custom backend. + + ```rust + pub unsafe trait SvpPrepareImpl { + fn svp_prepare_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: SvpPPolToMut, + A: ScalarZnxToRef; + } + ``` + + - **`tests`**: exported generic tests for the OEP/structs. Their goal is to enable a user to automatically be able to test its backend implementation, without having to re-implement any tests. + +- **`backend/implementation`**: + - **`cpu_spqlios`**: concrete cpu implementation of the hal through the oep using bindings on spqlios-arithmetic. This implementation currently supports the `FFT64` backend and will be extended to support the `NTT120` backend once it is available in spqlios-arithmetic. + + ```rust + unsafe impl SvpPrepareImpl for FFT64 { + fn svp_prepare_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: SvpPPolToMut, + A: ScalarZnxToRef, + { + unsafe { + svp::svp_prepare( + module.ptr(), + res.to_mut().at_mut_ptr(res_col, 0) as *mut svp::svp_ppol_t, + a.to_ref().at_ptr(a_col, 0), + ) + } + } + } + ``` + +- **`core`**: core of the FHE library, implementing scheme agnostic RLWE arithmetic for LWE, GLWE, GGLWE and GGSW ciphertexts. It notably includes all possible cross-ciphertext operations, for example applying an external product on a GGLWE or an automorphism on a GGSW, as well as blind rotation. This crate is entirely implemented using the hardware abstraction layer API, and is thus solely defined over generic and traits (including tests). As such it will work over any backend, as long as it implements the necessary traits defined in the OEP. + + ```rust + pub struct GLWESecret { + pub(crate) data: ScalarZnx, + pub(crate) dist: Distribution, + } + + pub struct GLWESecretExec { + pub(crate) data: SvpPPol, + pub(crate) dist: Distribution, + } + + impl GLWESecretExec { + pub fn prepare(&mut self, module: &Module, sk: &GLWESecret) + where + O: DataRef, + Module: SvpPrepare, + { + (0..self.rank()).for_each(|i| { + module.svp_prepare(&mut self.data, i, &sk.data, i); + }); + self.dist = sk.dist + } + } + ``` + +## Installation + +TBD — currently not published on crates.io. Clone the repository and use via path-based dependencies. + +## Documentation + +* Full `cargo doc` documentation is coming soon. +* Architecture diagrams and design notes will be added in the [`/doc`](./doc) folder. + +## Contributing + +We welcome external contributions, please see [CONTRIBUTING](./CONTRIBUTING.md). + +## Security + +Please see [SECURITY](./SECURITY.md). + +## License + +Poulpy is licensed under the Apache 2.0 License. See [NOTICE](./NOTICE) & [LICENSE](./LICENSE). + +## Acknowledgement + +**Poulpy** is inspired by the modular architecture of [Lattigo](https://github.com/tuneinsight/lattigo) and [TFHE-go](https://github.com/sp301415/tfhe-go), and its development is lead by Lattigo’s co-author and main contributor [@Pro7ech](https://github.com/Pro7ech). Poulpy reflects the experience gained from over five years of designing and maintaining Lattigo, and represents the next evolution in architecture, performance, and backend philosophy. + +## Citing +Please use the following BibTex entry for citing Lattigo + + @misc{poulpy, + title = {Poulpy v0.1.0}, + howpublished = {Online: \url{https://github.com/phantomzone-org/poulpy}}, + month = Aug, + year = 2025, + note = {Phantom Zone} + } diff --git a/backend/Cargo.toml b/backend/Cargo.toml index 1a74bf6..cef70a0 100644 --- a/backend/Cargo.toml +++ b/backend/Cargo.toml @@ -13,7 +13,12 @@ rand_distr = {workspace = true} rand_core = {workspace = true} sampling = { path = "../sampling" } utils = { path = "../utils" } +paste = "1.0.15" +byteorder = {workspace = true} -[[bench]] -name = "fft" -harness = false \ No newline at end of file +[build-dependencies] +cmake = "0.1.54" + +[package.metadata.docs.rs] +all-features = true +rustdoc-args = ["--cfg", "docsrs"] \ No newline at end of file diff --git a/backend/benches/fft.rs b/backend/benches/fft.rs deleted file mode 100644 index 8106a2d..0000000 --- a/backend/benches/fft.rs +++ /dev/null @@ -1,56 +0,0 @@ -use backend::ffi::reim::*; -use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; -use std::ffi::c_void; - -fn fft(c: &mut Criterion) { - fn forward<'a>(m: u32, log_bound: u32, reim_fft_precomp: *mut reim_fft_precomp, a: &'a [i64]) -> Box { - unsafe { - let buf_a: *mut f64 = reim_fft_precomp_get_buffer(reim_fft_precomp, 0); - reim_from_znx64_simple(m as u32, log_bound as u32, buf_a as *mut c_void, a.as_ptr()); - Box::new(move || reim_fft(reim_fft_precomp, buf_a)) - } - } - - fn backward<'a>(m: u32, log_bound: u32, reim_ifft_precomp: *mut reim_ifft_precomp, a: &'a [i64]) -> Box { - Box::new(move || unsafe { - let buf_a: *mut f64 = reim_ifft_precomp_get_buffer(reim_ifft_precomp, 0); - reim_from_znx64_simple(m as u32, log_bound as u32, buf_a as *mut c_void, a.as_ptr()); - reim_ifft(reim_ifft_precomp, buf_a); - }) - } - - let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = c.benchmark_group("fft"); - - for log_n in 10..17 { - let n: usize = 1 << log_n; - let m: usize = n >> 1; - let log_bound: u32 = 19; - - let mut a: Vec = vec![i64::default(); n]; - a.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); - - unsafe { - let reim_fft_precomp: *mut reim_fft_precomp = new_reim_fft_precomp(m as u32, 1); - let reim_ifft_precomp: *mut reim_ifft_precomp = new_reim_ifft_precomp(m as u32, 1); - - let runners: [(String, Box); 2] = [ - (format!("forward"), { - forward(m as u32, log_bound, reim_fft_precomp, &a) - }), - (format!("backward"), { - backward(m as u32, log_bound, reim_ifft_precomp, &a) - }), - ]; - - for (name, mut runner) in runners { - let id: BenchmarkId = BenchmarkId::new(name, format!("n={}", 1 << log_n)); - b.bench_with_input(id, &(), |b: &mut criterion::Bencher<'_>, _| { - b.iter(&mut runner) - }); - } - } - } -} - -criterion_group!(benches, fft,); -criterion_main!(benches); diff --git a/backend/build.rs b/backend/build.rs index 93e4098..e6d82be 100644 --- a/backend/build.rs +++ b/backend/build.rs @@ -1,13 +1,7 @@ -use std::path::absolute; - -fn main() { - println!( - "cargo:rustc-link-search=native={}", - absolute("spqlios-arithmetic/build/spqlios") - .unwrap() - .to_str() - .unwrap() - ); - println!("cargo:rustc-link-lib=static=spqlios"); - // println!("cargo:rustc-link-lib=dylib=spqlios") -} +mod builds { + pub mod cpu_spqlios; +} + +fn main() { + builds::cpu_spqlios::build() +} diff --git a/backend/builds/cpu_spqlios.rs b/backend/builds/cpu_spqlios.rs new file mode 100644 index 0000000..160af9e --- /dev/null +++ b/backend/builds/cpu_spqlios.rs @@ -0,0 +1,10 @@ +use std::path::PathBuf; + +pub fn build() { + let dst: PathBuf = cmake::Config::new("src/implementation/cpu_spqlios/spqlios-arithmetic").build(); + + let lib_dir: PathBuf = dst.join("lib"); + + println!("cargo:rustc-link-search=native={}", lib_dir.display()); + println!("cargo:rustc-link-lib=static=spqlios"); +} diff --git a/backend/docs/backend_safety_contract.md b/backend/docs/backend_safety_contract.md new file mode 100644 index 0000000..ec27ba3 --- /dev/null +++ b/backend/docs/backend_safety_contract.md @@ -0,0 +1,27 @@ +Implementors must uphold all of the following for **every** call: + +* **Memory domains**: Pointers produced by to_ref() / to_mut() must be valid + in the target execution domain for Self (e.g., CPU host memory for CPU, + device memory for a specific GPU). If host↔device transfers are required, + perform them inside the implementation; do not assume the caller synchronized. + +* **Alignment & layout**: All data must match the layout, stride, and element + size expected by the kernel. size(), rows(), cols_in(), cols_out(), + n(), etc... must be interpreted identically to the reference CPU implementation. + +* **Scratch lifetime**: Any scratch obtained from scratch.tmp_slice(...) (or a + backend-specific variant) must remain valid for the duration of the call; it + may be reused by the caller afterwards. Do not retain pointers past return. + +* **Synchronization**: The call must appear **logically synchronous** to the + caller. If you enqueue asynchronous work (e.g., CUDA streams), you must + ensure completion before returning or clearly document and implement a + synchronization contract used by all backends consistently. + +* **Aliasing & overlaps**: If res, a, b, etc... alias or overlap in ways + that violate your kernel’s requirements, you must either handle safely or reject + with a defined error path (e.g., debug assert). Never trigger UB. + +* **Numerical contract**: For modular/integer arithmetic, results must be + bit-exact to the specification. For floating-point, any permitted tolerance + must be documented and consistent with the crate’s guarantees. \ No newline at end of file diff --git a/backend/examples/fft.rs b/backend/examples/fft.rs deleted file mode 100644 index 63e243c..0000000 --- a/backend/examples/fft.rs +++ /dev/null @@ -1,56 +0,0 @@ -use backend::ffi::reim::*; -use std::ffi::c_void; -use std::time::Instant; - -fn main() { - let log_bound: usize = 19; - - let n: usize = 2048; - let m: usize = n >> 1; - - let mut a: Vec = vec![i64::default(); n]; - let mut b: Vec = vec![i64::default(); n]; - let mut c: Vec = vec![i64::default(); n]; - - a.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); - b[1] = 1; - - println!("{:?}", b); - - unsafe { - let reim_fft_precomp = new_reim_fft_precomp(m as u32, 2); - let reim_ifft_precomp = new_reim_ifft_precomp(m as u32, 1); - - let buf_a = reim_fft_precomp_get_buffer(reim_fft_precomp, 0); - let buf_b = reim_fft_precomp_get_buffer(reim_fft_precomp, 1); - let buf_c = reim_ifft_precomp_get_buffer(reim_ifft_precomp, 0); - - let now = Instant::now(); - (0..1024).for_each(|_| { - reim_from_znx64_simple(m as u32, log_bound as u32, buf_a as *mut c_void, a.as_ptr()); - reim_fft(reim_fft_precomp, buf_a); - - reim_from_znx64_simple(m as u32, log_bound as u32, buf_b as *mut c_void, b.as_ptr()); - reim_fft(reim_fft_precomp, buf_b); - - reim_fftvec_mul_simple( - m as u32, - buf_c as *mut c_void, - buf_a as *mut c_void, - buf_b as *mut c_void, - ); - reim_ifft(reim_ifft_precomp, buf_c); - - reim_to_znx64_simple( - m as u32, - m as f64, - log_bound as u32, - c.as_mut_ptr(), - buf_c as *mut c_void, - ) - }); - - println!("time: {}us", now.elapsed().as_micros()); - println!("{:?}", &c[..16]); - } -} diff --git a/backend/examples/rlwe_encrypt.rs b/backend/examples/rlwe_encrypt.rs index c56496c..56eedf8 100644 --- a/backend/examples/rlwe_encrypt.rs +++ b/backend/examples/rlwe_encrypt.rs @@ -1,7 +1,14 @@ use backend::{ - AddNormal, Decoding, Encoding, FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, - ScalarZnxDftOps, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, - VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxInfos, + hal::{ + api::{ + ModuleNew, ScalarZnxAlloc, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPrepare, + VecZnxAddNormal, VecZnxAlloc, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigNormalize, + VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace, VecZnxDecodeVeci64, VecZnxDftAlloc, VecZnxDftFromVecZnx, + VecZnxDftToVecZnxBigTmpA, VecZnxEncodeVeci64, VecZnxFillUniform, VecZnxNormalizeInplace, ZnxInfos, + }, + layouts::{Module, ScalarZnx, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft}, + }, + implementation::cpu_spqlios::FFT64, }; use itertools::izip; use sampling::source::Source; @@ -12,35 +19,35 @@ fn main() { let ct_size: usize = 3; let msg_size: usize = 2; let log_scale: usize = msg_size * basek - 5; - let module: Module = Module::::new(n); + let module: Module = Module::::new(n as u64); - let mut scratch: ScratchOwned = ScratchOwned::new(module.vec_znx_big_normalize_tmp_bytes()); + let mut scratch: ScratchOwned = ScratchOwned::::alloc(module.vec_znx_big_normalize_tmp_bytes(n)); let seed: [u8; 32] = [0; 32]; let mut source: Source = Source::new(seed); // s <- Z_{-1, 0, 1}[X]/(X^{N}+1) - let mut s: ScalarZnx> = module.new_scalar_znx(1); + let mut s: ScalarZnx> = module.scalar_znx_alloc(1); s.fill_ternary_prob(0, 0.5, &mut source); // Buffer to store s in the DFT domain - let mut s_dft: ScalarZnxDft, FFT64> = module.new_scalar_znx_dft(s.cols()); + let mut s_dft: SvpPPol, FFT64> = module.svp_ppol_alloc(s.cols()); // s_dft <- DFT(s) module.svp_prepare(&mut s_dft, 0, &s, 0); // Allocates a VecZnx with two columns: ct=(0, 0) - let mut ct: VecZnx> = module.new_vec_znx( + let mut ct: VecZnx> = module.vec_znx_alloc( 2, // Number of columns ct_size, // Number of small poly per column ); // Fill the second column with random values: ct = (0, a) - ct.fill_uniform(basek, 1, ct_size, &mut source); + module.vec_znx_fill_uniform(basek, &mut ct, 1, ct_size * basek, &mut source); - let mut buf_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_size); + let mut buf_dft: VecZnxDft, FFT64> = module.vec_znx_dft_alloc(1, ct_size); - module.vec_znx_dft(1, 0, &mut buf_dft, 0, &ct, 1); + module.vec_znx_dft_from_vec_znx(1, 0, &mut buf_dft, 0, &ct, 1); // Applies DFT(ct[1]) * DFT(s) module.svp_apply_inplace( @@ -53,18 +60,18 @@ fn main() { // Alias scratch space (VecZnxDft is always at least as big as VecZnxBig) // BIG(ct[1] * s) <- IDFT(DFT(ct[1] * s)) (not normalized) - let mut buf_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_size); - module.vec_znx_idft_tmp_a(&mut buf_big, 0, &mut buf_dft, 0); + let mut buf_big: VecZnxBig, FFT64> = module.vec_znx_big_alloc(1, ct_size); + module.vec_znx_dft_to_vec_znx_big_tmp_a(&mut buf_big, 0, &mut buf_dft, 0); // Creates a plaintext: VecZnx with 1 column - let mut m = module.new_vec_znx( + let mut m = module.vec_znx_alloc( 1, // Number of columns msg_size, // Number of small polynomials ); let mut want: Vec = vec![0; n]; want.iter_mut() .for_each(|x| *x = source.next_u64n(16, 15) as i64); - m.encode_vec_i64(0, basek, log_scale, &want, 4); + module.encode_vec_i64(basek, &mut m, 0, log_scale, &want, 4); module.vec_znx_normalize_inplace(basek, &mut m, 0, scratch.borrow()); // m - BIG(ct[1] * s) @@ -88,13 +95,14 @@ fn main() { // Add noise to ct[0] // ct[0] <- ct[0] + e - ct.add_normal( + module.vec_znx_add_normal( basek, + &mut ct, 0, // Selects the first column of ct (ct[0]) basek * ct_size, // Scaling of the noise: 2^{-basek * limbs} &mut source, - 3.2, // Standard deviation - 19.0, // Truncatation bound + 3.2, // Standard deviation + 3.2 * 6.0, // Truncatation bound ); // Final ciphertext: ct = (-a * s + m + e, a) @@ -102,7 +110,7 @@ fn main() { // Decryption // DFT(ct[1] * s) - module.vec_znx_dft(1, 0, &mut buf_dft, 0, &ct, 1); + module.vec_znx_dft_from_vec_znx(1, 0, &mut buf_dft, 0, &ct, 1); module.svp_apply_inplace( &mut buf_dft, 0, // Selects the first column of res. @@ -111,18 +119,18 @@ fn main() { ); // BIG(c1 * s) = IDFT(DFT(c1 * s)) - module.vec_znx_idft_tmp_a(&mut buf_big, 0, &mut buf_dft, 0); + module.vec_znx_dft_to_vec_znx_big_tmp_a(&mut buf_big, 0, &mut buf_dft, 0); // BIG(c1 * s) + ct[0] module.vec_znx_big_add_small_inplace(&mut buf_big, 0, &ct, 0); // m + e <- BIG(ct[1] * s + ct[0]) - let mut res = module.new_vec_znx(1, ct_size); + let mut res = module.vec_znx_alloc(1, ct_size); module.vec_znx_big_normalize(basek, &mut res, 0, &buf_big, 0, scratch.borrow()); // have = m * 2^{log_scale} + e let mut have: Vec = vec![i64::default(); n]; - res.decode_vec_i64(0, basek, res.size() * basek, &mut have); + module.decode_vec_i64(basek, &mut res, 0, ct_size * basek, &mut have); let scale: f64 = (1 << (res.size() * basek - log_scale)) as f64; izip!(want.iter(), have.iter()) diff --git a/backend/spqlios-arithmetic b/backend/spqlios-arithmetic deleted file mode 160000 index 0ae9a7b..0000000 --- a/backend/spqlios-arithmetic +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 0ae9a7b5adf07ce0b1797562528dab8e28192238 diff --git a/backend/src/encoding.rs b/backend/src/encoding.rs deleted file mode 100644 index 27d4b8c..0000000 --- a/backend/src/encoding.rs +++ /dev/null @@ -1,344 +0,0 @@ -use crate::ffi::znx::znx_zero_i64_ref; -use crate::znx_base::{ZnxView, ZnxViewMut}; -use crate::{VecZnx, znx_base::ZnxInfos}; -use itertools::izip; -use rug::{Assign, Float}; -use std::cmp::min; - -pub trait Encoding { - /// encode a vector of i64 on the receiver. - /// - /// # Arguments - /// - /// * `col_i`: the index of the poly where to encode the data. - /// * `basek`: base two negative logarithm decomposition of the receiver. - /// * `k`: base two negative logarithm of the scaling of the data. - /// * `data`: data to encode on the receiver. - /// * `log_max`: base two logarithm of the infinity norm of the input data. - fn encode_vec_i64(&mut self, col_i: usize, basek: usize, k: usize, data: &[i64], log_max: usize); - - /// encodes a single i64 on the receiver at the given index. - /// - /// # Arguments - /// - /// * `col_i`: the index of the poly where to encode the data. - /// * `basek`: base two negative logarithm decomposition of the receiver. - /// * `k`: base two negative logarithm of the scaling of the data. - /// * `i`: index of the coefficient on which to encode the data. - /// * `data`: data to encode on the receiver. - /// * `log_max`: base two logarithm of the infinity norm of the input data. - fn encode_coeff_i64(&mut self, col_i: usize, basek: usize, k: usize, i: usize, data: i64, log_max: usize); -} - -pub trait Decoding { - /// decode a vector of i64 from the receiver. - /// - /// # Arguments - /// - /// * `col_i`: the index of the poly where to encode the data. - /// * `basek`: base two negative logarithm decomposition of the receiver. - /// * `k`: base two logarithm of the scaling of the data. - /// * `data`: data to decode from the receiver. - fn decode_vec_i64(&self, col_i: usize, basek: usize, k: usize, data: &mut [i64]); - - /// decode a vector of Float from the receiver. - /// - /// # Arguments - /// * `col_i`: the index of the poly where to encode the data. - /// * `basek`: base two negative logarithm decomposition of the receiver. - /// * `data`: data to decode from the receiver. - fn decode_vec_float(&self, col_i: usize, basek: usize, data: &mut [Float]); - - /// decode a single of i64 from the receiver at the given index. - /// - /// # Arguments - /// - /// * `col_i`: the index of the poly where to encode the data. - /// * `basek`: base two negative logarithm decomposition of the receiver. - /// * `k`: base two negative logarithm of the scaling of the data. - /// * `i`: index of the coefficient to decode. - /// * `data`: data to decode from the receiver. - fn decode_coeff_i64(&self, col_i: usize, basek: usize, k: usize, i: usize) -> i64; -} - -impl + AsRef<[u8]>> Encoding for VecZnx { - fn encode_vec_i64(&mut self, col_i: usize, basek: usize, k: usize, data: &[i64], log_max: usize) { - encode_vec_i64(self, col_i, basek, k, data, log_max) - } - - fn encode_coeff_i64(&mut self, col_i: usize, basek: usize, k: usize, i: usize, value: i64, log_max: usize) { - encode_coeff_i64(self, col_i, basek, k, i, value, log_max) - } -} - -impl> Decoding for VecZnx { - fn decode_vec_i64(&self, col_i: usize, basek: usize, k: usize, data: &mut [i64]) { - decode_vec_i64(self, col_i, basek, k, data) - } - - fn decode_vec_float(&self, col_i: usize, basek: usize, data: &mut [Float]) { - decode_vec_float(self, col_i, basek, data) - } - - fn decode_coeff_i64(&self, col_i: usize, basek: usize, k: usize, i: usize) -> i64 { - decode_coeff_i64(self, col_i, basek, k, i) - } -} - -fn encode_vec_i64 + AsRef<[u8]>>( - a: &mut VecZnx, - col_i: usize, - basek: usize, - k: usize, - data: &[i64], - log_max: usize, -) { - let size: usize = (k + basek - 1) / basek; - - #[cfg(debug_assertions)] - { - assert!( - size <= a.size(), - "invalid argument k: (k + a.basek - 1)/a.basek={} > a.size()={}", - size, - a.size() - ); - assert!(col_i < a.cols()); - assert!(data.len() <= a.n()) - } - - let data_len: usize = data.len(); - let k_rem: usize = basek - (k % basek); - - // Zeroes coefficients of the i-th column - (0..a.size()).for_each(|i| unsafe { - znx_zero_i64_ref(a.n() as u64, a.at_mut_ptr(col_i, i)); - }); - - // If 2^{basek} * 2^{k_rem} < 2^{63}-1, then we can simply copy - // values on the last limb. - // Else we decompose values base2k. - if log_max + k_rem < 63 || k_rem == basek { - a.at_mut(col_i, size - 1)[..data_len].copy_from_slice(&data[..data_len]); - } else { - let mask: i64 = (1 << basek) - 1; - let steps: usize = min(size, (log_max + basek - 1) / basek); - (size - steps..size) - .rev() - .enumerate() - .for_each(|(i, i_rev)| { - let shift: usize = i * basek; - izip!(a.at_mut(col_i, i_rev).iter_mut(), data.iter()).for_each(|(y, x)| *y = (x >> shift) & mask); - }) - } - - // Case where self.prec % self.k != 0. - if k_rem != basek { - let steps: usize = min(size, (log_max + basek - 1) / basek); - (size - steps..size).rev().for_each(|i| { - a.at_mut(col_i, i)[..data_len] - .iter_mut() - .for_each(|x| *x <<= k_rem); - }) - } -} - -fn decode_vec_i64>(a: &VecZnx, col_i: usize, basek: usize, k: usize, data: &mut [i64]) { - let size: usize = (k + basek - 1) / basek; - #[cfg(debug_assertions)] - { - assert!( - data.len() >= a.n(), - "invalid data: data.len()={} < a.n()={}", - data.len(), - a.n() - ); - assert!(col_i < a.cols()); - } - data.copy_from_slice(a.at(col_i, 0)); - let rem: usize = basek - (k % basek); - if k < basek { - data.iter_mut().for_each(|x| *x >>= rem); - } else { - (1..size).for_each(|i| { - if i == size - 1 && rem != basek { - let k_rem: usize = basek - rem; - izip!(a.at(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| { - *y = (*y << k_rem) + (x >> rem); - }); - } else { - izip!(a.at(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| { - *y = (*y << basek) + x; - }); - } - }) - } -} - -fn decode_vec_float>(a: &VecZnx, col_i: usize, basek: usize, data: &mut [Float]) { - let size: usize = a.size(); - #[cfg(debug_assertions)] - { - assert!( - data.len() >= a.n(), - "invalid data: data.len()={} < a.n()={}", - data.len(), - a.n() - ); - assert!(col_i < a.cols()); - } - - let prec: u32 = (basek * size) as u32; - - // 2^{basek} - let base = Float::with_val(prec, (1 << basek) as f64); - - // y[i] = sum x[j][i] * 2^{-basek*j} - (0..size).for_each(|i| { - if i == 0 { - izip!(a.at(col_i, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { - y.assign(*x); - *y /= &base; - }); - } else { - izip!(a.at(col_i, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { - *y += Float::with_val(prec, *x); - *y /= &base; - }); - } - }); -} - -fn encode_coeff_i64 + AsRef<[u8]>>( - a: &mut VecZnx, - col_i: usize, - basek: usize, - k: usize, - i: usize, - value: i64, - log_max: usize, -) { - let size: usize = (k + basek - 1) / basek; - - #[cfg(debug_assertions)] - { - assert!(i < a.n()); - assert!( - size <= a.size(), - "invalid argument k: (k + a.basek - 1)/a.basek={} > a.size()={}", - size, - a.size() - ); - assert!(col_i < a.cols()); - } - - let k_rem: usize = basek - (k % basek); - (0..a.size()).for_each(|j| a.at_mut(col_i, j)[i] = 0); - - // If 2^{basek} * 2^{k_rem} < 2^{63}-1, then we can simply copy - // values on the last limb. - // Else we decompose values base2k. - if log_max + k_rem < 63 || k_rem == basek { - a.at_mut(col_i, size - 1)[i] = value; - } else { - let mask: i64 = (1 << basek) - 1; - let steps: usize = min(size, (log_max + basek - 1) / basek); - (size - steps..size) - .rev() - .enumerate() - .for_each(|(j, j_rev)| { - a.at_mut(col_i, j_rev)[i] = (value >> (j * basek)) & mask; - }) - } - - // Case where prec % k != 0. - if k_rem != basek { - let steps: usize = min(size, (log_max + basek - 1) / basek); - (size - steps..size).rev().for_each(|j| { - a.at_mut(col_i, j)[i] <<= k_rem; - }) - } -} - -fn decode_coeff_i64>(a: &VecZnx, col_i: usize, basek: usize, k: usize, i: usize) -> i64 { - #[cfg(debug_assertions)] - { - assert!(i < a.n()); - assert!(col_i < a.cols()) - } - - let size: usize = (k + basek - 1) / basek; - let data: &[i64] = a.raw(); - let mut res: i64 = 0; - let rem: usize = basek - (k % basek); - let slice_size: usize = a.n() * a.cols(); - (0..size).for_each(|j| { - let x: i64 = data[j * slice_size + i]; - if j == size - 1 && rem != basek { - let k_rem: usize = basek - rem; - res = (res << k_rem) + (x >> rem); - } else { - res = (res << basek) + x; - } - }); - res -} - -#[cfg(test)] -mod tests { - use crate::vec_znx_ops::*; - use crate::znx_base::*; - use crate::{Decoding, Encoding, FFT64, Module, VecZnx, znx_base::ZnxInfos}; - use itertools::izip; - use sampling::source::Source; - - #[test] - fn test_set_get_i64_lo_norm() { - let n: usize = 8; - let module: Module = Module::::new(n); - let basek: usize = 17; - let size: usize = 5; - let k: usize = size * basek - 5; - let mut a: VecZnx<_> = module.new_vec_znx(2, size); - let mut source: Source = Source::new([0u8; 32]); - let raw: &mut [i64] = a.raw_mut(); - raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); - (0..a.cols()).for_each(|col_i| { - let mut have: Vec = vec![i64::default(); n]; - have.iter_mut() - .for_each(|x| *x = (source.next_i64() << 56) >> 56); - a.encode_vec_i64(col_i, basek, k, &have, 10); - let mut want: Vec = vec![i64::default(); n]; - a.decode_vec_i64(col_i, basek, k, &mut want); - izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); - }); - } - - #[test] - fn test_set_get_i64_hi_norm() { - let n: usize = 8; - let module: Module = Module::::new(n); - let basek: usize = 17; - let size: usize = 5; - for k in [1, basek / 2, size * basek - 5] { - let mut a: VecZnx<_> = module.new_vec_znx(2, size); - let mut source = Source::new([0u8; 32]); - let raw: &mut [i64] = a.raw_mut(); - raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); - (0..a.cols()).for_each(|col_i| { - let mut have: Vec = vec![i64::default(); n]; - have.iter_mut().for_each(|x| { - if k < 64 { - *x = source.next_u64n(1 << k, (1 << k) - 1) as i64; - } else { - *x = source.next_i64(); - } - }); - a.encode_vec_i64(col_i, basek, k, &have, std::cmp::min(k, 64)); - let mut want = vec![i64::default(); n]; - a.decode_vec_i64(col_i, basek, k, &mut want); - izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); - }) - } - } -} diff --git a/backend/src/hal/api/mat_znx.rs b/backend/src/hal/api/mat_znx.rs new file mode 100644 index 0000000..3579e7d --- /dev/null +++ b/backend/src/hal/api/mat_znx.rs @@ -0,0 +1,17 @@ +use crate::hal::layouts::MatZnxOwned; + +/// Allocates as [crate::hal::layouts::MatZnx]. +pub trait MatZnxAlloc { + fn mat_znx_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxOwned; +} + +/// Returns the size in bytes to allocate a [crate::hal::layouts::MatZnx]. +pub trait MatZnxAllocBytes { + fn mat_znx_alloc_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize; +} + +/// Consume a vector of bytes into a [crate::hal::layouts::MatZnx]. +/// User must ensure that bytes is memory aligned and that it length is equal to [MatZnxAllocBytes]. +pub trait MatZnxFromBytes { + fn mat_znx_from_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: Vec) -> MatZnxOwned; +} diff --git a/backend/src/hal/api/mod.rs b/backend/src/hal/api/mod.rs new file mode 100644 index 0000000..cb806c9 --- /dev/null +++ b/backend/src/hal/api/mod.rs @@ -0,0 +1,21 @@ +mod mat_znx; +mod module; +mod scalar_znx; +mod scratch; +mod svp_ppol; +mod vec_znx; +mod vec_znx_big; +mod vec_znx_dft; +mod vmp_pmat; +mod znx_base; + +pub use mat_znx::*; +pub use module::*; +pub use scalar_znx::*; +pub use scratch::*; +pub use svp_ppol::*; +pub use vec_znx::*; +pub use vec_znx_big::*; +pub use vec_znx_dft::*; +pub use vmp_pmat::*; +pub use znx_base::*; diff --git a/backend/src/hal/api/module.rs b/backend/src/hal/api/module.rs new file mode 100644 index 0000000..412c70f --- /dev/null +++ b/backend/src/hal/api/module.rs @@ -0,0 +1,6 @@ +use crate::hal::layouts::{Backend, Module}; + +/// Instantiate a new [crate::hal::layouts::Module]. +pub trait ModuleNew { + fn new(n: u64) -> Module; +} diff --git a/backend/src/hal/api/scalar_znx.rs b/backend/src/hal/api/scalar_znx.rs new file mode 100644 index 0000000..cacb93b --- /dev/null +++ b/backend/src/hal/api/scalar_znx.rs @@ -0,0 +1,47 @@ +use crate::hal::layouts::{ScalarZnxOwned, ScalarZnxToMut, ScalarZnxToRef}; + +/// Allocates as [crate::hal::layouts::ScalarZnx]. +pub trait ScalarZnxAlloc { + fn scalar_znx_alloc(&self, cols: usize) -> ScalarZnxOwned; +} + +/// Returns the size in bytes to allocate a [crate::hal::layouts::ScalarZnx]. +pub trait ScalarZnxAllocBytes { + fn scalar_znx_alloc_bytes(&self, cols: usize) -> usize; +} + +/// Consume a vector of bytes into a [crate::hal::layouts::ScalarZnx]. +/// User must ensure that bytes is memory aligned and that it length is equal to [ScalarZnxAllocBytes]. +pub trait ScalarZnxFromBytes { + fn scalar_znx_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxOwned; +} + +/// Applies the mapping X -> X^k to a\[a_col\] and write the result on res\[res_col\]. +pub trait ScalarZnxAutomorphism { + fn scalar_znx_automorphism(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: ScalarZnxToMut, + A: ScalarZnxToRef; +} + +/// Applies the mapping X -> X^k on res\[res_col\]. +pub trait ScalarZnxAutomorphismInplace { + fn scalar_znx_automorphism_inplace(&self, k: i64, res: &mut R, res_col: usize) + where + R: ScalarZnxToMut; +} + +/// Multiply a\[a_col\] with (X^p - 1) and write the result on res\[res_col\]. +pub trait ScalarZnxMulXpMinusOne { + fn scalar_znx_mul_xp_minus_one(&self, p: i64, r: &mut R, r_col: usize, a: &A, a_col: usize) + where + R: ScalarZnxToMut, + A: ScalarZnxToRef; +} + +/// Multiply res\[res_col\] with (X^p - 1). +pub trait ScalarZnxMulXpMinusOneInplace { + fn scalar_znx_mul_xp_minus_one_inplace(&self, p: i64, res: &mut R, res_col: usize) + where + R: ScalarZnxToMut; +} diff --git a/backend/src/hal/api/scratch.rs b/backend/src/hal/api/scratch.rs new file mode 100644 index 0000000..12b856f --- /dev/null +++ b/backend/src/hal/api/scratch.rs @@ -0,0 +1,113 @@ +use crate::hal::layouts::{Backend, MatZnx, Module, ScalarZnx, Scratch, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat}; + +/// Allocates a new [crate::hal::layouts::ScratchOwned] of `size` aligned bytes. +pub trait ScratchOwnedAlloc { + fn alloc(size: usize) -> Self; +} + +/// Borrows a slice of bytes into a [Scratch]. +pub trait ScratchOwnedBorrow { + fn borrow(&mut self) -> &mut Scratch; +} + +/// Wrap an array of mutable borrowed bytes into a [Scratch]. +pub trait ScratchFromBytes { + fn from_bytes(data: &mut [u8]) -> &mut Scratch; +} + +/// Returns how many bytes left can be taken from the scratch. +pub trait ScratchAvailable { + fn available(&self) -> usize; +} + +/// Takes a slice of bytes from a [Scratch] and return a new [Scratch] minus the taken array of bytes. +pub trait TakeSlice { + fn take_slice(&mut self, len: usize) -> (&mut [T], &mut Self); +} + +/// Take a slice of bytes from a [Scratch], wraps it into a [ScalarZnx] and returns it +/// as well as a new [Scratch] minus the taken array of bytes. +pub trait TakeScalarZnx { + fn take_scalar_znx(&mut self, module: &Module, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self); +} + +/// Take a slice of bytes from a [Scratch], wraps it into a [SvpPPol] and returns it +/// as well as a new [Scratch] minus the taken array of bytes. +pub trait TakeSvpPPol { + fn take_svp_ppol(&mut self, module: &Module, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Self); +} + +/// Take a slice of bytes from a [Scratch], wraps it into a [VecZnx] and returns it +/// as well as a new [Scratch] minus the taken array of bytes. +pub trait TakeVecZnx { + fn take_vec_znx(&mut self, module: &Module, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self); +} + +/// Take a slice of bytes from a [Scratch], slices it into a vector of [VecZnx] aand returns it +/// as well as a new [Scratch] minus the taken array of bytes. +pub trait TakeVecZnxSlice { + fn take_vec_znx_slice( + &mut self, + len: usize, + module: &Module, + cols: usize, + size: usize, + ) -> (Vec>, &mut Self); +} + +/// Take a slice of bytes from a [Scratch], wraps it into a [VecZnxBig] and returns it +/// as well as a new [Scratch] minus the taken array of bytes. +pub trait TakeVecZnxBig { + fn take_vec_znx_big(&mut self, module: &Module, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, &mut Self); +} + +/// Take a slice of bytes from a [Scratch], wraps it into a [VecZnxDft] and returns it +/// as well as a new [Scratch] minus the taken array of bytes. +pub trait TakeVecZnxDft { + fn take_vec_znx_dft(&mut self, module: &Module, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, &mut Self); +} + +/// Take a slice of bytes from a [Scratch], slices it into a vector of [VecZnxDft] and returns it +/// as well as a new [Scratch] minus the taken array of bytes. +pub trait TakeVecZnxDftSlice { + fn take_vec_znx_dft_slice( + &mut self, + len: usize, + module: &Module, + cols: usize, + size: usize, + ) -> (Vec>, &mut Self); +} + +/// Take a slice of bytes from a [Scratch], wraps it into a [VmpPMat] and returns it +/// as well as a new [Scratch] minus the taken array of bytes. +pub trait TakeVmpPMat { + fn take_vmp_pmat( + &mut self, + module: &Module, + rows: usize, + cols_in: usize, + cols_out: usize, + size: usize, + ) -> (VmpPMat<&mut [u8], B>, &mut Self); +} + +/// Take a slice of bytes from a [Scratch], wraps it into a [MatZnx] and returns it +/// as well as a new [Scratch] minus the taken array of bytes. +pub trait TakeMatZnx { + fn take_mat_znx( + &mut self, + module: &Module, + rows: usize, + cols_in: usize, + cols_out: usize, + size: usize, + ) -> (MatZnx<&mut [u8]>, &mut Self); +} + +/// Take a slice of bytes from a [Scratch], wraps it into the template's type and returns it +/// as well as a new [Scratch] minus the taken array of bytes. +pub trait TakeLike<'a, B: Backend, T> { + type Output; + fn take_like(&'a mut self, template: &T) -> (Self::Output, &'a mut Self); +} diff --git a/backend/src/hal/api/svp_ppol.rs b/backend/src/hal/api/svp_ppol.rs new file mode 100644 index 0000000..f500923 --- /dev/null +++ b/backend/src/hal/api/svp_ppol.rs @@ -0,0 +1,42 @@ +use crate::hal::layouts::{Backend, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDftToMut, VecZnxDftToRef}; + +/// Allocates as [crate::hal::layouts::SvpPPol]. +pub trait SvpPPolAlloc { + fn svp_ppol_alloc(&self, cols: usize) -> SvpPPolOwned; +} + +/// Returns the size in bytes to allocate a [crate::hal::layouts::SvpPPol]. +pub trait SvpPPolAllocBytes { + fn svp_ppol_alloc_bytes(&self, cols: usize) -> usize; +} + +/// Consume a vector of bytes into a [crate::hal::layouts::MatZnx]. +/// User must ensure that bytes is memory aligned and that it length is equal to [SvpPPolAllocBytes]. +pub trait SvpPPolFromBytes { + fn svp_ppol_from_bytes(&self, cols: usize, bytes: Vec) -> SvpPPolOwned; +} + +/// Prepare a [crate::hal::layouts::ScalarZnx] into an [crate::hal::layouts::SvpPPol]. +pub trait SvpPrepare { + fn svp_prepare(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: SvpPPolToMut, + A: ScalarZnxToRef; +} + +/// Apply a scalar-vector product between `a[a_col]` and `b[b_col]` and stores the result on `res[res_col]`. +pub trait SvpApply { + fn svp_apply(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize) + where + R: VecZnxDftToMut, + A: SvpPPolToRef, + C: VecZnxDftToRef; +} + +/// Apply a scalar-vector product between `res[res_col]` and `a[a_col]` and stores the result on `res[res_col]`. +pub trait SvpApplyInplace { + fn svp_apply_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: SvpPPolToRef; +} diff --git a/backend/src/hal/api/vec_znx.rs b/backend/src/hal/api/vec_znx.rs new file mode 100644 index 0000000..413b90b --- /dev/null +++ b/backend/src/hal/api/vec_znx.rs @@ -0,0 +1,369 @@ +use rand_distr::Distribution; +use rug::Float; +use sampling::source::Source; + +use crate::hal::layouts::{Backend, ScalarZnxToRef, Scratch, VecZnxOwned, VecZnxToMut, VecZnxToRef}; + +pub trait VecZnxAlloc { + /// Allocates a new [crate::hal::layouts::VecZnx]. + /// + /// # Arguments + /// + /// * `cols`: the number of polynomials. + /// * `size`: the number small polynomials per column. + fn vec_znx_alloc(&self, cols: usize, size: usize) -> VecZnxOwned; +} + +pub trait VecZnxFromBytes { + /// Instantiates a new [crate::hal::layouts::VecZnx] from a slice of bytes. + /// The returned [crate::hal::layouts::VecZnx] takes ownership of the slice of bytes. + /// + /// # Arguments + /// + /// * `cols`: the number of polynomials. + /// * `size`: the number small polynomials per column. + fn vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxOwned; +} + +pub trait VecZnxAllocBytes { + /// Returns the number of bytes necessary to allocate a new [crate::hal::layouts::VecZnx]. + fn vec_znx_alloc_bytes(&self, cols: usize, size: usize) -> usize; +} + +pub trait VecZnxNormalizeTmpBytes { + /// Returns the minimum number of bytes necessary for normalization. + fn vec_znx_normalize_tmp_bytes(&self, n: usize) -> usize; +} + +pub trait VecZnxNormalize { + /// Normalizes the selected column of `a` and stores the result into the selected column of `res`. + fn vec_znx_normalize(&self, basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) + where + R: VecZnxToMut, + A: VecZnxToRef; +} + +pub trait VecZnxNormalizeInplace { + /// Normalizes the selected column of `a`. + fn vec_znx_normalize_inplace(&self, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) + where + A: VecZnxToMut; +} + +pub trait VecZnxAdd { + /// Adds the selected column of `a` to the selected column of `b` and writes the result on the selected column of `res`. + fn vec_znx_add(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + B: VecZnxToRef; +} + +pub trait VecZnxAddInplace { + /// Adds the selected column of `a` to the selected column of `res` and writes the result on the selected column of `res`. + fn vec_znx_add_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; +} + +pub trait VecZnxAddScalarInplace { + /// Adds the selected column of `a` on the selected column and limb of `res`. + fn vec_znx_add_scalar_inplace(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: ScalarZnxToRef; +} + +pub trait VecZnxSub { + /// Subtracts the selected column of `b` from the selected column of `a` and writes the result on the selected column of `res`. + fn vec_znx_sub(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + B: VecZnxToRef; +} + +pub trait VecZnxSubABInplace { + /// Subtracts the selected column of `a` from the selected column of `res` inplace. + /// + /// res\[res_col\] -= a\[a_col\] + fn vec_znx_sub_ab_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; +} + +pub trait VecZnxSubBAInplace { + /// Subtracts the selected column of `res` from the selected column of `a` and inplace mutates `res` + /// + /// res\[res_col\] = a\[a_col\] - res\[res_col\] + fn vec_znx_sub_ba_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; +} + +pub trait VecZnxSubScalarInplace { + /// Subtracts the selected column of `a` on the selected column and limb of `res`. + fn vec_znx_sub_scalar_inplace(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: ScalarZnxToRef; +} + +pub trait VecZnxNegate { + // Negates the selected column of `a` and stores the result in `res_col` of `res`. + fn vec_znx_negate(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; +} + +pub trait VecZnxNegateInplace { + /// Negates the selected column of `a`. + fn vec_znx_negate_inplace(&self, a: &mut A, a_col: usize) + where + A: VecZnxToMut; +} + +pub trait VecZnxLshInplace { + /// Left shift by k bits all columns of `a`. + fn vec_znx_lsh_inplace(&self, basek: usize, k: usize, a: &mut A) + where + A: VecZnxToMut; +} + +pub trait VecZnxRshInplace { + /// Right shift by k bits all columns of `a`. + fn vec_znx_rsh_inplace(&self, basek: usize, k: usize, a: &mut A) + where + A: VecZnxToMut; +} + +pub trait VecZnxRotate { + /// Multiplies the selected column of `a` by X^k and stores the result in `res_col` of `res`. + fn vec_znx_rotate(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; +} + +pub trait VecZnxRotateInplace { + /// Multiplies the selected column of `a` by X^k. + fn vec_znx_rotate_inplace(&self, k: i64, a: &mut A, a_col: usize) + where + A: VecZnxToMut; +} + +pub trait VecZnxAutomorphism { + /// Applies the automorphism X^i -> X^ik on the selected column of `a` and stores the result in `res_col` column of `res`. + fn vec_znx_automorphism(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; +} + +pub trait VecZnxAutomorphismInplace { + /// Applies the automorphism X^i -> X^ik on the selected column of `a`. + fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) + where + A: VecZnxToMut; +} + +pub trait VecZnxMulXpMinusOne { + fn vec_znx_mul_xp_minus_one(&self, p: i64, r: &mut R, r_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; +} + +pub trait VecZnxMulXpMinusOneInplace { + fn vec_znx_mul_xp_minus_one_inplace(&self, p: i64, r: &mut R, r_col: usize) + where + R: VecZnxToMut; +} + +pub trait VecZnxSplit { + /// Splits the selected columns of `b` into subrings and copies them them into the selected column of `res`. + /// + /// # Panics + /// + /// This method requires that all [crate::hal::layouts::VecZnx] of b have the same ring degree + /// and that b.n() * b.len() <= a.n() + fn vec_znx_split(&self, res: &mut Vec, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) + where + R: VecZnxToMut, + A: VecZnxToRef; +} + +pub trait VecZnxMerge { + /// Merges the subrings of the selected column of `a` into the selected column of `res`. + /// + /// # Panics + /// + /// This method requires that all [crate::hal::layouts::VecZnx] of a have the same ring degree + /// and that a.n() * a.len() <= b.n() + fn vec_znx_merge(&self, res: &mut R, res_col: usize, a: Vec, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; +} + +pub trait VecZnxSwithcDegree { + fn vec_znx_switch_degree(&self, res: &mut R, res_col: usize, a: &A, col_a: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; +} + +pub trait VecZnxCopy { + fn vec_znx_copy(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; +} + +pub trait VecZnxStd { + /// Returns the standard devaition of the i-th polynomial. + fn vec_znx_std(&self, basek: usize, a: &A, a_col: usize) -> f64 + where + A: VecZnxToRef; +} + +pub trait VecZnxFillUniform { + /// Fills the first `size` size with uniform values in \[-2^{basek-1}, 2^{basek-1}\] + fn vec_znx_fill_uniform(&self, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source) + where + R: VecZnxToMut; +} + +pub trait VecZnxFillDistF64 { + fn vec_znx_fill_dist_f64>( + &self, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + dist: D, + bound: f64, + ) where + R: VecZnxToMut; +} + +pub trait VecZnxAddDistF64 { + /// Adds vector sampled according to the provided distribution, scaled by 2^{-k} and bounded to \[-bound, bound\]. + fn vec_znx_add_dist_f64>( + &self, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + dist: D, + bound: f64, + ) where + R: VecZnxToMut; +} + +pub trait VecZnxFillNormal { + fn vec_znx_fill_normal( + &self, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + sigma: f64, + bound: f64, + ) where + R: VecZnxToMut; +} + +pub trait VecZnxAddNormal { + /// Adds a discrete normal vector scaled by 2^{-k} with the provided standard deviation and bounded to \[-bound, bound\]. + fn vec_znx_add_normal( + &self, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + sigma: f64, + bound: f64, + ) where + R: VecZnxToMut; +} + +pub trait VecZnxEncodeVeci64 { + /// encode a vector of i64 on the receiver. + /// + /// # Arguments + /// + /// * `col_i`: the index of the poly where to encode the data. + /// * `basek`: base two negative logarithm decomposition of the receiver. + /// * `k`: base two negative logarithm of the scaling of the data. + /// * `data`: data to encode on the receiver. + /// * `log_max`: base two logarithm of the infinity norm of the input data. + fn encode_vec_i64(&self, basek: usize, res: &mut R, res_col: usize, k: usize, data: &[i64], log_max: usize) + where + R: VecZnxToMut; +} + +pub trait VecZnxEncodeCoeffsi64 { + /// encodes a single i64 on the receiver at the given index. + /// + /// # Arguments + /// + /// * `res_col`: the index of the poly where to encode the data. + /// * `basek`: base two negative logarithm decomposition of the receiver. + /// * `k`: base two negative logarithm of the scaling of the data. + /// * `i`: index of the coefficient on which to encode the data. + /// * `data`: data to encode on the receiver. + /// * `log_max`: base two logarithm of the infinity norm of the input data. + fn encode_coeff_i64(&self, basek: usize, res: &mut R, res_col: usize, k: usize, i: usize, data: i64, log_max: usize) + where + R: VecZnxToMut; +} + +pub trait VecZnxDecodeVeci64 { + /// decode a vector of i64 from the receiver. + /// + /// # Arguments + /// + /// * `res_col`: the index of the poly where to encode the data. + /// * `basek`: base two negative logarithm decomposition of the receiver. + /// * `k`: base two logarithm of the scaling of the data. + /// * `data`: data to decode from the receiver. + fn decode_vec_i64(&self, basek: usize, res: &R, res_col: usize, k: usize, data: &mut [i64]) + where + R: VecZnxToRef; +} + +pub trait VecZnxDecodeCoeffsi64 { + /// decode a single of i64 from the receiver at the given index. + /// + /// # Arguments + /// + /// * `res_col`: the index of the poly where to encode the data. + /// * `basek`: base two negative logarithm decomposition of the receiver. + /// * `k`: base two negative logarithm of the scaling of the data. + /// * `i`: index of the coefficient to decode. + /// * `data`: data to decode from the receiver. + fn decode_coeff_i64(&self, basek: usize, res: &R, res_col: usize, k: usize, i: usize) -> i64 + where + R: VecZnxToRef; +} + +pub trait VecZnxDecodeVecFloat { + /// decode a vector of Float from the receiver. + /// + /// # Arguments + /// * `col_i`: the index of the poly where to encode the data. + /// * `basek`: base two negative logarithm decomposition of the receiver. + /// * `data`: data to decode from the receiver. + fn decode_vec_float(&self, basek: usize, res: &R, col_i: usize, data: &mut [Float]) + where + R: VecZnxToRef; +} diff --git a/backend/src/hal/api/vec_znx_big.rs b/backend/src/hal/api/vec_znx_big.rs new file mode 100644 index 0000000..7a56dc6 --- /dev/null +++ b/backend/src/hal/api/vec_znx_big.rs @@ -0,0 +1,214 @@ +use rand_distr::Distribution; +use sampling::source::Source; + +use crate::hal::layouts::{Backend, Scratch, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef}; + +/// Allocates as [crate::hal::layouts::VecZnxBig]. +pub trait VecZnxBigAlloc { + fn vec_znx_big_alloc(&self, cols: usize, size: usize) -> VecZnxBigOwned; +} + +/// Returns the size in bytes to allocate a [crate::hal::layouts::VecZnxBig]. +pub trait VecZnxBigAllocBytes { + fn vec_znx_big_alloc_bytes(&self, cols: usize, size: usize) -> usize; +} + +/// Consume a vector of bytes into a [crate::hal::layouts::VecZnxBig]. +/// User must ensure that bytes is memory aligned and that it length is equal to [VecZnxBigAllocBytes]. +pub trait VecZnxBigFromBytes { + fn vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxBigOwned; +} + +/// Add a discrete normal distribution on res. +/// +/// # Arguments +/// * `basek`: base two logarithm of the bivariate representation +/// * `res`: receiver. +/// * `res_col`: column of the receiver on which the operation is performed/stored. +/// * `k`: +/// * `source`: random coin source. +/// * `sigma`: standard deviation of the discrete normal distribution. +/// * `bound`: rejection sampling bound. +pub trait VecZnxBigAddNormal { + fn vec_znx_big_add_normal>( + &self, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + sigma: f64, + bound: f64, + ); +} + +pub trait VecZnxBigFillNormal { + fn vec_znx_big_fill_normal>( + &self, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + sigma: f64, + bound: f64, + ); +} + +pub trait VecZnxBigFillDistF64 { + fn vec_znx_big_fill_dist_f64, D: Distribution>( + &self, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + dist: D, + bound: f64, + ); +} + +pub trait VecZnxBigAddDistF64 { + fn vec_znx_big_add_dist_f64, D: Distribution>( + &self, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + dist: D, + bound: f64, + ); +} + +pub trait VecZnxBigAdd { + /// Adds `a` to `b` and stores the result on `c`. + fn vec_znx_big_add(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + C: VecZnxBigToRef; +} + +pub trait VecZnxBigAddInplace { + /// Adds `a` to `b` and stores the result on `b`. + fn vec_znx_big_add_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef; +} + +pub trait VecZnxBigAddSmall { + /// Adds `a` to `b` and stores the result on `c`. + fn vec_znx_big_add_small(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + C: VecZnxToRef; +} + +pub trait VecZnxBigAddSmallInplace { + /// Adds `a` to `b` and stores the result on `b`. + fn vec_znx_big_add_small_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef; +} + +pub trait VecZnxBigSub { + /// Subtracts `a` to `b` and stores the result on `c`. + fn vec_znx_big_sub(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + C: VecZnxBigToRef; +} + +pub trait VecZnxBigSubABInplace { + /// Subtracts `a` from `b` and stores the result on `b`. + fn vec_znx_big_sub_ab_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef; +} + +pub trait VecZnxBigSubBAInplace { + /// Subtracts `b` from `a` and stores the result on `b`. + fn vec_znx_big_sub_ba_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef; +} + +pub trait VecZnxBigSubSmallA { + /// Subtracts `b` from `a` and stores the result on `c`. + fn vec_znx_big_sub_small_a(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef, + C: VecZnxBigToRef; +} + +pub trait VecZnxBigSubSmallAInplace { + /// Subtracts `a` from `res` and stores the result on `res`. + fn vec_znx_big_sub_small_a_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef; +} + +pub trait VecZnxBigSubSmallB { + /// Subtracts `b` from `a` and stores the result on `c`. + fn vec_znx_big_sub_small_b(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + C: VecZnxToRef; +} + +pub trait VecZnxBigSubSmallBInplace { + /// Subtracts `res` from `a` and stores the result on `res`. + fn vec_znx_big_sub_small_b_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef; +} + +pub trait VecZnxBigNegateInplace { + fn vec_znx_big_negate_inplace(&self, a: &mut A, a_col: usize) + where + A: VecZnxBigToMut; +} + +pub trait VecZnxBigNormalizeTmpBytes { + fn vec_znx_big_normalize_tmp_bytes(&self, n: usize) -> usize; +} + +pub trait VecZnxBigNormalize { + fn vec_znx_big_normalize( + &self, + basek: usize, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxToMut, + A: VecZnxBigToRef; +} + +pub trait VecZnxBigAutomorphism { + /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`. + fn vec_znx_big_automorphism(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef; +} + +pub trait VecZnxBigAutomorphismInplace { + /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`. + fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) + where + A: VecZnxBigToMut; +} diff --git a/backend/src/hal/api/vec_znx_dft.rs b/backend/src/hal/api/vec_znx_dft.rs new file mode 100644 index 0000000..4aeac2f --- /dev/null +++ b/backend/src/hal/api/vec_znx_dft.rs @@ -0,0 +1,96 @@ +use crate::hal::layouts::{ + Backend, Data, Scratch, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, +}; + +pub trait VecZnxDftAlloc { + fn vec_znx_dft_alloc(&self, cols: usize, size: usize) -> VecZnxDftOwned; +} + +pub trait VecZnxDftFromBytes { + fn vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxDftOwned; +} + +pub trait VecZnxDftAllocBytes { + fn vec_znx_dft_alloc_bytes(&self, cols: usize, size: usize) -> usize; +} + +pub trait VecZnxDftToVecZnxBigTmpBytes { + fn vec_znx_dft_to_vec_znx_big_tmp_bytes(&self) -> usize; +} + +pub trait VecZnxDftToVecZnxBig { + fn vec_znx_dft_to_vec_znx_big(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) + where + R: VecZnxBigToMut, + A: VecZnxDftToRef; +} + +pub trait VecZnxDftToVecZnxBigTmpA { + fn vec_znx_dft_to_vec_znx_big_tmp_a(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxDftToMut; +} + +pub trait VecZnxDftToVecZnxBigConsume { + fn vec_znx_dft_to_vec_znx_big_consume(&self, a: VecZnxDft) -> VecZnxBig + where + VecZnxDft: VecZnxDftToMut; +} + +pub trait VecZnxDftAdd { + fn vec_znx_dft_add(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + D: VecZnxDftToRef; +} + +pub trait VecZnxDftAddInplace { + fn vec_znx_dft_add_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef; +} + +pub trait VecZnxDftSub { + fn vec_znx_dft_sub(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + D: VecZnxDftToRef; +} + +pub trait VecZnxDftSubABInplace { + fn vec_znx_dft_sub_ab_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef; +} + +pub trait VecZnxDftSubBAInplace { + fn vec_znx_dft_sub_ba_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef; +} + +pub trait VecZnxDftCopy { + fn vec_znx_dft_copy(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef; +} + +pub trait VecZnxDftFromVecZnx { + fn vec_znx_dft_from_vec_znx(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxToRef; +} + +pub trait VecZnxDftZero { + fn vec_znx_dft_zero(&self, res: &mut R) + where + R: VecZnxDftToMut; +} diff --git a/backend/src/hal/api/vmp_pmat.rs b/backend/src/hal/api/vmp_pmat.rs new file mode 100644 index 0000000..8b0ade7 --- /dev/null +++ b/backend/src/hal/api/vmp_pmat.rs @@ -0,0 +1,90 @@ +use crate::hal::layouts::{ + Backend, MatZnxToRef, Scratch, VecZnxDftToMut, VecZnxDftToRef, VmpPMatOwned, VmpPMatToMut, VmpPMatToRef, +}; + +pub trait VmpPMatAlloc { + fn vmp_pmat_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned; +} + +pub trait VmpPMatAllocBytes { + fn vmp_pmat_alloc_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize; +} + +pub trait VmpPMatFromBytes { + fn vmp_pmat_from_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: Vec) -> VmpPMatOwned; +} + +pub trait VmpPrepareTmpBytes { + fn vmp_prepare_tmp_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize; +} + +pub trait VmpPMatPrepare { + fn vmp_prepare(&self, res: &mut R, a: &A, scratch: &mut Scratch) + where + R: VmpPMatToMut, + A: MatZnxToRef; +} + +pub trait VmpApplyTmpBytes { + fn vmp_apply_tmp_bytes( + &self, + res_size: usize, + a_size: usize, + b_rows: usize, + b_cols_in: usize, + b_cols_out: usize, + b_size: usize, + ) -> usize; +} + +pub trait VmpApply { + /// Applies the vector matrix product [crate::hal::layouts::VecZnxDft] x [crate::hal::layouts::VmpPMat]. + /// + /// A vector matrix product numerically equivalent to a sum of [crate::hal::api::SvpApply], + /// where each [crate::hal::layouts::SvpPPol] is a limb of the input [crate::hal::layouts::VecZnx] in DFT, + /// and each vector a [crate::hal::layouts::VecZnxDft] (row) of the [crate::hal::layouts::VmpPMat]. + /// + /// As such, given an input [crate::hal::layouts::VecZnx] of `i` size and a [crate::hal::layouts::VmpPMat] of `i` rows and + /// `j` size, the output is a [crate::hal::layouts::VecZnx] of `j` size. + /// + /// If there is a mismatch between the dimensions the largest valid ones are used. + /// + /// ```text + /// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p| + /// |h i j| + /// |k l m| + /// ``` + /// where each element is a [crate::hal::layouts::VecZnxDft]. + /// + /// # Arguments + /// + /// * `c`: the output of the vector matrix product, as a [crate::hal::layouts::VecZnxDft]. + /// * `a`: the left operand [crate::hal::layouts::VecZnxDft] of the vector matrix product. + /// * `b`: the right operand [crate::hal::layouts::VmpPMat] of the vector matrix product. + /// * `buf`: scratch space, the size can be obtained with [VmpApplyTmpBytes::vmp_apply_tmp_bytes]. + fn vmp_apply(&self, res: &mut R, a: &A, b: &C, scratch: &mut Scratch) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + C: VmpPMatToRef; +} + +pub trait VmpApplyAddTmpBytes { + fn vmp_apply_add_tmp_bytes( + &self, + res_size: usize, + a_size: usize, + b_rows: usize, + b_cols_in: usize, + b_cols_out: usize, + b_size: usize, + ) -> usize; +} + +pub trait VmpApplyAdd { + fn vmp_apply_add(&self, res: &mut R, a: &A, b: &C, scale: usize, scratch: &mut Scratch) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + C: VmpPMatToRef; +} diff --git a/backend/src/znx_base.rs b/backend/src/hal/api/znx_base.rs similarity index 51% rename from backend/src/znx_base.rs rename to backend/src/hal/api/znx_base.rs index fa7ec49..f354fe7 100644 --- a/backend/src/znx_base.rs +++ b/backend/src/hal/api/znx_base.rs @@ -1,4 +1,4 @@ -use itertools::izip; +use crate::hal::layouts::{Data, DataMut, DataRef}; use rand_distr::num_traits::Zero; pub trait ZnxInfos { @@ -32,7 +32,7 @@ pub trait ZnxSliceSize { } pub trait DataView { - type D; + type D: Data; fn data(&self) -> &Self::D; } @@ -40,8 +40,8 @@ pub trait DataViewMut: DataView { fn data_mut(&mut self) -> &mut Self::D; } -pub trait ZnxView: ZnxInfos + DataView> { - type Scalar: Copy; +pub trait ZnxView: ZnxInfos + DataView { + type Scalar: Copy + Zero; /// Returns a non-mutable pointer to the underlying coefficients array. fn as_ptr(&self) -> *const Self::Scalar { @@ -57,8 +57,8 @@ pub trait ZnxView: ZnxInfos + DataView> { fn at_ptr(&self, i: usize, j: usize) -> *const Self::Scalar { #[cfg(debug_assertions)] { - assert!(i < self.cols(), "{} >= {}", i, self.cols()); - assert!(j < self.size(), "{} >= {}", j, self.size()); + assert!(i < self.cols(), "cols: {} >= {}", i, self.cols()); + assert!(j < self.size(), "size: {} >= {}", j, self.size()); } let offset: usize = self.n() * (j * self.cols() + i); unsafe { self.as_ptr().add(offset) } @@ -70,7 +70,7 @@ pub trait ZnxView: ZnxInfos + DataView> { } } -pub trait ZnxViewMut: ZnxView + DataViewMut> { +pub trait ZnxViewMut: ZnxView + DataViewMut { /// Returns a mutable pointer to the underlying coefficients array. fn as_mut_ptr(&mut self) -> *mut Self::Scalar { self.data_mut().as_mut().as_mut_ptr() as *mut Self::Scalar @@ -85,8 +85,8 @@ pub trait ZnxViewMut: ZnxView + DataViewMut> { fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut Self::Scalar { #[cfg(debug_assertions)] { - assert!(i < self.cols(), "{} >= {}", i, self.cols()); - assert!(j < self.size(), "{} >= {}", j, self.size()); + assert!(i < self.cols(), "cols: {} >= {}", i, self.cols()); + assert!(j < self.size(), "size: {} >= {}", j, self.size()); } let offset: usize = self.n() * (j * self.cols() + i); unsafe { self.as_mut_ptr().add(offset) } @@ -99,101 +99,12 @@ pub trait ZnxViewMut: ZnxView + DataViewMut> { } //(Jay)Note: Can't provide blanket impl. of ZnxView because Scalar is not known -impl ZnxViewMut for T where T: ZnxView + DataViewMut> {} +impl ZnxViewMut for T where T: ZnxView + DataViewMut {} -pub trait ZnxZero: ZnxViewMut + ZnxSliceSize +pub trait ZnxZero where Self: Sized, { - fn zero(&mut self) { - unsafe { - std::ptr::write_bytes(self.as_mut_ptr(), 0, self.n() * self.poly_count()); - } - } - - fn zero_at(&mut self, i: usize, j: usize) { - unsafe { - std::ptr::write_bytes(self.at_mut_ptr(i, j), 0, self.n()); - } - } -} - -// Blanket implementations -impl ZnxZero for T where T: ZnxViewMut + ZnxSliceSize {} // WARNING should not work for mat_znx_dft but it does - -use std::ops::{Add, AddAssign, Div, Mul, Neg, Shl, Shr, Sub}; - -use crate::Scratch; -pub trait Integer: - Copy - + Default - + PartialEq - + PartialOrd - + Add - + Sub - + Mul - + Div - + Neg - + Shl - + Shr - + AddAssign -{ - const BITS: u32; -} - -impl Integer for i64 { - const BITS: u32 = 64; -} - -impl Integer for i128 { - const BITS: u32 = 128; -} - -//(Jay)Note: `rsh` impl. ignores the column -pub fn rsh(k: usize, basek: usize, a: &mut V, _a_col: usize, scratch: &mut Scratch) -where - V::Scalar: From + Integer + Zero, -{ - let n: usize = a.n(); - let _size: usize = a.size(); - let cols: usize = a.cols(); - - let size: usize = a.size(); - let steps: usize = k / basek; - - a.raw_mut().rotate_right(n * steps * cols); - (0..cols).for_each(|i| { - (0..steps).for_each(|j| { - a.zero_at(i, j); - }) - }); - - let k_rem: usize = k % basek; - - if k_rem != 0 { - let (carry, _) = scratch.tmp_slice::(rsh_tmp_bytes::(n)); - - unsafe { - std::ptr::write_bytes(carry.as_mut_ptr(), 0, n * size_of::()); - } - - let basek_t = V::Scalar::from(basek); - let shift = V::Scalar::from(V::Scalar::BITS as usize - k_rem); - let k_rem_t = V::Scalar::from(k_rem); - - (0..cols).for_each(|i| { - (steps..size).for_each(|j| { - izip!(carry.iter_mut(), a.at_mut(i, j).iter_mut()).for_each(|(ci, xi)| { - *xi += *ci << basek_t; - *ci = (*xi << shift) >> shift; - *xi = (*xi - *ci) >> k_rem_t; - }); - }); - carry.iter_mut().for_each(|r| *r = V::Scalar::zero()); - }) - } -} - -pub fn rsh_tmp_bytes(n: usize) -> usize { - n * std::mem::size_of::() + fn zero(&mut self); + fn zero_at(&mut self, i: usize, j: usize); } diff --git a/backend/src/hal/delegates/mat_znx.rs b/backend/src/hal/delegates/mat_znx.rs new file mode 100644 index 0000000..1f63cae --- /dev/null +++ b/backend/src/hal/delegates/mat_znx.rs @@ -0,0 +1,32 @@ +use crate::hal::{ + api::{MatZnxAlloc, MatZnxAllocBytes, MatZnxFromBytes}, + layouts::{Backend, MatZnxOwned, Module}, + oep::{MatZnxAllocBytesImpl, MatZnxAllocImpl, MatZnxFromBytesImpl}, +}; + +impl MatZnxAlloc for Module +where + B: Backend + MatZnxAllocImpl, +{ + fn mat_znx_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxOwned { + B::mat_znx_alloc_impl(self, rows, cols_in, cols_out, size) + } +} + +impl MatZnxAllocBytes for Module +where + B: Backend + MatZnxAllocBytesImpl, +{ + fn mat_znx_alloc_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { + B::mat_znx_alloc_bytes_impl(self, rows, cols_in, cols_out, size) + } +} + +impl MatZnxFromBytes for Module +where + B: Backend + MatZnxFromBytesImpl, +{ + fn mat_znx_from_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: Vec) -> MatZnxOwned { + B::mat_znx_from_bytes_impl(self, rows, cols_in, cols_out, size, bytes) + } +} diff --git a/backend/src/hal/delegates/mod.rs b/backend/src/hal/delegates/mod.rs new file mode 100644 index 0000000..f02a59b --- /dev/null +++ b/backend/src/hal/delegates/mod.rs @@ -0,0 +1,9 @@ +mod mat_znx; +mod module; +mod scalar_znx; +mod scratch; +mod svp_ppol; +mod vec_znx; +mod vec_znx_big; +mod vec_znx_dft; +mod vmp_pmat; diff --git a/backend/src/hal/delegates/module.rs b/backend/src/hal/delegates/module.rs new file mode 100644 index 0000000..a9a8d24 --- /dev/null +++ b/backend/src/hal/delegates/module.rs @@ -0,0 +1,14 @@ +use crate::hal::{ + api::ModuleNew, + layouts::{Backend, Module}, + oep::ModuleNewImpl, +}; + +impl ModuleNew for Module +where + B: Backend + ModuleNewImpl, +{ + fn new(n: u64) -> Self { + B::new_impl(n) + } +} diff --git a/backend/src/hal/delegates/scalar_znx.rs b/backend/src/hal/delegates/scalar_znx.rs new file mode 100644 index 0000000..e8310a4 --- /dev/null +++ b/backend/src/hal/delegates/scalar_znx.rs @@ -0,0 +1,88 @@ +use crate::hal::{ + api::{ + ScalarZnxAlloc, ScalarZnxAllocBytes, ScalarZnxAutomorphism, ScalarZnxAutomorphismInplace, ScalarZnxFromBytes, + ScalarZnxMulXpMinusOne, ScalarZnxMulXpMinusOneInplace, + }, + layouts::{Backend, Module, ScalarZnxOwned, ScalarZnxToMut, ScalarZnxToRef}, + oep::{ + ScalarZnxAllocBytesImpl, ScalarZnxAllocImpl, ScalarZnxAutomorphismImpl, ScalarZnxAutomorphismInplaceIml, + ScalarZnxFromBytesImpl, ScalarZnxMulXpMinusOneImpl, ScalarZnxMulXpMinusOneInplaceImpl, + }, +}; + +impl ScalarZnxAllocBytes for Module +where + B: Backend + ScalarZnxAllocBytesImpl, +{ + fn scalar_znx_alloc_bytes(&self, cols: usize) -> usize { + B::scalar_znx_alloc_bytes_impl(self.n(), cols) + } +} + +impl ScalarZnxAlloc for Module +where + B: Backend + ScalarZnxAllocImpl, +{ + fn scalar_znx_alloc(&self, cols: usize) -> ScalarZnxOwned { + B::scalar_znx_alloc_impl(self.n(), cols) + } +} + +impl ScalarZnxFromBytes for Module +where + B: Backend + ScalarZnxFromBytesImpl, +{ + fn scalar_znx_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxOwned { + B::scalar_znx_from_bytes_impl(self.n(), cols, bytes) + } +} + +impl ScalarZnxAutomorphism for Module +where + B: Backend + ScalarZnxAutomorphismImpl, +{ + fn scalar_znx_automorphism(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: ScalarZnxToMut, + A: ScalarZnxToRef, + { + B::scalar_znx_automorphism_impl(self, k, res, res_col, a, a_col); + } +} + +impl ScalarZnxAutomorphismInplace for Module +where + B: Backend + ScalarZnxAutomorphismInplaceIml, +{ + fn scalar_znx_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) + where + A: ScalarZnxToMut, + { + B::scalar_znx_automorphism_inplace_impl(self, k, a, a_col); + } +} + +impl ScalarZnxMulXpMinusOne for Module +where + B: Backend + ScalarZnxMulXpMinusOneImpl, +{ + fn scalar_znx_mul_xp_minus_one(&self, p: i64, r: &mut R, r_col: usize, a: &A, a_col: usize) + where + R: ScalarZnxToMut, + A: ScalarZnxToRef, + { + B::scalar_znx_mul_xp_minus_one_impl(self, p, r, r_col, a, a_col); + } +} + +impl ScalarZnxMulXpMinusOneInplace for Module +where + B: Backend + ScalarZnxMulXpMinusOneInplaceImpl, +{ + fn scalar_znx_mul_xp_minus_one_inplace(&self, p: i64, r: &mut R, r_col: usize) + where + R: ScalarZnxToMut, + { + B::scalar_znx_mul_xp_minus_one_inplace_impl(self, p, r, r_col); + } +} diff --git a/backend/src/hal/delegates/scratch.rs b/backend/src/hal/delegates/scratch.rs new file mode 100644 index 0000000..350c6a9 --- /dev/null +++ b/backend/src/hal/delegates/scratch.rs @@ -0,0 +1,243 @@ +use crate::hal::{ + api::{ + ScratchAvailable, ScratchFromBytes, ScratchOwnedAlloc, ScratchOwnedBorrow, TakeLike, TakeMatZnx, TakeScalarZnx, + TakeSlice, TakeSvpPPol, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, TakeVecZnxDftSlice, TakeVecZnxSlice, TakeVmpPMat, + }, + layouts::{ + Backend, DataRef, MatZnx, Module, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, + }, + oep::{ + ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeLikeImpl, TakeMatZnxImpl, + TakeScalarZnxImpl, TakeSliceImpl, TakeSvpPPolImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxDftSliceImpl, + TakeVecZnxImpl, TakeVecZnxSliceImpl, TakeVmpPMatImpl, + }, +}; + +impl ScratchOwnedAlloc for ScratchOwned +where + B: Backend + ScratchOwnedAllocImpl, +{ + fn alloc(size: usize) -> Self { + B::scratch_owned_alloc_impl(size) + } +} + +impl ScratchOwnedBorrow for ScratchOwned +where + B: Backend + ScratchOwnedBorrowImpl, +{ + fn borrow(&mut self) -> &mut Scratch { + B::scratch_owned_borrow_impl(self) + } +} + +impl ScratchFromBytes for Scratch +where + B: Backend + ScratchFromBytesImpl, +{ + fn from_bytes(data: &mut [u8]) -> &mut Scratch { + B::scratch_from_bytes_impl(data) + } +} + +impl ScratchAvailable for Scratch +where + B: Backend + ScratchAvailableImpl, +{ + fn available(&self) -> usize { + B::scratch_available_impl(self) + } +} + +impl TakeSlice for Scratch +where + B: Backend + TakeSliceImpl, +{ + fn take_slice(&mut self, len: usize) -> (&mut [T], &mut Self) { + B::take_slice_impl(self, len) + } +} + +impl TakeScalarZnx for Scratch +where + B: Backend + TakeScalarZnxImpl, +{ + fn take_scalar_znx(&mut self, module: &Module, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) { + B::take_scalar_znx_impl(self, module.n(), cols) + } +} + +impl TakeSvpPPol for Scratch +where + B: Backend + TakeSvpPPolImpl, +{ + fn take_svp_ppol(&mut self, module: &Module, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Self) { + B::take_svp_ppol_impl(self, module.n(), cols) + } +} + +impl TakeVecZnx for Scratch +where + B: Backend + TakeVecZnxImpl, +{ + fn take_vec_znx(&mut self, module: &Module, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self) { + B::take_vec_znx_impl(self, module.n(), cols, size) + } +} + +impl TakeVecZnxSlice for Scratch +where + B: Backend + TakeVecZnxSliceImpl, +{ + fn take_vec_znx_slice( + &mut self, + len: usize, + module: &Module, + cols: usize, + size: usize, + ) -> (Vec>, &mut Self) { + B::take_vec_znx_slice_impl(self, len, module.n(), cols, size) + } +} + +impl TakeVecZnxBig for Scratch +where + B: Backend + TakeVecZnxBigImpl, +{ + fn take_vec_znx_big(&mut self, module: &Module, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, &mut Self) { + B::take_vec_znx_big_impl(self, module.n(), cols, size) + } +} + +impl TakeVecZnxDft for Scratch +where + B: Backend + TakeVecZnxDftImpl, +{ + fn take_vec_znx_dft(&mut self, module: &Module, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, &mut Self) { + B::take_vec_znx_dft_impl(self, module.n(), cols, size) + } +} + +impl TakeVecZnxDftSlice for Scratch +where + B: Backend + TakeVecZnxDftSliceImpl, +{ + fn take_vec_znx_dft_slice( + &mut self, + len: usize, + module: &Module, + cols: usize, + size: usize, + ) -> (Vec>, &mut Self) { + B::take_vec_znx_dft_slice_impl(self, len, module.n(), cols, size) + } +} + +impl TakeVmpPMat for Scratch +where + B: Backend + TakeVmpPMatImpl, +{ + fn take_vmp_pmat( + &mut self, + module: &Module, + rows: usize, + cols_in: usize, + cols_out: usize, + size: usize, + ) -> (VmpPMat<&mut [u8], B>, &mut Self) { + B::take_vmp_pmat_impl(self, module.n(), rows, cols_in, cols_out, size) + } +} + +impl TakeMatZnx for Scratch +where + B: Backend + TakeMatZnxImpl, +{ + fn take_mat_znx( + &mut self, + module: &Module, + rows: usize, + cols_in: usize, + cols_out: usize, + size: usize, + ) -> (MatZnx<&mut [u8]>, &mut Self) { + B::take_mat_znx_impl(self, module.n(), rows, cols_in, cols_out, size) + } +} + +impl<'a, B: Backend, D> TakeLike<'a, B, ScalarZnx> for Scratch +where + B: TakeLikeImpl<'a, B, ScalarZnx, Output = ScalarZnx<&'a mut [u8]>>, + D: DataRef, +{ + type Output = ScalarZnx<&'a mut [u8]>; + fn take_like(&'a mut self, template: &ScalarZnx) -> (Self::Output, &'a mut Self) { + B::take_like_impl(self, template) + } +} + +impl<'a, B: Backend, D> TakeLike<'a, B, SvpPPol> for Scratch +where + B: TakeLikeImpl<'a, B, SvpPPol, Output = SvpPPol<&'a mut [u8], B>>, + D: DataRef, +{ + type Output = SvpPPol<&'a mut [u8], B>; + fn take_like(&'a mut self, template: &SvpPPol) -> (Self::Output, &'a mut Self) { + B::take_like_impl(self, template) + } +} + +impl<'a, B: Backend, D> TakeLike<'a, B, VecZnx> for Scratch +where + B: TakeLikeImpl<'a, B, VecZnx, Output = VecZnx<&'a mut [u8]>>, + D: DataRef, +{ + type Output = VecZnx<&'a mut [u8]>; + fn take_like(&'a mut self, template: &VecZnx) -> (Self::Output, &'a mut Self) { + B::take_like_impl(self, template) + } +} + +impl<'a, B: Backend, D> TakeLike<'a, B, VecZnxBig> for Scratch +where + B: TakeLikeImpl<'a, B, VecZnxBig, Output = VecZnxBig<&'a mut [u8], B>>, + D: DataRef, +{ + type Output = VecZnxBig<&'a mut [u8], B>; + fn take_like(&'a mut self, template: &VecZnxBig) -> (Self::Output, &'a mut Self) { + B::take_like_impl(self, template) + } +} + +impl<'a, B: Backend, D> TakeLike<'a, B, VecZnxDft> for Scratch +where + B: TakeLikeImpl<'a, B, VecZnxDft, Output = VecZnxDft<&'a mut [u8], B>>, + D: DataRef, +{ + type Output = VecZnxDft<&'a mut [u8], B>; + fn take_like(&'a mut self, template: &VecZnxDft) -> (Self::Output, &'a mut Self) { + B::take_like_impl(self, template) + } +} + +impl<'a, B: Backend, D> TakeLike<'a, B, MatZnx> for Scratch +where + B: TakeLikeImpl<'a, B, MatZnx, Output = MatZnx<&'a mut [u8]>>, + D: DataRef, +{ + type Output = MatZnx<&'a mut [u8]>; + fn take_like(&'a mut self, template: &MatZnx) -> (Self::Output, &'a mut Self) { + B::take_like_impl(self, template) + } +} + +impl<'a, B: Backend, D> TakeLike<'a, B, VmpPMat> for Scratch +where + B: TakeLikeImpl<'a, B, VmpPMat, Output = VmpPMat<&'a mut [u8], B>>, + D: DataRef, +{ + type Output = VmpPMat<&'a mut [u8], B>; + fn take_like(&'a mut self, template: &VmpPMat) -> (Self::Output, &'a mut Self) { + B::take_like_impl(self, template) + } +} diff --git a/backend/src/hal/delegates/svp_ppol.rs b/backend/src/hal/delegates/svp_ppol.rs new file mode 100644 index 0000000..e968f8d --- /dev/null +++ b/backend/src/hal/delegates/svp_ppol.rs @@ -0,0 +1,72 @@ +use crate::hal::{ + api::{SvpApply, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPPolFromBytes, SvpPrepare}, + layouts::{Backend, Module, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDftToMut, VecZnxDftToRef}, + oep::{SvpApplyImpl, SvpApplyInplaceImpl, SvpPPolAllocBytesImpl, SvpPPolAllocImpl, SvpPPolFromBytesImpl, SvpPrepareImpl}, +}; + +impl SvpPPolFromBytes for Module +where + B: Backend + SvpPPolFromBytesImpl, +{ + fn svp_ppol_from_bytes(&self, cols: usize, bytes: Vec) -> SvpPPolOwned { + B::svp_ppol_from_bytes_impl(self.n(), cols, bytes) + } +} + +impl SvpPPolAlloc for Module +where + B: Backend + SvpPPolAllocImpl, +{ + fn svp_ppol_alloc(&self, cols: usize) -> SvpPPolOwned { + B::svp_ppol_alloc_impl(self.n(), cols) + } +} + +impl SvpPPolAllocBytes for Module +where + B: Backend + SvpPPolAllocBytesImpl, +{ + fn svp_ppol_alloc_bytes(&self, cols: usize) -> usize { + B::svp_ppol_alloc_bytes_impl(self.n(), cols) + } +} + +impl SvpPrepare for Module +where + B: Backend + SvpPrepareImpl, +{ + fn svp_prepare(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: SvpPPolToMut, + A: ScalarZnxToRef, + { + B::svp_prepare_impl(self, res, res_col, a, a_col); + } +} + +impl SvpApply for Module +where + B: Backend + SvpApplyImpl, +{ + fn svp_apply(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize) + where + R: VecZnxDftToMut, + A: SvpPPolToRef, + C: VecZnxDftToRef, + { + B::svp_apply_impl(self, res, res_col, a, a_col, b, b_col); + } +} + +impl SvpApplyInplace for Module +where + B: Backend + SvpApplyInplaceImpl, +{ + fn svp_apply_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: SvpPPolToRef, + { + B::svp_apply_inplace_impl(self, res, res_col, a, a_col); + } +} diff --git a/backend/src/hal/delegates/vec_znx.rs b/backend/src/hal/delegates/vec_znx.rs new file mode 100644 index 0000000..b10d8cc --- /dev/null +++ b/backend/src/hal/delegates/vec_znx.rs @@ -0,0 +1,518 @@ +use sampling::source::Source; + +use crate::hal::{ + api::{ + VecZnxAdd, VecZnxAddDistF64, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAlloc, VecZnxAllocBytes, + VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxCopy, VecZnxDecodeCoeffsi64, VecZnxDecodeVecFloat, + VecZnxDecodeVeci64, VecZnxEncodeCoeffsi64, VecZnxEncodeVeci64, VecZnxFillDistF64, VecZnxFillNormal, VecZnxFillUniform, + VecZnxFromBytes, VecZnxLshInplace, VecZnxMerge, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegate, + VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, + VecZnxRshInplace, VecZnxSplit, VecZnxStd, VecZnxSub, VecZnxSubABInplace, VecZnxSubBAInplace, VecZnxSubScalarInplace, + VecZnxSwithcDegree, + }, + layouts::{Backend, Module, ScalarZnxToRef, Scratch, VecZnxOwned, VecZnxToMut, VecZnxToRef}, + oep::{ + VecZnxAddDistF64Impl, VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl, VecZnxAddScalarInplaceImpl, + VecZnxAllocBytesImpl, VecZnxAllocImpl, VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl, + VecZnxDecodeCoeffsi64Impl, VecZnxDecodeVecFloatImpl, VecZnxDecodeVeci64Impl, VecZnxEncodeCoeffsi64Impl, + VecZnxEncodeVeci64Impl, VecZnxFillDistF64Impl, VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxFromBytesImpl, + VecZnxLshInplaceImpl, VecZnxMergeImpl, VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl, VecZnxNegateImpl, + VecZnxNegateInplaceImpl, VecZnxNormalizeImpl, VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, + VecZnxRotateInplaceImpl, VecZnxRshInplaceImpl, VecZnxSplitImpl, VecZnxStdImpl, VecZnxSubABInplaceImpl, + VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarInplaceImpl, VecZnxSwithcDegreeImpl, + }, +}; + +impl VecZnxAlloc for Module +where + B: Backend + VecZnxAllocImpl, +{ + fn vec_znx_alloc(&self, cols: usize, size: usize) -> VecZnxOwned { + B::vec_znx_alloc_impl(self.n(), cols, size) + } +} + +impl VecZnxFromBytes for Module +where + B: Backend + VecZnxFromBytesImpl, +{ + fn vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxOwned { + B::vec_znx_from_bytes_impl(self.n(), cols, size, bytes) + } +} + +impl VecZnxAllocBytes for Module +where + B: Backend + VecZnxAllocBytesImpl, +{ + fn vec_znx_alloc_bytes(&self, cols: usize, size: usize) -> usize { + B::vec_znx_alloc_bytes_impl(self.n(), cols, size) + } +} + +impl VecZnxNormalizeTmpBytes for Module +where + B: Backend + VecZnxNormalizeTmpBytesImpl, +{ + fn vec_znx_normalize_tmp_bytes(&self, n: usize) -> usize { + B::vec_znx_normalize_tmp_bytes_impl(self, n) + } +} + +impl VecZnxNormalize for Module +where + B: Backend + VecZnxNormalizeImpl, +{ + fn vec_znx_normalize(&self, basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + B::vec_znx_normalize_impl(self, basek, res, res_col, a, a_col, scratch) + } +} + +impl VecZnxNormalizeInplace for Module +where + B: Backend + VecZnxNormalizeInplaceImpl, +{ + fn vec_znx_normalize_inplace(&self, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) + where + A: VecZnxToMut, + { + B::vec_znx_normalize_inplace_impl(self, basek, a, a_col, scratch) + } +} + +impl VecZnxAdd for Module +where + B: Backend + VecZnxAddImpl, +{ + fn vec_znx_add(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + C: VecZnxToRef, + { + B::vec_znx_add_impl(self, res, res_col, a, a_col, b, b_col) + } +} + +impl VecZnxAddInplace for Module +where + B: Backend + VecZnxAddInplaceImpl, +{ + fn vec_znx_add_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + B::vec_znx_add_inplace_impl(self, res, res_col, a, a_col) + } +} + +impl VecZnxAddScalarInplace for Module +where + B: Backend + VecZnxAddScalarInplaceImpl, +{ + fn vec_znx_add_scalar_inplace(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: ScalarZnxToRef, + { + B::vec_znx_add_scalar_inplace_impl(self, res, res_col, res_limb, a, a_col) + } +} + +impl VecZnxSub for Module +where + B: Backend + VecZnxSubImpl, +{ + fn vec_znx_sub(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + C: VecZnxToRef, + { + B::vec_znx_sub_impl(self, res, res_col, a, a_col, b, b_col) + } +} + +impl VecZnxSubABInplace for Module +where + B: Backend + VecZnxSubABInplaceImpl, +{ + fn vec_znx_sub_ab_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + B::vec_znx_sub_ab_inplace_impl(self, res, res_col, a, a_col) + } +} + +impl VecZnxSubBAInplace for Module +where + B: Backend + VecZnxSubBAInplaceImpl, +{ + fn vec_znx_sub_ba_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + B::vec_znx_sub_ba_inplace_impl(self, res, res_col, a, a_col) + } +} + +impl VecZnxSubScalarInplace for Module +where + B: Backend + VecZnxSubScalarInplaceImpl, +{ + fn vec_znx_sub_scalar_inplace(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: ScalarZnxToRef, + { + B::vec_znx_sub_scalar_inplace_impl(self, res, res_col, res_limb, a, a_col) + } +} + +impl VecZnxNegate for Module +where + B: Backend + VecZnxNegateImpl, +{ + fn vec_znx_negate(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + B::vec_znx_negate_impl(self, res, res_col, a, a_col) + } +} + +impl VecZnxNegateInplace for Module +where + B: Backend + VecZnxNegateInplaceImpl, +{ + fn vec_znx_negate_inplace(&self, a: &mut A, a_col: usize) + where + A: VecZnxToMut, + { + B::vec_znx_negate_inplace_impl(self, a, a_col) + } +} + +impl VecZnxLshInplace for Module +where + B: Backend + VecZnxLshInplaceImpl, +{ + fn vec_znx_lsh_inplace(&self, basek: usize, k: usize, a: &mut A) + where + A: VecZnxToMut, + { + B::vec_znx_lsh_inplace_impl(self, basek, k, a) + } +} + +impl VecZnxRshInplace for Module +where + B: Backend + VecZnxRshInplaceImpl, +{ + fn vec_znx_rsh_inplace(&self, basek: usize, k: usize, a: &mut A) + where + A: VecZnxToMut, + { + B::vec_znx_rsh_inplace_impl(self, basek, k, a) + } +} + +impl VecZnxRotate for Module +where + B: Backend + VecZnxRotateImpl, +{ + fn vec_znx_rotate(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + B::vec_znx_rotate_impl(self, k, res, res_col, a, a_col) + } +} + +impl VecZnxRotateInplace for Module +where + B: Backend + VecZnxRotateInplaceImpl, +{ + fn vec_znx_rotate_inplace(&self, k: i64, a: &mut A, a_col: usize) + where + A: VecZnxToMut, + { + B::vec_znx_rotate_inplace_impl(self, k, a, a_col) + } +} + +impl VecZnxAutomorphism for Module +where + B: Backend + VecZnxAutomorphismImpl, +{ + fn vec_znx_automorphism(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + B::vec_znx_automorphism_impl(self, k, res, res_col, a, a_col) + } +} + +impl VecZnxAutomorphismInplace for Module +where + B: Backend + VecZnxAutomorphismInplaceImpl, +{ + fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) + where + A: VecZnxToMut, + { + B::vec_znx_automorphism_inplace_impl(self, k, a, a_col) + } +} + +impl VecZnxMulXpMinusOne for Module +where + B: Backend + VecZnxMulXpMinusOneImpl, +{ + fn vec_znx_mul_xp_minus_one(&self, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + B::vec_znx_mul_xp_minus_one_impl(self, p, res, res_col, a, a_col); + } +} + +impl VecZnxMulXpMinusOneInplace for Module +where + B: Backend + VecZnxMulXpMinusOneInplaceImpl, +{ + fn vec_znx_mul_xp_minus_one_inplace(&self, p: i64, res: &mut R, res_col: usize) + where + R: VecZnxToMut, + { + B::vec_znx_mul_xp_minus_one_inplace_impl(self, p, res, res_col); + } +} + +impl VecZnxSplit for Module +where + B: Backend + VecZnxSplitImpl, +{ + fn vec_znx_split(&self, res: &mut Vec, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + B::vec_znx_split_impl(self, res, res_col, a, a_col, scratch) + } +} + +impl VecZnxMerge for Module +where + B: Backend + VecZnxMergeImpl, +{ + fn vec_znx_merge(&self, res: &mut R, res_col: usize, a: Vec, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + B::vec_znx_merge_impl(self, res, res_col, a, a_col) + } +} + +impl VecZnxSwithcDegree for Module +where + B: Backend + VecZnxSwithcDegreeImpl, +{ + fn vec_znx_switch_degree(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + B::vec_znx_switch_degree_impl(self, res, res_col, a, a_col) + } +} + +impl VecZnxCopy for Module +where + B: Backend + VecZnxCopyImpl, +{ + fn vec_znx_copy(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + B::vec_znx_copy_impl(self, res, res_col, a, a_col) + } +} + +impl VecZnxStd for Module +where + B: Backend + VecZnxStdImpl, +{ + fn vec_znx_std(&self, basek: usize, a: &A, a_col: usize) -> f64 + where + A: VecZnxToRef, + { + B::vec_znx_std_impl(self, basek, a, a_col) + } +} + +impl VecZnxFillUniform for Module +where + B: Backend + VecZnxFillUniformImpl, +{ + fn vec_znx_fill_uniform(&self, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source) + where + R: VecZnxToMut, + { + B::vec_znx_fill_uniform_impl(self, basek, res, res_col, k, source); + } +} + +impl VecZnxFillDistF64 for Module +where + B: Backend + VecZnxFillDistF64Impl, +{ + fn vec_znx_fill_dist_f64>( + &self, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + dist: D, + bound: f64, + ) where + R: VecZnxToMut, + { + B::vec_znx_fill_dist_f64_impl(self, basek, res, res_col, k, source, dist, bound); + } +} + +impl VecZnxAddDistF64 for Module +where + B: Backend + VecZnxAddDistF64Impl, +{ + fn vec_znx_add_dist_f64>( + &self, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + dist: D, + bound: f64, + ) where + R: VecZnxToMut, + { + B::vec_znx_add_dist_f64_impl(self, basek, res, res_col, k, source, dist, bound); + } +} + +impl VecZnxFillNormal for Module +where + B: Backend + VecZnxFillNormalImpl, +{ + fn vec_znx_fill_normal( + &self, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + sigma: f64, + bound: f64, + ) where + R: VecZnxToMut, + { + B::vec_znx_fill_normal_impl(self, basek, res, res_col, k, source, sigma, bound); + } +} + +impl VecZnxAddNormal for Module +where + B: Backend + VecZnxAddNormalImpl, +{ + fn vec_znx_add_normal( + &self, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + sigma: f64, + bound: f64, + ) where + R: VecZnxToMut, + { + B::vec_znx_add_normal_impl(self, basek, res, res_col, k, source, sigma, bound); + } +} + +impl VecZnxEncodeVeci64 for Module +where + B: Backend + VecZnxEncodeVeci64Impl, +{ + fn encode_vec_i64(&self, basek: usize, res: &mut R, res_col: usize, k: usize, data: &[i64], log_max: usize) + where + R: VecZnxToMut, + { + B::encode_vec_i64_impl(self, basek, res, res_col, k, data, log_max); + } +} + +impl VecZnxEncodeCoeffsi64 for Module +where + B: Backend + VecZnxEncodeCoeffsi64Impl, +{ + fn encode_coeff_i64(&self, basek: usize, res: &mut R, res_col: usize, k: usize, i: usize, data: i64, log_max: usize) + where + R: VecZnxToMut, + { + B::encode_coeff_i64_impl(self, basek, res, res_col, k, i, data, log_max); + } +} + +impl VecZnxDecodeVeci64 for Module +where + B: Backend + VecZnxDecodeVeci64Impl, +{ + fn decode_vec_i64(&self, basek: usize, res: &R, res_col: usize, k: usize, data: &mut [i64]) + where + R: VecZnxToRef, + { + B::decode_vec_i64_impl(self, basek, res, res_col, k, data); + } +} + +impl VecZnxDecodeCoeffsi64 for Module +where + B: Backend + VecZnxDecodeCoeffsi64Impl, +{ + fn decode_coeff_i64(&self, basek: usize, res: &R, res_col: usize, k: usize, i: usize) -> i64 + where + R: VecZnxToRef, + { + B::decode_coeff_i64_impl(self, basek, res, res_col, k, i) + } +} + +impl VecZnxDecodeVecFloat for Module +where + B: Backend + VecZnxDecodeVecFloatImpl, +{ + fn decode_vec_float(&self, basek: usize, res: &R, col_i: usize, data: &mut [rug::Float]) + where + R: VecZnxToRef, + { + B::decode_vec_float_impl(self, basek, res, col_i, data); + } +} diff --git a/backend/src/hal/delegates/vec_znx_big.rs b/backend/src/hal/delegates/vec_znx_big.rs new file mode 100644 index 0000000..378e718 --- /dev/null +++ b/backend/src/hal/delegates/vec_znx_big.rs @@ -0,0 +1,334 @@ +use rand_distr::Distribution; +use sampling::source::Source; + +use crate::hal::{ + api::{ + VecZnxBigAdd, VecZnxBigAddDistF64, VecZnxBigAddInplace, VecZnxBigAddNormal, VecZnxBigAddSmall, VecZnxBigAddSmallInplace, + VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigAutomorphism, VecZnxBigAutomorphismInplace, VecZnxBigFillDistF64, + VecZnxBigFillNormal, VecZnxBigFromBytes, VecZnxBigNegateInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, + VecZnxBigSub, VecZnxBigSubABInplace, VecZnxBigSubBAInplace, VecZnxBigSubSmallA, VecZnxBigSubSmallAInplace, + VecZnxBigSubSmallB, VecZnxBigSubSmallBInplace, + }, + layouts::{Backend, Module, Scratch, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef}, + oep::{ + VecZnxBigAddDistF64Impl, VecZnxBigAddImpl, VecZnxBigAddInplaceImpl, VecZnxBigAddNormalImpl, VecZnxBigAddSmallImpl, + VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, VecZnxBigAutomorphismImpl, + VecZnxBigAutomorphismInplaceImpl, VecZnxBigFillDistF64Impl, VecZnxBigFillNormalImpl, VecZnxBigFromBytesImpl, + VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl, VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubABInplaceImpl, + VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl, VecZnxBigSubSmallAImpl, VecZnxBigSubSmallAInplaceImpl, + VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl, + }, +}; + +impl VecZnxBigAlloc for Module +where + B: Backend + VecZnxBigAllocImpl, +{ + fn vec_znx_big_alloc(&self, cols: usize, size: usize) -> VecZnxBigOwned { + B::vec_znx_big_alloc_impl(self.n(), cols, size) + } +} + +impl VecZnxBigFromBytes for Module +where + B: Backend + VecZnxBigFromBytesImpl, +{ + fn vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxBigOwned { + B::vec_znx_big_from_bytes_impl(self.n(), cols, size, bytes) + } +} + +impl VecZnxBigAllocBytes for Module +where + B: Backend + VecZnxBigAllocBytesImpl, +{ + fn vec_znx_big_alloc_bytes(&self, cols: usize, size: usize) -> usize { + B::vec_znx_big_alloc_bytes_impl(self.n(), cols, size) + } +} + +impl VecZnxBigAddDistF64 for Module +where + B: Backend + VecZnxBigAddDistF64Impl, +{ + fn vec_znx_big_add_dist_f64, D: Distribution>( + &self, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + dist: D, + bound: f64, + ) { + B::add_dist_f64_impl(self, basek, res, res_col, k, source, dist, bound); + } +} + +impl VecZnxBigAddNormal for Module +where + B: Backend + VecZnxBigAddNormalImpl, +{ + fn vec_znx_big_add_normal>( + &self, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + sigma: f64, + bound: f64, + ) { + B::add_normal_impl(self, basek, res, res_col, k, source, sigma, bound); + } +} + +impl VecZnxBigFillDistF64 for Module +where + B: Backend + VecZnxBigFillDistF64Impl, +{ + fn vec_znx_big_fill_dist_f64, D: Distribution>( + &self, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + dist: D, + bound: f64, + ) { + B::fill_dist_f64_impl(self, basek, res, res_col, k, source, dist, bound); + } +} + +impl VecZnxBigFillNormal for Module +where + B: Backend + VecZnxBigFillNormalImpl, +{ + fn vec_znx_big_fill_normal>( + &self, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + sigma: f64, + bound: f64, + ) { + B::fill_normal_impl(self, basek, res, res_col, k, source, sigma, bound); + } +} + +impl VecZnxBigAdd for Module +where + B: Backend + VecZnxBigAddImpl, +{ + fn vec_znx_big_add(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + C: VecZnxBigToRef, + { + B::vec_znx_big_add_impl(self, res, res_col, a, a_col, b, b_col); + } +} + +impl VecZnxBigAddInplace for Module +where + B: Backend + VecZnxBigAddInplaceImpl, +{ + fn vec_znx_big_add_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + { + B::vec_znx_big_add_inplace_impl(self, res, res_col, a, a_col); + } +} + +impl VecZnxBigAddSmall for Module +where + B: Backend + VecZnxBigAddSmallImpl, +{ + fn vec_znx_big_add_small(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + C: VecZnxToRef, + { + B::vec_znx_big_add_small_impl(self, res, res_col, a, a_col, b, b_col); + } +} + +impl VecZnxBigAddSmallInplace for Module +where + B: Backend + VecZnxBigAddSmallInplaceImpl, +{ + fn vec_znx_big_add_small_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef, + { + B::vec_znx_big_add_small_inplace_impl(self, res, res_col, a, a_col); + } +} + +impl VecZnxBigSub for Module +where + B: Backend + VecZnxBigSubImpl, +{ + fn vec_znx_big_sub(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + C: VecZnxBigToRef, + { + B::vec_znx_big_sub_impl(self, res, res_col, a, a_col, b, b_col); + } +} + +impl VecZnxBigSubABInplace for Module +where + B: Backend + VecZnxBigSubABInplaceImpl, +{ + fn vec_znx_big_sub_ab_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + { + B::vec_znx_big_sub_ab_inplace_impl(self, res, res_col, a, a_col); + } +} + +impl VecZnxBigSubBAInplace for Module +where + B: Backend + VecZnxBigSubBAInplaceImpl, +{ + fn vec_znx_big_sub_ba_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + { + B::vec_znx_big_sub_ba_inplace_impl(self, res, res_col, a, a_col); + } +} + +impl VecZnxBigSubSmallA for Module +where + B: Backend + VecZnxBigSubSmallAImpl, +{ + fn vec_znx_big_sub_small_a(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef, + C: VecZnxBigToRef, + { + B::vec_znx_big_sub_small_a_impl(self, res, res_col, a, a_col, b, b_col); + } +} + +impl VecZnxBigSubSmallAInplace for Module +where + B: Backend + VecZnxBigSubSmallAInplaceImpl, +{ + fn vec_znx_big_sub_small_a_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef, + { + B::vec_znx_big_sub_small_a_inplace_impl(self, res, res_col, a, a_col); + } +} + +impl VecZnxBigSubSmallB for Module +where + B: Backend + VecZnxBigSubSmallBImpl, +{ + fn vec_znx_big_sub_small_b(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + C: VecZnxToRef, + { + B::vec_znx_big_sub_small_b_impl(self, res, res_col, a, a_col, b, b_col); + } +} + +impl VecZnxBigSubSmallBInplace for Module +where + B: Backend + VecZnxBigSubSmallBInplaceImpl, +{ + fn vec_znx_big_sub_small_b_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef, + { + B::vec_znx_big_sub_small_b_inplace_impl(self, res, res_col, a, a_col); + } +} + +impl VecZnxBigNegateInplace for Module +where + B: Backend + VecZnxBigNegateInplaceImpl, +{ + fn vec_znx_big_negate_inplace(&self, a: &mut A, a_col: usize) + where + A: VecZnxBigToMut, + { + B::vec_znx_big_negate_inplace_impl(self, a, a_col); + } +} + +impl VecZnxBigNormalizeTmpBytes for Module +where + B: Backend + VecZnxBigNormalizeTmpBytesImpl, +{ + fn vec_znx_big_normalize_tmp_bytes(&self, n: usize) -> usize { + B::vec_znx_big_normalize_tmp_bytes_impl(self, n) + } +} + +impl VecZnxBigNormalize for Module +where + B: Backend + VecZnxBigNormalizeImpl, +{ + fn vec_znx_big_normalize( + &self, + basek: usize, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxToMut, + A: VecZnxBigToRef, + { + B::vec_znx_big_normalize_impl(self, basek, res, res_col, a, a_col, scratch); + } +} + +impl VecZnxBigAutomorphism for Module +where + B: Backend + VecZnxBigAutomorphismImpl, +{ + fn vec_znx_big_automorphism(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + { + B::vec_znx_big_automorphism_impl(self, k, res, res_col, a, a_col); + } +} + +impl VecZnxBigAutomorphismInplace for Module +where + B: Backend + VecZnxBigAutomorphismInplaceImpl, +{ + fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) + where + A: VecZnxBigToMut, + { + B::vec_znx_big_automorphism_inplace_impl(self, k, a, a_col); + } +} diff --git a/backend/src/hal/delegates/vec_znx_dft.rs b/backend/src/hal/delegates/vec_znx_dft.rs new file mode 100644 index 0000000..6877dad --- /dev/null +++ b/backend/src/hal/delegates/vec_znx_dft.rs @@ -0,0 +1,196 @@ +use crate::hal::{ + api::{ + VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftCopy, VecZnxDftFromBytes, + VecZnxDftFromVecZnx, VecZnxDftSub, VecZnxDftSubABInplace, VecZnxDftSubBAInplace, VecZnxDftToVecZnxBig, + VecZnxDftToVecZnxBigConsume, VecZnxDftToVecZnxBigTmpA, VecZnxDftToVecZnxBigTmpBytes, VecZnxDftZero, + }, + layouts::{ + Backend, Data, Module, Scratch, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef, + VecZnxToRef, + }, + oep::{ + VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftCopyImpl, + VecZnxDftFromBytesImpl, VecZnxDftFromVecZnxImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl, VecZnxDftSubImpl, + VecZnxDftToVecZnxBigConsumeImpl, VecZnxDftToVecZnxBigImpl, VecZnxDftToVecZnxBigTmpAImpl, + VecZnxDftToVecZnxBigTmpBytesImpl, VecZnxDftZeroImpl, + }, +}; + +impl VecZnxDftFromBytes for Module +where + B: Backend + VecZnxDftFromBytesImpl, +{ + fn vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxDftOwned { + B::vec_znx_dft_from_bytes_impl(self.n(), cols, size, bytes) + } +} + +impl VecZnxDftAllocBytes for Module +where + B: Backend + VecZnxDftAllocBytesImpl, +{ + fn vec_znx_dft_alloc_bytes(&self, cols: usize, size: usize) -> usize { + B::vec_znx_dft_alloc_bytes_impl(self.n(), cols, size) + } +} + +impl VecZnxDftAlloc for Module +where + B: Backend + VecZnxDftAllocImpl, +{ + fn vec_znx_dft_alloc(&self, cols: usize, size: usize) -> VecZnxDftOwned { + B::vec_znx_dft_alloc_impl(self.n(), cols, size) + } +} + +impl VecZnxDftToVecZnxBigTmpBytes for Module +where + B: Backend + VecZnxDftToVecZnxBigTmpBytesImpl, +{ + fn vec_znx_dft_to_vec_znx_big_tmp_bytes(&self) -> usize { + B::vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(self) + } +} + +impl VecZnxDftToVecZnxBig for Module +where + B: Backend + VecZnxDftToVecZnxBigImpl, +{ + fn vec_znx_dft_to_vec_znx_big(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) + where + R: VecZnxBigToMut, + A: VecZnxDftToRef, + { + B::vec_znx_dft_to_vec_znx_big_impl(self, res, res_col, a, a_col, scratch); + } +} + +impl VecZnxDftToVecZnxBigTmpA for Module +where + B: Backend + VecZnxDftToVecZnxBigTmpAImpl, +{ + fn vec_znx_dft_to_vec_znx_big_tmp_a(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxDftToMut, + { + B::vec_znx_dft_to_vec_znx_big_tmp_a_impl(self, res, res_col, a, a_col); + } +} + +impl VecZnxDftToVecZnxBigConsume for Module +where + B: Backend + VecZnxDftToVecZnxBigConsumeImpl, +{ + fn vec_znx_dft_to_vec_znx_big_consume(&self, a: VecZnxDft) -> VecZnxBig + where + VecZnxDft: VecZnxDftToMut, + { + B::vec_znx_dft_to_vec_znx_big_consume_impl(self, a) + } +} + +impl VecZnxDftFromVecZnx for Module +where + B: Backend + VecZnxDftFromVecZnxImpl, +{ + fn vec_znx_dft_from_vec_znx(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxToRef, + { + B::vec_znx_dft_from_vec_znx_impl(self, step, offset, res, res_col, a, a_col); + } +} + +impl VecZnxDftAdd for Module +where + B: Backend + VecZnxDftAddImpl, +{ + fn vec_znx_dft_add(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + D: VecZnxDftToRef, + { + B::vec_znx_dft_add_impl(self, res, res_col, a, a_col, b, b_col); + } +} + +impl VecZnxDftAddInplace for Module +where + B: Backend + VecZnxDftAddInplaceImpl, +{ + fn vec_znx_dft_add_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + { + B::vec_znx_dft_add_inplace_impl(self, res, res_col, a, a_col); + } +} + +impl VecZnxDftSub for Module +where + B: Backend + VecZnxDftSubImpl, +{ + fn vec_znx_dft_sub(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + D: VecZnxDftToRef, + { + B::vec_znx_dft_sub_impl(self, res, res_col, a, a_col, b, b_col); + } +} + +impl VecZnxDftSubABInplace for Module +where + B: Backend + VecZnxDftSubABInplaceImpl, +{ + fn vec_znx_dft_sub_ab_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + { + B::vec_znx_dft_sub_ab_inplace_impl(self, res, res_col, a, a_col); + } +} + +impl VecZnxDftSubBAInplace for Module +where + B: Backend + VecZnxDftSubBAInplaceImpl, +{ + fn vec_znx_dft_sub_ba_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + { + B::vec_znx_dft_sub_ba_inplace_impl(self, res, res_col, a, a_col); + } +} + +impl VecZnxDftCopy for Module +where + B: Backend + VecZnxDftCopyImpl, +{ + fn vec_znx_dft_copy(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + { + B::vec_znx_dft_copy_impl(self, step, offset, res, res_col, a, a_col); + } +} + +impl VecZnxDftZero for Module +where + B: Backend + VecZnxDftZeroImpl, +{ + fn vec_znx_dft_zero(&self, res: &mut R) + where + R: VecZnxDftToMut, + { + B::vec_znx_dft_zero_impl(self, res); + } +} diff --git a/backend/src/hal/delegates/vmp_pmat.rs b/backend/src/hal/delegates/vmp_pmat.rs new file mode 100644 index 0000000..89fca8f --- /dev/null +++ b/backend/src/hal/delegates/vmp_pmat.rs @@ -0,0 +1,126 @@ +use crate::hal::{ + api::{ + VmpApply, VmpApplyAdd, VmpApplyAddTmpBytes, VmpApplyTmpBytes, VmpPMatAlloc, VmpPMatAllocBytes, VmpPMatFromBytes, + VmpPMatPrepare, VmpPrepareTmpBytes, + }, + layouts::{Backend, MatZnxToRef, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VmpPMatOwned, VmpPMatToMut, VmpPMatToRef}, + oep::{ + VmpApplyAddImpl, VmpApplyAddTmpBytesImpl, VmpApplyImpl, VmpApplyTmpBytesImpl, VmpPMatAllocBytesImpl, VmpPMatAllocImpl, + VmpPMatFromBytesImpl, VmpPMatPrepareImpl, VmpPrepareTmpBytesImpl, + }, +}; + +impl VmpPMatAlloc for Module +where + B: Backend + VmpPMatAllocImpl, +{ + fn vmp_pmat_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned { + B::vmp_pmat_alloc_impl(self.n(), rows, cols_in, cols_out, size) + } +} + +impl VmpPMatAllocBytes for Module +where + B: Backend + VmpPMatAllocBytesImpl, +{ + fn vmp_pmat_alloc_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { + B::vmp_pmat_alloc_bytes_impl(self.n(), rows, cols_in, cols_out, size) + } +} + +impl VmpPMatFromBytes for Module +where + B: Backend + VmpPMatFromBytesImpl, +{ + fn vmp_pmat_from_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: Vec) -> VmpPMatOwned { + B::vmp_pmat_from_bytes_impl(self.n(), rows, cols_in, cols_out, size, bytes) + } +} + +impl VmpPrepareTmpBytes for Module +where + B: Backend + VmpPrepareTmpBytesImpl, +{ + fn vmp_prepare_tmp_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { + B::vmp_prepare_tmp_bytes_impl(self, rows, cols_in, cols_out, size) + } +} + +impl VmpPMatPrepare for Module +where + B: Backend + VmpPMatPrepareImpl, +{ + fn vmp_prepare(&self, res: &mut R, a: &A, scratch: &mut Scratch) + where + R: VmpPMatToMut, + A: MatZnxToRef, + { + B::vmp_prepare_impl(self, res, a, scratch) + } +} + +impl VmpApplyTmpBytes for Module +where + B: Backend + VmpApplyTmpBytesImpl, +{ + fn vmp_apply_tmp_bytes( + &self, + res_size: usize, + a_size: usize, + b_rows: usize, + b_cols_in: usize, + b_cols_out: usize, + b_size: usize, + ) -> usize { + B::vmp_apply_tmp_bytes_impl( + self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size, + ) + } +} + +impl VmpApply for Module +where + B: Backend + VmpApplyImpl, +{ + fn vmp_apply(&self, res: &mut R, a: &A, b: &C, scratch: &mut Scratch) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + C: VmpPMatToRef, + { + B::vmp_apply_impl(self, res, a, b, scratch); + } +} + +impl VmpApplyAddTmpBytes for Module +where + B: Backend + VmpApplyAddTmpBytesImpl, +{ + fn vmp_apply_add_tmp_bytes( + &self, + res_size: usize, + a_size: usize, + b_rows: usize, + b_cols_in: usize, + b_cols_out: usize, + b_size: usize, + ) -> usize { + B::vmp_apply_add_tmp_bytes_impl( + self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size, + ) + } +} + +impl VmpApplyAdd for Module +where + B: Backend + VmpApplyAddImpl, +{ + fn vmp_apply_add(&self, res: &mut R, a: &A, b: &C, scale: usize, scratch: &mut Scratch) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + C: VmpPMatToRef, + { + B::vmp_apply_add_impl(self, res, a, b, scale, scratch); + } +} diff --git a/backend/src/hal/layouts/mat_znx.rs b/backend/src/hal/layouts/mat_znx.rs new file mode 100644 index 0000000..912d98d --- /dev/null +++ b/backend/src/hal/layouts/mat_znx.rs @@ -0,0 +1,246 @@ +use crate::{ + alloc_aligned, + hal::{ + api::{DataView, DataViewMut, ZnxInfos, ZnxSliceSize, ZnxView}, + layouts::{Data, DataMut, DataRef, ReaderFrom, VecZnx, WriterTo}, + }, +}; + +#[derive(PartialEq, Eq)] +pub struct MatZnx { + data: D, + n: usize, + size: usize, + rows: usize, + cols_in: usize, + cols_out: usize, +} + +impl ZnxInfos for MatZnx { + fn cols(&self) -> usize { + self.cols_in + } + + fn rows(&self) -> usize { + self.rows + } + + fn n(&self) -> usize { + self.n + } + + fn size(&self) -> usize { + self.size + } +} + +impl ZnxSliceSize for MatZnx { + fn sl(&self) -> usize { + self.n() * self.cols_out() + } +} + +impl DataView for MatZnx { + type D = D; + fn data(&self) -> &Self::D { + &self.data + } +} + +impl DataViewMut for MatZnx { + fn data_mut(&mut self) -> &mut Self::D { + &mut self.data + } +} + +impl ZnxView for MatZnx { + type Scalar = i64; +} + +impl MatZnx { + pub fn cols_in(&self) -> usize { + self.cols_in + } + + pub fn cols_out(&self) -> usize { + self.cols_out + } +} + +impl MatZnx { + pub fn bytes_of(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { + rows * cols_in * VecZnx::>::alloc_bytes::(n, cols_out, size) + } +} + +impl>> MatZnx { + pub(crate) fn new(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self { + let data: Vec = alloc_aligned(Self::bytes_of(n, rows, cols_in, cols_out, size)); + Self { + data: data.into(), + n, + size, + rows, + cols_in, + cols_out, + } + } + + pub(crate) fn new_from_bytes( + n: usize, + rows: usize, + cols_in: usize, + cols_out: usize, + size: usize, + bytes: impl Into>, + ) -> Self { + let data: Vec = bytes.into(); + assert!(data.len() == Self::bytes_of(n, rows, cols_in, cols_out, size)); + Self { + data: data.into(), + n, + size, + rows, + cols_in, + cols_out, + } + } +} + +impl MatZnx { + pub fn at(&self, row: usize, col: usize) -> VecZnx<&[u8]> { + #[cfg(debug_assertions)] + { + assert!(row < self.rows(), "rows: {} >= {}", row, self.rows()); + assert!(col < self.cols_in(), "cols: {} >= {}", col, self.cols_in()); + } + + let self_ref: MatZnx<&[u8]> = self.to_ref(); + let nb_bytes: usize = VecZnx::>::alloc_bytes::(self.n, self.cols_out, self.size); + let start: usize = nb_bytes * self.cols() * row + col * nb_bytes; + let end: usize = start + nb_bytes; + + VecZnx { + data: &self_ref.data[start..end], + n: self.n, + cols: self.cols_out, + size: self.size, + max_size: self.size, + } + } +} + +impl MatZnx { + pub fn at_mut(&mut self, row: usize, col: usize) -> VecZnx<&mut [u8]> { + #[cfg(debug_assertions)] + { + assert!(row < self.rows(), "rows: {} >= {}", row, self.rows()); + assert!(col < self.cols_in(), "cols: {} >= {}", col, self.cols_in()); + } + + let n: usize = self.n(); + let cols_out: usize = self.cols_out(); + let cols_in: usize = self.cols_in(); + let size: usize = self.size(); + + let self_ref: MatZnx<&mut [u8]> = self.to_mut(); + let nb_bytes: usize = VecZnx::>::alloc_bytes::(n, cols_out, size); + let start: usize = nb_bytes * cols_in * row + col * nb_bytes; + let end: usize = start + nb_bytes; + + VecZnx { + data: &mut self_ref.data[start..end], + n, + cols: cols_out, + size, + max_size: size, + } + } +} + +pub type MatZnxOwned = MatZnx>; +pub type MatZnxMut<'a> = MatZnx<&'a mut [u8]>; +pub type MatZnxRef<'a> = MatZnx<&'a [u8]>; + +pub trait MatZnxToRef { + fn to_ref(&self) -> MatZnx<&[u8]>; +} + +impl MatZnxToRef for MatZnx { + fn to_ref(&self) -> MatZnx<&[u8]> { + MatZnx { + data: self.data.as_ref(), + n: self.n, + rows: self.rows, + cols_in: self.cols_in, + cols_out: self.cols_out, + size: self.size, + } + } +} + +pub trait MatZnxToMut { + fn to_mut(&mut self) -> MatZnx<&mut [u8]>; +} + +impl MatZnxToMut for MatZnx { + fn to_mut(&mut self) -> MatZnx<&mut [u8]> { + MatZnx { + data: self.data.as_mut(), + n: self.n, + rows: self.rows, + cols_in: self.cols_in, + cols_out: self.cols_out, + size: self.size, + } + } +} + +impl MatZnx { + pub(crate) fn from_data(data: D, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self { + Self { + data, + n, + rows, + cols_in, + cols_out, + size, + } + } +} + +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; + +impl ReaderFrom for MatZnx { + fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { + self.n = reader.read_u64::()? as usize; + self.size = reader.read_u64::()? as usize; + self.rows = reader.read_u64::()? as usize; + self.cols_in = reader.read_u64::()? as usize; + self.cols_out = reader.read_u64::()? as usize; + let len: usize = reader.read_u64::()? as usize; + let buf: &mut [u8] = self.data.as_mut(); + if buf.len() != len { + return Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + format!("self.data.len()={} != read len={}", buf.len(), len), + )); + } + reader.read_exact(&mut buf[..len])?; + Ok(()) + } +} + +impl WriterTo for MatZnx { + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + writer.write_u64::(self.n as u64)?; + writer.write_u64::(self.size as u64)?; + writer.write_u64::(self.rows as u64)?; + writer.write_u64::(self.cols_in as u64)?; + writer.write_u64::(self.cols_out as u64)?; + let buf: &[u8] = self.data.as_ref(); + writer.write_u64::(buf.len() as u64)?; + writer.write_all(buf)?; + Ok(()) + } +} diff --git a/backend/src/hal/layouts/mod.rs b/backend/src/hal/layouts/mod.rs new file mode 100644 index 0000000..13f138f --- /dev/null +++ b/backend/src/hal/layouts/mod.rs @@ -0,0 +1,25 @@ +mod mat_znx; +mod module; +mod scalar_znx; +mod scratch; +mod serialization; +mod svp_ppol; +mod vec_znx; +mod vec_znx_big; +mod vec_znx_dft; +mod vmp_pmat; + +pub use mat_znx::*; +pub use module::*; +pub use scalar_znx::*; +pub use scratch::*; +pub use serialization::*; +pub use svp_ppol::*; +pub use vec_znx::*; +pub use vec_znx_big::*; +pub use vec_znx_dft::*; +pub use vmp_pmat::*; + +pub trait Data = PartialEq + Eq + Sized; +pub trait DataRef = Data + AsRef<[u8]>; +pub trait DataMut = DataRef + AsMut<[u8]>; diff --git a/backend/src/module.rs b/backend/src/hal/layouts/module.rs similarity index 55% rename from backend/src/module.rs rename to backend/src/hal/layouts/module.rs index f6d0e0e..14d2d21 100644 --- a/backend/src/module.rs +++ b/backend/src/hal/layouts/module.rs @@ -1,71 +1,56 @@ +use std::{marker::PhantomData, ptr::NonNull}; + use crate::GALOISGENERATOR; -use crate::ffi::module::{MODULE, delete_module_info, module_info_t, new_module_info}; -use std::marker::PhantomData; -#[derive(Copy, Clone)] -#[repr(u8)] -pub enum BACKEND { - FFT64, - NTT120, -} - -pub trait Backend { - const KIND: BACKEND; - fn module_type() -> u32; -} - -pub struct FFT64; -pub struct NTT120; - -impl Backend for FFT64 { - const KIND: BACKEND = BACKEND::FFT64; - fn module_type() -> u32 { - 0 - } -} - -impl Backend for NTT120 { - const KIND: BACKEND = BACKEND::NTT120; - fn module_type() -> u32 { - 1 - } +pub trait Backend: Sized { + type Handle: 'static; + unsafe fn destroy(handle: NonNull); } pub struct Module { - pub ptr: *mut MODULE, - n: usize, + ptr: NonNull, + n: u64, _marker: PhantomData, } impl Module { - // Instantiates a new module. - pub fn new(n: usize) -> Self { - unsafe { - let m: *mut module_info_t = new_module_info(n as u64, B::module_type()); - if m.is_null() { - panic!("Failed to create module."); - } - Self { - ptr: m, - n: n, - _marker: PhantomData, - } + /// Construct from a raw pointer managed elsewhere. + /// SAFETY: `ptr` must be non-null and remain valid for the lifetime of this Module. + #[inline] + pub unsafe fn from_raw_parts(ptr: *mut B::Handle, n: u64) -> Self { + Self { + ptr: NonNull::new(ptr).expect("null module ptr"), + n, + _marker: PhantomData, } } - pub fn n(&self) -> usize { - self.n + #[inline] + pub unsafe fn ptr(&self) -> *mut ::Handle { + self.ptr.as_ptr() } + #[inline] + pub fn n(&self) -> usize { + self.n as usize + } + #[inline] + pub fn as_mut_ptr(&self) -> *mut B::Handle { + self.ptr.as_ptr() + } + + #[inline] pub fn log_n(&self) -> usize { (usize::BITS - (self.n() - 1).leading_zeros()) as _ } + #[inline] pub fn cyclotomic_order(&self) -> u64 { (self.n() << 1) as _ } // Returns GALOISGENERATOR^|generator| * sign(generator) + #[inline] pub fn galois_element(&self, generator: i64) -> i64 { if generator == 0 { return 1; @@ -74,6 +59,7 @@ impl Module { } // Returns gen^-1 + #[inline] pub fn galois_element_inv(&self, gal_el: i64) -> i64 { if gal_el == 0 { panic!("cannot invert 0") @@ -85,11 +71,11 @@ impl Module { impl Drop for Module { fn drop(&mut self) { - unsafe { delete_module_info(self.ptr) } + unsafe { B::destroy(self.ptr) } } } -fn mod_exp_u64(x: u64, e: usize) -> u64 { +pub fn mod_exp_u64(x: u64, e: usize) -> u64 { let mut y: u64 = 1; let mut x_pow: u64 = x; let mut exp = e; diff --git a/backend/src/scalar_znx.rs b/backend/src/hal/layouts/scalar_znx.rs similarity index 55% rename from backend/src/scalar_znx.rs rename to backend/src/hal/layouts/scalar_znx.rs index 4acedb5..d3c5287 100644 --- a/backend/src/scalar_znx.rs +++ b/backend/src/hal/layouts/scalar_znx.rs @@ -1,20 +1,24 @@ -use crate::ffi::vec_znx; -use crate::znx_base::ZnxInfos; -use crate::{ - Backend, DataView, DataViewMut, Module, VecZnx, VecZnxToMut, VecZnxToRef, ZnxSliceSize, ZnxView, ZnxViewMut, alloc_aligned, -}; use rand::seq::SliceRandom; use rand_core::RngCore; use rand_distr::{Distribution, weighted::WeightedIndex}; use sampling::source::Source; -pub struct ScalarZnx { +use crate::{ + alloc_aligned, + hal::{ + api::{DataView, DataViewMut, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero}, + layouts::{Data, DataMut, DataRef, ReaderFrom, VecZnx, VecZnxToMut, VecZnxToRef, WriterTo}, + }, +}; + +#[derive(PartialEq, Eq)] +pub struct ScalarZnx { pub(crate) data: D, pub(crate) n: usize, pub(crate) cols: usize, } -impl ZnxInfos for ScalarZnx { +impl ZnxInfos for ScalarZnx { fn cols(&self) -> usize { self.cols } @@ -32,30 +36,30 @@ impl ZnxInfos for ScalarZnx { } } -impl ZnxSliceSize for ScalarZnx { +impl ZnxSliceSize for ScalarZnx { fn sl(&self) -> usize { self.n() } } -impl DataView for ScalarZnx { +impl DataView for ScalarZnx { type D = D; fn data(&self) -> &Self::D { &self.data } } -impl DataViewMut for ScalarZnx { +impl DataViewMut for ScalarZnx { fn data_mut(&mut self) -> &mut Self::D { &mut self.data } } -impl> ZnxView for ScalarZnx { +impl ZnxView for ScalarZnx { type Scalar = i64; } -impl + AsRef<[u8]>> ScalarZnx { +impl ScalarZnx { pub fn fill_ternary_prob(&mut self, col: usize, prob: f64, source: &mut Source) { let choices: [i64; 3] = [-1, 0, 1]; let weights: [f64; 3] = [prob / 2.0, 1.0 - prob, prob / 2.0]; @@ -103,11 +107,13 @@ impl + AsRef<[u8]>> ScalarZnx { } } -impl>> ScalarZnx { - pub(crate) fn bytes_of(n: usize, cols: usize) -> usize { +impl ScalarZnx { + pub fn bytes_of(n: usize, cols: usize) -> usize { n * cols * size_of::() } +} +impl>> ScalarZnx { pub fn new(n: usize, cols: usize) -> Self { let data: Vec = alloc_aligned::(Self::bytes_of(n, cols)); Self { @@ -128,94 +134,18 @@ impl>> ScalarZnx { } } +impl ZnxZero for ScalarZnx { + fn zero(&mut self) { + self.raw_mut().fill(0) + } + fn zero_at(&mut self, i: usize, j: usize) { + self.at_mut(i, j).fill(0); + } +} + pub type ScalarZnxOwned = ScalarZnx>; -pub(crate) fn bytes_of_scalar_znx(module: &Module, cols: usize) -> usize { - ScalarZnxOwned::bytes_of(module.n(), cols) -} - -pub trait ScalarZnxAlloc { - fn bytes_of_scalar_znx(&self, cols: usize) -> usize; - fn new_scalar_znx(&self, cols: usize) -> ScalarZnxOwned; - fn new_scalar_znx_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxOwned; -} - -impl ScalarZnxAlloc for Module { - fn bytes_of_scalar_znx(&self, cols: usize) -> usize { - ScalarZnxOwned::bytes_of(self.n(), cols) - } - fn new_scalar_znx(&self, cols: usize) -> ScalarZnxOwned { - ScalarZnxOwned::new(self.n(), cols) - } - fn new_scalar_znx_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxOwned { - ScalarZnxOwned::new_from_bytes(self.n(), cols, bytes) - } -} - -pub trait ScalarZnxOps { - fn scalar_znx_automorphism(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: ScalarZnxToMut, - A: ScalarZnxToRef; - - /// Applies the automorphism X^i -> X^ik on the selected column of `a`. - fn scalar_znx_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) - where - A: ScalarZnxToMut; -} - -impl ScalarZnxOps for Module { - fn scalar_znx_automorphism(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: ScalarZnxToMut, - A: ScalarZnxToRef, - { - let a: ScalarZnx<&[u8]> = a.to_ref(); - let mut res: ScalarZnx<&mut [u8]> = res.to_mut(); - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(res.n(), self.n()); - } - unsafe { - vec_znx::vec_znx_automorphism( - self.ptr, - k, - res.at_mut_ptr(res_col, 0), - res.size() as u64, - res.sl() as u64, - a.at_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - ) - } - } - - fn scalar_znx_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) - where - A: ScalarZnxToMut, - { - let mut a: ScalarZnx<&mut [u8]> = a.to_mut(); - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - } - unsafe { - vec_znx::vec_znx_automorphism( - self.ptr, - k, - a.at_mut_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - a.at_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - ) - } - } -} - -impl ScalarZnx { +impl ScalarZnx { pub(crate) fn from_data(data: D, n: usize, cols: usize) -> Self { Self { data, n, cols } } @@ -225,10 +155,7 @@ pub trait ScalarZnxToRef { fn to_ref(&self) -> ScalarZnx<&[u8]>; } -impl ScalarZnxToRef for ScalarZnx -where - D: AsRef<[u8]>, -{ +impl ScalarZnxToRef for ScalarZnx { fn to_ref(&self) -> ScalarZnx<&[u8]> { ScalarZnx { data: self.data.as_ref(), @@ -242,10 +169,7 @@ pub trait ScalarZnxToMut { fn to_mut(&mut self) -> ScalarZnx<&mut [u8]>; } -impl ScalarZnxToMut for ScalarZnx -where - D: AsRef<[u8]> + AsMut<[u8]>, -{ +impl ScalarZnxToMut for ScalarZnx { fn to_mut(&mut self) -> ScalarZnx<&mut [u8]> { ScalarZnx { data: self.data.as_mut(), @@ -255,30 +179,56 @@ where } } -impl VecZnxToRef for ScalarZnx -where - D: AsRef<[u8]>, -{ +impl VecZnxToRef for ScalarZnx { fn to_ref(&self) -> VecZnx<&[u8]> { VecZnx { data: self.data.as_ref(), n: self.n, cols: self.cols, size: 1, + max_size: 1, } } } -impl VecZnxToMut for ScalarZnx -where - D: AsRef<[u8]> + AsMut<[u8]>, -{ +impl VecZnxToMut for ScalarZnx { fn to_mut(&mut self) -> VecZnx<&mut [u8]> { VecZnx { data: self.data.as_mut(), n: self.n, cols: self.cols, size: 1, + max_size: 1, } } } + +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; + +impl ReaderFrom for ScalarZnx { + fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { + self.n = reader.read_u64::()? as usize; + self.cols = reader.read_u64::()? as usize; + let len: usize = reader.read_u64::()? as usize; + let buf: &mut [u8] = self.data.as_mut(); + if buf.len() != len { + return Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + format!("self.data.len()={} != read len={}", buf.len(), len), + )); + } + reader.read_exact(&mut buf[..len])?; + Ok(()) + } +} + +impl WriterTo for ScalarZnx { + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + writer.write_u64::(self.n as u64)?; + writer.write_u64::(self.cols as u64)?; + let buf: &[u8] = self.data.as_ref(); + writer.write_u64::(buf.len() as u64)?; + writer.write_all(buf)?; + Ok(()) + } +} diff --git a/backend/src/hal/layouts/scratch.rs b/backend/src/hal/layouts/scratch.rs new file mode 100644 index 0000000..0562939 --- /dev/null +++ b/backend/src/hal/layouts/scratch.rs @@ -0,0 +1,13 @@ +use std::marker::PhantomData; + +use crate::hal::layouts::Backend; + +pub struct ScratchOwned { + pub(crate) data: Vec, + pub(crate) _phantom: PhantomData, +} + +pub struct Scratch { + pub(crate) _phantom: PhantomData, + pub(crate) data: [u8], +} diff --git a/backend/src/hal/layouts/serialization.rs b/backend/src/hal/layouts/serialization.rs new file mode 100644 index 0000000..a50925f --- /dev/null +++ b/backend/src/hal/layouts/serialization.rs @@ -0,0 +1,9 @@ +use std::io::{Read, Result, Write}; + +pub trait WriterTo { + fn write_to(&self, writer: &mut W) -> Result<()>; +} + +pub trait ReaderFrom { + fn read_from(&mut self, reader: &mut R) -> Result<()>; +} diff --git a/backend/src/hal/layouts/svp_ppol.rs b/backend/src/hal/layouts/svp_ppol.rs new file mode 100644 index 0000000..877c470 --- /dev/null +++ b/backend/src/hal/layouts/svp_ppol.rs @@ -0,0 +1,151 @@ +use std::marker::PhantomData; + +use crate::{ + alloc_aligned, + hal::{ + api::{DataView, DataViewMut, ZnxInfos}, + layouts::{Backend, Data, DataMut, DataRef, ReaderFrom, WriterTo}, + }, +}; + +#[derive(PartialEq, Eq)] +pub struct SvpPPol { + data: D, + n: usize, + cols: usize, + _phantom: PhantomData, +} + +impl ZnxInfos for SvpPPol { + fn cols(&self) -> usize { + self.cols + } + + fn rows(&self) -> usize { + 1 + } + + fn n(&self) -> usize { + self.n + } + + fn size(&self) -> usize { + 1 + } +} + +impl DataView for SvpPPol { + type D = D; + fn data(&self) -> &Self::D { + &self.data + } +} + +impl DataViewMut for SvpPPol { + fn data_mut(&mut self) -> &mut Self::D { + &mut self.data + } +} + +pub trait SvpPPolBytesOf { + fn bytes_of(n: usize, cols: usize) -> usize; +} + +impl>, B: Backend> SvpPPol +where + SvpPPol: SvpPPolBytesOf, +{ + pub(crate) fn alloc(n: usize, cols: usize) -> Self { + let data: Vec = alloc_aligned::(Self::bytes_of(n, cols)); + Self { + data: data.into(), + n, + cols, + _phantom: PhantomData, + } + } + + pub(crate) fn from_bytes(n: usize, cols: usize, bytes: impl Into>) -> Self { + let data: Vec = bytes.into(); + assert!(data.len() == Self::bytes_of(n, cols)); + Self { + data: data.into(), + n, + cols, + _phantom: PhantomData, + } + } +} + +pub type SvpPPolOwned = SvpPPol, B>; + +pub trait SvpPPolToRef { + fn to_ref(&self) -> SvpPPol<&[u8], B>; +} + +impl SvpPPolToRef for SvpPPol { + fn to_ref(&self) -> SvpPPol<&[u8], B> { + SvpPPol { + data: self.data.as_ref(), + n: self.n, + cols: self.cols, + _phantom: PhantomData, + } + } +} + +pub trait SvpPPolToMut { + fn to_mut(&mut self) -> SvpPPol<&mut [u8], B>; +} + +impl SvpPPolToMut for SvpPPol { + fn to_mut(&mut self) -> SvpPPol<&mut [u8], B> { + SvpPPol { + data: self.data.as_mut(), + n: self.n, + cols: self.cols, + _phantom: PhantomData, + } + } +} + +impl SvpPPol { + pub(crate) fn from_data(data: D, n: usize, cols: usize) -> Self { + Self { + data, + n, + cols, + _phantom: PhantomData, + } + } +} + +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; + +impl ReaderFrom for SvpPPol { + fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { + self.n = reader.read_u64::()? as usize; + self.cols = reader.read_u64::()? as usize; + let len: usize = reader.read_u64::()? as usize; + let buf: &mut [u8] = self.data.as_mut(); + if buf.len() != len { + return Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + format!("self.data.len()={} != read len={}", buf.len(), len), + )); + } + reader.read_exact(&mut buf[..len])?; + Ok(()) + } +} + +impl WriterTo for SvpPPol { + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + writer.write_u64::(self.n as u64)?; + writer.write_u64::(self.cols as u64)?; + let buf: &[u8] = self.data.as_ref(); + writer.write_u64::(buf.len() as u64)?; + writer.write_all(buf)?; + Ok(()) + } +} diff --git a/backend/src/hal/layouts/vec_znx.rs b/backend/src/hal/layouts/vec_znx.rs new file mode 100644 index 0000000..b7d69dc --- /dev/null +++ b/backend/src/hal/layouts/vec_znx.rs @@ -0,0 +1,241 @@ +use std::fmt; + +use crate::{ + alloc_aligned, + hal::{ + api::{DataView, DataViewMut, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero}, + layouts::{Data, DataMut, DataRef, ReaderFrom, WriterTo}, + }, +}; + +#[derive(PartialEq, Eq)] +pub struct VecZnx { + pub(crate) data: D, + pub(crate) n: usize, + pub(crate) cols: usize, + pub(crate) size: usize, + pub(crate) max_size: usize, +} + +impl fmt::Debug for VecZnx { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self) + } +} + +impl ZnxInfos for VecZnx { + fn cols(&self) -> usize { + self.cols + } + + fn rows(&self) -> usize { + 1 + } + + fn n(&self) -> usize { + self.n + } + + fn size(&self) -> usize { + self.size + } +} + +impl ZnxSliceSize for VecZnx { + fn sl(&self) -> usize { + self.n() * self.cols() + } +} + +impl DataView for VecZnx { + type D = D; + fn data(&self) -> &Self::D { + &self.data + } +} + +impl DataViewMut for VecZnx { + fn data_mut(&mut self) -> &mut Self::D { + &mut self.data + } +} + +impl ZnxView for VecZnx { + type Scalar = i64; +} + +impl VecZnx> { + pub fn rsh_scratch_space(n: usize) -> usize { + n * std::mem::size_of::() + } +} + +impl ZnxZero for VecZnx { + fn zero(&mut self) { + self.raw_mut().fill(0) + } + fn zero_at(&mut self, i: usize, j: usize) { + self.at_mut(i, j).fill(0); + } +} + +impl VecZnx { + pub fn alloc_bytes(n: usize, cols: usize, size: usize) -> usize { + n * cols * size * size_of::() + } +} + +impl>> VecZnx { + pub fn new(n: usize, cols: usize, size: usize) -> Self { + let data: Vec = alloc_aligned::(Self::alloc_bytes::(n, cols, size)); + Self { + data: data.into(), + n, + cols, + size, + max_size: size, + } + } + + pub fn from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into>) -> Self { + let data: Vec = bytes.into(); + assert!(data.len() == Self::alloc_bytes::(n, cols, size)); + Self { + data: data.into(), + n, + cols, + size, + max_size: size, + } + } +} + +impl VecZnx { + pub fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self { + Self { + data, + n, + cols, + size, + max_size: size, + } + } +} + +impl fmt::Display for VecZnx { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!( + f, + "VecZnx(n={}, cols={}, size={})", + self.n, self.cols, self.size + )?; + + for col in 0..self.cols { + writeln!(f, "Column {}:", col)?; + for size in 0..self.size { + let coeffs = self.at(col, size); + write!(f, " Size {}: [", size)?; + + let max_show = 100; + let show_count = coeffs.len().min(max_show); + + for (i, &coeff) in coeffs.iter().take(show_count).enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", coeff)?; + } + + if coeffs.len() > max_show { + write!(f, ", ... ({} more)", coeffs.len() - max_show)?; + } + + writeln!(f, "]")?; + } + } + Ok(()) + } +} + +pub type VecZnxOwned = VecZnx>; +pub type VecZnxMut<'a> = VecZnx<&'a mut [u8]>; +pub type VecZnxRef<'a> = VecZnx<&'a [u8]>; + +pub trait VecZnxToRef { + fn to_ref(&self) -> VecZnx<&[u8]>; +} + +impl VecZnxToRef for VecZnx { + fn to_ref(&self) -> VecZnx<&[u8]> { + VecZnx { + data: self.data.as_ref(), + n: self.n, + cols: self.cols, + size: self.size, + max_size: self.max_size, + } + } +} + +pub trait VecZnxToMut { + fn to_mut(&mut self) -> VecZnx<&mut [u8]>; +} + +impl VecZnxToMut for VecZnx { + fn to_mut(&mut self) -> VecZnx<&mut [u8]> { + VecZnx { + data: self.data.as_mut(), + n: self.n, + cols: self.cols, + size: self.size, + max_size: self.max_size, + } + } +} + +impl VecZnx { + pub fn clone(&self) -> VecZnx> { + let self_ref: VecZnx<&[u8]> = self.to_ref(); + VecZnx { + data: self_ref.data.to_vec(), + n: self_ref.n, + cols: self_ref.cols, + size: self_ref.size, + max_size: self_ref.max_size, + } + } +} + +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; + +impl ReaderFrom for VecZnx { + fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { + self.n = reader.read_u64::()? as usize; + self.cols = reader.read_u64::()? as usize; + self.size = reader.read_u64::()? as usize; + self.max_size = reader.read_u64::()? as usize; + let len: usize = reader.read_u64::()? as usize; + let buf: &mut [u8] = self.data.as_mut(); + if buf.len() != len { + return Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + format!("self.data.len()={} != read len={}", buf.len(), len), + )); + } + reader.read_exact(&mut buf[..len])?; + Ok(()) + } +} + +impl WriterTo for VecZnx { + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + writer.write_u64::(self.n as u64)?; + writer.write_u64::(self.cols as u64)?; + writer.write_u64::(self.size as u64)?; + writer.write_u64::(self.max_size as u64)?; + let buf: &[u8] = self.data.as_ref(); + writer.write_u64::(buf.len() as u64)?; + writer.write_all(buf)?; + Ok(()) + } +} diff --git a/backend/src/hal/layouts/vec_znx_big.rs b/backend/src/hal/layouts/vec_znx_big.rs new file mode 100644 index 0000000..0a48727 --- /dev/null +++ b/backend/src/hal/layouts/vec_znx_big.rs @@ -0,0 +1,148 @@ +use std::marker::PhantomData; + +use rand_distr::num_traits::Zero; + +use crate::{ + alloc_aligned, + hal::{ + api::{DataView, DataViewMut, ZnxInfos, ZnxView, ZnxViewMut, ZnxZero}, + layouts::{Backend, Data, DataMut, DataRef}, + }, +}; + +#[derive(PartialEq, Eq)] +pub struct VecZnxBig { + pub(crate) data: D, + pub(crate) n: usize, + pub(crate) cols: usize, + pub(crate) size: usize, + pub(crate) max_size: usize, + pub(crate) _phantom: PhantomData, +} + +impl ZnxInfos for VecZnxBig { + fn cols(&self) -> usize { + self.cols + } + + fn rows(&self) -> usize { + 1 + } + + fn n(&self) -> usize { + self.n + } + + fn size(&self) -> usize { + self.size + } +} + +impl DataView for VecZnxBig { + type D = D; + fn data(&self) -> &Self::D { + &self.data + } +} + +impl DataViewMut for VecZnxBig { + fn data_mut(&mut self) -> &mut Self::D { + &mut self.data + } +} + +pub trait VecZnxBigBytesOf { + fn bytes_of(n: usize, cols: usize, size: usize) -> usize; +} + +impl ZnxZero for VecZnxBig +where + Self: ZnxViewMut, + ::Scalar: Zero + Copy, +{ + fn zero(&mut self) { + self.raw_mut().fill(::Scalar::zero()) + } + fn zero_at(&mut self, i: usize, j: usize) { + self.at_mut(i, j).fill(::Scalar::zero()); + } +} + +impl>, B: Backend> VecZnxBig +where + VecZnxBig: VecZnxBigBytesOf, +{ + pub(crate) fn new(n: usize, cols: usize, size: usize) -> Self { + let data = alloc_aligned::(Self::bytes_of(n, cols, size)); + Self { + data: data.into(), + n, + cols, + size, + max_size: size, + _phantom: PhantomData, + } + } + + pub(crate) fn new_from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into>) -> Self { + let data: Vec = bytes.into(); + assert!(data.len() == Self::bytes_of(n, cols, size)); + Self { + data: data.into(), + n, + cols, + size, + max_size: size, + _phantom: PhantomData, + } + } +} + +impl VecZnxBig { + pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self { + Self { + data, + n, + cols, + size, + max_size: size, + _phantom: PhantomData, + } + } +} + +pub type VecZnxBigOwned = VecZnxBig, B>; + +pub trait VecZnxBigToRef { + fn to_ref(&self) -> VecZnxBig<&[u8], B>; +} + +impl VecZnxBigToRef for VecZnxBig { + fn to_ref(&self) -> VecZnxBig<&[u8], B> { + VecZnxBig { + data: self.data.as_ref(), + n: self.n, + cols: self.cols, + size: self.size, + max_size: self.max_size, + _phantom: std::marker::PhantomData, + } + } +} + +pub trait VecZnxBigToMut { + fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B>; +} + +impl VecZnxBigToMut for VecZnxBig { + fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B> { + VecZnxBig { + data: self.data.as_mut(), + n: self.n, + cols: self.cols, + size: self.size, + max_size: self.max_size, + _phantom: std::marker::PhantomData, + } + } +} diff --git a/backend/src/hal/layouts/vec_znx_dft.rs b/backend/src/hal/layouts/vec_znx_dft.rs new file mode 100644 index 0000000..c814532 --- /dev/null +++ b/backend/src/hal/layouts/vec_znx_dft.rs @@ -0,0 +1,166 @@ +use std::marker::PhantomData; + +use rand_distr::num_traits::Zero; + +use crate::{ + alloc_aligned, + hal::{ + api::{DataView, DataViewMut, ZnxInfos, ZnxView, ZnxViewMut, ZnxZero}, + layouts::{Backend, Data, DataMut, DataRef, VecZnxBig}, + }, +}; +#[derive(PartialEq, Eq)] +pub struct VecZnxDft { + pub(crate) data: D, + pub(crate) n: usize, + pub(crate) cols: usize, + pub(crate) size: usize, + pub(crate) max_size: usize, + pub(crate) _phantom: PhantomData, +} + +impl VecZnxDft { + pub fn into_big(self) -> VecZnxBig { + VecZnxBig::::from_data(self.data, self.n, self.cols, self.size) + } +} + +impl ZnxInfos for VecZnxDft { + fn cols(&self) -> usize { + self.cols + } + + fn rows(&self) -> usize { + 1 + } + + fn n(&self) -> usize { + self.n + } + + fn size(&self) -> usize { + self.size + } +} + +impl DataView for VecZnxDft { + type D = D; + fn data(&self) -> &Self::D { + &self.data + } +} + +impl DataViewMut for VecZnxDft { + fn data_mut(&mut self) -> &mut Self::D { + &mut self.data + } +} + +impl VecZnxDft { + pub fn max_size(&self) -> usize { + self.max_size + } +} + +impl VecZnxDft { + pub fn set_size(&mut self, size: usize) { + assert!(size <= self.max_size); + self.size = size + } +} + +impl ZnxZero for VecZnxDft +where + Self: ZnxViewMut, + ::Scalar: Zero + Copy, +{ + fn zero(&mut self) { + self.raw_mut().fill(::Scalar::zero()) + } + fn zero_at(&mut self, i: usize, j: usize) { + self.at_mut(i, j).fill(::Scalar::zero()); + } +} + +pub trait VecZnxDftBytesOf { + fn bytes_of(n: usize, cols: usize, size: usize) -> usize; +} + +impl>, B: Backend> VecZnxDft +where + VecZnxDft: VecZnxDftBytesOf, +{ + pub(crate) fn alloc(n: usize, cols: usize, size: usize) -> Self { + let data: Vec = alloc_aligned::(Self::bytes_of(n, cols, size)); + Self { + data: data.into(), + n: n, + cols, + size, + max_size: size, + _phantom: PhantomData, + } + } + + pub(crate) fn from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into>) -> Self { + let data: Vec = bytes.into(); + assert!(data.len() == Self::bytes_of(n, cols, size)); + Self { + data: data.into(), + n: n, + cols, + size, + max_size: size, + _phantom: PhantomData, + } + } +} + +pub type VecZnxDftOwned = VecZnxDft, B>; + +impl VecZnxDft { + pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self { + Self { + data, + n, + cols, + size, + max_size: size, + _phantom: PhantomData, + } + } +} + +pub trait VecZnxDftToRef { + fn to_ref(&self) -> VecZnxDft<&[u8], B>; +} + +impl VecZnxDftToRef for VecZnxDft { + fn to_ref(&self) -> VecZnxDft<&[u8], B> { + VecZnxDft { + data: self.data.as_ref(), + n: self.n, + cols: self.cols, + size: self.size, + max_size: self.max_size, + _phantom: std::marker::PhantomData, + } + } +} + +pub trait VecZnxDftToMut { + fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B>; +} + +impl VecZnxDftToMut for VecZnxDft { + fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { + VecZnxDft { + data: self.data.as_mut(), + n: self.n, + cols: self.cols, + size: self.size, + max_size: self.max_size, + _phantom: std::marker::PhantomData, + } + } +} diff --git a/backend/src/hal/layouts/vmp_pmat.rs b/backend/src/hal/layouts/vmp_pmat.rs new file mode 100644 index 0000000..4a0c387 --- /dev/null +++ b/backend/src/hal/layouts/vmp_pmat.rs @@ -0,0 +1,157 @@ +use std::marker::PhantomData; + +use crate::{ + alloc_aligned, + hal::{ + api::{DataView, DataViewMut, ZnxInfos}, + layouts::{Backend, Data, DataMut, DataRef}, + }, +}; + +#[derive(PartialEq, Eq)] +pub struct VmpPMat { + data: D, + n: usize, + size: usize, + rows: usize, + cols_in: usize, + cols_out: usize, + _phantom: PhantomData, +} + +impl ZnxInfos for VmpPMat { + fn cols(&self) -> usize { + self.cols_in + } + + fn rows(&self) -> usize { + self.rows + } + + fn n(&self) -> usize { + self.n + } + + fn size(&self) -> usize { + self.size + } +} + +impl DataView for VmpPMat { + type D = D; + fn data(&self) -> &Self::D { + &self.data + } +} + +impl DataViewMut for VmpPMat { + fn data_mut(&mut self) -> &mut Self::D { + &mut self.data + } +} + +impl VmpPMat { + pub fn cols_in(&self) -> usize { + self.cols_in + } + + pub fn cols_out(&self) -> usize { + self.cols_out + } +} + +pub trait VmpPMatBytesOf { + fn vmp_pmat_bytes_of(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize; +} + +impl>, B: Backend> VmpPMat +where + B: VmpPMatBytesOf, +{ + pub(crate) fn alloc(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self { + let data: Vec = alloc_aligned(B::vmp_pmat_bytes_of(n, rows, cols_in, cols_out, size)); + Self { + data: data.into(), + n, + size, + rows, + cols_in, + cols_out, + _phantom: PhantomData, + } + } + + pub(crate) fn from_bytes( + n: usize, + rows: usize, + cols_in: usize, + cols_out: usize, + size: usize, + bytes: impl Into>, + ) -> Self { + let data: Vec = bytes.into(); + assert!(data.len() == B::vmp_pmat_bytes_of(n, rows, cols_in, cols_out, size)); + Self { + data: data.into(), + n, + size, + rows, + cols_in, + cols_out, + _phantom: PhantomData, + } + } +} + +pub type VmpPMatOwned = VmpPMat, B>; +pub type VmpPMatRef<'a, B> = VmpPMat<&'a [u8], B>; + +pub trait VmpPMatToRef { + fn to_ref(&self) -> VmpPMat<&[u8], B>; +} + +impl VmpPMatToRef for VmpPMat { + fn to_ref(&self) -> VmpPMat<&[u8], B> { + VmpPMat { + data: self.data.as_ref(), + n: self.n, + rows: self.rows, + cols_in: self.cols_in, + cols_out: self.cols_out, + size: self.size, + _phantom: std::marker::PhantomData, + } + } +} + +pub trait VmpPMatToMut { + fn to_mut(&mut self) -> VmpPMat<&mut [u8], B>; +} + +impl VmpPMatToMut for VmpPMat { + fn to_mut(&mut self) -> VmpPMat<&mut [u8], B> { + VmpPMat { + data: self.data.as_mut(), + n: self.n, + rows: self.rows, + cols_in: self.cols_in, + cols_out: self.cols_out, + size: self.size, + _phantom: std::marker::PhantomData, + } + } +} + +impl VmpPMat { + pub(crate) fn from_data(data: D, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self { + Self { + data, + n, + rows, + cols_in, + cols_out, + size, + _phantom: PhantomData, + } + } +} diff --git a/backend/src/hal/mod.rs b/backend/src/hal/mod.rs new file mode 100644 index 0000000..5ecf38d --- /dev/null +++ b/backend/src/hal/mod.rs @@ -0,0 +1,5 @@ +pub mod api; +pub mod delegates; +pub mod layouts; +pub mod oep; +pub mod tests; diff --git a/backend/src/hal/oep/mat_znx.rs b/backend/src/hal/oep/mat_znx.rs new file mode 100644 index 0000000..87fd16a --- /dev/null +++ b/backend/src/hal/oep/mat_znx.rs @@ -0,0 +1,20 @@ +use crate::hal::layouts::{Backend, MatZnxOwned, Module}; + +pub unsafe trait MatZnxAllocImpl { + fn mat_znx_alloc_impl(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxOwned; +} + +pub unsafe trait MatZnxAllocBytesImpl { + fn mat_znx_alloc_bytes_impl(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize; +} + +pub unsafe trait MatZnxFromBytesImpl { + fn mat_znx_from_bytes_impl( + module: &Module, + rows: usize, + cols_in: usize, + cols_out: usize, + size: usize, + bytes: Vec, + ) -> MatZnxOwned; +} diff --git a/backend/src/hal/oep/mod.rs b/backend/src/hal/oep/mod.rs new file mode 100644 index 0000000..ef1ee02 --- /dev/null +++ b/backend/src/hal/oep/mod.rs @@ -0,0 +1,19 @@ +mod mat_znx; +mod module; +mod scalar_znx; +mod scratch; +mod svp_ppol; +mod vec_znx; +mod vec_znx_big; +mod vec_znx_dft; +mod vmp_pmat; + +pub use mat_znx::*; +pub use module::*; +pub use scalar_znx::*; +pub use scratch::*; +pub use svp_ppol::*; +pub use vec_znx::*; +pub use vec_znx_big::*; +pub use vec_znx_dft::*; +pub use vmp_pmat::*; diff --git a/backend/src/hal/oep/module.rs b/backend/src/hal/oep/module.rs new file mode 100644 index 0000000..f2daa9b --- /dev/null +++ b/backend/src/hal/oep/module.rs @@ -0,0 +1,5 @@ +use crate::hal::layouts::{Backend, Module}; + +pub unsafe trait ModuleNewImpl { + fn new_impl(n: u64) -> Module; +} diff --git a/backend/src/hal/oep/scalar_znx.rs b/backend/src/hal/oep/scalar_znx.rs new file mode 100644 index 0000000..0c636e7 --- /dev/null +++ b/backend/src/hal/oep/scalar_znx.rs @@ -0,0 +1,39 @@ +use crate::hal::layouts::{Backend, Module, ScalarZnxOwned, ScalarZnxToMut, ScalarZnxToRef}; + +pub unsafe trait ScalarZnxFromBytesImpl { + fn scalar_znx_from_bytes_impl(n: usize, cols: usize, bytes: Vec) -> ScalarZnxOwned; +} + +pub unsafe trait ScalarZnxAllocBytesImpl { + fn scalar_znx_alloc_bytes_impl(n: usize, cols: usize) -> usize; +} + +pub unsafe trait ScalarZnxAllocImpl { + fn scalar_znx_alloc_impl(n: usize, cols: usize) -> ScalarZnxOwned; +} + +pub unsafe trait ScalarZnxAutomorphismImpl { + fn scalar_znx_automorphism_impl(module: &Module, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: ScalarZnxToMut, + A: ScalarZnxToRef; +} + +pub unsafe trait ScalarZnxAutomorphismInplaceIml { + fn scalar_znx_automorphism_inplace_impl(module: &Module, k: i64, a: &mut A, a_col: usize) + where + A: ScalarZnxToMut; +} + +pub unsafe trait ScalarZnxMulXpMinusOneImpl { + fn scalar_znx_mul_xp_minus_one_impl(module: &Module, p: i64, r: &mut R, r_col: usize, a: &A, a_col: usize) + where + R: ScalarZnxToMut, + A: ScalarZnxToRef; +} + +pub unsafe trait ScalarZnxMulXpMinusOneInplaceImpl { + fn scalar_znx_mul_xp_minus_one_inplace_impl(module: &Module, p: i64, r: &mut R, r_col: usize) + where + R: ScalarZnxToMut; +} diff --git a/backend/src/hal/oep/scratch.rs b/backend/src/hal/oep/scratch.rs new file mode 100644 index 0000000..894f530 --- /dev/null +++ b/backend/src/hal/oep/scratch.rs @@ -0,0 +1,199 @@ +use crate::hal::{ + api::ZnxInfos, + layouts::{Backend, DataRef, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat}, +}; + +pub unsafe trait ScratchOwnedAllocImpl { + fn scratch_owned_alloc_impl(size: usize) -> ScratchOwned; +} + +pub unsafe trait ScratchOwnedBorrowImpl { + fn scratch_owned_borrow_impl(scratch: &mut ScratchOwned) -> &mut Scratch; +} + +pub unsafe trait ScratchFromBytesImpl { + fn scratch_from_bytes_impl(data: &mut [u8]) -> &mut Scratch; +} + +pub unsafe trait ScratchAvailableImpl { + fn scratch_available_impl(scratch: &Scratch) -> usize; +} + +pub unsafe trait TakeSliceImpl { + fn take_slice_impl(scratch: &mut Scratch, len: usize) -> (&mut [T], &mut Scratch); +} + +pub unsafe trait TakeScalarZnxImpl { + fn take_scalar_znx_impl(scratch: &mut Scratch, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Scratch); +} + +pub unsafe trait TakeSvpPPolImpl { + fn take_svp_ppol_impl(scratch: &mut Scratch, n: usize, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Scratch); +} + +pub unsafe trait TakeVecZnxImpl { + fn take_vec_znx_impl(scratch: &mut Scratch, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Scratch); +} + +pub unsafe trait TakeVecZnxSliceImpl { + fn take_vec_znx_slice_impl( + scratch: &mut Scratch, + len: usize, + n: usize, + cols: usize, + size: usize, + ) -> (Vec>, &mut Scratch); +} + +pub unsafe trait TakeVecZnxBigImpl { + fn take_vec_znx_big_impl( + scratch: &mut Scratch, + n: usize, + cols: usize, + size: usize, + ) -> (VecZnxBig<&mut [u8], B>, &mut Scratch); +} + +pub unsafe trait TakeVecZnxDftImpl { + fn take_vec_znx_dft_impl( + scratch: &mut Scratch, + n: usize, + cols: usize, + size: usize, + ) -> (VecZnxDft<&mut [u8], B>, &mut Scratch); +} + +pub unsafe trait TakeVecZnxDftSliceImpl { + fn take_vec_znx_dft_slice_impl( + scratch: &mut Scratch, + len: usize, + n: usize, + cols: usize, + size: usize, + ) -> (Vec>, &mut Scratch); +} + +pub unsafe trait TakeVmpPMatImpl { + fn take_vmp_pmat_impl( + scratch: &mut Scratch, + n: usize, + rows: usize, + cols_in: usize, + cols_out: usize, + size: usize, + ) -> (VmpPMat<&mut [u8], B>, &mut Scratch); +} + +pub unsafe trait TakeMatZnxImpl { + fn take_mat_znx_impl( + scratch: &mut Scratch, + n: usize, + rows: usize, + cols_in: usize, + cols_out: usize, + size: usize, + ) -> (MatZnx<&mut [u8]>, &mut Scratch); +} + +pub trait TakeLikeImpl<'a, B: Backend, T> { + type Output; + fn take_like_impl(scratch: &'a mut Scratch, template: &T) -> (Self::Output, &'a mut Scratch); +} + +impl<'a, B: Backend, D> TakeLikeImpl<'a, B, VmpPMat> for B +where + B: TakeVmpPMatImpl, + D: DataRef, +{ + type Output = VmpPMat<&'a mut [u8], B>; + + fn take_like_impl(scratch: &'a mut Scratch, template: &VmpPMat) -> (Self::Output, &'a mut Scratch) { + B::take_vmp_pmat_impl( + scratch, + template.n(), + template.rows(), + template.cols_in(), + template.cols_out(), + template.size(), + ) + } +} + +impl<'a, B: Backend, D> TakeLikeImpl<'a, B, MatZnx> for B +where + B: TakeMatZnxImpl, + D: DataRef, +{ + type Output = MatZnx<&'a mut [u8]>; + + fn take_like_impl(scratch: &'a mut Scratch, template: &MatZnx) -> (Self::Output, &'a mut Scratch) { + B::take_mat_znx_impl( + scratch, + template.n(), + template.rows(), + template.cols_in(), + template.cols_out(), + template.size(), + ) + } +} + +impl<'a, B: Backend, D> TakeLikeImpl<'a, B, VecZnxDft> for B +where + B: TakeVecZnxDftImpl, + D: DataRef, +{ + type Output = VecZnxDft<&'a mut [u8], B>; + + fn take_like_impl(scratch: &'a mut Scratch, template: &VecZnxDft) -> (Self::Output, &'a mut Scratch) { + B::take_vec_znx_dft_impl(scratch, template.n(), template.cols(), template.size()) + } +} + +impl<'a, B: Backend, D> TakeLikeImpl<'a, B, VecZnxBig> for B +where + B: TakeVecZnxBigImpl, + D: DataRef, +{ + type Output = VecZnxBig<&'a mut [u8], B>; + + fn take_like_impl(scratch: &'a mut Scratch, template: &VecZnxBig) -> (Self::Output, &'a mut Scratch) { + B::take_vec_znx_big_impl(scratch, template.n(), template.cols(), template.size()) + } +} + +impl<'a, B: Backend, D> TakeLikeImpl<'a, B, SvpPPol> for B +where + B: TakeSvpPPolImpl, + D: DataRef, +{ + type Output = SvpPPol<&'a mut [u8], B>; + + fn take_like_impl(scratch: &'a mut Scratch, template: &SvpPPol) -> (Self::Output, &'a mut Scratch) { + B::take_svp_ppol_impl(scratch, template.n(), template.cols()) + } +} + +impl<'a, B: Backend, D> TakeLikeImpl<'a, B, VecZnx> for B +where + B: TakeVecZnxImpl, + D: DataRef, +{ + type Output = VecZnx<&'a mut [u8]>; + + fn take_like_impl(scratch: &'a mut Scratch, template: &VecZnx) -> (Self::Output, &'a mut Scratch) { + B::take_vec_znx_impl(scratch, template.n(), template.cols(), template.size()) + } +} + +impl<'a, B: Backend, D> TakeLikeImpl<'a, B, ScalarZnx> for B +where + B: TakeScalarZnxImpl, + D: DataRef, +{ + type Output = ScalarZnx<&'a mut [u8]>; + + fn take_like_impl(scratch: &'a mut Scratch, template: &ScalarZnx) -> (Self::Output, &'a mut Scratch) { + B::take_scalar_znx_impl(scratch, template.n(), template.cols()) + } +} diff --git a/backend/src/hal/oep/svp_ppol.rs b/backend/src/hal/oep/svp_ppol.rs new file mode 100644 index 0000000..aea822d --- /dev/null +++ b/backend/src/hal/oep/svp_ppol.rs @@ -0,0 +1,37 @@ +use crate::hal::layouts::{ + Backend, Module, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDftToMut, VecZnxDftToRef, +}; + +pub unsafe trait SvpPPolFromBytesImpl { + fn svp_ppol_from_bytes_impl(n: usize, cols: usize, bytes: Vec) -> SvpPPolOwned; +} + +pub unsafe trait SvpPPolAllocImpl { + fn svp_ppol_alloc_impl(n: usize, cols: usize) -> SvpPPolOwned; +} + +pub unsafe trait SvpPPolAllocBytesImpl { + fn svp_ppol_alloc_bytes_impl(n: usize, cols: usize) -> usize; +} + +pub unsafe trait SvpPrepareImpl { + fn svp_prepare_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: SvpPPolToMut, + A: ScalarZnxToRef; +} + +pub unsafe trait SvpApplyImpl { + fn svp_apply_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize) + where + R: VecZnxDftToMut, + A: SvpPPolToRef, + C: VecZnxDftToRef; +} + +pub unsafe trait SvpApplyInplaceImpl: Backend { + fn svp_apply_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: SvpPPolToRef; +} diff --git a/backend/src/hal/oep/vec_znx.rs b/backend/src/hal/oep/vec_znx.rs new file mode 100644 index 0000000..3d61fc0 --- /dev/null +++ b/backend/src/hal/oep/vec_znx.rs @@ -0,0 +1,465 @@ +use rand_distr::Distribution; +use rug::Float; +use sampling::source::Source; + +use crate::hal::layouts::{Backend, Module, ScalarZnxToRef, Scratch, VecZnxOwned, VecZnxToMut, VecZnxToRef}; + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [crate::hal::layouts::VecZnx::new] for reference code. +/// * See [crate::hal::api::VecZnxAlloc] for corresponding public API. +/// * See [crate::doc::backend_safety] for safety contract. +/// * See test \[TODO\] +pub unsafe trait VecZnxAllocImpl { + fn vec_znx_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxOwned; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [crate::hal::layouts::VecZnx::from_bytes] for reference code. +/// * See [crate::hal::api::VecZnxFromBytes] for corresponding public API. +/// * See [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxFromBytesImpl { + fn vec_znx_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec) -> VecZnxOwned; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [crate::hal::layouts::VecZnx::alloc_bytes] for reference code. +/// * See [crate::hal::api::VecZnxAllocBytes] for corresponding public API. +/// * See [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxAllocBytesImpl { + fn vec_znx_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [vec_znx_normalize_base2k_tmp_bytes_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L245C17-L245C55) for reference code. +/// * See [crate::hal::api::VecZnxNormalizeTmpBytes] for corresponding public API. +/// * See [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxNormalizeTmpBytesImpl { + fn vec_znx_normalize_tmp_bytes_impl(module: &Module, n: usize) -> usize; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [vec_znx_normalize_base2k_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L212) for reference code. +/// * See [crate::hal::api::VecZnxNormalize] for corresponding public API. +/// * See [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxNormalizeImpl { + fn vec_znx_normalize_impl( + module: &Module, + basek: usize, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxToMut, + A: VecZnxToRef; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [vec_znx_normalize_base2k_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L212) for reference code. +/// * See [crate::hal::api::VecZnxNormalizeInplace] for corresponding public API. +/// * See [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxNormalizeInplaceImpl { + fn vec_znx_normalize_inplace_impl(module: &Module, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) + where + A: VecZnxToMut; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [vec_znx_add_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L86) for reference code. +/// * See [crate::hal::api::VecZnxAdd] for corresponding public API. +/// * See [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxAddImpl { + fn vec_znx_add_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + C: VecZnxToRef; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [vec_znx_add_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L86) for reference code. +/// * See [crate::hal::api::VecZnxAddInplace] for corresponding public API. +/// * See [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxAddInplaceImpl { + fn vec_znx_add_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [vec_znx_add_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L86) for reference code. +/// * See [crate::hal::api::VecZnxAddScalarInplace] for corresponding public API. +/// * See [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxAddScalarInplaceImpl { + fn vec_znx_add_scalar_inplace_impl( + module: &Module, + res: &mut R, + res_col: usize, + res_limb: usize, + a: &A, + a_col: usize, + ) where + R: VecZnxToMut, + A: ScalarZnxToRef; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [vec_znx_sub_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L125) for reference code. +/// * See [crate::hal::api::VecZnxSub] for corresponding public API. +/// * See [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxSubImpl { + fn vec_znx_sub_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + C: VecZnxToRef; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [vec_znx_sub_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L125) for reference code. +/// * See [crate::hal::api::VecZnxSubABInplace] for corresponding public API. +/// * See [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxSubABInplaceImpl { + fn vec_znx_sub_ab_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [vec_znx_sub_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L125) for reference code. +/// * See [crate::hal::api::VecZnxSubBAInplace] for corresponding public API. +/// * See [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxSubBAInplaceImpl { + fn vec_znx_sub_ba_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [vec_znx_sub_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L125) for reference code. +/// * See [crate::hal::api::VecZnxSubScalarInplace] for corresponding public API. +/// * See [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxSubScalarInplaceImpl { + fn vec_znx_sub_scalar_inplace_impl( + module: &Module, + res: &mut R, + res_col: usize, + res_limb: usize, + a: &A, + a_col: usize, + ) where + R: VecZnxToMut, + A: ScalarZnxToRef; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [vec_znx_negate_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L322C13-L322C31) for reference code. +/// * See [crate::hal::api::VecZnxNegate] for corresponding public API. +/// * See [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxNegateImpl { + fn vec_znx_negate_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [vec_znx_negate_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L322C13-L322C31) for reference code. +/// * See [crate::hal::api::VecZnxNegateInplace] for corresponding public API. +/// * See [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxNegateInplaceImpl { + fn vec_znx_negate_inplace_impl(module: &Module, a: &mut A, a_col: usize) + where + A: VecZnxToMut; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [crate::implementation::cpu_spqlios::vec_znx::vec_znx_rsh_inplace_ref] for reference code. +/// * See [crate::hal::api::VecZnxRshInplace] for corresponding public API. +/// * See [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxRshInplaceImpl { + fn vec_znx_rsh_inplace_impl(module: &Module, basek: usize, k: usize, a: &mut A) + where + A: VecZnxToMut; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [crate::implementation::cpu_spqlios::vec_znx::vec_znx_lsh_inplace_ref] for reference code. +/// * See [crate::hal::api::VecZnxLshInplace] for corresponding public API. +/// * See [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxLshInplaceImpl { + fn vec_znx_lsh_inplace_impl(module: &Module, basek: usize, k: usize, a: &mut A) + where + A: VecZnxToMut; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [vec_znx_rotate_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L164) for reference code. +/// * See [crate::hal::api::VecZnxRotate] for corresponding public API. +/// * See [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxRotateImpl { + fn vec_znx_rotate_impl(module: &Module, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [vec_znx_rotate_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L164) for reference code. +/// * See [crate::hal::api::VecZnxRotateInplace] for corresponding public API. +/// * See [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxRotateInplaceImpl { + fn vec_znx_rotate_inplace_impl(module: &Module, k: i64, a: &mut A, a_col: usize) + where + A: VecZnxToMut; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [vec_znx_automorphism_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L188) for reference code. +/// * See [crate::hal::api::VecZnxAutomorphism] for corresponding public API. +/// * See [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxAutomorphismImpl { + fn vec_znx_automorphism_impl(module: &Module, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [vec_znx_automorphism_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L188) for reference code. +/// * See [crate::hal::api::VecZnxAutomorphismInplace] for corresponding public API. +/// * See [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxAutomorphismInplaceImpl { + fn vec_znx_automorphism_inplace_impl(module: &Module, k: i64, a: &mut A, a_col: usize) + where + A: VecZnxToMut; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [vec_znx_mul_xp_minus_one_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/7160f588da49712a042931ea247b4259b95cefcc/spqlios/arithmetic/vec_znx.c#L200C13-L200C41) for reference code. +/// * See [crate::hal::api::VecZnxMulXpMinusOne] for corresponding public API. +/// * See [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxMulXpMinusOneImpl { + fn vec_znx_mul_xp_minus_one_impl(module: &Module, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [vec_znx_mul_xp_minus_one_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/7160f588da49712a042931ea247b4259b95cefcc/spqlios/arithmetic/vec_znx.c#L200C13-L200C41) for reference code. +/// * See [crate::hal::api::VecZnxMulXpMinusOneInplace] for corresponding public API. +/// * See [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxMulXpMinusOneInplaceImpl { + fn vec_znx_mul_xp_minus_one_inplace_impl(module: &Module, p: i64, res: &mut R, res_col: usize) + where + R: VecZnxToMut; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [crate::implementation::cpu_spqlios::vec_znx::vec_znx_split_ref] for reference code. +/// * See [crate::hal::api::VecZnxSplit] for corresponding public API. +/// * See [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxSplitImpl { + fn vec_znx_split_impl( + module: &Module, + res: &mut Vec, + res_col: usize, + a: &A, + a_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxToMut, + A: VecZnxToRef; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [crate::implementation::cpu_spqlios::vec_znx::vec_znx_merge_ref] for reference code. +/// * See [crate::hal::api::VecZnxMerge] for corresponding public API. +/// * See [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxMergeImpl { + fn vec_znx_merge_impl(module: &Module, res: &mut R, res_col: usize, a: Vec, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [crate::implementation::cpu_spqlios::vec_znx::vec_znx_switch_degree_ref] for reference code. +/// * See [crate::hal::api::VecZnxSwithcDegree] for corresponding public API. +/// * See [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxSwithcDegreeImpl { + fn vec_znx_switch_degree_impl( + module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + ); +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [crate::implementation::cpu_spqlios::vec_znx::vec_znx_copy_ref] for reference code. +/// * See [crate::hal::api::VecZnxCopy] for corresponding public API. +/// * See [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxCopyImpl { + fn vec_znx_copy_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [crate::hal::api::VecZnxStd] for corresponding public API. +/// * See [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxStdImpl { + fn vec_znx_std_impl(module: &Module, basek: usize, a: &A, a_col: usize) -> f64 + where + A: VecZnxToRef; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [crate::hal::api::VecZnxFillUniform] for corresponding public API. +/// * See [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxFillUniformImpl { + fn vec_znx_fill_uniform_impl(module: &Module, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source) + where + R: VecZnxToMut; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [crate::hal::api::VecZnxFillDistF64] for corresponding public API. +/// * See [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxFillDistF64Impl { + fn vec_znx_fill_dist_f64_impl>( + module: &Module, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + dist: D, + bound: f64, + ) where + R: VecZnxToMut; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [crate::hal::api::VecZnxAddDistF64] for corresponding public API. +/// * See [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxAddDistF64Impl { + fn vec_znx_add_dist_f64_impl>( + module: &Module, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + dist: D, + bound: f64, + ) where + R: VecZnxToMut; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [crate::hal::api::VecZnxFillNormal] for corresponding public API. +/// * See [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxFillNormalImpl { + fn vec_znx_fill_normal_impl( + module: &Module, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + sigma: f64, + bound: f64, + ) where + R: VecZnxToMut; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [crate::hal::api::VecZnxAddNormal] for corresponding public API. +/// * See [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxAddNormalImpl { + fn vec_znx_add_normal_impl( + module: &Module, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + sigma: f64, + bound: f64, + ) where + R: VecZnxToMut; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See \[TODO\] for reference code. +/// * See [crate::hal::api::VecZnxEncodeVeci64] for corresponding public API. +/// * See [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxEncodeVeci64Impl { + fn encode_vec_i64_impl( + module: &Module, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + data: &[i64], + log_max: usize, + ) where + R: VecZnxToMut; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See \[TODO\] for reference code. +/// * See [crate::hal::api::VecZnxEncodeCoeffsi64] for corresponding public API. +/// * See [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxEncodeCoeffsi64Impl { + fn encode_coeff_i64_impl( + module: &Module, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + i: usize, + data: i64, + log_max: usize, + ) where + R: VecZnxToMut; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See \[TODO\] for reference code. +/// * See [crate::hal::api::VecZnxDecodeVeci64] for corresponding public API. +/// * See [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxDecodeVeci64Impl { + fn decode_vec_i64_impl(module: &Module, basek: usize, res: &R, res_col: usize, k: usize, data: &mut [i64]) + where + R: VecZnxToRef; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See \[TODO\] for reference code. +/// * See [crate::hal::api::VecZnxDecodeCoeffsi64] for corresponding public API. +/// * See [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxDecodeCoeffsi64Impl { + fn decode_coeff_i64_impl(module: &Module, basek: usize, res: &R, res_col: usize, k: usize, i: usize) -> i64 + where + R: VecZnxToRef; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See \[TODO\] for reference code. +/// * See [crate::hal::api::VecZnxDecodeVecFloat] for corresponding public API. +/// * See [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxDecodeVecFloatImpl { + fn decode_vec_float_impl(module: &Module, basek: usize, res: &R, res_col: usize, data: &mut [Float]) + where + R: VecZnxToRef; +} diff --git a/backend/src/hal/oep/vec_znx_big.rs b/backend/src/hal/oep/vec_znx_big.rs new file mode 100644 index 0000000..8ff0564 --- /dev/null +++ b/backend/src/hal/oep/vec_znx_big.rs @@ -0,0 +1,208 @@ +use rand_distr::Distribution; +use sampling::source::Source; + +use crate::hal::layouts::{Backend, Module, Scratch, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef}; + +pub unsafe trait VecZnxBigAllocImpl { + fn vec_znx_big_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxBigOwned; +} + +pub unsafe trait VecZnxBigFromBytesImpl { + fn vec_znx_big_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec) -> VecZnxBigOwned; +} + +pub unsafe trait VecZnxBigAllocBytesImpl { + fn vec_znx_big_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize; +} + +pub unsafe trait VecZnxBigAddNormalImpl { + fn add_normal_impl>( + module: &Module, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + sigma: f64, + bound: f64, + ); +} + +pub unsafe trait VecZnxBigFillNormalImpl { + fn fill_normal_impl>( + module: &Module, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + sigma: f64, + bound: f64, + ); +} + +pub unsafe trait VecZnxBigFillDistF64Impl { + fn fill_dist_f64_impl, D: Distribution>( + module: &Module, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + dist: D, + bound: f64, + ); +} + +pub unsafe trait VecZnxBigAddDistF64Impl { + fn add_dist_f64_impl, D: Distribution>( + module: &Module, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + dist: D, + bound: f64, + ); +} + +pub unsafe trait VecZnxBigAddImpl { + fn vec_znx_big_add_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + C: VecZnxBigToRef; +} + +pub unsafe trait VecZnxBigAddInplaceImpl { + fn vec_znx_big_add_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef; +} + +pub unsafe trait VecZnxBigAddSmallImpl { + fn vec_znx_big_add_small_impl( + module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + b: &C, + b_col: usize, + ) where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + C: VecZnxToRef; +} + +pub unsafe trait VecZnxBigAddSmallInplaceImpl { + fn vec_znx_big_add_small_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef; +} + +pub unsafe trait VecZnxBigSubImpl { + fn vec_znx_big_sub_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + C: VecZnxBigToRef; +} + +pub unsafe trait VecZnxBigSubABInplaceImpl { + fn vec_znx_big_sub_ab_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef; +} + +pub unsafe trait VecZnxBigSubBAInplaceImpl { + fn vec_znx_big_sub_ba_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef; +} + +pub unsafe trait VecZnxBigSubSmallAImpl { + fn vec_znx_big_sub_small_a_impl( + module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + b: &C, + b_col: usize, + ) where + R: VecZnxBigToMut, + A: VecZnxToRef, + C: VecZnxBigToRef; +} + +pub unsafe trait VecZnxBigSubSmallAInplaceImpl { + fn vec_znx_big_sub_small_a_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef; +} + +pub unsafe trait VecZnxBigSubSmallBImpl { + fn vec_znx_big_sub_small_b_impl( + module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + b: &C, + b_col: usize, + ) where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + C: VecZnxToRef; +} + +pub unsafe trait VecZnxBigSubSmallBInplaceImpl { + fn vec_znx_big_sub_small_b_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef; +} + +pub unsafe trait VecZnxBigNegateInplaceImpl { + fn vec_znx_big_negate_inplace_impl(module: &Module, a: &mut A, a_col: usize) + where + A: VecZnxBigToMut; +} + +pub unsafe trait VecZnxBigNormalizeTmpBytesImpl { + fn vec_znx_big_normalize_tmp_bytes_impl(module: &Module, n: usize) -> usize; +} + +pub unsafe trait VecZnxBigNormalizeImpl { + fn vec_znx_big_normalize_impl( + module: &Module, + basek: usize, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxToMut, + A: VecZnxBigToRef; +} + +pub unsafe trait VecZnxBigAutomorphismImpl { + fn vec_znx_big_automorphism_impl(module: &Module, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef; +} + +pub unsafe trait VecZnxBigAutomorphismInplaceImpl { + fn vec_znx_big_automorphism_inplace_impl(module: &Module, k: i64, a: &mut A, a_col: usize) + where + A: VecZnxBigToMut; +} diff --git a/backend/src/hal/oep/vec_znx_dft.rs b/backend/src/hal/oep/vec_znx_dft.rs new file mode 100644 index 0000000..3f2aa0a --- /dev/null +++ b/backend/src/hal/oep/vec_znx_dft.rs @@ -0,0 +1,117 @@ +use crate::hal::layouts::{ + Backend, Data, Module, Scratch, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef, + VecZnxToRef, +}; + +pub unsafe trait VecZnxDftAllocImpl { + fn vec_znx_dft_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxDftOwned; +} + +pub unsafe trait VecZnxDftFromBytesImpl { + fn vec_znx_dft_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec) -> VecZnxDftOwned; +} + +pub unsafe trait VecZnxDftAllocBytesImpl { + fn vec_znx_dft_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize; +} + +pub unsafe trait VecZnxDftToVecZnxBigTmpBytesImpl { + fn vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(module: &Module) -> usize; +} + +pub unsafe trait VecZnxDftToVecZnxBigImpl { + fn vec_znx_dft_to_vec_znx_big_impl( + module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxBigToMut, + A: VecZnxDftToRef; +} + +pub unsafe trait VecZnxDftToVecZnxBigTmpAImpl { + fn vec_znx_dft_to_vec_znx_big_tmp_a_impl(module: &Module, res: &mut R, res_col: usize, a: &mut A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxDftToMut; +} + +pub unsafe trait VecZnxDftToVecZnxBigConsumeImpl { + fn vec_znx_dft_to_vec_znx_big_consume_impl(module: &Module, a: VecZnxDft) -> VecZnxBig + where + VecZnxDft: VecZnxDftToMut; +} + +pub unsafe trait VecZnxDftAddImpl { + fn vec_znx_dft_add_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + D: VecZnxDftToRef; +} + +pub unsafe trait VecZnxDftAddInplaceImpl { + fn vec_znx_dft_add_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef; +} + +pub unsafe trait VecZnxDftSubImpl { + fn vec_znx_dft_sub_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + D: VecZnxDftToRef; +} + +pub unsafe trait VecZnxDftSubABInplaceImpl { + fn vec_znx_dft_sub_ab_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef; +} + +pub unsafe trait VecZnxDftSubBAInplaceImpl { + fn vec_znx_dft_sub_ba_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef; +} + +pub unsafe trait VecZnxDftCopyImpl { + fn vec_znx_dft_copy_impl( + module: &Module, + step: usize, + offset: usize, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + ) where + R: VecZnxDftToMut, + A: VecZnxDftToRef; +} + +pub unsafe trait VecZnxDftFromVecZnxImpl { + fn vec_znx_dft_from_vec_znx_impl( + module: &Module, + step: usize, + offset: usize, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + ) where + R: VecZnxDftToMut, + A: VecZnxToRef; +} + +pub unsafe trait VecZnxDftZeroImpl { + fn vec_znx_dft_zero_impl(module: &Module, res: &mut R) + where + R: VecZnxDftToMut; +} diff --git a/backend/src/hal/oep/vmp_pmat.rs b/backend/src/hal/oep/vmp_pmat.rs new file mode 100644 index 0000000..56e7299 --- /dev/null +++ b/backend/src/hal/oep/vmp_pmat.rs @@ -0,0 +1,74 @@ +use crate::hal::layouts::{ + Backend, MatZnxToRef, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VmpPMatOwned, VmpPMatToMut, VmpPMatToRef, +}; + +pub unsafe trait VmpPMatAllocImpl { + fn vmp_pmat_alloc_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned; +} + +pub unsafe trait VmpPMatAllocBytesImpl { + fn vmp_pmat_alloc_bytes_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize; +} + +pub unsafe trait VmpPMatFromBytesImpl { + fn vmp_pmat_from_bytes_impl( + n: usize, + rows: usize, + cols_in: usize, + cols_out: usize, + size: usize, + bytes: Vec, + ) -> VmpPMatOwned; +} + +pub unsafe trait VmpPrepareTmpBytesImpl { + fn vmp_prepare_tmp_bytes_impl(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize; +} + +pub unsafe trait VmpPMatPrepareImpl { + fn vmp_prepare_impl(module: &Module, res: &mut R, a: &A, scratch: &mut Scratch) + where + R: VmpPMatToMut, + A: MatZnxToRef; +} + +pub unsafe trait VmpApplyTmpBytesImpl { + fn vmp_apply_tmp_bytes_impl( + module: &Module, + res_size: usize, + a_size: usize, + b_rows: usize, + b_cols_in: usize, + b_cols_out: usize, + b_size: usize, + ) -> usize; +} + +pub unsafe trait VmpApplyImpl { + fn vmp_apply_impl(module: &Module, res: &mut R, a: &A, b: &C, scratch: &mut Scratch) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + C: VmpPMatToRef; +} + +pub unsafe trait VmpApplyAddTmpBytesImpl { + fn vmp_apply_add_tmp_bytes_impl( + module: &Module, + res_size: usize, + a_size: usize, + b_rows: usize, + b_cols_in: usize, + b_cols_out: usize, + b_size: usize, + ) -> usize; +} + +pub unsafe trait VmpApplyAddImpl { + // Same as [MatZnxDftOps::vmp_apply] except result is added on R instead of overwritting R. + fn vmp_apply_add_impl(module: &Module, res: &mut R, a: &A, b: &C, scale: usize, scratch: &mut Scratch) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + C: VmpPMatToRef; +} diff --git a/backend/src/hal/tests/mod.rs b/backend/src/hal/tests/mod.rs new file mode 100644 index 0000000..d7dd017 --- /dev/null +++ b/backend/src/hal/tests/mod.rs @@ -0,0 +1 @@ +pub mod vec_znx; diff --git a/backend/src/hal/tests/vec_znx/generics.rs b/backend/src/hal/tests/vec_znx/generics.rs new file mode 100644 index 0000000..626c170 --- /dev/null +++ b/backend/src/hal/tests/vec_znx/generics.rs @@ -0,0 +1,120 @@ +use itertools::izip; +use sampling::source::Source; + +use crate::hal::{ + api::{ + VecZnxAddNormal, VecZnxAlloc, VecZnxDecodeVeci64, VecZnxEncodeVeci64, VecZnxFillUniform, VecZnxStd, ZnxInfos, ZnxView, + ZnxViewMut, + }, + layouts::{Backend, Module, VecZnx}, +}; + +pub fn test_vec_znx_fill_uniform(module: &Module) +where + Module: VecZnxFillUniform + VecZnxStd + VecZnxAlloc, +{ + let basek: usize = 17; + let size: usize = 5; + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + let zero: Vec = vec![0; module.n()]; + let one_12_sqrt: f64 = 0.28867513459481287; + (0..cols).for_each(|col_i| { + let mut a: VecZnx<_> = module.vec_znx_alloc(cols, size); + module.vec_znx_fill_uniform(basek, &mut a, col_i, size * basek, &mut source); + (0..cols).for_each(|col_j| { + if col_j != col_i { + (0..size).for_each(|limb_i| { + assert_eq!(a.at(col_j, limb_i), zero); + }) + } else { + let std: f64 = module.vec_znx_std(basek, &a, col_i); + assert!( + (std - one_12_sqrt).abs() < 0.01, + "std={} ~!= {}", + std, + one_12_sqrt + ); + } + }) + }); +} + +pub fn test_vec_znx_add_normal(module: &Module) +where + Module: VecZnxAddNormal + VecZnxStd + VecZnxAlloc, +{ + let basek: usize = 17; + let k: usize = 2 * 17; + let size: usize = 5; + let sigma: f64 = 3.2; + let bound: f64 = 6.0 * sigma; + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + let zero: Vec = vec![0; module.n()]; + let k_f64: f64 = (1u64 << k as u64) as f64; + (0..cols).for_each(|col_i| { + let mut a: VecZnx<_> = module.vec_znx_alloc(cols, size); + module.vec_znx_add_normal(basek, &mut a, col_i, k, &mut source, sigma, bound); + (0..cols).for_each(|col_j| { + if col_j != col_i { + (0..size).for_each(|limb_i| { + assert_eq!(a.at(col_j, limb_i), zero); + }) + } else { + let std: f64 = module.vec_znx_std(basek, &a, col_i) * k_f64; + assert!((std - sigma).abs() < 0.1, "std={} ~!= {}", std, sigma); + } + }) + }); +} + +pub fn test_vec_znx_encode_vec_i64_lo_norm(module: &Module) +where + Module: VecZnxEncodeVeci64 + VecZnxDecodeVeci64 + VecZnxAlloc, +{ + let basek: usize = 17; + let size: usize = 5; + let k: usize = size * basek - 5; + let mut a: VecZnx<_> = module.vec_znx_alloc(2, size); + let mut source: Source = Source::new([0u8; 32]); + let raw: &mut [i64] = a.raw_mut(); + raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); + (0..a.cols()).for_each(|col_i| { + let mut have: Vec = vec![i64::default(); module.n()]; + have.iter_mut() + .for_each(|x| *x = (source.next_i64() << 56) >> 56); + module.encode_vec_i64(basek, &mut a, col_i, k, &have, 10); + let mut want: Vec = vec![i64::default(); module.n()]; + module.decode_vec_i64(basek, &a, col_i, k, &mut want); + izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); + }); +} + +pub fn test_vec_znx_encode_vec_i64_hi_norm(module: &Module) +where + Module: VecZnxEncodeVeci64 + VecZnxDecodeVeci64 + VecZnxAlloc, +{ + let basek: usize = 17; + let size: usize = 5; + for k in [1, basek / 2, size * basek - 5] { + let mut a: VecZnx<_> = module.vec_znx_alloc(2, size); + let mut source = Source::new([0u8; 32]); + let raw: &mut [i64] = a.raw_mut(); + raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); + (0..a.cols()).for_each(|col_i| { + let mut have: Vec = vec![i64::default(); module.n()]; + have.iter_mut().for_each(|x| { + if k < 64 { + *x = source.next_u64n(1 << k, (1 << k) - 1) as i64; + } else { + *x = source.next_i64(); + } + }); + module.encode_vec_i64(basek, &mut a, col_i, k, &have, 63); + let mut want: Vec = vec![i64::default(); module.n()]; + module.decode_vec_i64(basek, &a, col_i, k, &mut want); + izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); + }) + } +} diff --git a/backend/src/hal/tests/vec_znx/mod.rs b/backend/src/hal/tests/vec_znx/mod.rs new file mode 100644 index 0000000..7a179f3 --- /dev/null +++ b/backend/src/hal/tests/vec_znx/mod.rs @@ -0,0 +1,2 @@ +mod generics; +pub use generics::*; diff --git a/backend/src/ffi/cnv.rs b/backend/src/implementation/cpu_spqlios/ffi/cnv.rs similarity index 100% rename from backend/src/ffi/cnv.rs rename to backend/src/implementation/cpu_spqlios/ffi/cnv.rs diff --git a/backend/src/ffi/mod.rs b/backend/src/implementation/cpu_spqlios/ffi/mod.rs similarity index 100% rename from backend/src/ffi/mod.rs rename to backend/src/implementation/cpu_spqlios/ffi/mod.rs diff --git a/backend/src/ffi/module.rs b/backend/src/implementation/cpu_spqlios/ffi/module.rs similarity index 100% rename from backend/src/ffi/module.rs rename to backend/src/implementation/cpu_spqlios/ffi/module.rs diff --git a/backend/src/ffi/reim.rs b/backend/src/implementation/cpu_spqlios/ffi/reim.rs similarity index 100% rename from backend/src/ffi/reim.rs rename to backend/src/implementation/cpu_spqlios/ffi/reim.rs diff --git a/backend/src/ffi/svp.rs b/backend/src/implementation/cpu_spqlios/ffi/svp.rs similarity index 89% rename from backend/src/ffi/svp.rs rename to backend/src/implementation/cpu_spqlios/ffi/svp.rs index 8c994c9..f9db97f 100644 --- a/backend/src/ffi/svp.rs +++ b/backend/src/implementation/cpu_spqlios/ffi/svp.rs @@ -1,48 +1,47 @@ -use crate::ffi::module::MODULE; -use crate::ffi::vec_znx_dft::VEC_ZNX_DFT; - -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct svp_ppol_t { - _unused: [u8; 0], -} -pub type SVP_PPOL = svp_ppol_t; - -unsafe extern "C" { - pub unsafe fn bytes_of_svp_ppol(module: *const MODULE) -> u64; -} -unsafe extern "C" { - pub unsafe fn new_svp_ppol(module: *const MODULE) -> *mut SVP_PPOL; -} -unsafe extern "C" { - pub unsafe fn delete_svp_ppol(res: *mut SVP_PPOL); -} - -unsafe extern "C" { - pub unsafe fn svp_prepare(module: *const MODULE, ppol: *mut SVP_PPOL, pol: *const i64); -} - -unsafe extern "C" { - pub unsafe fn svp_apply_dft( - module: *const MODULE, - res: *const VEC_ZNX_DFT, - res_size: u64, - ppol: *const SVP_PPOL, - a: *const i64, - a_size: u64, - a_sl: u64, - ); -} - -unsafe extern "C" { - pub unsafe fn svp_apply_dft_to_dft( - module: *const MODULE, - res: *const VEC_ZNX_DFT, - res_size: u64, - res_cols: u64, - ppol: *const SVP_PPOL, - a: *const VEC_ZNX_DFT, - a_size: u64, - a_cols: u64, - ); -} +use crate::implementation::cpu_spqlios::ffi::{module::MODULE, vec_znx_dft::VEC_ZNX_DFT}; + +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct svp_ppol_t { + _unused: [u8; 0], +} +pub type SVP_PPOL = svp_ppol_t; + +unsafe extern "C" { + pub unsafe fn bytes_of_svp_ppol(module: *const MODULE) -> u64; +} +unsafe extern "C" { + pub unsafe fn new_svp_ppol(module: *const MODULE) -> *mut SVP_PPOL; +} +unsafe extern "C" { + pub unsafe fn delete_svp_ppol(res: *mut SVP_PPOL); +} + +unsafe extern "C" { + pub unsafe fn svp_prepare(module: *const MODULE, ppol: *mut SVP_PPOL, pol: *const i64); +} + +unsafe extern "C" { + pub unsafe fn svp_apply_dft( + module: *const MODULE, + res: *const VEC_ZNX_DFT, + res_size: u64, + ppol: *const SVP_PPOL, + a: *const i64, + a_size: u64, + a_sl: u64, + ); +} + +unsafe extern "C" { + pub unsafe fn svp_apply_dft_to_dft( + module: *const MODULE, + res: *const VEC_ZNX_DFT, + res_size: u64, + res_cols: u64, + ppol: *const SVP_PPOL, + a: *const VEC_ZNX_DFT, + a_size: u64, + a_cols: u64, + ); +} diff --git a/backend/src/ffi/vec_znx.rs b/backend/src/implementation/cpu_spqlios/ffi/vec_znx.rs similarity index 80% rename from backend/src/ffi/vec_znx.rs rename to backend/src/implementation/cpu_spqlios/ffi/vec_znx.rs index b377fd1..f4ea531 100644 --- a/backend/src/ffi/vec_znx.rs +++ b/backend/src/implementation/cpu_spqlios/ffi/vec_znx.rs @@ -1,4 +1,4 @@ -use crate::ffi::module::MODULE; +use crate::implementation::cpu_spqlios::ffi::module::MODULE; unsafe extern "C" { pub unsafe fn vec_znx_add( @@ -28,6 +28,19 @@ unsafe extern "C" { ); } +unsafe extern "C" { + pub unsafe fn vec_znx_mul_xp_minus_one( + module: *const MODULE, + p: i64, + res: *mut i64, + res_size: u64, + res_sl: u64, + a: *const i64, + a_size: u64, + a_sl: u64, + ); +} + unsafe extern "C" { pub unsafe fn vec_znx_negate( module: *const MODULE, @@ -86,6 +99,7 @@ unsafe extern "C" { unsafe extern "C" { pub unsafe fn vec_znx_normalize_base2k( module: *const MODULE, + n: u64, base2k: u64, res: *mut i64, res_size: u64, @@ -97,5 +111,5 @@ unsafe extern "C" { ); } unsafe extern "C" { - pub unsafe fn vec_znx_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64; + pub unsafe fn vec_znx_normalize_base2k_tmp_bytes(module: *const MODULE, n: u64) -> u64; } diff --git a/backend/src/ffi/vec_znx_big.rs b/backend/src/implementation/cpu_spqlios/ffi/vec_znx_big.rs similarity index 91% rename from backend/src/ffi/vec_znx_big.rs rename to backend/src/implementation/cpu_spqlios/ffi/vec_znx_big.rs index 1353051..16d5647 100644 --- a/backend/src/ffi/vec_znx_big.rs +++ b/backend/src/implementation/cpu_spqlios/ffi/vec_znx_big.rs @@ -1,4 +1,4 @@ -use crate::ffi::module::MODULE; +use crate::implementation::cpu_spqlios::ffi::module::MODULE; #[repr(C)] #[derive(Debug, Copy, Clone)] @@ -103,12 +103,13 @@ unsafe extern "C" { } unsafe extern "C" { - pub unsafe fn vec_znx_big_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64; + pub unsafe fn vec_znx_big_normalize_base2k_tmp_bytes(module: *const MODULE, n: u64) -> u64; } unsafe extern "C" { pub unsafe fn vec_znx_big_normalize_base2k( module: *const MODULE, + n: u64, log2_base2k: u64, res: *mut i64, res_size: u64, @@ -122,6 +123,7 @@ unsafe extern "C" { unsafe extern "C" { pub unsafe fn vec_znx_big_range_normalize_base2k( module: *const MODULE, + n: u64, log2_base2k: u64, res: *mut i64, res_size: u64, @@ -135,7 +137,7 @@ unsafe extern "C" { } unsafe extern "C" { - pub unsafe fn vec_znx_big_range_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64; + pub unsafe fn vec_znx_big_range_normalize_base2k_tmp_bytes(module: *const MODULE, n: u64) -> u64; } unsafe extern "C" { diff --git a/backend/src/ffi/vec_znx_dft.rs b/backend/src/implementation/cpu_spqlios/ffi/vec_znx_dft.rs similarity index 92% rename from backend/src/ffi/vec_znx_dft.rs rename to backend/src/implementation/cpu_spqlios/ffi/vec_znx_dft.rs index 8f427bd..00bb2cd 100644 --- a/backend/src/ffi/vec_znx_dft.rs +++ b/backend/src/implementation/cpu_spqlios/ffi/vec_znx_dft.rs @@ -1,86 +1,85 @@ -use crate::ffi::module::MODULE; -use crate::ffi::vec_znx_big::VEC_ZNX_BIG; - -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct vec_znx_dft_t { - _unused: [u8; 0], -} -pub type VEC_ZNX_DFT = vec_znx_dft_t; - -unsafe extern "C" { - pub unsafe fn bytes_of_vec_znx_dft(module: *const MODULE, size: u64) -> u64; -} -unsafe extern "C" { - pub unsafe fn new_vec_znx_dft(module: *const MODULE, size: u64) -> *mut VEC_ZNX_DFT; -} -unsafe extern "C" { - pub unsafe fn delete_vec_znx_dft(res: *mut VEC_ZNX_DFT); -} - -unsafe extern "C" { - pub unsafe fn vec_dft_zero(module: *const MODULE, res: *mut VEC_ZNX_DFT, res_size: u64); -} -unsafe extern "C" { - pub unsafe fn vec_dft_add( - module: *const MODULE, - res: *mut VEC_ZNX_DFT, - res_size: u64, - a: *const VEC_ZNX_DFT, - a_size: u64, - b: *const VEC_ZNX_DFT, - b_size: u64, - ); -} -unsafe extern "C" { - pub unsafe fn vec_dft_sub( - module: *const MODULE, - res: *mut VEC_ZNX_DFT, - res_size: u64, - a: *const VEC_ZNX_DFT, - a_size: u64, - b: *const VEC_ZNX_DFT, - b_size: u64, - ); -} -unsafe extern "C" { - pub unsafe fn vec_znx_dft(module: *const MODULE, res: *mut VEC_ZNX_DFT, res_size: u64, a: *const i64, a_size: u64, a_sl: u64); -} -unsafe extern "C" { - pub unsafe fn vec_znx_idft( - module: *const MODULE, - res: *mut VEC_ZNX_BIG, - res_size: u64, - a_dft: *const VEC_ZNX_DFT, - a_size: u64, - tmp: *mut u8, - ); -} -unsafe extern "C" { - pub unsafe fn vec_znx_idft_tmp_bytes(module: *const MODULE) -> u64; -} -unsafe extern "C" { - pub unsafe fn vec_znx_idft_tmp_a( - module: *const MODULE, - res: *mut VEC_ZNX_BIG, - res_size: u64, - a_dft: *mut VEC_ZNX_DFT, - a_size: u64, - ); -} - -unsafe extern "C" { - pub unsafe fn vec_znx_dft_automorphism( - module: *const MODULE, - d: i64, - res_dft: *mut VEC_ZNX_DFT, - res_size: u64, - a_dft: *const VEC_ZNX_DFT, - a_size: u64, - tmp: *mut u8, - ); -} - -unsafe extern "C" { - pub unsafe fn vec_znx_dft_automorphism_tmp_bytes(module: *const MODULE) -> u64; -} +use crate::implementation::cpu_spqlios::ffi::{module::MODULE, vec_znx_big::VEC_ZNX_BIG}; + +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct vec_znx_dft_t { + _unused: [u8; 0], +} +pub type VEC_ZNX_DFT = vec_znx_dft_t; + +unsafe extern "C" { + pub unsafe fn bytes_of_vec_znx_dft(module: *const MODULE, size: u64) -> u64; +} +unsafe extern "C" { + pub unsafe fn new_vec_znx_dft(module: *const MODULE, size: u64) -> *mut VEC_ZNX_DFT; +} +unsafe extern "C" { + pub unsafe fn delete_vec_znx_dft(res: *mut VEC_ZNX_DFT); +} + +unsafe extern "C" { + pub unsafe fn vec_dft_zero(module: *const MODULE, res: *mut VEC_ZNX_DFT, res_size: u64); +} +unsafe extern "C" { + pub unsafe fn vec_dft_add( + module: *const MODULE, + res: *mut VEC_ZNX_DFT, + res_size: u64, + a: *const VEC_ZNX_DFT, + a_size: u64, + b: *const VEC_ZNX_DFT, + b_size: u64, + ); +} +unsafe extern "C" { + pub unsafe fn vec_dft_sub( + module: *const MODULE, + res: *mut VEC_ZNX_DFT, + res_size: u64, + a: *const VEC_ZNX_DFT, + a_size: u64, + b: *const VEC_ZNX_DFT, + b_size: u64, + ); +} +unsafe extern "C" { + pub unsafe fn vec_znx_dft(module: *const MODULE, res: *mut VEC_ZNX_DFT, res_size: u64, a: *const i64, a_size: u64, a_sl: u64); +} +unsafe extern "C" { + pub unsafe fn vec_znx_idft( + module: *const MODULE, + res: *mut VEC_ZNX_BIG, + res_size: u64, + a_dft: *const VEC_ZNX_DFT, + a_size: u64, + tmp: *mut u8, + ); +} +unsafe extern "C" { + pub unsafe fn vec_znx_idft_tmp_bytes(module: *const MODULE) -> u64; +} +unsafe extern "C" { + pub unsafe fn vec_znx_idft_tmp_a( + module: *const MODULE, + res: *mut VEC_ZNX_BIG, + res_size: u64, + a_dft: *mut VEC_ZNX_DFT, + a_size: u64, + ); +} + +unsafe extern "C" { + pub unsafe fn vec_znx_dft_automorphism( + module: *const MODULE, + d: i64, + res_dft: *mut VEC_ZNX_DFT, + res_size: u64, + a_dft: *const VEC_ZNX_DFT, + a_size: u64, + tmp: *mut u8, + ); +} + +unsafe extern "C" { + pub unsafe fn vec_znx_dft_automorphism_tmp_bytes(module: *const MODULE) -> u64; +} diff --git a/backend/src/ffi/vmp.rs b/backend/src/implementation/cpu_spqlios/ffi/vmp.rs similarity index 64% rename from backend/src/ffi/vmp.rs rename to backend/src/implementation/cpu_spqlios/ffi/vmp.rs index 4f58e9b..c742cea 100644 --- a/backend/src/ffi/vmp.rs +++ b/backend/src/implementation/cpu_spqlios/ffi/vmp.rs @@ -1,167 +1,113 @@ -use crate::ffi::module::MODULE; -use crate::ffi::vec_znx_big::VEC_ZNX_BIG; -use crate::ffi::vec_znx_dft::VEC_ZNX_DFT; - -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct vmp_pmat_t { - _unused: [u8; 0], -} - -// [rows][cols] = [#Decomposition][#Limbs] -pub type VMP_PMAT = vmp_pmat_t; - -unsafe extern "C" { - pub unsafe fn bytes_of_vmp_pmat(module: *const MODULE, nrows: u64, ncols: u64) -> u64; -} -unsafe extern "C" { - pub unsafe fn new_vmp_pmat(module: *const MODULE, nrows: u64, ncols: u64) -> *mut VMP_PMAT; -} -unsafe extern "C" { - pub unsafe fn delete_vmp_pmat(res: *mut VMP_PMAT); -} - -unsafe extern "C" { - pub unsafe fn vmp_apply_dft( - module: *const MODULE, - res: *mut VEC_ZNX_DFT, - res_size: u64, - a: *const i64, - a_size: u64, - a_sl: u64, - pmat: *const VMP_PMAT, - nrows: u64, - ncols: u64, - tmp_space: *mut u8, - ); -} - -unsafe extern "C" { - pub unsafe fn vmp_apply_dft_add( - module: *const MODULE, - res: *mut VEC_ZNX_DFT, - res_size: u64, - a: *const i64, - a_size: u64, - a_sl: u64, - pmat: *const VMP_PMAT, - nrows: u64, - ncols: u64, - pmat_scale: u64, - tmp_space: *mut u8, - ); -} - -unsafe extern "C" { - pub unsafe fn vmp_apply_dft_tmp_bytes(module: *const MODULE, res_size: u64, a_size: u64, nrows: u64, ncols: u64) -> u64; -} - -unsafe extern "C" { - pub unsafe fn vmp_apply_dft_to_dft( - module: *const MODULE, - res: *mut VEC_ZNX_DFT, - res_size: u64, - a_dft: *const VEC_ZNX_DFT, - a_size: u64, - pmat: *const VMP_PMAT, - nrows: u64, - ncols: u64, - tmp_space: *mut u8, - ); -} - -unsafe extern "C" { - pub unsafe fn vmp_apply_dft_to_dft_add( - module: *const MODULE, - res: *mut VEC_ZNX_DFT, - res_size: u64, - a_dft: *const VEC_ZNX_DFT, - a_size: u64, - pmat: *const VMP_PMAT, - nrows: u64, - ncols: u64, - pmat_scale: u64, - tmp_space: *mut u8, - ); -} - -unsafe extern "C" { - pub unsafe fn vmp_apply_dft_to_dft_tmp_bytes( - module: *const MODULE, - res_size: u64, - a_size: u64, - nrows: u64, - ncols: u64, - ) -> u64; -} - -unsafe extern "C" { - pub unsafe fn vmp_prepare_contiguous( - module: *const MODULE, - pmat: *mut VMP_PMAT, - mat: *const i64, - nrows: u64, - ncols: u64, - tmp_space: *mut u8, - ); -} - -unsafe extern "C" { - pub unsafe fn vmp_prepare_dblptr( - module: *const MODULE, - pmat: *mut VMP_PMAT, - mat: *const *const i64, - nrows: u64, - ncols: u64, - tmp_space: *mut u8, - ); -} - -unsafe extern "C" { - pub unsafe fn vmp_prepare_row( - module: *const MODULE, - pmat: *mut VMP_PMAT, - row: *const i64, - row_i: u64, - nrows: u64, - ncols: u64, - tmp_space: *mut u8, - ); -} - -unsafe extern "C" { - pub unsafe fn vmp_prepare_row_dft( - module: *const MODULE, - pmat: *mut VMP_PMAT, - row: *const VEC_ZNX_DFT, - row_i: u64, - nrows: u64, - ncols: u64, - ); -} - -unsafe extern "C" { - pub unsafe fn vmp_extract_row_dft( - module: *const MODULE, - res: *mut VEC_ZNX_DFT, - pmat: *const VMP_PMAT, - row_i: u64, - nrows: u64, - ncols: u64, - ); -} - -unsafe extern "C" { - pub unsafe fn vmp_extract_row( - module: *const MODULE, - res: *mut VEC_ZNX_BIG, - pmat: *const VMP_PMAT, - row_i: u64, - nrows: u64, - ncols: u64, - ); -} - -unsafe extern "C" { - pub unsafe fn vmp_prepare_tmp_bytes(module: *const MODULE, nrows: u64, ncols: u64) -> u64; -} +use crate::implementation::cpu_spqlios::ffi::{module::MODULE, vec_znx_dft::VEC_ZNX_DFT}; + +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct vmp_pmat_t { + _unused: [u8; 0], +} + +// [rows][cols] = [#Decomposition][#Limbs] +pub type VMP_PMAT = vmp_pmat_t; + +unsafe extern "C" { + pub unsafe fn bytes_of_vmp_pmat(module: *const MODULE, nrows: u64, ncols: u64) -> u64; +} +unsafe extern "C" { + pub unsafe fn new_vmp_pmat(module: *const MODULE, nrows: u64, ncols: u64) -> *mut VMP_PMAT; +} +unsafe extern "C" { + pub unsafe fn delete_vmp_pmat(res: *mut VMP_PMAT); +} + +unsafe extern "C" { + pub unsafe fn vmp_apply_dft( + module: *const MODULE, + res: *mut VEC_ZNX_DFT, + res_size: u64, + a: *const i64, + a_size: u64, + a_sl: u64, + pmat: *const VMP_PMAT, + nrows: u64, + ncols: u64, + tmp_space: *mut u8, + ); +} + +unsafe extern "C" { + pub unsafe fn vmp_apply_dft_add( + module: *const MODULE, + res: *mut VEC_ZNX_DFT, + res_size: u64, + a: *const i64, + a_size: u64, + a_sl: u64, + pmat: *const VMP_PMAT, + nrows: u64, + ncols: u64, + pmat_scale: u64, + tmp_space: *mut u8, + ); +} + +unsafe extern "C" { + pub unsafe fn vmp_apply_dft_tmp_bytes(module: *const MODULE, res_size: u64, a_size: u64, nrows: u64, ncols: u64) -> u64; +} + +unsafe extern "C" { + pub unsafe fn vmp_apply_dft_to_dft( + module: *const MODULE, + res: *mut VEC_ZNX_DFT, + res_size: u64, + a_dft: *const VEC_ZNX_DFT, + a_size: u64, + pmat: *const VMP_PMAT, + nrows: u64, + ncols: u64, + tmp_space: *mut u8, + ); +} + +unsafe extern "C" { + pub unsafe fn vmp_apply_dft_to_dft_add( + module: *const MODULE, + res: *mut VEC_ZNX_DFT, + res_size: u64, + a_dft: *const VEC_ZNX_DFT, + a_size: u64, + pmat: *const VMP_PMAT, + nrows: u64, + ncols: u64, + pmat_scale: u64, + tmp_space: *mut u8, + ); +} + +unsafe extern "C" { + pub unsafe fn vmp_apply_dft_to_dft_tmp_bytes( + module: *const MODULE, + res_size: u64, + a_size: u64, + nrows: u64, + ncols: u64, + ) -> u64; +} + +unsafe extern "C" { + pub unsafe fn vmp_prepare_contiguous( + module: *const MODULE, + pmat: *mut VMP_PMAT, + mat: *const i64, + nrows: u64, + ncols: u64, + tmp_space: *mut u8, + ); +} + +unsafe extern "C" { + pub unsafe fn vmp_prepare_contiguous_dft(module: *const MODULE, pmat: *mut VMP_PMAT, mat: *const f64, nrows: u64, ncols: u64); +} + +unsafe extern "C" { + pub unsafe fn vmp_prepare_tmp_bytes(module: *const MODULE, nrows: u64, ncols: u64) -> u64; +} diff --git a/backend/src/ffi/znx.rs b/backend/src/implementation/cpu_spqlios/ffi/znx.rs similarity index 84% rename from backend/src/ffi/znx.rs rename to backend/src/implementation/cpu_spqlios/ffi/znx.rs index dc30db6..f03da0a 100644 --- a/backend/src/ffi/znx.rs +++ b/backend/src/implementation/cpu_spqlios/ffi/znx.rs @@ -1,76 +1,79 @@ -use crate::ffi::module::MODULE; - -unsafe extern "C" { - pub unsafe fn znx_add_i64_ref(nn: u64, res: *mut i64, a: *const i64, b: *const i64); -} -unsafe extern "C" { - pub unsafe fn znx_add_i64_avx(nn: u64, res: *mut i64, a: *const i64, b: *const i64); -} -unsafe extern "C" { - pub unsafe fn znx_sub_i64_ref(nn: u64, res: *mut i64, a: *const i64, b: *const i64); -} -unsafe extern "C" { - pub unsafe fn znx_sub_i64_avx(nn: u64, res: *mut i64, a: *const i64, b: *const i64); -} -unsafe extern "C" { - pub unsafe fn znx_negate_i64_ref(nn: u64, res: *mut i64, a: *const i64); -} -unsafe extern "C" { - pub unsafe fn znx_negate_i64_avx(nn: u64, res: *mut i64, a: *const i64); -} -unsafe extern "C" { - pub unsafe fn znx_copy_i64_ref(nn: u64, res: *mut i64, a: *const i64); -} -unsafe extern "C" { - pub unsafe fn znx_zero_i64_ref(nn: u64, res: *mut i64); -} -unsafe extern "C" { - pub unsafe fn rnx_divide_by_m_ref(nn: u64, m: f64, res: *mut f64, a: *const f64); -} -unsafe extern "C" { - pub unsafe fn rnx_divide_by_m_avx(nn: u64, m: f64, res: *mut f64, a: *const f64); -} -unsafe extern "C" { - pub unsafe fn rnx_rotate_f64(nn: u64, p: i64, res: *mut f64, in_: *const f64); -} -unsafe extern "C" { - pub unsafe fn znx_rotate_i64(nn: u64, p: i64, res: *mut i64, in_: *const i64); -} -unsafe extern "C" { - pub unsafe fn rnx_rotate_inplace_f64(nn: u64, p: i64, res: *mut f64); -} -unsafe extern "C" { - pub unsafe fn znx_rotate_inplace_i64(nn: u64, p: i64, res: *mut i64); -} -unsafe extern "C" { - pub unsafe fn rnx_automorphism_f64(nn: u64, p: i64, res: *mut f64, in_: *const f64); -} -unsafe extern "C" { - pub unsafe fn znx_automorphism_i64(nn: u64, p: i64, res: *mut i64, in_: *const i64); -} -unsafe extern "C" { - pub unsafe fn rnx_automorphism_inplace_f64(nn: u64, p: i64, res: *mut f64); -} -unsafe extern "C" { - pub unsafe fn znx_automorphism_inplace_i64(nn: u64, p: i64, res: *mut i64); -} -unsafe extern "C" { - pub unsafe fn rnx_mul_xp_minus_one(nn: u64, p: i64, res: *mut f64, in_: *const f64); -} -unsafe extern "C" { - pub unsafe fn znx_mul_xp_minus_one(nn: u64, p: i64, res: *mut i64, in_: *const i64); -} -unsafe extern "C" { - pub unsafe fn rnx_mul_xp_minus_one_inplace(nn: u64, p: i64, res: *mut f64); -} -unsafe extern "C" { - pub unsafe fn znx_normalize(nn: u64, base_k: u64, out: *mut i64, carry_out: *mut i64, in_: *const i64, carry_in: *const i64); -} - -unsafe extern "C" { - pub unsafe fn znx_small_single_product(module: *const MODULE, res: *mut i64, a: *const i64, b: *const i64, tmp: *mut u8); -} - -unsafe extern "C" { - pub unsafe fn znx_small_single_product_tmp_bytes(module: *const MODULE) -> u64; -} +use crate::implementation::cpu_spqlios::ffi::module::MODULE; + +unsafe extern "C" { + pub unsafe fn znx_add_i64_ref(nn: u64, res: *mut i64, a: *const i64, b: *const i64); +} +unsafe extern "C" { + pub unsafe fn znx_add_i64_avx(nn: u64, res: *mut i64, a: *const i64, b: *const i64); +} +unsafe extern "C" { + pub unsafe fn znx_sub_i64_ref(nn: u64, res: *mut i64, a: *const i64, b: *const i64); +} +unsafe extern "C" { + pub unsafe fn znx_sub_i64_avx(nn: u64, res: *mut i64, a: *const i64, b: *const i64); +} +unsafe extern "C" { + pub unsafe fn znx_negate_i64_ref(nn: u64, res: *mut i64, a: *const i64); +} +unsafe extern "C" { + pub unsafe fn znx_negate_i64_avx(nn: u64, res: *mut i64, a: *const i64); +} +unsafe extern "C" { + pub unsafe fn znx_copy_i64_ref(nn: u64, res: *mut i64, a: *const i64); +} +unsafe extern "C" { + pub unsafe fn znx_zero_i64_ref(nn: u64, res: *mut i64); +} +unsafe extern "C" { + pub unsafe fn rnx_divide_by_m_ref(nn: u64, m: f64, res: *mut f64, a: *const f64); +} +unsafe extern "C" { + pub unsafe fn rnx_divide_by_m_avx(nn: u64, m: f64, res: *mut f64, a: *const f64); +} +unsafe extern "C" { + pub unsafe fn rnx_rotate_f64(nn: u64, p: i64, res: *mut f64, in_: *const f64); +} +unsafe extern "C" { + pub unsafe fn znx_rotate_i64(nn: u64, p: i64, res: *mut i64, in_: *const i64); +} +unsafe extern "C" { + pub unsafe fn rnx_rotate_inplace_f64(nn: u64, p: i64, res: *mut f64); +} +unsafe extern "C" { + pub unsafe fn znx_rotate_inplace_i64(nn: u64, p: i64, res: *mut i64); +} +unsafe extern "C" { + pub unsafe fn rnx_automorphism_f64(nn: u64, p: i64, res: *mut f64, in_: *const f64); +} +unsafe extern "C" { + pub unsafe fn znx_automorphism_i64(nn: u64, p: i64, res: *mut i64, in_: *const i64); +} +unsafe extern "C" { + pub unsafe fn rnx_automorphism_inplace_f64(nn: u64, p: i64, res: *mut f64); +} +unsafe extern "C" { + pub unsafe fn znx_automorphism_inplace_i64(nn: u64, p: i64, res: *mut i64); +} +unsafe extern "C" { + pub unsafe fn rnx_mul_xp_minus_one_f64(nn: u64, p: i64, res: *mut f64, in_: *const f64); +} +unsafe extern "C" { + pub unsafe fn znx_mul_xp_minus_one_i64(nn: u64, p: i64, res: *mut i64, in_: *const i64); +} +unsafe extern "C" { + pub unsafe fn rnx_mul_xp_minus_one_inplace_f64(nn: u64, p: i64, res: *mut f64); +} +unsafe extern "C" { + pub unsafe fn znx_mul_xp_minus_one_inplace_i64(nn: u64, p: i64, res: *mut i64); +} +unsafe extern "C" { + pub unsafe fn znx_normalize(nn: u64, base_k: u64, out: *mut i64, carry_out: *mut i64, in_: *const i64, carry_in: *const i64); +} + +unsafe extern "C" { + pub unsafe fn znx_small_single_product(module: *const MODULE, res: *mut i64, a: *const i64, b: *const i64, tmp: *mut u8); +} + +unsafe extern "C" { + pub unsafe fn znx_small_single_product_tmp_bytes(module: *const MODULE) -> u64; +} diff --git a/backend/src/implementation/cpu_spqlios/mat_znx.rs b/backend/src/implementation/cpu_spqlios/mat_znx.rs new file mode 100644 index 0000000..4e91184 --- /dev/null +++ b/backend/src/implementation/cpu_spqlios/mat_znx.rs @@ -0,0 +1,41 @@ +use crate::{ + hal::{ + layouts::{Backend, MatZnxOwned, Module}, + oep::{MatZnxAllocBytesImpl, MatZnxAllocImpl, MatZnxFromBytesImpl}, + }, + implementation::cpu_spqlios::CPUAVX, +}; + +unsafe impl MatZnxAllocImpl for B +where + B: CPUAVX, +{ + fn mat_znx_alloc_impl(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxOwned { + MatZnxOwned::new(module.n(), rows, cols_in, cols_out, size) + } +} + +unsafe impl MatZnxAllocBytesImpl for B +where + B: CPUAVX, +{ + fn mat_znx_alloc_bytes_impl(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { + MatZnxOwned::bytes_of(module.n(), rows, cols_in, cols_out, size) + } +} + +unsafe impl MatZnxFromBytesImpl for B +where + B: CPUAVX, +{ + fn mat_znx_from_bytes_impl( + module: &Module, + rows: usize, + cols_in: usize, + cols_out: usize, + size: usize, + bytes: Vec, + ) -> MatZnxOwned { + MatZnxOwned::new_from_bytes(module.n(), rows, cols_in, cols_out, size, bytes) + } +} diff --git a/backend/src/implementation/cpu_spqlios/mod.rs b/backend/src/implementation/cpu_spqlios/mod.rs new file mode 100644 index 0000000..0863ae1 --- /dev/null +++ b/backend/src/implementation/cpu_spqlios/mod.rs @@ -0,0 +1,26 @@ +mod ffi; +mod mat_znx; +mod module_fft64; +mod module_ntt120; +mod scalar_znx; +mod scratch; +mod svp_ppol_fft64; +mod svp_ppol_ntt120; +mod vec_znx; +mod vec_znx_big_fft64; +mod vec_znx_big_ntt120; +mod vec_znx_dft_fft64; +mod vec_znx_dft_ntt120; +mod vmp_pmat_fft64; +mod vmp_pmat_ntt120; + +#[cfg(test)] +mod test; + +pub use module_fft64::*; +pub use module_ntt120::*; + +/// For external documentation +pub use vec_znx::{vec_znx_copy_ref, vec_znx_lsh_inplace_ref, vec_znx_merge_ref, vec_znx_rsh_inplace_ref, vec_znx_split_ref, vec_znx_switch_degree_ref}; + +pub trait CPUAVX {} diff --git a/backend/src/implementation/cpu_spqlios/module_fft64.rs b/backend/src/implementation/cpu_spqlios/module_fft64.rs new file mode 100644 index 0000000..3f86c6c --- /dev/null +++ b/backend/src/implementation/cpu_spqlios/module_fft64.rs @@ -0,0 +1,29 @@ +use std::ptr::NonNull; + +use crate::{ + hal::{ + layouts::{Backend, Module}, + oep::ModuleNewImpl, + }, + implementation::cpu_spqlios::{ + CPUAVX, + ffi::module::{MODULE, delete_module_info, new_module_info}, + }, +}; + +pub struct FFT64; + +impl CPUAVX for FFT64 {} + +impl Backend for FFT64 { + type Handle = MODULE; + unsafe fn destroy(handle: NonNull) { + unsafe { delete_module_info(handle.as_ptr()) } + } +} + +unsafe impl ModuleNewImpl for FFT64 { + fn new_impl(n: u64) -> Module { + unsafe { Module::from_raw_parts(new_module_info(n, 0), n) } + } +} diff --git a/backend/src/implementation/cpu_spqlios/module_ntt120.rs b/backend/src/implementation/cpu_spqlios/module_ntt120.rs new file mode 100644 index 0000000..94e0bdb --- /dev/null +++ b/backend/src/implementation/cpu_spqlios/module_ntt120.rs @@ -0,0 +1,29 @@ +use std::ptr::NonNull; + +use crate::{ + hal::{ + layouts::{Backend, Module}, + oep::ModuleNewImpl, + }, + implementation::cpu_spqlios::{ + CPUAVX, + ffi::module::{MODULE, delete_module_info, new_module_info}, + }, +}; + +pub struct NTT120; + +impl CPUAVX for NTT120 {} + +impl Backend for NTT120 { + type Handle = MODULE; + unsafe fn destroy(handle: NonNull) { + unsafe { delete_module_info(handle.as_ptr()) } + } +} + +unsafe impl ModuleNewImpl for NTT120 { + fn new_impl(n: u64) -> Module { + unsafe { Module::from_raw_parts(new_module_info(n, 1), n) } + } +} diff --git a/backend/src/implementation/cpu_spqlios/scalar_znx.rs b/backend/src/implementation/cpu_spqlios/scalar_znx.rs new file mode 100644 index 0000000..3b39958 --- /dev/null +++ b/backend/src/implementation/cpu_spqlios/scalar_znx.rs @@ -0,0 +1,100 @@ +use crate::{ + hal::{ + api::{ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut}, + layouts::{Backend, Module, ScalarZnx, ScalarZnxOwned, ScalarZnxToMut, ScalarZnxToRef}, + oep::{ + ScalarZnxAllocBytesImpl, ScalarZnxAllocImpl, ScalarZnxAutomorphismImpl, ScalarZnxAutomorphismInplaceIml, + ScalarZnxFromBytesImpl, + }, + }, + implementation::cpu_spqlios::{ + CPUAVX, + ffi::{module::module_info_t, vec_znx}, + }, +}; + +unsafe impl ScalarZnxAllocBytesImpl for B +where + B: CPUAVX, +{ + fn scalar_znx_alloc_bytes_impl(n: usize, cols: usize) -> usize { + ScalarZnxOwned::bytes_of(n, cols) + } +} + +unsafe impl ScalarZnxAllocImpl for B +where + B: CPUAVX, +{ + fn scalar_znx_alloc_impl(n: usize, cols: usize) -> ScalarZnxOwned { + ScalarZnxOwned::new(n, cols) + } +} + +unsafe impl ScalarZnxFromBytesImpl for B +where + B: CPUAVX, +{ + fn scalar_znx_from_bytes_impl(n: usize, cols: usize, bytes: Vec) -> ScalarZnxOwned { + ScalarZnxOwned::new_from_bytes(n, cols, bytes) + } +} + +unsafe impl ScalarZnxAutomorphismImpl for B +where + B: CPUAVX, +{ + fn scalar_znx_automorphism_impl(module: &Module, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: ScalarZnxToMut, + A: ScalarZnxToRef, + { + let a: ScalarZnx<&[u8]> = a.to_ref(); + let mut res: ScalarZnx<&mut [u8]> = res.to_mut(); + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + assert_eq!(res.n(), module.n()); + } + unsafe { + vec_znx::vec_znx_automorphism( + module.ptr() as *const module_info_t, + k, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) + } + } +} + +unsafe impl ScalarZnxAutomorphismInplaceIml for B +where + B: CPUAVX, +{ + fn scalar_znx_automorphism_inplace_impl(module: &Module, k: i64, a: &mut A, a_col: usize) + where + A: ScalarZnxToMut, + { + let mut a: ScalarZnx<&mut [u8]> = a.to_mut(); + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + } + unsafe { + vec_znx::vec_znx_automorphism( + module.ptr() as *const module_info_t, + k, + a.at_mut_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) + } + } +} diff --git a/backend/src/implementation/cpu_spqlios/scratch.rs b/backend/src/implementation/cpu_spqlios/scratch.rs new file mode 100644 index 0000000..1f234c4 --- /dev/null +++ b/backend/src/implementation/cpu_spqlios/scratch.rs @@ -0,0 +1,274 @@ +use std::marker::PhantomData; + +use crate::{ + DEFAULTALIGN, alloc_aligned, + hal::{ + api::ScratchFromBytes, + layouts::{Backend, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat}, + oep::{ + ScalarZnxAllocBytesImpl, ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, + SvpPPolAllocBytesImpl, TakeMatZnxImpl, TakeScalarZnxImpl, TakeSliceImpl, TakeSvpPPolImpl, TakeVecZnxBigImpl, + TakeVecZnxDftImpl, TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, TakeVmpPMatImpl, + VecZnxAllocBytesImpl, VecZnxBigAllocBytesImpl, VecZnxDftAllocBytesImpl, VmpPMatAllocBytesImpl, + }, + }, + implementation::cpu_spqlios::CPUAVX, +}; + +unsafe impl ScratchOwnedAllocImpl for B +where + B: CPUAVX, +{ + fn scratch_owned_alloc_impl(size: usize) -> ScratchOwned { + let data: Vec = alloc_aligned(size); + ScratchOwned { + data, + _phantom: PhantomData, + } + } +} + +unsafe impl ScratchOwnedBorrowImpl for B +where + B: CPUAVX, +{ + fn scratch_owned_borrow_impl(scratch: &mut ScratchOwned) -> &mut Scratch { + Scratch::from_bytes(&mut scratch.data) + } +} + +unsafe impl ScratchFromBytesImpl for B +where + B: CPUAVX, +{ + fn scratch_from_bytes_impl(data: &mut [u8]) -> &mut Scratch { + unsafe { &mut *(data as *mut [u8] as *mut Scratch) } + } +} + +unsafe impl ScratchAvailableImpl for B +where + B: CPUAVX, +{ + fn scratch_available_impl(scratch: &Scratch) -> usize { + let ptr: *const u8 = scratch.data.as_ptr(); + let self_len: usize = scratch.data.len(); + let aligned_offset: usize = ptr.align_offset(DEFAULTALIGN); + self_len.saturating_sub(aligned_offset) + } +} + +unsafe impl TakeSliceImpl for B +where + B: CPUAVX, +{ + fn take_slice_impl(scratch: &mut Scratch, len: usize) -> (&mut [T], &mut Scratch) { + let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, len * std::mem::size_of::()); + + unsafe { + ( + &mut *(std::ptr::slice_from_raw_parts_mut(take_slice.as_mut_ptr() as *mut T, len)), + Scratch::from_bytes(rem_slice), + ) + } + } +} + +unsafe impl TakeScalarZnxImpl for B +where + B: CPUAVX + ScalarZnxAllocBytesImpl, +{ + fn take_scalar_znx_impl(scratch: &mut Scratch, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Scratch) { + let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, B::scalar_znx_alloc_bytes_impl(n, cols)); + ( + ScalarZnx::from_data(take_slice, n, cols), + Scratch::from_bytes(rem_slice), + ) + } +} + +unsafe impl TakeSvpPPolImpl for B +where + B: CPUAVX + SvpPPolAllocBytesImpl, +{ + fn take_svp_ppol_impl(scratch: &mut Scratch, n: usize, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Scratch) { + let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, B::svp_ppol_alloc_bytes_impl(n, cols)); + ( + SvpPPol::from_data(take_slice, n, cols), + Scratch::from_bytes(rem_slice), + ) + } +} + +unsafe impl TakeVecZnxImpl for B +where + B: CPUAVX + VecZnxAllocBytesImpl, +{ + fn take_vec_znx_impl(scratch: &mut Scratch, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Scratch) { + let (take_slice, rem_slice) = take_slice_aligned( + &mut scratch.data, + B::vec_znx_alloc_bytes_impl(n, cols, size), + ); + ( + VecZnx::from_data(take_slice, n, cols, size), + Scratch::from_bytes(rem_slice), + ) + } +} + +unsafe impl TakeVecZnxBigImpl for B +where + B: CPUAVX + VecZnxBigAllocBytesImpl, +{ + fn take_vec_znx_big_impl( + scratch: &mut Scratch, + n: usize, + cols: usize, + size: usize, + ) -> (VecZnxBig<&mut [u8], B>, &mut Scratch) { + let (take_slice, rem_slice) = take_slice_aligned( + &mut scratch.data, + B::vec_znx_big_alloc_bytes_impl(n, cols, size), + ); + ( + VecZnxBig::from_data(take_slice, n, cols, size), + Scratch::from_bytes(rem_slice), + ) + } +} + +unsafe impl TakeVecZnxDftImpl for B +where + B: CPUAVX + VecZnxDftAllocBytesImpl, +{ + fn take_vec_znx_dft_impl( + scratch: &mut Scratch, + n: usize, + cols: usize, + size: usize, + ) -> (VecZnxDft<&mut [u8], B>, &mut Scratch) { + let (take_slice, rem_slice) = take_slice_aligned( + &mut scratch.data, + B::vec_znx_dft_alloc_bytes_impl(n, cols, size), + ); + + ( + VecZnxDft::from_data(take_slice, n, cols, size), + Scratch::from_bytes(rem_slice), + ) + } +} + +unsafe impl TakeVecZnxDftSliceImpl for B +where + B: CPUAVX + VecZnxDftAllocBytesImpl, +{ + fn take_vec_znx_dft_slice_impl( + scratch: &mut Scratch, + len: usize, + n: usize, + cols: usize, + size: usize, + ) -> (Vec>, &mut Scratch) { + let mut scratch: &mut Scratch = scratch; + let mut slice: Vec> = Vec::with_capacity(len); + for _ in 0..len { + let (znx, new_scratch) = B::take_vec_znx_dft_impl(scratch, n, cols, size); + scratch = new_scratch; + slice.push(znx); + } + (slice, scratch) + } +} + +unsafe impl TakeVecZnxSliceImpl for B +where + B: CPUAVX, +{ + fn take_vec_znx_slice_impl( + scratch: &mut Scratch, + len: usize, + n: usize, + cols: usize, + size: usize, + ) -> (Vec>, &mut Scratch) { + let mut scratch: &mut Scratch = scratch; + let mut slice: Vec> = Vec::with_capacity(len); + for _ in 0..len { + let (znx, new_scratch) = B::take_vec_znx_impl(scratch, n, cols, size); + scratch = new_scratch; + slice.push(znx); + } + (slice, scratch) + } +} + +unsafe impl TakeVmpPMatImpl for B +where + B: CPUAVX + VmpPMatAllocBytesImpl, +{ + fn take_vmp_pmat_impl( + scratch: &mut Scratch, + n: usize, + rows: usize, + cols_in: usize, + cols_out: usize, + size: usize, + ) -> (VmpPMat<&mut [u8], B>, &mut Scratch) { + let (take_slice, rem_slice) = take_slice_aligned( + &mut scratch.data, + B::vmp_pmat_alloc_bytes_impl(n, rows, cols_in, cols_out, size), + ); + ( + VmpPMat::from_data(take_slice, n, rows, cols_in, cols_out, size), + Scratch::from_bytes(rem_slice), + ) + } +} + +unsafe impl TakeMatZnxImpl for B +where + B: CPUAVX, +{ + fn take_mat_znx_impl( + scratch: &mut Scratch, + n: usize, + rows: usize, + cols_in: usize, + cols_out: usize, + size: usize, + ) -> (MatZnx<&mut [u8]>, &mut Scratch) { + let (take_slice, rem_slice) = take_slice_aligned( + &mut scratch.data, + MatZnx::>::bytes_of(n, rows, cols_in, cols_out, size), + ); + ( + MatZnx::from_data(take_slice, n, rows, cols_in, cols_out, size), + Scratch::from_bytes(rem_slice), + ) + } +} + +fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8]) { + let ptr: *mut u8 = data.as_mut_ptr(); + let self_len: usize = data.len(); + + let aligned_offset: usize = ptr.align_offset(DEFAULTALIGN); + let aligned_len: usize = self_len.saturating_sub(aligned_offset); + + if let Some(rem_len) = aligned_len.checked_sub(take_len) { + unsafe { + let rem_ptr: *mut u8 = ptr.add(aligned_offset).add(take_len); + let rem_slice: &mut [u8] = &mut *std::ptr::slice_from_raw_parts_mut(rem_ptr, rem_len); + + let take_slice: &mut [u8] = &mut *std::ptr::slice_from_raw_parts_mut(ptr.add(aligned_offset), take_len); + + return (take_slice, rem_slice); + } + } else { + panic!( + "Attempted to take {} from scratch with {} aligned bytes left", + take_len, aligned_len, + ); + } +} diff --git a/backend/src/implementation/cpu_spqlios/spqlios-arithmetic b/backend/src/implementation/cpu_spqlios/spqlios-arithmetic new file mode 160000 index 0000000..7160f58 --- /dev/null +++ b/backend/src/implementation/cpu_spqlios/spqlios-arithmetic @@ -0,0 +1 @@ +Subproject commit 7160f588da49712a042931ea247b4259b95cefcc diff --git a/backend/src/implementation/cpu_spqlios/svp_ppol_fft64.rs b/backend/src/implementation/cpu_spqlios/svp_ppol_fft64.rs new file mode 100644 index 0000000..265840a --- /dev/null +++ b/backend/src/implementation/cpu_spqlios/svp_ppol_fft64.rs @@ -0,0 +1,114 @@ +use crate::{ + hal::{ + api::{ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut}, + layouts::{ + Data, DataRef, Module, ScalarZnxToRef, SvpPPol, SvpPPolBytesOf, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDft, + VecZnxDftToMut, VecZnxDftToRef, + }, + oep::{SvpApplyImpl, SvpApplyInplaceImpl, SvpPPolAllocBytesImpl, SvpPPolAllocImpl, SvpPPolFromBytesImpl, SvpPrepareImpl}, + }, + implementation::cpu_spqlios::{ + ffi::{svp, vec_znx_dft::vec_znx_dft_t}, + module_fft64::FFT64, + }, +}; + +const SVP_PPOL_FFT64_WORD_SIZE: usize = 1; + +impl SvpPPolBytesOf for SvpPPol { + fn bytes_of(n: usize, cols: usize) -> usize { + SVP_PPOL_FFT64_WORD_SIZE * n * cols * size_of::() + } +} + +impl ZnxSliceSize for SvpPPol { + fn sl(&self) -> usize { + SVP_PPOL_FFT64_WORD_SIZE * self.n() + } +} + +impl ZnxView for SvpPPol { + type Scalar = f64; +} + +unsafe impl SvpPPolFromBytesImpl for FFT64 { + fn svp_ppol_from_bytes_impl(n: usize, cols: usize, bytes: Vec) -> SvpPPolOwned { + SvpPPolOwned::from_bytes(n, cols, bytes) + } +} + +unsafe impl SvpPPolAllocImpl for FFT64 { + fn svp_ppol_alloc_impl(n: usize, cols: usize) -> SvpPPolOwned { + SvpPPolOwned::alloc(n, cols) + } +} + +unsafe impl SvpPPolAllocBytesImpl for FFT64 { + fn svp_ppol_alloc_bytes_impl(n: usize, cols: usize) -> usize { + SvpPPol::, Self>::bytes_of(n, cols) + } +} + +unsafe impl SvpPrepareImpl for FFT64 { + fn svp_prepare_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: SvpPPolToMut, + A: ScalarZnxToRef, + { + unsafe { + svp::svp_prepare( + module.ptr(), + res.to_mut().at_mut_ptr(res_col, 0) as *mut svp::svp_ppol_t, + a.to_ref().at_ptr(a_col, 0), + ) + } + } +} + +unsafe impl SvpApplyImpl for FFT64 { + fn svp_apply_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxDftToMut, + A: SvpPPolToRef, + B: VecZnxDftToRef, + { + let mut res: VecZnxDft<&mut [u8], Self> = res.to_mut(); + let a: SvpPPol<&[u8], Self> = a.to_ref(); + let b: VecZnxDft<&[u8], Self> = b.to_ref(); + unsafe { + svp::svp_apply_dft_to_dft( + module.ptr(), + res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t, + res.size() as u64, + res.cols() as u64, + a.at_ptr(a_col, 0) as *const svp::svp_ppol_t, + b.at_ptr(b_col, 0) as *const vec_znx_dft_t, + b.size() as u64, + b.cols() as u64, + ) + } + } +} + +unsafe impl SvpApplyInplaceImpl for FFT64 { + fn svp_apply_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: SvpPPolToRef, + { + let mut res: VecZnxDft<&mut [u8], Self> = res.to_mut(); + let a: SvpPPol<&[u8], Self> = a.to_ref(); + unsafe { + svp::svp_apply_dft_to_dft( + module.ptr(), + res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t, + res.size() as u64, + res.cols() as u64, + a.at_ptr(a_col, 0) as *const svp::svp_ppol_t, + res.at_ptr(res_col, 0) as *const vec_znx_dft_t, + res.size() as u64, + res.cols() as u64, + ) + } + } +} diff --git a/backend/src/implementation/cpu_spqlios/svp_ppol_ntt120.rs b/backend/src/implementation/cpu_spqlios/svp_ppol_ntt120.rs new file mode 100644 index 0000000..39a84f9 --- /dev/null +++ b/backend/src/implementation/cpu_spqlios/svp_ppol_ntt120.rs @@ -0,0 +1,44 @@ +use crate::{ + hal::{ + api::{ZnxInfos, ZnxSliceSize, ZnxView}, + layouts::{Data, DataRef, SvpPPol, SvpPPolBytesOf, SvpPPolOwned}, + oep::{SvpPPolAllocBytesImpl, SvpPPolAllocImpl, SvpPPolFromBytesImpl}, + }, + implementation::cpu_spqlios::module_ntt120::NTT120, +}; + +const SVP_PPOL_NTT120_WORD_SIZE: usize = 4; + +impl SvpPPolBytesOf for SvpPPol { + fn bytes_of(n: usize, cols: usize) -> usize { + SVP_PPOL_NTT120_WORD_SIZE * n * cols * size_of::() + } +} + +impl ZnxSliceSize for SvpPPol { + fn sl(&self) -> usize { + SVP_PPOL_NTT120_WORD_SIZE * self.n() + } +} + +impl ZnxView for SvpPPol { + type Scalar = i64; +} + +unsafe impl SvpPPolFromBytesImpl for NTT120 { + fn svp_ppol_from_bytes_impl(n: usize, cols: usize, bytes: Vec) -> SvpPPolOwned { + SvpPPolOwned::from_bytes(n, cols, bytes) + } +} + +unsafe impl SvpPPolAllocImpl for NTT120 { + fn svp_ppol_alloc_impl(n: usize, cols: usize) -> SvpPPolOwned { + SvpPPolOwned::alloc(n, cols) + } +} + +unsafe impl SvpPPolAllocBytesImpl for NTT120 { + fn svp_ppol_alloc_bytes_impl(n: usize, cols: usize) -> usize { + SvpPPol::, Self>::bytes_of(n, cols) + } +} diff --git a/backend/src/implementation/cpu_spqlios/test/mod.rs b/backend/src/implementation/cpu_spqlios/test/mod.rs new file mode 100644 index 0000000..636f5a4 --- /dev/null +++ b/backend/src/implementation/cpu_spqlios/test/mod.rs @@ -0,0 +1 @@ +mod vec_znx_fft64; diff --git a/backend/src/implementation/cpu_spqlios/test/vec_znx_fft64.rs b/backend/src/implementation/cpu_spqlios/test/vec_znx_fft64.rs new file mode 100644 index 0000000..a289439 --- /dev/null +++ b/backend/src/implementation/cpu_spqlios/test/vec_znx_fft64.rs @@ -0,0 +1,35 @@ +use crate::{ + hal::{ + api::ModuleNew, + layouts::Module, + tests::vec_znx::{ + test_vec_znx_add_normal, test_vec_znx_encode_vec_i64_hi_norm, test_vec_znx_encode_vec_i64_lo_norm, + test_vec_znx_fill_uniform, + }, + }, + implementation::cpu_spqlios::FFT64, +}; + +#[test] +fn test_vec_znx_fill_uniform_fft64() { + let module: Module = Module::::new(1 << 12); + test_vec_znx_fill_uniform(&module); +} + +#[test] +fn test_vec_znx_add_normal_fft64() { + let module: Module = Module::::new(1 << 12); + test_vec_znx_add_normal(&module); +} + +#[test] +fn test_vec_znx_encode_vec_lo_norm_fft64() { + let module: Module = Module::::new(1 << 8); + test_vec_znx_encode_vec_i64_lo_norm(&module); +} + +#[test] +fn test_vec_znx_encode_vec_hi_norm_fft64() { + let module: Module = Module::::new(1 << 8); + test_vec_znx_encode_vec_i64_hi_norm(&module); +} diff --git a/backend/src/implementation/cpu_spqlios/vec_znx.rs b/backend/src/implementation/cpu_spqlios/vec_znx.rs new file mode 100644 index 0000000..d446ea4 --- /dev/null +++ b/backend/src/implementation/cpu_spqlios/vec_znx.rs @@ -0,0 +1,1344 @@ +use itertools::izip; +use rand_distr::Normal; +use rug::{ + Assign, Float, + float::Round, + ops::{AddAssignRound, DivAssignRound, SubAssignRound}, +}; +use sampling::source::Source; + +use crate::{ + hal::{ + api::{ + TakeSlice, TakeVecZnx, VecZnxAddDistF64, VecZnxCopy, VecZnxDecodeVecFloat, VecZnxFillDistF64, + VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxSwithcDegree, ZnxInfos, ZnxSliceSize, ZnxView, + ZnxViewMut, ZnxZero, + }, + layouts::{Backend, Module, ScalarZnx, ScalarZnxToRef, Scratch, VecZnx, VecZnxOwned, VecZnxToMut, VecZnxToRef}, + oep::{ + VecZnxAddDistF64Impl, VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl, VecZnxAddScalarInplaceImpl, + VecZnxAllocBytesImpl, VecZnxAllocImpl, VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl, + VecZnxDecodeCoeffsi64Impl, VecZnxDecodeVecFloatImpl, VecZnxDecodeVeci64Impl, VecZnxEncodeCoeffsi64Impl, + VecZnxEncodeVeci64Impl, VecZnxFillDistF64Impl, VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxFromBytesImpl, + VecZnxLshInplaceImpl, VecZnxMergeImpl, VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl, VecZnxNegateImpl, + VecZnxNegateInplaceImpl, VecZnxNormalizeImpl, VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, + VecZnxRotateImpl, VecZnxRotateInplaceImpl, VecZnxRshInplaceImpl, VecZnxSplitImpl, VecZnxStdImpl, + VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarInplaceImpl, VecZnxSwithcDegreeImpl, + }, + }, + implementation::cpu_spqlios::{ + CPUAVX, + ffi::{module::module_info_t, vec_znx, znx}, + }, +}; + +unsafe impl VecZnxAllocImpl for B +where + B: CPUAVX, +{ + fn vec_znx_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxOwned { + VecZnxOwned::new::(n, cols, size) + } +} + +unsafe impl VecZnxFromBytesImpl for B +where + B: CPUAVX, +{ + fn vec_znx_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec) -> VecZnxOwned { + VecZnxOwned::from_bytes::(n, cols, size, bytes) + } +} + +unsafe impl VecZnxAllocBytesImpl for B +where + B: CPUAVX, +{ + fn vec_znx_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize { + VecZnxOwned::alloc_bytes::(n, cols, size) + } +} + +unsafe impl VecZnxNormalizeTmpBytesImpl for B +where + B: CPUAVX, +{ + fn vec_znx_normalize_tmp_bytes_impl(module: &Module, n: usize) -> usize { + unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(module.ptr() as *const module_info_t, n as u64) as usize } + } +} + +unsafe impl VecZnxNormalizeImpl for B +where + B: CPUAVX, +{ + fn vec_znx_normalize_impl( + module: &Module, + basek: usize, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.n(), a.n()); + } + + let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes(a.n())); + + unsafe { + vec_znx::vec_znx_normalize_base2k( + module.ptr() as *const module_info_t, + a.n() as u64, + basek as u64, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + tmp_bytes.as_mut_ptr(), + ); + } + } +} + +unsafe impl VecZnxNormalizeInplaceImpl for B +where + B: CPUAVX, +{ + fn vec_znx_normalize_inplace_impl(module: &Module, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) + where + A: VecZnxToMut, + { + let mut a: VecZnx<&mut [u8]> = a.to_mut(); + + let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes(a.n())); + + unsafe { + vec_znx::vec_znx_normalize_base2k( + module.ptr() as *const module_info_t, + a.n() as u64, + basek as u64, + a.at_mut_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + tmp_bytes.as_mut_ptr(), + ); + } + } +} + +unsafe impl VecZnxAddImpl for B +where + B: CPUAVX, +{ + fn vec_znx_add_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + C: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let b: VecZnx<&[u8]> = b.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + assert_eq!(b.n(), module.n()); + assert_eq!(res.n(), module.n()); + assert_ne!(a.as_ptr(), b.as_ptr()); + } + unsafe { + vec_znx::vec_znx_add( + module.ptr() as *const module_info_t, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + b.at_ptr(b_col, 0), + b.size() as u64, + b.sl() as u64, + ) + } + } +} + +unsafe impl VecZnxAddInplaceImpl for B +where + B: CPUAVX, +{ + fn vec_znx_add_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + assert_eq!(res.n(), module.n()); + } + unsafe { + vec_znx::vec_znx_add( + module.ptr() as *const module_info_t, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + res.at_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + ) + } + } +} + +unsafe impl VecZnxAddScalarInplaceImpl for B +where + B: CPUAVX, +{ + fn vec_znx_add_scalar_inplace_impl( + module: &Module, + res: &mut R, + res_col: usize, + res_limb: usize, + a: &A, + a_col: usize, + ) where + R: VecZnxToMut, + A: ScalarZnxToRef, + { + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + let a: ScalarZnx<&[u8]> = a.to_ref(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + assert_eq!(res.n(), module.n()); + } + + unsafe { + vec_znx::vec_znx_add( + module.ptr() as *const module_info_t, + res.at_mut_ptr(res_col, res_limb), + 1 as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + res.at_ptr(res_col, res_limb), + 1 as u64, + res.sl() as u64, + ) + } + } +} + +unsafe impl VecZnxSubImpl for B +where + B: CPUAVX, +{ + fn vec_znx_sub_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + C: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let b: VecZnx<&[u8]> = b.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + assert_eq!(b.n(), module.n()); + assert_eq!(res.n(), module.n()); + assert_ne!(a.as_ptr(), b.as_ptr()); + } + unsafe { + vec_znx::vec_znx_sub( + module.ptr() as *const module_info_t, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + b.at_ptr(b_col, 0), + b.size() as u64, + b.sl() as u64, + ) + } + } +} + +unsafe impl VecZnxSubABInplaceImpl for B +where + B: CPUAVX, +{ + fn vec_znx_sub_ab_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + assert_eq!(res.n(), module.n()); + } + unsafe { + vec_znx::vec_znx_sub( + module.ptr() as *const module_info_t, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + res.at_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) + } + } +} + +unsafe impl VecZnxSubBAInplaceImpl for B +where + B: CPUAVX, +{ + fn vec_znx_sub_ba_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + assert_eq!(res.n(), module.n()); + } + unsafe { + vec_znx::vec_znx_sub( + module.ptr() as *const module_info_t, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + res.at_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + ) + } + } +} + +unsafe impl VecZnxSubScalarInplaceImpl for B +where + B: CPUAVX, +{ + fn vec_znx_sub_scalar_inplace_impl( + module: &Module, + res: &mut R, + res_col: usize, + res_limb: usize, + a: &A, + a_col: usize, + ) where + R: VecZnxToMut, + A: ScalarZnxToRef, + { + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + let a: ScalarZnx<&[u8]> = a.to_ref(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + assert_eq!(res.n(), module.n()); + } + + unsafe { + vec_znx::vec_znx_sub( + module.ptr() as *const module_info_t, + res.at_mut_ptr(res_col, res_limb), + 1 as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + res.at_ptr(res_col, res_limb), + 1 as u64, + res.sl() as u64, + ) + } + } +} + +unsafe impl VecZnxNegateImpl for B +where + B: CPUAVX, +{ + fn vec_znx_negate_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + assert_eq!(res.n(), module.n()); + } + unsafe { + vec_znx::vec_znx_negate( + module.ptr() as *const module_info_t, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) + } + } +} + +unsafe impl VecZnxNegateInplaceImpl for B +where + B: CPUAVX, +{ + fn vec_znx_negate_inplace_impl(module: &Module, a: &mut A, a_col: usize) + where + A: VecZnxToMut, + { + let mut a: VecZnx<&mut [u8]> = a.to_mut(); + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + } + unsafe { + vec_znx::vec_znx_negate( + module.ptr() as *const module_info_t, + a.at_mut_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) + } + } +} + +unsafe impl VecZnxLshInplaceImpl for B +where + B: CPUAVX, +{ + fn vec_znx_lsh_inplace_impl(_module: &Module, basek: usize, k: usize, a: &mut A) + where + A: VecZnxToMut, + { + vec_znx_lsh_inplace_ref(basek, k, a) + } +} + +pub fn vec_znx_lsh_inplace_ref(basek: usize, k: usize, a: &mut A) +where + A: VecZnxToMut, +{ + let mut a: VecZnx<&mut [u8]> = a.to_mut(); + + let n: usize = a.n(); + let cols: usize = a.cols(); + let size: usize = a.size(); + let steps: usize = k / basek; + + a.raw_mut().rotate_left(n * steps * cols); + (0..cols).for_each(|i| { + (size - steps..size).for_each(|j| { + a.zero_at(i, j); + }) + }); + + let k_rem: usize = k % basek; + + if k_rem != 0 { + let shift: usize = i64::BITS as usize - k_rem; + (0..cols).for_each(|i| { + (0..steps).for_each(|j| { + a.at_mut(i, j).iter_mut().for_each(|xi| { + *xi <<= shift; + }); + }); + }); + } +} + +unsafe impl VecZnxRshInplaceImpl for B +where + B: CPUAVX, +{ + fn vec_znx_rsh_inplace_impl(_module: &Module, basek: usize, k: usize, a: &mut A) + where + A: VecZnxToMut, + { + vec_znx_rsh_inplace_ref(basek, k, a) + } +} + +pub fn vec_znx_rsh_inplace_ref(basek: usize, k: usize, a: &mut A) +where + A: VecZnxToMut, +{ + let mut a: VecZnx<&mut [u8]> = a.to_mut(); + let n: usize = a.n(); + let cols: usize = a.cols(); + let size: usize = a.size(); + let steps: usize = k / basek; + + a.raw_mut().rotate_right(n * steps * cols); + (0..cols).for_each(|i| { + (0..steps).for_each(|j| { + a.zero_at(i, j); + }) + }); + + let k_rem: usize = k % basek; + + if k_rem != 0 { + let mut carry: Vec = vec![0i64; n]; // ALLOC (but small so OK) + let shift: usize = i64::BITS as usize - k_rem; + (0..cols).for_each(|i| { + carry.fill(0); + (steps..size).for_each(|j| { + izip!(carry.iter_mut(), a.at_mut(i, j).iter_mut()).for_each(|(ci, xi)| { + *xi += *ci << basek; + *ci = (*xi << shift) >> shift; + *xi = (*xi - *ci) >> k_rem; + }); + }); + }) + } +} + +unsafe impl VecZnxRotateImpl for B +where + B: CPUAVX, +{ + fn vec_znx_rotate_impl(_module: &Module, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + #[cfg(debug_assertions)] + { + assert_eq!(res.n(), a.n()); + } + unsafe { + (0..a.size()).for_each(|j| { + znx::znx_rotate_i64( + a.n() as u64, + k, + res.at_mut_ptr(res_col, j), + a.at_ptr(a_col, j), + ); + }); + } + } +} + +unsafe impl VecZnxRotateInplaceImpl for B +where + B: CPUAVX, +{ + fn vec_znx_rotate_inplace_impl(_module: &Module, k: i64, a: &mut A, a_col: usize) + where + A: VecZnxToMut, + { + let mut a: VecZnx<&mut [u8]> = a.to_mut(); + unsafe { + (0..a.size()).for_each(|j| { + znx::znx_rotate_inplace_i64(a.n() as u64, k, a.at_mut_ptr(a_col, j)); + }); + } + } +} + +unsafe impl VecZnxAutomorphismImpl for B +where + B: CPUAVX, +{ + fn vec_znx_automorphism_impl(module: &Module, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + assert_eq!(res.n(), module.n()); + } + unsafe { + vec_znx::vec_znx_automorphism( + module.ptr() as *const module_info_t, + k, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) + } + } +} + +unsafe impl VecZnxAutomorphismInplaceImpl for B +where + B: CPUAVX, +{ + fn vec_znx_automorphism_inplace_impl(module: &Module, k: i64, a: &mut A, a_col: usize) + where + A: VecZnxToMut, + { + let mut a: VecZnx<&mut [u8]> = a.to_mut(); + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + assert!( + k & 1 != 0, + "invalid galois element: must be odd but is {}", + k + ); + } + unsafe { + vec_znx::vec_znx_automorphism( + module.ptr() as *const module_info_t, + k, + a.at_mut_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) + } + } +} + +unsafe impl VecZnxMulXpMinusOneImpl for B +where + B: CPUAVX, +{ + fn vec_znx_mul_xp_minus_one_impl(module: &Module, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + assert_eq!(res.n(), module.n()); + } + unsafe { + vec_znx::vec_znx_mul_xp_minus_one( + module.ptr() as *const module_info_t, + p, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) + } + } +} + +unsafe impl VecZnxMulXpMinusOneInplaceImpl for B +where + B: CPUAVX, +{ + fn vec_znx_mul_xp_minus_one_inplace_impl(module: &Module, p: i64, res: &mut R, res_col: usize) + where + R: VecZnxToMut, + { + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + #[cfg(debug_assertions)] + { + assert_eq!(res.n(), module.n()); + } + unsafe { + vec_znx::vec_znx_mul_xp_minus_one( + module.ptr() as *const module_info_t, + p, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + res.at_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + ) + } + } +} + +unsafe impl VecZnxSplitImpl for B +where + B: CPUAVX, +{ + fn vec_znx_split_impl( + module: &Module, + res: &mut Vec, + res_col: usize, + a: &A, + a_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxToMut, + A: VecZnxToRef, + { + vec_znx_split_ref(module, res, res_col, a, a_col, scratch) + } +} + +pub fn vec_znx_split_ref( + module: &Module, + res: &mut Vec, + res_col: usize, + a: &A, + a_col: usize, + scratch: &mut Scratch, +) where + B: CPUAVX, + R: VecZnxToMut, + A: VecZnxToRef, +{ + let a: VecZnx<&[u8]> = a.to_ref(); + + let (n_in, n_out) = (a.n(), res[0].to_mut().n()); + + let (mut buf, _) = scratch.take_vec_znx(module, 1, a.size()); + + debug_assert!( + n_out < n_in, + "invalid a: output ring degree should be smaller" + ); + res[1..].iter_mut().for_each(|bi| { + debug_assert_eq!( + bi.to_mut().n(), + n_out, + "invalid input a: all VecZnx must have the same degree" + ) + }); + + res.iter_mut().enumerate().for_each(|(i, bi)| { + if i == 0 { + module.vec_znx_switch_degree(bi, res_col, &a, a_col); + module.vec_znx_rotate(-1, &mut buf, 0, &a, a_col); + } else { + module.vec_znx_switch_degree(bi, res_col, &mut buf, a_col); + module.vec_znx_rotate_inplace(-1, &mut buf, a_col); + } + }) +} + +unsafe impl VecZnxMergeImpl for B +where + B: CPUAVX, +{ + fn vec_znx_merge_impl(module: &Module, res: &mut R, res_col: usize, a: Vec, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + vec_znx_merge_ref(module, res, res_col, a, a_col) + } +} + +pub fn vec_znx_merge_ref(module: &Module, res: &mut R, res_col: usize, a: Vec, a_col: usize) +where + B: CPUAVX, + R: VecZnxToMut, + A: VecZnxToRef, +{ + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + + let (n_in, n_out) = (res.n(), a[0].to_ref().n()); + + debug_assert!( + n_out < n_in, + "invalid a: output ring degree should be smaller" + ); + a[1..].iter().for_each(|ai| { + debug_assert_eq!( + ai.to_ref().n(), + n_out, + "invalid input a: all VecZnx must have the same degree" + ) + }); + + a.iter().enumerate().for_each(|(_, ai)| { + module.vec_znx_switch_degree(&mut res, res_col, ai, a_col); + module.vec_znx_rotate_inplace(-1, &mut res, res_col); + }); + + module.vec_znx_rotate_inplace(a.len() as i64, &mut res, res_col); +} + +unsafe impl VecZnxSwithcDegreeImpl for B +where + B: CPUAVX, +{ + fn vec_znx_switch_degree_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + vec_znx_switch_degree_ref(module, res, res_col, a, a_col) + } +} + +pub fn vec_znx_switch_degree_ref(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) +where + B: CPUAVX, + R: VecZnxToMut, + A: VecZnxToRef, +{ + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + + let (n_in, n_out) = (a.n(), res.n()); + + if n_in == n_out { + module.vec_znx_copy(&mut res, res_col, &a, a_col); + return; + } + + let (gap_in, gap_out): (usize, usize); + if n_in > n_out { + (gap_in, gap_out) = (n_in / n_out, 1) + } else { + (gap_in, gap_out) = (1, n_out / n_in); + res.zero(); + } + + let size: usize = a.size().min(res.size()); + + (0..size).for_each(|i| { + izip!( + a.at(a_col, i).iter().step_by(gap_in), + res.at_mut(res_col, i).iter_mut().step_by(gap_out) + ) + .for_each(|(x_in, x_out)| *x_out = *x_in); + }); +} + +unsafe impl VecZnxCopyImpl for B +where + B: CPUAVX, +{ + fn vec_znx_copy_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + vec_znx_copy_ref(res, res_col, a, a_col) + } +} + +pub fn vec_znx_copy_ref(res: &mut R, res_col: usize, a: &A, a_col: usize) +where + R: VecZnxToMut, + A: VecZnxToRef, +{ + let mut res_mut: VecZnx<&mut [u8]> = res.to_mut(); + let a_ref: VecZnx<&[u8]> = a.to_ref(); + + let min_size: usize = res_mut.size().min(a_ref.size()); + + (0..min_size).for_each(|j| { + res_mut + .at_mut(res_col, j) + .copy_from_slice(a_ref.at(a_col, j)); + }); + (min_size..res_mut.size()).for_each(|j| { + res_mut.zero_at(res_col, j); + }) +} + +unsafe impl VecZnxStdImpl for B +where + B: CPUAVX, +{ + fn vec_znx_std_impl(module: &Module, basek: usize, a: &A, a_col: usize) -> f64 + where + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let prec: u32 = (a.size() * basek) as u32; + let mut data: Vec = (0..a.n()).map(|_| Float::with_val(prec, 0)).collect(); + module.decode_vec_float(basek, &a, a_col, &mut data); + // std = sqrt(sum((xi - avg)^2) / n) + let mut avg: Float = Float::with_val(prec, 0); + data.iter().for_each(|x| { + avg.add_assign_round(x, Round::Nearest); + }); + avg.div_assign_round(Float::with_val(prec, data.len()), Round::Nearest); + data.iter_mut().for_each(|x| { + x.sub_assign_round(&avg, Round::Nearest); + }); + let mut std: Float = Float::with_val(prec, 0); + data.iter().for_each(|x| std += x * x); + std.div_assign_round(Float::with_val(prec, data.len()), Round::Nearest); + std = std.sqrt(); + std.to_f64() + } +} + +unsafe impl VecZnxFillUniformImpl for B +where + B: CPUAVX, +{ + fn vec_znx_fill_uniform_impl(_module: &Module, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source) + where + R: VecZnxToMut, + { + let mut a: VecZnx<&mut [u8]> = res.to_mut(); + let base2k: u64 = 1 << basek; + let mask: u64 = base2k - 1; + let base2k_half: i64 = (base2k >> 1) as i64; + (0..k.div_ceil(basek)).for_each(|j| { + a.at_mut(res_col, j) + .iter_mut() + .for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half); + }) + } +} + +unsafe impl VecZnxFillDistF64Impl for B +where + B: CPUAVX, +{ + fn vec_znx_fill_dist_f64_impl>( + _module: &Module, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + dist: D, + bound: f64, + ) where + R: VecZnxToMut, + { + let mut a: VecZnx<&mut [u8]> = res.to_mut(); + assert!( + (bound.log2().ceil() as i64) < 64, + "invalid bound: ceil(log2(bound))={} > 63", + (bound.log2().ceil() as i64) + ); + + let limb: usize = k.div_ceil(basek) - 1; + let basek_rem: usize = (limb + 1) * basek - k; + + if basek_rem != 0 { + a.at_mut(res_col, limb).iter_mut().for_each(|a| { + let mut dist_f64: f64 = dist.sample(source); + while dist_f64.abs() > bound { + dist_f64 = dist.sample(source) + } + *a = (dist_f64.round() as i64) << basek_rem; + }); + } else { + a.at_mut(res_col, limb).iter_mut().for_each(|a| { + let mut dist_f64: f64 = dist.sample(source); + while dist_f64.abs() > bound { + dist_f64 = dist.sample(source) + } + *a = dist_f64.round() as i64 + }); + } + } +} + +unsafe impl VecZnxAddDistF64Impl for B +where + B: CPUAVX, +{ + fn vec_znx_add_dist_f64_impl>( + _module: &Module, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + dist: D, + bound: f64, + ) where + R: VecZnxToMut, + { + let mut a: VecZnx<&mut [u8]> = res.to_mut(); + assert!( + (bound.log2().ceil() as i64) < 64, + "invalid bound: ceil(log2(bound))={} > 63", + (bound.log2().ceil() as i64) + ); + + let limb: usize = k.div_ceil(basek) - 1; + let basek_rem: usize = (limb + 1) * basek - k; + + if basek_rem != 0 { + a.at_mut(res_col, limb).iter_mut().for_each(|a| { + let mut dist_f64: f64 = dist.sample(source); + while dist_f64.abs() > bound { + dist_f64 = dist.sample(source) + } + *a += (dist_f64.round() as i64) << basek_rem; + }); + } else { + a.at_mut(res_col, limb).iter_mut().for_each(|a| { + let mut dist_f64: f64 = dist.sample(source); + while dist_f64.abs() > bound { + dist_f64 = dist.sample(source) + } + *a += dist_f64.round() as i64 + }); + } + } +} + +unsafe impl VecZnxFillNormalImpl for B +where + B: CPUAVX, +{ + fn vec_znx_fill_normal_impl( + module: &Module, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + sigma: f64, + bound: f64, + ) where + R: VecZnxToMut, + { + module.vec_znx_fill_dist_f64( + basek, + res, + res_col, + k, + source, + Normal::new(0.0, sigma).unwrap(), + bound, + ); + } +} + +unsafe impl VecZnxAddNormalImpl for B +where + B: CPUAVX, +{ + fn vec_znx_add_normal_impl( + module: &Module, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + sigma: f64, + bound: f64, + ) where + R: VecZnxToMut, + { + module.vec_znx_add_dist_f64( + basek, + res, + res_col, + k, + source, + Normal::new(0.0, sigma).unwrap(), + bound, + ); + } +} + +unsafe impl VecZnxEncodeVeci64Impl for B +where + B: CPUAVX, +{ + fn encode_vec_i64_impl( + _module: &Module, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + data: &[i64], + log_max: usize, + ) where + R: VecZnxToMut, + { + let size: usize = k.div_ceil(basek); + + #[cfg(debug_assertions)] + { + let a: VecZnx<&mut [u8]> = res.to_mut(); + assert!( + size <= a.size(), + "invalid argument k: k.div_ceil(basek)={} > a.size()={}", + size, + a.size() + ); + assert!(res_col < a.cols()); + assert!(data.len() <= a.n()) + } + + let data_len: usize = data.len(); + let mut a: VecZnx<&mut [u8]> = res.to_mut(); + let k_rem: usize = basek - (k % basek); + + // Zeroes coefficients of the i-th column + (0..a.size()).for_each(|i| { + a.zero_at(res_col, i); + }); + + // If 2^{basek} * 2^{k_rem} < 2^{63}-1, then we can simply copy + // values on the last limb. + // Else we decompose values base2k. + if log_max + k_rem < 63 || k_rem == basek { + a.at_mut(res_col, size - 1)[..data_len].copy_from_slice(&data[..data_len]); + } else { + let mask: i64 = (1 << basek) - 1; + let steps: usize = size.min(log_max.div_ceil(basek)); + (size - steps..size) + .rev() + .enumerate() + .for_each(|(i, i_rev)| { + let shift: usize = i * basek; + izip!(a.at_mut(res_col, i_rev).iter_mut(), data.iter()).for_each(|(y, x)| *y = (x >> shift) & mask); + }) + } + + // Case where self.prec % self.k != 0. + if k_rem != basek { + let steps: usize = size.min(log_max.div_ceil(basek)); + (size - steps..size).rev().for_each(|i| { + a.at_mut(res_col, i)[..data_len] + .iter_mut() + .for_each(|x| *x <<= k_rem); + }) + } + } +} + +unsafe impl VecZnxEncodeCoeffsi64Impl for B +where + B: CPUAVX, +{ + fn encode_coeff_i64_impl( + _module: &Module, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + i: usize, + data: i64, + log_max: usize, + ) where + R: VecZnxToMut, + { + let size: usize = k.div_ceil(basek); + + #[cfg(debug_assertions)] + { + let a: VecZnx<&mut [u8]> = res.to_mut(); + assert!(i < a.n()); + assert!( + size <= a.size(), + "invalid argument k: k.div_ceil(basek)={} > a.size()={}", + size, + a.size() + ); + assert!(res_col < a.cols()); + } + + let k_rem: usize = basek - (k % basek); + let mut a: VecZnx<&mut [u8]> = res.to_mut(); + (0..a.size()).for_each(|j| a.at_mut(res_col, j)[i] = 0); + + // If 2^{basek} * 2^{k_rem} < 2^{63}-1, then we can simply copy + // values on the last limb. + // Else we decompose values base2k. + if log_max + k_rem < 63 || k_rem == basek { + a.at_mut(res_col, size - 1)[i] = data; + } else { + let mask: i64 = (1 << basek) - 1; + let steps: usize = size.min(log_max.div_ceil(basek)); + (size - steps..size) + .rev() + .enumerate() + .for_each(|(j, j_rev)| { + a.at_mut(res_col, j_rev)[i] = (data >> (j * basek)) & mask; + }) + } + + // Case where prec % k != 0. + if k_rem != basek { + let steps: usize = size.min(log_max.div_ceil(basek)); + (size - steps..size).rev().for_each(|j| { + a.at_mut(res_col, j)[i] <<= k_rem; + }) + } + } +} + +unsafe impl VecZnxDecodeVeci64Impl for B +where + B: CPUAVX, +{ + fn decode_vec_i64_impl(_module: &Module, basek: usize, res: &R, res_col: usize, k: usize, data: &mut [i64]) + where + R: VecZnxToRef, + { + let size: usize = k.div_ceil(basek); + #[cfg(debug_assertions)] + { + let a: VecZnx<&[u8]> = res.to_ref(); + assert!( + data.len() >= a.n(), + "invalid data: data.len()={} < a.n()={}", + data.len(), + a.n() + ); + assert!(res_col < a.cols()); + } + + let a: VecZnx<&[u8]> = res.to_ref(); + data.copy_from_slice(a.at(res_col, 0)); + let rem: usize = basek - (k % basek); + if k < basek { + data.iter_mut().for_each(|x| *x >>= rem); + } else { + (1..size).for_each(|i| { + if i == size - 1 && rem != basek { + let k_rem: usize = basek - rem; + izip!(a.at(res_col, i).iter(), data.iter_mut()).for_each(|(x, y)| { + *y = (*y << k_rem) + (x >> rem); + }); + } else { + izip!(a.at(res_col, i).iter(), data.iter_mut()).for_each(|(x, y)| { + *y = (*y << basek) + x; + }); + } + }) + } + } +} + +unsafe impl VecZnxDecodeCoeffsi64Impl for B +where + B: CPUAVX, +{ + fn decode_coeff_i64_impl(_module: &Module, basek: usize, res: &R, res_col: usize, k: usize, i: usize) -> i64 + where + R: VecZnxToRef, + { + #[cfg(debug_assertions)] + { + let a: VecZnx<&[u8]> = res.to_ref(); + assert!(i < a.n()); + assert!(res_col < a.cols()) + } + + let a: VecZnx<&[u8]> = res.to_ref(); + let size: usize = k.div_ceil(basek); + let mut res: i64 = 0; + let rem: usize = basek - (k % basek); + (0..size).for_each(|j| { + let x: i64 = a.at(res_col, j)[i]; + if j == size - 1 && rem != basek { + let k_rem: usize = basek - rem; + res = (res << k_rem) + (x >> rem); + } else { + res = (res << basek) + x; + } + }); + res + } +} + +unsafe impl VecZnxDecodeVecFloatImpl for B +where + B: CPUAVX, +{ + fn decode_vec_float_impl(_module: &Module, basek: usize, res: &R, res_col: usize, data: &mut [Float]) + where + R: VecZnxToRef, + { + #[cfg(debug_assertions)] + { + let a: VecZnx<&[u8]> = res.to_ref(); + assert!( + data.len() >= a.n(), + "invalid data: data.len()={} < a.n()={}", + data.len(), + a.n() + ); + assert!(res_col < a.cols()); + } + + let a: VecZnx<&[u8]> = res.to_ref(); + let size: usize = a.size(); + let prec: u32 = (basek * size) as u32; + + // 2^{basek} + let base = Float::with_val(prec, (1 << basek) as f64); + + // y[i] = sum x[j][i] * 2^{-basek*j} + (0..size).for_each(|i| { + if i == 0 { + izip!(a.at(res_col, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { + y.assign(*x); + *y /= &base; + }); + } else { + izip!(a.at(res_col, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { + *y += Float::with_val(prec, *x); + *y /= &base; + }); + } + }); + } +} diff --git a/backend/src/implementation/cpu_spqlios/vec_znx_big_fft64.rs b/backend/src/implementation/cpu_spqlios/vec_znx_big_fft64.rs new file mode 100644 index 0000000..86fe6cc --- /dev/null +++ b/backend/src/implementation/cpu_spqlios/vec_znx_big_fft64.rs @@ -0,0 +1,758 @@ +use std::fmt; + +use rand_distr::{Distribution, Normal}; +use sampling::source::Source; + +use crate::{ + hal::{ + api::{ + TakeSlice, VecZnxBigAddDistF64, VecZnxBigFillDistF64, VecZnxBigNormalizeTmpBytes, ZnxInfos, ZnxSliceSize, ZnxView, + ZnxViewMut, + }, + layouts::{ + Data, DataRef, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigBytesOf, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, + VecZnxToMut, VecZnxToRef, + }, + oep::{ + VecZnxBigAddDistF64Impl, VecZnxBigAddImpl, VecZnxBigAddInplaceImpl, VecZnxBigAddNormalImpl, VecZnxBigAddSmallImpl, + VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, VecZnxBigAutomorphismImpl, + VecZnxBigAutomorphismInplaceImpl, VecZnxBigFillDistF64Impl, VecZnxBigFillNormalImpl, VecZnxBigFromBytesImpl, + VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl, VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubABInplaceImpl, + VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl, VecZnxBigSubSmallAImpl, VecZnxBigSubSmallAInplaceImpl, + VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl, + }, + }, + implementation::cpu_spqlios::{ffi::vec_znx, module_fft64::FFT64}, +}; + +const VEC_ZNX_BIG_FFT64_WORDSIZE: usize = 1; + +impl ZnxView for VecZnxBig { + type Scalar = i64; +} + +impl VecZnxBigBytesOf for VecZnxBig { + fn bytes_of(n: usize, cols: usize, size: usize) -> usize { + VEC_ZNX_BIG_FFT64_WORDSIZE * n * cols * size * size_of::() + } +} + +impl ZnxSliceSize for VecZnxBig { + fn sl(&self) -> usize { + VEC_ZNX_BIG_FFT64_WORDSIZE * self.n() * self.cols() + } +} + +unsafe impl VecZnxBigAllocImpl for FFT64 { + fn vec_znx_big_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxBigOwned { + VecZnxBig::, FFT64>::new(n, cols, size) + } +} + +unsafe impl VecZnxBigFromBytesImpl for FFT64 { + fn vec_znx_big_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec) -> VecZnxBigOwned { + VecZnxBig::, FFT64>::new_from_bytes(n, cols, size, bytes) + } +} + +unsafe impl VecZnxBigAllocBytesImpl for FFT64 { + fn vec_znx_big_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize { + VecZnxBig::, FFT64>::bytes_of(n, cols, size) + } +} + +unsafe impl VecZnxBigAddDistF64Impl for FFT64 { + fn add_dist_f64_impl, D: Distribution>( + _module: &Module, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + dist: D, + bound: f64, + ) { + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + assert!( + (bound.log2().ceil() as i64) < 64, + "invalid bound: ceil(log2(bound))={} > 63", + (bound.log2().ceil() as i64) + ); + + let limb: usize = k.div_ceil(basek) - 1; + let basek_rem: usize = (limb + 1) * basek - k; + + if basek_rem != 0 { + res.at_mut(res_col, limb).iter_mut().for_each(|x| { + let mut dist_f64: f64 = dist.sample(source); + while dist_f64.abs() > bound { + dist_f64 = dist.sample(source) + } + *x += (dist_f64.round() as i64) << basek_rem; + }); + } else { + res.at_mut(res_col, limb).iter_mut().for_each(|x| { + let mut dist_f64: f64 = dist.sample(source); + while dist_f64.abs() > bound { + dist_f64 = dist.sample(source) + } + *x += dist_f64.round() as i64 + }); + } + } +} + +unsafe impl VecZnxBigAddNormalImpl for FFT64 { + fn add_normal_impl>( + module: &Module, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + sigma: f64, + bound: f64, + ) { + module.vec_znx_big_add_dist_f64( + basek, + res, + res_col, + k, + source, + Normal::new(0.0, sigma).unwrap(), + bound, + ); + } +} + +unsafe impl VecZnxBigFillDistF64Impl for FFT64 { + fn fill_dist_f64_impl, D: Distribution>( + _module: &Module, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + dist: D, + bound: f64, + ) { + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + assert!( + (bound.log2().ceil() as i64) < 64, + "invalid bound: ceil(log2(bound))={} > 63", + (bound.log2().ceil() as i64) + ); + + let limb: usize = k.div_ceil(basek) - 1; + let basek_rem: usize = (limb + 1) * basek - k; + + if basek_rem != 0 { + res.at_mut(res_col, limb).iter_mut().for_each(|x| { + let mut dist_f64: f64 = dist.sample(source); + while dist_f64.abs() > bound { + dist_f64 = dist.sample(source) + } + *x = (dist_f64.round() as i64) << basek_rem; + }); + } else { + res.at_mut(res_col, limb).iter_mut().for_each(|x| { + let mut dist_f64: f64 = dist.sample(source); + while dist_f64.abs() > bound { + dist_f64 = dist.sample(source) + } + *x = dist_f64.round() as i64 + }); + } + } +} + +unsafe impl VecZnxBigFillNormalImpl for FFT64 { + fn fill_normal_impl>( + module: &Module, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + sigma: f64, + bound: f64, + ) { + module.vec_znx_big_fill_dist_f64( + basek, + res, + res_col, + k, + source, + Normal::new(0.0, sigma).unwrap(), + bound, + ); + } +} + +unsafe impl VecZnxBigAddImpl for FFT64 { + /// Adds `a` to `b` and stores the result on `c`. + fn vec_znx_big_add_impl( + module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + b: &B, + b_col: usize, + ) where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + B: VecZnxBigToRef, + { + let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); + let b: VecZnxBig<&[u8], FFT64> = b.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + assert_eq!(b.n(), module.n()); + assert_eq!(res.n(), module.n()); + assert_ne!(a.as_ptr(), b.as_ptr()); + } + unsafe { + vec_znx::vec_znx_add( + module.ptr(), + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + b.at_ptr(b_col, 0), + b.size() as u64, + b.sl() as u64, + ) + } + } +} + +unsafe impl VecZnxBigAddInplaceImpl for FFT64 { + /// Adds `a` to `b` and stores the result on `b`. + fn vec_znx_big_add_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + { + let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + assert_eq!(res.n(), module.n()); + } + unsafe { + vec_znx::vec_znx_add( + module.ptr(), + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + res.at_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + ) + } + } +} + +unsafe impl VecZnxBigAddSmallImpl for FFT64 { + /// Adds `a` to `b` and stores the result on `c`. + fn vec_znx_big_add_small_impl( + module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + b: &B, + b_col: usize, + ) where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + B: VecZnxToRef, + { + let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); + let b: VecZnx<&[u8]> = b.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + assert_eq!(b.n(), module.n()); + assert_eq!(res.n(), module.n()); + assert_ne!(a.as_ptr(), b.as_ptr()); + } + unsafe { + vec_znx::vec_znx_add( + module.ptr(), + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + b.at_ptr(b_col, 0), + b.size() as u64, + b.sl() as u64, + ) + } + } +} + +unsafe impl VecZnxBigAddSmallInplaceImpl for FFT64 { + /// Adds `a` to `b` and stores the result on `b`. + fn vec_znx_big_add_small_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + assert_eq!(res.n(), module.n()); + } + unsafe { + vec_znx::vec_znx_add( + module.ptr(), + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + res.at_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) + } + } +} + +unsafe impl VecZnxBigSubImpl for FFT64 { + /// Subtracts `a` to `b` and stores the result on `c`. + fn vec_znx_big_sub_impl( + module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + b: &B, + b_col: usize, + ) where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + B: VecZnxBigToRef, + { + let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); + let b: VecZnxBig<&[u8], FFT64> = b.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + assert_eq!(b.n(), module.n()); + assert_eq!(res.n(), module.n()); + assert_ne!(a.as_ptr(), b.as_ptr()); + } + unsafe { + vec_znx::vec_znx_sub( + module.ptr(), + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + b.at_ptr(b_col, 0), + b.size() as u64, + b.sl() as u64, + ) + } + } +} + +unsafe impl VecZnxBigSubABInplaceImpl for FFT64 { + /// Subtracts `a` from `b` and stores the result on `b`. + fn vec_znx_big_sub_ab_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + { + let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + assert_eq!(res.n(), module.n()); + } + unsafe { + vec_znx::vec_znx_sub( + module.ptr(), + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + res.at_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) + } + } +} + +unsafe impl VecZnxBigSubBAInplaceImpl for FFT64 { + /// Subtracts `b` from `a` and stores the result on `b`. + fn vec_znx_big_sub_ba_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + { + let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + assert_eq!(res.n(), module.n()); + } + unsafe { + vec_znx::vec_znx_sub( + module.ptr(), + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + res.at_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + ) + } + } +} + +unsafe impl VecZnxBigSubSmallAImpl for FFT64 { + /// Subtracts `b` from `a` and stores the result on `c`. + fn vec_znx_big_sub_small_a_impl( + module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + b: &B, + b_col: usize, + ) where + R: VecZnxBigToMut, + A: VecZnxToRef, + B: VecZnxBigToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let b: VecZnxBig<&[u8], FFT64> = b.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + assert_eq!(b.n(), module.n()); + assert_eq!(res.n(), module.n()); + assert_ne!(a.as_ptr(), b.as_ptr()); + } + unsafe { + vec_znx::vec_znx_sub( + module.ptr(), + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + b.at_ptr(b_col, 0), + b.size() as u64, + b.sl() as u64, + ) + } + } +} + +unsafe impl VecZnxBigSubSmallAInplaceImpl for FFT64 { + /// Subtracts `a` from `res` and stores the result on `res`. + fn vec_znx_big_sub_small_a_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + assert_eq!(res.n(), module.n()); + } + unsafe { + vec_znx::vec_znx_sub( + module.ptr(), + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + res.at_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) + } + } +} + +unsafe impl VecZnxBigSubSmallBImpl for FFT64 { + /// Subtracts `b` from `a` and stores the result on `c`. + fn vec_znx_big_sub_small_b_impl( + module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + b: &B, + b_col: usize, + ) where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + B: VecZnxToRef, + { + let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); + let b: VecZnx<&[u8]> = b.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + assert_eq!(b.n(), module.n()); + assert_eq!(res.n(), module.n()); + assert_ne!(a.as_ptr(), b.as_ptr()); + } + unsafe { + vec_znx::vec_znx_sub( + module.ptr(), + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + b.at_ptr(b_col, 0), + b.size() as u64, + b.sl() as u64, + ) + } + } +} + +unsafe impl VecZnxBigSubSmallBInplaceImpl for FFT64 { + /// Subtracts `res` from `a` and stores the result on `res`. + fn vec_znx_big_sub_small_b_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + assert_eq!(res.n(), module.n()); + } + unsafe { + vec_znx::vec_znx_sub( + module.ptr(), + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + res.at_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + ) + } + } +} + +unsafe impl VecZnxBigNegateInplaceImpl for FFT64 { + fn vec_znx_big_negate_inplace_impl(module: &Module, a: &mut A, a_col: usize) + where + A: VecZnxBigToMut, + { + let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut(); + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + } + unsafe { + vec_znx::vec_znx_negate( + module.ptr(), + a.at_mut_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) + } + } +} + +unsafe impl VecZnxBigNormalizeTmpBytesImpl for FFT64 { + fn vec_znx_big_normalize_tmp_bytes_impl(module: &Module, n: usize) -> usize { + unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(module.ptr(), n as u64) as usize } + } +} + +unsafe impl VecZnxBigNormalizeImpl for FFT64 { + fn vec_znx_big_normalize_impl( + module: &Module, + basek: usize, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxToMut, + A: VecZnxBigToRef, + { + let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.n(), a.n()); + } + + let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes(a.n())); + unsafe { + vec_znx::vec_znx_normalize_base2k( + module.ptr(), + a.n() as u64, + basek as u64, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + tmp_bytes.as_mut_ptr(), + ); + } + } +} + +unsafe impl VecZnxBigAutomorphismImpl for FFT64 { + /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`. + fn vec_znx_big_automorphism_impl(module: &Module, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + { + let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + assert_eq!(res.n(), module.n()); + } + unsafe { + vec_znx::vec_znx_automorphism( + module.ptr(), + k, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) + } + } +} + +unsafe impl VecZnxBigAutomorphismInplaceImpl for FFT64 { + /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`. + fn vec_znx_big_automorphism_inplace_impl(module: &Module, k: i64, a: &mut A, a_col: usize) + where + A: VecZnxBigToMut, + { + let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + } + unsafe { + vec_znx::vec_znx_automorphism( + module.ptr(), + k, + a.at_mut_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) + } + } +} + +impl fmt::Display for VecZnxBig { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!( + f, + "VecZnxBig(n={}, cols={}, size={})", + self.n, self.cols, self.size + )?; + + for col in 0..self.cols { + writeln!(f, "Column {}:", col)?; + for size in 0..self.size { + let coeffs = self.at(col, size); + write!(f, " Size {}: [", size)?; + + let max_show = 100; + let show_count = coeffs.len().min(max_show); + + for (i, &coeff) in coeffs.iter().take(show_count).enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", coeff)?; + } + + if coeffs.len() > max_show { + write!(f, ", ... ({} more)", coeffs.len() - max_show)?; + } + + writeln!(f, "]")?; + } + } + Ok(()) + } +} diff --git a/backend/src/implementation/cpu_spqlios/vec_znx_big_ntt120.rs b/backend/src/implementation/cpu_spqlios/vec_znx_big_ntt120.rs new file mode 100644 index 0000000..42e632a --- /dev/null +++ b/backend/src/implementation/cpu_spqlios/vec_znx_big_ntt120.rs @@ -0,0 +1,32 @@ +use crate::{ + hal::{ + api::{ZnxInfos, ZnxSliceSize, ZnxView}, + layouts::{Data, DataRef, VecZnxBig, VecZnxBigBytesOf}, + oep::VecZnxBigAllocBytesImpl, + }, + implementation::cpu_spqlios::module_ntt120::NTT120, +}; + +const VEC_ZNX_BIG_NTT120_WORDSIZE: usize = 4; + +impl ZnxView for VecZnxBig { + type Scalar = i128; +} + +impl VecZnxBigBytesOf for VecZnxBig { + fn bytes_of(n: usize, cols: usize, size: usize) -> usize { + VEC_ZNX_BIG_NTT120_WORDSIZE * n * cols * size * size_of::() + } +} + +impl ZnxSliceSize for VecZnxBig { + fn sl(&self) -> usize { + VEC_ZNX_BIG_NTT120_WORDSIZE * self.n() * self.cols() + } +} + +unsafe impl VecZnxBigAllocBytesImpl for NTT120 { + fn vec_znx_big_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize { + VecZnxBig::, NTT120>::bytes_of(n, cols, size) + } +} diff --git a/backend/src/vec_znx_dft_ops.rs b/backend/src/implementation/cpu_spqlios/vec_znx_dft_fft64.rs similarity index 54% rename from backend/src/vec_znx_dft_ops.rs rename to backend/src/implementation/cpu_spqlios/vec_znx_dft_fft64.rs index 5892155..2d9559b 100644 --- a/backend/src/vec_znx_dft_ops.rs +++ b/backend/src/implementation/cpu_spqlios/vec_znx_dft_fft64.rs @@ -1,375 +1,90 @@ -use crate::ffi::{vec_znx_big, vec_znx_dft}; -use crate::vec_znx_dft::bytes_of_vec_znx_dft; -use crate::znx_base::ZnxInfos; +use std::fmt; + use crate::{ - Backend, Scratch, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, - ZnxSliceSize, + hal::{ + api::{TakeSlice, VecZnxDftToVecZnxBigTmpBytes, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero}, + layouts::{ + Data, DataRef, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftBytesOf, VecZnxDftOwned, + VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, + }, + oep::{ + VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftCopyImpl, + VecZnxDftFromBytesImpl, VecZnxDftFromVecZnxImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl, + VecZnxDftSubImpl, VecZnxDftToVecZnxBigConsumeImpl, VecZnxDftToVecZnxBigImpl, VecZnxDftToVecZnxBigTmpAImpl, + VecZnxDftToVecZnxBigTmpBytesImpl, VecZnxDftZeroImpl, + }, + }, + implementation::cpu_spqlios::{ + ffi::{vec_znx_big, vec_znx_dft}, + module_fft64::FFT64, + }, }; -use crate::{FFT64, Module, ZnxView, ZnxViewMut, ZnxZero}; -use std::cmp::min; -pub trait VecZnxDftAlloc { - /// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space. - fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDftOwned; +const VEC_ZNX_DFT_FFT64_WORDSIZE: usize = 1; - /// Returns a new [VecZnxDft] with the provided bytes array as backing array. - /// - /// Behavior: takes ownership of the backing array. - /// - /// # Arguments - /// - /// * `cols`: the number of cols of the [VecZnxDft]. - /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft]. - /// - /// # Panics - /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. - fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxDftOwned; - - /// Returns a new [VecZnxDft] with the provided bytes array as backing array. - /// - /// # Arguments - /// - /// * `cols`: the number of cols of the [VecZnxDft]. - /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft]. - /// - /// # Panics - /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. - fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize; -} - -pub trait VecZnxDftOps { - /// Returns the minimum number of bytes necessary to allocate - /// a new [VecZnxDft] through [VecZnxDft::from_bytes]. - fn vec_znx_idft_tmp_bytes(&self) -> usize; - - fn vec_znx_dft_add(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize) - where - R: VecZnxDftToMut, - A: VecZnxDftToRef, - D: VecZnxDftToRef; - - fn vec_znx_dft_add_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxDftToMut, - A: VecZnxDftToRef; - - fn vec_znx_dft_sub(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize) - where - R: VecZnxDftToMut, - A: VecZnxDftToRef, - D: VecZnxDftToRef; - - fn vec_znx_dft_sub_ab_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxDftToMut, - A: VecZnxDftToRef; - - fn vec_znx_dft_sub_ba_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxDftToMut, - A: VecZnxDftToRef; - - fn vec_znx_dft_copy(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxDftToMut, - A: VecZnxDftToRef; - - /// b <- IDFT(a), uses a as scratch space. - fn vec_znx_idft_tmp_a(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize) - where - R: VecZnxBigToMut, - A: VecZnxDftToMut; - - /// Consumes a to return IDFT(a) in big coeff space. - fn vec_znx_idft_consume(&self, a: VecZnxDft) -> VecZnxBig - where - VecZnxDft: VecZnxDftToMut; - - fn vec_znx_idft(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) - where - R: VecZnxBigToMut, - A: VecZnxDftToRef; - - fn vec_znx_dft(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxDftToMut, - A: VecZnxToRef; -} - -impl VecZnxDftAlloc for Module { - fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDftOwned { - VecZnxDftOwned::new(&self, cols, size) - } - - fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxDftOwned { - VecZnxDftOwned::new_from_bytes(self, cols, size, bytes) - } - - fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize { - bytes_of_vec_znx_dft(self, cols, size) +impl ZnxSliceSize for VecZnxDft { + fn sl(&self) -> usize { + VEC_ZNX_DFT_FFT64_WORDSIZE * self.n() * self.cols() } } -impl VecZnxDftOps for Module { - fn vec_znx_dft_add(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize) - where - R: VecZnxDftToMut, - A: VecZnxDftToRef, - D: VecZnxDftToRef, - { - let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); - let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); - let b_ref: VecZnxDft<&[u8], FFT64> = b.to_ref(); - - let min_size: usize = res_mut.size().min(a_ref.size()).min(b_ref.size()); - - unsafe { - (0..min_size).for_each(|j| { - vec_znx_dft::vec_dft_add( - self.ptr, - res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, - 1, - a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, - 1, - b_ref.at_ptr(b_col, j) as *const vec_znx_dft::vec_znx_dft_t, - 1, - ); - }); - } - (min_size..res_mut.size()).for_each(|j| { - res_mut.zero_at(res_col, j); - }) +impl VecZnxDftBytesOf for VecZnxDft { + fn bytes_of(n: usize, cols: usize, size: usize) -> usize { + VEC_ZNX_DFT_FFT64_WORDSIZE * n * cols * size * size_of::() } +} - fn vec_znx_dft_add_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxDftToMut, - A: VecZnxDftToRef, - { - let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); - let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); +impl ZnxView for VecZnxDft { + type Scalar = f64; +} - let min_size: usize = res_mut.size().min(a_ref.size()); - - unsafe { - (0..min_size).for_each(|j| { - vec_znx_dft::vec_dft_add( - self.ptr, - res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, - 1, - res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t, - 1, - a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, - 1, - ); - }); - } +unsafe impl VecZnxDftFromBytesImpl for FFT64 { + fn vec_znx_dft_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec) -> VecZnxDftOwned { + VecZnxDft::, FFT64>::from_bytes(n, cols, size, bytes) } +} - fn vec_znx_dft_sub(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize) - where - R: VecZnxDftToMut, - A: VecZnxDftToRef, - D: VecZnxDftToRef, - { - let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); - let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); - let b_ref: VecZnxDft<&[u8], FFT64> = b.to_ref(); - - let min_size: usize = res_mut.size().min(a_ref.size()).min(b_ref.size()); - - unsafe { - (0..min_size).for_each(|j| { - vec_znx_dft::vec_dft_sub( - self.ptr, - res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, - 1, - a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, - 1, - b_ref.at_ptr(b_col, j) as *const vec_znx_dft::vec_znx_dft_t, - 1, - ); - }); - } - (min_size..res_mut.size()).for_each(|j| { - res_mut.zero_at(res_col, j); - }) +unsafe impl VecZnxDftAllocBytesImpl for FFT64 { + fn vec_znx_dft_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize { + VecZnxDft::, FFT64>::bytes_of(n, cols, size) } +} - fn vec_znx_dft_sub_ab_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxDftToMut, - A: VecZnxDftToRef, - { - let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); - let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); - - let min_size: usize = res_mut.size().min(a_ref.size()); - - unsafe { - (0..min_size).for_each(|j| { - vec_znx_dft::vec_dft_sub( - self.ptr, - res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, - 1, - res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t, - 1, - a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, - 1, - ); - }); - } +unsafe impl VecZnxDftAllocImpl for FFT64 { + fn vec_znx_dft_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxDftOwned { + VecZnxDftOwned::alloc(n, cols, size) } +} - fn vec_znx_dft_sub_ba_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxDftToMut, - A: VecZnxDftToRef, - { - let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); - let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); - - let min_size: usize = res_mut.size().min(a_ref.size()); - - unsafe { - (0..min_size).for_each(|j| { - vec_znx_dft::vec_dft_sub( - self.ptr, - res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, - 1, - a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, - 1, - res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t, - 1, - ); - }); - } +unsafe impl VecZnxDftToVecZnxBigTmpBytesImpl for FFT64 { + fn vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(module: &Module) -> usize { + unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(module.ptr()) as usize } } +} - fn vec_znx_dft_copy(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxDftToMut, - A: VecZnxDftToRef, - { - let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); - let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); - - let steps: usize = (a_ref.size() + step - 1) / step; - let min_steps: usize = min(res_mut.size(), steps); - - (0..min_steps).for_each(|j| { - let limb: usize = offset + j * step; - if limb < a_ref.size() { - res_mut - .at_mut(res_col, j) - .copy_from_slice(a_ref.at(a_col, limb)); - } - }); - (min_steps..res_mut.size()).for_each(|j| { - res_mut.zero_at(res_col, j); - }) - } - - fn vec_znx_idft_tmp_a(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize) - where - R: VecZnxBigToMut, - A: VecZnxDftToMut, - { - let mut res_mut: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); - let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut(); - - let min_size: usize = min(res_mut.size(), a_mut.size()); - - unsafe { - (0..min_size).for_each(|j| { - vec_znx_dft::vec_znx_idft_tmp_a( - self.ptr, - res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t, - 1 as u64, - a_mut.at_mut_ptr(a_col, j) as *mut vec_znx_dft::vec_znx_dft_t, - 1 as u64, - ) - }); - (min_size..res_mut.size()).for_each(|j| { - res_mut.zero_at(res_col, j); - }) - } - } - - fn vec_znx_idft_consume(&self, mut a: VecZnxDft) -> VecZnxBig - where - VecZnxDft: VecZnxDftToMut, - { - let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut(); - - unsafe { - // Rev col and rows because ZnxDft.sl() >= ZnxBig.sl() - (0..a_mut.size()).for_each(|j| { - (0..a_mut.cols()).for_each(|i| { - vec_znx_dft::vec_znx_idft_tmp_a( - self.ptr, - a_mut.at_mut_ptr(i, j) as *mut vec_znx_big::vec_znx_big_t, - 1 as u64, - a_mut.at_mut_ptr(i, j) as *mut vec_znx_dft::vec_znx_dft_t, - 1 as u64, - ) - }); - }); - } - - a.into_big() - } - - fn vec_znx_idft_tmp_bytes(&self) -> usize { - unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(self.ptr) as usize } - } - - fn vec_znx_dft(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxDftToMut, - A: VecZnxToRef, - { - let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); - let a_ref: crate::VecZnx<&[u8]> = a.to_ref(); - let steps: usize = (a_ref.size() + step - 1) / step; - let min_steps: usize = min(res_mut.size(), steps); - unsafe { - (0..min_steps).for_each(|j| { - let limb: usize = offset + j * step; - if limb < a_ref.size() { - vec_znx_dft::vec_znx_dft( - self.ptr, - res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, - 1 as u64, - a_ref.at_ptr(a_col, limb), - 1 as u64, - a_ref.sl() as u64, - ) - } - }); - (min_steps..res_mut.size()).for_each(|j| { - res_mut.zero_at(res_col, j); - }); - } - } - - // b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes]. - fn vec_znx_idft(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) - where +unsafe impl VecZnxDftToVecZnxBigImpl for FFT64 { + fn vec_znx_dft_to_vec_znx_big_impl( + module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + scratch: &mut Scratch, + ) where R: VecZnxBigToMut, A: VecZnxDftToRef, { let mut res_mut: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); - let (tmp_bytes, _) = scratch.tmp_slice(self.vec_znx_idft_tmp_bytes()); + let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_dft_to_vec_znx_big_tmp_bytes()); - let min_size: usize = min(res_mut.size(), a_ref.size()); + let min_size: usize = res_mut.size().min(a_ref.size()); unsafe { (0..min_size).for_each(|j| { vec_znx_dft::vec_znx_idft( - self.ptr, + module.ptr(), res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t, 1 as u64, a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, @@ -383,3 +98,331 @@ impl VecZnxDftOps for Module { } } } + +unsafe impl VecZnxDftToVecZnxBigTmpAImpl for FFT64 { + fn vec_znx_dft_to_vec_znx_big_tmp_a_impl(module: &Module, res: &mut R, res_col: usize, a: &mut A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxDftToMut, + { + let mut res_mut: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut(); + + let min_size: usize = res_mut.size().min(a_mut.size()); + + unsafe { + (0..min_size).for_each(|j| { + vec_znx_dft::vec_znx_idft_tmp_a( + module.ptr(), + res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t, + 1 as u64, + a_mut.at_mut_ptr(a_col, j) as *mut vec_znx_dft::vec_znx_dft_t, + 1 as u64, + ) + }); + (min_size..res_mut.size()).for_each(|j| { + res_mut.zero_at(res_col, j); + }) + } + } +} + +unsafe impl VecZnxDftToVecZnxBigConsumeImpl for FFT64 { + fn vec_znx_dft_to_vec_znx_big_consume_impl(module: &Module, mut a: VecZnxDft) -> VecZnxBig + where + VecZnxDft: VecZnxDftToMut, + { + let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut(); + + unsafe { + // Rev col and rows because ZnxDft.sl() >= ZnxBig.sl() + (0..a_mut.size()).for_each(|j| { + (0..a_mut.cols()).for_each(|i| { + vec_znx_dft::vec_znx_idft_tmp_a( + module.ptr(), + a_mut.at_mut_ptr(i, j) as *mut vec_znx_big::vec_znx_big_t, + 1 as u64, + a_mut.at_mut_ptr(i, j) as *mut vec_znx_dft::vec_znx_dft_t, + 1 as u64, + ) + }); + }); + } + + a.into_big() + } +} + +unsafe impl VecZnxDftFromVecZnxImpl for FFT64 { + fn vec_znx_dft_from_vec_znx_impl( + module: &Module, + step: usize, + offset: usize, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + ) where + R: VecZnxDftToMut, + A: VecZnxToRef, + { + let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); + let a_ref: VecZnx<&[u8]> = a.to_ref(); + let steps: usize = a_ref.size().div_ceil(step); + let min_steps: usize = res_mut.size().min(steps); + unsafe { + (0..min_steps).for_each(|j| { + let limb: usize = offset + j * step; + if limb < a_ref.size() { + vec_znx_dft::vec_znx_dft( + module.ptr(), + res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, + 1 as u64, + a_ref.at_ptr(a_col, limb), + 1 as u64, + a_ref.sl() as u64, + ) + } + }); + (min_steps..res_mut.size()).for_each(|j| { + res_mut.zero_at(res_col, j); + }); + } + } +} + +unsafe impl VecZnxDftAddImpl for FFT64 { + fn vec_znx_dft_add_impl( + module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + b: &D, + b_col: usize, + ) where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + D: VecZnxDftToRef, + { + let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); + let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); + let b_ref: VecZnxDft<&[u8], FFT64> = b.to_ref(); + + let min_size: usize = res_mut.size().min(a_ref.size()).min(b_ref.size()); + + unsafe { + (0..min_size).for_each(|j| { + vec_znx_dft::vec_dft_add( + module.ptr(), + res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, + 1, + a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1, + b_ref.at_ptr(b_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1, + ); + }); + } + (min_size..res_mut.size()).for_each(|j| { + res_mut.zero_at(res_col, j); + }) + } +} + +unsafe impl VecZnxDftAddInplaceImpl for FFT64 { + fn vec_znx_dft_add_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + { + let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); + let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); + + let min_size: usize = res_mut.size().min(a_ref.size()); + + unsafe { + (0..min_size).for_each(|j| { + vec_znx_dft::vec_dft_add( + module.ptr(), + res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, + 1, + res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1, + a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1, + ); + }); + } + } +} + +unsafe impl VecZnxDftSubImpl for FFT64 { + fn vec_znx_dft_sub_impl( + module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + b: &D, + b_col: usize, + ) where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + D: VecZnxDftToRef, + { + let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); + let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); + let b_ref: VecZnxDft<&[u8], FFT64> = b.to_ref(); + + let min_size: usize = res_mut.size().min(a_ref.size()).min(b_ref.size()); + + unsafe { + (0..min_size).for_each(|j| { + vec_znx_dft::vec_dft_sub( + module.ptr(), + res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, + 1, + a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1, + b_ref.at_ptr(b_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1, + ); + }); + } + (min_size..res_mut.size()).for_each(|j| { + res_mut.zero_at(res_col, j); + }) + } +} + +unsafe impl VecZnxDftSubABInplaceImpl for FFT64 { + fn vec_znx_dft_sub_ab_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + { + let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); + let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); + + let min_size: usize = res_mut.size().min(a_ref.size()); + + unsafe { + (0..min_size).for_each(|j| { + vec_znx_dft::vec_dft_sub( + module.ptr(), + res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, + 1, + res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1, + a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1, + ); + }); + } + } +} + +unsafe impl VecZnxDftSubBAInplaceImpl for FFT64 { + fn vec_znx_dft_sub_ba_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + { + let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); + let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); + + let min_size: usize = res_mut.size().min(a_ref.size()); + + unsafe { + (0..min_size).for_each(|j| { + vec_znx_dft::vec_dft_sub( + module.ptr(), + res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, + 1, + a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1, + res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1, + ); + }); + } + } +} + +unsafe impl VecZnxDftCopyImpl for FFT64 { + fn vec_znx_dft_copy_impl( + _module: &Module, + step: usize, + offset: usize, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + ) where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + { + let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); + let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); + + let steps: usize = a_ref.size().div_ceil(step); + let min_steps: usize = res_mut.size().min(steps); + + (0..min_steps).for_each(|j| { + let limb: usize = offset + j * step; + if limb < a_ref.size() { + res_mut + .at_mut(res_col, j) + .copy_from_slice(a_ref.at(a_col, limb)); + } + }); + (min_steps..res_mut.size()).for_each(|j| { + res_mut.zero_at(res_col, j); + }) + } +} + +unsafe impl VecZnxDftZeroImpl for FFT64 { + fn vec_znx_dft_zero_impl(_module: &Module, res: &mut R) + where + R: VecZnxDftToMut, + { + res.to_mut().data.fill(0); + } +} + +impl fmt::Display for VecZnxDft { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!( + f, + "VecZnxDft(n={}, cols={}, size={})", + self.n, self.cols, self.size + )?; + + for col in 0..self.cols { + writeln!(f, "Column {}:", col)?; + for size in 0..self.size { + let coeffs = self.at(col, size); + write!(f, " Size {}: [", size)?; + + let max_show = 100; + let show_count = coeffs.len().min(max_show); + + for (i, &coeff) in coeffs.iter().take(show_count).enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", coeff)?; + } + + if coeffs.len() > max_show { + write!(f, ", ... ({} more)", coeffs.len() - max_show)?; + } + + writeln!(f, "]")?; + } + } + Ok(()) + } +} diff --git a/backend/src/implementation/cpu_spqlios/vec_znx_dft_ntt120.rs b/backend/src/implementation/cpu_spqlios/vec_znx_dft_ntt120.rs new file mode 100644 index 0000000..3378fb0 --- /dev/null +++ b/backend/src/implementation/cpu_spqlios/vec_znx_dft_ntt120.rs @@ -0,0 +1,38 @@ +use crate::{ + hal::{ + api::{ZnxInfos, ZnxSliceSize, ZnxView}, + layouts::{Data, DataRef, VecZnxDft, VecZnxDftBytesOf, VecZnxDftOwned}, + oep::{VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl}, + }, + implementation::cpu_spqlios::module_ntt120::NTT120, +}; + +const VEC_ZNX_DFT_NTT120_WORDSIZE: usize = 4; + +impl ZnxSliceSize for VecZnxDft { + fn sl(&self) -> usize { + VEC_ZNX_DFT_NTT120_WORDSIZE * self.n() * self.cols() + } +} + +impl VecZnxDftBytesOf for VecZnxDft { + fn bytes_of(n: usize, cols: usize, size: usize) -> usize { + VEC_ZNX_DFT_NTT120_WORDSIZE * n * cols * size * size_of::() + } +} + +impl ZnxView for VecZnxDft { + type Scalar = i64; +} + +unsafe impl VecZnxDftAllocBytesImpl for NTT120 { + fn vec_znx_dft_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize { + VecZnxDft::, NTT120>::bytes_of(n, cols, size) + } +} + +unsafe impl VecZnxDftAllocImpl for NTT120 { + fn vec_znx_dft_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxDftOwned { + VecZnxDftOwned::alloc(n, cols, size) + } +} diff --git a/backend/src/implementation/cpu_spqlios/vmp_pmat_fft64.rs b/backend/src/implementation/cpu_spqlios/vmp_pmat_fft64.rs new file mode 100644 index 0000000..8b1106f --- /dev/null +++ b/backend/src/implementation/cpu_spqlios/vmp_pmat_fft64.rs @@ -0,0 +1,286 @@ +use crate::{ + hal::{ + api::{TakeSlice, VmpApplyTmpBytes, VmpPrepareTmpBytes, ZnxInfos, ZnxView, ZnxViewMut}, + layouts::{ + DataRef, MatZnx, MatZnxToRef, Module, Scratch, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, VmpPMat, VmpPMatBytesOf, + VmpPMatOwned, VmpPMatToMut, VmpPMatToRef, + }, + oep::{ + VmpApplyAddImpl, VmpApplyAddTmpBytesImpl, VmpApplyImpl, VmpApplyTmpBytesImpl, VmpPMatAllocBytesImpl, + VmpPMatAllocImpl, VmpPMatFromBytesImpl, VmpPMatPrepareImpl, VmpPrepareTmpBytesImpl, + }, + }, + implementation::cpu_spqlios::{ + ffi::{vec_znx_dft::vec_znx_dft_t, vmp}, + module_fft64::FFT64, + }, +}; + +const VMP_PMAT_FFT64_WORDSIZE: usize = 1; + +impl ZnxView for VmpPMat { + type Scalar = f64; +} + +impl VmpPMatBytesOf for FFT64 { + fn vmp_pmat_bytes_of(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { + VMP_PMAT_FFT64_WORDSIZE * n * rows * cols_in * cols_out * size * size_of::() + } +} + +unsafe impl VmpPMatAllocBytesImpl for FFT64 +where + FFT64: VmpPMatBytesOf, +{ + fn vmp_pmat_alloc_bytes_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { + FFT64::vmp_pmat_bytes_of(n, rows, cols_in, cols_out, size) + } +} + +unsafe impl VmpPMatFromBytesImpl for FFT64 { + fn vmp_pmat_from_bytes_impl( + n: usize, + rows: usize, + cols_in: usize, + cols_out: usize, + size: usize, + bytes: Vec, + ) -> VmpPMatOwned { + VmpPMatOwned::from_bytes(n, rows, cols_in, cols_out, size, bytes) + } +} + +unsafe impl VmpPMatAllocImpl for FFT64 { + fn vmp_pmat_alloc_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned { + VmpPMatOwned::alloc(n, rows, cols_in, cols_out, size) + } +} + +unsafe impl VmpPrepareTmpBytesImpl for FFT64 { + fn vmp_prepare_tmp_bytes_impl(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { + unsafe { + vmp::vmp_prepare_tmp_bytes( + module.ptr(), + (rows * cols_in) as u64, + (cols_out * size) as u64, + ) as usize + } + } +} + +unsafe impl VmpPMatPrepareImpl for FFT64 { + fn vmp_prepare_impl(module: &Module, res: &mut R, a: &A, scratch: &mut Scratch) + where + R: VmpPMatToMut, + A: MatZnxToRef, + { + let mut res: VmpPMat<&mut [u8], FFT64> = res.to_mut(); + let a: MatZnx<&[u8]> = a.to_ref(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.n(), module.n()); + assert_eq!(a.n(), module.n()); + assert_eq!( + res.cols_in(), + a.cols_in(), + "res.cols_in: {} != a.cols_in: {}", + res.cols_in(), + a.cols_in() + ); + assert_eq!( + res.rows(), + a.rows(), + "res.rows: {} != a.rows: {}", + res.rows(), + a.rows() + ); + assert_eq!( + res.cols_out(), + a.cols_out(), + "res.cols_out: {} != a.cols_out: {}", + res.cols_out(), + a.cols_out() + ); + assert_eq!( + res.size(), + a.size(), + "res.size: {} != a.size: {}", + res.size(), + a.size() + ); + } + + let (tmp_bytes, _) = scratch.take_slice(module.vmp_prepare_tmp_bytes(a.rows(), a.cols_in(), a.cols_out(), a.size())); + + unsafe { + vmp::vmp_prepare_contiguous( + module.ptr(), + res.as_mut_ptr() as *mut vmp::vmp_pmat_t, + a.as_ptr(), + (a.rows() * a.cols_in()) as u64, + (a.size() * a.cols_out()) as u64, + tmp_bytes.as_mut_ptr(), + ); + } + } +} + +unsafe impl VmpApplyTmpBytesImpl for FFT64 { + fn vmp_apply_tmp_bytes_impl( + module: &Module, + res_size: usize, + a_size: usize, + b_rows: usize, + b_cols_in: usize, + b_cols_out: usize, + b_size: usize, + ) -> usize { + unsafe { + vmp::vmp_apply_dft_to_dft_tmp_bytes( + module.ptr(), + (res_size * b_cols_out) as u64, + (a_size * b_cols_in) as u64, + (b_rows * b_cols_in) as u64, + (b_size * b_cols_out) as u64, + ) as usize + } + } +} + +unsafe impl VmpApplyImpl for FFT64 { + fn vmp_apply_impl(module: &Module, res: &mut R, a: &A, b: &C, scratch: &mut Scratch) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + C: VmpPMatToRef, + { + let mut res: VecZnxDft<&mut [u8], _> = res.to_mut(); + let a: VecZnxDft<&[u8], _> = a.to_ref(); + let b: VmpPMat<&[u8], _> = b.to_ref(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.n(), module.n()); + assert_eq!(b.n(), module.n()); + assert_eq!(a.n(), module.n()); + assert_eq!( + res.cols(), + b.cols_out(), + "res.cols(): {} != b.cols_out: {}", + res.cols(), + b.cols_out() + ); + assert_eq!( + a.cols(), + b.cols_in(), + "a.cols(): {} != b.cols_in: {}", + a.cols(), + b.cols_in() + ); + } + + let (tmp_bytes, _) = scratch.take_slice(module.vmp_apply_tmp_bytes( + res.size(), + a.size(), + b.rows(), + b.cols_in(), + b.cols_out(), + b.size(), + )); + unsafe { + vmp::vmp_apply_dft_to_dft( + module.ptr(), + res.as_mut_ptr() as *mut vec_znx_dft_t, + (res.size() * res.cols()) as u64, + a.as_ptr() as *const vec_znx_dft_t, + (a.size() * a.cols()) as u64, + b.as_ptr() as *const vmp::vmp_pmat_t, + (b.rows() * b.cols_in()) as u64, + (b.size() * b.cols_out()) as u64, + tmp_bytes.as_mut_ptr(), + ) + } + } +} + +unsafe impl VmpApplyAddTmpBytesImpl for FFT64 { + fn vmp_apply_add_tmp_bytes_impl( + module: &Module, + res_size: usize, + a_size: usize, + b_rows: usize, + b_cols_in: usize, + b_cols_out: usize, + b_size: usize, + ) -> usize { + unsafe { + vmp::vmp_apply_dft_to_dft_tmp_bytes( + module.ptr(), + (res_size * b_cols_out) as u64, + (a_size * b_cols_in) as u64, + (b_rows * b_cols_in) as u64, + (b_size * b_cols_out) as u64, + ) as usize + } + } +} + +unsafe impl VmpApplyAddImpl for FFT64 { + fn vmp_apply_add_impl(module: &Module, res: &mut R, a: &A, b: &C, scale: usize, scratch: &mut Scratch) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + C: VmpPMatToRef, + { + let mut res: VecZnxDft<&mut [u8], _> = res.to_mut(); + let a: VecZnxDft<&[u8], _> = a.to_ref(); + let b: VmpPMat<&[u8], _> = b.to_ref(); + + #[cfg(debug_assertions)] + { + use crate::hal::api::ZnxInfos; + + assert_eq!(res.n(), module.n()); + assert_eq!(b.n(), module.n()); + assert_eq!(a.n(), module.n()); + assert_eq!( + res.cols(), + b.cols_out(), + "res.cols(): {} != b.cols_out: {}", + res.cols(), + b.cols_out() + ); + assert_eq!( + a.cols(), + b.cols_in(), + "a.cols(): {} != b.cols_in: {}", + a.cols(), + b.cols_in() + ); + } + + let (tmp_bytes, _) = scratch.take_slice(module.vmp_apply_tmp_bytes( + res.size(), + a.size(), + b.rows(), + b.cols_in(), + b.cols_out(), + b.size(), + )); + unsafe { + vmp::vmp_apply_dft_to_dft_add( + module.ptr(), + res.as_mut_ptr() as *mut vec_znx_dft_t, + (res.size() * res.cols()) as u64, + a.as_ptr() as *const vec_znx_dft_t, + (a.size() * a.cols()) as u64, + b.as_ptr() as *const vmp::vmp_pmat_t, + (b.rows() * b.cols_in()) as u64, + (b.size() * b.cols_out()) as u64, + (scale * b.cols_out()) as u64, + tmp_bytes.as_mut_ptr(), + ) + } + } +} diff --git a/backend/src/implementation/cpu_spqlios/vmp_pmat_ntt120.rs b/backend/src/implementation/cpu_spqlios/vmp_pmat_ntt120.rs new file mode 100644 index 0000000..af135bf --- /dev/null +++ b/backend/src/implementation/cpu_spqlios/vmp_pmat_ntt120.rs @@ -0,0 +1,11 @@ +use crate::{ + hal::{ + api::ZnxView, + layouts::{DataRef, VmpPMat}, + }, + implementation::cpu_spqlios::module_ntt120::NTT120, +}; + +impl ZnxView for VmpPMat { + type Scalar = i64; +} diff --git a/backend/src/implementation/mod.rs b/backend/src/implementation/mod.rs new file mode 100644 index 0000000..15632e0 --- /dev/null +++ b/backend/src/implementation/mod.rs @@ -0,0 +1 @@ +pub mod cpu_spqlios; diff --git a/backend/src/lib.rs b/backend/src/lib.rs index 8ac50ce..8923775 100644 --- a/backend/src/lib.rs +++ b/backend/src/lib.rs @@ -1,39 +1,17 @@ -pub mod encoding; -#[allow(non_camel_case_types, non_snake_case, non_upper_case_globals, dead_code, improper_ctypes)] -// Other modules and exports -pub mod ffi; -pub mod mat_znx_dft; -pub mod mat_znx_dft_ops; -pub mod module; -pub mod sampling; -pub mod scalar_znx; -pub mod scalar_znx_dft; -pub mod scalar_znx_dft_ops; -pub mod stats; -pub mod vec_znx; -pub mod vec_znx_big; -pub mod vec_znx_big_ops; -pub mod vec_znx_dft; -pub mod vec_znx_dft_ops; -pub mod vec_znx_ops; -pub mod znx_base; +#![allow(non_camel_case_types, non_snake_case, non_upper_case_globals, dead_code, improper_ctypes)] +#![deny(rustdoc::broken_intra_doc_links)] +#![cfg_attr(docsrs, feature(doc_cfg))] +#![feature(trait_alias)] -pub use encoding::*; -pub use mat_znx_dft::*; -pub use mat_znx_dft_ops::*; -pub use module::*; -pub use sampling::*; -pub use scalar_znx::*; -pub use scalar_znx_dft::*; -pub use scalar_znx_dft_ops::*; -pub use stats::*; -pub use vec_znx::*; -pub use vec_znx_big::*; -pub use vec_znx_big_ops::*; -pub use vec_znx_dft::*; -pub use vec_znx_dft_ops::*; -pub use vec_znx_ops::*; -pub use znx_base::*; +pub mod hal; +pub mod implementation; + +pub mod doc { + #[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/docs/backend_safety_contract.md"))] + pub mod backend_safety { + pub const _PLACEHOLDER: () = (); + } +} pub const GALOISGENERATOR: u64 = 5; pub const DEFAULTALIGN: usize = 64; @@ -118,190 +96,10 @@ pub fn alloc_aligned_custom(size: usize, align: usize) -> Vec { } /// Allocates an aligned vector of size equal to the smallest multiple -/// of [DEFAULTALIGN]/size_of::() that is equal or greater to `size`. +/// of [DEFAULTALIGN]/`size_of::`() that is equal or greater to `size`. pub fn alloc_aligned(size: usize) -> Vec { alloc_aligned_custom::( - size + (DEFAULTALIGN - (size % (DEFAULTALIGN / size_of::()))), + size + (DEFAULTALIGN - (size % (DEFAULTALIGN / size_of::()))) % DEFAULTALIGN, DEFAULTALIGN, ) } - -// Scratch implementation below - -pub struct ScratchOwned(Vec); - -impl ScratchOwned { - pub fn new(byte_count: usize) -> Self { - let data: Vec = alloc_aligned(byte_count); - Self(data) - } - - pub fn borrow(&mut self) -> &mut Scratch { - Scratch::new(&mut self.0) - } -} - -pub struct Scratch { - data: [u8], -} - -impl Scratch { - fn new(data: &mut [u8]) -> &mut Self { - unsafe { &mut *(data as *mut [u8] as *mut Self) } - } - - pub fn zero(&mut self) { - self.data.fill(0); - } - - pub fn available(&self) -> usize { - let ptr: *const u8 = self.data.as_ptr(); - let self_len: usize = self.data.len(); - let aligned_offset: usize = ptr.align_offset(DEFAULTALIGN); - self_len.saturating_sub(aligned_offset) - } - - fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8]) { - let ptr: *mut u8 = data.as_mut_ptr(); - let self_len: usize = data.len(); - - let aligned_offset: usize = ptr.align_offset(DEFAULTALIGN); - let aligned_len: usize = self_len.saturating_sub(aligned_offset); - - if let Some(rem_len) = aligned_len.checked_sub(take_len) { - unsafe { - let rem_ptr: *mut u8 = ptr.add(aligned_offset).add(take_len); - let rem_slice: &mut [u8] = &mut *std::ptr::slice_from_raw_parts_mut(rem_ptr, rem_len); - - let take_slice: &mut [u8] = &mut *std::ptr::slice_from_raw_parts_mut(ptr.add(aligned_offset), take_len); - - return (take_slice, rem_slice); - } - } else { - panic!( - "Attempted to take {} from scratch with {} aligned bytes left", - take_len, - aligned_len, - // type_name::(), - // aligned_len - ); - } - } - - pub fn tmp_slice(&mut self, len: usize) -> (&mut [T], &mut Self) { - let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, len * std::mem::size_of::()); - - unsafe { - ( - &mut *(std::ptr::slice_from_raw_parts_mut(take_slice.as_mut_ptr() as *mut T, len)), - Self::new(rem_slice), - ) - } - } - - pub fn tmp_scalar_znx(&mut self, module: &Module, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) { - let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_scalar_znx(module, cols)); - - ( - ScalarZnx::from_data(take_slice, module.n(), cols), - Self::new(rem_slice), - ) - } - - pub fn tmp_scalar_znx_dft(&mut self, module: &Module, cols: usize) -> (ScalarZnxDft<&mut [u8], B>, &mut Self) { - let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_scalar_znx_dft(module, cols)); - - ( - ScalarZnxDft::from_data(take_slice, module.n(), cols), - Self::new(rem_slice), - ) - } - - pub fn tmp_vec_znx_dft( - &mut self, - module: &Module, - cols: usize, - size: usize, - ) -> (VecZnxDft<&mut [u8], B>, &mut Self) { - let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_vec_znx_dft(module, cols, size)); - - ( - VecZnxDft::from_data(take_slice, module.n(), cols, size), - Self::new(rem_slice), - ) - } - - pub fn tmp_slice_vec_znx_dft( - &mut self, - slice_size: usize, - module: &Module, - cols: usize, - size: usize, - ) -> (Vec>, &mut Self) { - let mut scratch: &mut Scratch = self; - let mut slice: Vec> = Vec::with_capacity(slice_size); - for _ in 0..slice_size { - let (znx, new_scratch) = scratch.tmp_vec_znx_dft(module, cols, size); - scratch = new_scratch; - slice.push(znx); - } - (slice, scratch) - } - - pub fn tmp_vec_znx_big( - &mut self, - module: &Module, - cols: usize, - size: usize, - ) -> (VecZnxBig<&mut [u8], B>, &mut Self) { - let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_vec_znx_big(module, cols, size)); - - ( - VecZnxBig::from_data(take_slice, module.n(), cols, size), - Self::new(rem_slice), - ) - } - - pub fn tmp_vec_znx(&mut self, module: &Module, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self) { - let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, module.bytes_of_vec_znx(cols, size)); - ( - VecZnx::from_data(take_slice, module.n(), cols, size), - Self::new(rem_slice), - ) - } - - pub fn tmp_slice_vec_znx( - &mut self, - slice_size: usize, - module: &Module, - cols: usize, - size: usize, - ) -> (Vec>, &mut Self) { - let mut scratch: &mut Scratch = self; - let mut slice: Vec> = Vec::with_capacity(slice_size); - for _ in 0..slice_size { - let (znx, new_scratch) = scratch.tmp_vec_znx(module, cols, size); - scratch = new_scratch; - slice.push(znx); - } - (slice, scratch) - } - - pub fn tmp_mat_znx_dft( - &mut self, - module: &Module, - rows: usize, - cols_in: usize, - cols_out: usize, - size: usize, - ) -> (MatZnxDft<&mut [u8], B>, &mut Self) { - let (take_slice, rem_slice) = Self::take_slice_aligned( - &mut self.data, - module.bytes_of_mat_znx_dft(rows, cols_in, cols_out, size), - ); - ( - MatZnxDft::from_data(take_slice, module.n(), rows, cols_in, cols_out, size), - Self::new(rem_slice), - ) - } -} diff --git a/backend/src/mat_znx_dft.rs b/backend/src/mat_znx_dft.rs deleted file mode 100644 index e9d6737..0000000 --- a/backend/src/mat_znx_dft.rs +++ /dev/null @@ -1,214 +0,0 @@ -use crate::znx_base::ZnxInfos; -use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxView, alloc_aligned}; -use std::marker::PhantomData; - -/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx], -/// stored as a 3D matrix in the DFT domain in a single contiguous array. -/// Each col of the [MatZnxDft] can be seen as a collection of [VecZnxDft]. -/// -/// [MatZnxDft] is used to permform a vector matrix product between a [VecZnx]/[VecZnxDft] and a [MatZnxDft]. -/// See the trait [MatZnxDftOps] for additional information. -pub struct MatZnxDft { - data: D, - n: usize, - size: usize, - rows: usize, - cols_in: usize, - cols_out: usize, - _phantom: PhantomData, -} - -impl ZnxInfos for MatZnxDft { - fn cols(&self) -> usize { - self.cols_in - } - - fn rows(&self) -> usize { - self.rows - } - - fn n(&self) -> usize { - self.n - } - - fn size(&self) -> usize { - self.size - } -} - -impl ZnxSliceSize for MatZnxDft { - fn sl(&self) -> usize { - self.n() * self.cols_out() - } -} - -impl DataView for MatZnxDft { - type D = D; - fn data(&self) -> &Self::D { - &self.data - } -} - -impl DataViewMut for MatZnxDft { - fn data_mut(&mut self) -> &mut Self::D { - &mut self.data - } -} - -impl> ZnxView for MatZnxDft { - type Scalar = f64; -} - -impl MatZnxDft { - pub fn cols_in(&self) -> usize { - self.cols_in - } - - pub fn cols_out(&self) -> usize { - self.cols_out - } -} - -impl>, B: Backend> MatZnxDft { - pub(crate) fn bytes_of(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { - unsafe { - crate::ffi::vmp::bytes_of_vmp_pmat( - module.ptr, - (rows * cols_in) as u64, - (size * cols_out) as u64, - ) as usize - } - } - - pub(crate) fn new(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self { - let data: Vec = alloc_aligned(Self::bytes_of(module, rows, cols_in, cols_out, size)); - Self { - data: data.into(), - n: module.n(), - size, - rows, - cols_in, - cols_out, - _phantom: PhantomData, - } - } - - pub(crate) fn new_from_bytes( - module: &Module, - rows: usize, - cols_in: usize, - cols_out: usize, - size: usize, - bytes: impl Into>, - ) -> Self { - let data: Vec = bytes.into(); - assert!(data.len() == Self::bytes_of(module, rows, cols_in, cols_out, size)); - Self { - data: data.into(), - n: module.n(), - size, - rows, - cols_in, - cols_out, - _phantom: PhantomData, - } - } -} - -impl> MatZnxDft { - /// Returns a copy of the backend array at index (i, j) of the [MatZnxDft]. - /// - /// # Arguments - /// - /// * `row`: row index (i). - /// * `col`: col index (j). - #[allow(dead_code)] - fn at(&self, row: usize, col: usize) -> Vec { - let n: usize = self.n(); - - let mut res: Vec = alloc_aligned(n); - - if n < 8 { - res.copy_from_slice(&self.raw()[(row + col * self.rows()) * n..(row + col * self.rows()) * (n + 1)]); - } else { - (0..n >> 3).for_each(|blk| { - res[blk * 8..(blk + 1) * 8].copy_from_slice(&self.at_block(row, col, blk)[..8]); - }); - } - - res - } - - #[allow(dead_code)] - fn at_block(&self, row: usize, col: usize, blk: usize) -> &[f64] { - let nrows: usize = self.rows(); - let nsize: usize = self.size(); - if col == (nsize - 1) && (nsize & 1 == 1) { - &self.raw()[blk * nrows * nsize * 8 + col * nrows * 8 + row * 8..] - } else { - &self.raw()[blk * nrows * nsize * 8 + (col / 2) * (2 * nrows) * 8 + row * 2 * 8 + (col % 2) * 8..] - } - } -} - -pub type MatZnxDftOwned = MatZnxDft, B>; -pub type MatZnxDftMut<'a, B> = MatZnxDft<&'a mut [u8], B>; -pub type MatZnxDftRef<'a, B> = MatZnxDft<&'a [u8], B>; - -pub trait MatZnxToRef { - fn to_ref(&self) -> MatZnxDft<&[u8], B>; -} - -impl MatZnxToRef for MatZnxDft -where - D: AsRef<[u8]>, - B: Backend, -{ - fn to_ref(&self) -> MatZnxDft<&[u8], B> { - MatZnxDft { - data: self.data.as_ref(), - n: self.n, - rows: self.rows, - cols_in: self.cols_in, - cols_out: self.cols_out, - size: self.size, - _phantom: std::marker::PhantomData, - } - } -} - -pub trait MatZnxToMut { - fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B>; -} - -impl MatZnxToMut for MatZnxDft -where - D: AsRef<[u8]> + AsMut<[u8]>, - B: Backend, -{ - fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { - MatZnxDft { - data: self.data.as_mut(), - n: self.n, - rows: self.rows, - cols_in: self.cols_in, - cols_out: self.cols_out, - size: self.size, - _phantom: std::marker::PhantomData, - } - } -} - -impl MatZnxDft { - pub(crate) fn from_data(data: D, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self { - Self { - data, - n, - rows, - cols_in, - cols_out, - size, - _phantom: PhantomData, - } - } -} diff --git a/backend/src/mat_znx_dft_ops.rs b/backend/src/mat_znx_dft_ops.rs deleted file mode 100644 index b48cb1a..0000000 --- a/backend/src/mat_znx_dft_ops.rs +++ /dev/null @@ -1,996 +0,0 @@ -use crate::ffi::vec_znx_dft::vec_znx_dft_t; -use crate::ffi::vmp; -use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; -use crate::{ - Backend, FFT64, MatZnxDft, MatZnxDftOwned, MatZnxToMut, MatZnxToRef, Module, ScalarZnxAlloc, ScalarZnxDftAlloc, - ScalarZnxDftOps, Scratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, -}; - -pub trait MatZnxDftAlloc { - /// Allocates a new [MatZnxDft] with the given number of rows and columns. - /// - /// # Arguments - /// - /// * `rows`: number of rows (number of [VecZnxDft]). - /// * `size`: number of size (number of size of each [VecZnxDft]). - fn new_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxDftOwned; - - fn bytes_of_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize; - - fn new_mat_znx_dft_from_bytes( - &self, - rows: usize, - cols_in: usize, - cols_out: usize, - size: usize, - bytes: Vec, - ) -> MatZnxDftOwned; -} - -pub trait MatZnxDftScratch { - /// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft_to_dft]. - fn vmp_apply_tmp_bytes( - &self, - res_size: usize, - a_size: usize, - b_rows: usize, - b_cols_in: usize, - b_cols_out: usize, - b_size: usize, - ) -> usize; - - fn mat_znx_dft_mul_x_pow_minus_one_scratch_space(&self, size: usize, cols_out: usize) -> usize; -} - -/// This trait implements methods for vector matrix product, -/// that is, multiplying a [VecZnx] with a [MatZnxDft]. -pub trait MatZnxDftOps { - /// Prepares the ith-row of [MatZnxDft] from a [VecZnxDft]. - /// - /// # Arguments - /// - /// * `res`: [MatZnxDft] on which the values are encoded. - /// * `a`: the [VecZnxDft] to encode on the [MatZnxDft]. - /// * `row_i`: the index of the row to prepare. - /// - /// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. - fn mat_znx_dft_set_row(&self, res: &mut R, res_row: usize, res_col_in: usize, a: &A) - where - R: MatZnxToMut, - A: VecZnxDftToRef; - - /// Extracts the ith-row of [MatZnxDft] into a [VecZnxDft]. - /// - /// # Arguments - /// - /// * `res`: the [VecZnxDft] to on which to extract the row of the [MatZnxDft]. - /// * `a`: [MatZnxDft] on which the values are encoded. - /// * `row_i`: the index of the row to extract. - fn mat_znx_dft_get_row(&self, res: &mut R, a: &A, a_row: usize, a_col_in: usize) - where - R: VecZnxDftToMut, - A: MatZnxToRef; - - /// Multiplies A by (X^{k} - 1) and stores the result on R. - fn mat_znx_dft_mul_x_pow_minus_one(&self, k: i64, res: &mut R, a: &A, scratch: &mut Scratch) - where - R: MatZnxToMut, - A: MatZnxToRef; - - /// Multiplies A by (X^{k} - 1). - fn mat_znx_dft_mul_x_pow_minus_one_inplace(&self, k: i64, a: &mut A, scratch: &mut Scratch) - where - A: MatZnxToMut; - - /// Multiplies A by (X^{k} - 1). - fn mat_znx_dft_mul_x_pow_minus_one_add_inplace(&self, k: i64, res: &mut R, a: &A, scratch: &mut Scratch) - where - R: MatZnxToMut, - A: MatZnxToRef; - - /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft]. - /// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. - /// - /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] - /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) - /// and each vector a [VecZnxDft] (row) of the [MatZnxDft]. - /// - /// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and - /// `j` size, the output is a [VecZnx] of `j` size. - /// - /// If there is a mismatch between the dimensions the largest valid ones are used. - /// - /// ```text - /// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p| - /// |h i j| - /// |k l m| - /// ``` - /// where each element is a [VecZnxDft]. - /// - /// # Arguments - /// - /// * `c`: the output of the vector matrix product, as a [VecZnxDft]. - /// * `a`: the left operand [VecZnxDft] of the vector matrix product. - /// * `b`: the right operand [MatZnxDft] of the vector matrix product. - /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. - fn vmp_apply(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch) - where - R: VecZnxDftToMut, - A: VecZnxDftToRef, - B: MatZnxToRef; - - // Same as [MatZnxDftOps::vmp_apply] except result is added on R instead of overwritting R. - fn vmp_apply_add(&self, res: &mut R, a: &A, b: &B, scale: usize, scratch: &mut Scratch) - where - R: VecZnxDftToMut, - A: VecZnxDftToRef, - B: MatZnxToRef; -} - -impl MatZnxDftAlloc for Module { - fn bytes_of_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { - MatZnxDftOwned::bytes_of(self, rows, cols_in, cols_out, size) - } - - fn new_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxDftOwned { - MatZnxDftOwned::new(self, rows, cols_in, cols_out, size) - } - - fn new_mat_znx_dft_from_bytes( - &self, - rows: usize, - cols_in: usize, - cols_out: usize, - size: usize, - bytes: Vec, - ) -> MatZnxDftOwned { - MatZnxDftOwned::new_from_bytes(self, rows, cols_in, cols_out, size, bytes) - } -} - -impl MatZnxDftScratch for Module { - fn vmp_apply_tmp_bytes( - &self, - res_size: usize, - a_size: usize, - b_rows: usize, - b_cols_in: usize, - b_cols_out: usize, - b_size: usize, - ) -> usize { - unsafe { - vmp::vmp_apply_dft_to_dft_tmp_bytes( - self.ptr, - (res_size * b_cols_out) as u64, - (a_size * b_cols_in) as u64, - (b_rows * b_cols_in) as u64, - (b_size * b_cols_out) as u64, - ) as usize - } - } - - fn mat_znx_dft_mul_x_pow_minus_one_scratch_space(&self, size: usize, cols_out: usize) -> usize { - let xpm1_dft: usize = self.bytes_of_scalar_znx(1); - let xpm1: usize = self.bytes_of_scalar_znx_dft(1); - let tmp: usize = self.bytes_of_vec_znx_dft(cols_out, size); - xpm1_dft + (xpm1 | 2 * tmp) - } -} - -impl MatZnxDftOps for Module { - fn mat_znx_dft_mul_x_pow_minus_one(&self, k: i64, res: &mut R, a: &A, scratch: &mut Scratch) - where - R: MatZnxToMut, - A: MatZnxToRef, - { - let mut res: MatZnxDft<&mut [u8], FFT64> = res.to_mut(); - let a: MatZnxDft<&[u8], FFT64> = a.to_ref(); - - #[cfg(debug_assertions)] - { - assert_eq!(res.n(), self.n()); - assert_eq!(a.n(), self.n()); - assert_eq!(res.rows(), a.rows()); - assert_eq!(res.cols_in(), a.cols_in()); - assert_eq!(res.cols_out(), a.cols_out()); - } - - let (mut xpm1_dft, scratch1) = scratch.tmp_scalar_znx_dft(self, 1); - - { - let (mut xpm1, _) = scratch1.tmp_scalar_znx(self, 1); - xpm1.data[0] = 1; - self.vec_znx_rotate_inplace(k, &mut xpm1, 0); - self.svp_prepare(&mut xpm1_dft, 0, &xpm1, 0); - } - - let (mut tmp_0, scratch2) = scratch1.tmp_vec_znx_dft(self, res.cols_out(), res.size()); - let (mut tmp_1, _) = scratch2.tmp_vec_znx_dft(self, res.cols_out(), res.size()); - - (0..res.rows()).for_each(|row_i| { - (0..res.cols_in()).for_each(|col_j| { - self.mat_znx_dft_get_row(&mut tmp_0, &a, row_i, col_j); - - (0..tmp_0.cols()).for_each(|i| { - self.svp_apply(&mut tmp_1, i, &xpm1_dft, 0, &tmp_0, i); - self.vec_znx_dft_sub_ab_inplace(&mut tmp_1, i, &tmp_0, i); - }); - - self.mat_znx_dft_set_row(&mut res, row_i, col_j, &tmp_1); - }); - }); - } - - fn mat_znx_dft_mul_x_pow_minus_one_inplace(&self, k: i64, a: &mut A, scratch: &mut Scratch) - where - A: MatZnxToMut, - { - let mut a: MatZnxDft<&mut [u8], FFT64> = a.to_mut(); - - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - } - - let (mut xpm1_dft, scratch1) = scratch.tmp_scalar_znx_dft(self, 1); - - { - let (mut xpm1, _) = scratch1.tmp_scalar_znx(self, 1); - xpm1.data[0] = 1; - self.vec_znx_rotate_inplace(k, &mut xpm1, 0); - self.svp_prepare(&mut xpm1_dft, 0, &xpm1, 0); - } - - let (mut tmp_0, scratch2) = scratch1.tmp_vec_znx_dft(self, a.cols_out(), a.size()); - let (mut tmp_1, _) = scratch2.tmp_vec_znx_dft(self, a.cols_out(), a.size()); - - (0..a.rows()).for_each(|row_i| { - (0..a.cols_in()).for_each(|col_j| { - self.mat_znx_dft_get_row(&mut tmp_0, &a, row_i, col_j); - - (0..tmp_0.cols()).for_each(|i| { - self.svp_apply(&mut tmp_1, i, &xpm1_dft, 0, &tmp_0, i); - self.vec_znx_dft_sub_ab_inplace(&mut tmp_1, i, &tmp_0, i); - }); - - self.mat_znx_dft_set_row(&mut a, row_i, col_j, &tmp_1); - }); - }); - } - - fn mat_znx_dft_mul_x_pow_minus_one_add_inplace(&self, k: i64, res: &mut R, a: &A, scratch: &mut Scratch) - where - R: MatZnxToMut, - A: MatZnxToRef, - { - let mut res: MatZnxDft<&mut [u8], FFT64> = res.to_mut(); - let a: MatZnxDft<&[u8], FFT64> = a.to_ref(); - - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - } - - let (mut xpm1_dft, scratch1) = scratch.tmp_scalar_znx_dft(self, 1); - - { - let (mut xpm1, _) = scratch1.tmp_scalar_znx(self, 1); - xpm1.data[0] = 1; - self.vec_znx_rotate_inplace(k, &mut xpm1, 0); - self.svp_prepare(&mut xpm1_dft, 0, &xpm1, 0); - } - - let (mut tmp_0, scratch2) = scratch1.tmp_vec_znx_dft(self, a.cols_out(), a.size()); - let (mut tmp_1, _) = scratch2.tmp_vec_znx_dft(self, a.cols_out(), a.size()); - - (0..a.rows()).for_each(|row_i| { - (0..a.cols_in()).for_each(|col_j| { - self.mat_znx_dft_get_row(&mut tmp_0, &a, row_i, col_j); - - (0..tmp_0.cols()).for_each(|i| { - self.svp_apply(&mut tmp_1, i, &xpm1_dft, 0, &tmp_0, i); - self.vec_znx_dft_sub_ab_inplace(&mut tmp_1, i, &tmp_0, i); - }); - - self.mat_znx_dft_get_row(&mut tmp_0, &res, row_i, col_j); - - (0..tmp_0.cols()).for_each(|i| { - self.vec_znx_dft_add_inplace(&mut tmp_0, i, &tmp_1, i); - }); - - self.mat_znx_dft_set_row(&mut res, row_i, col_j, &tmp_0); - }); - }); - } - - fn mat_znx_dft_set_row(&self, res: &mut R, res_row: usize, res_col_in: usize, a: &A) - where - R: MatZnxToMut, - A: VecZnxDftToRef, - { - let mut res: MatZnxDft<&mut [u8], FFT64> = res.to_mut(); - let a: VecZnxDft<&[u8], _> = a.to_ref(); - - #[cfg(debug_assertions)] - { - assert_eq!(res.n(), self.n()); - assert_eq!(a.n(), self.n()); - assert_eq!( - a.cols(), - res.cols_out(), - "a.cols(): {} != res.cols_out(): {}", - a.cols(), - res.cols_out() - ); - assert!( - res_row < res.rows(), - "res_row: {} >= res.rows(): {}", - res_row, - res.rows() - ); - assert!( - res_col_in < res.cols_in(), - "res_col_in: {} >= res.cols_in(): {}", - res_col_in, - res.cols_in() - ); - assert_eq!( - res.size(), - a.size(), - "res.size(): {} != a.size(): {}", - res.size(), - a.size() - ); - } - - unsafe { - vmp::vmp_prepare_row_dft( - self.ptr, - res.as_mut_ptr() as *mut vmp::vmp_pmat_t, - a.as_ptr() as *const vec_znx_dft_t, - (res_row * res.cols_in() + res_col_in) as u64, - (res.rows() * res.cols_in()) as u64, - (res.size() * res.cols_out()) as u64, - ); - } - } - - fn mat_znx_dft_get_row(&self, res: &mut R, a: &A, a_row: usize, a_col_in: usize) - where - R: VecZnxDftToMut, - A: MatZnxToRef, - { - let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); - let a: MatZnxDft<&[u8], _> = a.to_ref(); - - #[cfg(debug_assertions)] - { - assert_eq!(res.n(), self.n()); - assert_eq!(a.n(), self.n()); - assert_eq!( - res.cols(), - a.cols_out(), - "res.cols(): {} != a.cols_out(): {}", - res.cols(), - a.cols_out() - ); - assert!( - a_row < a.rows(), - "a_row: {} >= a.rows(): {}", - a_row, - a.rows() - ); - assert!( - a_col_in < a.cols_in(), - "a_col_in: {} >= a.cols_in(): {}", - a_col_in, - a.cols_in() - ); - assert_eq!( - res.size(), - a.size(), - "res.size(): {} != a.size(): {}", - res.size(), - a.size() - ); - } - unsafe { - vmp::vmp_extract_row_dft( - self.ptr, - res.as_mut_ptr() as *mut vec_znx_dft_t, - a.as_ptr() as *const vmp::vmp_pmat_t, - (a_row * a.cols_in() + a_col_in) as u64, - (a.rows() * a.cols_in()) as u64, - (a.size() * a.cols_out()) as u64, - ); - } - } - - fn vmp_apply(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch) - where - R: VecZnxDftToMut, - A: VecZnxDftToRef, - B: MatZnxToRef, - { - let mut res: VecZnxDft<&mut [u8], _> = res.to_mut(); - let a: VecZnxDft<&[u8], _> = a.to_ref(); - let b: MatZnxDft<&[u8], _> = b.to_ref(); - - #[cfg(debug_assertions)] - { - assert_eq!(res.n(), self.n()); - assert_eq!(b.n(), self.n()); - assert_eq!(a.n(), self.n()); - assert_eq!( - res.cols(), - b.cols_out(), - "res.cols(): {} != b.cols_out: {}", - res.cols(), - b.cols_out() - ); - assert_eq!( - a.cols(), - b.cols_in(), - "a.cols(): {} != b.cols_in: {}", - a.cols(), - b.cols_in() - ); - } - - let (tmp_bytes, _) = scratch.tmp_slice(self.vmp_apply_tmp_bytes( - res.size(), - a.size(), - b.rows(), - b.cols_in(), - b.cols_out(), - b.size(), - )); - unsafe { - vmp::vmp_apply_dft_to_dft( - self.ptr, - res.as_mut_ptr() as *mut vec_znx_dft_t, - (res.size() * res.cols()) as u64, - a.as_ptr() as *const vec_znx_dft_t, - (a.size() * a.cols()) as u64, - b.as_ptr() as *const vmp::vmp_pmat_t, - (b.rows() * b.cols_in()) as u64, - (b.size() * b.cols_out()) as u64, - tmp_bytes.as_mut_ptr(), - ) - } - } - - fn vmp_apply_add(&self, res: &mut R, a: &A, b: &B, scale: usize, scratch: &mut Scratch) - where - R: VecZnxDftToMut, - A: VecZnxDftToRef, - B: MatZnxToRef, - { - let mut res: VecZnxDft<&mut [u8], _> = res.to_mut(); - let a: VecZnxDft<&[u8], _> = a.to_ref(); - let b: MatZnxDft<&[u8], _> = b.to_ref(); - - #[cfg(debug_assertions)] - { - assert_eq!(res.n(), self.n()); - assert_eq!(b.n(), self.n()); - assert_eq!(a.n(), self.n()); - assert_eq!( - res.cols(), - b.cols_out(), - "res.cols(): {} != b.cols_out: {}", - res.cols(), - b.cols_out() - ); - assert_eq!( - a.cols(), - b.cols_in(), - "a.cols(): {} != b.cols_in: {}", - a.cols(), - b.cols_in() - ); - } - - let (tmp_bytes, _) = scratch.tmp_slice(self.vmp_apply_tmp_bytes( - res.size(), - a.size(), - b.rows(), - b.cols_in(), - b.cols_out(), - b.size(), - )); - unsafe { - vmp::vmp_apply_dft_to_dft_add( - self.ptr, - res.as_mut_ptr() as *mut vec_znx_dft_t, - (res.size() * res.cols()) as u64, - a.as_ptr() as *const vec_znx_dft_t, - (a.size() * a.cols()) as u64, - b.as_ptr() as *const vmp::vmp_pmat_t, - (b.rows() * b.cols_in()) as u64, - (b.size() * b.cols_out()) as u64, - (scale * b.cols_out()) as u64, - tmp_bytes.as_mut_ptr(), - ) - } - } -} -#[cfg(test)] -mod tests { - use crate::{ - Decoding, FFT64, FillUniform, MatZnxDft, MatZnxDftOps, Module, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig, - VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxInfos, ZnxView, - ZnxViewMut, ZnxZero, - }; - use sampling::source::Source; - - use super::{MatZnxDftAlloc, MatZnxDftScratch}; - - #[test] - fn vmp_set_row() { - let module: Module = Module::::new(16); - let basek: usize = 8; - let mat_rows: usize = 4; - let mat_cols_in: usize = 2; - let mat_cols_out: usize = 2; - let mat_size: usize = 5; - let mut a: VecZnx> = module.new_vec_znx(mat_cols_out, mat_size); - let mut a_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size); - let mut b_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size); - let mut mat: MatZnxDft, FFT64> = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size); - - for col_in in 0..mat_cols_in { - for row_i in 0..mat_rows { - let mut source: Source = Source::new([0u8; 32]); - (0..mat_cols_out).for_each(|col_out| { - a.fill_uniform(basek, col_out, mat_size, &mut source); - module.vec_znx_dft(1, 0, &mut a_dft, col_out, &a, col_out); - }); - module.mat_znx_dft_set_row(&mut mat, row_i, col_in, &a_dft); - module.mat_znx_dft_get_row(&mut b_dft, &mat, row_i, col_in); - assert_eq!(a_dft.raw(), b_dft.raw()); - } - } - } - - #[test] - fn vmp_apply() { - let log_n: i32 = 5; - let n: usize = 1 << log_n; - - let module: Module = Module::::new(n); - let basek: usize = 15; - let a_size: usize = 5; - let mat_size: usize = 6; - let res_size: usize = a_size; - - [1, 2].iter().for_each(|cols_in| { - [1, 2].iter().for_each(|cols_out| { - let a_cols: usize = *cols_in; - let res_cols: usize = *cols_out; - - let mat_rows: usize = a_size; - let mat_cols_in: usize = a_cols; - let mat_cols_out: usize = res_cols; - - let mut scratch: ScratchOwned = ScratchOwned::new( - module.vmp_apply_tmp_bytes( - res_size, - a_size, - mat_rows, - mat_cols_in, - mat_cols_out, - mat_size, - ) | module.vec_znx_big_normalize_tmp_bytes(), - ); - - let mut a: VecZnx> = module.new_vec_znx(a_cols, a_size); - - (0..a_cols).for_each(|i| { - a.at_mut(i, a_size - 1)[i + 1] = 1; - }); - - let mut mat_znx_dft: MatZnxDft, FFT64> = - module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size); - - let mut c_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size); - let mut c_big: VecZnxBig, FFT64> = module.new_vec_znx_big(mat_cols_out, mat_size); - - let mut tmp: VecZnx> = module.new_vec_znx(mat_cols_out, mat_size); - - // Construts a [VecZnxMatDft] that performs cyclic rotations on each submatrix. - (0..a.size()).for_each(|row_i| { - (0..mat_cols_in).for_each(|col_in_i| { - (0..mat_cols_out).for_each(|col_out_i| { - let idx = 1 + col_in_i * mat_cols_out + col_out_i; - tmp.at_mut(col_out_i, row_i)[idx] = 1 as i64; // X^{idx} - module.vec_znx_dft(1, 0, &mut c_dft, col_out_i, &tmp, col_out_i); - tmp.at_mut(col_out_i, row_i)[idx] = 0 as i64; - }); - module.mat_znx_dft_set_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft); - }); - }); - - let mut a_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(a_cols, a_size); - (0..a_cols).for_each(|i| { - module.vec_znx_dft(1, 0, &mut a_dft, i, &a, i); - }); - - module.vmp_apply(&mut c_dft, &a_dft, &mat_znx_dft, scratch.borrow()); - - let mut res_have_vi64: Vec = vec![i64::default(); n]; - - let mut res_have: VecZnx> = module.new_vec_znx(res_cols, res_size); - (0..mat_cols_out).for_each(|i| { - module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i); - module.vec_znx_big_normalize(basek, &mut res_have, i, &c_big, i, scratch.borrow()); - }); - - (0..mat_cols_out).for_each(|col_i| { - let mut res_want_vi64: Vec = vec![i64::default(); n]; - (0..a_cols).for_each(|i| { - res_want_vi64[(i + 1) + (1 + i * mat_cols_out + col_i)] = 1; - }); - res_have.decode_vec_i64(col_i, basek, basek * a_size, &mut res_have_vi64); - assert_eq!(res_have_vi64, res_want_vi64); - }); - }); - }); - } - - #[test] - fn vmp_apply_add() { - let log_n: i32 = 4; - let n: usize = 1 << log_n; - - let module: Module = Module::::new(n); - let basek: usize = 8; - let a_size: usize = 5; - let mat_size: usize = 5; - let res_size: usize = a_size; - let mut source: Source = Source::new([0u8; 32]); - - [1, 2].iter().for_each(|cols_in| { - [1, 2].iter().for_each(|cols_out| { - (0..res_size).for_each(|shift| { - let a_cols: usize = *cols_in; - let res_cols: usize = *cols_out; - - let mat_rows: usize = a_size; - let mat_cols_in: usize = a_cols; - let mat_cols_out: usize = res_cols; - - let mut scratch: ScratchOwned = ScratchOwned::new( - module.vmp_apply_tmp_bytes( - res_size, - a_size, - mat_rows, - mat_cols_in, - mat_cols_out, - mat_size, - ) | module.vec_znx_big_normalize_tmp_bytes(), - ); - - let mut a: VecZnx> = module.new_vec_znx(a_cols, a_size); - - (0..a_cols).for_each(|col_i| { - a.fill_uniform(basek, col_i, a.size(), &mut source); - }); - - let mut mat_znx_dft: MatZnxDft, FFT64> = - module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size); - - let mut c_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size); - let mut c_big: VecZnxBig, FFT64> = module.new_vec_znx_big(mat_cols_out, mat_size); - - let mut tmp: VecZnx> = module.new_vec_znx(mat_cols_out, mat_size); - - // Construts a [VecZnxMatDft] that performs cyclic rotations on each submatrix. - (0..a.size()).for_each(|row_i| { - (0..mat_cols_in).for_each(|col_in_i| { - (0..mat_cols_out).for_each(|col_out_i| { - let idx: usize = 1 + col_in_i * mat_cols_out + col_out_i; - tmp.at_mut(col_out_i, row_i)[idx] = 1 as i64; // X^{idx} - module.vec_znx_dft(1, 0, &mut c_dft, col_out_i, &tmp, col_out_i); - tmp.at_mut(col_out_i, row_i)[idx] = 0 as i64; - }); - module.mat_znx_dft_set_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft); - }); - }); - - let mut a_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(a_cols, a_size); - (0..a_cols).for_each(|i| { - module.vec_znx_dft(1, 0, &mut a_dft, i, &a, i); - }); - - c_dft.zero(); - (0..c_dft.cols()).for_each(|i| { - module.vec_znx_dft(1, 0, &mut c_dft, i, &a, 0); - }); - - module.vmp_apply_add(&mut c_dft, &a_dft, &mat_znx_dft, shift, scratch.borrow()); - - let mut res_have: VecZnx> = module.new_vec_znx(res_cols, mat_size); - (0..mat_cols_out).for_each(|i| { - module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i); - module.vec_znx_big_normalize(basek, &mut res_have, i, &c_big, i, scratch.borrow()); - }); - - let mut res_want: VecZnx> = module.new_vec_znx(res_cols, mat_size); - - // Equivalent to vmp_add & scale - module.vmp_apply(&mut c_dft, &a_dft, &mat_znx_dft, scratch.borrow()); - (0..mat_cols_out).for_each(|i| { - module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i); - module.vec_znx_big_normalize(basek, &mut res_want, i, &c_big, i, scratch.borrow()); - }); - module.vec_znx_shift_inplace( - basek, - (shift * basek) as i64, - &mut res_want, - scratch.borrow(), - ); - (0..res_cols).for_each(|i| { - module.vec_znx_add_inplace(&mut res_want, i, &a, 0); - module.vec_znx_normalize_inplace(basek, &mut res_want, i, scratch.borrow()); - }); - - assert_eq!(res_want, res_have); - }); - }); - }); - } - - #[test] - fn vmp_apply_digits() { - let log_n: i32 = 4; - let n: usize = 1 << log_n; - - let module: Module = Module::::new(n); - let basek: usize = 8; - let a_size: usize = 6; - let mat_size: usize = 6; - let res_size: usize = a_size; - - [1, 2].iter().for_each(|cols_in| { - [1, 2].iter().for_each(|cols_out| { - [1, 3, 6].iter().for_each(|digits| { - let mut source: Source = Source::new([0u8; 32]); - - let a_cols: usize = *cols_in; - let res_cols: usize = *cols_out; - - let mat_rows: usize = a_size; - let mat_cols_in: usize = a_cols; - let mat_cols_out: usize = res_cols; - - let mut scratch: ScratchOwned = ScratchOwned::new( - module.vmp_apply_tmp_bytes( - res_size, - a_size, - mat_rows, - mat_cols_in, - mat_cols_out, - mat_size, - ) | module.vec_znx_big_normalize_tmp_bytes(), - ); - - let mut a: VecZnx> = module.new_vec_znx(a_cols, a_size); - - (0..a_cols).for_each(|col_i| { - a.fill_uniform(basek, col_i, a.size(), &mut source); - }); - - let mut mat_znx_dft: MatZnxDft, FFT64> = - module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size); - - let mut c_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size); - let mut c_big: VecZnxBig, FFT64> = module.new_vec_znx_big(mat_cols_out, mat_size); - - let mut tmp: VecZnx> = module.new_vec_znx(mat_cols_out, mat_size); - - let rows: usize = a.size() / digits; - - let shift: usize = 1; - - // Construts a [VecZnxMatDft] that performs cyclic rotations on each submatrix. - (0..rows).for_each(|row_i| { - (0..mat_cols_in).for_each(|col_in_i| { - (0..mat_cols_out).for_each(|col_out_i| { - let idx: usize = shift + col_in_i * mat_cols_out + col_out_i; - let limb: usize = (digits - 1) + row_i * digits; - tmp.at_mut(col_out_i, limb)[idx] = 1 as i64; // X^{idx} - module.vec_znx_dft(1, 0, &mut c_dft, col_out_i, &tmp, col_out_i); - tmp.at_mut(col_out_i, limb)[idx] = 0 as i64; - }); - module.mat_znx_dft_set_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft); - }); - }); - - let mut a_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(a_cols, (a_size + digits - 1) / digits); - - (0..*digits).for_each(|di| { - (0..a_cols).for_each(|col_i| { - module.vec_znx_dft(*digits, digits - 1 - di, &mut a_dft, col_i, &a, col_i); - }); - - if di == 0 { - module.vmp_apply(&mut c_dft, &a_dft, &mat_znx_dft, scratch.borrow()); - } else { - module.vmp_apply_add(&mut c_dft, &a_dft, &mat_znx_dft, di, scratch.borrow()); - } - }); - - let mut res_have: VecZnx> = module.new_vec_znx(res_cols, mat_size); - (0..mat_cols_out).for_each(|i| { - module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i); - module.vec_znx_big_normalize(basek, &mut res_have, i, &c_big, i, scratch.borrow()); - }); - - let mut res_want: VecZnx> = module.new_vec_znx(res_cols, mat_size); - let mut tmp: VecZnx> = module.new_vec_znx(res_cols, mat_size); - (0..res_cols).for_each(|col_i| { - (0..a_cols).for_each(|j| { - module.vec_znx_rotate( - (col_i + j * mat_cols_out + shift) as i64, - &mut tmp, - 0, - &a, - j, - ); - module.vec_znx_add_inplace(&mut res_want, col_i, &tmp, 0); - }); - module.vec_znx_normalize_inplace(basek, &mut res_want, col_i, scratch.borrow()); - }); - - assert_eq!(res_have, res_want) - }); - }); - }); - } - - #[test] - fn mat_znx_dft_mul_x_pow_minus_one() { - let log_n: i32 = 5; - let n: usize = 1 << log_n; - - let module: Module = Module::::new(n); - let basek: usize = 8; - let rows: usize = 2; - let cols_in: usize = 2; - let cols_out: usize = 2; - let size: usize = 4; - - let mut scratch: ScratchOwned = ScratchOwned::new(module.mat_znx_dft_mul_x_pow_minus_one_scratch_space(size, cols_out)); - - let mut mat_want: MatZnxDft, FFT64> = module.new_mat_znx_dft(rows, cols_in, cols_out, size); - let mut mat_have: MatZnxDft, FFT64> = module.new_mat_znx_dft(rows, cols_in, cols_out, size); - - let mut tmp: VecZnx> = module.new_vec_znx(1, size); - let mut tmp_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(cols_out, size); - - let mut source: Source = Source::new([0u8; 32]); - - (0..mat_want.rows()).for_each(|row_i| { - (0..mat_want.cols_in()).for_each(|col_i| { - (0..cols_out).for_each(|j| { - tmp.fill_uniform(basek, 0, size, &mut source); - module.vec_znx_dft(1, 0, &mut tmp_dft, j, &tmp, 0); - }); - - module.mat_znx_dft_set_row(&mut mat_want, row_i, col_i, &tmp_dft); - }); - }); - - let k: i64 = 1; - - module.mat_znx_dft_mul_x_pow_minus_one(k, &mut mat_have, &mat_want, scratch.borrow()); - - let mut have: VecZnx> = module.new_vec_znx(cols_out, size); - let mut want: VecZnx> = module.new_vec_znx(cols_out, size); - let mut tmp_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, size); - - (0..mat_want.rows()).for_each(|row_i| { - (0..mat_want.cols_in()).for_each(|col_i| { - module.mat_znx_dft_get_row(&mut tmp_dft, &mat_want, row_i, col_i); - - (0..cols_out).for_each(|j| { - module.vec_znx_idft(&mut tmp_big, 0, &tmp_dft, j, scratch.borrow()); - module.vec_znx_big_normalize(basek, &mut tmp, 0, &tmp_big, 0, scratch.borrow()); - module.vec_znx_rotate(k, &mut want, j, &tmp, 0); - module.vec_znx_sub_ab_inplace(&mut want, j, &tmp, 0); - module.vec_znx_normalize_inplace(basek, &mut want, j, scratch.borrow()); - }); - - module.mat_znx_dft_get_row(&mut tmp_dft, &mat_have, row_i, col_i); - - (0..cols_out).for_each(|j| { - module.vec_znx_idft(&mut tmp_big, 0, &tmp_dft, j, scratch.borrow()); - module.vec_znx_big_normalize(basek, &mut have, j, &tmp_big, 0, scratch.borrow()); - }); - - assert_eq!(have, want) - }); - }); - } - - #[test] - fn mat_znx_dft_mul_x_pow_minus_one_add_inplace() { - let log_n: i32 = 5; - let n: usize = 1 << log_n; - - let module: Module = Module::::new(n); - let basek: usize = 8; - let rows: usize = 2; - let cols_in: usize = 2; - let cols_out: usize = 2; - let size: usize = 4; - - let mut scratch: ScratchOwned = ScratchOwned::new(module.mat_znx_dft_mul_x_pow_minus_one_scratch_space(size, cols_out)); - - let mut mat_want: MatZnxDft, FFT64> = module.new_mat_znx_dft(rows, cols_in, cols_out, size); - let mut mat_have: MatZnxDft, FFT64> = module.new_mat_znx_dft(rows, cols_in, cols_out, size); - - let mut tmp: VecZnx> = module.new_vec_znx(1, size); - let mut tmp_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(cols_out, size); - - let mut source: Source = Source::new([0u8; 32]); - - (0..mat_have.rows()).for_each(|row_i| { - (0..mat_have.cols_in()).for_each(|col_i| { - (0..cols_out).for_each(|j| { - tmp.fill_uniform(basek, 0, size, &mut source); - module.vec_znx_dft(1, 0, &mut tmp_dft, j, &tmp, 0); - }); - - module.mat_znx_dft_set_row(&mut mat_have, row_i, col_i, &tmp_dft); - }); - }); - - (0..mat_want.rows()).for_each(|row_i| { - (0..mat_want.cols_in()).for_each(|col_i| { - (0..cols_out).for_each(|j| { - tmp.fill_uniform(basek, 0, size, &mut source); - module.vec_znx_dft(1, 0, &mut tmp_dft, j, &tmp, 0); - }); - - module.mat_znx_dft_set_row(&mut mat_want, row_i, col_i, &tmp_dft); - }); - }); - - let k: i64 = 1; - - module.mat_znx_dft_mul_x_pow_minus_one_add_inplace(k, &mut mat_have, &mat_want, scratch.borrow()); - - let mut have: VecZnx> = module.new_vec_znx(cols_out, size); - let mut want: VecZnx> = module.new_vec_znx(cols_out, size); - let mut tmp_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, size); - - let mut source: Source = Source::new([0u8; 32]); - (0..mat_want.rows()).for_each(|row_i| { - (0..mat_want.cols_in()).for_each(|col_i| { - module.mat_znx_dft_get_row(&mut tmp_dft, &mat_want, row_i, col_i); - - (0..cols_out).for_each(|j| { - module.vec_znx_idft(&mut tmp_big, 0, &tmp_dft, j, scratch.borrow()); - module.vec_znx_big_normalize(basek, &mut tmp, 0, &tmp_big, 0, scratch.borrow()); - module.vec_znx_rotate(k, &mut want, j, &tmp, 0); - module.vec_znx_sub_ab_inplace(&mut want, j, &tmp, 0); - - tmp.fill_uniform(basek, 0, size, &mut source); - module.vec_znx_add_inplace(&mut want, j, &tmp, 0); - module.vec_znx_normalize_inplace(basek, &mut want, j, scratch.borrow()); - }); - - module.mat_znx_dft_get_row(&mut tmp_dft, &mat_have, row_i, col_i); - - (0..cols_out).for_each(|j| { - module.vec_znx_idft(&mut tmp_big, 0, &tmp_dft, j, scratch.borrow()); - module.vec_znx_big_normalize(basek, &mut have, j, &tmp_big, 0, scratch.borrow()); - }); - - assert_eq!(have, want) - }); - }); - } -} diff --git a/backend/src/sampling.rs b/backend/src/sampling.rs deleted file mode 100644 index 071adcf..0000000 --- a/backend/src/sampling.rs +++ /dev/null @@ -1,365 +0,0 @@ -use crate::znx_base::ZnxViewMut; -use crate::{FFT64, VecZnx, VecZnxBig, VecZnxBigToMut, VecZnxToMut}; -use rand_distr::{Distribution, Normal}; -use sampling::source::Source; - -pub trait FillUniform { - /// Fills the first `size` size with uniform values in \[-2^{basek-1}, 2^{basek-1}\] - fn fill_uniform(&mut self, basek: usize, col_i: usize, size: usize, source: &mut Source); -} - -pub trait FillDistF64 { - fn fill_dist_f64>( - &mut self, - basek: usize, - col_i: usize, - k: usize, - source: &mut Source, - dist: D, - bound: f64, - ); -} - -pub trait AddDistF64 { - /// Adds vector sampled according to the provided distribution, scaled by 2^{-k} and bounded to \[-bound, bound\]. - fn add_dist_f64>( - &mut self, - basek: usize, - col_i: usize, - k: usize, - source: &mut Source, - dist: D, - bound: f64, - ); -} - -pub trait FillNormal { - fn fill_normal(&mut self, basek: usize, col_i: usize, k: usize, source: &mut Source, sigma: f64, bound: f64); -} - -pub trait AddNormal { - /// Adds a discrete normal vector scaled by 2^{-k} with the provided standard deviation and bounded to \[-bound, bound\]. - fn add_normal(&mut self, basek: usize, col_i: usize, k: usize, source: &mut Source, sigma: f64, bound: f64); -} - -impl + AsRef<[u8]>> FillUniform for VecZnx -where - VecZnx: VecZnxToMut, -{ - fn fill_uniform(&mut self, basek: usize, col_i: usize, size: usize, source: &mut Source) { - let mut a: VecZnx<&mut [u8]> = self.to_mut(); - let base2k: u64 = 1 << basek; - let mask: u64 = base2k - 1; - let base2k_half: i64 = (base2k >> 1) as i64; - (0..size).for_each(|j| { - a.at_mut(col_i, j) - .iter_mut() - .for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half); - }) - } -} - -impl + AsRef<[u8]>> FillDistF64 for VecZnx -where - VecZnx: VecZnxToMut, -{ - fn fill_dist_f64>( - &mut self, - basek: usize, - col_i: usize, - k: usize, - source: &mut Source, - dist: D, - bound: f64, - ) { - let mut a: VecZnx<&mut [u8]> = self.to_mut(); - assert!( - (bound.log2().ceil() as i64) < 64, - "invalid bound: ceil(log2(bound))={} > 63", - (bound.log2().ceil() as i64) - ); - - let limb: usize = (k + basek - 1) / basek - 1; - let basek_rem: usize = (limb + 1) * basek - k; - - if basek_rem != 0 { - a.at_mut(col_i, limb).iter_mut().for_each(|a| { - let mut dist_f64: f64 = dist.sample(source); - while dist_f64.abs() > bound { - dist_f64 = dist.sample(source) - } - *a = (dist_f64.round() as i64) << basek_rem; - }); - } else { - a.at_mut(col_i, limb).iter_mut().for_each(|a| { - let mut dist_f64: f64 = dist.sample(source); - while dist_f64.abs() > bound { - dist_f64 = dist.sample(source) - } - *a = dist_f64.round() as i64 - }); - } - } -} - -impl + AsRef<[u8]>> AddDistF64 for VecZnx -where - VecZnx: VecZnxToMut, -{ - fn add_dist_f64>( - &mut self, - basek: usize, - col_i: usize, - k: usize, - source: &mut Source, - dist: D, - bound: f64, - ) { - let mut a: VecZnx<&mut [u8]> = self.to_mut(); - assert!( - (bound.log2().ceil() as i64) < 64, - "invalid bound: ceil(log2(bound))={} > 63", - (bound.log2().ceil() as i64) - ); - - let limb: usize = (k + basek - 1) / basek - 1; - let basek_rem: usize = (limb + 1) * basek - k; - - if basek_rem != 0 { - a.at_mut(col_i, limb).iter_mut().for_each(|a| { - let mut dist_f64: f64 = dist.sample(source); - while dist_f64.abs() > bound { - dist_f64 = dist.sample(source) - } - *a += (dist_f64.round() as i64) << basek_rem; - }); - } else { - a.at_mut(col_i, limb).iter_mut().for_each(|a| { - let mut dist_f64: f64 = dist.sample(source); - while dist_f64.abs() > bound { - dist_f64 = dist.sample(source) - } - *a += dist_f64.round() as i64 - }); - } - } -} - -impl + AsRef<[u8]>> FillNormal for VecZnx -where - VecZnx: VecZnxToMut, -{ - fn fill_normal(&mut self, basek: usize, col_i: usize, k: usize, source: &mut Source, sigma: f64, bound: f64) { - self.fill_dist_f64( - basek, - col_i, - k, - source, - Normal::new(0.0, sigma).unwrap(), - bound, - ); - } -} - -impl + AsRef<[u8]>> AddNormal for VecZnx -where - VecZnx: VecZnxToMut, -{ - fn add_normal(&mut self, basek: usize, col_i: usize, k: usize, source: &mut Source, sigma: f64, bound: f64) { - self.add_dist_f64( - basek, - col_i, - k, - source, - Normal::new(0.0, sigma).unwrap(), - bound, - ); - } -} - -impl + AsRef<[u8]>> FillDistF64 for VecZnxBig -where - VecZnxBig: VecZnxBigToMut, -{ - fn fill_dist_f64>( - &mut self, - basek: usize, - col_i: usize, - k: usize, - source: &mut Source, - dist: D, - bound: f64, - ) { - let mut a: VecZnxBig<&mut [u8], FFT64> = self.to_mut(); - assert!( - (bound.log2().ceil() as i64) < 64, - "invalid bound: ceil(log2(bound))={} > 63", - (bound.log2().ceil() as i64) - ); - - let limb: usize = (k + basek - 1) / basek - 1; - let basek_rem: usize = (limb + 1) * basek - k; - - if basek_rem != 0 { - a.at_mut(col_i, limb).iter_mut().for_each(|a| { - let mut dist_f64: f64 = dist.sample(source); - while dist_f64.abs() > bound { - dist_f64 = dist.sample(source) - } - *a = (dist_f64.round() as i64) << basek_rem; - }); - } else { - a.at_mut(col_i, limb).iter_mut().for_each(|a| { - let mut dist_f64: f64 = dist.sample(source); - while dist_f64.abs() > bound { - dist_f64 = dist.sample(source) - } - *a = dist_f64.round() as i64 - }); - } - } -} - -impl + AsRef<[u8]>> AddDistF64 for VecZnxBig -where - VecZnxBig: VecZnxBigToMut, -{ - fn add_dist_f64>( - &mut self, - basek: usize, - col_i: usize, - k: usize, - source: &mut Source, - dist: D, - bound: f64, - ) { - let mut a: VecZnxBig<&mut [u8], FFT64> = self.to_mut(); - assert!( - (bound.log2().ceil() as i64) < 64, - "invalid bound: ceil(log2(bound))={} > 63", - (bound.log2().ceil() as i64) - ); - - let limb: usize = (k + basek - 1) / basek - 1; - let basek_rem: usize = (limb + 1) * basek - k; - - if basek_rem != 0 { - a.at_mut(col_i, limb).iter_mut().for_each(|a| { - let mut dist_f64: f64 = dist.sample(source); - while dist_f64.abs() > bound { - dist_f64 = dist.sample(source) - } - *a += (dist_f64.round() as i64) << basek_rem; - }); - } else { - a.at_mut(col_i, limb).iter_mut().for_each(|a| { - let mut dist_f64: f64 = dist.sample(source); - while dist_f64.abs() > bound { - dist_f64 = dist.sample(source) - } - *a += dist_f64.round() as i64 - }); - } - } -} - -impl + AsRef<[u8]>> FillNormal for VecZnxBig -where - VecZnxBig: VecZnxBigToMut, -{ - fn fill_normal(&mut self, basek: usize, col_i: usize, k: usize, source: &mut Source, sigma: f64, bound: f64) { - self.fill_dist_f64( - basek, - col_i, - k, - source, - Normal::new(0.0, sigma).unwrap(), - bound, - ); - } -} - -impl + AsRef<[u8]>> AddNormal for VecZnxBig -where - VecZnxBig: VecZnxBigToMut, -{ - fn add_normal(&mut self, basek: usize, col_i: usize, k: usize, source: &mut Source, sigma: f64, bound: f64) { - self.add_dist_f64( - basek, - col_i, - k, - source, - Normal::new(0.0, sigma).unwrap(), - bound, - ); - } -} - -#[cfg(test)] -mod tests { - use super::{AddNormal, FillUniform}; - use crate::vec_znx_ops::*; - use crate::znx_base::*; - use crate::{FFT64, Module, Stats, VecZnx}; - use sampling::source::Source; - - #[test] - fn vec_znx_fill_uniform() { - let n: usize = 4096; - let module: Module = Module::::new(n); - let basek: usize = 17; - let size: usize = 5; - let mut source: Source = Source::new([0u8; 32]); - let cols: usize = 2; - let zero: Vec = vec![0; n]; - let one_12_sqrt: f64 = 0.28867513459481287; - (0..cols).for_each(|col_i| { - let mut a: VecZnx<_> = module.new_vec_znx(cols, size); - a.fill_uniform(basek, col_i, size, &mut source); - (0..cols).for_each(|col_j| { - if col_j != col_i { - (0..size).for_each(|limb_i| { - assert_eq!(a.at(col_j, limb_i), zero); - }) - } else { - let std: f64 = a.std(col_i, basek); - assert!( - (std - one_12_sqrt).abs() < 0.01, - "std={} ~!= {}", - std, - one_12_sqrt - ); - } - }) - }); - } - - #[test] - fn vec_znx_add_normal() { - let n: usize = 4096; - let module: Module = Module::::new(n); - let basek: usize = 17; - let k: usize = 2 * 17; - let size: usize = 5; - let sigma: f64 = 3.2; - let bound: f64 = 6.0 * sigma; - let mut source: Source = Source::new([0u8; 32]); - let cols: usize = 2; - let zero: Vec = vec![0; n]; - let k_f64: f64 = (1u64 << k as u64) as f64; - (0..cols).for_each(|col_i| { - let mut a: VecZnx<_> = module.new_vec_znx(cols, size); - a.add_normal(basek, col_i, k, &mut source, sigma, bound); - (0..cols).for_each(|col_j| { - if col_j != col_i { - (0..size).for_each(|limb_i| { - assert_eq!(a.at(col_j, limb_i), zero); - }) - } else { - let std: f64 = a.std(col_i, basek) * k_f64; - assert!((std - sigma).abs() < 0.1, "std={} ~!= {}", std, sigma); - } - }) - }); - } -} diff --git a/backend/src/scalar_znx_dft.rs b/backend/src/scalar_znx_dft.rs deleted file mode 100644 index d2ecb4f..0000000 --- a/backend/src/scalar_znx_dft.rs +++ /dev/null @@ -1,180 +0,0 @@ -use std::marker::PhantomData; - -use crate::ffi::svp; -use crate::znx_base::ZnxInfos; -use crate::{ - Backend, DataView, DataViewMut, FFT64, Module, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, ZnxSliceSize, ZnxView, - alloc_aligned, -}; - -pub struct ScalarZnxDft { - data: D, - n: usize, - cols: usize, - _phantom: PhantomData, -} - -impl ZnxInfos for ScalarZnxDft { - fn cols(&self) -> usize { - self.cols - } - - fn rows(&self) -> usize { - 1 - } - - fn n(&self) -> usize { - self.n - } - - fn size(&self) -> usize { - 1 - } -} - -impl ZnxSliceSize for ScalarZnxDft { - fn sl(&self) -> usize { - self.n() - } -} - -impl DataView for ScalarZnxDft { - type D = D; - fn data(&self) -> &Self::D { - &self.data - } -} - -impl DataViewMut for ScalarZnxDft { - fn data_mut(&mut self) -> &mut Self::D { - &mut self.data - } -} - -impl> ZnxView for ScalarZnxDft { - type Scalar = f64; -} - -pub(crate) fn bytes_of_scalar_znx_dft(module: &Module, cols: usize) -> usize { - ScalarZnxDftOwned::bytes_of(module, cols) -} - -impl>, B: Backend> ScalarZnxDft { - pub(crate) fn bytes_of(module: &Module, cols: usize) -> usize { - unsafe { svp::bytes_of_svp_ppol(module.ptr) as usize * cols } - } - - pub(crate) fn new(module: &Module, cols: usize) -> Self { - let data = alloc_aligned::(Self::bytes_of(module, cols)); - Self { - data: data.into(), - n: module.n(), - cols, - _phantom: PhantomData, - } - } - - pub(crate) fn new_from_bytes(module: &Module, cols: usize, bytes: impl Into>) -> Self { - let data: Vec = bytes.into(); - assert!(data.len() == Self::bytes_of(module, cols)); - Self { - data: data.into(), - n: module.n(), - cols, - _phantom: PhantomData, - } - } -} - -impl ScalarZnxDft { - pub(crate) fn from_data(data: D, n: usize, cols: usize) -> Self { - Self { - data, - n, - cols, - _phantom: PhantomData, - } - } - - pub fn as_vec_znx_dft(self) -> VecZnxDft { - VecZnxDft { - data: self.data, - n: self.n, - cols: self.cols, - size: 1, - _phantom: PhantomData, - } - } -} - -pub type ScalarZnxDftOwned = ScalarZnxDft, B>; - -pub trait ScalarZnxDftToRef { - fn to_ref(&self) -> ScalarZnxDft<&[u8], B>; -} - -impl ScalarZnxDftToRef for ScalarZnxDft -where - D: AsRef<[u8]>, - B: Backend, -{ - fn to_ref(&self) -> ScalarZnxDft<&[u8], B> { - ScalarZnxDft { - data: self.data.as_ref(), - n: self.n, - cols: self.cols, - _phantom: PhantomData, - } - } -} - -pub trait ScalarZnxDftToMut { - fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B>; -} - -impl ScalarZnxDftToMut for ScalarZnxDft -where - D: AsMut<[u8]> + AsRef<[u8]>, - B: Backend, -{ - fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B> { - ScalarZnxDft { - data: self.data.as_mut(), - n: self.n, - cols: self.cols, - _phantom: PhantomData, - } - } -} - -impl VecZnxDftToRef for ScalarZnxDft -where - D: AsRef<[u8]>, - B: Backend, -{ - fn to_ref(&self) -> VecZnxDft<&[u8], B> { - VecZnxDft { - data: self.data.as_ref(), - n: self.n, - cols: self.cols, - size: 1, - _phantom: std::marker::PhantomData, - } - } -} - -impl VecZnxDftToMut for ScalarZnxDft -where - D: AsRef<[u8]> + AsMut<[u8]>, - B: Backend, -{ - fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { - VecZnxDft { - data: self.data.as_mut(), - n: self.n, - cols: self.cols, - size: 1, - _phantom: std::marker::PhantomData, - } - } -} diff --git a/backend/src/scalar_znx_dft_ops.rs b/backend/src/scalar_znx_dft_ops.rs deleted file mode 100644 index c89808d..0000000 --- a/backend/src/scalar_znx_dft_ops.rs +++ /dev/null @@ -1,122 +0,0 @@ -use crate::ffi::svp; -use crate::ffi::vec_znx_dft::vec_znx_dft_t; -use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; -use crate::{ - Backend, FFT64, Module, ScalarZnx, ScalarZnxDft, ScalarZnxDftOwned, ScalarZnxDftToMut, ScalarZnxDftToRef, ScalarZnxToMut, - ScalarZnxToRef, Scratch, VecZnxDft, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, -}; - -pub trait ScalarZnxDftAlloc { - fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned; - fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize; - fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxDftOwned; -} - -pub trait ScalarZnxDftOps { - fn svp_prepare(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: ScalarZnxDftToMut, - A: ScalarZnxToRef; - - fn svp_apply(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) - where - R: VecZnxDftToMut, - A: ScalarZnxDftToRef, - B: VecZnxDftToRef; - - fn svp_apply_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxDftToMut, - A: ScalarZnxDftToRef; - - fn scalar_znx_idft(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) - where - R: ScalarZnxToMut, - A: ScalarZnxDftToRef; -} - -impl ScalarZnxDftAlloc for Module { - fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned { - ScalarZnxDftOwned::new(self, cols) - } - - fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize { - ScalarZnxDftOwned::bytes_of(self, cols) - } - - fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxDftOwned { - ScalarZnxDftOwned::new_from_bytes(self, cols, bytes) - } -} - -impl ScalarZnxDftOps for Module { - fn scalar_znx_idft(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) - where - R: ScalarZnxToMut, - A: ScalarZnxDftToRef, - { - let res_mut: &mut ScalarZnx<&mut [u8]> = &mut res.to_mut(); - let a_ref: &ScalarZnxDft<&[u8], FFT64> = &a.to_ref(); - let (mut vec_znx_big, scratch1) = scratch.tmp_vec_znx_big(self, 1, 1); - self.vec_znx_idft(&mut vec_znx_big, 0, a_ref, a_col, scratch1); - self.vec_znx_copy(res_mut, res_col, &vec_znx_big.to_vec_znx_small(), 0); - } - - fn svp_prepare(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: ScalarZnxDftToMut, - A: ScalarZnxToRef, - { - unsafe { - svp::svp_prepare( - self.ptr, - res.to_mut().at_mut_ptr(res_col, 0) as *mut svp::svp_ppol_t, - a.to_ref().at_ptr(a_col, 0), - ) - } - } - - fn svp_apply(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) - where - R: VecZnxDftToMut, - A: ScalarZnxDftToRef, - B: VecZnxDftToRef, - { - let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); - let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref(); - let b: VecZnxDft<&[u8], FFT64> = b.to_ref(); - unsafe { - svp::svp_apply_dft_to_dft( - self.ptr, - res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t, - res.size() as u64, - res.cols() as u64, - a.at_ptr(a_col, 0) as *const svp::svp_ppol_t, - b.at_ptr(b_col, 0) as *const vec_znx_dft_t, - b.size() as u64, - b.cols() as u64, - ) - } - } - - fn svp_apply_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxDftToMut, - A: ScalarZnxDftToRef, - { - let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); - let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref(); - unsafe { - svp::svp_apply_dft_to_dft( - self.ptr, - res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t, - res.size() as u64, - res.cols() as u64, - a.at_ptr(a_col, 0) as *const svp::svp_ppol_t, - res.at_ptr(res_col, 0) as *const vec_znx_dft_t, - res.size() as u64, - res.cols() as u64, - ) - } - } -} diff --git a/backend/src/stats.rs b/backend/src/stats.rs deleted file mode 100644 index f5aa26a..0000000 --- a/backend/src/stats.rs +++ /dev/null @@ -1,32 +0,0 @@ -use crate::znx_base::ZnxInfos; -use crate::{Decoding, VecZnx}; -use rug::Float; -use rug::float::Round; -use rug::ops::{AddAssignRound, DivAssignRound, SubAssignRound}; - -pub trait Stats { - /// Returns the standard devaition of the i-th polynomial. - fn std(&self, col_i: usize, basek: usize) -> f64; -} - -impl> Stats for VecZnx { - fn std(&self, col_i: usize, basek: usize) -> f64 { - let prec: u32 = (self.size() * basek) as u32; - let mut data: Vec = (0..self.n()).map(|_| Float::with_val(prec, 0)).collect(); - self.decode_vec_float(col_i, basek, &mut data); - // std = sqrt(sum((xi - avg)^2) / n) - let mut avg: Float = Float::with_val(prec, 0); - data.iter().for_each(|x| { - avg.add_assign_round(x, Round::Nearest); - }); - avg.div_assign_round(Float::with_val(prec, data.len()), Round::Nearest); - data.iter_mut().for_each(|x| { - x.sub_assign_round(&avg, Round::Nearest); - }); - let mut std: Float = Float::with_val(prec, 0); - data.iter().for_each(|x| std += x * x); - std.div_assign_round(Float::with_val(prec, data.len()), Round::Nearest); - std = std.sqrt(); - std.to_f64() - } -} diff --git a/backend/src/vec_znx.rs b/backend/src/vec_znx.rs deleted file mode 100644 index 74f9f86..0000000 --- a/backend/src/vec_znx.rs +++ /dev/null @@ -1,413 +0,0 @@ -use itertools::izip; - -use crate::DataView; -use crate::DataViewMut; -use crate::ScalarZnx; -use crate::Scratch; -use crate::ZnxSliceSize; -use crate::ZnxZero; -use crate::alloc_aligned; -use crate::assert_alignement; -use crate::cast_mut; -use crate::ffi::znx; -use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; -use std::{cmp::min, fmt}; - -/// [VecZnx] represents collection of contiguously stacked vector of small norm polynomials of -/// Zn\[X\] with [i64] coefficients. -/// A [VecZnx] is composed of multiple Zn\[X\] polynomials stored in a single contiguous array -/// in the memory. -/// -/// # Example -/// -/// Given 3 polynomials (a, b, c) of Zn\[X\], each with 4 columns, then the memory -/// layout is: `[a0, b0, c0, a1, b1, c1, a2, b2, c2, a3, b3, c3]`, where ai, bi, ci -/// are small polynomials of Zn\[X\]. -#[derive(PartialEq, Eq)] -pub struct VecZnx { - pub data: D, - pub n: usize, - pub cols: usize, - pub size: usize, -} - -impl fmt::Debug for VecZnx -where - D: AsRef<[u8]>, -{ - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self) - } -} - -impl ZnxInfos for VecZnx { - fn cols(&self) -> usize { - self.cols - } - - fn rows(&self) -> usize { - 1 - } - - fn n(&self) -> usize { - self.n - } - - fn size(&self) -> usize { - self.size - } -} - -impl ZnxSliceSize for VecZnx { - fn sl(&self) -> usize { - self.n() * self.cols() - } -} - -impl DataView for VecZnx { - type D = D; - fn data(&self) -> &Self::D { - &self.data - } -} - -impl DataViewMut for VecZnx { - fn data_mut(&mut self) -> &mut Self::D { - &mut self.data - } -} - -impl> ZnxView for VecZnx { - type Scalar = i64; -} - -impl VecZnx> { - pub fn rsh_scratch_space(n: usize) -> usize { - n * std::mem::size_of::() - } -} - -impl + AsRef<[u8]>> VecZnx { - /// Truncates the precision of the [VecZnx] by k bits. - /// - /// # Arguments - /// - /// * `basek`: the base two logarithm of the coefficients decomposition. - /// * `k`: the number of bits of precision to drop. - pub fn trunc_pow2(&mut self, basek: usize, k: usize, col: usize) { - if k == 0 { - return; - } - - self.size -= k / basek; - - let k_rem: usize = k % basek; - - if k_rem != 0 { - let mask: i64 = ((1 << (basek - k_rem - 1)) - 1) << k_rem; - self.at_mut(col, self.size() - 1) - .iter_mut() - .for_each(|x: &mut i64| *x &= mask) - } - } - - pub fn rotate(&mut self, k: i64) { - unsafe { - (0..self.cols()).for_each(|i| { - (0..self.size()).for_each(|j| { - znx::znx_rotate_inplace_i64(self.n() as u64, k, self.at_mut_ptr(i, j)); - }); - }) - } - } - - pub fn rsh(&mut self, basek: usize, k: usize, scratch: &mut Scratch) { - let n: usize = self.n(); - let cols: usize = self.cols(); - let size: usize = self.size(); - let steps: usize = k / basek; - - self.raw_mut().rotate_right(n * steps * cols); - (0..cols).for_each(|i| { - (0..steps).for_each(|j| { - self.zero_at(i, j); - }) - }); - - let k_rem: usize = k % basek; - - if k_rem != 0 { - let (carry, _) = scratch.tmp_slice::(n); - let shift = i64::BITS as usize - k_rem; - (0..cols).for_each(|i| { - carry.fill(0); - (steps..size).for_each(|j| { - izip!(carry.iter_mut(), self.at_mut(i, j).iter_mut()).for_each(|(ci, xi)| { - *xi += *ci << basek; - *ci = (*xi << shift) >> shift; - *xi = (*xi - *ci) >> k_rem; - }); - }); - }) - } - } - - pub fn lsh(&mut self, basek: usize, k: usize, scratch: &mut Scratch) { - let n: usize = self.n(); - let cols: usize = self.cols(); - let size: usize = self.size(); - let steps: usize = k / basek; - - self.raw_mut().rotate_left(n * steps * cols); - (0..cols).for_each(|i| { - (size - steps..size).for_each(|j| { - self.zero_at(i, j); - }) - }); - - let k_rem: usize = k % basek; - - if k_rem != 0 { - let shift: usize = i64::BITS as usize - k_rem; - let (tmp_bytes, _) = scratch.tmp_slice::(n * size_of::()); - (0..cols).for_each(|i| { - (0..steps).for_each(|j| { - self.at_mut(i, j).iter_mut().for_each(|xi| { - *xi <<= shift; - }); - }); - normalize(basek, self, i, tmp_bytes); - }); - } - } -} - -impl>> VecZnx { - pub(crate) fn bytes_of(n: usize, cols: usize, size: usize) -> usize { - n * cols * size * size_of::() - } - - pub fn new(n: usize, cols: usize, size: usize) -> Self { - let data = alloc_aligned::(Self::bytes_of::(n, cols, size)); - Self { - data: data.into(), - n, - cols, - size, - } - } - - pub(crate) fn new_from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into>) -> Self { - let data: Vec = bytes.into(); - assert!(data.len() == Self::bytes_of::(n, cols, size)); - Self { - data: data.into(), - n, - cols, - size, - } - } -} - -impl VecZnx { - pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self { - Self { - data, - n, - cols, - size, - } - } - - pub fn to_scalar_znx(self) -> ScalarZnx { - debug_assert_eq!( - self.size, 1, - "cannot convert VecZnx to ScalarZnx if cols: {} != 1", - self.cols - ); - ScalarZnx { - data: self.data, - n: self.n, - cols: self.cols, - } - } -} - -/// Copies the coefficients of `a` on the receiver. -/// Copy is done with the minimum size matching both backing arrays. -/// Panics if the cols do not match. -pub fn copy_vec_znx_from(b: &mut VecZnx, a: &VecZnx) -where - DataMut: AsMut<[u8]> + AsRef<[u8]>, - Data: AsRef<[u8]>, -{ - assert_eq!(b.cols(), a.cols()); - let data_a: &[i64] = a.raw(); - let data_b: &mut [i64] = b.raw_mut(); - let size = min(data_b.len(), data_a.len()); - data_b[..size].copy_from_slice(&data_a[..size]) -} - -#[allow(dead_code)] -fn normalize_tmp_bytes(n: usize) -> usize { - n * std::mem::size_of::() -} - -impl + AsMut<[u8]>> VecZnx { - pub fn normalize(&mut self, basek: usize, a_col: usize, tmp_bytes: &mut [u8]) { - normalize(basek, self, a_col, tmp_bytes); - } -} - -fn normalize + AsRef<[u8]>>(basek: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]) { - let n: usize = a.n(); - - debug_assert!( - tmp_bytes.len() >= normalize_tmp_bytes(n), - "invalid tmp_bytes: tmp_bytes.len()={} < normalize_tmp_bytes({})", - tmp_bytes.len(), - n, - ); - - #[cfg(debug_assertions)] - { - assert_alignement(tmp_bytes.as_ptr()) - } - - let carry_i64: &mut [i64] = cast_mut(tmp_bytes); - - unsafe { - znx::znx_zero_i64_ref(n as u64, carry_i64.as_mut_ptr()); - (0..a.size()).rev().for_each(|i| { - znx::znx_normalize( - n as u64, - basek as u64, - a.at_mut_ptr(a_col, i), - carry_i64.as_mut_ptr(), - a.at_mut_ptr(a_col, i), - carry_i64.as_mut_ptr(), - ) - }); - } -} - -impl + AsRef<[u8]>> VecZnx -where - VecZnx: VecZnxToMut + ZnxInfos, -{ - /// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self]. - pub fn extract_column(&mut self, self_col: usize, a: &VecZnx, a_col: usize) - where - R: AsRef<[u8]>, - VecZnx: VecZnxToRef + ZnxInfos, - { - #[cfg(debug_assertions)] - { - assert!(self_col < self.cols()); - assert!(a_col < a.cols()); - } - - let min_size: usize = self.size.min(a.size()); - let max_size: usize = self.size; - - let mut self_mut: VecZnx<&mut [u8]> = self.to_mut(); - let a_ref: VecZnx<&[u8]> = a.to_ref(); - - (0..min_size).for_each(|i: usize| { - self_mut - .at_mut(self_col, i) - .copy_from_slice(a_ref.at(a_col, i)); - }); - - (min_size..max_size).for_each(|i| { - self_mut.zero_at(self_col, i); - }); - } -} - -impl> fmt::Display for VecZnx { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - writeln!( - f, - "VecZnx(n={}, cols={}, size={})", - self.n, self.cols, self.size - )?; - - for col in 0..self.cols { - writeln!(f, "Column {}:", col)?; - for size in 0..self.size { - let coeffs = self.at(col, size); - write!(f, " Size {}: [", size)?; - - let max_show = 100; - let show_count = coeffs.len().min(max_show); - - for (i, &coeff) in coeffs.iter().take(show_count).enumerate() { - if i > 0 { - write!(f, ", ")?; - } - write!(f, "{}", coeff)?; - } - - if coeffs.len() > max_show { - write!(f, ", ... ({} more)", coeffs.len() - max_show)?; - } - - writeln!(f, "]")?; - } - } - Ok(()) - } -} - -pub type VecZnxOwned = VecZnx>; -pub type VecZnxMut<'a> = VecZnx<&'a mut [u8]>; -pub type VecZnxRef<'a> = VecZnx<&'a [u8]>; - -pub trait VecZnxToRef { - fn to_ref(&self) -> VecZnx<&[u8]>; -} - -impl VecZnxToRef for VecZnx -where - D: AsRef<[u8]>, -{ - fn to_ref(&self) -> VecZnx<&[u8]> { - VecZnx { - data: self.data.as_ref(), - n: self.n, - cols: self.cols, - size: self.size, - } - } -} - -pub trait VecZnxToMut { - fn to_mut(&mut self) -> VecZnx<&mut [u8]>; -} - -impl VecZnxToMut for VecZnx -where - D: AsRef<[u8]> + AsMut<[u8]>, -{ - fn to_mut(&mut self) -> VecZnx<&mut [u8]> { - VecZnx { - data: self.data.as_mut(), - n: self.n, - cols: self.cols, - size: self.size, - } - } -} - -impl> VecZnx { - pub fn clone(&self) -> VecZnx> { - let self_ref: VecZnx<&[u8]> = self.to_ref(); - VecZnx { - data: self_ref.data.to_vec(), - n: self_ref.n, - cols: self_ref.cols, - size: self_ref.size, - } - } -} diff --git a/backend/src/vec_znx_big.rs b/backend/src/vec_znx_big.rs deleted file mode 100644 index 90c3de2..0000000 --- a/backend/src/vec_znx_big.rs +++ /dev/null @@ -1,216 +0,0 @@ -use crate::ffi::vec_znx_big; -use crate::znx_base::{ZnxInfos, ZnxView}; -use crate::{Backend, DataView, DataViewMut, FFT64, Module, VecZnx, ZnxSliceSize, ZnxViewMut, ZnxZero, alloc_aligned}; -use std::fmt; -use std::marker::PhantomData; - -pub struct VecZnxBig { - data: D, - n: usize, - cols: usize, - size: usize, - _phantom: PhantomData, -} - -impl ZnxInfos for VecZnxBig { - fn cols(&self) -> usize { - self.cols - } - - fn rows(&self) -> usize { - 1 - } - - fn n(&self) -> usize { - self.n - } - - fn size(&self) -> usize { - self.size - } -} - -impl ZnxSliceSize for VecZnxBig { - fn sl(&self) -> usize { - self.n() * self.cols() - } -} - -impl DataView for VecZnxBig { - type D = D; - fn data(&self) -> &Self::D { - &self.data - } -} - -impl DataViewMut for VecZnxBig { - fn data_mut(&mut self) -> &mut Self::D { - &mut self.data - } -} - -impl> ZnxView for VecZnxBig { - type Scalar = i64; -} - -pub(crate) fn bytes_of_vec_znx_big(module: &Module, cols: usize, size: usize) -> usize { - unsafe { vec_znx_big::bytes_of_vec_znx_big(module.ptr, size as u64) as usize * cols } -} - -impl>, B: Backend> VecZnxBig { - pub(crate) fn new(module: &Module, cols: usize, size: usize) -> Self { - let data = alloc_aligned::(bytes_of_vec_znx_big(module, cols, size)); - Self { - data: data.into(), - n: module.n(), - cols, - size, - _phantom: PhantomData, - } - } - - pub(crate) fn new_from_bytes(module: &Module, cols: usize, size: usize, bytes: impl Into>) -> Self { - let data: Vec = bytes.into(); - assert!(data.len() == bytes_of_vec_znx_big(module, cols, size)); - Self { - data: data.into(), - n: module.n(), - cols, - size, - _phantom: PhantomData, - } - } -} - -impl VecZnxBig { - pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self { - Self { - data, - n, - cols, - size, - _phantom: PhantomData, - } - } -} - -impl + AsRef<[u8]>> VecZnxBig -where - VecZnxBig: VecZnxBigToMut + ZnxInfos, -{ - // Consumes the VecZnxBig to return a VecZnx. - // Useful when no normalization is needed. - pub fn to_vec_znx_small(self) -> VecZnx { - VecZnx { - data: self.data, - n: self.n, - cols: self.cols, - size: self.size, - } - } - - /// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self]. - pub fn extract_column(&mut self, self_col: usize, a: &C, a_col: usize) - where - C: VecZnxBigToRef + ZnxInfos, - { - #[cfg(debug_assertions)] - { - assert!(self_col < self.cols()); - assert!(a_col < a.cols()); - } - - let min_size: usize = self.size.min(a.size()); - let max_size: usize = self.size; - - let mut self_mut: VecZnxBig<&mut [u8], FFT64> = self.to_mut(); - let a_ref: VecZnxBig<&[u8], FFT64> = a.to_ref(); - - (0..min_size).for_each(|i: usize| { - self_mut - .at_mut(self_col, i) - .copy_from_slice(a_ref.at(a_col, i)); - }); - - (min_size..max_size).for_each(|i| { - self_mut.zero_at(self_col, i); - }); - } -} - -pub type VecZnxBigOwned = VecZnxBig, B>; - -pub trait VecZnxBigToRef { - fn to_ref(&self) -> VecZnxBig<&[u8], B>; -} - -impl VecZnxBigToRef for VecZnxBig -where - D: AsRef<[u8]>, - B: Backend, -{ - fn to_ref(&self) -> VecZnxBig<&[u8], B> { - VecZnxBig { - data: self.data.as_ref(), - n: self.n, - cols: self.cols, - size: self.size, - _phantom: std::marker::PhantomData, - } - } -} - -pub trait VecZnxBigToMut { - fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B>; -} - -impl VecZnxBigToMut for VecZnxBig -where - D: AsRef<[u8]> + AsMut<[u8]>, - B: Backend, -{ - fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B> { - VecZnxBig { - data: self.data.as_mut(), - n: self.n, - cols: self.cols, - size: self.size, - _phantom: std::marker::PhantomData, - } - } -} - -impl> fmt::Display for VecZnxBig { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - writeln!( - f, - "VecZnxBig(n={}, cols={}, size={})", - self.n, self.cols, self.size - )?; - - for col in 0..self.cols { - writeln!(f, "Column {}:", col)?; - for size in 0..self.size { - let coeffs = self.at(col, size); - write!(f, " Size {}: [", size)?; - - let max_show = 100; - let show_count = coeffs.len().min(max_show); - - for (i, &coeff) in coeffs.iter().take(show_count).enumerate() { - if i > 0 { - write!(f, ", ")?; - } - write!(f, "{}", coeff)?; - } - - if coeffs.len() > max_show { - write!(f, ", ... ({} more)", coeffs.len() - max_show)?; - } - - writeln!(f, "]")?; - } - } - Ok(()) - } -} diff --git a/backend/src/vec_znx_big_ops.rs b/backend/src/vec_znx_big_ops.rs deleted file mode 100644 index b0b09e7..0000000 --- a/backend/src/vec_znx_big_ops.rs +++ /dev/null @@ -1,618 +0,0 @@ -use crate::ffi::vec_znx; -use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; -use crate::{ - Backend, FFT64, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxScratch, - VecZnxToMut, VecZnxToRef, ZnxSliceSize, bytes_of_vec_znx_big, -}; - -pub trait VecZnxBigAlloc { - /// Allocates a vector Z[X]/(X^N+1) that stores not normalized values. - fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBigOwned; - - /// Returns a new [VecZnxBig] with the provided bytes array as backing array. - /// - /// Behavior: takes ownership of the backing array. - /// - /// # Arguments - /// - /// * `cols`: the number of polynomials.. - /// * `size`: the number of polynomials per column. - /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big]. - /// - /// # Panics - /// If `bytes.len()` < [Module::bytes_of_vec_znx_big]. - fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxBigOwned; - - // /// Returns a new [VecZnxBig] with the provided bytes array as backing array. - // /// - // /// Behavior: the backing array is only borrowed. - // /// - // /// # Arguments - // /// - // /// * `cols`: the number of polynomials.. - // /// * `size`: the number of polynomials per column. - // /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big]. - // /// - // /// # Panics - // /// If `bytes.len()` < [Module::bytes_of_vec_znx_big]. - // fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig; - - /// Returns the minimum number of bytes necessary to allocate - /// a new [VecZnxBig] through [VecZnxBig::from_bytes]. - fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize; -} - -pub trait VecZnxBigOps { - /// Adds `a` to `b` and stores the result on `c`. - fn vec_znx_big_add(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) - where - R: VecZnxBigToMut, - A: VecZnxBigToRef, - B: VecZnxBigToRef; - - /// Adds `a` to `b` and stores the result on `b`. - fn vec_znx_big_add_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxBigToMut, - A: VecZnxBigToRef; - - /// Adds `a` to `b` and stores the result on `c`. - fn vec_znx_big_add_small(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) - where - R: VecZnxBigToMut, - A: VecZnxBigToRef, - B: VecZnxToRef; - - /// Adds `a` to `b` and stores the result on `b`. - fn vec_znx_big_add_small_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxBigToMut, - A: VecZnxToRef; - - /// Subtracts `a` to `b` and stores the result on `c`. - fn vec_znx_big_sub(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) - where - R: VecZnxBigToMut, - A: VecZnxBigToRef, - B: VecZnxBigToRef; - - /// Subtracts `a` from `b` and stores the result on `b`. - fn vec_znx_big_sub_ab_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxBigToMut, - A: VecZnxBigToRef; - - /// Subtracts `b` from `a` and stores the result on `b`. - fn vec_znx_big_sub_ba_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxBigToMut, - A: VecZnxBigToRef; - - /// Subtracts `b` from `a` and stores the result on `c`. - fn vec_znx_big_sub_small_a(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) - where - R: VecZnxBigToMut, - A: VecZnxToRef, - B: VecZnxBigToRef; - - /// Subtracts `a` from `res` and stores the result on `res`. - fn vec_znx_big_sub_small_a_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxBigToMut, - A: VecZnxToRef; - - /// Subtracts `b` from `a` and stores the result on `c`. - fn vec_znx_big_sub_small_b(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) - where - R: VecZnxBigToMut, - A: VecZnxBigToRef, - B: VecZnxToRef; - - /// Subtracts `res` from `a` and stores the result on `res`. - fn vec_znx_big_sub_small_b_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxBigToMut, - A: VecZnxToRef; - - /// Negates `a` inplace. - fn vec_znx_big_negate_inplace(&self, a: &mut A, a_col: usize) - where - A: VecZnxBigToMut; - - /// Normalizes `a` and stores the result on `b`. - /// - /// # Arguments - /// - /// * `basek`: normalization basis. - /// * `tmp_bytes`: scratch space of size at least [VecZnxBigOps::vec_znx_big_normalize]. - fn vec_znx_big_normalize(&self, basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) - where - R: VecZnxToMut, - A: VecZnxBigToRef; - - /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`. - fn vec_znx_big_automorphism(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxBigToMut, - A: VecZnxBigToRef; - - /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`. - fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) - where - A: VecZnxBigToMut; -} - -pub trait VecZnxBigScratch { - /// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_normalize]. - fn vec_znx_big_normalize_tmp_bytes(&self) -> usize; -} - -impl VecZnxBigAlloc for Module { - fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBigOwned { - VecZnxBig::new(self, cols, size) - } - - fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxBigOwned { - VecZnxBig::new_from_bytes(self, cols, size, bytes) - } - - fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize { - bytes_of_vec_znx_big(self, cols, size) - } -} - -impl VecZnxBigOps for Module { - fn vec_znx_big_add(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) - where - R: VecZnxBigToMut, - A: VecZnxBigToRef, - B: VecZnxBigToRef, - { - let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); - let b: VecZnxBig<&[u8], FFT64> = b.to_ref(); - let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); - - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(b.n(), self.n()); - assert_eq!(res.n(), self.n()); - assert_ne!(a.as_ptr(), b.as_ptr()); - } - unsafe { - vec_znx::vec_znx_add( - self.ptr, - res.at_mut_ptr(res_col, 0), - res.size() as u64, - res.sl() as u64, - a.at_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - b.at_ptr(b_col, 0), - b.size() as u64, - b.sl() as u64, - ) - } - } - - fn vec_znx_big_add_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxBigToMut, - A: VecZnxBigToRef, - { - let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); - let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); - - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(res.n(), self.n()); - } - unsafe { - vec_znx::vec_znx_add( - self.ptr, - res.at_mut_ptr(res_col, 0), - res.size() as u64, - res.sl() as u64, - a.at_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - res.at_ptr(res_col, 0), - res.size() as u64, - res.sl() as u64, - ) - } - } - - fn vec_znx_big_sub(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) - where - R: VecZnxBigToMut, - A: VecZnxBigToRef, - B: VecZnxBigToRef, - { - let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); - let b: VecZnxBig<&[u8], FFT64> = b.to_ref(); - let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); - - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(b.n(), self.n()); - assert_eq!(res.n(), self.n()); - assert_ne!(a.as_ptr(), b.as_ptr()); - } - unsafe { - vec_znx::vec_znx_sub( - self.ptr, - res.at_mut_ptr(res_col, 0), - res.size() as u64, - res.sl() as u64, - a.at_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - b.at_ptr(b_col, 0), - b.size() as u64, - b.sl() as u64, - ) - } - } - - fn vec_znx_big_sub_ab_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxBigToMut, - A: VecZnxBigToRef, - { - let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); - let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); - - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(res.n(), self.n()); - } - unsafe { - vec_znx::vec_znx_sub( - self.ptr, - res.at_mut_ptr(res_col, 0), - res.size() as u64, - res.sl() as u64, - res.at_ptr(res_col, 0), - res.size() as u64, - res.sl() as u64, - a.at_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - ) - } - } - - fn vec_znx_big_sub_ba_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxBigToMut, - A: VecZnxBigToRef, - { - let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); - let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); - - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(res.n(), self.n()); - } - unsafe { - vec_znx::vec_znx_sub( - self.ptr, - res.at_mut_ptr(res_col, 0), - res.size() as u64, - res.sl() as u64, - a.at_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - res.at_ptr(res_col, 0), - res.size() as u64, - res.sl() as u64, - ) - } - } - - fn vec_znx_big_sub_small_b(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) - where - R: VecZnxBigToMut, - A: VecZnxBigToRef, - B: VecZnxToRef, - { - let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); - let b: VecZnx<&[u8]> = b.to_ref(); - let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); - - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(b.n(), self.n()); - assert_eq!(res.n(), self.n()); - assert_ne!(a.as_ptr(), b.as_ptr()); - } - unsafe { - vec_znx::vec_znx_sub( - self.ptr, - res.at_mut_ptr(res_col, 0), - res.size() as u64, - res.sl() as u64, - a.at_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - b.at_ptr(b_col, 0), - b.size() as u64, - b.sl() as u64, - ) - } - } - - fn vec_znx_big_sub_small_b_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxBigToMut, - A: VecZnxToRef, - { - let a: VecZnx<&[u8]> = a.to_ref(); - let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); - - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(res.n(), self.n()); - } - unsafe { - vec_znx::vec_znx_sub( - self.ptr, - res.at_mut_ptr(res_col, 0), - res.size() as u64, - res.sl() as u64, - a.at_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - res.at_ptr(res_col, 0), - res.size() as u64, - res.sl() as u64, - ) - } - } - - fn vec_znx_big_sub_small_a(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) - where - R: VecZnxBigToMut, - A: VecZnxToRef, - B: VecZnxBigToRef, - { - let a: VecZnx<&[u8]> = a.to_ref(); - let b: VecZnxBig<&[u8], FFT64> = b.to_ref(); - let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); - - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(b.n(), self.n()); - assert_eq!(res.n(), self.n()); - assert_ne!(a.as_ptr(), b.as_ptr()); - } - unsafe { - vec_znx::vec_znx_sub( - self.ptr, - res.at_mut_ptr(res_col, 0), - res.size() as u64, - res.sl() as u64, - a.at_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - b.at_ptr(b_col, 0), - b.size() as u64, - b.sl() as u64, - ) - } - } - - fn vec_znx_big_sub_small_a_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxBigToMut, - A: VecZnxToRef, - { - let a: VecZnx<&[u8]> = a.to_ref(); - let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); - - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(res.n(), self.n()); - } - unsafe { - vec_znx::vec_znx_sub( - self.ptr, - res.at_mut_ptr(res_col, 0), - res.size() as u64, - res.sl() as u64, - res.at_ptr(res_col, 0), - res.size() as u64, - res.sl() as u64, - a.at_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - ) - } - } - - fn vec_znx_big_add_small(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) - where - R: VecZnxBigToMut, - A: VecZnxBigToRef, - B: VecZnxToRef, - { - let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); - let b: VecZnx<&[u8]> = b.to_ref(); - let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); - - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(b.n(), self.n()); - assert_eq!(res.n(), self.n()); - assert_ne!(a.as_ptr(), b.as_ptr()); - } - unsafe { - vec_znx::vec_znx_add( - self.ptr, - res.at_mut_ptr(res_col, 0), - res.size() as u64, - res.sl() as u64, - a.at_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - b.at_ptr(b_col, 0), - b.size() as u64, - b.sl() as u64, - ) - } - } - - fn vec_znx_big_add_small_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxBigToMut, - A: VecZnxToRef, - { - let a: VecZnx<&[u8]> = a.to_ref(); - let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); - - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(res.n(), self.n()); - } - unsafe { - vec_znx::vec_znx_add( - self.ptr, - res.at_mut_ptr(res_col, 0), - res.size() as u64, - res.sl() as u64, - res.at_ptr(res_col, 0), - res.size() as u64, - res.sl() as u64, - a.at_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - ) - } - } - - fn vec_znx_big_negate_inplace(&self, a: &mut A, a_col: usize) - where - A: VecZnxBigToMut, - { - let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut(); - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - } - unsafe { - vec_znx::vec_znx_negate( - self.ptr, - a.at_mut_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - a.at_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - ) - } - } - - fn vec_znx_big_normalize(&self, basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) - where - R: VecZnxToMut, - A: VecZnxBigToRef, - { - let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); - let mut res: VecZnx<&mut [u8]> = res.to_mut(); - - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(res.n(), self.n()); - //(Jay)Note: This is calling VezZnxOps::vec_znx_normalize_tmp_bytes and not VecZnxBigOps::vec_znx_big_normalize_tmp_bytes. - // In the FFT backend the tmp sizes are same but will be different in the NTT backend - // assert!(tmp_bytes.len() >= >::vec_znx_normalize_tmp_bytes(&self)); - // assert_alignement(tmp_bytes.as_ptr()); - } - - let (tmp_bytes, _) = scratch.tmp_slice(::vec_znx_big_normalize_tmp_bytes( - &self, - )); - unsafe { - vec_znx::vec_znx_normalize_base2k( - self.ptr, - basek as u64, - res.at_mut_ptr(res_col, 0), - res.size() as u64, - res.sl() as u64, - a.at_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - tmp_bytes.as_mut_ptr(), - ); - } - } - - fn vec_znx_big_automorphism(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxBigToMut, - A: VecZnxBigToRef, - { - let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); - let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); - - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(res.n(), self.n()); - } - unsafe { - vec_znx::vec_znx_automorphism( - self.ptr, - k, - res.at_mut_ptr(res_col, 0), - res.size() as u64, - res.sl() as u64, - a.at_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - ) - } - } - - fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) - where - A: VecZnxBigToMut, - { - let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut(); - - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - } - unsafe { - vec_znx::vec_znx_automorphism( - self.ptr, - k, - a.at_mut_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - a.at_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - ) - } - } -} - -impl VecZnxBigScratch for Module { - fn vec_znx_big_normalize_tmp_bytes(&self) -> usize { - ::vec_znx_normalize_tmp_bytes(self) - } -} diff --git a/backend/src/vec_znx_dft.rs b/backend/src/vec_znx_dft.rs deleted file mode 100644 index d5c0ad5..0000000 --- a/backend/src/vec_znx_dft.rs +++ /dev/null @@ -1,190 +0,0 @@ -use std::marker::PhantomData; - -use crate::ffi::vec_znx_dft; -use crate::znx_base::ZnxInfos; -use crate::{Backend, DataView, DataViewMut, FFT64, Module, VecZnxBig, ZnxSliceSize, ZnxView, alloc_aligned}; -use std::fmt; - -pub struct VecZnxDft { - pub(crate) data: D, - pub(crate) n: usize, - pub(crate) cols: usize, - pub(crate) size: usize, - pub(crate) _phantom: PhantomData, -} - -impl VecZnxDft { - pub fn into_big(self) -> VecZnxBig { - VecZnxBig::::from_data(self.data, self.n, self.cols, self.size) - } -} - -impl ZnxInfos for VecZnxDft { - fn cols(&self) -> usize { - self.cols - } - - fn rows(&self) -> usize { - 1 - } - - fn n(&self) -> usize { - self.n - } - - fn size(&self) -> usize { - self.size - } -} - -impl ZnxSliceSize for VecZnxDft { - fn sl(&self) -> usize { - self.n() * self.cols() - } -} - -impl DataView for VecZnxDft { - type D = D; - fn data(&self) -> &Self::D { - &self.data - } -} - -impl DataViewMut for VecZnxDft { - fn data_mut(&mut self) -> &mut Self::D { - &mut self.data - } -} - -impl> ZnxView for VecZnxDft { - type Scalar = f64; -} - -impl + AsRef<[u8]>> VecZnxDft { - pub fn set_size(&mut self, size: usize) { - assert!(size <= self.data.as_ref().len() / (self.n * self.cols())); - self.size = size - } - - pub fn max_size(&mut self) -> usize { - self.data.as_ref().len() / (self.n * self.cols) - } -} - -pub(crate) fn bytes_of_vec_znx_dft(module: &Module, cols: usize, size: usize) -> usize { - unsafe { vec_znx_dft::bytes_of_vec_znx_dft(module.ptr, size as u64) as usize * cols } -} - -impl>, B: Backend> VecZnxDft { - pub(crate) fn new(module: &Module, cols: usize, size: usize) -> Self { - let data = alloc_aligned::(bytes_of_vec_znx_dft(module, cols, size)); - Self { - data: data.into(), - n: module.n(), - cols, - size, - _phantom: PhantomData, - } - } - - pub(crate) fn new_from_bytes(module: &Module, cols: usize, size: usize, bytes: impl Into>) -> Self { - let data: Vec = bytes.into(); - assert!(data.len() == bytes_of_vec_znx_dft(module, cols, size)); - Self { - data: data.into(), - n: module.n(), - cols, - size, - _phantom: PhantomData, - } - } -} - -pub type VecZnxDftOwned = VecZnxDft, B>; - -impl VecZnxDft { - pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self { - Self { - data, - n, - cols, - size, - _phantom: PhantomData, - } - } -} - -pub trait VecZnxDftToRef { - fn to_ref(&self) -> VecZnxDft<&[u8], B>; -} - -impl VecZnxDftToRef for VecZnxDft -where - D: AsRef<[u8]>, - B: Backend, -{ - fn to_ref(&self) -> VecZnxDft<&[u8], B> { - VecZnxDft { - data: self.data.as_ref(), - n: self.n, - cols: self.cols, - size: self.size, - _phantom: std::marker::PhantomData, - } - } -} - -pub trait VecZnxDftToMut { - fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B>; -} - -impl VecZnxDftToMut for VecZnxDft -where - D: AsRef<[u8]> + AsMut<[u8]>, - B: Backend, -{ - fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { - VecZnxDft { - data: self.data.as_mut(), - n: self.n, - cols: self.cols, - size: self.size, - _phantom: std::marker::PhantomData, - } - } -} - -impl> fmt::Display for VecZnxDft { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - writeln!( - f, - "VecZnxDft(n={}, cols={}, size={})", - self.n, self.cols, self.size - )?; - - for col in 0..self.cols { - writeln!(f, "Column {}:", col)?; - for size in 0..self.size { - let coeffs = self.at(col, size); - write!(f, " Size {}: [", size)?; - - let max_show = 100; - let show_count = coeffs.len().min(max_show); - - for (i, &coeff) in coeffs.iter().take(show_count).enumerate() { - if i > 0 { - write!(f, ", ")?; - } - write!(f, "{}", coeff)?; - } - - if coeffs.len() > max_show { - write!(f, ", ... ({} more)", coeffs.len() - max_show)?; - } - - writeln!(f, "]")?; - } - } - Ok(()) - } -} diff --git a/backend/src/vec_znx_ops.rs b/backend/src/vec_znx_ops.rs deleted file mode 100644 index 2bde61c..0000000 --- a/backend/src/vec_znx_ops.rs +++ /dev/null @@ -1,736 +0,0 @@ -use crate::ffi::vec_znx; -use crate::{ - Backend, Module, ScalarZnxToRef, Scratch, VecZnx, VecZnxOwned, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxSliceSize, ZnxView, - ZnxViewMut, ZnxZero, -}; -use itertools::izip; -use std::cmp::min; - -pub trait VecZnxAlloc { - /// Allocates a new [VecZnx]. - /// - /// # Arguments - /// - /// * `cols`: the number of polynomials. - /// * `size`: the number small polynomials per column. - fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnxOwned; - - /// Instantiates a new [VecZnx] from a slice of bytes. - /// The returned [VecZnx] takes ownership of the slice of bytes. - /// - /// # Arguments - /// - /// * `cols`: the number of polynomials. - /// * `size`: the number small polynomials per column. - /// - /// # Panic - /// Requires the slice of bytes to be equal to [VecZnxOps::bytes_of_vec_znx]. - fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxOwned; - - /// Returns the number of bytes necessary to allocate - /// a new [VecZnx] through [VecZnxOps::new_vec_znx_from_bytes] - /// or [VecZnxOps::new_vec_znx_from_bytes_borrow]. - fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize; -} - -pub trait VecZnxOps { - /// Normalizes the selected column of `a` and stores the result into the selected column of `res`. - fn vec_znx_normalize(&self, basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) - where - R: VecZnxToMut, - A: VecZnxToRef; - - /// Normalizes the selected column of `a`. - fn vec_znx_normalize_inplace(&self, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) - where - A: VecZnxToMut; - - /// Adds the selected column of `a` to the selected column of `b` and writes the result on the selected column of `res`. - fn vec_znx_add(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) - where - R: VecZnxToMut, - A: VecZnxToRef, - B: VecZnxToRef; - - /// Adds the selected column of `a` to the selected column of `res` and writes the result on the selected column of `res`. - fn vec_znx_add_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxToMut, - A: VecZnxToRef; - - /// Adds the selected column of `a` on the selected column and limb of `res`. - fn vec_znx_add_scalar_inplace(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize) - where - R: VecZnxToMut, - A: ScalarZnxToRef; - - /// Subtracts the selected column of `b` from the selected column of `a` and writes the result on the selected column of `res`. - fn vec_znx_sub(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) - where - R: VecZnxToMut, - A: VecZnxToRef, - B: VecZnxToRef; - - /// Subtracts the selected column of `a` from the selected column of `res` inplace. - /// - /// res[res_col] -= a[a_col] - fn vec_znx_sub_ab_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxToMut, - A: VecZnxToRef; - - /// Subtracts the selected column of `res` from the selected column of `a` and inplace mutates `res` - /// - /// res[res_col] = a[a_col] - res[res_col] - fn vec_znx_sub_ba_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxToMut, - A: VecZnxToRef; - - /// Subtracts the selected column of `a` on the selected column and limb of `res`. - fn vec_znx_sub_scalar_inplace(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, b_col: usize) - where - R: VecZnxToMut, - A: ScalarZnxToRef; - - // Negates the selected column of `a` and stores the result in `res_col` of `res`. - fn vec_znx_negate(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxToMut, - A: VecZnxToRef; - - /// Negates the selected column of `a`. - fn vec_znx_negate_inplace(&self, a: &mut A, a_col: usize) - where - A: VecZnxToMut; - - /// Shifts by k bits all columns of `a`. - /// A positive k applies a left shift, while a negative k applies a right shift. - fn vec_znx_shift_inplace(&self, basek: usize, k: i64, a: &mut A, scratch: &mut Scratch) - where - A: VecZnxToMut; - - /// Multiplies the selected column of `a` by X^k and stores the result in `res_col` of `res`. - fn vec_znx_rotate(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxToMut, - A: VecZnxToRef; - - /// Multiplies the selected column of `a` by X^k. - fn vec_znx_rotate_inplace(&self, k: i64, a: &mut A, a_col: usize) - where - A: VecZnxToMut; - - /// Applies the automorphism X^i -> X^ik on the selected column of `a` and stores the result in `res_col` column of `res`. - fn vec_znx_automorphism(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxToMut, - A: VecZnxToRef; - - /// Applies the automorphism X^i -> X^ik on the selected column of `a`. - fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) - where - A: VecZnxToMut; - - /// Splits the selected columns of `b` into subrings and copies them them into the selected column of `res`. - /// - /// # Panics - /// - /// This method requires that all [VecZnx] of b have the same ring degree - /// and that b.n() * b.len() <= a.n() - fn vec_znx_split(&self, res: &mut Vec, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) - where - R: VecZnxToMut, - A: VecZnxToRef; - - /// Merges the subrings of the selected column of `a` into the selected column of `res`. - /// - /// # Panics - /// - /// This method requires that all [VecZnx] of a have the same ring degree - /// and that a.n() * a.len() <= b.n() - fn vec_znx_merge(&self, res: &mut R, res_col: usize, a: Vec, a_col: usize) - where - R: VecZnxToMut, - A: VecZnxToRef; - - fn switch_degree(&self, r: &mut R, col_b: usize, a: &A, col_a: usize) - where - R: VecZnxToMut, - A: VecZnxToRef; - - fn vec_znx_copy(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxToMut, - A: VecZnxToRef; -} - -pub trait VecZnxScratch { - /// Returns the minimum number of bytes necessary for normalization. - fn vec_znx_normalize_tmp_bytes(&self) -> usize; -} - -impl VecZnxAlloc for Module { - fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnxOwned { - VecZnxOwned::new::(self.n(), cols, size) - } - - fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize { - VecZnxOwned::bytes_of::(self.n(), cols, size) - } - - fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxOwned { - VecZnxOwned::new_from_bytes::(self.n(), cols, size, bytes) - } -} - -impl VecZnxOps for Module { - fn vec_znx_shift_inplace(&self, basek: usize, k: i64, a: &mut A, scratch: &mut Scratch) - where - A: VecZnxToMut, - { - if k > 0 { - a.to_mut().lsh(basek, k as usize, scratch); - } else { - a.to_mut().rsh(basek, k.abs() as usize, scratch); - } - } - - fn vec_znx_copy(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxToMut, - A: VecZnxToRef, - { - let mut res_mut: VecZnx<&mut [u8]> = res.to_mut(); - let a_ref: VecZnx<&[u8]> = a.to_ref(); - - let min_size: usize = min(res_mut.size(), a_ref.size()); - - (0..min_size).for_each(|j| { - res_mut - .at_mut(res_col, j) - .copy_from_slice(a_ref.at(a_col, j)); - }); - (min_size..res_mut.size()).for_each(|j| { - res_mut.zero_at(res_col, j); - }) - } - - fn vec_znx_normalize(&self, basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) - where - R: VecZnxToMut, - A: VecZnxToRef, - { - let a: VecZnx<&[u8]> = a.to_ref(); - let mut res: VecZnx<&mut [u8]> = res.to_mut(); - - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(res.n(), self.n()); - } - - let (tmp_bytes, _) = scratch.tmp_slice(self.vec_znx_normalize_tmp_bytes()); - - unsafe { - vec_znx::vec_znx_normalize_base2k( - self.ptr, - basek as u64, - res.at_mut_ptr(res_col, 0), - res.size() as u64, - res.sl() as u64, - a.at_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - tmp_bytes.as_mut_ptr(), - ); - } - } - - fn vec_znx_normalize_inplace(&self, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) - where - A: VecZnxToMut, - { - let mut a: VecZnx<&mut [u8]> = a.to_mut(); - - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - } - - let (tmp_bytes, _) = scratch.tmp_slice(self.vec_znx_normalize_tmp_bytes()); - - unsafe { - vec_znx::vec_znx_normalize_base2k( - self.ptr, - basek as u64, - a.at_mut_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - a.at_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - tmp_bytes.as_mut_ptr(), - ); - } - } - - fn vec_znx_add(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) - where - R: VecZnxToMut, - A: VecZnxToRef, - B: VecZnxToRef, - { - let a: VecZnx<&[u8]> = a.to_ref(); - let b: VecZnx<&[u8]> = b.to_ref(); - let mut res: VecZnx<&mut [u8]> = res.to_mut(); - - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(b.n(), self.n()); - assert_eq!(res.n(), self.n()); - assert_ne!(a.as_ptr(), b.as_ptr()); - } - unsafe { - vec_znx::vec_znx_add( - self.ptr, - res.at_mut_ptr(res_col, 0), - res.size() as u64, - res.sl() as u64, - a.at_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - b.at_ptr(b_col, 0), - b.size() as u64, - b.sl() as u64, - ) - } - } - - fn vec_znx_add_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxToMut, - A: VecZnxToRef, - { - let a: VecZnx<&[u8]> = a.to_ref(); - let mut res: VecZnx<&mut [u8]> = res.to_mut(); - - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(res.n(), self.n()); - } - unsafe { - vec_znx::vec_znx_add( - self.ptr, - res.at_mut_ptr(res_col, 0), - res.size() as u64, - res.sl() as u64, - a.at_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - res.at_ptr(res_col, 0), - res.size() as u64, - res.sl() as u64, - ) - } - } - - fn vec_znx_add_scalar_inplace(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize) - where - R: VecZnxToMut, - A: ScalarZnxToRef, - { - let mut res: VecZnx<&mut [u8]> = res.to_mut(); - let a: crate::ScalarZnx<&[u8]> = a.to_ref(); - - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(res.n(), self.n()); - } - - unsafe { - vec_znx::vec_znx_add( - self.ptr, - res.at_mut_ptr(res_col, res_limb), - 1 as u64, - res.sl() as u64, - a.at_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - res.at_ptr(res_col, res_limb), - 1 as u64, - res.sl() as u64, - ) - } - } - - fn vec_znx_sub(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) - where - R: VecZnxToMut, - A: VecZnxToRef, - B: VecZnxToRef, - { - let a: VecZnx<&[u8]> = a.to_ref(); - let b: VecZnx<&[u8]> = b.to_ref(); - let mut res: VecZnx<&mut [u8]> = res.to_mut(); - - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(b.n(), self.n()); - assert_eq!(res.n(), self.n()); - assert_ne!(a.as_ptr(), b.as_ptr()); - } - unsafe { - vec_znx::vec_znx_sub( - self.ptr, - res.at_mut_ptr(res_col, 0), - res.size() as u64, - res.sl() as u64, - a.at_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - b.at_ptr(b_col, 0), - b.size() as u64, - b.sl() as u64, - ) - } - } - - fn vec_znx_sub_scalar_inplace(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize) - where - R: VecZnxToMut, - A: ScalarZnxToRef, - { - let mut res: VecZnx<&mut [u8]> = res.to_mut(); - let a: crate::ScalarZnx<&[u8]> = a.to_ref(); - - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(res.n(), self.n()); - } - - unsafe { - vec_znx::vec_znx_sub( - self.ptr, - res.at_mut_ptr(res_col, res_limb), - 1 as u64, - res.sl() as u64, - a.at_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - res.at_ptr(res_col, res_limb), - 1 as u64, - res.sl() as u64, - ) - } - } - - fn vec_znx_sub_ab_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxToMut, - A: VecZnxToRef, - { - let a: VecZnx<&[u8]> = a.to_ref(); - let mut res: VecZnx<&mut [u8]> = res.to_mut(); - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(res.n(), self.n()); - } - unsafe { - vec_znx::vec_znx_sub( - self.ptr, - res.at_mut_ptr(res_col, 0), - res.size() as u64, - res.sl() as u64, - res.at_ptr(res_col, 0), - res.size() as u64, - res.sl() as u64, - a.at_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - ) - } - } - - fn vec_znx_sub_ba_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxToMut, - A: VecZnxToRef, - { - let a: VecZnx<&[u8]> = a.to_ref(); - let mut res: VecZnx<&mut [u8]> = res.to_mut(); - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(res.n(), self.n()); - } - unsafe { - vec_znx::vec_znx_sub( - self.ptr, - res.at_mut_ptr(res_col, 0), - res.size() as u64, - res.sl() as u64, - a.at_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - res.at_ptr(res_col, 0), - res.size() as u64, - res.sl() as u64, - ) - } - } - - fn vec_znx_negate(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxToMut, - A: VecZnxToRef, - { - let a: VecZnx<&[u8]> = a.to_ref(); - let mut res: VecZnx<&mut [u8]> = res.to_mut(); - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(res.n(), self.n()); - } - unsafe { - vec_znx::vec_znx_negate( - self.ptr, - res.at_mut_ptr(res_col, 0), - res.size() as u64, - res.sl() as u64, - a.at_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - ) - } - } - - fn vec_znx_negate_inplace(&self, a: &mut A, a_col: usize) - where - A: VecZnxToMut, - { - let mut a: VecZnx<&mut [u8]> = a.to_mut(); - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - } - unsafe { - vec_znx::vec_znx_negate( - self.ptr, - a.at_mut_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - a.at_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - ) - } - } - - fn vec_znx_rotate(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxToMut, - A: VecZnxToRef, - { - let a: VecZnx<&[u8]> = a.to_ref(); - let mut res: VecZnx<&mut [u8]> = res.to_mut(); - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(res.n(), self.n()); - } - unsafe { - vec_znx::vec_znx_rotate( - self.ptr, - k, - res.at_mut_ptr(res_col, 0), - res.size() as u64, - res.sl() as u64, - a.at_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - ) - } - } - - fn vec_znx_rotate_inplace(&self, k: i64, a: &mut A, a_col: usize) - where - A: VecZnxToMut, - { - let mut a: VecZnx<&mut [u8]> = a.to_mut(); - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - } - unsafe { - vec_znx::vec_znx_rotate( - self.ptr, - k, - a.at_mut_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - a.at_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - ) - } - } - - fn vec_znx_automorphism(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxToMut, - A: VecZnxToRef, - { - let a: VecZnx<&[u8]> = a.to_ref(); - let mut res: VecZnx<&mut [u8]> = res.to_mut(); - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(res.n(), self.n()); - } - unsafe { - vec_znx::vec_znx_automorphism( - self.ptr, - k, - res.at_mut_ptr(res_col, 0), - res.size() as u64, - res.sl() as u64, - a.at_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - ) - } - } - - fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) - where - A: VecZnxToMut, - { - let mut a: VecZnx<&mut [u8]> = a.to_mut(); - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert!( - k & 1 != 0, - "invalid galois element: must be odd but is {}", - k - ); - } - unsafe { - vec_znx::vec_znx_automorphism( - self.ptr, - k, - a.at_mut_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - a.at_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - ) - } - } - - fn vec_znx_split(&self, res: &mut Vec, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) - where - R: VecZnxToMut, - A: VecZnxToRef, - { - let a: VecZnx<&[u8]> = a.to_ref(); - - let (n_in, n_out) = (a.n(), res[0].to_mut().n()); - - let (mut buf, _) = scratch.tmp_vec_znx(self, 1, a.size()); - - debug_assert!( - n_out < n_in, - "invalid a: output ring degree should be smaller" - ); - res[1..].iter_mut().for_each(|bi| { - debug_assert_eq!( - bi.to_mut().n(), - n_out, - "invalid input a: all VecZnx must have the same degree" - ) - }); - - res.iter_mut().enumerate().for_each(|(i, bi)| { - if i == 0 { - self.switch_degree(bi, res_col, &a, a_col); - self.vec_znx_rotate(-1, &mut buf, 0, &a, a_col); - } else { - self.switch_degree(bi, res_col, &mut buf, a_col); - self.vec_znx_rotate_inplace(-1, &mut buf, a_col); - } - }) - } - - fn vec_znx_merge(&self, res: &mut R, res_col: usize, a: Vec, a_col: usize) - where - R: VecZnxToMut, - A: VecZnxToRef, - { - let mut res: VecZnx<&mut [u8]> = res.to_mut(); - - let (n_in, n_out) = (res.n(), a[0].to_ref().n()); - - debug_assert!( - n_out < n_in, - "invalid a: output ring degree should be smaller" - ); - a[1..].iter().for_each(|ai| { - debug_assert_eq!( - ai.to_ref().n(), - n_out, - "invalid input a: all VecZnx must have the same degree" - ) - }); - - a.iter().enumerate().for_each(|(_, ai)| { - self.switch_degree(&mut res, res_col, ai, a_col); - self.vec_znx_rotate_inplace(-1, &mut res, res_col); - }); - - self.vec_znx_rotate_inplace(a.len() as i64, &mut res, res_col); - } - - fn switch_degree(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) - where - R: VecZnxToMut, - A: VecZnxToRef, - { - let a: VecZnx<&[u8]> = a.to_ref(); - let mut res: VecZnx<&mut [u8]> = res.to_mut(); - - let (n_in, n_out) = (a.n(), res.n()); - let (gap_in, gap_out): (usize, usize); - - if n_in > n_out { - (gap_in, gap_out) = (n_in / n_out, 1) - } else { - (gap_in, gap_out) = (1, n_out / n_in); - res.zero(); - } - - let size: usize = min(a.size(), res.size()); - - (0..size).for_each(|i| { - izip!( - a.at(a_col, i).iter().step_by(gap_in), - res.at_mut(res_col, i).iter_mut().step_by(gap_out) - ) - .for_each(|(x_in, x_out)| *x_out = *x_in); - }); - } -} - -impl VecZnxScratch for Module { - fn vec_znx_normalize_tmp_bytes(&self) -> usize { - unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.ptr) as usize } - } -} diff --git a/core/Cargo.toml b/core/Cargo.toml index a7c0fb3..f85f1b3 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -11,6 +11,7 @@ backend = {path="../backend"} sampling = {path="../sampling"} rand_distr = {workspace = true} itertools = {workspace = true} +byteorder = {workspace = true} [[bench]] name = "external_product_glwe_fft64" diff --git a/core/benches/external_product_glwe_fft64.rs b/core/benches/external_product_glwe_fft64.rs index fd6508a..6fa0f0d 100644 --- a/core/benches/external_product_glwe_fft64.rs +++ b/core/benches/external_product_glwe_fft64.rs @@ -1,8 +1,15 @@ -use backend::{FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned}; -use core::{FourierGLWESecret, GGSWCiphertext, GLWECiphertext, GLWESecret, Infos}; +use core::{GGSWCiphertext, GGSWCiphertextExec, GLWECiphertext, GLWESecret, GLWESecretExec, Infos}; +use std::hint::black_box; + +use backend::{ + hal::{ + api::{ModuleNew, ScalarZnxAlloc, ScratchOwnedAlloc, ScratchOwnedBorrow}, + layouts::{Module, ScalarZnx, ScratchOwned}, + }, + implementation::cpu_spqlios::FFT64, +}; use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; use sampling::source::Source; -use std::hint::black_box; fn bench_external_product_glwe_fft64(c: &mut Criterion) { let mut group = c.benchmark_group("external_product_glwe_fft64"); @@ -26,15 +33,15 @@ fn bench_external_product_glwe_fft64(c: &mut Criterion) { let rank: usize = p.rank; let digits: usize = 1; - let rows: usize = 1; //(p.k_ct_in + p.basek - 1) / p.basek; + let rows: usize = 1; //(p.k_ct_in.div_ceil(p.basek); let sigma: f64 = 3.2; - let mut ct_ggsw: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, digits, rank); + let mut ct_ggsw: GGSWCiphertext> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, digits, rank); let mut ct_glwe_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct_in, rank); let mut ct_glwe_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct_out, rank); - let pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + let pt_rgsw: ScalarZnx> = module.scalar_znx_alloc(1); - let mut scratch = ScratchOwned::new( + let mut scratch: ScratchOwned = ScratchOwned::alloc( GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, ct_ggsw.k(), rank) | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_glwe_in.k()) | GLWECiphertext::external_product_scratch_space( @@ -54,7 +61,7 @@ fn bench_external_product_glwe_fft64(c: &mut Criterion) { let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); + let sk_dft: GLWESecretExec, FFT64> = GLWESecretExec::from(&module, &sk); ct_ggsw.encrypt_sk( &module, @@ -75,8 +82,10 @@ fn bench_external_product_glwe_fft64(c: &mut Criterion) { scratch.borrow(), ); + let ggsw_exec: GGSWCiphertextExec, FFT64> = GGSWCiphertextExec::from(&module, &ct_ggsw, scratch.borrow()); + move || { - black_box(ct_glwe_out.external_product(&module, &ct_glwe_in, &ct_ggsw, scratch.borrow())); + black_box(ct_glwe_out.external_product(&module, &ct_glwe_in, &ggsw_exec, scratch.borrow())); } } @@ -118,14 +127,14 @@ fn bench_external_product_glwe_inplace_fft64(c: &mut Criterion) { let rank: usize = p.rank; let digits: usize = 1; - let rows: usize = (p.k_ct + p.basek - 1) / p.basek; + let rows: usize = p.k_ct.div_ceil(p.basek); let sigma: f64 = 3.2; - let mut ct_ggsw: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, digits, rank); + let mut ct_ggsw: GGSWCiphertext> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, digits, rank); let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_glwe, rank); - let pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + let pt_rgsw: ScalarZnx> = module.scalar_znx_alloc(1); - let mut scratch = ScratchOwned::new( + let mut scratch: ScratchOwned = ScratchOwned::alloc( GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, ct_ggsw.k(), rank) | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_glwe.k()) | GLWECiphertext::external_product_inplace_scratch_space(&module, basek, ct_glwe.k(), ct_ggsw.k(), digits, rank), @@ -137,7 +146,7 @@ fn bench_external_product_glwe_inplace_fft64(c: &mut Criterion) { let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); + let sk_dft: GLWESecretExec, FFT64> = GLWESecretExec::from(&module, &sk); ct_ggsw.encrypt_sk( &module, @@ -158,9 +167,11 @@ fn bench_external_product_glwe_inplace_fft64(c: &mut Criterion) { scratch.borrow(), ); + let ggsw_exec: GGSWCiphertextExec, FFT64> = GGSWCiphertextExec::from(&module, &ct_ggsw, scratch.borrow()); + move || { let scratch_borrow = scratch.borrow(); - black_box(ct_glwe.external_product_inplace(&module, &ct_ggsw, scratch_borrow)); + black_box(ct_glwe.external_product_inplace(&module, &ggsw_exec, scratch_borrow)); } } diff --git a/core/benches/keyswitch_glwe_fft64.rs b/core/benches/keyswitch_glwe_fft64.rs index 4acc754..696615d 100644 --- a/core/benches/keyswitch_glwe_fft64.rs +++ b/core/benches/keyswitch_glwe_fft64.rs @@ -1,8 +1,18 @@ -use backend::{FFT64, Module, ScratchOwned}; -use core::{FourierGLWESecret, GLWEAutomorphismKey, GLWECiphertext, GLWESecret, GLWESwitchingKey, Infos}; +use core::{ + AutomorphismKey, AutomorphismKeyExec, GLWECiphertext, GLWESecret, GLWESecretExec, GLWESwitchingKey, GLWESwitchingKeyExec, + Infos, +}; +use std::{hint::black_box, time::Duration}; + +use backend::{ + hal::{ + api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow}, + layouts::{Module, ScratchOwned}, + }, + implementation::cpu_spqlios::FFT64, +}; use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; use sampling::source::Source; -use std::{hint::black_box, time::Duration}; fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { let mut group = c.benchmark_group("keyswitch_glwe_fft64"); @@ -29,15 +39,14 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { let rank_out: usize = p.rank_out; let digits: usize = p.digits; - let rows: usize = (p.k_ct_in + (p.basek * digits) - 1) / (p.basek * digits); + let rows: usize = p.k_ct_in.div_ceil(p.basek * digits); let sigma: f64 = 3.2; - let mut ksk: GLWEAutomorphismKey, FFT64> = - GLWEAutomorphismKey::alloc(&module, basek, k_grlwe, rows, digits, rank_out); + let mut ksk: AutomorphismKey> = AutomorphismKey::alloc(&module, basek, k_grlwe, rows, digits, rank_out); let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_rlwe_in, rank_in); let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_rlwe_out, rank_out); - let mut scratch = ScratchOwned::new( + let mut scratch: ScratchOwned = ScratchOwned::alloc( GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ksk.k(), rank_in, rank_out) | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_in.k()) | GLWECiphertext::keyswitch_scratch_space( @@ -58,7 +67,7 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { let mut sk_in: GLWESecret> = GLWESecret::alloc(&module, rank_in); sk_in.fill_ternary_prob(0.5, &mut source_xs); - let sk_in_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_in); + let sk_in_dft: GLWESecretExec, FFT64> = GLWESecretExec::from(&module, &sk_in); let mut sk_out: GLWESecret> = GLWESecret::alloc(&module, rank_out); sk_out.fill_ternary_prob(0.5, &mut source_xs); @@ -82,8 +91,10 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { scratch.borrow(), ); + let ksk_exec: AutomorphismKeyExec, _> = AutomorphismKeyExec::from(&module, &ksk, scratch.borrow()); + move || { - black_box(ct_out.automorphism(&module, &ct_in, &ksk, scratch.borrow())); + black_box(ct_out.automorphism(&module, &ct_in, &ksk_exec, scratch.borrow())); } } @@ -132,13 +143,13 @@ fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) { let rank: usize = p.rank; let digits: usize = 1; - let rows: usize = (p.k_ct + p.basek - 1) / p.basek; + let rows: usize = p.k_ct.div_ceil(p.basek); let sigma: f64 = 3.2; - let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank, rank); + let mut ksk: GLWESwitchingKey> = GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank, rank); let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); - let mut scratch = ScratchOwned::new( + let mut scratch: ScratchOwned = ScratchOwned::alloc( GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ksk.k(), rank, rank) | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k()) | GLWECiphertext::keyswitch_inplace_scratch_space(&module, basek, ct.k(), ksk.k(), digits, rank), @@ -150,16 +161,15 @@ fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) { let mut sk_in: GLWESecret> = GLWESecret::alloc(&module, rank); sk_in.fill_ternary_prob(0.5, &mut source_xs); - let sk_in_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_in); + let sk_in_dft: GLWESecretExec, FFT64> = GLWESecretExec::from(&module, &sk_in); let mut sk_out: GLWESecret> = GLWESecret::alloc(&module, rank); sk_out.fill_ternary_prob(0.5, &mut source_xs); - let sk_out_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_out); ksk.encrypt_sk( &module, &sk_in, - &sk_out_dft, + &sk_out, &mut source_xa, &mut source_xe, sigma, @@ -175,8 +185,10 @@ fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) { scratch.borrow(), ); + let ksk_exec: GLWESwitchingKeyExec, FFT64> = GLWESwitchingKeyExec::from(&module, &ksk, scratch.borrow()); + move || { - black_box(ct.keyswitch_inplace(&module, &ksk, scratch.borrow())); + black_box(ct.keyswitch_inplace(&module, &ksk_exec, scratch.borrow())); } } diff --git a/core/src/automorphism.rs b/core/src/automorphism.rs deleted file mode 100644 index 3032532..0000000 --- a/core/src/automorphism.rs +++ /dev/null @@ -1,375 +0,0 @@ -use backend::{Backend, FFT64, MatZnxDft, MatZnxDftOps, Module, ScalarZnxOps, Scratch, VecZnx, VecZnxDftOps, VecZnxOps, ZnxZero}; -use sampling::source::Source; - -use crate::{ - FourierGLWECiphertext, GGLWECiphertext, GGSWCiphertext, GLWECiphertext, GLWESecret, GLWESwitchingKey, GetRow, Infos, - ScratchCore, SetRow, -}; - -pub struct AutomorphismKey { - pub(crate) key: GLWESwitchingKey, - pub(crate) p: i64, -} - -impl AutomorphismKey, FFT64> { - pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self { - AutomorphismKey { - key: GLWESwitchingKey::alloc(module, basek, k, rows, digits, rank, rank), - p: 0, - } - } - - pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize { - GLWESwitchingKey::, FFT64>::bytes_of(module, basek, k, rows, digits, rank, rank) - } -} - -impl Infos for AutomorphismKey { - type Inner = MatZnxDft; - - fn inner(&self) -> &Self::Inner { - &self.key.inner() - } - - fn basek(&self) -> usize { - self.key.basek() - } - - fn k(&self) -> usize { - self.key.k() - } -} - -impl AutomorphismKey { - pub fn p(&self) -> i64 { - self.p - } - - pub fn digits(&self) -> usize { - self.key.digits() - } - - pub fn rank(&self) -> usize { - self.key.rank() - } - - pub fn rank_in(&self) -> usize { - self.key.rank_in() - } - - pub fn rank_out(&self) -> usize { - self.key.rank_out() - } -} - -impl> GetRow for AutomorphismKey { - fn get_row + AsRef<[u8]>>( - &self, - module: &Module, - row_i: usize, - col_j: usize, - res: &mut FourierGLWECiphertext, - ) { - module.mat_znx_dft_get_row(&mut res.data, &self.key.0.data, row_i, col_j); - } -} - -impl + AsRef<[u8]>> SetRow for AutomorphismKey { - fn set_row>( - &mut self, - module: &Module, - row_i: usize, - col_j: usize, - a: &FourierGLWECiphertext, - ) { - module.mat_znx_dft_set_row(&mut self.key.0.data, row_i, col_j, &a.data); - } -} - -impl AutomorphismKey, FFT64> { - pub fn generate_from_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { - GGLWECiphertext::generate_from_sk_scratch_space(module, basek, k, rank) + GLWESecret::bytes_of(module, rank) - } - - pub fn generate_from_pk_scratch_space(module: &Module, _basek: usize, _k: usize, _rank: usize) -> usize { - GGLWECiphertext::generate_from_pk_scratch_space(module, _basek, _k, _rank) - } - - pub fn keyswitch_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_in: usize, - k_ksk: usize, - digits: usize, - rank: usize, - ) -> usize { - GLWESwitchingKey::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank, rank) - } - - pub fn keyswitch_inplace_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_ksk: usize, - digits: usize, - rank: usize, - ) -> usize { - GLWESwitchingKey::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank) - } - - pub fn automorphism_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_in: usize, - k_ksk: usize, - digits: usize, - rank: usize, - ) -> usize { - let tmp_dft: usize = FourierGLWECiphertext::bytes_of(module, basek, k_in, rank); - let tmp_idft: usize = FourierGLWECiphertext::bytes_of(module, basek, k_out, rank); - let idft: usize = module.vec_znx_idft_tmp_bytes(); - let keyswitch: usize = GLWECiphertext::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank); - tmp_dft + tmp_idft + idft + keyswitch - } - - pub fn automorphism_inplace_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_ksk: usize, - digits: usize, - rank: usize, - ) -> usize { - AutomorphismKey::automorphism_scratch_space(module, basek, k_out, k_out, k_ksk, digits, rank) - } - - pub fn external_product_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_in: usize, - ggsw_k: usize, - digits: usize, - rank: usize, - ) -> usize { - GLWESwitchingKey::external_product_scratch_space(module, basek, k_out, k_in, ggsw_k, digits, rank) - } - - pub fn external_product_inplace_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - ggsw_k: usize, - digits: usize, - rank: usize, - ) -> usize { - GLWESwitchingKey::external_product_inplace_scratch_space(module, basek, k_out, ggsw_k, digits, rank) - } -} - -impl + AsRef<[u8]>> AutomorphismKey { - pub fn generate_from_sk>( - &mut self, - module: &Module, - p: i64, - sk: &GLWESecret, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - scratch: &mut Scratch, - ) { - #[cfg(debug_assertions)] - { - assert_eq!(self.n(), module.n()); - assert_eq!(sk.n(), module.n()); - assert_eq!(self.rank_out(), self.rank_in()); - assert_eq!(sk.rank(), self.rank()); - assert!( - scratch.available() - >= AutomorphismKey::generate_from_sk_scratch_space(module, self.basek(), self.k(), self.rank()), - "scratch.available(): {} < AutomorphismKey::generate_from_sk_scratch_space(module, self.rank()={}, \ - self.size()={}): {}", - scratch.available(), - self.rank(), - self.size(), - AutomorphismKey::generate_from_sk_scratch_space(module, self.basek(), self.k(), self.rank()) - ) - } - - let (mut sk_out, scratch_1) = scratch.tmp_sk(module, sk.rank()); - (0..self.rank()).for_each(|i| { - module.scalar_znx_automorphism( - module.galois_element_inv(p), - &mut sk_out.data, - i, - &sk.data, - i, - ); - }); - - sk_out.prep_fourier(module); - - self.key - .generate_from_sk(module, &sk, &sk_out, source_xa, source_xe, sigma, scratch_1); - - self.p = p; - } -} - -impl + AsRef<[u8]>> AutomorphismKey { - pub fn automorphism, DataRhs: AsRef<[u8]>>( - &mut self, - module: &Module, - lhs: &AutomorphismKey, - rhs: &AutomorphismKey, - scratch: &mut Scratch, - ) { - #[cfg(debug_assertions)] - { - assert_eq!( - self.rank_in(), - lhs.rank_in(), - "ksk_out input rank: {} != ksk_in input rank: {}", - self.rank_in(), - lhs.rank_in() - ); - assert_eq!( - lhs.rank_out(), - rhs.rank_in(), - "ksk_in output rank: {} != ksk_apply input rank: {}", - self.rank_out(), - rhs.rank_in() - ); - assert_eq!( - self.rank_out(), - rhs.rank_out(), - "ksk_out output rank: {} != ksk_apply output rank: {}", - self.rank_out(), - rhs.rank_out() - ); - assert!( - self.k() <= lhs.k(), - "output k={} cannot be greater than input k={}", - self.k(), - lhs.k() - ) - } - - let cols_out: usize = rhs.rank_out() + 1; - - (0..self.rank_in()).for_each(|col_i| { - (0..self.rows()).for_each(|row_j| { - let (mut tmp_idft_data, scratct1) = scratch.tmp_vec_znx_big(module, cols_out, self.size()); - - { - let (mut tmp_dft, scratch2) = scratct1.tmp_glwe_fourier(module, lhs.basek(), lhs.k(), lhs.rank()); - - // Extracts relevant row - lhs.get_row(module, row_j, col_i, &mut tmp_dft); - - // Get a VecZnxBig from scratch space - - // Switches input outside of DFT - (0..cols_out).for_each(|i| { - module.vec_znx_idft(&mut tmp_idft_data, i, &tmp_dft.data, i, scratch2); - }); - } - - // Consumes to small vec znx - let mut tmp_idft_small_data: VecZnx<&mut [u8]> = tmp_idft_data.to_vec_znx_small(); - - // Reverts the automorphis key from (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a) - (0..cols_out).for_each(|i| { - module.vec_znx_automorphism_inplace(lhs.p(), &mut tmp_idft_small_data, i); - }); - - // Wraps into ciphertext - let mut tmp_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { - data: tmp_idft_small_data, - basek: self.basek(), - k: self.k(), - }; - - // Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) - tmp_idft.keyswitch_inplace(module, &rhs.key, scratct1); - - { - let (mut tmp_dft, _) = scratct1.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); - - // Applies back the automorphism X^{k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) -> (-pi^{-1}_{k'+k}(s)a + s, a) - // and switches back to DFT domain - (0..self.rank_out() + 1).for_each(|i| { - module.vec_znx_automorphism_inplace(lhs.p(), &mut tmp_idft.data, i); - module.vec_znx_dft(1, 0, &mut tmp_dft.data, i, &tmp_idft.data, i); - }); - - // Sets back the relevant row - self.set_row(module, row_j, col_i, &tmp_dft); - } - }); - }); - - let (mut tmp_dft, _) = scratch.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); - tmp_dft.data.zero(); - - (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| { - (0..self.rank_in()).for_each(|col_j| { - self.set_row(module, row_i, col_j, &tmp_dft); - }); - }); - - self.p = (lhs.p * rhs.p) % (module.cyclotomic_order() as i64); - } - - pub fn automorphism_inplace>( - &mut self, - module: &Module, - rhs: &AutomorphismKey, - scratch: &mut Scratch, - ) { - unsafe { - let self_ptr: *mut AutomorphismKey = self as *mut AutomorphismKey; - self.automorphism(&module, &*self_ptr, rhs, scratch); - } - } - - pub fn keyswitch, DataRhs: AsRef<[u8]>>( - &mut self, - module: &Module, - lhs: &AutomorphismKey, - rhs: &GLWESwitchingKey, - scratch: &mut Scratch, - ) { - self.key.keyswitch(module, &lhs.key, rhs, scratch); - } - - pub fn keyswitch_inplace>( - &mut self, - module: &Module, - rhs: &AutomorphismKey, - scratch: &mut Scratch, - ) { - self.key.keyswitch_inplace(module, &rhs.key, scratch); - } - - pub fn external_product, DataRhs: AsRef<[u8]>>( - &mut self, - module: &Module, - lhs: &AutomorphismKey, - rhs: &GGSWCiphertext, - scratch: &mut Scratch, - ) { - self.key.external_product(module, &lhs.key, rhs, scratch); - } - - pub fn external_product_inplace>( - &mut self, - module: &Module, - rhs: &GGSWCiphertext, - scratch: &mut Scratch, - ) { - self.key.external_product_inplace(module, rhs, scratch); - } -} diff --git a/core/src/blind_rotation/cggi.rs b/core/src/blind_rotation/cggi.rs index d9ada60..796daef 100644 --- a/core/src/blind_rotation/cggi.rs +++ b/core/src/blind_rotation/cggi.rs @@ -1,18 +1,47 @@ -use backend::{ - FFT64, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, Scratch, VecZnxAlloc, - VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxView, ZnxViewMut, ZnxZero, +use backend::hal::{ + api::{ + ScratchAvailable, SvpApply, SvpPPolAllocBytes, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, TakeVecZnxDftSlice, + TakeVecZnxSlice, VecZnxAddInplace, VecZnxAllocBytes, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, + VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAllocBytes, VecZnxDftFromVecZnx, + VecZnxDftSubABInplace, VecZnxDftToVecZnxBig, VecZnxDftToVecZnxBigTmpBytes, VecZnxDftZero, VecZnxMulXpMinusOneInplace, + VecZnxNormalize, VecZnxNormalizeInplace, VecZnxRotate, VecZnxSubABInplace, VmpApplyTmpBytes, ZnxView, ZnxZero, + }, + layouts::{Backend, DataMut, DataRef, Module, Scratch, SvpPPol}, }; use itertools::izip; use crate::{ - GLWECiphertext, GLWECiphertextToMut, GLWEOps, Infos, LWECiphertext, ScratchCore, - blind_rotation::{key::BlindRotationKeyCGGI, lut::LookUpTable}, + GLWECiphertext, GLWECiphertextToMut, GLWEExternalProductFamily, GLWEOps, Infos, LWECiphertext, TakeGLWECt, + blind_rotation::{key::BlindRotationKeyCGGIExec, lut::LookUpTable}, dist::Distribution, lwe::ciphertext::LWECiphertextToRef, }; -pub fn cggi_blind_rotate_scratch_space( - module: &Module, +pub trait CCGIBlindRotationFamily = VecZnxBigAllocBytes + + VecZnxDftAllocBytes + + SvpPPolAllocBytes + + VmpApplyTmpBytes + + VecZnxBigNormalizeTmpBytes + + VecZnxDftToVecZnxBigTmpBytes + + VecZnxDftToVecZnxBig + + VecZnxDftAdd + + VecZnxDftAddInplace + + VecZnxDftFromVecZnx + + VecZnxDftZero + + SvpApply + + VecZnxDftSubABInplace + + VecZnxBigAddSmallInplace + + GLWEExternalProductFamily + + VecZnxRotate + + VecZnxAddInplace + + VecZnxSubABInplace + + VecZnxNormalize + + VecZnxNormalizeInplace + + VecZnxCopy + + VecZnxMulXpMinusOneInplace; + +pub fn cggi_blind_rotate_scratch_space( + module: &Module, block_size: usize, extension_factor: usize, basek: usize, @@ -20,22 +49,24 @@ pub fn cggi_blind_rotate_scratch_space( k_brk: usize, rows: usize, rank: usize, -) -> usize { +) -> usize +where + Module: CCGIBlindRotationFamily + VecZnxAllocBytes, +{ let brk_size: usize = k_brk.div_ceil(basek); if block_size > 1 { let cols: usize = rank + 1; - let acc_dft: usize = module.bytes_of_vec_znx_dft(cols, rows) * extension_factor; - let acc_big: usize = module.bytes_of_vec_znx_big(1, brk_size); - let vmp_res: usize = module.bytes_of_vec_znx_dft(cols, brk_size) * extension_factor; + let acc_dft: usize = module.vec_znx_dft_alloc_bytes(cols, rows) * extension_factor; + let acc_big: usize = module.vec_znx_big_alloc_bytes(1, brk_size); + let vmp_res: usize = module.vec_znx_dft_alloc_bytes(cols, brk_size) * extension_factor; + let vmp_xai: usize = module.vec_znx_dft_alloc_bytes(1, brk_size); let acc_dft_add: usize = vmp_res; - let xai_plus_y: usize = module.bytes_of_scalar_znx_dft(1); - let xai_plus_y_dft: usize = module.bytes_of_scalar_znx_dft(1); let vmp: usize = module.vmp_apply_tmp_bytes(brk_size, rows, rows, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2) let acc: usize; if extension_factor > 1 { - acc = module.bytes_of_vec_znx(cols, k_res.div_ceil(basek)) * extension_factor; + acc = module.vec_znx_alloc_bytes(cols, k_res.div_ceil(basek)) * extension_factor; } else { acc = 0; } @@ -44,26 +75,30 @@ pub fn cggi_blind_rotate_scratch_space( + acc_dft + acc_dft_add + vmp_res - + xai_plus_y - + xai_plus_y_dft - + (vmp | (acc_big + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes()))); + + vmp_xai + + (vmp + | (acc_big + + (module.vec_znx_big_normalize_tmp_bytes(module.n()) | module.vec_znx_dft_to_vec_znx_big_tmp_bytes()))); } else { - 2 * GLWECiphertext::bytes_of(module, basek, k_res, rank) + GLWECiphertext::bytes_of(module, basek, k_res, rank) + GLWECiphertext::external_product_scratch_space(module, basek, k_res, k_res, k_brk, 1, rank) } } -pub fn cggi_blind_rotate( - module: &Module, +pub fn cggi_blind_rotate( + module: &Module, res: &mut GLWECiphertext, lwe: &LWECiphertext, lut: &LookUpTable, - brk: &BlindRotationKeyCGGI, - scratch: &mut Scratch, + brk: &BlindRotationKeyCGGIExec, + scratch: &mut Scratch, ) where - DataRes: AsRef<[u8]> + AsMut<[u8]>, - DataIn: AsRef<[u8]>, - DataBrk: AsRef<[u8]>, + DataRes: DataMut, + DataIn: DataRef, + DataBrk: DataRef, + Module: CCGIBlindRotationFamily, + Scratch: + TakeVecZnxDftSlice + TakeVecZnxDft + TakeVecZnxBig + TakeVecZnx + ScratchAvailable + TakeVecZnxSlice, { match brk.dist { Distribution::BinaryBlock(_) | Distribution::BinaryFixed(_) | Distribution::BinaryProb(_) | Distribution::ZERO => { @@ -82,37 +117,36 @@ pub fn cggi_blind_rotate( } } -pub(crate) fn cggi_blind_rotate_block_binary_extended( - module: &Module, +pub(crate) fn cggi_blind_rotate_block_binary_extended( + module: &Module, res: &mut GLWECiphertext, lwe: &LWECiphertext, lut: &LookUpTable, - brk: &BlindRotationKeyCGGI, - scratch: &mut Scratch, + brk: &BlindRotationKeyCGGIExec, + scratch: &mut Scratch, ) where - DataRes: AsRef<[u8]> + AsMut<[u8]>, - DataIn: AsRef<[u8]>, - DataBrk: AsRef<[u8]>, + DataRes: DataMut, + DataIn: DataRef, + DataBrk: DataRef, + Module: CCGIBlindRotationFamily, + Scratch: TakeVecZnxDftSlice + TakeVecZnxDft + TakeVecZnxBig + TakeVecZnxSlice, { let extension_factor: usize = lut.extension_factor(); let basek: usize = res.basek(); let rows: usize = brk.rows(); let cols: usize = res.rank() + 1; - let (mut acc, scratch1) = scratch.tmp_slice_vec_znx(extension_factor, module, cols, res.size()); - let (mut acc_dft, scratch2) = scratch1.tmp_slice_vec_znx_dft(extension_factor, module, cols, rows); - let (mut vmp_res, scratch3) = scratch2.tmp_slice_vec_znx_dft(extension_factor, module, cols, brk.size()); - let (mut acc_add_dft, scratch4) = scratch3.tmp_slice_vec_znx_dft(extension_factor, module, cols, brk.size()); - let (mut minus_one, scratch5) = scratch4.tmp_scalar_znx_dft(module, 1); - let (mut xai_plus_y_dft, scratch6) = scratch5.tmp_scalar_znx_dft(module, 1); - - minus_one.raw_mut()[..module.n() >> 1].fill(-1.0); + let (mut acc, scratch1) = scratch.take_vec_znx_slice(extension_factor, module, cols, res.size()); + let (mut acc_dft, scratch2) = scratch1.take_vec_znx_dft_slice(extension_factor, module, cols, rows); + let (mut vmp_res, scratch3) = scratch2.take_vec_znx_dft_slice(extension_factor, module, cols, brk.size()); + let (mut acc_add_dft, scratch4) = scratch3.take_vec_znx_dft_slice(extension_factor, module, cols, brk.size()); + let (mut vmp_xai, scratch5) = scratch4.take_vec_znx_dft(module, 1, brk.size()); (0..extension_factor).for_each(|i| { acc[i].zero(); }); - let x_pow_a: &Vec, FFT64>>; + let x_pow_a: &Vec, B>>; if let Some(b) = &brk.x_pow_a { x_pow_a = b } else { @@ -149,9 +183,9 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended( .for_each(|(ai, ski)| { (0..extension_factor).for_each(|i| { (0..cols).for_each(|j| { - module.vec_znx_dft(1, 0, &mut acc_dft[i], j, &acc[i], j); + module.vec_znx_dft_from_vec_znx(1, 0, &mut acc_dft[i], j, &acc[i], j); }); - acc_add_dft[i].zero(); + module.vec_znx_dft_zero(&mut acc_add_dft[i]) }); // TODO: first & last iterations can be optimized @@ -162,19 +196,19 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended( // vmp_res = DFT(acc) * BRK[i] (0..extension_factor).for_each(|i| { - module.vmp_apply(&mut vmp_res[i], &acc_dft[i], &skii.data, scratch6); + module.vmp_apply(&mut vmp_res[i], &acc_dft[i], &skii.data, scratch5); }); // Trivial case: no rotation between polynomials, we can directly multiply with (X^{-ai} - 1) if ai_lo == 0 { - // Sets acc_add_dft[i] = (acc[i] * sk) * (X^{-ai} - 1) + // Sets acc_add_dft[i] = (acc[i] * sk) * X^{-ai} - (acc[i] * sk) if ai_hi != 0 { // DFT X^{-ai} - module.vec_znx_dft_add(&mut xai_plus_y_dft, 0, &x_pow_a[ai_hi], 0, &minus_one, 0); (0..extension_factor).for_each(|j| { (0..cols).for_each(|i| { - module.svp_apply_inplace(&mut vmp_res[j], i, &xai_plus_y_dft, 0); - module.vec_znx_dft_add_inplace(&mut acc_add_dft[j], i, &vmp_res[j], i); + module.svp_apply(&mut vmp_xai, 0, &x_pow_a[ai_hi], 0, &vmp_res[j], i); + module.vec_znx_dft_add_inplace(&mut acc_add_dft[j], i, &vmp_xai, 0); + module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[j], i, &vmp_res[j], i); }); }); } @@ -184,32 +218,13 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended( // ring homomorphism R^{N} -> prod R^{N/extension_factor}, so we split the // computation in two steps: acc_add_dft = (acc * sk) * (-1) + (acc * sk) * X^{-ai} } else { - // Sets acc_add_dft[i] = acc[i] * sk - - // Sets acc_add_dft[0..ai_lo] -= acc[..ai_lo] * sk - if (ai_hi + 1) & (two_n - 1) != 0 { - for i in 0..ai_lo { - (0..cols).for_each(|k| { - module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[i], k, &vmp_res[i], k); - }); - } - } - - // Sets acc_add_dft[ai_lo..extension_factor] -= acc[ai_lo..extension_factor] * sk - if ai_hi != 0 { - for i in ai_lo..extension_factor { - (0..cols).for_each(|k: usize| { - module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[i], k, &vmp_res[i], k); - }); - } - } - // Sets acc_add_dft[0..ai_lo] += (acc[extension_factor - ai_lo..extension_factor] * sk) * X^{-ai+1} if (ai_hi + 1) & (two_n - 1) != 0 { for (i, j) in (0..ai_lo).zip(extension_factor - ai_lo..extension_factor) { (0..cols).for_each(|k| { - module.svp_apply_inplace(&mut vmp_res[j], k, &x_pow_a[ai_hi + 1], 0); - module.vec_znx_dft_add_inplace(&mut acc_add_dft[i], k, &vmp_res[j], k); + module.svp_apply(&mut vmp_xai, 0, &x_pow_a[ai_hi + 1], 0, &vmp_res[j], k); + module.vec_znx_dft_add_inplace(&mut acc_add_dft[i], k, &vmp_xai, 0); + module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[i], k, &vmp_res[i], k); }); } } @@ -219,8 +234,9 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended( // Sets acc_add_dft[ai_lo..extension_factor] += (acc[0..extension_factor - ai_lo] * sk) * X^{-ai} for (i, j) in (ai_lo..extension_factor).zip(0..extension_factor - ai_lo) { (0..cols).for_each(|k| { - module.svp_apply_inplace(&mut vmp_res[j], k, &x_pow_a[ai_hi], 0); - module.vec_znx_dft_add_inplace(&mut acc_add_dft[i], k, &vmp_res[j], k); + module.svp_apply(&mut vmp_xai, 0, &x_pow_a[ai_hi], 0, &vmp_res[j], k); + module.vec_znx_dft_add_inplace(&mut acc_add_dft[i], k, &vmp_xai, 0); + module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[i], k, &vmp_res[i], k); }); } } @@ -228,11 +244,11 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended( }); { - let (mut acc_add_big, scratch7) = scratch6.tmp_vec_znx_big(module, 1, brk.size()); + let (mut acc_add_big, scratch7) = scratch5.take_vec_znx_big(module, 1, brk.size()); (0..extension_factor).for_each(|j| { (0..cols).for_each(|i| { - module.vec_znx_idft(&mut acc_add_big, 0, &acc_add_dft[j], i, scratch7); + module.vec_znx_dft_to_vec_znx_big(&mut acc_add_big, 0, &acc_add_dft[j], i, scratch7); module.vec_znx_big_add_small_inplace(&mut acc_add_big, 0, &acc[j], i); module.vec_znx_big_normalize(basek, &mut acc[j], i, &acc_add_big, 0, scratch7); }); @@ -245,17 +261,19 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended( }); } -pub(crate) fn cggi_blind_rotate_block_binary( - module: &Module, +pub(crate) fn cggi_blind_rotate_block_binary( + module: &Module, res: &mut GLWECiphertext, lwe: &LWECiphertext, lut: &LookUpTable, - brk: &BlindRotationKeyCGGI, - scratch: &mut Scratch, + brk: &BlindRotationKeyCGGIExec, + scratch: &mut Scratch, ) where - DataRes: AsRef<[u8]> + AsMut<[u8]>, - DataIn: AsRef<[u8]>, - DataBrk: AsRef<[u8]>, + DataRes: DataMut, + DataIn: DataRef, + DataBrk: DataRef, + Module: CCGIBlindRotationFamily, + Scratch: TakeVecZnxDft + TakeVecZnxBig, { let mut lwe_2n: Vec = vec![0i64; lwe.n() + 1]; // TODO: from scratch space let mut out_mut: GLWECiphertext<&mut [u8]> = res.to_mut(); @@ -280,15 +298,12 @@ pub(crate) fn cggi_blind_rotate_block_binary( // ACC + [sum DFT(X^ai -1) * (DFT(ACC) x BRKi)] - let (mut acc_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rows); - let (mut vmp_res, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, brk.size()); - let (mut acc_add_dft, scratch3) = scratch2.tmp_vec_znx_dft(module, cols, brk.size()); - let (mut minus_one, scratch4) = scratch3.tmp_scalar_znx_dft(module, 1); - let (mut xai_plus_y_dft, scratch5) = scratch4.tmp_scalar_znx_dft(module, 1); + let (mut acc_dft, scratch1) = scratch.take_vec_znx_dft(module, cols, rows); + let (mut vmp_res, scratch2) = scratch1.take_vec_znx_dft(module, cols, brk.size()); + let (mut acc_add_dft, scratch3) = scratch2.take_vec_znx_dft(module, cols, brk.size()); + let (mut vmp_xai, scratch4) = scratch3.take_vec_znx_dft(module, 1, brk.size()); - minus_one.raw_mut()[..module.n() >> 1].fill(-1.0); - - let x_pow_a: &Vec, FFT64>>; + let x_pow_a: &Vec, B>>; if let Some(b) = &brk.x_pow_a { x_pow_a = b } else { @@ -301,50 +316,50 @@ pub(crate) fn cggi_blind_rotate_block_binary( ) .for_each(|(ai, ski)| { (0..cols).for_each(|j| { - module.vec_znx_dft(1, 0, &mut acc_dft, j, &out_mut.data, j); + module.vec_znx_dft_from_vec_znx(1, 0, &mut acc_dft, j, &out_mut.data, j); }); - acc_add_dft.zero(); + module.vec_znx_dft_zero(&mut acc_add_dft); izip!(ai.iter(), ski.iter()).for_each(|(aii, skii)| { let ai_pos: usize = ((aii + two_n as i64) & (two_n - 1) as i64) as usize; // vmp_res = DFT(acc) * BRK[i] - module.vmp_apply(&mut vmp_res, &acc_dft, &skii.data, scratch5); - - // DFT(X^ai -1) - module.vec_znx_dft_add(&mut xai_plus_y_dft, 0, &x_pow_a[ai_pos], 0, &minus_one, 0); + module.vmp_apply(&mut vmp_res, &acc_dft, &skii.data, scratch4); // DFT(X^ai -1) * (DFT(acc) * BRK[i]) (0..cols).for_each(|i| { - module.svp_apply_inplace(&mut vmp_res, i, &xai_plus_y_dft, 0); - module.vec_znx_dft_add_inplace(&mut acc_add_dft, i, &vmp_res, i); + module.svp_apply(&mut vmp_xai, 0, &x_pow_a[ai_pos], 0, &vmp_res, i); + module.vec_znx_dft_add_inplace(&mut acc_add_dft, i, &vmp_xai, 0); + module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft, i, &vmp_res, i); }); }); { - let (mut acc_add_big, scratch6) = scratch5.tmp_vec_znx_big(module, 1, brk.size()); + let (mut acc_add_big, scratch5) = scratch4.take_vec_znx_big(module, 1, brk.size()); (0..cols).for_each(|i| { - module.vec_znx_idft(&mut acc_add_big, 0, &acc_add_dft, i, scratch6); + module.vec_znx_dft_to_vec_znx_big(&mut acc_add_big, 0, &acc_add_dft, i, scratch5); module.vec_znx_big_add_small_inplace(&mut acc_add_big, 0, &out_mut.data, i); - module.vec_znx_big_normalize(basek, &mut out_mut.data, i, &acc_add_big, 0, scratch6); + module.vec_znx_big_normalize(basek, &mut out_mut.data, i, &acc_add_big, 0, scratch5); }); } }); } -pub(crate) fn cggi_blind_rotate_binary_standard( - module: &Module, +pub(crate) fn cggi_blind_rotate_binary_standard( + module: &Module, res: &mut GLWECiphertext, lwe: &LWECiphertext, lut: &LookUpTable, - brk: &BlindRotationKeyCGGI, - scratch: &mut Scratch, + brk: &BlindRotationKeyCGGIExec, + scratch: &mut Scratch, ) where - DataRes: AsRef<[u8]> + AsMut<[u8]>, - DataIn: AsRef<[u8]>, - DataBrk: AsRef<[u8]>, + DataRes: DataMut, + DataIn: DataRef, + DataBrk: DataRef, + Module: CCGIBlindRotationFamily, + Scratch: TakeVecZnxDft + TakeVecZnxBig + TakeVecZnx + ScratchAvailable, { #[cfg(debug_assertions)] { @@ -401,28 +416,24 @@ pub(crate) fn cggi_blind_rotate_binary_standard( module.vec_znx_rotate(b, &mut out_mut.data, 0, &lut.data[0], 0); // ACC + [sum DFT(X^ai -1) * (DFT(ACC) x BRKi)] - let (mut acc_tmp, scratch1) = scratch.tmp_glwe_ct(module, basek, out_mut.k(), out_mut.rank()); - let (mut acc_tmp_rot, scratch2) = scratch1.tmp_glwe_ct(module, basek, out_mut.k(), out_mut.rank()); + let (mut acc_tmp, scratch1) = scratch.take_glwe_ct(module, basek, out_mut.k(), out_mut.rank()); // TODO: see if faster by skipping normalization in external product and keeping acc in big coeffs // TODO: first iteration can be optimized to be a gglwe product izip!(a.iter(), brk.data.iter()).for_each(|(ai, ski)| { // acc_tmp = sk[i] * acc - acc_tmp.external_product(module, &out_mut, ski, scratch2); + acc_tmp.external_product(module, &out_mut, ski, scratch1); - // acc_tmp = (sk[i] * acc) * X^{ai} - acc_tmp_rot.rotate(module, *ai, &acc_tmp); + // acc_tmp = (sk[i] * acc) * (X^{ai} - 1) + acc_tmp.mul_xp_minus_one_inplace(module, *ai); - // acc = acc + (sk[i] * acc) * X^{ai} - out_mut.add_inplace(module, &acc_tmp_rot); - - // acc = acc + (sk[i] * acc) * X^{ai} - (sk[i] * acc) = acc + (sk[i] * acc) * (X^{ai} - 1) - out_mut.sub_inplace_ab(module, &acc_tmp); + // acc = acc + (sk[i] * acc) * (X^{ai} - 1) + out_mut.add_inplace(module, &acc_tmp); }); // We can normalize only at the end because we add normalized values in [-2^{basek-1}, 2^{basek-1}] // on top of each others, thus ~ 2^{63-basek} additions are supported before overflow. - out_mut.normalize_inplace(module, scratch2); + out_mut.normalize_inplace(module, scratch1); } pub(crate) fn negate_and_mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWECiphertext<&[u8]>) { diff --git a/core/src/blind_rotation/key.rs b/core/src/blind_rotation/key.rs index 01511c3..d46b21b 100644 --- a/core/src/blind_rotation/key.rs +++ b/core/src/blind_rotation/key.rs @@ -1,39 +1,185 @@ -use backend::{ - Backend, FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxToRef, Scratch, - ZnxView, ZnxViewMut, +use backend::hal::{ + api::{ + MatZnxAlloc, ScalarZnxAlloc, ScratchAvailable, SvpPPolAlloc, SvpPrepare, TakeVecZnx, TakeVecZnxDft, + VecZnxAddScalarInplace, VecZnxAllocBytes, ZnxView, ZnxViewMut, + }, + layouts::{Backend, Data, DataMut, DataRef, Module, ReaderFrom, ScalarZnx, ScalarZnxToRef, Scratch, SvpPPol, WriterTo}, }; use sampling::source::Source; -use crate::{Distribution, FourierGLWESecret, GGSWCiphertext, Infos, LWESecret}; +use crate::{ + Distribution, GGSWCiphertext, GGSWCiphertextExec, GGSWEncryptSkFamily, GGSWLayoutFamily, GLWESecretExec, Infos, LWESecret, +}; -pub struct BlindRotationKeyCGGI { - pub(crate) data: Vec>, +pub struct BlindRotationKeyCGGI { + pub(crate) keys: Vec>, pub(crate) dist: Distribution, - pub(crate) x_pow_a: Option, B>>>, } -// pub struct BlindRotationKeyFHEW { -// pub(crate) data: Vec, B>>, -// pub(crate) auto: Vec, B>>, -//} +impl PartialEq for BlindRotationKeyCGGI { + fn eq(&self, other: &Self) -> bool { + if self.keys.len() != other.keys.len() { + return false; + } + for (a, b) in self.keys.iter().zip(other.keys.iter()) { + if a != b { + return false; + } + } + self.dist == other.dist + } +} -impl BlindRotationKeyCGGI, FFT64> { - pub fn allocate(module: &Module, n_lwe: usize, basek: usize, k: usize, rows: usize, rank: usize) -> Self { - let mut data: Vec, FFT64>> = Vec::with_capacity(n_lwe); +impl Eq for BlindRotationKeyCGGI {} + +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; + +impl ReaderFrom for BlindRotationKeyCGGI { + fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { + match Distribution::read_from(reader) { + Ok(dist) => self.dist = dist, + Err(e) => return Err(e), + } + let len: usize = reader.read_u64::()? as usize; + if self.keys.len() != len { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("self.keys.len()={} != read len={}", self.keys.len(), len), + )); + } + for key in &mut self.keys { + key.read_from(reader)?; + } + Ok(()) + } +} + +impl WriterTo for BlindRotationKeyCGGI { + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + match self.dist.write_to(writer) { + Ok(()) => {} + Err(e) => return Err(e), + } + writer.write_u64::(self.keys.len() as u64)?; + for key in &self.keys { + key.write_to(writer)?; + } + Ok(()) + } +} + +impl BlindRotationKeyCGGI> { + pub fn alloc(module: &Module, n_lwe: usize, basek: usize, k: usize, rows: usize, rank: usize) -> Self + where + Module: MatZnxAlloc, + { + let mut data: Vec>> = Vec::with_capacity(n_lwe); (0..n_lwe).for_each(|_| data.push(GGSWCiphertext::alloc(module, basek, k, rows, 1, rank))); Self { - data, + keys: data, dist: Distribution::NONE, - x_pow_a: None::, FFT64>>>, } } - pub fn generate_from_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { + pub fn generate_from_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize + where + Module: GGSWEncryptSkFamily + VecZnxAllocBytes, + { GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k, rank) } } -impl> BlindRotationKeyCGGI { +impl BlindRotationKeyCGGI { + #[allow(dead_code)] + pub(crate) fn n(&self) -> usize { + self.keys[0].n() + } + + #[allow(dead_code)] + pub(crate) fn rows(&self) -> usize { + self.keys[0].rows() + } + + #[allow(dead_code)] + pub(crate) fn k(&self) -> usize { + self.keys[0].k() + } + + #[allow(dead_code)] + pub(crate) fn size(&self) -> usize { + self.keys[0].size() + } + + #[allow(dead_code)] + pub(crate) fn rank(&self) -> usize { + self.keys[0].rank() + } + + pub(crate) fn basek(&self) -> usize { + self.keys[0].basek() + } + + #[allow(dead_code)] + pub(crate) fn block_size(&self) -> usize { + match self.dist { + Distribution::BinaryBlock(value) => value, + _ => 1, + } + } +} + +impl BlindRotationKeyCGGI { + pub fn generate_from_sk( + &mut self, + module: &Module, + sk_glwe: &GLWESecretExec, + sk_lwe: &LWESecret, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) where + DataSkGLWE: DataRef, + DataSkLWE: DataRef, + Module: GGSWEncryptSkFamily + ScalarZnxAlloc + VecZnxAddScalarInplace, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.keys.len(), sk_lwe.n()); + assert_eq!(sk_glwe.n(), module.n()); + assert_eq!(sk_glwe.rank(), self.keys[0].rank()); + match sk_lwe.dist { + Distribution::BinaryBlock(_) + | Distribution::BinaryFixed(_) + | Distribution::BinaryProb(_) + | Distribution::ZERO => {} + _ => panic!( + "invalid GLWESecret distribution: must be BinaryBlock, BinaryFixed or BinaryProb (or ZERO for debugging)" + ), + } + } + + self.dist = sk_lwe.dist; + + let mut pt: ScalarZnx> = module.scalar_znx_alloc(1); + let sk_ref: ScalarZnx<&[u8]> = sk_lwe.data.to_ref(); + + self.keys.iter_mut().enumerate().for_each(|(i, ggsw)| { + pt.at_mut(0, 0)[0] = sk_ref.at(0, 0)[i]; + ggsw.encrypt_sk(module, &pt, sk_glwe, source_xa, source_xe, sigma, scratch); + }); + } +} + +#[derive(PartialEq, Eq)] +pub struct BlindRotationKeyCGGIExec { + pub(crate) data: Vec>, + pub(crate) dist: Distribution, + pub(crate) x_pow_a: Option, B>>>, +} + +impl BlindRotationKeyCGGIExec { #[allow(dead_code)] pub(crate) fn n(&self) -> usize { self.data[0].n() @@ -71,52 +217,66 @@ impl> BlindRotationKeyCGGI { } } -impl + AsMut<[u8]>> BlindRotationKeyCGGI { - pub fn generate_from_sk( - &mut self, - module: &Module, - sk_glwe: &FourierGLWESecret, - sk_lwe: &LWESecret, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - scratch: &mut Scratch, - ) where - DataSkGLWE: AsRef<[u8]>, - DataSkLWE: AsRef<[u8]>, +pub trait BlindRotationKeyCGGIExecLayoutFamily = GGSWLayoutFamily + SvpPPolAlloc + SvpPrepare; + +impl BlindRotationKeyCGGIExec, B> { + pub fn alloc(module: &Module, n_lwe: usize, basek: usize, k: usize, rows: usize, rank: usize) -> Self + where + Module: BlindRotationKeyCGGIExecLayoutFamily, + { + let mut data: Vec, B>> = Vec::with_capacity(n_lwe); + (0..n_lwe).for_each(|_| data.push(GGSWCiphertextExec::alloc(module, basek, k, rows, 1, rank))); + Self { + data, + dist: Distribution::NONE, + x_pow_a: None, + } + } + + pub fn from(module: &Module, other: &BlindRotationKeyCGGI, scratch: &mut Scratch) -> Self + where + DataOther: DataRef, + Module: BlindRotationKeyCGGIExecLayoutFamily + ScalarZnxAlloc, + { + let mut brk: BlindRotationKeyCGGIExec, B> = Self::alloc( + module, + other.keys.len(), + other.basek(), + other.k(), + other.rows(), + other.rank(), + ); + brk.prepare(module, other, scratch); + brk + } +} + +impl BlindRotationKeyCGGIExec { + pub fn prepare(&mut self, module: &Module, other: &BlindRotationKeyCGGI, scratch: &mut Scratch) + where + DataOther: DataRef, + Module: BlindRotationKeyCGGIExecLayoutFamily + ScalarZnxAlloc, { #[cfg(debug_assertions)] { - assert_eq!(self.data.len(), sk_lwe.n()); - assert_eq!(sk_glwe.n(), module.n()); - assert_eq!(sk_glwe.rank(), self.data[0].rank()); - match sk_lwe.dist { - Distribution::BinaryBlock(_) - | Distribution::BinaryFixed(_) - | Distribution::BinaryProb(_) - | Distribution::ZERO => {} - _ => panic!( - "invalid GLWESecret distribution: must be BinaryBlock, BinaryFixed or BinaryProb (or ZERO for debugging)" - ), - } + assert_eq!(self.data.len(), other.keys.len()); } - self.dist = sk_lwe.dist; + self.data + .iter_mut() + .zip(other.keys.iter()) + .for_each(|(ggsw_exec, other)| { + ggsw_exec.prepare(module, other, scratch); + }); - let mut pt: ScalarZnx> = module.new_scalar_znx(1); - let sk_ref: ScalarZnx<&[u8]> = sk_lwe.data.to_ref(); + self.dist = other.dist; - self.data.iter_mut().enumerate().for_each(|(i, ggsw)| { - pt.at_mut(0, 0)[0] = sk_ref.at(0, 0)[i]; - ggsw.encrypt_sk(module, &pt, sk_glwe, source_xa, source_xe, sigma, scratch); - }); - - match sk_lwe.dist { + match other.dist { Distribution::BinaryBlock(_) => { - let mut x_pow_a: Vec, FFT64>> = Vec::with_capacity(module.n() << 1); - let mut buf: ScalarZnx> = module.new_scalar_znx(1); + let mut x_pow_a: Vec, B>> = Vec::with_capacity(module.n() << 1); + let mut buf: ScalarZnx> = module.scalar_znx_alloc(1); (0..module.n() << 1).for_each(|i| { - let mut res: ScalarZnxDft, FFT64> = module.new_scalar_znx_dft(1); + let mut res: SvpPPol, B> = module.svp_ppol_alloc(1); set_xai_plus_y(module, i, 0, &mut res, &mut buf); x_pow_a.push(res); }); @@ -127,10 +287,11 @@ impl + AsMut<[u8]>> BlindRotationKeyCGGI { } } -pub fn set_xai_plus_y(module: &Module, ai: usize, y: i64, res: &mut ScalarZnxDft, buf: &mut ScalarZnx) +pub fn set_xai_plus_y(module: &Module, ai: usize, y: i64, res: &mut SvpPPol, buf: &mut ScalarZnx) where - A: AsRef<[u8]> + AsMut<[u8]>, - B: AsRef<[u8]> + AsMut<[u8]>, + A: DataMut, + C: DataMut, + Module: SvpPrepare, { let n: usize = module.n(); diff --git a/core/src/blind_rotation/lut.rs b/core/src/blind_rotation/lut.rs index 300aa6e..90bc807 100644 --- a/core/src/blind_rotation/lut.rs +++ b/core/src/blind_rotation/lut.rs @@ -1,4 +1,11 @@ -use backend::{FFT64, Module, VecZnx, VecZnxAlloc, VecZnxOps, ZnxInfos, ZnxViewMut, alloc_aligned}; +use backend::hal::{ + api::{ + ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAlloc, VecZnxCopy, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, + VecZnxRotateInplace, VecZnxSwithcDegree, ZnxInfos, ZnxViewMut, + }, + layouts::{Backend, Module, ScratchOwned, VecZnx}, + oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl}, +}; pub struct LookUpTable { pub(crate) data: Vec>>, @@ -7,7 +14,10 @@ pub struct LookUpTable { } impl LookUpTable { - pub fn alloc(module: &Module, basek: usize, k: usize, extension_factor: usize) -> Self { + pub fn alloc(module: &Module, basek: usize, k: usize, extension_factor: usize) -> Self + where + Module: VecZnxAlloc, + { #[cfg(debug_assertions)] { assert!( @@ -19,7 +29,7 @@ impl LookUpTable { let size: usize = k.div_ceil(basek); let mut data: Vec>> = Vec::with_capacity(extension_factor); (0..extension_factor).for_each(|_| { - data.push(module.new_vec_znx(1, size)); + data.push(module.vec_znx_alloc(1, size)); }); Self { data, basek, k } } @@ -36,7 +46,11 @@ impl LookUpTable { self.data.len() * self.data[0].n() } - pub fn set(&mut self, module: &Module, f: &Vec, k: usize) { + pub fn set(&mut self, module: &Module, f: &Vec, k: usize) + where + Module: VecZnxRotateInplace + VecZnxNormalizeInplace + VecZnxNormalizeTmpBytes + VecZnxSwithcDegree + VecZnxCopy, + B: ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, + { assert!(f.len() <= module.n()); let basek: usize = self.basek; @@ -74,16 +88,22 @@ impl LookUpTable { // Rotates half the step to the left let half_step: usize = domain_size.div_round(f_len << 1); - lut_full.rotate(-(half_step as i64)); + module.vec_znx_rotate_inplace(-(half_step as i64), &mut lut_full, 0); - let mut tmp_bytes: Vec = alloc_aligned(lut_full.n() * size_of::()); - lut_full.normalize(self.basek, 0, &mut tmp_bytes); + let n_large: usize = lut_full.n(); + + module.vec_znx_normalize_inplace( + self.basek, + &mut lut_full, + 0, + ScratchOwned::alloc(module.vec_znx_normalize_tmp_bytes(n_large)).borrow(), + ); if self.extension_factor() > 1 { (0..self.extension_factor()).for_each(|i| { - module.switch_degree(&mut self.data[i], 0, &lut_full, 0); + module.vec_znx_switch_degree(&mut self.data[i], 0, &lut_full, 0); if i < self.extension_factor() { - lut_full.rotate(-1); + module.vec_znx_rotate_inplace(-1, &mut lut_full, 0); } }); } else { @@ -92,7 +112,10 @@ impl LookUpTable { } #[allow(dead_code)] - pub(crate) fn rotate(&mut self, k: i64) { + pub(crate) fn rotate(&mut self, module: &Module, k: i64) + where + Module: VecZnxRotateInplace, + { let extension_factor: usize = self.extension_factor(); let two_n: usize = 2 * self.data[0].n(); let two_n_ext: usize = two_n * extension_factor; @@ -103,11 +126,11 @@ impl LookUpTable { let k_lo: usize = k_pos % extension_factor; (0..extension_factor - k_lo).for_each(|i| { - self.data[i].rotate(k_hi as i64); + module.vec_znx_rotate_inplace(k_hi as i64, &mut self.data[i], 0); }); (extension_factor - k_lo..extension_factor).for_each(|i| { - self.data[i].rotate(k_hi as i64 + 1); + module.vec_znx_rotate_inplace(k_hi as i64 + 1, &mut self.data[i], 0); }); self.data.rotate_right(k_lo as usize); diff --git a/core/src/blind_rotation/mod.rs b/core/src/blind_rotation/mod.rs index bbbdd2c..6a454e6 100644 --- a/core/src/blind_rotation/mod.rs +++ b/core/src/blind_rotation/mod.rs @@ -2,9 +2,9 @@ pub mod cggi; pub mod key; pub mod lut; -pub use cggi::{cggi_blind_rotate, cggi_blind_rotate_scratch_space}; -pub use key::BlindRotationKeyCGGI; +pub use cggi::{CCGIBlindRotationFamily, cggi_blind_rotate, cggi_blind_rotate_scratch_space}; +pub use key::{BlindRotationKeyCGGI, BlindRotationKeyCGGIExec, BlindRotationKeyCGGIExecLayoutFamily}; pub use lut::LookUpTable; #[cfg(test)] -pub mod test_fft64; +mod test; diff --git a/core/src/blind_rotation/test/cggi.rs b/core/src/blind_rotation/test/cggi.rs new file mode 100644 index 0000000..a04d5c1 --- /dev/null +++ b/core/src/blind_rotation/test/cggi.rs @@ -0,0 +1,179 @@ +use backend::{ + hal::{ + api::{ + MatZnxAlloc, ModuleNew, ScalarZnxAlloc, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAddNormal, + VecZnxAddScalarInplace, VecZnxAlloc, VecZnxAllocBytes, VecZnxEncodeCoeffsi64, VecZnxFillUniform, VecZnxRotateInplace, + VecZnxSwithcDegree, ZnxView, + }, + layouts::{Backend, Module, ScratchOwned}, + oep::{ + ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl, + TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, VecZnxBigAllocBytesImpl, VecZnxDftAllocBytesImpl, + }, + }, + implementation::cpu_spqlios::FFT64, +}; +use sampling::source::Source; + +use crate::{ + BlindRotationKeyCGGIExecLayoutFamily, CCGIBlindRotationFamily, GLWECiphertext, GLWEDecryptFamily, GLWEPlaintext, GLWESecret, + GLWESecretExec, GLWESecretFamily, Infos, LWECiphertext, LWESecret, + blind_rotation::{ + cggi::{cggi_blind_rotate, cggi_blind_rotate_scratch_space, negate_and_mod_switch_2n}, + key::{BlindRotationKeyCGGI, BlindRotationKeyCGGIExec}, + lut::LookUpTable, + }, + lwe::{LWEPlaintext, ciphertext::LWECiphertextToRef}, +}; + +#[test] +fn standard() { + let module: Module = Module::::new(512); + blind_rotatio_test(&module, 224, 1, 1); +} + +#[test] +fn block_binary() { + let module: Module = Module::::new(512); + blind_rotatio_test(&module, 224, 7, 1); +} + +#[test] +fn block_binary_extended() { + let module: Module = Module::::new(512); + blind_rotatio_test(&module, 224, 7, 2); +} + +pub(crate) trait CGGITestModuleFamily = CCGIBlindRotationFamily + + GLWESecretFamily + + GLWEDecryptFamily + + BlindRotationKeyCGGIExecLayoutFamily + + VecZnxAlloc + + ScalarZnxAlloc + + VecZnxFillUniform + + VecZnxAddNormal + + VecZnxAllocBytes + + VecZnxAddScalarInplace + + VecZnxEncodeCoeffsi64 + + VecZnxRotateInplace + + VecZnxSwithcDegree + + MatZnxAlloc; +pub(crate) trait CGGITestScratchFamily = VecZnxDftAllocBytesImpl + + VecZnxBigAllocBytesImpl + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + TakeVecZnxDftImpl + + TakeVecZnxBigImpl + + TakeVecZnxDftSliceImpl + + ScratchAvailableImpl + + TakeVecZnxImpl + + TakeVecZnxSliceImpl; + +fn blind_rotatio_test(module: &Module, n_lwe: usize, block_size: usize, extension_factor: usize) +where + Module: CGGITestModuleFamily, + B: CGGITestScratchFamily, +{ + let basek: usize = 19; + + let k_lwe: usize = 24; + let k_brk: usize = 3 * basek; + let rows_brk: usize = 2; // Ensures first limb is noise-free. + let k_lut: usize = 1 * basek; + let k_res: usize = 2 * basek; + let rank: usize = 1; + + let message_modulus: usize = 1 << 4; + + let mut source_xs: Source = Source::new([2u8; 32]); + let mut source_xe: Source = Source::new([2u8; 32]); + let mut source_xa: Source = Source::new([1u8; 32]); + + let mut sk_glwe: GLWESecret> = GLWESecret::alloc(module, rank); + sk_glwe.fill_ternary_prob(0.5, &mut source_xs); + let sk_glwe_dft: GLWESecretExec, B> = GLWESecretExec::from(module, &sk_glwe); + + let mut sk_lwe: LWESecret> = LWESecret::alloc(n_lwe); + sk_lwe.fill_binary_block(block_size, &mut source_xs); + + let mut scratch: ScratchOwned = ScratchOwned::::alloc(BlindRotationKeyCGGI::generate_from_sk_scratch_space( + module, basek, k_brk, rank, + )); + + let mut scratch_br: ScratchOwned = ScratchOwned::::alloc(cggi_blind_rotate_scratch_space( + module, + block_size, + extension_factor, + basek, + k_res, + k_brk, + rows_brk, + rank, + )); + + let mut brk: BlindRotationKeyCGGI> = BlindRotationKeyCGGI::alloc(module, n_lwe, basek, k_brk, rows_brk, rank); + + brk.generate_from_sk( + module, + &sk_glwe_dft, + &sk_lwe, + &mut source_xa, + &mut source_xe, + 3.2, + scratch.borrow(), + ); + + let mut lwe: LWECiphertext> = LWECiphertext::alloc(n_lwe, basek, k_lwe); + + let mut pt_lwe: LWEPlaintext> = LWEPlaintext::alloc(basek, k_lwe); + + let x: i64 = 2; + let bits: usize = 8; + + module.encode_coeff_i64(basek, &mut pt_lwe.data, 0, bits, 0, x, bits); + + lwe.encrypt_sk( + module, + &pt_lwe, + &sk_lwe, + &mut source_xa, + &mut source_xe, + 3.2, + ); + + let mut f: Vec = vec![0i64; message_modulus]; + f.iter_mut() + .enumerate() + .for_each(|(i, x)| *x = 2 * (i as i64) + 1); + + let mut lut: LookUpTable = LookUpTable::alloc(module, basek, k_lut, extension_factor); + lut.set(module, &f, message_modulus); + + let mut res: GLWECiphertext> = GLWECiphertext::alloc(module, basek, k_res, rank); + + let brk_exec: BlindRotationKeyCGGIExec, B> = BlindRotationKeyCGGIExec::from(module, &brk, scratch_br.borrow()); + + cggi_blind_rotate(module, &mut res, &lwe, &lut, &brk_exec, scratch_br.borrow()); + + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k_res); + + res.decrypt(module, &mut pt_have, &sk_glwe_dft, scratch.borrow()); + + let mut lwe_2n: Vec = vec![0i64; lwe.n() + 1]; // TODO: from scratch space + + negate_and_mod_switch_2n(2 * lut.domain_size(), &mut lwe_2n, &lwe.to_ref()); + + let pt_want: i64 = (lwe_2n[0] + + lwe_2n[1..] + .iter() + .zip(sk_lwe.data.at(0, 0)) + .map(|(x, y)| x * y) + .sum::()) + & (2 * lut.domain_size() - 1) as i64; + + lut.rotate(module, pt_want); + + // First limb should be exactly equal (test are parameterized such that the noise does not reach + // the first limb) + assert_eq!(pt_have.data.at(0, 0), lut.data[0].at(0, 0)); +} diff --git a/core/src/blind_rotation/test_fft64/lut.rs b/core/src/blind_rotation/test/lut.rs similarity index 87% rename from core/src/blind_rotation/test_fft64/lut.rs rename to core/src/blind_rotation/test/lut.rs index 02f710d..bd893fc 100644 --- a/core/src/blind_rotation/test_fft64/lut.rs +++ b/core/src/blind_rotation/test/lut.rs @@ -1,6 +1,12 @@ use std::vec; -use backend::{FFT64, Module, ZnxView}; +use backend::{ + hal::{ + api::{ModuleNew, ZnxView}, + layouts::Module, + }, + implementation::cpu_spqlios::FFT64, +}; use crate::blind_rotation::lut::{DivRound, LookUpTable}; @@ -23,7 +29,7 @@ fn standard() { lut.set(&module, &f, log_scale); let half_step: i64 = lut.domain_size().div_round(message_modulus << 1) as i64; - lut.rotate(half_step); + lut.rotate(&module, half_step); let step: usize = lut.domain_size().div_round(message_modulus); @@ -33,7 +39,7 @@ fn standard() { f[i / step] % message_modulus as i64, lut.data[0].raw()[0] / (1 << (log_scale % basek)) as i64 ); - lut.rotate(-1); + lut.rotate(&module, -1); }); }); } @@ -57,7 +63,7 @@ fn extended() { lut.set(&module, &f, log_scale); let half_step: i64 = lut.domain_size().div_round(message_modulus << 1) as i64; - lut.rotate(half_step); + lut.rotate(&module, half_step); let step: usize = lut.domain_size().div_round(message_modulus); @@ -67,7 +73,7 @@ fn extended() { f[i / step] % message_modulus as i64, lut.data[0].raw()[0] / (1 << (log_scale % basek)) as i64 ); - lut.rotate(-1); + lut.rotate(&module, -1); }); }); } diff --git a/core/src/blind_rotation/test_fft64/mod.rs b/core/src/blind_rotation/test/mod.rs similarity index 100% rename from core/src/blind_rotation/test_fft64/mod.rs rename to core/src/blind_rotation/test/mod.rs diff --git a/core/src/blind_rotation/test_fft64/cggi.rs b/core/src/blind_rotation/test_fft64/cggi.rs deleted file mode 100644 index 2fbad48..0000000 --- a/core/src/blind_rotation/test_fft64/cggi.rs +++ /dev/null @@ -1,125 +0,0 @@ -use backend::{Encoding, FFT64, Module, ScratchOwned, ZnxView}; -use sampling::source::Source; - -use crate::{ - FourierGLWESecret, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, LWECiphertext, LWESecret, - blind_rotation::{ - cggi::{cggi_blind_rotate, cggi_blind_rotate_scratch_space, negate_and_mod_switch_2n}, - key::BlindRotationKeyCGGI, - lut::LookUpTable, - }, - lwe::{LWEPlaintext, ciphertext::LWECiphertextToRef}, -}; - -#[test] -fn standard() { - blind_rotatio_test(224, 1, 1); -} - -#[test] -fn block_binary() { - blind_rotatio_test(224, 7, 1); -} - -#[test] -fn block_binary_extended() { - blind_rotatio_test(224, 7, 2); -} - -fn blind_rotatio_test(n_lwe: usize, block_size: usize, extension_factor: usize) { - let module: Module = Module::::new(512); - let basek: usize = 19; - - let k_lwe: usize = 24; - let k_brk: usize = 3 * basek; - let rows_brk: usize = 2; // Ensures first limb is noise-free. - let k_lut: usize = 1 * basek; - let k_res: usize = 2 * basek; - let rank: usize = 1; - - let message_modulus: usize = 1 << 4; - - let mut source_xs: Source = Source::new([2u8; 32]); - let mut source_xe: Source = Source::new([2u8; 32]); - let mut source_xa: Source = Source::new([1u8; 32]); - - let mut sk_glwe: GLWESecret> = GLWESecret::alloc(&module, rank); - sk_glwe.fill_ternary_prob(0.5, &mut source_xs); - let sk_glwe_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_glwe); - - let mut sk_lwe: LWESecret> = LWESecret::alloc(n_lwe); - sk_lwe.fill_binary_block(block_size, &mut source_xs); - - let mut scratch: ScratchOwned = ScratchOwned::new(BlindRotationKeyCGGI::generate_from_sk_scratch_space( - &module, basek, k_brk, rank, - )); - - let mut scratch_br: ScratchOwned = ScratchOwned::new(cggi_blind_rotate_scratch_space( - &module, - block_size, - extension_factor, - basek, - k_res, - k_brk, - rows_brk, - rank, - )); - - let mut brk: BlindRotationKeyCGGI, FFT64> = - BlindRotationKeyCGGI::allocate(&module, n_lwe, basek, k_brk, rows_brk, rank); - - brk.generate_from_sk( - &module, - &sk_glwe_dft, - &sk_lwe, - &mut source_xa, - &mut source_xe, - 3.2, - scratch.borrow(), - ); - - let mut lwe: LWECiphertext> = LWECiphertext::alloc(n_lwe, basek, k_lwe); - - let mut pt_lwe: LWEPlaintext> = LWEPlaintext::alloc(basek, k_lwe); - - let x: i64 = 2; - let bits: usize = 8; - - pt_lwe.data.encode_coeff_i64(0, basek, bits, 0, x, bits); - - lwe.encrypt_sk(&pt_lwe, &sk_lwe, &mut source_xa, &mut source_xe, 3.2); - - let mut f: Vec = vec![0i64; message_modulus]; - f.iter_mut() - .enumerate() - .for_each(|(i, x)| *x = 2 * (i as i64) + 1); - - let mut lut: LookUpTable = LookUpTable::alloc(&module, basek, k_lut, extension_factor); - lut.set(&module, &f, message_modulus); - - let mut res: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_res, rank); - - cggi_blind_rotate(&module, &mut res, &lwe, &lut, &brk, scratch_br.borrow()); - - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_res); - - res.decrypt(&module, &mut pt_have, &sk_glwe_dft, scratch.borrow()); - - let mut lwe_2n: Vec = vec![0i64; lwe.n() + 1]; // TODO: from scratch space - - negate_and_mod_switch_2n(2 * lut.domain_size(), &mut lwe_2n, &lwe.to_ref()); - - let pt_want: i64 = (lwe_2n[0] - + lwe_2n[1..] - .iter() - .zip(sk_lwe.data.at(0, 0)) - .map(|(x, y)| x * y) - .sum::()) - & (2 * lut.domain_size() - 1) as i64; - - lut.rotate(pt_want); - - // First limb should be exactly equal (test are parameterized such that the noise does not reach - // the first limb) - assert_eq!(pt_have.data.at(0, 0), lut.data[0].at(0, 0)); -} diff --git a/core/src/dist.rs b/core/src/dist.rs index 4a97369..a168530 100644 --- a/core/src/dist.rs +++ b/core/src/dist.rs @@ -1,3 +1,5 @@ +use std::io::{Read, Result, Write}; + #[derive(Clone, Copy, Debug)] pub(crate) enum Distribution { TernaryFixed(usize), // Ternary with fixed Hamming weight @@ -8,3 +10,75 @@ pub(crate) enum Distribution { ZERO, // Debug mod NONE, // Unitialized } + +const TAG_TERNARY_FIXED: u8 = 0; +const TAG_TERNARY_PROB: u8 = 1; +const TAG_BINARY_FIXED: u8 = 2; +const TAG_BINARY_PROB: u8 = 3; +const TAG_BINARY_BLOCK: u8 = 4; +const TAG_ZERO: u8 = 5; +const TAG_NONE: u8 = 6; + +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; + +impl Distribution { + pub fn write_to(&self, writer: &mut W) -> Result<()> { + let word: u64 = match self { + Distribution::TernaryFixed(v) => (TAG_TERNARY_FIXED as u64) << 56 | (*v as u64), + Distribution::TernaryProb(p) => { + let bits = p.to_bits(); // f64 -> u64 bit representation + (TAG_TERNARY_PROB as u64) << 56 | (bits & 0x00FF_FFFF_FFFF_FFFF) + } + Distribution::BinaryFixed(v) => (TAG_BINARY_FIXED as u64) << 56 | (*v as u64), + Distribution::BinaryProb(p) => { + let bits = p.to_bits(); + (TAG_BINARY_PROB as u64) << 56 | (bits & 0x00FF_FFFF_FFFF_FFFF) + } + Distribution::BinaryBlock(v) => (TAG_BINARY_BLOCK as u64) << 56 | (*v as u64), + Distribution::ZERO => (TAG_ZERO as u64) << 56, + Distribution::NONE => (TAG_NONE as u64) << 56, + }; + writer.write_u64::(word) + } + + pub fn read_from(reader: &mut R) -> Result { + let word = reader.read_u64::()?; + let tag = (word >> 56) as u8; + let payload = word & 0x00FF_FFFF_FFFF_FFFF; + + let dist = match tag { + TAG_TERNARY_FIXED => Distribution::TernaryFixed(payload as usize), + TAG_TERNARY_PROB => Distribution::TernaryProb(f64::from_bits(payload)), + TAG_BINARY_FIXED => Distribution::BinaryFixed(payload as usize), + TAG_BINARY_PROB => Distribution::BinaryProb(f64::from_bits(payload)), + TAG_BINARY_BLOCK => Distribution::BinaryBlock(payload as usize), + TAG_ZERO => Distribution::ZERO, + TAG_NONE => Distribution::NONE, + _ => { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "Invalid tag", + )); + } + }; + Ok(dist) + } +} + +impl PartialEq for Distribution { + fn eq(&self, other: &Self) -> bool { + use Distribution::*; + match (self, other) { + (TernaryFixed(a), TernaryFixed(b)) => a == b, + (TernaryProb(a), TernaryProb(b)) => a.to_bits() == b.to_bits(), + (BinaryFixed(a), BinaryFixed(b)) => a == b, + (BinaryProb(a), BinaryProb(b)) => a.to_bits() == b.to_bits(), + (BinaryBlock(a), BinaryBlock(b)) => a == b, + (ZERO, ZERO) => true, + (NONE, NONE) => true, + _ => false, + } + } +} + +impl Eq for Distribution {} diff --git a/core/src/elem.rs b/core/src/elem.rs index 6e15616..6de038b 100644 --- a/core/src/elem.rs +++ b/core/src/elem.rs @@ -1,6 +1,4 @@ -use backend::{Backend, Module, ZnxInfos}; - -use crate::FourierGLWECiphertext; +use backend::hal::api::ZnxInfos; pub trait Infos { type Inner: ZnxInfos; @@ -54,15 +52,3 @@ pub trait SetMetaData { fn set_basek(&mut self, basek: usize); fn set_k(&mut self, k: usize); } - -pub trait GetRow { - fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut FourierGLWECiphertext) - where - R: AsMut<[u8]> + AsRef<[u8]>; -} - -pub trait SetRow { - fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &FourierGLWECiphertext) - where - R: AsRef<[u8]>; -} diff --git a/core/src/fourier_glwe/ciphertext.rs b/core/src/fourier_glwe/ciphertext.rs deleted file mode 100644 index a742e31..0000000 --- a/core/src/fourier_glwe/ciphertext.rs +++ /dev/null @@ -1,45 +0,0 @@ -use backend::{Backend, Module, VecZnxDft, VecZnxDftAlloc}; - -use crate::Infos; - -pub struct FourierGLWECiphertext { - pub data: VecZnxDft, - pub basek: usize, - pub k: usize, -} - -impl FourierGLWECiphertext, B> { - pub fn alloc(module: &Module, basek: usize, k: usize, rank: usize) -> Self { - Self { - data: module.new_vec_znx_dft(rank + 1, k.div_ceil(basek)), - basek: basek, - k: k, - } - } - - pub fn bytes_of(module: &Module, basek: usize, k: usize, rank: usize) -> usize { - module.bytes_of_vec_znx_dft(rank + 1, k.div_ceil(basek)) - } -} - -impl Infos for FourierGLWECiphertext { - type Inner = VecZnxDft; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn basek(&self) -> usize { - self.basek - } - - fn k(&self) -> usize { - self.k - } -} - -impl FourierGLWECiphertext { - pub fn rank(&self) -> usize { - self.cols() - 1 - } -} diff --git a/core/src/fourier_glwe/decryption.rs b/core/src/fourier_glwe/decryption.rs deleted file mode 100644 index 6c18383..0000000 --- a/core/src/fourier_glwe/decryption.rs +++ /dev/null @@ -1,84 +0,0 @@ -use backend::{ - FFT64, Module, ScalarZnxDftOps, Scratch, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDftAlloc, - VecZnxDftOps, ZnxZero, -}; - -use crate::{FourierGLWECiphertext, FourierGLWESecret, GLWECiphertext, GLWEPlaintext, Infos}; - -impl FourierGLWECiphertext, FFT64> { - pub fn decrypt_scratch_space(module: &Module, basek: usize, k: usize) -> usize { - let size: usize = k.div_ceil(basek); - (module.vec_znx_big_normalize_tmp_bytes() - | module.bytes_of_vec_znx_dft(1, size) - | (module.bytes_of_vec_znx_big(1, size) + module.vec_znx_idft_tmp_bytes())) - + module.bytes_of_vec_znx_big(1, size) - } -} - -impl> FourierGLWECiphertext { - pub fn decrypt + AsMut<[u8]>, DataSk: AsRef<[u8]>>( - &self, - module: &Module, - pt: &mut GLWEPlaintext, - sk: &FourierGLWESecret, - scratch: &mut Scratch, - ) { - #[cfg(debug_assertions)] - { - assert_eq!(self.rank(), sk.rank()); - assert_eq!(self.n(), module.n()); - assert_eq!(pt.n(), module.n()); - assert_eq!(sk.n(), module.n()); - } - - let cols = self.rank() + 1; - - let (mut pt_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, self.size()); // TODO optimize size when pt << ct - pt_big.zero(); - - { - (1..cols).for_each(|i| { - let (mut ci_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, self.size()); // TODO optimize size when pt << ct - module.svp_apply(&mut ci_dft, 0, &sk.data, i - 1, &self.data, i); - let ci_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(ci_dft); - module.vec_znx_big_add_inplace(&mut pt_big, 0, &ci_big, 0); - }); - } - - { - let (mut c0_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, self.size()); - // c0_big = (a * s) + (-a * s + m + e) = BIG(m + e) - module.vec_znx_idft(&mut c0_big, 0, &self.data, 0, scratch_2); - module.vec_znx_big_add_inplace(&mut pt_big, 0, &c0_big, 0); - } - - // pt = norm(BIG(m + e)) - module.vec_znx_big_normalize(self.basek(), &mut pt.data, 0, &mut pt_big, 0, scratch_1); - - pt.basek = self.basek(); - pt.k = pt.k().min(self.k()); - } - - #[allow(dead_code)] - pub(crate) fn idft + AsMut<[u8]>>( - &self, - module: &Module, - res: &mut GLWECiphertext, - scratch: &mut Scratch, - ) { - #[cfg(debug_assertions)] - { - assert_eq!(self.rank(), res.rank()); - assert_eq!(self.basek(), res.basek()) - } - - let min_size: usize = self.size().min(res.size()); - - let (mut res_big, scratch1) = scratch.tmp_vec_znx_big(module, 1, min_size); - - (0..self.rank() + 1).for_each(|i| { - module.vec_znx_idft(&mut res_big, 0, &self.data, i, scratch1); - module.vec_znx_big_normalize(self.basek(), &mut res.data, i, &res_big, 0, scratch1); - }); - } -} diff --git a/core/src/fourier_glwe/encryption.rs b/core/src/fourier_glwe/encryption.rs deleted file mode 100644 index fd08709..0000000 --- a/core/src/fourier_glwe/encryption.rs +++ /dev/null @@ -1,32 +0,0 @@ -use backend::{FFT64, Module, Scratch, VecZnxAlloc, VecZnxBigScratch, VecZnxDftOps}; -use sampling::source::Source; - -use crate::{FourierGLWECiphertext, FourierGLWESecret, GLWECiphertext, Infos, ScratchCore}; - -impl FourierGLWECiphertext, FFT64> { - #[allow(dead_code)] - pub(crate) fn idft_scratch_space(module: &Module, basek: usize, k: usize) -> usize { - module.bytes_of_vec_znx(1, k.div_ceil(basek)) - + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes()) - } - - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { - module.bytes_of_vec_znx(rank + 1, k.div_ceil(basek)) + GLWECiphertext::encrypt_sk_scratch_space(module, basek, k) - } -} - -impl + AsRef<[u8]>> FourierGLWECiphertext { - pub fn encrypt_zero_sk>( - &mut self, - module: &Module, - sk: &FourierGLWESecret, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - scratch: &mut Scratch, - ) { - let (mut tmp_ct, scratch1) = scratch.tmp_glwe_ct(module, self.basek(), self.k(), self.rank()); - tmp_ct.encrypt_zero_sk(module, sk, source_xa, source_xe, sigma, scratch1); - tmp_ct.dft(module, self); - } -} diff --git a/core/src/fourier_glwe/external_product.rs b/core/src/fourier_glwe/external_product.rs deleted file mode 100644 index 01a7371..0000000 --- a/core/src/fourier_glwe/external_product.rs +++ /dev/null @@ -1,129 +0,0 @@ -use backend::{ - FFT64, MatZnxDftOps, MatZnxDftScratch, Module, Scratch, VecZnxAlloc, VecZnxBig, VecZnxBigOps, VecZnxBigScratch, - VecZnxDftAlloc, VecZnxDftOps, -}; - -use crate::{FourierGLWECiphertext, GGSWCiphertext, Infos}; - -impl FourierGLWECiphertext, FFT64> { - // WARNING TODO: UPDATE - pub fn external_product_scratch_space( - module: &Module, - basek: usize, - _k_out: usize, - k_in: usize, - k_ggsw: usize, - digits: usize, - rank: usize, - ) -> usize { - let ggsw_size: usize = k_ggsw.div_ceil(basek); - let res_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, ggsw_size); - let in_size: usize = k_in.div_ceil(basek).div_ceil(digits); - let ggsw_size: usize = k_ggsw.div_ceil(basek); - let vmp: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size) - + module.vmp_apply_tmp_bytes(ggsw_size, in_size, in_size, rank + 1, rank + 1, ggsw_size); - let res_small: usize = module.bytes_of_vec_znx(rank + 1, ggsw_size); - let normalize: usize = module.vec_znx_big_normalize_tmp_bytes(); - res_dft + (vmp | (res_small + normalize)) - } - - pub fn external_product_inplace_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_ggsw: usize, - digits: usize, - rank: usize, - ) -> usize { - Self::external_product_scratch_space(module, basek, k_out, k_out, k_ggsw, digits, rank) - } -} - -impl + AsRef<[u8]>> FourierGLWECiphertext { - pub fn external_product, DataRhs: AsRef<[u8]>>( - &mut self, - module: &Module, - lhs: &FourierGLWECiphertext, - rhs: &GGSWCiphertext, - scratch: &mut Scratch, - ) { - let basek: usize = self.basek(); - - #[cfg(debug_assertions)] - { - assert_eq!(rhs.rank(), lhs.rank()); - assert_eq!(rhs.rank(), self.rank()); - assert_eq!(self.basek(), basek); - assert_eq!(lhs.basek(), basek); - assert_eq!(rhs.n(), module.n()); - assert_eq!(self.n(), module.n()); - assert_eq!(lhs.n(), module.n()); - assert!( - scratch.available() - >= FourierGLWECiphertext::external_product_scratch_space( - module, - self.basek(), - self.k(), - lhs.k(), - rhs.k(), - rhs.digits(), - rhs.rank(), - ) - ); - } - - let cols: usize = rhs.rank() + 1; - let digits = rhs.digits(); - - // Space for VMP result in DFT domain and high precision - let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size()); - let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, (lhs.size() + digits - 1) / digits); - - { - (0..digits).for_each(|di| { - a_dft.set_size((lhs.size() + di) / digits); - - // Small optimization for digits > 2 - // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then - // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}. - // As such we can ignore the last digits-2 limbs safely of the sum of vmp products. - // It is possible to further ignore the last digits-1 limbs, but this introduce - // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same - // noise is kept with respect to the ideal functionality. - res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize); - - (0..cols).for_each(|col_i| { - module.vec_znx_dft_copy(digits, digits - 1 - di, &mut a_dft, col_i, &lhs.data, col_i); - }); - - if di == 0 { - module.vmp_apply(&mut res_dft, &a_dft, &rhs.data, scratch2); - } else { - module.vmp_apply_add(&mut res_dft, &a_dft, &rhs.data, di, scratch2); - } - }); - } - - // VMP result in high precision - let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft); - - // Space for VMP result normalized - let (mut res_small, scratch2) = scratch1.tmp_vec_znx(module, cols, rhs.size()); - (0..cols).for_each(|i| { - module.vec_znx_big_normalize(basek, &mut res_small, i, &res_big, i, scratch2); - module.vec_znx_dft(1, 0, &mut self.data, i, &res_small, i); - }); - } - - pub fn external_product_inplace>( - &mut self, - module: &Module, - rhs: &GGSWCiphertext, - scratch: &mut Scratch, - ) { - unsafe { - let self_ptr: *mut FourierGLWECiphertext = self as *mut FourierGLWECiphertext; - self.external_product(&module, &*self_ptr, rhs, scratch); - } - } -} diff --git a/core/src/fourier_glwe/keyswitch.rs b/core/src/fourier_glwe/keyswitch.rs deleted file mode 100644 index 3abb26e..0000000 --- a/core/src/fourier_glwe/keyswitch.rs +++ /dev/null @@ -1,56 +0,0 @@ -use backend::{FFT64, Module, Scratch}; - -use crate::{FourierGLWECiphertext, GLWECiphertext, GLWESwitchingKey, Infos, ScratchCore}; - -impl FourierGLWECiphertext, FFT64> { - pub fn keyswitch_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_in: usize, - k_ksk: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - ) -> usize { - GLWECiphertext::bytes_of(module, basek, k_out, rank_out) - + GLWECiphertext::keyswitch_from_fourier_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank_in, rank_out) - } - - pub fn keyswitch_inplace_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_ksk: usize, - digits: usize, - rank: usize, - ) -> usize { - Self::keyswitch_scratch_space(module, basek, k_out, k_out, k_ksk, digits, rank, rank) - } -} - -impl + AsRef<[u8]>> FourierGLWECiphertext { - pub fn keyswitch, DataRhs: AsRef<[u8]>>( - &mut self, - module: &Module, - lhs: &FourierGLWECiphertext, - rhs: &GLWESwitchingKey, - scratch: &mut Scratch, - ) { - let (mut tmp_ct, scratch1) = scratch.tmp_glwe_ct(module, self.basek(), self.k(), self.rank()); - tmp_ct.keyswitch_from_fourier(module, lhs, rhs, scratch1); - tmp_ct.dft(module, self); - } - - pub fn keyswitch_inplace>( - &mut self, - module: &Module, - rhs: &GLWESwitchingKey, - scratch: &mut Scratch, - ) { - unsafe { - let self_ptr: *mut FourierGLWECiphertext = self as *mut FourierGLWECiphertext; - self.keyswitch(&module, &*self_ptr, rhs, scratch); - } - } -} diff --git a/core/src/fourier_glwe/mod.rs b/core/src/fourier_glwe/mod.rs deleted file mode 100644 index 35c9905..0000000 --- a/core/src/fourier_glwe/mod.rs +++ /dev/null @@ -1,12 +0,0 @@ -pub mod ciphertext; -pub mod decryption; -pub mod encryption; -pub mod external_product; -pub mod keyswitch; -pub mod secret; - -pub use ciphertext::FourierGLWECiphertext; -pub use secret::FourierGLWESecret; - -#[cfg(test)] -pub mod test_fft64; diff --git a/core/src/fourier_glwe/secret.rs b/core/src/fourier_glwe/secret.rs deleted file mode 100644 index 0f28939..0000000 --- a/core/src/fourier_glwe/secret.rs +++ /dev/null @@ -1,58 +0,0 @@ -use backend::{Backend, FFT64, Module, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ZnxInfos}; - -use crate::{GLWESecret, dist::Distribution}; - -pub struct FourierGLWESecret { - pub(crate) data: ScalarZnxDft, - pub(crate) dist: Distribution, -} - -impl FourierGLWESecret, B> { - pub fn alloc(module: &Module, rank: usize) -> Self { - Self { - data: module.new_scalar_znx_dft(rank), - dist: Distribution::NONE, - } - } - - pub fn bytes_of(module: &Module, rank: usize) -> usize { - module.bytes_of_scalar_znx_dft(rank) - } -} - -impl FourierGLWESecret, FFT64> { - pub fn from(module: &Module, sk: &GLWESecret) -> Self - where - D: AsRef<[u8]>, - { - let mut sk_dft: FourierGLWESecret, FFT64> = Self::alloc(module, sk.rank()); - sk_dft.set(module, sk); - sk_dft - } -} - -impl FourierGLWESecret { - pub fn n(&self) -> usize { - self.data.n() - } - - pub fn log_n(&self) -> usize { - self.data.log_n() - } - - pub fn rank(&self) -> usize { - self.data.cols() - } -} - -impl + AsRef<[u8]>> FourierGLWESecret { - pub(crate) fn set(&mut self, module: &Module, sk: &GLWESecret) - where - D: AsRef<[u8]>, - { - (0..self.rank()).for_each(|i| { - module.svp_prepare(&mut self.data, i, &sk.data, i); - }); - self.dist = sk.dist - } -} diff --git a/core/src/fourier_glwe/test_fft64/external_product.rs b/core/src/fourier_glwe/test_fft64/external_product.rs deleted file mode 100644 index 80c9c9a..0000000 --- a/core/src/fourier_glwe/test_fft64/external_product.rs +++ /dev/null @@ -1,246 +0,0 @@ -use crate::{ - FourierGLWECiphertext, FourierGLWESecret, GGSWCiphertext, GLWECiphertext, GLWEOps, GLWEPlaintext, GLWESecret, Infos, - noise::noise_ggsw_product, -}; -use backend::{FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, ZnxViewMut}; -use sampling::source::Source; - -#[test] -fn apply() { - let log_n: usize = 8; - let basek: usize = 12; - let k_in: usize = 45; - let digits: usize = k_in.div_ceil(basek); - (1..4).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_ggsw: usize = k_in + basek * di; - println!("test external_product digits: {} rank: {}", di, rank); - let k_out: usize = k_ggsw; // Better capture noise. - test_apply(log_n, basek, k_out, k_in, k_ggsw, di, rank, 3.2); - }); - }); -} - -#[test] -fn apply_inplace() { - let log_n: usize = 8; - let basek: usize = 12; - let k_ct: usize = 60; - let digits: usize = k_ct.div_ceil(basek); - (1..4).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_ggsw: usize = k_ct + basek * di; - println!("test external_product digits: {} rank: {}", di, rank); - test_apply_inplace(log_n, basek, k_ct, k_ggsw, di, rank, 3.2); - }); - }); -} - -fn test_apply(log_n: usize, basek: usize, k_out: usize, k_in: usize, k_ggsw: usize, digits: usize, rank: usize, sigma: f64) { - let module: Module = Module::::new(1 << log_n); - - let rows: usize = k_in.div_ceil(digits * basek); - - let mut ct_ggsw: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, digits, rank); - let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_in, rank); - let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_out, rank); - let mut ct_in_dft: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_in, rank); - let mut ct_out_dft: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_out, rank); - let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_in); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); - - pt_want.data.at_mut(0, 0)[1] = 1; - - let k: i64 = 1; - - pt_rgsw.raw_mut()[0] = 1; // X^{0} - module.vec_znx_rotate_inplace(k, &mut pt_rgsw, 0); // X^{k} - - let mut scratch: ScratchOwned = ScratchOwned::new( - GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, ct_ggsw.k(), rank) - | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_out.k()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_in.k()) - | FourierGLWECiphertext::external_product_scratch_space( - &module, - basek, - ct_out.k(), - ct_in.k(), - ct_ggsw.k(), - digits, - rank, - ), - ); - - let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); - - ct_ggsw.encrypt_sk( - &module, - &pt_rgsw, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_in.encrypt_sk( - &module, - &pt_want, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_in.dft(&module, &mut ct_in_dft); - ct_out_dft.external_product(&module, &ct_in_dft, &ct_ggsw, scratch.borrow()); - ct_out_dft.idft(&module, &mut ct_out, scratch.borrow()); - - ct_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); - - pt_want.rotate_inplace(&module, k); - pt_have.sub_inplace_ab(&module, &pt_want); - - let noise_have: f64 = pt_have.data.std(0, basek).log2(); - - let var_gct_err_lhs: f64 = sigma * sigma; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} - let var_a0_err: f64 = sigma * sigma; - let var_a1_err: f64 = 1f64 / 12f64; - - let noise_want: f64 = noise_ggsw_product( - module.n() as f64, - basek * digits, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - rank as f64, - k_in, - k_ggsw, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.5, - "{} {}", - noise_have, - noise_want - ); -} - -fn test_apply_inplace(log_n: usize, basek: usize, k_ct: usize, k_ggsw: usize, digits: usize, rank: usize, sigma: f64) { - let module: Module = Module::::new(1 << log_n); - let rows: usize = k_ct.div_ceil(digits * basek); - - let mut ct_ggsw: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, digits, rank); - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); - let mut ct_rlwe_dft: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_ct, rank); - let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); - - pt_want.data.at_mut(0, 0)[1] = 1; - - let k: i64 = 1; - - pt_rgsw.raw_mut()[0] = 1; // X^{0} - module.vec_znx_rotate_inplace(k, &mut pt_rgsw, 0); // X^{k} - - let mut scratch: ScratchOwned = ScratchOwned::new( - GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, ct_ggsw.k(), rank) - | GLWECiphertext::decrypt_scratch_space(&module, basek, ct.k()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k()) - | FourierGLWECiphertext::external_product_inplace_scratch_space(&module, basek, ct.k(), ct_ggsw.k(), digits, rank), - ); - - let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); - - ct_ggsw.encrypt_sk( - &module, - &pt_rgsw, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct.encrypt_sk( - &module, - &pt_want, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct.dft(&module, &mut ct_rlwe_dft); - ct_rlwe_dft.external_product_inplace(&module, &ct_ggsw, scratch.borrow()); - ct_rlwe_dft.idft(&module, &mut ct, scratch.borrow()); - - ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); - - pt_want.rotate_inplace(&module, k); - pt_have.sub_inplace_ab(&module, &pt_want); - - let noise_have: f64 = pt_have.data.std(0, basek).log2(); - - let var_gct_err_lhs: f64 = sigma * sigma; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} - let var_a0_err: f64 = sigma * sigma; - let var_a1_err: f64 = 1f64 / 12f64; - - let noise_want: f64 = noise_ggsw_product( - module.n() as f64, - basek * digits, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - rank as f64, - k_ct, - k_ggsw, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.5, - "{} {}", - noise_have, - noise_want - ); - - println!("{} {}", noise_have, noise_want); -} diff --git a/core/src/fourier_glwe/test_fft64/keyswitch.rs b/core/src/fourier_glwe/test_fft64/keyswitch.rs deleted file mode 100644 index 61c8b27..0000000 --- a/core/src/fourier_glwe/test_fft64/keyswitch.rs +++ /dev/null @@ -1,235 +0,0 @@ -use crate::{ - FourierGLWECiphertext, FourierGLWESecret, GLWECiphertext, GLWEPlaintext, GLWESecret, GLWESwitchingKey, Infos, - noise::log2_std_noise_gglwe_product, -}; -use backend::{FFT64, FillUniform, Module, ScratchOwned, Stats, VecZnxOps}; -use sampling::source::Source; - -#[test] -fn apply() { - let log_n: usize = 8; - let basek: usize = 12; - let k_in: usize = 45; - let digits: usize = k_in.div_ceil(basek); - (1..4).for_each(|rank_in| { - (1..4).for_each(|rank_out| { - (1..digits + 1).for_each(|di| { - let k_ksk: usize = k_in + basek * di; - println!( - "test keyswitch digits: {} rank_in: {} rank_out: {}", - di, rank_in, rank_out - ); - let k_out: usize = k_ksk; // Better capture noise. - test_apply(log_n, basek, k_in, k_out, k_ksk, di, rank_in, rank_out, 3.2); - }) - }); - }); -} - -#[test] -fn apply_inplace() { - let log_n: usize = 8; - let basek: usize = 12; - let k_ct: usize = 45; - let digits: usize = k_ct.div_ceil(basek); - (1..4).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_ksk: usize = k_ct + basek * di; - println!("test keyswitch_inplace digits: {} rank: {}", di, rank); - test_apply_inplace(log_n, basek, k_ct, k_ksk, di, rank, 3.2); - }); - }); -} - -fn test_apply( - log_n: usize, - basek: usize, - k_in: usize, - k_out: usize, - k_ksk: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - sigma: f64, -) { - let module: Module = Module::::new(1 << log_n); - - let rows: usize = k_in.div_ceil(basek * digits); - - let mut ksk: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank_in, rank_out); - let mut ct_glwe_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_in, rank_in); - let mut ct_glwe_dft_in: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_in, rank_in); - let mut ct_glwe_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_out, rank_out); - let mut ct_glwe_dft_out: FourierGLWECiphertext, FFT64> = - FourierGLWECiphertext::alloc(&module, basek, k_out, rank_out); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_in); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank_in, rank_out) - | GLWECiphertext::decrypt_scratch_space(&module, basek, k_out) - | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, k_in) - | FourierGLWECiphertext::keyswitch_scratch_space( - &module, - basek, - ct_glwe_out.k(), - ksk.k(), - ct_glwe_in.k(), - digits, - rank_in, - rank_out, - ), - ); - - let mut sk_in: GLWESecret> = GLWESecret::alloc(&module, rank_in); - sk_in.fill_ternary_prob(0.5, &mut source_xs); - let sk_in_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_in); - - let mut sk_out: GLWESecret> = GLWESecret::alloc(&module, rank_out); - sk_out.fill_ternary_prob(0.5, &mut source_xs); - let sk_out_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_out); - - ksk.encrypt_sk( - &module, - &sk_in, - &sk_out_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_glwe_in.encrypt_sk( - &module, - &pt_want, - &sk_in_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_glwe_in.dft(&module, &mut ct_glwe_dft_in); - ct_glwe_dft_out.keyswitch(&module, &ct_glwe_dft_in, &ksk, scratch.borrow()); - ct_glwe_dft_out.idft(&module, &mut ct_glwe_out, scratch.borrow()); - - ct_glwe_out.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); - - let noise_have: f64 = pt_have.data.std(0, basek).log2(); - let noise_want: f64 = log2_std_noise_gglwe_product( - module.n() as f64, - basek * digits, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - rank_in as f64, - k_in, - k_ksk, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.5, - "{} {}", - noise_have, - noise_want - ); -} - -fn test_apply_inplace(log_n: usize, basek: usize, k_ct: usize, k_ksk: usize, digits: usize, rank: usize, sigma: f64) { - let module: Module = Module::::new(1 << log_n); - - let rows: usize = k_ct.div_ceil(basek * digits); - - let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank, rank); - let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); - let mut ct_rlwe_dft: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_ct, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ksk.k(), rank, rank) - | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_glwe.k()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_glwe.k()) - | FourierGLWECiphertext::keyswitch_inplace_scratch_space(&module, basek, ct_rlwe_dft.k(), ksk.k(), digits, rank), - ); - - let mut sk_in: GLWESecret> = GLWESecret::alloc(&module, rank); - sk_in.fill_ternary_prob(0.5, &mut source_xs); - let sk_in_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_in); - - let mut sk_out: GLWESecret> = GLWESecret::alloc(&module, rank); - sk_out.fill_ternary_prob(0.5, &mut source_xs); - let sk_out_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_out); - - ksk.encrypt_sk( - &module, - &sk_in, - &sk_out_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_glwe.encrypt_sk( - &module, - &pt_want, - &sk_in_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_glwe.dft(&module, &mut ct_rlwe_dft); - ct_rlwe_dft.keyswitch_inplace(&module, &ksk, scratch.borrow()); - ct_rlwe_dft.idft(&module, &mut ct_glwe, scratch.borrow()); - - ct_glwe.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); - - let noise_have: f64 = pt_have.data.std(0, basek).log2(); - let noise_want: f64 = log2_std_noise_gglwe_product( - module.n() as f64, - basek * digits, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - rank as f64, - k_ct, - k_ksk, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.5, - "{} {}", - noise_have, - noise_want - ); -} diff --git a/core/src/fourier_glwe/test_fft64/mod.rs b/core/src/fourier_glwe/test_fft64/mod.rs deleted file mode 100644 index 784c37c..0000000 --- a/core/src/fourier_glwe/test_fft64/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod external_product; -pub mod keyswitch; diff --git a/core/src/gglwe/automorphism.rs b/core/src/gglwe/automorphism.rs index 07fcb14..06c4f63 100644 --- a/core/src/gglwe/automorphism.rs +++ b/core/src/gglwe/automorphism.rs @@ -1,44 +1,52 @@ -use backend::{FFT64, Module, Scratch, VecZnx, VecZnxDftOps, VecZnxOps, ZnxZero}; +use backend::hal::{ + api::{ScratchAvailable, TakeVecZnxDft, VecZnxAutomorphism, VecZnxAutomorphismInplace, ZnxZero}, + layouts::{Backend, DataMut, DataRef, Module, Scratch}, +}; -use crate::{FourierGLWECiphertext, GLWEAutomorphismKey, GLWECiphertext, GetRow, Infos, ScratchCore, SetRow}; +use crate::{AutomorphismKey, AutomorphismKeyExec, GLWECiphertext, GLWEKeyswitchFamily, Infos}; -impl GLWEAutomorphismKey, FFT64> { - pub fn automorphism_scratch_space( - module: &Module, +impl AutomorphismKey> { + pub fn automorphism_scratch_space( + module: &Module, basek: usize, k_out: usize, k_in: usize, k_ksk: usize, digits: usize, rank: usize, - ) -> usize { - let tmp_dft: usize = FourierGLWECiphertext::bytes_of(module, basek, k_in, rank); - let tmp_idft: usize = FourierGLWECiphertext::bytes_of(module, basek, k_out, rank); - let idft: usize = module.vec_znx_idft_tmp_bytes(); - let keyswitch: usize = GLWECiphertext::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank); - tmp_dft + tmp_idft + idft + keyswitch + ) -> usize + where + Module: GLWEKeyswitchFamily, + { + GLWECiphertext::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank, rank) } - pub fn automorphism_inplace_scratch_space( - module: &Module, + pub fn automorphism_inplace_scratch_space( + module: &Module, basek: usize, k_out: usize, k_ksk: usize, digits: usize, rank: usize, - ) -> usize { - GLWEAutomorphismKey::automorphism_scratch_space(module, basek, k_out, k_out, k_ksk, digits, rank) + ) -> usize + where + Module: GLWEKeyswitchFamily, + { + AutomorphismKey::automorphism_scratch_space(module, basek, k_out, k_out, k_ksk, digits, rank) } } -impl + AsRef<[u8]>> GLWEAutomorphismKey { - pub fn automorphism, DataRhs: AsRef<[u8]>>( +impl AutomorphismKey { + pub fn automorphism<'a, DataLhs: DataRef, DataRhs: DataRef, B: Backend>( &mut self, - module: &Module, - lhs: &GLWEAutomorphismKey, - rhs: &GLWEAutomorphismKey, - scratch: &mut Scratch, - ) { + module: &Module, + lhs: &AutomorphismKey, + rhs: &AutomorphismKeyExec, + scratch: &mut Scratch, + ) where + Module: GLWEKeyswitchFamily + VecZnxAutomorphism + VecZnxAutomorphismInplace, + Scratch: TakeVecZnxDft + ScratchAvailable, + { #[cfg(debug_assertions)] { assert_eq!( @@ -72,78 +80,49 @@ impl + AsRef<[u8]>> GLWEAutomorphismKey { let cols_out: usize = rhs.rank_out() + 1; + let p: i64 = lhs.p(); + let p_inv = module.galois_element_inv(p); + (0..self.rank_in()).for_each(|col_i| { (0..self.rows()).for_each(|row_j| { - let (mut tmp_idft_data, scratct1) = scratch.tmp_vec_znx_big(module, cols_out, self.size()); + let mut res_ct: GLWECiphertext<&mut [u8]> = self.at_mut(row_j, col_i); + let lhs_ct: GLWECiphertext<&[u8]> = lhs.at(row_j, col_i); - { - let (mut tmp_dft, scratch2) = scratct1.tmp_fourier_glwe_ct(module, lhs.basek(), lhs.k(), lhs.rank()); - - // Extracts relevant row - lhs.get_row(module, row_j, col_i, &mut tmp_dft); - - // Get a VecZnxBig from scratch space - - // Switches input outside of DFT - (0..cols_out).for_each(|i| { - module.vec_znx_idft(&mut tmp_idft_data, i, &tmp_dft.data, i, scratch2); - }); - } - - // Consumes to small vec znx - let mut tmp_idft_small_data: VecZnx<&mut [u8]> = tmp_idft_data.to_vec_znx_small(); - - // Reverts the automorphis key from (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a) + // Reverts the automorphism X^{-k}: (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a) (0..cols_out).for_each(|i| { - module.vec_znx_automorphism_inplace(lhs.p(), &mut tmp_idft_small_data, i); + module.vec_znx_automorphism(lhs.p(), &mut res_ct.data, i, &lhs_ct.data, i); }); - // Wraps into ciphertext - let mut tmp_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { - data: tmp_idft_small_data, - basek: self.basek(), - k: self.k(), - }; - // Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) - tmp_idft.keyswitch_inplace(module, &rhs.key, scratct1); + res_ct.keyswitch_inplace(module, &rhs.key, scratch); - { - let (mut tmp_dft, _) = scratct1.tmp_fourier_glwe_ct(module, self.basek(), self.k(), self.rank()); - - // Applies back the automorphism X^{k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) -> (-pi^{-1}_{k'+k}(s)a + s, a) - // and switches back to DFT domain - (0..self.rank_out() + 1).for_each(|i| { - module.vec_znx_automorphism_inplace(lhs.p(), &mut tmp_idft.data, i); - module.vec_znx_dft(1, 0, &mut tmp_dft.data, i, &tmp_idft.data, i); - }); - - // Sets back the relevant row - self.set_row(module, row_j, col_i, &tmp_dft); - } + // Applies back the automorphism X^{-k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) to (-pi^{-1}_{k'+k}(s)a + s, a) + (0..cols_out).for_each(|i| { + module.vec_znx_automorphism_inplace(p_inv, &mut res_ct.data, i); + }); }); }); - let (mut tmp_dft, _) = scratch.tmp_fourier_glwe_ct(module, self.basek(), self.k(), self.rank()); - tmp_dft.data.zero(); - (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| { (0..self.rank_in()).for_each(|col_j| { - self.set_row(module, row_i, col_j, &tmp_dft); + self.at_mut(row_i, col_j).data.zero(); }); }); self.p = (lhs.p * rhs.p) % (module.cyclotomic_order() as i64); } - pub fn automorphism_inplace>( + pub fn automorphism_inplace( &mut self, - module: &Module, - rhs: &GLWEAutomorphismKey, - scratch: &mut Scratch, - ) { + module: &Module, + rhs: &AutomorphismKeyExec, + scratch: &mut Scratch, + ) where + Module: GLWEKeyswitchFamily + VecZnxAutomorphism + VecZnxAutomorphismInplace, + Scratch: TakeVecZnxDft + ScratchAvailable, + { unsafe { - let self_ptr: *mut GLWEAutomorphismKey = self as *mut GLWEAutomorphismKey; + let self_ptr: *mut AutomorphismKey = self as *mut AutomorphismKey; self.automorphism(&module, &*self_ptr, rhs, scratch); } } diff --git a/core/src/gglwe/automorphism_key.rs b/core/src/gglwe/automorphism_key.rs index 38a7f6e..a671ad9 100644 --- a/core/src/gglwe/automorphism_key.rs +++ b/core/src/gglwe/automorphism_key.rs @@ -1,27 +1,37 @@ -use backend::{Backend, FFT64, MatZnxDft, MatZnxDftOps, Module}; +use backend::hal::{ + api::{MatZnxAlloc, MatZnxAllocBytes}, + layouts::{Backend, Data, DataMut, DataRef, MatZnx, Module, ReaderFrom, Scratch, VmpPMat, WriterTo}, +}; -use crate::{FourierGLWECiphertext, GLWESwitchingKey, GetRow, Infos, SetRow}; +use crate::{GGLWEExecLayoutFamily, GLWECiphertext, GLWESwitchingKey, GLWESwitchingKeyExec, Infos}; -pub struct GLWEAutomorphismKey { - pub(crate) key: GLWESwitchingKey, +#[derive(PartialEq, Eq)] +pub struct AutomorphismKey { + pub(crate) key: GLWESwitchingKey, pub(crate) p: i64, } -impl GLWEAutomorphismKey, FFT64> { - pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self { - GLWEAutomorphismKey { +impl AutomorphismKey> { + pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self + where + Module: MatZnxAlloc, + { + AutomorphismKey { key: GLWESwitchingKey::alloc(module, basek, k, rows, digits, rank, rank), p: 0, } } - pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize { - GLWESwitchingKey::, FFT64>::bytes_of(module, basek, k, rows, digits, rank, rank) + pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize + where + Module: MatZnxAllocBytes, + { + GLWESwitchingKey::>::bytes_of(module, basek, k, rows, digits, rank, rank) } } -impl Infos for GLWEAutomorphismKey { - type Inner = MatZnxDft; +impl Infos for AutomorphismKey { + type Inner = MatZnx; fn inner(&self) -> &Self::Inner { &self.key.inner() @@ -36,7 +46,7 @@ impl Infos for GLWEAutomorphismKey { } } -impl GLWEAutomorphismKey { +impl AutomorphismKey { pub fn p(&self) -> i64 { self.p } @@ -58,26 +68,120 @@ impl GLWEAutomorphismKey { } } -impl> GetRow for GLWEAutomorphismKey { - fn get_row + AsRef<[u8]>>( - &self, - module: &Module, - row_i: usize, - col_j: usize, - res: &mut FourierGLWECiphertext, - ) { - module.mat_znx_dft_get_row(&mut res.data, &self.key.key.data, row_i, col_j); +impl AutomorphismKey { + pub fn at(&self, row: usize, col: usize) -> GLWECiphertext<&[u8]> { + self.key.at(row, col) } } -impl + AsRef<[u8]>> SetRow for GLWEAutomorphismKey { - fn set_row>( - &mut self, - module: &Module, - row_i: usize, - col_j: usize, - a: &FourierGLWECiphertext, - ) { - module.mat_znx_dft_set_row(&mut self.key.key.data, row_i, col_j, &a.data); +impl AutomorphismKey { + pub fn at_mut(&mut self, row: usize, col: usize) -> GLWECiphertext<&mut [u8]> { + self.key.at_mut(row, col) + } +} + +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; + +impl ReaderFrom for AutomorphismKey { + fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { + self.p = reader.read_u64::()? as i64; + self.key.read_from(reader) + } +} + +impl WriterTo for AutomorphismKey { + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + writer.write_u64::(self.p as u64)?; + self.key.write_to(writer) + } +} + +#[derive(PartialEq, Eq)] +pub struct AutomorphismKeyExec { + pub(crate) key: GLWESwitchingKeyExec, + pub(crate) p: i64, +} + +impl AutomorphismKeyExec, B> { + pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self + where + Module: GGLWEExecLayoutFamily, + { + AutomorphismKeyExec::, B> { + key: GLWESwitchingKeyExec::alloc(module, basek, k, rows, digits, rank, rank), + p: 0, + } + } + + pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize + where + Module: GGLWEExecLayoutFamily, + { + GLWESwitchingKeyExec::, B>::bytes_of(module, basek, k, rows, digits, rank, rank) + } + + pub fn from(module: &Module, other: &AutomorphismKey, scratch: &mut Scratch) -> Self + where + Module: GGLWEExecLayoutFamily, + { + let mut atk_exec: AutomorphismKeyExec, B> = Self::alloc( + module, + other.basek(), + other.k(), + other.rows(), + other.digits(), + other.rank(), + ); + atk_exec.prepare(module, other, scratch); + atk_exec + } +} + +impl AutomorphismKeyExec { + pub fn prepare(&mut self, module: &Module, other: &AutomorphismKey, scratch: &mut Scratch) + where + DataOther: DataRef, + Module: GGLWEExecLayoutFamily, + { + self.key.prepare(module, &other.key, scratch); + self.p = other.p; + } +} + +impl Infos for AutomorphismKeyExec { + type Inner = VmpPMat; + + fn inner(&self) -> &Self::Inner { + &self.key.inner() + } + + fn basek(&self) -> usize { + self.key.basek() + } + + fn k(&self) -> usize { + self.key.k() + } +} + +impl AutomorphismKeyExec { + pub fn p(&self) -> i64 { + self.p + } + + pub fn digits(&self) -> usize { + self.key.digits() + } + + pub fn rank(&self) -> usize { + self.key.rank() + } + + pub fn rank_in(&self) -> usize { + self.key.rank_in() + } + + pub fn rank_out(&self) -> usize { + self.key.rank_out() } } diff --git a/core/src/gglwe/ciphertext.rs b/core/src/gglwe/ciphertext.rs deleted file mode 100644 index 340b897..0000000 --- a/core/src/gglwe/ciphertext.rs +++ /dev/null @@ -1,131 +0,0 @@ -use backend::{Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, Module}; - -use crate::{FourierGLWECiphertext, GetRow, Infos, SetRow}; - -pub struct GGLWECiphertext { - pub(crate) data: MatZnxDft, - pub(crate) basek: usize, - pub(crate) k: usize, - pub(crate) digits: usize, -} - -impl GGLWECiphertext, B> { - pub fn alloc( - module: &Module, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - ) -> Self { - let size: usize = k.div_ceil(basek); - debug_assert!( - size > digits, - "invalid gglwe: ceil(k/basek): {} <= digits: {}", - size, - digits - ); - - assert!( - rows * digits <= size, - "invalid gglwe: rows: {} * digits:{} > ceil(k/basek): {}", - rows, - digits, - size - ); - - Self { - data: module.new_mat_znx_dft(rows, rank_in, rank_out + 1, size), - basek: basek, - k, - digits, - } - } - - pub fn bytes_of( - module: &Module, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - ) -> usize { - let size: usize = k.div_ceil(basek); - debug_assert!( - size > digits, - "invalid gglwe: ceil(k/basek): {} <= digits: {}", - size, - digits - ); - - assert!( - rows * digits <= size, - "invalid gglwe: rows: {} * digits:{} > ceil(k/basek): {}", - rows, - digits, - size - ); - - module.bytes_of_mat_znx_dft(rows, rank_in, rank_out + 1, rows) - } -} - -impl Infos for GGLWECiphertext { - type Inner = MatZnxDft; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn basek(&self) -> usize { - self.basek - } - - fn k(&self) -> usize { - self.k - } -} - -impl GGLWECiphertext { - pub fn rank(&self) -> usize { - self.data.cols_out() - 1 - } - - pub fn digits(&self) -> usize { - self.digits - } - - pub fn rank_in(&self) -> usize { - self.data.cols_in() - } - - pub fn rank_out(&self) -> usize { - self.data.cols_out() - 1 - } -} - -impl> GetRow for GGLWECiphertext { - fn get_row + AsRef<[u8]>>( - &self, - module: &Module, - row_i: usize, - col_j: usize, - res: &mut FourierGLWECiphertext, - ) { - module.mat_znx_dft_get_row(&mut res.data, &self.data, row_i, col_j); - } -} - -impl + AsRef<[u8]>> SetRow for GGLWECiphertext { - fn set_row>( - &mut self, - module: &Module, - row_i: usize, - col_j: usize, - a: &FourierGLWECiphertext, - ) { - module.mat_znx_dft_set_row(&mut self.data, row_i, col_j, &a.data); - } -} diff --git a/core/src/gglwe/encryption.rs b/core/src/gglwe/encryption.rs index 3e0a7f4..dc54cce 100644 --- a/core/src/gglwe/encryption.rs +++ b/core/src/gglwe/encryption.rs @@ -1,41 +1,52 @@ -use backend::{ - FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScalarZnxOps, Scratch, VecZnxAlloc, VecZnxDftAlloc, VecZnxOps, - ZnxInfos, ZnxView, ZnxViewMut, ZnxZero, +use backend::hal::{ + api::{ + ScalarZnxAllocBytes, ScalarZnxAutomorphism, ScratchAvailable, SvpApply, TakeScalarZnx, TakeVecZnx, TakeVecZnxBig, + TakeVecZnxDft, VecZnxAddScalarInplace, VecZnxAllocBytes, VecZnxBigAllocBytes, VecZnxDftToVecZnxBigTmpA, + VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSwithcDegree, ZnxZero, + }, + layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch}, }; use sampling::source::Source; use crate::{ - FourierGLWESecret, GGLWECiphertext, GLWEAutomorphismKey, GLWECiphertext, GLWESecret, GLWESwitchingKey, GLWETensorKey, Infos, - ScratchCore, SetRow, + AutomorphismKey, GGLWECiphertext, GLWECiphertext, GLWEDecryptFamily, GLWEEncryptSkFamily, GLWEPlaintext, GLWESecret, + GLWESecretExec, GLWESecretFamily, GLWESwitchingKey, GLWETensorKey, Infos, TakeGLWEPt, TakeGLWESecret, TakeGLWESecretExec, }; -impl GGLWECiphertext, FFT64> { - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { - let size = k.div_ceil(basek); +pub trait GGLWEEncryptSkFamily = GLWEEncryptSkFamily + GLWESecretFamily; + +impl GGLWECiphertext> { + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize) -> usize + where + Module: GGLWEEncryptSkFamily + VecZnxAllocBytes, + { GLWECiphertext::encrypt_sk_scratch_space(module, basek, k) - + module.bytes_of_vec_znx(rank + 1, size) - + module.bytes_of_vec_znx(1, size) - + module.bytes_of_vec_znx_dft(rank + 1, size) + + (GLWEPlaintext::byte_of(module, basek, k) | module.vec_znx_normalize_tmp_bytes(module.n())) } - pub fn encrypt_pk_scratch_space(_module: &Module, _basek: usize, _k: usize, _rank: usize) -> usize { + pub fn encrypt_pk_scratch_space(_module: &Module, _basek: usize, _k: usize, _rank: usize) -> usize { unimplemented!() } } -impl + AsRef<[u8]>> GGLWECiphertext { - pub fn encrypt_sk, DataSk: AsRef<[u8]>>( +impl GGLWECiphertext { + pub fn encrypt_sk( &mut self, - module: &Module, + module: &Module, pt: &ScalarZnx, - sk: &FourierGLWESecret, + sk: &GLWESecretExec, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, - scratch: &mut Scratch, - ) { + scratch: &mut Scratch, + ) where + Module: GGLWEEncryptSkFamily + VecZnxAllocBytes + VecZnxAddScalarInplace, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + { #[cfg(debug_assertions)] { + use backend::hal::api::ZnxInfos; + assert_eq!( self.rank_in(), pt.cols(), @@ -54,12 +65,12 @@ impl + AsRef<[u8]>> GGLWECiphertext { assert_eq!(sk.n(), module.n()); assert_eq!(pt.n(), module.n()); assert!( - scratch.available() >= GGLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k(), self.rank()), + scratch.available() >= GGLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k()), "scratch.available: {} < GGLWECiphertext::encrypt_sk_scratch_space(module, self.rank()={}, self.size()={}): {}", scratch.available(), self.rank(), self.size(), - GGLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k(), self.rank()) + GGLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k()) ); assert!( self.rows() * self.digits() * self.basek() <= self.k(), @@ -77,12 +88,8 @@ impl + AsRef<[u8]>> GGLWECiphertext { let basek: usize = self.basek(); let k: usize = self.k(); let rank_in: usize = self.rank_in(); - let rank_out: usize = self.rank_out(); - - let (mut tmp_pt, scrach_1) = scratch.tmp_glwe_pt(module, basek, k); - let (mut tmp_ct, scrach_2) = scrach_1.tmp_glwe_ct(module, basek, k, rank_out); - let (mut tmp_ct_dft, scratch_3) = scrach_2.tmp_fourier_glwe_ct(module, basek, k, rank_out); + let (mut tmp_pt, scrach_1) = scratch.take_glwe_pt(module, basek, k); // For each input column (i.e. rank) produces a GGLWE ciphertext of rank_out+1 columns // // Example for ksk rank 2 to rank 3: @@ -105,30 +112,36 @@ impl + AsRef<[u8]>> GGLWECiphertext { pt, col_i, ); - module.vec_znx_normalize_inplace(basek, &mut tmp_pt.data, 0, scratch_3); + module.vec_znx_normalize_inplace(basek, &mut tmp_pt.data, 0, scrach_1); // rlwe encrypt of vec_znx_pt into vec_znx_ct - tmp_ct.encrypt_sk(module, &tmp_pt, sk, source_xa, source_xe, sigma, scratch_3); - - // Switch vec_znx_ct into DFT domain - tmp_ct.dft(module, &mut tmp_ct_dft); - - // Stores vec_znx_dft_ct into thw i-th row of the MatZnxDft - self.set_row(module, row_i, col_i, &tmp_ct_dft); + self.at_mut(row_i, col_i) + .encrypt_sk(module, &tmp_pt, sk, source_xa, source_xe, sigma, scrach_1); }); }); } } -impl GLWESwitchingKey, FFT64> { - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank_in: usize, rank_out: usize) -> usize { - GGLWECiphertext::encrypt_sk_scratch_space(module, basek, k, rank_out) - + module.bytes_of_scalar_znx(rank_in) - + FourierGLWESecret::bytes_of(module, rank_out) +pub trait GLWESwitchingKeyEncryptSkFamily = GGLWEEncryptSkFamily; + +impl GLWESwitchingKey> { + pub fn encrypt_sk_scratch_space( + module: &Module, + basek: usize, + k: usize, + rank_in: usize, + rank_out: usize, + ) -> usize + where + Module: GLWESwitchingKeyEncryptSkFamily + ScalarZnxAllocBytes + VecZnxAllocBytes, + { + (GGLWECiphertext::encrypt_sk_scratch_space(module, basek, k) | module.scalar_znx_alloc_bytes(1)) + + module.scalar_znx_alloc_bytes(rank_in) + + GLWESecretExec::bytes_of(module, rank_out) } - pub fn encrypt_pk_scratch_space( - module: &Module, + pub fn encrypt_pk_scratch_space( + module: &Module, _basek: usize, _k: usize, _rank_in: usize, @@ -138,46 +151,63 @@ impl GLWESwitchingKey, FFT64> { } } -impl + AsRef<[u8]>> GLWESwitchingKey { - pub fn encrypt_sk, DataSkOut: AsRef<[u8]>>( +impl GLWESwitchingKey { + pub fn encrypt_sk( &mut self, - module: &Module, + module: &Module, sk_in: &GLWESecret, - sk_out: &FourierGLWESecret, + sk_out: &GLWESecret, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, - scratch: &mut Scratch, - ) { + scratch: &mut Scratch, + ) where + Module: GLWESwitchingKeyEncryptSkFamily + + ScalarZnxAllocBytes + + VecZnxSwithcDegree + + VecZnxAllocBytes + + VecZnxAddScalarInplace, + Scratch: + ScratchAvailable + TakeScalarZnx + TakeVecZnxDft + TakeGLWESecretExec + ScratchAvailable + TakeVecZnx, + { #[cfg(debug_assertions)] { assert!(sk_in.n() <= module.n()); assert!(sk_out.n() <= module.n()); + assert!( + scratch.available() + >= GLWESwitchingKey::encrypt_sk_scratch_space( + module, + self.basek(), + self.k(), + self.rank_in(), + self.rank_out() + ), + "scratch.available()={} < GLWESwitchingKey::encrypt_sk_scratch_space={}", + scratch.available(), + GLWESwitchingKey::encrypt_sk_scratch_space( + module, + self.basek(), + self.k(), + self.rank_in(), + self.rank_out() + ) + ) } - let (mut sk_in_tmp, scratch1) = scratch.tmp_scalar_znx(module, sk_in.rank()); - sk_in_tmp.zero(); - + let (mut sk_in_tmp, scratch1) = scratch.take_scalar_znx(module, sk_in.rank()); (0..sk_in.rank()).for_each(|i| { - sk_in_tmp - .at_mut(i, 0) - .iter_mut() - .step_by(module.n() / sk_in.n()) - .zip(sk_in.data.at(i, 0).iter()) - .for_each(|(x, y)| *x = *y); + module.vec_znx_switch_degree(&mut sk_in_tmp, i, &sk_in.data, i); }); - let (mut sk_out_tmp, scratch2) = scratch1.tmp_fourier_glwe_secret(module, sk_out.rank()); - (0..sk_out.rank()).for_each(|i| { - sk_out_tmp - .data - .at_mut(i, 0) - .chunks_exact_mut(module.n() / sk_out.n()) - .zip(sk_out.data.at(i, 0).iter()) - .for_each(|(a_chunk, &b_elem)| { - a_chunk.fill(b_elem); - }); - }); + let (mut sk_out_tmp, scratch2) = scratch1.take_glwe_secret_exec(module, sk_out.rank()); + { + let (mut tmp, _) = scratch2.take_scalar_znx(module, 1); + (0..sk_out.rank()).for_each(|i| { + module.vec_znx_switch_degree(&mut tmp, 0, &sk_out.data, i); + module.svp_prepare(&mut sk_out_tmp.data, i, &tmp, 0); + }); + } self.key.encrypt_sk( module, @@ -193,27 +223,40 @@ impl + AsRef<[u8]>> GLWESwitchingKey { } } -impl GLWEAutomorphismKey, FFT64> { - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { +pub trait AutomorphismKeyEncryptSkFamily = GGLWEEncryptSkFamily; + +impl AutomorphismKey> { + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize + where + Module: AutomorphismKeyEncryptSkFamily + ScalarZnxAllocBytes + VecZnxAllocBytes, + { GLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, rank, rank) + GLWESecret::bytes_of(module, rank) } - pub fn encrypt_pk_scratch_space(module: &Module, _basek: usize, _k: usize, _rank: usize) -> usize { + pub fn encrypt_pk_scratch_space(module: &Module, _basek: usize, _k: usize, _rank: usize) -> usize { GLWESwitchingKey::encrypt_pk_scratch_space(module, _basek, _k, _rank, _rank) } } -impl + AsRef<[u8]>> GLWEAutomorphismKey { - pub fn encrypt_sk>( +impl AutomorphismKey { + pub fn encrypt_sk( &mut self, - module: &Module, + module: &Module, p: i64, sk: &GLWESecret, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, - scratch: &mut Scratch, - ) { + scratch: &mut Scratch, + ) where + Module: AutomorphismKeyEncryptSkFamily + + ScalarZnxAutomorphism + + ScalarZnxAllocBytes + + VecZnxAllocBytes + + VecZnxSwithcDegree + + VecZnxAddScalarInplace, + Scratch: ScratchAvailable + TakeScalarZnx + TakeVecZnxDft + TakeGLWESecretExec + TakeVecZnx, + { #[cfg(debug_assertions)] { assert_eq!(self.n(), module.n()); @@ -221,19 +264,18 @@ impl + AsRef<[u8]>> GLWEAutomorphismKey { assert_eq!(self.rank_out(), self.rank_in()); assert_eq!(sk.rank(), self.rank()); assert!( - scratch.available() >= GLWEAutomorphismKey::encrypt_sk_scratch_space(module, self.basek(), self.k(), self.rank()), + scratch.available() >= AutomorphismKey::encrypt_sk_scratch_space(module, self.basek(), self.k(), self.rank()), "scratch.available(): {} < AutomorphismKey::encrypt_sk_scratch_space(module, self.rank()={}, self.size()={}): {}", scratch.available(), self.rank(), self.size(), - GLWEAutomorphismKey::encrypt_sk_scratch_space(module, self.basek(), self.k(), self.rank()) + AutomorphismKey::encrypt_sk_scratch_space(module, self.basek(), self.k(), self.rank()) ) } - let (mut sk_out_dft, scratch_1) = scratch.tmp_fourier_glwe_secret(module, sk.rank()); + let (mut sk_out, scratch_1) = scratch.take_glwe_secret(module, sk.rank()); { - let (mut sk_out, _) = scratch_1.tmp_glwe_secret(module, sk.rank()); (0..self.rank()).for_each(|i| { module.scalar_znx_automorphism( module.galois_element_inv(p), @@ -243,41 +285,50 @@ impl + AsRef<[u8]>> GLWEAutomorphismKey { i, ); }); - sk_out_dft.set(module, &sk_out); } - self.key.encrypt_sk( - module, - &sk, - &sk_out_dft, - source_xa, - source_xe, - sigma, - scratch_1, - ); + self.key + .encrypt_sk(module, &sk, &sk_out, source_xa, source_xe, sigma, scratch_1); self.p = p; } } -impl GLWETensorKey, FFT64> { - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { - GLWESecret::bytes_of(module, 1) - + FourierGLWESecret::bytes_of(module, 1) +pub trait GLWETensorKeyEncryptSkFamily = + GGLWEEncryptSkFamily + VecZnxBigAllocBytes + VecZnxDftToVecZnxBigTmpA + SvpApply; + +impl GLWETensorKey> { + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize + where + Module: GLWETensorKeyEncryptSkFamily + ScalarZnxAllocBytes + VecZnxAllocBytes, + { + GLWESecretExec::bytes_of(module, rank) + + module.vec_znx_dft_alloc_bytes(rank, 1) + + module.vec_znx_big_alloc_bytes(1, 1) + + module.vec_znx_dft_alloc_bytes(1, 1) + + GLWESecret::bytes_of(module, 1) + GLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, rank, rank) } } -impl + AsRef<[u8]>> GLWETensorKey { - pub fn encrypt_sk>( +impl GLWETensorKey { + pub fn encrypt_sk( &mut self, - module: &Module, - sk: &FourierGLWESecret, + module: &Module, + sk: &GLWESecret, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, - scratch: &mut Scratch, - ) { + scratch: &mut Scratch, + ) where + Module: GLWETensorKeyEncryptSkFamily + + ScalarZnxAllocBytes + + VecZnxSwithcDegree + + VecZnxAllocBytes + + VecZnxAddScalarInplace, + Scratch: + ScratchAvailable + TakeVecZnxDft + TakeVecZnxBig + TakeGLWESecretExec + TakeScalarZnx + TakeVecZnx, + { #[cfg(debug_assertions)] { assert_eq!(self.rank(), sk.rank()); @@ -287,15 +338,28 @@ impl + AsRef<[u8]>> GLWETensorKey { let rank: usize = self.rank(); - let (mut sk_ij, scratch1) = scratch.tmp_glwe_secret(module, 1); - let (mut sk_ij_dft, scratch2) = scratch1.tmp_fourier_glwe_secret(module, 1); + let (mut sk_dft_prep, scratch1) = scratch.take_glwe_secret_exec(module, rank); + sk_dft_prep.prepare(module, &sk); + + let (mut sk_dft, scratch2) = scratch1.take_vec_znx_dft(module, rank, 1); + + (0..rank).for_each(|i| { + module.vec_znx_dft_from_vec_znx(1, 0, &mut sk_dft, i, &sk.data, i); + }); + + let (mut sk_ij_big, scratch3) = scratch2.take_vec_znx_big(module, 1, 1); + let (mut sk_ij, scratch4) = scratch3.take_glwe_secret(module, 1); + let (mut sk_ij_dft, scratch5) = scratch4.take_vec_znx_dft(module, 1, 1); (0..rank).for_each(|i| { (i..rank).for_each(|j| { - module.svp_apply(&mut sk_ij_dft.data, 0, &sk.data, i, &sk.data, j); - module.scalar_znx_idft(&mut sk_ij.data, 0, &sk_ij_dft.data, 0, scratch2); + module.svp_apply(&mut sk_ij_dft, 0, &sk_dft_prep.data, j, &sk_dft, i); + + module.vec_znx_dft_to_vec_znx_big_tmp_a(&mut sk_ij_big, 0, &mut sk_ij_dft, 0); + module.vec_znx_big_normalize(self.basek(), &mut sk_ij.data, 0, &sk_ij_big, 0, scratch5); + self.at_mut(i, j) - .encrypt_sk(module, &sk_ij, sk, source_xa, source_xe, sigma, scratch2); + .encrypt_sk(module, &sk_ij, sk, source_xa, source_xe, sigma, scratch5); }); }) } diff --git a/core/src/gglwe/external_product.rs b/core/src/gglwe/external_product.rs index 26a8c92..b067a58 100644 --- a/core/src/gglwe/external_product.rs +++ b/core/src/gglwe/external_product.rs @@ -1,46 +1,52 @@ -use backend::{FFT64, Module, Scratch, ZnxZero}; +use backend::hal::{ + api::{ScratchAvailable, TakeVecZnxDft, ZnxZero}, + layouts::{Backend, DataMut, DataRef, Module, Scratch}, +}; -use crate::{FourierGLWECiphertext, GGSWCiphertext, GLWEAutomorphismKey, GLWESwitchingKey, GetRow, Infos, ScratchCore, SetRow}; +use crate::{AutomorphismKey, GGSWCiphertextExec, GLWECiphertext, GLWEExternalProductFamily, GLWESwitchingKey, Infos}; -impl GLWESwitchingKey, FFT64> { - pub fn external_product_scratch_space( - module: &Module, +impl GLWESwitchingKey> { + pub fn external_product_scratch_space( + module: &Module, basek: usize, k_out: usize, k_in: usize, k_ggsw: usize, digits: usize, rank: usize, - ) -> usize { - let tmp_in: usize = FourierGLWECiphertext::bytes_of(module, basek, k_in, rank); - let tmp_out: usize = FourierGLWECiphertext::bytes_of(module, basek, k_out, rank); - let ggsw: usize = FourierGLWECiphertext::external_product_scratch_space(module, basek, k_out, k_in, k_ggsw, digits, rank); - tmp_in + tmp_out + ggsw + ) -> usize + where + Module: GLWEExternalProductFamily, + { + GLWECiphertext::external_product_scratch_space(module, basek, k_out, k_in, k_ggsw, digits, rank) } - pub fn external_product_inplace_scratch_space( - module: &Module, + pub fn external_product_inplace_scratch_space( + module: &Module, basek: usize, k_out: usize, k_ggsw: usize, digits: usize, rank: usize, - ) -> usize { - let tmp: usize = FourierGLWECiphertext::bytes_of(module, basek, k_out, rank); - let ggsw: usize = - FourierGLWECiphertext::external_product_inplace_scratch_space(module, basek, k_out, k_ggsw, digits, rank); - tmp + ggsw + ) -> usize + where + Module: GLWEExternalProductFamily, + { + GLWECiphertext::external_product_inplace_scratch_space(module, basek, k_out, k_ggsw, digits, rank) } } -impl + AsRef<[u8]>> GLWESwitchingKey { - pub fn external_product, DataRhs: AsRef<[u8]>>( +impl GLWESwitchingKey { + pub fn external_product( &mut self, - module: &Module, - lhs: &GLWESwitchingKey, - rhs: &GGSWCiphertext, - scratch: &mut Scratch, - ) { + module: &Module, + lhs: &GLWESwitchingKey, + rhs: &GGSWCiphertextExec, + scratch: &mut Scratch, + ) where + Module: GLWEExternalProductFamily, + Scratch: TakeVecZnxDft + ScratchAvailable, + { #[cfg(debug_assertions)] { assert_eq!( @@ -66,32 +72,29 @@ impl + AsRef<[u8]>> GLWESwitchingKey { ); } - let (mut tmp_in, scratch1) = scratch.tmp_fourier_glwe_ct(module, lhs.basek(), lhs.k(), lhs.rank()); - let (mut tmp_out, scratch2) = scratch1.tmp_fourier_glwe_ct(module, self.basek(), self.k(), self.rank()); - (0..self.rank_in()).for_each(|col_i| { (0..self.rows()).for_each(|row_j| { - lhs.get_row(module, row_j, col_i, &mut tmp_in); - tmp_out.external_product(module, &tmp_in, rhs, scratch2); - self.set_row(module, row_j, col_i, &tmp_out); + self.at_mut(row_j, col_i) + .external_product(module, &lhs.at(row_j, col_i), rhs, scratch); }); }); - tmp_out.data.zero(); - (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| { (0..self.rank_in()).for_each(|col_j| { - self.set_row(module, row_i, col_j, &tmp_out); + self.at_mut(row_i, col_j).data.zero(); }); }); } - pub fn external_product_inplace>( + pub fn external_product_inplace( &mut self, - module: &Module, - rhs: &GGSWCiphertext, - scratch: &mut Scratch, - ) { + module: &Module, + rhs: &GGSWCiphertextExec, + scratch: &mut Scratch, + ) where + Module: GLWEExternalProductFamily, + Scratch: TakeVecZnxDft + ScratchAvailable, + { #[cfg(debug_assertions)] { assert_eq!( @@ -103,60 +106,69 @@ impl + AsRef<[u8]>> GLWESwitchingKey { ); } - let (mut tmp, scratch1) = scratch.tmp_fourier_glwe_ct(module, self.basek(), self.k(), self.rank()); - println!("tmp: {}", tmp.size()); (0..self.rank_in()).for_each(|col_i| { (0..self.rows()).for_each(|row_j| { - self.get_row(module, row_j, col_i, &mut tmp); - tmp.external_product_inplace(module, rhs, scratch1); - self.set_row(module, row_j, col_i, &tmp); + self.at_mut(row_j, col_i) + .external_product_inplace(module, rhs, scratch); }); }); } } -impl GLWEAutomorphismKey, FFT64> { - pub fn external_product_scratch_space( - module: &Module, +impl AutomorphismKey> { + pub fn external_product_scratch_space( + module: &Module, basek: usize, k_out: usize, k_in: usize, ggsw_k: usize, digits: usize, rank: usize, - ) -> usize { + ) -> usize + where + Module: GLWEExternalProductFamily, + { GLWESwitchingKey::external_product_scratch_space(module, basek, k_out, k_in, ggsw_k, digits, rank) } - pub fn external_product_inplace_scratch_space( - module: &Module, + pub fn external_product_inplace_scratch_space( + module: &Module, basek: usize, k_out: usize, ggsw_k: usize, digits: usize, rank: usize, - ) -> usize { + ) -> usize + where + Module: GLWEExternalProductFamily, + { GLWESwitchingKey::external_product_inplace_scratch_space(module, basek, k_out, ggsw_k, digits, rank) } } -impl + AsRef<[u8]>> GLWEAutomorphismKey { - pub fn external_product, DataRhs: AsRef<[u8]>>( +impl AutomorphismKey { + pub fn external_product( &mut self, - module: &Module, - lhs: &GLWEAutomorphismKey, - rhs: &GGSWCiphertext, - scratch: &mut Scratch, - ) { + module: &Module, + lhs: &AutomorphismKey, + rhs: &GGSWCiphertextExec, + scratch: &mut Scratch, + ) where + Module: GLWEExternalProductFamily, + Scratch: TakeVecZnxDft + ScratchAvailable, + { self.key.external_product(module, &lhs.key, rhs, scratch); } - pub fn external_product_inplace>( + pub fn external_product_inplace( &mut self, - module: &Module, - rhs: &GGSWCiphertext, - scratch: &mut Scratch, - ) { + module: &Module, + rhs: &GGSWCiphertextExec, + scratch: &mut Scratch, + ) where + Module: GLWEExternalProductFamily, + Scratch: TakeVecZnxDft + ScratchAvailable, + { self.key.external_product_inplace(module, rhs, scratch); } } diff --git a/core/src/gglwe/keyswitch.rs b/core/src/gglwe/keyswitch.rs index fe4a3f6..0ddbb64 100644 --- a/core/src/gglwe/keyswitch.rs +++ b/core/src/gglwe/keyswitch.rs @@ -1,56 +1,73 @@ -use backend::{FFT64, Module, Scratch, ZnxZero}; +use backend::hal::{ + api::{ScratchAvailable, TakeVecZnxDft, ZnxZero}, + layouts::{Backend, DataMut, DataRef, Module, Scratch}, +}; -use crate::{FourierGLWECiphertext, GLWEAutomorphismKey, GLWESwitchingKey, GetRow, Infos, ScratchCore, SetRow}; +use crate::{ + AutomorphismKey, AutomorphismKeyExec, GLWECiphertext, GLWEKeyswitchFamily, GLWESwitchingKey, GLWESwitchingKeyExec, Infos, +}; -impl GLWEAutomorphismKey, FFT64> { - pub fn keyswitch_scratch_space( - module: &Module, +impl AutomorphismKey> { + pub fn keyswitch_scratch_space( + module: &Module, basek: usize, k_out: usize, k_in: usize, k_ksk: usize, digits: usize, rank: usize, - ) -> usize { + ) -> usize + where + Module: GLWEKeyswitchFamily, + { GLWESwitchingKey::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank, rank) } - pub fn keyswitch_inplace_scratch_space( - module: &Module, + pub fn keyswitch_inplace_scratch_space( + module: &Module, basek: usize, k_out: usize, k_ksk: usize, digits: usize, rank: usize, - ) -> usize { + ) -> usize + where + Module: GLWEKeyswitchFamily, + { GLWESwitchingKey::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank) } } -impl + AsRef<[u8]>> GLWEAutomorphismKey { - pub fn keyswitch, DataRhs: AsRef<[u8]>>( +impl AutomorphismKey { + pub fn keyswitch( &mut self, - module: &Module, - lhs: &GLWEAutomorphismKey, - rhs: &GLWESwitchingKey, - scratch: &mut Scratch, - ) { + module: &Module, + lhs: &AutomorphismKey, + rhs: &GLWESwitchingKeyExec, + scratch: &mut Scratch, + ) where + Module: GLWEKeyswitchFamily, + Scratch: TakeVecZnxDft + ScratchAvailable, + { self.key.keyswitch(module, &lhs.key, rhs, scratch); } - pub fn keyswitch_inplace>( + pub fn keyswitch_inplace( &mut self, - module: &Module, - rhs: &GLWEAutomorphismKey, - scratch: &mut Scratch, - ) { + module: &Module, + rhs: &AutomorphismKeyExec, + scratch: &mut Scratch, + ) where + Module: GLWEKeyswitchFamily, + Scratch: TakeVecZnxDft + ScratchAvailable, + { self.key.keyswitch_inplace(module, &rhs.key, scratch); } } -impl GLWESwitchingKey, FFT64> { - pub fn keyswitch_scratch_space( - module: &Module, +impl GLWESwitchingKey> { + pub fn keyswitch_scratch_space( + module: &Module, basek: usize, k_out: usize, k_in: usize, @@ -58,36 +75,39 @@ impl GLWESwitchingKey, FFT64> { digits: usize, rank_in: usize, rank_out: usize, - ) -> usize { - let tmp_in: usize = FourierGLWECiphertext::bytes_of(module, basek, k_in, rank_in); - let tmp_out: usize = FourierGLWECiphertext::bytes_of(module, basek, k_out, rank_out); - let ksk: usize = - FourierGLWECiphertext::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank_in, rank_out); - tmp_in + tmp_out + ksk + ) -> usize + where + Module: GLWEKeyswitchFamily, + { + GLWECiphertext::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank_in, rank_out) } - pub fn keyswitch_inplace_scratch_space( - module: &Module, + pub fn keyswitch_inplace_scratch_space( + module: &Module, basek: usize, k_out: usize, k_ksk: usize, digits: usize, rank: usize, - ) -> usize { - let tmp: usize = FourierGLWECiphertext::bytes_of(module, basek, k_out, rank); - let ksk: usize = FourierGLWECiphertext::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank); - tmp + ksk + ) -> usize + where + Module: GLWEKeyswitchFamily, + { + GLWECiphertext::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank) } } -impl + AsRef<[u8]>> GLWESwitchingKey { - pub fn keyswitch, DataRhs: AsRef<[u8]>>( +impl GLWESwitchingKey { + pub fn keyswitch( &mut self, - module: &Module, - lhs: &GLWESwitchingKey, - rhs: &GLWESwitchingKey, - scratch: &mut Scratch, - ) { + module: &Module, + lhs: &GLWESwitchingKey, + rhs: &GLWESwitchingKeyExec, + scratch: &mut Scratch, + ) where + Module: GLWEKeyswitchFamily, + Scratch: TakeVecZnxDft + ScratchAvailable, + { #[cfg(debug_assertions)] { assert_eq!( @@ -113,32 +133,29 @@ impl + AsRef<[u8]>> GLWESwitchingKey { ); } - let (mut tmp_in, scratch1) = scratch.tmp_fourier_glwe_ct(module, lhs.basek(), lhs.k(), lhs.rank()); - let (mut tmp_out, scratch2) = scratch1.tmp_fourier_glwe_ct(module, self.basek(), self.k(), self.rank()); - (0..self.rank_in()).for_each(|col_i| { (0..self.rows()).for_each(|row_j| { - lhs.get_row(module, row_j, col_i, &mut tmp_in); - tmp_out.keyswitch(module, &tmp_in, rhs, scratch2); - self.set_row(module, row_j, col_i, &tmp_out); + self.at_mut(row_j, col_i) + .keyswitch(module, &lhs.at(row_j, col_i), rhs, scratch); }); }); - tmp_out.data.zero(); - (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| { (0..self.rank_in()).for_each(|col_j| { - self.set_row(module, row_i, col_j, &tmp_out); + self.at_mut(row_i, col_j).data.zero(); }); }); } - pub fn keyswitch_inplace>( + pub fn keyswitch_inplace( &mut self, - module: &Module, - rhs: &GLWESwitchingKey, - scratch: &mut Scratch, - ) { + module: &Module, + rhs: &GLWESwitchingKeyExec, + scratch: &mut Scratch, + ) where + Module: GLWEKeyswitchFamily, + Scratch: TakeVecZnxDft + ScratchAvailable, + { #[cfg(debug_assertions)] { assert_eq!( @@ -150,13 +167,10 @@ impl + AsRef<[u8]>> GLWESwitchingKey { ); } - let (mut tmp, scratch1) = scratch.tmp_fourier_glwe_ct(module, self.basek(), self.k(), self.rank()); - (0..self.rank_in()).for_each(|col_i| { (0..self.rows()).for_each(|row_j| { - self.get_row(module, row_j, col_i, &mut tmp); - tmp.keyswitch_inplace(module, rhs, scratch1); - self.set_row(module, row_j, col_i, &tmp); + self.at_mut(row_j, col_i) + .keyswitch_inplace(module, rhs, scratch) }); }); } diff --git a/core/src/gglwe/keyswitch_key.rs b/core/src/gglwe/keyswitch_key.rs index cb1c1fb..b742c37 100644 --- a/core/src/gglwe/keyswitch_key.rs +++ b/core/src/gglwe/keyswitch_key.rs @@ -1,23 +1,30 @@ -use backend::{Backend, FFT64, MatZnxDft, MatZnxDftOps, Module}; +use backend::hal::{ + api::{MatZnxAlloc, MatZnxAllocBytes}, + layouts::{Backend, Data, DataMut, DataRef, MatZnx, Module, ReaderFrom, Scratch, VmpPMat, WriterTo}, +}; -use crate::{FourierGLWECiphertext, GGLWECiphertext, GetRow, Infos, SetRow}; +use crate::{GGLWECiphertext, GGLWECiphertextExec, GGLWEExecLayoutFamily, GLWECiphertext, Infos}; -pub struct GLWESwitchingKey { - pub(crate) key: GGLWECiphertext, +#[derive(PartialEq, Eq)] +pub struct GLWESwitchingKey { + pub(crate) key: GGLWECiphertext, pub(crate) sk_in_n: usize, // Degree of sk_in pub(crate) sk_out_n: usize, // Degree of sk_out } -impl GLWESwitchingKey, FFT64> { - pub fn alloc( - module: &Module, +impl GLWESwitchingKey> { + pub fn alloc( + module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize, rank_out: usize, - ) -> Self { + ) -> Self + where + Module: MatZnxAlloc, + { GLWESwitchingKey { key: GGLWECiphertext::alloc(module, basek, k, rows, digits, rank_in, rank_out), sk_in_n: 0, @@ -25,21 +32,24 @@ impl GLWESwitchingKey, FFT64> { } } - pub fn bytes_of( - module: &Module, + pub fn bytes_of( + module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize, rank_out: usize, - ) -> usize { - GGLWECiphertext::, FFT64>::bytes_of(module, basek, k, rows, digits, rank_in, rank_out) + ) -> usize + where + Module: MatZnxAllocBytes, + { + GGLWECiphertext::>::bytes_of(module, basek, k, rows, digits, rank_in, rank_out) } } -impl Infos for GLWESwitchingKey { - type Inner = MatZnxDft; +impl Infos for GLWESwitchingKey { + type Inner = MatZnx; fn inner(&self) -> &Self::Inner { self.key.inner() @@ -54,7 +64,7 @@ impl Infos for GLWESwitchingKey { } } -impl GLWESwitchingKey { +impl GLWESwitchingKey { pub fn rank(&self) -> usize { self.key.data.cols_out() - 1 } @@ -80,26 +90,138 @@ impl GLWESwitchingKey { } } -impl> GetRow for GLWESwitchingKey { - fn get_row + AsRef<[u8]>>( - &self, - module: &Module, - row_i: usize, - col_j: usize, - res: &mut FourierGLWECiphertext, - ) { - module.mat_znx_dft_get_row(&mut res.data, &self.key.data, row_i, col_j); +impl GLWESwitchingKey { + pub fn at(&self, row: usize, col: usize) -> GLWECiphertext<&[u8]> { + self.key.at(row, col) } } -impl + AsRef<[u8]>> SetRow for GLWESwitchingKey { - fn set_row>( - &mut self, - module: &Module, - row_i: usize, - col_j: usize, - a: &FourierGLWECiphertext, - ) { - module.mat_znx_dft_set_row(&mut self.key.data, row_i, col_j, &a.data); +impl GLWESwitchingKey { + pub fn at_mut(&mut self, row: usize, col: usize) -> GLWECiphertext<&mut [u8]> { + self.key.at_mut(row, col) + } +} + +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; + +impl ReaderFrom for GLWESwitchingKey { + fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { + self.sk_in_n = reader.read_u64::()? as usize; + self.sk_out_n = reader.read_u64::()? as usize; + self.key.read_from(reader) + } +} + +impl WriterTo for GLWESwitchingKey { + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + writer.write_u64::(self.sk_in_n as u64)?; + writer.write_u64::(self.sk_out_n as u64)?; + self.key.write_to(writer) + } +} + +#[derive(PartialEq, Eq)] +pub struct GLWESwitchingKeyExec { + pub(crate) key: GGLWECiphertextExec, + pub(crate) sk_in_n: usize, // Degree of sk_in + pub(crate) sk_out_n: usize, // Degree of sk_out +} + +impl GLWESwitchingKeyExec, B> { + pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize, rank_out: usize) -> Self + where + Module: GGLWEExecLayoutFamily, + { + GLWESwitchingKeyExec::, B> { + key: GGLWECiphertextExec::alloc(module, basek, k, rows, digits, rank_in, rank_out), + sk_in_n: 0, + sk_out_n: 0, + } + } + + pub fn bytes_of( + module: &Module, + basek: usize, + k: usize, + rows: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + ) -> usize + where + Module: GGLWEExecLayoutFamily, + { + GGLWECiphertextExec::bytes_of(module, basek, k, rows, digits, rank_in, rank_out) + } + + pub fn from(module: &Module, other: &GLWESwitchingKey, scratch: &mut Scratch) -> Self + where + Module: GGLWEExecLayoutFamily, + { + let mut ksk_exec: GLWESwitchingKeyExec, B> = Self::alloc( + module, + other.basek(), + other.k(), + other.rows(), + other.digits(), + other.rank_in(), + other.rank_out(), + ); + ksk_exec.prepare(module, other, scratch); + ksk_exec + } +} + +impl Infos for GLWESwitchingKeyExec { + type Inner = VmpPMat; + + fn inner(&self) -> &Self::Inner { + self.key.inner() + } + + fn basek(&self) -> usize { + self.key.basek() + } + + fn k(&self) -> usize { + self.key.k() + } +} + +impl GLWESwitchingKeyExec { + pub fn rank(&self) -> usize { + self.key.data.cols_out() - 1 + } + + pub fn rank_in(&self) -> usize { + self.key.data.cols_in() + } + + pub fn rank_out(&self) -> usize { + self.key.data.cols_out() - 1 + } + + pub fn digits(&self) -> usize { + self.key.digits() + } + + pub fn sk_degree_in(&self) -> usize { + self.sk_in_n + } + + pub fn sk_degree_out(&self) -> usize { + self.sk_out_n + } +} + +impl GLWESwitchingKeyExec { + pub fn prepare(&mut self, module: &Module, other: &GLWESwitchingKey, scratch: &mut Scratch) + where + DataOther: DataRef, + Module: GGLWEExecLayoutFamily, + { + self.key.prepare(module, &other.key, scratch); + self.sk_in_n = other.sk_in_n; + self.sk_out_n = other.sk_out_n; } } diff --git a/core/src/gglwe/layout.rs b/core/src/gglwe/layout.rs new file mode 100644 index 0000000..a905074 --- /dev/null +++ b/core/src/gglwe/layout.rs @@ -0,0 +1,275 @@ +use backend::hal::{ + api::{MatZnxAlloc, MatZnxAllocBytes, VmpPMatAlloc, VmpPMatAllocBytes, VmpPMatPrepare}, + layouts::{Backend, Data, DataMut, DataRef, MatZnx, Module, ReaderFrom, Scratch, VmpPMat, WriterTo}, +}; + +use crate::{GLWECiphertext, Infos}; + +pub trait GGLWEExecLayoutFamily = VmpPMatAlloc + VmpPMatAllocBytes + VmpPMatPrepare; + +#[derive(PartialEq, Eq)] +pub struct GGLWECiphertext { + pub(crate) data: MatZnx, + pub(crate) basek: usize, + pub(crate) k: usize, + pub(crate) digits: usize, +} + +impl GGLWECiphertext { + pub fn at(&self, row: usize, col: usize) -> GLWECiphertext<&[u8]> { + GLWECiphertext { + data: self.data.at(row, col), + basek: self.basek, + k: self.k, + } + } +} + +impl GGLWECiphertext { + pub fn at_mut(&mut self, row: usize, col: usize) -> GLWECiphertext<&mut [u8]> { + GLWECiphertext { + data: self.data.at_mut(row, col), + basek: self.basek, + k: self.k, + } + } +} + +impl GGLWECiphertext> { + pub fn alloc( + module: &Module, + basek: usize, + k: usize, + rows: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + ) -> Self + where + Module: MatZnxAlloc, + { + let size: usize = k.div_ceil(basek); + debug_assert!( + size > digits, + "invalid gglwe: ceil(k/basek): {} <= digits: {}", + size, + digits + ); + + assert!( + rows * digits <= size, + "invalid gglwe: rows: {} * digits:{} > ceil(k/basek): {}", + rows, + digits, + size + ); + + Self { + data: module.mat_znx_alloc(rows, rank_in, rank_out + 1, size), + basek: basek, + k, + digits, + } + } + + pub fn bytes_of( + module: &Module, + basek: usize, + k: usize, + rows: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + ) -> usize + where + Module: MatZnxAllocBytes, + { + let size: usize = k.div_ceil(basek); + debug_assert!( + size > digits, + "invalid gglwe: ceil(k/basek): {} <= digits: {}", + size, + digits + ); + + assert!( + rows * digits <= size, + "invalid gglwe: rows: {} * digits:{} > ceil(k/basek): {}", + rows, + digits, + size + ); + + module.mat_znx_alloc_bytes(rows, rank_in, rank_out + 1, rows) + } +} + +impl Infos for GGLWECiphertext { + type Inner = MatZnx; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn basek(&self) -> usize { + self.basek + } + + fn k(&self) -> usize { + self.k + } +} + +impl GGLWECiphertext { + pub fn rank(&self) -> usize { + self.data.cols_out() - 1 + } + + pub fn digits(&self) -> usize { + self.digits + } + + pub fn rank_in(&self) -> usize { + self.data.cols_in() + } + + pub fn rank_out(&self) -> usize { + self.data.cols_out() - 1 + } +} + +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; + +impl ReaderFrom for GGLWECiphertext { + fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { + self.k = reader.read_u64::()? as usize; + self.basek = reader.read_u64::()? as usize; + self.digits = reader.read_u64::()? as usize; + self.data.read_from(reader) + } +} + +impl WriterTo for GGLWECiphertext { + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + writer.write_u64::(self.k as u64)?; + writer.write_u64::(self.basek as u64)?; + writer.write_u64::(self.digits as u64)?; + self.data.write_to(writer) + } +} + +#[derive(PartialEq, Eq)] +pub struct GGLWECiphertextExec { + pub(crate) data: VmpPMat, + pub(crate) basek: usize, + pub(crate) k: usize, + pub(crate) digits: usize, +} + +impl GGLWECiphertextExec, B> { + pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize, rank_out: usize) -> Self + where + Module: GGLWEExecLayoutFamily, + { + let size: usize = k.div_ceil(basek); + debug_assert!( + size > digits, + "invalid gglwe: ceil(k/basek): {} <= digits: {}", + size, + digits + ); + + assert!( + rows * digits <= size, + "invalid gglwe: rows: {} * digits:{} > ceil(k/basek): {}", + rows, + digits, + size + ); + + Self { + data: module.vmp_pmat_alloc(rows, rank_in, rank_out + 1, size), + basek: basek, + k, + digits, + } + } + + pub fn bytes_of( + module: &Module, + basek: usize, + k: usize, + rows: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + ) -> usize + where + Module: GGLWEExecLayoutFamily, + { + let size: usize = k.div_ceil(basek); + debug_assert!( + size > digits, + "invalid gglwe: ceil(k/basek): {} <= digits: {}", + size, + digits + ); + + assert!( + rows * digits <= size, + "invalid gglwe: rows: {} * digits:{} > ceil(k/basek): {}", + rows, + digits, + size + ); + + module.vmp_pmat_alloc_bytes(rows, rank_in, rank_out + 1, rows) + } +} + +impl Infos for GGLWECiphertextExec { + type Inner = VmpPMat; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn basek(&self) -> usize { + self.basek + } + + fn k(&self) -> usize { + self.k + } +} + +impl GGLWECiphertextExec { + pub fn rank(&self) -> usize { + self.data.cols_out() - 1 + } + + pub fn digits(&self) -> usize { + self.digits + } + + pub fn rank_in(&self) -> usize { + self.data.cols_in() + } + + pub fn rank_out(&self) -> usize { + self.data.cols_out() - 1 + } +} + +impl GGLWECiphertextExec { + pub fn prepare(&mut self, module: &Module, other: &GGLWECiphertext, scratch: &mut Scratch) + where + DataOther: DataRef, + Module: GGLWEExecLayoutFamily, + { + module.vmp_prepare(&mut self.data, &other.data, scratch); + self.basek = other.basek; + self.k = other.k; + self.digits = other.digits; + } +} diff --git a/core/src/gglwe/mod.rs b/core/src/gglwe/mod.rs index 4c2d20a..7ab6f0d 100644 --- a/core/src/gglwe/mod.rs +++ b/core/src/gglwe/mod.rs @@ -1,16 +1,20 @@ -pub mod automorphism; -pub mod automorphism_key; -pub mod ciphertext; -pub mod encryption; -pub mod external_product; -pub mod keyswitch; -pub mod keyswitch_key; -pub mod tensor_key; +mod automorphism; +mod automorphism_key; +mod encryption; +mod external_product; +mod keyswitch; +mod keyswitch_key; +mod layout; +mod noise; +mod tensor_key; -pub use automorphism_key::GLWEAutomorphismKey; -pub use ciphertext::GGLWECiphertext; -pub use keyswitch_key::GLWESwitchingKey; -pub use tensor_key::GLWETensorKey; +pub use automorphism_key::{AutomorphismKey, AutomorphismKeyExec}; +pub use encryption::{ + AutomorphismKeyEncryptSkFamily, GGLWEEncryptSkFamily, GLWESwitchingKeyEncryptSkFamily, GLWETensorKeyEncryptSkFamily, +}; +pub use keyswitch_key::{GLWESwitchingKey, GLWESwitchingKeyExec}; +pub use layout::{GGLWECiphertext, GGLWECiphertextExec, GGLWEExecLayoutFamily}; +pub use tensor_key::{GLWETensorKey, GLWETensorKeyExec}; #[cfg(test)] -mod test_fft64; +mod test; diff --git a/core/src/gglwe/noise.rs b/core/src/gglwe/noise.rs new file mode 100644 index 0000000..def18c2 --- /dev/null +++ b/core/src/gglwe/noise.rs @@ -0,0 +1,55 @@ +use backend::hal::{ + api::{ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAlloc, VecZnxStd, VecZnxSubScalarInplace, ZnxZero}, + layouts::{Backend, DataRef, Module, ScalarZnx, ScratchOwned}, + oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl}, +}; + +use crate::{GGLWECiphertext, GLWECiphertext, GLWEDecryptFamily, GLWEPlaintext, GLWESecretExec, Infos}; + +impl GGLWECiphertext { + pub fn assert_noise( + self, + module: &Module, + sk: &GLWESecretExec, + pt_want: &ScalarZnx, + max_noise: f64, + ) where + DataSk: DataRef, + DataWant: DataRef, + Module: GLWEDecryptFamily + VecZnxStd + VecZnxAlloc + VecZnxSubScalarInplace, + B: TakeVecZnxDftImpl + TakeVecZnxBigImpl + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, + { + let digits: usize = self.digits(); + let basek: usize = self.basek(); + let k: usize = self.k(); + + let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWECiphertext::decrypt_scratch_space(module, basek, k)); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k); + + (0..self.rank_in()).for_each(|col_i| { + (0..self.rows()).for_each(|row_i| { + self.at(row_i, col_i) + .decrypt(&module, &mut pt, &sk, scratch.borrow()); + + module.vec_znx_sub_scalar_inplace( + &mut pt.data, + 0, + (digits - 1) + row_i * digits, + pt_want, + col_i, + ); + + let noise_have: f64 = module.vec_znx_std(basek, &pt.data, 0).log2(); + + assert!( + noise_have <= max_noise, + "noise_have: {} > max_noise: {}", + noise_have, + max_noise + ); + + pt.data.zero(); + }); + }); + } +} diff --git a/core/src/gglwe/tensor_key.rs b/core/src/gglwe/tensor_key.rs index c12c1f5..5ebcb70 100644 --- a/core/src/gglwe/tensor_key.rs +++ b/core/src/gglwe/tensor_key.rs @@ -1,14 +1,21 @@ -use backend::{Backend, FFT64, MatZnxDft, Module}; +use backend::hal::{ + api::{MatZnxAlloc, MatZnxAllocBytes}, + layouts::{Backend, Data, DataMut, DataRef, MatZnx, Module, ReaderFrom, Scratch, VmpPMat, WriterTo}, +}; -use crate::{GLWESwitchingKey, Infos}; +use crate::{GGLWEExecLayoutFamily, GLWESwitchingKey, GLWESwitchingKeyExec, Infos}; -pub struct GLWETensorKey { - pub(crate) keys: Vec>, +#[derive(PartialEq, Eq)] +pub struct GLWETensorKey { + pub(crate) keys: Vec>, } -impl GLWETensorKey, FFT64> { - pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self { - let mut keys: Vec, FFT64>> = Vec::new(); +impl GLWETensorKey> { + pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self + where + Module: MatZnxAlloc, + { + let mut keys: Vec>> = Vec::new(); let pairs: usize = (((rank + 1) * rank) >> 1).max(1); (0..pairs).for_each(|_| { keys.push(GLWESwitchingKey::alloc( @@ -18,14 +25,17 @@ impl GLWETensorKey, FFT64> { Self { keys: keys } } - pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize { + pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize + where + Module: MatZnxAllocBytes, + { let pairs: usize = (((rank + 1) * rank) >> 1).max(1); - pairs * GLWESwitchingKey::, FFT64>::bytes_of(module, basek, k, rows, digits, 1, rank) + pairs * GLWESwitchingKey::>::bytes_of(module, basek, k, rows, digits, 1, rank) } } -impl Infos for GLWETensorKey { - type Inner = MatZnxDft; +impl Infos for GLWETensorKey { + type Inner = MatZnx; fn inner(&self) -> &Self::Inner { &self.keys[0].inner() @@ -40,7 +50,7 @@ impl Infos for GLWETensorKey { } } -impl GLWETensorKey { +impl GLWETensorKey { pub fn rank(&self) -> usize { self.keys[0].rank() } @@ -58,9 +68,9 @@ impl GLWETensorKey { } } -impl + AsRef<[u8]>> GLWETensorKey { +impl GLWETensorKey { // Returns a mutable reference to GLWESwitchingKey_{s}(s[i] * s[j]) - pub fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GLWESwitchingKey { + pub fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GLWESwitchingKey { if i > j { std::mem::swap(&mut i, &mut j); }; @@ -69,9 +79,9 @@ impl + AsRef<[u8]>> GLWETensorKey { } } -impl> GLWETensorKey { +impl GLWETensorKey { // Returns a reference to GLWESwitchingKey_{s}(s[i] * s[j]) - pub fn at(&self, mut i: usize, mut j: usize) -> &GLWESwitchingKey { + pub fn at(&self, mut i: usize, mut j: usize) -> &GLWESwitchingKey { if i > j { std::mem::swap(&mut i, &mut j); }; @@ -79,3 +89,135 @@ impl> GLWETensorKey { &self.keys[i * rank + j - (i * (i + 1) / 2)] } } + +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; + +impl ReaderFrom for GLWETensorKey { + fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { + let len: usize = reader.read_u64::()? as usize; + if self.keys.len() != len { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("self.keys.len()={} != read len={}", self.keys.len(), len), + )); + } + for key in &mut self.keys { + key.read_from(reader)?; + } + Ok(()) + } +} + +impl WriterTo for GLWETensorKey { + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + writer.write_u64::(self.keys.len() as u64)?; + for key in &self.keys { + key.write_to(writer)?; + } + Ok(()) + } +} + +#[derive(PartialEq, Eq)] +pub struct GLWETensorKeyExec { + pub(crate) keys: Vec>, +} + +impl GLWETensorKeyExec, B> { + pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self + where + Module: GGLWEExecLayoutFamily, + { + let mut keys: Vec, B>> = Vec::new(); + let pairs: usize = (((rank + 1) * rank) >> 1).max(1); + (0..pairs).for_each(|_| { + keys.push(GLWESwitchingKeyExec::alloc( + module, basek, k, rows, digits, 1, rank, + )); + }); + Self { keys } + } + + pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize + where + Module: GGLWEExecLayoutFamily, + { + let pairs: usize = (((rank + 1) * rank) >> 1).max(1); + pairs * GLWESwitchingKeyExec::, B>::bytes_of(module, basek, k, rows, digits, 1, rank) + } +} + +impl Infos for GLWETensorKeyExec { + type Inner = VmpPMat; + + fn inner(&self) -> &Self::Inner { + &self.keys[0].inner() + } + + fn basek(&self) -> usize { + self.keys[0].basek() + } + + fn k(&self) -> usize { + self.keys[0].k() + } +} + +impl GLWETensorKeyExec { + pub fn rank(&self) -> usize { + self.keys[0].rank() + } + + pub fn rank_in(&self) -> usize { + self.keys[0].rank_in() + } + + pub fn rank_out(&self) -> usize { + self.keys[0].rank_out() + } + + pub fn digits(&self) -> usize { + self.keys[0].digits() + } +} + +impl GLWETensorKeyExec { + // Returns a mutable reference to GLWESwitchingKey_{s}(s[i] * s[j]) + pub fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GLWESwitchingKeyExec { + if i > j { + std::mem::swap(&mut i, &mut j); + }; + let rank: usize = self.rank(); + &mut self.keys[i * rank + j - (i * (i + 1) / 2)] + } +} + +impl GLWETensorKeyExec { + // Returns a reference to GLWESwitchingKey_{s}(s[i] * s[j]) + pub fn at(&self, mut i: usize, mut j: usize) -> &GLWESwitchingKeyExec { + if i > j { + std::mem::swap(&mut i, &mut j); + }; + let rank: usize = self.rank(); + &self.keys[i * rank + j - (i * (i + 1) / 2)] + } +} + +impl GLWETensorKeyExec { + pub fn prepare(&mut self, module: &Module, other: &GLWETensorKey, scratch: &mut Scratch) + where + DataOther: DataRef, + Module: GGLWEExecLayoutFamily, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.keys.len(), other.keys.len()); + } + self.keys + .iter_mut() + .zip(other.keys.iter()) + .for_each(|(a, b)| { + a.prepare(module, b, scratch); + }); + } +} diff --git a/core/src/gglwe/test_fft64/automorphism_key.rs b/core/src/gglwe/test/automorphism_key.rs similarity index 68% rename from core/src/gglwe/test_fft64/automorphism_key.rs rename to core/src/gglwe/test/automorphism_key.rs index c6dc212..1c8cca2 100644 --- a/core/src/gglwe/test_fft64/automorphism_key.rs +++ b/core/src/gglwe/test/automorphism_key.rs @@ -1,9 +1,14 @@ -use backend::{FFT64, Module, ScalarZnxOps, ScratchOwned, Stats, VecZnxOps}; +use backend::{ + hal::{ + api::{ModuleNew, ScalarZnxAutomorphism, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxStd, VecZnxSubScalarInplace}, + layouts::{Module, ScratchOwned}, + }, + implementation::cpu_spqlios::FFT64, +}; use sampling::source::Source; use crate::{ - FourierGLWECiphertext, FourierGLWESecret, GLWEAutomorphismKey, GLWEPlaintext, GLWESecret, GetRow, Infos, - noise::log2_std_noise_gglwe_product, + AutomorphismKey, AutomorphismKeyExec, GLWEPlaintext, GLWESecret, GLWESecretExec, Infos, noise::log2_std_noise_gglwe_product, }; #[test] @@ -58,21 +63,17 @@ fn test_automorphism( let rows_in: usize = k_in / (basek * digits); let rows_apply: usize = k_in.div_ceil(basek * digits); - let mut auto_key_in: GLWEAutomorphismKey, FFT64> = - GLWEAutomorphismKey::alloc(&module, basek, k_in, rows_in, digits_in, rank); - let mut auto_key_out: GLWEAutomorphismKey, FFT64> = - GLWEAutomorphismKey::alloc(&module, basek, k_out, rows_in, digits_in, rank); - let mut auto_key_apply: GLWEAutomorphismKey, FFT64> = - GLWEAutomorphismKey::alloc(&module, basek, k_apply, rows_apply, digits, rank); + let mut auto_key_in: AutomorphismKey> = AutomorphismKey::alloc(&module, basek, k_in, rows_in, digits_in, rank); + let mut auto_key_out: AutomorphismKey> = AutomorphismKey::alloc(&module, basek, k_out, rows_in, digits_in, rank); + let mut auto_key_apply: AutomorphismKey> = AutomorphismKey::alloc(&module, basek, k_apply, rows_apply, digits, rank); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWEAutomorphismKey::encrypt_sk_scratch_space(&module, basek, k_apply, rank) - | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_out) - | GLWEAutomorphismKey::automorphism_scratch_space(&module, basek, k_out, k_in, k_apply, digits, rank), + let mut scratch: ScratchOwned = ScratchOwned::alloc( + AutomorphismKey::encrypt_sk_scratch_space(&module, basek, k_apply, rank) + | AutomorphismKey::automorphism_scratch_space(&module, basek, k_out, k_in, k_apply, digits, rank), ); let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); @@ -100,10 +101,19 @@ fn test_automorphism( scratch.borrow(), ); - // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) - auto_key_out.automorphism(&module, &auto_key_in, &auto_key_apply, scratch.borrow()); + let mut auto_key_apply_exec: AutomorphismKeyExec, FFT64> = + AutomorphismKeyExec::alloc(&module, basek, k_apply, rows_apply, digits, rank); + + auto_key_apply_exec.prepare(&module, &auto_key_apply, scratch.borrow()); + + // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) + auto_key_out.automorphism( + &module, + &auto_key_in, + &auto_key_apply_exec, + scratch.borrow(), + ); - let mut ct_glwe_dft: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_out, rank); let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); let mut sk_auto: GLWESecret> = GLWESecret::alloc(&module, rank); @@ -118,12 +128,13 @@ fn test_automorphism( ); }); - let sk_auto_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_auto); + let sk_auto_dft: GLWESecretExec, FFT64> = GLWESecretExec::from(&module, &sk_auto); (0..auto_key_out.rank_in()).for_each(|col_i| { (0..auto_key_out.rows()).for_each(|row_i| { - auto_key_out.get_row(&module, row_i, col_i, &mut ct_glwe_dft); - ct_glwe_dft.decrypt(&module, &mut pt, &sk_auto_dft, scratch.borrow()); + auto_key_out + .at(row_i, col_i) + .decrypt(&module, &mut pt, &sk_auto_dft, scratch.borrow()); module.vec_znx_sub_scalar_inplace( &mut pt.data, @@ -133,7 +144,7 @@ fn test_automorphism( col_i, ); - let noise_have: f64 = pt.data.std(0, basek).log2(); + let noise_have: f64 = module.vec_znx_std(basek, &pt.data, 0).log2(); let noise_want: f64 = log2_std_noise_gglwe_product( module.n() as f64, basek * digits, @@ -175,19 +186,16 @@ fn test_automorphism_inplace( let rows_in: usize = k_in / (basek * digits); let rows_apply: usize = k_in.div_ceil(basek * digits); - let mut auto_key: GLWEAutomorphismKey, FFT64> = - GLWEAutomorphismKey::alloc(&module, basek, k_in, rows_in, digits_in, rank); - let mut auto_key_apply: GLWEAutomorphismKey, FFT64> = - GLWEAutomorphismKey::alloc(&module, basek, k_apply, rows_apply, digits, rank); + let mut auto_key: AutomorphismKey> = AutomorphismKey::alloc(&module, basek, k_in, rows_in, digits_in, rank); + let mut auto_key_apply: AutomorphismKey> = AutomorphismKey::alloc(&module, basek, k_apply, rows_apply, digits, rank); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWEAutomorphismKey::encrypt_sk_scratch_space(&module, basek, k_apply, rank) - | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_in) - | GLWEAutomorphismKey::automorphism_inplace_scratch_space(&module, basek, k_in, k_apply, digits, rank), + let mut scratch: ScratchOwned = ScratchOwned::alloc( + AutomorphismKey::encrypt_sk_scratch_space(&module, basek, k_apply, rank) + | AutomorphismKey::automorphism_inplace_scratch_space(&module, basek, k_in, k_apply, digits, rank), ); let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); @@ -215,10 +223,14 @@ fn test_automorphism_inplace( scratch.borrow(), ); - // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) - auto_key.automorphism_inplace(&module, &auto_key_apply, scratch.borrow()); + let mut auto_key_apply_exec: AutomorphismKeyExec, FFT64> = + AutomorphismKeyExec::alloc(&module, basek, k_apply, rows_apply, digits, rank); + + auto_key_apply_exec.prepare(&module, &auto_key_apply, scratch.borrow()); + + // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) + auto_key.automorphism_inplace(&module, &auto_key_apply_exec, scratch.borrow()); - let mut ct_glwe_dft: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_in, rank); let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_in); let mut sk_auto: GLWESecret> = GLWESecret::alloc(&module, rank); @@ -234,13 +246,13 @@ fn test_automorphism_inplace( ); }); - let sk_auto_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_auto); + let sk_auto_dft: GLWESecretExec, FFT64> = GLWESecretExec::from(&module, &sk_auto); (0..auto_key.rank_in()).for_each(|col_i| { (0..auto_key.rows()).for_each(|row_i| { - auto_key.get_row(&module, row_i, col_i, &mut ct_glwe_dft); - - ct_glwe_dft.decrypt(&module, &mut pt, &sk_auto_dft, scratch.borrow()); + auto_key + .at(row_i, col_i) + .decrypt(&module, &mut pt, &sk_auto_dft, scratch.borrow()); module.vec_znx_sub_scalar_inplace( &mut pt.data, 0, @@ -249,7 +261,7 @@ fn test_automorphism_inplace( col_i, ); - let noise_have: f64 = pt.data.std(0, basek).log2(); + let noise_have: f64 = module.vec_znx_std(basek, &pt.data, 0).log2(); let noise_want: f64 = log2_std_noise_gglwe_product( module.n() as f64, basek * digits, diff --git a/core/src/gglwe/test/gglwe_fft64.rs b/core/src/gglwe/test/gglwe_fft64.rs new file mode 100644 index 0000000..ba28663 --- /dev/null +++ b/core/src/gglwe/test/gglwe_fft64.rs @@ -0,0 +1,138 @@ +use backend::{ + hal::{api::ModuleNew, layouts::Module}, + implementation::cpu_spqlios::FFT64, +}; + +use crate::gglwe::test::gglwe_generic::{ + test_encrypt_sk, test_external_product, test_external_product_inplace, test_keyswitch, test_keyswitch_inplace, +}; + +#[test] +fn encrypt_sk() { + let log_n: usize = 8; + let module: Module = Module::::new(1 << log_n); + let basek: usize = 12; + let k_ksk: usize = 54; + let digits: usize = k_ksk / basek; + (1..4).for_each(|rank_in| { + (1..4).for_each(|rank_out| { + (1..digits + 1).for_each(|di| { + println!( + "test encrypt_sk digits: {} ranks: ({} {})", + di, rank_in, rank_out + ); + test_encrypt_sk(&module, basek, k_ksk, di, rank_in, rank_out, 3.2); + }); + }); + }); +} + +#[test] +fn keyswitch() { + let log_n: usize = 8; + let module: Module = Module::::new(1 << log_n); + let basek: usize = 12; + let k_in: usize = 60; + let digits: usize = k_in.div_ceil(basek); + (1..4).for_each(|rank_in_s0s1| { + (1..4).for_each(|rank_out_s0s1| { + (1..4).for_each(|rank_out_s1s2| { + (1..digits + 1).for_each(|di| { + let k_ksk: usize = k_in + basek * di; + println!( + "test key_switch digits: {} ranks: ({},{},{})", + di, rank_in_s0s1, rank_out_s0s1, rank_out_s1s2 + ); + let k_out: usize = k_ksk; // Better capture noise. + test_keyswitch( + &module, + basek, + k_out, + k_in, + k_ksk, + di, + rank_in_s0s1, + rank_out_s0s1, + rank_out_s1s2, + 3.2, + ); + }) + }) + }); + }); +} + +#[test] +fn keyswitch_inplace() { + let log_n: usize = 8; + let module: Module = Module::::new(1 << log_n); + let basek: usize = 12; + let k_ct: usize = 60; + let digits: usize = k_ct.div_ceil(basek); + (1..4).for_each(|rank_in_s0s1| { + (1..4).for_each(|rank_out_s0s1| { + (1..digits + 1).for_each(|di| { + let k_ksk: usize = k_ct + basek * di; + println!( + "test key_switch_inplace digits: {} ranks: ({},{})", + di, rank_in_s0s1, rank_out_s0s1 + ); + test_keyswitch_inplace( + &module, + basek, + k_ct, + k_ksk, + di, + rank_in_s0s1, + rank_out_s0s1, + 3.2, + ); + }); + }); + }); +} + +#[test] +fn external_product() { + let log_n: usize = 8; + let module: Module = Module::::new(1 << log_n); + let basek: usize = 12; + let k_in: usize = 60; + let digits: usize = k_in.div_ceil(basek); + (1..4).for_each(|rank_in| { + (1..4).for_each(|rank_out| { + (1..digits + 1).for_each(|di| { + let k_ggsw: usize = k_in + basek * di; + println!( + "test external_product digits: {} ranks: ({} {})", + di, rank_in, rank_out + ); + let k_out: usize = k_in; // Better capture noise. + test_external_product( + &module, basek, k_out, k_in, k_ggsw, di, rank_in, rank_out, 3.2, + ); + }); + }); + }); +} + +#[test] +fn external_product_inplace() { + let log_n: usize = 5; + let module: Module = Module::::new(1 << log_n); + let basek: usize = 12; + let k_ct: usize = 60; + let digits: usize = k_ct.div_ceil(basek); + (1..4).for_each(|rank_in| { + (1..4).for_each(|rank_out| { + (1..digits).for_each(|di| { + let k_ggsw: usize = k_ct + basek * di; + println!( + "test external_product_inplace digits: {} ranks: ({} {})", + di, rank_in, rank_out + ); + test_external_product_inplace(&module, basek, k_ct, k_ggsw, di, rank_in, rank_out, 3.2); + }); + }); + }); +} diff --git a/core/src/gglwe/test/gglwe_generic.rs b/core/src/gglwe/test/gglwe_generic.rs new file mode 100644 index 0000000..3b34f5e --- /dev/null +++ b/core/src/gglwe/test/gglwe_generic.rs @@ -0,0 +1,540 @@ +use backend::hal::{ + api::{ + MatZnxAlloc, ScalarZnxAlloc, ScalarZnxAllocBytes, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAddScalarInplace, + VecZnxAlloc, VecZnxAllocBytes, VecZnxRotateInplace, VecZnxStd, VecZnxSubScalarInplace, VecZnxSwithcDegree, ZnxViewMut, + }, + layouts::{Backend, Module, ScalarZnx, ScalarZnxToMut, ScratchOwned}, + oep::{ + ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, + TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, VecZnxBigAllocBytesImpl, VecZnxDftAllocBytesImpl, + }, +}; +use sampling::source::Source; + +use crate::{ + GGLWEEncryptSkFamily, GGLWEExecLayoutFamily, GGSWCiphertext, GGSWCiphertextExec, GGSWLayoutFamily, GLWEDecryptFamily, + GLWEExternalProductFamily, GLWEKeyswitchFamily, GLWESecret, GLWESecretExec, GLWESwitchingKey, + GLWESwitchingKeyEncryptSkFamily, GLWESwitchingKeyExec, + noise::{log2_std_noise_gglwe_product, noise_ggsw_product}, +}; + +pub(crate) trait TestModuleFamily = GGLWEEncryptSkFamily + + GLWEDecryptFamily + + MatZnxAlloc + + ScalarZnxAlloc + + ScalarZnxAllocBytes + + VecZnxAllocBytes + + VecZnxSwithcDegree + + VecZnxAddScalarInplace + + VecZnxStd + + VecZnxAlloc + + VecZnxSubScalarInplace; + +pub(crate) trait TestScratchFamily = TakeVecZnxDftImpl + + TakeVecZnxBigImpl + + TakeSvpPPolImpl + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + ScratchAvailableImpl + + TakeScalarZnxImpl + + TakeVecZnxImpl + + VecZnxDftAllocBytesImpl + + VecZnxBigAllocBytesImpl + + TakeSvpPPolImpl; + +pub(crate) fn test_encrypt_sk( + module: &Module, + basek: usize, + k_ksk: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + sigma: f64, +) where + Module: TestModuleFamily, + B: TestScratchFamily, +{ + let rows: usize = (k_ksk - digits * basek) / (digits * basek); + + let mut ksk: GLWESwitchingKey> = GLWESwitchingKey::alloc(module, basek, k_ksk, rows, digits, rank_in, rank_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWESwitchingKey::encrypt_sk_scratch_space( + module, basek, k_ksk, rank_in, rank_out, + )); + + let mut sk_in: GLWESecret> = GLWESecret::alloc(module, rank_in); + sk_in.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_out: GLWESecret> = GLWESecret::alloc(module, rank_out); + sk_out.fill_ternary_prob(0.5, &mut source_xs); + let sk_out_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk_out); + + ksk.encrypt_sk( + module, + &sk_in, + &sk_out, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ksk.key + .assert_noise(module, &sk_out_exec, &sk_in.data, sigma); +} + +pub(crate) fn test_keyswitch( + module: &Module, + basek: usize, + k_out: usize, + k_in: usize, + k_ksk: usize, + digits: usize, + rank_in_s0s1: usize, + rank_out_s0s1: usize, + rank_out_s1s2: usize, + sigma: f64, +) where + Module: + TestModuleFamily + GGLWEEncryptSkFamily + GLWEDecryptFamily + GLWEKeyswitchFamily + GGLWEExecLayoutFamily, + B: TestScratchFamily, +{ + let rows: usize = k_in.div_ceil(basek * digits); + let digits_in: usize = 1; + + let mut ct_gglwe_s0s1: GLWESwitchingKey> = GLWESwitchingKey::alloc( + module, + basek, + k_in, + rows, + digits_in, + rank_in_s0s1, + rank_out_s0s1, + ); + let mut ct_gglwe_s1s2: GLWESwitchingKey> = GLWESwitchingKey::alloc( + module, + basek, + k_ksk, + rows, + digits, + rank_out_s0s1, + rank_out_s1s2, + ); + let mut ct_gglwe_s0s2: GLWESwitchingKey> = GLWESwitchingKey::alloc( + module, + basek, + k_out, + rows, + digits_in, + rank_in_s0s1, + rank_out_s1s2, + ); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch_enc: ScratchOwned = ScratchOwned::alloc(GLWESwitchingKey::encrypt_sk_scratch_space( + module, + basek, + k_ksk, + rank_in_s0s1 | rank_out_s0s1, + rank_out_s0s1 | rank_out_s1s2, + )); + let mut scratch_apply: ScratchOwned = ScratchOwned::alloc(GLWESwitchingKey::keyswitch_scratch_space( + module, + basek, + k_out, + k_in, + k_ksk, + digits, + ct_gglwe_s1s2.rank_in(), + ct_gglwe_s1s2.rank_out(), + )); + + let mut sk0: GLWESecret> = GLWESecret::alloc(module, rank_in_s0s1); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1: GLWESecret> = GLWESecret::alloc(module, rank_out_s0s1); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk2: GLWESecret> = GLWESecret::alloc(module, rank_out_s1s2); + sk2.fill_ternary_prob(0.5, &mut source_xs); + let sk2_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk2); + + // gglwe_{s1}(s0) = s0 -> s1 + ct_gglwe_s0s1.encrypt_sk( + module, + &sk0, + &sk1, + &mut source_xa, + &mut source_xe, + sigma, + scratch_enc.borrow(), + ); + + // gglwe_{s2}(s1) -> s1 -> s2 + ct_gglwe_s1s2.encrypt_sk( + module, + &sk1, + &sk2, + &mut source_xa, + &mut source_xe, + sigma, + scratch_enc.borrow(), + ); + + let ct_gglwe_s1s2_exec: GLWESwitchingKeyExec, B> = + GLWESwitchingKeyExec::from(module, &ct_gglwe_s1s2, scratch_apply.borrow()); + + // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) + ct_gglwe_s0s2.keyswitch( + module, + &ct_gglwe_s0s1, + &ct_gglwe_s1s2_exec, + scratch_apply.borrow(), + ); + + let max_noise: f64 = log2_std_noise_gglwe_product( + module.n() as f64, + basek * digits, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + rank_out_s0s1 as f64, + k_in, + k_ksk, + ); + + ct_gglwe_s0s2 + .key + .assert_noise(module, &sk2_exec, &sk0.data, max_noise + 0.5); +} + +pub(crate) fn test_keyswitch_inplace( + module: &Module, + basek: usize, + k_ct: usize, + k_ksk: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + sigma: f64, +) where + Module: TestModuleFamily + + GLWESwitchingKeyEncryptSkFamily + + GLWEKeyswitchFamily + + GGLWEExecLayoutFamily + + GLWEDecryptFamily, + B: TestScratchFamily, +{ + let rows: usize = k_ct.div_ceil(basek * digits); + let digits_in: usize = 1; + + let mut ct_gglwe_s0s1: GLWESwitchingKey> = + GLWESwitchingKey::alloc(module, basek, k_ct, rows, digits_in, rank_in, rank_out); + let mut ct_gglwe_s1s2: GLWESwitchingKey> = + GLWESwitchingKey::alloc(module, basek, k_ksk, rows, digits, rank_out, rank_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch_enc: ScratchOwned = ScratchOwned::alloc(GLWESwitchingKey::encrypt_sk_scratch_space( + module, + basek, + k_ksk, + rank_in | rank_out, + rank_out, + )); + let mut scratch_apply: ScratchOwned = ScratchOwned::alloc(GLWESwitchingKey::keyswitch_inplace_scratch_space( + module, basek, k_ct, k_ksk, digits, rank_out, + )); + + let var_xs: f64 = 0.5; + + let mut sk0: GLWESecret> = GLWESecret::alloc(module, rank_in); + sk0.fill_ternary_prob(var_xs, &mut source_xs); + + let mut sk1: GLWESecret> = GLWESecret::alloc(module, rank_out); + sk1.fill_ternary_prob(var_xs, &mut source_xs); + + let mut sk2: GLWESecret> = GLWESecret::alloc(module, rank_out); + sk2.fill_ternary_prob(var_xs, &mut source_xs); + let sk2_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk2); + + // gglwe_{s1}(s0) = s0 -> s1 + ct_gglwe_s0s1.encrypt_sk( + module, + &sk0, + &sk1, + &mut source_xa, + &mut source_xe, + sigma, + scratch_enc.borrow(), + ); + + // gglwe_{s2}(s1) -> s1 -> s2 + ct_gglwe_s1s2.encrypt_sk( + module, + &sk1, + &sk2, + &mut source_xa, + &mut source_xe, + sigma, + scratch_enc.borrow(), + ); + + let ct_gglwe_s1s2_exec: GLWESwitchingKeyExec, B> = + GLWESwitchingKeyExec::from(module, &ct_gglwe_s1s2, scratch_apply.borrow()); + + // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) + ct_gglwe_s0s1.keyswitch_inplace(module, &ct_gglwe_s1s2_exec, scratch_apply.borrow()); + + let ct_gglwe_s0s2: GLWESwitchingKey> = ct_gglwe_s0s1; + + let max_noise: f64 = log2_std_noise_gglwe_product( + module.n() as f64, + basek * digits, + var_xs, + var_xs, + 0f64, + sigma * sigma, + 0f64, + rank_out as f64, + k_ct, + k_ksk, + ); + + ct_gglwe_s0s2 + .key + .assert_noise(module, &sk2_exec, &sk0.data, max_noise + 0.5); +} + +pub(crate) fn test_external_product( + module: &Module, + basek: usize, + k_out: usize, + k_in: usize, + k_ggsw: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + sigma: f64, +) where + Module: TestModuleFamily + + GLWESwitchingKeyEncryptSkFamily + + GLWEExternalProductFamily + + GGSWLayoutFamily + + GLWEDecryptFamily + + VecZnxRotateInplace, + B: TestScratchFamily, +{ + let rows: usize = k_in.div_ceil(basek * digits); + let digits_in: usize = 1; + + let mut ct_gglwe_in: GLWESwitchingKey> = + GLWESwitchingKey::alloc(module, basek, k_in, rows, digits_in, rank_in, rank_out); + let mut ct_gglwe_out: GLWESwitchingKey> = + GLWESwitchingKey::alloc(module, basek, k_out, rows, digits_in, rank_in, rank_out); + let mut ct_rgsw: GGSWCiphertext> = GGSWCiphertext::alloc(module, basek, k_ggsw, rows, digits, rank_out); + + let mut pt_rgsw: ScalarZnx> = module.scalar_znx_alloc(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k_in, rank_in, rank_out) + | GLWESwitchingKey::external_product_scratch_space(module, basek, k_out, k_in, k_ggsw, digits, rank_out) + | GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_ggsw, rank_out), + ); + + let r: usize = 1; + + pt_rgsw.to_mut().raw_mut()[r] = 1; // X^{r} + + let var_xs: f64 = 0.5; + + let mut sk_in: GLWESecret> = GLWESecret::alloc(module, rank_in); + sk_in.fill_ternary_prob(var_xs, &mut source_xs); + + let mut sk_out: GLWESecret> = GLWESecret::alloc(module, rank_out); + sk_out.fill_ternary_prob(var_xs, &mut source_xs); + let sk_out_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk_out); + + // gglwe_{s1}(s0) = s0 -> s1 + ct_gglwe_in.encrypt_sk( + module, + &sk_in, + &sk_out, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_rgsw.encrypt_sk( + module, + &pt_rgsw, + &sk_out_exec, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + let mut ct_rgsw_exec: GGSWCiphertextExec, B> = + GGSWCiphertextExec::alloc(module, basek, k_ggsw, rows, digits, rank_out); + + ct_rgsw_exec.prepare(module, &ct_rgsw, scratch.borrow()); + + // gglwe_(m) (x) RGSW_(X^k) = gglwe_(m * X^k) + ct_gglwe_out.external_product(module, &ct_gglwe_in, &ct_rgsw_exec, scratch.borrow()); + + (0..rank_in).for_each(|i| { + module.vec_znx_rotate_inplace(r as i64, &mut sk_in.data, i); // * X^{r} + }); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let max_noise: f64 = noise_ggsw_product( + module.n() as f64, + basek * digits, + var_xs, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + rank_out as f64, + k_in, + k_ggsw, + ); + + ct_gglwe_out + .key + .assert_noise(module, &sk_out_exec, &sk_in.data, max_noise + 0.5); +} + +pub(crate) fn test_external_product_inplace( + module: &Module, + basek: usize, + k_ct: usize, + k_ggsw: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + sigma: f64, +) where + Module: TestModuleFamily + + GLWESwitchingKeyEncryptSkFamily + + GLWEExternalProductFamily + + GGSWLayoutFamily + + GLWEDecryptFamily + + VecZnxRotateInplace, + B: TestScratchFamily, +{ + let rows: usize = k_ct.div_ceil(basek * digits); + + let digits_in: usize = 1; + + let mut ct_gglwe: GLWESwitchingKey> = + GLWESwitchingKey::alloc(module, basek, k_ct, rows, digits_in, rank_in, rank_out); + let mut ct_rgsw: GGSWCiphertext> = GGSWCiphertext::alloc(module, basek, k_ggsw, rows, digits, rank_out); + + let mut pt_rgsw: ScalarZnx> = module.scalar_znx_alloc(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k_ct, rank_in, rank_out) + | GLWESwitchingKey::external_product_inplace_scratch_space(module, basek, k_ct, k_ggsw, digits, rank_out) + | GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_ggsw, rank_out), + ); + + let r: usize = 1; + + pt_rgsw.to_mut().raw_mut()[r] = 1; // X^{r} + + let var_xs: f64 = 0.5; + + let mut sk_in: GLWESecret> = GLWESecret::alloc(module, rank_in); + sk_in.fill_ternary_prob(var_xs, &mut source_xs); + + let mut sk_out: GLWESecret> = GLWESecret::alloc(module, rank_out); + sk_out.fill_ternary_prob(var_xs, &mut source_xs); + let sk_out_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk_out); + + // gglwe_{s1}(s0) = s0 -> s1 + ct_gglwe.encrypt_sk( + module, + &sk_in, + &sk_out, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_rgsw.encrypt_sk( + module, + &pt_rgsw, + &sk_out_exec, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + let mut ct_rgsw_exec: GGSWCiphertextExec, B> = + GGSWCiphertextExec::alloc(module, basek, k_ggsw, rows, digits, rank_out); + + ct_rgsw_exec.prepare(module, &ct_rgsw, scratch.borrow()); + + // gglwe_(m) (x) RGSW_(X^k) = gglwe_(m * X^k) + ct_gglwe.external_product_inplace(module, &ct_rgsw_exec, scratch.borrow()); + + (0..rank_in).for_each(|i| { + module.vec_znx_rotate_inplace(r as i64, &mut sk_in.data, i); // * X^{r} + }); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let max_noise: f64 = noise_ggsw_product( + module.n() as f64, + basek * digits, + var_xs, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + rank_out as f64, + k_ct, + k_ggsw, + ); + + ct_gglwe + .key + .assert_noise(module, &sk_out_exec, &sk_in.data, max_noise + 0.5); +} diff --git a/core/src/gglwe/test/mod.rs b/core/src/gglwe/test/mod.rs new file mode 100644 index 0000000..6c46bd5 --- /dev/null +++ b/core/src/gglwe/test/mod.rs @@ -0,0 +1,5 @@ +mod automorphism_key; +mod gglwe_fft64; +mod gglwe_generic; +mod tensor_key_fft64; +mod tensor_key_generic; diff --git a/core/src/gglwe/test/tensor_key_fft64.rs b/core/src/gglwe/test/tensor_key_fft64.rs new file mode 100644 index 0000000..d610928 --- /dev/null +++ b/core/src/gglwe/test/tensor_key_fft64.rs @@ -0,0 +1,16 @@ +use backend::{ + hal::{api::ModuleNew, layouts::Module}, + implementation::cpu_spqlios::FFT64, +}; + +use crate::gglwe::test::tensor_key_generic::test_encrypt_sk; + +#[test] +fn encrypt_sk() { + let log_n: usize = 8; + let module: Module = Module::::new(1 << log_n); + (1..4).for_each(|rank| { + println!("test encrypt_sk rank: {}", rank); + test_encrypt_sk(&module, 16, 54, 3.2, rank); + }); +} diff --git a/core/src/gglwe/test/tensor_key_generic.rs b/core/src/gglwe/test/tensor_key_generic.rs new file mode 100644 index 0000000..fb524bf --- /dev/null +++ b/core/src/gglwe/test/tensor_key_generic.rs @@ -0,0 +1,113 @@ +use backend::hal::{ + api::{ + MatZnxAlloc, ScalarZnxAlloc, ScalarZnxAllocBytes, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAddScalarInplace, + VecZnxAlloc, VecZnxAllocBytes, VecZnxBigAlloc, VecZnxDftAlloc, VecZnxStd, VecZnxSubScalarInplace, VecZnxSwithcDegree, + }, + layouts::{Backend, Module, ScratchOwned, VecZnxDft}, + oep::{ + ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, + TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, VecZnxBigAllocBytesImpl, VecZnxDftAllocBytesImpl, + }, +}; +use sampling::source::Source; + +use crate::{ + GGLWEEncryptSkFamily, GGLWEExecLayoutFamily, GLWEDecryptFamily, GLWEPlaintext, GLWESecret, GLWESecretExec, GLWETensorKey, + GLWETensorKeyEncryptSkFamily, Infos, +}; + +pub(crate) trait TestModuleFamily = GGLWEEncryptSkFamily + + GLWEDecryptFamily + + MatZnxAlloc + + ScalarZnxAlloc + + ScalarZnxAllocBytes + + VecZnxAllocBytes + + VecZnxSwithcDegree + + VecZnxAddScalarInplace + + VecZnxStd + + VecZnxAlloc + + VecZnxSubScalarInplace; + +pub(crate) trait TestScratchFamily = TakeVecZnxDftImpl + + TakeVecZnxBigImpl + + TakeSvpPPolImpl + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + ScratchAvailableImpl + + TakeScalarZnxImpl + + TakeVecZnxImpl + + VecZnxDftAllocBytesImpl + + VecZnxBigAllocBytesImpl + + TakeSvpPPolImpl; + +pub(crate) fn test_encrypt_sk(module: &Module, basek: usize, k: usize, sigma: f64, rank: usize) +where + Module: TestModuleFamily + + GGLWEExecLayoutFamily + + GLWETensorKeyEncryptSkFamily + + GLWEDecryptFamily + + VecZnxDftAlloc + + VecZnxBigAlloc, + B: TestScratchFamily, +{ + let rows: usize = k / basek; + + let mut tensor_key: GLWETensorKey> = GLWETensorKey::alloc(&module, basek, k, rows, 1, rank); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWETensorKey::encrypt_sk_scratch_space( + module, + basek, + tensor_key.k(), + rank, + )); + + let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_exec: GLWESecretExec, B> = GLWESecretExec::from(&module, &sk); + sk_exec.prepare(module, &sk); + + tensor_key.encrypt_sk( + module, + &sk, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k); + + let mut sk_ij_dft = module.vec_znx_dft_alloc(1, 1); + let mut sk_ij_big = module.vec_znx_big_alloc(1, 1); + let mut sk_ij: GLWESecret> = GLWESecret::alloc(&module, 1); + let mut sk_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(rank, 1); + + (0..rank).for_each(|i| { + module.vec_znx_dft_from_vec_znx(1, 0, &mut sk_dft, i, &sk.data, i); + }); + + (0..rank).for_each(|i| { + (0..rank).for_each(|j| { + module.svp_apply(&mut sk_ij_dft, 0, &sk_exec.data, j, &sk_dft, i); + module.vec_znx_dft_to_vec_znx_big_tmp_a(&mut sk_ij_big, 0, &mut sk_ij_dft, 0); + module.vec_znx_big_normalize(basek, &mut sk_ij.data, 0, &sk_ij_big, 0, scratch.borrow()); + (0..tensor_key.rank_in()).for_each(|col_i| { + (0..tensor_key.rows()).for_each(|row_i| { + tensor_key + .at(i, j) + .at(row_i, col_i) + .decrypt(&module, &mut pt, &sk_exec, scratch.borrow()); + + module.vec_znx_sub_scalar_inplace(&mut pt.data, 0, row_i, &sk_ij.data, col_i); + + let std_pt: f64 = module.vec_znx_std(basek, &pt.data, 0) * (k as f64).exp2(); + assert!((sigma - std_pt).abs() <= 0.5, "{} {}", sigma, std_pt); + }); + }); + }) + }) +} diff --git a/core/src/gglwe/test_fft64/gglwe.rs b/core/src/gglwe/test_fft64/gglwe.rs deleted file mode 100644 index 492d3b8..0000000 --- a/core/src/gglwe/test_fft64/gglwe.rs +++ /dev/null @@ -1,680 +0,0 @@ -use backend::{FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxToMut, ScratchOwned, Stats, VecZnxOps, ZnxViewMut}; -use sampling::source::Source; - -use crate::{ - FourierGLWECiphertext, FourierGLWESecret, GGSWCiphertext, GLWEPlaintext, GLWESecret, GLWESwitchingKey, GetRow, Infos, - noise::{log2_std_noise_gglwe_product, noise_ggsw_product}, -}; - -#[test] -fn encrypt_sk() { - let log_n: usize = 8; - let basek: usize = 12; - let k_ksk: usize = 54; - let digits: usize = k_ksk / basek; - (1..4).for_each(|rank_in| { - (1..4).for_each(|rank_out| { - (1..digits + 1).for_each(|di| { - println!( - "test encrypt_sk digits: {} ranks: ({} {})", - di, rank_in, rank_out - ); - test_encrypt_sk(log_n, basek, k_ksk, di, rank_in, rank_out, 3.2); - }); - }); - }); -} - -#[test] -fn key_switch() { - let log_n: usize = 8; - let basek: usize = 12; - let k_in: usize = 60; - let digits: usize = k_in.div_ceil(basek); - (1..4).for_each(|rank_in_s0s1| { - (1..4).for_each(|rank_out_s0s1| { - (1..4).for_each(|rank_out_s1s2| { - (1..digits + 1).for_each(|di| { - let k_ksk: usize = k_in + basek * di; - println!( - "test key_switch digits: {} ranks: ({},{},{})", - di, rank_in_s0s1, rank_out_s0s1, rank_out_s1s2 - ); - let k_out: usize = k_ksk; // Better capture noise. - test_key_switch( - log_n, - basek, - k_out, - k_in, - k_ksk, - di, - rank_in_s0s1, - rank_out_s0s1, - rank_out_s1s2, - 3.2, - ); - }) - }) - }); - }); -} - -#[test] -fn key_switch_inplace() { - let log_n: usize = 8; - let basek: usize = 12; - let k_ct: usize = 60; - let digits: usize = k_ct.div_ceil(basek); - (1..4).for_each(|rank_in_s0s1| { - (1..4).for_each(|rank_out_s0s1| { - (1..digits + 1).for_each(|di| { - let k_ksk: usize = k_ct + basek * di; - println!( - "test key_switch_inplace digits: {} ranks: ({},{})", - di, rank_in_s0s1, rank_out_s0s1 - ); - test_key_switch_inplace( - log_n, - basek, - k_ct, - k_ksk, - di, - rank_in_s0s1, - rank_out_s0s1, - 3.2, - ); - }); - }); - }); -} - -#[test] -fn external_product() { - let log_n: usize = 8; - let basek: usize = 12; - let k_in: usize = 60; - let digits: usize = k_in.div_ceil(basek); - (1..4).for_each(|rank_in| { - (1..4).for_each(|rank_out| { - (1..digits + 1).for_each(|di| { - let k_ggsw: usize = k_in + basek * di; - println!( - "test external_product digits: {} ranks: ({} {})", - di, rank_in, rank_out - ); - let k_out: usize = k_in; // Better capture noise. - test_external_product( - log_n, basek, k_out, k_in, k_ggsw, di, rank_in, rank_out, 3.2, - ); - }); - }); - }); -} - -#[test] -fn external_product_inplace() { - let log_n: usize = 5; - let basek: usize = 12; - let k_ct: usize = 60; - let digits: usize = k_ct.div_ceil(basek); - (1..4).for_each(|rank_in| { - (1..4).for_each(|rank_out| { - (1..digits).for_each(|di| { - let k_ggsw: usize = k_ct + basek * di; - println!( - "test external_product_inplace digits: {} ranks: ({} {})", - di, rank_in, rank_out - ); - test_external_product_inplace(log_n, basek, k_ct, k_ggsw, di, rank_in, rank_out, 3.2); - }); - }); - }); -} - -fn test_encrypt_sk(log_n: usize, basek: usize, k_ksk: usize, digits: usize, rank_in: usize, rank_out: usize, sigma: f64) { - let module: Module = Module::::new(1 << log_n); - let rows: usize = (k_ksk - digits * basek) / (digits * basek); - - let mut ksk: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank_in, rank_out); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ksk); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank_in, rank_out) - | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_ksk), - ); - - let mut sk_in: GLWESecret> = GLWESecret::alloc(&module, rank_in); - sk_in.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk_out: GLWESecret> = GLWESecret::alloc(&module, rank_out); - sk_out.fill_ternary_prob(0.5, &mut source_xs); - let sk_out_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_out); - - ksk.encrypt_sk( - &module, - &sk_in, - &sk_out_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - let mut ct_glwe_fourier: FourierGLWECiphertext, FFT64> = - FourierGLWECiphertext::alloc(&module, basek, k_ksk, rank_out); - - (0..ksk.rank_in()).for_each(|col_i| { - (0..ksk.rows()).for_each(|row_i| { - ksk.get_row(&module, row_i, col_i, &mut ct_glwe_fourier); - ct_glwe_fourier.decrypt(&module, &mut pt, &sk_out_dft, scratch.borrow()); - module.vec_znx_sub_scalar_inplace( - &mut pt.data, - 0, - (digits - 1) + row_i * digits, - &sk_in.data, - col_i, - ); - let std_pt: f64 = pt.data.std(0, basek) * (k_ksk as f64).exp2(); - assert!((sigma - std_pt).abs() <= 0.5, "{} {}", sigma, std_pt); - }); - }); -} - -fn test_key_switch( - log_n: usize, - basek: usize, - k_out: usize, - k_in: usize, - k_ksk: usize, - digits: usize, - rank_in_s0s1: usize, - rank_out_s0s1: usize, - rank_out_s1s2: usize, - sigma: f64, -) { - let module: Module = Module::::new(1 << log_n); - let rows: usize = k_in.div_ceil(basek * digits); - let digits_in: usize = 1; - - let mut ct_gglwe_s0s1: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc( - &module, - basek, - k_in, - rows, - digits_in, - rank_in_s0s1, - rank_out_s0s1, - ); - let mut ct_gglwe_s1s2: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc( - &module, - basek, - k_ksk, - rows, - digits, - rank_out_s0s1, - rank_out_s1s2, - ); - let mut ct_gglwe_s0s2: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc( - &module, - basek, - k_out, - rows, - digits_in, - rank_in_s0s1, - rank_out_s1s2, - ); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space( - &module, - basek, - k_ksk, - rank_in_s0s1, - rank_in_s0s1 | rank_out_s0s1, - ) | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_out) - | GLWESwitchingKey::keyswitch_scratch_space( - &module, - basek, - k_out, - k_in, - k_ksk, - digits, - ct_gglwe_s1s2.rank_in(), - ct_gglwe_s1s2.rank_out(), - ), - ); - - let mut sk0: GLWESecret> = GLWESecret::alloc(&module, rank_in_s0s1); - sk0.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk1: GLWESecret> = GLWESecret::alloc(&module, rank_out_s0s1); - sk1.fill_ternary_prob(0.5, &mut source_xs); - let sk1_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk1); - - let mut sk2: GLWESecret> = GLWESecret::alloc(&module, rank_out_s1s2); - sk2.fill_ternary_prob(0.5, &mut source_xs); - let sk2_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk2); - - // gglwe_{s1}(s0) = s0 -> s1 - ct_gglwe_s0s1.encrypt_sk( - &module, - &sk0, - &sk1_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - // gglwe_{s2}(s1) -> s1 -> s2 - ct_gglwe_s1s2.encrypt_sk( - &module, - &sk1, - &sk2_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) - ct_gglwe_s0s2.keyswitch(&module, &ct_gglwe_s0s1, &ct_gglwe_s1s2, scratch.borrow()); - - let mut ct_glwe_dft: FourierGLWECiphertext, FFT64> = - FourierGLWECiphertext::alloc(&module, basek, k_out, rank_out_s1s2); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); - - (0..ct_gglwe_s0s2.rank_in()).for_each(|col_i| { - (0..ct_gglwe_s0s2.rows()).for_each(|row_i| { - ct_gglwe_s0s2.get_row(&module, row_i, col_i, &mut ct_glwe_dft); - ct_glwe_dft.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow()); - module.vec_znx_sub_scalar_inplace( - &mut pt.data, - 0, - (digits_in - 1) + row_i * digits_in, - &sk0.data, - col_i, - ); - - let noise_have: f64 = pt.data.std(0, basek).log2(); - let noise_want: f64 = log2_std_noise_gglwe_product( - module.n() as f64, - basek * digits, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - rank_out_s0s1 as f64, - k_in, - k_ksk, - ); - - assert!( - (noise_have - noise_want).abs() <= 1.0, - "{} {}", - noise_have, - noise_want - ); - }); - }); -} - -fn test_key_switch_inplace( - log_n: usize, - basek: usize, - k_ct: usize, - k_ksk: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - sigma: f64, -) { - let module: Module = Module::::new(1 << log_n); - let rows: usize = k_ct.div_ceil(basek * digits); - let digits_in: usize = 1; - - let mut ct_gglwe_s0s1: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::alloc(&module, basek, k_ct, rows, digits_in, rank_in, rank_out); - let mut ct_gglwe_s1s2: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank_out, rank_out); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank_in, rank_out) - | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_ksk) - | GLWESwitchingKey::keyswitch_inplace_scratch_space(&module, basek, k_ct, k_ksk, digits, rank_out), - ); - - let mut sk0: GLWESecret> = GLWESecret::alloc(&module, rank_in); - sk0.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk1: GLWESecret> = GLWESecret::alloc(&module, rank_out); - sk1.fill_ternary_prob(0.5, &mut source_xs); - let sk1_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk1); - - let mut sk2: GLWESecret> = GLWESecret::alloc(&module, rank_out); - sk2.fill_ternary_prob(0.5, &mut source_xs); - let sk2_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk2); - - // gglwe_{s1}(s0) = s0 -> s1 - ct_gglwe_s0s1.encrypt_sk( - &module, - &sk0, - &sk1_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - // gglwe_{s2}(s1) -> s1 -> s2 - ct_gglwe_s1s2.encrypt_sk( - &module, - &sk1, - &sk2_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) - ct_gglwe_s0s1.keyswitch_inplace(&module, &ct_gglwe_s1s2, scratch.borrow()); - - let ct_gglwe_s0s2: GLWESwitchingKey, FFT64> = ct_gglwe_s0s1; - - let mut ct_glwe_dft: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_ct, rank_out); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); - - (0..ct_gglwe_s0s2.rank_in()).for_each(|col_i| { - (0..ct_gglwe_s0s2.rows()).for_each(|row_i| { - ct_gglwe_s0s2.get_row(&module, row_i, col_i, &mut ct_glwe_dft); - ct_glwe_dft.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow()); - module.vec_znx_sub_scalar_inplace( - &mut pt.data, - 0, - (digits_in - 1) + row_i * digits_in, - &sk0.data, - col_i, - ); - - let noise_have: f64 = pt.data.std(0, basek).log2(); - let noise_want: f64 = log2_std_noise_gglwe_product( - module.n() as f64, - basek * digits, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - rank_out as f64, - k_ct, - k_ksk, - ); - - assert!( - (noise_have - noise_want).abs() <= 1.0, - "{} {}", - noise_have, - noise_want - ); - }); - }); -} - -fn test_external_product( - log_n: usize, - basek: usize, - k_out: usize, - k_in: usize, - k_ggsw: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - sigma: f64, -) { - let module: Module = Module::::new(1 << log_n); - - let rows: usize = k_in.div_ceil(basek * digits); - let digits_in: usize = 1; - - let mut ct_gglwe_in: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::alloc(&module, basek, k_in, rows, digits_in, rank_in, rank_out); - let mut ct_gglwe_out: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::alloc(&module, basek, k_out, rows, digits_in, rank_in, rank_out); - let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, digits, rank_out); - - let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_in, rank_in, rank_out) - | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_out) - | GLWESwitchingKey::external_product_scratch_space(&module, basek, k_out, k_in, k_ggsw, digits, rank_out) - | GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k_ggsw, rank_out), - ); - - let r: usize = 1; - - pt_rgsw.to_mut().raw_mut()[r] = 1; // X^{r} - - let mut sk_in: GLWESecret> = GLWESecret::alloc(&module, rank_in); - sk_in.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk_out: GLWESecret> = GLWESecret::alloc(&module, rank_out); - sk_out.fill_ternary_prob(0.5, &mut source_xs); - let sk_out_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_out); - - // gglwe_{s1}(s0) = s0 -> s1 - ct_gglwe_in.encrypt_sk( - &module, - &sk_in, - &sk_out_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_rgsw.encrypt_sk( - &module, - &pt_rgsw, - &sk_out_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - // gglwe_(m) (x) RGSW_(X^k) = gglwe_(m * X^k) - ct_gglwe_out.external_product(&module, &ct_gglwe_in, &ct_rgsw, scratch.borrow()); - - let mut ct_glwe_dft: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_out, rank_out); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); - - (0..rank_in).for_each(|i| { - module.vec_znx_rotate_inplace(r as i64, &mut sk_in.data, i); // * X^{r} - }); - - (0..rank_in).for_each(|col_i| { - (0..ct_gglwe_out.rows()).for_each(|row_i| { - ct_gglwe_out.get_row(&module, row_i, col_i, &mut ct_glwe_dft); - ct_glwe_dft.decrypt(&module, &mut pt, &sk_out_dft, scratch.borrow()); - - module.vec_znx_sub_scalar_inplace( - &mut pt.data, - 0, - (digits_in - 1) + row_i * digits_in, - &sk_in.data, - col_i, - ); - - let noise_have: f64 = pt.data.std(0, basek).log2(); - - let var_gct_err_lhs: f64 = sigma * sigma; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} - let var_a0_err: f64 = sigma * sigma; - let var_a1_err: f64 = 1f64 / 12f64; - - let noise_want: f64 = noise_ggsw_product( - module.n() as f64, - basek * digits, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - rank_out as f64, - k_in, - k_ggsw, - ); - - assert!( - (noise_have - noise_want).abs() <= 1.0, - "{} {}", - noise_have, - noise_want - ); - }); - }); -} - -fn test_external_product_inplace( - log_n: usize, - basek: usize, - k_ct: usize, - k_ggsw: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - sigma: f64, -) { - let module: Module = Module::::new(1 << log_n); - - let rows: usize = k_ct.div_ceil(basek * digits); - - let digits_in: usize = 1; - - let mut ct_gglwe: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::alloc(&module, basek, k_ct, rows, digits_in, rank_in, rank_out); - let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, digits, rank_out); - - let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ct, rank_in, rank_out) - | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_ct) - | GLWESwitchingKey::external_product_inplace_scratch_space(&module, basek, k_ct, k_ggsw, digits, rank_out) - | GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k_ggsw, rank_out), - ); - - let r: usize = 1; - - pt_rgsw.to_mut().raw_mut()[r] = 1; // X^{r} - - let mut sk_in: GLWESecret> = GLWESecret::alloc(&module, rank_in); - sk_in.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk_out: GLWESecret> = GLWESecret::alloc(&module, rank_out); - sk_out.fill_ternary_prob(0.5, &mut source_xs); - let sk_out_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_out); - - // gglwe_{s1}(s0) = s0 -> s1 - ct_gglwe.encrypt_sk( - &module, - &sk_in, - &sk_out_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_rgsw.encrypt_sk( - &module, - &pt_rgsw, - &sk_out_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - // gglwe_(m) (x) RGSW_(X^k) = gglwe_(m * X^k) - ct_gglwe.external_product_inplace(&module, &ct_rgsw, scratch.borrow()); - - let mut ct_glwe_dft: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_ct, rank_out); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); - - (0..rank_in).for_each(|i| { - module.vec_znx_rotate_inplace(r as i64, &mut sk_in.data, i); // * X^{r} - }); - - (0..rank_in).for_each(|col_i| { - (0..ct_gglwe.rows()).for_each(|row_i| { - ct_gglwe.get_row(&module, row_i, col_i, &mut ct_glwe_dft); - ct_glwe_dft.decrypt(&module, &mut pt, &sk_out_dft, scratch.borrow()); - - module.vec_znx_sub_scalar_inplace( - &mut pt.data, - 0, - (digits_in - 1) + row_i * digits_in, - &sk_in.data, - col_i, - ); - - let noise_have: f64 = pt.data.std(0, basek).log2(); - - let var_gct_err_lhs: f64 = sigma * sigma; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} - let var_a0_err: f64 = sigma * sigma; - let var_a1_err: f64 = 1f64 / 12f64; - - let noise_want: f64 = noise_ggsw_product( - module.n() as f64, - basek * digits, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - rank_out as f64, - k_ct, - k_ggsw, - ); - - assert!( - (noise_have - noise_want).abs() <= 1.0, - "{} {}", - noise_have, - noise_want - ); - }); - }); -} diff --git a/core/src/gglwe/test_fft64/mod.rs b/core/src/gglwe/test_fft64/mod.rs deleted file mode 100644 index 49d23cd..0000000 --- a/core/src/gglwe/test_fft64/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub mod automorphism_key; -pub mod gglwe; -pub mod tensor_key; diff --git a/core/src/gglwe/test_fft64/tensor_key.rs b/core/src/gglwe/test_fft64/tensor_key.rs deleted file mode 100644 index ab1d191..0000000 --- a/core/src/gglwe/test_fft64/tensor_key.rs +++ /dev/null @@ -1,69 +0,0 @@ -use backend::{FFT64, Module, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxOps}; -use sampling::source::Source; - -use crate::{FourierGLWECiphertext, FourierGLWESecret, GLWEPlaintext, GLWESecret, GLWETensorKey, GetRow, Infos}; - -#[test] -fn encrypt_sk() { - let log_n: usize = 8; - (1..4).for_each(|rank| { - println!("test encrypt_sk rank: {}", rank); - test_encrypt_sk(log_n, 16, 54, 3.2, rank); - }); -} - -fn test_encrypt_sk(log_n: usize, basek: usize, k: usize, sigma: f64, rank: usize) { - let module: Module = Module::::new(1 << log_n); - - let rows: usize = k / basek; - - let mut tensor_key: GLWETensorKey, FFT64> = GLWETensorKey::alloc(&module, basek, k, rows, 1, rank); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::new(GLWETensorKey::encrypt_sk_scratch_space( - &module, - basek, - tensor_key.k(), - rank, - )); - - let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); - - tensor_key.encrypt_sk( - &module, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - let mut ct_glwe_fourier: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k, rank); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k); - - let mut sk_ij_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::alloc(&module, 1); - let mut sk_ij: GLWESecret> = GLWESecret::alloc(&module, 1); - - (0..rank).for_each(|i| { - (0..rank).for_each(|j| { - module.svp_apply(&mut sk_ij_dft.data, 0, &sk_dft.data, i, &sk_dft.data, j); - module.scalar_znx_idft(&mut sk_ij.data, 0, &sk_ij_dft.data, 0, scratch.borrow()); - (0..tensor_key.rank_in()).for_each(|col_i| { - (0..tensor_key.rows()).for_each(|row_i| { - tensor_key - .at(i, j) - .get_row(&module, row_i, col_i, &mut ct_glwe_fourier); - ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); - module.vec_znx_sub_scalar_inplace(&mut pt.data, 0, row_i, &sk_ij.data, col_i); - let std_pt: f64 = pt.data.std(0, basek) * (k as f64).exp2(); - assert!((sigma - std_pt).abs() <= 0.5, "{} {}", sigma, std_pt); - }); - }); - }) - }) -} diff --git a/core/src/ggsw/automorphism.rs b/core/src/ggsw/automorphism.rs new file mode 100644 index 0000000..c983976 --- /dev/null +++ b/core/src/ggsw/automorphism.rs @@ -0,0 +1,147 @@ +use backend::hal::{ + api::{ScratchAvailable, TakeVecZnxBig, TakeVecZnxDft, VecZnxAutomorphismInplace, VecZnxNormalizeTmpBytes}, + layouts::{Backend, DataMut, DataRef, Module, Scratch}, +}; + +use crate::{ + AutomorphismKeyExec, GGSWCiphertext, GGSWKeySwitchFamily, GLWECiphertext, GLWEKeyswitchFamily, GLWETensorKeyExec, Infos, +}; + +impl GGSWCiphertext> { + pub fn automorphism_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + k_in: usize, + k_ksk: usize, + digits_ksk: usize, + k_tsk: usize, + digits_tsk: usize, + rank: usize, + ) -> usize + where + Module: GLWEKeyswitchFamily + GGSWKeySwitchFamily + VecZnxNormalizeTmpBytes, + { + let out_size: usize = k_out.div_ceil(basek); + let ci_dft: usize = module.vec_znx_dft_alloc_bytes(rank + 1, out_size); + let ks_internal: usize = + GLWECiphertext::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits_ksk, rank, rank); + let expand: usize = GGSWCiphertext::expand_row_scratch_space(module, basek, k_out, k_tsk, digits_tsk, rank); + ci_dft + (ks_internal | expand) + } + + pub fn automorphism_inplace_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + k_ksk: usize, + digits_ksk: usize, + k_tsk: usize, + digits_tsk: usize, + rank: usize, + ) -> usize + where + Module: GLWEKeyswitchFamily + GGSWKeySwitchFamily + VecZnxNormalizeTmpBytes, + { + GGSWCiphertext::automorphism_scratch_space( + module, basek, k_out, k_out, k_ksk, digits_ksk, k_tsk, digits_tsk, rank, + ) + } +} + +impl GGSWCiphertext { + pub fn automorphism( + &mut self, + module: &Module, + lhs: &GGSWCiphertext, + auto_key: &AutomorphismKeyExec, + tensor_key: &GLWETensorKeyExec, + scratch: &mut Scratch, + ) where + Module: GLWEKeyswitchFamily + GGSWKeySwitchFamily + VecZnxAutomorphismInplace + VecZnxNormalizeTmpBytes, + Scratch: ScratchAvailable + TakeVecZnxDft + TakeVecZnxBig, + { + #[cfg(debug_assertions)] + { + use crate::Infos; + + assert_eq!( + self.rank(), + lhs.rank(), + "ggsw_out rank: {} != ggsw_in rank: {}", + self.rank(), + lhs.rank() + ); + assert_eq!( + self.rank(), + auto_key.rank(), + "ggsw_in rank: {} != auto_key rank: {}", + self.rank(), + auto_key.rank() + ); + assert_eq!( + self.rank(), + tensor_key.rank(), + "ggsw_in rank: {} != tensor_key rank: {}", + self.rank(), + tensor_key.rank() + ); + assert!( + scratch.available() + >= GGSWCiphertext::automorphism_scratch_space( + module, + self.basek(), + self.k(), + lhs.k(), + auto_key.k(), + auto_key.digits(), + tensor_key.k(), + tensor_key.digits(), + self.rank(), + ) + ) + }; + + let rank: usize = self.rank(); + let cols: usize = rank + 1; + + // Keyswitch the j-th row of the col 0 + (0..lhs.rows()).for_each(|row_i| { + // Key-switch column 0, i.e. + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2) + self.at_mut(row_i, 0) + .automorphism(module, &lhs.at(row_i, 0), auto_key, scratch); + + // Isolates DFT(AUTO(a[i])) + let (mut ci_dft, scratch1) = scratch.take_vec_znx_dft(module, cols, self.size()); + (0..cols).for_each(|i| { + module.vec_znx_dft_from_vec_znx(1, 0, &mut ci_dft, i, &self.at(row_i, 0).data, i); + }); + + // Generates + // + // col 1: (-(b0s0 + b1s1 + b2s2) , b0 + pi(M[i]), b1 , b2 ) + // col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + pi(M[i]), c2 ) + // col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + pi(M[i])) + (1..cols).for_each(|col_j| { + self.expand_row(module, row_i, col_j, &ci_dft, tensor_key, scratch1); + }); + }) + } + + pub fn automorphism_inplace( + &mut self, + module: &Module, + auto_key: &AutomorphismKeyExec, + tensor_key: &GLWETensorKeyExec, + scratch: &mut Scratch, + ) where + Module: GLWEKeyswitchFamily + GGSWKeySwitchFamily + VecZnxAutomorphismInplace + VecZnxNormalizeTmpBytes, + Scratch: ScratchAvailable + TakeVecZnxDft + TakeVecZnxBig, + { + unsafe { + let self_ptr: *mut GGSWCiphertext = self as *mut GGSWCiphertext; + self.automorphism(module, &*self_ptr, auto_key, tensor_key, scratch); + } + } +} diff --git a/core/src/ggsw/ciphertext.rs b/core/src/ggsw/ciphertext.rs deleted file mode 100644 index abeedac..0000000 --- a/core/src/ggsw/ciphertext.rs +++ /dev/null @@ -1,707 +0,0 @@ -use backend::{ - Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnx, Scratch, VecZnxAlloc, - VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, VecZnxToMut, ZnxInfos, - ZnxZero, -}; -use sampling::source::Source; - -use crate::{ - FourierGLWECiphertext, FourierGLWESecret, GLWEAutomorphismKey, GLWECiphertext, GLWESwitchingKey, GLWETensorKey, GetRow, - Infos, ScratchCore, SetRow, -}; - -pub struct GGSWCiphertext { - pub(crate) data: MatZnxDft, - pub(crate) basek: usize, - pub(crate) k: usize, - pub(crate) digits: usize, -} - -impl GGSWCiphertext, FFT64> { - pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self { - let size: usize = k.div_ceil(basek); - debug_assert!(digits > 0, "invalid ggsw: `digits` == 0"); - - debug_assert!( - size > digits, - "invalid ggsw: ceil(k/basek): {} <= digits: {}", - size, - digits - ); - - assert!( - rows * digits <= size, - "invalid ggsw: rows: {} * digits:{} > ceil(k/basek): {}", - rows, - digits, - size - ); - - Self { - data: module.new_mat_znx_dft(rows, rank + 1, rank + 1, k.div_ceil(basek)), - basek, - k: k, - digits, - } - } - - pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize { - let size: usize = k.div_ceil(basek); - debug_assert!( - size > digits, - "invalid ggsw: ceil(k/basek): {} <= digits: {}", - size, - digits - ); - - assert!( - rows * digits <= size, - "invalid ggsw: rows: {} * digits:{} > ceil(k/basek): {}", - rows, - digits, - size - ); - - module.bytes_of_mat_znx_dft(rows, rank + 1, rank + 1, size) - } -} - -impl Infos for GGSWCiphertext { - type Inner = MatZnxDft; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn basek(&self) -> usize { - self.basek - } - - fn k(&self) -> usize { - self.k - } -} - -impl GGSWCiphertext { - pub fn rank(&self) -> usize { - self.data.cols_out() - 1 - } - - pub fn digits(&self) -> usize { - self.digits - } -} - -impl GGSWCiphertext, FFT64> { - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { - let size = k.div_ceil(basek); - GLWECiphertext::encrypt_sk_scratch_space(module, basek, k) - + module.bytes_of_vec_znx(rank + 1, size) - + module.bytes_of_vec_znx(1, size) - + module.bytes_of_vec_znx_dft(rank + 1, size) - } - - pub(crate) fn expand_row_scratch_space( - module: &Module, - basek: usize, - self_k: usize, - k_tsk: usize, - digits: usize, - rank: usize, - ) -> usize { - let tsk_size: usize = k_tsk.div_ceil(basek); - let self_size_out: usize = self_k.div_ceil(basek); - let self_size_in: usize = self_size_out.div_ceil(digits); - let tmp_dft_i: usize = module.bytes_of_vec_znx_dft(rank + 1, tsk_size); - let tmp_a: usize = module.bytes_of_vec_znx_dft(1, self_size_in); - let vmp: usize = module.vmp_apply_tmp_bytes( - self_size_out, - self_size_in, - self_size_in, - rank, - rank, - tsk_size, - ); - let tmp_idft: usize = module.bytes_of_vec_znx_big(1, tsk_size); - let norm: usize = module.vec_znx_big_normalize_tmp_bytes(); - tmp_dft_i + ((tmp_a + vmp) | (tmp_idft + norm)) - } - - pub(crate) fn keyswitch_internal_col0_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_in: usize, - k_ksk: usize, - digits: usize, - rank: usize, - ) -> usize { - GLWECiphertext::keyswitch_from_fourier_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank, rank) - + module.bytes_of_vec_znx_dft(rank + 1, k_in.div_ceil(basek)) - } - - pub fn keyswitch_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_in: usize, - k_ksk: usize, - digits_ksk: usize, - k_tsk: usize, - digits_tsk: usize, - rank: usize, - ) -> usize { - let out_size: usize = k_out.div_ceil(basek); - let res_znx: usize = module.bytes_of_vec_znx(rank + 1, out_size); - let ci_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size); - let ks: usize = - GGSWCiphertext::keyswitch_internal_col0_scratch_space(module, basek, k_out, k_in, k_ksk, digits_ksk, rank); - let expand_rows: usize = GGSWCiphertext::expand_row_scratch_space(module, basek, k_out, k_tsk, digits_tsk, rank); - let res_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size); - res_znx + ci_dft + (ks | expand_rows | res_dft) - } - - pub fn keyswitch_inplace_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_ksk: usize, - digits_ksk: usize, - k_tsk: usize, - digits_tsk: usize, - rank: usize, - ) -> usize { - GGSWCiphertext::keyswitch_scratch_space( - module, basek, k_out, k_out, k_ksk, digits_ksk, k_tsk, digits_tsk, rank, - ) - } - - pub fn automorphism_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_in: usize, - k_ksk: usize, - digits_ksk: usize, - k_tsk: usize, - digits_tsk: usize, - rank: usize, - ) -> usize { - let cols: usize = rank + 1; - let out_size: usize = k_out.div_ceil(basek); - let res: usize = module.bytes_of_vec_znx(cols, out_size); - let res_dft: usize = module.bytes_of_vec_znx_dft(cols, out_size); - let ci_dft: usize = module.bytes_of_vec_znx_dft(cols, out_size); - let ks_internal: usize = - GGSWCiphertext::keyswitch_internal_col0_scratch_space(module, basek, k_out, k_in, k_ksk, digits_ksk, rank); - let expand: usize = GGSWCiphertext::expand_row_scratch_space(module, basek, k_out, k_tsk, digits_tsk, rank); - res + ci_dft + (ks_internal | expand | res_dft) - } - - pub fn automorphism_inplace_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_ksk: usize, - digits_ksk: usize, - k_tsk: usize, - digits_tsk: usize, - rank: usize, - ) -> usize { - GGSWCiphertext::automorphism_scratch_space( - module, basek, k_out, k_out, k_ksk, digits_ksk, k_tsk, digits_tsk, rank, - ) - } - - pub fn external_product_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_in: usize, - k_ggsw: usize, - digits: usize, - rank: usize, - ) -> usize { - let tmp_in: usize = FourierGLWECiphertext::bytes_of(module, basek, k_in, rank); - let tmp_out: usize = FourierGLWECiphertext::bytes_of(module, basek, k_out, rank); - let ggsw: usize = FourierGLWECiphertext::external_product_scratch_space(module, basek, k_out, k_in, k_ggsw, digits, rank); - tmp_in + tmp_out + ggsw - } - - pub fn external_product_inplace_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_ggsw: usize, - digits: usize, - rank: usize, - ) -> usize { - let tmp: usize = FourierGLWECiphertext::bytes_of(module, basek, k_out, rank); - let ggsw: usize = - FourierGLWECiphertext::external_product_inplace_scratch_space(module, basek, k_out, k_ggsw, digits, rank); - tmp + ggsw - } -} - -impl + AsRef<[u8]>> GGSWCiphertext { - pub fn encrypt_sk, DataSk: AsRef<[u8]>>( - &mut self, - module: &Module, - pt: &ScalarZnx, - sk: &FourierGLWESecret, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - scratch: &mut Scratch, - ) { - #[cfg(debug_assertions)] - { - assert_eq!(self.rank(), sk.rank()); - assert_eq!(self.n(), module.n()); - assert_eq!(pt.n(), module.n()); - assert_eq!(sk.n(), module.n()); - } - - let basek: usize = self.basek(); - let k: usize = self.k(); - let rank: usize = self.rank(); - let digits: usize = self.digits(); - - let (mut tmp_pt, scratch1) = scratch.tmp_glwe_pt(module, basek, k); - let (mut tmp_ct, scratch2) = scratch1.tmp_glwe_ct(module, basek, k, rank); - - (0..self.rows()).for_each(|row_i| { - tmp_pt.data.zero(); - - // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt - module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (digits - 1) + row_i * digits, pt, 0); - module.vec_znx_normalize_inplace(basek, &mut tmp_pt.data, 0, scratch2); - - (0..rank + 1).for_each(|col_j| { - // rlwe encrypt of vec_znx_pt into vec_znx_ct - - tmp_ct.encrypt_sk_private( - module, - Some((&tmp_pt, col_j)), - sk, - source_xa, - source_xe, - sigma, - scratch2, - ); - - // Switch vec_znx_ct into DFT domain - { - let (mut tmp_ct_dft, _) = scratch2.tmp_fourier_glwe_ct(module, basek, k, rank); - tmp_ct.dft(module, &mut tmp_ct_dft); - self.set_row(module, row_i, col_j, &tmp_ct_dft); - } - }); - }); - } - - pub(crate) fn expand_row, DataTsk: AsRef<[u8]>>( - &mut self, - module: &Module, - col_j: usize, - res: &mut R, - ci_dft: &VecZnxDft, - tsk: &GLWETensorKey, - scratch: &mut Scratch, - ) where - R: VecZnxToMut, - { - let cols: usize = self.rank() + 1; - - assert!( - scratch.available() - >= GGSWCiphertext::expand_row_scratch_space( - module, - self.basek(), - self.k(), - tsk.k(), - tsk.digits(), - tsk.rank() - ) - ); - - // Example for rank 3: - // - // Note: M is a vector (m, Bm, B^2m, B^3m, ...), so each column is - // actually composed of that many rows and we focus on a specific row here - // implicitely given ci_dft. - // - // # Input - // - // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 ) - // col 1: (0, 0, 0, 0) - // col 2: (0, 0, 0, 0) - // col 3: (0, 0, 0, 0) - // - // # Output - // - // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 ) - // col 1: (-(b0s0 + b1s1 + b2s2) , b0 + M[i], b1 , b2 ) - // col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + M[i], c2 ) - // col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + M[i]) - - let digits: usize = tsk.digits(); - - let (mut tmp_dft_i, scratch1) = scratch.tmp_vec_znx_dft(module, cols, tsk.size()); - let (mut tmp_a, scratch2) = scratch1.tmp_vec_znx_dft(module, 1, (ci_dft.size() + digits - 1) / digits); - - { - // Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2 - // - // # Example for col=1 - // - // a0 * (-(f0s0 + f1s1 + f1s2) + s0^2, f0, f1, f2) = (-(a0f0s0 + a0f1s1 + a0f1s2) + a0s0^2, a0f0, a0f1, a0f2) - // + - // a1 * (-(g0s0 + g1s1 + g1s2) + s0s1, g0, g1, g2) = (-(a1g0s0 + a1g1s1 + a1g1s2) + a1s0s1, a1g0, a1g1, a1g2) - // + - // a2 * (-(h0s0 + h1s1 + h1s2) + s0s2, h0, h1, h2) = (-(a2h0s0 + a2h1s1 + a2h1s2) + a2s0s2, a2h0, a2h1, a2h2) - // = - // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0, x1, x2) - (1..cols).for_each(|col_i| { - let pmat: &MatZnxDft = &tsk.at(col_i - 1, col_j - 1).key.data; // Selects Enc(s[i]s[j]) - - // Extracts a[i] and multipies with Enc(s[i]s[j]) - (0..digits).for_each(|di| { - tmp_a.set_size((ci_dft.size() + di) / digits); - - // Small optimization for digits > 2 - // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then - // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}. - // As such we can ignore the last digits-2 limbs safely of the sum of vmp products. - // It is possible to further ignore the last digits-1 limbs, but this introduce - // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same - // noise is kept with respect to the ideal functionality. - tmp_dft_i.set_size(tsk.size() - ((digits - di) as isize - 2).max(0) as usize); - - module.vec_znx_dft_copy(digits, digits - 1 - di, &mut tmp_a, 0, ci_dft, col_i); - if di == 0 && col_i == 1 { - module.vmp_apply(&mut tmp_dft_i, &tmp_a, pmat, scratch2); - } else { - module.vmp_apply_add(&mut tmp_dft_i, &tmp_a, pmat, di, scratch2); - } - }); - }); - } - - // Adds -(sum a[i] * s[i]) + m) on the i-th column of tmp_idft_i - // - // (-(x0s0 + x1s1 + x2s2) + a0s0s0 + a1s0s1 + a2s0s2, x0, x1, x2) - // + - // (0, -(a0s0 + a1s1 + a2s2) + M[i], 0, 0) - // = - // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0 -(a0s0 + a1s1 + a2s2) + M[i], x1, x2) - // = - // (-(x0s0 + x1s1 + x2s2), x0 + M[i], x1, x2) - module.vec_znx_dft_add_inplace(&mut tmp_dft_i, col_j, ci_dft, 0); - let (mut tmp_idft, scratch2) = scratch1.tmp_vec_znx_big(module, 1, tsk.size()); - (0..cols).for_each(|i| { - module.vec_znx_idft_tmp_a(&mut tmp_idft, 0, &mut tmp_dft_i, i); - module.vec_znx_big_normalize(self.basek(), res, i, &tmp_idft, 0, scratch2); - }); - } - - pub fn keyswitch, DataKsk: AsRef<[u8]>, DataTsk: AsRef<[u8]>>( - &mut self, - module: &Module, - lhs: &GGSWCiphertext, - ksk: &GLWESwitchingKey, - tsk: &GLWETensorKey, - scratch: &mut Scratch, - ) { - let rank: usize = self.rank(); - let cols: usize = rank + 1; - let basek: usize = self.basek(); - - let (mut tmp_res, scratch1) = scratch.tmp_glwe_ct(module, basek, self.k(), rank); - let (mut ci_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, self.size()); - - // Keyswitch the j-th row of the col 0 - (0..lhs.rows()).for_each(|row_i| { - // Key-switch column 0, i.e. - // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2) - lhs.keyswitch_internal_col0(module, row_i, &mut tmp_res, ksk, scratch2); - - // Isolates DFT(a[i]) - (0..cols).for_each(|col_i| { - module.vec_znx_dft(1, 0, &mut ci_dft, col_i, &tmp_res.data, col_i); - }); - - module.mat_znx_dft_set_row(&mut self.data, row_i, 0, &ci_dft); - - // Generates - // - // col 1: (-(b0s0' + b1s1' + b2s2') , b0 + M[i], b1 , b2 ) - // col 2: (-(c0s0' + c1s1' + c2s2') , c0 , c1 + M[i], c2 ) - // col 3: (-(d0s0' + d1s1' + d2s2') , d0 , d1 , d2 + M[i]) - (1..cols).for_each(|col_j| { - self.expand_row(module, col_j, &mut tmp_res.data, &ci_dft, tsk, scratch2); - let (mut tmp_res_dft, _) = scratch2.tmp_fourier_glwe_ct(module, basek, self.k(), rank); - tmp_res.dft(module, &mut tmp_res_dft); - self.set_row(module, row_i, col_j, &tmp_res_dft); - }); - }) - } - - pub fn keyswitch_inplace, DataTsk: AsRef<[u8]>>( - &mut self, - module: &Module, - ksk: &GLWESwitchingKey, - tsk: &GLWETensorKey, - scratch: &mut Scratch, - ) { - unsafe { - let self_ptr: *mut GGSWCiphertext = self as *mut GGSWCiphertext; - self.keyswitch(module, &*self_ptr, ksk, tsk, scratch); - } - } - - pub fn automorphism, DataAk: AsRef<[u8]>, DataTsk: AsRef<[u8]>>( - &mut self, - module: &Module, - lhs: &GGSWCiphertext, - auto_key: &GLWEAutomorphismKey, - tensor_key: &GLWETensorKey, - scratch: &mut Scratch, - ) { - #[cfg(debug_assertions)] - { - assert_eq!( - self.rank(), - lhs.rank(), - "ggsw_out rank: {} != ggsw_in rank: {}", - self.rank(), - lhs.rank() - ); - assert_eq!( - self.rank(), - auto_key.rank(), - "ggsw_in rank: {} != auto_key rank: {}", - self.rank(), - auto_key.rank() - ); - assert_eq!( - self.rank(), - tensor_key.rank(), - "ggsw_in rank: {} != tensor_key rank: {}", - self.rank(), - tensor_key.rank() - ); - assert!( - scratch.available() - >= GGSWCiphertext::automorphism_scratch_space( - module, - self.basek(), - self.k(), - lhs.k(), - auto_key.k(), - auto_key.digits(), - tensor_key.k(), - tensor_key.digits(), - self.rank(), - ) - ) - }; - - let rank: usize = self.rank(); - let cols: usize = rank + 1; - let basek: usize = self.basek(); - - let (mut tmp_res, scratch1) = scratch.tmp_glwe_ct(module, basek, self.k(), rank); - let (mut ci_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, self.size()); - - // Keyswitch the j-th row of the col 0 - (0..lhs.rows()).for_each(|row_i| { - // Key-switch column 0, i.e. - // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2) - lhs.keyswitch_internal_col0(module, row_i, &mut tmp_res, &auto_key.key, scratch2); - - // Isolates DFT(AUTO(a[i])) - (0..cols).for_each(|col_i| { - // (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2) -> (-(a0s0 + a1s1 + a2s2) + pi(M[i]), a0, a1, a2) - module.vec_znx_automorphism_inplace(auto_key.p(), &mut tmp_res.data, col_i); - module.vec_znx_dft(1, 0, &mut ci_dft, col_i, &tmp_res.data, col_i); - }); - - module.mat_znx_dft_set_row(&mut self.data, row_i, 0, &ci_dft); - - // Generates - // - // col 1: (-(b0s0 + b1s1 + b2s2) , b0 + pi(M[i]), b1 , b2 ) - // col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + pi(M[i]), c2 ) - // col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + pi(M[i])) - (1..cols).for_each(|col_j| { - self.expand_row( - module, - col_j, - &mut tmp_res.data, - &ci_dft, - tensor_key, - scratch2, - ); - let (mut tmp_res_dft, _) = scratch2.tmp_fourier_glwe_ct(module, basek, self.k(), rank); - tmp_res.dft(module, &mut tmp_res_dft); - self.set_row(module, row_i, col_j, &tmp_res_dft); - }); - }) - } - - pub fn automorphism_inplace, DataTsk: AsRef<[u8]>>( - &mut self, - module: &Module, - auto_key: &GLWEAutomorphismKey, - tensor_key: &GLWETensorKey, - scratch: &mut Scratch, - ) { - unsafe { - let self_ptr: *mut GGSWCiphertext = self as *mut GGSWCiphertext; - self.automorphism(module, &*self_ptr, auto_key, tensor_key, scratch); - } - } - - pub fn external_product, DataRhs: AsRef<[u8]>>( - &mut self, - module: &Module, - lhs: &GGSWCiphertext, - rhs: &GGSWCiphertext, - scratch: &mut Scratch, - ) { - #[cfg(debug_assertions)] - { - assert_eq!( - self.rank(), - lhs.rank(), - "ggsw_out rank: {} != ggsw_in rank: {}", - self.rank(), - lhs.rank() - ); - assert_eq!( - self.rank(), - rhs.rank(), - "ggsw_in rank: {} != ggsw_apply rank: {}", - self.rank(), - rhs.rank() - ); - - assert!( - scratch.available() - >= GGSWCiphertext::external_product_scratch_space( - module, - self.basek(), - self.k(), - lhs.k(), - rhs.k(), - rhs.digits(), - rhs.rank() - ) - ) - } - - let (mut tmp_ct_in, scratch1) = scratch.tmp_fourier_glwe_ct(module, lhs.basek(), lhs.k(), lhs.rank()); - let (mut tmp_ct_out, scratch2) = scratch1.tmp_fourier_glwe_ct(module, self.basek(), self.k(), self.rank()); - - (0..self.rank() + 1).for_each(|col_i| { - (0..self.rows()).for_each(|row_j| { - lhs.get_row(module, row_j, col_i, &mut tmp_ct_in); - tmp_ct_out.external_product(module, &tmp_ct_in, rhs, scratch2); - self.set_row(module, row_j, col_i, &tmp_ct_out); - }); - }); - - tmp_ct_out.data.zero(); - - (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| { - (0..self.rank() + 1).for_each(|col_j| { - self.set_row(module, row_i, col_j, &tmp_ct_out); - }); - }); - } - - pub fn external_product_inplace>( - &mut self, - module: &Module, - rhs: &GGSWCiphertext, - scratch: &mut Scratch, - ) { - #[cfg(debug_assertions)] - { - assert_eq!( - self.rank(), - rhs.rank(), - "ggsw_out rank: {} != ggsw_apply: {}", - self.rank(), - rhs.rank() - ); - } - - let (mut tmp_ct, scratch1) = scratch.tmp_fourier_glwe_ct(module, self.basek(), self.k(), self.rank()); - - (0..self.rank() + 1).for_each(|col_i| { - (0..self.rows()).for_each(|row_j| { - self.get_row(module, row_j, col_i, &mut tmp_ct); - tmp_ct.external_product_inplace(module, rhs, scratch1); - self.set_row(module, row_j, col_i, &tmp_ct); - }); - }); - } -} - -impl> GGSWCiphertext { - pub(crate) fn keyswitch_internal_col0 + AsRef<[u8]>, DataKsk: AsRef<[u8]>>( - &self, - module: &Module, - row_i: usize, - res: &mut GLWECiphertext, - ksk: &GLWESwitchingKey, - scratch: &mut Scratch, - ) { - #[cfg(debug_assertions)] - { - assert_eq!(self.rank(), ksk.rank()); - assert_eq!(res.rank(), ksk.rank()); - assert!( - scratch.available() - >= GGSWCiphertext::keyswitch_internal_col0_scratch_space( - module, - self.basek(), - res.k(), - self.k(), - ksk.k(), - ksk.digits(), - ksk.rank() - ) - ) - } - let (mut tmp_dft_dft, scratch1) = scratch.tmp_fourier_glwe_ct(module, self.basek(), self.k(), self.rank()); - self.get_row(module, row_i, 0, &mut tmp_dft_dft); - res.keyswitch_from_fourier(module, &tmp_dft_dft, ksk, scratch1); - } -} - -impl> GetRow for GGSWCiphertext { - fn get_row + AsRef<[u8]>>( - &self, - module: &Module, - row_i: usize, - col_j: usize, - res: &mut FourierGLWECiphertext, - ) { - module.mat_znx_dft_get_row(&mut res.data, &self.data, row_i, col_j); - } -} - -impl + AsRef<[u8]>> SetRow for GGSWCiphertext { - fn set_row>( - &mut self, - module: &Module, - row_i: usize, - col_j: usize, - a: &FourierGLWECiphertext, - ) { - module.mat_znx_dft_set_row(&mut self.data, row_i, col_j, &a.data); - } -} diff --git a/core/src/ggsw/encryption.rs b/core/src/ggsw/encryption.rs new file mode 100644 index 0000000..b2c019a --- /dev/null +++ b/core/src/ggsw/encryption.rs @@ -0,0 +1,79 @@ +use backend::hal::{ + api::{ + ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxAddScalarInplace, VecZnxAllocBytes, VecZnxNormalizeInplace, ZnxZero, + }, + layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch}, +}; +use sampling::source::Source; + +use crate::{GGSWCiphertext, GLWECiphertext, GLWEEncryptSkFamily, GLWESecretExec, Infos, TakeGLWEPt}; + +pub trait GGSWEncryptSkFamily = GLWEEncryptSkFamily; + +impl GGSWCiphertext> { + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize + where + Module: GGSWEncryptSkFamily + VecZnxAllocBytes, + { + let size = k.div_ceil(basek); + GLWECiphertext::encrypt_sk_scratch_space(module, basek, k) + + module.vec_znx_alloc_bytes(rank + 1, size) + + module.vec_znx_alloc_bytes(1, size) + + module.vec_znx_dft_alloc_bytes(rank + 1, size) + } +} + +impl GGSWCiphertext { + pub fn encrypt_sk( + &mut self, + module: &Module, + pt: &ScalarZnx, + sk: &GLWESecretExec, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) where + Module: GGSWEncryptSkFamily + VecZnxAddScalarInplace, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + { + #[cfg(debug_assertions)] + { + use backend::hal::api::ZnxInfos; + + assert_eq!(self.rank(), sk.rank()); + assert_eq!(self.n(), module.n()); + assert_eq!(pt.n(), module.n()); + assert_eq!(sk.n(), module.n()); + } + + let basek: usize = self.basek(); + let k: usize = self.k(); + let rank: usize = self.rank(); + let digits: usize = self.digits(); + + let (mut tmp_pt, scratch1) = scratch.take_glwe_pt(module, basek, k); + + (0..self.rows()).for_each(|row_i| { + tmp_pt.data.zero(); + + // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt + module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (digits - 1) + row_i * digits, pt, 0); + module.vec_znx_normalize_inplace(basek, &mut tmp_pt.data, 0, scratch1); + + (0..rank + 1).for_each(|col_j| { + // rlwe encrypt of vec_znx_pt into vec_znx_ct + + self.at_mut(row_i, col_j).encrypt_sk_private( + module, + Some((&tmp_pt, col_j)), + sk, + source_xa, + source_xe, + sigma, + scratch1, + ); + }); + }); + } +} diff --git a/core/src/ggsw/external_product.rs b/core/src/ggsw/external_product.rs new file mode 100644 index 0000000..c279ac3 --- /dev/null +++ b/core/src/ggsw/external_product.rs @@ -0,0 +1,123 @@ +use backend::hal::{ + api::{ScratchAvailable, TakeVecZnxDft, ZnxZero}, + layouts::{Backend, DataMut, DataRef, Module, Scratch}, +}; + +use crate::{GGSWCiphertext, GGSWCiphertextExec, GLWECiphertext, GLWEExternalProductFamily, Infos}; + +impl GGSWCiphertext> { + pub fn external_product_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + k_in: usize, + k_ggsw: usize, + digits: usize, + rank: usize, + ) -> usize + where + Module: GLWEExternalProductFamily, + { + GLWECiphertext::external_product_scratch_space(module, basek, k_out, k_in, k_ggsw, digits, rank) + } + + pub fn external_product_inplace_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + k_ggsw: usize, + digits: usize, + rank: usize, + ) -> usize + where + Module: GLWEExternalProductFamily, + { + GLWECiphertext::external_product_inplace_scratch_space(module, basek, k_out, k_ggsw, digits, rank) + } +} + +impl GGSWCiphertext { + pub fn external_product( + &mut self, + module: &Module, + lhs: &GGSWCiphertext, + rhs: &GGSWCiphertextExec, + scratch: &mut Scratch, + ) where + Module: GLWEExternalProductFamily, + Scratch: ScratchAvailable + TakeVecZnxDft, + { + #[cfg(debug_assertions)] + { + use crate::{GGSWCiphertext, Infos}; + + assert_eq!( + self.rank(), + lhs.rank(), + "ggsw_out rank: {} != ggsw_in rank: {}", + self.rank(), + lhs.rank() + ); + assert_eq!( + self.rank(), + rhs.rank(), + "ggsw_in rank: {} != ggsw_apply rank: {}", + self.rank(), + rhs.rank() + ); + + assert!( + scratch.available() + >= GGSWCiphertext::external_product_scratch_space( + module, + self.basek(), + self.k(), + lhs.k(), + rhs.k(), + rhs.digits(), + rhs.rank() + ) + ) + } + + let min_rows: usize = self.rows().min(lhs.rows()); + + (0..self.rank() + 1).for_each(|col_i| { + (0..min_rows).for_each(|row_j| { + self.at_mut(row_j, col_i) + .external_product(module, &lhs.at(row_j, col_i), rhs, scratch); + }); + (min_rows..self.rows()).for_each(|row_i| { + self.at_mut(row_i, col_i).data.zero(); + }); + }); + } + + pub fn external_product_inplace( + &mut self, + module: &Module, + rhs: &GGSWCiphertextExec, + scratch: &mut Scratch, + ) where + Module: GLWEExternalProductFamily, + Scratch: TakeVecZnxDft + ScratchAvailable, + { + #[cfg(debug_assertions)] + { + assert_eq!( + self.rank(), + rhs.rank(), + "ggsw_out rank: {} != ggsw_apply: {}", + self.rank(), + rhs.rank() + ); + } + + (0..self.rank() + 1).for_each(|col_i| { + (0..self.rows()).for_each(|row_j| { + self.at_mut(row_j, col_i) + .external_product_inplace(module, rhs, scratch); + }); + }); + } +} diff --git a/core/src/ggsw/keyswitch.rs b/core/src/ggsw/keyswitch.rs new file mode 100644 index 0000000..ac84e28 --- /dev/null +++ b/core/src/ggsw/keyswitch.rs @@ -0,0 +1,253 @@ +use backend::hal::{ + api::{ + ScratchAvailable, TakeVecZnxBig, TakeVecZnxDft, VecZnxAllocBytes, VecZnxBigAllocBytes, VecZnxDftAddInplace, + VecZnxDftCopy, VecZnxDftToVecZnxBigTmpA, VecZnxNormalizeTmpBytes, ZnxInfos, + }, + layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnxDft, VmpPMat}, +}; + +use crate::{GGSWCiphertext, GLWECiphertext, GLWEKeyswitchFamily, GLWESwitchingKeyExec, GLWETensorKeyExec, Infos}; + +pub trait GGSWKeySwitchFamily = + GLWEKeyswitchFamily + VecZnxBigAllocBytes + VecZnxDftCopy + VecZnxDftAddInplace + VecZnxDftToVecZnxBigTmpA; + +impl GGSWCiphertext> { + pub(crate) fn expand_row_scratch_space( + module: &Module, + basek: usize, + self_k: usize, + k_tsk: usize, + digits: usize, + rank: usize, + ) -> usize + where + Module: GGSWKeySwitchFamily + VecZnxNormalizeTmpBytes, + { + let tsk_size: usize = k_tsk.div_ceil(basek); + let self_size_out: usize = self_k.div_ceil(basek); + let self_size_in: usize = self_size_out.div_ceil(digits); + + let tmp_dft_i: usize = module.vec_znx_dft_alloc_bytes(rank + 1, tsk_size); + let tmp_a: usize = module.vec_znx_dft_alloc_bytes(1, self_size_in); + let vmp: usize = module.vmp_apply_tmp_bytes( + self_size_out, + self_size_in, + self_size_in, + rank, + rank, + tsk_size, + ); + let tmp_idft: usize = module.vec_znx_big_alloc_bytes(1, tsk_size); + let norm: usize = module.vec_znx_normalize_tmp_bytes(module.n()); + tmp_dft_i + ((tmp_a + vmp) | (tmp_idft + norm)) + } + + pub fn keyswitch_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + k_in: usize, + k_ksk: usize, + digits_ksk: usize, + k_tsk: usize, + digits_tsk: usize, + rank: usize, + ) -> usize + where + Module: GLWEKeyswitchFamily + GGSWKeySwitchFamily + VecZnxAllocBytes + VecZnxNormalizeTmpBytes, + { + let out_size: usize = k_out.div_ceil(basek); + let res_znx: usize = module.vec_znx_alloc_bytes(rank + 1, out_size); + let ci_dft: usize = module.vec_znx_dft_alloc_bytes(rank + 1, out_size); + let ks: usize = GLWECiphertext::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits_ksk, rank, rank); + let expand_rows: usize = GGSWCiphertext::expand_row_scratch_space(module, basek, k_out, k_tsk, digits_tsk, rank); + let res_dft: usize = module.vec_znx_dft_alloc_bytes(rank + 1, out_size); + res_znx + ci_dft + (ks | expand_rows | res_dft) + } + + pub fn keyswitch_inplace_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + k_ksk: usize, + digits_ksk: usize, + k_tsk: usize, + digits_tsk: usize, + rank: usize, + ) -> usize + where + Module: GLWEKeyswitchFamily + GGSWKeySwitchFamily + VecZnxAllocBytes + VecZnxNormalizeTmpBytes, + { + GGSWCiphertext::keyswitch_scratch_space( + module, basek, k_out, k_out, k_ksk, digits_ksk, k_tsk, digits_tsk, rank, + ) + } +} + +impl GGSWCiphertext { + pub(crate) fn expand_row( + &mut self, + module: &Module, + row_i: usize, + col_j: usize, + ci_dft: &VecZnxDft, + tsk: &GLWETensorKeyExec, + scratch: &mut Scratch, + ) where + Module: GGSWKeySwitchFamily + VecZnxNormalizeTmpBytes, + Scratch: TakeVecZnxDft + TakeVecZnxBig + ScratchAvailable, + { + let cols: usize = self.rank() + 1; + + assert!( + scratch.available() + >= GGSWCiphertext::expand_row_scratch_space( + module, + self.basek(), + self.k(), + tsk.k(), + tsk.digits(), + tsk.rank() + ) + ); + + // Example for rank 3: + // + // Note: M is a vector (m, Bm, B^2m, B^3m, ...), so each column is + // actually composed of that many rows and we focus on a specific row here + // implicitely given ci_dft. + // + // # Input + // + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 ) + // col 1: (0, 0, 0, 0) + // col 2: (0, 0, 0, 0) + // col 3: (0, 0, 0, 0) + // + // # Output + // + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 ) + // col 1: (-(b0s0 + b1s1 + b2s2) , b0 + M[i], b1 , b2 ) + // col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + M[i], c2 ) + // col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + M[i]) + + let digits: usize = tsk.digits(); + + let (mut tmp_dft_i, scratch1) = scratch.take_vec_znx_dft(module, cols, tsk.size()); + let (mut tmp_a, scratch2) = scratch1.take_vec_znx_dft(module, 1, ci_dft.size().div_ceil(digits)); + + { + // Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2 + // + // # Example for col=1 + // + // a0 * (-(f0s0 + f1s1 + f1s2) + s0^2, f0, f1, f2) = (-(a0f0s0 + a0f1s1 + a0f1s2) + a0s0^2, a0f0, a0f1, a0f2) + // + + // a1 * (-(g0s0 + g1s1 + g1s2) + s0s1, g0, g1, g2) = (-(a1g0s0 + a1g1s1 + a1g1s2) + a1s0s1, a1g0, a1g1, a1g2) + // + + // a2 * (-(h0s0 + h1s1 + h1s2) + s0s2, h0, h1, h2) = (-(a2h0s0 + a2h1s1 + a2h1s2) + a2s0s2, a2h0, a2h1, a2h2) + // = + // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0, x1, x2) + (1..cols).for_each(|col_i| { + let pmat: &VmpPMat = &tsk.at(col_i - 1, col_j - 1).key.data; // Selects Enc(s[i]s[j]) + + // Extracts a[i] and multipies with Enc(s[i]s[j]) + (0..digits).for_each(|di| { + tmp_a.set_size((ci_dft.size() + di) / digits); + + // Small optimization for digits > 2 + // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then + // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}. + // As such we can ignore the last digits-2 limbs safely of the sum of vmp products. + // It is possible to further ignore the last digits-1 limbs, but this introduce + // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same + // noise is kept with respect to the ideal functionality. + tmp_dft_i.set_size(tsk.size() - ((digits - di) as isize - 2).max(0) as usize); + + module.vec_znx_dft_copy(digits, digits - 1 - di, &mut tmp_a, 0, ci_dft, col_i); + if di == 0 && col_i == 1 { + module.vmp_apply(&mut tmp_dft_i, &tmp_a, pmat, scratch2); + } else { + module.vmp_apply_add(&mut tmp_dft_i, &tmp_a, pmat, di, scratch2); + } + }); + }); + } + + // Adds -(sum a[i] * s[i]) + m) on the i-th column of tmp_idft_i + // + // (-(x0s0 + x1s1 + x2s2) + a0s0s0 + a1s0s1 + a2s0s2, x0, x1, x2) + // + + // (0, -(a0s0 + a1s1 + a2s2) + M[i], 0, 0) + // = + // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0 -(a0s0 + a1s1 + a2s2) + M[i], x1, x2) + // = + // (-(x0s0 + x1s1 + x2s2), x0 + M[i], x1, x2) + module.vec_znx_dft_add_inplace(&mut tmp_dft_i, col_j, ci_dft, 0); + let (mut tmp_idft, scratch2) = scratch1.take_vec_znx_big(module, 1, tsk.size()); + (0..cols).for_each(|i| { + module.vec_znx_dft_to_vec_znx_big_tmp_a(&mut tmp_idft, 0, &mut tmp_dft_i, i); + module.vec_znx_big_normalize( + self.basek(), + &mut self.at_mut(row_i, col_j).data, + i, + &tmp_idft, + 0, + scratch2, + ); + }); + } + + pub fn keyswitch( + &mut self, + module: &Module, + lhs: &GGSWCiphertext, + ksk: &GLWESwitchingKeyExec, + tsk: &GLWETensorKeyExec, + scratch: &mut Scratch, + ) where + Module: GLWEKeyswitchFamily + GGSWKeySwitchFamily + VecZnxNormalizeTmpBytes, + Scratch: TakeVecZnxDft + TakeVecZnxBig + ScratchAvailable, + { + let rank: usize = self.rank(); + let cols: usize = rank + 1; + + // Keyswitch the j-th row of the col 0 + (0..lhs.rows()).for_each(|row_i| { + // Key-switch column 0, i.e. + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2) + self.at_mut(row_i, 0) + .keyswitch(module, &lhs.at(row_i, 0), ksk, scratch); + + // Pre-compute DFT of (a0, a1, a2) + let (mut ci_dft, scratch1) = scratch.take_vec_znx_dft(module, cols, self.size()); + (0..cols).for_each(|i| { + module.vec_znx_dft_from_vec_znx(1, 0, &mut ci_dft, i, &self.at(row_i, 0).data, i); + }); + // Generates + // + // col 1: (-(b0s0' + b1s1' + b2s2') , b0 + M[i], b1 , b2 ) + // col 2: (-(c0s0' + c1s1' + c2s2') , c0 , c1 + M[i], c2 ) + // col 3: (-(d0s0' + d1s1' + d2s2') , d0 , d1 , d2 + M[i]) + (1..cols).for_each(|col_j| { + self.expand_row(module, row_i, col_j, &ci_dft, tsk, scratch1); + }); + }) + } + + pub fn keyswitch_inplace( + &mut self, + module: &Module, + ksk: &GLWESwitchingKeyExec, + tsk: &GLWETensorKeyExec, + scratch: &mut Scratch, + ) where + Module: GLWEKeyswitchFamily + GGSWKeySwitchFamily + VecZnxNormalizeTmpBytes, + Scratch: TakeVecZnxDft + TakeVecZnxBig + ScratchAvailable, + { + unsafe { + let self_ptr: *mut GGSWCiphertext = self as *mut GGSWCiphertext; + self.keyswitch(module, &*self_ptr, ksk, tsk, scratch); + } + } +} diff --git a/core/src/ggsw/layout.rs b/core/src/ggsw/layout.rs new file mode 100644 index 0000000..d8713b5 --- /dev/null +++ b/core/src/ggsw/layout.rs @@ -0,0 +1,259 @@ +use backend::hal::{ + api::{MatZnxAlloc, MatZnxAllocBytes, VmpPMatAlloc, VmpPMatAllocBytes, VmpPMatPrepare}, + layouts::{Backend, Data, DataMut, DataRef, MatZnx, Module, ReaderFrom, Scratch, VmpPMat, WriterTo}, +}; + +use crate::{GLWECiphertext, Infos}; + +pub trait GGSWLayoutFamily = VmpPMatAlloc + VmpPMatAllocBytes + VmpPMatPrepare; + +#[derive(PartialEq, Eq)] +pub struct GGSWCiphertext { + pub(crate) data: MatZnx, + pub(crate) basek: usize, + pub(crate) k: usize, + pub(crate) digits: usize, +} + +impl GGSWCiphertext { + pub fn at(&self, row: usize, col: usize) -> GLWECiphertext<&[u8]> { + GLWECiphertext { + data: self.data.at(row, col), + basek: self.basek, + k: self.k, + } + } +} + +impl GGSWCiphertext { + pub fn at_mut(&mut self, row: usize, col: usize) -> GLWECiphertext<&mut [u8]> { + GLWECiphertext { + data: self.data.at_mut(row, col), + basek: self.basek, + k: self.k, + } + } +} + +impl GGSWCiphertext> { + pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self + where + Module: MatZnxAlloc, + { + let size: usize = k.div_ceil(basek); + debug_assert!(digits > 0, "invalid ggsw: `digits` == 0"); + + debug_assert!( + size > digits, + "invalid ggsw: ceil(k/basek): {} <= digits: {}", + size, + digits + ); + + assert!( + rows * digits <= size, + "invalid ggsw: rows: {} * digits:{} > ceil(k/basek): {}", + rows, + digits, + size + ); + + Self { + data: module.mat_znx_alloc(rows, rank + 1, rank + 1, k.div_ceil(basek)), + basek, + k: k, + digits, + } + } + + pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize + where + Module: MatZnxAllocBytes, + { + let size: usize = k.div_ceil(basek); + debug_assert!( + size > digits, + "invalid ggsw: ceil(k/basek): {} <= digits: {}", + size, + digits + ); + + assert!( + rows * digits <= size, + "invalid ggsw: rows: {} * digits:{} > ceil(k/basek): {}", + rows, + digits, + size + ); + + module.mat_znx_alloc_bytes(rows, rank + 1, rank + 1, size) + } +} + +impl Infos for GGSWCiphertext { + type Inner = MatZnx; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn basek(&self) -> usize { + self.basek + } + + fn k(&self) -> usize { + self.k + } +} + +impl GGSWCiphertext { + pub fn rank(&self) -> usize { + self.data.cols_out() - 1 + } + + pub fn digits(&self) -> usize { + self.digits + } +} + +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; + +impl ReaderFrom for GGSWCiphertext { + fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { + self.k = reader.read_u64::()? as usize; + self.basek = reader.read_u64::()? as usize; + self.digits = reader.read_u64::()? as usize; + self.data.read_from(reader) + } +} + +impl WriterTo for GGSWCiphertext { + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + writer.write_u64::(self.k as u64)?; + writer.write_u64::(self.basek as u64)?; + writer.write_u64::(self.digits as u64)?; + self.data.write_to(writer) + } +} + +#[derive(PartialEq, Eq)] +pub struct GGSWCiphertextExec { + pub(crate) data: VmpPMat, + pub(crate) basek: usize, + pub(crate) k: usize, + pub(crate) digits: usize, +} + +impl GGSWCiphertextExec, B> { + pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self + where + Module: GGSWLayoutFamily, + { + let size: usize = k.div_ceil(basek); + debug_assert!(digits > 0, "invalid ggsw: `digits` == 0"); + + debug_assert!( + size > digits, + "invalid ggsw: ceil(k/basek): {} <= digits: {}", + size, + digits + ); + + assert!( + rows * digits <= size, + "invalid ggsw: rows: {} * digits:{} > ceil(k/basek): {}", + rows, + digits, + size + ); + + Self { + data: module.vmp_pmat_alloc(rows, rank + 1, rank + 1, k.div_ceil(basek)), + basek, + k: k, + digits, + } + } + + pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize + where + Module: GGSWLayoutFamily, + { + let size: usize = k.div_ceil(basek); + debug_assert!( + size > digits, + "invalid ggsw: ceil(k/basek): {} <= digits: {}", + size, + digits + ); + + assert!( + rows * digits <= size, + "invalid ggsw: rows: {} * digits:{} > ceil(k/basek): {}", + rows, + digits, + size + ); + + module.vmp_pmat_alloc_bytes(rows, rank + 1, rank + 1, size) + } + + pub fn from( + module: &Module, + other: &GGSWCiphertext, + scratch: &mut Scratch, + ) -> GGSWCiphertextExec, B> + where + Module: GGSWLayoutFamily, + { + let mut ggsw_exec: GGSWCiphertextExec, B> = Self::alloc( + module, + other.basek(), + other.k(), + other.rows(), + other.digits(), + other.rank(), + ); + ggsw_exec.prepare(module, other, scratch); + ggsw_exec + } +} + +impl Infos for GGSWCiphertextExec { + type Inner = VmpPMat; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn basek(&self) -> usize { + self.basek + } + + fn k(&self) -> usize { + self.k + } +} + +impl GGSWCiphertextExec { + pub fn rank(&self) -> usize { + self.data.cols_out() - 1 + } + + pub fn digits(&self) -> usize { + self.digits + } +} + +impl GGSWCiphertextExec { + pub fn prepare(&mut self, module: &Module, other: &GGSWCiphertext, scratch: &mut Scratch) + where + DataOther: DataRef, + Module: GGSWLayoutFamily, + { + module.vmp_prepare(&mut self.data, &other.data, scratch); + self.k = other.k; + self.basek = other.basek; + self.digits = other.digits; + } +} diff --git a/core/src/ggsw/mod.rs b/core/src/ggsw/mod.rs index f27b96b..a5883f2 100644 --- a/core/src/ggsw/mod.rs +++ b/core/src/ggsw/mod.rs @@ -1,6 +1,14 @@ -pub mod ciphertext; +mod automorphism; +mod encryption; +mod external_product; +mod keyswitch; +mod layout; +mod noise; -pub use ciphertext::GGSWCiphertext; +pub use encryption::GGSWEncryptSkFamily; +pub use keyswitch::GGSWKeySwitchFamily; +pub use layout::{GGSWCiphertext, GGSWCiphertextExec, GGSWLayoutFamily}; +pub use noise::GGSWAssertNoiseFamily; #[cfg(test)] -mod test_fft64; +mod test; diff --git a/core/src/ggsw/noise.rs b/core/src/ggsw/noise.rs new file mode 100644 index 0000000..e98636c --- /dev/null +++ b/core/src/ggsw/noise.rs @@ -0,0 +1,73 @@ +use backend::hal::{ + api::{ + ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAddScalarInplace, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigNormalize, + VecZnxBigNormalizeTmpBytes, VecZnxDftAlloc, VecZnxDftToVecZnxBigTmpA, VecZnxNormalizeTmpBytes, VecZnxStd, + VecZnxSubABInplace, ZnxZero, + }, + layouts::{Backend, DataRef, Module, ScalarZnx, ScratchOwned, VecZnxBig, VecZnxDft}, + oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl}, +}; + +use crate::{GGSWCiphertext, GLWECiphertext, GLWEDecryptFamily, GLWEPlaintext, GLWESecretExec, Infos}; + +pub trait GGSWAssertNoiseFamily = GLWEDecryptFamily + + VecZnxBigAlloc + + VecZnxDftAlloc + + VecZnxBigNormalizeTmpBytes + + VecZnxBigNormalize + + VecZnxDftToVecZnxBigTmpA; + +impl GGSWCiphertext { + pub fn assert_noise( + &self, + module: &Module, + sk_exec: &GLWESecretExec, + pt_want: &ScalarZnx, + max_noise: F, + ) where + DataSk: DataRef, + DataScalar: DataRef, + Module: GGSWAssertNoiseFamily + VecZnxAlloc + VecZnxAddScalarInplace + VecZnxSubABInplace + VecZnxStd, + B: TakeVecZnxDftImpl + TakeVecZnxBigImpl + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, + F: Fn(usize) -> f64, + { + let basek: usize = self.basek(); + let k: usize = self.k(); + let digits: usize = self.digits(); + + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k); + let mut pt_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(1, self.size()); + let mut pt_big: VecZnxBig, B> = module.vec_znx_big_alloc(1, self.size()); + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GLWECiphertext::decrypt_scratch_space(module, basek, k) | module.vec_znx_normalize_tmp_bytes(module.n()), + ); + + (0..self.rank() + 1).for_each(|col_j| { + (0..self.rows()).for_each(|row_i| { + module.vec_znx_add_scalar_inplace(&mut pt.data, 0, (digits - 1) + row_i * digits, pt_want, 0); + + // mul with sk[col_j-1] + if col_j > 0 { + module.vec_znx_dft_from_vec_znx(1, 0, &mut pt_dft, 0, &pt.data, 0); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_exec.data, col_j - 1); + module.vec_znx_dft_to_vec_znx_big_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); + module.vec_znx_big_normalize(basek, &mut pt.data, 0, &pt_big, 0, scratch.borrow()); + } + + self.at(row_i, col_j) + .decrypt(module, &mut pt_have, &sk_exec, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt.data, 0); + + let std_pt: f64 = module.vec_znx_std(basek, &pt_have.data, 0).log2(); + let noise: f64 = max_noise(col_j); + println!("{} {}", std_pt, noise); + assert!(std_pt <= noise, "{} > {}", std_pt, noise); + + pt.data.zero(); + }); + }); + } +} diff --git a/core/src/ggsw/test/generic_tests.rs b/core/src/ggsw/test/generic_tests.rs new file mode 100644 index 0000000..a03db42 --- /dev/null +++ b/core/src/ggsw/test/generic_tests.rs @@ -0,0 +1,724 @@ +use backend::hal::{ + api::{ + MatZnxAlloc, ScalarZnxAlloc, ScalarZnxAllocBytes, ScalarZnxAutomorphism, ScalarZnxAutomorphismInplace, ScratchOwnedAlloc, + ScratchOwnedBorrow, VecZnxAddScalarInplace, VecZnxAlloc, VecZnxAllocBytes, VecZnxAutomorphismInplace, + VecZnxRotateInplace, VecZnxStd, VecZnxSubABInplace, VecZnxSwithcDegree, ZnxViewMut, + }, + layouts::{Backend, Module, ScalarZnx, ScalarZnxToMut, ScratchOwned}, + oep::{ + ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, + TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, VecZnxBigAllocBytesImpl, VecZnxDftAllocBytesImpl, + }, +}; +use sampling::source::Source; + +use crate::{ + AutomorphismKey, AutomorphismKeyExec, GGLWEExecLayoutFamily, GGSWAssertNoiseFamily, GGSWCiphertext, GGSWCiphertextExec, + GGSWEncryptSkFamily, GGSWKeySwitchFamily, GLWESecret, GLWESecretExec, GLWESecretFamily, GLWESwitchingKey, + GLWESwitchingKeyEncryptSkFamily, GLWESwitchingKeyExec, GLWETensorKey, GLWETensorKeyEncryptSkFamily, GLWETensorKeyExec, + noise::{noise_ggsw_keyswitch, noise_ggsw_product}, +}; + +pub(crate) trait TestModuleFamily = GLWESecretFamily + + GGSWEncryptSkFamily + + GGSWAssertNoiseFamily + + VecZnxAlloc + + ScalarZnxAlloc + + VecZnxAllocBytes + + MatZnxAlloc + + VecZnxAddScalarInplace + + VecZnxSubABInplace + + VecZnxStd + + ScalarZnxAllocBytes; +pub(crate) trait TestScratchFamily = TakeVecZnxDftImpl + + TakeVecZnxBigImpl + + TakeSvpPPolImpl + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + ScratchAvailableImpl + + TakeScalarZnxImpl + + TakeVecZnxImpl + + VecZnxDftAllocBytesImpl + + VecZnxBigAllocBytesImpl + + TakeSvpPPolImpl; + +pub(crate) fn test_encrypt_sk(module: &Module, basek: usize, k: usize, digits: usize, rank: usize, sigma: f64) +where + Module: TestModuleFamily, + B: TestScratchFamily, +{ + let rows: usize = (k - digits * basek) / (digits * basek); + + let mut ct: GGSWCiphertext> = GGSWCiphertext::alloc(module, basek, k, rows, digits, rank); + + let mut pt_scalar: ScalarZnx> = module.scalar_znx_alloc(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); + + let mut scratch: ScratchOwned = ScratchOwned::alloc(GGSWCiphertext::encrypt_sk_scratch_space( + module, basek, k, rank, + )); + + let mut sk: GLWESecret> = GLWESecret::alloc(module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk); + sk_exec.prepare(module, &sk); + + ct.encrypt_sk( + module, + &pt_scalar, + &sk_exec, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + let noise_f = |_col_i: usize| -(k as f64) + sigma.log2() + 0.5; + + ct.assert_noise(module, &sk_exec, &pt_scalar, &noise_f); +} + +pub(crate) fn test_keyswitch( + module: &Module, + basek: usize, + k_out: usize, + k_in: usize, + k_ksk: usize, + k_tsk: usize, + digits: usize, + rank: usize, + sigma: f64, +) where + Module: TestModuleFamily + + GGSWAssertNoiseFamily + + GGSWKeySwitchFamily + + GLWESwitchingKeyEncryptSkFamily + + GLWETensorKeyEncryptSkFamily + + GGLWEExecLayoutFamily + + VecZnxSwithcDegree, + B: TestScratchFamily + VecZnxDftAllocBytesImpl + VecZnxBigAllocBytesImpl + TakeSvpPPolImpl, +{ + let rows: usize = k_in.div_ceil(digits * basek); + + let digits_in: usize = 1; + + let mut ct_in: GGSWCiphertext> = GGSWCiphertext::alloc(module, basek, k_in, rows, digits_in, rank); + let mut ct_out: GGSWCiphertext> = GGSWCiphertext::alloc(module, basek, k_out, rows, digits_in, rank); + let mut tsk: GLWETensorKey> = GLWETensorKey::alloc(module, basek, k_ksk, rows, digits, rank); + let mut ksk: GLWESwitchingKey> = GLWESwitchingKey::alloc(module, basek, k_ksk, rows, digits, rank, rank); + let mut pt_scalar: ScalarZnx> = module.scalar_znx_alloc(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_in, rank) + | GLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k_ksk, rank, rank) + | GLWETensorKey::encrypt_sk_scratch_space(module, basek, k_tsk, rank) + | GGSWCiphertext::keyswitch_scratch_space( + module, basek, k_out, k_in, k_ksk, digits, k_tsk, digits, rank, + ), + ); + + let var_xs: f64 = 0.5; + + let mut sk_in: GLWESecret> = GLWESecret::alloc(module, rank); + sk_in.fill_ternary_prob(var_xs, &mut source_xs); + let sk_in_dft: GLWESecretExec, B> = GLWESecretExec::from(module, &sk_in); + + let mut sk_out: GLWESecret> = GLWESecret::alloc(module, rank); + sk_out.fill_ternary_prob(var_xs, &mut source_xs); + let sk_out_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk_out); + + ksk.encrypt_sk( + module, + &sk_in, + &sk_out, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + tsk.encrypt_sk( + module, + &sk_out, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); + + ct_in.encrypt_sk( + module, + &pt_scalar, + &sk_in_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + let mut ksk_exec: GLWESwitchingKeyExec, B> = + GLWESwitchingKeyExec::alloc(module, basek, k_ksk, rows, digits, rank, rank); + let mut tsk_exec: GLWETensorKeyExec, B> = GLWETensorKeyExec::alloc(module, basek, k_ksk, rows, digits, rank); + + ksk_exec.prepare(module, &ksk, scratch.borrow()); + tsk_exec.prepare(module, &tsk, scratch.borrow()); + + ct_out.keyswitch(module, &ct_in, &ksk_exec, &tsk_exec, scratch.borrow()); + + let max_noise = |col_j: usize| -> f64 { + noise_ggsw_keyswitch( + module.n() as f64, + basek * digits, + col_j, + var_xs, + 0f64, + sigma * sigma, + 0f64, + rank as f64, + k_in, + k_ksk, + k_tsk, + ) + 0.5 + }; + + ct_out.assert_noise(module, &sk_out_exec, &pt_scalar, &max_noise); +} + +pub(crate) fn test_keyswitch_inplace( + module: &Module, + basek: usize, + k_ct: usize, + k_ksk: usize, + k_tsk: usize, + digits: usize, + rank: usize, + sigma: f64, +) where + Module: TestModuleFamily + + GGSWAssertNoiseFamily + + GGSWKeySwitchFamily + + GLWESwitchingKeyEncryptSkFamily + + GLWETensorKeyEncryptSkFamily + + GGLWEExecLayoutFamily + + VecZnxSwithcDegree, + B: TestScratchFamily, +{ + let rows: usize = k_ct.div_ceil(digits * basek); + + let digits_in: usize = 1; + + let mut ct: GGSWCiphertext> = GGSWCiphertext::alloc(module, basek, k_ct, rows, digits_in, rank); + let mut tsk: GLWETensorKey> = GLWETensorKey::alloc(module, basek, k_tsk, rows, digits, rank); + let mut ksk: GLWESwitchingKey> = GLWESwitchingKey::alloc(module, basek, k_ksk, rows, digits, rank, rank); + let mut pt_scalar: ScalarZnx> = module.scalar_znx_alloc(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_ct, rank) + | GLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k_ksk, rank, rank) + | GLWETensorKey::encrypt_sk_scratch_space(module, basek, k_tsk, rank) + | GGSWCiphertext::keyswitch_inplace_scratch_space(module, basek, k_ct, k_ksk, digits, k_tsk, digits, rank), + ); + + let var_xs: f64 = 0.5; + + let mut sk_in: GLWESecret> = GLWESecret::alloc(module, rank); + sk_in.fill_ternary_prob(var_xs, &mut source_xs); + let sk_in_dft: GLWESecretExec, B> = GLWESecretExec::from(module, &sk_in); + + let mut sk_out: GLWESecret> = GLWESecret::alloc(module, rank); + sk_out.fill_ternary_prob(var_xs, &mut source_xs); + let sk_out_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk_out); + + ksk.encrypt_sk( + module, + &sk_in, + &sk_out, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + tsk.encrypt_sk( + module, + &sk_out, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); + + ct.encrypt_sk( + module, + &pt_scalar, + &sk_in_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + let mut ksk_exec: GLWESwitchingKeyExec, B> = + GLWESwitchingKeyExec::alloc(module, basek, k_ksk, rows, digits, rank, rank); + let mut tsk_exec: GLWETensorKeyExec, B> = GLWETensorKeyExec::alloc(module, basek, k_ksk, rows, digits, rank); + + ksk_exec.prepare(module, &ksk, scratch.borrow()); + tsk_exec.prepare(module, &tsk, scratch.borrow()); + + ct.keyswitch_inplace(module, &ksk_exec, &tsk_exec, scratch.borrow()); + + let max_noise = |col_j: usize| -> f64 { + noise_ggsw_keyswitch( + module.n() as f64, + basek * digits, + col_j, + var_xs, + 0f64, + sigma * sigma, + 0f64, + rank as f64, + k_ct, + k_ksk, + k_tsk, + ) + 0.5 + }; + + ct.assert_noise(module, &sk_out_exec, &pt_scalar, &max_noise); +} + +pub(crate) fn test_automorphism( + p: i64, + module: &Module, + basek: usize, + k_out: usize, + k_in: usize, + k_ksk: usize, + k_tsk: usize, + digits: usize, + rank: usize, + sigma: f64, +) where + Module: TestModuleFamily + + GGSWAssertNoiseFamily + + GGSWKeySwitchFamily + + GLWESwitchingKeyEncryptSkFamily + + GLWETensorKeyEncryptSkFamily + + GGLWEExecLayoutFamily + + VecZnxSwithcDegree + + VecZnxAutomorphismInplace + + ScalarZnxAutomorphismInplace + + ScalarZnxAutomorphism, + B: TestScratchFamily, +{ + let rows: usize = k_in.div_ceil(basek * digits); + let rows_in: usize = k_in.div_euclid(basek * digits); + + let digits_in: usize = 1; + + let mut ct_in: GGSWCiphertext> = GGSWCiphertext::alloc(module, basek, k_in, rows_in, digits_in, rank); + let mut ct_out: GGSWCiphertext> = GGSWCiphertext::alloc(module, basek, k_out, rows_in, digits_in, rank); + let mut tensor_key: GLWETensorKey> = GLWETensorKey::alloc(module, basek, k_tsk, rows, digits, rank); + let mut auto_key: AutomorphismKey> = AutomorphismKey::alloc(module, basek, k_ksk, rows, digits, rank); + let mut pt_scalar: ScalarZnx> = module.scalar_znx_alloc(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_in, rank) + | AutomorphismKey::encrypt_sk_scratch_space(module, basek, k_ksk, rank) + | GLWETensorKey::encrypt_sk_scratch_space(module, basek, k_tsk, rank) + | GGSWCiphertext::automorphism_scratch_space( + module, basek, k_out, k_in, k_ksk, digits, k_tsk, digits, rank, + ), + ); + + let var_xs: f64 = 0.5; + + let mut sk: GLWESecret> = GLWESecret::alloc(module, rank); + sk.fill_ternary_prob(var_xs, &mut source_xs); + let sk_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk); + + auto_key.encrypt_sk( + module, + p, + &sk, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + tensor_key.encrypt_sk( + module, + &sk, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); + + ct_in.encrypt_sk( + module, + &pt_scalar, + &sk_exec, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + let mut auto_key_exec: AutomorphismKeyExec, B> = AutomorphismKeyExec::alloc(module, basek, k_ksk, rows, digits, rank); + auto_key_exec.prepare(module, &auto_key, scratch.borrow()); + + let mut tsk_exec: GLWETensorKeyExec, B> = GLWETensorKeyExec::alloc(module, basek, k_tsk, rows, digits, rank); + tsk_exec.prepare(module, &tensor_key, scratch.borrow()); + + ct_out.automorphism(module, &ct_in, &auto_key_exec, &tsk_exec, scratch.borrow()); + + module.scalar_znx_automorphism_inplace(p, &mut pt_scalar, 0); + + let max_noise = |col_j: usize| -> f64 { + noise_ggsw_keyswitch( + module.n() as f64, + basek * digits, + col_j, + var_xs, + 0f64, + sigma * sigma, + 0f64, + rank as f64, + k_in, + k_ksk, + k_tsk, + ) + 0.5 + }; + + ct_out.assert_noise(module, &sk_exec, &pt_scalar, &max_noise); +} + +pub(crate) fn test_automorphism_inplace( + p: i64, + module: &Module, + basek: usize, + k_ct: usize, + k_ksk: usize, + k_tsk: usize, + digits: usize, + rank: usize, + sigma: f64, +) where + Module: TestModuleFamily + + GGSWAssertNoiseFamily + + GGSWKeySwitchFamily + + GLWESwitchingKeyEncryptSkFamily + + GLWETensorKeyEncryptSkFamily + + GGLWEExecLayoutFamily + + VecZnxSwithcDegree + + VecZnxAutomorphismInplace + + ScalarZnxAutomorphismInplace + + ScalarZnxAutomorphism, + B: TestScratchFamily, +{ + let rows: usize = k_ct.div_ceil(digits * basek); + let rows_in: usize = k_ct.div_euclid(basek * digits); + let digits_in: usize = 1; + + let mut ct: GGSWCiphertext> = GGSWCiphertext::alloc(module, basek, k_ct, rows_in, digits_in, rank); + let mut tensor_key: GLWETensorKey> = GLWETensorKey::alloc(module, basek, k_tsk, rows, digits, rank); + let mut auto_key: AutomorphismKey> = AutomorphismKey::alloc(module, basek, k_ksk, rows, digits, rank); + let mut pt_scalar: ScalarZnx> = module.scalar_znx_alloc(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_ct, rank) + | AutomorphismKey::encrypt_sk_scratch_space(module, basek, k_ksk, rank) + | GLWETensorKey::encrypt_sk_scratch_space(module, basek, k_tsk, rank) + | GGSWCiphertext::automorphism_inplace_scratch_space(module, basek, k_ct, k_ksk, digits, k_tsk, digits, rank), + ); + + let var_xs: f64 = 0.5; + + let mut sk: GLWESecret> = GLWESecret::alloc(module, rank); + sk.fill_ternary_prob(var_xs, &mut source_xs); + let sk_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk); + + auto_key.encrypt_sk( + module, + p, + &sk, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + tensor_key.encrypt_sk( + module, + &sk, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); + + ct.encrypt_sk( + module, + &pt_scalar, + &sk_exec, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + let mut auto_key_exec: AutomorphismKeyExec, B> = AutomorphismKeyExec::alloc(module, basek, k_ksk, rows, digits, rank); + auto_key_exec.prepare(module, &auto_key, scratch.borrow()); + + let mut tsk_exec: GLWETensorKeyExec, B> = GLWETensorKeyExec::alloc(module, basek, k_tsk, rows, digits, rank); + tsk_exec.prepare(module, &tensor_key, scratch.borrow()); + + ct.automorphism_inplace(module, &auto_key_exec, &tsk_exec, scratch.borrow()); + + module.scalar_znx_automorphism_inplace(p, &mut pt_scalar, 0); + + let max_noise = |col_j: usize| -> f64 { + noise_ggsw_keyswitch( + module.n() as f64, + basek * digits, + col_j, + var_xs, + 0f64, + sigma * sigma, + 0f64, + rank as f64, + k_ct, + k_ksk, + k_tsk, + ) + 0.5 + }; + + ct.assert_noise(module, &sk_exec, &pt_scalar, &max_noise); +} + +pub(crate) fn test_external_product( + module: &Module, + basek: usize, + k_in: usize, + k_out: usize, + k_ggsw: usize, + digits: usize, + rank: usize, + sigma: f64, +) where + Module: TestModuleFamily + + GGSWAssertNoiseFamily + + GGSWKeySwitchFamily + + GLWESwitchingKeyEncryptSkFamily + + GLWETensorKeyEncryptSkFamily + + GGLWEExecLayoutFamily + + VecZnxRotateInplace, + B: TestScratchFamily, +{ + let rows: usize = k_in.div_ceil(basek * digits); + let rows_in: usize = k_in.div_euclid(basek * digits); + let digits_in: usize = 1; + + let mut ct_ggsw_lhs_in: GGSWCiphertext> = GGSWCiphertext::alloc(module, basek, k_in, rows_in, digits_in, rank); + let mut ct_ggsw_lhs_out: GGSWCiphertext> = GGSWCiphertext::alloc(module, basek, k_out, rows_in, digits_in, rank); + let mut ct_ggsw_rhs: GGSWCiphertext> = GGSWCiphertext::alloc(module, basek, k_ggsw, rows, digits, rank); + let mut pt_ggsw_lhs: ScalarZnx> = module.scalar_znx_alloc(1); + let mut pt_ggsw_rhs: ScalarZnx> = module.scalar_znx_alloc(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + pt_ggsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs); + + let k: usize = 1; + + pt_ggsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_ggsw, rank) + | GGSWCiphertext::external_product_scratch_space(module, basek, k_out, k_in, k_ggsw, digits, rank), + ); + + let mut sk: GLWESecret> = GLWESecret::alloc(module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk); + + ct_ggsw_rhs.encrypt_sk( + module, + &pt_ggsw_rhs, + &sk_exec, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_ggsw_lhs_in.encrypt_sk( + module, + &pt_ggsw_lhs, + &sk_exec, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + let mut ct_rhs_exec: GGSWCiphertextExec, B> = GGSWCiphertextExec::alloc(module, basek, k_ggsw, rows, digits, rank); + ct_rhs_exec.prepare(module, &ct_ggsw_rhs, scratch.borrow()); + + ct_ggsw_lhs_out.external_product(module, &ct_ggsw_lhs_in, &ct_rhs_exec, scratch.borrow()); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_ggsw_lhs, 0); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let max_noise = |_col_j: usize| -> f64 { + noise_ggsw_product( + module.n() as f64, + basek * digits, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + rank as f64, + k_in, + k_ggsw, + ) + 0.5 + }; + + ct_ggsw_lhs_out.assert_noise(module, &sk_exec, &pt_ggsw_lhs, &max_noise); +} + +pub(crate) fn test_external_product_inplace( + module: &Module, + basek: usize, + k_ct: usize, + k_ggsw: usize, + digits: usize, + rank: usize, + sigma: f64, +) where + Module: TestModuleFamily + + GGSWAssertNoiseFamily + + GGSWKeySwitchFamily + + GLWESwitchingKeyEncryptSkFamily + + GLWETensorKeyEncryptSkFamily + + GGLWEExecLayoutFamily + + VecZnxRotateInplace, + B: TestScratchFamily, +{ + let rows: usize = k_ct.div_ceil(digits * basek); + let rows_in: usize = k_ct.div_euclid(basek * digits); + let digits_in: usize = 1; + + let mut ct_ggsw_lhs: GGSWCiphertext> = GGSWCiphertext::alloc(module, basek, k_ct, rows_in, digits_in, rank); + let mut ct_ggsw_rhs: GGSWCiphertext> = GGSWCiphertext::alloc(module, basek, k_ggsw, rows, digits, rank); + + let mut pt_ggsw_lhs: ScalarZnx> = module.scalar_znx_alloc(1); + let mut pt_ggsw_rhs: ScalarZnx> = module.scalar_znx_alloc(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + pt_ggsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs); + + let k: usize = 1; + + pt_ggsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_ggsw, rank) + | GGSWCiphertext::external_product_inplace_scratch_space(module, basek, k_ct, k_ggsw, digits, rank), + ); + + let mut sk: GLWESecret> = GLWESecret::alloc(module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk); + + ct_ggsw_rhs.encrypt_sk( + module, + &pt_ggsw_rhs, + &sk_exec, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_ggsw_lhs.encrypt_sk( + module, + &pt_ggsw_lhs, + &sk_exec, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + let mut ct_rhs_exec: GGSWCiphertextExec, B> = GGSWCiphertextExec::alloc(module, basek, k_ggsw, rows, digits, rank); + ct_rhs_exec.prepare(module, &ct_ggsw_rhs, scratch.borrow()); + + ct_ggsw_lhs.external_product_inplace(module, &ct_rhs_exec, scratch.borrow()); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_ggsw_lhs, 0); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let max_noise = |_col_j: usize| -> f64 { + noise_ggsw_product( + module.n() as f64, + basek * digits, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + rank as f64, + k_ct, + k_ggsw, + ) + 0.5 + }; + + ct_ggsw_lhs.assert_noise(module, &sk_exec, &pt_ggsw_lhs, &max_noise); +} diff --git a/core/src/ggsw/test/mod.rs b/core/src/ggsw/test/mod.rs new file mode 100644 index 0000000..bf72215 --- /dev/null +++ b/core/src/ggsw/test/mod.rs @@ -0,0 +1,2 @@ +mod generic_tests; +mod test_fft64; diff --git a/core/src/ggsw/test/test_fft64.rs b/core/src/ggsw/test/test_fft64.rs new file mode 100644 index 0000000..5ec265d --- /dev/null +++ b/core/src/ggsw/test/test_fft64.rs @@ -0,0 +1,127 @@ +use backend::{ + hal::{api::ModuleNew, layouts::Module}, + implementation::cpu_spqlios::FFT64, +}; + +use crate::ggsw::test::generic_tests::{ + test_automorphism, test_automorphism_inplace, test_encrypt_sk, test_external_product, test_external_product_inplace, + test_keyswitch, test_keyswitch_inplace, +}; + +#[test] +fn encrypt_sk() { + let log_n: usize = 8; + let module: Module = Module::::new(1 << log_n); + let basek: usize = 12; + let k_ct: usize = 54; + let digits: usize = k_ct / basek; + (1..4).for_each(|rank| { + (1..digits + 1).for_each(|di| { + println!("test encrypt_sk digits: {} rank: {}", di, rank); + test_encrypt_sk(&module, basek, k_ct, di, rank, 3.2); + }); + }); +} + +#[test] +fn keyswitch() { + let log_n: usize = 8; + let module: Module = Module::::new(1 << log_n); + let basek: usize = 12; + let k_in: usize = 54; + let digits: usize = k_in.div_ceil(basek); + (1..4).for_each(|rank| { + (1..digits + 1).for_each(|di| { + let k_ksk: usize = k_in + basek * di; + let k_tsk: usize = k_ksk; + println!("test keyswitch digits: {} rank: {}", di, rank); + let k_out: usize = k_ksk; // Better capture noise. + test_keyswitch(&module, basek, k_out, k_in, k_ksk, k_tsk, di, rank, 3.2); + }); + }); +} + +#[test] +fn keyswitch_inplace() { + let log_n: usize = 8; + let module: Module = Module::::new(1 << log_n); + let basek: usize = 12; + let k_ct: usize = 54; + let digits: usize = k_ct.div_ceil(basek); + (1..4).for_each(|rank| { + (1..digits + 1).for_each(|di| { + let k_ksk: usize = k_ct + basek * di; + let k_tsk: usize = k_ksk; + println!("test keyswitch_inplace digits: {} rank: {}", di, rank); + test_keyswitch_inplace(&module, basek, k_ct, k_ksk, k_tsk, di, rank, 3.2); + }); + }); +} + +#[test] +fn automorphism() { + let log_n: usize = 8; + let module: Module = Module::::new(1 << log_n); + let basek: usize = 12; + let k_in: usize = 54; + let digits: usize = k_in.div_ceil(basek); + (1..4).for_each(|rank| { + (1..digits + 1).for_each(|di| { + let k_ksk: usize = k_in + basek * di; + let k_tsk: usize = k_ksk; + println!("test automorphism rank: {}", rank); + let k_out: usize = k_ksk; // Better capture noise. + test_automorphism(-5, &module, basek, k_out, k_in, k_ksk, k_tsk, di, rank, 3.2); + }); + }); +} + +#[test] +fn automorphism_inplace() { + let log_n: usize = 8; + let module: Module = Module::::new(1 << log_n); + let basek: usize = 12; + let k_ct: usize = 54; + let digits: usize = k_ct.div_ceil(basek); + (1..4).for_each(|rank| { + (1..digits + 1).for_each(|di| { + let k_ksk: usize = k_ct + basek * di; + let k_tsk: usize = k_ksk; + println!("test automorphism_inplace rank: {}", rank); + test_automorphism_inplace(-5, &module, basek, k_ct, k_ksk, k_tsk, di, rank, 3.2); + }); + }); +} + +#[test] +fn external_product() { + let log_n: usize = 8; + let module: Module = Module::::new(1 << log_n); + let basek: usize = 12; + let k_in: usize = 60; + let digits: usize = k_in.div_ceil(basek); + (1..4).for_each(|rank| { + (1..digits + 1).for_each(|di| { + let k_ggsw: usize = k_in + basek * di; + println!("test external_product digits: {} ranks: {}", di, rank); + let k_out: usize = k_in; // Better capture noise. + test_external_product(&module, basek, k_in, k_out, k_ggsw, di, rank, 3.2); + }); + }); +} + +#[test] +fn external_product_inplace() { + let log_n: usize = 8; + let module: Module = Module::::new(1 << log_n); + let basek: usize = 12; + let k_ct: usize = 60; + let digits: usize = k_ct.div_ceil(basek); + (1..4).for_each(|rank| { + (1..digits).for_each(|di| { + let k_ggsw: usize = k_ct + basek * di; + println!("test external_product digits: {} rank: {}", di, rank); + test_external_product_inplace(&module, basek, k_ct, k_ggsw, di, rank, 3.2); + }); + }); +} diff --git a/core/src/ggsw/test_fft64/ggsw.rs b/core/src/ggsw/test_fft64/ggsw.rs deleted file mode 100644 index dc84eb6..0000000 --- a/core/src/ggsw/test_fft64/ggsw.rs +++ /dev/null @@ -1,962 +0,0 @@ -use backend::{ - FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScalarZnxOps, ScratchOwned, Stats, VecZnxBig, VecZnxBigAlloc, - VecZnxBigOps, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, VecZnxToMut, ZnxViewMut, ZnxZero, -}; -use sampling::source::Source; - -use crate::{ - FourierGLWECiphertext, FourierGLWESecret, GGSWCiphertext, GLWEAutomorphismKey, GLWEPlaintext, GLWESecret, GLWESwitchingKey, - GLWETensorKey, GetRow, Infos, - noise::{noise_ggsw_keyswitch, noise_ggsw_product}, -}; - -#[test] -fn encrypt_sk() { - let log_n: usize = 8; - let basek: usize = 12; - let k_ct: usize = 54; - let digits: usize = k_ct / basek; - (1..4).for_each(|rank| { - (1..digits + 1).for_each(|di| { - println!("test encrypt_sk digits: {} rank: {}", di, rank); - test_encrypt_sk(log_n, basek, k_ct, di, rank, 3.2); - }); - }); -} - -#[test] -fn keyswitch() { - let log_n: usize = 8; - let basek: usize = 12; - let k_in: usize = 54; - let digits: usize = k_in.div_ceil(basek); - (1..4).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_ksk: usize = k_in + basek * di; - let k_tsk: usize = k_ksk; - println!("test keyswitch digits: {} rank: {}", di, rank); - let k_out: usize = k_ksk; // Better capture noise. - test_keyswitch(log_n, basek, k_out, k_in, k_ksk, k_tsk, di, rank, 3.2); - }); - }); -} - -#[test] -fn keyswitch_inplace() { - let log_n: usize = 8; - let basek: usize = 12; - let k_ct: usize = 54; - let digits: usize = k_ct.div_ceil(basek); - (1..4).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_ksk: usize = k_ct + basek * di; - let k_tsk: usize = k_ksk; - println!("test keyswitch_inplace digits: {} rank: {}", di, rank); - test_keyswitch_inplace(log_n, basek, k_ct, k_ksk, k_tsk, di, rank, 3.2); - }); - }); -} - -#[test] -fn automorphism() { - let log_n: usize = 8; - let basek: usize = 12; - let k_in: usize = 54; - let digits: usize = k_in.div_ceil(basek); - (1..4).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_ksk: usize = k_in + basek * di; - let k_tsk: usize = k_ksk; - println!("test automorphism rank: {}", rank); - let k_out: usize = k_ksk; // Better capture noise. - test_automorphism(-5, log_n, basek, k_out, k_in, k_ksk, k_tsk, di, rank, 3.2); - }); - }); -} - -#[test] -fn automorphism_inplace() { - let log_n: usize = 8; - let basek: usize = 12; - let k_ct: usize = 54; - let digits: usize = k_ct.div_ceil(basek); - (1..4).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_ksk: usize = k_ct + basek * di; - let k_tsk: usize = k_ksk; - println!("test automorphism_inplace rank: {}", rank); - test_automorphism_inplace(-5, log_n, basek, k_ct, k_ksk, k_tsk, di, rank, 3.2); - }); - }); -} - -#[test] -fn external_product() { - let log_n: usize = 8; - let basek: usize = 12; - let k_in: usize = 60; - let digits: usize = k_in.div_ceil(basek); - (1..4).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_ggsw: usize = k_in + basek * di; - println!("test external_product digits: {} ranks: {}", di, rank); - let k_out: usize = k_in; // Better capture noise. - test_external_product(log_n, basek, k_in, k_out, k_ggsw, di, rank, 3.2); - }); - }); -} - -#[test] -fn external_product_inplace() { - let log_n: usize = 5; - let basek: usize = 12; - let k_ct: usize = 60; - let digits: usize = k_ct.div_ceil(basek); - (1..4).for_each(|rank| { - (1..digits).for_each(|di| { - let k_ggsw: usize = k_ct + basek * di; - println!("test external_product digits: {} rank: {}", di, rank); - test_external_product_inplace(log_n, basek, k_ct, k_ggsw, di, rank, 3.2); - }); - }); -} - -fn test_encrypt_sk(log_n: usize, basek: usize, k: usize, digits: usize, rank: usize, sigma: f64) { - let module: Module = Module::::new(1 << log_n); - - let rows: usize = (k - digits * basek) / (digits * basek); - - let mut ct: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k, rows, digits, rank); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k); - let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k, rank) - | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k), - ); - - let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); - - ct.encrypt_sk( - &module, - &pt_scalar, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - let mut ct_glwe_fourier: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k, rank); - let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct.size()); - let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct.size()); - - (0..ct.rank() + 1).for_each(|col_j| { - (0..ct.rows()).for_each(|row_i| { - module.vec_znx_add_scalar_inplace( - &mut pt_want.data, - 0, - (digits - 1) + row_i * digits, - &pt_scalar, - 0, - ); - - // mul with sk[col_j-1] - if col_j > 0 { - module.vec_znx_dft(1, 0, &mut pt_dft, 0, &pt_want.data, 0); - module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft.data, col_j - 1); - module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); - module.vec_znx_big_normalize(basek, &mut pt_want.data, 0, &pt_big, 0, scratch.borrow()); - } - - ct.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); - - ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); - - let std_pt: f64 = pt_have.data.std(0, basek) * (k as f64).exp2(); - assert!((sigma - std_pt).abs() <= 0.5, "{} {}", sigma, std_pt); - - pt_want.data.zero(); - }); - }); -} - -fn test_keyswitch( - log_n: usize, - basek: usize, - k_out: usize, - k_in: usize, - k_ksk: usize, - k_tsk: usize, - digits: usize, - rank: usize, - sigma: f64, -) { - let module: Module = Module::::new(1 << log_n); - let rows: usize = k_in.div_ceil(digits * basek); - - let digits_in: usize = 1; - - let mut ct_in: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_in, rows, digits_in, rank); - let mut ct_out: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_out, rows, digits_in, rank); - let mut tsk: GLWETensorKey, FFT64> = GLWETensorKey::alloc(&module, basek, k_ksk, rows, digits, rank); - let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank, rank); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); - let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k_in, rank) - | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_out) - | GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank, rank) - | GLWETensorKey::encrypt_sk_scratch_space(&module, basek, k_tsk, rank) - | GGSWCiphertext::keyswitch_scratch_space( - &module, basek, k_out, k_in, k_ksk, digits, k_tsk, digits, rank, - ), - ); - - let var_xs: f64 = 0.5; - - let mut sk_in: GLWESecret> = GLWESecret::alloc(&module, rank); - sk_in.fill_ternary_prob(var_xs, &mut source_xs); - let sk_in_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_in); - - let mut sk_out: GLWESecret> = GLWESecret::alloc(&module, rank); - sk_out.fill_ternary_prob(var_xs, &mut source_xs); - let sk_out_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_out); - - ksk.encrypt_sk( - &module, - &sk_in, - &sk_out_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - tsk.encrypt_sk( - &module, - &sk_out_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); - - ct_in.encrypt_sk( - &module, - &pt_scalar, - &sk_in_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_out.keyswitch(&module, &ct_in, &ksk, &tsk, scratch.borrow()); - - let mut ct_glwe_fourier: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_out, rank); - let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_out.size()); - let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_out.size()); - - (0..ct_out.rank() + 1).for_each(|col_j| { - (0..ct_out.rows()).for_each(|row_i| { - module.vec_znx_add_scalar_inplace(&mut pt_want.data, 0, row_i, &pt_scalar, 0); - - // mul with sk[col_j-1] - if col_j > 0 { - module.vec_znx_dft(1, 0, &mut pt_dft, 0, &pt_want.data, 0); - module.svp_apply_inplace(&mut pt_dft, 0, &sk_out_dft.data, col_j - 1); - module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); - module.vec_znx_big_normalize(basek, &mut pt_want.data, 0, &pt_big, 0, scratch.borrow()); - } - - ct_out.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); - - ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); - - let noise_have: f64 = pt_have.data.std(0, basek).log2(); - let noise_want: f64 = noise_ggsw_keyswitch( - module.n() as f64, - basek * digits, - col_j, - var_xs, - 0f64, - sigma * sigma, - 0f64, - rank as f64, - k_in, - k_ksk, - k_tsk, - ); - - println!("{} {}", noise_have, noise_want); - - assert!( - noise_have < noise_want + 0.5, - "{} {}", - noise_have, - noise_want - ); - - pt_want.data.zero(); - }); - }); -} - -fn test_keyswitch_inplace( - log_n: usize, - basek: usize, - k_ct: usize, - k_ksk: usize, - k_tsk: usize, - digits: usize, - rank: usize, - sigma: f64, -) { - let module: Module = Module::::new(1 << log_n); - let rows: usize = k_ct.div_ceil(digits * basek); - - let digits_in: usize = 1; - - let mut ct: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ct, rows, digits_in, rank); - let mut tsk: GLWETensorKey, FFT64> = GLWETensorKey::alloc(&module, basek, k_tsk, rows, digits, rank); - let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank, rank); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); - let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k_ct, rank) - | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_ct) - | GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank, rank) - | GLWETensorKey::encrypt_sk_scratch_space(&module, basek, k_tsk, rank) - | GGSWCiphertext::keyswitch_inplace_scratch_space(&module, basek, k_ct, k_ksk, digits, k_tsk, digits, rank), - ); - - let var_xs: f64 = 0.5; - - let mut sk_in: GLWESecret> = GLWESecret::alloc(&module, rank); - sk_in.fill_ternary_prob(var_xs, &mut source_xs); - let sk_in_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_in); - - let mut sk_out: GLWESecret> = GLWESecret::alloc(&module, rank); - sk_out.fill_ternary_prob(var_xs, &mut source_xs); - let sk_out_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_out); - - ksk.encrypt_sk( - &module, - &sk_in, - &sk_out_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - tsk.encrypt_sk( - &module, - &sk_out_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); - - ct.encrypt_sk( - &module, - &pt_scalar, - &sk_in_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct.keyswitch_inplace(&module, &ksk, &tsk, scratch.borrow()); - - let mut ct_glwe_fourier: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_ct, rank); - let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct.size()); - let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct.size()); - - (0..ct.rank() + 1).for_each(|col_j| { - (0..ct.rows()).for_each(|row_i| { - module.vec_znx_add_scalar_inplace( - &mut pt_want.data, - 0, - (digits_in - 1) + row_i * digits_in, - &pt_scalar, - 0, - ); - - // mul with sk[col_j-1] - if col_j > 0 { - module.vec_znx_dft(1, 0, &mut pt_dft, 0, &pt_want.data, 0); - module.svp_apply_inplace(&mut pt_dft, 0, &sk_out_dft.data, col_j - 1); - module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); - module.vec_znx_big_normalize(basek, &mut pt_want.data, 0, &pt_big, 0, scratch.borrow()); - } - - ct.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); - - ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); - - let noise_have: f64 = pt_have.data.std(0, basek).log2(); - let noise_want: f64 = noise_ggsw_keyswitch( - module.n() as f64, - basek * digits, - col_j, - var_xs, - 0f64, - sigma * sigma, - 0f64, - rank as f64, - k_ct, - k_ksk, - k_tsk, - ); - - println!("{} {}", noise_have, noise_want); - - assert!( - noise_have < noise_want + 0.5, - "{} {}", - noise_have, - noise_want - ); - - pt_want.data.zero(); - }); - }); -} - -fn test_automorphism( - p: i64, - log_n: usize, - basek: usize, - k_out: usize, - k_in: usize, - k_ksk: usize, - k_tsk: usize, - digits: usize, - rank: usize, - sigma: f64, -) { - let module: Module = Module::::new(1 << log_n); - let rows: usize = k_in.div_ceil(basek * digits); - let rows_in: usize = k_in.div_euclid(basek * digits); - - let digits_in: usize = 1; - - let mut ct_in: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_in, rows_in, digits_in, rank); - let mut ct_out: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_out, rows_in, digits_in, rank); - let mut tensor_key: GLWETensorKey, FFT64> = GLWETensorKey::alloc(&module, basek, k_tsk, rows, digits, rank); - let mut auto_key: GLWEAutomorphismKey, FFT64> = GLWEAutomorphismKey::alloc(&module, basek, k_ksk, rows, digits, rank); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); - let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k_in, rank) - | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_out) - | GLWEAutomorphismKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank) - | GLWETensorKey::encrypt_sk_scratch_space(&module, basek, k_tsk, rank) - | GGSWCiphertext::automorphism_scratch_space( - &module, basek, k_out, k_in, k_ksk, digits, k_tsk, digits, rank, - ), - ); - - let var_xs: f64 = 0.5; - - let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(var_xs, &mut source_xs); - let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); - - auto_key.encrypt_sk( - &module, - p, - &sk, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - tensor_key.encrypt_sk( - &module, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); - - ct_in.encrypt_sk( - &module, - &pt_scalar, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_out.automorphism(&module, &ct_in, &auto_key, &tensor_key, scratch.borrow()); - - module.scalar_znx_automorphism_inplace(p, &mut pt_scalar, 0); - - let mut ct_glwe_fourier: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_out, rank); - let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_out.size()); - let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_out.size()); - - (0..ct_out.rank() + 1).for_each(|col_j| { - (0..ct_out.rows()).for_each(|row_i| { - module.vec_znx_add_scalar_inplace(&mut pt_want.data, 0, row_i, &pt_scalar, 0); - - // mul with sk[col_j-1] - if col_j > 0 { - module.vec_znx_dft(1, 0, &mut pt_dft, 0, &pt_want.data, 0); - module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft.data, col_j - 1); - module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); - module.vec_znx_big_normalize(basek, &mut pt_want.data, 0, &pt_big, 0, scratch.borrow()); - } - - ct_out.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); - - ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); - - let noise_have: f64 = pt_have.data.std(0, basek).log2(); - let noise_want: f64 = noise_ggsw_keyswitch( - module.n() as f64, - basek * digits, - col_j, - var_xs, - 0f64, - sigma * sigma, - 0f64, - rank as f64, - k_in, - k_ksk, - k_tsk, - ); - - assert!( - noise_have < noise_want + 0.5, - "{} {}", - noise_have, - noise_want - ); - - pt_want.data.zero(); - }); - }); -} - -fn test_automorphism_inplace( - p: i64, - log_n: usize, - basek: usize, - k_ct: usize, - k_ksk: usize, - k_tsk: usize, - digits: usize, - rank: usize, - sigma: f64, -) { - let module: Module = Module::::new(1 << log_n); - let rows: usize = k_ct.div_ceil(digits * basek); - let rows_in: usize = k_ct.div_euclid(basek * digits); - let digits_in: usize = 1; - - let mut ct: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ct, rows_in, digits_in, rank); - let mut tensor_key: GLWETensorKey, FFT64> = GLWETensorKey::alloc(&module, basek, k_tsk, rows, digits, rank); - let mut auto_key: GLWEAutomorphismKey, FFT64> = GLWEAutomorphismKey::alloc(&module, basek, k_ksk, rows, digits, rank); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); - let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k_ct, rank) - | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_ct) - | GLWEAutomorphismKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank) - | GLWETensorKey::encrypt_sk_scratch_space(&module, basek, k_tsk, rank) - | GGSWCiphertext::automorphism_inplace_scratch_space(&module, basek, k_ct, k_ksk, digits, k_tsk, digits, rank), - ); - - let var_xs: f64 = 0.5; - - let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(var_xs, &mut source_xs); - let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); - - auto_key.encrypt_sk( - &module, - p, - &sk, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - tensor_key.encrypt_sk( - &module, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); - - ct.encrypt_sk( - &module, - &pt_scalar, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct.automorphism_inplace(&module, &auto_key, &tensor_key, scratch.borrow()); - - module.scalar_znx_automorphism_inplace(p, &mut pt_scalar, 0); - - let mut ct_glwe_fourier: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_ct, rank); - let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct.size()); - let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct.size()); - - (0..ct.rank() + 1).for_each(|col_j| { - (0..ct.rows()).for_each(|row_i| { - module.vec_znx_add_scalar_inplace(&mut pt_want.data, 0, row_i, &pt_scalar, 0); - - // mul with sk[col_j-1] - if col_j > 0 { - module.vec_znx_dft(1, 0, &mut pt_dft, 0, &pt_want.data, 0); - module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft.data, col_j - 1); - module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); - module.vec_znx_big_normalize(basek, &mut pt_want.data, 0, &pt_big, 0, scratch.borrow()); - } - - ct.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); - - ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); - - let noise_have: f64 = pt_have.data.std(0, basek).log2(); - let noise_want: f64 = noise_ggsw_keyswitch( - module.n() as f64, - basek * digits, - col_j, - var_xs, - 0f64, - sigma * sigma, - 0f64, - rank as f64, - k_ct, - k_ksk, - k_tsk, - ); - - assert!( - noise_have <= noise_want + 0.5, - "{} {}", - noise_have, - noise_want - ); - - pt_want.data.zero(); - }); - }); -} - -fn test_external_product( - log_n: usize, - basek: usize, - k_in: usize, - k_out: usize, - k_ggsw: usize, - digits: usize, - rank: usize, - sigma: f64, -) { - let module: Module = Module::::new(1 << log_n); - - let rows: usize = k_in.div_ceil(basek * digits); - let rows_in: usize = k_in.div_euclid(basek * digits); - let digits_in: usize = 1; - - let mut ct_ggsw_lhs_in: GGSWCiphertext, FFT64> = - GGSWCiphertext::alloc(&module, basek, k_in, rows_in, digits_in, rank); - let mut ct_ggsw_lhs_out: GGSWCiphertext, FFT64> = - GGSWCiphertext::alloc(&module, basek, k_out, rows_in, digits_in, rank); - let mut ct_ggsw_rhs: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, digits, rank); - let mut pt_ggsw_lhs: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_ggsw_rhs: ScalarZnx> = module.new_scalar_znx(1); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - pt_ggsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs); - - let k: usize = 1; - - pt_ggsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} - - let mut scratch: ScratchOwned = ScratchOwned::new( - FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_out) - | GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k_ggsw, rank) - | GGSWCiphertext::external_product_scratch_space(&module, basek, k_out, k_in, k_ggsw, digits, rank), - ); - - let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); - - ct_ggsw_rhs.encrypt_sk( - &module, - &pt_ggsw_rhs, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_ggsw_lhs_in.encrypt_sk( - &module, - &pt_ggsw_lhs, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_ggsw_lhs_out.external_product(&module, &ct_ggsw_lhs_in, &ct_ggsw_rhs, scratch.borrow()); - - let mut ct_glwe_fourier: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_out, rank); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); - let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_ggsw_lhs_out.size()); - let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_ggsw_lhs_out.size()); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); - - module.vec_znx_rotate_inplace(k as i64, &mut pt_ggsw_lhs, 0); - - (0..ct_ggsw_lhs_out.rank() + 1).for_each(|col_j| { - (0..ct_ggsw_lhs_out.rows()).for_each(|row_i| { - module.vec_znx_add_scalar_inplace( - &mut pt_want.data, - 0, - (digits_in - 1) + row_i * digits_in, - &pt_ggsw_lhs, - 0, - ); - - if col_j > 0 { - module.vec_znx_dft(1, 0, &mut pt_dft, 0, &pt_want.data, 0); - module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft.data, col_j - 1); - module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); - module.vec_znx_big_normalize(basek, &mut pt_want.data, 0, &pt_big, 0, scratch.borrow()); - } - - ct_ggsw_lhs_out.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); - ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt.data, 0, &pt_want.data, 0); - - let noise_have: f64 = pt.data.std(0, basek).log2(); - - let var_gct_err_lhs: f64 = sigma * sigma; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} - let var_a0_err: f64 = sigma * sigma; - let var_a1_err: f64 = 1f64 / 12f64; - - let noise_want: f64 = noise_ggsw_product( - module.n() as f64, - basek * digits, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - rank as f64, - k_in, - k_ggsw, - ); - - assert!( - noise_have <= noise_want + 0.5, - "have: {} want: {}", - noise_have, - noise_want - ); - - println!("{} {}", noise_have, noise_want); - - pt_want.data.zero(); - }); - }); -} - -fn test_external_product_inplace(log_n: usize, basek: usize, k_ct: usize, k_ggsw: usize, digits: usize, rank: usize, sigma: f64) { - let module: Module = Module::::new(1 << log_n); - let rows: usize = k_ct.div_ceil(digits * basek); - let rows_in: usize = k_ct.div_euclid(basek * digits); - let digits_in: usize = 1; - - let mut ct_ggsw_lhs: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ct, rows_in, digits_in, rank); - let mut ct_ggsw_rhs: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, digits, rank); - - let mut pt_ggsw_lhs: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_ggsw_rhs: ScalarZnx> = module.new_scalar_znx(1); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - pt_ggsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs); - - let k: usize = 1; - - pt_ggsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} - - let mut scratch: ScratchOwned = ScratchOwned::new( - FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_ct) - | GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k_ggsw, rank) - | GGSWCiphertext::external_product_inplace_scratch_space(&module, basek, k_ct, k_ggsw, digits, rank), - ); - - let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); - - ct_ggsw_rhs.encrypt_sk( - &module, - &pt_ggsw_rhs, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_ggsw_lhs.encrypt_sk( - &module, - &pt_ggsw_lhs, - &sk_dft, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_ggsw_lhs.external_product_inplace(&module, &ct_ggsw_rhs, scratch.borrow()); - - let mut ct_glwe_fourier: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_ct, rank); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); - let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_ggsw_lhs.size()); - let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_ggsw_lhs.size()); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); - - module.vec_znx_rotate_inplace(k as i64, &mut pt_ggsw_lhs, 0); - - (0..ct_ggsw_lhs.rank() + 1).for_each(|col_j| { - (0..ct_ggsw_lhs.rows()).for_each(|row_i| { - module.vec_znx_add_scalar_inplace( - &mut pt_want.data, - 0, - (digits_in - 1) + row_i * digits_in, - &pt_ggsw_lhs, - 0, - ); - - if col_j > 0 { - module.vec_znx_dft(1, 0, &mut pt_dft, 0, &pt_want.data, 0); - module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft.data, col_j - 1); - module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); - module.vec_znx_big_normalize(basek, &mut pt_want.data, 0, &pt_big, 0, scratch.borrow()); - } - - ct_ggsw_lhs.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); - ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt.data, 0, &pt_want.data, 0); - - let noise_have: f64 = pt.data.std(0, basek).log2(); - - let var_gct_err_lhs: f64 = sigma * sigma; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} - let var_a0_err: f64 = sigma * sigma; - let var_a1_err: f64 = 1f64 / 12f64; - - let noise_want: f64 = noise_ggsw_product( - module.n() as f64, - basek * digits, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - rank as f64, - k_ct, - k_ggsw, - ); - - assert!( - noise_have <= noise_want + 0.5, - "have: {} want: {}", - noise_have, - noise_want - ); - - pt_want.data.zero(); - }); - }); -} diff --git a/core/src/ggsw/test_fft64/mod.rs b/core/src/ggsw/test_fft64/mod.rs deleted file mode 100644 index 3326f10..0000000 --- a/core/src/ggsw/test_fft64/mod.rs +++ /dev/null @@ -1 +0,0 @@ -mod ggsw; diff --git a/core/src/glwe/automorphism.rs b/core/src/glwe/automorphism.rs index 1513362..ac731e2 100644 --- a/core/src/glwe/automorphism.rs +++ b/core/src/glwe/automorphism.rs @@ -1,121 +1,187 @@ -use backend::{FFT64, Module, Scratch, VecZnxOps}; +use backend::hal::{ + api::{ + ScratchAvailable, TakeVecZnxDft, VecZnxAutomorphismInplace, VecZnxBigAutomorphismInplace, VecZnxBigSubSmallAInplace, + VecZnxBigSubSmallBInplace, + }, + layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnxBig}, +}; -use crate::{GLWEAutomorphismKey, GLWECiphertext}; +use crate::{AutomorphismKeyExec, GLWECiphertext, GLWEKeyswitchFamily, Infos, glwe::keyswitch::keyswitch}; impl GLWECiphertext> { - pub fn automorphism_scratch_space( - module: &Module, + pub fn automorphism_scratch_space( + module: &Module, basek: usize, k_out: usize, k_in: usize, k_ksk: usize, digits: usize, rank: usize, - ) -> usize { + ) -> usize + where + Module: GLWEKeyswitchFamily, + { Self::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank, rank) } - pub fn automorphism_inplace_scratch_space( - module: &Module, + pub fn automorphism_inplace_scratch_space( + module: &Module, basek: usize, k_out: usize, k_ksk: usize, digits: usize, rank: usize, - ) -> usize { + ) -> usize + where + Module: GLWEKeyswitchFamily, + { Self::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank) } } -impl + AsMut<[u8]>> GLWECiphertext { - pub fn automorphism, DataRhs: AsRef<[u8]>>( +impl GLWECiphertext { + pub fn automorphism( &mut self, - module: &Module, + module: &Module, lhs: &GLWECiphertext, - rhs: &GLWEAutomorphismKey, - scratch: &mut Scratch, - ) { + rhs: &AutomorphismKeyExec, + scratch: &mut Scratch, + ) where + Module: GLWEKeyswitchFamily + VecZnxAutomorphismInplace, + Scratch: TakeVecZnxDft + ScratchAvailable, + { self.keyswitch(module, lhs, &rhs.key, scratch); (0..self.rank() + 1).for_each(|i| { module.vec_znx_automorphism_inplace(rhs.p(), &mut self.data, i); }) } - pub fn automorphism_inplace>( + pub fn automorphism_inplace( &mut self, - module: &Module, - rhs: &GLWEAutomorphismKey, - scratch: &mut Scratch, - ) { + module: &Module, + rhs: &AutomorphismKeyExec, + scratch: &mut Scratch, + ) where + Module: GLWEKeyswitchFamily + VecZnxAutomorphismInplace, + Scratch: TakeVecZnxDft + ScratchAvailable, + { self.keyswitch_inplace(module, &rhs.key, scratch); (0..self.rank() + 1).for_each(|i| { module.vec_znx_automorphism_inplace(rhs.p(), &mut self.data, i); }) } - pub fn automorphism_add, DataRhs: AsRef<[u8]>>( + pub fn automorphism_add( &mut self, - module: &Module, + module: &Module, lhs: &GLWECiphertext, - rhs: &GLWEAutomorphismKey, - scratch: &mut Scratch, - ) { - Self::keyswitch_private::<_, _, 1>(self, rhs.p(), module, lhs, &rhs.key, scratch); + rhs: &AutomorphismKeyExec, + scratch: &mut Scratch, + ) where + Module: GLWEKeyswitchFamily + VecZnxBigAutomorphismInplace, + Scratch: TakeVecZnxDft + ScratchAvailable, + { + #[cfg(debug_assertions)] + { + self.assert_keyswitch(module, lhs, &rhs.key, scratch); + } + let (res_dft, scratch1) = scratch.take_vec_znx_dft(module, self.cols(), rhs.size()); // TODO: optimise size + let mut res_big: VecZnxBig<_, B> = keyswitch(module, res_dft, lhs, &rhs.key, scratch1); + (0..self.cols()).for_each(|i| { + module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i); + module.vec_znx_big_add_small_inplace(&mut res_big, i, &lhs.data, i); + module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch1); + }) } - pub fn automorphism_add_inplace>( + pub fn automorphism_add_inplace( &mut self, - module: &Module, - rhs: &GLWEAutomorphismKey, - scratch: &mut Scratch, - ) { + module: &Module, + rhs: &AutomorphismKeyExec, + scratch: &mut Scratch, + ) where + Module: GLWEKeyswitchFamily + VecZnxBigAutomorphismInplace, + Scratch: TakeVecZnxDft + ScratchAvailable, + { unsafe { let self_ptr: *mut GLWECiphertext = self as *mut GLWECiphertext; - Self::keyswitch_private::<_, _, 1>(self, rhs.p(), module, &*self_ptr, &rhs.key, scratch); + self.automorphism_add(module, &*self_ptr, rhs, scratch); } } - pub fn automorphism_sub_ab, DataRhs: AsRef<[u8]>>( + pub fn automorphism_sub_ab( &mut self, - module: &Module, + module: &Module, lhs: &GLWECiphertext, - rhs: &GLWEAutomorphismKey, - scratch: &mut Scratch, - ) { - Self::keyswitch_private::<_, _, 2>(self, rhs.p(), module, lhs, &rhs.key, scratch); + rhs: &AutomorphismKeyExec, + scratch: &mut Scratch, + ) where + Module: GLWEKeyswitchFamily + VecZnxBigAutomorphismInplace + VecZnxBigSubSmallAInplace, + Scratch: TakeVecZnxDft + ScratchAvailable, + { + #[cfg(debug_assertions)] + { + self.assert_keyswitch(module, lhs, &rhs.key, scratch); + } + let (res_dft, scratch1) = scratch.take_vec_znx_dft(module, self.cols(), rhs.size()); // TODO: optimise size + let mut res_big: VecZnxBig<_, B> = keyswitch(module, res_dft, lhs, &rhs.key, scratch1); + (0..self.cols()).for_each(|i| { + module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i); + module.vec_znx_big_sub_small_a_inplace(&mut res_big, i, &lhs.data, i); + module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch1); + }) } - pub fn automorphism_sub_ab_inplace>( + pub fn automorphism_sub_ab_inplace( &mut self, - module: &Module, - rhs: &GLWEAutomorphismKey, - scratch: &mut Scratch, - ) { + module: &Module, + rhs: &AutomorphismKeyExec, + scratch: &mut Scratch, + ) where + Module: GLWEKeyswitchFamily + VecZnxBigAutomorphismInplace + VecZnxBigSubSmallAInplace, + Scratch: TakeVecZnxDft + ScratchAvailable, + { unsafe { let self_ptr: *mut GLWECiphertext = self as *mut GLWECiphertext; - Self::keyswitch_private::<_, _, 2>(self, rhs.p(), module, &*self_ptr, &rhs.key, scratch); + self.automorphism_sub_ab(module, &*self_ptr, rhs, scratch); } } - pub fn automorphism_sub_ba, DataRhs: AsRef<[u8]>>( + pub fn automorphism_sub_ba( &mut self, - module: &Module, + module: &Module, lhs: &GLWECiphertext, - rhs: &GLWEAutomorphismKey, - scratch: &mut Scratch, - ) { - Self::keyswitch_private::<_, _, 3>(self, rhs.p(), module, lhs, &rhs.key, scratch); + rhs: &AutomorphismKeyExec, + scratch: &mut Scratch, + ) where + Module: GLWEKeyswitchFamily + VecZnxBigAutomorphismInplace + VecZnxBigSubSmallBInplace, + Scratch: TakeVecZnxDft + ScratchAvailable, + { + #[cfg(debug_assertions)] + { + self.assert_keyswitch(module, lhs, &rhs.key, scratch); + } + let (res_dft, scratch1) = scratch.take_vec_znx_dft(module, self.cols(), rhs.size()); // TODO: optimise size + let mut res_big: VecZnxBig<_, B> = keyswitch(module, res_dft, lhs, &rhs.key, scratch1); + (0..self.cols()).for_each(|i| { + module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i); + module.vec_znx_big_sub_small_b_inplace(&mut res_big, i, &lhs.data, i); + module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch1); + }) } - pub fn automorphism_sub_ba_inplace>( + pub fn automorphism_sub_ba_inplace( &mut self, - module: &Module, - rhs: &GLWEAutomorphismKey, - scratch: &mut Scratch, - ) { + module: &Module, + rhs: &AutomorphismKeyExec, + scratch: &mut Scratch, + ) where + Module: GLWEKeyswitchFamily + VecZnxBigAutomorphismInplace + VecZnxBigSubSmallBInplace, + Scratch: TakeVecZnxDft + ScratchAvailable, + { unsafe { let self_ptr: *mut GLWECiphertext = self as *mut GLWECiphertext; - Self::keyswitch_private::<_, _, 3>(self, rhs.p(), module, &*self_ptr, &rhs.key, scratch); + self.automorphism_sub_ba(module, &*self_ptr, rhs, scratch); } } } diff --git a/core/src/glwe/ciphertext.rs b/core/src/glwe/ciphertext.rs deleted file mode 100644 index 07b4264..0000000 --- a/core/src/glwe/ciphertext.rs +++ /dev/null @@ -1,115 +0,0 @@ -use backend::{Backend, FFT64, Module, VecZnx, VecZnxAlloc, VecZnxDftOps, VecZnxToMut, VecZnxToRef}; - -use crate::{FourierGLWECiphertext, GLWEOps, Infos, SetMetaData}; - -pub struct GLWECiphertext { - pub data: VecZnx, - pub basek: usize, - pub k: usize, -} - -impl GLWECiphertext> { - pub fn alloc(module: &Module, basek: usize, k: usize, rank: usize) -> Self { - Self { - data: module.new_vec_znx(rank + 1, k.div_ceil(basek)), - basek, - k, - } - } - - pub fn bytes_of(module: &Module, basek: usize, k: usize, rank: usize) -> usize { - module.bytes_of_vec_znx(rank + 1, k.div_ceil(basek)) - } -} - -impl Infos for GLWECiphertext { - type Inner = VecZnx; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn basek(&self) -> usize { - self.basek - } - - fn k(&self) -> usize { - self.k - } -} - -impl GLWECiphertext { - pub fn rank(&self) -> usize { - self.cols() - 1 - } -} - -impl> GLWECiphertext { - #[allow(dead_code)] - pub(crate) fn dft + AsRef<[u8]>>(&self, module: &Module, res: &mut FourierGLWECiphertext) { - #[cfg(debug_assertions)] - { - assert_eq!(self.rank(), res.rank()); - assert_eq!(self.basek(), res.basek()) - } - - (0..self.rank() + 1).for_each(|i| { - module.vec_znx_dft(1, 0, &mut res.data, i, &self.data, i); - }) - } -} - -impl> GLWECiphertext { - pub fn clone(&self) -> GLWECiphertext> { - GLWECiphertext { - data: self.data.clone(), - basek: self.basek(), - k: self.k(), - } - } -} - -impl + AsRef<[u8]>> SetMetaData for GLWECiphertext { - fn set_k(&mut self, k: usize) { - self.k = k - } - - fn set_basek(&mut self, basek: usize) { - self.basek = basek - } -} - -pub trait GLWECiphertextToRef: Infos { - fn to_ref(&self) -> GLWECiphertext<&[u8]>; -} - -impl> GLWECiphertextToRef for GLWECiphertext { - fn to_ref(&self) -> GLWECiphertext<&[u8]> { - GLWECiphertext { - data: self.data.to_ref(), - basek: self.basek, - k: self.k, - } - } -} - -pub trait GLWECiphertextToMut: Infos { - fn to_mut(&mut self) -> GLWECiphertext<&mut [u8]>; -} - -impl + AsRef<[u8]>> GLWECiphertextToMut for GLWECiphertext { - fn to_mut(&mut self) -> GLWECiphertext<&mut [u8]> { - GLWECiphertext { - data: self.data.to_mut(), - basek: self.basek, - k: self.k, - } - } -} - -impl GLWEOps for GLWECiphertext -where - D: AsRef<[u8]> + AsMut<[u8]>, - GLWECiphertext: GLWECiphertextToMut + Infos + SetMetaData, -{ -} diff --git a/core/src/glwe/decryption.rs b/core/src/glwe/decryption.rs index e543963..fcd2e59 100644 --- a/core/src/glwe/decryption.rs +++ b/core/src/glwe/decryption.rs @@ -1,25 +1,46 @@ -use backend::{ - FFT64, Module, ScalarZnxDftOps, Scratch, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDftAlloc, VecZnxDftOps, - ZnxZero, +use backend::hal::{ + api::{ + DataViewMut, SvpApplyInplace, TakeVecZnxBig, TakeVecZnxDft, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, + VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, + VecZnxNormalizeTmpBytes, + }, + layouts::{Backend, DataMut, DataRef, Module, Scratch}, }; -use crate::{FourierGLWESecret, GLWECiphertext, GLWEPlaintext, Infos}; +use crate::{GLWECiphertext, GLWEPlaintext, GLWESecretExec, Infos}; + +pub trait GLWEDecryptFamily = VecZnxDftAllocBytes + + VecZnxBigAllocBytes + + VecZnxDftFromVecZnx + + SvpApplyInplace + + VecZnxDftToVecZnxBigConsume + + VecZnxBigAddInplace + + VecZnxBigAddSmallInplace + + VecZnxBigNormalize + + VecZnxNormalizeTmpBytes; impl GLWECiphertext> { - pub fn decrypt_scratch_space(module: &Module, basek: usize, k: usize) -> usize { + pub fn decrypt_scratch_space(module: &Module, basek: usize, k: usize) -> usize + where + Module: GLWEDecryptFamily, + { let size: usize = k.div_ceil(basek); - (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) + (module.vec_znx_normalize_tmp_bytes(module.n()) | module.vec_znx_dft_alloc_bytes(1, size)) + + module.vec_znx_dft_alloc_bytes(1, size) } } -impl> GLWECiphertext { - pub fn decrypt + AsRef<[u8]>, DataSk: AsRef<[u8]>>( +impl GLWECiphertext { + pub fn decrypt( &self, - module: &Module, + module: &Module, pt: &mut GLWEPlaintext, - sk: &FourierGLWESecret, - scratch: &mut Scratch, - ) { + sk: &GLWESecretExec, + scratch: &mut Scratch, + ) where + Module: GLWEDecryptFamily, + Scratch: TakeVecZnxDft + TakeVecZnxBig, + { #[cfg(debug_assertions)] { assert_eq!(self.rank(), sk.rank()); @@ -30,16 +51,16 @@ impl> GLWECiphertext { let cols: usize = self.rank() + 1; - let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, self.size()); // TODO optimize size when pt << ct - c0_big.zero(); + let (mut c0_big, scratch_1) = scratch.take_vec_znx_big(module, 1, self.size()); // TODO optimize size when pt << ct + c0_big.data_mut().fill(0); { (1..cols).for_each(|i| { // ci_dft = DFT(a[i]) * DFT(s[i]) - let (mut ci_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, self.size()); // TODO optimize size when pt << ct - module.vec_znx_dft(1, 0, &mut ci_dft, 0, &self.data, i); + let (mut ci_dft, _) = scratch_1.take_vec_znx_dft(module, 1, self.size()); // TODO optimize size when pt << ct + module.vec_znx_dft_from_vec_znx(1, 0, &mut ci_dft, 0, &self.data, i); module.svp_apply_inplace(&mut ci_dft, 0, &sk.data, i - 1); - let ci_big = module.vec_znx_idft_consume(ci_dft); + let ci_big = module.vec_znx_dft_to_vec_znx_big_consume(ci_dft); // c0_big += a[i] * s[i] module.vec_znx_big_add_inplace(&mut c0_big, 0, &ci_big, 0); diff --git a/core/src/glwe/encryption.rs b/core/src/glwe/encryption.rs index b0a7615..67f0de8 100644 --- a/core/src/glwe/encryption.rs +++ b/core/src/glwe/encryption.rs @@ -1,35 +1,76 @@ -use backend::{ - AddNormal, FFT64, FillUniform, Module, ScalarZnxAlloc, ScalarZnxDftAlloc, ScalarZnxDftOps, Scratch, VecZnxAlloc, VecZnxBig, - VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxZero, +use backend::hal::{ + api::{ + ScalarZnxAllocBytes, ScratchAvailable, SvpApply, SvpApplyInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, + TakeSvpPPol, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxBigAddNormal, VecZnxBigAddSmallInplace, + VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, + VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSubABInplace, ZnxZero, + }, + layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnxBig}, }; use sampling::source::Source; -use crate::{FourierGLWESecret, GLWECiphertext, GLWEPlaintext, GLWEPublicKey, Infos, SIX_SIGMA, dist::Distribution}; +use crate::{GLWECiphertext, GLWEPlaintext, GLWEPublicKeyExec, GLWESecretExec, Infos, SIX_SIGMA, dist::Distribution}; + +pub trait GLWEEncryptSkFamily = VecZnxDftAllocBytes + + VecZnxBigNormalize + + VecZnxDftFromVecZnx + + SvpApplyInplace + + VecZnxDftToVecZnxBigConsume + + VecZnxNormalizeTmpBytes + + VecZnxFillUniform + + VecZnxSubABInplace + + VecZnxAddInplace + + VecZnxNormalizeInplace + + VecZnxAddNormal + + VecZnxNormalize; + +pub trait GLWEEncryptPkFamily = VecZnxDftAllocBytes + + VecZnxBigAllocBytes + + SvpPPolAllocBytes + + SvpPrepare + + SvpApply + + VecZnxDftToVecZnxBigConsume + + VecZnxBigAddNormal + + VecZnxBigAddSmallInplace + + VecZnxBigNormalize + + ScalarZnxAllocBytes + + VecZnxNormalizeTmpBytes; impl GLWECiphertext> { - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize) -> usize { + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize) -> usize + where + Module: GLWEEncryptSkFamily, + { let size: usize = k.div_ceil(basek); - module.vec_znx_big_normalize_tmp_bytes() + module.bytes_of_vec_znx_dft(1, size) + module.bytes_of_vec_znx(1, size) + module.vec_znx_normalize_tmp_bytes(module.n()) + + module.vec_znx_dft_alloc_bytes(1, size) + + module.vec_znx_dft_alloc_bytes(1, size) } - pub fn encrypt_pk_scratch_space(module: &Module, basek: usize, k: usize) -> usize { + pub fn encrypt_pk_scratch_space(module: &Module, basek: usize, k: usize) -> usize + where + Module: GLWEEncryptPkFamily, + { let size: usize = k.div_ceil(basek); - ((module.bytes_of_vec_znx_dft(1, size) + module.bytes_of_vec_znx_big(1, size)) | module.bytes_of_scalar_znx(1)) - + module.bytes_of_scalar_znx_dft(1) - + module.vec_znx_big_normalize_tmp_bytes() + ((module.vec_znx_dft_alloc_bytes(1, size) + module.vec_znx_big_alloc_bytes(1, size)) | module.scalar_znx_alloc_bytes(1)) + + module.svp_ppol_alloc_bytes(1) + + module.vec_znx_normalize_tmp_bytes(module.n()) } } -impl + AsMut<[u8]>> GLWECiphertext { - pub fn encrypt_sk, DataSk: AsRef<[u8]>>( +impl GLWECiphertext { + pub fn encrypt_sk( &mut self, - module: &Module, + module: &Module, pt: &GLWEPlaintext, - sk: &FourierGLWESecret, + sk: &GLWESecretExec, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, - scratch: &mut Scratch, - ) { + scratch: &mut Scratch, + ) where + Module: GLWEEncryptSkFamily, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + { self.encrypt_sk_private( module, Some((pt, 0)), @@ -41,15 +82,18 @@ impl + AsMut<[u8]>> GLWECiphertext { ); } - pub fn encrypt_zero_sk>( + pub fn encrypt_zero_sk( &mut self, - module: &Module, - sk: &FourierGLWESecret, + module: &Module, + sk: &GLWESecretExec, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, - scratch: &mut Scratch, - ) { + scratch: &mut Scratch, + ) where + Module: GLWEEncryptSkFamily, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + { self.encrypt_sk_private( module, None::<(&GLWEPlaintext>, usize)>, @@ -61,17 +105,20 @@ impl + AsMut<[u8]>> GLWECiphertext { ); } - pub fn encrypt_pk, DataPk: AsRef<[u8]>>( + pub fn encrypt_pk( &mut self, - module: &Module, + module: &Module, pt: &GLWEPlaintext, - pk: &GLWEPublicKey, + pk: &GLWEPublicKeyExec, source_xu: &mut Source, source_xe: &mut Source, sigma: f64, - scratch: &mut Scratch, - ) { - self.encrypt_pk_private::( + scratch: &mut Scratch, + ) where + Module: GLWEEncryptPkFamily, + Scratch: TakeVecZnxDft + TakeSvpPPol + TakeScalarZnx, + { + self.encrypt_pk_private::( module, Some((pt, 0)), pk, @@ -82,16 +129,19 @@ impl + AsMut<[u8]>> GLWECiphertext { ); } - pub fn encrypt_zero_pk>( + pub fn encrypt_zero_pk( &mut self, - module: &Module, - pk: &GLWEPublicKey, + module: &Module, + pk: &GLWEPublicKeyExec, source_xu: &mut Source, source_xe: &mut Source, sigma: f64, - scratch: &mut Scratch, - ) { - self.encrypt_pk_private::, DataPk>( + scratch: &mut Scratch, + ) where + Module: GLWEEncryptPkFamily, + Scratch: TakeVecZnxDft + TakeSvpPPol + TakeScalarZnx, + { + self.encrypt_pk_private::, DataPk, B>( module, None::<(&GLWEPlaintext>, usize)>, pk, @@ -102,16 +152,19 @@ impl + AsMut<[u8]>> GLWECiphertext { ); } - pub(crate) fn encrypt_sk_private, DataSk: AsRef<[u8]>>( + pub(crate) fn encrypt_sk_private( &mut self, - module: &Module, + module: &Module, pt: Option<(&GLWEPlaintext, usize)>, - sk: &FourierGLWESecret, + sk: &GLWESecretExec, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, - scratch: &mut Scratch, - ) { + scratch: &mut Scratch, + ) where + Module: GLWEEncryptSkFamily, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + { #[cfg(debug_assertions)] { assert_eq!(self.rank(), sk.rank()); @@ -134,28 +187,28 @@ impl + AsMut<[u8]>> GLWECiphertext { let size: usize = self.size(); let cols: usize = self.rank() + 1; - let (mut c0_big, scratch_1) = scratch.tmp_vec_znx(module, 1, size); - c0_big.zero(); + let (mut c0, scratch_1) = scratch.take_vec_znx(module, 1, size); + c0.zero(); { // c[i] = uniform // c[0] -= c[i] * s[i], (1..cols).for_each(|i| { - let (mut ci_dft, scratch_2) = scratch_1.tmp_vec_znx_dft(module, 1, size); + let (mut ci_dft, scratch_2) = scratch_1.take_vec_znx_dft(module, 1, size); // c[i] = uniform - self.data.fill_uniform(basek, i, size, source_xa); + module.vec_znx_fill_uniform(basek, &mut self.data, i, k, source_xa); // c[i] = norm(IDFT(DFT(c[i]) * DFT(s[i]))) - module.vec_znx_dft(1, 0, &mut ci_dft, 0, &self.data, i); + module.vec_znx_dft_from_vec_znx(1, 0, &mut ci_dft, 0, &self.data, i); module.svp_apply_inplace(&mut ci_dft, 0, &sk.data, i - 1); - let ci_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(ci_dft); + let ci_big: VecZnxBig<&mut [u8], B> = module.vec_znx_dft_to_vec_znx_big_consume(ci_dft); // use c[0] as buffer, which is overwritten later by the normalization step module.vec_znx_big_normalize(basek, &mut self.data, 0, &ci_big, 0, scratch_2); // c0_tmp = -c[i] * s[i] (use c[0] as buffer) - module.vec_znx_sub_ab_inplace(&mut c0_big, 0, &self.data, 0); + module.vec_znx_sub_ab_inplace(&mut c0, 0, &self.data, 0); // c[i] += m if col = i if let Some((pt, col)) = pt { @@ -168,29 +221,39 @@ impl + AsMut<[u8]>> GLWECiphertext { } // c[0] += e - c0_big.add_normal(basek, 0, k, source_xe, sigma, sigma * SIX_SIGMA); + module.vec_znx_add_normal(basek, &mut c0, 0, k, source_xe, sigma, sigma * SIX_SIGMA); // c[0] += m if col = 0 if let Some((pt, col)) = pt { if col == 0 { - module.vec_znx_add_inplace(&mut c0_big, 0, &pt.data, 0); + module.vec_znx_add_inplace(&mut c0, 0, &pt.data, 0); } } // c[0] = norm(c[0]) - module.vec_znx_normalize(basek, &mut self.data, 0, &c0_big, 0, scratch_1); + module.vec_znx_normalize(basek, &mut self.data, 0, &c0, 0, scratch_1); } - pub(crate) fn encrypt_pk_private, DataPk: AsRef<[u8]>>( + pub(crate) fn encrypt_pk_private( &mut self, - module: &Module, + module: &Module, pt: Option<(&GLWEPlaintext, usize)>, - pk: &GLWEPublicKey, + pk: &GLWEPublicKeyExec, source_xu: &mut Source, source_xe: &mut Source, sigma: f64, - scratch: &mut Scratch, - ) { + scratch: &mut Scratch, + ) where + Module: VecZnxDftAllocBytes + + SvpPPolAllocBytes + + SvpPrepare + + SvpApply + + VecZnxDftToVecZnxBigConsume + + VecZnxBigAddNormal + + VecZnxBigAddSmallInplace + + VecZnxBigNormalize, + Scratch: TakeVecZnxDft + TakeSvpPPol + TakeScalarZnx, + { #[cfg(debug_assertions)] { assert_eq!(self.basek(), pk.basek()); @@ -208,10 +271,10 @@ impl + AsMut<[u8]>> GLWECiphertext { let cols: usize = self.rank() + 1; // Generates u according to the underlying secret distribution. - let (mut u_dft, scratch_1) = scratch.tmp_scalar_znx_dft(module, 1); + let (mut u_dft, scratch_1) = scratch.take_svp_ppol(module, 1); { - let (mut u, _) = scratch_1.tmp_scalar_znx(module, 1); + let (mut u, _) = scratch_1.take_scalar_znx(module, 1); match pk.dist { Distribution::NONE => panic!( "invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through \ @@ -230,15 +293,23 @@ impl + AsMut<[u8]>> GLWECiphertext { // ct[i] = pk[i] * u + ei (+ m if col = i) (0..cols).for_each(|i| { - let (mut ci_dft, scratch_2) = scratch_1.tmp_vec_znx_dft(module, 1, size_pk); + let (mut ci_dft, scratch_2) = scratch_1.take_vec_znx_dft(module, 1, size_pk); // ci_dft = DFT(u) * DFT(pk[i]) - module.svp_apply(&mut ci_dft, 0, &u_dft, 0, &pk.data.data, i); + module.svp_apply(&mut ci_dft, 0, &u_dft, 0, &pk.data, i); // ci_big = u * p[i] - let mut ci_big = module.vec_znx_idft_consume(ci_dft); + let mut ci_big = module.vec_znx_dft_to_vec_znx_big_consume(ci_dft); // ci_big = u * pk[i] + e - ci_big.add_normal(basek, 0, pk.k(), source_xe, sigma, sigma * SIX_SIGMA); + module.vec_znx_big_add_normal( + basek, + &mut ci_big, + 0, + pk.k(), + source_xe, + sigma, + sigma * SIX_SIGMA, + ); // ci_big = u * pk[i] + e + m (if col = i) if let Some((pt, col)) = pt { diff --git a/core/src/glwe/external_product.rs b/core/src/glwe/external_product.rs index a95c5ac..52a18bd 100644 --- a/core/src/glwe/external_product.rs +++ b/core/src/glwe/external_product.rs @@ -1,24 +1,40 @@ -use backend::{ - FFT64, MatZnxDftOps, MatZnxDftScratch, Module, Scratch, VecZnxBig, VecZnxBigOps, VecZnxDftAlloc, VecZnxDftOps, VecZnxScratch, +use backend::hal::{ + api::{ + DataViewMut, ScratchAvailable, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftFromVecZnx, + VecZnxDftToVecZnxBigConsume, VecZnxNormalizeTmpBytes, VmpApply, VmpApplyAdd, VmpApplyTmpBytes, + }, + layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnxBig}, }; -use crate::{GGSWCiphertext, GLWECiphertext, Infos}; +use crate::{GGSWCiphertextExec, GLWECiphertext, Infos}; + +pub trait GLWEExternalProductFamily = VecZnxDftAllocBytes + + VmpApplyTmpBytes + + VmpApply + + VmpApplyAdd + + VecZnxDftFromVecZnx + + VecZnxDftToVecZnxBigConsume + + VecZnxBigNormalize + + VecZnxNormalizeTmpBytes; impl GLWECiphertext> { - pub fn external_product_scratch_space( - module: &Module, + pub fn external_product_scratch_space( + module: &Module, basek: usize, k_out: usize, k_in: usize, k_ggsw: usize, digits: usize, rank: usize, - ) -> usize { + ) -> usize + where + Module: GLWEExternalProductFamily, + { let in_size: usize = k_in.div_ceil(basek).div_ceil(digits); let out_size: usize = k_out.div_ceil(basek); let ggsw_size: usize = k_ggsw.div_ceil(basek); - let res_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, ggsw_size); - let a_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size); + let res_dft: usize = module.vec_znx_dft_alloc_bytes(rank + 1, ggsw_size); + let a_dft: usize = module.vec_znx_dft_alloc_bytes(rank + 1, in_size); let vmp: usize = module.vmp_apply_tmp_bytes( out_size, in_size, @@ -27,34 +43,42 @@ impl GLWECiphertext> { rank + 1, // cols out ggsw_size, ); - let normalize: usize = module.vec_znx_normalize_tmp_bytes(); + let normalize: usize = module.vec_znx_normalize_tmp_bytes(module.n()); res_dft + a_dft + (vmp | normalize) } - pub fn external_product_inplace_scratch_space( - module: &Module, + pub fn external_product_inplace_scratch_space( + module: &Module, basek: usize, k_out: usize, k_ggsw: usize, digits: usize, rank: usize, - ) -> usize { + ) -> usize + where + Module: GLWEExternalProductFamily, + { Self::external_product_scratch_space(module, basek, k_out, k_out, k_ggsw, digits, rank) } } -impl + AsMut<[u8]>> GLWECiphertext { - pub fn external_product, DataRhs: AsRef<[u8]>>( +impl GLWECiphertext { + pub fn external_product( &mut self, - module: &Module, + module: &Module, lhs: &GLWECiphertext, - rhs: &GGSWCiphertext, - scratch: &mut Scratch, - ) { + rhs: &GGSWCiphertextExec, + scratch: &mut Scratch, + ) where + Module: GLWEExternalProductFamily, + Scratch: TakeVecZnxDft + ScratchAvailable, + { let basek: usize = self.basek(); #[cfg(debug_assertions)] { + use backend::hal::api::ScratchAvailable; + assert_eq!(rhs.rank(), lhs.rank()); assert_eq!(rhs.rank(), self.rank()); assert_eq!(self.basek(), basek); @@ -79,12 +103,14 @@ impl + AsMut<[u8]>> GLWECiphertext { let cols: usize = rhs.rank() + 1; let digits: usize = rhs.digits(); - let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size()); // Todo optimise - let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, (lhs.size() + digits - 1) / digits); + let (mut res_dft, scratch1) = scratch.take_vec_znx_dft(module, cols, rhs.size()); // Todo optimise + let (mut a_dft, scratch2) = scratch1.take_vec_znx_dft(module, cols, lhs.size().div_ceil(digits)); + + a_dft.data_mut().fill(0); { (0..digits).for_each(|di| { - // (lhs.size() + di) / digits = (a - (digit - di - 1) + digit - 1) / digits + // (lhs.size() + di) / digits = (a - (digit - di - 1)).div_ceil(digits) a_dft.set_size((lhs.size() + di) / digits); // Small optimization for digits > 2 @@ -97,7 +123,7 @@ impl + AsMut<[u8]>> GLWECiphertext { res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize); (0..cols).for_each(|col_i| { - module.vec_znx_dft(digits, digits - 1 - di, &mut a_dft, col_i, &lhs.data, col_i); + module.vec_znx_dft_from_vec_znx(digits, digits - 1 - di, &mut a_dft, col_i, &lhs.data, col_i); }); if di == 0 { @@ -108,19 +134,22 @@ impl + AsMut<[u8]>> GLWECiphertext { }); } - let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft); + let res_big: VecZnxBig<&mut [u8], B> = module.vec_znx_dft_to_vec_znx_big_consume(res_dft); (0..cols).for_each(|i| { module.vec_znx_big_normalize(basek, &mut self.data, i, &res_big, i, scratch1); }); } - pub fn external_product_inplace>( + pub fn external_product_inplace( &mut self, - module: &Module, - rhs: &GGSWCiphertext, - scratch: &mut Scratch, - ) { + module: &Module, + rhs: &GGSWCiphertextExec, + scratch: &mut Scratch, + ) where + Module: GLWEExternalProductFamily, + Scratch: TakeVecZnxDft + ScratchAvailable, + { unsafe { let self_ptr: *mut GLWECiphertext = self as *mut GLWECiphertext; self.external_product(&module, &*self_ptr, rhs, scratch); diff --git a/core/src/glwe/keyswitch.rs b/core/src/glwe/keyswitch.rs index ca6dcad..87e494d 100644 --- a/core/src/glwe/keyswitch.rs +++ b/core/src/glwe/keyswitch.rs @@ -1,13 +1,27 @@ -use backend::{ - FFT64, MatZnxDftOps, MatZnxDftScratch, Module, Scratch, VecZnxBig, VecZnxBigOps, VecZnxBigScratch, VecZnxDftAlloc, - VecZnxDftOps, ZnxZero, +use backend::hal::{ + api::{ + DataViewMut, ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, + VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, VmpApply, VmpApplyAdd, VmpApplyTmpBytes, ZnxInfos, + }, + layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft, VmpPMat}, }; -use crate::{FourierGLWECiphertext, GLWECiphertext, GLWESwitchingKey, Infos}; +use crate::{GLWECiphertext, GLWESwitchingKeyExec, Infos}; + +pub trait GLWEKeyswitchFamily = VecZnxDftAllocBytes + + VmpApplyTmpBytes + + VecZnxBigNormalizeTmpBytes + + VmpApplyTmpBytes + + VmpApply + + VmpApplyAdd + + VecZnxDftFromVecZnx + + VecZnxDftToVecZnxBigConsume + + VecZnxBigAddSmallInplace + + VecZnxBigNormalize; impl GLWECiphertext> { - pub fn keyswitch_scratch_space( - module: &Module, + pub fn keyswitch_scratch_space( + module: &Module, basek: usize, k_out: usize, k_in: usize, @@ -15,20 +29,23 @@ impl GLWECiphertext> { digits: usize, rank_in: usize, rank_out: usize, - ) -> usize { - let res_dft: usize = FourierGLWECiphertext::bytes_of(module, basek, k_out, rank_out + 1); + ) -> usize + where + Module: GLWEKeyswitchFamily, + { let in_size: usize = k_in.div_ceil(basek).div_ceil(digits); let out_size: usize = k_out.div_ceil(basek); let ksk_size: usize = k_ksk.div_ceil(basek); - let ai_dft: usize = module.bytes_of_vec_znx_dft(rank_in, in_size); + let res_dft: usize = module.vec_znx_dft_alloc_bytes(rank_out + 1, ksk_size); // TODO OPTIMIZE + let ai_dft: usize = module.vec_znx_dft_alloc_bytes(rank_in, in_size); let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, rank_in, rank_out + 1, ksk_size) - + module.bytes_of_vec_znx_dft(rank_in, in_size); - let normalize: usize = module.vec_znx_big_normalize_tmp_bytes(); + + module.vec_znx_dft_alloc_bytes(rank_in, in_size); + let normalize: usize = module.vec_znx_big_normalize_tmp_bytes(module.n()); return res_dft + ((ai_dft + vmp) | normalize); } - pub fn keyswitch_from_fourier_scratch_space( - module: &Module, + pub fn keyswitch_from_fourier_scratch_space( + module: &Module, basek: usize, k_out: usize, k_in: usize, @@ -36,221 +53,220 @@ impl GLWECiphertext> { digits: usize, rank_in: usize, rank_out: usize, - ) -> usize { + ) -> usize + where + Module: GLWEKeyswitchFamily, + { Self::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank_in, rank_out) } - pub fn keyswitch_inplace_scratch_space( - module: &Module, + pub fn keyswitch_inplace_scratch_space( + module: &Module, basek: usize, k_out: usize, k_ksk: usize, digits: usize, rank: usize, - ) -> usize { + ) -> usize + where + Module: GLWEKeyswitchFamily, + { Self::keyswitch_scratch_space(module, basek, k_out, k_out, k_ksk, digits, rank, rank) } } -impl + AsMut<[u8]>> GLWECiphertext { - pub fn keyswitch, DataRhs: AsRef<[u8]>>( - &mut self, - module: &Module, +impl GLWECiphertext { + pub(crate) fn assert_keyswitch( + &self, + module: &Module, lhs: &GLWECiphertext, - rhs: &GLWESwitchingKey, - scratch: &mut Scratch, - ) { - Self::keyswitch_private::<_, _, 0>(self, 0, module, lhs, rhs, scratch); + rhs: &GLWESwitchingKeyExec, + scratch: &Scratch, + ) where + DataLhs: DataRef, + DataRhs: DataRef, + Module: GLWEKeyswitchFamily, + Scratch: ScratchAvailable, + { + let basek: usize = self.basek(); + assert_eq!( + lhs.rank(), + rhs.rank_in(), + "lhs.rank(): {} != rhs.rank_in(): {}", + lhs.rank(), + rhs.rank_in() + ); + assert_eq!( + self.rank(), + rhs.rank_out(), + "self.rank(): {} != rhs.rank_out(): {}", + self.rank(), + rhs.rank_out() + ); + assert_eq!(self.basek(), basek); + assert_eq!(lhs.basek(), basek); + assert_eq!(rhs.n(), module.n()); + assert_eq!(self.n(), module.n()); + assert_eq!(lhs.n(), module.n()); + assert!( + scratch.available() + >= GLWECiphertext::keyswitch_scratch_space( + module, + self.basek(), + self.k(), + lhs.k(), + rhs.k(), + rhs.digits(), + rhs.rank_in(), + rhs.rank_out(), + ) + ); + } +} + +impl GLWECiphertext { + pub fn keyswitch( + &mut self, + module: &Module, + lhs: &GLWECiphertext, + rhs: &GLWESwitchingKeyExec, + scratch: &mut Scratch, + ) where + Module: GLWEKeyswitchFamily, + Scratch: TakeVecZnxDft + ScratchAvailable, + { + #[cfg(debug_assertions)] + { + self.assert_keyswitch(module, lhs, rhs, scratch); + } + let (res_dft, scratch1) = scratch.take_vec_znx_dft(module, self.cols(), rhs.size()); // Todo optimise + let res_big: VecZnxBig<_, B> = keyswitch(module, res_dft, lhs, rhs, scratch1); + (0..self.cols()).for_each(|i| { + module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch1); + }) } - pub fn keyswitch_inplace>( + pub fn keyswitch_inplace( &mut self, - module: &Module, - rhs: &GLWESwitchingKey, - scratch: &mut Scratch, - ) { + module: &Module, + rhs: &GLWESwitchingKeyExec, + scratch: &mut Scratch, + ) where + Module: GLWEKeyswitchFamily, + Scratch: TakeVecZnxDft + ScratchAvailable, + { unsafe { let self_ptr: *mut GLWECiphertext = self as *mut GLWECiphertext; self.keyswitch(&module, &*self_ptr, rhs, scratch); } } - - pub(crate) fn keyswitch_private, DataRhs: AsRef<[u8]>, const OP: u8>( - &mut self, - apply_auto: i64, - module: &Module, - lhs: &GLWECiphertext, - rhs: &GLWESwitchingKey, - scratch: &mut Scratch, - ) { - let basek: usize = self.basek(); - - #[cfg(debug_assertions)] - { - assert_eq!( - lhs.rank(), - rhs.rank_in(), - "lhs.rank(): {} != rhs.rank_in(): {}", - lhs.rank(), - rhs.rank_in() - ); - assert_eq!( - self.rank(), - rhs.rank_out(), - "self.rank(): {} != rhs.rank_out(): {}", - self.rank(), - rhs.rank_out() - ); - assert_eq!(self.basek(), basek); - assert_eq!(lhs.basek(), basek); - assert_eq!(rhs.n(), module.n()); - assert_eq!(self.n(), module.n()); - assert_eq!(lhs.n(), module.n()); - assert!( - scratch.available() - >= GLWECiphertext::keyswitch_scratch_space( - module, - self.basek(), - self.k(), - lhs.k(), - rhs.k(), - rhs.digits(), - rhs.rank_in(), - rhs.rank_out(), - ) - ); - } - - let cols_in: usize = rhs.rank_in(); - let cols_out: usize = rhs.rank_out() + 1; - let digits: usize = rhs.digits(); - - let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise - let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, (lhs.size() + digits - 1) / digits); - ai_dft.zero(); - { - (0..digits).for_each(|di| { - ai_dft.set_size((lhs.size() + di) / digits); - - // Small optimization for digits > 2 - // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then - // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}. - // As such we can ignore the last digits-2 limbs safely of the sum of vmp products. - // It is possible to further ignore the last digits-1 limbs, but this introduce - // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same - // noise is kept with respect to the ideal functionality. - res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize); - - (0..cols_in).for_each(|col_i| { - module.vec_znx_dft( - digits, - digits - di - 1, - &mut ai_dft, - col_i, - &lhs.data, - col_i + 1, - ); - }); - - if di == 0 { - module.vmp_apply(&mut res_dft, &ai_dft, &rhs.key.data, scratch2); - } else { - module.vmp_apply_add(&mut res_dft, &ai_dft, &rhs.key.data, di, scratch2); - } - }); - } - - let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft); - - module.vec_znx_big_add_small_inplace(&mut res_big, 0, &lhs.data, 0); - - (0..cols_out).for_each(|i| { - if apply_auto != 0 { - module.vec_znx_big_automorphism_inplace(apply_auto, &mut res_big, i); - } - - match OP { - 1 => module.vec_znx_big_add_small_inplace(&mut res_big, i, &lhs.data, i), - 2 => module.vec_znx_big_sub_small_a_inplace(&mut res_big, i, &lhs.data, i), - 3 => module.vec_znx_big_sub_small_b_inplace(&mut res_big, i, &lhs.data, i), - _ => {} - } - module.vec_znx_big_normalize(basek, &mut self.data, i, &res_big, i, scratch1); - }); - } - - pub(crate) fn keyswitch_from_fourier, DataRhs: AsRef<[u8]>>( - &mut self, - module: &Module, - lhs: &FourierGLWECiphertext, - rhs: &GLWESwitchingKey, - scratch: &mut Scratch, - ) { - let basek: usize = self.basek(); - - #[cfg(debug_assertions)] - { - assert_eq!(lhs.rank(), rhs.rank_in()); - assert_eq!(self.rank(), rhs.rank_out()); - assert_eq!(self.basek(), basek); - assert_eq!(lhs.basek(), basek); - assert_eq!(rhs.n(), module.n()); - assert_eq!(self.n(), module.n()); - assert_eq!(lhs.n(), module.n()); - assert!( - scratch.available() - >= GLWECiphertext::keyswitch_from_fourier_scratch_space( - module, - self.basek(), - self.k(), - lhs.k(), - rhs.k(), - rhs.digits(), - rhs.rank_in(), - rhs.rank_out(), - ) - ); - } - - let cols_in: usize = rhs.rank_in(); - let cols_out: usize = rhs.rank_out() + 1; - - // Buffer of the result of VMP in DFT - let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise - - { - let digits = rhs.digits(); - - (0..digits).for_each(|di| { - // (lhs.size() + di) / digits = (a - (digit - di - 1) + digit - 1) / digits - let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, (lhs.size() + di) / digits); - - (0..cols_in).for_each(|col_i| { - module.vec_znx_dft_copy( - digits, - digits - 1 - di, - &mut ai_dft, - col_i, - &lhs.data, - col_i + 1, - ); - }); - - if di == 0 { - module.vmp_apply(&mut res_dft, &ai_dft, &rhs.key.data, scratch2); - } else { - module.vmp_apply_add(&mut res_dft, &ai_dft, &rhs.key.data, di, scratch2); - } - }); - } - - module.vec_znx_dft_add_inplace(&mut res_dft, 0, &lhs.data, 0); - - // Switches result of VMP outside of DFT - let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft); - - (0..cols_out).for_each(|i| { - module.vec_znx_big_normalize(basek, &mut self.data, i, &res_big, i, scratch1); - }); - } +} + +pub(crate) fn keyswitch( + module: &Module, + res_dft: VecZnxDft, + lhs: &GLWECiphertext, + rhs: &GLWESwitchingKeyExec, + scratch: &mut Scratch, +) -> VecZnxBig +where + DataRes: DataMut, + DataIn: DataRef, + DataKey: DataRef, + Module: GLWEKeyswitchFamily, + Scratch: TakeVecZnxDft, +{ + if rhs.digits() == 1 { + return keyswitch_vmp_one_digit(module, res_dft, &lhs.data, &rhs.key.data, scratch); + } + + keyswitch_vmp_multiple_digits( + module, + res_dft, + &lhs.data, + &rhs.key.data, + rhs.digits(), + scratch, + ) +} + +fn keyswitch_vmp_one_digit( + module: &Module, + mut res_dft: VecZnxDft, + a: &VecZnx, + mat: &VmpPMat, + scratch: &mut Scratch, +) -> VecZnxBig +where + DataRes: DataMut, + DataIn: DataRef, + DataVmp: DataRef, + Module: + VecZnxDftAllocBytes + VecZnxDftFromVecZnx + VmpApply + VecZnxDftToVecZnxBigConsume + VecZnxBigAddSmallInplace, + Scratch: TakeVecZnxDft, +{ + let cols: usize = a.cols(); + let (mut ai_dft, scratch1) = scratch.take_vec_znx_dft(module, cols - 1, a.size()); + (0..cols - 1).for_each(|col_i| { + module.vec_znx_dft_from_vec_znx(1, 0, &mut ai_dft, col_i, a, col_i + 1); + }); + module.vmp_apply(&mut res_dft, &ai_dft, mat, scratch1); + let mut res_big: VecZnxBig = module.vec_znx_dft_to_vec_znx_big_consume(res_dft); + module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0); + res_big +} + +fn keyswitch_vmp_multiple_digits( + module: &Module, + mut res_dft: VecZnxDft, + a: &VecZnx, + mat: &VmpPMat, + digits: usize, + scratch: &mut Scratch, +) -> VecZnxBig +where + DataRes: DataMut, + DataIn: DataRef, + DataVmp: DataRef, + Module: VecZnxDftAllocBytes + + VecZnxDftFromVecZnx + + VmpApply + + VmpApplyAdd + + VecZnxDftToVecZnxBigConsume + + VecZnxBigAddSmallInplace, + Scratch: TakeVecZnxDft, +{ + let cols: usize = a.cols(); + let size: usize = a.size(); + let (mut ai_dft, scratch1) = scratch.take_vec_znx_dft(module, cols - 1, size.div_ceil(digits)); + + ai_dft.data_mut().fill(0); + + (0..digits).for_each(|di| { + ai_dft.set_size((size + di) / digits); + + // Small optimization for digits > 2 + // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then + // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}. + // As such we can ignore the last digits-2 limbs safely of the sum of vmp products. + // It is possible to further ignore the last digits-1 limbs, but this introduce + // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same + // noise is kept with respect to the ideal functionality. + res_dft.set_size(mat.size() - ((digits - di) as isize - 2).max(0) as usize); + + (0..cols - 1).for_each(|col_i| { + module.vec_znx_dft_from_vec_znx(digits, digits - di - 1, &mut ai_dft, col_i, a, col_i + 1); + }); + + if di == 0 { + module.vmp_apply(&mut res_dft, &ai_dft, mat, scratch1); + } else { + module.vmp_apply_add(&mut res_dft, &ai_dft, mat, di, scratch1); + } + }); + + res_dft.set_size(res_dft.max_size()); + let mut res_big: VecZnxBig = module.vec_znx_dft_to_vec_znx_big_consume(res_dft); + module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0); + res_big } diff --git a/core/src/glwe/layout.rs b/core/src/glwe/layout.rs new file mode 100644 index 0000000..9c35912 --- /dev/null +++ b/core/src/glwe/layout.rs @@ -0,0 +1,123 @@ +use backend::hal::{ + api::{VecZnxAlloc, VecZnxAllocBytes}, + layouts::{Backend, Data, DataMut, DataRef, Module, ReaderFrom, VecZnx, VecZnxToMut, VecZnxToRef, WriterTo}, +}; + +use crate::{GLWEOps, Infos, SetMetaData}; + +#[derive(PartialEq, Eq)] +pub struct GLWECiphertext { + pub data: VecZnx, + pub basek: usize, + pub k: usize, +} + +impl GLWECiphertext> { + pub fn alloc(module: &Module, basek: usize, k: usize, rank: usize) -> Self + where + Module: VecZnxAlloc, + { + Self { + data: module.vec_znx_alloc(rank + 1, k.div_ceil(basek)), + basek, + k, + } + } + + pub fn bytes_of(module: &Module, basek: usize, k: usize, rank: usize) -> usize + where + Module: VecZnxAllocBytes, + { + module.vec_znx_alloc_bytes(rank + 1, k.div_ceil(basek)) + } +} + +impl Infos for GLWECiphertext { + type Inner = VecZnx; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn basek(&self) -> usize { + self.basek + } + + fn k(&self) -> usize { + self.k + } +} + +impl GLWECiphertext { + pub fn rank(&self) -> usize { + self.cols() - 1 + } +} + +impl GLWECiphertext { + pub fn clone(&self) -> GLWECiphertext> { + GLWECiphertext { + data: self.data.clone(), + basek: self.basek(), + k: self.k(), + } + } +} + +impl SetMetaData for GLWECiphertext { + fn set_k(&mut self, k: usize) { + self.k = k + } + + fn set_basek(&mut self, basek: usize) { + self.basek = basek + } +} + +pub trait GLWECiphertextToRef: Infos { + fn to_ref(&self) -> GLWECiphertext<&[u8]>; +} + +impl GLWECiphertextToRef for GLWECiphertext { + fn to_ref(&self) -> GLWECiphertext<&[u8]> { + GLWECiphertext { + data: self.data.to_ref(), + basek: self.basek, + k: self.k, + } + } +} + +pub trait GLWECiphertextToMut: Infos { + fn to_mut(&mut self) -> GLWECiphertext<&mut [u8]>; +} + +impl GLWECiphertextToMut for GLWECiphertext { + fn to_mut(&mut self) -> GLWECiphertext<&mut [u8]> { + GLWECiphertext { + data: self.data.to_mut(), + basek: self.basek, + k: self.k, + } + } +} + +impl GLWEOps for GLWECiphertext where GLWECiphertext: GLWECiphertextToMut + Infos + SetMetaData {} + +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; + +impl ReaderFrom for GLWECiphertext { + fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { + self.k = reader.read_u64::()? as usize; + self.basek = reader.read_u64::()? as usize; + self.data.read_from(reader) + } +} + +impl WriterTo for GLWECiphertext { + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + writer.write_u64::(self.k as u64)?; + writer.write_u64::(self.basek as u64)?; + self.data.write_to(writer) + } +} diff --git a/core/src/glwe/mod.rs b/core/src/glwe/mod.rs index e3879cd..60f576d 100644 --- a/core/src/glwe/mod.rs +++ b/core/src/glwe/mod.rs @@ -1,23 +1,27 @@ -pub mod automorphism; -pub mod ciphertext; -pub mod decryption; -pub mod encryption; -pub mod external_product; -pub mod keyswitch; -pub mod ops; -pub mod packing; -pub mod plaintext; -pub mod public_key; -pub mod secret; -pub mod trace; +mod automorphism; +mod decryption; +mod encryption; +mod external_product; +mod keyswitch; +mod layout; +mod noise; +mod ops; +mod packing; +mod plaintext; +mod public_key; +mod secret; +mod trace; -pub use ciphertext::GLWECiphertext; -pub(crate) use ciphertext::{GLWECiphertextToMut, GLWECiphertextToRef}; +pub use decryption::*; +pub use encryption::*; +pub use external_product::*; +pub use keyswitch::*; +pub use layout::*; pub use ops::GLWEOps; -pub use packing::GLWEPacker; -pub use plaintext::GLWEPlaintext; -pub use public_key::GLWEPublicKey; -pub use secret::GLWESecret; +pub use packing::*; +pub use plaintext::*; +pub use public_key::*; +pub use secret::*; #[cfg(test)] mod test_fft64; diff --git a/core/src/glwe/noise.rs b/core/src/glwe/noise.rs new file mode 100644 index 0000000..a8bb86f --- /dev/null +++ b/core/src/glwe/noise.rs @@ -0,0 +1,38 @@ +use backend::hal::{ + api::{ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAlloc, VecZnxNormalizeInplace, VecZnxStd, VecZnxSubABInplace}, + layouts::{Backend, DataRef, Module, ScratchOwned}, + oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl}, +}; + +use crate::{GLWECiphertext, GLWEDecryptFamily, GLWEPlaintext, GLWESecretExec, Infos}; + +impl GLWECiphertext { + pub fn assert_noise( + &self, + module: &Module, + sk_exec: &GLWESecretExec, + pt_want: &GLWEPlaintext, + max_noise: f64, + ) where + DataSk: DataRef, + DataPt: DataRef, + Module: GLWEDecryptFamily + VecZnxSubABInplace + VecZnxNormalizeInplace + VecZnxStd + VecZnxAlloc, + B: TakeVecZnxDftImpl + TakeVecZnxBigImpl + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, + { + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(module, self.basek(), self.k()); + + let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWECiphertext::decrypt_scratch_space( + module, + self.basek(), + self.k(), + )); + + self.decrypt(module, &mut pt_have, &sk_exec, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); + module.vec_znx_normalize_inplace(self.basek(), &mut pt_have.data, 0, scratch.borrow()); + + let noise_have: f64 = module.vec_znx_std(self.basek(), &pt_have.data, 0).log2(); + assert!(noise_have <= max_noise, "{} {}", noise_have, max_noise); + } +} diff --git a/core/src/glwe/ops.rs b/core/src/glwe/ops.rs index 48034a5..d6ede6d 100644 --- a/core/src/glwe/ops.rs +++ b/core/src/glwe/ops.rs @@ -1,12 +1,20 @@ -use backend::{FFT64, Module, Scratch, VecZnx, VecZnxOps, ZnxZero}; +use backend::hal::{ + api::{ + VecZnxAdd, VecZnxAddInplace, VecZnxCopy, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegateInplace, + VecZnxNormalize, VecZnxNormalizeInplace, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, + VecZnxSubABInplace, VecZnxSubBAInplace, ZnxZero, + }, + layouts::{Backend, Module, Scratch, VecZnx}, +}; use crate::{GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, Infos, SetMetaData}; pub trait GLWEOps: GLWECiphertextToMut + SetMetaData + Sized { - fn add(&mut self, module: &Module, a: &A, b: &B) + fn add(&mut self, module: &Module, a: &A, b: &B) where A: GLWECiphertextToRef, B: GLWECiphertextToRef, + Module: VecZnxAdd + VecZnxCopy, { #[cfg(debug_assertions)] { @@ -50,9 +58,10 @@ pub trait GLWEOps: GLWECiphertextToMut + SetMetaData + Sized { self.set_k(set_k_binary(self, a, b)); } - fn add_inplace(&mut self, module: &Module, a: &A) + fn add_inplace(&mut self, module: &Module, a: &A) where A: GLWECiphertextToRef + Infos, + Module: VecZnxAddInplace, { #[cfg(debug_assertions)] { @@ -72,10 +81,11 @@ pub trait GLWEOps: GLWECiphertextToMut + SetMetaData + Sized { self.set_k(set_k_unary(self, a)) } - fn sub(&mut self, module: &Module, a: &A, b: &B) + fn sub(&mut self, module: &Module, a: &A, b: &B) where A: GLWECiphertextToRef, B: GLWECiphertextToRef, + Module: VecZnxSub + VecZnxCopy + VecZnxNegateInplace, { #[cfg(debug_assertions)] { @@ -120,9 +130,10 @@ pub trait GLWEOps: GLWECiphertextToMut + SetMetaData + Sized { self.set_k(set_k_binary(self, a, b)); } - fn sub_inplace_ab(&mut self, module: &Module, a: &A) + fn sub_inplace_ab(&mut self, module: &Module, a: &A) where A: GLWECiphertextToRef + Infos, + Module: VecZnxSubABInplace, { #[cfg(debug_assertions)] { @@ -142,9 +153,10 @@ pub trait GLWEOps: GLWECiphertextToMut + SetMetaData + Sized { self.set_k(set_k_unary(self, a)) } - fn sub_inplace_ba(&mut self, module: &Module, a: &A) + fn sub_inplace_ba(&mut self, module: &Module, a: &A) where A: GLWECiphertextToRef + Infos, + Module: VecZnxSubBAInplace, { #[cfg(debug_assertions)] { @@ -164,9 +176,10 @@ pub trait GLWEOps: GLWECiphertextToMut + SetMetaData + Sized { self.set_k(set_k_unary(self, a)) } - fn rotate(&mut self, module: &Module, k: i64, a: &A) + fn rotate(&mut self, module: &Module, k: i64, a: &A) where A: GLWECiphertextToRef + Infos, + Module: VecZnxRotate, { #[cfg(debug_assertions)] { @@ -186,7 +199,10 @@ pub trait GLWEOps: GLWECiphertextToMut + SetMetaData + Sized { self.set_k(set_k_unary(self, a)) } - fn rotate_inplace(&mut self, module: &Module, k: i64) { + fn rotate_inplace(&mut self, module: &Module, k: i64) + where + Module: VecZnxRotateInplace, + { #[cfg(debug_assertions)] { assert_eq!(self.n(), module.n()); @@ -199,9 +215,49 @@ pub trait GLWEOps: GLWECiphertextToMut + SetMetaData + Sized { }); } - fn copy(&mut self, module: &Module, a: &A) + fn mul_xp_minus_one(&mut self, module: &Module, k: i64, a: &A) where A: GLWECiphertextToRef + Infos, + Module: VecZnxMulXpMinusOne, + { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + assert_eq!(self.n(), module.n()); + assert_eq!(self.rank(), a.rank()) + } + + let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); + let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref(); + + (0..a.rank() + 1).for_each(|i| { + module.vec_znx_mul_xp_minus_one(k, &mut self_mut.data, i, &a_ref.data, i); + }); + + self.set_basek(a.basek()); + self.set_k(set_k_unary(self, a)) + } + + fn mul_xp_minus_one_inplace(&mut self, module: &Module, k: i64) + where + Module: VecZnxMulXpMinusOneInplace, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.n(), module.n()); + } + + let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); + + (0..self_mut.rank() + 1).for_each(|i| { + module.vec_znx_mul_xp_minus_one_inplace(k, &mut self_mut.data, i); + }); + } + + fn copy(&mut self, module: &Module, a: &A) + where + A: GLWECiphertextToRef + Infos, + Module: VecZnxCopy, { #[cfg(debug_assertions)] { @@ -221,15 +277,18 @@ pub trait GLWEOps: GLWECiphertextToMut + SetMetaData + Sized { self.set_basek(a.basek()); } - fn rsh(&mut self, k: usize, scratch: &mut Scratch) { + fn rsh(&mut self, module: &Module, k: usize) + where + Module: VecZnxRshInplace, + { let basek: usize = self.basek(); - let mut self_mut: GLWECiphertext<&mut [u8]> = self.to_mut(); - self_mut.data.rsh(basek, k, scratch); + module.vec_znx_rsh_inplace(basek, k, &mut self.to_mut().data); } - fn normalize(&mut self, module: &Module, a: &A, scratch: &mut Scratch) + fn normalize(&mut self, module: &Module, a: &A, scratch: &mut Scratch) where A: GLWECiphertextToRef, + Module: VecZnxNormalize, { #[cfg(debug_assertions)] { @@ -248,7 +307,10 @@ pub trait GLWEOps: GLWECiphertextToMut + SetMetaData + Sized { self.set_k(a.k().min(self.k())); } - fn normalize_inplace(&mut self, module: &Module, scratch: &mut Scratch) { + fn normalize_inplace(&mut self, module: &Module, scratch: &mut Scratch) + where + Module: VecZnxNormalizeInplace, + { #[cfg(debug_assertions)] { assert_eq!(self.n(), module.n()); @@ -261,7 +323,7 @@ pub trait GLWEOps: GLWECiphertextToMut + SetMetaData + Sized { } impl GLWECiphertext> { - pub fn rsh_scratch_space(module: &Module) -> usize { + pub fn rsh_scratch_space(module: &Module) -> usize { VecZnx::rsh_scratch_space(module.n()) } } diff --git a/core/src/glwe/packing.rs b/core/src/glwe/packing.rs index f38504b..e4ba401 100644 --- a/core/src/glwe/packing.rs +++ b/core/src/glwe/packing.rs @@ -1,9 +1,31 @@ -use crate::{GLWEAutomorphismKey, GLWECiphertext, GLWEOps, Infos, ScratchCore}; use std::collections::HashMap; -use backend::{FFT64, Module, Scratch}; +use backend::hal::{ + api::{ + ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAlloc, VecZnxAllocBytes, VecZnxAutomorphismInplace, + VecZnxBigAutomorphismInplace, VecZnxBigSubSmallBInplace, VecZnxCopy, VecZnxNegateInplace, VecZnxNormalizeInplace, + VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxSubABInplace, + }, + layouts::{Backend, DataMut, DataRef, Module, Scratch}, +}; -/// [StreamPacker] enables only the fly GLWE packing +use crate::{AutomorphismKeyExec, GLWECiphertext, GLWEKeyswitchFamily, GLWEOps, Infos, TakeGLWECt}; + +pub trait GLWEPackingFamily = GLWEKeyswitchFamily + + VecZnxCopy + + VecZnxRotateInplace + + VecZnxSub + + VecZnxNegateInplace + + VecZnxRshInplace + + VecZnxAddInplace + + VecZnxNormalizeInplace + + VecZnxSubABInplace + + VecZnxRotate + + VecZnxAutomorphismInplace + + VecZnxBigSubSmallBInplace + + VecZnxBigAutomorphismInplace; + +/// [GLWEPacker] enables only the fly GLWE packing /// with constant memory of Log(N) ciphertexts. /// Main difference with usual GLWE packing is that /// the output is bit-reversed. @@ -14,7 +36,7 @@ pub struct GLWEPacker { } /// [Accumulator] stores intermediate packing result. -/// There are Log(N) such accumulators in a [StreamPacker]. +/// There are Log(N) such accumulators in a [GLWEPacker]. struct Accumulator { data: GLWECiphertext>, value: bool, // Implicit flag for zero ciphertext @@ -30,7 +52,10 @@ impl Accumulator { /// * `basek`: base 2 logarithm of the GLWE ciphertext in memory digit representation. /// * `k`: base 2 precision of the GLWE ciphertext precision over the Torus. /// * `rank`: rank of the GLWE ciphertext. - pub fn alloc(module: &Module, basek: usize, k: usize, rank: usize) -> Self { + pub fn alloc(module: &Module, basek: usize, k: usize, rank: usize) -> Self + where + Module: VecZnxAlloc, + { Self { data: GLWECiphertext::alloc(module, basek, k, rank), value: false, @@ -40,7 +65,7 @@ impl Accumulator { } impl GLWEPacker { - /// Instantiates a new [StreamPacker]. + /// Instantiates a new [GLWEPacker]. /// /// #Arguments /// @@ -53,7 +78,10 @@ impl GLWEPacker { /// * `basek`: base 2 logarithm of the GLWE ciphertext in memory digit representation. /// * `k`: base 2 precision of the GLWE ciphertext precision over the Torus. /// * `rank`: rank of the GLWE ciphertext. - pub fn new(module: &Module, log_batch: usize, basek: usize, k: usize, rank: usize) -> Self { + pub fn new(module: &Module, log_batch: usize, basek: usize, k: usize, rank: usize) -> Self + where + Module: VecZnxAlloc, + { let mut accumulators: Vec = Vec::::new(); let log_n: usize = module.log_n(); (0..log_n - log_batch).for_each(|_| accumulators.push(Accumulator::alloc(module, basek, k, rank))); @@ -74,30 +102,43 @@ impl GLWEPacker { } /// Number of scratch space bytes required to call [Self::add]. - pub fn scratch_space(module: &Module, basek: usize, ct_k: usize, k_ksk: usize, digits: usize, rank: usize) -> usize { + pub fn scratch_space( + module: &Module, + basek: usize, + ct_k: usize, + k_ksk: usize, + digits: usize, + rank: usize, + ) -> usize + where + Module: GLWEKeyswitchFamily + VecZnxAllocBytes, + { pack_core_scratch_space(module, basek, ct_k, k_ksk, digits, rank) } - pub fn galois_elements(module: &Module) -> Vec { + pub fn galois_elements(module: &Module) -> Vec { GLWECiphertext::trace_galois_elements(module) } - /// Adds a GLWE ciphertext to the [StreamPacker]. + /// Adds a GLWE ciphertext to the [GLWEPacker]. /// #Arguments /// /// * `module`: static backend FFT tables. /// * `res`: space to append fully packed ciphertext. Only when the number /// of packed ciphertexts reaches N/2^log_batch is a result written. /// * `a`: ciphertext to pack. Can optionally give None to pack a 0 ciphertext. - /// * `auto_keys`: a [HashMap] containing the [AutomorphismKey]s. - /// * `scratch`: scratch space of size at least [Self::add_scratch_space]. - pub fn add, DataAK: AsRef<[u8]>>( + /// * `auto_keys`: a [HashMap] containing the [AutomorphismKeyExec]s. + /// * `scratch`: scratch space of size at least [Self::scratch_space]. + pub fn add( &mut self, - module: &Module, + module: &Module, a: Option<&GLWECiphertext>, - auto_keys: &HashMap>, - scratch: &mut Scratch, - ) { + auto_keys: &HashMap>, + scratch: &mut Scratch, + ) where + Module: GLWEPackingFamily, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + { assert!( self.counter < module.n(), "Packing limit of {} reached", @@ -116,7 +157,10 @@ impl GLWEPacker { } /// Flush result to`res`. - pub fn flush + AsRef<[u8]>>(&mut self, module: &Module, res: &mut GLWECiphertext) { + pub fn flush(&mut self, module: &Module, res: &mut GLWECiphertext) + where + Module: VecZnxCopy, + { assert!(self.counter == module.n()); // Copy result GLWE into res GLWE res.copy( @@ -128,18 +172,31 @@ impl GLWEPacker { } } -fn pack_core_scratch_space(module: &Module, basek: usize, ct_k: usize, k_ksk: usize, digits: usize, rank: usize) -> usize { +fn pack_core_scratch_space( + module: &Module, + basek: usize, + ct_k: usize, + k_ksk: usize, + digits: usize, + rank: usize, +) -> usize +where + Module: GLWEKeyswitchFamily + VecZnxAllocBytes, +{ combine_scratch_space(module, basek, ct_k, k_ksk, digits, rank) } -fn pack_core, DataAK: AsRef<[u8]>>( - module: &Module, +fn pack_core( + module: &Module, a: Option<&GLWECiphertext>, accumulators: &mut [Accumulator], i: usize, - auto_keys: &HashMap>, - scratch: &mut Scratch, -) { + auto_keys: &HashMap>, + scratch: &mut Scratch, +) where + Module: GLWEPackingFamily, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, +{ let log_n: usize = module.log_n(); if i == log_n { @@ -189,21 +246,34 @@ fn pack_core, DataAK: AsRef<[u8]>>( } } -fn combine_scratch_space(module: &Module, basek: usize, ct_k: usize, k_ksk: usize, digits: usize, rank: usize) -> usize { +fn combine_scratch_space( + module: &Module, + basek: usize, + ct_k: usize, + k_ksk: usize, + digits: usize, + rank: usize, +) -> usize +where + Module: GLWEKeyswitchFamily + VecZnxAllocBytes, +{ GLWECiphertext::bytes_of(module, basek, ct_k, rank) + (GLWECiphertext::rsh_scratch_space(module) | GLWECiphertext::automorphism_scratch_space(module, basek, ct_k, ct_k, k_ksk, digits, rank)) } /// [combine] merges two ciphertexts together. -fn combine, DataAK: AsRef<[u8]>>( - module: &Module, +fn combine( + module: &Module, acc: &mut Accumulator, b: Option<&GLWECiphertext>, i: usize, - auto_keys: &HashMap>, - scratch: &mut Scratch, -) { + auto_keys: &HashMap>, + scratch: &mut Scratch, +) where + Module: GLWEPackingFamily, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, +{ let log_n: usize = module.log_n(); let a: &mut GLWECiphertext> = &mut acc.data; let basek: usize = a.basek(); @@ -232,18 +302,18 @@ fn combine, DataAK: AsRef<[u8]>>( // since 2*(I(X) * Q/2) = I(X) * Q = 0 mod Q. if acc.value { if let Some(b) = b { - let (mut tmp_b, scratch_1) = scratch.tmp_glwe_ct(module, basek, k, rank); + let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(module, basek, k, rank); // a = a * X^-t a.rotate_inplace(module, -t); // tmp_b = a * X^-t - b tmp_b.sub(module, a, b); - tmp_b.rsh(1, scratch_1); + tmp_b.rsh(module, 1); // a = a * X^-t + b a.add_inplace(module, b); - a.rsh(1, scratch_1); + a.rsh(module, 1); tmp_b.normalize_inplace(module, scratch_1); @@ -263,7 +333,7 @@ fn combine, DataAK: AsRef<[u8]>>( // = a + b * X^t + phi(a - b * X^t) a.rotate_inplace(module, t); } else { - a.rsh(1, scratch); + a.rsh(module, 1); // a = a + phi(a) if let Some(key) = auto_keys.get(&gal_el) { a.automorphism_add_inplace(module, key, scratch); @@ -273,13 +343,13 @@ fn combine, DataAK: AsRef<[u8]>>( } } else { if let Some(b) = b { - let (mut tmp_b, scratch_1) = scratch.tmp_glwe_ct(module, basek, k, rank); + let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(module, basek, k, rank); tmp_b.rotate(module, 1 << (log_n - i - 1), b); - tmp_b.rsh(1, scratch_1); + tmp_b.rsh(module, 1); // a = (b* X^t - phi(b* X^t)) if let Some(key) = auto_keys.get(&gal_el) { - a.automorphism_sub_ba::<&mut [u8], _>(module, &tmp_b, key, scratch_1); + a.automorphism_sub_ba(module, &tmp_b, key, scratch_1); } else { panic!("auto_key[{}] not found", gal_el); } diff --git a/core/src/glwe/plaintext.rs b/core/src/glwe/plaintext.rs index 5bebc68..114c488 100644 --- a/core/src/glwe/plaintext.rs +++ b/core/src/glwe/plaintext.rs @@ -1,15 +1,18 @@ -use backend::{Backend, FFT64, Module, VecZnx, VecZnxAlloc, VecZnxToMut, VecZnxToRef}; +use backend::hal::{ + api::{VecZnxAlloc, VecZnxAllocBytes}, + layouts::{Backend, Data, DataMut, DataRef, Module, VecZnx, VecZnxToMut, VecZnxToRef}, +}; use crate::{GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, GLWEOps, Infos, SetMetaData}; -pub struct GLWEPlaintext { - pub data: VecZnx, +pub struct GLWEPlaintext { + pub data: VecZnx, pub basek: usize, pub k: usize, } -impl Infos for GLWEPlaintext { - type Inner = VecZnx; +impl Infos for GLWEPlaintext { + type Inner = VecZnx; fn inner(&self) -> &Self::Inner { &self.data @@ -24,7 +27,7 @@ impl Infos for GLWEPlaintext { } } -impl + AsRef<[u8]>> SetMetaData for GLWEPlaintext { +impl SetMetaData for GLWEPlaintext { fn set_k(&mut self, k: usize) { self.k = k } @@ -35,20 +38,26 @@ impl + AsRef<[u8]>> SetMetaData for GLWEPlaintext> { - pub fn alloc(module: &Module, basek: usize, k: usize) -> Self { + pub fn alloc(module: &Module, basek: usize, k: usize) -> Self + where + Module: VecZnxAlloc, + { Self { - data: module.new_vec_znx(1, k.div_ceil(basek)), + data: module.vec_znx_alloc(1, k.div_ceil(basek)), basek: basek, k, } } - pub fn byte_of(module: &Module, basek: usize, k: usize) -> usize { - module.bytes_of_vec_znx(1, k.div_ceil(basek)) + pub fn byte_of(module: &Module, basek: usize, k: usize) -> usize + where + Module: VecZnxAllocBytes, + { + module.vec_znx_alloc_bytes(1, k.div_ceil(basek)) } } -impl> GLWECiphertextToRef for GLWEPlaintext { +impl GLWECiphertextToRef for GLWEPlaintext { fn to_ref(&self) -> GLWECiphertext<&[u8]> { GLWECiphertext { data: self.data.to_ref(), @@ -58,7 +67,7 @@ impl> GLWECiphertextToRef for GLWEPlaintext { } } -impl + AsRef<[u8]>> GLWECiphertextToMut for GLWEPlaintext { +impl GLWECiphertextToMut for GLWEPlaintext { fn to_mut(&mut self) -> GLWECiphertext<&mut [u8]> { GLWECiphertext { data: self.data.to_mut(), @@ -70,7 +79,7 @@ impl + AsRef<[u8]>> GLWECiphertextToMut for GLWEPlaintext { impl GLWEOps for GLWEPlaintext where - D: AsRef<[u8]> + AsMut<[u8]>, + D: DataMut, GLWEPlaintext: GLWECiphertextToMut + Infos + SetMetaData, { } diff --git a/core/src/glwe/public_key.rs b/core/src/glwe/public_key.rs index f4871ad..cd9bb34 100644 --- a/core/src/glwe/public_key.rs +++ b/core/src/glwe/public_key.rs @@ -1,57 +1,84 @@ -use backend::{Backend, FFT64, Module, ScratchOwned, VecZnxDft}; +use backend::hal::{ + api::{ + ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAlloc, VecZnxAllocBytes, VecZnxDftAlloc, VecZnxDftAllocBytes, + VecZnxDftFromVecZnx, + }, + layouts::{Backend, Data, DataMut, DataRef, Module, ReaderFrom, Scratch, ScratchOwned, VecZnx, VecZnxDft, WriterTo}, + oep::{ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxDftImpl, TakeVecZnxImpl}, +}; use sampling::source::Source; -use crate::{FourierGLWECiphertext, FourierGLWESecret, Infos, dist::Distribution}; +use crate::{GLWECiphertext, GLWEEncryptSkFamily, GLWESecretExec, Infos, dist::Distribution}; -pub struct GLWEPublicKey { - pub(crate) data: FourierGLWECiphertext, +pub trait GLWEPublicKeyFamily = GLWEEncryptSkFamily; + +#[derive(PartialEq, Eq)] +pub struct GLWEPublicKey { + pub(crate) data: VecZnx, + pub(crate) basek: usize, + pub(crate) k: usize, pub(crate) dist: Distribution, } -impl GLWEPublicKey, B> { - pub fn alloc(module: &Module, basek: usize, k: usize, rank: usize) -> Self { +impl GLWEPublicKey> { + pub fn alloc(module: &Module, basek: usize, k: usize, rank: usize) -> Self + where + Module: VecZnxAlloc, + { Self { - data: FourierGLWECiphertext::alloc(module, basek, k, rank), + data: module.vec_znx_alloc(rank + 1, k.div_ceil(basek)), + basek: basek, + k: k, dist: Distribution::NONE, } } - pub fn bytes_of(module: &Module, basek: usize, k: usize, rank: usize) -> usize { - FourierGLWECiphertext::, B>::bytes_of(module, basek, k, rank) + pub fn bytes_of(module: &Module, basek: usize, k: usize, rank: usize) -> usize + where + Module: VecZnxAllocBytes, + { + module.vec_znx_alloc_bytes(rank + 1, k.div_ceil(basek)) } } -impl Infos for GLWEPublicKey { - type Inner = VecZnxDft; +impl Infos for GLWEPublicKey { + type Inner = VecZnx; fn inner(&self) -> &Self::Inner { - &self.data.data + &self.data } fn basek(&self) -> usize { - self.data.basek + self.basek } fn k(&self) -> usize { - self.data.k + self.k } } -impl GLWEPublicKey { +impl GLWEPublicKey { pub fn rank(&self) -> usize { self.cols() - 1 } } -impl + AsMut<[u8]>> GLWEPublicKey { - pub fn generate_from_sk>( +impl GLWEPublicKey { + pub fn generate_from_sk( &mut self, - module: &Module, - sk: &FourierGLWESecret, + module: &Module, + sk: &GLWESecretExec, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, - ) { + ) where + Module: GLWEPublicKeyFamily + VecZnxAlloc, + B: ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + TakeVecZnxDftImpl + + ScratchAvailableImpl + + TakeVecZnxImpl, + { #[cfg(debug_assertions)] { match sk.dist { @@ -61,15 +88,123 @@ impl + AsMut<[u8]>> GLWEPublicKey { } // Its ok to allocate scratch space here since pk is usually generated only once. - let mut scratch: ScratchOwned = ScratchOwned::new(FourierGLWECiphertext::encrypt_sk_scratch_space( + let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWECiphertext::encrypt_sk_scratch_space( module, self.basek(), self.k(), - self.rank(), )); - self.data - .encrypt_zero_sk(module, sk, source_xa, source_xe, sigma, scratch.borrow()); + let mut tmp: GLWECiphertext> = GLWECiphertext::alloc(module, self.basek(), self.k(), self.rank()); + tmp.encrypt_zero_sk(module, sk, source_xa, source_xe, sigma, scratch.borrow()); self.dist = sk.dist; } } + +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; + +impl ReaderFrom for GLWEPublicKey { + fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { + self.k = reader.read_u64::()? as usize; + self.basek = reader.read_u64::()? as usize; + match Distribution::read_from(reader) { + Ok(dist) => self.dist = dist, + Err(e) => return Err(e), + } + self.data.read_from(reader) + } +} + +impl WriterTo for GLWEPublicKey { + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + writer.write_u64::(self.k as u64)?; + writer.write_u64::(self.basek as u64)?; + match self.dist.write_to(writer) { + Ok(()) => {} + Err(e) => return Err(e), + } + self.data.write_to(writer) + } +} + +#[derive(PartialEq, Eq)] +pub struct GLWEPublicKeyExec { + pub(crate) data: VecZnxDft, + pub(crate) basek: usize, + pub(crate) k: usize, + pub(crate) dist: Distribution, +} + +impl Infos for GLWEPublicKeyExec { + type Inner = VecZnxDft; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn basek(&self) -> usize { + self.basek + } + + fn k(&self) -> usize { + self.k + } +} + +impl GLWEPublicKeyExec { + pub fn rank(&self) -> usize { + self.cols() - 1 + } +} + +impl GLWEPublicKeyExec, B> { + pub fn alloc(module: &Module, basek: usize, k: usize, rank: usize) -> Self + where + Module: VecZnxDftAlloc, + { + Self { + data: module.vec_znx_dft_alloc(rank + 1, k.div_ceil(basek)), + basek: basek, + k: k, + dist: Distribution::NONE, + } + } + + pub fn bytes_of(module: &Module, basek: usize, k: usize, rank: usize) -> usize + where + Module: VecZnxDftAllocBytes, + { + module.vec_znx_dft_alloc_bytes(rank + 1, k.div_ceil(basek)) + } + + pub fn from(module: &Module, other: &GLWEPublicKey, scratch: &mut Scratch) -> Self + where + DataOther: DataRef, + Module: VecZnxDftAlloc + VecZnxDftFromVecZnx, + { + let mut pk_exec: GLWEPublicKeyExec, B> = GLWEPublicKeyExec::alloc(module, other.basek(), other.k(), other.rank()); + pk_exec.prepare(module, other, scratch); + pk_exec + } +} + +impl GLWEPublicKeyExec { + pub fn prepare(&mut self, module: &Module, other: &GLWEPublicKey, _scratch: &mut Scratch) + where + DataOther: DataRef, + Module: VecZnxDftFromVecZnx, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.n(), module.n()); + assert_eq!(other.n(), module.n()); + assert_eq!(self.size(), other.size()); + } + + (0..self.cols()).for_each(|i| { + module.vec_znx_dft_from_vec_znx(1, 0, &mut self.data, i, &other.data, i); + }); + self.k = other.k; + self.basek = other.basek; + self.dist = other.dist; + } +} diff --git a/core/src/glwe/secret.rs b/core/src/glwe/secret.rs index 5073d2b..8d0bc3d 100644 --- a/core/src/glwe/secret.rs +++ b/core/src/glwe/secret.rs @@ -1,27 +1,39 @@ -use backend::{Backend, Module, ScalarZnx, ScalarZnxAlloc, ZnxInfos, ZnxZero}; +use backend::hal::{ + api::{ScalarZnxAlloc, ScalarZnxAllocBytes, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, ZnxInfos, ZnxZero}, + layouts::{Backend, Data, DataMut, DataRef, Module, ReaderFrom, ScalarZnx, SvpPPol, WriterTo}, +}; use sampling::source::Source; use crate::dist::Distribution; -pub struct GLWESecret { - pub(crate) data: ScalarZnx, +pub trait GLWESecretFamily = SvpPrepare + SvpPPolAllocBytes + SvpPPolAlloc; + +#[derive(PartialEq, Eq)] +pub struct GLWESecret { + pub(crate) data: ScalarZnx, pub(crate) dist: Distribution, } impl GLWESecret> { - pub fn alloc(module: &Module, rank: usize) -> Self { + pub fn alloc(module: &Module, rank: usize) -> Self + where + Module: ScalarZnxAlloc, + { Self { - data: module.new_scalar_znx(rank), + data: module.scalar_znx_alloc(rank), dist: Distribution::NONE, } } - pub fn bytes_of(module: &Module, rank: usize) -> usize { - module.bytes_of_scalar_znx(rank) + pub fn bytes_of(module: &Module, rank: usize) -> usize + where + Module: ScalarZnxAllocBytes, + { + module.scalar_znx_alloc_bytes(rank) } } -impl GLWESecret { +impl GLWESecret { pub fn n(&self) -> usize { self.data.n() } @@ -35,7 +47,7 @@ impl GLWESecret { } } -impl + AsRef<[u8]>> GLWESecret { +impl GLWESecret { pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) { (0..self.rank()).for_each(|i| { self.data.fill_ternary_prob(i, prob, source); @@ -75,10 +87,87 @@ impl + AsRef<[u8]>> GLWESecret { self.data.zero(); self.dist = Distribution::ZERO; } - - // pub(crate) fn prep_fourier(&mut self, module: &Module) { - // (0..self.rank()).for_each(|i| { - // module.svp_prepare(&mut self.data_fourier, i, &self.data, i); - // }); - // } +} + +impl ReaderFrom for GLWESecret { + fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { + match Distribution::read_from(reader) { + Ok(dist) => self.dist = dist, + Err(e) => return Err(e), + } + self.data.read_from(reader) + } +} + +impl WriterTo for GLWESecret { + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + match self.dist.write_to(writer) { + Ok(()) => {} + Err(e) => return Err(e), + } + self.data.write_to(writer) + } +} + +pub struct GLWESecretExec { + pub(crate) data: SvpPPol, + pub(crate) dist: Distribution, +} + +impl GLWESecretExec, B> { + pub fn alloc(module: &Module, rank: usize) -> Self + where + Module: GLWESecretFamily, + { + Self { + data: module.svp_ppol_alloc(rank), + dist: Distribution::NONE, + } + } + + pub fn bytes_of(module: &Module, rank: usize) -> usize + where + Module: GLWESecretFamily, + { + module.svp_ppol_alloc_bytes(rank) + } +} + +impl GLWESecretExec, B> { + pub fn from(module: &Module, sk: &GLWESecret) -> Self + where + D: DataRef, + Module: GLWESecretFamily, + { + let mut sk_dft: GLWESecretExec, B> = Self::alloc(module, sk.rank()); + sk_dft.prepare(module, sk); + sk_dft + } +} + +impl GLWESecretExec { + pub fn n(&self) -> usize { + self.data.n() + } + + pub fn log_n(&self) -> usize { + self.data.log_n() + } + + pub fn rank(&self) -> usize { + self.data.cols() + } +} + +impl GLWESecretExec { + pub(crate) fn prepare(&mut self, module: &Module, sk: &GLWESecret) + where + O: DataRef, + Module: GLWESecretFamily, + { + (0..self.rank()).for_each(|i| { + module.svp_prepare(&mut self.data, i, &sk.data, i); + }); + self.dist = sk.dist + } } diff --git a/core/src/glwe/test_fft64/automorphism.rs b/core/src/glwe/test_fft64/automorphism.rs index 0b917ef..ef608d4 100644 --- a/core/src/glwe/test_fft64/automorphism.rs +++ b/core/src/glwe/test_fft64/automorphism.rs @@ -1,14 +1,30 @@ -use backend::{FFT64, FillUniform, Module, ScratchOwned, Stats, VecZnxOps}; - +use backend::{ + hal::{ + api::{ + MatZnxAlloc, ModuleNew, ScalarZnxAlloc, ScalarZnxAllocBytes, ScalarZnxAutomorphism, ScratchOwnedAlloc, + ScratchOwnedBorrow, VecZnxAddScalarInplace, VecZnxAlloc, VecZnxAllocBytes, VecZnxAutomorphismInplace, + VecZnxFillUniform, VecZnxStd, VecZnxSwithcDegree, + }, + layouts::{Backend, Module, ScratchOwned}, + oep::{ + ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, + TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, + }, + }, + implementation::cpu_spqlios::FFT64, +}; use sampling::source::Source; use crate::{ - FourierGLWESecret, GLWEAutomorphismKey, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, noise::log2_std_noise_gglwe_product, + AutomorphismKey, AutomorphismKeyEncryptSkFamily, AutomorphismKeyExec, GGLWEExecLayoutFamily, GLWECiphertext, + GLWEDecryptFamily, GLWEKeyswitchFamily, GLWEPlaintext, GLWESecret, GLWESecretExec, Infos, + noise::log2_std_noise_gglwe_product, }; #[test] fn apply_inplace() { let log_n: usize = 8; + let module: Module = Module::::new(1 << log_n); let basek: usize = 12; let k_ct: usize = 60; let digits: usize = k_ct.div_ceil(basek); @@ -16,7 +32,7 @@ fn apply_inplace() { (1..digits + 1).for_each(|di| { let k_ksk: usize = k_ct + basek * di; println!("test automorphism_inplace digits: {} rank: {}", di, rank); - test_automorphism_inplace(log_n, basek, -5, k_ct, k_ksk, di, rank, 3.2); + test_automorphism_inplace(&module, basek, -5, k_ct, k_ksk, di, rank, 3.2); }); }); } @@ -24,6 +40,7 @@ fn apply_inplace() { #[test] fn apply() { let log_n: usize = 8; + let module: Module = Module::::new(1 << log_n); let basek: usize = 12; let k_in: usize = 60; let digits: usize = k_in.div_ceil(basek); @@ -32,13 +49,36 @@ fn apply() { let k_ksk: usize = k_in + basek * di; let k_out: usize = k_ksk; // Better capture noise. println!("test automorphism digits: {} rank: {}", di, rank); - test_automorphism(log_n, basek, -5, k_out, k_in, k_ksk, di, rank, 3.2); + test_automorphism(&module, basek, -5, k_out, k_in, k_ksk, di, rank, 3.2); }) }); } -fn test_automorphism( - log_n: usize, +pub(crate) trait AutomorphismTestModuleFamily = AutomorphismKeyEncryptSkFamily + + GLWEDecryptFamily + + GGLWEExecLayoutFamily + + GLWEKeyswitchFamily + + MatZnxAlloc + + VecZnxAlloc + + ScalarZnxAllocBytes + + VecZnxAllocBytes + + ScalarZnxAutomorphism + + VecZnxSwithcDegree + + ScalarZnxAlloc + + VecZnxAddScalarInplace + + VecZnxAutomorphismInplace + + VecZnxStd; +pub(crate) trait AutomorphismTestScratchFamily = TakeVecZnxDftImpl + + TakeVecZnxBigImpl + + TakeSvpPPolImpl + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + ScratchAvailableImpl + + TakeScalarZnxImpl + + TakeVecZnxImpl; + +fn test_automorphism( + module: &Module, basek: usize, p: i64, k_out: usize, @@ -47,31 +87,29 @@ fn test_automorphism( digits: usize, rank: usize, sigma: f64, -) { - let module: Module = Module::::new(1 << log_n); - +) where + Module: AutomorphismTestModuleFamily, + B: AutomorphismTestScratchFamily, +{ let rows: usize = k_in.div_ceil(basek * digits); - let mut autokey: GLWEAutomorphismKey, FFT64> = GLWEAutomorphismKey::alloc(&module, basek, k_ksk, rows, digits, rank); - let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_in, rank); - let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_out, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_in); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); + let mut autokey: AutomorphismKey> = AutomorphismKey::alloc(module, basek, k_ksk, rows, digits, rank); + let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(module, basek, k_in, rank); + let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(module, basek, k_out, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k_in); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - pt_want - .data - .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); + module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, k_in, &mut source_xa); - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWEAutomorphismKey::encrypt_sk_scratch_space(&module, basek, autokey.k(), rank) - | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_out.k()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_in.k()) + let mut scratch: ScratchOwned = ScratchOwned::alloc( + AutomorphismKey::encrypt_sk_scratch_space(module, basek, autokey.k(), rank) + | GLWECiphertext::decrypt_scratch_space(module, basek, ct_out.k()) + | GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct_in.k()) | GLWECiphertext::automorphism_scratch_space( - &module, + module, basek, ct_out.k(), ct_in.k(), @@ -81,12 +119,12 @@ fn test_automorphism( ), ); - let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(module, rank); sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); + let sk_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk); autokey.encrypt_sk( - &module, + module, p, &sk, &mut source_xa, @@ -96,26 +134,21 @@ fn test_automorphism( ); ct_in.encrypt_sk( - &module, + module, &pt_want, - &sk_dft, + &sk_exec, &mut source_xa, &mut source_xe, sigma, scratch.borrow(), ); - ct_out.automorphism(&module, &ct_in, &autokey, scratch.borrow()); - ct_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); - module.vec_znx_automorphism_inplace(p, &mut pt_want.data, 0); - module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); - module.vec_znx_normalize_inplace(basek, &mut pt_have.data, 0, scratch.borrow()); + let mut autokey_exec: AutomorphismKeyExec, B> = AutomorphismKeyExec::alloc(module, basek, k_ksk, rows, digits, rank); + autokey_exec.prepare(module, &autokey, scratch.borrow()); - let noise_have: f64 = pt_have.data.std(0, basek).log2(); + ct_out.automorphism(module, &ct_in, &autokey_exec, scratch.borrow()); - println!("{}", noise_have); - - let noise_want: f64 = log2_std_noise_gglwe_product( + let max_noise: f64 = log2_std_noise_gglwe_product( module.n() as f64, basek * digits, 0.5, @@ -128,16 +161,13 @@ fn test_automorphism( k_ksk, ); - assert!( - noise_have <= noise_want + 1.0, - "{} {}", - noise_have, - noise_want - ); + module.vec_znx_automorphism_inplace(p, &mut pt_want.data, 0); + + ct_out.assert_noise(module, &sk_exec, &pt_want, max_noise + 1.0); } -fn test_automorphism_inplace( - log_n: usize, +fn test_automorphism_inplace( + module: &Module, basek: usize, p: i64, k_ct: usize, @@ -145,37 +175,35 @@ fn test_automorphism_inplace( digits: usize, rank: usize, sigma: f64, -) { - let module: Module = Module::::new(1 << log_n); - +) where + Module: AutomorphismTestModuleFamily, + B: AutomorphismTestScratchFamily, +{ let rows: usize = k_ct.div_ceil(basek * digits); - let mut autokey: GLWEAutomorphismKey, FFT64> = GLWEAutomorphismKey::alloc(&module, basek, k_ksk, rows, digits, rank); - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); + let mut autokey: AutomorphismKey> = AutomorphismKey::alloc(module, basek, k_ksk, rows, digits, rank); + let mut ct: GLWECiphertext> = GLWECiphertext::alloc(module, basek, k_ct, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k_ct); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - pt_want - .data - .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); + module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, k_ct, &mut source_xa); - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWEAutomorphismKey::encrypt_sk_scratch_space(&module, basek, autokey.k(), rank) - | GLWECiphertext::decrypt_scratch_space(&module, basek, ct.k()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k()) - | GLWECiphertext::automorphism_inplace_scratch_space(&module, basek, ct.k(), autokey.k(), digits, rank), + let mut scratch: ScratchOwned = ScratchOwned::alloc( + AutomorphismKey::encrypt_sk_scratch_space(module, basek, autokey.k(), rank) + | GLWECiphertext::decrypt_scratch_space(module, basek, ct.k()) + | GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct.k()) + | GLWECiphertext::automorphism_inplace_scratch_space(module, basek, ct.k(), autokey.k(), digits, rank), ); - let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(module, rank); sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); + let sk_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk); autokey.encrypt_sk( - &module, + module, p, &sk, &mut source_xa, @@ -185,23 +213,21 @@ fn test_automorphism_inplace( ); ct.encrypt_sk( - &module, + module, &pt_want, - &sk_dft, + &sk_exec, &mut source_xa, &mut source_xe, sigma, scratch.borrow(), ); - ct.automorphism_inplace(&module, &autokey, scratch.borrow()); - ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); - module.vec_znx_automorphism_inplace(p, &mut pt_want.data, 0); - module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); - module.vec_znx_normalize_inplace(basek, &mut pt_have.data, 0, scratch.borrow()); + let mut autokey_exec: AutomorphismKeyExec, B> = AutomorphismKeyExec::alloc(module, basek, k_ksk, rows, digits, rank); + autokey_exec.prepare(module, &autokey, scratch.borrow()); - let noise_have: f64 = pt_have.data.std(0, basek).log2(); - let noise_want: f64 = log2_std_noise_gglwe_product( + ct.automorphism_inplace(module, &autokey_exec, scratch.borrow()); + + let max_noise: f64 = log2_std_noise_gglwe_product( module.n() as f64, basek * digits, 0.5, @@ -214,10 +240,7 @@ fn test_automorphism_inplace( k_ksk, ); - assert!( - (noise_have - noise_want).abs() <= 0.5, - "{} {}", - noise_have, - noise_want - ); + module.vec_znx_automorphism_inplace(p, &mut pt_want.data, 0); + + ct.assert_noise(module, &sk_exec, &pt_want, max_noise + 1.0); } diff --git a/core/src/glwe/test_fft64/encryption.rs b/core/src/glwe/test_fft64/encryption.rs index 7909cb3..fd49612 100644 --- a/core/src/glwe/test_fft64/encryption.rs +++ b/core/src/glwe/test_fft64/encryption.rs @@ -1,160 +1,200 @@ -use backend::{FFT64, FillUniform, Module, ScratchOwned, Stats}; +use backend::{ + hal::{ + api::{ + ModuleNew, ScalarZnxAlloc, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAlloc, VecZnxDftAlloc, VecZnxFillUniform, + VecZnxStd, VecZnxSubABInplace, + }, + layouts::{Backend, Module, ScratchOwned}, + oep::{ + ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, + TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, + }, + }, + implementation::cpu_spqlios::FFT64, +}; use sampling::source::Source; -use crate::{FourierGLWECiphertext, FourierGLWESecret, GLWECiphertext, GLWEOps, GLWEPlaintext, GLWEPublicKey, GLWESecret, Infos}; +use crate::{ + GLWECiphertext, GLWEDecryptFamily, GLWEEncryptPkFamily, GLWEEncryptSkFamily, GLWEOps, GLWEPlaintext, GLWEPublicKey, + GLWEPublicKeyExec, GLWESecret, GLWESecretExec, GLWESecretFamily, Infos, +}; #[test] fn encrypt_sk() { let log_n: usize = 8; + let module: Module = Module::::new(1 << log_n); (1..4).for_each(|rank| { println!("test encrypt_sk rank: {}", rank); - test_encrypt_sk(log_n, 8, 54, 30, 3.2, rank); + test_encrypt_sk(&module, 8, 54, 30, 3.2, rank); }); } #[test] fn encrypt_zero_sk() { let log_n: usize = 8; + let module: Module = Module::::new(1 << log_n); (1..4).for_each(|rank| { println!("test encrypt_zero_sk rank: {}", rank); - test_encrypt_zero_sk(log_n, 8, 64, 3.2, rank); + test_encrypt_zero_sk(&module, 8, 64, 3.2, rank); }); } #[test] fn encrypt_pk() { let log_n: usize = 8; + let module: Module = Module::::new(1 << log_n); (1..4).for_each(|rank| { println!("test encrypt_pk rank: {}", rank); - test_encrypt_pk(log_n, 8, 64, 64, 3.2, rank) + test_encrypt_pk(&module, 8, 64, 64, 3.2, rank) }); } -fn test_encrypt_sk(log_n: usize, basek: usize, k_ct: usize, k_pt: usize, sigma: f64, rank: usize) { - let module: Module = Module::::new(1 << log_n); +pub(crate) trait EncryptionTestModuleFamily = + GLWEDecryptFamily + GLWESecretFamily + VecZnxAlloc + ScalarZnxAlloc + VecZnxStd; - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_pt); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_pt); +pub(crate) trait EncryptionTestScratchFamily = TakeVecZnxDftImpl + + TakeVecZnxBigImpl + + TakeSvpPPolImpl + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + ScratchAvailableImpl + + TakeScalarZnxImpl + + TakeVecZnxImpl; + +fn test_encrypt_sk(module: &Module, basek: usize, k_ct: usize, k_pt: usize, sigma: f64, rank: usize) +where + Module: EncryptionTestModuleFamily + GLWEEncryptSkFamily, + B: EncryptionTestScratchFamily, +{ + let mut ct: GLWECiphertext> = GLWECiphertext::alloc(module, basek, k_ct, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k_pt); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k_pt); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k()) - | GLWECiphertext::decrypt_scratch_space(&module, basek, ct.k()), + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct.k()) + | GLWECiphertext::decrypt_scratch_space(module, basek, ct.k()), ); - let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(module, rank); sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); + let sk_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk); - pt_want - .data - .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); + module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, k_pt, &mut source_xa); ct.encrypt_sk( - &module, + module, &pt_want, - &sk_dft, + &sk_exec, &mut source_xa, &mut source_xe, sigma, scratch.borrow(), ); - ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + ct.decrypt(module, &mut pt_have, &sk_exec, scratch.borrow()); - pt_want.sub_inplace_ab(&module, &pt_have); + pt_want.sub_inplace_ab(module, &pt_have); - let noise_have: f64 = pt_want.data.std(0, basek) * (ct.k() as f64).exp2(); + let noise_have: f64 = module.vec_znx_std(basek, &pt_want.data, 0) * (ct.k() as f64).exp2(); let noise_want: f64 = sigma; assert!(noise_have <= noise_want + 0.2); } -fn test_encrypt_zero_sk(log_n: usize, basek: usize, k_ct: usize, sigma: f64, rank: usize) { - let module: Module = Module::::new(1 << log_n); - - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); +fn test_encrypt_zero_sk(module: &Module, basek: usize, k_ct: usize, sigma: f64, rank: usize) +where + Module: EncryptionTestModuleFamily + GLWEEncryptSkFamily, + B: EncryptionTestScratchFamily, +{ + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k_ct); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([1u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(module, rank); sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); + let sk_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk); - let mut ct_dft: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_ct, rank); + let mut ct: GLWECiphertext> = GLWECiphertext::alloc(module, basek, k_ct, rank); - let mut scratch: ScratchOwned = ScratchOwned::new( - FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_ct) - | FourierGLWECiphertext::encrypt_sk_scratch_space(&module, basek, k_ct, rank), + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GLWECiphertext::decrypt_scratch_space(module, basek, k_ct) + | GLWECiphertext::encrypt_sk_scratch_space(module, basek, k_ct), ); - ct_dft.encrypt_zero_sk( - &module, - &sk_dft, + ct.encrypt_zero_sk( + module, + &sk_exec, &mut source_xa, &mut source_xe, sigma, scratch.borrow(), ); - ct_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + ct.decrypt(module, &mut pt, &sk_exec, scratch.borrow()); - assert!((sigma - pt.data.std(0, basek) * (k_ct as f64).exp2()) <= 0.2); + assert!((sigma - module.vec_znx_std(basek, &pt.data, 0) * (k_ct as f64).exp2()) <= 0.2); } -fn test_encrypt_pk(log_n: usize, basek: usize, k_ct: usize, k_pk: usize, sigma: f64, rank: usize) { - let module: Module = Module::::new(1 << log_n); - - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); +fn test_encrypt_pk(module: &Module, basek: usize, k_ct: usize, k_pk: usize, sigma: f64, rank: usize) +where + Module: EncryptionTestModuleFamily + + GLWEEncryptPkFamily + + GLWEEncryptSkFamily + + VecZnxDftAlloc + + VecZnxFillUniform + + VecZnxSubABInplace, + B: EncryptionTestScratchFamily, +{ + let mut ct: GLWECiphertext> = GLWECiphertext::alloc(module, basek, k_ct, rank); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k_ct); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k_ct); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut source_xu: Source = Source::new([0u8; 32]); - let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(module, rank); sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); + let sk_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk); - let mut pk: GLWEPublicKey, FFT64> = GLWEPublicKey::alloc(&module, basek, k_pk, rank); - pk.generate_from_sk(&module, &sk_dft, &mut source_xa, &mut source_xe, sigma); + let mut pk: GLWEPublicKey> = GLWEPublicKey::alloc(module, basek, k_pk, rank); + pk.generate_from_sk(module, &sk_exec, &mut source_xa, &mut source_xe, sigma); - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k()) - | GLWECiphertext::decrypt_scratch_space(&module, basek, ct.k()) - | GLWECiphertext::encrypt_pk_scratch_space(&module, basek, pk.k()), + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct.k()) + | GLWECiphertext::decrypt_scratch_space(module, basek, ct.k()) + | GLWECiphertext::encrypt_pk_scratch_space(module, basek, pk.k()), ); - pt_want - .data - .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); + module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, k_ct, &mut source_xa); + + let pk_exec: GLWEPublicKeyExec, B> = GLWEPublicKeyExec::from(module, &pk, scratch.borrow()); ct.encrypt_pk( - &module, + module, &pt_want, - &pk, + &pk_exec, &mut source_xu, &mut source_xe, sigma, scratch.borrow(), ); - ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + ct.decrypt(module, &mut pt_have, &sk_exec, scratch.borrow()); - pt_want.sub_inplace_ab(&module, &pt_have); + pt_want.sub_inplace_ab(module, &pt_have); - let noise_have: f64 = pt_want.data.std(0, basek).log2(); + let noise_have: f64 = module.vec_znx_std(basek, &pt_want.data, 0).log2(); let noise_want: f64 = ((((rank as f64) + 1.0) * module.n() as f64 * 0.5 * sigma * sigma).sqrt()).log2() - (k_ct as f64); assert!( - (noise_have - noise_want).abs() < 0.2, + noise_have <= noise_want + 0.2, "{} {}", noise_have, noise_want diff --git a/core/src/glwe/test_fft64/external_product.rs b/core/src/glwe/test_fft64/external_product.rs index e1f6b19..80f84ef 100644 --- a/core/src/glwe/test_fft64/external_product.rs +++ b/core/src/glwe/test_fft64/external_product.rs @@ -1,11 +1,28 @@ -use backend::{FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, ZnxViewMut}; +use backend::{ + hal::{ + api::{ + MatZnxAlloc, ModuleNew, ScalarZnxAlloc, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAddScalarInplace, VecZnxAlloc, + VecZnxAllocBytes, VecZnxFillUniform, VecZnxRotateInplace, VecZnxStd, ZnxViewMut, + }, + layouts::{Backend, Module, ScalarZnx, ScratchOwned}, + oep::{ + ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, + TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, + }, + }, + implementation::cpu_spqlios::FFT64, +}; use sampling::source::Source; -use crate::{FourierGLWESecret, GGSWCiphertext, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, noise::noise_ggsw_product}; +use crate::{ + GGSWCiphertext, GGSWCiphertextExec, GGSWLayoutFamily, GLWECiphertext, GLWEDecryptFamily, GLWEEncryptSkFamily, + GLWEExternalProductFamily, GLWEPlaintext, GLWESecret, GLWESecretExec, GLWESecretFamily, Infos, noise::noise_ggsw_product, +}; #[test] fn apply() { let log_n: usize = 8; + let module: Module = Module::::new(1 << log_n); let basek: usize = 12; let k_in: usize = 45; let digits: usize = k_in.div_ceil(basek); @@ -14,7 +31,7 @@ fn apply() { let k_ggsw: usize = k_in + basek * di; let k_out: usize = k_ggsw; // Better capture noise println!("test external_product digits: {} rank: {}", di, rank); - test_external_product(log_n, basek, k_out, k_in, k_ggsw, di, rank, 3.2); + test_external_product(&module, basek, k_out, k_in, k_ggsw, di, rank, 3.2); }); }); } @@ -22,6 +39,7 @@ fn apply() { #[test] fn apply_inplace() { let log_n: usize = 8; + let module: Module = Module::::new(1 << log_n); let basek: usize = 12; let k_ct: usize = 60; let digits: usize = k_ct.div_ceil(basek); @@ -29,13 +47,35 @@ fn apply_inplace() { (1..digits + 1).for_each(|di| { let k_ggsw: usize = k_ct + basek * di; println!("test external_product digits: {} rank: {}", di, rank); - test_external_product_inplace(log_n, basek, k_ct, k_ggsw, di, rank, 3.2); + test_external_product_inplace(&module, basek, k_ct, k_ggsw, di, rank, 3.2); }); }); } -fn test_external_product( - log_n: usize, +pub(crate) trait ExternalProductTestModuleFamily = GLWEEncryptSkFamily + + GLWEDecryptFamily + + GLWESecretFamily + + GLWEExternalProductFamily + + GGSWLayoutFamily + + MatZnxAlloc + + VecZnxAlloc + + ScalarZnxAlloc + + VecZnxAllocBytes + + VecZnxAddScalarInplace + + VecZnxRotateInplace + + VecZnxStd; + +pub(crate) trait ExternalProductTestScratchFamily = TakeVecZnxDftImpl + + TakeVecZnxBigImpl + + TakeSvpPPolImpl + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + ScratchAvailableImpl + + TakeScalarZnxImpl + + TakeVecZnxImpl; + +fn test_external_product( + module: &Module, basek: usize, k_out: usize, k_in: usize, @@ -43,26 +83,24 @@ fn test_external_product( digits: usize, rank: usize, sigma: f64, -) { - let module: Module = Module::::new(1 << log_n); - +) where + Module: ExternalProductTestModuleFamily, + B: ExternalProductTestScratchFamily, +{ let rows: usize = k_in.div_ceil(basek * digits); - let mut ct_ggsw: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, digits, rank); - let mut ct_glwe_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_in, rank); - let mut ct_glwe_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_out, rank); - let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_in); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); + let mut ct_ggsw: GGSWCiphertext> = GGSWCiphertext::alloc(module, basek, k_ggsw, rows, digits, rank); + let mut ct_glwe_in: GLWECiphertext> = GLWECiphertext::alloc(module, basek, k_in, rank); + let mut ct_glwe_out: GLWECiphertext> = GLWECiphertext::alloc(module, basek, k_out, rank); + let mut pt_rgsw: ScalarZnx> = module.scalar_znx_alloc(1); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k_in); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); // Random input plaintext - pt_want - .data - .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); + module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, k_in, &mut source_xa); pt_want.data.at_mut(0, 0)[1] = 1; @@ -70,12 +108,11 @@ fn test_external_product( pt_rgsw.raw_mut()[k] = 1; // X^{k} - let mut scratch: ScratchOwned = ScratchOwned::new( - GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, ct_ggsw.k(), rank) - | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_glwe_out.k()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_glwe_in.k()) + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GGSWCiphertext::encrypt_sk_scratch_space(module, basek, ct_ggsw.k(), rank) + | GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct_glwe_in.k()) | GLWECiphertext::external_product_scratch_space( - &module, + module, basek, ct_glwe_out.k(), ct_glwe_in.k(), @@ -85,14 +122,14 @@ fn test_external_product( ), ); - let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(module, rank); sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); + let sk_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk); ct_ggsw.encrypt_sk( - &module, + module, &pt_rgsw, - &sk_dft, + &sk_exec, &mut source_xa, &mut source_xe, sigma, @@ -100,25 +137,21 @@ fn test_external_product( ); ct_glwe_in.encrypt_sk( - &module, + module, &pt_want, - &sk_dft, + &sk_exec, &mut source_xa, &mut source_xe, sigma, scratch.borrow(), ); - ct_glwe_out.external_product(&module, &ct_glwe_in, &ct_ggsw, scratch.borrow()); + let ct_ggsw_exec: GGSWCiphertextExec, B> = GGSWCiphertextExec::from(module, &ct_ggsw, scratch.borrow()); - ct_glwe_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + ct_glwe_out.external_product(module, &ct_glwe_in, &ct_ggsw_exec, scratch.borrow()); module.vec_znx_rotate_inplace(k as i64, &mut pt_want.data, 0); - module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); - - let noise_have: f64 = pt_have.data.std(0, basek).log2(); - let var_gct_err_lhs: f64 = sigma * sigma; let var_gct_err_rhs: f64 = 0f64; @@ -126,7 +159,7 @@ fn test_external_product( let var_a0_err: f64 = sigma * sigma; let var_a1_err: f64 = 1f64 / 12f64; - let noise_want: f64 = noise_ggsw_product( + let max_noise: f64 = noise_ggsw_product( module.n() as f64, basek * digits, 0.5, @@ -140,32 +173,34 @@ fn test_external_product( k_ggsw, ); - assert!( - (noise_have - noise_want).abs() <= 0.5, - "{} {}", - noise_have, - noise_want - ); + ct_glwe_out.assert_noise(module, &sk_exec, &pt_want, max_noise + 0.5); } -fn test_external_product_inplace(log_n: usize, basek: usize, k_ct: usize, k_ggsw: usize, digits: usize, rank: usize, sigma: f64) { - let module: Module = Module::::new(1 << log_n); +fn test_external_product_inplace( + module: &Module, + basek: usize, + k_ct: usize, + k_ggsw: usize, + digits: usize, + rank: usize, + sigma: f64, +) where + Module: ExternalProductTestModuleFamily, + B: ExternalProductTestScratchFamily, +{ let rows: usize = k_ct.div_ceil(basek * digits); - let mut ct_ggsw: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, digits, rank); - let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); - let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); + let mut ct_ggsw: GGSWCiphertext> = GGSWCiphertext::alloc(module, basek, k_ggsw, rows, digits, rank); + let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(module, basek, k_ct, rank); + let mut pt_rgsw: ScalarZnx> = module.scalar_znx_alloc(1); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k_ct); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); // Random input plaintext - pt_want - .data - .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); + module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, k_ct, &mut source_xa); pt_want.data.at_mut(0, 0)[1] = 1; @@ -173,21 +208,20 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ct: usize, k_ggsw pt_rgsw.raw_mut()[k] = 1; // X^{k} - let mut scratch: ScratchOwned = ScratchOwned::new( - GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, ct_ggsw.k(), rank) - | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_glwe.k()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_glwe.k()) - | GLWECiphertext::external_product_inplace_scratch_space(&module, basek, ct_glwe.k(), ct_ggsw.k(), digits, rank), + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GGSWCiphertext::encrypt_sk_scratch_space(module, basek, ct_ggsw.k(), rank) + | GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct_glwe.k()) + | GLWECiphertext::external_product_inplace_scratch_space(module, basek, ct_glwe.k(), ct_ggsw.k(), digits, rank), ); - let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(module, rank); sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); + let sk_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk); ct_ggsw.encrypt_sk( - &module, + module, &pt_rgsw, - &sk_dft, + &sk_exec, &mut source_xa, &mut source_xe, sigma, @@ -195,25 +229,21 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ct: usize, k_ggsw ); ct_glwe.encrypt_sk( - &module, + module, &pt_want, - &sk_dft, + &sk_exec, &mut source_xa, &mut source_xe, sigma, scratch.borrow(), ); - ct_glwe.external_product_inplace(&module, &ct_ggsw, scratch.borrow()); + let ct_ggsw_exec: GGSWCiphertextExec, B> = GGSWCiphertextExec::from(module, &ct_ggsw, scratch.borrow()); - ct_glwe.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + ct_glwe.external_product_inplace(module, &ct_ggsw_exec, scratch.borrow()); module.vec_znx_rotate_inplace(k as i64, &mut pt_want.data, 0); - module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); - - let noise_have: f64 = pt_have.data.std(0, basek).log2(); - let var_gct_err_lhs: f64 = sigma * sigma; let var_gct_err_rhs: f64 = 0f64; @@ -221,7 +251,7 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ct: usize, k_ggsw let var_a0_err: f64 = sigma * sigma; let var_a1_err: f64 = 1f64 / 12f64; - let noise_want: f64 = noise_ggsw_product( + let max_noise: f64 = noise_ggsw_product( module.n() as f64, basek * digits, 0.5, @@ -235,10 +265,5 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ct: usize, k_ggsw k_ggsw, ); - assert!( - (noise_have - noise_want).abs() <= 0.5, - "{} {}", - noise_have, - noise_want - ); + ct_glwe.assert_noise(module, &sk_exec, &pt_want, max_noise + 0.5); } diff --git a/core/src/glwe/test_fft64/keyswitch.rs b/core/src/glwe/test_fft64/keyswitch.rs index 9142292..df13af1 100644 --- a/core/src/glwe/test_fft64/keyswitch.rs +++ b/core/src/glwe/test_fft64/keyswitch.rs @@ -1,13 +1,29 @@ -use backend::{FFT64, FillUniform, Module, ScratchOwned, Stats, VecZnxOps}; +use backend::{ + hal::{ + api::{ + MatZnxAlloc, ModuleNew, ScalarZnxAlloc, ScalarZnxAllocBytes, ScratchOwnedAlloc, ScratchOwnedBorrow, + VecZnxAddScalarInplace, VecZnxAlloc, VecZnxAllocBytes, VecZnxFillUniform, VecZnxStd, VecZnxSwithcDegree, + }, + layouts::{Backend, Module, ScratchOwned}, + oep::{ + ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, + TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, + }, + }, + implementation::cpu_spqlios::FFT64, +}; use sampling::source::Source; use crate::{ - FourierGLWESecret, GLWECiphertext, GLWEPlaintext, GLWESecret, GLWESwitchingKey, Infos, noise::log2_std_noise_gglwe_product, + GGLWEExecLayoutFamily, GLWECiphertext, GLWEDecryptFamily, GLWEKeyswitchFamily, GLWEPlaintext, GLWESecret, GLWESecretExec, + GLWESecretFamily, GLWESwitchingKey, GLWESwitchingKeyEncryptSkFamily, GLWESwitchingKeyExec, Infos, + noise::log2_std_noise_gglwe_product, }; #[test] fn apply() { let log_n: usize = 8; + let module: Module = Module::::new(1 << log_n); let basek: usize = 12; let k_in: usize = 45; let digits: usize = k_in.div_ceil(basek); @@ -20,7 +36,9 @@ fn apply() { "test keyswitch digits: {} rank_in: {} rank_out: {}", di, rank_in, rank_out ); - test_keyswitch(log_n, basek, k_out, k_in, k_ksk, di, rank_in, rank_out, 3.2); + test_keyswitch( + &module, basek, k_out, k_in, k_ksk, di, rank_in, rank_out, 3.2, + ); }) }); }); @@ -29,6 +47,7 @@ fn apply() { #[test] fn apply_inplace() { let log_n: usize = 8; + let module: Module = Module::::new(1 << log_n); let basek: usize = 12; let k_ct: usize = 45; let digits: usize = k_ct.div_ceil(basek); @@ -36,13 +55,36 @@ fn apply_inplace() { (1..digits + 1).for_each(|di| { let k_ksk: usize = k_ct + basek * di; println!("test keyswitch_inplace digits: {} rank: {}", di, rank); - test_keyswitch_inplace(log_n, basek, k_ct, k_ksk, di, rank, 3.2); + test_keyswitch_inplace(&module, basek, k_ct, k_ksk, di, rank, 3.2); }); }); } -fn test_keyswitch( - log_n: usize, +pub(crate) trait KeySwitchTestModuleFamily = GLWESecretFamily + + GLWESwitchingKeyEncryptSkFamily + + GLWEKeyswitchFamily + + GLWEDecryptFamily + + GGLWEExecLayoutFamily + + MatZnxAlloc + + VecZnxAlloc + + ScalarZnxAlloc + + ScalarZnxAllocBytes + + VecZnxAllocBytes + + VecZnxStd + + VecZnxSwithcDegree + + VecZnxAddScalarInplace; + +pub(crate) trait KeySwitchTestScratchFamily = TakeVecZnxDftImpl + + TakeVecZnxBigImpl + + TakeSvpPPolImpl + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + ScratchAvailableImpl + + TakeScalarZnxImpl + + TakeVecZnxImpl; + +fn test_keyswitch( + module: &Module, basek: usize, k_out: usize, k_in: usize, @@ -51,32 +93,28 @@ fn test_keyswitch( rank_in: usize, rank_out: usize, sigma: f64, -) { - let module: Module = Module::::new(1 << log_n); - +) where + Module: KeySwitchTestModuleFamily, + B: KeySwitchTestScratchFamily, +{ let rows: usize = k_in.div_ceil(basek * digits); - let mut ksk: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank_in, rank_out); - let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_in, rank_in); - let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_out, rank_out); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_in); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); + let mut ksk: GLWESwitchingKey> = GLWESwitchingKey::alloc(module, basek, k_ksk, rows, digits, rank_in, rank_out); + let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(module, basek, k_in, rank_in); + let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(module, basek, k_out, rank_out); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k_in); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - pt_want - .data - .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); + module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, k_in, &mut source_xa); - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ksk.k(), rank_in, rank_out) - | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_out.k()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_in.k()) + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GLWESwitchingKey::encrypt_sk_scratch_space(module, basek, ksk.k(), rank_in, rank_out) + | GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct_in.k()) | GLWECiphertext::keyswitch_scratch_space( - &module, + module, basek, ct_out.k(), ct_in.k(), @@ -87,18 +125,18 @@ fn test_keyswitch( ), ); - let mut sk_in: GLWESecret> = GLWESecret::alloc(&module, rank_in); + let mut sk_in: GLWESecret> = GLWESecret::alloc(module, rank_in); sk_in.fill_ternary_prob(0.5, &mut source_xs); - let sk_in_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_in); + let sk_in_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk_in); - let mut sk_out: GLWESecret> = GLWESecret::alloc(&module, rank_out); + let mut sk_out: GLWESecret> = GLWESecret::alloc(module, rank_out); sk_out.fill_ternary_prob(0.5, &mut source_xs); - let sk_out_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_out); + let sk_out_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk_out); ksk.encrypt_sk( - &module, + module, &sk_in, - &sk_out_dft, + &sk_out, &mut source_xa, &mut source_xe, sigma, @@ -106,22 +144,20 @@ fn test_keyswitch( ); ct_in.encrypt_sk( - &module, + module, &pt_want, - &sk_in_dft, + &sk_in_exec, &mut source_xa, &mut source_xe, sigma, scratch.borrow(), ); - ct_out.keyswitch(&module, &ct_in, &ksk, scratch.borrow()); - ct_out.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow()); + let ksk_exec: GLWESwitchingKeyExec, B> = GLWESwitchingKeyExec::from(module, &ksk, scratch.borrow()); - module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); + ct_out.keyswitch(module, &ct_in, &ksk_exec, scratch.borrow()); - let noise_have: f64 = pt_have.data.std(0, basek).log2(); - let noise_want: f64 = log2_std_noise_gglwe_product( + let max_noise: f64 = log2_std_noise_gglwe_product( module.n() as f64, basek * digits, 0.5, @@ -134,53 +170,51 @@ fn test_keyswitch( k_ksk, ); - println!("{} vs. {}", noise_have, noise_want); - - assert!( - (noise_have - noise_want).abs() <= 0.5, - "{} {}", - noise_have, - noise_want - ); + ct_out.assert_noise(module, &sk_out_exec, &pt_want, max_noise + 0.5); } -fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ct: usize, k_ksk: usize, digits: usize, rank: usize, sigma: f64) { - let module: Module = Module::::new(1 << log_n); - +fn test_keyswitch_inplace( + module: &Module, + basek: usize, + k_ct: usize, + k_ksk: usize, + digits: usize, + rank: usize, + sigma: f64, +) where + Module: KeySwitchTestModuleFamily, + B: KeySwitchTestScratchFamily, +{ let rows: usize = k_ct.div_ceil(basek * digits); - let mut ct_grlwe: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank, rank); - let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); + let mut ksk: GLWESwitchingKey> = GLWESwitchingKey::alloc(module, basek, k_ksk, rows, digits, rank, rank); + let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(module, basek, k_ct, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k_ct); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - pt_want - .data - .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); + module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, k_ct, &mut source_xa); - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ct_grlwe.k(), rank, rank) - | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_glwe.k()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_glwe.k()) - | GLWECiphertext::keyswitch_inplace_scratch_space(&module, basek, ct_glwe.k(), ct_grlwe.k(), digits, rank), + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GLWESwitchingKey::encrypt_sk_scratch_space(module, basek, ksk.k(), rank, rank) + | GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct_glwe.k()) + | GLWECiphertext::keyswitch_inplace_scratch_space(module, basek, ct_glwe.k(), ksk.k(), digits, rank), ); - let mut sk_in: GLWESecret> = GLWESecret::alloc(&module, rank); + let mut sk_in: GLWESecret> = GLWESecret::alloc(module, rank); sk_in.fill_ternary_prob(0.5, &mut source_xs); - let sk_in_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_in); + let sk_in_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk_in); - let mut sk_out: GLWESecret> = GLWESecret::alloc(&module, rank); + let mut sk_out: GLWESecret> = GLWESecret::alloc(module, rank); sk_out.fill_ternary_prob(0.5, &mut source_xs); - let sk_out_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_out); + let sk_out_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk_out); - ct_grlwe.encrypt_sk( - &module, + ksk.encrypt_sk( + module, &sk_in, - &sk_out_dft, + &sk_out, &mut source_xa, &mut source_xe, sigma, @@ -188,23 +222,20 @@ fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ct: usize, k_ksk: usize, ); ct_glwe.encrypt_sk( - &module, + module, &pt_want, - &sk_in_dft, + &sk_in_exec, &mut source_xa, &mut source_xe, sigma, scratch.borrow(), ); - ct_glwe.keyswitch_inplace(&module, &ct_grlwe, scratch.borrow()); + let ksk_exec: GLWESwitchingKeyExec, B> = GLWESwitchingKeyExec::from(module, &ksk, scratch.borrow()); - ct_glwe.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow()); + ct_glwe.keyswitch_inplace(module, &ksk_exec, scratch.borrow()); - module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); - - let noise_have: f64 = pt_have.data.std(0, basek).log2(); - let noise_want: f64 = log2_std_noise_gglwe_product( + let max_noise: f64 = log2_std_noise_gglwe_product( module.n() as f64, basek * digits, 0.5, @@ -217,10 +248,5 @@ fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ct: usize, k_ksk: usize, k_ksk, ); - assert!( - (noise_have - noise_want).abs() <= 0.5, - "{} {}", - noise_have, - noise_want - ); + ct_glwe.assert_noise(module, &sk_out_exec, &pt_want, max_noise + 0.5); } diff --git a/core/src/glwe/test_fft64/packing.rs b/core/src/glwe/test_fft64/packing.rs index ff2cfbd..b6743b6 100644 --- a/core/src/glwe/test_fft64/packing.rs +++ b/core/src/glwe/test_fft64/packing.rs @@ -1,14 +1,67 @@ -use crate::{FourierGLWESecret, GLWEAutomorphismKey, GLWECiphertext, GLWEOps, GLWEPacker, GLWEPlaintext, GLWESecret}; use std::collections::HashMap; -use backend::{Encoding, FFT64, Module, ScratchOwned, Stats}; +use backend::{ + hal::{ + api::{ + MatZnxAlloc, ModuleNew, ScalarZnxAlloc, ScalarZnxAllocBytes, ScalarZnxAutomorphism, ScratchOwnedAlloc, + ScratchOwnedBorrow, VecZnxAddScalarInplace, VecZnxAlloc, VecZnxAllocBytes, VecZnxBigSubSmallBInplace, + VecZnxEncodeVeci64, VecZnxRotateInplace, VecZnxStd, VecZnxSwithcDegree, + }, + layouts::{Backend, Module, ScratchOwned}, + oep::{ + ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, + TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, + }, + }, + implementation::cpu_spqlios::FFT64, +}; use sampling::source::Source; +use crate::{ + AutomorphismKey, AutomorphismKeyExec, GGLWEExecLayoutFamily, GLWECiphertext, GLWEDecryptFamily, GLWEKeyswitchFamily, GLWEOps, + GLWEPacker, GLWEPackingFamily, GLWEPlaintext, GLWESecret, GLWESecretExec, GLWESecretFamily, GLWESwitchingKeyEncryptSkFamily, +}; + #[test] -fn apply() { +fn trace() { let log_n: usize = 5; let module: Module = Module::::new(1 << log_n); + test_packing(&module); +} +pub(crate) trait PackingTestModuleFamily = GLWEPackingFamily + + GLWESecretFamily + + GLWESwitchingKeyEncryptSkFamily + + GLWEKeyswitchFamily + + GLWEDecryptFamily + + GGLWEExecLayoutFamily + + MatZnxAlloc + + VecZnxAlloc + + ScalarZnxAlloc + + ScalarZnxAllocBytes + + VecZnxAllocBytes + + VecZnxStd + + VecZnxSwithcDegree + + VecZnxAddScalarInplace + + VecZnxEncodeVeci64 + + ScalarZnxAutomorphism + + VecZnxRotateInplace + + VecZnxBigSubSmallBInplace; + +pub(crate) trait PackingTestScratchFamily = TakeVecZnxDftImpl + + TakeVecZnxBigImpl + + TakeSvpPPolImpl + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + ScratchAvailableImpl + + TakeScalarZnxImpl + + TakeVecZnxImpl; + +pub(crate) fn test_packing(module: &Module) +where + Module: PackingTestModuleFamily, + B: PackingTestScratchFamily, +{ let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); @@ -23,31 +76,31 @@ fn apply() { let rows: usize = k_ct.div_ceil(basek * digits); - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWECiphertext::encrypt_sk_scratch_space(&module, basek, k_ct) - | GLWECiphertext::decrypt_scratch_space(&module, basek, k_ct) - | GLWEAutomorphismKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank) - | GLWEPacker::scratch_space(&module, basek, k_ct, k_ksk, digits, rank), + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GLWECiphertext::encrypt_sk_scratch_space(module, basek, k_ct) + | AutomorphismKey::encrypt_sk_scratch_space(module, basek, k_ksk, rank) + | GLWEPacker::scratch_space(module, basek, k_ct, k_ksk, digits, rank), ); - let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(module, rank); sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); + let sk_dft: GLWESecretExec, B> = GLWESecretExec::from(module, &sk); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k_ct); let mut data: Vec = vec![0i64; module.n()]; data.iter_mut().enumerate().for_each(|(i, x)| { *x = i as i64; }); - pt.data.encode_vec_i64(0, basek, pt_k, &data, 32); - let gal_els: Vec = GLWEPacker::galois_elements(&module); + module.encode_vec_i64(basek, &mut pt.data, 0, pt_k, &data, 32); - let mut auto_keys: HashMap, FFT64>> = HashMap::new(); + let gal_els: Vec = GLWEPacker::galois_elements(module); + + let mut auto_keys: HashMap, B>> = HashMap::new(); + let mut tmp: AutomorphismKey> = AutomorphismKey::alloc(module, basek, k_ksk, rows, digits, rank); gal_els.iter().for_each(|gal_el| { - let mut key: GLWEAutomorphismKey, FFT64> = GLWEAutomorphismKey::alloc(&module, basek, k_ksk, rows, digits, rank); - key.encrypt_sk( - &module, + tmp.encrypt_sk( + module, *gal_el, &sk, &mut source_xa, @@ -55,17 +108,18 @@ fn apply() { sigma, scratch.borrow(), ); - auto_keys.insert(*gal_el, key); + let atk_exec: AutomorphismKeyExec, B> = AutomorphismKeyExec::from(module, &tmp, scratch.borrow()); + auto_keys.insert(*gal_el, atk_exec); }); let log_batch: usize = 0; - let mut packer: GLWEPacker = GLWEPacker::new(&module, log_batch, basek, k_ct, rank); + let mut packer: GLWEPacker = GLWEPacker::new(module, log_batch, basek, k_ct, rank); - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); + let mut ct: GLWECiphertext> = GLWECiphertext::alloc(module, basek, k_ct, rank); ct.encrypt_sk( - &module, + module, &pt, &sk_dft, &mut source_xa, @@ -74,9 +128,11 @@ fn apply() { scratch.borrow(), ); + let log_n: usize = module.log_n(); + (0..module.n() >> log_batch).for_each(|i| { ct.encrypt_sk( - &module, + module, &pt, &sk_dft, &mut source_xa, @@ -85,13 +141,13 @@ fn apply() { scratch.borrow(), ); - pt.rotate_inplace(&module, -(1 << log_batch)); // X^-batch * pt + pt.rotate_inplace(module, -(1 << log_batch)); // X^-batch * pt if reverse_bits_msb(i, log_n as u32) % 5 == 0 { - packer.add(&module, Some(&ct), &auto_keys, scratch.borrow()); + packer.add(module, Some(&ct), &auto_keys, scratch.borrow()); } else { packer.add( - &module, + module, None::<&GLWECiphertext>>, &auto_keys, scratch.borrow(), @@ -99,23 +155,24 @@ fn apply() { } }); - let mut res = GLWECiphertext::alloc(&module, basek, k_ct, rank); - packer.flush(&module, &mut res); + let mut res = GLWECiphertext::alloc(module, basek, k_ct, rank); + packer.flush(module, &mut res); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k_ct); let mut data: Vec = vec![0i64; module.n()]; data.iter_mut().enumerate().for_each(|(i, x)| { if i % 5 == 0 { *x = reverse_bits_msb(i, log_n as u32) as i64; } }); - pt_want.data.encode_vec_i64(0, basek, pt_k, &data, 32); - res.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + module.encode_vec_i64(basek, &mut pt_want.data, 0, pt_k, &data, 32); - pt.sub_inplace_ab(&module, &pt_want); + res.decrypt(module, &mut pt, &sk_dft, scratch.borrow()); - let noise_have = pt.data.std(0, basek).log2(); + pt.sub_inplace_ab(module, &pt_want); + + let noise_have: f64 = module.vec_znx_std(basek, &pt.data, 0).log2(); // println!("noise_have: {}", noise_have); assert!( noise_have < -((k_ct - basek) as f64), diff --git a/core/src/glwe/test_fft64/trace.rs b/core/src/glwe/test_fft64/trace.rs index e34e260..234f173 100644 --- a/core/src/glwe/test_fft64/trace.rs +++ b/core/src/glwe/test_fft64/trace.rs @@ -1,47 +1,97 @@ use std::collections::HashMap; -use backend::{FFT64, FillUniform, Module, ScratchOwned, Stats, VecZnxOps, ZnxView, ZnxViewMut}; +use backend::{ + hal::{ + api::{ + MatZnxAlloc, ModuleNew, ScalarZnxAlloc, ScalarZnxAllocBytes, ScalarZnxAutomorphism, ScratchOwnedAlloc, + ScratchOwnedBorrow, VecZnxAddScalarInplace, VecZnxAlloc, VecZnxAllocBytes, VecZnxBigAutomorphismInplace, + VecZnxBigSubSmallBInplace, VecZnxCopy, VecZnxEncodeVeci64, VecZnxFillUniform, VecZnxNormalizeInplace, + VecZnxRotateInplace, VecZnxRshInplace, VecZnxStd, VecZnxSubABInplace, VecZnxSwithcDegree, ZnxView, ZnxViewMut, + }, + layouts::{Backend, Module, ScratchOwned}, + oep::{ + ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, + TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, + }, + }, + implementation::cpu_spqlios::FFT64, +}; use sampling::source::Source; use crate::{ - FourierGLWESecret, GLWEAutomorphismKey, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, noise::var_noise_gglwe_product, + AutomorphismKey, AutomorphismKeyExec, GGLWEExecLayoutFamily, GLWECiphertext, GLWEDecryptFamily, GLWEKeyswitchFamily, + GLWEPlaintext, GLWESecret, GLWESecretExec, GLWESecretFamily, GLWESwitchingKeyEncryptSkFamily, Infos, + noise::var_noise_gglwe_product, }; #[test] fn apply_inplace() { let log_n: usize = 8; + let module: Module = Module::::new(1 << log_n); (1..4).for_each(|rank| { println!("test trace_inplace rank: {}", rank); - test_trace_inplace(log_n, 8, 54, 3.2, rank); + test_trace_inplace(&module, 8, 54, 3.2, rank); }); } -fn test_trace_inplace(log_n: usize, basek: usize, k: usize, sigma: f64, rank: usize) { - let module: Module = Module::::new(1 << log_n); +pub(crate) trait TraceTestModuleFamily = GLWESecretFamily + + GLWESwitchingKeyEncryptSkFamily + + GLWEKeyswitchFamily + + GLWEDecryptFamily + + GGLWEExecLayoutFamily + + MatZnxAlloc + + VecZnxAlloc + + ScalarZnxAlloc + + ScalarZnxAllocBytes + + VecZnxAllocBytes + + VecZnxStd + + VecZnxSwithcDegree + + VecZnxAddScalarInplace + + VecZnxEncodeVeci64 + + ScalarZnxAutomorphism + + VecZnxRotateInplace + + VecZnxBigSubSmallBInplace + + VecZnxBigAutomorphismInplace + + VecZnxCopy + + VecZnxRshInplace; +pub(crate) trait TraceTestScratchFamily = TakeVecZnxDftImpl + + TakeVecZnxBigImpl + + TakeSvpPPolImpl + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + ScratchAvailableImpl + + TakeScalarZnxImpl + + TakeVecZnxImpl; + +fn test_trace_inplace(module: &Module, basek: usize, k: usize, sigma: f64, rank: usize) +where + Module: TraceTestModuleFamily, + B: TraceTestScratchFamily, +{ let k_autokey: usize = k + basek; let digits: usize = 1; let rows: usize = k.div_ceil(basek * digits); - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k); + let mut ct: GLWECiphertext> = GLWECiphertext::alloc(module, basek, k, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k()) - | GLWECiphertext::decrypt_scratch_space(&module, basek, ct.k()) - | GLWEAutomorphismKey::encrypt_sk_scratch_space(&module, basek, k_autokey, rank) - | GLWECiphertext::trace_inplace_scratch_space(&module, basek, ct.k(), k_autokey, digits, rank), + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct.k()) + | GLWECiphertext::decrypt_scratch_space(module, basek, ct.k()) + | AutomorphismKey::encrypt_sk_scratch_space(module, basek, k_autokey, rank) + | GLWECiphertext::trace_inplace_scratch_space(module, basek, ct.k(), k_autokey, digits, rank), ); - let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(module, rank); sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); + let sk_dft: GLWESecretExec, B> = GLWESecretExec::from(module, &sk); let mut data_want: Vec = vec![0i64; module.n()]; @@ -49,12 +99,10 @@ fn test_trace_inplace(log_n: usize, basek: usize, k: usize, sigma: f64, rank: us .iter_mut() .for_each(|x| *x = source_xa.next_i64() & 0xFF); - pt_have - .data - .fill_uniform(basek, 0, pt_have.size(), &mut source_xa); + module.vec_znx_fill_uniform(basek, &mut pt_have.data, 0, k, &mut source_xa); ct.encrypt_sk( - &module, + module, &pt_have, &sk_dft, &mut source_xa, @@ -63,13 +111,12 @@ fn test_trace_inplace(log_n: usize, basek: usize, k: usize, sigma: f64, rank: us scratch.borrow(), ); - let mut auto_keys: HashMap, FFT64>> = HashMap::new(); - let gal_els: Vec = GLWECiphertext::trace_galois_elements(&module); + let mut auto_keys: HashMap, B>> = HashMap::new(); + let gal_els: Vec = GLWECiphertext::trace_galois_elements(module); + let mut tmp: AutomorphismKey> = AutomorphismKey::alloc(module, basek, k_autokey, rows, digits, rank); gal_els.iter().for_each(|gal_el| { - let mut key: GLWEAutomorphismKey, FFT64> = - GLWEAutomorphismKey::alloc(&module, basek, k_autokey, rows, digits, rank); - key.encrypt_sk( - &module, + tmp.encrypt_sk( + module, *gal_el, &sk, &mut source_xa, @@ -77,20 +124,21 @@ fn test_trace_inplace(log_n: usize, basek: usize, k: usize, sigma: f64, rank: us sigma, scratch.borrow(), ); - auto_keys.insert(*gal_el, key); + let atk_exec: AutomorphismKeyExec, B> = AutomorphismKeyExec::from(module, &tmp, scratch.borrow()); + auto_keys.insert(*gal_el, atk_exec); }); - ct.trace_inplace(&module, 0, 5, &auto_keys, scratch.borrow()); - ct.trace_inplace(&module, 5, log_n, &auto_keys, scratch.borrow()); + ct.trace_inplace(module, 0, 5, &auto_keys, scratch.borrow()); + ct.trace_inplace(module, 5, module.log_n(), &auto_keys, scratch.borrow()); (0..pt_want.size()).for_each(|i| pt_want.data.at_mut(0, i)[0] = pt_have.data.at(0, i)[0]); - ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + ct.decrypt(module, &mut pt_have, &sk_dft, scratch.borrow()); module.vec_znx_sub_ab_inplace(&mut pt_want.data, 0, &pt_have.data, 0); module.vec_znx_normalize_inplace(basek, &mut pt_want.data, 0, scratch.borrow()); - let noise_have = pt_want.data.std(0, basek).log2(); + let noise_have: f64 = module.vec_znx_std(basek, &pt_want.data, 0).log2(); let mut noise_want: f64 = var_noise_gglwe_product( module.n() as f64, diff --git a/core/src/glwe/trace.rs b/core/src/glwe/trace.rs index c702489..407c960 100644 --- a/core/src/glwe/trace.rs +++ b/core/src/glwe/trace.rs @@ -1,11 +1,19 @@ use std::collections::HashMap; -use backend::{FFT64, Module, Scratch}; +use backend::hal::{ + api::{ScratchAvailable, TakeVecZnxDft, VecZnxBigAutomorphismInplace, VecZnxCopy, VecZnxRshInplace}, + layouts::{Backend, DataMut, DataRef, Module, Scratch}, +}; -use crate::{GLWEAutomorphismKey, GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, GLWEOps, Infos, SetMetaData}; +use crate::{ + AutomorphismKeyExec, GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, GLWEKeyswitchFamily, GLWEOps, Infos, + SetMetaData, +}; + +pub trait GLWETraceFamily = GLWEKeyswitchFamily + VecZnxCopy + VecZnxRshInplace + VecZnxBigAutomorphismInplace; impl GLWECiphertext> { - pub fn trace_galois_elements(module: &Module) -> Vec { + pub fn trace_galois_elements(module: &Module) -> Vec { let mut gal_els: Vec = Vec::new(); (0..module.log_n()).for_each(|i| { if i == 0 { @@ -17,59 +25,70 @@ impl GLWECiphertext> { gal_els } - pub fn trace_scratch_space( - module: &Module, + pub fn trace_scratch_space( + module: &Module, basek: usize, out_k: usize, in_k: usize, ksk_k: usize, digits: usize, rank: usize, - ) -> usize { + ) -> usize + where + Module: GLWEKeyswitchFamily, + { Self::automorphism_inplace_scratch_space(module, basek, out_k.min(in_k), ksk_k, digits, rank) } - pub fn trace_inplace_scratch_space( - module: &Module, + pub fn trace_inplace_scratch_space( + module: &Module, basek: usize, out_k: usize, ksk_k: usize, digits: usize, rank: usize, - ) -> usize { + ) -> usize + where + Module: GLWEKeyswitchFamily, + { Self::automorphism_inplace_scratch_space(module, basek, out_k, ksk_k, digits, rank) } } -impl + AsMut<[u8]>> GLWECiphertext +impl GLWECiphertext where GLWECiphertext: GLWECiphertextToMut + Infos + SetMetaData, { - pub fn trace, DataAK: AsRef<[u8]>>( + pub fn trace( &mut self, - module: &Module, + module: &Module, start: usize, end: usize, lhs: &GLWECiphertext, - auto_keys: &HashMap>, - scratch: &mut Scratch, + auto_keys: &HashMap>, + scratch: &mut Scratch, ) where - GLWECiphertext: GLWECiphertextToRef + Infos, + GLWECiphertext: GLWECiphertextToRef + Infos + VecZnxRshInplace, + Module: GLWETraceFamily, + Scratch: TakeVecZnxDft + ScratchAvailable, { self.copy(module, lhs); self.trace_inplace(module, start, end, auto_keys, scratch); } - pub fn trace_inplace>( + pub fn trace_inplace( &mut self, - module: &Module, + module: &Module, start: usize, end: usize, - auto_keys: &HashMap>, - scratch: &mut Scratch, - ) { + auto_keys: &HashMap>, + scratch: &mut Scratch, + ) where + Module: GLWETraceFamily, + Scratch: TakeVecZnxDft + ScratchAvailable, + { (start..end).for_each(|i| { - self.rsh(1, scratch); + self.rsh(module, 1); let p: i64; if i == 0 { diff --git a/core/src/lib.rs b/core/src/lib.rs index 2c38575..fa3b87f 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -1,337 +1,22 @@ -pub mod blind_rotation; -pub mod dist; -pub mod elem; -pub mod fourier_glwe; -pub mod gglwe; -pub mod ggsw; -pub mod glwe; -pub mod lwe; -pub mod noise; - -use backend::Backend; -use backend::FFT64; -use backend::Module; -pub use blind_rotation::{BlindRotationKeyCGGI, LookUpTable, cggi_blind_rotate, cggi_blind_rotate_scratch_space}; -pub use elem::{GetRow, Infos, SetMetaData, SetRow}; -pub use fourier_glwe::{FourierGLWECiphertext, FourierGLWESecret}; -pub use gglwe::{GGLWECiphertext, GLWEAutomorphismKey, GLWESwitchingKey, GLWETensorKey}; -pub use ggsw::GGSWCiphertext; -pub use glwe::{GLWECiphertext, GLWEOps, GLWEPacker, GLWEPlaintext, GLWEPublicKey, GLWESecret}; -pub use lwe::{LWECiphertext, LWESecret}; - -pub use backend; -pub use backend::Scratch; -pub use backend::ScratchOwned; -pub(crate) use glwe::{GLWECiphertextToMut, GLWECiphertextToRef}; +#![feature(trait_alias)] +mod blind_rotation; +mod dist; +mod elem; +mod gglwe; +mod ggsw; +mod glwe; +mod lwe; +mod noise; +mod scratch; use crate::dist::Distribution; +pub use blind_rotation::*; +pub use elem::*; +pub use gglwe::*; +pub use ggsw::*; +pub use glwe::*; +pub use lwe::*; +pub use scratch::*; + pub(crate) const SIX_SIGMA: f64 = 6.0; - -pub trait ScratchCore { - fn tmp_glwe_ct(&mut self, module: &Module, basek: usize, k: usize, rank: usize) -> (GLWECiphertext<&mut [u8]>, &mut Self); - fn tmp_vec_glwe_ct( - &mut self, - size: usize, - module: &Module, - basek: usize, - k: usize, - rank: usize, - ) -> (Vec>, &mut Self); - fn tmp_glwe_pt(&mut self, module: &Module, basek: usize, k: usize) -> (GLWEPlaintext<&mut [u8]>, &mut Self); - fn tmp_gglwe( - &mut self, - module: &Module, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - ) -> (GGLWECiphertext<&mut [u8], B>, &mut Self); - fn tmp_ggsw( - &mut self, - module: &Module, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank: usize, - ) -> (GGSWCiphertext<&mut [u8], B>, &mut Self); - fn tmp_fourier_glwe_ct( - &mut self, - module: &Module, - basek: usize, - k: usize, - rank: usize, - ) -> (FourierGLWECiphertext<&mut [u8], B>, &mut Self); - fn tmp_slice_fourier_glwe_ct( - &mut self, - size: usize, - module: &Module, - basek: usize, - k: usize, - rank: usize, - ) -> (Vec>, &mut Self); - fn tmp_glwe_secret(&mut self, module: &Module, rank: usize) -> (GLWESecret<&mut [u8]>, &mut Self); - fn tmp_fourier_glwe_secret(&mut self, module: &Module, rank: usize) -> (FourierGLWESecret<&mut [u8], B>, &mut Self); - fn tmp_glwe_pk( - &mut self, - module: &Module, - basek: usize, - k: usize, - rank: usize, - ) -> (GLWEPublicKey<&mut [u8], B>, &mut Self); - fn tmp_glwe_ksk( - &mut self, - module: &Module, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - ) -> (GLWESwitchingKey<&mut [u8], B>, &mut Self); - fn tmp_tsk( - &mut self, - module: &Module, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank: usize, - ) -> (GLWETensorKey<&mut [u8], B>, &mut Self); - fn tmp_autokey( - &mut self, - module: &Module, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank: usize, - ) -> (GLWEAutomorphismKey<&mut [u8], B>, &mut Self); -} - -impl ScratchCore for Scratch { - fn tmp_glwe_ct( - &mut self, - module: &Module, - basek: usize, - k: usize, - rank: usize, - ) -> (GLWECiphertext<&mut [u8]>, &mut Self) { - let (data, scratch) = self.tmp_vec_znx(module, rank + 1, k.div_ceil(basek)); - (GLWECiphertext { data, basek, k }, scratch) - } - - fn tmp_vec_glwe_ct( - &mut self, - size: usize, - module: &Module, - basek: usize, - k: usize, - rank: usize, - ) -> (Vec>, &mut Self) { - let mut scratch: &mut Scratch = self; - let mut cts: Vec> = Vec::with_capacity(size); - for _ in 0..size { - let (ct, new_scratch) = scratch.tmp_glwe_ct(module, basek, k, rank); - scratch = new_scratch; - cts.push(ct); - } - (cts, scratch) - } - - fn tmp_glwe_pt(&mut self, module: &Module, basek: usize, k: usize) -> (GLWEPlaintext<&mut [u8]>, &mut Self) { - let (data, scratch) = self.tmp_vec_znx(module, 1, k.div_ceil(basek)); - (GLWEPlaintext { data, basek, k }, scratch) - } - - fn tmp_gglwe( - &mut self, - module: &Module, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - ) -> (GGLWECiphertext<&mut [u8], FFT64>, &mut Self) { - let (data, scratch) = self.tmp_mat_znx_dft( - module, - rows.div_ceil(digits), - rank_in, - rank_out + 1, - k.div_ceil(basek), - ); - ( - GGLWECiphertext { - data: data, - basek: basek, - k, - digits, - }, - scratch, - ) - } - - fn tmp_ggsw( - &mut self, - module: &Module, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank: usize, - ) -> (GGSWCiphertext<&mut [u8], FFT64>, &mut Self) { - let (data, scratch) = self.tmp_mat_znx_dft( - module, - rows.div_ceil(digits), - rank + 1, - rank + 1, - k.div_ceil(basek), - ); - ( - GGSWCiphertext { - data, - basek, - k, - digits, - }, - scratch, - ) - } - - fn tmp_fourier_glwe_ct( - &mut self, - module: &Module, - basek: usize, - k: usize, - rank: usize, - ) -> (FourierGLWECiphertext<&mut [u8], FFT64>, &mut Self) { - let (data, scratch) = self.tmp_vec_znx_dft(module, rank + 1, k.div_ceil(basek)); - (FourierGLWECiphertext { data, basek, k }, scratch) - } - - fn tmp_slice_fourier_glwe_ct( - &mut self, - size: usize, - module: &Module, - basek: usize, - k: usize, - rank: usize, - ) -> (Vec>, &mut Self) { - let mut scratch: &mut Scratch = self; - let mut cts: Vec> = Vec::with_capacity(size); - for _ in 0..size { - let (ct, new_scratch) = scratch.tmp_fourier_glwe_ct(module, basek, k, rank); - scratch = new_scratch; - cts.push(ct); - } - (cts, scratch) - } - - fn tmp_glwe_pk( - &mut self, - module: &Module, - basek: usize, - k: usize, - rank: usize, - ) -> (GLWEPublicKey<&mut [u8], FFT64>, &mut Self) { - let (data, scratch) = self.tmp_fourier_glwe_ct(module, basek, k, rank); - ( - GLWEPublicKey { - data, - dist: Distribution::NONE, - }, - scratch, - ) - } - - fn tmp_glwe_secret(&mut self, module: &Module, rank: usize) -> (GLWESecret<&mut [u8]>, &mut Self) { - let (data, scratch) = self.tmp_scalar_znx(module, rank); - ( - GLWESecret { - data, - dist: Distribution::NONE, - }, - scratch, - ) - } - - fn tmp_fourier_glwe_secret( - &mut self, - module: &Module, - rank: usize, - ) -> (FourierGLWESecret<&mut [u8], FFT64>, &mut Self) { - let (data, scratch) = self.tmp_scalar_znx_dft(module, rank); - ( - FourierGLWESecret { - data, - dist: Distribution::NONE, - }, - scratch, - ) - } - - fn tmp_glwe_ksk( - &mut self, - module: &Module, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - ) -> (GLWESwitchingKey<&mut [u8], FFT64>, &mut Self) { - let (data, scratch) = self.tmp_gglwe(module, basek, k, rows, digits, rank_in, rank_out); - ( - GLWESwitchingKey { - key: data, - sk_in_n: 0, - sk_out_n: 0, - }, - scratch, - ) - } - - fn tmp_autokey( - &mut self, - module: &Module, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank: usize, - ) -> (GLWEAutomorphismKey<&mut [u8], FFT64>, &mut Self) { - let (data, scratch) = self.tmp_glwe_ksk(module, basek, k, rows, digits, rank, rank); - (GLWEAutomorphismKey { key: data, p: 0 }, scratch) - } - - fn tmp_tsk( - &mut self, - module: &Module, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank: usize, - ) -> (GLWETensorKey<&mut [u8], FFT64>, &mut Self) { - let mut keys: Vec> = Vec::new(); - let pairs: usize = (((rank + 1) * rank) >> 1).max(1); - - let mut scratch: &mut Scratch = self; - - if pairs != 0 { - let (gglwe, s) = scratch.tmp_glwe_ksk(module, basek, k, rows, digits, 1, rank); - scratch = s; - keys.push(gglwe); - } - for _ in 1..pairs { - let (gglwe, s) = scratch.tmp_glwe_ksk(module, basek, k, rows, digits, 1, rank); - scratch = s; - keys.push(gglwe); - } - (GLWETensorKey { keys }, scratch) - } -} diff --git a/core/src/lwe/ciphertext.rs b/core/src/lwe/ciphertext.rs index 1e97eb4..4e81a02 100644 --- a/core/src/lwe/ciphertext.rs +++ b/core/src/lwe/ciphertext.rs @@ -1,8 +1,11 @@ -use backend::{VecZnx, VecZnxToMut, VecZnxToRef}; +use backend::hal::{ + api::ZnxInfos, + layouts::{Data, DataMut, DataRef, ReaderFrom, VecZnx, VecZnxToMut, VecZnxToRef, WriterTo}, +}; use crate::{Infos, SetMetaData}; -pub struct LWECiphertext { +pub struct LWECiphertext { pub(crate) data: VecZnx, pub(crate) k: usize, pub(crate) basek: usize, @@ -18,11 +21,14 @@ impl LWECiphertext> { } } -impl Infos for LWECiphertext { - type Inner = VecZnx; +impl Infos for LWECiphertext +where + VecZnx: ZnxInfos, +{ + type Inner = VecZnx; fn n(&self) -> usize { - &self.inner().n - 1 + &self.inner().n() - 1 } fn inner(&self) -> &Self::Inner { @@ -38,7 +44,7 @@ impl Infos for LWECiphertext { } } -impl + AsRef<[u8]>> SetMetaData for LWECiphertext { +impl SetMetaData for LWECiphertext { fn set_k(&mut self, k: usize) { self.k = k } @@ -52,7 +58,7 @@ pub trait LWECiphertextToRef { fn to_ref(&self) -> LWECiphertext<&[u8]>; } -impl> LWECiphertextToRef for LWECiphertext { +impl LWECiphertextToRef for LWECiphertext { fn to_ref(&self) -> LWECiphertext<&[u8]> { LWECiphertext { data: self.data.to_ref(), @@ -63,10 +69,11 @@ impl> LWECiphertextToRef for LWECiphertext { } pub trait LWECiphertextToMut { + #[allow(dead_code)] fn to_mut(&mut self) -> LWECiphertext<&mut [u8]>; } -impl + AsRef<[u8]>> LWECiphertextToMut for LWECiphertext { +impl LWECiphertextToMut for LWECiphertext { fn to_mut(&mut self) -> LWECiphertext<&mut [u8]> { LWECiphertext { data: self.data.to_mut(), @@ -75,3 +82,21 @@ impl + AsRef<[u8]>> LWECiphertextToMut for LWECiphertext { } } } + +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; + +impl ReaderFrom for LWECiphertext { + fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { + self.k = reader.read_u64::()? as usize; + self.basek = reader.read_u64::()? as usize; + self.data.read_from(reader) + } +} + +impl WriterTo for LWECiphertext { + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + writer.write_u64::(self.k as u64)?; + writer.write_u64::(self.basek as u64)?; + self.data.write_to(writer) + } +} diff --git a/core/src/lwe/decryption.rs b/core/src/lwe/decryption.rs index 3ed9d2b..e27799b 100644 --- a/core/src/lwe/decryption.rs +++ b/core/src/lwe/decryption.rs @@ -1,15 +1,21 @@ -use backend::{ZnxView, ZnxViewMut, alloc_aligned}; +use backend::hal::{ + api::{ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxNormalizeInplace, ZnxView, ZnxViewMut}, + layouts::{Backend, DataMut, DataRef, Module, ScratchOwned}, + oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl}, +}; use crate::{Infos, LWECiphertext, LWESecret, SetMetaData, lwe::LWEPlaintext}; impl LWECiphertext where - DataSelf: AsRef<[u8]>, + DataSelf: DataRef, { - pub fn decrypt(&self, pt: &mut LWEPlaintext, sk: &LWESecret) + pub fn decrypt(&self, module: &Module, pt: &mut LWEPlaintext, sk: &LWESecret) where - DataPt: AsRef<[u8]> + AsMut<[u8]>, - DataSk: AsRef<[u8]>, + DataPt: DataMut, + DataSk: DataRef, + Module: VecZnxNormalizeInplace, + B: ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, { #[cfg(debug_assertions)] { @@ -24,10 +30,12 @@ where .map(|(x, y)| x * y) .sum::(); }); - - let mut tmp_bytes: Vec = alloc_aligned(size_of::()); - pt.data.normalize(self.basek(), 0, &mut tmp_bytes); - + module.vec_znx_normalize_inplace( + self.basek(), + &mut pt.data, + 0, + ScratchOwned::alloc(size_of::()).borrow(), + ); pt.set_basek(self.basek()); pt.set_k(self.k().min(pt.size() * self.basek())); } diff --git a/core/src/lwe/encryption.rs b/core/src/lwe/encryption.rs index 00d814f..6331d3b 100644 --- a/core/src/lwe/encryption.rs +++ b/core/src/lwe/encryption.rs @@ -1,22 +1,28 @@ -use backend::{AddNormal, FillUniform, VecZnx, ZnxView, ZnxViewMut, alloc_aligned}; +use backend::hal::{ + api::{ + ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAddNormal, VecZnxFillUniform, VecZnxNormalizeInplace, ZnxView, ZnxViewMut, + }, + layouts::{Backend, DataMut, DataRef, Module, ScratchOwned, VecZnx}, + oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl}, +}; use sampling::source::Source; use crate::{Infos, LWECiphertext, LWESecret, SIX_SIGMA, lwe::LWEPlaintext}; -impl LWECiphertext -where - DataSelf: AsMut<[u8]> + AsRef<[u8]>, -{ - pub fn encrypt_sk( +impl LWECiphertext { + pub fn encrypt_sk( &mut self, + module: &Module, pt: &LWEPlaintext, sk: &LWESecret, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, ) where - DataPt: AsRef<[u8]>, - DataSk: AsRef<[u8]>, + DataPt: DataRef, + DataSk: DataRef, + Module: VecZnxFillUniform + VecZnxAddNormal + VecZnxNormalizeInplace, + B: ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, { #[cfg(debug_assertions)] { @@ -24,8 +30,10 @@ where } let basek: usize = self.basek(); + let k: usize = self.k(); + + module.vec_znx_fill_uniform(basek, &mut self.data, 0, k, source_xa); - self.data.fill_uniform(basek, 0, self.size(), source_xa); let mut tmp_znx: VecZnx> = VecZnx::>::new::(1, 1, self.size()); let min_size = self.size().min(pt.size()); @@ -47,11 +55,22 @@ where .sum::(); }); - tmp_znx.add_normal(basek, 0, self.k(), source_xe, sigma, sigma * SIX_SIGMA); + module.vec_znx_add_normal( + basek, + &mut self.data, + 0, + k, + source_xe, + sigma, + sigma * SIX_SIGMA, + ); - let mut tmp_bytes: Vec = alloc_aligned(size_of::()); - - tmp_znx.normalize(basek, 0, &mut tmp_bytes); + module.vec_znx_normalize_inplace( + basek, + &mut tmp_znx, + 0, + ScratchOwned::alloc(size_of::()).borrow(), + ); (0..self.size()).for_each(|i| { self.data.at_mut(0, i)[0] = tmp_znx.at(0, i)[0]; diff --git a/core/src/lwe/keyswtich.rs b/core/src/lwe/keyswtich.rs index d06b7aa..4e88f85 100644 --- a/core/src/lwe/keyswtich.rs +++ b/core/src/lwe/keyswtich.rs @@ -1,55 +1,137 @@ -use backend::{Backend, FFT64, Module, Scratch, VecZnxOps, ZnxView, ZnxViewMut, ZnxZero}; +use backend::hal::{ + api::{ + MatZnxAlloc, ScalarZnxAllocBytes, ScratchAvailable, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft, VecZnxAddScalarInplace, + VecZnxAllocBytes, VecZnxAutomorphismInplace, VecZnxSwithcDegree, ZnxView, ZnxViewMut, ZnxZero, + }, + layouts::{Backend, Data, DataMut, DataRef, Module, ReaderFrom, Scratch, WriterTo}, +}; use sampling::source::Source; -use crate::{FourierGLWESecret, GLWECiphertext, GLWESecret, GLWESwitchingKey, Infos, LWECiphertext, LWESecret, ScratchCore}; +use crate::{ + GGLWEEncryptSkFamily, GGLWEExecLayoutFamily, GLWECiphertext, GLWEKeyswitchFamily, GLWESecret, GLWESecretExec, + GLWESwitchingKey, GLWESwitchingKeyExec, Infos, LWECiphertext, LWESecret, TakeGLWECt, TakeGLWESecret, TakeGLWESecretExec, +}; /// A special [GLWESwitchingKey] required to for the conversion from [GLWECiphertext] to [LWECiphertext]. -pub struct GLWEToLWESwitchingKey(GLWESwitchingKey); +#[derive(PartialEq, Eq)] +pub struct GLWEToLWESwitchingKey(GLWESwitchingKey); -impl GLWEToLWESwitchingKey, FFT64> { - pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, rank: usize) -> Self { - Self(GLWESwitchingKey::alloc(module, basek, k, rows, 1, rank, 1)) - } - - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { - FourierGLWESecret::bytes_of(module, rank) - + (GLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, rank, rank) | GLWESecret::bytes_of(module, rank)) +impl ReaderFrom for GLWEToLWESwitchingKey { + fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { + self.0.read_from(reader) } } -impl + AsRef<[u8]>> GLWEToLWESwitchingKey { - pub fn encrypt_sk( +impl WriterTo for GLWEToLWESwitchingKey { + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + self.0.write_to(writer) + } +} + +#[derive(PartialEq, Eq)] +pub struct GLWEToLWESwitchingKeyExec(GLWESwitchingKeyExec); + +impl GLWEToLWESwitchingKeyExec, B> { + pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, rank_in: usize) -> Self + where + Module: GGLWEExecLayoutFamily, + { + Self(GLWESwitchingKeyExec::alloc( + module, basek, k, rows, 1, rank_in, 1, + )) + } + + pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize) -> usize + where + Module: GGLWEExecLayoutFamily, + { + GLWESwitchingKeyExec::, B>::bytes_of(module, basek, k, rows, digits, rank_in, 1) + } + + pub fn from( + module: &Module, + other: &GLWEToLWESwitchingKey, + scratch: &mut Scratch, + ) -> Self + where + Module: GGLWEExecLayoutFamily, + { + let mut ksk_exec: GLWEToLWESwitchingKeyExec, B> = Self::alloc( + module, + other.0.basek(), + other.0.k(), + other.0.rows(), + other.0.rank_in(), + ); + ksk_exec.prepare(module, other, scratch); + ksk_exec + } +} + +impl GLWEToLWESwitchingKeyExec { + pub fn prepare(&mut self, module: &Module, other: &GLWEToLWESwitchingKey, scratch: &mut Scratch) + where + DataOther: DataRef, + Module: GGLWEExecLayoutFamily, + { + self.0.prepare(module, &other.0, scratch); + } +} + +impl GLWEToLWESwitchingKey> { + pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, rank_in: usize) -> Self + where + Module: MatZnxAlloc, + { + Self(GLWESwitchingKey::alloc( + module, basek, k, rows, 1, rank_in, 1, + )) + } + + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank_in: usize) -> usize + where + Module: GGLWEEncryptSkFamily + ScalarZnxAllocBytes + VecZnxAllocBytes, + { + GLWESecretExec::bytes_of(module, rank_in) + + (GLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, rank_in, 1) | GLWESecret::bytes_of(module, rank_in)) + } +} + +impl GLWEToLWESwitchingKey { + pub fn encrypt_sk( &mut self, - module: &Module, + module: &Module, sk_lwe: &LWESecret, sk_glwe: &GLWESecret, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, - scratch: &mut Scratch, + scratch: &mut Scratch, ) where - DLwe: AsRef<[u8]>, - DGlwe: AsRef<[u8]>, + DLwe: DataRef, + DGlwe: DataRef, + Module: GGLWEEncryptSkFamily + + VecZnxAutomorphismInplace + + ScalarZnxAllocBytes + + VecZnxSwithcDegree + + VecZnxAllocBytes + + VecZnxAddScalarInplace, + Scratch: ScratchAvailable + TakeScalarZnx + TakeVecZnxDft + TakeGLWESecretExec + TakeVecZnx, { #[cfg(debug_assertions)] { assert!(sk_lwe.n() <= module.n()); } - let (mut sk_lwe_as_glwe_dft, scratch1) = scratch.tmp_fourier_glwe_secret(module, 1); - - { - let (mut sk_lwe_as_glwe, _) = scratch1.tmp_glwe_secret(module, 1); - sk_lwe_as_glwe.data.zero(); - sk_lwe_as_glwe.data.at_mut(0, 0)[..sk_lwe.n()].copy_from_slice(sk_lwe.data.at(0, 0)); - module.vec_znx_automorphism_inplace(-1, &mut sk_lwe_as_glwe.data, 0); - sk_lwe_as_glwe_dft.set(module, &sk_lwe_as_glwe); - } + let (mut sk_lwe_as_glwe, scratch1) = scratch.take_glwe_secret(module, 1); + sk_lwe_as_glwe.data.zero(); + sk_lwe_as_glwe.data.at_mut(0, 0)[..sk_lwe.n()].copy_from_slice(sk_lwe.data.at(0, 0)); + module.vec_znx_automorphism_inplace(-1, &mut sk_lwe_as_glwe.data, 0); self.0.encrypt_sk( module, sk_glwe, - &sk_lwe_as_glwe_dft, + &sk_lwe_as_glwe, source_xa, source_xe, sigma, @@ -59,38 +141,115 @@ impl + AsRef<[u8]>> GLWEToLWESwitchingKey { } /// A special [GLWESwitchingKey] required to for the conversion from [LWECiphertext] to [GLWECiphertext]. -pub struct LWEToGLWESwitchingKey(GLWESwitchingKey); +#[derive(PartialEq, Eq)] +pub struct LWEToGLWESwitchingKeyExec(GLWESwitchingKeyExec); -impl LWEToGLWESwitchingKey, FFT64> { - pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, rank: usize) -> Self { - Self(GLWESwitchingKey::alloc(module, basek, k, rows, 1, 1, rank)) +impl LWEToGLWESwitchingKeyExec, B> { + pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, rank_out: usize) -> Self + where + Module: GGLWEExecLayoutFamily, + { + Self(GLWESwitchingKeyExec::alloc( + module, basek, k, rows, 1, 1, rank_out, + )) } - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { - GLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, 1, rank) + GLWESecret::bytes_of(module, 1) + pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank_out: usize) -> usize + where + Module: GGLWEExecLayoutFamily, + { + GLWESwitchingKeyExec::, B>::bytes_of(module, basek, k, rows, digits, 1, rank_out) + } + + pub fn from( + module: &Module, + other: &LWEToGLWESwitchingKey, + scratch: &mut Scratch, + ) -> Self + where + Module: GGLWEExecLayoutFamily, + { + let mut ksk_exec: LWEToGLWESwitchingKeyExec, B> = Self::alloc( + module, + other.0.basek(), + other.0.k(), + other.0.rows(), + other.0.rank(), + ); + ksk_exec.prepare(module, other, scratch); + ksk_exec } } -impl + AsRef<[u8]>> LWEToGLWESwitchingKey { - pub fn encrypt_sk( +impl LWEToGLWESwitchingKeyExec { + pub fn prepare(&mut self, module: &Module, other: &LWEToGLWESwitchingKey, scratch: &mut Scratch) + where + DataOther: DataRef, + Module: GGLWEExecLayoutFamily, + { + self.0.prepare(module, &other.0, scratch); + } +} +#[derive(PartialEq, Eq)] +pub struct LWEToGLWESwitchingKey(GLWESwitchingKey); + +impl ReaderFrom for LWEToGLWESwitchingKey { + fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { + self.0.read_from(reader) + } +} + +impl WriterTo for LWEToGLWESwitchingKey { + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + self.0.write_to(writer) + } +} + +impl LWEToGLWESwitchingKey> { + pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, rank_out: usize) -> Self + where + Module: MatZnxAlloc, + { + Self(GLWESwitchingKey::alloc( + module, basek, k, rows, 1, 1, rank_out, + )) + } + + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank_out: usize) -> usize + where + Module: GGLWEEncryptSkFamily + ScalarZnxAllocBytes + VecZnxAllocBytes, + { + GLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, 1, rank_out) + GLWESecret::bytes_of(module, 1) + } +} + +impl LWEToGLWESwitchingKey { + pub fn encrypt_sk( &mut self, - module: &Module, + module: &Module, sk_lwe: &LWESecret, - sk_glwe: &FourierGLWESecret, + sk_glwe: &GLWESecret, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, - scratch: &mut Scratch, + scratch: &mut Scratch, ) where - DLwe: AsRef<[u8]>, - DGlwe: AsRef<[u8]>, + DLwe: DataRef, + DGlwe: DataRef, + Module: GGLWEEncryptSkFamily + + VecZnxAutomorphismInplace + + ScalarZnxAllocBytes + + VecZnxSwithcDegree + + VecZnxAllocBytes + + VecZnxAddScalarInplace, + Scratch: ScratchAvailable + TakeScalarZnx + TakeVecZnxDft + TakeGLWESecretExec + TakeVecZnx, { #[cfg(debug_assertions)] { assert!(sk_lwe.n() <= module.n()); } - let (mut sk_lwe_as_glwe, scratch1) = scratch.tmp_glwe_secret(module, 1); + let (mut sk_lwe_as_glwe, scratch1) = scratch.take_glwe_secret(module, 1); sk_lwe_as_glwe.data.at_mut(0, 0)[..sk_lwe.n()].copy_from_slice(sk_lwe.data.at(0, 0)); sk_lwe_as_glwe.data.at_mut(0, 0)[sk_lwe.n()..].fill(0); module.vec_znx_automorphism_inplace(-1, &mut sk_lwe_as_glwe.data, 0); @@ -107,33 +266,96 @@ impl + AsRef<[u8]>> LWEToGLWESwitchingKey { } } -pub struct LWESwitchingKey(GLWESwitchingKey); +#[derive(PartialEq, Eq)] +pub struct LWESwitchingKeyExec(GLWESwitchingKeyExec); -impl LWESwitchingKey, FFT64> { - pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize) -> Self { +impl LWESwitchingKeyExec, B> { + pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize) -> Self + where + Module: GGLWEExecLayoutFamily, + { + Self(GLWESwitchingKeyExec::alloc(module, basek, k, rows, 1, 1, 1)) + } + + pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize) -> usize + where + Module: GGLWEExecLayoutFamily, + { + GLWESwitchingKeyExec::, B>::bytes_of(module, basek, k, rows, digits, 1, 1) + } + + pub fn from(module: &Module, other: &LWESwitchingKey, scratch: &mut Scratch) -> Self + where + Module: GGLWEExecLayoutFamily, + { + let mut ksk_exec: LWESwitchingKeyExec, B> = Self::alloc(module, other.0.basek(), other.0.k(), other.0.rows()); + ksk_exec.prepare(module, other, scratch); + ksk_exec + } +} + +impl LWESwitchingKeyExec { + pub fn prepare(&mut self, module: &Module, other: &LWESwitchingKey, scratch: &mut Scratch) + where + DataOther: DataRef, + Module: GGLWEExecLayoutFamily, + { + self.0.prepare(module, &other.0, scratch); + } +} +#[derive(PartialEq, Eq)] +pub struct LWESwitchingKey(GLWESwitchingKey); + +impl ReaderFrom for LWESwitchingKey { + fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { + self.0.read_from(reader) + } +} + +impl WriterTo for LWESwitchingKey { + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + self.0.write_to(writer) + } +} + +impl LWESwitchingKey> { + pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize) -> Self + where + Module: MatZnxAlloc, + { Self(GLWESwitchingKey::alloc(module, basek, k, rows, 1, 1, 1)) } - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize) -> usize { + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize) -> usize + where + Module: GGLWEEncryptSkFamily + ScalarZnxAllocBytes + VecZnxAllocBytes, + { GLWESecret::bytes_of(module, 1) - + FourierGLWESecret::bytes_of(module, 1) + + GLWESecretExec::bytes_of(module, 1) + GLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, 1, 1) } } -impl + AsRef<[u8]>> LWESwitchingKey { - pub fn encrypt_sk( +impl LWESwitchingKey { + pub fn encrypt_sk( &mut self, - module: &Module, + module: &Module, sk_lwe_in: &LWESecret, sk_lwe_out: &LWESecret, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, - scratch: &mut Scratch, + scratch: &mut Scratch, ) where - DIn: AsRef<[u8]>, - DOut: AsRef<[u8]>, + DIn: DataRef, + DOut: DataRef, + Module: GGLWEEncryptSkFamily + + VecZnxAutomorphismInplace + + ScalarZnxAllocBytes + + VecZnxSwithcDegree + + VecZnxAllocBytes + + VecZnxAddScalarInplace, + Scratch: ScratchAvailable + TakeScalarZnx + TakeVecZnxDft + TakeGLWESecretExec + TakeVecZnx, { #[cfg(debug_assertions)] { @@ -141,13 +363,13 @@ impl + AsRef<[u8]>> LWESwitchingKey { assert!(sk_lwe_out.n() <= module.n()); } - let (mut sk_in_glwe, scratch1) = scratch.tmp_glwe_secret(module, 1); - let (mut sk_out_glwe, scratch2) = scratch1.tmp_fourier_glwe_secret(module, 1); + let (mut sk_in_glwe, scratch1) = scratch.take_glwe_secret(module, 1); + let (mut sk_out_glwe, scratch2) = scratch1.take_glwe_secret(module, 1); + + sk_out_glwe.data.at_mut(0, 0)[..sk_lwe_out.n()].copy_from_slice(sk_lwe_out.data.at(0, 0)); + sk_out_glwe.data.at_mut(0, 0)[sk_lwe_out.n()..].fill(0); + module.vec_znx_automorphism_inplace(-1, &mut sk_out_glwe.data, 0); - sk_in_glwe.data.at_mut(0, 0)[..sk_lwe_out.n()].copy_from_slice(sk_lwe_out.data.at(0, 0)); - sk_in_glwe.data.at_mut(0, 0)[sk_lwe_out.n()..].fill(0); - module.vec_znx_automorphism_inplace(-1, &mut sk_in_glwe.data, 0); - sk_out_glwe.set(module, &sk_in_glwe); sk_in_glwe.data.at_mut(0, 0)[..sk_lwe_in.n()].copy_from_slice(sk_lwe_in.data.at(0, 0)); sk_in_glwe.data.at_mut(0, 0)[sk_lwe_in.n()..].fill(0); module.vec_znx_automorphism_inplace(-1, &mut sk_in_glwe.data, 0); @@ -165,35 +387,38 @@ impl + AsRef<[u8]>> LWESwitchingKey { } impl LWECiphertext> { - pub fn from_glwe_scratch_space( - module: &Module, + pub fn from_glwe_scratch_space( + module: &Module, basek: usize, k_lwe: usize, k_glwe: usize, k_ksk: usize, rank: usize, - ) -> usize { + ) -> usize + where + Module: GLWEKeyswitchFamily + VecZnxAllocBytes, + { GLWECiphertext::bytes_of(module, basek, k_lwe, 1) + GLWECiphertext::keyswitch_scratch_space(module, basek, k_lwe, k_glwe, k_ksk, 1, rank, 1) } - pub fn keyswitch_scratch_space( - module: &Module, + pub fn keyswitch_scratch_space( + module: &Module, basek: usize, k_lwe_out: usize, k_lwe_in: usize, k_ksk: usize, - ) -> usize { + ) -> usize + where + Module: GLWEKeyswitchFamily + ScalarZnxAllocBytes + VecZnxAllocBytes, + { GLWECiphertext::bytes_of(module, basek, k_lwe_out.max(k_lwe_in), 1) + GLWECiphertext::keyswitch_inplace_scratch_space(module, basek, k_lwe_out, k_ksk, 1, 1) } } -impl + AsMut<[u8]>> LWECiphertext { - pub fn sample_extract(&mut self, a: &GLWECiphertext) - where - DGlwe: AsRef<[u8]>, - { +impl LWECiphertext { + pub fn sample_extract(&mut self, a: &GLWECiphertext) { #[cfg(debug_assertions)] { assert!(self.n() <= a.n()); @@ -210,34 +435,38 @@ impl + AsMut<[u8]>> LWECiphertext { }); } - pub fn from_glwe( + pub fn from_glwe( &mut self, - module: &Module, + module: &Module, a: &GLWECiphertext, - ks: &GLWEToLWESwitchingKey, - scratch: &mut Scratch, + ks: &GLWEToLWESwitchingKeyExec, + scratch: &mut Scratch, ) where - DGlwe: AsRef<[u8]>, - DKs: AsRef<[u8]>, + DGlwe: DataRef, + DKs: DataRef, + Module: GLWEKeyswitchFamily, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { #[cfg(debug_assertions)] { assert_eq!(self.basek(), a.basek()); } - let (mut tmp_glwe, scratch1) = scratch.tmp_glwe_ct(module, a.basek(), self.k(), 1); + let (mut tmp_glwe, scratch1) = scratch.take_glwe_ct(module, a.basek(), self.k(), 1); tmp_glwe.keyswitch(module, a, &ks.0, scratch1); self.sample_extract(&tmp_glwe); } - pub fn keyswitch( + pub fn keyswitch( &mut self, - module: &Module, + module: &Module, a: &LWECiphertext, - ksk: &LWESwitchingKey, - scratch: &mut Scratch, + ksk: &LWESwitchingKeyExec, + scratch: &mut Scratch, ) where - A: AsRef<[u8]>, - DKs: AsRef<[u8]>, + A: DataRef, + DKs: DataRef, + Module: GLWEKeyswitchFamily, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { #[cfg(debug_assertions)] { @@ -249,7 +478,7 @@ impl + AsMut<[u8]>> LWECiphertext { let max_k: usize = self.k().max(a.k()); let basek: usize = self.basek(); - let (mut glwe, scratch1) = scratch.tmp_glwe_ct(&module, basek, max_k, 1); + let (mut glwe, scratch1) = scratch.take_glwe_ct(&module, basek, max_k, 1); glwe.data.zero(); let n_lwe: usize = a.n(); @@ -267,29 +496,34 @@ impl + AsMut<[u8]>> LWECiphertext { } impl GLWECiphertext> { - pub fn from_lwe_scratch_space( - module: &Module, + pub fn from_lwe_scratch_space( + module: &Module, basek: usize, k_lwe: usize, k_glwe: usize, k_ksk: usize, rank: usize, - ) -> usize { + ) -> usize + where + Module: GLWEKeyswitchFamily + VecZnxAllocBytes, + { GLWECiphertext::keyswitch_scratch_space(module, basek, k_glwe, k_lwe, k_ksk, 1, 1, rank) + GLWECiphertext::bytes_of(module, basek, k_lwe, 1) } } -impl + AsMut<[u8]>> GLWECiphertext { - pub fn from_lwe( +impl GLWECiphertext { + pub fn from_lwe( &mut self, - module: &Module, + module: &Module, lwe: &LWECiphertext, - ksk: &LWEToGLWESwitchingKey, - scratch: &mut Scratch, + ksk: &LWEToGLWESwitchingKeyExec, + scratch: &mut Scratch, ) where - DLwe: AsRef<[u8]>, - DKsk: AsRef<[u8]>, + DLwe: DataRef, + DKsk: DataRef, + Module: GLWEKeyswitchFamily, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { #[cfg(debug_assertions)] { @@ -297,7 +531,7 @@ impl + AsMut<[u8]>> GLWECiphertext { assert_eq!(self.basek(), self.basek()); } - let (mut glwe, scratch1) = scratch.tmp_glwe_ct(module, lwe.basek(), lwe.k(), 1); + let (mut glwe, scratch1) = scratch.take_glwe_ct(module, lwe.basek(), lwe.k(), 1); glwe.data.zero(); let n_lwe: usize = lwe.n(); diff --git a/core/src/lwe/plaintext.rs b/core/src/lwe/plaintext.rs index 7c73351..3b76d0e 100644 --- a/core/src/lwe/plaintext.rs +++ b/core/src/lwe/plaintext.rs @@ -1,8 +1,8 @@ -use backend::{VecZnx, VecZnxToMut, VecZnxToRef}; +use backend::hal::layouts::{Data, DataMut, DataRef, VecZnx, VecZnxToMut, VecZnxToRef}; use crate::{Infos, SetMetaData}; -pub struct LWEPlaintext { +pub struct LWEPlaintext { pub(crate) data: VecZnx, pub(crate) k: usize, pub(crate) basek: usize, @@ -18,8 +18,8 @@ impl LWEPlaintext> { } } -impl Infos for LWEPlaintext { - type Inner = VecZnx; +impl Infos for LWEPlaintext { + type Inner = VecZnx; fn inner(&self) -> &Self::Inner { &self.data @@ -34,7 +34,7 @@ impl Infos for LWEPlaintext { } } -impl + AsRef<[u8]>> SetMetaData for LWEPlaintext { +impl SetMetaData for LWEPlaintext { fn set_k(&mut self, k: usize) { self.k = k } @@ -45,10 +45,11 @@ impl + AsRef<[u8]>> SetMetaData for LWEPlaintext } pub trait LWEPlaintextToRef { + #[allow(dead_code)] fn to_ref(&self) -> LWEPlaintext<&[u8]>; } -impl> LWEPlaintextToRef for LWEPlaintext { +impl LWEPlaintextToRef for LWEPlaintext { fn to_ref(&self) -> LWEPlaintext<&[u8]> { LWEPlaintext { data: self.data.to_ref(), @@ -59,10 +60,11 @@ impl> LWEPlaintextToRef for LWEPlaintext { } pub trait LWEPlaintextToMut { + #[allow(dead_code)] fn to_mut(&mut self) -> LWEPlaintext<&mut [u8]>; } -impl + AsRef<[u8]>> LWEPlaintextToMut for LWEPlaintext { +impl LWEPlaintextToMut for LWEPlaintext { fn to_mut(&mut self) -> LWEPlaintext<&mut [u8]> { LWEPlaintext { data: self.data.to_mut(), diff --git a/core/src/lwe/secret.rs b/core/src/lwe/secret.rs index 90776a7..0c9aea1 100644 --- a/core/src/lwe/secret.rs +++ b/core/src/lwe/secret.rs @@ -1,10 +1,13 @@ -use backend::{ScalarZnx, ZnxInfos, ZnxZero}; +use backend::hal::{ + api::{ZnxInfos, ZnxZero}, + layouts::{Data, DataMut, ScalarZnx}, +}; use sampling::source::Source; use crate::Distribution; -pub struct LWESecret { - pub(crate) data: ScalarZnx, +pub struct LWESecret { + pub(crate) data: ScalarZnx, pub(crate) dist: Distribution, } @@ -17,7 +20,7 @@ impl LWESecret> { } } -impl LWESecret { +impl LWESecret { pub fn n(&self) -> usize { self.data.n() } @@ -31,7 +34,7 @@ impl LWESecret { } } -impl + AsMut<[u8]>> LWESecret { +impl LWESecret { pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) { self.data.fill_ternary_prob(0, prob, source); self.dist = Distribution::TernaryProb(prob); diff --git a/core/src/lwe/test_fft64/conversion.rs b/core/src/lwe/test_fft64/conversion.rs index 1fbd4cb..b403146 100644 --- a/core/src/lwe/test_fft64/conversion.rs +++ b/core/src/lwe/test_fft64/conversion.rs @@ -1,18 +1,67 @@ -use backend::{Encoding, FFT64, Module, ScratchOwned, ZnxView}; +use backend::{ + hal::{ + api::{ + MatZnxAlloc, ModuleNew, ScalarZnxAlloc, ScalarZnxAllocBytes, ScratchOwnedAlloc, ScratchOwnedBorrow, + VecZnxAddScalarInplace, VecZnxAlloc, VecZnxAllocBytes, VecZnxAutomorphismInplace, VecZnxEncodeCoeffsi64, + VecZnxSwithcDegree, ZnxView, + }, + layouts::{Backend, Module, ScratchOwned}, + oep::{ + ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, + TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, + }, + }, + implementation::cpu_spqlios::FFT64, +}; use sampling::source::Source; use crate::{ - FourierGLWESecret, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, LWECiphertext, LWESecret, + GGLWEEncryptSkFamily, GGLWEExecLayoutFamily, GLWECiphertext, GLWEDecryptFamily, GLWEKeyswitchFamily, GLWEPlaintext, + GLWESecret, GLWESecretExec, Infos, LWECiphertext, LWESecret, lwe::{ LWEPlaintext, - keyswtich::{GLWEToLWESwitchingKey, LWESwitchingKey, LWEToGLWESwitchingKey}, + keyswtich::{ + GLWEToLWESwitchingKey, GLWEToLWESwitchingKeyExec, LWESwitchingKey, LWESwitchingKeyExec, LWEToGLWESwitchingKey, + LWEToGLWESwitchingKeyExec, + }, }, }; #[test] fn lwe_to_glwe() { - let n: usize = 1 << 5; - let module: Module = Module::::new(n); + let log_n: usize = 5; + let module: Module = Module::::new(1 << log_n); + test_lwe_to_glwe(&module) +} + +pub(crate) trait LWETestModuleFamily = GGLWEEncryptSkFamily + + GLWEDecryptFamily + + VecZnxSwithcDegree + + VecZnxAddScalarInplace + + VecZnxAlloc + + GGLWEExecLayoutFamily + + GLWEKeyswitchFamily + + ScalarZnxAllocBytes + + VecZnxAllocBytes + + ScalarZnxAlloc + + VecZnxEncodeCoeffsi64 + + MatZnxAlloc + + VecZnxAutomorphismInplace; + +pub(crate) trait LWETestScratchFamily = TakeScalarZnxImpl + + TakeVecZnxDftImpl + + ScratchAvailableImpl + + TakeVecZnxImpl + + TakeVecZnxBigImpl + + TakeSvpPPolImpl + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl; + +pub(crate) fn test_lwe_to_glwe(module: &Module) +where + Module: LWETestModuleFamily, + B: LWETestScratchFamily, +{ let basek: usize = 17; let sigma: f64 = 3.2; @@ -30,56 +79,71 @@ fn lwe_to_glwe() { let mut source_xa: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::new( - LWEToGLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank) - | GLWECiphertext::from_lwe_scratch_space(&module, basek, k_lwe_ct, k_glwe_ct, k_ksk, rank) - | GLWECiphertext::decrypt_scratch_space(&module, basek, k_glwe_ct), + let mut scratch: ScratchOwned = ScratchOwned::alloc( + LWEToGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k_ksk, rank) + | GLWECiphertext::from_lwe_scratch_space(module, basek, k_lwe_ct, k_glwe_ct, k_ksk, rank) + | GLWECiphertext::decrypt_scratch_space(module, basek, k_glwe_ct), ); - let mut sk_glwe: GLWESecret> = GLWESecret::alloc(&module, rank); + let mut sk_glwe: GLWESecret> = GLWESecret::alloc(module, rank); sk_glwe.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_glwe_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::alloc(&module, rank); - sk_glwe_dft.set(&module, &sk_glwe); + let sk_glwe_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk_glwe); - let mut sk_lwe = LWESecret::alloc(n_lwe); + let mut sk_lwe: LWESecret> = LWESecret::alloc(n_lwe); sk_lwe.fill_ternary_prob(0.5, &mut source_xs); let data: i64 = 17; let mut lwe_pt: LWEPlaintext> = LWEPlaintext::alloc(basek, k_lwe_pt); - lwe_pt - .data - .encode_coeff_i64(0, basek, k_lwe_pt, 0, data, k_lwe_pt); + module.encode_coeff_i64(basek, &mut lwe_pt.data, 0, k_lwe_pt, 0, data, k_lwe_pt); let mut lwe_ct: LWECiphertext> = LWECiphertext::alloc(n_lwe, basek, k_lwe_ct); - lwe_ct.encrypt_sk(&lwe_pt, &sk_lwe, &mut source_xa, &mut source_xe, sigma); + lwe_ct.encrypt_sk( + module, + &lwe_pt, + &sk_lwe, + &mut source_xa, + &mut source_xe, + sigma, + ); - let mut ksk: LWEToGLWESwitchingKey, FFT64> = LWEToGLWESwitchingKey::alloc(&module, basek, k_ksk, lwe_ct.size(), rank); + let mut ksk: LWEToGLWESwitchingKey> = LWEToGLWESwitchingKey::alloc(module, basek, k_ksk, lwe_ct.size(), rank); ksk.encrypt_sk( - &module, + module, &sk_lwe, - &sk_glwe_dft, + &sk_glwe, &mut source_xa, &mut source_xe, sigma, scratch.borrow(), ); - let mut glwe_ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_glwe_ct, rank); - glwe_ct.from_lwe(&module, &lwe_ct, &ksk, scratch.borrow()); + let mut glwe_ct: GLWECiphertext> = GLWECiphertext::alloc(module, basek, k_glwe_ct, rank); - let mut glwe_pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_glwe_ct); - glwe_ct.decrypt(&module, &mut glwe_pt, &sk_glwe_dft, scratch.borrow()); + let ksk_exec: LWEToGLWESwitchingKeyExec, B> = LWEToGLWESwitchingKeyExec::from(module, &ksk, scratch.borrow()); + + glwe_ct.from_lwe(module, &lwe_ct, &ksk_exec, scratch.borrow()); + + let mut glwe_pt: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k_glwe_ct); + glwe_ct.decrypt(module, &mut glwe_pt, &sk_glwe_exec, scratch.borrow()); assert_eq!(glwe_pt.data.at(0, 0)[0], lwe_pt.data.at(0, 0)[0]); } #[test] fn glwe_to_lwe() { - let n: usize = 1 << 5; - let module: Module = Module::::new(n); + let log_n: usize = 5; + let module: Module = Module::::new(1 << log_n); + test_glwe_to_lwe(&module) +} + +fn test_glwe_to_lwe(module: &Module) +where + Module: LWETestModuleFamily, + B: LWETestScratchFamily, +{ let basek: usize = 17; let sigma: f64 = 3.2; @@ -97,44 +161,39 @@ fn glwe_to_lwe() { let mut source_xa: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::new( - LWEToGLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank) - | GLWECiphertext::from_lwe_scratch_space(&module, basek, k_lwe_ct, k_glwe_ct, k_ksk, rank) - | GLWECiphertext::decrypt_scratch_space(&module, basek, k_glwe_ct), + let mut scratch: ScratchOwned = ScratchOwned::alloc( + LWEToGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k_ksk, rank) + | GLWECiphertext::from_lwe_scratch_space(module, basek, k_lwe_ct, k_glwe_ct, k_ksk, rank) + | GLWECiphertext::decrypt_scratch_space(module, basek, k_glwe_ct), ); - let mut sk_glwe: GLWESecret> = GLWESecret::alloc(&module, rank); + let mut sk_glwe: GLWESecret> = GLWESecret::alloc(module, rank); sk_glwe.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_glwe_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::alloc(&module, rank); - sk_glwe_dft.set(&module, &sk_glwe); + let sk_glwe_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk_glwe); let mut sk_lwe = LWESecret::alloc(n_lwe); sk_lwe.fill_ternary_prob(0.5, &mut source_xs); let data: i64 = 17; - let mut glwe_pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_glwe_ct); + let mut glwe_pt: GLWEPlaintext> = GLWEPlaintext::alloc(module, basek, k_glwe_ct); + module.encode_coeff_i64(basek, &mut glwe_pt.data, 0, k_lwe_pt, 0, data, k_lwe_pt); - glwe_pt - .data - .encode_coeff_i64(0, basek, k_lwe_pt, 0, data, k_lwe_pt); - - let mut glwe_ct = GLWECiphertext::alloc(&module, basek, k_glwe_ct, rank); + let mut glwe_ct = GLWECiphertext::alloc(module, basek, k_glwe_ct, rank); glwe_ct.encrypt_sk( - &module, + module, &glwe_pt, - &sk_glwe_dft, + &sk_glwe_exec, &mut source_xa, &mut source_xe, sigma, scratch.borrow(), ); - let mut ksk: GLWEToLWESwitchingKey, FFT64> = - GLWEToLWESwitchingKey::alloc(&module, basek, k_ksk, glwe_ct.size(), rank); + let mut ksk: GLWEToLWESwitchingKey> = GLWEToLWESwitchingKey::alloc(module, basek, k_ksk, glwe_ct.size(), rank); ksk.encrypt_sk( - &module, + module, &sk_lwe, &sk_glwe, &mut source_xa, @@ -144,18 +203,29 @@ fn glwe_to_lwe() { ); let mut lwe_ct: LWECiphertext> = LWECiphertext::alloc(n_lwe, basek, k_lwe_ct); - lwe_ct.from_glwe(&module, &glwe_ct, &ksk, scratch.borrow()); + + let ksk_exec: GLWEToLWESwitchingKeyExec, B> = GLWEToLWESwitchingKeyExec::from(module, &ksk, scratch.borrow()); + + lwe_ct.from_glwe(module, &glwe_ct, &ksk_exec, scratch.borrow()); let mut lwe_pt: LWEPlaintext> = LWEPlaintext::alloc(basek, k_lwe_ct); - lwe_ct.decrypt(&mut lwe_pt, &sk_lwe); + lwe_ct.decrypt(module, &mut lwe_pt, &sk_lwe); assert_eq!(glwe_pt.data.at(0, 0)[0], lwe_pt.data.at(0, 0)[0]); } #[test] fn keyswitch() { - let n: usize = 1 << 5; - let module: Module = Module::::new(n); + let log_n: usize = 5; + let module: Module = Module::::new(1 << log_n); + test_keyswitch(&module) +} + +fn test_keyswitch(module: &Module) +where + Module: LWETestModuleFamily, + B: LWETestScratchFamily, +{ let basek: usize = 17; let sigma: f64 = 3.2; @@ -170,9 +240,9 @@ fn keyswitch() { let mut source_xa: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::new( - LWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk) - | LWECiphertext::keyswitch_scratch_space(&module, basek, k_lwe_ct, k_lwe_ct, k_ksk), + let mut scratch: ScratchOwned = ScratchOwned::alloc( + LWESwitchingKey::encrypt_sk_scratch_space(module, basek, k_ksk) + | LWECiphertext::keyswitch_scratch_space(module, basek, k_lwe_ct, k_lwe_ct, k_ksk), ); let mut sk_lwe_in: LWESecret> = LWESecret::alloc(n_lwe_in); @@ -184,12 +254,11 @@ fn keyswitch() { let data: i64 = 17; let mut lwe_pt_in: LWEPlaintext> = LWEPlaintext::alloc(basek, k_lwe_pt); - lwe_pt_in - .data - .encode_coeff_i64(0, basek, k_lwe_pt, 0, data, k_lwe_pt); + module.encode_coeff_i64(basek, &mut lwe_pt_in.data, 0, k_lwe_pt, 0, data, k_lwe_pt); let mut lwe_ct_in: LWECiphertext> = LWECiphertext::alloc(n_lwe_in, basek, k_lwe_ct); lwe_ct_in.encrypt_sk( + module, &lwe_pt_in, &sk_lwe_in, &mut source_xa, @@ -197,10 +266,10 @@ fn keyswitch() { sigma, ); - let mut ksk: LWESwitchingKey, FFT64> = LWESwitchingKey::alloc(&module, basek, k_ksk, lwe_ct_in.size()); + let mut ksk: LWESwitchingKey> = LWESwitchingKey::alloc(module, basek, k_ksk, lwe_ct_in.size()); ksk.encrypt_sk( - &module, + module, &sk_lwe_in, &sk_lwe_out, &mut source_xa, @@ -211,10 +280,12 @@ fn keyswitch() { let mut lwe_ct_out: LWECiphertext> = LWECiphertext::alloc(n_lwe_out, basek, k_lwe_ct); - lwe_ct_out.keyswitch(&module, &lwe_ct_in, &ksk, scratch.borrow()); + let ksk_exec: LWESwitchingKeyExec, B> = LWESwitchingKeyExec::from(module, &ksk, scratch.borrow()); + + lwe_ct_out.keyswitch(module, &lwe_ct_in, &ksk_exec, scratch.borrow()); let mut lwe_pt_out: LWEPlaintext> = LWEPlaintext::alloc(basek, k_lwe_ct); - lwe_ct_out.decrypt(&mut lwe_pt_out, &sk_lwe_out); + lwe_ct_out.decrypt(module, &mut lwe_pt_out, &sk_lwe_out); assert_eq!(lwe_pt_in.data.at(0, 0)[0], lwe_pt_out.data.at(0, 0)[0]); } diff --git a/core/src/noise.rs b/core/src/noise.rs index cfc7698..e6901ac 100644 --- a/core/src/noise.rs +++ b/core/src/noise.rs @@ -12,7 +12,7 @@ pub(crate) fn var_noise_gglwe_product( b_logq: usize, ) -> f64 { let a_logq: usize = a_logq.min(b_logq); - let a_cols: usize = (a_logq + basek - 1) / basek; + let a_cols: usize = a_logq.div_ceil(basek); let b_scale: f64 = (b_logq as f64).exp2(); let a_scale: f64 = ((b_logq - a_logq) as f64).exp2(); @@ -73,7 +73,7 @@ pub(crate) fn noise_ggsw_product( k_ggsw: usize, ) -> f64 { let a_logq: usize = k_in.min(k_ggsw); - let a_cols: usize = (a_logq + basek - 1) / basek; + let a_cols: usize = a_logq.div_ceil(basek); let b_scale: f64 = (k_ggsw as f64).exp2(); let a_scale: f64 = ((k_ggsw - a_logq) as f64).exp2(); diff --git a/core/src/scratch.rs b/core/src/scratch.rs new file mode 100644 index 0000000..43d18e4 --- /dev/null +++ b/core/src/scratch.rs @@ -0,0 +1,947 @@ +use backend::hal::{ + api::{TakeMatZnx, TakeScalarZnx, TakeSvpPPol, TakeVecZnx, TakeVecZnxDft, TakeVmpPMat}, + layouts::{Backend, DataRef, Module, Scratch}, + oep::{TakeMatZnxImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, TakeVmpPMatImpl}, +}; + +use crate::{ + AutomorphismKey, AutomorphismKeyExec, GGLWECiphertext, GGLWECiphertextExec, GGSWCiphertext, GGSWCiphertextExec, + GLWECiphertext, GLWEPlaintext, GLWEPublicKey, GLWEPublicKeyExec, GLWESecret, GLWESecretExec, GLWESwitchingKey, + GLWESwitchingKeyExec, GLWETensorKey, GLWETensorKeyExec, Infos, dist::Distribution, +}; + +pub trait TakeLike<'a, B: Backend, T> { + type Output; + fn take_like(&'a mut self, template: &T) -> (Self::Output, &'a mut Self); +} + +pub trait TakeGLWECt { + fn take_glwe_ct(&mut self, module: &Module, basek: usize, k: usize, rank: usize) + -> (GLWECiphertext<&mut [u8]>, &mut Self); +} + +pub trait TakeGLWECtSlice { + fn take_glwe_ct_slice( + &mut self, + size: usize, + module: &Module, + basek: usize, + k: usize, + rank: usize, + ) -> (Vec>, &mut Self); +} + +pub trait TakeGLWEPt { + fn take_glwe_pt(&mut self, module: &Module, basek: usize, k: usize) -> (GLWEPlaintext<&mut [u8]>, &mut Self); +} + +pub trait TakeGGLWE { + fn take_gglwe( + &mut self, + module: &Module, + basek: usize, + k: usize, + rows: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + ) -> (GGLWECiphertext<&mut [u8]>, &mut Self); +} + +pub trait TakeGGLWEExec { + fn take_gglwe_exec( + &mut self, + module: &Module, + basek: usize, + k: usize, + rows: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + ) -> (GGLWECiphertextExec<&mut [u8], B>, &mut Self); +} + +pub trait TakeGGSW { + fn take_ggsw( + &mut self, + module: &Module, + basek: usize, + k: usize, + rows: usize, + digits: usize, + rank: usize, + ) -> (GGSWCiphertext<&mut [u8]>, &mut Self); +} + +pub trait TakeGGSWExec { + fn take_ggsw_exec( + &mut self, + module: &Module, + basek: usize, + k: usize, + rows: usize, + digits: usize, + rank: usize, + ) -> (GGSWCiphertextExec<&mut [u8], B>, &mut Self); +} + +pub trait TakeGLWESecret { + fn take_glwe_secret(&mut self, module: &Module, rank: usize) -> (GLWESecret<&mut [u8]>, &mut Self); +} + +pub trait TakeGLWESecretExec { + fn take_glwe_secret_exec(&mut self, module: &Module, rank: usize) -> (GLWESecretExec<&mut [u8], B>, &mut Self); +} + +pub trait TakeGLWEPk { + fn take_glwe_pk(&mut self, module: &Module, basek: usize, k: usize, rank: usize) -> (GLWEPublicKey<&mut [u8]>, &mut Self); +} + +pub trait TakeGLWEPkExec { + fn take_glwe_pk_exec( + &mut self, + module: &Module, + basek: usize, + k: usize, + rank: usize, + ) -> (GLWEPublicKeyExec<&mut [u8], B>, &mut Self); +} + +pub trait TakeGLWESwitchingKey { + fn take_glwe_switching_key( + &mut self, + module: &Module, + basek: usize, + k: usize, + rows: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + ) -> (GLWESwitchingKey<&mut [u8]>, &mut Self); +} + +pub trait TakeGLWESwitchingKeyExec { + fn take_glwe_switching_key_exec( + &mut self, + module: &Module, + basek: usize, + k: usize, + rows: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + ) -> (GLWESwitchingKeyExec<&mut [u8], B>, &mut Self); +} + +pub trait TakeTensorKey { + fn take_tensor_key( + &mut self, + module: &Module, + basek: usize, + k: usize, + rows: usize, + digits: usize, + rank: usize, + ) -> (GLWETensorKey<&mut [u8]>, &mut Self); +} + +pub trait TakeTensorKeyExec { + fn take_tensor_key_exec( + &mut self, + module: &Module, + basek: usize, + k: usize, + rows: usize, + digits: usize, + rank: usize, + ) -> (GLWETensorKeyExec<&mut [u8], B>, &mut Self); +} + +pub trait TakeAutomorphismKey { + fn take_automorphism_key( + &mut self, + module: &Module, + basek: usize, + k: usize, + rows: usize, + digits: usize, + rank: usize, + ) -> (AutomorphismKey<&mut [u8]>, &mut Self); +} + +pub trait TakeAutomorphismKeyExec { + fn take_automorphism_key_exec( + &mut self, + module: &Module, + basek: usize, + k: usize, + rows: usize, + digits: usize, + rank: usize, + ) -> (AutomorphismKeyExec<&mut [u8], B>, &mut Self); +} + +impl TakeGLWECt for Scratch +where + Scratch: TakeVecZnx, +{ + fn take_glwe_ct( + &mut self, + module: &Module, + basek: usize, + k: usize, + rank: usize, + ) -> (GLWECiphertext<&mut [u8]>, &mut Self) { + let (data, scratch) = self.take_vec_znx(module, rank + 1, k.div_ceil(basek)); + (GLWECiphertext { data, basek, k }, scratch) + } +} + +impl<'a, B, D> TakeLike<'a, B, GLWECiphertext> for Scratch +where + B: Backend + TakeVecZnxImpl, + D: DataRef, +{ + type Output = GLWECiphertext<&'a mut [u8]>; + + fn take_like(&'a mut self, template: &GLWECiphertext) -> (Self::Output, &'a mut Self) { + let (data, scratch) = B::take_vec_znx_impl(self, template.n(), template.cols(), template.size()); + ( + GLWECiphertext { + data, + basek: template.basek(), + k: template.k(), + }, + scratch, + ) + } +} + +impl TakeGLWECtSlice for Scratch +where + Scratch: TakeVecZnx, +{ + fn take_glwe_ct_slice( + &mut self, + size: usize, + module: &Module, + basek: usize, + k: usize, + rank: usize, + ) -> (Vec>, &mut Self) { + let mut scratch: &mut Scratch = self; + let mut cts: Vec> = Vec::with_capacity(size); + for _ in 0..size { + let (ct, new_scratch) = scratch.take_glwe_ct(module, basek, k, rank); + scratch = new_scratch; + cts.push(ct); + } + (cts, scratch) + } +} + +impl TakeGLWEPt for Scratch +where + Scratch: TakeVecZnx, +{ + fn take_glwe_pt(&mut self, module: &Module, basek: usize, k: usize) -> (GLWEPlaintext<&mut [u8]>, &mut Self) { + let (data, scratch) = self.take_vec_znx(module, 1, k.div_ceil(basek)); + (GLWEPlaintext { data, basek, k }, scratch) + } +} + +impl<'a, B, D> TakeLike<'a, B, GLWEPlaintext> for Scratch +where + B: Backend + TakeVecZnxImpl, + D: DataRef, +{ + type Output = GLWEPlaintext<&'a mut [u8]>; + + fn take_like(&'a mut self, template: &GLWEPlaintext) -> (Self::Output, &'a mut Self) { + let (data, scratch) = B::take_vec_znx_impl(self, template.n(), template.cols(), template.size()); + ( + GLWEPlaintext { + data, + basek: template.basek(), + k: template.k(), + }, + scratch, + ) + } +} + +impl TakeGGLWE for Scratch +where + Scratch: TakeMatZnx, +{ + fn take_gglwe( + &mut self, + module: &Module, + basek: usize, + k: usize, + rows: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + ) -> (GGLWECiphertext<&mut [u8]>, &mut Self) { + let (data, scratch) = self.take_mat_znx( + module, + rows.div_ceil(digits), + rank_in, + rank_out + 1, + k.div_ceil(basek), + ); + ( + GGLWECiphertext { + data: data, + basek: basek, + k, + digits, + }, + scratch, + ) + } +} + +impl<'a, B, D> TakeLike<'a, B, GGLWECiphertext> for Scratch +where + B: Backend + TakeMatZnxImpl, + D: DataRef, +{ + type Output = GGLWECiphertext<&'a mut [u8]>; + + fn take_like(&'a mut self, template: &GGLWECiphertext) -> (Self::Output, &'a mut Self) { + let (data, scratch) = B::take_mat_znx_impl( + self, + template.n(), + template.rows(), + template.data.cols_in(), + template.data.cols_out(), + template.size(), + ); + ( + GGLWECiphertext { + data, + basek: template.basek(), + k: template.k(), + digits: template.digits(), + }, + scratch, + ) + } +} + +impl TakeGGLWEExec for Scratch +where + Scratch: TakeVmpPMat, +{ + fn take_gglwe_exec( + &mut self, + module: &Module, + basek: usize, + k: usize, + rows: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + ) -> (GGLWECiphertextExec<&mut [u8], B>, &mut Self) { + let (data, scratch) = self.take_vmp_pmat( + module, + rows.div_ceil(digits), + rank_in, + rank_out + 1, + k.div_ceil(basek), + ); + ( + GGLWECiphertextExec { + data: data, + basek: basek, + k, + digits, + }, + scratch, + ) + } +} + +impl<'a, B, D> TakeLike<'a, B, GGLWECiphertextExec> for Scratch +where + B: Backend + TakeVmpPMatImpl, + D: DataRef, +{ + type Output = GGLWECiphertextExec<&'a mut [u8], B>; + + fn take_like(&'a mut self, template: &GGLWECiphertextExec) -> (Self::Output, &'a mut Self) { + let (data, scratch) = B::take_vmp_pmat_impl( + self, + template.n(), + template.rows(), + template.data.cols_in(), + template.data.cols_out(), + template.size(), + ); + ( + GGLWECiphertextExec { + data, + basek: template.basek(), + k: template.k(), + digits: template.digits(), + }, + scratch, + ) + } +} + +impl TakeGGSW for Scratch +where + Scratch: TakeMatZnx, +{ + fn take_ggsw( + &mut self, + module: &Module, + basek: usize, + k: usize, + rows: usize, + digits: usize, + rank: usize, + ) -> (GGSWCiphertext<&mut [u8]>, &mut Self) { + let (data, scratch) = self.take_mat_znx( + module, + rows.div_ceil(digits), + rank + 1, + rank + 1, + k.div_ceil(basek), + ); + ( + GGSWCiphertext { + data, + basek, + k, + digits, + }, + scratch, + ) + } +} + +impl<'a, B, D> TakeLike<'a, B, GGSWCiphertext> for Scratch +where + B: Backend + TakeMatZnxImpl, + D: DataRef, +{ + type Output = GGSWCiphertext<&'a mut [u8]>; + + fn take_like(&'a mut self, template: &GGSWCiphertext) -> (Self::Output, &'a mut Self) { + let (data, scratch) = B::take_mat_znx_impl( + self, + template.n(), + template.rows(), + template.data.cols_in(), + template.data.cols_out(), + template.size(), + ); + ( + GGSWCiphertext { + data, + basek: template.basek(), + k: template.k(), + digits: template.digits(), + }, + scratch, + ) + } +} + +impl TakeGGSWExec for Scratch +where + Scratch: TakeVmpPMat, +{ + fn take_ggsw_exec( + &mut self, + module: &Module, + basek: usize, + k: usize, + rows: usize, + digits: usize, + rank: usize, + ) -> (GGSWCiphertextExec<&mut [u8], B>, &mut Self) { + let (data, scratch) = self.take_vmp_pmat( + module, + rows.div_ceil(digits), + rank + 1, + rank + 1, + k.div_ceil(basek), + ); + ( + GGSWCiphertextExec { + data, + basek, + k, + digits, + }, + scratch, + ) + } +} + +impl<'a, B, D> TakeLike<'a, B, GGSWCiphertextExec> for Scratch +where + B: Backend + TakeVmpPMatImpl, + D: DataRef, +{ + type Output = GGSWCiphertextExec<&'a mut [u8], B>; + + fn take_like(&'a mut self, template: &GGSWCiphertextExec) -> (Self::Output, &'a mut Self) { + let (data, scratch) = B::take_vmp_pmat_impl( + self, + template.n(), + template.rows(), + template.data.cols_in(), + template.data.cols_out(), + template.size(), + ); + ( + GGSWCiphertextExec { + data, + basek: template.basek(), + k: template.k(), + digits: template.digits(), + }, + scratch, + ) + } +} + +impl TakeGLWEPk for Scratch +where + Scratch: TakeVecZnx, +{ + fn take_glwe_pk(&mut self, module: &Module, basek: usize, k: usize, rank: usize) -> (GLWEPublicKey<&mut [u8]>, &mut Self) { + let (data, scratch) = self.take_vec_znx(module, rank + 1, k.div_ceil(basek)); + ( + GLWEPublicKey { + data, + k, + basek, + dist: Distribution::NONE, + }, + scratch, + ) + } +} + +impl<'a, B, D> TakeLike<'a, B, GLWEPublicKey> for Scratch +where + B: Backend + TakeVecZnxImpl, + D: DataRef, +{ + type Output = GLWEPublicKey<&'a mut [u8]>; + + fn take_like(&'a mut self, template: &GLWEPublicKey) -> (Self::Output, &'a mut Self) { + let (data, scratch) = B::take_vec_znx_impl(self, template.n(), template.cols(), template.size()); + ( + GLWEPublicKey { + data, + basek: template.basek(), + k: template.k(), + dist: template.dist, + }, + scratch, + ) + } +} + +impl TakeGLWEPkExec for Scratch +where + Scratch: TakeVecZnxDft, +{ + fn take_glwe_pk_exec( + &mut self, + module: &Module, + basek: usize, + k: usize, + rank: usize, + ) -> (GLWEPublicKeyExec<&mut [u8], B>, &mut Self) { + let (data, scratch) = self.take_vec_znx_dft(module, rank + 1, k.div_ceil(basek)); + ( + GLWEPublicKeyExec { + data, + k, + basek, + dist: Distribution::NONE, + }, + scratch, + ) + } +} + +impl<'a, B, D> TakeLike<'a, B, GLWEPublicKeyExec> for Scratch +where + B: Backend + TakeVecZnxDftImpl, + D: DataRef, +{ + type Output = GLWEPublicKeyExec<&'a mut [u8], B>; + + fn take_like(&'a mut self, template: &GLWEPublicKeyExec) -> (Self::Output, &'a mut Self) { + let (data, scratch) = B::take_vec_znx_dft_impl(self, template.n(), template.cols(), template.size()); + ( + GLWEPublicKeyExec { + data, + basek: template.basek(), + k: template.k(), + dist: template.dist, + }, + scratch, + ) + } +} + +impl TakeGLWESecret for Scratch +where + Scratch: TakeScalarZnx, +{ + fn take_glwe_secret(&mut self, module: &Module, rank: usize) -> (GLWESecret<&mut [u8]>, &mut Self) { + let (data, scratch) = self.take_scalar_znx(module, rank); + ( + GLWESecret { + data, + dist: Distribution::NONE, + }, + scratch, + ) + } +} + +impl<'a, B, D> TakeLike<'a, B, GLWESecret> for Scratch +where + B: Backend + TakeScalarZnxImpl, + D: DataRef, +{ + type Output = GLWESecret<&'a mut [u8]>; + + fn take_like(&'a mut self, template: &GLWESecret) -> (Self::Output, &'a mut Self) { + let (data, scratch) = B::take_scalar_znx_impl(self, template.n(), template.rank()); + ( + GLWESecret { + data, + dist: template.dist, + }, + scratch, + ) + } +} + +impl TakeGLWESecretExec for Scratch +where + Scratch: TakeSvpPPol, +{ + fn take_glwe_secret_exec(&mut self, module: &Module, rank: usize) -> (GLWESecretExec<&mut [u8], B>, &mut Self) { + let (data, scratch) = self.take_svp_ppol(module, rank); + ( + GLWESecretExec { + data, + dist: Distribution::NONE, + }, + scratch, + ) + } +} + +impl<'a, B, D> TakeLike<'a, B, GLWESecretExec> for Scratch +where + B: Backend + TakeSvpPPolImpl, + D: DataRef, +{ + type Output = GLWESecretExec<&'a mut [u8], B>; + + fn take_like(&'a mut self, template: &GLWESecretExec) -> (Self::Output, &'a mut Self) { + let (data, scratch) = B::take_svp_ppol_impl(self, template.n(), template.rank()); + ( + GLWESecretExec { + data, + dist: template.dist, + }, + scratch, + ) + } +} + +impl TakeGLWESwitchingKey for Scratch +where + Scratch: TakeMatZnx, +{ + fn take_glwe_switching_key( + &mut self, + module: &Module, + basek: usize, + k: usize, + rows: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + ) -> (GLWESwitchingKey<&mut [u8]>, &mut Self) { + let (data, scratch) = self.take_gglwe(module, basek, k, rows, digits, rank_in, rank_out); + ( + GLWESwitchingKey { + key: data, + sk_in_n: 0, + sk_out_n: 0, + }, + scratch, + ) + } +} + +impl<'a, B, D> TakeLike<'a, B, GLWESwitchingKey> for Scratch +where + Scratch: TakeLike<'a, B, GGLWECiphertext, Output = GGLWECiphertext<&'a mut [u8]>>, + B: Backend + TakeMatZnxImpl, + D: DataRef, +{ + type Output = GLWESwitchingKey<&'a mut [u8]>; + + fn take_like(&'a mut self, template: &GLWESwitchingKey) -> (Self::Output, &'a mut Self) { + let (key, scratch) = self.take_like(&template.key); + ( + GLWESwitchingKey { + key, + sk_in_n: template.sk_in_n, + sk_out_n: template.sk_out_n, + }, + scratch, + ) + } +} + +impl TakeGLWESwitchingKeyExec for Scratch +where + Scratch: TakeGGLWEExec, +{ + fn take_glwe_switching_key_exec( + &mut self, + module: &Module, + basek: usize, + k: usize, + rows: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + ) -> (GLWESwitchingKeyExec<&mut [u8], B>, &mut Self) { + let (data, scratch) = self.take_gglwe_exec(module, basek, k, rows, digits, rank_in, rank_out); + ( + GLWESwitchingKeyExec { + key: data, + sk_in_n: 0, + sk_out_n: 0, + }, + scratch, + ) + } +} + +impl<'a, B, D> TakeLike<'a, B, GLWESwitchingKeyExec> for Scratch +where + Scratch: TakeLike<'a, B, GGLWECiphertextExec, Output = GGLWECiphertextExec<&'a mut [u8], B>>, + B: Backend + TakeMatZnxImpl, + D: DataRef, +{ + type Output = GLWESwitchingKeyExec<&'a mut [u8], B>; + + fn take_like(&'a mut self, template: &GLWESwitchingKeyExec) -> (Self::Output, &'a mut Self) { + let (key, scratch) = self.take_like(&template.key); + ( + GLWESwitchingKeyExec { + key, + sk_in_n: template.sk_in_n, + sk_out_n: template.sk_out_n, + }, + scratch, + ) + } +} + +impl TakeAutomorphismKey for Scratch +where + Scratch: TakeMatZnx, +{ + fn take_automorphism_key( + &mut self, + module: &Module, + basek: usize, + k: usize, + rows: usize, + digits: usize, + rank: usize, + ) -> (AutomorphismKey<&mut [u8]>, &mut Self) { + let (data, scratch) = self.take_glwe_switching_key(module, basek, k, rows, digits, rank, rank); + (AutomorphismKey { key: data, p: 0 }, scratch) + } +} + +impl<'a, B, D> TakeLike<'a, B, AutomorphismKey> for Scratch +where + Scratch: TakeLike<'a, B, GLWESwitchingKey, Output = GLWESwitchingKey<&'a mut [u8]>>, + B: Backend + TakeMatZnxImpl, + D: DataRef, +{ + type Output = AutomorphismKey<&'a mut [u8]>; + + fn take_like(&'a mut self, template: &AutomorphismKey) -> (Self::Output, &'a mut Self) { + let (key, scratch) = self.take_like(&template.key); + (AutomorphismKey { key, p: template.p }, scratch) + } +} + +impl TakeAutomorphismKeyExec for Scratch +where + Scratch: TakeGLWESwitchingKeyExec, +{ + fn take_automorphism_key_exec( + &mut self, + module: &Module, + basek: usize, + k: usize, + rows: usize, + digits: usize, + rank: usize, + ) -> (AutomorphismKeyExec<&mut [u8], B>, &mut Self) { + let (data, scratch) = self.take_glwe_switching_key_exec(module, basek, k, rows, digits, rank, rank); + (AutomorphismKeyExec { key: data, p: 0 }, scratch) + } +} + +impl<'a, B, D> TakeLike<'a, B, AutomorphismKeyExec> for Scratch +where + Scratch: TakeLike<'a, B, GLWESwitchingKeyExec, Output = GLWESwitchingKeyExec<&'a mut [u8], B>>, + B: Backend + TakeMatZnxImpl, + D: DataRef, +{ + type Output = AutomorphismKeyExec<&'a mut [u8], B>; + + fn take_like(&'a mut self, template: &AutomorphismKeyExec) -> (Self::Output, &'a mut Self) { + let (key, scratch) = self.take_like(&template.key); + (AutomorphismKeyExec { key, p: template.p }, scratch) + } +} + +impl TakeTensorKey for Scratch +where + Scratch: TakeMatZnx, +{ + fn take_tensor_key( + &mut self, + module: &Module, + basek: usize, + k: usize, + rows: usize, + digits: usize, + rank: usize, + ) -> (GLWETensorKey<&mut [u8]>, &mut Self) { + let mut keys: Vec> = Vec::new(); + let pairs: usize = (((rank + 1) * rank) >> 1).max(1); + + let mut scratch: &mut Scratch = self; + + if pairs != 0 { + let (gglwe, s) = scratch.take_glwe_switching_key(module, basek, k, rows, digits, 1, rank); + scratch = s; + keys.push(gglwe); + } + for _ in 1..pairs { + let (gglwe, s) = scratch.take_glwe_switching_key(module, basek, k, rows, digits, 1, rank); + scratch = s; + keys.push(gglwe); + } + (GLWETensorKey { keys }, scratch) + } +} + +impl<'a, B, D> TakeLike<'a, B, GLWETensorKey> for Scratch +where + Scratch: TakeLike<'a, B, GLWESwitchingKey, Output = GLWESwitchingKey<&'a mut [u8]>>, + B: Backend + TakeMatZnxImpl, + D: DataRef, +{ + type Output = GLWETensorKey<&'a mut [u8]>; + + fn take_like(&'a mut self, template: &GLWETensorKey) -> (Self::Output, &'a mut Self) { + let mut keys: Vec> = Vec::new(); + let pairs: usize = template.keys.len(); + + let mut scratch: &mut Scratch = self; + + if pairs != 0 { + let (gglwe, s) = scratch.take_like(template.at(0, 0)); + scratch = s; + keys.push(gglwe); + } + for _ in 1..pairs { + let (gglwe, s) = scratch.take_like(template.at(0, 0)); + scratch = s; + keys.push(gglwe); + } + + (GLWETensorKey { keys }, scratch) + } +} + +impl TakeTensorKeyExec for Scratch +where + Scratch: TakeVmpPMat, +{ + fn take_tensor_key_exec( + &mut self, + module: &Module, + basek: usize, + k: usize, + rows: usize, + digits: usize, + rank: usize, + ) -> (GLWETensorKeyExec<&mut [u8], B>, &mut Self) { + let mut keys: Vec> = Vec::new(); + let pairs: usize = (((rank + 1) * rank) >> 1).max(1); + + let mut scratch: &mut Scratch = self; + + if pairs != 0 { + let (gglwe, s) = scratch.take_glwe_switching_key_exec(module, basek, k, rows, digits, 1, rank); + scratch = s; + keys.push(gglwe); + } + for _ in 1..pairs { + let (gglwe, s) = scratch.take_glwe_switching_key_exec(module, basek, k, rows, digits, 1, rank); + scratch = s; + keys.push(gglwe); + } + (GLWETensorKeyExec { keys }, scratch) + } +} + +impl<'a, B, D> TakeLike<'a, B, GLWETensorKeyExec> for Scratch +where + Scratch: TakeLike<'a, B, GLWESwitchingKeyExec, Output = GLWESwitchingKeyExec<&'a mut [u8], B>>, + B: Backend + TakeMatZnxImpl, + D: DataRef, +{ + type Output = GLWETensorKeyExec<&'a mut [u8], B>; + + fn take_like(&'a mut self, template: &GLWETensorKeyExec) -> (Self::Output, &'a mut Self) { + let mut keys: Vec> = Vec::new(); + let pairs: usize = template.keys.len(); + + let mut scratch: &mut Scratch = self; + + if pairs != 0 { + let (gglwe, s) = scratch.take_like(template.at(0, 0)); + scratch = s; + keys.push(gglwe); + } + for _ in 1..pairs { + let (gglwe, s) = scratch.take_like(template.at(0, 0)); + scratch = s; + keys.push(gglwe); + } + + (GLWETensorKeyExec { keys }, scratch) + } +} diff --git a/poulpy.png b/poulpy.png index 264f947..23b4935 100644 Binary files a/poulpy.png and b/poulpy.png differ diff --git a/sampling/src/distributions.rs b/sampling/src/distributions.rs deleted file mode 100644 index 2093bc0..0000000 --- a/sampling/src/distributions.rs +++ /dev/null @@ -1,7 +0,0 @@ -use rand_distr::{Distribution, Normal, Binomial}; - -pub enum Distributions{ - Binonial(Binomial), - Normal(Normal), - Ternary() -} \ No newline at end of file diff --git a/sampling/src/source.rs b/sampling/src/source.rs index fe5f641..d352bfe 100644 --- a/sampling/src/source.rs +++ b/sampling/src/source.rs @@ -1,5 +1,4 @@ -use rand_chacha::ChaCha8Rng; -use rand_chacha::rand_core::SeedableRng; +use rand_chacha::{ChaCha8Rng, rand_core::SeedableRng}; use rand_core::RngCore; const MAXF64: f64 = 9007199254740992.0; diff --git a/utils/src/map.rs b/utils/src/map.rs index 9709ada..94b824f 100644 --- a/utils/src/map.rs +++ b/utils/src/map.rs @@ -1,6 +1,7 @@ -use fnv::FnvHashMap; use std::hash::Hash; +use fnv::FnvHashMap; + pub struct Map(pub FnvHashMap); impl Map {