diff --git a/tfhe/src/tlwe.rs b/tfhe/src/tlwe.rs index bc5bdee..a5669a4 100644 --- a/tfhe/src/tlwe.rs +++ b/tfhe/src/tlwe.rs @@ -7,7 +7,7 @@ use std::array; use std::iter::Sum; use std::ops::{Add, AddAssign, Mul, Sub}; -use arith::{Ring, Rq, Tn, Zq, T64, TR}; +use arith::{Ring, Rq, Tn, T64, TR}; const ERR_SIGMA: f64 = 3.2; @@ -36,12 +36,12 @@ impl TLWE { Ok((SecretKey(s), pk)) } - pub fn encode(m: Rq) -> Tn<1> { + pub fn encode(m: &Rq) -> Tn<1> { let delta = u64::MAX / P; // floored let coeffs = m.coeffs(); Tn(array::from_fn(|i| T64(coeffs[i].0 * delta))) } - pub fn decode(p: Tn<1>) -> Rq { + pub fn decode(p: &Tn<1>) -> Rq { let p = p.mul_div_round(P, u64::MAX); Rq::::from_vec_u64(p.coeffs().iter().map(|c| c.0).collect()) } @@ -77,6 +77,63 @@ impl TLWE { } } +impl Add> for TLWE { + type Output = Self; + fn add(self, other: Self) -> Self { + let a: TR, K> = self.0 + other.0; + let b: Tn<1> = self.1 + other.1; + Self(a, b) + } +} +impl AddAssign for TLWE { + fn add_assign(&mut self, rhs: Self) { + for i in 0..K { + self.0 .0[i] = self.0 .0[i] + rhs.0 .0[i]; + } + self.1 = self.1 + rhs.1; + } +} +impl Sum> for TLWE { + fn sum(iter: I) -> Self + where + I: Iterator, + { + let mut acc = TLWE::::zero(); + for e in iter { + acc += e; + } + acc + } +} + +impl Sub> for TLWE { + type Output = Self; + fn sub(self, other: Self) -> Self { + let a: TR, K> = self.0 - other.0; + let b: Tn<1> = self.1 - other.1; + Self(a, b) + } +} + +// plaintext addition +impl Add> for TLWE { + type Output = Self; + fn add(self, plaintext: Tn<1>) -> Self { + let a: TR, K> = self.0; + let b: Tn<1> = self.1 + plaintext; + Self(a, b) + } +} +// plaintext substraction +impl Sub> for TLWE { + type Output = Self; + fn sub(self, plaintext: Tn<1>) -> Self { + let a: TR, K> = self.0; + let b: Tn<1> = self.1 - plaintext; + Self(a, b) + } +} + #[cfg(test)] mod tests { use anyhow::Result; @@ -86,7 +143,7 @@ mod tests { #[test] fn test_encrypt_decrypt() -> Result<()> { - const T: u64 = 32; // plaintext modulus + const T: u64 = 128; // msg space (msg modulus) const K: usize = 16; type S = TLWE; @@ -98,23 +155,84 @@ mod tests { let msg_dist = Uniform::new(0_u64, T); let m = Rq::::rand_u64(&mut rng, msg_dist)?; dbg!(&m); - let p: Tn<1> = S::encode::(m); + let p: Tn<1> = S::encode::(&m); dbg!(&p); let c = S::encrypt(&mut rng, &pk, &p)?; let p_recovered = c.decrypt(&sk); - let m_recovered = S::decode::(p_recovered); + let m_recovered = S::decode::(&p_recovered); assert_eq!(m, m_recovered); // same but using encrypt_s (with sk instead of pk)) let c = S::encrypt_s(&mut rng, &sk, &p)?; let p_recovered = c.decrypt(&sk); - let m_recovered = S::decode::(p_recovered); + let m_recovered = S::decode::(&p_recovered); assert_eq!(m, m_recovered); } Ok(()) } + + #[test] + fn test_addition() -> Result<()> { + const T: u64 = 128; + const K: usize = 16; + type S = TLWE; + + let mut rng = rand::thread_rng(); + + for _ in 0..200 { + let (sk, pk) = S::new_key(&mut rng)?; + + let msg_dist = Uniform::new(0_u64, T); + let m1 = Rq::::rand_u64(&mut rng, msg_dist)?; + let m2 = Rq::::rand_u64(&mut rng, msg_dist)?; + let p1: Tn<1> = S::encode::(&m1); // plaintext + let p2: Tn<1> = S::encode::(&m2); // plaintext + + let c1 = S::encrypt(&mut rng, &pk, &p1)?; + let c2 = S::encrypt(&mut rng, &pk, &p2)?; + + let c3 = c1 + c2; + + let p3_recovered = c3.decrypt(&sk); + let m3_recovered = S::decode::(&p3_recovered); + + assert_eq!((m1 + m2).remodule::(), m3_recovered.remodule::()); + } + + Ok(()) + } + + #[test] + fn test_add_plaintext() -> Result<()> { + const T: u64 = 128; + const K: usize = 16; + type S = TLWE; + + let mut rng = rand::thread_rng(); + + for _ in 0..200 { + let (sk, pk) = S::new_key(&mut rng)?; + + let msg_dist = Uniform::new(0_u64, T); + let m1 = Rq::::rand_u64(&mut rng, msg_dist)?; + let m2 = Rq::::rand_u64(&mut rng, msg_dist)?; + let p1: Tn<1> = S::encode::(&m1); // plaintext + let p2: Tn<1> = S::encode::(&m2); // plaintext + + let c1 = S::encrypt(&mut rng, &pk, &p1)?; + + let c3 = c1 + p2; + + let p3_recovered = c3.decrypt(&sk); + let m3_recovered = S::decode::(&p3_recovered); + + assert_eq!((m1 + m2).remodule::(), m3_recovered); + } + + Ok(()) + } }