//! Generalized LWE. //! use anyhow::Result; use itertools::zip_eq; use rand::Rng; use rand_distr::{Normal, Uniform}; use std::iter::Sum; use std::ops::{Add, AddAssign, Mul, Sub}; use arith::{Ring, RingParam, Rq, Zq, TR}; use crate::glev::GLev; // error deviation for the Gaussian(Normal) distribution // sigma=3.2 from: https://eprint.iacr.org/2022/162.pdf page 5 pub(crate) const ERR_SIGMA: f64 = 3.2; #[derive(Clone, Copy, Debug)] pub struct Param { pub err_sigma: f64, pub ring: RingParam, pub k: usize, pub t: u64, } impl Param { /// returns the plaintext param pub fn pt(&self) -> RingParam { // TODO think if maybe return a new truct "PtParam" to differenciate // between the ciphertexxt (RingParam) and the plaintext param. Maybe it // can be just a wrapper on top of RingParam. RingParam { q: self.t, n: self.ring.n, } } /// returns the LWE param for the given GLWE (self), that is, it uses k=K*N /// as the length for the secret key. This follows [2018-421] where /// TLWE sk: s \in B^n , where n=K*N /// TRLWE sk: s \in B_N[X]^K pub fn lwe(&self) -> Self { Self { err_sigma: ERR_SIGMA, ring: RingParam { q: self.ring.q, n: 1, }, k: self.k * self.ring.n, t: self.t, } } } /// GLWE implemented over the `Ring` trait, so that it can be also instantiated /// over the Torus polynomials 𝕋_[X] = 𝕋_q[X]/ (X^N+1). #[derive(Clone, Debug)] pub struct GLWE(pub TR, pub R); #[derive(Clone, Debug)] pub struct SecretKey(pub TR); #[derive(Clone, Debug)] pub struct PublicKey(pub R, pub TR); // K GLevs, each KSK_i=l GLWEs #[derive(Clone, Debug)] pub struct KSK(Vec>); impl GLWE { pub fn zero(k: usize, param: &RingParam) -> Self { Self(TR::zero(k, ¶m), R::zero(¶m)) } pub fn from_plaintext(k: usize, param: &RingParam, p: R) -> Self { Self(TR::zero(k, ¶m), p) } pub fn new_key(mut rng: impl Rng, param: &Param) -> Result<(SecretKey, PublicKey)> { let Xi_key = Uniform::new(0_f64, 2_f64)?; let Xi_err = Normal::new(0_f64, param.err_sigma)?; let s: TR = TR::rand(&mut rng, Xi_key, param.k, ¶m.ring); let a: TR = TR::rand( &mut rng, Uniform::new(0_f64, param.ring.q as f64)?, param.k, ¶m.ring, ); let e = R::rand(&mut rng, Xi_err, ¶m.ring); let pk: PublicKey = PublicKey((&a * &s) + e, a); Ok((SecretKey(s), pk)) } pub fn pk_from_sk(mut rng: impl Rng, param: &Param, sk: SecretKey) -> Result> { let Xi_err = Normal::new(0_f64, param.err_sigma)?; let a: TR = TR::rand( &mut rng, Uniform::new(0_f64, param.ring.q as f64)?, param.k, ¶m.ring, ); let e = R::rand(&mut rng, Xi_err, ¶m.ring); let pk: PublicKey = PublicKey((&a * &sk.0) + e, a); Ok(pk) } pub fn new_ksk( mut rng: impl Rng, param: &Param, beta: u32, l: u32, sk: &SecretKey, new_sk: &SecretKey, ) -> Result> { debug_assert_eq!(param.k, sk.0.k); let k = sk.0.k; let r: Vec> = (0..k) .into_iter() .map(|i| // treat sk_i as the msg being encrypted GLev::::encrypt_s(&mut rng, param, beta, l, &new_sk, &sk.0 .r[i])) .collect::>>()?; Ok(KSK(r)) } pub fn key_switch(&self, param: &Param, beta: u32, l: u32, ksk: &KSK) -> Self { let (a, b): (TR, R) = (self.0.clone(), self.1.clone()); // TODO rm clones let lhs: GLWE = GLWE(TR::zero(param.k, ¶m.ring), b); // K iterations, ksk.0 contains K times GLev let rhs: GLWE = zip_eq(a.r, ksk.0.clone()) .map(|(a_i, ksk_i)| ksk_i * a_i.decompose(beta, l)) // dot_product .sum(); lhs - rhs } // encrypts with the given SecretKey (instead of PublicKey) pub fn encrypt_s( mut rng: impl Rng, param: &Param, sk: &SecretKey, m: &R, // already scaled ) -> Result { let Xi_key = Uniform::new(0_f64, 2_f64)?; let Xi_err = Normal::new(0_f64, param.err_sigma)?; let a: TR = TR::rand(&mut rng, Xi_key, param.k, ¶m.ring); let e = R::rand(&mut rng, Xi_err, ¶m.ring); let b: R = (&a * &sk.0) + m.clone() + e; // TODO rm clone Ok(Self(a, b)) } pub fn encrypt( mut rng: impl Rng, param: &Param, pk: &PublicKey, m: &R, // already scaled ) -> Result { let Xi_key = Uniform::new(0_f64, 2_f64)?; let Xi_err = Normal::new(0_f64, param.err_sigma)?; let u: R = R::rand(&mut rng, Xi_key, ¶m.ring); let e0 = R::rand(&mut rng, Xi_err, ¶m.ring); let e1 = TR::::rand(&mut rng, Xi_err, param.k, ¶m.ring); let b: R = pk.0.clone() * u.clone() + m.clone() + e0; // TODO rm clones let d: TR = &pk.1 * &u + e1; Ok(Self(d, b)) } // returns m' not downscaled pub fn decrypt(&self, sk: &SecretKey) -> R { let (d, b): (TR, R) = (self.0.clone(), self.1.clone()); let p: R = b - &d * &sk.0; p } } // Methods for when Ring=Rq impl GLWE { // scale up pub fn encode(param: &Param, m: &Rq) -> Rq { debug_assert_eq!(param.t, m.param.q); let m = m.remodule(param.ring.q); let delta = param.ring.q / param.t; // floored m * delta } // scale down pub fn decode(param: &Param, m: &Rq) -> Rq { let r = m.mul_div_round(param.t, param.ring.q); let r: Rq = r.remodule(param.t); r } pub fn mod_switch(&self, p: u64) -> GLWE { let a: TR = TR { k: self.0.k, r: self.0.r.iter().map(|r| r.mod_switch(p)).collect::>(), }; let b: Rq = self.1.mod_switch(p); GLWE(a, b) } } impl Add> for GLWE { type Output = Self; fn add(self, other: Self) -> Self { debug_assert_eq!(self.0.k, other.0.k); debug_assert_eq!(self.1.param(), other.1.param()); let a: TR = self.0 + other.0; let b: R = self.1 + other.1; Self(a, b) } } impl Add for GLWE { type Output = Self; fn add(self, plaintext: R) -> Self { debug_assert_eq!(self.1.param(), plaintext.param()); let a: TR = self.0; let b: R = self.1 + plaintext; Self(a, b) } } impl AddAssign for GLWE { fn add_assign(&mut self, rhs: Self) { debug_assert_eq!(self.0.k, rhs.0.k); debug_assert_eq!(self.1.param(), rhs.1.param()); let k = self.0.k; for i in 0..k { self.0.r[i] = self.0.r[i].clone() + rhs.0.r[i].clone(); } self.1 = self.1.clone() + rhs.1.clone(); } } impl Sum> for GLWE { fn sum(mut iter: I) -> Self where I: Iterator, { let first = iter.next().unwrap(); iter.fold(first, |acc, e| acc + e) } } impl Sub> for GLWE { type Output = Self; fn sub(self, other: Self) -> Self { debug_assert_eq!(self.0.k, other.0.k); debug_assert_eq!(self.1.param(), other.1.param()); let a: TR = self.0 - other.0; let b: R = self.1 - other.1; Self(a, b) } } impl Mul for GLWE { type Output = Self; fn mul(self, plaintext: R) -> Self { debug_assert_eq!(self.1.param(), plaintext.param()); let a: TR = TR { k: self.0.k, r: self .0 .r .iter() .map(|r_i| r_i.clone() * plaintext.clone()) .collect(), }; let b: R = self.1 * plaintext; Self(a, b) } } // for when R = Rq // impl Mul> for GLWE, K> { // type Output = Self; // fn mul(self, plaintext: Rq) -> Self { // // first compute the NTT for plaintext, to avoid computing it at each // // iteration, speeding up the multiplications // let mut plaintext = plaintext.clone(); // plaintext.compute_evals(); // // let a: TR, K> = TR(self.0 .0.iter().map(|r_i| *r_i * plaintext).collect()); // let b: Rq = self.1 * plaintext; // Self(a, b) // } // } // impl Mul for GLWE // // where // // // R: std::ops::Mul<::C>, // // // Vec: FromIterator<::C>>::Output>, // // Vec: FromIterator<::C>>::Output>, // { // type Output = Self; // fn mul(self, e: R::C) -> Self { // let a: TR = TR(self.0 .0.iter().map(|r_i| *r_i * e.clone()).collect()); // let b: R = self.1 * e.clone(); // Self(a, b) // } // } // impl Mul> for GLWE { // type Output = Self; // fn mul(self, e: Zq) -> Self { // let a: TR, K> = TR(self.0 .0.iter().map(|r_i| *r_i * e).collect()); // let b: Rq = self.1 * e; // Self(a, b) // } // } #[cfg(test)] mod tests { use anyhow::Result; use rand::distr::Uniform; use super::*; #[test] fn test_encrypt_decrypt_ring_nq() -> Result<()> { let param = Param { err_sigma: ERR_SIGMA, ring: RingParam { q: 2u64.pow(16) + 1, n: 128, }, k: 16, t: 32, // plaintext modulus }; type S = GLWE; let mut rng = rand::rng(); let msg_dist = Uniform::new(0_u64, param.t)?; for _ in 0..200 { let (sk, pk) = S::new_key(&mut rng, ¶m)?; let m = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; // msg let p = S::encode(¶m, &m); // plaintext let c = S::encrypt(&mut rng, ¶m, &pk, &p)?; // ciphertext let p_recovered = c.decrypt(&sk); let m_recovered = S::decode(¶m, &p_recovered); assert_eq!(m.remodule(param.t), m_recovered.remodule(param.t)); // same but using encrypt_s (with sk instead of pk)) let c = S::encrypt_s(&mut rng, ¶m, &sk, &p)?; let p_recovered = c.decrypt(&sk); let m_recovered = S::decode(¶m, &p_recovered); assert_eq!(m.remodule(param.t), m_recovered.remodule(param.t)); } Ok(()) } use arith::{Tn, T64}; pub fn t_encode(param: &RingParam, m: &Rq) -> Tn { let p = m.param.q; // plaintext space let delta = u64::MAX / p; // floored let coeffs = m.coeffs(); Tn { param: *param, coeffs: coeffs.iter().map(|c_i| T64(c_i.v * delta)).collect(), } } pub fn t_decode(param: &Param, pt: &Tn) -> Rq { let pt = pt.mul_div_round(param.t, u64::MAX); Rq::from_vec_u64(¶m.pt(), pt.coeffs().iter().map(|c| c.0).collect()) } #[test] fn test_encrypt_decrypt_torus() -> Result<()> { let param = Param { err_sigma: ERR_SIGMA, ring: RingParam { q: u64::MAX, n: 128, }, k: 16, t: 32, // plaintext modulus }; type S = GLWE; let mut rng = rand::rng(); let msg_dist = Uniform::new(0_f64, param.t as f64)?; for _ in 0..200 { let (sk, pk) = S::new_key(&mut rng, ¶m)?; let m = Rq::rand(&mut rng, msg_dist, ¶m.pt()); // msg let p = t_encode(¶m.ring, &m); // plaintext let c = S::encrypt(&mut rng, ¶m, &pk, &p)?; // ciphertext let p_recovered = c.decrypt(&sk); let m_recovered = t_decode(¶m, &p_recovered); assert_eq!(m, m_recovered); // same but using encrypt_s (with sk instead of pk)) let c = S::encrypt_s(&mut rng, ¶m, &sk, &p)?; let p_recovered = c.decrypt(&sk); let m_recovered = t_decode(¶m, &p_recovered); assert_eq!(m, m_recovered); } Ok(()) } #[test] fn test_addition() -> Result<()> { let param = Param { err_sigma: ERR_SIGMA, ring: RingParam { q: 2u64.pow(16) + 1, n: 128, }, k: 16, t: 20, // plaintext modulus }; type S = GLWE; let mut rng = rand::rng(); let msg_dist = Uniform::new(0_u64, param.t)?; for _ in 0..200 { let (sk, pk) = S::new_key(&mut rng, ¶m)?; let m1 = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; let m2 = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; let p1: Rq = S::encode(¶m, &m1); // plaintext let p2: Rq = S::encode(¶m, &m2); // plaintext let c1 = S::encrypt(&mut rng, ¶m, &pk, &p1)?; let c2 = S::encrypt(&mut rng, ¶m, &pk, &p2)?; let c3 = c1 + c2; let p3_recovered = c3.decrypt(&sk); let m3_recovered = S::decode(¶m, &p3_recovered); assert_eq!((m1 + m2).remodule(param.t), m3_recovered.remodule(param.t)); } Ok(()) } #[test] fn test_add_plaintext() -> Result<()> { let param = Param { err_sigma: ERR_SIGMA, ring: RingParam { q: 2u64.pow(16) + 1, n: 128, }, k: 16, t: 32, // plaintext modulus }; type S = GLWE; let mut rng = rand::rng(); let msg_dist = Uniform::new(0_u64, param.t)?; for _ in 0..200 { let (sk, pk) = S::new_key(&mut rng, ¶m)?; let m1 = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; let m2 = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; let p1: Rq = S::encode(¶m, &m1); // plaintext let p2: Rq = S::encode(¶m, &m2); // plaintext let c1 = S::encrypt(&mut rng, ¶m, &pk, &p1)?; let c3 = c1 + p2; let p3_recovered = c3.decrypt(&sk); let m3_recovered = S::decode(¶m, &p3_recovered); assert_eq!((m1 + m2).remodule(param.t), m3_recovered.remodule(param.t)); } Ok(()) } #[test] fn test_mul_plaintext() -> Result<()> { let param = Param { err_sigma: ERR_SIGMA, ring: RingParam { q: 2u64.pow(16) + 1, n: 16, }, k: 16, t: 4, // plaintext modulus }; type S = GLWE; let mut rng = rand::rng(); let msg_dist = Uniform::new(0_u64, param.t)?; for _ in 0..200 { let (sk, pk) = S::new_key(&mut rng, ¶m)?; let m1 = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; let m2 = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; let p1: Rq = S::encode(¶m, &m1); // plaintext let p2 = m2.remodule(param.ring.q); // notice we don't encode (scale by delta) let c1 = S::encrypt(&mut rng, ¶m, &pk, &p1)?; let c3 = c1 * p2; let p3_recovered: Rq = c3.decrypt(&sk); let m3_recovered: Rq = S::decode(¶m, &p3_recovered); assert_eq!((m1.to_r() * m2.to_r()).to_rq(param.t), m3_recovered); } Ok(()) } #[test] fn test_mod_switch() -> Result<()> { let param = Param { err_sigma: ERR_SIGMA, ring: RingParam { q: 2u64.pow(16) + 1, n: 8, }, k: 16, t: 4, // plaintext modulus, must be a prime or power of a prime }; let new_q: u64 = 2u64.pow(8) + 1; // note: wip, Q and P chosen so that P/Q is an integer type S = GLWE; let mut rng = rand::rng(); let msg_dist = Uniform::new(0_u64, param.t)?; for _ in 0..200 { let (sk, pk) = S::new_key(&mut rng, ¶m)?; let m = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; let p = S::encode(¶m, &m); let c = S::encrypt(&mut rng, ¶m, &pk, &p)?; let c2: GLWE = c.mod_switch(new_q); assert_eq!(c2.1.param.q, new_q); let sk2: SecretKey = SecretKey(TR { k: param.k, r: sk.0.r.iter().map(|s_i| s_i.remodule(new_q)).collect(), }); let p_recovered = c2.decrypt(&sk2); let new_param = Param { err_sigma: ERR_SIGMA, ring: RingParam { q: new_q, n: param.ring.n, }, k: param.k, t: param.t, }; let m_recovered = GLWE::::decode(&new_param, &p_recovered); assert_eq!(m.remodule(param.t), m_recovered.remodule(param.t)); } Ok(()) } #[test] fn test_key_switch() -> Result<()> { let param = Param { err_sigma: ERR_SIGMA, ring: RingParam { q: 2u64.pow(16) + 1, n: 128, }, k: 16, t: 2, }; type S = GLWE; let beta: u32 = 2; let l: u32 = 16; let mut rng = rand::rng(); let (sk, pk) = S::new_key(&mut rng, ¶m)?; let (sk2, _) = S::new_key(&mut rng, ¶m)?; // ksk to switch from sk to sk2 let ksk = S::new_ksk(&mut rng, ¶m, beta, l, &sk, &sk2)?; let msg_dist = Uniform::new(0_u64, param.t)?; let m = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; let p = S::encode(¶m, &m); // plaintext // let c = S::encrypt_s(&mut rng, ¶m, &sk, &p)?; let c2 = c.key_switch(¶m, beta, l, &ksk); // decrypt with the 2nd secret key let p_recovered = c2.decrypt(&sk2); let m_recovered = S::decode(¶m, &p_recovered); assert_eq!(m.remodule(param.t), m_recovered.remodule(param.t)); // do the same but now encrypting with pk let c = S::encrypt(&mut rng, ¶m, &pk, &p)?; let c2 = c.key_switch(¶m, beta, l, &ksk); let p_recovered = c2.decrypt(&sk2); let m_recovered = S::decode(¶m, &p_recovered); assert_eq!(m, m_recovered); Ok(()) } }