mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
wip
This commit is contained in:
@@ -1,12 +1,14 @@
|
||||
use poulpy_hal::{
|
||||
api::{VecZnxCopy, VecZnxFillUniform},
|
||||
layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, WriterTo},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
use crate::layouts::{
|
||||
Base2K, Degree, Dnum, Dsize, GGLWEInfos, GLWEInfos, LWEInfos, Rank, TensorKey, TorusPrecision,
|
||||
compressed::{Decompress, GLWESwitchingKeyCompressed, GLWESwitchingKeyCompressedToMut, GLWESwitchingKeyCompressedToRef},
|
||||
Base2K, Degree, Dnum, Dsize, GGLWEInfos, GLWEInfos, LWEInfos, Rank, TensorKey, TensorKeyToMut, TorusPrecision,
|
||||
compressed::{
|
||||
GLWESwitchingKeyCompressed, GLWESwitchingKeyCompressedAlloc, GLWESwitchingKeyCompressedToMut,
|
||||
GLWESwitchingKeyCompressedToRef, GLWESwitchingKeyDecompress,
|
||||
},
|
||||
};
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
use std::fmt;
|
||||
@@ -80,8 +82,27 @@ impl<D: DataRef> fmt::Display for TensorKeyCompressed<D> {
|
||||
}
|
||||
}
|
||||
|
||||
impl TensorKeyCompressed<Vec<u8>> {
|
||||
pub fn alloc<A>(infos: &A) -> Self
|
||||
pub trait TensorKeyCompressedAlloc
|
||||
where
|
||||
Self: GLWESwitchingKeyCompressedAlloc,
|
||||
{
|
||||
fn tensor_key_compressed_alloc(
|
||||
&self,
|
||||
base2k: Base2K,
|
||||
k: TorusPrecision,
|
||||
rank: Rank,
|
||||
dnum: Dnum,
|
||||
dsize: Dsize,
|
||||
) -> TensorKeyCompressed<Vec<u8>> {
|
||||
let pairs: u32 = (((rank.as_u32() + 1) * rank.as_u32()) >> 1).max(1);
|
||||
TensorKeyCompressed {
|
||||
keys: (0..pairs)
|
||||
.map(|_| self.alloc_glwe_switching_key_compressed(base2k, k, Rank(1), rank, dnum, dsize))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
fn tensor_key_compressed_alloc_from_infos<A>(&self, infos: &A) -> TensorKeyCompressed<Vec<u8>>
|
||||
where
|
||||
A: GGLWEInfos,
|
||||
{
|
||||
@@ -90,58 +111,70 @@ impl TensorKeyCompressed<Vec<u8>> {
|
||||
infos.rank_out(),
|
||||
"rank_in != rank_out is not supported for GGLWETensorKeyCompressed"
|
||||
);
|
||||
Self::alloc_with(
|
||||
infos.n(),
|
||||
self.tensor_key_compressed_alloc(
|
||||
infos.base2k(),
|
||||
infos.k(),
|
||||
infos.rank_out(),
|
||||
infos.rank(),
|
||||
infos.dnum(),
|
||||
infos.dsize(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self {
|
||||
let mut keys: Vec<GLWESwitchingKeyCompressed<Vec<u8>>> = Vec::new();
|
||||
let pairs: u32 = (((rank.0 + 1) * rank.0) >> 1).max(1);
|
||||
(0..pairs).for_each(|_| {
|
||||
keys.push(GLWESwitchingKeyCompressed::alloc_with(
|
||||
n,
|
||||
base2k,
|
||||
k,
|
||||
Rank(1),
|
||||
rank,
|
||||
dnum,
|
||||
dsize,
|
||||
));
|
||||
});
|
||||
Self { keys }
|
||||
fn tensor_key_compressed_bytes_of(&self, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize {
|
||||
let pairs: usize = (((rank.0 + 1) * rank.0) >> 1).max(1) as usize;
|
||||
pairs * self.bytes_of_glwe_switching_key_compressed(base2k, k, Rank(1), dnum, dsize)
|
||||
}
|
||||
|
||||
pub fn alloc_bytes<A>(infos: &A) -> usize
|
||||
fn tensor_key_compressed_bytes_of_from_infos<A>(&self, infos: &A) -> usize
|
||||
where
|
||||
A: GGLWEInfos,
|
||||
{
|
||||
assert_eq!(
|
||||
infos.rank_in(),
|
||||
infos.rank_out(),
|
||||
"rank_in != rank_out is not supported for GGLWETensorKeyCompressed"
|
||||
);
|
||||
let rank_out: usize = infos.rank_out().into();
|
||||
let pairs: usize = (((rank_out + 1) * rank_out) >> 1).max(1);
|
||||
pairs
|
||||
* GLWESwitchingKeyCompressed::alloc_bytes_with(
|
||||
infos.n(),
|
||||
infos.base2k(),
|
||||
infos.k(),
|
||||
Rank(1),
|
||||
infos.dnum(),
|
||||
infos.dsize(),
|
||||
)
|
||||
self.tensor_key_compressed_bytes_of(
|
||||
infos.base2k(),
|
||||
infos.k(),
|
||||
infos.rank(),
|
||||
infos.dnum(),
|
||||
infos.dsize(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl TensorKeyCompressed<Vec<u8>> {
|
||||
pub fn alloc_from_infos<A, B: Backend>(module: Module<B>, infos: &A) -> Self
|
||||
where
|
||||
A: GGLWEInfos,
|
||||
Module<B>: TensorKeyCompressedAlloc,
|
||||
{
|
||||
module.tensor_key_compressed_alloc_from_infos(infos)
|
||||
}
|
||||
|
||||
pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize {
|
||||
let pairs: usize = (((rank.0 + 1) * rank.0) >> 1).max(1) as usize;
|
||||
pairs * GLWESwitchingKeyCompressed::alloc_bytes_with(n, base2k, k, Rank(1), dnum, dsize)
|
||||
pub fn alloc<B: Backend>(module: Module<B>, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self
|
||||
where
|
||||
Module<B>: TensorKeyCompressedAlloc,
|
||||
{
|
||||
module.tensor_key_compressed_alloc(base2k, k, rank, dnum, dsize)
|
||||
}
|
||||
|
||||
pub fn bytes_of_from_infos<A, B: Backend>(module: Module<B>, infos: &A) -> usize
|
||||
where
|
||||
A: GGLWEInfos,
|
||||
Module<B>: TensorKeyCompressedAlloc,
|
||||
{
|
||||
module.tensor_key_compressed_bytes_of_from_infos(infos)
|
||||
}
|
||||
|
||||
pub fn bytes_of<B: Backend>(
|
||||
module: Module<B>,
|
||||
base2k: Base2K,
|
||||
k: TorusPrecision,
|
||||
rank: Rank,
|
||||
dnum: Dnum,
|
||||
dsize: Dsize,
|
||||
) -> usize
|
||||
where
|
||||
Module<B>: TensorKeyCompressedAlloc,
|
||||
{
|
||||
module.tensor_key_compressed_bytes_of(base2k, k, rank, dnum, dsize)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -181,28 +214,41 @@ impl<D: DataMut> TensorKeyCompressed<D> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataMut, DR: DataRef, B: Backend> Decompress<B, TensorKeyCompressed<DR>> for TensorKey<D>
|
||||
pub trait TensorKeyDecompress
|
||||
where
|
||||
Module<B>: VecZnxFillUniform + VecZnxCopy,
|
||||
Self: GLWESwitchingKeyDecompress,
|
||||
{
|
||||
fn decompress(&mut self, module: &Module<B>, other: &TensorKeyCompressed<DR>) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(
|
||||
self.keys.len(),
|
||||
other.keys.len(),
|
||||
"invalid receiver: self.keys.len()={} != other.keys.len()={}",
|
||||
self.keys.len(),
|
||||
other.keys.len()
|
||||
);
|
||||
}
|
||||
fn decompress_tensor_key<R, O>(&self, res: &mut R, other: &O)
|
||||
where
|
||||
R: TensorKeyToMut,
|
||||
O: TensorKeyCompressedToRef,
|
||||
{
|
||||
let res: &mut TensorKey<&mut [u8]> = &mut res.to_mut();
|
||||
let other: &TensorKeyCompressed<&[u8]> = &other.to_ref();
|
||||
|
||||
self.keys
|
||||
.iter_mut()
|
||||
.zip(other.keys.iter())
|
||||
.for_each(|(a, b)| {
|
||||
a.decompress(module, b);
|
||||
});
|
||||
assert_eq!(
|
||||
res.keys.len(),
|
||||
other.keys.len(),
|
||||
"invalid receiver: res.keys.len()={} != other.keys.len()={}",
|
||||
res.keys.len(),
|
||||
other.keys.len()
|
||||
);
|
||||
|
||||
for (a, b) in res.keys.iter_mut().zip(other.keys.iter()) {
|
||||
self.decompress_glwe_switching_key(a, b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> TensorKeyDecompress for Module<B> where Self: GLWESwitchingKeyDecompress {}
|
||||
|
||||
impl<D: DataMut> TensorKey<D> {
|
||||
pub fn decompress<O, B: Backend>(&mut self, module: &Module<B>, other: &O)
|
||||
where
|
||||
O: TensorKeyCompressedToRef,
|
||||
Module<B>: GLWESwitchingKeyDecompress,
|
||||
{
|
||||
module.decompress_tensor_key(self, other);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user