diff --git a/base2k/src/encoding.rs b/base2k/src/encoding.rs index 8c41381..7f8a0cc 100644 --- a/base2k/src/encoding.rs +++ b/base2k/src/encoding.rs @@ -107,7 +107,7 @@ fn encode_vec_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize, // values on the last limb. // Else we decompose values base2k. if log_max + log_k_rem < 63 || log_k_rem == log_base2k { - a.at_poly_mut(col_i, size - 1)[..data_len].copy_from_slice(&data[..data_len]); + a.at_mut(col_i, size - 1)[..data_len].copy_from_slice(&data[..data_len]); } else { let mask: i64 = (1 << log_base2k) - 1; let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k); @@ -116,7 +116,7 @@ fn encode_vec_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize, .enumerate() .for_each(|(i, i_rev)| { let shift: usize = i * log_base2k; - izip!(a.at_poly_mut(col_i, i_rev).iter_mut(), data.iter()).for_each(|(y, x)| *y = (x >> shift) & mask); + izip!(a.at_mut(col_i, i_rev).iter_mut(), data.iter()).for_each(|(y, x)| *y = (x >> shift) & mask); }) } @@ -124,7 +124,7 @@ fn encode_vec_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize, if log_k_rem != log_base2k { let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k); (size - steps..size).rev().for_each(|i| { - a.at_poly_mut(col_i, i)[..data_len] + a.at_mut(col_i, i)[..data_len] .iter_mut() .for_each(|x| *x <<= log_k_rem); }) @@ -143,16 +143,16 @@ fn decode_vec_i64(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, dat ); assert!(col_i < a.cols()); } - data.copy_from_slice(a.at_poly(col_i, 0)); + data.copy_from_slice(a.at(col_i, 0)); let rem: usize = log_base2k - (log_k % log_base2k); (1..size).for_each(|i| { if i == size - 1 && rem != log_base2k { let k_rem: usize = log_base2k - rem; - izip!(a.at_poly(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| { + izip!(a.at(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| { *y = (*y << k_rem) + (x >> rem); }); } else { - izip!(a.at_poly(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| { + izip!(a.at(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| { *y = (*y << log_base2k) + x; }); } @@ -180,12 +180,12 @@ fn decode_vec_float(a: &VecZnx, col_i: usize, log_base2k: usize, data: &mut [Flo // y[i] = sum x[j][i] * 2^{-log_base2k*j} (0..size).for_each(|i| { if i == 0 { - izip!(a.at_poly(col_i, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { + izip!(a.at(col_i, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { y.assign(*x); *y /= &base; }); } else { - izip!(a.at_poly(col_i, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { + izip!(a.at(col_i, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { *y += Float::with_val(prec, *x); *y /= &base; }); @@ -209,13 +209,13 @@ fn encode_coeff_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usiz } let log_k_rem: usize = log_base2k - (log_k % log_base2k); - (0..a.size()).for_each(|j| a.at_poly_mut(col_i, j)[i] = 0); + (0..a.size()).for_each(|j| a.at_mut(col_i, j)[i] = 0); // If 2^{log_base2k} * 2^{log_k_rem} < 2^{63}-1, then we can simply copy // values on the last limb. // Else we decompose values base2k. if log_max + log_k_rem < 63 || log_k_rem == log_base2k { - a.at_poly_mut(col_i, size - 1)[i] = value; + a.at_mut(col_i, size - 1)[i] = value; } else { let mask: i64 = (1 << log_base2k) - 1; let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k); @@ -223,7 +223,7 @@ fn encode_coeff_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usiz .rev() .enumerate() .for_each(|(j, j_rev)| { - a.at_poly_mut(col_i, j_rev)[i] = (value >> (j * log_base2k)) & mask; + a.at_mut(col_i, j_rev)[i] = (value >> (j * log_base2k)) & mask; }) } @@ -231,7 +231,7 @@ fn encode_coeff_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usiz if log_k_rem != log_base2k { let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k); (size - steps..size).rev().for_each(|j| { - a.at_poly_mut(col_i, j)[i] <<= log_k_rem; + a.at_mut(col_i, j)[i] <<= log_k_rem; }) } } diff --git a/base2k/src/internals.rs b/base2k/src/internals.rs deleted file mode 100644 index f2fbe3b..0000000 --- a/base2k/src/internals.rs +++ /dev/null @@ -1,96 +0,0 @@ -use std::cmp::{max, min}; - -use crate::{Backend, IntegerType, Module, ZnxBasics, ZnxLayout, ffi::module::MODULE}; - -#[inline(always)] -pub fn apply_unary_op( - module: &Module, - b: &mut T, - a: &T, - op: impl Fn(&mut [T::Scalar], &[T::Scalar]), -) where - ::Scalar: IntegerType, -{ - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), module.n()); - assert_eq!(b.n(), module.n()); - } - let a_cols: usize = a.cols(); - let b_cols: usize = b.cols(); - let min_cols: usize = min(a_cols, b_cols); - // Applies over the shared cols between (a, b) - (0..min_cols).for_each(|i| op(b.at_poly_mut(i, 0), a.at_poly(i, 0))); - // Zeroes the remaining cols of b. - (min_cols..b_cols).for_each(|i| (0..b.size()).for_each(|j| b.zero_at(i, j))); -} - -pub fn ffi_ternary_op_factory( - module_ptr: *const MODULE, - c_size: usize, - c_sl: usize, - a_size: usize, - a_sl: usize, - b_size: usize, - b_sl: usize, - op_fn: unsafe extern "C" fn(*const MODULE, *mut T, u64, u64, *const T, u64, u64, *const T, u64, u64), -) -> impl Fn(&mut [T], &[T], &[T]) { - move |cv: &mut [T], av: &[T], bv: &[T]| unsafe { - op_fn( - module_ptr, - cv.as_mut_ptr(), - c_size as u64, - c_sl as u64, - av.as_ptr(), - a_size as u64, - a_sl as u64, - bv.as_ptr(), - b_size as u64, - b_sl as u64, - ) - } -} - -pub fn ffi_binary_op_factory_type_0( - module_ptr: *const MODULE, - b_size: usize, - b_sl: usize, - a_size: usize, - a_sl: usize, - op_fn: unsafe extern "C" fn(*const MODULE, *mut T, u64, u64, *const T, u64, u64), -) -> impl Fn(&mut [T], &[T]) { - move |bv: &mut [T], av: &[T]| unsafe { - op_fn( - module_ptr, - bv.as_mut_ptr(), - b_size as u64, - b_sl as u64, - av.as_ptr(), - a_size as u64, - a_sl as u64, - ) - } -} - -pub fn ffi_binary_op_factory_type_1( - module_ptr: *const MODULE, - k: i64, - b_size: usize, - b_sl: usize, - a_size: usize, - a_sl: usize, - op_fn: unsafe extern "C" fn(*const MODULE, i64, *mut T, u64, u64, *const T, u64, u64), -) -> impl Fn(&mut [T], &[T]) { - move |bv: &mut [T], av: &[T]| unsafe { - op_fn( - module_ptr, - k, - bv.as_mut_ptr(), - b_size as u64, - b_sl as u64, - av.as_ptr(), - a_size as u64, - a_sl as u64, - ) - } -} diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index 2a9a899..7a8a3f8 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -3,7 +3,6 @@ pub mod encoding; #[allow(non_camel_case_types, non_snake_case, non_upper_case_globals, dead_code, improper_ctypes)] // Other modules and exports pub mod ffi; -mod internals; pub mod mat_znx_dft; pub mod module; pub mod sampling; diff --git a/base2k/src/sampling.rs b/base2k/src/sampling.rs index a96937e..5261207 100644 --- a/base2k/src/sampling.rs +++ b/base2k/src/sampling.rs @@ -32,12 +32,12 @@ pub trait Sampling { } impl Sampling for Module { - fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, col_i: usize, size: usize, source: &mut Source) { + fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, col_a: usize, size: usize, source: &mut Source) { let base2k: u64 = 1 << log_base2k; let mask: u64 = base2k - 1; let base2k_half: i64 = (base2k >> 1) as i64; (0..size).for_each(|j| { - a.at_poly_mut(col_i, j) + a.at_mut(col_a, j) .iter_mut() .for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half); }) @@ -47,7 +47,7 @@ impl Sampling for Module { &self, log_base2k: usize, a: &mut VecZnx, - col_i: usize, + col_a: usize, log_k: usize, source: &mut Source, dist: D, @@ -63,7 +63,7 @@ impl Sampling for Module { let log_base2k_rem: usize = log_k % log_base2k; if log_base2k_rem != 0 { - a.at_poly_mut(col_i, limb).iter_mut().for_each(|a| { + a.at_mut(col_a, limb).iter_mut().for_each(|a| { let mut dist_f64: f64 = dist.sample(source); while dist_f64.abs() > bound { dist_f64 = dist.sample(source) @@ -71,7 +71,7 @@ impl Sampling for Module { *a += (dist_f64.round() as i64) << log_base2k_rem; }); } else { - a.at_poly_mut(col_i, limb).iter_mut().for_each(|a| { + a.at_mut(col_a, limb).iter_mut().for_each(|a| { let mut dist_f64: f64 = dist.sample(source); while dist_f64.abs() > bound { dist_f64 = dist.sample(source) @@ -85,7 +85,7 @@ impl Sampling for Module { &self, log_base2k: usize, a: &mut VecZnx, - col_i: usize, + col_a: usize, log_k: usize, source: &mut Source, sigma: f64, @@ -94,7 +94,7 @@ impl Sampling for Module { self.add_dist_f64( log_base2k, a, - col_i, + col_a, log_k, source, Normal::new(0.0, sigma).unwrap(), @@ -125,7 +125,7 @@ mod tests { (0..cols).for_each(|col_j| { if col_j != col_i { (0..size).for_each(|limb_i| { - assert_eq!(a.at_poly(col_j, limb_i), zero); + assert_eq!(a.at(col_j, limb_i), zero); }) } else { let std: f64 = a.std(col_i, log_base2k); @@ -159,7 +159,7 @@ mod tests { (0..cols).for_each(|col_j| { if col_j != col_i { (0..size).for_each(|limb_i| { - assert_eq!(a.at_poly(col_j, limb_i), zero); + assert_eq!(a.at(col_j, limb_i), zero); }) } else { let std: f64 = a.std(col_i, log_base2k) * k_f64; diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index d9c9e60..1b88af5 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -341,7 +341,7 @@ mod tests { module.vec_znx_dft_automorphism_inplace(p, &mut a_dft, &mut tmp_bytes); // a <- AUTO(a) - module.vec_znx_automorphism_inplace(p, &mut a); + module.vec_znx_automorphism_inplace(p, &mut a, 0); // b_dft <- DFT(AUTO(a)) module.vec_znx_dft(&mut b_dft, &a);