mirror of
https://github.com/arnaucube/phantom-zone.git
synced 2026-01-07 22:51:29 +01:00
speed up backward ntt
This commit is contained in:
101
benches/ntt.rs
101
benches/ntt.rs
@@ -35,52 +35,67 @@ fn benchmark(c: &mut Criterion) {
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::new("forward", format!("q={prime}/N={ring_size}")),
|
||||
|b| {
|
||||
b.iter_batched_ref(
|
||||
|| a.clone(),
|
||||
|mut a| black_box(ntt.forward(&mut a)),
|
||||
criterion::BatchSize::PerIteration,
|
||||
)
|
||||
},
|
||||
);
|
||||
{
|
||||
group.bench_function(
|
||||
BenchmarkId::new("forward", format!("q={prime}/N={ring_size}")),
|
||||
|b| {
|
||||
b.iter_batched_ref(
|
||||
|| a.clone(),
|
||||
|mut a| black_box(ntt.forward(&mut a)),
|
||||
criterion::BatchSize::PerIteration,
|
||||
)
|
||||
},
|
||||
);
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::new("forward_lazy", format!("q={prime}/N={ring_size}")),
|
||||
|b| {
|
||||
b.iter_batched_ref(
|
||||
|| a.clone(),
|
||||
|mut a| black_box(ntt.forward_lazy(&mut a)),
|
||||
criterion::BatchSize::PerIteration,
|
||||
)
|
||||
},
|
||||
);
|
||||
group.bench_function(
|
||||
BenchmarkId::new("forward_lazy", format!("q={prime}/N={ring_size}")),
|
||||
|b| {
|
||||
b.iter_batched_ref(
|
||||
|| a.clone(),
|
||||
|mut a| black_box(ntt.forward_lazy(&mut a)),
|
||||
criterion::BatchSize::PerIteration,
|
||||
)
|
||||
},
|
||||
);
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::new("forward_matrix", format!("q={prime}/N={ring_size}/d={d}")),
|
||||
|b| {
|
||||
b.iter_batched_ref(
|
||||
|| a_matrix.clone(),
|
||||
|a_matrix| black_box(forward_matrix(a_matrix, &ntt)),
|
||||
criterion::BatchSize::PerIteration,
|
||||
)
|
||||
},
|
||||
);
|
||||
group.bench_function(
|
||||
BenchmarkId::new("forward_matrix", format!("q={prime}/N={ring_size}/d={d}")),
|
||||
|b| {
|
||||
b.iter_batched_ref(
|
||||
|| a_matrix.clone(),
|
||||
|a_matrix| black_box(forward_matrix(a_matrix, &ntt)),
|
||||
criterion::BatchSize::PerIteration,
|
||||
)
|
||||
},
|
||||
);
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::new(
|
||||
"forward_lazy_matrix",
|
||||
format!("q={prime}/N={ring_size}/d={d}"),
|
||||
),
|
||||
|b| {
|
||||
b.iter_batched_ref(
|
||||
|| a_matrix.clone(),
|
||||
|a_matrix| black_box(forward_lazy_matrix(a_matrix, &ntt)),
|
||||
criterion::BatchSize::PerIteration,
|
||||
)
|
||||
},
|
||||
);
|
||||
group.bench_function(
|
||||
BenchmarkId::new(
|
||||
"forward_lazy_matrix",
|
||||
format!("q={prime}/N={ring_size}/d={d}"),
|
||||
),
|
||||
|b| {
|
||||
b.iter_batched_ref(
|
||||
|| a_matrix.clone(),
|
||||
|a_matrix| black_box(forward_lazy_matrix(a_matrix, &ntt)),
|
||||
criterion::BatchSize::PerIteration,
|
||||
)
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
group.bench_function(
|
||||
BenchmarkId::new("backward_lazy", format!("q={prime}/N={ring_size}")),
|
||||
|b| {
|
||||
b.iter_batched_ref(
|
||||
|| a.clone(),
|
||||
|mut a| black_box(ntt.backward_lazy(&mut a)),
|
||||
criterion::BatchSize::PerIteration,
|
||||
)
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
112
src/ntt.rs
112
src/ntt.rs
@@ -84,27 +84,27 @@ pub fn forward_butterly_0_to_2q(
|
||||
/// and both x' and y' are \in [0, 2q)
|
||||
///
|
||||
/// Implements Algorithm 3 of [FASTER ARITHMETIC FOR NUMBER-THEORETIC TRANSFORMS](https://arxiv.org/pdf/1205.2926.pdf)
|
||||
pub unsafe fn inverse_butterfly(
|
||||
x: *mut u64,
|
||||
y: *mut u64,
|
||||
w_inv: &u64,
|
||||
w_inv_shoup: &u64,
|
||||
q: &u64,
|
||||
q_twice: &u64,
|
||||
) {
|
||||
debug_assert!(*x < *q_twice, "{} >= (2q){q_twice}", *x);
|
||||
debug_assert!(*y < *q_twice, "{} >= (2q){q_twice}", *y);
|
||||
pub fn inverse_butterfly(
|
||||
x: u64,
|
||||
y: u64,
|
||||
w_inv: u64,
|
||||
w_inv_shoup: u64,
|
||||
q: u64,
|
||||
q_twice: u64,
|
||||
) -> (u64, u64) {
|
||||
debug_assert!(x < q_twice, "{} >= (2q){q_twice}", x);
|
||||
debug_assert!(y < q_twice, "{} >= (2q){q_twice}", y);
|
||||
|
||||
let mut x_dash = *x + *y;
|
||||
if x_dash >= *q_twice {
|
||||
let mut x_dash = x + y;
|
||||
if x_dash >= q_twice {
|
||||
x_dash -= q_twice
|
||||
}
|
||||
|
||||
let t = *x + q_twice - *y;
|
||||
let k = ((*w_inv_shoup as u128 * t as u128) >> 64) as u64; // TODO (Jay): Hot path
|
||||
*y = w_inv.wrapping_mul(t).wrapping_sub(k.wrapping_mul(*q));
|
||||
let t = x + q_twice - y;
|
||||
let k = ((w_inv_shoup as u128 * t as u128) >> 64) as u64;
|
||||
let y = w_inv.wrapping_mul(t).wrapping_sub(k.wrapping_mul(q));
|
||||
|
||||
*x = x_dash;
|
||||
(x_dash, y)
|
||||
}
|
||||
|
||||
/// Number theoretic transform of vector `a` where each element can be in range
|
||||
@@ -192,36 +192,78 @@ pub fn ntt_inv_lazy(
|
||||
psi_inv: &[u64],
|
||||
psi_inv_shoup: &[u64],
|
||||
n_inv: u64,
|
||||
n_inv_shoup: u64,
|
||||
q: u64,
|
||||
q_twice: u64,
|
||||
) {
|
||||
debug_assert!(a.len() == psi_inv.len());
|
||||
assert!(a.len() == psi_inv.len());
|
||||
|
||||
let mut m = a.len();
|
||||
let mut m = a.len() >> 1;
|
||||
let mut t = 1;
|
||||
while m > 1 {
|
||||
let mut j_1: usize = 0;
|
||||
let h = m >> 1;
|
||||
for i in 0..h {
|
||||
let j_2 = j_1 + t;
|
||||
unsafe {
|
||||
let w_inv = psi_inv.get_unchecked(h + i);
|
||||
let w_inv_shoup = psi_inv_shoup.get_unchecked(h + i);
|
||||
|
||||
for j in j_1..j_2 {
|
||||
let x = a.get_unchecked_mut(j) as *mut u64;
|
||||
let y = a.get_unchecked_mut(j + t) as *mut u64;
|
||||
inverse_butterfly(x, y, w_inv, w_inv_shoup, &q, &q_twice);
|
||||
while m > 0 {
|
||||
let w_inv = &psi_inv[m..];
|
||||
let w_inv_shoup = &psi_inv_shoup[m..];
|
||||
|
||||
if m == 1 {
|
||||
let (left, right) = a.split_at_mut(t);
|
||||
|
||||
for (x, y) in izip!(left.iter_mut(), right.iter_mut()) {
|
||||
let (ox, oy) = inverse_butterfly(*x, *y, w_inv[0], w_inv_shoup[0], q, q_twice);
|
||||
|
||||
*x = (n_inv.wrapping_mul(ox)).wrapping_sub(
|
||||
q.wrapping_mul(((ox as u128 * n_inv_shoup as u128) >> 64) as u64),
|
||||
);
|
||||
*y = (n_inv.wrapping_mul(oy)).wrapping_sub(
|
||||
q.wrapping_mul(((oy as u128 * n_inv_shoup as u128) >> 64) as u64),
|
||||
);
|
||||
}
|
||||
} else {
|
||||
for i in 0..m {
|
||||
let a = &mut a[2 * i * t..2 * (i + 1) * t];
|
||||
let (left, right) = a.split_at_mut(t);
|
||||
|
||||
for (x, y) in izip!(left.iter_mut(), right.iter_mut()) {
|
||||
let (ox, oy) = inverse_butterfly(*x, *y, w_inv[i], w_inv_shoup[i], q, q_twice);
|
||||
*x = ox;
|
||||
*y = oy;
|
||||
}
|
||||
}
|
||||
j_1 = j_1 + 2 * t;
|
||||
}
|
||||
|
||||
// for i in 0..m {
|
||||
// let a = &mut a[2 * i * t..2 * (i + 1) * t];
|
||||
// let (left, right) = a.split_at_mut(t);
|
||||
|
||||
// for (x, y) in izip!(left.iter_mut(), right.iter_mut()) {
|
||||
// let (ox, oy) = inverse_butterfly(*x, *y, w_inv[i], w_inv_shoup[i], q,
|
||||
// q_twice); *x = ox;
|
||||
// *y = oy;
|
||||
// }
|
||||
// }
|
||||
|
||||
// for i in 0..h {
|
||||
// let j_2 = j_1 + t;
|
||||
// unsafe {
|
||||
// let w_inv = psi_inv.get_unchecked(h + i);
|
||||
// let w_inv_shoup = psi_inv_shoup.get_unchecked(h + i);
|
||||
|
||||
// for j in j_1..j_2 {
|
||||
// let x = a.get_unchecked_mut(j) as *mut u64;
|
||||
// let y = a.get_unchecked_mut(j + t) as *mut u64;
|
||||
// inverse_butterfly(x, y, w_inv, w_inv_shoup, &q, &q_twice);
|
||||
// }
|
||||
// }
|
||||
// j_1 = j_1 + 2 * t;
|
||||
// }
|
||||
t *= 2;
|
||||
m >>= 1;
|
||||
}
|
||||
|
||||
a.iter_mut()
|
||||
.for_each(|a0| *a0 = ((*a0 as u128 * n_inv as u128) % q as u128) as u64);
|
||||
// a.iter_mut().for_each(|a0| {
|
||||
// *a0 = (n_inv.wrapping_mul(*a0))
|
||||
// .wrapping_sub(((*a0 as u128 * n_inv_shoup as u128) >> 64) as u64)
|
||||
// });
|
||||
}
|
||||
|
||||
/// Find n^{th} root of unity in field F_q, if one exists
|
||||
@@ -259,6 +301,7 @@ pub struct NttBackendU64 {
|
||||
q_twice: u64,
|
||||
n: u64,
|
||||
n_inv: u64,
|
||||
n_inv_shoup: u64,
|
||||
psi_powers_bo: Box<[u64]>,
|
||||
psi_inv_powers_bo: Box<[u64]>,
|
||||
psi_powers_bo_shoup: Box<[u64]>,
|
||||
@@ -322,6 +365,7 @@ impl NttBackendU64 {
|
||||
q_twice: 2 * q,
|
||||
n: n as u64,
|
||||
n_inv,
|
||||
n_inv_shoup: shoup_representation_fq(n_inv, q),
|
||||
psi_powers_bo: psi_powers_bo.into_boxed_slice(),
|
||||
psi_inv_powers_bo: psi_inv_powers_bo.into_boxed_slice(),
|
||||
psi_powers_bo_shoup: psi_powers_bo_shoup.into_boxed_slice(),
|
||||
@@ -378,6 +422,7 @@ impl Ntt for NttBackendU64 {
|
||||
&self.psi_inv_powers_bo,
|
||||
&self.psi_inv_powers_bo_shoup,
|
||||
self.n_inv,
|
||||
self.n_inv_shoup,
|
||||
self.q,
|
||||
self.q_twice,
|
||||
)
|
||||
@@ -389,6 +434,7 @@ impl Ntt for NttBackendU64 {
|
||||
&self.psi_inv_powers_bo,
|
||||
&self.psi_inv_powers_bo_shoup,
|
||||
self.n_inv,
|
||||
self.n_inv_shoup,
|
||||
self.q,
|
||||
self.q_twice,
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user