diff --git a/benches/ntt.rs b/benches/ntt.rs index b26b56e..13cacfe 100644 --- a/benches/ntt.rs +++ b/benches/ntt.rs @@ -4,20 +4,39 @@ use itertools::Itertools; use rand::{thread_rng, Rng}; use rand_distr::Uniform; +fn forward_matrix(a: &mut [Vec], nttop: &NttBackendU64) { + a.iter_mut().for_each(|r| nttop.forward(r.as_mut_slice())); +} + +fn forward_lazy_matrix(a: &mut [Vec], nttop: &NttBackendU64) { + a.iter_mut() + .for_each(|r| nttop.forward_lazy(r.as_mut_slice())); +} + fn benchmark(c: &mut Criterion) { let mut group = c.benchmark_group("ntt"); // 55 for prime in [36028797017456641] { - for ring_size in [1 << 11, 1 << 15] { + for ring_size in [1 << 11] { let ntt = NttBackendU64::new(&prime, ring_size); let mut rng = thread_rng(); - let a = rng + let a = (&mut rng) .sample_iter(Uniform::new(0, prime)) .take(ring_size) .collect_vec(); + let d = 3; + let a_matrix = (0..d) + .map(|_| { + (&mut rng) + .sample_iter(Uniform::new(0, prime)) + .take(ring_size) + .collect_vec() + }) + .collect_vec(); + group.bench_function( - BenchmarkId::new("forward", format!("q={prime}/{ring_size}")), + BenchmarkId::new("forward", format!("q={prime}/N={ring_size}")), |b| { b.iter_batched_ref( || a.clone(), @@ -26,6 +45,42 @@ fn benchmark(c: &mut Criterion) { ) }, ); + + group.bench_function( + BenchmarkId::new("forward_lazy", format!("q={prime}/N={ring_size}")), + |b| { + b.iter_batched_ref( + || a.clone(), + |mut a| black_box(ntt.forward_lazy(&mut a)), + criterion::BatchSize::PerIteration, + ) + }, + ); + + group.bench_function( + BenchmarkId::new("forward_matrix", format!("q={prime}/N={ring_size}/d={d}")), + |b| { + b.iter_batched_ref( + || a_matrix.clone(), + |a_matrix| black_box(forward_matrix(a_matrix, &ntt)), + criterion::BatchSize::PerIteration, + ) + }, + ); + + group.bench_function( + BenchmarkId::new( + "forward_lazy_matrix", + format!("q={prime}/N={ring_size}/d={d}"), + ), + |b| { + b.iter_batched_ref( + || a_matrix.clone(), + |a_matrix| black_box(forward_lazy_matrix(a_matrix, &ntt)), + criterion::BatchSize::PerIteration, + ) + }, + ); } } diff --git a/src/ntt.rs b/src/ntt.rs index 0a0f17d..5f31aff 100644 --- a/src/ntt.rs +++ b/src/ntt.rs @@ -28,7 +28,7 @@ pub trait Ntt { /// and both x' and y' are \in [0, 4q) /// /// Implements Algorithm 4 of [FASTER ARITHMETIC FOR NUMBER-THEORETIC TRANSFORMS](https://arxiv.org/pdf/1205.2926.pdf) -pub fn forward_butterly( +pub fn forward_butterly_0_to_4q( mut x: u64, y: u64, w: u64, @@ -50,6 +50,33 @@ pub fn forward_butterly( (x + t, x + q_twice - t) } +pub fn forward_butterly_0_to_2q( + mut x: u64, + mut y: u64, + w: u64, + w_shoup: u64, + q: u64, + q_twice: u64, +) -> (u64, u64) { + debug_assert!(x < q * 4, "{} >= (4q){}", x, 4 * q); + debug_assert!(y < q * 4, "{} >= (4q){}", y, 4 * q); + + if x >= q_twice { + x = x - q_twice; + } + + let k = ((w_shoup as u128 * y as u128) >> 64) as u64; + let t = w.wrapping_mul(y).wrapping_sub(k.wrapping_mul(q)); + + let ox = x.wrapping_add(t); + let oy = x.wrapping_sub(t); + + ( + (ox).min(ox.wrapping_sub(q_twice)), + oy.min(oy.wrapping_add(q_twice)), + ) +} + /// Inverse butterfly routine of Inverse Number theoretic transform. Given /// inputs `x < 2q` and `y < 2q` mutates x and y to equal x' and y' where /// x'= x + y @@ -96,34 +123,63 @@ pub fn ntt_lazy(a: &mut [u64], psi: &[u64], psi_shoup: &[u64], q: u64, q_twice: let w = &psi[m..]; let w_shoup = &psi_shoup[m..]; - // for (vector, w, w_shoup) in - // izip!(a.chunks_mut(t << 1), psi[m..].iter(), psi_shoup[m..].iter()) - // { - // let (left, right) = vector.split_at_mut(t); - - // for (x, y) in izip!(left.iter_mut(), right.iter_mut()) { - // let (ox, oy) = forward_butterly(*x, *y, *w, *w_shoup, q, q_twice); - // *x = ox; - // *y = oy; - // } - // } - - for i in 0..m { - let a = &mut a[2 * i * t..(2 * (i + 1) * t)]; - let (left, right) = a.split_at_mut(t); - - for (x, y) in izip!(left.iter_mut(), right.iter_mut()) { - let (ox, oy) = forward_butterly(*x, *y, w[i], w_shoup[i], q, q_twice); - *x = ox; - *y = oy; + if t == 1 { + for (a, w, w_shoup) in izip!(a.chunks_mut(2), w.iter(), w_shoup.iter()) { + let (ox, oy) = forward_butterly_0_to_2q(a[0], a[1], *w, *w_shoup, q, q_twice); + a[0] = ox; + a[1] = oy; + } + } else { + for i in 0..m { + let a = &mut a[2 * i * t..(2 * (i + 1) * t)]; + let (left, right) = a.split_at_mut(t); + + for (x, y) in izip!(left.iter_mut(), right.iter_mut()) { + let (ox, oy) = forward_butterly_0_to_4q(*x, *y, w[i], w_shoup[i], q, q_twice); + *x = ox; + *y = oy; + } } } + m <<= 1; } +} + +/// Same as `ntt_lazy` with output in range [0, q) +pub fn ntt(a: &mut [u64], psi: &[u64], psi_shoup: &[u64], q: u64, q_twice: u64) { + assert!(a.len() == psi.len()); + + let n = a.len(); + let mut t = n; + + let mut m = 1; + while m < n { + t >>= 1; + let w = &psi[m..]; + let w_shoup = &psi_shoup[m..]; - a.iter_mut().for_each(|a0| { - *a0 = (*a0).min((*a0).wrapping_sub(q_twice)); - }); + if t == 1 { + for (a, w, w_shoup) in izip!(a.chunks_mut(2), w.iter(), w_shoup.iter()) { + let (ox, oy) = forward_butterly_0_to_2q(a[0], a[1], *w, *w_shoup, q, q_twice); + a[0] = ox.min(ox.wrapping_sub(q_twice)); + a[1] = oy.min(oy.wrapping_sub(q_twice)); + } + } else { + for i in 0..m { + let a = &mut a[2 * i * t..(2 * (i + 1) * t)]; + let (left, right) = a.split_at_mut(t); + + for (x, y) in izip!(left.iter_mut(), right.iter_mut()) { + let (ox, oy) = forward_butterly_0_to_4q(*x, *y, w[i], w_shoup[i], q, q_twice); + *x = ox; + *y = oy; + } + } + } + + m <<= 1; + } } /// Inverse number theoretic transform of input vector `a` with each element can @@ -307,14 +363,13 @@ impl Ntt for NttBackendU64 { } fn forward(&self, v: &mut [Self::Element]) { - ntt_lazy( + ntt( v, &self.psi_powers_bo, &self.psi_powers_bo_shoup, self.q, self.q_twice, ); - self.reduce_from_lazy(v); } fn backward_lazy(&self, v: &mut [Self::Element]) {