mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
Add Hardware Abstraction Layer (#56)
This commit is contained in:
committed by
GitHub
parent
833520b163
commit
0e0745065e
@@ -13,7 +13,12 @@ rand_distr = {workspace = true}
|
||||
rand_core = {workspace = true}
|
||||
sampling = { path = "../sampling" }
|
||||
utils = { path = "../utils" }
|
||||
paste = "1.0.15"
|
||||
byteorder = {workspace = true}
|
||||
|
||||
[[bench]]
|
||||
name = "fft"
|
||||
harness = false
|
||||
[build-dependencies]
|
||||
cmake = "0.1.54"
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
all-features = true
|
||||
rustdoc-args = ["--cfg", "docsrs"]
|
||||
@@ -1,56 +0,0 @@
|
||||
use backend::ffi::reim::*;
|
||||
use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main};
|
||||
use std::ffi::c_void;
|
||||
|
||||
fn fft(c: &mut Criterion) {
|
||||
fn forward<'a>(m: u32, log_bound: u32, reim_fft_precomp: *mut reim_fft_precomp, a: &'a [i64]) -> Box<dyn FnMut() + 'a> {
|
||||
unsafe {
|
||||
let buf_a: *mut f64 = reim_fft_precomp_get_buffer(reim_fft_precomp, 0);
|
||||
reim_from_znx64_simple(m as u32, log_bound as u32, buf_a as *mut c_void, a.as_ptr());
|
||||
Box::new(move || reim_fft(reim_fft_precomp, buf_a))
|
||||
}
|
||||
}
|
||||
|
||||
fn backward<'a>(m: u32, log_bound: u32, reim_ifft_precomp: *mut reim_ifft_precomp, a: &'a [i64]) -> Box<dyn FnMut() + 'a> {
|
||||
Box::new(move || unsafe {
|
||||
let buf_a: *mut f64 = reim_ifft_precomp_get_buffer(reim_ifft_precomp, 0);
|
||||
reim_from_znx64_simple(m as u32, log_bound as u32, buf_a as *mut c_void, a.as_ptr());
|
||||
reim_ifft(reim_ifft_precomp, buf_a);
|
||||
})
|
||||
}
|
||||
|
||||
let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = c.benchmark_group("fft");
|
||||
|
||||
for log_n in 10..17 {
|
||||
let n: usize = 1 << log_n;
|
||||
let m: usize = n >> 1;
|
||||
let log_bound: u32 = 19;
|
||||
|
||||
let mut a: Vec<i64> = vec![i64::default(); n];
|
||||
a.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
|
||||
|
||||
unsafe {
|
||||
let reim_fft_precomp: *mut reim_fft_precomp = new_reim_fft_precomp(m as u32, 1);
|
||||
let reim_ifft_precomp: *mut reim_ifft_precomp = new_reim_ifft_precomp(m as u32, 1);
|
||||
|
||||
let runners: [(String, Box<dyn FnMut()>); 2] = [
|
||||
(format!("forward"), {
|
||||
forward(m as u32, log_bound, reim_fft_precomp, &a)
|
||||
}),
|
||||
(format!("backward"), {
|
||||
backward(m as u32, log_bound, reim_ifft_precomp, &a)
|
||||
}),
|
||||
];
|
||||
|
||||
for (name, mut runner) in runners {
|
||||
let id: BenchmarkId = BenchmarkId::new(name, format!("n={}", 1 << log_n));
|
||||
b.bench_with_input(id, &(), |b: &mut criterion::Bencher<'_>, _| {
|
||||
b.iter(&mut runner)
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, fft,);
|
||||
criterion_main!(benches);
|
||||
@@ -1,13 +1,7 @@
|
||||
use std::path::absolute;
|
||||
|
||||
fn main() {
|
||||
println!(
|
||||
"cargo:rustc-link-search=native={}",
|
||||
absolute("spqlios-arithmetic/build/spqlios")
|
||||
.unwrap()
|
||||
.to_str()
|
||||
.unwrap()
|
||||
);
|
||||
println!("cargo:rustc-link-lib=static=spqlios");
|
||||
// println!("cargo:rustc-link-lib=dylib=spqlios")
|
||||
}
|
||||
mod builds {
|
||||
pub mod cpu_spqlios;
|
||||
}
|
||||
|
||||
fn main() {
|
||||
builds::cpu_spqlios::build()
|
||||
}
|
||||
|
||||
10
backend/builds/cpu_spqlios.rs
Normal file
10
backend/builds/cpu_spqlios.rs
Normal file
@@ -0,0 +1,10 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
pub fn build() {
|
||||
let dst: PathBuf = cmake::Config::new("src/implementation/cpu_spqlios/spqlios-arithmetic").build();
|
||||
|
||||
let lib_dir: PathBuf = dst.join("lib");
|
||||
|
||||
println!("cargo:rustc-link-search=native={}", lib_dir.display());
|
||||
println!("cargo:rustc-link-lib=static=spqlios");
|
||||
}
|
||||
27
backend/docs/backend_safety_contract.md
Normal file
27
backend/docs/backend_safety_contract.md
Normal file
@@ -0,0 +1,27 @@
|
||||
Implementors must uphold all of the following for **every** call:
|
||||
|
||||
* **Memory domains**: Pointers produced by to_ref() / to_mut() must be valid
|
||||
in the target execution domain for Self (e.g., CPU host memory for CPU,
|
||||
device memory for a specific GPU). If host↔device transfers are required,
|
||||
perform them inside the implementation; do not assume the caller synchronized.
|
||||
|
||||
* **Alignment & layout**: All data must match the layout, stride, and element
|
||||
size expected by the kernel. size(), rows(), cols_in(), cols_out(),
|
||||
n(), etc... must be interpreted identically to the reference CPU implementation.
|
||||
|
||||
* **Scratch lifetime**: Any scratch obtained from scratch.tmp_slice(...) (or a
|
||||
backend-specific variant) must remain valid for the duration of the call; it
|
||||
may be reused by the caller afterwards. Do not retain pointers past return.
|
||||
|
||||
* **Synchronization**: The call must appear **logically synchronous** to the
|
||||
caller. If you enqueue asynchronous work (e.g., CUDA streams), you must
|
||||
ensure completion before returning or clearly document and implement a
|
||||
synchronization contract used by all backends consistently.
|
||||
|
||||
* **Aliasing & overlaps**: If res, a, b, etc... alias or overlap in ways
|
||||
that violate your kernel’s requirements, you must either handle safely or reject
|
||||
with a defined error path (e.g., debug assert). Never trigger UB.
|
||||
|
||||
* **Numerical contract**: For modular/integer arithmetic, results must be
|
||||
bit-exact to the specification. For floating-point, any permitted tolerance
|
||||
must be documented and consistent with the crate’s guarantees.
|
||||
@@ -1,56 +0,0 @@
|
||||
use backend::ffi::reim::*;
|
||||
use std::ffi::c_void;
|
||||
use std::time::Instant;
|
||||
|
||||
fn main() {
|
||||
let log_bound: usize = 19;
|
||||
|
||||
let n: usize = 2048;
|
||||
let m: usize = n >> 1;
|
||||
|
||||
let mut a: Vec<i64> = vec![i64::default(); n];
|
||||
let mut b: Vec<i64> = vec![i64::default(); n];
|
||||
let mut c: Vec<i64> = vec![i64::default(); n];
|
||||
|
||||
a.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
|
||||
b[1] = 1;
|
||||
|
||||
println!("{:?}", b);
|
||||
|
||||
unsafe {
|
||||
let reim_fft_precomp = new_reim_fft_precomp(m as u32, 2);
|
||||
let reim_ifft_precomp = new_reim_ifft_precomp(m as u32, 1);
|
||||
|
||||
let buf_a = reim_fft_precomp_get_buffer(reim_fft_precomp, 0);
|
||||
let buf_b = reim_fft_precomp_get_buffer(reim_fft_precomp, 1);
|
||||
let buf_c = reim_ifft_precomp_get_buffer(reim_ifft_precomp, 0);
|
||||
|
||||
let now = Instant::now();
|
||||
(0..1024).for_each(|_| {
|
||||
reim_from_znx64_simple(m as u32, log_bound as u32, buf_a as *mut c_void, a.as_ptr());
|
||||
reim_fft(reim_fft_precomp, buf_a);
|
||||
|
||||
reim_from_znx64_simple(m as u32, log_bound as u32, buf_b as *mut c_void, b.as_ptr());
|
||||
reim_fft(reim_fft_precomp, buf_b);
|
||||
|
||||
reim_fftvec_mul_simple(
|
||||
m as u32,
|
||||
buf_c as *mut c_void,
|
||||
buf_a as *mut c_void,
|
||||
buf_b as *mut c_void,
|
||||
);
|
||||
reim_ifft(reim_ifft_precomp, buf_c);
|
||||
|
||||
reim_to_znx64_simple(
|
||||
m as u32,
|
||||
m as f64,
|
||||
log_bound as u32,
|
||||
c.as_mut_ptr(),
|
||||
buf_c as *mut c_void,
|
||||
)
|
||||
});
|
||||
|
||||
println!("time: {}us", now.elapsed().as_micros());
|
||||
println!("{:?}", &c[..16]);
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,14 @@
|
||||
use backend::{
|
||||
AddNormal, Decoding, Encoding, FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc,
|
||||
ScalarZnxDftOps, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft,
|
||||
VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxInfos,
|
||||
hal::{
|
||||
api::{
|
||||
ModuleNew, ScalarZnxAlloc, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPrepare,
|
||||
VecZnxAddNormal, VecZnxAlloc, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigNormalize,
|
||||
VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace, VecZnxDecodeVeci64, VecZnxDftAlloc, VecZnxDftFromVecZnx,
|
||||
VecZnxDftToVecZnxBigTmpA, VecZnxEncodeVeci64, VecZnxFillUniform, VecZnxNormalizeInplace, ZnxInfos,
|
||||
},
|
||||
layouts::{Module, ScalarZnx, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft},
|
||||
},
|
||||
implementation::cpu_spqlios::FFT64,
|
||||
};
|
||||
use itertools::izip;
|
||||
use sampling::source::Source;
|
||||
@@ -12,35 +19,35 @@ fn main() {
|
||||
let ct_size: usize = 3;
|
||||
let msg_size: usize = 2;
|
||||
let log_scale: usize = msg_size * basek - 5;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n as u64);
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(module.vec_znx_big_normalize_tmp_bytes());
|
||||
let mut scratch: ScratchOwned<FFT64> = ScratchOwned::<FFT64>::alloc(module.vec_znx_big_normalize_tmp_bytes(n));
|
||||
|
||||
let seed: [u8; 32] = [0; 32];
|
||||
let mut source: Source = Source::new(seed);
|
||||
|
||||
// s <- Z_{-1, 0, 1}[X]/(X^{N}+1)
|
||||
let mut s: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
||||
let mut s: ScalarZnx<Vec<u8>> = module.scalar_znx_alloc(1);
|
||||
s.fill_ternary_prob(0, 0.5, &mut source);
|
||||
|
||||
// Buffer to store s in the DFT domain
|
||||
let mut s_dft: ScalarZnxDft<Vec<u8>, FFT64> = module.new_scalar_znx_dft(s.cols());
|
||||
let mut s_dft: SvpPPol<Vec<u8>, FFT64> = module.svp_ppol_alloc(s.cols());
|
||||
|
||||
// s_dft <- DFT(s)
|
||||
module.svp_prepare(&mut s_dft, 0, &s, 0);
|
||||
|
||||
// Allocates a VecZnx with two columns: ct=(0, 0)
|
||||
let mut ct: VecZnx<Vec<u8>> = module.new_vec_znx(
|
||||
let mut ct: VecZnx<Vec<u8>> = module.vec_znx_alloc(
|
||||
2, // Number of columns
|
||||
ct_size, // Number of small poly per column
|
||||
);
|
||||
|
||||
// Fill the second column with random values: ct = (0, a)
|
||||
ct.fill_uniform(basek, 1, ct_size, &mut source);
|
||||
module.vec_znx_fill_uniform(basek, &mut ct, 1, ct_size * basek, &mut source);
|
||||
|
||||
let mut buf_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct_size);
|
||||
let mut buf_dft: VecZnxDft<Vec<u8>, FFT64> = module.vec_znx_dft_alloc(1, ct_size);
|
||||
|
||||
module.vec_znx_dft(1, 0, &mut buf_dft, 0, &ct, 1);
|
||||
module.vec_znx_dft_from_vec_znx(1, 0, &mut buf_dft, 0, &ct, 1);
|
||||
|
||||
// Applies DFT(ct[1]) * DFT(s)
|
||||
module.svp_apply_inplace(
|
||||
@@ -53,18 +60,18 @@ fn main() {
|
||||
// Alias scratch space (VecZnxDft<B> is always at least as big as VecZnxBig<B>)
|
||||
|
||||
// BIG(ct[1] * s) <- IDFT(DFT(ct[1] * s)) (not normalized)
|
||||
let mut buf_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, ct_size);
|
||||
module.vec_znx_idft_tmp_a(&mut buf_big, 0, &mut buf_dft, 0);
|
||||
let mut buf_big: VecZnxBig<Vec<u8>, FFT64> = module.vec_znx_big_alloc(1, ct_size);
|
||||
module.vec_znx_dft_to_vec_znx_big_tmp_a(&mut buf_big, 0, &mut buf_dft, 0);
|
||||
|
||||
// Creates a plaintext: VecZnx with 1 column
|
||||
let mut m = module.new_vec_znx(
|
||||
let mut m = module.vec_znx_alloc(
|
||||
1, // Number of columns
|
||||
msg_size, // Number of small polynomials
|
||||
);
|
||||
let mut want: Vec<i64> = vec![0; n];
|
||||
want.iter_mut()
|
||||
.for_each(|x| *x = source.next_u64n(16, 15) as i64);
|
||||
m.encode_vec_i64(0, basek, log_scale, &want, 4);
|
||||
module.encode_vec_i64(basek, &mut m, 0, log_scale, &want, 4);
|
||||
module.vec_znx_normalize_inplace(basek, &mut m, 0, scratch.borrow());
|
||||
|
||||
// m - BIG(ct[1] * s)
|
||||
@@ -88,13 +95,14 @@ fn main() {
|
||||
|
||||
// Add noise to ct[0]
|
||||
// ct[0] <- ct[0] + e
|
||||
ct.add_normal(
|
||||
module.vec_znx_add_normal(
|
||||
basek,
|
||||
&mut ct,
|
||||
0, // Selects the first column of ct (ct[0])
|
||||
basek * ct_size, // Scaling of the noise: 2^{-basek * limbs}
|
||||
&mut source,
|
||||
3.2, // Standard deviation
|
||||
19.0, // Truncatation bound
|
||||
3.2, // Standard deviation
|
||||
3.2 * 6.0, // Truncatation bound
|
||||
);
|
||||
|
||||
// Final ciphertext: ct = (-a * s + m + e, a)
|
||||
@@ -102,7 +110,7 @@ fn main() {
|
||||
// Decryption
|
||||
|
||||
// DFT(ct[1] * s)
|
||||
module.vec_znx_dft(1, 0, &mut buf_dft, 0, &ct, 1);
|
||||
module.vec_znx_dft_from_vec_znx(1, 0, &mut buf_dft, 0, &ct, 1);
|
||||
module.svp_apply_inplace(
|
||||
&mut buf_dft,
|
||||
0, // Selects the first column of res.
|
||||
@@ -111,18 +119,18 @@ fn main() {
|
||||
);
|
||||
|
||||
// BIG(c1 * s) = IDFT(DFT(c1 * s))
|
||||
module.vec_znx_idft_tmp_a(&mut buf_big, 0, &mut buf_dft, 0);
|
||||
module.vec_znx_dft_to_vec_znx_big_tmp_a(&mut buf_big, 0, &mut buf_dft, 0);
|
||||
|
||||
// BIG(c1 * s) + ct[0]
|
||||
module.vec_znx_big_add_small_inplace(&mut buf_big, 0, &ct, 0);
|
||||
|
||||
// m + e <- BIG(ct[1] * s + ct[0])
|
||||
let mut res = module.new_vec_znx(1, ct_size);
|
||||
let mut res = module.vec_znx_alloc(1, ct_size);
|
||||
module.vec_znx_big_normalize(basek, &mut res, 0, &buf_big, 0, scratch.borrow());
|
||||
|
||||
// have = m * 2^{log_scale} + e
|
||||
let mut have: Vec<i64> = vec![i64::default(); n];
|
||||
res.decode_vec_i64(0, basek, res.size() * basek, &mut have);
|
||||
module.decode_vec_i64(basek, &mut res, 0, ct_size * basek, &mut have);
|
||||
|
||||
let scale: f64 = (1 << (res.size() * basek - log_scale)) as f64;
|
||||
izip!(want.iter(), have.iter())
|
||||
|
||||
Submodule backend/spqlios-arithmetic deleted from 0ae9a7b5ad
@@ -1,344 +0,0 @@
|
||||
use crate::ffi::znx::znx_zero_i64_ref;
|
||||
use crate::znx_base::{ZnxView, ZnxViewMut};
|
||||
use crate::{VecZnx, znx_base::ZnxInfos};
|
||||
use itertools::izip;
|
||||
use rug::{Assign, Float};
|
||||
use std::cmp::min;
|
||||
|
||||
pub trait Encoding {
|
||||
/// encode a vector of i64 on the receiver.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `col_i`: the index of the poly where to encode the data.
|
||||
/// * `basek`: base two negative logarithm decomposition of the receiver.
|
||||
/// * `k`: base two negative logarithm of the scaling of the data.
|
||||
/// * `data`: data to encode on the receiver.
|
||||
/// * `log_max`: base two logarithm of the infinity norm of the input data.
|
||||
fn encode_vec_i64(&mut self, col_i: usize, basek: usize, k: usize, data: &[i64], log_max: usize);
|
||||
|
||||
/// encodes a single i64 on the receiver at the given index.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `col_i`: the index of the poly where to encode the data.
|
||||
/// * `basek`: base two negative logarithm decomposition of the receiver.
|
||||
/// * `k`: base two negative logarithm of the scaling of the data.
|
||||
/// * `i`: index of the coefficient on which to encode the data.
|
||||
/// * `data`: data to encode on the receiver.
|
||||
/// * `log_max`: base two logarithm of the infinity norm of the input data.
|
||||
fn encode_coeff_i64(&mut self, col_i: usize, basek: usize, k: usize, i: usize, data: i64, log_max: usize);
|
||||
}
|
||||
|
||||
pub trait Decoding {
|
||||
/// decode a vector of i64 from the receiver.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `col_i`: the index of the poly where to encode the data.
|
||||
/// * `basek`: base two negative logarithm decomposition of the receiver.
|
||||
/// * `k`: base two logarithm of the scaling of the data.
|
||||
/// * `data`: data to decode from the receiver.
|
||||
fn decode_vec_i64(&self, col_i: usize, basek: usize, k: usize, data: &mut [i64]);
|
||||
|
||||
/// decode a vector of Float from the receiver.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `col_i`: the index of the poly where to encode the data.
|
||||
/// * `basek`: base two negative logarithm decomposition of the receiver.
|
||||
/// * `data`: data to decode from the receiver.
|
||||
fn decode_vec_float(&self, col_i: usize, basek: usize, data: &mut [Float]);
|
||||
|
||||
/// decode a single of i64 from the receiver at the given index.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `col_i`: the index of the poly where to encode the data.
|
||||
/// * `basek`: base two negative logarithm decomposition of the receiver.
|
||||
/// * `k`: base two negative logarithm of the scaling of the data.
|
||||
/// * `i`: index of the coefficient to decode.
|
||||
/// * `data`: data to decode from the receiver.
|
||||
fn decode_coeff_i64(&self, col_i: usize, basek: usize, k: usize, i: usize) -> i64;
|
||||
}
|
||||
|
||||
impl<D: AsMut<[u8]> + AsRef<[u8]>> Encoding for VecZnx<D> {
|
||||
fn encode_vec_i64(&mut self, col_i: usize, basek: usize, k: usize, data: &[i64], log_max: usize) {
|
||||
encode_vec_i64(self, col_i, basek, k, data, log_max)
|
||||
}
|
||||
|
||||
fn encode_coeff_i64(&mut self, col_i: usize, basek: usize, k: usize, i: usize, value: i64, log_max: usize) {
|
||||
encode_coeff_i64(self, col_i, basek, k, i, value, log_max)
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: AsRef<[u8]>> Decoding for VecZnx<D> {
|
||||
fn decode_vec_i64(&self, col_i: usize, basek: usize, k: usize, data: &mut [i64]) {
|
||||
decode_vec_i64(self, col_i, basek, k, data)
|
||||
}
|
||||
|
||||
fn decode_vec_float(&self, col_i: usize, basek: usize, data: &mut [Float]) {
|
||||
decode_vec_float(self, col_i, basek, data)
|
||||
}
|
||||
|
||||
fn decode_coeff_i64(&self, col_i: usize, basek: usize, k: usize, i: usize) -> i64 {
|
||||
decode_coeff_i64(self, col_i, basek, k, i)
|
||||
}
|
||||
}
|
||||
|
||||
fn encode_vec_i64<D: AsMut<[u8]> + AsRef<[u8]>>(
|
||||
a: &mut VecZnx<D>,
|
||||
col_i: usize,
|
||||
basek: usize,
|
||||
k: usize,
|
||||
data: &[i64],
|
||||
log_max: usize,
|
||||
) {
|
||||
let size: usize = (k + basek - 1) / basek;
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(
|
||||
size <= a.size(),
|
||||
"invalid argument k: (k + a.basek - 1)/a.basek={} > a.size()={}",
|
||||
size,
|
||||
a.size()
|
||||
);
|
||||
assert!(col_i < a.cols());
|
||||
assert!(data.len() <= a.n())
|
||||
}
|
||||
|
||||
let data_len: usize = data.len();
|
||||
let k_rem: usize = basek - (k % basek);
|
||||
|
||||
// Zeroes coefficients of the i-th column
|
||||
(0..a.size()).for_each(|i| unsafe {
|
||||
znx_zero_i64_ref(a.n() as u64, a.at_mut_ptr(col_i, i));
|
||||
});
|
||||
|
||||
// If 2^{basek} * 2^{k_rem} < 2^{63}-1, then we can simply copy
|
||||
// values on the last limb.
|
||||
// Else we decompose values base2k.
|
||||
if log_max + k_rem < 63 || k_rem == basek {
|
||||
a.at_mut(col_i, size - 1)[..data_len].copy_from_slice(&data[..data_len]);
|
||||
} else {
|
||||
let mask: i64 = (1 << basek) - 1;
|
||||
let steps: usize = min(size, (log_max + basek - 1) / basek);
|
||||
(size - steps..size)
|
||||
.rev()
|
||||
.enumerate()
|
||||
.for_each(|(i, i_rev)| {
|
||||
let shift: usize = i * basek;
|
||||
izip!(a.at_mut(col_i, i_rev).iter_mut(), data.iter()).for_each(|(y, x)| *y = (x >> shift) & mask);
|
||||
})
|
||||
}
|
||||
|
||||
// Case where self.prec % self.k != 0.
|
||||
if k_rem != basek {
|
||||
let steps: usize = min(size, (log_max + basek - 1) / basek);
|
||||
(size - steps..size).rev().for_each(|i| {
|
||||
a.at_mut(col_i, i)[..data_len]
|
||||
.iter_mut()
|
||||
.for_each(|x| *x <<= k_rem);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn decode_vec_i64<D: AsRef<[u8]>>(a: &VecZnx<D>, col_i: usize, basek: usize, k: usize, data: &mut [i64]) {
|
||||
let size: usize = (k + basek - 1) / basek;
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(
|
||||
data.len() >= a.n(),
|
||||
"invalid data: data.len()={} < a.n()={}",
|
||||
data.len(),
|
||||
a.n()
|
||||
);
|
||||
assert!(col_i < a.cols());
|
||||
}
|
||||
data.copy_from_slice(a.at(col_i, 0));
|
||||
let rem: usize = basek - (k % basek);
|
||||
if k < basek {
|
||||
data.iter_mut().for_each(|x| *x >>= rem);
|
||||
} else {
|
||||
(1..size).for_each(|i| {
|
||||
if i == size - 1 && rem != basek {
|
||||
let k_rem: usize = basek - rem;
|
||||
izip!(a.at(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| {
|
||||
*y = (*y << k_rem) + (x >> rem);
|
||||
});
|
||||
} else {
|
||||
izip!(a.at(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| {
|
||||
*y = (*y << basek) + x;
|
||||
});
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn decode_vec_float<D: AsRef<[u8]>>(a: &VecZnx<D>, col_i: usize, basek: usize, data: &mut [Float]) {
|
||||
let size: usize = a.size();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(
|
||||
data.len() >= a.n(),
|
||||
"invalid data: data.len()={} < a.n()={}",
|
||||
data.len(),
|
||||
a.n()
|
||||
);
|
||||
assert!(col_i < a.cols());
|
||||
}
|
||||
|
||||
let prec: u32 = (basek * size) as u32;
|
||||
|
||||
// 2^{basek}
|
||||
let base = Float::with_val(prec, (1 << basek) as f64);
|
||||
|
||||
// y[i] = sum x[j][i] * 2^{-basek*j}
|
||||
(0..size).for_each(|i| {
|
||||
if i == 0 {
|
||||
izip!(a.at(col_i, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
|
||||
y.assign(*x);
|
||||
*y /= &base;
|
||||
});
|
||||
} else {
|
||||
izip!(a.at(col_i, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
|
||||
*y += Float::with_val(prec, *x);
|
||||
*y /= &base;
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn encode_coeff_i64<D: AsMut<[u8]> + AsRef<[u8]>>(
|
||||
a: &mut VecZnx<D>,
|
||||
col_i: usize,
|
||||
basek: usize,
|
||||
k: usize,
|
||||
i: usize,
|
||||
value: i64,
|
||||
log_max: usize,
|
||||
) {
|
||||
let size: usize = (k + basek - 1) / basek;
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(i < a.n());
|
||||
assert!(
|
||||
size <= a.size(),
|
||||
"invalid argument k: (k + a.basek - 1)/a.basek={} > a.size()={}",
|
||||
size,
|
||||
a.size()
|
||||
);
|
||||
assert!(col_i < a.cols());
|
||||
}
|
||||
|
||||
let k_rem: usize = basek - (k % basek);
|
||||
(0..a.size()).for_each(|j| a.at_mut(col_i, j)[i] = 0);
|
||||
|
||||
// If 2^{basek} * 2^{k_rem} < 2^{63}-1, then we can simply copy
|
||||
// values on the last limb.
|
||||
// Else we decompose values base2k.
|
||||
if log_max + k_rem < 63 || k_rem == basek {
|
||||
a.at_mut(col_i, size - 1)[i] = value;
|
||||
} else {
|
||||
let mask: i64 = (1 << basek) - 1;
|
||||
let steps: usize = min(size, (log_max + basek - 1) / basek);
|
||||
(size - steps..size)
|
||||
.rev()
|
||||
.enumerate()
|
||||
.for_each(|(j, j_rev)| {
|
||||
a.at_mut(col_i, j_rev)[i] = (value >> (j * basek)) & mask;
|
||||
})
|
||||
}
|
||||
|
||||
// Case where prec % k != 0.
|
||||
if k_rem != basek {
|
||||
let steps: usize = min(size, (log_max + basek - 1) / basek);
|
||||
(size - steps..size).rev().for_each(|j| {
|
||||
a.at_mut(col_i, j)[i] <<= k_rem;
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn decode_coeff_i64<D: AsRef<[u8]>>(a: &VecZnx<D>, col_i: usize, basek: usize, k: usize, i: usize) -> i64 {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(i < a.n());
|
||||
assert!(col_i < a.cols())
|
||||
}
|
||||
|
||||
let size: usize = (k + basek - 1) / basek;
|
||||
let data: &[i64] = a.raw();
|
||||
let mut res: i64 = 0;
|
||||
let rem: usize = basek - (k % basek);
|
||||
let slice_size: usize = a.n() * a.cols();
|
||||
(0..size).for_each(|j| {
|
||||
let x: i64 = data[j * slice_size + i];
|
||||
if j == size - 1 && rem != basek {
|
||||
let k_rem: usize = basek - rem;
|
||||
res = (res << k_rem) + (x >> rem);
|
||||
} else {
|
||||
res = (res << basek) + x;
|
||||
}
|
||||
});
|
||||
res
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::vec_znx_ops::*;
|
||||
use crate::znx_base::*;
|
||||
use crate::{Decoding, Encoding, FFT64, Module, VecZnx, znx_base::ZnxInfos};
|
||||
use itertools::izip;
|
||||
use sampling::source::Source;
|
||||
|
||||
#[test]
|
||||
fn test_set_get_i64_lo_norm() {
|
||||
let n: usize = 8;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let basek: usize = 17;
|
||||
let size: usize = 5;
|
||||
let k: usize = size * basek - 5;
|
||||
let mut a: VecZnx<_> = module.new_vec_znx(2, size);
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let raw: &mut [i64] = a.raw_mut();
|
||||
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
|
||||
(0..a.cols()).for_each(|col_i| {
|
||||
let mut have: Vec<i64> = vec![i64::default(); n];
|
||||
have.iter_mut()
|
||||
.for_each(|x| *x = (source.next_i64() << 56) >> 56);
|
||||
a.encode_vec_i64(col_i, basek, k, &have, 10);
|
||||
let mut want: Vec<i64> = vec![i64::default(); n];
|
||||
a.decode_vec_i64(col_i, basek, k, &mut want);
|
||||
izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b));
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_set_get_i64_hi_norm() {
|
||||
let n: usize = 8;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let basek: usize = 17;
|
||||
let size: usize = 5;
|
||||
for k in [1, basek / 2, size * basek - 5] {
|
||||
let mut a: VecZnx<_> = module.new_vec_znx(2, size);
|
||||
let mut source = Source::new([0u8; 32]);
|
||||
let raw: &mut [i64] = a.raw_mut();
|
||||
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
|
||||
(0..a.cols()).for_each(|col_i| {
|
||||
let mut have: Vec<i64> = vec![i64::default(); n];
|
||||
have.iter_mut().for_each(|x| {
|
||||
if k < 64 {
|
||||
*x = source.next_u64n(1 << k, (1 << k) - 1) as i64;
|
||||
} else {
|
||||
*x = source.next_i64();
|
||||
}
|
||||
});
|
||||
a.encode_vec_i64(col_i, basek, k, &have, std::cmp::min(k, 64));
|
||||
let mut want = vec![i64::default(); n];
|
||||
a.decode_vec_i64(col_i, basek, k, &mut want);
|
||||
izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b));
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
17
backend/src/hal/api/mat_znx.rs
Normal file
17
backend/src/hal/api/mat_znx.rs
Normal file
@@ -0,0 +1,17 @@
|
||||
use crate::hal::layouts::MatZnxOwned;
|
||||
|
||||
/// Allocates as [crate::hal::layouts::MatZnx].
|
||||
pub trait MatZnxAlloc {
|
||||
fn mat_znx_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxOwned;
|
||||
}
|
||||
|
||||
/// Returns the size in bytes to allocate a [crate::hal::layouts::MatZnx].
|
||||
pub trait MatZnxAllocBytes {
|
||||
fn mat_znx_alloc_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
/// Consume a vector of bytes into a [crate::hal::layouts::MatZnx].
|
||||
/// User must ensure that bytes is memory aligned and that it length is equal to [MatZnxAllocBytes].
|
||||
pub trait MatZnxFromBytes {
|
||||
fn mat_znx_from_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: Vec<u8>) -> MatZnxOwned;
|
||||
}
|
||||
21
backend/src/hal/api/mod.rs
Normal file
21
backend/src/hal/api/mod.rs
Normal file
@@ -0,0 +1,21 @@
|
||||
mod mat_znx;
|
||||
mod module;
|
||||
mod scalar_znx;
|
||||
mod scratch;
|
||||
mod svp_ppol;
|
||||
mod vec_znx;
|
||||
mod vec_znx_big;
|
||||
mod vec_znx_dft;
|
||||
mod vmp_pmat;
|
||||
mod znx_base;
|
||||
|
||||
pub use mat_znx::*;
|
||||
pub use module::*;
|
||||
pub use scalar_znx::*;
|
||||
pub use scratch::*;
|
||||
pub use svp_ppol::*;
|
||||
pub use vec_znx::*;
|
||||
pub use vec_znx_big::*;
|
||||
pub use vec_znx_dft::*;
|
||||
pub use vmp_pmat::*;
|
||||
pub use znx_base::*;
|
||||
6
backend/src/hal/api/module.rs
Normal file
6
backend/src/hal/api/module.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
use crate::hal::layouts::{Backend, Module};
|
||||
|
||||
/// Instantiate a new [crate::hal::layouts::Module].
|
||||
pub trait ModuleNew<B: Backend> {
|
||||
fn new(n: u64) -> Module<B>;
|
||||
}
|
||||
47
backend/src/hal/api/scalar_znx.rs
Normal file
47
backend/src/hal/api/scalar_znx.rs
Normal file
@@ -0,0 +1,47 @@
|
||||
use crate::hal::layouts::{ScalarZnxOwned, ScalarZnxToMut, ScalarZnxToRef};
|
||||
|
||||
/// Allocates as [crate::hal::layouts::ScalarZnx].
|
||||
pub trait ScalarZnxAlloc {
|
||||
fn scalar_znx_alloc(&self, cols: usize) -> ScalarZnxOwned;
|
||||
}
|
||||
|
||||
/// Returns the size in bytes to allocate a [crate::hal::layouts::ScalarZnx].
|
||||
pub trait ScalarZnxAllocBytes {
|
||||
fn scalar_znx_alloc_bytes(&self, cols: usize) -> usize;
|
||||
}
|
||||
|
||||
/// Consume a vector of bytes into a [crate::hal::layouts::ScalarZnx].
|
||||
/// User must ensure that bytes is memory aligned and that it length is equal to [ScalarZnxAllocBytes].
|
||||
pub trait ScalarZnxFromBytes {
|
||||
fn scalar_znx_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxOwned;
|
||||
}
|
||||
|
||||
/// Applies the mapping X -> X^k to a\[a_col\] and write the result on res\[res_col\].
|
||||
pub trait ScalarZnxAutomorphism {
|
||||
fn scalar_znx_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: ScalarZnxToMut,
|
||||
A: ScalarZnxToRef;
|
||||
}
|
||||
|
||||
/// Applies the mapping X -> X^k on res\[res_col\].
|
||||
pub trait ScalarZnxAutomorphismInplace {
|
||||
fn scalar_znx_automorphism_inplace<R>(&self, k: i64, res: &mut R, res_col: usize)
|
||||
where
|
||||
R: ScalarZnxToMut;
|
||||
}
|
||||
|
||||
/// Multiply a\[a_col\] with (X^p - 1) and write the result on res\[res_col\].
|
||||
pub trait ScalarZnxMulXpMinusOne {
|
||||
fn scalar_znx_mul_xp_minus_one<R, A>(&self, p: i64, r: &mut R, r_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: ScalarZnxToMut,
|
||||
A: ScalarZnxToRef;
|
||||
}
|
||||
|
||||
/// Multiply res\[res_col\] with (X^p - 1).
|
||||
pub trait ScalarZnxMulXpMinusOneInplace {
|
||||
fn scalar_znx_mul_xp_minus_one_inplace<R>(&self, p: i64, res: &mut R, res_col: usize)
|
||||
where
|
||||
R: ScalarZnxToMut;
|
||||
}
|
||||
113
backend/src/hal/api/scratch.rs
Normal file
113
backend/src/hal/api/scratch.rs
Normal file
@@ -0,0 +1,113 @@
|
||||
use crate::hal::layouts::{Backend, MatZnx, Module, ScalarZnx, Scratch, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat};
|
||||
|
||||
/// Allocates a new [crate::hal::layouts::ScratchOwned] of `size` aligned bytes.
|
||||
pub trait ScratchOwnedAlloc<B: Backend> {
|
||||
fn alloc(size: usize) -> Self;
|
||||
}
|
||||
|
||||
/// Borrows a slice of bytes into a [Scratch].
|
||||
pub trait ScratchOwnedBorrow<B: Backend> {
|
||||
fn borrow(&mut self) -> &mut Scratch<B>;
|
||||
}
|
||||
|
||||
/// Wrap an array of mutable borrowed bytes into a [Scratch].
|
||||
pub trait ScratchFromBytes<B: Backend> {
|
||||
fn from_bytes(data: &mut [u8]) -> &mut Scratch<B>;
|
||||
}
|
||||
|
||||
/// Returns how many bytes left can be taken from the scratch.
|
||||
pub trait ScratchAvailable {
|
||||
fn available(&self) -> usize;
|
||||
}
|
||||
|
||||
/// Takes a slice of bytes from a [Scratch] and return a new [Scratch] minus the taken array of bytes.
|
||||
pub trait TakeSlice {
|
||||
fn take_slice<T>(&mut self, len: usize) -> (&mut [T], &mut Self);
|
||||
}
|
||||
|
||||
/// Take a slice of bytes from a [Scratch], wraps it into a [ScalarZnx] and returns it
|
||||
/// as well as a new [Scratch] minus the taken array of bytes.
|
||||
pub trait TakeScalarZnx<B: Backend> {
|
||||
fn take_scalar_znx(&mut self, module: &Module<B>, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self);
|
||||
}
|
||||
|
||||
/// Take a slice of bytes from a [Scratch], wraps it into a [SvpPPol] and returns it
|
||||
/// as well as a new [Scratch] minus the taken array of bytes.
|
||||
pub trait TakeSvpPPol<B: Backend> {
|
||||
fn take_svp_ppol(&mut self, module: &Module<B>, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Self);
|
||||
}
|
||||
|
||||
/// Take a slice of bytes from a [Scratch], wraps it into a [VecZnx] and returns it
|
||||
/// as well as a new [Scratch] minus the taken array of bytes.
|
||||
pub trait TakeVecZnx<B: Backend> {
|
||||
fn take_vec_znx(&mut self, module: &Module<B>, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self);
|
||||
}
|
||||
|
||||
/// Take a slice of bytes from a [Scratch], slices it into a vector of [VecZnx] aand returns it
|
||||
/// as well as a new [Scratch] minus the taken array of bytes.
|
||||
pub trait TakeVecZnxSlice<B: Backend> {
|
||||
fn take_vec_znx_slice(
|
||||
&mut self,
|
||||
len: usize,
|
||||
module: &Module<B>,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (Vec<VecZnx<&mut [u8]>>, &mut Self);
|
||||
}
|
||||
|
||||
/// Take a slice of bytes from a [Scratch], wraps it into a [VecZnxBig] and returns it
|
||||
/// as well as a new [Scratch] minus the taken array of bytes.
|
||||
pub trait TakeVecZnxBig<B: Backend> {
|
||||
fn take_vec_znx_big(&mut self, module: &Module<B>, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, &mut Self);
|
||||
}
|
||||
|
||||
/// Take a slice of bytes from a [Scratch], wraps it into a [VecZnxDft] and returns it
|
||||
/// as well as a new [Scratch] minus the taken array of bytes.
|
||||
pub trait TakeVecZnxDft<B: Backend> {
|
||||
fn take_vec_znx_dft(&mut self, module: &Module<B>, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, &mut Self);
|
||||
}
|
||||
|
||||
/// Take a slice of bytes from a [Scratch], slices it into a vector of [VecZnxDft] and returns it
|
||||
/// as well as a new [Scratch] minus the taken array of bytes.
|
||||
pub trait TakeVecZnxDftSlice<B: Backend> {
|
||||
fn take_vec_znx_dft_slice(
|
||||
&mut self,
|
||||
len: usize,
|
||||
module: &Module<B>,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (Vec<VecZnxDft<&mut [u8], B>>, &mut Self);
|
||||
}
|
||||
|
||||
/// Take a slice of bytes from a [Scratch], wraps it into a [VmpPMat] and returns it
|
||||
/// as well as a new [Scratch] minus the taken array of bytes.
|
||||
pub trait TakeVmpPMat<B: Backend> {
|
||||
fn take_vmp_pmat(
|
||||
&mut self,
|
||||
module: &Module<B>,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
) -> (VmpPMat<&mut [u8], B>, &mut Self);
|
||||
}
|
||||
|
||||
/// Take a slice of bytes from a [Scratch], wraps it into a [MatZnx] and returns it
|
||||
/// as well as a new [Scratch] minus the taken array of bytes.
|
||||
pub trait TakeMatZnx<B: Backend> {
|
||||
fn take_mat_znx(
|
||||
&mut self,
|
||||
module: &Module<B>,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
) -> (MatZnx<&mut [u8]>, &mut Self);
|
||||
}
|
||||
|
||||
/// Take a slice of bytes from a [Scratch], wraps it into the template's type and returns it
|
||||
/// as well as a new [Scratch] minus the taken array of bytes.
|
||||
pub trait TakeLike<'a, B: Backend, T> {
|
||||
type Output;
|
||||
fn take_like(&'a mut self, template: &T) -> (Self::Output, &'a mut Self);
|
||||
}
|
||||
42
backend/src/hal/api/svp_ppol.rs
Normal file
42
backend/src/hal/api/svp_ppol.rs
Normal file
@@ -0,0 +1,42 @@
|
||||
use crate::hal::layouts::{Backend, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDftToMut, VecZnxDftToRef};
|
||||
|
||||
/// Allocates as [crate::hal::layouts::SvpPPol].
|
||||
pub trait SvpPPolAlloc<B: Backend> {
|
||||
fn svp_ppol_alloc(&self, cols: usize) -> SvpPPolOwned<B>;
|
||||
}
|
||||
|
||||
/// Returns the size in bytes to allocate a [crate::hal::layouts::SvpPPol].
|
||||
pub trait SvpPPolAllocBytes {
|
||||
fn svp_ppol_alloc_bytes(&self, cols: usize) -> usize;
|
||||
}
|
||||
|
||||
/// Consume a vector of bytes into a [crate::hal::layouts::MatZnx].
|
||||
/// User must ensure that bytes is memory aligned and that it length is equal to [SvpPPolAllocBytes].
|
||||
pub trait SvpPPolFromBytes<B: Backend> {
|
||||
fn svp_ppol_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<B>;
|
||||
}
|
||||
|
||||
/// Prepare a [crate::hal::layouts::ScalarZnx] into an [crate::hal::layouts::SvpPPol].
|
||||
pub trait SvpPrepare<B: Backend> {
|
||||
fn svp_prepare<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: SvpPPolToMut<B>,
|
||||
A: ScalarZnxToRef;
|
||||
}
|
||||
|
||||
/// Apply a scalar-vector product between `a[a_col]` and `b[b_col]` and stores the result on `res[res_col]`.
|
||||
pub trait SvpApply<B: Backend> {
|
||||
fn svp_apply<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: SvpPPolToRef<B>,
|
||||
C: VecZnxDftToRef<B>;
|
||||
}
|
||||
|
||||
/// Apply a scalar-vector product between `res[res_col]` and `a[a_col]` and stores the result on `res[res_col]`.
|
||||
pub trait SvpApplyInplace<B: Backend> {
|
||||
fn svp_apply_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: SvpPPolToRef<B>;
|
||||
}
|
||||
369
backend/src/hal/api/vec_znx.rs
Normal file
369
backend/src/hal/api/vec_znx.rs
Normal file
@@ -0,0 +1,369 @@
|
||||
use rand_distr::Distribution;
|
||||
use rug::Float;
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::hal::layouts::{Backend, ScalarZnxToRef, Scratch, VecZnxOwned, VecZnxToMut, VecZnxToRef};
|
||||
|
||||
pub trait VecZnxAlloc {
|
||||
/// Allocates a new [crate::hal::layouts::VecZnx].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `cols`: the number of polynomials.
|
||||
/// * `size`: the number small polynomials per column.
|
||||
fn vec_znx_alloc(&self, cols: usize, size: usize) -> VecZnxOwned;
|
||||
}
|
||||
|
||||
pub trait VecZnxFromBytes {
|
||||
/// Instantiates a new [crate::hal::layouts::VecZnx] from a slice of bytes.
|
||||
/// The returned [crate::hal::layouts::VecZnx] takes ownership of the slice of bytes.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `cols`: the number of polynomials.
|
||||
/// * `size`: the number small polynomials per column.
|
||||
fn vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxOwned;
|
||||
}
|
||||
|
||||
pub trait VecZnxAllocBytes {
|
||||
/// Returns the number of bytes necessary to allocate a new [crate::hal::layouts::VecZnx].
|
||||
fn vec_znx_alloc_bytes(&self, cols: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
pub trait VecZnxNormalizeTmpBytes {
|
||||
/// Returns the minimum number of bytes necessary for normalization.
|
||||
fn vec_znx_normalize_tmp_bytes(&self, n: usize) -> usize;
|
||||
}
|
||||
|
||||
pub trait VecZnxNormalize<B: Backend> {
|
||||
/// Normalizes the selected column of `a` and stores the result into the selected column of `res`.
|
||||
fn vec_znx_normalize<R, A>(&self, basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxNormalizeInplace<B: Backend> {
|
||||
/// Normalizes the selected column of `a`.
|
||||
fn vec_znx_normalize_inplace<A>(&self, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
A: VecZnxToMut;
|
||||
}
|
||||
|
||||
pub trait VecZnxAdd {
|
||||
/// Adds the selected column of `a` to the selected column of `b` and writes the result on the selected column of `res`.
|
||||
fn vec_znx_add<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
B: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxAddInplace {
|
||||
/// Adds the selected column of `a` to the selected column of `res` and writes the result on the selected column of `res`.
|
||||
fn vec_znx_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxAddScalarInplace {
|
||||
/// Adds the selected column of `a` on the selected column and limb of `res`.
|
||||
fn vec_znx_add_scalar_inplace<R, A>(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxSub {
|
||||
/// Subtracts the selected column of `b` from the selected column of `a` and writes the result on the selected column of `res`.
|
||||
fn vec_znx_sub<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
B: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxSubABInplace {
|
||||
/// Subtracts the selected column of `a` from the selected column of `res` inplace.
|
||||
///
|
||||
/// res\[res_col\] -= a\[a_col\]
|
||||
fn vec_znx_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxSubBAInplace {
|
||||
/// Subtracts the selected column of `res` from the selected column of `a` and inplace mutates `res`
|
||||
///
|
||||
/// res\[res_col\] = a\[a_col\] - res\[res_col\]
|
||||
fn vec_znx_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxSubScalarInplace {
|
||||
/// Subtracts the selected column of `a` on the selected column and limb of `res`.
|
||||
fn vec_znx_sub_scalar_inplace<R, A>(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxNegate {
|
||||
// Negates the selected column of `a` and stores the result in `res_col` of `res`.
|
||||
fn vec_znx_negate<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxNegateInplace {
|
||||
/// Negates the selected column of `a`.
|
||||
fn vec_znx_negate_inplace<A>(&self, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxToMut;
|
||||
}
|
||||
|
||||
pub trait VecZnxLshInplace {
|
||||
/// Left shift by k bits all columns of `a`.
|
||||
fn vec_znx_lsh_inplace<A>(&self, basek: usize, k: usize, a: &mut A)
|
||||
where
|
||||
A: VecZnxToMut;
|
||||
}
|
||||
|
||||
pub trait VecZnxRshInplace {
|
||||
/// Right shift by k bits all columns of `a`.
|
||||
fn vec_znx_rsh_inplace<A>(&self, basek: usize, k: usize, a: &mut A)
|
||||
where
|
||||
A: VecZnxToMut;
|
||||
}
|
||||
|
||||
pub trait VecZnxRotate {
|
||||
/// Multiplies the selected column of `a` by X^k and stores the result in `res_col` of `res`.
|
||||
fn vec_znx_rotate<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxRotateInplace {
|
||||
/// Multiplies the selected column of `a` by X^k.
|
||||
fn vec_znx_rotate_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxToMut;
|
||||
}
|
||||
|
||||
pub trait VecZnxAutomorphism {
|
||||
/// Applies the automorphism X^i -> X^ik on the selected column of `a` and stores the result in `res_col` column of `res`.
|
||||
fn vec_znx_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxAutomorphismInplace {
|
||||
/// Applies the automorphism X^i -> X^ik on the selected column of `a`.
|
||||
fn vec_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxToMut;
|
||||
}
|
||||
|
||||
pub trait VecZnxMulXpMinusOne {
|
||||
fn vec_znx_mul_xp_minus_one<R, A>(&self, p: i64, r: &mut R, r_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxMulXpMinusOneInplace {
|
||||
fn vec_znx_mul_xp_minus_one_inplace<R>(&self, p: i64, r: &mut R, r_col: usize)
|
||||
where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
pub trait VecZnxSplit<B: Backend> {
|
||||
/// Splits the selected columns of `b` into subrings and copies them them into the selected column of `res`.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// This method requires that all [crate::hal::layouts::VecZnx] of b have the same ring degree
|
||||
/// and that b.n() * b.len() <= a.n()
|
||||
fn vec_znx_split<R, A>(&self, res: &mut Vec<R>, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxMerge {
|
||||
/// Merges the subrings of the selected column of `a` into the selected column of `res`.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// This method requires that all [crate::hal::layouts::VecZnx] of a have the same ring degree
|
||||
/// and that a.n() * a.len() <= b.n()
|
||||
fn vec_znx_merge<R, A>(&self, res: &mut R, res_col: usize, a: Vec<A>, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxSwithcDegree {
|
||||
fn vec_znx_switch_degree<R, A>(&self, res: &mut R, res_col: usize, a: &A, col_a: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxCopy {
|
||||
fn vec_znx_copy<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxStd {
|
||||
/// Returns the standard devaition of the i-th polynomial.
|
||||
fn vec_znx_std<A>(&self, basek: usize, a: &A, a_col: usize) -> f64
|
||||
where
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxFillUniform {
|
||||
/// Fills the first `size` size with uniform values in \[-2^{basek-1}, 2^{basek-1}\]
|
||||
fn vec_znx_fill_uniform<R>(&self, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source)
|
||||
where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
pub trait VecZnxFillDistF64 {
|
||||
fn vec_znx_fill_dist_f64<R, D: Distribution<f64>>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
pub trait VecZnxAddDistF64 {
|
||||
/// Adds vector sampled according to the provided distribution, scaled by 2^{-k} and bounded to \[-bound, bound\].
|
||||
fn vec_znx_add_dist_f64<R, D: Distribution<f64>>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
pub trait VecZnxFillNormal {
|
||||
fn vec_znx_fill_normal<R>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
pub trait VecZnxAddNormal {
|
||||
/// Adds a discrete normal vector scaled by 2^{-k} with the provided standard deviation and bounded to \[-bound, bound\].
|
||||
fn vec_znx_add_normal<R>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
pub trait VecZnxEncodeVeci64 {
|
||||
/// encode a vector of i64 on the receiver.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `col_i`: the index of the poly where to encode the data.
|
||||
/// * `basek`: base two negative logarithm decomposition of the receiver.
|
||||
/// * `k`: base two negative logarithm of the scaling of the data.
|
||||
/// * `data`: data to encode on the receiver.
|
||||
/// * `log_max`: base two logarithm of the infinity norm of the input data.
|
||||
fn encode_vec_i64<R>(&self, basek: usize, res: &mut R, res_col: usize, k: usize, data: &[i64], log_max: usize)
|
||||
where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
pub trait VecZnxEncodeCoeffsi64 {
|
||||
/// encodes a single i64 on the receiver at the given index.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `res_col`: the index of the poly where to encode the data.
|
||||
/// * `basek`: base two negative logarithm decomposition of the receiver.
|
||||
/// * `k`: base two negative logarithm of the scaling of the data.
|
||||
/// * `i`: index of the coefficient on which to encode the data.
|
||||
/// * `data`: data to encode on the receiver.
|
||||
/// * `log_max`: base two logarithm of the infinity norm of the input data.
|
||||
fn encode_coeff_i64<R>(&self, basek: usize, res: &mut R, res_col: usize, k: usize, i: usize, data: i64, log_max: usize)
|
||||
where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
pub trait VecZnxDecodeVeci64 {
|
||||
/// decode a vector of i64 from the receiver.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `res_col`: the index of the poly where to encode the data.
|
||||
/// * `basek`: base two negative logarithm decomposition of the receiver.
|
||||
/// * `k`: base two logarithm of the scaling of the data.
|
||||
/// * `data`: data to decode from the receiver.
|
||||
fn decode_vec_i64<R>(&self, basek: usize, res: &R, res_col: usize, k: usize, data: &mut [i64])
|
||||
where
|
||||
R: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxDecodeCoeffsi64 {
|
||||
/// decode a single of i64 from the receiver at the given index.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `res_col`: the index of the poly where to encode the data.
|
||||
/// * `basek`: base two negative logarithm decomposition of the receiver.
|
||||
/// * `k`: base two negative logarithm of the scaling of the data.
|
||||
/// * `i`: index of the coefficient to decode.
|
||||
/// * `data`: data to decode from the receiver.
|
||||
fn decode_coeff_i64<R>(&self, basek: usize, res: &R, res_col: usize, k: usize, i: usize) -> i64
|
||||
where
|
||||
R: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxDecodeVecFloat {
|
||||
/// decode a vector of Float from the receiver.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `col_i`: the index of the poly where to encode the data.
|
||||
/// * `basek`: base two negative logarithm decomposition of the receiver.
|
||||
/// * `data`: data to decode from the receiver.
|
||||
fn decode_vec_float<R>(&self, basek: usize, res: &R, col_i: usize, data: &mut [Float])
|
||||
where
|
||||
R: VecZnxToRef;
|
||||
}
|
||||
214
backend/src/hal/api/vec_znx_big.rs
Normal file
214
backend/src/hal/api/vec_znx_big.rs
Normal file
@@ -0,0 +1,214 @@
|
||||
use rand_distr::Distribution;
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::hal::layouts::{Backend, Scratch, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef};
|
||||
|
||||
/// Allocates as [crate::hal::layouts::VecZnxBig].
|
||||
pub trait VecZnxBigAlloc<B: Backend> {
|
||||
fn vec_znx_big_alloc(&self, cols: usize, size: usize) -> VecZnxBigOwned<B>;
|
||||
}
|
||||
|
||||
/// Returns the size in bytes to allocate a [crate::hal::layouts::VecZnxBig].
|
||||
pub trait VecZnxBigAllocBytes {
|
||||
fn vec_znx_big_alloc_bytes(&self, cols: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
/// Consume a vector of bytes into a [crate::hal::layouts::VecZnxBig].
|
||||
/// User must ensure that bytes is memory aligned and that it length is equal to [VecZnxBigAllocBytes].
|
||||
pub trait VecZnxBigFromBytes<B: Backend> {
|
||||
fn vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<B>;
|
||||
}
|
||||
|
||||
/// Add a discrete normal distribution on res.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `basek`: base two logarithm of the bivariate representation
|
||||
/// * `res`: receiver.
|
||||
/// * `res_col`: column of the receiver on which the operation is performed/stored.
|
||||
/// * `k`:
|
||||
/// * `source`: random coin source.
|
||||
/// * `sigma`: standard deviation of the discrete normal distribution.
|
||||
/// * `bound`: rejection sampling bound.
|
||||
pub trait VecZnxBigAddNormal<B: Backend> {
|
||||
fn vec_znx_big_add_normal<R: VecZnxBigToMut<B>>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
);
|
||||
}
|
||||
|
||||
pub trait VecZnxBigFillNormal<B: Backend> {
|
||||
fn vec_znx_big_fill_normal<R: VecZnxBigToMut<B>>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
);
|
||||
}
|
||||
|
||||
pub trait VecZnxBigFillDistF64<B: Backend> {
|
||||
fn vec_znx_big_fill_dist_f64<R: VecZnxBigToMut<B>, D: Distribution<f64>>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
);
|
||||
}
|
||||
|
||||
pub trait VecZnxBigAddDistF64<B: Backend> {
|
||||
fn vec_znx_big_add_dist_f64<R: VecZnxBigToMut<B>, D: Distribution<f64>>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
);
|
||||
}
|
||||
|
||||
pub trait VecZnxBigAdd<B: Backend> {
|
||||
/// Adds `a` to `b` and stores the result on `c`.
|
||||
fn vec_znx_big_add<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>,
|
||||
C: VecZnxBigToRef<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxBigAddInplace<B: Backend> {
|
||||
/// Adds `a` to `b` and stores the result on `b`.
|
||||
fn vec_znx_big_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxBigAddSmall<B: Backend> {
|
||||
/// Adds `a` to `b` and stores the result on `c`.
|
||||
fn vec_znx_big_add_small<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>,
|
||||
C: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxBigAddSmallInplace<B: Backend> {
|
||||
/// Adds `a` to `b` and stores the result on `b`.
|
||||
fn vec_znx_big_add_small_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxBigSub<B: Backend> {
|
||||
/// Subtracts `a` to `b` and stores the result on `c`.
|
||||
fn vec_znx_big_sub<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>,
|
||||
C: VecZnxBigToRef<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxBigSubABInplace<B: Backend> {
|
||||
/// Subtracts `a` from `b` and stores the result on `b`.
|
||||
fn vec_znx_big_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxBigSubBAInplace<B: Backend> {
|
||||
/// Subtracts `b` from `a` and stores the result on `b`.
|
||||
fn vec_znx_big_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxBigSubSmallA<B: Backend> {
|
||||
/// Subtracts `b` from `a` and stores the result on `c`.
|
||||
fn vec_znx_big_sub_small_a<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxToRef,
|
||||
C: VecZnxBigToRef<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxBigSubSmallAInplace<B: Backend> {
|
||||
/// Subtracts `a` from `res` and stores the result on `res`.
|
||||
fn vec_znx_big_sub_small_a_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxBigSubSmallB<B: Backend> {
|
||||
/// Subtracts `b` from `a` and stores the result on `c`.
|
||||
fn vec_znx_big_sub_small_b<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>,
|
||||
C: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxBigSubSmallBInplace<B: Backend> {
|
||||
/// Subtracts `res` from `a` and stores the result on `res`.
|
||||
fn vec_znx_big_sub_small_b_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxBigNegateInplace<B: Backend> {
|
||||
fn vec_znx_big_negate_inplace<A>(&self, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxBigToMut<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxBigNormalizeTmpBytes {
|
||||
fn vec_znx_big_normalize_tmp_bytes(&self, n: usize) -> usize;
|
||||
}
|
||||
|
||||
pub trait VecZnxBigNormalize<B: Backend> {
|
||||
fn vec_znx_big_normalize<R, A>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxBigToRef<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxBigAutomorphism<B: Backend> {
|
||||
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`.
|
||||
fn vec_znx_big_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxBigAutomorphismInplace<B: Backend> {
|
||||
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`.
|
||||
fn vec_znx_big_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxBigToMut<B>;
|
||||
}
|
||||
96
backend/src/hal/api/vec_znx_dft.rs
Normal file
96
backend/src/hal/api/vec_znx_dft.rs
Normal file
@@ -0,0 +1,96 @@
|
||||
use crate::hal::layouts::{
|
||||
Backend, Data, Scratch, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef,
|
||||
};
|
||||
|
||||
pub trait VecZnxDftAlloc<B: Backend> {
|
||||
fn vec_znx_dft_alloc(&self, cols: usize, size: usize) -> VecZnxDftOwned<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxDftFromBytes<B: Backend> {
|
||||
fn vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxDftAllocBytes {
|
||||
fn vec_znx_dft_alloc_bytes(&self, cols: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
pub trait VecZnxDftToVecZnxBigTmpBytes {
|
||||
fn vec_znx_dft_to_vec_znx_big_tmp_bytes(&self) -> usize;
|
||||
}
|
||||
|
||||
pub trait VecZnxDftToVecZnxBig<B: Backend> {
|
||||
fn vec_znx_dft_to_vec_znx_big<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxDftToRef<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxDftToVecZnxBigTmpA<B: Backend> {
|
||||
fn vec_znx_dft_to_vec_znx_big_tmp_a<R, A>(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxDftToMut<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxDftToVecZnxBigConsume<B: Backend> {
|
||||
fn vec_znx_dft_to_vec_znx_big_consume<D: Data>(&self, a: VecZnxDft<D, B>) -> VecZnxBig<D, B>
|
||||
where
|
||||
VecZnxDft<D, B>: VecZnxDftToMut<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxDftAdd<B: Backend> {
|
||||
fn vec_znx_dft_add<R, A, D>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
D: VecZnxDftToRef<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxDftAddInplace<B: Backend> {
|
||||
fn vec_znx_dft_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxDftSub<B: Backend> {
|
||||
fn vec_znx_dft_sub<R, A, D>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
D: VecZnxDftToRef<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxDftSubABInplace<B: Backend> {
|
||||
fn vec_znx_dft_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxDftSubBAInplace<B: Backend> {
|
||||
fn vec_znx_dft_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxDftCopy<B: Backend> {
|
||||
fn vec_znx_dft_copy<R, A>(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxDftFromVecZnx<B: Backend> {
|
||||
fn vec_znx_dft_from_vec_znx<R, A>(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxDftZero<B: Backend> {
|
||||
fn vec_znx_dft_zero<R>(&self, res: &mut R)
|
||||
where
|
||||
R: VecZnxDftToMut<B>;
|
||||
}
|
||||
90
backend/src/hal/api/vmp_pmat.rs
Normal file
90
backend/src/hal/api/vmp_pmat.rs
Normal file
@@ -0,0 +1,90 @@
|
||||
use crate::hal::layouts::{
|
||||
Backend, MatZnxToRef, Scratch, VecZnxDftToMut, VecZnxDftToRef, VmpPMatOwned, VmpPMatToMut, VmpPMatToRef,
|
||||
};
|
||||
|
||||
pub trait VmpPMatAlloc<B: Backend> {
|
||||
fn vmp_pmat_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<B>;
|
||||
}
|
||||
|
||||
pub trait VmpPMatAllocBytes {
|
||||
fn vmp_pmat_alloc_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
pub trait VmpPMatFromBytes<B: Backend> {
|
||||
fn vmp_pmat_from_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: Vec<u8>) -> VmpPMatOwned<B>;
|
||||
}
|
||||
|
||||
pub trait VmpPrepareTmpBytes {
|
||||
fn vmp_prepare_tmp_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
pub trait VmpPMatPrepare<B: Backend> {
|
||||
fn vmp_prepare<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VmpPMatToMut<B>,
|
||||
A: MatZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VmpApplyTmpBytes {
|
||||
fn vmp_apply_tmp_bytes(
|
||||
&self,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
b_cols_in: usize,
|
||||
b_cols_out: usize,
|
||||
b_size: usize,
|
||||
) -> usize;
|
||||
}
|
||||
|
||||
pub trait VmpApply<B: Backend> {
|
||||
/// Applies the vector matrix product [crate::hal::layouts::VecZnxDft] x [crate::hal::layouts::VmpPMat].
|
||||
///
|
||||
/// A vector matrix product numerically equivalent to a sum of [crate::hal::api::SvpApply],
|
||||
/// where each [crate::hal::layouts::SvpPPol] is a limb of the input [crate::hal::layouts::VecZnx] in DFT,
|
||||
/// and each vector a [crate::hal::layouts::VecZnxDft] (row) of the [crate::hal::layouts::VmpPMat].
|
||||
///
|
||||
/// As such, given an input [crate::hal::layouts::VecZnx] of `i` size and a [crate::hal::layouts::VmpPMat] of `i` rows and
|
||||
/// `j` size, the output is a [crate::hal::layouts::VecZnx] of `j` size.
|
||||
///
|
||||
/// If there is a mismatch between the dimensions the largest valid ones are used.
|
||||
///
|
||||
/// ```text
|
||||
/// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p|
|
||||
/// |h i j|
|
||||
/// |k l m|
|
||||
/// ```
|
||||
/// where each element is a [crate::hal::layouts::VecZnxDft].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `c`: the output of the vector matrix product, as a [crate::hal::layouts::VecZnxDft].
|
||||
/// * `a`: the left operand [crate::hal::layouts::VecZnxDft] of the vector matrix product.
|
||||
/// * `b`: the right operand [crate::hal::layouts::VmpPMat] of the vector matrix product.
|
||||
/// * `buf`: scratch space, the size can be obtained with [VmpApplyTmpBytes::vmp_apply_tmp_bytes].
|
||||
fn vmp_apply<R, A, C>(&self, res: &mut R, a: &A, b: &C, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
C: VmpPMatToRef<B>;
|
||||
}
|
||||
|
||||
pub trait VmpApplyAddTmpBytes {
|
||||
fn vmp_apply_add_tmp_bytes(
|
||||
&self,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
b_cols_in: usize,
|
||||
b_cols_out: usize,
|
||||
b_size: usize,
|
||||
) -> usize;
|
||||
}
|
||||
|
||||
pub trait VmpApplyAdd<B: Backend> {
|
||||
fn vmp_apply_add<R, A, C>(&self, res: &mut R, a: &A, b: &C, scale: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
C: VmpPMatToRef<B>;
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
use itertools::izip;
|
||||
use crate::hal::layouts::{Data, DataMut, DataRef};
|
||||
use rand_distr::num_traits::Zero;
|
||||
|
||||
pub trait ZnxInfos {
|
||||
@@ -32,7 +32,7 @@ pub trait ZnxSliceSize {
|
||||
}
|
||||
|
||||
pub trait DataView {
|
||||
type D;
|
||||
type D: Data;
|
||||
fn data(&self) -> &Self::D;
|
||||
}
|
||||
|
||||
@@ -40,8 +40,8 @@ pub trait DataViewMut: DataView {
|
||||
fn data_mut(&mut self) -> &mut Self::D;
|
||||
}
|
||||
|
||||
pub trait ZnxView: ZnxInfos + DataView<D: AsRef<[u8]>> {
|
||||
type Scalar: Copy;
|
||||
pub trait ZnxView: ZnxInfos + DataView<D: DataRef> {
|
||||
type Scalar: Copy + Zero;
|
||||
|
||||
/// Returns a non-mutable pointer to the underlying coefficients array.
|
||||
fn as_ptr(&self) -> *const Self::Scalar {
|
||||
@@ -57,8 +57,8 @@ pub trait ZnxView: ZnxInfos + DataView<D: AsRef<[u8]>> {
|
||||
fn at_ptr(&self, i: usize, j: usize) -> *const Self::Scalar {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(i < self.cols(), "{} >= {}", i, self.cols());
|
||||
assert!(j < self.size(), "{} >= {}", j, self.size());
|
||||
assert!(i < self.cols(), "cols: {} >= {}", i, self.cols());
|
||||
assert!(j < self.size(), "size: {} >= {}", j, self.size());
|
||||
}
|
||||
let offset: usize = self.n() * (j * self.cols() + i);
|
||||
unsafe { self.as_ptr().add(offset) }
|
||||
@@ -70,7 +70,7 @@ pub trait ZnxView: ZnxInfos + DataView<D: AsRef<[u8]>> {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ZnxViewMut: ZnxView + DataViewMut<D: AsMut<[u8]>> {
|
||||
pub trait ZnxViewMut: ZnxView + DataViewMut<D: DataMut> {
|
||||
/// Returns a mutable pointer to the underlying coefficients array.
|
||||
fn as_mut_ptr(&mut self) -> *mut Self::Scalar {
|
||||
self.data_mut().as_mut().as_mut_ptr() as *mut Self::Scalar
|
||||
@@ -85,8 +85,8 @@ pub trait ZnxViewMut: ZnxView + DataViewMut<D: AsMut<[u8]>> {
|
||||
fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut Self::Scalar {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(i < self.cols(), "{} >= {}", i, self.cols());
|
||||
assert!(j < self.size(), "{} >= {}", j, self.size());
|
||||
assert!(i < self.cols(), "cols: {} >= {}", i, self.cols());
|
||||
assert!(j < self.size(), "size: {} >= {}", j, self.size());
|
||||
}
|
||||
let offset: usize = self.n() * (j * self.cols() + i);
|
||||
unsafe { self.as_mut_ptr().add(offset) }
|
||||
@@ -99,101 +99,12 @@ pub trait ZnxViewMut: ZnxView + DataViewMut<D: AsMut<[u8]>> {
|
||||
}
|
||||
|
||||
//(Jay)Note: Can't provide blanket impl. of ZnxView because Scalar is not known
|
||||
impl<T> ZnxViewMut for T where T: ZnxView + DataViewMut<D: AsMut<[u8]>> {}
|
||||
impl<T> ZnxViewMut for T where T: ZnxView + DataViewMut<D: DataMut> {}
|
||||
|
||||
pub trait ZnxZero: ZnxViewMut + ZnxSliceSize
|
||||
pub trait ZnxZero
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
fn zero(&mut self) {
|
||||
unsafe {
|
||||
std::ptr::write_bytes(self.as_mut_ptr(), 0, self.n() * self.poly_count());
|
||||
}
|
||||
}
|
||||
|
||||
fn zero_at(&mut self, i: usize, j: usize) {
|
||||
unsafe {
|
||||
std::ptr::write_bytes(self.at_mut_ptr(i, j), 0, self.n());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Blanket implementations
|
||||
impl<T> ZnxZero for T where T: ZnxViewMut + ZnxSliceSize {} // WARNING should not work for mat_znx_dft but it does
|
||||
|
||||
use std::ops::{Add, AddAssign, Div, Mul, Neg, Shl, Shr, Sub};
|
||||
|
||||
use crate::Scratch;
|
||||
pub trait Integer:
|
||||
Copy
|
||||
+ Default
|
||||
+ PartialEq
|
||||
+ PartialOrd
|
||||
+ Add<Output = Self>
|
||||
+ Sub<Output = Self>
|
||||
+ Mul<Output = Self>
|
||||
+ Div<Output = Self>
|
||||
+ Neg<Output = Self>
|
||||
+ Shl<Output = Self>
|
||||
+ Shr<Output = Self>
|
||||
+ AddAssign
|
||||
{
|
||||
const BITS: u32;
|
||||
}
|
||||
|
||||
impl Integer for i64 {
|
||||
const BITS: u32 = 64;
|
||||
}
|
||||
|
||||
impl Integer for i128 {
|
||||
const BITS: u32 = 128;
|
||||
}
|
||||
|
||||
//(Jay)Note: `rsh` impl. ignores the column
|
||||
pub fn rsh<V: ZnxZero>(k: usize, basek: usize, a: &mut V, _a_col: usize, scratch: &mut Scratch)
|
||||
where
|
||||
V::Scalar: From<usize> + Integer + Zero,
|
||||
{
|
||||
let n: usize = a.n();
|
||||
let _size: usize = a.size();
|
||||
let cols: usize = a.cols();
|
||||
|
||||
let size: usize = a.size();
|
||||
let steps: usize = k / basek;
|
||||
|
||||
a.raw_mut().rotate_right(n * steps * cols);
|
||||
(0..cols).for_each(|i| {
|
||||
(0..steps).for_each(|j| {
|
||||
a.zero_at(i, j);
|
||||
})
|
||||
});
|
||||
|
||||
let k_rem: usize = k % basek;
|
||||
|
||||
if k_rem != 0 {
|
||||
let (carry, _) = scratch.tmp_slice::<V::Scalar>(rsh_tmp_bytes::<V::Scalar>(n));
|
||||
|
||||
unsafe {
|
||||
std::ptr::write_bytes(carry.as_mut_ptr(), 0, n * size_of::<V::Scalar>());
|
||||
}
|
||||
|
||||
let basek_t = V::Scalar::from(basek);
|
||||
let shift = V::Scalar::from(V::Scalar::BITS as usize - k_rem);
|
||||
let k_rem_t = V::Scalar::from(k_rem);
|
||||
|
||||
(0..cols).for_each(|i| {
|
||||
(steps..size).for_each(|j| {
|
||||
izip!(carry.iter_mut(), a.at_mut(i, j).iter_mut()).for_each(|(ci, xi)| {
|
||||
*xi += *ci << basek_t;
|
||||
*ci = (*xi << shift) >> shift;
|
||||
*xi = (*xi - *ci) >> k_rem_t;
|
||||
});
|
||||
});
|
||||
carry.iter_mut().for_each(|r| *r = V::Scalar::zero());
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn rsh_tmp_bytes<T>(n: usize) -> usize {
|
||||
n * std::mem::size_of::<T>()
|
||||
fn zero(&mut self);
|
||||
fn zero_at(&mut self, i: usize, j: usize);
|
||||
}
|
||||
32
backend/src/hal/delegates/mat_znx.rs
Normal file
32
backend/src/hal/delegates/mat_znx.rs
Normal file
@@ -0,0 +1,32 @@
|
||||
use crate::hal::{
|
||||
api::{MatZnxAlloc, MatZnxAllocBytes, MatZnxFromBytes},
|
||||
layouts::{Backend, MatZnxOwned, Module},
|
||||
oep::{MatZnxAllocBytesImpl, MatZnxAllocImpl, MatZnxFromBytesImpl},
|
||||
};
|
||||
|
||||
impl<B> MatZnxAlloc for Module<B>
|
||||
where
|
||||
B: Backend + MatZnxAllocImpl<B>,
|
||||
{
|
||||
fn mat_znx_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxOwned {
|
||||
B::mat_znx_alloc_impl(self, rows, cols_in, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> MatZnxAllocBytes for Module<B>
|
||||
where
|
||||
B: Backend + MatZnxAllocBytesImpl<B>,
|
||||
{
|
||||
fn mat_znx_alloc_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
B::mat_znx_alloc_bytes_impl(self, rows, cols_in, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> MatZnxFromBytes for Module<B>
|
||||
where
|
||||
B: Backend + MatZnxFromBytesImpl<B>,
|
||||
{
|
||||
fn mat_znx_from_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: Vec<u8>) -> MatZnxOwned {
|
||||
B::mat_znx_from_bytes_impl(self, rows, cols_in, cols_out, size, bytes)
|
||||
}
|
||||
}
|
||||
9
backend/src/hal/delegates/mod.rs
Normal file
9
backend/src/hal/delegates/mod.rs
Normal file
@@ -0,0 +1,9 @@
|
||||
mod mat_znx;
|
||||
mod module;
|
||||
mod scalar_znx;
|
||||
mod scratch;
|
||||
mod svp_ppol;
|
||||
mod vec_znx;
|
||||
mod vec_znx_big;
|
||||
mod vec_znx_dft;
|
||||
mod vmp_pmat;
|
||||
14
backend/src/hal/delegates/module.rs
Normal file
14
backend/src/hal/delegates/module.rs
Normal file
@@ -0,0 +1,14 @@
|
||||
use crate::hal::{
|
||||
api::ModuleNew,
|
||||
layouts::{Backend, Module},
|
||||
oep::ModuleNewImpl,
|
||||
};
|
||||
|
||||
impl<B> ModuleNew<B> for Module<B>
|
||||
where
|
||||
B: Backend + ModuleNewImpl<B>,
|
||||
{
|
||||
fn new(n: u64) -> Self {
|
||||
B::new_impl(n)
|
||||
}
|
||||
}
|
||||
88
backend/src/hal/delegates/scalar_znx.rs
Normal file
88
backend/src/hal/delegates/scalar_znx.rs
Normal file
@@ -0,0 +1,88 @@
|
||||
use crate::hal::{
|
||||
api::{
|
||||
ScalarZnxAlloc, ScalarZnxAllocBytes, ScalarZnxAutomorphism, ScalarZnxAutomorphismInplace, ScalarZnxFromBytes,
|
||||
ScalarZnxMulXpMinusOne, ScalarZnxMulXpMinusOneInplace,
|
||||
},
|
||||
layouts::{Backend, Module, ScalarZnxOwned, ScalarZnxToMut, ScalarZnxToRef},
|
||||
oep::{
|
||||
ScalarZnxAllocBytesImpl, ScalarZnxAllocImpl, ScalarZnxAutomorphismImpl, ScalarZnxAutomorphismInplaceIml,
|
||||
ScalarZnxFromBytesImpl, ScalarZnxMulXpMinusOneImpl, ScalarZnxMulXpMinusOneInplaceImpl,
|
||||
},
|
||||
};
|
||||
|
||||
impl<B> ScalarZnxAllocBytes for Module<B>
|
||||
where
|
||||
B: Backend + ScalarZnxAllocBytesImpl<B>,
|
||||
{
|
||||
fn scalar_znx_alloc_bytes(&self, cols: usize) -> usize {
|
||||
B::scalar_znx_alloc_bytes_impl(self.n(), cols)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> ScalarZnxAlloc for Module<B>
|
||||
where
|
||||
B: Backend + ScalarZnxAllocImpl<B>,
|
||||
{
|
||||
fn scalar_znx_alloc(&self, cols: usize) -> ScalarZnxOwned {
|
||||
B::scalar_znx_alloc_impl(self.n(), cols)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> ScalarZnxFromBytes for Module<B>
|
||||
where
|
||||
B: Backend + ScalarZnxFromBytesImpl<B>,
|
||||
{
|
||||
fn scalar_znx_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxOwned {
|
||||
B::scalar_znx_from_bytes_impl(self.n(), cols, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> ScalarZnxAutomorphism for Module<B>
|
||||
where
|
||||
B: Backend + ScalarZnxAutomorphismImpl<B>,
|
||||
{
|
||||
fn scalar_znx_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: ScalarZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
{
|
||||
B::scalar_znx_automorphism_impl(self, k, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> ScalarZnxAutomorphismInplace for Module<B>
|
||||
where
|
||||
B: Backend + ScalarZnxAutomorphismInplaceIml<B>,
|
||||
{
|
||||
fn scalar_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: ScalarZnxToMut,
|
||||
{
|
||||
B::scalar_znx_automorphism_inplace_impl(self, k, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> ScalarZnxMulXpMinusOne for Module<B>
|
||||
where
|
||||
B: Backend + ScalarZnxMulXpMinusOneImpl<B>,
|
||||
{
|
||||
fn scalar_znx_mul_xp_minus_one<R, A>(&self, p: i64, r: &mut R, r_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: ScalarZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
{
|
||||
B::scalar_znx_mul_xp_minus_one_impl(self, p, r, r_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> ScalarZnxMulXpMinusOneInplace for Module<B>
|
||||
where
|
||||
B: Backend + ScalarZnxMulXpMinusOneInplaceImpl<B>,
|
||||
{
|
||||
fn scalar_znx_mul_xp_minus_one_inplace<R>(&self, p: i64, r: &mut R, r_col: usize)
|
||||
where
|
||||
R: ScalarZnxToMut,
|
||||
{
|
||||
B::scalar_znx_mul_xp_minus_one_inplace_impl(self, p, r, r_col);
|
||||
}
|
||||
}
|
||||
243
backend/src/hal/delegates/scratch.rs
Normal file
243
backend/src/hal/delegates/scratch.rs
Normal file
@@ -0,0 +1,243 @@
|
||||
use crate::hal::{
|
||||
api::{
|
||||
ScratchAvailable, ScratchFromBytes, ScratchOwnedAlloc, ScratchOwnedBorrow, TakeLike, TakeMatZnx, TakeScalarZnx,
|
||||
TakeSlice, TakeSvpPPol, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, TakeVecZnxDftSlice, TakeVecZnxSlice, TakeVmpPMat,
|
||||
},
|
||||
layouts::{
|
||||
Backend, DataRef, MatZnx, Module, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat,
|
||||
},
|
||||
oep::{
|
||||
ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeLikeImpl, TakeMatZnxImpl,
|
||||
TakeScalarZnxImpl, TakeSliceImpl, TakeSvpPPolImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxDftSliceImpl,
|
||||
TakeVecZnxImpl, TakeVecZnxSliceImpl, TakeVmpPMatImpl,
|
||||
},
|
||||
};
|
||||
|
||||
impl<B> ScratchOwnedAlloc<B> for ScratchOwned<B>
|
||||
where
|
||||
B: Backend + ScratchOwnedAllocImpl<B>,
|
||||
{
|
||||
fn alloc(size: usize) -> Self {
|
||||
B::scratch_owned_alloc_impl(size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> ScratchOwnedBorrow<B> for ScratchOwned<B>
|
||||
where
|
||||
B: Backend + ScratchOwnedBorrowImpl<B>,
|
||||
{
|
||||
fn borrow(&mut self) -> &mut Scratch<B> {
|
||||
B::scratch_owned_borrow_impl(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> ScratchFromBytes<B> for Scratch<B>
|
||||
where
|
||||
B: Backend + ScratchFromBytesImpl<B>,
|
||||
{
|
||||
fn from_bytes(data: &mut [u8]) -> &mut Scratch<B> {
|
||||
B::scratch_from_bytes_impl(data)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> ScratchAvailable for Scratch<B>
|
||||
where
|
||||
B: Backend + ScratchAvailableImpl<B>,
|
||||
{
|
||||
fn available(&self) -> usize {
|
||||
B::scratch_available_impl(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> TakeSlice for Scratch<B>
|
||||
where
|
||||
B: Backend + TakeSliceImpl<B>,
|
||||
{
|
||||
fn take_slice<T>(&mut self, len: usize) -> (&mut [T], &mut Self) {
|
||||
B::take_slice_impl(self, len)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> TakeScalarZnx<B> for Scratch<B>
|
||||
where
|
||||
B: Backend + TakeScalarZnxImpl<B>,
|
||||
{
|
||||
fn take_scalar_znx(&mut self, module: &Module<B>, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) {
|
||||
B::take_scalar_znx_impl(self, module.n(), cols)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> TakeSvpPPol<B> for Scratch<B>
|
||||
where
|
||||
B: Backend + TakeSvpPPolImpl<B>,
|
||||
{
|
||||
fn take_svp_ppol(&mut self, module: &Module<B>, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Self) {
|
||||
B::take_svp_ppol_impl(self, module.n(), cols)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> TakeVecZnx<B> for Scratch<B>
|
||||
where
|
||||
B: Backend + TakeVecZnxImpl<B>,
|
||||
{
|
||||
fn take_vec_znx(&mut self, module: &Module<B>, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self) {
|
||||
B::take_vec_znx_impl(self, module.n(), cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> TakeVecZnxSlice<B> for Scratch<B>
|
||||
where
|
||||
B: Backend + TakeVecZnxSliceImpl<B>,
|
||||
{
|
||||
fn take_vec_znx_slice(
|
||||
&mut self,
|
||||
len: usize,
|
||||
module: &Module<B>,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (Vec<VecZnx<&mut [u8]>>, &mut Self) {
|
||||
B::take_vec_znx_slice_impl(self, len, module.n(), cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> TakeVecZnxBig<B> for Scratch<B>
|
||||
where
|
||||
B: Backend + TakeVecZnxBigImpl<B>,
|
||||
{
|
||||
fn take_vec_znx_big(&mut self, module: &Module<B>, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, &mut Self) {
|
||||
B::take_vec_znx_big_impl(self, module.n(), cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> TakeVecZnxDft<B> for Scratch<B>
|
||||
where
|
||||
B: Backend + TakeVecZnxDftImpl<B>,
|
||||
{
|
||||
fn take_vec_znx_dft(&mut self, module: &Module<B>, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, &mut Self) {
|
||||
B::take_vec_znx_dft_impl(self, module.n(), cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> TakeVecZnxDftSlice<B> for Scratch<B>
|
||||
where
|
||||
B: Backend + TakeVecZnxDftSliceImpl<B>,
|
||||
{
|
||||
fn take_vec_znx_dft_slice(
|
||||
&mut self,
|
||||
len: usize,
|
||||
module: &Module<B>,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (Vec<VecZnxDft<&mut [u8], B>>, &mut Self) {
|
||||
B::take_vec_znx_dft_slice_impl(self, len, module.n(), cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> TakeVmpPMat<B> for Scratch<B>
|
||||
where
|
||||
B: Backend + TakeVmpPMatImpl<B>,
|
||||
{
|
||||
fn take_vmp_pmat(
|
||||
&mut self,
|
||||
module: &Module<B>,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
) -> (VmpPMat<&mut [u8], B>, &mut Self) {
|
||||
B::take_vmp_pmat_impl(self, module.n(), rows, cols_in, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> TakeMatZnx<B> for Scratch<B>
|
||||
where
|
||||
B: Backend + TakeMatZnxImpl<B>,
|
||||
{
|
||||
fn take_mat_znx(
|
||||
&mut self,
|
||||
module: &Module<B>,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
) -> (MatZnx<&mut [u8]>, &mut Self) {
|
||||
B::take_mat_znx_impl(self, module.n(), rows, cols_in, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, B: Backend, D> TakeLike<'a, B, ScalarZnx<D>> for Scratch<B>
|
||||
where
|
||||
B: TakeLikeImpl<'a, B, ScalarZnx<D>, Output = ScalarZnx<&'a mut [u8]>>,
|
||||
D: DataRef,
|
||||
{
|
||||
type Output = ScalarZnx<&'a mut [u8]>;
|
||||
fn take_like(&'a mut self, template: &ScalarZnx<D>) -> (Self::Output, &'a mut Self) {
|
||||
B::take_like_impl(self, template)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, B: Backend, D> TakeLike<'a, B, SvpPPol<D, B>> for Scratch<B>
|
||||
where
|
||||
B: TakeLikeImpl<'a, B, SvpPPol<D, B>, Output = SvpPPol<&'a mut [u8], B>>,
|
||||
D: DataRef,
|
||||
{
|
||||
type Output = SvpPPol<&'a mut [u8], B>;
|
||||
fn take_like(&'a mut self, template: &SvpPPol<D, B>) -> (Self::Output, &'a mut Self) {
|
||||
B::take_like_impl(self, template)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, B: Backend, D> TakeLike<'a, B, VecZnx<D>> for Scratch<B>
|
||||
where
|
||||
B: TakeLikeImpl<'a, B, VecZnx<D>, Output = VecZnx<&'a mut [u8]>>,
|
||||
D: DataRef,
|
||||
{
|
||||
type Output = VecZnx<&'a mut [u8]>;
|
||||
fn take_like(&'a mut self, template: &VecZnx<D>) -> (Self::Output, &'a mut Self) {
|
||||
B::take_like_impl(self, template)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, B: Backend, D> TakeLike<'a, B, VecZnxBig<D, B>> for Scratch<B>
|
||||
where
|
||||
B: TakeLikeImpl<'a, B, VecZnxBig<D, B>, Output = VecZnxBig<&'a mut [u8], B>>,
|
||||
D: DataRef,
|
||||
{
|
||||
type Output = VecZnxBig<&'a mut [u8], B>;
|
||||
fn take_like(&'a mut self, template: &VecZnxBig<D, B>) -> (Self::Output, &'a mut Self) {
|
||||
B::take_like_impl(self, template)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, B: Backend, D> TakeLike<'a, B, VecZnxDft<D, B>> for Scratch<B>
|
||||
where
|
||||
B: TakeLikeImpl<'a, B, VecZnxDft<D, B>, Output = VecZnxDft<&'a mut [u8], B>>,
|
||||
D: DataRef,
|
||||
{
|
||||
type Output = VecZnxDft<&'a mut [u8], B>;
|
||||
fn take_like(&'a mut self, template: &VecZnxDft<D, B>) -> (Self::Output, &'a mut Self) {
|
||||
B::take_like_impl(self, template)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, B: Backend, D> TakeLike<'a, B, MatZnx<D>> for Scratch<B>
|
||||
where
|
||||
B: TakeLikeImpl<'a, B, MatZnx<D>, Output = MatZnx<&'a mut [u8]>>,
|
||||
D: DataRef,
|
||||
{
|
||||
type Output = MatZnx<&'a mut [u8]>;
|
||||
fn take_like(&'a mut self, template: &MatZnx<D>) -> (Self::Output, &'a mut Self) {
|
||||
B::take_like_impl(self, template)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, B: Backend, D> TakeLike<'a, B, VmpPMat<D, B>> for Scratch<B>
|
||||
where
|
||||
B: TakeLikeImpl<'a, B, VmpPMat<D, B>, Output = VmpPMat<&'a mut [u8], B>>,
|
||||
D: DataRef,
|
||||
{
|
||||
type Output = VmpPMat<&'a mut [u8], B>;
|
||||
fn take_like(&'a mut self, template: &VmpPMat<D, B>) -> (Self::Output, &'a mut Self) {
|
||||
B::take_like_impl(self, template)
|
||||
}
|
||||
}
|
||||
72
backend/src/hal/delegates/svp_ppol.rs
Normal file
72
backend/src/hal/delegates/svp_ppol.rs
Normal file
@@ -0,0 +1,72 @@
|
||||
use crate::hal::{
|
||||
api::{SvpApply, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPPolFromBytes, SvpPrepare},
|
||||
layouts::{Backend, Module, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDftToMut, VecZnxDftToRef},
|
||||
oep::{SvpApplyImpl, SvpApplyInplaceImpl, SvpPPolAllocBytesImpl, SvpPPolAllocImpl, SvpPPolFromBytesImpl, SvpPrepareImpl},
|
||||
};
|
||||
|
||||
impl<B> SvpPPolFromBytes<B> for Module<B>
|
||||
where
|
||||
B: Backend + SvpPPolFromBytesImpl<B>,
|
||||
{
|
||||
fn svp_ppol_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<B> {
|
||||
B::svp_ppol_from_bytes_impl(self.n(), cols, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> SvpPPolAlloc<B> for Module<B>
|
||||
where
|
||||
B: Backend + SvpPPolAllocImpl<B>,
|
||||
{
|
||||
fn svp_ppol_alloc(&self, cols: usize) -> SvpPPolOwned<B> {
|
||||
B::svp_ppol_alloc_impl(self.n(), cols)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> SvpPPolAllocBytes for Module<B>
|
||||
where
|
||||
B: Backend + SvpPPolAllocBytesImpl<B>,
|
||||
{
|
||||
fn svp_ppol_alloc_bytes(&self, cols: usize) -> usize {
|
||||
B::svp_ppol_alloc_bytes_impl(self.n(), cols)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> SvpPrepare<B> for Module<B>
|
||||
where
|
||||
B: Backend + SvpPrepareImpl<B>,
|
||||
{
|
||||
fn svp_prepare<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: SvpPPolToMut<B>,
|
||||
A: ScalarZnxToRef,
|
||||
{
|
||||
B::svp_prepare_impl(self, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> SvpApply<B> for Module<B>
|
||||
where
|
||||
B: Backend + SvpApplyImpl<B>,
|
||||
{
|
||||
fn svp_apply<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: SvpPPolToRef<B>,
|
||||
C: VecZnxDftToRef<B>,
|
||||
{
|
||||
B::svp_apply_impl(self, res, res_col, a, a_col, b, b_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> SvpApplyInplace<B> for Module<B>
|
||||
where
|
||||
B: Backend + SvpApplyInplaceImpl,
|
||||
{
|
||||
fn svp_apply_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: SvpPPolToRef<B>,
|
||||
{
|
||||
B::svp_apply_inplace_impl(self, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
518
backend/src/hal/delegates/vec_znx.rs
Normal file
518
backend/src/hal/delegates/vec_znx.rs
Normal file
@@ -0,0 +1,518 @@
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::hal::{
|
||||
api::{
|
||||
VecZnxAdd, VecZnxAddDistF64, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAlloc, VecZnxAllocBytes,
|
||||
VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxCopy, VecZnxDecodeCoeffsi64, VecZnxDecodeVecFloat,
|
||||
VecZnxDecodeVeci64, VecZnxEncodeCoeffsi64, VecZnxEncodeVeci64, VecZnxFillDistF64, VecZnxFillNormal, VecZnxFillUniform,
|
||||
VecZnxFromBytes, VecZnxLshInplace, VecZnxMerge, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegate,
|
||||
VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace,
|
||||
VecZnxRshInplace, VecZnxSplit, VecZnxStd, VecZnxSub, VecZnxSubABInplace, VecZnxSubBAInplace, VecZnxSubScalarInplace,
|
||||
VecZnxSwithcDegree,
|
||||
},
|
||||
layouts::{Backend, Module, ScalarZnxToRef, Scratch, VecZnxOwned, VecZnxToMut, VecZnxToRef},
|
||||
oep::{
|
||||
VecZnxAddDistF64Impl, VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl, VecZnxAddScalarInplaceImpl,
|
||||
VecZnxAllocBytesImpl, VecZnxAllocImpl, VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl,
|
||||
VecZnxDecodeCoeffsi64Impl, VecZnxDecodeVecFloatImpl, VecZnxDecodeVeci64Impl, VecZnxEncodeCoeffsi64Impl,
|
||||
VecZnxEncodeVeci64Impl, VecZnxFillDistF64Impl, VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxFromBytesImpl,
|
||||
VecZnxLshInplaceImpl, VecZnxMergeImpl, VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl, VecZnxNegateImpl,
|
||||
VecZnxNegateInplaceImpl, VecZnxNormalizeImpl, VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl,
|
||||
VecZnxRotateInplaceImpl, VecZnxRshInplaceImpl, VecZnxSplitImpl, VecZnxStdImpl, VecZnxSubABInplaceImpl,
|
||||
VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarInplaceImpl, VecZnxSwithcDegreeImpl,
|
||||
},
|
||||
};
|
||||
|
||||
impl<B> VecZnxAlloc for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxAllocImpl<B>,
|
||||
{
|
||||
fn vec_znx_alloc(&self, cols: usize, size: usize) -> VecZnxOwned {
|
||||
B::vec_znx_alloc_impl(self.n(), cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxFromBytes for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxFromBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxOwned {
|
||||
B::vec_znx_from_bytes_impl(self.n(), cols, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxAllocBytes for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxAllocBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_alloc_bytes(&self, cols: usize, size: usize) -> usize {
|
||||
B::vec_znx_alloc_bytes_impl(self.n(), cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxNormalizeTmpBytes for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxNormalizeTmpBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_normalize_tmp_bytes(&self, n: usize) -> usize {
|
||||
B::vec_znx_normalize_tmp_bytes_impl(self, n)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxNormalize<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxNormalizeImpl<B>,
|
||||
{
|
||||
fn vec_znx_normalize<R, A>(&self, basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_normalize_impl(self, basek, res, res_col, a, a_col, scratch)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxNormalizeInplace<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxNormalizeInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_normalize_inplace<A>(&self, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
B::vec_znx_normalize_inplace_impl(self, basek, a, a_col, scratch)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxAdd for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxAddImpl<B>,
|
||||
{
|
||||
fn vec_znx_add<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
C: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_add_impl(self, res, res_col, a, a_col, b, b_col)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxAddInplace for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxAddInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_add_inplace_impl(self, res, res_col, a, a_col)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxAddScalarInplace for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxAddScalarInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_add_scalar_inplace<R, A>(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
{
|
||||
B::vec_znx_add_scalar_inplace_impl(self, res, res_col, res_limb, a, a_col)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxSub for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxSubImpl<B>,
|
||||
{
|
||||
fn vec_znx_sub<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
C: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_sub_impl(self, res, res_col, a, a_col, b, b_col)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxSubABInplace for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxSubABInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_sub_ab_inplace_impl(self, res, res_col, a, a_col)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxSubBAInplace for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxSubBAInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_sub_ba_inplace_impl(self, res, res_col, a, a_col)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxSubScalarInplace for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxSubScalarInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_sub_scalar_inplace<R, A>(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
{
|
||||
B::vec_znx_sub_scalar_inplace_impl(self, res, res_col, res_limb, a, a_col)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxNegate for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxNegateImpl<B>,
|
||||
{
|
||||
fn vec_znx_negate<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_negate_impl(self, res, res_col, a, a_col)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxNegateInplace for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxNegateInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_negate_inplace<A>(&self, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
B::vec_znx_negate_inplace_impl(self, a, a_col)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxLshInplace for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxLshInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_lsh_inplace<A>(&self, basek: usize, k: usize, a: &mut A)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
B::vec_znx_lsh_inplace_impl(self, basek, k, a)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxRshInplace for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxRshInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_rsh_inplace<A>(&self, basek: usize, k: usize, a: &mut A)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
B::vec_znx_rsh_inplace_impl(self, basek, k, a)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxRotate for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxRotateImpl<B>,
|
||||
{
|
||||
fn vec_znx_rotate<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_rotate_impl(self, k, res, res_col, a, a_col)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxRotateInplace for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxRotateInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_rotate_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
B::vec_znx_rotate_inplace_impl(self, k, a, a_col)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxAutomorphism for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxAutomorphismImpl<B>,
|
||||
{
|
||||
fn vec_znx_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_automorphism_impl(self, k, res, res_col, a, a_col)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxAutomorphismInplace for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxAutomorphismInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
B::vec_znx_automorphism_inplace_impl(self, k, a, a_col)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxMulXpMinusOne for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxMulXpMinusOneImpl<B>,
|
||||
{
|
||||
fn vec_znx_mul_xp_minus_one<R, A>(&self, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_mul_xp_minus_one_impl(self, p, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxMulXpMinusOneInplace for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxMulXpMinusOneInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_mul_xp_minus_one_inplace<R>(&self, p: i64, res: &mut R, res_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
B::vec_znx_mul_xp_minus_one_inplace_impl(self, p, res, res_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxSplit<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxSplitImpl<B>,
|
||||
{
|
||||
fn vec_znx_split<R, A>(&self, res: &mut Vec<R>, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_split_impl(self, res, res_col, a, a_col, scratch)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxMerge for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxMergeImpl<B>,
|
||||
{
|
||||
fn vec_znx_merge<R, A>(&self, res: &mut R, res_col: usize, a: Vec<A>, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_merge_impl(self, res, res_col, a, a_col)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxSwithcDegree for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxSwithcDegreeImpl<B>,
|
||||
{
|
||||
fn vec_znx_switch_degree<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_switch_degree_impl(self, res, res_col, a, a_col)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxCopy for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxCopyImpl<B>,
|
||||
{
|
||||
fn vec_znx_copy<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_copy_impl(self, res, res_col, a, a_col)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxStd for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxStdImpl<B>,
|
||||
{
|
||||
fn vec_znx_std<A>(&self, basek: usize, a: &A, a_col: usize) -> f64
|
||||
where
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_std_impl(self, basek, a, a_col)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxFillUniform for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxFillUniformImpl<B>,
|
||||
{
|
||||
fn vec_znx_fill_uniform<R>(&self, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
B::vec_znx_fill_uniform_impl(self, basek, res, res_col, k, source);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxFillDistF64 for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxFillDistF64Impl<B>,
|
||||
{
|
||||
fn vec_znx_fill_dist_f64<R, D: rand::prelude::Distribution<f64>>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
B::vec_znx_fill_dist_f64_impl(self, basek, res, res_col, k, source, dist, bound);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxAddDistF64 for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxAddDistF64Impl<B>,
|
||||
{
|
||||
fn vec_znx_add_dist_f64<R, D: rand::prelude::Distribution<f64>>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
B::vec_znx_add_dist_f64_impl(self, basek, res, res_col, k, source, dist, bound);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxFillNormal for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxFillNormalImpl<B>,
|
||||
{
|
||||
fn vec_znx_fill_normal<R>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
B::vec_znx_fill_normal_impl(self, basek, res, res_col, k, source, sigma, bound);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxAddNormal for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxAddNormalImpl<B>,
|
||||
{
|
||||
fn vec_znx_add_normal<R>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
B::vec_znx_add_normal_impl(self, basek, res, res_col, k, source, sigma, bound);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxEncodeVeci64 for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxEncodeVeci64Impl<B>,
|
||||
{
|
||||
fn encode_vec_i64<R>(&self, basek: usize, res: &mut R, res_col: usize, k: usize, data: &[i64], log_max: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
B::encode_vec_i64_impl(self, basek, res, res_col, k, data, log_max);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxEncodeCoeffsi64 for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxEncodeCoeffsi64Impl<B>,
|
||||
{
|
||||
fn encode_coeff_i64<R>(&self, basek: usize, res: &mut R, res_col: usize, k: usize, i: usize, data: i64, log_max: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
B::encode_coeff_i64_impl(self, basek, res, res_col, k, i, data, log_max);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxDecodeVeci64 for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDecodeVeci64Impl<B>,
|
||||
{
|
||||
fn decode_vec_i64<R>(&self, basek: usize, res: &R, res_col: usize, k: usize, data: &mut [i64])
|
||||
where
|
||||
R: VecZnxToRef,
|
||||
{
|
||||
B::decode_vec_i64_impl(self, basek, res, res_col, k, data);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxDecodeCoeffsi64 for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDecodeCoeffsi64Impl<B>,
|
||||
{
|
||||
fn decode_coeff_i64<R>(&self, basek: usize, res: &R, res_col: usize, k: usize, i: usize) -> i64
|
||||
where
|
||||
R: VecZnxToRef,
|
||||
{
|
||||
B::decode_coeff_i64_impl(self, basek, res, res_col, k, i)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxDecodeVecFloat for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDecodeVecFloatImpl<B>,
|
||||
{
|
||||
fn decode_vec_float<R>(&self, basek: usize, res: &R, col_i: usize, data: &mut [rug::Float])
|
||||
where
|
||||
R: VecZnxToRef,
|
||||
{
|
||||
B::decode_vec_float_impl(self, basek, res, col_i, data);
|
||||
}
|
||||
}
|
||||
334
backend/src/hal/delegates/vec_znx_big.rs
Normal file
334
backend/src/hal/delegates/vec_znx_big.rs
Normal file
@@ -0,0 +1,334 @@
|
||||
use rand_distr::Distribution;
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::hal::{
|
||||
api::{
|
||||
VecZnxBigAdd, VecZnxBigAddDistF64, VecZnxBigAddInplace, VecZnxBigAddNormal, VecZnxBigAddSmall, VecZnxBigAddSmallInplace,
|
||||
VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigAutomorphism, VecZnxBigAutomorphismInplace, VecZnxBigFillDistF64,
|
||||
VecZnxBigFillNormal, VecZnxBigFromBytes, VecZnxBigNegateInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
|
||||
VecZnxBigSub, VecZnxBigSubABInplace, VecZnxBigSubBAInplace, VecZnxBigSubSmallA, VecZnxBigSubSmallAInplace,
|
||||
VecZnxBigSubSmallB, VecZnxBigSubSmallBInplace,
|
||||
},
|
||||
layouts::{Backend, Module, Scratch, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef},
|
||||
oep::{
|
||||
VecZnxBigAddDistF64Impl, VecZnxBigAddImpl, VecZnxBigAddInplaceImpl, VecZnxBigAddNormalImpl, VecZnxBigAddSmallImpl,
|
||||
VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, VecZnxBigAutomorphismImpl,
|
||||
VecZnxBigAutomorphismInplaceImpl, VecZnxBigFillDistF64Impl, VecZnxBigFillNormalImpl, VecZnxBigFromBytesImpl,
|
||||
VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl, VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubABInplaceImpl,
|
||||
VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl, VecZnxBigSubSmallAImpl, VecZnxBigSubSmallAInplaceImpl,
|
||||
VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl,
|
||||
},
|
||||
};
|
||||
|
||||
impl<B> VecZnxBigAlloc<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigAllocImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_alloc(&self, cols: usize, size: usize) -> VecZnxBigOwned<B> {
|
||||
B::vec_znx_big_alloc_impl(self.n(), cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigFromBytes<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigFromBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<B> {
|
||||
B::vec_znx_big_from_bytes_impl(self.n(), cols, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigAllocBytes for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigAllocBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_alloc_bytes(&self, cols: usize, size: usize) -> usize {
|
||||
B::vec_znx_big_alloc_bytes_impl(self.n(), cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigAddDistF64<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigAddDistF64Impl<B>,
|
||||
{
|
||||
fn vec_znx_big_add_dist_f64<R: VecZnxBigToMut<B>, D: Distribution<f64>>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) {
|
||||
B::add_dist_f64_impl(self, basek, res, res_col, k, source, dist, bound);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigAddNormal<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigAddNormalImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_add_normal<R: VecZnxBigToMut<B>>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) {
|
||||
B::add_normal_impl(self, basek, res, res_col, k, source, sigma, bound);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigFillDistF64<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigFillDistF64Impl<B>,
|
||||
{
|
||||
fn vec_znx_big_fill_dist_f64<R: VecZnxBigToMut<B>, D: Distribution<f64>>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) {
|
||||
B::fill_dist_f64_impl(self, basek, res, res_col, k, source, dist, bound);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigFillNormal<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigFillNormalImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_fill_normal<R: VecZnxBigToMut<B>>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) {
|
||||
B::fill_normal_impl(self, basek, res, res_col, k, source, sigma, bound);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigAdd<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigAddImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_add<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>,
|
||||
C: VecZnxBigToRef<B>,
|
||||
{
|
||||
B::vec_znx_big_add_impl(self, res, res_col, a, a_col, b, b_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigAddInplace<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigAddInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>,
|
||||
{
|
||||
B::vec_znx_big_add_inplace_impl(self, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigAddSmall<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigAddSmallImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_add_small<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>,
|
||||
C: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_big_add_small_impl(self, res, res_col, a, a_col, b, b_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigAddSmallInplace<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigAddSmallInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_add_small_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_big_add_small_inplace_impl(self, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigSub<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigSubImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_sub<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>,
|
||||
C: VecZnxBigToRef<B>,
|
||||
{
|
||||
B::vec_znx_big_sub_impl(self, res, res_col, a, a_col, b, b_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigSubABInplace<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigSubABInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>,
|
||||
{
|
||||
B::vec_znx_big_sub_ab_inplace_impl(self, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigSubBAInplace<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigSubBAInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>,
|
||||
{
|
||||
B::vec_znx_big_sub_ba_inplace_impl(self, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigSubSmallA<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigSubSmallAImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_sub_small_a<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxToRef,
|
||||
C: VecZnxBigToRef<B>,
|
||||
{
|
||||
B::vec_znx_big_sub_small_a_impl(self, res, res_col, a, a_col, b, b_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigSubSmallAInplace<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigSubSmallAInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_sub_small_a_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_big_sub_small_a_inplace_impl(self, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigSubSmallB<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigSubSmallBImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_sub_small_b<R, A, C>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>,
|
||||
C: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_big_sub_small_b_impl(self, res, res_col, a, a_col, b, b_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigSubSmallBInplace<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigSubSmallBInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_sub_small_b_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_big_sub_small_b_inplace_impl(self, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigNegateInplace<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigNegateInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_negate_inplace<A>(&self, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxBigToMut<B>,
|
||||
{
|
||||
B::vec_znx_big_negate_inplace_impl(self, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigNormalizeTmpBytes for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigNormalizeTmpBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_normalize_tmp_bytes(&self, n: usize) -> usize {
|
||||
B::vec_znx_big_normalize_tmp_bytes_impl(self, n)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigNormalize<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigNormalizeImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_normalize<R, A>(
|
||||
&self,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxBigToRef<B>,
|
||||
{
|
||||
B::vec_znx_big_normalize_impl(self, basek, res, res_col, a, a_col, scratch);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigAutomorphism<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigAutomorphismImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>,
|
||||
{
|
||||
B::vec_znx_big_automorphism_impl(self, k, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxBigAutomorphismInplace<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigAutomorphismInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxBigToMut<B>,
|
||||
{
|
||||
B::vec_znx_big_automorphism_inplace_impl(self, k, a, a_col);
|
||||
}
|
||||
}
|
||||
196
backend/src/hal/delegates/vec_znx_dft.rs
Normal file
196
backend/src/hal/delegates/vec_znx_dft.rs
Normal file
@@ -0,0 +1,196 @@
|
||||
use crate::hal::{
|
||||
api::{
|
||||
VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftCopy, VecZnxDftFromBytes,
|
||||
VecZnxDftFromVecZnx, VecZnxDftSub, VecZnxDftSubABInplace, VecZnxDftSubBAInplace, VecZnxDftToVecZnxBig,
|
||||
VecZnxDftToVecZnxBigConsume, VecZnxDftToVecZnxBigTmpA, VecZnxDftToVecZnxBigTmpBytes, VecZnxDftZero,
|
||||
},
|
||||
layouts::{
|
||||
Backend, Data, Module, Scratch, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef,
|
||||
VecZnxToRef,
|
||||
},
|
||||
oep::{
|
||||
VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftCopyImpl,
|
||||
VecZnxDftFromBytesImpl, VecZnxDftFromVecZnxImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl, VecZnxDftSubImpl,
|
||||
VecZnxDftToVecZnxBigConsumeImpl, VecZnxDftToVecZnxBigImpl, VecZnxDftToVecZnxBigTmpAImpl,
|
||||
VecZnxDftToVecZnxBigTmpBytesImpl, VecZnxDftZeroImpl,
|
||||
},
|
||||
};
|
||||
|
||||
impl<B> VecZnxDftFromBytes<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDftFromBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<B> {
|
||||
B::vec_znx_dft_from_bytes_impl(self.n(), cols, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxDftAllocBytes for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDftAllocBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_dft_alloc_bytes(&self, cols: usize, size: usize) -> usize {
|
||||
B::vec_znx_dft_alloc_bytes_impl(self.n(), cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxDftAlloc<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDftAllocImpl<B>,
|
||||
{
|
||||
fn vec_znx_dft_alloc(&self, cols: usize, size: usize) -> VecZnxDftOwned<B> {
|
||||
B::vec_znx_dft_alloc_impl(self.n(), cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxDftToVecZnxBigTmpBytes for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDftToVecZnxBigTmpBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_dft_to_vec_znx_big_tmp_bytes(&self) -> usize {
|
||||
B::vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxDftToVecZnxBig<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDftToVecZnxBigImpl<B>,
|
||||
{
|
||||
fn vec_znx_dft_to_vec_znx_big<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
{
|
||||
B::vec_znx_dft_to_vec_znx_big_impl(self, res, res_col, a, a_col, scratch);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxDftToVecZnxBigTmpA<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDftToVecZnxBigTmpAImpl<B>,
|
||||
{
|
||||
fn vec_znx_dft_to_vec_znx_big_tmp_a<R, A>(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxDftToMut<B>,
|
||||
{
|
||||
B::vec_znx_dft_to_vec_znx_big_tmp_a_impl(self, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxDftToVecZnxBigConsume<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDftToVecZnxBigConsumeImpl<B>,
|
||||
{
|
||||
fn vec_znx_dft_to_vec_znx_big_consume<D: Data>(&self, a: VecZnxDft<D, B>) -> VecZnxBig<D, B>
|
||||
where
|
||||
VecZnxDft<D, B>: VecZnxDftToMut<B>,
|
||||
{
|
||||
B::vec_znx_dft_to_vec_znx_big_consume_impl(self, a)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxDftFromVecZnx<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDftFromVecZnxImpl<B>,
|
||||
{
|
||||
fn vec_znx_dft_from_vec_znx<R, A>(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_dft_from_vec_znx_impl(self, step, offset, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxDftAdd<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDftAddImpl<B>,
|
||||
{
|
||||
fn vec_znx_dft_add<R, A, D>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
D: VecZnxDftToRef<B>,
|
||||
{
|
||||
B::vec_znx_dft_add_impl(self, res, res_col, a, a_col, b, b_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxDftAddInplace<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDftAddInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_dft_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
{
|
||||
B::vec_znx_dft_add_inplace_impl(self, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxDftSub<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDftSubImpl<B>,
|
||||
{
|
||||
fn vec_znx_dft_sub<R, A, D>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
D: VecZnxDftToRef<B>,
|
||||
{
|
||||
B::vec_znx_dft_sub_impl(self, res, res_col, a, a_col, b, b_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxDftSubABInplace<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDftSubABInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_dft_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
{
|
||||
B::vec_znx_dft_sub_ab_inplace_impl(self, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxDftSubBAInplace<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDftSubBAInplaceImpl<B>,
|
||||
{
|
||||
fn vec_znx_dft_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
{
|
||||
B::vec_znx_dft_sub_ba_inplace_impl(self, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxDftCopy<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDftCopyImpl<B>,
|
||||
{
|
||||
fn vec_znx_dft_copy<R, A>(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
{
|
||||
B::vec_znx_dft_copy_impl(self, step, offset, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxDftZero<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDftZeroImpl<B>,
|
||||
{
|
||||
fn vec_znx_dft_zero<R>(&self, res: &mut R)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
{
|
||||
B::vec_znx_dft_zero_impl(self, res);
|
||||
}
|
||||
}
|
||||
126
backend/src/hal/delegates/vmp_pmat.rs
Normal file
126
backend/src/hal/delegates/vmp_pmat.rs
Normal file
@@ -0,0 +1,126 @@
|
||||
use crate::hal::{
|
||||
api::{
|
||||
VmpApply, VmpApplyAdd, VmpApplyAddTmpBytes, VmpApplyTmpBytes, VmpPMatAlloc, VmpPMatAllocBytes, VmpPMatFromBytes,
|
||||
VmpPMatPrepare, VmpPrepareTmpBytes,
|
||||
},
|
||||
layouts::{Backend, MatZnxToRef, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VmpPMatOwned, VmpPMatToMut, VmpPMatToRef},
|
||||
oep::{
|
||||
VmpApplyAddImpl, VmpApplyAddTmpBytesImpl, VmpApplyImpl, VmpApplyTmpBytesImpl, VmpPMatAllocBytesImpl, VmpPMatAllocImpl,
|
||||
VmpPMatFromBytesImpl, VmpPMatPrepareImpl, VmpPrepareTmpBytesImpl,
|
||||
},
|
||||
};
|
||||
|
||||
impl<B> VmpPMatAlloc<B> for Module<B>
|
||||
where
|
||||
B: Backend + VmpPMatAllocImpl<B>,
|
||||
{
|
||||
fn vmp_pmat_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<B> {
|
||||
B::vmp_pmat_alloc_impl(self.n(), rows, cols_in, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VmpPMatAllocBytes for Module<B>
|
||||
where
|
||||
B: Backend + VmpPMatAllocBytesImpl<B>,
|
||||
{
|
||||
fn vmp_pmat_alloc_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
B::vmp_pmat_alloc_bytes_impl(self.n(), rows, cols_in, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VmpPMatFromBytes<B> for Module<B>
|
||||
where
|
||||
B: Backend + VmpPMatFromBytesImpl<B>,
|
||||
{
|
||||
fn vmp_pmat_from_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: Vec<u8>) -> VmpPMatOwned<B> {
|
||||
B::vmp_pmat_from_bytes_impl(self.n(), rows, cols_in, cols_out, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VmpPrepareTmpBytes for Module<B>
|
||||
where
|
||||
B: Backend + VmpPrepareTmpBytesImpl<B>,
|
||||
{
|
||||
fn vmp_prepare_tmp_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
B::vmp_prepare_tmp_bytes_impl(self, rows, cols_in, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VmpPMatPrepare<B> for Module<B>
|
||||
where
|
||||
B: Backend + VmpPMatPrepareImpl<B>,
|
||||
{
|
||||
fn vmp_prepare<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VmpPMatToMut<B>,
|
||||
A: MatZnxToRef,
|
||||
{
|
||||
B::vmp_prepare_impl(self, res, a, scratch)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VmpApplyTmpBytes for Module<B>
|
||||
where
|
||||
B: Backend + VmpApplyTmpBytesImpl<B>,
|
||||
{
|
||||
fn vmp_apply_tmp_bytes(
|
||||
&self,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
b_cols_in: usize,
|
||||
b_cols_out: usize,
|
||||
b_size: usize,
|
||||
) -> usize {
|
||||
B::vmp_apply_tmp_bytes_impl(
|
||||
self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VmpApply<B> for Module<B>
|
||||
where
|
||||
B: Backend + VmpApplyImpl<B>,
|
||||
{
|
||||
fn vmp_apply<R, A, C>(&self, res: &mut R, a: &A, b: &C, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
C: VmpPMatToRef<B>,
|
||||
{
|
||||
B::vmp_apply_impl(self, res, a, b, scratch);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VmpApplyAddTmpBytes for Module<B>
|
||||
where
|
||||
B: Backend + VmpApplyAddTmpBytesImpl<B>,
|
||||
{
|
||||
fn vmp_apply_add_tmp_bytes(
|
||||
&self,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
b_cols_in: usize,
|
||||
b_cols_out: usize,
|
||||
b_size: usize,
|
||||
) -> usize {
|
||||
B::vmp_apply_add_tmp_bytes_impl(
|
||||
self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VmpApplyAdd<B> for Module<B>
|
||||
where
|
||||
B: Backend + VmpApplyAddImpl<B>,
|
||||
{
|
||||
fn vmp_apply_add<R, A, C>(&self, res: &mut R, a: &A, b: &C, scale: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
C: VmpPMatToRef<B>,
|
||||
{
|
||||
B::vmp_apply_add_impl(self, res, a, b, scale, scratch);
|
||||
}
|
||||
}
|
||||
246
backend/src/hal/layouts/mat_znx.rs
Normal file
246
backend/src/hal/layouts/mat_znx.rs
Normal file
@@ -0,0 +1,246 @@
|
||||
use crate::{
|
||||
alloc_aligned,
|
||||
hal::{
|
||||
api::{DataView, DataViewMut, ZnxInfos, ZnxSliceSize, ZnxView},
|
||||
layouts::{Data, DataMut, DataRef, ReaderFrom, VecZnx, WriterTo},
|
||||
},
|
||||
};
|
||||
|
||||
#[derive(PartialEq, Eq)]
|
||||
pub struct MatZnx<D: Data> {
|
||||
data: D,
|
||||
n: usize,
|
||||
size: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
}
|
||||
|
||||
impl<D: Data> ZnxInfos for MatZnx<D> {
|
||||
fn cols(&self) -> usize {
|
||||
self.cols_in
|
||||
}
|
||||
|
||||
fn rows(&self) -> usize {
|
||||
self.rows
|
||||
}
|
||||
|
||||
fn n(&self) -> usize {
|
||||
self.n
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> ZnxSliceSize for MatZnx<D> {
|
||||
fn sl(&self) -> usize {
|
||||
self.n() * self.cols_out()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> DataView for MatZnx<D> {
|
||||
type D = D;
|
||||
fn data(&self) -> &Self::D {
|
||||
&self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> DataViewMut for MatZnx<D> {
|
||||
fn data_mut(&mut self) -> &mut Self::D {
|
||||
&mut self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> ZnxView for MatZnx<D> {
|
||||
type Scalar = i64;
|
||||
}
|
||||
|
||||
impl<D: Data> MatZnx<D> {
|
||||
pub fn cols_in(&self) -> usize {
|
||||
self.cols_in
|
||||
}
|
||||
|
||||
pub fn cols_out(&self) -> usize {
|
||||
self.cols_out
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> MatZnx<D> {
|
||||
pub fn bytes_of(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
rows * cols_in * VecZnx::<Vec<u8>>::alloc_bytes::<i64>(n, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef + From<Vec<u8>>> MatZnx<D> {
|
||||
pub(crate) fn new(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
|
||||
let data: Vec<u8> = alloc_aligned(Self::bytes_of(n, rows, cols_in, cols_out, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n,
|
||||
size,
|
||||
rows,
|
||||
cols_in,
|
||||
cols_out,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn new_from_bytes(
|
||||
n: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
bytes: impl Into<Vec<u8>>,
|
||||
) -> Self {
|
||||
let data: Vec<u8> = bytes.into();
|
||||
assert!(data.len() == Self::bytes_of(n, rows, cols_in, cols_out, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n,
|
||||
size,
|
||||
rows,
|
||||
cols_in,
|
||||
cols_out,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> MatZnx<D> {
|
||||
pub fn at(&self, row: usize, col: usize) -> VecZnx<&[u8]> {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(row < self.rows(), "rows: {} >= {}", row, self.rows());
|
||||
assert!(col < self.cols_in(), "cols: {} >= {}", col, self.cols_in());
|
||||
}
|
||||
|
||||
let self_ref: MatZnx<&[u8]> = self.to_ref();
|
||||
let nb_bytes: usize = VecZnx::<Vec<u8>>::alloc_bytes::<i64>(self.n, self.cols_out, self.size);
|
||||
let start: usize = nb_bytes * self.cols() * row + col * nb_bytes;
|
||||
let end: usize = start + nb_bytes;
|
||||
|
||||
VecZnx {
|
||||
data: &self_ref.data[start..end],
|
||||
n: self.n,
|
||||
cols: self.cols_out,
|
||||
size: self.size,
|
||||
max_size: self.size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataMut> MatZnx<D> {
|
||||
pub fn at_mut(&mut self, row: usize, col: usize) -> VecZnx<&mut [u8]> {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(row < self.rows(), "rows: {} >= {}", row, self.rows());
|
||||
assert!(col < self.cols_in(), "cols: {} >= {}", col, self.cols_in());
|
||||
}
|
||||
|
||||
let n: usize = self.n();
|
||||
let cols_out: usize = self.cols_out();
|
||||
let cols_in: usize = self.cols_in();
|
||||
let size: usize = self.size();
|
||||
|
||||
let self_ref: MatZnx<&mut [u8]> = self.to_mut();
|
||||
let nb_bytes: usize = VecZnx::<Vec<u8>>::alloc_bytes::<i64>(n, cols_out, size);
|
||||
let start: usize = nb_bytes * cols_in * row + col * nb_bytes;
|
||||
let end: usize = start + nb_bytes;
|
||||
|
||||
VecZnx {
|
||||
data: &mut self_ref.data[start..end],
|
||||
n,
|
||||
cols: cols_out,
|
||||
size,
|
||||
max_size: size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type MatZnxOwned = MatZnx<Vec<u8>>;
|
||||
pub type MatZnxMut<'a> = MatZnx<&'a mut [u8]>;
|
||||
pub type MatZnxRef<'a> = MatZnx<&'a [u8]>;
|
||||
|
||||
pub trait MatZnxToRef {
|
||||
fn to_ref(&self) -> MatZnx<&[u8]>;
|
||||
}
|
||||
|
||||
impl<D: DataRef> MatZnxToRef for MatZnx<D> {
|
||||
fn to_ref(&self) -> MatZnx<&[u8]> {
|
||||
MatZnx {
|
||||
data: self.data.as_ref(),
|
||||
n: self.n,
|
||||
rows: self.rows,
|
||||
cols_in: self.cols_in,
|
||||
cols_out: self.cols_out,
|
||||
size: self.size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait MatZnxToMut {
|
||||
fn to_mut(&mut self) -> MatZnx<&mut [u8]>;
|
||||
}
|
||||
|
||||
impl<D: DataMut> MatZnxToMut for MatZnx<D> {
|
||||
fn to_mut(&mut self) -> MatZnx<&mut [u8]> {
|
||||
MatZnx {
|
||||
data: self.data.as_mut(),
|
||||
n: self.n,
|
||||
rows: self.rows,
|
||||
cols_in: self.cols_in,
|
||||
cols_out: self.cols_out,
|
||||
size: self.size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> MatZnx<D> {
|
||||
pub(crate) fn from_data(data: D, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
|
||||
Self {
|
||||
data,
|
||||
n,
|
||||
rows,
|
||||
cols_in,
|
||||
cols_out,
|
||||
size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
|
||||
impl<D: DataMut> ReaderFrom for MatZnx<D> {
|
||||
fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
|
||||
self.n = reader.read_u64::<LittleEndian>()? as usize;
|
||||
self.size = reader.read_u64::<LittleEndian>()? as usize;
|
||||
self.rows = reader.read_u64::<LittleEndian>()? as usize;
|
||||
self.cols_in = reader.read_u64::<LittleEndian>()? as usize;
|
||||
self.cols_out = reader.read_u64::<LittleEndian>()? as usize;
|
||||
let len: usize = reader.read_u64::<LittleEndian>()? as usize;
|
||||
let buf: &mut [u8] = self.data.as_mut();
|
||||
if buf.len() != len {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::UnexpectedEof,
|
||||
format!("self.data.len()={} != read len={}", buf.len(), len),
|
||||
));
|
||||
}
|
||||
reader.read_exact(&mut buf[..len])?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> WriterTo for MatZnx<D> {
|
||||
fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
|
||||
writer.write_u64::<LittleEndian>(self.n as u64)?;
|
||||
writer.write_u64::<LittleEndian>(self.size as u64)?;
|
||||
writer.write_u64::<LittleEndian>(self.rows as u64)?;
|
||||
writer.write_u64::<LittleEndian>(self.cols_in as u64)?;
|
||||
writer.write_u64::<LittleEndian>(self.cols_out as u64)?;
|
||||
let buf: &[u8] = self.data.as_ref();
|
||||
writer.write_u64::<LittleEndian>(buf.len() as u64)?;
|
||||
writer.write_all(buf)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
25
backend/src/hal/layouts/mod.rs
Normal file
25
backend/src/hal/layouts/mod.rs
Normal file
@@ -0,0 +1,25 @@
|
||||
mod mat_znx;
|
||||
mod module;
|
||||
mod scalar_znx;
|
||||
mod scratch;
|
||||
mod serialization;
|
||||
mod svp_ppol;
|
||||
mod vec_znx;
|
||||
mod vec_znx_big;
|
||||
mod vec_znx_dft;
|
||||
mod vmp_pmat;
|
||||
|
||||
pub use mat_znx::*;
|
||||
pub use module::*;
|
||||
pub use scalar_znx::*;
|
||||
pub use scratch::*;
|
||||
pub use serialization::*;
|
||||
pub use svp_ppol::*;
|
||||
pub use vec_znx::*;
|
||||
pub use vec_znx_big::*;
|
||||
pub use vec_znx_dft::*;
|
||||
pub use vmp_pmat::*;
|
||||
|
||||
pub trait Data = PartialEq + Eq + Sized;
|
||||
pub trait DataRef = Data + AsRef<[u8]>;
|
||||
pub trait DataMut = DataRef + AsMut<[u8]>;
|
||||
@@ -1,71 +1,56 @@
|
||||
use std::{marker::PhantomData, ptr::NonNull};
|
||||
|
||||
use crate::GALOISGENERATOR;
|
||||
use crate::ffi::module::{MODULE, delete_module_info, module_info_t, new_module_info};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
#[repr(u8)]
|
||||
pub enum BACKEND {
|
||||
FFT64,
|
||||
NTT120,
|
||||
}
|
||||
|
||||
pub trait Backend {
|
||||
const KIND: BACKEND;
|
||||
fn module_type() -> u32;
|
||||
}
|
||||
|
||||
pub struct FFT64;
|
||||
pub struct NTT120;
|
||||
|
||||
impl Backend for FFT64 {
|
||||
const KIND: BACKEND = BACKEND::FFT64;
|
||||
fn module_type() -> u32 {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
impl Backend for NTT120 {
|
||||
const KIND: BACKEND = BACKEND::NTT120;
|
||||
fn module_type() -> u32 {
|
||||
1
|
||||
}
|
||||
pub trait Backend: Sized {
|
||||
type Handle: 'static;
|
||||
unsafe fn destroy(handle: NonNull<Self::Handle>);
|
||||
}
|
||||
|
||||
pub struct Module<B: Backend> {
|
||||
pub ptr: *mut MODULE,
|
||||
n: usize,
|
||||
ptr: NonNull<B::Handle>,
|
||||
n: u64,
|
||||
_marker: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Module<B> {
|
||||
// Instantiates a new module.
|
||||
pub fn new(n: usize) -> Self {
|
||||
unsafe {
|
||||
let m: *mut module_info_t = new_module_info(n as u64, B::module_type());
|
||||
if m.is_null() {
|
||||
panic!("Failed to create module.");
|
||||
}
|
||||
Self {
|
||||
ptr: m,
|
||||
n: n,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
/// Construct from a raw pointer managed elsewhere.
|
||||
/// SAFETY: `ptr` must be non-null and remain valid for the lifetime of this Module.
|
||||
#[inline]
|
||||
pub unsafe fn from_raw_parts(ptr: *mut B::Handle, n: u64) -> Self {
|
||||
Self {
|
||||
ptr: NonNull::new(ptr).expect("null module ptr"),
|
||||
n,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn n(&self) -> usize {
|
||||
self.n
|
||||
#[inline]
|
||||
pub unsafe fn ptr(&self) -> *mut <B as Backend>::Handle {
|
||||
self.ptr.as_ptr()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn n(&self) -> usize {
|
||||
self.n as usize
|
||||
}
|
||||
#[inline]
|
||||
pub fn as_mut_ptr(&self) -> *mut B::Handle {
|
||||
self.ptr.as_ptr()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn log_n(&self) -> usize {
|
||||
(usize::BITS - (self.n() - 1).leading_zeros()) as _
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn cyclotomic_order(&self) -> u64 {
|
||||
(self.n() << 1) as _
|
||||
}
|
||||
|
||||
// Returns GALOISGENERATOR^|generator| * sign(generator)
|
||||
#[inline]
|
||||
pub fn galois_element(&self, generator: i64) -> i64 {
|
||||
if generator == 0 {
|
||||
return 1;
|
||||
@@ -74,6 +59,7 @@ impl<B: Backend> Module<B> {
|
||||
}
|
||||
|
||||
// Returns gen^-1
|
||||
#[inline]
|
||||
pub fn galois_element_inv(&self, gal_el: i64) -> i64 {
|
||||
if gal_el == 0 {
|
||||
panic!("cannot invert 0")
|
||||
@@ -85,11 +71,11 @@ impl<B: Backend> Module<B> {
|
||||
|
||||
impl<B: Backend> Drop for Module<B> {
|
||||
fn drop(&mut self) {
|
||||
unsafe { delete_module_info(self.ptr) }
|
||||
unsafe { B::destroy(self.ptr) }
|
||||
}
|
||||
}
|
||||
|
||||
fn mod_exp_u64(x: u64, e: usize) -> u64 {
|
||||
pub fn mod_exp_u64(x: u64, e: usize) -> u64 {
|
||||
let mut y: u64 = 1;
|
||||
let mut x_pow: u64 = x;
|
||||
let mut exp = e;
|
||||
@@ -1,20 +1,24 @@
|
||||
use crate::ffi::vec_znx;
|
||||
use crate::znx_base::ZnxInfos;
|
||||
use crate::{
|
||||
Backend, DataView, DataViewMut, Module, VecZnx, VecZnxToMut, VecZnxToRef, ZnxSliceSize, ZnxView, ZnxViewMut, alloc_aligned,
|
||||
};
|
||||
use rand::seq::SliceRandom;
|
||||
use rand_core::RngCore;
|
||||
use rand_distr::{Distribution, weighted::WeightedIndex};
|
||||
use sampling::source::Source;
|
||||
|
||||
pub struct ScalarZnx<D> {
|
||||
use crate::{
|
||||
alloc_aligned,
|
||||
hal::{
|
||||
api::{DataView, DataViewMut, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero},
|
||||
layouts::{Data, DataMut, DataRef, ReaderFrom, VecZnx, VecZnxToMut, VecZnxToRef, WriterTo},
|
||||
},
|
||||
};
|
||||
|
||||
#[derive(PartialEq, Eq)]
|
||||
pub struct ScalarZnx<D: Data> {
|
||||
pub(crate) data: D,
|
||||
pub(crate) n: usize,
|
||||
pub(crate) cols: usize,
|
||||
}
|
||||
|
||||
impl<D> ZnxInfos for ScalarZnx<D> {
|
||||
impl<D: Data> ZnxInfos for ScalarZnx<D> {
|
||||
fn cols(&self) -> usize {
|
||||
self.cols
|
||||
}
|
||||
@@ -32,30 +36,30 @@ impl<D> ZnxInfos for ScalarZnx<D> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<D> ZnxSliceSize for ScalarZnx<D> {
|
||||
impl<D: Data> ZnxSliceSize for ScalarZnx<D> {
|
||||
fn sl(&self) -> usize {
|
||||
self.n()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D> DataView for ScalarZnx<D> {
|
||||
impl<D: Data> DataView for ScalarZnx<D> {
|
||||
type D = D;
|
||||
fn data(&self) -> &Self::D {
|
||||
&self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D> DataViewMut for ScalarZnx<D> {
|
||||
impl<D: Data> DataViewMut for ScalarZnx<D> {
|
||||
fn data_mut(&mut self) -> &mut Self::D {
|
||||
&mut self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: AsRef<[u8]>> ZnxView for ScalarZnx<D> {
|
||||
impl<D: DataRef> ZnxView for ScalarZnx<D> {
|
||||
type Scalar = i64;
|
||||
}
|
||||
|
||||
impl<D: AsMut<[u8]> + AsRef<[u8]>> ScalarZnx<D> {
|
||||
impl<D: DataMut> ScalarZnx<D> {
|
||||
pub fn fill_ternary_prob(&mut self, col: usize, prob: f64, source: &mut Source) {
|
||||
let choices: [i64; 3] = [-1, 0, 1];
|
||||
let weights: [f64; 3] = [prob / 2.0, 1.0 - prob, prob / 2.0];
|
||||
@@ -103,11 +107,13 @@ impl<D: AsMut<[u8]> + AsRef<[u8]>> ScalarZnx<D> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: From<Vec<u8>>> ScalarZnx<D> {
|
||||
pub(crate) fn bytes_of(n: usize, cols: usize) -> usize {
|
||||
impl<D: DataRef> ScalarZnx<D> {
|
||||
pub fn bytes_of(n: usize, cols: usize) -> usize {
|
||||
n * cols * size_of::<i64>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef + From<Vec<u8>>> ScalarZnx<D> {
|
||||
pub fn new(n: usize, cols: usize) -> Self {
|
||||
let data: Vec<u8> = alloc_aligned::<u8>(Self::bytes_of(n, cols));
|
||||
Self {
|
||||
@@ -128,94 +134,18 @@ impl<D: From<Vec<u8>>> ScalarZnx<D> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataMut> ZnxZero for ScalarZnx<D> {
|
||||
fn zero(&mut self) {
|
||||
self.raw_mut().fill(0)
|
||||
}
|
||||
fn zero_at(&mut self, i: usize, j: usize) {
|
||||
self.at_mut(i, j).fill(0);
|
||||
}
|
||||
}
|
||||
|
||||
pub type ScalarZnxOwned = ScalarZnx<Vec<u8>>;
|
||||
|
||||
pub(crate) fn bytes_of_scalar_znx<B: Backend>(module: &Module<B>, cols: usize) -> usize {
|
||||
ScalarZnxOwned::bytes_of(module.n(), cols)
|
||||
}
|
||||
|
||||
pub trait ScalarZnxAlloc {
|
||||
fn bytes_of_scalar_znx(&self, cols: usize) -> usize;
|
||||
fn new_scalar_znx(&self, cols: usize) -> ScalarZnxOwned;
|
||||
fn new_scalar_znx_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxOwned;
|
||||
}
|
||||
|
||||
impl<B: Backend> ScalarZnxAlloc for Module<B> {
|
||||
fn bytes_of_scalar_znx(&self, cols: usize) -> usize {
|
||||
ScalarZnxOwned::bytes_of(self.n(), cols)
|
||||
}
|
||||
fn new_scalar_znx(&self, cols: usize) -> ScalarZnxOwned {
|
||||
ScalarZnxOwned::new(self.n(), cols)
|
||||
}
|
||||
fn new_scalar_znx_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxOwned {
|
||||
ScalarZnxOwned::new_from_bytes(self.n(), cols, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ScalarZnxOps {
|
||||
fn scalar_znx_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: ScalarZnxToMut,
|
||||
A: ScalarZnxToRef;
|
||||
|
||||
/// Applies the automorphism X^i -> X^ik on the selected column of `a`.
|
||||
fn scalar_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: ScalarZnxToMut;
|
||||
}
|
||||
|
||||
impl<B: Backend> ScalarZnxOps for Module<B> {
|
||||
fn scalar_znx_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: ScalarZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
{
|
||||
let a: ScalarZnx<&[u8]> = a.to_ref();
|
||||
let mut res: ScalarZnx<&mut [u8]> = res.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
self.ptr,
|
||||
k,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn scalar_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: ScalarZnxToMut,
|
||||
{
|
||||
let mut a: ScalarZnx<&mut [u8]> = a.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
self.ptr,
|
||||
k,
|
||||
a.at_mut_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D> ScalarZnx<D> {
|
||||
impl<D: Data> ScalarZnx<D> {
|
||||
pub(crate) fn from_data(data: D, n: usize, cols: usize) -> Self {
|
||||
Self { data, n, cols }
|
||||
}
|
||||
@@ -225,10 +155,7 @@ pub trait ScalarZnxToRef {
|
||||
fn to_ref(&self) -> ScalarZnx<&[u8]>;
|
||||
}
|
||||
|
||||
impl<D> ScalarZnxToRef for ScalarZnx<D>
|
||||
where
|
||||
D: AsRef<[u8]>,
|
||||
{
|
||||
impl<D: DataRef> ScalarZnxToRef for ScalarZnx<D> {
|
||||
fn to_ref(&self) -> ScalarZnx<&[u8]> {
|
||||
ScalarZnx {
|
||||
data: self.data.as_ref(),
|
||||
@@ -242,10 +169,7 @@ pub trait ScalarZnxToMut {
|
||||
fn to_mut(&mut self) -> ScalarZnx<&mut [u8]>;
|
||||
}
|
||||
|
||||
impl<D> ScalarZnxToMut for ScalarZnx<D>
|
||||
where
|
||||
D: AsRef<[u8]> + AsMut<[u8]>,
|
||||
{
|
||||
impl<D: DataMut> ScalarZnxToMut for ScalarZnx<D> {
|
||||
fn to_mut(&mut self) -> ScalarZnx<&mut [u8]> {
|
||||
ScalarZnx {
|
||||
data: self.data.as_mut(),
|
||||
@@ -255,30 +179,56 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<D> VecZnxToRef for ScalarZnx<D>
|
||||
where
|
||||
D: AsRef<[u8]>,
|
||||
{
|
||||
impl<D: DataRef> VecZnxToRef for ScalarZnx<D> {
|
||||
fn to_ref(&self) -> VecZnx<&[u8]> {
|
||||
VecZnx {
|
||||
data: self.data.as_ref(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: 1,
|
||||
max_size: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D> VecZnxToMut for ScalarZnx<D>
|
||||
where
|
||||
D: AsRef<[u8]> + AsMut<[u8]>,
|
||||
{
|
||||
impl<D: DataMut> VecZnxToMut for ScalarZnx<D> {
|
||||
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
|
||||
VecZnx {
|
||||
data: self.data.as_mut(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: 1,
|
||||
max_size: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
|
||||
impl<D: DataMut> ReaderFrom for ScalarZnx<D> {
|
||||
fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
|
||||
self.n = reader.read_u64::<LittleEndian>()? as usize;
|
||||
self.cols = reader.read_u64::<LittleEndian>()? as usize;
|
||||
let len: usize = reader.read_u64::<LittleEndian>()? as usize;
|
||||
let buf: &mut [u8] = self.data.as_mut();
|
||||
if buf.len() != len {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::UnexpectedEof,
|
||||
format!("self.data.len()={} != read len={}", buf.len(), len),
|
||||
));
|
||||
}
|
||||
reader.read_exact(&mut buf[..len])?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> WriterTo for ScalarZnx<D> {
|
||||
fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
|
||||
writer.write_u64::<LittleEndian>(self.n as u64)?;
|
||||
writer.write_u64::<LittleEndian>(self.cols as u64)?;
|
||||
let buf: &[u8] = self.data.as_ref();
|
||||
writer.write_u64::<LittleEndian>(buf.len() as u64)?;
|
||||
writer.write_all(buf)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
13
backend/src/hal/layouts/scratch.rs
Normal file
13
backend/src/hal/layouts/scratch.rs
Normal file
@@ -0,0 +1,13 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::hal::layouts::Backend;
|
||||
|
||||
pub struct ScratchOwned<B: Backend> {
|
||||
pub(crate) data: Vec<u8>,
|
||||
pub(crate) _phantom: PhantomData<B>,
|
||||
}
|
||||
|
||||
pub struct Scratch<B: Backend> {
|
||||
pub(crate) _phantom: PhantomData<B>,
|
||||
pub(crate) data: [u8],
|
||||
}
|
||||
9
backend/src/hal/layouts/serialization.rs
Normal file
9
backend/src/hal/layouts/serialization.rs
Normal file
@@ -0,0 +1,9 @@
|
||||
use std::io::{Read, Result, Write};
|
||||
|
||||
pub trait WriterTo {
|
||||
fn write_to<W: Write>(&self, writer: &mut W) -> Result<()>;
|
||||
}
|
||||
|
||||
pub trait ReaderFrom {
|
||||
fn read_from<R: Read>(&mut self, reader: &mut R) -> Result<()>;
|
||||
}
|
||||
151
backend/src/hal/layouts/svp_ppol.rs
Normal file
151
backend/src/hal/layouts/svp_ppol.rs
Normal file
@@ -0,0 +1,151 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::{
|
||||
alloc_aligned,
|
||||
hal::{
|
||||
api::{DataView, DataViewMut, ZnxInfos},
|
||||
layouts::{Backend, Data, DataMut, DataRef, ReaderFrom, WriterTo},
|
||||
},
|
||||
};
|
||||
|
||||
#[derive(PartialEq, Eq)]
|
||||
pub struct SvpPPol<D: Data, B: Backend> {
|
||||
data: D,
|
||||
n: usize,
|
||||
cols: usize,
|
||||
_phantom: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<D: Data, B: Backend> ZnxInfos for SvpPPol<D, B> {
|
||||
fn cols(&self) -> usize {
|
||||
self.cols
|
||||
}
|
||||
|
||||
fn rows(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn n(&self) -> usize {
|
||||
self.n
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
1
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data, B: Backend> DataView for SvpPPol<D, B> {
|
||||
type D = D;
|
||||
fn data(&self) -> &Self::D {
|
||||
&self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data, B: Backend> DataViewMut for SvpPPol<D, B> {
|
||||
fn data_mut(&mut self) -> &mut Self::D {
|
||||
&mut self.data
|
||||
}
|
||||
}
|
||||
|
||||
pub trait SvpPPolBytesOf {
|
||||
fn bytes_of(n: usize, cols: usize) -> usize;
|
||||
}
|
||||
|
||||
impl<D: Data + From<Vec<u8>>, B: Backend> SvpPPol<D, B>
|
||||
where
|
||||
SvpPPol<D, B>: SvpPPolBytesOf,
|
||||
{
|
||||
pub(crate) fn alloc(n: usize, cols: usize) -> Self {
|
||||
let data: Vec<u8> = alloc_aligned::<u8>(Self::bytes_of(n, cols));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n,
|
||||
cols,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn from_bytes(n: usize, cols: usize, bytes: impl Into<Vec<u8>>) -> Self {
|
||||
let data: Vec<u8> = bytes.into();
|
||||
assert!(data.len() == Self::bytes_of(n, cols));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n,
|
||||
cols,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type SvpPPolOwned<B> = SvpPPol<Vec<u8>, B>;
|
||||
|
||||
pub trait SvpPPolToRef<B: Backend> {
|
||||
fn to_ref(&self) -> SvpPPol<&[u8], B>;
|
||||
}
|
||||
|
||||
impl<D: DataRef, B: Backend> SvpPPolToRef<B> for SvpPPol<D, B> {
|
||||
fn to_ref(&self) -> SvpPPol<&[u8], B> {
|
||||
SvpPPol {
|
||||
data: self.data.as_ref(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait SvpPPolToMut<B: Backend> {
|
||||
fn to_mut(&mut self) -> SvpPPol<&mut [u8], B>;
|
||||
}
|
||||
|
||||
impl<D: DataMut, B: Backend> SvpPPolToMut<B> for SvpPPol<D, B> {
|
||||
fn to_mut(&mut self) -> SvpPPol<&mut [u8], B> {
|
||||
SvpPPol {
|
||||
data: self.data.as_mut(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data, B: Backend> SvpPPol<D, B> {
|
||||
pub(crate) fn from_data(data: D, n: usize, cols: usize) -> Self {
|
||||
Self {
|
||||
data,
|
||||
n,
|
||||
cols,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
|
||||
impl<D: DataMut, B: Backend> ReaderFrom for SvpPPol<D, B> {
|
||||
fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
|
||||
self.n = reader.read_u64::<LittleEndian>()? as usize;
|
||||
self.cols = reader.read_u64::<LittleEndian>()? as usize;
|
||||
let len: usize = reader.read_u64::<LittleEndian>()? as usize;
|
||||
let buf: &mut [u8] = self.data.as_mut();
|
||||
if buf.len() != len {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::UnexpectedEof,
|
||||
format!("self.data.len()={} != read len={}", buf.len(), len),
|
||||
));
|
||||
}
|
||||
reader.read_exact(&mut buf[..len])?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef, B: Backend> WriterTo for SvpPPol<D, B> {
|
||||
fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
|
||||
writer.write_u64::<LittleEndian>(self.n as u64)?;
|
||||
writer.write_u64::<LittleEndian>(self.cols as u64)?;
|
||||
let buf: &[u8] = self.data.as_ref();
|
||||
writer.write_u64::<LittleEndian>(buf.len() as u64)?;
|
||||
writer.write_all(buf)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
241
backend/src/hal/layouts/vec_znx.rs
Normal file
241
backend/src/hal/layouts/vec_znx.rs
Normal file
@@ -0,0 +1,241 @@
|
||||
use std::fmt;
|
||||
|
||||
use crate::{
|
||||
alloc_aligned,
|
||||
hal::{
|
||||
api::{DataView, DataViewMut, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero},
|
||||
layouts::{Data, DataMut, DataRef, ReaderFrom, WriterTo},
|
||||
},
|
||||
};
|
||||
|
||||
#[derive(PartialEq, Eq)]
|
||||
pub struct VecZnx<D: Data> {
|
||||
pub(crate) data: D,
|
||||
pub(crate) n: usize,
|
||||
pub(crate) cols: usize,
|
||||
pub(crate) size: usize,
|
||||
pub(crate) max_size: usize,
|
||||
}
|
||||
|
||||
impl<D: DataRef> fmt::Debug for VecZnx<D> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}", self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> ZnxInfos for VecZnx<D> {
|
||||
fn cols(&self) -> usize {
|
||||
self.cols
|
||||
}
|
||||
|
||||
fn rows(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn n(&self) -> usize {
|
||||
self.n
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> ZnxSliceSize for VecZnx<D> {
|
||||
fn sl(&self) -> usize {
|
||||
self.n() * self.cols()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> DataView for VecZnx<D> {
|
||||
type D = D;
|
||||
fn data(&self) -> &Self::D {
|
||||
&self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> DataViewMut for VecZnx<D> {
|
||||
fn data_mut(&mut self) -> &mut Self::D {
|
||||
&mut self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> ZnxView for VecZnx<D> {
|
||||
type Scalar = i64;
|
||||
}
|
||||
|
||||
impl VecZnx<Vec<u8>> {
|
||||
pub fn rsh_scratch_space(n: usize) -> usize {
|
||||
n * std::mem::size_of::<i64>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataMut> ZnxZero for VecZnx<D> {
|
||||
fn zero(&mut self) {
|
||||
self.raw_mut().fill(0)
|
||||
}
|
||||
fn zero_at(&mut self, i: usize, j: usize) {
|
||||
self.at_mut(i, j).fill(0);
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> VecZnx<D> {
|
||||
pub fn alloc_bytes<Scalar: Sized>(n: usize, cols: usize, size: usize) -> usize {
|
||||
n * cols * size * size_of::<Scalar>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef + From<Vec<u8>>> VecZnx<D> {
|
||||
pub fn new<Scalar: Sized>(n: usize, cols: usize, size: usize) -> Self {
|
||||
let data: Vec<u8> = alloc_aligned::<u8>(Self::alloc_bytes::<Scalar>(n, cols, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n,
|
||||
cols,
|
||||
size,
|
||||
max_size: size,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_bytes<Scalar: Sized>(n: usize, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
|
||||
let data: Vec<u8> = bytes.into();
|
||||
assert!(data.len() == Self::alloc_bytes::<Scalar>(n, cols, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n,
|
||||
cols,
|
||||
size,
|
||||
max_size: size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> VecZnx<D> {
|
||||
pub fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
|
||||
Self {
|
||||
data,
|
||||
n,
|
||||
cols,
|
||||
size,
|
||||
max_size: size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> fmt::Display for VecZnx<D> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
writeln!(
|
||||
f,
|
||||
"VecZnx(n={}, cols={}, size={})",
|
||||
self.n, self.cols, self.size
|
||||
)?;
|
||||
|
||||
for col in 0..self.cols {
|
||||
writeln!(f, "Column {}:", col)?;
|
||||
for size in 0..self.size {
|
||||
let coeffs = self.at(col, size);
|
||||
write!(f, " Size {}: [", size)?;
|
||||
|
||||
let max_show = 100;
|
||||
let show_count = coeffs.len().min(max_show);
|
||||
|
||||
for (i, &coeff) in coeffs.iter().take(show_count).enumerate() {
|
||||
if i > 0 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
write!(f, "{}", coeff)?;
|
||||
}
|
||||
|
||||
if coeffs.len() > max_show {
|
||||
write!(f, ", ... ({} more)", coeffs.len() - max_show)?;
|
||||
}
|
||||
|
||||
writeln!(f, "]")?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub type VecZnxOwned = VecZnx<Vec<u8>>;
|
||||
pub type VecZnxMut<'a> = VecZnx<&'a mut [u8]>;
|
||||
pub type VecZnxRef<'a> = VecZnx<&'a [u8]>;
|
||||
|
||||
pub trait VecZnxToRef {
|
||||
fn to_ref(&self) -> VecZnx<&[u8]>;
|
||||
}
|
||||
|
||||
impl<D: DataRef> VecZnxToRef for VecZnx<D> {
|
||||
fn to_ref(&self) -> VecZnx<&[u8]> {
|
||||
VecZnx {
|
||||
data: self.data.as_ref(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: self.size,
|
||||
max_size: self.max_size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait VecZnxToMut {
|
||||
fn to_mut(&mut self) -> VecZnx<&mut [u8]>;
|
||||
}
|
||||
|
||||
impl<D: DataMut> VecZnxToMut for VecZnx<D> {
|
||||
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
|
||||
VecZnx {
|
||||
data: self.data.as_mut(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: self.size,
|
||||
max_size: self.max_size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> VecZnx<D> {
|
||||
pub fn clone(&self) -> VecZnx<Vec<u8>> {
|
||||
let self_ref: VecZnx<&[u8]> = self.to_ref();
|
||||
VecZnx {
|
||||
data: self_ref.data.to_vec(),
|
||||
n: self_ref.n,
|
||||
cols: self_ref.cols,
|
||||
size: self_ref.size,
|
||||
max_size: self_ref.max_size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
|
||||
impl<D: DataMut> ReaderFrom for VecZnx<D> {
|
||||
fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
|
||||
self.n = reader.read_u64::<LittleEndian>()? as usize;
|
||||
self.cols = reader.read_u64::<LittleEndian>()? as usize;
|
||||
self.size = reader.read_u64::<LittleEndian>()? as usize;
|
||||
self.max_size = reader.read_u64::<LittleEndian>()? as usize;
|
||||
let len: usize = reader.read_u64::<LittleEndian>()? as usize;
|
||||
let buf: &mut [u8] = self.data.as_mut();
|
||||
if buf.len() != len {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::UnexpectedEof,
|
||||
format!("self.data.len()={} != read len={}", buf.len(), len),
|
||||
));
|
||||
}
|
||||
reader.read_exact(&mut buf[..len])?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> WriterTo for VecZnx<D> {
|
||||
fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
|
||||
writer.write_u64::<LittleEndian>(self.n as u64)?;
|
||||
writer.write_u64::<LittleEndian>(self.cols as u64)?;
|
||||
writer.write_u64::<LittleEndian>(self.size as u64)?;
|
||||
writer.write_u64::<LittleEndian>(self.max_size as u64)?;
|
||||
let buf: &[u8] = self.data.as_ref();
|
||||
writer.write_u64::<LittleEndian>(buf.len() as u64)?;
|
||||
writer.write_all(buf)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
148
backend/src/hal/layouts/vec_znx_big.rs
Normal file
148
backend/src/hal/layouts/vec_znx_big.rs
Normal file
@@ -0,0 +1,148 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use rand_distr::num_traits::Zero;
|
||||
|
||||
use crate::{
|
||||
alloc_aligned,
|
||||
hal::{
|
||||
api::{DataView, DataViewMut, ZnxInfos, ZnxView, ZnxViewMut, ZnxZero},
|
||||
layouts::{Backend, Data, DataMut, DataRef},
|
||||
},
|
||||
};
|
||||
|
||||
#[derive(PartialEq, Eq)]
|
||||
pub struct VecZnxBig<D: Data, B: Backend> {
|
||||
pub(crate) data: D,
|
||||
pub(crate) n: usize,
|
||||
pub(crate) cols: usize,
|
||||
pub(crate) size: usize,
|
||||
pub(crate) max_size: usize,
|
||||
pub(crate) _phantom: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<D: Data, B: Backend> ZnxInfos for VecZnxBig<D, B> {
|
||||
fn cols(&self) -> usize {
|
||||
self.cols
|
||||
}
|
||||
|
||||
fn rows(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn n(&self) -> usize {
|
||||
self.n
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data, B: Backend> DataView for VecZnxBig<D, B> {
|
||||
type D = D;
|
||||
fn data(&self) -> &Self::D {
|
||||
&self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data, B: Backend> DataViewMut for VecZnxBig<D, B> {
|
||||
fn data_mut(&mut self) -> &mut Self::D {
|
||||
&mut self.data
|
||||
}
|
||||
}
|
||||
|
||||
pub trait VecZnxBigBytesOf {
|
||||
fn bytes_of(n: usize, cols: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
impl<D: DataMut, B: Backend> ZnxZero for VecZnxBig<D, B>
|
||||
where
|
||||
Self: ZnxViewMut,
|
||||
<Self as ZnxView>::Scalar: Zero + Copy,
|
||||
{
|
||||
fn zero(&mut self) {
|
||||
self.raw_mut().fill(<Self as ZnxView>::Scalar::zero())
|
||||
}
|
||||
fn zero_at(&mut self, i: usize, j: usize) {
|
||||
self.at_mut(i, j).fill(<Self as ZnxView>::Scalar::zero());
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef + From<Vec<u8>>, B: Backend> VecZnxBig<D, B>
|
||||
where
|
||||
VecZnxBig<D, B>: VecZnxBigBytesOf,
|
||||
{
|
||||
pub(crate) fn new(n: usize, cols: usize, size: usize) -> Self {
|
||||
let data = alloc_aligned::<u8>(Self::bytes_of(n, cols, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n,
|
||||
cols,
|
||||
size,
|
||||
max_size: size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn new_from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
|
||||
let data: Vec<u8> = bytes.into();
|
||||
assert!(data.len() == Self::bytes_of(n, cols, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n,
|
||||
cols,
|
||||
size,
|
||||
max_size: size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data, B: Backend> VecZnxBig<D, B> {
|
||||
pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
|
||||
Self {
|
||||
data,
|
||||
n,
|
||||
cols,
|
||||
size,
|
||||
max_size: size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type VecZnxBigOwned<B> = VecZnxBig<Vec<u8>, B>;
|
||||
|
||||
pub trait VecZnxBigToRef<B: Backend> {
|
||||
fn to_ref(&self) -> VecZnxBig<&[u8], B>;
|
||||
}
|
||||
|
||||
impl<D: DataRef, B: Backend> VecZnxBigToRef<B> for VecZnxBig<D, B> {
|
||||
fn to_ref(&self) -> VecZnxBig<&[u8], B> {
|
||||
VecZnxBig {
|
||||
data: self.data.as_ref(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: self.size,
|
||||
max_size: self.max_size,
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait VecZnxBigToMut<B: Backend> {
|
||||
fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B>;
|
||||
}
|
||||
|
||||
impl<D: DataMut, B: Backend> VecZnxBigToMut<B> for VecZnxBig<D, B> {
|
||||
fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B> {
|
||||
VecZnxBig {
|
||||
data: self.data.as_mut(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: self.size,
|
||||
max_size: self.max_size,
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
166
backend/src/hal/layouts/vec_znx_dft.rs
Normal file
166
backend/src/hal/layouts/vec_znx_dft.rs
Normal file
@@ -0,0 +1,166 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use rand_distr::num_traits::Zero;
|
||||
|
||||
use crate::{
|
||||
alloc_aligned,
|
||||
hal::{
|
||||
api::{DataView, DataViewMut, ZnxInfos, ZnxView, ZnxViewMut, ZnxZero},
|
||||
layouts::{Backend, Data, DataMut, DataRef, VecZnxBig},
|
||||
},
|
||||
};
|
||||
#[derive(PartialEq, Eq)]
|
||||
pub struct VecZnxDft<D: Data, B: Backend> {
|
||||
pub(crate) data: D,
|
||||
pub(crate) n: usize,
|
||||
pub(crate) cols: usize,
|
||||
pub(crate) size: usize,
|
||||
pub(crate) max_size: usize,
|
||||
pub(crate) _phantom: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<D: Data, B: Backend> VecZnxDft<D, B> {
|
||||
pub fn into_big(self) -> VecZnxBig<D, B> {
|
||||
VecZnxBig::<D, B>::from_data(self.data, self.n, self.cols, self.size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data, B: Backend> ZnxInfos for VecZnxDft<D, B> {
|
||||
fn cols(&self) -> usize {
|
||||
self.cols
|
||||
}
|
||||
|
||||
fn rows(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn n(&self) -> usize {
|
||||
self.n
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data, B: Backend> DataView for VecZnxDft<D, B> {
|
||||
type D = D;
|
||||
fn data(&self) -> &Self::D {
|
||||
&self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data, B: Backend> DataViewMut for VecZnxDft<D, B> {
|
||||
fn data_mut(&mut self) -> &mut Self::D {
|
||||
&mut self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef, B: Backend> VecZnxDft<D, B> {
|
||||
pub fn max_size(&self) -> usize {
|
||||
self.max_size
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataMut, B: Backend> VecZnxDft<D, B> {
|
||||
pub fn set_size(&mut self, size: usize) {
|
||||
assert!(size <= self.max_size);
|
||||
self.size = size
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataMut, B: Backend> ZnxZero for VecZnxDft<D, B>
|
||||
where
|
||||
Self: ZnxViewMut,
|
||||
<Self as ZnxView>::Scalar: Zero + Copy,
|
||||
{
|
||||
fn zero(&mut self) {
|
||||
self.raw_mut().fill(<Self as ZnxView>::Scalar::zero())
|
||||
}
|
||||
fn zero_at(&mut self, i: usize, j: usize) {
|
||||
self.at_mut(i, j).fill(<Self as ZnxView>::Scalar::zero());
|
||||
}
|
||||
}
|
||||
|
||||
pub trait VecZnxDftBytesOf {
|
||||
fn bytes_of(n: usize, cols: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
impl<D: DataRef + From<Vec<u8>>, B: Backend> VecZnxDft<D, B>
|
||||
where
|
||||
VecZnxDft<D, B>: VecZnxDftBytesOf,
|
||||
{
|
||||
pub(crate) fn alloc(n: usize, cols: usize, size: usize) -> Self {
|
||||
let data: Vec<u8> = alloc_aligned::<u8>(Self::bytes_of(n, cols, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n: n,
|
||||
cols,
|
||||
size,
|
||||
max_size: size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
|
||||
let data: Vec<u8> = bytes.into();
|
||||
assert!(data.len() == Self::bytes_of(n, cols, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n: n,
|
||||
cols,
|
||||
size,
|
||||
max_size: size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type VecZnxDftOwned<B> = VecZnxDft<Vec<u8>, B>;
|
||||
|
||||
impl<D: Data, B: Backend> VecZnxDft<D, B> {
|
||||
pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
|
||||
Self {
|
||||
data,
|
||||
n,
|
||||
cols,
|
||||
size,
|
||||
max_size: size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait VecZnxDftToRef<B: Backend> {
|
||||
fn to_ref(&self) -> VecZnxDft<&[u8], B>;
|
||||
}
|
||||
|
||||
impl<D: DataRef, B: Backend> VecZnxDftToRef<B> for VecZnxDft<D, B> {
|
||||
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
|
||||
VecZnxDft {
|
||||
data: self.data.as_ref(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: self.size,
|
||||
max_size: self.max_size,
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait VecZnxDftToMut<B: Backend> {
|
||||
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B>;
|
||||
}
|
||||
|
||||
impl<D: DataMut, B: Backend> VecZnxDftToMut<B> for VecZnxDft<D, B> {
|
||||
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
|
||||
VecZnxDft {
|
||||
data: self.data.as_mut(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: self.size,
|
||||
max_size: self.max_size,
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
157
backend/src/hal/layouts/vmp_pmat.rs
Normal file
157
backend/src/hal/layouts/vmp_pmat.rs
Normal file
@@ -0,0 +1,157 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::{
|
||||
alloc_aligned,
|
||||
hal::{
|
||||
api::{DataView, DataViewMut, ZnxInfos},
|
||||
layouts::{Backend, Data, DataMut, DataRef},
|
||||
},
|
||||
};
|
||||
|
||||
#[derive(PartialEq, Eq)]
|
||||
pub struct VmpPMat<D: Data, B: Backend> {
|
||||
data: D,
|
||||
n: usize,
|
||||
size: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
_phantom: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<D: Data, B: Backend> ZnxInfos for VmpPMat<D, B> {
|
||||
fn cols(&self) -> usize {
|
||||
self.cols_in
|
||||
}
|
||||
|
||||
fn rows(&self) -> usize {
|
||||
self.rows
|
||||
}
|
||||
|
||||
fn n(&self) -> usize {
|
||||
self.n
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data, B: Backend> DataView for VmpPMat<D, B> {
|
||||
type D = D;
|
||||
fn data(&self) -> &Self::D {
|
||||
&self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data, B: Backend> DataViewMut for VmpPMat<D, B> {
|
||||
fn data_mut(&mut self) -> &mut Self::D {
|
||||
&mut self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data, B: Backend> VmpPMat<D, B> {
|
||||
pub fn cols_in(&self) -> usize {
|
||||
self.cols_in
|
||||
}
|
||||
|
||||
pub fn cols_out(&self) -> usize {
|
||||
self.cols_out
|
||||
}
|
||||
}
|
||||
|
||||
pub trait VmpPMatBytesOf {
|
||||
fn vmp_pmat_bytes_of(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
impl<D: DataRef + From<Vec<u8>>, B: Backend> VmpPMat<D, B>
|
||||
where
|
||||
B: VmpPMatBytesOf,
|
||||
{
|
||||
pub(crate) fn alloc(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
|
||||
let data: Vec<u8> = alloc_aligned(B::vmp_pmat_bytes_of(n, rows, cols_in, cols_out, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n,
|
||||
size,
|
||||
rows,
|
||||
cols_in,
|
||||
cols_out,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn from_bytes(
|
||||
n: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
bytes: impl Into<Vec<u8>>,
|
||||
) -> Self {
|
||||
let data: Vec<u8> = bytes.into();
|
||||
assert!(data.len() == B::vmp_pmat_bytes_of(n, rows, cols_in, cols_out, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n,
|
||||
size,
|
||||
rows,
|
||||
cols_in,
|
||||
cols_out,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type VmpPMatOwned<B> = VmpPMat<Vec<u8>, B>;
|
||||
pub type VmpPMatRef<'a, B> = VmpPMat<&'a [u8], B>;
|
||||
|
||||
pub trait VmpPMatToRef<B: Backend> {
|
||||
fn to_ref(&self) -> VmpPMat<&[u8], B>;
|
||||
}
|
||||
|
||||
impl<D: DataRef, B: Backend> VmpPMatToRef<B> for VmpPMat<D, B> {
|
||||
fn to_ref(&self) -> VmpPMat<&[u8], B> {
|
||||
VmpPMat {
|
||||
data: self.data.as_ref(),
|
||||
n: self.n,
|
||||
rows: self.rows,
|
||||
cols_in: self.cols_in,
|
||||
cols_out: self.cols_out,
|
||||
size: self.size,
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait VmpPMatToMut<B: Backend> {
|
||||
fn to_mut(&mut self) -> VmpPMat<&mut [u8], B>;
|
||||
}
|
||||
|
||||
impl<D: DataMut, B: Backend> VmpPMatToMut<B> for VmpPMat<D, B> {
|
||||
fn to_mut(&mut self) -> VmpPMat<&mut [u8], B> {
|
||||
VmpPMat {
|
||||
data: self.data.as_mut(),
|
||||
n: self.n,
|
||||
rows: self.rows,
|
||||
cols_in: self.cols_in,
|
||||
cols_out: self.cols_out,
|
||||
size: self.size,
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data, B: Backend> VmpPMat<D, B> {
|
||||
pub(crate) fn from_data(data: D, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
|
||||
Self {
|
||||
data,
|
||||
n,
|
||||
rows,
|
||||
cols_in,
|
||||
cols_out,
|
||||
size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
5
backend/src/hal/mod.rs
Normal file
5
backend/src/hal/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
pub mod api;
|
||||
pub mod delegates;
|
||||
pub mod layouts;
|
||||
pub mod oep;
|
||||
pub mod tests;
|
||||
20
backend/src/hal/oep/mat_znx.rs
Normal file
20
backend/src/hal/oep/mat_znx.rs
Normal file
@@ -0,0 +1,20 @@
|
||||
use crate::hal::layouts::{Backend, MatZnxOwned, Module};
|
||||
|
||||
pub unsafe trait MatZnxAllocImpl<B: Backend> {
|
||||
fn mat_znx_alloc_impl(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxOwned;
|
||||
}
|
||||
|
||||
pub unsafe trait MatZnxAllocBytesImpl<B: Backend> {
|
||||
fn mat_znx_alloc_bytes_impl(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
pub unsafe trait MatZnxFromBytesImpl<B: Backend> {
|
||||
fn mat_znx_from_bytes_impl(
|
||||
module: &Module<B>,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
bytes: Vec<u8>,
|
||||
) -> MatZnxOwned;
|
||||
}
|
||||
19
backend/src/hal/oep/mod.rs
Normal file
19
backend/src/hal/oep/mod.rs
Normal file
@@ -0,0 +1,19 @@
|
||||
mod mat_znx;
|
||||
mod module;
|
||||
mod scalar_znx;
|
||||
mod scratch;
|
||||
mod svp_ppol;
|
||||
mod vec_znx;
|
||||
mod vec_znx_big;
|
||||
mod vec_znx_dft;
|
||||
mod vmp_pmat;
|
||||
|
||||
pub use mat_znx::*;
|
||||
pub use module::*;
|
||||
pub use scalar_znx::*;
|
||||
pub use scratch::*;
|
||||
pub use svp_ppol::*;
|
||||
pub use vec_znx::*;
|
||||
pub use vec_znx_big::*;
|
||||
pub use vec_znx_dft::*;
|
||||
pub use vmp_pmat::*;
|
||||
5
backend/src/hal/oep/module.rs
Normal file
5
backend/src/hal/oep/module.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
use crate::hal::layouts::{Backend, Module};
|
||||
|
||||
pub unsafe trait ModuleNewImpl<B: Backend> {
|
||||
fn new_impl(n: u64) -> Module<B>;
|
||||
}
|
||||
39
backend/src/hal/oep/scalar_znx.rs
Normal file
39
backend/src/hal/oep/scalar_znx.rs
Normal file
@@ -0,0 +1,39 @@
|
||||
use crate::hal::layouts::{Backend, Module, ScalarZnxOwned, ScalarZnxToMut, ScalarZnxToRef};
|
||||
|
||||
pub unsafe trait ScalarZnxFromBytesImpl<B: Backend> {
|
||||
fn scalar_znx_from_bytes_impl(n: usize, cols: usize, bytes: Vec<u8>) -> ScalarZnxOwned;
|
||||
}
|
||||
|
||||
pub unsafe trait ScalarZnxAllocBytesImpl<B: Backend> {
|
||||
fn scalar_znx_alloc_bytes_impl(n: usize, cols: usize) -> usize;
|
||||
}
|
||||
|
||||
pub unsafe trait ScalarZnxAllocImpl<B: Backend> {
|
||||
fn scalar_znx_alloc_impl(n: usize, cols: usize) -> ScalarZnxOwned;
|
||||
}
|
||||
|
||||
pub unsafe trait ScalarZnxAutomorphismImpl<B: Backend> {
|
||||
fn scalar_znx_automorphism_impl<R, A>(module: &Module<B>, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: ScalarZnxToMut,
|
||||
A: ScalarZnxToRef;
|
||||
}
|
||||
|
||||
pub unsafe trait ScalarZnxAutomorphismInplaceIml<B: Backend> {
|
||||
fn scalar_znx_automorphism_inplace_impl<A>(module: &Module<B>, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: ScalarZnxToMut;
|
||||
}
|
||||
|
||||
pub unsafe trait ScalarZnxMulXpMinusOneImpl<B: Backend> {
|
||||
fn scalar_znx_mul_xp_minus_one_impl<R, A>(module: &Module<B>, p: i64, r: &mut R, r_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: ScalarZnxToMut,
|
||||
A: ScalarZnxToRef;
|
||||
}
|
||||
|
||||
pub unsafe trait ScalarZnxMulXpMinusOneInplaceImpl<B: Backend> {
|
||||
fn scalar_znx_mul_xp_minus_one_inplace_impl<R>(module: &Module<B>, p: i64, r: &mut R, r_col: usize)
|
||||
where
|
||||
R: ScalarZnxToMut;
|
||||
}
|
||||
199
backend/src/hal/oep/scratch.rs
Normal file
199
backend/src/hal/oep/scratch.rs
Normal file
@@ -0,0 +1,199 @@
|
||||
use crate::hal::{
|
||||
api::ZnxInfos,
|
||||
layouts::{Backend, DataRef, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat},
|
||||
};
|
||||
|
||||
pub unsafe trait ScratchOwnedAllocImpl<B: Backend> {
|
||||
fn scratch_owned_alloc_impl(size: usize) -> ScratchOwned<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait ScratchOwnedBorrowImpl<B: Backend> {
|
||||
fn scratch_owned_borrow_impl(scratch: &mut ScratchOwned<B>) -> &mut Scratch<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait ScratchFromBytesImpl<B: Backend> {
|
||||
fn scratch_from_bytes_impl(data: &mut [u8]) -> &mut Scratch<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait ScratchAvailableImpl<B: Backend> {
|
||||
fn scratch_available_impl(scratch: &Scratch<B>) -> usize;
|
||||
}
|
||||
|
||||
pub unsafe trait TakeSliceImpl<B: Backend> {
|
||||
fn take_slice_impl<T>(scratch: &mut Scratch<B>, len: usize) -> (&mut [T], &mut Scratch<B>);
|
||||
}
|
||||
|
||||
pub unsafe trait TakeScalarZnxImpl<B: Backend> {
|
||||
fn take_scalar_znx_impl(scratch: &mut Scratch<B>, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Scratch<B>);
|
||||
}
|
||||
|
||||
pub unsafe trait TakeSvpPPolImpl<B: Backend> {
|
||||
fn take_svp_ppol_impl(scratch: &mut Scratch<B>, n: usize, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Scratch<B>);
|
||||
}
|
||||
|
||||
pub unsafe trait TakeVecZnxImpl<B: Backend> {
|
||||
fn take_vec_znx_impl(scratch: &mut Scratch<B>, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Scratch<B>);
|
||||
}
|
||||
|
||||
pub unsafe trait TakeVecZnxSliceImpl<B: Backend> {
|
||||
fn take_vec_znx_slice_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
len: usize,
|
||||
n: usize,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (Vec<VecZnx<&mut [u8]>>, &mut Scratch<B>);
|
||||
}
|
||||
|
||||
pub unsafe trait TakeVecZnxBigImpl<B: Backend> {
|
||||
fn take_vec_znx_big_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
n: usize,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (VecZnxBig<&mut [u8], B>, &mut Scratch<B>);
|
||||
}
|
||||
|
||||
pub unsafe trait TakeVecZnxDftImpl<B: Backend> {
|
||||
fn take_vec_znx_dft_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
n: usize,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (VecZnxDft<&mut [u8], B>, &mut Scratch<B>);
|
||||
}
|
||||
|
||||
pub unsafe trait TakeVecZnxDftSliceImpl<B: Backend> {
|
||||
fn take_vec_znx_dft_slice_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
len: usize,
|
||||
n: usize,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (Vec<VecZnxDft<&mut [u8], B>>, &mut Scratch<B>);
|
||||
}
|
||||
|
||||
pub unsafe trait TakeVmpPMatImpl<B: Backend> {
|
||||
fn take_vmp_pmat_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
n: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
) -> (VmpPMat<&mut [u8], B>, &mut Scratch<B>);
|
||||
}
|
||||
|
||||
pub unsafe trait TakeMatZnxImpl<B: Backend> {
|
||||
fn take_mat_znx_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
n: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
) -> (MatZnx<&mut [u8]>, &mut Scratch<B>);
|
||||
}
|
||||
|
||||
pub trait TakeLikeImpl<'a, B: Backend, T> {
|
||||
type Output;
|
||||
fn take_like_impl(scratch: &'a mut Scratch<B>, template: &T) -> (Self::Output, &'a mut Scratch<B>);
|
||||
}
|
||||
|
||||
impl<'a, B: Backend, D> TakeLikeImpl<'a, B, VmpPMat<D, B>> for B
|
||||
where
|
||||
B: TakeVmpPMatImpl<B>,
|
||||
D: DataRef,
|
||||
{
|
||||
type Output = VmpPMat<&'a mut [u8], B>;
|
||||
|
||||
fn take_like_impl(scratch: &'a mut Scratch<B>, template: &VmpPMat<D, B>) -> (Self::Output, &'a mut Scratch<B>) {
|
||||
B::take_vmp_pmat_impl(
|
||||
scratch,
|
||||
template.n(),
|
||||
template.rows(),
|
||||
template.cols_in(),
|
||||
template.cols_out(),
|
||||
template.size(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, B: Backend, D> TakeLikeImpl<'a, B, MatZnx<D>> for B
|
||||
where
|
||||
B: TakeMatZnxImpl<B>,
|
||||
D: DataRef,
|
||||
{
|
||||
type Output = MatZnx<&'a mut [u8]>;
|
||||
|
||||
fn take_like_impl(scratch: &'a mut Scratch<B>, template: &MatZnx<D>) -> (Self::Output, &'a mut Scratch<B>) {
|
||||
B::take_mat_znx_impl(
|
||||
scratch,
|
||||
template.n(),
|
||||
template.rows(),
|
||||
template.cols_in(),
|
||||
template.cols_out(),
|
||||
template.size(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, B: Backend, D> TakeLikeImpl<'a, B, VecZnxDft<D, B>> for B
|
||||
where
|
||||
B: TakeVecZnxDftImpl<B>,
|
||||
D: DataRef,
|
||||
{
|
||||
type Output = VecZnxDft<&'a mut [u8], B>;
|
||||
|
||||
fn take_like_impl(scratch: &'a mut Scratch<B>, template: &VecZnxDft<D, B>) -> (Self::Output, &'a mut Scratch<B>) {
|
||||
B::take_vec_znx_dft_impl(scratch, template.n(), template.cols(), template.size())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, B: Backend, D> TakeLikeImpl<'a, B, VecZnxBig<D, B>> for B
|
||||
where
|
||||
B: TakeVecZnxBigImpl<B>,
|
||||
D: DataRef,
|
||||
{
|
||||
type Output = VecZnxBig<&'a mut [u8], B>;
|
||||
|
||||
fn take_like_impl(scratch: &'a mut Scratch<B>, template: &VecZnxBig<D, B>) -> (Self::Output, &'a mut Scratch<B>) {
|
||||
B::take_vec_znx_big_impl(scratch, template.n(), template.cols(), template.size())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, B: Backend, D> TakeLikeImpl<'a, B, SvpPPol<D, B>> for B
|
||||
where
|
||||
B: TakeSvpPPolImpl<B>,
|
||||
D: DataRef,
|
||||
{
|
||||
type Output = SvpPPol<&'a mut [u8], B>;
|
||||
|
||||
fn take_like_impl(scratch: &'a mut Scratch<B>, template: &SvpPPol<D, B>) -> (Self::Output, &'a mut Scratch<B>) {
|
||||
B::take_svp_ppol_impl(scratch, template.n(), template.cols())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, B: Backend, D> TakeLikeImpl<'a, B, VecZnx<D>> for B
|
||||
where
|
||||
B: TakeVecZnxImpl<B>,
|
||||
D: DataRef,
|
||||
{
|
||||
type Output = VecZnx<&'a mut [u8]>;
|
||||
|
||||
fn take_like_impl(scratch: &'a mut Scratch<B>, template: &VecZnx<D>) -> (Self::Output, &'a mut Scratch<B>) {
|
||||
B::take_vec_znx_impl(scratch, template.n(), template.cols(), template.size())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, B: Backend, D> TakeLikeImpl<'a, B, ScalarZnx<D>> for B
|
||||
where
|
||||
B: TakeScalarZnxImpl<B>,
|
||||
D: DataRef,
|
||||
{
|
||||
type Output = ScalarZnx<&'a mut [u8]>;
|
||||
|
||||
fn take_like_impl(scratch: &'a mut Scratch<B>, template: &ScalarZnx<D>) -> (Self::Output, &'a mut Scratch<B>) {
|
||||
B::take_scalar_znx_impl(scratch, template.n(), template.cols())
|
||||
}
|
||||
}
|
||||
37
backend/src/hal/oep/svp_ppol.rs
Normal file
37
backend/src/hal/oep/svp_ppol.rs
Normal file
@@ -0,0 +1,37 @@
|
||||
use crate::hal::layouts::{
|
||||
Backend, Module, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDftToMut, VecZnxDftToRef,
|
||||
};
|
||||
|
||||
pub unsafe trait SvpPPolFromBytesImpl<B: Backend> {
|
||||
fn svp_ppol_from_bytes_impl(n: usize, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait SvpPPolAllocImpl<B: Backend> {
|
||||
fn svp_ppol_alloc_impl(n: usize, cols: usize) -> SvpPPolOwned<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait SvpPPolAllocBytesImpl<B: Backend> {
|
||||
fn svp_ppol_alloc_bytes_impl(n: usize, cols: usize) -> usize;
|
||||
}
|
||||
|
||||
pub unsafe trait SvpPrepareImpl<B: Backend> {
|
||||
fn svp_prepare_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: SvpPPolToMut<B>,
|
||||
A: ScalarZnxToRef;
|
||||
}
|
||||
|
||||
pub unsafe trait SvpApplyImpl<B: Backend> {
|
||||
fn svp_apply_impl<R, A, C>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: SvpPPolToRef<B>,
|
||||
C: VecZnxDftToRef<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait SvpApplyInplaceImpl: Backend {
|
||||
fn svp_apply_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: SvpPPolToRef<Self>;
|
||||
}
|
||||
465
backend/src/hal/oep/vec_znx.rs
Normal file
465
backend/src/hal/oep/vec_znx.rs
Normal file
@@ -0,0 +1,465 @@
|
||||
use rand_distr::Distribution;
|
||||
use rug::Float;
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::hal::layouts::{Backend, Module, ScalarZnxToRef, Scratch, VecZnxOwned, VecZnxToMut, VecZnxToRef};
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::hal::layouts::VecZnx::new] for reference code.
|
||||
/// * See [crate::hal::api::VecZnxAlloc] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
/// * See test \[TODO\]
|
||||
pub unsafe trait VecZnxAllocImpl<B: Backend> {
|
||||
fn vec_znx_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxOwned;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::hal::layouts::VecZnx::from_bytes] for reference code.
|
||||
/// * See [crate::hal::api::VecZnxFromBytes] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxFromBytesImpl<B: Backend> {
|
||||
fn vec_znx_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxOwned;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::hal::layouts::VecZnx::alloc_bytes] for reference code.
|
||||
/// * See [crate::hal::api::VecZnxAllocBytes] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxAllocBytesImpl<B: Backend> {
|
||||
fn vec_znx_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [vec_znx_normalize_base2k_tmp_bytes_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L245C17-L245C55) for reference code.
|
||||
/// * See [crate::hal::api::VecZnxNormalizeTmpBytes] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxNormalizeTmpBytesImpl<B: Backend> {
|
||||
fn vec_znx_normalize_tmp_bytes_impl(module: &Module<B>, n: usize) -> usize;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [vec_znx_normalize_base2k_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L212) for reference code.
|
||||
/// * See [crate::hal::api::VecZnxNormalize] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxNormalizeImpl<B: Backend> {
|
||||
fn vec_znx_normalize_impl<R, A>(
|
||||
module: &Module<B>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [vec_znx_normalize_base2k_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L212) for reference code.
|
||||
/// * See [crate::hal::api::VecZnxNormalizeInplace] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxNormalizeInplaceImpl<B: Backend> {
|
||||
fn vec_znx_normalize_inplace_impl<A>(module: &Module<B>, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
A: VecZnxToMut;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [vec_znx_add_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L86) for reference code.
|
||||
/// * See [crate::hal::api::VecZnxAdd] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxAddImpl<B: Backend> {
|
||||
fn vec_znx_add_impl<R, A, C>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
C: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [vec_znx_add_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L86) for reference code.
|
||||
/// * See [crate::hal::api::VecZnxAddInplace] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxAddInplaceImpl<B: Backend> {
|
||||
fn vec_znx_add_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [vec_znx_add_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L86) for reference code.
|
||||
/// * See [crate::hal::api::VecZnxAddScalarInplace] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxAddScalarInplaceImpl<B: Backend> {
|
||||
fn vec_znx_add_scalar_inplace_impl<R, A>(
|
||||
module: &Module<B>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
res_limb: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [vec_znx_sub_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L125) for reference code.
|
||||
/// * See [crate::hal::api::VecZnxSub] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxSubImpl<B: Backend> {
|
||||
fn vec_znx_sub_impl<R, A, C>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
C: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [vec_znx_sub_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L125) for reference code.
|
||||
/// * See [crate::hal::api::VecZnxSubABInplace] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxSubABInplaceImpl<B: Backend> {
|
||||
fn vec_znx_sub_ab_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [vec_znx_sub_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L125) for reference code.
|
||||
/// * See [crate::hal::api::VecZnxSubBAInplace] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxSubBAInplaceImpl<B: Backend> {
|
||||
fn vec_znx_sub_ba_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [vec_znx_sub_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L125) for reference code.
|
||||
/// * See [crate::hal::api::VecZnxSubScalarInplace] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxSubScalarInplaceImpl<B: Backend> {
|
||||
fn vec_znx_sub_scalar_inplace_impl<R, A>(
|
||||
module: &Module<B>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
res_limb: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [vec_znx_negate_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L322C13-L322C31) for reference code.
|
||||
/// * See [crate::hal::api::VecZnxNegate] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxNegateImpl<B: Backend> {
|
||||
fn vec_znx_negate_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [vec_znx_negate_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L322C13-L322C31) for reference code.
|
||||
/// * See [crate::hal::api::VecZnxNegateInplace] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxNegateInplaceImpl<B: Backend> {
|
||||
fn vec_znx_negate_inplace_impl<A>(module: &Module<B>, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxToMut;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::implementation::cpu_spqlios::vec_znx::vec_znx_rsh_inplace_ref] for reference code.
|
||||
/// * See [crate::hal::api::VecZnxRshInplace] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxRshInplaceImpl<B: Backend> {
|
||||
fn vec_znx_rsh_inplace_impl<A>(module: &Module<B>, basek: usize, k: usize, a: &mut A)
|
||||
where
|
||||
A: VecZnxToMut;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::implementation::cpu_spqlios::vec_znx::vec_znx_lsh_inplace_ref] for reference code.
|
||||
/// * See [crate::hal::api::VecZnxLshInplace] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxLshInplaceImpl<B: Backend> {
|
||||
fn vec_znx_lsh_inplace_impl<A>(module: &Module<B>, basek: usize, k: usize, a: &mut A)
|
||||
where
|
||||
A: VecZnxToMut;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [vec_znx_rotate_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L164) for reference code.
|
||||
/// * See [crate::hal::api::VecZnxRotate] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxRotateImpl<B: Backend> {
|
||||
fn vec_znx_rotate_impl<R, A>(module: &Module<B>, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [vec_znx_rotate_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L164) for reference code.
|
||||
/// * See [crate::hal::api::VecZnxRotateInplace] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxRotateInplaceImpl<B: Backend> {
|
||||
fn vec_znx_rotate_inplace_impl<A>(module: &Module<B>, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxToMut;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [vec_znx_automorphism_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L188) for reference code.
|
||||
/// * See [crate::hal::api::VecZnxAutomorphism] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxAutomorphismImpl<B: Backend> {
|
||||
fn vec_znx_automorphism_impl<R, A>(module: &Module<B>, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [vec_znx_automorphism_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L188) for reference code.
|
||||
/// * See [crate::hal::api::VecZnxAutomorphismInplace] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxAutomorphismInplaceImpl<B: Backend> {
|
||||
fn vec_znx_automorphism_inplace_impl<A>(module: &Module<B>, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxToMut;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [vec_znx_mul_xp_minus_one_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/7160f588da49712a042931ea247b4259b95cefcc/spqlios/arithmetic/vec_znx.c#L200C13-L200C41) for reference code.
|
||||
/// * See [crate::hal::api::VecZnxMulXpMinusOne] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxMulXpMinusOneImpl<B: Backend> {
|
||||
fn vec_znx_mul_xp_minus_one_impl<R, A>(module: &Module<B>, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [vec_znx_mul_xp_minus_one_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/7160f588da49712a042931ea247b4259b95cefcc/spqlios/arithmetic/vec_znx.c#L200C13-L200C41) for reference code.
|
||||
/// * See [crate::hal::api::VecZnxMulXpMinusOneInplace] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxMulXpMinusOneInplaceImpl<B: Backend> {
|
||||
fn vec_znx_mul_xp_minus_one_inplace_impl<R>(module: &Module<B>, p: i64, res: &mut R, res_col: usize)
|
||||
where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::implementation::cpu_spqlios::vec_znx::vec_znx_split_ref] for reference code.
|
||||
/// * See [crate::hal::api::VecZnxSplit] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxSplitImpl<B: Backend> {
|
||||
fn vec_znx_split_impl<R, A>(
|
||||
module: &Module<B>,
|
||||
res: &mut Vec<R>,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::implementation::cpu_spqlios::vec_znx::vec_znx_merge_ref] for reference code.
|
||||
/// * See [crate::hal::api::VecZnxMerge] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxMergeImpl<B: Backend> {
|
||||
fn vec_znx_merge_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: Vec<A>, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::implementation::cpu_spqlios::vec_znx::vec_znx_switch_degree_ref] for reference code.
|
||||
/// * See [crate::hal::api::VecZnxSwithcDegree] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxSwithcDegreeImpl<B: Backend> {
|
||||
fn vec_znx_switch_degree_impl<R: VecZnxToMut, A: VecZnxToRef>(
|
||||
module: &Module<B>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
);
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::implementation::cpu_spqlios::vec_znx::vec_znx_copy_ref] for reference code.
|
||||
/// * See [crate::hal::api::VecZnxCopy] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxCopyImpl<B: Backend> {
|
||||
fn vec_znx_copy_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::hal::api::VecZnxStd] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxStdImpl<B: Backend> {
|
||||
fn vec_znx_std_impl<A>(module: &Module<B>, basek: usize, a: &A, a_col: usize) -> f64
|
||||
where
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::hal::api::VecZnxFillUniform] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxFillUniformImpl<B: Backend> {
|
||||
fn vec_znx_fill_uniform_impl<R>(module: &Module<B>, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source)
|
||||
where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::hal::api::VecZnxFillDistF64] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxFillDistF64Impl<B: Backend> {
|
||||
fn vec_znx_fill_dist_f64_impl<R, D: Distribution<f64>>(
|
||||
module: &Module<B>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::hal::api::VecZnxAddDistF64] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxAddDistF64Impl<B: Backend> {
|
||||
fn vec_znx_add_dist_f64_impl<R, D: Distribution<f64>>(
|
||||
module: &Module<B>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::hal::api::VecZnxFillNormal] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxFillNormalImpl<B: Backend> {
|
||||
fn vec_znx_fill_normal_impl<R>(
|
||||
module: &Module<B>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::hal::api::VecZnxAddNormal] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxAddNormalImpl<B: Backend> {
|
||||
fn vec_znx_add_normal_impl<R>(
|
||||
module: &Module<B>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See \[TODO\] for reference code.
|
||||
/// * See [crate::hal::api::VecZnxEncodeVeci64] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxEncodeVeci64Impl<B: Backend> {
|
||||
fn encode_vec_i64_impl<R>(
|
||||
module: &Module<B>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
data: &[i64],
|
||||
log_max: usize,
|
||||
) where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See \[TODO\] for reference code.
|
||||
/// * See [crate::hal::api::VecZnxEncodeCoeffsi64] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxEncodeCoeffsi64Impl<B: Backend> {
|
||||
fn encode_coeff_i64_impl<R>(
|
||||
module: &Module<B>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
i: usize,
|
||||
data: i64,
|
||||
log_max: usize,
|
||||
) where
|
||||
R: VecZnxToMut;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See \[TODO\] for reference code.
|
||||
/// * See [crate::hal::api::VecZnxDecodeVeci64] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxDecodeVeci64Impl<B: Backend> {
|
||||
fn decode_vec_i64_impl<R>(module: &Module<B>, basek: usize, res: &R, res_col: usize, k: usize, data: &mut [i64])
|
||||
where
|
||||
R: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See \[TODO\] for reference code.
|
||||
/// * See [crate::hal::api::VecZnxDecodeCoeffsi64] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxDecodeCoeffsi64Impl<B: Backend> {
|
||||
fn decode_coeff_i64_impl<R>(module: &Module<B>, basek: usize, res: &R, res_col: usize, k: usize, i: usize) -> i64
|
||||
where
|
||||
R: VecZnxToRef;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See \[TODO\] for reference code.
|
||||
/// * See [crate::hal::api::VecZnxDecodeVecFloat] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxDecodeVecFloatImpl<B: Backend> {
|
||||
fn decode_vec_float_impl<R>(module: &Module<B>, basek: usize, res: &R, res_col: usize, data: &mut [Float])
|
||||
where
|
||||
R: VecZnxToRef;
|
||||
}
|
||||
208
backend/src/hal/oep/vec_znx_big.rs
Normal file
208
backend/src/hal/oep/vec_znx_big.rs
Normal file
@@ -0,0 +1,208 @@
|
||||
use rand_distr::Distribution;
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::hal::layouts::{Backend, Module, Scratch, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef};
|
||||
|
||||
pub unsafe trait VecZnxBigAllocImpl<B: Backend> {
|
||||
fn vec_znx_big_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxBigOwned<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigFromBytesImpl<B: Backend> {
|
||||
fn vec_znx_big_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigAllocBytesImpl<B: Backend> {
|
||||
fn vec_znx_big_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigAddNormalImpl<B: Backend> {
|
||||
fn add_normal_impl<R: VecZnxBigToMut<B>>(
|
||||
module: &Module<B>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
);
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigFillNormalImpl<B: Backend> {
|
||||
fn fill_normal_impl<R: VecZnxBigToMut<B>>(
|
||||
module: &Module<B>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
);
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigFillDistF64Impl<B: Backend> {
|
||||
fn fill_dist_f64_impl<R: VecZnxBigToMut<B>, D: Distribution<f64>>(
|
||||
module: &Module<B>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
);
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigAddDistF64Impl<B: Backend> {
|
||||
fn add_dist_f64_impl<R: VecZnxBigToMut<B>, D: Distribution<f64>>(
|
||||
module: &Module<B>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
);
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigAddImpl<B: Backend> {
|
||||
fn vec_znx_big_add_impl<R, A, C>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>,
|
||||
C: VecZnxBigToRef<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigAddInplaceImpl<B: Backend> {
|
||||
fn vec_znx_big_add_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigAddSmallImpl<B: Backend> {
|
||||
fn vec_znx_big_add_small_impl<R, A, C>(
|
||||
module: &Module<B>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &C,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>,
|
||||
C: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigAddSmallInplaceImpl<B: Backend> {
|
||||
fn vec_znx_big_add_small_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigSubImpl<B: Backend> {
|
||||
fn vec_znx_big_sub_impl<R, A, C>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>,
|
||||
C: VecZnxBigToRef<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigSubABInplaceImpl<B: Backend> {
|
||||
fn vec_znx_big_sub_ab_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigSubBAInplaceImpl<B: Backend> {
|
||||
fn vec_znx_big_sub_ba_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigSubSmallAImpl<B: Backend> {
|
||||
fn vec_znx_big_sub_small_a_impl<R, A, C>(
|
||||
module: &Module<B>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &C,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxToRef,
|
||||
C: VecZnxBigToRef<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigSubSmallAInplaceImpl<B: Backend> {
|
||||
fn vec_znx_big_sub_small_a_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigSubSmallBImpl<B: Backend> {
|
||||
fn vec_znx_big_sub_small_b_impl<R, A, C>(
|
||||
module: &Module<B>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &C,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>,
|
||||
C: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigSubSmallBInplaceImpl<B: Backend> {
|
||||
fn vec_znx_big_sub_small_b_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigNegateInplaceImpl<B: Backend> {
|
||||
fn vec_znx_big_negate_inplace_impl<A>(module: &Module<B>, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxBigToMut<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigNormalizeTmpBytesImpl<B: Backend> {
|
||||
fn vec_znx_big_normalize_tmp_bytes_impl(module: &Module<B>, n: usize) -> usize;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigNormalizeImpl<B: Backend> {
|
||||
fn vec_znx_big_normalize_impl<R, A>(
|
||||
module: &Module<B>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxBigToRef<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigAutomorphismImpl<B: Backend> {
|
||||
fn vec_znx_big_automorphism_impl<R, A>(module: &Module<B>, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxBigToRef<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxBigAutomorphismInplaceImpl<B: Backend> {
|
||||
fn vec_znx_big_automorphism_inplace_impl<A>(module: &Module<B>, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxBigToMut<B>;
|
||||
}
|
||||
117
backend/src/hal/oep/vec_znx_dft.rs
Normal file
117
backend/src/hal/oep/vec_znx_dft.rs
Normal file
@@ -0,0 +1,117 @@
|
||||
use crate::hal::layouts::{
|
||||
Backend, Data, Module, Scratch, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef,
|
||||
VecZnxToRef,
|
||||
};
|
||||
|
||||
pub unsafe trait VecZnxDftAllocImpl<B: Backend> {
|
||||
fn vec_znx_dft_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxDftOwned<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxDftFromBytesImpl<B: Backend> {
|
||||
fn vec_znx_dft_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxDftAllocBytesImpl<B: Backend> {
|
||||
fn vec_znx_dft_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxDftToVecZnxBigTmpBytesImpl<B: Backend> {
|
||||
fn vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(module: &Module<B>) -> usize;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxDftToVecZnxBigImpl<B: Backend> {
|
||||
fn vec_znx_dft_to_vec_znx_big_impl<R, A>(
|
||||
module: &Module<B>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxDftToRef<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxDftToVecZnxBigTmpAImpl<B: Backend> {
|
||||
fn vec_znx_dft_to_vec_znx_big_tmp_a_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxDftToMut<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxDftToVecZnxBigConsumeImpl<B: Backend> {
|
||||
fn vec_znx_dft_to_vec_znx_big_consume_impl<D: Data>(module: &Module<B>, a: VecZnxDft<D, B>) -> VecZnxBig<D, B>
|
||||
where
|
||||
VecZnxDft<D, B>: VecZnxDftToMut<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxDftAddImpl<B: Backend> {
|
||||
fn vec_znx_dft_add_impl<R, A, D>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
D: VecZnxDftToRef<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxDftAddInplaceImpl<B: Backend> {
|
||||
fn vec_znx_dft_add_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxDftSubImpl<B: Backend> {
|
||||
fn vec_znx_dft_sub_impl<R, A, D>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
D: VecZnxDftToRef<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxDftSubABInplaceImpl<B: Backend> {
|
||||
fn vec_znx_dft_sub_ab_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxDftSubBAInplaceImpl<B: Backend> {
|
||||
fn vec_znx_dft_sub_ba_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxDftCopyImpl<B: Backend> {
|
||||
fn vec_znx_dft_copy_impl<R, A>(
|
||||
module: &Module<B>,
|
||||
step: usize,
|
||||
offset: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
) where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxDftFromVecZnxImpl<B: Backend> {
|
||||
fn vec_znx_dft_from_vec_znx_impl<R, A>(
|
||||
module: &Module<B>,
|
||||
step: usize,
|
||||
offset: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
) where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxDftZeroImpl<B: Backend> {
|
||||
fn vec_znx_dft_zero_impl<R>(module: &Module<B>, res: &mut R)
|
||||
where
|
||||
R: VecZnxDftToMut<B>;
|
||||
}
|
||||
74
backend/src/hal/oep/vmp_pmat.rs
Normal file
74
backend/src/hal/oep/vmp_pmat.rs
Normal file
@@ -0,0 +1,74 @@
|
||||
use crate::hal::layouts::{
|
||||
Backend, MatZnxToRef, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VmpPMatOwned, VmpPMatToMut, VmpPMatToRef,
|
||||
};
|
||||
|
||||
pub unsafe trait VmpPMatAllocImpl<B: Backend> {
|
||||
fn vmp_pmat_alloc_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VmpPMatAllocBytesImpl<B: Backend> {
|
||||
fn vmp_pmat_alloc_bytes_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
pub unsafe trait VmpPMatFromBytesImpl<B: Backend> {
|
||||
fn vmp_pmat_from_bytes_impl(
|
||||
n: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
bytes: Vec<u8>,
|
||||
) -> VmpPMatOwned<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VmpPrepareTmpBytesImpl<B: Backend> {
|
||||
fn vmp_prepare_tmp_bytes_impl(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
pub unsafe trait VmpPMatPrepareImpl<B: Backend> {
|
||||
fn vmp_prepare_impl<R, A>(module: &Module<B>, res: &mut R, a: &A, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VmpPMatToMut<B>,
|
||||
A: MatZnxToRef;
|
||||
}
|
||||
|
||||
pub unsafe trait VmpApplyTmpBytesImpl<B: Backend> {
|
||||
fn vmp_apply_tmp_bytes_impl(
|
||||
module: &Module<B>,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
b_cols_in: usize,
|
||||
b_cols_out: usize,
|
||||
b_size: usize,
|
||||
) -> usize;
|
||||
}
|
||||
|
||||
pub unsafe trait VmpApplyImpl<B: Backend> {
|
||||
fn vmp_apply_impl<R, A, C>(module: &Module<B>, res: &mut R, a: &A, b: &C, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
C: VmpPMatToRef<B>;
|
||||
}
|
||||
|
||||
pub unsafe trait VmpApplyAddTmpBytesImpl<B: Backend> {
|
||||
fn vmp_apply_add_tmp_bytes_impl(
|
||||
module: &Module<B>,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
b_cols_in: usize,
|
||||
b_cols_out: usize,
|
||||
b_size: usize,
|
||||
) -> usize;
|
||||
}
|
||||
|
||||
pub unsafe trait VmpApplyAddImpl<B: Backend> {
|
||||
// Same as [MatZnxDftOps::vmp_apply] except result is added on R instead of overwritting R.
|
||||
fn vmp_apply_add_impl<R, A, C>(module: &Module<B>, res: &mut R, a: &A, b: &C, scale: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
C: VmpPMatToRef<B>;
|
||||
}
|
||||
1
backend/src/hal/tests/mod.rs
Normal file
1
backend/src/hal/tests/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod vec_znx;
|
||||
120
backend/src/hal/tests/vec_znx/generics.rs
Normal file
120
backend/src/hal/tests/vec_znx/generics.rs
Normal file
@@ -0,0 +1,120 @@
|
||||
use itertools::izip;
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::hal::{
|
||||
api::{
|
||||
VecZnxAddNormal, VecZnxAlloc, VecZnxDecodeVeci64, VecZnxEncodeVeci64, VecZnxFillUniform, VecZnxStd, ZnxInfos, ZnxView,
|
||||
ZnxViewMut,
|
||||
},
|
||||
layouts::{Backend, Module, VecZnx},
|
||||
};
|
||||
|
||||
pub fn test_vec_znx_fill_uniform<B: Backend>(module: &Module<B>)
|
||||
where
|
||||
Module<B>: VecZnxFillUniform + VecZnxStd + VecZnxAlloc,
|
||||
{
|
||||
let basek: usize = 17;
|
||||
let size: usize = 5;
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let cols: usize = 2;
|
||||
let zero: Vec<i64> = vec![0; module.n()];
|
||||
let one_12_sqrt: f64 = 0.28867513459481287;
|
||||
(0..cols).for_each(|col_i| {
|
||||
let mut a: VecZnx<_> = module.vec_znx_alloc(cols, size);
|
||||
module.vec_znx_fill_uniform(basek, &mut a, col_i, size * basek, &mut source);
|
||||
(0..cols).for_each(|col_j| {
|
||||
if col_j != col_i {
|
||||
(0..size).for_each(|limb_i| {
|
||||
assert_eq!(a.at(col_j, limb_i), zero);
|
||||
})
|
||||
} else {
|
||||
let std: f64 = module.vec_znx_std(basek, &a, col_i);
|
||||
assert!(
|
||||
(std - one_12_sqrt).abs() < 0.01,
|
||||
"std={} ~!= {}",
|
||||
std,
|
||||
one_12_sqrt
|
||||
);
|
||||
}
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
pub fn test_vec_znx_add_normal<B: Backend>(module: &Module<B>)
|
||||
where
|
||||
Module<B>: VecZnxAddNormal + VecZnxStd + VecZnxAlloc,
|
||||
{
|
||||
let basek: usize = 17;
|
||||
let k: usize = 2 * 17;
|
||||
let size: usize = 5;
|
||||
let sigma: f64 = 3.2;
|
||||
let bound: f64 = 6.0 * sigma;
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let cols: usize = 2;
|
||||
let zero: Vec<i64> = vec![0; module.n()];
|
||||
let k_f64: f64 = (1u64 << k as u64) as f64;
|
||||
(0..cols).for_each(|col_i| {
|
||||
let mut a: VecZnx<_> = module.vec_znx_alloc(cols, size);
|
||||
module.vec_znx_add_normal(basek, &mut a, col_i, k, &mut source, sigma, bound);
|
||||
(0..cols).for_each(|col_j| {
|
||||
if col_j != col_i {
|
||||
(0..size).for_each(|limb_i| {
|
||||
assert_eq!(a.at(col_j, limb_i), zero);
|
||||
})
|
||||
} else {
|
||||
let std: f64 = module.vec_znx_std(basek, &a, col_i) * k_f64;
|
||||
assert!((std - sigma).abs() < 0.1, "std={} ~!= {}", std, sigma);
|
||||
}
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
pub fn test_vec_znx_encode_vec_i64_lo_norm<B: Backend>(module: &Module<B>)
|
||||
where
|
||||
Module<B>: VecZnxEncodeVeci64 + VecZnxDecodeVeci64 + VecZnxAlloc,
|
||||
{
|
||||
let basek: usize = 17;
|
||||
let size: usize = 5;
|
||||
let k: usize = size * basek - 5;
|
||||
let mut a: VecZnx<_> = module.vec_znx_alloc(2, size);
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let raw: &mut [i64] = a.raw_mut();
|
||||
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
|
||||
(0..a.cols()).for_each(|col_i| {
|
||||
let mut have: Vec<i64> = vec![i64::default(); module.n()];
|
||||
have.iter_mut()
|
||||
.for_each(|x| *x = (source.next_i64() << 56) >> 56);
|
||||
module.encode_vec_i64(basek, &mut a, col_i, k, &have, 10);
|
||||
let mut want: Vec<i64> = vec![i64::default(); module.n()];
|
||||
module.decode_vec_i64(basek, &a, col_i, k, &mut want);
|
||||
izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b));
|
||||
});
|
||||
}
|
||||
|
||||
pub fn test_vec_znx_encode_vec_i64_hi_norm<B: Backend>(module: &Module<B>)
|
||||
where
|
||||
Module<B>: VecZnxEncodeVeci64 + VecZnxDecodeVeci64 + VecZnxAlloc,
|
||||
{
|
||||
let basek: usize = 17;
|
||||
let size: usize = 5;
|
||||
for k in [1, basek / 2, size * basek - 5] {
|
||||
let mut a: VecZnx<_> = module.vec_znx_alloc(2, size);
|
||||
let mut source = Source::new([0u8; 32]);
|
||||
let raw: &mut [i64] = a.raw_mut();
|
||||
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
|
||||
(0..a.cols()).for_each(|col_i| {
|
||||
let mut have: Vec<i64> = vec![i64::default(); module.n()];
|
||||
have.iter_mut().for_each(|x| {
|
||||
if k < 64 {
|
||||
*x = source.next_u64n(1 << k, (1 << k) - 1) as i64;
|
||||
} else {
|
||||
*x = source.next_i64();
|
||||
}
|
||||
});
|
||||
module.encode_vec_i64(basek, &mut a, col_i, k, &have, 63);
|
||||
let mut want: Vec<i64> = vec![i64::default(); module.n()];
|
||||
module.decode_vec_i64(basek, &a, col_i, k, &mut want);
|
||||
izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b));
|
||||
})
|
||||
}
|
||||
}
|
||||
2
backend/src/hal/tests/vec_znx/mod.rs
Normal file
2
backend/src/hal/tests/vec_znx/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
mod generics;
|
||||
pub use generics::*;
|
||||
@@ -1,48 +1,47 @@
|
||||
use crate::ffi::module::MODULE;
|
||||
use crate::ffi::vec_znx_dft::VEC_ZNX_DFT;
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct svp_ppol_t {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type SVP_PPOL = svp_ppol_t;
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn bytes_of_svp_ppol(module: *const MODULE) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_svp_ppol(module: *const MODULE) -> *mut SVP_PPOL;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn delete_svp_ppol(res: *mut SVP_PPOL);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn svp_prepare(module: *const MODULE, ppol: *mut SVP_PPOL, pol: *const i64);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn svp_apply_dft(
|
||||
module: *const MODULE,
|
||||
res: *const VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
ppol: *const SVP_PPOL,
|
||||
a: *const i64,
|
||||
a_size: u64,
|
||||
a_sl: u64,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn svp_apply_dft_to_dft(
|
||||
module: *const MODULE,
|
||||
res: *const VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
res_cols: u64,
|
||||
ppol: *const SVP_PPOL,
|
||||
a: *const VEC_ZNX_DFT,
|
||||
a_size: u64,
|
||||
a_cols: u64,
|
||||
);
|
||||
}
|
||||
use crate::implementation::cpu_spqlios::ffi::{module::MODULE, vec_znx_dft::VEC_ZNX_DFT};
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct svp_ppol_t {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type SVP_PPOL = svp_ppol_t;
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn bytes_of_svp_ppol(module: *const MODULE) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_svp_ppol(module: *const MODULE) -> *mut SVP_PPOL;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn delete_svp_ppol(res: *mut SVP_PPOL);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn svp_prepare(module: *const MODULE, ppol: *mut SVP_PPOL, pol: *const i64);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn svp_apply_dft(
|
||||
module: *const MODULE,
|
||||
res: *const VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
ppol: *const SVP_PPOL,
|
||||
a: *const i64,
|
||||
a_size: u64,
|
||||
a_sl: u64,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn svp_apply_dft_to_dft(
|
||||
module: *const MODULE,
|
||||
res: *const VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
res_cols: u64,
|
||||
ppol: *const SVP_PPOL,
|
||||
a: *const VEC_ZNX_DFT,
|
||||
a_size: u64,
|
||||
a_cols: u64,
|
||||
);
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::ffi::module::MODULE;
|
||||
use crate::implementation::cpu_spqlios::ffi::module::MODULE;
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_add(
|
||||
@@ -28,6 +28,19 @@ unsafe extern "C" {
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_mul_xp_minus_one(
|
||||
module: *const MODULE,
|
||||
p: i64,
|
||||
res: *mut i64,
|
||||
res_size: u64,
|
||||
res_sl: u64,
|
||||
a: *const i64,
|
||||
a_size: u64,
|
||||
a_sl: u64,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_negate(
|
||||
module: *const MODULE,
|
||||
@@ -86,6 +99,7 @@ unsafe extern "C" {
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_normalize_base2k(
|
||||
module: *const MODULE,
|
||||
n: u64,
|
||||
base2k: u64,
|
||||
res: *mut i64,
|
||||
res_size: u64,
|
||||
@@ -97,5 +111,5 @@ unsafe extern "C" {
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64;
|
||||
pub unsafe fn vec_znx_normalize_base2k_tmp_bytes(module: *const MODULE, n: u64) -> u64;
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::ffi::module::MODULE;
|
||||
use crate::implementation::cpu_spqlios::ffi::module::MODULE;
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
@@ -103,12 +103,13 @@ unsafe extern "C" {
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64;
|
||||
pub unsafe fn vec_znx_big_normalize_base2k_tmp_bytes(module: *const MODULE, n: u64) -> u64;
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_normalize_base2k(
|
||||
module: *const MODULE,
|
||||
n: u64,
|
||||
log2_base2k: u64,
|
||||
res: *mut i64,
|
||||
res_size: u64,
|
||||
@@ -122,6 +123,7 @@ unsafe extern "C" {
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_range_normalize_base2k(
|
||||
module: *const MODULE,
|
||||
n: u64,
|
||||
log2_base2k: u64,
|
||||
res: *mut i64,
|
||||
res_size: u64,
|
||||
@@ -135,7 +137,7 @@ unsafe extern "C" {
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_range_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64;
|
||||
pub unsafe fn vec_znx_big_range_normalize_base2k_tmp_bytes(module: *const MODULE, n: u64) -> u64;
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
@@ -1,86 +1,85 @@
|
||||
use crate::ffi::module::MODULE;
|
||||
use crate::ffi::vec_znx_big::VEC_ZNX_BIG;
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct vec_znx_dft_t {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type VEC_ZNX_DFT = vec_znx_dft_t;
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn bytes_of_vec_znx_dft(module: *const MODULE, size: u64) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_vec_znx_dft(module: *const MODULE, size: u64) -> *mut VEC_ZNX_DFT;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn delete_vec_znx_dft(res: *mut VEC_ZNX_DFT);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_dft_zero(module: *const MODULE, res: *mut VEC_ZNX_DFT, res_size: u64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_dft_add(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
a: *const VEC_ZNX_DFT,
|
||||
a_size: u64,
|
||||
b: *const VEC_ZNX_DFT,
|
||||
b_size: u64,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_dft_sub(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
a: *const VEC_ZNX_DFT,
|
||||
a_size: u64,
|
||||
b: *const VEC_ZNX_DFT,
|
||||
b_size: u64,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_dft(module: *const MODULE, res: *mut VEC_ZNX_DFT, res_size: u64, a: *const i64, a_size: u64, a_sl: u64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_idft(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
a_dft: *const VEC_ZNX_DFT,
|
||||
a_size: u64,
|
||||
tmp: *mut u8,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_idft_tmp_bytes(module: *const MODULE) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_idft_tmp_a(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
a_dft: *mut VEC_ZNX_DFT,
|
||||
a_size: u64,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_dft_automorphism(
|
||||
module: *const MODULE,
|
||||
d: i64,
|
||||
res_dft: *mut VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
a_dft: *const VEC_ZNX_DFT,
|
||||
a_size: u64,
|
||||
tmp: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_dft_automorphism_tmp_bytes(module: *const MODULE) -> u64;
|
||||
}
|
||||
use crate::implementation::cpu_spqlios::ffi::{module::MODULE, vec_znx_big::VEC_ZNX_BIG};
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct vec_znx_dft_t {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type VEC_ZNX_DFT = vec_znx_dft_t;
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn bytes_of_vec_znx_dft(module: *const MODULE, size: u64) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_vec_znx_dft(module: *const MODULE, size: u64) -> *mut VEC_ZNX_DFT;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn delete_vec_znx_dft(res: *mut VEC_ZNX_DFT);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_dft_zero(module: *const MODULE, res: *mut VEC_ZNX_DFT, res_size: u64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_dft_add(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
a: *const VEC_ZNX_DFT,
|
||||
a_size: u64,
|
||||
b: *const VEC_ZNX_DFT,
|
||||
b_size: u64,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_dft_sub(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
a: *const VEC_ZNX_DFT,
|
||||
a_size: u64,
|
||||
b: *const VEC_ZNX_DFT,
|
||||
b_size: u64,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_dft(module: *const MODULE, res: *mut VEC_ZNX_DFT, res_size: u64, a: *const i64, a_size: u64, a_sl: u64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_idft(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
a_dft: *const VEC_ZNX_DFT,
|
||||
a_size: u64,
|
||||
tmp: *mut u8,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_idft_tmp_bytes(module: *const MODULE) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_idft_tmp_a(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
a_dft: *mut VEC_ZNX_DFT,
|
||||
a_size: u64,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_dft_automorphism(
|
||||
module: *const MODULE,
|
||||
d: i64,
|
||||
res_dft: *mut VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
a_dft: *const VEC_ZNX_DFT,
|
||||
a_size: u64,
|
||||
tmp: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_dft_automorphism_tmp_bytes(module: *const MODULE) -> u64;
|
||||
}
|
||||
@@ -1,167 +1,113 @@
|
||||
use crate::ffi::module::MODULE;
|
||||
use crate::ffi::vec_znx_big::VEC_ZNX_BIG;
|
||||
use crate::ffi::vec_znx_dft::VEC_ZNX_DFT;
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct vmp_pmat_t {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
|
||||
// [rows][cols] = [#Decomposition][#Limbs]
|
||||
pub type VMP_PMAT = vmp_pmat_t;
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn bytes_of_vmp_pmat(module: *const MODULE, nrows: u64, ncols: u64) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_vmp_pmat(module: *const MODULE, nrows: u64, ncols: u64) -> *mut VMP_PMAT;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn delete_vmp_pmat(res: *mut VMP_PMAT);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_apply_dft(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
a: *const i64,
|
||||
a_size: u64,
|
||||
a_sl: u64,
|
||||
pmat: *const VMP_PMAT,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
tmp_space: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_apply_dft_add(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
a: *const i64,
|
||||
a_size: u64,
|
||||
a_sl: u64,
|
||||
pmat: *const VMP_PMAT,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
pmat_scale: u64,
|
||||
tmp_space: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_apply_dft_tmp_bytes(module: *const MODULE, res_size: u64, a_size: u64, nrows: u64, ncols: u64) -> u64;
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_apply_dft_to_dft(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
a_dft: *const VEC_ZNX_DFT,
|
||||
a_size: u64,
|
||||
pmat: *const VMP_PMAT,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
tmp_space: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_apply_dft_to_dft_add(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
a_dft: *const VEC_ZNX_DFT,
|
||||
a_size: u64,
|
||||
pmat: *const VMP_PMAT,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
pmat_scale: u64,
|
||||
tmp_space: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_apply_dft_to_dft_tmp_bytes(
|
||||
module: *const MODULE,
|
||||
res_size: u64,
|
||||
a_size: u64,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
) -> u64;
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_prepare_contiguous(
|
||||
module: *const MODULE,
|
||||
pmat: *mut VMP_PMAT,
|
||||
mat: *const i64,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
tmp_space: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_prepare_dblptr(
|
||||
module: *const MODULE,
|
||||
pmat: *mut VMP_PMAT,
|
||||
mat: *const *const i64,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
tmp_space: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_prepare_row(
|
||||
module: *const MODULE,
|
||||
pmat: *mut VMP_PMAT,
|
||||
row: *const i64,
|
||||
row_i: u64,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
tmp_space: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_prepare_row_dft(
|
||||
module: *const MODULE,
|
||||
pmat: *mut VMP_PMAT,
|
||||
row: *const VEC_ZNX_DFT,
|
||||
row_i: u64,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_extract_row_dft(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_DFT,
|
||||
pmat: *const VMP_PMAT,
|
||||
row_i: u64,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_extract_row(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
pmat: *const VMP_PMAT,
|
||||
row_i: u64,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_prepare_tmp_bytes(module: *const MODULE, nrows: u64, ncols: u64) -> u64;
|
||||
}
|
||||
use crate::implementation::cpu_spqlios::ffi::{module::MODULE, vec_znx_dft::VEC_ZNX_DFT};
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct vmp_pmat_t {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
|
||||
// [rows][cols] = [#Decomposition][#Limbs]
|
||||
pub type VMP_PMAT = vmp_pmat_t;
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn bytes_of_vmp_pmat(module: *const MODULE, nrows: u64, ncols: u64) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_vmp_pmat(module: *const MODULE, nrows: u64, ncols: u64) -> *mut VMP_PMAT;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn delete_vmp_pmat(res: *mut VMP_PMAT);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_apply_dft(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
a: *const i64,
|
||||
a_size: u64,
|
||||
a_sl: u64,
|
||||
pmat: *const VMP_PMAT,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
tmp_space: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_apply_dft_add(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
a: *const i64,
|
||||
a_size: u64,
|
||||
a_sl: u64,
|
||||
pmat: *const VMP_PMAT,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
pmat_scale: u64,
|
||||
tmp_space: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_apply_dft_tmp_bytes(module: *const MODULE, res_size: u64, a_size: u64, nrows: u64, ncols: u64) -> u64;
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_apply_dft_to_dft(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
a_dft: *const VEC_ZNX_DFT,
|
||||
a_size: u64,
|
||||
pmat: *const VMP_PMAT,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
tmp_space: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_apply_dft_to_dft_add(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
a_dft: *const VEC_ZNX_DFT,
|
||||
a_size: u64,
|
||||
pmat: *const VMP_PMAT,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
pmat_scale: u64,
|
||||
tmp_space: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_apply_dft_to_dft_tmp_bytes(
|
||||
module: *const MODULE,
|
||||
res_size: u64,
|
||||
a_size: u64,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
) -> u64;
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_prepare_contiguous(
|
||||
module: *const MODULE,
|
||||
pmat: *mut VMP_PMAT,
|
||||
mat: *const i64,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
tmp_space: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_prepare_contiguous_dft(module: *const MODULE, pmat: *mut VMP_PMAT, mat: *const f64, nrows: u64, ncols: u64);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_prepare_tmp_bytes(module: *const MODULE, nrows: u64, ncols: u64) -> u64;
|
||||
}
|
||||
@@ -1,76 +1,79 @@
|
||||
use crate::ffi::module::MODULE;
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_add_i64_ref(nn: u64, res: *mut i64, a: *const i64, b: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_add_i64_avx(nn: u64, res: *mut i64, a: *const i64, b: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_sub_i64_ref(nn: u64, res: *mut i64, a: *const i64, b: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_sub_i64_avx(nn: u64, res: *mut i64, a: *const i64, b: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_negate_i64_ref(nn: u64, res: *mut i64, a: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_negate_i64_avx(nn: u64, res: *mut i64, a: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_copy_i64_ref(nn: u64, res: *mut i64, a: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_zero_i64_ref(nn: u64, res: *mut i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn rnx_divide_by_m_ref(nn: u64, m: f64, res: *mut f64, a: *const f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn rnx_divide_by_m_avx(nn: u64, m: f64, res: *mut f64, a: *const f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn rnx_rotate_f64(nn: u64, p: i64, res: *mut f64, in_: *const f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_rotate_i64(nn: u64, p: i64, res: *mut i64, in_: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn rnx_rotate_inplace_f64(nn: u64, p: i64, res: *mut f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_rotate_inplace_i64(nn: u64, p: i64, res: *mut i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn rnx_automorphism_f64(nn: u64, p: i64, res: *mut f64, in_: *const f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_automorphism_i64(nn: u64, p: i64, res: *mut i64, in_: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn rnx_automorphism_inplace_f64(nn: u64, p: i64, res: *mut f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_automorphism_inplace_i64(nn: u64, p: i64, res: *mut i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn rnx_mul_xp_minus_one(nn: u64, p: i64, res: *mut f64, in_: *const f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_mul_xp_minus_one(nn: u64, p: i64, res: *mut i64, in_: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn rnx_mul_xp_minus_one_inplace(nn: u64, p: i64, res: *mut f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_normalize(nn: u64, base_k: u64, out: *mut i64, carry_out: *mut i64, in_: *const i64, carry_in: *const i64);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_small_single_product(module: *const MODULE, res: *mut i64, a: *const i64, b: *const i64, tmp: *mut u8);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_small_single_product_tmp_bytes(module: *const MODULE) -> u64;
|
||||
}
|
||||
use crate::implementation::cpu_spqlios::ffi::module::MODULE;
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_add_i64_ref(nn: u64, res: *mut i64, a: *const i64, b: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_add_i64_avx(nn: u64, res: *mut i64, a: *const i64, b: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_sub_i64_ref(nn: u64, res: *mut i64, a: *const i64, b: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_sub_i64_avx(nn: u64, res: *mut i64, a: *const i64, b: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_negate_i64_ref(nn: u64, res: *mut i64, a: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_negate_i64_avx(nn: u64, res: *mut i64, a: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_copy_i64_ref(nn: u64, res: *mut i64, a: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_zero_i64_ref(nn: u64, res: *mut i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn rnx_divide_by_m_ref(nn: u64, m: f64, res: *mut f64, a: *const f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn rnx_divide_by_m_avx(nn: u64, m: f64, res: *mut f64, a: *const f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn rnx_rotate_f64(nn: u64, p: i64, res: *mut f64, in_: *const f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_rotate_i64(nn: u64, p: i64, res: *mut i64, in_: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn rnx_rotate_inplace_f64(nn: u64, p: i64, res: *mut f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_rotate_inplace_i64(nn: u64, p: i64, res: *mut i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn rnx_automorphism_f64(nn: u64, p: i64, res: *mut f64, in_: *const f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_automorphism_i64(nn: u64, p: i64, res: *mut i64, in_: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn rnx_automorphism_inplace_f64(nn: u64, p: i64, res: *mut f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_automorphism_inplace_i64(nn: u64, p: i64, res: *mut i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn rnx_mul_xp_minus_one_f64(nn: u64, p: i64, res: *mut f64, in_: *const f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_mul_xp_minus_one_i64(nn: u64, p: i64, res: *mut i64, in_: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn rnx_mul_xp_minus_one_inplace_f64(nn: u64, p: i64, res: *mut f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_mul_xp_minus_one_inplace_i64(nn: u64, p: i64, res: *mut i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_normalize(nn: u64, base_k: u64, out: *mut i64, carry_out: *mut i64, in_: *const i64, carry_in: *const i64);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_small_single_product(module: *const MODULE, res: *mut i64, a: *const i64, b: *const i64, tmp: *mut u8);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn znx_small_single_product_tmp_bytes(module: *const MODULE) -> u64;
|
||||
}
|
||||
41
backend/src/implementation/cpu_spqlios/mat_znx.rs
Normal file
41
backend/src/implementation/cpu_spqlios/mat_znx.rs
Normal file
@@ -0,0 +1,41 @@
|
||||
use crate::{
|
||||
hal::{
|
||||
layouts::{Backend, MatZnxOwned, Module},
|
||||
oep::{MatZnxAllocBytesImpl, MatZnxAllocImpl, MatZnxFromBytesImpl},
|
||||
},
|
||||
implementation::cpu_spqlios::CPUAVX,
|
||||
};
|
||||
|
||||
unsafe impl<B: Backend> MatZnxAllocImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn mat_znx_alloc_impl(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxOwned {
|
||||
MatZnxOwned::new(module.n(), rows, cols_in, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> MatZnxAllocBytesImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn mat_znx_alloc_bytes_impl(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
MatZnxOwned::bytes_of(module.n(), rows, cols_in, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> MatZnxFromBytesImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn mat_znx_from_bytes_impl(
|
||||
module: &Module<B>,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
bytes: Vec<u8>,
|
||||
) -> MatZnxOwned {
|
||||
MatZnxOwned::new_from_bytes(module.n(), rows, cols_in, cols_out, size, bytes)
|
||||
}
|
||||
}
|
||||
26
backend/src/implementation/cpu_spqlios/mod.rs
Normal file
26
backend/src/implementation/cpu_spqlios/mod.rs
Normal file
@@ -0,0 +1,26 @@
|
||||
mod ffi;
|
||||
mod mat_znx;
|
||||
mod module_fft64;
|
||||
mod module_ntt120;
|
||||
mod scalar_znx;
|
||||
mod scratch;
|
||||
mod svp_ppol_fft64;
|
||||
mod svp_ppol_ntt120;
|
||||
mod vec_znx;
|
||||
mod vec_znx_big_fft64;
|
||||
mod vec_znx_big_ntt120;
|
||||
mod vec_znx_dft_fft64;
|
||||
mod vec_znx_dft_ntt120;
|
||||
mod vmp_pmat_fft64;
|
||||
mod vmp_pmat_ntt120;
|
||||
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
|
||||
pub use module_fft64::*;
|
||||
pub use module_ntt120::*;
|
||||
|
||||
/// For external documentation
|
||||
pub use vec_znx::{vec_znx_copy_ref, vec_znx_lsh_inplace_ref, vec_znx_merge_ref, vec_znx_rsh_inplace_ref, vec_znx_split_ref, vec_znx_switch_degree_ref};
|
||||
|
||||
pub trait CPUAVX {}
|
||||
29
backend/src/implementation/cpu_spqlios/module_fft64.rs
Normal file
29
backend/src/implementation/cpu_spqlios/module_fft64.rs
Normal file
@@ -0,0 +1,29 @@
|
||||
use std::ptr::NonNull;
|
||||
|
||||
use crate::{
|
||||
hal::{
|
||||
layouts::{Backend, Module},
|
||||
oep::ModuleNewImpl,
|
||||
},
|
||||
implementation::cpu_spqlios::{
|
||||
CPUAVX,
|
||||
ffi::module::{MODULE, delete_module_info, new_module_info},
|
||||
},
|
||||
};
|
||||
|
||||
pub struct FFT64;
|
||||
|
||||
impl CPUAVX for FFT64 {}
|
||||
|
||||
impl Backend for FFT64 {
|
||||
type Handle = MODULE;
|
||||
unsafe fn destroy(handle: NonNull<Self::Handle>) {
|
||||
unsafe { delete_module_info(handle.as_ptr()) }
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ModuleNewImpl<Self> for FFT64 {
|
||||
fn new_impl(n: u64) -> Module<Self> {
|
||||
unsafe { Module::from_raw_parts(new_module_info(n, 0), n) }
|
||||
}
|
||||
}
|
||||
29
backend/src/implementation/cpu_spqlios/module_ntt120.rs
Normal file
29
backend/src/implementation/cpu_spqlios/module_ntt120.rs
Normal file
@@ -0,0 +1,29 @@
|
||||
use std::ptr::NonNull;
|
||||
|
||||
use crate::{
|
||||
hal::{
|
||||
layouts::{Backend, Module},
|
||||
oep::ModuleNewImpl,
|
||||
},
|
||||
implementation::cpu_spqlios::{
|
||||
CPUAVX,
|
||||
ffi::module::{MODULE, delete_module_info, new_module_info},
|
||||
},
|
||||
};
|
||||
|
||||
pub struct NTT120;
|
||||
|
||||
impl CPUAVX for NTT120 {}
|
||||
|
||||
impl Backend for NTT120 {
|
||||
type Handle = MODULE;
|
||||
unsafe fn destroy(handle: NonNull<Self::Handle>) {
|
||||
unsafe { delete_module_info(handle.as_ptr()) }
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ModuleNewImpl<Self> for NTT120 {
|
||||
fn new_impl(n: u64) -> Module<Self> {
|
||||
unsafe { Module::from_raw_parts(new_module_info(n, 1), n) }
|
||||
}
|
||||
}
|
||||
100
backend/src/implementation/cpu_spqlios/scalar_znx.rs
Normal file
100
backend/src/implementation/cpu_spqlios/scalar_znx.rs
Normal file
@@ -0,0 +1,100 @@
|
||||
use crate::{
|
||||
hal::{
|
||||
api::{ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut},
|
||||
layouts::{Backend, Module, ScalarZnx, ScalarZnxOwned, ScalarZnxToMut, ScalarZnxToRef},
|
||||
oep::{
|
||||
ScalarZnxAllocBytesImpl, ScalarZnxAllocImpl, ScalarZnxAutomorphismImpl, ScalarZnxAutomorphismInplaceIml,
|
||||
ScalarZnxFromBytesImpl,
|
||||
},
|
||||
},
|
||||
implementation::cpu_spqlios::{
|
||||
CPUAVX,
|
||||
ffi::{module::module_info_t, vec_znx},
|
||||
},
|
||||
};
|
||||
|
||||
unsafe impl<B: Backend> ScalarZnxAllocBytesImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn scalar_znx_alloc_bytes_impl(n: usize, cols: usize) -> usize {
|
||||
ScalarZnxOwned::bytes_of(n, cols)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> ScalarZnxAllocImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn scalar_znx_alloc_impl(n: usize, cols: usize) -> ScalarZnxOwned {
|
||||
ScalarZnxOwned::new(n, cols)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> ScalarZnxFromBytesImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn scalar_znx_from_bytes_impl(n: usize, cols: usize, bytes: Vec<u8>) -> ScalarZnxOwned {
|
||||
ScalarZnxOwned::new_from_bytes(n, cols, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> ScalarZnxAutomorphismImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn scalar_znx_automorphism_impl<R, A>(module: &Module<B>, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: ScalarZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
{
|
||||
let a: ScalarZnx<&[u8]> = a.to_ref();
|
||||
let mut res: ScalarZnx<&mut [u8]> = res.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
module.ptr() as *const module_info_t,
|
||||
k,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> ScalarZnxAutomorphismInplaceIml<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn scalar_znx_automorphism_inplace_impl<A>(module: &Module<B>, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: ScalarZnxToMut,
|
||||
{
|
||||
let mut a: ScalarZnx<&mut [u8]> = a.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
module.ptr() as *const module_info_t,
|
||||
k,
|
||||
a.at_mut_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
274
backend/src/implementation/cpu_spqlios/scratch.rs
Normal file
274
backend/src/implementation/cpu_spqlios/scratch.rs
Normal file
@@ -0,0 +1,274 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::{
|
||||
DEFAULTALIGN, alloc_aligned,
|
||||
hal::{
|
||||
api::ScratchFromBytes,
|
||||
layouts::{Backend, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat},
|
||||
oep::{
|
||||
ScalarZnxAllocBytesImpl, ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl,
|
||||
SvpPPolAllocBytesImpl, TakeMatZnxImpl, TakeScalarZnxImpl, TakeSliceImpl, TakeSvpPPolImpl, TakeVecZnxBigImpl,
|
||||
TakeVecZnxDftImpl, TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, TakeVmpPMatImpl,
|
||||
VecZnxAllocBytesImpl, VecZnxBigAllocBytesImpl, VecZnxDftAllocBytesImpl, VmpPMatAllocBytesImpl,
|
||||
},
|
||||
},
|
||||
implementation::cpu_spqlios::CPUAVX,
|
||||
};
|
||||
|
||||
unsafe impl<B: Backend> ScratchOwnedAllocImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn scratch_owned_alloc_impl(size: usize) -> ScratchOwned<B> {
|
||||
let data: Vec<u8> = alloc_aligned(size);
|
||||
ScratchOwned {
|
||||
data,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> ScratchOwnedBorrowImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn scratch_owned_borrow_impl(scratch: &mut ScratchOwned<B>) -> &mut Scratch<B> {
|
||||
Scratch::from_bytes(&mut scratch.data)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> ScratchFromBytesImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn scratch_from_bytes_impl(data: &mut [u8]) -> &mut Scratch<B> {
|
||||
unsafe { &mut *(data as *mut [u8] as *mut Scratch<B>) }
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> ScratchAvailableImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn scratch_available_impl(scratch: &Scratch<B>) -> usize {
|
||||
let ptr: *const u8 = scratch.data.as_ptr();
|
||||
let self_len: usize = scratch.data.len();
|
||||
let aligned_offset: usize = ptr.align_offset(DEFAULTALIGN);
|
||||
self_len.saturating_sub(aligned_offset)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeSliceImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn take_slice_impl<T>(scratch: &mut Scratch<B>, len: usize) -> (&mut [T], &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, len * std::mem::size_of::<T>());
|
||||
|
||||
unsafe {
|
||||
(
|
||||
&mut *(std::ptr::slice_from_raw_parts_mut(take_slice.as_mut_ptr() as *mut T, len)),
|
||||
Scratch::from_bytes(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeScalarZnxImpl<B> for B
|
||||
where
|
||||
B: CPUAVX + ScalarZnxAllocBytesImpl<B>,
|
||||
{
|
||||
fn take_scalar_znx_impl(scratch: &mut Scratch<B>, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, B::scalar_znx_alloc_bytes_impl(n, cols));
|
||||
(
|
||||
ScalarZnx::from_data(take_slice, n, cols),
|
||||
Scratch::from_bytes(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeSvpPPolImpl<B> for B
|
||||
where
|
||||
B: CPUAVX + SvpPPolAllocBytesImpl<B>,
|
||||
{
|
||||
fn take_svp_ppol_impl(scratch: &mut Scratch<B>, n: usize, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, B::svp_ppol_alloc_bytes_impl(n, cols));
|
||||
(
|
||||
SvpPPol::from_data(take_slice, n, cols),
|
||||
Scratch::from_bytes(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVecZnxImpl<B> for B
|
||||
where
|
||||
B: CPUAVX + VecZnxAllocBytesImpl<B>,
|
||||
{
|
||||
fn take_vec_znx_impl(scratch: &mut Scratch<B>, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(
|
||||
&mut scratch.data,
|
||||
B::vec_znx_alloc_bytes_impl(n, cols, size),
|
||||
);
|
||||
(
|
||||
VecZnx::from_data(take_slice, n, cols, size),
|
||||
Scratch::from_bytes(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVecZnxBigImpl<B> for B
|
||||
where
|
||||
B: CPUAVX + VecZnxBigAllocBytesImpl<B>,
|
||||
{
|
||||
fn take_vec_znx_big_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
n: usize,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (VecZnxBig<&mut [u8], B>, &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(
|
||||
&mut scratch.data,
|
||||
B::vec_znx_big_alloc_bytes_impl(n, cols, size),
|
||||
);
|
||||
(
|
||||
VecZnxBig::from_data(take_slice, n, cols, size),
|
||||
Scratch::from_bytes(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVecZnxDftImpl<B> for B
|
||||
where
|
||||
B: CPUAVX + VecZnxDftAllocBytesImpl<B>,
|
||||
{
|
||||
fn take_vec_znx_dft_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
n: usize,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (VecZnxDft<&mut [u8], B>, &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(
|
||||
&mut scratch.data,
|
||||
B::vec_znx_dft_alloc_bytes_impl(n, cols, size),
|
||||
);
|
||||
|
||||
(
|
||||
VecZnxDft::from_data(take_slice, n, cols, size),
|
||||
Scratch::from_bytes(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVecZnxDftSliceImpl<B> for B
|
||||
where
|
||||
B: CPUAVX + VecZnxDftAllocBytesImpl<B>,
|
||||
{
|
||||
fn take_vec_znx_dft_slice_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
len: usize,
|
||||
n: usize,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (Vec<VecZnxDft<&mut [u8], B>>, &mut Scratch<B>) {
|
||||
let mut scratch: &mut Scratch<B> = scratch;
|
||||
let mut slice: Vec<VecZnxDft<&mut [u8], B>> = Vec::with_capacity(len);
|
||||
for _ in 0..len {
|
||||
let (znx, new_scratch) = B::take_vec_znx_dft_impl(scratch, n, cols, size);
|
||||
scratch = new_scratch;
|
||||
slice.push(znx);
|
||||
}
|
||||
(slice, scratch)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVecZnxSliceImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn take_vec_znx_slice_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
len: usize,
|
||||
n: usize,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (Vec<VecZnx<&mut [u8]>>, &mut Scratch<B>) {
|
||||
let mut scratch: &mut Scratch<B> = scratch;
|
||||
let mut slice: Vec<VecZnx<&mut [u8]>> = Vec::with_capacity(len);
|
||||
for _ in 0..len {
|
||||
let (znx, new_scratch) = B::take_vec_znx_impl(scratch, n, cols, size);
|
||||
scratch = new_scratch;
|
||||
slice.push(znx);
|
||||
}
|
||||
(slice, scratch)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVmpPMatImpl<B> for B
|
||||
where
|
||||
B: CPUAVX + VmpPMatAllocBytesImpl<B>,
|
||||
{
|
||||
fn take_vmp_pmat_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
n: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
) -> (VmpPMat<&mut [u8], B>, &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(
|
||||
&mut scratch.data,
|
||||
B::vmp_pmat_alloc_bytes_impl(n, rows, cols_in, cols_out, size),
|
||||
);
|
||||
(
|
||||
VmpPMat::from_data(take_slice, n, rows, cols_in, cols_out, size),
|
||||
Scratch::from_bytes(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeMatZnxImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn take_mat_znx_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
n: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
) -> (MatZnx<&mut [u8]>, &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(
|
||||
&mut scratch.data,
|
||||
MatZnx::<Vec<u8>>::bytes_of(n, rows, cols_in, cols_out, size),
|
||||
);
|
||||
(
|
||||
MatZnx::from_data(take_slice, n, rows, cols_in, cols_out, size),
|
||||
Scratch::from_bytes(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8]) {
|
||||
let ptr: *mut u8 = data.as_mut_ptr();
|
||||
let self_len: usize = data.len();
|
||||
|
||||
let aligned_offset: usize = ptr.align_offset(DEFAULTALIGN);
|
||||
let aligned_len: usize = self_len.saturating_sub(aligned_offset);
|
||||
|
||||
if let Some(rem_len) = aligned_len.checked_sub(take_len) {
|
||||
unsafe {
|
||||
let rem_ptr: *mut u8 = ptr.add(aligned_offset).add(take_len);
|
||||
let rem_slice: &mut [u8] = &mut *std::ptr::slice_from_raw_parts_mut(rem_ptr, rem_len);
|
||||
|
||||
let take_slice: &mut [u8] = &mut *std::ptr::slice_from_raw_parts_mut(ptr.add(aligned_offset), take_len);
|
||||
|
||||
return (take_slice, rem_slice);
|
||||
}
|
||||
} else {
|
||||
panic!(
|
||||
"Attempted to take {} from scratch with {} aligned bytes left",
|
||||
take_len, aligned_len,
|
||||
);
|
||||
}
|
||||
}
|
||||
Submodule backend/src/implementation/cpu_spqlios/spqlios-arithmetic added at 7160f588da
114
backend/src/implementation/cpu_spqlios/svp_ppol_fft64.rs
Normal file
114
backend/src/implementation/cpu_spqlios/svp_ppol_fft64.rs
Normal file
@@ -0,0 +1,114 @@
|
||||
use crate::{
|
||||
hal::{
|
||||
api::{ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut},
|
||||
layouts::{
|
||||
Data, DataRef, Module, ScalarZnxToRef, SvpPPol, SvpPPolBytesOf, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDft,
|
||||
VecZnxDftToMut, VecZnxDftToRef,
|
||||
},
|
||||
oep::{SvpApplyImpl, SvpApplyInplaceImpl, SvpPPolAllocBytesImpl, SvpPPolAllocImpl, SvpPPolFromBytesImpl, SvpPrepareImpl},
|
||||
},
|
||||
implementation::cpu_spqlios::{
|
||||
ffi::{svp, vec_znx_dft::vec_znx_dft_t},
|
||||
module_fft64::FFT64,
|
||||
},
|
||||
};
|
||||
|
||||
const SVP_PPOL_FFT64_WORD_SIZE: usize = 1;
|
||||
|
||||
impl<D: Data> SvpPPolBytesOf for SvpPPol<D, FFT64> {
|
||||
fn bytes_of(n: usize, cols: usize) -> usize {
|
||||
SVP_PPOL_FFT64_WORD_SIZE * n * cols * size_of::<f64>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> ZnxSliceSize for SvpPPol<D, FFT64> {
|
||||
fn sl(&self) -> usize {
|
||||
SVP_PPOL_FFT64_WORD_SIZE * self.n()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> ZnxView for SvpPPol<D, FFT64> {
|
||||
type Scalar = f64;
|
||||
}
|
||||
|
||||
unsafe impl SvpPPolFromBytesImpl<Self> for FFT64 {
|
||||
fn svp_ppol_from_bytes_impl(n: usize, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<Self> {
|
||||
SvpPPolOwned::from_bytes(n, cols, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SvpPPolAllocImpl<Self> for FFT64 {
|
||||
fn svp_ppol_alloc_impl(n: usize, cols: usize) -> SvpPPolOwned<Self> {
|
||||
SvpPPolOwned::alloc(n, cols)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SvpPPolAllocBytesImpl<Self> for FFT64 {
|
||||
fn svp_ppol_alloc_bytes_impl(n: usize, cols: usize) -> usize {
|
||||
SvpPPol::<Vec<u8>, Self>::bytes_of(n, cols)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SvpPrepareImpl<Self> for FFT64 {
|
||||
fn svp_prepare_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: SvpPPolToMut<Self>,
|
||||
A: ScalarZnxToRef,
|
||||
{
|
||||
unsafe {
|
||||
svp::svp_prepare(
|
||||
module.ptr(),
|
||||
res.to_mut().at_mut_ptr(res_col, 0) as *mut svp::svp_ppol_t,
|
||||
a.to_ref().at_ptr(a_col, 0),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SvpApplyImpl<Self> for FFT64 {
|
||||
fn svp_apply_impl<R, A, B>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: SvpPPolToRef<Self>,
|
||||
B: VecZnxDftToRef<Self>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], Self> = res.to_mut();
|
||||
let a: SvpPPol<&[u8], Self> = a.to_ref();
|
||||
let b: VecZnxDft<&[u8], Self> = b.to_ref();
|
||||
unsafe {
|
||||
svp::svp_apply_dft_to_dft(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
|
||||
res.size() as u64,
|
||||
res.cols() as u64,
|
||||
a.at_ptr(a_col, 0) as *const svp::svp_ppol_t,
|
||||
b.at_ptr(b_col, 0) as *const vec_znx_dft_t,
|
||||
b.size() as u64,
|
||||
b.cols() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SvpApplyInplaceImpl for FFT64 {
|
||||
fn svp_apply_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: SvpPPolToRef<Self>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], Self> = res.to_mut();
|
||||
let a: SvpPPol<&[u8], Self> = a.to_ref();
|
||||
unsafe {
|
||||
svp::svp_apply_dft_to_dft(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
|
||||
res.size() as u64,
|
||||
res.cols() as u64,
|
||||
a.at_ptr(a_col, 0) as *const svp::svp_ppol_t,
|
||||
res.at_ptr(res_col, 0) as *const vec_znx_dft_t,
|
||||
res.size() as u64,
|
||||
res.cols() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
44
backend/src/implementation/cpu_spqlios/svp_ppol_ntt120.rs
Normal file
44
backend/src/implementation/cpu_spqlios/svp_ppol_ntt120.rs
Normal file
@@ -0,0 +1,44 @@
|
||||
use crate::{
|
||||
hal::{
|
||||
api::{ZnxInfos, ZnxSliceSize, ZnxView},
|
||||
layouts::{Data, DataRef, SvpPPol, SvpPPolBytesOf, SvpPPolOwned},
|
||||
oep::{SvpPPolAllocBytesImpl, SvpPPolAllocImpl, SvpPPolFromBytesImpl},
|
||||
},
|
||||
implementation::cpu_spqlios::module_ntt120::NTT120,
|
||||
};
|
||||
|
||||
const SVP_PPOL_NTT120_WORD_SIZE: usize = 4;
|
||||
|
||||
impl<D: Data> SvpPPolBytesOf for SvpPPol<D, NTT120> {
|
||||
fn bytes_of(n: usize, cols: usize) -> usize {
|
||||
SVP_PPOL_NTT120_WORD_SIZE * n * cols * size_of::<i64>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> ZnxSliceSize for SvpPPol<D, NTT120> {
|
||||
fn sl(&self) -> usize {
|
||||
SVP_PPOL_NTT120_WORD_SIZE * self.n()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> ZnxView for SvpPPol<D, NTT120> {
|
||||
type Scalar = i64;
|
||||
}
|
||||
|
||||
unsafe impl SvpPPolFromBytesImpl<Self> for NTT120 {
|
||||
fn svp_ppol_from_bytes_impl(n: usize, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<NTT120> {
|
||||
SvpPPolOwned::from_bytes(n, cols, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SvpPPolAllocImpl<Self> for NTT120 {
|
||||
fn svp_ppol_alloc_impl(n: usize, cols: usize) -> SvpPPolOwned<NTT120> {
|
||||
SvpPPolOwned::alloc(n, cols)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SvpPPolAllocBytesImpl<Self> for NTT120 {
|
||||
fn svp_ppol_alloc_bytes_impl(n: usize, cols: usize) -> usize {
|
||||
SvpPPol::<Vec<u8>, Self>::bytes_of(n, cols)
|
||||
}
|
||||
}
|
||||
1
backend/src/implementation/cpu_spqlios/test/mod.rs
Normal file
1
backend/src/implementation/cpu_spqlios/test/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
mod vec_znx_fft64;
|
||||
35
backend/src/implementation/cpu_spqlios/test/vec_znx_fft64.rs
Normal file
35
backend/src/implementation/cpu_spqlios/test/vec_znx_fft64.rs
Normal file
@@ -0,0 +1,35 @@
|
||||
use crate::{
|
||||
hal::{
|
||||
api::ModuleNew,
|
||||
layouts::Module,
|
||||
tests::vec_znx::{
|
||||
test_vec_znx_add_normal, test_vec_znx_encode_vec_i64_hi_norm, test_vec_znx_encode_vec_i64_lo_norm,
|
||||
test_vec_znx_fill_uniform,
|
||||
},
|
||||
},
|
||||
implementation::cpu_spqlios::FFT64,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_vec_znx_fill_uniform_fft64() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << 12);
|
||||
test_vec_znx_fill_uniform(&module);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vec_znx_add_normal_fft64() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << 12);
|
||||
test_vec_znx_add_normal(&module);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vec_znx_encode_vec_lo_norm_fft64() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << 8);
|
||||
test_vec_znx_encode_vec_i64_lo_norm(&module);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vec_znx_encode_vec_hi_norm_fft64() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << 8);
|
||||
test_vec_znx_encode_vec_i64_hi_norm(&module);
|
||||
}
|
||||
1344
backend/src/implementation/cpu_spqlios/vec_znx.rs
Normal file
1344
backend/src/implementation/cpu_spqlios/vec_znx.rs
Normal file
File diff suppressed because it is too large
Load Diff
758
backend/src/implementation/cpu_spqlios/vec_znx_big_fft64.rs
Normal file
758
backend/src/implementation/cpu_spqlios/vec_znx_big_fft64.rs
Normal file
@@ -0,0 +1,758 @@
|
||||
use std::fmt;
|
||||
|
||||
use rand_distr::{Distribution, Normal};
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::{
|
||||
hal::{
|
||||
api::{
|
||||
TakeSlice, VecZnxBigAddDistF64, VecZnxBigFillDistF64, VecZnxBigNormalizeTmpBytes, ZnxInfos, ZnxSliceSize, ZnxView,
|
||||
ZnxViewMut,
|
||||
},
|
||||
layouts::{
|
||||
Data, DataRef, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigBytesOf, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef,
|
||||
VecZnxToMut, VecZnxToRef,
|
||||
},
|
||||
oep::{
|
||||
VecZnxBigAddDistF64Impl, VecZnxBigAddImpl, VecZnxBigAddInplaceImpl, VecZnxBigAddNormalImpl, VecZnxBigAddSmallImpl,
|
||||
VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, VecZnxBigAutomorphismImpl,
|
||||
VecZnxBigAutomorphismInplaceImpl, VecZnxBigFillDistF64Impl, VecZnxBigFillNormalImpl, VecZnxBigFromBytesImpl,
|
||||
VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl, VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubABInplaceImpl,
|
||||
VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl, VecZnxBigSubSmallAImpl, VecZnxBigSubSmallAInplaceImpl,
|
||||
VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl,
|
||||
},
|
||||
},
|
||||
implementation::cpu_spqlios::{ffi::vec_znx, module_fft64::FFT64},
|
||||
};
|
||||
|
||||
const VEC_ZNX_BIG_FFT64_WORDSIZE: usize = 1;
|
||||
|
||||
impl<D: DataRef> ZnxView for VecZnxBig<D, FFT64> {
|
||||
type Scalar = i64;
|
||||
}
|
||||
|
||||
impl<D: Data> VecZnxBigBytesOf for VecZnxBig<D, FFT64> {
|
||||
fn bytes_of(n: usize, cols: usize, size: usize) -> usize {
|
||||
VEC_ZNX_BIG_FFT64_WORDSIZE * n * cols * size * size_of::<f64>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> ZnxSliceSize for VecZnxBig<D, FFT64> {
|
||||
fn sl(&self) -> usize {
|
||||
VEC_ZNX_BIG_FFT64_WORDSIZE * self.n() * self.cols()
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAllocImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_big_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxBigOwned<FFT64> {
|
||||
VecZnxBig::<Vec<u8>, FFT64>::new(n, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigFromBytesImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_big_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<FFT64> {
|
||||
VecZnxBig::<Vec<u8>, FFT64>::new_from_bytes(n, cols, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAllocBytesImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_big_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize {
|
||||
VecZnxBig::<Vec<u8>, FFT64>::bytes_of(n, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddDistF64Impl<FFT64> for FFT64 {
|
||||
fn add_dist_f64_impl<R: VecZnxBigToMut<FFT64>, D: Distribution<f64>>(
|
||||
_module: &Module<FFT64>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) {
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = k.div_ceil(basek) - 1;
|
||||
let basek_rem: usize = (limb + 1) * basek - k;
|
||||
|
||||
if basek_rem != 0 {
|
||||
res.at_mut(res_col, limb).iter_mut().for_each(|x| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*x += (dist_f64.round() as i64) << basek_rem;
|
||||
});
|
||||
} else {
|
||||
res.at_mut(res_col, limb).iter_mut().for_each(|x| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*x += dist_f64.round() as i64
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddNormalImpl<FFT64> for FFT64 {
|
||||
fn add_normal_impl<R: VecZnxBigToMut<FFT64>>(
|
||||
module: &Module<FFT64>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) {
|
||||
module.vec_znx_big_add_dist_f64(
|
||||
basek,
|
||||
res,
|
||||
res_col,
|
||||
k,
|
||||
source,
|
||||
Normal::new(0.0, sigma).unwrap(),
|
||||
bound,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigFillDistF64Impl<FFT64> for FFT64 {
|
||||
fn fill_dist_f64_impl<R: VecZnxBigToMut<FFT64>, D: Distribution<f64>>(
|
||||
_module: &Module<FFT64>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) {
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = k.div_ceil(basek) - 1;
|
||||
let basek_rem: usize = (limb + 1) * basek - k;
|
||||
|
||||
if basek_rem != 0 {
|
||||
res.at_mut(res_col, limb).iter_mut().for_each(|x| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*x = (dist_f64.round() as i64) << basek_rem;
|
||||
});
|
||||
} else {
|
||||
res.at_mut(res_col, limb).iter_mut().for_each(|x| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*x = dist_f64.round() as i64
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigFillNormalImpl<FFT64> for FFT64 {
|
||||
fn fill_normal_impl<R: VecZnxBigToMut<FFT64>>(
|
||||
module: &Module<FFT64>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) {
|
||||
module.vec_znx_big_fill_dist_f64(
|
||||
basek,
|
||||
res,
|
||||
res_col,
|
||||
k,
|
||||
source,
|
||||
Normal::new(0.0, sigma).unwrap(),
|
||||
bound,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddImpl<FFT64> for FFT64 {
|
||||
/// Adds `a` to `b` and stores the result on `c`.
|
||||
fn vec_znx_big_add_impl<R, A, B>(
|
||||
module: &Module<FFT64>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
B: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let b: VecZnxBig<&[u8], FFT64> = b.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(b.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_add(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(b_col, 0),
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddInplaceImpl<FFT64> for FFT64 {
|
||||
/// Adds `a` to `b` and stores the result on `b`.
|
||||
fn vec_znx_big_add_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_add(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddSmallImpl<FFT64> for FFT64 {
|
||||
/// Adds `a` to `b` and stores the result on `c`.
|
||||
fn vec_znx_big_add_small_impl<R, A, B>(
|
||||
module: &Module<FFT64>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
B: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let b: VecZnx<&[u8]> = b.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(b.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_add(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(b_col, 0),
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddSmallInplaceImpl<FFT64> for FFT64 {
|
||||
/// Adds `a` to `b` and stores the result on `b`.
|
||||
fn vec_znx_big_add_small_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_add(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubImpl<FFT64> for FFT64 {
|
||||
/// Subtracts `a` to `b` and stores the result on `c`.
|
||||
fn vec_znx_big_sub_impl<R, A, B>(
|
||||
module: &Module<FFT64>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
B: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let b: VecZnxBig<&[u8], FFT64> = b.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(b.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(b_col, 0),
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubABInplaceImpl<FFT64> for FFT64 {
|
||||
/// Subtracts `a` from `b` and stores the result on `b`.
|
||||
fn vec_znx_big_sub_ab_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubBAInplaceImpl<FFT64> for FFT64 {
|
||||
/// Subtracts `b` from `a` and stores the result on `b`.
|
||||
fn vec_znx_big_sub_ba_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubSmallAImpl<FFT64> for FFT64 {
|
||||
/// Subtracts `b` from `a` and stores the result on `c`.
|
||||
fn vec_znx_big_sub_small_a_impl<R, A, B>(
|
||||
module: &Module<FFT64>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxToRef,
|
||||
B: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let b: VecZnxBig<&[u8], FFT64> = b.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(b.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(b_col, 0),
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubSmallAInplaceImpl<FFT64> for FFT64 {
|
||||
/// Subtracts `a` from `res` and stores the result on `res`.
|
||||
fn vec_znx_big_sub_small_a_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubSmallBImpl<FFT64> for FFT64 {
|
||||
/// Subtracts `b` from `a` and stores the result on `c`.
|
||||
fn vec_znx_big_sub_small_b_impl<R, A, B>(
|
||||
module: &Module<FFT64>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
B: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let b: VecZnx<&[u8]> = b.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(b.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(b_col, 0),
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubSmallBInplaceImpl<FFT64> for FFT64 {
|
||||
/// Subtracts `res` from `a` and stores the result on `res`.
|
||||
fn vec_znx_big_sub_small_b_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigNegateInplaceImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_big_negate_inplace_impl<A>(module: &Module<FFT64>, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxBigToMut<FFT64>,
|
||||
{
|
||||
let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_negate(
|
||||
module.ptr(),
|
||||
a.at_mut_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigNormalizeTmpBytesImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_big_normalize_tmp_bytes_impl(module: &Module<FFT64>, n: usize) -> usize {
|
||||
unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(module.ptr(), n as u64) as usize }
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigNormalizeImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_big_normalize_impl<R, A>(
|
||||
module: &Module<FFT64>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<FFT64>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), a.n());
|
||||
}
|
||||
|
||||
let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes(a.n()));
|
||||
unsafe {
|
||||
vec_znx::vec_znx_normalize_base2k(
|
||||
module.ptr(),
|
||||
a.n() as u64,
|
||||
basek as u64,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAutomorphismImpl<FFT64> for FFT64 {
|
||||
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`.
|
||||
fn vec_znx_big_automorphism_impl<R, A>(module: &Module<FFT64>, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
module.ptr(),
|
||||
k,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAutomorphismInplaceImpl<FFT64> for FFT64 {
|
||||
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`.
|
||||
fn vec_znx_big_automorphism_inplace_impl<A>(module: &Module<FFT64>, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxBigToMut<FFT64>,
|
||||
{
|
||||
let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
module.ptr(),
|
||||
k,
|
||||
a.at_mut_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> fmt::Display for VecZnxBig<D, FFT64> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
writeln!(
|
||||
f,
|
||||
"VecZnxBig(n={}, cols={}, size={})",
|
||||
self.n, self.cols, self.size
|
||||
)?;
|
||||
|
||||
for col in 0..self.cols {
|
||||
writeln!(f, "Column {}:", col)?;
|
||||
for size in 0..self.size {
|
||||
let coeffs = self.at(col, size);
|
||||
write!(f, " Size {}: [", size)?;
|
||||
|
||||
let max_show = 100;
|
||||
let show_count = coeffs.len().min(max_show);
|
||||
|
||||
for (i, &coeff) in coeffs.iter().take(show_count).enumerate() {
|
||||
if i > 0 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
write!(f, "{}", coeff)?;
|
||||
}
|
||||
|
||||
if coeffs.len() > max_show {
|
||||
write!(f, ", ... ({} more)", coeffs.len() - max_show)?;
|
||||
}
|
||||
|
||||
writeln!(f, "]")?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
32
backend/src/implementation/cpu_spqlios/vec_znx_big_ntt120.rs
Normal file
32
backend/src/implementation/cpu_spqlios/vec_znx_big_ntt120.rs
Normal file
@@ -0,0 +1,32 @@
|
||||
use crate::{
|
||||
hal::{
|
||||
api::{ZnxInfos, ZnxSliceSize, ZnxView},
|
||||
layouts::{Data, DataRef, VecZnxBig, VecZnxBigBytesOf},
|
||||
oep::VecZnxBigAllocBytesImpl,
|
||||
},
|
||||
implementation::cpu_spqlios::module_ntt120::NTT120,
|
||||
};
|
||||
|
||||
const VEC_ZNX_BIG_NTT120_WORDSIZE: usize = 4;
|
||||
|
||||
impl<D: DataRef> ZnxView for VecZnxBig<D, NTT120> {
|
||||
type Scalar = i128;
|
||||
}
|
||||
|
||||
impl<D: Data> VecZnxBigBytesOf for VecZnxBig<D, NTT120> {
|
||||
fn bytes_of(n: usize, cols: usize, size: usize) -> usize {
|
||||
VEC_ZNX_BIG_NTT120_WORDSIZE * n * cols * size * size_of::<f64>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> ZnxSliceSize for VecZnxBig<D, NTT120> {
|
||||
fn sl(&self) -> usize {
|
||||
VEC_ZNX_BIG_NTT120_WORDSIZE * self.n() * self.cols()
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAllocBytesImpl<NTT120> for NTT120 {
|
||||
fn vec_znx_big_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize {
|
||||
VecZnxBig::<Vec<u8>, NTT120>::bytes_of(n, cols, size)
|
||||
}
|
||||
}
|
||||
@@ -1,375 +1,90 @@
|
||||
use crate::ffi::{vec_znx_big, vec_znx_dft};
|
||||
use crate::vec_znx_dft::bytes_of_vec_znx_dft;
|
||||
use crate::znx_base::ZnxInfos;
|
||||
use std::fmt;
|
||||
|
||||
use crate::{
|
||||
Backend, Scratch, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef,
|
||||
ZnxSliceSize,
|
||||
hal::{
|
||||
api::{TakeSlice, VecZnxDftToVecZnxBigTmpBytes, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero},
|
||||
layouts::{
|
||||
Data, DataRef, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftBytesOf, VecZnxDftOwned,
|
||||
VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef,
|
||||
},
|
||||
oep::{
|
||||
VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftCopyImpl,
|
||||
VecZnxDftFromBytesImpl, VecZnxDftFromVecZnxImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl,
|
||||
VecZnxDftSubImpl, VecZnxDftToVecZnxBigConsumeImpl, VecZnxDftToVecZnxBigImpl, VecZnxDftToVecZnxBigTmpAImpl,
|
||||
VecZnxDftToVecZnxBigTmpBytesImpl, VecZnxDftZeroImpl,
|
||||
},
|
||||
},
|
||||
implementation::cpu_spqlios::{
|
||||
ffi::{vec_znx_big, vec_znx_dft},
|
||||
module_fft64::FFT64,
|
||||
},
|
||||
};
|
||||
use crate::{FFT64, Module, ZnxView, ZnxViewMut, ZnxZero};
|
||||
use std::cmp::min;
|
||||
|
||||
pub trait VecZnxDftAlloc<B: Backend> {
|
||||
/// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space.
|
||||
fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDftOwned<B>;
|
||||
const VEC_ZNX_DFT_FFT64_WORDSIZE: usize = 1;
|
||||
|
||||
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
|
||||
///
|
||||
/// Behavior: takes ownership of the backing array.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `cols`: the number of cols of the [VecZnxDft].
|
||||
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft].
|
||||
///
|
||||
/// # Panics
|
||||
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
|
||||
fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<B>;
|
||||
|
||||
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `cols`: the number of cols of the [VecZnxDft].
|
||||
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft].
|
||||
///
|
||||
/// # Panics
|
||||
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
|
||||
fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
pub trait VecZnxDftOps<B: Backend> {
|
||||
/// Returns the minimum number of bytes necessary to allocate
|
||||
/// a new [VecZnxDft] through [VecZnxDft::from_bytes].
|
||||
fn vec_znx_idft_tmp_bytes(&self) -> usize;
|
||||
|
||||
fn vec_znx_dft_add<R, A, D>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
D: VecZnxDftToRef<B>;
|
||||
|
||||
fn vec_znx_dft_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>;
|
||||
|
||||
fn vec_znx_dft_sub<R, A, D>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
D: VecZnxDftToRef<B>;
|
||||
|
||||
fn vec_znx_dft_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>;
|
||||
|
||||
fn vec_znx_dft_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>;
|
||||
|
||||
fn vec_znx_dft_copy<R, A>(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>;
|
||||
|
||||
/// b <- IDFT(a), uses a as scratch space.
|
||||
fn vec_znx_idft_tmp_a<R, A>(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxDftToMut<B>;
|
||||
|
||||
/// Consumes a to return IDFT(a) in big coeff space.
|
||||
fn vec_znx_idft_consume<D>(&self, a: VecZnxDft<D, B>) -> VecZnxBig<D, FFT64>
|
||||
where
|
||||
VecZnxDft<D, FFT64>: VecZnxDftToMut<FFT64>;
|
||||
|
||||
fn vec_znx_idft<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxDftToRef<B>;
|
||||
|
||||
fn vec_znx_dft<R, A>(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
impl<B: Backend> VecZnxDftAlloc<B> for Module<B> {
|
||||
fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDftOwned<B> {
|
||||
VecZnxDftOwned::new(&self, cols, size)
|
||||
}
|
||||
|
||||
fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<B> {
|
||||
VecZnxDftOwned::new_from_bytes(self, cols, size, bytes)
|
||||
}
|
||||
|
||||
fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize {
|
||||
bytes_of_vec_znx_dft(self, cols, size)
|
||||
impl<D: Data> ZnxSliceSize for VecZnxDft<D, FFT64> {
|
||||
fn sl(&self) -> usize {
|
||||
VEC_ZNX_DFT_FFT64_WORDSIZE * self.n() * self.cols()
|
||||
}
|
||||
}
|
||||
|
||||
impl VecZnxDftOps<FFT64> for Module<FFT64> {
|
||||
fn vec_znx_dft_add<R, A, D>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
D: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
let b_ref: VecZnxDft<&[u8], FFT64> = b.to_ref();
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_ref.size()).min(b_ref.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_add(
|
||||
self.ptr,
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
b_ref.at_ptr(b_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
}
|
||||
(min_size..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
})
|
||||
impl<D: Data> VecZnxDftBytesOf for VecZnxDft<D, FFT64> {
|
||||
fn bytes_of(n: usize, cols: usize, size: usize) -> usize {
|
||||
VEC_ZNX_DFT_FFT64_WORDSIZE * n * cols * size * size_of::<f64>()
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_dft_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
impl<D: DataRef> ZnxView for VecZnxDft<D, FFT64> {
|
||||
type Scalar = f64;
|
||||
}
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_ref.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_add(
|
||||
self.ptr,
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
}
|
||||
unsafe impl VecZnxDftFromBytesImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<FFT64> {
|
||||
VecZnxDft::<Vec<u8>, FFT64>::from_bytes(n, cols, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_dft_sub<R, A, D>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
D: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
let b_ref: VecZnxDft<&[u8], FFT64> = b.to_ref();
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_ref.size()).min(b_ref.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_sub(
|
||||
self.ptr,
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
b_ref.at_ptr(b_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
}
|
||||
(min_size..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
})
|
||||
unsafe impl VecZnxDftAllocBytesImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize {
|
||||
VecZnxDft::<Vec<u8>, FFT64>::bytes_of(n, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_dft_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_ref.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_sub(
|
||||
self.ptr,
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
}
|
||||
unsafe impl VecZnxDftAllocImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxDftOwned<FFT64> {
|
||||
VecZnxDftOwned::alloc(n, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_dft_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_ref.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_sub(
|
||||
self.ptr,
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
}
|
||||
unsafe impl VecZnxDftToVecZnxBigTmpBytesImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(module: &Module<FFT64>) -> usize {
|
||||
unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(module.ptr()) as usize }
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_dft_copy<R, A>(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
|
||||
let steps: usize = (a_ref.size() + step - 1) / step;
|
||||
let min_steps: usize = min(res_mut.size(), steps);
|
||||
|
||||
(0..min_steps).for_each(|j| {
|
||||
let limb: usize = offset + j * step;
|
||||
if limb < a_ref.size() {
|
||||
res_mut
|
||||
.at_mut(res_col, j)
|
||||
.copy_from_slice(a_ref.at(a_col, limb));
|
||||
}
|
||||
});
|
||||
(min_steps..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
})
|
||||
}
|
||||
|
||||
fn vec_znx_idft_tmp_a<R, A>(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxDftToMut<FFT64>,
|
||||
{
|
||||
let mut res_mut: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut();
|
||||
|
||||
let min_size: usize = min(res_mut.size(), a_mut.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_znx_idft_tmp_a(
|
||||
self.ptr,
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t,
|
||||
1 as u64,
|
||||
a_mut.at_mut_ptr(a_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1 as u64,
|
||||
)
|
||||
});
|
||||
(min_size..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_idft_consume<D>(&self, mut a: VecZnxDft<D, FFT64>) -> VecZnxBig<D, FFT64>
|
||||
where
|
||||
VecZnxDft<D, FFT64>: VecZnxDftToMut<FFT64>,
|
||||
{
|
||||
let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut();
|
||||
|
||||
unsafe {
|
||||
// Rev col and rows because ZnxDft.sl() >= ZnxBig.sl()
|
||||
(0..a_mut.size()).for_each(|j| {
|
||||
(0..a_mut.cols()).for_each(|i| {
|
||||
vec_znx_dft::vec_znx_idft_tmp_a(
|
||||
self.ptr,
|
||||
a_mut.at_mut_ptr(i, j) as *mut vec_znx_big::vec_znx_big_t,
|
||||
1 as u64,
|
||||
a_mut.at_mut_ptr(i, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1 as u64,
|
||||
)
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
a.into_big()
|
||||
}
|
||||
|
||||
fn vec_znx_idft_tmp_bytes(&self) -> usize {
|
||||
unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(self.ptr) as usize }
|
||||
}
|
||||
|
||||
fn vec_znx_dft<R, A>(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: crate::VecZnx<&[u8]> = a.to_ref();
|
||||
let steps: usize = (a_ref.size() + step - 1) / step;
|
||||
let min_steps: usize = min(res_mut.size(), steps);
|
||||
unsafe {
|
||||
(0..min_steps).for_each(|j| {
|
||||
let limb: usize = offset + j * step;
|
||||
if limb < a_ref.size() {
|
||||
vec_znx_dft::vec_znx_dft(
|
||||
self.ptr,
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1 as u64,
|
||||
a_ref.at_ptr(a_col, limb),
|
||||
1 as u64,
|
||||
a_ref.sl() as u64,
|
||||
)
|
||||
}
|
||||
});
|
||||
(min_steps..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes].
|
||||
fn vec_znx_idft<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
|
||||
where
|
||||
unsafe impl VecZnxDftToVecZnxBigImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_to_vec_znx_big_impl<R, A>(
|
||||
module: &Module<FFT64>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<FFT64>,
|
||||
) where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res_mut: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
|
||||
let (tmp_bytes, _) = scratch.tmp_slice(self.vec_znx_idft_tmp_bytes());
|
||||
let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_dft_to_vec_znx_big_tmp_bytes());
|
||||
|
||||
let min_size: usize = min(res_mut.size(), a_ref.size());
|
||||
let min_size: usize = res_mut.size().min(a_ref.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_znx_idft(
|
||||
self.ptr,
|
||||
module.ptr(),
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t,
|
||||
1 as u64,
|
||||
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
@@ -383,3 +98,331 @@ impl VecZnxDftOps<FFT64> for Module<FFT64> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftToVecZnxBigTmpAImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_to_vec_znx_big_tmp_a_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxDftToMut<FFT64>,
|
||||
{
|
||||
let mut res_mut: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut();
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_mut.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_znx_idft_tmp_a(
|
||||
module.ptr(),
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t,
|
||||
1 as u64,
|
||||
a_mut.at_mut_ptr(a_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1 as u64,
|
||||
)
|
||||
});
|
||||
(min_size..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftToVecZnxBigConsumeImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_to_vec_znx_big_consume_impl<D: Data>(module: &Module<FFT64>, mut a: VecZnxDft<D, FFT64>) -> VecZnxBig<D, FFT64>
|
||||
where
|
||||
VecZnxDft<D, FFT64>: VecZnxDftToMut<FFT64>,
|
||||
{
|
||||
let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut();
|
||||
|
||||
unsafe {
|
||||
// Rev col and rows because ZnxDft.sl() >= ZnxBig.sl()
|
||||
(0..a_mut.size()).for_each(|j| {
|
||||
(0..a_mut.cols()).for_each(|i| {
|
||||
vec_znx_dft::vec_znx_idft_tmp_a(
|
||||
module.ptr(),
|
||||
a_mut.at_mut_ptr(i, j) as *mut vec_znx_big::vec_znx_big_t,
|
||||
1 as u64,
|
||||
a_mut.at_mut_ptr(i, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1 as u64,
|
||||
)
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
a.into_big()
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftFromVecZnxImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_from_vec_znx_impl<R, A>(
|
||||
module: &Module<FFT64>,
|
||||
step: usize,
|
||||
offset: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
) where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnx<&[u8]> = a.to_ref();
|
||||
let steps: usize = a_ref.size().div_ceil(step);
|
||||
let min_steps: usize = res_mut.size().min(steps);
|
||||
unsafe {
|
||||
(0..min_steps).for_each(|j| {
|
||||
let limb: usize = offset + j * step;
|
||||
if limb < a_ref.size() {
|
||||
vec_znx_dft::vec_znx_dft(
|
||||
module.ptr(),
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1 as u64,
|
||||
a_ref.at_ptr(a_col, limb),
|
||||
1 as u64,
|
||||
a_ref.sl() as u64,
|
||||
)
|
||||
}
|
||||
});
|
||||
(min_steps..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftAddImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_add_impl<R, A, D>(
|
||||
module: &Module<FFT64>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &D,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
D: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
let b_ref: VecZnxDft<&[u8], FFT64> = b.to_ref();
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_ref.size()).min(b_ref.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_add(
|
||||
module.ptr(),
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
b_ref.at_ptr(b_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
}
|
||||
(min_size..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftAddInplaceImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_add_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_ref.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_add(
|
||||
module.ptr(),
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftSubImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_sub_impl<R, A, D>(
|
||||
module: &Module<FFT64>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &D,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
D: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
let b_ref: VecZnxDft<&[u8], FFT64> = b.to_ref();
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_ref.size()).min(b_ref.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_sub(
|
||||
module.ptr(),
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
b_ref.at_ptr(b_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
}
|
||||
(min_size..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftSubABInplaceImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_sub_ab_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_ref.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_sub(
|
||||
module.ptr(),
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftSubBAInplaceImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_sub_ba_inplace_impl<R, A>(module: &Module<FFT64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_ref.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_sub(
|
||||
module.ptr(),
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftCopyImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_copy_impl<R, A>(
|
||||
_module: &Module<FFT64>,
|
||||
step: usize,
|
||||
offset: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
) where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
|
||||
let steps: usize = a_ref.size().div_ceil(step);
|
||||
let min_steps: usize = res_mut.size().min(steps);
|
||||
|
||||
(0..min_steps).for_each(|j| {
|
||||
let limb: usize = offset + j * step;
|
||||
if limb < a_ref.size() {
|
||||
res_mut
|
||||
.at_mut(res_col, j)
|
||||
.copy_from_slice(a_ref.at(a_col, limb));
|
||||
}
|
||||
});
|
||||
(min_steps..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftZeroImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_zero_impl<R>(_module: &Module<FFT64>, res: &mut R)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
{
|
||||
res.to_mut().data.fill(0);
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> fmt::Display for VecZnxDft<D, FFT64> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
writeln!(
|
||||
f,
|
||||
"VecZnxDft(n={}, cols={}, size={})",
|
||||
self.n, self.cols, self.size
|
||||
)?;
|
||||
|
||||
for col in 0..self.cols {
|
||||
writeln!(f, "Column {}:", col)?;
|
||||
for size in 0..self.size {
|
||||
let coeffs = self.at(col, size);
|
||||
write!(f, " Size {}: [", size)?;
|
||||
|
||||
let max_show = 100;
|
||||
let show_count = coeffs.len().min(max_show);
|
||||
|
||||
for (i, &coeff) in coeffs.iter().take(show_count).enumerate() {
|
||||
if i > 0 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
write!(f, "{}", coeff)?;
|
||||
}
|
||||
|
||||
if coeffs.len() > max_show {
|
||||
write!(f, ", ... ({} more)", coeffs.len() - max_show)?;
|
||||
}
|
||||
|
||||
writeln!(f, "]")?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
38
backend/src/implementation/cpu_spqlios/vec_znx_dft_ntt120.rs
Normal file
38
backend/src/implementation/cpu_spqlios/vec_znx_dft_ntt120.rs
Normal file
@@ -0,0 +1,38 @@
|
||||
use crate::{
|
||||
hal::{
|
||||
api::{ZnxInfos, ZnxSliceSize, ZnxView},
|
||||
layouts::{Data, DataRef, VecZnxDft, VecZnxDftBytesOf, VecZnxDftOwned},
|
||||
oep::{VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl},
|
||||
},
|
||||
implementation::cpu_spqlios::module_ntt120::NTT120,
|
||||
};
|
||||
|
||||
const VEC_ZNX_DFT_NTT120_WORDSIZE: usize = 4;
|
||||
|
||||
impl<D: Data> ZnxSliceSize for VecZnxDft<D, NTT120> {
|
||||
fn sl(&self) -> usize {
|
||||
VEC_ZNX_DFT_NTT120_WORDSIZE * self.n() * self.cols()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> VecZnxDftBytesOf for VecZnxDft<D, NTT120> {
|
||||
fn bytes_of(n: usize, cols: usize, size: usize) -> usize {
|
||||
VEC_ZNX_DFT_NTT120_WORDSIZE * n * cols * size * size_of::<i64>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> ZnxView for VecZnxDft<D, NTT120> {
|
||||
type Scalar = i64;
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftAllocBytesImpl<NTT120> for NTT120 {
|
||||
fn vec_znx_dft_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize {
|
||||
VecZnxDft::<Vec<u8>, NTT120>::bytes_of(n, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftAllocImpl<NTT120> for NTT120 {
|
||||
fn vec_znx_dft_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxDftOwned<NTT120> {
|
||||
VecZnxDftOwned::alloc(n, cols, size)
|
||||
}
|
||||
}
|
||||
286
backend/src/implementation/cpu_spqlios/vmp_pmat_fft64.rs
Normal file
286
backend/src/implementation/cpu_spqlios/vmp_pmat_fft64.rs
Normal file
@@ -0,0 +1,286 @@
|
||||
use crate::{
|
||||
hal::{
|
||||
api::{TakeSlice, VmpApplyTmpBytes, VmpPrepareTmpBytes, ZnxInfos, ZnxView, ZnxViewMut},
|
||||
layouts::{
|
||||
DataRef, MatZnx, MatZnxToRef, Module, Scratch, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, VmpPMat, VmpPMatBytesOf,
|
||||
VmpPMatOwned, VmpPMatToMut, VmpPMatToRef,
|
||||
},
|
||||
oep::{
|
||||
VmpApplyAddImpl, VmpApplyAddTmpBytesImpl, VmpApplyImpl, VmpApplyTmpBytesImpl, VmpPMatAllocBytesImpl,
|
||||
VmpPMatAllocImpl, VmpPMatFromBytesImpl, VmpPMatPrepareImpl, VmpPrepareTmpBytesImpl,
|
||||
},
|
||||
},
|
||||
implementation::cpu_spqlios::{
|
||||
ffi::{vec_znx_dft::vec_znx_dft_t, vmp},
|
||||
module_fft64::FFT64,
|
||||
},
|
||||
};
|
||||
|
||||
const VMP_PMAT_FFT64_WORDSIZE: usize = 1;
|
||||
|
||||
impl<D: DataRef> ZnxView for VmpPMat<D, FFT64> {
|
||||
type Scalar = f64;
|
||||
}
|
||||
|
||||
impl VmpPMatBytesOf for FFT64 {
|
||||
fn vmp_pmat_bytes_of(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
VMP_PMAT_FFT64_WORDSIZE * n * rows * cols_in * cols_out * size * size_of::<f64>()
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpPMatAllocBytesImpl<FFT64> for FFT64
|
||||
where
|
||||
FFT64: VmpPMatBytesOf,
|
||||
{
|
||||
fn vmp_pmat_alloc_bytes_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
FFT64::vmp_pmat_bytes_of(n, rows, cols_in, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpPMatFromBytesImpl<FFT64> for FFT64 {
|
||||
fn vmp_pmat_from_bytes_impl(
|
||||
n: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
bytes: Vec<u8>,
|
||||
) -> VmpPMatOwned<FFT64> {
|
||||
VmpPMatOwned::from_bytes(n, rows, cols_in, cols_out, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpPMatAllocImpl<FFT64> for FFT64 {
|
||||
fn vmp_pmat_alloc_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<FFT64> {
|
||||
VmpPMatOwned::alloc(n, rows, cols_in, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpPrepareTmpBytesImpl<FFT64> for FFT64 {
|
||||
fn vmp_prepare_tmp_bytes_impl(module: &Module<FFT64>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
unsafe {
|
||||
vmp::vmp_prepare_tmp_bytes(
|
||||
module.ptr(),
|
||||
(rows * cols_in) as u64,
|
||||
(cols_out * size) as u64,
|
||||
) as usize
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpPMatPrepareImpl<FFT64> for FFT64 {
|
||||
fn vmp_prepare_impl<R, A>(module: &Module<FFT64>, res: &mut R, a: &A, scratch: &mut Scratch<FFT64>)
|
||||
where
|
||||
R: VmpPMatToMut<FFT64>,
|
||||
A: MatZnxToRef,
|
||||
{
|
||||
let mut res: VmpPMat<&mut [u8], FFT64> = res.to_mut();
|
||||
let a: MatZnx<&[u8]> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(
|
||||
res.cols_in(),
|
||||
a.cols_in(),
|
||||
"res.cols_in: {} != a.cols_in: {}",
|
||||
res.cols_in(),
|
||||
a.cols_in()
|
||||
);
|
||||
assert_eq!(
|
||||
res.rows(),
|
||||
a.rows(),
|
||||
"res.rows: {} != a.rows: {}",
|
||||
res.rows(),
|
||||
a.rows()
|
||||
);
|
||||
assert_eq!(
|
||||
res.cols_out(),
|
||||
a.cols_out(),
|
||||
"res.cols_out: {} != a.cols_out: {}",
|
||||
res.cols_out(),
|
||||
a.cols_out()
|
||||
);
|
||||
assert_eq!(
|
||||
res.size(),
|
||||
a.size(),
|
||||
"res.size: {} != a.size: {}",
|
||||
res.size(),
|
||||
a.size()
|
||||
);
|
||||
}
|
||||
|
||||
let (tmp_bytes, _) = scratch.take_slice(module.vmp_prepare_tmp_bytes(a.rows(), a.cols_in(), a.cols_out(), a.size()));
|
||||
|
||||
unsafe {
|
||||
vmp::vmp_prepare_contiguous(
|
||||
module.ptr(),
|
||||
res.as_mut_ptr() as *mut vmp::vmp_pmat_t,
|
||||
a.as_ptr(),
|
||||
(a.rows() * a.cols_in()) as u64,
|
||||
(a.size() * a.cols_out()) as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpApplyTmpBytesImpl<FFT64> for FFT64 {
|
||||
fn vmp_apply_tmp_bytes_impl(
|
||||
module: &Module<FFT64>,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
b_cols_in: usize,
|
||||
b_cols_out: usize,
|
||||
b_size: usize,
|
||||
) -> usize {
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_to_dft_tmp_bytes(
|
||||
module.ptr(),
|
||||
(res_size * b_cols_out) as u64,
|
||||
(a_size * b_cols_in) as u64,
|
||||
(b_rows * b_cols_in) as u64,
|
||||
(b_size * b_cols_out) as u64,
|
||||
) as usize
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpApplyImpl<FFT64> for FFT64 {
|
||||
fn vmp_apply_impl<R, A, C>(module: &Module<FFT64>, res: &mut R, a: &A, b: &C, scratch: &mut Scratch<FFT64>)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
C: VmpPMatToRef<FFT64>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], _> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], _> = a.to_ref();
|
||||
let b: VmpPMat<&[u8], _> = b.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(b.n(), module.n());
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(
|
||||
res.cols(),
|
||||
b.cols_out(),
|
||||
"res.cols(): {} != b.cols_out: {}",
|
||||
res.cols(),
|
||||
b.cols_out()
|
||||
);
|
||||
assert_eq!(
|
||||
a.cols(),
|
||||
b.cols_in(),
|
||||
"a.cols(): {} != b.cols_in: {}",
|
||||
a.cols(),
|
||||
b.cols_in()
|
||||
);
|
||||
}
|
||||
|
||||
let (tmp_bytes, _) = scratch.take_slice(module.vmp_apply_tmp_bytes(
|
||||
res.size(),
|
||||
a.size(),
|
||||
b.rows(),
|
||||
b.cols_in(),
|
||||
b.cols_out(),
|
||||
b.size(),
|
||||
));
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_to_dft(
|
||||
module.ptr(),
|
||||
res.as_mut_ptr() as *mut vec_znx_dft_t,
|
||||
(res.size() * res.cols()) as u64,
|
||||
a.as_ptr() as *const vec_znx_dft_t,
|
||||
(a.size() * a.cols()) as u64,
|
||||
b.as_ptr() as *const vmp::vmp_pmat_t,
|
||||
(b.rows() * b.cols_in()) as u64,
|
||||
(b.size() * b.cols_out()) as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpApplyAddTmpBytesImpl<FFT64> for FFT64 {
|
||||
fn vmp_apply_add_tmp_bytes_impl(
|
||||
module: &Module<FFT64>,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
b_cols_in: usize,
|
||||
b_cols_out: usize,
|
||||
b_size: usize,
|
||||
) -> usize {
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_to_dft_tmp_bytes(
|
||||
module.ptr(),
|
||||
(res_size * b_cols_out) as u64,
|
||||
(a_size * b_cols_in) as u64,
|
||||
(b_rows * b_cols_in) as u64,
|
||||
(b_size * b_cols_out) as u64,
|
||||
) as usize
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpApplyAddImpl<FFT64> for FFT64 {
|
||||
fn vmp_apply_add_impl<R, A, C>(module: &Module<FFT64>, res: &mut R, a: &A, b: &C, scale: usize, scratch: &mut Scratch<FFT64>)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
C: VmpPMatToRef<FFT64>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], _> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], _> = a.to_ref();
|
||||
let b: VmpPMat<&[u8], _> = b.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
use crate::hal::api::ZnxInfos;
|
||||
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(b.n(), module.n());
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(
|
||||
res.cols(),
|
||||
b.cols_out(),
|
||||
"res.cols(): {} != b.cols_out: {}",
|
||||
res.cols(),
|
||||
b.cols_out()
|
||||
);
|
||||
assert_eq!(
|
||||
a.cols(),
|
||||
b.cols_in(),
|
||||
"a.cols(): {} != b.cols_in: {}",
|
||||
a.cols(),
|
||||
b.cols_in()
|
||||
);
|
||||
}
|
||||
|
||||
let (tmp_bytes, _) = scratch.take_slice(module.vmp_apply_tmp_bytes(
|
||||
res.size(),
|
||||
a.size(),
|
||||
b.rows(),
|
||||
b.cols_in(),
|
||||
b.cols_out(),
|
||||
b.size(),
|
||||
));
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_to_dft_add(
|
||||
module.ptr(),
|
||||
res.as_mut_ptr() as *mut vec_znx_dft_t,
|
||||
(res.size() * res.cols()) as u64,
|
||||
a.as_ptr() as *const vec_znx_dft_t,
|
||||
(a.size() * a.cols()) as u64,
|
||||
b.as_ptr() as *const vmp::vmp_pmat_t,
|
||||
(b.rows() * b.cols_in()) as u64,
|
||||
(b.size() * b.cols_out()) as u64,
|
||||
(scale * b.cols_out()) as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
11
backend/src/implementation/cpu_spqlios/vmp_pmat_ntt120.rs
Normal file
11
backend/src/implementation/cpu_spqlios/vmp_pmat_ntt120.rs
Normal file
@@ -0,0 +1,11 @@
|
||||
use crate::{
|
||||
hal::{
|
||||
api::ZnxView,
|
||||
layouts::{DataRef, VmpPMat},
|
||||
},
|
||||
implementation::cpu_spqlios::module_ntt120::NTT120,
|
||||
};
|
||||
|
||||
impl<D: DataRef> ZnxView for VmpPMat<D, NTT120> {
|
||||
type Scalar = i64;
|
||||
}
|
||||
1
backend/src/implementation/mod.rs
Normal file
1
backend/src/implementation/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod cpu_spqlios;
|
||||
@@ -1,39 +1,17 @@
|
||||
pub mod encoding;
|
||||
#[allow(non_camel_case_types, non_snake_case, non_upper_case_globals, dead_code, improper_ctypes)]
|
||||
// Other modules and exports
|
||||
pub mod ffi;
|
||||
pub mod mat_znx_dft;
|
||||
pub mod mat_znx_dft_ops;
|
||||
pub mod module;
|
||||
pub mod sampling;
|
||||
pub mod scalar_znx;
|
||||
pub mod scalar_znx_dft;
|
||||
pub mod scalar_znx_dft_ops;
|
||||
pub mod stats;
|
||||
pub mod vec_znx;
|
||||
pub mod vec_znx_big;
|
||||
pub mod vec_znx_big_ops;
|
||||
pub mod vec_znx_dft;
|
||||
pub mod vec_znx_dft_ops;
|
||||
pub mod vec_znx_ops;
|
||||
pub mod znx_base;
|
||||
#![allow(non_camel_case_types, non_snake_case, non_upper_case_globals, dead_code, improper_ctypes)]
|
||||
#![deny(rustdoc::broken_intra_doc_links)]
|
||||
#![cfg_attr(docsrs, feature(doc_cfg))]
|
||||
#![feature(trait_alias)]
|
||||
|
||||
pub use encoding::*;
|
||||
pub use mat_znx_dft::*;
|
||||
pub use mat_znx_dft_ops::*;
|
||||
pub use module::*;
|
||||
pub use sampling::*;
|
||||
pub use scalar_znx::*;
|
||||
pub use scalar_znx_dft::*;
|
||||
pub use scalar_znx_dft_ops::*;
|
||||
pub use stats::*;
|
||||
pub use vec_znx::*;
|
||||
pub use vec_znx_big::*;
|
||||
pub use vec_znx_big_ops::*;
|
||||
pub use vec_znx_dft::*;
|
||||
pub use vec_znx_dft_ops::*;
|
||||
pub use vec_znx_ops::*;
|
||||
pub use znx_base::*;
|
||||
pub mod hal;
|
||||
pub mod implementation;
|
||||
|
||||
pub mod doc {
|
||||
#[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/docs/backend_safety_contract.md"))]
|
||||
pub mod backend_safety {
|
||||
pub const _PLACEHOLDER: () = ();
|
||||
}
|
||||
}
|
||||
|
||||
pub const GALOISGENERATOR: u64 = 5;
|
||||
pub const DEFAULTALIGN: usize = 64;
|
||||
@@ -118,190 +96,10 @@ pub fn alloc_aligned_custom<T>(size: usize, align: usize) -> Vec<T> {
|
||||
}
|
||||
|
||||
/// Allocates an aligned vector of size equal to the smallest multiple
|
||||
/// of [DEFAULTALIGN]/size_of::<T>() that is equal or greater to `size`.
|
||||
/// of [DEFAULTALIGN]/`size_of::<T>`() that is equal or greater to `size`.
|
||||
pub fn alloc_aligned<T>(size: usize) -> Vec<T> {
|
||||
alloc_aligned_custom::<T>(
|
||||
size + (DEFAULTALIGN - (size % (DEFAULTALIGN / size_of::<T>()))),
|
||||
size + (DEFAULTALIGN - (size % (DEFAULTALIGN / size_of::<T>()))) % DEFAULTALIGN,
|
||||
DEFAULTALIGN,
|
||||
)
|
||||
}
|
||||
|
||||
// Scratch implementation below
|
||||
|
||||
pub struct ScratchOwned(Vec<u8>);
|
||||
|
||||
impl ScratchOwned {
|
||||
pub fn new(byte_count: usize) -> Self {
|
||||
let data: Vec<u8> = alloc_aligned(byte_count);
|
||||
Self(data)
|
||||
}
|
||||
|
||||
pub fn borrow(&mut self) -> &mut Scratch {
|
||||
Scratch::new(&mut self.0)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Scratch {
|
||||
data: [u8],
|
||||
}
|
||||
|
||||
impl Scratch {
|
||||
fn new(data: &mut [u8]) -> &mut Self {
|
||||
unsafe { &mut *(data as *mut [u8] as *mut Self) }
|
||||
}
|
||||
|
||||
pub fn zero(&mut self) {
|
||||
self.data.fill(0);
|
||||
}
|
||||
|
||||
pub fn available(&self) -> usize {
|
||||
let ptr: *const u8 = self.data.as_ptr();
|
||||
let self_len: usize = self.data.len();
|
||||
let aligned_offset: usize = ptr.align_offset(DEFAULTALIGN);
|
||||
self_len.saturating_sub(aligned_offset)
|
||||
}
|
||||
|
||||
fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8]) {
|
||||
let ptr: *mut u8 = data.as_mut_ptr();
|
||||
let self_len: usize = data.len();
|
||||
|
||||
let aligned_offset: usize = ptr.align_offset(DEFAULTALIGN);
|
||||
let aligned_len: usize = self_len.saturating_sub(aligned_offset);
|
||||
|
||||
if let Some(rem_len) = aligned_len.checked_sub(take_len) {
|
||||
unsafe {
|
||||
let rem_ptr: *mut u8 = ptr.add(aligned_offset).add(take_len);
|
||||
let rem_slice: &mut [u8] = &mut *std::ptr::slice_from_raw_parts_mut(rem_ptr, rem_len);
|
||||
|
||||
let take_slice: &mut [u8] = &mut *std::ptr::slice_from_raw_parts_mut(ptr.add(aligned_offset), take_len);
|
||||
|
||||
return (take_slice, rem_slice);
|
||||
}
|
||||
} else {
|
||||
panic!(
|
||||
"Attempted to take {} from scratch with {} aligned bytes left",
|
||||
take_len,
|
||||
aligned_len,
|
||||
// type_name::<T>(),
|
||||
// aligned_len
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tmp_slice<T>(&mut self, len: usize) -> (&mut [T], &mut Self) {
|
||||
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, len * std::mem::size_of::<T>());
|
||||
|
||||
unsafe {
|
||||
(
|
||||
&mut *(std::ptr::slice_from_raw_parts_mut(take_slice.as_mut_ptr() as *mut T, len)),
|
||||
Self::new(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tmp_scalar_znx<B: Backend>(&mut self, module: &Module<B>, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) {
|
||||
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_scalar_znx(module, cols));
|
||||
|
||||
(
|
||||
ScalarZnx::from_data(take_slice, module.n(), cols),
|
||||
Self::new(rem_slice),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn tmp_scalar_znx_dft<B: Backend>(&mut self, module: &Module<B>, cols: usize) -> (ScalarZnxDft<&mut [u8], B>, &mut Self) {
|
||||
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_scalar_znx_dft(module, cols));
|
||||
|
||||
(
|
||||
ScalarZnxDft::from_data(take_slice, module.n(), cols),
|
||||
Self::new(rem_slice),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn tmp_vec_znx_dft<B: Backend>(
|
||||
&mut self,
|
||||
module: &Module<B>,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (VecZnxDft<&mut [u8], B>, &mut Self) {
|
||||
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_vec_znx_dft(module, cols, size));
|
||||
|
||||
(
|
||||
VecZnxDft::from_data(take_slice, module.n(), cols, size),
|
||||
Self::new(rem_slice),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn tmp_slice_vec_znx_dft<B: Backend>(
|
||||
&mut self,
|
||||
slice_size: usize,
|
||||
module: &Module<B>,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (Vec<VecZnxDft<&mut [u8], B>>, &mut Self) {
|
||||
let mut scratch: &mut Scratch = self;
|
||||
let mut slice: Vec<VecZnxDft<&mut [u8], B>> = Vec::with_capacity(slice_size);
|
||||
for _ in 0..slice_size {
|
||||
let (znx, new_scratch) = scratch.tmp_vec_znx_dft(module, cols, size);
|
||||
scratch = new_scratch;
|
||||
slice.push(znx);
|
||||
}
|
||||
(slice, scratch)
|
||||
}
|
||||
|
||||
pub fn tmp_vec_znx_big<B: Backend>(
|
||||
&mut self,
|
||||
module: &Module<B>,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (VecZnxBig<&mut [u8], B>, &mut Self) {
|
||||
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_vec_znx_big(module, cols, size));
|
||||
|
||||
(
|
||||
VecZnxBig::from_data(take_slice, module.n(), cols, size),
|
||||
Self::new(rem_slice),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn tmp_vec_znx<B: Backend>(&mut self, module: &Module<B>, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self) {
|
||||
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, module.bytes_of_vec_znx(cols, size));
|
||||
(
|
||||
VecZnx::from_data(take_slice, module.n(), cols, size),
|
||||
Self::new(rem_slice),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn tmp_slice_vec_znx<B: Backend>(
|
||||
&mut self,
|
||||
slice_size: usize,
|
||||
module: &Module<B>,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (Vec<VecZnx<&mut [u8]>>, &mut Self) {
|
||||
let mut scratch: &mut Scratch = self;
|
||||
let mut slice: Vec<VecZnx<&mut [u8]>> = Vec::with_capacity(slice_size);
|
||||
for _ in 0..slice_size {
|
||||
let (znx, new_scratch) = scratch.tmp_vec_znx(module, cols, size);
|
||||
scratch = new_scratch;
|
||||
slice.push(znx);
|
||||
}
|
||||
(slice, scratch)
|
||||
}
|
||||
|
||||
pub fn tmp_mat_znx_dft<B: Backend>(
|
||||
&mut self,
|
||||
module: &Module<B>,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
) -> (MatZnxDft<&mut [u8], B>, &mut Self) {
|
||||
let (take_slice, rem_slice) = Self::take_slice_aligned(
|
||||
&mut self.data,
|
||||
module.bytes_of_mat_znx_dft(rows, cols_in, cols_out, size),
|
||||
);
|
||||
(
|
||||
MatZnxDft::from_data(take_slice, module.n(), rows, cols_in, cols_out, size),
|
||||
Self::new(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,214 +0,0 @@
|
||||
use crate::znx_base::ZnxInfos;
|
||||
use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxView, alloc_aligned};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx],
|
||||
/// stored as a 3D matrix in the DFT domain in a single contiguous array.
|
||||
/// Each col of the [MatZnxDft] can be seen as a collection of [VecZnxDft].
|
||||
///
|
||||
/// [MatZnxDft] is used to permform a vector matrix product between a [VecZnx]/[VecZnxDft] and a [MatZnxDft].
|
||||
/// See the trait [MatZnxDftOps] for additional information.
|
||||
pub struct MatZnxDft<D, B: Backend> {
|
||||
data: D,
|
||||
n: usize,
|
||||
size: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
_phantom: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<D, B: Backend> ZnxInfos for MatZnxDft<D, B> {
|
||||
fn cols(&self) -> usize {
|
||||
self.cols_in
|
||||
}
|
||||
|
||||
fn rows(&self) -> usize {
|
||||
self.rows
|
||||
}
|
||||
|
||||
fn n(&self) -> usize {
|
||||
self.n
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
}
|
||||
|
||||
impl<D> ZnxSliceSize for MatZnxDft<D, FFT64> {
|
||||
fn sl(&self) -> usize {
|
||||
self.n() * self.cols_out()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, B: Backend> DataView for MatZnxDft<D, B> {
|
||||
type D = D;
|
||||
fn data(&self) -> &Self::D {
|
||||
&self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, B: Backend> DataViewMut for MatZnxDft<D, B> {
|
||||
fn data_mut(&mut self) -> &mut Self::D {
|
||||
&mut self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: AsRef<[u8]>> ZnxView for MatZnxDft<D, FFT64> {
|
||||
type Scalar = f64;
|
||||
}
|
||||
|
||||
impl<D, B: Backend> MatZnxDft<D, B> {
|
||||
pub fn cols_in(&self) -> usize {
|
||||
self.cols_in
|
||||
}
|
||||
|
||||
pub fn cols_out(&self) -> usize {
|
||||
self.cols_out
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: From<Vec<u8>>, B: Backend> MatZnxDft<D, B> {
|
||||
pub(crate) fn bytes_of(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
unsafe {
|
||||
crate::ffi::vmp::bytes_of_vmp_pmat(
|
||||
module.ptr,
|
||||
(rows * cols_in) as u64,
|
||||
(size * cols_out) as u64,
|
||||
) as usize
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn new(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
|
||||
let data: Vec<u8> = alloc_aligned(Self::bytes_of(module, rows, cols_in, cols_out, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n: module.n(),
|
||||
size,
|
||||
rows,
|
||||
cols_in,
|
||||
cols_out,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn new_from_bytes(
|
||||
module: &Module<B>,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
bytes: impl Into<Vec<u8>>,
|
||||
) -> Self {
|
||||
let data: Vec<u8> = bytes.into();
|
||||
assert!(data.len() == Self::bytes_of(module, rows, cols_in, cols_out, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n: module.n(),
|
||||
size,
|
||||
rows,
|
||||
cols_in,
|
||||
cols_out,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: AsRef<[u8]>> MatZnxDft<D, FFT64> {
|
||||
/// Returns a copy of the backend array at index (i, j) of the [MatZnxDft].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `row`: row index (i).
|
||||
/// * `col`: col index (j).
|
||||
#[allow(dead_code)]
|
||||
fn at(&self, row: usize, col: usize) -> Vec<f64> {
|
||||
let n: usize = self.n();
|
||||
|
||||
let mut res: Vec<f64> = alloc_aligned(n);
|
||||
|
||||
if n < 8 {
|
||||
res.copy_from_slice(&self.raw()[(row + col * self.rows()) * n..(row + col * self.rows()) * (n + 1)]);
|
||||
} else {
|
||||
(0..n >> 3).for_each(|blk| {
|
||||
res[blk * 8..(blk + 1) * 8].copy_from_slice(&self.at_block(row, col, blk)[..8]);
|
||||
});
|
||||
}
|
||||
|
||||
res
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn at_block(&self, row: usize, col: usize, blk: usize) -> &[f64] {
|
||||
let nrows: usize = self.rows();
|
||||
let nsize: usize = self.size();
|
||||
if col == (nsize - 1) && (nsize & 1 == 1) {
|
||||
&self.raw()[blk * nrows * nsize * 8 + col * nrows * 8 + row * 8..]
|
||||
} else {
|
||||
&self.raw()[blk * nrows * nsize * 8 + (col / 2) * (2 * nrows) * 8 + row * 2 * 8 + (col % 2) * 8..]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type MatZnxDftOwned<B> = MatZnxDft<Vec<u8>, B>;
|
||||
pub type MatZnxDftMut<'a, B> = MatZnxDft<&'a mut [u8], B>;
|
||||
pub type MatZnxDftRef<'a, B> = MatZnxDft<&'a [u8], B>;
|
||||
|
||||
pub trait MatZnxToRef<B: Backend> {
|
||||
fn to_ref(&self) -> MatZnxDft<&[u8], B>;
|
||||
}
|
||||
|
||||
impl<D, B: Backend> MatZnxToRef<B> for MatZnxDft<D, B>
|
||||
where
|
||||
D: AsRef<[u8]>,
|
||||
B: Backend,
|
||||
{
|
||||
fn to_ref(&self) -> MatZnxDft<&[u8], B> {
|
||||
MatZnxDft {
|
||||
data: self.data.as_ref(),
|
||||
n: self.n,
|
||||
rows: self.rows,
|
||||
cols_in: self.cols_in,
|
||||
cols_out: self.cols_out,
|
||||
size: self.size,
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait MatZnxToMut<B: Backend> {
|
||||
fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B>;
|
||||
}
|
||||
|
||||
impl<D, B: Backend> MatZnxToMut<B> for MatZnxDft<D, B>
|
||||
where
|
||||
D: AsRef<[u8]> + AsMut<[u8]>,
|
||||
B: Backend,
|
||||
{
|
||||
fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> {
|
||||
MatZnxDft {
|
||||
data: self.data.as_mut(),
|
||||
n: self.n,
|
||||
rows: self.rows,
|
||||
cols_in: self.cols_in,
|
||||
cols_out: self.cols_out,
|
||||
size: self.size,
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, B: Backend> MatZnxDft<D, B> {
|
||||
pub(crate) fn from_data(data: D, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
|
||||
Self {
|
||||
data,
|
||||
n,
|
||||
rows,
|
||||
cols_in,
|
||||
cols_out,
|
||||
size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,996 +0,0 @@
|
||||
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
|
||||
use crate::ffi::vmp;
|
||||
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
|
||||
use crate::{
|
||||
Backend, FFT64, MatZnxDft, MatZnxDftOwned, MatZnxToMut, MatZnxToRef, Module, ScalarZnxAlloc, ScalarZnxDftAlloc,
|
||||
ScalarZnxDftOps, Scratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps,
|
||||
};
|
||||
|
||||
pub trait MatZnxDftAlloc<B: Backend> {
|
||||
/// Allocates a new [MatZnxDft] with the given number of rows and columns.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `rows`: number of rows (number of [VecZnxDft]).
|
||||
/// * `size`: number of size (number of size of each [VecZnxDft]).
|
||||
fn new_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxDftOwned<B>;
|
||||
|
||||
fn bytes_of_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
|
||||
|
||||
fn new_mat_znx_dft_from_bytes(
|
||||
&self,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
bytes: Vec<u8>,
|
||||
) -> MatZnxDftOwned<B>;
|
||||
}
|
||||
|
||||
pub trait MatZnxDftScratch {
|
||||
/// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft_to_dft].
|
||||
fn vmp_apply_tmp_bytes(
|
||||
&self,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
b_cols_in: usize,
|
||||
b_cols_out: usize,
|
||||
b_size: usize,
|
||||
) -> usize;
|
||||
|
||||
fn mat_znx_dft_mul_x_pow_minus_one_scratch_space(&self, size: usize, cols_out: usize) -> usize;
|
||||
}
|
||||
|
||||
/// This trait implements methods for vector matrix product,
|
||||
/// that is, multiplying a [VecZnx] with a [MatZnxDft].
|
||||
pub trait MatZnxDftOps<BACKEND: Backend> {
|
||||
/// Prepares the ith-row of [MatZnxDft] from a [VecZnxDft].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `res`: [MatZnxDft] on which the values are encoded.
|
||||
/// * `a`: the [VecZnxDft] to encode on the [MatZnxDft].
|
||||
/// * `row_i`: the index of the row to prepare.
|
||||
///
|
||||
/// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes].
|
||||
fn mat_znx_dft_set_row<R, A>(&self, res: &mut R, res_row: usize, res_col_in: usize, a: &A)
|
||||
where
|
||||
R: MatZnxToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>;
|
||||
|
||||
/// Extracts the ith-row of [MatZnxDft] into a [VecZnxDft].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `res`: the [VecZnxDft] to on which to extract the row of the [MatZnxDft].
|
||||
/// * `a`: [MatZnxDft] on which the values are encoded.
|
||||
/// * `row_i`: the index of the row to extract.
|
||||
fn mat_znx_dft_get_row<R, A>(&self, res: &mut R, a: &A, a_row: usize, a_col_in: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: MatZnxToRef<FFT64>;
|
||||
|
||||
/// Multiplies A by (X^{k} - 1) and stores the result on R.
|
||||
fn mat_znx_dft_mul_x_pow_minus_one<R, A>(&self, k: i64, res: &mut R, a: &A, scratch: &mut Scratch)
|
||||
where
|
||||
R: MatZnxToMut<FFT64>,
|
||||
A: MatZnxToRef<FFT64>;
|
||||
|
||||
/// Multiplies A by (X^{k} - 1).
|
||||
fn mat_znx_dft_mul_x_pow_minus_one_inplace<A>(&self, k: i64, a: &mut A, scratch: &mut Scratch)
|
||||
where
|
||||
A: MatZnxToMut<FFT64>;
|
||||
|
||||
/// Multiplies A by (X^{k} - 1).
|
||||
fn mat_znx_dft_mul_x_pow_minus_one_add_inplace<R, A>(&self, k: i64, res: &mut R, a: &A, scratch: &mut Scratch)
|
||||
where
|
||||
R: MatZnxToMut<FFT64>,
|
||||
A: MatZnxToRef<FFT64>;
|
||||
|
||||
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft].
|
||||
/// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
|
||||
///
|
||||
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
|
||||
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
|
||||
/// and each vector a [VecZnxDft] (row) of the [MatZnxDft].
|
||||
///
|
||||
/// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and
|
||||
/// `j` size, the output is a [VecZnx] of `j` size.
|
||||
///
|
||||
/// If there is a mismatch between the dimensions the largest valid ones are used.
|
||||
///
|
||||
/// ```text
|
||||
/// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p|
|
||||
/// |h i j|
|
||||
/// |k l m|
|
||||
/// ```
|
||||
/// where each element is a [VecZnxDft].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `c`: the output of the vector matrix product, as a [VecZnxDft].
|
||||
/// * `a`: the left operand [VecZnxDft] of the vector matrix product.
|
||||
/// * `b`: the right operand [MatZnxDft] of the vector matrix product.
|
||||
/// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
|
||||
fn vmp_apply<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
B: MatZnxToRef<FFT64>;
|
||||
|
||||
// Same as [MatZnxDftOps::vmp_apply] except result is added on R instead of overwritting R.
|
||||
fn vmp_apply_add<R, A, B>(&self, res: &mut R, a: &A, b: &B, scale: usize, scratch: &mut Scratch)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
B: MatZnxToRef<FFT64>;
|
||||
}
|
||||
|
||||
impl<B: Backend> MatZnxDftAlloc<B> for Module<B> {
|
||||
fn bytes_of_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
MatZnxDftOwned::bytes_of(self, rows, cols_in, cols_out, size)
|
||||
}
|
||||
|
||||
fn new_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxDftOwned<B> {
|
||||
MatZnxDftOwned::new(self, rows, cols_in, cols_out, size)
|
||||
}
|
||||
|
||||
fn new_mat_znx_dft_from_bytes(
|
||||
&self,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
bytes: Vec<u8>,
|
||||
) -> MatZnxDftOwned<B> {
|
||||
MatZnxDftOwned::new_from_bytes(self, rows, cols_in, cols_out, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl<BACKEND: Backend> MatZnxDftScratch for Module<BACKEND> {
|
||||
fn vmp_apply_tmp_bytes(
|
||||
&self,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
b_cols_in: usize,
|
||||
b_cols_out: usize,
|
||||
b_size: usize,
|
||||
) -> usize {
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_to_dft_tmp_bytes(
|
||||
self.ptr,
|
||||
(res_size * b_cols_out) as u64,
|
||||
(a_size * b_cols_in) as u64,
|
||||
(b_rows * b_cols_in) as u64,
|
||||
(b_size * b_cols_out) as u64,
|
||||
) as usize
|
||||
}
|
||||
}
|
||||
|
||||
fn mat_znx_dft_mul_x_pow_minus_one_scratch_space(&self, size: usize, cols_out: usize) -> usize {
|
||||
let xpm1_dft: usize = self.bytes_of_scalar_znx(1);
|
||||
let xpm1: usize = self.bytes_of_scalar_znx_dft(1);
|
||||
let tmp: usize = self.bytes_of_vec_znx_dft(cols_out, size);
|
||||
xpm1_dft + (xpm1 | 2 * tmp)
|
||||
}
|
||||
}
|
||||
|
||||
impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
fn mat_znx_dft_mul_x_pow_minus_one<R, A>(&self, k: i64, res: &mut R, a: &A, scratch: &mut Scratch)
|
||||
where
|
||||
R: MatZnxToMut<FFT64>,
|
||||
A: MatZnxToRef<FFT64>,
|
||||
{
|
||||
let mut res: MatZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a: MatZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), self.n());
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.rows(), a.rows());
|
||||
assert_eq!(res.cols_in(), a.cols_in());
|
||||
assert_eq!(res.cols_out(), a.cols_out());
|
||||
}
|
||||
|
||||
let (mut xpm1_dft, scratch1) = scratch.tmp_scalar_znx_dft(self, 1);
|
||||
|
||||
{
|
||||
let (mut xpm1, _) = scratch1.tmp_scalar_znx(self, 1);
|
||||
xpm1.data[0] = 1;
|
||||
self.vec_znx_rotate_inplace(k, &mut xpm1, 0);
|
||||
self.svp_prepare(&mut xpm1_dft, 0, &xpm1, 0);
|
||||
}
|
||||
|
||||
let (mut tmp_0, scratch2) = scratch1.tmp_vec_znx_dft(self, res.cols_out(), res.size());
|
||||
let (mut tmp_1, _) = scratch2.tmp_vec_znx_dft(self, res.cols_out(), res.size());
|
||||
|
||||
(0..res.rows()).for_each(|row_i| {
|
||||
(0..res.cols_in()).for_each(|col_j| {
|
||||
self.mat_znx_dft_get_row(&mut tmp_0, &a, row_i, col_j);
|
||||
|
||||
(0..tmp_0.cols()).for_each(|i| {
|
||||
self.svp_apply(&mut tmp_1, i, &xpm1_dft, 0, &tmp_0, i);
|
||||
self.vec_znx_dft_sub_ab_inplace(&mut tmp_1, i, &tmp_0, i);
|
||||
});
|
||||
|
||||
self.mat_znx_dft_set_row(&mut res, row_i, col_j, &tmp_1);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
fn mat_znx_dft_mul_x_pow_minus_one_inplace<A>(&self, k: i64, a: &mut A, scratch: &mut Scratch)
|
||||
where
|
||||
A: MatZnxToMut<FFT64>,
|
||||
{
|
||||
let mut a: MatZnxDft<&mut [u8], FFT64> = a.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
}
|
||||
|
||||
let (mut xpm1_dft, scratch1) = scratch.tmp_scalar_znx_dft(self, 1);
|
||||
|
||||
{
|
||||
let (mut xpm1, _) = scratch1.tmp_scalar_znx(self, 1);
|
||||
xpm1.data[0] = 1;
|
||||
self.vec_znx_rotate_inplace(k, &mut xpm1, 0);
|
||||
self.svp_prepare(&mut xpm1_dft, 0, &xpm1, 0);
|
||||
}
|
||||
|
||||
let (mut tmp_0, scratch2) = scratch1.tmp_vec_znx_dft(self, a.cols_out(), a.size());
|
||||
let (mut tmp_1, _) = scratch2.tmp_vec_znx_dft(self, a.cols_out(), a.size());
|
||||
|
||||
(0..a.rows()).for_each(|row_i| {
|
||||
(0..a.cols_in()).for_each(|col_j| {
|
||||
self.mat_znx_dft_get_row(&mut tmp_0, &a, row_i, col_j);
|
||||
|
||||
(0..tmp_0.cols()).for_each(|i| {
|
||||
self.svp_apply(&mut tmp_1, i, &xpm1_dft, 0, &tmp_0, i);
|
||||
self.vec_znx_dft_sub_ab_inplace(&mut tmp_1, i, &tmp_0, i);
|
||||
});
|
||||
|
||||
self.mat_znx_dft_set_row(&mut a, row_i, col_j, &tmp_1);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
fn mat_znx_dft_mul_x_pow_minus_one_add_inplace<R, A>(&self, k: i64, res: &mut R, a: &A, scratch: &mut Scratch)
|
||||
where
|
||||
R: MatZnxToMut<FFT64>,
|
||||
A: MatZnxToRef<FFT64>,
|
||||
{
|
||||
let mut res: MatZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a: MatZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
}
|
||||
|
||||
let (mut xpm1_dft, scratch1) = scratch.tmp_scalar_znx_dft(self, 1);
|
||||
|
||||
{
|
||||
let (mut xpm1, _) = scratch1.tmp_scalar_znx(self, 1);
|
||||
xpm1.data[0] = 1;
|
||||
self.vec_znx_rotate_inplace(k, &mut xpm1, 0);
|
||||
self.svp_prepare(&mut xpm1_dft, 0, &xpm1, 0);
|
||||
}
|
||||
|
||||
let (mut tmp_0, scratch2) = scratch1.tmp_vec_znx_dft(self, a.cols_out(), a.size());
|
||||
let (mut tmp_1, _) = scratch2.tmp_vec_znx_dft(self, a.cols_out(), a.size());
|
||||
|
||||
(0..a.rows()).for_each(|row_i| {
|
||||
(0..a.cols_in()).for_each(|col_j| {
|
||||
self.mat_znx_dft_get_row(&mut tmp_0, &a, row_i, col_j);
|
||||
|
||||
(0..tmp_0.cols()).for_each(|i| {
|
||||
self.svp_apply(&mut tmp_1, i, &xpm1_dft, 0, &tmp_0, i);
|
||||
self.vec_znx_dft_sub_ab_inplace(&mut tmp_1, i, &tmp_0, i);
|
||||
});
|
||||
|
||||
self.mat_znx_dft_get_row(&mut tmp_0, &res, row_i, col_j);
|
||||
|
||||
(0..tmp_0.cols()).for_each(|i| {
|
||||
self.vec_znx_dft_add_inplace(&mut tmp_0, i, &tmp_1, i);
|
||||
});
|
||||
|
||||
self.mat_znx_dft_set_row(&mut res, row_i, col_j, &tmp_0);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
fn mat_znx_dft_set_row<R, A>(&self, res: &mut R, res_row: usize, res_col_in: usize, a: &A)
|
||||
where
|
||||
R: MatZnxToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res: MatZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], _> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), self.n());
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(
|
||||
a.cols(),
|
||||
res.cols_out(),
|
||||
"a.cols(): {} != res.cols_out(): {}",
|
||||
a.cols(),
|
||||
res.cols_out()
|
||||
);
|
||||
assert!(
|
||||
res_row < res.rows(),
|
||||
"res_row: {} >= res.rows(): {}",
|
||||
res_row,
|
||||
res.rows()
|
||||
);
|
||||
assert!(
|
||||
res_col_in < res.cols_in(),
|
||||
"res_col_in: {} >= res.cols_in(): {}",
|
||||
res_col_in,
|
||||
res.cols_in()
|
||||
);
|
||||
assert_eq!(
|
||||
res.size(),
|
||||
a.size(),
|
||||
"res.size(): {} != a.size(): {}",
|
||||
res.size(),
|
||||
a.size()
|
||||
);
|
||||
}
|
||||
|
||||
unsafe {
|
||||
vmp::vmp_prepare_row_dft(
|
||||
self.ptr,
|
||||
res.as_mut_ptr() as *mut vmp::vmp_pmat_t,
|
||||
a.as_ptr() as *const vec_znx_dft_t,
|
||||
(res_row * res.cols_in() + res_col_in) as u64,
|
||||
(res.rows() * res.cols_in()) as u64,
|
||||
(res.size() * res.cols_out()) as u64,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn mat_znx_dft_get_row<R, A>(&self, res: &mut R, a: &A, a_row: usize, a_col_in: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: MatZnxToRef<FFT64>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a: MatZnxDft<&[u8], _> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), self.n());
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(
|
||||
res.cols(),
|
||||
a.cols_out(),
|
||||
"res.cols(): {} != a.cols_out(): {}",
|
||||
res.cols(),
|
||||
a.cols_out()
|
||||
);
|
||||
assert!(
|
||||
a_row < a.rows(),
|
||||
"a_row: {} >= a.rows(): {}",
|
||||
a_row,
|
||||
a.rows()
|
||||
);
|
||||
assert!(
|
||||
a_col_in < a.cols_in(),
|
||||
"a_col_in: {} >= a.cols_in(): {}",
|
||||
a_col_in,
|
||||
a.cols_in()
|
||||
);
|
||||
assert_eq!(
|
||||
res.size(),
|
||||
a.size(),
|
||||
"res.size(): {} != a.size(): {}",
|
||||
res.size(),
|
||||
a.size()
|
||||
);
|
||||
}
|
||||
unsafe {
|
||||
vmp::vmp_extract_row_dft(
|
||||
self.ptr,
|
||||
res.as_mut_ptr() as *mut vec_znx_dft_t,
|
||||
a.as_ptr() as *const vmp::vmp_pmat_t,
|
||||
(a_row * a.cols_in() + a_col_in) as u64,
|
||||
(a.rows() * a.cols_in()) as u64,
|
||||
(a.size() * a.cols_out()) as u64,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_apply<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
B: MatZnxToRef<FFT64>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], _> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], _> = a.to_ref();
|
||||
let b: MatZnxDft<&[u8], _> = b.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), self.n());
|
||||
assert_eq!(b.n(), self.n());
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(
|
||||
res.cols(),
|
||||
b.cols_out(),
|
||||
"res.cols(): {} != b.cols_out: {}",
|
||||
res.cols(),
|
||||
b.cols_out()
|
||||
);
|
||||
assert_eq!(
|
||||
a.cols(),
|
||||
b.cols_in(),
|
||||
"a.cols(): {} != b.cols_in: {}",
|
||||
a.cols(),
|
||||
b.cols_in()
|
||||
);
|
||||
}
|
||||
|
||||
let (tmp_bytes, _) = scratch.tmp_slice(self.vmp_apply_tmp_bytes(
|
||||
res.size(),
|
||||
a.size(),
|
||||
b.rows(),
|
||||
b.cols_in(),
|
||||
b.cols_out(),
|
||||
b.size(),
|
||||
));
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_to_dft(
|
||||
self.ptr,
|
||||
res.as_mut_ptr() as *mut vec_znx_dft_t,
|
||||
(res.size() * res.cols()) as u64,
|
||||
a.as_ptr() as *const vec_znx_dft_t,
|
||||
(a.size() * a.cols()) as u64,
|
||||
b.as_ptr() as *const vmp::vmp_pmat_t,
|
||||
(b.rows() * b.cols_in()) as u64,
|
||||
(b.size() * b.cols_out()) as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_apply_add<R, A, B>(&self, res: &mut R, a: &A, b: &B, scale: usize, scratch: &mut Scratch)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
B: MatZnxToRef<FFT64>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], _> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], _> = a.to_ref();
|
||||
let b: MatZnxDft<&[u8], _> = b.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), self.n());
|
||||
assert_eq!(b.n(), self.n());
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(
|
||||
res.cols(),
|
||||
b.cols_out(),
|
||||
"res.cols(): {} != b.cols_out: {}",
|
||||
res.cols(),
|
||||
b.cols_out()
|
||||
);
|
||||
assert_eq!(
|
||||
a.cols(),
|
||||
b.cols_in(),
|
||||
"a.cols(): {} != b.cols_in: {}",
|
||||
a.cols(),
|
||||
b.cols_in()
|
||||
);
|
||||
}
|
||||
|
||||
let (tmp_bytes, _) = scratch.tmp_slice(self.vmp_apply_tmp_bytes(
|
||||
res.size(),
|
||||
a.size(),
|
||||
b.rows(),
|
||||
b.cols_in(),
|
||||
b.cols_out(),
|
||||
b.size(),
|
||||
));
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_to_dft_add(
|
||||
self.ptr,
|
||||
res.as_mut_ptr() as *mut vec_znx_dft_t,
|
||||
(res.size() * res.cols()) as u64,
|
||||
a.as_ptr() as *const vec_znx_dft_t,
|
||||
(a.size() * a.cols()) as u64,
|
||||
b.as_ptr() as *const vmp::vmp_pmat_t,
|
||||
(b.rows() * b.cols_in()) as u64,
|
||||
(b.size() * b.cols_out()) as u64,
|
||||
(scale * b.cols_out()) as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{
|
||||
Decoding, FFT64, FillUniform, MatZnxDft, MatZnxDftOps, Module, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig,
|
||||
VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxInfos, ZnxView,
|
||||
ZnxViewMut, ZnxZero,
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
use super::{MatZnxDftAlloc, MatZnxDftScratch};
|
||||
|
||||
#[test]
|
||||
fn vmp_set_row() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(16);
|
||||
let basek: usize = 8;
|
||||
let mat_rows: usize = 4;
|
||||
let mat_cols_in: usize = 2;
|
||||
let mat_cols_out: usize = 2;
|
||||
let mat_size: usize = 5;
|
||||
let mut a: VecZnx<Vec<u8>> = module.new_vec_znx(mat_cols_out, mat_size);
|
||||
let mut a_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size);
|
||||
let mut b_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size);
|
||||
let mut mat: MatZnxDft<Vec<u8>, FFT64> = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size);
|
||||
|
||||
for col_in in 0..mat_cols_in {
|
||||
for row_i in 0..mat_rows {
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
(0..mat_cols_out).for_each(|col_out| {
|
||||
a.fill_uniform(basek, col_out, mat_size, &mut source);
|
||||
module.vec_znx_dft(1, 0, &mut a_dft, col_out, &a, col_out);
|
||||
});
|
||||
module.mat_znx_dft_set_row(&mut mat, row_i, col_in, &a_dft);
|
||||
module.mat_znx_dft_get_row(&mut b_dft, &mat, row_i, col_in);
|
||||
assert_eq!(a_dft.raw(), b_dft.raw());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vmp_apply() {
|
||||
let log_n: i32 = 5;
|
||||
let n: usize = 1 << log_n;
|
||||
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let basek: usize = 15;
|
||||
let a_size: usize = 5;
|
||||
let mat_size: usize = 6;
|
||||
let res_size: usize = a_size;
|
||||
|
||||
[1, 2].iter().for_each(|cols_in| {
|
||||
[1, 2].iter().for_each(|cols_out| {
|
||||
let a_cols: usize = *cols_in;
|
||||
let res_cols: usize = *cols_out;
|
||||
|
||||
let mat_rows: usize = a_size;
|
||||
let mat_cols_in: usize = a_cols;
|
||||
let mat_cols_out: usize = res_cols;
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
module.vmp_apply_tmp_bytes(
|
||||
res_size,
|
||||
a_size,
|
||||
mat_rows,
|
||||
mat_cols_in,
|
||||
mat_cols_out,
|
||||
mat_size,
|
||||
) | module.vec_znx_big_normalize_tmp_bytes(),
|
||||
);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = module.new_vec_znx(a_cols, a_size);
|
||||
|
||||
(0..a_cols).for_each(|i| {
|
||||
a.at_mut(i, a_size - 1)[i + 1] = 1;
|
||||
});
|
||||
|
||||
let mut mat_znx_dft: MatZnxDft<Vec<u8>, FFT64> =
|
||||
module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size);
|
||||
|
||||
let mut c_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size);
|
||||
let mut c_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(mat_cols_out, mat_size);
|
||||
|
||||
let mut tmp: VecZnx<Vec<u8>> = module.new_vec_znx(mat_cols_out, mat_size);
|
||||
|
||||
// Construts a [VecZnxMatDft] that performs cyclic rotations on each submatrix.
|
||||
(0..a.size()).for_each(|row_i| {
|
||||
(0..mat_cols_in).for_each(|col_in_i| {
|
||||
(0..mat_cols_out).for_each(|col_out_i| {
|
||||
let idx = 1 + col_in_i * mat_cols_out + col_out_i;
|
||||
tmp.at_mut(col_out_i, row_i)[idx] = 1 as i64; // X^{idx}
|
||||
module.vec_znx_dft(1, 0, &mut c_dft, col_out_i, &tmp, col_out_i);
|
||||
tmp.at_mut(col_out_i, row_i)[idx] = 0 as i64;
|
||||
});
|
||||
module.mat_znx_dft_set_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft);
|
||||
});
|
||||
});
|
||||
|
||||
let mut a_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(a_cols, a_size);
|
||||
(0..a_cols).for_each(|i| {
|
||||
module.vec_znx_dft(1, 0, &mut a_dft, i, &a, i);
|
||||
});
|
||||
|
||||
module.vmp_apply(&mut c_dft, &a_dft, &mat_znx_dft, scratch.borrow());
|
||||
|
||||
let mut res_have_vi64: Vec<i64> = vec![i64::default(); n];
|
||||
|
||||
let mut res_have: VecZnx<Vec<u8>> = module.new_vec_znx(res_cols, res_size);
|
||||
(0..mat_cols_out).for_each(|i| {
|
||||
module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i);
|
||||
module.vec_znx_big_normalize(basek, &mut res_have, i, &c_big, i, scratch.borrow());
|
||||
});
|
||||
|
||||
(0..mat_cols_out).for_each(|col_i| {
|
||||
let mut res_want_vi64: Vec<i64> = vec![i64::default(); n];
|
||||
(0..a_cols).for_each(|i| {
|
||||
res_want_vi64[(i + 1) + (1 + i * mat_cols_out + col_i)] = 1;
|
||||
});
|
||||
res_have.decode_vec_i64(col_i, basek, basek * a_size, &mut res_have_vi64);
|
||||
assert_eq!(res_have_vi64, res_want_vi64);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vmp_apply_add() {
|
||||
let log_n: i32 = 4;
|
||||
let n: usize = 1 << log_n;
|
||||
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let basek: usize = 8;
|
||||
let a_size: usize = 5;
|
||||
let mat_size: usize = 5;
|
||||
let res_size: usize = a_size;
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
[1, 2].iter().for_each(|cols_in| {
|
||||
[1, 2].iter().for_each(|cols_out| {
|
||||
(0..res_size).for_each(|shift| {
|
||||
let a_cols: usize = *cols_in;
|
||||
let res_cols: usize = *cols_out;
|
||||
|
||||
let mat_rows: usize = a_size;
|
||||
let mat_cols_in: usize = a_cols;
|
||||
let mat_cols_out: usize = res_cols;
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
module.vmp_apply_tmp_bytes(
|
||||
res_size,
|
||||
a_size,
|
||||
mat_rows,
|
||||
mat_cols_in,
|
||||
mat_cols_out,
|
||||
mat_size,
|
||||
) | module.vec_znx_big_normalize_tmp_bytes(),
|
||||
);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = module.new_vec_znx(a_cols, a_size);
|
||||
|
||||
(0..a_cols).for_each(|col_i| {
|
||||
a.fill_uniform(basek, col_i, a.size(), &mut source);
|
||||
});
|
||||
|
||||
let mut mat_znx_dft: MatZnxDft<Vec<u8>, FFT64> =
|
||||
module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size);
|
||||
|
||||
let mut c_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size);
|
||||
let mut c_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(mat_cols_out, mat_size);
|
||||
|
||||
let mut tmp: VecZnx<Vec<u8>> = module.new_vec_znx(mat_cols_out, mat_size);
|
||||
|
||||
// Construts a [VecZnxMatDft] that performs cyclic rotations on each submatrix.
|
||||
(0..a.size()).for_each(|row_i| {
|
||||
(0..mat_cols_in).for_each(|col_in_i| {
|
||||
(0..mat_cols_out).for_each(|col_out_i| {
|
||||
let idx: usize = 1 + col_in_i * mat_cols_out + col_out_i;
|
||||
tmp.at_mut(col_out_i, row_i)[idx] = 1 as i64; // X^{idx}
|
||||
module.vec_znx_dft(1, 0, &mut c_dft, col_out_i, &tmp, col_out_i);
|
||||
tmp.at_mut(col_out_i, row_i)[idx] = 0 as i64;
|
||||
});
|
||||
module.mat_znx_dft_set_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft);
|
||||
});
|
||||
});
|
||||
|
||||
let mut a_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(a_cols, a_size);
|
||||
(0..a_cols).for_each(|i| {
|
||||
module.vec_znx_dft(1, 0, &mut a_dft, i, &a, i);
|
||||
});
|
||||
|
||||
c_dft.zero();
|
||||
(0..c_dft.cols()).for_each(|i| {
|
||||
module.vec_znx_dft(1, 0, &mut c_dft, i, &a, 0);
|
||||
});
|
||||
|
||||
module.vmp_apply_add(&mut c_dft, &a_dft, &mat_znx_dft, shift, scratch.borrow());
|
||||
|
||||
let mut res_have: VecZnx<Vec<u8>> = module.new_vec_znx(res_cols, mat_size);
|
||||
(0..mat_cols_out).for_each(|i| {
|
||||
module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i);
|
||||
module.vec_znx_big_normalize(basek, &mut res_have, i, &c_big, i, scratch.borrow());
|
||||
});
|
||||
|
||||
let mut res_want: VecZnx<Vec<u8>> = module.new_vec_znx(res_cols, mat_size);
|
||||
|
||||
// Equivalent to vmp_add & scale
|
||||
module.vmp_apply(&mut c_dft, &a_dft, &mat_znx_dft, scratch.borrow());
|
||||
(0..mat_cols_out).for_each(|i| {
|
||||
module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i);
|
||||
module.vec_znx_big_normalize(basek, &mut res_want, i, &c_big, i, scratch.borrow());
|
||||
});
|
||||
module.vec_znx_shift_inplace(
|
||||
basek,
|
||||
(shift * basek) as i64,
|
||||
&mut res_want,
|
||||
scratch.borrow(),
|
||||
);
|
||||
(0..res_cols).for_each(|i| {
|
||||
module.vec_znx_add_inplace(&mut res_want, i, &a, 0);
|
||||
module.vec_znx_normalize_inplace(basek, &mut res_want, i, scratch.borrow());
|
||||
});
|
||||
|
||||
assert_eq!(res_want, res_have);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vmp_apply_digits() {
|
||||
let log_n: i32 = 4;
|
||||
let n: usize = 1 << log_n;
|
||||
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let basek: usize = 8;
|
||||
let a_size: usize = 6;
|
||||
let mat_size: usize = 6;
|
||||
let res_size: usize = a_size;
|
||||
|
||||
[1, 2].iter().for_each(|cols_in| {
|
||||
[1, 2].iter().for_each(|cols_out| {
|
||||
[1, 3, 6].iter().for_each(|digits| {
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let a_cols: usize = *cols_in;
|
||||
let res_cols: usize = *cols_out;
|
||||
|
||||
let mat_rows: usize = a_size;
|
||||
let mat_cols_in: usize = a_cols;
|
||||
let mat_cols_out: usize = res_cols;
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
module.vmp_apply_tmp_bytes(
|
||||
res_size,
|
||||
a_size,
|
||||
mat_rows,
|
||||
mat_cols_in,
|
||||
mat_cols_out,
|
||||
mat_size,
|
||||
) | module.vec_znx_big_normalize_tmp_bytes(),
|
||||
);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = module.new_vec_znx(a_cols, a_size);
|
||||
|
||||
(0..a_cols).for_each(|col_i| {
|
||||
a.fill_uniform(basek, col_i, a.size(), &mut source);
|
||||
});
|
||||
|
||||
let mut mat_znx_dft: MatZnxDft<Vec<u8>, FFT64> =
|
||||
module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size);
|
||||
|
||||
let mut c_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size);
|
||||
let mut c_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(mat_cols_out, mat_size);
|
||||
|
||||
let mut tmp: VecZnx<Vec<u8>> = module.new_vec_znx(mat_cols_out, mat_size);
|
||||
|
||||
let rows: usize = a.size() / digits;
|
||||
|
||||
let shift: usize = 1;
|
||||
|
||||
// Construts a [VecZnxMatDft] that performs cyclic rotations on each submatrix.
|
||||
(0..rows).for_each(|row_i| {
|
||||
(0..mat_cols_in).for_each(|col_in_i| {
|
||||
(0..mat_cols_out).for_each(|col_out_i| {
|
||||
let idx: usize = shift + col_in_i * mat_cols_out + col_out_i;
|
||||
let limb: usize = (digits - 1) + row_i * digits;
|
||||
tmp.at_mut(col_out_i, limb)[idx] = 1 as i64; // X^{idx}
|
||||
module.vec_znx_dft(1, 0, &mut c_dft, col_out_i, &tmp, col_out_i);
|
||||
tmp.at_mut(col_out_i, limb)[idx] = 0 as i64;
|
||||
});
|
||||
module.mat_znx_dft_set_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft);
|
||||
});
|
||||
});
|
||||
|
||||
let mut a_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(a_cols, (a_size + digits - 1) / digits);
|
||||
|
||||
(0..*digits).for_each(|di| {
|
||||
(0..a_cols).for_each(|col_i| {
|
||||
module.vec_znx_dft(*digits, digits - 1 - di, &mut a_dft, col_i, &a, col_i);
|
||||
});
|
||||
|
||||
if di == 0 {
|
||||
module.vmp_apply(&mut c_dft, &a_dft, &mat_znx_dft, scratch.borrow());
|
||||
} else {
|
||||
module.vmp_apply_add(&mut c_dft, &a_dft, &mat_znx_dft, di, scratch.borrow());
|
||||
}
|
||||
});
|
||||
|
||||
let mut res_have: VecZnx<Vec<u8>> = module.new_vec_znx(res_cols, mat_size);
|
||||
(0..mat_cols_out).for_each(|i| {
|
||||
module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i);
|
||||
module.vec_znx_big_normalize(basek, &mut res_have, i, &c_big, i, scratch.borrow());
|
||||
});
|
||||
|
||||
let mut res_want: VecZnx<Vec<u8>> = module.new_vec_znx(res_cols, mat_size);
|
||||
let mut tmp: VecZnx<Vec<u8>> = module.new_vec_znx(res_cols, mat_size);
|
||||
(0..res_cols).for_each(|col_i| {
|
||||
(0..a_cols).for_each(|j| {
|
||||
module.vec_znx_rotate(
|
||||
(col_i + j * mat_cols_out + shift) as i64,
|
||||
&mut tmp,
|
||||
0,
|
||||
&a,
|
||||
j,
|
||||
);
|
||||
module.vec_znx_add_inplace(&mut res_want, col_i, &tmp, 0);
|
||||
});
|
||||
module.vec_znx_normalize_inplace(basek, &mut res_want, col_i, scratch.borrow());
|
||||
});
|
||||
|
||||
assert_eq!(res_have, res_want)
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mat_znx_dft_mul_x_pow_minus_one() {
|
||||
let log_n: i32 = 5;
|
||||
let n: usize = 1 << log_n;
|
||||
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let basek: usize = 8;
|
||||
let rows: usize = 2;
|
||||
let cols_in: usize = 2;
|
||||
let cols_out: usize = 2;
|
||||
let size: usize = 4;
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(module.mat_znx_dft_mul_x_pow_minus_one_scratch_space(size, cols_out));
|
||||
|
||||
let mut mat_want: MatZnxDft<Vec<u8>, FFT64> = module.new_mat_znx_dft(rows, cols_in, cols_out, size);
|
||||
let mut mat_have: MatZnxDft<Vec<u8>, FFT64> = module.new_mat_znx_dft(rows, cols_in, cols_out, size);
|
||||
|
||||
let mut tmp: VecZnx<Vec<u8>> = module.new_vec_znx(1, size);
|
||||
let mut tmp_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(cols_out, size);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
(0..mat_want.rows()).for_each(|row_i| {
|
||||
(0..mat_want.cols_in()).for_each(|col_i| {
|
||||
(0..cols_out).for_each(|j| {
|
||||
tmp.fill_uniform(basek, 0, size, &mut source);
|
||||
module.vec_znx_dft(1, 0, &mut tmp_dft, j, &tmp, 0);
|
||||
});
|
||||
|
||||
module.mat_znx_dft_set_row(&mut mat_want, row_i, col_i, &tmp_dft);
|
||||
});
|
||||
});
|
||||
|
||||
let k: i64 = 1;
|
||||
|
||||
module.mat_znx_dft_mul_x_pow_minus_one(k, &mut mat_have, &mat_want, scratch.borrow());
|
||||
|
||||
let mut have: VecZnx<Vec<u8>> = module.new_vec_znx(cols_out, size);
|
||||
let mut want: VecZnx<Vec<u8>> = module.new_vec_znx(cols_out, size);
|
||||
let mut tmp_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, size);
|
||||
|
||||
(0..mat_want.rows()).for_each(|row_i| {
|
||||
(0..mat_want.cols_in()).for_each(|col_i| {
|
||||
module.mat_znx_dft_get_row(&mut tmp_dft, &mat_want, row_i, col_i);
|
||||
|
||||
(0..cols_out).for_each(|j| {
|
||||
module.vec_znx_idft(&mut tmp_big, 0, &tmp_dft, j, scratch.borrow());
|
||||
module.vec_znx_big_normalize(basek, &mut tmp, 0, &tmp_big, 0, scratch.borrow());
|
||||
module.vec_znx_rotate(k, &mut want, j, &tmp, 0);
|
||||
module.vec_znx_sub_ab_inplace(&mut want, j, &tmp, 0);
|
||||
module.vec_znx_normalize_inplace(basek, &mut want, j, scratch.borrow());
|
||||
});
|
||||
|
||||
module.mat_znx_dft_get_row(&mut tmp_dft, &mat_have, row_i, col_i);
|
||||
|
||||
(0..cols_out).for_each(|j| {
|
||||
module.vec_znx_idft(&mut tmp_big, 0, &tmp_dft, j, scratch.borrow());
|
||||
module.vec_znx_big_normalize(basek, &mut have, j, &tmp_big, 0, scratch.borrow());
|
||||
});
|
||||
|
||||
assert_eq!(have, want)
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mat_znx_dft_mul_x_pow_minus_one_add_inplace() {
|
||||
let log_n: i32 = 5;
|
||||
let n: usize = 1 << log_n;
|
||||
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let basek: usize = 8;
|
||||
let rows: usize = 2;
|
||||
let cols_in: usize = 2;
|
||||
let cols_out: usize = 2;
|
||||
let size: usize = 4;
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(module.mat_znx_dft_mul_x_pow_minus_one_scratch_space(size, cols_out));
|
||||
|
||||
let mut mat_want: MatZnxDft<Vec<u8>, FFT64> = module.new_mat_znx_dft(rows, cols_in, cols_out, size);
|
||||
let mut mat_have: MatZnxDft<Vec<u8>, FFT64> = module.new_mat_znx_dft(rows, cols_in, cols_out, size);
|
||||
|
||||
let mut tmp: VecZnx<Vec<u8>> = module.new_vec_znx(1, size);
|
||||
let mut tmp_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(cols_out, size);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
(0..mat_have.rows()).for_each(|row_i| {
|
||||
(0..mat_have.cols_in()).for_each(|col_i| {
|
||||
(0..cols_out).for_each(|j| {
|
||||
tmp.fill_uniform(basek, 0, size, &mut source);
|
||||
module.vec_znx_dft(1, 0, &mut tmp_dft, j, &tmp, 0);
|
||||
});
|
||||
|
||||
module.mat_znx_dft_set_row(&mut mat_have, row_i, col_i, &tmp_dft);
|
||||
});
|
||||
});
|
||||
|
||||
(0..mat_want.rows()).for_each(|row_i| {
|
||||
(0..mat_want.cols_in()).for_each(|col_i| {
|
||||
(0..cols_out).for_each(|j| {
|
||||
tmp.fill_uniform(basek, 0, size, &mut source);
|
||||
module.vec_znx_dft(1, 0, &mut tmp_dft, j, &tmp, 0);
|
||||
});
|
||||
|
||||
module.mat_znx_dft_set_row(&mut mat_want, row_i, col_i, &tmp_dft);
|
||||
});
|
||||
});
|
||||
|
||||
let k: i64 = 1;
|
||||
|
||||
module.mat_znx_dft_mul_x_pow_minus_one_add_inplace(k, &mut mat_have, &mat_want, scratch.borrow());
|
||||
|
||||
let mut have: VecZnx<Vec<u8>> = module.new_vec_znx(cols_out, size);
|
||||
let mut want: VecZnx<Vec<u8>> = module.new_vec_znx(cols_out, size);
|
||||
let mut tmp_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, size);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
(0..mat_want.rows()).for_each(|row_i| {
|
||||
(0..mat_want.cols_in()).for_each(|col_i| {
|
||||
module.mat_znx_dft_get_row(&mut tmp_dft, &mat_want, row_i, col_i);
|
||||
|
||||
(0..cols_out).for_each(|j| {
|
||||
module.vec_znx_idft(&mut tmp_big, 0, &tmp_dft, j, scratch.borrow());
|
||||
module.vec_znx_big_normalize(basek, &mut tmp, 0, &tmp_big, 0, scratch.borrow());
|
||||
module.vec_znx_rotate(k, &mut want, j, &tmp, 0);
|
||||
module.vec_znx_sub_ab_inplace(&mut want, j, &tmp, 0);
|
||||
|
||||
tmp.fill_uniform(basek, 0, size, &mut source);
|
||||
module.vec_znx_add_inplace(&mut want, j, &tmp, 0);
|
||||
module.vec_znx_normalize_inplace(basek, &mut want, j, scratch.borrow());
|
||||
});
|
||||
|
||||
module.mat_znx_dft_get_row(&mut tmp_dft, &mat_have, row_i, col_i);
|
||||
|
||||
(0..cols_out).for_each(|j| {
|
||||
module.vec_znx_idft(&mut tmp_big, 0, &tmp_dft, j, scratch.borrow());
|
||||
module.vec_znx_big_normalize(basek, &mut have, j, &tmp_big, 0, scratch.borrow());
|
||||
});
|
||||
|
||||
assert_eq!(have, want)
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1,365 +0,0 @@
|
||||
use crate::znx_base::ZnxViewMut;
|
||||
use crate::{FFT64, VecZnx, VecZnxBig, VecZnxBigToMut, VecZnxToMut};
|
||||
use rand_distr::{Distribution, Normal};
|
||||
use sampling::source::Source;
|
||||
|
||||
pub trait FillUniform {
|
||||
/// Fills the first `size` size with uniform values in \[-2^{basek-1}, 2^{basek-1}\]
|
||||
fn fill_uniform(&mut self, basek: usize, col_i: usize, size: usize, source: &mut Source);
|
||||
}
|
||||
|
||||
pub trait FillDistF64 {
|
||||
fn fill_dist_f64<D: Distribution<f64>>(
|
||||
&mut self,
|
||||
basek: usize,
|
||||
col_i: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
);
|
||||
}
|
||||
|
||||
pub trait AddDistF64 {
|
||||
/// Adds vector sampled according to the provided distribution, scaled by 2^{-k} and bounded to \[-bound, bound\].
|
||||
fn add_dist_f64<D: Distribution<f64>>(
|
||||
&mut self,
|
||||
basek: usize,
|
||||
col_i: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
);
|
||||
}
|
||||
|
||||
pub trait FillNormal {
|
||||
fn fill_normal(&mut self, basek: usize, col_i: usize, k: usize, source: &mut Source, sigma: f64, bound: f64);
|
||||
}
|
||||
|
||||
pub trait AddNormal {
|
||||
/// Adds a discrete normal vector scaled by 2^{-k} with the provided standard deviation and bounded to \[-bound, bound\].
|
||||
fn add_normal(&mut self, basek: usize, col_i: usize, k: usize, source: &mut Source, sigma: f64, bound: f64);
|
||||
}
|
||||
|
||||
impl<T: AsMut<[u8]> + AsRef<[u8]>> FillUniform for VecZnx<T>
|
||||
where
|
||||
VecZnx<T>: VecZnxToMut,
|
||||
{
|
||||
fn fill_uniform(&mut self, basek: usize, col_i: usize, size: usize, source: &mut Source) {
|
||||
let mut a: VecZnx<&mut [u8]> = self.to_mut();
|
||||
let base2k: u64 = 1 << basek;
|
||||
let mask: u64 = base2k - 1;
|
||||
let base2k_half: i64 = (base2k >> 1) as i64;
|
||||
(0..size).for_each(|j| {
|
||||
a.at_mut(col_i, j)
|
||||
.iter_mut()
|
||||
.for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsMut<[u8]> + AsRef<[u8]>> FillDistF64 for VecZnx<T>
|
||||
where
|
||||
VecZnx<T>: VecZnxToMut,
|
||||
{
|
||||
fn fill_dist_f64<D: Distribution<f64>>(
|
||||
&mut self,
|
||||
basek: usize,
|
||||
col_i: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) {
|
||||
let mut a: VecZnx<&mut [u8]> = self.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = (k + basek - 1) / basek - 1;
|
||||
let basek_rem: usize = (limb + 1) * basek - k;
|
||||
|
||||
if basek_rem != 0 {
|
||||
a.at_mut(col_i, limb).iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a = (dist_f64.round() as i64) << basek_rem;
|
||||
});
|
||||
} else {
|
||||
a.at_mut(col_i, limb).iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a = dist_f64.round() as i64
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsMut<[u8]> + AsRef<[u8]>> AddDistF64 for VecZnx<T>
|
||||
where
|
||||
VecZnx<T>: VecZnxToMut,
|
||||
{
|
||||
fn add_dist_f64<D: Distribution<f64>>(
|
||||
&mut self,
|
||||
basek: usize,
|
||||
col_i: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) {
|
||||
let mut a: VecZnx<&mut [u8]> = self.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = (k + basek - 1) / basek - 1;
|
||||
let basek_rem: usize = (limb + 1) * basek - k;
|
||||
|
||||
if basek_rem != 0 {
|
||||
a.at_mut(col_i, limb).iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a += (dist_f64.round() as i64) << basek_rem;
|
||||
});
|
||||
} else {
|
||||
a.at_mut(col_i, limb).iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a += dist_f64.round() as i64
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsMut<[u8]> + AsRef<[u8]>> FillNormal for VecZnx<T>
|
||||
where
|
||||
VecZnx<T>: VecZnxToMut,
|
||||
{
|
||||
fn fill_normal(&mut self, basek: usize, col_i: usize, k: usize, source: &mut Source, sigma: f64, bound: f64) {
|
||||
self.fill_dist_f64(
|
||||
basek,
|
||||
col_i,
|
||||
k,
|
||||
source,
|
||||
Normal::new(0.0, sigma).unwrap(),
|
||||
bound,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsMut<[u8]> + AsRef<[u8]>> AddNormal for VecZnx<T>
|
||||
where
|
||||
VecZnx<T>: VecZnxToMut,
|
||||
{
|
||||
fn add_normal(&mut self, basek: usize, col_i: usize, k: usize, source: &mut Source, sigma: f64, bound: f64) {
|
||||
self.add_dist_f64(
|
||||
basek,
|
||||
col_i,
|
||||
k,
|
||||
source,
|
||||
Normal::new(0.0, sigma).unwrap(),
|
||||
bound,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsMut<[u8]> + AsRef<[u8]>> FillDistF64 for VecZnxBig<T, FFT64>
|
||||
where
|
||||
VecZnxBig<T, FFT64>: VecZnxBigToMut<FFT64>,
|
||||
{
|
||||
fn fill_dist_f64<D: Distribution<f64>>(
|
||||
&mut self,
|
||||
basek: usize,
|
||||
col_i: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) {
|
||||
let mut a: VecZnxBig<&mut [u8], FFT64> = self.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = (k + basek - 1) / basek - 1;
|
||||
let basek_rem: usize = (limb + 1) * basek - k;
|
||||
|
||||
if basek_rem != 0 {
|
||||
a.at_mut(col_i, limb).iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a = (dist_f64.round() as i64) << basek_rem;
|
||||
});
|
||||
} else {
|
||||
a.at_mut(col_i, limb).iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a = dist_f64.round() as i64
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsMut<[u8]> + AsRef<[u8]>> AddDistF64 for VecZnxBig<T, FFT64>
|
||||
where
|
||||
VecZnxBig<T, FFT64>: VecZnxBigToMut<FFT64>,
|
||||
{
|
||||
fn add_dist_f64<D: Distribution<f64>>(
|
||||
&mut self,
|
||||
basek: usize,
|
||||
col_i: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) {
|
||||
let mut a: VecZnxBig<&mut [u8], FFT64> = self.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = (k + basek - 1) / basek - 1;
|
||||
let basek_rem: usize = (limb + 1) * basek - k;
|
||||
|
||||
if basek_rem != 0 {
|
||||
a.at_mut(col_i, limb).iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a += (dist_f64.round() as i64) << basek_rem;
|
||||
});
|
||||
} else {
|
||||
a.at_mut(col_i, limb).iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a += dist_f64.round() as i64
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsMut<[u8]> + AsRef<[u8]>> FillNormal for VecZnxBig<T, FFT64>
|
||||
where
|
||||
VecZnxBig<T, FFT64>: VecZnxBigToMut<FFT64>,
|
||||
{
|
||||
fn fill_normal(&mut self, basek: usize, col_i: usize, k: usize, source: &mut Source, sigma: f64, bound: f64) {
|
||||
self.fill_dist_f64(
|
||||
basek,
|
||||
col_i,
|
||||
k,
|
||||
source,
|
||||
Normal::new(0.0, sigma).unwrap(),
|
||||
bound,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsMut<[u8]> + AsRef<[u8]>> AddNormal for VecZnxBig<T, FFT64>
|
||||
where
|
||||
VecZnxBig<T, FFT64>: VecZnxBigToMut<FFT64>,
|
||||
{
|
||||
fn add_normal(&mut self, basek: usize, col_i: usize, k: usize, source: &mut Source, sigma: f64, bound: f64) {
|
||||
self.add_dist_f64(
|
||||
basek,
|
||||
col_i,
|
||||
k,
|
||||
source,
|
||||
Normal::new(0.0, sigma).unwrap(),
|
||||
bound,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{AddNormal, FillUniform};
|
||||
use crate::vec_znx_ops::*;
|
||||
use crate::znx_base::*;
|
||||
use crate::{FFT64, Module, Stats, VecZnx};
|
||||
use sampling::source::Source;
|
||||
|
||||
#[test]
|
||||
fn vec_znx_fill_uniform() {
|
||||
let n: usize = 4096;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let basek: usize = 17;
|
||||
let size: usize = 5;
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let cols: usize = 2;
|
||||
let zero: Vec<i64> = vec![0; n];
|
||||
let one_12_sqrt: f64 = 0.28867513459481287;
|
||||
(0..cols).for_each(|col_i| {
|
||||
let mut a: VecZnx<_> = module.new_vec_znx(cols, size);
|
||||
a.fill_uniform(basek, col_i, size, &mut source);
|
||||
(0..cols).for_each(|col_j| {
|
||||
if col_j != col_i {
|
||||
(0..size).for_each(|limb_i| {
|
||||
assert_eq!(a.at(col_j, limb_i), zero);
|
||||
})
|
||||
} else {
|
||||
let std: f64 = a.std(col_i, basek);
|
||||
assert!(
|
||||
(std - one_12_sqrt).abs() < 0.01,
|
||||
"std={} ~!= {}",
|
||||
std,
|
||||
one_12_sqrt
|
||||
);
|
||||
}
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vec_znx_add_normal() {
|
||||
let n: usize = 4096;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let basek: usize = 17;
|
||||
let k: usize = 2 * 17;
|
||||
let size: usize = 5;
|
||||
let sigma: f64 = 3.2;
|
||||
let bound: f64 = 6.0 * sigma;
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let cols: usize = 2;
|
||||
let zero: Vec<i64> = vec![0; n];
|
||||
let k_f64: f64 = (1u64 << k as u64) as f64;
|
||||
(0..cols).for_each(|col_i| {
|
||||
let mut a: VecZnx<_> = module.new_vec_znx(cols, size);
|
||||
a.add_normal(basek, col_i, k, &mut source, sigma, bound);
|
||||
(0..cols).for_each(|col_j| {
|
||||
if col_j != col_i {
|
||||
(0..size).for_each(|limb_i| {
|
||||
assert_eq!(a.at(col_j, limb_i), zero);
|
||||
})
|
||||
} else {
|
||||
let std: f64 = a.std(col_i, basek) * k_f64;
|
||||
assert!((std - sigma).abs() < 0.1, "std={} ~!= {}", std, sigma);
|
||||
}
|
||||
})
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1,180 +0,0 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::ffi::svp;
|
||||
use crate::znx_base::ZnxInfos;
|
||||
use crate::{
|
||||
Backend, DataView, DataViewMut, FFT64, Module, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, ZnxSliceSize, ZnxView,
|
||||
alloc_aligned,
|
||||
};
|
||||
|
||||
pub struct ScalarZnxDft<D, B: Backend> {
|
||||
data: D,
|
||||
n: usize,
|
||||
cols: usize,
|
||||
_phantom: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<D, B: Backend> ZnxInfos for ScalarZnxDft<D, B> {
|
||||
fn cols(&self) -> usize {
|
||||
self.cols
|
||||
}
|
||||
|
||||
fn rows(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn n(&self) -> usize {
|
||||
self.n
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
1
|
||||
}
|
||||
}
|
||||
|
||||
impl<D> ZnxSliceSize for ScalarZnxDft<D, FFT64> {
|
||||
fn sl(&self) -> usize {
|
||||
self.n()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, B: Backend> DataView for ScalarZnxDft<D, B> {
|
||||
type D = D;
|
||||
fn data(&self) -> &Self::D {
|
||||
&self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, B: Backend> DataViewMut for ScalarZnxDft<D, B> {
|
||||
fn data_mut(&mut self) -> &mut Self::D {
|
||||
&mut self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: AsRef<[u8]>> ZnxView for ScalarZnxDft<D, FFT64> {
|
||||
type Scalar = f64;
|
||||
}
|
||||
|
||||
pub(crate) fn bytes_of_scalar_znx_dft<B: Backend>(module: &Module<B>, cols: usize) -> usize {
|
||||
ScalarZnxDftOwned::bytes_of(module, cols)
|
||||
}
|
||||
|
||||
impl<D: From<Vec<u8>>, B: Backend> ScalarZnxDft<D, B> {
|
||||
pub(crate) fn bytes_of(module: &Module<B>, cols: usize) -> usize {
|
||||
unsafe { svp::bytes_of_svp_ppol(module.ptr) as usize * cols }
|
||||
}
|
||||
|
||||
pub(crate) fn new(module: &Module<B>, cols: usize) -> Self {
|
||||
let data = alloc_aligned::<u8>(Self::bytes_of(module, cols));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n: module.n(),
|
||||
cols,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn new_from_bytes(module: &Module<B>, cols: usize, bytes: impl Into<Vec<u8>>) -> Self {
|
||||
let data: Vec<u8> = bytes.into();
|
||||
assert!(data.len() == Self::bytes_of(module, cols));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n: module.n(),
|
||||
cols,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, B: Backend> ScalarZnxDft<D, B> {
|
||||
pub(crate) fn from_data(data: D, n: usize, cols: usize) -> Self {
|
||||
Self {
|
||||
data,
|
||||
n,
|
||||
cols,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_vec_znx_dft(self) -> VecZnxDft<D, B> {
|
||||
VecZnxDft {
|
||||
data: self.data,
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: 1,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type ScalarZnxDftOwned<B> = ScalarZnxDft<Vec<u8>, B>;
|
||||
|
||||
pub trait ScalarZnxDftToRef<B: Backend> {
|
||||
fn to_ref(&self) -> ScalarZnxDft<&[u8], B>;
|
||||
}
|
||||
|
||||
impl<D, B: Backend> ScalarZnxDftToRef<B> for ScalarZnxDft<D, B>
|
||||
where
|
||||
D: AsRef<[u8]>,
|
||||
B: Backend,
|
||||
{
|
||||
fn to_ref(&self) -> ScalarZnxDft<&[u8], B> {
|
||||
ScalarZnxDft {
|
||||
data: self.data.as_ref(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ScalarZnxDftToMut<B: Backend> {
|
||||
fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B>;
|
||||
}
|
||||
|
||||
impl<D, B: Backend> ScalarZnxDftToMut<B> for ScalarZnxDft<D, B>
|
||||
where
|
||||
D: AsMut<[u8]> + AsRef<[u8]>,
|
||||
B: Backend,
|
||||
{
|
||||
fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B> {
|
||||
ScalarZnxDft {
|
||||
data: self.data.as_mut(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, B: Backend> VecZnxDftToRef<B> for ScalarZnxDft<D, B>
|
||||
where
|
||||
D: AsRef<[u8]>,
|
||||
B: Backend,
|
||||
{
|
||||
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
|
||||
VecZnxDft {
|
||||
data: self.data.as_ref(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: 1,
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, B: Backend> VecZnxDftToMut<B> for ScalarZnxDft<D, B>
|
||||
where
|
||||
D: AsRef<[u8]> + AsMut<[u8]>,
|
||||
B: Backend,
|
||||
{
|
||||
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
|
||||
VecZnxDft {
|
||||
data: self.data.as_mut(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: 1,
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,122 +0,0 @@
|
||||
use crate::ffi::svp;
|
||||
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
|
||||
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
|
||||
use crate::{
|
||||
Backend, FFT64, Module, ScalarZnx, ScalarZnxDft, ScalarZnxDftOwned, ScalarZnxDftToMut, ScalarZnxDftToRef, ScalarZnxToMut,
|
||||
ScalarZnxToRef, Scratch, VecZnxDft, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps,
|
||||
};
|
||||
|
||||
pub trait ScalarZnxDftAlloc<B: Backend> {
|
||||
fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned<B>;
|
||||
fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize;
|
||||
fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxDftOwned<B>;
|
||||
}
|
||||
|
||||
pub trait ScalarZnxDftOps<BACKEND: Backend> {
|
||||
fn svp_prepare<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: ScalarZnxDftToMut<BACKEND>,
|
||||
A: ScalarZnxToRef;
|
||||
|
||||
fn svp_apply<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<BACKEND>,
|
||||
A: ScalarZnxDftToRef<BACKEND>,
|
||||
B: VecZnxDftToRef<BACKEND>;
|
||||
|
||||
fn svp_apply_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<BACKEND>,
|
||||
A: ScalarZnxDftToRef<BACKEND>;
|
||||
|
||||
fn scalar_znx_idft<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
|
||||
where
|
||||
R: ScalarZnxToMut,
|
||||
A: ScalarZnxDftToRef<BACKEND>;
|
||||
}
|
||||
|
||||
impl<B: Backend> ScalarZnxDftAlloc<B> for Module<B> {
|
||||
fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned<B> {
|
||||
ScalarZnxDftOwned::new(self, cols)
|
||||
}
|
||||
|
||||
fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize {
|
||||
ScalarZnxDftOwned::bytes_of(self, cols)
|
||||
}
|
||||
|
||||
fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxDftOwned<B> {
|
||||
ScalarZnxDftOwned::new_from_bytes(self, cols, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl ScalarZnxDftOps<FFT64> for Module<FFT64> {
|
||||
fn scalar_znx_idft<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
|
||||
where
|
||||
R: ScalarZnxToMut,
|
||||
A: ScalarZnxDftToRef<FFT64>,
|
||||
{
|
||||
let res_mut: &mut ScalarZnx<&mut [u8]> = &mut res.to_mut();
|
||||
let a_ref: &ScalarZnxDft<&[u8], FFT64> = &a.to_ref();
|
||||
let (mut vec_znx_big, scratch1) = scratch.tmp_vec_znx_big(self, 1, 1);
|
||||
self.vec_znx_idft(&mut vec_znx_big, 0, a_ref, a_col, scratch1);
|
||||
self.vec_znx_copy(res_mut, res_col, &vec_znx_big.to_vec_znx_small(), 0);
|
||||
}
|
||||
|
||||
fn svp_prepare<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: ScalarZnxDftToMut<FFT64>,
|
||||
A: ScalarZnxToRef,
|
||||
{
|
||||
unsafe {
|
||||
svp::svp_prepare(
|
||||
self.ptr,
|
||||
res.to_mut().at_mut_ptr(res_col, 0) as *mut svp::svp_ppol_t,
|
||||
a.to_ref().at_ptr(a_col, 0),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn svp_apply<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: ScalarZnxDftToRef<FFT64>,
|
||||
B: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
let b: VecZnxDft<&[u8], FFT64> = b.to_ref();
|
||||
unsafe {
|
||||
svp::svp_apply_dft_to_dft(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
|
||||
res.size() as u64,
|
||||
res.cols() as u64,
|
||||
a.at_ptr(a_col, 0) as *const svp::svp_ppol_t,
|
||||
b.at_ptr(b_col, 0) as *const vec_znx_dft_t,
|
||||
b.size() as u64,
|
||||
b.cols() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn svp_apply_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: ScalarZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
unsafe {
|
||||
svp::svp_apply_dft_to_dft(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
|
||||
res.size() as u64,
|
||||
res.cols() as u64,
|
||||
a.at_ptr(a_col, 0) as *const svp::svp_ppol_t,
|
||||
res.at_ptr(res_col, 0) as *const vec_znx_dft_t,
|
||||
res.size() as u64,
|
||||
res.cols() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,32 +0,0 @@
|
||||
use crate::znx_base::ZnxInfos;
|
||||
use crate::{Decoding, VecZnx};
|
||||
use rug::Float;
|
||||
use rug::float::Round;
|
||||
use rug::ops::{AddAssignRound, DivAssignRound, SubAssignRound};
|
||||
|
||||
pub trait Stats {
|
||||
/// Returns the standard devaition of the i-th polynomial.
|
||||
fn std(&self, col_i: usize, basek: usize) -> f64;
|
||||
}
|
||||
|
||||
impl<D: AsRef<[u8]>> Stats for VecZnx<D> {
|
||||
fn std(&self, col_i: usize, basek: usize) -> f64 {
|
||||
let prec: u32 = (self.size() * basek) as u32;
|
||||
let mut data: Vec<Float> = (0..self.n()).map(|_| Float::with_val(prec, 0)).collect();
|
||||
self.decode_vec_float(col_i, basek, &mut data);
|
||||
// std = sqrt(sum((xi - avg)^2) / n)
|
||||
let mut avg: Float = Float::with_val(prec, 0);
|
||||
data.iter().for_each(|x| {
|
||||
avg.add_assign_round(x, Round::Nearest);
|
||||
});
|
||||
avg.div_assign_round(Float::with_val(prec, data.len()), Round::Nearest);
|
||||
data.iter_mut().for_each(|x| {
|
||||
x.sub_assign_round(&avg, Round::Nearest);
|
||||
});
|
||||
let mut std: Float = Float::with_val(prec, 0);
|
||||
data.iter().for_each(|x| std += x * x);
|
||||
std.div_assign_round(Float::with_val(prec, data.len()), Round::Nearest);
|
||||
std = std.sqrt();
|
||||
std.to_f64()
|
||||
}
|
||||
}
|
||||
@@ -1,413 +0,0 @@
|
||||
use itertools::izip;
|
||||
|
||||
use crate::DataView;
|
||||
use crate::DataViewMut;
|
||||
use crate::ScalarZnx;
|
||||
use crate::Scratch;
|
||||
use crate::ZnxSliceSize;
|
||||
use crate::ZnxZero;
|
||||
use crate::alloc_aligned;
|
||||
use crate::assert_alignement;
|
||||
use crate::cast_mut;
|
||||
use crate::ffi::znx;
|
||||
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
|
||||
use std::{cmp::min, fmt};
|
||||
|
||||
/// [VecZnx] represents collection of contiguously stacked vector of small norm polynomials of
|
||||
/// Zn\[X\] with [i64] coefficients.
|
||||
/// A [VecZnx] is composed of multiple Zn\[X\] polynomials stored in a single contiguous array
|
||||
/// in the memory.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// Given 3 polynomials (a, b, c) of Zn\[X\], each with 4 columns, then the memory
|
||||
/// layout is: `[a0, b0, c0, a1, b1, c1, a2, b2, c2, a3, b3, c3]`, where ai, bi, ci
|
||||
/// are small polynomials of Zn\[X\].
|
||||
#[derive(PartialEq, Eq)]
|
||||
pub struct VecZnx<D> {
|
||||
pub data: D,
|
||||
pub n: usize,
|
||||
pub cols: usize,
|
||||
pub size: usize,
|
||||
}
|
||||
|
||||
impl<D> fmt::Debug for VecZnx<D>
|
||||
where
|
||||
D: AsRef<[u8]>,
|
||||
{
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}", self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<D> ZnxInfos for VecZnx<D> {
|
||||
fn cols(&self) -> usize {
|
||||
self.cols
|
||||
}
|
||||
|
||||
fn rows(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn n(&self) -> usize {
|
||||
self.n
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
}
|
||||
|
||||
impl<D> ZnxSliceSize for VecZnx<D> {
|
||||
fn sl(&self) -> usize {
|
||||
self.n() * self.cols()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D> DataView for VecZnx<D> {
|
||||
type D = D;
|
||||
fn data(&self) -> &Self::D {
|
||||
&self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D> DataViewMut for VecZnx<D> {
|
||||
fn data_mut(&mut self) -> &mut Self::D {
|
||||
&mut self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: AsRef<[u8]>> ZnxView for VecZnx<D> {
|
||||
type Scalar = i64;
|
||||
}
|
||||
|
||||
impl VecZnx<Vec<u8>> {
|
||||
pub fn rsh_scratch_space(n: usize) -> usize {
|
||||
n * std::mem::size_of::<i64>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: AsMut<[u8]> + AsRef<[u8]>> VecZnx<D> {
|
||||
/// Truncates the precision of the [VecZnx] by k bits.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `basek`: the base two logarithm of the coefficients decomposition.
|
||||
/// * `k`: the number of bits of precision to drop.
|
||||
pub fn trunc_pow2(&mut self, basek: usize, k: usize, col: usize) {
|
||||
if k == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
self.size -= k / basek;
|
||||
|
||||
let k_rem: usize = k % basek;
|
||||
|
||||
if k_rem != 0 {
|
||||
let mask: i64 = ((1 << (basek - k_rem - 1)) - 1) << k_rem;
|
||||
self.at_mut(col, self.size() - 1)
|
||||
.iter_mut()
|
||||
.for_each(|x: &mut i64| *x &= mask)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn rotate(&mut self, k: i64) {
|
||||
unsafe {
|
||||
(0..self.cols()).for_each(|i| {
|
||||
(0..self.size()).for_each(|j| {
|
||||
znx::znx_rotate_inplace_i64(self.n() as u64, k, self.at_mut_ptr(i, j));
|
||||
});
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn rsh(&mut self, basek: usize, k: usize, scratch: &mut Scratch) {
|
||||
let n: usize = self.n();
|
||||
let cols: usize = self.cols();
|
||||
let size: usize = self.size();
|
||||
let steps: usize = k / basek;
|
||||
|
||||
self.raw_mut().rotate_right(n * steps * cols);
|
||||
(0..cols).for_each(|i| {
|
||||
(0..steps).for_each(|j| {
|
||||
self.zero_at(i, j);
|
||||
})
|
||||
});
|
||||
|
||||
let k_rem: usize = k % basek;
|
||||
|
||||
if k_rem != 0 {
|
||||
let (carry, _) = scratch.tmp_slice::<i64>(n);
|
||||
let shift = i64::BITS as usize - k_rem;
|
||||
(0..cols).for_each(|i| {
|
||||
carry.fill(0);
|
||||
(steps..size).for_each(|j| {
|
||||
izip!(carry.iter_mut(), self.at_mut(i, j).iter_mut()).for_each(|(ci, xi)| {
|
||||
*xi += *ci << basek;
|
||||
*ci = (*xi << shift) >> shift;
|
||||
*xi = (*xi - *ci) >> k_rem;
|
||||
});
|
||||
});
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn lsh(&mut self, basek: usize, k: usize, scratch: &mut Scratch) {
|
||||
let n: usize = self.n();
|
||||
let cols: usize = self.cols();
|
||||
let size: usize = self.size();
|
||||
let steps: usize = k / basek;
|
||||
|
||||
self.raw_mut().rotate_left(n * steps * cols);
|
||||
(0..cols).for_each(|i| {
|
||||
(size - steps..size).for_each(|j| {
|
||||
self.zero_at(i, j);
|
||||
})
|
||||
});
|
||||
|
||||
let k_rem: usize = k % basek;
|
||||
|
||||
if k_rem != 0 {
|
||||
let shift: usize = i64::BITS as usize - k_rem;
|
||||
let (tmp_bytes, _) = scratch.tmp_slice::<u8>(n * size_of::<i64>());
|
||||
(0..cols).for_each(|i| {
|
||||
(0..steps).for_each(|j| {
|
||||
self.at_mut(i, j).iter_mut().for_each(|xi| {
|
||||
*xi <<= shift;
|
||||
});
|
||||
});
|
||||
normalize(basek, self, i, tmp_bytes);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: From<Vec<u8>>> VecZnx<D> {
|
||||
pub(crate) fn bytes_of<Scalar: Sized>(n: usize, cols: usize, size: usize) -> usize {
|
||||
n * cols * size * size_of::<Scalar>()
|
||||
}
|
||||
|
||||
pub fn new<Scalar: Sized>(n: usize, cols: usize, size: usize) -> Self {
|
||||
let data = alloc_aligned::<u8>(Self::bytes_of::<Scalar>(n, cols, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n,
|
||||
cols,
|
||||
size,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn new_from_bytes<Scalar: Sized>(n: usize, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
|
||||
let data: Vec<u8> = bytes.into();
|
||||
assert!(data.len() == Self::bytes_of::<Scalar>(n, cols, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n,
|
||||
cols,
|
||||
size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D> VecZnx<D> {
|
||||
pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
|
||||
Self {
|
||||
data,
|
||||
n,
|
||||
cols,
|
||||
size,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_scalar_znx(self) -> ScalarZnx<D> {
|
||||
debug_assert_eq!(
|
||||
self.size, 1,
|
||||
"cannot convert VecZnx to ScalarZnx if cols: {} != 1",
|
||||
self.cols
|
||||
);
|
||||
ScalarZnx {
|
||||
data: self.data,
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Copies the coefficients of `a` on the receiver.
|
||||
/// Copy is done with the minimum size matching both backing arrays.
|
||||
/// Panics if the cols do not match.
|
||||
pub fn copy_vec_znx_from<DataMut, Data>(b: &mut VecZnx<DataMut>, a: &VecZnx<Data>)
|
||||
where
|
||||
DataMut: AsMut<[u8]> + AsRef<[u8]>,
|
||||
Data: AsRef<[u8]>,
|
||||
{
|
||||
assert_eq!(b.cols(), a.cols());
|
||||
let data_a: &[i64] = a.raw();
|
||||
let data_b: &mut [i64] = b.raw_mut();
|
||||
let size = min(data_b.len(), data_a.len());
|
||||
data_b[..size].copy_from_slice(&data_a[..size])
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn normalize_tmp_bytes(n: usize) -> usize {
|
||||
n * std::mem::size_of::<i64>()
|
||||
}
|
||||
|
||||
impl<D: AsRef<[u8]> + AsMut<[u8]>> VecZnx<D> {
|
||||
pub fn normalize(&mut self, basek: usize, a_col: usize, tmp_bytes: &mut [u8]) {
|
||||
normalize(basek, self, a_col, tmp_bytes);
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize<D: AsMut<[u8]> + AsRef<[u8]>>(basek: usize, a: &mut VecZnx<D>, a_col: usize, tmp_bytes: &mut [u8]) {
|
||||
let n: usize = a.n();
|
||||
|
||||
debug_assert!(
|
||||
tmp_bytes.len() >= normalize_tmp_bytes(n),
|
||||
"invalid tmp_bytes: tmp_bytes.len()={} < normalize_tmp_bytes({})",
|
||||
tmp_bytes.len(),
|
||||
n,
|
||||
);
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr())
|
||||
}
|
||||
|
||||
let carry_i64: &mut [i64] = cast_mut(tmp_bytes);
|
||||
|
||||
unsafe {
|
||||
znx::znx_zero_i64_ref(n as u64, carry_i64.as_mut_ptr());
|
||||
(0..a.size()).rev().for_each(|i| {
|
||||
znx::znx_normalize(
|
||||
n as u64,
|
||||
basek as u64,
|
||||
a.at_mut_ptr(a_col, i),
|
||||
carry_i64.as_mut_ptr(),
|
||||
a.at_mut_ptr(a_col, i),
|
||||
carry_i64.as_mut_ptr(),
|
||||
)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: AsMut<[u8]> + AsRef<[u8]>> VecZnx<D>
|
||||
where
|
||||
VecZnx<D>: VecZnxToMut + ZnxInfos,
|
||||
{
|
||||
/// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self].
|
||||
pub fn extract_column<R>(&mut self, self_col: usize, a: &VecZnx<R>, a_col: usize)
|
||||
where
|
||||
R: AsRef<[u8]>,
|
||||
VecZnx<R>: VecZnxToRef + ZnxInfos,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(self_col < self.cols());
|
||||
assert!(a_col < a.cols());
|
||||
}
|
||||
|
||||
let min_size: usize = self.size.min(a.size());
|
||||
let max_size: usize = self.size;
|
||||
|
||||
let mut self_mut: VecZnx<&mut [u8]> = self.to_mut();
|
||||
let a_ref: VecZnx<&[u8]> = a.to_ref();
|
||||
|
||||
(0..min_size).for_each(|i: usize| {
|
||||
self_mut
|
||||
.at_mut(self_col, i)
|
||||
.copy_from_slice(a_ref.at(a_col, i));
|
||||
});
|
||||
|
||||
(min_size..max_size).for_each(|i| {
|
||||
self_mut.zero_at(self_col, i);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: AsRef<[u8]>> fmt::Display for VecZnx<D> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
writeln!(
|
||||
f,
|
||||
"VecZnx(n={}, cols={}, size={})",
|
||||
self.n, self.cols, self.size
|
||||
)?;
|
||||
|
||||
for col in 0..self.cols {
|
||||
writeln!(f, "Column {}:", col)?;
|
||||
for size in 0..self.size {
|
||||
let coeffs = self.at(col, size);
|
||||
write!(f, " Size {}: [", size)?;
|
||||
|
||||
let max_show = 100;
|
||||
let show_count = coeffs.len().min(max_show);
|
||||
|
||||
for (i, &coeff) in coeffs.iter().take(show_count).enumerate() {
|
||||
if i > 0 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
write!(f, "{}", coeff)?;
|
||||
}
|
||||
|
||||
if coeffs.len() > max_show {
|
||||
write!(f, ", ... ({} more)", coeffs.len() - max_show)?;
|
||||
}
|
||||
|
||||
writeln!(f, "]")?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub type VecZnxOwned = VecZnx<Vec<u8>>;
|
||||
pub type VecZnxMut<'a> = VecZnx<&'a mut [u8]>;
|
||||
pub type VecZnxRef<'a> = VecZnx<&'a [u8]>;
|
||||
|
||||
pub trait VecZnxToRef {
|
||||
fn to_ref(&self) -> VecZnx<&[u8]>;
|
||||
}
|
||||
|
||||
impl<D> VecZnxToRef for VecZnx<D>
|
||||
where
|
||||
D: AsRef<[u8]>,
|
||||
{
|
||||
fn to_ref(&self) -> VecZnx<&[u8]> {
|
||||
VecZnx {
|
||||
data: self.data.as_ref(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: self.size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait VecZnxToMut {
|
||||
fn to_mut(&mut self) -> VecZnx<&mut [u8]>;
|
||||
}
|
||||
|
||||
impl<D> VecZnxToMut for VecZnx<D>
|
||||
where
|
||||
D: AsRef<[u8]> + AsMut<[u8]>,
|
||||
{
|
||||
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
|
||||
VecZnx {
|
||||
data: self.data.as_mut(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: self.size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<DataSelf: AsRef<[u8]>> VecZnx<DataSelf> {
|
||||
pub fn clone(&self) -> VecZnx<Vec<u8>> {
|
||||
let self_ref: VecZnx<&[u8]> = self.to_ref();
|
||||
VecZnx {
|
||||
data: self_ref.data.to_vec(),
|
||||
n: self_ref.n,
|
||||
cols: self_ref.cols,
|
||||
size: self_ref.size,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,216 +0,0 @@
|
||||
use crate::ffi::vec_znx_big;
|
||||
use crate::znx_base::{ZnxInfos, ZnxView};
|
||||
use crate::{Backend, DataView, DataViewMut, FFT64, Module, VecZnx, ZnxSliceSize, ZnxViewMut, ZnxZero, alloc_aligned};
|
||||
use std::fmt;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
pub struct VecZnxBig<D, B: Backend> {
|
||||
data: D,
|
||||
n: usize,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
_phantom: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<D, B: Backend> ZnxInfos for VecZnxBig<D, B> {
|
||||
fn cols(&self) -> usize {
|
||||
self.cols
|
||||
}
|
||||
|
||||
fn rows(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn n(&self) -> usize {
|
||||
self.n
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
}
|
||||
|
||||
impl<D> ZnxSliceSize for VecZnxBig<D, FFT64> {
|
||||
fn sl(&self) -> usize {
|
||||
self.n() * self.cols()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, B: Backend> DataView for VecZnxBig<D, B> {
|
||||
type D = D;
|
||||
fn data(&self) -> &Self::D {
|
||||
&self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, B: Backend> DataViewMut for VecZnxBig<D, B> {
|
||||
fn data_mut(&mut self) -> &mut Self::D {
|
||||
&mut self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: AsRef<[u8]>> ZnxView for VecZnxBig<D, FFT64> {
|
||||
type Scalar = i64;
|
||||
}
|
||||
|
||||
pub(crate) fn bytes_of_vec_znx_big<B: Backend>(module: &Module<B>, cols: usize, size: usize) -> usize {
|
||||
unsafe { vec_znx_big::bytes_of_vec_znx_big(module.ptr, size as u64) as usize * cols }
|
||||
}
|
||||
|
||||
impl<D: From<Vec<u8>>, B: Backend> VecZnxBig<D, B> {
|
||||
pub(crate) fn new(module: &Module<B>, cols: usize, size: usize) -> Self {
|
||||
let data = alloc_aligned::<u8>(bytes_of_vec_znx_big(module, cols, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n: module.n(),
|
||||
cols,
|
||||
size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn new_from_bytes(module: &Module<B>, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
|
||||
let data: Vec<u8> = bytes.into();
|
||||
assert!(data.len() == bytes_of_vec_znx_big(module, cols, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n: module.n(),
|
||||
cols,
|
||||
size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, B: Backend> VecZnxBig<D, B> {
|
||||
pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
|
||||
Self {
|
||||
data,
|
||||
n,
|
||||
cols,
|
||||
size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: AsMut<[u8]> + AsRef<[u8]>> VecZnxBig<D, FFT64>
|
||||
where
|
||||
VecZnxBig<D, FFT64>: VecZnxBigToMut<FFT64> + ZnxInfos,
|
||||
{
|
||||
// Consumes the VecZnxBig to return a VecZnx.
|
||||
// Useful when no normalization is needed.
|
||||
pub fn to_vec_znx_small(self) -> VecZnx<D> {
|
||||
VecZnx {
|
||||
data: self.data,
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: self.size,
|
||||
}
|
||||
}
|
||||
|
||||
/// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self].
|
||||
pub fn extract_column<C>(&mut self, self_col: usize, a: &C, a_col: usize)
|
||||
where
|
||||
C: VecZnxBigToRef<FFT64> + ZnxInfos,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(self_col < self.cols());
|
||||
assert!(a_col < a.cols());
|
||||
}
|
||||
|
||||
let min_size: usize = self.size.min(a.size());
|
||||
let max_size: usize = self.size;
|
||||
|
||||
let mut self_mut: VecZnxBig<&mut [u8], FFT64> = self.to_mut();
|
||||
let a_ref: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
|
||||
(0..min_size).for_each(|i: usize| {
|
||||
self_mut
|
||||
.at_mut(self_col, i)
|
||||
.copy_from_slice(a_ref.at(a_col, i));
|
||||
});
|
||||
|
||||
(min_size..max_size).for_each(|i| {
|
||||
self_mut.zero_at(self_col, i);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub type VecZnxBigOwned<B> = VecZnxBig<Vec<u8>, B>;
|
||||
|
||||
pub trait VecZnxBigToRef<B: Backend> {
|
||||
fn to_ref(&self) -> VecZnxBig<&[u8], B>;
|
||||
}
|
||||
|
||||
impl<D, B: Backend> VecZnxBigToRef<B> for VecZnxBig<D, B>
|
||||
where
|
||||
D: AsRef<[u8]>,
|
||||
B: Backend,
|
||||
{
|
||||
fn to_ref(&self) -> VecZnxBig<&[u8], B> {
|
||||
VecZnxBig {
|
||||
data: self.data.as_ref(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: self.size,
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait VecZnxBigToMut<B: Backend> {
|
||||
fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B>;
|
||||
}
|
||||
|
||||
impl<D, B: Backend> VecZnxBigToMut<B> for VecZnxBig<D, B>
|
||||
where
|
||||
D: AsRef<[u8]> + AsMut<[u8]>,
|
||||
B: Backend,
|
||||
{
|
||||
fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B> {
|
||||
VecZnxBig {
|
||||
data: self.data.as_mut(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: self.size,
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: AsRef<[u8]>> fmt::Display for VecZnxBig<D, FFT64> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
writeln!(
|
||||
f,
|
||||
"VecZnxBig(n={}, cols={}, size={})",
|
||||
self.n, self.cols, self.size
|
||||
)?;
|
||||
|
||||
for col in 0..self.cols {
|
||||
writeln!(f, "Column {}:", col)?;
|
||||
for size in 0..self.size {
|
||||
let coeffs = self.at(col, size);
|
||||
write!(f, " Size {}: [", size)?;
|
||||
|
||||
let max_show = 100;
|
||||
let show_count = coeffs.len().min(max_show);
|
||||
|
||||
for (i, &coeff) in coeffs.iter().take(show_count).enumerate() {
|
||||
if i > 0 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
write!(f, "{}", coeff)?;
|
||||
}
|
||||
|
||||
if coeffs.len() > max_show {
|
||||
write!(f, ", ... ({} more)", coeffs.len() - max_show)?;
|
||||
}
|
||||
|
||||
writeln!(f, "]")?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,618 +0,0 @@
|
||||
use crate::ffi::vec_znx;
|
||||
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
|
||||
use crate::{
|
||||
Backend, FFT64, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxScratch,
|
||||
VecZnxToMut, VecZnxToRef, ZnxSliceSize, bytes_of_vec_znx_big,
|
||||
};
|
||||
|
||||
pub trait VecZnxBigAlloc<B: Backend> {
|
||||
/// Allocates a vector Z[X]/(X^N+1) that stores not normalized values.
|
||||
fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBigOwned<B>;
|
||||
|
||||
/// Returns a new [VecZnxBig] with the provided bytes array as backing array.
|
||||
///
|
||||
/// Behavior: takes ownership of the backing array.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `cols`: the number of polynomials..
|
||||
/// * `size`: the number of polynomials per column.
|
||||
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big].
|
||||
///
|
||||
/// # Panics
|
||||
/// If `bytes.len()` < [Module::bytes_of_vec_znx_big].
|
||||
fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<B>;
|
||||
|
||||
// /// Returns a new [VecZnxBig] with the provided bytes array as backing array.
|
||||
// ///
|
||||
// /// Behavior: the backing array is only borrowed.
|
||||
// ///
|
||||
// /// # Arguments
|
||||
// ///
|
||||
// /// * `cols`: the number of polynomials..
|
||||
// /// * `size`: the number of polynomials per column.
|
||||
// /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big].
|
||||
// ///
|
||||
// /// # Panics
|
||||
// /// If `bytes.len()` < [Module::bytes_of_vec_znx_big].
|
||||
// fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig<B>;
|
||||
|
||||
/// Returns the minimum number of bytes necessary to allocate
|
||||
/// a new [VecZnxBig] through [VecZnxBig::from_bytes].
|
||||
fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
pub trait VecZnxBigOps<BACKEND: Backend> {
|
||||
/// Adds `a` to `b` and stores the result on `c`.
|
||||
fn vec_znx_big_add<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<BACKEND>,
|
||||
A: VecZnxBigToRef<BACKEND>,
|
||||
B: VecZnxBigToRef<BACKEND>;
|
||||
|
||||
/// Adds `a` to `b` and stores the result on `b`.
|
||||
fn vec_znx_big_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<BACKEND>,
|
||||
A: VecZnxBigToRef<BACKEND>;
|
||||
|
||||
/// Adds `a` to `b` and stores the result on `c`.
|
||||
fn vec_znx_big_add_small<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<BACKEND>,
|
||||
A: VecZnxBigToRef<BACKEND>,
|
||||
B: VecZnxToRef;
|
||||
|
||||
/// Adds `a` to `b` and stores the result on `b`.
|
||||
fn vec_znx_big_add_small_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<BACKEND>,
|
||||
A: VecZnxToRef;
|
||||
|
||||
/// Subtracts `a` to `b` and stores the result on `c`.
|
||||
fn vec_znx_big_sub<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<BACKEND>,
|
||||
A: VecZnxBigToRef<BACKEND>,
|
||||
B: VecZnxBigToRef<BACKEND>;
|
||||
|
||||
/// Subtracts `a` from `b` and stores the result on `b`.
|
||||
fn vec_znx_big_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<BACKEND>,
|
||||
A: VecZnxBigToRef<BACKEND>;
|
||||
|
||||
/// Subtracts `b` from `a` and stores the result on `b`.
|
||||
fn vec_znx_big_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<BACKEND>,
|
||||
A: VecZnxBigToRef<BACKEND>;
|
||||
|
||||
/// Subtracts `b` from `a` and stores the result on `c`.
|
||||
fn vec_znx_big_sub_small_a<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<BACKEND>,
|
||||
A: VecZnxToRef,
|
||||
B: VecZnxBigToRef<BACKEND>;
|
||||
|
||||
/// Subtracts `a` from `res` and stores the result on `res`.
|
||||
fn vec_znx_big_sub_small_a_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<BACKEND>,
|
||||
A: VecZnxToRef;
|
||||
|
||||
/// Subtracts `b` from `a` and stores the result on `c`.
|
||||
fn vec_znx_big_sub_small_b<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<BACKEND>,
|
||||
A: VecZnxBigToRef<BACKEND>,
|
||||
B: VecZnxToRef;
|
||||
|
||||
/// Subtracts `res` from `a` and stores the result on `res`.
|
||||
fn vec_znx_big_sub_small_b_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<BACKEND>,
|
||||
A: VecZnxToRef;
|
||||
|
||||
/// Negates `a` inplace.
|
||||
fn vec_znx_big_negate_inplace<A>(&self, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxBigToMut<BACKEND>;
|
||||
|
||||
/// Normalizes `a` and stores the result on `b`.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `basek`: normalization basis.
|
||||
/// * `tmp_bytes`: scratch space of size at least [VecZnxBigOps::vec_znx_big_normalize].
|
||||
fn vec_znx_big_normalize<R, A>(&self, basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxBigToRef<FFT64>;
|
||||
|
||||
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`.
|
||||
fn vec_znx_big_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<BACKEND>,
|
||||
A: VecZnxBigToRef<BACKEND>;
|
||||
|
||||
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`.
|
||||
fn vec_znx_big_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxBigToMut<BACKEND>;
|
||||
}
|
||||
|
||||
pub trait VecZnxBigScratch {
|
||||
/// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_normalize].
|
||||
fn vec_znx_big_normalize_tmp_bytes(&self) -> usize;
|
||||
}
|
||||
|
||||
impl<B: Backend> VecZnxBigAlloc<B> for Module<B> {
|
||||
fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBigOwned<B> {
|
||||
VecZnxBig::new(self, cols, size)
|
||||
}
|
||||
|
||||
fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<B> {
|
||||
VecZnxBig::new_from_bytes(self, cols, size, bytes)
|
||||
}
|
||||
|
||||
fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize {
|
||||
bytes_of_vec_znx_big(self, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl VecZnxBigOps<FFT64> for Module<FFT64> {
|
||||
fn vec_znx_big_add<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
B: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let b: VecZnxBig<&[u8], FFT64> = b.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(b.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_add(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(b_col, 0),
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_add(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_sub<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
B: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let b: VecZnxBig<&[u8], FFT64> = b.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(b.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(b_col, 0),
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_sub_small_b<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
B: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let b: VecZnx<&[u8]> = b.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(b.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(b_col, 0),
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_sub_small_b_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_sub_small_a<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxToRef,
|
||||
B: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let b: VecZnxBig<&[u8], FFT64> = b.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(b.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(b_col, 0),
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_sub_small_a_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_add_small<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
B: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let b: VecZnx<&[u8]> = b.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(b.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_add(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(b_col, 0),
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_add_small_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_add(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_negate_inplace<A>(&self, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxBigToMut<FFT64>,
|
||||
{
|
||||
let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_negate(
|
||||
self.ptr,
|
||||
a.at_mut_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_normalize<R, A>(&self, basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
//(Jay)Note: This is calling VezZnxOps::vec_znx_normalize_tmp_bytes and not VecZnxBigOps::vec_znx_big_normalize_tmp_bytes.
|
||||
// In the FFT backend the tmp sizes are same but will be different in the NTT backend
|
||||
// assert!(tmp_bytes.len() >= <Self as VecZnxOps<&mut [u8], & [u8]>>::vec_znx_normalize_tmp_bytes(&self));
|
||||
// assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
|
||||
let (tmp_bytes, _) = scratch.tmp_slice(<Self as VecZnxBigScratch>::vec_znx_big_normalize_tmp_bytes(
|
||||
&self,
|
||||
));
|
||||
unsafe {
|
||||
vec_znx::vec_znx_normalize_base2k(
|
||||
self.ptr,
|
||||
basek as u64,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
self.ptr,
|
||||
k,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxBigToMut<FFT64>,
|
||||
{
|
||||
let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
self.ptr,
|
||||
k,
|
||||
a.at_mut_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> VecZnxBigScratch for Module<B> {
|
||||
fn vec_znx_big_normalize_tmp_bytes(&self) -> usize {
|
||||
<Self as VecZnxScratch>::vec_znx_normalize_tmp_bytes(self)
|
||||
}
|
||||
}
|
||||
@@ -1,190 +0,0 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::ffi::vec_znx_dft;
|
||||
use crate::znx_base::ZnxInfos;
|
||||
use crate::{Backend, DataView, DataViewMut, FFT64, Module, VecZnxBig, ZnxSliceSize, ZnxView, alloc_aligned};
|
||||
use std::fmt;
|
||||
|
||||
pub struct VecZnxDft<D, B: Backend> {
|
||||
pub(crate) data: D,
|
||||
pub(crate) n: usize,
|
||||
pub(crate) cols: usize,
|
||||
pub(crate) size: usize,
|
||||
pub(crate) _phantom: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<D, B: Backend> VecZnxDft<D, B> {
|
||||
pub fn into_big(self) -> VecZnxBig<D, B> {
|
||||
VecZnxBig::<D, B>::from_data(self.data, self.n, self.cols, self.size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, B: Backend> ZnxInfos for VecZnxDft<D, B> {
|
||||
fn cols(&self) -> usize {
|
||||
self.cols
|
||||
}
|
||||
|
||||
fn rows(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn n(&self) -> usize {
|
||||
self.n
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
}
|
||||
|
||||
impl<D> ZnxSliceSize for VecZnxDft<D, FFT64> {
|
||||
fn sl(&self) -> usize {
|
||||
self.n() * self.cols()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, B: Backend> DataView for VecZnxDft<D, B> {
|
||||
type D = D;
|
||||
fn data(&self) -> &Self::D {
|
||||
&self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, B: Backend> DataViewMut for VecZnxDft<D, B> {
|
||||
fn data_mut(&mut self) -> &mut Self::D {
|
||||
&mut self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: AsRef<[u8]>> ZnxView for VecZnxDft<D, FFT64> {
|
||||
type Scalar = f64;
|
||||
}
|
||||
|
||||
impl<D: AsMut<[u8]> + AsRef<[u8]>> VecZnxDft<D, FFT64> {
|
||||
pub fn set_size(&mut self, size: usize) {
|
||||
assert!(size <= self.data.as_ref().len() / (self.n * self.cols()));
|
||||
self.size = size
|
||||
}
|
||||
|
||||
pub fn max_size(&mut self) -> usize {
|
||||
self.data.as_ref().len() / (self.n * self.cols)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn bytes_of_vec_znx_dft<B: Backend>(module: &Module<B>, cols: usize, size: usize) -> usize {
|
||||
unsafe { vec_znx_dft::bytes_of_vec_znx_dft(module.ptr, size as u64) as usize * cols }
|
||||
}
|
||||
|
||||
impl<D: From<Vec<u8>>, B: Backend> VecZnxDft<D, B> {
|
||||
pub(crate) fn new(module: &Module<B>, cols: usize, size: usize) -> Self {
|
||||
let data = alloc_aligned::<u8>(bytes_of_vec_znx_dft(module, cols, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n: module.n(),
|
||||
cols,
|
||||
size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn new_from_bytes(module: &Module<B>, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
|
||||
let data: Vec<u8> = bytes.into();
|
||||
assert!(data.len() == bytes_of_vec_znx_dft(module, cols, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n: module.n(),
|
||||
cols,
|
||||
size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type VecZnxDftOwned<B> = VecZnxDft<Vec<u8>, B>;
|
||||
|
||||
impl<D, B: Backend> VecZnxDft<D, B> {
|
||||
pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
|
||||
Self {
|
||||
data,
|
||||
n,
|
||||
cols,
|
||||
size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait VecZnxDftToRef<B: Backend> {
|
||||
fn to_ref(&self) -> VecZnxDft<&[u8], B>;
|
||||
}
|
||||
|
||||
impl<D, B: Backend> VecZnxDftToRef<B> for VecZnxDft<D, B>
|
||||
where
|
||||
D: AsRef<[u8]>,
|
||||
B: Backend,
|
||||
{
|
||||
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
|
||||
VecZnxDft {
|
||||
data: self.data.as_ref(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: self.size,
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait VecZnxDftToMut<B: Backend> {
|
||||
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B>;
|
||||
}
|
||||
|
||||
impl<D, B: Backend> VecZnxDftToMut<B> for VecZnxDft<D, B>
|
||||
where
|
||||
D: AsRef<[u8]> + AsMut<[u8]>,
|
||||
B: Backend,
|
||||
{
|
||||
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
|
||||
VecZnxDft {
|
||||
data: self.data.as_mut(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: self.size,
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: AsRef<[u8]>> fmt::Display for VecZnxDft<D, FFT64> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
writeln!(
|
||||
f,
|
||||
"VecZnxDft(n={}, cols={}, size={})",
|
||||
self.n, self.cols, self.size
|
||||
)?;
|
||||
|
||||
for col in 0..self.cols {
|
||||
writeln!(f, "Column {}:", col)?;
|
||||
for size in 0..self.size {
|
||||
let coeffs = self.at(col, size);
|
||||
write!(f, " Size {}: [", size)?;
|
||||
|
||||
let max_show = 100;
|
||||
let show_count = coeffs.len().min(max_show);
|
||||
|
||||
for (i, &coeff) in coeffs.iter().take(show_count).enumerate() {
|
||||
if i > 0 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
write!(f, "{}", coeff)?;
|
||||
}
|
||||
|
||||
if coeffs.len() > max_show {
|
||||
write!(f, ", ... ({} more)", coeffs.len() - max_show)?;
|
||||
}
|
||||
|
||||
writeln!(f, "]")?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,736 +0,0 @@
|
||||
use crate::ffi::vec_znx;
|
||||
use crate::{
|
||||
Backend, Module, ScalarZnxToRef, Scratch, VecZnx, VecZnxOwned, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxSliceSize, ZnxView,
|
||||
ZnxViewMut, ZnxZero,
|
||||
};
|
||||
use itertools::izip;
|
||||
use std::cmp::min;
|
||||
|
||||
pub trait VecZnxAlloc {
|
||||
/// Allocates a new [VecZnx].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `cols`: the number of polynomials.
|
||||
/// * `size`: the number small polynomials per column.
|
||||
fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnxOwned;
|
||||
|
||||
/// Instantiates a new [VecZnx] from a slice of bytes.
|
||||
/// The returned [VecZnx] takes ownership of the slice of bytes.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `cols`: the number of polynomials.
|
||||
/// * `size`: the number small polynomials per column.
|
||||
///
|
||||
/// # Panic
|
||||
/// Requires the slice of bytes to be equal to [VecZnxOps::bytes_of_vec_znx].
|
||||
fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxOwned;
|
||||
|
||||
/// Returns the number of bytes necessary to allocate
|
||||
/// a new [VecZnx] through [VecZnxOps::new_vec_znx_from_bytes]
|
||||
/// or [VecZnxOps::new_vec_znx_from_bytes_borrow].
|
||||
fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
pub trait VecZnxOps {
|
||||
/// Normalizes the selected column of `a` and stores the result into the selected column of `res`.
|
||||
fn vec_znx_normalize<R, A>(&self, basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
|
||||
/// Normalizes the selected column of `a`.
|
||||
fn vec_znx_normalize_inplace<A>(&self, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch)
|
||||
where
|
||||
A: VecZnxToMut;
|
||||
|
||||
/// Adds the selected column of `a` to the selected column of `b` and writes the result on the selected column of `res`.
|
||||
fn vec_znx_add<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
B: VecZnxToRef;
|
||||
|
||||
/// Adds the selected column of `a` to the selected column of `res` and writes the result on the selected column of `res`.
|
||||
fn vec_znx_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
|
||||
/// Adds the selected column of `a` on the selected column and limb of `res`.
|
||||
fn vec_znx_add_scalar_inplace<R, A>(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef;
|
||||
|
||||
/// Subtracts the selected column of `b` from the selected column of `a` and writes the result on the selected column of `res`.
|
||||
fn vec_znx_sub<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
B: VecZnxToRef;
|
||||
|
||||
/// Subtracts the selected column of `a` from the selected column of `res` inplace.
|
||||
///
|
||||
/// res[res_col] -= a[a_col]
|
||||
fn vec_znx_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
|
||||
/// Subtracts the selected column of `res` from the selected column of `a` and inplace mutates `res`
|
||||
///
|
||||
/// res[res_col] = a[a_col] - res[res_col]
|
||||
fn vec_znx_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
|
||||
/// Subtracts the selected column of `a` on the selected column and limb of `res`.
|
||||
fn vec_znx_sub_scalar_inplace<R, A>(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, b_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef;
|
||||
|
||||
// Negates the selected column of `a` and stores the result in `res_col` of `res`.
|
||||
fn vec_znx_negate<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
|
||||
/// Negates the selected column of `a`.
|
||||
fn vec_znx_negate_inplace<A>(&self, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxToMut;
|
||||
|
||||
/// Shifts by k bits all columns of `a`.
|
||||
/// A positive k applies a left shift, while a negative k applies a right shift.
|
||||
fn vec_znx_shift_inplace<A>(&self, basek: usize, k: i64, a: &mut A, scratch: &mut Scratch)
|
||||
where
|
||||
A: VecZnxToMut;
|
||||
|
||||
/// Multiplies the selected column of `a` by X^k and stores the result in `res_col` of `res`.
|
||||
fn vec_znx_rotate<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
|
||||
/// Multiplies the selected column of `a` by X^k.
|
||||
fn vec_znx_rotate_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxToMut;
|
||||
|
||||
/// Applies the automorphism X^i -> X^ik on the selected column of `a` and stores the result in `res_col` column of `res`.
|
||||
fn vec_znx_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
|
||||
/// Applies the automorphism X^i -> X^ik on the selected column of `a`.
|
||||
fn vec_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxToMut;
|
||||
|
||||
/// Splits the selected columns of `b` into subrings and copies them them into the selected column of `res`.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// This method requires that all [VecZnx] of b have the same ring degree
|
||||
/// and that b.n() * b.len() <= a.n()
|
||||
fn vec_znx_split<R, A>(&self, res: &mut Vec<R>, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
|
||||
/// Merges the subrings of the selected column of `a` into the selected column of `res`.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// This method requires that all [VecZnx] of a have the same ring degree
|
||||
/// and that a.n() * a.len() <= b.n()
|
||||
fn vec_znx_merge<R, A>(&self, res: &mut R, res_col: usize, a: Vec<A>, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
|
||||
fn switch_degree<R, A>(&self, r: &mut R, col_b: usize, a: &A, col_a: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
|
||||
fn vec_znx_copy<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxScratch {
|
||||
/// Returns the minimum number of bytes necessary for normalization.
|
||||
fn vec_znx_normalize_tmp_bytes(&self) -> usize;
|
||||
}
|
||||
|
||||
impl<B: Backend> VecZnxAlloc for Module<B> {
|
||||
fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnxOwned {
|
||||
VecZnxOwned::new::<i64>(self.n(), cols, size)
|
||||
}
|
||||
|
||||
fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize {
|
||||
VecZnxOwned::bytes_of::<i64>(self.n(), cols, size)
|
||||
}
|
||||
|
||||
fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxOwned {
|
||||
VecZnxOwned::new_from_bytes::<i64>(self.n(), cols, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl<BACKEND: Backend> VecZnxOps for Module<BACKEND> {
|
||||
fn vec_znx_shift_inplace<A>(&self, basek: usize, k: i64, a: &mut A, scratch: &mut Scratch)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
if k > 0 {
|
||||
a.to_mut().lsh(basek, k as usize, scratch);
|
||||
} else {
|
||||
a.to_mut().rsh(basek, k.abs() as usize, scratch);
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_copy<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let mut res_mut: VecZnx<&mut [u8]> = res.to_mut();
|
||||
let a_ref: VecZnx<&[u8]> = a.to_ref();
|
||||
|
||||
let min_size: usize = min(res_mut.size(), a_ref.size());
|
||||
|
||||
(0..min_size).for_each(|j| {
|
||||
res_mut
|
||||
.at_mut(res_col, j)
|
||||
.copy_from_slice(a_ref.at(a_col, j));
|
||||
});
|
||||
(min_size..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
})
|
||||
}
|
||||
|
||||
fn vec_znx_normalize<R, A>(&self, basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
|
||||
let (tmp_bytes, _) = scratch.tmp_slice(self.vec_znx_normalize_tmp_bytes());
|
||||
|
||||
unsafe {
|
||||
vec_znx::vec_znx_normalize_base2k(
|
||||
self.ptr,
|
||||
basek as u64,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_normalize_inplace<A>(&self, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
let mut a: VecZnx<&mut [u8]> = a.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
}
|
||||
|
||||
let (tmp_bytes, _) = scratch.tmp_slice(self.vec_znx_normalize_tmp_bytes());
|
||||
|
||||
unsafe {
|
||||
vec_znx::vec_znx_normalize_base2k(
|
||||
self.ptr,
|
||||
basek as u64,
|
||||
a.at_mut_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_add<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
B: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let b: VecZnx<&[u8]> = b.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(b.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_add(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(b_col, 0),
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_add(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_add_scalar_inplace<R, A>(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
let a: crate::ScalarZnx<&[u8]> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
|
||||
unsafe {
|
||||
vec_znx::vec_znx_add(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, res_limb),
|
||||
1 as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
res.at_ptr(res_col, res_limb),
|
||||
1 as u64,
|
||||
res.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_sub<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
B: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let b: VecZnx<&[u8]> = b.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(b.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(b_col, 0),
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_sub_scalar_inplace<R, A>(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
let a: crate::ScalarZnx<&[u8]> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, res_limb),
|
||||
1 as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
res.at_ptr(res_col, res_limb),
|
||||
1 as u64,
|
||||
res.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_negate<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_negate(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_negate_inplace<A>(&self, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
let mut a: VecZnx<&mut [u8]> = a.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_negate(
|
||||
self.ptr,
|
||||
a.at_mut_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_rotate<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_rotate(
|
||||
self.ptr,
|
||||
k,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_rotate_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
let mut a: VecZnx<&mut [u8]> = a.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_rotate(
|
||||
self.ptr,
|
||||
k,
|
||||
a.at_mut_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
self.ptr,
|
||||
k,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
let mut a: VecZnx<&mut [u8]> = a.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert!(
|
||||
k & 1 != 0,
|
||||
"invalid galois element: must be odd but is {}",
|
||||
k
|
||||
);
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
self.ptr,
|
||||
k,
|
||||
a.at_mut_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_split<R, A>(&self, res: &mut Vec<R>, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
|
||||
let (n_in, n_out) = (a.n(), res[0].to_mut().n());
|
||||
|
||||
let (mut buf, _) = scratch.tmp_vec_znx(self, 1, a.size());
|
||||
|
||||
debug_assert!(
|
||||
n_out < n_in,
|
||||
"invalid a: output ring degree should be smaller"
|
||||
);
|
||||
res[1..].iter_mut().for_each(|bi| {
|
||||
debug_assert_eq!(
|
||||
bi.to_mut().n(),
|
||||
n_out,
|
||||
"invalid input a: all VecZnx must have the same degree"
|
||||
)
|
||||
});
|
||||
|
||||
res.iter_mut().enumerate().for_each(|(i, bi)| {
|
||||
if i == 0 {
|
||||
self.switch_degree(bi, res_col, &a, a_col);
|
||||
self.vec_znx_rotate(-1, &mut buf, 0, &a, a_col);
|
||||
} else {
|
||||
self.switch_degree(bi, res_col, &mut buf, a_col);
|
||||
self.vec_znx_rotate_inplace(-1, &mut buf, a_col);
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn vec_znx_merge<R, A>(&self, res: &mut R, res_col: usize, a: Vec<A>, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
let (n_in, n_out) = (res.n(), a[0].to_ref().n());
|
||||
|
||||
debug_assert!(
|
||||
n_out < n_in,
|
||||
"invalid a: output ring degree should be smaller"
|
||||
);
|
||||
a[1..].iter().for_each(|ai| {
|
||||
debug_assert_eq!(
|
||||
ai.to_ref().n(),
|
||||
n_out,
|
||||
"invalid input a: all VecZnx must have the same degree"
|
||||
)
|
||||
});
|
||||
|
||||
a.iter().enumerate().for_each(|(_, ai)| {
|
||||
self.switch_degree(&mut res, res_col, ai, a_col);
|
||||
self.vec_znx_rotate_inplace(-1, &mut res, res_col);
|
||||
});
|
||||
|
||||
self.vec_znx_rotate_inplace(a.len() as i64, &mut res, res_col);
|
||||
}
|
||||
|
||||
fn switch_degree<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
let (n_in, n_out) = (a.n(), res.n());
|
||||
let (gap_in, gap_out): (usize, usize);
|
||||
|
||||
if n_in > n_out {
|
||||
(gap_in, gap_out) = (n_in / n_out, 1)
|
||||
} else {
|
||||
(gap_in, gap_out) = (1, n_out / n_in);
|
||||
res.zero();
|
||||
}
|
||||
|
||||
let size: usize = min(a.size(), res.size());
|
||||
|
||||
(0..size).for_each(|i| {
|
||||
izip!(
|
||||
a.at(a_col, i).iter().step_by(gap_in),
|
||||
res.at_mut(res_col, i).iter_mut().step_by(gap_out)
|
||||
)
|
||||
.for_each(|(x_in, x_out)| *x_out = *x_in);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> VecZnxScratch for Module<B> {
|
||||
fn vec_znx_normalize_tmp_bytes(&self) -> usize {
|
||||
unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.ptr) as usize }
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user