use std::{collections::HashMap, time::Instant, usize}; use backend::hal::{ api::{ ScratchAvailable, TakeMatZnx, TakeScalarZnx, TakeSvpPPol, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, TakeVecZnxDftSlice, TakeVecZnxSlice, VecZnxAddInplace, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxBigAutomorphismInplace, VecZnxBigSubSmallBInplace, VecZnxCopy, VecZnxDftCopy, VecZnxDftToVecZnxBigTmpA, VecZnxNegateInplace, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxSubABInplace, VecZnxSwithcDegree, VmpPMatAlloc, VmpPMatPrepare, }, layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}, oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl}, }; use sampling::source::Source; use crate::{ BlindRotationKeyCGGI, BlindRotationKeyCGGIExec, BlindRotationKeyCGGIExecLayoutFamily, CCGIBlindRotationFamily, GLWEOps, Infos, LookUpTable, LookUpTableRotationDirection, TakeGGLWE, TakeGLWECt, cggi_blind_rotate, layouts::{ AutomorphismKey, GGSWCiphertext, GLWECiphertext, GLWESecret, GLWETensorKey, LWECiphertext, LWESecret, prepared::{AutomorphismKeyExec, GLWESecretExec, GLWETensorKeyExec}, }, }; use crate::trait_families::{ AutomorphismKeyEncryptSkFamily, GGSWEncryptSkFamily, GLWESecretExecModuleFamily, GLWETensorKeyEncryptSkFamily, GLWETraceModuleFamily, }; pub struct CircuitBootstrappingKeyCGGI { pub(crate) brk: BlindRotationKeyCGGI, pub(crate) tsk: GLWETensorKey>, pub(crate) atk: HashMap>>, } impl CircuitBootstrappingKeyCGGI> { pub fn generate( module: &Module, basek: usize, sk_lwe: &LWESecret, sk_glwe: &GLWESecret, k_brk: usize, rows_brk: usize, k_trace: usize, rows_trace: usize, k_tsk: usize, rows_tsk: usize, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, scratch: &mut Scratch, ) -> Self where Module: GGSWEncryptSkFamily + GLWESecretExecModuleFamily + VecZnxAddScalarInplace + AutomorphismKeyEncryptSkFamily + VecZnxAutomorphism + VecZnxSwithcDegree + GLWETensorKeyEncryptSkFamily, DLwe: DataRef, DGlwe: DataRef, Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeSvpPPol + TakeVecZnxBig, { let mut auto_keys: HashMap>> = HashMap::new(); let gal_els: Vec = GLWECiphertext::trace_galois_elements(&module); gal_els.iter().for_each(|gal_el| { let mut key: AutomorphismKey> = AutomorphismKey::alloc(sk_glwe.n(), basek, k_trace, rows_trace, 1, sk_glwe.rank()); key.encrypt_sk( &module, *gal_el, &sk_glwe, source_xa, source_xe, sigma, scratch, ); auto_keys.insert(*gal_el, key); }); let sk_glwe_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk_glwe); let mut brk: BlindRotationKeyCGGI> = BlindRotationKeyCGGI::alloc( sk_glwe.n(), sk_lwe.n(), basek, k_brk, rows_brk, sk_glwe.rank(), ); brk.generate_from_sk( module, &sk_glwe_exec, sk_lwe, source_xa, source_xe, sigma, scratch, ); let mut tsk: GLWETensorKey> = GLWETensorKey::alloc(sk_glwe.n(), basek, k_tsk, rows_tsk, 1, sk_glwe.rank()); tsk.encrypt_sk(module, &sk_glwe, source_xa, source_xe, sigma, scratch); Self { brk, atk: auto_keys, tsk, } } } pub struct CircuitBootstrappingKeyCGGIExec { pub(crate) brk: BlindRotationKeyCGGIExec, pub(crate) tsk: GLWETensorKeyExec, B>, pub(crate) atk: HashMap, B>>, } impl CircuitBootstrappingKeyCGGIExec, B> { pub fn from( module: &Module, other: &CircuitBootstrappingKeyCGGI, scratch: &mut Scratch, ) -> CircuitBootstrappingKeyCGGIExec, B> where Module: BlindRotationKeyCGGIExecLayoutFamily + VmpPMatAlloc + VmpPMatPrepare, { let brk: BlindRotationKeyCGGIExec, B> = BlindRotationKeyCGGIExec::from(module, &other.brk, scratch); let tsk: GLWETensorKeyExec, B> = GLWETensorKeyExec::from(module, &other.tsk, scratch); let mut atk: HashMap, B>> = HashMap::new(); for (key, value) in &other.atk { atk.insert(*key, AutomorphismKeyExec::from(module, value, scratch)); } CircuitBootstrappingKeyCGGIExec { brk, tsk, atk } } } pub trait CGGICircuitBootstrapFamily = VecZnxRotateInplace + VecZnxNormalizeInplace + VecZnxNormalizeTmpBytes + CCGIBlindRotationFamily + VecZnxSwithcDegree + VecZnxBigAutomorphismInplace + VecZnxRshInplace + VecZnxDftCopy + VecZnxDftToVecZnxBigTmpA + VecZnxSub + VecZnxAddInplace + VecZnxNegateInplace + VecZnxCopy + VecZnxSubABInplace + GLWETraceModuleFamily + VecZnxRotateInplace + VecZnxAutomorphismInplace + VecZnxBigSubSmallBInplace; pub fn circuit_bootstrap_to_constant_cggi( module: &Module, res: &mut GGSWCiphertext, lwe: &LWECiphertext, log_domain: usize, extension_factor: usize, key: &CircuitBootstrappingKeyCGGIExec, scratch: &mut Scratch, ) where DRes: DataMut, DLwe: DataRef, DBrk: DataRef, Module: CGGICircuitBootstrapFamily, B: ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, Scratch: TakeVecZnx + TakeVecZnxDftSlice + TakeVecZnxBig + TakeVecZnxDft + TakeMatZnx + ScratchAvailable + TakeVecZnxSlice, { circuit_bootstrap_core_cggi( false, module, 0, res, lwe, log_domain, extension_factor, key, scratch, ); } pub fn circuit_bootstrap_to_exponent_cggi( module: &Module, log_gap_out: usize, res: &mut GGSWCiphertext, lwe: &LWECiphertext, log_domain: usize, extension_factor: usize, key: &CircuitBootstrappingKeyCGGIExec, scratch: &mut Scratch, ) where DRes: DataMut, DLwe: DataRef, DBrk: DataRef, Module: CGGICircuitBootstrapFamily, B: ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, Scratch: TakeVecZnx + TakeVecZnxDftSlice + TakeVecZnxBig + TakeVecZnxDft + TakeMatZnx + ScratchAvailable + TakeVecZnxSlice, { circuit_bootstrap_core_cggi( true, module, log_gap_out, res, lwe, log_domain, extension_factor, key, scratch, ); } pub fn circuit_bootstrap_core_cggi( to_exponent: bool, module: &Module, log_gap_out: usize, res: &mut GGSWCiphertext, lwe: &LWECiphertext, log_domain: usize, extension_factor: usize, key: &CircuitBootstrappingKeyCGGIExec, scratch: &mut Scratch, ) where DRes: DataMut, DLwe: DataRef, DBrk: DataRef, Module: CGGICircuitBootstrapFamily, B: ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, Scratch: TakeGGLWE + TakeVecZnxDftSlice + TakeVecZnxBig + TakeVecZnxDft + TakeVecZnx + ScratchAvailable + TakeVecZnxSlice, { #[cfg(debug_assertions)] { use crate::Infos; assert_eq!(res.n(), key.brk.n()); assert_eq!(lwe.basek(), key.brk.basek()); assert_eq!(res.basek(), key.brk.basek()); } let n: usize = res.n(); let basek: usize = res.basek(); let rows: usize = res.rows(); let rank: usize = res.rank(); let k: usize = res.k(); let alpha: usize = rows.next_power_of_two(); let mut f: Vec = vec![0i64; (1 << log_domain) * alpha]; if to_exponent { (0..rows).for_each(|i| { f[i] = 1 << (basek * (rows - 1 - i)); }); } else { (0..1 << log_domain).for_each(|j| { (0..rows).for_each(|i| { f[j * alpha + i] = j as i64 * (1 << (basek * (rows - 1 - i))); }); }); } // Lut precision, basically must be able to hold the decomposition power basis of the GGSW let mut lut: LookUpTable = LookUpTable::alloc(n, basek, basek * rows, extension_factor); lut.set(module, &f, basek * rows); if to_exponent { lut.set_rotation_direction(LookUpTableRotationDirection::Right); } // TODO: separate GGSW k from output of blind rotation k let (mut res_glwe, scratch1) = scratch.take_glwe_ct(n, basek, k, rank); let (mut tmp_gglwe, scratch2) = scratch1.take_gglwe(n, basek, k, rows, 1, rank, rank); let now: Instant = Instant::now(); cggi_blind_rotate(module, &mut res_glwe, &lwe, &lut, &key.brk, scratch2); println!("cggi_blind_rotate: {} ms", now.elapsed().as_millis()); let gap: usize = 2 * lut.drift / lut.extension_factor(); let log_gap_in: usize = (usize::BITS - (gap * alpha - 1).leading_zeros()) as _; (0..rows).for_each(|i| { let mut tmp_glwe: GLWECiphertext<&mut [u8]> = tmp_gglwe.at_mut(i, 0); if to_exponent { let now: Instant = Instant::now(); // Isolates i-th LUT and moves coefficients according to requested gap. post_process( module, &mut tmp_glwe, &res_glwe, log_gap_in, log_gap_out, log_domain, &key.atk, scratch2, ); println!("post_process: {} ms", now.elapsed().as_millis()); } else { tmp_glwe.trace(module, 0, module.log_n(), &res_glwe, &key.atk, scratch2); } if i < rows { res_glwe.rotate_inplace(module, -(gap as i64)); } }); // Expands GGLWE to GGSW using GGLWE(s^2) res.from_gglwe(module, &tmp_gglwe, &key.tsk, scratch2); } fn post_process( module: &Module, res: &mut GLWECiphertext, a: &GLWECiphertext, log_gap_in: usize, log_gap_out: usize, log_domain: usize, auto_keys: &HashMap, B>>, scratch: &mut Scratch, ) where DataRes: DataMut, DataA: DataRef, Module: CGGICircuitBootstrapFamily, Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { let log_n: usize = module.log_n(); let mut cts: HashMap>> = HashMap::new(); // First partial trace, vanishes all coefficients which are not multiples of gap_in // [1, 1, 1, 1, 0, 0, 0, ..., 0, 0, -1, -1, -1, -1] -> [1, 0, 0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0, 0] res.trace( module, module.log_n() - log_gap_in as usize + 1, log_n, &a, auto_keys, scratch, ); // TODO: optimize with packing and final partial trace // If gap_out < gap_in, then we need to repack, i.e. reduce the cap between coefficients. if log_gap_in != log_gap_out { let steps: i32 = 1 << log_domain; (0..steps).for_each(|i| { if i != 0 { res.rotate_inplace(module, -(1 << log_gap_in)); } cts.insert(i as usize * (1 << log_gap_out), res.clone()); }); let now: Instant = Instant::now(); pack(module, &mut cts, log_gap_out, auto_keys, scratch); println!("pack: {} ms", now.elapsed().as_millis()); let packed: GLWECiphertext> = cts.remove(&0).unwrap(); res.trace( module, log_n - log_gap_out, log_n, &packed, auto_keys, scratch, ); } } pub fn pack( module: &Module, cts: &mut HashMap>, log_gap_out: usize, auto_keys: &HashMap, B>>, scratch: &mut Scratch, ) where Module: CGGICircuitBootstrapFamily, Scratch: TakeVecZnx + TakeVecZnxDft + ScratchAvailable, { let log_n: usize = module.log_n(); let basek: usize = cts.get(&0).unwrap().basek(); let k: usize = cts.get(&0).unwrap().k(); let rank: usize = cts.get(&0).unwrap().rank(); (0..log_n - log_gap_out).for_each(|i| { let now: Instant = Instant::now(); let t = 16.min(1 << (log_n - 1 - i)); let auto_key: &AutomorphismKeyExec, B>; if i == 0 { auto_key = auto_keys.get(&-1).unwrap() } else { auto_key = auto_keys.get(&module.galois_element(1 << (i - 1))).unwrap(); } (0..t).for_each(|j| { let mut a: Option> = cts.remove(&j); let mut b: Option> = cts.remove(&(j + t)); combine( module, basek, k, rank, a.as_mut(), b.as_mut(), i, auto_key, scratch, ); if let Some(a) = a { cts.insert(j, a); } else if let Some(b) = b { cts.insert(j, b); } }); println!("combine: {} us", now.elapsed().as_micros()); }); } fn combine( module: &Module, basek: usize, k: usize, rank: usize, a: Option<&mut GLWECiphertext>, b: Option<&mut GLWECiphertext>, i: usize, auto_key: &AutomorphismKeyExec, scratch: &mut Scratch, ) where Module: CGGICircuitBootstrapFamily, Scratch: TakeVecZnx + TakeVecZnxDft + ScratchAvailable, { // Goal is to evaluate: a = a + b*X^t + phi(a - b*X^t)) // We also use the identity: AUTO(a * X^t, g) = -X^t * AUTO(a, g) // where t = 2^(log_n - i - 1) and g = 5^{2^(i - 1)} // Different cases for wether a and/or b are zero. // // Implicite RSH without modulus switch, introduces extra I(X) * Q/2 on decryption. // Necessary so that the scaling of the plaintext remains constant. // It however is ok to do so here because coefficients are eventually // either mapped to garbage or twice their value which vanishes I(X) // since 2*(I(X) * Q/2) = I(X) * Q = 0 mod Q. if let Some(a) = a { let n: usize = a.n(); let log_n: usize = (u64::BITS - (n - 1).leading_zeros()) as _; let t: i64 = 1 << (log_n - i - 1); if let Some(b) = b { let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(n, basek, k, rank); // a = a * X^-t a.rotate_inplace(module, -t); // tmp_b = a * X^-t - b tmp_b.sub(module, a, b); tmp_b.rsh(module, 1); // a = a * X^-t + b a.add_inplace(module, b); a.rsh(module, 1); tmp_b.normalize_inplace(module, scratch_1); // tmp_b = phi(a * X^-t - b) tmp_b.automorphism_inplace(module, auto_key, scratch_1); // a = a * X^-t + b - phi(a * X^-t - b) a.sub_inplace_ab(module, &tmp_b); a.normalize_inplace(module, scratch_1); // a = a + b * X^t - phi(a * X^-t - b) * X^t // = a + b * X^t - phi(a * X^-t - b) * - phi(X^t) // = a + b * X^t + phi(a - b * X^t) a.rotate_inplace(module, t); } else { a.rsh(module, 1); // a = a + phi(a) a.automorphism_add_inplace(module, auto_key, scratch); } } else { if let Some(b) = b { let n: usize = b.n(); let log_n: usize = (u64::BITS - (n - 1).leading_zeros()) as _; let t: i64 = 1 << (log_n - i - 1); let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(n, basek, k, rank); tmp_b.rotate(module, t, b); tmp_b.rsh(module, 1); // a = (b* X^t - phi(b* X^t)) b.automorphism_sub_ba(module, &tmp_b, auto_key, scratch_1); } } }