From 752525a1c64740d9cc1960c72d101a7230f597d1 Mon Sep 17 00:00:00 2001 From: arnaucube Date: Mon, 28 Jul 2025 12:14:44 +0000 Subject: [PATCH] add TGLWE logic (pending to abstract it with TLWE to reuse part of the impl) --- tfhe/src/lib.rs | 1 + tfhe/src/tglwe.rs | 248 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 249 insertions(+) create mode 100644 tfhe/src/tglwe.rs diff --git a/tfhe/src/lib.rs b/tfhe/src/lib.rs index 784d84e..9db005b 100644 --- a/tfhe/src/lib.rs +++ b/tfhe/src/lib.rs @@ -5,6 +5,7 @@ #![allow(clippy::upper_case_acronyms)] #![allow(dead_code)] // TMP +pub mod tglwe; pub mod tgsw; pub mod tlev; pub mod tlwe; diff --git a/tfhe/src/tglwe.rs b/tfhe/src/tglwe.rs new file mode 100644 index 0000000..68eacf7 --- /dev/null +++ b/tfhe/src/tglwe.rs @@ -0,0 +1,248 @@ +use anyhow::Result; +use itertools::zip_eq; +use rand::distributions::Standard; +use rand::Rng; +use rand_distr::{Normal, Uniform}; +use std::array; +use std::iter::Sum; +use std::ops::{Add, AddAssign, Mul, Sub}; + +use arith::{Ring, Rq, Tn, T64, TR}; +use gfhe::{glwe, GLWE}; + +use crate::tlev::TLev; + +pub type SecretKey = glwe::SecretKey, K>; +pub type PublicKey = glwe::PublicKey, K>; + +#[derive(Clone, Debug)] +pub struct TGLWE(pub GLWE, K>); + +impl TGLWE { + pub fn zero() -> Self { + Self(GLWE::, K>::zero()) + } + + pub fn new_key(rng: impl Rng) -> Result<(SecretKey, PublicKey)> { + let (sk, pk) = GLWE::new_key(rng)?; + // Ok((SecretKey(sk), PublicKey(pk))) + Ok((sk, pk)) + } + + pub fn encode(m: &Rq) -> Tn { + let delta = u64::MAX / P; // floored + let coeffs = m.coeffs(); + Tn(array::from_fn(|i| T64(coeffs[i].0 * delta))) + } + pub fn decode(p: &Tn) -> Rq { + let p = p.mul_div_round(P, u64::MAX); + Rq::::from_vec_u64(p.coeffs().iter().map(|c| c.0).collect()) + } + + // encrypts with the given SecretKey (instead of PublicKey) + pub fn encrypt_s(rng: impl Rng, sk: &SecretKey, p: &Tn) -> Result { + let glwe = GLWE::encrypt_s(rng, &sk, p)?; + Ok(Self(glwe)) + } + pub fn encrypt(rng: impl Rng, pk: &PublicKey, p: &Tn) -> Result { + let glwe = GLWE::encrypt(rng, &pk, p)?; + Ok(Self(glwe)) + } + pub fn decrypt(&self, sk: &SecretKey) -> Tn { + self.0.decrypt(&sk) + } +} + +impl Add> for TGLWE { + type Output = Self; + fn add(self, other: Self) -> Self { + Self(self.0 + other.0) + } +} +impl AddAssign for TGLWE { + fn add_assign(&mut self, rhs: Self) { + self.0 += rhs.0 + } +} +impl Sum> for TGLWE { + fn sum(iter: I) -> Self + where + I: Iterator, + { + let mut acc = TGLWE::::zero(); + for e in iter { + acc += e; + } + acc + } +} + +impl Sub> for TGLWE { + type Output = Self; + fn sub(self, other: Self) -> Self { + Self(self.0 - other.0) + } +} + +// plaintext addition +impl Add> for TGLWE { + type Output = Self; + fn add(self, plaintext: Tn) -> Self { + let a: TR, K> = self.0 .0; + let b: Tn = self.0 .1 + plaintext; + Self(GLWE(a, b)) + } +} +// plaintext substraction +impl Sub> for TGLWE { + type Output = Self; + fn sub(self, plaintext: Tn) -> Self { + let a: TR, K> = self.0 .0; + let b: Tn = self.0 .1 - plaintext; + Self(GLWE(a, b)) + } +} +// plaintext multiplication +impl Mul> for TGLWE { + type Output = Self; + fn mul(self, plaintext: Tn) -> Self { + let a: TR, K> = TR(self.0 .0 .0.iter().map(|r_i| *r_i * plaintext).collect()); + let b: Tn = self.0 .1 * plaintext; + Self(GLWE(a, b)) + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use rand::distributions::Uniform; + + use super::*; + + #[test] + fn test_encrypt_decrypt() -> Result<()> { + const T: u64 = 128; // msg space (msg modulus) + const N: usize = 64; + const K: usize = 16; + type S = TGLWE; + + let mut rng = rand::thread_rng(); + let msg_dist = Uniform::new(0_u64, T); + + for _ in 0..200 { + let (sk, pk) = S::new_key(&mut rng)?; + + let m = Rq::::rand_u64(&mut rng, msg_dist)?; + let p: Tn = S::encode::(&m); + + let c = S::encrypt(&mut rng, &pk, &p)?; + let p_recovered = c.decrypt(&sk); + let m_recovered = S::decode::(&p_recovered); + + assert_eq!(m, m_recovered); + + // same but using encrypt_s (with sk instead of pk)) + let c = S::encrypt_s(&mut rng, &sk, &p)?; + let p_recovered = c.decrypt(&sk); + let m_recovered = S::decode::(&p_recovered); + + assert_eq!(m, m_recovered); + } + + Ok(()) + } + + #[test] + fn test_addition() -> Result<()> { + const T: u64 = 128; + const N: usize = 64; + const K: usize = 16; + type S = TGLWE; + + let mut rng = rand::thread_rng(); + let msg_dist = Uniform::new(0_u64, T); + + for _ in 0..200 { + let (sk, pk) = S::new_key(&mut rng)?; + + let m1 = Rq::::rand_u64(&mut rng, msg_dist)?; + let m2 = Rq::::rand_u64(&mut rng, msg_dist)?; + let p1: Tn = S::encode::(&m1); // plaintext + let p2: Tn = S::encode::(&m2); // plaintext + + let c1 = S::encrypt(&mut rng, &pk, &p1)?; + let c2 = S::encrypt(&mut rng, &pk, &p2)?; + + let c3 = c1 + c2; + + let p3_recovered = c3.decrypt(&sk); + let m3_recovered = S::decode::(&p3_recovered); + + assert_eq!((m1 + m2).remodule::(), m3_recovered.remodule::()); + } + + Ok(()) + } + + #[test] + fn test_add_plaintext() -> Result<()> { + const T: u64 = 128; + const N: usize = 64; + const K: usize = 16; + type S = TGLWE; + + let mut rng = rand::thread_rng(); + let msg_dist = Uniform::new(0_u64, T); + + for _ in 0..200 { + let (sk, pk) = S::new_key(&mut rng)?; + + let m1 = Rq::::rand_u64(&mut rng, msg_dist)?; + let m2 = Rq::::rand_u64(&mut rng, msg_dist)?; + let p1: Tn = S::encode::(&m1); // plaintext + let p2: Tn = S::encode::(&m2); // plaintext + + let c1 = S::encrypt(&mut rng, &pk, &p1)?; + + let c3 = c1 + p2; + + let p3_recovered = c3.decrypt(&sk); + let m3_recovered = S::decode::(&p3_recovered); + + assert_eq!(m1 + m2, m3_recovered); + } + + Ok(()) + } + + #[test] + fn test_mul_plaintext() -> Result<()> { + const T: u64 = 128; + const N: usize = 64; + const K: usize = 16; + type S = TGLWE; + + let mut rng = rand::thread_rng(); + let msg_dist = Uniform::new(0_u64, T); + + for _ in 0..200 { + let (sk, pk) = S::new_key(&mut rng)?; + + let m1 = Rq::::rand_u64(&mut rng, msg_dist)?; + let m2 = Rq::::rand_u64(&mut rng, msg_dist)?; + let p1: Tn = S::encode::(&m1); + // don't scale up p2, set it directly from m2 + let p2: Tn = Tn(array::from_fn(|i| T64(m2.coeffs()[i].0))); + + let c1 = S::encrypt(&mut rng, &pk, &p1)?; + + let c3 = c1 * p2; + + let p3_recovered: Tn = c3.decrypt(&sk); + let m3_recovered = S::decode::(&p3_recovered); + assert_eq!((m1.to_r() * m2.to_r()).to_rq::(), m3_recovered); + } + + Ok(()) + } +}