diff --git a/tfhe/src/lib.rs b/tfhe/src/lib.rs index 604bf17..784d84e 100644 --- a/tfhe/src/lib.rs +++ b/tfhe/src/lib.rs @@ -5,5 +5,6 @@ #![allow(clippy::upper_case_acronyms)] #![allow(dead_code)] // TMP +pub mod tgsw; pub mod tlev; pub mod tlwe; diff --git a/tfhe/src/tgsw.rs b/tfhe/src/tgsw.rs new file mode 100644 index 0000000..28c7593 --- /dev/null +++ b/tfhe/src/tgsw.rs @@ -0,0 +1,74 @@ +use anyhow::Result; +use itertools::zip_eq; +use rand::Rng; +use std::array; +use std::ops::{Add, Mul}; + +use arith::{Ring, Rq, Tn, T64, TR}; + +use crate::tlev::TLev; +use crate::tlwe::{PublicKey, SecretKey, TLWE}; +use gfhe::glwe::GLWE; + +/// vector of length K+1 = [K], [1] +#[derive(Clone, Debug)] +pub struct TGSW(pub(crate) Vec>, TLev); + +impl TGSW { + pub fn encrypt_s( + mut rng: impl Rng, + beta: u32, + l: u32, + sk: &SecretKey, + m: &T64, + ) -> Result { + let a: Vec> = (0..K) + .map(|i| TLev::encrypt_s(&mut rng, beta, l, sk, &(-sk.0 .0[i] * *m))) + .collect::>>()?; + let b: TLev = TLev::encrypt_s(&mut rng, beta, l, sk, m)?; + Ok(Self(a, b)) + } + + pub fn decrypt(&self, sk: &SecretKey, beta: u32) -> T64 { + self.1.decrypt(sk, beta) + } + pub fn from_tlwe(_tlwe: TLWE) -> Self { + todo!() + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use rand::distributions::Uniform; + + use super::*; + + #[test] + fn test_encrypt_decrypt() -> Result<()> { + const T: u64 = 2; // plaintext modulus + const K: usize = 16; + type S = TGSW; + + let beta: u32 = 2; + let l: u32 = 16; + + let mut rng = rand::thread_rng(); + let msg_dist = Uniform::new(0_u64, T); + + for _ in 0..50 { + let (sk, _) = TLWE::::new_key(&mut rng)?; + + let m: Rq = Rq::rand_u64(&mut rng, msg_dist)?; + let p: T64 = TLev::::encode::(&m); // plaintext + + let c = S::encrypt_s(&mut rng, beta, l, &sk, &p)?; + let p_recovered = c.decrypt(&sk, beta); + let m_recovered = TLev::::decode::(&p_recovered); + + assert_eq!(m, m_recovered); + } + + Ok(()) + } +} diff --git a/tfhe/src/tlev.rs b/tfhe/src/tlev.rs index de7a2ab..8209330 100644 --- a/tfhe/src/tlev.rs +++ b/tfhe/src/tlev.rs @@ -1,4 +1,5 @@ use anyhow::Result; +use itertools::zip_eq; use rand::Rng; use std::array; use std::ops::{Add, Mul}; @@ -11,11 +12,11 @@ use crate::tlwe::{PublicKey, SecretKey, TLWE}; pub struct TLev(pub(crate) Vec>); impl TLev { - pub fn encode(m: &Rq) -> Tn<1> { + pub fn encode(m: &Rq) -> T64 { let coeffs = m.coeffs(); - Tn(array::from_fn(|i| T64(coeffs[i].0))) + T64(coeffs[0].0) // N=1, so take the only coeff } - pub fn decode(p: &Tn<1>) -> Rq { + pub fn decode(p: &T64) -> Rq { Rq::::from_vec_u64(p.coeffs().iter().map(|c| c.0).collect()) } pub fn encrypt( @@ -23,7 +24,7 @@ impl TLev { beta: u32, l: u32, pk: &PublicKey, - m: &Tn<1>, + m: &T64, ) -> Result { let tlev: Vec> = (1..l + 1) .map(|i| { @@ -35,25 +36,33 @@ impl TLev { } pub fn encrypt_s( mut rng: impl Rng, - beta: u32, + _beta: u32, // TODO rm, and make beta=2 always l: u32, sk: &SecretKey, - m: &Tn<1>, + m: &T64, ) -> Result { - let tlev: Vec> = (1..l + 1) + let tlev: Vec> = (1..l as u64 + 1) .map(|i| { - TLWE::::encrypt_s(&mut rng, sk, &(*m * (u64::MAX / beta.pow(i as u32) as u64))) + let aux = if i < 64 { + *m * (u64::MAX / (1u64 << i)) + } else { + // 1<<64 would overflow, and anyways we're dividing u64::MAX + // by it, which would be equal to 1 + *m + }; + TLWE::::encrypt_s(&mut rng, sk, &aux) }) .collect::>>()?; Ok(Self(tlev)) } - pub fn decrypt(&self, sk: &SecretKey, beta: u32) -> Tn<1> { + pub fn decrypt(&self, sk: &SecretKey, beta: u32) -> T64 { let pt = self.0[0].decrypt(sk); pt.mul_div_round(beta as u64, u64::MAX) } } +// TODO review u64::MAX, since is -1 of the value we actually want #[cfg(test)] mod tests { @@ -78,7 +87,7 @@ mod tests { let (sk, pk) = TLWE::::new_key(&mut rng)?; let m: Rq = Rq::rand_u64(&mut rng, msg_dist)?; - let p: Tn<1> = S::encode::(&m); // plaintext + let p: T64 = S::encode::(&m); // plaintext let c = S::encrypt(&mut rng, beta, l, &pk, &p)?; let p_recovered = c.decrypt(&sk, beta); diff --git a/tfhe/src/tlwe.rs b/tfhe/src/tlwe.rs index 6ab4809..edd2e51 100644 --- a/tfhe/src/tlwe.rs +++ b/tfhe/src/tlwe.rs @@ -10,47 +10,50 @@ use std::ops::{Add, AddAssign, Mul, Sub}; use arith::{Ring, Rq, Tn, T64, TR}; use gfhe::{glwe, GLWE}; -const ERR_SIGMA: f64 = 3.2; +// #[derive(Clone, Debug)] +// pub struct SecretKey(glwe::SecretKey); +pub type SecretKey = glwe::SecretKey; -#[derive(Clone, Debug)] -pub struct SecretKey(glwe::SecretKey, K>); -#[derive(Clone, Debug)] -pub struct PublicKey(glwe::PublicKey, K>); +// #[derive(Clone, Debug)] +// pub struct PublicKey(glwe::PublicKey); +pub type PublicKey = glwe::PublicKey; #[derive(Clone, Debug)] -pub struct TLWE(pub GLWE, K>); +pub struct TLWE(pub GLWE); impl TLWE { pub fn zero() -> Self { - Self(GLWE::, K>::zero()) + Self(GLWE::::zero()) } pub fn new_key(rng: impl Rng) -> Result<(SecretKey, PublicKey)> { let (sk, pk) = GLWE::new_key(rng)?; - Ok((SecretKey(sk), PublicKey(pk))) + // Ok((SecretKey(sk), PublicKey(pk))) + Ok((sk, pk)) } - pub fn encode(m: &Rq) -> Tn<1> { + pub fn encode(m: &Rq) -> T64 { let delta = u64::MAX / P; // floored let coeffs = m.coeffs(); - Tn(array::from_fn(|i| T64(coeffs[i].0 * delta))) + // Tn(array::from_fn(|i| T64(coeffs[i].0 * delta))) + T64(coeffs[0].0 * delta) } - pub fn decode(p: &Tn<1>) -> Rq { + pub fn decode(p: &T64) -> Rq { let p = p.mul_div_round(P, u64::MAX); Rq::::from_vec_u64(p.coeffs().iter().map(|c| c.0).collect()) } // encrypts with the given SecretKey (instead of PublicKey) - pub fn encrypt_s(rng: impl Rng, sk: &SecretKey, p: &Tn<1>) -> Result { - let glwe = GLWE::encrypt_s(rng, &sk.0, p)?; + pub fn encrypt_s(rng: impl Rng, sk: &SecretKey, p: &T64) -> Result { + let glwe = GLWE::encrypt_s(rng, &sk, p)?; Ok(Self(glwe)) } - pub fn encrypt(rng: impl Rng, pk: &PublicKey, p: &Tn<1>) -> Result { - let glwe = GLWE::encrypt(rng, &pk.0, p)?; + pub fn encrypt(rng: impl Rng, pk: &PublicKey, p: &T64) -> Result { + let glwe = GLWE::encrypt(rng, &pk, p)?; Ok(Self(glwe)) } - pub fn decrypt(&self, sk: &SecretKey) -> Tn<1> { - self.0.decrypt(&sk.0) + pub fn decrypt(&self, sk: &SecretKey) -> T64 { + self.0.decrypt(&sk) } } @@ -86,29 +89,29 @@ impl Sub> for TLWE { } // plaintext addition -impl Add> for TLWE { +impl Add for TLWE { type Output = Self; - fn add(self, plaintext: Tn<1>) -> Self { - let a: TR, K> = self.0 .0; - let b: Tn<1> = self.0 .1 + plaintext; + fn add(self, plaintext: T64) -> Self { + let a: TR = self.0 .0; + let b: T64 = self.0 .1 + plaintext; Self(GLWE(a, b)) } } // plaintext substraction -impl Sub> for TLWE { +impl Sub for TLWE { type Output = Self; - fn sub(self, plaintext: Tn<1>) -> Self { - let a: TR, K> = self.0 .0; - let b: Tn<1> = self.0 .1 - plaintext; + fn sub(self, plaintext: T64) -> Self { + let a: TR = self.0 .0; + let b: T64 = self.0 .1 - plaintext; Self(GLWE(a, b)) } } // plaintext multiplication -impl Mul> for TLWE { +impl Mul for TLWE { type Output = Self; - fn mul(self, plaintext: Tn<1>) -> Self { - let a: TR, K> = TR(self.0 .0 .0.iter().map(|r_i| *r_i * plaintext).collect()); - let b: Tn<1> = self.0 .1 * plaintext; + fn mul(self, plaintext: T64) -> Self { + let a: TR = TR(self.0 .0 .0.iter().map(|r_i| *r_i * plaintext).collect()); + let b: T64 = self.0 .1 * plaintext; Self(GLWE(a, b)) } } @@ -134,7 +137,7 @@ mod tests { let m = Rq::::rand_u64(&mut rng, msg_dist)?; dbg!(&m); - let p: Tn<1> = S::encode::(&m); + let p: T64 = S::encode::(&m); dbg!(&p); let c = S::encrypt(&mut rng, &pk, &p)?; @@ -168,8 +171,8 @@ mod tests { let m1 = Rq::::rand_u64(&mut rng, msg_dist)?; let m2 = Rq::::rand_u64(&mut rng, msg_dist)?; - let p1: Tn<1> = S::encode::(&m1); // plaintext - let p2: Tn<1> = S::encode::(&m2); // plaintext + let p1: T64 = S::encode::(&m1); // plaintext + let p2: T64 = S::encode::(&m2); // plaintext let c1 = S::encrypt(&mut rng, &pk, &p1)?; let c2 = S::encrypt(&mut rng, &pk, &p2)?; @@ -199,8 +202,8 @@ mod tests { let m1 = Rq::::rand_u64(&mut rng, msg_dist)?; let m2 = Rq::::rand_u64(&mut rng, msg_dist)?; - let p1: Tn<1> = S::encode::(&m1); // plaintext - let p2: Tn<1> = S::encode::(&m2); // plaintext + let p1: T64 = S::encode::(&m1); // plaintext + let p2: T64 = S::encode::(&m2); // plaintext let c1 = S::encrypt(&mut rng, &pk, &p1)?; @@ -209,7 +212,7 @@ mod tests { let p3_recovered = c3.decrypt(&sk); let m3_recovered = S::decode::(&p3_recovered); - assert_eq!((m1 + m2).remodule::(), m3_recovered); + assert_eq!(m1 + m2, m3_recovered); } Ok(()) @@ -229,15 +232,16 @@ mod tests { let m1 = Rq::::rand_u64(&mut rng, msg_dist)?; let m2 = Rq::::rand_u64(&mut rng, msg_dist)?; - let p1: Tn<1> = S::encode::(&m1); + let p1: T64 = S::encode::(&m1); // don't scale up p2, set it directly from m2 - let p2: Tn<1> = Tn(array::from_fn(|i| T64(m2.coeffs()[i].0))); + // let p2: T64 = Tn(array::from_fn(|i| T64(m2.coeffs()[i].0))); + let p2: T64 = T64(m2.coeffs()[0].0); let c1 = S::encrypt(&mut rng, &pk, &p1)?; let c3 = c1 * p2; - let p3_recovered: Tn<1> = c3.decrypt(&sk); + let p3_recovered: T64 = c3.decrypt(&sk); let m3_recovered = S::decode::(&p3_recovered); assert_eq!((m1.to_r() * m2.to_r()).to_rq::(), m3_recovered); }