Some fixes & QoL to Base2k

This commit is contained in:
Jean-Philippe Bossuat
2025-05-07 16:47:58 +02:00
parent 64874dbda8
commit 6cbd2a6a93
8 changed files with 73 additions and 33 deletions

View File

@@ -1,7 +1,7 @@
use base2k::{ use base2k::{
AddNormal, Encoding, FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, AddNormal, Decoding, Encoding, FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc,
ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, ScalarZnxDftOps, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft,
VecZnxDftOps, VecZnxOps, ZnxInfos, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxInfos,
}; };
use itertools::izip; use itertools::izip;
use sampling::source::Source; use sampling::source::Source;

View File

@@ -17,6 +17,20 @@ pub trait Encoding {
/// * `log_max`: base two logarithm of the infinity norm of the input data. /// * `log_max`: base two logarithm of the infinity norm of the input data.
fn encode_vec_i64(&mut self, col_i: usize, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize); fn encode_vec_i64(&mut self, col_i: usize, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize);
/// encodes a single i64 on the receiver at the given index.
///
/// # Arguments
///
/// * `col_i`: the index of the poly where to encode the data.
/// * `log_base2k`: base two negative logarithm decomposition of the receiver.
/// * `log_k`: base two negative logarithm of the scaling of the data.
/// * `i`: index of the coefficient on which to encode the data.
/// * `data`: data to encode on the receiver.
/// * `log_max`: base two logarithm of the infinity norm of the input data.
fn encode_coeff_i64(&mut self, col_i: usize, log_base2k: usize, log_k: usize, i: usize, data: i64, log_max: usize);
}
pub trait Decoding {
/// decode a vector of i64 from the receiver. /// decode a vector of i64 from the receiver.
/// ///
/// # Arguments /// # Arguments
@@ -35,18 +49,6 @@ pub trait Encoding {
/// * `data`: data to decode from the receiver. /// * `data`: data to decode from the receiver.
fn decode_vec_float(&self, col_i: usize, log_base2k: usize, data: &mut [Float]); fn decode_vec_float(&self, col_i: usize, log_base2k: usize, data: &mut [Float]);
/// encodes a single i64 on the receiver at the given index.
///
/// # Arguments
///
/// * `col_i`: the index of the poly where to encode the data.
/// * `log_base2k`: base two negative logarithm decomposition of the receiver.
/// * `log_k`: base two negative logarithm of the scaling of the data.
/// * `i`: index of the coefficient on which to encode the data.
/// * `data`: data to encode on the receiver.
/// * `log_max`: base two logarithm of the infinity norm of the input data.
fn encode_coeff_i64(&mut self, col_i: usize, log_base2k: usize, log_k: usize, i: usize, data: i64, log_max: usize);
/// decode a single of i64 from the receiver at the given index. /// decode a single of i64 from the receiver at the given index.
/// ///
/// # Arguments /// # Arguments
@@ -64,6 +66,12 @@ impl<D: AsMut<[u8]> + AsRef<[u8]>> Encoding for VecZnx<D> {
encode_vec_i64(self, col_i, log_base2k, log_k, data, log_max) encode_vec_i64(self, col_i, log_base2k, log_k, data, log_max)
} }
fn encode_coeff_i64(&mut self, col_i: usize, log_base2k: usize, log_k: usize, i: usize, value: i64, log_max: usize) {
encode_coeff_i64(self, col_i, log_base2k, log_k, i, value, log_max)
}
}
impl<D: AsRef<[u8]>> Decoding for VecZnx<D> {
fn decode_vec_i64(&self, col_i: usize, log_base2k: usize, log_k: usize, data: &mut [i64]) { fn decode_vec_i64(&self, col_i: usize, log_base2k: usize, log_k: usize, data: &mut [i64]) {
decode_vec_i64(self, col_i, log_base2k, log_k, data) decode_vec_i64(self, col_i, log_base2k, log_k, data)
} }
@@ -72,10 +80,6 @@ impl<D: AsMut<[u8]> + AsRef<[u8]>> Encoding for VecZnx<D> {
decode_vec_float(self, col_i, log_base2k, data) decode_vec_float(self, col_i, log_base2k, data)
} }
fn encode_coeff_i64(&mut self, col_i: usize, log_base2k: usize, log_k: usize, i: usize, value: i64, log_max: usize) {
encode_coeff_i64(self, col_i, log_base2k, log_k, i, value, log_max)
}
fn decode_coeff_i64(&self, col_i: usize, log_base2k: usize, log_k: usize, i: usize) -> i64 { fn decode_coeff_i64(&self, col_i: usize, log_base2k: usize, log_k: usize, i: usize) -> i64 {
decode_coeff_i64(self, col_i, log_base2k, log_k, i) decode_coeff_i64(self, col_i, log_base2k, log_k, i)
} }
@@ -139,7 +143,7 @@ fn encode_vec_i64<D: AsMut<[u8]> + AsRef<[u8]>>(
} }
} }
fn decode_vec_i64<D: AsMut<[u8]> + AsRef<[u8]>>(a: &VecZnx<D>, col_i: usize, log_base2k: usize, log_k: usize, data: &mut [i64]) { fn decode_vec_i64<D: AsRef<[u8]>>(a: &VecZnx<D>, col_i: usize, log_base2k: usize, log_k: usize, data: &mut [i64]) {
let size: usize = (log_k + log_base2k - 1) / log_base2k; let size: usize = (log_k + log_base2k - 1) / log_base2k;
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
@@ -167,7 +171,7 @@ fn decode_vec_i64<D: AsMut<[u8]> + AsRef<[u8]>>(a: &VecZnx<D>, col_i: usize, log
}) })
} }
fn decode_vec_float<D: AsMut<[u8]> + AsRef<[u8]>>(a: &VecZnx<D>, col_i: usize, log_base2k: usize, data: &mut [Float]) { fn decode_vec_float<D: AsRef<[u8]>>(a: &VecZnx<D>, col_i: usize, log_base2k: usize, data: &mut [Float]) {
let size: usize = a.size(); let size: usize = a.size();
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
@@ -252,7 +256,7 @@ fn encode_coeff_i64<D: AsMut<[u8]> + AsRef<[u8]>>(
} }
} }
fn decode_coeff_i64<D: AsMut<[u8]> + AsRef<[u8]>>(a: &VecZnx<D>, col_i: usize, log_base2k: usize, log_k: usize, i: usize) -> i64 { fn decode_coeff_i64<D: AsRef<[u8]>>(a: &VecZnx<D>, col_i: usize, log_base2k: usize, log_k: usize, i: usize) -> i64 {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert!(i < a.n()); assert!(i < a.n());
@@ -280,7 +284,7 @@ fn decode_coeff_i64<D: AsMut<[u8]> + AsRef<[u8]>>(a: &VecZnx<D>, col_i: usize, l
mod tests { mod tests {
use crate::vec_znx_ops::*; use crate::vec_znx_ops::*;
use crate::znx_base::*; use crate::znx_base::*;
use crate::{Encoding, FFT64, Module, VecZnx, znx_base::ZnxInfos}; use crate::{Decoding, Encoding, FFT64, Module, VecZnx, znx_base::ZnxInfos};
use itertools::izip; use itertools::izip;
use sampling::source::Source; use sampling::source::Source;

View File

@@ -305,7 +305,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::{ use crate::{
Encoding, FFT64, FillUniform, MatZnxDft, MatZnxDftOps, Module, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig, Decoding, FFT64, FillUniform, MatZnxDft, MatZnxDftOps, Module, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig,
VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, ZnxInfos, ZnxView, ZnxViewMut, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, ZnxInfos, ZnxView, ZnxViewMut,
}; };
use sampling::source::Source; use sampling::source::Source;

View File

@@ -80,7 +80,7 @@ where
); );
let limb: usize = (log_k + log_base2k - 1) / log_base2k - 1; let limb: usize = (log_k + log_base2k - 1) / log_base2k - 1;
let log_base2k_rem: usize = log_k % log_base2k; let log_base2k_rem: usize = (limb + 1) * log_base2k - log_k;
if log_base2k_rem != 0 { if log_base2k_rem != 0 {
a.at_mut(col_i, limb).iter_mut().for_each(|a| { a.at_mut(col_i, limb).iter_mut().for_each(|a| {
@@ -123,7 +123,7 @@ where
); );
let limb: usize = (log_k + log_base2k - 1) / log_base2k - 1; let limb: usize = (log_k + log_base2k - 1) / log_base2k - 1;
let log_base2k_rem: usize = log_k % log_base2k; let log_base2k_rem: usize = (limb + 1) * log_base2k - log_k;
if log_base2k_rem != 0 { if log_base2k_rem != 0 {
a.at_mut(col_i, limb).iter_mut().for_each(|a| { a.at_mut(col_i, limb).iter_mut().for_each(|a| {
@@ -198,7 +198,7 @@ where
); );
let limb: usize = (log_k + log_base2k - 1) / log_base2k - 1; let limb: usize = (log_k + log_base2k - 1) / log_base2k - 1;
let log_base2k_rem: usize = log_k % log_base2k; let log_base2k_rem: usize = (limb + 1) * log_base2k - log_k;
if log_base2k_rem != 0 { if log_base2k_rem != 0 {
a.at_mut(col_i, limb).iter_mut().for_each(|a| { a.at_mut(col_i, limb).iter_mut().for_each(|a| {
@@ -241,7 +241,7 @@ where
); );
let limb: usize = (log_k + log_base2k - 1) / log_base2k - 1; let limb: usize = (log_k + log_base2k - 1) / log_base2k - 1;
let log_base2k_rem: usize = log_k % log_base2k; let log_base2k_rem: usize = (limb + 1) * log_base2k - log_k;
if log_base2k_rem != 0 { if log_base2k_rem != 0 {
a.at_mut(col_i, limb).iter_mut().for_each(|a| { a.at_mut(col_i, limb).iter_mut().for_each(|a| {

View File

@@ -1,5 +1,5 @@
use crate::znx_base::ZnxInfos; use crate::znx_base::ZnxInfos;
use crate::{Encoding, VecZnx}; use crate::{Decoding, VecZnx};
use rug::Float; use rug::Float;
use rug::float::Round; use rug::float::Round;
use rug::ops::{AddAssignRound, DivAssignRound, SubAssignRound}; use rug::ops::{AddAssignRound, DivAssignRound, SubAssignRound};
@@ -9,7 +9,7 @@ pub trait Stats {
fn std(&self, col_i: usize, log_base2k: usize) -> f64; fn std(&self, col_i: usize, log_base2k: usize) -> f64;
} }
impl<D: AsMut<[u8]> + AsRef<[u8]>> Stats for VecZnx<D> { impl<D: AsRef<[u8]>> Stats for VecZnx<D> {
fn std(&self, col_i: usize, log_base2k: usize) -> f64 { fn std(&self, col_i: usize, log_base2k: usize) -> f64 {
let prec: u32 = (self.size() * log_base2k) as u32; let prec: u32 = (self.size() * log_base2k) as u32;
let mut data: Vec<Float> = (0..self.n()).map(|_| Float::with_val(prec, 0)).collect(); let mut data: Vec<Float> = (0..self.n()).map(|_| Float::with_val(prec, 0)).collect();

View File

@@ -1,8 +1,8 @@
use crate::ffi::vec_znx_big; use crate::ffi::vec_znx_big;
use crate::znx_base::{ZnxInfos, ZnxView}; use crate::znx_base::{ZnxInfos, ZnxView};
use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, alloc_aligned}; use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, alloc_aligned};
use std::fmt;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::{cmp::min, fmt};
pub struct VecZnxBig<D, B: Backend> { pub struct VecZnxBig<D, B: Backend> {
data: D, data: D,
@@ -168,7 +168,7 @@ impl<D: AsRef<[u8]>> fmt::Display for VecZnxBig<D, FFT64> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!( writeln!(
f, f,
"VecZnx(n={}, cols={}, size={})", "VecZnxBig(n={}, cols={}, size={})",
self.n, self.cols, self.size self.n, self.cols, self.size
)?; )?;

View File

@@ -3,6 +3,7 @@ use std::marker::PhantomData;
use crate::ffi::vec_znx_dft; use crate::ffi::vec_znx_dft;
use crate::znx_base::ZnxInfos; use crate::znx_base::ZnxInfos;
use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxView, alloc_aligned}; use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxView, alloc_aligned};
use std::fmt;
pub struct VecZnxDft<D, B: Backend> { pub struct VecZnxDft<D, B: Backend> {
data: D, data: D,
@@ -163,3 +164,38 @@ impl<B: Backend> VecZnxDftToRef<B> for VecZnxDft<&[u8], B> {
} }
} }
} }
impl<D: AsRef<[u8]>> fmt::Display for VecZnxDft<D, FFT64> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(
f,
"VecZnxDft(n={}, cols={}, size={})",
self.n, self.cols, self.size
)?;
for col in 0..self.cols {
writeln!(f, "Column {}:", col)?;
for size in 0..self.size {
let coeffs = self.at(col, size);
write!(f, " Size {}: [", size)?;
let max_show = 100;
let show_count = coeffs.len().min(max_show);
for (i, &coeff) in coeffs.iter().take(show_count).enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", coeff)?;
}
if coeffs.len() > max_show {
write!(f, ", ... ({} more)", coeffs.len() - max_show)?;
}
writeln!(f, "]")?;
}
}
Ok(())
}
}