use poulpy_core::{ layouts::{ GGLWEInfos, GGLWEToGGSWKeyPrepared, GGLWEToGGSWKeyPreparedFactory, GGSWInfos, GLWEAutomorphismKeyHelper, GLWEAutomorphismKeyLayout, GLWEAutomorphismKeyPreparedFactory, GLWEInfos, GLWETensorKeyLayout, GLWETensorKeyPreparedFactory, LWEInfos, prepared::GLWEAutomorphismKeyPrepared, }, trace_galois_elements, }; use std::collections::HashMap; use poulpy_hal::layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}; use crate::tfhe::{ blind_rotation::{ BlindRotationAlgo, BlindRotationKeyInfos, BlindRotationKeyLayout, BlindRotationKeyPrepared, BlindRotationKeyPreparedFactory, }, circuit_bootstrapping::{CircuitBootstrappingKey, CircuitBootstrappingKeyInfos}, }; impl CircuitBootstrappingKeyPrepared, BRA, BE> { pub fn alloc_from_infos(module: &M, infos: &A) -> CircuitBootstrappingKeyPrepared, BRA, BE> where A: CircuitBootstrappingKeyInfos, M: CircuitBootstrappingKeyPreparedFactory, { module.circuit_bootstrapping_key_prepared_alloc_from_infos(infos) } } impl CircuitBootstrappingKeyPrepared { pub fn prepare(&mut self, module: &M, other: &CircuitBootstrappingKey, scratch: &mut Scratch) where DR: DataRef, M: CircuitBootstrappingKeyPreparedFactory, { module.circuit_bootstrapping_key_prepare(self, other, scratch); } } impl CircuitBootstrappingKeyPreparedFactory for Module where Self: Sized + BlindRotationKeyPreparedFactory + GLWETensorKeyPreparedFactory + GLWEAutomorphismKeyPreparedFactory { } pub trait CircuitBootstrappingKeyPreparedFactory where Self: Sized + BlindRotationKeyPreparedFactory + GGLWEToGGSWKeyPreparedFactory + GLWEAutomorphismKeyPreparedFactory, { fn circuit_bootstrapping_key_prepared_alloc_from_infos( &self, infos: &A, ) -> CircuitBootstrappingKeyPrepared, BRA, BE> where A: CircuitBootstrappingKeyInfos, { let atk_infos: &GLWEAutomorphismKeyLayout = &infos.atk_infos(); let gal_els: Vec = trace_galois_elements(atk_infos.log_n(), 2 * atk_infos.n().as_usize() as i64); CircuitBootstrappingKeyPrepared { brk: BlindRotationKeyPrepared::alloc(self, &infos.brk_infos()), tsk: GGLWEToGGSWKeyPrepared::alloc_from_infos(self, &infos.tsk_infos()), atk: gal_els .iter() .map(|&gal_el| { let key = GLWEAutomorphismKeyPrepared::alloc_from_infos(self, atk_infos); (gal_el, key) }) .collect(), } } fn circuit_bootstrapping_key_prepare_tmp_bytes(&self, infos: &A) -> usize where A: CircuitBootstrappingKeyInfos, { self.blind_rotation_key_prepare_tmp_bytes(&infos.brk_infos()) .max(self.prepare_gglwe_to_ggsw_key_tmp_bytes(&infos.tsk_infos())) .max(self.prepare_glwe_automorphism_key_tmp_bytes(&infos.atk_infos())) } fn circuit_bootstrapping_key_prepare( &self, res: &mut CircuitBootstrappingKeyPrepared, other: &CircuitBootstrappingKey, scratch: &mut Scratch, ) where DM: DataMut, DR: DataRef, { res.brk.prepare(self, &other.brk, scratch); res.tsk.prepare(self, &other.tsk, scratch); for (k, a) in res.atk.iter_mut() { a.prepare(self, other.atk.get(k).unwrap(), scratch); } } } pub struct CircuitBootstrappingKeyPrepared { pub(crate) brk: BlindRotationKeyPrepared, pub(crate) tsk: GGLWEToGGSWKeyPrepared, pub(crate) atk: HashMap>, } impl GLWEAutomorphismKeyHelper, BE> for CircuitBootstrappingKeyPrepared { fn get_automorphism_key(&self, k: i64) -> Option<&GLWEAutomorphismKeyPrepared> { self.atk.get_automorphism_key(k) } fn automorphism_key_infos(&self) -> poulpy_core::layouts::GGLWELayout { self.atk.automorphism_key_infos() } } impl CircuitBootstrappingKeyInfos for CircuitBootstrappingKeyPrepared { fn atk_infos(&self) -> GLWEAutomorphismKeyLayout { let (_, atk) = self.atk.iter().next().expect("atk is empty"); GLWEAutomorphismKeyLayout { n: atk.n(), base2k: atk.base2k(), k: atk.k(), dnum: atk.dnum(), dsize: atk.dsize(), rank: atk.rank(), } } fn brk_infos(&self) -> BlindRotationKeyLayout { BlindRotationKeyLayout { n_glwe: self.brk.n_glwe(), n_lwe: self.brk.n_lwe(), base2k: self.brk.base2k(), k: self.brk.k(), dnum: self.brk.dnum(), rank: self.brk.rank(), } } fn tsk_infos(&self) -> GLWETensorKeyLayout { GLWETensorKeyLayout { n: self.tsk.n(), base2k: self.tsk.base2k(), k: self.tsk.k(), dnum: self.tsk.dnum(), dsize: self.tsk.dsize(), rank: self.tsk.rank(), } } }