diff --git a/Cargo.toml b/Cargo.toml index 8ac3f64..32238ca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,8 @@ members = [ "arith", "gfhe", "bfv", - "ckks" + "ckks", + "tfhe" ] resolver = "2" diff --git a/README.md b/README.md index 096c069..6d682d9 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,10 @@ # fhe-study Implementations from scratch done while studying some FHE papers; do not use in production. -- `arith`: contains $\mathbb{Z}_q$, $R_q=\mathbb{Z}_q[X]/(X^N+1)$ and $R=\mathbb{Z}[X]/(X^N+1)$ arithmetic implementations, together with the NTT implementation. +- `arith`: contains $\mathbb{Z}_q$, $R_q=\mathbb{Z}_q[X]/(X^N+1)$, $R=\mathbb{Z}[X]/(X^N+1)$, $\mathbb{T}_{Q}[X]/(X^N +1)$ arithmetic implementations, together with the NTT implementation. - `gfhe`: (gfhe=generalized-fhe) contains the structs and logic for RLWE, GLWE, GLev, GGSW, RGSW cryptosystems, and modulus switching and key switching methods, which can be used by concrete FHE schemes. - `bfv`: https://eprint.iacr.org/2012/144.pdf scheme implementation - `ckks`: https://eprint.iacr.org/2016/421.pdf scheme implementation +- `tfhe`: https://eprint.iacr.org/2018/421.pdf scheme implementation `cargo test --release` diff --git a/arith/src/ring.rs b/arith/src/ring.rs index 03205ac..50ab7f6 100644 --- a/arith/src/ring.rs +++ b/arith/src/ring.rs @@ -30,4 +30,10 @@ pub trait Ring: fn from_vec(coeffs: Vec) -> Self; fn decompose(&self, beta: u32, l: u32) -> Vec; + + /// returns [ [(num/den) * self].round() ] mod q + /// ie. performs the multiplication and division over f64, and then it + /// rounds the result, only applying the mod Q (if the ring is mod Q) at the + /// end. + fn mul_div_round(&self, num: u64, den: u64) -> Self; } diff --git a/arith/src/ring_n.rs b/arith/src/ring_n.rs index a2cbef4..e15c915 100644 --- a/arith/src/ring_n.rs +++ b/arith/src/ring_n.rs @@ -43,6 +43,19 @@ impl Ring for R { unimplemented!(); // array::from_fn(|i| self.coeffs[i].decompose(beta, l)) } + + // performs the multiplication and division over f64, and then it rounds the + // result, only applying the mod Q at the end + fn mul_div_round(&self, num: u64, den: u64) -> Self { + unimplemented!() + // fn mul_div_round(&self, num: u64, den: u64) -> crate::Rq { + // let r: Vec = self + // .coeffs() + // .iter() + // .map(|e| ((num as f64 * *e as f64) / den as f64).round()) + // .collect(); + // crate::Rq::::from_vec_f64(r) + } } impl From> for R { @@ -74,16 +87,6 @@ impl R { pub fn mul_by_i64(&self, s: i64) -> Self { Self(array::from_fn(|i| self.0[i] * s)) } - // performs the multiplication and division over f64, and then it rounds the - // result, only applying the mod Q at the end - pub fn mul_div_round(&self, num: u64, den: u64) -> crate::Rq { - let r: Vec = self - .coeffs() - .iter() - .map(|e| ((num as f64 * *e as f64) / den as f64).round()) - .collect(); - crate::Rq::::from_vec_f64(r) - } pub fn infinity_norm(&self) -> u64 { self.coeffs() diff --git a/arith/src/ring_nq.rs b/arith/src/ring_nq.rs index 4c6b1c0..7e1a373 100644 --- a/arith/src/ring_nq.rs +++ b/arith/src/ring_nq.rs @@ -70,6 +70,18 @@ impl Ring for Rq { // convert it to Rq r.iter().map(|a_i| Self::from_vec(a_i.clone())).collect() } + + // returns [ [(num/den) * self].round() ] mod q + // ie. performs the multiplication and division over f64, and then it rounds the + // result, only applying the mod Q at the end + fn mul_div_round(&self, num: u64, den: u64) -> Self { + let r: Vec = self + .coeffs() + .iter() + .map(|e| ((num as f64 * e.0 as f64) / den as f64).round()) + .collect(); + Rq::::from_vec_f64(r) + } } impl From> for Rq { @@ -231,17 +243,6 @@ impl Rq { .collect(); Rq::::from_vec_f64(r) } - // returns [ [(num/den) * self].round() ] mod q - // ie. performs the multiplication and division over f64, and then it rounds the - // result, only applying the mod Q at the end - pub fn mul_div_round(&self, num: u64, den: u64) -> Self { - let r: Vec = self - .coeffs() - .iter() - .map(|e| ((num as f64 * e.0 as f64) / den as f64).round()) - .collect(); - Rq::::from_vec_f64(r) - } fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { // TODO simplify diff --git a/arith/src/ring_torus.rs b/arith/src/ring_torus.rs index 41c7b10..9da3656 100644 --- a/arith/src/ring_torus.rs +++ b/arith/src/ring_torus.rs @@ -49,6 +49,18 @@ impl Ring for Tn { // convert it to Tn r.iter().map(|a_i| Self::from_vec(a_i.clone())).collect() } + + /// returns [ [(num/den) * self].round() ] mod q + /// ie. performs the multiplication and division over f64, and then it rounds the + /// result, only applying the mod Q at the end + fn mul_div_round(&self, num: u64, den: u64) -> Self { + let r: Vec = self + .coeffs() + .iter() + .map(|e| T64(((num as f64 * e.0 as f64) / den as f64).round() as u64)) + .collect(); + Self::from_vec(r) + } } // apply mod (X^N+1) diff --git a/arith/src/tuple_ring.rs b/arith/src/tuple_ring.rs index ab2e13a..af18e76 100644 --- a/arith/src/tuple_ring.rs +++ b/arith/src/tuple_ring.rs @@ -1,12 +1,15 @@ //! This file implements the struct for an Tuple of Ring Rq elements and its -//! operations. +//! operations, which are performed element-wise. use anyhow::Result; use itertools::zip_eq; use rand::{distributions::Distribution, Rng}; use rand_distr::{Normal, Uniform}; use std::iter::Sum; -use std::{array, ops}; +use std::{ + array, + ops::{Add, Mul, Sub}, +}; use crate::Ring; @@ -34,7 +37,7 @@ impl TR { } } -impl ops::Add> for TR { +impl Add> for TR { type Output = Self; fn add(self, other: Self) -> Self { Self( @@ -45,7 +48,7 @@ impl ops::Add> for TR { } } -impl ops::Sub> for TR { +impl Sub> for TR { type Output = Self; fn sub(self, other: Self) -> Self { Self(zip_eq(self.0, other.0).map(|(s, o)| s - o).collect()) @@ -54,13 +57,13 @@ impl ops::Sub> for TR { /// for (TR,TR), the Mul operation is defined as: /// for A, B \in R^k, result = Σ A_i * B_i \in R -impl ops::Mul> for TR { +impl Mul> for TR { type Output = R; fn mul(self, other: Self) -> R { zip_eq(self.0, other.0).map(|(s, o)| s * o).sum() } } -impl ops::Mul<&TR> for &TR { +impl Mul<&TR> for &TR { type Output = R; fn mul(self, other: &TR) -> R { zip_eq(self.0.clone(), other.0.clone()) @@ -71,13 +74,13 @@ impl ops::Mul<&TR> for &TR { /// for (TR, R), the Mul operation is defined as each element of TR is /// multiplied by R -impl ops::Mul for TR { +impl Mul for TR { type Output = TR; fn mul(self, other: R) -> TR { Self(self.0.iter().map(|s| s.clone() * other.clone()).collect()) } } -impl ops::Mul<&R> for &TR { +impl Mul<&R> for &TR { type Output = TR; fn mul(self, other: &R) -> TR { TR::(self.0.iter().map(|s| s.clone() * other.clone()).collect()) diff --git a/gfhe/src/glev.rs b/gfhe/src/glev.rs index 8337c1b..ce6d083 100644 --- a/gfhe/src/glev.rs +++ b/gfhe/src/glev.rs @@ -14,17 +14,22 @@ const ERR_SIGMA: f64 = 3.2; pub struct GLev(pub(crate) Vec>); impl GLev { + pub fn encode(m: &Rq) -> Rq { + m.remodule::() + } + pub fn decode(p: &Rq) -> Rq { + p.remodule::() + } pub fn encrypt( mut rng: impl Rng, beta: u32, l: u32, pk: &PublicKey, m: &Rq, - // delta: u64, ) -> Result { - let glev: Vec> = (0..l) + let glev: Vec> = (1..l + 1) .map(|i| { - GLWE::::encrypt(&mut rng, pk, &(*m * (Q / beta.pow(i as u32) as u64)), 1) + GLWE::::encrypt(&mut rng, pk, &(*m * (Q / beta.pow(i as u32) as u64))) }) .collect::>>()?; @@ -36,19 +41,19 @@ impl GLev { l: u32, sk: &SecretKey, m: &Rq, - // delta: u64, ) -> Result { - let glev: Vec> = (0..l) + let glev: Vec> = (1..l + 1) .map(|i| { - GLWE::::encrypt_s(&mut rng, sk, &(*m * (Q / beta.pow(i as u32) as u64)), 1) + GLWE::::encrypt_s(&mut rng, sk, &(*m * (Q / beta.pow(i as u32) as u64))) }) .collect::>>()?; Ok(Self(glev)) } - pub fn decrypt(&self, sk: &SecretKey, delta: u64) -> Rq { - self.0[1].decrypt::(sk, delta) + pub fn decrypt(&self, sk: &SecretKey, beta: u32) -> Rq { + let pt = self.0[0].decrypt(sk); + pt.mul_div_round(beta as u64, Q) } } @@ -70,7 +75,6 @@ mod tests { let beta: u32 = 2; let l: u32 = 16; - let delta: u64 = Q / T; // floored let mut rng = rand::thread_rng(); for _ in 0..200 { @@ -78,12 +82,13 @@ mod tests { let msg_dist = Uniform::new(0_u64, T); let m = Rq::::rand_u64(&mut rng, msg_dist)?; - let m: Rq = m.remodule::(); + let p: Rq = S::encode::(&m); // plaintext - let c = S::encrypt(&mut rng, beta, l, &pk, &m)?; - let m_recovered = c.decrypt::(&sk, delta); + let c = S::encrypt(&mut rng, beta, l, &pk, &p)?; + let p_recovered = c.decrypt::(&sk, beta); + let m_recovered = S::decode::(&p_recovered); - assert_eq!(m.remodule::(), m_recovered.remodule::()); + assert_eq!(m, m_recovered); } Ok(()) diff --git a/gfhe/src/glwe.rs b/gfhe/src/glwe.rs index 417c2ec..39ddb48 100644 --- a/gfhe/src/glwe.rs +++ b/gfhe/src/glwe.rs @@ -80,29 +80,30 @@ impl GLWE { r } + // scale up + pub fn encode(m: &Rq) -> Rq { + let m = m.remodule::(); + let delta = Q / T; // floored + m * delta + } + // scale down + pub fn decode(p: &Rq) -> Rq { + let r = p.mul_div_round(T, Q); + r.remodule::() + } + // encrypts with the given SecretKey (instead of PublicKey) - pub fn encrypt_s( - mut rng: impl Rng, - sk: &SecretKey, - m: &Rq, - // TODO delta not as input - delta: u64, - ) -> Result { + pub fn encrypt_s(mut rng: impl Rng, sk: &SecretKey, m: &Rq) -> Result { let Xi_key = Uniform::new(0_f64, 2_f64); let Xi_err = Normal::new(0_f64, ERR_SIGMA)?; let a: TR, K> = TR::rand(&mut rng, Xi_key); let e = Rq::::rand(&mut rng, Xi_err); - let b: Rq = (&a * &sk.0) + *m * delta + e; + let b: Rq = (&a * &sk.0) + *m + e; Ok(Self(a, b)) } - pub fn encrypt( - mut rng: impl Rng, - pk: &PublicKey, - m: &Rq, - delta: u64, - ) -> Result { + pub fn encrypt(mut rng: impl Rng, pk: &PublicKey, m: &Rq) -> Result { let Xi_key = Uniform::new(0_f64, 2_f64); let Xi_err = Normal::new(0_f64, ERR_SIGMA)?; @@ -111,15 +112,14 @@ impl GLWE { let e0 = Rq::::rand(&mut rng, Xi_err); let e1 = TR::, K>::rand(&mut rng, Xi_err); - let b: Rq = pk.0 * u + *m * delta + e0; + let b: Rq = pk.0 * u + *m + e0; let d: TR, K> = &pk.1 * &u + e1; Ok(Self(d, b)) } - pub fn decrypt(&self, sk: &SecretKey, delta: u64) -> Rq { + pub fn decrypt(&self, sk: &SecretKey) -> Rq { let (d, b): (TR, K>, Rq) = (self.0.clone(), self.1); let r: Rq = b - &d * &sk.0; - let r = r.mul_div_round(T, Q); r } @@ -215,7 +215,6 @@ mod tests { const K: usize = 16; type S = GLWE; - let delta: u64 = Q / T; // floored let mut rng = rand::thread_rng(); for _ in 0..200 { @@ -223,16 +222,18 @@ mod tests { let msg_dist = Uniform::new(0_u64, T); let m = Rq::::rand_u64(&mut rng, msg_dist)?; - let m: Rq = m.remodule::(); + let p = S::encode::(&m); // plaintext - let c = S::encrypt(&mut rng, &pk, &m, delta)?; - let m_recovered = c.decrypt::(&sk, delta); + let c = S::encrypt(&mut rng, &pk, &p)?; + let p_recovered = c.decrypt(&sk); + let m_recovered = S::decode::(&p_recovered); assert_eq!(m.remodule::(), m_recovered.remodule::()); // same but using encrypt_s (with sk instead of pk)) - let c = S::encrypt_s(&mut rng, &sk, &m, delta)?; - let m_recovered = c.decrypt::(&sk, delta); + let c = S::encrypt_s(&mut rng, &sk, &p)?; + let p_recovered = c.decrypt(&sk); + let m_recovered = S::decode::(&p_recovered); assert_eq!(m.remodule::(), m_recovered.remodule::()); } @@ -248,7 +249,6 @@ mod tests { const K: usize = 16; type S = GLWE; - let delta: u64 = Q / T; // floored let mut rng = rand::thread_rng(); for _ in 0..200 { @@ -257,15 +257,16 @@ mod tests { let msg_dist = Uniform::new(0_u64, T); let m1 = Rq::::rand_u64(&mut rng, msg_dist)?; let m2 = Rq::::rand_u64(&mut rng, msg_dist)?; - let m1: Rq = m1.remodule::(); - let m2: Rq = m2.remodule::(); + let p1: Rq = S::encode::(&m1); // plaintext + let p2: Rq = S::encode::(&m2); // plaintext - let c1 = S::encrypt(&mut rng, &pk, &m1, delta)?; - let c2 = S::encrypt(&mut rng, &pk, &m2, delta)?; + let c1 = S::encrypt(&mut rng, &pk, &p1)?; + let c2 = S::encrypt(&mut rng, &pk, &p2)?; let c3 = c1 + c2; - let m3_recovered = c3.decrypt::(&sk, delta); + let p3_recovered = c3.decrypt(&sk); + let m3_recovered = S::decode::(&p3_recovered); assert_eq!((m1 + m2).remodule::(), m3_recovered.remodule::()); } @@ -281,7 +282,6 @@ mod tests { const K: usize = 16; type S = GLWE; - let delta: u64 = Q / T; // floored let mut rng = rand::thread_rng(); for _ in 0..200 { @@ -290,17 +290,17 @@ mod tests { let msg_dist = Uniform::new(0_u64, T); let m1 = Rq::::rand_u64(&mut rng, msg_dist)?; let m2 = Rq::::rand_u64(&mut rng, msg_dist)?; - let m1: Rq = m1.remodule::(); - let m2: Rq = m2.remodule::(); - let m2_scaled: Rq = m2 * delta; + let p1: Rq = S::encode::(&m1); // plaintext + let p2: Rq = S::encode::(&m2); // plaintext - let c1 = S::encrypt(&mut rng, &pk, &m1, delta)?; + let c1 = S::encrypt(&mut rng, &pk, &p1)?; - let c3 = c1 + m2_scaled; + let c3 = c1 + p2; - let m3_recovered = c3.decrypt::(&sk, delta); + let p3_recovered = c3.decrypt(&sk); + let m3_recovered = S::decode::(&p3_recovered); - assert_eq!((m1 + m2).remodule::(), m3_recovered.remodule::()); + assert_eq!((m1 + m2).remodule::(), m3_recovered); } Ok(()) @@ -314,7 +314,6 @@ mod tests { const K: usize = 16; type S = GLWE; - let delta: u64 = Q / T; // floored let mut rng = rand::thread_rng(); for _ in 0..200 { @@ -323,14 +322,15 @@ mod tests { let msg_dist = Uniform::new(0_u64, T); let m1 = Rq::::rand_u64(&mut rng, msg_dist)?; let m2 = Rq::::rand_u64(&mut rng, msg_dist)?; - let m1: Rq = m1.remodule::(); - let m2: Rq = m2.remodule::(); - let c1 = S::encrypt(&mut rng, &pk, &m1, delta)?; + let p1: Rq = S::encode::(&m1); // plaintext + let p2: Rq = m2.remodule::(); + + let c1 = S::encrypt(&mut rng, &pk, &p1)?; - let c3 = c1 * m2; + let c3 = c1 * p2; - let m3_recovered: Rq = c3.decrypt::(&sk, delta); - let m3_recovered: Rq = m3_recovered.remodule::(); + let p3_recovered: Rq = c3.decrypt(&sk); + let m3_recovered = S::decode::(&p3_recovered); assert_eq!((m1.to_r() * m2.to_r()).to_rq::(), m3_recovered); } @@ -360,17 +360,18 @@ mod tests { let msg_dist = Uniform::new(0_u64, T); let m = Rq::::rand_u64(&mut rng, msg_dist)?; - let m: Rq = m.remodule::(); + let p = S::encode::(&m); // plaintext - let c = S::encrypt(&mut rng, &pk, &m, delta)?; + let c = S::encrypt(&mut rng, &pk, &p)?; // let c = S::encrypt_s(&mut rng, &sk, &m, delta)?; let c2 = c.mod_switch::

(); let sk2: SecretKey = SecretKey(TR(sk.0 .0.iter().map(|s_i| s_i.remodule::

()).collect())); - let delta2: u64 = ((P as f64 * delta as f64) / Q as f64).round() as u64; + // let delta2: u64 = ((P as f64 * delta as f64) / Q as f64).round() as u64; - let m_recovered = c2.decrypt::(&sk2, delta2); + let p_recovered = c2.decrypt(&sk2); + let m_recovered = GLWE::::decode::(&p_recovered); assert_eq!(m.remodule::(), m_recovered.remodule::()); } @@ -389,7 +390,6 @@ mod tests { let beta: u32 = 2; let l: u32 = 16; - let delta: u64 = Q / T; // floored let mut rng = rand::thread_rng(); let (sk, pk) = S::new_key(&mut rng)?; @@ -399,21 +399,23 @@ mod tests { let msg_dist = Uniform::new(0_u64, T); let m = Rq::::rand_u64(&mut rng, msg_dist)?; - let m: Rq = m.remodule::(); + let p: Rq = S::encode::(&m); // plaintext - let c = S::encrypt_s(&mut rng, &sk, &m, delta)?; + let c = S::encrypt_s(&mut rng, &sk, &p)?; let c2 = c.key_switch(beta, l, &ksk); // decrypt with the 2nd secret key - let m_recovered = c2.decrypt::(&sk2, delta); - assert_eq!(m.remodule::(), m_recovered.remodule::()); + let p_recovered = c2.decrypt(&sk2); + let m_recovered = S::decode::(&p_recovered); + assert_eq!(m, m_recovered); // do the same but now encrypting with pk - // let c = S::encrypt(&mut rng, &pk, &m, delta)?; + // let c = S::encrypt(&mut rng, &pk, &p)?; // let c2 = c.key_switch(beta, l, &ksk); - // let m_recovered = c2.decrypt::(&sk2, delta); - // assert_eq!(m.remodule::(), m_recovered.remodule::()); + // let p_recovered = c2.decrypt(&sk2); + // let m_recovered = S::decode::(&p_recovered); + // assert_eq!(m, m_recovered); Ok(()) } diff --git a/tfhe/Cargo.toml b/tfhe/Cargo.toml new file mode 100644 index 0000000..01e6a56 --- /dev/null +++ b/tfhe/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "tfhe" +version = "0.1.0" +edition = "2024" + +[dependencies] +anyhow = { workspace = true } +rand = { workspace = true } +rand_distr = { workspace = true } +itertools = { workspace = true } + +arith = { path="../arith" } diff --git a/tfhe/src/lib.rs b/tfhe/src/lib.rs new file mode 100644 index 0000000..42c7d85 --- /dev/null +++ b/tfhe/src/lib.rs @@ -0,0 +1,8 @@ +//! Implementation of TFHE https://eprint.iacr.org/2018/421.pdf +#![allow(non_snake_case)] +#![allow(non_upper_case_globals)] +#![allow(non_camel_case_types)] +#![allow(clippy::upper_case_acronyms)] +#![allow(dead_code)] // TMP + +pub mod tlwe; diff --git a/tfhe/src/tlwe.rs b/tfhe/src/tlwe.rs new file mode 100644 index 0000000..bc5bdee --- /dev/null +++ b/tfhe/src/tlwe.rs @@ -0,0 +1,120 @@ +use anyhow::Result; +use itertools::zip_eq; +use rand::distributions::Standard; +use rand::Rng; +use rand_distr::{Normal, Uniform}; +use std::array; +use std::iter::Sum; +use std::ops::{Add, AddAssign, Mul, Sub}; + +use arith::{Ring, Rq, Tn, Zq, T64, TR}; + +const ERR_SIGMA: f64 = 3.2; + +#[derive(Clone, Debug)] +pub struct TLWE(TR, K>, Tn<1>); + +#[derive(Clone, Debug)] +pub struct SecretKey(TR, K>); +#[derive(Clone, Debug)] +pub struct PublicKey(Tn<1>, TR, K>); + +impl TLWE { + pub fn zero() -> Self { + Self(TR::zero(), Tn::zero()) + } + + pub fn new_key(mut rng: impl Rng) -> Result<(SecretKey, PublicKey)> { + let Xi_key = Uniform::new(0_f64, 2_f64); + let Xi_err = Normal::new(0_f64, ERR_SIGMA)?; + + let s: TR, K> = TR::rand(&mut rng, Xi_key); + let a: TR, K> = TR::rand(&mut rng, Standard); + let e = Tn::rand(&mut rng, Xi_err); + + let pk: PublicKey = PublicKey((&a * &s) + e, a); + Ok((SecretKey(s), pk)) + } + + pub fn encode(m: Rq) -> Tn<1> { + let delta = u64::MAX / P; // floored + let coeffs = m.coeffs(); + Tn(array::from_fn(|i| T64(coeffs[i].0 * delta))) + } + pub fn decode(p: Tn<1>) -> Rq { + let p = p.mul_div_round(P, u64::MAX); + Rq::::from_vec_u64(p.coeffs().iter().map(|c| c.0).collect()) + } + + // encrypts with the given SecretKey (instead of PublicKey) + pub fn encrypt_s(mut rng: impl Rng, sk: &SecretKey, m: &Tn<1>) -> Result { + let Xi_key = Uniform::new(0_f64, 2_f64); + let Xi_err = Normal::new(0_f64, ERR_SIGMA)?; + + let a: TR, K> = TR::rand(&mut rng, Xi_key); + let e = Tn::rand(&mut rng, Xi_err); + + let b: Tn<1> = (&a * &sk.0) + *m + e; + Ok(Self(a, b)) + } + pub fn encrypt(mut rng: impl Rng, pk: &PublicKey, m: &Tn<1>) -> Result { + let Xi_key = Uniform::new(0_f64, 2_f64); + let Xi_err = Normal::new(0_f64, ERR_SIGMA)?; + + let u: Tn<1> = Tn::rand(&mut rng, Xi_key); + + let e0: Tn<1> = Tn::rand(&mut rng, Xi_err); + let e1 = TR::, K>::rand(&mut rng, Xi_err); + + let b: Tn<1> = pk.0 * u + *m + e0; + let d: TR, K> = &pk.1 * &u + e1; + + Ok(Self(d, b)) + } + pub fn decrypt(&self, sk: &SecretKey) -> Tn<1> { + let (d, b): (TR, K>, Tn<1>) = (self.0.clone(), self.1); + b - &d * &sk.0 + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use rand::distributions::Uniform; + + use super::*; + + #[test] + fn test_encrypt_decrypt() -> Result<()> { + const T: u64 = 32; // plaintext modulus + const K: usize = 16; + type S = TLWE; + + let mut rng = rand::thread_rng(); + + for _ in 0..200 { + let (sk, pk) = S::new_key(&mut rng)?; + + let msg_dist = Uniform::new(0_u64, T); + let m = Rq::::rand_u64(&mut rng, msg_dist)?; + dbg!(&m); + let p: Tn<1> = S::encode::(m); + dbg!(&p); + + let c = S::encrypt(&mut rng, &pk, &p)?; + let p_recovered = c.decrypt(&sk); + let m_recovered = S::decode::(p_recovered); + + assert_eq!(m, m_recovered); + + // same but using encrypt_s (with sk instead of pk)) + let c = S::encrypt_s(&mut rng, &sk, &p)?; + let p_recovered = c.decrypt(&sk); + let m_recovered = S::decode::(p_recovered); + + assert_eq!(m, m_recovered); + } + + Ok(()) + } +}