diff --git a/README.md b/README.md index 170530a..7cea5d4 100644 --- a/README.md +++ b/README.md @@ -16,30 +16,34 @@ Implementations from scratch done while studying some FHE papers; do not use in This example shows usage of TFHE, but the idea is that the same interface would work for using CKKS & BFV, the only thing to be changed would be the parameters -and the line `type S = TWLE` to use `CKKS` or `BFV`. +and the usage of `TLWE` by `CKKS` or `BFV`. ```rust -const T: u64 = 128; // msg space (msg modulus) -type M = Rq; // msg space -type S = TLWE<256>; +let param = Param { + err_sigma: crate::ERR_SIGMA, + ring: RingParam { q: u64::MAX, n: 1 }, + k: 256, + t: 128, // plaintext modulus +}; let mut rng = rand::thread_rng(); -let msg_dist = Uniform::new(0_u64, T); +let msg_dist = Uniform::new(0_u64, param.t); -let (sk, pk) = S::new_key(&mut rng)?; +let (sk, pk) = TLWE::new_key(&mut rng, ¶m)?; -// get two random msgs in Z_t -let m1 = M::rand_u64(&mut rng, msg_dist)?; -let m2 = M::rand_u64(&mut rng, msg_dist)?; -let m3 = M::rand_u64(&mut rng, msg_dist)?; +// get three random msgs in Rt +let m1 = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; +let m2 = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; +let m3 = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; // encode the msgs into the plaintext space -let p1 = S::encode::(&m1); // plaintext -let p2 = S::encode::(&m2); // plaintext -let c3_const: Tn<1> = Tn(array::from_fn(|i| T64(m3.coeffs()[i].0))); // encode it as constant +let p1 = TLWE::encode(¶m, &m1); // plaintext +let p2 = TLWE::encode(¶m, &m2); // plaintext +let c3_const = TLWE::new_const(¶m, &m3); // as constant/public value -let c1 = S::encrypt(&mut rng, &pk, &p1)?; -let c2 = S::encrypt(&mut rng, &pk, &p2)?; +// encrypt p1 and m2 +let c1 = TLWE::encrypt(&mut rng, ¶m, &pk, &p1)?; +let c2 = TLWE::encrypt(&mut rng, ¶m, &pk, &p2)?; // now we can do encrypted operations (notice that we do them using simple // operation notation by rust's operator overloading): @@ -48,9 +52,10 @@ let c4 = c_12 * c3_const; // decrypt & decode let p4_recovered = c4.decrypt(&sk); -let m4 = S::decode::(&p4_recovered); +let m4 = TLWE::decode(¶m, &p4_recovered); // m4 is equal to (m1+m2)*m3 +assert_eq!(((m1 + m2).to_r() * m3.to_r()).to_rq(param.t), m4); ``` @@ -62,7 +67,7 @@ let m4 = S::decode::(&p4_recovered); - external products of ciphertexts - TGSW x TLWE - TGGSW x TGLWE - - TGSW & TGGSW CMux gate + - {TGSW, TGGSW} CMux gate - blind rotation, key switching, mod switching - bootstrapping - CKKS diff --git a/arith/Cargo.toml b/arith/Cargo.toml index e00269d..692757b 100644 --- a/arith/Cargo.toml +++ b/arith/Cargo.toml @@ -8,6 +8,7 @@ anyhow = { workspace = true } rand = { workspace = true } rand_distr = { workspace = true } itertools = { workspace = true } +lazy_static = "1.5.0" # TMP: the next 4 imports are TMP, to solve systems of linear equations. Used # for the CKKS encoding step, probably remvoed once in ckks the encoding is done diff --git a/arith/src/lib.rs b/arith/src/lib.rs index f04c61b..f72cb1f 100644 --- a/arith/src/lib.rs +++ b/arith/src/lib.rs @@ -15,17 +15,16 @@ pub mod ring_nq; pub mod ring_torus; pub mod tuple_ring; -mod naive_ntt; // note: for dev only +// mod naive_ntt; // note: for dev only pub mod ntt; // expose objects - pub use complex::C; pub use matrix::Matrix; pub use torus::T64; pub use zq::Zq; -pub use ring::Ring; +pub use ring::{Ring, RingParam}; pub use ring_n::R; pub use ring_nq::Rq; pub use ring_torus::Tn; diff --git a/arith/src/ntt.rs b/arith/src/ntt.rs index 90b8d78..890f8ed 100644 --- a/arith/src/ntt.rs +++ b/arith/src/ntt.rs @@ -1,34 +1,60 @@ //! Implementation of the NTT & iNTT, following the CT & GS algorighms, more details in //! https://eprint.iacr.org/2017/727.pdf, some notes at //! https://github.com/arnaucube/math/blob/master/notes_ntt.pdf . -use crate::zq::Zq; +//! +//! NOTE: initially I implemented it with fixed Q & N, given as constant +//! generics; but once using real-world parameters, the stack could not handle +//! it, so moved to use Vec instead of fixed-sized arrays, and adapted the NTT +//! implementation to that too. +use crate::{ring::RingParam, ring_nq::Rq, zq::Zq}; + +use std::collections::HashMap; #[derive(Debug)] -pub struct NTT {} - -impl NTT { - const N_INV: Zq = Zq(const_inv_mod::(N as u64)); - // since we work over Zq[X]/(X^N+1) (negacyclic), get the 2*N-th root of unity - pub(crate) const ROOT_OF_UNITY: u64 = primitive_root_of_unity::(2 * N); - pub(crate) const ROOTS_OF_UNITY: [Zq; N] = roots_of_unity(Self::ROOT_OF_UNITY); - const ROOTS_OF_UNITY_INV: [Zq; N] = roots_of_unity_inv(Self::ROOTS_OF_UNITY); +pub struct NTT {} + +use std::sync::{Mutex, OnceLock}; + +static CACHE: OnceLock, Vec, Zq)>>> = OnceLock::new(); + +fn roots(q: u64, n: usize) -> (Vec, Vec, Zq) { + let cache_lock = CACHE.get_or_init(|| Mutex::new(HashMap::new())); + let mut cache = cache_lock.lock().unwrap(); + if let Some(value) = cache.get(&(q, n)) { + return value.clone(); + } + + let n_inv: Zq = Zq { + q, + v: const_inv_mod(q, n as u64), + }; + let root_of_unity: u64 = primitive_root_of_unity(q, 2 * n); + let roots_of_unity: Vec = roots_of_unity(q, n, root_of_unity); + let roots_of_unity_inv: Vec = roots_of_unity_inv(q, n, roots_of_unity.clone()); + let value = (roots_of_unity, roots_of_unity_inv, n_inv); + + cache.insert((q, n), value.clone()); + value } -impl NTT { +impl NTT { /// implements the Cooley-Tukey (CT) algorithm. Details at /// https://eprint.iacr.org/2017/727.pdf, also some notes at section 3.1 of /// https://github.com/arnaucube/math/blob/master/notes_ntt.pdf - pub fn ntt(a: [Zq; N]) -> [Zq; N] { - let mut t = N / 2; + pub fn ntt(a: &Rq) -> Rq { + let (q, n) = (a.param.q, a.param.n); + let (roots_of_unity, _, _) = roots(q, n); + + let mut t = n / 2; let mut m = 1; - let mut r: [Zq; N] = a.clone(); - while m < N { + let mut r: Vec = a.coeffs.clone(); + while m < n { let mut k = 0; for i in 0..m { - let S: Zq = Self::ROOTS_OF_UNITY[m + i]; + let S: Zq = roots_of_unity[m + i]; for j in k..k + t { - let U: Zq = r[j]; - let V: Zq = r[j + t] * S; + let U: Zq = r[j]; + let V: Zq = r[j + t] * S; r[j] = U + V; r[j + t] = U - V; } @@ -37,23 +63,32 @@ impl NTT { t /= 2; m *= 2; } - r + // TODO think if maybe not return a Rq type, or if returned Rq, maybe + // fill the `evals` field, which is what we're actually returning here + Rq { + param: RingParam { q, n }, + coeffs: r, + evals: None, + } } /// implements the Cooley-Tukey (CT) algorithm. Details at /// https://eprint.iacr.org/2017/727.pdf, also some notes at section 3.2 of /// https://github.com/arnaucube/math/blob/master/notes_ntt.pdf - pub fn intt(a: [Zq; N]) -> [Zq; N] { + pub fn intt(a: &Rq) -> Rq { + let (q, n) = (a.param.q, a.param.n); + let (_, roots_of_unity_inv, n_inv) = roots(q, n); + let mut t = 1; - let mut m = N / 2; - let mut r: [Zq; N] = a.clone(); + let mut m = n / 2; + let mut r: Vec = a.coeffs.clone(); while m > 0 { let mut k = 0; for i in 0..m { - let S: Zq = Self::ROOTS_OF_UNITY_INV[m + i]; + let S: Zq = roots_of_unity_inv[m + i]; for j in k..k + t { - let U: Zq = r[j]; - let V: Zq = r[j + t]; + let U: Zq = r[j]; + let V: Zq = r[j + t]; r[j] = U + V; r[j + t] = (U - V) * S; } @@ -62,26 +97,32 @@ impl NTT { t *= 2; m /= 2; } - for i in 0..N { - r[i] = r[i] * Self::N_INV; + for i in 0..n { + r[i] = r[i] * n_inv; + } + Rq { + param: RingParam { q, n }, + coeffs: r, + // TODO maybe at `evals` place the inputed `a` which is the evals + // format + evals: None, } - r } } /// computes a primitive N-th root of unity using the method described by Thomas /// Pornin in https://crypto.stackexchange.com/a/63616 -const fn primitive_root_of_unity(N: usize) -> u64 { - assert!(N.is_power_of_two()); - assert!((Q - 1) % N as u64 == 0); +const fn primitive_root_of_unity(q: u64, n: usize) -> u64 { + assert!(n.is_power_of_two()); + assert!((q - 1) % n as u64 == 0); + let n_u64 = n as u64; - let n: u64 = N as u64; let mut k = 1; - while k < Q { + while k < q { // alternatively could get a random k at each iteration, if so, add the following if: // `if k == 0 { continue; }` - let w = const_exp_mod::(k, (Q - 1) / n); - if const_exp_mod::(w, n / 2) != 1 { + let w = const_exp_mod(q, k, (q - 1) / n_u64); + if const_exp_mod(q, w, n_u64 / 2) != 1 { return w; // w is a primitive N-th root of unity } k += 1; @@ -89,78 +130,85 @@ const fn primitive_root_of_unity(N: usize) -> u64 { panic!("No primitive root of unity"); } -const fn roots_of_unity(w: u64) -> [Zq; N] { - let mut r: [Zq; N] = [Zq(0u64); N]; +fn roots_of_unity(q: u64, n: usize, w: u64) -> Vec { + let mut r: Vec = vec![Zq { q, v: 0 }; n]; let mut i = 0; - let log_n = N.ilog2(); - while i < N { + let log_n = n.ilog2(); + while i < n { // (return the roots in bit-reverset order) let j = ((i as u64).reverse_bits() >> (64 - log_n)) as usize; - r[i] = Zq(const_exp_mod::(w, j as u64)); + r[i] = Zq { + q, + v: const_exp_mod(q, w, j as u64), + }; i += 1; } r } -const fn roots_of_unity_inv(v: [Zq; N]) -> [Zq; N] { +fn roots_of_unity_inv(q: u64, n: usize, v: Vec) -> Vec { // assumes that the inputted roots are already in bit-reverset order - let mut r: [Zq; N] = [Zq(0u64); N]; + let mut r: Vec = vec![Zq { q, v: 0 }; n]; let mut i = 0; - while i < N { - r[i] = Zq(const_inv_mod::(v[i].0)); + while i < n { + r[i] = Zq { + q, + v: const_inv_mod(q, v[i].v), + }; i += 1; } r } /// returns x^k mod Q -const fn const_exp_mod(x: u64, k: u64) -> u64 { +const fn const_exp_mod(q: u64, x: u64, k: u64) -> u64 { // work on u128 to avoid overflow let mut r = 1u128; let mut x = x as u128; let mut k = k as u128; - x = x % Q as u128; + x = x % q as u128; // exponentiation by square strategy while k > 0 { if k % 2 == 1 { - r = (r * x) % Q as u128; + r = (r * x) % q as u128; } - x = (x * x) % Q as u128; + x = (x * x) % q as u128; k /= 2; } r as u64 } /// returns x^-1 mod Q -const fn const_inv_mod(x: u64) -> u64 { +const fn const_inv_mod(q: u64, x: u64) -> u64 { // by Fermat's Little Theorem, x^-1 mod q \equiv x^{q-2} mod q - const_exp_mod::(x, Q - 2) + const_exp_mod(q, x, q - 2) } #[cfg(test)] mod tests { use super::*; + use crate::Ring; use anyhow::Result; - use std::array; #[test] fn test_ntt() -> Result<()> { - const Q: u64 = 2u64.pow(16) + 1; - const N: usize = 4; + let q: u64 = 2u64.pow(16) + 1; + let n: usize = 4; + let param = RingParam { q, n }; - let a: [u64; N] = [1u64, 2, 3, 4]; - let a: [Zq; N] = array::from_fn(|i| Zq::from_u64(a[i])); + let a: Vec = vec![1u64, 2, 3, 4]; + let a: Rq = Rq::from_vec_u64(¶m, a); - let a_ntt = NTT::::ntt(a); + let a_ntt = NTT::ntt(&a); - let a_intt = NTT::::intt(a_ntt); + let a_intt = NTT::intt(&a_ntt); dbg!(&a); dbg!(&a_ntt); dbg!(&a_intt); - dbg!(NTT::::ROOT_OF_UNITY); - dbg!(NTT::::ROOTS_OF_UNITY); + // dbg!(NTT::ROOT_OF_UNITY); + // dbg!(NTT::ROOTS_OF_UNITY); assert_eq!(a, a_intt); Ok(()) @@ -168,18 +216,18 @@ mod tests { #[test] fn test_ntt_loop() -> Result<()> { - const Q: u64 = 2u64.pow(16) + 1; - const N: usize = 512; + let q: u64 = 2u64.pow(16) + 1; + let n: usize = 512; + let param = RingParam { q, n }; - use rand::distributions::Distribution; use rand::distributions::Uniform; let mut rng = rand::thread_rng(); - let dist = Uniform::new(0_f64, Q as f64); + let dist = Uniform::new(0_f64, q as f64); - for _ in 0..100 { - let a: [Zq; N] = array::from_fn(|_| Zq::from_f64(dist.sample(&mut rng))); - let a_ntt = NTT::::ntt(a); - let a_intt = NTT::::intt(a_ntt); + for _ in 0..1000 { + let a: Rq = Rq::rand(&mut rng, dist, ¶m); + let a_ntt = NTT::ntt(&a); + let a_intt = NTT::intt(&a_ntt); assert_eq!(a, a_intt); } Ok(()) diff --git a/arith/src/ntt_fixedsize.rs b/arith/src/ntt_fixedsize.rs new file mode 100644 index 0000000..90b8d78 --- /dev/null +++ b/arith/src/ntt_fixedsize.rs @@ -0,0 +1,187 @@ +//! Implementation of the NTT & iNTT, following the CT & GS algorighms, more details in +//! https://eprint.iacr.org/2017/727.pdf, some notes at +//! https://github.com/arnaucube/math/blob/master/notes_ntt.pdf . +use crate::zq::Zq; + +#[derive(Debug)] +pub struct NTT {} + +impl NTT { + const N_INV: Zq = Zq(const_inv_mod::(N as u64)); + // since we work over Zq[X]/(X^N+1) (negacyclic), get the 2*N-th root of unity + pub(crate) const ROOT_OF_UNITY: u64 = primitive_root_of_unity::(2 * N); + pub(crate) const ROOTS_OF_UNITY: [Zq; N] = roots_of_unity(Self::ROOT_OF_UNITY); + const ROOTS_OF_UNITY_INV: [Zq; N] = roots_of_unity_inv(Self::ROOTS_OF_UNITY); +} + +impl NTT { + /// implements the Cooley-Tukey (CT) algorithm. Details at + /// https://eprint.iacr.org/2017/727.pdf, also some notes at section 3.1 of + /// https://github.com/arnaucube/math/blob/master/notes_ntt.pdf + pub fn ntt(a: [Zq; N]) -> [Zq; N] { + let mut t = N / 2; + let mut m = 1; + let mut r: [Zq; N] = a.clone(); + while m < N { + let mut k = 0; + for i in 0..m { + let S: Zq = Self::ROOTS_OF_UNITY[m + i]; + for j in k..k + t { + let U: Zq = r[j]; + let V: Zq = r[j + t] * S; + r[j] = U + V; + r[j + t] = U - V; + } + k = k + 2 * t; + } + t /= 2; + m *= 2; + } + r + } + + /// implements the Cooley-Tukey (CT) algorithm. Details at + /// https://eprint.iacr.org/2017/727.pdf, also some notes at section 3.2 of + /// https://github.com/arnaucube/math/blob/master/notes_ntt.pdf + pub fn intt(a: [Zq; N]) -> [Zq; N] { + let mut t = 1; + let mut m = N / 2; + let mut r: [Zq; N] = a.clone(); + while m > 0 { + let mut k = 0; + for i in 0..m { + let S: Zq = Self::ROOTS_OF_UNITY_INV[m + i]; + for j in k..k + t { + let U: Zq = r[j]; + let V: Zq = r[j + t]; + r[j] = U + V; + r[j + t] = (U - V) * S; + } + k += 2 * t; + } + t *= 2; + m /= 2; + } + for i in 0..N { + r[i] = r[i] * Self::N_INV; + } + r + } +} + +/// computes a primitive N-th root of unity using the method described by Thomas +/// Pornin in https://crypto.stackexchange.com/a/63616 +const fn primitive_root_of_unity(N: usize) -> u64 { + assert!(N.is_power_of_two()); + assert!((Q - 1) % N as u64 == 0); + + let n: u64 = N as u64; + let mut k = 1; + while k < Q { + // alternatively could get a random k at each iteration, if so, add the following if: + // `if k == 0 { continue; }` + let w = const_exp_mod::(k, (Q - 1) / n); + if const_exp_mod::(w, n / 2) != 1 { + return w; // w is a primitive N-th root of unity + } + k += 1; + } + panic!("No primitive root of unity"); +} + +const fn roots_of_unity(w: u64) -> [Zq; N] { + let mut r: [Zq; N] = [Zq(0u64); N]; + let mut i = 0; + let log_n = N.ilog2(); + while i < N { + // (return the roots in bit-reverset order) + let j = ((i as u64).reverse_bits() >> (64 - log_n)) as usize; + r[i] = Zq(const_exp_mod::(w, j as u64)); + i += 1; + } + r +} + +const fn roots_of_unity_inv(v: [Zq; N]) -> [Zq; N] { + // assumes that the inputted roots are already in bit-reverset order + let mut r: [Zq; N] = [Zq(0u64); N]; + let mut i = 0; + while i < N { + r[i] = Zq(const_inv_mod::(v[i].0)); + i += 1; + } + r +} + +/// returns x^k mod Q +const fn const_exp_mod(x: u64, k: u64) -> u64 { + // work on u128 to avoid overflow + let mut r = 1u128; + let mut x = x as u128; + let mut k = k as u128; + x = x % Q as u128; + // exponentiation by square strategy + while k > 0 { + if k % 2 == 1 { + r = (r * x) % Q as u128; + } + x = (x * x) % Q as u128; + k /= 2; + } + r as u64 +} + +/// returns x^-1 mod Q +const fn const_inv_mod(x: u64) -> u64 { + // by Fermat's Little Theorem, x^-1 mod q \equiv x^{q-2} mod q + const_exp_mod::(x, Q - 2) +} + +#[cfg(test)] +mod tests { + use super::*; + + use anyhow::Result; + use std::array; + + #[test] + fn test_ntt() -> Result<()> { + const Q: u64 = 2u64.pow(16) + 1; + const N: usize = 4; + + let a: [u64; N] = [1u64, 2, 3, 4]; + let a: [Zq; N] = array::from_fn(|i| Zq::from_u64(a[i])); + + let a_ntt = NTT::::ntt(a); + + let a_intt = NTT::::intt(a_ntt); + + dbg!(&a); + dbg!(&a_ntt); + dbg!(&a_intt); + dbg!(NTT::::ROOT_OF_UNITY); + dbg!(NTT::::ROOTS_OF_UNITY); + + assert_eq!(a, a_intt); + Ok(()) + } + + #[test] + fn test_ntt_loop() -> Result<()> { + const Q: u64 = 2u64.pow(16) + 1; + const N: usize = 512; + + use rand::distributions::Distribution; + use rand::distributions::Uniform; + let mut rng = rand::thread_rng(); + let dist = Uniform::new(0_f64, Q as f64); + + for _ in 0..100 { + let a: [Zq; N] = array::from_fn(|_| Zq::from_f64(dist.sample(&mut rng))); + let a_ntt = NTT::::ntt(a); + let a_intt = NTT::::intt(a_ntt); + assert_eq!(a, a_intt); + } + Ok(()) + } +} diff --git a/arith/src/ring.rs b/arith/src/ring.rs index f6a40ec..3f39f1e 100644 --- a/arith/src/ring.rs +++ b/arith/src/ring.rs @@ -3,6 +3,12 @@ use std::fmt::Debug; use std::iter::Sum; use std::ops::{Add, AddAssign, Mul, Neg, Sub, SubAssign}; +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct RingParam { + pub q: u64, // TODO think if really needed or it's fine with coeffs[0].q + pub n: usize, +} + /// Represents a ring element. Currently implemented by ring_nq.rs#Rq and /// ring_torus.rs#Tn. Is not a 'pure algebraic ring', but more a custom trait /// definition which includes methods like `mod_switch`. @@ -21,27 +27,25 @@ pub trait Ring: + PartialEq + Debug + Clone - + Copy + // + Copy + Sum<::Output> + Sum<::Output> { /// C defines the coefficient type type C: Debug + Clone; - const Q: u64; - const N: usize; - + fn param(&self) -> RingParam; fn coeffs(&self) -> Vec; - fn zero() -> Self; + fn zero(param: &RingParam) -> Self; // note/wip/warning: dist (0,q) with f64, will output more '0=q' elements than other values - fn rand(rng: impl Rng, dist: impl Distribution) -> Self; + fn rand(rng: impl Rng, dist: impl Distribution, param: &RingParam) -> Self; - fn from_vec(coeffs: Vec) -> Self; + fn from_vec(param: &RingParam, coeffs: Vec) -> Self; fn decompose(&self, beta: u32, l: u32) -> Vec; - fn remodule(&self) -> impl Ring; - fn mod_switch(&self) -> impl Ring; + fn remodule(&self, p:u64) -> impl Ring; + fn mod_switch(&self, p:u64) -> impl Ring; /// returns [ [(num/den) * self].round() ] mod q /// ie. performs the multiplication and division over f64, and then it diff --git a/arith/src/ring_n.rs b/arith/src/ring_n.rs index 1a36a9f..67eb17f 100644 --- a/arith/src/ring_n.rs +++ b/arith/src/ring_n.rs @@ -1,45 +1,43 @@ //! Polynomial ring Z[X]/(X^N+1) //! -use anyhow::Result; +use itertools::zip_eq; use rand::{distributions::Distribution, Rng}; -use std::array; use std::fmt; use std::iter::Sum; -use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; - -use crate::Ring; +use std::ops::{Add, AddAssign, Mul, Neg, Sub, SubAssign}; // TODO rename to not have name conflicts with the Ring trait (R: Ring) // PolynomialRing element, where the PolynomialRing is R = Z[X]/(X^n +1) -#[derive(Clone, Copy)] -pub struct R(pub [i64; N]); - -// impl Ring for R { -impl R { - // type C = i64; - // const Q: u64 = i64::MAX as u64; // WIP - // const N: usize = N; +#[derive(Clone)] +pub struct R { + pub n: usize, + pub coeffs: Vec, +} +impl R { pub fn coeffs(&self) -> Vec { - self.0.to_vec() + self.coeffs.clone() } - fn zero() -> Self { - let coeffs: [i64; N] = array::from_fn(|_| 0i64); - Self(coeffs) + fn zero(n: usize) -> Self { + Self { + n, + coeffs: vec![0i64; n], + } } - fn rand(mut rng: impl Rng, dist: impl Distribution) -> Self { - // let coeffs: [i64; N] = array::from_fn(|_| Self::C::rand(&mut rng, &dist)); - let coeffs: [i64; N] = array::from_fn(|_| dist.sample(&mut rng).round() as i64); - Self(coeffs) - // let coeffs: [C; N] = array::from_fn(|_| Zq::from_u64(dist.sample(&mut rng))); - // Self(coeffs) + fn rand(mut rng: impl Rng, dist: impl Distribution, n: usize) -> Self { + Self { + n, + coeffs: std::iter::repeat_with(|| dist.sample(&mut rng).round() as i64) + .take(n) + .collect(), + } } - pub fn from_vec(coeffs: Vec) -> Self { + pub fn from_vec(n: usize, coeffs: Vec) -> Self { let mut p = coeffs; - modulus::(&mut p); - Self(array::from_fn(|i| p[i])) + modulus(n, &mut p); + Self { n, coeffs: p } } /* @@ -71,34 +69,38 @@ impl R { */ } -impl From> for R { - fn from(rq: crate::ring_nq::Rq) -> Self { - Self::from_vec_u64(rq.coeffs().to_vec().iter().map(|e| e.0).collect()) +impl From for R { + fn from(rq: crate::ring_nq::Rq) -> Self { + Self::from_vec_u64( + rq.param.n, + rq.coeffs().to_vec().iter().map(|e| e.v).collect(), + ) } } -impl R { - // pub fn coeffs(&self) -> [i64; N] { - // self.0 - // } - pub fn to_rq(self) -> crate::Rq { - crate::Rq::::from(self) +impl R { + pub fn to_rq(self, q: u64) -> crate::Rq { + crate::Rq::from((q, self)) } // this method is mostly for tests - pub fn from_vec_u64(coeffs: Vec) -> Self { + pub fn from_vec_u64(n: usize, coeffs: Vec) -> Self { let coeffs_i64 = coeffs.iter().map(|c| *c as i64).collect(); - Self::from_vec(coeffs_i64) + Self::from_vec(n, coeffs_i64) } - pub fn from_vec_f64(coeffs: Vec) -> Self { + pub fn from_vec_f64(n: usize, coeffs: Vec) -> Self { let coeffs_i64 = coeffs.iter().map(|c| c.round() as i64).collect(); - Self::from_vec(coeffs_i64) + Self::from_vec(n, coeffs_i64) } - pub fn new(coeffs: [i64; N]) -> Self { - Self(coeffs) + pub fn new(n: usize, coeffs: Vec) -> Self { + assert_eq!(n, coeffs.len()); + Self { n, coeffs } } pub fn mul_by_i64(&self, s: i64) -> Self { - Self(array::from_fn(|i| self.0[i] * s)) + Self { + n: self.n, + coeffs: self.coeffs.iter().map(|c_i| c_i * s).collect(), + } } pub fn infinity_norm(&self) -> u64 { @@ -108,10 +110,10 @@ impl R { .map(|x| x.abs() as u64) .fold(0, |a, b| a.max(b)) } - pub fn mod_centered_q(&self) -> R { - let q = Q as i64; + pub fn mod_centered_q(&self, q: u64) -> R { + let q = q as i64; let r = self - .0 + .coeffs .iter() .map(|v| { let mut res = v % q; @@ -121,190 +123,216 @@ impl R { res }) .collect::>(); - R::::from_vec(r) + R::from_vec(self.n, r) } } -pub fn mul_div_round( - v: Vec, - num: u64, - den: u64, -) -> crate::Rq { +pub fn mul_div_round(q: u64, n: usize, v: Vec, num: u64, den: u64) -> crate::Rq { // dbg!(&v); let r: Vec = v .iter() .map(|e| ((num as f64 * *e as f64) / den as f64).round()) .collect(); // dbg!(&r); - crate::Rq::::from_vec_f64(r) + crate::Rq::from_vec_f64(&crate::ring::RingParam { q, n }, r) } // TODO rename to make it clear that is not mod q, but mod X^N+1 // apply mod (X^N+1) -pub fn modulus(p: &mut Vec) { - if p.len() < N { +pub fn modulus(n: usize, p: &mut Vec) { + if p.len() < n { return; } - for i in N..p.len() { - p[i - N] = p[i - N].clone() - p[i].clone(); + for i in n..p.len() { + p[i - n] = p[i - n].clone() - p[i].clone(); p[i] = 0; } - p.truncate(N); + p.truncate(n); } -pub fn modulus_i128(p: &mut Vec) { - if p.len() < N { +pub fn modulus_i128(n: usize, p: &mut Vec) { + if p.len() < n { return; } - for i in N..p.len() { - p[i - N] = p[i - N].clone() - p[i].clone(); + for i in n..p.len() { + p[i - n] = p[i - n].clone() - p[i].clone(); p[i] = 0; } - p.truncate(N); + p.truncate(n); } -impl PartialEq for R { +impl PartialEq for R { fn eq(&self, other: &Self) -> bool { - self.0 == other.0 + self.coeffs == other.coeffs && self.n == other.n } } -impl Add> for R { +impl Add for R { type Output = Self; fn add(self, rhs: Self) -> Self { - Self(array::from_fn(|i| self.0[i] + rhs.0[i])) + assert_eq!(self.n, rhs.n); + Self { + n: self.n, + coeffs: zip_eq(self.coeffs, rhs.coeffs) + .map(|(l, r)| l + r) + .collect(), + } } } -impl Add<&R> for &R { - type Output = R; - - fn add(self, rhs: &R) -> Self::Output { - R(array::from_fn(|i| self.0[i] + rhs.0[i])) +impl Add<&R> for &R { + type Output = R; + + fn add(self, rhs: &R) -> Self::Output { + assert_eq!(self.n, rhs.n); + R { + n: self.n, + coeffs: zip_eq(self.coeffs.clone(), rhs.coeffs.clone()) + .map(|(l, r)| l + r) + .collect(), + } } } -impl AddAssign for R { +impl AddAssign for R { fn add_assign(&mut self, rhs: Self) { - for i in 0..N { - self.0[i] += rhs.0[i]; + assert_eq!(self.n, rhs.n); + for i in 0..self.n { + self.coeffs[i] += rhs.coeffs[i]; } } } -impl Sum> for R { - fn sum(iter: I) -> Self +impl Sum for R { + fn sum(mut iter: I) -> Self where I: Iterator, { - let mut acc = R::::zero(); - for e in iter { - acc += e; - } - acc + let first = iter.next().unwrap(); + iter.fold(first, |acc, x| acc + x) } } -impl Sub> for R { +impl Sub for R { type Output = Self; fn sub(self, rhs: Self) -> Self { - Self(array::from_fn(|i| self.0[i] - rhs.0[i])) + assert_eq!(self.n, rhs.n); + Self { + n: self.n, + coeffs: zip_eq(self.coeffs, rhs.coeffs) + .map(|(l, r)| l - r) + .collect(), + } } } -impl Sub<&R> for &R { - type Output = R; - - fn sub(self, rhs: &R) -> Self::Output { - R(array::from_fn(|i| self.0[i] - rhs.0[i])) +impl Sub<&R> for &R { + type Output = R; + + fn sub(self, rhs: &R) -> Self::Output { + assert_eq!(self.n, rhs.n); + R { + n: self.n, + coeffs: zip_eq(&self.coeffs, &rhs.coeffs) + .map(|(l, r)| l - r) + .collect(), + } } } -impl SubAssign for R { +impl SubAssign for R { fn sub_assign(&mut self, rhs: Self) { - for i in 0..N { - self.0[i] -= rhs.0[i]; + assert_eq!(self.n, rhs.n); + for i in 0..self.n { + self.coeffs[i] -= rhs.coeffs[i]; } } } -impl Mul> for R { +impl Mul for R { type Output = Self; fn mul(self, rhs: Self) -> Self { naive_poly_mul(&self, &rhs) } } -impl Mul<&R> for &R { - type Output = R; +impl Mul<&R> for &R { + type Output = R; - fn mul(self, rhs: &R) -> Self::Output { + fn mul(self, rhs: &R) -> Self::Output { naive_poly_mul(self, rhs) } } // TODO WIP -pub fn naive_poly_mul(poly1: &R, poly2: &R) -> R { - let poly1: Vec = poly1.0.iter().map(|c| *c as i128).collect(); - let poly2: Vec = poly2.0.iter().map(|c| *c as i128).collect(); - let mut result: Vec = vec![0; (N * 2) - 1]; - for i in 0..N { - for j in 0..N { +pub fn naive_poly_mul(poly1: &R, poly2: &R) -> R { + assert_eq!(poly1.n, poly2.n); + let n = poly1.n; + + let poly1: Vec = poly1.coeffs.iter().map(|c| *c as i128).collect(); + let poly2: Vec = poly2.coeffs.iter().map(|c| *c as i128).collect(); + let mut result: Vec = vec![0; (n * 2) - 1]; + for i in 0..n { + for j in 0..n { result[i + j] = result[i + j] + poly1[i] * poly2[j]; } } // apply mod (X^N + 1)) // R::::from_vec(result.iter().map(|c| *c as i64).collect()) - modulus_i128::(&mut result); + modulus_i128(n, &mut result); // dbg!(&result); // dbg!(R::(array::from_fn(|i| result[i] as i64)).coeffs()); + let result_i64: Vec = result.iter().map(|c_i| *c_i as i64).collect(); + let r = R::from_vec(n, result_i64); // sanity check: check that there are no coeffs > i64_max assert_eq!( result, - R::(array::from_fn(|i| result[i] as i64)) - .coeffs() - .iter() - .map(|c| *c as i128) - .collect::>() + r.coeffs.iter().map(|c| *c as i128).collect::>() ); - R(array::from_fn(|i| result[i] as i64)) + r } -pub fn naive_mul_2(poly1: &Vec, poly2: &Vec) -> Vec { - let mut result: Vec = vec![0; (N * 2) - 1]; - for i in 0..N { - for j in 0..N { +pub fn naive_mul_2(n: usize, poly1: &Vec, poly2: &Vec) -> Vec { + let mut result: Vec = vec![0; (n * 2) - 1]; + for i in 0..n { + for j in 0..n { result[i + j] = result[i + j] + poly1[i] * poly2[j]; } } // apply mod (X^N + 1)) // R::::from_vec(result.iter().map(|c| *c as i64).collect()) - modulus_i128::(&mut result); + modulus_i128(n, &mut result); result } -pub fn naive_mul(poly1: &R, poly2: &R) -> Vec { - let poly1: Vec = poly1.0.iter().map(|c| *c as i128).collect(); - let poly2: Vec = poly2.0.iter().map(|c| *c as i128).collect(); - let mut result = vec![0; (N * 2) - 1]; - for i in 0..N { - for j in 0..N { +pub fn naive_mul(poly1: &R, poly2: &R) -> Vec { + assert_eq!(poly1.n, poly2.n); + let n = poly1.n; + + let poly1: Vec = poly1.coeffs.iter().map(|c| *c as i128).collect(); + let poly2: Vec = poly2.coeffs.iter().map(|c| *c as i128).collect(); + let mut result = vec![0; (n * 2) - 1]; + for i in 0..n { + for j in 0..n { result[i + j] = result[i + j] + poly1[i] * poly2[j]; } } result.iter().map(|c| *c as i64).collect() } -pub fn naive_mul_TMP(poly1: &R, poly2: &R) -> Vec { - let poly1: Vec = poly1.0.iter().map(|c| *c as i128).collect(); - let poly2: Vec = poly2.0.iter().map(|c| *c as i128).collect(); - let mut result: Vec = vec![0; (N * 2) - 1]; - for i in 0..N { - for j in 0..N { +pub fn naive_mul_TMP(poly1: &R, poly2: &R) -> Vec { + assert_eq!(poly1.n, poly2.n); + let n = poly1.n; + + let poly1: Vec = poly1.coeffs.iter().map(|c| *c as i128).collect(); + let poly2: Vec = poly2.coeffs.iter().map(|c| *c as i128).collect(); + let mut result: Vec = vec![0; (n * 2) - 1]; + for i in 0..n { + for j in 0..n { result[i + j] = result[i + j] + poly1[i] * poly2[j]; } } // dbg!(&result); - modulus_i128::(&mut result); + modulus_i128(n, &mut result); // for c_i in result.iter() { // println!("---"); // println!("{:?}", &c_i); @@ -316,8 +344,8 @@ pub fn naive_mul_TMP(poly1: &R, poly2: &R) -> Vec { } // wip -pub fn mod_centered_q(p: Vec) -> R { - let q: i128 = Q as i128; +pub fn mod_centered_q(q: u64, n: usize, p: Vec) -> R { + let q: i128 = q as i128; let r = p .iter() .map(|v| { @@ -328,10 +356,10 @@ pub fn mod_centered_q(p: Vec) -> R { res }) .collect::>(); - R::::from_vec(r.iter().map(|v| *v as i64).collect::>()) + R::from_vec(n, r.iter().map(|v| *v as i64).collect::>()) } -impl Mul for R { +impl Mul for R { type Output = Self; fn mul(self, s: i64) -> Self { @@ -339,34 +367,38 @@ impl Mul for R { } } // mul by u64 -impl Mul for R { +impl Mul for R { type Output = Self; fn mul(self, s: u64) -> Self { self.mul_by_i64(s as i64) } } -impl Mul<&u64> for &R { - type Output = R; +impl Mul<&u64> for &R { + type Output = R; fn mul(self, s: &u64) -> Self::Output { self.mul_by_i64(*s as i64) } } -impl Neg for R { +impl Neg for R { type Output = Self; fn neg(self) -> Self::Output { - Self(array::from_fn(|i| -self.0[i])) + // Self(array::from_fn(|i| -self.0[i])) + Self { + n: self.n, + coeffs: self.coeffs.iter().map(|c_i| -c_i).collect(), + } } } -impl R { +impl R { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let mut str = ""; let mut zero = true; - for (i, coeff) in self.0.iter().enumerate().rev() { + for (i, coeff) in self.coeffs.iter().enumerate().rev() { if *coeff == 0 { continue; } @@ -395,18 +427,18 @@ impl R { f.write_str(" mod Z")?; f.write_str("/(X^")?; - f.write_str(N.to_string().as_str())?; + f.write_str(self.n.to_string().as_str())?; f.write_str("+1)")?; Ok(()) } } -impl fmt::Display for R { +impl fmt::Display for R { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { self.fmt(f)?; Ok(()) } } -impl fmt::Debug for R { +impl fmt::Debug for R { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { self.fmt(f)?; Ok(()) @@ -420,38 +452,33 @@ mod tests { #[test] fn test_mul() -> Result<()> { - const Q: u64 = 2u64.pow(16) + 1; - const N: usize = 2; - let q: i64 = Q as i64; + let n: usize = 2; + let q: i64 = (2u64.pow(16) + 1) as i64; // test vectors generated with SageMath - let a: [i64; N] = [q - 1, q - 1]; - let b: [i64; N] = [q - 1, q - 1]; - let c: [i64; N] = [0, 8589934592]; - test_mul_opt::(a, b, c)?; + let a: Vec = vec![q - 1, q - 1]; + let b: Vec = vec![q - 1, q - 1]; + let c: Vec = vec![0, 8589934592]; + test_mul_opt(n, a, b, c)?; - let a: [i64; N] = [1, q - 1]; - let b: [i64; N] = [1, q - 1]; - let c: [i64; N] = [-4294967295, 131072]; - test_mul_opt::(a, b, c)?; + let a: Vec = vec![1, q - 1]; + let b: Vec = vec![1, q - 1]; + let c: Vec = vec![-4294967295, 131072]; + test_mul_opt(n, a, b, c)?; Ok(()) } - fn test_mul_opt( - a: [i64; N], - b: [i64; N], - expected_c: [i64; N], - ) -> Result<()> { - let mut a = R::new(a); - let mut b = R::new(b); + fn test_mul_opt(n: usize, a: Vec, b: Vec, expected_c: Vec) -> Result<()> { + let mut a = R::new(n, a); + let mut b = R::new(n, b); dbg!(&a); dbg!(&b); - let expected_c = R::new(expected_c); + let expected_c = R::new(n, expected_c); let mut c = naive_mul(&mut a, &mut b); - modulus::(&mut c); - dbg!(R::::from_vec(c.clone())); - assert_eq!(c, expected_c.0.to_vec()); + modulus(n, &mut c); + dbg!(R::from_vec(n, c.clone())); + assert_eq!(c, expected_c.coeffs); Ok(()) } } diff --git a/arith/src/ring_nq.rs b/arith/src/ring_nq.rs index 7510ac3..d51c1c2 100644 --- a/arith/src/ring_nq.rs +++ b/arith/src/ring_nq.rs @@ -2,8 +2,8 @@ //! use anyhow::{anyhow, Result}; +use itertools::zip_eq; use rand::{distributions::Distribution, Rng}; -use std::array; use std::fmt; use std::iter::Sum; use std::ops::{Add, AddAssign, Mul, Neg, Sub, SubAssign}; @@ -11,82 +11,91 @@ use std::ops::{Add, AddAssign, Mul, Neg, Sub, SubAssign}; use crate::ntt::NTT; use crate::zq::{modulus_u64, Zq}; -use crate::Ring; +use crate::{Ring, RingParam}; -// NOTE: currently using fixed-size arrays, but pending to see if with -// real-world parameters the stack can keep up; if not will move everything to -// use Vec. /// PolynomialRing element, where the PolynomialRing is R = Z_q[X]/(X^n +1) /// The implementation assumes that q is prime. -#[derive(Clone, Copy)] -pub struct Rq { - pub(crate) coeffs: [Zq; N], +#[derive(Clone)] +pub struct Rq { + pub param: RingParam, + + pub(crate) coeffs: Vec, // evals are set when doig a PRxPR multiplication, so it can be reused in future // multiplications avoiding recomputing it - pub(crate) evals: Option<[Zq; N]>, + pub(crate) evals: Option>, } -impl Ring for Rq { - type C = Zq; - - const Q: u64 = Q; - const N: usize = N; +impl Ring for Rq { + type C = Zq; + fn param(&self) -> RingParam { + self.param + } fn coeffs(&self) -> Vec { self.coeffs.to_vec() } - fn zero() -> Self { - let coeffs = array::from_fn(|_| Zq::zero()); + fn zero(param: &RingParam) -> Self { Self { - coeffs, + param: param.clone(), + coeffs: vec![Zq::zero(param.q); param.n], evals: None, } } - fn rand(mut rng: impl Rng, dist: impl Distribution) -> Self { - // let coeffs: [Zq; N] = array::from_fn(|_| Zq::from_u64(dist.sample(&mut rng))); - let coeffs: [Zq; N] = array::from_fn(|_| Self::C::rand(&mut rng, &dist)); + fn rand(mut rng: impl Rng, dist: impl Distribution, param: &RingParam) -> Self { Self { - coeffs, + param: param.clone(), + coeffs: std::iter::repeat_with(|| Self::C::rand(&mut rng, &dist, param.q)) + .take(param.n) + .collect(), evals: None, } } - fn from_vec(coeffs: Vec>) -> Self { + fn from_vec(param: &RingParam, coeffs: Vec) -> Self { let mut p = coeffs; - modulus::(&mut p); - let coeffs = array::from_fn(|i| p[i]); + modulus(param.q, param.n, &mut p); Self { - coeffs, + param: param.clone(), + coeffs: p, evals: None, } } // returns the decomposition of each polynomial coefficient, such - // decomposition will be a vecotor of length N, containint N vectors of Zq + // decomposition will be a vector of length N, containing N vectors of Zq fn decompose(&self, beta: u32, l: u32) -> Vec { - let elems: Vec>> = self.coeffs.iter().map(|r| r.decompose(beta, l)).collect(); + let elems: Vec> = self.coeffs.iter().map(|r| r.decompose(beta, l)).collect(); // transpose it - let r: Vec>> = (0..elems[0].len()) + let r: Vec> = (0..elems[0].len()) .map(|i| (0..elems.len()).map(|j| elems[j][i]).collect()) .collect(); // convert it to Rq - r.iter().map(|a_i| Self::from_vec(a_i.clone())).collect() + r.iter() + .map(|a_i| Self::from_vec(&self.param, a_i.clone())) + .collect() } // Warning: this method will behave differently depending on the values P and Q: // if Q=P, it crops to mod P - fn remodule(&self) -> Rq { - Rq::::from_vec_u64(self.coeffs().iter().map(|m_i| m_i.0).collect()) + fn remodule(&self, p: u64) -> Rq { + let param = RingParam { + q: p, + n: self.param.n, + }; + Rq::from_vec_u64(¶m, self.coeffs().iter().map(|m_i| m_i.v).collect()) } /// perform the mod switch operation from Q to Q', where Q2=Q' - // fn mod_switch(&self) -> impl Ring { - fn mod_switch(&self) -> Rq { - // assert_eq!(N, M); // sanity check - Rq:: { - coeffs: array::from_fn(|i| self.coeffs[i].mod_switch::

()), + fn mod_switch(&self, p: u64) -> Rq { + let param = RingParam { + q: p, + n: self.param.n, + }; + Rq { + param, + coeffs: self.coeffs.iter().map(|c_i| c_i.mod_switch(p)).collect(), evals: None, } } @@ -98,105 +107,138 @@ impl Ring for Rq { let r: Vec = self .coeffs() .iter() - .map(|e| ((num as f64 * e.0 as f64) / den as f64).round()) + .map(|e| ((num as f64 * e.v as f64) / den as f64).round()) .collect(); - Rq::::from_vec_f64(r) + Rq::from_vec_f64(&self.param, r) } } -impl From> for Rq { - fn from(r: crate::ring_n::R) -> Self { +impl From<(u64, crate::ring_n::R)> for Rq { + fn from(qr: (u64, crate::ring_n::R)) -> Self { + let (q, r) = qr; + assert_eq!(r.n, r.coeffs.len()); + Self::from_vec( + &RingParam { q, n: r.n }, r.coeffs() .iter() - .map(|e| Zq::::from_f64(*e as f64)) + .map(|e| Zq::from_f64(q, *e as f64)) .collect(), ) } } // apply mod (X^N+1) -pub fn modulus(p: &mut Vec>) { - if p.len() < N { +pub fn modulus(q: u64, n: usize, p: &mut Vec) { + if p.len() < n { return; } - for i in N..p.len() { - p[i - N] = p[i - N].clone() - p[i].clone(); - p[i] = Zq(0); + for i in n..p.len() { + p[i - n] = p[i - n].clone() - p[i].clone(); + p[i] = Zq::zero(q); } - p.truncate(N); + p.truncate(n); } -// PR stands for PolynomialRing -impl Rq { - pub fn coeffs(&self) -> [Zq; N] { - self.coeffs +impl Rq { + pub fn coeffs(&self) -> Vec { + self.coeffs.clone() } pub fn compute_evals(&mut self) { - self.evals = Some(NTT::::ntt(self.coeffs)); + self.evals = Some(NTT::ntt(self).coeffs); + // TODO improve, ntt returns Rq but here just needs Vec } - pub fn to_r(self) -> crate::R { - crate::R::::from(self) + pub fn to_r(self) -> crate::R { + crate::R::from(self) } - // TODO rm since it is implemented in Ring trait impl - // pub fn zero() -> Self { - // let coeffs = array::from_fn(|_| Zq::zero()); - // Self { - // coeffs, - // evals: None, - // } - // } // this method is mostly for tests - pub fn from_vec_u64(coeffs: Vec) -> Self { - let coeffs_mod_q = coeffs.iter().map(|c| Zq::from_u64(*c)).collect(); - Self::from_vec(coeffs_mod_q) + pub fn from_vec_u64(param: &RingParam, coeffs: Vec) -> Self { + let coeffs_mod_q: Vec = coeffs.iter().map(|c| Zq::from_u64(param.q, *c)).collect(); + Self::from_vec(param, coeffs_mod_q) } - pub fn from_vec_f64(coeffs: Vec) -> Self { - let coeffs_mod_q = coeffs.iter().map(|c| Zq::from_f64(*c)).collect(); - Self::from_vec(coeffs_mod_q) + pub fn from_vec_f64(param: &RingParam, coeffs: Vec) -> Self { + let coeffs_mod_q: Vec = coeffs.iter().map(|c| Zq::from_f64(param.q, *c)).collect(); + Self::from_vec(param, coeffs_mod_q) } - pub fn from_vec_i64(coeffs: Vec) -> Self { - let coeffs_mod_q = coeffs.iter().map(|c| Zq::from_f64(*c as f64)).collect(); - Self::from_vec(coeffs_mod_q) + pub fn from_vec_i64(param: &RingParam, coeffs: Vec) -> Self { + let coeffs_mod_q: Vec = coeffs + .iter() + .map(|c| Zq::from_f64(param.q, *c as f64)) + .collect(); + Self::from_vec(param, coeffs_mod_q) } - pub fn new(coeffs: [Zq; N], evals: Option<[Zq; N]>) -> Self { - Self { coeffs, evals } + pub fn new(param: &RingParam, coeffs: Vec, evals: Option>) -> Self { + Self { + param: *param, + coeffs, + evals, + } } - pub fn rand_abs(mut rng: impl Rng, dist: impl Distribution) -> Result { - let coeffs: [Zq; N] = array::from_fn(|_| Zq::from_f64(dist.sample(&mut rng).abs())); + pub fn rand_abs( + mut rng: impl Rng, + dist: impl Distribution, + param: &RingParam, + ) -> Result { Ok(Self { - coeffs, + param: *param, + coeffs: std::iter::repeat_with(|| Zq::from_f64(param.q, dist.sample(&mut rng).abs())) + .take(param.n) + .collect(), evals: None, }) } - pub fn rand_f64_abs(mut rng: impl Rng, dist: impl Distribution) -> Result { - let coeffs: [Zq; N] = array::from_fn(|_| Zq::from_f64(dist.sample(&mut rng).abs())); + pub fn rand_f64_abs( + mut rng: impl Rng, + dist: impl Distribution, + param: &RingParam, + ) -> Result { Ok(Self { - coeffs, + param: *param, + coeffs: std::iter::repeat_with(|| Zq::from_f64(param.q, dist.sample(&mut rng).abs())) + .take(param.n) + .collect(), evals: None, }) } - pub fn rand_f64(mut rng: impl Rng, dist: impl Distribution) -> Result { - let coeffs: [Zq; N] = array::from_fn(|_| Zq::from_f64(dist.sample(&mut rng))); + pub fn rand_f64( + mut rng: impl Rng, + dist: impl Distribution, + param: &RingParam, + ) -> Result { Ok(Self { - coeffs, + param: *param, + coeffs: std::iter::repeat_with(|| Zq::from_f64(param.q, dist.sample(&mut rng))) + .take(param.n) + .collect(), evals: None, }) } - pub fn rand_u64(mut rng: impl Rng, dist: impl Distribution) -> Result { - let coeffs: [Zq; N] = array::from_fn(|_| Zq::from_u64(dist.sample(&mut rng))); + pub fn rand_u64( + mut rng: impl Rng, + dist: impl Distribution, + param: &RingParam, + ) -> Result { Ok(Self { - coeffs, + param: *param, + coeffs: std::iter::repeat_with(|| Zq::from_u64(param.q, dist.sample(&mut rng))) + .take(param.n) + .collect(), evals: None, }) } // WIP. returns random v \in {0,1}. // TODO {-1, 0, 1} - pub fn rand_bin(mut rng: impl Rng, dist: impl Distribution) -> Result { - let coeffs: [Zq; N] = array::from_fn(|_| Zq::from_bool(dist.sample(&mut rng))); + pub fn rand_bin( + mut rng: impl Rng, + dist: impl Distribution, + param: &RingParam, + ) -> Result { Ok(Rq { - coeffs, + param: *param, + coeffs: std::iter::repeat_with(|| Zq::from_bool(param.q, dist.sample(&mut rng))) + .take(param.n) + .collect(), evals: None, }) } @@ -208,36 +250,43 @@ impl Rq { // } // applies mod(T) to all coefficients of self - pub fn coeffs_mod(&self) -> Self { - Rq::::from_vec_u64( + pub fn coeffs_mod(&self, param: &RingParam, t: u64) -> Self { + Rq::from_vec_u64( + param, self.coeffs() .iter() - .map(|m_i| modulus_u64::(m_i.0)) + .map(|m_i| modulus_u64(t, m_i.v)) .collect(), ) } // TODO review if needed, or if with this interface - pub fn mul_by_matrix(&self, m: &Vec>>) -> Result>> { + pub fn mul_by_matrix(&self, m: &Vec>) -> Result> { matrix_vec_product(m, &self.coeffs.to_vec()) } - pub fn mul_by_zq(&self, s: &Zq) -> Self { + pub fn mul_by_zq(&self, s: &Zq) -> Self { Self { - coeffs: array::from_fn(|i| self.coeffs[i] * *s), + param: self.param, + coeffs: self.coeffs.iter().map(|c_i| *c_i * *s).collect(), evals: None, } } pub fn mul_by_u64(&self, s: u64) -> Self { - let s = Zq::from_u64(s); + let s = Zq::from_u64(self.param.q, s); Self { - coeffs: array::from_fn(|i| self.coeffs[i] * s), - // coeffs: self.coeffs.iter().map(|&e| e * s).collect(), + param: self.param, + coeffs: self.coeffs.iter().map(|&e| e * s).collect(), evals: None, } } pub fn mul_by_f64(&self, s: f64) -> Self { Self { - coeffs: array::from_fn(|i| Zq::from_f64(self.coeffs[i].0 as f64 * s)), + param: self.param, + coeffs: self + .coeffs + .iter() + .map(|c_i| Zq::from_f64(self.param.q, c_i.v as f64 * s)) + .collect(), evals: None, } } @@ -251,9 +300,9 @@ impl Rq { let r: Vec = self .coeffs() .iter() - .map(|e| (e.0 as f64 / s as f64).round()) + .map(|e| (e.v as f64 / s as f64).round()) .collect(); - Rq::::from_vec_f64(r) + Rq::from_vec_f64(&self.param, r) } fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { @@ -261,19 +310,19 @@ impl Rq { let mut str = ""; let mut zero = true; for (i, coeff) in self.coeffs.iter().enumerate().rev() { - if coeff.0 == 0 { + if coeff.v == 0 { continue; } zero = false; f.write_str(str)?; - if coeff.0 != 1 { - f.write_str(coeff.0.to_string().as_str())?; + if coeff.v != 1 { + f.write_str(coeff.v.to_string().as_str())?; if i > 0 { f.write_str("*")?; } } - if coeff.0 == 1 && i == 0 { - f.write_str(coeff.0.to_string().as_str())?; + if coeff.v == 1 && i == 0 { + f.write_str(coeff.v.to_string().as_str())?; } if i == 1 { f.write_str("x")?; @@ -288,9 +337,9 @@ impl Rq { } f.write_str(" mod Z_")?; - f.write_str(Q.to_string().as_str())?; + f.write_str(self.param.q.to_string().as_str())?; f.write_str("/(X^")?; - f.write_str(N.to_string().as_str())?; + f.write_str(self.param.n.to_string().as_str())?; f.write_str("+1)")?; Ok(()) } @@ -298,16 +347,20 @@ impl Rq { pub fn infinity_norm(&self) -> u64 { self.coeffs() .iter() - .map(|x| if x.0 > (Q / 2) { Q - x.0 } else { x.0 }) + .map(|x| { + if x.v > (self.param.q / 2) { + self.param.q - x.v + } else { + x.v + } + }) .fold(0, |a, b| a.max(b)) } - pub fn mod_centered_q(&self) -> crate::ring_n::R { - self.to_r().mod_centered_q::() + pub fn mod_centered_q(&self) -> crate::ring_n::R { + self.clone().to_r().mod_centered_q(self.param.q) } } -pub fn matrix_vec_product(m: &Vec>>, v: &Vec>) -> Result>> { - // assert_eq!(m.len(), m[0].len()); // TODO change to returning err - // assert_eq!(m.len(), v.len()); +pub fn matrix_vec_product(m: &Vec>, v: &Vec) -> Result> { if m.len() != m[0].len() { return Err(anyhow!("expected 'm' to be a square matrix")); } @@ -319,6 +372,8 @@ pub fn matrix_vec_product(m: &Vec>>, v: &Vec>) -> )); } + assert_eq!(m[0][0].q, v[0].q); // TODO change to returning err + Ok(m.iter() .map(|row| { row.iter() @@ -326,12 +381,15 @@ pub fn matrix_vec_product(m: &Vec>>, v: &Vec>) -> .map(|(&row_i, &v_i)| row_i * v_i) .sum() }) - .collect::>>()) + .collect::>()) } -pub fn transpose(m: &[Vec>]) -> Vec>> { +pub fn transpose(m: &[Vec]) -> Vec> { + assert!(m.len() > 0); + assert!(m[0].len() > 0); + let q = m[0][0].q; // TODO case when m[0].len()=0 // TODO non square matrix - let mut r: Vec>> = vec![vec![Zq(0); m[0].len()]; m.len()]; + let mut r: Vec> = vec![vec![Zq::zero(q); m[0].len()]; m.len()]; for (i, m_row) in m.iter().enumerate() { for (j, m_ij) in m_row.iter().enumerate() { r[j][i] = *m_ij; @@ -340,205 +398,221 @@ pub fn transpose(m: &[Vec>]) -> Vec>> { r } -impl PartialEq for Rq { +impl PartialEq for Rq { fn eq(&self, other: &Self) -> bool { - self.coeffs == other.coeffs + self.coeffs == other.coeffs && self.param == other.param } } -impl Add> for Rq { +impl Add for Rq { type Output = Self; fn add(self, rhs: Self) -> Self { + assert_eq!(self.param, rhs.param); Self { - coeffs: array::from_fn(|i| self.coeffs[i] + rhs.coeffs[i]), + param: self.param, + coeffs: zip_eq(self.coeffs, rhs.coeffs) + .map(|(l, r)| l + r) + .collect(), evals: None, } - // Self { - // coeffs: self - // .coeffs - // .iter() - // .zip(rhs.coeffs) - // .map(|(a, b)| *a + b) - // .collect(), - // evals: None, - // } - // Self(r.iter_mut().map(|e| e.r#mod()).collect()) // TODO mod should happen auto in + } } -impl Add<&Rq> for &Rq { - type Output = Rq; +impl Add<&Rq> for &Rq { + type Output = Rq; - fn add(self, rhs: &Rq) -> Self::Output { + fn add(self, rhs: &Rq) -> Self::Output { + assert_eq!(self.param, rhs.param); Rq { - coeffs: array::from_fn(|i| self.coeffs[i] + rhs.coeffs[i]), + param: self.param, + coeffs: zip_eq(self.coeffs.clone(), rhs.coeffs.clone()) + .map(|(l, r)| l + r) + .collect(), evals: None, } } } -impl AddAssign for Rq { +impl AddAssign for Rq { fn add_assign(&mut self, rhs: Self) { - for i in 0..N { + debug_assert_eq!(self.param, rhs.param); + for i in 0..self.param.n { self.coeffs[i] += rhs.coeffs[i]; } } } -impl Sum> for Rq { - fn sum(iter: I) -> Self +impl Sum for Rq { + fn sum(mut iter: I) -> Self where I: Iterator, { - let mut acc = Rq::::zero(); - for e in iter { - acc += e; - } - acc + let first = iter.next().unwrap(); + iter.fold(first, |acc, x| acc + x) } } -impl Sub> for Rq { +impl Sub for Rq { type Output = Self; fn sub(self, rhs: Self) -> Self { + assert_eq!(self.param, rhs.param); Self { - coeffs: array::from_fn(|i| self.coeffs[i] - rhs.coeffs[i]), + param: self.param, + coeffs: zip_eq(self.coeffs, rhs.coeffs) + .map(|(l, r)| l - r) + .collect(), evals: None, } } } -impl Sub<&Rq> for &Rq { - type Output = Rq; +impl Sub<&Rq> for &Rq { + type Output = Rq; - fn sub(self, rhs: &Rq) -> Self::Output { + fn sub(self, rhs: &Rq) -> Self::Output { + debug_assert_eq!(self.param, rhs.param); Rq { - coeffs: array::from_fn(|i| self.coeffs[i] - rhs.coeffs[i]), + param: self.param, + coeffs: zip_eq(self.coeffs.clone(), rhs.coeffs.clone()) + .map(|(l, r)| l - r) + .collect(), evals: None, } } } -impl SubAssign for Rq { +impl SubAssign for Rq { fn sub_assign(&mut self, rhs: Self) { - for i in 0..N { + debug_assert_eq!(self.param, rhs.param); + for i in 0..self.param.n { self.coeffs[i] -= rhs.coeffs[i]; } } } -impl Mul> for Rq { +impl Mul for Rq { type Output = Self; fn mul(self, rhs: Self) -> Self { mul(&self, &rhs) } } -impl Mul<&Rq> for &Rq { - type Output = Rq; +impl Mul<&Rq> for &Rq { + type Output = Rq; - fn mul(self, rhs: &Rq) -> Self::Output { + fn mul(self, rhs: &Rq) -> Self::Output { mul(self, rhs) } } // mul by Zq element -impl Mul> for Rq { +impl Mul for Rq { type Output = Self; - fn mul(self, s: Zq) -> Self { + fn mul(self, s: Zq) -> Self { self.mul_by_zq(&s) } } -impl Mul<&Zq> for &Rq { - type Output = Rq; +impl Mul<&Zq> for &Rq { + type Output = Rq; - fn mul(self, s: &Zq) -> Self::Output { + fn mul(self, s: &Zq) -> Self::Output { self.mul_by_zq(s) } } // mul by u64 -impl Mul for Rq { +impl Mul for Rq { type Output = Self; fn mul(self, s: u64) -> Self { self.mul_by_u64(s) } } -impl Mul<&u64> for &Rq { - type Output = Rq; +impl Mul<&u64> for &Rq { + type Output = Rq; fn mul(self, s: &u64) -> Self::Output { self.mul_by_u64(*s) } } // mul by f64 -impl Mul for Rq { +impl Mul for Rq { type Output = Self; fn mul(self, s: f64) -> Self { self.mul_by_f64(s) } } -impl Mul<&f64> for &Rq { - type Output = Rq; +impl Mul<&f64> for &Rq { + type Output = Rq; fn mul(self, s: &f64) -> Self::Output { self.mul_by_f64(*s) } } -impl Neg for Rq { +impl Neg for Rq { type Output = Self; fn neg(self) -> Self::Output { Self { - coeffs: array::from_fn(|i| -self.coeffs[i]), + param: self.param, + coeffs: self.coeffs.iter().map(|c_i| -*c_i).collect(), evals: None, } } } // note: this assumes that Q is prime -fn mul_mut(lhs: &mut Rq, rhs: &mut Rq) -> Rq { +fn mul_mut(lhs: &mut Rq, rhs: &mut Rq) -> Rq { + assert_eq!(lhs.param, rhs.param); + // reuse evaluations if already computed if !lhs.evals.is_some() { - lhs.evals = Some(NTT::::ntt(lhs.coeffs)); + lhs.evals = Some(NTT::ntt(lhs).coeffs); }; if !rhs.evals.is_some() { - rhs.evals = Some(NTT::::ntt(rhs.coeffs)); + rhs.evals = Some(NTT::ntt(rhs).coeffs); }; - let lhs_evals = lhs.evals.unwrap(); - let rhs_evals = rhs.evals.unwrap(); - - let c_ntt: [Zq; N] = array::from_fn(|i| lhs_evals[i] * rhs_evals[i]); - let c = NTT::::intt(c_ntt); - Rq::new(c, Some(c_ntt)) + let lhs_evals = lhs.evals.clone().unwrap(); + let rhs_evals = rhs.evals.clone().unwrap(); + + let c_ntt: Rq = Rq::from_vec( + &lhs.param, + zip_eq(lhs_evals, rhs_evals).map(|(l, r)| l * r).collect(), + ); + let c = NTT::intt(&c_ntt); + Rq::new(&lhs.param, c.coeffs, Some(c_ntt.coeffs)) } // note: this assumes that Q is prime -// TODO impl karatsuba for non-prime Q -fn mul(lhs: &Rq, rhs: &Rq) -> Rq { +// TODO impl karatsuba for non-prime Q. Alternatively check NTT with RNS trick. +fn mul(lhs: &Rq, rhs: &Rq) -> Rq { + assert_eq!(lhs.param, rhs.param); + // reuse evaluations if already computed - let lhs_evals = if lhs.evals.is_some() { - lhs.evals.unwrap() + let lhs_evals: Vec = if lhs.evals.is_some() { + lhs.evals.clone().unwrap() } else { - NTT::::ntt(lhs.coeffs) + NTT::ntt(lhs).coeffs }; - let rhs_evals = if rhs.evals.is_some() { - rhs.evals.unwrap() + let rhs_evals: Vec = if rhs.evals.is_some() { + rhs.evals.clone().unwrap() } else { - NTT::::ntt(rhs.coeffs) + NTT::ntt(rhs).coeffs }; - let c_ntt: [Zq; N] = array::from_fn(|i| lhs_evals[i] * rhs_evals[i]); - let c = NTT::::intt(c_ntt); - Rq::new(c, Some(c_ntt)) + let c_ntt: Rq = Rq::from_vec( + &lhs.param, + zip_eq(lhs_evals, rhs_evals).map(|(l, r)| l * r).collect(), + ); + let c = NTT::intt(&c_ntt); + Rq::new(&lhs.param, c.coeffs, Some(c_ntt.coeffs)) } -impl fmt::Display for Rq { +impl fmt::Display for Rq { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { self.fmt(f)?; Ok(()) } } -impl fmt::Debug for Rq { +impl fmt::Debug for Rq { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { self.fmt(f)?; Ok(()) @@ -552,31 +626,30 @@ mod tests { #[test] fn test_polynomial_ring() { // the test values used are generated with SageMath - const Q: u64 = 7; - const N: usize = 3; + let param = RingParam { q: 7, n: 3 }; // p = 1x + 2x^2 + 3x^3 + 4 x^4 + 5 x^5 in R=Z_q[X]/(X^n +1) - let p = Rq::::from_vec_u64(vec![0u64, 1, 2, 3, 4, 5]); + let p = Rq::from_vec_u64(¶m, vec![0u64, 1, 2, 3, 4, 5]); assert_eq!(p.to_string(), "4*x^2 + 4*x + 4 mod Z_7/(X^3+1)"); // try with coefficients bigger than Q - let p = Rq::::from_vec_u64(vec![0u64, 1, Q + 2, 3, 4, 5]); + let p = Rq::from_vec_u64(¶m, vec![0u64, 1, param.q + 2, 3, 4, 5]); assert_eq!(p.to_string(), "4*x^2 + 4*x + 4 mod Z_7/(X^3+1)"); // try with other ring - let p = Rq::<7, 4>::from_vec_u64(vec![0u64, 1, 2, 3, 4, 5]); + let p = Rq::from_vec_u64(&RingParam { q: 7, n: 4 }, vec![0u64, 1, 2, 3, 4, 5]); assert_eq!(p.to_string(), "3*x^3 + 2*x^2 + 3*x + 3 mod Z_7/(X^4+1)"); - let p = Rq::::from_vec_u64(vec![0u64, 0, 0, 0, 4, 5]); + let p = Rq::from_vec_u64(¶m, vec![0u64, 0, 0, 0, 4, 5]); assert_eq!(p.to_string(), "2*x^2 + 3*x mod Z_7/(X^3+1)"); - let p = Rq::::from_vec_u64(vec![5u64, 4, 5, 2, 1, 0]); + let p = Rq::from_vec_u64(¶m, vec![5u64, 4, 5, 2, 1, 0]); assert_eq!(p.to_string(), "5*x^2 + 3*x + 3 mod Z_7/(X^3+1)"); - let a = Rq::::from_vec_u64(vec![0u64, 1, 2, 3, 4, 5]); + let a = Rq::from_vec_u64(¶m, vec![0u64, 1, 2, 3, 4, 5]); assert_eq!(a.to_string(), "4*x^2 + 4*x + 4 mod Z_7/(X^3+1)"); - let b = Rq::::from_vec_u64(vec![5u64, 4, 3, 2, 1, 0]); + let b = Rq::from_vec_u64(¶m, vec![5u64, 4, 3, 2, 1, 0]); assert_eq!(b.to_string(), "3*x^2 + 3*x + 3 mod Z_7/(X^3+1)"); // add @@ -593,34 +666,37 @@ mod tests { #[test] fn test_mul() -> Result<()> { - const Q: u64 = 2u64.pow(16) + 1; - const N: usize = 4; + let param = RingParam { + q: 2u64.pow(16) + 1, + n: 4, + }; - let a: [u64; N] = [1u64, 2, 3, 4]; - let b: [u64; N] = [1u64, 2, 3, 4]; - let c: [u64; N] = [65513, 65517, 65531, 20]; - test_mul_opt::(a, b, c)?; + let a: Vec = vec![1u64, 2, 3, 4]; + let b: Vec = vec![1u64, 2, 3, 4]; + let c: Vec = vec![65513, 65517, 65531, 20]; + test_mul_opt(¶m, a, b, c)?; - let a: [u64; N] = [0u64, 0, 0, 2]; - let b: [u64; N] = [0u64, 0, 0, 2]; - let c: [u64; N] = [0u64, 0, 65533, 0]; - test_mul_opt::(a, b, c)?; + let a: Vec = vec![0u64, 0, 0, 2]; + let b: Vec = vec![0u64, 0, 0, 2]; + let c: Vec = vec![0u64, 0, 65533, 0]; + test_mul_opt(¶m, a, b, c)?; // TODO more testvectors Ok(()) } - fn test_mul_opt( - a: [u64; N], - b: [u64; N], - expected_c: [u64; N], + fn test_mul_opt( + param: &RingParam, + a: Vec, + b: Vec, + expected_c: Vec, ) -> Result<()> { - let a: [Zq; N] = array::from_fn(|i| Zq::from_u64(a[i])); - let mut a = Rq::new(a, None); - let b: [Zq; N] = array::from_fn(|i| Zq::from_u64(b[i])); - let mut b = Rq::new(b, None); - let expected_c: [Zq; N] = array::from_fn(|i| Zq::from_u64(expected_c[i])); - let expected_c = Rq::new(expected_c, None); + assert_eq!(a.len(), param.n); + assert_eq!(b.len(), param.n); + + let mut a = Rq::from_vec_u64(¶m, a); + let mut b = Rq::from_vec_u64(¶m, b); + let expected_c = Rq::from_vec_u64(¶m, expected_c); let c = mul_mut(&mut a, &mut b); assert_eq!(c, expected_c); @@ -629,26 +705,25 @@ mod tests { #[test] fn test_rq_decompose() -> Result<()> { - const Q: u64 = 16; - const N: usize = 4; + let param = RingParam { q: 16, n: 4 }; let beta = 4; let l = 2; - let a = Rq::::from_vec_u64(vec![7u64, 14, 3, 6]); + let a = Rq::from_vec_u64(¶m, vec![7u64, 14, 3, 6]); let d = a.decompose(beta, l); assert_eq!( - d[0].coeffs().to_vec(), + d[0].coeffs(), vec![1u64, 3, 0, 1] .iter() - .map(|e| Zq::::from_u64(*e)) + .map(|e| Zq::from_u64(param.q, *e)) .collect::>() ); assert_eq!( - d[1].coeffs().to_vec(), + d[1].coeffs(), vec![3u64, 2, 3, 2] .iter() - .map(|e| Zq::::from_u64(*e)) + .map(|e| Zq::from_u64(param.q, *e)) .collect::>() ); Ok(()) diff --git a/arith/src/ring_torus.rs b/arith/src/ring_torus.rs index 46e134c..7f5d756 100644 --- a/arith/src/ring_torus.rs +++ b/arith/src/ring_torus.rs @@ -7,63 +7,94 @@ //! u64, we fit it into the `Ring` trait (from ring.rs) so that we can compose //! the 𝕋_ implementation with the other objects from the code. +use itertools::zip_eq; use rand::{distributions::Distribution, Rng}; -use std::array; use std::iter::Sum; use std::ops::{Add, AddAssign, Mul, Neg, Sub, SubAssign}; -use crate::{ring::Ring, torus::T64, Rq, Zq}; +use crate::{ + ring::{Ring, RingParam}, + torus::T64, + Rq, Zq, +}; /// 𝕋_[X] = 𝕋[X]/(X^N +1), polynomials modulo X^N+1 with coefficients in /// 𝕋, where Q=2^64. -#[derive(Clone, Copy, Debug)] -pub struct Tn(pub [T64; N]); +#[derive(Clone, Debug)] +pub struct Tn { + pub param: RingParam, + pub coeffs: Vec, +} -impl Ring for Tn { +impl Ring for Tn { type C = T64; - const Q: u64 = u64::MAX; // WIP - const N: usize = N; - + fn param(&self) -> RingParam { + RingParam { + q: u64::MAX, + n: self.param.n, + } + } fn coeffs(&self) -> Vec { - self.0.to_vec() + self.coeffs.to_vec() } - fn zero() -> Self { - Self(array::from_fn(|_| T64::zero())) + fn zero(param: &RingParam) -> Self { + Self { + param: *param, + coeffs: vec![T64::zero(param); param.n], + } } - fn rand(mut rng: impl Rng, dist: impl Distribution) -> Self { - Self(array::from_fn(|_| T64::rand(&mut rng, &dist))) + fn rand(mut rng: impl Rng, dist: impl Distribution, param: &RingParam) -> Self { + Self { + param: *param, + coeffs: std::iter::repeat_with(|| T64::rand(&mut rng, &dist, ¶m)) + .take(param.n) + .collect(), + } } - fn from_vec(coeffs: Vec) -> Self { + fn from_vec(param: &RingParam, coeffs: Vec) -> Self { let mut p = coeffs; - modulus::(&mut p); - Self(array::from_fn(|i| p[i])) + modulus(param, &mut p); + Self { + param: *param, + coeffs: p, + } } fn decompose(&self, beta: u32, l: u32) -> Vec { - let elems: Vec> = self.0.iter().map(|r| r.decompose(beta, l)).collect(); + let elems: Vec> = self.coeffs.iter().map(|r| r.decompose(beta, l)).collect(); // transpose it let r: Vec> = (0..elems[0].len()) .map(|i| (0..elems.len()).map(|j| elems[j][i]).collect()) .collect(); - // convert it to Tn - r.iter().map(|a_i| Self::from_vec(a_i.clone())).collect() + // convert it to Tn + r.iter() + .map(|a_i| Self::from_vec(&self.param, a_i.clone())) + .collect() } - fn remodule(&self) -> Tn { + fn remodule(&self, p: u64) -> Tn { todo!() // Rq::::from_vec_u64(self.coeffs().iter().map(|m_i| m_i.0).collect()) } // fn mod_switch(&self) -> impl Ring { - fn mod_switch(&self) -> Rq { + fn mod_switch(&self, p: u64) -> Rq { // unimplemented!() // TODO WIP - let coeffs = array::from_fn(|i| Zq::

::from_u64(self.0[i].mod_switch::

().0)); - Rq:: { + let coeffs = self + .coeffs + .iter() + .map(|c_i| Zq::from_u64(p, c_i.mod_switch(p).0)) + .collect(); + Rq { + param: RingParam { + q: p, + n: self.param.n, + }, coeffs, evals: None, } @@ -78,175 +109,220 @@ impl Ring for Tn { .iter() .map(|e| T64(((num as f64 * e.0 as f64) / den as f64).round() as u64)) .collect(); - Self::from_vec(r) + Self::from_vec(&self.param, r) } } -impl Tn { +impl Tn { // multiply self by X^-h pub fn left_rotate(&self, h: usize) -> Self { - let h = h % N; - assert!(h < N); - let c = self.0; + let n = self.param.n; + + let h = h % n; + assert!(h < n); + let c = &self.coeffs; // c[h], c[h+1], c[h+2], ..., c[n-1], -c[0], -c[1], ..., -c[h-1] // let r: Vec = vec![c[h..N], c[0..h].iter().map(|&c_i| -c_i).collect()].concat(); - let r: Vec = c[h..N] + let r: Vec = c[h..n] .iter() .copied() .chain(c[0..h].iter().map(|&x| -x)) .collect(); - Self::from_vec(r) + Self::from_vec(&self.param, r) } - pub fn from_vec_u64(v: Vec) -> Self { + pub fn from_vec_u64(param: &RingParam, v: Vec) -> Self { let coeffs = v.iter().map(|c| T64(*c)).collect(); - Self::from_vec(coeffs) + Self::from_vec(param, coeffs) } } // apply mod (X^N+1) -pub fn modulus(p: &mut Vec) { - if p.len() < N { +pub fn modulus(param: &RingParam, p: &mut Vec) { + let n = param.n; + if p.len() < n { return; } - for i in N..p.len() { - p[i - N] = p[i - N].clone() - p[i].clone(); - p[i] = T64::zero(); + for i in n..p.len() { + p[i - n] = p[i - n].clone() - p[i].clone(); + p[i] = T64::zero(param); } - p.truncate(N); + p.truncate(n); } -impl Add> for Tn { +impl Add for Tn { type Output = Self; fn add(self, rhs: Self) -> Self { - Self(array::from_fn(|i| self.0[i] + rhs.0[i])) + assert_eq!(self.param, rhs.param); + Self { + param: self.param, + coeffs: zip_eq(self.coeffs, rhs.coeffs) + .map(|(l, r)| l + r) + .collect(), + } } } -impl Add<&Tn> for &Tn { - type Output = Tn; - - fn add(self, rhs: &Tn) -> Self::Output { - Tn(array::from_fn(|i| self.0[i] + rhs.0[i])) +impl Add<&Tn> for &Tn { + type Output = Tn; + + fn add(self, rhs: &Tn) -> Self::Output { + assert_eq!(self.param, rhs.param); + Tn { + param: self.param, + coeffs: zip_eq(self.coeffs.clone(), rhs.coeffs.clone()) + .map(|(l, r)| l + r) + .collect(), + } } } -impl AddAssign for Tn { +impl AddAssign for Tn { fn add_assign(&mut self, rhs: Self) { - for i in 0..N { - self.0[i] += rhs.0[i]; + assert_eq!(self.param, rhs.param); + for i in 0..self.param.n { + self.coeffs[i] += rhs.coeffs[i]; } } } -impl Sum> for Tn { - fn sum(iter: I) -> Self +impl Sum for Tn { + fn sum(mut iter: I) -> Self where I: Iterator, { - let mut acc = Tn::::zero(); - for e in iter { - acc += e; - } - acc + let first = iter.next().unwrap(); + iter.fold(first, |acc, x| acc + x) } } -impl Sub> for Tn { +impl Sub for Tn { type Output = Self; fn sub(self, rhs: Self) -> Self { - Self(array::from_fn(|i| self.0[i] - rhs.0[i])) + assert_eq!(self.param, rhs.param); + Self { + param: self.param, + coeffs: zip_eq(self.coeffs, rhs.coeffs) + .map(|(l, r)| l - r) + .collect(), + } } } -impl Sub<&Tn> for &Tn { - type Output = Tn; - - fn sub(self, rhs: &Tn) -> Self::Output { - Tn(array::from_fn(|i| self.0[i] - rhs.0[i])) +impl Sub<&Tn> for &Tn { + type Output = Tn; + + fn sub(self, rhs: &Tn) -> Self::Output { + assert_eq!(self.param, rhs.param); + Tn { + param: self.param, + coeffs: zip_eq(self.coeffs.clone(), rhs.coeffs.clone()) + .map(|(l, r)| l - r) + .collect(), + } } } -impl SubAssign for Tn { +impl SubAssign for Tn { fn sub_assign(&mut self, rhs: Self) { - for i in 0..N { - self.0[i] -= rhs.0[i]; + assert_eq!(self.param, rhs.param); + for i in 0..self.param.n { + self.coeffs[i] -= rhs.coeffs[i]; } } } -impl Neg for Tn { +impl Neg for Tn { type Output = Self; fn neg(self) -> Self::Output { - Tn(array::from_fn(|i| -self.0[i])) + Self { + param: self.param, + coeffs: self.coeffs.iter().map(|c_i| -*c_i).collect(), + } } } -impl PartialEq for Tn { +impl PartialEq for Tn { fn eq(&self, other: &Self) -> bool { - self.0 == other.0 + self.coeffs == other.coeffs && self.param == other.param } } -impl Mul> for Tn { +impl Mul for Tn { type Output = Self; fn mul(self, rhs: Self) -> Self { naive_poly_mul(&self, &rhs) } } -impl Mul<&Tn> for &Tn { - type Output = Tn; +impl Mul<&Tn> for &Tn { + type Output = Tn; - fn mul(self, rhs: &Tn) -> Self::Output { + fn mul(self, rhs: &Tn) -> Self::Output { naive_poly_mul(self, rhs) } } -fn naive_poly_mul(poly1: &Tn, poly2: &Tn) -> Tn { - let poly1: Vec = poly1.0.iter().map(|c| c.0 as u128).collect(); - let poly2: Vec = poly2.0.iter().map(|c| c.0 as u128).collect(); - let mut result: Vec = vec![0; (N * 2) - 1]; - for i in 0..N { - for j in 0..N { +fn naive_poly_mul(poly1: &Tn, poly2: &Tn) -> Tn { + assert_eq!(poly1.param, poly2.param); + let n = poly1.param.n; + let param = poly1.param; + + let poly1: Vec = poly1.coeffs.iter().map(|c| c.0 as u128).collect(); + let poly2: Vec = poly2.coeffs.iter().map(|c| c.0 as u128).collect(); + let mut result: Vec = vec![0; (n * 2) - 1]; + for i in 0..n { + for j in 0..n { result[i + j] = result[i + j] + poly1[i] * poly2[j]; } } - // apply mod (X^N + 1)) - modulus_u128::(&mut result); + // apply mod (X^n + 1)) + modulus_u128(n, &mut result); - Tn(array::from_fn(|i| T64(result[i] as u64))) + Tn { + param, + coeffs: result.iter().map(|r_i| T64(*r_i as u64)).collect(), + } } -fn modulus_u128(p: &mut Vec) { - if p.len() < N { +fn modulus_u128(n: usize, p: &mut Vec) { + if p.len() < n { return; } - for i in N..p.len() { - // p[i - N] = p[i - N].clone() - p[i].clone(); - p[i - N] = p[i - N].wrapping_sub(p[i]); + for i in n..p.len() { + // p[i - n] = p[i - n].clone() - p[i].clone(); + p[i - n] = p[i - n].wrapping_sub(p[i]); p[i] = 0; } - p.truncate(N); + p.truncate(n); } -impl Mul for Tn { +impl Mul for Tn { type Output = Self; fn mul(self, s: T64) -> Self { - Self(array::from_fn(|i| self.0[i] * s)) + Self { + param: self.param, + coeffs: self.coeffs.iter().map(|c_i| *c_i * s).collect(), + } } } // mul by u64 -impl Mul for Tn { +impl Mul for Tn { type Output = Self; fn mul(self, s: u64) -> Self { - Self(array::from_fn(|i| self.0[i] * s)) + Tn { + param: self.param, + coeffs: self.coeffs.iter().map(|c_i| *c_i * s).collect(), + } } } -impl Mul<&u64> for &Tn { - type Output = Tn; +impl Mul<&u64> for &Tn { + type Output = Tn; fn mul(self, s: &u64) -> Self::Output { - Tn::(array::from_fn(|i| self.0[i] * *s)) + Tn { + param: self.param, + coeffs: self.coeffs.iter().map(|c_i| c_i * s).collect(), + } } } @@ -256,8 +332,9 @@ mod tests { #[test] fn test_left_rotate() { - const N: usize = 4; - let f = Tn::::from_vec( + let param = RingParam { q: u64::MAX, n: 4 }; + let f = Tn::from_vec( + ¶m, vec![2i64, 3, -4, -1] .iter() .map(|c| T64(*c as u64)) @@ -267,7 +344,8 @@ mod tests { // expect f*x^-3 == -1 -2x -3x^2 +4x^3 assert_eq!( f.left_rotate(3), - Tn::::from_vec( + Tn::from_vec( + ¶m, vec![-1i64, -2, -3, 4] .iter() .map(|c| T64(*c as u64)) @@ -277,7 +355,8 @@ mod tests { // expect f*x^-1 == 3 -4x -1x^2 -2x^3 assert_eq!( f.left_rotate(1), - Tn::::from_vec( + Tn::from_vec( + ¶m, vec![3i64, -4, -1, -2] .iter() .map(|c| T64(*c as u64)) diff --git a/arith/src/torus.rs b/arith/src/torus.rs index fa48ad7..f50f3b6 100644 --- a/arith/src/torus.rs +++ b/arith/src/torus.rs @@ -4,7 +4,7 @@ use std::{ ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, }; -use crate::ring::Ring; +use crate::ring::{Ring, RingParam}; /// Let 𝕋 = ℝ/ℤ, where 𝕋 is a ℤ-module, with homogeneous external product. /// Let 𝕋q @@ -16,20 +16,24 @@ pub struct T64(pub u64); // `Tn<1>`. impl Ring for T64 { type C = T64; - const Q: u64 = u64::MAX; // WIP - const N: usize = 1; + fn param(&self) -> RingParam { + RingParam { + q: u64::MAX, // WIP + n: 1, + } + } fn coeffs(&self) -> Vec { vec![self.clone()] } - fn zero() -> Self { + fn zero(_: &RingParam) -> Self { Self(0u64) } - fn rand(mut rng: impl Rng, dist: impl Distribution) -> Self { + fn rand(mut rng: impl Rng, dist: impl Distribution, _: &RingParam) -> Self { let r: f64 = dist.sample(&mut rng); Self(r.round() as u64) } - fn from_vec(coeffs: Vec) -> Self { + fn from_vec(_n: &RingParam, coeffs: Vec) -> Self { assert_eq!(coeffs.len(), 1); coeffs[0] } @@ -37,27 +41,27 @@ impl Ring for T64 { // TODO rm beta & l from inputs, make it always beta=2,l=64. /// Note: only beta=2 and l=64 is supported. fn decompose(&self, beta: u32, l: u32) -> Vec { - assert_eq!(beta, 2u32); // only beta=2 supported - // assert_eq!(l, 64u32); // only l=64 supported + assert_eq!(beta, 2u32, "only beta=2 supported"); + // assert_eq!(l, 64u32, "only l=64 supported"); // (0..64) - (0..l) + (0..l as u64) .rev() - .map(|i| T64(((self.0 >> i) & 1) as u64)) + .map(|i| T64((self.0 >> i) & 1)) .collect() } - fn remodule(&self) -> T64 { + fn remodule(&self, p: u64) -> T64 { todo!() } // modulus switch from Q to Q2: self * Q2/Q - fn mod_switch(&self) -> T64 { + fn mod_switch(&self, q2: u64) -> T64 { // for the moment we assume Q|Q2, since Q=2^64, check that Q2 is a power // of two: - assert!(Q2.is_power_of_two()); + assert!(q2.is_power_of_two()); // since Q=2^64, dividing Q2/Q is equivalent to dividing 2^log2(Q2)/2^64 // which would be like right-shifting 64-log2(Q2). - let log2_q2 = 63 - Q2.leading_zeros(); + let log2_q2 = 63 - q2.leading_zeros(); T64(self.0 >> (64 - log2_q2)) } @@ -173,9 +177,13 @@ mod tests { let d = x.decompose(beta, l); assert_eq!(recompose(d), T64(u64::MAX - 1)); + let param = RingParam { + q: u64::MAX, // WIP + n: 1, + }; let mut rng = rand::thread_rng(); for _ in 0..1000 { - let x = T64::rand(&mut rng, Standard); + let x = T64::rand(&mut rng, Standard, ¶m); let d = x.decompose(beta, l); assert_eq!(recompose(d), x); } diff --git a/arith/src/tuple_ring.rs b/arith/src/tuple_ring.rs index 1da5e08..b468913 100644 --- a/arith/src/tuple_ring.rs +++ b/arith/src/tuple_ring.rs @@ -1,40 +1,46 @@ //! This file implements the struct for an Tuple of Ring Rq elements and its //! operations, which are performed element-wise. -use anyhow::Result; use itertools::zip_eq; use rand::{distributions::Distribution, Rng}; -use rand_distr::{Normal, Uniform}; -use std::iter::Sum; -use std::{ - array, - ops::{Add, Mul, Neg, Sub}, -}; +use std::ops::{Add, Mul, Neg, Sub}; -use crate::Ring; +use crate::{Ring, RingParam}; /// Tuple of K Ring (Rq) elements. We use Vec to allocate it in the heap, /// since if using a fixed-size array it would overflow the stack. #[derive(Clone, Debug)] -pub struct TR(pub Vec); +pub struct TR { + pub k: usize, + pub r: Vec, +} // TODO rm pub from Vec, so that TR can not be created from a Vec with // invalid length, since it has to be created using the `new` method. -impl TR { - pub fn new(v: Vec) -> Self { - assert_eq!(v.len(), K); - Self(v) +impl TR { + pub fn new(k: usize, r: Vec) -> Self { + assert_eq!(r.len(), k); + Self { k, r } } - pub fn zero() -> Self { - Self((0..K).into_iter().map(|_| R::zero()).collect()) + pub fn zero(k: usize, r_param: &RingParam) -> Self { + Self { + k, + r: (0..k).into_iter().map(|_| R::zero(r_param)).collect(), + } } - pub fn rand(mut rng: impl Rng, dist: impl Distribution) -> Self { - Self( - (0..K) + pub fn rand( + mut rng: impl Rng, + dist: impl Distribution, + k: usize, + r_param: &RingParam, + ) -> Self { + Self { + k, + r: (0..k) .into_iter() - .map(|_| R::rand(&mut rng, &dist)) + .map(|_| R::rand(&mut rng, &dist, r_param)) .collect(), - ) + } } // returns the decomposition of each polynomial element pub fn decompose(&self, beta: u32, l: u32) -> Vec { @@ -43,64 +49,85 @@ impl TR { } } -impl TR { - pub fn mod_switch(&self) -> TR { - TR(self.0.iter().map(|c_i| c_i.mod_switch::()).collect()) +impl TR { + pub fn mod_switch(&self, q2: u64) -> TR { + TR:: { + k: self.k, + r: self.r.iter().map(|c_i| c_i.mod_switch(q2)).collect(), + } } // pub fn mod_switch(&self, Q2: u64) -> TR { // TR(self.0.iter().map(|c_i| c_i.mod_switch(Q2)).collect()) // } } -impl TR, K> { +impl TR { pub fn left_rotate(&self, h: usize) -> Self { - TR(self.0.iter().map(|c_i| c_i.left_rotate(h)).collect()) + TR { + k: self.k, + r: self.r.iter().map(|c_i| c_i.left_rotate(h)).collect(), + } } } -impl TR { +impl TR { pub fn iter(&self) -> std::slice::Iter { - self.0.iter() + self.r.iter() } } -impl Add> for TR { +impl Add> for TR { type Output = Self; fn add(self, other: Self) -> Self { - Self( - zip_eq(self.0, other.0) + debug_assert_eq!(self.k, other.k); + + Self { + k: self.k, + r: zip_eq(self.r, other.r) .map(|(s, o)| s + o) .collect::>(), - ) + } } } -impl Sub> for TR { +impl Sub> for TR { type Output = Self; fn sub(self, other: Self) -> Self { - Self(zip_eq(self.0, other.0).map(|(s, o)| s - o).collect()) + debug_assert_eq!(self.k, other.k); + + Self { + k: self.k, + r: zip_eq(self.r, other.r).map(|(s, o)| s - o).collect(), + } } } -impl Neg for TR { +impl Neg for TR { type Output = Self; fn neg(self) -> Self::Output { - Self(self.0.iter().map(|&e_i| -e_i).collect()) + Self { + k: self.k, + r: self.r.iter().map(|e_i| -e_i.clone()).collect(), + } } } /// for (TR,TR), the Mul operation is defined as the dot product: /// for A, B \in R^k, result = Σ A_i * B_i \in R -impl Mul> for TR { +impl Mul> for TR { type Output = R; fn mul(self, other: Self) -> R { - zip_eq(self.0, other.0).map(|(s, o)| s * o).sum() + debug_assert_eq!(self.k, other.k); + + zip_eq(self.r, other.r).map(|(s, o)| s * o).sum() } } -impl Mul<&TR> for &TR { +impl Mul<&TR> for &TR { type Output = R; - fn mul(self, other: &TR) -> R { - zip_eq(self.0.clone(), other.0.clone()) + fn mul(self, other: &TR) -> R { + debug_assert_eq!(self.k, other.k); + + zip_eq(self.r.clone(), other.r.clone()) .map(|(s, o)| s * o) .sum() } @@ -108,15 +135,21 @@ impl Mul<&TR> for &TR { /// for (TR, R), the Mul operation is defined as each element of TR is /// multiplied by R -impl Mul for TR { - type Output = TR; - fn mul(self, other: R) -> TR { - Self(self.0.iter().map(|s| s.clone() * other.clone()).collect()) +impl Mul for TR { + type Output = TR; + fn mul(self, other: R) -> TR { + Self { + k: self.k, + r: self.r.iter().map(|s| s.clone() * other.clone()).collect(), + } } } -impl Mul<&R> for &TR { - type Output = TR; - fn mul(self, other: &R) -> TR { - TR::(self.0.iter().map(|s| s.clone() * other.clone()).collect()) +impl Mul<&R> for &TR { + type Output = TR; + fn mul(self, other: &R) -> TR { + TR:: { + k: self.k, + r: self.r.iter().map(|s| s.clone() * other.clone()).collect(), + } } } diff --git a/arith/src/zq.rs b/arith/src/zq.rs index f6b5904..392f1d5 100644 --- a/arith/src/zq.rs +++ b/arith/src/zq.rs @@ -4,41 +4,39 @@ use std::ops::{Add, AddAssign, Div, Mul, Neg, Sub, SubAssign}; /// Z_q, integers modulus q, not necessarily prime #[derive(Clone, Copy, PartialEq)] -pub struct Zq(pub u64); - -// WIP -// impl From> for Vec> { -// fn from(v: Vec) -> Self { -// v.into_iter().map(Zq::new).collect() -// } -// } +pub struct Zq { + pub q: u64, + pub v: u64, +} -pub(crate) fn modulus_u64(e: u64) -> u64 { - (e % Q + Q) % Q +pub(crate) fn modulus_u64(q: u64, e: u64) -> u64 { + (e % q + q) % q } -impl Zq { - pub fn rand(mut rng: impl Rng, dist: impl Distribution) -> Self { +impl Zq { + pub fn rand(mut rng: impl Rng, dist: impl Distribution, q: u64) -> Self { // TODO WIP let r: f64 = dist.sample(&mut rng); - Self::from_f64(r) - // Self::from_u64(r.round() as u64) - } - pub fn from_u64(e: u64) -> Self { - if e >= Q { - // (e % Q + Q) % Q - return Zq(modulus_u64::(e)); - // return Zq(e % Q); + Self::from_f64(q, r) + } + pub fn from_u64(q: u64, v: u64) -> Self { + if v >= q { + // (v % Q + Q) % Q + return Zq { + q, + v: modulus_u64(q, v), + }; + // return Zq(v % Q); } - Zq(e) + Zq { q, v } } - pub fn from_f64(e: f64) -> Self { + pub fn from_f64(q: u64, e: f64) -> Self { // WIP method let e: i64 = e.round() as i64; - let q = Q as i64; - if e < 0 || e >= q { - return Zq(((e % q + q) % q) as u64); + let q_i64 = q as i64; + if e < 0 || e >= q_i64 { + return Zq::from_u64(q, ((e % q_i64 + q_i64) % q_i64) as u64); } - Zq(e as u64) + Zq { q, v: e as u64 } // if e < 0 { // // dbg!(&e); @@ -50,15 +48,18 @@ impl Zq { // } // Zq(e as u64) } - pub fn from_bool(b: bool) -> Self { + pub fn from_bool(q: u64, b: bool) -> Self { if b { - Zq(1) + Zq { q, v: 1 } } else { - Zq(0) + Zq { q, v: 0 } } } - pub fn zero() -> Self { - Self(0u64) + pub fn zero(q: u64) -> Self { + Self { q, v: 0u64 } + } + pub fn one(q: u64) -> Self { + Self { q, v: 1u64 } } pub fn square(self) -> Self { self * self @@ -66,18 +67,21 @@ impl Zq { // modular exponentiation pub fn exp(self, e: Self) -> Self { // mul-square approach - let mut res = Self(1); + let mut res = Self::one(self.q); let mut rem = e.clone(); let mut exp = self; // for rem != Self(0) { - while rem != Self(0) { + while rem != Self::zero(self.q) { // if odd - // TODO use a more readible expression - if 1 - ((rem.0 & 1) << 1) as i64 == -1 { + // TODO use a more readeable expression + if 1 - ((rem.v & 1) << 1) as i64 == -1 { res = res * exp; } exp = exp.square(); - rem = Self(rem.0 >> 1); + rem = Self { + q: self.q, + v: rem.v >> 1, + }; } res } @@ -89,9 +93,9 @@ impl Zq { // let a = self.0; // let q = Q; let mut t = 0; - let mut r = Q; + let mut r = self.q; let mut new_t = 0; - let mut new_r = self.0.clone(); + let mut new_r = self.v.clone(); while new_r != 0 { let q = r / new_r; @@ -104,16 +108,16 @@ impl Zq { // if t < 0 { // t = t + q; // } - return Zq::from_u64(t); + return Zq::from_u64(self.q, t); } - pub fn inv(self) -> Zq { - let (g, x, _) = Self::egcd(self.0 as i128, Q as i128); + pub fn inv(self) -> Zq { + let (g, x, _) = Self::egcd(self.v as i128, self.q as i128); if g != 1 { // None panic!("E"); } else { - let q = Q as i128; - Zq(((x % q + q) % q) as u64) // TODO maybe just Zq::new(x) + let q = self.q as i128; + Zq::from_u64(self.q, ((x % q + q) % q) as u64) // TODO maybe just Zq::new(x) } } fn egcd(a: i128, b: i128) -> (i128, i128, i128) { @@ -126,8 +130,11 @@ impl Zq { } /// perform the mod switch operation from Q to Q', where Q2=Q' - pub fn mod_switch(&self) -> Zq { - Zq::::from_u64(((self.0 as f64 * Q2 as f64) / Q as f64).round() as u64) + pub fn mod_switch(&self, q2: u64) -> Zq { + Zq::from_u64( + q2, + ((self.v as f64 * q2 as f64) / self.q as f64).round() as u64, + ) } pub fn decompose(&self, beta: u32, l: u32) -> Vec { @@ -138,19 +145,25 @@ impl Zq { } } pub fn decompose_base_beta(&self, beta: u32, l: u32) -> Vec { - let mut rem: u64 = self.0; + let mut rem: u64 = self.v; // next if is for cases in which beta does not divide Q (concretely // beta^l!=Q). round to the nearest multiple of q/beta^l if rem >= beta.pow(l) as u64 { // rem = Q - 1 - (Q / beta as u64); // floor - return vec![Zq(beta as u64 - 1); l as usize]; + return vec![ + Zq { + q: self.q, + v: beta as u64 - 1 + }; + l as usize + ]; } let mut x: Vec = vec![]; for i in 1..l + 1 { - let den = Q / beta.pow(i) as u64; + let den = self.q / beta.pow(i) as u64; let x_i = rem / den; // division between u64 already does floor - x.push(Self::from_u64(x_i)); + x.push(Self::from_u64(self.q, x_i)); if x_i != 0 { rem = rem % den; } @@ -161,15 +174,15 @@ impl Zq { pub fn decompose_base2(&self, l: u32) -> Vec { // next if is for cases in which beta does not divide Q (concretely // beta^l!=Q). round to the nearest multiple of q/beta^l - if self.0 >= 1 << l as u64 { + if self.v >= 1 << l as u64 { // rem = Q - 1 - (Q / beta as u64); // floor // (where beta=2) - return vec![Zq(1); l as usize]; + return vec![Zq::one(self.q); l as usize]; } (0..l) .rev() - .map(|i| Self(((self.0 >> i) & 1) as u64)) + .map(|i| Self::from_u64(self.q, ((self.v >> i) & 1) as u64)) .collect() // naive ver: @@ -194,114 +207,143 @@ impl Zq { } } -impl Zq { +impl Zq { fn r#mod(self) -> Self { - if self.0 >= Q { - return Zq(self.0 % Q); + if self.v >= self.q { + return Zq::from_u64(self.q, self.v % self.q); } self } } -impl Add> for Zq { +impl Add for Zq { type Output = Self; fn add(self, rhs: Self) -> Self::Output { - let mut r = self.0 + rhs.0; - if r >= Q { - r -= Q; + assert_eq!(self.q, rhs.q); + + let mut v = self.v + rhs.v; + if v >= self.q { + v -= self.q; } - Zq(r) + Zq { q: self.q, v } } } -impl Add<&Zq> for &Zq { - type Output = Zq; +impl Add<&Zq> for &Zq { + type Output = Zq; - fn add(self, rhs: &Zq) -> Self::Output { - let mut r = self.0 + rhs.0; - if r >= Q { - r -= Q; + fn add(self, rhs: &Zq) -> Self::Output { + assert_eq!(self.q, rhs.q); + + let mut v = self.v + rhs.v; + if v >= self.q { + v -= self.q; } - Zq(r) + Zq { q: self.q, v } } } -impl AddAssign> for Zq { +impl AddAssign for Zq { fn add_assign(&mut self, rhs: Self) { *self = *self + rhs } } -impl std::iter::Sum for Zq { - fn sum(iter: I) -> Self +impl std::iter::Sum for Zq { + fn sum(mut iter: I) -> Self where I: Iterator, { - iter.fold(Zq(0), |acc, x| acc + x) + let first: Zq = iter.next().unwrap(); + iter.fold(first, |acc, x| acc + x) } } -impl Sub> for Zq { +impl Sub for Zq { type Output = Self; - fn sub(self, rhs: Self) -> Zq { - if self.0 >= rhs.0 { - Zq(self.0 - rhs.0) + fn sub(self, rhs: Self) -> Zq { + assert_eq!(self.q, rhs.q); + + if self.v >= rhs.v { + Zq { + q: self.q, + v: self.v - rhs.v, + } } else { - Zq((Q + self.0) - rhs.0) + Zq { + q: self.q, + v: (self.q + self.v) - rhs.v, + } } } } -impl Sub<&Zq> for &Zq { - type Output = Zq; +impl Sub<&Zq> for &Zq { + type Output = Zq; + + fn sub(self, rhs: &Zq) -> Self::Output { + assert_eq!(self.q, rhs.q); - fn sub(self, rhs: &Zq) -> Self::Output { - if self.0 >= rhs.0 { - Zq(self.0 - rhs.0) + if self.q >= rhs.q { + Zq { + q: self.q, + v: self.v - rhs.v, + } } else { - Zq((Q + self.0) - rhs.0) + Zq { + q: self.q, + v: (self.q + self.v) - rhs.v, + } } } } -impl SubAssign> for Zq { +impl SubAssign for Zq { fn sub_assign(&mut self, rhs: Self) { *self = *self - rhs } } -impl Neg for Zq { +impl Neg for Zq { type Output = Self; fn neg(self) -> Self::Output { - if self.0 == 0 { + if self.v == 0 { return self; } - Zq(Q - self.0) + Zq { + q: self.q, + v: self.q - self.v, + } } } -impl Mul> for Zq { +impl Mul for Zq { type Output = Self; - fn mul(self, rhs: Self) -> Zq { + fn mul(self, rhs: Self) -> Zq { + assert_eq!(self.q, rhs.q); + // TODO non-naive way - Zq(((self.0 as u128 * rhs.0 as u128) % Q as u128) as u64) + Zq { + q: self.q, + v: ((self.v as u128 * rhs.v as u128) % self.q as u128) as u64, + } // Zq((self.0 * rhs.0) % Q) } } -impl Div> for Zq { +impl Div for Zq { type Output = Self; - fn div(self, rhs: Self) -> Zq { + fn div(self, rhs: Self) -> Zq { // TODO non-naive way // Zq((self.0 / rhs.0) % Q) self * rhs.inv() } } -impl fmt::Display for Zq { +impl fmt::Display for Zq { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.0) + write!(f, "{}", self.v) } } -impl fmt::Debug for Zq { +impl fmt::Debug for Zq { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.0) + write!(f, "{}", self.v) } } @@ -312,80 +354,83 @@ mod tests { #[test] fn exp() { - const Q: u64 = 1021; - let a = Zq::(3); - let b = Zq::(3); - assert_eq!(a.exp(b), Zq(27)); + const q: u64 = 1021; + let a = Zq::from_u64(q, 3); + let b = Zq::from_u64(q, 3); + assert_eq!(a.exp(b), Zq::from_u64(q, 27)); - let a = Zq::(1000); - let b = Zq::(3); - assert_eq!(a.exp(b), Zq(949)); + let a = Zq::from_u64(q, 1000); + let b = Zq::from_u64(q, 3); + assert_eq!(a.exp(b), Zq::from_u64(q, 949)); } #[test] fn neg() { - const Q: u64 = 1021; - let a = Zq::::from_f64(101.0); - let b = Zq::::from_f64(-1.0); + let q: u64 = 1021; + let a = Zq::from_f64(q, 101.0); + let b = Zq::from_f64(q, -1.0); assert_eq!(-a, a * b); } - fn recompose(beta: u32, l: u32, d: Vec>) -> Zq { + fn recompose(q: u64, beta: u32, l: u32, d: Vec) -> Zq { let mut x = 0u64; for i in 0..l { - x += d[i as usize].0 * Q / beta.pow(i + 1) as u64; + x += d[i as usize].v * q / beta.pow(i + 1) as u64; } - Zq::from_u64(x) + Zq::from_u64(q, x) } #[test] fn test_decompose() { - const Q1: u64 = 16; + let q1: u64 = 16; let beta: u32 = 2; let l: u32 = 4; - let x = Zq::::from_u64(9); + let x = Zq::from_u64(q1, 9); let d = x.decompose(beta, l); - assert_eq!(recompose::(beta, l, d), x); + assert_eq!(recompose(q1, beta, l, d), x); - const Q: u64 = 5u64.pow(3); + let q: u64 = 5u64.pow(3); let beta: u32 = 5; let l: u32 = 3; - let dist = Uniform::new(0_u64, Q); + let dist = Uniform::new(0_u64, q); let mut rng = rand::thread_rng(); for _ in 0..1000 { - let x = Zq::::from_u64(dist.sample(&mut rng)); + let x = Zq::from_u64(q, dist.sample(&mut rng)); let d = x.decompose(beta, l); assert_eq!(d.len(), l as usize); - assert_eq!(recompose::(beta, l, d), x); + assert_eq!(recompose(q, beta, l, d), x); } } #[test] fn test_decompose_approx() { - const Q: u64 = 2u64.pow(4) + 1; + let q: u64 = 2u64.pow(4) + 1; let beta: u32 = 2; let l: u32 = 4; - let x = Zq::::from_u64(16); // in q, but bigger than beta^l + let x = Zq::from_u64(q, 16); // in q, but bigger than beta^l let d = x.decompose(beta, l); assert_eq!(d.len(), l as usize); - assert_eq!(recompose::(beta, l, d), Zq(15)); + assert_eq!(recompose(q, beta, l, d), Zq::from_u64(q, 15)); - const Q2: u64 = 5u64.pow(3) + 1; + let q2: u64 = 5u64.pow(3) + 1; let beta: u32 = 5; let l: u32 = 3; - let x = Zq::::from_u64(125); // in q, but bigger than beta^l + let x = Zq::from_u64(q2, 125); // in q, but bigger than beta^l let d = x.decompose(beta, l); assert_eq!(d.len(), l as usize); - assert_eq!(recompose::(beta, l, d), Zq(124)); + assert_eq!(recompose(q2, beta, l, d), Zq::from_u64(q2, 124)); - const Q3: u64 = 2u64.pow(16) + 1; + let q3: u64 = 2u64.pow(16) + 1; let beta: u32 = 2; let l: u32 = 16; - let x = Zq::::from_u64(Q3 - 1); // in q, but bigger than beta^l + let x = Zq::from_u64(q3, q3 - 1); // in q, but bigger than beta^l let d = x.decompose(beta, l); assert_eq!(d.len(), l as usize); - assert_eq!(recompose::(beta, l, d), Zq(beta.pow(l) as u64 - 1)); + assert_eq!( + recompose(q3, beta, l, d), + Zq::from_u64(q3, beta.pow(l) as u64 - 1) + ); } } diff --git a/bfv/src/lib.rs b/bfv/src/lib.rs index 8c43341..abb7b88 100644 --- a/bfv/src/lib.rs +++ b/bfv/src/lib.rs @@ -10,44 +10,61 @@ use rand::Rng; use rand_distr::{Normal, Uniform}; use std::ops; -use arith::{Ring, Rq, R}; +use arith::{Ring, RingParam, Rq, R}; // error deviation for the Gaussian(Normal) distribution // sigma=3.2 from: https://eprint.iacr.org/2022/162.pdf page 5 const ERR_SIGMA: f64 = 3.2; +#[derive(Clone, Copy, Debug)] +pub struct Param { + ring: RingParam, + t: u64, + p: u64, +} +impl Param { + // returns the plaintext param + pub fn pt(&self) -> RingParam { + RingParam { + q: self.t, + n: self.ring.n, + } + } +} + #[derive(Clone, Debug)] -pub struct SecretKey(Rq); +pub struct SecretKey(Rq); #[derive(Clone, Debug)] -pub struct PublicKey(Rq, Rq); +pub struct PublicKey(Rq, Rq); /// Relinearization key #[derive(Clone, Debug)] -pub struct RLK(Rq, Rq); +pub struct RLK(Rq, Rq); // RLWE ciphertext #[derive(Clone, Debug)] -pub struct RLWE(Rq, Rq); +pub struct RLWE(Rq, Rq); -impl RLWE { +impl RLWE { fn add(lhs: Self, rhs: Self) -> Self { - RLWE::(lhs.0 + rhs.0, lhs.1 + rhs.1) + RLWE(lhs.0 + rhs.0, lhs.1 + rhs.1) } - pub fn remodule(&self) -> RLWE { - let x = self.0.remodule::

(); - let y = self.1.remodule::

(); - RLWE::(x, y) + pub fn remodule(&self, p: u64) -> RLWE { + let x = self.0.remodule(p); + let y = self.1.remodule(p); + RLWE(x, y) } - fn tensor(a: &Self, b: &Self) -> (Rq, Rq, Rq) { + fn tensor(t: u64, a: &Self, b: &Self) -> (Rq, Rq, Rq) { + let (q, n) = (a.0.param.q, a.0.param.n); // expand Q->PQ // TODO rm // get the coefficients in Z, ie. interpret a,b \in R (instead of R_q) - let a0: R = a.0.to_r(); - let a1: R = a.1.to_r(); - let b0: R = b.0.to_r(); - let b1: R = b.1.to_r(); + let a0: R = a.0.clone().to_r(); // TODO rm clone() + let a1: R = a.1.clone().to_r(); + let b0: R = b.0.clone().to_r(); + let b1: R = b.1.clone().to_r(); // tensor (\in R) (2021-204 p.9) // NOTE: here can use *, but at first versions want to make it explicit @@ -60,44 +77,47 @@ impl RLWE { let c2: Vec = naive_mul(&a1, &b1); // scale down, then reduce module Q, so result is \in R_q - let c0: Rq = arith::ring_n::mul_div_round::(c0, T, Q); - let c1: Rq = arith::ring_n::mul_div_round::(c1, T, Q); - let c2: Rq = arith::ring_n::mul_div_round::(c2, T, Q); + let c0: Rq = arith::ring_n::mul_div_round(q, n, c0, t, q); + let c1: Rq = arith::ring_n::mul_div_round(q, n, c1, t, q); + let c2: Rq = arith::ring_n::mul_div_round(q, n, c2, t, q); (c0, c1, c2) } /// ciphertext multiplication - fn mul(rlk: &RLK, a: &Self, b: &Self) -> Self { - let (c0, c1, c2) = Self::tensor::(a, b); - BFV::::relinearize_204::(&rlk, &c0, &c1, &c2) + fn mul(t: u64, rlk: &RLK, a: &Self, b: &Self) -> Self { + let (c0, c1, c2) = Self::tensor(t, a, b); + BFV::relinearize_204(&rlk, &c0, &c1, &c2) } } // naive mul in the ring Rq, reusing the ring_n::naive_mul and then applying mod(X^N +1) -fn tmp_naive_mul(a: Rq, b: Rq) -> Rq { - Rq::::from_vec_i64(arith::ring_n::naive_mul(&a.to_r(), &b.to_r())) +fn tmp_naive_mul(a: Rq, b: Rq) -> Rq { + Rq::from_vec_i64( + &a.param.clone(), + arith::ring_n::naive_mul(&a.to_r(), &b.to_r()), + ) } -impl ops::Add> for RLWE { +impl ops::Add for RLWE { type Output = Self; fn add(self, rhs: Self) -> Self { Self::add(self, rhs) } } -impl ops::Add<&Rq> for &RLWE { - type Output = RLWE; - fn add(self, rhs: &Rq) -> Self::Output { - BFV::::add_const(self, rhs) +impl ops::Add<&Rq> for &RLWE { + type Output = RLWE; + fn add(self, rhs: &Rq) -> Self::Output { + BFV::add_const(self, rhs) } } -pub struct BFV {} +pub struct BFV {} -impl BFV { - const DELTA: u64 = Q / T; // floor +impl BFV { + // const DELTA: u64 = Q / T; // floor /// generate a new key pair (privK, pubK) - pub fn new_key(mut rng: impl Rng) -> Result<(SecretKey, PublicKey)> { + pub fn new_key(mut rng: impl Rng, param: &Param) -> Result<(SecretKey, PublicKey)> { // WIP: review probabilities // let Xi_key = Uniform::new(-1_f64, 1_f64); @@ -105,114 +125,135 @@ impl BFV { let Xi_err = Normal::new(0_f64, ERR_SIGMA)?; // secret key - // let mut s = Rq::::rand_f64(&mut rng, Xi_key)?; - let mut s = Rq::::rand_u64(&mut rng, Xi_key)?; + // let mut s = Rq::rand_f64(&mut rng, Xi_key)?; + let mut s = Rq::rand_u64(&mut rng, Xi_key, ¶m.ring)?; // since s is going to be multiplied by other Rq elements, already // compute its NTT s.compute_evals(); // pk = (-a * s + e, a) - let a = Rq::::rand_u64(&mut rng, Uniform::new(0_u64, Q))?; - let e = Rq::::rand_f64(&mut rng, Xi_err)?; - let pk: PublicKey = PublicKey((&(-a) * &s) + e, a.clone()); + let a = Rq::rand_u64(&mut rng, Uniform::new(0_u64, param.ring.q), ¶m.ring)?; + let e = Rq::rand_f64(&mut rng, Xi_err, ¶m.ring)?; + let pk: PublicKey = PublicKey(&(&(-a.clone()) * &s) + &e, a.clone()); // TODO rm clones Ok((SecretKey(s), pk)) } - pub fn encrypt(mut rng: impl Rng, pk: &PublicKey, m: &Rq) -> Result> { + // note: m is modulus t + pub fn encrypt(mut rng: impl Rng, param: &Param, pk: &PublicKey, m: &Rq) -> Result { + // assert param & inputs + debug_assert_eq!(param.ring, pk.0.param); + debug_assert_eq!(param.t, m.param.q); + debug_assert_eq!(param.ring.n, m.param.n); + let Xi_key = Uniform::new(-1_f64, 1_f64); // let Xi_key = Uniform::new(0_u64, 2_u64); let Xi_err = Normal::new(0_f64, ERR_SIGMA)?; - let u = Rq::::rand_f64(&mut rng, Xi_key)?; - // let u = Rq::::rand_u64(&mut rng, Xi_key)?; - let e_1 = Rq::::rand_f64(&mut rng, Xi_err)?; - let e_2 = Rq::::rand_f64(&mut rng, Xi_err)?; + let u = Rq::rand_f64(&mut rng, Xi_key, ¶m.ring)?; + // let u = Rq::rand_u64(&mut rng, Xi_key)?; + let e_1 = Rq::rand_f64(&mut rng, Xi_err, ¶m.ring)?; + let e_2 = Rq::rand_f64(&mut rng, Xi_err, ¶m.ring)?; // migrate m's coeffs to the bigger modulus Q (from T) - let m = m.remodule::(); - let c0 = &pk.0 * &u + e_1 + m * Self::DELTA; + let m = m.remodule(param.ring.q); + let c0 = &pk.0 * &u + e_1 + m * (param.ring.q / param.t); // floor(q/t)=DELTA let c1 = &pk.1 * &u + e_2; - Ok(RLWE::(c0, c1)) + Ok(RLWE(c0, c1)) } - pub fn decrypt(sk: &SecretKey, c: &RLWE) -> Rq { - let cs = c.0 + c.1 * sk.0; // done in mod q + pub fn decrypt(param: &Param, sk: &SecretKey, c: &RLWE) -> Rq { + debug_assert_eq!(param.ring, sk.0.param); + debug_assert_eq!(param.ring.q, c.0.param.q); + debug_assert_eq!(param.ring.n, c.0.param.n); + + let cs: Rq = &c.0 + &(&c.1 * &sk.0); // done in mod q // same but with naive_mul: // let c1s = arith::ring_n::naive_mul(&c.1.to_r(), &sk.0.to_r()); - // let c1s = Rq::::from_vec_i64(c1s); + // let c1s = Rq::from_vec_i64(c1s); // let cs = c.0 + c1s; - let r: Rq = cs.mul_div_round(T, Q); - r.remodule::() + let r: Rq = cs.mul_div_round(param.t, param.ring.q); + r.remodule(param.t) } - fn add_const(c: &RLWE, m: &Rq) -> RLWE { + fn add_const(c: &RLWE, m: &Rq) -> RLWE { + let q = c.0.param.q; + let t = m.param.q; + // assuming T to Zq - let m = m.remodule::(); - RLWE::(c.0 + m * Self::DELTA, c.1) + let m = m.remodule(c.0.param.q); + // TODO rm clones + RLWE(c.0.clone() + m * (q / t), c.1.clone()) // floor(q/t)=DELTA } - fn mul_const(rlk: &RLK, c: &RLWE, m: &Rq) -> RLWE { + fn mul_const(rlk: &RLK, c: &RLWE, m: &Rq) -> RLWE { + // let pq = rlk.0.q; + let q = c.0.param.q; + let t = m.param.q; + // assuming T to Zq - let m = m.remodule::(); + let m = m.remodule(q); // encrypt m*Delta without noise, and then perform normal ciphertext multiplication - let md = RLWE::(m * Self::DELTA, Rq::zero()); - RLWE::::mul::(&rlk, &c, &md) + let md = RLWE(m * (q / t), Rq::zero(&c.0.param)); // floor(q/t)=DELTA + RLWE::mul(t, &rlk, &c, &md) } - fn rlk_key(mut rng: impl Rng, s: &SecretKey) -> Result> { + fn rlk_key(mut rng: impl Rng, param: &Param, s: &SecretKey) -> Result { + let pq = param.p * param.ring.q; + let rlk_param = RingParam { + q: pq, + n: param.ring.n, + }; + // TODO review using Xi' instead of Xi let Xi_err = Normal::new(0_f64, ERR_SIGMA)?; // let Xi_err = Normal::new(0_f64, 0.0)?; - let s = s.0.remodule::(); - let a = Rq::::rand_u64(&mut rng, Uniform::new(0_u64, PQ))?; - let e = Rq::::rand_f64(&mut rng, Xi_err)?; - - let P = PQ / Q; + let s = s.0.remodule(pq); + let a = Rq::rand_u64(&mut rng, Uniform::new(0_u64, pq), &rlk_param)?; + let e = Rq::rand_f64(&mut rng, Xi_err, &rlk_param)?; // let rlk: RLK = RLK::(-(&a * &s + e) + (s * s) * P, a.clone()); - let rlk: RLK = RLK::( - -(tmp_naive_mul(a, s) + e) + tmp_naive_mul(s, s) * P, + // TODO rm clones + let rlk: RLK = RLK( + -(tmp_naive_mul(a.clone(), s.clone()) + e) + + tmp_naive_mul(s.clone(), s.clone()) * param.p, a.clone(), ); Ok(rlk) } - fn relinearize( - rlk: &RLK, - c0: &Rq, - c1: &Rq, - c2: &Rq, - ) -> RLWE { - let P = PQ / Q; + fn relinearize(rlk: &RLK, c0: &Rq, c1: &Rq, c2: &Rq) -> RLWE { + let pq = rlk.0.param.q; + let param = c0.param; + let q = param.q; + let p = pq / q; - let c2rlk0: Vec = (c2.to_r() * rlk.0.to_r()) + let c2rlk0: Vec = (c2.clone().to_r() * rlk.0.clone().to_r()) .coeffs() .iter() - .map(|e| (*e as f64 / P as f64).round()) + .map(|e| (*e as f64 / p as f64).round()) .collect(); - let c2rlk1: Vec = (c2.to_r() * rlk.1.to_r()) + let c2rlk1: Vec = (c2.clone().to_r() * rlk.1.clone().to_r()) // TODO rm clones .coeffs() .iter() - .map(|e| (*e as f64 / P as f64).round()) + .map(|e| (*e as f64 / p as f64).round()) .collect(); - let r0 = Rq::::from_vec_f64(c2rlk0); - let r1 = Rq::::from_vec_f64(c2rlk1); + let r0 = Rq::from_vec_f64(¶m, c2rlk0); + let r1 = Rq::from_vec_f64(¶m, c2rlk1); - let res = RLWE::(c0 + &r0, c1 + &r1); + let res = RLWE(c0 + &r0, c1 + &r1); res } - fn relinearize_204( - rlk: &RLK, - c0: &Rq, - c1: &Rq, - c2: &Rq, - ) -> RLWE { - let P = PQ / Q; + fn relinearize_204(rlk: &RLK, c0: &Rq, c1: &Rq, c2: &Rq) -> RLWE { + let pq = rlk.0.param.q; + let q = c0.param.q; + let p = pq / q; + let n = c0.param.n; + // TODO (in debug) check that all Ns match // let c2rlk0: Rq = c2.remodule::() * rlk.0.remodule::(); // let c2rlk1: Rq = c2.remodule::() * rlk.1.remodule::(); @@ -220,12 +261,12 @@ impl BFV { // let r1: Rq = c2rlk1.mul_div_round(1, P).remodule::(); use arith::ring_n::naive_mul; - let c2rlk0: Vec = naive_mul(&c2.to_r(), &rlk.0.to_r()); - let c2rlk1: Vec = naive_mul(&c2.to_r(), &rlk.1.to_r()); - let r0: Rq = arith::ring_n::mul_div_round::(c2rlk0, 1, P); - let r1: Rq = arith::ring_n::mul_div_round::(c2rlk1, 1, P); + let c2rlk0: Vec = naive_mul(&c2.clone().to_r(), &rlk.0.clone().to_r()); // TODO rm clones + let c2rlk1: Vec = naive_mul(&c2.clone().to_r(), &rlk.1.clone().to_r()); + let r0: Rq = arith::ring_n::mul_div_round(q, n, c2rlk0, 1, p); + let r1: Rq = arith::ring_n::mul_div_round(q, n, c2rlk1, 1, p); - let res = RLWE::(c0 + &r0, c1 + &r1); + let res = RLWE(c0 + &r0, c1 + &r1); res } } @@ -239,21 +280,25 @@ mod tests { #[test] fn test_encrypt_decrypt() -> Result<()> { - const Q: u64 = 2u64.pow(16) + 1; - const N: usize = 512; - const T: u64 = 32; // plaintext modulus - type S = BFV; + let param = Param { + ring: RingParam { + q: 2u64.pow(16) + 1, // q prime, and 2^q + 1 shape + n: 512, + }, + t: 32, // plaintext modulus + p: 0, // unused in this test + }; let mut rng = rand::thread_rng(); for _ in 0..100 { - let (sk, pk) = S::new_key(&mut rng)?; + let (sk, pk) = BFV::new_key(&mut rng, ¶m)?; - let msg_dist = Uniform::new(0_u64, T); - let m = Rq::::rand_u64(&mut rng, msg_dist)?; + let msg_dist = Uniform::new(0_u64, param.t); + let m = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; - let c = S::encrypt(&mut rng, &pk, &m)?; - let m_recovered = S::decrypt(&sk, &c); + let c = BFV::encrypt(&mut rng, ¶m, &pk, &m)?; + let m_recovered = BFV::decrypt(¶m, &sk, &c); assert_eq!(m, m_recovered); } @@ -263,26 +308,30 @@ mod tests { #[test] fn test_addition() -> Result<()> { - const Q: u64 = 2u64.pow(16) + 1; - const N: usize = 128; - const T: u64 = 32; // plaintext modulus - type S = BFV; + let param = Param { + ring: RingParam { + q: 2u64.pow(16) + 1, // q prime, and 2^q + 1 shape + n: 128, + }, + t: 32, // plaintext modulus + p: 0, // unused in this test + }; let mut rng = rand::thread_rng(); for _ in 0..100 { - let (sk, pk) = S::new_key(&mut rng)?; + let (sk, pk) = BFV::new_key(&mut rng, ¶m)?; - let msg_dist = Uniform::new(0_u64, T); - let m1 = Rq::::rand_u64(&mut rng, msg_dist)?; - let m2 = Rq::::rand_u64(&mut rng, msg_dist)?; + let msg_dist = Uniform::new(0_u64, param.t); + let m1 = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; + let m2 = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; - let c1 = S::encrypt(&mut rng, &pk, &m1)?; - let c2 = S::encrypt(&mut rng, &pk, &m2)?; + let c1 = BFV::encrypt(&mut rng, ¶m, &pk, &m1)?; + let c2 = BFV::encrypt(&mut rng, ¶m, &pk, &m2)?; let c3 = c1 + c2; - let m3_recovered = S::decrypt(&sk, &c3); + let m3_recovered = BFV::decrypt(¶m, &sk, &c3); assert_eq!(m1 + m2, m3_recovered); } @@ -292,211 +341,208 @@ mod tests { #[test] fn test_constant_add_mul() -> Result<()> { - const Q: u64 = 2u64.pow(16) + 1; - const N: usize = 16; - const T: u64 = 8; // plaintext modulus - type S = BFV; + let q: u64 = 2u64.pow(16) + 1; // q prime, and 2^q + 1 shape + let param = Param { + ring: RingParam { q, n: 16 }, + t: 8, // plaintext modulus + p: q * q, + }; let mut rng = rand::thread_rng(); - let (sk, pk) = S::new_key(&mut rng)?; + let (sk, pk) = BFV::new_key(&mut rng, ¶m)?; - let msg_dist = Uniform::new(0_u64, T); - let m1 = Rq::::rand_u64(&mut rng, msg_dist)?; - let m2_const = Rq::::rand_u64(&mut rng, msg_dist)?; - let c1 = S::encrypt(&mut rng, &pk, &m1)?; + let msg_dist = Uniform::new(0_u64, param.t); + let m1 = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; + let m2_const = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; + let c1 = BFV::encrypt(&mut rng, ¶m, &pk, &m1)?; let c3_add = &c1 + &m2_const; - let m3_add_recovered = S::decrypt(&sk, &c3_add); - assert_eq!(m1 + m2_const, m3_add_recovered); + let m3_add_recovered = BFV::decrypt(¶m, &sk, &c3_add); + assert_eq!(&m1 + &m2_const, m3_add_recovered); // test multiplication of a ciphertext by a constant - const P: u64 = Q * Q; - const PQ: u64 = P * Q; - let rlk = BFV::::rlk_key::(&mut rng, &sk)?; + let rlk = BFV::rlk_key(&mut rng, ¶m, &sk)?; - let c3_mul = S::mul_const(&rlk, &c1, &m2_const); + let c3_mul = BFV::mul_const(&rlk, &c1, &m2_const); - let m3_mul_recovered = S::decrypt(&sk, &c3_mul); + let m3_mul_recovered = BFV::decrypt(¶m, &sk, &c3_mul); assert_eq!( - (m1.to_r() * m2_const.to_r()).to_rq::().coeffs(), + (m1.to_r() * m2_const.to_r()).to_rq(param.t).coeffs(), m3_mul_recovered.coeffs() ); Ok(()) } - // TMP WIP - #[test] - #[ignore] - fn test_params() -> Result<()> { - const Q: u64 = 2u64.pow(16) + 1; // q prime, and 2^q + 1 shape - const N: usize = 32; - const T: u64 = 8; // plaintext modulus - - const P: u64 = Q * Q; - const PQ: u64 = P * Q; - const DELTA: u64 = Q / T; // floor - - let mut rng = rand::thread_rng(); - - let Xi_key = Uniform::new(0_f64, 1_f64); - let Xi_err = Normal::new(0_f64, ERR_SIGMA)?; - - let s = Rq::::rand_f64(&mut rng, Xi_key)?; - let e = Rq::::rand_f64(&mut rng, Xi_err)?; - let u = Rq::::rand_f64(&mut rng, Xi_key)?; - let e_0 = Rq::::rand_f64(&mut rng, Xi_err)?; - let e_1 = Rq::::rand_f64(&mut rng, Xi_err)?; - let m = Rq::::rand_u64(&mut rng, Uniform::new(0_u64, T))?; - - // v_fresh - let v: Rq = u * e + e_1 * s + e_0; - - let q: f64 = Q as f64; - let t: f64 = T as f64; - let n: f64 = N as f64; - let delta: f64 = DELTA as f64; - - // r_t(q)/t should be equal to q/t-Δ - assert_eq!( - // r_t(q)/t, where r_t(q)=q mod t - (q % t) / t, - // Δt/Q = q - r_t(Q)/Q, so r_t(Q)=q - Δt - (q / t) - delta - ); - let rt: f64 = (q % t) / t; - dbg!(&rt); - - dbg!(v.infinity_norm()); - let bound: f64 = (q / (2_f64 * t)) - (rt / 2_f64); - dbg!(bound); - assert!((v.infinity_norm() as f64) < bound); - let max_v_infnorm = bound - 1.0; - - // addition noise - let v_add: Rq = v + v + u * rt; - let v_add: Rq = v_add + v_add + u * rt; - assert!((v_add.infinity_norm() as f64) < bound); - - // multiplication noise - let (_, pk) = BFV::::new_key(&mut rng)?; - let c = BFV::::encrypt(&mut rng, &pk, &m.remodule::())?; - let b_key: f64 = 1_f64; - // ef: expansion factor - let ef: f64 = 2.0 * n.sqrt(); - let bound: f64 = ((ef * t) / 2.0) - * ((2.0 * max_v_infnorm * max_v_infnorm) / q - + (4.0 + ef * b_key) * (max_v_infnorm + max_v_infnorm) - + rt * (ef * b_key + 5.0)) - + (1.0 + ef * b_key + ef * ef * b_key * b_key) / 2.0; - dbg!(&bound); - - let k: Vec = (c.0 + c.1 * s - m * delta - v) - .coeffs() - .iter() - .map(|e_i| e_i.0 as f64 / q) - .collect(); - let k = Rq::::from_vec_f64(k); - let v_tensor_0 = (v * v) - .coeffs() - .iter() - .map(|e_i| (e_i.0 as f64 * t) / q) - .collect::>(); - let v_tensor_0 = Rq::::from_vec_f64(v_tensor_0); - let v_tensor_1 = ((m * v) + (m * v)) - .coeffs() - .iter() - .map(|e_i| (e_i.0 as f64 * t * delta) / q) - .collect::>(); - let v_tensor_1 = Rq::::from_vec_f64(v_tensor_1); - let v_tensor_2: Rq = (v * k + v * k) * t; - let rm: f64 = (ef * t) / 2.0; - let rm: Rq = Rq::::from_vec_f64(vec![rm; N]); - let v_tensor_3: Rq = (m * k - + m * k - + rm - + Rq::from_vec_f64( - ((m * m) * DELTA) - .coeffs() - .iter() - .map(|e_i| e_i.0 as f64 / q) - .collect::>(), - )) - * rt; - let v_tensor = v_tensor_0 + v_tensor_1 + v_tensor_2 - v_tensor_3; - - let v_r = (1.0 + ef * b_key + ef * ef * b_key * b_key) / 2.0; - let v_mult_norm = v_tensor.infinity_norm() as f64 + v_r; - dbg!(&v_mult_norm); - dbg!(&bound); - assert!(v_mult_norm < bound); - - // let m1 = Rq::::zero(); - // let m2 = Rq::::zero(); - // let (_, pk) = BFV::::new_key(&mut rng)?; - // let c1 = BFV::::encrypt(&mut rng, &pk, &m1)?; - // let c2 = BFV::::encrypt(&mut rng, &pk, &m2)?; - // let (c_a, c_b, c_c) = RLWE::::tensor::(&c1, &c2); - // dbg!(&c_a.infinity_norm()); - // dbg!(&c_b.infinity_norm()); - // dbg!(&c_c.infinity_norm()); - // assert!((c_a.infinity_norm() as f64) < bound); - // assert!((c_b.infinity_norm() as f64) < bound); - // assert!((c_c.infinity_norm() as f64) < bound); - // WIP - - Ok(()) - } + /* + // TMP WIP + #[test] + #[ignore] + fn test_param() -> Result<()> { + const Q: u64 = 2u64.pow(16) + 1; // q prime, and 2^q + 1 shape + const N: usize = 32; + const T: u64 = 8; // plaintext modulus + + const P: u64 = Q * Q; + const PQ: u64 = P * Q; + const DELTA: u64 = Q / T; // floor + + let mut rng = rand::thread_rng(); + + let Xi_key = Uniform::new(0_f64, 1_f64); + let Xi_err = Normal::new(0_f64, ERR_SIGMA)?; + + let s = Rq::rand_f64(&mut rng, Xi_key)?; + let e = Rq::rand_f64(&mut rng, Xi_err)?; + let u = Rq::rand_f64(&mut rng, Xi_key)?; + let e_0 = Rq::rand_f64(&mut rng, Xi_err)?; + let e_1 = Rq::rand_f64(&mut rng, Xi_err)?; + let m = Rq::rand_u64(&mut rng, Uniform::new(0_u64, T))?; + + // v_fresh + let v: Rq = u * e + e_1 * s + e_0; + + let q: f64 = Q as f64; + let t: f64 = T as f64; + let n: f64 = N as f64; + let delta: f64 = DELTA as f64; + + // r_t(q)/t should be equal to q/t-Δ + assert_eq!( + // r_t(q)/t, where r_t(q)=q mod t + (q % t) / t, + // Δt/Q = q - r_t(Q)/Q, so r_t(Q)=q - Δt + (q / t) - delta + ); + let rt: f64 = (q % t) / t; + dbg!(&rt); + + dbg!(v.infinity_norm()); + let bound: f64 = (q / (2_f64 * t)) - (rt / 2_f64); + dbg!(bound); + assert!((v.infinity_norm() as f64) < bound); + let max_v_infnorm = bound - 1.0; + + // addition noise + let v_add: Rq = v + v + u * rt; + let v_add: Rq = v_add + v_add + u * rt; + assert!((v_add.infinity_norm() as f64) < bound); + + // multiplication noise + let (_, pk) = BFV::::new_key(&mut rng)?; + let c = BFV::::encrypt(&mut rng, &pk, &m.remodule::())?; + let b_key: f64 = 1_f64; + // ef: expansion factor + let ef: f64 = 2.0 * n.sqrt(); + let bound: f64 = ((ef * t) / 2.0) + * ((2.0 * max_v_infnorm * max_v_infnorm) / q + + (4.0 + ef * b_key) * (max_v_infnorm + max_v_infnorm) + + rt * (ef * b_key + 5.0)) + + (1.0 + ef * b_key + ef * ef * b_key * b_key) / 2.0; + dbg!(&bound); + + let k: Vec = (c.0 + c.1 * s - m * delta - v) + .coeffs() + .iter() + .map(|e_i| e_i.0 as f64 / q) + .collect(); + let k = Rq::from_vec_f64(k); + let v_tensor_0 = (v * v) + .coeffs() + .iter() + .map(|e_i| (e_i.0 as f64 * t) / q) + .collect::>(); + let v_tensor_0 = Rq::from_vec_f64(v_tensor_0); + let v_tensor_1 = ((m * v) + (m * v)) + .coeffs() + .iter() + .map(|e_i| (e_i.0 as f64 * t * delta) / q) + .collect::>(); + let v_tensor_1 = Rq::from_vec_f64(v_tensor_1); + let v_tensor_2: Rq = (v * k + v * k) * t; + let rm: f64 = (ef * t) / 2.0; + let rm: Rq = Rq::from_vec_f64(vec![rm; N]); + let v_tensor_3: Rq = (m * k + + m * k + + rm + + Rq::from_vec_f64( + ((m * m) * DELTA) + .coeffs() + .iter() + .map(|e_i| e_i.0 as f64 / q) + .collect::>(), + )) + * rt; + let v_tensor = v_tensor_0 + v_tensor_1 + v_tensor_2 - v_tensor_3; + + let v_r = (1.0 + ef * b_key + ef * ef * b_key * b_key) / 2.0; + let v_mult_norm = v_tensor.infinity_norm() as f64 + v_r; + dbg!(&v_mult_norm); + dbg!(&bound); + assert!(v_mult_norm < bound); + + // let m1 = Rq::::zero(); + // let m2 = Rq::::zero(); + // let (_, pk) = BFV::::new_key(&mut rng)?; + // let c1 = BFV::::encrypt(&mut rng, &pk, &m1)?; + // let c2 = BFV::::encrypt(&mut rng, &pk, &m2)?; + // let (c_a, c_b, c_c) = RLWE::tensor::(&c1, &c2); + // dbg!(&c_a.infinity_norm()); + // dbg!(&c_b.infinity_norm()); + // dbg!(&c_c.infinity_norm()); + // assert!((c_a.infinity_norm() as f64) < bound); + // assert!((c_b.infinity_norm() as f64) < bound); + // assert!((c_c.infinity_norm() as f64) < bound); + // WIP + + Ok(()) + } + */ #[test] fn test_tensor() -> Result<()> { - const Q: u64 = 2u64.pow(16) + 1; // q prime, and 2^q + 1 shape - const N: usize = 16; - const T: u64 = 2; // plaintext modulus - - // const P: u64 = Q; - const P: u64 = Q * Q; - const PQ: u64 = P * Q; - + let q: u64 = 2u64.pow(16) + 1; // q prime, and 2^q + 1 shape + let param = Param { + ring: RingParam { q, n: 16 }, + t: 2, // plaintext modulus + p: q * q, + }; let mut rng = rand::thread_rng(); - let msg_dist = Uniform::new(0_u64, T); + let msg_dist = Uniform::new(0_u64, param.t); for _ in 0..1_000 { - let m1 = Rq::::rand_u64(&mut rng, msg_dist)?; - let m2 = Rq::::rand_u64(&mut rng, msg_dist)?; + let m1 = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; + let m2 = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; - test_tensor_opt::(&mut rng, m1, m2)?; + test_tensor_opt(&mut rng, ¶m, m1, m2)?; } Ok(()) } - fn test_tensor_opt( - mut rng: impl Rng, - m1: Rq, - m2: Rq, - ) -> Result<()> { - let (sk, pk) = BFV::::new_key(&mut rng)?; + fn test_tensor_opt(mut rng: impl Rng, param: &Param, m1: Rq, m2: Rq) -> Result<()> { + let (sk, pk) = BFV::new_key(&mut rng, ¶m)?; - let c1 = BFV::::encrypt(&mut rng, &pk, &m1)?; - let c2 = BFV::::encrypt(&mut rng, &pk, &m2)?; + let c1 = BFV::encrypt(&mut rng, ¶m, &pk, &m1)?; + let c2 = BFV::encrypt(&mut rng, ¶m, &pk, &m2)?; - let (c_a, c_b, c_c) = RLWE::::tensor::(&c1, &c2); - // let (c_a, c_b, c_c) = RLWE::::tensor_new::(&c1, &c2); + let (c_a, c_b, c_c) = RLWE::tensor(param.t, &c1, &c2); + // let (c_a, c_b, c_c) = RLWE::tensor_new::(&c1, &c2); // decrypt non-relinearized mul result - let m3: Rq = c_a + c_b * sk.0 + c_c * sk.0 * sk.0; + let m3: Rq = c_a + &c_b * &sk.0 + &c_c * &(&sk.0 * &sk.0); + // let m3: Rq = c_a - // + Rq::::from_vec_i64(arith::ring_n::naive_mul(&c_b.to_r(), &sk.0.to_r())) - // + Rq::::from_vec_i64(arith::ring_n::naive_mul( + // + Rq::from_vec_i64(arith::ring_n::naive_mul(&c_b.to_r(), &sk.0.to_r())) + // + Rq::from_vec_i64(arith::ring_n::naive_mul( // &c_c.to_r(), // &R::::from_vec(arith::ring_n::naive_mul(&sk.0.to_r(), &sk.0.to_r())), // )); - let m3: Rq = m3.mul_div_round(T, Q); // descale - let m3 = m3.remodule::(); + let m3: Rq = m3.mul_div_round(param.t, param.ring.q); // descale + let m3 = m3.remodule(param.t); - let naive = (m1.to_r() * m2.to_r()).to_rq::(); + let naive = (m1.clone().to_r() * m2.clone().to_r()).to_rq(param.t); // TODO rm clones assert_eq!( m3.coeffs().to_vec(), naive.coeffs().to_vec(), @@ -510,44 +556,39 @@ mod tests { #[test] fn test_mul_relin() -> Result<()> { - const Q: u64 = 2u64.pow(16) + 1; - const N: usize = 16; - const T: u64 = 2; // plaintext modulus - type S = BFV; - - const P: u64 = Q * Q; - const PQ: u64 = P * Q; + let q: u64 = 2u64.pow(16) + 1; // q prime, and 2^q + 1 shape + let param = Param { + ring: RingParam { q, n: 16 }, + t: 2, // plaintext modulus + p: q * q, + }; let mut rng = rand::thread_rng(); - let msg_dist = Uniform::new(0_u64, T); + let msg_dist = Uniform::new(0_u64, param.t); for _ in 0..1_000 { - let m1 = Rq::::rand_u64(&mut rng, msg_dist)?; - let m2 = Rq::::rand_u64(&mut rng, msg_dist)?; + let m1 = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; + let m2 = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; - test_mul_relin_opt::(&mut rng, m1, m2)?; + test_mul_relin_opt(&mut rng, ¶m, m1, m2)?; } Ok(()) } - fn test_mul_relin_opt( - mut rng: impl Rng, - m1: Rq, - m2: Rq, - ) -> Result<()> { - let (sk, pk) = BFV::::new_key(&mut rng)?; + fn test_mul_relin_opt(mut rng: impl Rng, param: &Param, m1: Rq, m2: Rq) -> Result<()> { + let (sk, pk) = BFV::new_key(&mut rng, ¶m)?; - let rlk = BFV::::rlk_key::(&mut rng, &sk)?; + let rlk = BFV::rlk_key(&mut rng, ¶m, &sk)?; - let c1 = BFV::::encrypt(&mut rng, &pk, &m1)?; - let c2 = BFV::::encrypt(&mut rng, &pk, &m2)?; + let c1 = BFV::encrypt(&mut rng, ¶m, &pk, &m1)?; + let c2 = BFV::encrypt(&mut rng, ¶m, &pk, &m2)?; - let c3 = RLWE::::mul::(&rlk, &c1, &c2); // uses relinearize internally + let c3 = RLWE::mul(param.t, &rlk, &c1, &c2); // uses relinearize internally - let m3 = BFV::::decrypt(&sk, &c3); + let m3 = BFV::decrypt(¶m, &sk, &c3); - let naive = (m1.to_r() * m2.to_r()).to_rq::(); + let naive = (m1.clone().to_r() * m2.clone().to_r()).to_rq(param.t); // TODO rm clones assert_eq!( m3.coeffs().to_vec(), naive.coeffs().to_vec(), diff --git a/ckks/src/encoder.rs b/ckks/src/encoder.rs index a33378c..9e46ed8 100644 --- a/ckks/src/encoder.rs +++ b/ckks/src/encoder.rs @@ -1,14 +1,15 @@ use anyhow::Result; -use arith::{Matrix, Ring, Rq, C, R}; +use arith::{Matrix, Rq, C, R}; #[derive(Clone, Debug)] -pub struct SecretKey(Rq); +pub struct SecretKey(Rq); #[derive(Clone, Debug)] -pub struct PublicKey(Rq, Rq); +pub struct PublicKey(Rq, Rq); -pub struct Encoder { +pub struct Encoder { + n: usize, scale_factor: C, // Δ (delta) primitive: C, basis: Matrix>, @@ -34,13 +35,14 @@ fn vandermonde(n: usize, w: C) -> Matrix> { } Matrix::>(v) } -impl Encoder { - pub fn new(scale_factor: C) -> Self { - let primitive: C = primitive_root_of_unity(2 * N); - let basis = vandermonde(N, primitive); +impl Encoder { + pub fn new(n: usize, scale_factor: C) -> Self { + let primitive: C = primitive_root_of_unity(2 * n); + let basis = vandermonde(n, primitive); let basis_t = basis.transpose(); Self { + n, scale_factor, primitive, basis, @@ -52,7 +54,7 @@ impl Encoder { /// from $\mathbb{C}^{N/2} \longrightarrow \mathbb{Z_q}[X]/(X^N +1) = R$ // TODO use alg.1 from 2018-1043, // or as in 2018-1073: $f(x) = 1N (U^T.conj() m + U^T m.conj())$ - pub fn encode(&self, z: &[C]) -> Result> { + pub fn encode(&self, z: &[C]) -> Result { // $pi^{-1}: \mathbb{C}^{N/2} \longrightarrow \mathbb{H}$ let expanded = self.pi_inv(z); @@ -93,10 +95,10 @@ impl Encoder { // TMP: naive round, maybe do gaussian let coeffs = r.iter().map(|e| e.re.round() as i64).collect::>(); - Ok(R::from_vec(coeffs)) + Ok(R::from_vec(self.n, coeffs)) } - pub fn decode(&self, p: &R) -> Result>> { + pub fn decode(&self, p: &R) -> Result>> { let p: Vec> = p .coeffs() .iter() @@ -110,7 +112,7 @@ impl Encoder { /// pi: \mathbb{H} \longrightarrow \mathbb{C}^{N/2} fn pi(&self, z: &[C]) -> Vec> { - z[..N / 2].to_vec() + z[..self.n / 2].to_vec() } /// pi^{-1}: \mathbb{C}^{N/2} \longrightarrow \mathbb{H} fn pi_inv(&self, z: &[C]) -> Vec> { @@ -154,6 +156,7 @@ mod tests { fn test_encode_decode() -> Result<()> { const Q: u64 = 1024; const N: usize = 32; + let n: usize = 32; let T = 128; // WIP let mut rng = rand::thread_rng(); @@ -166,9 +169,9 @@ mod tests { .collect(); let delta = C::::new(64.0, 0.0); // delta = scaling factor - let encoder = Encoder::::new(delta); + let encoder = Encoder::new(n, delta); - let m: R = encoder.encode(&z)?; // polynomial (encoded vec) \in R + let m: R = encoder.encode(&z)?; // polynomial (encoded vec) \in R let z_decoded = encoder.decode(&m)?; diff --git a/ckks/src/lib.rs b/ckks/src/lib.rs index 1552d03..1fb8bf4 100644 --- a/ckks/src/lib.rs +++ b/ckks/src/lib.rs @@ -5,7 +5,7 @@ #![allow(clippy::upper_case_acronyms)] #![allow(dead_code)] // TMP -use arith::{Rq, C, R}; +use arith::{RingParam, Rq, C, R}; use anyhow::Result; use rand::Rng; @@ -18,35 +18,47 @@ pub use encoder::Encoder; // sigma=3.2 from: https://eprint.iacr.org/2016/421.pdf page 17 const ERR_SIGMA: f64 = 3.2; +#[derive(Clone, Copy, Debug)] +pub struct Param { + ring: RingParam, + t: u64, +} + #[derive(Debug)] -pub struct PublicKey(Rq, Rq); +pub struct PublicKey(Rq, Rq); -pub struct SecretKey(Rq); +pub struct SecretKey(Rq); -pub struct CKKS { - encoder: Encoder, +pub struct CKKS { + param: Param, + encoder: Encoder, } -impl CKKS { - pub fn new(delta: C) -> Self { - let encoder = Encoder::::new(delta); - Self { encoder } +impl CKKS { + pub fn new(param: &Param, delta: C) -> Self { + let encoder = Encoder::new(param.ring.n, delta); + Self { + param: param.clone(), + encoder, + } } /// generate a new key pair (privK, pubK) - pub fn new_key(&self, mut rng: impl Rng) -> Result<(SecretKey, PublicKey)> { + pub fn new_key(&self, mut rng: impl Rng) -> Result<(SecretKey, PublicKey)> { + let param = &self.param; + let Xi_key = Uniform::new(-1_f64, 1_f64); let Xi_err = Normal::new(0_f64, ERR_SIGMA)?; - let e = Rq::::rand_f64(&mut rng, Xi_err)?; + let e = Rq::rand_f64(&mut rng, Xi_err, ¶m.ring)?; - let mut s = Rq::::rand_f64(&mut rng, Xi_key)?; + let mut s = Rq::rand_f64(&mut rng, Xi_key, ¶m.ring)?; // since s is going to be multiplied by other Rq elements, already // compute its NTT s.compute_evals(); - let a = Rq::::rand_f64(&mut rng, Xi_key)?; + let a = Rq::rand_f64(&mut rng, Xi_key, ¶m.ring)?; - let pk: PublicKey = PublicKey((&(-a) * &s) + e, a.clone()); + let pk: PublicKey = PublicKey((&(-a.clone()) * &s) + e, a.clone()); // TODO rm clones Ok((SecretKey(s), pk)) } @@ -54,64 +66,54 @@ impl CKKS { fn encrypt( &self, // TODO maybe rm? mut rng: impl Rng, - pk: &PublicKey, - m: &R, - ) -> Result<(Rq, Rq)> { + pk: &PublicKey, + m: &R, + ) -> Result<(Rq, Rq)> { + let param = self.param; let Xi_key = Uniform::new(-1_f64, 1_f64); let Xi_err = Normal::new(0_f64, ERR_SIGMA)?; - let e_0 = Rq::::rand_f64(&mut rng, Xi_err)?; - let e_1 = Rq::::rand_f64(&mut rng, Xi_err)?; + let e_0 = Rq::rand_f64(&mut rng, Xi_err, ¶m.ring)?; + let e_1 = Rq::rand_f64(&mut rng, Xi_err, ¶m.ring)?; - let v = Rq::::rand_f64(&mut rng, Xi_key)?; + let v = Rq::rand_f64(&mut rng, Xi_key, ¶m.ring)?; - let m: Rq = Rq::::from(*m); + // let m: Rq = Rq::from(*m); + let m: Rq = m.clone().to_rq(param.ring.q); // TODO rm clone - Ok((m + e_0 + v * pk.0.clone(), v * pk.1.clone() + e_1)) + Ok((m + e_0 + &v * &pk.0.clone(), &v * &pk.1 + e_1)) } fn decrypt( &self, // TODO maybe rm? - sk: &SecretKey, - c: (Rq, Rq), - ) -> Result> { - let m = c.0.clone() + c.1 * sk.0; + sk: &SecretKey, + c: (Rq, Rq), + ) -> Result { + let m = c.0.clone() + &c.1 * &sk.0; Ok(m.mod_centered_q()) } pub fn encode_and_encrypt( &self, mut rng: impl Rng, - pk: &PublicKey, + pk: &PublicKey, z: &[C], - ) -> Result<(Rq, Rq)> { - let m: R = self.encoder.encode(&z)?; // polynomial (encoded vec) \in R + ) -> Result<(Rq, Rq)> { + let m: R = self.encoder.encode(&z)?; // polynomial (encoded vec) \in R self.encrypt(&mut rng, pk, &m) } - pub fn decrypt_and_decode( - &self, - sk: SecretKey, - c: (Rq, Rq), - ) -> Result>> { + pub fn decrypt_and_decode(&self, sk: SecretKey, c: (Rq, Rq)) -> Result>> { let d = self.decrypt(&sk, c)?; self.encoder.decode(&d) } - pub fn add( - &self, - c0: &(Rq, Rq), - c1: &(Rq, Rq), - ) -> Result<(Rq, Rq)> { + pub fn add(&self, c0: &(Rq, Rq), c1: &(Rq, Rq)) -> Result<(Rq, Rq)> { Ok((&c0.0 + &c1.0, &c0.1 + &c1.1)) } - pub fn sub( - &self, - c0: &(Rq, Rq), - c1: &(Rq, Rq), - ) -> Result<(Rq, Rq)> { + pub fn sub(&self, c0: &(Rq, Rq), c1: &(Rq, Rq)) -> Result<(Rq, Rq)> { Ok((&c0.0 - &c1.0, &c0.1 + &c1.1)) } } @@ -122,21 +124,26 @@ mod tests { #[test] fn test_encrypt_decrypt() -> Result<()> { - const Q: u64 = 2u64.pow(16) + 1; - const N: usize = 32; - const T: u64 = 50; + let q: u64 = 2u64.pow(16) + 1; + let n: usize = 32; + let t: u64 = 50; + let param = Param { + ring: RingParam { q, n }, + t, + }; let scale_factor_u64 = 512_u64; // delta let scale_factor = C::::new(scale_factor_u64 as f64, 0.0); // delta let mut rng = rand::thread_rng(); for _ in 0..1000 { - let ckks = CKKS::::new(scale_factor); + let ckks = CKKS::new(¶m, scale_factor); let (sk, pk) = ckks.new_key(&mut rng)?; - let m_raw: R = Rq::::rand_f64(&mut rng, Uniform::new(0_f64, T as f64))?.to_r(); - let m = m_raw * scale_factor_u64; + let m_raw: R = + Rq::rand_f64(&mut rng, Uniform::new(0_f64, t as f64), ¶m.ring)?.to_r(); + let m = &m_raw * &scale_factor_u64; let ct = ckks.encrypt(&mut rng, &pk, &m)?; let m_decrypted = ckks.decrypt(&sk, ct)?; @@ -146,8 +153,8 @@ mod tests { .iter() .map(|e| (*e as f64 / (scale_factor_u64 as f64)).round() as u64) .collect(); - let m_decrypted = Rq::::from_vec_u64(m_decrypted); - assert_eq!(m_decrypted, Rq::::from(m_raw)); + let m_decrypted = Rq::from_vec_u64(¶m.ring, m_decrypted); + assert_eq!(m_decrypted, m_raw.to_rq(q)); } Ok(()) @@ -155,21 +162,25 @@ mod tests { #[test] fn test_encode_encrypt_decrypt_decode() -> Result<()> { - const Q: u64 = 2u64.pow(16) + 1; - const N: usize = 16; - const T: u64 = 8; + let q: u64 = 2u64.pow(16) + 1; + let n: usize = 16; + let t: u64 = 8; + let param = Param { + ring: RingParam { q, n }, + t, + }; let scale_factor = C::::new(512.0, 0.0); // delta let mut rng = rand::thread_rng(); for _ in 0..1000 { - let ckks = CKKS::::new(scale_factor); + let ckks = CKKS::new(¶m, scale_factor); let (sk, pk) = ckks.new_key(&mut rng)?; - let z: Vec> = std::iter::repeat_with(|| C::::rand(&mut rng, T)) - .take(N / 2) + let z: Vec> = std::iter::repeat_with(|| C::::rand(&mut rng, t)) + .take(n / 2) .collect(); - let m: R = ckks.encoder.encode(&z)?; + let m: R = ckks.encoder.encode(&z)?; println!("{}", m); // sanity check @@ -200,26 +211,30 @@ mod tests { #[test] fn test_add() -> Result<()> { - const Q: u64 = 2u64.pow(16) + 1; - const N: usize = 16; - const T: u64 = 8; + let q: u64 = 2u64.pow(16) + 1; + let n: usize = 16; + let t: u64 = 8; + let param = Param { + ring: RingParam { q, n }, + t, + }; let scale_factor = C::::new(1024.0, 0.0); // delta let mut rng = rand::thread_rng(); for _ in 0..1000 { - let ckks = CKKS::::new(scale_factor); + let ckks = CKKS::new(¶m, scale_factor); let (sk, pk) = ckks.new_key(&mut rng)?; - let z0: Vec> = std::iter::repeat_with(|| C::::rand(&mut rng, T)) - .take(N / 2) + let z0: Vec> = std::iter::repeat_with(|| C::::rand(&mut rng, t)) + .take(n / 2) .collect(); - let z1: Vec> = std::iter::repeat_with(|| C::::rand(&mut rng, T)) - .take(N / 2) + let z1: Vec> = std::iter::repeat_with(|| C::::rand(&mut rng, t)) + .take(n / 2) .collect(); - let m0: R = ckks.encoder.encode(&z0)?; - let m1: R = ckks.encoder.encode(&z1)?; + let m0: R = ckks.encoder.encode(&z0)?; + let m1: R = ckks.encoder.encode(&z1)?; let ct0 = ckks.encrypt(&mut rng, &pk, &m0)?; let ct1 = ckks.encrypt(&mut rng, &pk, &m1)?; @@ -243,26 +258,30 @@ mod tests { #[test] fn test_sub() -> Result<()> { - const Q: u64 = 2u64.pow(16) + 1; - const N: usize = 16; - const T: u64 = 8; + let q: u64 = 2u64.pow(16) + 1; + let n: usize = 16; + let t: u64 = 2; + let param = Param { + ring: RingParam { q, n }, + t, + }; let scale_factor = C::::new(1024.0, 0.0); // delta let mut rng = rand::thread_rng(); for _ in 0..1000 { - let ckks = CKKS::::new(scale_factor); + let ckks = CKKS::new(¶m, scale_factor); let (sk, pk) = ckks.new_key(&mut rng)?; - let z0: Vec> = std::iter::repeat_with(|| C::::rand(&mut rng, T)) - .take(N / 2) + let z0: Vec> = std::iter::repeat_with(|| C::::rand(&mut rng, t)) + .take(n / 2) .collect(); - let z1: Vec> = std::iter::repeat_with(|| C::::rand(&mut rng, T)) - .take(N / 2) + let z1: Vec> = std::iter::repeat_with(|| C::::rand(&mut rng, t)) + .take(n / 2) .collect(); - let m0: R = ckks.encoder.encode(&z0)?; - let m1: R = ckks.encoder.encode(&z1)?; + let m0: R = ckks.encoder.encode(&z0)?; + let m1: R = ckks.encoder.encode(&z1)?; let ct0 = ckks.encrypt(&mut rng, &pk, &m0)?; let ct1 = ckks.encrypt(&mut rng, &pk, &m1)?; diff --git a/gfhe/src/glev.rs b/gfhe/src/glev.rs index a04f13b..c6a9c05 100644 --- a/gfhe/src/glev.rs +++ b/gfhe/src/glev.rs @@ -1,28 +1,33 @@ use anyhow::Result; use itertools::zip_eq; use rand::Rng; -use rand_distr::{Normal, Uniform}; -use std::ops::{Add, Mul}; +use std::ops::Mul; -use arith::{Ring, TR}; +use arith::Ring; -use crate::glwe::{PublicKey, SecretKey, GLWE}; +use crate::glwe::{Param, PublicKey, SecretKey, GLWE}; // l GLWEs #[derive(Clone, Debug)] -pub struct GLev(pub(crate) Vec>); +pub struct GLev(pub(crate) Vec>); -impl GLev { +impl GLev { pub fn encrypt( mut rng: impl Rng, + param: &Param, beta: u32, l: u32, - pk: &PublicKey, + pk: &PublicKey, m: &R, ) -> Result { - let glev: Vec> = (0..l) + let glev: Vec> = (0..l) .map(|i| { - GLWE::::encrypt(&mut rng, pk, &(*m * (R::Q / beta.pow(i as u32) as u64))) + GLWE::::encrypt( + &mut rng, + param, + pk, + &(m.clone() * (param.ring.q / beta.pow(i as u32) as u64)), + ) }) .collect::>>()?; @@ -30,38 +35,46 @@ impl GLev { } pub fn encrypt_s( mut rng: impl Rng, + param: &Param, beta: u32, l: u32, - sk: &SecretKey, + sk: &SecretKey, m: &R, - // delta: u64, ) -> Result { - let glev: Vec> = (1..l + 1) + let glev: Vec> = (1..l + 1) .map(|i| { - GLWE::::encrypt_s(&mut rng, sk, &(*m * (R::Q / beta.pow(i as u32) as u64))) + GLWE::::encrypt_s( + &mut rng, + param, + sk, + &(m.clone() * (param.ring.q / beta.pow(i as u32) as u64)), // TODO rm clone + ) }) .collect::>>()?; Ok(Self(glev)) } - pub fn decrypt(&self, sk: &SecretKey, beta: u32) -> R { + pub fn decrypt(&self, param: &Param, sk: &SecretKey, beta: u32) -> R { let pt = self.0[1].decrypt(sk); - pt.mul_div_round(beta as u64, R::Q) + pt.mul_div_round(beta as u64, param.ring.q) } } // dot product between a GLev and Vec. // Used for operating decompositions with KSK_i. // GLev * Vec --> GLWE -impl Mul> for GLev { - type Output = GLWE; - fn mul(self, v: Vec) -> GLWE { +impl Mul> for GLev { + type Output = GLWE; + fn mul(self, v: Vec) -> GLWE { + debug_assert_eq!(self.0.len(), v.len()); + // TODO debug_assert_eq of param + // l times GLWES - let glwes: Vec> = self.0; + let glwes: Vec> = self.0; // l iterations - let r: GLWE = zip_eq(v, glwes).map(|(v_i, glwe_i)| glwe_i * v_i).sum(); + let r: GLWE = zip_eq(v, glwes).map(|(v_i, glwe_i)| glwe_i * v_i).sum(); r } } @@ -72,33 +85,37 @@ mod tests { use rand::distributions::Uniform; use super::*; - use arith::Rq; + use arith::{RingParam, Rq}; #[test] fn test_encrypt_decrypt() -> Result<()> { - const Q: u64 = 2u64.pow(16) + 1; - const N: usize = 128; - const T: u64 = 2; // plaintext modulus - const K: usize = 16; - type S = GLev, K>; + let param = Param { + err_sigma: crate::glwe::ERR_SIGMA, + ring: RingParam { + q: 2u64.pow(16) + 1, + n: 128, + }, + k: 16, + t: 2, // plaintext modulus + }; + type S = GLev; let beta: u32 = 2; let l: u32 = 16; - // let delta: u64 = Q / T; // floored let mut rng = rand::thread_rng(); - let msg_dist = Uniform::new(0_u64, T); + let msg_dist = Uniform::new(0_u64, param.t); for _ in 0..200 { - let (sk, pk) = GLWE::, K>::new_key(&mut rng)?; + let (sk, pk) = GLWE::::new_key(&mut rng, ¶m)?; - let m = Rq::::rand_u64(&mut rng, msg_dist)?; - let m: Rq = m.remodule::(); + let m = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; + let m: Rq = m.remodule(param.ring.q); - let c = S::encrypt(&mut rng, beta, l, &pk, &m)?; - let m_recovered = c.decrypt::(&sk, beta); + let c = S::encrypt(&mut rng, ¶m, beta, l, &pk, &m)?; + let m_recovered = c.decrypt(¶m, &sk, beta); - assert_eq!(m.remodule::(), m_recovered.remodule::()); + assert_eq!(m.remodule(param.t), m_recovered.remodule(param.t)); } Ok(()) diff --git a/gfhe/src/glwe.rs b/gfhe/src/glwe.rs index f1aed8b..df3a830 100644 --- a/gfhe/src/glwe.rs +++ b/gfhe/src/glwe.rs @@ -8,79 +8,128 @@ use rand_distr::{Normal, Uniform}; use std::iter::Sum; use std::ops::{Add, AddAssign, Mul, Sub}; -use arith::{Ring, Rq, Zq, TR}; +use arith::{Ring, RingParam, Rq, Zq, TR}; use crate::glev::GLev; -// const ERR_SIGMA: f64 = 3.2; -const ERR_SIGMA: f64 = 0.0; // TODO WIP +// error deviation for the Gaussian(Normal) distribution +// sigma=3.2 from: https://eprint.iacr.org/2022/162.pdf page 5 +pub(crate) const ERR_SIGMA: f64 = 3.2; + +#[derive(Clone, Copy, Debug)] +pub struct Param { + pub err_sigma: f64, + pub ring: RingParam, + pub k: usize, + pub t: u64, +} +impl Param { + /// returns the plaintext param + pub fn pt(&self) -> RingParam { + // TODO think if maybe return a new truct "PtParam" to differenciate + // between the ciphertexxt (RingParam) and the plaintext param. Maybe it + // can be just a wrapper on top of RingParam. + RingParam { + q: self.t, + n: self.ring.n, + } + } + /// returns the LWE param for the given GLWE (self), that is, it uses k=K*N + /// as the length for the secret key. This follows [2018-421] where + /// TLWE sk: s \in B^n , where n=K*N + /// TRLWE sk: s \in B_N[X]^K + pub fn lwe(&self) -> Self { + Self { + err_sigma: ERR_SIGMA, + ring: RingParam { + q: self.ring.q, + n: 1, + }, + k: self.k * self.ring.n, + t: self.t, + } + } +} /// GLWE implemented over the `Ring` trait, so that it can be also instantiated /// over the Torus polynomials 𝕋_[X] = 𝕋_q[X]/ (X^N+1). #[derive(Clone, Debug)] -pub struct GLWE(pub TR, pub R); +pub struct GLWE(pub TR, pub R); #[derive(Clone, Debug)] -pub struct SecretKey(pub TR); +pub struct SecretKey(pub TR); #[derive(Clone, Debug)] -pub struct PublicKey(pub R, pub TR); +pub struct PublicKey(pub R, pub TR); // K GLevs, each KSK_i=l GLWEs #[derive(Clone, Debug)] -pub struct KSK(Vec>); +pub struct KSK(Vec>); -impl GLWE { - pub fn zero() -> Self { - Self(TR::zero(), R::zero()) +impl GLWE { + pub fn zero(k: usize, param: &RingParam) -> Self { + Self(TR::zero(k, ¶m), R::zero(¶m)) } - pub fn from_plaintext(p: R) -> Self { - Self(TR::zero(), p) + pub fn from_plaintext(k: usize, param: &RingParam, p: R) -> Self { + Self(TR::zero(k, ¶m), p) } - pub fn new_key(mut rng: impl Rng) -> Result<(SecretKey, PublicKey)> { + pub fn new_key(mut rng: impl Rng, param: &Param) -> Result<(SecretKey, PublicKey)> { let Xi_key = Uniform::new(0_f64, 2_f64); - let Xi_err = Normal::new(0_f64, ERR_SIGMA)?; - - let s: TR = TR::rand(&mut rng, Xi_key); - let a: TR = TR::rand(&mut rng, Uniform::new(0_f64, R::Q as f64)); - let e = R::rand(&mut rng, Xi_err); - - let pk: PublicKey = PublicKey((&a * &s) + e, a); + let Xi_err = Normal::new(0_f64, param.err_sigma)?; + + let s: TR = TR::rand(&mut rng, Xi_key, param.k, ¶m.ring); + let a: TR = TR::rand( + &mut rng, + Uniform::new(0_f64, param.ring.q as f64), + param.k, + ¶m.ring, + ); + let e = R::rand(&mut rng, Xi_err, ¶m.ring); + + let pk: PublicKey = PublicKey((&a * &s) + e, a); Ok((SecretKey(s), pk)) } - pub fn pk_from_sk(mut rng: impl Rng, sk: SecretKey) -> Result> { - let Xi_err = Normal::new(0_f64, ERR_SIGMA)?; - - let a: TR = TR::rand(&mut rng, Uniform::new(0_f64, R::Q as f64)); - let e = R::rand(&mut rng, Xi_err); - - let pk: PublicKey = PublicKey((&a * &sk.0) + e, a); + pub fn pk_from_sk(mut rng: impl Rng, param: &Param, sk: SecretKey) -> Result> { + let Xi_err = Normal::new(0_f64, param.err_sigma)?; + + let a: TR = TR::rand( + &mut rng, + Uniform::new(0_f64, param.ring.q as f64), + param.k, + ¶m.ring, + ); + let e = R::rand(&mut rng, Xi_err, ¶m.ring); + + let pk: PublicKey = PublicKey((&a * &sk.0) + e, a); Ok(pk) } pub fn new_ksk( mut rng: impl Rng, + param: &Param, beta: u32, l: u32, - sk: &SecretKey, - new_sk: &SecretKey, - ) -> Result> { - let r: Vec> = (0..K) + sk: &SecretKey, + new_sk: &SecretKey, + ) -> Result> { + debug_assert_eq!(param.k, sk.0.k); + let k = sk.0.k; + let r: Vec> = (0..k) .into_iter() .map(|i| // treat sk_i as the msg being encrypted - GLev::::encrypt_s(&mut rng, beta, l, &new_sk, &sk.0 .0[i])) + GLev::::encrypt_s(&mut rng, param, beta, l, &new_sk, &sk.0 .r[i])) .collect::>>()?; Ok(KSK(r)) } - pub fn key_switch(&self, beta: u32, l: u32, ksk: &KSK) -> Self { - let (a, b): (TR, R) = (self.0.clone(), self.1); + pub fn key_switch(&self, param: &Param, beta: u32, l: u32, ksk: &KSK) -> Self { + let (a, b): (TR, R) = (self.0.clone(), self.1.clone()); // TODO rm clones - let lhs: GLWE = GLWE(TR::zero(), b); + let lhs: GLWE = GLWE(TR::zero(param.k, ¶m.ring), b); // K iterations, ksk.0 contains K times GLev - let rhs: GLWE = zip_eq(a.0, ksk.0.clone()) + let rhs: GLWE = zip_eq(a.r, ksk.0.clone()) .map(|(a_i, ksk_i)| ksk_i * a_i.decompose(beta, l)) // dot_product .sum(); @@ -90,121 +139,141 @@ impl GLWE { // encrypts with the given SecretKey (instead of PublicKey) pub fn encrypt_s( mut rng: impl Rng, - sk: &SecretKey, + param: &Param, + sk: &SecretKey, m: &R, // already scaled ) -> Result { let Xi_key = Uniform::new(0_f64, 2_f64); - let Xi_err = Normal::new(0_f64, ERR_SIGMA)?; + let Xi_err = Normal::new(0_f64, param.err_sigma)?; - let a: TR = TR::rand(&mut rng, Xi_key); - let e = R::rand(&mut rng, Xi_err); + let a: TR = TR::rand(&mut rng, Xi_key, param.k, ¶m.ring); + let e = R::rand(&mut rng, Xi_err, ¶m.ring); - let b: R = (&a * &sk.0) + *m + e; + let b: R = (&a * &sk.0) + m.clone() + e; // TODO rm clone Ok(Self(a, b)) } pub fn encrypt( mut rng: impl Rng, - pk: &PublicKey, + param: &Param, + pk: &PublicKey, m: &R, // already scaled ) -> Result { let Xi_key = Uniform::new(0_f64, 2_f64); - let Xi_err = Normal::new(0_f64, ERR_SIGMA)?; + let Xi_err = Normal::new(0_f64, param.err_sigma)?; - let u: R = R::rand(&mut rng, Xi_key); + let u: R = R::rand(&mut rng, Xi_key, ¶m.ring); - let e0 = R::rand(&mut rng, Xi_err); - let e1 = TR::::rand(&mut rng, Xi_err); + let e0 = R::rand(&mut rng, Xi_err, ¶m.ring); + let e1 = TR::::rand(&mut rng, Xi_err, param.k, ¶m.ring); - let b: R = pk.0.clone() * u.clone() + *m + e0; - let d: TR = &pk.1 * &u + e1; + let b: R = pk.0.clone() * u.clone() + m.clone() + e0; // TODO rm clones + let d: TR = &pk.1 * &u + e1; Ok(Self(d, b)) } // returns m' not downscaled - pub fn decrypt(&self, sk: &SecretKey) -> R { - let (d, b): (TR, R) = (self.0.clone(), self.1); + pub fn decrypt(&self, sk: &SecretKey) -> R { + let (d, b): (TR, R) = (self.0.clone(), self.1.clone()); let p: R = b - &d * &sk.0; p } } // Methods for when Ring=Rq -impl GLWE, K> { +impl GLWE { // scale up - pub fn encode(m: &Rq) -> Rq { - let m = m.remodule::(); - let delta = Q / T; // floored + pub fn encode(param: &Param, m: &Rq) -> Rq { + debug_assert_eq!(param.t, m.param.q); + let m = m.remodule(param.ring.q); + let delta = param.ring.q / param.t; // floored m * delta } // scale down - pub fn decode(m: &Rq) -> Rq { - let r = m.mul_div_round(T, Q); - let r: Rq = r.remodule::(); + pub fn decode(param: &Param, m: &Rq) -> Rq { + let r = m.mul_div_round(param.t, param.ring.q); + let r: Rq = r.remodule(param.t); r } - pub fn mod_switch(&self) -> GLWE, K> { - let a: TR, K> = TR(self - .0 - .0 - .iter() - .map(|r| r.mod_switch::

()) - .collect::>()); - let b: Rq = self.1.mod_switch::

(); + pub fn mod_switch(&self, p: u64) -> GLWE { + let a: TR = TR { + k: self.0.k, + r: self.0.r.iter().map(|r| r.mod_switch(p)).collect::>(), + }; + let b: Rq = self.1.mod_switch(p); GLWE(a, b) } } -impl Add> for GLWE { +impl Add> for GLWE { type Output = Self; fn add(self, other: Self) -> Self { - let a: TR = self.0 + other.0; + debug_assert_eq!(self.0.k, other.0.k); + debug_assert_eq!(self.1.param(), other.1.param()); + + let a: TR = self.0 + other.0; let b: R = self.1 + other.1; Self(a, b) } } -impl Add for GLWE { +impl Add for GLWE { type Output = Self; fn add(self, plaintext: R) -> Self { - let a: TR = self.0; + debug_assert_eq!(self.1.param(), plaintext.param()); + + let a: TR = self.0; let b: R = self.1 + plaintext; Self(a, b) } } -impl AddAssign for GLWE { +impl AddAssign for GLWE { fn add_assign(&mut self, rhs: Self) { - for i in 0..K { - self.0 .0[i] = self.0 .0[i].clone() + rhs.0 .0[i].clone(); + debug_assert_eq!(self.0.k, rhs.0.k); + debug_assert_eq!(self.1.param(), rhs.1.param()); + + let k = self.0.k; + for i in 0..k { + self.0.r[i] = self.0.r[i].clone() + rhs.0.r[i].clone(); } self.1 = self.1.clone() + rhs.1.clone(); } } -impl Sum> for GLWE { - fn sum(iter: I) -> Self +impl Sum> for GLWE { + fn sum(mut iter: I) -> Self where I: Iterator, { - let mut acc = GLWE::::zero(); - for e in iter { - acc += e; - } - acc + let first = iter.next().unwrap(); + iter.fold(first, |acc, e| acc + e) } } -impl Sub> for GLWE { +impl Sub> for GLWE { type Output = Self; fn sub(self, other: Self) -> Self { - let a: TR = self.0 - other.0; + debug_assert_eq!(self.0.k, other.0.k); + debug_assert_eq!(self.1.param(), other.1.param()); + + let a: TR = self.0 - other.0; let b: R = self.1 - other.1; Self(a, b) } } -impl Mul for GLWE { +impl Mul for GLWE { type Output = Self; fn mul(self, plaintext: R) -> Self { - let a: TR = TR(self.0 .0.iter().map(|r_i| *r_i * plaintext).collect()); + debug_assert_eq!(self.1.param(), plaintext.param()); + + let a: TR = TR { + k: self.0.k, + r: self + .0 + .r + .iter() + .map(|r_i| r_i.clone() * plaintext.clone()) + .collect(), + }; let b: R = self.1 * plaintext; Self(a, b) } @@ -255,77 +324,90 @@ mod tests { use super::*; #[test] - fn test_encrypt_decrypt() -> Result<()> { - const Q: u64 = 2u64.pow(16) + 1; - const N: usize = 128; - const T: u64 = 32; // plaintext modulus - const K: usize = 16; - type S = GLWE, K>; + fn test_encrypt_decrypt_ring_nq() -> Result<()> { + let param = Param { + err_sigma: ERR_SIGMA, + ring: RingParam { + q: 2u64.pow(16) + 1, + n: 128, + }, + k: 16, + t: 32, // plaintext modulus + }; + type S = GLWE; let mut rng = rand::thread_rng(); - let msg_dist = Uniform::new(0_u64, T); + let msg_dist = Uniform::new(0_u64, param.t); for _ in 0..200 { - let (sk, pk) = S::new_key(&mut rng)?; + let (sk, pk) = S::new_key(&mut rng, ¶m)?; - let m = Rq::::rand_u64(&mut rng, msg_dist)?; // msg - // let m: Rq = m.remodule::(); + let m = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; // msg + let p = S::encode(¶m, &m); // plaintext - let p = S::encode::(&m); // plaintext - let c = S::encrypt(&mut rng, &pk, &p)?; // ciphertext + let c = S::encrypt(&mut rng, ¶m, &pk, &p)?; // ciphertext let p_recovered = c.decrypt(&sk); - let m_recovered = S::decode::(&p_recovered); + let m_recovered = S::decode(¶m, &p_recovered); - assert_eq!(m.remodule::(), m_recovered.remodule::()); + assert_eq!(m.remodule(param.t), m_recovered.remodule(param.t)); // same but using encrypt_s (with sk instead of pk)) - let c = S::encrypt_s(&mut rng, &sk, &p)?; + let c = S::encrypt_s(&mut rng, ¶m, &sk, &p)?; let p_recovered = c.decrypt(&sk); - let m_recovered = S::decode::(&p_recovered); + let m_recovered = S::decode(¶m, &p_recovered); - assert_eq!(m.remodule::(), m_recovered.remodule::()); + assert_eq!(m.remodule(param.t), m_recovered.remodule(param.t)); } Ok(()) } use arith::{Tn, T64}; - use std::array; - pub fn t_encode(m: &Rq) -> Tn<4> { - let delta = u64::MAX / P; // floored + pub fn t_encode(param: &RingParam, m: &Rq) -> Tn { + let p = m.param.q; // plaintext space + let delta = u64::MAX / p; // floored let coeffs = m.coeffs(); - Tn(array::from_fn(|i| T64(coeffs[i].0 * delta))) + Tn { + param: *param, + coeffs: coeffs.iter().map(|c_i| T64(c_i.v * delta)).collect(), + } } - pub fn t_decode(p: &Tn<4>) -> Rq { - let p = p.mul_div_round(P, u64::MAX); - Rq::::from_vec_u64(p.coeffs().iter().map(|c| c.0).collect()) + pub fn t_decode(param: &Param, pt: &Tn) -> Rq { + let pt = pt.mul_div_round(param.t, u64::MAX); + Rq::from_vec_u64(¶m.pt(), pt.coeffs().iter().map(|c| c.0).collect()) } #[test] fn test_encrypt_decrypt_torus() -> Result<()> { - const N: usize = 128; - const T: u64 = 32; // plaintext modulus - const K: usize = 16; - type S = GLWE, K>; + let param = Param { + err_sigma: ERR_SIGMA, + ring: RingParam { + q: u64::MAX, + n: 128, + }, + k: 16, + t: 32, // plaintext modulus + }; + type S = GLWE; let mut rng = rand::thread_rng(); - let msg_dist = Uniform::new(0_f64, T as f64); + let msg_dist = Uniform::new(0_f64, param.t as f64); for _ in 0..200 { - let (sk, pk) = S::new_key(&mut rng)?; + let (sk, pk) = S::new_key(&mut rng, ¶m)?; - let m = Rq::::rand(&mut rng, msg_dist); // msg + let m = Rq::rand(&mut rng, msg_dist, ¶m.pt()); // msg - let p = t_encode::(&m); // plaintext - let c = S::encrypt(&mut rng, &pk, &p)?; // ciphertext + let p = t_encode(¶m.ring, &m); // plaintext + let c = S::encrypt(&mut rng, ¶m, &pk, &p)?; // ciphertext let p_recovered = c.decrypt(&sk); - let m_recovered = t_decode::(&p_recovered); + let m_recovered = t_decode(¶m, &p_recovered); assert_eq!(m, m_recovered); // same but using encrypt_s (with sk instead of pk)) - let c = S::encrypt_s(&mut rng, &sk, &p)?; + let c = S::encrypt_s(&mut rng, ¶m, &sk, &p)?; let p_recovered = c.decrypt(&sk); - let m_recovered = t_decode::(&p_recovered); + let m_recovered = t_decode(¶m, &p_recovered); assert_eq!(m, m_recovered); } @@ -335,32 +417,37 @@ mod tests { #[test] fn test_addition() -> Result<()> { - const Q: u64 = 2u64.pow(16) + 1; - const N: usize = 128; - const T: u64 = 20; - const K: usize = 16; - type S = GLWE, K>; + let param = Param { + err_sigma: ERR_SIGMA, + ring: RingParam { + q: 2u64.pow(16) + 1, + n: 128, + }, + k: 16, + t: 20, // plaintext modulus + }; + type S = GLWE; let mut rng = rand::thread_rng(); - let msg_dist = Uniform::new(0_u64, T); + let msg_dist = Uniform::new(0_u64, param.t); for _ in 0..200 { - let (sk, pk) = S::new_key(&mut rng)?; + let (sk, pk) = S::new_key(&mut rng, ¶m)?; - let m1 = Rq::::rand_u64(&mut rng, msg_dist)?; - let m2 = Rq::::rand_u64(&mut rng, msg_dist)?; - let p1: Rq = S::encode::(&m1); // plaintext - let p2: Rq = S::encode::(&m2); // plaintext + let m1 = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; + let m2 = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; + let p1: Rq = S::encode(¶m, &m1); // plaintext + let p2: Rq = S::encode(¶m, &m2); // plaintext - let c1 = S::encrypt(&mut rng, &pk, &p1)?; - let c2 = S::encrypt(&mut rng, &pk, &p2)?; + let c1 = S::encrypt(&mut rng, ¶m, &pk, &p1)?; + let c2 = S::encrypt(&mut rng, ¶m, &pk, &p2)?; let c3 = c1 + c2; let p3_recovered = c3.decrypt(&sk); - let m3_recovered = S::decode::(&p3_recovered); + let m3_recovered = S::decode(¶m, &p3_recovered); - assert_eq!((m1 + m2).remodule::(), m3_recovered.remodule::()); + assert_eq!((m1 + m2).remodule(param.t), m3_recovered.remodule(param.t)); } Ok(()) @@ -368,31 +455,36 @@ mod tests { #[test] fn test_add_plaintext() -> Result<()> { - const Q: u64 = 2u64.pow(16) + 1; - const N: usize = 128; - const T: u64 = 32; - const K: usize = 16; - type S = GLWE, K>; + let param = Param { + err_sigma: ERR_SIGMA, + ring: RingParam { + q: 2u64.pow(16) + 1, + n: 128, + }, + k: 16, + t: 32, // plaintext modulus + }; + type S = GLWE; let mut rng = rand::thread_rng(); - let msg_dist = Uniform::new(0_u64, T); + let msg_dist = Uniform::new(0_u64, param.t); for _ in 0..200 { - let (sk, pk) = S::new_key(&mut rng)?; + let (sk, pk) = S::new_key(&mut rng, ¶m)?; - let m1 = Rq::::rand_u64(&mut rng, msg_dist)?; - let m2 = Rq::::rand_u64(&mut rng, msg_dist)?; - let p1: Rq = S::encode::(&m1); // plaintext - let p2: Rq = S::encode::(&m2); // plaintext + let m1 = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; + let m2 = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; + let p1: Rq = S::encode(¶m, &m1); // plaintext + let p2: Rq = S::encode(¶m, &m2); // plaintext - let c1 = S::encrypt(&mut rng, &pk, &p1)?; + let c1 = S::encrypt(&mut rng, ¶m, &pk, &p1)?; let c3 = c1 + p2; let p3_recovered = c3.decrypt(&sk); - let m3_recovered = S::decode::(&p3_recovered); + let m3_recovered = S::decode(¶m, &p3_recovered); - assert_eq!((m1 + m2).remodule::(), m3_recovered.remodule::()); + assert_eq!((m1 + m2).remodule(param.t), m3_recovered.remodule(param.t)); } Ok(()) @@ -400,30 +492,35 @@ mod tests { #[test] fn test_mul_plaintext() -> Result<()> { - const Q: u64 = 2u64.pow(16) + 1; - const N: usize = 16; - const T: u64 = 4; - const K: usize = 16; - type S = GLWE, K>; + let param = Param { + err_sigma: ERR_SIGMA, + ring: RingParam { + q: 2u64.pow(16) + 1, + n: 16, + }, + k: 16, + t: 4, // plaintext modulus + }; + type S = GLWE; let mut rng = rand::thread_rng(); - let msg_dist = Uniform::new(0_u64, T); + let msg_dist = Uniform::new(0_u64, param.t); for _ in 0..200 { - let (sk, pk) = S::new_key(&mut rng)?; + let (sk, pk) = S::new_key(&mut rng, ¶m)?; - let m1 = Rq::::rand_u64(&mut rng, msg_dist)?; - let m2 = Rq::::rand_u64(&mut rng, msg_dist)?; - let p1: Rq = S::encode::(&m1); // plaintext - let p2 = m2.remodule::(); // notice we don't encode (scale by delta) + let m1 = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; + let m2 = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; + let p1: Rq = S::encode(¶m, &m1); // plaintext + let p2 = m2.remodule(param.ring.q); // notice we don't encode (scale by delta) - let c1 = S::encrypt(&mut rng, &pk, &p1)?; + let c1 = S::encrypt(&mut rng, ¶m, &pk, &p1)?; let c3 = c1 * p2; - let p3_recovered: Rq = c3.decrypt(&sk); - let m3_recovered: Rq = S::decode::(&p3_recovered); - assert_eq!((m1.to_r() * m2.to_r()).to_rq::(), m3_recovered); + let p3_recovered: Rq = c3.decrypt(&sk); + let m3_recovered: Rq = S::decode(¶m, &p3_recovered); + assert_eq!((m1.to_r() * m2.to_r()).to_rq(param.t), m3_recovered); } Ok(()) @@ -431,33 +528,50 @@ mod tests { #[test] fn test_mod_switch() -> Result<()> { - const Q: u64 = 2u64.pow(16) + 1; - const P: u64 = 2u64.pow(8) + 1; + let param = Param { + err_sigma: ERR_SIGMA, + ring: RingParam { + q: 2u64.pow(16) + 1, + n: 8, + }, + k: 16, + t: 4, // plaintext modulus, must be a prime or power of a prime + }; + let new_q: u64 = 2u64.pow(8) + 1; // note: wip, Q and P chosen so that P/Q is an integer - const N: usize = 8; - const T: u64 = 4; // plaintext modulus, must be a prime or power of a prime - const K: usize = 16; - type S = GLWE, K>; + type S = GLWE; let mut rng = rand::thread_rng(); - let msg_dist = Uniform::new(0_u64, T); + let msg_dist = Uniform::new(0_u64, param.t); for _ in 0..200 { - let (sk, pk) = S::new_key(&mut rng)?; + let (sk, pk) = S::new_key(&mut rng, ¶m)?; - let m = Rq::::rand_u64(&mut rng, msg_dist)?; + let m = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; - let p = S::encode::(&m); - let c = S::encrypt(&mut rng, &pk, &p)?; + let p = S::encode(¶m, &m); + let c = S::encrypt(&mut rng, ¶m, &pk, &p)?; - let c2: GLWE, K> = c.mod_switch::

(); - let sk2: SecretKey, K> = - SecretKey(TR(sk.0 .0.iter().map(|s_i| s_i.remodule::

()).collect())); + let c2: GLWE = c.mod_switch(new_q); + assert_eq!(c2.1.param.q, new_q); + let sk2: SecretKey = SecretKey(TR { + k: param.k, + r: sk.0.r.iter().map(|s_i| s_i.remodule(new_q)).collect(), + }); let p_recovered = c2.decrypt(&sk2); - let m_recovered = GLWE::, K>::decode::(&p_recovered); - - assert_eq!(m.remodule::(), m_recovered.remodule::()); + let new_param = Param { + err_sigma: ERR_SIGMA, + ring: RingParam { + q: new_q, + n: param.ring.n, + }, + k: param.k, + t: param.t, + }; + let m_recovered = GLWE::::decode(&new_param, &p_recovered); + + assert_eq!(m.remodule(param.t), m_recovered.remodule(param.t)); } Ok(()) @@ -465,40 +579,45 @@ mod tests { #[test] fn test_key_switch() -> Result<()> { - const Q: u64 = 2u64.pow(16) + 1; - const N: usize = 128; - const T: u64 = 2; // plaintext modulus - const K: usize = 16; - type S = GLWE, K>; + let param = Param { + err_sigma: ERR_SIGMA, + ring: RingParam { + q: 2u64.pow(16) + 1, + n: 128, + }, + k: 16, + t: 2, + }; + type S = GLWE; let beta: u32 = 2; let l: u32 = 16; let mut rng = rand::thread_rng(); - let (sk, pk) = S::new_key(&mut rng)?; - let (sk2, _) = S::new_key(&mut rng)?; + let (sk, pk) = S::new_key(&mut rng, ¶m)?; + let (sk2, _) = S::new_key(&mut rng, ¶m)?; // ksk to switch from sk to sk2 - let ksk = S::new_ksk(&mut rng, beta, l, &sk, &sk2)?; + let ksk = S::new_ksk(&mut rng, ¶m, beta, l, &sk, &sk2)?; - let msg_dist = Uniform::new(0_u64, T); - let m = Rq::::rand_u64(&mut rng, msg_dist)?; - let p = S::encode::(&m); // plaintext - // - let c = S::encrypt_s(&mut rng, &sk, &p)?; + let msg_dist = Uniform::new(0_u64, param.t); + let m = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; + let p = S::encode(¶m, &m); // plaintext + // + let c = S::encrypt_s(&mut rng, ¶m, &sk, &p)?; - let c2 = c.key_switch(beta, l, &ksk); + let c2 = c.key_switch(¶m, beta, l, &ksk); // decrypt with the 2nd secret key let p_recovered = c2.decrypt(&sk2); - let m_recovered = S::decode::(&p_recovered); - assert_eq!(m.remodule::(), m_recovered.remodule::()); + let m_recovered = S::decode(¶m, &p_recovered); + assert_eq!(m.remodule(param.t), m_recovered.remodule(param.t)); // do the same but now encrypting with pk - let c = S::encrypt(&mut rng, &pk, &p)?; - let c2 = c.key_switch(beta, l, &ksk); + let c = S::encrypt(&mut rng, ¶m, &pk, &p)?; + let c2 = c.key_switch(¶m, beta, l, &ksk); let p_recovered = c2.decrypt(&sk2); - let m_recovered = S::decode::(&p_recovered); + let m_recovered = S::decode(¶m, &p_recovered); assert_eq!(m, m_recovered); Ok(()) diff --git a/tfhe/src/lib.rs b/tfhe/src/lib.rs index f2f32fa..114547d 100644 --- a/tfhe/src/lib.rs +++ b/tfhe/src/lib.rs @@ -10,3 +10,5 @@ pub mod tglwe; pub mod tgsw; pub mod tlev; pub mod tlwe; + +pub(crate) const ERR_SIGMA: f64 = 3.2; diff --git a/tfhe/src/tggsw.rs b/tfhe/src/tggsw.rs index 663cd03..bc2fda5 100644 --- a/tfhe/src/tggsw.rs +++ b/tfhe/src/tggsw.rs @@ -4,53 +4,57 @@ use rand::Rng; use std::array; use std::ops::{Add, Mul}; -use arith::{Ring, Rq, Tn, T64, TR}; +use arith::{Ring, RingParam, Rq, Tn, T64, TR}; use crate::tglwe::{PublicKey, SecretKey, TGLWE}; -use gfhe::glwe::GLWE; +use gfhe::glwe::{Param, GLWE}; /// vector of length K+1 = ([K * TGLev], [1 * TGLev]) #[derive(Clone, Debug)] -pub struct TGGSW(pub(crate) Vec>, TGLev); +pub struct TGGSW(pub(crate) Vec, TGLev); -impl TGGSW { +impl TGGSW { pub fn encrypt_s( mut rng: impl Rng, + param: &Param, beta: u32, l: u32, - sk: &SecretKey, - m: &Tn, + sk: &SecretKey, + m: &Tn, ) -> Result { - let a: Vec> = (0..K) - .map(|i| TGLev::encrypt_s(&mut rng, beta, l, sk, &(-sk.0 .0 .0[i] * *m))) + debug_assert_eq!(sk.0 .0.k, param.k); + + let a: Vec = (0..param.k) + .map(|i| TGLev::encrypt_s(&mut rng, param, beta, l, sk, &(&-sk.0 .0.r[i].clone() * m))) + // TODO rm clone .collect::>>()?; - let b: TGLev = TGLev::encrypt_s(&mut rng, beta, l, sk, m)?; + let b: TGLev = TGLev::encrypt_s(&mut rng, ¶m, beta, l, sk, m)?; Ok(Self(a, b)) } - pub fn decrypt(&self, sk: &SecretKey, beta: u32) -> Tn { + pub fn decrypt(&self, sk: &SecretKey, beta: u32) -> Tn { self.1.decrypt(sk, beta) } - pub fn cmux(bit: Self, ct1: TGLWE, ct2: TGLWE) -> TGLWE { + pub fn cmux(bit: Self, ct1: TGLWE, ct2: TGLWE) -> TGLWE { ct1.clone() + (bit * (ct2 - ct1)) } } -/// External product TGGSW x TGLWE -impl Mul> for TGGSW { - type Output = TGLWE; +/// External product tggsw x tglwe +impl Mul for TGGSW { + type Output = TGLWE; - fn mul(self, tglwe: TGLWE) -> TGLWE { + fn mul(self, tglwe: TGLWE) -> TGLWE { let beta: u32 = 2; let l: u32 = 64; // TODO wip - let tglwe_ab: Vec> = [tglwe.0 .0 .0.clone(), vec![tglwe.0 .1]].concat(); + let tglwe_ab: Vec = [tglwe.0 .0.r.clone(), vec![tglwe.0 .1]].concat(); - let tgsw_ab: Vec> = [self.0.clone(), vec![self.1]].concat(); + let tgsw_ab: Vec = [self.0.clone(), vec![self.1]].concat(); assert_eq!(tgsw_ab.len(), tglwe_ab.len()); - let r: TGLWE = zip_eq(tgsw_ab, tglwe_ab) + let r: TGLWE = zip_eq(tgsw_ab, tglwe_ab) .map(|(tlev_i, tglwe_i)| tlev_i * tglwe_i.decompose(beta, l)) .sum(); r @@ -58,26 +62,36 @@ impl Mul> for TGGSW { } #[derive(Clone, Debug)] -pub struct TGLev(pub(crate) Vec>); +pub struct TGLev(pub(crate) Vec); + +impl TGLev { + pub fn encode(param: &Param, m: &Rq) -> Tn { + debug_assert_eq!(param.t, m.param.q); // plaintext modulus -impl TGLev { - pub fn encode(m: &Rq) -> Tn { - let coeffs = m.coeffs(); - Tn(array::from_fn(|i| T64(coeffs[i].0))) + Tn { + param: param.ring, + coeffs: m.coeffs().iter().map(|c_i| T64(c_i.v)).collect(), + } } - pub fn decode(p: &Tn) -> Rq { - Rq::::from_vec_u64(p.coeffs().iter().map(|c| c.0).collect()) + pub fn decode(param: &Param, p: &Tn) -> Rq { + Rq::from_vec_u64(¶m.pt(), p.coeffs().iter().map(|c| c.0).collect()) } pub fn encrypt( mut rng: impl Rng, + param: &Param, beta: u32, l: u32, - pk: &PublicKey, - m: &Tn, + pk: &PublicKey, + m: &Tn, ) -> Result { - let tlev: Vec> = (1..l + 1) + let tlev: Vec = (1..l + 1) .map(|i| { - TGLWE::::encrypt(&mut rng, pk, &(*m * (u64::MAX / beta.pow(i as u32) as u64))) + TGLWE::encrypt( + &mut rng, + ¶m, + pk, + &(m * &(u64::MAX / beta.pow(i as u32) as u64)), + ) }) .collect::>>()?; @@ -85,35 +99,36 @@ impl TGLev { } pub fn encrypt_s( mut rng: impl Rng, + param: &Param, _beta: u32, // TODO rm, and make beta=2 always l: u32, - sk: &SecretKey, - m: &Tn, + sk: &SecretKey, + m: &Tn, ) -> Result { - let tlev: Vec> = (1..l as u64 + 1) + let tlev: Vec = (1..l as u64 + 1) .map(|i| { let aux = if i < 64 { - *m * (u64::MAX / (1u64 << i)) + m * &(u64::MAX / (1u64 << i)) } else { // 1<<64 would overflow, and anyways we're dividing u64::MAX // by it, which would be equal to 1 - *m + m.clone() // TODO rm clone }; - TGLWE::::encrypt_s(&mut rng, sk, &aux) + TGLWE::encrypt_s(&mut rng, ¶m, sk, &aux) }) .collect::>>()?; Ok(Self(tlev)) } - pub fn decrypt(&self, sk: &SecretKey, beta: u32) -> Tn { + pub fn decrypt(&self, sk: &SecretKey, beta: u32) -> Tn { let pt = self.0[0].decrypt(sk); pt.mul_div_round(beta as u64, u64::MAX) } } -impl TGLev { - pub fn iter(&self) -> std::slice::Iter> { +impl TGLev { + pub fn iter(&self) -> std::slice::Iter { self.0.iter() } } @@ -121,14 +136,14 @@ impl TGLev { // dot product between a TGLev and Vec>, usually Vec> comes from a // decomposition of Tn // TGLev * Vec> --> TGLWE -impl Mul>> for TGLev { - type Output = TGLWE; - fn mul(self, v: Vec>) -> Self::Output { +impl Mul> for TGLev { + type Output = TGLWE; + fn mul(self, v: Vec) -> Self::Output { assert_eq!(self.0.len(), v.len()); // l TGLWES - let tlwes: Vec> = self.0; - let r: TGLWE = zip_eq(v, tlwes).map(|(a_d_i, glwe_i)| glwe_i * a_d_i).sum(); + let tlwes: Vec = self.0; + let r: TGLWE = zip_eq(v, tlwes).map(|(a_d_i, glwe_i)| glwe_i * a_d_i).sum(); r } } @@ -141,38 +156,40 @@ mod tests { use super::*; #[test] fn test_external_product() -> Result<()> { - const T: u64 = 16; // plaintext modulus - const K: usize = 4; - const N: usize = 64; - const KN: usize = K * N; + let param = Param { + err_sigma: crate::ERR_SIGMA, + ring: RingParam { q: u64::MAX, n: 64 }, + k: 4, + t: 16, // plaintext modulus + }; let beta: u32 = 2; let l: u32 = 64; let mut rng = rand::thread_rng(); - let msg_dist = Uniform::new(0_u64, T); + let msg_dist = Uniform::new(0_u64, param.t); for _ in 0..50 { - let (sk, _) = TGLWE::::new_key::(&mut rng)?; + let (sk, _) = TGLWE::new_key(&mut rng, ¶m)?; - let m1: Rq = Rq::rand_u64(&mut rng, msg_dist)?; - let p1: Tn = TGLev::::encode::(&m1); + let m1: Rq = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; + let p1: Tn = TGLev::encode(¶m, &m1); - let m2: Rq = Rq::rand_u64(&mut rng, msg_dist)?; - let p2: Tn = TGLWE::::encode::(&m2); // scaled by delta + let m2: Rq = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; + let p2: Tn = TGLWE::encode(¶m, &m2); // scaled by delta - let tgsw = TGGSW::::encrypt_s(&mut rng, beta, l, &sk, &p1)?; - let tlwe = TGLWE::::encrypt_s(&mut rng, &sk, &p2)?; + let tgsw = TGGSW::encrypt_s(&mut rng, ¶m, beta, l, &sk, &p1)?; + let tlwe = TGLWE::encrypt_s(&mut rng, ¶m, &sk, &p2)?; - let res: TGLWE = tgsw * tlwe; + let res: TGLWE = tgsw * tlwe; // let p_recovered = res.decrypt(&sk, beta); let p_recovered = res.decrypt(&sk); // downscaled by delta^-1 - let res_recovered = TGLWE::::decode::(&p_recovered); + let res_recovered = TGLWE::decode(¶m, &p_recovered); // assert_eq!(m1 * m2, m_recovered); - assert_eq!((m1.to_r() * m2.to_r()).to_rq::(), res_recovered); + assert_eq!((m1.to_r() * m2.to_r()).to_rq(param.t), res_recovered); } Ok(()) diff --git a/tfhe/src/tglwe.rs b/tfhe/src/tglwe.rs index 33993f4..ae4c97b 100644 --- a/tfhe/src/tglwe.rs +++ b/tfhe/src/tglwe.rs @@ -1,161 +1,194 @@ use anyhow::Result; -use itertools::zip_eq; -use rand::distributions::Standard; use rand::Rng; -use rand_distr::{Normal, Uniform}; -use std::array; use std::iter::Sum; use std::ops::{Add, AddAssign, Mul, Sub}; -use arith::{Ring, Rq, Tn, T64, TR}; -use gfhe::{glwe, GLWE}; +use arith::{Ring, RingParam, Rq, Tn, T64, TR}; +use gfhe::{glwe, glwe::Param, GLWE}; -use crate::tlev::TLev; use crate::{tlwe, tlwe::TLWE}; -// pub type SecretKey = glwe::SecretKey, K>; #[derive(Clone, Debug)] -pub struct SecretKey(pub glwe::SecretKey, K>); -// pub struct SecretKey(pub tlwe::SecretKey); +pub struct SecretKey(pub glwe::SecretKey); -impl SecretKey { - pub fn to_tlwe(&self) -> tlwe::SecretKey { - let s: TR, K> = self.0 .0.clone(); +impl SecretKey { + pub fn to_tlwe(&self, param: &Param) -> tlwe::SecretKey { + let s: TR = self.0 .0.clone(); + debug_assert_eq!(s.r.len(), param.k); // sanity check - let r: Vec> = s.0.iter().map(|s_i| s_i.coeffs()).collect(); + let kn = param.k * param.ring.n; + let r: Vec> = s.r.iter().map(|s_i| s_i.coeffs()).collect(); let r: Vec = r.into_iter().flatten().collect(); - tlwe::SecretKey::(glwe::SecretKey::(TR::::new(r))) + debug_assert_eq!(r.len(), kn); // sanity check + tlwe::SecretKey(glwe::SecretKey::(TR::::new(kn, r))) } } -pub type PublicKey = glwe::PublicKey, K>; +pub type PublicKey = glwe::PublicKey; #[derive(Clone, Debug)] -pub struct TGLWE(pub GLWE, K>); +pub struct TGLWE(pub GLWE); -impl TGLWE { - pub fn zero() -> Self { - Self(GLWE::, K>::zero()) +impl TGLWE { + pub fn zero(k: usize, param: &RingParam) -> Self { + Self(GLWE::::zero(k, param)) } - pub fn from_plaintext(p: Tn) -> Self { - Self(GLWE::, K>::from_plaintext(p)) + pub fn from_plaintext(k: usize, param: &RingParam, p: Tn) -> Self { + Self(GLWE::::from_plaintext(k, param, p)) } - pub fn new_key( - mut rng: impl Rng, - ) -> Result<(SecretKey, PublicKey)> { - // assert_eq!(KN, K * N); // this is wip, while not being able to compute K*N - let (sk_tlwe, _) = TLWE::::new_key(&mut rng)?; - // let sk = crate::tlwe::sk_to_tglwe::(sk_tlwe); - let sk = sk_tlwe.to_tglwe::(); - let pk: PublicKey = GLWE::pk_from_sk(rng, sk.0.clone())?; + pub fn new_key(mut rng: impl Rng, param: &Param) -> Result<(SecretKey, PublicKey)> { + let (sk_tlwe, _) = TLWE::new_key(&mut rng, ¶m.lwe())?; //param.lwe() so that it uses K*N + debug_assert_eq!(sk_tlwe.0 .0.r.len(), param.lwe().k); // =KN (sanity check) + + let sk = sk_tlwe.to_tglwe(param); + let pk: PublicKey = GLWE::pk_from_sk(rng, param, sk.0.clone())?; Ok((sk, pk)) } - pub fn encode(m: &Rq) -> Tn { - let delta = u64::MAX / P; // floored + pub fn encode(param: &Param, m: &Rq) -> Tn { + debug_assert_eq!(param.t, m.param.q); // plaintext modulus + let p = param.t; + let delta = u64::MAX / p; // floored let coeffs = m.coeffs(); - Tn(array::from_fn(|i| T64(coeffs[i].0 * delta))) + Tn { + param: param.ring, + coeffs: coeffs.iter().map(|c_i| T64(c_i.v * delta)).collect(), + } + } + pub fn decode(param: &Param, pt: &Tn) -> Rq { + let p = param.t; + let pt = pt.mul_div_round(p, u64::MAX); + Rq::from_vec_u64(¶m.pt(), pt.coeffs().iter().map(|c| c.0).collect()) } - pub fn decode(p: &Tn) -> Rq { - let p = p.mul_div_round(P, u64::MAX); - Rq::::from_vec_u64(p.coeffs().iter().map(|c| c.0).collect()) + /// encodes the given message as a TGLWE constant/public value, for using it + /// in ct-pt-multiplication. + pub fn new_const(param: &Param, m: &Rq) -> Tn { + debug_assert_eq!(param.t, m.param.q); + // don't scale up m, set the Tn element directly from m's coefficients + Tn { + param: param.ring, + coeffs: m.coeffs().iter().map(|c_i| T64(c_i.v)).collect(), + } } - // encrypts with the given SecretKey (instead of PublicKey) - pub fn encrypt_s(rng: impl Rng, sk: &SecretKey, p: &Tn) -> Result { - let glwe = GLWE::encrypt_s(rng, &sk.0, p)?; + /// encrypts with the given SecretKey (instead of PublicKey) + pub fn encrypt_s(rng: impl Rng, param: &Param, sk: &SecretKey, p: &Tn) -> Result { + let glwe = GLWE::encrypt_s(rng, param, &sk.0, p)?; Ok(Self(glwe)) } - pub fn encrypt(rng: impl Rng, pk: &PublicKey, p: &Tn) -> Result { - let glwe = GLWE::encrypt(rng, &pk, p)?; + pub fn encrypt(rng: impl Rng, param: &Param, pk: &PublicKey, p: &Tn) -> Result { + let glwe = GLWE::encrypt(rng, param, &pk, p)?; Ok(Self(glwe)) } - pub fn decrypt(&self, sk: &SecretKey) -> Tn { + pub fn decrypt(&self, sk: &SecretKey) -> Tn { self.0.decrypt(&sk.0) } /// Sample extraction / Coefficient extraction - pub fn sample_extraction(&self, h: usize) -> TLWE { - assert!(h < N); + pub fn sample_extraction(&self, param: &Param, h: usize) -> TLWE { + let n = param.ring.n; + assert!(h < n); - let a: TR, K> = self.0 .0.clone(); + let a: TR = self.0 .0.clone(); // set a_{n*i+j} = a_{i, h-j} if j \in {0, h} // -a_{i, n+h-j} if j \in {h+1, n-1} let new_a: Vec = a .iter() .flat_map(|a_i| { let a_i = a_i.coeffs(); - (0..N) - .map(|j| if j <= h { a_i[h - j] } else { -a_i[N + h - j] }) + (0..n) + .map(|j| if j <= h { a_i[h - j] } else { -a_i[n + h - j] }) .collect::>() }) .collect::>(); - - TLWE(GLWE(TR(new_a), self.0 .1.coeffs()[h])) + debug_assert_eq!(new_a.len(), param.k * param.ring.n); // sanity check + + TLWE(GLWE( + TR { + // TODO use constructor `new`, which will check len with k + k: param.k * param.ring.n, + r: new_a, + }, + self.0 .1.coeffs()[h], + )) } pub fn left_rotate(&self, h: usize) -> Self { - dbg!(&h); - let (a, b): (TR, K>, Tn) = (self.0 .0.clone(), self.0 .1); + let (a, b): (TR, Tn) = (self.0 .0.clone(), self.0 .1.clone()); Self(GLWE(a.left_rotate(h), b.left_rotate(h))) } } -impl Add> for TGLWE { +impl Add for TGLWE { type Output = Self; fn add(self, other: Self) -> Self { + debug_assert_eq!(self.0 .0.k, other.0 .0.k); + debug_assert_eq!(self.0 .1.param(), other.0 .1.param()); + Self(self.0 + other.0) } } -impl AddAssign for TGLWE { - fn add_assign(&mut self, rhs: Self) { - self.0 += rhs.0 +impl AddAssign for TGLWE { + fn add_assign(&mut self, other: Self) { + debug_assert_eq!(self.0 .0.k, other.0 .0.k); + debug_assert_eq!(self.0 .1.param(), other.0 .1.param()); + + self.0 += other.0 } } -impl Sum> for TGLWE { - fn sum(iter: I) -> Self +impl Sum for TGLWE { + fn sum(mut iter: I) -> Self where I: Iterator, { - let mut acc = TGLWE::::zero(); - for e in iter { - acc += e; - } - acc + let first = iter.next().unwrap(); + iter.fold(first, |acc, e| acc + e) } } -impl Sub> for TGLWE { +impl Sub for TGLWE { type Output = Self; fn sub(self, other: Self) -> Self { + debug_assert_eq!(self.0 .0.k, other.0 .0.k); + debug_assert_eq!(self.0 .1.param(), other.0 .1.param()); + Self(self.0 - other.0) } } // plaintext addition -impl Add> for TGLWE { +impl Add for TGLWE { type Output = Self; - fn add(self, plaintext: Tn) -> Self { - let a: TR, K> = self.0 .0; - let b: Tn = self.0 .1 + plaintext; + fn add(self, plaintext: Tn) -> Self { + debug_assert_eq!(self.0 .1.param(), plaintext.param()); + + let a: TR = self.0 .0; + let b: Tn = self.0 .1 + plaintext; Self(GLWE(a, b)) } } // plaintext substraction -impl Sub> for TGLWE { +impl Sub for TGLWE { type Output = Self; - fn sub(self, plaintext: Tn) -> Self { - let a: TR, K> = self.0 .0; - let b: Tn = self.0 .1 - plaintext; + fn sub(self, plaintext: Tn) -> Self { + debug_assert_eq!(self.0 .1.param(), plaintext.param()); + + let a: TR = self.0 .0; + let b: Tn = self.0 .1 - plaintext; Self(GLWE(a, b)) } } // plaintext multiplication -impl Mul> for TGLWE { +impl Mul for TGLWE { type Output = Self; - fn mul(self, plaintext: Tn) -> Self { - let a: TR, K> = TR(self.0 .0 .0.iter().map(|r_i| *r_i * plaintext).collect()); - let b: Tn = self.0 .1 * plaintext; + fn mul(self, plaintext: Tn) -> Self { + debug_assert_eq!(self.0 .1.param(), plaintext.param()); + + let a: TR = TR { + k: self.0 .0.k, + r: self.0 .0.r.iter().map(|r_i| r_i * &plaintext).collect(), + }; + let b: Tn = self.0 .1 * plaintext; Self(GLWE(a, b)) } } @@ -169,30 +202,32 @@ mod tests { #[test] fn test_encrypt_decrypt() -> Result<()> { - const T: u64 = 128; // msg space (msg modulus) - const N: usize = 64; - const K: usize = 16; - type S = TGLWE; + let param = Param { + err_sigma: crate::ERR_SIGMA, + ring: RingParam { q: u64::MAX, n: 64 }, + k: 16, + t: 128, // plaintext modulus + }; let mut rng = rand::thread_rng(); - let msg_dist = Uniform::new(0_u64, T); + let msg_dist = Uniform::new(0_u64, param.t); for _ in 0..200 { - let (sk, pk) = TGLWE::::new_key::<{ K * N }>(&mut rng)?; + let (sk, pk) = TGLWE::new_key(&mut rng, ¶m)?; - let m = Rq::::rand_u64(&mut rng, msg_dist)?; - let p: Tn = S::encode::(&m); + let m = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; + let p: Tn = TGLWE::encode(¶m, &m); - let c = S::encrypt(&mut rng, &pk, &p)?; + let c = TGLWE::encrypt(&mut rng, ¶m, &pk, &p)?; let p_recovered = c.decrypt(&sk); - let m_recovered = S::decode::(&p_recovered); + let m_recovered = TGLWE::decode(¶m, &p_recovered); assert_eq!(m, m_recovered); // same but using encrypt_s (with sk instead of pk)) - let c = S::encrypt_s(&mut rng, &sk, &p)?; + let c = TGLWE::encrypt_s(&mut rng, ¶m, &sk, &p)?; let p_recovered = c.decrypt(&sk); - let m_recovered = S::decode::(&p_recovered); + let m_recovered = TGLWE::decode(¶m, &p_recovered); assert_eq!(m, m_recovered); } @@ -202,31 +237,33 @@ mod tests { #[test] fn test_addition() -> Result<()> { - const T: u64 = 128; - const N: usize = 64; - const K: usize = 16; - type S = TGLWE; + let param = Param { + err_sigma: crate::ERR_SIGMA, + ring: RingParam { q: u64::MAX, n: 64 }, + k: 16, + t: 128, // plaintext modulus + }; let mut rng = rand::thread_rng(); - let msg_dist = Uniform::new(0_u64, T); + let msg_dist = Uniform::new(0_u64, param.t); for _ in 0..200 { - let (sk, pk) = S::new_key::<{ K * N }>(&mut rng)?; + let (sk, pk) = TGLWE::new_key(&mut rng, ¶m)?; - let m1 = Rq::::rand_u64(&mut rng, msg_dist)?; - let m2 = Rq::::rand_u64(&mut rng, msg_dist)?; - let p1: Tn = S::encode::(&m1); // plaintext - let p2: Tn = S::encode::(&m2); // plaintext + let m1 = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; + let m2 = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; + let p1: Tn = TGLWE::encode(¶m, &m1); // plaintext + let p2: Tn = TGLWE::encode(¶m, &m2); // plaintext - let c1 = S::encrypt(&mut rng, &pk, &p1)?; - let c2 = S::encrypt(&mut rng, &pk, &p2)?; + let c1 = TGLWE::encrypt(&mut rng, ¶m, &pk, &p1)?; + let c2 = TGLWE::encrypt(&mut rng, ¶m, &pk, &p2)?; let c3 = c1 + c2; let p3_recovered = c3.decrypt(&sk); - let m3_recovered = S::decode::(&p3_recovered); + let m3_recovered = TGLWE::decode(¶m, &p3_recovered); - assert_eq!((m1 + m2).remodule::(), m3_recovered.remodule::()); + assert_eq!((m1 + m2).remodule(param.t), m3_recovered.remodule(param.t)); } Ok(()) @@ -234,28 +271,30 @@ mod tests { #[test] fn test_add_plaintext() -> Result<()> { - const T: u64 = 128; - const N: usize = 64; - const K: usize = 16; - type S = TGLWE; + let param = Param { + err_sigma: crate::ERR_SIGMA, + ring: RingParam { q: u64::MAX, n: 64 }, + k: 16, + t: 128, // plaintext modulus + }; let mut rng = rand::thread_rng(); - let msg_dist = Uniform::new(0_u64, T); + let msg_dist = Uniform::new(0_u64, param.t); for _ in 0..200 { - let (sk, pk) = S::new_key::<{ K * N }>(&mut rng)?; + let (sk, pk) = TGLWE::new_key(&mut rng, ¶m)?; - let m1 = Rq::::rand_u64(&mut rng, msg_dist)?; - let m2 = Rq::::rand_u64(&mut rng, msg_dist)?; - let p1: Tn = S::encode::(&m1); // plaintext - let p2: Tn = S::encode::(&m2); // plaintext + let m1 = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; + let m2 = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; + let p1: Tn = TGLWE::encode(¶m, &m1); // plaintext + let p2: Tn = TGLWE::encode(¶m, &m2); // plaintext - let c1 = S::encrypt(&mut rng, &pk, &p1)?; + let c1 = TGLWE::encrypt(&mut rng, ¶m, &pk, &p1)?; let c3 = c1 + p2; let p3_recovered = c3.decrypt(&sk); - let m3_recovered = S::decode::(&p3_recovered); + let m3_recovered = TGLWE::decode(¶m, &p3_recovered); assert_eq!(m1 + m2, m3_recovered); } @@ -265,30 +304,31 @@ mod tests { #[test] fn test_mul_plaintext() -> Result<()> { - const T: u64 = 128; - const N: usize = 64; - const K: usize = 16; - type S = TGLWE; + let param = Param { + err_sigma: crate::ERR_SIGMA, + ring: RingParam { q: u64::MAX, n: 64 }, + k: 16, + t: 128, // plaintext modulus + }; let mut rng = rand::thread_rng(); - let msg_dist = Uniform::new(0_u64, T); + let msg_dist = Uniform::new(0_u64, param.t); for _ in 0..200 { - let (sk, pk) = S::new_key::<{ K * N }>(&mut rng)?; + let (sk, pk) = TGLWE::new_key(&mut rng, ¶m)?; - let m1 = Rq::::rand_u64(&mut rng, msg_dist)?; - let m2 = Rq::::rand_u64(&mut rng, msg_dist)?; - let p1: Tn = S::encode::(&m1); - // don't scale up p2, set it directly from m2 - let p2: Tn = Tn(array::from_fn(|i| T64(m2.coeffs()[i].0))); + let m1 = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; + let m2 = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; + let p1: Tn = TGLWE::encode(¶m, &m1); + let p2: Tn = TGLWE::new_const(¶m, &m2); // as constant/public value - let c1 = S::encrypt(&mut rng, &pk, &p1)?; + let c1 = TGLWE::encrypt(&mut rng, ¶m, &pk, &p1)?; let c3 = c1 * p2; - let p3_recovered: Tn = c3.decrypt(&sk); - let m3_recovered = S::decode::(&p3_recovered); - assert_eq!((m1.to_r() * m2.to_r()).to_rq::(), m3_recovered); + let p3_recovered: Tn = c3.decrypt(&sk); + let m3_recovered = TGLWE::decode(¶m, &p3_recovered); + assert_eq!((m1.to_r() * m2.to_r()).to_rq(param.t), m3_recovered); } Ok(()) @@ -296,28 +336,30 @@ mod tests { #[test] fn test_sample_extraction() -> Result<()> { - const T: u64 = 128; // msg space (msg modulus) - const N: usize = 64; - const K: usize = 16; - const KN: usize = K * N; + let param = Param { + err_sigma: crate::ERR_SIGMA, + ring: RingParam { q: u64::MAX, n: 64 }, + k: 16, + t: 128, // plaintext modulus + }; let mut rng = rand::thread_rng(); - let msg_dist = Uniform::new(0_u64, T); + let msg_dist = Uniform::new(0_u64, param.t); for _ in 0..20 { - let (sk, pk) = TGLWE::::new_key::(&mut rng)?; - let sk_tlwe = sk.to_tlwe::(); + let (sk, pk) = TGLWE::new_key(&mut rng, ¶m)?; + let sk_tlwe = sk.to_tlwe(¶m); - let m = Rq::::rand_u64(&mut rng, msg_dist)?; - let p: Tn = TGLWE::::encode::(&m); + let m = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; + let p: Tn = TGLWE::encode(¶m, &m); - let c = TGLWE::::encrypt(&mut rng, &pk, &p)?; + let c = TGLWE::encrypt(&mut rng, ¶m, &pk, &p)?; - for h in 0..N { - let c_h: TLWE = c.sample_extraction(h); + for h in 0..param.ring.n { + let c_h: TLWE = c.sample_extraction(¶m, h); let p_recovered = c_h.decrypt(&sk_tlwe); - let m_recovered = TLWE::::decode::(&p_recovered); + let m_recovered = TLWE::decode(¶m, &p_recovered); assert_eq!(m.coeffs()[h], m_recovered.coeffs()[0]); } } diff --git a/tfhe/src/tgsw.rs b/tfhe/src/tgsw.rs index 8b83107..8e4eca0 100644 --- a/tfhe/src/tgsw.rs +++ b/tfhe/src/tgsw.rs @@ -1,65 +1,62 @@ use anyhow::Result; use itertools::zip_eq; use rand::Rng; -use std::array; -use std::ops::{Add, Mul}; +use std::ops::Mul; -use arith::{Ring, Rq, Tn, T64, TR}; +use arith::{Ring, T64}; use crate::tlev::TLev; -use crate::{ - tglwe::TGLWE, - tlwe::{PublicKey, SecretKey, TLWE}, -}; -use gfhe::glwe::GLWE; +use crate::tlwe::{SecretKey, TLWE}; +use gfhe::glwe::Param; /// vector of length K+1 = [K], [1] #[derive(Clone, Debug)] -pub struct TGSW(pub(crate) Vec>, TLev); +pub struct TGSW(pub(crate) Vec, TLev); -impl TGSW { +impl TGSW { pub fn encrypt_s( mut rng: impl Rng, + param: &Param, beta: u32, l: u32, - sk: &SecretKey, + sk: &SecretKey, m: &T64, ) -> Result { - let a: Vec> = (0..K) - .map(|i| TLev::encrypt_s(&mut rng, beta, l, sk, &(-sk.0 .0 .0[i] * *m))) + let a: Vec = (0..param.k) + .map(|i| TLev::encrypt_s(&mut rng, ¶m, beta, l, sk, &(-sk.0 .0.r[i] * *m))) .collect::>>()?; - let b: TLev = TLev::encrypt_s(&mut rng, beta, l, sk, m)?; + let b: TLev = TLev::encrypt_s(&mut rng, ¶m, beta, l, sk, m)?; Ok(Self(a, b)) } - pub fn decrypt(&self, sk: &SecretKey, beta: u32) -> T64 { + pub fn decrypt(&self, sk: &SecretKey, beta: u32) -> T64 { self.1.decrypt(sk, beta) } - pub fn from_tlwe(_tlwe: TLWE) -> Self { + pub fn from_tlwe(_tlwe: TLWE) -> Self { todo!() } - pub fn cmux(bit: Self, ct1: TLWE, ct2: TLWE) -> TLWE { + pub fn cmux(bit: Self, ct1: TLWE, ct2: TLWE) -> TLWE { ct1.clone() + (bit * (ct2 - ct1)) } } /// External product TGSW x TLWE -impl Mul> for TGSW { - type Output = TLWE; +impl Mul for TGSW { + type Output = TLWE; - fn mul(self, tlwe: TLWE) -> TLWE { + fn mul(self, tlwe: TLWE) -> TLWE { let beta: u32 = 2; let l: u32 = 64; // TODO wip // since N=1, each tlwe element is a vector of length=1, decomposed into // l elements, and we have K of them - let tlwe_ab: Vec = [tlwe.0 .0 .0.clone(), vec![tlwe.0 .1]].concat(); + let tlwe_ab: Vec = [tlwe.0 .0.r.clone(), vec![tlwe.0 .1]].concat(); - let tgsw_ab: Vec> = [self.0.clone(), vec![self.1]].concat(); + let tgsw_ab: Vec = [self.0.clone(), vec![self.1]].concat(); assert_eq!(tgsw_ab.len(), tlwe_ab.len()); - let r: TLWE = zip_eq(tgsw_ab, tlwe_ab) + let r: TLWE = zip_eq(tgsw_ab, tlwe_ab) .map(|(tlev_i, tlwe_i)| tlev_i * tlwe_i.decompose(beta, l)) .sum(); r @@ -72,28 +69,31 @@ mod tests { use rand::distributions::Uniform; use super::*; + use arith::{RingParam, Rq}; #[test] fn test_encrypt_decrypt() -> Result<()> { - const T: u64 = 2; // plaintext modulus - const K: usize = 16; - type S = TGSW; - + let param = Param { + err_sigma: crate::ERR_SIGMA, + ring: RingParam { q: u64::MAX, n: 1 }, + k: 16, + t: 2, // plaintext modulus + }; let beta: u32 = 2; let l: u32 = 16; let mut rng = rand::thread_rng(); - let msg_dist = Uniform::new(0_u64, T); + let msg_dist = Uniform::new(0_u64, param.t); for _ in 0..50 { - let (sk, _) = TLWE::::new_key(&mut rng)?; + let (sk, _) = TLWE::new_key(&mut rng, ¶m)?; - let m: Rq = Rq::rand_u64(&mut rng, msg_dist)?; - let p: T64 = TLev::::encode::(&m); // plaintext + let m: Rq = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; + let p: T64 = TLev::encode(¶m, &m); // plaintext - let c = S::encrypt_s(&mut rng, beta, l, &sk, &p)?; + let c = TGSW::encrypt_s(&mut rng, ¶m, beta, l, &sk, &p)?; let p_recovered = c.decrypt(&sk, beta); - let m_recovered = TLev::::decode::(&p_recovered); + let m_recovered = TLev::decode(¶m, &p_recovered); assert_eq!(m, m_recovered); } @@ -103,36 +103,38 @@ mod tests { #[test] fn test_external_product() -> Result<()> { - const T: u64 = 2; // plaintext modulus - const K: usize = 32; - + let param = Param { + err_sigma: crate::ERR_SIGMA, + ring: RingParam { q: u64::MAX, n: 1 }, + k: 32, + t: 2, // plaintext modulus + }; let beta: u32 = 2; let l: u32 = 64; let mut rng = rand::thread_rng(); - let msg_dist = Uniform::new(0_u64, T); + let msg_dist = Uniform::new(0_u64, param.t); for _ in 0..50 { - let (sk, _) = TLWE::::new_key(&mut rng)?; + let (sk, _) = TLWE::new_key(&mut rng, ¶m)?; - let m1: Rq = Rq::rand_u64(&mut rng, msg_dist)?; - let p1: T64 = TLev::::encode::(&m1); + let m1: Rq = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; + let p1: T64 = TLev::encode(¶m, &m1); - let m2: Rq = Rq::rand_u64(&mut rng, msg_dist)?; - let p2: T64 = TLWE::::encode::(&m2); // scaled by delta + let m2: Rq = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; + let p2: T64 = TLWE::encode(¶m, &m2); // scaled by delta - let tgsw = TGSW::::encrypt_s(&mut rng, beta, l, &sk, &p1)?; - let tlwe = TLWE::::encrypt_s(&mut rng, &sk, &p2)?; + let tgsw = TGSW::encrypt_s(&mut rng, ¶m, beta, l, &sk, &p1)?; + let tlwe = TLWE::encrypt_s(&mut rng, ¶m, &sk, &p2)?; - let res: TLWE = tgsw * tlwe; + let res: TLWE = tgsw * tlwe; - // let p_recovered = res.decrypt(&sk, beta); let p_recovered = res.decrypt(&sk); // downscaled by delta^-1 - let res_recovered = TLWE::::decode::(&p_recovered); + let res_recovered = TLWE::decode(¶m, &p_recovered); // assert_eq!(m1 * m2, m_recovered); - assert_eq!((m1.to_r() * m2.to_r()).to_rq::(), res_recovered); + assert_eq!((m1.to_r() * m2.to_r()).to_rq(param.t), res_recovered); } Ok(()) @@ -140,35 +142,39 @@ mod tests { #[test] fn test_cmux() -> Result<()> { - const T: u64 = 2; // plaintext modulus - const K: usize = 32; + let param = Param { + err_sigma: crate::ERR_SIGMA, + ring: RingParam { q: u64::MAX, n: 1 }, + k: 32, + t: 2, // plaintext modulus + }; let beta: u32 = 2; let l: u32 = 64; let mut rng = rand::thread_rng(); - let msg_dist = Uniform::new(0_u64, T); + let msg_dist = Uniform::new(0_u64, param.t); for _ in 0..50 { - let (sk, _) = TLWE::::new_key(&mut rng)?; + let (sk, _) = TLWE::new_key(&mut rng, ¶m)?; - let m1: Rq = Rq::rand_u64(&mut rng, msg_dist)?; - let p1: T64 = TLWE::::encode::(&m1); // scaled by delta + let m1: Rq = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; + let p1: T64 = TLWE::encode(¶m, &m1); // scaled by delta - let m2: Rq = Rq::rand_u64(&mut rng, msg_dist)?; - let p2: T64 = TLWE::::encode::(&m2); // scaled by delta + let m2: Rq = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; + let p2: T64 = TLWE::encode(¶m, &m2); // scaled by delta for bit_raw in 0..2 { - let bit = TGSW::::encrypt_s(&mut rng, beta, l, &sk, &T64(bit_raw))?; + let bit = TGSW::encrypt_s(&mut rng, ¶m, beta, l, &sk, &T64(bit_raw))?; - let c1 = TLWE::::encrypt_s(&mut rng, &sk, &p1)?; - let c2 = TLWE::::encrypt_s(&mut rng, &sk, &p2)?; + let c1 = TLWE::encrypt_s(&mut rng, ¶m, &sk, &p1)?; + let c2 = TLWE::encrypt_s(&mut rng, ¶m, &sk, &p2)?; - let res: TLWE = TGSW::cmux(bit, c1, c2); + let res: TLWE = TGSW::cmux(bit, c1, c2); let p_recovered = res.decrypt(&sk); // downscaled by delta^-1 - let res_recovered = TLWE::::decode::(&p_recovered); + let res_recovered = TLWE::decode(¶m, &p_recovered); if bit_raw == 0 { assert_eq!(m1, res_recovered); diff --git a/tfhe/src/tlev.rs b/tfhe/src/tlev.rs index 66bd2c2..b8a543f 100644 --- a/tfhe/src/tlev.rs +++ b/tfhe/src/tlev.rs @@ -1,35 +1,50 @@ use anyhow::Result; use itertools::zip_eq; use rand::Rng; -use std::array; -use std::ops::{Add, Mul}; +use std::ops::Mul; -use arith::{Ring, Rq, Tn, T64, TR}; +use arith::{Ring, RingParam, Rq, T64}; -use crate::tglwe::TGLWE; use crate::tlwe::{PublicKey, SecretKey, TLWE}; +use gfhe::glwe::Param; #[derive(Clone, Debug)] -pub struct TLev(pub(crate) Vec>); +pub struct TLev(pub(crate) Vec); + +impl TLev { + pub fn encode(param: &Param, m: &Rq) -> T64 { + assert_eq!(m.param.n, 1); + assert_eq!(param.t, m.param.q); -impl TLev { - pub fn encode(m: &Rq) -> T64 { let coeffs = m.coeffs(); - T64(coeffs[0].0) // N=1, so take the only coeff + T64(coeffs[0].v) // N=1, so take the only coeff } - pub fn decode(p: &T64) -> Rq { - Rq::::from_vec_u64(p.coeffs().iter().map(|c| c.0).collect()) + pub fn decode(param: &Param, p: &T64) -> Rq { + Rq::from_vec_u64( + &RingParam { q: param.t, n: 1 }, + p.coeffs().iter().map(|c| c.0).collect(), + ) } pub fn encrypt( mut rng: impl Rng, + param: &Param, beta: u32, l: u32, - pk: &PublicKey, + pk: &PublicKey, m: &T64, ) -> Result { - let tlev: Vec> = (1..l + 1) + debug_assert_eq!(pk.1.k, param.k); + + let tlev: Vec = (1..l as u64 + 1) .map(|i| { - TLWE::::encrypt(&mut rng, pk, &(*m * (u64::MAX / beta.pow(i as u32) as u64))) + let aux = if i < 64 { + *m * (u64::MAX / (1u64 << i)) + } else { + // 1<<64 would overflow, and anyways we're dividing u64::MAX + // by it, which would be equal to 1 + *m + }; + TLWE::encrypt(&mut rng, param, pk, &aux) }) .collect::>>()?; @@ -37,12 +52,15 @@ impl TLev { } pub fn encrypt_s( mut rng: impl Rng, + param: &Param, _beta: u32, // TODO rm, and make beta=2 always l: u32, - sk: &SecretKey, + sk: &SecretKey, m: &T64, ) -> Result { - let tlev: Vec> = (1..l as u64 + 1) + debug_assert_eq!(sk.0 .0.k, param.k); + + let tlev: Vec = (1..l as u64 + 1) .map(|i| { let aux = if i < 64 { *m * (u64::MAX / (1u64 << i)) @@ -51,22 +69,22 @@ impl TLev { // by it, which would be equal to 1 *m }; - TLWE::::encrypt_s(&mut rng, sk, &aux) + TLWE::encrypt_s(&mut rng, ¶m, sk, &aux) }) .collect::>>()?; Ok(Self(tlev)) } - pub fn decrypt(&self, sk: &SecretKey, beta: u32) -> T64 { + pub fn decrypt(&self, sk: &SecretKey, beta: u32) -> T64 { let pt = self.0[0].decrypt(sk); pt.mul_div_round(beta as u64, u64::MAX) } } // TODO review u64::MAX, since is -1 of the value we actually want -impl TLev { - pub fn iter(&self) -> std::slice::Iter> { +impl TLev { + pub fn iter(&self) -> std::slice::Iter { self.0.iter() } } @@ -74,14 +92,14 @@ impl TLev { // dot product between a TLev and Vec, usually Vec comes from a // decomposition of T64 // TLev * Vec --> TLWE -impl Mul> for TLev { - type Output = TLWE; +impl Mul> for TLev { + type Output = TLWE; fn mul(self, v: Vec) -> Self::Output { assert_eq!(self.0.len(), v.len()); // l TLWES - let tlwes: Vec> = self.0; - let r: TLWE = zip_eq(v, tlwes).map(|(a_d_i, glwe_i)| glwe_i * a_d_i).sum(); + let tlwes: Vec = self.0; + let r: TLWE = zip_eq(v, tlwes).map(|(a_d_i, glwe_i)| glwe_i * a_d_i).sum(); r } } @@ -95,27 +113,30 @@ mod tests { #[test] fn test_encrypt_decrypt() -> Result<()> { - const T: u64 = 2; // plaintext modulus - const K: usize = 16; - type S = TLev; + let param = Param { + err_sigma: crate::ERR_SIGMA, + ring: RingParam { q: u64::MAX, n: 1 }, + k: 16, + t: 2, // plaintext modulus + }; let beta: u32 = 2; let l: u32 = 16; let mut rng = rand::thread_rng(); - let msg_dist = Uniform::new(0_u64, T); + let msg_dist = Uniform::new(0_u64, param.t); for _ in 0..200 { - let (sk, pk) = TLWE::::new_key(&mut rng)?; + let (sk, pk) = TLWE::new_key(&mut rng, ¶m)?; - let m: Rq = Rq::rand_u64(&mut rng, msg_dist)?; - let p: T64 = S::encode::(&m); // plaintext + let m: Rq = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; + let p: T64 = TLev::encode(¶m, &m); // plaintext - let c = S::encrypt(&mut rng, beta, l, &pk, &p)?; + let c = TLev::encrypt(&mut rng, ¶m, beta, l, &pk, &p)?; let p_recovered = c.decrypt(&sk, beta); - let m_recovered = S::decode::(&p_recovered); + let m_recovered = TLev::decode(¶m, &p_recovered); - assert_eq!(m.remodule::(), m_recovered.remodule::()); + assert_eq!(m.remodule(param.t), m_recovered.remodule(param.t)); } Ok(()) @@ -123,32 +144,37 @@ mod tests { #[test] fn test_tlev_vect64_product() -> Result<()> { - const T: u64 = 2; // plaintext modulus - const K: usize = 16; + let param = Param { + err_sigma: 0.1, // WIP + ring: RingParam { q: u64::MAX, n: 1 }, + k: 16, + t: 2, // plaintext modulus + }; let beta: u32 = 2; - let l: u32 = 16; + // let l: u32 = 16; + let l: u32 = 64; let mut rng = rand::thread_rng(); - let msg_dist = Uniform::new(0_u64, T); + let msg_dist = Uniform::new(0_u64, param.t); for _ in 0..200 { - let (sk, pk) = TLWE::::new_key(&mut rng)?; + let (sk, pk) = TLWE::new_key(&mut rng, ¶m)?; - let m1: Rq = Rq::rand_u64(&mut rng, msg_dist)?; - let m2: Rq = Rq::rand_u64(&mut rng, msg_dist)?; - let p1: T64 = TLev::::encode::(&m1); - let p2: T64 = TLev::::encode::(&m2); + let m1: Rq = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; + let m2: Rq = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; + let p1: T64 = TLev::encode(¶m, &m1); + let p2: T64 = TLev::encode(¶m, &m2); - let c1 = TLev::::encrypt(&mut rng, beta, l, &pk, &p1)?; + let c1 = TLev::encrypt(&mut rng, ¶m, beta, l, &pk, &p1)?; let c2 = p2.decompose(beta, l); let c3 = c1 * c2; let p_recovered = c3.decrypt(&sk); - let m_recovered = TLev::::decode::(&p_recovered); + let m_recovered = TLev::decode(¶m, &p_recovered); - assert_eq!((m1.to_r() * m2.to_r()).to_rq::(), m_recovered); + assert_eq!((m1.to_r() * m2.to_r()).to_rq(param.t), m_recovered); } Ok(()) diff --git a/tfhe/src/tlwe.rs b/tfhe/src/tlwe.rs index f0fc7ce..dac81ac 100644 --- a/tfhe/src/tlwe.rs +++ b/tfhe/src/tlwe.rs @@ -4,242 +4,275 @@ use rand::Rng; use std::iter::Sum; use std::ops::{Add, AddAssign, Mul, Sub}; -use arith::{Ring, Rq, Tn, Zq, T64, TR}; -use gfhe::{glwe, GLWE}; +use arith::{Ring, RingParam, Rq, Tn, Zq, T64, TR}; +use gfhe::{glwe, glwe::Param, GLWE}; use crate::tggsw::TGGSW; use crate::tlev::TLev; use crate::{tglwe, tglwe::TGLWE}; -pub struct SecretKey(pub glwe::SecretKey); +pub struct SecretKey(pub glwe::SecretKey); -impl SecretKey { +impl SecretKey { /// from TFHE [2018-421] paper: A TLWE key k \in B^n, can be interpreted as a /// TRLWE key K \in B_N[X]^k having the same sequence of coefficients and /// vice-versa. - pub fn to_tglwe(self) -> crate::tglwe::SecretKey { - let s: TR = self.0 .0; + pub fn to_tglwe(self, param: &Param) -> crate::tglwe::SecretKey { + let s: TR = self.0 .0; // of length K*N + assert_eq!(s.r.len(), param.k * param.ring.n); // sanity check + // split into K vectors, and interpret each of them as a T_N[X]/(X^N+1) // polynomial - let r: Vec> = - s.0.chunks(N) - .map(|v| Tn::::from_vec(v.to_vec())) + let r: Vec = + s.r.chunks(param.ring.n) + .map(|v| Tn::from_vec(¶m.ring, v.to_vec())) .collect(); - crate::tglwe::SecretKey(glwe::SecretKey::, K>(TR(r))) + crate::tglwe::SecretKey(glwe::SecretKey::(TR { k: param.k, r })) } } -pub type PublicKey = glwe::PublicKey; +pub type PublicKey = glwe::PublicKey; #[derive(Clone, Debug)] -pub struct KSK(Vec>); +pub struct KSK(Vec); #[derive(Clone, Debug)] -pub struct TLWE(pub GLWE); +pub struct TLWE(pub GLWE); -impl TLWE { - pub fn zero() -> Self { - Self(GLWE::::zero()) +impl TLWE { + pub fn zero(k: usize, ring_param: &RingParam) -> Self { + Self(GLWE::::zero(k, ring_param)) } - pub fn new_key(rng: impl Rng) -> Result<(SecretKey, PublicKey)> { - let (sk, pk): (glwe::SecretKey, glwe::PublicKey) = GLWE::new_key(rng)?; + pub fn new_key(rng: impl Rng, param: &Param) -> Result<(SecretKey, PublicKey)> { + let (sk, pk): (glwe::SecretKey, glwe::PublicKey) = GLWE::new_key(rng, param)?; Ok((SecretKey(sk), pk)) } - pub fn encode(m: &Rq) -> T64 { - let delta = u64::MAX / P; // floored + pub fn encode(param: &Param, m: &Rq) -> T64 { + assert_eq!(param.ring.n, 1); + debug_assert_eq!(param.t, m.param.q); // plaintext modulus + + let delta = u64::MAX / param.t; // floored let coeffs = m.coeffs(); - T64(coeffs[0].0 * delta) + T64(coeffs[0].v * delta) } - pub fn decode(p: &T64) -> Rq { - let p = p.mul_div_round(P, u64::MAX); - Rq::::from_vec_u64(p.coeffs().iter().map(|c| c.0).collect()) + pub fn decode(param: &Param, p: &T64) -> Rq { + let p = p.mul_div_round(param.t, u64::MAX); + Rq::from_vec_u64(¶m.pt(), p.coeffs().iter().map(|c| c.0).collect()) + } + /// encodes the given message as a TLWE constant/public value, for using it + /// in ct-pt-multiplication. + pub fn new_const(param: &Param, m: &Rq) -> T64 { + debug_assert_eq!(param.t, m.param.q); + T64(m.coeffs()[0].v) } // encrypts with the given SecretKey (instead of PublicKey) - pub fn encrypt_s(rng: impl Rng, sk: &SecretKey, p: &T64) -> Result { - let glwe = GLWE::encrypt_s(rng, &sk.0, p)?; + pub fn encrypt_s(rng: impl Rng, param: &Param, sk: &SecretKey, p: &T64) -> Result { + let glwe = GLWE::encrypt_s(rng, param, &sk.0, p)?; Ok(Self(glwe)) } - pub fn encrypt(rng: impl Rng, pk: &PublicKey, p: &T64) -> Result { - let glwe = GLWE::encrypt(rng, &pk, p)?; + pub fn encrypt(rng: impl Rng, param: &Param, pk: &PublicKey, p: &T64) -> Result { + let glwe = GLWE::encrypt(rng, param, pk, p)?; Ok(Self(glwe)) } - pub fn decrypt(&self, sk: &SecretKey) -> T64 { + pub fn decrypt(&self, sk: &SecretKey) -> T64 { self.0.decrypt(&sk.0) } pub fn new_ksk( mut rng: impl Rng, + param: &Param, beta: u32, l: u32, - sk: &SecretKey, - new_sk: &SecretKey, - ) -> Result> { - let r: Vec> = (0..K) + sk: &SecretKey, + new_sk: &SecretKey, + ) -> Result { + let r: Vec = (0..param.k) .into_iter() .map(|i| // treat sk_i as the msg being encrypted - TLev::::encrypt_s(&mut rng, beta, l, &new_sk, &sk.0.0 .0[i])) + TLev::encrypt_s(&mut rng, param, beta, l, &new_sk, &sk.0.0 .r[i])) .collect::>>()?; Ok(KSK(r)) } - pub fn key_switch(&self, beta: u32, l: u32, ksk: &KSK) -> Self { - let (a, b): (TR, T64) = (self.0 .0.clone(), self.0 .1); + pub fn key_switch(&self, param: &Param, beta: u32, l: u32, ksk: &KSK) -> Self { + let (a, b): (TR, T64) = (self.0 .0.clone(), self.0 .1); - let lhs: TLWE = TLWE(GLWE(TR::zero(), b)); + let lhs: TLWE = TLWE(GLWE(TR::zero(param.k * param.ring.n, ¶m.ring), b)); // K iterations, ksk.0 contains K times GLev - let rhs: TLWE = zip_eq(a.0, ksk.0.clone()) + let rhs: TLWE = zip_eq(a.r, ksk.0.clone()) .map(|(a_i, ksk_i)| ksk_i * a_i.decompose(beta, l)) // dot_product .sum(); lhs - rhs } // modulus switch from Q (2^64) to Q2 (in blind_rotation Q2=K*N) - pub fn mod_switch(&self) -> Self { - let a: TR = self.0 .0.mod_switch::(); - let b: T64 = self.0 .1.mod_switch::(); + pub fn mod_switch(&self, q2: u64) -> Self { + let a: TR = self.0 .0.mod_switch(q2); + let b: T64 = self.0 .1.mod_switch(q2); Self(GLWE(a, b)) } } -// NOTE: the ugly const generics are temporary -pub fn blind_rotation( - c: TLWE, - btk: BootstrappingKey, - table: TGLWE, -) -> TGLWE { - let c_kn: TLWE = c.mod_switch::(); - let (a, b): (TR, T64) = (c_kn.0 .0, c_kn.0 .1); + +pub fn blind_rotation( + param: &Param, + c: TLWE, // kn + btk: BootstrappingKey, + table: TGLWE, // n,k +) -> TGLWE { + debug_assert_eq!(c.0 .0.k, param.k); + + // TODO replace `param.k*param.ring.n` by `param.kn()` + let c_kn: TLWE = c.mod_switch((param.k * param.ring.n) as u64); + let (a, b): (TR, T64) = (c_kn.0 .0, c_kn.0 .1); // two main parts: rotate by a known power of X, rotate by a secret // power of X (using the C gate) // table * X^-b, ie. left rotate - let v_xb: TGLWE = table.left_rotate(b.0 as usize); + let v_xb: TGLWE = table.left_rotate(b.0 as usize); // rotate by a secret power of X using the cmux gate - let mut c_j: TGLWE = v_xb.clone(); - let _ = (1..K).map(|j| { - c_j = TGGSW::::cmux( + let mut c_j: TGLWE = v_xb.clone(); + let _ = (1..param.k).map(|j| { + c_j = TGGSW::cmux( btk.0[j].clone(), c_j.clone(), - c_j.clone().left_rotate(a.0[j].0 as usize), + c_j.clone().left_rotate(a.r[j].0 as usize), ); - dbg!(&c_j); }); c_j } -pub fn bootstrapping( - btk: BootstrappingKey, - table: TGLWE, - c: TLWE, -) -> TLWE { - let rotated: TGLWE = blind_rotation::(c, btk.clone(), table); - let c_h: TLWE = rotated.sample_extraction(0); - let r = c_h.key_switch(2, 64, &btk.1); +pub fn bootstrapping( + param: &Param, + btk: BootstrappingKey, + table: TGLWE, + c: TLWE, // kn +) -> TLWE { + // kn + let rotated: TGLWE = blind_rotation(param, c, btk.clone(), table); + let c_h: TLWE = rotated.sample_extraction(¶m, 0); + let r = c_h.key_switch(param, 2, 64, &btk.1); r } #[derive(Clone, Debug)] -pub struct BootstrappingKey( - pub Vec>, - pub KSK, +pub struct BootstrappingKey( + pub Vec, + pub KSK, // kn ); -impl BootstrappingKey { - pub fn from_sk(mut rng: impl Rng, sk: &tglwe::SecretKey) -> Result { +impl BootstrappingKey { + pub fn from_sk(mut rng: impl Rng, param: &Param, sk: &tglwe::SecretKey) -> Result { let (beta, l) = (2u32, 64u32); // TMP - // - let s: TR, K> = sk.0 .0.clone(); - let (sk2, _) = TLWE::::new_key(&mut rng)?; // TLWE compatible with TGLWE + + let s: TR = sk.0 .0.clone(); + let (sk2, _) = TLWE::new_key(&mut rng, ¶m.lwe())?; // TLWE compatible with TGLWE // each btk_j = TGGSW_sk(s_i) - let btk: Vec> = s + let btk: Vec = s .iter() - .map(|s_i| TGGSW::::encrypt_s(&mut rng, beta, l, sk, s_i)) + .map(|s_i| TGGSW::encrypt_s(&mut rng, param, beta, l, sk, s_i)) .collect::>>()?; - let ksk = TLWE::::new_ksk(&mut rng, beta, l, &sk.to_tlwe(), &sk2)?; + + let ksk = TLWE::new_ksk( + &mut rng, + ¶m.lwe(), + beta, + l, + &sk.to_tlwe(¶m.lwe()), // converted to length k*n + &sk2, // created with length k*n + )?; + debug_assert_eq!(ksk.0.len(), param.lwe().k); + debug_assert_eq!(ksk.0.len(), param.k * param.ring.n); Ok(Self(btk, ksk)) } } -pub fn compute_lookup_table() -> TGLWE { +pub fn compute_lookup_table(param: &Param) -> TGLWE { // from 2021-1402: // v(x) = \sum_j^{N-1} [(p_j / 2N mod p)/p] X^j // matrix of coefficients with size K*N = delta x T - let delta: usize = N / T as usize; - let values: Vec> = (0..T).map(|v| Zq::::from_u64(v)).collect(); - let coeffs: Vec> = (0..T as usize) + let delta: usize = param.ring.n / param.t as usize; + let values: Vec = (0..param.t).map(|v| Zq::from_u64(param.t, v)).collect(); + let coeffs: Vec = (0..param.t as usize) .flat_map(|i| vec![values[i]; delta]) .collect(); - let table = Rq::::from_vec(coeffs); + let table = Rq::from_vec(¶m.pt(), coeffs); // encode the table as plaintext - let v: Tn = TGLWE::::encode::(&table); + let v: Tn = TGLWE::encode(param, &table); // encode the table as TGLWE ciphertext - let v: TGLWE = TGLWE::::from_plaintext(v); + let v: TGLWE = TGLWE::from_plaintext(param.k, ¶m.ring, v); v } -impl Add> for TLWE { +impl Add for TLWE { type Output = Self; fn add(self, other: Self) -> Self { + debug_assert_eq!(self.0 .0.k, other.0 .0.k); + debug_assert_eq!(self.0 .1.param(), other.0 .1.param()); Self(self.0 + other.0) } } -impl AddAssign for TLWE { +impl AddAssign for TLWE { fn add_assign(&mut self, rhs: Self) { + debug_assert_eq!(self.0 .0.k, rhs.0 .0.k); + debug_assert_eq!(self.0 .1.param(), rhs.0 .1.param()); self.0 += rhs.0 } } -impl Sum> for TLWE { - fn sum(iter: I) -> Self +impl Sum for TLWE { + fn sum(mut iter: I) -> Self where I: Iterator, { - let mut acc = TLWE::::zero(); - for e in iter { - acc += e; - } - acc + let first = iter.next().unwrap(); + iter.fold(first, |acc, e| acc + e) } } -impl Sub> for TLWE { +impl Sub for TLWE { type Output = Self; fn sub(self, other: Self) -> Self { + debug_assert_eq!(self.0 .0.k, other.0 .0.k); + debug_assert_eq!(self.0 .1.param(), other.0 .1.param()); Self(self.0 - other.0) } } // plaintext addition -impl Add for TLWE { +impl Add for TLWE { type Output = Self; fn add(self, plaintext: T64) -> Self { - let a: TR = self.0 .0; + let a: TR = self.0 .0; let b: T64 = self.0 .1 + plaintext; Self(GLWE(a, b)) } } // plaintext substraction -impl Sub for TLWE { +impl Sub for TLWE { type Output = Self; fn sub(self, plaintext: T64) -> Self { - let a: TR = self.0 .0; + let a: TR = self.0 .0; let b: T64 = self.0 .1 - plaintext; Self(GLWE(a, b)) } } // plaintext multiplication -impl Mul for TLWE { +impl Mul for TLWE { type Output = Self; fn mul(self, plaintext: T64) -> Self { - let a: TR = TR(self.0 .0 .0.iter().map(|r_i| *r_i * plaintext).collect()); + let a: TR = TR { + k: self.0 .0.k, + r: self.0 .0.r.iter().map(|r_i| *r_i * plaintext).collect(), + }; let b: T64 = self.0 .1 * plaintext; Self(GLWE(a, b)) } @@ -255,29 +288,32 @@ mod tests { #[test] fn test_encrypt_decrypt() -> Result<()> { - const T: u64 = 128; // msg space (msg modulus) - const K: usize = 16; - type S = TLWE; + let param = Param { + err_sigma: crate::ERR_SIGMA, + ring: RingParam { q: u64::MAX, n: 1 }, + k: 16, + t: 128, // plaintext modulus + }; let mut rng = rand::thread_rng(); - let msg_dist = Uniform::new(0_u64, T); + let msg_dist = Uniform::new(0_u64, param.t); for _ in 0..200 { - let (sk, pk) = S::new_key(&mut rng)?; + let (sk, pk) = TLWE::new_key(&mut rng, ¶m)?; - let m = Rq::::rand_u64(&mut rng, msg_dist)?; - let p: T64 = S::encode::(&m); + let m = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; + let p: T64 = TLWE::encode(¶m, &m); - let c = S::encrypt(&mut rng, &pk, &p)?; + let c = TLWE::encrypt(&mut rng, ¶m, &pk, &p)?; let p_recovered = c.decrypt(&sk); - let m_recovered = S::decode::(&p_recovered); + let m_recovered = TLWE::decode(¶m, &p_recovered); assert_eq!(m, m_recovered); // same but using encrypt_s (with sk instead of pk)) - let c = S::encrypt_s(&mut rng, &sk, &p)?; + let c = TLWE::encrypt_s(&mut rng, ¶m, &sk, &p)?; let p_recovered = c.decrypt(&sk); - let m_recovered = S::decode::(&p_recovered); + let m_recovered = TLWE::decode(¶m, &p_recovered); assert_eq!(m, m_recovered); } @@ -287,30 +323,33 @@ mod tests { #[test] fn test_addition() -> Result<()> { - const T: u64 = 128; - const K: usize = 16; - type S = TLWE; + let param = Param { + err_sigma: crate::ERR_SIGMA, + ring: RingParam { q: u64::MAX, n: 1 }, + k: 16, + t: 128, // plaintext modulus + }; let mut rng = rand::thread_rng(); - let msg_dist = Uniform::new(0_u64, T); + let msg_dist = Uniform::new(0_u64, param.t); for _ in 0..200 { - let (sk, pk) = S::new_key(&mut rng)?; + let (sk, pk) = TLWE::new_key(&mut rng, ¶m)?; - let m1 = Rq::::rand_u64(&mut rng, msg_dist)?; - let m2 = Rq::::rand_u64(&mut rng, msg_dist)?; - let p1: T64 = S::encode::(&m1); // plaintext - let p2: T64 = S::encode::(&m2); // plaintext + let m1 = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; + let m2 = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; + let p1: T64 = TLWE::encode(¶m, &m1); // plaintext + let p2: T64 = TLWE::encode(¶m, &m2); // plaintext - let c1 = S::encrypt(&mut rng, &pk, &p1)?; - let c2 = S::encrypt(&mut rng, &pk, &p2)?; + let c1 = TLWE::encrypt(&mut rng, ¶m, &pk, &p1)?; + let c2 = TLWE::encrypt(&mut rng, ¶m, &pk, &p2)?; let c3 = c1 + c2; let p3_recovered = c3.decrypt(&sk); - let m3_recovered = S::decode::(&p3_recovered); + let m3_recovered = TLWE::decode(¶m, &p3_recovered); - assert_eq!((m1 + m2).remodule::(), m3_recovered.remodule::()); + assert_eq!((m1 + m2).remodule(param.t), m3_recovered.remodule(param.t)); } Ok(()) @@ -318,27 +357,30 @@ mod tests { #[test] fn test_add_plaintext() -> Result<()> { - const T: u64 = 128; - const K: usize = 16; - type S = TLWE; + let param = Param { + err_sigma: crate::ERR_SIGMA, + ring: RingParam { q: u64::MAX, n: 1 }, + k: 16, + t: 128, // plaintext modulus + }; let mut rng = rand::thread_rng(); - let msg_dist = Uniform::new(0_u64, T); + let msg_dist = Uniform::new(0_u64, param.t); for _ in 0..200 { - let (sk, pk) = S::new_key(&mut rng)?; + let (sk, pk) = TLWE::new_key(&mut rng, ¶m)?; - let m1 = Rq::::rand_u64(&mut rng, msg_dist)?; - let m2 = Rq::::rand_u64(&mut rng, msg_dist)?; - let p1: T64 = S::encode::(&m1); // plaintext - let p2: T64 = S::encode::(&m2); // plaintext + let m1 = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; + let m2 = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; + let p1: T64 = TLWE::encode(¶m, &m1); // plaintext + let p2: T64 = TLWE::encode(¶m, &m2); // plaintext - let c1 = S::encrypt(&mut rng, &pk, &p1)?; + let c1 = TLWE::encrypt(&mut rng, ¶m, &pk, &p1)?; let c3 = c1 + p2; let p3_recovered = c3.decrypt(&sk); - let m3_recovered = S::decode::(&p3_recovered); + let m3_recovered = TLWE::decode(¶m, &p3_recovered); assert_eq!(m1 + m2, m3_recovered); } @@ -348,30 +390,31 @@ mod tests { #[test] fn test_mul_plaintext() -> Result<()> { - const T: u64 = 128; - const K: usize = 16; - type S = TLWE; + let param = Param { + err_sigma: crate::ERR_SIGMA, + ring: RingParam { q: u64::MAX, n: 1 }, + k: 16, + t: 128, // plaintext modulus + }; let mut rng = rand::thread_rng(); - let msg_dist = Uniform::new(0_u64, T); + let msg_dist = Uniform::new(0_u64, param.t); for _ in 0..200 { - let (sk, pk) = S::new_key(&mut rng)?; + let (sk, pk) = TLWE::new_key(&mut rng, ¶m)?; - let m1 = Rq::::rand_u64(&mut rng, msg_dist)?; - let m2 = Rq::::rand_u64(&mut rng, msg_dist)?; - let p1: T64 = S::encode::(&m1); - // don't scale up p2, set it directly from m2 - // let p2: T64 = Tn(array::from_fn(|i| T64(m2.coeffs()[i].0))); - let p2: T64 = T64(m2.coeffs()[0].0); + let m1 = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; + let m2 = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; + let p1: T64 = TLWE::encode(¶m, &m1); + let p2: T64 = TLWE::new_const(¶m, &m2); // as constant/public value - let c1 = S::encrypt(&mut rng, &pk, &p1)?; + let c1 = TLWE::encrypt(&mut rng, ¶m, &pk, &p1)?; let c3 = c1 * p2; let p3_recovered: T64 = c3.decrypt(&sk); - let m3_recovered = S::decode::(&p3_recovered); - assert_eq!((m1.to_r() * m2.to_r()).to_rq::(), m3_recovered); + let m3_recovered = TLWE::decode(¶m, &p3_recovered); + assert_eq!((m1.to_r() * m2.to_r()).to_rq(param.t), m3_recovered); } Ok(()) @@ -379,38 +422,41 @@ mod tests { #[test] fn test_key_switch() -> Result<()> { - const T: u64 = 128; // plaintext modulus - const K: usize = 16; - type S = TLWE; + let param = Param { + err_sigma: crate::ERR_SIGMA, + ring: RingParam { q: u64::MAX, n: 1 }, + k: 16, + t: 128, // plaintext modulus + }; let beta: u32 = 2; let l: u32 = 64; let mut rng = rand::thread_rng(); - let (sk, pk) = S::new_key(&mut rng)?; - let (sk2, _) = S::new_key(&mut rng)?; + let (sk, pk) = TLWE::new_key(&mut rng, ¶m)?; + let (sk2, _) = TLWE::new_key(&mut rng, ¶m)?; // ksk to switch from sk to sk2 - let ksk = S::new_ksk(&mut rng, beta, l, &sk, &sk2)?; + let ksk = TLWE::new_ksk(&mut rng, ¶m, beta, l, &sk, &sk2)?; + + let msg_dist = Uniform::new(0_u64, param.t); + let m = Rq::rand_u64(&mut rng, msg_dist, ¶m.pt())?; + let p = TLWE::encode(¶m, &m); // plaintext - let msg_dist = Uniform::new(0_u64, T); - let m = Rq::::rand_u64(&mut rng, msg_dist)?; - let p = S::encode::(&m); // plaintext - // - let c = S::encrypt_s(&mut rng, &sk, &p)?; + let c = TLWE::encrypt_s(&mut rng, ¶m, &sk, &p)?; - let c2 = c.key_switch(beta, l, &ksk); + let c2 = c.key_switch(¶m, beta, l, &ksk); // decrypt with the 2nd secret key let p_recovered = c2.decrypt(&sk2); - let m_recovered = S::decode::(&p_recovered); - assert_eq!(m.remodule::(), m_recovered.remodule::()); + let m_recovered = TLWE::decode(¶m, &p_recovered); + assert_eq!(m.remodule(param.t), m_recovered.remodule(param.t)); // do the same but now encrypting with pk - let c = S::encrypt(&mut rng, &pk, &p)?; - let c2 = c.key_switch(beta, l, &ksk); + let c = TLWE::encrypt(&mut rng, ¶m, &pk, &p)?; + let c2 = c.key_switch(¶m, beta, l, &ksk); let p_recovered = c2.decrypt(&sk2); - let m_recovered = S::decode::(&p_recovered); + let m_recovered = TLWE::decode(¶m, &p_recovered); assert_eq!(m, m_recovered); Ok(()) @@ -418,39 +464,40 @@ mod tests { #[test] fn test_bootstrapping() -> Result<()> { - const T: u64 = 128; // plaintext modulus - const K: usize = 1; - const N: usize = 1024; - const KN: usize = K * N; + let param = Param { + err_sigma: crate::ERR_SIGMA, + ring: RingParam { + q: u64::MAX, + n: 1024, + }, + k: 1, + t: 128, // plaintext modulus + }; let mut rng = rand::thread_rng(); let start = Instant::now(); - let table: TGLWE = compute_lookup_table::(); + let table: TGLWE = compute_lookup_table(¶m); println!("table took: {:?}", start.elapsed()); - let (sk, _) = TGLWE::::new_key::(&mut rng)?; - let sk_tlwe: SecretKey = sk.to_tlwe::(); + let (sk, _) = TGLWE::new_key(&mut rng, ¶m)?; + let sk_tlwe: SecretKey = sk.to_tlwe(¶m); let start = Instant::now(); - let btk = BootstrappingKey::::from_sk(&mut rng, &sk)?; + let btk = BootstrappingKey::from_sk(&mut rng, ¶m, &sk)?; println!("btk took: {:?}", start.elapsed()); - let msg_dist = Uniform::new(0_u64, T); - let m = Rq::::rand_u64(&mut rng, msg_dist)?; - dbg!(&m); - let p = TLWE::::encode::(&m); // plaintext + let msg_dist = Uniform::new(0_u64, param.t); + let m = Rq::rand_u64(&mut rng, msg_dist, ¶m.lwe().pt())?; // q=t, n=1 + let p = TLWE::encode(¶m.lwe(), &m); // plaintext - let c = TLWE::::encrypt_s(&mut rng, &sk_tlwe, &p)?; + let c = TLWE::encrypt_s(&mut rng, ¶m.lwe(), &sk_tlwe, &p)?; let start = Instant::now(); - // the ugly const generics are temporary - let bootstrapped: TLWE = - bootstrapping::(btk, table, c); + let bootstrapped: TLWE = bootstrapping(¶m, btk, table, c); println!("bootstrapping took: {:?}", start.elapsed()); let p_recovered: T64 = bootstrapped.decrypt(&sk_tlwe); - let m_recovered = TLWE::::decode::(&p_recovered); - dbg!(&m_recovered); + let m_recovered = TLWE::decode(¶m.lwe(), &p_recovered); assert_eq!(m_recovered, m); Ok(())