use backend::{ Backend, FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxDftToMut, ScalarZnxDftToRef, ScalarZnxToMut, ScalarZnxToRef, ScratchOwned, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, ZnxInfos, ZnxZero, }; use sampling::source::Source; use crate::{elem::Infos, glwe_ciphertext_fourier::GLWECiphertextFourier}; #[derive(Clone, Copy, Debug)] pub enum SecretDistribution { TernaryFixed(usize), // Ternary with fixed Hamming weight TernaryProb(f64), // Ternary with probabilistic Hamming weight ZERO, // Debug mod NONE, } pub struct SecretKey { pub data: ScalarZnx, pub dist: SecretDistribution, } impl SecretKey> { pub fn alloc(module: &Module, rank: usize) -> Self { Self { data: module.new_scalar_znx(rank), dist: SecretDistribution::NONE, } } } impl SecretKey { pub fn n(&self) -> usize { self.data.n() } pub fn log_n(&self) -> usize { self.data.log_n() } pub fn rank(&self) -> usize { self.data.cols() } } impl SecretKey where S: AsMut<[u8]> + AsRef<[u8]>, { pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) { (0..self.rank()).for_each(|i| { self.data.fill_ternary_prob(i, prob, source); }); self.dist = SecretDistribution::TernaryProb(prob); } pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) { (0..self.rank()).for_each(|i| { self.data.fill_ternary_hw(i, hw, source); }); self.dist = SecretDistribution::TernaryFixed(hw); } pub fn fill_zero(&mut self) { self.data.zero(); self.dist = SecretDistribution::ZERO; } } impl ScalarZnxToMut for SecretKey where ScalarZnx: ScalarZnxToMut, { fn to_mut(&mut self) -> ScalarZnx<&mut [u8]> { self.data.to_mut() } } impl ScalarZnxToRef for SecretKey where ScalarZnx: ScalarZnxToRef, { fn to_ref(&self) -> ScalarZnx<&[u8]> { self.data.to_ref() } } pub struct SecretKeyFourier { pub data: ScalarZnxDft, pub dist: SecretDistribution, } impl SecretKeyFourier { pub fn n(&self) -> usize { self.data.n() } pub fn log_n(&self) -> usize { self.data.log_n() } pub fn rank(&self) -> usize { self.data.cols() } } impl SecretKeyFourier, B> { pub fn alloc(module: &Module, rank: usize) -> Self { Self { data: module.new_scalar_znx_dft(rank), dist: SecretDistribution::NONE, } } pub fn dft(&mut self, module: &Module, sk: &SecretKey) where SecretKeyFourier, B>: ScalarZnxDftToMut, SecretKey: ScalarZnxToRef, { #[cfg(debug_assertions)] { match sk.dist { SecretDistribution::NONE => panic!("invalid sk: SecretDistribution::NONE"), _ => {} } assert_eq!(self.n(), module.n()); assert_eq!(sk.n(), module.n()); assert_eq!(self.rank(), sk.rank()); } (0..self.rank()).for_each(|i| { module.svp_prepare(self, i, sk, i); }); self.dist = sk.dist; } } impl ScalarZnxDftToMut for SecretKeyFourier where ScalarZnxDft: ScalarZnxDftToMut, { fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B> { self.data.to_mut() } } impl ScalarZnxDftToRef for SecretKeyFourier where ScalarZnxDft: ScalarZnxDftToRef, { fn to_ref(&self) -> ScalarZnxDft<&[u8], B> { self.data.to_ref() } } pub struct GLWEPublicKey { pub data: GLWECiphertextFourier, pub dist: SecretDistribution, } impl GLWEPublicKey, B> { pub fn alloc(module: &Module, basek: usize, k: usize, rank: usize) -> Self { Self { data: GLWECiphertextFourier::alloc(module, basek, k, rank), dist: SecretDistribution::NONE, } } } impl Infos for GLWEPublicKey { type Inner = VecZnxDft; fn inner(&self) -> &Self::Inner { &self.data.data } fn basek(&self) -> usize { self.data.basek } fn k(&self) -> usize { self.data.k } } impl GLWEPublicKey { pub fn rank(&self) -> usize { self.cols() - 1 } } impl VecZnxDftToMut for GLWEPublicKey where VecZnxDft: VecZnxDftToMut, { fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { self.data.to_mut() } } impl VecZnxDftToRef for GLWEPublicKey where VecZnxDft: VecZnxDftToRef, { fn to_ref(&self) -> VecZnxDft<&[u8], B> { self.data.to_ref() } } impl GLWEPublicKey { pub fn generate_from_sk( &mut self, module: &Module, sk_dft: &SecretKeyFourier, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, ) where VecZnxDft: VecZnxDftToMut, ScalarZnxDft: ScalarZnxDftToRef + ZnxInfos, { #[cfg(debug_assertions)] { match sk_dft.dist { SecretDistribution::NONE => panic!("invalid sk_dft: SecretDistribution::NONE"), _ => {} } } // Its ok to allocate scratch space here since pk is usually generated only once. let mut scratch: ScratchOwned = ScratchOwned::new(GLWECiphertextFourier::encrypt_sk_scratch_space( module, self.rank(), self.size(), )); self.data.encrypt_zero_sk( module, sk_dft, source_xa, source_xe, sigma, scratch.borrow(), ); self.dist = sk_dft.dist; } }