use backend::hal::{ api::{FillUniform, Reset, VecZnxCopy, VecZnxFillUniform, ZnxInfos}, layouts::{Backend, Data, DataMut, DataRef, Module, ReaderFrom, VecZnx, VecZnxToMut, VecZnxToRef, WriterTo}, }; use sampling::source::Source; use crate::{Decompress, GLWEOps, Infos, SetMetaData}; use std::fmt; #[derive(PartialEq, Eq, Clone)] pub struct GLWECiphertext { pub data: VecZnx, pub basek: usize, pub k: usize, } impl fmt::Debug for GLWECiphertext { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}", self) } } impl fmt::Display for GLWECiphertext { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, "GLWECiphertext: basek={} k={}: {}", self.basek(), self.k(), self.data ) } } impl Reset for GLWECiphertext where VecZnx: Reset, { fn reset(&mut self) { self.data.reset(); self.basek = 0; self.k = 0; } } impl FillUniform for GLWECiphertext where VecZnx: FillUniform, { fn fill_uniform(&mut self, source: &mut Source) { self.data.fill_uniform(source); } } impl GLWECiphertext> { pub fn alloc(n: usize, basek: usize, k: usize, rank: usize) -> Self { Self { data: VecZnx::alloc(n, rank + 1, k.div_ceil(basek)), basek, k, } } pub fn bytes_of(n: usize, basek: usize, k: usize, rank: usize) -> usize { VecZnx::alloc_bytes(n, rank + 1, k.div_ceil(basek)) } } impl Infos for GLWECiphertext { type Inner = VecZnx; fn inner(&self) -> &Self::Inner { &self.data } fn basek(&self) -> usize { self.basek } fn k(&self) -> usize { self.k } } impl GLWECiphertext { pub fn rank(&self) -> usize { self.cols() - 1 } } impl GLWECiphertext { pub fn clone(&self) -> GLWECiphertext> { GLWECiphertext { data: self.data.clone(), basek: self.basek(), k: self.k(), } } } impl SetMetaData for GLWECiphertext { fn set_k(&mut self, k: usize) { self.k = k } fn set_basek(&mut self, basek: usize) { self.basek = basek } } pub trait GLWECiphertextToRef: Infos { fn to_ref(&self) -> GLWECiphertext<&[u8]>; } impl GLWECiphertextToRef for GLWECiphertext { fn to_ref(&self) -> GLWECiphertext<&[u8]> { GLWECiphertext { data: self.data.to_ref(), basek: self.basek, k: self.k, } } } pub trait GLWECiphertextToMut: Infos { fn to_mut(&mut self) -> GLWECiphertext<&mut [u8]>; } impl GLWECiphertextToMut for GLWECiphertext { fn to_mut(&mut self) -> GLWECiphertext<&mut [u8]> { GLWECiphertext { data: self.data.to_mut(), basek: self.basek, k: self.k, } } } impl GLWEOps for GLWECiphertext where GLWECiphertext: GLWECiphertextToMut + Infos + SetMetaData {} use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; impl ReaderFrom for GLWECiphertext { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { self.k = reader.read_u64::()? as usize; self.basek = reader.read_u64::()? as usize; self.data.read_from(reader) } } impl WriterTo for GLWECiphertext { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { writer.write_u64::(self.k as u64)?; writer.write_u64::(self.basek as u64)?; self.data.write_to(writer) } } #[derive(PartialEq, Eq, Clone)] pub struct GLWECiphertextCompressed { pub(crate) data: VecZnx, pub(crate) basek: usize, pub(crate) k: usize, pub(crate) rank: usize, pub(crate) seed: [u8; 32], } impl fmt::Debug for GLWECiphertextCompressed { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}", self) } } impl fmt::Display for GLWECiphertextCompressed { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, "GLWECiphertextCompressed: basek={} k={} rank={} seed={:?}: {}", self.basek(), self.k(), self.rank, self.seed, self.data ) } } impl Reset for GLWECiphertextCompressed where VecZnx: Reset, { fn reset(&mut self) { self.data.reset(); self.basek = 0; self.k = 0; self.rank = 0; self.seed = [0u8; 32]; } } impl FillUniform for GLWECiphertextCompressed where VecZnx: FillUniform, { fn fill_uniform(&mut self, source: &mut Source) { self.data.fill_uniform(source); } } impl Infos for GLWECiphertextCompressed { type Inner = VecZnx; fn inner(&self) -> &Self::Inner { &self.data } fn basek(&self) -> usize { self.basek } fn k(&self) -> usize { self.k } } impl GLWECiphertextCompressed { pub fn rank(&self) -> usize { self.rank } } impl GLWECiphertextCompressed> { pub fn alloc(n: usize, basek: usize, k: usize, rank: usize) -> Self { Self { data: VecZnx::alloc(n, 1, k.div_ceil(basek)), basek, k, rank, seed: [0u8; 32], } } pub fn bytes_of(n: usize, basek: usize, k: usize) -> usize { GLWECiphertext::bytes_of(n, basek, k, 1) } } impl ReaderFrom for GLWECiphertextCompressed { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { self.k = reader.read_u64::()? as usize; self.basek = reader.read_u64::()? as usize; self.rank = reader.read_u64::()? as usize; reader.read(&mut self.seed)?; self.data.read_from(reader) } } impl WriterTo for GLWECiphertextCompressed { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { writer.write_u64::(self.k as u64)?; writer.write_u64::(self.basek as u64)?; writer.write_u64::(self.rank as u64)?; writer.write_all(&self.seed)?; self.data.write_to(writer) } } impl Decompress> for GLWECiphertext { fn decompress(&mut self, module: &Module, other: &GLWECiphertextCompressed) where Module: VecZnxFillUniform + VecZnxCopy, { #[cfg(debug_assertions)] { assert_eq!( self.n(), other.data.n(), "invalid receiver: self.n()={} != other.n()={}", self.n(), other.data.n() ); assert_eq!( self.size(), other.size(), "invalid receiver: self.size()={} != other.size()={}", self.size(), other.size() ); assert_eq!( self.rank(), other.rank(), "invalid receiver: self.rank()={} != other.rank()={}", self.rank(), other.rank() ); let mut source: Source = Source::new(other.seed); self.decompress_internal(module, other, &mut source); } } } impl GLWECiphertext { pub(crate) fn decompress_internal( &mut self, module: &Module, other: &GLWECiphertextCompressed, source: &mut Source, ) where DataOther: DataRef, Module: VecZnxFillUniform + VecZnxCopy, { #[cfg(debug_assertions)] { assert_eq!(self.rank(), other.rank()) } let k: usize = other.k; let basek: usize = other.basek; let cols: usize = other.rank() + 1; module.vec_znx_copy(&mut self.data, 0, &other.data, 0); (1..cols).for_each(|i| { module.vec_znx_fill_uniform(basek, &mut self.data, i, k, source); }); self.basek = basek; self.k = k; } }