From e9bdaaea696d70fa16fcea61d149aeee982f4ab4 Mon Sep 17 00:00:00 2001 From: Janmajaya Mall Date: Mon, 10 Jun 2024 13:08:51 +0530 Subject: [PATCH] speed up backward ntt --- benches/ntt.rs | 101 ++++++++++++++++++++++++------------------- src/ntt.rs | 114 ++++++++++++++++++++++++++++++++++--------------- 2 files changed, 138 insertions(+), 77 deletions(-) diff --git a/benches/ntt.rs b/benches/ntt.rs index 13cacfe..d860749 100644 --- a/benches/ntt.rs +++ b/benches/ntt.rs @@ -35,52 +35,67 @@ fn benchmark(c: &mut Criterion) { }) .collect_vec(); - group.bench_function( - BenchmarkId::new("forward", format!("q={prime}/N={ring_size}")), - |b| { - b.iter_batched_ref( - || a.clone(), - |mut a| black_box(ntt.forward(&mut a)), - criterion::BatchSize::PerIteration, - ) - }, - ); + { + group.bench_function( + BenchmarkId::new("forward", format!("q={prime}/N={ring_size}")), + |b| { + b.iter_batched_ref( + || a.clone(), + |mut a| black_box(ntt.forward(&mut a)), + criterion::BatchSize::PerIteration, + ) + }, + ); - group.bench_function( - BenchmarkId::new("forward_lazy", format!("q={prime}/N={ring_size}")), - |b| { - b.iter_batched_ref( - || a.clone(), - |mut a| black_box(ntt.forward_lazy(&mut a)), - criterion::BatchSize::PerIteration, - ) - }, - ); + group.bench_function( + BenchmarkId::new("forward_lazy", format!("q={prime}/N={ring_size}")), + |b| { + b.iter_batched_ref( + || a.clone(), + |mut a| black_box(ntt.forward_lazy(&mut a)), + criterion::BatchSize::PerIteration, + ) + }, + ); - group.bench_function( - BenchmarkId::new("forward_matrix", format!("q={prime}/N={ring_size}/d={d}")), - |b| { - b.iter_batched_ref( - || a_matrix.clone(), - |a_matrix| black_box(forward_matrix(a_matrix, &ntt)), - criterion::BatchSize::PerIteration, - ) - }, - ); + group.bench_function( + BenchmarkId::new("forward_matrix", format!("q={prime}/N={ring_size}/d={d}")), + |b| { + b.iter_batched_ref( + || a_matrix.clone(), + |a_matrix| black_box(forward_matrix(a_matrix, &ntt)), + criterion::BatchSize::PerIteration, + ) + }, + ); - group.bench_function( - BenchmarkId::new( - "forward_lazy_matrix", - format!("q={prime}/N={ring_size}/d={d}"), - ), - |b| { - b.iter_batched_ref( - || a_matrix.clone(), - |a_matrix| black_box(forward_lazy_matrix(a_matrix, &ntt)), - criterion::BatchSize::PerIteration, - ) - }, - ); + group.bench_function( + BenchmarkId::new( + "forward_lazy_matrix", + format!("q={prime}/N={ring_size}/d={d}"), + ), + |b| { + b.iter_batched_ref( + || a_matrix.clone(), + |a_matrix| black_box(forward_lazy_matrix(a_matrix, &ntt)), + criterion::BatchSize::PerIteration, + ) + }, + ); + } + + { + group.bench_function( + BenchmarkId::new("backward_lazy", format!("q={prime}/N={ring_size}")), + |b| { + b.iter_batched_ref( + || a.clone(), + |mut a| black_box(ntt.backward_lazy(&mut a)), + criterion::BatchSize::PerIteration, + ) + }, + ); + } } } diff --git a/src/ntt.rs b/src/ntt.rs index 5f31aff..649af8b 100644 --- a/src/ntt.rs +++ b/src/ntt.rs @@ -84,27 +84,27 @@ pub fn forward_butterly_0_to_2q( /// and both x' and y' are \in [0, 2q) /// /// Implements Algorithm 3 of [FASTER ARITHMETIC FOR NUMBER-THEORETIC TRANSFORMS](https://arxiv.org/pdf/1205.2926.pdf) -pub unsafe fn inverse_butterfly( - x: *mut u64, - y: *mut u64, - w_inv: &u64, - w_inv_shoup: &u64, - q: &u64, - q_twice: &u64, -) { - debug_assert!(*x < *q_twice, "{} >= (2q){q_twice}", *x); - debug_assert!(*y < *q_twice, "{} >= (2q){q_twice}", *y); +pub fn inverse_butterfly( + x: u64, + y: u64, + w_inv: u64, + w_inv_shoup: u64, + q: u64, + q_twice: u64, +) -> (u64, u64) { + debug_assert!(x < q_twice, "{} >= (2q){q_twice}", x); + debug_assert!(y < q_twice, "{} >= (2q){q_twice}", y); - let mut x_dash = *x + *y; - if x_dash >= *q_twice { + let mut x_dash = x + y; + if x_dash >= q_twice { x_dash -= q_twice } - let t = *x + q_twice - *y; - let k = ((*w_inv_shoup as u128 * t as u128) >> 64) as u64; // TODO (Jay): Hot path - *y = w_inv.wrapping_mul(t).wrapping_sub(k.wrapping_mul(*q)); + let t = x + q_twice - y; + let k = ((w_inv_shoup as u128 * t as u128) >> 64) as u64; + let y = w_inv.wrapping_mul(t).wrapping_sub(k.wrapping_mul(q)); - *x = x_dash; + (x_dash, y) } /// Number theoretic transform of vector `a` where each element can be in range @@ -192,36 +192,78 @@ pub fn ntt_inv_lazy( psi_inv: &[u64], psi_inv_shoup: &[u64], n_inv: u64, + n_inv_shoup: u64, q: u64, q_twice: u64, ) { - debug_assert!(a.len() == psi_inv.len()); + assert!(a.len() == psi_inv.len()); - let mut m = a.len(); + let mut m = a.len() >> 1; let mut t = 1; - while m > 1 { - let mut j_1: usize = 0; - let h = m >> 1; - for i in 0..h { - let j_2 = j_1 + t; - unsafe { - let w_inv = psi_inv.get_unchecked(h + i); - let w_inv_shoup = psi_inv_shoup.get_unchecked(h + i); - - for j in j_1..j_2 { - let x = a.get_unchecked_mut(j) as *mut u64; - let y = a.get_unchecked_mut(j + t) as *mut u64; - inverse_butterfly(x, y, w_inv, w_inv_shoup, &q, &q_twice); + + while m > 0 { + let w_inv = &psi_inv[m..]; + let w_inv_shoup = &psi_inv_shoup[m..]; + + if m == 1 { + let (left, right) = a.split_at_mut(t); + + for (x, y) in izip!(left.iter_mut(), right.iter_mut()) { + let (ox, oy) = inverse_butterfly(*x, *y, w_inv[0], w_inv_shoup[0], q, q_twice); + + *x = (n_inv.wrapping_mul(ox)).wrapping_sub( + q.wrapping_mul(((ox as u128 * n_inv_shoup as u128) >> 64) as u64), + ); + *y = (n_inv.wrapping_mul(oy)).wrapping_sub( + q.wrapping_mul(((oy as u128 * n_inv_shoup as u128) >> 64) as u64), + ); + } + } else { + for i in 0..m { + let a = &mut a[2 * i * t..2 * (i + 1) * t]; + let (left, right) = a.split_at_mut(t); + + for (x, y) in izip!(left.iter_mut(), right.iter_mut()) { + let (ox, oy) = inverse_butterfly(*x, *y, w_inv[i], w_inv_shoup[i], q, q_twice); + *x = ox; + *y = oy; } } - j_1 = j_1 + 2 * t; } + + // for i in 0..m { + // let a = &mut a[2 * i * t..2 * (i + 1) * t]; + // let (left, right) = a.split_at_mut(t); + + // for (x, y) in izip!(left.iter_mut(), right.iter_mut()) { + // let (ox, oy) = inverse_butterfly(*x, *y, w_inv[i], w_inv_shoup[i], q, + // q_twice); *x = ox; + // *y = oy; + // } + // } + + // for i in 0..h { + // let j_2 = j_1 + t; + // unsafe { + // let w_inv = psi_inv.get_unchecked(h + i); + // let w_inv_shoup = psi_inv_shoup.get_unchecked(h + i); + + // for j in j_1..j_2 { + // let x = a.get_unchecked_mut(j) as *mut u64; + // let y = a.get_unchecked_mut(j + t) as *mut u64; + // inverse_butterfly(x, y, w_inv, w_inv_shoup, &q, &q_twice); + // } + // } + // j_1 = j_1 + 2 * t; + // } t *= 2; m >>= 1; } - a.iter_mut() - .for_each(|a0| *a0 = ((*a0 as u128 * n_inv as u128) % q as u128) as u64); + // a.iter_mut().for_each(|a0| { + // *a0 = (n_inv.wrapping_mul(*a0)) + // .wrapping_sub(((*a0 as u128 * n_inv_shoup as u128) >> 64) as u64) + // }); } /// Find n^{th} root of unity in field F_q, if one exists @@ -259,6 +301,7 @@ pub struct NttBackendU64 { q_twice: u64, n: u64, n_inv: u64, + n_inv_shoup: u64, psi_powers_bo: Box<[u64]>, psi_inv_powers_bo: Box<[u64]>, psi_powers_bo_shoup: Box<[u64]>, @@ -322,6 +365,7 @@ impl NttBackendU64 { q_twice: 2 * q, n: n as u64, n_inv, + n_inv_shoup: shoup_representation_fq(n_inv, q), psi_powers_bo: psi_powers_bo.into_boxed_slice(), psi_inv_powers_bo: psi_inv_powers_bo.into_boxed_slice(), psi_powers_bo_shoup: psi_powers_bo_shoup.into_boxed_slice(), @@ -378,6 +422,7 @@ impl Ntt for NttBackendU64 { &self.psi_inv_powers_bo, &self.psi_inv_powers_bo_shoup, self.n_inv, + self.n_inv_shoup, self.q, self.q_twice, ) @@ -389,6 +434,7 @@ impl Ntt for NttBackendU64 { &self.psi_inv_powers_bo, &self.psi_inv_powers_bo_shoup, self.n_inv, + self.n_inv_shoup, self.q, self.q_twice, );