use poulpy_hal::{ api::{SvpPPolAlloc, SvpPrepare, VmpPMatAlloc, VmpPrepare}, layouts::{Backend, Data, DataMut, DataRef, Module, ScalarZnx, Scratch, SvpPPol}, }; use std::marker::PhantomData; use poulpy_core::{ Distribution, layouts::{ Base2K, Dnum, Dsize, GGSWInfos, GLWEInfos, LWEInfos, Rank, RingDegree, TorusPrecision, prepared::{GGSWPrepared, Prepare, PrepareAlloc}, }, }; use crate::tfhe::blind_rotation::{BlindRotationAlgo, BlindRotationKey, BlindRotationKeyInfos, utils::set_xai_plus_y}; pub trait BlindRotationKeyPreparedAlloc { fn alloc(module: &Module, infos: &A) -> Self where A: BlindRotationKeyInfos; } #[derive(PartialEq, Eq)] pub struct BlindRotationKeyPrepared { pub(crate) data: Vec>, pub(crate) dist: Distribution, pub(crate) x_pow_a: Option, B>>>, pub(crate) _phantom: PhantomData, } impl BlindRotationKeyInfos for BlindRotationKeyPrepared { fn n_glwe(&self) -> RingDegree { self.n() } fn n_lwe(&self) -> RingDegree { RingDegree(self.data.len() as u32) } } impl LWEInfos for BlindRotationKeyPrepared { fn base2k(&self) -> Base2K { self.data[0].base2k() } fn k(&self) -> TorusPrecision { self.data[0].k() } fn n(&self) -> RingDegree { self.data[0].n() } fn size(&self) -> usize { self.data[0].size() } } impl GLWEInfos for BlindRotationKeyPrepared { fn rank(&self) -> Rank { self.data[0].rank() } } impl GGSWInfos for BlindRotationKeyPrepared { fn dsize(&self) -> poulpy_core::layouts::Dsize { Dsize(1) } fn dnum(&self) -> Dnum { self.data[0].dnum() } } impl BlindRotationKeyPrepared { pub fn block_size(&self) -> usize { match self.dist { Distribution::BinaryBlock(value) => value, _ => 1, } } } impl PrepareAlloc, BRA, B>> for BlindRotationKey where BlindRotationKeyPrepared, BRA, B>: BlindRotationKeyPreparedAlloc, BlindRotationKeyPrepared, BRA, B>: Prepare>, { fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> BlindRotationKeyPrepared, BRA, B> { let mut brk: BlindRotationKeyPrepared, BRA, B> = BlindRotationKeyPrepared::alloc(module, self); brk.prepare(module, self, scratch); brk } } impl Prepare> for BlindRotationKeyPrepared where Module: VmpPMatAlloc + VmpPrepare + SvpPPolAlloc + SvpPrepare, { fn prepare(&mut self, module: &Module, other: &BlindRotationKey, scratch: &mut Scratch) { #[cfg(debug_assertions)] { assert_eq!(self.data.len(), other.keys.len()); } let n: usize = other.n().as_usize(); self.data .iter_mut() .zip(other.keys.iter()) .for_each(|(ggsw_prepared, other)| { ggsw_prepared.prepare(module, other, scratch); }); self.dist = other.dist; if let Distribution::BinaryBlock(_) = other.dist { let mut x_pow_a: Vec, B>> = Vec::with_capacity(n << 1); let mut buf: ScalarZnx> = ScalarZnx::alloc(n, 1); (0..n << 1).for_each(|i| { let mut res: SvpPPol, B> = module.svp_ppol_alloc(1); set_xai_plus_y(module, i, 0, &mut res, &mut buf); x_pow_a.push(res); }); self.x_pow_a = Some(x_pow_a); } } }