mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
wip: change of approach, enables to select columns on which to operate
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
use base2k::{
|
||||
Encoding, FFT64, Module, Sampling, Scalar, ScalarZnxDft, ScalarZnxDftOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft,
|
||||
VecZnxDftOps, VecZnxOps, ZnxInfos, ZnxLayout, alloc_aligned,
|
||||
VecZnxDftOps, VecZnxOps, ZnxInfos, alloc_aligned,
|
||||
};
|
||||
use itertools::izip;
|
||||
use sampling::source::Source;
|
||||
@@ -13,13 +13,11 @@ fn main() {
|
||||
let log_scale: usize = msg_size * log_base2k - 5;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
|
||||
let mut carry: Vec<u8> = alloc_aligned(module.vec_znx_big_normalize_tmp_bytes(1));
|
||||
let mut carry: Vec<u8> = alloc_aligned(module.vec_znx_big_normalize_tmp_bytes(2));
|
||||
|
||||
let seed: [u8; 32] = [0; 32];
|
||||
let mut source: Source = Source::new(seed);
|
||||
|
||||
let mut res: VecZnx = module.new_vec_znx(1, ct_size);
|
||||
|
||||
// s <- Z_{-1, 0, 1}[X]/(X^{N}+1)
|
||||
let mut s: Scalar = Scalar::new(n);
|
||||
s.fill_ternary_prob(0.5, &mut source);
|
||||
@@ -30,47 +28,50 @@ fn main() {
|
||||
// s_ppol <- DFT(s)
|
||||
module.svp_prepare(&mut s_ppol, &s);
|
||||
|
||||
// a <- Z_{2^prec}[X]/(X^{N}+1)
|
||||
let mut a: VecZnx = module.new_vec_znx(1, ct_size);
|
||||
module.fill_uniform(log_base2k, &mut a, 0, ct_size, &mut source);
|
||||
// ct = (c0, c1)
|
||||
let mut ct: VecZnx = module.new_vec_znx(2, ct_size);
|
||||
|
||||
// Fill c1 with random values
|
||||
module.fill_uniform(log_base2k, &mut ct, 1, ct_size, &mut source);
|
||||
|
||||
// Scratch space for DFT values
|
||||
let mut buf_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(1, a.size());
|
||||
let mut buf_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(1, ct.size());
|
||||
|
||||
// Applies buf_dft <- s * a
|
||||
module.svp_apply_dft(&mut buf_dft, &s_ppol, &a);
|
||||
// Applies buf_dft <- s * c1
|
||||
module.svp_apply_dft(
|
||||
&mut buf_dft, // DFT(c1 * s)
|
||||
&s_ppol,
|
||||
&ct,
|
||||
1, // c1
|
||||
);
|
||||
|
||||
// Alias scratch space
|
||||
// Alias scratch space (VecZnxDftis always at least as big as VecZnxBig)
|
||||
let mut buf_big: VecZnxBig<FFT64> = buf_dft.as_vec_znx_big();
|
||||
|
||||
// buf_big <- IDFT(buf_dft) (not normalized)
|
||||
// BIG(c1 * s) <- IDFT(DFT(c1 * s)) (not normalized)
|
||||
module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft);
|
||||
|
||||
println!("{:?}", buf_big.raw());
|
||||
|
||||
// m <- (0)
|
||||
let mut m: VecZnx = module.new_vec_znx(1, msg_size);
|
||||
|
||||
let mut want: Vec<i64> = vec![0; n];
|
||||
want.iter_mut()
|
||||
.for_each(|x| *x = source.next_u64n(16, 15) as i64);
|
||||
|
||||
// m
|
||||
m.encode_vec_i64(0, log_base2k, log_scale, &want, 4);
|
||||
m.normalize(log_base2k, &mut carry);
|
||||
|
||||
// buf_big <- m - buf_big
|
||||
// m - BIG(c1 * s)
|
||||
module.vec_znx_big_sub_small_ab_inplace(&mut buf_big, &m);
|
||||
|
||||
println!("{:?}", buf_big.raw());
|
||||
// c0 <- m - BIG(c1 * s)
|
||||
module.vec_znx_big_normalize(log_base2k, &mut ct, &buf_big, &mut carry);
|
||||
|
||||
// b <- normalize(buf_big) + e
|
||||
let mut b: VecZnx = module.new_vec_znx(1, ct_size);
|
||||
module.vec_znx_big_normalize(log_base2k, &mut b, &buf_big, &mut carry);
|
||||
b.print(n);
|
||||
ct.print(ct.sl());
|
||||
|
||||
// (c0 + e, c1)
|
||||
module.add_normal(
|
||||
log_base2k,
|
||||
&mut b,
|
||||
0,
|
||||
&mut ct,
|
||||
0, // c0
|
||||
log_base2k * ct_size,
|
||||
&mut source,
|
||||
3.2,
|
||||
@@ -79,16 +80,16 @@ fn main() {
|
||||
|
||||
// Decrypt
|
||||
|
||||
// buf_big <- a * s
|
||||
module.svp_apply_dft(&mut buf_dft, &s_ppol, &a);
|
||||
// DFT(c1 * s)
|
||||
module.svp_apply_dft(&mut buf_dft, &s_ppol, &ct, 1);
|
||||
// BIG(c1 * s) = IDFT(DFT(c1 * s))
|
||||
module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft);
|
||||
|
||||
// buf_big <- a * s + b
|
||||
module.vec_znx_big_add_small_inplace(&mut buf_big, &b);
|
||||
// BIG(c1 * s) + c0
|
||||
module.vec_znx_big_add_small_inplace(&mut buf_big, &ct);
|
||||
|
||||
println!("raw: {:?}", &buf_big.raw());
|
||||
|
||||
// res <- normalize(buf_big)
|
||||
// m + e <- BIG(c1 * s + c0)
|
||||
let mut res: VecZnx = module.new_vec_znx(1, ct_size);
|
||||
module.vec_znx_big_normalize(log_base2k, &mut res, &buf_big, &mut carry);
|
||||
|
||||
// have = m * 2^{log_scale} + e
|
||||
|
||||
@@ -81,12 +81,12 @@ pub trait ZnxLayout: ZnxInfos {
|
||||
}
|
||||
|
||||
/// Returns non-mutable reference to the (i, j)-th small polynomial.
|
||||
fn at_poly(&self, i: usize, j: usize) -> &[Self::Scalar] {
|
||||
fn at(&self, i: usize, j: usize) -> &[Self::Scalar] {
|
||||
unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n()) }
|
||||
}
|
||||
|
||||
/// Returns mutable reference to the (i, j)-th small polynomial.
|
||||
fn at_poly_mut(&mut self, i: usize, j: usize) -> &mut [Self::Scalar] {
|
||||
fn at_mut(&mut self, i: usize, j: usize) -> &mut [Self::Scalar] {
|
||||
unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n()) }
|
||||
}
|
||||
|
||||
@@ -219,7 +219,7 @@ pub fn rsh_tmp_bytes<T: IntegerType>(n: usize, cols: usize) -> usize {
|
||||
n * cols * std::mem::size_of::<T>()
|
||||
}
|
||||
|
||||
pub fn switch_degree<T: ZnxLayout + ZnxBasics>(b: &mut T, a: &T)
|
||||
pub fn switch_degree<T: ZnxLayout + ZnxBasics>(b: &mut T, col_b: usize, a: &T, col_a: usize)
|
||||
where
|
||||
<T as ZnxLayout>::Scalar: IntegerType,
|
||||
{
|
||||
@@ -237,8 +237,8 @@ where
|
||||
|
||||
(0..size).for_each(|i| {
|
||||
izip!(
|
||||
a.at_limb(i).iter().step_by(gap_in),
|
||||
b.at_limb_mut(i).iter_mut().step_by(gap_out)
|
||||
a.at(col_a, i).iter().step_by(gap_in),
|
||||
b.at_mut(col_b, i).iter_mut().step_by(gap_out)
|
||||
)
|
||||
.for_each(|(x_in, x_out)| *x_out = *x_in);
|
||||
});
|
||||
|
||||
@@ -2,102 +2,6 @@ use std::cmp::{max, min};
|
||||
|
||||
use crate::{Backend, IntegerType, Module, ZnxBasics, ZnxLayout, ffi::module::MODULE};
|
||||
|
||||
pub(crate) fn znx_post_process_ternary_op<C, A, B, const NEGATE: bool>(c: &mut C, a: &A, b: &B)
|
||||
where
|
||||
C: ZnxBasics + ZnxLayout,
|
||||
A: ZnxBasics + ZnxLayout<Scalar = C::Scalar>,
|
||||
B: ZnxBasics + ZnxLayout<Scalar = C::Scalar>,
|
||||
C::Scalar: IntegerType,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
assert_ne!(b.as_ptr(), c.as_ptr());
|
||||
assert_ne!(a.as_ptr(), c.as_ptr());
|
||||
}
|
||||
|
||||
let a_cols: usize = a.cols();
|
||||
let b_cols: usize = b.cols();
|
||||
let c_cols: usize = c.cols();
|
||||
|
||||
let min_ab_cols: usize = min(a_cols, b_cols);
|
||||
let max_ab_cols: usize = max(a_cols, b_cols);
|
||||
|
||||
// Copies shared shared cols between (c, max(a, b))
|
||||
if a_cols != b_cols {
|
||||
if a_cols > b_cols {
|
||||
let min_size = min(c.size(), a.size());
|
||||
(min_ab_cols..min(max_ab_cols, c_cols)).for_each(|i| {
|
||||
(0..min_size).for_each(|j| {
|
||||
c.at_poly_mut(i, j).copy_from_slice(a.at_poly(i, j));
|
||||
if NEGATE {
|
||||
c.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x);
|
||||
}
|
||||
});
|
||||
(min_size..c.size()).for_each(|j| {
|
||||
c.zero_at(i, j);
|
||||
});
|
||||
});
|
||||
} else {
|
||||
let min_size = min(c.size(), b.size());
|
||||
(min_ab_cols..min(max_ab_cols, c_cols)).for_each(|i| {
|
||||
(0..min_size).for_each(|j| {
|
||||
c.at_poly_mut(i, j).copy_from_slice(b.at_poly(i, j));
|
||||
if NEGATE {
|
||||
c.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x);
|
||||
}
|
||||
});
|
||||
(min_size..c.size()).for_each(|j| {
|
||||
c.zero_at(i, j);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Zeroes the cols of c > max(a, b).
|
||||
if c_cols > max_ab_cols {
|
||||
(max_ab_cols..c_cols).for_each(|i| {
|
||||
(0..c.size()).for_each(|j| {
|
||||
c.zero_at(i, j);
|
||||
})
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn apply_binary_op<BE, C, A, B, const NEGATE: bool>(
|
||||
module: &Module<BE>,
|
||||
c: &mut C,
|
||||
a: &A,
|
||||
b: &B,
|
||||
op: impl Fn(&mut [C::Scalar], &[A::Scalar], &[B::Scalar]),
|
||||
) where
|
||||
BE: Backend,
|
||||
C: ZnxBasics + ZnxLayout,
|
||||
A: ZnxBasics + ZnxLayout<Scalar = C::Scalar>,
|
||||
B: ZnxBasics + ZnxLayout<Scalar = C::Scalar>,
|
||||
C::Scalar: IntegerType,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(b.n(), module.n());
|
||||
assert_eq!(c.n(), module.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
let a_cols: usize = a.cols();
|
||||
let b_cols: usize = b.cols();
|
||||
let c_cols: usize = c.cols();
|
||||
let min_ab_cols: usize = min(a_cols, b_cols);
|
||||
let min_cols: usize = min(c_cols, min_ab_cols);
|
||||
// Applies over shared cols between (a, b, c)
|
||||
(0..min_cols).for_each(|i| op(c.at_poly_mut(i, 0), a.at_poly(i, 0), b.at_poly(i, 0)));
|
||||
// Copies/Negates/Zeroes the remaining cols if op is not inplace.
|
||||
if c.as_ptr() != a.as_ptr() && c.as_ptr() != b.as_ptr() {
|
||||
znx_post_process_ternary_op::<C, A, B, NEGATE>(c, a, b);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn apply_unary_op<B: Backend, T: ZnxBasics + ZnxLayout>(
|
||||
module: &Module<B>,
|
||||
|
||||
@@ -230,7 +230,7 @@ pub trait ScalarZnxDftOps<B: Backend> {
|
||||
|
||||
/// Applies the [SvpPPol] x [VecZnxDft] product, where each limb of
|
||||
/// the [VecZnxDft] is multiplied with [SvpPPol].
|
||||
fn svp_apply_dft(&self, c: &mut VecZnxDft<B>, a: &ScalarZnxDft<B>, b: &VecZnx);
|
||||
fn svp_apply_dft(&self, c: &mut VecZnxDft<B>, a: &ScalarZnxDft<B>, b: &VecZnx, b_col: usize);
|
||||
}
|
||||
|
||||
impl ScalarZnxDftOps<FFT64> for Module<FFT64> {
|
||||
@@ -261,16 +261,16 @@ impl ScalarZnxDftOps<FFT64> for Module<FFT64> {
|
||||
unsafe { svp::svp_prepare(self.ptr, svp_ppol.ptr as *mut svp_ppol_t, a.as_ptr()) }
|
||||
}
|
||||
|
||||
fn svp_apply_dft(&self, c: &mut VecZnxDft<FFT64>, a: &ScalarZnxDft<FFT64>, b: &VecZnx) {
|
||||
fn svp_apply_dft(&self, c: &mut VecZnxDft<FFT64>, a: &ScalarZnxDft<FFT64>, b: &VecZnx, b_col: usize) {
|
||||
unsafe {
|
||||
svp::svp_apply_dft(
|
||||
self.ptr,
|
||||
c.ptr as *mut vec_znx_dft_t,
|
||||
c.cols() as u64,
|
||||
c.size() as u64,
|
||||
a.ptr as *const svp_ppol_t,
|
||||
b.as_ptr(),
|
||||
b.cols() as u64,
|
||||
b.n() as u64,
|
||||
b.at_ptr(b_col, 0),
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -193,8 +193,8 @@ impl VecZnx {
|
||||
normalize(log_base2k, self, carry)
|
||||
}
|
||||
|
||||
pub fn switch_degree(&self, a: &mut Self) {
|
||||
switch_degree(a, self)
|
||||
pub fn switch_degree(&self, col: usize, a: &mut Self, col_a: usize) {
|
||||
switch_degree(a, col_a, self, col)
|
||||
}
|
||||
|
||||
// Prints the first `n` coefficients of each limb
|
||||
|
||||
@@ -232,9 +232,9 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
|
||||
}
|
||||
|
||||
let a_size: usize = a.size();
|
||||
let b_size: usize = b.sl();
|
||||
let a_sl: usize = a.size();
|
||||
let b_sl: usize = a.sl();
|
||||
let b_size: usize = b.size();
|
||||
let a_sl: usize = a.sl();
|
||||
let b_sl: usize = b.sl();
|
||||
let a_cols: usize = a.cols();
|
||||
let b_cols: usize = b.cols();
|
||||
let min_cols: usize = min(a_cols, b_cols);
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
use std::cmp::min;
|
||||
|
||||
use crate::ffi::module::MODULE;
|
||||
use crate::ffi::vec_znx;
|
||||
use crate::internals::{apply_binary_op, apply_unary_op, ffi_binary_op_factory_type_0, ffi_binary_op_factory_type_1};
|
||||
use crate::{Backend, Module, VecZnx, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, assert_alignement, switch_degree};
|
||||
use crate::{Backend, Module, VecZnx, ZnxBase, ZnxInfos, ZnxLayout, assert_alignement, switch_degree};
|
||||
pub trait VecZnxOps {
|
||||
/// Allocates a new [VecZnx].
|
||||
///
|
||||
@@ -43,62 +39,70 @@ pub trait VecZnxOps {
|
||||
fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize;
|
||||
|
||||
/// Returns the minimum number of bytes necessary for normalization.
|
||||
fn vec_znx_normalize_tmp_bytes(&self, cols: usize) -> usize;
|
||||
fn vec_znx_normalize_tmp_bytes(&self) -> usize;
|
||||
|
||||
/// Normalizes `a` and stores the result into `b`.
|
||||
fn vec_znx_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnx, tmp_bytes: &mut [u8]);
|
||||
/// Normalizes the selected column of `a` and stores the result into the selected column of `res`.
|
||||
fn vec_znx_normalize(
|
||||
&self,
|
||||
log_base2k: usize,
|
||||
res: &mut VecZnx,
|
||||
col_res: usize,
|
||||
a: &VecZnx,
|
||||
col_a: usize,
|
||||
tmp_bytes: &mut [u8],
|
||||
);
|
||||
|
||||
/// Normalizes `a` and stores the result into `a`.
|
||||
fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]);
|
||||
/// Normalizes the selected column of `a`.
|
||||
fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, col_a: usize, tmp_bytes: &mut [u8]);
|
||||
|
||||
/// Adds `a` to `b` and write the result on `c`.
|
||||
fn vec_znx_add(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx);
|
||||
/// Adds the selected column of `a` to the selected column of `b` and write the result on the selected column of `c`.
|
||||
fn vec_znx_add(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize, b: &VecZnx, col_b: usize);
|
||||
|
||||
/// Adds `a` to `b` and write the result on `b`.
|
||||
fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx);
|
||||
/// Adds the selected column of `a` to the selected column of `b` and write the result on the selected column of `res`.
|
||||
fn vec_znx_add_inplace(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize);
|
||||
|
||||
/// Subtracts `b` to `a` and write the result on `c`.
|
||||
fn vec_znx_sub(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx);
|
||||
/// Subtracts the selected column of `b` to the selected column of `a` and write the result on the selected column of `res`.
|
||||
fn vec_znx_sub(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize, b: &VecZnx, col_b: usize);
|
||||
|
||||
/// Subtracts `a` to `b` and write the result on `b`.
|
||||
fn vec_znx_sub_ab_inplace(&self, b: &mut VecZnx, a: &VecZnx);
|
||||
/// Subtracts the selected column of `a` to the selected column of `res`.
|
||||
fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize);
|
||||
|
||||
/// Subtracts `b` to `a` and write the result on `b`.
|
||||
fn vec_znx_sub_ba_inplace(&self, b: &mut VecZnx, a: &VecZnx);
|
||||
/// Subtracts the selected column of `a` to the selected column of `res` and negates the selected column of `res`.
|
||||
fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize);
|
||||
|
||||
// Negates `a` and stores the result on `b`.
|
||||
fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx);
|
||||
// Negates the selected column of `a` and stores the result on the selected column of `res`.
|
||||
fn vec_znx_negate(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize);
|
||||
|
||||
/// Negages `a` and stores the result on `a`.
|
||||
fn vec_znx_negate_inplace(&self, a: &mut VecZnx);
|
||||
/// Negates the selected column of `a`.
|
||||
fn vec_znx_negate_inplace(&self, a: &mut VecZnx, col_a: usize);
|
||||
|
||||
/// Multiplies `a` by X^k and stores the result on `b`.
|
||||
fn vec_znx_rotate(&self, k: i64, b: &mut VecZnx, a: &VecZnx);
|
||||
/// Multiplies the selected column of `a` by X^k and stores the result on the selected column of `res`.
|
||||
fn vec_znx_rotate(&self, k: i64, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize);
|
||||
|
||||
/// Multiplies `a` by X^k and stores the result on `a`.
|
||||
fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx);
|
||||
/// Multiplies the selected column of `a` by X^k.
|
||||
fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx, col_a: usize);
|
||||
|
||||
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`.
|
||||
fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx);
|
||||
/// Applies the automorphism X^i -> X^ik on the selected column of `a` and stores the result on the selected column of `res`.
|
||||
fn vec_znx_automorphism(&self, k: i64, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize);
|
||||
|
||||
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`.
|
||||
fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx);
|
||||
/// Applies the automorphism X^i -> X^ik on the selected column of `a`.
|
||||
fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, col_a: usize);
|
||||
|
||||
/// Splits b into subrings and copies them them into a.
|
||||
/// Splits the selected columns of `b` into subrings and copies them them into the selected column of `res`.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// This method requires that all [VecZnx] of b have the same ring degree
|
||||
/// and that b.n() * b.len() <= a.n()
|
||||
fn vec_znx_split(&self, b: &mut Vec<VecZnx>, a: &VecZnx, buf: &mut VecZnx);
|
||||
fn vec_znx_split(&self, res: &mut Vec<VecZnx>, col_res: usize, a: &VecZnx, col_a: usize, buf: &mut VecZnx);
|
||||
|
||||
/// Merges the subrings a into b.
|
||||
/// Merges the subrings of the selected column of `a` into the selected column of `res`.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// This method requires that all [VecZnx] of a have the same ring degree
|
||||
/// and that a.n() * a.len() <= b.n()
|
||||
fn vec_znx_merge(&self, b: &mut VecZnx, a: &Vec<VecZnx>);
|
||||
fn vec_znx_merge(&self, res: &mut VecZnx, col_res: usize, a: &Vec<VecZnx>, col_a: usize);
|
||||
}
|
||||
|
||||
impl<B: Backend> VecZnxOps for Module<B> {
|
||||
@@ -118,164 +122,213 @@ impl<B: Backend> VecZnxOps for Module<B> {
|
||||
VecZnx::from_bytes_borrow(self, cols, size, tmp_bytes)
|
||||
}
|
||||
|
||||
fn vec_znx_normalize_tmp_bytes(&self, cols: usize) -> usize {
|
||||
unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.ptr) as usize * cols }
|
||||
fn vec_znx_normalize_tmp_bytes(&self) -> usize {
|
||||
unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.ptr) as usize }
|
||||
}
|
||||
|
||||
fn vec_znx_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnx, tmp_bytes: &mut [u8]) {
|
||||
fn vec_znx_normalize(
|
||||
&self,
|
||||
log_base2k: usize,
|
||||
res: &mut VecZnx,
|
||||
col_res: usize,
|
||||
a: &VecZnx,
|
||||
col_a: usize,
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(tmp_bytes.len() >= Self::vec_znx_normalize_tmp_bytes(&self, a.cols()));
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
assert!(tmp_bytes.len() >= Self::vec_znx_normalize_tmp_bytes(&self));
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
|
||||
let a_size: usize = a.size();
|
||||
let b_size: usize = b.sl();
|
||||
let a_sl: usize = a.size();
|
||||
let b_sl: usize = a.sl();
|
||||
let a_cols: usize = a.cols();
|
||||
let b_cols: usize = b.cols();
|
||||
let min_cols: usize = min(a_cols, b_cols);
|
||||
(0..min_cols).for_each(|i| unsafe {
|
||||
unsafe {
|
||||
vec_znx::vec_znx_normalize_base2k(
|
||||
self.ptr,
|
||||
log_base2k as u64,
|
||||
b.at_mut_ptr(i, 0),
|
||||
b_size as u64,
|
||||
b_sl as u64,
|
||||
a.at_ptr(i, 0),
|
||||
a_size as u64,
|
||||
a_sl as u64,
|
||||
res.at_mut_ptr(col_res, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(col_a, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
);
|
||||
});
|
||||
|
||||
(min_cols..b_cols).for_each(|i| (0..b_size).for_each(|j| b.zero_at(i, j)));
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]) {
|
||||
fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, col_a: usize, tmp_bytes: &mut [u8]) {
|
||||
unsafe {
|
||||
let a_ptr: *mut VecZnx = a as *mut VecZnx;
|
||||
Self::vec_znx_normalize(self, log_base2k, &mut *a_ptr, &*a_ptr, tmp_bytes);
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_add(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx) {
|
||||
let op = ffi_ternary_op_factory(
|
||||
self.ptr,
|
||||
c.size(),
|
||||
c.sl(),
|
||||
a.size(),
|
||||
a.sl(),
|
||||
b.size(),
|
||||
b.sl(),
|
||||
vec_znx::vec_znx_add,
|
||||
Self::vec_znx_normalize(
|
||||
self,
|
||||
log_base2k,
|
||||
&mut *a_ptr,
|
||||
col_a,
|
||||
&*a_ptr,
|
||||
col_a,
|
||||
tmp_bytes,
|
||||
);
|
||||
apply_binary_op::<B, VecZnx, VecZnx, VecZnx, false>(self, c, a, b, op);
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx) {
|
||||
fn vec_znx_add(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize, b: &VecZnx, col_b: usize) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(b.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
let b_ptr: *mut VecZnx = b as *mut VecZnx;
|
||||
Self::vec_znx_add(self, &mut *b_ptr, a, &*b_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_sub(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx) {
|
||||
let op = ffi_ternary_op_factory(
|
||||
vec_znx::vec_znx_add(
|
||||
self.ptr,
|
||||
c.size(),
|
||||
c.sl(),
|
||||
a.size(),
|
||||
a.sl(),
|
||||
b.size(),
|
||||
b.sl(),
|
||||
vec_znx::vec_znx_sub,
|
||||
);
|
||||
apply_binary_op::<B, VecZnx, VecZnx, VecZnx, true>(self, c, a, b, op);
|
||||
res.at_mut_ptr(col_res, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(col_a, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(col_b, 0),
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_sub_ab_inplace(&self, b: &mut VecZnx, a: &VecZnx) {
|
||||
fn vec_znx_add_inplace(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize) {
|
||||
unsafe {
|
||||
let b_ptr: *mut VecZnx = b as *mut VecZnx;
|
||||
Self::vec_znx_sub(self, &mut *b_ptr, a, &*b_ptr);
|
||||
let res_ptr: *mut VecZnx = res as *mut VecZnx;
|
||||
Self::vec_znx_add(self, &mut *res_ptr, col_res, a, col_a, &*res_ptr, col_res);
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_sub_ba_inplace(&self, b: &mut VecZnx, a: &VecZnx) {
|
||||
fn vec_znx_sub(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize, b: &VecZnx, col_b: usize) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(b.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
let b_ptr: *mut VecZnx = b as *mut VecZnx;
|
||||
Self::vec_znx_sub(self, &mut *b_ptr, &*b_ptr, a);
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx) {
|
||||
let op = ffi_binary_op_factory_type_0(
|
||||
vec_znx::vec_znx_sub(
|
||||
self.ptr,
|
||||
b.size(),
|
||||
b.sl(),
|
||||
a.size(),
|
||||
a.sl(),
|
||||
vec_znx::vec_znx_negate,
|
||||
);
|
||||
apply_unary_op::<B, VecZnx>(self, b, a, op);
|
||||
res.at_mut_ptr(col_res, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(col_a, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(col_b, 0),
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_negate_inplace(&self, a: &mut VecZnx) {
|
||||
fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize) {
|
||||
unsafe {
|
||||
let res_ptr: *mut VecZnx = res as *mut VecZnx;
|
||||
Self::vec_znx_sub(self, &mut *res_ptr, col_res, a, col_a, &*res_ptr, col_res);
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize) {
|
||||
unsafe {
|
||||
let res_ptr: *mut VecZnx = res as *mut VecZnx;
|
||||
Self::vec_znx_sub(self, &mut *res_ptr, col_res, &*res_ptr, col_res, a, col_a);
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_negate(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_negate(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(col_res, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(col_a, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_negate_inplace(&self, a: &mut VecZnx, col_a: usize) {
|
||||
unsafe {
|
||||
let a_ptr: *mut VecZnx = a as *mut VecZnx;
|
||||
Self::vec_znx_negate(self, &mut *a_ptr, &*a_ptr);
|
||||
Self::vec_znx_negate(self, &mut *a_ptr, col_a, &*a_ptr, col_a);
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_rotate(&self, k: i64, b: &mut VecZnx, a: &VecZnx) {
|
||||
let op = ffi_binary_op_factory_type_1(
|
||||
fn vec_znx_rotate(&self, k: i64, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_rotate(
|
||||
self.ptr,
|
||||
k,
|
||||
b.size(),
|
||||
b.sl(),
|
||||
a.size(),
|
||||
a.sl(),
|
||||
vec_znx::vec_znx_rotate,
|
||||
);
|
||||
apply_unary_op::<B, VecZnx>(self, b, a, op);
|
||||
res.at_mut_ptr(col_res, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(col_a, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx) {
|
||||
fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx, col_a: usize) {
|
||||
unsafe {
|
||||
let a_ptr: *mut VecZnx = a as *mut VecZnx;
|
||||
Self::vec_znx_rotate(self, k, &mut *a_ptr, &*a_ptr);
|
||||
Self::vec_znx_rotate(self, k, &mut *a_ptr, col_a, &*a_ptr, col_a);
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx) {
|
||||
let op = ffi_binary_op_factory_type_1(
|
||||
fn vec_znx_automorphism(&self, k: i64, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
self.ptr,
|
||||
k,
|
||||
b.size(),
|
||||
b.sl(),
|
||||
a.size(),
|
||||
a.sl(),
|
||||
vec_znx::vec_znx_automorphism,
|
||||
);
|
||||
apply_unary_op::<B, VecZnx>(self, b, a, op);
|
||||
res.at_mut_ptr(col_res, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(col_a, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx) {
|
||||
fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, col_a: usize) {
|
||||
unsafe {
|
||||
let a_ptr: *mut VecZnx = a as *mut VecZnx;
|
||||
Self::vec_znx_automorphism(self, k, &mut *a_ptr, &*a_ptr);
|
||||
Self::vec_znx_automorphism(self, k, &mut *a_ptr, col_a, &*a_ptr, col_a);
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_split(&self, b: &mut Vec<VecZnx>, a: &VecZnx, buf: &mut VecZnx) {
|
||||
let (n_in, n_out) = (a.n(), b[0].n());
|
||||
fn vec_znx_split(&self, res: &mut Vec<VecZnx>, col_res: usize, a: &VecZnx, col_a: usize, buf: &mut VecZnx) {
|
||||
let (n_in, n_out) = (a.n(), res[0].n());
|
||||
|
||||
debug_assert!(
|
||||
n_out < n_in,
|
||||
"invalid a: output ring degree should be smaller"
|
||||
);
|
||||
b[1..].iter().for_each(|bi| {
|
||||
res[1..].iter().for_each(|bi| {
|
||||
debug_assert_eq!(
|
||||
bi.n(),
|
||||
n_out,
|
||||
@@ -283,19 +336,19 @@ impl<B: Backend> VecZnxOps for Module<B> {
|
||||
)
|
||||
});
|
||||
|
||||
b.iter_mut().enumerate().for_each(|(i, bi)| {
|
||||
res.iter_mut().enumerate().for_each(|(i, bi)| {
|
||||
if i == 0 {
|
||||
switch_degree(bi, a);
|
||||
self.vec_znx_rotate(-1, buf, a);
|
||||
switch_degree(bi, col_res, a, col_a);
|
||||
self.vec_znx_rotate(-1, buf, 0, a, col_a);
|
||||
} else {
|
||||
switch_degree(bi, buf);
|
||||
self.vec_znx_rotate_inplace(-1, buf);
|
||||
switch_degree(bi, col_res, buf, col_a);
|
||||
self.vec_znx_rotate_inplace(-1, buf, col_a);
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn vec_znx_merge(&self, b: &mut VecZnx, a: &Vec<VecZnx>) {
|
||||
let (n_in, n_out) = (b.n(), a[0].n());
|
||||
fn vec_znx_merge(&self, res: &mut VecZnx, col_res: usize, a: &Vec<VecZnx>, col_a: usize) {
|
||||
let (n_in, n_out) = (res.n(), a[0].n());
|
||||
|
||||
debug_assert!(
|
||||
n_out < n_in,
|
||||
@@ -310,456 +363,10 @@ impl<B: Backend> VecZnxOps for Module<B> {
|
||||
});
|
||||
|
||||
a.iter().enumerate().for_each(|(_, ai)| {
|
||||
switch_degree(b, ai);
|
||||
self.vec_znx_rotate_inplace(-1, b);
|
||||
switch_degree(res, col_res, ai, col_a);
|
||||
self.vec_znx_rotate_inplace(-1, res, col_res);
|
||||
});
|
||||
|
||||
self.vec_znx_rotate_inplace(a.len() as i64, b);
|
||||
}
|
||||
}
|
||||
|
||||
fn ffi_ternary_op_factory(
|
||||
module_ptr: *const MODULE,
|
||||
c_size: usize,
|
||||
c_sl: usize,
|
||||
a_size: usize,
|
||||
a_sl: usize,
|
||||
b_size: usize,
|
||||
b_sl: usize,
|
||||
op_fn: unsafe extern "C" fn(*const MODULE, *mut i64, u64, u64, *const i64, u64, u64, *const i64, u64, u64),
|
||||
) -> impl Fn(&mut [i64], &[i64], &[i64]) {
|
||||
move |cv: &mut [i64], av: &[i64], bv: &[i64]| unsafe {
|
||||
op_fn(
|
||||
module_ptr,
|
||||
cv.as_mut_ptr(),
|
||||
c_size as u64,
|
||||
c_sl as u64,
|
||||
av.as_ptr(),
|
||||
a_size as u64,
|
||||
a_sl as u64,
|
||||
bv.as_ptr(),
|
||||
b_size as u64,
|
||||
b_sl as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::internals::znx_post_process_ternary_op;
|
||||
use crate::{Backend, FFT64, Module, Sampling, VecZnx, VecZnxOps, ZnxBasics, ZnxInfos, ZnxLayout, ffi::vec_znx};
|
||||
|
||||
use itertools::izip;
|
||||
use sampling::source::Source;
|
||||
use std::cmp::min;
|
||||
|
||||
#[test]
|
||||
fn vec_znx_add() {
|
||||
let n: usize = 8;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let op = |cv: &mut [i64], av: &[i64], bv: &[i64]| {
|
||||
izip!(cv.iter_mut(), bv.iter(), av.iter()).for_each(|(ci, bi, ai)| *ci = *bi + *ai);
|
||||
};
|
||||
test_binary_op::<false, _>(
|
||||
&module,
|
||||
&|c: &mut VecZnx, a: &VecZnx, b: &VecZnx| module.vec_znx_add(c, a, b),
|
||||
op,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vec_znx_add_inplace() {
|
||||
let n: usize = 8;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let op = |bv: &mut [i64], av: &[i64]| {
|
||||
izip!(bv.iter_mut(), av.iter()).for_each(|(bi, ai)| *bi = *bi + *ai);
|
||||
};
|
||||
test_binary_op_inplace::<false, _>(
|
||||
&module,
|
||||
&|b: &mut VecZnx, a: &VecZnx| module.vec_znx_add_inplace(b, a),
|
||||
op,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vec_znx_sub() {
|
||||
let n: usize = 8;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let op = |cv: &mut [i64], av: &[i64], bv: &[i64]| {
|
||||
izip!(cv.iter_mut(), bv.iter(), av.iter()).for_each(|(ci, bi, ai)| *ci = *bi - *ai);
|
||||
};
|
||||
test_binary_op::<true, _>(
|
||||
&module,
|
||||
&|c: &mut VecZnx, a: &VecZnx, b: &VecZnx| module.vec_znx_sub(c, a, b),
|
||||
op,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vec_znx_sub_ab_inplace() {
|
||||
let n: usize = 8;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let op = |bv: &mut [i64], av: &[i64]| {
|
||||
izip!(bv.iter_mut(), av.iter()).for_each(|(bi, ai)| *bi = *ai - *bi);
|
||||
};
|
||||
test_binary_op_inplace::<true, _>(
|
||||
&module,
|
||||
&|b: &mut VecZnx, a: &VecZnx| module.vec_znx_sub_ab_inplace(b, a),
|
||||
op,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vec_znx_sub_ba_inplace() {
|
||||
let n: usize = 8;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let op = |bv: &mut [i64], av: &[i64]| {
|
||||
izip!(bv.iter_mut(), av.iter()).for_each(|(bi, ai)| *bi = *bi - *ai);
|
||||
};
|
||||
test_binary_op_inplace::<false, _>(
|
||||
&module,
|
||||
&|b: &mut VecZnx, a: &VecZnx| module.vec_znx_sub_ba_inplace(b, a),
|
||||
op,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vec_znx_negate() {
|
||||
let n: usize = 8;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let op = |b: &mut [i64], a: &[i64]| {
|
||||
izip!(b.iter_mut(), a.iter()).for_each(|(bi, ai)| *bi = -*ai);
|
||||
};
|
||||
test_unary_op(
|
||||
&module,
|
||||
|b: &mut VecZnx, a: &VecZnx| module.vec_znx_negate(b, a),
|
||||
op,
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vec_znx_negate_inplace() {
|
||||
let n: usize = 8;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let op = |a: &mut [i64]| a.iter_mut().for_each(|xi| *xi = -*xi);
|
||||
test_unary_op_inplace(
|
||||
&module,
|
||||
|a: &mut VecZnx| module.vec_znx_negate_inplace(a),
|
||||
op,
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vec_znx_rotate() {
|
||||
let n: usize = 8;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let k: i64 = 53;
|
||||
let op = |b: &mut [i64], a: &[i64]| {
|
||||
assert_eq!(b.len(), a.len());
|
||||
b.copy_from_slice(a);
|
||||
|
||||
let mut k_mod2n: i64 = k % (2 * n as i64);
|
||||
if k_mod2n < 0 {
|
||||
k_mod2n += 2 * n as i64;
|
||||
}
|
||||
let sign: i64 = (k_mod2n.abs() / (n as i64)) & 1;
|
||||
let k_modn: i64 = k_mod2n % (n as i64);
|
||||
|
||||
b.rotate_right(k_modn as usize);
|
||||
b[0..k_modn as usize].iter_mut().for_each(|x| *x = -*x);
|
||||
|
||||
if sign == 1 {
|
||||
b.iter_mut().for_each(|x| *x = -*x);
|
||||
}
|
||||
};
|
||||
test_unary_op(
|
||||
&module,
|
||||
|b: &mut VecZnx, a: &VecZnx| module.vec_znx_rotate(k, b, a),
|
||||
op,
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vec_znx_rotate_inplace() {
|
||||
let n: usize = 8;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let k: i64 = 53;
|
||||
let rot = |a: &mut [i64]| {
|
||||
let mut k_mod2n: i64 = k % (2 * n as i64);
|
||||
if k_mod2n < 0 {
|
||||
k_mod2n += 2 * n as i64;
|
||||
}
|
||||
let sign: i64 = (k_mod2n.abs() / (n as i64)) & 1;
|
||||
let k_modn: i64 = k_mod2n % (n as i64);
|
||||
|
||||
a.rotate_right(k_modn as usize);
|
||||
a[0..k_modn as usize].iter_mut().for_each(|x| *x = -*x);
|
||||
|
||||
if sign == 1 {
|
||||
a.iter_mut().for_each(|x| *x = -*x);
|
||||
}
|
||||
};
|
||||
test_unary_op_inplace(
|
||||
&module,
|
||||
|a: &mut VecZnx| module.vec_znx_rotate_inplace(k, a),
|
||||
rot,
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vec_znx_automorphism() {
|
||||
let n: usize = 8;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let k: i64 = -5;
|
||||
let op = |b: &mut [i64], a: &[i64]| {
|
||||
assert_eq!(b.len(), a.len());
|
||||
unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
module.ptr,
|
||||
k,
|
||||
b.as_mut_ptr(),
|
||||
1u64,
|
||||
n as u64,
|
||||
a.as_ptr(),
|
||||
1u64,
|
||||
n as u64,
|
||||
);
|
||||
}
|
||||
};
|
||||
test_unary_op(
|
||||
&module,
|
||||
|b: &mut VecZnx, a: &VecZnx| module.vec_znx_automorphism(k, b, a),
|
||||
op,
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vec_znx_automorphism_inplace() {
|
||||
let n: usize = 8;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let k: i64 = -5;
|
||||
let op = |a: &mut [i64]| unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
module.ptr,
|
||||
k,
|
||||
a.as_mut_ptr(),
|
||||
1u64,
|
||||
n as u64,
|
||||
a.as_ptr(),
|
||||
1u64,
|
||||
n as u64,
|
||||
);
|
||||
};
|
||||
test_unary_op_inplace(
|
||||
&module,
|
||||
|a: &mut VecZnx| module.vec_znx_automorphism_inplace(k, a),
|
||||
op,
|
||||
)
|
||||
}
|
||||
|
||||
fn test_binary_op<const NEGATE: bool, B: Backend>(
|
||||
module: &Module<B>,
|
||||
func_have: impl Fn(&mut VecZnx, &VecZnx, &VecZnx),
|
||||
func_want: impl Fn(&mut [i64], &[i64], &[i64]),
|
||||
) {
|
||||
let a_size: usize = 3;
|
||||
let b_size: usize = 4;
|
||||
let c_size: usize = 5;
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
[1usize, 2, 3].iter().for_each(|a_cols| {
|
||||
[1usize, 2, 3].iter().for_each(|b_cols| {
|
||||
[1usize, 2, 3].iter().for_each(|c_cols| {
|
||||
let min_ab_cols: usize = min(*a_cols, *b_cols);
|
||||
let min_cols: usize = min(*c_cols, min_ab_cols);
|
||||
let min_size: usize = min(c_size, min(a_size, b_size));
|
||||
|
||||
// Allocats a and populates with random values.
|
||||
let mut a: VecZnx = module.new_vec_znx(*a_cols, a_size);
|
||||
(0..*a_cols).for_each(|i| {
|
||||
module.fill_uniform(3, &mut a, i, a_size, &mut source);
|
||||
});
|
||||
|
||||
// Allocats b and populates with random values.
|
||||
let mut b: VecZnx = module.new_vec_znx(*b_cols, b_size);
|
||||
(0..*b_cols).for_each(|i| {
|
||||
module.fill_uniform(3, &mut b, i, b_size, &mut source);
|
||||
});
|
||||
|
||||
// Allocats c and populates with random values.
|
||||
let mut c_have: VecZnx = module.new_vec_znx(*c_cols, c_size);
|
||||
(0..c_have.cols()).for_each(|i| {
|
||||
module.fill_uniform(3, &mut c_have, i, c_size, &mut source);
|
||||
});
|
||||
|
||||
// Applies the function to test
|
||||
func_have(&mut c_have, &a, &b);
|
||||
|
||||
let mut c_want: VecZnx = module.new_vec_znx(*c_cols, c_size);
|
||||
|
||||
// Applies the reference function and expected behavior.
|
||||
// Adds with the minimum matching columns
|
||||
(0..min_cols).for_each(|i| {
|
||||
// Adds with th eminimum matching size
|
||||
(0..min_size).for_each(|j| {
|
||||
func_want(c_want.at_poly_mut(i, j), b.at_poly(i, j), a.at_poly(i, j));
|
||||
});
|
||||
|
||||
if a_size > b_size {
|
||||
// Copies remaining size of lh if lh.size() > rh.size()
|
||||
(min_size..a_size).for_each(|j| {
|
||||
izip!(c_want.at_poly_mut(i, j).iter_mut(), a.at_poly(i, j).iter()).for_each(|(ci, ai)| *ci = *ai);
|
||||
if NEGATE {
|
||||
c_want.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
// Copies the remaining size of rh if the are greater
|
||||
(min_size..b_size).for_each(|j| {
|
||||
izip!(c_want.at_poly_mut(i, j).iter_mut(), b.at_poly(i, j).iter()).for_each(|(ci, bi)| *ci = *bi);
|
||||
if NEGATE {
|
||||
c_want.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
znx_post_process_ternary_op::<VecZnx, VecZnx, VecZnx, NEGATE>(&mut c_want, &a, &b);
|
||||
|
||||
assert_eq!(c_have.raw(), c_want.raw());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
fn test_binary_op_inplace<const NEGATE: bool, B: Backend>(
|
||||
module: &Module<B>,
|
||||
func_have: impl Fn(&mut VecZnx, &VecZnx),
|
||||
func_want: impl Fn(&mut [i64], &[i64]),
|
||||
) {
|
||||
let a_size: usize = 3;
|
||||
let b_size: usize = 5;
|
||||
let mut source = Source::new([0u8; 32]);
|
||||
|
||||
[1usize, 2, 3].iter().for_each(|a_cols| {
|
||||
[1usize, 2, 3].iter().for_each(|b_cols| {
|
||||
let min_cols: usize = min(*b_cols, *a_cols);
|
||||
let min_size: usize = min(b_size, a_size);
|
||||
|
||||
// Allocats a and populates with random values.
|
||||
let mut a: VecZnx = module.new_vec_znx(*a_cols, a_size);
|
||||
(0..*a_cols).for_each(|i| {
|
||||
module.fill_uniform(3, &mut a, i, a_size, &mut source);
|
||||
});
|
||||
|
||||
// Allocats b and populates with random values.
|
||||
let mut b_have: VecZnx = module.new_vec_znx(*b_cols, b_size);
|
||||
(0..*b_cols).for_each(|i| {
|
||||
module.fill_uniform(3, &mut b_have, i, b_size, &mut source);
|
||||
});
|
||||
|
||||
let mut b_want: VecZnx = module.new_vec_znx(*b_cols, b_size);
|
||||
b_want.raw_mut().copy_from_slice(b_have.raw());
|
||||
|
||||
// Applies the function to test.
|
||||
func_have(&mut b_have, &a);
|
||||
|
||||
// Applies the reference function and expected behavior.
|
||||
// Applies with the minimum matching columns
|
||||
(0..min_cols).for_each(|i| {
|
||||
// Adds with th eminimum matching size
|
||||
(0..min_size).for_each(|j| func_want(b_want.at_poly_mut(i, j), a.at_poly(i, j)));
|
||||
if NEGATE {
|
||||
(min_size..b_size).for_each(|j| {
|
||||
b_want.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
assert_eq!(b_have.raw(), b_want.raw());
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
fn test_unary_op<B: Backend>(
|
||||
module: &Module<B>,
|
||||
func_have: impl Fn(&mut VecZnx, &VecZnx),
|
||||
func_want: impl Fn(&mut [i64], &[i64]),
|
||||
) {
|
||||
let a_size: usize = 3;
|
||||
let b_size: usize = 5;
|
||||
let mut source = Source::new([0u8; 32]);
|
||||
|
||||
[1usize, 2, 3].iter().for_each(|a_cols| {
|
||||
[1usize, 2, 3].iter().for_each(|b_cols| {
|
||||
let min_cols: usize = min(*b_cols, *a_cols);
|
||||
let min_size: usize = min(b_size, a_size);
|
||||
|
||||
// Allocats a and populates with random values.
|
||||
let mut a: VecZnx = module.new_vec_znx(*a_cols, a_size);
|
||||
(0..a.cols()).for_each(|i| {
|
||||
module.fill_uniform(3, &mut a, i, a_size, &mut source);
|
||||
});
|
||||
|
||||
// Allocats b and populates with random values.
|
||||
let mut b_have: VecZnx = module.new_vec_znx(*b_cols, b_size);
|
||||
(0..b_have.cols()).for_each(|i| {
|
||||
module.fill_uniform(3, &mut b_have, i, b_size, &mut source);
|
||||
});
|
||||
|
||||
let mut b_want: VecZnx = module.new_vec_znx(*b_cols, b_size);
|
||||
|
||||
// Applies the function to test.
|
||||
func_have(&mut b_have, &a);
|
||||
|
||||
// Applies the reference function and expected behavior.
|
||||
// Applies on the minimum matching columns
|
||||
(0..min_cols).for_each(|i| {
|
||||
// Applies on the minimum matching size
|
||||
(0..min_size).for_each(|j| func_want(b_want.at_poly_mut(i, j), a.at_poly(i, j)));
|
||||
|
||||
// Zeroes the unmatching size
|
||||
(min_size..b_size).for_each(|j| {
|
||||
b_want.zero_at(i, j);
|
||||
})
|
||||
});
|
||||
|
||||
// Zeroes the unmatching columns
|
||||
(min_cols..*b_cols).for_each(|i| {
|
||||
(0..b_size).for_each(|j| {
|
||||
b_want.zero_at(i, j);
|
||||
})
|
||||
});
|
||||
|
||||
assert_eq!(b_have.raw(), b_want.raw());
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
fn test_unary_op_inplace<B: Backend>(module: &Module<B>, func_have: impl Fn(&mut VecZnx), func_want: impl Fn(&mut [i64])) {
|
||||
let a_size: usize = 3;
|
||||
let mut source = Source::new([0u8; 32]);
|
||||
[1usize, 2, 3].iter().for_each(|a_cols| {
|
||||
let mut a_have: VecZnx = module.new_vec_znx(*a_cols, a_size);
|
||||
(0..*a_cols).for_each(|i| {
|
||||
module.fill_uniform(3, &mut a_have, i, a_size, &mut source);
|
||||
});
|
||||
|
||||
// Allocats a and populates with random values.
|
||||
let mut a_want: VecZnx = module.new_vec_znx(*a_cols, a_size);
|
||||
a_have.raw_mut().copy_from_slice(a_want.raw());
|
||||
|
||||
// Applies the function to test.
|
||||
func_have(&mut a_have);
|
||||
|
||||
// Applies the reference function and expected behavior.
|
||||
// Applies on the minimum matching columns
|
||||
(0..*a_cols).for_each(|i| {
|
||||
// Applies on the minimum matching size
|
||||
(0..a_size).for_each(|j| func_want(a_want.at_poly_mut(i, j)));
|
||||
});
|
||||
|
||||
assert_eq!(a_have.raw(), a_want.raw());
|
||||
});
|
||||
self.vec_znx_rotate_inplace(a.len() as i64, res, col_res);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user