This commit is contained in:
Pro7ech
2025-10-14 18:46:25 +02:00
parent 0533cdff8a
commit 72dca47cbe
153 changed files with 3099 additions and 1956 deletions

View File

@@ -1,12 +1,14 @@
use poulpy_hal::{
api::{VecZnxCopy, VecZnxFillUniform},
layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, WriterTo},
source::Source,
};
use crate::layouts::{
Base2K, Degree, Dnum, Dsize, GGLWEInfos, GLWEInfos, LWEInfos, Rank, TensorKey, TorusPrecision,
compressed::{Decompress, GLWESwitchingKeyCompressed, GLWESwitchingKeyCompressedToMut, GLWESwitchingKeyCompressedToRef},
Base2K, Degree, Dnum, Dsize, GGLWEInfos, GLWEInfos, LWEInfos, Rank, TensorKey, TensorKeyToMut, TorusPrecision,
compressed::{
GLWESwitchingKeyCompressed, GLWESwitchingKeyCompressedAlloc, GLWESwitchingKeyCompressedToMut,
GLWESwitchingKeyCompressedToRef, GLWESwitchingKeyDecompress,
},
};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use std::fmt;
@@ -80,8 +82,27 @@ impl<D: DataRef> fmt::Display for TensorKeyCompressed<D> {
}
}
impl TensorKeyCompressed<Vec<u8>> {
pub fn alloc<A>(infos: &A) -> Self
pub trait TensorKeyCompressedAlloc
where
Self: GLWESwitchingKeyCompressedAlloc,
{
fn tensor_key_compressed_alloc(
&self,
base2k: Base2K,
k: TorusPrecision,
rank: Rank,
dnum: Dnum,
dsize: Dsize,
) -> TensorKeyCompressed<Vec<u8>> {
let pairs: u32 = (((rank.as_u32() + 1) * rank.as_u32()) >> 1).max(1);
TensorKeyCompressed {
keys: (0..pairs)
.map(|_| self.alloc_glwe_switching_key_compressed(base2k, k, Rank(1), rank, dnum, dsize))
.collect(),
}
}
fn tensor_key_compressed_alloc_from_infos<A>(&self, infos: &A) -> TensorKeyCompressed<Vec<u8>>
where
A: GGLWEInfos,
{
@@ -90,58 +111,70 @@ impl TensorKeyCompressed<Vec<u8>> {
infos.rank_out(),
"rank_in != rank_out is not supported for GGLWETensorKeyCompressed"
);
Self::alloc_with(
infos.n(),
self.tensor_key_compressed_alloc(
infos.base2k(),
infos.k(),
infos.rank_out(),
infos.rank(),
infos.dnum(),
infos.dsize(),
)
}
pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self {
let mut keys: Vec<GLWESwitchingKeyCompressed<Vec<u8>>> = Vec::new();
let pairs: u32 = (((rank.0 + 1) * rank.0) >> 1).max(1);
(0..pairs).for_each(|_| {
keys.push(GLWESwitchingKeyCompressed::alloc_with(
n,
base2k,
k,
Rank(1),
rank,
dnum,
dsize,
));
});
Self { keys }
fn tensor_key_compressed_bytes_of(&self, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize {
let pairs: usize = (((rank.0 + 1) * rank.0) >> 1).max(1) as usize;
pairs * self.bytes_of_glwe_switching_key_compressed(base2k, k, Rank(1), dnum, dsize)
}
pub fn alloc_bytes<A>(infos: &A) -> usize
fn tensor_key_compressed_bytes_of_from_infos<A>(&self, infos: &A) -> usize
where
A: GGLWEInfos,
{
assert_eq!(
infos.rank_in(),
infos.rank_out(),
"rank_in != rank_out is not supported for GGLWETensorKeyCompressed"
);
let rank_out: usize = infos.rank_out().into();
let pairs: usize = (((rank_out + 1) * rank_out) >> 1).max(1);
pairs
* GLWESwitchingKeyCompressed::alloc_bytes_with(
infos.n(),
infos.base2k(),
infos.k(),
Rank(1),
infos.dnum(),
infos.dsize(),
)
self.tensor_key_compressed_bytes_of(
infos.base2k(),
infos.k(),
infos.rank(),
infos.dnum(),
infos.dsize(),
)
}
}
impl TensorKeyCompressed<Vec<u8>> {
pub fn alloc_from_infos<A, B: Backend>(module: Module<B>, infos: &A) -> Self
where
A: GGLWEInfos,
Module<B>: TensorKeyCompressedAlloc,
{
module.tensor_key_compressed_alloc_from_infos(infos)
}
pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize {
let pairs: usize = (((rank.0 + 1) * rank.0) >> 1).max(1) as usize;
pairs * GLWESwitchingKeyCompressed::alloc_bytes_with(n, base2k, k, Rank(1), dnum, dsize)
pub fn alloc<B: Backend>(module: Module<B>, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self
where
Module<B>: TensorKeyCompressedAlloc,
{
module.tensor_key_compressed_alloc(base2k, k, rank, dnum, dsize)
}
pub fn bytes_of_from_infos<A, B: Backend>(module: Module<B>, infos: &A) -> usize
where
A: GGLWEInfos,
Module<B>: TensorKeyCompressedAlloc,
{
module.tensor_key_compressed_bytes_of_from_infos(infos)
}
pub fn bytes_of<B: Backend>(
module: Module<B>,
base2k: Base2K,
k: TorusPrecision,
rank: Rank,
dnum: Dnum,
dsize: Dsize,
) -> usize
where
Module<B>: TensorKeyCompressedAlloc,
{
module.tensor_key_compressed_bytes_of(base2k, k, rank, dnum, dsize)
}
}
@@ -181,28 +214,41 @@ impl<D: DataMut> TensorKeyCompressed<D> {
}
}
impl<D: DataMut, DR: DataRef, B: Backend> Decompress<B, TensorKeyCompressed<DR>> for TensorKey<D>
pub trait TensorKeyDecompress
where
Module<B>: VecZnxFillUniform + VecZnxCopy,
Self: GLWESwitchingKeyDecompress,
{
fn decompress(&mut self, module: &Module<B>, other: &TensorKeyCompressed<DR>) {
#[cfg(debug_assertions)]
{
assert_eq!(
self.keys.len(),
other.keys.len(),
"invalid receiver: self.keys.len()={} != other.keys.len()={}",
self.keys.len(),
other.keys.len()
);
}
fn decompress_tensor_key<R, O>(&self, res: &mut R, other: &O)
where
R: TensorKeyToMut,
O: TensorKeyCompressedToRef,
{
let res: &mut TensorKey<&mut [u8]> = &mut res.to_mut();
let other: &TensorKeyCompressed<&[u8]> = &other.to_ref();
self.keys
.iter_mut()
.zip(other.keys.iter())
.for_each(|(a, b)| {
a.decompress(module, b);
});
assert_eq!(
res.keys.len(),
other.keys.len(),
"invalid receiver: res.keys.len()={} != other.keys.len()={}",
res.keys.len(),
other.keys.len()
);
for (a, b) in res.keys.iter_mut().zip(other.keys.iter()) {
self.decompress_glwe_switching_key(a, b);
}
}
}
impl<B: Backend> TensorKeyDecompress for Module<B> where Self: GLWESwitchingKeyDecompress {}
impl<D: DataMut> TensorKey<D> {
pub fn decompress<O, B: Backend>(&mut self, module: &Module<B>, other: &O)
where
O: TensorKeyCompressedToRef,
Module<B>: GLWESwitchingKeyDecompress,
{
module.decompress_tensor_key(self, other);
}
}