mirror of
https://github.com/arnaucube/phantom-zone.git
synced 2026-01-10 16:11:30 +01:00
add more ntt benches
This commit is contained in:
@@ -13,6 +13,15 @@ fn forward_lazy_matrix(a: &mut [Vec<u64>], nttop: &NttBackendU64) {
|
|||||||
.for_each(|r| nttop.forward_lazy(r.as_mut_slice()));
|
.for_each(|r| nttop.forward_lazy(r.as_mut_slice()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn backward_matrix(a: &mut [Vec<u64>], nttop: &NttBackendU64) {
|
||||||
|
a.iter_mut().for_each(|r| nttop.backward(r.as_mut_slice()));
|
||||||
|
}
|
||||||
|
|
||||||
|
fn backward_lazy_matrix(a: &mut [Vec<u64>], nttop: &NttBackendU64) {
|
||||||
|
a.iter_mut()
|
||||||
|
.for_each(|r| nttop.backward_lazy(r.as_mut_slice()));
|
||||||
|
}
|
||||||
|
|
||||||
fn benchmark(c: &mut Criterion) {
|
fn benchmark(c: &mut Criterion) {
|
||||||
let mut group = c.benchmark_group("ntt");
|
let mut group = c.benchmark_group("ntt");
|
||||||
// 55
|
// 55
|
||||||
@@ -85,6 +94,17 @@ fn benchmark(c: &mut Criterion) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
group.bench_function(
|
||||||
|
BenchmarkId::new("backward", format!("q={prime}/N={ring_size}")),
|
||||||
|
|b| {
|
||||||
|
b.iter_batched_ref(
|
||||||
|
|| a.clone(),
|
||||||
|
|mut a| black_box(ntt.backward(&mut a)),
|
||||||
|
criterion::BatchSize::PerIteration,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
group.bench_function(
|
group.bench_function(
|
||||||
BenchmarkId::new("backward_lazy", format!("q={prime}/N={ring_size}")),
|
BenchmarkId::new("backward_lazy", format!("q={prime}/N={ring_size}")),
|
||||||
|b| {
|
|b| {
|
||||||
@@ -95,6 +115,31 @@ fn benchmark(c: &mut Criterion) {
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
|
group.bench_function(
|
||||||
|
BenchmarkId::new("backward_matrix", format!("q={prime}/N={ring_size}")),
|
||||||
|
|b| {
|
||||||
|
b.iter_batched_ref(
|
||||||
|
|| a_matrix.clone(),
|
||||||
|
|a_matrix| black_box(backward_matrix(a_matrix, &ntt)),
|
||||||
|
criterion::BatchSize::PerIteration,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
group.bench_function(
|
||||||
|
BenchmarkId::new(
|
||||||
|
"backward_lazy_matrix",
|
||||||
|
format!("q={prime}/N={ring_size}/d={d}"),
|
||||||
|
),
|
||||||
|
|b| {
|
||||||
|
b.iter_batched_ref(
|
||||||
|
|| a_matrix.clone(),
|
||||||
|
|a_matrix| black_box(backward_lazy_matrix(a_matrix, &ntt)),
|
||||||
|
criterion::BatchSize::PerIteration,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
120
src/ntt.rs
120
src/ntt.rs
@@ -4,7 +4,7 @@ use rand_chacha::{rand_core::le, ChaCha8Rng};
|
|||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
backend::{ArithmeticOps, ModInit, ModularOpsU64, Modulus},
|
backend::{ArithmeticOps, ModInit, ModularOpsU64, Modulus},
|
||||||
utils::{mod_exponent, mod_inverse, shoup_representation_fq},
|
utils::{mod_exponent, mod_inverse, ShoupMul},
|
||||||
};
|
};
|
||||||
|
|
||||||
pub trait NttInit<M> {
|
pub trait NttInit<M> {
|
||||||
@@ -43,9 +43,7 @@ pub fn forward_butterly_0_to_4q(
|
|||||||
x = x - q_twice;
|
x = x - q_twice;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO (Jay): Hot path expected. How expensive is it?
|
let t = ShoupMul::mul(y, w, w_shoup, q);
|
||||||
let k = ((w_shoup as u128 * y as u128) >> 64) as u64;
|
|
||||||
let t = w.wrapping_mul(y).wrapping_sub(k.wrapping_mul(q));
|
|
||||||
|
|
||||||
(x + t, x + q_twice - t)
|
(x + t, x + q_twice - t)
|
||||||
}
|
}
|
||||||
@@ -65,8 +63,7 @@ pub fn forward_butterly_0_to_2q(
|
|||||||
x = x - q_twice;
|
x = x - q_twice;
|
||||||
}
|
}
|
||||||
|
|
||||||
let k = ((w_shoup as u128 * y as u128) >> 64) as u64;
|
let t = ShoupMul::mul(y, w, w_shoup, q);
|
||||||
let t = w.wrapping_mul(y).wrapping_sub(k.wrapping_mul(q));
|
|
||||||
|
|
||||||
let ox = x.wrapping_add(t);
|
let ox = x.wrapping_add(t);
|
||||||
let oy = x.wrapping_sub(t);
|
let oy = x.wrapping_sub(t);
|
||||||
@@ -84,7 +81,7 @@ pub fn forward_butterly_0_to_2q(
|
|||||||
/// and both x' and y' are \in [0, 2q)
|
/// and both x' and y' are \in [0, 2q)
|
||||||
///
|
///
|
||||||
/// Implements Algorithm 3 of [FASTER ARITHMETIC FOR NUMBER-THEORETIC TRANSFORMS](https://arxiv.org/pdf/1205.2926.pdf)
|
/// Implements Algorithm 3 of [FASTER ARITHMETIC FOR NUMBER-THEORETIC TRANSFORMS](https://arxiv.org/pdf/1205.2926.pdf)
|
||||||
pub fn inverse_butterfly(
|
pub fn inverse_butterfly_0_to_2q(
|
||||||
x: u64,
|
x: u64,
|
||||||
y: u64,
|
y: u64,
|
||||||
w_inv: u64,
|
w_inv: u64,
|
||||||
@@ -101,8 +98,7 @@ pub fn inverse_butterfly(
|
|||||||
}
|
}
|
||||||
|
|
||||||
let t = x + q_twice - y;
|
let t = x + q_twice - y;
|
||||||
let k = ((w_inv_shoup as u128 * t as u128) >> 64) as u64;
|
let y = ShoupMul::mul(t, w_inv, w_inv_shoup, q);
|
||||||
let y = w_inv.wrapping_mul(t).wrapping_sub(k.wrapping_mul(q));
|
|
||||||
|
|
||||||
(x_dash, y)
|
(x_dash, y)
|
||||||
}
|
}
|
||||||
@@ -202,68 +198,82 @@ pub fn ntt_inv_lazy(
|
|||||||
let mut t = 1;
|
let mut t = 1;
|
||||||
|
|
||||||
while m > 0 {
|
while m > 0 {
|
||||||
let w_inv = &psi_inv[m..];
|
|
||||||
let w_inv_shoup = &psi_inv_shoup[m..];
|
|
||||||
|
|
||||||
if m == 1 {
|
if m == 1 {
|
||||||
let (left, right) = a.split_at_mut(t);
|
let (left, right) = a.split_at_mut(t);
|
||||||
|
|
||||||
for (x, y) in izip!(left.iter_mut(), right.iter_mut()) {
|
for (x, y) in izip!(left.iter_mut(), right.iter_mut()) {
|
||||||
let (ox, oy) = inverse_butterfly(*x, *y, w_inv[0], w_inv_shoup[0], q, q_twice);
|
let (ox, oy) =
|
||||||
|
inverse_butterfly_0_to_2q(*x, *y, psi_inv[1], psi_inv_shoup[1], q, q_twice);
|
||||||
*x = (n_inv.wrapping_mul(ox)).wrapping_sub(
|
*x = ShoupMul::mul(ox, n_inv, n_inv_shoup, q);
|
||||||
q.wrapping_mul(((ox as u128 * n_inv_shoup as u128) >> 64) as u64),
|
*y = ShoupMul::mul(oy, n_inv, n_inv_shoup, q);
|
||||||
);
|
|
||||||
*y = (n_inv.wrapping_mul(oy)).wrapping_sub(
|
|
||||||
q.wrapping_mul(((oy as u128 * n_inv_shoup as u128) >> 64) as u64),
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
let w_inv = &psi_inv[m..];
|
||||||
|
let w_inv_shoup = &psi_inv_shoup[m..];
|
||||||
for i in 0..m {
|
for i in 0..m {
|
||||||
let a = &mut a[2 * i * t..2 * (i + 1) * t];
|
let a = &mut a[2 * i * t..2 * (i + 1) * t];
|
||||||
let (left, right) = a.split_at_mut(t);
|
let (left, right) = a.split_at_mut(t);
|
||||||
|
|
||||||
for (x, y) in izip!(left.iter_mut(), right.iter_mut()) {
|
for (x, y) in izip!(left.iter_mut(), right.iter_mut()) {
|
||||||
let (ox, oy) = inverse_butterfly(*x, *y, w_inv[i], w_inv_shoup[i], q, q_twice);
|
let (ox, oy) =
|
||||||
|
inverse_butterfly_0_to_2q(*x, *y, w_inv[i], w_inv_shoup[i], q, q_twice);
|
||||||
*x = ox;
|
*x = ox;
|
||||||
*y = oy;
|
*y = oy;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// for i in 0..m {
|
|
||||||
// let a = &mut a[2 * i * t..2 * (i + 1) * t];
|
|
||||||
// let (left, right) = a.split_at_mut(t);
|
|
||||||
|
|
||||||
// for (x, y) in izip!(left.iter_mut(), right.iter_mut()) {
|
|
||||||
// let (ox, oy) = inverse_butterfly(*x, *y, w_inv[i], w_inv_shoup[i], q,
|
|
||||||
// q_twice); *x = ox;
|
|
||||||
// *y = oy;
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// for i in 0..h {
|
|
||||||
// let j_2 = j_1 + t;
|
|
||||||
// unsafe {
|
|
||||||
// let w_inv = psi_inv.get_unchecked(h + i);
|
|
||||||
// let w_inv_shoup = psi_inv_shoup.get_unchecked(h + i);
|
|
||||||
|
|
||||||
// for j in j_1..j_2 {
|
|
||||||
// let x = a.get_unchecked_mut(j) as *mut u64;
|
|
||||||
// let y = a.get_unchecked_mut(j + t) as *mut u64;
|
|
||||||
// inverse_butterfly(x, y, w_inv, w_inv_shoup, &q, &q_twice);
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// j_1 = j_1 + 2 * t;
|
|
||||||
// }
|
|
||||||
t *= 2;
|
t *= 2;
|
||||||
m >>= 1;
|
m >>= 1;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// a.iter_mut().for_each(|a0| {
|
/// Same as `ntt_inv_lazy` with output in range [0, q)
|
||||||
// *a0 = (n_inv.wrapping_mul(*a0))
|
pub fn ntt_inv(
|
||||||
// .wrapping_sub(((*a0 as u128 * n_inv_shoup as u128) >> 64) as u64)
|
a: &mut [u64],
|
||||||
// });
|
psi_inv: &[u64],
|
||||||
|
psi_inv_shoup: &[u64],
|
||||||
|
n_inv: u64,
|
||||||
|
n_inv_shoup: u64,
|
||||||
|
q: u64,
|
||||||
|
q_twice: u64,
|
||||||
|
) {
|
||||||
|
assert!(a.len() == psi_inv.len());
|
||||||
|
|
||||||
|
let mut m = a.len() >> 1;
|
||||||
|
let mut t = 1;
|
||||||
|
|
||||||
|
while m > 0 {
|
||||||
|
if m == 1 {
|
||||||
|
let (left, right) = a.split_at_mut(t);
|
||||||
|
|
||||||
|
for (x, y) in izip!(left.iter_mut(), right.iter_mut()) {
|
||||||
|
let (ox, oy) =
|
||||||
|
inverse_butterfly_0_to_2q(*x, *y, psi_inv[1], psi_inv_shoup[1], q, q_twice);
|
||||||
|
let ox = ShoupMul::mul(ox, n_inv, n_inv_shoup, q);
|
||||||
|
let oy = ShoupMul::mul(oy, n_inv, n_inv_shoup, q);
|
||||||
|
*x = ox.min(ox.wrapping_sub(q));
|
||||||
|
*y = oy.min(oy.wrapping_sub(q));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let w_inv = &psi_inv[m..];
|
||||||
|
let w_inv_shoup = &psi_inv_shoup[m..];
|
||||||
|
for i in 0..m {
|
||||||
|
let a = &mut a[2 * i * t..2 * (i + 1) * t];
|
||||||
|
let (left, right) = a.split_at_mut(t);
|
||||||
|
|
||||||
|
for (x, y) in izip!(left.iter_mut(), right.iter_mut()) {
|
||||||
|
let (ox, oy) =
|
||||||
|
inverse_butterfly_0_to_2q(*x, *y, w_inv[i], w_inv_shoup[i], q, q_twice);
|
||||||
|
*x = ox;
|
||||||
|
*y = oy;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t *= 2;
|
||||||
|
m >>= 1;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Find n^{th} root of unity in field F_q, if one exists
|
/// Find n^{th} root of unity in field F_q, if one exists
|
||||||
@@ -350,11 +360,11 @@ impl NttBackendU64 {
|
|||||||
// shoup representation
|
// shoup representation
|
||||||
let psi_powers_bo_shoup = psi_powers_bo
|
let psi_powers_bo_shoup = psi_powers_bo
|
||||||
.iter()
|
.iter()
|
||||||
.map(|v| shoup_representation_fq(*v, q))
|
.map(|v| ShoupMul::representation(*v, q))
|
||||||
.collect_vec();
|
.collect_vec();
|
||||||
let psi_inv_powers_bo_shoup = psi_inv_powers_bo
|
let psi_inv_powers_bo_shoup = psi_inv_powers_bo
|
||||||
.iter()
|
.iter()
|
||||||
.map(|v| shoup_representation_fq(*v, q))
|
.map(|v| ShoupMul::representation(*v, q))
|
||||||
.collect_vec();
|
.collect_vec();
|
||||||
|
|
||||||
// n^{-1} \mod{q}
|
// n^{-1} \mod{q}
|
||||||
@@ -365,7 +375,7 @@ impl NttBackendU64 {
|
|||||||
q_twice: 2 * q,
|
q_twice: 2 * q,
|
||||||
n: n as u64,
|
n: n as u64,
|
||||||
n_inv,
|
n_inv,
|
||||||
n_inv_shoup: shoup_representation_fq(n_inv, q),
|
n_inv_shoup: ShoupMul::representation(n_inv, q),
|
||||||
psi_powers_bo: psi_powers_bo.into_boxed_slice(),
|
psi_powers_bo: psi_powers_bo.into_boxed_slice(),
|
||||||
psi_inv_powers_bo: psi_inv_powers_bo.into_boxed_slice(),
|
psi_inv_powers_bo: psi_inv_powers_bo.into_boxed_slice(),
|
||||||
psi_powers_bo_shoup: psi_powers_bo_shoup.into_boxed_slice(),
|
psi_powers_bo_shoup: psi_powers_bo_shoup.into_boxed_slice(),
|
||||||
@@ -429,7 +439,7 @@ impl Ntt for NttBackendU64 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn backward(&self, v: &mut [Self::Element]) {
|
fn backward(&self, v: &mut [Self::Element]) {
|
||||||
ntt_inv_lazy(
|
ntt_inv(
|
||||||
v,
|
v,
|
||||||
&self.psi_inv_powers_bo,
|
&self.psi_inv_powers_bo,
|
||||||
&self.psi_inv_powers_bo_shoup,
|
&self.psi_inv_powers_bo_shoup,
|
||||||
@@ -438,10 +448,10 @@ impl Ntt for NttBackendU64 {
|
|||||||
self.q,
|
self.q,
|
||||||
self.q_twice,
|
self.q_twice,
|
||||||
);
|
);
|
||||||
self.reduce_from_lazy(v);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
use rand::{thread_rng, Rng};
|
use rand::{thread_rng, Rng};
|
||||||
|
|||||||
@@ -342,7 +342,7 @@ fn blind_rotation<
|
|||||||
mod_op,
|
mod_op,
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
println!("Auto count: {count}");
|
// println!("Auto count: {count}");
|
||||||
}
|
}
|
||||||
|
|
||||||
fn mod_switch_odd(v: f64, from_q: f64, to_q: f64) -> usize {
|
fn mod_switch_odd(v: f64, from_q: f64, to_q: f64) -> usize {
|
||||||
|
|||||||
@@ -306,7 +306,7 @@ mod tests {
|
|||||||
bool::{
|
bool::{
|
||||||
aggregate_public_key_shares, aggregate_server_key_shares, gen_client_key, gen_keys,
|
aggregate_public_key_shares, aggregate_server_key_shares, gen_client_key, gen_keys,
|
||||||
gen_mp_keys_phase1, gen_mp_keys_phase2,
|
gen_mp_keys_phase1, gen_mp_keys_phase2,
|
||||||
parameters::{MP_BOOL_PARAMS, SP_BOOL_PARAMS},
|
parameters::{MP_BOOL_PARAMS, SMALL_MP_BOOL_PARAMS, SP_BOOL_PARAMS},
|
||||||
set_mp_seed, set_parameter_set,
|
set_mp_seed, set_parameter_set,
|
||||||
},
|
},
|
||||||
shortint::types::FheUint8,
|
shortint::types::FheUint8,
|
||||||
@@ -463,7 +463,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn fheuint8_test_multi_party() {
|
fn fheuint8_test_multi_party() {
|
||||||
set_parameter_set(&MP_BOOL_PARAMS);
|
set_parameter_set(&SMALL_MP_BOOL_PARAMS);
|
||||||
set_mp_seed([0; 32]);
|
set_mp_seed([0; 32]);
|
||||||
|
|
||||||
let parties = 8;
|
let parties = 8;
|
||||||
@@ -497,10 +497,12 @@ mod tests {
|
|||||||
let ct_b = public_key.encrypt(&b);
|
let ct_b = public_key.encrypt(&b);
|
||||||
let ct_c = public_key.encrypt(&c);
|
let ct_c = public_key.encrypt(&c);
|
||||||
|
|
||||||
|
let now = std::time::Instant::now();
|
||||||
// server computes
|
// server computes
|
||||||
// a*b + c
|
// a*b + c
|
||||||
let mut ct_ab = &ct_a * &ct_b;
|
let mut ct_ab = &ct_a * &ct_b;
|
||||||
ct_ab += &ct_c;
|
ct_ab += &ct_c;
|
||||||
|
println!("Circuit time: {:?}", now.elapsed());
|
||||||
|
|
||||||
// decrypt ab and check
|
// decrypt ab and check
|
||||||
// generate decryption shares
|
// generate decryption shares
|
||||||
|
|||||||
23
src/utils.rs
23
src/utils.rs
@@ -25,6 +25,25 @@ pub trait Global {
|
|||||||
fn global() -> &'static Self;
|
fn global() -> &'static Self;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub trait ShoupMul {
|
||||||
|
fn representation(value: Self, q: Self) -> Self;
|
||||||
|
fn mul(a: Self, b: Self, b_shoup: Self, q: Self) -> Self;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ShoupMul for u64 {
|
||||||
|
#[inline]
|
||||||
|
fn representation(value: Self, q: Self) -> Self {
|
||||||
|
((value as u128 * (1u128 << 64)) / q as u128) as u64
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
/// Returns a * b % q
|
||||||
|
fn mul(a: Self, b: Self, b_shoup: Self, q: Self) -> Self {
|
||||||
|
(b.wrapping_mul(a))
|
||||||
|
.wrapping_sub(q.wrapping_mul(((b_shoup as u128 * a as u128) >> 64) as u64))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn fill_random_ternary_secret_with_hamming_weight<
|
pub fn fill_random_ternary_secret_with_hamming_weight<
|
||||||
T: Signed,
|
T: Signed,
|
||||||
R: RandomFill<[u8]> + RandomElementInModulus<usize, usize>,
|
R: RandomFill<[u8]> + RandomElementInModulus<usize, usize>,
|
||||||
@@ -121,10 +140,6 @@ pub fn mod_inverse(a: u64, q: u64) -> u64 {
|
|||||||
mod_exponent(a, q - 2, q)
|
mod_exponent(a, q - 2, q)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn shoup_representation_fq(v: u64, q: u64) -> u64 {
|
|
||||||
((v as u128 * (1u128 << 64)) / q as u128) as u64
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn negacyclic_mul<T: PrimInt, F: Fn(&T, &T) -> T>(
|
pub fn negacyclic_mul<T: PrimInt, F: Fn(&T, &T) -> T>(
|
||||||
a: &[T],
|
a: &[T],
|
||||||
b: &[T],
|
b: &[T],
|
||||||
|
|||||||
Reference in New Issue
Block a user