wip major refactoring (compiles & all test + example passing)

This commit is contained in:
Jean-Philippe Bossuat
2025-04-30 13:43:18 +02:00
parent 2cc51eee18
commit 6f7b93c7ca
18 changed files with 662 additions and 870 deletions

View File

@@ -1,12 +1,13 @@
use crate::Backend;
use crate::ZnxBase;
use crate::Module;
use crate::assert_alignement;
use crate::cast_mut;
use crate::ffi::znx;
use crate::switch_degree;
use crate::{Module, ZnxBasics, ZnxInfos, ZnxLayout};
use crate::{alloc_aligned, assert_alignement};
use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, ZnxSliceSize, switch_degree};
use std::cmp::min;
pub const VEC_ZNX_ROWS: usize = 1;
/// [VecZnx] represents collection of contiguously stacked vector of small norm polynomials of
/// Zn\[X\] with [i64] coefficients.
/// A [VecZnx] is composed of multiple Zn\[X\] polynomials stored in a single contiguous array
@@ -17,56 +18,54 @@ use std::cmp::min;
/// Given 3 polynomials (a, b, c) of Zn\[X\], each with 4 columns, then the memory
/// layout is: `[a0, b0, c0, a1, b1, c1, a2, b2, c2, a3, b3, c3]`, where ai, bi, ci
/// are small polynomials of Zn\[X\].
#[derive(Clone)]
pub struct VecZnx {
/// Polynomial degree.
pub n: usize,
/// The number of polynomials
pub cols: usize,
/// The number of size per polynomial (a.k.a small polynomials).
pub size: usize,
/// Polynomial coefficients, as a contiguous array. Each col is equally spaced by n.
pub data: Vec<i64>,
/// Pointer to data (data can be enpty if [VecZnx] borrows space instead of owning it).
pub ptr: *mut i64,
pub inner: ZnxBase,
}
impl ZnxInfos for VecZnx {
fn n(&self) -> usize {
self.n
impl GetZnxBase for VecZnx {
fn znx(&self) -> &ZnxBase {
&self.inner
}
fn rows(&self) -> usize {
1
fn znx_mut(&mut self) -> &mut ZnxBase {
&mut self.inner
}
}
fn cols(&self) -> usize {
self.cols
}
impl ZnxInfos for VecZnx {}
fn size(&self) -> usize {
self.size
impl ZnxSliceSize for VecZnx {
fn sl(&self) -> usize {
self.cols() * self.n()
}
}
impl ZnxLayout for VecZnx {
type Scalar = i64;
fn as_ptr(&self) -> *const Self::Scalar {
self.ptr
}
fn as_mut_ptr(&mut self) -> *mut Self::Scalar {
self.ptr
}
}
impl ZnxBasics for VecZnx {}
impl<B: Backend> ZnxAlloc<B> for VecZnx {
type Scalar = i64;
fn from_bytes_borrow(module: &Module<B>, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnx {
debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, size));
VecZnx {
inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_ROWS, cols, size, bytes),
}
}
fn bytes_of(module: &Module<B>, _rows: usize, cols: usize, size: usize) -> usize {
debug_assert_eq!(
_rows, VEC_ZNX_ROWS,
"rows != {} not supported for VecZnx",
VEC_ZNX_ROWS
);
module.n() * cols * size * size_of::<Self::Scalar>()
}
}
/// Copies the coefficients of `a` on the receiver.
/// Copy is done with the minimum size matching both backing arrays.
/// Panics if the cols do not match.
@@ -78,80 +77,6 @@ pub fn copy_vec_znx_from(b: &mut VecZnx, a: &VecZnx) {
data_b[..size].copy_from_slice(&data_a[..size])
}
impl<B: Backend> ZnxBase<B> for VecZnx {
type Scalar = i64;
/// Allocates a new [VecZnx] composed of #size polynomials of Z\[X\].
fn new(module: &Module<B>, cols: usize, size: usize) -> Self {
let n: usize = module.n();
#[cfg(debug_assertions)]
{
assert!(n > 0);
assert!(n & (n - 1) == 0);
assert!(cols > 0);
assert!(size > 0);
}
let mut data: Vec<i64> = alloc_aligned::<i64>(Self::bytes_of(module, cols, size));
let ptr: *mut i64 = data.as_mut_ptr();
Self {
n: n,
cols: cols,
size: size,
data: data,
ptr: ptr,
}
}
fn bytes_of(module: &Module<B>, cols: usize, size: usize) -> usize {
module.n() * cols * size * size_of::<i64>()
}
/// Returns a new struct implementing [VecZnx] with the provided data as backing array.
///
/// The struct will take ownership of buf[..[Self::bytes_of]]
///
/// User must ensure that data is properly alligned and that
/// the size of data is equal to [Self::bytes_of].
fn from_bytes(module: &Module<B>, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
let n: usize = module.n();
#[cfg(debug_assertions)]
{
assert!(cols > 0);
assert!(size > 0);
assert_eq!(bytes.len(), Self::bytes_of(module, cols, size));
assert_alignement(bytes.as_ptr());
}
unsafe {
let bytes_i64: &mut [i64] = cast_mut::<u8, i64>(bytes);
let ptr: *mut i64 = bytes_i64.as_mut_ptr();
Self {
n: n,
cols: cols,
size: size,
data: Vec::from_raw_parts(ptr, bytes.len(), bytes.len()),
ptr: ptr,
}
}
}
fn from_bytes_borrow(module: &Module<B>, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
#[cfg(debug_assertions)]
{
assert!(cols > 0);
assert!(size > 0);
assert!(bytes.len() >= Self::bytes_of(module, cols, size));
assert_alignement(bytes.as_ptr());
}
Self {
n: module.n(),
cols: cols,
size: size,
data: Vec::new(),
ptr: bytes.as_mut_ptr() as *mut i64,
}
}
}
impl VecZnx {
/// Truncates the precision of the [VecZnx] by k bits.
///
@@ -165,11 +90,12 @@ impl VecZnx {
}
if !self.borrowing() {
self.data
self.inner
.data
.truncate(self.n() * self.cols() * (self.size() - k / log_base2k));
}
self.size -= k / log_base2k;
self.inner.size -= k / log_base2k;
let k_rem: usize = k % log_base2k;
@@ -185,10 +111,6 @@ impl VecZnx {
copy_vec_znx_from(self, a);
}
pub fn borrowing(&self) -> bool {
self.data.len() == 0
}
pub fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]) {
normalize(log_base2k, self, carry)
}