From 4dca2c6ff5c6da0a3bd897fd6e0742f81da0213d Mon Sep 17 00:00:00 2001 From: arnaucube Date: Mon, 18 Aug 2025 14:14:19 +0000 Subject: [PATCH] ntt: get rid of Zq and use u64 instead (>2x speed improvement) --- arith/src/ntt.rs | 117 ++++++++++++++++++++-------------------- arith/src/ring_nq.rs | 36 +++++++++---- arith/src/ring_torus.rs | 2 + 3 files changed, 86 insertions(+), 69 deletions(-) diff --git a/arith/src/ntt.rs b/arith/src/ntt.rs index 890f8ed..9970904 100644 --- a/arith/src/ntt.rs +++ b/arith/src/ntt.rs @@ -6,7 +6,6 @@ //! generics; but once using real-world parameters, the stack could not handle //! it, so moved to use Vec instead of fixed-sized arrays, and adapted the NTT //! implementation to that too. -use crate::{ring::RingParam, ring_nq::Rq, zq::Zq}; use std::collections::HashMap; @@ -15,22 +14,19 @@ pub struct NTT {} use std::sync::{Mutex, OnceLock}; -static CACHE: OnceLock, Vec, Zq)>>> = OnceLock::new(); +static CACHE: OnceLock, Vec, u64)>>> = OnceLock::new(); -fn roots(q: u64, n: usize) -> (Vec, Vec, Zq) { +fn roots(q: u64, n: usize) -> (Vec, Vec, u64) { let cache_lock = CACHE.get_or_init(|| Mutex::new(HashMap::new())); let mut cache = cache_lock.lock().unwrap(); if let Some(value) = cache.get(&(q, n)) { return value.clone(); } - let n_inv: Zq = Zq { - q, - v: const_inv_mod(q, n as u64), - }; + let n_inv: u64 = const_inv_mod(q, n as u64); let root_of_unity: u64 = primitive_root_of_unity(q, 2 * n); - let roots_of_unity: Vec = roots_of_unity(q, n, root_of_unity); - let roots_of_unity_inv: Vec = roots_of_unity_inv(q, n, roots_of_unity.clone()); + let roots_of_unity: Vec = roots_of_unity(q, n, root_of_unity); + let roots_of_unity_inv: Vec = roots_of_unity_inv(q, n, roots_of_unity.clone()); let value = (roots_of_unity, roots_of_unity_inv, n_inv); cache.insert((q, n), value.clone()); @@ -41,56 +37,70 @@ impl NTT { /// implements the Cooley-Tukey (CT) algorithm. Details at /// https://eprint.iacr.org/2017/727.pdf, also some notes at section 3.1 of /// https://github.com/arnaucube/math/blob/master/notes_ntt.pdf - pub fn ntt(a: &Rq) -> Rq { - let (q, n) = (a.param.q, a.param.n); + pub fn ntt(q: u64, n: usize, a: &Vec) -> Vec { + debug_assert_eq!(n, a.len()); + let (roots_of_unity, _, _) = roots(q, n); let mut t = n / 2; let mut m = 1; - let mut r: Vec = a.coeffs.clone(); + let mut r: Vec = a.clone(); while m < n { let mut k = 0; for i in 0..m { - let S: Zq = roots_of_unity[m + i]; + let S: u64 = roots_of_unity[m + i]; for j in k..k + t { - let U: Zq = r[j]; - let V: Zq = r[j + t] * S; + let U: u64 = r[j]; + let V: u64 = (r[j + t] * S) % q; + // compute r[j] = (U + V) % q: r[j] = U + V; - r[j + t] = U - V; + if r[j] >= q { + r[j] -= q; + } + // compute r[j + t] = (U - V) % q: + if U >= V { + r[j + t] = U - V; + } else { + r[j + t] = (q + U) - V; + } } k = k + 2 * t; } t /= 2; m *= 2; } - // TODO think if maybe not return a Rq type, or if returned Rq, maybe - // fill the `evals` field, which is what we're actually returning here - Rq { - param: RingParam { q, n }, - coeffs: r, - evals: None, - } + r } /// implements the Cooley-Tukey (CT) algorithm. Details at /// https://eprint.iacr.org/2017/727.pdf, also some notes at section 3.2 of /// https://github.com/arnaucube/math/blob/master/notes_ntt.pdf - pub fn intt(a: &Rq) -> Rq { - let (q, n) = (a.param.q, a.param.n); + pub fn intt(q: u64, n: usize, a: &Vec) -> Vec { + debug_assert_eq!(n, a.len()); + let (_, roots_of_unity_inv, n_inv) = roots(q, n); let mut t = 1; let mut m = n / 2; - let mut r: Vec = a.coeffs.clone(); + let mut r: Vec = a.clone(); while m > 0 { let mut k = 0; for i in 0..m { - let S: Zq = roots_of_unity_inv[m + i]; + let S: u64 = roots_of_unity_inv[m + i]; for j in k..k + t { - let U: Zq = r[j]; - let V: Zq = r[j + t]; + let U: u64 = r[j]; + let V: u64 = r[j + t]; + // compute r[j] = (U + V) % q: r[j] = U + V; - r[j + t] = (U - V) * S; + if r[j] >= q { + r[j] -= q; + } + // compute r[j + t] = ((U - V) * S) % q; + if U >= V { + r[j + t] = ((U - V) * S) % q; + } else { + r[j + t] = ((q + U - V) * S) % q; + } } k += 2 * t; } @@ -98,15 +108,9 @@ impl NTT { m /= 2; } for i in 0..n { - r[i] = r[i] * n_inv; - } - Rq { - param: RingParam { q, n }, - coeffs: r, - // TODO maybe at `evals` place the inputed `a` which is the evals - // format - evals: None, + r[i] = (r[i] * n_inv) % q; } + r } } @@ -130,31 +134,25 @@ const fn primitive_root_of_unity(q: u64, n: usize) -> u64 { panic!("No primitive root of unity"); } -fn roots_of_unity(q: u64, n: usize, w: u64) -> Vec { - let mut r: Vec = vec![Zq { q, v: 0 }; n]; +fn roots_of_unity(q: u64, n: usize, w: u64) -> Vec { + let mut r: Vec = vec![0; n]; let mut i = 0; let log_n = n.ilog2(); while i < n { // (return the roots in bit-reverset order) let j = ((i as u64).reverse_bits() >> (64 - log_n)) as usize; - r[i] = Zq { - q, - v: const_exp_mod(q, w, j as u64), - }; + r[i] = const_exp_mod(q, w, j as u64); i += 1; } r } -fn roots_of_unity_inv(q: u64, n: usize, v: Vec) -> Vec { +fn roots_of_unity_inv(q: u64, n: usize, v: Vec) -> Vec { // assumes that the inputted roots are already in bit-reverset order - let mut r: Vec = vec![Zq { q, v: 0 }; n]; + let mut r: Vec = vec![0; n]; let mut i = 0; while i < n { - r[i] = Zq { - q, - v: const_inv_mod(q, v[i].v), - }; + r[i] = const_inv_mod(q, v[i]); i += 1; } r @@ -187,7 +185,7 @@ const fn const_inv_mod(q: u64, x: u64) -> u64 { #[cfg(test)] mod tests { use super::*; - use crate::Ring; + use rand_distr::Distribution; use anyhow::Result; @@ -195,14 +193,12 @@ mod tests { fn test_ntt() -> Result<()> { let q: u64 = 2u64.pow(16) + 1; let n: usize = 4; - let param = RingParam { q, n }; let a: Vec = vec![1u64, 2, 3, 4]; - let a: Rq = Rq::from_vec_u64(¶m, a); - let a_ntt = NTT::ntt(&a); + let a_ntt = NTT::ntt(q, n, &a); - let a_intt = NTT::intt(&a_ntt); + let a_intt = NTT::intt(q, n, &a_ntt); dbg!(&a); dbg!(&a_ntt); @@ -218,16 +214,17 @@ mod tests { fn test_ntt_loop() -> Result<()> { let q: u64 = 2u64.pow(16) + 1; let n: usize = 512; - let param = RingParam { q, n }; use rand::distributions::Uniform; let mut rng = rand::thread_rng(); - let dist = Uniform::new(0_f64, q as f64); + let dist = Uniform::new(0_u64, q as u64); for _ in 0..1000 { - let a: Rq = Rq::rand(&mut rng, dist, ¶m); - let a_ntt = NTT::ntt(&a); - let a_intt = NTT::intt(&a_ntt); + let a: Vec = std::iter::repeat_with(|| dist.sample(&mut rng)) + .take(n) + .collect(); + let a_ntt = NTT::ntt(q, n, &a); + let a_intt = NTT::intt(q, n, &a_ntt); assert_eq!(a, a_intt); } Ok(()) diff --git a/arith/src/ring_nq.rs b/arith/src/ring_nq.rs index d51c1c2..4a938bb 100644 --- a/arith/src/ring_nq.rs +++ b/arith/src/ring_nq.rs @@ -113,6 +113,24 @@ impl Ring for Rq { } } +impl Rq { + fn coeffs_u64(&self) -> Vec { + self.coeffs.iter().map(|c_i| c_i.v).collect() + } + fn ntt(&self) -> Vec { + NTT::ntt(self.param.q, self.param.n, &self.coeffs_u64()) + .iter() + .map(|c_i| Zq::from_u64(self.param.q, *c_i)) + .collect() + } + fn intt(&self) -> Vec { + NTT::intt(self.param.q, self.param.n, &self.coeffs_u64()) + .iter() + .map(|c_i| Zq::from_u64(self.param.q, *c_i)) + .collect() + } +} + impl From<(u64, crate::ring_n::R)> for Rq { fn from(qr: (u64, crate::ring_n::R)) -> Self { let (q, r) = qr; @@ -145,7 +163,7 @@ impl Rq { self.coeffs.clone() } pub fn compute_evals(&mut self) { - self.evals = Some(NTT::ntt(self).coeffs); + self.evals = Some(self.ntt()); // TODO improve, ntt returns Rq but here just needs Vec } pub fn to_r(self) -> crate::R { @@ -566,10 +584,10 @@ fn mul_mut(lhs: &mut Rq, rhs: &mut Rq) -> Rq { // reuse evaluations if already computed if !lhs.evals.is_some() { - lhs.evals = Some(NTT::ntt(lhs).coeffs); + lhs.evals = Some(lhs.ntt()); }; if !rhs.evals.is_some() { - rhs.evals = Some(NTT::ntt(rhs).coeffs); + rhs.evals = Some(rhs.ntt()); }; let lhs_evals = lhs.evals.clone().unwrap(); let rhs_evals = rhs.evals.clone().unwrap(); @@ -578,8 +596,8 @@ fn mul_mut(lhs: &mut Rq, rhs: &mut Rq) -> Rq { &lhs.param, zip_eq(lhs_evals, rhs_evals).map(|(l, r)| l * r).collect(), ); - let c = NTT::intt(&c_ntt); - Rq::new(&lhs.param, c.coeffs, Some(c_ntt.coeffs)) + let c: Vec = c_ntt.intt(); + Rq::new(&lhs.param, c, Some(c_ntt.coeffs)) } // note: this assumes that Q is prime // TODO impl karatsuba for non-prime Q. Alternatively check NTT with RNS trick. @@ -590,20 +608,20 @@ fn mul(lhs: &Rq, rhs: &Rq) -> Rq { let lhs_evals: Vec = if lhs.evals.is_some() { lhs.evals.clone().unwrap() } else { - NTT::ntt(lhs).coeffs + lhs.ntt() }; let rhs_evals: Vec = if rhs.evals.is_some() { rhs.evals.clone().unwrap() } else { - NTT::ntt(rhs).coeffs + rhs.ntt() }; let c_ntt: Rq = Rq::from_vec( &lhs.param, zip_eq(lhs_evals, rhs_evals).map(|(l, r)| l * r).collect(), ); - let c = NTT::intt(&c_ntt); - Rq::new(&lhs.param, c.coeffs, Some(c_ntt.coeffs)) + let c = c_ntt.intt(); + Rq::new(&lhs.param, c, Some(c_ntt.coeffs)) } impl fmt::Display for Rq { diff --git a/arith/src/ring_torus.rs b/arith/src/ring_torus.rs index 7f5d756..d9602e5 100644 --- a/arith/src/ring_torus.rs +++ b/arith/src/ring_torus.rs @@ -252,6 +252,7 @@ impl Mul for Tn { type Output = Self; fn mul(self, rhs: Self) -> Self { + // TODO NTT/FFT naive_poly_mul(&self, &rhs) } } @@ -259,6 +260,7 @@ impl Mul<&Tn> for &Tn { type Output = Tn; fn mul(self, rhs: &Tn) -> Self::Output { + // TODO NTT/FFT naive_poly_mul(self, rhs) } }