use poulpy_core::layouts::{ GGLWEAutomorphismKey, GGLWETensorKey, GLWECiphertext, GLWESecret, LWESecret, prepared::{GGLWEAutomorphismKeyPrepared, GGLWETensorKeyPrepared, GLWESecretPrepared, PrepareAlloc}, }; use std::collections::HashMap; use poulpy_backend::hal::{ api::{ ScratchAvailable, SvpApply, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeSvpPPol, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, VecZnxDftToVecZnxBigTmpA, VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwithcDegree, VmpPMatAlloc, VmpPrepare, }, layouts::{Backend, Data, DataRef, Module, Scratch}, source::Source, }; use crate::tfhe::blind_rotation::{ BlindRotationAlgo, BlindRotationKey, BlindRotationKeyAlloc, BlindRotationKeyEncryptSk, BlindRotationKeyPrepared, }; pub trait CircuitBootstrappingKeyEncryptSk { #[allow(clippy::too_many_arguments)] fn encrypt_sk( module: &Module, basek: usize, sk_lwe: &LWESecret, sk_glwe: &GLWESecret, k_brk: usize, rows_brk: usize, k_trace: usize, rows_trace: usize, k_tsk: usize, rows_tsk: usize, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, scratch: &mut Scratch, ) -> Self where DLwe: DataRef, DGlwe: DataRef; } pub struct CircuitBootstrappingKey { pub(crate) brk: BlindRotationKey, pub(crate) tsk: GGLWETensorKey>, pub(crate) atk: HashMap>>, } impl CircuitBootstrappingKeyEncryptSk for CircuitBootstrappingKey, BRA> where BlindRotationKey, BRA>: BlindRotationKeyAlloc + BlindRotationKeyEncryptSk, Module: SvpApply + VecZnxDftToVecZnxBigTmpA + VecZnxAddScalarInplace + VecZnxDftAllocBytes + VecZnxBigNormalize + VecZnxDftFromVecZnx + SvpApplyInplace + VecZnxDftToVecZnxBigConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform + VecZnxSubABInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal + VecZnxNormalize + VecZnxSub + SvpPrepare + VecZnxSwithcDegree + SvpPPolAllocBytes + SvpPPolAlloc + VecZnxAutomorphism, Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeSvpPPol + TakeVecZnxBig, { fn encrypt_sk( module: &Module, basek: usize, sk_lwe: &LWESecret, sk_glwe: &GLWESecret, k_brk: usize, rows_brk: usize, k_trace: usize, rows_trace: usize, k_tsk: usize, rows_tsk: usize, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, scratch: &mut Scratch, ) -> Self where DLwe: DataRef, DGlwe: DataRef, Module:, { let mut auto_keys: HashMap>> = HashMap::new(); let gal_els: Vec = GLWECiphertext::trace_galois_elements(module); gal_els.iter().for_each(|gal_el| { let mut key: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(sk_glwe.n(), basek, k_trace, rows_trace, 1, sk_glwe.rank()); key.encrypt_sk( module, *gal_el, sk_glwe, source_xa, source_xe, sigma, scratch, ); auto_keys.insert(*gal_el, key); }); let sk_glwe_prepared: GLWESecretPrepared, B> = sk_glwe.prepare_alloc(module, scratch); let mut brk: BlindRotationKey, BRA> = BlindRotationKey::, BRA>::alloc( sk_glwe.n(), sk_lwe.n(), basek, k_brk, rows_brk, sk_glwe.rank(), ); brk.encrypt_sk( module, &sk_glwe_prepared, sk_lwe, source_xa, source_xe, sigma, scratch, ); let mut tsk: GGLWETensorKey> = GGLWETensorKey::alloc(sk_glwe.n(), basek, k_tsk, rows_tsk, 1, sk_glwe.rank()); tsk.encrypt_sk(module, sk_glwe, source_xa, source_xe, sigma, scratch); Self { brk, atk: auto_keys, tsk, } } } pub struct CircuitBootstrappingKeyPrepared { pub(crate) brk: BlindRotationKeyPrepared, pub(crate) tsk: GGLWETensorKeyPrepared, B>, pub(crate) atk: HashMap, B>>, } impl PrepareAlloc, BRA, B>> for CircuitBootstrappingKey where Module: VmpPMatAlloc + VmpPrepare, BlindRotationKey: PrepareAlloc, BRA, B>>, GGLWETensorKey: PrepareAlloc, B>>, GGLWEAutomorphismKey: PrepareAlloc, B>>, { fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> CircuitBootstrappingKeyPrepared, BRA, B> { let brk: BlindRotationKeyPrepared, BRA, B> = self.brk.prepare_alloc(module, scratch); let tsk: GGLWETensorKeyPrepared, B> = self.tsk.prepare_alloc(module, scratch); let mut atk: HashMap, B>> = HashMap::new(); for (key, value) in &self.atk { atk.insert(*key, value.prepare_alloc(module, scratch)); } CircuitBootstrappingKeyPrepared { brk, tsk, atk } } }