From 504cb72f82bac2f381513f89f3174383eb99441f Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Wed, 13 Aug 2025 15:56:09 +0200 Subject: [PATCH] Add CBT prototype (#68) --- core/src/blind_rotation/cggi.rs | 28 +- core/src/blind_rotation/lut.rs | 67 ++- core/src/blind_rotation/test_fft64/cggi.rs | 130 +++++ core/src/blind_rotation/tests/generic_cggi.rs | 10 +- core/src/blind_rotation/tests/generic_lut.rs | 44 +- .../circuit_bootstrapping.rs | 508 ++++++++++++++++++ core/src/circuit_bootstrapping/mod.rs | 6 + .../test_fft64/circuit_bootstrapping.rs | 357 ++++++++++++ .../circuit_bootstrapping/test_fft64/mod.rs | 1 + core/src/gglwe/layouts_exec.rs | 21 + core/src/ggsw/automorphism.rs | 47 +- core/src/ggsw/keyswitch.rs | 304 ++++++----- core/src/ggsw/noise.rs | 50 ++ core/src/glwe/mod.rs | 1 + core/src/glwe/trace.rs | 6 +- core/src/lib.rs | 2 + 16 files changed, 1380 insertions(+), 202 deletions(-) create mode 100644 core/src/blind_rotation/test_fft64/cggi.rs create mode 100644 core/src/circuit_bootstrapping/circuit_bootstrapping.rs create mode 100644 core/src/circuit_bootstrapping/mod.rs create mode 100644 core/src/circuit_bootstrapping/test_fft64/circuit_bootstrapping.rs create mode 100644 core/src/circuit_bootstrapping/test_fft64/mod.rs diff --git a/core/src/blind_rotation/cggi.rs b/core/src/blind_rotation/cggi.rs index 644fe65..1eabdac 100644 --- a/core/src/blind_rotation/cggi.rs +++ b/core/src/blind_rotation/cggi.rs @@ -12,7 +12,7 @@ use itertools::izip; use crate::{ GLWECiphertext, GLWECiphertextToMut, GLWEExternalProductFamily, GLWEOps, Infos, LWECiphertext, LWECiphertextToRef, - TakeGLWECt, + LookUpTableRotationDirection, TakeGLWECt, blind_rotation::{key::BlindRotationKeyCGGIExec, lut::LookUpTable}, dist::Distribution, }; @@ -158,7 +158,7 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended = lwe.to_ref(); let basek: usize = brk.basek(); - negate_and_mod_switch_2n(2 * lut.domain_size(), &mut lwe_2n, &lwe_ref); + mod_switch_2n( + 2 * lut.domain_size(), + &mut lwe_2n, + &lwe_ref, + lut.rotation_direction(), + ); let a: &[i64] = &lwe_2n[1..]; let b: i64 = lwe_2n[0]; @@ -429,13 +439,19 @@ pub(crate) fn cggi_blind_rotate_binary_standard) { +pub(crate) fn mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWECiphertext<&[u8]>, rot_dir: LookUpTableRotationDirection) { let basek: usize = lwe.basek(); let log2n: usize = usize::BITS as usize - (n - 1).leading_zeros() as usize + 1; res.copy_from_slice(&lwe.data.at(0, 0)); - res.iter_mut().for_each(|x| *x = -*x); + + match rot_dir { + LookUpTableRotationDirection::Left => { + res.iter_mut().for_each(|x| *x = -*x); + } + LookUpTableRotationDirection::Right => {} + } if basek > log2n { let diff: usize = basek - log2n; diff --git a/core/src/blind_rotation/lut.rs b/core/src/blind_rotation/lut.rs index a06aa46..6ec093e 100644 --- a/core/src/blind_rotation/lut.rs +++ b/core/src/blind_rotation/lut.rs @@ -7,10 +7,18 @@ use backend::hal::{ oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl}, }; +#[derive(Debug, Clone, Copy)] +pub enum LookUpTableRotationDirection { + Left, + Right, +} + pub struct LookUpTable { pub(crate) data: Vec>>, + pub(crate) rot_dir: LookUpTableRotationDirection, pub(crate) basek: usize, pub(crate) k: usize, + pub(crate) drift: usize, } impl LookUpTable { @@ -28,7 +36,13 @@ impl LookUpTable { (0..extension_factor).for_each(|_| { data.push(VecZnx::alloc(n, 1, size)); }); - Self { data, basek, k } + Self { + data, + basek, + k, + drift: 0, + rot_dir: LookUpTableRotationDirection::Left, + } } pub fn log_extension_factor(&self) -> usize { @@ -43,6 +57,18 @@ impl LookUpTable { self.data.len() * self.data[0].n() } + pub fn rotation_direction(&self) -> LookUpTableRotationDirection { + self.rot_dir + } + + // By default X^{-dec(lwe)} is computed during the blind rotation. + // Setting [reverse_rotation] to true will reverse the sign of + // rotation of the LUT by instead evaluating X^{dec(lwe)} during + // the blind rotation. + pub fn set_rotation_direction(&mut self, rot_dir: LookUpTableRotationDirection) { + self.rot_dir = rot_dir + } + pub fn set(&mut self, module: &Module, f: &Vec, k: usize) where Module: VecZnxRotateInplace + VecZnxNormalizeInplace + VecZnxNormalizeTmpBytes + VecZnxSwithcDegree + VecZnxCopy, @@ -53,15 +79,23 @@ impl LookUpTable { let basek: usize = self.basek; // Get the number minimum limb to store the message modulus - let limbs: usize = k.div_ceil(1 << basek); + let limbs: usize = k.div_ceil(basek); #[cfg(debug_assertions)] { + assert!(f.len() <= module.n()); + assert!( + (max_bit_size(f) + (k % basek) as u32) < i64::BITS, + "overflow: max(|f|) << (k%basek) > i64::BITS" + ); assert!(limbs <= self.data[0].size()); } // Scaling factor - let scale: i64 = 1 << (k % basek) as i64; + let mut scale = 1; + if k % basek != 0 { + scale <<= basek - (k % basek); + } // #elements in lookup table let f_len: usize = f.len(); @@ -76,16 +110,18 @@ impl LookUpTable { let lut_at: &mut [i64] = lut_full.at_mut(0, limbs - 1); + let step: usize = domain_size.div_round(f_len); + f.iter().enumerate().for_each(|(i, fi)| { - let start: usize = (i * domain_size).div_round(f_len); - let end: usize = ((i + 1) * domain_size).div_round(f_len); + let start: usize = i * step; + let end: usize = start + step; lut_at[start..end].fill(fi * scale); }); - // Rotates half the step to the left - let half_step: usize = domain_size.div_round(f_len << 1); + let drift: usize = step >> 1; - module.vec_znx_rotate_inplace(-(half_step as i64), &mut lut_full, 0); + // Rotates half the step to the left + module.vec_znx_rotate_inplace(-(drift as i64), &mut lut_full, 0); let n_large: usize = lut_full.n(); @@ -106,6 +142,8 @@ impl LookUpTable { } else { module.vec_znx_copy(&mut self.data[0], 0, &lut_full, 0); } + + self.drift = drift } #[allow(dead_code)] @@ -144,3 +182,16 @@ impl DivRound for usize { (self + rhs / 2) / rhs } } + +fn max_bit_size(vec: &[i64]) -> u32 { + vec.iter() + .map(|&v| { + if v == 0 { + 0 + } else { + v.unsigned_abs().ilog2() + 1 + } + }) + .max() + .unwrap_or(0) +} diff --git a/core/src/blind_rotation/test_fft64/cggi.rs b/core/src/blind_rotation/test_fft64/cggi.rs new file mode 100644 index 0000000..446dcd2 --- /dev/null +++ b/core/src/blind_rotation/test_fft64/cggi.rs @@ -0,0 +1,130 @@ +use backend::{Encoding, FFT64, Module, ScratchOwned, ZnxView}; +use sampling::source::Source; + +use crate::{ + FourierGLWESecret, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, LWECiphertext, LWESecret, + blind_rotation::{ + cggi::{cggi_blind_rotate, cggi_blind_rotate_scratch_space, mod_switch_2n}, + key::BlindRotationKeyCGGI, + lut::LookUpTable, + }, + lwe::{LWEPlaintext, ciphertext::LWECiphertextToRef}, +}; + +#[test] +fn standard() { + blind_rotatio_test(224, 1, 1); +} + +#[test] +fn block_binary() { + blind_rotatio_test(224, 7, 1); +} + +#[test] +fn block_binary_extended() { + blind_rotatio_test(224, 7, 2); +} + +fn blind_rotatio_test(n_lwe: usize, block_size: usize, extension_factor: usize) { + let module: Module = Module::::new(512); + let basek: usize = 19; + + let k_lwe: usize = 24; + let k_brk: usize = 3 * basek; + let rows_brk: usize = 2; // Ensures first limb is noise-free. + let k_lut: usize = 1 * basek; + let k_res: usize = 2 * basek; + let rank: usize = 1; + + let message_modulus: usize = 1 << 4; + + let mut source_xs: Source = Source::new([2u8; 32]); + let mut source_xe: Source = Source::new([2u8; 32]); + let mut source_xa: Source = Source::new([1u8; 32]); + + let mut sk_glwe: GLWESecret> = GLWESecret::alloc(&module, rank); + sk_glwe.fill_ternary_prob(0.5, &mut source_xs); + let sk_glwe_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_glwe); + + let mut sk_lwe: LWESecret> = LWESecret::alloc(n_lwe); + sk_lwe.fill_binary_block(block_size, &mut source_xs); + + let mut scratch: ScratchOwned = ScratchOwned::new(BlindRotationKeyCGGI::generate_from_sk_scratch_space( + &module, basek, k_brk, rank, + )); + + let mut scratch_br: ScratchOwned = ScratchOwned::new(cggi_blind_rotate_scratch_space( + &module, + block_size, + extension_factor, + basek, + k_res, + k_brk, + rows_brk, + rank, + )); + + let mut brk: BlindRotationKeyCGGI, FFT64> = + BlindRotationKeyCGGI::allocate(&module, n_lwe, basek, k_brk, rows_brk, rank); + + brk.generate_from_sk( + &module, + &sk_glwe_dft, + &sk_lwe, + &mut source_xa, + &mut source_xe, + 3.2, + scratch.borrow(), + ); + + let mut lwe: LWECiphertext> = LWECiphertext::alloc(n_lwe, basek, k_lwe); + + let mut pt_lwe: LWEPlaintext> = LWEPlaintext::alloc(basek, k_lwe); + + let x: i64 = 2; + let bits: usize = 8; + + pt_lwe.data.encode_coeff_i64(0, basek, bits, 0, x, bits); + + lwe.encrypt_sk(&pt_lwe, &sk_lwe, &mut source_xa, &mut source_xe, 3.2); + + let mut f: Vec = vec![0i64; message_modulus]; + f.iter_mut() + .enumerate() + .for_each(|(i, x)| *x = 2 * (i as i64) + 1); + + let mut lut: LookUpTable = LookUpTable::alloc(&module, basek, k_lut, extension_factor); + lut.set(&module, &f, message_modulus); + + let mut res: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_res, rank); + + cggi_blind_rotate(&module, &mut res, &lwe, &lut, &brk, scratch_br.borrow()); + + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_res); + + res.decrypt(&module, &mut pt_have, &sk_glwe_dft, scratch.borrow()); + + let mut lwe_2n: Vec = vec![0i64; lwe.n() + 1]; // TODO: from scratch space + + mod_switch_2n( + 2 * lut.domain_size(), + &mut lwe_2n, + &lwe.to_ref(), + lut.rotation_direction(), + ); + + let pt_want: i64 = (lwe_2n[0] + + lwe_2n[1..] + .iter() + .zip(sk_lwe.data.at(0, 0)) + .map(|(x, y)| x * y) + .sum::()) + & (2 * lut.domain_size() - 1) as i64; + + lut.rotate(pt_want); + + // First limb should be exactly equal (test are parameterized such that the noise does not reach + // the first limb) + assert_eq!(pt_have.data.at(0, 0), lut.data[0].at(0, 0)); +} diff --git a/core/src/blind_rotation/tests/generic_cggi.rs b/core/src/blind_rotation/tests/generic_cggi.rs index ffec8af..01c1a8b 100644 --- a/core/src/blind_rotation/tests/generic_cggi.rs +++ b/core/src/blind_rotation/tests/generic_cggi.rs @@ -14,8 +14,7 @@ use sampling::source::Source; use crate::{ BlindRotationKeyCGGI, BlindRotationKeyCGGIExec, BlindRotationKeyCGGIExecLayoutFamily, CCGIBlindRotationFamily, GLWECiphertext, GLWEDecryptFamily, GLWEPlaintext, GLWESecret, GLWESecretExec, GLWESecretFamily, Infos, LWECiphertext, - LWECiphertextToRef, LWEPlaintext, LWESecret, LookUpTable, cggi_blind_rotate, cggi_blind_rotate_scratch_space, - negate_and_mod_switch_2n, + LWECiphertextToRef, LWEPlaintext, LWESecret, LookUpTable, cggi_blind_rotate, cggi_blind_rotate_scratch_space, mod_switch_2n, }; pub(crate) trait CGGITestModuleFamily = CCGIBlindRotationFamily @@ -133,7 +132,12 @@ where let mut lwe_2n: Vec = vec![0i64; lwe.n() + 1]; // TODO: from scratch space - negate_and_mod_switch_2n(2 * lut.domain_size(), &mut lwe_2n, &lwe.to_ref()); + mod_switch_2n( + 2 * lut.domain_size(), + &mut lwe_2n, + &lwe.to_ref(), + lut.rotation_direction(), + ); let pt_want: i64 = (lwe_2n[0] + lwe_2n[1..] diff --git a/core/src/blind_rotation/tests/generic_lut.rs b/core/src/blind_rotation/tests/generic_lut.rs index 86263b0..0486ed0 100644 --- a/core/src/blind_rotation/tests/generic_lut.rs +++ b/core/src/blind_rotation/tests/generic_lut.rs @@ -1,7 +1,9 @@ use std::vec; use backend::hal::{ - api::{VecZnxCopy, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotateInplace, VecZnxSwithcDegree, ZnxView}, + api::{ + VecZnxCopy, VecZnxDecodeVeci64, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotateInplace, VecZnxSwithcDegree, + }, layouts::{Backend, Module}, oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl}, }; @@ -10,7 +12,12 @@ use crate::{DivRound, LookUpTable}; pub(crate) fn test_lut_standard(module: &Module) where - Module: VecZnxRotateInplace + VecZnxNormalizeInplace + VecZnxNormalizeTmpBytes + VecZnxSwithcDegree + VecZnxCopy, + Module: VecZnxRotateInplace + + VecZnxNormalizeInplace + + VecZnxNormalizeTmpBytes + + VecZnxSwithcDegree + + VecZnxCopy + + VecZnxDecodeVeci64, B: ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, { let n: usize = module.n(); @@ -34,20 +41,24 @@ where let step: usize = lut.domain_size().div_round(message_modulus); + let mut lut_dec: Vec = vec![0i64; module.n()]; + module.decode_vec_i64(basek, &lut.data[0], 0, log_scale, &mut lut_dec); + (0..lut.domain_size()).step_by(step).for_each(|i| { (0..step).for_each(|_| { - assert_eq!( - f[i / step] % message_modulus as i64, - lut.data[0].raw()[0] / (1 << (log_scale % basek)) as i64 - ); - lut.rotate(module, -1); + assert_eq!(f[i / step] % message_modulus as i64, lut_dec[i]); }); }); } pub(crate) fn test_lut_extended(module: &Module) where - Module: VecZnxRotateInplace + VecZnxNormalizeInplace + VecZnxNormalizeTmpBytes + VecZnxSwithcDegree + VecZnxCopy, + Module: VecZnxRotateInplace + + VecZnxNormalizeInplace + + VecZnxNormalizeTmpBytes + + VecZnxSwithcDegree + + VecZnxCopy + + VecZnxDecodeVeci64, B: ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, { let n: usize = module.n(); @@ -69,15 +80,16 @@ where let half_step: i64 = lut.domain_size().div_round(message_modulus << 1) as i64; lut.rotate(&module, half_step); - let step: usize = lut.domain_size().div_round(message_modulus); + let step: usize = module.n().div_round(message_modulus); - (0..lut.domain_size()).step_by(step).for_each(|i| { - (0..step).for_each(|_| { - assert_eq!( - f[i / step] % message_modulus as i64, - lut.data[0].raw()[0] / (1 << (log_scale % basek)) as i64 - ); - lut.rotate(&module, -1); + let mut lut_dec: Vec = vec![0i64; module.n()]; + + (0..extension_factor).for_each(|ext| { + module.decode_vec_i64(basek, &lut.data[ext], 0, log_scale, &mut lut_dec); + (0..module.n()).step_by(step).for_each(|i| { + (0..step).for_each(|_| { + assert_eq!(f[i / step] % message_modulus as i64, lut_dec[i]); + }); }); }); } diff --git a/core/src/circuit_bootstrapping/circuit_bootstrapping.rs b/core/src/circuit_bootstrapping/circuit_bootstrapping.rs new file mode 100644 index 0000000..44124b9 --- /dev/null +++ b/core/src/circuit_bootstrapping/circuit_bootstrapping.rs @@ -0,0 +1,508 @@ +use std::{collections::HashMap, time::Instant, usize}; + +use backend::hal::{ + api::{ + ScratchAvailable, TakeMatZnx, TakeScalarZnx, TakeSvpPPol, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, TakeVecZnxDftSlice, + TakeVecZnxSlice, VecZnxAddInplace, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, + VecZnxBigAutomorphismInplace, VecZnxBigSubSmallBInplace, VecZnxCopy, VecZnxDftCopy, VecZnxDftToVecZnxBigTmpA, + VecZnxNegateInplace, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, + VecZnxSubABInplace, VecZnxSwithcDegree, + }, + layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}, + oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl}, +}; +use sampling::source::Source; + +use crate::{ + AutomorphismKey, AutomorphismKeyEncryptSkFamily, AutomorphismKeyExec, BlindRotationKeyCGGI, BlindRotationKeyCGGIExec, + BlindRotationKeyCGGIExecLayoutFamily, CCGIBlindRotationFamily, GGSWCiphertext, GGSWEncryptSkFamily, GLWECiphertext, GLWEOps, + GLWESecret, GLWESecretExec, GLWESecretFamily, GLWETensorKey, GLWETensorKeyEncryptSkFamily, GLWETensorKeyExec, + GLWETraceFamily, Infos, LWECiphertext, LWESecret, LookUpTable, LookUpTableRotationDirection, TakeGGLWE, TakeGLWECt, + cggi_blind_rotate, +}; + +pub struct CircuitBootstrappingKeyCGGI { + pub(crate) brk: BlindRotationKeyCGGI, + pub(crate) tsk: GLWETensorKey>, + pub(crate) atk: HashMap>>, +} + +impl CircuitBootstrappingKeyCGGI> { + pub fn generate( + module: &Module, + basek: usize, + sk_lwe: &LWESecret, + sk_glwe: &GLWESecret, + k_brk: usize, + rows_brk: usize, + k_trace: usize, + rows_trace: usize, + k_tsk: usize, + rows_tsk: usize, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) -> Self + where + Module: GLWESecretFamily + + GGSWEncryptSkFamily + + VecZnxAddScalarInplace + + AutomorphismKeyEncryptSkFamily + + VecZnxAutomorphism + + VecZnxSwithcDegree + + GLWETensorKeyEncryptSkFamily, + DLwe: DataRef, + DGlwe: DataRef, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeSvpPPol + TakeVecZnxBig, + { + let mut auto_keys: HashMap>> = HashMap::new(); + let gal_els: Vec = GLWECiphertext::trace_galois_elements(&module); + gal_els.iter().for_each(|gal_el| { + let mut key: AutomorphismKey> = + AutomorphismKey::alloc(sk_glwe.n(), basek, k_trace, rows_trace, 1, sk_glwe.rank()); + key.encrypt_sk( + &module, *gal_el, &sk_glwe, source_xa, source_xe, sigma, scratch, + ); + auto_keys.insert(*gal_el, key); + }); + + let sk_glwe_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk_glwe); + + let mut brk: BlindRotationKeyCGGI> = BlindRotationKeyCGGI::alloc( + sk_glwe.n(), + sk_lwe.n(), + basek, + k_brk, + rows_brk, + sk_glwe.rank(), + ); + brk.generate_from_sk( + module, + &sk_glwe_exec, + sk_lwe, + source_xa, + source_xe, + sigma, + scratch, + ); + + let mut tsk: GLWETensorKey> = GLWETensorKey::alloc(sk_glwe.n(), basek, k_tsk, rows_tsk, 1, sk_glwe.rank()); + tsk.encrypt_sk(module, &sk_glwe, source_xa, source_xe, sigma, scratch); + + Self { + brk, + atk: auto_keys, + tsk, + } + } +} + +pub struct CircuitBootstrappingKeyCGGIExec { + pub(crate) brk: BlindRotationKeyCGGIExec, + pub(crate) tsk: GLWETensorKeyExec, B>, + pub(crate) atk: HashMap, B>>, +} + +impl CircuitBootstrappingKeyCGGIExec, B> { + pub fn from( + module: &Module, + other: &CircuitBootstrappingKeyCGGI, + scratch: &mut Scratch, + ) -> CircuitBootstrappingKeyCGGIExec, B> + where + Module: BlindRotationKeyCGGIExecLayoutFamily, + { + let brk: BlindRotationKeyCGGIExec, B> = BlindRotationKeyCGGIExec::from(module, &other.brk, scratch); + let tsk: GLWETensorKeyExec, B> = GLWETensorKeyExec::from(module, &other.tsk, scratch); + let mut atk: HashMap, B>> = HashMap::new(); + for (key, value) in &other.atk { + atk.insert(*key, AutomorphismKeyExec::from(module, value, scratch)); + } + CircuitBootstrappingKeyCGGIExec { brk, tsk, atk } + } +} + +pub trait CGGICircuitBootstrapFamily = VecZnxRotateInplace + + VecZnxNormalizeInplace + + VecZnxNormalizeTmpBytes + + CCGIBlindRotationFamily + + VecZnxSwithcDegree + + VecZnxBigAutomorphismInplace + + VecZnxRshInplace + + VecZnxDftCopy + + VecZnxDftToVecZnxBigTmpA + + VecZnxSub + + VecZnxAddInplace + + VecZnxNegateInplace + + VecZnxCopy + + VecZnxSubABInplace + + GLWETraceFamily + + VecZnxRotateInplace + + VecZnxAutomorphismInplace + + VecZnxBigSubSmallBInplace; + +pub fn circuit_bootstrap_to_constant_cggi( + module: &Module, + res: &mut GGSWCiphertext, + lwe: &LWECiphertext, + log_domain: usize, + extension_factor: usize, + key: &CircuitBootstrappingKeyCGGIExec, + scratch: &mut Scratch, +) where + DRes: DataMut, + DLwe: DataRef, + DBrk: DataRef, + Module: CGGICircuitBootstrapFamily, + B: ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, + Scratch: TakeVecZnx + + TakeVecZnxDftSlice + + TakeVecZnxBig + + TakeVecZnxDft + + TakeMatZnx + + ScratchAvailable + + TakeVecZnxSlice, +{ + circuit_bootstrap_core_cggi( + false, + module, + 0, + res, + lwe, + log_domain, + extension_factor, + key, + scratch, + ); +} + +pub fn circuit_bootstrap_to_exponent_cggi( + module: &Module, + log_gap_out: usize, + res: &mut GGSWCiphertext, + lwe: &LWECiphertext, + log_domain: usize, + extension_factor: usize, + key: &CircuitBootstrappingKeyCGGIExec, + scratch: &mut Scratch, +) where + DRes: DataMut, + DLwe: DataRef, + DBrk: DataRef, + Module: CGGICircuitBootstrapFamily, + B: ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, + Scratch: TakeVecZnx + + TakeVecZnxDftSlice + + TakeVecZnxBig + + TakeVecZnxDft + + TakeMatZnx + + ScratchAvailable + + TakeVecZnxSlice, +{ + circuit_bootstrap_core_cggi( + true, + module, + log_gap_out, + res, + lwe, + log_domain, + extension_factor, + key, + scratch, + ); +} + +pub fn circuit_bootstrap_core_cggi( + to_exponent: bool, + module: &Module, + log_gap_out: usize, + res: &mut GGSWCiphertext, + lwe: &LWECiphertext, + log_domain: usize, + extension_factor: usize, + key: &CircuitBootstrappingKeyCGGIExec, + scratch: &mut Scratch, +) where + DRes: DataMut, + DLwe: DataRef, + DBrk: DataRef, + Module: CGGICircuitBootstrapFamily, + B: ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, + Scratch: TakeGGLWE + + TakeVecZnxDftSlice + + TakeVecZnxBig + + TakeVecZnxDft + + TakeVecZnx + + ScratchAvailable + + TakeVecZnxSlice, +{ + #[cfg(debug_assertions)] + { + use crate::Infos; + + assert_eq!(res.n(), key.brk.n()); + assert_eq!(lwe.basek(), key.brk.basek()); + assert_eq!(res.basek(), key.brk.basek()); + } + + let n: usize = res.n(); + let basek: usize = res.basek(); + let rows: usize = res.rows(); + let rank: usize = res.rank(); + let k: usize = res.k(); + + let alpha: usize = rows.next_power_of_two(); + + let mut f: Vec = vec![0i64; (1 << log_domain) * alpha]; + + if to_exponent { + (0..rows).for_each(|i| { + f[i] = 1 << (basek * (rows - 1 - i)); + }); + } else { + (0..1 << log_domain).for_each(|j| { + (0..rows).for_each(|i| { + f[j * alpha + i] = j as i64 * (1 << (basek * (rows - 1 - i))); + }); + }); + } + + // Lut precision, basically must be able to hold the decomposition power basis of the GGSW + let mut lut: LookUpTable = LookUpTable::alloc(n, basek, basek * rows, extension_factor); + lut.set(module, &f, basek * rows); + + if to_exponent { + lut.set_rotation_direction(LookUpTableRotationDirection::Right); + } + + // TODO: separate GGSW k from output of blind rotation k + let (mut res_glwe, scratch1) = scratch.take_glwe_ct(n, basek, k, rank); + let (mut tmp_gglwe, scratch2) = scratch1.take_gglwe(n, basek, k, rows, 1, rank, rank); + + let now: Instant = Instant::now(); + cggi_blind_rotate(module, &mut res_glwe, &lwe, &lut, &key.brk, scratch2); + println!("cggi_blind_rotate: {} ms", now.elapsed().as_millis()); + + let gap: usize = 2 * lut.drift / lut.extension_factor(); + + let log_gap_in: usize = (usize::BITS - (gap * alpha - 1).leading_zeros()) as _; + + (0..rows).for_each(|i| { + let mut tmp_glwe: GLWECiphertext<&mut [u8]> = tmp_gglwe.at_mut(i, 0); + + if to_exponent { + let now: Instant = Instant::now(); + + // Isolates i-th LUT and moves coefficients according to requested gap. + post_process( + module, + &mut tmp_glwe, + &res_glwe, + log_gap_in, + log_gap_out, + log_domain, + &key.atk, + scratch2, + ); + println!("post_process: {} ms", now.elapsed().as_millis()); + } else { + tmp_glwe.trace(module, 0, module.log_n(), &res_glwe, &key.atk, scratch2); + } + + if i < rows { + res_glwe.rotate_inplace(module, -(gap as i64)); + } + }); + + // Expands GGLWE to GGSW using GGLWE(s^2) + res.from_gglwe(module, &tmp_gglwe, &key.tsk, scratch2); +} + +fn post_process( + module: &Module, + res: &mut GLWECiphertext, + a: &GLWECiphertext, + log_gap_in: usize, + log_gap_out: usize, + log_domain: usize, + auto_keys: &HashMap, B>>, + scratch: &mut Scratch, +) where + DataRes: DataMut, + DataA: DataRef, + Module: CGGICircuitBootstrapFamily, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, +{ + let log_n: usize = module.log_n(); + + let mut cts: HashMap>> = HashMap::new(); + + // First partial trace, vanishes all coefficients which are not multiples of gap_in + // [1, 1, 1, 1, 0, 0, 0, ..., 0, 0, -1, -1, -1, -1] -> [1, 0, 0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0, 0] + res.trace( + module, + module.log_n() - log_gap_in as usize + 1, + log_n, + &a, + auto_keys, + scratch, + ); + + // TODO: optimize with packing and final partial trace + // If gap_out < gap_in, then we need to repack, i.e. reduce the cap between coefficients. + if log_gap_in != log_gap_out { + let steps: i32 = 1 << log_domain; + (0..steps).for_each(|i| { + if i != 0 { + res.rotate_inplace(module, -(1 << log_gap_in)); + } + cts.insert(i as usize * (1 << log_gap_out), res.clone()); + }); + + let now: Instant = Instant::now(); + pack(module, &mut cts, log_gap_out, auto_keys, scratch); + println!("pack: {} ms", now.elapsed().as_millis()); + let packed: GLWECiphertext> = cts.remove(&0).unwrap(); + res.trace( + module, + log_n - log_gap_out, + log_n, + &packed, + auto_keys, + scratch, + ); + } +} + +pub fn pack( + module: &Module, + cts: &mut HashMap>, + log_gap_out: usize, + auto_keys: &HashMap, B>>, + scratch: &mut Scratch, +) where + Module: CGGICircuitBootstrapFamily, + Scratch: TakeVecZnx + TakeVecZnxDft + ScratchAvailable, +{ + let log_n: usize = module.log_n(); + + let basek: usize = cts.get(&0).unwrap().basek(); + let k: usize = cts.get(&0).unwrap().k(); + let rank: usize = cts.get(&0).unwrap().rank(); + + (0..log_n - log_gap_out).for_each(|i| { + let now: Instant = Instant::now(); + + let t = 16.min(1 << (log_n - 1 - i)); + + let auto_key: &AutomorphismKeyExec, B>; + if i == 0 { + auto_key = auto_keys.get(&-1).unwrap() + } else { + auto_key = auto_keys.get(&module.galois_element(1 << (i - 1))).unwrap(); + } + + (0..t).for_each(|j| { + let mut a: Option> = cts.remove(&j); + let mut b: Option> = cts.remove(&(j + t)); + + combine( + module, + basek, + k, + rank, + a.as_mut(), + b.as_mut(), + i, + auto_key, + scratch, + ); + + if let Some(a) = a { + cts.insert(j, a); + } else if let Some(b) = b { + cts.insert(j, b); + } + }); + + println!("combine: {} us", now.elapsed().as_micros()); + }); +} + +fn combine( + module: &Module, + basek: usize, + k: usize, + rank: usize, + a: Option<&mut GLWECiphertext>, + b: Option<&mut GLWECiphertext>, + i: usize, + auto_key: &AutomorphismKeyExec, + scratch: &mut Scratch, +) where + Module: CGGICircuitBootstrapFamily, + Scratch: TakeVecZnx + TakeVecZnxDft + ScratchAvailable, +{ + // Goal is to evaluate: a = a + b*X^t + phi(a - b*X^t)) + // We also use the identity: AUTO(a * X^t, g) = -X^t * AUTO(a, g) + // where t = 2^(log_n - i - 1) and g = 5^{2^(i - 1)} + // Different cases for wether a and/or b are zero. + // + // Implicite RSH without modulus switch, introduces extra I(X) * Q/2 on decryption. + // Necessary so that the scaling of the plaintext remains constant. + // It however is ok to do so here because coefficients are eventually + // either mapped to garbage or twice their value which vanishes I(X) + // since 2*(I(X) * Q/2) = I(X) * Q = 0 mod Q. + if let Some(a) = a { + let n: usize = a.n(); + let log_n: usize = (u64::BITS - (n - 1).leading_zeros()) as _; + let t: i64 = 1 << (log_n - i - 1); + + if let Some(b) = b { + let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(n, basek, k, rank); + + // a = a * X^-t + a.rotate_inplace(module, -t); + + // tmp_b = a * X^-t - b + tmp_b.sub(module, a, b); + tmp_b.rsh(module, 1); + + // a = a * X^-t + b + a.add_inplace(module, b); + a.rsh(module, 1); + + tmp_b.normalize_inplace(module, scratch_1); + + // tmp_b = phi(a * X^-t - b) + tmp_b.automorphism_inplace(module, auto_key, scratch_1); + + // a = a * X^-t + b - phi(a * X^-t - b) + a.sub_inplace_ab(module, &tmp_b); + a.normalize_inplace(module, scratch_1); + + // a = a + b * X^t - phi(a * X^-t - b) * X^t + // = a + b * X^t - phi(a * X^-t - b) * - phi(X^t) + // = a + b * X^t + phi(a - b * X^t) + a.rotate_inplace(module, t); + } else { + a.rsh(module, 1); + // a = a + phi(a) + a.automorphism_add_inplace(module, auto_key, scratch); + } + } else { + if let Some(b) = b { + let n: usize = b.n(); + let log_n: usize = (u64::BITS - (n - 1).leading_zeros()) as _; + let t: i64 = 1 << (log_n - i - 1); + + let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(n, basek, k, rank); + tmp_b.rotate(module, t, b); + tmp_b.rsh(module, 1); + + // a = (b* X^t - phi(b* X^t)) + b.automorphism_sub_ba(module, &tmp_b, auto_key, scratch_1); + } + } +} diff --git a/core/src/circuit_bootstrapping/mod.rs b/core/src/circuit_bootstrapping/mod.rs new file mode 100644 index 0000000..36ba50e --- /dev/null +++ b/core/src/circuit_bootstrapping/mod.rs @@ -0,0 +1,6 @@ +mod circuit_bootstrapping; + +pub use circuit_bootstrapping::*; + +#[cfg(test)] +mod test_fft64; diff --git a/core/src/circuit_bootstrapping/test_fft64/circuit_bootstrapping.rs b/core/src/circuit_bootstrapping/test_fft64/circuit_bootstrapping.rs new file mode 100644 index 0000000..59cd010 --- /dev/null +++ b/core/src/circuit_bootstrapping/test_fft64/circuit_bootstrapping.rs @@ -0,0 +1,357 @@ +use std::time::Instant; + +use backend::{ + hal::{ + api::{ + ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, + VecZnxEncodeCoeffsi64, VecZnxFillUniform, VecZnxNormalizeInplace, VecZnxRotateInplace, VecZnxStd, VecZnxSwithcDegree, + ZnxView, ZnxViewMut, + }, + layouts::{Backend, Module, ScalarZnx, ScratchOwned}, + oep::{ + ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeMatZnxImpl, TakeScalarZnxImpl, + TakeSvpPPolImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, + }, + }, + implementation::cpu_spqlios::FFT64, +}; +use sampling::source::Source; + +use crate::{ + AutomorphismKeyEncryptSkFamily, BlindRotationKeyCGGIExecLayoutFamily, GGSWAssertNoiseFamily, GGSWCiphertext, + GGSWCiphertextExec, GGSWEncryptSkFamily, GLWECiphertext, GLWEDecryptFamily, GLWEPlaintext, GLWESecret, GLWESecretExec, + GLWESecretFamily, GLWETensorKeyEncryptSkFamily, LWECiphertext, LWESecret, + circuit_bootstrapping::circuit_bootstrapping::{ + CGGICircuitBootstrapFamily, CircuitBootstrappingKeyCGGI, CircuitBootstrappingKeyCGGIExec, + circuit_bootstrap_to_constant_cggi, circuit_bootstrap_to_exponent_cggi, + }, + lwe::LWEPlaintext, +}; + +#[test] +fn test_to_exponent() { + let module: Module = Module::::new(256); + to_exponent(&module); +} + +fn to_exponent(module: &Module) +where + Module: GLWESecretFamily + + VecZnxEncodeCoeffsi64 + + VecZnxFillUniform + + VecZnxAddNormal + + VecZnxNormalizeInplace + + GLWESecretFamily + + GGSWEncryptSkFamily + + VecZnxAddScalarInplace + + AutomorphismKeyEncryptSkFamily + + VecZnxAutomorphism + + VecZnxSwithcDegree + + GLWETensorKeyEncryptSkFamily + + BlindRotationKeyCGGIExecLayoutFamily + + CGGICircuitBootstrapFamily + + GLWEDecryptFamily + + GGSWAssertNoiseFamily + + VecZnxStd, + B: ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + TakeVecZnxDftImpl + + ScratchAvailableImpl + + TakeVecZnxImpl + + TakeScalarZnxImpl + + TakeSvpPPolImpl + + TakeVecZnxBigImpl + + TakeVecZnxDftSliceImpl + + TakeMatZnxImpl + + TakeVecZnxSliceImpl, +{ + let n: usize = module.n(); + let basek: usize = 17; + let extension_factor: usize = 1; + let rank: usize = 1; + let sigma: f64 = 3.2; + + let n_lwe: usize = 77; + let k_lwe_pt: usize = 4; + let k_lwe_ct: usize = 22; + let block_size: usize = 7; + + let k_brk: usize = 5 * basek; + let rows_brk: usize = 4; + + let k_trace: usize = 5 * basek; + let rows_trace: usize = 4; + + let k_tsk: usize = 5 * basek; + let rows_tsk: usize = 4; + + let mut scratch: ScratchOwned = ScratchOwned::alloc(1 << 23); + + let mut source_xs: Source = Source::new([1u8; 32]); + let mut source_xa: Source = Source::new([1u8; 32]); + let mut source_xe: Source = Source::new([1u8; 32]); + + let mut sk_lwe: LWESecret> = LWESecret::alloc(n_lwe); + sk_lwe.fill_binary_block(block_size, &mut source_xs); + + let mut sk_glwe: GLWESecret> = GLWESecret::alloc(n, rank); + sk_glwe.fill_ternary_prob(0.5, &mut source_xs); + + let sk_glwe_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk_glwe); + + let data: i64 = 1; + + let mut pt_lwe: LWEPlaintext> = LWEPlaintext::alloc(basek, k_lwe_pt); + module.encode_coeff_i64(basek, &mut pt_lwe.data, 0, k_lwe_pt + 2, 0, data, k_lwe_pt); + + println!("pt_lwe: {}", pt_lwe.data); + + let mut ct_lwe: LWECiphertext> = LWECiphertext::alloc(n_lwe, basek, k_lwe_ct); + ct_lwe.encrypt_sk( + module, + &pt_lwe, + &sk_lwe, + &mut source_xa, + &mut source_xe, + sigma, + ); + + let now: Instant = Instant::now(); + let cbt_key: CircuitBootstrappingKeyCGGI> = CircuitBootstrappingKeyCGGI::generate( + module, + basek, + &sk_lwe, + &sk_glwe, + k_brk, + rows_brk, + k_trace, + rows_trace, + k_tsk, + rows_tsk, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + println!("CBT-KGEN: {} ms", now.elapsed().as_millis()); + + let k_ggsw_res: usize = 4 * basek; + let rows_ggsw_res: usize = 2; + + let mut res: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ggsw_res, rows_ggsw_res, 1, rank); + + let log_gap_out = 1; + + let cbt_exec: CircuitBootstrappingKeyCGGIExec, B> = + CircuitBootstrappingKeyCGGIExec::from(module, &cbt_key, scratch.borrow()); + + let now: Instant = Instant::now(); + circuit_bootstrap_to_exponent_cggi( + module, + log_gap_out, + &mut res, + &ct_lwe, + k_lwe_pt, + extension_factor, + &cbt_exec, + scratch.borrow(), + ); + println!("CBT: {} ms", now.elapsed().as_millis()); + + // X^{data * 2^log_gap_out} + let mut pt_ggsw: ScalarZnx> = ScalarZnx::alloc(n, 1); + pt_ggsw.at_mut(0, 0)[0] = 1; + module.vec_znx_rotate_inplace(data * (1 << log_gap_out), &mut pt_ggsw.as_vec_znx_mut(), 0); + + res.print_noise(module, &sk_glwe_exec, &pt_ggsw); + + let k_glwe: usize = k_ggsw_res; + + let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_glwe, rank); + let mut pt_glwe: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, basek); + pt_glwe.data.at_mut(0, 0)[0] = 1 << (basek - 2); + + ct_glwe.encrypt_sk( + module, + &pt_glwe, + &sk_glwe_exec, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + let res_exec: GGSWCiphertextExec, B> = GGSWCiphertextExec::from(module, &res, scratch.borrow()); + + ct_glwe.external_product_inplace(module, &res_exec, scratch.borrow()); + + let mut pt_res: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_glwe); + ct_glwe.decrypt(module, &mut pt_res, &sk_glwe_exec, scratch.borrow()); + + // Parameters are set such that the first limb should be noiseless. + let mut pt_want: Vec = vec![0i64; module.n()]; + pt_want[data as usize * (1 << log_gap_out)] = pt_glwe.data.at(0, 0)[0]; + assert_eq!(pt_res.data.at(0, 0), pt_want); +} + +#[test] +fn test_to_constant() { + let module: Module = Module::::new(256); + to_constant(&module); +} + +fn to_constant(module: &Module) +where + Module: GLWESecretFamily + + VecZnxEncodeCoeffsi64 + + VecZnxFillUniform + + VecZnxAddNormal + + VecZnxNormalizeInplace + + GLWESecretFamily + + GGSWEncryptSkFamily + + VecZnxAddScalarInplace + + AutomorphismKeyEncryptSkFamily + + VecZnxAutomorphism + + VecZnxSwithcDegree + + GLWETensorKeyEncryptSkFamily + + BlindRotationKeyCGGIExecLayoutFamily + + CGGICircuitBootstrapFamily + + GLWEDecryptFamily + + GGSWAssertNoiseFamily + + VecZnxStd, + B: ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + TakeVecZnxDftImpl + + ScratchAvailableImpl + + TakeVecZnxImpl + + TakeScalarZnxImpl + + TakeSvpPPolImpl + + TakeVecZnxBigImpl + + TakeVecZnxDftSliceImpl + + TakeMatZnxImpl + + TakeVecZnxSliceImpl, +{ + let n = module.n(); + let basek: usize = 14; + let extension_factor: usize = 1; + let rank: usize = 2; + let sigma: f64 = 3.2; + + let n_lwe: usize = 77; + let k_lwe_pt: usize = 1; + let k_lwe_ct: usize = 13; + let block_size: usize = 7; + + let k_brk: usize = 5 * basek; + let rows_brk: usize = 3; + + let k_trace: usize = 5 * basek; + let rows_trace: usize = 4; + + let k_tsk: usize = 5 * basek; + let rows_tsk: usize = 4; + + let mut scratch: ScratchOwned = ScratchOwned::alloc(1 << 23); + + let mut source_xs: Source = Source::new([1u8; 32]); + let mut source_xa: Source = Source::new([1u8; 32]); + let mut source_xe: Source = Source::new([1u8; 32]); + + let mut sk_lwe: LWESecret> = LWESecret::alloc(n_lwe); + sk_lwe.fill_binary_block(block_size, &mut source_xs); + + let mut sk_glwe: GLWESecret> = GLWESecret::alloc(n, rank); + sk_glwe.fill_ternary_prob(0.5, &mut source_xs); + + let sk_glwe_exec: GLWESecretExec, B> = GLWESecretExec::from(module, &sk_glwe); + + let data: i64 = 1; + + let mut pt_lwe: LWEPlaintext> = LWEPlaintext::alloc(basek, k_lwe_pt); + module.encode_coeff_i64(basek, &mut pt_lwe.data, 0, k_lwe_pt + 2, 0, data, k_lwe_pt); + + println!("pt_lwe: {}", pt_lwe.data); + + let mut ct_lwe: LWECiphertext> = LWECiphertext::alloc(n_lwe, basek, k_lwe_ct); + ct_lwe.encrypt_sk( + module, + &pt_lwe, + &sk_lwe, + &mut source_xa, + &mut source_xe, + sigma, + ); + + let now: Instant = Instant::now(); + let cbt_key: CircuitBootstrappingKeyCGGI> = CircuitBootstrappingKeyCGGI::generate( + module, + basek, + &sk_lwe, + &sk_glwe, + k_brk, + rows_brk, + k_trace, + rows_trace, + k_tsk, + rows_tsk, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + println!("CBT-KGEN: {} ms", now.elapsed().as_millis()); + + let k_ggsw_res: usize = 4 * basek; + let rows_ggsw_res: usize = 3; + + let mut res: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ggsw_res, rows_ggsw_res, 1, rank); + + let cbt_exec: CircuitBootstrappingKeyCGGIExec, B> = + CircuitBootstrappingKeyCGGIExec::from(module, &cbt_key, scratch.borrow()); + + let now: Instant = Instant::now(); + circuit_bootstrap_to_constant_cggi( + module, + &mut res, + &ct_lwe, + k_lwe_pt, + extension_factor, + &cbt_exec, + scratch.borrow(), + ); + println!("CBT: {} ms", now.elapsed().as_millis()); + + // X^{data * 2^log_gap_out} + let mut pt_ggsw: ScalarZnx> = ScalarZnx::alloc(n, 1); + pt_ggsw.at_mut(0, 0)[0] = data; + + res.print_noise(module, &sk_glwe_exec, &pt_ggsw); + + let k_glwe: usize = k_ggsw_res; + + let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_glwe, rank); + let mut pt_glwe: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, basek); + pt_glwe.data.at_mut(0, 0)[0] = 1 << (basek - k_lwe_pt - 1); + + ct_glwe.encrypt_sk( + module, + &pt_glwe, + &sk_glwe_exec, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + let res_exec: GGSWCiphertextExec, B> = GGSWCiphertextExec::from(module, &res, scratch.borrow()); + + ct_glwe.external_product_inplace(module, &res_exec, scratch.borrow()); + + let mut pt_res: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_glwe); + ct_glwe.decrypt(module, &mut pt_res, &sk_glwe_exec, scratch.borrow()); + + // Parameters are set such that the first limb should be noiseless. + let mut pt_want: Vec = vec![0i64; module.n()]; + pt_want[0] = pt_glwe.data.at(0, 0)[0] * data; + assert_eq!(pt_res.data.at(0, 0), pt_want); +} diff --git a/core/src/circuit_bootstrapping/test_fft64/mod.rs b/core/src/circuit_bootstrapping/test_fft64/mod.rs new file mode 100644 index 0000000..cc73725 --- /dev/null +++ b/core/src/circuit_bootstrapping/test_fft64/mod.rs @@ -0,0 +1 @@ +mod circuit_bootstrapping; diff --git a/core/src/gglwe/layouts_exec.rs b/core/src/gglwe/layouts_exec.rs index 4ba0d35..52bbbee 100644 --- a/core/src/gglwe/layouts_exec.rs +++ b/core/src/gglwe/layouts_exec.rs @@ -366,6 +366,27 @@ impl GLWETensorKeyExec, B> { let pairs: usize = (((rank + 1) * rank) >> 1).max(1); pairs * GLWESwitchingKeyExec::bytes_of(module, n, basek, k, rows, digits, 1, rank) } + + pub fn from( + module: &Module, + other: &GLWETensorKey, + scratch: &mut Scratch, + ) -> GLWETensorKeyExec, B> + where + Module: GGLWEExecLayoutFamily, + { + let mut tsk_exec: GLWETensorKeyExec, B> = Self::alloc( + module, + other.n(), + other.basek(), + other.k(), + other.rows(), + other.digits(), + other.rank(), + ); + tsk_exec.prepare(module, other, scratch); + tsk_exec + } } impl Infos for GLWETensorKeyExec { diff --git a/core/src/ggsw/automorphism.rs b/core/src/ggsw/automorphism.rs index 9e25c37..1081828 100644 --- a/core/src/ggsw/automorphism.rs +++ b/core/src/ggsw/automorphism.rs @@ -108,32 +108,8 @@ impl GGSWCiphertext { ) }; - let n: usize = auto_key.n(); - let rank: usize = self.rank(); - let cols: usize = rank + 1; - - // Keyswitch the j-th row of the col 0 - (0..lhs.rows()).for_each(|row_i| { - // Key-switch column 0, i.e. - // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2) - self.at_mut(row_i, 0) - .automorphism(module, &lhs.at(row_i, 0), auto_key, scratch); - - // Isolates DFT(AUTO(a[i])) - let (mut ci_dft, scratch1) = scratch.take_vec_znx_dft(n, cols, self.size()); - (0..cols).for_each(|i| { - module.vec_znx_dft_from_vec_znx(1, 0, &mut ci_dft, i, &self.at(row_i, 0).data, i); - }); - - // Generates - // - // col 1: (-(b0s0 + b1s1 + b2s2) , b0 + pi(M[i]), b1 , b2 ) - // col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + pi(M[i]), c2 ) - // col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + pi(M[i])) - (1..cols).for_each(|col_j| { - self.expand_row(module, row_i, col_j, &ci_dft, tensor_key, scratch1); - }); - }) + self.automorphism_internal(module, lhs, auto_key, scratch); + self.expand_row(module, tensor_key, scratch); } pub fn automorphism_inplace( @@ -151,4 +127,23 @@ impl GGSWCiphertext { self.automorphism(module, &*self_ptr, auto_key, tensor_key, scratch); } } + + fn automorphism_internal( + &mut self, + module: &Module, + lhs: &GGSWCiphertext, + auto_key: &AutomorphismKeyExec, + scratch: &mut Scratch, + ) where + Module: GLWEKeyswitchFamily + GGSWKeySwitchFamily + VecZnxAutomorphismInplace + VecZnxNormalizeTmpBytes, + Scratch: ScratchAvailable + TakeVecZnxDft + TakeVecZnxBig, + { + // Keyswitch the j-th row of the col 0 + (0..lhs.rows()).for_each(|row_i| { + // Key-switch column 0, i.e. + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2) + self.at_mut(row_i, 0) + .automorphism(module, &lhs.at(row_i, 0), auto_key, scratch); + }); + } } diff --git a/core/src/ggsw/keyswitch.rs b/core/src/ggsw/keyswitch.rs index f549e19..4c85b0c 100644 --- a/core/src/ggsw/keyswitch.rs +++ b/core/src/ggsw/keyswitch.rs @@ -1,12 +1,14 @@ use backend::hal::{ api::{ - ScratchAvailable, TakeVecZnxBig, TakeVecZnxDft, VecZnxBigAllocBytes, VecZnxDftAddInplace, VecZnxDftCopy, - VecZnxDftToVecZnxBigTmpA, VecZnxNormalizeTmpBytes, ZnxInfos, + ScratchAvailable, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxBigAllocBytes, VecZnxCopy, VecZnxDftAddInplace, + VecZnxDftCopy, VecZnxDftToVecZnxBigTmpA, VecZnxNormalizeTmpBytes, ZnxInfos, }, - layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnx, VecZnxDft, VmpPMat}, + layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnx, VmpPMat}, }; -use crate::{GGSWCiphertext, GLWECiphertext, GLWEKeyswitchFamily, GLWESwitchingKeyExec, GLWETensorKeyExec, Infos}; +use crate::{ + GGLWECiphertext, GGSWCiphertext, GLWECiphertext, GLWEKeyswitchFamily, GLWEOps, GLWESwitchingKeyExec, GLWETensorKeyExec, Infos, +}; pub trait GGSWKeySwitchFamily = GLWEKeyswitchFamily + VecZnxBigAllocBytes + VecZnxDftCopy + VecZnxDftAddInplace + VecZnxDftToVecZnxBigTmpA; @@ -40,7 +42,7 @@ impl GGSWCiphertext> { tsk_size, ); let tmp_idft: usize = module.vec_znx_big_alloc_bytes(n, 1, tsk_size); - let norm: usize = module.vec_znx_normalize_tmp_bytes(module.n()); + let norm: usize = module.vec_znx_normalize_tmp_bytes(n); tmp_dft_i + ((tmp_a + vmp) | (tmp_idft + norm)) } @@ -89,124 +91,30 @@ impl GGSWCiphertext> { } impl GGSWCiphertext { - pub(crate) fn expand_row( + pub fn from_gglwe( &mut self, module: &Module, - row_i: usize, - col_j: usize, - ci_dft: &VecZnxDft, + a: &GGLWECiphertext, tsk: &GLWETensorKeyExec, scratch: &mut Scratch, ) where - Module: GGSWKeySwitchFamily + VecZnxNormalizeTmpBytes, - Scratch: TakeVecZnxDft + TakeVecZnxBig + ScratchAvailable, + DataA: DataRef, + DataTsk: DataRef, + Module: GGSWKeySwitchFamily + VecZnxNormalizeTmpBytes + VecZnxCopy, + Scratch: TakeVecZnxDft + TakeVecZnxBig + ScratchAvailable + TakeVecZnx, { - let cols: usize = self.rank() + 1; - #[cfg(debug_assertions)] { - assert_eq!(self.n(), tsk.n()); + assert_eq!(self.rank(), a.rank()); + assert_eq!(self.rows(), a.rows()); + assert_eq!(self.n(), module.n()); + assert_eq!(a.n(), module.n()); + assert_eq!(tsk.n(), module.n()); } - - assert!( - scratch.available() - >= GGSWCiphertext::expand_row_scratch_space( - module, - self.n(), - self.basek(), - self.k(), - tsk.k(), - tsk.digits(), - tsk.rank() - ) - ); - - // Example for rank 3: - // - // Note: M is a vector (m, Bm, B^2m, B^3m, ...), so each column is - // actually composed of that many rows and we focus on a specific row here - // implicitely given ci_dft. - // - // # Input - // - // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 ) - // col 1: (0, 0, 0, 0) - // col 2: (0, 0, 0, 0) - // col 3: (0, 0, 0, 0) - // - // # Output - // - // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 ) - // col 1: (-(b0s0 + b1s1 + b2s2) , b0 + M[i], b1 , b2 ) - // col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + M[i], c2 ) - // col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + M[i]) - - let n: usize = self.n(); - let digits: usize = tsk.digits(); - - let (mut tmp_dft_i, scratch1) = scratch.take_vec_znx_dft(n, cols, tsk.size()); - let (mut tmp_a, scratch2) = scratch1.take_vec_znx_dft(n, 1, ci_dft.size().div_ceil(digits)); - - { - // Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2 - // - // # Example for col=1 - // - // a0 * (-(f0s0 + f1s1 + f1s2) + s0^2, f0, f1, f2) = (-(a0f0s0 + a0f1s1 + a0f1s2) + a0s0^2, a0f0, a0f1, a0f2) - // + - // a1 * (-(g0s0 + g1s1 + g1s2) + s0s1, g0, g1, g2) = (-(a1g0s0 + a1g1s1 + a1g1s2) + a1s0s1, a1g0, a1g1, a1g2) - // + - // a2 * (-(h0s0 + h1s1 + h1s2) + s0s2, h0, h1, h2) = (-(a2h0s0 + a2h1s1 + a2h1s2) + a2s0s2, a2h0, a2h1, a2h2) - // = - // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0, x1, x2) - (1..cols).for_each(|col_i| { - let pmat: &VmpPMat = &tsk.at(col_i - 1, col_j - 1).key.data; // Selects Enc(s[i]s[j]) - - // Extracts a[i] and multipies with Enc(s[i]s[j]) - (0..digits).for_each(|di| { - tmp_a.set_size((ci_dft.size() + di) / digits); - - // Small optimization for digits > 2 - // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then - // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}. - // As such we can ignore the last digits-2 limbs safely of the sum of vmp products. - // It is possible to further ignore the last digits-1 limbs, but this introduce - // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same - // noise is kept with respect to the ideal functionality. - tmp_dft_i.set_size(tsk.size() - ((digits - di) as isize - 2).max(0) as usize); - - module.vec_znx_dft_copy(digits, digits - 1 - di, &mut tmp_a, 0, ci_dft, col_i); - if di == 0 && col_i == 1 { - module.vmp_apply(&mut tmp_dft_i, &tmp_a, pmat, scratch2); - } else { - module.vmp_apply_add(&mut tmp_dft_i, &tmp_a, pmat, di, scratch2); - } - }); - }); - } - - // Adds -(sum a[i] * s[i]) + m) on the i-th column of tmp_idft_i - // - // (-(x0s0 + x1s1 + x2s2) + a0s0s0 + a1s0s1 + a2s0s2, x0, x1, x2) - // + - // (0, -(a0s0 + a1s1 + a2s2) + M[i], 0, 0) - // = - // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0 -(a0s0 + a1s1 + a2s2) + M[i], x1, x2) - // = - // (-(x0s0 + x1s1 + x2s2), x0 + M[i], x1, x2) - module.vec_znx_dft_add_inplace(&mut tmp_dft_i, col_j, ci_dft, 0); - let (mut tmp_idft, scratch2) = scratch1.take_vec_znx_big(n, 1, tsk.size()); - (0..cols).for_each(|i| { - module.vec_znx_dft_to_vec_znx_big_tmp_a(&mut tmp_idft, 0, &mut tmp_dft_i, i); - module.vec_znx_big_normalize( - self.basek(), - &mut self.at_mut(row_i, col_j).data, - i, - &tmp_idft, - 0, - scratch2, - ); + (0..self.rows()).for_each(|row_i| { + self.at_mut(row_i, 0).copy(module, &a.at(row_i, 0)); }); + self.expand_row(module, tsk, scratch); } pub fn keyswitch( @@ -220,31 +128,8 @@ impl GGSWCiphertext { Module: GLWEKeyswitchFamily + GGSWKeySwitchFamily + VecZnxNormalizeTmpBytes, Scratch: TakeVecZnxDft + TakeVecZnxBig + ScratchAvailable, { - let n: usize = self.n(); - let rank: usize = self.rank(); - let cols: usize = rank + 1; - - // Keyswitch the j-th row of the col 0 - (0..lhs.rows()).for_each(|row_i| { - // Key-switch column 0, i.e. - // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2) - self.at_mut(row_i, 0) - .keyswitch(module, &lhs.at(row_i, 0), ksk, scratch); - - // Pre-compute DFT of (a0, a1, a2) - let (mut ci_dft, scratch1) = scratch.take_vec_znx_dft(n, cols, self.size()); - (0..cols).for_each(|i| { - module.vec_znx_dft_from_vec_znx(1, 0, &mut ci_dft, i, &self.at(row_i, 0).data, i); - }); - // Generates - // - // col 1: (-(b0s0' + b1s1' + b2s2') , b0 + M[i], b1 , b2 ) - // col 2: (-(c0s0' + c1s1' + c2s2') , c0 , c1 + M[i], c2 ) - // col 3: (-(d0s0' + d1s1' + d2s2') , d0 , d1 , d2 + M[i]) - (1..cols).for_each(|col_j| { - self.expand_row(module, row_i, col_j, &ci_dft, tsk, scratch1); - }); - }) + self.keyswitch_internal(module, lhs, ksk, scratch); + self.expand_row(module, tsk, scratch); } pub fn keyswitch_inplace( @@ -262,4 +147,147 @@ impl GGSWCiphertext { self.keyswitch(module, &*self_ptr, ksk, tsk, scratch); } } + + pub fn expand_row( + &mut self, + module: &Module, + tsk: &GLWETensorKeyExec, + scratch: &mut Scratch, + ) where + Module: GGSWKeySwitchFamily + VecZnxNormalizeTmpBytes, + Scratch: TakeVecZnxDft + TakeVecZnxBig + ScratchAvailable, + { + assert!( + scratch.available() + >= GGSWCiphertext::expand_row_scratch_space( + module, + self.n(), + self.basek(), + self.k(), + tsk.k(), + tsk.digits(), + tsk.rank() + ) + ); + + let n: usize = self.n(); + let rank: usize = self.rank(); + let cols: usize = rank + 1; + + // Keyswitch the j-th row of the col 0 + (0..self.rows()).for_each(|row_i| { + // Pre-compute DFT of (a0, a1, a2) + let (mut ci_dft, scratch1) = scratch.take_vec_znx_dft(n, cols, self.size()); + (0..cols).for_each(|i| { + module.vec_znx_dft_from_vec_znx(1, 0, &mut ci_dft, i, &self.at(row_i, 0).data, i); + }); + + (1..cols).for_each(|col_j| { + // Example for rank 3: + // + // Note: M is a vector (m, Bm, B^2m, B^3m, ...), so each column is + // actually composed of that many rows and we focus on a specific row here + // implicitely given ci_dft. + // + // # Input + // + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 ) + // col 1: (0, 0, 0, 0) + // col 2: (0, 0, 0, 0) + // col 3: (0, 0, 0, 0) + // + // # Output + // + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 ) + // col 1: (-(b0s0 + b1s1 + b2s2) , b0 + M[i], b1 , b2 ) + // col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + M[i], c2 ) + // col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + M[i]) + + let digits: usize = tsk.digits(); + + let (mut tmp_dft_i, scratch2) = scratch1.take_vec_znx_dft(n, cols, tsk.size()); + let (mut tmp_a, scratch3) = scratch2.take_vec_znx_dft(n, 1, ci_dft.size().div_ceil(digits)); + + { + // Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2 + // + // # Example for col=1 + // + // a0 * (-(f0s0 + f1s1 + f1s2) + s0^2, f0, f1, f2) = (-(a0f0s0 + a0f1s1 + a0f1s2) + a0s0^2, a0f0, a0f1, a0f2) + // + + // a1 * (-(g0s0 + g1s1 + g1s2) + s0s1, g0, g1, g2) = (-(a1g0s0 + a1g1s1 + a1g1s2) + a1s0s1, a1g0, a1g1, a1g2) + // + + // a2 * (-(h0s0 + h1s1 + h1s2) + s0s2, h0, h1, h2) = (-(a2h0s0 + a2h1s1 + a2h1s2) + a2s0s2, a2h0, a2h1, a2h2) + // = + // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0, x1, x2) + (1..cols).for_each(|col_i| { + let pmat: &VmpPMat = &tsk.at(col_i - 1, col_j - 1).key.data; // Selects Enc(s[i]s[j]) + + // Extracts a[i] and multipies with Enc(s[i]s[j]) + (0..digits).for_each(|di| { + tmp_a.set_size((ci_dft.size() + di) / digits); + + // Small optimization for digits > 2 + // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then + // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}. + // As such we can ignore the last digits-2 limbs safely of the sum of vmp products. + // It is possible to further ignore the last digits-1 limbs, but this introduce + // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same + // noise is kept with respect to the ideal functionality. + tmp_dft_i.set_size(tsk.size() - ((digits - di) as isize - 2).max(0) as usize); + + module.vec_znx_dft_copy(digits, digits - 1 - di, &mut tmp_a, 0, &ci_dft, col_i); + if di == 0 && col_i == 1 { + module.vmp_apply(&mut tmp_dft_i, &tmp_a, pmat, scratch3); + } else { + module.vmp_apply_add(&mut tmp_dft_i, &tmp_a, pmat, di, scratch3); + } + }); + }); + } + + // Adds -(sum a[i] * s[i]) + m) on the i-th column of tmp_idft_i + // + // (-(x0s0 + x1s1 + x2s2) + a0s0s0 + a1s0s1 + a2s0s2, x0, x1, x2) + // + + // (0, -(a0s0 + a1s1 + a2s2) + M[i], 0, 0) + // = + // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0 -(a0s0 + a1s1 + a2s2) + M[i], x1, x2) + // = + // (-(x0s0 + x1s1 + x2s2), x0 + M[i], x1, x2) + module.vec_znx_dft_add_inplace(&mut tmp_dft_i, col_j, &ci_dft, 0); + let (mut tmp_idft, scratch3) = scratch2.take_vec_znx_big(n, 1, tsk.size()); + (0..cols).for_each(|i| { + module.vec_znx_dft_to_vec_znx_big_tmp_a(&mut tmp_idft, 0, &mut tmp_dft_i, i); + module.vec_znx_big_normalize( + self.basek(), + &mut self.at_mut(row_i, col_j).data, + i, + &tmp_idft, + 0, + scratch3, + ); + }); + }) + }) + } + + fn keyswitch_internal( + &mut self, + module: &Module, + lhs: &GGSWCiphertext, + ksk: &GLWESwitchingKeyExec, + scratch: &mut Scratch, + ) where + Module: GLWEKeyswitchFamily + GGSWKeySwitchFamily + VecZnxNormalizeTmpBytes, + Scratch: TakeVecZnxDft + TakeVecZnxBig + ScratchAvailable, + { + // Keyswitch the j-th row of the col 0 + (0..lhs.rows()).for_each(|row_i| { + // Key-switch column 0, i.e. + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2) + self.at_mut(row_i, 0) + .keyswitch(module, &lhs.at(row_i, 0), ksk, scratch); + }) + } } diff --git a/core/src/ggsw/noise.rs b/core/src/ggsw/noise.rs index 83590e2..c8c7148 100644 --- a/core/src/ggsw/noise.rs +++ b/core/src/ggsw/noise.rs @@ -71,3 +71,53 @@ impl GGSWCiphertext { }); } } + +impl GGSWCiphertext { + pub fn print_noise( + &self, + module: &Module, + sk_exec: &GLWESecretExec, + pt_want: &ScalarZnx, + ) where + DataSk: DataRef, + DataScalar: DataRef, + Module: GGSWAssertNoiseFamily + VecZnxAddScalarInplace + VecZnxSubABInplace + VecZnxStd, + B: TakeVecZnxDftImpl + TakeVecZnxBigImpl + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, + { + let basek: usize = self.basek(); + let k: usize = self.k(); + let digits: usize = self.digits(); + + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(self.n(), basek, k); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(self.n(), basek, k); + let mut pt_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(self.n(), 1, self.size()); + let mut pt_big: VecZnxBig, B> = module.vec_znx_big_alloc(self.n(), 1, self.size()); + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GLWECiphertext::decrypt_scratch_space(module, self.n(), basek, k) | module.vec_znx_normalize_tmp_bytes(module.n()), + ); + + (0..self.rank() + 1).for_each(|col_j| { + (0..self.rows()).for_each(|row_i| { + module.vec_znx_add_scalar_inplace(&mut pt.data, 0, (digits - 1) + row_i * digits, pt_want, 0); + + // mul with sk[col_j-1] + if col_j > 0 { + module.vec_znx_dft_from_vec_znx(1, 0, &mut pt_dft, 0, &pt.data, 0); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_exec.data, col_j - 1); + module.vec_znx_dft_to_vec_znx_big_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); + module.vec_znx_big_normalize(basek, &mut pt.data, 0, &pt_big, 0, scratch.borrow()); + } + + self.at(row_i, col_j) + .decrypt(module, &mut pt_have, &sk_exec, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt.data, 0); + + let std_pt: f64 = module.vec_znx_std(basek, &pt_have.data, 0).log2(); + println!("{}", std_pt); + pt.data.zero(); + }); + }); + } +} diff --git a/core/src/glwe/mod.rs b/core/src/glwe/mod.rs index 42b71d9..ec4a48d 100644 --- a/core/src/glwe/mod.rs +++ b/core/src/glwe/mod.rs @@ -25,3 +25,4 @@ pub use packing::*; pub use plaintext::*; pub use public_key::*; pub use secret::*; +pub use trace::*; diff --git a/core/src/glwe/trace.rs b/core/src/glwe/trace.rs index 3ae9869..ab21375 100644 --- a/core/src/glwe/trace.rs +++ b/core/src/glwe/trace.rs @@ -5,10 +5,7 @@ use backend::hal::{ layouts::{Backend, DataMut, DataRef, Module, Scratch}, }; -use crate::{ - AutomorphismKeyExec, GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, GLWEKeyswitchFamily, GLWEOps, Infos, - SetMetaData, -}; +use crate::{AutomorphismKeyExec, GLWECiphertext, GLWECiphertextToMut, GLWEKeyswitchFamily, GLWEOps, Infos, SetMetaData}; pub trait GLWETraceFamily = GLWEKeyswitchFamily + VecZnxCopy + VecZnxRshInplace + VecZnxBigAutomorphismInplace; @@ -70,7 +67,6 @@ where auto_keys: &HashMap>, scratch: &mut Scratch, ) where - GLWECiphertext: GLWECiphertextToRef + Infos + VecZnxRshInplace, Module: GLWETraceFamily, Scratch: TakeVecZnxDft + ScratchAvailable, { diff --git a/core/src/lib.rs b/core/src/lib.rs index fa3b87f..8e8389f 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -1,5 +1,6 @@ #![feature(trait_alias)] mod blind_rotation; +mod circuit_bootstrapping; mod dist; mod elem; mod gglwe; @@ -12,6 +13,7 @@ mod scratch; use crate::dist::Distribution; pub use blind_rotation::*; +pub use circuit_bootstrapping::*; pub use elem::*; pub use gglwe::*; pub use ggsw::*;