Add cross-basek normalization (#90)

* added cross_basek_normalization

* updated method signatures to take layouts

* fixed cross-base normalization

fix #91
fix #93
This commit is contained in:
Jean-Philippe Bossuat
2025-09-30 14:40:10 +02:00
committed by GitHub
parent 4da790ea6a
commit 37e13b965c
216 changed files with 12481 additions and 7745 deletions

View File

@@ -6,8 +6,8 @@ use std::{
use crate::{
alloc_aligned,
layouts::{
Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, FillUniform, ReaderFrom, Reset, ToOwnedDeep, WriterTo,
ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, FillUniform, ReaderFrom, ToOwnedDeep, WriterTo, ZnxInfos,
ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
},
source::Source,
};
@@ -25,6 +25,18 @@ pub struct VecZnx<D: Data> {
pub max_size: usize,
}
impl<D: Data + Default> Default for VecZnx<D> {
fn default() -> Self {
Self {
data: D::default(),
n: 0,
cols: 0,
size: 0,
max_size: 0,
}
}
}
impl<D: DataRef> DigestU64 for VecZnx<D> {
fn digest_u64(&self) -> u64 {
let mut h: DefaultHasher = DefaultHasher::new();
@@ -52,7 +64,7 @@ impl<D: DataRef> ToOwnedDeep for VecZnx<D> {
impl<D: DataRef> fmt::Debug for VecZnx<D> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self)
write!(f, "{self}")
}
}
@@ -162,10 +174,10 @@ impl<D: DataRef> fmt::Display for VecZnx<D> {
)?;
for col in 0..self.cols {
writeln!(f, "Column {}:", col)?;
writeln!(f, "Column {col}:")?;
for size in 0..self.size {
let coeffs = self.at(col, size);
write!(f, " Size {}: [", size)?;
write!(f, " Size {size}: [")?;
let max_show = 100;
let show_count = coeffs.len().min(max_show);
@@ -174,7 +186,7 @@ impl<D: DataRef> fmt::Display for VecZnx<D> {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", coeff)?;
write!(f, "{coeff}")?;
}
if coeffs.len() > max_show {
@@ -204,16 +216,6 @@ impl<D: DataMut> FillUniform for VecZnx<D> {
}
}
impl<D: DataMut> Reset for VecZnx<D> {
fn reset(&mut self) {
self.zero();
self.n = 0;
self.cols = 0;
self.size = 0;
self.max_size = 0;
}
}
pub type VecZnxOwned = VecZnx<Vec<u8>>;
pub type VecZnxMut<'a> = VecZnx<&'a mut [u8]>;
pub type VecZnxRef<'a> = VecZnx<&'a [u8]>;