Added more serialization tests + generalize methods to any n

This commit is contained in:
Pro7ech
2025-08-13 15:28:52 +02:00
parent 068470783e
commit 940742ce6c
117 changed files with 3658 additions and 2577 deletions

View File

@@ -1,8 +1,5 @@
use backend::hal::{
api::{
ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAlloc, VecZnxAllocBytes, VecZnxDftAlloc, VecZnxDftAllocBytes,
VecZnxDftFromVecZnx,
},
api::{ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftFromVecZnx},
layouts::{Backend, Data, DataMut, DataRef, Module, ReaderFrom, Scratch, ScratchOwned, VecZnx, VecZnxDft, WriterTo},
oep::{ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxDftImpl, TakeVecZnxImpl},
};
@@ -21,23 +18,17 @@ pub struct GLWEPublicKey<D: Data> {
}
impl GLWEPublicKey<Vec<u8>> {
pub fn alloc<B: Backend>(module: &Module<B>, basek: usize, k: usize, rank: usize) -> Self
where
Module<B>: VecZnxAlloc,
{
pub fn alloc(n: usize, basek: usize, k: usize, rank: usize) -> Self {
Self {
data: module.vec_znx_alloc(rank + 1, k.div_ceil(basek)),
data: VecZnx::alloc(n, rank + 1, k.div_ceil(basek)),
basek: basek,
k: k,
dist: Distribution::NONE,
}
}
pub fn bytes_of<B: Backend>(module: &Module<B>, basek: usize, k: usize, rank: usize) -> usize
where
Module<B>: VecZnxAllocBytes,
{
module.vec_znx_alloc_bytes(rank + 1, k.div_ceil(basek))
pub fn bytes_of(n: usize, basek: usize, k: usize, rank: usize) -> usize {
VecZnx::alloc_bytes(n, rank + 1, k.div_ceil(basek))
}
}
@@ -72,7 +63,7 @@ impl<D: DataMut> GLWEPublicKey<D> {
source_xe: &mut Source,
sigma: f64,
) where
Module<B>: GLWEPublicKeyFamily<B> + VecZnxAlloc,
Module<B>: GLWEPublicKeyFamily<B>,
B: ScratchOwnedAllocImpl<B>
+ ScratchOwnedBorrowImpl<B>
+ TakeVecZnxDftImpl<B>
@@ -81,6 +72,8 @@ impl<D: DataMut> GLWEPublicKey<D> {
{
#[cfg(debug_assertions)]
{
assert_eq!(self.n(), sk.n());
match sk.dist {
Distribution::NONE => panic!("invalid sk: SecretDistribution::NONE"),
_ => {}
@@ -90,11 +83,12 @@ impl<D: DataMut> GLWEPublicKey<D> {
// Its ok to allocate scratch space here since pk is usually generated only once.
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(GLWECiphertext::encrypt_sk_scratch_space(
module,
self.n(),
self.basek(),
self.k(),
));
let mut tmp: GLWECiphertext<Vec<u8>> = GLWECiphertext::alloc(module, self.basek(), self.k(), self.rank());
let mut tmp: GLWECiphertext<Vec<u8>> = GLWECiphertext::alloc(self.n(), self.basek(), self.k(), self.rank());
tmp.encrypt_zero_sk(module, sk, source_xa, source_xe, sigma, scratch.borrow());
self.dist = sk.dist;
}
@@ -157,23 +151,23 @@ impl<D: Data, B: Backend> GLWEPublicKeyExec<D, B> {
}
impl<B: Backend> GLWEPublicKeyExec<Vec<u8>, B> {
pub fn alloc(module: &Module<B>, basek: usize, k: usize, rank: usize) -> Self
pub fn alloc(module: &Module<B>, n: usize, basek: usize, k: usize, rank: usize) -> Self
where
Module<B>: VecZnxDftAlloc<B>,
{
Self {
data: module.vec_znx_dft_alloc(rank + 1, k.div_ceil(basek)),
data: module.vec_znx_dft_alloc(n, rank + 1, k.div_ceil(basek)),
basek: basek,
k: k,
dist: Distribution::NONE,
}
}
pub fn bytes_of(module: &Module<B>, basek: usize, k: usize, rank: usize) -> usize
pub fn bytes_of(module: &Module<B>, n: usize, basek: usize, k: usize, rank: usize) -> usize
where
Module<B>: VecZnxDftAllocBytes,
{
module.vec_znx_dft_alloc_bytes(rank + 1, k.div_ceil(basek))
module.vec_znx_dft_alloc_bytes(n, rank + 1, k.div_ceil(basek))
}
pub fn from<DataOther>(module: &Module<B>, other: &GLWEPublicKey<DataOther>, scratch: &mut Scratch<B>) -> Self
@@ -181,7 +175,8 @@ impl<B: Backend> GLWEPublicKeyExec<Vec<u8>, B> {
DataOther: DataRef,
Module<B>: VecZnxDftAlloc<B> + VecZnxDftFromVecZnx<B>,
{
let mut pk_exec: GLWEPublicKeyExec<Vec<u8>, B> = GLWEPublicKeyExec::alloc(module, other.basek(), other.k(), other.rank());
let mut pk_exec: GLWEPublicKeyExec<Vec<u8>, B> =
GLWEPublicKeyExec::alloc(module, other.n(), other.basek(), other.k(), other.rank());
pk_exec.prepare(module, other, scratch);
pk_exec
}
@@ -195,8 +190,7 @@ impl<D: DataMut, B: Backend> GLWEPublicKeyExec<D, B> {
{
#[cfg(debug_assertions)]
{
assert_eq!(self.n(), module.n());
assert_eq!(other.n(), module.n());
assert_eq!(self.n(), other.n());
assert_eq!(self.size(), other.size());
}