diff --git a/tfhe/src/tlwe.rs b/tfhe/src/tlwe.rs index edd2e51..da5deef 100644 --- a/tfhe/src/tlwe.rs +++ b/tfhe/src/tlwe.rs @@ -10,6 +10,8 @@ use std::ops::{Add, AddAssign, Mul, Sub}; use arith::{Ring, Rq, Tn, T64, TR}; use gfhe::{glwe, GLWE}; +use crate::tlev::TLev; + // #[derive(Clone, Debug)] // pub struct SecretKey(glwe::SecretKey); pub type SecretKey = glwe::SecretKey; @@ -18,6 +20,9 @@ pub type SecretKey = glwe::SecretKey; // pub struct PublicKey(glwe::PublicKey); pub type PublicKey = glwe::PublicKey; +#[derive(Clone, Debug)] +pub struct KSK(Vec>); + #[derive(Clone, Debug)] pub struct TLWE(pub GLWE); @@ -55,6 +60,35 @@ impl TLWE { pub fn decrypt(&self, sk: &SecretKey) -> T64 { self.0.decrypt(&sk) } + + pub fn new_ksk( + mut rng: impl Rng, + beta: u32, + l: u32, + sk: &SecretKey, + new_sk: &SecretKey, + ) -> Result> { + let r: Vec> = (0..K) + .into_iter() + .map(|i| + // treat sk_i as the msg being encrypted + TLev::::encrypt_s(&mut rng, beta, l, &new_sk, &sk.0 .0[i])) + .collect::>>()?; + + Ok(KSK(r)) + } + pub fn key_switch(&self, beta: u32, l: u32, ksk: &KSK) -> Self { + let (a, b): (TR, T64) = (self.0 .0.clone(), self.0 .1); + + let lhs: TLWE = TLWE(GLWE(TR::zero(), b)); + + // K iterations, ksk.0 contains K times GLev + let rhs: TLWE = zip_eq(a.0, ksk.0.clone()) + .map(|(a_i, ksk_i)| ksk_i * a_i.decompose(beta, l)) // dot_product + .sum(); + + lhs - rhs + } } impl Add> for TLWE { @@ -248,4 +282,43 @@ mod tests { Ok(()) } + + #[test] + fn test_key_switch() -> Result<()> { + const T: u64 = 128; // plaintext modulus + const K: usize = 16; + type S = TLWE; + + let beta: u32 = 2; + let l: u32 = 64; + + let mut rng = rand::thread_rng(); + + let (sk, pk) = S::new_key(&mut rng)?; + let (sk2, _) = S::new_key(&mut rng)?; + // ksk to switch from sk to sk2 + let ksk = S::new_ksk(&mut rng, beta, l, &sk, &sk2)?; + + let msg_dist = Uniform::new(0_u64, T); + let m = Rq::::rand_u64(&mut rng, msg_dist)?; + let p = S::encode::(&m); // plaintext + // + let c = S::encrypt_s(&mut rng, &sk, &p)?; + + let c2 = c.key_switch(beta, l, &ksk); + + // decrypt with the 2nd secret key + let p_recovered = c2.decrypt(&sk2); + let m_recovered = S::decode::(&p_recovered); + assert_eq!(m.remodule::(), m_recovered.remodule::()); + + // do the same but now encrypting with pk + let c = S::encrypt(&mut rng, &pk, &p)?; + let c2 = c.key_switch(beta, l, &ksk); + let p_recovered = c2.decrypt(&sk2); + let m_recovered = S::decode::(&p_recovered); + assert_eq!(m, m_recovered); + + Ok(()) + } }