Added trace operation + test and renamed base2k to backend

This commit is contained in:
Jean-Philippe Bossuat
2025-05-21 16:54:29 +02:00
parent acd81c40c2
commit 27a5395ce2
62 changed files with 1926 additions and 1620 deletions

333
backend/src/encoding.rs Normal file
View File

@@ -0,0 +1,333 @@
use crate::ffi::znx::znx_zero_i64_ref;
use crate::znx_base::{ZnxView, ZnxViewMut};
use crate::{VecZnx, znx_base::ZnxInfos};
use itertools::izip;
use rug::{Assign, Float};
use std::cmp::min;
pub trait Encoding {
/// encode a vector of i64 on the receiver.
///
/// # Arguments
///
/// * `col_i`: the index of the poly where to encode the data.
/// * `basek`: base two negative logarithm decomposition of the receiver.
/// * `k`: base two negative logarithm of the scaling of the data.
/// * `data`: data to encode on the receiver.
/// * `log_max`: base two logarithm of the infinity norm of the input data.
fn encode_vec_i64(&mut self, col_i: usize, basek: usize, k: usize, data: &[i64], log_max: usize);
/// encodes a single i64 on the receiver at the given index.
///
/// # Arguments
///
/// * `col_i`: the index of the poly where to encode the data.
/// * `basek`: base two negative logarithm decomposition of the receiver.
/// * `k`: base two negative logarithm of the scaling of the data.
/// * `i`: index of the coefficient on which to encode the data.
/// * `data`: data to encode on the receiver.
/// * `log_max`: base two logarithm of the infinity norm of the input data.
fn encode_coeff_i64(&mut self, col_i: usize, basek: usize, k: usize, i: usize, data: i64, log_max: usize);
}
pub trait Decoding {
/// decode a vector of i64 from the receiver.
///
/// # Arguments
///
/// * `col_i`: the index of the poly where to encode the data.
/// * `basek`: base two negative logarithm decomposition of the receiver.
/// * `k`: base two logarithm of the scaling of the data.
/// * `data`: data to decode from the receiver.
fn decode_vec_i64(&self, col_i: usize, basek: usize, k: usize, data: &mut [i64]);
/// decode a vector of Float from the receiver.
///
/// # Arguments
/// * `col_i`: the index of the poly where to encode the data.
/// * `basek`: base two negative logarithm decomposition of the receiver.
/// * `data`: data to decode from the receiver.
fn decode_vec_float(&self, col_i: usize, basek: usize, data: &mut [Float]);
/// decode a single of i64 from the receiver at the given index.
///
/// # Arguments
///
/// * `col_i`: the index of the poly where to encode the data.
/// * `basek`: base two negative logarithm decomposition of the receiver.
/// * `k`: base two negative logarithm of the scaling of the data.
/// * `i`: index of the coefficient to decode.
/// * `data`: data to decode from the receiver.
fn decode_coeff_i64(&self, col_i: usize, basek: usize, k: usize, i: usize) -> i64;
}
impl<D: AsMut<[u8]> + AsRef<[u8]>> Encoding for VecZnx<D> {
fn encode_vec_i64(&mut self, col_i: usize, basek: usize, k: usize, data: &[i64], log_max: usize) {
encode_vec_i64(self, col_i, basek, k, data, log_max)
}
fn encode_coeff_i64(&mut self, col_i: usize, basek: usize, k: usize, i: usize, value: i64, log_max: usize) {
encode_coeff_i64(self, col_i, basek, k, i, value, log_max)
}
}
impl<D: AsRef<[u8]>> Decoding for VecZnx<D> {
fn decode_vec_i64(&self, col_i: usize, basek: usize, k: usize, data: &mut [i64]) {
decode_vec_i64(self, col_i, basek, k, data)
}
fn decode_vec_float(&self, col_i: usize, basek: usize, data: &mut [Float]) {
decode_vec_float(self, col_i, basek, data)
}
fn decode_coeff_i64(&self, col_i: usize, basek: usize, k: usize, i: usize) -> i64 {
decode_coeff_i64(self, col_i, basek, k, i)
}
}
fn encode_vec_i64<D: AsMut<[u8]> + AsRef<[u8]>>(
a: &mut VecZnx<D>,
col_i: usize,
basek: usize,
k: usize,
data: &[i64],
log_max: usize,
) {
let size: usize = (k + basek - 1) / basek;
#[cfg(debug_assertions)]
{
assert!(
size <= a.size(),
"invalid argument k: (k + a.basek - 1)/a.basek={} > a.size()={}",
size,
a.size()
);
assert!(col_i < a.cols());
assert!(data.len() <= a.n())
}
let data_len: usize = data.len();
let k_rem: usize = basek - (k % basek);
// Zeroes coefficients of the i-th column
(0..a.size()).for_each(|i| unsafe {
znx_zero_i64_ref(a.n() as u64, a.at_mut_ptr(col_i, i));
});
// If 2^{basek} * 2^{k_rem} < 2^{63}-1, then we can simply copy
// values on the last limb.
// Else we decompose values base2k.
if log_max + k_rem < 63 || k_rem == basek {
a.at_mut(col_i, size - 1)[..data_len].copy_from_slice(&data[..data_len]);
} else {
let mask: i64 = (1 << basek) - 1;
let steps: usize = min(size, (log_max + basek - 1) / basek);
(size - steps..size)
.rev()
.enumerate()
.for_each(|(i, i_rev)| {
let shift: usize = i * basek;
izip!(a.at_mut(col_i, i_rev).iter_mut(), data.iter()).for_each(|(y, x)| *y = (x >> shift) & mask);
})
}
// Case where self.prec % self.k != 0.
if k_rem != basek {
let steps: usize = min(size, (log_max + basek - 1) / basek);
(size - steps..size).rev().for_each(|i| {
a.at_mut(col_i, i)[..data_len]
.iter_mut()
.for_each(|x| *x <<= k_rem);
})
}
}
fn decode_vec_i64<D: AsRef<[u8]>>(a: &VecZnx<D>, col_i: usize, basek: usize, k: usize, data: &mut [i64]) {
let size: usize = (k + basek - 1) / basek;
#[cfg(debug_assertions)]
{
assert!(
data.len() >= a.n(),
"invalid data: data.len()={} < a.n()={}",
data.len(),
a.n()
);
assert!(col_i < a.cols());
}
data.copy_from_slice(a.at(col_i, 0));
let rem: usize = basek - (k % basek);
(1..size).for_each(|i| {
if i == size - 1 && rem != basek {
let k_rem: usize = basek - rem;
izip!(a.at(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| {
*y = (*y << k_rem) + (x >> rem);
});
} else {
izip!(a.at(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| {
*y = (*y << basek) + x;
});
}
})
}
fn decode_vec_float<D: AsRef<[u8]>>(a: &VecZnx<D>, col_i: usize, basek: usize, data: &mut [Float]) {
let size: usize = a.size();
#[cfg(debug_assertions)]
{
assert!(
data.len() >= a.n(),
"invalid data: data.len()={} < a.n()={}",
data.len(),
a.n()
);
assert!(col_i < a.cols());
}
let prec: u32 = (basek * size) as u32;
// 2^{basek}
let base = Float::with_val(prec, (1 << basek) as f64);
// y[i] = sum x[j][i] * 2^{-basek*j}
(0..size).for_each(|i| {
if i == 0 {
izip!(a.at(col_i, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
y.assign(*x);
*y /= &base;
});
} else {
izip!(a.at(col_i, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
*y += Float::with_val(prec, *x);
*y /= &base;
});
}
});
}
fn encode_coeff_i64<D: AsMut<[u8]> + AsRef<[u8]>>(
a: &mut VecZnx<D>,
col_i: usize,
basek: usize,
k: usize,
i: usize,
value: i64,
log_max: usize,
) {
let size: usize = (k + basek - 1) / basek;
#[cfg(debug_assertions)]
{
assert!(i < a.n());
assert!(
size <= a.size(),
"invalid argument k: (k + a.basek - 1)/a.basek={} > a.size()={}",
size,
a.size()
);
assert!(col_i < a.cols());
}
let k_rem: usize = basek - (k % basek);
(0..a.size()).for_each(|j| a.at_mut(col_i, j)[i] = 0);
// If 2^{basek} * 2^{k_rem} < 2^{63}-1, then we can simply copy
// values on the last limb.
// Else we decompose values base2k.
if log_max + k_rem < 63 || k_rem == basek {
a.at_mut(col_i, size - 1)[i] = value;
} else {
let mask: i64 = (1 << basek) - 1;
let steps: usize = min(size, (log_max + basek - 1) / basek);
(size - steps..size)
.rev()
.enumerate()
.for_each(|(j, j_rev)| {
a.at_mut(col_i, j_rev)[i] = (value >> (j * basek)) & mask;
})
}
// Case where prec % k != 0.
if k_rem != basek {
let steps: usize = min(size, (log_max + basek - 1) / basek);
(size - steps..size).rev().for_each(|j| {
a.at_mut(col_i, j)[i] <<= k_rem;
})
}
}
fn decode_coeff_i64<D: AsRef<[u8]>>(a: &VecZnx<D>, col_i: usize, basek: usize, k: usize, i: usize) -> i64 {
#[cfg(debug_assertions)]
{
assert!(i < a.n());
assert!(col_i < a.cols())
}
let cols: usize = (k + basek - 1) / basek;
let data: &[i64] = a.raw();
let mut res: i64 = data[i];
let rem: usize = basek - (k % basek);
let slice_size: usize = a.n() * a.size();
(1..cols).for_each(|i| {
let x = data[i * slice_size];
if i == cols - 1 && rem != basek {
let k_rem: usize = basek - rem;
res = (res << k_rem) + (x >> rem);
} else {
res = (res << basek) + x;
}
});
res
}
#[cfg(test)]
mod tests {
use crate::vec_znx_ops::*;
use crate::znx_base::*;
use crate::{Decoding, Encoding, FFT64, Module, VecZnx, znx_base::ZnxInfos};
use itertools::izip;
use sampling::source::Source;
#[test]
fn test_set_get_i64_lo_norm() {
let n: usize = 8;
let module: Module<FFT64> = Module::<FFT64>::new(n);
let basek: usize = 17;
let size: usize = 5;
let k: usize = size * basek - 5;
let mut a: VecZnx<_> = module.new_vec_znx(2, size);
let mut source: Source = Source::new([0u8; 32]);
let raw: &mut [i64] = a.raw_mut();
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
(0..a.cols()).for_each(|col_i| {
let mut have: Vec<i64> = vec![i64::default(); n];
have.iter_mut()
.for_each(|x| *x = (source.next_i64() << 56) >> 56);
a.encode_vec_i64(col_i, basek, k, &have, 10);
let mut want: Vec<i64> = vec![i64::default(); n];
a.decode_vec_i64(col_i, basek, k, &mut want);
izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b));
});
}
#[test]
fn test_set_get_i64_hi_norm() {
let n: usize = 8;
let module: Module<FFT64> = Module::<FFT64>::new(n);
let basek: usize = 17;
let size: usize = 5;
let k: usize = size * basek - 5;
let mut a: VecZnx<_> = module.new_vec_znx(2, size);
let mut source = Source::new([0u8; 32]);
let raw: &mut [i64] = a.raw_mut();
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
(0..a.cols()).for_each(|col_i| {
let mut have: Vec<i64> = vec![i64::default(); n];
have.iter_mut().for_each(|x| *x = source.next_i64());
a.encode_vec_i64(col_i, basek, k, &have, 64);
let mut want = vec![i64::default(); n];
a.decode_vec_i64(col_i, basek, k, &mut want);
izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b));
})
}
}

7
backend/src/ffi/cnv.rs Normal file
View File

@@ -0,0 +1,7 @@
pub type CNV_PVEC_L = cnv_pvec_l_t;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct cnv_pvec_r_t {
_unused: [u8; 0],
}
pub type CNV_PVEC_R = cnv_pvec_r_t;

8
backend/src/ffi/mod.rs Normal file
View File

@@ -0,0 +1,8 @@
pub mod module;
pub mod reim;
pub mod svp;
pub mod vec_znx;
pub mod vec_znx_big;
pub mod vec_znx_dft;
pub mod vmp;
pub mod znx;

18
backend/src/ffi/module.rs Normal file
View File

@@ -0,0 +1,18 @@
pub struct module_info_t {
_unused: [u8; 0],
}
pub type module_type_t = ::std::os::raw::c_uint;
pub use self::module_type_t as MODULE_TYPE;
pub type MODULE = module_info_t;
unsafe extern "C" {
pub unsafe fn new_module_info(N: u64, mode: MODULE_TYPE) -> *mut MODULE;
}
unsafe extern "C" {
pub unsafe fn delete_module_info(module_info: *mut MODULE);
}
unsafe extern "C" {
pub unsafe fn module_get_n(module: *const MODULE) -> u64;
}

172
backend/src/ffi/reim.rs Normal file
View File

@@ -0,0 +1,172 @@
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct reim_fft_precomp {
_unused: [u8; 0],
}
pub type REIM_FFT_PRECOMP = reim_fft_precomp;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct reim_ifft_precomp {
_unused: [u8; 0],
}
pub type REIM_IFFT_PRECOMP = reim_ifft_precomp;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct reim_mul_precomp {
_unused: [u8; 0],
}
pub type REIM_FFTVEC_MUL_PRECOMP = reim_mul_precomp;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct reim_addmul_precomp {
_unused: [u8; 0],
}
pub type REIM_FFTVEC_ADDMUL_PRECOMP = reim_addmul_precomp;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct reim_from_znx32_precomp {
_unused: [u8; 0],
}
pub type REIM_FROM_ZNX32_PRECOMP = reim_from_znx32_precomp;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct reim_from_znx64_precomp {
_unused: [u8; 0],
}
pub type REIM_FROM_ZNX64_PRECOMP = reim_from_znx64_precomp;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct reim_from_tnx32_precomp {
_unused: [u8; 0],
}
pub type REIM_FROM_TNX32_PRECOMP = reim_from_tnx32_precomp;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct reim_to_tnx32_precomp {
_unused: [u8; 0],
}
pub type REIM_TO_TNX32_PRECOMP = reim_to_tnx32_precomp;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct reim_to_tnx_precomp {
_unused: [u8; 0],
}
pub type REIM_TO_TNX_PRECOMP = reim_to_tnx_precomp;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct reim_to_znx64_precomp {
_unused: [u8; 0],
}
pub type REIM_TO_ZNX64_PRECOMP = reim_to_znx64_precomp;
unsafe extern "C" {
pub unsafe fn new_reim_fft_precomp(m: u32, num_buffers: u32) -> *mut REIM_FFT_PRECOMP;
}
unsafe extern "C" {
pub unsafe fn reim_fft_precomp_get_buffer(tables: *const REIM_FFT_PRECOMP, buffer_index: u32) -> *mut f64;
}
unsafe extern "C" {
pub unsafe fn new_reim_fft_buffer(m: u32) -> *mut f64;
}
unsafe extern "C" {
pub unsafe fn delete_reim_fft_buffer(buffer: *mut f64);
}
unsafe extern "C" {
pub unsafe fn reim_fft(tables: *const REIM_FFT_PRECOMP, data: *mut f64);
}
unsafe extern "C" {
pub unsafe fn new_reim_ifft_precomp(m: u32, num_buffers: u32) -> *mut REIM_IFFT_PRECOMP;
}
unsafe extern "C" {
pub unsafe fn reim_ifft_precomp_get_buffer(tables: *const REIM_IFFT_PRECOMP, buffer_index: u32) -> *mut f64;
}
unsafe extern "C" {
pub unsafe fn reim_ifft(tables: *const REIM_IFFT_PRECOMP, data: *mut f64);
}
unsafe extern "C" {
pub unsafe fn new_reim_fftvec_mul_precomp(m: u32) -> *mut REIM_FFTVEC_MUL_PRECOMP;
}
unsafe extern "C" {
pub unsafe fn reim_fftvec_mul(tables: *const REIM_FFTVEC_MUL_PRECOMP, r: *mut f64, a: *const f64, b: *const f64);
}
unsafe extern "C" {
pub unsafe fn new_reim_fftvec_addmul_precomp(m: u32) -> *mut REIM_FFTVEC_ADDMUL_PRECOMP;
}
unsafe extern "C" {
pub unsafe fn reim_fftvec_addmul(tables: *const REIM_FFTVEC_ADDMUL_PRECOMP, r: *mut f64, a: *const f64, b: *const f64);
}
unsafe extern "C" {
pub unsafe fn new_reim_from_znx32_precomp(m: u32, log2bound: u32) -> *mut REIM_FROM_ZNX32_PRECOMP;
}
unsafe extern "C" {
pub unsafe fn reim_from_znx32(tables: *const REIM_FROM_ZNX32_PRECOMP, r: *mut ::std::os::raw::c_void, a: *const i32);
}
unsafe extern "C" {
pub unsafe fn reim_from_znx64(tables: *const REIM_FROM_ZNX64_PRECOMP, r: *mut ::std::os::raw::c_void, a: *const i64);
}
unsafe extern "C" {
pub unsafe fn new_reim_from_znx64_precomp(m: u32, maxbnd: u32) -> *mut REIM_FROM_ZNX64_PRECOMP;
}
unsafe extern "C" {
pub unsafe fn reim_from_znx64_simple(m: u32, log2bound: u32, r: *mut ::std::os::raw::c_void, a: *const i64);
}
unsafe extern "C" {
pub unsafe fn new_reim_from_tnx32_precomp(m: u32) -> *mut REIM_FROM_TNX32_PRECOMP;
}
unsafe extern "C" {
pub unsafe fn reim_from_tnx32(tables: *const REIM_FROM_TNX32_PRECOMP, r: *mut ::std::os::raw::c_void, a: *const i32);
}
unsafe extern "C" {
pub unsafe fn new_reim_to_tnx32_precomp(m: u32, divisor: f64, log2overhead: u32) -> *mut REIM_TO_TNX32_PRECOMP;
}
unsafe extern "C" {
pub unsafe fn reim_to_tnx32(tables: *const REIM_TO_TNX32_PRECOMP, r: *mut i32, a: *const ::std::os::raw::c_void);
}
unsafe extern "C" {
pub unsafe fn new_reim_to_tnx_precomp(m: u32, divisor: f64, log2overhead: u32) -> *mut REIM_TO_TNX_PRECOMP;
}
unsafe extern "C" {
pub unsafe fn reim_to_tnx(tables: *const REIM_TO_TNX_PRECOMP, r: *mut f64, a: *const f64);
}
unsafe extern "C" {
pub unsafe fn reim_to_tnx_simple(m: u32, divisor: f64, log2overhead: u32, r: *mut f64, a: *const f64);
}
unsafe extern "C" {
pub unsafe fn new_reim_to_znx64_precomp(m: u32, divisor: f64, log2bound: u32) -> *mut REIM_TO_ZNX64_PRECOMP;
}
unsafe extern "C" {
pub unsafe fn reim_to_znx64(precomp: *const REIM_TO_ZNX64_PRECOMP, r: *mut i64, a: *const ::std::os::raw::c_void);
}
unsafe extern "C" {
pub unsafe fn reim_to_znx64_simple(m: u32, divisor: f64, log2bound: u32, r: *mut i64, a: *const ::std::os::raw::c_void);
}
unsafe extern "C" {
pub unsafe fn reim_fft_simple(m: u32, data: *mut ::std::os::raw::c_void);
}
unsafe extern "C" {
pub unsafe fn reim_ifft_simple(m: u32, data: *mut ::std::os::raw::c_void);
}
unsafe extern "C" {
pub unsafe fn reim_fftvec_mul_simple(
m: u32,
r: *mut ::std::os::raw::c_void,
a: *const ::std::os::raw::c_void,
b: *const ::std::os::raw::c_void,
);
}
unsafe extern "C" {
pub unsafe fn reim_fftvec_addmul_simple(
m: u32,
r: *mut ::std::os::raw::c_void,
a: *const ::std::os::raw::c_void,
b: *const ::std::os::raw::c_void,
);
}
unsafe extern "C" {
pub unsafe fn reim_from_znx32_simple(m: u32, log2bound: u32, r: *mut ::std::os::raw::c_void, x: *const i32);
}
unsafe extern "C" {
pub unsafe fn reim_from_tnx32_simple(m: u32, r: *mut ::std::os::raw::c_void, x: *const i32);
}
unsafe extern "C" {
pub unsafe fn reim_to_tnx32_simple(m: u32, divisor: f64, log2overhead: u32, r: *mut i32, x: *const ::std::os::raw::c_void);
}

48
backend/src/ffi/svp.rs Normal file
View File

@@ -0,0 +1,48 @@
use crate::ffi::module::MODULE;
use crate::ffi::vec_znx_dft::VEC_ZNX_DFT;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct svp_ppol_t {
_unused: [u8; 0],
}
pub type SVP_PPOL = svp_ppol_t;
unsafe extern "C" {
pub unsafe fn bytes_of_svp_ppol(module: *const MODULE) -> u64;
}
unsafe extern "C" {
pub unsafe fn new_svp_ppol(module: *const MODULE) -> *mut SVP_PPOL;
}
unsafe extern "C" {
pub unsafe fn delete_svp_ppol(res: *mut SVP_PPOL);
}
unsafe extern "C" {
pub unsafe fn svp_prepare(module: *const MODULE, ppol: *mut SVP_PPOL, pol: *const i64);
}
unsafe extern "C" {
pub unsafe fn svp_apply_dft(
module: *const MODULE,
res: *const VEC_ZNX_DFT,
res_size: u64,
ppol: *const SVP_PPOL,
a: *const i64,
a_size: u64,
a_sl: u64,
);
}
unsafe extern "C" {
pub unsafe fn svp_apply_dft_to_dft(
module: *const MODULE,
res: *const VEC_ZNX_DFT,
res_size: u64,
res_cols: u64,
ppol: *const SVP_PPOL,
a: *const VEC_ZNX_DFT,
a_size: u64,
a_cols: u64,
);
}

101
backend/src/ffi/vec_znx.rs Normal file
View File

@@ -0,0 +1,101 @@
use crate::ffi::module::MODULE;
unsafe extern "C" {
pub unsafe fn vec_znx_add(
module: *const MODULE,
res: *mut i64,
res_size: u64,
res_sl: u64,
a: *const i64,
a_size: u64,
a_sl: u64,
b: *const i64,
b_size: u64,
b_sl: u64,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_automorphism(
module: *const MODULE,
p: i64,
res: *mut i64,
res_size: u64,
res_sl: u64,
a: *const i64,
a_size: u64,
a_sl: u64,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_negate(
module: *const MODULE,
res: *mut i64,
res_size: u64,
res_sl: u64,
a: *const i64,
a_size: u64,
a_sl: u64,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_rotate(
module: *const MODULE,
p: i64,
res: *mut i64,
res_size: u64,
res_sl: u64,
a: *const i64,
a_size: u64,
a_sl: u64,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_sub(
module: *const MODULE,
res: *mut i64,
res_size: u64,
res_sl: u64,
a: *const i64,
a_size: u64,
a_sl: u64,
b: *const i64,
b_size: u64,
b_sl: u64,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_zero(module: *const MODULE, res: *mut i64, res_size: u64, res_sl: u64);
}
unsafe extern "C" {
pub unsafe fn vec_znx_copy(
module: *const MODULE,
res: *mut i64,
res_size: u64,
res_sl: u64,
a: *const i64,
a_size: u64,
a_sl: u64,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_normalize_base2k(
module: *const MODULE,
base2k: u64,
res: *mut i64,
res_size: u64,
res_sl: u64,
a: *const i64,
a_size: u64,
a_sl: u64,
tmp_space: *mut u8,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64;
}

View File

@@ -0,0 +1,161 @@
use crate::ffi::module::MODULE;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct vec_znx_big_t {
_unused: [u8; 0],
}
pub type VEC_ZNX_BIG = vec_znx_big_t;
unsafe extern "C" {
pub unsafe fn bytes_of_vec_znx_big(module: *const MODULE, size: u64) -> u64;
}
unsafe extern "C" {
pub unsafe fn new_vec_znx_big(module: *const MODULE, size: u64) -> *mut VEC_ZNX_BIG;
}
unsafe extern "C" {
pub unsafe fn delete_vec_znx_big(res: *mut VEC_ZNX_BIG);
}
unsafe extern "C" {
pub unsafe fn vec_znx_big_add(
module: *const MODULE,
res: *mut VEC_ZNX_BIG,
res_size: u64,
a: *const VEC_ZNX_BIG,
a_size: u64,
b: *const VEC_ZNX_BIG,
b_size: u64,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_big_add_small(
module: *const MODULE,
res: *mut VEC_ZNX_BIG,
res_size: u64,
a: *const VEC_ZNX_BIG,
a_size: u64,
b: *const i64,
b_size: u64,
b_sl: u64,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_big_add_small2(
module: *const MODULE,
res: *mut VEC_ZNX_BIG,
res_size: u64,
a: *const i64,
a_size: u64,
a_sl: u64,
b: *const i64,
b_size: u64,
b_sl: u64,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_big_sub(
module: *const MODULE,
res: *mut VEC_ZNX_BIG,
res_size: u64,
a: *const VEC_ZNX_BIG,
a_size: u64,
b: *const VEC_ZNX_BIG,
b_size: u64,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_big_sub_small_b(
module: *const MODULE,
res: *mut VEC_ZNX_BIG,
res_size: u64,
a: *const VEC_ZNX_BIG,
a_size: u64,
b: *const i64,
b_size: u64,
b_sl: u64,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_big_sub_small_a(
module: *const MODULE,
res: *mut VEC_ZNX_BIG,
res_size: u64,
a: *const i64,
a_size: u64,
a_sl: u64,
b: *const VEC_ZNX_BIG,
b_size: u64,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_big_sub_small2(
module: *const MODULE,
res: *mut VEC_ZNX_BIG,
res_size: u64,
a: *const i64,
a_size: u64,
a_sl: u64,
b: *const i64,
b_size: u64,
b_sl: u64,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_big_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64;
}
unsafe extern "C" {
pub unsafe fn vec_znx_big_normalize_base2k(
module: *const MODULE,
log2_base2k: u64,
res: *mut i64,
res_size: u64,
res_sl: u64,
a: *const VEC_ZNX_BIG,
a_size: u64,
tmp_space: *mut u8,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_big_range_normalize_base2k(
module: *const MODULE,
log2_base2k: u64,
res: *mut i64,
res_size: u64,
res_sl: u64,
a: *const VEC_ZNX_BIG,
a_range_begin: u64,
a_range_xend: u64,
a_range_step: u64,
tmp_space: *mut u8,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_big_range_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64;
}
unsafe extern "C" {
pub unsafe fn vec_znx_big_automorphism(
module: *const MODULE,
p: i64,
res: *mut VEC_ZNX_BIG,
res_size: u64,
a: *const VEC_ZNX_BIG,
a_size: u64,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_big_rotate(
module: *const MODULE,
p: i64,
res: *mut VEC_ZNX_BIG,
res_size: u64,
a: *const VEC_ZNX_BIG,
a_size: u64,
);
}

View File

@@ -0,0 +1,86 @@
use crate::ffi::module::MODULE;
use crate::ffi::vec_znx_big::VEC_ZNX_BIG;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct vec_znx_dft_t {
_unused: [u8; 0],
}
pub type VEC_ZNX_DFT = vec_znx_dft_t;
unsafe extern "C" {
pub unsafe fn bytes_of_vec_znx_dft(module: *const MODULE, size: u64) -> u64;
}
unsafe extern "C" {
pub unsafe fn new_vec_znx_dft(module: *const MODULE, size: u64) -> *mut VEC_ZNX_DFT;
}
unsafe extern "C" {
pub unsafe fn delete_vec_znx_dft(res: *mut VEC_ZNX_DFT);
}
unsafe extern "C" {
pub unsafe fn vec_dft_zero(module: *const MODULE, res: *mut VEC_ZNX_DFT, res_size: u64);
}
unsafe extern "C" {
pub unsafe fn vec_dft_add(
module: *const MODULE,
res: *mut VEC_ZNX_DFT,
res_size: u64,
a: *const VEC_ZNX_DFT,
a_size: u64,
b: *const VEC_ZNX_DFT,
b_size: u64,
);
}
unsafe extern "C" {
pub unsafe fn vec_dft_sub(
module: *const MODULE,
res: *mut VEC_ZNX_DFT,
res_size: u64,
a: *const VEC_ZNX_DFT,
a_size: u64,
b: *const VEC_ZNX_DFT,
b_size: u64,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_dft(module: *const MODULE, res: *mut VEC_ZNX_DFT, res_size: u64, a: *const i64, a_size: u64, a_sl: u64);
}
unsafe extern "C" {
pub unsafe fn vec_znx_idft(
module: *const MODULE,
res: *mut VEC_ZNX_BIG,
res_size: u64,
a_dft: *const VEC_ZNX_DFT,
a_size: u64,
tmp: *mut u8,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_idft_tmp_bytes(module: *const MODULE) -> u64;
}
unsafe extern "C" {
pub unsafe fn vec_znx_idft_tmp_a(
module: *const MODULE,
res: *mut VEC_ZNX_BIG,
res_size: u64,
a_dft: *mut VEC_ZNX_DFT,
a_size: u64,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_dft_automorphism(
module: *const MODULE,
d: i64,
res_dft: *mut VEC_ZNX_DFT,
res_size: u64,
a_dft: *const VEC_ZNX_DFT,
a_size: u64,
tmp: *mut u8,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_dft_automorphism_tmp_bytes(module: *const MODULE) -> u64;
}

165
backend/src/ffi/vmp.rs Normal file
View File

@@ -0,0 +1,165 @@
use crate::ffi::module::MODULE;
use crate::ffi::vec_znx_big::VEC_ZNX_BIG;
use crate::ffi::vec_znx_dft::VEC_ZNX_DFT;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct vmp_pmat_t {
_unused: [u8; 0],
}
// [rows][cols] = [#Decomposition][#Limbs]
pub type VMP_PMAT = vmp_pmat_t;
unsafe extern "C" {
pub unsafe fn bytes_of_vmp_pmat(module: *const MODULE, nrows: u64, ncols: u64) -> u64;
}
unsafe extern "C" {
pub unsafe fn new_vmp_pmat(module: *const MODULE, nrows: u64, ncols: u64) -> *mut VMP_PMAT;
}
unsafe extern "C" {
pub unsafe fn delete_vmp_pmat(res: *mut VMP_PMAT);
}
unsafe extern "C" {
pub unsafe fn vmp_apply_dft(
module: *const MODULE,
res: *mut VEC_ZNX_DFT,
res_size: u64,
a: *const i64,
a_size: u64,
a_sl: u64,
pmat: *const VMP_PMAT,
nrows: u64,
ncols: u64,
tmp_space: *mut u8,
);
}
unsafe extern "C" {
pub unsafe fn vmp_apply_dft_add(
module: *const MODULE,
res: *mut VEC_ZNX_DFT,
res_size: u64,
a: *const i64,
a_size: u64,
a_sl: u64,
pmat: *const VMP_PMAT,
nrows: u64,
ncols: u64,
tmp_space: *mut u8,
);
}
unsafe extern "C" {
pub unsafe fn vmp_apply_dft_tmp_bytes(module: *const MODULE, res_size: u64, a_size: u64, nrows: u64, ncols: u64) -> u64;
}
unsafe extern "C" {
pub unsafe fn vmp_apply_dft_to_dft(
module: *const MODULE,
res: *mut VEC_ZNX_DFT,
res_size: u64,
a_dft: *const VEC_ZNX_DFT,
a_size: u64,
pmat: *const VMP_PMAT,
nrows: u64,
ncols: u64,
tmp_space: *mut u8,
);
}
unsafe extern "C" {
pub unsafe fn vmp_apply_dft_to_dft_add(
module: *const MODULE,
res: *mut VEC_ZNX_DFT,
res_size: u64,
a_dft: *const VEC_ZNX_DFT,
a_size: u64,
pmat: *const VMP_PMAT,
nrows: u64,
ncols: u64,
tmp_space: *mut u8,
);
}
unsafe extern "C" {
pub unsafe fn vmp_apply_dft_to_dft_tmp_bytes(
module: *const MODULE,
res_size: u64,
a_size: u64,
nrows: u64,
ncols: u64,
) -> u64;
}
unsafe extern "C" {
pub unsafe fn vmp_prepare_contiguous(
module: *const MODULE,
pmat: *mut VMP_PMAT,
mat: *const i64,
nrows: u64,
ncols: u64,
tmp_space: *mut u8,
);
}
unsafe extern "C" {
pub unsafe fn vmp_prepare_dblptr(
module: *const MODULE,
pmat: *mut VMP_PMAT,
mat: *const *const i64,
nrows: u64,
ncols: u64,
tmp_space: *mut u8,
);
}
unsafe extern "C" {
pub unsafe fn vmp_prepare_row(
module: *const MODULE,
pmat: *mut VMP_PMAT,
row: *const i64,
row_i: u64,
nrows: u64,
ncols: u64,
tmp_space: *mut u8,
);
}
unsafe extern "C" {
pub unsafe fn vmp_prepare_row_dft(
module: *const MODULE,
pmat: *mut VMP_PMAT,
row: *const VEC_ZNX_DFT,
row_i: u64,
nrows: u64,
ncols: u64,
);
}
unsafe extern "C" {
pub unsafe fn vmp_extract_row_dft(
module: *const MODULE,
res: *mut VEC_ZNX_DFT,
pmat: *const VMP_PMAT,
row_i: u64,
nrows: u64,
ncols: u64,
);
}
unsafe extern "C" {
pub unsafe fn vmp_extract_row(
module: *const MODULE,
res: *mut VEC_ZNX_BIG,
pmat: *const VMP_PMAT,
row_i: u64,
nrows: u64,
ncols: u64,
);
}
unsafe extern "C" {
pub unsafe fn vmp_prepare_tmp_bytes(module: *const MODULE, nrows: u64, ncols: u64) -> u64;
}

76
backend/src/ffi/znx.rs Normal file
View File

@@ -0,0 +1,76 @@
use crate::ffi::module::MODULE;
unsafe extern "C" {
pub unsafe fn znx_add_i64_ref(nn: u64, res: *mut i64, a: *const i64, b: *const i64);
}
unsafe extern "C" {
pub unsafe fn znx_add_i64_avx(nn: u64, res: *mut i64, a: *const i64, b: *const i64);
}
unsafe extern "C" {
pub unsafe fn znx_sub_i64_ref(nn: u64, res: *mut i64, a: *const i64, b: *const i64);
}
unsafe extern "C" {
pub unsafe fn znx_sub_i64_avx(nn: u64, res: *mut i64, a: *const i64, b: *const i64);
}
unsafe extern "C" {
pub unsafe fn znx_negate_i64_ref(nn: u64, res: *mut i64, a: *const i64);
}
unsafe extern "C" {
pub unsafe fn znx_negate_i64_avx(nn: u64, res: *mut i64, a: *const i64);
}
unsafe extern "C" {
pub unsafe fn znx_copy_i64_ref(nn: u64, res: *mut i64, a: *const i64);
}
unsafe extern "C" {
pub unsafe fn znx_zero_i64_ref(nn: u64, res: *mut i64);
}
unsafe extern "C" {
pub unsafe fn rnx_divide_by_m_ref(nn: u64, m: f64, res: *mut f64, a: *const f64);
}
unsafe extern "C" {
pub unsafe fn rnx_divide_by_m_avx(nn: u64, m: f64, res: *mut f64, a: *const f64);
}
unsafe extern "C" {
pub unsafe fn rnx_rotate_f64(nn: u64, p: i64, res: *mut f64, in_: *const f64);
}
unsafe extern "C" {
pub unsafe fn znx_rotate_i64(nn: u64, p: i64, res: *mut i64, in_: *const i64);
}
unsafe extern "C" {
pub unsafe fn rnx_rotate_inplace_f64(nn: u64, p: i64, res: *mut f64);
}
unsafe extern "C" {
pub unsafe fn znx_rotate_inplace_i64(nn: u64, p: i64, res: *mut i64);
}
unsafe extern "C" {
pub unsafe fn rnx_automorphism_f64(nn: u64, p: i64, res: *mut f64, in_: *const f64);
}
unsafe extern "C" {
pub unsafe fn znx_automorphism_i64(nn: u64, p: i64, res: *mut i64, in_: *const i64);
}
unsafe extern "C" {
pub unsafe fn rnx_automorphism_inplace_f64(nn: u64, p: i64, res: *mut f64);
}
unsafe extern "C" {
pub unsafe fn znx_automorphism_inplace_i64(nn: u64, p: i64, res: *mut i64);
}
unsafe extern "C" {
pub unsafe fn rnx_mul_xp_minus_one(nn: u64, p: i64, res: *mut f64, in_: *const f64);
}
unsafe extern "C" {
pub unsafe fn znx_mul_xp_minus_one(nn: u64, p: i64, res: *mut i64, in_: *const i64);
}
unsafe extern "C" {
pub unsafe fn rnx_mul_xp_minus_one_inplace(nn: u64, p: i64, res: *mut f64);
}
unsafe extern "C" {
pub unsafe fn znx_normalize(nn: u64, base_k: u64, out: *mut i64, carry_out: *mut i64, in_: *const i64, carry_in: *const i64);
}
unsafe extern "C" {
pub unsafe fn znx_small_single_product(module: *const MODULE, res: *mut i64, a: *const i64, b: *const i64, tmp: *mut u8);
}
unsafe extern "C" {
pub unsafe fn znx_small_single_product_tmp_bytes(module: *const MODULE) -> u64;
}

251
backend/src/lib.rs Normal file
View File

@@ -0,0 +1,251 @@
pub mod encoding;
#[allow(non_camel_case_types, non_snake_case, non_upper_case_globals, dead_code, improper_ctypes)]
// Other modules and exports
pub mod ffi;
pub mod mat_znx_dft;
pub mod mat_znx_dft_ops;
pub mod module;
pub mod sampling;
pub mod scalar_znx;
pub mod scalar_znx_dft;
pub mod scalar_znx_dft_ops;
pub mod stats;
pub mod vec_znx;
pub mod vec_znx_big;
pub mod vec_znx_big_ops;
pub mod vec_znx_dft;
pub mod vec_znx_dft_ops;
pub mod vec_znx_ops;
pub mod znx_base;
pub use encoding::*;
pub use mat_znx_dft::*;
pub use mat_znx_dft_ops::*;
pub use module::*;
pub use sampling::*;
pub use scalar_znx::*;
pub use scalar_znx_dft::*;
pub use scalar_znx_dft_ops::*;
pub use stats::*;
pub use vec_znx::*;
pub use vec_znx_big::*;
pub use vec_znx_big_ops::*;
pub use vec_znx_dft::*;
pub use vec_znx_dft_ops::*;
pub use vec_znx_ops::*;
pub use znx_base::*;
pub const GALOISGENERATOR: u64 = 5;
pub const DEFAULTALIGN: usize = 64;
fn is_aligned_custom<T>(ptr: *const T, align: usize) -> bool {
(ptr as usize) % align == 0
}
pub fn is_aligned<T>(ptr: *const T) -> bool {
is_aligned_custom(ptr, DEFAULTALIGN)
}
pub fn assert_alignement<T>(ptr: *const T) {
assert!(
is_aligned(ptr),
"invalid alignement: ensure passed bytes have been allocated with [alloc_aligned_u8] or [alloc_aligned]"
)
}
pub fn cast<T, V>(data: &[T]) -> &[V] {
let ptr: *const V = data.as_ptr() as *const V;
let len: usize = data.len() / size_of::<V>();
unsafe { std::slice::from_raw_parts(ptr, len) }
}
pub fn cast_mut<T, V>(data: &[T]) -> &mut [V] {
let ptr: *mut V = data.as_ptr() as *mut V;
let len: usize = data.len() / size_of::<V>();
unsafe { std::slice::from_raw_parts_mut(ptr, len) }
}
/// Allocates a block of bytes with a custom alignement.
/// Alignement must be a power of two and size a multiple of the alignement.
/// Allocated memory is initialized to zero.
fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec<u8> {
assert!(
align.is_power_of_two(),
"Alignment must be a power of two but is {}",
align
);
assert_eq!(
(size * size_of::<u8>()) % align,
0,
"size={} must be a multiple of align={}",
size,
align
);
unsafe {
let layout: std::alloc::Layout = std::alloc::Layout::from_size_align(size, align).expect("Invalid alignment");
let ptr: *mut u8 = std::alloc::alloc(layout);
if ptr.is_null() {
panic!("Memory allocation failed");
}
assert!(
is_aligned_custom(ptr, align),
"Memory allocation at {:p} is not aligned to {} bytes",
ptr,
align
);
// Init allocated memory to zero
std::ptr::write_bytes(ptr, 0, size);
Vec::from_raw_parts(ptr, size, size)
}
}
/// Allocates a block of T aligned with [DEFAULTALIGN].
/// Size of T * size msut be a multiple of [DEFAULTALIGN].
pub fn alloc_aligned_custom<T>(size: usize, align: usize) -> Vec<T> {
assert_eq!(
(size * size_of::<T>()) % align,
0,
"size={} must be a multiple of align={}",
size,
align
);
let mut vec_u8: Vec<u8> = alloc_aligned_custom_u8(size_of::<T>() * size, align);
let ptr: *mut T = vec_u8.as_mut_ptr() as *mut T;
let len: usize = vec_u8.len() / size_of::<T>();
let cap: usize = vec_u8.capacity() / size_of::<T>();
std::mem::forget(vec_u8);
unsafe { Vec::from_raw_parts(ptr, len, cap) }
}
/// Allocates an aligned vector of size equal to the smallest multiple
/// of [DEFAULTALIGN]/size_of::<T>() that is equal or greater to `size`.
pub fn alloc_aligned<T>(size: usize) -> Vec<T> {
alloc_aligned_custom::<T>(
size + (size % (DEFAULTALIGN / size_of::<T>())),
DEFAULTALIGN,
)
}
// Scratch implementation below
pub struct ScratchOwned(Vec<u8>);
impl ScratchOwned {
pub fn new(byte_count: usize) -> Self {
let data: Vec<u8> = alloc_aligned(byte_count);
Self(data)
}
pub fn borrow(&mut self) -> &mut Scratch {
Scratch::new(&mut self.0)
}
}
pub struct Scratch {
data: [u8],
}
impl Scratch {
fn new(data: &mut [u8]) -> &mut Self {
unsafe { &mut *(data as *mut [u8] as *mut Self) }
}
pub fn available(&self) -> usize {
let ptr: *const u8 = self.data.as_ptr();
let self_len: usize = self.data.len();
let aligned_offset: usize = ptr.align_offset(DEFAULTALIGN);
self_len.saturating_sub(aligned_offset)
}
fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8]) {
let ptr: *mut u8 = data.as_mut_ptr();
let self_len: usize = data.len();
let aligned_offset: usize = ptr.align_offset(DEFAULTALIGN);
let aligned_len: usize = self_len.saturating_sub(aligned_offset);
if let Some(rem_len) = aligned_len.checked_sub(take_len) {
unsafe {
let rem_ptr: *mut u8 = ptr.add(aligned_offset).add(take_len);
let rem_slice: &mut [u8] = &mut *std::ptr::slice_from_raw_parts_mut(rem_ptr, rem_len);
let take_slice: &mut [u8] = &mut *std::ptr::slice_from_raw_parts_mut(ptr.add(aligned_offset), take_len);
return (take_slice, rem_slice);
}
} else {
panic!(
"Attempted to take {} from scratch with {} aligned bytes left",
take_len,
aligned_len,
// type_name::<T>(),
// aligned_len
);
}
}
pub fn tmp_slice<T>(&mut self, len: usize) -> (&mut [T], &mut Self) {
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, len * std::mem::size_of::<T>());
unsafe {
(
&mut *(std::ptr::slice_from_raw_parts_mut(take_slice.as_mut_ptr() as *mut T, len)),
Self::new(rem_slice),
)
}
}
pub fn tmp_scalar_znx<B: Backend>(&mut self, module: &Module<B>, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) {
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_scalar_znx(module, cols));
(
ScalarZnx::from_data(take_slice, module.n(), cols),
Self::new(rem_slice),
)
}
pub fn tmp_scalar_znx_dft<B: Backend>(&mut self, module: &Module<B>, cols: usize) -> (ScalarZnxDft<&mut [u8], B>, &mut Self) {
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_scalar_znx_dft(module, cols));
(
ScalarZnxDft::from_data(take_slice, module.n(), cols),
Self::new(rem_slice),
)
}
pub fn tmp_vec_znx_dft<B: Backend>(
&mut self,
module: &Module<B>,
cols: usize,
size: usize,
) -> (VecZnxDft<&mut [u8], B>, &mut Self) {
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_vec_znx_dft(module, cols, size));
(
VecZnxDft::from_data(take_slice, module.n(), cols, size),
Self::new(rem_slice),
)
}
pub fn tmp_vec_znx_big<B: Backend>(
&mut self,
module: &Module<B>,
cols: usize,
size: usize,
) -> (VecZnxBig<&mut [u8], B>, &mut Self) {
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_vec_znx_big(module, cols, size));
(
VecZnxBig::from_data(take_slice, module.n(), cols, size),
Self::new(rem_slice),
)
}
pub fn tmp_vec_znx<B: Backend>(&mut self, module: &Module<B>, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self) {
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, module.bytes_of_vec_znx(cols, size));
(
VecZnx::from_data(take_slice, module.n(), cols, size),
Self::new(rem_slice),
)
}
}

232
backend/src/mat_znx_dft.rs Normal file
View File

@@ -0,0 +1,232 @@
use crate::znx_base::ZnxInfos;
use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxView, alloc_aligned};
use std::marker::PhantomData;
/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx],
/// stored as a 3D matrix in the DFT domain in a single contiguous array.
/// Each col of the [MatZnxDft] can be seen as a collection of [VecZnxDft].
///
/// [MatZnxDft] is used to permform a vector matrix product between a [VecZnx]/[VecZnxDft] and a [MatZnxDft].
/// See the trait [MatZnxDftOps] for additional information.
pub struct MatZnxDft<D, B: Backend> {
data: D,
n: usize,
size: usize,
rows: usize,
cols_in: usize,
cols_out: usize,
_phantom: PhantomData<B>,
}
impl<D, B: Backend> ZnxInfos for MatZnxDft<D, B> {
fn cols(&self) -> usize {
self.cols_in
}
fn rows(&self) -> usize {
self.rows
}
fn n(&self) -> usize {
self.n
}
fn size(&self) -> usize {
self.size
}
}
impl<D> ZnxSliceSize for MatZnxDft<D, FFT64> {
fn sl(&self) -> usize {
self.n() * self.cols_out()
}
}
impl<D, B: Backend> DataView for MatZnxDft<D, B> {
type D = D;
fn data(&self) -> &Self::D {
&self.data
}
}
impl<D, B: Backend> DataViewMut for MatZnxDft<D, B> {
fn data_mut(&mut self) -> &mut Self::D {
&mut self.data
}
}
impl<D: AsRef<[u8]>> ZnxView for MatZnxDft<D, FFT64> {
type Scalar = f64;
}
impl<D, B: Backend> MatZnxDft<D, B> {
pub fn cols_in(&self) -> usize {
self.cols_in
}
pub fn cols_out(&self) -> usize {
self.cols_out
}
}
impl<D: From<Vec<u8>>, B: Backend> MatZnxDft<D, B> {
pub(crate) fn bytes_of(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
unsafe {
crate::ffi::vmp::bytes_of_vmp_pmat(
module.ptr,
(rows * cols_in) as u64,
(size * cols_out) as u64,
) as usize
}
}
pub(crate) fn new(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
let data: Vec<u8> = alloc_aligned(Self::bytes_of(module, rows, cols_in, cols_out, size));
Self {
data: data.into(),
n: module.n(),
size,
rows,
cols_in,
cols_out,
_phantom: PhantomData,
}
}
pub(crate) fn new_from_bytes(
module: &Module<B>,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
bytes: impl Into<Vec<u8>>,
) -> Self {
let data: Vec<u8> = bytes.into();
assert!(data.len() == Self::bytes_of(module, rows, cols_in, cols_out, size));
Self {
data: data.into(),
n: module.n(),
size,
rows,
cols_in,
cols_out,
_phantom: PhantomData,
}
}
}
impl<D: AsRef<[u8]>> MatZnxDft<D, FFT64> {
/// Returns a copy of the backend array at index (i, j) of the [MatZnxDft].
///
/// # Arguments
///
/// * `row`: row index (i).
/// * `col`: col index (j).
#[allow(dead_code)]
fn at(&self, row: usize, col: usize) -> Vec<f64> {
let n: usize = self.n();
let mut res: Vec<f64> = alloc_aligned(n);
if n < 8 {
res.copy_from_slice(&self.raw()[(row + col * self.rows()) * n..(row + col * self.rows()) * (n + 1)]);
} else {
(0..n >> 3).for_each(|blk| {
res[blk * 8..(blk + 1) * 8].copy_from_slice(&self.at_block(row, col, blk)[..8]);
});
}
res
}
#[allow(dead_code)]
fn at_block(&self, row: usize, col: usize, blk: usize) -> &[f64] {
let nrows: usize = self.rows();
let nsize: usize = self.size();
if col == (nsize - 1) && (nsize & 1 == 1) {
&self.raw()[blk * nrows * nsize * 8 + col * nrows * 8 + row * 8..]
} else {
&self.raw()[blk * nrows * nsize * 8 + (col / 2) * (2 * nrows) * 8 + row * 2 * 8 + (col % 2) * 8..]
}
}
}
pub type MatZnxDftOwned<B> = MatZnxDft<Vec<u8>, B>;
pub trait MatZnxDftToRef<B: Backend> {
fn to_ref(&self) -> MatZnxDft<&[u8], B>;
}
pub trait MatZnxDftToMut<B: Backend> {
fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B>;
}
impl<B: Backend> MatZnxDftToMut<B> for MatZnxDft<Vec<u8>, B> {
fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> {
MatZnxDft {
data: self.data.as_mut_slice(),
n: self.n,
rows: self.rows,
cols_in: self.cols_in,
cols_out: self.cols_out,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> MatZnxDftToRef<B> for MatZnxDft<Vec<u8>, B> {
fn to_ref(&self) -> MatZnxDft<&[u8], B> {
MatZnxDft {
data: self.data.as_slice(),
n: self.n,
rows: self.rows,
cols_in: self.cols_in,
cols_out: self.cols_out,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> MatZnxDftToMut<B> for MatZnxDft<&mut [u8], B> {
fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> {
MatZnxDft {
data: self.data,
n: self.n,
rows: self.rows,
cols_in: self.cols_in,
cols_out: self.cols_out,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> MatZnxDftToRef<B> for MatZnxDft<&mut [u8], B> {
fn to_ref(&self) -> MatZnxDft<&[u8], B> {
MatZnxDft {
data: self.data,
n: self.n,
rows: self.rows,
cols_in: self.cols_in,
cols_out: self.cols_out,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> MatZnxDftToRef<B> for MatZnxDft<&[u8], B> {
fn to_ref(&self) -> MatZnxDft<&[u8], B> {
MatZnxDft {
data: self.data,
n: self.n,
rows: self.rows,
cols_in: self.cols_in,
cols_out: self.cols_out,
size: self.size,
_phantom: PhantomData,
}
}
}

View File

@@ -0,0 +1,488 @@
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
use crate::ffi::vmp;
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
use crate::{
Backend, FFT64, MatZnxDft, MatZnxDftOwned, MatZnxDftToMut, MatZnxDftToRef, Module, Scratch, VecZnxDft, VecZnxDftToMut,
VecZnxDftToRef,
};
pub trait MatZnxDftAlloc<B: Backend> {
/// Allocates a new [MatZnxDft] with the given number of rows and columns.
///
/// # Arguments
///
/// * `rows`: number of rows (number of [VecZnxDft]).
/// * `size`: number of size (number of size of each [VecZnxDft]).
fn new_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxDftOwned<B>;
fn bytes_of_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
fn new_mat_znx_dft_from_bytes(
&self,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
bytes: Vec<u8>,
) -> MatZnxDftOwned<B>;
}
pub trait MatZnxDftScratch {
/// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft_to_dft].
fn vmp_apply_tmp_bytes(
&self,
res_size: usize,
a_size: usize,
b_rows: usize,
b_cols_in: usize,
b_cols_out: usize,
b_size: usize,
) -> usize;
}
/// This trait implements methods for vector matrix product,
/// that is, multiplying a [VecZnx] with a [MatZnxDft].
pub trait MatZnxDftOps<BACKEND: Backend> {
/// Prepares the ith-row of [MatZnxDft] from a [VecZnxDft].
///
/// # Arguments
///
/// * `b`: [MatZnxDft] on which the values are encoded.
/// * `a`: the [VecZnxDft] to encode on the [MatZnxDft].
/// * `row_i`: the index of the row to prepare.
///
/// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes].
fn vmp_prepare_row<R, A>(&self, res: &mut R, res_row: usize, res_col_in: usize, a: &A)
where
R: MatZnxDftToMut<BACKEND>,
A: VecZnxDftToRef<BACKEND>;
/// Extracts the ith-row of [MatZnxDft] into a [VecZnxDft].
///
/// # Arguments
///
/// * `b`: the [VecZnxDft] to on which to extract the row of the [MatZnxDft].
/// * `a`: [MatZnxDft] on which the values are encoded.
/// * `row_i`: the index of the row to extract.
fn vmp_extract_row<R, A>(&self, res: &mut R, a: &A, a_row: usize, a_col_in: usize)
where
R: VecZnxDftToMut<BACKEND>,
A: MatZnxDftToRef<BACKEND>;
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft].
/// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
///
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
/// and each vector a [VecZnxDft] (row) of the [MatZnxDft].
///
/// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and
/// `j` size, the output is a [VecZnx] of `j` size.
///
/// If there is a mismatch between the dimensions the largest valid ones are used.
///
/// ```text
/// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p|
/// |h i j|
/// |k l m|
/// ```
/// where each element is a [VecZnxDft].
///
/// # Arguments
///
/// * `c`: the output of the vector matrix product, as a [VecZnxDft].
/// * `a`: the left operand [VecZnxDft] of the vector matrix product.
/// * `b`: the right operand [MatZnxDft] of the vector matrix product.
/// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
fn vmp_apply<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch)
where
R: VecZnxDftToMut<BACKEND>,
A: VecZnxDftToRef<BACKEND>,
B: MatZnxDftToRef<BACKEND>;
// Same as [MatZnxDftOps::vmp_apply] except result is added on R instead of overwritting R.
fn vmp_apply_add<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch)
where
R: VecZnxDftToMut<BACKEND>,
A: VecZnxDftToRef<BACKEND>,
B: MatZnxDftToRef<BACKEND>;
}
impl<B: Backend> MatZnxDftAlloc<B> for Module<B> {
fn bytes_of_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
MatZnxDftOwned::bytes_of(self, rows, cols_in, cols_out, size)
}
fn new_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxDftOwned<B> {
MatZnxDftOwned::new(self, rows, cols_in, cols_out, size)
}
fn new_mat_znx_dft_from_bytes(
&self,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
bytes: Vec<u8>,
) -> MatZnxDftOwned<B> {
MatZnxDftOwned::new_from_bytes(self, rows, cols_in, cols_out, size, bytes)
}
}
impl<BACKEND: Backend> MatZnxDftScratch for Module<BACKEND> {
fn vmp_apply_tmp_bytes(
&self,
res_size: usize,
a_size: usize,
b_rows: usize,
b_cols_in: usize,
b_cols_out: usize,
b_size: usize,
) -> usize {
unsafe {
vmp::vmp_apply_dft_to_dft_tmp_bytes(
self.ptr,
(res_size * b_cols_out) as u64,
(a_size * b_cols_in) as u64,
(b_rows * b_cols_in) as u64,
(b_size * b_cols_out) as u64,
) as usize
}
}
}
impl MatZnxDftOps<FFT64> for Module<FFT64> {
fn vmp_prepare_row<R, A>(&self, res: &mut R, res_row: usize, res_col_in: usize, a: &A)
where
R: MatZnxDftToMut<FFT64>,
A: VecZnxDftToRef<FFT64>,
{
let mut res: MatZnxDft<&mut [u8], _> = res.to_mut();
let a: VecZnxDft<&[u8], _> = a.to_ref();
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), self.n());
assert_eq!(a.n(), self.n());
assert_eq!(
a.cols(),
res.cols_out(),
"a.cols(): {} != res.cols_out(): {}",
a.cols(),
res.cols_out()
);
assert!(
res_row < res.rows(),
"res_row: {} >= res.rows(): {}",
res_row,
res.rows()
);
assert!(
res_col_in < res.cols_in(),
"res_col_in: {} >= res.cols_in(): {}",
res_col_in,
res.cols_in()
);
assert_eq!(
res.size(),
a.size(),
"res.size(): {} != a.size(): {}",
res.size(),
a.size()
);
}
unsafe {
vmp::vmp_prepare_row_dft(
self.ptr,
res.as_mut_ptr() as *mut vmp::vmp_pmat_t,
a.as_ptr() as *const vec_znx_dft_t,
(res_row * res.cols_in() + res_col_in) as u64,
(res.rows() * res.cols_in()) as u64,
(res.size() * res.cols_out()) as u64,
);
}
}
fn vmp_extract_row<R, A>(&self, res: &mut R, a: &A, a_row: usize, a_col_in: usize)
where
R: VecZnxDftToMut<FFT64>,
A: MatZnxDftToRef<FFT64>,
{
let mut res: VecZnxDft<&mut [u8], _> = res.to_mut();
let a: MatZnxDft<&[u8], _> = a.to_ref();
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), self.n());
assert_eq!(a.n(), self.n());
assert_eq!(
res.cols(),
a.cols_out(),
"res.cols(): {} != a.cols_out(): {}",
res.cols(),
a.cols_out()
);
assert!(
a_row < a.rows(),
"a_row: {} >= a.rows(): {}",
a_row,
a.rows()
);
assert!(
a_col_in < a.cols_in(),
"a_col_in: {} >= a.cols_in(): {}",
a_col_in,
a.cols_in()
);
assert_eq!(
res.size(),
a.size(),
"res.size(): {} != a.size(): {}",
res.size(),
a.size()
);
}
unsafe {
vmp::vmp_extract_row_dft(
self.ptr,
res.as_mut_ptr() as *mut vec_znx_dft_t,
a.as_ptr() as *const vmp::vmp_pmat_t,
(a_row * a.cols_in() + a_col_in) as u64,
(a.rows() * a.cols_in()) as u64,
(a.size() * a.cols_out()) as u64,
);
}
}
fn vmp_apply<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch)
where
R: VecZnxDftToMut<FFT64>,
A: VecZnxDftToRef<FFT64>,
B: MatZnxDftToRef<FFT64>,
{
let mut res: VecZnxDft<&mut [u8], _> = res.to_mut();
let a: VecZnxDft<&[u8], _> = a.to_ref();
let b: MatZnxDft<&[u8], _> = b.to_ref();
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), self.n());
assert_eq!(b.n(), self.n());
assert_eq!(a.n(), self.n());
assert_eq!(
res.cols(),
b.cols_out(),
"res.cols(): {} != b.cols_out: {}",
res.cols(),
b.cols_out()
);
assert_eq!(
a.cols(),
b.cols_in(),
"a.cols(): {} != b.cols_in: {}",
a.cols(),
b.cols_in()
);
}
let (tmp_bytes, _) = scratch.tmp_slice(self.vmp_apply_tmp_bytes(
res.size(),
a.size(),
b.rows(),
b.cols_in(),
b.cols_out(),
b.size(),
));
unsafe {
vmp::vmp_apply_dft_to_dft(
self.ptr,
res.as_mut_ptr() as *mut vec_znx_dft_t,
(res.size() * res.cols()) as u64,
a.as_ptr() as *const vec_znx_dft_t,
(a.size() * a.cols()) as u64,
b.as_ptr() as *const vmp::vmp_pmat_t,
(b.rows() * b.cols_in()) as u64,
(b.size() * b.cols_out()) as u64,
tmp_bytes.as_mut_ptr(),
)
}
}
fn vmp_apply_add<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch)
where
R: VecZnxDftToMut<FFT64>,
A: VecZnxDftToRef<FFT64>,
B: MatZnxDftToRef<FFT64>,
{
let mut res: VecZnxDft<&mut [u8], _> = res.to_mut();
let a: VecZnxDft<&[u8], _> = a.to_ref();
let b: MatZnxDft<&[u8], _> = b.to_ref();
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), self.n());
assert_eq!(b.n(), self.n());
assert_eq!(a.n(), self.n());
assert_eq!(
res.cols(),
b.cols_out(),
"res.cols(): {} != b.cols_out: {}",
res.cols(),
b.cols_out()
);
assert_eq!(
a.cols(),
b.cols_in(),
"a.cols(): {} != b.cols_in: {}",
a.cols(),
b.cols_in()
);
}
let (tmp_bytes, _) = scratch.tmp_slice(self.vmp_apply_tmp_bytes(
res.size(),
a.size(),
b.rows(),
b.cols_in(),
b.cols_out(),
b.size(),
));
unsafe {
vmp::vmp_apply_dft_to_dft_add(
self.ptr,
res.as_mut_ptr() as *mut vec_znx_dft_t,
(res.size() * res.cols()) as u64,
a.as_ptr() as *const vec_znx_dft_t,
(a.size() * a.cols()) as u64,
b.as_ptr() as *const vmp::vmp_pmat_t,
(b.rows() * b.cols_in()) as u64,
(b.size() * b.cols_out()) as u64,
tmp_bytes.as_mut_ptr(),
)
}
}
}
#[cfg(test)]
mod tests {
use crate::{
Decoding, FFT64, FillUniform, MatZnxDft, MatZnxDftOps, Module, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig,
VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, ZnxInfos, ZnxView, ZnxViewMut,
};
use sampling::source::Source;
use super::{MatZnxDftAlloc, MatZnxDftScratch};
#[test]
fn vmp_prepare_row() {
let module: Module<FFT64> = Module::<FFT64>::new(16);
let basek: usize = 8;
let mat_rows: usize = 4;
let mat_cols_in: usize = 2;
let mat_cols_out: usize = 2;
let mat_size: usize = 5;
let mut a: VecZnx<Vec<u8>> = module.new_vec_znx(mat_cols_out, mat_size);
let mut a_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size);
let mut b_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size);
let mut mat: MatZnxDft<Vec<u8>, FFT64> = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size);
for col_in in 0..mat_cols_in {
for row_i in 0..mat_rows {
let mut source: Source = Source::new([0u8; 32]);
(0..mat_cols_out).for_each(|col_out| {
a.fill_uniform(basek, col_out, mat_size, &mut source);
module.vec_znx_dft(&mut a_dft, col_out, &a, col_out);
});
module.vmp_prepare_row(&mut mat, row_i, col_in, &a_dft);
module.vmp_extract_row(&mut b_dft, &mat, row_i, col_in);
assert_eq!(a_dft.raw(), b_dft.raw());
}
}
}
#[test]
fn vmp_apply() {
let log_n: i32 = 5;
let n: usize = 1 << log_n;
let module: Module<FFT64> = Module::<FFT64>::new(n);
let basek: usize = 15;
let a_size: usize = 5;
let mat_size: usize = 6;
let res_size: usize = 5;
[1, 2].iter().for_each(|in_cols| {
[1, 2].iter().for_each(|out_cols| {
let a_cols: usize = *in_cols;
let res_cols: usize = *out_cols;
let mat_rows: usize = a_size;
let mat_cols_in: usize = a_cols;
let mat_cols_out: usize = res_cols;
let res_cols: usize = mat_cols_out;
let mut scratch: ScratchOwned = ScratchOwned::new(
module.vmp_apply_tmp_bytes(
res_size,
a_size,
mat_rows,
mat_cols_in,
mat_cols_out,
mat_size,
) | module.vec_znx_big_normalize_tmp_bytes(),
);
let mut a: VecZnx<Vec<u8>> = module.new_vec_znx(a_cols, a_size);
(0..a_cols).for_each(|i| {
a.at_mut(i, 2)[i + 1] = 1;
});
let mut mat_znx_dft: MatZnxDft<Vec<u8>, FFT64> =
module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size);
let mut c_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size);
let mut c_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(mat_cols_out, mat_size);
let mut tmp: VecZnx<Vec<u8>> = module.new_vec_znx(mat_cols_out, mat_size);
// Construts a [VecZnxMatDft] that performs cyclic rotations on each submatrix.
(0..a.size()).for_each(|row_i| {
(0..mat_cols_in).for_each(|col_in_i| {
(0..mat_cols_out).for_each(|col_out_i| {
let idx = 1 + col_in_i * mat_cols_out + col_out_i;
tmp.at_mut(col_out_i, row_i)[idx] = 1 as i64; // X^{idx}
module.vec_znx_dft(&mut c_dft, col_out_i, &tmp, col_out_i);
tmp.at_mut(col_out_i, row_i)[idx] = 0 as i64;
});
module.vmp_prepare_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft);
});
});
let mut a_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(a_cols, a_size);
(0..a_cols).for_each(|i| {
module.vec_znx_dft(&mut a_dft, i, &a, i);
});
module.vmp_apply(&mut c_dft, &a_dft, &mat_znx_dft, scratch.borrow());
let mut res_have_vi64: Vec<i64> = vec![i64::default(); n];
let mut res_have: VecZnx<Vec<u8>> = module.new_vec_znx(res_cols, res_size);
(0..mat_cols_out).for_each(|i| {
module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i);
module.vec_znx_big_normalize(basek, &mut res_have, i, &c_big, i, scratch.borrow());
});
(0..mat_cols_out).for_each(|col_i| {
let mut res_want_vi64: Vec<i64> = vec![i64::default(); n];
(0..a_cols).for_each(|i| {
res_want_vi64[(i + 1) + (1 + i * mat_cols_out + col_i)] = 1;
});
res_have.decode_vec_i64(col_i, basek, basek * 3, &mut res_have_vi64);
assert_eq!(res_have_vi64, res_want_vi64);
});
});
});
}
}

104
backend/src/module.rs Normal file
View File

@@ -0,0 +1,104 @@
use crate::GALOISGENERATOR;
use crate::ffi::module::{MODULE, delete_module_info, module_info_t, new_module_info};
use std::marker::PhantomData;
#[derive(Copy, Clone)]
#[repr(u8)]
pub enum BACKEND {
FFT64,
NTT120,
}
pub trait Backend {
const KIND: BACKEND;
fn module_type() -> u32;
}
pub struct FFT64;
pub struct NTT120;
impl Backend for FFT64 {
const KIND: BACKEND = BACKEND::FFT64;
fn module_type() -> u32 {
0
}
}
impl Backend for NTT120 {
const KIND: BACKEND = BACKEND::NTT120;
fn module_type() -> u32 {
1
}
}
pub struct Module<B: Backend> {
pub ptr: *mut MODULE,
n: usize,
_marker: PhantomData<B>,
}
impl<B: Backend> Module<B> {
// Instantiates a new module.
pub fn new(n: usize) -> Self {
unsafe {
let m: *mut module_info_t = new_module_info(n as u64, B::module_type());
if m.is_null() {
panic!("Failed to create module.");
}
Self {
ptr: m,
n: n,
_marker: PhantomData,
}
}
}
pub fn n(&self) -> usize {
self.n
}
pub fn log_n(&self) -> usize {
(usize::BITS - (self.n() - 1).leading_zeros()) as _
}
pub fn cyclotomic_order(&self) -> u64 {
(self.n() << 1) as _
}
// Returns GALOISGENERATOR^|generator| * sign(generator)
pub fn galois_element(&self, generator: i64) -> i64 {
if generator == 0 {
return 1;
}
((mod_exp_u64(GALOISGENERATOR, generator.abs() as usize) & (self.cyclotomic_order() - 1)) as i64) * generator.signum()
}
// Returns gen^-1
pub fn galois_element_inv(&self, gal_el: i64) -> i64 {
if gal_el == 0 {
panic!("cannot invert 0")
}
((mod_exp_u64(gal_el.abs() as u64, (self.cyclotomic_order() - 1) as usize) & (self.cyclotomic_order() - 1)) as i64)
* gal_el.signum()
}
}
impl<B: Backend> Drop for Module<B> {
fn drop(&mut self) {
unsafe { delete_module_info(self.ptr) }
}
}
fn mod_exp_u64(x: u64, e: usize) -> u64 {
let mut y: u64 = 1;
let mut x_pow: u64 = x;
let mut exp = e;
while exp > 0 {
if exp & 1 == 1 {
y = y.wrapping_mul(x_pow);
}
x_pow = x_pow.wrapping_mul(x_pow);
exp >>= 1;
}
y
}

365
backend/src/sampling.rs Normal file
View File

@@ -0,0 +1,365 @@
use crate::znx_base::ZnxViewMut;
use crate::{FFT64, VecZnx, VecZnxBig, VecZnxBigToMut, VecZnxToMut};
use rand_distr::{Distribution, Normal};
use sampling::source::Source;
pub trait FillUniform {
/// Fills the first `size` size with uniform values in \[-2^{basek-1}, 2^{basek-1}\]
fn fill_uniform(&mut self, basek: usize, col_i: usize, size: usize, source: &mut Source);
}
pub trait FillDistF64 {
fn fill_dist_f64<D: Distribution<f64>>(
&mut self,
basek: usize,
col_i: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
);
}
pub trait AddDistF64 {
/// Adds vector sampled according to the provided distribution, scaled by 2^{-k} and bounded to \[-bound, bound\].
fn add_dist_f64<D: Distribution<f64>>(
&mut self,
basek: usize,
col_i: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
);
}
pub trait FillNormal {
fn fill_normal(&mut self, basek: usize, col_i: usize, k: usize, source: &mut Source, sigma: f64, bound: f64);
}
pub trait AddNormal {
/// Adds a discrete normal vector scaled by 2^{-k} with the provided standard deviation and bounded to \[-bound, bound\].
fn add_normal(&mut self, basek: usize, col_i: usize, k: usize, source: &mut Source, sigma: f64, bound: f64);
}
impl<T> FillUniform for VecZnx<T>
where
VecZnx<T>: VecZnxToMut,
{
fn fill_uniform(&mut self, basek: usize, col_i: usize, size: usize, source: &mut Source) {
let mut a: VecZnx<&mut [u8]> = self.to_mut();
let base2k: u64 = 1 << basek;
let mask: u64 = base2k - 1;
let base2k_half: i64 = (base2k >> 1) as i64;
(0..size).for_each(|j| {
a.at_mut(col_i, j)
.iter_mut()
.for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half);
})
}
}
impl<T> FillDistF64 for VecZnx<T>
where
VecZnx<T>: VecZnxToMut,
{
fn fill_dist_f64<D: Distribution<f64>>(
&mut self,
basek: usize,
col_i: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) {
let mut a: VecZnx<&mut [u8]> = self.to_mut();
assert!(
(bound.log2().ceil() as i64) < 64,
"invalid bound: ceil(log2(bound))={} > 63",
(bound.log2().ceil() as i64)
);
let limb: usize = (k + basek - 1) / basek - 1;
let basek_rem: usize = (limb + 1) * basek - k;
if basek_rem != 0 {
a.at_mut(col_i, limb).iter_mut().for_each(|a| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
}
*a = (dist_f64.round() as i64) << basek_rem;
});
} else {
a.at_mut(col_i, limb).iter_mut().for_each(|a| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
}
*a = dist_f64.round() as i64
});
}
}
}
impl<T> AddDistF64 for VecZnx<T>
where
VecZnx<T>: VecZnxToMut,
{
fn add_dist_f64<D: Distribution<f64>>(
&mut self,
basek: usize,
col_i: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) {
let mut a: VecZnx<&mut [u8]> = self.to_mut();
assert!(
(bound.log2().ceil() as i64) < 64,
"invalid bound: ceil(log2(bound))={} > 63",
(bound.log2().ceil() as i64)
);
let limb: usize = (k + basek - 1) / basek - 1;
let basek_rem: usize = (limb + 1) * basek - k;
if basek_rem != 0 {
a.at_mut(col_i, limb).iter_mut().for_each(|a| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
}
*a += (dist_f64.round() as i64) << basek_rem;
});
} else {
a.at_mut(col_i, limb).iter_mut().for_each(|a| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
}
*a += dist_f64.round() as i64
});
}
}
}
impl<T> FillNormal for VecZnx<T>
where
VecZnx<T>: VecZnxToMut,
{
fn fill_normal(&mut self, basek: usize, col_i: usize, k: usize, source: &mut Source, sigma: f64, bound: f64) {
self.fill_dist_f64(
basek,
col_i,
k,
source,
Normal::new(0.0, sigma).unwrap(),
bound,
);
}
}
impl<T> AddNormal for VecZnx<T>
where
VecZnx<T>: VecZnxToMut,
{
fn add_normal(&mut self, basek: usize, col_i: usize, k: usize, source: &mut Source, sigma: f64, bound: f64) {
self.add_dist_f64(
basek,
col_i,
k,
source,
Normal::new(0.0, sigma).unwrap(),
bound,
);
}
}
impl<T> FillDistF64 for VecZnxBig<T, FFT64>
where
VecZnxBig<T, FFT64>: VecZnxBigToMut<FFT64>,
{
fn fill_dist_f64<D: Distribution<f64>>(
&mut self,
basek: usize,
col_i: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) {
let mut a: VecZnxBig<&mut [u8], FFT64> = self.to_mut();
assert!(
(bound.log2().ceil() as i64) < 64,
"invalid bound: ceil(log2(bound))={} > 63",
(bound.log2().ceil() as i64)
);
let limb: usize = (k + basek - 1) / basek - 1;
let basek_rem: usize = (limb + 1) * basek - k;
if basek_rem != 0 {
a.at_mut(col_i, limb).iter_mut().for_each(|a| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
}
*a = (dist_f64.round() as i64) << basek_rem;
});
} else {
a.at_mut(col_i, limb).iter_mut().for_each(|a| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
}
*a = dist_f64.round() as i64
});
}
}
}
impl<T> AddDistF64 for VecZnxBig<T, FFT64>
where
VecZnxBig<T, FFT64>: VecZnxBigToMut<FFT64>,
{
fn add_dist_f64<D: Distribution<f64>>(
&mut self,
basek: usize,
col_i: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) {
let mut a: VecZnxBig<&mut [u8], FFT64> = self.to_mut();
assert!(
(bound.log2().ceil() as i64) < 64,
"invalid bound: ceil(log2(bound))={} > 63",
(bound.log2().ceil() as i64)
);
let limb: usize = (k + basek - 1) / basek - 1;
let basek_rem: usize = (limb + 1) * basek - k;
if basek_rem != 0 {
a.at_mut(col_i, limb).iter_mut().for_each(|a| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
}
*a += (dist_f64.round() as i64) << basek_rem;
});
} else {
a.at_mut(col_i, limb).iter_mut().for_each(|a| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
}
*a += dist_f64.round() as i64
});
}
}
}
impl<T> FillNormal for VecZnxBig<T, FFT64>
where
VecZnxBig<T, FFT64>: VecZnxBigToMut<FFT64>,
{
fn fill_normal(&mut self, basek: usize, col_i: usize, k: usize, source: &mut Source, sigma: f64, bound: f64) {
self.fill_dist_f64(
basek,
col_i,
k,
source,
Normal::new(0.0, sigma).unwrap(),
bound,
);
}
}
impl<T> AddNormal for VecZnxBig<T, FFT64>
where
VecZnxBig<T, FFT64>: VecZnxBigToMut<FFT64>,
{
fn add_normal(&mut self, basek: usize, col_i: usize, k: usize, source: &mut Source, sigma: f64, bound: f64) {
self.add_dist_f64(
basek,
col_i,
k,
source,
Normal::new(0.0, sigma).unwrap(),
bound,
);
}
}
#[cfg(test)]
mod tests {
use super::{AddNormal, FillUniform};
use crate::vec_znx_ops::*;
use crate::znx_base::*;
use crate::{FFT64, Module, Stats, VecZnx};
use sampling::source::Source;
#[test]
fn vec_znx_fill_uniform() {
let n: usize = 4096;
let module: Module<FFT64> = Module::<FFT64>::new(n);
let basek: usize = 17;
let size: usize = 5;
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let zero: Vec<i64> = vec![0; n];
let one_12_sqrt: f64 = 0.28867513459481287;
(0..cols).for_each(|col_i| {
let mut a: VecZnx<_> = module.new_vec_znx(cols, size);
a.fill_uniform(basek, col_i, size, &mut source);
(0..cols).for_each(|col_j| {
if col_j != col_i {
(0..size).for_each(|limb_i| {
assert_eq!(a.at(col_j, limb_i), zero);
})
} else {
let std: f64 = a.std(col_i, basek);
assert!(
(std - one_12_sqrt).abs() < 0.01,
"std={} ~!= {}",
std,
one_12_sqrt
);
}
})
});
}
#[test]
fn vec_znx_add_normal() {
let n: usize = 4096;
let module: Module<FFT64> = Module::<FFT64>::new(n);
let basek: usize = 17;
let k: usize = 2 * 17;
let size: usize = 5;
let sigma: f64 = 3.2;
let bound: f64 = 6.0 * sigma;
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let zero: Vec<i64> = vec![0; n];
let k_f64: f64 = (1u64 << k as u64) as f64;
(0..cols).for_each(|col_i| {
let mut a: VecZnx<_> = module.new_vec_znx(cols, size);
a.add_normal(basek, col_i, k, &mut source, sigma, bound);
(0..cols).for_each(|col_j| {
if col_j != col_i {
(0..size).for_each(|limb_i| {
assert_eq!(a.at(col_j, limb_i), zero);
})
} else {
let std: f64 = a.std(col_i, basek) * k_f64;
assert!((std - sigma).abs() < 0.1, "std={} ~!= {}", std, sigma);
}
})
});
}
}

306
backend/src/scalar_znx.rs Normal file
View File

@@ -0,0 +1,306 @@
use crate::ffi::vec_znx;
use crate::znx_base::ZnxInfos;
use crate::{
Backend, DataView, DataViewMut, Module, VecZnx, VecZnxToMut, VecZnxToRef, ZnxSliceSize, ZnxView, ZnxViewMut, alloc_aligned,
};
use rand::seq::SliceRandom;
use rand_core::RngCore;
use rand_distr::{Distribution, weighted::WeightedIndex};
use sampling::source::Source;
pub struct ScalarZnx<D> {
pub(crate) data: D,
pub(crate) n: usize,
pub(crate) cols: usize,
}
impl<D> ZnxInfos for ScalarZnx<D> {
fn cols(&self) -> usize {
self.cols
}
fn rows(&self) -> usize {
1
}
fn n(&self) -> usize {
self.n
}
fn size(&self) -> usize {
1
}
}
impl<D> ZnxSliceSize for ScalarZnx<D> {
fn sl(&self) -> usize {
self.n()
}
}
impl<D> DataView for ScalarZnx<D> {
type D = D;
fn data(&self) -> &Self::D {
&self.data
}
}
impl<D> DataViewMut for ScalarZnx<D> {
fn data_mut(&mut self) -> &mut Self::D {
&mut self.data
}
}
impl<D: AsRef<[u8]>> ZnxView for ScalarZnx<D> {
type Scalar = i64;
}
impl<D: AsMut<[u8]> + AsRef<[u8]>> ScalarZnx<D> {
pub fn fill_ternary_prob(&mut self, col: usize, prob: f64, source: &mut Source) {
let choices: [i64; 3] = [-1, 0, 1];
let weights: [f64; 3] = [prob / 2.0, 1.0 - prob, prob / 2.0];
let dist: WeightedIndex<f64> = WeightedIndex::new(&weights).unwrap();
self.at_mut(col, 0)
.iter_mut()
.for_each(|x: &mut i64| *x = choices[dist.sample(source)]);
}
pub fn fill_ternary_hw(&mut self, col: usize, hw: usize, source: &mut Source) {
assert!(hw <= self.n());
self.at_mut(col, 0)[..hw]
.iter_mut()
.for_each(|x: &mut i64| *x = (((source.next_u32() & 1) as i64) << 1) - 1);
self.at_mut(col, 0).shuffle(source);
}
}
impl<D: From<Vec<u8>>> ScalarZnx<D> {
pub(crate) fn bytes_of<S: Sized>(n: usize, cols: usize) -> usize {
n * cols * size_of::<S>()
}
pub(crate) fn new<S: Sized>(n: usize, cols: usize) -> Self {
let data = alloc_aligned::<u8>(Self::bytes_of::<S>(n, cols));
Self {
data: data.into(),
n,
cols,
}
}
pub(crate) fn new_from_bytes<S: Sized>(n: usize, cols: usize, bytes: impl Into<Vec<u8>>) -> Self {
let data: Vec<u8> = bytes.into();
assert!(data.len() == Self::bytes_of::<S>(n, cols));
Self {
data: data.into(),
n,
cols,
}
}
}
pub type ScalarZnxOwned = ScalarZnx<Vec<u8>>;
pub(crate) fn bytes_of_scalar_znx<B: Backend>(module: &Module<B>, cols: usize) -> usize {
ScalarZnxOwned::bytes_of::<i64>(module.n(), cols)
}
pub trait ScalarZnxAlloc {
fn bytes_of_scalar_znx(&self, cols: usize) -> usize;
fn new_scalar_znx(&self, cols: usize) -> ScalarZnxOwned;
fn new_scalar_znx_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxOwned;
}
impl<B: Backend> ScalarZnxAlloc for Module<B> {
fn bytes_of_scalar_znx(&self, cols: usize) -> usize {
ScalarZnxOwned::bytes_of::<i64>(self.n(), cols)
}
fn new_scalar_znx(&self, cols: usize) -> ScalarZnxOwned {
ScalarZnxOwned::new::<i64>(self.n(), cols)
}
fn new_scalar_znx_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxOwned {
ScalarZnxOwned::new_from_bytes::<i64>(self.n(), cols, bytes)
}
}
pub trait ScalarZnxOps {
fn scalar_znx_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: ScalarZnxToMut,
A: ScalarZnxToRef;
/// Applies the automorphism X^i -> X^ik on the selected column of `a`.
fn scalar_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
where
A: ScalarZnxToMut;
}
impl<B: Backend> ScalarZnxOps for Module<B> {
fn scalar_znx_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: ScalarZnxToMut,
A: ScalarZnxToRef,
{
let a: ScalarZnx<&[u8]> = a.to_ref();
let mut res: ScalarZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(res.n(), self.n());
}
unsafe {
vec_znx::vec_znx_automorphism(
self.ptr,
k,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
fn scalar_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
where
A: ScalarZnxToMut,
{
let mut a: ScalarZnx<&mut [u8]> = a.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
}
unsafe {
vec_znx::vec_znx_automorphism(
self.ptr,
k,
a.at_mut_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
}
impl<D> ScalarZnx<D> {
pub(crate) fn from_data(data: D, n: usize, cols: usize) -> Self {
Self { data, n, cols }
}
}
pub trait ScalarZnxToRef {
fn to_ref(&self) -> ScalarZnx<&[u8]>;
}
pub trait ScalarZnxToMut {
fn to_mut(&mut self) -> ScalarZnx<&mut [u8]>;
}
impl ScalarZnxToMut for ScalarZnx<Vec<u8>> {
fn to_mut(&mut self) -> ScalarZnx<&mut [u8]> {
ScalarZnx {
data: self.data.as_mut_slice(),
n: self.n,
cols: self.cols,
}
}
}
impl VecZnxToMut for ScalarZnx<Vec<u8>> {
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
VecZnx {
data: self.data.as_mut_slice(),
n: self.n,
cols: self.cols,
size: 1,
}
}
}
impl ScalarZnxToRef for ScalarZnx<Vec<u8>> {
fn to_ref(&self) -> ScalarZnx<&[u8]> {
ScalarZnx {
data: self.data.as_slice(),
n: self.n,
cols: self.cols,
}
}
}
impl VecZnxToRef for ScalarZnx<Vec<u8>> {
fn to_ref(&self) -> VecZnx<&[u8]> {
VecZnx {
data: self.data.as_slice(),
n: self.n,
cols: self.cols,
size: 1,
}
}
}
impl ScalarZnxToMut for ScalarZnx<&mut [u8]> {
fn to_mut(&mut self) -> ScalarZnx<&mut [u8]> {
ScalarZnx {
data: self.data,
n: self.n,
cols: self.cols,
}
}
}
impl VecZnxToMut for ScalarZnx<&mut [u8]> {
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
VecZnx {
data: self.data,
n: self.n,
cols: self.cols,
size: 1,
}
}
}
impl ScalarZnxToRef for ScalarZnx<&mut [u8]> {
fn to_ref(&self) -> ScalarZnx<&[u8]> {
ScalarZnx {
data: self.data,
n: self.n,
cols: self.cols,
}
}
}
impl VecZnxToRef for ScalarZnx<&mut [u8]> {
fn to_ref(&self) -> VecZnx<&[u8]> {
VecZnx {
data: self.data,
n: self.n,
cols: self.cols,
size: 1,
}
}
}
impl ScalarZnxToRef for ScalarZnx<&[u8]> {
fn to_ref(&self) -> ScalarZnx<&[u8]> {
ScalarZnx {
data: self.data,
n: self.n,
cols: self.cols,
}
}
}
impl VecZnxToRef for ScalarZnx<&[u8]> {
fn to_ref(&self) -> VecZnx<&[u8]> {
VecZnx {
data: self.data,
n: self.n,
cols: self.cols,
size: 1,
}
}
}

View File

@@ -0,0 +1,233 @@
use std::marker::PhantomData;
use crate::ffi::svp;
use crate::znx_base::ZnxInfos;
use crate::{
Backend, DataView, DataViewMut, FFT64, Module, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, ZnxSliceSize, ZnxView,
alloc_aligned,
};
pub struct ScalarZnxDft<D, B: Backend> {
data: D,
n: usize,
cols: usize,
_phantom: PhantomData<B>,
}
impl<D, B: Backend> ZnxInfos for ScalarZnxDft<D, B> {
fn cols(&self) -> usize {
self.cols
}
fn rows(&self) -> usize {
1
}
fn n(&self) -> usize {
self.n
}
fn size(&self) -> usize {
1
}
}
impl<D> ZnxSliceSize for ScalarZnxDft<D, FFT64> {
fn sl(&self) -> usize {
self.n()
}
}
impl<D, B: Backend> DataView for ScalarZnxDft<D, B> {
type D = D;
fn data(&self) -> &Self::D {
&self.data
}
}
impl<D, B: Backend> DataViewMut for ScalarZnxDft<D, B> {
fn data_mut(&mut self) -> &mut Self::D {
&mut self.data
}
}
impl<D: AsRef<[u8]>> ZnxView for ScalarZnxDft<D, FFT64> {
type Scalar = f64;
}
pub(crate) fn bytes_of_scalar_znx_dft<B: Backend>(module: &Module<B>, cols: usize) -> usize {
ScalarZnxDftOwned::bytes_of(module, cols)
}
impl<D: From<Vec<u8>>, B: Backend> ScalarZnxDft<D, B> {
pub(crate) fn bytes_of(module: &Module<B>, cols: usize) -> usize {
unsafe { svp::bytes_of_svp_ppol(module.ptr) as usize * cols }
}
pub(crate) fn new(module: &Module<B>, cols: usize) -> Self {
let data = alloc_aligned::<u8>(Self::bytes_of(module, cols));
Self {
data: data.into(),
n: module.n(),
cols,
_phantom: PhantomData,
}
}
pub(crate) fn new_from_bytes(module: &Module<B>, cols: usize, bytes: impl Into<Vec<u8>>) -> Self {
let data: Vec<u8> = bytes.into();
assert!(data.len() == Self::bytes_of(module, cols));
Self {
data: data.into(),
n: module.n(),
cols,
_phantom: PhantomData,
}
}
}
impl<D, B: Backend> ScalarZnxDft<D, B> {
pub(crate) fn from_data(data: D, n: usize, cols: usize) -> Self {
Self {
data,
n,
cols,
_phantom: PhantomData,
}
}
pub fn as_vec_znx_dft(self) -> VecZnxDft<D, B> {
VecZnxDft {
data: self.data,
n: self.n,
cols: self.cols,
size: 1,
_phantom: PhantomData,
}
}
}
pub type ScalarZnxDftOwned<B> = ScalarZnxDft<Vec<u8>, B>;
pub trait ScalarZnxDftToRef<B: Backend> {
fn to_ref(&self) -> ScalarZnxDft<&[u8], B>;
}
pub trait ScalarZnxDftToMut<B: Backend> {
fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B>;
}
impl<B: Backend> ScalarZnxDftToMut<B> for ScalarZnxDft<Vec<u8>, B> {
fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B> {
ScalarZnxDft {
data: self.data.as_mut_slice(),
n: self.n,
cols: self.cols,
_phantom: PhantomData,
}
}
}
impl<B: Backend> ScalarZnxDftToRef<B> for ScalarZnxDft<Vec<u8>, B> {
fn to_ref(&self) -> ScalarZnxDft<&[u8], B> {
ScalarZnxDft {
data: self.data.as_slice(),
n: self.n,
cols: self.cols,
_phantom: PhantomData,
}
}
}
impl<B: Backend> ScalarZnxDftToMut<B> for ScalarZnxDft<&mut [u8], B> {
fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B> {
ScalarZnxDft {
data: self.data,
n: self.n,
cols: self.cols,
_phantom: PhantomData,
}
}
}
impl<B: Backend> ScalarZnxDftToRef<B> for ScalarZnxDft<&mut [u8], B> {
fn to_ref(&self) -> ScalarZnxDft<&[u8], B> {
ScalarZnxDft {
data: self.data,
n: self.n,
cols: self.cols,
_phantom: PhantomData,
}
}
}
impl<B: Backend> ScalarZnxDftToRef<B> for ScalarZnxDft<&[u8], B> {
fn to_ref(&self) -> ScalarZnxDft<&[u8], B> {
ScalarZnxDft {
data: self.data,
n: self.n,
cols: self.cols,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToMut<B> for ScalarZnxDft<Vec<u8>, B> {
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
VecZnxDft {
data: self.data.as_mut_slice(),
n: self.n,
cols: self.cols,
size: 1,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToRef<B> for ScalarZnxDft<Vec<u8>, B> {
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
VecZnxDft {
data: self.data.as_slice(),
n: self.n,
cols: self.cols,
size: 1,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToMut<B> for ScalarZnxDft<&mut [u8], B> {
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
VecZnxDft {
data: self.data,
n: self.n,
cols: self.cols,
size: 1,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToRef<B> for ScalarZnxDft<&mut [u8], B> {
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
VecZnxDft {
data: self.data,
n: self.n,
cols: self.cols,
size: 1,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToRef<B> for ScalarZnxDft<&[u8], B> {
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
VecZnxDft {
data: self.data,
n: self.n,
cols: self.cols,
size: 1,
_phantom: PhantomData,
}
}
}

View File

@@ -0,0 +1,103 @@
use crate::ffi::svp;
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
use crate::{
Backend, FFT64, Module, ScalarZnxDft, ScalarZnxDftOwned, ScalarZnxDftToMut, ScalarZnxDftToRef, ScalarZnxToRef, VecZnxDft,
VecZnxDftToMut, VecZnxDftToRef,
};
pub trait ScalarZnxDftAlloc<B: Backend> {
fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned<B>;
fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize;
fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxDftOwned<B>;
}
pub trait ScalarZnxDftOps<BACKEND: Backend> {
fn svp_prepare<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: ScalarZnxDftToMut<BACKEND>,
A: ScalarZnxToRef;
fn svp_apply<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxDftToMut<BACKEND>,
A: ScalarZnxDftToRef<BACKEND>,
B: VecZnxDftToRef<FFT64>;
fn svp_apply_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<BACKEND>,
A: ScalarZnxDftToRef<BACKEND>;
}
impl<B: Backend> ScalarZnxDftAlloc<B> for Module<B> {
fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned<B> {
ScalarZnxDftOwned::new(self, cols)
}
fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize {
ScalarZnxDftOwned::bytes_of(self, cols)
}
fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxDftOwned<B> {
ScalarZnxDftOwned::new_from_bytes(self, cols, bytes)
}
}
impl ScalarZnxDftOps<FFT64> for Module<FFT64> {
fn svp_prepare<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: ScalarZnxDftToMut<FFT64>,
A: ScalarZnxToRef,
{
unsafe {
svp::svp_prepare(
self.ptr,
res.to_mut().at_mut_ptr(res_col, 0) as *mut svp::svp_ppol_t,
a.to_ref().at_ptr(a_col, 0),
)
}
}
fn svp_apply<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxDftToMut<FFT64>,
A: ScalarZnxDftToRef<FFT64>,
B: VecZnxDftToRef<FFT64>,
{
let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref();
let b: VecZnxDft<&[u8], FFT64> = b.to_ref();
unsafe {
svp::svp_apply_dft_to_dft(
self.ptr,
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
res.size() as u64,
res.cols() as u64,
a.at_ptr(a_col, 0) as *const svp::svp_ppol_t,
b.at_ptr(b_col, 0) as *const vec_znx_dft_t,
b.size() as u64,
b.cols() as u64,
)
}
}
fn svp_apply_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<FFT64>,
A: ScalarZnxDftToRef<FFT64>,
{
let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref();
unsafe {
svp::svp_apply_dft_to_dft(
self.ptr,
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
res.size() as u64,
res.cols() as u64,
a.at_ptr(a_col, 0) as *const svp::svp_ppol_t,
res.at_ptr(res_col, 0) as *const vec_znx_dft_t,
res.size() as u64,
res.cols() as u64,
)
}
}
}

32
backend/src/stats.rs Normal file
View File

@@ -0,0 +1,32 @@
use crate::znx_base::ZnxInfos;
use crate::{Decoding, VecZnx};
use rug::Float;
use rug::float::Round;
use rug::ops::{AddAssignRound, DivAssignRound, SubAssignRound};
pub trait Stats {
/// Returns the standard devaition of the i-th polynomial.
fn std(&self, col_i: usize, basek: usize) -> f64;
}
impl<D: AsRef<[u8]>> Stats for VecZnx<D> {
fn std(&self, col_i: usize, basek: usize) -> f64 {
let prec: u32 = (self.size() * basek) as u32;
let mut data: Vec<Float> = (0..self.n()).map(|_| Float::with_val(prec, 0)).collect();
self.decode_vec_float(col_i, basek, &mut data);
// std = sqrt(sum((xi - avg)^2) / n)
let mut avg: Float = Float::with_val(prec, 0);
data.iter().for_each(|x| {
avg.add_assign_round(x, Round::Nearest);
});
avg.div_assign_round(Float::with_val(prec, data.len()), Round::Nearest);
data.iter_mut().for_each(|x| {
x.sub_assign_round(&avg, Round::Nearest);
});
let mut std: Float = Float::with_val(prec, 0);
data.iter().for_each(|x| std += x * x);
std.div_assign_round(Float::with_val(prec, data.len()), Round::Nearest);
std = std.sqrt();
std.to_f64()
}
}

373
backend/src/vec_znx.rs Normal file
View File

@@ -0,0 +1,373 @@
use itertools::izip;
use crate::DataView;
use crate::DataViewMut;
use crate::ScalarZnx;
use crate::Scratch;
use crate::ZnxSliceSize;
use crate::ZnxZero;
use crate::alloc_aligned;
use crate::assert_alignement;
use crate::cast_mut;
use crate::ffi::znx;
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
use std::{cmp::min, fmt};
/// [VecZnx] represents collection of contiguously stacked vector of small norm polynomials of
/// Zn\[X\] with [i64] coefficients.
/// A [VecZnx] is composed of multiple Zn\[X\] polynomials stored in a single contiguous array
/// in the memory.
///
/// # Example
///
/// Given 3 polynomials (a, b, c) of Zn\[X\], each with 4 columns, then the memory
/// layout is: `[a0, b0, c0, a1, b1, c1, a2, b2, c2, a3, b3, c3]`, where ai, bi, ci
/// are small polynomials of Zn\[X\].
pub struct VecZnx<D> {
pub data: D,
pub n: usize,
pub cols: usize,
pub size: usize,
}
impl<D> ZnxInfos for VecZnx<D> {
fn cols(&self) -> usize {
self.cols
}
fn rows(&self) -> usize {
1
}
fn n(&self) -> usize {
self.n
}
fn size(&self) -> usize {
self.size
}
}
impl<D> ZnxSliceSize for VecZnx<D> {
fn sl(&self) -> usize {
self.n() * self.cols()
}
}
impl<D> DataView for VecZnx<D> {
type D = D;
fn data(&self) -> &Self::D {
&self.data
}
}
impl<D> DataViewMut for VecZnx<D> {
fn data_mut(&mut self) -> &mut Self::D {
&mut self.data
}
}
impl<D: AsRef<[u8]>> ZnxView for VecZnx<D> {
type Scalar = i64;
}
impl<D: AsRef<[u8]>> VecZnx<D> {
pub fn rsh_scratch_space(n: usize) -> usize {
n * std::mem::size_of::<i64>()
}
}
impl<D: AsMut<[u8]> + AsRef<[u8]>> VecZnx<D> {
/// Truncates the precision of the [VecZnx] by k bits.
///
/// # Arguments
///
/// * `basek`: the base two logarithm of the coefficients decomposition.
/// * `k`: the number of bits of precision to drop.
pub fn trunc_pow2(&mut self, basek: usize, k: usize, col: usize) {
if k == 0 {
return;
}
self.size -= k / basek;
let k_rem: usize = k % basek;
if k_rem != 0 {
let mask: i64 = ((1 << (basek - k_rem - 1)) - 1) << k_rem;
self.at_mut(col, self.size() - 1)
.iter_mut()
.for_each(|x: &mut i64| *x &= mask)
}
}
pub fn rsh(&mut self, basek: usize, k: usize, scratch: &mut Scratch) {
let n: usize = self.n();
let cols: usize = self.cols();
let size: usize = self.size();
let steps: usize = k / basek;
self.raw_mut().rotate_right(n * steps * cols);
(0..cols).for_each(|i| {
(0..steps).for_each(|j| {
self.zero_at(i, j);
})
});
let k_rem: usize = k % basek;
if k_rem != 0 {
let (carry, _) = scratch.tmp_slice::<i64>(n);
let shift = i64::BITS as usize - k_rem;
(0..cols).for_each(|i| {
carry.fill(0);
(steps..size).for_each(|j| {
izip!(carry.iter_mut(), self.at_mut(i, j).iter_mut()).for_each(|(ci, xi)| {
*xi += *ci << basek;
*ci = (*xi << shift) >> shift;
*xi = (*xi - *ci) >> k_rem;
});
});
})
}
}
}
impl<D: From<Vec<u8>>> VecZnx<D> {
pub(crate) fn bytes_of<Scalar: Sized>(n: usize, cols: usize, size: usize) -> usize {
n * cols * size * size_of::<Scalar>()
}
pub(crate) fn new<Scalar: Sized>(n: usize, cols: usize, size: usize) -> Self {
let data = alloc_aligned::<u8>(Self::bytes_of::<Scalar>(n, cols, size));
Self {
data: data.into(),
n,
cols,
size,
}
}
pub(crate) fn new_from_bytes<Scalar: Sized>(n: usize, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
let data: Vec<u8> = bytes.into();
assert!(data.len() == Self::bytes_of::<Scalar>(n, cols, size));
Self {
data: data.into(),
n,
cols,
size,
}
}
}
impl<D> VecZnx<D> {
pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
Self {
data,
n,
cols,
size,
}
}
pub fn to_scalar_znx(self) -> ScalarZnx<D> {
debug_assert_eq!(
self.size, 1,
"cannot convert VecZnx to ScalarZnx if cols: {} != 1",
self.cols
);
ScalarZnx {
data: self.data,
n: self.n,
cols: self.cols,
}
}
}
/// Copies the coefficients of `a` on the receiver.
/// Copy is done with the minimum size matching both backing arrays.
/// Panics if the cols do not match.
pub fn copy_vec_znx_from<DataMut, Data>(b: &mut VecZnx<DataMut>, a: &VecZnx<Data>)
where
DataMut: AsMut<[u8]> + AsRef<[u8]>,
Data: AsRef<[u8]>,
{
assert_eq!(b.cols(), a.cols());
let data_a: &[i64] = a.raw();
let data_b: &mut [i64] = b.raw_mut();
let size = min(data_b.len(), data_a.len());
data_b[..size].copy_from_slice(&data_a[..size])
}
#[allow(dead_code)]
fn normalize_tmp_bytes(n: usize) -> usize {
n * std::mem::size_of::<i64>()
}
#[allow(dead_code)]
fn normalize<D: AsMut<[u8]> + AsRef<[u8]>>(basek: usize, a: &mut VecZnx<D>, a_col: usize, tmp_bytes: &mut [u8]) {
let n: usize = a.n();
debug_assert!(
tmp_bytes.len() >= normalize_tmp_bytes(n),
"invalid tmp_bytes: tmp_bytes.len()={} < normalize_tmp_bytes({})",
tmp_bytes.len(),
n,
);
#[cfg(debug_assertions)]
{
assert_alignement(tmp_bytes.as_ptr())
}
let carry_i64: &mut [i64] = cast_mut(tmp_bytes);
unsafe {
znx::znx_zero_i64_ref(n as u64, carry_i64.as_mut_ptr());
(0..a.size()).rev().for_each(|i| {
znx::znx_normalize(
n as u64,
basek as u64,
a.at_mut_ptr(a_col, i),
carry_i64.as_mut_ptr(),
a.at_mut_ptr(a_col, i),
carry_i64.as_mut_ptr(),
)
});
}
}
impl<D> VecZnx<D>
where
VecZnx<D>: VecZnxToMut + ZnxInfos,
{
/// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self].
pub fn extract_column<R>(&mut self, self_col: usize, a: &R, a_col: usize)
where
R: VecZnxToRef + ZnxInfos,
{
#[cfg(debug_assertions)]
{
assert!(self_col < self.cols());
assert!(a_col < a.cols());
}
let min_size: usize = self.size.min(a.size());
let max_size: usize = self.size;
let mut self_mut: VecZnx<&mut [u8]> = self.to_mut();
let a_ref: VecZnx<&[u8]> = a.to_ref();
(0..min_size).for_each(|i: usize| {
self_mut
.at_mut(self_col, i)
.copy_from_slice(a_ref.at(a_col, i));
});
(min_size..max_size).for_each(|i| {
self_mut.zero_at(self_col, i);
});
}
}
impl<D: AsRef<[u8]>> fmt::Display for VecZnx<D> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(
f,
"VecZnx(n={}, cols={}, size={})",
self.n, self.cols, self.size
)?;
for col in 0..self.cols {
writeln!(f, "Column {}:", col)?;
for size in 0..self.size {
let coeffs = self.at(col, size);
write!(f, " Size {}: [", size)?;
let max_show = 100;
let show_count = coeffs.len().min(max_show);
for (i, &coeff) in coeffs.iter().take(show_count).enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", coeff)?;
}
if coeffs.len() > max_show {
write!(f, ", ... ({} more)", coeffs.len() - max_show)?;
}
writeln!(f, "]")?;
}
}
Ok(())
}
}
pub type VecZnxOwned = VecZnx<Vec<u8>>;
pub type VecZnxMut<'a> = VecZnx<&'a mut [u8]>;
pub type VecZnxRef<'a> = VecZnx<&'a [u8]>;
pub trait VecZnxToRef {
fn to_ref(&self) -> VecZnx<&[u8]>;
}
pub trait VecZnxToMut {
fn to_mut(&mut self) -> VecZnx<&mut [u8]>;
}
impl VecZnxToMut for VecZnx<Vec<u8>> {
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
VecZnx {
data: self.data.as_mut_slice(),
n: self.n,
cols: self.cols,
size: self.size,
}
}
}
impl VecZnxToRef for VecZnx<Vec<u8>> {
fn to_ref(&self) -> VecZnx<&[u8]> {
VecZnx {
data: self.data.as_slice(),
n: self.n,
cols: self.cols,
size: self.size,
}
}
}
impl VecZnxToMut for VecZnx<&mut [u8]> {
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
VecZnx {
data: self.data,
n: self.n,
cols: self.cols,
size: self.size,
}
}
}
impl VecZnxToRef for VecZnx<&mut [u8]> {
fn to_ref(&self) -> VecZnx<&[u8]> {
VecZnx {
data: self.data,
n: self.n,
cols: self.cols,
size: self.size,
}
}
}
impl VecZnxToRef for VecZnx<&[u8]> {
fn to_ref(&self) -> VecZnx<&[u8]> {
VecZnx {
data: self.data,
n: self.n,
cols: self.cols,
size: self.size,
}
}
}

244
backend/src/vec_znx_big.rs Normal file
View File

@@ -0,0 +1,244 @@
use crate::ffi::vec_znx_big;
use crate::znx_base::{ZnxInfos, ZnxView};
use crate::{Backend, DataView, DataViewMut, FFT64, Module, VecZnx, ZnxSliceSize, ZnxViewMut, ZnxZero, alloc_aligned};
use std::fmt;
use std::marker::PhantomData;
pub struct VecZnxBig<D, B: Backend> {
data: D,
n: usize,
cols: usize,
size: usize,
_phantom: PhantomData<B>,
}
impl<D, B: Backend> ZnxInfos for VecZnxBig<D, B> {
fn cols(&self) -> usize {
self.cols
}
fn rows(&self) -> usize {
1
}
fn n(&self) -> usize {
self.n
}
fn size(&self) -> usize {
self.size
}
}
impl<D> ZnxSliceSize for VecZnxBig<D, FFT64> {
fn sl(&self) -> usize {
self.n() * self.cols()
}
}
impl<D, B: Backend> DataView for VecZnxBig<D, B> {
type D = D;
fn data(&self) -> &Self::D {
&self.data
}
}
impl<D, B: Backend> DataViewMut for VecZnxBig<D, B> {
fn data_mut(&mut self) -> &mut Self::D {
&mut self.data
}
}
impl<D: AsRef<[u8]>> ZnxView for VecZnxBig<D, FFT64> {
type Scalar = i64;
}
pub(crate) fn bytes_of_vec_znx_big<B: Backend>(module: &Module<B>, cols: usize, size: usize) -> usize {
unsafe { vec_znx_big::bytes_of_vec_znx_big(module.ptr, size as u64) as usize * cols }
}
impl<D: From<Vec<u8>>, B: Backend> VecZnxBig<D, B> {
pub(crate) fn new(module: &Module<B>, cols: usize, size: usize) -> Self {
let data = alloc_aligned::<u8>(bytes_of_vec_znx_big(module, cols, size));
Self {
data: data.into(),
n: module.n(),
cols,
size,
_phantom: PhantomData,
}
}
pub(crate) fn new_from_bytes(module: &Module<B>, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
let data: Vec<u8> = bytes.into();
assert!(data.len() == bytes_of_vec_znx_big(module, cols, size));
Self {
data: data.into(),
n: module.n(),
cols,
size,
_phantom: PhantomData,
}
}
}
impl<D, B: Backend> VecZnxBig<D, B> {
pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
Self {
data,
n,
cols,
size,
_phantom: PhantomData,
}
}
}
impl<D> VecZnxBig<D, FFT64>
where
VecZnxBig<D, FFT64>: VecZnxBigToMut<FFT64> + ZnxInfos,
{
// Consumes the VecZnxBig to return a VecZnx.
// Useful when no normalization is needed.
pub fn to_vec_znx_small(self) -> VecZnx<D> {
VecZnx {
data: self.data,
n: self.n,
cols: self.cols,
size: self.size,
}
}
/// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self].
pub fn extract_column<C>(&mut self, self_col: usize, a: &VecZnxBig<C, FFT64>, a_col: usize)
where
VecZnxBig<C, FFT64>: VecZnxBigToRef<FFT64> + ZnxInfos,
{
#[cfg(debug_assertions)]
{
assert!(self_col < self.cols());
assert!(a_col < a.cols());
}
let min_size: usize = self.size.min(a.size());
let max_size: usize = self.size;
let mut self_mut: VecZnxBig<&mut [u8], FFT64> = self.to_mut();
let a_ref: VecZnxBig<&[u8], FFT64> = a.to_ref();
(0..min_size).for_each(|i: usize| {
self_mut
.at_mut(self_col, i)
.copy_from_slice(a_ref.at(a_col, i));
});
(min_size..max_size).for_each(|i| {
self_mut.zero_at(self_col, i);
});
}
}
pub type VecZnxBigOwned<B> = VecZnxBig<Vec<u8>, B>;
pub trait VecZnxBigToRef<B: Backend> {
fn to_ref(&self) -> VecZnxBig<&[u8], B>;
}
pub trait VecZnxBigToMut<B: Backend> {
fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B>;
}
impl<B: Backend> VecZnxBigToMut<B> for VecZnxBig<Vec<u8>, B> {
fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B> {
VecZnxBig {
data: self.data.as_mut_slice(),
n: self.n,
cols: self.cols,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxBigToRef<B> for VecZnxBig<Vec<u8>, B> {
fn to_ref(&self) -> VecZnxBig<&[u8], B> {
VecZnxBig {
data: self.data.as_slice(),
n: self.n,
cols: self.cols,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxBigToMut<B> for VecZnxBig<&mut [u8], B> {
fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B> {
VecZnxBig {
data: self.data,
n: self.n,
cols: self.cols,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxBigToRef<B> for VecZnxBig<&mut [u8], B> {
fn to_ref(&self) -> VecZnxBig<&[u8], B> {
VecZnxBig {
data: self.data,
n: self.n,
cols: self.cols,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxBigToRef<B> for VecZnxBig<&[u8], B> {
fn to_ref(&self) -> VecZnxBig<&[u8], B> {
VecZnxBig {
data: self.data,
n: self.n,
cols: self.cols,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<D: AsRef<[u8]>> fmt::Display for VecZnxBig<D, FFT64> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(
f,
"VecZnxBig(n={}, cols={}, size={})",
self.n, self.cols, self.size
)?;
for col in 0..self.cols {
writeln!(f, "Column {}:", col)?;
for size in 0..self.size {
let coeffs = self.at(col, size);
write!(f, " Size {}: [", size)?;
let max_show = 100;
let show_count = coeffs.len().min(max_show);
for (i, &coeff) in coeffs.iter().take(show_count).enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", coeff)?;
}
if coeffs.len() > max_show {
write!(f, ", ... ({} more)", coeffs.len() - max_show)?;
}
writeln!(f, "]")?;
}
}
Ok(())
}
}

View File

@@ -0,0 +1,632 @@
use crate::ffi::vec_znx;
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
use crate::{
Backend, FFT64, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxScratch,
VecZnxToMut, VecZnxToRef, ZnxSliceSize, bytes_of_vec_znx_big,
};
pub trait VecZnxBigAlloc<B: Backend> {
/// Allocates a vector Z[X]/(X^N+1) that stores not normalized values.
fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBigOwned<B>;
/// Returns a new [VecZnxBig] with the provided bytes array as backing array.
///
/// Behavior: takes ownership of the backing array.
///
/// # Arguments
///
/// * `cols`: the number of polynomials..
/// * `size`: the number of polynomials per column.
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big].
///
/// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_big].
fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<B>;
// /// Returns a new [VecZnxBig] with the provided bytes array as backing array.
// ///
// /// Behavior: the backing array is only borrowed.
// ///
// /// # Arguments
// ///
// /// * `cols`: the number of polynomials..
// /// * `size`: the number of polynomials per column.
// /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big].
// ///
// /// # Panics
// /// If `bytes.len()` < [Module::bytes_of_vec_znx_big].
// fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig<B>;
/// Returns the minimum number of bytes necessary to allocate
/// a new [VecZnxBig] through [VecZnxBig::from_bytes].
fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize;
}
pub trait VecZnxBigOps<BACKEND: Backend> {
/// Adds `a` to `b` and stores the result on `c`.
fn vec_znx_big_add<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxBigToMut<BACKEND>,
A: VecZnxBigToRef<BACKEND>,
B: VecZnxBigToRef<BACKEND>;
/// Adds `a` to `b` and stores the result on `b`.
fn vec_znx_big_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<BACKEND>,
A: VecZnxBigToRef<BACKEND>;
/// Adds `a` to `b` and stores the result on `c`.
fn vec_znx_big_add_small<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxBigToMut<BACKEND>,
A: VecZnxBigToRef<BACKEND>,
B: VecZnxToRef;
/// Adds `a` to `b` and stores the result on `b`.
fn vec_znx_big_add_small_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<BACKEND>,
A: VecZnxToRef;
/// Subtracts `a` to `b` and stores the result on `c`.
fn vec_znx_big_sub<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxBigToMut<BACKEND>,
A: VecZnxBigToRef<BACKEND>,
B: VecZnxBigToRef<BACKEND>;
/// Subtracts `a` from `b` and stores the result on `b`.
fn vec_znx_big_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<BACKEND>,
A: VecZnxBigToRef<BACKEND>;
/// Subtracts `b` from `a` and stores the result on `b`.
fn vec_znx_big_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<BACKEND>,
A: VecZnxBigToRef<BACKEND>;
/// Subtracts `b` from `a` and stores the result on `c`.
fn vec_znx_big_sub_small_a<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxBigToMut<BACKEND>,
A: VecZnxToRef,
B: VecZnxBigToRef<BACKEND>;
/// Subtracts `a` from `res` and stores the result on `res`.
fn vec_znx_big_sub_small_a_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<BACKEND>,
A: VecZnxToRef;
/// Subtracts `b` from `a` and stores the result on `c`.
fn vec_znx_big_sub_small_b<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxBigToMut<BACKEND>,
A: VecZnxBigToRef<BACKEND>,
B: VecZnxToRef;
/// Subtracts `res` from `a` and stores the result on `res`.
fn vec_znx_big_sub_small_b_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<BACKEND>,
A: VecZnxToRef;
/// Negates `a` inplace.
fn vec_znx_big_negate_inplace<A>(&self, a: &mut A, a_col: usize)
where
A: VecZnxBigToMut<BACKEND>;
/// Normalizes `a` and stores the result on `b`.
///
/// # Arguments
///
/// * `basek`: normalization basis.
/// * `tmp_bytes`: scratch space of size at least [VecZnxBigOps::vec_znx_big_normalize].
fn vec_znx_big_normalize<R, A>(
&self,
basek: usize,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch,
) where
R: VecZnxToMut,
A: VecZnxBigToRef<BACKEND>;
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`.
fn vec_znx_big_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<BACKEND>,
A: VecZnxBigToRef<BACKEND>;
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`.
fn vec_znx_big_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
where
A: VecZnxBigToMut<BACKEND>;
}
pub trait VecZnxBigScratch {
/// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_normalize].
fn vec_znx_big_normalize_tmp_bytes(&self) -> usize;
}
impl<B: Backend> VecZnxBigAlloc<B> for Module<B> {
fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBigOwned<B> {
VecZnxBig::new(self, cols, size)
}
fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<B> {
VecZnxBig::new_from_bytes(self, cols, size, bytes)
}
fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize {
bytes_of_vec_znx_big(self, cols, size)
}
}
impl VecZnxBigOps<FFT64> for Module<FFT64> {
fn vec_znx_big_add<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxBigToMut<FFT64>,
A: VecZnxBigToRef<FFT64>,
B: VecZnxBigToRef<FFT64>,
{
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
let b: VecZnxBig<&[u8], FFT64> = b.to_ref();
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(b.n(), self.n());
assert_eq!(res.n(), self.n());
assert_ne!(a.as_ptr(), b.as_ptr());
}
unsafe {
vec_znx::vec_znx_add(
self.ptr,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
b.at_ptr(b_col, 0),
b.size() as u64,
b.sl() as u64,
)
}
}
fn vec_znx_big_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<FFT64>,
A: VecZnxBigToRef<FFT64>,
{
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(res.n(), self.n());
}
unsafe {
vec_znx::vec_znx_add(
self.ptr,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
res.at_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
)
}
}
fn vec_znx_big_sub<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxBigToMut<FFT64>,
A: VecZnxBigToRef<FFT64>,
B: VecZnxBigToRef<FFT64>,
{
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
let b: VecZnxBig<&[u8], FFT64> = b.to_ref();
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(b.n(), self.n());
assert_eq!(res.n(), self.n());
assert_ne!(a.as_ptr(), b.as_ptr());
}
unsafe {
vec_znx::vec_znx_sub(
self.ptr,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
b.at_ptr(b_col, 0),
b.size() as u64,
b.sl() as u64,
)
}
}
fn vec_znx_big_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<FFT64>,
A: VecZnxBigToRef<FFT64>,
{
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(res.n(), self.n());
}
unsafe {
vec_znx::vec_znx_sub(
self.ptr,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
res.at_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
fn vec_znx_big_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<FFT64>,
A: VecZnxBigToRef<FFT64>,
{
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(res.n(), self.n());
}
unsafe {
vec_znx::vec_znx_sub(
self.ptr,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
res.at_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
)
}
}
fn vec_znx_big_sub_small_b<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxBigToMut<FFT64>,
A: VecZnxBigToRef<FFT64>,
B: VecZnxToRef,
{
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
let b: VecZnx<&[u8]> = b.to_ref();
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(b.n(), self.n());
assert_eq!(res.n(), self.n());
assert_ne!(a.as_ptr(), b.as_ptr());
}
unsafe {
vec_znx::vec_znx_sub(
self.ptr,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
b.at_ptr(b_col, 0),
b.size() as u64,
b.sl() as u64,
)
}
}
fn vec_znx_big_sub_small_b_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<FFT64>,
A: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(res.n(), self.n());
}
unsafe {
vec_znx::vec_znx_sub(
self.ptr,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
res.at_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
)
}
}
fn vec_znx_big_sub_small_a<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxBigToMut<FFT64>,
A: VecZnxToRef,
B: VecZnxBigToRef<FFT64>,
{
let a: VecZnx<&[u8]> = a.to_ref();
let b: VecZnxBig<&[u8], FFT64> = b.to_ref();
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(b.n(), self.n());
assert_eq!(res.n(), self.n());
assert_ne!(a.as_ptr(), b.as_ptr());
}
unsafe {
vec_znx::vec_znx_sub(
self.ptr,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
b.at_ptr(b_col, 0),
b.size() as u64,
b.sl() as u64,
)
}
}
fn vec_znx_big_sub_small_a_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<FFT64>,
A: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(res.n(), self.n());
}
unsafe {
vec_znx::vec_znx_sub(
self.ptr,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
res.at_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
fn vec_znx_big_add_small<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxBigToMut<FFT64>,
A: VecZnxBigToRef<FFT64>,
B: VecZnxToRef,
{
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
let b: VecZnx<&[u8]> = b.to_ref();
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(b.n(), self.n());
assert_eq!(res.n(), self.n());
assert_ne!(a.as_ptr(), b.as_ptr());
}
unsafe {
vec_znx::vec_znx_add(
self.ptr,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
b.at_ptr(b_col, 0),
b.size() as u64,
b.sl() as u64,
)
}
}
fn vec_znx_big_add_small_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<FFT64>,
A: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(res.n(), self.n());
}
unsafe {
vec_znx::vec_znx_add(
self.ptr,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
res.at_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
fn vec_znx_big_negate_inplace<A>(&self, a: &mut A, res_col: usize)
where
A: VecZnxBigToMut<FFT64>,
{
let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
}
unsafe {
vec_znx::vec_znx_negate(
self.ptr,
a.at_mut_ptr(res_col, 0),
a.size() as u64,
a.sl() as u64,
a.at_ptr(res_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
fn vec_znx_big_normalize<R, A>(
&self,
basek: usize,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch,
) where
R: VecZnxToMut,
A: VecZnxBigToRef<FFT64>,
{
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(res.n(), self.n());
//(Jay)Note: This is calling VezZnxOps::vec_znx_normalize_tmp_bytes and not VecZnxBigOps::vec_znx_big_normalize_tmp_bytes.
// In the FFT backend the tmp sizes are same but will be different in the NTT backend
// assert!(tmp_bytes.len() >= <Self as VecZnxOps<&mut [u8], & [u8]>>::vec_znx_normalize_tmp_bytes(&self));
// assert_alignement(tmp_bytes.as_ptr());
}
let (tmp_bytes, _) = scratch.tmp_slice(<Self as VecZnxBigScratch>::vec_znx_big_normalize_tmp_bytes(
&self,
));
unsafe {
vec_znx::vec_znx_normalize_base2k(
self.ptr,
basek as u64,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
tmp_bytes.as_mut_ptr(),
);
}
}
fn vec_znx_big_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<FFT64>,
A: VecZnxBigToRef<FFT64>,
{
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(res.n(), self.n());
}
unsafe {
vec_znx::vec_znx_automorphism(
self.ptr,
k,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
fn vec_znx_big_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
where
A: VecZnxBigToMut<FFT64>,
{
let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
}
unsafe {
vec_znx::vec_znx_automorphism(
self.ptr,
k,
a.at_mut_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
}
impl<B: Backend> VecZnxBigScratch for Module<B> {
fn vec_znx_big_normalize_tmp_bytes(&self) -> usize {
<Self as VecZnxScratch>::vec_znx_normalize_tmp_bytes(self)
}
}

242
backend/src/vec_znx_dft.rs Normal file
View File

@@ -0,0 +1,242 @@
use std::marker::PhantomData;
use crate::ffi::vec_znx_dft;
use crate::znx_base::ZnxInfos;
use crate::{
Backend, DataView, DataViewMut, FFT64, Module, VecZnxBig, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero, alloc_aligned,
};
use std::fmt;
pub struct VecZnxDft<D, B: Backend> {
pub(crate) data: D,
pub(crate) n: usize,
pub(crate) cols: usize,
pub(crate) size: usize,
pub(crate) _phantom: PhantomData<B>,
}
impl<D, B: Backend> VecZnxDft<D, B> {
pub fn into_big(self) -> VecZnxBig<D, B> {
VecZnxBig::<D, B>::from_data(self.data, self.n, self.cols, self.size)
}
}
impl<D, B: Backend> ZnxInfos for VecZnxDft<D, B> {
fn cols(&self) -> usize {
self.cols
}
fn rows(&self) -> usize {
1
}
fn n(&self) -> usize {
self.n
}
fn size(&self) -> usize {
self.size
}
}
impl<D> ZnxSliceSize for VecZnxDft<D, FFT64> {
fn sl(&self) -> usize {
self.n() * self.cols()
}
}
impl<D, B: Backend> DataView for VecZnxDft<D, B> {
type D = D;
fn data(&self) -> &Self::D {
&self.data
}
}
impl<D, B: Backend> DataViewMut for VecZnxDft<D, B> {
fn data_mut(&mut self) -> &mut Self::D {
&mut self.data
}
}
impl<D: AsRef<[u8]>> ZnxView for VecZnxDft<D, FFT64> {
type Scalar = f64;
}
pub(crate) fn bytes_of_vec_znx_dft<B: Backend>(module: &Module<B>, cols: usize, size: usize) -> usize {
unsafe { vec_znx_dft::bytes_of_vec_znx_dft(module.ptr, size as u64) as usize * cols }
}
impl<D: From<Vec<u8>>, B: Backend> VecZnxDft<D, B> {
pub(crate) fn new(module: &Module<B>, cols: usize, size: usize) -> Self {
let data = alloc_aligned::<u8>(bytes_of_vec_znx_dft(module, cols, size));
Self {
data: data.into(),
n: module.n(),
cols,
size,
_phantom: PhantomData,
}
}
pub(crate) fn new_from_bytes(module: &Module<B>, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
let data: Vec<u8> = bytes.into();
assert!(data.len() == bytes_of_vec_znx_dft(module, cols, size));
Self {
data: data.into(),
n: module.n(),
cols,
size,
_phantom: PhantomData,
}
}
}
impl<D> VecZnxDft<D, FFT64>
where
VecZnxDft<D, FFT64>: VecZnxDftToMut<FFT64> + ZnxInfos,
{
/// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self].
pub fn extract_column<C>(&mut self, self_col: usize, a: &VecZnxDft<C, FFT64>, a_col: usize)
where
VecZnxDft<C, FFT64>: VecZnxDftToRef<FFT64> + ZnxInfos,
{
#[cfg(debug_assertions)]
{
assert!(self_col < self.cols());
assert!(a_col < a.cols());
}
let min_size: usize = self.size.min(a.size());
let max_size: usize = self.size;
let mut self_mut: VecZnxDft<&mut [u8], FFT64> = self.to_mut();
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
(0..min_size).for_each(|i: usize| {
self_mut
.at_mut(self_col, i)
.copy_from_slice(a_ref.at(a_col, i));
});
(min_size..max_size).for_each(|i| {
self_mut.zero_at(self_col, i);
});
}
}
pub type VecZnxDftOwned<B> = VecZnxDft<Vec<u8>, B>;
impl<D, B: Backend> VecZnxDft<D, B> {
pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
Self {
data,
n,
cols,
size,
_phantom: PhantomData,
}
}
}
pub trait VecZnxDftToRef<B: Backend> {
fn to_ref(&self) -> VecZnxDft<&[u8], B>;
}
pub trait VecZnxDftToMut<B: Backend> {
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B>;
}
impl<B: Backend> VecZnxDftToMut<B> for VecZnxDft<Vec<u8>, B> {
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
VecZnxDft {
data: self.data.as_mut_slice(),
n: self.n,
cols: self.cols,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToRef<B> for VecZnxDft<Vec<u8>, B> {
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
VecZnxDft {
data: self.data.as_slice(),
n: self.n,
cols: self.cols,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToMut<B> for VecZnxDft<&mut [u8], B> {
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
VecZnxDft {
data: self.data,
n: self.n,
cols: self.cols,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToRef<B> for VecZnxDft<&mut [u8], B> {
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
VecZnxDft {
data: self.data,
n: self.n,
cols: self.cols,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToRef<B> for VecZnxDft<&[u8], B> {
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
VecZnxDft {
data: self.data,
n: self.n,
cols: self.cols,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<D: AsRef<[u8]>> fmt::Display for VecZnxDft<D, FFT64> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(
f,
"VecZnxDft(n={}, cols={}, size={})",
self.n, self.cols, self.size
)?;
for col in 0..self.cols {
writeln!(f, "Column {}:", col)?;
for size in 0..self.size {
let coeffs = self.at(col, size);
write!(f, " Size {}: [", size)?;
let max_show = 100;
let show_count = coeffs.len().min(max_show);
for (i, &coeff) in coeffs.iter().take(show_count).enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", coeff)?;
}
if coeffs.len() > max_show {
write!(f, ", ... ({} more)", coeffs.len() - max_show)?;
}
writeln!(f, "]")?;
}
}
Ok(())
}
}

View File

@@ -0,0 +1,287 @@
use crate::ffi::{vec_znx_big, vec_znx_dft};
use crate::vec_znx_dft::bytes_of_vec_znx_dft;
use crate::znx_base::ZnxInfos;
use crate::{
Backend, Scratch, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef,
ZnxSliceSize,
};
use crate::{FFT64, Module, ZnxView, ZnxViewMut, ZnxZero};
use std::cmp::min;
pub trait VecZnxDftAlloc<B: Backend> {
/// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space.
fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDftOwned<B>;
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
///
/// Behavior: takes ownership of the backing array.
///
/// # Arguments
///
/// * `cols`: the number of cols of the [VecZnxDft].
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft].
///
/// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<B>;
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
///
/// # Arguments
///
/// * `cols`: the number of cols of the [VecZnxDft].
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft].
///
/// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize;
}
pub trait VecZnxDftOps<B: Backend> {
/// Returns the minimum number of bytes necessary to allocate
/// a new [VecZnxDft] through [VecZnxDft::from_bytes].
fn vec_znx_idft_tmp_bytes(&self) -> usize;
fn vec_znx_dft_add<R, A, D>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>,
D: VecZnxDftToRef<B>;
fn vec_znx_dft_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>;
fn vec_znx_dft_copy<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>;
/// b <- IDFT(a), uses a as scratch space.
fn vec_znx_idft_tmp_a<R, A>(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxDftToMut<B>;
/// Consumes a to return IDFT(a) in big coeff space.
fn vec_znx_idft_consume<D>(&self, a: VecZnxDft<D, B>) -> VecZnxBig<D, FFT64>
where
VecZnxDft<D, FFT64>: VecZnxDftToMut<FFT64>;
fn vec_znx_idft<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
where
R: VecZnxBigToMut<B>,
A: VecZnxDftToRef<B>;
fn vec_znx_dft<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<B>,
A: VecZnxToRef;
}
impl<B: Backend> VecZnxDftAlloc<B> for Module<B> {
fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDftOwned<B> {
VecZnxDftOwned::new(&self, cols, size)
}
fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<B> {
VecZnxDftOwned::new_from_bytes(self, cols, size, bytes)
}
fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize {
bytes_of_vec_znx_dft(self, cols, size)
}
}
impl VecZnxDftOps<FFT64> for Module<FFT64> {
fn vec_znx_dft_add<R, A, D>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
where
R: VecZnxDftToMut<FFT64>,
A: VecZnxDftToRef<FFT64>,
D: VecZnxDftToRef<FFT64>,
{
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
let b_ref: VecZnxDft<&[u8], FFT64> = b.to_ref();
let min_size: usize = res_mut.size().min(a_ref.size()).min(b_ref.size());
unsafe {
(0..min_size).for_each(|j| {
vec_znx_dft::vec_dft_add(
self.ptr,
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
1,
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
1,
b_ref.at_ptr(b_col, j) as *const vec_znx_dft::vec_znx_dft_t,
1,
);
});
}
(min_size..res_mut.size()).for_each(|j| {
res_mut.zero_at(res_col, j);
})
}
fn vec_znx_dft_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<FFT64>,
A: VecZnxDftToRef<FFT64>,
{
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
let min_size: usize = res_mut.size().min(a_ref.size());
unsafe {
(0..min_size).for_each(|j| {
vec_znx_dft::vec_dft_add(
self.ptr,
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
1,
res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t,
1,
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
1,
);
});
}
}
fn vec_znx_dft_copy<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<FFT64>,
A: VecZnxDftToRef<FFT64>,
{
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
let min_size: usize = min(res_mut.size(), a_ref.size());
(0..min_size).for_each(|j| {
res_mut
.at_mut(res_col, j)
.copy_from_slice(a_ref.at(a_col, j));
});
(min_size..res_mut.size()).for_each(|j| {
res_mut.zero_at(res_col, j);
})
}
fn vec_znx_idft_tmp_a<R, A>(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
where
R: VecZnxBigToMut<FFT64>,
A: VecZnxDftToMut<FFT64>,
{
let mut res_mut: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut();
let min_size: usize = min(res_mut.size(), a_mut.size());
unsafe {
(0..min_size).for_each(|j| {
vec_znx_dft::vec_znx_idft_tmp_a(
self.ptr,
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t,
1 as u64,
a_mut.at_mut_ptr(a_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
1 as u64,
)
});
(min_size..res_mut.size()).for_each(|j| {
res_mut.zero_at(res_col, j);
})
}
}
fn vec_znx_idft_consume<D>(&self, mut a: VecZnxDft<D, FFT64>) -> VecZnxBig<D, FFT64>
where
VecZnxDft<D, FFT64>: VecZnxDftToMut<FFT64>,
{
let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut();
unsafe {
// Rev col and rows because ZnxDft.sl() >= ZnxBig.sl()
(0..a_mut.size()).for_each(|j| {
(0..a_mut.cols()).for_each(|i| {
vec_znx_dft::vec_znx_idft_tmp_a(
self.ptr,
a_mut.at_mut_ptr(i, j) as *mut vec_znx_big::vec_znx_big_t,
1 as u64,
a_mut.at_mut_ptr(i, j) as *mut vec_znx_dft::vec_znx_dft_t,
1 as u64,
)
});
});
}
a.into_big()
}
fn vec_znx_idft_tmp_bytes(&self) -> usize {
unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(self.ptr) as usize }
}
/// b <- DFT(a)
///
/// # Panics
/// If b.cols < a_col
fn vec_znx_dft<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<FFT64>,
A: VecZnxToRef,
{
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
let a_ref: crate::VecZnx<&[u8]> = a.to_ref();
let min_size: usize = min(res_mut.size(), a_ref.size());
unsafe {
(0..min_size).for_each(|j| {
vec_znx_dft::vec_znx_dft(
self.ptr,
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
1 as u64,
a_ref.at_ptr(a_col, j),
1 as u64,
a_ref.sl() as u64,
)
});
(min_size..res_mut.size()).for_each(|j| {
res_mut.zero_at(res_col, j);
});
}
}
// b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes].
fn vec_znx_idft<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
where
R: VecZnxBigToMut<FFT64>,
A: VecZnxDftToRef<FFT64>,
{
let mut res_mut: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
let (tmp_bytes, _) = scratch.tmp_slice(self.vec_znx_idft_tmp_bytes());
let min_size: usize = min(res_mut.size(), a_ref.size());
unsafe {
(0..min_size).for_each(|j| {
vec_znx_dft::vec_znx_idft(
self.ptr,
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t,
1 as u64,
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
1 as u64,
tmp_bytes.as_mut_ptr(),
)
});
(min_size..res_mut.size()).for_each(|j| {
res_mut.zero_at(res_col, j);
});
}
}
}

694
backend/src/vec_znx_ops.rs Normal file
View File

@@ -0,0 +1,694 @@
use crate::ffi::vec_znx;
use crate::{
Backend, Module, ScalarZnxToRef, Scratch, VecZnx, VecZnxOwned, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxSliceSize, ZnxView,
ZnxViewMut, ZnxZero,
};
use itertools::izip;
use std::cmp::min;
pub trait VecZnxAlloc {
/// Allocates a new [VecZnx].
///
/// # Arguments
///
/// * `cols`: the number of polynomials.
/// * `size`: the number small polynomials per column.
fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnxOwned;
/// Instantiates a new [VecZnx] from a slice of bytes.
/// The returned [VecZnx] takes ownership of the slice of bytes.
///
/// # Arguments
///
/// * `cols`: the number of polynomials.
/// * `size`: the number small polynomials per column.
///
/// # Panic
/// Requires the slice of bytes to be equal to [VecZnxOps::bytes_of_vec_znx].
fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxOwned;
/// Returns the number of bytes necessary to allocate
/// a new [VecZnx] through [VecZnxOps::new_vec_znx_from_bytes]
/// or [VecZnxOps::new_vec_znx_from_bytes_borrow].
fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize;
}
pub trait VecZnxOps {
/// Normalizes the selected column of `a` and stores the result into the selected column of `res`.
fn vec_znx_normalize<R, A>(&self, basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
where
R: VecZnxToMut,
A: VecZnxToRef;
/// Normalizes the selected column of `a`.
fn vec_znx_normalize_inplace<A>(&self, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch)
where
A: VecZnxToMut;
/// Adds the selected column of `a` to the selected column of `b` and writes the result on the selected column of `res`.
fn vec_znx_add<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
B: VecZnxToRef;
/// Adds the selected column of `a` to the selected column of `res` and writes the result on the selected column of `res`.
fn vec_znx_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
/// Adds the selected column of `a` on the selected column and limb of `res`.
fn vec_znx_add_scalar_inplace<R, A>(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, b_col: usize)
where
R: VecZnxToMut,
A: ScalarZnxToRef;
/// Subtracts the selected column of `b` from the selected column of `a` and writes the result on the selected column of `res`.
fn vec_znx_sub<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
B: VecZnxToRef;
/// Subtracts the selected column of `a` from the selected column of `res` inplace.
///
/// res[res_col] -= a[a_col]
fn vec_znx_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
/// Subtracts the selected column of `res` from the selected column of `a` and inplace mutates `res`
///
/// res[res_col] = a[a_col] - res[res_col]
fn vec_znx_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
/// Subtracts the selected column of `a` on the selected column and limb of `res`.
fn vec_znx_sub_scalar_inplace<R, A>(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, b_col: usize)
where
R: VecZnxToMut,
A: ScalarZnxToRef;
// Negates the selected column of `a` and stores the result in `res_col` of `res`.
fn vec_znx_negate<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
/// Negates the selected column of `a`.
fn vec_znx_negate_inplace<A>(&self, a: &mut A, a_col: usize)
where
A: VecZnxToMut;
/// Multiplies the selected column of `a` by X^k and stores the result in `res_col` of `res`.
fn vec_znx_rotate<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
/// Multiplies the selected column of `a` by X^k.
fn vec_znx_rotate_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
where
A: VecZnxToMut;
/// Applies the automorphism X^i -> X^ik on the selected column of `a` and stores the result in `res_col` column of `res`.
fn vec_znx_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
/// Applies the automorphism X^i -> X^ik on the selected column of `a`.
fn vec_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
where
A: VecZnxToMut;
/// Splits the selected columns of `b` into subrings and copies them them into the selected column of `res`.
///
/// # Panics
///
/// This method requires that all [VecZnx] of b have the same ring degree
/// and that b.n() * b.len() <= a.n()
fn vec_znx_split<R, A>(&self, res: &mut Vec<R>, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
where
R: VecZnxToMut,
A: VecZnxToRef;
/// Merges the subrings of the selected column of `a` into the selected column of `res`.
///
/// # Panics
///
/// This method requires that all [VecZnx] of a have the same ring degree
/// and that a.n() * a.len() <= b.n()
fn vec_znx_merge<R, A>(&self, res: &mut R, res_col: usize, a: Vec<A>, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
fn switch_degree<R, A>(&self, r: &mut R, col_b: usize, a: &A, col_a: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
}
pub trait VecZnxScratch {
/// Returns the minimum number of bytes necessary for normalization.
fn vec_znx_normalize_tmp_bytes(&self) -> usize;
}
impl<B: Backend> VecZnxAlloc for Module<B> {
fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnxOwned {
VecZnxOwned::new::<i64>(self.n(), cols, size)
}
fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize {
VecZnxOwned::bytes_of::<i64>(self.n(), cols, size)
}
fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxOwned {
VecZnxOwned::new_from_bytes::<i64>(self.n(), cols, size, bytes)
}
}
impl<BACKEND: Backend> VecZnxOps for Module<BACKEND> {
fn vec_znx_normalize<R, A>(&self, basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(res.n(), self.n());
}
let (tmp_bytes, _) = scratch.tmp_slice(self.vec_znx_normalize_tmp_bytes());
unsafe {
vec_znx::vec_znx_normalize_base2k(
self.ptr,
basek as u64,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
tmp_bytes.as_mut_ptr(),
);
}
}
fn vec_znx_normalize_inplace<A>(&self, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch)
where
A: VecZnxToMut,
{
let mut a: VecZnx<&mut [u8]> = a.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
}
let (tmp_bytes, _) = scratch.tmp_slice(self.vec_znx_normalize_tmp_bytes());
unsafe {
vec_znx::vec_znx_normalize_base2k(
self.ptr,
basek as u64,
a.at_mut_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
tmp_bytes.as_mut_ptr(),
);
}
}
fn vec_znx_add<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
B: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let b: VecZnx<&[u8]> = b.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(b.n(), self.n());
assert_eq!(res.n(), self.n());
assert_ne!(a.as_ptr(), b.as_ptr());
}
unsafe {
vec_znx::vec_znx_add(
self.ptr,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
b.at_ptr(b_col, 0),
b.size() as u64,
b.sl() as u64,
)
}
}
fn vec_znx_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(res.n(), self.n());
}
unsafe {
vec_znx::vec_znx_add(
self.ptr,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
res.at_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
)
}
}
fn vec_znx_add_scalar_inplace<R, A>(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: ScalarZnxToRef,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let a: crate::ScalarZnx<&[u8]> = a.to_ref();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(res.n(), self.n());
}
unsafe {
vec_znx::vec_znx_add(
self.ptr,
res.at_mut_ptr(res_col, res_limb),
1 as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
res.at_ptr(res_col, res_limb),
1 as u64,
res.sl() as u64,
)
}
}
fn vec_znx_sub<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
B: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let b: VecZnx<&[u8]> = b.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(b.n(), self.n());
assert_eq!(res.n(), self.n());
assert_ne!(a.as_ptr(), b.as_ptr());
}
unsafe {
vec_znx::vec_znx_sub(
self.ptr,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
b.at_ptr(b_col, 0),
b.size() as u64,
b.sl() as u64,
)
}
}
fn vec_znx_sub_scalar_inplace<R, A>(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: ScalarZnxToRef,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let a: crate::ScalarZnx<&[u8]> = a.to_ref();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(res.n(), self.n());
}
unsafe {
vec_znx::vec_znx_sub(
self.ptr,
res.at_mut_ptr(res_col, res_limb),
1 as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
res.at_ptr(res_col, res_limb),
1 as u64,
res.sl() as u64,
)
}
}
fn vec_znx_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(res.n(), self.n());
}
unsafe {
vec_znx::vec_znx_sub(
self.ptr,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
res.at_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
fn vec_znx_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(res.n(), self.n());
}
unsafe {
vec_znx::vec_znx_sub(
self.ptr,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
res.at_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
)
}
}
fn vec_znx_negate<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(res.n(), self.n());
}
unsafe {
vec_znx::vec_znx_negate(
self.ptr,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
fn vec_znx_negate_inplace<A>(&self, a: &mut A, a_col: usize)
where
A: VecZnxToMut,
{
let mut a: VecZnx<&mut [u8]> = a.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
}
unsafe {
vec_znx::vec_znx_negate(
self.ptr,
a.at_mut_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
fn vec_znx_rotate<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(res.n(), self.n());
}
unsafe {
vec_znx::vec_znx_rotate(
self.ptr,
k,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
fn vec_znx_rotate_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
where
A: VecZnxToMut,
{
let mut a: VecZnx<&mut [u8]> = a.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
}
unsafe {
vec_znx::vec_znx_rotate(
self.ptr,
k,
a.at_mut_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
fn vec_znx_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(res.n(), self.n());
}
unsafe {
vec_znx::vec_znx_automorphism(
self.ptr,
k,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
fn vec_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
where
A: VecZnxToMut,
{
let mut a: VecZnx<&mut [u8]> = a.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert!(
k & 1 != 0,
"invalid galois element: must be odd but is {}",
k
);
}
unsafe {
vec_znx::vec_znx_automorphism(
self.ptr,
k,
a.at_mut_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
fn vec_znx_split<R, A>(&self, res: &mut Vec<R>, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let (n_in, n_out) = (a.n(), res[0].to_mut().n());
let (mut buf, _) = scratch.tmp_vec_znx(self, 1, a.size());
debug_assert!(
n_out < n_in,
"invalid a: output ring degree should be smaller"
);
res[1..].iter_mut().for_each(|bi| {
debug_assert_eq!(
bi.to_mut().n(),
n_out,
"invalid input a: all VecZnx must have the same degree"
)
});
res.iter_mut().enumerate().for_each(|(i, bi)| {
if i == 0 {
self.switch_degree(bi, res_col, &a, a_col);
self.vec_znx_rotate(-1, &mut buf, 0, &a, a_col);
} else {
self.switch_degree(bi, res_col, &mut buf, a_col);
self.vec_znx_rotate_inplace(-1, &mut buf, a_col);
}
})
}
fn vec_znx_merge<R, A>(&self, res: &mut R, res_col: usize, a: Vec<A>, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let (n_in, n_out) = (res.n(), a[0].to_ref().n());
debug_assert!(
n_out < n_in,
"invalid a: output ring degree should be smaller"
);
a[1..].iter().for_each(|ai| {
debug_assert_eq!(
ai.to_ref().n(),
n_out,
"invalid input a: all VecZnx must have the same degree"
)
});
a.iter().enumerate().for_each(|(_, ai)| {
self.switch_degree(&mut res, res_col, ai, a_col);
self.vec_znx_rotate_inplace(-1, &mut res, res_col);
});
self.vec_znx_rotate_inplace(a.len() as i64, &mut res, res_col);
}
fn switch_degree<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let (n_in, n_out) = (a.n(), res.n());
let (gap_in, gap_out): (usize, usize);
if n_in > n_out {
(gap_in, gap_out) = (n_in / n_out, 1)
} else {
(gap_in, gap_out) = (1, n_out / n_in);
res.zero();
}
let size: usize = min(a.size(), res.size());
(0..size).for_each(|i| {
izip!(
a.at(a_col, i).iter().step_by(gap_in),
res.at_mut(res_col, i).iter_mut().step_by(gap_out)
)
.for_each(|(x_in, x_out)| *x_out = *x_in);
});
}
}
impl<B: Backend> VecZnxScratch for Module<B> {
fn vec_znx_normalize_tmp_bytes(&self) -> usize {
unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.ptr) as usize }
}
}

199
backend/src/znx_base.rs Normal file
View File

@@ -0,0 +1,199 @@
use itertools::izip;
use rand_distr::num_traits::Zero;
pub trait ZnxInfos {
/// Returns the ring degree of the polynomials.
fn n(&self) -> usize;
/// Returns the base two logarithm of the ring dimension of the polynomials.
fn log_n(&self) -> usize {
(usize::BITS - (self.n() - 1).leading_zeros()) as _
}
/// Returns the number of rows.
fn rows(&self) -> usize;
/// Returns the number of polynomials in each row.
fn cols(&self) -> usize;
/// Returns the number of size per polynomial.
fn size(&self) -> usize;
/// Returns the total number of small polynomials.
fn poly_count(&self) -> usize {
self.rows() * self.cols() * self.size()
}
}
pub trait ZnxSliceSize {
/// Returns the slice size, which is the offset between
/// two size of the same column.
fn sl(&self) -> usize;
}
pub trait DataView {
type D;
fn data(&self) -> &Self::D;
}
pub trait DataViewMut: DataView {
fn data_mut(&mut self) -> &mut Self::D;
}
pub trait ZnxView: ZnxInfos + DataView<D: AsRef<[u8]>> {
type Scalar: Copy;
/// Returns a non-mutable pointer to the underlying coefficients array.
fn as_ptr(&self) -> *const Self::Scalar {
self.data().as_ref().as_ptr() as *const Self::Scalar
}
/// Returns a non-mutable reference to the entire underlying coefficient array.
fn raw(&self) -> &[Self::Scalar] {
unsafe { std::slice::from_raw_parts(self.as_ptr(), self.n() * self.poly_count()) }
}
/// Returns a non-mutable pointer starting at the j-th small polynomial of the i-th column.
fn at_ptr(&self, i: usize, j: usize) -> *const Self::Scalar {
#[cfg(debug_assertions)]
{
assert!(i < self.cols());
assert!(j < self.size());
}
let offset: usize = self.n() * (j * self.cols() + i);
unsafe { self.as_ptr().add(offset) }
}
/// Returns non-mutable reference to the (i, j)-th small polynomial.
fn at(&self, i: usize, j: usize) -> &[Self::Scalar] {
unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n()) }
}
}
pub trait ZnxViewMut: ZnxView + DataViewMut<D: AsMut<[u8]>> {
/// Returns a mutable pointer to the underlying coefficients array.
fn as_mut_ptr(&mut self) -> *mut Self::Scalar {
self.data_mut().as_mut().as_mut_ptr() as *mut Self::Scalar
}
/// Returns a mutable reference to the entire underlying coefficient array.
fn raw_mut(&mut self) -> &mut [Self::Scalar] {
unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.n() * self.poly_count()) }
}
/// Returns a mutable pointer starting at the j-th small polynomial of the i-th column.
fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut Self::Scalar {
#[cfg(debug_assertions)]
{
assert!(i < self.cols());
assert!(j < self.size());
}
let offset: usize = self.n() * (j * self.cols() + i);
unsafe { self.as_mut_ptr().add(offset) }
}
/// Returns mutable reference to the (i, j)-th small polynomial.
fn at_mut(&mut self, i: usize, j: usize) -> &mut [Self::Scalar] {
unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n()) }
}
}
//(Jay)Note: Can't provide blanket impl. of ZnxView because Scalar is not known
impl<T> ZnxViewMut for T where T: ZnxView + DataViewMut<D: AsMut<[u8]>> {}
pub trait ZnxZero: ZnxViewMut + ZnxSliceSize
where
Self: Sized,
{
fn zero(&mut self) {
unsafe {
std::ptr::write_bytes(self.as_mut_ptr(), 0, self.n() * self.poly_count());
}
}
fn zero_at(&mut self, i: usize, j: usize) {
unsafe {
std::ptr::write_bytes(self.at_mut_ptr(i, j), 0, self.n());
}
}
}
// Blanket implementations
impl<T> ZnxZero for T where T: ZnxViewMut + ZnxSliceSize {} // WARNING should not work for mat_znx_dft but it does
use std::ops::{Add, AddAssign, Div, Mul, Neg, Shl, Shr, Sub};
use crate::Scratch;
pub trait Integer:
Copy
+ Default
+ PartialEq
+ PartialOrd
+ Add<Output = Self>
+ Sub<Output = Self>
+ Mul<Output = Self>
+ Div<Output = Self>
+ Neg<Output = Self>
+ Shl<Output = Self>
+ Shr<Output = Self>
+ AddAssign
{
const BITS: u32;
}
impl Integer for i64 {
const BITS: u32 = 64;
}
impl Integer for i128 {
const BITS: u32 = 128;
}
//(Jay)Note: `rsh` impl. ignores the column
pub fn rsh<V: ZnxZero>(k: usize, basek: usize, a: &mut V, _a_col: usize, scratch: &mut Scratch)
where
V::Scalar: From<usize> + Integer + Zero,
{
let n: usize = a.n();
let _size: usize = a.size();
let cols: usize = a.cols();
let size: usize = a.size();
let steps: usize = k / basek;
a.raw_mut().rotate_right(n * steps * cols);
(0..cols).for_each(|i| {
(0..steps).for_each(|j| {
a.zero_at(i, j);
})
});
let k_rem: usize = k % basek;
if k_rem != 0 {
let (carry, _) = scratch.tmp_slice::<V::Scalar>(rsh_tmp_bytes::<V::Scalar>(n));
unsafe {
std::ptr::write_bytes(carry.as_mut_ptr(), 0, n * size_of::<V::Scalar>());
}
let basek_t = V::Scalar::from(basek);
let shift = V::Scalar::from(V::Scalar::BITS as usize - k_rem);
let k_rem_t = V::Scalar::from(k_rem);
(0..cols).for_each(|i| {
(steps..size).for_each(|j| {
izip!(carry.iter_mut(), a.at_mut(i, j).iter_mut()).for_each(|(ci, xi)| {
*xi += *ci << basek_t;
*ci = (*xi << shift) >> shift;
*xi = (*xi - *ci) >> k_rem_t;
});
});
carry.iter_mut().for_each(|r| *r = V::Scalar::zero());
})
}
}
pub fn rsh_tmp_bytes<T>(n: usize) -> usize {
n * std::mem::size_of::<T>()
}