From 6cea6917496e12b9a758eceab25c5208d0242e5a Mon Sep 17 00:00:00 2001 From: Janmajaya Mall Date: Wed, 5 Jun 2024 12:00:55 +0530 Subject: [PATCH] minor mods --- Cargo.toml | 4 ++++ benches/modulus.rs | 38 +++++++++++++++++++++++++++++ src/bool/noise.rs | 54 +++++++++++++++++++++--------------------- src/bool/parameters.rs | 2 +- src/decomposer.rs | 5 +--- src/lib.rs | 1 + src/pbs.rs | 9 +++---- src/rgsw.rs | 2 ++ 8 files changed, 79 insertions(+), 36 deletions(-) create mode 100644 benches/modulus.rs diff --git a/Cargo.toml b/Cargo.toml index 0ebd8dd..732d164 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,4 +16,8 @@ criterion = "0.5.1" [[bench]] name = "ntt" +harness = false + +[[bench]] +name = "modulus" harness = false \ No newline at end of file diff --git a/benches/modulus.rs b/benches/modulus.rs new file mode 100644 index 0000000..087ad93 --- /dev/null +++ b/benches/modulus.rs @@ -0,0 +1,38 @@ +use bin_rs::{ModInit, ModularOpsU64, VectorOps}; +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; +use itertools::Itertools; +use rand::{thread_rng, Rng}; +use rand_distr::Uniform; + +fn benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("modulus"); + // 55 + for prime in [36028797017456641] { + for ring_size in [1 << 11, 1 << 15] { + let modop = ModularOpsU64::new(prime); + + let mut rng = thread_rng(); + let dist = Uniform::new(0, prime); + + let a0 = (&mut rng).sample_iter(dist).take(ring_size).collect_vec(); + let a1 = (&mut rng).sample_iter(dist).take(ring_size).collect_vec(); + let a2 = (&mut rng).sample_iter(dist).take(ring_size).collect_vec(); + + group.bench_function( + BenchmarkId::new("elwise_fma", format!("q={prime}/{ring_size}")), + |b| { + b.iter_batched_ref( + || (a0.clone(), a1.clone(), a2.clone()), + |(a0, a1, a2)| black_box(modop.elwise_fma_mut(a0, a1, a2)), + criterion::BatchSize::PerIteration, + ) + }, + ); + } + } + + group.finish(); +} + +criterion_group!(modulus, benchmark); +criterion_main!(modulus); diff --git a/src/bool/noise.rs b/src/bool/noise.rs index f25c54a..d2b7491 100644 --- a/src/bool/noise.rs +++ b/src/bool/noise.rs @@ -95,12 +95,12 @@ mod test { let true_el_encoded = evaluator.parameters().rlwe_q().true_el(); let false_el_encoded = evaluator.parameters().rlwe_q().false_el(); - let mut stats = Stats::new(); + // let mut stats = Stats::new(); for _ in 0..1000 { - // let now = std::time::Instant::now(); + let now = std::time::Instant::now(); let c_out = evaluator.xor(&c_m0, &c_m1, &server_key_eval_domain); - // println!("Gate time: {:?}", now.elapsed()); + println!("Gate time: {:?}", now.elapsed()); // mp decrypt let decryption_shares = cks @@ -111,36 +111,36 @@ mod test { let m_expected = (m0 ^ m1); assert_eq!(m_expected, m_out, "Expected {m_expected} but got {m_out}"); - // find noise update - { - let out = decrypt_lwe( - &c_out, - ideal_client_key.sk_rlwe().values(), - evaluator.pbs_info().modop_rlweq(), - ); - - let out_want = { - if m_expected == true { - true_el_encoded - } else { - false_el_encoded - } - }; - let diff = evaluator.pbs_info().modop_rlweq().sub(&out, &out_want); - - stats.add_more(&vec![evaluator - .pbs_info() - .rlwe_q() - .map_element_to_i64(&diff)]); - } + // // find noise update + // { + // let out = decrypt_lwe( + // &c_out, + // ideal_client_key.sk_rlwe().values(), + // evaluator.pbs_info().modop_rlweq(), + // ); + + // let out_want = { + // if m_expected == true { + // true_el_encoded + // } else { + // false_el_encoded + // } + // }; + // let diff = evaluator.pbs_info().modop_rlweq().sub(&out, &out_want); + + // stats.add_more(&vec![evaluator + // .pbs_info() + // .rlwe_q() + // .map_element_to_i64(&diff)]); + // } m1 = m0; - m0 = m_out; + m0 = m_expected; c_m1 = c_m0; c_m0 = c_out; } - println!("log2 std dev {}", stats.std_dev().abs().log2()); + // println!("log2 std dev {}", stats.std_dev().abs().log2()); } } diff --git a/src/bool/parameters.rs b/src/bool/parameters.rs index 0bc6133..ba0f63a 100644 --- a/src/bool/parameters.rs +++ b/src/bool/parameters.rs @@ -337,7 +337,7 @@ pub(crate) const SMALL_MP_BOOL_PARAMS: BoolParameters = BoolParameters:: Decompose } } - /// Signed BNAF decomposition. Only returns most significant `d` - /// decomposition limbs - /// - /// Implements algorithm 3 of https://eprint.iacr.org/2021/1161.pdf + // TODO(Jay): Outline the caveat fn decompose(&self, value: &T) -> Vec { let mut value = round_value(*value, self.ignore_bits); diff --git a/src/lib.rs b/src/lib.rs index ac8187e..4f91121 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,6 +20,7 @@ mod rgsw; mod shortint; mod utils; +pub use backend::{ModInit, ModularOpsU64, VectorOps}; pub use ntt::{Ntt, NttBackendU64, NttInit}; pub trait Matrix: AsRef<[Self::R]> { diff --git a/src/pbs.rs b/src/pbs.rs index 9d32d22..ddfda22 100644 --- a/src/pbs.rs +++ b/src/pbs.rs @@ -240,7 +240,7 @@ fn blind_rotation< let s_indices = &gk_to_si[q_by_4 + i]; s_indices.iter().for_each(|s_index| { - let new = std::time::Instant::now(); + // let new = std::time::Instant::now(); rlwe_by_rgsw( trivial_rlwe_test_poly, pbs_key.rgsw_ct_lwe_si(*s_index), @@ -249,14 +249,14 @@ fn blind_rotation< ntt_op, mod_op, ); - println!("Rlwe x Rgsw time: {:?}", new.elapsed()); + // println!("Rlwe x Rgsw time: {:?}", new.elapsed()); }); v += 1; if gk_to_si[q_by_4 + i - 1].len() != 0 || v == w || i == 1 { let (auto_map_index, auto_map_sign) = parameters.rlwe_auto_map(v); - let now = std::time::Instant::now(); + // let now = std::time::Instant::now(); galois_auto( trivial_rlwe_test_poly, pbs_key.galois_key_for_auto(v), @@ -267,7 +267,7 @@ fn blind_rotation< ntt_op, auto_decomposer, ); - println!("Auto time: {:?}", now.elapsed()); + // println!("Auto time: {:?}", now.elapsed()); count += 1; v = 0; @@ -296,6 +296,7 @@ fn blind_rotation< ntt_op, auto_decomposer, ); + count += 1; // +(g^k) let mut v = 0; diff --git a/src/rgsw.rs b/src/rgsw.rs index 7b771cc..9e15b59 100644 --- a/src/rgsw.rs +++ b/src/rgsw.rs @@ -791,6 +791,7 @@ pub(crate) fn rlwe_by_rgsw< ); scratch_matrix_d_ring .iter_mut() + .take(d_a) .for_each(|r| ntt_op.forward(r.as_mut())); // a_out += decomp \cdot RLWE_A'(-sm) routine( @@ -815,6 +816,7 @@ pub(crate) fn rlwe_by_rgsw< ); scratch_matrix_d_ring .iter_mut() + .take(d_b) .for_each(|r| ntt_op.forward(r.as_mut())); // a_out += decomp \cdot RLWE_A'(m) routine(