Browse Source

tfhe: add initial bootstrapping impl

arnaucube 4 days ago
parent
commit
204bec5352
3 changed files with 105 additions and 5 deletions
  1. +0
    -3
      arith/src/ring_torus.rs
  2. +46
    -0
      tfhe/src/tggsw.rs
  3. +59
    -2
      tfhe/src/tlwe.rs

+ 0
- 3
arith/src/ring_torus.rs

@ -85,14 +85,11 @@ impl Ring for Tn {
impl<const N: usize> Tn<N> { impl<const N: usize> Tn<N> {
// multiply self by X^-h // multiply self by X^-h
pub fn left_rotate(&self, h: usize) -> Self { pub fn left_rotate(&self, h: usize) -> Self {
dbg!(&h);
dbg!(&N);
let h = h % N; let h = h % N;
assert!(h < N); assert!(h < N);
let c = self.0; let c = self.0;
// c[h], c[h+1], c[h+2], ..., c[n-1], -c[0], -c[1], ..., -c[h-1] // c[h], c[h+1], c[h+2], ..., c[n-1], -c[0], -c[1], ..., -c[h-1]
// let r: Vec<T64> = vec![c[h..N], c[0..h].iter().map(|&c_i| -c_i).collect()].concat(); // let r: Vec<T64> = vec![c[h..N], c[0..h].iter().map(|&c_i| -c_i).collect()].concat();
dbg!(&h);
let r: Vec<T64> = c[h..N] let r: Vec<T64> = c[h..N]
.iter() .iter()
.copied() .copied()

+ 46
- 0
tfhe/src/tggsw.rs

@ -132,3 +132,49 @@ impl Mul>> for TGLev {
r r
} }
} }
#[cfg(test)]
mod tests {
use anyhow::Result;
use rand::distributions::Uniform;
use super::*;
#[test]
fn test_external_product() -> Result<()> {
const T: u64 = 16; // plaintext modulus
const K: usize = 4;
const N: usize = 64;
const KN: usize = K * N;
let beta: u32 = 2;
let l: u32 = 64;
let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
for _ in 0..50 {
let (sk, _) = TGLWE::<N, K>::new_key::<KN>(&mut rng)?;
let m1: Rq<T, N> = Rq::rand_u64(&mut rng, msg_dist)?;
let p1: Tn<N> = TGLev::<N, K>::encode::<T>(&m1);
let m2: Rq<T, N> = Rq::rand_u64(&mut rng, msg_dist)?;
let p2: Tn<N> = TGLWE::<N, K>::encode::<T>(&m2); // scaled by delta
let tgsw = TGGSW::<N, K>::encrypt_s(&mut rng, beta, l, &sk, &p1)?;
let tlwe = TGLWE::<N, K>::encrypt_s(&mut rng, &sk, &p2)?;
let res: TGLWE<N, K> = tgsw * tlwe;
// let p_recovered = res.decrypt(&sk, beta);
let p_recovered = res.decrypt(&sk);
// downscaled by delta^-1
let res_recovered = TGLWE::<N, K>::decode::<T>(&p_recovered);
// assert_eq!(m1 * m2, m_recovered);
assert_eq!((m1.to_r() * m2.to_r()).to_rq::<T>(), res_recovered);
}
Ok(())
}
}

+ 59
- 2
tfhe/src/tlwe.rs

@ -132,21 +132,37 @@ pub fn blind_rotation
c_j c_j
} }
pub fn bootstrapping<const N: usize, const K: usize, const KN: usize, const KN2: u64>(
btk: BootstrappingKey<N, K, KN>,
table: TGLWE<N, K>,
c: TLWE<KN>,
) -> TLWE<KN> {
let rotated: TGLWE<N, K> = blind_rotation::<N, K, KN, KN2>(c, btk.clone(), table);
let c_h: TLWE<KN> = rotated.sample_extraction(0);
let r = c_h.key_switch(2, 64, &btk.1);
r
}
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct BootstrappingKey<const N: usize, const K: usize, const KN: usize>(pub Vec<TGGSW<N, K>>);
pub struct BootstrappingKey<const N: usize, const K: usize, const KN: usize>(
pub Vec<TGGSW<N, K>>,
pub KSK<KN>,
);
impl<const N: usize, const K: usize, const KN: usize> BootstrappingKey<N, K, KN> { impl<const N: usize, const K: usize, const KN: usize> BootstrappingKey<N, K, KN> {
pub fn from_sk(mut rng: impl Rng, sk: &tglwe::SecretKey<N, K>) -> Result<Self> { pub fn from_sk(mut rng: impl Rng, sk: &tglwe::SecretKey<N, K>) -> Result<Self> {
let (beta, l) = (2u32, 64u32); // TMP let (beta, l) = (2u32, 64u32); // TMP
// //
let s: TR<Tn<N>, K> = sk.0 .0.clone(); let s: TR<Tn<N>, K> = sk.0 .0.clone();
let (sk2, _) = TLWE::<KN>::new_key(&mut rng)?; // TLWE<KN> compatible with TGLWE<N,K>
// each btk_j = TGGSW_sk(s_i) // each btk_j = TGGSW_sk(s_i)
let btk: Vec<TGGSW<N, K>> = s let btk: Vec<TGGSW<N, K>> = s
.iter() .iter()
.map(|s_i| TGGSW::<N, K>::encrypt_s(&mut rng, beta, l, sk, s_i)) .map(|s_i| TGGSW::<N, K>::encrypt_s(&mut rng, beta, l, sk, s_i))
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
let ksk = TLWE::<KN>::new_ksk(&mut rng, beta, l, &sk.to_tlwe(), &sk2)?;
Ok(Self(btk))
Ok(Self(btk, ksk))
} }
} }
@ -399,4 +415,45 @@ mod tests {
Ok(()) Ok(())
} }
#[test]
fn test_bootstrapping() -> Result<()> {
const T: u64 = 128; // plaintext modulus
const K: usize = 1;
const N: usize = 1024;
const KN: usize = K * N;
let mut rng = rand::thread_rng();
let start = Instant::now();
let table: TGLWE<N, K> = compute_lookup_table::<T, K, N>();
println!("table took: {:?}", start.elapsed());
let (sk, _) = TGLWE::<N, K>::new_key::<KN>(&mut rng)?;
let sk_tlwe: SecretKey<KN> = sk.to_tlwe::<KN>();
let start = Instant::now();
let btk = BootstrappingKey::<N, K, KN>::from_sk(&mut rng, &sk)?;
println!("btk took: {:?}", start.elapsed());
let msg_dist = Uniform::new(0_u64, T);
let m = Rq::<T, 1>::rand_u64(&mut rng, msg_dist)?;
// let m = Rq::<T, 1>::from_vec(vec![Zq(5)]);
dbg!(&m);
let p = TLWE::<K>::encode::<T>(&m); // plaintext
let c = TLWE::<KN>::encrypt_s(&mut rng, &sk_tlwe, &p)?;
let start = Instant::now();
// the ugly const generics are temporary
let bootstrapped: TLWE<KN> =
bootstrapping::<N, K, KN, { K as u64 * N as u64 }>(btk, table, c);
println!("bootstrapping took: {:?}", start.elapsed());
let p_recovered: T64 = bootstrapped.decrypt(&sk_tlwe);
let m_recovered = TLWE::<KN>::decode::<T>(&p_recovered);
dbg!(&m_recovered);
assert_eq!(m_recovered, m);
Ok(())
}
} }

Loading…
Cancel
Save