From 6cbd2a6a9380dd7648aac6e05e6ca93227757321 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Wed, 7 May 2025 16:47:58 +0200 Subject: [PATCH] Some fixes & QoL to Base2k --- base2k/examples/rlwe_encrypt.rs | 6 ++--- base2k/spqlios-arithmetic | 2 +- base2k/src/encoding.rs | 44 ++++++++++++++++++--------------- base2k/src/mat_znx_dft_ops.rs | 2 +- base2k/src/sampling.rs | 8 +++--- base2k/src/stats.rs | 4 +-- base2k/src/vec_znx_big.rs | 4 +-- base2k/src/vec_znx_dft.rs | 36 +++++++++++++++++++++++++++ 8 files changed, 73 insertions(+), 33 deletions(-) diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index b9d78f4..4db6ef5 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -1,7 +1,7 @@ use base2k::{ - AddNormal, Encoding, FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, - ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, - VecZnxDftOps, VecZnxOps, ZnxInfos, + AddNormal, Decoding, Encoding, FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, + ScalarZnxDftOps, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, + VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxInfos, }; use itertools::izip; use sampling::source::Source; diff --git a/base2k/spqlios-arithmetic b/base2k/spqlios-arithmetic index b6fa494..b919282 160000 --- a/base2k/spqlios-arithmetic +++ b/base2k/spqlios-arithmetic @@ -1 +1 @@ -Subproject commit b6fa494a14c52842712f8ff032ea80812467dec2 +Subproject commit b919282c9b913e8b11418df6afdb0baa02debc9b diff --git a/base2k/src/encoding.rs b/base2k/src/encoding.rs index ba48474..45214c6 100644 --- a/base2k/src/encoding.rs +++ b/base2k/src/encoding.rs @@ -17,6 +17,20 @@ pub trait Encoding { /// * `log_max`: base two logarithm of the infinity norm of the input data. fn encode_vec_i64(&mut self, col_i: usize, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize); + /// encodes a single i64 on the receiver at the given index. + /// + /// # Arguments + /// + /// * `col_i`: the index of the poly where to encode the data. + /// * `log_base2k`: base two negative logarithm decomposition of the receiver. + /// * `log_k`: base two negative logarithm of the scaling of the data. + /// * `i`: index of the coefficient on which to encode the data. + /// * `data`: data to encode on the receiver. + /// * `log_max`: base two logarithm of the infinity norm of the input data. + fn encode_coeff_i64(&mut self, col_i: usize, log_base2k: usize, log_k: usize, i: usize, data: i64, log_max: usize); +} + +pub trait Decoding { /// decode a vector of i64 from the receiver. /// /// # Arguments @@ -35,18 +49,6 @@ pub trait Encoding { /// * `data`: data to decode from the receiver. fn decode_vec_float(&self, col_i: usize, log_base2k: usize, data: &mut [Float]); - /// encodes a single i64 on the receiver at the given index. - /// - /// # Arguments - /// - /// * `col_i`: the index of the poly where to encode the data. - /// * `log_base2k`: base two negative logarithm decomposition of the receiver. - /// * `log_k`: base two negative logarithm of the scaling of the data. - /// * `i`: index of the coefficient on which to encode the data. - /// * `data`: data to encode on the receiver. - /// * `log_max`: base two logarithm of the infinity norm of the input data. - fn encode_coeff_i64(&mut self, col_i: usize, log_base2k: usize, log_k: usize, i: usize, data: i64, log_max: usize); - /// decode a single of i64 from the receiver at the given index. /// /// # Arguments @@ -64,6 +66,12 @@ impl + AsRef<[u8]>> Encoding for VecZnx { encode_vec_i64(self, col_i, log_base2k, log_k, data, log_max) } + fn encode_coeff_i64(&mut self, col_i: usize, log_base2k: usize, log_k: usize, i: usize, value: i64, log_max: usize) { + encode_coeff_i64(self, col_i, log_base2k, log_k, i, value, log_max) + } +} + +impl> Decoding for VecZnx { fn decode_vec_i64(&self, col_i: usize, log_base2k: usize, log_k: usize, data: &mut [i64]) { decode_vec_i64(self, col_i, log_base2k, log_k, data) } @@ -72,10 +80,6 @@ impl + AsRef<[u8]>> Encoding for VecZnx { decode_vec_float(self, col_i, log_base2k, data) } - fn encode_coeff_i64(&mut self, col_i: usize, log_base2k: usize, log_k: usize, i: usize, value: i64, log_max: usize) { - encode_coeff_i64(self, col_i, log_base2k, log_k, i, value, log_max) - } - fn decode_coeff_i64(&self, col_i: usize, log_base2k: usize, log_k: usize, i: usize) -> i64 { decode_coeff_i64(self, col_i, log_base2k, log_k, i) } @@ -139,7 +143,7 @@ fn encode_vec_i64 + AsRef<[u8]>>( } } -fn decode_vec_i64 + AsRef<[u8]>>(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, data: &mut [i64]) { +fn decode_vec_i64>(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, data: &mut [i64]) { let size: usize = (log_k + log_base2k - 1) / log_base2k; #[cfg(debug_assertions)] { @@ -167,7 +171,7 @@ fn decode_vec_i64 + AsRef<[u8]>>(a: &VecZnx, col_i: usize, log }) } -fn decode_vec_float + AsRef<[u8]>>(a: &VecZnx, col_i: usize, log_base2k: usize, data: &mut [Float]) { +fn decode_vec_float>(a: &VecZnx, col_i: usize, log_base2k: usize, data: &mut [Float]) { let size: usize = a.size(); #[cfg(debug_assertions)] { @@ -252,7 +256,7 @@ fn encode_coeff_i64 + AsRef<[u8]>>( } } -fn decode_coeff_i64 + AsRef<[u8]>>(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, i: usize) -> i64 { +fn decode_coeff_i64>(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, i: usize) -> i64 { #[cfg(debug_assertions)] { assert!(i < a.n()); @@ -280,7 +284,7 @@ fn decode_coeff_i64 + AsRef<[u8]>>(a: &VecZnx, col_i: usize, l mod tests { use crate::vec_znx_ops::*; use crate::znx_base::*; - use crate::{Encoding, FFT64, Module, VecZnx, znx_base::ZnxInfos}; + use crate::{Decoding, Encoding, FFT64, Module, VecZnx, znx_base::ZnxInfos}; use itertools::izip; use sampling::source::Source; diff --git a/base2k/src/mat_znx_dft_ops.rs b/base2k/src/mat_znx_dft_ops.rs index 85e6264..f302e9b 100644 --- a/base2k/src/mat_znx_dft_ops.rs +++ b/base2k/src/mat_znx_dft_ops.rs @@ -305,7 +305,7 @@ impl MatZnxDftOps for Module { #[cfg(test)] mod tests { use crate::{ - Encoding, FFT64, FillUniform, MatZnxDft, MatZnxDftOps, Module, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig, + Decoding, FFT64, FillUniform, MatZnxDft, MatZnxDftOps, Module, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, ZnxInfos, ZnxView, ZnxViewMut, }; use sampling::source::Source; diff --git a/base2k/src/sampling.rs b/base2k/src/sampling.rs index b2d6f22..b4e1489 100644 --- a/base2k/src/sampling.rs +++ b/base2k/src/sampling.rs @@ -80,7 +80,7 @@ where ); let limb: usize = (log_k + log_base2k - 1) / log_base2k - 1; - let log_base2k_rem: usize = log_k % log_base2k; + let log_base2k_rem: usize = (limb + 1) * log_base2k - log_k; if log_base2k_rem != 0 { a.at_mut(col_i, limb).iter_mut().for_each(|a| { @@ -123,7 +123,7 @@ where ); let limb: usize = (log_k + log_base2k - 1) / log_base2k - 1; - let log_base2k_rem: usize = log_k % log_base2k; + let log_base2k_rem: usize = (limb + 1) * log_base2k - log_k; if log_base2k_rem != 0 { a.at_mut(col_i, limb).iter_mut().for_each(|a| { @@ -198,7 +198,7 @@ where ); let limb: usize = (log_k + log_base2k - 1) / log_base2k - 1; - let log_base2k_rem: usize = log_k % log_base2k; + let log_base2k_rem: usize = (limb + 1) * log_base2k - log_k; if log_base2k_rem != 0 { a.at_mut(col_i, limb).iter_mut().for_each(|a| { @@ -241,7 +241,7 @@ where ); let limb: usize = (log_k + log_base2k - 1) / log_base2k - 1; - let log_base2k_rem: usize = log_k % log_base2k; + let log_base2k_rem: usize = (limb + 1) * log_base2k - log_k; if log_base2k_rem != 0 { a.at_mut(col_i, limb).iter_mut().for_each(|a| { diff --git a/base2k/src/stats.rs b/base2k/src/stats.rs index c6d16b4..8db40f2 100644 --- a/base2k/src/stats.rs +++ b/base2k/src/stats.rs @@ -1,5 +1,5 @@ use crate::znx_base::ZnxInfos; -use crate::{Encoding, VecZnx}; +use crate::{Decoding, VecZnx}; use rug::Float; use rug::float::Round; use rug::ops::{AddAssignRound, DivAssignRound, SubAssignRound}; @@ -9,7 +9,7 @@ pub trait Stats { fn std(&self, col_i: usize, log_base2k: usize) -> f64; } -impl + AsRef<[u8]>> Stats for VecZnx { +impl> Stats for VecZnx { fn std(&self, col_i: usize, log_base2k: usize) -> f64 { let prec: u32 = (self.size() * log_base2k) as u32; let mut data: Vec = (0..self.n()).map(|_| Float::with_val(prec, 0)).collect(); diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index f5f220e..d8c1bdd 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -1,8 +1,8 @@ use crate::ffi::vec_znx_big; use crate::znx_base::{ZnxInfos, ZnxView}; use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, alloc_aligned}; +use std::fmt; use std::marker::PhantomData; -use std::{cmp::min, fmt}; pub struct VecZnxBig { data: D, @@ -168,7 +168,7 @@ impl> fmt::Display for VecZnxBig { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!( f, - "VecZnx(n={}, cols={}, size={})", + "VecZnxBig(n={}, cols={}, size={})", self.n, self.cols, self.size )?; diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index 66e58cf..0e7f952 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -3,6 +3,7 @@ use std::marker::PhantomData; use crate::ffi::vec_znx_dft; use crate::znx_base::ZnxInfos; use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxView, alloc_aligned}; +use std::fmt; pub struct VecZnxDft { data: D, @@ -163,3 +164,38 @@ impl VecZnxDftToRef for VecZnxDft<&[u8], B> { } } } + +impl> fmt::Display for VecZnxDft { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!( + f, + "VecZnxDft(n={}, cols={}, size={})", + self.n, self.cols, self.size + )?; + + for col in 0..self.cols { + writeln!(f, "Column {}:", col)?; + for size in 0..self.size { + let coeffs = self.at(col, size); + write!(f, " Size {}: [", size)?; + + let max_show = 100; + let show_count = coeffs.len().min(max_show); + + for (i, &coeff) in coeffs.iter().take(show_count).enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", coeff)?; + } + + if coeffs.len() > max_show { + write!(f, ", ... ({} more)", coeffs.len() - max_show)?; + } + + writeln!(f, "]")?; + } + } + Ok(()) + } +}