more fixes

This commit is contained in:
Jean-Philippe Bossuat
2025-04-29 18:16:09 +02:00
parent 917a472437
commit 06d0c5e832
5 changed files with 22 additions and 119 deletions

View File

@@ -107,7 +107,7 @@ fn encode_vec_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize,
// values on the last limb. // values on the last limb.
// Else we decompose values base2k. // Else we decompose values base2k.
if log_max + log_k_rem < 63 || log_k_rem == log_base2k { if log_max + log_k_rem < 63 || log_k_rem == log_base2k {
a.at_poly_mut(col_i, size - 1)[..data_len].copy_from_slice(&data[..data_len]); a.at_mut(col_i, size - 1)[..data_len].copy_from_slice(&data[..data_len]);
} else { } else {
let mask: i64 = (1 << log_base2k) - 1; let mask: i64 = (1 << log_base2k) - 1;
let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k); let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k);
@@ -116,7 +116,7 @@ fn encode_vec_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize,
.enumerate() .enumerate()
.for_each(|(i, i_rev)| { .for_each(|(i, i_rev)| {
let shift: usize = i * log_base2k; let shift: usize = i * log_base2k;
izip!(a.at_poly_mut(col_i, i_rev).iter_mut(), data.iter()).for_each(|(y, x)| *y = (x >> shift) & mask); izip!(a.at_mut(col_i, i_rev).iter_mut(), data.iter()).for_each(|(y, x)| *y = (x >> shift) & mask);
}) })
} }
@@ -124,7 +124,7 @@ fn encode_vec_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize,
if log_k_rem != log_base2k { if log_k_rem != log_base2k {
let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k); let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k);
(size - steps..size).rev().for_each(|i| { (size - steps..size).rev().for_each(|i| {
a.at_poly_mut(col_i, i)[..data_len] a.at_mut(col_i, i)[..data_len]
.iter_mut() .iter_mut()
.for_each(|x| *x <<= log_k_rem); .for_each(|x| *x <<= log_k_rem);
}) })
@@ -143,16 +143,16 @@ fn decode_vec_i64(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, dat
); );
assert!(col_i < a.cols()); assert!(col_i < a.cols());
} }
data.copy_from_slice(a.at_poly(col_i, 0)); data.copy_from_slice(a.at(col_i, 0));
let rem: usize = log_base2k - (log_k % log_base2k); let rem: usize = log_base2k - (log_k % log_base2k);
(1..size).for_each(|i| { (1..size).for_each(|i| {
if i == size - 1 && rem != log_base2k { if i == size - 1 && rem != log_base2k {
let k_rem: usize = log_base2k - rem; let k_rem: usize = log_base2k - rem;
izip!(a.at_poly(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| { izip!(a.at(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| {
*y = (*y << k_rem) + (x >> rem); *y = (*y << k_rem) + (x >> rem);
}); });
} else { } else {
izip!(a.at_poly(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| { izip!(a.at(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| {
*y = (*y << log_base2k) + x; *y = (*y << log_base2k) + x;
}); });
} }
@@ -180,12 +180,12 @@ fn decode_vec_float(a: &VecZnx, col_i: usize, log_base2k: usize, data: &mut [Flo
// y[i] = sum x[j][i] * 2^{-log_base2k*j} // y[i] = sum x[j][i] * 2^{-log_base2k*j}
(0..size).for_each(|i| { (0..size).for_each(|i| {
if i == 0 { if i == 0 {
izip!(a.at_poly(col_i, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { izip!(a.at(col_i, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
y.assign(*x); y.assign(*x);
*y /= &base; *y /= &base;
}); });
} else { } else {
izip!(a.at_poly(col_i, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { izip!(a.at(col_i, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
*y += Float::with_val(prec, *x); *y += Float::with_val(prec, *x);
*y /= &base; *y /= &base;
}); });
@@ -209,13 +209,13 @@ fn encode_coeff_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usiz
} }
let log_k_rem: usize = log_base2k - (log_k % log_base2k); let log_k_rem: usize = log_base2k - (log_k % log_base2k);
(0..a.size()).for_each(|j| a.at_poly_mut(col_i, j)[i] = 0); (0..a.size()).for_each(|j| a.at_mut(col_i, j)[i] = 0);
// If 2^{log_base2k} * 2^{log_k_rem} < 2^{63}-1, then we can simply copy // If 2^{log_base2k} * 2^{log_k_rem} < 2^{63}-1, then we can simply copy
// values on the last limb. // values on the last limb.
// Else we decompose values base2k. // Else we decompose values base2k.
if log_max + log_k_rem < 63 || log_k_rem == log_base2k { if log_max + log_k_rem < 63 || log_k_rem == log_base2k {
a.at_poly_mut(col_i, size - 1)[i] = value; a.at_mut(col_i, size - 1)[i] = value;
} else { } else {
let mask: i64 = (1 << log_base2k) - 1; let mask: i64 = (1 << log_base2k) - 1;
let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k); let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k);
@@ -223,7 +223,7 @@ fn encode_coeff_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usiz
.rev() .rev()
.enumerate() .enumerate()
.for_each(|(j, j_rev)| { .for_each(|(j, j_rev)| {
a.at_poly_mut(col_i, j_rev)[i] = (value >> (j * log_base2k)) & mask; a.at_mut(col_i, j_rev)[i] = (value >> (j * log_base2k)) & mask;
}) })
} }
@@ -231,7 +231,7 @@ fn encode_coeff_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usiz
if log_k_rem != log_base2k { if log_k_rem != log_base2k {
let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k); let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k);
(size - steps..size).rev().for_each(|j| { (size - steps..size).rev().for_each(|j| {
a.at_poly_mut(col_i, j)[i] <<= log_k_rem; a.at_mut(col_i, j)[i] <<= log_k_rem;
}) })
} }
} }

View File

@@ -1,96 +0,0 @@
use std::cmp::{max, min};
use crate::{Backend, IntegerType, Module, ZnxBasics, ZnxLayout, ffi::module::MODULE};
#[inline(always)]
pub fn apply_unary_op<B: Backend, T: ZnxBasics + ZnxLayout>(
module: &Module<B>,
b: &mut T,
a: &T,
op: impl Fn(&mut [T::Scalar], &[T::Scalar]),
) where
<T as ZnxLayout>::Scalar: IntegerType,
{
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(b.n(), module.n());
}
let a_cols: usize = a.cols();
let b_cols: usize = b.cols();
let min_cols: usize = min(a_cols, b_cols);
// Applies over the shared cols between (a, b)
(0..min_cols).for_each(|i| op(b.at_poly_mut(i, 0), a.at_poly(i, 0)));
// Zeroes the remaining cols of b.
(min_cols..b_cols).for_each(|i| (0..b.size()).for_each(|j| b.zero_at(i, j)));
}
pub fn ffi_ternary_op_factory<T>(
module_ptr: *const MODULE,
c_size: usize,
c_sl: usize,
a_size: usize,
a_sl: usize,
b_size: usize,
b_sl: usize,
op_fn: unsafe extern "C" fn(*const MODULE, *mut T, u64, u64, *const T, u64, u64, *const T, u64, u64),
) -> impl Fn(&mut [T], &[T], &[T]) {
move |cv: &mut [T], av: &[T], bv: &[T]| unsafe {
op_fn(
module_ptr,
cv.as_mut_ptr(),
c_size as u64,
c_sl as u64,
av.as_ptr(),
a_size as u64,
a_sl as u64,
bv.as_ptr(),
b_size as u64,
b_sl as u64,
)
}
}
pub fn ffi_binary_op_factory_type_0<T>(
module_ptr: *const MODULE,
b_size: usize,
b_sl: usize,
a_size: usize,
a_sl: usize,
op_fn: unsafe extern "C" fn(*const MODULE, *mut T, u64, u64, *const T, u64, u64),
) -> impl Fn(&mut [T], &[T]) {
move |bv: &mut [T], av: &[T]| unsafe {
op_fn(
module_ptr,
bv.as_mut_ptr(),
b_size as u64,
b_sl as u64,
av.as_ptr(),
a_size as u64,
a_sl as u64,
)
}
}
pub fn ffi_binary_op_factory_type_1<T>(
module_ptr: *const MODULE,
k: i64,
b_size: usize,
b_sl: usize,
a_size: usize,
a_sl: usize,
op_fn: unsafe extern "C" fn(*const MODULE, i64, *mut T, u64, u64, *const T, u64, u64),
) -> impl Fn(&mut [T], &[T]) {
move |bv: &mut [T], av: &[T]| unsafe {
op_fn(
module_ptr,
k,
bv.as_mut_ptr(),
b_size as u64,
b_sl as u64,
av.as_ptr(),
a_size as u64,
a_sl as u64,
)
}
}

View File

@@ -3,7 +3,6 @@ pub mod encoding;
#[allow(non_camel_case_types, non_snake_case, non_upper_case_globals, dead_code, improper_ctypes)] #[allow(non_camel_case_types, non_snake_case, non_upper_case_globals, dead_code, improper_ctypes)]
// Other modules and exports // Other modules and exports
pub mod ffi; pub mod ffi;
mod internals;
pub mod mat_znx_dft; pub mod mat_znx_dft;
pub mod module; pub mod module;
pub mod sampling; pub mod sampling;

View File

@@ -32,12 +32,12 @@ pub trait Sampling {
} }
impl<B: Backend> Sampling for Module<B> { impl<B: Backend> Sampling for Module<B> {
fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, col_i: usize, size: usize, source: &mut Source) { fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, col_a: usize, size: usize, source: &mut Source) {
let base2k: u64 = 1 << log_base2k; let base2k: u64 = 1 << log_base2k;
let mask: u64 = base2k - 1; let mask: u64 = base2k - 1;
let base2k_half: i64 = (base2k >> 1) as i64; let base2k_half: i64 = (base2k >> 1) as i64;
(0..size).for_each(|j| { (0..size).for_each(|j| {
a.at_poly_mut(col_i, j) a.at_mut(col_a, j)
.iter_mut() .iter_mut()
.for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half); .for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half);
}) })
@@ -47,7 +47,7 @@ impl<B: Backend> Sampling for Module<B> {
&self, &self,
log_base2k: usize, log_base2k: usize,
a: &mut VecZnx, a: &mut VecZnx,
col_i: usize, col_a: usize,
log_k: usize, log_k: usize,
source: &mut Source, source: &mut Source,
dist: D, dist: D,
@@ -63,7 +63,7 @@ impl<B: Backend> Sampling for Module<B> {
let log_base2k_rem: usize = log_k % log_base2k; let log_base2k_rem: usize = log_k % log_base2k;
if log_base2k_rem != 0 { if log_base2k_rem != 0 {
a.at_poly_mut(col_i, limb).iter_mut().for_each(|a| { a.at_mut(col_a, limb).iter_mut().for_each(|a| {
let mut dist_f64: f64 = dist.sample(source); let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound { while dist_f64.abs() > bound {
dist_f64 = dist.sample(source) dist_f64 = dist.sample(source)
@@ -71,7 +71,7 @@ impl<B: Backend> Sampling for Module<B> {
*a += (dist_f64.round() as i64) << log_base2k_rem; *a += (dist_f64.round() as i64) << log_base2k_rem;
}); });
} else { } else {
a.at_poly_mut(col_i, limb).iter_mut().for_each(|a| { a.at_mut(col_a, limb).iter_mut().for_each(|a| {
let mut dist_f64: f64 = dist.sample(source); let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound { while dist_f64.abs() > bound {
dist_f64 = dist.sample(source) dist_f64 = dist.sample(source)
@@ -85,7 +85,7 @@ impl<B: Backend> Sampling for Module<B> {
&self, &self,
log_base2k: usize, log_base2k: usize,
a: &mut VecZnx, a: &mut VecZnx,
col_i: usize, col_a: usize,
log_k: usize, log_k: usize,
source: &mut Source, source: &mut Source,
sigma: f64, sigma: f64,
@@ -94,7 +94,7 @@ impl<B: Backend> Sampling for Module<B> {
self.add_dist_f64( self.add_dist_f64(
log_base2k, log_base2k,
a, a,
col_i, col_a,
log_k, log_k,
source, source,
Normal::new(0.0, sigma).unwrap(), Normal::new(0.0, sigma).unwrap(),
@@ -125,7 +125,7 @@ mod tests {
(0..cols).for_each(|col_j| { (0..cols).for_each(|col_j| {
if col_j != col_i { if col_j != col_i {
(0..size).for_each(|limb_i| { (0..size).for_each(|limb_i| {
assert_eq!(a.at_poly(col_j, limb_i), zero); assert_eq!(a.at(col_j, limb_i), zero);
}) })
} else { } else {
let std: f64 = a.std(col_i, log_base2k); let std: f64 = a.std(col_i, log_base2k);
@@ -159,7 +159,7 @@ mod tests {
(0..cols).for_each(|col_j| { (0..cols).for_each(|col_j| {
if col_j != col_i { if col_j != col_i {
(0..size).for_each(|limb_i| { (0..size).for_each(|limb_i| {
assert_eq!(a.at_poly(col_j, limb_i), zero); assert_eq!(a.at(col_j, limb_i), zero);
}) })
} else { } else {
let std: f64 = a.std(col_i, log_base2k) * k_f64; let std: f64 = a.std(col_i, log_base2k) * k_f64;

View File

@@ -341,7 +341,7 @@ mod tests {
module.vec_znx_dft_automorphism_inplace(p, &mut a_dft, &mut tmp_bytes); module.vec_znx_dft_automorphism_inplace(p, &mut a_dft, &mut tmp_bytes);
// a <- AUTO(a) // a <- AUTO(a)
module.vec_znx_automorphism_inplace(p, &mut a); module.vec_znx_automorphism_inplace(p, &mut a, 0);
// b_dft <- DFT(AUTO(a)) // b_dft <- DFT(AUTO(a))
module.vec_znx_dft(&mut b_dft, &a); module.vec_znx_dft(&mut b_dft, &a);