diff --git a/README.md b/README.md new file mode 100644 index 0000000..585300f --- /dev/null +++ b/README.md @@ -0,0 +1,4 @@ +# fhe-study +Code done while studying some FHE papers. + +- arithmetic: contains $\mathbb{Z}_q$ and $\mathbb{Z}_q[X]/(X^N+1)$ arithmetic implementations, together with the NTT implementation. diff --git a/arithmetic/.gitignore b/arithmetic/.gitignore new file mode 100644 index 0000000..d7524a2 --- /dev/null +++ b/arithmetic/.gitignore @@ -0,0 +1,3 @@ +/target +Cargo.lock +*.sage.py diff --git a/arithmetic/README.md b/arithmetic/README.md new file mode 100644 index 0000000..6ba3705 --- /dev/null +++ b/arithmetic/README.md @@ -0,0 +1,2 @@ +# arithmetic +Contains $\mathbb{Z}_q$ and $\mathbb{Z}_q[X]/(X^N+1)$ arithmetic implementations, together with the NTT implementation. diff --git a/arithmetic/src/lib.rs b/arithmetic/src/lib.rs index a0c2bce..3ec2590 100644 --- a/arithmetic/src/lib.rs +++ b/arithmetic/src/lib.rs @@ -4,8 +4,11 @@ #![allow(clippy::upper_case_acronyms)] #![allow(dead_code)] // TMP +mod naive; // TODO rm +pub mod ntt; pub mod ring; pub mod zq; +pub use ntt::NTT; pub use ring::PR; pub use zq::Zq; diff --git a/arithmetic/src/naive.rs b/arithmetic/src/naive.rs new file mode 100644 index 0000000..7968e08 --- /dev/null +++ b/arithmetic/src/naive.rs @@ -0,0 +1,195 @@ +//! this file implements the non-efficient NTT, which uses multiplication by the +//! Vandermonde matrix. +use crate::zq::Zq; + +use anyhow::{anyhow, Result}; + +#[derive(Debug)] +pub struct NTT { + pub primitive: Zq, + // nth_roots: Vec>, + pub ntt: Vec>>, + pub intt: Vec>>, +} + +impl NTT { + pub fn new() -> Result { + // TODO change n to be u64 and ensure that is n>> = Self::vandermonde(primitive); + let intt = Self::invert_vandermonde(&ntt); + Ok(Self { + primitive, + // nth_roots, + ntt, + intt, + }) + } + pub fn vandermonde(primitive: Zq) -> Vec>> { + let mut v: Vec>> = vec![]; + let n = (2 * N) as u64; + // let n = N as u64; + for i in 0..n { + let mut row: Vec> = vec![]; + let primitive_i = primitive.exp(Zq(i)); + let mut primitive_ij = Zq(1); + for _ in 0..n { + row.push(primitive_ij); + primitive_ij = primitive_ij * primitive_i; + } + v.push(row); + } + v + } + // specifically for the Vandermonde matrix + pub fn invert_vandermonde(v: &Vec>>) -> Vec>> { + let n = 2 * N; + // let n = N; + let mut inv: Vec>> = vec![]; + for i in 0..n { + let w_i = v[i][1]; // = w_i^1=w^i^1 = w^i + let w_i_inv = w_i.inv(); + let mut row: Vec> = vec![]; + for j in 0..n { + row.push(w_i_inv.exp(Zq(j as u64)) / Zq(n as u64)); + } + inv.push(row); + } + inv + } + + pub fn get_primitive_root_of_unity(n: u64) -> Result> { + // using the method described by Thomas Pornin in + // https://crypto.stackexchange.com/a/63616 + + // assert!((Q - 1) % N as u64 == 0); + assert!((Q - 1) % n == 0); + + // TODO maybe not using Zq and using u64 directly + let n = Zq(n); + for k in 0..Q { + if k == 0 { + continue; + } + let g = Zq(k); + // g = F.random_element() + if g == Zq(0) { + continue; + } + let w = g.exp((-Zq(1)) / n); + if w.exp(n / Zq(2)) != Zq(1) { + // g is the generator + return Ok(w); + } + } + Err(anyhow!("can not find the primitive root of unity")) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rand_distr::Uniform; + + use crate::ring::matrix_vec_product; + use crate::ring::PR; + + #[test] + fn roots_of_unity() -> Result<()> { + const Q: u64 = 12289; + const N: usize = 512; + let _ntt = NTT::::new()?; + Ok(()) + } + + #[test] + fn vandermonde_ntt() -> Result<()> { + const Q: u64 = 41; + const N: usize = 4; + let primitive = NTT::::get_primitive_root_of_unity((2 * N) as u64)?; + let v = NTT::::vandermonde(primitive); + + // naively compute the Vandermonde matrix, and assert that the one from the method matches + // the naively obtained one + let n2 = (2 * N) as u64; + let mut v2: Vec>> = vec![]; + for i in 0..n2 { + let mut row: Vec> = vec![]; + for j in 0..n2 { + row.push(primitive.exp(Zq(i * j))); + } + v2.push(row); + } + assert_eq!(v, v2); + + let v_inv = NTT::::invert_vandermonde(&v); + + let mut rng = rand::thread_rng(); + let uniform_distr = Uniform::new(0_f64, Q as f64); + let a = PR::::rand(&mut rng, uniform_distr)?; + // let a = PR::::new_from_u64(vec![36, 21, 9, 19]); + + // let a_padded_coeffs: [Zq; 2 * N] = + // std::array::from_fn(|i| if i < N { a.coeffs[i] } else { Zq::zero() }); + let mut a_padded = a.coeffs.to_vec(); + a_padded.append(&mut vec![Zq(0); N]); + // let a_ntt = a_padded.mul_by_matrix(&v)?; + let a_ntt = matrix_vec_product(&v, &a_padded)?; + let a_intt: Vec> = matrix_vec_product(&v_inv, &a_ntt)?; + assert_eq!(a_intt, a_padded); + let a_intt_arr: [Zq; N] = std::array::from_fn(|i| a_intt[i]); + assert_eq!(PR::new(a_intt_arr, None), a); + + Ok(()) + } + + #[test] + fn vec_by_ntt() -> Result<()> { + const Q: u64 = 257; + const N: usize = 4; + // let primitive = NTT::::get_primitive_root_of_unity((2*N) as u64)?; + let ntt = NTT::::new()?; + + let a: Vec> = vec![256, 256, 256, 256, 0, 0, 0, 0] + .iter() + .map(|&e| Zq::new(e)) + .collect(); + let a_ntt = matrix_vec_product(&ntt.ntt, &a)?; + let a_intt = matrix_vec_product(&ntt.intt, &a_ntt)?; + assert_eq!(a_intt, a); + + Ok(()) + } + + #[test] + fn bench_ntt() -> Result<()> { + // const Q: u64 = 12289; + // const N: usize = 512; + const Q: u64 = 257; + const N: usize = 4; + // let primitive = NTT::::get_primitive_root_of_unity((2*N) as u64)?; + let ntt = NTT::::new()?; + + let rng = rand::thread_rng(); + let a = PR::::rand(rng, Uniform::new(0_f64, (Q - 1) as f64))?; + let a = a.coeffs; + dbg!(&a); + let a_ntt = matrix_vec_product(&ntt.ntt, &a.to_vec())?; + dbg!(&a_ntt); + let a_intt = matrix_vec_product(&ntt.intt, &a_ntt)?; + dbg!(&a_intt); + assert_eq!(a_intt, a); + + Ok(()) + } +} diff --git a/arithmetic/src/ntt.rs b/arithmetic/src/ntt.rs new file mode 100644 index 0000000..9a535d6 --- /dev/null +++ b/arithmetic/src/ntt.rs @@ -0,0 +1,183 @@ +//! Implementation of the NTT & iNTT, following the CT & GS algorighms, more +//! details in https://github.com/arnaucube/math/blob/master/notes_ntt.pdf . +use crate::zq::Zq; + +#[derive(Debug)] +pub struct NTT {} + +impl NTT { + const N_INV: Zq = Zq(const_inv_mod::(N as u64)); + // since we work over Zq[X]/(X^N+1) (negacyclic), get the 2*N-th root of unity + pub(crate) const ROOT_OF_UNITY: u64 = primitive_root_of_unity::(2 * N); + pub(crate) const ROOTS_OF_UNITY: [Zq; N] = roots_of_unity(Self::ROOT_OF_UNITY); + const ROOTS_OF_UNITY_INV: [Zq; N] = roots_of_unity_inv(Self::ROOTS_OF_UNITY); +} + +impl NTT { + /// implements the Cooley-Tukey (CT) algorithm. Details at section 3.1 of + /// https://github.com/arnaucube/math/blob/master/notes_ntt.pdf + pub fn ntt(a: [Zq; N]) -> [Zq; N] { + let mut t = N / 2; + let mut m = 1; + let mut r: [Zq; N] = a.clone(); + while m < N { + let mut k = 0; + for i in 0..m { + let S: Zq = Self::ROOTS_OF_UNITY[m + i]; + for j in k..k + t { + let U: Zq = r[j]; + let V: Zq = r[j + t] * S; + r[j] = U + V; + r[j + t] = U - V; + } + k = k + 2 * t; + } + t /= 2; + m *= 2; + } + r + } + + /// implements the Gentleman-Sande (GS) algorithm. Details at section 3.2 of + /// https://github.com/arnaucube/math/blob/master/notes_ntt.pdf + pub fn intt(a: [Zq; N]) -> [Zq; N] { + let mut t = 1; + let mut m = N / 2; + let mut r: [Zq; N] = a.clone(); + while m > 0 { + let mut k = 0; + for i in 0..m { + let S: Zq = Self::ROOTS_OF_UNITY_INV[m + i]; + for j in k..k + t { + let U: Zq = r[j]; + let V: Zq = r[j + t]; + r[j] = U + V; + r[j + t] = (U - V) * S; + } + k += 2 * t; + } + t *= 2; + m /= 2; + } + for i in 0..N { + r[i] = r[i] * Self::N_INV; + } + r + } +} + +/// computes a primitive N-th root of unity using the method described by Thomas +/// Pornin in https://crypto.stackexchange.com/a/63616 +const fn primitive_root_of_unity(N: usize) -> u64 { + assert!(N.is_power_of_two()); + assert!((Q - 1) % N as u64 == 0); + + let n: u64 = N as u64; + let mut k = 1; + while k < Q { + // alternatively could get a random k at each iteration, if so, add the following if: + // `if k == 0 { continue; }` + let w = const_exp_mod::(k, (Q - 1) / n); + if const_exp_mod::(w, n / 2) != 1 { + return w; // w is a primitive N-th root of unity + } + k += 1; + } + panic!("No primitive root of unity"); +} + +const fn roots_of_unity(w: u64) -> [Zq; N] { + let mut r: [Zq; N] = [Zq(0u64); N]; + let mut i = 0; + let log_n = N.ilog2(); + while i < N { + // (return the roots in bit-reverset order) + let j = ((i as u64).reverse_bits() >> (64 - log_n)) as usize; + r[i] = Zq(const_exp_mod::(w, j as u64)); + i += 1; + } + r +} + +const fn roots_of_unity_inv(v: [Zq; N]) -> [Zq; N] { + // assumes that the inputted roots are already in bit-reverset order + let mut r: [Zq; N] = [Zq(0u64); N]; + let mut i = 0; + while i < N { + r[i] = Zq(const_inv_mod::(v[i].0)); + i += 1; + } + r +} + +/// returns x^k mod Q +const fn const_exp_mod(x: u64, k: u64) -> u64 { + let mut r = 1u64; + let mut x = x; + let mut k = k; + x = x % Q; + // exponentiation by square strategy + while k > 0 { + if k % 2 == 1 { + r = (r * x) % Q; + } + x = (x * x) % Q; + k /= 2; + } + r +} + +/// returns x^-1 mod Q +const fn const_inv_mod(x: u64) -> u64 { + // by Fermat's Little Theorem, x^-1 mod q \equiv x^{q-2} mod q + const_exp_mod::(x, Q - 2) +} + +#[cfg(test)] +mod tests { + use super::*; + + use anyhow::Result; + use std::array; + + #[test] + fn test_ntt() -> Result<()> { + const Q: u64 = 2u64.pow(16) + 1; + const N: usize = 4; + + let a: [u64; N] = [1u64, 2, 3, 4]; + let a: [Zq; N] = array::from_fn(|i| Zq::new(a[i])); + + let a_ntt = NTT::::ntt(a); + + let a_intt = NTT::::intt(a_ntt); + + dbg!(&a); + dbg!(&a_ntt); + dbg!(&a_intt); + dbg!(NTT::::ROOT_OF_UNITY); + dbg!(NTT::::ROOTS_OF_UNITY); + + assert_eq!(a, a_intt); + Ok(()) + } + + #[test] + fn test_ntt_loop() -> Result<()> { + const Q: u64 = 2u64.pow(16) + 1; + const N: usize = 512; + + use rand::distributions::Distribution; + use rand::distributions::Uniform; + let mut rng = rand::thread_rng(); + let dist = Uniform::new(0_f64, Q as f64); + + for _ in 0..100 { + let a: [Zq; N] = array::from_fn(|_| Zq::from_f64(dist.sample(&mut rng))); + let a_ntt = NTT::::ntt(a); + let a_intt = NTT::::intt(a_ntt); + assert_eq!(a, a_intt); + } + Ok(()) + } +} diff --git a/arithmetic/src/ring.rs b/arithmetic/src/ring.rs index 6b8e883..0b9b57a 100644 --- a/arithmetic/src/ring.rs +++ b/arithmetic/src/ring.rs @@ -3,6 +3,7 @@ use std::array; use std::fmt; use std::ops; +use crate::ntt::NTT; use crate::zq::Zq; use anyhow::{anyhow, Result}; @@ -78,6 +79,35 @@ impl PR { }) } + // TODO review if needed, or if with this interface + pub fn mul_by_matrix(&self, m: &Vec>>) -> Result>> { + matrix_vec_product(m, &self.coeffs.to_vec()) + } + pub fn mul_by_zq(&self, s: &Zq) -> Self { + Self { + coeffs: array::from_fn(|i| self.coeffs[i] * *s), + evals: None, + } + } + pub fn mul_by_u64(&self, s: u64) -> Self { + let s = Zq::new(s); + Self { + coeffs: array::from_fn(|i| self.coeffs[i] * s), + // coeffs: self.coeffs.iter().map(|&e| e * s).collect(), + evals: None, + } + } + pub fn mul_by_f64(&self, s: f64) -> Self { + Self { + coeffs: array::from_fn(|i| Zq::from_f64(self.coeffs[i].0 as f64 * s)), + evals: None, + } + } + + pub fn mul(&mut self, rhs: &mut Self) -> Self { + mul_mut(self, rhs) + } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { // TODO simplify let mut str = ""; @@ -207,6 +237,51 @@ impl ops::Sub<&PR> for &PR { } } } +impl ops::Mul> for PR { + type Output = Self; + + fn mul(self, rhs: Self) -> Self { + mul(&self, &rhs) + } +} +impl ops::Mul<&PR> for &PR { + type Output = PR; + + fn mul(self, rhs: &PR) -> Self::Output { + mul(self, rhs) + } +} + +// mul by Zq element +impl ops::Mul> for PR { + type Output = Self; + + fn mul(self, s: Zq) -> Self { + self.mul_by_zq(&s) + } +} +impl ops::Mul<&Zq> for &PR { + type Output = PR; + + fn mul(self, s: &Zq) -> Self::Output { + self.mul_by_zq(s) + } +} +// mul by u64 +impl ops::Mul for PR { + type Output = Self; + + fn mul(self, s: u64) -> Self { + self.mul_by_u64(s) + } +} +impl ops::Mul<&u64> for &PR { + type Output = PR; + + fn mul(self, s: &u64) -> Self::Output { + self.mul_by_u64(*s) + } +} impl ops::Neg for PR { type Output = Self; @@ -219,6 +294,39 @@ impl ops::Neg for PR { } } +fn mul_mut(lhs: &mut PR, rhs: &mut PR) -> PR { + // reuse evaluations if already computed + if !lhs.evals.is_some() { + lhs.evals = Some(NTT::::ntt(lhs.coeffs)); + }; + if !rhs.evals.is_some() { + rhs.evals = Some(NTT::::ntt(rhs.coeffs)); + }; + let lhs_evals = lhs.evals.unwrap(); + let rhs_evals = rhs.evals.unwrap(); + + let c_ntt: [Zq; N] = array::from_fn(|i| lhs_evals[i] * rhs_evals[i]); + let c = NTT::::intt(c_ntt); + PR::new(c, Some(c_ntt)) +} +fn mul(lhs: &PR, rhs: &PR) -> PR { + // reuse evaluations if already computed + let lhs_evals = if lhs.evals.is_some() { + lhs.evals.unwrap() + } else { + NTT::::ntt(lhs.coeffs) + }; + let rhs_evals = if rhs.evals.is_some() { + rhs.evals.unwrap() + } else { + NTT::::ntt(rhs.coeffs) + }; + + let c_ntt: [Zq; N] = array::from_fn(|i| lhs_evals[i] * rhs_evals[i]); + let c = NTT::::intt(c_ntt); + PR::new(c, Some(c_ntt)) +} + impl fmt::Display for PR { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { self.fmt(f)?; @@ -277,4 +385,40 @@ mod tests { "x^2 + x + 1 mod Z_7/(X^3+1)" ); } + + fn test_mul_opt( + a: [u64; N], + b: [u64; N], + expected_c: [u64; N], + ) -> Result<()> { + let a: [Zq; N] = array::from_fn(|i| Zq::new(a[i])); + let mut a = PR::new(a, None); + let b: [Zq; N] = array::from_fn(|i| Zq::new(b[i])); + let mut b = PR::new(b, None); + let expected_c: [Zq; N] = array::from_fn(|i| Zq::new(expected_c[i])); + let expected_c = PR::new(expected_c, None); + + let c = mul_mut(&mut a, &mut b); + assert_eq!(c, expected_c); + Ok(()) + } + #[test] + fn test_mul() -> Result<()> { + const Q: u64 = 2u64.pow(16) + 1; + const N: usize = 4; + + let a: [u64; N] = [1u64, 2, 3, 4]; + let b: [u64; N] = [1u64, 2, 3, 4]; + let c: [u64; N] = [65513, 65517, 65531, 20]; + test_mul_opt::(a, b, c)?; + + let a: [u64; N] = [0u64, 0, 0, 2]; + let b: [u64; N] = [0u64, 0, 0, 2]; + let c: [u64; N] = [0u64, 0, 65533, 0]; + test_mul_opt::(a, b, c)?; + + // TODO more testvectors + + Ok(()) + } }