diff --git a/src/ntt.rs b/src/ntt.rs index b6b9649..0a0f17d 100644 --- a/src/ntt.rs +++ b/src/ntt.rs @@ -1,6 +1,6 @@ -use itertools::Itertools; +use itertools::{izip, Itertools}; use rand::{thread_rng, Rng, RngCore, SeedableRng}; -use rand_chacha::ChaCha8Rng; +use rand_chacha::{rand_core::le, ChaCha8Rng}; use crate::{ backend::{ArithmeticOps, ModInit, ModularOpsU64, Modulus}, @@ -28,27 +28,26 @@ pub trait Ntt { /// and both x' and y' are \in [0, 4q) /// /// Implements Algorithm 4 of [FASTER ARITHMETIC FOR NUMBER-THEORETIC TRANSFORMS](https://arxiv.org/pdf/1205.2926.pdf) -pub unsafe fn forward_butterly( - x: *mut u64, - y: *mut u64, - w: &u64, - w_shoup: &u64, - q: &u64, - q_twice: &u64, -) { - debug_assert!(*x < *q * 4, "{} >= (4q){}", *x, 4 * q); - debug_assert!(*y < *q * 4, "{} >= (4q){}", *y, 4 * q); +pub fn forward_butterly( + mut x: u64, + y: u64, + w: u64, + w_shoup: u64, + q: u64, + q_twice: u64, +) -> (u64, u64) { + debug_assert!(x < q * 4, "{} >= (4q){}", x, 4 * q); + debug_assert!(y < q * 4, "{} >= (4q){}", y, 4 * q); - if *x >= *q_twice { - *x = *x - q_twice; + if x >= q_twice { + x = x - q_twice; } // TODO (Jay): Hot path expected. How expensive is it? - let k = ((*w_shoup as u128 * *y as u128) >> 64) as u64; - let t = w.wrapping_mul(*y).wrapping_sub(k.wrapping_mul(*q)); + let k = ((w_shoup as u128 * y as u128) >> 64) as u64; + let t = w.wrapping_mul(y).wrapping_sub(k.wrapping_mul(q)); - *y = *x + q_twice - t; - *x = *x + t; + (x + t, x + q_twice - t) } /// Inverse butterfly routine of Inverse Number theoretic transform. Given @@ -86,7 +85,7 @@ pub unsafe fn inverse_butterfly( /// /// Implements Cooley-tukey based forward NTT as given in Algorithm 1 of https://eprint.iacr.org/2016/504.pdf. pub fn ntt_lazy(a: &mut [u64], psi: &[u64], psi_shoup: &[u64], q: u64, q_twice: u64) { - debug_assert!(a.len() == psi.len()); + assert!(a.len() == psi.len()); let n = a.len(); let mut t = n; @@ -94,29 +93,36 @@ pub fn ntt_lazy(a: &mut [u64], psi: &[u64], psi_shoup: &[u64], q: u64, q_twice: let mut m = 1; while m < n { t >>= 1; + let w = &psi[m..]; + let w_shoup = &psi_shoup[m..]; + + // for (vector, w, w_shoup) in + // izip!(a.chunks_mut(t << 1), psi[m..].iter(), psi_shoup[m..].iter()) + // { + // let (left, right) = vector.split_at_mut(t); + + // for (x, y) in izip!(left.iter_mut(), right.iter_mut()) { + // let (ox, oy) = forward_butterly(*x, *y, *w, *w_shoup, q, q_twice); + // *x = ox; + // *y = oy; + // } + // } for i in 0..m { - let j_1 = 2 * i * t; - let j_2 = j_1 + t; + let a = &mut a[2 * i * t..(2 * (i + 1) * t)]; + let (left, right) = a.split_at_mut(t); - unsafe { - let w = psi.get_unchecked(m + i); - let w_shoup = psi_shoup.get_unchecked(m + i); - for j in j_1..j_2 { - let x = a.get_unchecked_mut(j) as *mut u64; - let y = a.get_unchecked_mut(j + t) as *mut u64; - forward_butterly(x, y, w, w_shoup, &q, &q_twice); - } + for (x, y) in izip!(left.iter_mut(), right.iter_mut()) { + let (ox, oy) = forward_butterly(*x, *y, w[i], w_shoup[i], q, q_twice); + *x = ox; + *y = oy; } } - m <<= 1; } a.iter_mut().for_each(|a0| { - if *a0 >= q_twice { - *a0 -= q_twice - } + *a0 = (*a0).min((*a0).wrapping_sub(q_twice)); }); }