use backend::hal::{ api::{ ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAlloc, VecZnxAllocBytes, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftFromVecZnx, }, layouts::{Backend, Data, DataMut, DataRef, Module, ReaderFrom, Scratch, ScratchOwned, VecZnx, VecZnxDft, WriterTo}, oep::{ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxDftImpl, TakeVecZnxImpl}, }; use sampling::source::Source; use crate::{GLWECiphertext, GLWEEncryptSkFamily, GLWESecretExec, Infos, dist::Distribution}; pub trait GLWEPublicKeyFamily = GLWEEncryptSkFamily; #[derive(PartialEq, Eq)] pub struct GLWEPublicKey { pub(crate) data: VecZnx, pub(crate) basek: usize, pub(crate) k: usize, pub(crate) dist: Distribution, } impl GLWEPublicKey> { pub fn alloc(module: &Module, basek: usize, k: usize, rank: usize) -> Self where Module: VecZnxAlloc, { Self { data: module.vec_znx_alloc(rank + 1, k.div_ceil(basek)), basek: basek, k: k, dist: Distribution::NONE, } } pub fn bytes_of(module: &Module, basek: usize, k: usize, rank: usize) -> usize where Module: VecZnxAllocBytes, { module.vec_znx_alloc_bytes(rank + 1, k.div_ceil(basek)) } } impl Infos for GLWEPublicKey { type Inner = VecZnx; fn inner(&self) -> &Self::Inner { &self.data } fn basek(&self) -> usize { self.basek } fn k(&self) -> usize { self.k } } impl GLWEPublicKey { pub fn rank(&self) -> usize { self.cols() - 1 } } impl GLWEPublicKey { pub fn generate_from_sk( &mut self, module: &Module, sk: &GLWESecretExec, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, ) where Module: GLWEPublicKeyFamily + VecZnxAlloc, B: ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl + TakeVecZnxDftImpl + ScratchAvailableImpl + TakeVecZnxImpl, { #[cfg(debug_assertions)] { match sk.dist { Distribution::NONE => panic!("invalid sk: SecretDistribution::NONE"), _ => {} } } // Its ok to allocate scratch space here since pk is usually generated only once. let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWECiphertext::encrypt_sk_scratch_space( module, self.basek(), self.k(), )); let mut tmp: GLWECiphertext> = GLWECiphertext::alloc(module, self.basek(), self.k(), self.rank()); tmp.encrypt_zero_sk(module, sk, source_xa, source_xe, sigma, scratch.borrow()); self.dist = sk.dist; } } use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; impl ReaderFrom for GLWEPublicKey { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { self.k = reader.read_u64::()? as usize; self.basek = reader.read_u64::()? as usize; match Distribution::read_from(reader) { Ok(dist) => self.dist = dist, Err(e) => return Err(e), } self.data.read_from(reader) } } impl WriterTo for GLWEPublicKey { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { writer.write_u64::(self.k as u64)?; writer.write_u64::(self.basek as u64)?; match self.dist.write_to(writer) { Ok(()) => {} Err(e) => return Err(e), } self.data.write_to(writer) } } #[derive(PartialEq, Eq)] pub struct GLWEPublicKeyExec { pub(crate) data: VecZnxDft, pub(crate) basek: usize, pub(crate) k: usize, pub(crate) dist: Distribution, } impl Infos for GLWEPublicKeyExec { type Inner = VecZnxDft; fn inner(&self) -> &Self::Inner { &self.data } fn basek(&self) -> usize { self.basek } fn k(&self) -> usize { self.k } } impl GLWEPublicKeyExec { pub fn rank(&self) -> usize { self.cols() - 1 } } impl GLWEPublicKeyExec, B> { pub fn alloc(module: &Module, basek: usize, k: usize, rank: usize) -> Self where Module: VecZnxDftAlloc, { Self { data: module.vec_znx_dft_alloc(rank + 1, k.div_ceil(basek)), basek: basek, k: k, dist: Distribution::NONE, } } pub fn bytes_of(module: &Module, basek: usize, k: usize, rank: usize) -> usize where Module: VecZnxDftAllocBytes, { module.vec_znx_dft_alloc_bytes(rank + 1, k.div_ceil(basek)) } pub fn from(module: &Module, other: &GLWEPublicKey, scratch: &mut Scratch) -> Self where DataOther: DataRef, Module: VecZnxDftAlloc + VecZnxDftFromVecZnx, { let mut pk_exec: GLWEPublicKeyExec, B> = GLWEPublicKeyExec::alloc(module, other.basek(), other.k(), other.rank()); pk_exec.prepare(module, other, scratch); pk_exec } } impl GLWEPublicKeyExec { pub fn prepare(&mut self, module: &Module, other: &GLWEPublicKey, _scratch: &mut Scratch) where DataOther: DataRef, Module: VecZnxDftFromVecZnx, { #[cfg(debug_assertions)] { assert_eq!(self.n(), module.n()); assert_eq!(other.n(), module.n()); assert_eq!(self.size(), other.size()); } (0..self.cols()).for_each(|i| { module.vec_znx_dft_from_vec_znx(1, 0, &mut self.data, i, &other.data, i); }); self.k = other.k; self.basek = other.basek; self.dist = other.dist; } }