use itertools::{izip, Itertools}; use rand::{thread_rng, Rng, RngCore, SeedableRng}; use rand_chacha::{rand_core::le, ChaCha8Rng}; use crate::{ backend::{ArithmeticOps, ModInit, ModularOpsU64, Modulus}, utils::{mod_exponent, mod_inverse, shoup_representation_fq}, }; pub trait NttInit { /// Ntt istance must be compatible across different instances with same `q` /// and `n` fn new(q: &M, n: usize) -> Self; } pub trait Ntt { type Element; fn forward_lazy(&self, v: &mut [Self::Element]); fn forward(&self, v: &mut [Self::Element]); fn backward_lazy(&self, v: &mut [Self::Element]); fn backward(&self, v: &mut [Self::Element]); } /// Forward butterfly routine for Number theoretic transform. Given inputs `x < /// 4q` and `y < 4q` mutates x and y in place to equal x' and y' where /// x' = x + wy /// y' = x - wy /// and both x' and y' are \in [0, 4q) /// /// Implements Algorithm 4 of [FASTER ARITHMETIC FOR NUMBER-THEORETIC TRANSFORMS](https://arxiv.org/pdf/1205.2926.pdf) pub fn forward_butterly_0_to_4q( mut x: u64, y: u64, w: u64, w_shoup: u64, q: u64, q_twice: u64, ) -> (u64, u64) { debug_assert!(x < q * 4, "{} >= (4q){}", x, 4 * q); debug_assert!(y < q * 4, "{} >= (4q){}", y, 4 * q); if x >= q_twice { x = x - q_twice; } // TODO (Jay): Hot path expected. How expensive is it? let k = ((w_shoup as u128 * y as u128) >> 64) as u64; let t = w.wrapping_mul(y).wrapping_sub(k.wrapping_mul(q)); (x + t, x + q_twice - t) } pub fn forward_butterly_0_to_2q( mut x: u64, mut y: u64, w: u64, w_shoup: u64, q: u64, q_twice: u64, ) -> (u64, u64) { debug_assert!(x < q * 4, "{} >= (4q){}", x, 4 * q); debug_assert!(y < q * 4, "{} >= (4q){}", y, 4 * q); if x >= q_twice { x = x - q_twice; } let k = ((w_shoup as u128 * y as u128) >> 64) as u64; let t = w.wrapping_mul(y).wrapping_sub(k.wrapping_mul(q)); let ox = x.wrapping_add(t); let oy = x.wrapping_sub(t); ( (ox).min(ox.wrapping_sub(q_twice)), oy.min(oy.wrapping_add(q_twice)), ) } /// Inverse butterfly routine of Inverse Number theoretic transform. Given /// inputs `x < 2q` and `y < 2q` mutates x and y to equal x' and y' where /// x'= x + y /// y' = w(x - y) /// and both x' and y' are \in [0, 2q) /// /// Implements Algorithm 3 of [FASTER ARITHMETIC FOR NUMBER-THEORETIC TRANSFORMS](https://arxiv.org/pdf/1205.2926.pdf) pub unsafe fn inverse_butterfly( x: *mut u64, y: *mut u64, w_inv: &u64, w_inv_shoup: &u64, q: &u64, q_twice: &u64, ) { debug_assert!(*x < *q_twice, "{} >= (2q){q_twice}", *x); debug_assert!(*y < *q_twice, "{} >= (2q){q_twice}", *y); let mut x_dash = *x + *y; if x_dash >= *q_twice { x_dash -= q_twice } let t = *x + q_twice - *y; let k = ((*w_inv_shoup as u128 * t as u128) >> 64) as u64; // TODO (Jay): Hot path *y = w_inv.wrapping_mul(t).wrapping_sub(k.wrapping_mul(*q)); *x = x_dash; } /// Number theoretic transform of vector `a` where each element can be in range /// [0, 2q). Outputs NTT(a) where each element is in range [0,2q) /// /// Implements Cooley-tukey based forward NTT as given in Algorithm 1 of https://eprint.iacr.org/2016/504.pdf. pub fn ntt_lazy(a: &mut [u64], psi: &[u64], psi_shoup: &[u64], q: u64, q_twice: u64) { assert!(a.len() == psi.len()); let n = a.len(); let mut t = n; let mut m = 1; while m < n { t >>= 1; let w = &psi[m..]; let w_shoup = &psi_shoup[m..]; if t == 1 { for (a, w, w_shoup) in izip!(a.chunks_mut(2), w.iter(), w_shoup.iter()) { let (ox, oy) = forward_butterly_0_to_2q(a[0], a[1], *w, *w_shoup, q, q_twice); a[0] = ox; a[1] = oy; } } else { for i in 0..m { let a = &mut a[2 * i * t..(2 * (i + 1) * t)]; let (left, right) = a.split_at_mut(t); for (x, y) in izip!(left.iter_mut(), right.iter_mut()) { let (ox, oy) = forward_butterly_0_to_4q(*x, *y, w[i], w_shoup[i], q, q_twice); *x = ox; *y = oy; } } } m <<= 1; } } /// Same as `ntt_lazy` with output in range [0, q) pub fn ntt(a: &mut [u64], psi: &[u64], psi_shoup: &[u64], q: u64, q_twice: u64) { assert!(a.len() == psi.len()); let n = a.len(); let mut t = n; let mut m = 1; while m < n { t >>= 1; let w = &psi[m..]; let w_shoup = &psi_shoup[m..]; if t == 1 { for (a, w, w_shoup) in izip!(a.chunks_mut(2), w.iter(), w_shoup.iter()) { let (ox, oy) = forward_butterly_0_to_2q(a[0], a[1], *w, *w_shoup, q, q_twice); a[0] = ox.min(ox.wrapping_sub(q_twice)); a[1] = oy.min(oy.wrapping_sub(q_twice)); } } else { for i in 0..m { let a = &mut a[2 * i * t..(2 * (i + 1) * t)]; let (left, right) = a.split_at_mut(t); for (x, y) in izip!(left.iter_mut(), right.iter_mut()) { let (ox, oy) = forward_butterly_0_to_4q(*x, *y, w[i], w_shoup[i], q, q_twice); *x = ox; *y = oy; } } } m <<= 1; } } /// Inverse number theoretic transform of input vector `a` with each element can /// be in range [0, 2q). Outputs vector INTT(a) with each element in range [0, /// 2q) /// /// Implements backward number theorectic transform using GS algorithm as given in Algorithm 2 of https://eprint.iacr.org/2016/504.pdf pub fn ntt_inv_lazy( a: &mut [u64], psi_inv: &[u64], psi_inv_shoup: &[u64], n_inv: u64, q: u64, q_twice: u64, ) { debug_assert!(a.len() == psi_inv.len()); let mut m = a.len(); let mut t = 1; while m > 1 { let mut j_1: usize = 0; let h = m >> 1; for i in 0..h { let j_2 = j_1 + t; unsafe { let w_inv = psi_inv.get_unchecked(h + i); let w_inv_shoup = psi_inv_shoup.get_unchecked(h + i); for j in j_1..j_2 { let x = a.get_unchecked_mut(j) as *mut u64; let y = a.get_unchecked_mut(j + t) as *mut u64; inverse_butterfly(x, y, w_inv, w_inv_shoup, &q, &q_twice); } } j_1 = j_1 + 2 * t; } t *= 2; m >>= 1; } a.iter_mut() .for_each(|a0| *a0 = ((*a0 as u128 * n_inv as u128) % q as u128) as u64); } /// Find n^{th} root of unity in field F_q, if one exists /// /// Note: n^{th} root of unity exists if and only if $q = 1 \mod{n}$ pub(crate) fn find_primitive_root(q: u64, n: u64, rng: &mut R) -> Option { assert!(n.is_power_of_two(), "{n} is not power of two"); // n^th root of unity only exists if n|(q-1) assert!(q % n == 1, "{n}^th root of unity in F_{q} does not exists"); let t = (q - 1) / n; for _ in 0..100 { let mut omega = rng.gen::() % q; // \omega = \omega^t. \omega is now n^th root of unity omega = mod_exponent(omega, t, q); // We restrict n to be power of 2. Thus checking whether \omega is primitive // n^th root of unity is as simple as checking: \omega^{n/2} != 1 if mod_exponent(omega, n >> 1, q) == 1 { continue; } else { return Some(omega); } } None } #[derive(Debug)] pub struct NttBackendU64 { q: u64, q_twice: u64, n: u64, n_inv: u64, psi_powers_bo: Box<[u64]>, psi_inv_powers_bo: Box<[u64]>, psi_powers_bo_shoup: Box<[u64]>, psi_inv_powers_bo_shoup: Box<[u64]>, } impl NttBackendU64 { fn _new(q: u64, n: usize) -> Self { // \psi = 2n^{th} primitive root of unity in F_q let mut rng = ChaCha8Rng::from_seed([0u8; 32]); let psi = find_primitive_root(q, (n * 2) as u64, &mut rng) .expect("Unable to find 2n^th root of unity"); let psi_inv = mod_inverse(psi, q); // assert!( // ((psi_inv as u128 * psi as u128) % q as u128) == 1, // "psi:{psi}, psi_inv:{psi_inv}" // ); let modulus = ModularOpsU64::new(q); let mut psi_powers = Vec::with_capacity(n as usize); let mut psi_inv_powers = Vec::with_capacity(n as usize); let mut running_psi = 1; let mut running_psi_inv = 1; for _ in 0..n { psi_powers.push(running_psi); psi_inv_powers.push(running_psi_inv); running_psi = modulus.mul(&running_psi, &psi); running_psi_inv = modulus.mul(&running_psi_inv, &psi_inv); } // powers stored in bit reversed order let mut psi_powers_bo = vec![0u64; n as usize]; let mut psi_inv_powers_bo = vec![0u64; n as usize]; let shift_by = n.leading_zeros() + 1; for i in 0..n as usize { // i in bit reversed order let bo_index = i.reverse_bits() >> shift_by; psi_powers_bo[bo_index] = psi_powers[i]; psi_inv_powers_bo[bo_index] = psi_inv_powers[i]; } // shoup representation let psi_powers_bo_shoup = psi_powers_bo .iter() .map(|v| shoup_representation_fq(*v, q)) .collect_vec(); let psi_inv_powers_bo_shoup = psi_inv_powers_bo .iter() .map(|v| shoup_representation_fq(*v, q)) .collect_vec(); // n^{-1} \mod{q} let n_inv = mod_inverse(n as u64, q); NttBackendU64 { q, q_twice: 2 * q, n: n as u64, n_inv, psi_powers_bo: psi_powers_bo.into_boxed_slice(), psi_inv_powers_bo: psi_inv_powers_bo.into_boxed_slice(), psi_powers_bo_shoup: psi_powers_bo_shoup.into_boxed_slice(), psi_inv_powers_bo_shoup: psi_inv_powers_bo_shoup.into_boxed_slice(), } } } impl> NttInit for NttBackendU64 { fn new(q: &M, n: usize) -> Self { // This NTT does not support native modulus assert!(!q.is_native()); NttBackendU64::_new(q.q().unwrap(), n) } } impl NttBackendU64 { fn reduce_from_lazy(&self, a: &mut [u64]) { let q = self.q; a.iter_mut().for_each(|a0| { if *a0 >= q { *a0 = *a0 - q; } }); } } impl Ntt for NttBackendU64 { type Element = u64; fn forward_lazy(&self, v: &mut [Self::Element]) { ntt_lazy( v, &self.psi_powers_bo, &self.psi_powers_bo_shoup, self.q, self.q_twice, ) } fn forward(&self, v: &mut [Self::Element]) { ntt( v, &self.psi_powers_bo, &self.psi_powers_bo_shoup, self.q, self.q_twice, ); } fn backward_lazy(&self, v: &mut [Self::Element]) { ntt_inv_lazy( v, &self.psi_inv_powers_bo, &self.psi_inv_powers_bo_shoup, self.n_inv, self.q, self.q_twice, ) } fn backward(&self, v: &mut [Self::Element]) { ntt_inv_lazy( v, &self.psi_inv_powers_bo, &self.psi_inv_powers_bo_shoup, self.n_inv, self.q, self.q_twice, ); self.reduce_from_lazy(v); } } mod tests { use itertools::Itertools; use rand::{thread_rng, Rng}; use rand_distr::Uniform; use super::{NttBackendU64, NttInit}; use crate::{ backend::{ArithmeticOps, ModInit, ModularOpsU64, VectorOps}, ntt::Ntt, utils::{generate_prime, negacyclic_mul}, }; const Q_60_BITS: u64 = 1152921504606748673; const N: usize = 1 << 4; const K: usize = 128; fn random_vec_in_fq(size: usize, q: u64) -> Vec { thread_rng() .sample_iter(Uniform::new(0, q)) .take(size) .collect_vec() } #[test] fn native_ntt_backend_works() { // TODO(Jay): Improve tests. Add tests for different primes and ring size. let ntt_backend = NttBackendU64::_new(Q_60_BITS, N); for _ in 0..K { let mut a = random_vec_in_fq(N, Q_60_BITS); let a_clone = a.clone(); ntt_backend.forward(&mut a); assert_ne!(a, a_clone); ntt_backend.backward(&mut a); assert_eq!(a, a_clone); ntt_backend.forward_lazy(&mut a); assert_ne!(a, a_clone); ntt_backend.backward(&mut a); assert_eq!(a, a_clone); ntt_backend.forward(&mut a); ntt_backend.backward_lazy(&mut a); // reduce a.iter_mut().for_each(|a0| { if *a0 > Q_60_BITS { *a0 -= *a0 - Q_60_BITS; } }); assert_eq!(a, a_clone); } } #[test] fn native_ntt_negacylic_mul() { let primes = [25, 40, 50, 60] .iter() .map(|bits| generate_prime(*bits, (2 * N) as u64, 1u64 << bits).unwrap()) .collect_vec(); for p in primes.into_iter() { let ntt_backend = NttBackendU64::_new(p, N); let modulus_backend = ModularOpsU64::new(p); for _ in 0..K { let a = random_vec_in_fq(N, p); let b = random_vec_in_fq(N, p); let mut a_clone = a.clone(); let mut b_clone = b.clone(); ntt_backend.forward_lazy(&mut a_clone); ntt_backend.forward_lazy(&mut b_clone); modulus_backend.elwise_mul_mut(&mut a_clone, &b_clone); ntt_backend.backward(&mut a_clone); let mul = |a: &u64, b: &u64| { let tmp = *a as u128 * *b as u128; (tmp % p as u128) as u64 }; let expected_out = negacyclic_mul(&a, &b, mul, p); assert_eq!(a_clone, expected_out); } } } }