mirror of
https://github.com/arnaucube/fhe-study.git
synced 2026-01-24 04:33:52 +01:00
ntt: get rid of Zq and use u64 instead (>2x speed improvement)
This commit is contained in:
117
arith/src/ntt.rs
117
arith/src/ntt.rs
@@ -6,7 +6,6 @@
|
|||||||
//! generics; but once using real-world parameters, the stack could not handle
|
//! generics; but once using real-world parameters, the stack could not handle
|
||||||
//! it, so moved to use Vec instead of fixed-sized arrays, and adapted the NTT
|
//! it, so moved to use Vec instead of fixed-sized arrays, and adapted the NTT
|
||||||
//! implementation to that too.
|
//! implementation to that too.
|
||||||
use crate::{ring::RingParam, ring_nq::Rq, zq::Zq};
|
|
||||||
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
@@ -15,22 +14,19 @@ pub struct NTT {}
|
|||||||
|
|
||||||
use std::sync::{Mutex, OnceLock};
|
use std::sync::{Mutex, OnceLock};
|
||||||
|
|
||||||
static CACHE: OnceLock<Mutex<HashMap<(u64, usize), (Vec<Zq>, Vec<Zq>, Zq)>>> = OnceLock::new();
|
static CACHE: OnceLock<Mutex<HashMap<(u64, usize), (Vec<u64>, Vec<u64>, u64)>>> = OnceLock::new();
|
||||||
|
|
||||||
fn roots(q: u64, n: usize) -> (Vec<Zq>, Vec<Zq>, Zq) {
|
fn roots(q: u64, n: usize) -> (Vec<u64>, Vec<u64>, u64) {
|
||||||
let cache_lock = CACHE.get_or_init(|| Mutex::new(HashMap::new()));
|
let cache_lock = CACHE.get_or_init(|| Mutex::new(HashMap::new()));
|
||||||
let mut cache = cache_lock.lock().unwrap();
|
let mut cache = cache_lock.lock().unwrap();
|
||||||
if let Some(value) = cache.get(&(q, n)) {
|
if let Some(value) = cache.get(&(q, n)) {
|
||||||
return value.clone();
|
return value.clone();
|
||||||
}
|
}
|
||||||
|
|
||||||
let n_inv: Zq = Zq {
|
let n_inv: u64 = const_inv_mod(q, n as u64);
|
||||||
q,
|
|
||||||
v: const_inv_mod(q, n as u64),
|
|
||||||
};
|
|
||||||
let root_of_unity: u64 = primitive_root_of_unity(q, 2 * n);
|
let root_of_unity: u64 = primitive_root_of_unity(q, 2 * n);
|
||||||
let roots_of_unity: Vec<Zq> = roots_of_unity(q, n, root_of_unity);
|
let roots_of_unity: Vec<u64> = roots_of_unity(q, n, root_of_unity);
|
||||||
let roots_of_unity_inv: Vec<Zq> = roots_of_unity_inv(q, n, roots_of_unity.clone());
|
let roots_of_unity_inv: Vec<u64> = roots_of_unity_inv(q, n, roots_of_unity.clone());
|
||||||
let value = (roots_of_unity, roots_of_unity_inv, n_inv);
|
let value = (roots_of_unity, roots_of_unity_inv, n_inv);
|
||||||
|
|
||||||
cache.insert((q, n), value.clone());
|
cache.insert((q, n), value.clone());
|
||||||
@@ -41,56 +37,70 @@ impl NTT {
|
|||||||
/// implements the Cooley-Tukey (CT) algorithm. Details at
|
/// implements the Cooley-Tukey (CT) algorithm. Details at
|
||||||
/// https://eprint.iacr.org/2017/727.pdf, also some notes at section 3.1 of
|
/// https://eprint.iacr.org/2017/727.pdf, also some notes at section 3.1 of
|
||||||
/// https://github.com/arnaucube/math/blob/master/notes_ntt.pdf
|
/// https://github.com/arnaucube/math/blob/master/notes_ntt.pdf
|
||||||
pub fn ntt(a: &Rq) -> Rq {
|
pub fn ntt(q: u64, n: usize, a: &Vec<u64>) -> Vec<u64> {
|
||||||
let (q, n) = (a.param.q, a.param.n);
|
debug_assert_eq!(n, a.len());
|
||||||
|
|
||||||
let (roots_of_unity, _, _) = roots(q, n);
|
let (roots_of_unity, _, _) = roots(q, n);
|
||||||
|
|
||||||
let mut t = n / 2;
|
let mut t = n / 2;
|
||||||
let mut m = 1;
|
let mut m = 1;
|
||||||
let mut r: Vec<Zq> = a.coeffs.clone();
|
let mut r: Vec<u64> = a.clone();
|
||||||
while m < n {
|
while m < n {
|
||||||
let mut k = 0;
|
let mut k = 0;
|
||||||
for i in 0..m {
|
for i in 0..m {
|
||||||
let S: Zq = roots_of_unity[m + i];
|
let S: u64 = roots_of_unity[m + i];
|
||||||
for j in k..k + t {
|
for j in k..k + t {
|
||||||
let U: Zq = r[j];
|
let U: u64 = r[j];
|
||||||
let V: Zq = r[j + t] * S;
|
let V: u64 = (r[j + t] * S) % q;
|
||||||
|
// compute r[j] = (U + V) % q:
|
||||||
r[j] = U + V;
|
r[j] = U + V;
|
||||||
r[j + t] = U - V;
|
if r[j] >= q {
|
||||||
|
r[j] -= q;
|
||||||
|
}
|
||||||
|
// compute r[j + t] = (U - V) % q:
|
||||||
|
if U >= V {
|
||||||
|
r[j + t] = U - V;
|
||||||
|
} else {
|
||||||
|
r[j + t] = (q + U) - V;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
k = k + 2 * t;
|
k = k + 2 * t;
|
||||||
}
|
}
|
||||||
t /= 2;
|
t /= 2;
|
||||||
m *= 2;
|
m *= 2;
|
||||||
}
|
}
|
||||||
// TODO think if maybe not return a Rq type, or if returned Rq, maybe
|
r
|
||||||
// fill the `evals` field, which is what we're actually returning here
|
|
||||||
Rq {
|
|
||||||
param: RingParam { q, n },
|
|
||||||
coeffs: r,
|
|
||||||
evals: None,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// implements the Cooley-Tukey (CT) algorithm. Details at
|
/// implements the Cooley-Tukey (CT) algorithm. Details at
|
||||||
/// https://eprint.iacr.org/2017/727.pdf, also some notes at section 3.2 of
|
/// https://eprint.iacr.org/2017/727.pdf, also some notes at section 3.2 of
|
||||||
/// https://github.com/arnaucube/math/blob/master/notes_ntt.pdf
|
/// https://github.com/arnaucube/math/blob/master/notes_ntt.pdf
|
||||||
pub fn intt(a: &Rq) -> Rq {
|
pub fn intt(q: u64, n: usize, a: &Vec<u64>) -> Vec<u64> {
|
||||||
let (q, n) = (a.param.q, a.param.n);
|
debug_assert_eq!(n, a.len());
|
||||||
|
|
||||||
let (_, roots_of_unity_inv, n_inv) = roots(q, n);
|
let (_, roots_of_unity_inv, n_inv) = roots(q, n);
|
||||||
|
|
||||||
let mut t = 1;
|
let mut t = 1;
|
||||||
let mut m = n / 2;
|
let mut m = n / 2;
|
||||||
let mut r: Vec<Zq> = a.coeffs.clone();
|
let mut r: Vec<u64> = a.clone();
|
||||||
while m > 0 {
|
while m > 0 {
|
||||||
let mut k = 0;
|
let mut k = 0;
|
||||||
for i in 0..m {
|
for i in 0..m {
|
||||||
let S: Zq = roots_of_unity_inv[m + i];
|
let S: u64 = roots_of_unity_inv[m + i];
|
||||||
for j in k..k + t {
|
for j in k..k + t {
|
||||||
let U: Zq = r[j];
|
let U: u64 = r[j];
|
||||||
let V: Zq = r[j + t];
|
let V: u64 = r[j + t];
|
||||||
|
// compute r[j] = (U + V) % q:
|
||||||
r[j] = U + V;
|
r[j] = U + V;
|
||||||
r[j + t] = (U - V) * S;
|
if r[j] >= q {
|
||||||
|
r[j] -= q;
|
||||||
|
}
|
||||||
|
// compute r[j + t] = ((U - V) * S) % q;
|
||||||
|
if U >= V {
|
||||||
|
r[j + t] = ((U - V) * S) % q;
|
||||||
|
} else {
|
||||||
|
r[j + t] = ((q + U - V) * S) % q;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
k += 2 * t;
|
k += 2 * t;
|
||||||
}
|
}
|
||||||
@@ -98,15 +108,9 @@ impl NTT {
|
|||||||
m /= 2;
|
m /= 2;
|
||||||
}
|
}
|
||||||
for i in 0..n {
|
for i in 0..n {
|
||||||
r[i] = r[i] * n_inv;
|
r[i] = (r[i] * n_inv) % q;
|
||||||
}
|
|
||||||
Rq {
|
|
||||||
param: RingParam { q, n },
|
|
||||||
coeffs: r,
|
|
||||||
// TODO maybe at `evals` place the inputed `a` which is the evals
|
|
||||||
// format
|
|
||||||
evals: None,
|
|
||||||
}
|
}
|
||||||
|
r
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -130,31 +134,25 @@ const fn primitive_root_of_unity(q: u64, n: usize) -> u64 {
|
|||||||
panic!("No primitive root of unity");
|
panic!("No primitive root of unity");
|
||||||
}
|
}
|
||||||
|
|
||||||
fn roots_of_unity(q: u64, n: usize, w: u64) -> Vec<Zq> {
|
fn roots_of_unity(q: u64, n: usize, w: u64) -> Vec<u64> {
|
||||||
let mut r: Vec<Zq> = vec![Zq { q, v: 0 }; n];
|
let mut r: Vec<u64> = vec![0; n];
|
||||||
let mut i = 0;
|
let mut i = 0;
|
||||||
let log_n = n.ilog2();
|
let log_n = n.ilog2();
|
||||||
while i < n {
|
while i < n {
|
||||||
// (return the roots in bit-reverset order)
|
// (return the roots in bit-reverset order)
|
||||||
let j = ((i as u64).reverse_bits() >> (64 - log_n)) as usize;
|
let j = ((i as u64).reverse_bits() >> (64 - log_n)) as usize;
|
||||||
r[i] = Zq {
|
r[i] = const_exp_mod(q, w, j as u64);
|
||||||
q,
|
|
||||||
v: const_exp_mod(q, w, j as u64),
|
|
||||||
};
|
|
||||||
i += 1;
|
i += 1;
|
||||||
}
|
}
|
||||||
r
|
r
|
||||||
}
|
}
|
||||||
|
|
||||||
fn roots_of_unity_inv(q: u64, n: usize, v: Vec<Zq>) -> Vec<Zq> {
|
fn roots_of_unity_inv(q: u64, n: usize, v: Vec<u64>) -> Vec<u64> {
|
||||||
// assumes that the inputted roots are already in bit-reverset order
|
// assumes that the inputted roots are already in bit-reverset order
|
||||||
let mut r: Vec<Zq> = vec![Zq { q, v: 0 }; n];
|
let mut r: Vec<u64> = vec![0; n];
|
||||||
let mut i = 0;
|
let mut i = 0;
|
||||||
while i < n {
|
while i < n {
|
||||||
r[i] = Zq {
|
r[i] = const_inv_mod(q, v[i]);
|
||||||
q,
|
|
||||||
v: const_inv_mod(q, v[i].v),
|
|
||||||
};
|
|
||||||
i += 1;
|
i += 1;
|
||||||
}
|
}
|
||||||
r
|
r
|
||||||
@@ -187,7 +185,7 @@ const fn const_inv_mod(q: u64, x: u64) -> u64 {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::Ring;
|
use rand_distr::Distribution;
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
|
|
||||||
@@ -195,14 +193,12 @@ mod tests {
|
|||||||
fn test_ntt() -> Result<()> {
|
fn test_ntt() -> Result<()> {
|
||||||
let q: u64 = 2u64.pow(16) + 1;
|
let q: u64 = 2u64.pow(16) + 1;
|
||||||
let n: usize = 4;
|
let n: usize = 4;
|
||||||
let param = RingParam { q, n };
|
|
||||||
|
|
||||||
let a: Vec<u64> = vec![1u64, 2, 3, 4];
|
let a: Vec<u64> = vec![1u64, 2, 3, 4];
|
||||||
let a: Rq = Rq::from_vec_u64(¶m, a);
|
|
||||||
|
|
||||||
let a_ntt = NTT::ntt(&a);
|
let a_ntt = NTT::ntt(q, n, &a);
|
||||||
|
|
||||||
let a_intt = NTT::intt(&a_ntt);
|
let a_intt = NTT::intt(q, n, &a_ntt);
|
||||||
|
|
||||||
dbg!(&a);
|
dbg!(&a);
|
||||||
dbg!(&a_ntt);
|
dbg!(&a_ntt);
|
||||||
@@ -218,16 +214,17 @@ mod tests {
|
|||||||
fn test_ntt_loop() -> Result<()> {
|
fn test_ntt_loop() -> Result<()> {
|
||||||
let q: u64 = 2u64.pow(16) + 1;
|
let q: u64 = 2u64.pow(16) + 1;
|
||||||
let n: usize = 512;
|
let n: usize = 512;
|
||||||
let param = RingParam { q, n };
|
|
||||||
|
|
||||||
use rand::distributions::Uniform;
|
use rand::distributions::Uniform;
|
||||||
let mut rng = rand::thread_rng();
|
let mut rng = rand::thread_rng();
|
||||||
let dist = Uniform::new(0_f64, q as f64);
|
let dist = Uniform::new(0_u64, q as u64);
|
||||||
|
|
||||||
for _ in 0..1000 {
|
for _ in 0..1000 {
|
||||||
let a: Rq = Rq::rand(&mut rng, dist, ¶m);
|
let a: Vec<u64> = std::iter::repeat_with(|| dist.sample(&mut rng))
|
||||||
let a_ntt = NTT::ntt(&a);
|
.take(n)
|
||||||
let a_intt = NTT::intt(&a_ntt);
|
.collect();
|
||||||
|
let a_ntt = NTT::ntt(q, n, &a);
|
||||||
|
let a_intt = NTT::intt(q, n, &a_ntt);
|
||||||
assert_eq!(a, a_intt);
|
assert_eq!(a, a_intt);
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|||||||
@@ -113,6 +113,24 @@ impl Ring for Rq {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Rq {
|
||||||
|
fn coeffs_u64(&self) -> Vec<u64> {
|
||||||
|
self.coeffs.iter().map(|c_i| c_i.v).collect()
|
||||||
|
}
|
||||||
|
fn ntt(&self) -> Vec<Zq> {
|
||||||
|
NTT::ntt(self.param.q, self.param.n, &self.coeffs_u64())
|
||||||
|
.iter()
|
||||||
|
.map(|c_i| Zq::from_u64(self.param.q, *c_i))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
fn intt(&self) -> Vec<Zq> {
|
||||||
|
NTT::intt(self.param.q, self.param.n, &self.coeffs_u64())
|
||||||
|
.iter()
|
||||||
|
.map(|c_i| Zq::from_u64(self.param.q, *c_i))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl From<(u64, crate::ring_n::R)> for Rq {
|
impl From<(u64, crate::ring_n::R)> for Rq {
|
||||||
fn from(qr: (u64, crate::ring_n::R)) -> Self {
|
fn from(qr: (u64, crate::ring_n::R)) -> Self {
|
||||||
let (q, r) = qr;
|
let (q, r) = qr;
|
||||||
@@ -145,7 +163,7 @@ impl Rq {
|
|||||||
self.coeffs.clone()
|
self.coeffs.clone()
|
||||||
}
|
}
|
||||||
pub fn compute_evals(&mut self) {
|
pub fn compute_evals(&mut self) {
|
||||||
self.evals = Some(NTT::ntt(self).coeffs);
|
self.evals = Some(self.ntt());
|
||||||
// TODO improve, ntt returns Rq but here just needs Vec<Zq>
|
// TODO improve, ntt returns Rq but here just needs Vec<Zq>
|
||||||
}
|
}
|
||||||
pub fn to_r(self) -> crate::R {
|
pub fn to_r(self) -> crate::R {
|
||||||
@@ -566,10 +584,10 @@ fn mul_mut(lhs: &mut Rq, rhs: &mut Rq) -> Rq {
|
|||||||
|
|
||||||
// reuse evaluations if already computed
|
// reuse evaluations if already computed
|
||||||
if !lhs.evals.is_some() {
|
if !lhs.evals.is_some() {
|
||||||
lhs.evals = Some(NTT::ntt(lhs).coeffs);
|
lhs.evals = Some(lhs.ntt());
|
||||||
};
|
};
|
||||||
if !rhs.evals.is_some() {
|
if !rhs.evals.is_some() {
|
||||||
rhs.evals = Some(NTT::ntt(rhs).coeffs);
|
rhs.evals = Some(rhs.ntt());
|
||||||
};
|
};
|
||||||
let lhs_evals = lhs.evals.clone().unwrap();
|
let lhs_evals = lhs.evals.clone().unwrap();
|
||||||
let rhs_evals = rhs.evals.clone().unwrap();
|
let rhs_evals = rhs.evals.clone().unwrap();
|
||||||
@@ -578,8 +596,8 @@ fn mul_mut(lhs: &mut Rq, rhs: &mut Rq) -> Rq {
|
|||||||
&lhs.param,
|
&lhs.param,
|
||||||
zip_eq(lhs_evals, rhs_evals).map(|(l, r)| l * r).collect(),
|
zip_eq(lhs_evals, rhs_evals).map(|(l, r)| l * r).collect(),
|
||||||
);
|
);
|
||||||
let c = NTT::intt(&c_ntt);
|
let c: Vec<Zq> = c_ntt.intt();
|
||||||
Rq::new(&lhs.param, c.coeffs, Some(c_ntt.coeffs))
|
Rq::new(&lhs.param, c, Some(c_ntt.coeffs))
|
||||||
}
|
}
|
||||||
// note: this assumes that Q is prime
|
// note: this assumes that Q is prime
|
||||||
// TODO impl karatsuba for non-prime Q. Alternatively check NTT with RNS trick.
|
// TODO impl karatsuba for non-prime Q. Alternatively check NTT with RNS trick.
|
||||||
@@ -590,20 +608,20 @@ fn mul(lhs: &Rq, rhs: &Rq) -> Rq {
|
|||||||
let lhs_evals: Vec<Zq> = if lhs.evals.is_some() {
|
let lhs_evals: Vec<Zq> = if lhs.evals.is_some() {
|
||||||
lhs.evals.clone().unwrap()
|
lhs.evals.clone().unwrap()
|
||||||
} else {
|
} else {
|
||||||
NTT::ntt(lhs).coeffs
|
lhs.ntt()
|
||||||
};
|
};
|
||||||
let rhs_evals: Vec<Zq> = if rhs.evals.is_some() {
|
let rhs_evals: Vec<Zq> = if rhs.evals.is_some() {
|
||||||
rhs.evals.clone().unwrap()
|
rhs.evals.clone().unwrap()
|
||||||
} else {
|
} else {
|
||||||
NTT::ntt(rhs).coeffs
|
rhs.ntt()
|
||||||
};
|
};
|
||||||
|
|
||||||
let c_ntt: Rq = Rq::from_vec(
|
let c_ntt: Rq = Rq::from_vec(
|
||||||
&lhs.param,
|
&lhs.param,
|
||||||
zip_eq(lhs_evals, rhs_evals).map(|(l, r)| l * r).collect(),
|
zip_eq(lhs_evals, rhs_evals).map(|(l, r)| l * r).collect(),
|
||||||
);
|
);
|
||||||
let c = NTT::intt(&c_ntt);
|
let c = c_ntt.intt();
|
||||||
Rq::new(&lhs.param, c.coeffs, Some(c_ntt.coeffs))
|
Rq::new(&lhs.param, c, Some(c_ntt.coeffs))
|
||||||
}
|
}
|
||||||
|
|
||||||
impl fmt::Display for Rq {
|
impl fmt::Display for Rq {
|
||||||
|
|||||||
@@ -252,6 +252,7 @@ impl Mul<Tn> for Tn {
|
|||||||
type Output = Self;
|
type Output = Self;
|
||||||
|
|
||||||
fn mul(self, rhs: Self) -> Self {
|
fn mul(self, rhs: Self) -> Self {
|
||||||
|
// TODO NTT/FFT
|
||||||
naive_poly_mul(&self, &rhs)
|
naive_poly_mul(&self, &rhs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -259,6 +260,7 @@ impl Mul<&Tn> for &Tn {
|
|||||||
type Output = Tn;
|
type Output = Tn;
|
||||||
|
|
||||||
fn mul(self, rhs: &Tn) -> Self::Output {
|
fn mul(self, rhs: &Tn) -> Self::Output {
|
||||||
|
// TODO NTT/FFT
|
||||||
naive_poly_mul(self, rhs)
|
naive_poly_mul(self, rhs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user