updated repo for publishing (#74)

This commit is contained in:
Jean-Philippe Bossuat
2025-08-17 14:57:39 +02:00
committed by GitHub
parent 0be569eca0
commit 62eb87cc07
244 changed files with 374 additions and 539 deletions

27
poulpy-backend/Cargo.toml Normal file
View File

@@ -0,0 +1,27 @@
[package]
name = "poulpy-backend"
version = "0.1.0"
edition = "2024"
license = "Apache-2.0"
readme = "README.md"
description = "A crate implementing bivariate polynomial arithmetic"
repository = "https://github.com/phantomzone-org/poulpy"
homepage = "https://github.com/phantomzone-org/poulpy"
documentation = "https://docs.rs/poulpy"
[dependencies]
rug = {workspace = true}
criterion = {workspace = true}
itertools = {workspace = true}
rand = {workspace = true}
rand_distr = {workspace = true}
rand_core = {workspace = true}
byteorder = {workspace = true}
rand_chacha = "0.9.0"
[build-dependencies]
cmake = "0.1.54"
[package.metadata.docs.rs]
all-features = true
rustdoc-args = ["--cfg", "docsrs"]

12
poulpy-backend/README.md Normal file
View File

@@ -0,0 +1,12 @@
## WSL/Ubuntu
To use this crate you need to build spqlios-arithmetic, which is provided a as a git submodule:
1) Initialize the sub-module
2) $ cd backend/spqlios-arithmetic
3) mdkir build
4) cd build
5) cmake ..
6) make
## Others
Steps 3 to 6 might change depending of your platform. See [spqlios-arithmetic/wiki/build](https://github.com/tfhe/spqlios-arithmetic/wiki/build) for additional information and build options.

7
poulpy-backend/build.rs Normal file
View File

@@ -0,0 +1,7 @@
mod builds {
pub mod cpu_spqlios;
}
fn main() {
builds::cpu_spqlios::build()
}

View File

@@ -0,0 +1,12 @@
use std::path::PathBuf;
pub fn build() {
let dst: PathBuf = cmake::Config::new("src/implementation/cpu_spqlios/spqlios-arithmetic")
.define("ENABLE_TESTING", "FALSE")
.build();
let lib_dir: PathBuf = dst.join("lib");
println!("cargo:rustc-link-search=native={}", lib_dir.display());
println!("cargo:rustc-link-lib=static=spqlios");
}

View File

@@ -0,0 +1,27 @@
Implementors must uphold all of the following for **every** call:
* **Memory domains**: Pointers produced by to_ref() / to_mut() must be valid
in the target execution domain for Self (e.g., CPU host memory for CPU,
device memory for a specific GPU). If host↔device transfers are required,
perform them inside the implementation; do not assume the caller synchronized.
* **Alignment & layout**: All data must match the layout, stride, and element
size expected by the kernel. size(), rows(), cols_in(), cols_out(),
n(), etc... must be interpreted identically to the reference CPU implementation.
* **Scratch lifetime**: Any scratch obtained from scratch.tmp_slice(...) (or a
backend-specific variant) must remain valid for the duration of the call; it
may be reused by the caller afterwards. Do not retain pointers past return.
* **Synchronization**: The call must appear **logically synchronous** to the
caller. If you enqueue asynchronous work (e.g., CUDA streams), you must
ensure completion before returning or clearly document and implement a
synchronization contract used by all backends consistently.
* **Aliasing & overlaps**: If res, a, b, etc... alias or overlap in ways
that violate your kernels requirements, you must either handle safely or reject
with a defined error path (e.g., debug assert). Never trigger UB.
* **Numerical contract**: For modular/integer arithmetic, results must be
bit-exact to the specification. For floating-point, any permitted tolerance
must be documented and consistent with the crates guarantees.

View File

@@ -0,0 +1,141 @@
use itertools::izip;
use poulpy_backend::{
hal::{
api::{
ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPrepare, VecZnxAddNormal,
VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace,
VecZnxDftAlloc, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigTmpA, VecZnxFillUniform, VecZnxNormalizeInplace, ZnxInfos,
},
layouts::{Module, ScalarZnx, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft},
source::Source,
},
implementation::cpu_spqlios::FFT64,
};
fn main() {
let n: usize = 16;
let basek: usize = 18;
let ct_size: usize = 3;
let msg_size: usize = 2;
let log_scale: usize = msg_size * basek - 5;
let module: Module<FFT64> = Module::<FFT64>::new(n as u64);
let mut scratch: ScratchOwned<FFT64> = ScratchOwned::<FFT64>::alloc(module.vec_znx_big_normalize_tmp_bytes(n));
let seed: [u8; 32] = [0; 32];
let mut source: Source = Source::new(seed);
// s <- Z_{-1, 0, 1}[X]/(X^{N}+1)
let mut s: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(module.n(), 1);
s.fill_ternary_prob(0, 0.5, &mut source);
// Buffer to store s in the DFT domain
let mut s_dft: SvpPPol<Vec<u8>, FFT64> = module.svp_ppol_alloc(n, s.cols());
// s_dft <- DFT(s)
module.svp_prepare(&mut s_dft, 0, &s, 0);
// Allocates a VecZnx with two columns: ct=(0, 0)
let mut ct: VecZnx<Vec<u8>> = VecZnx::alloc(
module.n(),
2, // Number of columns
ct_size, // Number of small poly per column
);
// Fill the second column with random values: ct = (0, a)
module.vec_znx_fill_uniform(basek, &mut ct, 1, ct_size * basek, &mut source);
let mut buf_dft: VecZnxDft<Vec<u8>, FFT64> = module.vec_znx_dft_alloc(n, 1, ct_size);
module.vec_znx_dft_from_vec_znx(1, 0, &mut buf_dft, 0, &ct, 1);
// Applies DFT(ct[1]) * DFT(s)
module.svp_apply_inplace(
&mut buf_dft, // DFT(ct[1] * s)
0, // Selects the first column of res
&s_dft, // DFT(s)
0, // Selects the first column of s_dft
);
// Alias scratch space (VecZnxDft<B> is always at least as big as VecZnxBig<B>)
// BIG(ct[1] * s) <- IDFT(DFT(ct[1] * s)) (not normalized)
let mut buf_big: VecZnxBig<Vec<u8>, FFT64> = module.vec_znx_big_alloc(n, 1, ct_size);
module.vec_znx_dft_to_vec_znx_big_tmp_a(&mut buf_big, 0, &mut buf_dft, 0);
// Creates a plaintext: VecZnx with 1 column
let mut m = VecZnx::alloc(
module.n(),
1, // Number of columns
msg_size, // Number of small polynomials
);
let mut want: Vec<i64> = vec![0; n];
want.iter_mut()
.for_each(|x| *x = source.next_u64n(16, 15) as i64);
m.encode_vec_i64(basek, 0, log_scale, &want, 4);
module.vec_znx_normalize_inplace(basek, &mut m, 0, scratch.borrow());
// m - BIG(ct[1] * s)
module.vec_znx_big_sub_small_b_inplace(
&mut buf_big,
0, // Selects the first column of the receiver
&m,
0, // Selects the first column of the message
);
// Normalizes back to VecZnx
// ct[0] <- m - BIG(c1 * s)
module.vec_znx_big_normalize(
basek,
&mut ct,
0, // Selects the first column of ct (ct[0])
&buf_big,
0, // Selects the first column of buf_big
scratch.borrow(),
);
// Add noise to ct[0]
// ct[0] <- ct[0] + e
module.vec_znx_add_normal(
basek,
&mut ct,
0, // Selects the first column of ct (ct[0])
basek * ct_size, // Scaling of the noise: 2^{-basek * limbs}
&mut source,
3.2, // Standard deviation
3.2 * 6.0, // Truncatation bound
);
// Final ciphertext: ct = (-a * s + m + e, a)
// Decryption
// DFT(ct[1] * s)
module.vec_znx_dft_from_vec_znx(1, 0, &mut buf_dft, 0, &ct, 1);
module.svp_apply_inplace(
&mut buf_dft,
0, // Selects the first column of res.
&s_dft,
0,
);
// BIG(c1 * s) = IDFT(DFT(c1 * s))
module.vec_znx_dft_to_vec_znx_big_tmp_a(&mut buf_big, 0, &mut buf_dft, 0);
// BIG(c1 * s) + ct[0]
module.vec_znx_big_add_small_inplace(&mut buf_big, 0, &ct, 0);
// m + e <- BIG(ct[1] * s + ct[0])
let mut res = VecZnx::alloc(module.n(), 1, ct_size);
module.vec_znx_big_normalize(basek, &mut res, 0, &buf_big, 0, scratch.borrow());
// have = m * 2^{log_scale} + e
let mut have: Vec<i64> = vec![i64::default(); n];
res.decode_vec_i64(basek, 0, ct_size * basek, &mut have);
let scale: f64 = (1 << (res.size() * basek - log_scale)) as f64;
izip!(want.iter(), have.iter())
.enumerate()
.for_each(|(i, (a, b))| {
println!("{}: {} {}", i, a, (*b as f64) / scale);
});
}

View File

@@ -0,0 +1,17 @@
mod module;
mod scratch;
mod svp_ppol;
mod vec_znx;
mod vec_znx_big;
mod vec_znx_dft;
mod vmp_pmat;
mod znx_base;
pub use module::*;
pub use scratch::*;
pub use svp_ppol::*;
pub use vec_znx::*;
pub use vec_znx_big::*;
pub use vec_znx_dft::*;
pub use vmp_pmat::*;
pub use znx_base::*;

View File

@@ -0,0 +1,6 @@
use crate::hal::layouts::Backend;
/// Instantiate a new [crate::hal::layouts::Module].
pub trait ModuleNew<B: Backend> {
fn new(n: u64) -> Self;
}

View File

@@ -0,0 +1,107 @@
use crate::hal::layouts::{Backend, MatZnx, ScalarZnx, Scratch, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat};
/// Allocates a new [crate::hal::layouts::ScratchOwned] of `size` aligned bytes.
pub trait ScratchOwnedAlloc<B: Backend> {
fn alloc(size: usize) -> Self;
}
/// Borrows a slice of bytes into a [Scratch].
pub trait ScratchOwnedBorrow<B: Backend> {
fn borrow(&mut self) -> &mut Scratch<B>;
}
/// Wrap an array of mutable borrowed bytes into a [Scratch].
pub trait ScratchFromBytes<B: Backend> {
fn from_bytes(data: &mut [u8]) -> &mut Scratch<B>;
}
/// Returns how many bytes left can be taken from the scratch.
pub trait ScratchAvailable {
fn available(&self) -> usize;
}
/// Takes a slice of bytes from a [Scratch] and return a new [Scratch] minus the taken array of bytes.
pub trait TakeSlice {
fn take_slice<T>(&mut self, len: usize) -> (&mut [T], &mut Self);
}
/// Take a slice of bytes from a [Scratch], wraps it into a [ScalarZnx] and returns it
/// as well as a new [Scratch] minus the taken array of bytes.
pub trait TakeScalarZnx {
fn take_scalar_znx(&mut self, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self);
}
/// Take a slice of bytes from a [Scratch], wraps it into a [SvpPPol] and returns it
/// as well as a new [Scratch] minus the taken array of bytes.
pub trait TakeSvpPPol<B: Backend> {
fn take_svp_ppol(&mut self, n: usize, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Self);
}
/// Take a slice of bytes from a [Scratch], wraps it into a [VecZnx] and returns it
/// as well as a new [Scratch] minus the taken array of bytes.
pub trait TakeVecZnx {
fn take_vec_znx(&mut self, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self);
}
/// Take a slice of bytes from a [Scratch], slices it into a vector of [VecZnx] aand returns it
/// as well as a new [Scratch] minus the taken array of bytes.
pub trait TakeVecZnxSlice {
fn take_vec_znx_slice(&mut self, len: usize, n: usize, cols: usize, size: usize) -> (Vec<VecZnx<&mut [u8]>>, &mut Self);
}
/// Take a slice of bytes from a [Scratch], wraps it into a [VecZnxBig] and returns it
/// as well as a new [Scratch] minus the taken array of bytes.
pub trait TakeVecZnxBig<B: Backend> {
fn take_vec_znx_big(&mut self, n: usize, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, &mut Self);
}
/// Take a slice of bytes from a [Scratch], wraps it into a [VecZnxDft] and returns it
/// as well as a new [Scratch] minus the taken array of bytes.
pub trait TakeVecZnxDft<B: Backend> {
fn take_vec_znx_dft(&mut self, n: usize, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, &mut Self);
}
/// Take a slice of bytes from a [Scratch], slices it into a vector of [VecZnxDft] and returns it
/// as well as a new [Scratch] minus the taken array of bytes.
pub trait TakeVecZnxDftSlice<B: Backend> {
fn take_vec_znx_dft_slice(
&mut self,
len: usize,
n: usize,
cols: usize,
size: usize,
) -> (Vec<VecZnxDft<&mut [u8], B>>, &mut Self);
}
/// Take a slice of bytes from a [Scratch], wraps it into a [VmpPMat] and returns it
/// as well as a new [Scratch] minus the taken array of bytes.
pub trait TakeVmpPMat<B: Backend> {
fn take_vmp_pmat(
&mut self,
n: usize,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
) -> (VmpPMat<&mut [u8], B>, &mut Self);
}
/// Take a slice of bytes from a [Scratch], wraps it into a [MatZnx] and returns it
/// as well as a new [Scratch] minus the taken array of bytes.
pub trait TakeMatZnx {
fn take_mat_znx(
&mut self,
n: usize,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
) -> (MatZnx<&mut [u8]>, &mut Self);
}
/// Take a slice of bytes from a [Scratch], wraps it into the template's type and returns it
/// as well as a new [Scratch] minus the taken array of bytes.
pub trait TakeLike<'a, B: Backend, T> {
type Output;
fn take_like(&'a mut self, template: &T) -> (Self::Output, &'a mut Self);
}

View File

@@ -0,0 +1,42 @@
use crate::hal::layouts::{Backend, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDftToMut, VecZnxDftToRef};
/// Allocates as [crate::hal::layouts::SvpPPol].
pub trait SvpPPolAlloc<B: Backend> {
fn svp_ppol_alloc(&self, n: usize, cols: usize) -> SvpPPolOwned<B>;
}
/// Returns the size in bytes to allocate a [crate::hal::layouts::SvpPPol].
pub trait SvpPPolAllocBytes {
fn svp_ppol_alloc_bytes(&self, n: usize, cols: usize) -> usize;
}
/// Consume a vector of bytes into a [crate::hal::layouts::MatZnx].
/// User must ensure that bytes is memory aligned and that it length is equal to [SvpPPolAllocBytes].
pub trait SvpPPolFromBytes<B: Backend> {
fn svp_ppol_from_bytes(&self, n: usize, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<B>;
}
/// Prepare a [crate::hal::layouts::ScalarZnx] into an [crate::hal::layouts::SvpPPol].
pub trait SvpPrepare<B: Backend> {
fn svp_prepare<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: SvpPPolToMut<B>,
A: ScalarZnxToRef;
}
/// Apply a scalar-vector product between `a[a_col]` and `b[b_col]` and stores the result on `res[res_col]`.
pub trait SvpApply<B: Backend> {
fn svp_apply<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
where
R: VecZnxDftToMut<B>,
A: SvpPPolToRef<B>,
C: VecZnxDftToRef<B>;
}
/// Apply a scalar-vector product between `res[res_col]` and `a[a_col]` and stores the result on `res[res_col]`.
pub trait SvpApplyInplace<B: Backend> {
fn svp_apply_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<B>,
A: SvpPPolToRef<B>;
}

View File

@@ -0,0 +1,269 @@
use rand_distr::Distribution;
use crate::hal::{
layouts::{Backend, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef},
source::Source,
};
pub trait VecZnxNormalizeTmpBytes {
/// Returns the minimum number of bytes necessary for normalization.
fn vec_znx_normalize_tmp_bytes(&self, n: usize) -> usize;
}
pub trait VecZnxNormalize<B: Backend> {
/// Normalizes the selected column of `a` and stores the result into the selected column of `res`.
fn vec_znx_normalize<R, A>(&self, basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
where
R: VecZnxToMut,
A: VecZnxToRef;
}
pub trait VecZnxNormalizeInplace<B: Backend> {
/// Normalizes the selected column of `a`.
fn vec_znx_normalize_inplace<A>(&self, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
where
A: VecZnxToMut;
}
pub trait VecZnxAdd {
/// Adds the selected column of `a` to the selected column of `b` and writes the result on the selected column of `res`.
fn vec_znx_add<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
B: VecZnxToRef;
}
pub trait VecZnxAddInplace {
/// Adds the selected column of `a` to the selected column of `res` and writes the result on the selected column of `res`.
fn vec_znx_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
}
pub trait VecZnxAddScalarInplace {
/// Adds the selected column of `a` on the selected column and limb of `res`.
fn vec_znx_add_scalar_inplace<R, A>(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: ScalarZnxToRef;
}
pub trait VecZnxSub {
/// Subtracts the selected column of `b` from the selected column of `a` and writes the result on the selected column of `res`.
fn vec_znx_sub<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
B: VecZnxToRef;
}
pub trait VecZnxSubABInplace {
/// Subtracts the selected column of `a` from the selected column of `res` inplace.
///
/// res\[res_col\] -= a\[a_col\]
fn vec_znx_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
}
pub trait VecZnxSubBAInplace {
/// Subtracts the selected column of `res` from the selected column of `a` and inplace mutates `res`
///
/// res\[res_col\] = a\[a_col\] - res\[res_col\]
fn vec_znx_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
}
pub trait VecZnxSubScalarInplace {
/// Subtracts the selected column of `a` on the selected column and limb of `res`.
fn vec_znx_sub_scalar_inplace<R, A>(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: ScalarZnxToRef;
}
pub trait VecZnxNegate {
// Negates the selected column of `a` and stores the result in `res_col` of `res`.
fn vec_znx_negate<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
}
pub trait VecZnxNegateInplace {
/// Negates the selected column of `a`.
fn vec_znx_negate_inplace<A>(&self, a: &mut A, a_col: usize)
where
A: VecZnxToMut;
}
pub trait VecZnxLshInplace {
/// Left shift by k bits all columns of `a`.
fn vec_znx_lsh_inplace<A>(&self, basek: usize, k: usize, a: &mut A)
where
A: VecZnxToMut;
}
pub trait VecZnxRshInplace {
/// Right shift by k bits all columns of `a`.
fn vec_znx_rsh_inplace<A>(&self, basek: usize, k: usize, a: &mut A)
where
A: VecZnxToMut;
}
pub trait VecZnxRotate {
/// Multiplies the selected column of `a` by X^k and stores the result in `res_col` of `res`.
fn vec_znx_rotate<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
}
pub trait VecZnxRotateInplace {
/// Multiplies the selected column of `a` by X^k.
fn vec_znx_rotate_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
where
A: VecZnxToMut;
}
pub trait VecZnxAutomorphism {
/// Applies the automorphism X^i -> X^ik on the selected column of `a` and stores the result in `res_col` column of `res`.
fn vec_znx_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
}
pub trait VecZnxAutomorphismInplace {
/// Applies the automorphism X^i -> X^ik on the selected column of `a`.
fn vec_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
where
A: VecZnxToMut;
}
pub trait VecZnxMulXpMinusOne {
fn vec_znx_mul_xp_minus_one<R, A>(&self, p: i64, r: &mut R, r_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
}
pub trait VecZnxMulXpMinusOneInplace {
fn vec_znx_mul_xp_minus_one_inplace<R>(&self, p: i64, r: &mut R, r_col: usize)
where
R: VecZnxToMut;
}
pub trait VecZnxSplit<B: Backend> {
/// Splits the selected columns of `b` into subrings and copies them them into the selected column of `res`.
///
/// # Panics
///
/// This method requires that all [crate::hal::layouts::VecZnx] of b have the same ring degree
/// and that b.n() * b.len() <= a.n()
fn vec_znx_split<R, A>(&self, res: &mut [R], res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
where
R: VecZnxToMut,
A: VecZnxToRef;
}
pub trait VecZnxMerge {
/// Merges the subrings of the selected column of `a` into the selected column of `res`.
///
/// # Panics
///
/// This method requires that all [crate::hal::layouts::VecZnx] of a have the same ring degree
/// and that a.n() * a.len() <= b.n()
fn vec_znx_merge<R, A>(&self, res: &mut R, res_col: usize, a: &[A], a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
}
pub trait VecZnxSwithcDegree {
fn vec_znx_switch_degree<R, A>(&self, res: &mut R, res_col: usize, a: &A, col_a: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
}
pub trait VecZnxCopy {
fn vec_znx_copy<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
}
pub trait VecZnxFillUniform {
/// Fills the first `size` size with uniform values in \[-2^{basek-1}, 2^{basek-1}\]
fn vec_znx_fill_uniform<R>(&self, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source)
where
R: VecZnxToMut;
}
#[allow(clippy::too_many_arguments)]
pub trait VecZnxFillDistF64 {
fn vec_znx_fill_dist_f64<R, D: Distribution<f64>>(
&self,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) where
R: VecZnxToMut;
}
#[allow(clippy::too_many_arguments)]
pub trait VecZnxAddDistF64 {
/// Adds vector sampled according to the provided distribution, scaled by 2^{-k} and bounded to \[-bound, bound\].
fn vec_znx_add_dist_f64<R, D: Distribution<f64>>(
&self,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) where
R: VecZnxToMut;
}
#[allow(clippy::too_many_arguments)]
pub trait VecZnxFillNormal {
fn vec_znx_fill_normal<R>(
&self,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) where
R: VecZnxToMut;
}
#[allow(clippy::too_many_arguments)]
pub trait VecZnxAddNormal {
/// Adds a discrete normal vector scaled by 2^{-k} with the provided standard deviation and bounded to \[-bound, bound\].
fn vec_znx_add_normal<R>(
&self,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) where
R: VecZnxToMut;
}

View File

@@ -0,0 +1,220 @@
use rand_distr::Distribution;
use crate::hal::{
layouts::{Backend, Scratch, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef},
source::Source,
};
/// Allocates as [crate::hal::layouts::VecZnxBig].
pub trait VecZnxBigAlloc<B: Backend> {
fn vec_znx_big_alloc(&self, n: usize, cols: usize, size: usize) -> VecZnxBigOwned<B>;
}
/// Returns the size in bytes to allocate a [crate::hal::layouts::VecZnxBig].
pub trait VecZnxBigAllocBytes {
fn vec_znx_big_alloc_bytes(&self, n: usize, cols: usize, size: usize) -> usize;
}
/// Consume a vector of bytes into a [crate::hal::layouts::VecZnxBig].
/// User must ensure that bytes is memory aligned and that it length is equal to [VecZnxBigAllocBytes].
pub trait VecZnxBigFromBytes<B: Backend> {
fn vec_znx_big_from_bytes(&self, n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<B>;
}
#[allow(clippy::too_many_arguments)]
/// Add a discrete normal distribution on res.
///
/// # Arguments
/// * `basek`: base two logarithm of the bivariate representation
/// * `res`: receiver.
/// * `res_col`: column of the receiver on which the operation is performed/stored.
/// * `k`:
/// * `source`: random coin source.
/// * `sigma`: standard deviation of the discrete normal distribution.
/// * `bound`: rejection sampling bound.
pub trait VecZnxBigAddNormal<B: Backend> {
fn vec_znx_big_add_normal<R: VecZnxBigToMut<B>>(
&self,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
);
}
#[allow(clippy::too_many_arguments)]
pub trait VecZnxBigFillNormal<B: Backend> {
fn vec_znx_big_fill_normal<R: VecZnxBigToMut<B>>(
&self,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
);
}
#[allow(clippy::too_many_arguments)]
pub trait VecZnxBigFillDistF64<B: Backend> {
fn vec_znx_big_fill_dist_f64<R: VecZnxBigToMut<B>, D: Distribution<f64>>(
&self,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
);
}
#[allow(clippy::too_many_arguments)]
pub trait VecZnxBigAddDistF64<B: Backend> {
fn vec_znx_big_add_dist_f64<R: VecZnxBigToMut<B>, D: Distribution<f64>>(
&self,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
);
}
pub trait VecZnxBigAdd<B: Backend> {
/// Adds `a` to `b` and stores the result on `c`.
fn vec_znx_big_add<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxBigToRef<B>,
C: VecZnxBigToRef<B>;
}
pub trait VecZnxBigAddInplace<B: Backend> {
/// Adds `a` to `b` and stores the result on `b`.
fn vec_znx_big_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxBigToRef<B>;
}
pub trait VecZnxBigAddSmall<B: Backend> {
/// Adds `a` to `b` and stores the result on `c`.
fn vec_znx_big_add_small<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxBigToRef<B>,
C: VecZnxToRef;
}
pub trait VecZnxBigAddSmallInplace<B: Backend> {
/// Adds `a` to `b` and stores the result on `b`.
fn vec_znx_big_add_small_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxToRef;
}
pub trait VecZnxBigSub<B: Backend> {
/// Subtracts `a` to `b` and stores the result on `c`.
fn vec_znx_big_sub<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxBigToRef<B>,
C: VecZnxBigToRef<B>;
}
pub trait VecZnxBigSubABInplace<B: Backend> {
/// Subtracts `a` from `b` and stores the result on `b`.
fn vec_znx_big_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxBigToRef<B>;
}
pub trait VecZnxBigSubBAInplace<B: Backend> {
/// Subtracts `b` from `a` and stores the result on `b`.
fn vec_znx_big_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxBigToRef<B>;
}
pub trait VecZnxBigSubSmallA<B: Backend> {
/// Subtracts `b` from `a` and stores the result on `c`.
fn vec_znx_big_sub_small_a<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxToRef,
C: VecZnxBigToRef<B>;
}
pub trait VecZnxBigSubSmallAInplace<B: Backend> {
/// Subtracts `a` from `res` and stores the result on `res`.
fn vec_znx_big_sub_small_a_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxToRef;
}
pub trait VecZnxBigSubSmallB<B: Backend> {
/// Subtracts `b` from `a` and stores the result on `c`.
fn vec_znx_big_sub_small_b<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxBigToRef<B>,
C: VecZnxToRef;
}
pub trait VecZnxBigSubSmallBInplace<B: Backend> {
/// Subtracts `res` from `a` and stores the result on `res`.
fn vec_znx_big_sub_small_b_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxToRef;
}
pub trait VecZnxBigNegateInplace<B: Backend> {
fn vec_znx_big_negate_inplace<A>(&self, a: &mut A, a_col: usize)
where
A: VecZnxBigToMut<B>;
}
pub trait VecZnxBigNormalizeTmpBytes {
fn vec_znx_big_normalize_tmp_bytes(&self, n: usize) -> usize;
}
pub trait VecZnxBigNormalize<B: Backend> {
fn vec_znx_big_normalize<R, A>(
&self,
basek: usize,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch<B>,
) where
R: VecZnxToMut,
A: VecZnxBigToRef<B>;
}
pub trait VecZnxBigAutomorphism<B: Backend> {
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`.
fn vec_znx_big_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxBigToRef<B>;
}
pub trait VecZnxBigAutomorphismInplace<B: Backend> {
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`.
fn vec_znx_big_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
where
A: VecZnxBigToMut<B>;
}

View File

@@ -0,0 +1,96 @@
use crate::hal::layouts::{
Backend, Data, Scratch, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef,
};
pub trait VecZnxDftAlloc<B: Backend> {
fn vec_znx_dft_alloc(&self, n: usize, cols: usize, size: usize) -> VecZnxDftOwned<B>;
}
pub trait VecZnxDftFromBytes<B: Backend> {
fn vec_znx_dft_from_bytes(&self, n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<B>;
}
pub trait VecZnxDftAllocBytes {
fn vec_znx_dft_alloc_bytes(&self, n: usize, cols: usize, size: usize) -> usize;
}
pub trait VecZnxDftToVecZnxBigTmpBytes {
fn vec_znx_dft_to_vec_znx_big_tmp_bytes(&self, n: usize) -> usize;
}
pub trait VecZnxDftToVecZnxBig<B: Backend> {
fn vec_znx_dft_to_vec_znx_big<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
where
R: VecZnxBigToMut<B>,
A: VecZnxDftToRef<B>;
}
pub trait VecZnxDftToVecZnxBigTmpA<B: Backend> {
fn vec_znx_dft_to_vec_znx_big_tmp_a<R, A>(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxDftToMut<B>;
}
pub trait VecZnxDftToVecZnxBigConsume<B: Backend> {
fn vec_znx_dft_to_vec_znx_big_consume<D: Data>(&self, a: VecZnxDft<D, B>) -> VecZnxBig<D, B>
where
VecZnxDft<D, B>: VecZnxDftToMut<B>;
}
pub trait VecZnxDftAdd<B: Backend> {
fn vec_znx_dft_add<R, A, D>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>,
D: VecZnxDftToRef<B>;
}
pub trait VecZnxDftAddInplace<B: Backend> {
fn vec_znx_dft_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>;
}
pub trait VecZnxDftSub<B: Backend> {
fn vec_znx_dft_sub<R, A, D>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>,
D: VecZnxDftToRef<B>;
}
pub trait VecZnxDftSubABInplace<B: Backend> {
fn vec_znx_dft_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>;
}
pub trait VecZnxDftSubBAInplace<B: Backend> {
fn vec_znx_dft_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>;
}
pub trait VecZnxDftCopy<B: Backend> {
fn vec_znx_dft_copy<R, A>(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>;
}
pub trait VecZnxDftFromVecZnx<B: Backend> {
fn vec_znx_dft_from_vec_znx<R, A>(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<B>,
A: VecZnxToRef;
}
pub trait VecZnxDftZero<B: Backend> {
fn vec_znx_dft_zero<R>(&self, res: &mut R)
where
R: VecZnxDftToMut<B>;
}

View File

@@ -0,0 +1,102 @@
use crate::hal::layouts::{
Backend, MatZnxToRef, Scratch, VecZnxDftToMut, VecZnxDftToRef, VmpPMatOwned, VmpPMatToMut, VmpPMatToRef,
};
pub trait VmpPMatAlloc<B: Backend> {
fn vmp_pmat_alloc(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<B>;
}
pub trait VmpPMatAllocBytes {
fn vmp_pmat_alloc_bytes(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
}
pub trait VmpPMatFromBytes<B: Backend> {
fn vmp_pmat_from_bytes(
&self,
n: usize,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
bytes: Vec<u8>,
) -> VmpPMatOwned<B>;
}
pub trait VmpPrepareTmpBytes {
fn vmp_prepare_tmp_bytes(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
}
pub trait VmpPrepare<B: Backend> {
fn vmp_prepare<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<B>)
where
R: VmpPMatToMut<B>,
A: MatZnxToRef;
}
#[allow(clippy::too_many_arguments)]
pub trait VmpApplyTmpBytes {
fn vmp_apply_tmp_bytes(
&self,
n: usize,
res_size: usize,
a_size: usize,
b_rows: usize,
b_cols_in: usize,
b_cols_out: usize,
b_size: usize,
) -> usize;
}
pub trait VmpApply<B: Backend> {
/// Applies the vector matrix product [crate::hal::layouts::VecZnxDft] x [crate::hal::layouts::VmpPMat].
///
/// A vector matrix product numerically equivalent to a sum of [crate::hal::api::SvpApply],
/// where each [crate::hal::layouts::SvpPPol] is a limb of the input [crate::hal::layouts::VecZnx] in DFT,
/// and each vector a [crate::hal::layouts::VecZnxDft] (row) of the [crate::hal::layouts::VmpPMat].
///
/// As such, given an input [crate::hal::layouts::VecZnx] of `i` size and a [crate::hal::layouts::VmpPMat] of `i` rows and
/// `j` size, the output is a [crate::hal::layouts::VecZnx] of `j` size.
///
/// If there is a mismatch between the dimensions the largest valid ones are used.
///
/// ```text
/// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p|
/// |h i j|
/// |k l m|
/// ```
/// where each element is a [crate::hal::layouts::VecZnxDft].
///
/// # Arguments
///
/// * `c`: the output of the vector matrix product, as a [crate::hal::layouts::VecZnxDft].
/// * `a`: the left operand [crate::hal::layouts::VecZnxDft] of the vector matrix product.
/// * `b`: the right operand [crate::hal::layouts::VmpPMat] of the vector matrix product.
/// * `buf`: scratch space, the size can be obtained with [VmpApplyTmpBytes::vmp_apply_tmp_bytes].
fn vmp_apply<R, A, C>(&self, res: &mut R, a: &A, b: &C, scratch: &mut Scratch<B>)
where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>,
C: VmpPMatToRef<B>;
}
#[allow(clippy::too_many_arguments)]
pub trait VmpApplyAddTmpBytes {
fn vmp_apply_add_tmp_bytes(
&self,
n: usize,
res_size: usize,
a_size: usize,
b_rows: usize,
b_cols_in: usize,
b_cols_out: usize,
b_size: usize,
) -> usize;
}
pub trait VmpApplyAdd<B: Backend> {
fn vmp_apply_add<R, A, C>(&self, res: &mut R, a: &A, b: &C, scale: usize, scratch: &mut Scratch<B>)
where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>,
C: VmpPMatToRef<B>;
}

View File

@@ -0,0 +1,121 @@
use crate::hal::{
layouts::{Data, DataMut, DataRef},
source::Source,
};
use rand_distr::num_traits::Zero;
pub trait ZnxInfos {
/// Returns the ring degree of the polynomials.
fn n(&self) -> usize;
/// Returns the base two logarithm of the ring dimension of the polynomials.
fn log_n(&self) -> usize {
(usize::BITS - (self.n() - 1).leading_zeros()) as _
}
/// Returns the number of rows.
fn rows(&self) -> usize;
/// Returns the number of polynomials in each row.
fn cols(&self) -> usize;
/// Returns the number of size per polynomial.
fn size(&self) -> usize;
/// Returns the total number of small polynomials.
fn poly_count(&self) -> usize {
self.rows() * self.cols() * self.size()
}
}
pub trait ZnxSliceSize {
/// Returns the slice size, which is the offset between
/// two size of the same column.
fn sl(&self) -> usize;
}
pub trait DataView {
type D: Data;
fn data(&self) -> &Self::D;
}
pub trait DataViewMut: DataView {
fn data_mut(&mut self) -> &mut Self::D;
}
pub trait ZnxView: ZnxInfos + DataView<D: DataRef> {
type Scalar: Copy + Zero;
/// Returns a non-mutable pointer to the underlying coefficients array.
fn as_ptr(&self) -> *const Self::Scalar {
self.data().as_ref().as_ptr() as *const Self::Scalar
}
/// Returns a non-mutable reference to the entire underlying coefficient array.
fn raw(&self) -> &[Self::Scalar] {
unsafe { std::slice::from_raw_parts(self.as_ptr(), self.n() * self.poly_count()) }
}
/// Returns a non-mutable pointer starting at the j-th small polynomial of the i-th column.
fn at_ptr(&self, i: usize, j: usize) -> *const Self::Scalar {
#[cfg(debug_assertions)]
{
assert!(i < self.cols(), "cols: {} >= {}", i, self.cols());
assert!(j < self.size(), "size: {} >= {}", j, self.size());
}
let offset: usize = self.n() * (j * self.cols() + i);
unsafe { self.as_ptr().add(offset) }
}
/// Returns non-mutable reference to the (i, j)-th small polynomial.
fn at(&self, i: usize, j: usize) -> &[Self::Scalar] {
unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n()) }
}
}
pub trait ZnxViewMut: ZnxView + DataViewMut<D: DataMut> {
/// Returns a mutable pointer to the underlying coefficients array.
fn as_mut_ptr(&mut self) -> *mut Self::Scalar {
self.data_mut().as_mut().as_mut_ptr() as *mut Self::Scalar
}
/// Returns a mutable reference to the entire underlying coefficient array.
fn raw_mut(&mut self) -> &mut [Self::Scalar] {
unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.n() * self.poly_count()) }
}
/// Returns a mutable pointer starting at the j-th small polynomial of the i-th column.
fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut Self::Scalar {
#[cfg(debug_assertions)]
{
assert!(i < self.cols(), "cols: {} >= {}", i, self.cols());
assert!(j < self.size(), "size: {} >= {}", j, self.size());
}
let offset: usize = self.n() * (j * self.cols() + i);
unsafe { self.as_mut_ptr().add(offset) }
}
/// Returns mutable reference to the (i, j)-th small polynomial.
fn at_mut(&mut self, i: usize, j: usize) -> &mut [Self::Scalar] {
unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n()) }
}
}
//(Jay)Note: Can't provide blanket impl. of ZnxView because Scalar is not known
impl<T> ZnxViewMut for T where T: ZnxView + DataViewMut<D: DataMut> {}
pub trait ZnxZero
where
Self: Sized,
{
fn zero(&mut self);
fn zero_at(&mut self, i: usize, j: usize);
}
pub trait FillUniform {
fn fill_uniform(&mut self, source: &mut Source);
}
pub trait Reset {
fn reset(&mut self);
}

View File

@@ -0,0 +1,7 @@
mod module;
mod scratch;
mod svp_ppol;
mod vec_znx;
mod vec_znx_big;
mod vec_znx_dft;
mod vmp_pmat;

View File

@@ -0,0 +1,14 @@
use crate::hal::{
api::ModuleNew,
layouts::{Backend, Module},
oep::ModuleNewImpl,
};
impl<B> ModuleNew<B> for Module<B>
where
B: Backend + ModuleNewImpl<B>,
{
fn new(n: u64) -> Self {
B::new_impl(n)
}
}

View File

@@ -0,0 +1,235 @@
use crate::hal::{
api::{
ScratchAvailable, ScratchFromBytes, ScratchOwnedAlloc, ScratchOwnedBorrow, TakeLike, TakeMatZnx, TakeScalarZnx,
TakeSlice, TakeSvpPPol, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, TakeVecZnxDftSlice, TakeVecZnxSlice, TakeVmpPMat,
},
layouts::{Backend, DataRef, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat},
oep::{
ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeLikeImpl, TakeMatZnxImpl,
TakeScalarZnxImpl, TakeSliceImpl, TakeSvpPPolImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxDftSliceImpl,
TakeVecZnxImpl, TakeVecZnxSliceImpl, TakeVmpPMatImpl,
},
};
impl<B> ScratchOwnedAlloc<B> for ScratchOwned<B>
where
B: Backend + ScratchOwnedAllocImpl<B>,
{
fn alloc(size: usize) -> Self {
B::scratch_owned_alloc_impl(size)
}
}
impl<B> ScratchOwnedBorrow<B> for ScratchOwned<B>
where
B: Backend + ScratchOwnedBorrowImpl<B>,
{
fn borrow(&mut self) -> &mut Scratch<B> {
B::scratch_owned_borrow_impl(self)
}
}
impl<B> ScratchFromBytes<B> for Scratch<B>
where
B: Backend + ScratchFromBytesImpl<B>,
{
fn from_bytes(data: &mut [u8]) -> &mut Scratch<B> {
B::scratch_from_bytes_impl(data)
}
}
impl<B> ScratchAvailable for Scratch<B>
where
B: Backend + ScratchAvailableImpl<B>,
{
fn available(&self) -> usize {
B::scratch_available_impl(self)
}
}
impl<B> TakeSlice for Scratch<B>
where
B: Backend + TakeSliceImpl<B>,
{
fn take_slice<T>(&mut self, len: usize) -> (&mut [T], &mut Self) {
B::take_slice_impl(self, len)
}
}
impl<B> TakeScalarZnx for Scratch<B>
where
B: Backend + TakeScalarZnxImpl<B>,
{
fn take_scalar_znx(&mut self, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) {
B::take_scalar_znx_impl(self, n, cols)
}
}
impl<B> TakeSvpPPol<B> for Scratch<B>
where
B: Backend + TakeSvpPPolImpl<B>,
{
fn take_svp_ppol(&mut self, n: usize, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Self) {
B::take_svp_ppol_impl(self, n, cols)
}
}
impl<B> TakeVecZnx for Scratch<B>
where
B: Backend + TakeVecZnxImpl<B>,
{
fn take_vec_znx(&mut self, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self) {
B::take_vec_znx_impl(self, n, cols, size)
}
}
impl<B> TakeVecZnxSlice for Scratch<B>
where
B: Backend + TakeVecZnxSliceImpl<B>,
{
fn take_vec_znx_slice(&mut self, len: usize, n: usize, cols: usize, size: usize) -> (Vec<VecZnx<&mut [u8]>>, &mut Self) {
B::take_vec_znx_slice_impl(self, len, n, cols, size)
}
}
impl<B> TakeVecZnxBig<B> for Scratch<B>
where
B: Backend + TakeVecZnxBigImpl<B>,
{
fn take_vec_znx_big(&mut self, n: usize, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, &mut Self) {
B::take_vec_znx_big_impl(self, n, cols, size)
}
}
impl<B> TakeVecZnxDft<B> for Scratch<B>
where
B: Backend + TakeVecZnxDftImpl<B>,
{
fn take_vec_znx_dft(&mut self, n: usize, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, &mut Self) {
B::take_vec_znx_dft_impl(self, n, cols, size)
}
}
impl<B> TakeVecZnxDftSlice<B> for Scratch<B>
where
B: Backend + TakeVecZnxDftSliceImpl<B>,
{
fn take_vec_znx_dft_slice(
&mut self,
len: usize,
n: usize,
cols: usize,
size: usize,
) -> (Vec<VecZnxDft<&mut [u8], B>>, &mut Self) {
B::take_vec_znx_dft_slice_impl(self, len, n, cols, size)
}
}
impl<B> TakeVmpPMat<B> for Scratch<B>
where
B: Backend + TakeVmpPMatImpl<B>,
{
fn take_vmp_pmat(
&mut self,
n: usize,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
) -> (VmpPMat<&mut [u8], B>, &mut Self) {
B::take_vmp_pmat_impl(self, n, rows, cols_in, cols_out, size)
}
}
impl<B> TakeMatZnx for Scratch<B>
where
B: Backend + TakeMatZnxImpl<B>,
{
fn take_mat_znx(
&mut self,
n: usize,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
) -> (MatZnx<&mut [u8]>, &mut Self) {
B::take_mat_znx_impl(self, n, rows, cols_in, cols_out, size)
}
}
impl<'a, B: Backend, D> TakeLike<'a, B, ScalarZnx<D>> for Scratch<B>
where
B: TakeLikeImpl<'a, B, ScalarZnx<D>, Output = ScalarZnx<&'a mut [u8]>>,
D: DataRef,
{
type Output = ScalarZnx<&'a mut [u8]>;
fn take_like(&'a mut self, template: &ScalarZnx<D>) -> (Self::Output, &'a mut Self) {
B::take_like_impl(self, template)
}
}
impl<'a, B: Backend, D> TakeLike<'a, B, SvpPPol<D, B>> for Scratch<B>
where
B: TakeLikeImpl<'a, B, SvpPPol<D, B>, Output = SvpPPol<&'a mut [u8], B>>,
D: DataRef,
{
type Output = SvpPPol<&'a mut [u8], B>;
fn take_like(&'a mut self, template: &SvpPPol<D, B>) -> (Self::Output, &'a mut Self) {
B::take_like_impl(self, template)
}
}
impl<'a, B: Backend, D> TakeLike<'a, B, VecZnx<D>> for Scratch<B>
where
B: TakeLikeImpl<'a, B, VecZnx<D>, Output = VecZnx<&'a mut [u8]>>,
D: DataRef,
{
type Output = VecZnx<&'a mut [u8]>;
fn take_like(&'a mut self, template: &VecZnx<D>) -> (Self::Output, &'a mut Self) {
B::take_like_impl(self, template)
}
}
impl<'a, B: Backend, D> TakeLike<'a, B, VecZnxBig<D, B>> for Scratch<B>
where
B: TakeLikeImpl<'a, B, VecZnxBig<D, B>, Output = VecZnxBig<&'a mut [u8], B>>,
D: DataRef,
{
type Output = VecZnxBig<&'a mut [u8], B>;
fn take_like(&'a mut self, template: &VecZnxBig<D, B>) -> (Self::Output, &'a mut Self) {
B::take_like_impl(self, template)
}
}
impl<'a, B: Backend, D> TakeLike<'a, B, VecZnxDft<D, B>> for Scratch<B>
where
B: TakeLikeImpl<'a, B, VecZnxDft<D, B>, Output = VecZnxDft<&'a mut [u8], B>>,
D: DataRef,
{
type Output = VecZnxDft<&'a mut [u8], B>;
fn take_like(&'a mut self, template: &VecZnxDft<D, B>) -> (Self::Output, &'a mut Self) {
B::take_like_impl(self, template)
}
}
impl<'a, B: Backend, D> TakeLike<'a, B, MatZnx<D>> for Scratch<B>
where
B: TakeLikeImpl<'a, B, MatZnx<D>, Output = MatZnx<&'a mut [u8]>>,
D: DataRef,
{
type Output = MatZnx<&'a mut [u8]>;
fn take_like(&'a mut self, template: &MatZnx<D>) -> (Self::Output, &'a mut Self) {
B::take_like_impl(self, template)
}
}
impl<'a, B: Backend, D> TakeLike<'a, B, VmpPMat<D, B>> for Scratch<B>
where
B: TakeLikeImpl<'a, B, VmpPMat<D, B>, Output = VmpPMat<&'a mut [u8], B>>,
D: DataRef,
{
type Output = VmpPMat<&'a mut [u8], B>;
fn take_like(&'a mut self, template: &VmpPMat<D, B>) -> (Self::Output, &'a mut Self) {
B::take_like_impl(self, template)
}
}

View File

@@ -0,0 +1,72 @@
use crate::hal::{
api::{SvpApply, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPPolFromBytes, SvpPrepare},
layouts::{Backend, Module, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDftToMut, VecZnxDftToRef},
oep::{SvpApplyImpl, SvpApplyInplaceImpl, SvpPPolAllocBytesImpl, SvpPPolAllocImpl, SvpPPolFromBytesImpl, SvpPrepareImpl},
};
impl<B> SvpPPolFromBytes<B> for Module<B>
where
B: Backend + SvpPPolFromBytesImpl<B>,
{
fn svp_ppol_from_bytes(&self, n: usize, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<B> {
B::svp_ppol_from_bytes_impl(n, cols, bytes)
}
}
impl<B> SvpPPolAlloc<B> for Module<B>
where
B: Backend + SvpPPolAllocImpl<B>,
{
fn svp_ppol_alloc(&self, n: usize, cols: usize) -> SvpPPolOwned<B> {
B::svp_ppol_alloc_impl(n, cols)
}
}
impl<B> SvpPPolAllocBytes for Module<B>
where
B: Backend + SvpPPolAllocBytesImpl<B>,
{
fn svp_ppol_alloc_bytes(&self, n: usize, cols: usize) -> usize {
B::svp_ppol_alloc_bytes_impl(n, cols)
}
}
impl<B> SvpPrepare<B> for Module<B>
where
B: Backend + SvpPrepareImpl<B>,
{
fn svp_prepare<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: SvpPPolToMut<B>,
A: ScalarZnxToRef,
{
B::svp_prepare_impl(self, res, res_col, a, a_col);
}
}
impl<B> SvpApply<B> for Module<B>
where
B: Backend + SvpApplyImpl<B>,
{
fn svp_apply<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
where
R: VecZnxDftToMut<B>,
A: SvpPPolToRef<B>,
C: VecZnxDftToRef<B>,
{
B::svp_apply_impl(self, res, res_col, a, a_col, b, b_col);
}
}
impl<B> SvpApplyInplace<B> for Module<B>
where
B: Backend + SvpApplyInplaceImpl,
{
fn svp_apply_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<B>,
A: SvpPPolToRef<B>,
{
B::svp_apply_inplace_impl(self, res, res_col, a, a_col);
}
}

View File

@@ -0,0 +1,414 @@
use crate::hal::{
api::{
VecZnxAdd, VecZnxAddDistF64, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism,
VecZnxAutomorphismInplace, VecZnxCopy, VecZnxFillDistF64, VecZnxFillNormal, VecZnxFillUniform, VecZnxLshInplace,
VecZnxMerge, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegate, VecZnxNegateInplace, VecZnxNormalize,
VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSplit,
VecZnxSub, VecZnxSubABInplace, VecZnxSubBAInplace, VecZnxSubScalarInplace, VecZnxSwithcDegree,
},
layouts::{Backend, Module, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef},
oep::{
VecZnxAddDistF64Impl, VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl, VecZnxAddScalarInplaceImpl,
VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl, VecZnxFillDistF64Impl, VecZnxFillNormalImpl,
VecZnxFillUniformImpl, VecZnxLshInplaceImpl, VecZnxMergeImpl, VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl,
VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl, VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl,
VecZnxRotateImpl, VecZnxRotateInplaceImpl, VecZnxRshInplaceImpl, VecZnxSplitImpl, VecZnxSubABInplaceImpl,
VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarInplaceImpl, VecZnxSwithcDegreeImpl,
},
source::Source,
};
impl<B> VecZnxNormalizeTmpBytes for Module<B>
where
B: Backend + VecZnxNormalizeTmpBytesImpl<B>,
{
fn vec_znx_normalize_tmp_bytes(&self, n: usize) -> usize {
B::vec_znx_normalize_tmp_bytes_impl(self, n)
}
}
impl<B> VecZnxNormalize<B> for Module<B>
where
B: Backend + VecZnxNormalizeImpl<B>,
{
fn vec_znx_normalize<R, A>(&self, basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
B::vec_znx_normalize_impl(self, basek, res, res_col, a, a_col, scratch)
}
}
impl<B> VecZnxNormalizeInplace<B> for Module<B>
where
B: Backend + VecZnxNormalizeInplaceImpl<B>,
{
fn vec_znx_normalize_inplace<A>(&self, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
where
A: VecZnxToMut,
{
B::vec_znx_normalize_inplace_impl(self, basek, a, a_col, scratch)
}
}
impl<B> VecZnxAdd for Module<B>
where
B: Backend + VecZnxAddImpl<B>,
{
fn vec_znx_add<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
C: VecZnxToRef,
{
B::vec_znx_add_impl(self, res, res_col, a, a_col, b, b_col)
}
}
impl<B> VecZnxAddInplace for Module<B>
where
B: Backend + VecZnxAddInplaceImpl<B>,
{
fn vec_znx_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
B::vec_znx_add_inplace_impl(self, res, res_col, a, a_col)
}
}
impl<B> VecZnxAddScalarInplace for Module<B>
where
B: Backend + VecZnxAddScalarInplaceImpl<B>,
{
fn vec_znx_add_scalar_inplace<R, A>(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: ScalarZnxToRef,
{
B::vec_znx_add_scalar_inplace_impl(self, res, res_col, res_limb, a, a_col)
}
}
impl<B> VecZnxSub for Module<B>
where
B: Backend + VecZnxSubImpl<B>,
{
fn vec_znx_sub<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
C: VecZnxToRef,
{
B::vec_znx_sub_impl(self, res, res_col, a, a_col, b, b_col)
}
}
impl<B> VecZnxSubABInplace for Module<B>
where
B: Backend + VecZnxSubABInplaceImpl<B>,
{
fn vec_znx_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
B::vec_znx_sub_ab_inplace_impl(self, res, res_col, a, a_col)
}
}
impl<B> VecZnxSubBAInplace for Module<B>
where
B: Backend + VecZnxSubBAInplaceImpl<B>,
{
fn vec_znx_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
B::vec_znx_sub_ba_inplace_impl(self, res, res_col, a, a_col)
}
}
impl<B> VecZnxSubScalarInplace for Module<B>
where
B: Backend + VecZnxSubScalarInplaceImpl<B>,
{
fn vec_znx_sub_scalar_inplace<R, A>(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: ScalarZnxToRef,
{
B::vec_znx_sub_scalar_inplace_impl(self, res, res_col, res_limb, a, a_col)
}
}
impl<B> VecZnxNegate for Module<B>
where
B: Backend + VecZnxNegateImpl<B>,
{
fn vec_znx_negate<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
B::vec_znx_negate_impl(self, res, res_col, a, a_col)
}
}
impl<B> VecZnxNegateInplace for Module<B>
where
B: Backend + VecZnxNegateInplaceImpl<B>,
{
fn vec_znx_negate_inplace<A>(&self, a: &mut A, a_col: usize)
where
A: VecZnxToMut,
{
B::vec_znx_negate_inplace_impl(self, a, a_col)
}
}
impl<B> VecZnxLshInplace for Module<B>
where
B: Backend + VecZnxLshInplaceImpl<B>,
{
fn vec_znx_lsh_inplace<A>(&self, basek: usize, k: usize, a: &mut A)
where
A: VecZnxToMut,
{
B::vec_znx_lsh_inplace_impl(self, basek, k, a)
}
}
impl<B> VecZnxRshInplace for Module<B>
where
B: Backend + VecZnxRshInplaceImpl<B>,
{
fn vec_znx_rsh_inplace<A>(&self, basek: usize, k: usize, a: &mut A)
where
A: VecZnxToMut,
{
B::vec_znx_rsh_inplace_impl(self, basek, k, a)
}
}
impl<B> VecZnxRotate for Module<B>
where
B: Backend + VecZnxRotateImpl<B>,
{
fn vec_znx_rotate<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
B::vec_znx_rotate_impl(self, k, res, res_col, a, a_col)
}
}
impl<B> VecZnxRotateInplace for Module<B>
where
B: Backend + VecZnxRotateInplaceImpl<B>,
{
fn vec_znx_rotate_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
where
A: VecZnxToMut,
{
B::vec_znx_rotate_inplace_impl(self, k, a, a_col)
}
}
impl<B> VecZnxAutomorphism for Module<B>
where
B: Backend + VecZnxAutomorphismImpl<B>,
{
fn vec_znx_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
B::vec_znx_automorphism_impl(self, k, res, res_col, a, a_col)
}
}
impl<B> VecZnxAutomorphismInplace for Module<B>
where
B: Backend + VecZnxAutomorphismInplaceImpl<B>,
{
fn vec_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
where
A: VecZnxToMut,
{
B::vec_znx_automorphism_inplace_impl(self, k, a, a_col)
}
}
impl<B> VecZnxMulXpMinusOne for Module<B>
where
B: Backend + VecZnxMulXpMinusOneImpl<B>,
{
fn vec_znx_mul_xp_minus_one<R, A>(&self, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
B::vec_znx_mul_xp_minus_one_impl(self, p, res, res_col, a, a_col);
}
}
impl<B> VecZnxMulXpMinusOneInplace for Module<B>
where
B: Backend + VecZnxMulXpMinusOneInplaceImpl<B>,
{
fn vec_znx_mul_xp_minus_one_inplace<R>(&self, p: i64, res: &mut R, res_col: usize)
where
R: VecZnxToMut,
{
B::vec_znx_mul_xp_minus_one_inplace_impl(self, p, res, res_col);
}
}
impl<B> VecZnxSplit<B> for Module<B>
where
B: Backend + VecZnxSplitImpl<B>,
{
fn vec_znx_split<R, A>(&self, res: &mut [R], res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
B::vec_znx_split_impl(self, res, res_col, a, a_col, scratch)
}
}
impl<B> VecZnxMerge for Module<B>
where
B: Backend + VecZnxMergeImpl<B>,
{
fn vec_znx_merge<R, A>(&self, res: &mut R, res_col: usize, a: &[A], a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
B::vec_znx_merge_impl(self, res, res_col, a, a_col)
}
}
impl<B> VecZnxSwithcDegree for Module<B>
where
B: Backend + VecZnxSwithcDegreeImpl<B>,
{
fn vec_znx_switch_degree<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
B::vec_znx_switch_degree_impl(self, res, res_col, a, a_col)
}
}
impl<B> VecZnxCopy for Module<B>
where
B: Backend + VecZnxCopyImpl<B>,
{
fn vec_znx_copy<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
B::vec_znx_copy_impl(self, res, res_col, a, a_col)
}
}
impl<B> VecZnxFillUniform for Module<B>
where
B: Backend + VecZnxFillUniformImpl<B>,
{
fn vec_znx_fill_uniform<R>(&self, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source)
where
R: VecZnxToMut,
{
B::vec_znx_fill_uniform_impl(self, basek, res, res_col, k, source);
}
}
impl<B> VecZnxFillDistF64 for Module<B>
where
B: Backend + VecZnxFillDistF64Impl<B>,
{
fn vec_znx_fill_dist_f64<R, D: rand::prelude::Distribution<f64>>(
&self,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) where
R: VecZnxToMut,
{
B::vec_znx_fill_dist_f64_impl(self, basek, res, res_col, k, source, dist, bound);
}
}
impl<B> VecZnxAddDistF64 for Module<B>
where
B: Backend + VecZnxAddDistF64Impl<B>,
{
fn vec_znx_add_dist_f64<R, D: rand::prelude::Distribution<f64>>(
&self,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) where
R: VecZnxToMut,
{
B::vec_znx_add_dist_f64_impl(self, basek, res, res_col, k, source, dist, bound);
}
}
impl<B> VecZnxFillNormal for Module<B>
where
B: Backend + VecZnxFillNormalImpl<B>,
{
fn vec_znx_fill_normal<R>(
&self,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) where
R: VecZnxToMut,
{
B::vec_znx_fill_normal_impl(self, basek, res, res_col, k, source, sigma, bound);
}
}
impl<B> VecZnxAddNormal for Module<B>
where
B: Backend + VecZnxAddNormalImpl<B>,
{
fn vec_znx_add_normal<R>(
&self,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) where
R: VecZnxToMut,
{
B::vec_znx_add_normal_impl(self, basek, res, res_col, k, source, sigma, bound);
}
}

View File

@@ -0,0 +1,334 @@
use rand_distr::Distribution;
use crate::hal::{
api::{
VecZnxBigAdd, VecZnxBigAddDistF64, VecZnxBigAddInplace, VecZnxBigAddNormal, VecZnxBigAddSmall, VecZnxBigAddSmallInplace,
VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigAutomorphism, VecZnxBigAutomorphismInplace, VecZnxBigFillDistF64,
VecZnxBigFillNormal, VecZnxBigFromBytes, VecZnxBigNegateInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
VecZnxBigSub, VecZnxBigSubABInplace, VecZnxBigSubBAInplace, VecZnxBigSubSmallA, VecZnxBigSubSmallAInplace,
VecZnxBigSubSmallB, VecZnxBigSubSmallBInplace,
},
layouts::{Backend, Module, Scratch, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef},
oep::{
VecZnxBigAddDistF64Impl, VecZnxBigAddImpl, VecZnxBigAddInplaceImpl, VecZnxBigAddNormalImpl, VecZnxBigAddSmallImpl,
VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, VecZnxBigAutomorphismImpl,
VecZnxBigAutomorphismInplaceImpl, VecZnxBigFillDistF64Impl, VecZnxBigFillNormalImpl, VecZnxBigFromBytesImpl,
VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl, VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubABInplaceImpl,
VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl, VecZnxBigSubSmallAImpl, VecZnxBigSubSmallAInplaceImpl,
VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl,
},
source::Source,
};
impl<B> VecZnxBigAlloc<B> for Module<B>
where
B: Backend + VecZnxBigAllocImpl<B>,
{
fn vec_znx_big_alloc(&self, n: usize, cols: usize, size: usize) -> VecZnxBigOwned<B> {
B::vec_znx_big_alloc_impl(n, cols, size)
}
}
impl<B> VecZnxBigFromBytes<B> for Module<B>
where
B: Backend + VecZnxBigFromBytesImpl<B>,
{
fn vec_znx_big_from_bytes(&self, n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<B> {
B::vec_znx_big_from_bytes_impl(n, cols, size, bytes)
}
}
impl<B> VecZnxBigAllocBytes for Module<B>
where
B: Backend + VecZnxBigAllocBytesImpl<B>,
{
fn vec_znx_big_alloc_bytes(&self, n: usize, cols: usize, size: usize) -> usize {
B::vec_znx_big_alloc_bytes_impl(n, cols, size)
}
}
impl<B> VecZnxBigAddDistF64<B> for Module<B>
where
B: Backend + VecZnxBigAddDistF64Impl<B>,
{
fn vec_znx_big_add_dist_f64<R: VecZnxBigToMut<B>, D: Distribution<f64>>(
&self,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) {
B::add_dist_f64_impl(self, basek, res, res_col, k, source, dist, bound);
}
}
impl<B> VecZnxBigAddNormal<B> for Module<B>
where
B: Backend + VecZnxBigAddNormalImpl<B>,
{
fn vec_znx_big_add_normal<R: VecZnxBigToMut<B>>(
&self,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) {
B::add_normal_impl(self, basek, res, res_col, k, source, sigma, bound);
}
}
impl<B> VecZnxBigFillDistF64<B> for Module<B>
where
B: Backend + VecZnxBigFillDistF64Impl<B>,
{
fn vec_znx_big_fill_dist_f64<R: VecZnxBigToMut<B>, D: Distribution<f64>>(
&self,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) {
B::fill_dist_f64_impl(self, basek, res, res_col, k, source, dist, bound);
}
}
impl<B> VecZnxBigFillNormal<B> for Module<B>
where
B: Backend + VecZnxBigFillNormalImpl<B>,
{
fn vec_znx_big_fill_normal<R: VecZnxBigToMut<B>>(
&self,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) {
B::fill_normal_impl(self, basek, res, res_col, k, source, sigma, bound);
}
}
impl<B> VecZnxBigAdd<B> for Module<B>
where
B: Backend + VecZnxBigAddImpl<B>,
{
fn vec_znx_big_add<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxBigToRef<B>,
C: VecZnxBigToRef<B>,
{
B::vec_znx_big_add_impl(self, res, res_col, a, a_col, b, b_col);
}
}
impl<B> VecZnxBigAddInplace<B> for Module<B>
where
B: Backend + VecZnxBigAddInplaceImpl<B>,
{
fn vec_znx_big_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxBigToRef<B>,
{
B::vec_znx_big_add_inplace_impl(self, res, res_col, a, a_col);
}
}
impl<B> VecZnxBigAddSmall<B> for Module<B>
where
B: Backend + VecZnxBigAddSmallImpl<B>,
{
fn vec_znx_big_add_small<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxBigToRef<B>,
C: VecZnxToRef,
{
B::vec_znx_big_add_small_impl(self, res, res_col, a, a_col, b, b_col);
}
}
impl<B> VecZnxBigAddSmallInplace<B> for Module<B>
where
B: Backend + VecZnxBigAddSmallInplaceImpl<B>,
{
fn vec_znx_big_add_small_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxToRef,
{
B::vec_znx_big_add_small_inplace_impl(self, res, res_col, a, a_col);
}
}
impl<B> VecZnxBigSub<B> for Module<B>
where
B: Backend + VecZnxBigSubImpl<B>,
{
fn vec_znx_big_sub<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxBigToRef<B>,
C: VecZnxBigToRef<B>,
{
B::vec_znx_big_sub_impl(self, res, res_col, a, a_col, b, b_col);
}
}
impl<B> VecZnxBigSubABInplace<B> for Module<B>
where
B: Backend + VecZnxBigSubABInplaceImpl<B>,
{
fn vec_znx_big_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxBigToRef<B>,
{
B::vec_znx_big_sub_ab_inplace_impl(self, res, res_col, a, a_col);
}
}
impl<B> VecZnxBigSubBAInplace<B> for Module<B>
where
B: Backend + VecZnxBigSubBAInplaceImpl<B>,
{
fn vec_znx_big_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxBigToRef<B>,
{
B::vec_znx_big_sub_ba_inplace_impl(self, res, res_col, a, a_col);
}
}
impl<B> VecZnxBigSubSmallA<B> for Module<B>
where
B: Backend + VecZnxBigSubSmallAImpl<B>,
{
fn vec_znx_big_sub_small_a<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxToRef,
C: VecZnxBigToRef<B>,
{
B::vec_znx_big_sub_small_a_impl(self, res, res_col, a, a_col, b, b_col);
}
}
impl<B> VecZnxBigSubSmallAInplace<B> for Module<B>
where
B: Backend + VecZnxBigSubSmallAInplaceImpl<B>,
{
fn vec_znx_big_sub_small_a_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxToRef,
{
B::vec_znx_big_sub_small_a_inplace_impl(self, res, res_col, a, a_col);
}
}
impl<B> VecZnxBigSubSmallB<B> for Module<B>
where
B: Backend + VecZnxBigSubSmallBImpl<B>,
{
fn vec_znx_big_sub_small_b<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxBigToRef<B>,
C: VecZnxToRef,
{
B::vec_znx_big_sub_small_b_impl(self, res, res_col, a, a_col, b, b_col);
}
}
impl<B> VecZnxBigSubSmallBInplace<B> for Module<B>
where
B: Backend + VecZnxBigSubSmallBInplaceImpl<B>,
{
fn vec_znx_big_sub_small_b_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxToRef,
{
B::vec_znx_big_sub_small_b_inplace_impl(self, res, res_col, a, a_col);
}
}
impl<B> VecZnxBigNegateInplace<B> for Module<B>
where
B: Backend + VecZnxBigNegateInplaceImpl<B>,
{
fn vec_znx_big_negate_inplace<A>(&self, a: &mut A, a_col: usize)
where
A: VecZnxBigToMut<B>,
{
B::vec_znx_big_negate_inplace_impl(self, a, a_col);
}
}
impl<B> VecZnxBigNormalizeTmpBytes for Module<B>
where
B: Backend + VecZnxBigNormalizeTmpBytesImpl<B>,
{
fn vec_znx_big_normalize_tmp_bytes(&self, n: usize) -> usize {
B::vec_znx_big_normalize_tmp_bytes_impl(self, n)
}
}
impl<B> VecZnxBigNormalize<B> for Module<B>
where
B: Backend + VecZnxBigNormalizeImpl<B>,
{
fn vec_znx_big_normalize<R, A>(
&self,
basek: usize,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch<B>,
) where
R: VecZnxToMut,
A: VecZnxBigToRef<B>,
{
B::vec_znx_big_normalize_impl(self, basek, res, res_col, a, a_col, scratch);
}
}
impl<B> VecZnxBigAutomorphism<B> for Module<B>
where
B: Backend + VecZnxBigAutomorphismImpl<B>,
{
fn vec_znx_big_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxBigToRef<B>,
{
B::vec_znx_big_automorphism_impl(self, k, res, res_col, a, a_col);
}
}
impl<B> VecZnxBigAutomorphismInplace<B> for Module<B>
where
B: Backend + VecZnxBigAutomorphismInplaceImpl<B>,
{
fn vec_znx_big_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
where
A: VecZnxBigToMut<B>,
{
B::vec_znx_big_automorphism_inplace_impl(self, k, a, a_col);
}
}

View File

@@ -0,0 +1,196 @@
use crate::hal::{
api::{
VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftCopy, VecZnxDftFromBytes,
VecZnxDftFromVecZnx, VecZnxDftSub, VecZnxDftSubABInplace, VecZnxDftSubBAInplace, VecZnxDftToVecZnxBig,
VecZnxDftToVecZnxBigConsume, VecZnxDftToVecZnxBigTmpA, VecZnxDftToVecZnxBigTmpBytes, VecZnxDftZero,
},
layouts::{
Backend, Data, Module, Scratch, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef,
VecZnxToRef,
},
oep::{
VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftCopyImpl,
VecZnxDftFromBytesImpl, VecZnxDftFromVecZnxImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl, VecZnxDftSubImpl,
VecZnxDftToVecZnxBigConsumeImpl, VecZnxDftToVecZnxBigImpl, VecZnxDftToVecZnxBigTmpAImpl,
VecZnxDftToVecZnxBigTmpBytesImpl, VecZnxDftZeroImpl,
},
};
impl<B> VecZnxDftFromBytes<B> for Module<B>
where
B: Backend + VecZnxDftFromBytesImpl<B>,
{
fn vec_znx_dft_from_bytes(&self, n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<B> {
B::vec_znx_dft_from_bytes_impl(n, cols, size, bytes)
}
}
impl<B> VecZnxDftAllocBytes for Module<B>
where
B: Backend + VecZnxDftAllocBytesImpl<B>,
{
fn vec_znx_dft_alloc_bytes(&self, n: usize, cols: usize, size: usize) -> usize {
B::vec_znx_dft_alloc_bytes_impl(n, cols, size)
}
}
impl<B> VecZnxDftAlloc<B> for Module<B>
where
B: Backend + VecZnxDftAllocImpl<B>,
{
fn vec_znx_dft_alloc(&self, n: usize, cols: usize, size: usize) -> VecZnxDftOwned<B> {
B::vec_znx_dft_alloc_impl(n, cols, size)
}
}
impl<B> VecZnxDftToVecZnxBigTmpBytes for Module<B>
where
B: Backend + VecZnxDftToVecZnxBigTmpBytesImpl<B>,
{
fn vec_znx_dft_to_vec_znx_big_tmp_bytes(&self, n: usize) -> usize {
B::vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(self, n)
}
}
impl<B> VecZnxDftToVecZnxBig<B> for Module<B>
where
B: Backend + VecZnxDftToVecZnxBigImpl<B>,
{
fn vec_znx_dft_to_vec_znx_big<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
where
R: VecZnxBigToMut<B>,
A: VecZnxDftToRef<B>,
{
B::vec_znx_dft_to_vec_znx_big_impl(self, res, res_col, a, a_col, scratch);
}
}
impl<B> VecZnxDftToVecZnxBigTmpA<B> for Module<B>
where
B: Backend + VecZnxDftToVecZnxBigTmpAImpl<B>,
{
fn vec_znx_dft_to_vec_znx_big_tmp_a<R, A>(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxDftToMut<B>,
{
B::vec_znx_dft_to_vec_znx_big_tmp_a_impl(self, res, res_col, a, a_col);
}
}
impl<B> VecZnxDftToVecZnxBigConsume<B> for Module<B>
where
B: Backend + VecZnxDftToVecZnxBigConsumeImpl<B>,
{
fn vec_znx_dft_to_vec_znx_big_consume<D: Data>(&self, a: VecZnxDft<D, B>) -> VecZnxBig<D, B>
where
VecZnxDft<D, B>: VecZnxDftToMut<B>,
{
B::vec_znx_dft_to_vec_znx_big_consume_impl(self, a)
}
}
impl<B> VecZnxDftFromVecZnx<B> for Module<B>
where
B: Backend + VecZnxDftFromVecZnxImpl<B>,
{
fn vec_znx_dft_from_vec_znx<R, A>(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<B>,
A: VecZnxToRef,
{
B::vec_znx_dft_from_vec_znx_impl(self, step, offset, res, res_col, a, a_col);
}
}
impl<B> VecZnxDftAdd<B> for Module<B>
where
B: Backend + VecZnxDftAddImpl<B>,
{
fn vec_znx_dft_add<R, A, D>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>,
D: VecZnxDftToRef<B>,
{
B::vec_znx_dft_add_impl(self, res, res_col, a, a_col, b, b_col);
}
}
impl<B> VecZnxDftAddInplace<B> for Module<B>
where
B: Backend + VecZnxDftAddInplaceImpl<B>,
{
fn vec_znx_dft_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>,
{
B::vec_znx_dft_add_inplace_impl(self, res, res_col, a, a_col);
}
}
impl<B> VecZnxDftSub<B> for Module<B>
where
B: Backend + VecZnxDftSubImpl<B>,
{
fn vec_znx_dft_sub<R, A, D>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>,
D: VecZnxDftToRef<B>,
{
B::vec_znx_dft_sub_impl(self, res, res_col, a, a_col, b, b_col);
}
}
impl<B> VecZnxDftSubABInplace<B> for Module<B>
where
B: Backend + VecZnxDftSubABInplaceImpl<B>,
{
fn vec_znx_dft_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>,
{
B::vec_znx_dft_sub_ab_inplace_impl(self, res, res_col, a, a_col);
}
}
impl<B> VecZnxDftSubBAInplace<B> for Module<B>
where
B: Backend + VecZnxDftSubBAInplaceImpl<B>,
{
fn vec_znx_dft_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>,
{
B::vec_znx_dft_sub_ba_inplace_impl(self, res, res_col, a, a_col);
}
}
impl<B> VecZnxDftCopy<B> for Module<B>
where
B: Backend + VecZnxDftCopyImpl<B>,
{
fn vec_znx_dft_copy<R, A>(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>,
{
B::vec_znx_dft_copy_impl(self, step, offset, res, res_col, a, a_col);
}
}
impl<B> VecZnxDftZero<B> for Module<B>
where
B: Backend + VecZnxDftZeroImpl<B>,
{
fn vec_znx_dft_zero<R>(&self, res: &mut R)
where
R: VecZnxDftToMut<B>,
{
B::vec_znx_dft_zero_impl(self, res);
}
}

View File

@@ -0,0 +1,136 @@
use crate::hal::{
api::{
VmpApply, VmpApplyAdd, VmpApplyAddTmpBytes, VmpApplyTmpBytes, VmpPMatAlloc, VmpPMatAllocBytes, VmpPMatFromBytes,
VmpPrepare, VmpPrepareTmpBytes,
},
layouts::{Backend, MatZnxToRef, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VmpPMatOwned, VmpPMatToMut, VmpPMatToRef},
oep::{
VmpApplyAddImpl, VmpApplyAddTmpBytesImpl, VmpApplyImpl, VmpApplyTmpBytesImpl, VmpPMatAllocBytesImpl, VmpPMatAllocImpl,
VmpPMatFromBytesImpl, VmpPMatPrepareImpl, VmpPrepareTmpBytesImpl,
},
};
impl<B> VmpPMatAlloc<B> for Module<B>
where
B: Backend + VmpPMatAllocImpl<B>,
{
fn vmp_pmat_alloc(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<B> {
B::vmp_pmat_alloc_impl(n, rows, cols_in, cols_out, size)
}
}
impl<B> VmpPMatAllocBytes for Module<B>
where
B: Backend + VmpPMatAllocBytesImpl<B>,
{
fn vmp_pmat_alloc_bytes(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
B::vmp_pmat_alloc_bytes_impl(n, rows, cols_in, cols_out, size)
}
}
impl<B> VmpPMatFromBytes<B> for Module<B>
where
B: Backend + VmpPMatFromBytesImpl<B>,
{
fn vmp_pmat_from_bytes(
&self,
n: usize,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
bytes: Vec<u8>,
) -> VmpPMatOwned<B> {
B::vmp_pmat_from_bytes_impl(n, rows, cols_in, cols_out, size, bytes)
}
}
impl<B> VmpPrepareTmpBytes for Module<B>
where
B: Backend + VmpPrepareTmpBytesImpl<B>,
{
fn vmp_prepare_tmp_bytes(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
B::vmp_prepare_tmp_bytes_impl(self, n, rows, cols_in, cols_out, size)
}
}
impl<B> VmpPrepare<B> for Module<B>
where
B: Backend + VmpPMatPrepareImpl<B>,
{
fn vmp_prepare<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<B>)
where
R: VmpPMatToMut<B>,
A: MatZnxToRef,
{
B::vmp_prepare_impl(self, res, a, scratch)
}
}
impl<B> VmpApplyTmpBytes for Module<B>
where
B: Backend + VmpApplyTmpBytesImpl<B>,
{
fn vmp_apply_tmp_bytes(
&self,
n: usize,
res_size: usize,
a_size: usize,
b_rows: usize,
b_cols_in: usize,
b_cols_out: usize,
b_size: usize,
) -> usize {
B::vmp_apply_tmp_bytes_impl(
self, n, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
)
}
}
impl<B> VmpApply<B> for Module<B>
where
B: Backend + VmpApplyImpl<B>,
{
fn vmp_apply<R, A, C>(&self, res: &mut R, a: &A, b: &C, scratch: &mut Scratch<B>)
where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>,
C: VmpPMatToRef<B>,
{
B::vmp_apply_impl(self, res, a, b, scratch);
}
}
impl<B> VmpApplyAddTmpBytes for Module<B>
where
B: Backend + VmpApplyAddTmpBytesImpl<B>,
{
fn vmp_apply_add_tmp_bytes(
&self,
n: usize,
res_size: usize,
a_size: usize,
b_rows: usize,
b_cols_in: usize,
b_cols_out: usize,
b_size: usize,
) -> usize {
B::vmp_apply_add_tmp_bytes_impl(
self, n, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
)
}
}
impl<B> VmpApplyAdd<B> for Module<B>
where
B: Backend + VmpApplyAddImpl<B>,
{
fn vmp_apply_add<R, A, C>(&self, res: &mut R, a: &A, b: &C, scale: usize, scratch: &mut Scratch<B>)
where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>,
C: VmpPMatToRef<B>,
{
B::vmp_apply_add_impl(self, res, a, b, scale, scratch);
}
}

View File

@@ -0,0 +1,204 @@
use itertools::izip;
use rug::{Assign, Float};
use crate::hal::{
api::{ZnxInfos, ZnxView, ZnxViewMut, ZnxZero},
layouts::{DataMut, DataRef, VecZnx, VecZnxToMut, VecZnxToRef},
};
impl<D: DataMut> VecZnx<D> {
pub fn encode_vec_i64(&mut self, basek: usize, col: usize, k: usize, data: &[i64], log_max: usize) {
let size: usize = k.div_ceil(basek);
#[cfg(debug_assertions)]
{
let a: VecZnx<&mut [u8]> = self.to_mut();
assert!(
size <= a.size(),
"invalid argument k.div_ceil(basek)={} > a.size()={}",
size,
a.size()
);
assert!(col < a.cols());
assert!(data.len() <= a.n())
}
let data_len: usize = data.len();
let mut a: VecZnx<&mut [u8]> = self.to_mut();
let k_rem: usize = basek - (k % basek);
// Zeroes coefficients of the i-th column
(0..a.size()).for_each(|i| {
a.zero_at(col, i);
});
// If 2^{basek} * 2^{k_rem} < 2^{63}-1, then we can simply copy
// values on the last limb.
// Else we decompose values base2k.
if log_max + k_rem < 63 || k_rem == basek {
a.at_mut(col, size - 1)[..data_len].copy_from_slice(&data[..data_len]);
} else {
let mask: i64 = (1 << basek) - 1;
let steps: usize = size.min(log_max.div_ceil(basek));
(size - steps..size)
.rev()
.enumerate()
.for_each(|(i, i_rev)| {
let shift: usize = i * basek;
izip!(a.at_mut(col, i_rev).iter_mut(), data.iter()).for_each(|(y, x)| *y = (x >> shift) & mask);
})
}
// Case where self.prec % self.k != 0.
if k_rem != basek {
let steps: usize = size.min(log_max.div_ceil(basek));
(size - steps..size).rev().for_each(|i| {
a.at_mut(col, i)[..data_len]
.iter_mut()
.for_each(|x| *x <<= k_rem);
})
}
}
pub fn encode_coeff_i64(&mut self, basek: usize, col: usize, k: usize, idx: usize, data: i64, log_max: usize) {
let size: usize = k.div_ceil(basek);
#[cfg(debug_assertions)]
{
let a: VecZnx<&mut [u8]> = self.to_mut();
assert!(idx < a.n());
assert!(
size <= a.size(),
"invalid argument k.div_ceil(basek)={} > a.size()={}",
size,
a.size()
);
assert!(col < a.cols());
}
let k_rem: usize = basek - (k % basek);
let mut a: VecZnx<&mut [u8]> = self.to_mut();
(0..a.size()).for_each(|j| a.at_mut(col, j)[idx] = 0);
// If 2^{basek} * 2^{k_rem} < 2^{63}-1, then we can simply copy
// values on the last limb.
// Else we decompose values base2k.
if log_max + k_rem < 63 || k_rem == basek {
a.at_mut(col, size - 1)[idx] = data;
} else {
let mask: i64 = (1 << basek) - 1;
let steps: usize = size.min(log_max.div_ceil(basek));
(size - steps..size)
.rev()
.enumerate()
.for_each(|(j, j_rev)| {
a.at_mut(col, j_rev)[idx] = (data >> (j * basek)) & mask;
})
}
// Case where prec % k != 0.
if k_rem != basek {
let steps: usize = size.min(log_max.div_ceil(basek));
(size - steps..size).rev().for_each(|j| {
a.at_mut(col, j)[idx] <<= k_rem;
})
}
}
}
impl<D: DataRef> VecZnx<D> {
pub fn decode_vec_i64(&self, basek: usize, col: usize, k: usize, data: &mut [i64]) {
let size: usize = k.div_ceil(basek);
#[cfg(debug_assertions)]
{
let a: VecZnx<&[u8]> = self.to_ref();
assert!(
data.len() >= a.n(),
"invalid data: data.len()={} < a.n()={}",
data.len(),
a.n()
);
assert!(col < a.cols());
}
let a: VecZnx<&[u8]> = self.to_ref();
data.copy_from_slice(a.at(col, 0));
let rem: usize = basek - (k % basek);
if k < basek {
data.iter_mut().for_each(|x| *x >>= rem);
} else {
(1..size).for_each(|i| {
if i == size - 1 && rem != basek {
let k_rem: usize = basek - rem;
izip!(a.at(col, i).iter(), data.iter_mut()).for_each(|(x, y)| {
*y = (*y << k_rem) + (x >> rem);
});
} else {
izip!(a.at(col, i).iter(), data.iter_mut()).for_each(|(x, y)| {
*y = (*y << basek) + x;
});
}
})
}
}
pub fn decode_coeff_i64(&self, basek: usize, col: usize, k: usize, idx: usize) -> i64 {
#[cfg(debug_assertions)]
{
let a: VecZnx<&[u8]> = self.to_ref();
assert!(idx < a.n());
assert!(col < a.cols())
}
let a: VecZnx<&[u8]> = self.to_ref();
let size: usize = k.div_ceil(basek);
let mut res: i64 = 0;
let rem: usize = basek - (k % basek);
(0..size).for_each(|j| {
let x: i64 = a.at(col, j)[idx];
if j == size - 1 && rem != basek {
let k_rem: usize = basek - rem;
res = (res << k_rem) + (x >> rem);
} else {
res = (res << basek) + x;
}
});
res
}
pub fn decode_vec_float(&self, basek: usize, col: usize, data: &mut [Float]) {
#[cfg(debug_assertions)]
{
let a: VecZnx<&[u8]> = self.to_ref();
assert!(
data.len() >= a.n(),
"invalid data: data.len()={} < a.n()={}",
data.len(),
a.n()
);
assert!(col < a.cols());
}
let a: VecZnx<&[u8]> = self.to_ref();
let size: usize = a.size();
let prec: u32 = (basek * size) as u32;
// 2^{basek}
let base = Float::with_val(prec, (1 << basek) as f64);
// y[i] = sum x[j][i] * 2^{-basek*j}
(0..size).for_each(|i| {
if i == 0 {
izip!(a.at(col, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
y.assign(*x);
*y /= &base;
});
} else {
izip!(a.at(col, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
*y += Float::with_val(prec, *x);
*y /= &base;
});
}
});
}
}

View File

@@ -0,0 +1,306 @@
use crate::{
alloc_aligned,
hal::{
api::{DataView, DataViewMut, FillUniform, Reset, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero},
layouts::{Data, DataMut, DataRef, ReaderFrom, ToOwnedDeep, VecZnx, WriterTo},
source::Source,
},
};
use std::fmt;
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use rand::RngCore;
#[derive(PartialEq, Eq, Clone)]
pub struct MatZnx<D: Data> {
data: D,
n: usize,
size: usize,
rows: usize,
cols_in: usize,
cols_out: usize,
}
impl<D: DataRef> ToOwnedDeep for MatZnx<D> {
type Owned = MatZnx<Vec<u8>>;
fn to_owned_deep(&self) -> Self::Owned {
MatZnx {
data: self.data.as_ref().to_vec(),
n: self.n,
size: self.size,
rows: self.rows,
cols_in: self.cols_in,
cols_out: self.cols_out,
}
}
}
impl<D: DataRef> fmt::Debug for MatZnx<D> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self)
}
}
impl<D: Data> ZnxInfos for MatZnx<D> {
fn cols(&self) -> usize {
self.cols_in
}
fn rows(&self) -> usize {
self.rows
}
fn n(&self) -> usize {
self.n
}
fn size(&self) -> usize {
self.size
}
}
impl<D: Data> ZnxSliceSize for MatZnx<D> {
fn sl(&self) -> usize {
self.n() * self.cols_out()
}
}
impl<D: Data> DataView for MatZnx<D> {
type D = D;
fn data(&self) -> &Self::D {
&self.data
}
}
impl<D: Data> DataViewMut for MatZnx<D> {
fn data_mut(&mut self) -> &mut Self::D {
&mut self.data
}
}
impl<D: DataRef> ZnxView for MatZnx<D> {
type Scalar = i64;
}
impl<D: Data> MatZnx<D> {
pub fn cols_in(&self) -> usize {
self.cols_in
}
pub fn cols_out(&self) -> usize {
self.cols_out
}
}
impl MatZnx<Vec<u8>> {
pub fn alloc_bytes(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
rows * cols_in * VecZnx::<Vec<u8>>::alloc_bytes(n, cols_out, size)
}
pub fn alloc(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
let data: Vec<u8> = alloc_aligned(Self::alloc_bytes(n, rows, cols_in, cols_out, size));
Self {
data,
n,
size,
rows,
cols_in,
cols_out,
}
}
pub fn from_bytes(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
let data: Vec<u8> = bytes.into();
assert!(data.len() == Self::alloc_bytes(n, rows, cols_in, cols_out, size));
Self {
data,
n,
size,
rows,
cols_in,
cols_out,
}
}
}
impl<D: DataRef> MatZnx<D> {
pub fn at(&self, row: usize, col: usize) -> VecZnx<&[u8]> {
#[cfg(debug_assertions)]
{
assert!(row < self.rows(), "rows: {} >= {}", row, self.rows());
assert!(col < self.cols_in(), "cols: {} >= {}", col, self.cols_in());
}
let self_ref: MatZnx<&[u8]> = self.to_ref();
let nb_bytes: usize = VecZnx::<Vec<u8>>::alloc_bytes(self.n, self.cols_out, self.size);
let start: usize = nb_bytes * self.cols() * row + col * nb_bytes;
let end: usize = start + nb_bytes;
VecZnx {
data: &self_ref.data[start..end],
n: self.n,
cols: self.cols_out,
size: self.size,
max_size: self.size,
}
}
}
impl<D: DataMut> MatZnx<D> {
pub fn at_mut(&mut self, row: usize, col: usize) -> VecZnx<&mut [u8]> {
#[cfg(debug_assertions)]
{
assert!(row < self.rows(), "rows: {} >= {}", row, self.rows());
assert!(col < self.cols_in(), "cols: {} >= {}", col, self.cols_in());
}
let n: usize = self.n();
let cols_out: usize = self.cols_out();
let cols_in: usize = self.cols_in();
let size: usize = self.size();
let self_ref: MatZnx<&mut [u8]> = self.to_mut();
let nb_bytes: usize = VecZnx::<Vec<u8>>::alloc_bytes(n, cols_out, size);
let start: usize = nb_bytes * cols_in * row + col * nb_bytes;
let end: usize = start + nb_bytes;
VecZnx {
data: &mut self_ref.data[start..end],
n,
cols: cols_out,
size,
max_size: size,
}
}
}
impl<D: DataMut> FillUniform for MatZnx<D> {
fn fill_uniform(&mut self, source: &mut Source) {
source.fill_bytes(self.data.as_mut());
}
}
impl<D: DataMut> Reset for MatZnx<D> {
fn reset(&mut self) {
self.zero();
self.n = 0;
self.size = 0;
self.rows = 0;
self.cols_in = 0;
self.cols_out = 0;
}
}
pub type MatZnxOwned = MatZnx<Vec<u8>>;
pub type MatZnxMut<'a> = MatZnx<&'a mut [u8]>;
pub type MatZnxRef<'a> = MatZnx<&'a [u8]>;
pub trait MatZnxToRef {
fn to_ref(&self) -> MatZnx<&[u8]>;
}
impl<D: DataRef> MatZnxToRef for MatZnx<D> {
fn to_ref(&self) -> MatZnx<&[u8]> {
MatZnx {
data: self.data.as_ref(),
n: self.n,
rows: self.rows,
cols_in: self.cols_in,
cols_out: self.cols_out,
size: self.size,
}
}
}
pub trait MatZnxToMut {
fn to_mut(&mut self) -> MatZnx<&mut [u8]>;
}
impl<D: DataMut> MatZnxToMut for MatZnx<D> {
fn to_mut(&mut self) -> MatZnx<&mut [u8]> {
MatZnx {
data: self.data.as_mut(),
n: self.n,
rows: self.rows,
cols_in: self.cols_in,
cols_out: self.cols_out,
size: self.size,
}
}
}
impl<D: Data> MatZnx<D> {
pub(crate) fn from_data(data: D, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
Self {
data,
n,
rows,
cols_in,
cols_out,
size,
}
}
}
impl<D: DataMut> ReaderFrom for MatZnx<D> {
fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
self.n = reader.read_u64::<LittleEndian>()? as usize;
self.size = reader.read_u64::<LittleEndian>()? as usize;
self.rows = reader.read_u64::<LittleEndian>()? as usize;
self.cols_in = reader.read_u64::<LittleEndian>()? as usize;
self.cols_out = reader.read_u64::<LittleEndian>()? as usize;
let len: usize = reader.read_u64::<LittleEndian>()? as usize;
let buf: &mut [u8] = self.data.as_mut();
if buf.len() != len {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
format!("self.data.len()={} != read len={}", buf.len(), len),
));
}
reader.read_exact(&mut buf[..len])?;
Ok(())
}
}
impl<D: DataRef> WriterTo for MatZnx<D> {
fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
writer.write_u64::<LittleEndian>(self.n as u64)?;
writer.write_u64::<LittleEndian>(self.size as u64)?;
writer.write_u64::<LittleEndian>(self.rows as u64)?;
writer.write_u64::<LittleEndian>(self.cols_in as u64)?;
writer.write_u64::<LittleEndian>(self.cols_out as u64)?;
let buf: &[u8] = self.data.as_ref();
writer.write_u64::<LittleEndian>(buf.len() as u64)?;
writer.write_all(buf)?;
Ok(())
}
}
impl<D: DataRef> fmt::Display for MatZnx<D> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(
f,
"MatZnx(n={}, rows={}, cols_in={}, cols_out={}, size={})",
self.n, self.rows, self.cols_in, self.cols_out, self.size
)?;
for row_i in 0..self.rows {
writeln!(f, "Row {}:", row_i)?;
for col_i in 0..self.cols_in {
writeln!(f, "cols_in {}:", col_i)?;
writeln!(f, "{}:", self.at(row_i, col_i))?;
}
}
Ok(())
}
}
impl<D: DataMut> ZnxZero for MatZnx<D> {
fn zero(&mut self) {
self.raw_mut().fill(0)
}
fn zero_at(&mut self, i: usize, j: usize) {
self.at_mut(i, j).zero();
}
}

View File

@@ -0,0 +1,32 @@
mod encoding;
mod mat_znx;
mod module;
mod scalar_znx;
mod scratch;
mod serialization;
mod stats;
mod svp_ppol;
mod vec_znx;
mod vec_znx_big;
mod vec_znx_dft;
mod vmp_pmat;
pub use mat_znx::*;
pub use module::*;
pub use scalar_znx::*;
pub use scratch::*;
pub use serialization::*;
pub use svp_ppol::*;
pub use vec_znx::*;
pub use vec_znx_big::*;
pub use vec_znx_dft::*;
pub use vmp_pmat::*;
pub trait Data = PartialEq + Eq + Sized;
pub trait DataRef = Data + AsRef<[u8]>;
pub trait DataMut = DataRef + AsMut<[u8]>;
pub trait ToOwnedDeep {
type Owned;
fn to_owned_deep(&self) -> Self::Owned;
}

View File

@@ -0,0 +1,97 @@
use std::{marker::PhantomData, ptr::NonNull};
use crate::GALOISGENERATOR;
#[allow(clippy::missing_safety_doc)]
pub trait Backend: Sized {
type Handle: 'static;
unsafe fn destroy(handle: NonNull<Self::Handle>);
}
pub struct Module<B: Backend> {
ptr: NonNull<B::Handle>,
n: u64,
_marker: PhantomData<B>,
}
impl<B: Backend> Module<B> {
/// Construct from a raw pointer managed elsewhere.
/// SAFETY: `ptr` must be non-null and remain valid for the lifetime of this Module.
#[inline]
#[allow(clippy::missing_safety_doc)]
pub unsafe fn from_raw_parts(ptr: *mut B::Handle, n: u64) -> Self {
Self {
ptr: NonNull::new(ptr).expect("null module ptr"),
n,
_marker: PhantomData,
}
}
#[allow(clippy::missing_safety_doc)]
#[inline]
pub unsafe fn ptr(&self) -> *mut <B as Backend>::Handle {
self.ptr.as_ptr()
}
#[inline]
pub fn n(&self) -> usize {
self.n as usize
}
#[inline]
pub fn as_mut_ptr(&self) -> *mut B::Handle {
self.ptr.as_ptr()
}
#[inline]
pub fn log_n(&self) -> usize {
(usize::BITS - (self.n() - 1).leading_zeros()) as _
}
#[inline]
pub fn cyclotomic_order(&self) -> u64 {
(self.n() << 1) as _
}
// Returns GALOISGENERATOR^|generator| * sign(generator)
#[inline]
pub fn galois_element(&self, generator: i64) -> i64 {
if generator == 0 {
return 1;
}
((mod_exp_u64(GALOISGENERATOR, generator.unsigned_abs() as usize) & (self.cyclotomic_order() - 1)) as i64)
* generator.signum()
}
// Returns gen^-1
#[inline]
pub fn galois_element_inv(&self, gal_el: i64) -> i64 {
if gal_el == 0 {
panic!("cannot invert 0")
}
((mod_exp_u64(
gal_el.unsigned_abs(),
(self.cyclotomic_order() - 1) as usize,
) & (self.cyclotomic_order() - 1)) as i64)
* gal_el.signum()
}
}
impl<B: Backend> Drop for Module<B> {
fn drop(&mut self) {
unsafe { B::destroy(self.ptr) }
}
}
pub fn mod_exp_u64(x: u64, e: usize) -> u64 {
let mut y: u64 = 1;
let mut x_pow: u64 = x;
let mut exp = e;
while exp > 0 {
if exp & 1 == 1 {
y = y.wrapping_mul(x_pow);
}
x_pow = x_pow.wrapping_mul(x_pow);
exp >>= 1;
}
y
}

View File

@@ -0,0 +1,249 @@
use rand::seq::SliceRandom;
use rand_core::RngCore;
use rand_distr::{Distribution, weighted::WeightedIndex};
use crate::{
alloc_aligned,
hal::{
api::{DataView, DataViewMut, FillUniform, Reset, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero},
layouts::{Data, DataMut, DataRef, ReaderFrom, ToOwnedDeep, VecZnx, WriterTo},
source::Source,
},
};
#[derive(PartialEq, Eq, Debug, Clone)]
pub struct ScalarZnx<D: Data> {
pub(crate) data: D,
pub(crate) n: usize,
pub(crate) cols: usize,
}
impl<D: DataRef> ToOwnedDeep for ScalarZnx<D> {
type Owned = ScalarZnx<Vec<u8>>;
fn to_owned_deep(&self) -> Self::Owned {
ScalarZnx {
data: self.data.as_ref().to_vec(),
n: self.n,
cols: self.cols,
}
}
}
impl<D: Data> ZnxInfos for ScalarZnx<D> {
fn cols(&self) -> usize {
self.cols
}
fn rows(&self) -> usize {
1
}
fn n(&self) -> usize {
self.n
}
fn size(&self) -> usize {
1
}
}
impl<D: Data> ZnxSliceSize for ScalarZnx<D> {
fn sl(&self) -> usize {
self.n()
}
}
impl<D: Data> DataView for ScalarZnx<D> {
type D = D;
fn data(&self) -> &Self::D {
&self.data
}
}
impl<D: Data> DataViewMut for ScalarZnx<D> {
fn data_mut(&mut self) -> &mut Self::D {
&mut self.data
}
}
impl<D: DataRef> ZnxView for ScalarZnx<D> {
type Scalar = i64;
}
impl<D: DataMut> ScalarZnx<D> {
pub fn fill_ternary_prob(&mut self, col: usize, prob: f64, source: &mut Source) {
let choices: [i64; 3] = [-1, 0, 1];
let weights: [f64; 3] = [prob / 2.0, 1.0 - prob, prob / 2.0];
let dist: WeightedIndex<f64> = WeightedIndex::new(weights).unwrap();
self.at_mut(col, 0)
.iter_mut()
.for_each(|x: &mut i64| *x = choices[dist.sample(source)]);
}
pub fn fill_ternary_hw(&mut self, col: usize, hw: usize, source: &mut Source) {
assert!(hw <= self.n());
self.at_mut(col, 0)[..hw]
.iter_mut()
.for_each(|x: &mut i64| *x = (((source.next_u32() & 1) as i64) << 1) - 1);
self.at_mut(col, 0).shuffle(source);
}
pub fn fill_binary_prob(&mut self, col: usize, prob: f64, source: &mut Source) {
let choices: [i64; 2] = [0, 1];
let weights: [f64; 2] = [1.0 - prob, prob];
let dist: WeightedIndex<f64> = WeightedIndex::new(weights).unwrap();
self.at_mut(col, 0)
.iter_mut()
.for_each(|x: &mut i64| *x = choices[dist.sample(source)]);
}
pub fn fill_binary_hw(&mut self, col: usize, hw: usize, source: &mut Source) {
assert!(hw <= self.n());
self.at_mut(col, 0)[..hw]
.iter_mut()
.for_each(|x: &mut i64| *x = (source.next_u32() & 1) as i64);
self.at_mut(col, 0).shuffle(source);
}
pub fn fill_binary_block(&mut self, col: usize, block_size: usize, source: &mut Source) {
assert!(self.n().is_multiple_of(block_size));
let max_idx: u64 = (block_size + 1) as u64;
let mask_idx: u64 = (1 << ((u64::BITS - max_idx.leading_zeros()) as u64)) - 1;
for block in self.at_mut(col, 0).chunks_mut(block_size) {
let idx: usize = source.next_u64n(max_idx, mask_idx) as usize;
if idx != block_size {
block[idx] = 1;
}
}
}
}
impl ScalarZnx<Vec<u8>> {
pub fn alloc_bytes(n: usize, cols: usize) -> usize {
n * cols * size_of::<i64>()
}
pub fn alloc(n: usize, cols: usize) -> Self {
let data: Vec<u8> = alloc_aligned::<u8>(Self::alloc_bytes(n, cols));
Self { data, n, cols }
}
pub fn from_bytes(n: usize, cols: usize, bytes: impl Into<Vec<u8>>) -> Self {
let data: Vec<u8> = bytes.into();
assert!(data.len() == Self::alloc_bytes(n, cols));
Self { data, n, cols }
}
}
impl<D: DataMut> ZnxZero for ScalarZnx<D> {
fn zero(&mut self) {
self.raw_mut().fill(0)
}
fn zero_at(&mut self, i: usize, j: usize) {
self.at_mut(i, j).fill(0);
}
}
impl<D: DataMut> FillUniform for ScalarZnx<D> {
fn fill_uniform(&mut self, source: &mut Source) {
source.fill_bytes(self.data.as_mut());
}
}
impl<D: DataMut> Reset for ScalarZnx<D> {
fn reset(&mut self) {
self.zero();
self.n = 0;
self.cols = 0;
}
}
pub type ScalarZnxOwned = ScalarZnx<Vec<u8>>;
impl<D: Data> ScalarZnx<D> {
pub(crate) fn from_data(data: D, n: usize, cols: usize) -> Self {
Self { data, n, cols }
}
}
pub trait ScalarZnxToRef {
fn to_ref(&self) -> ScalarZnx<&[u8]>;
}
impl<D: DataRef> ScalarZnxToRef for ScalarZnx<D> {
fn to_ref(&self) -> ScalarZnx<&[u8]> {
ScalarZnx {
data: self.data.as_ref(),
n: self.n,
cols: self.cols,
}
}
}
pub trait ScalarZnxToMut {
fn to_mut(&mut self) -> ScalarZnx<&mut [u8]>;
}
impl<D: DataMut> ScalarZnxToMut for ScalarZnx<D> {
fn to_mut(&mut self) -> ScalarZnx<&mut [u8]> {
ScalarZnx {
data: self.data.as_mut(),
n: self.n,
cols: self.cols,
}
}
}
impl<D: DataRef> ScalarZnx<D> {
pub fn as_vec_znx(&self) -> VecZnx<&[u8]> {
VecZnx {
data: self.data.as_ref(),
n: self.n,
cols: self.cols,
size: 1,
max_size: 1,
}
}
}
impl<D: DataMut> ScalarZnx<D> {
pub fn as_vec_znx_mut(&mut self) -> VecZnx<&mut [u8]> {
VecZnx {
data: self.data.as_mut(),
n: self.n,
cols: self.cols,
size: 1,
max_size: 1,
}
}
}
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
impl<D: DataMut> ReaderFrom for ScalarZnx<D> {
fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
self.n = reader.read_u64::<LittleEndian>()? as usize;
self.cols = reader.read_u64::<LittleEndian>()? as usize;
let len: usize = reader.read_u64::<LittleEndian>()? as usize;
let buf: &mut [u8] = self.data.as_mut();
if buf.len() != len {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
format!("self.data.len()={} != read len={}", buf.len(), len),
));
}
reader.read_exact(&mut buf[..len])?;
Ok(())
}
}
impl<D: DataRef> WriterTo for ScalarZnx<D> {
fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
writer.write_u64::<LittleEndian>(self.n as u64)?;
writer.write_u64::<LittleEndian>(self.cols as u64)?;
let buf: &[u8] = self.data.as_ref();
writer.write_u64::<LittleEndian>(buf.len() as u64)?;
writer.write_all(buf)?;
Ok(())
}
}

View File

@@ -0,0 +1,13 @@
use std::marker::PhantomData;
use crate::hal::layouts::Backend;
pub struct ScratchOwned<B: Backend> {
pub(crate) data: Vec<u8>,
pub(crate) _phantom: PhantomData<B>,
}
pub struct Scratch<B: Backend> {
pub(crate) _phantom: PhantomData<B>,
pub(crate) data: [u8],
}

View File

@@ -0,0 +1,9 @@
use std::io::{Read, Result, Write};
pub trait WriterTo {
fn write_to<W: Write>(&self, writer: &mut W) -> Result<()>;
}
pub trait ReaderFrom {
fn read_from<R: Read>(&mut self, reader: &mut R) -> Result<()>;
}

View File

@@ -0,0 +1,32 @@
use rug::{
Float,
float::Round,
ops::{AddAssignRound, DivAssignRound, SubAssignRound},
};
use crate::hal::{
api::ZnxInfos,
layouts::{DataRef, VecZnx},
};
impl<D: DataRef> VecZnx<D> {
pub fn std(&self, basek: usize, col: usize) -> f64 {
let prec: u32 = (self.size() * basek) as u32;
let mut data: Vec<Float> = (0..self.n()).map(|_| Float::with_val(prec, 0)).collect();
self.decode_vec_float(basek, col, &mut data);
// std = sqrt(sum((xi - avg)^2) / n)
let mut avg: Float = Float::with_val(prec, 0);
data.iter().for_each(|x| {
avg.add_assign_round(x, Round::Nearest);
});
avg.div_assign_round(Float::with_val(prec, data.len()), Round::Nearest);
data.iter_mut().for_each(|x| {
x.sub_assign_round(&avg, Round::Nearest);
});
let mut std: Float = Float::with_val(prec, 0);
data.iter().for_each(|x| std += x * x);
std.div_assign_round(Float::with_val(prec, data.len()), Round::Nearest);
std = std.sqrt();
std.to_f64()
}
}

View File

@@ -0,0 +1,151 @@
use std::marker::PhantomData;
use crate::{
alloc_aligned,
hal::{
api::{DataView, DataViewMut, ZnxInfos},
layouts::{Backend, Data, DataMut, DataRef, ReaderFrom, WriterTo},
},
};
#[derive(PartialEq, Eq)]
pub struct SvpPPol<D: Data, B: Backend> {
data: D,
n: usize,
cols: usize,
_phantom: PhantomData<B>,
}
impl<D: Data, B: Backend> ZnxInfos for SvpPPol<D, B> {
fn cols(&self) -> usize {
self.cols
}
fn rows(&self) -> usize {
1
}
fn n(&self) -> usize {
self.n
}
fn size(&self) -> usize {
1
}
}
impl<D: Data, B: Backend> DataView for SvpPPol<D, B> {
type D = D;
fn data(&self) -> &Self::D {
&self.data
}
}
impl<D: Data, B: Backend> DataViewMut for SvpPPol<D, B> {
fn data_mut(&mut self) -> &mut Self::D {
&mut self.data
}
}
pub trait SvpPPolBytesOf {
fn bytes_of(n: usize, cols: usize) -> usize;
}
impl<D: Data + From<Vec<u8>>, B: Backend> SvpPPol<D, B>
where
SvpPPol<D, B>: SvpPPolBytesOf,
{
pub(crate) fn alloc(n: usize, cols: usize) -> Self {
let data: Vec<u8> = alloc_aligned::<u8>(Self::bytes_of(n, cols));
Self {
data: data.into(),
n,
cols,
_phantom: PhantomData,
}
}
pub(crate) fn from_bytes(n: usize, cols: usize, bytes: impl Into<Vec<u8>>) -> Self {
let data: Vec<u8> = bytes.into();
assert!(data.len() == Self::bytes_of(n, cols));
Self {
data: data.into(),
n,
cols,
_phantom: PhantomData,
}
}
}
pub type SvpPPolOwned<B> = SvpPPol<Vec<u8>, B>;
pub trait SvpPPolToRef<B: Backend> {
fn to_ref(&self) -> SvpPPol<&[u8], B>;
}
impl<D: DataRef, B: Backend> SvpPPolToRef<B> for SvpPPol<D, B> {
fn to_ref(&self) -> SvpPPol<&[u8], B> {
SvpPPol {
data: self.data.as_ref(),
n: self.n,
cols: self.cols,
_phantom: PhantomData,
}
}
}
pub trait SvpPPolToMut<B: Backend> {
fn to_mut(&mut self) -> SvpPPol<&mut [u8], B>;
}
impl<D: DataMut, B: Backend> SvpPPolToMut<B> for SvpPPol<D, B> {
fn to_mut(&mut self) -> SvpPPol<&mut [u8], B> {
SvpPPol {
data: self.data.as_mut(),
n: self.n,
cols: self.cols,
_phantom: PhantomData,
}
}
}
impl<D: Data, B: Backend> SvpPPol<D, B> {
pub(crate) fn from_data(data: D, n: usize, cols: usize) -> Self {
Self {
data,
n,
cols,
_phantom: PhantomData,
}
}
}
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
impl<D: DataMut, B: Backend> ReaderFrom for SvpPPol<D, B> {
fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
self.n = reader.read_u64::<LittleEndian>()? as usize;
self.cols = reader.read_u64::<LittleEndian>()? as usize;
let len: usize = reader.read_u64::<LittleEndian>()? as usize;
let buf: &mut [u8] = self.data.as_mut();
if buf.len() != len {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
format!("self.data.len()={} != read len={}", buf.len(), len),
));
}
reader.read_exact(&mut buf[..len])?;
Ok(())
}
}
impl<D: DataRef, B: Backend> WriterTo for SvpPPol<D, B> {
fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
writer.write_u64::<LittleEndian>(self.n as u64)?;
writer.write_u64::<LittleEndian>(self.cols as u64)?;
let buf: &[u8] = self.data.as_ref();
writer.write_u64::<LittleEndian>(buf.len() as u64)?;
writer.write_all(buf)?;
Ok(())
}
}

View File

@@ -0,0 +1,257 @@
use std::fmt;
use crate::{
alloc_aligned,
hal::{
api::{DataView, DataViewMut, FillUniform, Reset, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero},
layouts::{Data, DataMut, DataRef, ReaderFrom, ToOwnedDeep, WriterTo},
source::Source,
},
};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use rand::RngCore;
#[derive(PartialEq, Eq, Clone, Copy)]
pub struct VecZnx<D: Data> {
pub(crate) data: D,
pub(crate) n: usize,
pub(crate) cols: usize,
pub(crate) size: usize,
pub(crate) max_size: usize,
}
impl<D: DataRef> ToOwnedDeep for VecZnx<D> {
type Owned = VecZnx<Vec<u8>>;
fn to_owned_deep(&self) -> Self::Owned {
VecZnx {
data: self.data.as_ref().to_vec(),
n: self.n,
cols: self.cols,
size: self.size,
max_size: self.max_size,
}
}
}
impl<D: DataRef> fmt::Debug for VecZnx<D> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self)
}
}
impl<D: Data> ZnxInfos for VecZnx<D> {
fn cols(&self) -> usize {
self.cols
}
fn rows(&self) -> usize {
1
}
fn n(&self) -> usize {
self.n
}
fn size(&self) -> usize {
self.size
}
}
impl<D: Data> ZnxSliceSize for VecZnx<D> {
fn sl(&self) -> usize {
self.n() * self.cols()
}
}
impl<D: Data> DataView for VecZnx<D> {
type D = D;
fn data(&self) -> &Self::D {
&self.data
}
}
impl<D: Data> DataViewMut for VecZnx<D> {
fn data_mut(&mut self) -> &mut Self::D {
&mut self.data
}
}
impl<D: DataRef> ZnxView for VecZnx<D> {
type Scalar = i64;
}
impl VecZnx<Vec<u8>> {
pub fn rsh_scratch_space(n: usize) -> usize {
n * std::mem::size_of::<i64>()
}
}
impl<D: DataMut> ZnxZero for VecZnx<D> {
fn zero(&mut self) {
self.raw_mut().fill(0)
}
fn zero_at(&mut self, i: usize, j: usize) {
self.at_mut(i, j).fill(0);
}
}
impl VecZnx<Vec<u8>> {
pub fn alloc_bytes(n: usize, cols: usize, size: usize) -> usize {
n * cols * size * size_of::<i64>()
}
pub fn alloc(n: usize, cols: usize, size: usize) -> Self {
let data: Vec<u8> = alloc_aligned::<u8>(Self::alloc_bytes(n, cols, size));
Self {
data,
n,
cols,
size,
max_size: size,
}
}
pub fn from_bytes<Scalar: Sized>(n: usize, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
let data: Vec<u8> = bytes.into();
assert!(data.len() == Self::alloc_bytes(n, cols, size));
Self {
data,
n,
cols,
size,
max_size: size,
}
}
}
impl<D: Data> VecZnx<D> {
pub fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
Self {
data,
n,
cols,
size,
max_size: size,
}
}
}
impl<D: DataRef> fmt::Display for VecZnx<D> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(
f,
"VecZnx(n={}, cols={}, size={})",
self.n, self.cols, self.size
)?;
for col in 0..self.cols {
writeln!(f, "Column {}:", col)?;
for size in 0..self.size {
let coeffs = self.at(col, size);
write!(f, " Size {}: [", size)?;
let max_show = 100;
let show_count = coeffs.len().min(max_show);
for (i, &coeff) in coeffs.iter().take(show_count).enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", coeff)?;
}
if coeffs.len() > max_show {
write!(f, ", ... ({} more)", coeffs.len() - max_show)?;
}
writeln!(f, "]")?;
}
}
Ok(())
}
}
impl<D: DataMut> FillUniform for VecZnx<D> {
fn fill_uniform(&mut self, source: &mut Source) {
source.fill_bytes(self.data.as_mut());
}
}
impl<D: DataMut> Reset for VecZnx<D> {
fn reset(&mut self) {
self.zero();
self.n = 0;
self.cols = 0;
self.size = 0;
self.max_size = 0;
}
}
pub type VecZnxOwned = VecZnx<Vec<u8>>;
pub type VecZnxMut<'a> = VecZnx<&'a mut [u8]>;
pub type VecZnxRef<'a> = VecZnx<&'a [u8]>;
pub trait VecZnxToRef {
fn to_ref(&self) -> VecZnx<&[u8]>;
}
impl<D: DataRef> VecZnxToRef for VecZnx<D> {
fn to_ref(&self) -> VecZnx<&[u8]> {
VecZnx {
data: self.data.as_ref(),
n: self.n,
cols: self.cols,
size: self.size,
max_size: self.max_size,
}
}
}
pub trait VecZnxToMut {
fn to_mut(&mut self) -> VecZnx<&mut [u8]>;
}
impl<D: DataMut> VecZnxToMut for VecZnx<D> {
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
VecZnx {
data: self.data.as_mut(),
n: self.n,
cols: self.cols,
size: self.size,
max_size: self.max_size,
}
}
}
impl<D: DataMut> ReaderFrom for VecZnx<D> {
fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
self.n = reader.read_u64::<LittleEndian>()? as usize;
self.cols = reader.read_u64::<LittleEndian>()? as usize;
self.size = reader.read_u64::<LittleEndian>()? as usize;
self.max_size = reader.read_u64::<LittleEndian>()? as usize;
let len: usize = reader.read_u64::<LittleEndian>()? as usize;
let buf: &mut [u8] = self.data.as_mut();
if buf.len() != len {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
format!("self.data.len()={} != read len={}", buf.len(), len),
));
}
reader.read_exact(&mut buf[..len])?;
Ok(())
}
}
impl<D: DataRef> WriterTo for VecZnx<D> {
fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
writer.write_u64::<LittleEndian>(self.n as u64)?;
writer.write_u64::<LittleEndian>(self.cols as u64)?;
writer.write_u64::<LittleEndian>(self.size as u64)?;
writer.write_u64::<LittleEndian>(self.max_size as u64)?;
let buf: &[u8] = self.data.as_ref();
writer.write_u64::<LittleEndian>(buf.len() as u64)?;
writer.write_all(buf)?;
Ok(())
}
}

View File

@@ -0,0 +1,148 @@
use std::marker::PhantomData;
use rand_distr::num_traits::Zero;
use crate::{
alloc_aligned,
hal::{
api::{DataView, DataViewMut, ZnxInfos, ZnxView, ZnxViewMut, ZnxZero},
layouts::{Backend, Data, DataMut, DataRef},
},
};
#[derive(PartialEq, Eq)]
pub struct VecZnxBig<D: Data, B: Backend> {
pub(crate) data: D,
pub(crate) n: usize,
pub(crate) cols: usize,
pub(crate) size: usize,
pub(crate) max_size: usize,
pub(crate) _phantom: PhantomData<B>,
}
impl<D: Data, B: Backend> ZnxInfos for VecZnxBig<D, B> {
fn cols(&self) -> usize {
self.cols
}
fn rows(&self) -> usize {
1
}
fn n(&self) -> usize {
self.n
}
fn size(&self) -> usize {
self.size
}
}
impl<D: Data, B: Backend> DataView for VecZnxBig<D, B> {
type D = D;
fn data(&self) -> &Self::D {
&self.data
}
}
impl<D: Data, B: Backend> DataViewMut for VecZnxBig<D, B> {
fn data_mut(&mut self) -> &mut Self::D {
&mut self.data
}
}
pub trait VecZnxBigBytesOf {
fn bytes_of(n: usize, cols: usize, size: usize) -> usize;
}
impl<D: DataMut, B: Backend> ZnxZero for VecZnxBig<D, B>
where
Self: ZnxViewMut,
<Self as ZnxView>::Scalar: Zero + Copy,
{
fn zero(&mut self) {
self.raw_mut().fill(<Self as ZnxView>::Scalar::zero())
}
fn zero_at(&mut self, i: usize, j: usize) {
self.at_mut(i, j).fill(<Self as ZnxView>::Scalar::zero());
}
}
impl<D: DataRef + From<Vec<u8>>, B: Backend> VecZnxBig<D, B>
where
VecZnxBig<D, B>: VecZnxBigBytesOf,
{
pub(crate) fn new(n: usize, cols: usize, size: usize) -> Self {
let data = alloc_aligned::<u8>(Self::bytes_of(n, cols, size));
Self {
data: data.into(),
n,
cols,
size,
max_size: size,
_phantom: PhantomData,
}
}
pub(crate) fn new_from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
let data: Vec<u8> = bytes.into();
assert!(data.len() == Self::bytes_of(n, cols, size));
Self {
data: data.into(),
n,
cols,
size,
max_size: size,
_phantom: PhantomData,
}
}
}
impl<D: Data, B: Backend> VecZnxBig<D, B> {
pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
Self {
data,
n,
cols,
size,
max_size: size,
_phantom: PhantomData,
}
}
}
pub type VecZnxBigOwned<B> = VecZnxBig<Vec<u8>, B>;
pub trait VecZnxBigToRef<B: Backend> {
fn to_ref(&self) -> VecZnxBig<&[u8], B>;
}
impl<D: DataRef, B: Backend> VecZnxBigToRef<B> for VecZnxBig<D, B> {
fn to_ref(&self) -> VecZnxBig<&[u8], B> {
VecZnxBig {
data: self.data.as_ref(),
n: self.n,
cols: self.cols,
size: self.size,
max_size: self.max_size,
_phantom: std::marker::PhantomData,
}
}
}
pub trait VecZnxBigToMut<B: Backend> {
fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B>;
}
impl<D: DataMut, B: Backend> VecZnxBigToMut<B> for VecZnxBig<D, B> {
fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B> {
VecZnxBig {
data: self.data.as_mut(),
n: self.n,
cols: self.cols,
size: self.size,
max_size: self.max_size,
_phantom: std::marker::PhantomData,
}
}
}

View File

@@ -0,0 +1,166 @@
use std::marker::PhantomData;
use rand_distr::num_traits::Zero;
use crate::{
alloc_aligned,
hal::{
api::{DataView, DataViewMut, ZnxInfos, ZnxView, ZnxViewMut, ZnxZero},
layouts::{Backend, Data, DataMut, DataRef, VecZnxBig},
},
};
#[derive(PartialEq, Eq)]
pub struct VecZnxDft<D: Data, B: Backend> {
pub(crate) data: D,
pub(crate) n: usize,
pub(crate) cols: usize,
pub(crate) size: usize,
pub(crate) max_size: usize,
pub(crate) _phantom: PhantomData<B>,
}
impl<D: Data, B: Backend> VecZnxDft<D, B> {
pub fn into_big(self) -> VecZnxBig<D, B> {
VecZnxBig::<D, B>::from_data(self.data, self.n, self.cols, self.size)
}
}
impl<D: Data, B: Backend> ZnxInfos for VecZnxDft<D, B> {
fn cols(&self) -> usize {
self.cols
}
fn rows(&self) -> usize {
1
}
fn n(&self) -> usize {
self.n
}
fn size(&self) -> usize {
self.size
}
}
impl<D: Data, B: Backend> DataView for VecZnxDft<D, B> {
type D = D;
fn data(&self) -> &Self::D {
&self.data
}
}
impl<D: Data, B: Backend> DataViewMut for VecZnxDft<D, B> {
fn data_mut(&mut self) -> &mut Self::D {
&mut self.data
}
}
impl<D: DataRef, B: Backend> VecZnxDft<D, B> {
pub fn max_size(&self) -> usize {
self.max_size
}
}
impl<D: DataMut, B: Backend> VecZnxDft<D, B> {
pub fn set_size(&mut self, size: usize) {
assert!(size <= self.max_size);
self.size = size
}
}
impl<D: DataMut, B: Backend> ZnxZero for VecZnxDft<D, B>
where
Self: ZnxViewMut,
<Self as ZnxView>::Scalar: Zero + Copy,
{
fn zero(&mut self) {
self.raw_mut().fill(<Self as ZnxView>::Scalar::zero())
}
fn zero_at(&mut self, i: usize, j: usize) {
self.at_mut(i, j).fill(<Self as ZnxView>::Scalar::zero());
}
}
pub trait VecZnxDftBytesOf {
fn bytes_of(n: usize, cols: usize, size: usize) -> usize;
}
impl<D: DataRef + From<Vec<u8>>, B: Backend> VecZnxDft<D, B>
where
VecZnxDft<D, B>: VecZnxDftBytesOf,
{
pub(crate) fn alloc(n: usize, cols: usize, size: usize) -> Self {
let data: Vec<u8> = alloc_aligned::<u8>(Self::bytes_of(n, cols, size));
Self {
data: data.into(),
n,
cols,
size,
max_size: size,
_phantom: PhantomData,
}
}
pub(crate) fn from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
let data: Vec<u8> = bytes.into();
assert!(data.len() == Self::bytes_of(n, cols, size));
Self {
data: data.into(),
n,
cols,
size,
max_size: size,
_phantom: PhantomData,
}
}
}
pub type VecZnxDftOwned<B> = VecZnxDft<Vec<u8>, B>;
impl<D: Data, B: Backend> VecZnxDft<D, B> {
pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
Self {
data,
n,
cols,
size,
max_size: size,
_phantom: PhantomData,
}
}
}
pub trait VecZnxDftToRef<B: Backend> {
fn to_ref(&self) -> VecZnxDft<&[u8], B>;
}
impl<D: DataRef, B: Backend> VecZnxDftToRef<B> for VecZnxDft<D, B> {
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
VecZnxDft {
data: self.data.as_ref(),
n: self.n,
cols: self.cols,
size: self.size,
max_size: self.max_size,
_phantom: std::marker::PhantomData,
}
}
}
pub trait VecZnxDftToMut<B: Backend> {
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B>;
}
impl<D: DataMut, B: Backend> VecZnxDftToMut<B> for VecZnxDft<D, B> {
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
VecZnxDft {
data: self.data.as_mut(),
n: self.n,
cols: self.cols,
size: self.size,
max_size: self.max_size,
_phantom: std::marker::PhantomData,
}
}
}

View File

@@ -0,0 +1,157 @@
use std::marker::PhantomData;
use crate::{
alloc_aligned,
hal::{
api::{DataView, DataViewMut, ZnxInfos},
layouts::{Backend, Data, DataMut, DataRef},
},
};
#[derive(PartialEq, Eq)]
pub struct VmpPMat<D: Data, B: Backend> {
data: D,
n: usize,
size: usize,
rows: usize,
cols_in: usize,
cols_out: usize,
_phantom: PhantomData<B>,
}
impl<D: Data, B: Backend> ZnxInfos for VmpPMat<D, B> {
fn cols(&self) -> usize {
self.cols_in
}
fn rows(&self) -> usize {
self.rows
}
fn n(&self) -> usize {
self.n
}
fn size(&self) -> usize {
self.size
}
}
impl<D: Data, B: Backend> DataView for VmpPMat<D, B> {
type D = D;
fn data(&self) -> &Self::D {
&self.data
}
}
impl<D: Data, B: Backend> DataViewMut for VmpPMat<D, B> {
fn data_mut(&mut self) -> &mut Self::D {
&mut self.data
}
}
impl<D: Data, B: Backend> VmpPMat<D, B> {
pub fn cols_in(&self) -> usize {
self.cols_in
}
pub fn cols_out(&self) -> usize {
self.cols_out
}
}
pub trait VmpPMatBytesOf {
fn vmp_pmat_bytes_of(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
}
impl<D: DataRef + From<Vec<u8>>, B: Backend> VmpPMat<D, B>
where
B: VmpPMatBytesOf,
{
pub(crate) fn alloc(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
let data: Vec<u8> = alloc_aligned(B::vmp_pmat_bytes_of(n, rows, cols_in, cols_out, size));
Self {
data: data.into(),
n,
size,
rows,
cols_in,
cols_out,
_phantom: PhantomData,
}
}
pub(crate) fn from_bytes(
n: usize,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
bytes: impl Into<Vec<u8>>,
) -> Self {
let data: Vec<u8> = bytes.into();
assert!(data.len() == B::vmp_pmat_bytes_of(n, rows, cols_in, cols_out, size));
Self {
data: data.into(),
n,
size,
rows,
cols_in,
cols_out,
_phantom: PhantomData,
}
}
}
pub type VmpPMatOwned<B> = VmpPMat<Vec<u8>, B>;
pub type VmpPMatRef<'a, B> = VmpPMat<&'a [u8], B>;
pub trait VmpPMatToRef<B: Backend> {
fn to_ref(&self) -> VmpPMat<&[u8], B>;
}
impl<D: DataRef, B: Backend> VmpPMatToRef<B> for VmpPMat<D, B> {
fn to_ref(&self) -> VmpPMat<&[u8], B> {
VmpPMat {
data: self.data.as_ref(),
n: self.n,
rows: self.rows,
cols_in: self.cols_in,
cols_out: self.cols_out,
size: self.size,
_phantom: std::marker::PhantomData,
}
}
}
pub trait VmpPMatToMut<B: Backend> {
fn to_mut(&mut self) -> VmpPMat<&mut [u8], B>;
}
impl<D: DataMut, B: Backend> VmpPMatToMut<B> for VmpPMat<D, B> {
fn to_mut(&mut self) -> VmpPMat<&mut [u8], B> {
VmpPMat {
data: self.data.as_mut(),
n: self.n,
rows: self.rows,
cols_in: self.cols_in,
cols_out: self.cols_out,
size: self.size,
_phantom: std::marker::PhantomData,
}
}
}
impl<D: Data, B: Backend> VmpPMat<D, B> {
pub(crate) fn from_data(data: D, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
Self {
data,
n,
rows,
cols_in,
cols_out,
size,
_phantom: PhantomData,
}
}
}

View File

@@ -0,0 +1,6 @@
pub mod api;
pub mod delegates;
pub mod layouts;
pub mod oep;
pub mod source;
pub mod tests;

View File

@@ -0,0 +1,15 @@
mod module;
mod scratch;
mod svp_ppol;
mod vec_znx;
mod vec_znx_big;
mod vec_znx_dft;
mod vmp_pmat;
pub use module::*;
pub use scratch::*;
pub use svp_ppol::*;
pub use vec_znx::*;
pub use vec_znx_big::*;
pub use vec_znx_dft::*;
pub use vmp_pmat::*;

View File

@@ -0,0 +1,9 @@
use crate::hal::layouts::{Backend, Module};
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait ModuleNewImpl<B: Backend> {
fn new_impl(n: u64) -> Module<B>;
}

View File

@@ -0,0 +1,259 @@
use crate::hal::{
api::ZnxInfos,
layouts::{Backend, DataRef, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat},
};
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait ScratchOwnedAllocImpl<B: Backend> {
fn scratch_owned_alloc_impl(size: usize) -> ScratchOwned<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait ScratchOwnedBorrowImpl<B: Backend> {
fn scratch_owned_borrow_impl(scratch: &mut ScratchOwned<B>) -> &mut Scratch<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait ScratchFromBytesImpl<B: Backend> {
fn scratch_from_bytes_impl(data: &mut [u8]) -> &mut Scratch<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait ScratchAvailableImpl<B: Backend> {
fn scratch_available_impl(scratch: &Scratch<B>) -> usize;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait TakeSliceImpl<B: Backend> {
fn take_slice_impl<T>(scratch: &mut Scratch<B>, len: usize) -> (&mut [T], &mut Scratch<B>);
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait TakeScalarZnxImpl<B: Backend> {
fn take_scalar_znx_impl(scratch: &mut Scratch<B>, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Scratch<B>);
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait TakeSvpPPolImpl<B: Backend> {
fn take_svp_ppol_impl(scratch: &mut Scratch<B>, n: usize, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Scratch<B>);
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait TakeVecZnxImpl<B: Backend> {
fn take_vec_znx_impl(scratch: &mut Scratch<B>, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Scratch<B>);
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait TakeVecZnxSliceImpl<B: Backend> {
fn take_vec_znx_slice_impl(
scratch: &mut Scratch<B>,
len: usize,
n: usize,
cols: usize,
size: usize,
) -> (Vec<VecZnx<&mut [u8]>>, &mut Scratch<B>);
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait TakeVecZnxBigImpl<B: Backend> {
fn take_vec_znx_big_impl(
scratch: &mut Scratch<B>,
n: usize,
cols: usize,
size: usize,
) -> (VecZnxBig<&mut [u8], B>, &mut Scratch<B>);
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait TakeVecZnxDftImpl<B: Backend> {
fn take_vec_znx_dft_impl(
scratch: &mut Scratch<B>,
n: usize,
cols: usize,
size: usize,
) -> (VecZnxDft<&mut [u8], B>, &mut Scratch<B>);
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait TakeVecZnxDftSliceImpl<B: Backend> {
fn take_vec_znx_dft_slice_impl(
scratch: &mut Scratch<B>,
len: usize,
n: usize,
cols: usize,
size: usize,
) -> (Vec<VecZnxDft<&mut [u8], B>>, &mut Scratch<B>);
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait TakeVmpPMatImpl<B: Backend> {
fn take_vmp_pmat_impl(
scratch: &mut Scratch<B>,
n: usize,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
) -> (VmpPMat<&mut [u8], B>, &mut Scratch<B>);
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait TakeMatZnxImpl<B: Backend> {
fn take_mat_znx_impl(
scratch: &mut Scratch<B>,
n: usize,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
) -> (MatZnx<&mut [u8]>, &mut Scratch<B>);
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub trait TakeLikeImpl<'a, B: Backend, T> {
type Output;
fn take_like_impl(scratch: &'a mut Scratch<B>, template: &T) -> (Self::Output, &'a mut Scratch<B>);
}
impl<'a, B: Backend, D> TakeLikeImpl<'a, B, VmpPMat<D, B>> for B
where
B: TakeVmpPMatImpl<B>,
D: DataRef,
{
type Output = VmpPMat<&'a mut [u8], B>;
fn take_like_impl(scratch: &'a mut Scratch<B>, template: &VmpPMat<D, B>) -> (Self::Output, &'a mut Scratch<B>) {
B::take_vmp_pmat_impl(
scratch,
template.n(),
template.rows(),
template.cols_in(),
template.cols_out(),
template.size(),
)
}
}
impl<'a, B: Backend, D> TakeLikeImpl<'a, B, MatZnx<D>> for B
where
B: TakeMatZnxImpl<B>,
D: DataRef,
{
type Output = MatZnx<&'a mut [u8]>;
fn take_like_impl(scratch: &'a mut Scratch<B>, template: &MatZnx<D>) -> (Self::Output, &'a mut Scratch<B>) {
B::take_mat_znx_impl(
scratch,
template.n(),
template.rows(),
template.cols_in(),
template.cols_out(),
template.size(),
)
}
}
impl<'a, B: Backend, D> TakeLikeImpl<'a, B, VecZnxDft<D, B>> for B
where
B: TakeVecZnxDftImpl<B>,
D: DataRef,
{
type Output = VecZnxDft<&'a mut [u8], B>;
fn take_like_impl(scratch: &'a mut Scratch<B>, template: &VecZnxDft<D, B>) -> (Self::Output, &'a mut Scratch<B>) {
B::take_vec_znx_dft_impl(scratch, template.n(), template.cols(), template.size())
}
}
impl<'a, B: Backend, D> TakeLikeImpl<'a, B, VecZnxBig<D, B>> for B
where
B: TakeVecZnxBigImpl<B>,
D: DataRef,
{
type Output = VecZnxBig<&'a mut [u8], B>;
fn take_like_impl(scratch: &'a mut Scratch<B>, template: &VecZnxBig<D, B>) -> (Self::Output, &'a mut Scratch<B>) {
B::take_vec_znx_big_impl(scratch, template.n(), template.cols(), template.size())
}
}
impl<'a, B: Backend, D> TakeLikeImpl<'a, B, SvpPPol<D, B>> for B
where
B: TakeSvpPPolImpl<B>,
D: DataRef,
{
type Output = SvpPPol<&'a mut [u8], B>;
fn take_like_impl(scratch: &'a mut Scratch<B>, template: &SvpPPol<D, B>) -> (Self::Output, &'a mut Scratch<B>) {
B::take_svp_ppol_impl(scratch, template.n(), template.cols())
}
}
impl<'a, B: Backend, D> TakeLikeImpl<'a, B, VecZnx<D>> for B
where
B: TakeVecZnxImpl<B>,
D: DataRef,
{
type Output = VecZnx<&'a mut [u8]>;
fn take_like_impl(scratch: &'a mut Scratch<B>, template: &VecZnx<D>) -> (Self::Output, &'a mut Scratch<B>) {
B::take_vec_znx_impl(scratch, template.n(), template.cols(), template.size())
}
}
impl<'a, B: Backend, D> TakeLikeImpl<'a, B, ScalarZnx<D>> for B
where
B: TakeScalarZnxImpl<B>,
D: DataRef,
{
type Output = ScalarZnx<&'a mut [u8]>;
fn take_like_impl(scratch: &'a mut Scratch<B>, template: &ScalarZnx<D>) -> (Self::Output, &'a mut Scratch<B>) {
B::take_scalar_znx_impl(scratch, template.n(), template.cols())
}
}

View File

@@ -0,0 +1,61 @@
use crate::hal::layouts::{
Backend, Module, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDftToMut, VecZnxDftToRef,
};
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait SvpPPolFromBytesImpl<B: Backend> {
fn svp_ppol_from_bytes_impl(n: usize, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait SvpPPolAllocImpl<B: Backend> {
fn svp_ppol_alloc_impl(n: usize, cols: usize) -> SvpPPolOwned<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait SvpPPolAllocBytesImpl<B: Backend> {
fn svp_ppol_alloc_bytes_impl(n: usize, cols: usize) -> usize;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait SvpPrepareImpl<B: Backend> {
fn svp_prepare_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: SvpPPolToMut<B>,
A: ScalarZnxToRef;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait SvpApplyImpl<B: Backend> {
fn svp_apply_impl<R, A, C>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
where
R: VecZnxDftToMut<B>,
A: SvpPPolToRef<B>,
C: VecZnxDftToRef<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait SvpApplyInplaceImpl: Backend {
fn svp_apply_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<Self>,
A: SvpPPolToRef<Self>;
}

View File

@@ -0,0 +1,365 @@
use rand_distr::Distribution;
use crate::hal::{
layouts::{Backend, Module, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef},
source::Source,
};
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_normalize_base2k_tmp_bytes_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L245C17-L245C55) for reference code.
/// * See [crate::hal::api::VecZnxNormalizeTmpBytes] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxNormalizeTmpBytesImpl<B: Backend> {
fn vec_znx_normalize_tmp_bytes_impl(module: &Module<B>, n: usize) -> usize;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_normalize_base2k_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L212) for reference code.
/// * See [crate::hal::api::VecZnxNormalize] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxNormalizeImpl<B: Backend> {
fn vec_znx_normalize_impl<R, A>(
module: &Module<B>,
basek: usize,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch<B>,
) where
R: VecZnxToMut,
A: VecZnxToRef;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_normalize_base2k_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L212) for reference code.
/// * See [crate::hal::api::VecZnxNormalizeInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxNormalizeInplaceImpl<B: Backend> {
fn vec_znx_normalize_inplace_impl<A>(module: &Module<B>, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
where
A: VecZnxToMut;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_add_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L86) for reference code.
/// * See [crate::hal::api::VecZnxAdd] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxAddImpl<B: Backend> {
fn vec_znx_add_impl<R, A, C>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
C: VecZnxToRef;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_add_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L86) for reference code.
/// * See [crate::hal::api::VecZnxAddInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxAddInplaceImpl<B: Backend> {
fn vec_znx_add_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_add_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L86) for reference code.
/// * See [crate::hal::api::VecZnxAddScalarInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxAddScalarInplaceImpl<B: Backend> {
fn vec_znx_add_scalar_inplace_impl<R, A>(
module: &Module<B>,
res: &mut R,
res_col: usize,
res_limb: usize,
a: &A,
a_col: usize,
) where
R: VecZnxToMut,
A: ScalarZnxToRef;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_sub_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L125) for reference code.
/// * See [crate::hal::api::VecZnxSub] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxSubImpl<B: Backend> {
fn vec_znx_sub_impl<R, A, C>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
C: VecZnxToRef;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_sub_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L125) for reference code.
/// * See [crate::hal::api::VecZnxSubABInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxSubABInplaceImpl<B: Backend> {
fn vec_znx_sub_ab_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_sub_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L125) for reference code.
/// * See [crate::hal::api::VecZnxSubBAInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxSubBAInplaceImpl<B: Backend> {
fn vec_znx_sub_ba_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_sub_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L125) for reference code.
/// * See [crate::hal::api::VecZnxSubScalarInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxSubScalarInplaceImpl<B: Backend> {
fn vec_znx_sub_scalar_inplace_impl<R, A>(
module: &Module<B>,
res: &mut R,
res_col: usize,
res_limb: usize,
a: &A,
a_col: usize,
) where
R: VecZnxToMut,
A: ScalarZnxToRef;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_negate_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L322C13-L322C31) for reference code.
/// * See [crate::hal::api::VecZnxNegate] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxNegateImpl<B: Backend> {
fn vec_znx_negate_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_negate_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L322C13-L322C31) for reference code.
/// * See [crate::hal::api::VecZnxNegateInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxNegateInplaceImpl<B: Backend> {
fn vec_znx_negate_inplace_impl<A>(module: &Module<B>, a: &mut A, a_col: usize)
where
A: VecZnxToMut;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::implementation::cpu_spqlios::vec_znx::vec_znx_rsh_inplace_ref] for reference code.
/// * See [crate::hal::api::VecZnxRshInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxRshInplaceImpl<B: Backend> {
fn vec_znx_rsh_inplace_impl<A>(module: &Module<B>, basek: usize, k: usize, a: &mut A)
where
A: VecZnxToMut;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::implementation::cpu_spqlios::vec_znx::vec_znx_lsh_inplace_ref] for reference code.
/// * See [crate::hal::api::VecZnxLshInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxLshInplaceImpl<B: Backend> {
fn vec_znx_lsh_inplace_impl<A>(module: &Module<B>, basek: usize, k: usize, a: &mut A)
where
A: VecZnxToMut;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_rotate_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L164) for reference code.
/// * See [crate::hal::api::VecZnxRotate] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxRotateImpl<B: Backend> {
fn vec_znx_rotate_impl<R, A>(module: &Module<B>, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_rotate_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L164) for reference code.
/// * See [crate::hal::api::VecZnxRotateInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxRotateInplaceImpl<B: Backend> {
fn vec_znx_rotate_inplace_impl<A>(module: &Module<B>, k: i64, a: &mut A, a_col: usize)
where
A: VecZnxToMut;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_automorphism_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L188) for reference code.
/// * See [crate::hal::api::VecZnxAutomorphism] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxAutomorphismImpl<B: Backend> {
fn vec_znx_automorphism_impl<R, A>(module: &Module<B>, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_automorphism_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L188) for reference code.
/// * See [crate::hal::api::VecZnxAutomorphismInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxAutomorphismInplaceImpl<B: Backend> {
fn vec_znx_automorphism_inplace_impl<A>(module: &Module<B>, k: i64, a: &mut A, a_col: usize)
where
A: VecZnxToMut;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_mul_xp_minus_one_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/7160f588da49712a042931ea247b4259b95cefcc/spqlios/arithmetic/vec_znx.c#L200C13-L200C41) for reference code.
/// * See [crate::hal::api::VecZnxMulXpMinusOne] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxMulXpMinusOneImpl<B: Backend> {
fn vec_znx_mul_xp_minus_one_impl<R, A>(module: &Module<B>, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_mul_xp_minus_one_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/7160f588da49712a042931ea247b4259b95cefcc/spqlios/arithmetic/vec_znx.c#L200C13-L200C41) for reference code.
/// * See [crate::hal::api::VecZnxMulXpMinusOneInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxMulXpMinusOneInplaceImpl<B: Backend> {
fn vec_znx_mul_xp_minus_one_inplace_impl<R>(module: &Module<B>, p: i64, res: &mut R, res_col: usize)
where
R: VecZnxToMut;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::implementation::cpu_spqlios::vec_znx::vec_znx_split_ref] for reference code.
/// * See [crate::hal::api::VecZnxSplit] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxSplitImpl<B: Backend> {
fn vec_znx_split_impl<R, A>(module: &Module<B>, res: &mut [R], res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
where
R: VecZnxToMut,
A: VecZnxToRef;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::implementation::cpu_spqlios::vec_znx::vec_znx_merge_ref] for reference code.
/// * See [crate::hal::api::VecZnxMerge] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxMergeImpl<B: Backend> {
fn vec_znx_merge_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &[A], a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::implementation::cpu_spqlios::vec_znx::vec_znx_switch_degree_ref] for reference code.
/// * See [crate::hal::api::VecZnxSwithcDegree] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxSwithcDegreeImpl<B: Backend> {
fn vec_znx_switch_degree_impl<R: VecZnxToMut, A: VecZnxToRef>(
module: &Module<B>,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
);
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::implementation::cpu_spqlios::vec_znx::vec_znx_copy_ref] for reference code.
/// * See [crate::hal::api::VecZnxCopy] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxCopyImpl<B: Backend> {
fn vec_znx_copy_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::hal::api::VecZnxFillUniform] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxFillUniformImpl<B: Backend> {
fn vec_znx_fill_uniform_impl<R>(module: &Module<B>, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source)
where
R: VecZnxToMut;
}
#[allow(clippy::too_many_arguments)]
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::hal::api::VecZnxFillDistF64] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxFillDistF64Impl<B: Backend> {
fn vec_znx_fill_dist_f64_impl<R, D: Distribution<f64>>(
module: &Module<B>,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) where
R: VecZnxToMut;
}
#[allow(clippy::too_many_arguments)]
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::hal::api::VecZnxAddDistF64] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxAddDistF64Impl<B: Backend> {
fn vec_znx_add_dist_f64_impl<R, D: Distribution<f64>>(
module: &Module<B>,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) where
R: VecZnxToMut;
}
#[allow(clippy::too_many_arguments)]
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::hal::api::VecZnxFillNormal] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxFillNormalImpl<B: Backend> {
fn vec_znx_fill_normal_impl<R>(
module: &Module<B>,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) where
R: VecZnxToMut;
}
#[allow(clippy::too_many_arguments)]
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::hal::api::VecZnxAddNormal] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxAddNormalImpl<B: Backend> {
fn vec_znx_add_normal_impl<R>(
module: &Module<B>,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) where
R: VecZnxToMut;
}

View File

@@ -0,0 +1,306 @@
use rand_distr::Distribution;
use crate::hal::{
layouts::{Backend, Module, Scratch, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef},
source::Source,
};
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigAllocImpl<B: Backend> {
fn vec_znx_big_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxBigOwned<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigFromBytesImpl<B: Backend> {
fn vec_znx_big_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigAllocBytesImpl<B: Backend> {
fn vec_znx_big_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize;
}
#[allow(clippy::too_many_arguments)]
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigAddNormalImpl<B: Backend> {
fn add_normal_impl<R: VecZnxBigToMut<B>>(
module: &Module<B>,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
);
}
#[allow(clippy::too_many_arguments)]
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigFillNormalImpl<B: Backend> {
fn fill_normal_impl<R: VecZnxBigToMut<B>>(
module: &Module<B>,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
);
}
#[allow(clippy::too_many_arguments)]
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigFillDistF64Impl<B: Backend> {
fn fill_dist_f64_impl<R: VecZnxBigToMut<B>, D: Distribution<f64>>(
module: &Module<B>,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
);
}
#[allow(clippy::too_many_arguments)]
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigAddDistF64Impl<B: Backend> {
fn add_dist_f64_impl<R: VecZnxBigToMut<B>, D: Distribution<f64>>(
module: &Module<B>,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
);
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigAddImpl<B: Backend> {
fn vec_znx_big_add_impl<R, A, C>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxBigToRef<B>,
C: VecZnxBigToRef<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigAddInplaceImpl<B: Backend> {
fn vec_znx_big_add_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxBigToRef<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigAddSmallImpl<B: Backend> {
fn vec_znx_big_add_small_impl<R, A, C>(
module: &Module<B>,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
b: &C,
b_col: usize,
) where
R: VecZnxBigToMut<B>,
A: VecZnxBigToRef<B>,
C: VecZnxToRef;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigAddSmallInplaceImpl<B: Backend> {
fn vec_znx_big_add_small_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxToRef;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigSubImpl<B: Backend> {
fn vec_znx_big_sub_impl<R, A, C>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxBigToRef<B>,
C: VecZnxBigToRef<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigSubABInplaceImpl<B: Backend> {
fn vec_znx_big_sub_ab_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxBigToRef<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigSubBAInplaceImpl<B: Backend> {
fn vec_znx_big_sub_ba_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxBigToRef<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigSubSmallAImpl<B: Backend> {
fn vec_znx_big_sub_small_a_impl<R, A, C>(
module: &Module<B>,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
b: &C,
b_col: usize,
) where
R: VecZnxBigToMut<B>,
A: VecZnxToRef,
C: VecZnxBigToRef<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigSubSmallAInplaceImpl<B: Backend> {
fn vec_znx_big_sub_small_a_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxToRef;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigSubSmallBImpl<B: Backend> {
fn vec_znx_big_sub_small_b_impl<R, A, C>(
module: &Module<B>,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
b: &C,
b_col: usize,
) where
R: VecZnxBigToMut<B>,
A: VecZnxBigToRef<B>,
C: VecZnxToRef;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigSubSmallBInplaceImpl<B: Backend> {
fn vec_znx_big_sub_small_b_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxToRef;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigNegateInplaceImpl<B: Backend> {
fn vec_znx_big_negate_inplace_impl<A>(module: &Module<B>, a: &mut A, a_col: usize)
where
A: VecZnxBigToMut<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigNormalizeTmpBytesImpl<B: Backend> {
fn vec_znx_big_normalize_tmp_bytes_impl(module: &Module<B>, n: usize) -> usize;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigNormalizeImpl<B: Backend> {
fn vec_znx_big_normalize_impl<R, A>(
module: &Module<B>,
basek: usize,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch<B>,
) where
R: VecZnxToMut,
A: VecZnxBigToRef<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigAutomorphismImpl<B: Backend> {
fn vec_znx_big_automorphism_impl<R, A>(module: &Module<B>, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxBigToRef<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigAutomorphismInplaceImpl<B: Backend> {
fn vec_znx_big_automorphism_inplace_impl<A>(module: &Module<B>, k: i64, a: &mut A, a_col: usize)
where
A: VecZnxBigToMut<B>;
}

View File

@@ -0,0 +1,177 @@
use crate::hal::layouts::{
Backend, Data, Module, Scratch, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef,
VecZnxToRef,
};
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxDftAllocImpl<B: Backend> {
fn vec_znx_dft_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxDftOwned<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxDftFromBytesImpl<B: Backend> {
fn vec_znx_dft_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxDftAllocBytesImpl<B: Backend> {
fn vec_znx_dft_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxDftToVecZnxBigTmpBytesImpl<B: Backend> {
fn vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(module: &Module<B>, n: usize) -> usize;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxDftToVecZnxBigImpl<B: Backend> {
fn vec_znx_dft_to_vec_znx_big_impl<R, A>(
module: &Module<B>,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch<B>,
) where
R: VecZnxBigToMut<B>,
A: VecZnxDftToRef<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxDftToVecZnxBigTmpAImpl<B: Backend> {
fn vec_znx_dft_to_vec_znx_big_tmp_a_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxDftToMut<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxDftToVecZnxBigConsumeImpl<B: Backend> {
fn vec_znx_dft_to_vec_znx_big_consume_impl<D: Data>(module: &Module<B>, a: VecZnxDft<D, B>) -> VecZnxBig<D, B>
where
VecZnxDft<D, B>: VecZnxDftToMut<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxDftAddImpl<B: Backend> {
fn vec_znx_dft_add_impl<R, A, D>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>,
D: VecZnxDftToRef<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxDftAddInplaceImpl<B: Backend> {
fn vec_znx_dft_add_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxDftSubImpl<B: Backend> {
fn vec_znx_dft_sub_impl<R, A, D>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>,
D: VecZnxDftToRef<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxDftSubABInplaceImpl<B: Backend> {
fn vec_znx_dft_sub_ab_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxDftSubBAInplaceImpl<B: Backend> {
fn vec_znx_dft_sub_ba_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxDftCopyImpl<B: Backend> {
fn vec_znx_dft_copy_impl<R, A>(
module: &Module<B>,
step: usize,
offset: usize,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
) where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxDftFromVecZnxImpl<B: Backend> {
fn vec_znx_dft_from_vec_znx_impl<R, A>(
module: &Module<B>,
step: usize,
offset: usize,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
) where
R: VecZnxDftToMut<B>,
A: VecZnxToRef;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxDftZeroImpl<B: Backend> {
fn vec_znx_dft_zero_impl<R>(module: &Module<B>, res: &mut R)
where
R: VecZnxDftToMut<B>;
}

View File

@@ -0,0 +1,121 @@
use crate::hal::layouts::{
Backend, MatZnxToRef, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VmpPMatOwned, VmpPMatToMut, VmpPMatToRef,
};
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VmpPMatAllocImpl<B: Backend> {
fn vmp_pmat_alloc_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VmpPMatAllocBytesImpl<B: Backend> {
fn vmp_pmat_alloc_bytes_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VmpPMatFromBytesImpl<B: Backend> {
fn vmp_pmat_from_bytes_impl(
n: usize,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
bytes: Vec<u8>,
) -> VmpPMatOwned<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VmpPrepareTmpBytesImpl<B: Backend> {
fn vmp_prepare_tmp_bytes_impl(
module: &Module<B>,
n: usize,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
) -> usize;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VmpPMatPrepareImpl<B: Backend> {
fn vmp_prepare_impl<R, A>(module: &Module<B>, res: &mut R, a: &A, scratch: &mut Scratch<B>)
where
R: VmpPMatToMut<B>,
A: MatZnxToRef;
}
#[allow(clippy::too_many_arguments)]
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VmpApplyTmpBytesImpl<B: Backend> {
fn vmp_apply_tmp_bytes_impl(
module: &Module<B>,
n: usize,
res_size: usize,
a_size: usize,
b_rows: usize,
b_cols_in: usize,
b_cols_out: usize,
b_size: usize,
) -> usize;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VmpApplyImpl<B: Backend> {
fn vmp_apply_impl<R, A, C>(module: &Module<B>, res: &mut R, a: &A, b: &C, scratch: &mut Scratch<B>)
where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>,
C: VmpPMatToRef<B>;
}
#[allow(clippy::too_many_arguments)]
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VmpApplyAddTmpBytesImpl<B: Backend> {
fn vmp_apply_add_tmp_bytes_impl(
module: &Module<B>,
n: usize,
res_size: usize,
a_size: usize,
b_rows: usize,
b_cols_in: usize,
b_cols_out: usize,
b_size: usize,
) -> usize;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VmpApplyAddImpl<B: Backend> {
// Same as [MatZnxDftOps::vmp_apply] except result is added on R instead of overwritting R.
fn vmp_apply_add_impl<R, A, C>(module: &Module<B>, res: &mut R, a: &A, b: &C, scale: usize, scratch: &mut Scratch<B>)
where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>,
C: VmpPMatToRef<B>;
}

View File

@@ -0,0 +1,62 @@
use rand_chacha::{ChaCha8Rng, rand_core::SeedableRng};
use rand_core::RngCore;
const MAXF64: f64 = 9007199254740992.0;
pub struct Source {
source: ChaCha8Rng,
}
impl Source {
pub fn new(seed: [u8; 32]) -> Source {
Source {
source: ChaCha8Rng::from_seed(seed),
}
}
pub fn branch(&mut self) -> ([u8; 32], Self) {
let seed: [u8; 32] = self.new_seed();
(seed, Source::new(seed))
}
pub fn new_seed(&mut self) -> [u8; 32] {
let mut seed: [u8; 32] = [0u8; 32];
self.fill_bytes(&mut seed);
seed
}
#[inline(always)]
pub fn next_u64n(&mut self, max: u64, mask: u64) -> u64 {
let mut x: u64 = self.next_u64() & mask;
while x >= max {
x = self.next_u64() & mask;
}
x
}
#[inline(always)]
pub fn next_f64(&mut self, min: f64, max: f64) -> f64 {
min + ((self.next_u64() << 11 >> 11) as f64) / MAXF64 * (max - min)
}
pub fn next_i64(&mut self) -> i64 {
self.next_u64() as i64
}
}
impl RngCore for Source {
#[inline(always)]
fn next_u32(&mut self) -> u32 {
self.source.next_u32()
}
#[inline(always)]
fn next_u64(&mut self) -> u64 {
self.source.next_u64()
}
#[inline(always)]
fn fill_bytes(&mut self, bytes: &mut [u8]) {
self.source.fill_bytes(bytes)
}
}

View File

@@ -0,0 +1,2 @@
pub mod serialization;
pub mod vec_znx;

View File

@@ -0,0 +1,55 @@
use std::fmt::Debug;
use crate::hal::{
api::{FillUniform, Reset},
layouts::{ReaderFrom, WriterTo},
source::Source,
};
/// Generic test for serialization and deserialization.
///
/// - `T` must implement I/O traits, zeroing, cloning, and random filling.
pub fn test_reader_writer_interface<T>(mut original: T)
where
T: WriterTo + ReaderFrom + PartialEq + Eq + Debug + Clone + Reset + FillUniform,
{
// Fill original with uniform random data
let mut source = Source::new([0u8; 32]);
original.fill_uniform(&mut source);
// Serialize into a buffer
let mut buffer = Vec::new();
original.write_to(&mut buffer).expect("write_to failed");
// Prepare receiver: same shape, but zeroed
let mut receiver = original.clone();
receiver.reset();
// Deserialize from buffer
let mut reader: &[u8] = &buffer;
receiver.read_from(&mut reader).expect("read_from failed");
// Ensure serialization round-trip correctness
assert_eq!(
&original, &receiver,
"Deserialized object does not match the original"
);
}
#[test]
fn scalar_znx_serialize() {
let original: crate::hal::layouts::ScalarZnx<Vec<u8>> = crate::hal::layouts::ScalarZnx::alloc(1024, 3);
test_reader_writer_interface(original);
}
#[test]
fn vec_znx_serialize() {
let original: crate::hal::layouts::VecZnx<Vec<u8>> = crate::hal::layouts::VecZnx::alloc(1024, 3, 4);
test_reader_writer_interface(original);
}
#[test]
fn mat_znx_serialize() {
let original: crate::hal::layouts::MatZnx<Vec<u8>> = crate::hal::layouts::MatZnx::alloc(1024, 3, 2, 2, 4);
test_reader_writer_interface(original);
}

View File

@@ -0,0 +1,51 @@
use crate::hal::{
api::{ZnxInfos, ZnxViewMut},
layouts::VecZnx,
source::Source,
};
pub fn test_vec_znx_encode_vec_i64_lo_norm() {
let n: usize = 32;
let basek: usize = 17;
let size: usize = 5;
let k: usize = size * basek - 5;
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, 2, size);
let mut source: Source = Source::new([0u8; 32]);
let raw: &mut [i64] = a.raw_mut();
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
(0..a.cols()).for_each(|col_i| {
let mut have: Vec<i64> = vec![i64::default(); n];
have.iter_mut()
.for_each(|x| *x = (source.next_i64() << 56) >> 56);
a.encode_vec_i64(basek, col_i, k, &have, 10);
let mut want: Vec<i64> = vec![i64::default(); n];
a.decode_vec_i64(basek, col_i, k, &mut want);
assert_eq!(have, want, "{:?} != {:?}", &have, &want);
});
}
pub fn test_vec_znx_encode_vec_i64_hi_norm() {
let n: usize = 32;
let basek: usize = 17;
let size: usize = 5;
for k in [1, basek / 2, size * basek - 5] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, 2, size);
let mut source = Source::new([0u8; 32]);
let raw: &mut [i64] = a.raw_mut();
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
(0..a.cols()).for_each(|col_i| {
let mut have: Vec<i64> = vec![i64::default(); n];
have.iter_mut().for_each(|x| {
if k < 64 {
*x = source.next_u64n(1 << k, (1 << k) - 1) as i64;
} else {
*x = source.next_i64();
}
});
a.encode_vec_i64(basek, col_i, k, &have, 63);
let mut want: Vec<i64> = vec![i64::default(); n];
a.decode_vec_i64(basek, col_i, k, &mut want);
assert_eq!(have, want, "{:?} != {:?}", &have, &want);
})
}
}

View File

@@ -0,0 +1,67 @@
use crate::hal::{
api::{VecZnxAddNormal, VecZnxFillUniform, ZnxView},
layouts::{Backend, Module, VecZnx},
source::Source,
};
pub fn test_vec_znx_fill_uniform<B: Backend>(module: &Module<B>)
where
Module<B>: VecZnxFillUniform,
{
let n: usize = module.n();
let basek: usize = 17;
let size: usize = 5;
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let zero: Vec<i64> = vec![0; n];
let one_12_sqrt: f64 = 0.28867513459481287;
(0..cols).for_each(|col_i| {
let mut a: VecZnx<_> = VecZnx::alloc(n, cols, size);
module.vec_znx_fill_uniform(basek, &mut a, col_i, size * basek, &mut source);
(0..cols).for_each(|col_j| {
if col_j != col_i {
(0..size).for_each(|limb_i| {
assert_eq!(a.at(col_j, limb_i), zero);
})
} else {
let std: f64 = a.std(basek, col_i);
assert!(
(std - one_12_sqrt).abs() < 0.01,
"std={} ~!= {}",
std,
one_12_sqrt
);
}
})
});
}
pub fn test_vec_znx_add_normal<B: Backend>(module: &Module<B>)
where
Module<B>: VecZnxAddNormal,
{
let n: usize = module.n();
let basek: usize = 17;
let k: usize = 2 * 17;
let size: usize = 5;
let sigma: f64 = 3.2;
let bound: f64 = 6.0 * sigma;
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let zero: Vec<i64> = vec![0; n];
let k_f64: f64 = (1u64 << k as u64) as f64;
(0..cols).for_each(|col_i| {
let mut a: VecZnx<_> = VecZnx::alloc(n, cols, size);
module.vec_znx_add_normal(basek, &mut a, col_i, k, &mut source, sigma, bound);
(0..cols).for_each(|col_j| {
if col_j != col_i {
(0..size).for_each(|limb_i| {
assert_eq!(a.at(col_j, limb_i), zero);
})
} else {
let std: f64 = a.std(basek, col_i) * k_f64;
assert!((std - sigma).abs() < 0.1, "std={} ~!= {}", std, sigma);
}
})
});
}

View File

@@ -0,0 +1,5 @@
mod generics;
pub use generics::*;
#[cfg(test)]
mod encoding;

View File

@@ -0,0 +1,7 @@
pub type CNV_PVEC_L = cnv_pvec_l_t;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct cnv_pvec_r_t {
_unused: [u8; 0],
}
pub type CNV_PVEC_R = cnv_pvec_r_t;

View File

@@ -0,0 +1,8 @@
pub mod module;
pub mod reim;
pub mod svp;
pub mod vec_znx;
pub mod vec_znx_big;
pub mod vec_znx_dft;
pub mod vmp;
pub mod znx;

View File

@@ -0,0 +1,19 @@
pub struct module_info_t {
_unused: [u8; 0],
}
pub type module_type_t = ::std::os::raw::c_uint;
pub use self::module_type_t as MODULE_TYPE;
#[allow(clippy::upper_case_acronyms)]
pub type MODULE = module_info_t;
unsafe extern "C" {
pub unsafe fn new_module_info(N: u64, mode: MODULE_TYPE) -> *mut MODULE;
}
unsafe extern "C" {
pub unsafe fn delete_module_info(module_info: *mut MODULE);
}
unsafe extern "C" {
pub unsafe fn module_get_n(module: *const MODULE) -> u64;
}

View File

@@ -0,0 +1,172 @@
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct reim_fft_precomp {
_unused: [u8; 0],
}
pub type REIM_FFT_PRECOMP = reim_fft_precomp;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct reim_ifft_precomp {
_unused: [u8; 0],
}
pub type REIM_IFFT_PRECOMP = reim_ifft_precomp;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct reim_mul_precomp {
_unused: [u8; 0],
}
pub type REIM_FFTVEC_MUL_PRECOMP = reim_mul_precomp;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct reim_addmul_precomp {
_unused: [u8; 0],
}
pub type REIM_FFTVEC_ADDMUL_PRECOMP = reim_addmul_precomp;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct reim_from_znx32_precomp {
_unused: [u8; 0],
}
pub type REIM_FROM_ZNX32_PRECOMP = reim_from_znx32_precomp;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct reim_from_znx64_precomp {
_unused: [u8; 0],
}
pub type REIM_FROM_ZNX64_PRECOMP = reim_from_znx64_precomp;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct reim_from_tnx32_precomp {
_unused: [u8; 0],
}
pub type REIM_FROM_TNX32_PRECOMP = reim_from_tnx32_precomp;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct reim_to_tnx32_precomp {
_unused: [u8; 0],
}
pub type REIM_TO_TNX32_PRECOMP = reim_to_tnx32_precomp;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct reim_to_tnx_precomp {
_unused: [u8; 0],
}
pub type REIM_TO_TNX_PRECOMP = reim_to_tnx_precomp;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct reim_to_znx64_precomp {
_unused: [u8; 0],
}
pub type REIM_TO_ZNX64_PRECOMP = reim_to_znx64_precomp;
unsafe extern "C" {
pub unsafe fn new_reim_fft_precomp(m: u32, num_buffers: u32) -> *mut REIM_FFT_PRECOMP;
}
unsafe extern "C" {
pub unsafe fn reim_fft_precomp_get_buffer(tables: *const REIM_FFT_PRECOMP, buffer_index: u32) -> *mut f64;
}
unsafe extern "C" {
pub unsafe fn new_reim_fft_buffer(m: u32) -> *mut f64;
}
unsafe extern "C" {
pub unsafe fn delete_reim_fft_buffer(buffer: *mut f64);
}
unsafe extern "C" {
pub unsafe fn reim_fft(tables: *const REIM_FFT_PRECOMP, data: *mut f64);
}
unsafe extern "C" {
pub unsafe fn new_reim_ifft_precomp(m: u32, num_buffers: u32) -> *mut REIM_IFFT_PRECOMP;
}
unsafe extern "C" {
pub unsafe fn reim_ifft_precomp_get_buffer(tables: *const REIM_IFFT_PRECOMP, buffer_index: u32) -> *mut f64;
}
unsafe extern "C" {
pub unsafe fn reim_ifft(tables: *const REIM_IFFT_PRECOMP, data: *mut f64);
}
unsafe extern "C" {
pub unsafe fn new_reim_fftvec_mul_precomp(m: u32) -> *mut REIM_FFTVEC_MUL_PRECOMP;
}
unsafe extern "C" {
pub unsafe fn reim_fftvec_mul(tables: *const REIM_FFTVEC_MUL_PRECOMP, r: *mut f64, a: *const f64, b: *const f64);
}
unsafe extern "C" {
pub unsafe fn new_reim_fftvec_addmul_precomp(m: u32) -> *mut REIM_FFTVEC_ADDMUL_PRECOMP;
}
unsafe extern "C" {
pub unsafe fn reim_fftvec_addmul(tables: *const REIM_FFTVEC_ADDMUL_PRECOMP, r: *mut f64, a: *const f64, b: *const f64);
}
unsafe extern "C" {
pub unsafe fn new_reim_from_znx32_precomp(m: u32, log2bound: u32) -> *mut REIM_FROM_ZNX32_PRECOMP;
}
unsafe extern "C" {
pub unsafe fn reim_from_znx32(tables: *const REIM_FROM_ZNX32_PRECOMP, r: *mut ::std::os::raw::c_void, a: *const i32);
}
unsafe extern "C" {
pub unsafe fn reim_from_znx64(tables: *const REIM_FROM_ZNX64_PRECOMP, r: *mut ::std::os::raw::c_void, a: *const i64);
}
unsafe extern "C" {
pub unsafe fn new_reim_from_znx64_precomp(m: u32, maxbnd: u32) -> *mut REIM_FROM_ZNX64_PRECOMP;
}
unsafe extern "C" {
pub unsafe fn reim_from_znx64_simple(m: u32, log2bound: u32, r: *mut ::std::os::raw::c_void, a: *const i64);
}
unsafe extern "C" {
pub unsafe fn new_reim_from_tnx32_precomp(m: u32) -> *mut REIM_FROM_TNX32_PRECOMP;
}
unsafe extern "C" {
pub unsafe fn reim_from_tnx32(tables: *const REIM_FROM_TNX32_PRECOMP, r: *mut ::std::os::raw::c_void, a: *const i32);
}
unsafe extern "C" {
pub unsafe fn new_reim_to_tnx32_precomp(m: u32, divisor: f64, log2overhead: u32) -> *mut REIM_TO_TNX32_PRECOMP;
}
unsafe extern "C" {
pub unsafe fn reim_to_tnx32(tables: *const REIM_TO_TNX32_PRECOMP, r: *mut i32, a: *const ::std::os::raw::c_void);
}
unsafe extern "C" {
pub unsafe fn new_reim_to_tnx_precomp(m: u32, divisor: f64, log2overhead: u32) -> *mut REIM_TO_TNX_PRECOMP;
}
unsafe extern "C" {
pub unsafe fn reim_to_tnx(tables: *const REIM_TO_TNX_PRECOMP, r: *mut f64, a: *const f64);
}
unsafe extern "C" {
pub unsafe fn reim_to_tnx_simple(m: u32, divisor: f64, log2overhead: u32, r: *mut f64, a: *const f64);
}
unsafe extern "C" {
pub unsafe fn new_reim_to_znx64_precomp(m: u32, divisor: f64, log2bound: u32) -> *mut REIM_TO_ZNX64_PRECOMP;
}
unsafe extern "C" {
pub unsafe fn reim_to_znx64(precomp: *const REIM_TO_ZNX64_PRECOMP, r: *mut i64, a: *const ::std::os::raw::c_void);
}
unsafe extern "C" {
pub unsafe fn reim_to_znx64_simple(m: u32, divisor: f64, log2bound: u32, r: *mut i64, a: *const ::std::os::raw::c_void);
}
unsafe extern "C" {
pub unsafe fn reim_fft_simple(m: u32, data: *mut ::std::os::raw::c_void);
}
unsafe extern "C" {
pub unsafe fn reim_ifft_simple(m: u32, data: *mut ::std::os::raw::c_void);
}
unsafe extern "C" {
pub unsafe fn reim_fftvec_mul_simple(
m: u32,
r: *mut ::std::os::raw::c_void,
a: *const ::std::os::raw::c_void,
b: *const ::std::os::raw::c_void,
);
}
unsafe extern "C" {
pub unsafe fn reim_fftvec_addmul_simple(
m: u32,
r: *mut ::std::os::raw::c_void,
a: *const ::std::os::raw::c_void,
b: *const ::std::os::raw::c_void,
);
}
unsafe extern "C" {
pub unsafe fn reim_from_znx32_simple(m: u32, log2bound: u32, r: *mut ::std::os::raw::c_void, x: *const i32);
}
unsafe extern "C" {
pub unsafe fn reim_from_tnx32_simple(m: u32, r: *mut ::std::os::raw::c_void, x: *const i32);
}
unsafe extern "C" {
pub unsafe fn reim_to_tnx32_simple(m: u32, divisor: f64, log2overhead: u32, r: *mut i32, x: *const ::std::os::raw::c_void);
}

View File

@@ -0,0 +1,47 @@
use crate::implementation::cpu_spqlios::ffi::{module::MODULE, vec_znx_dft::VEC_ZNX_DFT};
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct svp_ppol_t {
_unused: [u8; 0],
}
pub type SVP_PPOL = svp_ppol_t;
unsafe extern "C" {
pub unsafe fn bytes_of_svp_ppol(module: *const MODULE) -> u64;
}
unsafe extern "C" {
pub unsafe fn new_svp_ppol(module: *const MODULE) -> *mut SVP_PPOL;
}
unsafe extern "C" {
pub unsafe fn delete_svp_ppol(res: *mut SVP_PPOL);
}
unsafe extern "C" {
pub unsafe fn svp_prepare(module: *const MODULE, ppol: *mut SVP_PPOL, pol: *const i64);
}
unsafe extern "C" {
pub unsafe fn svp_apply_dft(
module: *const MODULE,
res: *const VEC_ZNX_DFT,
res_size: u64,
ppol: *const SVP_PPOL,
a: *const i64,
a_size: u64,
a_sl: u64,
);
}
unsafe extern "C" {
pub unsafe fn svp_apply_dft_to_dft(
module: *const MODULE,
res: *const VEC_ZNX_DFT,
res_size: u64,
res_cols: u64,
ppol: *const SVP_PPOL,
a: *const VEC_ZNX_DFT,
a_size: u64,
a_cols: u64,
);
}

View File

@@ -0,0 +1,115 @@
use crate::implementation::cpu_spqlios::ffi::module::MODULE;
unsafe extern "C" {
pub unsafe fn vec_znx_add(
module: *const MODULE,
res: *mut i64,
res_size: u64,
res_sl: u64,
a: *const i64,
a_size: u64,
a_sl: u64,
b: *const i64,
b_size: u64,
b_sl: u64,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_automorphism(
module: *const MODULE,
p: i64,
res: *mut i64,
res_size: u64,
res_sl: u64,
a: *const i64,
a_size: u64,
a_sl: u64,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_mul_xp_minus_one(
module: *const MODULE,
p: i64,
res: *mut i64,
res_size: u64,
res_sl: u64,
a: *const i64,
a_size: u64,
a_sl: u64,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_negate(
module: *const MODULE,
res: *mut i64,
res_size: u64,
res_sl: u64,
a: *const i64,
a_size: u64,
a_sl: u64,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_rotate(
module: *const MODULE,
p: i64,
res: *mut i64,
res_size: u64,
res_sl: u64,
a: *const i64,
a_size: u64,
a_sl: u64,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_sub(
module: *const MODULE,
res: *mut i64,
res_size: u64,
res_sl: u64,
a: *const i64,
a_size: u64,
a_sl: u64,
b: *const i64,
b_size: u64,
b_sl: u64,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_zero(module: *const MODULE, res: *mut i64, res_size: u64, res_sl: u64);
}
unsafe extern "C" {
pub unsafe fn vec_znx_copy(
module: *const MODULE,
res: *mut i64,
res_size: u64,
res_sl: u64,
a: *const i64,
a_size: u64,
a_sl: u64,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_normalize_base2k(
module: *const MODULE,
n: u64,
base2k: u64,
res: *mut i64,
res_size: u64,
res_sl: u64,
a: *const i64,
a_size: u64,
a_sl: u64,
tmp_space: *mut u8,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_normalize_base2k_tmp_bytes(module: *const MODULE, n: u64) -> u64;
}

View File

@@ -0,0 +1,163 @@
use crate::implementation::cpu_spqlios::ffi::module::MODULE;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct vec_znx_big_t {
_unused: [u8; 0],
}
pub type VEC_ZNX_BIG = vec_znx_big_t;
unsafe extern "C" {
pub unsafe fn bytes_of_vec_znx_big(module: *const MODULE, size: u64) -> u64;
}
unsafe extern "C" {
pub unsafe fn new_vec_znx_big(module: *const MODULE, size: u64) -> *mut VEC_ZNX_BIG;
}
unsafe extern "C" {
pub unsafe fn delete_vec_znx_big(res: *mut VEC_ZNX_BIG);
}
unsafe extern "C" {
pub unsafe fn vec_znx_big_add(
module: *const MODULE,
res: *mut VEC_ZNX_BIG,
res_size: u64,
a: *const VEC_ZNX_BIG,
a_size: u64,
b: *const VEC_ZNX_BIG,
b_size: u64,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_big_add_small(
module: *const MODULE,
res: *mut VEC_ZNX_BIG,
res_size: u64,
a: *const VEC_ZNX_BIG,
a_size: u64,
b: *const i64,
b_size: u64,
b_sl: u64,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_big_add_small2(
module: *const MODULE,
res: *mut VEC_ZNX_BIG,
res_size: u64,
a: *const i64,
a_size: u64,
a_sl: u64,
b: *const i64,
b_size: u64,
b_sl: u64,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_big_sub(
module: *const MODULE,
res: *mut VEC_ZNX_BIG,
res_size: u64,
a: *const VEC_ZNX_BIG,
a_size: u64,
b: *const VEC_ZNX_BIG,
b_size: u64,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_big_sub_small_b(
module: *const MODULE,
res: *mut VEC_ZNX_BIG,
res_size: u64,
a: *const VEC_ZNX_BIG,
a_size: u64,
b: *const i64,
b_size: u64,
b_sl: u64,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_big_sub_small_a(
module: *const MODULE,
res: *mut VEC_ZNX_BIG,
res_size: u64,
a: *const i64,
a_size: u64,
a_sl: u64,
b: *const VEC_ZNX_BIG,
b_size: u64,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_big_sub_small2(
module: *const MODULE,
res: *mut VEC_ZNX_BIG,
res_size: u64,
a: *const i64,
a_size: u64,
a_sl: u64,
b: *const i64,
b_size: u64,
b_sl: u64,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_big_normalize_base2k_tmp_bytes(module: *const MODULE, n: u64) -> u64;
}
unsafe extern "C" {
pub unsafe fn vec_znx_big_normalize_base2k(
module: *const MODULE,
n: u64,
log2_base2k: u64,
res: *mut i64,
res_size: u64,
res_sl: u64,
a: *const VEC_ZNX_BIG,
a_size: u64,
tmp_space: *mut u8,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_big_range_normalize_base2k(
module: *const MODULE,
n: u64,
log2_base2k: u64,
res: *mut i64,
res_size: u64,
res_sl: u64,
a: *const VEC_ZNX_BIG,
a_range_begin: u64,
a_range_xend: u64,
a_range_step: u64,
tmp_space: *mut u8,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_big_range_normalize_base2k_tmp_bytes(module: *const MODULE, n: u64) -> u64;
}
unsafe extern "C" {
pub unsafe fn vec_znx_big_automorphism(
module: *const MODULE,
p: i64,
res: *mut VEC_ZNX_BIG,
res_size: u64,
a: *const VEC_ZNX_BIG,
a_size: u64,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_big_rotate(
module: *const MODULE,
p: i64,
res: *mut VEC_ZNX_BIG,
res_size: u64,
a: *const VEC_ZNX_BIG,
a_size: u64,
);
}

View File

@@ -0,0 +1,69 @@
use crate::implementation::cpu_spqlios::ffi::{module::MODULE, vec_znx_big::VEC_ZNX_BIG};
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct vec_znx_dft_t {
_unused: [u8; 0],
}
pub type VEC_ZNX_DFT = vec_znx_dft_t;
unsafe extern "C" {
pub unsafe fn bytes_of_vec_znx_dft(module: *const MODULE, size: u64) -> u64;
}
unsafe extern "C" {
pub unsafe fn new_vec_znx_dft(module: *const MODULE, size: u64) -> *mut VEC_ZNX_DFT;
}
unsafe extern "C" {
pub unsafe fn delete_vec_znx_dft(res: *mut VEC_ZNX_DFT);
}
unsafe extern "C" {
pub unsafe fn vec_dft_zero(module: *const MODULE, res: *mut VEC_ZNX_DFT, res_size: u64);
}
unsafe extern "C" {
pub unsafe fn vec_dft_add(
module: *const MODULE,
res: *mut VEC_ZNX_DFT,
res_size: u64,
a: *const VEC_ZNX_DFT,
a_size: u64,
b: *const VEC_ZNX_DFT,
b_size: u64,
);
}
unsafe extern "C" {
pub unsafe fn vec_dft_sub(
module: *const MODULE,
res: *mut VEC_ZNX_DFT,
res_size: u64,
a: *const VEC_ZNX_DFT,
a_size: u64,
b: *const VEC_ZNX_DFT,
b_size: u64,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_dft(module: *const MODULE, res: *mut VEC_ZNX_DFT, res_size: u64, a: *const i64, a_size: u64, a_sl: u64);
}
unsafe extern "C" {
pub unsafe fn vec_znx_idft(
module: *const MODULE,
res: *mut VEC_ZNX_BIG,
res_size: u64,
a_dft: *const VEC_ZNX_DFT,
a_size: u64,
tmp: *mut u8,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_idft_tmp_bytes(module: *const MODULE, n: u64) -> u64;
}
unsafe extern "C" {
pub unsafe fn vec_znx_idft_tmp_a(
module: *const MODULE,
res: *mut VEC_ZNX_BIG,
res_size: u64,
a_dft: *mut VEC_ZNX_DFT,
a_size: u64,
);
}

View File

@@ -0,0 +1,114 @@
use crate::implementation::cpu_spqlios::ffi::{module::MODULE, vec_znx_dft::VEC_ZNX_DFT};
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct vmp_pmat_t {
_unused: [u8; 0],
}
// [rows][cols] = [#Decomposition][#Limbs]
pub type VMP_PMAT = vmp_pmat_t;
unsafe extern "C" {
pub unsafe fn bytes_of_vmp_pmat(module: *const MODULE, nrows: u64, ncols: u64) -> u64;
}
unsafe extern "C" {
pub unsafe fn new_vmp_pmat(module: *const MODULE, nrows: u64, ncols: u64) -> *mut VMP_PMAT;
}
unsafe extern "C" {
pub unsafe fn delete_vmp_pmat(res: *mut VMP_PMAT);
}
unsafe extern "C" {
pub unsafe fn vmp_apply_dft(
module: *const MODULE,
res: *mut VEC_ZNX_DFT,
res_size: u64,
a: *const i64,
a_size: u64,
a_sl: u64,
pmat: *const VMP_PMAT,
nrows: u64,
ncols: u64,
tmp_space: *mut u8,
);
}
unsafe extern "C" {
pub unsafe fn vmp_apply_dft_add(
module: *const MODULE,
res: *mut VEC_ZNX_DFT,
res_size: u64,
a: *const i64,
a_size: u64,
a_sl: u64,
pmat: *const VMP_PMAT,
nrows: u64,
ncols: u64,
pmat_scale: u64,
tmp_space: *mut u8,
);
}
unsafe extern "C" {
pub unsafe fn vmp_apply_dft_tmp_bytes(module: *const MODULE, res_size: u64, a_size: u64, nrows: u64, ncols: u64) -> u64;
}
unsafe extern "C" {
pub unsafe fn vmp_apply_dft_to_dft(
module: *const MODULE,
res: *mut VEC_ZNX_DFT,
res_size: u64,
a_dft: *const VEC_ZNX_DFT,
a_size: u64,
pmat: *const VMP_PMAT,
nrows: u64,
ncols: u64,
tmp_space: *mut u8,
);
}
unsafe extern "C" {
pub unsafe fn vmp_apply_dft_to_dft_add(
module: *const MODULE,
res: *mut VEC_ZNX_DFT,
res_size: u64,
a_dft: *const VEC_ZNX_DFT,
a_size: u64,
pmat: *const VMP_PMAT,
nrows: u64,
ncols: u64,
pmat_scale: u64,
tmp_space: *mut u8,
);
}
unsafe extern "C" {
pub unsafe fn vmp_apply_dft_to_dft_tmp_bytes(
module: *const MODULE,
nn: u64,
res_size: u64,
a_size: u64,
nrows: u64,
ncols: u64,
) -> u64;
}
unsafe extern "C" {
pub unsafe fn vmp_prepare_contiguous(
module: *const MODULE,
pmat: *mut VMP_PMAT,
mat: *const i64,
nrows: u64,
ncols: u64,
tmp_space: *mut u8,
);
}
unsafe extern "C" {
pub unsafe fn vmp_prepare_contiguous_dft(module: *const MODULE, pmat: *mut VMP_PMAT, mat: *const f64, nrows: u64, ncols: u64);
}
unsafe extern "C" {
pub unsafe fn vmp_prepare_tmp_bytes(module: *const MODULE, nn: u64, nrows: u64, ncols: u64) -> u64;
}

View File

@@ -0,0 +1,79 @@
use crate::implementation::cpu_spqlios::ffi::module::MODULE;
unsafe extern "C" {
pub unsafe fn znx_add_i64_ref(nn: u64, res: *mut i64, a: *const i64, b: *const i64);
}
unsafe extern "C" {
pub unsafe fn znx_add_i64_avx(nn: u64, res: *mut i64, a: *const i64, b: *const i64);
}
unsafe extern "C" {
pub unsafe fn znx_sub_i64_ref(nn: u64, res: *mut i64, a: *const i64, b: *const i64);
}
unsafe extern "C" {
pub unsafe fn znx_sub_i64_avx(nn: u64, res: *mut i64, a: *const i64, b: *const i64);
}
unsafe extern "C" {
pub unsafe fn znx_negate_i64_ref(nn: u64, res: *mut i64, a: *const i64);
}
unsafe extern "C" {
pub unsafe fn znx_negate_i64_avx(nn: u64, res: *mut i64, a: *const i64);
}
unsafe extern "C" {
pub unsafe fn znx_copy_i64_ref(nn: u64, res: *mut i64, a: *const i64);
}
unsafe extern "C" {
pub unsafe fn znx_zero_i64_ref(nn: u64, res: *mut i64);
}
unsafe extern "C" {
pub unsafe fn rnx_divide_by_m_ref(nn: u64, m: f64, res: *mut f64, a: *const f64);
}
unsafe extern "C" {
pub unsafe fn rnx_divide_by_m_avx(nn: u64, m: f64, res: *mut f64, a: *const f64);
}
unsafe extern "C" {
pub unsafe fn rnx_rotate_f64(nn: u64, p: i64, res: *mut f64, in_: *const f64);
}
unsafe extern "C" {
pub unsafe fn znx_rotate_i64(nn: u64, p: i64, res: *mut i64, in_: *const i64);
}
unsafe extern "C" {
pub unsafe fn rnx_rotate_inplace_f64(nn: u64, p: i64, res: *mut f64);
}
unsafe extern "C" {
pub unsafe fn znx_rotate_inplace_i64(nn: u64, p: i64, res: *mut i64);
}
unsafe extern "C" {
pub unsafe fn rnx_automorphism_f64(nn: u64, p: i64, res: *mut f64, in_: *const f64);
}
unsafe extern "C" {
pub unsafe fn znx_automorphism_i64(nn: u64, p: i64, res: *mut i64, in_: *const i64);
}
unsafe extern "C" {
pub unsafe fn rnx_automorphism_inplace_f64(nn: u64, p: i64, res: *mut f64);
}
unsafe extern "C" {
pub unsafe fn znx_automorphism_inplace_i64(nn: u64, p: i64, res: *mut i64);
}
unsafe extern "C" {
pub unsafe fn rnx_mul_xp_minus_one_f64(nn: u64, p: i64, res: *mut f64, in_: *const f64);
}
unsafe extern "C" {
pub unsafe fn znx_mul_xp_minus_one_i64(nn: u64, p: i64, res: *mut i64, in_: *const i64);
}
unsafe extern "C" {
pub unsafe fn rnx_mul_xp_minus_one_inplace_f64(nn: u64, p: i64, res: *mut f64);
}
unsafe extern "C" {
pub unsafe fn znx_mul_xp_minus_one_inplace_i64(nn: u64, p: i64, res: *mut i64);
}
unsafe extern "C" {
pub unsafe fn znx_normalize(nn: u64, base_k: u64, out: *mut i64, carry_out: *mut i64, in_: *const i64, carry_in: *const i64);
}
unsafe extern "C" {
pub unsafe fn znx_small_single_product(module: *const MODULE, res: *mut i64, a: *const i64, b: *const i64, tmp: *mut u8);
}
unsafe extern "C" {
pub unsafe fn znx_small_single_product_tmp_bytes(module: *const MODULE) -> u64;
}

View File

@@ -0,0 +1,27 @@
mod ffi;
mod module_fft64;
mod module_ntt120;
mod scratch;
mod svp_ppol_fft64;
mod svp_ppol_ntt120;
mod vec_znx;
mod vec_znx_big_fft64;
mod vec_znx_big_ntt120;
mod vec_znx_dft_fft64;
mod vec_znx_dft_ntt120;
mod vmp_pmat_fft64;
mod vmp_pmat_ntt120;
#[cfg(test)]
mod test;
pub use module_fft64::*;
pub use module_ntt120::*;
/// For external documentation
pub use vec_znx::{
vec_znx_copy_ref, vec_znx_lsh_inplace_ref, vec_znx_merge_ref, vec_znx_rsh_inplace_ref, vec_znx_split_ref,
vec_znx_switch_degree_ref,
};
pub trait CPUAVX {}

View File

@@ -0,0 +1,29 @@
use std::ptr::NonNull;
use crate::{
hal::{
layouts::{Backend, Module},
oep::ModuleNewImpl,
},
implementation::cpu_spqlios::{
CPUAVX,
ffi::module::{MODULE, delete_module_info, new_module_info},
},
};
pub struct FFT64;
impl CPUAVX for FFT64 {}
impl Backend for FFT64 {
type Handle = MODULE;
unsafe fn destroy(handle: NonNull<Self::Handle>) {
unsafe { delete_module_info(handle.as_ptr()) }
}
}
unsafe impl ModuleNewImpl<Self> for FFT64 {
fn new_impl(n: u64) -> Module<Self> {
unsafe { Module::from_raw_parts(new_module_info(n, 0), n) }
}
}

View File

@@ -0,0 +1,29 @@
use std::ptr::NonNull;
use crate::{
hal::{
layouts::{Backend, Module},
oep::ModuleNewImpl,
},
implementation::cpu_spqlios::{
CPUAVX,
ffi::module::{MODULE, delete_module_info, new_module_info},
},
};
pub struct NTT120;
impl CPUAVX for NTT120 {}
impl Backend for NTT120 {
type Handle = MODULE;
unsafe fn destroy(handle: NonNull<Self::Handle>) {
unsafe { delete_module_info(handle.as_ptr()) }
}
}
unsafe impl ModuleNewImpl<Self> for NTT120 {
fn new_impl(n: u64) -> Module<Self> {
unsafe { Module::from_raw_parts(new_module_info(n, 1), n) }
}
}

View File

@@ -0,0 +1,271 @@
use std::marker::PhantomData;
use crate::{
DEFAULTALIGN, alloc_aligned,
hal::{
api::ScratchFromBytes,
layouts::{Backend, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat},
oep::{
ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, SvpPPolAllocBytesImpl,
TakeMatZnxImpl, TakeScalarZnxImpl, TakeSliceImpl, TakeSvpPPolImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl,
TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, TakeVmpPMatImpl, VecZnxBigAllocBytesImpl,
VecZnxDftAllocBytesImpl, VmpPMatAllocBytesImpl,
},
},
implementation::cpu_spqlios::CPUAVX,
};
unsafe impl<B: Backend> ScratchOwnedAllocImpl<B> for B
where
B: CPUAVX,
{
fn scratch_owned_alloc_impl(size: usize) -> ScratchOwned<B> {
let data: Vec<u8> = alloc_aligned(size);
ScratchOwned {
data,
_phantom: PhantomData,
}
}
}
unsafe impl<B: Backend> ScratchOwnedBorrowImpl<B> for B
where
B: CPUAVX,
{
fn scratch_owned_borrow_impl(scratch: &mut ScratchOwned<B>) -> &mut Scratch<B> {
Scratch::from_bytes(&mut scratch.data)
}
}
unsafe impl<B: Backend> ScratchFromBytesImpl<B> for B
where
B: CPUAVX,
{
fn scratch_from_bytes_impl(data: &mut [u8]) -> &mut Scratch<B> {
unsafe { &mut *(data as *mut [u8] as *mut Scratch<B>) }
}
}
unsafe impl<B: Backend> ScratchAvailableImpl<B> for B
where
B: CPUAVX,
{
fn scratch_available_impl(scratch: &Scratch<B>) -> usize {
let ptr: *const u8 = scratch.data.as_ptr();
let self_len: usize = scratch.data.len();
let aligned_offset: usize = ptr.align_offset(DEFAULTALIGN);
self_len.saturating_sub(aligned_offset)
}
}
unsafe impl<B: Backend> TakeSliceImpl<B> for B
where
B: CPUAVX,
{
fn take_slice_impl<T>(scratch: &mut Scratch<B>, len: usize) -> (&mut [T], &mut Scratch<B>) {
let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, len * std::mem::size_of::<T>());
unsafe {
(
&mut *(std::ptr::slice_from_raw_parts_mut(take_slice.as_mut_ptr() as *mut T, len)),
Scratch::from_bytes(rem_slice),
)
}
}
}
unsafe impl<B: Backend> TakeScalarZnxImpl<B> for B
where
B: CPUAVX,
{
fn take_scalar_znx_impl(scratch: &mut Scratch<B>, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Scratch<B>) {
let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, ScalarZnx::alloc_bytes(n, cols));
(
ScalarZnx::from_data(take_slice, n, cols),
Scratch::from_bytes(rem_slice),
)
}
}
unsafe impl<B: Backend> TakeSvpPPolImpl<B> for B
where
B: CPUAVX + SvpPPolAllocBytesImpl<B>,
{
fn take_svp_ppol_impl(scratch: &mut Scratch<B>, n: usize, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Scratch<B>) {
let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, B::svp_ppol_alloc_bytes_impl(n, cols));
(
SvpPPol::from_data(take_slice, n, cols),
Scratch::from_bytes(rem_slice),
)
}
}
unsafe impl<B: Backend> TakeVecZnxImpl<B> for B
where
B: CPUAVX,
{
fn take_vec_znx_impl(scratch: &mut Scratch<B>, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Scratch<B>) {
let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, VecZnx::alloc_bytes(n, cols, size));
(
VecZnx::from_data(take_slice, n, cols, size),
Scratch::from_bytes(rem_slice),
)
}
}
unsafe impl<B: Backend> TakeVecZnxBigImpl<B> for B
where
B: CPUAVX + VecZnxBigAllocBytesImpl<B>,
{
fn take_vec_znx_big_impl(
scratch: &mut Scratch<B>,
n: usize,
cols: usize,
size: usize,
) -> (VecZnxBig<&mut [u8], B>, &mut Scratch<B>) {
let (take_slice, rem_slice) = take_slice_aligned(
&mut scratch.data,
B::vec_znx_big_alloc_bytes_impl(n, cols, size),
);
(
VecZnxBig::from_data(take_slice, n, cols, size),
Scratch::from_bytes(rem_slice),
)
}
}
unsafe impl<B: Backend> TakeVecZnxDftImpl<B> for B
where
B: CPUAVX + VecZnxDftAllocBytesImpl<B>,
{
fn take_vec_znx_dft_impl(
scratch: &mut Scratch<B>,
n: usize,
cols: usize,
size: usize,
) -> (VecZnxDft<&mut [u8], B>, &mut Scratch<B>) {
let (take_slice, rem_slice) = take_slice_aligned(
&mut scratch.data,
B::vec_znx_dft_alloc_bytes_impl(n, cols, size),
);
(
VecZnxDft::from_data(take_slice, n, cols, size),
Scratch::from_bytes(rem_slice),
)
}
}
unsafe impl<B: Backend> TakeVecZnxDftSliceImpl<B> for B
where
B: CPUAVX + VecZnxDftAllocBytesImpl<B>,
{
fn take_vec_znx_dft_slice_impl(
scratch: &mut Scratch<B>,
len: usize,
n: usize,
cols: usize,
size: usize,
) -> (Vec<VecZnxDft<&mut [u8], B>>, &mut Scratch<B>) {
let mut scratch: &mut Scratch<B> = scratch;
let mut slice: Vec<VecZnxDft<&mut [u8], B>> = Vec::with_capacity(len);
for _ in 0..len {
let (znx, new_scratch) = B::take_vec_znx_dft_impl(scratch, n, cols, size);
scratch = new_scratch;
slice.push(znx);
}
(slice, scratch)
}
}
unsafe impl<B: Backend> TakeVecZnxSliceImpl<B> for B
where
B: CPUAVX,
{
fn take_vec_znx_slice_impl(
scratch: &mut Scratch<B>,
len: usize,
n: usize,
cols: usize,
size: usize,
) -> (Vec<VecZnx<&mut [u8]>>, &mut Scratch<B>) {
let mut scratch: &mut Scratch<B> = scratch;
let mut slice: Vec<VecZnx<&mut [u8]>> = Vec::with_capacity(len);
for _ in 0..len {
let (znx, new_scratch) = B::take_vec_znx_impl(scratch, n, cols, size);
scratch = new_scratch;
slice.push(znx);
}
(slice, scratch)
}
}
unsafe impl<B: Backend> TakeVmpPMatImpl<B> for B
where
B: CPUAVX + VmpPMatAllocBytesImpl<B>,
{
fn take_vmp_pmat_impl(
scratch: &mut Scratch<B>,
n: usize,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
) -> (VmpPMat<&mut [u8], B>, &mut Scratch<B>) {
let (take_slice, rem_slice) = take_slice_aligned(
&mut scratch.data,
B::vmp_pmat_alloc_bytes_impl(n, rows, cols_in, cols_out, size),
);
(
VmpPMat::from_data(take_slice, n, rows, cols_in, cols_out, size),
Scratch::from_bytes(rem_slice),
)
}
}
unsafe impl<B: Backend> TakeMatZnxImpl<B> for B
where
B: CPUAVX,
{
fn take_mat_znx_impl(
scratch: &mut Scratch<B>,
n: usize,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
) -> (MatZnx<&mut [u8]>, &mut Scratch<B>) {
let (take_slice, rem_slice) = take_slice_aligned(
&mut scratch.data,
MatZnx::alloc_bytes(n, rows, cols_in, cols_out, size),
);
(
MatZnx::from_data(take_slice, n, rows, cols_in, cols_out, size),
Scratch::from_bytes(rem_slice),
)
}
}
fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8]) {
let ptr: *mut u8 = data.as_mut_ptr();
let self_len: usize = data.len();
let aligned_offset: usize = ptr.align_offset(DEFAULTALIGN);
let aligned_len: usize = self_len.saturating_sub(aligned_offset);
if let Some(rem_len) = aligned_len.checked_sub(take_len) {
unsafe {
let rem_ptr: *mut u8 = ptr.add(aligned_offset).add(take_len);
let rem_slice: &mut [u8] = &mut *std::ptr::slice_from_raw_parts_mut(rem_ptr, rem_len);
let take_slice: &mut [u8] = &mut *std::ptr::slice_from_raw_parts_mut(ptr.add(aligned_offset), take_len);
(take_slice, rem_slice)
}
} else {
panic!(
"Attempted to take {} from scratch with {} aligned bytes left",
take_len, aligned_len,
);
}
}

View File

@@ -0,0 +1,114 @@
use crate::{
hal::{
api::{ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut},
layouts::{
Data, DataRef, Module, ScalarZnxToRef, SvpPPol, SvpPPolBytesOf, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDft,
VecZnxDftToMut, VecZnxDftToRef,
},
oep::{SvpApplyImpl, SvpApplyInplaceImpl, SvpPPolAllocBytesImpl, SvpPPolAllocImpl, SvpPPolFromBytesImpl, SvpPrepareImpl},
},
implementation::cpu_spqlios::{
ffi::{svp, vec_znx_dft::vec_znx_dft_t},
module_fft64::FFT64,
},
};
const SVP_PPOL_FFT64_WORD_SIZE: usize = 1;
impl<D: Data> SvpPPolBytesOf for SvpPPol<D, FFT64> {
fn bytes_of(n: usize, cols: usize) -> usize {
SVP_PPOL_FFT64_WORD_SIZE * n * cols * size_of::<f64>()
}
}
impl<D: Data> ZnxSliceSize for SvpPPol<D, FFT64> {
fn sl(&self) -> usize {
SVP_PPOL_FFT64_WORD_SIZE * self.n()
}
}
impl<D: DataRef> ZnxView for SvpPPol<D, FFT64> {
type Scalar = f64;
}
unsafe impl SvpPPolFromBytesImpl<Self> for FFT64 {
fn svp_ppol_from_bytes_impl(n: usize, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<Self> {
SvpPPolOwned::from_bytes(n, cols, bytes)
}
}
unsafe impl SvpPPolAllocImpl<Self> for FFT64 {
fn svp_ppol_alloc_impl(n: usize, cols: usize) -> SvpPPolOwned<Self> {
SvpPPolOwned::alloc(n, cols)
}
}
unsafe impl SvpPPolAllocBytesImpl<Self> for FFT64 {
fn svp_ppol_alloc_bytes_impl(n: usize, cols: usize) -> usize {
SvpPPol::<Vec<u8>, Self>::bytes_of(n, cols)
}
}
unsafe impl SvpPrepareImpl<Self> for FFT64 {
fn svp_prepare_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: SvpPPolToMut<Self>,
A: ScalarZnxToRef,
{
unsafe {
svp::svp_prepare(
module.ptr(),
res.to_mut().at_mut_ptr(res_col, 0) as *mut svp::svp_ppol_t,
a.to_ref().at_ptr(a_col, 0),
)
}
}
}
unsafe impl SvpApplyImpl<Self> for FFT64 {
fn svp_apply_impl<R, A, B>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxDftToMut<Self>,
A: SvpPPolToRef<Self>,
B: VecZnxDftToRef<Self>,
{
let mut res: VecZnxDft<&mut [u8], Self> = res.to_mut();
let a: SvpPPol<&[u8], Self> = a.to_ref();
let b: VecZnxDft<&[u8], Self> = b.to_ref();
unsafe {
svp::svp_apply_dft_to_dft(
module.ptr(),
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
res.size() as u64,
res.cols() as u64,
a.at_ptr(a_col, 0) as *const svp::svp_ppol_t,
b.at_ptr(b_col, 0) as *const vec_znx_dft_t,
b.size() as u64,
b.cols() as u64,
)
}
}
}
unsafe impl SvpApplyInplaceImpl for FFT64 {
fn svp_apply_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<Self>,
A: SvpPPolToRef<Self>,
{
let mut res: VecZnxDft<&mut [u8], Self> = res.to_mut();
let a: SvpPPol<&[u8], Self> = a.to_ref();
unsafe {
svp::svp_apply_dft_to_dft(
module.ptr(),
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
res.size() as u64,
res.cols() as u64,
a.at_ptr(a_col, 0) as *const svp::svp_ppol_t,
res.at_ptr(res_col, 0) as *const vec_znx_dft_t,
res.size() as u64,
res.cols() as u64,
)
}
}
}

View File

@@ -0,0 +1,44 @@
use crate::{
hal::{
api::{ZnxInfos, ZnxSliceSize, ZnxView},
layouts::{Data, DataRef, SvpPPol, SvpPPolBytesOf, SvpPPolOwned},
oep::{SvpPPolAllocBytesImpl, SvpPPolAllocImpl, SvpPPolFromBytesImpl},
},
implementation::cpu_spqlios::module_ntt120::NTT120,
};
const SVP_PPOL_NTT120_WORD_SIZE: usize = 4;
impl<D: Data> SvpPPolBytesOf for SvpPPol<D, NTT120> {
fn bytes_of(n: usize, cols: usize) -> usize {
SVP_PPOL_NTT120_WORD_SIZE * n * cols * size_of::<i64>()
}
}
impl<D: Data> ZnxSliceSize for SvpPPol<D, NTT120> {
fn sl(&self) -> usize {
SVP_PPOL_NTT120_WORD_SIZE * self.n()
}
}
impl<D: DataRef> ZnxView for SvpPPol<D, NTT120> {
type Scalar = i64;
}
unsafe impl SvpPPolFromBytesImpl<Self> for NTT120 {
fn svp_ppol_from_bytes_impl(n: usize, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<NTT120> {
SvpPPolOwned::from_bytes(n, cols, bytes)
}
}
unsafe impl SvpPPolAllocImpl<Self> for NTT120 {
fn svp_ppol_alloc_impl(n: usize, cols: usize) -> SvpPPolOwned<NTT120> {
SvpPPolOwned::alloc(n, cols)
}
}
unsafe impl SvpPPolAllocBytesImpl<Self> for NTT120 {
fn svp_ppol_alloc_bytes_impl(n: usize, cols: usize) -> usize {
SvpPPol::<Vec<u8>, Self>::bytes_of(n, cols)
}
}

View File

@@ -0,0 +1 @@
mod vec_znx_fft64;

View File

@@ -0,0 +1,20 @@
use crate::{
hal::{
api::ModuleNew,
layouts::Module,
tests::vec_znx::{test_vec_znx_add_normal, test_vec_znx_fill_uniform},
},
implementation::cpu_spqlios::FFT64,
};
#[test]
fn test_vec_znx_fill_uniform_fft64() {
let module: Module<FFT64> = Module::<FFT64>::new(1 << 12);
test_vec_znx_fill_uniform(&module);
}
#[test]
fn test_vec_znx_add_normal_fft64() {
let module: Module<FFT64> = Module::<FFT64>::new(1 << 12);
test_vec_znx_add_normal(&module);
}

View File

@@ -0,0 +1,929 @@
use itertools::izip;
use rand_distr::Normal;
use crate::{
hal::{
api::{
TakeSlice, TakeVecZnx, VecZnxAddDistF64, VecZnxCopy, VecZnxFillDistF64, VecZnxNormalizeTmpBytes, VecZnxRotate,
VecZnxRotateInplace, VecZnxSwithcDegree, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
},
layouts::{Backend, Module, ScalarZnx, ScalarZnxToRef, Scratch, VecZnx, VecZnxToMut, VecZnxToRef},
oep::{
VecZnxAddDistF64Impl, VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl, VecZnxAddScalarInplaceImpl,
VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl, VecZnxFillDistF64Impl, VecZnxFillNormalImpl,
VecZnxFillUniformImpl, VecZnxLshInplaceImpl, VecZnxMergeImpl, VecZnxMulXpMinusOneImpl,
VecZnxMulXpMinusOneInplaceImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl,
VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl,
VecZnxRshInplaceImpl, VecZnxSplitImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl,
VecZnxSubScalarInplaceImpl, VecZnxSwithcDegreeImpl,
},
source::Source,
},
implementation::cpu_spqlios::{
CPUAVX,
ffi::{module::module_info_t, vec_znx, znx},
},
};
unsafe impl<B: Backend> VecZnxNormalizeTmpBytesImpl<B> for B
where
B: CPUAVX,
{
fn vec_znx_normalize_tmp_bytes_impl(module: &Module<B>, n: usize) -> usize {
unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(module.ptr() as *const module_info_t, n as u64) as usize }
}
}
unsafe impl<B: Backend + CPUAVX> VecZnxNormalizeImpl<B> for B {
fn vec_znx_normalize_impl<R, A>(
module: &Module<B>,
basek: usize,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch<B>,
) where
R: VecZnxToMut,
A: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), a.n());
}
let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes(a.n()));
unsafe {
vec_znx::vec_znx_normalize_base2k(
module.ptr() as *const module_info_t,
a.n() as u64,
basek as u64,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
tmp_bytes.as_mut_ptr(),
);
}
}
}
unsafe impl<B: Backend + CPUAVX> VecZnxNormalizeInplaceImpl<B> for B {
fn vec_znx_normalize_inplace_impl<A>(module: &Module<B>, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
where
A: VecZnxToMut,
{
let mut a: VecZnx<&mut [u8]> = a.to_mut();
let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes(a.n()));
unsafe {
vec_znx::vec_znx_normalize_base2k(
module.ptr() as *const module_info_t,
a.n() as u64,
basek as u64,
a.at_mut_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
tmp_bytes.as_mut_ptr(),
);
}
}
}
unsafe impl<B: Backend + CPUAVX> VecZnxAddImpl<B> for B {
fn vec_znx_add_impl<R, A, C>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
C: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let b: VecZnx<&[u8]> = b.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
assert_eq!(b.n(), res.n());
assert_ne!(a.as_ptr(), b.as_ptr());
}
unsafe {
vec_znx::vec_znx_add(
module.ptr() as *const module_info_t,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
b.at_ptr(b_col, 0),
b.size() as u64,
b.sl() as u64,
)
}
}
}
unsafe impl<B: Backend + CPUAVX> VecZnxAddInplaceImpl<B> for B {
fn vec_znx_add_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
}
unsafe {
vec_znx::vec_znx_add(
module.ptr() as *const module_info_t,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
res.at_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
)
}
}
}
unsafe impl<B: Backend + CPUAVX> VecZnxAddScalarInplaceImpl<B> for B {
fn vec_znx_add_scalar_inplace_impl<R, A>(
module: &Module<B>,
res: &mut R,
res_col: usize,
res_limb: usize,
a: &A,
a_col: usize,
) where
R: VecZnxToMut,
A: ScalarZnxToRef,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let a: ScalarZnx<&[u8]> = a.to_ref();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
}
unsafe {
vec_znx::vec_znx_add(
module.ptr() as *const module_info_t,
res.at_mut_ptr(res_col, res_limb),
1_u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
res.at_ptr(res_col, res_limb),
1_u64,
res.sl() as u64,
)
}
}
}
unsafe impl<B: Backend + CPUAVX> VecZnxSubImpl<B> for B {
fn vec_znx_sub_impl<R, A, C>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
C: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let b: VecZnx<&[u8]> = b.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
assert_eq!(b.n(), res.n());
assert_ne!(a.as_ptr(), b.as_ptr());
}
unsafe {
vec_znx::vec_znx_sub(
module.ptr() as *const module_info_t,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
b.at_ptr(b_col, 0),
b.size() as u64,
b.sl() as u64,
)
}
}
}
unsafe impl<B: Backend + CPUAVX> VecZnxSubABInplaceImpl<B> for B {
fn vec_znx_sub_ab_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
}
unsafe {
vec_znx::vec_znx_sub(
module.ptr() as *const module_info_t,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
res.at_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
}
unsafe impl<B: Backend + CPUAVX> VecZnxSubBAInplaceImpl<B> for B {
fn vec_znx_sub_ba_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
}
unsafe {
vec_znx::vec_znx_sub(
module.ptr() as *const module_info_t,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
res.at_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
)
}
}
}
unsafe impl<B: Backend + CPUAVX> VecZnxSubScalarInplaceImpl<B> for B {
fn vec_znx_sub_scalar_inplace_impl<R, A>(
module: &Module<B>,
res: &mut R,
res_col: usize,
res_limb: usize,
a: &A,
a_col: usize,
) where
R: VecZnxToMut,
A: ScalarZnxToRef,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let a: ScalarZnx<&[u8]> = a.to_ref();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
}
unsafe {
vec_znx::vec_znx_sub(
module.ptr() as *const module_info_t,
res.at_mut_ptr(res_col, res_limb),
1_u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
res.at_ptr(res_col, res_limb),
1_u64,
res.sl() as u64,
)
}
}
}
unsafe impl<B: Backend + CPUAVX> VecZnxNegateImpl<B> for B {
fn vec_znx_negate_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
}
unsafe {
vec_znx::vec_znx_negate(
module.ptr() as *const module_info_t,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
}
unsafe impl<B: Backend + CPUAVX> VecZnxNegateInplaceImpl<B> for B {
fn vec_znx_negate_inplace_impl<A>(module: &Module<B>, a: &mut A, a_col: usize)
where
A: VecZnxToMut,
{
let mut a: VecZnx<&mut [u8]> = a.to_mut();
unsafe {
vec_znx::vec_znx_negate(
module.ptr() as *const module_info_t,
a.at_mut_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
}
unsafe impl<B: Backend + CPUAVX> VecZnxLshInplaceImpl<B> for B {
fn vec_znx_lsh_inplace_impl<A>(_module: &Module<B>, basek: usize, k: usize, a: &mut A)
where
A: VecZnxToMut,
{
vec_znx_lsh_inplace_ref(basek, k, a)
}
}
pub fn vec_znx_lsh_inplace_ref<A>(basek: usize, k: usize, a: &mut A)
where
A: VecZnxToMut,
{
let mut a: VecZnx<&mut [u8]> = a.to_mut();
let n: usize = a.n();
let cols: usize = a.cols();
let size: usize = a.size();
let steps: usize = k / basek;
a.raw_mut().rotate_left(n * steps * cols);
(0..cols).for_each(|i| {
(size - steps..size).for_each(|j| {
a.zero_at(i, j);
})
});
let k_rem: usize = k % basek;
if k_rem != 0 {
let shift: usize = i64::BITS as usize - k_rem;
(0..cols).for_each(|i| {
(0..steps).for_each(|j| {
a.at_mut(i, j).iter_mut().for_each(|xi| {
*xi <<= shift;
});
});
});
}
}
unsafe impl<B: Backend + CPUAVX> VecZnxRshInplaceImpl<B> for B {
fn vec_znx_rsh_inplace_impl<A>(_module: &Module<B>, basek: usize, k: usize, a: &mut A)
where
A: VecZnxToMut,
{
vec_znx_rsh_inplace_ref(basek, k, a)
}
}
pub fn vec_znx_rsh_inplace_ref<A>(basek: usize, k: usize, a: &mut A)
where
A: VecZnxToMut,
{
let mut a: VecZnx<&mut [u8]> = a.to_mut();
let n: usize = a.n();
let cols: usize = a.cols();
let size: usize = a.size();
let steps: usize = k / basek;
a.raw_mut().rotate_right(n * steps * cols);
(0..cols).for_each(|i| {
(0..steps).for_each(|j| {
a.zero_at(i, j);
})
});
let k_rem: usize = k % basek;
if k_rem != 0 {
let mut carry: Vec<i64> = vec![0i64; n]; // ALLOC (but small so OK)
let shift: usize = i64::BITS as usize - k_rem;
(0..cols).for_each(|i| {
carry.fill(0);
(steps..size).for_each(|j| {
izip!(carry.iter_mut(), a.at_mut(i, j).iter_mut()).for_each(|(ci, xi)| {
*xi += *ci << basek;
*ci = (*xi << shift) >> shift;
*xi = (*xi - *ci) >> k_rem;
});
});
})
}
}
unsafe impl<B: Backend + CPUAVX> VecZnxRotateImpl<B> for B {
fn vec_znx_rotate_impl<R, A>(_module: &Module<B>, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), a.n());
}
unsafe {
(0..a.size()).for_each(|j| {
znx::znx_rotate_i64(
a.n() as u64,
k,
res.at_mut_ptr(res_col, j),
a.at_ptr(a_col, j),
);
});
}
}
}
unsafe impl<B: Backend + CPUAVX> VecZnxRotateInplaceImpl<B> for B {
fn vec_znx_rotate_inplace_impl<A>(_module: &Module<B>, k: i64, a: &mut A, a_col: usize)
where
A: VecZnxToMut,
{
let mut a: VecZnx<&mut [u8]> = a.to_mut();
unsafe {
(0..a.size()).for_each(|j| {
znx::znx_rotate_inplace_i64(a.n() as u64, k, a.at_mut_ptr(a_col, j));
});
}
}
}
unsafe impl<B: Backend + CPUAVX> VecZnxAutomorphismImpl<B> for B {
fn vec_znx_automorphism_impl<R, A>(module: &Module<B>, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
}
unsafe {
vec_znx::vec_znx_automorphism(
module.ptr() as *const module_info_t,
k,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
}
unsafe impl<B: Backend + CPUAVX> VecZnxAutomorphismInplaceImpl<B> for B {
fn vec_znx_automorphism_inplace_impl<A>(module: &Module<B>, k: i64, a: &mut A, a_col: usize)
where
A: VecZnxToMut,
{
let mut a: VecZnx<&mut [u8]> = a.to_mut();
#[cfg(debug_assertions)]
{
assert!(
k & 1 != 0,
"invalid galois element: must be odd but is {}",
k
);
}
unsafe {
vec_znx::vec_znx_automorphism(
module.ptr() as *const module_info_t,
k,
a.at_mut_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
}
unsafe impl<B: Backend + CPUAVX> VecZnxMulXpMinusOneImpl<B> for B {
fn vec_znx_mul_xp_minus_one_impl<R, A>(module: &Module<B>, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
assert_eq!(res.n(), res.n());
}
unsafe {
vec_znx::vec_znx_mul_xp_minus_one(
module.ptr() as *const module_info_t,
p,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
}
unsafe impl<B: Backend + CPUAVX> VecZnxMulXpMinusOneInplaceImpl<B> for B {
fn vec_znx_mul_xp_minus_one_inplace_impl<R>(module: &Module<B>, p: i64, res: &mut R, res_col: usize)
where
R: VecZnxToMut,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), res.n());
}
unsafe {
vec_znx::vec_znx_mul_xp_minus_one(
module.ptr() as *const module_info_t,
p,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
res.at_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
)
}
}
}
unsafe impl<B: Backend + CPUAVX> VecZnxSplitImpl<B> for B {
fn vec_znx_split_impl<R, A>(module: &Module<B>, res: &mut [R], res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
vec_znx_split_ref(module, res, res_col, a, a_col, scratch)
}
}
pub fn vec_znx_split_ref<R, A, B>(
module: &Module<B>,
res: &mut [R],
res_col: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch<B>,
) where
B: Backend + CPUAVX,
R: VecZnxToMut,
A: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let (n_in, n_out) = (a.n(), res[0].to_mut().n());
let (mut buf, _) = scratch.take_vec_znx(n_in.max(n_out), 1, a.size());
debug_assert!(
n_out < n_in,
"invalid a: output ring degree should be smaller"
);
res[1..].iter_mut().for_each(|bi| {
debug_assert_eq!(
bi.to_mut().n(),
n_out,
"invalid input a: all VecZnx must have the same degree"
)
});
res.iter_mut().enumerate().for_each(|(i, bi)| {
if i == 0 {
module.vec_znx_switch_degree(bi, res_col, &a, a_col);
module.vec_znx_rotate(-1, &mut buf, 0, &a, a_col);
} else {
module.vec_znx_switch_degree(bi, res_col, &buf, a_col);
module.vec_znx_rotate_inplace(-1, &mut buf, a_col);
}
})
}
unsafe impl<B: Backend + CPUAVX> VecZnxMergeImpl<B> for B {
fn vec_znx_merge_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &[A], a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
vec_znx_merge_ref(module, res, res_col, a, a_col)
}
}
pub fn vec_znx_merge_ref<R, A, B>(module: &Module<B>, res: &mut R, res_col: usize, a: &[A], a_col: usize)
where
B: Backend + CPUAVX,
R: VecZnxToMut,
A: VecZnxToRef,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let (n_in, n_out) = (res.n(), a[0].to_ref().n());
debug_assert!(
n_out < n_in,
"invalid a: output ring degree should be smaller"
);
a[1..].iter().for_each(|ai| {
debug_assert_eq!(
ai.to_ref().n(),
n_out,
"invalid input a: all VecZnx must have the same degree"
)
});
a.iter().for_each(|ai| {
module.vec_znx_switch_degree(&mut res, res_col, ai, a_col);
module.vec_znx_rotate_inplace(-1, &mut res, res_col);
});
module.vec_znx_rotate_inplace(a.len() as i64, &mut res, res_col);
}
unsafe impl<B: Backend + CPUAVX> VecZnxSwithcDegreeImpl<B> for B {
fn vec_znx_switch_degree_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
vec_znx_switch_degree_ref(module, res, res_col, a, a_col)
}
}
pub fn vec_znx_switch_degree_ref<R, A, B>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
B: Backend + CPUAVX,
R: VecZnxToMut,
A: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let (n_in, n_out) = (a.n(), res.n());
if n_in == n_out {
module.vec_znx_copy(&mut res, res_col, &a, a_col);
return;
}
let (gap_in, gap_out): (usize, usize);
if n_in > n_out {
(gap_in, gap_out) = (n_in / n_out, 1)
} else {
(gap_in, gap_out) = (1, n_out / n_in);
res.zero();
}
let size: usize = a.size().min(res.size());
(0..size).for_each(|i| {
izip!(
a.at(a_col, i).iter().step_by(gap_in),
res.at_mut(res_col, i).iter_mut().step_by(gap_out)
)
.for_each(|(x_in, x_out)| *x_out = *x_in);
});
}
unsafe impl<B: Backend + CPUAVX> VecZnxCopyImpl<B> for B {
fn vec_znx_copy_impl<R, A>(_module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
vec_znx_copy_ref(res, res_col, a, a_col)
}
}
pub fn vec_znx_copy_ref<R, A>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
let mut res_mut: VecZnx<&mut [u8]> = res.to_mut();
let a_ref: VecZnx<&[u8]> = a.to_ref();
let min_size: usize = res_mut.size().min(a_ref.size());
(0..min_size).for_each(|j| {
res_mut
.at_mut(res_col, j)
.copy_from_slice(a_ref.at(a_col, j));
});
(min_size..res_mut.size()).for_each(|j| {
res_mut.zero_at(res_col, j);
})
}
unsafe impl<B: Backend + CPUAVX> VecZnxFillUniformImpl<B> for B {
fn vec_znx_fill_uniform_impl<R>(_module: &Module<B>, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source)
where
R: VecZnxToMut,
{
let mut a: VecZnx<&mut [u8]> = res.to_mut();
let base2k: u64 = 1 << basek;
let mask: u64 = base2k - 1;
let base2k_half: i64 = (base2k >> 1) as i64;
(0..k.div_ceil(basek)).for_each(|j| {
a.at_mut(res_col, j)
.iter_mut()
.for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half);
})
}
}
unsafe impl<B: Backend + CPUAVX> VecZnxFillDistF64Impl<B> for B {
fn vec_znx_fill_dist_f64_impl<R, D: rand::prelude::Distribution<f64>>(
_module: &Module<B>,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) where
R: VecZnxToMut,
{
let mut a: VecZnx<&mut [u8]> = res.to_mut();
assert!(
(bound.log2().ceil() as i64) < 64,
"invalid bound: ceil(log2(bound))={} > 63",
(bound.log2().ceil() as i64)
);
let limb: usize = k.div_ceil(basek) - 1;
let basek_rem: usize = (limb + 1) * basek - k;
if basek_rem != 0 {
a.at_mut(res_col, limb).iter_mut().for_each(|a| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
}
*a = (dist_f64.round() as i64) << basek_rem;
});
} else {
a.at_mut(res_col, limb).iter_mut().for_each(|a| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
}
*a = dist_f64.round() as i64
});
}
}
}
unsafe impl<B: Backend + CPUAVX> VecZnxAddDistF64Impl<B> for B {
fn vec_znx_add_dist_f64_impl<R, D: rand::prelude::Distribution<f64>>(
_module: &Module<B>,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) where
R: VecZnxToMut,
{
let mut a: VecZnx<&mut [u8]> = res.to_mut();
assert!(
(bound.log2().ceil() as i64) < 64,
"invalid bound: ceil(log2(bound))={} > 63",
(bound.log2().ceil() as i64)
);
let limb: usize = k.div_ceil(basek) - 1;
let basek_rem: usize = (limb + 1) * basek - k;
if basek_rem != 0 {
a.at_mut(res_col, limb).iter_mut().for_each(|a| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
}
*a += (dist_f64.round() as i64) << basek_rem;
});
} else {
a.at_mut(res_col, limb).iter_mut().for_each(|a| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
}
*a += dist_f64.round() as i64
});
}
}
}
unsafe impl<B: Backend + CPUAVX> VecZnxFillNormalImpl<B> for B {
fn vec_znx_fill_normal_impl<R>(
module: &Module<B>,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) where
R: VecZnxToMut,
{
module.vec_znx_fill_dist_f64(
basek,
res,
res_col,
k,
source,
Normal::new(0.0, sigma).unwrap(),
bound,
);
}
}
unsafe impl<B: Backend + CPUAVX> VecZnxAddNormalImpl<B> for B {
fn vec_znx_add_normal_impl<R>(
module: &Module<B>,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) where
R: VecZnxToMut,
{
module.vec_znx_add_dist_f64(
basek,
res,
res_col,
k,
source,
Normal::new(0.0, sigma).unwrap(),
bound,
);
}
}

View File

@@ -0,0 +1,737 @@
use std::fmt;
use rand_distr::{Distribution, Normal};
use crate::{
hal::{
api::{
TakeSlice, VecZnxBigAddDistF64, VecZnxBigFillDistF64, VecZnxBigNormalizeTmpBytes, ZnxInfos, ZnxSliceSize, ZnxView,
ZnxViewMut,
},
layouts::{
Data, DataRef, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigBytesOf, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef,
VecZnxToMut, VecZnxToRef,
},
oep::{
VecZnxBigAddDistF64Impl, VecZnxBigAddImpl, VecZnxBigAddInplaceImpl, VecZnxBigAddNormalImpl, VecZnxBigAddSmallImpl,
VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, VecZnxBigAutomorphismImpl,
VecZnxBigAutomorphismInplaceImpl, VecZnxBigFillDistF64Impl, VecZnxBigFillNormalImpl, VecZnxBigFromBytesImpl,
VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl, VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubABInplaceImpl,
VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl, VecZnxBigSubSmallAImpl, VecZnxBigSubSmallAInplaceImpl,
VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl,
},
source::Source,
},
implementation::cpu_spqlios::{ffi::vec_znx, module_fft64::FFT64},
};
const VEC_ZNX_BIG_FFT64_WORDSIZE: usize = 1;
impl<D: DataRef> ZnxView for VecZnxBig<D, FFT64> {
type Scalar = i64;
}
impl<D: Data> VecZnxBigBytesOf for VecZnxBig<D, FFT64> {
fn bytes_of(n: usize, cols: usize, size: usize) -> usize {
VEC_ZNX_BIG_FFT64_WORDSIZE * n * cols * size * size_of::<f64>()
}
}
impl<D: Data> ZnxSliceSize for VecZnxBig<D, FFT64> {
fn sl(&self) -> usize {
VEC_ZNX_BIG_FFT64_WORDSIZE * self.n() * self.cols()
}
}
unsafe impl VecZnxBigAllocImpl<FFT64> for FFT64 {
fn vec_znx_big_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxBigOwned<FFT64> {
VecZnxBig::<Vec<u8>, FFT64>::new(n, cols, size)
}
}
unsafe impl VecZnxBigFromBytesImpl<FFT64> for FFT64 {
fn vec_znx_big_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<FFT64> {
VecZnxBig::<Vec<u8>, FFT64>::new_from_bytes(n, cols, size, bytes)
}
}
unsafe impl VecZnxBigAllocBytesImpl<FFT64> for FFT64 {
fn vec_znx_big_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize {
VecZnxBig::<Vec<u8>, FFT64>::bytes_of(n, cols, size)
}
}
unsafe impl VecZnxBigAddDistF64Impl<FFT64> for FFT64 {
fn add_dist_f64_impl<R: VecZnxBigToMut<FFT64>, D: Distribution<f64>>(
_module: &Module<FFT64>,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) {
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
assert!(
(bound.log2().ceil() as i64) < 64,
"invalid bound: ceil(log2(bound))={} > 63",
(bound.log2().ceil() as i64)
);
let limb: usize = k.div_ceil(basek) - 1;
let basek_rem: usize = (limb + 1) * basek - k;
if basek_rem != 0 {
res.at_mut(res_col, limb).iter_mut().for_each(|x| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
}
*x += (dist_f64.round() as i64) << basek_rem;
});
} else {
res.at_mut(res_col, limb).iter_mut().for_each(|x| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
}
*x += dist_f64.round() as i64
});
}
}
}
unsafe impl VecZnxBigAddNormalImpl<FFT64> for FFT64 {
fn add_normal_impl<R: VecZnxBigToMut<FFT64>>(
module: &Module<FFT64>,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) {
module.vec_znx_big_add_dist_f64(
basek,
res,
res_col,
k,
source,
Normal::new(0.0, sigma).unwrap(),
bound,
);
}
}
unsafe impl VecZnxBigFillDistF64Impl<FFT64> for FFT64 {
fn fill_dist_f64_impl<R: VecZnxBigToMut<FFT64>, D: Distribution<f64>>(
_module: &Module<FFT64>,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) {
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
assert!(
(bound.log2().ceil() as i64) < 64,
"invalid bound: ceil(log2(bound))={} > 63",
(bound.log2().ceil() as i64)
);
let limb: usize = k.div_ceil(basek) - 1;
let basek_rem: usize = (limb + 1) * basek - k;
if basek_rem != 0 {
res.at_mut(res_col, limb).iter_mut().for_each(|x| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
}
*x = (dist_f64.round() as i64) << basek_rem;
});
} else {
res.at_mut(res_col, limb).iter_mut().for_each(|x| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
}
*x = dist_f64.round() as i64
});
}
}
}
unsafe impl VecZnxBigFillNormalImpl<FFT64> for FFT64 {
fn fill_normal_impl<R: VecZnxBigToMut<FFT64>>(
module: &Module<FFT64>,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) {
module.vec_znx_big_fill_dist_f64(
basek,
res,
res_col,
k,
source,
Normal::new(0.0, sigma).unwrap(),
bound,
);
}
}
unsafe impl VecZnxBigAddImpl<FFT64> for FFT64 {
/// Adds `a` to `b` and stores the result on `c`.
fn vec_znx_big_add_impl<R, A, B>(
module: &Module<FFT64>,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
b: &B,
b_col: usize,
) where
R: VecZnxBigToMut<FFT64>,
A: VecZnxBigToRef<FFT64>,
B: VecZnxBigToRef<FFT64>,
{
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
let b: VecZnxBig<&[u8], FFT64> = b.to_ref();
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
assert_eq!(b.n(), res.n());
assert_ne!(a.as_ptr(), b.as_ptr());
}
unsafe {
vec_znx::vec_znx_add(
module.ptr(),
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
b.at_ptr(b_col, 0),
b.size() as u64,
b.sl() as u64,
)
}
}
}
unsafe impl VecZnxBigAddInplaceImpl<FFT64> for FFT64 {
/// Adds `a` to `b` and stores the result on `b`.
fn vec_znx_big_add_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<FFT64>,
A: VecZnxBigToRef<FFT64>,
{
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
}
unsafe {
vec_znx::vec_znx_add(
module.ptr(),
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
res.at_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
)
}
}
}
unsafe impl VecZnxBigAddSmallImpl<FFT64> for FFT64 {
/// Adds `a` to `b` and stores the result on `c`.
fn vec_znx_big_add_small_impl<R, A, B>(
module: &Module<FFT64>,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
b: &B,
b_col: usize,
) where
R: VecZnxBigToMut<FFT64>,
A: VecZnxBigToRef<FFT64>,
B: VecZnxToRef,
{
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
let b: VecZnx<&[u8]> = b.to_ref();
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
assert_eq!(b.n(), res.n());
assert_ne!(a.as_ptr(), b.as_ptr());
}
unsafe {
vec_znx::vec_znx_add(
module.ptr(),
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
b.at_ptr(b_col, 0),
b.size() as u64,
b.sl() as u64,
)
}
}
}
unsafe impl VecZnxBigAddSmallInplaceImpl<FFT64> for FFT64 {
/// Adds `a` to `b` and stores the result on `b`.
fn vec_znx_big_add_small_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<FFT64>,
A: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
}
unsafe {
vec_znx::vec_znx_add(
module.ptr(),
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
res.at_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
}
unsafe impl VecZnxBigSubImpl<FFT64> for FFT64 {
/// Subtracts `a` to `b` and stores the result on `c`.
fn vec_znx_big_sub_impl<R, A, B>(
module: &Module<FFT64>,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
b: &B,
b_col: usize,
) where
R: VecZnxBigToMut<FFT64>,
A: VecZnxBigToRef<FFT64>,
B: VecZnxBigToRef<FFT64>,
{
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
let b: VecZnxBig<&[u8], FFT64> = b.to_ref();
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
assert_eq!(b.n(), res.n());
assert_ne!(a.as_ptr(), b.as_ptr());
}
unsafe {
vec_znx::vec_znx_sub(
module.ptr(),
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
b.at_ptr(b_col, 0),
b.size() as u64,
b.sl() as u64,
)
}
}
}
unsafe impl VecZnxBigSubABInplaceImpl<FFT64> for FFT64 {
/// Subtracts `a` from `b` and stores the result on `b`.
fn vec_znx_big_sub_ab_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<FFT64>,
A: VecZnxBigToRef<FFT64>,
{
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
}
unsafe {
vec_znx::vec_znx_sub(
module.ptr(),
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
res.at_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
}
unsafe impl VecZnxBigSubBAInplaceImpl<FFT64> for FFT64 {
/// Subtracts `b` from `a` and stores the result on `b`.
fn vec_znx_big_sub_ba_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<FFT64>,
A: VecZnxBigToRef<FFT64>,
{
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
}
unsafe {
vec_znx::vec_znx_sub(
module.ptr(),
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
res.at_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
)
}
}
}
unsafe impl VecZnxBigSubSmallAImpl<FFT64> for FFT64 {
/// Subtracts `b` from `a` and stores the result on `c`.
fn vec_znx_big_sub_small_a_impl<R, A, B>(
module: &Module<FFT64>,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
b: &B,
b_col: usize,
) where
R: VecZnxBigToMut<FFT64>,
A: VecZnxToRef,
B: VecZnxBigToRef<FFT64>,
{
let a: VecZnx<&[u8]> = a.to_ref();
let b: VecZnxBig<&[u8], FFT64> = b.to_ref();
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
assert_eq!(b.n(), res.n());
assert_ne!(a.as_ptr(), b.as_ptr());
}
unsafe {
vec_znx::vec_znx_sub(
module.ptr(),
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
b.at_ptr(b_col, 0),
b.size() as u64,
b.sl() as u64,
)
}
}
}
unsafe impl VecZnxBigSubSmallAInplaceImpl<FFT64> for FFT64 {
/// Subtracts `a` from `res` and stores the result on `res`.
fn vec_znx_big_sub_small_a_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<FFT64>,
A: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
}
unsafe {
vec_znx::vec_znx_sub(
module.ptr(),
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
res.at_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
}
unsafe impl VecZnxBigSubSmallBImpl<FFT64> for FFT64 {
/// Subtracts `b` from `a` and stores the result on `c`.
fn vec_znx_big_sub_small_b_impl<R, A, B>(
module: &Module<FFT64>,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
b: &B,
b_col: usize,
) where
R: VecZnxBigToMut<FFT64>,
A: VecZnxBigToRef<FFT64>,
B: VecZnxToRef,
{
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
let b: VecZnx<&[u8]> = b.to_ref();
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
assert_eq!(b.n(), res.n());
assert_ne!(a.as_ptr(), b.as_ptr());
}
unsafe {
vec_znx::vec_znx_sub(
module.ptr(),
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
b.at_ptr(b_col, 0),
b.size() as u64,
b.sl() as u64,
)
}
}
}
unsafe impl VecZnxBigSubSmallBInplaceImpl<FFT64> for FFT64 {
/// Subtracts `res` from `a` and stores the result on `res`.
fn vec_znx_big_sub_small_b_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<FFT64>,
A: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
}
unsafe {
vec_znx::vec_znx_sub(
module.ptr(),
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
res.at_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
)
}
}
}
unsafe impl VecZnxBigNegateInplaceImpl<FFT64> for FFT64 {
fn vec_znx_big_negate_inplace_impl<A>(module: &Module<FFT64>, a: &mut A, a_col: usize)
where
A: VecZnxBigToMut<FFT64>,
{
let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut();
unsafe {
vec_znx::vec_znx_negate(
module.ptr(),
a.at_mut_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
}
unsafe impl VecZnxBigNormalizeTmpBytesImpl<FFT64> for FFT64 {
fn vec_znx_big_normalize_tmp_bytes_impl(module: &Module<FFT64>, n: usize) -> usize {
unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(module.ptr(), n as u64) as usize }
}
}
unsafe impl VecZnxBigNormalizeImpl<FFT64> for FFT64 {
fn vec_znx_big_normalize_impl<R, A>(
module: &Module<FFT64>,
basek: usize,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch<FFT64>,
) where
R: VecZnxToMut,
A: VecZnxBigToRef<FFT64>,
{
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), a.n());
}
let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes(a.n()));
unsafe {
vec_znx::vec_znx_normalize_base2k(
module.ptr(),
a.n() as u64,
basek as u64,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
tmp_bytes.as_mut_ptr(),
);
}
}
}
unsafe impl VecZnxBigAutomorphismImpl<FFT64> for FFT64 {
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`.
fn vec_znx_big_automorphism_impl<R, A>(module: &Module<FFT64>, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<FFT64>,
A: VecZnxBigToRef<FFT64>,
{
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
}
unsafe {
vec_znx::vec_znx_automorphism(
module.ptr(),
k,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
}
unsafe impl VecZnxBigAutomorphismInplaceImpl<FFT64> for FFT64 {
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`.
fn vec_znx_big_automorphism_inplace_impl<A>(module: &Module<FFT64>, k: i64, a: &mut A, a_col: usize)
where
A: VecZnxBigToMut<FFT64>,
{
let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut();
unsafe {
vec_znx::vec_znx_automorphism(
module.ptr(),
k,
a.at_mut_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
}
impl<D: DataRef> fmt::Display for VecZnxBig<D, FFT64> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(
f,
"VecZnxBig(n={}, cols={}, size={})",
self.n, self.cols, self.size
)?;
for col in 0..self.cols {
writeln!(f, "Column {}:", col)?;
for size in 0..self.size {
let coeffs = self.at(col, size);
write!(f, " Size {}: [", size)?;
let max_show = 100;
let show_count = coeffs.len().min(max_show);
for (i, &coeff) in coeffs.iter().take(show_count).enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", coeff)?;
}
if coeffs.len() > max_show {
write!(f, ", ... ({} more)", coeffs.len() - max_show)?;
}
writeln!(f, "]")?;
}
}
Ok(())
}
}

View File

@@ -0,0 +1,32 @@
use crate::{
hal::{
api::{ZnxInfos, ZnxSliceSize, ZnxView},
layouts::{Data, DataRef, VecZnxBig, VecZnxBigBytesOf},
oep::VecZnxBigAllocBytesImpl,
},
implementation::cpu_spqlios::module_ntt120::NTT120,
};
const VEC_ZNX_BIG_NTT120_WORDSIZE: usize = 4;
impl<D: DataRef> ZnxView for VecZnxBig<D, NTT120> {
type Scalar = i128;
}
impl<D: Data> VecZnxBigBytesOf for VecZnxBig<D, NTT120> {
fn bytes_of(n: usize, cols: usize, size: usize) -> usize {
VEC_ZNX_BIG_NTT120_WORDSIZE * n * cols * size * size_of::<f64>()
}
}
impl<D: Data> ZnxSliceSize for VecZnxBig<D, NTT120> {
fn sl(&self) -> usize {
VEC_ZNX_BIG_NTT120_WORDSIZE * self.n() * self.cols()
}
}
unsafe impl VecZnxBigAllocBytesImpl<NTT120> for NTT120 {
fn vec_znx_big_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize {
VecZnxBig::<Vec<u8>, NTT120>::bytes_of(n, cols, size)
}
}

View File

@@ -0,0 +1,433 @@
use std::fmt;
use crate::{
hal::{
api::{TakeSlice, VecZnxDftToVecZnxBigTmpBytes, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero},
layouts::{
Data, DataRef, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftBytesOf, VecZnxDftOwned,
VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef,
},
oep::{
VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftCopyImpl,
VecZnxDftFromBytesImpl, VecZnxDftFromVecZnxImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl,
VecZnxDftSubImpl, VecZnxDftToVecZnxBigConsumeImpl, VecZnxDftToVecZnxBigImpl, VecZnxDftToVecZnxBigTmpAImpl,
VecZnxDftToVecZnxBigTmpBytesImpl, VecZnxDftZeroImpl,
},
},
implementation::cpu_spqlios::{
ffi::{vec_znx_big, vec_znx_dft},
module_fft64::FFT64,
},
};
const VEC_ZNX_DFT_FFT64_WORDSIZE: usize = 1;
impl<D: Data> ZnxSliceSize for VecZnxDft<D, FFT64> {
fn sl(&self) -> usize {
VEC_ZNX_DFT_FFT64_WORDSIZE * self.n() * self.cols()
}
}
impl<D: Data> VecZnxDftBytesOf for VecZnxDft<D, FFT64> {
fn bytes_of(n: usize, cols: usize, size: usize) -> usize {
VEC_ZNX_DFT_FFT64_WORDSIZE * n * cols * size * size_of::<f64>()
}
}
impl<D: DataRef> ZnxView for VecZnxDft<D, FFT64> {
type Scalar = f64;
}
unsafe impl VecZnxDftFromBytesImpl<FFT64> for FFT64 {
fn vec_znx_dft_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<FFT64> {
VecZnxDft::<Vec<u8>, FFT64>::from_bytes(n, cols, size, bytes)
}
}
unsafe impl VecZnxDftAllocBytesImpl<FFT64> for FFT64 {
fn vec_znx_dft_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize {
VecZnxDft::<Vec<u8>, FFT64>::bytes_of(n, cols, size)
}
}
unsafe impl VecZnxDftAllocImpl<FFT64> for FFT64 {
fn vec_znx_dft_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxDftOwned<FFT64> {
VecZnxDftOwned::alloc(n, cols, size)
}
}
unsafe impl VecZnxDftToVecZnxBigTmpBytesImpl<FFT64> for FFT64 {
fn vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(module: &Module<FFT64>, n: usize) -> usize {
unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(module.ptr(), n as u64) as usize }
}
}
unsafe impl VecZnxDftToVecZnxBigImpl<FFT64> for FFT64 {
fn vec_znx_dft_to_vec_znx_big_impl<R, A>(
module: &Module<FFT64>,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch<FFT64>,
) where
R: VecZnxBigToMut<FFT64>,
A: VecZnxDftToRef<FFT64>,
{
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
let a: VecZnxDft<&[u8], FFT64> = a.to_ref();
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), a.n())
}
let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_dft_to_vec_znx_big_tmp_bytes(a.n()));
let min_size: usize = res.size().min(a.size());
unsafe {
(0..min_size).for_each(|j| {
vec_znx_dft::vec_znx_idft(
module.ptr(),
res.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t,
1_u64,
a.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
1_u64,
tmp_bytes.as_mut_ptr(),
)
});
(min_size..res.size()).for_each(|j| {
res.zero_at(res_col, j);
});
}
}
}
unsafe impl VecZnxDftToVecZnxBigTmpAImpl<FFT64> for FFT64 {
fn vec_znx_dft_to_vec_znx_big_tmp_a_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
where
R: VecZnxBigToMut<FFT64>,
A: VecZnxDftToMut<FFT64>,
{
let mut res_mut: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut();
let min_size: usize = res_mut.size().min(a_mut.size());
unsafe {
(0..min_size).for_each(|j| {
vec_znx_dft::vec_znx_idft_tmp_a(
module.ptr(),
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t,
1_u64,
a_mut.at_mut_ptr(a_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
1_u64,
)
});
(min_size..res_mut.size()).for_each(|j| {
res_mut.zero_at(res_col, j);
})
}
}
}
unsafe impl VecZnxDftToVecZnxBigConsumeImpl<FFT64> for FFT64 {
fn vec_znx_dft_to_vec_znx_big_consume_impl<D: Data>(module: &Module<FFT64>, mut a: VecZnxDft<D, FFT64>) -> VecZnxBig<D, FFT64>
where
VecZnxDft<D, FFT64>: VecZnxDftToMut<FFT64>,
{
let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut();
unsafe {
// Rev col and rows because ZnxDft.sl() >= ZnxBig.sl()
(0..a_mut.size()).for_each(|j| {
(0..a_mut.cols()).for_each(|i| {
vec_znx_dft::vec_znx_idft_tmp_a(
module.ptr(),
a_mut.at_mut_ptr(i, j) as *mut vec_znx_big::vec_znx_big_t,
1_u64,
a_mut.at_mut_ptr(i, j) as *mut vec_znx_dft::vec_znx_dft_t,
1_u64,
)
});
});
}
a.into_big()
}
}
unsafe impl VecZnxDftFromVecZnxImpl<FFT64> for FFT64 {
fn vec_znx_dft_from_vec_znx_impl<R, A>(
module: &Module<FFT64>,
step: usize,
offset: usize,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
) where
R: VecZnxDftToMut<FFT64>,
A: VecZnxToRef,
{
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
let a_ref: VecZnx<&[u8]> = a.to_ref();
let steps: usize = a_ref.size().div_ceil(step);
let min_steps: usize = res_mut.size().min(steps);
unsafe {
(0..min_steps).for_each(|j| {
let limb: usize = offset + j * step;
if limb < a_ref.size() {
vec_znx_dft::vec_znx_dft(
module.ptr(),
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
1_u64,
a_ref.at_ptr(a_col, limb),
1_u64,
a_ref.sl() as u64,
)
}
});
(min_steps..res_mut.size()).for_each(|j| {
res_mut.zero_at(res_col, j);
});
}
}
}
unsafe impl VecZnxDftAddImpl<FFT64> for FFT64 {
fn vec_znx_dft_add_impl<R, A, D>(
module: &Module<FFT64>,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
b: &D,
b_col: usize,
) where
R: VecZnxDftToMut<FFT64>,
A: VecZnxDftToRef<FFT64>,
D: VecZnxDftToRef<FFT64>,
{
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
let b_ref: VecZnxDft<&[u8], FFT64> = b.to_ref();
let min_size: usize = res_mut.size().min(a_ref.size()).min(b_ref.size());
unsafe {
(0..min_size).for_each(|j| {
vec_znx_dft::vec_dft_add(
module.ptr(),
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
1,
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
1,
b_ref.at_ptr(b_col, j) as *const vec_znx_dft::vec_znx_dft_t,
1,
);
});
}
(min_size..res_mut.size()).for_each(|j| {
res_mut.zero_at(res_col, j);
})
}
}
unsafe impl VecZnxDftAddInplaceImpl<FFT64> for FFT64 {
fn vec_znx_dft_add_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<FFT64>,
A: VecZnxDftToRef<FFT64>,
{
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
let min_size: usize = res_mut.size().min(a_ref.size());
unsafe {
(0..min_size).for_each(|j| {
vec_znx_dft::vec_dft_add(
module.ptr(),
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
1,
res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t,
1,
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
1,
);
});
}
}
}
unsafe impl VecZnxDftSubImpl<FFT64> for FFT64 {
fn vec_znx_dft_sub_impl<R, A, D>(
module: &Module<FFT64>,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
b: &D,
b_col: usize,
) where
R: VecZnxDftToMut<FFT64>,
A: VecZnxDftToRef<FFT64>,
D: VecZnxDftToRef<FFT64>,
{
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
let b_ref: VecZnxDft<&[u8], FFT64> = b.to_ref();
let min_size: usize = res_mut.size().min(a_ref.size()).min(b_ref.size());
unsafe {
(0..min_size).for_each(|j| {
vec_znx_dft::vec_dft_sub(
module.ptr(),
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
1,
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
1,
b_ref.at_ptr(b_col, j) as *const vec_znx_dft::vec_znx_dft_t,
1,
);
});
}
(min_size..res_mut.size()).for_each(|j| {
res_mut.zero_at(res_col, j);
})
}
}
unsafe impl VecZnxDftSubABInplaceImpl<FFT64> for FFT64 {
fn vec_znx_dft_sub_ab_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<FFT64>,
A: VecZnxDftToRef<FFT64>,
{
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
let min_size: usize = res_mut.size().min(a_ref.size());
unsafe {
(0..min_size).for_each(|j| {
vec_znx_dft::vec_dft_sub(
module.ptr(),
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
1,
res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t,
1,
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
1,
);
});
}
}
}
unsafe impl VecZnxDftSubBAInplaceImpl<FFT64> for FFT64 {
fn vec_znx_dft_sub_ba_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<FFT64>,
A: VecZnxDftToRef<FFT64>,
{
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
let min_size: usize = res_mut.size().min(a_ref.size());
unsafe {
(0..min_size).for_each(|j| {
vec_znx_dft::vec_dft_sub(
module.ptr(),
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
1,
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
1,
res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t,
1,
);
});
}
}
}
unsafe impl VecZnxDftCopyImpl<FFT64> for FFT64 {
fn vec_znx_dft_copy_impl<R, A>(
_module: &Module<FFT64>,
step: usize,
offset: usize,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
) where
R: VecZnxDftToMut<FFT64>,
A: VecZnxDftToRef<FFT64>,
{
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
let steps: usize = a_ref.size().div_ceil(step);
let min_steps: usize = res_mut.size().min(steps);
(0..min_steps).for_each(|j| {
let limb: usize = offset + j * step;
if limb < a_ref.size() {
res_mut
.at_mut(res_col, j)
.copy_from_slice(a_ref.at(a_col, limb));
}
});
(min_steps..res_mut.size()).for_each(|j| {
res_mut.zero_at(res_col, j);
})
}
}
unsafe impl VecZnxDftZeroImpl<FFT64> for FFT64 {
fn vec_znx_dft_zero_impl<R>(_module: &Module<FFT64>, res: &mut R)
where
R: VecZnxDftToMut<FFT64>,
{
res.to_mut().data.fill(0);
}
}
impl<D: DataRef> fmt::Display for VecZnxDft<D, FFT64> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(
f,
"VecZnxDft(n={}, cols={}, size={})",
self.n, self.cols, self.size
)?;
for col in 0..self.cols {
writeln!(f, "Column {}:", col)?;
for size in 0..self.size {
let coeffs = self.at(col, size);
write!(f, " Size {}: [", size)?;
let max_show = 100;
let show_count = coeffs.len().min(max_show);
for (i, &coeff) in coeffs.iter().take(show_count).enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", coeff)?;
}
if coeffs.len() > max_show {
write!(f, ", ... ({} more)", coeffs.len() - max_show)?;
}
writeln!(f, "]")?;
}
}
Ok(())
}
}

View File

@@ -0,0 +1,38 @@
use crate::{
hal::{
api::{ZnxInfos, ZnxSliceSize, ZnxView},
layouts::{Data, DataRef, VecZnxDft, VecZnxDftBytesOf, VecZnxDftOwned},
oep::{VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl},
},
implementation::cpu_spqlios::module_ntt120::NTT120,
};
const VEC_ZNX_DFT_NTT120_WORDSIZE: usize = 4;
impl<D: Data> ZnxSliceSize for VecZnxDft<D, NTT120> {
fn sl(&self) -> usize {
VEC_ZNX_DFT_NTT120_WORDSIZE * self.n() * self.cols()
}
}
impl<D: Data> VecZnxDftBytesOf for VecZnxDft<D, NTT120> {
fn bytes_of(n: usize, cols: usize, size: usize) -> usize {
VEC_ZNX_DFT_NTT120_WORDSIZE * n * cols * size * size_of::<i64>()
}
}
impl<D: DataRef> ZnxView for VecZnxDft<D, NTT120> {
type Scalar = i64;
}
unsafe impl VecZnxDftAllocBytesImpl<NTT120> for NTT120 {
fn vec_znx_dft_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize {
VecZnxDft::<Vec<u8>, NTT120>::bytes_of(n, cols, size)
}
}
unsafe impl VecZnxDftAllocImpl<NTT120> for NTT120 {
fn vec_znx_dft_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxDftOwned<NTT120> {
VecZnxDftOwned::alloc(n, cols, size)
}
}

View File

@@ -0,0 +1,298 @@
use crate::{
hal::{
api::{TakeSlice, VmpApplyTmpBytes, VmpPrepareTmpBytes, ZnxInfos, ZnxView, ZnxViewMut},
layouts::{
DataRef, MatZnx, MatZnxToRef, Module, Scratch, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, VmpPMat, VmpPMatBytesOf,
VmpPMatOwned, VmpPMatToMut, VmpPMatToRef,
},
oep::{
VmpApplyAddImpl, VmpApplyAddTmpBytesImpl, VmpApplyImpl, VmpApplyTmpBytesImpl, VmpPMatAllocBytesImpl,
VmpPMatAllocImpl, VmpPMatFromBytesImpl, VmpPMatPrepareImpl, VmpPrepareTmpBytesImpl,
},
},
implementation::cpu_spqlios::{
ffi::{vec_znx_dft::vec_znx_dft_t, vmp},
module_fft64::FFT64,
},
};
const VMP_PMAT_FFT64_WORDSIZE: usize = 1;
impl<D: DataRef> ZnxView for VmpPMat<D, FFT64> {
type Scalar = f64;
}
impl VmpPMatBytesOf for FFT64 {
fn vmp_pmat_bytes_of(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
VMP_PMAT_FFT64_WORDSIZE * n * rows * cols_in * cols_out * size * size_of::<f64>()
}
}
unsafe impl VmpPMatAllocBytesImpl<FFT64> for FFT64
where
FFT64: VmpPMatBytesOf,
{
fn vmp_pmat_alloc_bytes_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
FFT64::vmp_pmat_bytes_of(n, rows, cols_in, cols_out, size)
}
}
unsafe impl VmpPMatFromBytesImpl<FFT64> for FFT64 {
fn vmp_pmat_from_bytes_impl(
n: usize,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
bytes: Vec<u8>,
) -> VmpPMatOwned<FFT64> {
VmpPMatOwned::from_bytes(n, rows, cols_in, cols_out, size, bytes)
}
}
unsafe impl VmpPMatAllocImpl<FFT64> for FFT64 {
fn vmp_pmat_alloc_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<FFT64> {
VmpPMatOwned::alloc(n, rows, cols_in, cols_out, size)
}
}
unsafe impl VmpPrepareTmpBytesImpl<FFT64> for FFT64 {
fn vmp_prepare_tmp_bytes_impl(
module: &Module<FFT64>,
n: usize,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
) -> usize {
unsafe {
vmp::vmp_prepare_tmp_bytes(
module.ptr(),
n as u64,
(rows * cols_in) as u64,
(cols_out * size) as u64,
) as usize
}
}
}
unsafe impl VmpPMatPrepareImpl<FFT64> for FFT64 {
fn vmp_prepare_impl<R, A>(module: &Module<FFT64>, res: &mut R, a: &A, scratch: &mut Scratch<FFT64>)
where
R: VmpPMatToMut<FFT64>,
A: MatZnxToRef,
{
let mut res: VmpPMat<&mut [u8], FFT64> = res.to_mut();
let a: MatZnx<&[u8]> = a.to_ref();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
assert_eq!(
res.cols_in(),
a.cols_in(),
"res.cols_in: {} != a.cols_in: {}",
res.cols_in(),
a.cols_in()
);
assert_eq!(
res.rows(),
a.rows(),
"res.rows: {} != a.rows: {}",
res.rows(),
a.rows()
);
assert_eq!(
res.cols_out(),
a.cols_out(),
"res.cols_out: {} != a.cols_out: {}",
res.cols_out(),
a.cols_out()
);
assert_eq!(
res.size(),
a.size(),
"res.size: {} != a.size: {}",
res.size(),
a.size()
);
}
let (tmp_bytes, _) =
scratch.take_slice(module.vmp_prepare_tmp_bytes(res.n(), a.rows(), a.cols_in(), a.cols_out(), a.size()));
unsafe {
vmp::vmp_prepare_contiguous(
module.ptr(),
res.as_mut_ptr() as *mut vmp::vmp_pmat_t,
a.as_ptr(),
(a.rows() * a.cols_in()) as u64,
(a.size() * a.cols_out()) as u64,
tmp_bytes.as_mut_ptr(),
);
}
}
}
unsafe impl VmpApplyTmpBytesImpl<FFT64> for FFT64 {
fn vmp_apply_tmp_bytes_impl(
module: &Module<FFT64>,
n: usize,
res_size: usize,
a_size: usize,
b_rows: usize,
b_cols_in: usize,
b_cols_out: usize,
b_size: usize,
) -> usize {
unsafe {
vmp::vmp_apply_dft_to_dft_tmp_bytes(
module.ptr(),
n as u64,
(res_size * b_cols_out) as u64,
(a_size * b_cols_in) as u64,
(b_rows * b_cols_in) as u64,
(b_size * b_cols_out) as u64,
) as usize
}
}
}
unsafe impl VmpApplyImpl<FFT64> for FFT64 {
fn vmp_apply_impl<R, A, C>(module: &Module<FFT64>, res: &mut R, a: &A, b: &C, scratch: &mut Scratch<FFT64>)
where
R: VecZnxDftToMut<FFT64>,
A: VecZnxDftToRef<FFT64>,
C: VmpPMatToRef<FFT64>,
{
let mut res: VecZnxDft<&mut [u8], _> = res.to_mut();
let a: VecZnxDft<&[u8], _> = a.to_ref();
let b: VmpPMat<&[u8], _> = b.to_ref();
#[cfg(debug_assertions)]
{
assert_eq!(b.n(), res.n());
assert_eq!(a.n(), res.n());
assert_eq!(
res.cols(),
b.cols_out(),
"res.cols(): {} != b.cols_out: {}",
res.cols(),
b.cols_out()
);
assert_eq!(
a.cols(),
b.cols_in(),
"a.cols(): {} != b.cols_in: {}",
a.cols(),
b.cols_in()
);
}
let (tmp_bytes, _) = scratch.take_slice(module.vmp_apply_tmp_bytes(
res.n(),
res.size(),
a.size(),
b.rows(),
b.cols_in(),
b.cols_out(),
b.size(),
));
unsafe {
vmp::vmp_apply_dft_to_dft(
module.ptr(),
res.as_mut_ptr() as *mut vec_znx_dft_t,
(res.size() * res.cols()) as u64,
a.as_ptr() as *const vec_znx_dft_t,
(a.size() * a.cols()) as u64,
b.as_ptr() as *const vmp::vmp_pmat_t,
(b.rows() * b.cols_in()) as u64,
(b.size() * b.cols_out()) as u64,
tmp_bytes.as_mut_ptr(),
)
}
}
}
unsafe impl VmpApplyAddTmpBytesImpl<FFT64> for FFT64 {
fn vmp_apply_add_tmp_bytes_impl(
module: &Module<FFT64>,
n: usize,
res_size: usize,
a_size: usize,
b_rows: usize,
b_cols_in: usize,
b_cols_out: usize,
b_size: usize,
) -> usize {
unsafe {
vmp::vmp_apply_dft_to_dft_tmp_bytes(
module.ptr(),
n as u64,
(res_size * b_cols_out) as u64,
(a_size * b_cols_in) as u64,
(b_rows * b_cols_in) as u64,
(b_size * b_cols_out) as u64,
) as usize
}
}
}
unsafe impl VmpApplyAddImpl<FFT64> for FFT64 {
fn vmp_apply_add_impl<R, A, C>(module: &Module<FFT64>, res: &mut R, a: &A, b: &C, scale: usize, scratch: &mut Scratch<FFT64>)
where
R: VecZnxDftToMut<FFT64>,
A: VecZnxDftToRef<FFT64>,
C: VmpPMatToRef<FFT64>,
{
let mut res: VecZnxDft<&mut [u8], _> = res.to_mut();
let a: VecZnxDft<&[u8], _> = a.to_ref();
let b: VmpPMat<&[u8], _> = b.to_ref();
#[cfg(debug_assertions)]
{
use crate::hal::api::ZnxInfos;
assert_eq!(b.n(), res.n());
assert_eq!(a.n(), res.n());
assert_eq!(
res.cols(),
b.cols_out(),
"res.cols(): {} != b.cols_out: {}",
res.cols(),
b.cols_out()
);
assert_eq!(
a.cols(),
b.cols_in(),
"a.cols(): {} != b.cols_in: {}",
a.cols(),
b.cols_in()
);
}
let (tmp_bytes, _) = scratch.take_slice(module.vmp_apply_tmp_bytes(
res.n(),
res.size(),
a.size(),
b.rows(),
b.cols_in(),
b.cols_out(),
b.size(),
));
unsafe {
vmp::vmp_apply_dft_to_dft_add(
module.ptr(),
res.as_mut_ptr() as *mut vec_znx_dft_t,
(res.size() * res.cols()) as u64,
a.as_ptr() as *const vec_znx_dft_t,
(a.size() * a.cols()) as u64,
b.as_ptr() as *const vmp::vmp_pmat_t,
(b.rows() * b.cols_in()) as u64,
(b.size() * b.cols_out()) as u64,
(scale * b.cols_out()) as u64,
tmp_bytes.as_mut_ptr(),
)
}
}
}

View File

@@ -0,0 +1,11 @@
use crate::{
hal::{
api::ZnxView,
layouts::{DataRef, VmpPMat},
},
implementation::cpu_spqlios::module_ntt120::NTT120,
};
impl<D: DataRef> ZnxView for VmpPMat<D, NTT120> {
type Scalar = i64;
}

View File

@@ -0,0 +1 @@
pub mod cpu_spqlios;

106
poulpy-backend/src/lib.rs Normal file
View File

@@ -0,0 +1,106 @@
#![allow(non_camel_case_types, non_snake_case, non_upper_case_globals, dead_code, improper_ctypes)]
#![deny(rustdoc::broken_intra_doc_links)]
#![cfg_attr(docsrs, feature(doc_cfg))]
#![feature(trait_alias)]
pub mod hal;
pub mod implementation;
pub mod doc {
#[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/docs/backend_safety_contract.md"))]
pub mod backend_safety {
pub const _PLACEHOLDER: () = ();
}
}
pub const GALOISGENERATOR: u64 = 5;
pub const DEFAULTALIGN: usize = 64;
fn is_aligned_custom<T>(ptr: *const T, align: usize) -> bool {
(ptr as usize).is_multiple_of(align)
}
pub fn is_aligned<T>(ptr: *const T) -> bool {
is_aligned_custom(ptr, DEFAULTALIGN)
}
pub fn assert_alignement<T>(ptr: *const T) {
assert!(
is_aligned(ptr),
"invalid alignement: ensure passed bytes have been allocated with [alloc_aligned_u8] or [alloc_aligned]"
)
}
pub fn cast<T, V>(data: &[T]) -> &[V] {
let ptr: *const V = data.as_ptr() as *const V;
let len: usize = data.len() / size_of::<V>();
unsafe { std::slice::from_raw_parts(ptr, len) }
}
#[allow(clippy::mut_from_ref)]
pub fn cast_mut<T, V>(data: &[T]) -> &mut [V] {
let ptr: *mut V = data.as_ptr() as *mut V;
let len: usize = data.len() / size_of::<V>();
unsafe { std::slice::from_raw_parts_mut(ptr, len) }
}
/// Allocates a block of bytes with a custom alignement.
/// Alignement must be a power of two and size a multiple of the alignement.
/// Allocated memory is initialized to zero.
fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec<u8> {
assert!(
align.is_power_of_two(),
"Alignment must be a power of two but is {}",
align
);
assert_eq!(
(size * size_of::<u8>()) % align,
0,
"size={} must be a multiple of align={}",
size,
align
);
unsafe {
let layout: std::alloc::Layout = std::alloc::Layout::from_size_align(size, align).expect("Invalid alignment");
let ptr: *mut u8 = std::alloc::alloc(layout);
if ptr.is_null() {
panic!("Memory allocation failed");
}
assert!(
is_aligned_custom(ptr, align),
"Memory allocation at {:p} is not aligned to {} bytes",
ptr,
align
);
// Init allocated memory to zero
std::ptr::write_bytes(ptr, 0, size);
Vec::from_raw_parts(ptr, size, size)
}
}
/// Allocates a block of T aligned with [DEFAULTALIGN].
/// Size of T * size msut be a multiple of [DEFAULTALIGN].
pub fn alloc_aligned_custom<T>(size: usize, align: usize) -> Vec<T> {
assert_eq!(
(size * size_of::<T>()) % (align / size_of::<T>()),
0,
"size={} must be a multiple of align={}",
size,
align
);
let mut vec_u8: Vec<u8> = alloc_aligned_custom_u8(size_of::<T>() * size, align);
let ptr: *mut T = vec_u8.as_mut_ptr() as *mut T;
let len: usize = vec_u8.len() / size_of::<T>();
let cap: usize = vec_u8.capacity() / size_of::<T>();
std::mem::forget(vec_u8);
unsafe { Vec::from_raw_parts(ptr, len, cap) }
}
/// Allocates an aligned vector of size equal to the smallest multiple
/// of [DEFAULTALIGN]/`size_of::<T>`() that is equal or greater to `size`.
pub fn alloc_aligned<T>(size: usize) -> Vec<T> {
alloc_aligned_custom::<T>(
size + (DEFAULTALIGN - (size % (DEFAULTALIGN / size_of::<T>()))) % DEFAULTALIGN,
DEFAULTALIGN,
)
}