Added more serialization tests + generalize methods to any n

This commit is contained in:
Pro7ech
2025-08-13 15:28:52 +02:00
parent 068470783e
commit 940742ce6c
117 changed files with 3658 additions and 2577 deletions

View File

@@ -1,12 +1,11 @@
use std::fmt::Debug;
use backend::hal::{
api::{FillUniform, VecZnxAlloc, VecZnxAllocBytes, VecZnxCopy, VecZnxFillUniform, ZnxInfos, ZnxZero},
api::{FillUniform, Reset, VecZnxCopy, VecZnxFillUniform, ZnxInfos},
layouts::{Backend, Data, DataMut, DataRef, Module, ReaderFrom, VecZnx, VecZnxToMut, VecZnxToRef, WriterTo},
};
use sampling::source::Source;
use crate::{Decompress, GLWEOps, Infos, SetMetaData};
use std::fmt;
#[derive(PartialEq, Eq, Clone)]
pub struct GLWECiphertext<D: Data> {
@@ -15,8 +14,14 @@ pub struct GLWECiphertext<D: Data> {
pub k: usize,
}
impl<D: DataRef> Debug for GLWECiphertext<D> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
impl<D: DataRef> fmt::Debug for GLWECiphertext<D> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self)
}
}
impl<D: DataRef> fmt::Display for GLWECiphertext<D> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"GLWECiphertext: basek={} k={}: {}",
@@ -27,16 +32,14 @@ impl<D: DataRef> Debug for GLWECiphertext<D> {
}
}
impl<D: DataMut> ZnxZero for GLWECiphertext<D>
impl<D: DataMut> Reset for GLWECiphertext<D>
where
VecZnx<D>: ZnxZero,
VecZnx<D>: Reset,
{
fn zero(&mut self) {
self.data.zero()
}
fn zero_at(&mut self, i: usize, j: usize) {
self.data.zero_at(i, j);
fn reset(&mut self) {
self.data.reset();
self.basek = 0;
self.k = 0;
}
}
@@ -50,22 +53,16 @@ where
}
impl GLWECiphertext<Vec<u8>> {
pub fn alloc<B: Backend>(module: &Module<B>, basek: usize, k: usize, rank: usize) -> Self
where
Module<B>: VecZnxAlloc,
{
pub fn alloc(n: usize, basek: usize, k: usize, rank: usize) -> Self {
Self {
data: module.vec_znx_alloc(rank + 1, k.div_ceil(basek)),
data: VecZnx::alloc(n, rank + 1, k.div_ceil(basek)),
basek,
k,
}
}
pub fn bytes_of<B: Backend>(module: &Module<B>, basek: usize, k: usize, rank: usize) -> usize
where
Module<B>: VecZnxAllocBytes,
{
module.vec_znx_alloc_bytes(rank + 1, k.div_ceil(basek))
pub fn bytes_of(n: usize, basek: usize, k: usize, rank: usize) -> usize {
VecZnx::alloc_bytes(n, rank + 1, k.div_ceil(basek))
}
}
@@ -168,28 +165,36 @@ pub struct GLWECiphertextCompressed<D: Data> {
pub(crate) seed: [u8; 32],
}
impl<D: DataRef> Debug for GLWECiphertextCompressed<D> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
impl<D: DataRef> fmt::Debug for GLWECiphertextCompressed<D> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self)
}
}
impl<D: DataRef> fmt::Display for GLWECiphertextCompressed<D> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"GLWECiphertext: basek={} k={}: {}",
"GLWECiphertextCompressed: basek={} k={} rank={} seed={:?}: {}",
self.basek(),
self.k(),
self.rank,
self.seed,
self.data
)
}
}
impl<D: DataMut> ZnxZero for GLWECiphertextCompressed<D>
impl<D: DataMut> Reset for GLWECiphertextCompressed<D>
where
VecZnx<D>: ZnxZero,
VecZnx<D>: Reset,
{
fn zero(&mut self) {
self.data.zero()
}
fn zero_at(&mut self, i: usize, j: usize) {
self.data.zero_at(i, j);
fn reset(&mut self) {
self.data.reset();
self.basek = 0;
self.k = 0;
self.rank = 0;
self.seed = [0u8; 32];
}
}
@@ -225,12 +230,9 @@ impl<D: Data> GLWECiphertextCompressed<D> {
}
impl GLWECiphertextCompressed<Vec<u8>> {
pub fn alloc<B: Backend>(module: &Module<B>, basek: usize, k: usize, rank: usize) -> Self
where
Module<B>: VecZnxAlloc,
{
pub fn alloc(n: usize, basek: usize, k: usize, rank: usize) -> Self {
Self {
data: module.vec_znx_alloc(1, k.div_ceil(basek)),
data: VecZnx::alloc(n, 1, k.div_ceil(basek)),
basek,
k,
rank,
@@ -238,11 +240,8 @@ impl GLWECiphertextCompressed<Vec<u8>> {
}
}
pub fn bytes_of<B: Backend>(module: &Module<B>, basek: usize, k: usize) -> usize
where
Module<B>: VecZnxAllocBytes,
{
GLWECiphertext::bytes_of(module, basek, k, 1)
pub fn bytes_of(n: usize, basek: usize, k: usize) -> usize {
GLWECiphertext::bytes_of(n, basek, k, 1)
}
}