Applied discussed changes, everything working, but still to discuss

This commit is contained in:
Jean-Philippe Bossuat
2025-05-01 10:33:19 +02:00
parent 4e6fce3458
commit ca5e6d46c9
14 changed files with 710 additions and 508 deletions

View File

@@ -3,7 +3,7 @@ use crate::Module;
use crate::assert_alignement;
use crate::cast_mut;
use crate::ffi::znx;
use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, ZnxSliceSize, switch_degree};
use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxRsh, ZnxSliceSize, ZnxZero, switch_degree};
use std::cmp::min;
pub const VEC_ZNX_ROWS: usize = 1;
@@ -44,7 +44,9 @@ impl ZnxLayout for VecZnx {
type Scalar = i64;
}
impl ZnxBasics for VecZnx {}
impl ZnxZero for VecZnx {}
impl ZnxRsh for VecZnx {}
impl<B: Backend> ZnxAlloc<B> for VecZnx {
type Scalar = i64;
@@ -84,7 +86,7 @@ impl VecZnx {
///
/// * `log_base2k`: the base two logarithm of the coefficients decomposition.
/// * `k`: the number of bits of precision to drop.
pub fn trunc_pow2(&mut self, log_base2k: usize, k: usize) {
pub fn trunc_pow2(&mut self, log_base2k: usize, k: usize, col: usize) {
if k == 0 {
return;
}
@@ -101,7 +103,7 @@ impl VecZnx {
if k_rem != 0 {
let mask: i64 = ((1 << (log_base2k - k_rem - 1)) - 1) << k_rem;
self.at_limb_mut(self.size() - 1)
self.at_mut(col, self.size() - 1)
.iter_mut()
.for_each(|x: &mut i64| *x &= mask)
}
@@ -111,8 +113,8 @@ impl VecZnx {
copy_vec_znx_from(self, a);
}
pub fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]) {
normalize(log_base2k, self, carry)
pub fn normalize(&mut self, log_base2k: usize, col: usize, carry: &mut [u8]) {
normalize(log_base2k, self, col, carry)
}
pub fn switch_degree(&self, col: usize, a: &mut Self, col_a: usize) {
@@ -120,26 +122,25 @@ impl VecZnx {
}
// Prints the first `n` coefficients of each limb
pub fn print(&self, n: usize) {
(0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n]))
pub fn print(&self, n: usize, col: usize) {
(0..self.size()).for_each(|j| println!("{}: {:?}", j, &self.at(col, j)[..n]));
}
}
fn normalize_tmp_bytes(n: usize, size: usize) -> usize {
n * size * std::mem::size_of::<i64>()
fn normalize_tmp_bytes(n: usize) -> usize {
n * std::mem::size_of::<i64>()
}
fn normalize(log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]) {
fn normalize(log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]) {
let n: usize = a.n();
let cols: usize = a.cols();
debug_assert!(
tmp_bytes.len() >= normalize_tmp_bytes(n, cols),
"invalid tmp_bytes: tmp_bytes.len()={} < normalize_tmp_bytes({}, {})",
tmp_bytes.len() >= normalize_tmp_bytes(n),
"invalid tmp_bytes: tmp_bytes.len()={} < normalize_tmp_bytes({})",
tmp_bytes.len(),
n,
cols,
);
#[cfg(debug_assertions)]
{
assert_alignement(tmp_bytes.as_ptr())
@@ -151,11 +152,11 @@ fn normalize(log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]) {
znx::znx_zero_i64_ref(n as u64, carry_i64.as_mut_ptr());
(0..a.size()).rev().for_each(|i| {
znx::znx_normalize(
(n * cols) as u64,
n as u64,
log_base2k as u64,
a.at_mut_ptr(0, i),
a.at_mut_ptr(a_col, i),
carry_i64.as_mut_ptr(),
a.at_mut_ptr(0, i),
a.at_mut_ptr(a_col, i),
carry_i64.as_mut_ptr(),
)
});