mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 21:26:41 +01:00
Applied discussed changes, everything working, but still to discuss
This commit is contained in:
@@ -3,7 +3,7 @@ use crate::Module;
|
||||
use crate::assert_alignement;
|
||||
use crate::cast_mut;
|
||||
use crate::ffi::znx;
|
||||
use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, ZnxSliceSize, switch_degree};
|
||||
use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxRsh, ZnxSliceSize, ZnxZero, switch_degree};
|
||||
use std::cmp::min;
|
||||
|
||||
pub const VEC_ZNX_ROWS: usize = 1;
|
||||
@@ -44,7 +44,9 @@ impl ZnxLayout for VecZnx {
|
||||
type Scalar = i64;
|
||||
}
|
||||
|
||||
impl ZnxBasics for VecZnx {}
|
||||
impl ZnxZero for VecZnx {}
|
||||
|
||||
impl ZnxRsh for VecZnx {}
|
||||
|
||||
impl<B: Backend> ZnxAlloc<B> for VecZnx {
|
||||
type Scalar = i64;
|
||||
@@ -84,7 +86,7 @@ impl VecZnx {
|
||||
///
|
||||
/// * `log_base2k`: the base two logarithm of the coefficients decomposition.
|
||||
/// * `k`: the number of bits of precision to drop.
|
||||
pub fn trunc_pow2(&mut self, log_base2k: usize, k: usize) {
|
||||
pub fn trunc_pow2(&mut self, log_base2k: usize, k: usize, col: usize) {
|
||||
if k == 0 {
|
||||
return;
|
||||
}
|
||||
@@ -101,7 +103,7 @@ impl VecZnx {
|
||||
|
||||
if k_rem != 0 {
|
||||
let mask: i64 = ((1 << (log_base2k - k_rem - 1)) - 1) << k_rem;
|
||||
self.at_limb_mut(self.size() - 1)
|
||||
self.at_mut(col, self.size() - 1)
|
||||
.iter_mut()
|
||||
.for_each(|x: &mut i64| *x &= mask)
|
||||
}
|
||||
@@ -111,8 +113,8 @@ impl VecZnx {
|
||||
copy_vec_znx_from(self, a);
|
||||
}
|
||||
|
||||
pub fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]) {
|
||||
normalize(log_base2k, self, carry)
|
||||
pub fn normalize(&mut self, log_base2k: usize, col: usize, carry: &mut [u8]) {
|
||||
normalize(log_base2k, self, col, carry)
|
||||
}
|
||||
|
||||
pub fn switch_degree(&self, col: usize, a: &mut Self, col_a: usize) {
|
||||
@@ -120,26 +122,25 @@ impl VecZnx {
|
||||
}
|
||||
|
||||
// Prints the first `n` coefficients of each limb
|
||||
pub fn print(&self, n: usize) {
|
||||
(0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n]))
|
||||
pub fn print(&self, n: usize, col: usize) {
|
||||
(0..self.size()).for_each(|j| println!("{}: {:?}", j, &self.at(col, j)[..n]));
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_tmp_bytes(n: usize, size: usize) -> usize {
|
||||
n * size * std::mem::size_of::<i64>()
|
||||
fn normalize_tmp_bytes(n: usize) -> usize {
|
||||
n * std::mem::size_of::<i64>()
|
||||
}
|
||||
|
||||
fn normalize(log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]) {
|
||||
fn normalize(log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]) {
|
||||
let n: usize = a.n();
|
||||
let cols: usize = a.cols();
|
||||
|
||||
debug_assert!(
|
||||
tmp_bytes.len() >= normalize_tmp_bytes(n, cols),
|
||||
"invalid tmp_bytes: tmp_bytes.len()={} < normalize_tmp_bytes({}, {})",
|
||||
tmp_bytes.len() >= normalize_tmp_bytes(n),
|
||||
"invalid tmp_bytes: tmp_bytes.len()={} < normalize_tmp_bytes({})",
|
||||
tmp_bytes.len(),
|
||||
n,
|
||||
cols,
|
||||
);
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr())
|
||||
@@ -151,11 +152,11 @@ fn normalize(log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]) {
|
||||
znx::znx_zero_i64_ref(n as u64, carry_i64.as_mut_ptr());
|
||||
(0..a.size()).rev().for_each(|i| {
|
||||
znx::znx_normalize(
|
||||
(n * cols) as u64,
|
||||
n as u64,
|
||||
log_base2k as u64,
|
||||
a.at_mut_ptr(0, i),
|
||||
a.at_mut_ptr(a_col, i),
|
||||
carry_i64.as_mut_ptr(),
|
||||
a.at_mut_ptr(0, i),
|
||||
a.at_mut_ptr(a_col, i),
|
||||
carry_i64.as_mut_ptr(),
|
||||
)
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user