Add cross-basek normalization (#90)

* added cross_basek_normalization

* updated method signatures to take layouts

* fixed cross-base normalization

fix #91
fix #93
This commit is contained in:
Jean-Philippe Bossuat
2025-09-30 14:40:10 +02:00
committed by GitHub
parent 4da790ea6a
commit 37e13b965c
216 changed files with 12481 additions and 7745 deletions

View File

@@ -1,11 +1,11 @@
use poulpy_hal::{
api::{VecZnxCopy, VecZnxFillUniform},
layouts::{Backend, Data, DataMut, DataRef, FillUniform, MatZnx, Module, ReaderFrom, Reset, WriterTo},
layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, WriterTo},
source::Source,
};
use crate::layouts::{
GGLWETensorKey, Infos,
Base2K, Degree, Digits, GGLWELayoutInfos, GGLWETensorKey, GLWEInfos, LWEInfos, Rank, Rows, TorusPrecision,
compressed::{Decompress, GGLWESwitchingKeyCompressed},
};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
@@ -16,9 +16,49 @@ pub struct GGLWETensorKeyCompressed<D: Data> {
pub(crate) keys: Vec<GGLWESwitchingKeyCompressed<D>>,
}
impl<D: Data> LWEInfos for GGLWETensorKeyCompressed<D> {
fn n(&self) -> Degree {
self.keys[0].n()
}
fn base2k(&self) -> Base2K {
self.keys[0].base2k()
}
fn k(&self) -> TorusPrecision {
self.keys[0].k()
}
fn size(&self) -> usize {
self.keys[0].size()
}
}
impl<D: Data> GLWEInfos for GGLWETensorKeyCompressed<D> {
fn rank(&self) -> Rank {
self.rank_out()
}
}
impl<D: Data> GGLWELayoutInfos for GGLWETensorKeyCompressed<D> {
fn rank_in(&self) -> Rank {
self.rank_out()
}
fn rank_out(&self) -> Rank {
self.keys[0].rank_out()
}
fn digits(&self) -> Digits {
self.keys[0].digits()
}
fn rows(&self) -> Rows {
self.keys[0].rows()
}
}
impl<D: DataRef> fmt::Debug for GGLWETensorKeyCompressed<D> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self)
write!(f, "{self}")
}
}
@@ -30,76 +70,79 @@ impl<D: DataMut> FillUniform for GGLWETensorKeyCompressed<D> {
}
}
impl<D: DataMut> Reset for GGLWETensorKeyCompressed<D>
where
MatZnx<D>: Reset,
{
fn reset(&mut self) {
self.keys
.iter_mut()
.for_each(|key: &mut GGLWESwitchingKeyCompressed<D>| key.reset())
}
}
impl<D: DataRef> fmt::Display for GGLWETensorKeyCompressed<D> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "(GLWETensorKeyCompressed)",)?;
for (i, key) in self.keys.iter().enumerate() {
write!(f, "{}: {}", i, key)?;
write!(f, "{i}: {key}")?;
}
Ok(())
}
}
impl GGLWETensorKeyCompressed<Vec<u8>> {
pub fn alloc(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self {
pub fn alloc<A>(infos: &A) -> Self
where
A: GGLWELayoutInfos,
{
assert_eq!(
infos.rank_in(),
infos.rank_out(),
"rank_in != rank_out is not supported for GGLWETensorKeyCompressed"
);
Self::alloc_with(
infos.n(),
infos.base2k(),
infos.k(),
infos.rows(),
infos.digits(),
infos.rank_out(),
)
}
pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rows: Rows, digits: Digits, rank: Rank) -> Self {
let mut keys: Vec<GGLWESwitchingKeyCompressed<Vec<u8>>> = Vec::new();
let pairs: usize = (((rank + 1) * rank) >> 1).max(1);
let pairs: u32 = (((rank.0 + 1) * rank.0) >> 1).max(1);
(0..pairs).for_each(|_| {
keys.push(GGLWESwitchingKeyCompressed::alloc(
n, basek, k, rows, digits, 1, rank,
keys.push(GGLWESwitchingKeyCompressed::alloc_with(
n,
base2k,
k,
rows,
digits,
Rank(1),
rank,
));
});
Self { keys }
}
pub fn bytes_of(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize {
let pairs: usize = (((rank + 1) * rank) >> 1).max(1);
pairs * GGLWESwitchingKeyCompressed::bytes_of(n, basek, k, rows, digits, 1)
}
}
impl<D: Data> Infos for GGLWETensorKeyCompressed<D> {
type Inner = MatZnx<D>;
fn inner(&self) -> &Self::Inner {
self.keys[0].inner()
pub fn alloc_bytes<A>(infos: &A) -> usize
where
A: GGLWELayoutInfos,
{
assert_eq!(
infos.rank_in(),
infos.rank_out(),
"rank_in != rank_out is not supported for GGLWETensorKeyCompressed"
);
let rank_out: usize = infos.rank_out().into();
let pairs: usize = (((rank_out + 1) * rank_out) >> 1).max(1);
pairs
* GGLWESwitchingKeyCompressed::alloc_bytes_with(
infos.n(),
infos.base2k(),
infos.k(),
infos.rows(),
infos.digits(),
Rank(1),
infos.rank_out(),
)
}
fn basek(&self) -> usize {
self.keys[0].basek()
}
fn k(&self) -> usize {
self.keys[0].k()
}
}
impl<D: Data> GGLWETensorKeyCompressed<D> {
pub fn rank(&self) -> usize {
self.keys[0].rank()
}
pub fn digits(&self) -> usize {
self.keys[0].digits()
}
pub fn rank_in(&self) -> usize {
self.keys[0].rank_in()
}
pub fn rank_out(&self) -> usize {
self.keys[0].rank_out()
pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rows: Rows, digits: Digits, rank: Rank) -> usize {
let pairs: usize = (((rank.0 + 1) * rank.0) >> 1).max(1) as usize;
pairs * GGLWESwitchingKeyCompressed::alloc_bytes_with(n, base2k, k, rows, digits, Rank(1), rank)
}
}
@@ -134,7 +177,7 @@ impl<D: DataMut> GGLWETensorKeyCompressed<D> {
if i > j {
std::mem::swap(&mut i, &mut j);
};
let rank: usize = self.rank();
let rank: usize = self.rank_out().into();
&mut self.keys[i * rank + j - (i * (i + 1) / 2)]
}
}