From 0f496a1032572e60f35e190f6af77a67ab988b85 Mon Sep 17 00:00:00 2001 From: Janmajaya Mall Date: Mon, 10 Jun 2024 13:59:05 +0530 Subject: [PATCH] add more ntt benches --- benches/ntt.rs | 45 +++++++++++++++++ src/ntt.rs | 120 ++++++++++++++++++++++++-------------------- src/pbs.rs | 2 +- src/shortint/mod.rs | 6 ++- src/utils.rs | 23 +++++++-- 5 files changed, 134 insertions(+), 62 deletions(-) diff --git a/benches/ntt.rs b/benches/ntt.rs index d860749..fe0931e 100644 --- a/benches/ntt.rs +++ b/benches/ntt.rs @@ -13,6 +13,15 @@ fn forward_lazy_matrix(a: &mut [Vec], nttop: &NttBackendU64) { .for_each(|r| nttop.forward_lazy(r.as_mut_slice())); } +fn backward_matrix(a: &mut [Vec], nttop: &NttBackendU64) { + a.iter_mut().for_each(|r| nttop.backward(r.as_mut_slice())); +} + +fn backward_lazy_matrix(a: &mut [Vec], nttop: &NttBackendU64) { + a.iter_mut() + .for_each(|r| nttop.backward_lazy(r.as_mut_slice())); +} + fn benchmark(c: &mut Criterion) { let mut group = c.benchmark_group("ntt"); // 55 @@ -85,6 +94,17 @@ fn benchmark(c: &mut Criterion) { } { + group.bench_function( + BenchmarkId::new("backward", format!("q={prime}/N={ring_size}")), + |b| { + b.iter_batched_ref( + || a.clone(), + |mut a| black_box(ntt.backward(&mut a)), + criterion::BatchSize::PerIteration, + ) + }, + ); + group.bench_function( BenchmarkId::new("backward_lazy", format!("q={prime}/N={ring_size}")), |b| { @@ -95,6 +115,31 @@ fn benchmark(c: &mut Criterion) { ) }, ); + + group.bench_function( + BenchmarkId::new("backward_matrix", format!("q={prime}/N={ring_size}")), + |b| { + b.iter_batched_ref( + || a_matrix.clone(), + |a_matrix| black_box(backward_matrix(a_matrix, &ntt)), + criterion::BatchSize::PerIteration, + ) + }, + ); + + group.bench_function( + BenchmarkId::new( + "backward_lazy_matrix", + format!("q={prime}/N={ring_size}/d={d}"), + ), + |b| { + b.iter_batched_ref( + || a_matrix.clone(), + |a_matrix| black_box(backward_lazy_matrix(a_matrix, &ntt)), + criterion::BatchSize::PerIteration, + ) + }, + ); } } } diff --git a/src/ntt.rs b/src/ntt.rs index 649af8b..a2e0c40 100644 --- a/src/ntt.rs +++ b/src/ntt.rs @@ -4,7 +4,7 @@ use rand_chacha::{rand_core::le, ChaCha8Rng}; use crate::{ backend::{ArithmeticOps, ModInit, ModularOpsU64, Modulus}, - utils::{mod_exponent, mod_inverse, shoup_representation_fq}, + utils::{mod_exponent, mod_inverse, ShoupMul}, }; pub trait NttInit { @@ -43,9 +43,7 @@ pub fn forward_butterly_0_to_4q( x = x - q_twice; } - // TODO (Jay): Hot path expected. How expensive is it? - let k = ((w_shoup as u128 * y as u128) >> 64) as u64; - let t = w.wrapping_mul(y).wrapping_sub(k.wrapping_mul(q)); + let t = ShoupMul::mul(y, w, w_shoup, q); (x + t, x + q_twice - t) } @@ -65,8 +63,7 @@ pub fn forward_butterly_0_to_2q( x = x - q_twice; } - let k = ((w_shoup as u128 * y as u128) >> 64) as u64; - let t = w.wrapping_mul(y).wrapping_sub(k.wrapping_mul(q)); + let t = ShoupMul::mul(y, w, w_shoup, q); let ox = x.wrapping_add(t); let oy = x.wrapping_sub(t); @@ -84,7 +81,7 @@ pub fn forward_butterly_0_to_2q( /// and both x' and y' are \in [0, 2q) /// /// Implements Algorithm 3 of [FASTER ARITHMETIC FOR NUMBER-THEORETIC TRANSFORMS](https://arxiv.org/pdf/1205.2926.pdf) -pub fn inverse_butterfly( +pub fn inverse_butterfly_0_to_2q( x: u64, y: u64, w_inv: u64, @@ -101,8 +98,7 @@ pub fn inverse_butterfly( } let t = x + q_twice - y; - let k = ((w_inv_shoup as u128 * t as u128) >> 64) as u64; - let y = w_inv.wrapping_mul(t).wrapping_sub(k.wrapping_mul(q)); + let y = ShoupMul::mul(t, w_inv, w_inv_shoup, q); (x_dash, y) } @@ -202,68 +198,82 @@ pub fn ntt_inv_lazy( let mut t = 1; while m > 0 { - let w_inv = &psi_inv[m..]; - let w_inv_shoup = &psi_inv_shoup[m..]; - if m == 1 { let (left, right) = a.split_at_mut(t); for (x, y) in izip!(left.iter_mut(), right.iter_mut()) { - let (ox, oy) = inverse_butterfly(*x, *y, w_inv[0], w_inv_shoup[0], q, q_twice); - - *x = (n_inv.wrapping_mul(ox)).wrapping_sub( - q.wrapping_mul(((ox as u128 * n_inv_shoup as u128) >> 64) as u64), - ); - *y = (n_inv.wrapping_mul(oy)).wrapping_sub( - q.wrapping_mul(((oy as u128 * n_inv_shoup as u128) >> 64) as u64), - ); + let (ox, oy) = + inverse_butterfly_0_to_2q(*x, *y, psi_inv[1], psi_inv_shoup[1], q, q_twice); + *x = ShoupMul::mul(ox, n_inv, n_inv_shoup, q); + *y = ShoupMul::mul(oy, n_inv, n_inv_shoup, q); } } else { + let w_inv = &psi_inv[m..]; + let w_inv_shoup = &psi_inv_shoup[m..]; for i in 0..m { let a = &mut a[2 * i * t..2 * (i + 1) * t]; let (left, right) = a.split_at_mut(t); for (x, y) in izip!(left.iter_mut(), right.iter_mut()) { - let (ox, oy) = inverse_butterfly(*x, *y, w_inv[i], w_inv_shoup[i], q, q_twice); + let (ox, oy) = + inverse_butterfly_0_to_2q(*x, *y, w_inv[i], w_inv_shoup[i], q, q_twice); *x = ox; *y = oy; } } } - // for i in 0..m { - // let a = &mut a[2 * i * t..2 * (i + 1) * t]; - // let (left, right) = a.split_at_mut(t); - - // for (x, y) in izip!(left.iter_mut(), right.iter_mut()) { - // let (ox, oy) = inverse_butterfly(*x, *y, w_inv[i], w_inv_shoup[i], q, - // q_twice); *x = ox; - // *y = oy; - // } - // } - - // for i in 0..h { - // let j_2 = j_1 + t; - // unsafe { - // let w_inv = psi_inv.get_unchecked(h + i); - // let w_inv_shoup = psi_inv_shoup.get_unchecked(h + i); - - // for j in j_1..j_2 { - // let x = a.get_unchecked_mut(j) as *mut u64; - // let y = a.get_unchecked_mut(j + t) as *mut u64; - // inverse_butterfly(x, y, w_inv, w_inv_shoup, &q, &q_twice); - // } - // } - // j_1 = j_1 + 2 * t; - // } t *= 2; m >>= 1; } +} + +/// Same as `ntt_inv_lazy` with output in range [0, q) +pub fn ntt_inv( + a: &mut [u64], + psi_inv: &[u64], + psi_inv_shoup: &[u64], + n_inv: u64, + n_inv_shoup: u64, + q: u64, + q_twice: u64, +) { + assert!(a.len() == psi_inv.len()); + + let mut m = a.len() >> 1; + let mut t = 1; + + while m > 0 { + if m == 1 { + let (left, right) = a.split_at_mut(t); - // a.iter_mut().for_each(|a0| { - // *a0 = (n_inv.wrapping_mul(*a0)) - // .wrapping_sub(((*a0 as u128 * n_inv_shoup as u128) >> 64) as u64) - // }); + for (x, y) in izip!(left.iter_mut(), right.iter_mut()) { + let (ox, oy) = + inverse_butterfly_0_to_2q(*x, *y, psi_inv[1], psi_inv_shoup[1], q, q_twice); + let ox = ShoupMul::mul(ox, n_inv, n_inv_shoup, q); + let oy = ShoupMul::mul(oy, n_inv, n_inv_shoup, q); + *x = ox.min(ox.wrapping_sub(q)); + *y = oy.min(oy.wrapping_sub(q)); + } + } else { + let w_inv = &psi_inv[m..]; + let w_inv_shoup = &psi_inv_shoup[m..]; + for i in 0..m { + let a = &mut a[2 * i * t..2 * (i + 1) * t]; + let (left, right) = a.split_at_mut(t); + + for (x, y) in izip!(left.iter_mut(), right.iter_mut()) { + let (ox, oy) = + inverse_butterfly_0_to_2q(*x, *y, w_inv[i], w_inv_shoup[i], q, q_twice); + *x = ox; + *y = oy; + } + } + } + + t *= 2; + m >>= 1; + } } /// Find n^{th} root of unity in field F_q, if one exists @@ -350,11 +360,11 @@ impl NttBackendU64 { // shoup representation let psi_powers_bo_shoup = psi_powers_bo .iter() - .map(|v| shoup_representation_fq(*v, q)) + .map(|v| ShoupMul::representation(*v, q)) .collect_vec(); let psi_inv_powers_bo_shoup = psi_inv_powers_bo .iter() - .map(|v| shoup_representation_fq(*v, q)) + .map(|v| ShoupMul::representation(*v, q)) .collect_vec(); // n^{-1} \mod{q} @@ -365,7 +375,7 @@ impl NttBackendU64 { q_twice: 2 * q, n: n as u64, n_inv, - n_inv_shoup: shoup_representation_fq(n_inv, q), + n_inv_shoup: ShoupMul::representation(n_inv, q), psi_powers_bo: psi_powers_bo.into_boxed_slice(), psi_inv_powers_bo: psi_inv_powers_bo.into_boxed_slice(), psi_powers_bo_shoup: psi_powers_bo_shoup.into_boxed_slice(), @@ -429,7 +439,7 @@ impl Ntt for NttBackendU64 { } fn backward(&self, v: &mut [Self::Element]) { - ntt_inv_lazy( + ntt_inv( v, &self.psi_inv_powers_bo, &self.psi_inv_powers_bo_shoup, @@ -438,10 +448,10 @@ impl Ntt for NttBackendU64 { self.q, self.q_twice, ); - self.reduce_from_lazy(v); } } +#[cfg(test)] mod tests { use itertools::Itertools; use rand::{thread_rng, Rng}; diff --git a/src/pbs.rs b/src/pbs.rs index ddfda22..f6e7370 100644 --- a/src/pbs.rs +++ b/src/pbs.rs @@ -342,7 +342,7 @@ fn blind_rotation< mod_op, ); }); - println!("Auto count: {count}"); + // println!("Auto count: {count}"); } fn mod_switch_odd(v: f64, from_q: f64, to_q: f64) -> usize { diff --git a/src/shortint/mod.rs b/src/shortint/mod.rs index 9cb394c..5fd8b9c 100644 --- a/src/shortint/mod.rs +++ b/src/shortint/mod.rs @@ -306,7 +306,7 @@ mod tests { bool::{ aggregate_public_key_shares, aggregate_server_key_shares, gen_client_key, gen_keys, gen_mp_keys_phase1, gen_mp_keys_phase2, - parameters::{MP_BOOL_PARAMS, SP_BOOL_PARAMS}, + parameters::{MP_BOOL_PARAMS, SMALL_MP_BOOL_PARAMS, SP_BOOL_PARAMS}, set_mp_seed, set_parameter_set, }, shortint::types::FheUint8, @@ -463,7 +463,7 @@ mod tests { #[test] fn fheuint8_test_multi_party() { - set_parameter_set(&MP_BOOL_PARAMS); + set_parameter_set(&SMALL_MP_BOOL_PARAMS); set_mp_seed([0; 32]); let parties = 8; @@ -497,10 +497,12 @@ mod tests { let ct_b = public_key.encrypt(&b); let ct_c = public_key.encrypt(&c); + let now = std::time::Instant::now(); // server computes // a*b + c let mut ct_ab = &ct_a * &ct_b; ct_ab += &ct_c; + println!("Circuit time: {:?}", now.elapsed()); // decrypt ab and check // generate decryption shares diff --git a/src/utils.rs b/src/utils.rs index 31b694e..b448a83 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -25,6 +25,25 @@ pub trait Global { fn global() -> &'static Self; } +pub trait ShoupMul { + fn representation(value: Self, q: Self) -> Self; + fn mul(a: Self, b: Self, b_shoup: Self, q: Self) -> Self; +} + +impl ShoupMul for u64 { + #[inline] + fn representation(value: Self, q: Self) -> Self { + ((value as u128 * (1u128 << 64)) / q as u128) as u64 + } + + #[inline] + /// Returns a * b % q + fn mul(a: Self, b: Self, b_shoup: Self, q: Self) -> Self { + (b.wrapping_mul(a)) + .wrapping_sub(q.wrapping_mul(((b_shoup as u128 * a as u128) >> 64) as u64)) + } +} + pub fn fill_random_ternary_secret_with_hamming_weight< T: Signed, R: RandomFill<[u8]> + RandomElementInModulus, @@ -121,10 +140,6 @@ pub fn mod_inverse(a: u64, q: u64) -> u64 { mod_exponent(a, q - 2, q) } -pub fn shoup_representation_fq(v: u64, q: u64) -> u64 { - ((v as u128 * (1u128 << 64)) / q as u128) as u64 -} - pub fn negacyclic_mul T>( a: &[T], b: &[T],