mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
refactoring of vec_znx
This commit is contained in:
@@ -35,7 +35,7 @@ fn main() {
|
||||
module.fill_uniform(log_base2k, &mut a, 0, limbs, &mut source);
|
||||
|
||||
// Scratch space for DFT values
|
||||
let mut buf_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(1, a.limbs());
|
||||
let mut buf_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(1, a.size());
|
||||
|
||||
// Applies buf_dft <- s * a
|
||||
module.svp_apply_dft(&mut buf_dft, &s_ppol, &a);
|
||||
@@ -93,9 +93,9 @@ fn main() {
|
||||
|
||||
// have = m * 2^{log_scale} + e
|
||||
let mut have: Vec<i64> = vec![i64::default(); n];
|
||||
res.decode_vec_i64(0, log_base2k, res.limbs() * log_base2k, &mut have);
|
||||
res.decode_vec_i64(0, log_base2k, res.size() * log_base2k, &mut have);
|
||||
|
||||
let scale: f64 = (1 << (res.limbs() * log_base2k - log_scale)) as f64;
|
||||
let scale: f64 = (1 << (res.size() * log_base2k - log_scale)) as f64;
|
||||
izip!(want.iter(), have.iter())
|
||||
.enumerate()
|
||||
.for_each(|(i, (a, b))| {
|
||||
|
||||
@@ -33,7 +33,7 @@ fn main() {
|
||||
|
||||
let mut mat_znx_dft: MatZnxDft<FFT64> = module.new_mat_znx_dft(rows_mat, 1, limbs_mat);
|
||||
|
||||
(0..a.limbs()).for_each(|row_i| {
|
||||
(0..a.size()).for_each(|row_i| {
|
||||
let mut tmp: VecZnx = module.new_vec_znx(1, limbs_mat);
|
||||
tmp.at_limb_mut(row_i)[1] = 1 as i64;
|
||||
module.vmp_prepare_row(&mut mat_znx_dft, tmp.raw(), row_i, &mut buf);
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
use crate::{Backend, Module};
|
||||
use crate::{Backend, Module, assert_alignement, cast_mut};
|
||||
use itertools::izip;
|
||||
use std::cmp::{max, min};
|
||||
|
||||
pub trait ZnxInfos {
|
||||
/// Returns the ring degree of the polynomials.
|
||||
fn n(&self) -> usize;
|
||||
|
||||
/// Returns the base two logarithm of the ring dimension of the polynomials.
|
||||
fn log_n(&self) -> usize;
|
||||
fn log_n(&self) -> usize {
|
||||
(usize::BITS - (self.n() - 1).leading_zeros()) as _
|
||||
}
|
||||
|
||||
/// Returns the number of rows.
|
||||
fn rows(&self) -> usize;
|
||||
@@ -13,21 +17,28 @@ pub trait ZnxInfos {
|
||||
/// Returns the number of polynomials in each row.
|
||||
fn cols(&self) -> usize;
|
||||
|
||||
/// Returns the number of limbs per polynomial.
|
||||
fn limbs(&self) -> usize;
|
||||
/// Returns the number of size per polynomial.
|
||||
fn size(&self) -> usize;
|
||||
|
||||
/// Returns the total number of small polynomials.
|
||||
fn poly_count(&self) -> usize;
|
||||
fn poly_count(&self) -> usize {
|
||||
self.rows() * self.cols() * self.size()
|
||||
}
|
||||
|
||||
/// Returns the slice size, which is the offset between
|
||||
/// two size of the same column.
|
||||
fn sl(&self) -> usize {
|
||||
self.n() * self.cols()
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ZnxBase<B: Backend> {
|
||||
type Scalar;
|
||||
fn new(module: &Module<B>, cols: usize, limbs: usize) -> Self;
|
||||
fn from_bytes(module: &Module<B>, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self;
|
||||
fn from_bytes_borrow(module: &Module<B>, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self;
|
||||
fn bytes_of(module: &Module<B>, cols: usize, limbs: usize) -> usize;
|
||||
fn new(module: &Module<B>, cols: usize, size: usize) -> Self;
|
||||
fn from_bytes(module: &Module<B>, cols: usize, size: usize, bytes: &mut [u8]) -> Self;
|
||||
fn from_bytes_borrow(module: &Module<B>, cols: usize, size: usize, bytes: &mut [u8]) -> Self;
|
||||
fn bytes_of(module: &Module<B>, cols: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
pub trait ZnxLayout: ZnxInfos {
|
||||
type Scalar;
|
||||
|
||||
@@ -52,7 +63,7 @@ pub trait ZnxLayout: ZnxInfos {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(i < self.cols());
|
||||
assert!(j < self.limbs());
|
||||
assert!(j < self.size());
|
||||
}
|
||||
let offset = self.n() * (j * self.cols() + i);
|
||||
unsafe { self.as_ptr().add(offset) }
|
||||
@@ -63,7 +74,7 @@ pub trait ZnxLayout: ZnxInfos {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(i < self.cols());
|
||||
assert!(j < self.limbs());
|
||||
assert!(j < self.size());
|
||||
}
|
||||
let offset = self.n() * (j * self.cols() + i);
|
||||
unsafe { self.as_mut_ptr().add(offset) }
|
||||
@@ -89,3 +100,195 @@ pub trait ZnxLayout: ZnxInfos {
|
||||
unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(0, j), self.n() * self.cols()) }
|
||||
}
|
||||
}
|
||||
|
||||
use std::convert::TryFrom;
|
||||
use std::num::TryFromIntError;
|
||||
use std::ops::{Add, AddAssign, Div, Mul, Neg, Shl, Shr, Sub};
|
||||
pub trait IntegerType:
|
||||
Copy
|
||||
+ std::fmt::Debug
|
||||
+ Default
|
||||
+ PartialEq
|
||||
+ PartialOrd
|
||||
+ Add<Output = Self>
|
||||
+ Sub<Output = Self>
|
||||
+ Mul<Output = Self>
|
||||
+ Div<Output = Self>
|
||||
+ Neg<Output = Self>
|
||||
+ Shr<Output = Self>
|
||||
+ Shl<Output = Self>
|
||||
+ AddAssign
|
||||
+ TryFrom<usize, Error = TryFromIntError>
|
||||
{
|
||||
const BITS: u32;
|
||||
}
|
||||
|
||||
impl IntegerType for i64 {
|
||||
const BITS: u32 = 64;
|
||||
}
|
||||
|
||||
impl IntegerType for i128 {
|
||||
const BITS: u32 = 128;
|
||||
}
|
||||
|
||||
pub trait ZnxBasics: ZnxLayout
|
||||
where
|
||||
Self: Sized,
|
||||
Self::Scalar: IntegerType,
|
||||
{
|
||||
fn zero(&mut self) {
|
||||
unsafe {
|
||||
std::ptr::write_bytes(self.as_mut_ptr(), 0, self.n() * size_of::<Self::Scalar>());
|
||||
}
|
||||
}
|
||||
|
||||
fn zero_at(&mut self, i: usize, j: usize) {
|
||||
unsafe {
|
||||
std::ptr::write_bytes(
|
||||
self.at_mut_ptr(i, j),
|
||||
0,
|
||||
self.n() * size_of::<Self::Scalar>(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn rsh(&mut self, log_base2k: usize, k: usize, carry: &mut [u8]) {
|
||||
rsh(log_base2k, self, k, carry)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn rsh<V: ZnxBasics>(log_base2k: usize, a: &mut V, k: usize, tmp_bytes: &mut [u8])
|
||||
where
|
||||
V::Scalar: IntegerType,
|
||||
{
|
||||
let n: usize = a.n();
|
||||
let size: usize = a.size();
|
||||
let cols: usize = a.cols();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(
|
||||
tmp_bytes.len() >= rsh_tmp_bytes::<V::Scalar>(n, cols),
|
||||
"invalid carry: carry.len()/size_ofSelf::Scalar={} < rsh_tmp_bytes({}, {})",
|
||||
tmp_bytes.len() / size_of::<V::Scalar>(),
|
||||
n,
|
||||
size,
|
||||
);
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
|
||||
let size: usize = a.size();
|
||||
let steps: usize = k / log_base2k;
|
||||
|
||||
a.raw_mut().rotate_right(n * steps * cols);
|
||||
(0..cols).for_each(|i| {
|
||||
(0..steps).for_each(|j| {
|
||||
a.zero_at(i, j);
|
||||
})
|
||||
});
|
||||
|
||||
let k_rem: usize = k % log_base2k;
|
||||
|
||||
if k_rem != 0 {
|
||||
let carry: &mut [V::Scalar] = cast_mut(tmp_bytes);
|
||||
|
||||
unsafe {
|
||||
std::ptr::write_bytes(carry.as_mut_ptr(), 0, n * size_of::<V::Scalar>());
|
||||
}
|
||||
|
||||
let log_base2k_t: V::Scalar = V::Scalar::try_from(log_base2k).unwrap();
|
||||
let shift: V::Scalar = V::Scalar::try_from(V::Scalar::BITS as usize - k_rem).unwrap();
|
||||
let k_rem_t: V::Scalar = V::Scalar::try_from(k_rem).unwrap();
|
||||
|
||||
(steps..size).for_each(|i| {
|
||||
izip!(carry.iter_mut(), a.at_limb_mut(i).iter_mut()).for_each(|(ci, xi)| {
|
||||
*xi += *ci << log_base2k_t;
|
||||
*ci = get_base_k_carry(*xi, shift);
|
||||
*xi = (*xi - *ci) >> k_rem_t;
|
||||
});
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn get_base_k_carry<T: IntegerType>(x: T, shift: T) -> T {
|
||||
(x << shift) >> shift
|
||||
}
|
||||
|
||||
pub fn rsh_tmp_bytes<T: IntegerType>(n: usize, cols: usize) -> usize {
|
||||
n * cols * std::mem::size_of::<T>()
|
||||
}
|
||||
|
||||
pub fn switch_degree<T: ZnxLayout + ZnxBasics>(b: &mut T, a: &T)
|
||||
where
|
||||
<T as ZnxLayout>::Scalar: IntegerType,
|
||||
{
|
||||
let (n_in, n_out) = (a.n(), b.n());
|
||||
let (gap_in, gap_out): (usize, usize);
|
||||
|
||||
if n_in > n_out {
|
||||
(gap_in, gap_out) = (n_in / n_out, 1)
|
||||
} else {
|
||||
(gap_in, gap_out) = (1, n_out / n_in);
|
||||
b.zero();
|
||||
}
|
||||
|
||||
let size: usize = min(a.size(), b.size());
|
||||
|
||||
(0..size).for_each(|i| {
|
||||
izip!(
|
||||
a.at_limb(i).iter().step_by(gap_in),
|
||||
b.at_limb_mut(i).iter_mut().step_by(gap_out)
|
||||
)
|
||||
.for_each(|(x_in, x_out)| *x_out = *x_in);
|
||||
});
|
||||
}
|
||||
|
||||
pub fn znx_post_process_ternary_op<T: ZnxInfos + ZnxLayout + ZnxBasics, const NEGATE: bool>(c: &mut T, a: &T, b: &T)
|
||||
where
|
||||
<T as ZnxLayout>::Scalar: IntegerType,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
assert_ne!(b.as_ptr(), c.as_ptr());
|
||||
assert_ne!(a.as_ptr(), c.as_ptr());
|
||||
}
|
||||
|
||||
let a_cols: usize = a.cols();
|
||||
let b_cols: usize = b.cols();
|
||||
let c_cols: usize = c.cols();
|
||||
|
||||
let min_ab_cols: usize = min(a_cols, b_cols);
|
||||
let max_ab_cols: usize = max(a_cols, b_cols);
|
||||
|
||||
// Copies shared shared cols between (c, max(a, b))
|
||||
if a_cols != b_cols {
|
||||
let mut x: &T = a;
|
||||
if a_cols < b_cols {
|
||||
x = b;
|
||||
}
|
||||
|
||||
let min_size = min(c.size(), x.size());
|
||||
(min_ab_cols..min(max_ab_cols, c_cols)).for_each(|i| {
|
||||
(0..min_size).for_each(|j| {
|
||||
c.at_poly_mut(i, j).copy_from_slice(x.at_poly(i, j));
|
||||
if NEGATE {
|
||||
c.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x);
|
||||
}
|
||||
});
|
||||
(min_size..c.size()).for_each(|j| {
|
||||
c.zero_at(i, j);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// Zeroes the cols of c > max(a, b).
|
||||
if c_cols > max_ab_cols {
|
||||
(max_ab_cols..c_cols).for_each(|i| {
|
||||
(0..c.size()).for_each(|j| {
|
||||
c.zero_at(i, j);
|
||||
})
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,15 +81,15 @@ impl Encoding for VecZnx {
|
||||
}
|
||||
|
||||
fn encode_vec_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize) {
|
||||
let limbs: usize = (log_k + log_base2k - 1) / log_base2k;
|
||||
let size: usize = (log_k + log_base2k - 1) / log_base2k;
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(
|
||||
limbs <= a.limbs(),
|
||||
"invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.limbs()={}",
|
||||
limbs,
|
||||
a.limbs()
|
||||
size <= a.size(),
|
||||
"invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.size()={}",
|
||||
size,
|
||||
a.size()
|
||||
);
|
||||
assert!(col_i < a.cols());
|
||||
assert!(data.len() <= a.n())
|
||||
@@ -99,7 +99,7 @@ fn encode_vec_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize,
|
||||
let log_k_rem: usize = log_base2k - (log_k % log_base2k);
|
||||
|
||||
// Zeroes coefficients of the i-th column
|
||||
(0..a.limbs()).for_each(|i| unsafe {
|
||||
(0..a.size()).for_each(|i| unsafe {
|
||||
znx_zero_i64_ref(a.n() as u64, a.at_mut_ptr(col_i, i));
|
||||
});
|
||||
|
||||
@@ -107,11 +107,11 @@ fn encode_vec_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize,
|
||||
// values on the last limb.
|
||||
// Else we decompose values base2k.
|
||||
if log_max + log_k_rem < 63 || log_k_rem == log_base2k {
|
||||
a.at_poly_mut(col_i, limbs - 1)[..data_len].copy_from_slice(&data[..data_len]);
|
||||
a.at_poly_mut(col_i, size - 1)[..data_len].copy_from_slice(&data[..data_len]);
|
||||
} else {
|
||||
let mask: i64 = (1 << log_base2k) - 1;
|
||||
let steps: usize = min(limbs, (log_max + log_base2k - 1) / log_base2k);
|
||||
(limbs - steps..limbs)
|
||||
let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k);
|
||||
(size - steps..size)
|
||||
.rev()
|
||||
.enumerate()
|
||||
.for_each(|(i, i_rev)| {
|
||||
@@ -122,8 +122,8 @@ fn encode_vec_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize,
|
||||
|
||||
// Case where self.prec % self.k != 0.
|
||||
if log_k_rem != log_base2k {
|
||||
let steps: usize = min(limbs, (log_max + log_base2k - 1) / log_base2k);
|
||||
(limbs - steps..limbs).rev().for_each(|i| {
|
||||
let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k);
|
||||
(size - steps..size).rev().for_each(|i| {
|
||||
a.at_poly_mut(col_i, i)[..data_len]
|
||||
.iter_mut()
|
||||
.for_each(|x| *x <<= log_k_rem);
|
||||
@@ -132,7 +132,7 @@ fn encode_vec_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize,
|
||||
}
|
||||
|
||||
fn decode_vec_i64(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, data: &mut [i64]) {
|
||||
let limbs: usize = (log_k + log_base2k - 1) / log_base2k;
|
||||
let size: usize = (log_k + log_base2k - 1) / log_base2k;
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(
|
||||
@@ -145,8 +145,8 @@ fn decode_vec_i64(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, dat
|
||||
}
|
||||
data.copy_from_slice(a.at_poly(col_i, 0));
|
||||
let rem: usize = log_base2k - (log_k % log_base2k);
|
||||
(1..limbs).for_each(|i| {
|
||||
if i == limbs - 1 && rem != log_base2k {
|
||||
(1..size).for_each(|i| {
|
||||
if i == size - 1 && rem != log_base2k {
|
||||
let k_rem: usize = log_base2k - rem;
|
||||
izip!(a.at_poly(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| {
|
||||
*y = (*y << k_rem) + (x >> rem);
|
||||
@@ -160,7 +160,7 @@ fn decode_vec_i64(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, dat
|
||||
}
|
||||
|
||||
fn decode_vec_float(a: &VecZnx, col_i: usize, log_base2k: usize, data: &mut [Float]) {
|
||||
let limbs: usize = a.limbs();
|
||||
let size: usize = a.size();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(
|
||||
@@ -172,20 +172,20 @@ fn decode_vec_float(a: &VecZnx, col_i: usize, log_base2k: usize, data: &mut [Flo
|
||||
assert!(col_i < a.cols());
|
||||
}
|
||||
|
||||
let prec: u32 = (log_base2k * limbs) as u32;
|
||||
let prec: u32 = (log_base2k * size) as u32;
|
||||
|
||||
// 2^{log_base2k}
|
||||
let base = Float::with_val(prec, (1 << log_base2k) as f64);
|
||||
|
||||
// y[i] = sum x[j][i] * 2^{-log_base2k*j}
|
||||
(0..limbs).for_each(|i| {
|
||||
(0..size).for_each(|i| {
|
||||
if i == 0 {
|
||||
izip!(a.at_poly(col_i, limbs - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
|
||||
izip!(a.at_poly(col_i, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
|
||||
y.assign(*x);
|
||||
*y /= &base;
|
||||
});
|
||||
} else {
|
||||
izip!(a.at_poly(col_i, limbs - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
|
||||
izip!(a.at_poly(col_i, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
|
||||
*y += Float::with_val(prec, *x);
|
||||
*y /= &base;
|
||||
});
|
||||
@@ -194,32 +194,32 @@ fn decode_vec_float(a: &VecZnx, col_i: usize, log_base2k: usize, data: &mut [Flo
|
||||
}
|
||||
|
||||
fn encode_coeff_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize, i: usize, value: i64, log_max: usize) {
|
||||
let limbs: usize = (log_k + log_base2k - 1) / log_base2k;
|
||||
let size: usize = (log_k + log_base2k - 1) / log_base2k;
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(i < a.n());
|
||||
assert!(
|
||||
limbs <= a.limbs(),
|
||||
"invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.limbs()={}",
|
||||
limbs,
|
||||
a.limbs()
|
||||
size <= a.size(),
|
||||
"invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.size()={}",
|
||||
size,
|
||||
a.size()
|
||||
);
|
||||
assert!(col_i < a.cols());
|
||||
}
|
||||
|
||||
let log_k_rem: usize = log_base2k - (log_k % log_base2k);
|
||||
(0..a.limbs()).for_each(|j| a.at_poly_mut(col_i, j)[i] = 0);
|
||||
(0..a.size()).for_each(|j| a.at_poly_mut(col_i, j)[i] = 0);
|
||||
|
||||
// If 2^{log_base2k} * 2^{log_k_rem} < 2^{63}-1, then we can simply copy
|
||||
// values on the last limb.
|
||||
// Else we decompose values base2k.
|
||||
if log_max + log_k_rem < 63 || log_k_rem == log_base2k {
|
||||
a.at_poly_mut(col_i, limbs - 1)[i] = value;
|
||||
a.at_poly_mut(col_i, size - 1)[i] = value;
|
||||
} else {
|
||||
let mask: i64 = (1 << log_base2k) - 1;
|
||||
let steps: usize = min(limbs, (log_max + log_base2k - 1) / log_base2k);
|
||||
(limbs - steps..limbs)
|
||||
let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k);
|
||||
(size - steps..size)
|
||||
.rev()
|
||||
.enumerate()
|
||||
.for_each(|(j, j_rev)| {
|
||||
@@ -229,8 +229,8 @@ fn encode_coeff_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usiz
|
||||
|
||||
// Case where prec % k != 0.
|
||||
if log_k_rem != log_base2k {
|
||||
let steps: usize = min(limbs, (log_max + log_base2k - 1) / log_base2k);
|
||||
(limbs - steps..limbs).rev().for_each(|j| {
|
||||
let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k);
|
||||
(size - steps..size).rev().for_each(|j| {
|
||||
a.at_poly_mut(col_i, j)[i] <<= log_k_rem;
|
||||
})
|
||||
}
|
||||
@@ -247,7 +247,7 @@ fn decode_coeff_i64(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, i
|
||||
let data: &[i64] = a.raw();
|
||||
let mut res: i64 = data[i];
|
||||
let rem: usize = log_base2k - (log_k % log_base2k);
|
||||
let slice_size: usize = a.n() * a.limbs();
|
||||
let slice_size: usize = a.n() * a.size();
|
||||
(1..cols).for_each(|i| {
|
||||
let x = data[i * slice_size];
|
||||
if i == cols - 1 && rem != log_base2k {
|
||||
@@ -271,9 +271,9 @@ mod tests {
|
||||
let n: usize = 8;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let log_base2k: usize = 17;
|
||||
let limbs: usize = 5;
|
||||
let log_k: usize = limbs * log_base2k - 5;
|
||||
let mut a: VecZnx = VecZnx::new(&module, 2, limbs);
|
||||
let size: usize = 5;
|
||||
let log_k: usize = size * log_base2k - 5;
|
||||
let mut a: VecZnx = VecZnx::new(&module, 2, size);
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let raw: &mut [i64] = a.raw_mut();
|
||||
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
|
||||
@@ -293,9 +293,9 @@ mod tests {
|
||||
let n: usize = 8;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let log_base2k: usize = 17;
|
||||
let limbs: usize = 5;
|
||||
let log_k: usize = limbs * log_base2k - 5;
|
||||
let mut a: VecZnx = VecZnx::new(&module, 2, limbs);
|
||||
let size: usize = 5;
|
||||
let log_k: usize = size * log_base2k - 5;
|
||||
let mut a: VecZnx = VecZnx::new(&module, 2, size);
|
||||
let mut source = Source::new([0u8; 32]);
|
||||
let raw: &mut [i64] = a.raw_mut();
|
||||
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
|
||||
|
||||
@@ -11,6 +11,7 @@ pub mod stats;
|
||||
pub mod vec_znx;
|
||||
pub mod vec_znx_big;
|
||||
pub mod vec_znx_dft;
|
||||
pub mod vec_znx_ops;
|
||||
|
||||
pub use commons::*;
|
||||
pub use encoding::*;
|
||||
@@ -23,6 +24,7 @@ pub use stats::*;
|
||||
pub use vec_znx::*;
|
||||
pub use vec_znx_big::*;
|
||||
pub use vec_znx_dft::*;
|
||||
pub use vec_znx_ops::*;
|
||||
|
||||
pub const GALOISGENERATOR: u64 = 5;
|
||||
pub const DEFAULTALIGN: usize = 64;
|
||||
|
||||
@@ -22,7 +22,7 @@ pub struct MatZnxDft<B: Backend> {
|
||||
/// Number of cols
|
||||
cols: usize,
|
||||
/// The number of small polynomials
|
||||
limbs: usize,
|
||||
size: usize,
|
||||
_marker: PhantomData<B>,
|
||||
}
|
||||
|
||||
@@ -31,10 +31,6 @@ impl<B: Backend> ZnxInfos for MatZnxDft<B> {
|
||||
self.n
|
||||
}
|
||||
|
||||
fn log_n(&self) -> usize {
|
||||
(usize::BITS - (self.n() - 1).leading_zeros()) as _
|
||||
}
|
||||
|
||||
fn rows(&self) -> usize {
|
||||
self.rows
|
||||
}
|
||||
@@ -43,18 +39,14 @@ impl<B: Backend> ZnxInfos for MatZnxDft<B> {
|
||||
self.cols
|
||||
}
|
||||
|
||||
fn limbs(&self) -> usize {
|
||||
self.limbs
|
||||
}
|
||||
|
||||
fn poly_count(&self) -> usize {
|
||||
self.rows * self.cols * self.limbs
|
||||
fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
}
|
||||
|
||||
impl MatZnxDft<FFT64> {
|
||||
fn new(module: &Module<FFT64>, rows: usize, cols: usize, limbs: usize) -> MatZnxDft<FFT64> {
|
||||
let mut data: Vec<u8> = alloc_aligned::<u8>(module.bytes_of_mat_znx_dft(rows, cols, limbs));
|
||||
fn new(module: &Module<FFT64>, rows: usize, cols: usize, size: usize) -> MatZnxDft<FFT64> {
|
||||
let mut data: Vec<u8> = alloc_aligned::<u8>(module.bytes_of_mat_znx_dft(rows, cols, size));
|
||||
let ptr: *mut u8 = data.as_mut_ptr();
|
||||
MatZnxDft::<FFT64> {
|
||||
data: data,
|
||||
@@ -62,7 +54,7 @@ impl MatZnxDft<FFT64> {
|
||||
n: module.n(),
|
||||
rows: rows,
|
||||
cols: cols,
|
||||
limbs: limbs,
|
||||
size: size,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
@@ -115,7 +107,7 @@ impl MatZnxDft<FFT64> {
|
||||
|
||||
fn at_block(&self, row: usize, col: usize, blk: usize) -> &[f64] {
|
||||
let nrows: usize = self.rows();
|
||||
let nsize: usize = self.limbs();
|
||||
let nsize: usize = self.size();
|
||||
if col == (nsize - 1) && (nsize & 1 == 1) {
|
||||
&self.raw()[blk * nrows * nsize * 8 + col * nrows * 8 + row * 8..]
|
||||
} else {
|
||||
@@ -127,7 +119,7 @@ impl MatZnxDft<FFT64> {
|
||||
/// This trait implements methods for vector matrix product,
|
||||
/// that is, multiplying a [VecZnx] with a [VmpPMat].
|
||||
pub trait MatZnxDftOps<B: Backend> {
|
||||
fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, limbs: usize) -> usize;
|
||||
fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> usize;
|
||||
|
||||
/// Allocates a new [VmpPMat] with the given number of rows and columns.
|
||||
///
|
||||
@@ -135,7 +127,7 @@ pub trait MatZnxDftOps<B: Backend> {
|
||||
///
|
||||
/// * `rows`: number of rows (number of [VecZnxDft]).
|
||||
/// * `size`: number of size (number of size of each [VecZnxDft]).
|
||||
fn new_mat_znx_dft(&self, rows: usize, cols: usize, limbs: usize) -> MatZnxDft<B>;
|
||||
fn new_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> MatZnxDft<B>;
|
||||
|
||||
/// Returns the number of bytes needed as scratch space for [VmpPMatOps::vmp_prepare_contiguous].
|
||||
///
|
||||
@@ -351,12 +343,12 @@ pub trait MatZnxDftOps<B: Backend> {
|
||||
}
|
||||
|
||||
impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
fn new_mat_znx_dft(&self, rows: usize, cols: usize, limbs: usize) -> MatZnxDft<FFT64> {
|
||||
MatZnxDft::<FFT64>::new(self, rows, cols, limbs)
|
||||
fn new_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> MatZnxDft<FFT64> {
|
||||
MatZnxDft::<FFT64>::new(self, rows, cols, size)
|
||||
}
|
||||
|
||||
fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, limbs: usize) -> usize {
|
||||
unsafe { vmp::bytes_of_vmp_pmat(self.ptr, rows as u64, (limbs * cols) as u64) as usize }
|
||||
fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> usize {
|
||||
unsafe { vmp::bytes_of_vmp_pmat(self.ptr, rows as u64, (size * cols) as u64) as usize }
|
||||
}
|
||||
|
||||
fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize, size: usize) -> usize {
|
||||
@@ -367,7 +359,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), b.n() * b.poly_count());
|
||||
assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols(), b.limbs()));
|
||||
assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols(), b.size()));
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
@@ -376,7 +368,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
b.as_mut_ptr() as *mut vmp_pmat_t,
|
||||
a.as_ptr(),
|
||||
b.rows() as u64,
|
||||
(b.limbs() * b.cols()) as u64,
|
||||
(b.size() * b.cols()) as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
);
|
||||
}
|
||||
@@ -385,8 +377,8 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
fn vmp_prepare_row(&self, b: &mut MatZnxDft<FFT64>, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), b.limbs() * self.n() * b.cols());
|
||||
assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols(), b.limbs()));
|
||||
assert_eq!(a.len(), b.size() * self.n() * b.cols());
|
||||
assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols(), b.size()));
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
@@ -396,7 +388,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
a.as_ptr(),
|
||||
row_i as u64,
|
||||
b.rows() as u64,
|
||||
(b.limbs() * b.cols()) as u64,
|
||||
(b.size() * b.cols()) as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
);
|
||||
}
|
||||
@@ -406,7 +398,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), b.n());
|
||||
assert_eq!(a.limbs(), b.limbs());
|
||||
assert_eq!(a.size(), b.size());
|
||||
assert_eq!(a.cols(), b.cols());
|
||||
}
|
||||
unsafe {
|
||||
@@ -416,7 +408,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
a.as_ptr() as *const vmp_pmat_t,
|
||||
row_i as u64,
|
||||
a.rows() as u64,
|
||||
(a.limbs() * a.cols()) as u64,
|
||||
(a.size() * a.cols()) as u64,
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -425,7 +417,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), b.n());
|
||||
assert_eq!(a.limbs(), b.limbs());
|
||||
assert_eq!(a.size(), b.size());
|
||||
}
|
||||
unsafe {
|
||||
vmp::vmp_prepare_row_dft(
|
||||
@@ -434,7 +426,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
a.ptr as *const vec_znx_dft_t,
|
||||
row_i as u64,
|
||||
b.rows() as u64,
|
||||
b.limbs() as u64,
|
||||
b.size() as u64,
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -443,7 +435,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), b.n());
|
||||
assert_eq!(a.limbs(), b.limbs());
|
||||
assert_eq!(a.size(), b.size());
|
||||
}
|
||||
unsafe {
|
||||
vmp::vmp_extract_row_dft(
|
||||
@@ -452,7 +444,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
a.as_ptr() as *const vmp_pmat_t,
|
||||
row_i as u64,
|
||||
a.rows() as u64,
|
||||
a.limbs() as u64,
|
||||
a.size() as u64,
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -470,7 +462,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
}
|
||||
|
||||
fn vmp_apply_dft(&self, c: &mut VecZnxDft<FFT64>, a: &VecZnx, b: &MatZnxDft<FFT64>, tmp_bytes: &mut [u8]) {
|
||||
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.limbs(), a.limbs(), b.rows(), b.limbs()));
|
||||
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size()));
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
@@ -479,20 +471,20 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
vmp::vmp_apply_dft(
|
||||
self.ptr,
|
||||
c.ptr as *mut vec_znx_dft_t,
|
||||
c.limbs() as u64,
|
||||
c.size() as u64,
|
||||
a.as_ptr(),
|
||||
a.limbs() as u64,
|
||||
a.size() as u64,
|
||||
(a.n() * a.cols()) as u64,
|
||||
b.as_ptr() as *const vmp_pmat_t,
|
||||
b.rows() as u64,
|
||||
b.limbs() as u64,
|
||||
b.size() as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_apply_dft_add(&self, c: &mut VecZnxDft<FFT64>, a: &VecZnx, b: &MatZnxDft<FFT64>, tmp_bytes: &mut [u8]) {
|
||||
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.limbs(), a.limbs(), b.rows(), b.limbs()));
|
||||
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size()));
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
@@ -501,13 +493,13 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
vmp::vmp_apply_dft_add(
|
||||
self.ptr,
|
||||
c.ptr as *mut vec_znx_dft_t,
|
||||
c.limbs() as u64,
|
||||
c.size() as u64,
|
||||
a.as_ptr(),
|
||||
a.limbs() as u64,
|
||||
(a.n() * a.limbs()) as u64,
|
||||
a.size() as u64,
|
||||
(a.n() * a.size()) as u64,
|
||||
b.as_ptr() as *const vmp_pmat_t,
|
||||
b.rows() as u64,
|
||||
b.limbs() as u64,
|
||||
b.size() as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
}
|
||||
@@ -526,7 +518,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
}
|
||||
|
||||
fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft<FFT64>, a: &VecZnxDft<FFT64>, b: &MatZnxDft<FFT64>, tmp_bytes: &mut [u8]) {
|
||||
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.limbs(), a.limbs(), b.rows(), b.limbs()));
|
||||
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size()));
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
@@ -535,12 +527,12 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
vmp::vmp_apply_dft_to_dft(
|
||||
self.ptr,
|
||||
c.ptr as *mut vec_znx_dft_t,
|
||||
c.limbs() as u64,
|
||||
c.size() as u64,
|
||||
a.ptr as *const vec_znx_dft_t,
|
||||
a.limbs() as u64,
|
||||
a.size() as u64,
|
||||
b.as_ptr() as *const vmp_pmat_t,
|
||||
b.rows() as u64,
|
||||
b.limbs() as u64,
|
||||
b.size() as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
}
|
||||
@@ -553,7 +545,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
b: &MatZnxDft<FFT64>,
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.limbs(), a.limbs(), b.rows(), b.limbs()));
|
||||
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size()));
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
@@ -562,19 +554,19 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
vmp::vmp_apply_dft_to_dft_add(
|
||||
self.ptr,
|
||||
c.ptr as *mut vec_znx_dft_t,
|
||||
c.limbs() as u64,
|
||||
c.size() as u64,
|
||||
a.ptr as *const vec_znx_dft_t,
|
||||
a.limbs() as u64,
|
||||
a.size() as u64,
|
||||
b.as_ptr() as *const vmp_pmat_t,
|
||||
b.rows() as u64,
|
||||
b.limbs() as u64,
|
||||
b.size() as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft<FFT64>, a: &MatZnxDft<FFT64>, tmp_bytes: &mut [u8]) {
|
||||
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(b.limbs(), b.limbs(), a.rows(), a.limbs()));
|
||||
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(b.size(), b.size(), a.rows(), a.size()));
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
@@ -583,12 +575,12 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
vmp::vmp_apply_dft_to_dft(
|
||||
self.ptr,
|
||||
b.ptr as *mut vec_znx_dft_t,
|
||||
b.limbs() as u64,
|
||||
b.size() as u64,
|
||||
b.ptr as *mut vec_znx_dft_t,
|
||||
b.limbs() as u64,
|
||||
b.size() as u64,
|
||||
a.as_ptr() as *const vmp_pmat_t,
|
||||
a.rows() as u64,
|
||||
a.limbs() as u64,
|
||||
a.size() as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -3,8 +3,8 @@ use rand_distr::{Distribution, Normal};
|
||||
use sampling::source::Source;
|
||||
|
||||
pub trait Sampling {
|
||||
/// Fills the first `limbs` limbs with uniform values in \[-2^{log_base2k-1}, 2^{log_base2k-1}\]
|
||||
fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, col_i: usize, limbs: usize, source: &mut Source);
|
||||
/// Fills the first `size` size with uniform values in \[-2^{log_base2k-1}, 2^{log_base2k-1}\]
|
||||
fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, col_i: usize, size: usize, source: &mut Source);
|
||||
|
||||
/// Adds vector sampled according to the provided distribution, scaled by 2^{-log_k} and bounded to \[-bound, bound\].
|
||||
fn add_dist_f64<D: Distribution<f64>>(
|
||||
@@ -32,11 +32,11 @@ pub trait Sampling {
|
||||
}
|
||||
|
||||
impl<B: Backend> Sampling for Module<B> {
|
||||
fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, col_i: usize, limbs: usize, source: &mut Source) {
|
||||
fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, col_i: usize, size: usize, source: &mut Source) {
|
||||
let base2k: u64 = 1 << log_base2k;
|
||||
let mask: u64 = base2k - 1;
|
||||
let base2k_half: i64 = (base2k >> 1) as i64;
|
||||
(0..limbs).for_each(|j| {
|
||||
(0..size).for_each(|j| {
|
||||
a.at_poly_mut(col_i, j)
|
||||
.iter_mut()
|
||||
.for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half);
|
||||
@@ -114,17 +114,17 @@ mod tests {
|
||||
let n: usize = 4096;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let log_base2k: usize = 17;
|
||||
let limbs: usize = 5;
|
||||
let size: usize = 5;
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let cols: usize = 2;
|
||||
let zero: Vec<i64> = vec![0; n];
|
||||
let one_12_sqrt: f64 = 0.28867513459481287;
|
||||
(0..cols).for_each(|col_i| {
|
||||
let mut a: VecZnx = VecZnx::new(&module, cols, limbs);
|
||||
module.fill_uniform(log_base2k, &mut a, col_i, limbs, &mut source);
|
||||
let mut a: VecZnx = VecZnx::new(&module, cols, size);
|
||||
module.fill_uniform(log_base2k, &mut a, col_i, size, &mut source);
|
||||
(0..cols).for_each(|col_j| {
|
||||
if col_j != col_i {
|
||||
(0..limbs).for_each(|limb_i| {
|
||||
(0..size).for_each(|limb_i| {
|
||||
assert_eq!(a.at_poly(col_j, limb_i), zero);
|
||||
})
|
||||
} else {
|
||||
@@ -146,7 +146,7 @@ mod tests {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let log_base2k: usize = 17;
|
||||
let log_k: usize = 2 * 17;
|
||||
let limbs: usize = 5;
|
||||
let size: usize = 5;
|
||||
let sigma: f64 = 3.2;
|
||||
let bound: f64 = 6.0 * sigma;
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
@@ -154,11 +154,11 @@ mod tests {
|
||||
let zero: Vec<i64> = vec![0; n];
|
||||
let k_f64: f64 = (1u64 << log_k as u64) as f64;
|
||||
(0..cols).for_each(|col_i| {
|
||||
let mut a: VecZnx = VecZnx::new(&module, cols, limbs);
|
||||
let mut a: VecZnx = VecZnx::new(&module, cols, size);
|
||||
module.add_normal(log_base2k, &mut a, col_i, log_k, &mut source, sigma, bound);
|
||||
(0..cols).for_each(|col_j| {
|
||||
if col_j != col_i {
|
||||
(0..limbs).for_each(|limb_i| {
|
||||
(0..size).for_each(|limb_i| {
|
||||
assert_eq!(a.at_poly(col_j, limb_i), zero);
|
||||
})
|
||||
} else {
|
||||
|
||||
@@ -120,7 +120,7 @@ impl Scalar {
|
||||
VecZnx {
|
||||
n: self.n,
|
||||
cols: 1,
|
||||
limbs: 1,
|
||||
size: 1,
|
||||
data: Vec::new(),
|
||||
ptr: self.ptr,
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ pub trait Stats {
|
||||
|
||||
impl Stats for VecZnx {
|
||||
fn std(&self, col_i: usize, log_base2k: usize) -> f64 {
|
||||
let prec: u32 = (self.limbs() * log_base2k) as u32;
|
||||
let prec: u32 = (self.size() * log_base2k) as u32;
|
||||
let mut data: Vec<Float> = (0..self.n()).map(|_| Float::with_val(prec, 0)).collect();
|
||||
self.decode_vec_float(col_i, log_base2k, &mut data);
|
||||
// std = sqrt(sum((xi - avg)^2) / n)
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
use crate::Backend;
|
||||
use crate::ZnxBase;
|
||||
use crate::cast_mut;
|
||||
use crate::ffi::vec_znx;
|
||||
use crate::ffi::znx;
|
||||
use crate::{Module, ZnxInfos, ZnxLayout};
|
||||
use crate::switch_degree;
|
||||
use crate::{Module, ZnxBasics, ZnxInfos, ZnxLayout};
|
||||
use crate::{alloc_aligned, assert_alignement};
|
||||
use itertools::izip;
|
||||
use std::cmp::min;
|
||||
|
||||
/// [VecZnx] represents collection of contiguously stacked vector of small norm polynomials of
|
||||
@@ -26,8 +25,8 @@ pub struct VecZnx {
|
||||
/// The number of polynomials
|
||||
pub cols: usize,
|
||||
|
||||
/// The number of limbs per polynomial (a.k.a small polynomials).
|
||||
pub limbs: usize,
|
||||
/// The number of size per polynomial (a.k.a small polynomials).
|
||||
pub size: usize,
|
||||
|
||||
/// Polynomial coefficients, as a contiguous array. Each col is equally spaced by n.
|
||||
pub data: Vec<i64>,
|
||||
@@ -41,10 +40,6 @@ impl ZnxInfos for VecZnx {
|
||||
self.n
|
||||
}
|
||||
|
||||
fn log_n(&self) -> usize {
|
||||
(usize::BITS - (self.n() - 1).leading_zeros()) as _
|
||||
}
|
||||
|
||||
fn rows(&self) -> usize {
|
||||
1
|
||||
}
|
||||
@@ -53,12 +48,8 @@ impl ZnxInfos for VecZnx {
|
||||
self.cols
|
||||
}
|
||||
|
||||
fn limbs(&self) -> usize {
|
||||
self.limbs
|
||||
}
|
||||
|
||||
fn poly_count(&self) -> usize {
|
||||
self.cols * self.limbs
|
||||
fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,6 +65,8 @@ impl ZnxLayout for VecZnx {
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxBasics for VecZnx {}
|
||||
|
||||
/// Copies the coefficients of `a` on the receiver.
|
||||
/// Copy is done with the minimum size matching both backing arrays.
|
||||
/// Panics if the cols do not match.
|
||||
@@ -89,28 +82,28 @@ impl<B: Backend> ZnxBase<B> for VecZnx {
|
||||
type Scalar = i64;
|
||||
|
||||
/// Allocates a new [VecZnx] composed of #size polynomials of Z\[X\].
|
||||
fn new(module: &Module<B>, cols: usize, limbs: usize) -> Self {
|
||||
fn new(module: &Module<B>, cols: usize, size: usize) -> Self {
|
||||
let n: usize = module.n();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(n > 0);
|
||||
assert!(n & (n - 1) == 0);
|
||||
assert!(cols > 0);
|
||||
assert!(limbs > 0);
|
||||
assert!(size > 0);
|
||||
}
|
||||
let mut data: Vec<i64> = alloc_aligned::<i64>(Self::bytes_of(module, cols, limbs));
|
||||
let mut data: Vec<i64> = alloc_aligned::<i64>(Self::bytes_of(module, cols, size));
|
||||
let ptr: *mut i64 = data.as_mut_ptr();
|
||||
Self {
|
||||
n: n,
|
||||
cols: cols,
|
||||
limbs: limbs,
|
||||
size: size,
|
||||
data: data,
|
||||
ptr: ptr,
|
||||
}
|
||||
}
|
||||
|
||||
fn bytes_of(module: &Module<B>, cols: usize, limbs: usize) -> usize {
|
||||
module.n() * cols * limbs * size_of::<i64>()
|
||||
fn bytes_of(module: &Module<B>, cols: usize, size: usize) -> usize {
|
||||
module.n() * cols * size * size_of::<i64>()
|
||||
}
|
||||
|
||||
/// Returns a new struct implementing [VecZnx] with the provided data as backing array.
|
||||
@@ -118,14 +111,14 @@ impl<B: Backend> ZnxBase<B> for VecZnx {
|
||||
/// The struct will take ownership of buf[..[Self::bytes_of]]
|
||||
///
|
||||
/// User must ensure that data is properly alligned and that
|
||||
/// the limbs of data is equal to [Self::bytes_of].
|
||||
fn from_bytes(module: &Module<B>, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self {
|
||||
/// the size of data is equal to [Self::bytes_of].
|
||||
fn from_bytes(module: &Module<B>, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
|
||||
let n: usize = module.n();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(cols > 0);
|
||||
assert!(limbs > 0);
|
||||
assert_eq!(bytes.len(), Self::bytes_of(module, cols, limbs));
|
||||
assert!(size > 0);
|
||||
assert_eq!(bytes.len(), Self::bytes_of(module, cols, size));
|
||||
assert_alignement(bytes.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
@@ -134,25 +127,25 @@ impl<B: Backend> ZnxBase<B> for VecZnx {
|
||||
Self {
|
||||
n: n,
|
||||
cols: cols,
|
||||
limbs: limbs,
|
||||
size: size,
|
||||
data: Vec::from_raw_parts(ptr, bytes.len(), bytes.len()),
|
||||
ptr: ptr,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn from_bytes_borrow(module: &Module<B>, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self {
|
||||
fn from_bytes_borrow(module: &Module<B>, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(cols > 0);
|
||||
assert!(limbs > 0);
|
||||
assert!(bytes.len() >= Self::bytes_of(module, cols, limbs));
|
||||
assert!(size > 0);
|
||||
assert!(bytes.len() >= Self::bytes_of(module, cols, size));
|
||||
assert_alignement(bytes.as_ptr());
|
||||
}
|
||||
Self {
|
||||
n: module.n(),
|
||||
cols: cols,
|
||||
limbs: limbs,
|
||||
size: size,
|
||||
data: Vec::new(),
|
||||
ptr: bytes.as_mut_ptr() as *mut i64,
|
||||
}
|
||||
@@ -173,16 +166,16 @@ impl VecZnx {
|
||||
|
||||
if !self.borrowing() {
|
||||
self.data
|
||||
.truncate(self.n() * self.cols() * (self.limbs() - k / log_base2k));
|
||||
.truncate(self.n() * self.cols() * (self.size() - k / log_base2k));
|
||||
}
|
||||
|
||||
self.limbs -= k / log_base2k;
|
||||
self.size -= k / log_base2k;
|
||||
|
||||
let k_rem: usize = k % log_base2k;
|
||||
|
||||
if k_rem != 0 {
|
||||
let mask: i64 = ((1 << (log_base2k - k_rem - 1)) - 1) << k_rem;
|
||||
self.at_limb_mut(self.limbs() - 1)
|
||||
self.at_limb_mut(self.size() - 1)
|
||||
.iter_mut()
|
||||
.for_each(|x: &mut i64| *x &= mask)
|
||||
}
|
||||
@@ -196,52 +189,22 @@ impl VecZnx {
|
||||
self.data.len() == 0
|
||||
}
|
||||
|
||||
pub fn zero(&mut self) {
|
||||
unsafe { znx::znx_zero_i64_ref((self.n * self.poly_count()) as u64, self.ptr) }
|
||||
}
|
||||
|
||||
pub fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]) {
|
||||
normalize(log_base2k, self, carry)
|
||||
}
|
||||
|
||||
pub fn rsh(&mut self, log_base2k: usize, k: usize, carry: &mut [u8]) {
|
||||
rsh(log_base2k, self, k, carry)
|
||||
}
|
||||
|
||||
pub fn switch_degree(&self, a: &mut Self) {
|
||||
switch_degree(a, self)
|
||||
}
|
||||
|
||||
// Prints the first `n` coefficients of each limb
|
||||
pub fn print(&self, n: usize) {
|
||||
(0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n]))
|
||||
(0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n]))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn switch_degree(b: &mut VecZnx, a: &VecZnx) {
|
||||
let (n_in, n_out) = (a.n(), b.n());
|
||||
let (gap_in, gap_out): (usize, usize);
|
||||
|
||||
if n_in > n_out {
|
||||
(gap_in, gap_out) = (n_in / n_out, 1)
|
||||
} else {
|
||||
(gap_in, gap_out) = (1, n_out / n_in);
|
||||
b.zero();
|
||||
}
|
||||
|
||||
let limbs: usize = min(a.limbs(), b.limbs());
|
||||
|
||||
(0..limbs).for_each(|i| {
|
||||
izip!(
|
||||
a.at_limb(i).iter().step_by(gap_in),
|
||||
b.at_limb_mut(i).iter_mut().step_by(gap_out)
|
||||
)
|
||||
.for_each(|(x_in, x_out)| *x_out = *x_in);
|
||||
});
|
||||
}
|
||||
|
||||
fn normalize_tmp_bytes(n: usize, limbs: usize) -> usize {
|
||||
n * limbs * std::mem::size_of::<i64>()
|
||||
fn normalize_tmp_bytes(n: usize, size: usize) -> usize {
|
||||
n * size * std::mem::size_of::<i64>()
|
||||
}
|
||||
|
||||
fn normalize(log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]) {
|
||||
@@ -264,7 +227,7 @@ fn normalize(log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]) {
|
||||
|
||||
unsafe {
|
||||
znx::znx_zero_i64_ref(n as u64, carry_i64.as_mut_ptr());
|
||||
(0..a.limbs()).rev().for_each(|i| {
|
||||
(0..a.size()).rev().for_each(|i| {
|
||||
znx::znx_normalize(
|
||||
(n * cols) as u64,
|
||||
log_base2k as u64,
|
||||
@@ -276,462 +239,3 @@ fn normalize(log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]) {
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub fn rsh_tmp_bytes(n: usize, limbs: usize) -> usize {
|
||||
n * limbs * std::mem::size_of::<i64>()
|
||||
}
|
||||
|
||||
pub fn rsh(log_base2k: usize, a: &mut VecZnx, k: usize, tmp_bytes: &mut [u8]) {
|
||||
let n: usize = a.n();
|
||||
let limbs: usize = a.limbs();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(
|
||||
tmp_bytes.len() >= rsh_tmp_bytes(n, limbs),
|
||||
"invalid carry: carry.len()/8={} < rsh_tmp_bytes({}, {})",
|
||||
tmp_bytes.len() >> 3,
|
||||
n,
|
||||
limbs,
|
||||
);
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
|
||||
let limbs: usize = a.limbs();
|
||||
let size_steps: usize = k / log_base2k;
|
||||
|
||||
a.raw_mut().rotate_right(n * limbs * size_steps);
|
||||
unsafe {
|
||||
znx::znx_zero_i64_ref((n * limbs * size_steps) as u64, a.as_mut_ptr());
|
||||
}
|
||||
|
||||
let k_rem = k % log_base2k;
|
||||
|
||||
if k_rem != 0 {
|
||||
let carry_i64: &mut [i64] = cast_mut(tmp_bytes);
|
||||
|
||||
unsafe {
|
||||
znx::znx_zero_i64_ref((n * limbs) as u64, carry_i64.as_mut_ptr());
|
||||
}
|
||||
|
||||
let log_base2k: usize = log_base2k;
|
||||
|
||||
(size_steps..limbs).for_each(|i| {
|
||||
izip!(carry_i64.iter_mut(), a.at_limb_mut(i).iter_mut()).for_each(|(ci, xi)| {
|
||||
*xi += *ci << log_base2k;
|
||||
*ci = get_base_k_carry(*xi, k_rem);
|
||||
*xi = (*xi - *ci) >> k_rem;
|
||||
});
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn get_base_k_carry(x: i64, k: usize) -> i64 {
|
||||
(x << 64 - k) >> (64 - k)
|
||||
}
|
||||
|
||||
pub trait VecZnxOps {
|
||||
/// Allocates a new [VecZnx].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `cols`: the number of polynomials.
|
||||
/// * `limbs`: the number of limbs per polynomial (a.k.a small polynomials).
|
||||
fn new_vec_znx(&self, cols: usize, limbs: usize) -> VecZnx;
|
||||
|
||||
fn new_vec_znx_from_bytes(&self, cols: usize, limbs: usize, bytes: &mut [u8]) -> VecZnx;
|
||||
fn new_vec_znx_from_bytes_borrow(&self, cols: usize, limbs: usize, tmp_bytes: &mut [u8]) -> VecZnx;
|
||||
|
||||
/// Returns the minimum number of bytes necessary to allocate
|
||||
/// a new [VecZnx] through [VecZnx::from_bytes].
|
||||
fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize;
|
||||
|
||||
fn vec_znx_normalize_tmp_bytes(&self, cols: usize) -> usize;
|
||||
|
||||
/// c <- a + b.
|
||||
fn vec_znx_add(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx);
|
||||
|
||||
/// b <- b + a.
|
||||
fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx);
|
||||
|
||||
/// c <- a - b.
|
||||
fn vec_znx_sub(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx);
|
||||
|
||||
/// b <- a - b.
|
||||
fn vec_znx_sub_ab_inplace(&self, b: &mut VecZnx, a: &VecZnx);
|
||||
|
||||
/// b <- b - a.
|
||||
fn vec_znx_sub_ba_inplace(&self, b: &mut VecZnx, a: &VecZnx);
|
||||
|
||||
/// b <- -a.
|
||||
fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx);
|
||||
|
||||
/// b <- -b.
|
||||
fn vec_znx_negate_inplace(&self, a: &mut VecZnx);
|
||||
|
||||
/// b <- a * X^k (mod X^{n} + 1)
|
||||
fn vec_znx_rotate(&self, k: i64, b: &mut VecZnx, a: &VecZnx);
|
||||
|
||||
/// a <- a * X^k (mod X^{n} + 1)
|
||||
fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx);
|
||||
|
||||
/// b <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1))
|
||||
fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx);
|
||||
|
||||
/// a <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1))
|
||||
fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx);
|
||||
|
||||
/// Splits b into subrings and copies them them into a.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// This method requires that all [VecZnx] of b have the same ring degree
|
||||
/// and that b.n() * b.len() <= a.n()
|
||||
fn vec_znx_split(&self, b: &mut Vec<VecZnx>, a: &VecZnx, buf: &mut VecZnx);
|
||||
|
||||
/// Merges the subrings a into b.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// This method requires that all [VecZnx] of a have the same ring degree
|
||||
/// and that a.n() * a.len() <= b.n()
|
||||
fn vec_znx_merge(&self, b: &mut VecZnx, a: &Vec<VecZnx>);
|
||||
}
|
||||
|
||||
impl<B: Backend> VecZnxOps for Module<B> {
|
||||
fn new_vec_znx(&self, cols: usize, limbs: usize) -> VecZnx {
|
||||
VecZnx::new(self, cols, limbs)
|
||||
}
|
||||
|
||||
fn bytes_of_vec_znx(&self, cols: usize, limbs: usize) -> usize {
|
||||
VecZnx::bytes_of(self, cols, limbs)
|
||||
}
|
||||
|
||||
fn new_vec_znx_from_bytes(&self, cols: usize, limbs: usize, bytes: &mut [u8]) -> VecZnx {
|
||||
VecZnx::from_bytes(self, cols, limbs, bytes)
|
||||
}
|
||||
|
||||
fn new_vec_znx_from_bytes_borrow(&self, cols: usize, limbs: usize, tmp_bytes: &mut [u8]) -> VecZnx {
|
||||
VecZnx::from_bytes_borrow(self, cols, limbs, tmp_bytes)
|
||||
}
|
||||
|
||||
fn vec_znx_normalize_tmp_bytes(&self, cols: usize) -> usize {
|
||||
unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.ptr) as usize * cols }
|
||||
}
|
||||
|
||||
// c <- a + b
|
||||
fn vec_znx_add(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx) {
|
||||
let n: usize = self.n();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(c.n(), n);
|
||||
assert_eq!(a.n(), n);
|
||||
assert_eq!(b.n(), n);
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_add(
|
||||
self.ptr,
|
||||
c.as_mut_ptr(),
|
||||
c.limbs() as u64,
|
||||
(n * c.cols()) as u64,
|
||||
a.as_ptr(),
|
||||
a.limbs() as u64,
|
||||
(n * a.cols()) as u64,
|
||||
b.as_ptr(),
|
||||
b.limbs() as u64,
|
||||
(n * b.cols()) as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// b <- a + b
|
||||
fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx) {
|
||||
let n: usize = self.n();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), n);
|
||||
assert_eq!(b.n(), n);
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_add(
|
||||
self.ptr,
|
||||
b.as_mut_ptr(),
|
||||
b.limbs() as u64,
|
||||
(n * b.cols()) as u64,
|
||||
a.as_ptr(),
|
||||
a.limbs() as u64,
|
||||
(n * a.cols()) as u64,
|
||||
b.as_ptr(),
|
||||
b.limbs() as u64,
|
||||
(n * b.cols()) as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// c <- a + b
|
||||
fn vec_znx_sub(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx) {
|
||||
let n: usize = self.n();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(c.n(), n);
|
||||
assert_eq!(a.n(), n);
|
||||
assert_eq!(b.n(), n);
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
self.ptr,
|
||||
c.as_mut_ptr(),
|
||||
c.limbs() as u64,
|
||||
(n * c.cols()) as u64,
|
||||
a.as_ptr(),
|
||||
a.limbs() as u64,
|
||||
(n * a.cols()) as u64,
|
||||
b.as_ptr(),
|
||||
b.limbs() as u64,
|
||||
(n * b.cols()) as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// b <- a - b
|
||||
fn vec_znx_sub_ab_inplace(&self, b: &mut VecZnx, a: &VecZnx) {
|
||||
let n: usize = self.n();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), n);
|
||||
assert_eq!(b.n(), n);
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
self.ptr,
|
||||
b.as_mut_ptr(),
|
||||
b.limbs() as u64,
|
||||
(n * b.cols()) as u64,
|
||||
a.as_ptr(),
|
||||
a.limbs() as u64,
|
||||
(n * a.cols()) as u64,
|
||||
b.as_ptr(),
|
||||
b.limbs() as u64,
|
||||
(n * b.cols()) as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// b <- b - a
|
||||
fn vec_znx_sub_ba_inplace(&self, b: &mut VecZnx, a: &VecZnx) {
|
||||
let n: usize = self.n();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), n);
|
||||
assert_eq!(b.n(), n);
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
self.ptr,
|
||||
b.as_mut_ptr(),
|
||||
b.limbs() as u64,
|
||||
(n * b.cols()) as u64,
|
||||
b.as_ptr(),
|
||||
b.limbs() as u64,
|
||||
(n * b.cols()) as u64,
|
||||
a.as_ptr(),
|
||||
a.limbs() as u64,
|
||||
(n * a.cols()) as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx) {
|
||||
let n: usize = self.n();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), n);
|
||||
assert_eq!(b.n(), n);
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_negate(
|
||||
self.ptr,
|
||||
b.as_mut_ptr(),
|
||||
b.limbs() as u64,
|
||||
(n * b.cols()) as u64,
|
||||
a.as_ptr(),
|
||||
a.limbs() as u64,
|
||||
(n * a.cols()) as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_negate_inplace(&self, a: &mut VecZnx) {
|
||||
let n: usize = self.n();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), n);
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_negate(
|
||||
self.ptr,
|
||||
a.as_mut_ptr(),
|
||||
a.limbs() as u64,
|
||||
(n * a.cols()) as u64,
|
||||
a.as_ptr(),
|
||||
a.limbs() as u64,
|
||||
(n * a.cols()) as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_rotate(&self, k: i64, b: &mut VecZnx, a: &VecZnx) {
|
||||
let n: usize = self.n();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), n);
|
||||
assert_eq!(b.n(), n);
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_rotate(
|
||||
self.ptr,
|
||||
k,
|
||||
b.as_mut_ptr(),
|
||||
b.limbs() as u64,
|
||||
(n * b.cols()) as u64,
|
||||
a.as_ptr(),
|
||||
a.limbs() as u64,
|
||||
(n * a.cols()) as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx) {
|
||||
let n: usize = self.n();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), n);
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_rotate(
|
||||
self.ptr,
|
||||
k,
|
||||
a.as_mut_ptr(),
|
||||
a.limbs() as u64,
|
||||
(n * a.cols()) as u64,
|
||||
a.as_ptr(),
|
||||
a.limbs() as u64,
|
||||
(n * a.cols()) as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Maps X^i to X^{ik} mod X^{n}+1. The mapping is applied independently on each size.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `a`: input.
|
||||
/// * `b`: output.
|
||||
/// * `k`: the power to which to map each coefficients.
|
||||
/// * `a_size`: the number of a_size on which to apply the mapping.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// The method will panic if the argument `a` is greater than `a.limbs()`.
|
||||
fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx) {
|
||||
let n: usize = self.n();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), n);
|
||||
assert_eq!(b.n(), n);
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
self.ptr,
|
||||
k,
|
||||
b.as_mut_ptr(),
|
||||
b.limbs() as u64,
|
||||
(n * b.cols()) as u64,
|
||||
a.as_ptr(),
|
||||
a.limbs() as u64,
|
||||
(n * a.cols()) as u64,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Maps X^i to X^{ik} mod X^{n}+1. The mapping is applied independently on each size.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `a`: input and output.
|
||||
/// * `k`: the power to which to map each coefficients.
|
||||
/// * `a_size`: the number of size on which to apply the mapping.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// The method will panic if the argument `size` is greater than `self.limbs()`.
|
||||
fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx) {
|
||||
let n: usize = self.n();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), n);
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
self.ptr,
|
||||
k,
|
||||
a.as_mut_ptr(),
|
||||
a.limbs() as u64,
|
||||
(n * a.cols()) as u64,
|
||||
a.as_ptr(),
|
||||
a.limbs() as u64,
|
||||
(n * a.cols()) as u64,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_split(&self, b: &mut Vec<VecZnx>, a: &VecZnx, buf: &mut VecZnx) {
|
||||
let (n_in, n_out) = (a.n(), b[0].n());
|
||||
|
||||
debug_assert!(
|
||||
n_out < n_in,
|
||||
"invalid a: output ring degree should be smaller"
|
||||
);
|
||||
b[1..].iter().for_each(|bi| {
|
||||
debug_assert_eq!(
|
||||
bi.n(),
|
||||
n_out,
|
||||
"invalid input a: all VecZnx must have the same degree"
|
||||
)
|
||||
});
|
||||
|
||||
b.iter_mut().enumerate().for_each(|(i, bi)| {
|
||||
if i == 0 {
|
||||
switch_degree(bi, a);
|
||||
self.vec_znx_rotate(-1, buf, a);
|
||||
} else {
|
||||
switch_degree(bi, buf);
|
||||
self.vec_znx_rotate_inplace(-1, buf);
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn vec_znx_merge(&self, b: &mut VecZnx, a: &Vec<VecZnx>) {
|
||||
let (n_in, n_out) = (b.n(), a[0].n());
|
||||
|
||||
debug_assert!(
|
||||
n_out < n_in,
|
||||
"invalid a: output ring degree should be smaller"
|
||||
);
|
||||
a[1..].iter().for_each(|ai| {
|
||||
debug_assert_eq!(
|
||||
ai.n(),
|
||||
n_out,
|
||||
"invalid input a: all VecZnx must have the same degree"
|
||||
)
|
||||
});
|
||||
|
||||
a.iter().enumerate().for_each(|(_, ai)| {
|
||||
switch_degree(b, ai);
|
||||
self.vec_znx_rotate_inplace(-1, b);
|
||||
});
|
||||
|
||||
self.vec_znx_rotate_inplace(a.len() as i64, b);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,43 +7,43 @@ pub struct VecZnxBig<B: Backend> {
|
||||
pub ptr: *mut u8,
|
||||
pub n: usize,
|
||||
pub cols: usize,
|
||||
pub limbs: usize,
|
||||
pub size: usize,
|
||||
pub _marker: PhantomData<B>,
|
||||
}
|
||||
impl<B: Backend> ZnxBase<B> for VecZnxBig<B> {
|
||||
type Scalar = u8;
|
||||
|
||||
fn new(module: &Module<B>, cols: usize, limbs: usize) -> Self {
|
||||
fn new(module: &Module<B>, cols: usize, size: usize) -> Self {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(cols > 0);
|
||||
assert!(limbs > 0);
|
||||
assert!(size > 0);
|
||||
}
|
||||
let mut data: Vec<Self::Scalar> = alloc_aligned::<u8>(Self::bytes_of(module, cols, limbs));
|
||||
let mut data: Vec<Self::Scalar> = alloc_aligned::<u8>(Self::bytes_of(module, cols, size));
|
||||
let ptr: *mut Self::Scalar = data.as_mut_ptr();
|
||||
Self {
|
||||
data: data,
|
||||
ptr: ptr,
|
||||
n: module.n(),
|
||||
cols: cols,
|
||||
limbs: limbs,
|
||||
size: size,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
fn bytes_of(module: &Module<B>, cols: usize, limbs: usize) -> usize {
|
||||
unsafe { vec_znx_big::bytes_of_vec_znx_big(module.ptr, limbs as u64) as usize * cols }
|
||||
fn bytes_of(module: &Module<B>, cols: usize, size: usize) -> usize {
|
||||
unsafe { vec_znx_big::bytes_of_vec_znx_big(module.ptr, size as u64) as usize * cols }
|
||||
}
|
||||
|
||||
/// Returns a new [VecZnxBig] with the provided data as backing array.
|
||||
/// User must ensure that data is properly alligned and that
|
||||
/// the size of data is at least equal to [Module::bytes_of_vec_znx_big].
|
||||
fn from_bytes(module: &Module<B>, cols: usize, limbs: usize, bytes: &mut [Self::Scalar]) -> Self {
|
||||
fn from_bytes(module: &Module<B>, cols: usize, size: usize, bytes: &mut [Self::Scalar]) -> Self {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(cols > 0);
|
||||
assert!(limbs > 0);
|
||||
assert_eq!(bytes.len(), Self::bytes_of(module, cols, limbs));
|
||||
assert!(size > 0);
|
||||
assert_eq!(bytes.len(), Self::bytes_of(module, cols, size));
|
||||
assert_alignement(bytes.as_ptr())
|
||||
};
|
||||
unsafe {
|
||||
@@ -52,18 +52,18 @@ impl<B: Backend> ZnxBase<B> for VecZnxBig<B> {
|
||||
ptr: bytes.as_mut_ptr(),
|
||||
n: module.n(),
|
||||
cols: cols,
|
||||
limbs: limbs,
|
||||
size: size,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn from_bytes_borrow(module: &Module<B>, cols: usize, limbs: usize, bytes: &mut [Self::Scalar]) -> Self {
|
||||
fn from_bytes_borrow(module: &Module<B>, cols: usize, size: usize, bytes: &mut [Self::Scalar]) -> Self {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(cols > 0);
|
||||
assert!(limbs > 0);
|
||||
assert_eq!(bytes.len(), Self::bytes_of(module, cols, limbs));
|
||||
assert!(size > 0);
|
||||
assert_eq!(bytes.len(), Self::bytes_of(module, cols, size));
|
||||
assert_alignement(bytes.as_ptr());
|
||||
}
|
||||
Self {
|
||||
@@ -71,17 +71,13 @@ impl<B: Backend> ZnxBase<B> for VecZnxBig<B> {
|
||||
ptr: bytes.as_mut_ptr(),
|
||||
n: module.n(),
|
||||
cols: cols,
|
||||
limbs: limbs,
|
||||
size: size,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ZnxInfos for VecZnxBig<B> {
|
||||
fn log_n(&self) -> usize {
|
||||
(usize::BITS - (self.n - 1).leading_zeros()) as _
|
||||
}
|
||||
|
||||
fn n(&self) -> usize {
|
||||
self.n
|
||||
}
|
||||
@@ -94,12 +90,8 @@ impl<B: Backend> ZnxInfos for VecZnxBig<B> {
|
||||
1
|
||||
}
|
||||
|
||||
fn limbs(&self) -> usize {
|
||||
self.limbs
|
||||
}
|
||||
|
||||
fn poly_count(&self) -> usize {
|
||||
self.cols * self.limbs
|
||||
fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
}
|
||||
|
||||
@@ -117,13 +109,13 @@ impl ZnxLayout for VecZnxBig<FFT64> {
|
||||
|
||||
impl VecZnxBig<FFT64> {
|
||||
pub fn print(&self, n: usize) {
|
||||
(0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n]));
|
||||
(0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n]));
|
||||
}
|
||||
}
|
||||
|
||||
pub trait VecZnxBigOps<B: Backend> {
|
||||
/// Allocates a vector Z[X]/(X^N+1) that stores not normalized values.
|
||||
fn new_vec_znx_big(&self, cols: usize, limbs: usize) -> VecZnxBig<B>;
|
||||
fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBig<B>;
|
||||
|
||||
/// Returns a new [VecZnxBig] with the provided bytes array as backing array.
|
||||
///
|
||||
@@ -132,12 +124,12 @@ pub trait VecZnxBigOps<B: Backend> {
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `cols`: the number of polynomials..
|
||||
/// * `limbs`: the number of limbs (a.k.a small polynomials) per polynomial.
|
||||
/// * `size`: the number of size (a.k.a small polynomials) per polynomial.
|
||||
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big].
|
||||
///
|
||||
/// # Panics
|
||||
/// If `bytes.len()` < [Module::bytes_of_vec_znx_big].
|
||||
fn new_vec_znx_big_from_bytes(&self, cols: usize, limbs: usize, bytes: &mut [u8]) -> VecZnxBig<B>;
|
||||
fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxBig<B>;
|
||||
|
||||
/// Returns a new [VecZnxBig] with the provided bytes array as backing array.
|
||||
///
|
||||
@@ -146,25 +138,25 @@ pub trait VecZnxBigOps<B: Backend> {
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `cols`: the number of polynomials..
|
||||
/// * `limbs`: the number of limbs (a.k.a small polynomials) per polynomial.
|
||||
/// * `size`: the number of size (a.k.a small polynomials) per polynomial.
|
||||
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big].
|
||||
///
|
||||
/// # Panics
|
||||
/// If `bytes.len()` < [Module::bytes_of_vec_znx_big].
|
||||
fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, limbs: usize, tmp_bytes: &mut [u8]) -> VecZnxBig<B>;
|
||||
fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig<B>;
|
||||
|
||||
/// Returns the minimum number of bytes necessary to allocate
|
||||
/// a new [VecZnxBig] through [VecZnxBig::from_bytes].
|
||||
fn bytes_of_vec_znx_big(&self, cols: usize, limbs: usize) -> usize;
|
||||
fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize;
|
||||
|
||||
/// b[VecZnxBig] <- b[VecZnxBig] - a[VecZnx]
|
||||
///
|
||||
/// # Behavior
|
||||
///
|
||||
/// [VecZnxBig] (3 cols and 4 limbs)
|
||||
/// [VecZnxBig] (3 cols and 4 size)
|
||||
/// [a0, b0, c0] [a1, b1, c1] [a2, b2, c2] [a3, b3, c3]
|
||||
/// -
|
||||
/// [VecZnx] (2 cols and 3 limbs)
|
||||
/// [VecZnx] (2 cols and 3 size)
|
||||
/// [d0, e0] [d1, e1] [d2, e2]
|
||||
/// =
|
||||
/// [a0-d0, b0-e0, c0] [a1-d1, b1-e1, c1] [a2-d2, b2-e2, c2] [a3, b3, c3]
|
||||
@@ -203,26 +195,26 @@ pub trait VecZnxBigOps<B: Backend> {
|
||||
}
|
||||
|
||||
impl VecZnxBigOps<FFT64> for Module<FFT64> {
|
||||
fn new_vec_znx_big(&self, cols: usize, limbs: usize) -> VecZnxBig<FFT64> {
|
||||
VecZnxBig::new(self, cols, limbs)
|
||||
fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBig<FFT64> {
|
||||
VecZnxBig::new(self, cols, size)
|
||||
}
|
||||
|
||||
fn new_vec_znx_big_from_bytes(&self, cols: usize, limbs: usize, bytes: &mut [u8]) -> VecZnxBig<FFT64> {
|
||||
VecZnxBig::from_bytes(self, cols, limbs, bytes)
|
||||
fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxBig<FFT64> {
|
||||
VecZnxBig::from_bytes(self, cols, size, bytes)
|
||||
}
|
||||
|
||||
fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, limbs: usize, tmp_bytes: &mut [u8]) -> VecZnxBig<FFT64> {
|
||||
VecZnxBig::from_bytes_borrow(self, cols, limbs, tmp_bytes)
|
||||
fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig<FFT64> {
|
||||
VecZnxBig::from_bytes_borrow(self, cols, size, tmp_bytes)
|
||||
}
|
||||
|
||||
fn bytes_of_vec_znx_big(&self, cols: usize, limbs: usize) -> usize {
|
||||
VecZnxBig::bytes_of(self, cols, limbs)
|
||||
fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize {
|
||||
VecZnxBig::bytes_of(self, cols, size)
|
||||
}
|
||||
|
||||
/// [VecZnxBig] (3 cols and 4 limbs)
|
||||
/// [VecZnxBig] (3 cols and 4 size)
|
||||
/// [a0, b0, c0] [a1, b1, c1] [a2, b2, c2] [a3, b3, c3]
|
||||
/// -
|
||||
/// [VecZnx] (2 cols and 3 limbs)
|
||||
/// [VecZnx] (2 cols and 3 size)
|
||||
/// [d0, e0] [d1, e1] [d2, e2]
|
||||
/// =
|
||||
/// [a0-d0, b0-e0, c0] [a1-d1, b1-e1, c1] [a2-d2, b2-e2, c2] [a3, b3, c3]
|
||||
@@ -306,10 +298,10 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
|
||||
self.ptr,
|
||||
log_base2k as u64,
|
||||
b.as_mut_ptr(),
|
||||
b.limbs() as u64,
|
||||
b.size() as u64,
|
||||
b.n() as u64,
|
||||
a.ptr as *mut vec_znx_big_t,
|
||||
a.limbs() as u64,
|
||||
a.size() as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
}
|
||||
@@ -344,7 +336,7 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
|
||||
self.ptr,
|
||||
log_base2k as u64,
|
||||
res.as_mut_ptr(),
|
||||
res.limbs() as u64,
|
||||
res.size() as u64,
|
||||
res.n() as u64,
|
||||
a.ptr as *mut vec_znx_big_t,
|
||||
a_range_begin as u64,
|
||||
|
||||
@@ -10,44 +10,44 @@ pub struct VecZnxDft<B: Backend> {
|
||||
pub ptr: *mut u8,
|
||||
pub n: usize,
|
||||
pub cols: usize,
|
||||
pub limbs: usize,
|
||||
pub size: usize,
|
||||
pub _marker: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> ZnxBase<B> for VecZnxDft<B> {
|
||||
type Scalar = u8;
|
||||
|
||||
fn new(module: &Module<B>, cols: usize, limbs: usize) -> Self {
|
||||
fn new(module: &Module<B>, cols: usize, size: usize) -> Self {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(cols > 0);
|
||||
assert!(limbs > 0);
|
||||
assert!(size > 0);
|
||||
}
|
||||
let mut data: Vec<Self::Scalar> = alloc_aligned(Self::bytes_of(module, cols, limbs));
|
||||
let mut data: Vec<Self::Scalar> = alloc_aligned(Self::bytes_of(module, cols, size));
|
||||
let ptr: *mut Self::Scalar = data.as_mut_ptr();
|
||||
Self {
|
||||
data: data,
|
||||
ptr: ptr,
|
||||
n: module.n(),
|
||||
limbs: limbs,
|
||||
size: size,
|
||||
cols: cols,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
fn bytes_of(module: &Module<B>, cols: usize, limbs: usize) -> usize {
|
||||
unsafe { bytes_of_vec_znx_dft(module.ptr, limbs as u64) as usize * cols }
|
||||
fn bytes_of(module: &Module<B>, cols: usize, size: usize) -> usize {
|
||||
unsafe { bytes_of_vec_znx_dft(module.ptr, size as u64) as usize * cols }
|
||||
}
|
||||
|
||||
/// Returns a new [VecZnxDft] with the provided data as backing array.
|
||||
/// User must ensure that data is properly alligned and that
|
||||
/// the size of data is at least equal to [Module::bytes_of_vec_znx_dft].
|
||||
fn from_bytes(module: &Module<B>, cols: usize, limbs: usize, bytes: &mut [Self::Scalar]) -> Self {
|
||||
fn from_bytes(module: &Module<B>, cols: usize, size: usize, bytes: &mut [Self::Scalar]) -> Self {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(cols > 0);
|
||||
assert!(limbs > 0);
|
||||
assert_eq!(bytes.len(), Self::bytes_of(module, cols, limbs));
|
||||
assert!(size > 0);
|
||||
assert_eq!(bytes.len(), Self::bytes_of(module, cols, size));
|
||||
assert_alignement(bytes.as_ptr())
|
||||
}
|
||||
unsafe {
|
||||
@@ -56,18 +56,18 @@ impl<B: Backend> ZnxBase<B> for VecZnxDft<B> {
|
||||
ptr: bytes.as_mut_ptr(),
|
||||
n: module.n(),
|
||||
cols: cols,
|
||||
limbs: limbs,
|
||||
size: size,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn from_bytes_borrow(module: &Module<B>, cols: usize, limbs: usize, bytes: &mut [Self::Scalar]) -> Self {
|
||||
fn from_bytes_borrow(module: &Module<B>, cols: usize, size: usize, bytes: &mut [Self::Scalar]) -> Self {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(cols > 0);
|
||||
assert!(limbs > 0);
|
||||
assert_eq!(bytes.len(), Self::bytes_of(module, cols, limbs));
|
||||
assert!(size > 0);
|
||||
assert_eq!(bytes.len(), Self::bytes_of(module, cols, size));
|
||||
assert_alignement(bytes.as_ptr());
|
||||
}
|
||||
Self {
|
||||
@@ -75,7 +75,7 @@ impl<B: Backend> ZnxBase<B> for VecZnxDft<B> {
|
||||
ptr: bytes.as_mut_ptr(),
|
||||
n: module.n(),
|
||||
cols: cols,
|
||||
limbs: limbs,
|
||||
size: size,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
@@ -91,7 +91,7 @@ impl<B: Backend> VecZnxDft<B> {
|
||||
ptr: self.ptr,
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
limbs: self.limbs,
|
||||
size: self.size,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
@@ -102,10 +102,6 @@ impl<B: Backend> ZnxInfos for VecZnxDft<B> {
|
||||
self.n
|
||||
}
|
||||
|
||||
fn log_n(&self) -> usize {
|
||||
(usize::BITS - (self.n() - 1).leading_zeros()) as _
|
||||
}
|
||||
|
||||
fn rows(&self) -> usize {
|
||||
1
|
||||
}
|
||||
@@ -114,12 +110,8 @@ impl<B: Backend> ZnxInfos for VecZnxDft<B> {
|
||||
self.cols
|
||||
}
|
||||
|
||||
fn limbs(&self) -> usize {
|
||||
self.limbs
|
||||
}
|
||||
|
||||
fn poly_count(&self) -> usize {
|
||||
self.cols * self.limbs
|
||||
fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
}
|
||||
|
||||
@@ -137,13 +129,13 @@ impl ZnxLayout for VecZnxDft<FFT64> {
|
||||
|
||||
impl VecZnxDft<FFT64> {
|
||||
pub fn print(&self, n: usize) {
|
||||
(0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n]));
|
||||
(0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n]));
|
||||
}
|
||||
}
|
||||
|
||||
pub trait VecZnxDftOps<B: Backend> {
|
||||
/// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space.
|
||||
fn new_vec_znx_dft(&self, cols: usize, limbs: usize) -> VecZnxDft<B>;
|
||||
fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDft<B>;
|
||||
|
||||
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
|
||||
///
|
||||
@@ -156,7 +148,7 @@ pub trait VecZnxDftOps<B: Backend> {
|
||||
///
|
||||
/// # Panics
|
||||
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
|
||||
fn new_vec_znx_dft_from_bytes(&self, cols: usize, limbs: usize, bytes: &mut [u8]) -> VecZnxDft<B>;
|
||||
fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft<B>;
|
||||
|
||||
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
|
||||
///
|
||||
@@ -169,7 +161,7 @@ pub trait VecZnxDftOps<B: Backend> {
|
||||
///
|
||||
/// # Panics
|
||||
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
|
||||
fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, limbs: usize, bytes: &mut [u8]) -> VecZnxDft<B>;
|
||||
fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft<B>;
|
||||
|
||||
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
|
||||
///
|
||||
@@ -180,7 +172,7 @@ pub trait VecZnxDftOps<B: Backend> {
|
||||
///
|
||||
/// # Panics
|
||||
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
|
||||
fn bytes_of_vec_znx_dft(&self, cols: usize, limbs: usize) -> usize;
|
||||
fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize;
|
||||
|
||||
/// Returns the minimum number of bytes necessary to allocate
|
||||
/// a new [VecZnxDft] through [VecZnxDft::from_bytes].
|
||||
@@ -201,20 +193,20 @@ pub trait VecZnxDftOps<B: Backend> {
|
||||
}
|
||||
|
||||
impl VecZnxDftOps<FFT64> for Module<FFT64> {
|
||||
fn new_vec_znx_dft(&self, cols: usize, limbs: usize) -> VecZnxDft<FFT64> {
|
||||
VecZnxDft::<FFT64>::new(&self, cols, limbs)
|
||||
fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDft<FFT64> {
|
||||
VecZnxDft::<FFT64>::new(&self, cols, size)
|
||||
}
|
||||
|
||||
fn new_vec_znx_dft_from_bytes(&self, cols: usize, limbs: usize, tmp_bytes: &mut [u8]) -> VecZnxDft<FFT64> {
|
||||
VecZnxDft::from_bytes(self, cols, limbs, tmp_bytes)
|
||||
fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxDft<FFT64> {
|
||||
VecZnxDft::from_bytes(self, cols, size, tmp_bytes)
|
||||
}
|
||||
|
||||
fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, limbs: usize, tmp_bytes: &mut [u8]) -> VecZnxDft<FFT64> {
|
||||
VecZnxDft::from_bytes_borrow(self, cols, limbs, tmp_bytes)
|
||||
fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxDft<FFT64> {
|
||||
VecZnxDft::from_bytes_borrow(self, cols, size, tmp_bytes)
|
||||
}
|
||||
|
||||
fn bytes_of_vec_znx_dft(&self, cols: usize, limbs: usize) -> usize {
|
||||
VecZnxDft::bytes_of(&self, cols, limbs)
|
||||
fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize {
|
||||
VecZnxDft::bytes_of(&self, cols, size)
|
||||
}
|
||||
|
||||
fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig<FFT64>, a: &mut VecZnxDft<FFT64>) {
|
||||
@@ -242,9 +234,9 @@ impl VecZnxDftOps<FFT64> for Module<FFT64> {
|
||||
vec_znx_dft::vec_znx_dft(
|
||||
self.ptr,
|
||||
b.ptr as *mut vec_znx_dft_t,
|
||||
b.limbs() as u64,
|
||||
b.size() as u64,
|
||||
a.as_ptr(),
|
||||
a.limbs() as u64,
|
||||
a.size() as u64,
|
||||
(a.n() * a.cols()) as u64,
|
||||
)
|
||||
}
|
||||
@@ -329,14 +321,14 @@ mod tests {
|
||||
let n: usize = 8;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
|
||||
let limbs: usize = 2;
|
||||
let size: usize = 2;
|
||||
let log_base2k: usize = 17;
|
||||
let mut a: VecZnx = module.new_vec_znx(1, limbs);
|
||||
let mut a_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(1, limbs);
|
||||
let mut b_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(1, limbs);
|
||||
let mut a: VecZnx = module.new_vec_znx(1, size);
|
||||
let mut a_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(1, size);
|
||||
let mut b_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(1, size);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
module.fill_uniform(log_base2k, &mut a, 0, limbs, &mut source);
|
||||
module.fill_uniform(log_base2k, &mut a, 0, size, &mut source);
|
||||
|
||||
let mut tmp_bytes: Vec<u8> = alloc_aligned(module.vec_znx_dft_automorphism_tmp_bytes());
|
||||
|
||||
|
||||
795
base2k/src/vec_znx_ops.rs
Normal file
795
base2k/src/vec_znx_ops.rs
Normal file
@@ -0,0 +1,795 @@
|
||||
use crate::ffi::module::MODULE;
|
||||
use crate::ffi::vec_znx;
|
||||
use crate::{Backend, Module, VecZnx, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, switch_degree, znx_post_process_ternary_op};
|
||||
use std::cmp::min;
|
||||
pub trait VecZnxOps {
|
||||
/// Allocates a new [VecZnx].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `cols`: the number of polynomials.
|
||||
/// * `size`: the number of size per polynomial (a.k.a small polynomials).
|
||||
fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnx;
|
||||
|
||||
fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnx;
|
||||
fn new_vec_znx_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnx;
|
||||
|
||||
/// Returns the minimum number of bytes necessary to allocate
|
||||
/// a new [VecZnx] through [VecZnx::from_bytes].
|
||||
fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize;
|
||||
|
||||
fn vec_znx_normalize_tmp_bytes(&self, cols: usize) -> usize;
|
||||
|
||||
/// c <- a + b.
|
||||
fn vec_znx_add(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx);
|
||||
|
||||
/// b <- b + a.
|
||||
fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx);
|
||||
|
||||
/// c <- a - b.
|
||||
fn vec_znx_sub(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx);
|
||||
|
||||
/// b <- a - b.
|
||||
fn vec_znx_sub_ab_inplace(&self, b: &mut VecZnx, a: &VecZnx);
|
||||
|
||||
/// b <- b - a.
|
||||
fn vec_znx_sub_ba_inplace(&self, b: &mut VecZnx, a: &VecZnx);
|
||||
|
||||
/// b <- -a.
|
||||
fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx);
|
||||
|
||||
/// b <- -b.
|
||||
fn vec_znx_negate_inplace(&self, a: &mut VecZnx);
|
||||
|
||||
/// b <- a * X^k (mod X^{n} + 1)
|
||||
fn vec_znx_rotate(&self, k: i64, b: &mut VecZnx, a: &VecZnx);
|
||||
|
||||
/// a <- a * X^k (mod X^{n} + 1)
|
||||
fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx);
|
||||
|
||||
/// b <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1))
|
||||
fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx);
|
||||
|
||||
/// a <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1))
|
||||
fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx);
|
||||
|
||||
/// Splits b into subrings and copies them them into a.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// This method requires that all [VecZnx] of b have the same ring degree
|
||||
/// and that b.n() * b.len() <= a.n()
|
||||
fn vec_znx_split(&self, b: &mut Vec<VecZnx>, a: &VecZnx, buf: &mut VecZnx);
|
||||
|
||||
/// Merges the subrings a into b.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// This method requires that all [VecZnx] of a have the same ring degree
|
||||
/// and that a.n() * a.len() <= b.n()
|
||||
fn vec_znx_merge(&self, b: &mut VecZnx, a: &Vec<VecZnx>);
|
||||
}
|
||||
|
||||
impl<B: Backend> VecZnxOps for Module<B> {
|
||||
fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnx {
|
||||
VecZnx::new(self, cols, size)
|
||||
}
|
||||
|
||||
fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize {
|
||||
VecZnx::bytes_of(self, cols, size)
|
||||
}
|
||||
|
||||
fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnx {
|
||||
VecZnx::from_bytes(self, cols, size, bytes)
|
||||
}
|
||||
|
||||
fn new_vec_znx_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnx {
|
||||
VecZnx::from_bytes_borrow(self, cols, size, tmp_bytes)
|
||||
}
|
||||
|
||||
fn vec_znx_normalize_tmp_bytes(&self, cols: usize) -> usize {
|
||||
unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.ptr) as usize * cols }
|
||||
}
|
||||
|
||||
fn vec_znx_add(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx) {
|
||||
let op = ffi_ternary_op_factory(
|
||||
self.ptr,
|
||||
c.size(),
|
||||
c.sl(),
|
||||
a.size(),
|
||||
a.sl(),
|
||||
b.size(),
|
||||
b.sl(),
|
||||
vec_znx::vec_znx_add,
|
||||
);
|
||||
vec_znx_apply_binary_op::<B, false>(self, c, a, b, op);
|
||||
}
|
||||
|
||||
fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx) {
|
||||
unsafe {
|
||||
let b_ptr: *mut VecZnx = b as *mut VecZnx;
|
||||
Self::vec_znx_add(self, &mut *b_ptr, a, &*b_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_sub(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx) {
|
||||
let op = ffi_ternary_op_factory(
|
||||
self.ptr,
|
||||
c.size(),
|
||||
c.sl(),
|
||||
a.size(),
|
||||
a.sl(),
|
||||
b.size(),
|
||||
b.sl(),
|
||||
vec_znx::vec_znx_sub,
|
||||
);
|
||||
vec_znx_apply_binary_op::<B, true>(self, c, a, b, op);
|
||||
}
|
||||
|
||||
fn vec_znx_sub_ab_inplace(&self, b: &mut VecZnx, a: &VecZnx) {
|
||||
unsafe {
|
||||
let b_ptr: *mut VecZnx = b as *mut VecZnx;
|
||||
Self::vec_znx_sub(self, &mut *b_ptr, a, &*b_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_sub_ba_inplace(&self, b: &mut VecZnx, a: &VecZnx) {
|
||||
unsafe {
|
||||
let b_ptr: *mut VecZnx = b as *mut VecZnx;
|
||||
Self::vec_znx_sub(self, &mut *b_ptr, &*b_ptr, a);
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx) {
|
||||
let op = ffi_binary_op_factory_type_0(
|
||||
self.ptr,
|
||||
b.size(),
|
||||
b.sl(),
|
||||
a.size(),
|
||||
a.sl(),
|
||||
vec_znx::vec_znx_negate,
|
||||
);
|
||||
vec_znx_apply_unary_op::<B>(self, b, a, op);
|
||||
}
|
||||
|
||||
fn vec_znx_negate_inplace(&self, a: &mut VecZnx) {
|
||||
unsafe {
|
||||
let a_ptr: *mut VecZnx = a as *mut VecZnx;
|
||||
Self::vec_znx_negate(self, &mut *a_ptr, &*a_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_rotate(&self, k: i64, b: &mut VecZnx, a: &VecZnx) {
|
||||
let op = ffi_binary_op_factory_type_1(
|
||||
self.ptr,
|
||||
k,
|
||||
b.size(),
|
||||
b.sl(),
|
||||
a.size(),
|
||||
a.sl(),
|
||||
vec_znx::vec_znx_rotate,
|
||||
);
|
||||
vec_znx_apply_unary_op::<B>(self, b, a, op);
|
||||
}
|
||||
|
||||
fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx) {
|
||||
unsafe {
|
||||
let a_ptr: *mut VecZnx = a as *mut VecZnx;
|
||||
Self::vec_znx_rotate(self, k, &mut *a_ptr, &*a_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
/// Maps X^i to X^{ik} mod X^{n}+1. The mapping is applied independently on each size.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `a`: input.
|
||||
/// * `b`: output.
|
||||
/// * `k`: the power to which to map each coefficients.
|
||||
/// * `a_size`: the number of a_size on which to apply the mapping.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// The method will panic if the argument `a` is greater than `a.size()`.
|
||||
fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx) {
|
||||
let op = ffi_binary_op_factory_type_1(
|
||||
self.ptr,
|
||||
k,
|
||||
b.size(),
|
||||
b.sl(),
|
||||
a.size(),
|
||||
a.sl(),
|
||||
vec_znx::vec_znx_automorphism,
|
||||
);
|
||||
vec_znx_apply_unary_op::<B>(self, b, a, op);
|
||||
}
|
||||
|
||||
/// Maps X^i to X^{ik} mod X^{n}+1. The mapping is applied independently on each size.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `a`: input and output.
|
||||
/// * `k`: the power to which to map each coefficients.
|
||||
/// * `a_size`: the number of size on which to apply the mapping.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// The method will panic if the argument `size` is greater than `self.size()`.
|
||||
fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx) {
|
||||
unsafe {
|
||||
let a_ptr: *mut VecZnx = a as *mut VecZnx;
|
||||
Self::vec_znx_automorphism(self, k, &mut *a_ptr, &*a_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_split(&self, b: &mut Vec<VecZnx>, a: &VecZnx, buf: &mut VecZnx) {
|
||||
let (n_in, n_out) = (a.n(), b[0].n());
|
||||
|
||||
debug_assert!(
|
||||
n_out < n_in,
|
||||
"invalid a: output ring degree should be smaller"
|
||||
);
|
||||
b[1..].iter().for_each(|bi| {
|
||||
debug_assert_eq!(
|
||||
bi.n(),
|
||||
n_out,
|
||||
"invalid input a: all VecZnx must have the same degree"
|
||||
)
|
||||
});
|
||||
|
||||
b.iter_mut().enumerate().for_each(|(i, bi)| {
|
||||
if i == 0 {
|
||||
switch_degree(bi, a);
|
||||
self.vec_znx_rotate(-1, buf, a);
|
||||
} else {
|
||||
switch_degree(bi, buf);
|
||||
self.vec_znx_rotate_inplace(-1, buf);
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn vec_znx_merge(&self, b: &mut VecZnx, a: &Vec<VecZnx>) {
|
||||
let (n_in, n_out) = (b.n(), a[0].n());
|
||||
|
||||
debug_assert!(
|
||||
n_out < n_in,
|
||||
"invalid a: output ring degree should be smaller"
|
||||
);
|
||||
a[1..].iter().for_each(|ai| {
|
||||
debug_assert_eq!(
|
||||
ai.n(),
|
||||
n_out,
|
||||
"invalid input a: all VecZnx must have the same degree"
|
||||
)
|
||||
});
|
||||
|
||||
a.iter().enumerate().for_each(|(_, ai)| {
|
||||
switch_degree(b, ai);
|
||||
self.vec_znx_rotate_inplace(-1, b);
|
||||
});
|
||||
|
||||
self.vec_znx_rotate_inplace(a.len() as i64, b);
|
||||
}
|
||||
}
|
||||
|
||||
fn ffi_ternary_op_factory(
|
||||
module_ptr: *const MODULE,
|
||||
c_size: usize,
|
||||
c_sl: usize,
|
||||
a_size: usize,
|
||||
a_sl: usize,
|
||||
b_size: usize,
|
||||
b_sl: usize,
|
||||
op_fn: unsafe extern "C" fn(*const MODULE, *mut i64, u64, u64, *const i64, u64, u64, *const i64, u64, u64),
|
||||
) -> impl Fn(&mut [i64], &[i64], &[i64]) {
|
||||
move |cv: &mut [i64], av: &[i64], bv: &[i64]| unsafe {
|
||||
op_fn(
|
||||
module_ptr,
|
||||
cv.as_mut_ptr(),
|
||||
c_size as u64,
|
||||
c_sl as u64,
|
||||
av.as_ptr(),
|
||||
a_size as u64,
|
||||
a_sl as u64,
|
||||
bv.as_ptr(),
|
||||
b_size as u64,
|
||||
b_sl as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn ffi_binary_op_factory_type_0(
|
||||
module_ptr: *const MODULE,
|
||||
b_size: usize,
|
||||
b_sl: usize,
|
||||
a_size: usize,
|
||||
a_sl: usize,
|
||||
op_fn: unsafe extern "C" fn(*const MODULE, *mut i64, u64, u64, *const i64, u64, u64),
|
||||
) -> impl Fn(&mut [i64], &[i64]) {
|
||||
move |bv: &mut [i64], av: &[i64]| unsafe {
|
||||
op_fn(
|
||||
module_ptr,
|
||||
bv.as_mut_ptr(),
|
||||
b_size as u64,
|
||||
b_sl as u64,
|
||||
av.as_ptr(),
|
||||
a_size as u64,
|
||||
a_sl as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn ffi_binary_op_factory_type_1(
|
||||
module_ptr: *const MODULE,
|
||||
k: i64,
|
||||
b_size: usize,
|
||||
b_sl: usize,
|
||||
a_size: usize,
|
||||
a_sl: usize,
|
||||
op_fn: unsafe extern "C" fn(*const MODULE, i64, *mut i64, u64, u64, *const i64, u64, u64),
|
||||
) -> impl Fn(&mut [i64], &[i64]) {
|
||||
move |bv: &mut [i64], av: &[i64]| unsafe {
|
||||
op_fn(
|
||||
module_ptr,
|
||||
k,
|
||||
bv.as_mut_ptr(),
|
||||
b_size as u64,
|
||||
b_sl as u64,
|
||||
av.as_ptr(),
|
||||
a_size as u64,
|
||||
a_sl as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn vec_znx_apply_binary_op<B: Backend, const NEGATE: bool>(
|
||||
module: &Module<B>,
|
||||
c: &mut VecZnx,
|
||||
a: &VecZnx,
|
||||
b: &VecZnx,
|
||||
op: impl Fn(&mut [i64], &[i64], &[i64]),
|
||||
) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(b.n(), module.n());
|
||||
assert_eq!(c.n(), module.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
|
||||
let a_cols: usize = a.cols();
|
||||
let b_cols: usize = b.cols();
|
||||
let c_cols: usize = c.cols();
|
||||
|
||||
let min_ab_cols: usize = min(a_cols, b_cols);
|
||||
let min_cols: usize = min(c_cols, min_ab_cols);
|
||||
|
||||
// Applies over shared cols between (a, b, c)
|
||||
(0..min_cols).for_each(|i| op(c.at_poly_mut(i, 0), a.at_poly(i, 0), b.at_poly(i, 0)));
|
||||
// Copies/Negates/Zeroes the remaining cols if op is not inplace.
|
||||
if c.as_ptr() != a.as_ptr() && c.as_ptr() != b.as_ptr() {
|
||||
znx_post_process_ternary_op::<VecZnx, NEGATE>(c, a, b);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn vec_znx_apply_unary_op<B: Backend>(module: &Module<B>, b: &mut VecZnx, a: &VecZnx, op: impl Fn(&mut [i64], &[i64])) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(b.n(), module.n());
|
||||
}
|
||||
let a_cols: usize = a.cols();
|
||||
let b_cols: usize = b.cols();
|
||||
let min_cols: usize = min(a_cols, b_cols);
|
||||
// Applies over the shared cols between (a, b)
|
||||
(0..min_cols).for_each(|i| op(b.at_poly_mut(i, 0), a.at_poly(i, 0)));
|
||||
// Zeroes the remaining cols of b.
|
||||
(min_cols..b_cols).for_each(|i| (0..b.size()).for_each(|j| b.zero_at(i, j)));
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{
|
||||
Backend, FFT64, Module, Sampling, VecZnx, VecZnxOps, ZnxBasics, ZnxInfos, ZnxLayout, ffi::vec_znx,
|
||||
znx_post_process_ternary_op,
|
||||
};
|
||||
use itertools::izip;
|
||||
use sampling::source::Source;
|
||||
use std::cmp::min;
|
||||
|
||||
#[test]
|
||||
fn vec_znx_add() {
|
||||
let n: usize = 8;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let op = |cv: &mut [i64], av: &[i64], bv: &[i64]| {
|
||||
izip!(cv.iter_mut(), bv.iter(), av.iter()).for_each(|(ci, bi, ai)| *ci = *bi + *ai);
|
||||
};
|
||||
test_binary_op::<false, _>(
|
||||
&module,
|
||||
&|c: &mut VecZnx, a: &VecZnx, b: &VecZnx| module.vec_znx_add(c, a, b),
|
||||
op,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vec_znx_add_inplace() {
|
||||
let n: usize = 8;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let op = |bv: &mut [i64], av: &[i64]| {
|
||||
izip!(bv.iter_mut(), av.iter()).for_each(|(bi, ai)| *bi = *bi + *ai);
|
||||
};
|
||||
test_binary_op_inplace::<false, _>(
|
||||
&module,
|
||||
&|b: &mut VecZnx, a: &VecZnx| module.vec_znx_add_inplace(b, a),
|
||||
op,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vec_znx_sub() {
|
||||
let n: usize = 8;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let op = |cv: &mut [i64], av: &[i64], bv: &[i64]| {
|
||||
izip!(cv.iter_mut(), bv.iter(), av.iter()).for_each(|(ci, bi, ai)| *ci = *bi - *ai);
|
||||
};
|
||||
test_binary_op::<true, _>(
|
||||
&module,
|
||||
&|c: &mut VecZnx, a: &VecZnx, b: &VecZnx| module.vec_znx_sub(c, a, b),
|
||||
op,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vec_znx_sub_ab_inplace() {
|
||||
let n: usize = 8;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let op = |bv: &mut [i64], av: &[i64]| {
|
||||
izip!(bv.iter_mut(), av.iter()).for_each(|(bi, ai)| *bi = *ai - *bi);
|
||||
};
|
||||
test_binary_op_inplace::<true, _>(
|
||||
&module,
|
||||
&|b: &mut VecZnx, a: &VecZnx| module.vec_znx_sub_ab_inplace(b, a),
|
||||
op,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vec_znx_sub_ba_inplace() {
|
||||
let n: usize = 8;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let op = |bv: &mut [i64], av: &[i64]| {
|
||||
izip!(bv.iter_mut(), av.iter()).for_each(|(bi, ai)| *bi = *bi - *ai);
|
||||
};
|
||||
test_binary_op_inplace::<false, _>(
|
||||
&module,
|
||||
&|b: &mut VecZnx, a: &VecZnx| module.vec_znx_sub_ba_inplace(b, a),
|
||||
op,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vec_znx_negate() {
|
||||
let n: usize = 8;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let op = |b: &mut [i64], a: &[i64]| {
|
||||
izip!(b.iter_mut(), a.iter()).for_each(|(bi, ai)| *bi = -*ai);
|
||||
};
|
||||
test_unary_op(
|
||||
&module,
|
||||
|b: &mut VecZnx, a: &VecZnx| module.vec_znx_negate(b, a),
|
||||
op,
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vec_znx_negate_inplace() {
|
||||
let n: usize = 8;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let op = |a: &mut [i64]| a.iter_mut().for_each(|xi| *xi = -*xi);
|
||||
test_unary_op_inplace(
|
||||
&module,
|
||||
|a: &mut VecZnx| module.vec_znx_negate_inplace(a),
|
||||
op,
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vec_znx_rotate() {
|
||||
let n: usize = 8;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let k: i64 = 53;
|
||||
let op = |b: &mut [i64], a: &[i64]| {
|
||||
assert_eq!(b.len(), a.len());
|
||||
b.copy_from_slice(a);
|
||||
|
||||
let mut k_mod2n: i64 = k % (2 * n as i64);
|
||||
if k_mod2n < 0 {
|
||||
k_mod2n += 2 * n as i64;
|
||||
}
|
||||
let sign: i64 = (k_mod2n.abs() / (n as i64)) & 1;
|
||||
let k_modn: i64 = k_mod2n % (n as i64);
|
||||
|
||||
b.rotate_right(k_modn as usize);
|
||||
b[0..k_modn as usize].iter_mut().for_each(|x| *x = -*x);
|
||||
|
||||
if sign == 1 {
|
||||
b.iter_mut().for_each(|x| *x = -*x);
|
||||
}
|
||||
};
|
||||
test_unary_op(
|
||||
&module,
|
||||
|b: &mut VecZnx, a: &VecZnx| module.vec_znx_rotate(k, b, a),
|
||||
op,
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vec_znx_rotate_inplace() {
|
||||
let n: usize = 8;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let k: i64 = 53;
|
||||
let rot = |a: &mut [i64]| {
|
||||
let mut k_mod2n: i64 = k % (2 * n as i64);
|
||||
if k_mod2n < 0 {
|
||||
k_mod2n += 2 * n as i64;
|
||||
}
|
||||
let sign: i64 = (k_mod2n.abs() / (n as i64)) & 1;
|
||||
let k_modn: i64 = k_mod2n % (n as i64);
|
||||
|
||||
a.rotate_right(k_modn as usize);
|
||||
a[0..k_modn as usize].iter_mut().for_each(|x| *x = -*x);
|
||||
|
||||
if sign == 1 {
|
||||
a.iter_mut().for_each(|x| *x = -*x);
|
||||
}
|
||||
};
|
||||
test_unary_op_inplace(
|
||||
&module,
|
||||
|a: &mut VecZnx| module.vec_znx_rotate_inplace(k, a),
|
||||
rot,
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vec_znx_automorphism() {
|
||||
let n: usize = 8;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let k: i64 = -5;
|
||||
let op = |b: &mut [i64], a: &[i64]| {
|
||||
assert_eq!(b.len(), a.len());
|
||||
unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
module.ptr,
|
||||
k,
|
||||
b.as_mut_ptr(),
|
||||
1u64,
|
||||
n as u64,
|
||||
a.as_ptr(),
|
||||
1u64,
|
||||
n as u64,
|
||||
);
|
||||
}
|
||||
};
|
||||
test_unary_op(
|
||||
&module,
|
||||
|b: &mut VecZnx, a: &VecZnx| module.vec_znx_automorphism(k, b, a),
|
||||
op,
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vec_znx_automorphism_inplace() {
|
||||
let n: usize = 8;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let k: i64 = -5;
|
||||
let op = |a: &mut [i64]| unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
module.ptr,
|
||||
k,
|
||||
a.as_mut_ptr(),
|
||||
1u64,
|
||||
n as u64,
|
||||
a.as_ptr(),
|
||||
1u64,
|
||||
n as u64,
|
||||
);
|
||||
};
|
||||
test_unary_op_inplace(
|
||||
&module,
|
||||
|a: &mut VecZnx| module.vec_znx_automorphism_inplace(k, a),
|
||||
op,
|
||||
)
|
||||
}
|
||||
|
||||
fn test_binary_op<const NEGATE: bool, B: Backend>(
|
||||
module: &Module<B>,
|
||||
func_have: impl Fn(&mut VecZnx, &VecZnx, &VecZnx),
|
||||
func_want: impl Fn(&mut [i64], &[i64], &[i64]),
|
||||
) {
|
||||
let a_size: usize = 3;
|
||||
let b_size: usize = 4;
|
||||
let c_size: usize = 5;
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
[1usize, 2, 3].iter().for_each(|a_cols| {
|
||||
[1usize, 2, 3].iter().for_each(|b_cols| {
|
||||
[1usize, 2, 3].iter().for_each(|c_cols| {
|
||||
let min_ab_cols: usize = min(*a_cols, *b_cols);
|
||||
let min_cols: usize = min(*c_cols, min_ab_cols);
|
||||
let min_size: usize = min(c_size, min(a_size, b_size));
|
||||
|
||||
let mut a: VecZnx = module.new_vec_znx(*a_cols, a_size);
|
||||
(0..*a_cols).for_each(|i| {
|
||||
module.fill_uniform(3, &mut a, i, a_size, &mut source);
|
||||
});
|
||||
|
||||
let mut b: VecZnx = module.new_vec_znx(*b_cols, b_size);
|
||||
(0..*b_cols).for_each(|i| {
|
||||
module.fill_uniform(3, &mut b, i, b_size, &mut source);
|
||||
});
|
||||
|
||||
let mut c_have: VecZnx = module.new_vec_znx(*c_cols, c_size);
|
||||
(0..c_have.cols()).for_each(|i| {
|
||||
module.fill_uniform(3, &mut c_have, i, c_size, &mut source);
|
||||
});
|
||||
|
||||
func_have(&mut c_have, &a, &b);
|
||||
|
||||
let mut c_want: VecZnx = module.new_vec_znx(*c_cols, c_size);
|
||||
|
||||
// Adds with the minimum matching columns
|
||||
(0..min_cols).for_each(|i| {
|
||||
// Adds with th eminimum matching size
|
||||
(0..min_size).for_each(|j| {
|
||||
func_want(c_want.at_poly_mut(i, j), b.at_poly(i, j), a.at_poly(i, j));
|
||||
});
|
||||
|
||||
if a_size > b_size {
|
||||
// Copies remaining size of lh if lh.size() > rh.size()
|
||||
(min_size..a_size).for_each(|j| {
|
||||
izip!(c_want.at_poly_mut(i, j).iter_mut(), a.at_poly(i, j).iter()).for_each(|(ci, ai)| *ci = *ai);
|
||||
if NEGATE {
|
||||
c_want.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
// Copies the remaining size of rh if the are greater
|
||||
(min_size..b_size).for_each(|j| {
|
||||
izip!(c_want.at_poly_mut(i, j).iter_mut(), b.at_poly(i, j).iter()).for_each(|(ci, bi)| *ci = *bi);
|
||||
if NEGATE {
|
||||
c_want.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
znx_post_process_ternary_op::<_, NEGATE>(&mut c_want, &a, &b);
|
||||
|
||||
assert_eq!(c_have.raw(), c_want.raw());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
fn test_binary_op_inplace<const NEGATE: bool, B: Backend>(
|
||||
module: &Module<B>,
|
||||
func_have: impl Fn(&mut VecZnx, &VecZnx),
|
||||
func_want: impl Fn(&mut [i64], &[i64]),
|
||||
) {
|
||||
let a_size: usize = 3;
|
||||
let b_size: usize = 5;
|
||||
let mut source = Source::new([0u8; 32]);
|
||||
|
||||
[1usize, 2, 3].iter().for_each(|a_cols| {
|
||||
[1usize, 2, 3].iter().for_each(|b_cols| {
|
||||
let min_cols: usize = min(*b_cols, *a_cols);
|
||||
let min_size: usize = min(b_size, a_size);
|
||||
|
||||
let mut a: VecZnx = module.new_vec_znx(*a_cols, a_size);
|
||||
(0..*a_cols).for_each(|i| {
|
||||
module.fill_uniform(3, &mut a, i, a_size, &mut source);
|
||||
});
|
||||
|
||||
let mut b_have: VecZnx = module.new_vec_znx(*b_cols, b_size);
|
||||
(0..*b_cols).for_each(|i| {
|
||||
module.fill_uniform(3, &mut b_have, i, b_size, &mut source);
|
||||
});
|
||||
|
||||
let mut b_want: VecZnx = module.new_vec_znx(*b_cols, b_size);
|
||||
b_want.raw_mut().copy_from_slice(b_have.raw());
|
||||
|
||||
func_have(&mut b_have, &a);
|
||||
|
||||
// Applies with the minimum matching columns
|
||||
(0..min_cols).for_each(|i| {
|
||||
// Adds with th eminimum matching size
|
||||
(0..min_size).for_each(|j| func_want(b_want.at_poly_mut(i, j), a.at_poly(i, j)));
|
||||
if NEGATE {
|
||||
(min_size..b_size).for_each(|j| {
|
||||
b_want.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
assert_eq!(b_have.raw(), b_want.raw());
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
fn test_unary_op<B: Backend>(
|
||||
module: &Module<B>,
|
||||
func_have: impl Fn(&mut VecZnx, &VecZnx),
|
||||
func_want: impl Fn(&mut [i64], &[i64]),
|
||||
) {
|
||||
let a_size: usize = 3;
|
||||
let b_size: usize = 5;
|
||||
let mut source = Source::new([0u8; 32]);
|
||||
|
||||
[1usize, 2, 3].iter().for_each(|a_cols| {
|
||||
[1usize, 2, 3].iter().for_each(|b_cols| {
|
||||
let min_cols: usize = min(*b_cols, *a_cols);
|
||||
let min_size: usize = min(b_size, a_size);
|
||||
|
||||
let mut a: VecZnx = module.new_vec_znx(*a_cols, a_size);
|
||||
(0..a.cols()).for_each(|i| {
|
||||
module.fill_uniform(3, &mut a, i, a_size, &mut source);
|
||||
});
|
||||
|
||||
let mut b_have: VecZnx = module.new_vec_znx(*b_cols, b_size);
|
||||
(0..b_have.cols()).for_each(|i| {
|
||||
module.fill_uniform(3, &mut b_have, i, b_size, &mut source);
|
||||
});
|
||||
|
||||
let mut b_want: VecZnx = module.new_vec_znx(*b_cols, b_size);
|
||||
|
||||
func_have(&mut b_have, &a);
|
||||
|
||||
// Applies on the minimum matching columns
|
||||
(0..min_cols).for_each(|i| {
|
||||
// Applies on the minimum matching size
|
||||
(0..min_size).for_each(|j| func_want(b_want.at_poly_mut(i, j), a.at_poly(i, j)));
|
||||
|
||||
// Zeroes the unmatching size
|
||||
(min_size..b_size).for_each(|j| {
|
||||
b_want.zero_at(i, j);
|
||||
})
|
||||
});
|
||||
|
||||
// Zeroes the unmatching columns
|
||||
(min_cols..*b_cols).for_each(|i| {
|
||||
(0..b_size).for_each(|j| {
|
||||
b_want.zero_at(i, j);
|
||||
})
|
||||
});
|
||||
|
||||
assert_eq!(b_have.raw(), b_want.raw());
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
fn test_unary_op_inplace<B: Backend>(module: &Module<B>, func_have: impl Fn(&mut VecZnx), func_want: impl Fn(&mut [i64])) {
|
||||
let a_size: usize = 3;
|
||||
let mut source = Source::new([0u8; 32]);
|
||||
[1usize, 2, 3].iter().for_each(|a_cols| {
|
||||
let mut a_have: VecZnx = module.new_vec_znx(*a_cols, a_size);
|
||||
(0..*a_cols).for_each(|i| {
|
||||
module.fill_uniform(3, &mut a_have, i, a_size, &mut source);
|
||||
});
|
||||
|
||||
let mut a_want: VecZnx = module.new_vec_znx(*a_cols, a_size);
|
||||
a_have.raw_mut().copy_from_slice(a_want.raw());
|
||||
|
||||
func_have(&mut a_have);
|
||||
|
||||
// Applies on the minimum matching columns
|
||||
(0..*a_cols).for_each(|i| {
|
||||
// Applies on the minimum matching size
|
||||
(0..a_size).for_each(|j| func_want(a_want.at_poly_mut(i, j)));
|
||||
});
|
||||
|
||||
assert_eq!(a_have.raw(), a_want.raw());
|
||||
});
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user