refactoring of vec_znx

This commit is contained in:
Jean-Philippe Bossuat
2025-04-28 10:33:15 +02:00
parent 39bbe5b917
commit 2f9a1cf6d9
13 changed files with 1218 additions and 738 deletions

View File

@@ -35,7 +35,7 @@ fn main() {
module.fill_uniform(log_base2k, &mut a, 0, limbs, &mut source); module.fill_uniform(log_base2k, &mut a, 0, limbs, &mut source);
// Scratch space for DFT values // Scratch space for DFT values
let mut buf_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(1, a.limbs()); let mut buf_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(1, a.size());
// Applies buf_dft <- s * a // Applies buf_dft <- s * a
module.svp_apply_dft(&mut buf_dft, &s_ppol, &a); module.svp_apply_dft(&mut buf_dft, &s_ppol, &a);
@@ -93,9 +93,9 @@ fn main() {
// have = m * 2^{log_scale} + e // have = m * 2^{log_scale} + e
let mut have: Vec<i64> = vec![i64::default(); n]; let mut have: Vec<i64> = vec![i64::default(); n];
res.decode_vec_i64(0, log_base2k, res.limbs() * log_base2k, &mut have); res.decode_vec_i64(0, log_base2k, res.size() * log_base2k, &mut have);
let scale: f64 = (1 << (res.limbs() * log_base2k - log_scale)) as f64; let scale: f64 = (1 << (res.size() * log_base2k - log_scale)) as f64;
izip!(want.iter(), have.iter()) izip!(want.iter(), have.iter())
.enumerate() .enumerate()
.for_each(|(i, (a, b))| { .for_each(|(i, (a, b))| {

View File

@@ -33,7 +33,7 @@ fn main() {
let mut mat_znx_dft: MatZnxDft<FFT64> = module.new_mat_znx_dft(rows_mat, 1, limbs_mat); let mut mat_znx_dft: MatZnxDft<FFT64> = module.new_mat_znx_dft(rows_mat, 1, limbs_mat);
(0..a.limbs()).for_each(|row_i| { (0..a.size()).for_each(|row_i| {
let mut tmp: VecZnx = module.new_vec_znx(1, limbs_mat); let mut tmp: VecZnx = module.new_vec_znx(1, limbs_mat);
tmp.at_limb_mut(row_i)[1] = 1 as i64; tmp.at_limb_mut(row_i)[1] = 1 as i64;
module.vmp_prepare_row(&mut mat_znx_dft, tmp.raw(), row_i, &mut buf); module.vmp_prepare_row(&mut mat_znx_dft, tmp.raw(), row_i, &mut buf);

View File

@@ -1,11 +1,15 @@
use crate::{Backend, Module}; use crate::{Backend, Module, assert_alignement, cast_mut};
use itertools::izip;
use std::cmp::{max, min};
pub trait ZnxInfos { pub trait ZnxInfos {
/// Returns the ring degree of the polynomials. /// Returns the ring degree of the polynomials.
fn n(&self) -> usize; fn n(&self) -> usize;
/// Returns the base two logarithm of the ring dimension of the polynomials. /// Returns the base two logarithm of the ring dimension of the polynomials.
fn log_n(&self) -> usize; fn log_n(&self) -> usize {
(usize::BITS - (self.n() - 1).leading_zeros()) as _
}
/// Returns the number of rows. /// Returns the number of rows.
fn rows(&self) -> usize; fn rows(&self) -> usize;
@@ -13,21 +17,28 @@ pub trait ZnxInfos {
/// Returns the number of polynomials in each row. /// Returns the number of polynomials in each row.
fn cols(&self) -> usize; fn cols(&self) -> usize;
/// Returns the number of limbs per polynomial. /// Returns the number of size per polynomial.
fn limbs(&self) -> usize; fn size(&self) -> usize;
/// Returns the total number of small polynomials. /// Returns the total number of small polynomials.
fn poly_count(&self) -> usize; fn poly_count(&self) -> usize {
self.rows() * self.cols() * self.size()
}
/// Returns the slice size, which is the offset between
/// two size of the same column.
fn sl(&self) -> usize {
self.n() * self.cols()
}
} }
pub trait ZnxBase<B: Backend> { pub trait ZnxBase<B: Backend> {
type Scalar; type Scalar;
fn new(module: &Module<B>, cols: usize, limbs: usize) -> Self; fn new(module: &Module<B>, cols: usize, size: usize) -> Self;
fn from_bytes(module: &Module<B>, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self; fn from_bytes(module: &Module<B>, cols: usize, size: usize, bytes: &mut [u8]) -> Self;
fn from_bytes_borrow(module: &Module<B>, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self; fn from_bytes_borrow(module: &Module<B>, cols: usize, size: usize, bytes: &mut [u8]) -> Self;
fn bytes_of(module: &Module<B>, cols: usize, limbs: usize) -> usize; fn bytes_of(module: &Module<B>, cols: usize, size: usize) -> usize;
} }
pub trait ZnxLayout: ZnxInfos { pub trait ZnxLayout: ZnxInfos {
type Scalar; type Scalar;
@@ -52,7 +63,7 @@ pub trait ZnxLayout: ZnxInfos {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert!(i < self.cols()); assert!(i < self.cols());
assert!(j < self.limbs()); assert!(j < self.size());
} }
let offset = self.n() * (j * self.cols() + i); let offset = self.n() * (j * self.cols() + i);
unsafe { self.as_ptr().add(offset) } unsafe { self.as_ptr().add(offset) }
@@ -63,7 +74,7 @@ pub trait ZnxLayout: ZnxInfos {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert!(i < self.cols()); assert!(i < self.cols());
assert!(j < self.limbs()); assert!(j < self.size());
} }
let offset = self.n() * (j * self.cols() + i); let offset = self.n() * (j * self.cols() + i);
unsafe { self.as_mut_ptr().add(offset) } unsafe { self.as_mut_ptr().add(offset) }
@@ -89,3 +100,195 @@ pub trait ZnxLayout: ZnxInfos {
unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(0, j), self.n() * self.cols()) } unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(0, j), self.n() * self.cols()) }
} }
} }
use std::convert::TryFrom;
use std::num::TryFromIntError;
use std::ops::{Add, AddAssign, Div, Mul, Neg, Shl, Shr, Sub};
pub trait IntegerType:
Copy
+ std::fmt::Debug
+ Default
+ PartialEq
+ PartialOrd
+ Add<Output = Self>
+ Sub<Output = Self>
+ Mul<Output = Self>
+ Div<Output = Self>
+ Neg<Output = Self>
+ Shr<Output = Self>
+ Shl<Output = Self>
+ AddAssign
+ TryFrom<usize, Error = TryFromIntError>
{
const BITS: u32;
}
impl IntegerType for i64 {
const BITS: u32 = 64;
}
impl IntegerType for i128 {
const BITS: u32 = 128;
}
pub trait ZnxBasics: ZnxLayout
where
Self: Sized,
Self::Scalar: IntegerType,
{
fn zero(&mut self) {
unsafe {
std::ptr::write_bytes(self.as_mut_ptr(), 0, self.n() * size_of::<Self::Scalar>());
}
}
fn zero_at(&mut self, i: usize, j: usize) {
unsafe {
std::ptr::write_bytes(
self.at_mut_ptr(i, j),
0,
self.n() * size_of::<Self::Scalar>(),
);
}
}
fn rsh(&mut self, log_base2k: usize, k: usize, carry: &mut [u8]) {
rsh(log_base2k, self, k, carry)
}
}
pub fn rsh<V: ZnxBasics>(log_base2k: usize, a: &mut V, k: usize, tmp_bytes: &mut [u8])
where
V::Scalar: IntegerType,
{
let n: usize = a.n();
let size: usize = a.size();
let cols: usize = a.cols();
#[cfg(debug_assertions)]
{
assert!(
tmp_bytes.len() >= rsh_tmp_bytes::<V::Scalar>(n, cols),
"invalid carry: carry.len()/size_ofSelf::Scalar={} < rsh_tmp_bytes({}, {})",
tmp_bytes.len() / size_of::<V::Scalar>(),
n,
size,
);
assert_alignement(tmp_bytes.as_ptr());
}
let size: usize = a.size();
let steps: usize = k / log_base2k;
a.raw_mut().rotate_right(n * steps * cols);
(0..cols).for_each(|i| {
(0..steps).for_each(|j| {
a.zero_at(i, j);
})
});
let k_rem: usize = k % log_base2k;
if k_rem != 0 {
let carry: &mut [V::Scalar] = cast_mut(tmp_bytes);
unsafe {
std::ptr::write_bytes(carry.as_mut_ptr(), 0, n * size_of::<V::Scalar>());
}
let log_base2k_t: V::Scalar = V::Scalar::try_from(log_base2k).unwrap();
let shift: V::Scalar = V::Scalar::try_from(V::Scalar::BITS as usize - k_rem).unwrap();
let k_rem_t: V::Scalar = V::Scalar::try_from(k_rem).unwrap();
(steps..size).for_each(|i| {
izip!(carry.iter_mut(), a.at_limb_mut(i).iter_mut()).for_each(|(ci, xi)| {
*xi += *ci << log_base2k_t;
*ci = get_base_k_carry(*xi, shift);
*xi = (*xi - *ci) >> k_rem_t;
});
})
}
}
#[inline(always)]
fn get_base_k_carry<T: IntegerType>(x: T, shift: T) -> T {
(x << shift) >> shift
}
pub fn rsh_tmp_bytes<T: IntegerType>(n: usize, cols: usize) -> usize {
n * cols * std::mem::size_of::<T>()
}
pub fn switch_degree<T: ZnxLayout + ZnxBasics>(b: &mut T, a: &T)
where
<T as ZnxLayout>::Scalar: IntegerType,
{
let (n_in, n_out) = (a.n(), b.n());
let (gap_in, gap_out): (usize, usize);
if n_in > n_out {
(gap_in, gap_out) = (n_in / n_out, 1)
} else {
(gap_in, gap_out) = (1, n_out / n_in);
b.zero();
}
let size: usize = min(a.size(), b.size());
(0..size).for_each(|i| {
izip!(
a.at_limb(i).iter().step_by(gap_in),
b.at_limb_mut(i).iter_mut().step_by(gap_out)
)
.for_each(|(x_in, x_out)| *x_out = *x_in);
});
}
pub fn znx_post_process_ternary_op<T: ZnxInfos + ZnxLayout + ZnxBasics, const NEGATE: bool>(c: &mut T, a: &T, b: &T)
where
<T as ZnxLayout>::Scalar: IntegerType,
{
#[cfg(debug_assertions)]
{
assert_ne!(a.as_ptr(), b.as_ptr());
assert_ne!(b.as_ptr(), c.as_ptr());
assert_ne!(a.as_ptr(), c.as_ptr());
}
let a_cols: usize = a.cols();
let b_cols: usize = b.cols();
let c_cols: usize = c.cols();
let min_ab_cols: usize = min(a_cols, b_cols);
let max_ab_cols: usize = max(a_cols, b_cols);
// Copies shared shared cols between (c, max(a, b))
if a_cols != b_cols {
let mut x: &T = a;
if a_cols < b_cols {
x = b;
}
let min_size = min(c.size(), x.size());
(min_ab_cols..min(max_ab_cols, c_cols)).for_each(|i| {
(0..min_size).for_each(|j| {
c.at_poly_mut(i, j).copy_from_slice(x.at_poly(i, j));
if NEGATE {
c.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x);
}
});
(min_size..c.size()).for_each(|j| {
c.zero_at(i, j);
});
});
}
// Zeroes the cols of c > max(a, b).
if c_cols > max_ab_cols {
(max_ab_cols..c_cols).for_each(|i| {
(0..c.size()).for_each(|j| {
c.zero_at(i, j);
})
});
}
}

View File

@@ -81,15 +81,15 @@ impl Encoding for VecZnx {
} }
fn encode_vec_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize) { fn encode_vec_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize) {
let limbs: usize = (log_k + log_base2k - 1) / log_base2k; let size: usize = (log_k + log_base2k - 1) / log_base2k;
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert!( assert!(
limbs <= a.limbs(), size <= a.size(),
"invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.limbs()={}", "invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.size()={}",
limbs, size,
a.limbs() a.size()
); );
assert!(col_i < a.cols()); assert!(col_i < a.cols());
assert!(data.len() <= a.n()) assert!(data.len() <= a.n())
@@ -99,7 +99,7 @@ fn encode_vec_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize,
let log_k_rem: usize = log_base2k - (log_k % log_base2k); let log_k_rem: usize = log_base2k - (log_k % log_base2k);
// Zeroes coefficients of the i-th column // Zeroes coefficients of the i-th column
(0..a.limbs()).for_each(|i| unsafe { (0..a.size()).for_each(|i| unsafe {
znx_zero_i64_ref(a.n() as u64, a.at_mut_ptr(col_i, i)); znx_zero_i64_ref(a.n() as u64, a.at_mut_ptr(col_i, i));
}); });
@@ -107,11 +107,11 @@ fn encode_vec_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize,
// values on the last limb. // values on the last limb.
// Else we decompose values base2k. // Else we decompose values base2k.
if log_max + log_k_rem < 63 || log_k_rem == log_base2k { if log_max + log_k_rem < 63 || log_k_rem == log_base2k {
a.at_poly_mut(col_i, limbs - 1)[..data_len].copy_from_slice(&data[..data_len]); a.at_poly_mut(col_i, size - 1)[..data_len].copy_from_slice(&data[..data_len]);
} else { } else {
let mask: i64 = (1 << log_base2k) - 1; let mask: i64 = (1 << log_base2k) - 1;
let steps: usize = min(limbs, (log_max + log_base2k - 1) / log_base2k); let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k);
(limbs - steps..limbs) (size - steps..size)
.rev() .rev()
.enumerate() .enumerate()
.for_each(|(i, i_rev)| { .for_each(|(i, i_rev)| {
@@ -122,8 +122,8 @@ fn encode_vec_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize,
// Case where self.prec % self.k != 0. // Case where self.prec % self.k != 0.
if log_k_rem != log_base2k { if log_k_rem != log_base2k {
let steps: usize = min(limbs, (log_max + log_base2k - 1) / log_base2k); let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k);
(limbs - steps..limbs).rev().for_each(|i| { (size - steps..size).rev().for_each(|i| {
a.at_poly_mut(col_i, i)[..data_len] a.at_poly_mut(col_i, i)[..data_len]
.iter_mut() .iter_mut()
.for_each(|x| *x <<= log_k_rem); .for_each(|x| *x <<= log_k_rem);
@@ -132,7 +132,7 @@ fn encode_vec_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize,
} }
fn decode_vec_i64(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, data: &mut [i64]) { fn decode_vec_i64(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, data: &mut [i64]) {
let limbs: usize = (log_k + log_base2k - 1) / log_base2k; let size: usize = (log_k + log_base2k - 1) / log_base2k;
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert!( assert!(
@@ -145,8 +145,8 @@ fn decode_vec_i64(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, dat
} }
data.copy_from_slice(a.at_poly(col_i, 0)); data.copy_from_slice(a.at_poly(col_i, 0));
let rem: usize = log_base2k - (log_k % log_base2k); let rem: usize = log_base2k - (log_k % log_base2k);
(1..limbs).for_each(|i| { (1..size).for_each(|i| {
if i == limbs - 1 && rem != log_base2k { if i == size - 1 && rem != log_base2k {
let k_rem: usize = log_base2k - rem; let k_rem: usize = log_base2k - rem;
izip!(a.at_poly(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| { izip!(a.at_poly(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| {
*y = (*y << k_rem) + (x >> rem); *y = (*y << k_rem) + (x >> rem);
@@ -160,7 +160,7 @@ fn decode_vec_i64(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, dat
} }
fn decode_vec_float(a: &VecZnx, col_i: usize, log_base2k: usize, data: &mut [Float]) { fn decode_vec_float(a: &VecZnx, col_i: usize, log_base2k: usize, data: &mut [Float]) {
let limbs: usize = a.limbs(); let size: usize = a.size();
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert!( assert!(
@@ -172,20 +172,20 @@ fn decode_vec_float(a: &VecZnx, col_i: usize, log_base2k: usize, data: &mut [Flo
assert!(col_i < a.cols()); assert!(col_i < a.cols());
} }
let prec: u32 = (log_base2k * limbs) as u32; let prec: u32 = (log_base2k * size) as u32;
// 2^{log_base2k} // 2^{log_base2k}
let base = Float::with_val(prec, (1 << log_base2k) as f64); let base = Float::with_val(prec, (1 << log_base2k) as f64);
// y[i] = sum x[j][i] * 2^{-log_base2k*j} // y[i] = sum x[j][i] * 2^{-log_base2k*j}
(0..limbs).for_each(|i| { (0..size).for_each(|i| {
if i == 0 { if i == 0 {
izip!(a.at_poly(col_i, limbs - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { izip!(a.at_poly(col_i, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
y.assign(*x); y.assign(*x);
*y /= &base; *y /= &base;
}); });
} else { } else {
izip!(a.at_poly(col_i, limbs - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { izip!(a.at_poly(col_i, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
*y += Float::with_val(prec, *x); *y += Float::with_val(prec, *x);
*y /= &base; *y /= &base;
}); });
@@ -194,32 +194,32 @@ fn decode_vec_float(a: &VecZnx, col_i: usize, log_base2k: usize, data: &mut [Flo
} }
fn encode_coeff_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize, i: usize, value: i64, log_max: usize) { fn encode_coeff_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize, i: usize, value: i64, log_max: usize) {
let limbs: usize = (log_k + log_base2k - 1) / log_base2k; let size: usize = (log_k + log_base2k - 1) / log_base2k;
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert!(i < a.n()); assert!(i < a.n());
assert!( assert!(
limbs <= a.limbs(), size <= a.size(),
"invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.limbs()={}", "invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.size()={}",
limbs, size,
a.limbs() a.size()
); );
assert!(col_i < a.cols()); assert!(col_i < a.cols());
} }
let log_k_rem: usize = log_base2k - (log_k % log_base2k); let log_k_rem: usize = log_base2k - (log_k % log_base2k);
(0..a.limbs()).for_each(|j| a.at_poly_mut(col_i, j)[i] = 0); (0..a.size()).for_each(|j| a.at_poly_mut(col_i, j)[i] = 0);
// If 2^{log_base2k} * 2^{log_k_rem} < 2^{63}-1, then we can simply copy // If 2^{log_base2k} * 2^{log_k_rem} < 2^{63}-1, then we can simply copy
// values on the last limb. // values on the last limb.
// Else we decompose values base2k. // Else we decompose values base2k.
if log_max + log_k_rem < 63 || log_k_rem == log_base2k { if log_max + log_k_rem < 63 || log_k_rem == log_base2k {
a.at_poly_mut(col_i, limbs - 1)[i] = value; a.at_poly_mut(col_i, size - 1)[i] = value;
} else { } else {
let mask: i64 = (1 << log_base2k) - 1; let mask: i64 = (1 << log_base2k) - 1;
let steps: usize = min(limbs, (log_max + log_base2k - 1) / log_base2k); let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k);
(limbs - steps..limbs) (size - steps..size)
.rev() .rev()
.enumerate() .enumerate()
.for_each(|(j, j_rev)| { .for_each(|(j, j_rev)| {
@@ -229,8 +229,8 @@ fn encode_coeff_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usiz
// Case where prec % k != 0. // Case where prec % k != 0.
if log_k_rem != log_base2k { if log_k_rem != log_base2k {
let steps: usize = min(limbs, (log_max + log_base2k - 1) / log_base2k); let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k);
(limbs - steps..limbs).rev().for_each(|j| { (size - steps..size).rev().for_each(|j| {
a.at_poly_mut(col_i, j)[i] <<= log_k_rem; a.at_poly_mut(col_i, j)[i] <<= log_k_rem;
}) })
} }
@@ -247,7 +247,7 @@ fn decode_coeff_i64(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, i
let data: &[i64] = a.raw(); let data: &[i64] = a.raw();
let mut res: i64 = data[i]; let mut res: i64 = data[i];
let rem: usize = log_base2k - (log_k % log_base2k); let rem: usize = log_base2k - (log_k % log_base2k);
let slice_size: usize = a.n() * a.limbs(); let slice_size: usize = a.n() * a.size();
(1..cols).for_each(|i| { (1..cols).for_each(|i| {
let x = data[i * slice_size]; let x = data[i * slice_size];
if i == cols - 1 && rem != log_base2k { if i == cols - 1 && rem != log_base2k {
@@ -271,9 +271,9 @@ mod tests {
let n: usize = 8; let n: usize = 8;
let module: Module<FFT64> = Module::<FFT64>::new(n); let module: Module<FFT64> = Module::<FFT64>::new(n);
let log_base2k: usize = 17; let log_base2k: usize = 17;
let limbs: usize = 5; let size: usize = 5;
let log_k: usize = limbs * log_base2k - 5; let log_k: usize = size * log_base2k - 5;
let mut a: VecZnx = VecZnx::new(&module, 2, limbs); let mut a: VecZnx = VecZnx::new(&module, 2, size);
let mut source: Source = Source::new([0u8; 32]); let mut source: Source = Source::new([0u8; 32]);
let raw: &mut [i64] = a.raw_mut(); let raw: &mut [i64] = a.raw_mut();
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
@@ -293,9 +293,9 @@ mod tests {
let n: usize = 8; let n: usize = 8;
let module: Module<FFT64> = Module::<FFT64>::new(n); let module: Module<FFT64> = Module::<FFT64>::new(n);
let log_base2k: usize = 17; let log_base2k: usize = 17;
let limbs: usize = 5; let size: usize = 5;
let log_k: usize = limbs * log_base2k - 5; let log_k: usize = size * log_base2k - 5;
let mut a: VecZnx = VecZnx::new(&module, 2, limbs); let mut a: VecZnx = VecZnx::new(&module, 2, size);
let mut source = Source::new([0u8; 32]); let mut source = Source::new([0u8; 32]);
let raw: &mut [i64] = a.raw_mut(); let raw: &mut [i64] = a.raw_mut();
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);

View File

@@ -11,6 +11,7 @@ pub mod stats;
pub mod vec_znx; pub mod vec_znx;
pub mod vec_znx_big; pub mod vec_znx_big;
pub mod vec_znx_dft; pub mod vec_znx_dft;
pub mod vec_znx_ops;
pub use commons::*; pub use commons::*;
pub use encoding::*; pub use encoding::*;
@@ -23,6 +24,7 @@ pub use stats::*;
pub use vec_znx::*; pub use vec_znx::*;
pub use vec_znx_big::*; pub use vec_znx_big::*;
pub use vec_znx_dft::*; pub use vec_znx_dft::*;
pub use vec_znx_ops::*;
pub const GALOISGENERATOR: u64 = 5; pub const GALOISGENERATOR: u64 = 5;
pub const DEFAULTALIGN: usize = 64; pub const DEFAULTALIGN: usize = 64;

View File

@@ -22,7 +22,7 @@ pub struct MatZnxDft<B: Backend> {
/// Number of cols /// Number of cols
cols: usize, cols: usize,
/// The number of small polynomials /// The number of small polynomials
limbs: usize, size: usize,
_marker: PhantomData<B>, _marker: PhantomData<B>,
} }
@@ -31,10 +31,6 @@ impl<B: Backend> ZnxInfos for MatZnxDft<B> {
self.n self.n
} }
fn log_n(&self) -> usize {
(usize::BITS - (self.n() - 1).leading_zeros()) as _
}
fn rows(&self) -> usize { fn rows(&self) -> usize {
self.rows self.rows
} }
@@ -43,18 +39,14 @@ impl<B: Backend> ZnxInfos for MatZnxDft<B> {
self.cols self.cols
} }
fn limbs(&self) -> usize { fn size(&self) -> usize {
self.limbs self.size
}
fn poly_count(&self) -> usize {
self.rows * self.cols * self.limbs
} }
} }
impl MatZnxDft<FFT64> { impl MatZnxDft<FFT64> {
fn new(module: &Module<FFT64>, rows: usize, cols: usize, limbs: usize) -> MatZnxDft<FFT64> { fn new(module: &Module<FFT64>, rows: usize, cols: usize, size: usize) -> MatZnxDft<FFT64> {
let mut data: Vec<u8> = alloc_aligned::<u8>(module.bytes_of_mat_znx_dft(rows, cols, limbs)); let mut data: Vec<u8> = alloc_aligned::<u8>(module.bytes_of_mat_znx_dft(rows, cols, size));
let ptr: *mut u8 = data.as_mut_ptr(); let ptr: *mut u8 = data.as_mut_ptr();
MatZnxDft::<FFT64> { MatZnxDft::<FFT64> {
data: data, data: data,
@@ -62,7 +54,7 @@ impl MatZnxDft<FFT64> {
n: module.n(), n: module.n(),
rows: rows, rows: rows,
cols: cols, cols: cols,
limbs: limbs, size: size,
_marker: PhantomData, _marker: PhantomData,
} }
} }
@@ -115,7 +107,7 @@ impl MatZnxDft<FFT64> {
fn at_block(&self, row: usize, col: usize, blk: usize) -> &[f64] { fn at_block(&self, row: usize, col: usize, blk: usize) -> &[f64] {
let nrows: usize = self.rows(); let nrows: usize = self.rows();
let nsize: usize = self.limbs(); let nsize: usize = self.size();
if col == (nsize - 1) && (nsize & 1 == 1) { if col == (nsize - 1) && (nsize & 1 == 1) {
&self.raw()[blk * nrows * nsize * 8 + col * nrows * 8 + row * 8..] &self.raw()[blk * nrows * nsize * 8 + col * nrows * 8 + row * 8..]
} else { } else {
@@ -127,7 +119,7 @@ impl MatZnxDft<FFT64> {
/// This trait implements methods for vector matrix product, /// This trait implements methods for vector matrix product,
/// that is, multiplying a [VecZnx] with a [VmpPMat]. /// that is, multiplying a [VecZnx] with a [VmpPMat].
pub trait MatZnxDftOps<B: Backend> { pub trait MatZnxDftOps<B: Backend> {
fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, limbs: usize) -> usize; fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> usize;
/// Allocates a new [VmpPMat] with the given number of rows and columns. /// Allocates a new [VmpPMat] with the given number of rows and columns.
/// ///
@@ -135,7 +127,7 @@ pub trait MatZnxDftOps<B: Backend> {
/// ///
/// * `rows`: number of rows (number of [VecZnxDft]). /// * `rows`: number of rows (number of [VecZnxDft]).
/// * `size`: number of size (number of size of each [VecZnxDft]). /// * `size`: number of size (number of size of each [VecZnxDft]).
fn new_mat_znx_dft(&self, rows: usize, cols: usize, limbs: usize) -> MatZnxDft<B>; fn new_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> MatZnxDft<B>;
/// Returns the number of bytes needed as scratch space for [VmpPMatOps::vmp_prepare_contiguous]. /// Returns the number of bytes needed as scratch space for [VmpPMatOps::vmp_prepare_contiguous].
/// ///
@@ -351,12 +343,12 @@ pub trait MatZnxDftOps<B: Backend> {
} }
impl MatZnxDftOps<FFT64> for Module<FFT64> { impl MatZnxDftOps<FFT64> for Module<FFT64> {
fn new_mat_znx_dft(&self, rows: usize, cols: usize, limbs: usize) -> MatZnxDft<FFT64> { fn new_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> MatZnxDft<FFT64> {
MatZnxDft::<FFT64>::new(self, rows, cols, limbs) MatZnxDft::<FFT64>::new(self, rows, cols, size)
} }
fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, limbs: usize) -> usize { fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> usize {
unsafe { vmp::bytes_of_vmp_pmat(self.ptr, rows as u64, (limbs * cols) as u64) as usize } unsafe { vmp::bytes_of_vmp_pmat(self.ptr, rows as u64, (size * cols) as u64) as usize }
} }
fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize, size: usize) -> usize { fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize, size: usize) -> usize {
@@ -367,7 +359,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_eq!(a.len(), b.n() * b.poly_count()); assert_eq!(a.len(), b.n() * b.poly_count());
assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols(), b.limbs())); assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols(), b.size()));
assert_alignement(tmp_bytes.as_ptr()); assert_alignement(tmp_bytes.as_ptr());
} }
unsafe { unsafe {
@@ -376,7 +368,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
b.as_mut_ptr() as *mut vmp_pmat_t, b.as_mut_ptr() as *mut vmp_pmat_t,
a.as_ptr(), a.as_ptr(),
b.rows() as u64, b.rows() as u64,
(b.limbs() * b.cols()) as u64, (b.size() * b.cols()) as u64,
tmp_bytes.as_mut_ptr(), tmp_bytes.as_mut_ptr(),
); );
} }
@@ -385,8 +377,8 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
fn vmp_prepare_row(&self, b: &mut MatZnxDft<FFT64>, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]) { fn vmp_prepare_row(&self, b: &mut MatZnxDft<FFT64>, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]) {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_eq!(a.len(), b.limbs() * self.n() * b.cols()); assert_eq!(a.len(), b.size() * self.n() * b.cols());
assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols(), b.limbs())); assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols(), b.size()));
assert_alignement(tmp_bytes.as_ptr()); assert_alignement(tmp_bytes.as_ptr());
} }
unsafe { unsafe {
@@ -396,7 +388,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
a.as_ptr(), a.as_ptr(),
row_i as u64, row_i as u64,
b.rows() as u64, b.rows() as u64,
(b.limbs() * b.cols()) as u64, (b.size() * b.cols()) as u64,
tmp_bytes.as_mut_ptr(), tmp_bytes.as_mut_ptr(),
); );
} }
@@ -406,7 +398,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_eq!(a.n(), b.n()); assert_eq!(a.n(), b.n());
assert_eq!(a.limbs(), b.limbs()); assert_eq!(a.size(), b.size());
assert_eq!(a.cols(), b.cols()); assert_eq!(a.cols(), b.cols());
} }
unsafe { unsafe {
@@ -416,7 +408,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
a.as_ptr() as *const vmp_pmat_t, a.as_ptr() as *const vmp_pmat_t,
row_i as u64, row_i as u64,
a.rows() as u64, a.rows() as u64,
(a.limbs() * a.cols()) as u64, (a.size() * a.cols()) as u64,
); );
} }
} }
@@ -425,7 +417,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_eq!(a.n(), b.n()); assert_eq!(a.n(), b.n());
assert_eq!(a.limbs(), b.limbs()); assert_eq!(a.size(), b.size());
} }
unsafe { unsafe {
vmp::vmp_prepare_row_dft( vmp::vmp_prepare_row_dft(
@@ -434,7 +426,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
a.ptr as *const vec_znx_dft_t, a.ptr as *const vec_znx_dft_t,
row_i as u64, row_i as u64,
b.rows() as u64, b.rows() as u64,
b.limbs() as u64, b.size() as u64,
); );
} }
} }
@@ -443,7 +435,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_eq!(a.n(), b.n()); assert_eq!(a.n(), b.n());
assert_eq!(a.limbs(), b.limbs()); assert_eq!(a.size(), b.size());
} }
unsafe { unsafe {
vmp::vmp_extract_row_dft( vmp::vmp_extract_row_dft(
@@ -452,7 +444,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
a.as_ptr() as *const vmp_pmat_t, a.as_ptr() as *const vmp_pmat_t,
row_i as u64, row_i as u64,
a.rows() as u64, a.rows() as u64,
a.limbs() as u64, a.size() as u64,
); );
} }
} }
@@ -470,7 +462,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
} }
fn vmp_apply_dft(&self, c: &mut VecZnxDft<FFT64>, a: &VecZnx, b: &MatZnxDft<FFT64>, tmp_bytes: &mut [u8]) { fn vmp_apply_dft(&self, c: &mut VecZnxDft<FFT64>, a: &VecZnx, b: &MatZnxDft<FFT64>, tmp_bytes: &mut [u8]) {
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.limbs(), a.limbs(), b.rows(), b.limbs())); debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size()));
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_alignement(tmp_bytes.as_ptr()); assert_alignement(tmp_bytes.as_ptr());
@@ -479,20 +471,20 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
vmp::vmp_apply_dft( vmp::vmp_apply_dft(
self.ptr, self.ptr,
c.ptr as *mut vec_znx_dft_t, c.ptr as *mut vec_znx_dft_t,
c.limbs() as u64, c.size() as u64,
a.as_ptr(), a.as_ptr(),
a.limbs() as u64, a.size() as u64,
(a.n() * a.cols()) as u64, (a.n() * a.cols()) as u64,
b.as_ptr() as *const vmp_pmat_t, b.as_ptr() as *const vmp_pmat_t,
b.rows() as u64, b.rows() as u64,
b.limbs() as u64, b.size() as u64,
tmp_bytes.as_mut_ptr(), tmp_bytes.as_mut_ptr(),
) )
} }
} }
fn vmp_apply_dft_add(&self, c: &mut VecZnxDft<FFT64>, a: &VecZnx, b: &MatZnxDft<FFT64>, tmp_bytes: &mut [u8]) { fn vmp_apply_dft_add(&self, c: &mut VecZnxDft<FFT64>, a: &VecZnx, b: &MatZnxDft<FFT64>, tmp_bytes: &mut [u8]) {
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.limbs(), a.limbs(), b.rows(), b.limbs())); debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size()));
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_alignement(tmp_bytes.as_ptr()); assert_alignement(tmp_bytes.as_ptr());
@@ -501,13 +493,13 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
vmp::vmp_apply_dft_add( vmp::vmp_apply_dft_add(
self.ptr, self.ptr,
c.ptr as *mut vec_znx_dft_t, c.ptr as *mut vec_znx_dft_t,
c.limbs() as u64, c.size() as u64,
a.as_ptr(), a.as_ptr(),
a.limbs() as u64, a.size() as u64,
(a.n() * a.limbs()) as u64, (a.n() * a.size()) as u64,
b.as_ptr() as *const vmp_pmat_t, b.as_ptr() as *const vmp_pmat_t,
b.rows() as u64, b.rows() as u64,
b.limbs() as u64, b.size() as u64,
tmp_bytes.as_mut_ptr(), tmp_bytes.as_mut_ptr(),
) )
} }
@@ -526,7 +518,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
} }
fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft<FFT64>, a: &VecZnxDft<FFT64>, b: &MatZnxDft<FFT64>, tmp_bytes: &mut [u8]) { fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft<FFT64>, a: &VecZnxDft<FFT64>, b: &MatZnxDft<FFT64>, tmp_bytes: &mut [u8]) {
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.limbs(), a.limbs(), b.rows(), b.limbs())); debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size()));
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_alignement(tmp_bytes.as_ptr()); assert_alignement(tmp_bytes.as_ptr());
@@ -535,12 +527,12 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
vmp::vmp_apply_dft_to_dft( vmp::vmp_apply_dft_to_dft(
self.ptr, self.ptr,
c.ptr as *mut vec_znx_dft_t, c.ptr as *mut vec_znx_dft_t,
c.limbs() as u64, c.size() as u64,
a.ptr as *const vec_znx_dft_t, a.ptr as *const vec_znx_dft_t,
a.limbs() as u64, a.size() as u64,
b.as_ptr() as *const vmp_pmat_t, b.as_ptr() as *const vmp_pmat_t,
b.rows() as u64, b.rows() as u64,
b.limbs() as u64, b.size() as u64,
tmp_bytes.as_mut_ptr(), tmp_bytes.as_mut_ptr(),
) )
} }
@@ -553,7 +545,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
b: &MatZnxDft<FFT64>, b: &MatZnxDft<FFT64>,
tmp_bytes: &mut [u8], tmp_bytes: &mut [u8],
) { ) {
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.limbs(), a.limbs(), b.rows(), b.limbs())); debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size()));
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_alignement(tmp_bytes.as_ptr()); assert_alignement(tmp_bytes.as_ptr());
@@ -562,19 +554,19 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
vmp::vmp_apply_dft_to_dft_add( vmp::vmp_apply_dft_to_dft_add(
self.ptr, self.ptr,
c.ptr as *mut vec_znx_dft_t, c.ptr as *mut vec_znx_dft_t,
c.limbs() as u64, c.size() as u64,
a.ptr as *const vec_znx_dft_t, a.ptr as *const vec_znx_dft_t,
a.limbs() as u64, a.size() as u64,
b.as_ptr() as *const vmp_pmat_t, b.as_ptr() as *const vmp_pmat_t,
b.rows() as u64, b.rows() as u64,
b.limbs() as u64, b.size() as u64,
tmp_bytes.as_mut_ptr(), tmp_bytes.as_mut_ptr(),
) )
} }
} }
fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft<FFT64>, a: &MatZnxDft<FFT64>, tmp_bytes: &mut [u8]) { fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft<FFT64>, a: &MatZnxDft<FFT64>, tmp_bytes: &mut [u8]) {
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(b.limbs(), b.limbs(), a.rows(), a.limbs())); debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(b.size(), b.size(), a.rows(), a.size()));
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_alignement(tmp_bytes.as_ptr()); assert_alignement(tmp_bytes.as_ptr());
@@ -583,12 +575,12 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
vmp::vmp_apply_dft_to_dft( vmp::vmp_apply_dft_to_dft(
self.ptr, self.ptr,
b.ptr as *mut vec_znx_dft_t, b.ptr as *mut vec_znx_dft_t,
b.limbs() as u64, b.size() as u64,
b.ptr as *mut vec_znx_dft_t, b.ptr as *mut vec_znx_dft_t,
b.limbs() as u64, b.size() as u64,
a.as_ptr() as *const vmp_pmat_t, a.as_ptr() as *const vmp_pmat_t,
a.rows() as u64, a.rows() as u64,
a.limbs() as u64, a.size() as u64,
tmp_bytes.as_mut_ptr(), tmp_bytes.as_mut_ptr(),
) )
} }

View File

@@ -3,8 +3,8 @@ use rand_distr::{Distribution, Normal};
use sampling::source::Source; use sampling::source::Source;
pub trait Sampling { pub trait Sampling {
/// Fills the first `limbs` limbs with uniform values in \[-2^{log_base2k-1}, 2^{log_base2k-1}\] /// Fills the first `size` size with uniform values in \[-2^{log_base2k-1}, 2^{log_base2k-1}\]
fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, col_i: usize, limbs: usize, source: &mut Source); fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, col_i: usize, size: usize, source: &mut Source);
/// Adds vector sampled according to the provided distribution, scaled by 2^{-log_k} and bounded to \[-bound, bound\]. /// Adds vector sampled according to the provided distribution, scaled by 2^{-log_k} and bounded to \[-bound, bound\].
fn add_dist_f64<D: Distribution<f64>>( fn add_dist_f64<D: Distribution<f64>>(
@@ -32,11 +32,11 @@ pub trait Sampling {
} }
impl<B: Backend> Sampling for Module<B> { impl<B: Backend> Sampling for Module<B> {
fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, col_i: usize, limbs: usize, source: &mut Source) { fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, col_i: usize, size: usize, source: &mut Source) {
let base2k: u64 = 1 << log_base2k; let base2k: u64 = 1 << log_base2k;
let mask: u64 = base2k - 1; let mask: u64 = base2k - 1;
let base2k_half: i64 = (base2k >> 1) as i64; let base2k_half: i64 = (base2k >> 1) as i64;
(0..limbs).for_each(|j| { (0..size).for_each(|j| {
a.at_poly_mut(col_i, j) a.at_poly_mut(col_i, j)
.iter_mut() .iter_mut()
.for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half); .for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half);
@@ -114,17 +114,17 @@ mod tests {
let n: usize = 4096; let n: usize = 4096;
let module: Module<FFT64> = Module::<FFT64>::new(n); let module: Module<FFT64> = Module::<FFT64>::new(n);
let log_base2k: usize = 17; let log_base2k: usize = 17;
let limbs: usize = 5; let size: usize = 5;
let mut source: Source = Source::new([0u8; 32]); let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2; let cols: usize = 2;
let zero: Vec<i64> = vec![0; n]; let zero: Vec<i64> = vec![0; n];
let one_12_sqrt: f64 = 0.28867513459481287; let one_12_sqrt: f64 = 0.28867513459481287;
(0..cols).for_each(|col_i| { (0..cols).for_each(|col_i| {
let mut a: VecZnx = VecZnx::new(&module, cols, limbs); let mut a: VecZnx = VecZnx::new(&module, cols, size);
module.fill_uniform(log_base2k, &mut a, col_i, limbs, &mut source); module.fill_uniform(log_base2k, &mut a, col_i, size, &mut source);
(0..cols).for_each(|col_j| { (0..cols).for_each(|col_j| {
if col_j != col_i { if col_j != col_i {
(0..limbs).for_each(|limb_i| { (0..size).for_each(|limb_i| {
assert_eq!(a.at_poly(col_j, limb_i), zero); assert_eq!(a.at_poly(col_j, limb_i), zero);
}) })
} else { } else {
@@ -146,7 +146,7 @@ mod tests {
let module: Module<FFT64> = Module::<FFT64>::new(n); let module: Module<FFT64> = Module::<FFT64>::new(n);
let log_base2k: usize = 17; let log_base2k: usize = 17;
let log_k: usize = 2 * 17; let log_k: usize = 2 * 17;
let limbs: usize = 5; let size: usize = 5;
let sigma: f64 = 3.2; let sigma: f64 = 3.2;
let bound: f64 = 6.0 * sigma; let bound: f64 = 6.0 * sigma;
let mut source: Source = Source::new([0u8; 32]); let mut source: Source = Source::new([0u8; 32]);
@@ -154,11 +154,11 @@ mod tests {
let zero: Vec<i64> = vec![0; n]; let zero: Vec<i64> = vec![0; n];
let k_f64: f64 = (1u64 << log_k as u64) as f64; let k_f64: f64 = (1u64 << log_k as u64) as f64;
(0..cols).for_each(|col_i| { (0..cols).for_each(|col_i| {
let mut a: VecZnx = VecZnx::new(&module, cols, limbs); let mut a: VecZnx = VecZnx::new(&module, cols, size);
module.add_normal(log_base2k, &mut a, col_i, log_k, &mut source, sigma, bound); module.add_normal(log_base2k, &mut a, col_i, log_k, &mut source, sigma, bound);
(0..cols).for_each(|col_j| { (0..cols).for_each(|col_j| {
if col_j != col_i { if col_j != col_i {
(0..limbs).for_each(|limb_i| { (0..size).for_each(|limb_i| {
assert_eq!(a.at_poly(col_j, limb_i), zero); assert_eq!(a.at_poly(col_j, limb_i), zero);
}) })
} else { } else {

View File

@@ -120,7 +120,7 @@ impl Scalar {
VecZnx { VecZnx {
n: self.n, n: self.n,
cols: 1, cols: 1,
limbs: 1, size: 1,
data: Vec::new(), data: Vec::new(),
ptr: self.ptr, ptr: self.ptr,
} }

View File

@@ -10,7 +10,7 @@ pub trait Stats {
impl Stats for VecZnx { impl Stats for VecZnx {
fn std(&self, col_i: usize, log_base2k: usize) -> f64 { fn std(&self, col_i: usize, log_base2k: usize) -> f64 {
let prec: u32 = (self.limbs() * log_base2k) as u32; let prec: u32 = (self.size() * log_base2k) as u32;
let mut data: Vec<Float> = (0..self.n()).map(|_| Float::with_val(prec, 0)).collect(); let mut data: Vec<Float> = (0..self.n()).map(|_| Float::with_val(prec, 0)).collect();
self.decode_vec_float(col_i, log_base2k, &mut data); self.decode_vec_float(col_i, log_base2k, &mut data);
// std = sqrt(sum((xi - avg)^2) / n) // std = sqrt(sum((xi - avg)^2) / n)

View File

@@ -1,11 +1,10 @@
use crate::Backend; use crate::Backend;
use crate::ZnxBase; use crate::ZnxBase;
use crate::cast_mut; use crate::cast_mut;
use crate::ffi::vec_znx;
use crate::ffi::znx; use crate::ffi::znx;
use crate::{Module, ZnxInfos, ZnxLayout}; use crate::switch_degree;
use crate::{Module, ZnxBasics, ZnxInfos, ZnxLayout};
use crate::{alloc_aligned, assert_alignement}; use crate::{alloc_aligned, assert_alignement};
use itertools::izip;
use std::cmp::min; use std::cmp::min;
/// [VecZnx] represents collection of contiguously stacked vector of small norm polynomials of /// [VecZnx] represents collection of contiguously stacked vector of small norm polynomials of
@@ -26,8 +25,8 @@ pub struct VecZnx {
/// The number of polynomials /// The number of polynomials
pub cols: usize, pub cols: usize,
/// The number of limbs per polynomial (a.k.a small polynomials). /// The number of size per polynomial (a.k.a small polynomials).
pub limbs: usize, pub size: usize,
/// Polynomial coefficients, as a contiguous array. Each col is equally spaced by n. /// Polynomial coefficients, as a contiguous array. Each col is equally spaced by n.
pub data: Vec<i64>, pub data: Vec<i64>,
@@ -41,10 +40,6 @@ impl ZnxInfos for VecZnx {
self.n self.n
} }
fn log_n(&self) -> usize {
(usize::BITS - (self.n() - 1).leading_zeros()) as _
}
fn rows(&self) -> usize { fn rows(&self) -> usize {
1 1
} }
@@ -53,12 +48,8 @@ impl ZnxInfos for VecZnx {
self.cols self.cols
} }
fn limbs(&self) -> usize { fn size(&self) -> usize {
self.limbs self.size
}
fn poly_count(&self) -> usize {
self.cols * self.limbs
} }
} }
@@ -74,6 +65,8 @@ impl ZnxLayout for VecZnx {
} }
} }
impl ZnxBasics for VecZnx {}
/// Copies the coefficients of `a` on the receiver. /// Copies the coefficients of `a` on the receiver.
/// Copy is done with the minimum size matching both backing arrays. /// Copy is done with the minimum size matching both backing arrays.
/// Panics if the cols do not match. /// Panics if the cols do not match.
@@ -89,28 +82,28 @@ impl<B: Backend> ZnxBase<B> for VecZnx {
type Scalar = i64; type Scalar = i64;
/// Allocates a new [VecZnx] composed of #size polynomials of Z\[X\]. /// Allocates a new [VecZnx] composed of #size polynomials of Z\[X\].
fn new(module: &Module<B>, cols: usize, limbs: usize) -> Self { fn new(module: &Module<B>, cols: usize, size: usize) -> Self {
let n: usize = module.n(); let n: usize = module.n();
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert!(n > 0); assert!(n > 0);
assert!(n & (n - 1) == 0); assert!(n & (n - 1) == 0);
assert!(cols > 0); assert!(cols > 0);
assert!(limbs > 0); assert!(size > 0);
} }
let mut data: Vec<i64> = alloc_aligned::<i64>(Self::bytes_of(module, cols, limbs)); let mut data: Vec<i64> = alloc_aligned::<i64>(Self::bytes_of(module, cols, size));
let ptr: *mut i64 = data.as_mut_ptr(); let ptr: *mut i64 = data.as_mut_ptr();
Self { Self {
n: n, n: n,
cols: cols, cols: cols,
limbs: limbs, size: size,
data: data, data: data,
ptr: ptr, ptr: ptr,
} }
} }
fn bytes_of(module: &Module<B>, cols: usize, limbs: usize) -> usize { fn bytes_of(module: &Module<B>, cols: usize, size: usize) -> usize {
module.n() * cols * limbs * size_of::<i64>() module.n() * cols * size * size_of::<i64>()
} }
/// Returns a new struct implementing [VecZnx] with the provided data as backing array. /// Returns a new struct implementing [VecZnx] with the provided data as backing array.
@@ -118,14 +111,14 @@ impl<B: Backend> ZnxBase<B> for VecZnx {
/// The struct will take ownership of buf[..[Self::bytes_of]] /// The struct will take ownership of buf[..[Self::bytes_of]]
/// ///
/// User must ensure that data is properly alligned and that /// User must ensure that data is properly alligned and that
/// the limbs of data is equal to [Self::bytes_of]. /// the size of data is equal to [Self::bytes_of].
fn from_bytes(module: &Module<B>, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self { fn from_bytes(module: &Module<B>, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
let n: usize = module.n(); let n: usize = module.n();
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert!(cols > 0); assert!(cols > 0);
assert!(limbs > 0); assert!(size > 0);
assert_eq!(bytes.len(), Self::bytes_of(module, cols, limbs)); assert_eq!(bytes.len(), Self::bytes_of(module, cols, size));
assert_alignement(bytes.as_ptr()); assert_alignement(bytes.as_ptr());
} }
unsafe { unsafe {
@@ -134,25 +127,25 @@ impl<B: Backend> ZnxBase<B> for VecZnx {
Self { Self {
n: n, n: n,
cols: cols, cols: cols,
limbs: limbs, size: size,
data: Vec::from_raw_parts(ptr, bytes.len(), bytes.len()), data: Vec::from_raw_parts(ptr, bytes.len(), bytes.len()),
ptr: ptr, ptr: ptr,
} }
} }
} }
fn from_bytes_borrow(module: &Module<B>, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self { fn from_bytes_borrow(module: &Module<B>, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert!(cols > 0); assert!(cols > 0);
assert!(limbs > 0); assert!(size > 0);
assert!(bytes.len() >= Self::bytes_of(module, cols, limbs)); assert!(bytes.len() >= Self::bytes_of(module, cols, size));
assert_alignement(bytes.as_ptr()); assert_alignement(bytes.as_ptr());
} }
Self { Self {
n: module.n(), n: module.n(),
cols: cols, cols: cols,
limbs: limbs, size: size,
data: Vec::new(), data: Vec::new(),
ptr: bytes.as_mut_ptr() as *mut i64, ptr: bytes.as_mut_ptr() as *mut i64,
} }
@@ -173,16 +166,16 @@ impl VecZnx {
if !self.borrowing() { if !self.borrowing() {
self.data self.data
.truncate(self.n() * self.cols() * (self.limbs() - k / log_base2k)); .truncate(self.n() * self.cols() * (self.size() - k / log_base2k));
} }
self.limbs -= k / log_base2k; self.size -= k / log_base2k;
let k_rem: usize = k % log_base2k; let k_rem: usize = k % log_base2k;
if k_rem != 0 { if k_rem != 0 {
let mask: i64 = ((1 << (log_base2k - k_rem - 1)) - 1) << k_rem; let mask: i64 = ((1 << (log_base2k - k_rem - 1)) - 1) << k_rem;
self.at_limb_mut(self.limbs() - 1) self.at_limb_mut(self.size() - 1)
.iter_mut() .iter_mut()
.for_each(|x: &mut i64| *x &= mask) .for_each(|x: &mut i64| *x &= mask)
} }
@@ -196,52 +189,22 @@ impl VecZnx {
self.data.len() == 0 self.data.len() == 0
} }
pub fn zero(&mut self) {
unsafe { znx::znx_zero_i64_ref((self.n * self.poly_count()) as u64, self.ptr) }
}
pub fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]) { pub fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]) {
normalize(log_base2k, self, carry) normalize(log_base2k, self, carry)
} }
pub fn rsh(&mut self, log_base2k: usize, k: usize, carry: &mut [u8]) {
rsh(log_base2k, self, k, carry)
}
pub fn switch_degree(&self, a: &mut Self) { pub fn switch_degree(&self, a: &mut Self) {
switch_degree(a, self) switch_degree(a, self)
} }
// Prints the first `n` coefficients of each limb // Prints the first `n` coefficients of each limb
pub fn print(&self, n: usize) { pub fn print(&self, n: usize) {
(0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])) (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n]))
} }
} }
pub fn switch_degree(b: &mut VecZnx, a: &VecZnx) { fn normalize_tmp_bytes(n: usize, size: usize) -> usize {
let (n_in, n_out) = (a.n(), b.n()); n * size * std::mem::size_of::<i64>()
let (gap_in, gap_out): (usize, usize);
if n_in > n_out {
(gap_in, gap_out) = (n_in / n_out, 1)
} else {
(gap_in, gap_out) = (1, n_out / n_in);
b.zero();
}
let limbs: usize = min(a.limbs(), b.limbs());
(0..limbs).for_each(|i| {
izip!(
a.at_limb(i).iter().step_by(gap_in),
b.at_limb_mut(i).iter_mut().step_by(gap_out)
)
.for_each(|(x_in, x_out)| *x_out = *x_in);
});
}
fn normalize_tmp_bytes(n: usize, limbs: usize) -> usize {
n * limbs * std::mem::size_of::<i64>()
} }
fn normalize(log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]) { fn normalize(log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]) {
@@ -264,7 +227,7 @@ fn normalize(log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]) {
unsafe { unsafe {
znx::znx_zero_i64_ref(n as u64, carry_i64.as_mut_ptr()); znx::znx_zero_i64_ref(n as u64, carry_i64.as_mut_ptr());
(0..a.limbs()).rev().for_each(|i| { (0..a.size()).rev().for_each(|i| {
znx::znx_normalize( znx::znx_normalize(
(n * cols) as u64, (n * cols) as u64,
log_base2k as u64, log_base2k as u64,
@@ -276,462 +239,3 @@ fn normalize(log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]) {
}); });
} }
} }
pub fn rsh_tmp_bytes(n: usize, limbs: usize) -> usize {
n * limbs * std::mem::size_of::<i64>()
}
pub fn rsh(log_base2k: usize, a: &mut VecZnx, k: usize, tmp_bytes: &mut [u8]) {
let n: usize = a.n();
let limbs: usize = a.limbs();
#[cfg(debug_assertions)]
{
assert!(
tmp_bytes.len() >= rsh_tmp_bytes(n, limbs),
"invalid carry: carry.len()/8={} < rsh_tmp_bytes({}, {})",
tmp_bytes.len() >> 3,
n,
limbs,
);
assert_alignement(tmp_bytes.as_ptr());
}
let limbs: usize = a.limbs();
let size_steps: usize = k / log_base2k;
a.raw_mut().rotate_right(n * limbs * size_steps);
unsafe {
znx::znx_zero_i64_ref((n * limbs * size_steps) as u64, a.as_mut_ptr());
}
let k_rem = k % log_base2k;
if k_rem != 0 {
let carry_i64: &mut [i64] = cast_mut(tmp_bytes);
unsafe {
znx::znx_zero_i64_ref((n * limbs) as u64, carry_i64.as_mut_ptr());
}
let log_base2k: usize = log_base2k;
(size_steps..limbs).for_each(|i| {
izip!(carry_i64.iter_mut(), a.at_limb_mut(i).iter_mut()).for_each(|(ci, xi)| {
*xi += *ci << log_base2k;
*ci = get_base_k_carry(*xi, k_rem);
*xi = (*xi - *ci) >> k_rem;
});
})
}
}
#[inline(always)]
fn get_base_k_carry(x: i64, k: usize) -> i64 {
(x << 64 - k) >> (64 - k)
}
pub trait VecZnxOps {
/// Allocates a new [VecZnx].
///
/// # Arguments
///
/// * `cols`: the number of polynomials.
/// * `limbs`: the number of limbs per polynomial (a.k.a small polynomials).
fn new_vec_znx(&self, cols: usize, limbs: usize) -> VecZnx;
fn new_vec_znx_from_bytes(&self, cols: usize, limbs: usize, bytes: &mut [u8]) -> VecZnx;
fn new_vec_znx_from_bytes_borrow(&self, cols: usize, limbs: usize, tmp_bytes: &mut [u8]) -> VecZnx;
/// Returns the minimum number of bytes necessary to allocate
/// a new [VecZnx] through [VecZnx::from_bytes].
fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize;
fn vec_znx_normalize_tmp_bytes(&self, cols: usize) -> usize;
/// c <- a + b.
fn vec_znx_add(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx);
/// b <- b + a.
fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx);
/// c <- a - b.
fn vec_znx_sub(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx);
/// b <- a - b.
fn vec_znx_sub_ab_inplace(&self, b: &mut VecZnx, a: &VecZnx);
/// b <- b - a.
fn vec_znx_sub_ba_inplace(&self, b: &mut VecZnx, a: &VecZnx);
/// b <- -a.
fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx);
/// b <- -b.
fn vec_znx_negate_inplace(&self, a: &mut VecZnx);
/// b <- a * X^k (mod X^{n} + 1)
fn vec_znx_rotate(&self, k: i64, b: &mut VecZnx, a: &VecZnx);
/// a <- a * X^k (mod X^{n} + 1)
fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx);
/// b <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1))
fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx);
/// a <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1))
fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx);
/// Splits b into subrings and copies them them into a.
///
/// # Panics
///
/// This method requires that all [VecZnx] of b have the same ring degree
/// and that b.n() * b.len() <= a.n()
fn vec_znx_split(&self, b: &mut Vec<VecZnx>, a: &VecZnx, buf: &mut VecZnx);
/// Merges the subrings a into b.
///
/// # Panics
///
/// This method requires that all [VecZnx] of a have the same ring degree
/// and that a.n() * a.len() <= b.n()
fn vec_znx_merge(&self, b: &mut VecZnx, a: &Vec<VecZnx>);
}
impl<B: Backend> VecZnxOps for Module<B> {
fn new_vec_znx(&self, cols: usize, limbs: usize) -> VecZnx {
VecZnx::new(self, cols, limbs)
}
fn bytes_of_vec_znx(&self, cols: usize, limbs: usize) -> usize {
VecZnx::bytes_of(self, cols, limbs)
}
fn new_vec_znx_from_bytes(&self, cols: usize, limbs: usize, bytes: &mut [u8]) -> VecZnx {
VecZnx::from_bytes(self, cols, limbs, bytes)
}
fn new_vec_znx_from_bytes_borrow(&self, cols: usize, limbs: usize, tmp_bytes: &mut [u8]) -> VecZnx {
VecZnx::from_bytes_borrow(self, cols, limbs, tmp_bytes)
}
fn vec_znx_normalize_tmp_bytes(&self, cols: usize) -> usize {
unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.ptr) as usize * cols }
}
// c <- a + b
fn vec_znx_add(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx) {
let n: usize = self.n();
#[cfg(debug_assertions)]
{
assert_eq!(c.n(), n);
assert_eq!(a.n(), n);
assert_eq!(b.n(), n);
}
unsafe {
vec_znx::vec_znx_add(
self.ptr,
c.as_mut_ptr(),
c.limbs() as u64,
(n * c.cols()) as u64,
a.as_ptr(),
a.limbs() as u64,
(n * a.cols()) as u64,
b.as_ptr(),
b.limbs() as u64,
(n * b.cols()) as u64,
)
}
}
// b <- a + b
fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx) {
let n: usize = self.n();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), n);
assert_eq!(b.n(), n);
}
unsafe {
vec_znx::vec_znx_add(
self.ptr,
b.as_mut_ptr(),
b.limbs() as u64,
(n * b.cols()) as u64,
a.as_ptr(),
a.limbs() as u64,
(n * a.cols()) as u64,
b.as_ptr(),
b.limbs() as u64,
(n * b.cols()) as u64,
)
}
}
// c <- a + b
fn vec_znx_sub(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx) {
let n: usize = self.n();
#[cfg(debug_assertions)]
{
assert_eq!(c.n(), n);
assert_eq!(a.n(), n);
assert_eq!(b.n(), n);
}
unsafe {
vec_znx::vec_znx_sub(
self.ptr,
c.as_mut_ptr(),
c.limbs() as u64,
(n * c.cols()) as u64,
a.as_ptr(),
a.limbs() as u64,
(n * a.cols()) as u64,
b.as_ptr(),
b.limbs() as u64,
(n * b.cols()) as u64,
)
}
}
// b <- a - b
fn vec_znx_sub_ab_inplace(&self, b: &mut VecZnx, a: &VecZnx) {
let n: usize = self.n();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), n);
assert_eq!(b.n(), n);
}
unsafe {
vec_znx::vec_znx_sub(
self.ptr,
b.as_mut_ptr(),
b.limbs() as u64,
(n * b.cols()) as u64,
a.as_ptr(),
a.limbs() as u64,
(n * a.cols()) as u64,
b.as_ptr(),
b.limbs() as u64,
(n * b.cols()) as u64,
)
}
}
// b <- b - a
fn vec_znx_sub_ba_inplace(&self, b: &mut VecZnx, a: &VecZnx) {
let n: usize = self.n();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), n);
assert_eq!(b.n(), n);
}
unsafe {
vec_znx::vec_znx_sub(
self.ptr,
b.as_mut_ptr(),
b.limbs() as u64,
(n * b.cols()) as u64,
b.as_ptr(),
b.limbs() as u64,
(n * b.cols()) as u64,
a.as_ptr(),
a.limbs() as u64,
(n * a.cols()) as u64,
)
}
}
fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx) {
let n: usize = self.n();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), n);
assert_eq!(b.n(), n);
}
unsafe {
vec_znx::vec_znx_negate(
self.ptr,
b.as_mut_ptr(),
b.limbs() as u64,
(n * b.cols()) as u64,
a.as_ptr(),
a.limbs() as u64,
(n * a.cols()) as u64,
)
}
}
fn vec_znx_negate_inplace(&self, a: &mut VecZnx) {
let n: usize = self.n();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), n);
}
unsafe {
vec_znx::vec_znx_negate(
self.ptr,
a.as_mut_ptr(),
a.limbs() as u64,
(n * a.cols()) as u64,
a.as_ptr(),
a.limbs() as u64,
(n * a.cols()) as u64,
)
}
}
fn vec_znx_rotate(&self, k: i64, b: &mut VecZnx, a: &VecZnx) {
let n: usize = self.n();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), n);
assert_eq!(b.n(), n);
}
unsafe {
vec_znx::vec_znx_rotate(
self.ptr,
k,
b.as_mut_ptr(),
b.limbs() as u64,
(n * b.cols()) as u64,
a.as_ptr(),
a.limbs() as u64,
(n * a.cols()) as u64,
)
}
}
fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx) {
let n: usize = self.n();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), n);
}
unsafe {
vec_znx::vec_znx_rotate(
self.ptr,
k,
a.as_mut_ptr(),
a.limbs() as u64,
(n * a.cols()) as u64,
a.as_ptr(),
a.limbs() as u64,
(n * a.cols()) as u64,
)
}
}
/// Maps X^i to X^{ik} mod X^{n}+1. The mapping is applied independently on each size.
///
/// # Arguments
///
/// * `a`: input.
/// * `b`: output.
/// * `k`: the power to which to map each coefficients.
/// * `a_size`: the number of a_size on which to apply the mapping.
///
/// # Panics
///
/// The method will panic if the argument `a` is greater than `a.limbs()`.
fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx) {
let n: usize = self.n();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), n);
assert_eq!(b.n(), n);
}
unsafe {
vec_znx::vec_znx_automorphism(
self.ptr,
k,
b.as_mut_ptr(),
b.limbs() as u64,
(n * b.cols()) as u64,
a.as_ptr(),
a.limbs() as u64,
(n * a.cols()) as u64,
);
}
}
/// Maps X^i to X^{ik} mod X^{n}+1. The mapping is applied independently on each size.
///
/// # Arguments
///
/// * `a`: input and output.
/// * `k`: the power to which to map each coefficients.
/// * `a_size`: the number of size on which to apply the mapping.
///
/// # Panics
///
/// The method will panic if the argument `size` is greater than `self.limbs()`.
fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx) {
let n: usize = self.n();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), n);
}
unsafe {
vec_znx::vec_znx_automorphism(
self.ptr,
k,
a.as_mut_ptr(),
a.limbs() as u64,
(n * a.cols()) as u64,
a.as_ptr(),
a.limbs() as u64,
(n * a.cols()) as u64,
);
}
}
fn vec_znx_split(&self, b: &mut Vec<VecZnx>, a: &VecZnx, buf: &mut VecZnx) {
let (n_in, n_out) = (a.n(), b[0].n());
debug_assert!(
n_out < n_in,
"invalid a: output ring degree should be smaller"
);
b[1..].iter().for_each(|bi| {
debug_assert_eq!(
bi.n(),
n_out,
"invalid input a: all VecZnx must have the same degree"
)
});
b.iter_mut().enumerate().for_each(|(i, bi)| {
if i == 0 {
switch_degree(bi, a);
self.vec_znx_rotate(-1, buf, a);
} else {
switch_degree(bi, buf);
self.vec_znx_rotate_inplace(-1, buf);
}
})
}
fn vec_znx_merge(&self, b: &mut VecZnx, a: &Vec<VecZnx>) {
let (n_in, n_out) = (b.n(), a[0].n());
debug_assert!(
n_out < n_in,
"invalid a: output ring degree should be smaller"
);
a[1..].iter().for_each(|ai| {
debug_assert_eq!(
ai.n(),
n_out,
"invalid input a: all VecZnx must have the same degree"
)
});
a.iter().enumerate().for_each(|(_, ai)| {
switch_degree(b, ai);
self.vec_znx_rotate_inplace(-1, b);
});
self.vec_znx_rotate_inplace(a.len() as i64, b);
}
}

View File

@@ -7,43 +7,43 @@ pub struct VecZnxBig<B: Backend> {
pub ptr: *mut u8, pub ptr: *mut u8,
pub n: usize, pub n: usize,
pub cols: usize, pub cols: usize,
pub limbs: usize, pub size: usize,
pub _marker: PhantomData<B>, pub _marker: PhantomData<B>,
} }
impl<B: Backend> ZnxBase<B> for VecZnxBig<B> { impl<B: Backend> ZnxBase<B> for VecZnxBig<B> {
type Scalar = u8; type Scalar = u8;
fn new(module: &Module<B>, cols: usize, limbs: usize) -> Self { fn new(module: &Module<B>, cols: usize, size: usize) -> Self {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert!(cols > 0); assert!(cols > 0);
assert!(limbs > 0); assert!(size > 0);
} }
let mut data: Vec<Self::Scalar> = alloc_aligned::<u8>(Self::bytes_of(module, cols, limbs)); let mut data: Vec<Self::Scalar> = alloc_aligned::<u8>(Self::bytes_of(module, cols, size));
let ptr: *mut Self::Scalar = data.as_mut_ptr(); let ptr: *mut Self::Scalar = data.as_mut_ptr();
Self { Self {
data: data, data: data,
ptr: ptr, ptr: ptr,
n: module.n(), n: module.n(),
cols: cols, cols: cols,
limbs: limbs, size: size,
_marker: PhantomData, _marker: PhantomData,
} }
} }
fn bytes_of(module: &Module<B>, cols: usize, limbs: usize) -> usize { fn bytes_of(module: &Module<B>, cols: usize, size: usize) -> usize {
unsafe { vec_znx_big::bytes_of_vec_znx_big(module.ptr, limbs as u64) as usize * cols } unsafe { vec_znx_big::bytes_of_vec_znx_big(module.ptr, size as u64) as usize * cols }
} }
/// Returns a new [VecZnxBig] with the provided data as backing array. /// Returns a new [VecZnxBig] with the provided data as backing array.
/// User must ensure that data is properly alligned and that /// User must ensure that data is properly alligned and that
/// the size of data is at least equal to [Module::bytes_of_vec_znx_big]. /// the size of data is at least equal to [Module::bytes_of_vec_znx_big].
fn from_bytes(module: &Module<B>, cols: usize, limbs: usize, bytes: &mut [Self::Scalar]) -> Self { fn from_bytes(module: &Module<B>, cols: usize, size: usize, bytes: &mut [Self::Scalar]) -> Self {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert!(cols > 0); assert!(cols > 0);
assert!(limbs > 0); assert!(size > 0);
assert_eq!(bytes.len(), Self::bytes_of(module, cols, limbs)); assert_eq!(bytes.len(), Self::bytes_of(module, cols, size));
assert_alignement(bytes.as_ptr()) assert_alignement(bytes.as_ptr())
}; };
unsafe { unsafe {
@@ -52,18 +52,18 @@ impl<B: Backend> ZnxBase<B> for VecZnxBig<B> {
ptr: bytes.as_mut_ptr(), ptr: bytes.as_mut_ptr(),
n: module.n(), n: module.n(),
cols: cols, cols: cols,
limbs: limbs, size: size,
_marker: PhantomData, _marker: PhantomData,
} }
} }
} }
fn from_bytes_borrow(module: &Module<B>, cols: usize, limbs: usize, bytes: &mut [Self::Scalar]) -> Self { fn from_bytes_borrow(module: &Module<B>, cols: usize, size: usize, bytes: &mut [Self::Scalar]) -> Self {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert!(cols > 0); assert!(cols > 0);
assert!(limbs > 0); assert!(size > 0);
assert_eq!(bytes.len(), Self::bytes_of(module, cols, limbs)); assert_eq!(bytes.len(), Self::bytes_of(module, cols, size));
assert_alignement(bytes.as_ptr()); assert_alignement(bytes.as_ptr());
} }
Self { Self {
@@ -71,17 +71,13 @@ impl<B: Backend> ZnxBase<B> for VecZnxBig<B> {
ptr: bytes.as_mut_ptr(), ptr: bytes.as_mut_ptr(),
n: module.n(), n: module.n(),
cols: cols, cols: cols,
limbs: limbs, size: size,
_marker: PhantomData, _marker: PhantomData,
} }
} }
} }
impl<B: Backend> ZnxInfos for VecZnxBig<B> { impl<B: Backend> ZnxInfos for VecZnxBig<B> {
fn log_n(&self) -> usize {
(usize::BITS - (self.n - 1).leading_zeros()) as _
}
fn n(&self) -> usize { fn n(&self) -> usize {
self.n self.n
} }
@@ -94,12 +90,8 @@ impl<B: Backend> ZnxInfos for VecZnxBig<B> {
1 1
} }
fn limbs(&self) -> usize { fn size(&self) -> usize {
self.limbs self.size
}
fn poly_count(&self) -> usize {
self.cols * self.limbs
} }
} }
@@ -117,13 +109,13 @@ impl ZnxLayout for VecZnxBig<FFT64> {
impl VecZnxBig<FFT64> { impl VecZnxBig<FFT64> {
pub fn print(&self, n: usize) { pub fn print(&self, n: usize) {
(0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])); (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n]));
} }
} }
pub trait VecZnxBigOps<B: Backend> { pub trait VecZnxBigOps<B: Backend> {
/// Allocates a vector Z[X]/(X^N+1) that stores not normalized values. /// Allocates a vector Z[X]/(X^N+1) that stores not normalized values.
fn new_vec_znx_big(&self, cols: usize, limbs: usize) -> VecZnxBig<B>; fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBig<B>;
/// Returns a new [VecZnxBig] with the provided bytes array as backing array. /// Returns a new [VecZnxBig] with the provided bytes array as backing array.
/// ///
@@ -132,12 +124,12 @@ pub trait VecZnxBigOps<B: Backend> {
/// # Arguments /// # Arguments
/// ///
/// * `cols`: the number of polynomials.. /// * `cols`: the number of polynomials..
/// * `limbs`: the number of limbs (a.k.a small polynomials) per polynomial. /// * `size`: the number of size (a.k.a small polynomials) per polynomial.
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big]. /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big].
/// ///
/// # Panics /// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_big]. /// If `bytes.len()` < [Module::bytes_of_vec_znx_big].
fn new_vec_znx_big_from_bytes(&self, cols: usize, limbs: usize, bytes: &mut [u8]) -> VecZnxBig<B>; fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxBig<B>;
/// Returns a new [VecZnxBig] with the provided bytes array as backing array. /// Returns a new [VecZnxBig] with the provided bytes array as backing array.
/// ///
@@ -146,25 +138,25 @@ pub trait VecZnxBigOps<B: Backend> {
/// # Arguments /// # Arguments
/// ///
/// * `cols`: the number of polynomials.. /// * `cols`: the number of polynomials..
/// * `limbs`: the number of limbs (a.k.a small polynomials) per polynomial. /// * `size`: the number of size (a.k.a small polynomials) per polynomial.
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big]. /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big].
/// ///
/// # Panics /// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_big]. /// If `bytes.len()` < [Module::bytes_of_vec_znx_big].
fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, limbs: usize, tmp_bytes: &mut [u8]) -> VecZnxBig<B>; fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig<B>;
/// Returns the minimum number of bytes necessary to allocate /// Returns the minimum number of bytes necessary to allocate
/// a new [VecZnxBig] through [VecZnxBig::from_bytes]. /// a new [VecZnxBig] through [VecZnxBig::from_bytes].
fn bytes_of_vec_znx_big(&self, cols: usize, limbs: usize) -> usize; fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize;
/// b[VecZnxBig] <- b[VecZnxBig] - a[VecZnx] /// b[VecZnxBig] <- b[VecZnxBig] - a[VecZnx]
/// ///
/// # Behavior /// # Behavior
/// ///
/// [VecZnxBig] (3 cols and 4 limbs) /// [VecZnxBig] (3 cols and 4 size)
/// [a0, b0, c0] [a1, b1, c1] [a2, b2, c2] [a3, b3, c3] /// [a0, b0, c0] [a1, b1, c1] [a2, b2, c2] [a3, b3, c3]
/// - /// -
/// [VecZnx] (2 cols and 3 limbs) /// [VecZnx] (2 cols and 3 size)
/// [d0, e0] [d1, e1] [d2, e2] /// [d0, e0] [d1, e1] [d2, e2]
/// = /// =
/// [a0-d0, b0-e0, c0] [a1-d1, b1-e1, c1] [a2-d2, b2-e2, c2] [a3, b3, c3] /// [a0-d0, b0-e0, c0] [a1-d1, b1-e1, c1] [a2-d2, b2-e2, c2] [a3, b3, c3]
@@ -203,26 +195,26 @@ pub trait VecZnxBigOps<B: Backend> {
} }
impl VecZnxBigOps<FFT64> for Module<FFT64> { impl VecZnxBigOps<FFT64> for Module<FFT64> {
fn new_vec_znx_big(&self, cols: usize, limbs: usize) -> VecZnxBig<FFT64> { fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBig<FFT64> {
VecZnxBig::new(self, cols, limbs) VecZnxBig::new(self, cols, size)
} }
fn new_vec_znx_big_from_bytes(&self, cols: usize, limbs: usize, bytes: &mut [u8]) -> VecZnxBig<FFT64> { fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxBig<FFT64> {
VecZnxBig::from_bytes(self, cols, limbs, bytes) VecZnxBig::from_bytes(self, cols, size, bytes)
} }
fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, limbs: usize, tmp_bytes: &mut [u8]) -> VecZnxBig<FFT64> { fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig<FFT64> {
VecZnxBig::from_bytes_borrow(self, cols, limbs, tmp_bytes) VecZnxBig::from_bytes_borrow(self, cols, size, tmp_bytes)
} }
fn bytes_of_vec_znx_big(&self, cols: usize, limbs: usize) -> usize { fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize {
VecZnxBig::bytes_of(self, cols, limbs) VecZnxBig::bytes_of(self, cols, size)
} }
/// [VecZnxBig] (3 cols and 4 limbs) /// [VecZnxBig] (3 cols and 4 size)
/// [a0, b0, c0] [a1, b1, c1] [a2, b2, c2] [a3, b3, c3] /// [a0, b0, c0] [a1, b1, c1] [a2, b2, c2] [a3, b3, c3]
/// - /// -
/// [VecZnx] (2 cols and 3 limbs) /// [VecZnx] (2 cols and 3 size)
/// [d0, e0] [d1, e1] [d2, e2] /// [d0, e0] [d1, e1] [d2, e2]
/// = /// =
/// [a0-d0, b0-e0, c0] [a1-d1, b1-e1, c1] [a2-d2, b2-e2, c2] [a3, b3, c3] /// [a0-d0, b0-e0, c0] [a1-d1, b1-e1, c1] [a2-d2, b2-e2, c2] [a3, b3, c3]
@@ -306,10 +298,10 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
self.ptr, self.ptr,
log_base2k as u64, log_base2k as u64,
b.as_mut_ptr(), b.as_mut_ptr(),
b.limbs() as u64, b.size() as u64,
b.n() as u64, b.n() as u64,
a.ptr as *mut vec_znx_big_t, a.ptr as *mut vec_znx_big_t,
a.limbs() as u64, a.size() as u64,
tmp_bytes.as_mut_ptr(), tmp_bytes.as_mut_ptr(),
) )
} }
@@ -344,7 +336,7 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
self.ptr, self.ptr,
log_base2k as u64, log_base2k as u64,
res.as_mut_ptr(), res.as_mut_ptr(),
res.limbs() as u64, res.size() as u64,
res.n() as u64, res.n() as u64,
a.ptr as *mut vec_znx_big_t, a.ptr as *mut vec_znx_big_t,
a_range_begin as u64, a_range_begin as u64,

View File

@@ -10,44 +10,44 @@ pub struct VecZnxDft<B: Backend> {
pub ptr: *mut u8, pub ptr: *mut u8,
pub n: usize, pub n: usize,
pub cols: usize, pub cols: usize,
pub limbs: usize, pub size: usize,
pub _marker: PhantomData<B>, pub _marker: PhantomData<B>,
} }
impl<B: Backend> ZnxBase<B> for VecZnxDft<B> { impl<B: Backend> ZnxBase<B> for VecZnxDft<B> {
type Scalar = u8; type Scalar = u8;
fn new(module: &Module<B>, cols: usize, limbs: usize) -> Self { fn new(module: &Module<B>, cols: usize, size: usize) -> Self {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert!(cols > 0); assert!(cols > 0);
assert!(limbs > 0); assert!(size > 0);
} }
let mut data: Vec<Self::Scalar> = alloc_aligned(Self::bytes_of(module, cols, limbs)); let mut data: Vec<Self::Scalar> = alloc_aligned(Self::bytes_of(module, cols, size));
let ptr: *mut Self::Scalar = data.as_mut_ptr(); let ptr: *mut Self::Scalar = data.as_mut_ptr();
Self { Self {
data: data, data: data,
ptr: ptr, ptr: ptr,
n: module.n(), n: module.n(),
limbs: limbs, size: size,
cols: cols, cols: cols,
_marker: PhantomData, _marker: PhantomData,
} }
} }
fn bytes_of(module: &Module<B>, cols: usize, limbs: usize) -> usize { fn bytes_of(module: &Module<B>, cols: usize, size: usize) -> usize {
unsafe { bytes_of_vec_znx_dft(module.ptr, limbs as u64) as usize * cols } unsafe { bytes_of_vec_znx_dft(module.ptr, size as u64) as usize * cols }
} }
/// Returns a new [VecZnxDft] with the provided data as backing array. /// Returns a new [VecZnxDft] with the provided data as backing array.
/// User must ensure that data is properly alligned and that /// User must ensure that data is properly alligned and that
/// the size of data is at least equal to [Module::bytes_of_vec_znx_dft]. /// the size of data is at least equal to [Module::bytes_of_vec_znx_dft].
fn from_bytes(module: &Module<B>, cols: usize, limbs: usize, bytes: &mut [Self::Scalar]) -> Self { fn from_bytes(module: &Module<B>, cols: usize, size: usize, bytes: &mut [Self::Scalar]) -> Self {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert!(cols > 0); assert!(cols > 0);
assert!(limbs > 0); assert!(size > 0);
assert_eq!(bytes.len(), Self::bytes_of(module, cols, limbs)); assert_eq!(bytes.len(), Self::bytes_of(module, cols, size));
assert_alignement(bytes.as_ptr()) assert_alignement(bytes.as_ptr())
} }
unsafe { unsafe {
@@ -56,18 +56,18 @@ impl<B: Backend> ZnxBase<B> for VecZnxDft<B> {
ptr: bytes.as_mut_ptr(), ptr: bytes.as_mut_ptr(),
n: module.n(), n: module.n(),
cols: cols, cols: cols,
limbs: limbs, size: size,
_marker: PhantomData, _marker: PhantomData,
} }
} }
} }
fn from_bytes_borrow(module: &Module<B>, cols: usize, limbs: usize, bytes: &mut [Self::Scalar]) -> Self { fn from_bytes_borrow(module: &Module<B>, cols: usize, size: usize, bytes: &mut [Self::Scalar]) -> Self {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert!(cols > 0); assert!(cols > 0);
assert!(limbs > 0); assert!(size > 0);
assert_eq!(bytes.len(), Self::bytes_of(module, cols, limbs)); assert_eq!(bytes.len(), Self::bytes_of(module, cols, size));
assert_alignement(bytes.as_ptr()); assert_alignement(bytes.as_ptr());
} }
Self { Self {
@@ -75,7 +75,7 @@ impl<B: Backend> ZnxBase<B> for VecZnxDft<B> {
ptr: bytes.as_mut_ptr(), ptr: bytes.as_mut_ptr(),
n: module.n(), n: module.n(),
cols: cols, cols: cols,
limbs: limbs, size: size,
_marker: PhantomData, _marker: PhantomData,
} }
} }
@@ -91,7 +91,7 @@ impl<B: Backend> VecZnxDft<B> {
ptr: self.ptr, ptr: self.ptr,
n: self.n, n: self.n,
cols: self.cols, cols: self.cols,
limbs: self.limbs, size: self.size,
_marker: PhantomData, _marker: PhantomData,
} }
} }
@@ -102,10 +102,6 @@ impl<B: Backend> ZnxInfos for VecZnxDft<B> {
self.n self.n
} }
fn log_n(&self) -> usize {
(usize::BITS - (self.n() - 1).leading_zeros()) as _
}
fn rows(&self) -> usize { fn rows(&self) -> usize {
1 1
} }
@@ -114,12 +110,8 @@ impl<B: Backend> ZnxInfos for VecZnxDft<B> {
self.cols self.cols
} }
fn limbs(&self) -> usize { fn size(&self) -> usize {
self.limbs self.size
}
fn poly_count(&self) -> usize {
self.cols * self.limbs
} }
} }
@@ -137,13 +129,13 @@ impl ZnxLayout for VecZnxDft<FFT64> {
impl VecZnxDft<FFT64> { impl VecZnxDft<FFT64> {
pub fn print(&self, n: usize) { pub fn print(&self, n: usize) {
(0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])); (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n]));
} }
} }
pub trait VecZnxDftOps<B: Backend> { pub trait VecZnxDftOps<B: Backend> {
/// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space. /// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space.
fn new_vec_znx_dft(&self, cols: usize, limbs: usize) -> VecZnxDft<B>; fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDft<B>;
/// Returns a new [VecZnxDft] with the provided bytes array as backing array. /// Returns a new [VecZnxDft] with the provided bytes array as backing array.
/// ///
@@ -156,7 +148,7 @@ pub trait VecZnxDftOps<B: Backend> {
/// ///
/// # Panics /// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
fn new_vec_znx_dft_from_bytes(&self, cols: usize, limbs: usize, bytes: &mut [u8]) -> VecZnxDft<B>; fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft<B>;
/// Returns a new [VecZnxDft] with the provided bytes array as backing array. /// Returns a new [VecZnxDft] with the provided bytes array as backing array.
/// ///
@@ -169,7 +161,7 @@ pub trait VecZnxDftOps<B: Backend> {
/// ///
/// # Panics /// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, limbs: usize, bytes: &mut [u8]) -> VecZnxDft<B>; fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft<B>;
/// Returns a new [VecZnxDft] with the provided bytes array as backing array. /// Returns a new [VecZnxDft] with the provided bytes array as backing array.
/// ///
@@ -180,7 +172,7 @@ pub trait VecZnxDftOps<B: Backend> {
/// ///
/// # Panics /// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
fn bytes_of_vec_znx_dft(&self, cols: usize, limbs: usize) -> usize; fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize;
/// Returns the minimum number of bytes necessary to allocate /// Returns the minimum number of bytes necessary to allocate
/// a new [VecZnxDft] through [VecZnxDft::from_bytes]. /// a new [VecZnxDft] through [VecZnxDft::from_bytes].
@@ -201,20 +193,20 @@ pub trait VecZnxDftOps<B: Backend> {
} }
impl VecZnxDftOps<FFT64> for Module<FFT64> { impl VecZnxDftOps<FFT64> for Module<FFT64> {
fn new_vec_znx_dft(&self, cols: usize, limbs: usize) -> VecZnxDft<FFT64> { fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDft<FFT64> {
VecZnxDft::<FFT64>::new(&self, cols, limbs) VecZnxDft::<FFT64>::new(&self, cols, size)
} }
fn new_vec_znx_dft_from_bytes(&self, cols: usize, limbs: usize, tmp_bytes: &mut [u8]) -> VecZnxDft<FFT64> { fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxDft<FFT64> {
VecZnxDft::from_bytes(self, cols, limbs, tmp_bytes) VecZnxDft::from_bytes(self, cols, size, tmp_bytes)
} }
fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, limbs: usize, tmp_bytes: &mut [u8]) -> VecZnxDft<FFT64> { fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxDft<FFT64> {
VecZnxDft::from_bytes_borrow(self, cols, limbs, tmp_bytes) VecZnxDft::from_bytes_borrow(self, cols, size, tmp_bytes)
} }
fn bytes_of_vec_znx_dft(&self, cols: usize, limbs: usize) -> usize { fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize {
VecZnxDft::bytes_of(&self, cols, limbs) VecZnxDft::bytes_of(&self, cols, size)
} }
fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig<FFT64>, a: &mut VecZnxDft<FFT64>) { fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig<FFT64>, a: &mut VecZnxDft<FFT64>) {
@@ -242,9 +234,9 @@ impl VecZnxDftOps<FFT64> for Module<FFT64> {
vec_znx_dft::vec_znx_dft( vec_znx_dft::vec_znx_dft(
self.ptr, self.ptr,
b.ptr as *mut vec_znx_dft_t, b.ptr as *mut vec_znx_dft_t,
b.limbs() as u64, b.size() as u64,
a.as_ptr(), a.as_ptr(),
a.limbs() as u64, a.size() as u64,
(a.n() * a.cols()) as u64, (a.n() * a.cols()) as u64,
) )
} }
@@ -329,14 +321,14 @@ mod tests {
let n: usize = 8; let n: usize = 8;
let module: Module<FFT64> = Module::<FFT64>::new(n); let module: Module<FFT64> = Module::<FFT64>::new(n);
let limbs: usize = 2; let size: usize = 2;
let log_base2k: usize = 17; let log_base2k: usize = 17;
let mut a: VecZnx = module.new_vec_znx(1, limbs); let mut a: VecZnx = module.new_vec_znx(1, size);
let mut a_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(1, limbs); let mut a_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(1, size);
let mut b_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(1, limbs); let mut b_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(1, size);
let mut source: Source = Source::new([0u8; 32]); let mut source: Source = Source::new([0u8; 32]);
module.fill_uniform(log_base2k, &mut a, 0, limbs, &mut source); module.fill_uniform(log_base2k, &mut a, 0, size, &mut source);
let mut tmp_bytes: Vec<u8> = alloc_aligned(module.vec_znx_dft_automorphism_tmp_bytes()); let mut tmp_bytes: Vec<u8> = alloc_aligned(module.vec_znx_dft_automorphism_tmp_bytes());

795
base2k/src/vec_znx_ops.rs Normal file
View File

@@ -0,0 +1,795 @@
use crate::ffi::module::MODULE;
use crate::ffi::vec_znx;
use crate::{Backend, Module, VecZnx, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, switch_degree, znx_post_process_ternary_op};
use std::cmp::min;
pub trait VecZnxOps {
/// Allocates a new [VecZnx].
///
/// # Arguments
///
/// * `cols`: the number of polynomials.
/// * `size`: the number of size per polynomial (a.k.a small polynomials).
fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnx;
fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnx;
fn new_vec_znx_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnx;
/// Returns the minimum number of bytes necessary to allocate
/// a new [VecZnx] through [VecZnx::from_bytes].
fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize;
fn vec_znx_normalize_tmp_bytes(&self, cols: usize) -> usize;
/// c <- a + b.
fn vec_znx_add(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx);
/// b <- b + a.
fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx);
/// c <- a - b.
fn vec_znx_sub(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx);
/// b <- a - b.
fn vec_znx_sub_ab_inplace(&self, b: &mut VecZnx, a: &VecZnx);
/// b <- b - a.
fn vec_znx_sub_ba_inplace(&self, b: &mut VecZnx, a: &VecZnx);
/// b <- -a.
fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx);
/// b <- -b.
fn vec_znx_negate_inplace(&self, a: &mut VecZnx);
/// b <- a * X^k (mod X^{n} + 1)
fn vec_znx_rotate(&self, k: i64, b: &mut VecZnx, a: &VecZnx);
/// a <- a * X^k (mod X^{n} + 1)
fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx);
/// b <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1))
fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx);
/// a <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1))
fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx);
/// Splits b into subrings and copies them them into a.
///
/// # Panics
///
/// This method requires that all [VecZnx] of b have the same ring degree
/// and that b.n() * b.len() <= a.n()
fn vec_znx_split(&self, b: &mut Vec<VecZnx>, a: &VecZnx, buf: &mut VecZnx);
/// Merges the subrings a into b.
///
/// # Panics
///
/// This method requires that all [VecZnx] of a have the same ring degree
/// and that a.n() * a.len() <= b.n()
fn vec_znx_merge(&self, b: &mut VecZnx, a: &Vec<VecZnx>);
}
impl<B: Backend> VecZnxOps for Module<B> {
fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnx {
VecZnx::new(self, cols, size)
}
fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize {
VecZnx::bytes_of(self, cols, size)
}
fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnx {
VecZnx::from_bytes(self, cols, size, bytes)
}
fn new_vec_znx_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnx {
VecZnx::from_bytes_borrow(self, cols, size, tmp_bytes)
}
fn vec_znx_normalize_tmp_bytes(&self, cols: usize) -> usize {
unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.ptr) as usize * cols }
}
fn vec_znx_add(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx) {
let op = ffi_ternary_op_factory(
self.ptr,
c.size(),
c.sl(),
a.size(),
a.sl(),
b.size(),
b.sl(),
vec_znx::vec_znx_add,
);
vec_znx_apply_binary_op::<B, false>(self, c, a, b, op);
}
fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx) {
unsafe {
let b_ptr: *mut VecZnx = b as *mut VecZnx;
Self::vec_znx_add(self, &mut *b_ptr, a, &*b_ptr);
}
}
fn vec_znx_sub(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx) {
let op = ffi_ternary_op_factory(
self.ptr,
c.size(),
c.sl(),
a.size(),
a.sl(),
b.size(),
b.sl(),
vec_znx::vec_znx_sub,
);
vec_znx_apply_binary_op::<B, true>(self, c, a, b, op);
}
fn vec_znx_sub_ab_inplace(&self, b: &mut VecZnx, a: &VecZnx) {
unsafe {
let b_ptr: *mut VecZnx = b as *mut VecZnx;
Self::vec_znx_sub(self, &mut *b_ptr, a, &*b_ptr);
}
}
fn vec_znx_sub_ba_inplace(&self, b: &mut VecZnx, a: &VecZnx) {
unsafe {
let b_ptr: *mut VecZnx = b as *mut VecZnx;
Self::vec_znx_sub(self, &mut *b_ptr, &*b_ptr, a);
}
}
fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx) {
let op = ffi_binary_op_factory_type_0(
self.ptr,
b.size(),
b.sl(),
a.size(),
a.sl(),
vec_znx::vec_znx_negate,
);
vec_znx_apply_unary_op::<B>(self, b, a, op);
}
fn vec_znx_negate_inplace(&self, a: &mut VecZnx) {
unsafe {
let a_ptr: *mut VecZnx = a as *mut VecZnx;
Self::vec_znx_negate(self, &mut *a_ptr, &*a_ptr);
}
}
fn vec_znx_rotate(&self, k: i64, b: &mut VecZnx, a: &VecZnx) {
let op = ffi_binary_op_factory_type_1(
self.ptr,
k,
b.size(),
b.sl(),
a.size(),
a.sl(),
vec_znx::vec_znx_rotate,
);
vec_znx_apply_unary_op::<B>(self, b, a, op);
}
fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx) {
unsafe {
let a_ptr: *mut VecZnx = a as *mut VecZnx;
Self::vec_znx_rotate(self, k, &mut *a_ptr, &*a_ptr);
}
}
/// Maps X^i to X^{ik} mod X^{n}+1. The mapping is applied independently on each size.
///
/// # Arguments
///
/// * `a`: input.
/// * `b`: output.
/// * `k`: the power to which to map each coefficients.
/// * `a_size`: the number of a_size on which to apply the mapping.
///
/// # Panics
///
/// The method will panic if the argument `a` is greater than `a.size()`.
fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx) {
let op = ffi_binary_op_factory_type_1(
self.ptr,
k,
b.size(),
b.sl(),
a.size(),
a.sl(),
vec_znx::vec_znx_automorphism,
);
vec_znx_apply_unary_op::<B>(self, b, a, op);
}
/// Maps X^i to X^{ik} mod X^{n}+1. The mapping is applied independently on each size.
///
/// # Arguments
///
/// * `a`: input and output.
/// * `k`: the power to which to map each coefficients.
/// * `a_size`: the number of size on which to apply the mapping.
///
/// # Panics
///
/// The method will panic if the argument `size` is greater than `self.size()`.
fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx) {
unsafe {
let a_ptr: *mut VecZnx = a as *mut VecZnx;
Self::vec_znx_automorphism(self, k, &mut *a_ptr, &*a_ptr);
}
}
fn vec_znx_split(&self, b: &mut Vec<VecZnx>, a: &VecZnx, buf: &mut VecZnx) {
let (n_in, n_out) = (a.n(), b[0].n());
debug_assert!(
n_out < n_in,
"invalid a: output ring degree should be smaller"
);
b[1..].iter().for_each(|bi| {
debug_assert_eq!(
bi.n(),
n_out,
"invalid input a: all VecZnx must have the same degree"
)
});
b.iter_mut().enumerate().for_each(|(i, bi)| {
if i == 0 {
switch_degree(bi, a);
self.vec_znx_rotate(-1, buf, a);
} else {
switch_degree(bi, buf);
self.vec_znx_rotate_inplace(-1, buf);
}
})
}
fn vec_znx_merge(&self, b: &mut VecZnx, a: &Vec<VecZnx>) {
let (n_in, n_out) = (b.n(), a[0].n());
debug_assert!(
n_out < n_in,
"invalid a: output ring degree should be smaller"
);
a[1..].iter().for_each(|ai| {
debug_assert_eq!(
ai.n(),
n_out,
"invalid input a: all VecZnx must have the same degree"
)
});
a.iter().enumerate().for_each(|(_, ai)| {
switch_degree(b, ai);
self.vec_znx_rotate_inplace(-1, b);
});
self.vec_znx_rotate_inplace(a.len() as i64, b);
}
}
fn ffi_ternary_op_factory(
module_ptr: *const MODULE,
c_size: usize,
c_sl: usize,
a_size: usize,
a_sl: usize,
b_size: usize,
b_sl: usize,
op_fn: unsafe extern "C" fn(*const MODULE, *mut i64, u64, u64, *const i64, u64, u64, *const i64, u64, u64),
) -> impl Fn(&mut [i64], &[i64], &[i64]) {
move |cv: &mut [i64], av: &[i64], bv: &[i64]| unsafe {
op_fn(
module_ptr,
cv.as_mut_ptr(),
c_size as u64,
c_sl as u64,
av.as_ptr(),
a_size as u64,
a_sl as u64,
bv.as_ptr(),
b_size as u64,
b_sl as u64,
)
}
}
fn ffi_binary_op_factory_type_0(
module_ptr: *const MODULE,
b_size: usize,
b_sl: usize,
a_size: usize,
a_sl: usize,
op_fn: unsafe extern "C" fn(*const MODULE, *mut i64, u64, u64, *const i64, u64, u64),
) -> impl Fn(&mut [i64], &[i64]) {
move |bv: &mut [i64], av: &[i64]| unsafe {
op_fn(
module_ptr,
bv.as_mut_ptr(),
b_size as u64,
b_sl as u64,
av.as_ptr(),
a_size as u64,
a_sl as u64,
)
}
}
fn ffi_binary_op_factory_type_1(
module_ptr: *const MODULE,
k: i64,
b_size: usize,
b_sl: usize,
a_size: usize,
a_sl: usize,
op_fn: unsafe extern "C" fn(*const MODULE, i64, *mut i64, u64, u64, *const i64, u64, u64),
) -> impl Fn(&mut [i64], &[i64]) {
move |bv: &mut [i64], av: &[i64]| unsafe {
op_fn(
module_ptr,
k,
bv.as_mut_ptr(),
b_size as u64,
b_sl as u64,
av.as_ptr(),
a_size as u64,
a_sl as u64,
)
}
}
#[inline(always)]
pub fn vec_znx_apply_binary_op<B: Backend, const NEGATE: bool>(
module: &Module<B>,
c: &mut VecZnx,
a: &VecZnx,
b: &VecZnx,
op: impl Fn(&mut [i64], &[i64], &[i64]),
) {
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(b.n(), module.n());
assert_eq!(c.n(), module.n());
assert_ne!(a.as_ptr(), b.as_ptr());
}
let a_cols: usize = a.cols();
let b_cols: usize = b.cols();
let c_cols: usize = c.cols();
let min_ab_cols: usize = min(a_cols, b_cols);
let min_cols: usize = min(c_cols, min_ab_cols);
// Applies over shared cols between (a, b, c)
(0..min_cols).for_each(|i| op(c.at_poly_mut(i, 0), a.at_poly(i, 0), b.at_poly(i, 0)));
// Copies/Negates/Zeroes the remaining cols if op is not inplace.
if c.as_ptr() != a.as_ptr() && c.as_ptr() != b.as_ptr() {
znx_post_process_ternary_op::<VecZnx, NEGATE>(c, a, b);
}
}
#[inline(always)]
pub fn vec_znx_apply_unary_op<B: Backend>(module: &Module<B>, b: &mut VecZnx, a: &VecZnx, op: impl Fn(&mut [i64], &[i64])) {
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(b.n(), module.n());
}
let a_cols: usize = a.cols();
let b_cols: usize = b.cols();
let min_cols: usize = min(a_cols, b_cols);
// Applies over the shared cols between (a, b)
(0..min_cols).for_each(|i| op(b.at_poly_mut(i, 0), a.at_poly(i, 0)));
// Zeroes the remaining cols of b.
(min_cols..b_cols).for_each(|i| (0..b.size()).for_each(|j| b.zero_at(i, j)));
}
#[cfg(test)]
mod tests {
use crate::{
Backend, FFT64, Module, Sampling, VecZnx, VecZnxOps, ZnxBasics, ZnxInfos, ZnxLayout, ffi::vec_znx,
znx_post_process_ternary_op,
};
use itertools::izip;
use sampling::source::Source;
use std::cmp::min;
#[test]
fn vec_znx_add() {
let n: usize = 8;
let module: Module<FFT64> = Module::<FFT64>::new(n);
let op = |cv: &mut [i64], av: &[i64], bv: &[i64]| {
izip!(cv.iter_mut(), bv.iter(), av.iter()).for_each(|(ci, bi, ai)| *ci = *bi + *ai);
};
test_binary_op::<false, _>(
&module,
&|c: &mut VecZnx, a: &VecZnx, b: &VecZnx| module.vec_znx_add(c, a, b),
op,
);
}
#[test]
fn vec_znx_add_inplace() {
let n: usize = 8;
let module: Module<FFT64> = Module::<FFT64>::new(n);
let op = |bv: &mut [i64], av: &[i64]| {
izip!(bv.iter_mut(), av.iter()).for_each(|(bi, ai)| *bi = *bi + *ai);
};
test_binary_op_inplace::<false, _>(
&module,
&|b: &mut VecZnx, a: &VecZnx| module.vec_znx_add_inplace(b, a),
op,
);
}
#[test]
fn vec_znx_sub() {
let n: usize = 8;
let module: Module<FFT64> = Module::<FFT64>::new(n);
let op = |cv: &mut [i64], av: &[i64], bv: &[i64]| {
izip!(cv.iter_mut(), bv.iter(), av.iter()).for_each(|(ci, bi, ai)| *ci = *bi - *ai);
};
test_binary_op::<true, _>(
&module,
&|c: &mut VecZnx, a: &VecZnx, b: &VecZnx| module.vec_znx_sub(c, a, b),
op,
);
}
#[test]
fn vec_znx_sub_ab_inplace() {
let n: usize = 8;
let module: Module<FFT64> = Module::<FFT64>::new(n);
let op = |bv: &mut [i64], av: &[i64]| {
izip!(bv.iter_mut(), av.iter()).for_each(|(bi, ai)| *bi = *ai - *bi);
};
test_binary_op_inplace::<true, _>(
&module,
&|b: &mut VecZnx, a: &VecZnx| module.vec_znx_sub_ab_inplace(b, a),
op,
);
}
#[test]
fn vec_znx_sub_ba_inplace() {
let n: usize = 8;
let module: Module<FFT64> = Module::<FFT64>::new(n);
let op = |bv: &mut [i64], av: &[i64]| {
izip!(bv.iter_mut(), av.iter()).for_each(|(bi, ai)| *bi = *bi - *ai);
};
test_binary_op_inplace::<false, _>(
&module,
&|b: &mut VecZnx, a: &VecZnx| module.vec_znx_sub_ba_inplace(b, a),
op,
);
}
#[test]
fn vec_znx_negate() {
let n: usize = 8;
let module: Module<FFT64> = Module::<FFT64>::new(n);
let op = |b: &mut [i64], a: &[i64]| {
izip!(b.iter_mut(), a.iter()).for_each(|(bi, ai)| *bi = -*ai);
};
test_unary_op(
&module,
|b: &mut VecZnx, a: &VecZnx| module.vec_znx_negate(b, a),
op,
)
}
#[test]
fn vec_znx_negate_inplace() {
let n: usize = 8;
let module: Module<FFT64> = Module::<FFT64>::new(n);
let op = |a: &mut [i64]| a.iter_mut().for_each(|xi| *xi = -*xi);
test_unary_op_inplace(
&module,
|a: &mut VecZnx| module.vec_znx_negate_inplace(a),
op,
)
}
#[test]
fn vec_znx_rotate() {
let n: usize = 8;
let module: Module<FFT64> = Module::<FFT64>::new(n);
let k: i64 = 53;
let op = |b: &mut [i64], a: &[i64]| {
assert_eq!(b.len(), a.len());
b.copy_from_slice(a);
let mut k_mod2n: i64 = k % (2 * n as i64);
if k_mod2n < 0 {
k_mod2n += 2 * n as i64;
}
let sign: i64 = (k_mod2n.abs() / (n as i64)) & 1;
let k_modn: i64 = k_mod2n % (n as i64);
b.rotate_right(k_modn as usize);
b[0..k_modn as usize].iter_mut().for_each(|x| *x = -*x);
if sign == 1 {
b.iter_mut().for_each(|x| *x = -*x);
}
};
test_unary_op(
&module,
|b: &mut VecZnx, a: &VecZnx| module.vec_znx_rotate(k, b, a),
op,
)
}
#[test]
fn vec_znx_rotate_inplace() {
let n: usize = 8;
let module: Module<FFT64> = Module::<FFT64>::new(n);
let k: i64 = 53;
let rot = |a: &mut [i64]| {
let mut k_mod2n: i64 = k % (2 * n as i64);
if k_mod2n < 0 {
k_mod2n += 2 * n as i64;
}
let sign: i64 = (k_mod2n.abs() / (n as i64)) & 1;
let k_modn: i64 = k_mod2n % (n as i64);
a.rotate_right(k_modn as usize);
a[0..k_modn as usize].iter_mut().for_each(|x| *x = -*x);
if sign == 1 {
a.iter_mut().for_each(|x| *x = -*x);
}
};
test_unary_op_inplace(
&module,
|a: &mut VecZnx| module.vec_znx_rotate_inplace(k, a),
rot,
)
}
#[test]
fn vec_znx_automorphism() {
let n: usize = 8;
let module: Module<FFT64> = Module::<FFT64>::new(n);
let k: i64 = -5;
let op = |b: &mut [i64], a: &[i64]| {
assert_eq!(b.len(), a.len());
unsafe {
vec_znx::vec_znx_automorphism(
module.ptr,
k,
b.as_mut_ptr(),
1u64,
n as u64,
a.as_ptr(),
1u64,
n as u64,
);
}
};
test_unary_op(
&module,
|b: &mut VecZnx, a: &VecZnx| module.vec_znx_automorphism(k, b, a),
op,
)
}
#[test]
fn vec_znx_automorphism_inplace() {
let n: usize = 8;
let module: Module<FFT64> = Module::<FFT64>::new(n);
let k: i64 = -5;
let op = |a: &mut [i64]| unsafe {
vec_znx::vec_znx_automorphism(
module.ptr,
k,
a.as_mut_ptr(),
1u64,
n as u64,
a.as_ptr(),
1u64,
n as u64,
);
};
test_unary_op_inplace(
&module,
|a: &mut VecZnx| module.vec_znx_automorphism_inplace(k, a),
op,
)
}
fn test_binary_op<const NEGATE: bool, B: Backend>(
module: &Module<B>,
func_have: impl Fn(&mut VecZnx, &VecZnx, &VecZnx),
func_want: impl Fn(&mut [i64], &[i64], &[i64]),
) {
let a_size: usize = 3;
let b_size: usize = 4;
let c_size: usize = 5;
let mut source: Source = Source::new([0u8; 32]);
[1usize, 2, 3].iter().for_each(|a_cols| {
[1usize, 2, 3].iter().for_each(|b_cols| {
[1usize, 2, 3].iter().for_each(|c_cols| {
let min_ab_cols: usize = min(*a_cols, *b_cols);
let min_cols: usize = min(*c_cols, min_ab_cols);
let min_size: usize = min(c_size, min(a_size, b_size));
let mut a: VecZnx = module.new_vec_znx(*a_cols, a_size);
(0..*a_cols).for_each(|i| {
module.fill_uniform(3, &mut a, i, a_size, &mut source);
});
let mut b: VecZnx = module.new_vec_znx(*b_cols, b_size);
(0..*b_cols).for_each(|i| {
module.fill_uniform(3, &mut b, i, b_size, &mut source);
});
let mut c_have: VecZnx = module.new_vec_znx(*c_cols, c_size);
(0..c_have.cols()).for_each(|i| {
module.fill_uniform(3, &mut c_have, i, c_size, &mut source);
});
func_have(&mut c_have, &a, &b);
let mut c_want: VecZnx = module.new_vec_znx(*c_cols, c_size);
// Adds with the minimum matching columns
(0..min_cols).for_each(|i| {
// Adds with th eminimum matching size
(0..min_size).for_each(|j| {
func_want(c_want.at_poly_mut(i, j), b.at_poly(i, j), a.at_poly(i, j));
});
if a_size > b_size {
// Copies remaining size of lh if lh.size() > rh.size()
(min_size..a_size).for_each(|j| {
izip!(c_want.at_poly_mut(i, j).iter_mut(), a.at_poly(i, j).iter()).for_each(|(ci, ai)| *ci = *ai);
if NEGATE {
c_want.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x);
}
});
} else {
// Copies the remaining size of rh if the are greater
(min_size..b_size).for_each(|j| {
izip!(c_want.at_poly_mut(i, j).iter_mut(), b.at_poly(i, j).iter()).for_each(|(ci, bi)| *ci = *bi);
if NEGATE {
c_want.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x);
}
});
}
});
znx_post_process_ternary_op::<_, NEGATE>(&mut c_want, &a, &b);
assert_eq!(c_have.raw(), c_want.raw());
});
});
});
}
fn test_binary_op_inplace<const NEGATE: bool, B: Backend>(
module: &Module<B>,
func_have: impl Fn(&mut VecZnx, &VecZnx),
func_want: impl Fn(&mut [i64], &[i64]),
) {
let a_size: usize = 3;
let b_size: usize = 5;
let mut source = Source::new([0u8; 32]);
[1usize, 2, 3].iter().for_each(|a_cols| {
[1usize, 2, 3].iter().for_each(|b_cols| {
let min_cols: usize = min(*b_cols, *a_cols);
let min_size: usize = min(b_size, a_size);
let mut a: VecZnx = module.new_vec_znx(*a_cols, a_size);
(0..*a_cols).for_each(|i| {
module.fill_uniform(3, &mut a, i, a_size, &mut source);
});
let mut b_have: VecZnx = module.new_vec_znx(*b_cols, b_size);
(0..*b_cols).for_each(|i| {
module.fill_uniform(3, &mut b_have, i, b_size, &mut source);
});
let mut b_want: VecZnx = module.new_vec_znx(*b_cols, b_size);
b_want.raw_mut().copy_from_slice(b_have.raw());
func_have(&mut b_have, &a);
// Applies with the minimum matching columns
(0..min_cols).for_each(|i| {
// Adds with th eminimum matching size
(0..min_size).for_each(|j| func_want(b_want.at_poly_mut(i, j), a.at_poly(i, j)));
if NEGATE {
(min_size..b_size).for_each(|j| {
b_want.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x);
});
}
});
assert_eq!(b_have.raw(), b_want.raw());
});
});
}
fn test_unary_op<B: Backend>(
module: &Module<B>,
func_have: impl Fn(&mut VecZnx, &VecZnx),
func_want: impl Fn(&mut [i64], &[i64]),
) {
let a_size: usize = 3;
let b_size: usize = 5;
let mut source = Source::new([0u8; 32]);
[1usize, 2, 3].iter().for_each(|a_cols| {
[1usize, 2, 3].iter().for_each(|b_cols| {
let min_cols: usize = min(*b_cols, *a_cols);
let min_size: usize = min(b_size, a_size);
let mut a: VecZnx = module.new_vec_znx(*a_cols, a_size);
(0..a.cols()).for_each(|i| {
module.fill_uniform(3, &mut a, i, a_size, &mut source);
});
let mut b_have: VecZnx = module.new_vec_znx(*b_cols, b_size);
(0..b_have.cols()).for_each(|i| {
module.fill_uniform(3, &mut b_have, i, b_size, &mut source);
});
let mut b_want: VecZnx = module.new_vec_znx(*b_cols, b_size);
func_have(&mut b_have, &a);
// Applies on the minimum matching columns
(0..min_cols).for_each(|i| {
// Applies on the minimum matching size
(0..min_size).for_each(|j| func_want(b_want.at_poly_mut(i, j), a.at_poly(i, j)));
// Zeroes the unmatching size
(min_size..b_size).for_each(|j| {
b_want.zero_at(i, j);
})
});
// Zeroes the unmatching columns
(min_cols..*b_cols).for_each(|i| {
(0..b_size).for_each(|j| {
b_want.zero_at(i, j);
})
});
assert_eq!(b_have.raw(), b_want.raw());
});
});
}
fn test_unary_op_inplace<B: Backend>(module: &Module<B>, func_have: impl Fn(&mut VecZnx), func_want: impl Fn(&mut [i64])) {
let a_size: usize = 3;
let mut source = Source::new([0u8; 32]);
[1usize, 2, 3].iter().for_each(|a_cols| {
let mut a_have: VecZnx = module.new_vec_znx(*a_cols, a_size);
(0..*a_cols).for_each(|i| {
module.fill_uniform(3, &mut a_have, i, a_size, &mut source);
});
let mut a_want: VecZnx = module.new_vec_znx(*a_cols, a_size);
a_have.raw_mut().copy_from_slice(a_want.raw());
func_have(&mut a_have);
// Applies on the minimum matching columns
(0..*a_cols).for_each(|i| {
// Applies on the minimum matching size
(0..a_size).for_each(|j| func_want(a_want.at_poly_mut(i, j)));
});
assert_eq!(a_have.raw(), a_want.raw());
});
}
}