minor mods

This commit is contained in:
Janmajaya Mall
2024-06-05 12:00:55 +05:30
parent 15464c1ecc
commit 6cea691749
8 changed files with 77 additions and 34 deletions

View File

@@ -17,3 +17,7 @@ criterion = "0.5.1"
[[bench]] [[bench]]
name = "ntt" name = "ntt"
harness = false harness = false
[[bench]]
name = "modulus"
harness = false

38
benches/modulus.rs Normal file
View File

@@ -0,0 +1,38 @@
use bin_rs::{ModInit, ModularOpsU64, VectorOps};
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
use itertools::Itertools;
use rand::{thread_rng, Rng};
use rand_distr::Uniform;
fn benchmark(c: &mut Criterion) {
let mut group = c.benchmark_group("modulus");
// 55
for prime in [36028797017456641] {
for ring_size in [1 << 11, 1 << 15] {
let modop = ModularOpsU64::new(prime);
let mut rng = thread_rng();
let dist = Uniform::new(0, prime);
let a0 = (&mut rng).sample_iter(dist).take(ring_size).collect_vec();
let a1 = (&mut rng).sample_iter(dist).take(ring_size).collect_vec();
let a2 = (&mut rng).sample_iter(dist).take(ring_size).collect_vec();
group.bench_function(
BenchmarkId::new("elwise_fma", format!("q={prime}/{ring_size}")),
|b| {
b.iter_batched_ref(
|| (a0.clone(), a1.clone(), a2.clone()),
|(a0, a1, a2)| black_box(modop.elwise_fma_mut(a0, a1, a2)),
criterion::BatchSize::PerIteration,
)
},
);
}
}
group.finish();
}
criterion_group!(modulus, benchmark);
criterion_main!(modulus);

View File

@@ -95,12 +95,12 @@ mod test {
let true_el_encoded = evaluator.parameters().rlwe_q().true_el(); let true_el_encoded = evaluator.parameters().rlwe_q().true_el();
let false_el_encoded = evaluator.parameters().rlwe_q().false_el(); let false_el_encoded = evaluator.parameters().rlwe_q().false_el();
let mut stats = Stats::new(); // let mut stats = Stats::new();
for _ in 0..1000 { for _ in 0..1000 {
// let now = std::time::Instant::now(); let now = std::time::Instant::now();
let c_out = evaluator.xor(&c_m0, &c_m1, &server_key_eval_domain); let c_out = evaluator.xor(&c_m0, &c_m1, &server_key_eval_domain);
// println!("Gate time: {:?}", now.elapsed()); println!("Gate time: {:?}", now.elapsed());
// mp decrypt // mp decrypt
let decryption_shares = cks let decryption_shares = cks
@@ -111,36 +111,36 @@ mod test {
let m_expected = (m0 ^ m1); let m_expected = (m0 ^ m1);
assert_eq!(m_expected, m_out, "Expected {m_expected} but got {m_out}"); assert_eq!(m_expected, m_out, "Expected {m_expected} but got {m_out}");
// find noise update // // find noise update
{ // {
let out = decrypt_lwe( // let out = decrypt_lwe(
&c_out, // &c_out,
ideal_client_key.sk_rlwe().values(), // ideal_client_key.sk_rlwe().values(),
evaluator.pbs_info().modop_rlweq(), // evaluator.pbs_info().modop_rlweq(),
); // );
let out_want = { // let out_want = {
if m_expected == true { // if m_expected == true {
true_el_encoded // true_el_encoded
} else { // } else {
false_el_encoded // false_el_encoded
} // }
}; // };
let diff = evaluator.pbs_info().modop_rlweq().sub(&out, &out_want); // let diff = evaluator.pbs_info().modop_rlweq().sub(&out, &out_want);
stats.add_more(&vec![evaluator // stats.add_more(&vec![evaluator
.pbs_info() // .pbs_info()
.rlwe_q() // .rlwe_q()
.map_element_to_i64(&diff)]); // .map_element_to_i64(&diff)]);
} // }
m1 = m0; m1 = m0;
m0 = m_out; m0 = m_expected;
c_m1 = c_m0; c_m1 = c_m0;
c_m0 = c_out; c_m0 = c_out;
} }
println!("log2 std dev {}", stats.std_dev().abs().log2()); // println!("log2 std dev {}", stats.std_dev().abs().log2());
} }
} }

View File

@@ -337,7 +337,7 @@ pub(crate) const SMALL_MP_BOOL_PARAMS: BoolParameters<u64> = BoolParameters::<u6
lwe_decomposer_base: DecompostionLogBase(4), lwe_decomposer_base: DecompostionLogBase(4),
lwe_decomposer_count: DecompositionCount(5), lwe_decomposer_count: DecompositionCount(5),
rlrg_decomposer_base: DecompostionLogBase(11), rlrg_decomposer_base: DecompostionLogBase(11),
rlrg_decomposer_count: (DecompositionCount(2), DecompositionCount(2)), rlrg_decomposer_count: (DecompositionCount(2), DecompositionCount(1)),
rgrg_decomposer_base: DecompostionLogBase(11), rgrg_decomposer_base: DecompostionLogBase(11),
rgrg_decomposer_count: (DecompositionCount(5), DecompositionCount(4)), rgrg_decomposer_count: (DecompositionCount(5), DecompositionCount(4)),
auto_decomposer_base: DecompostionLogBase(11), auto_decomposer_base: DecompostionLogBase(11),

View File

@@ -119,10 +119,7 @@ impl<T: PrimInt + ToPrimitive + FromPrimitive + WrappingSub + NumInfo> Decompose
} }
} }
/// Signed BNAF decomposition. Only returns most significant `d` // TODO(Jay): Outline the caveat
/// decomposition limbs
///
/// Implements algorithm 3 of https://eprint.iacr.org/2021/1161.pdf
fn decompose(&self, value: &T) -> Vec<T> { fn decompose(&self, value: &T) -> Vec<T> {
let mut value = round_value(*value, self.ignore_bits); let mut value = round_value(*value, self.ignore_bits);

View File

@@ -20,6 +20,7 @@ mod rgsw;
mod shortint; mod shortint;
mod utils; mod utils;
pub use backend::{ModInit, ModularOpsU64, VectorOps};
pub use ntt::{Ntt, NttBackendU64, NttInit}; pub use ntt::{Ntt, NttBackendU64, NttInit};
pub trait Matrix: AsRef<[Self::R]> { pub trait Matrix: AsRef<[Self::R]> {

View File

@@ -240,7 +240,7 @@ fn blind_rotation<
let s_indices = &gk_to_si[q_by_4 + i]; let s_indices = &gk_to_si[q_by_4 + i];
s_indices.iter().for_each(|s_index| { s_indices.iter().for_each(|s_index| {
let new = std::time::Instant::now(); // let new = std::time::Instant::now();
rlwe_by_rgsw( rlwe_by_rgsw(
trivial_rlwe_test_poly, trivial_rlwe_test_poly,
pbs_key.rgsw_ct_lwe_si(*s_index), pbs_key.rgsw_ct_lwe_si(*s_index),
@@ -249,14 +249,14 @@ fn blind_rotation<
ntt_op, ntt_op,
mod_op, mod_op,
); );
println!("Rlwe x Rgsw time: {:?}", new.elapsed()); // println!("Rlwe x Rgsw time: {:?}", new.elapsed());
}); });
v += 1; v += 1;
if gk_to_si[q_by_4 + i - 1].len() != 0 || v == w || i == 1 { if gk_to_si[q_by_4 + i - 1].len() != 0 || v == w || i == 1 {
let (auto_map_index, auto_map_sign) = parameters.rlwe_auto_map(v); let (auto_map_index, auto_map_sign) = parameters.rlwe_auto_map(v);
let now = std::time::Instant::now(); // let now = std::time::Instant::now();
galois_auto( galois_auto(
trivial_rlwe_test_poly, trivial_rlwe_test_poly,
pbs_key.galois_key_for_auto(v), pbs_key.galois_key_for_auto(v),
@@ -267,7 +267,7 @@ fn blind_rotation<
ntt_op, ntt_op,
auto_decomposer, auto_decomposer,
); );
println!("Auto time: {:?}", now.elapsed()); // println!("Auto time: {:?}", now.elapsed());
count += 1; count += 1;
v = 0; v = 0;
@@ -296,6 +296,7 @@ fn blind_rotation<
ntt_op, ntt_op,
auto_decomposer, auto_decomposer,
); );
count += 1;
// +(g^k) // +(g^k)
let mut v = 0; let mut v = 0;

View File

@@ -791,6 +791,7 @@ pub(crate) fn rlwe_by_rgsw<
); );
scratch_matrix_d_ring scratch_matrix_d_ring
.iter_mut() .iter_mut()
.take(d_a)
.for_each(|r| ntt_op.forward(r.as_mut())); .for_each(|r| ntt_op.forward(r.as_mut()));
// a_out += decomp<a_in> \cdot RLWE_A'(-sm) // a_out += decomp<a_in> \cdot RLWE_A'(-sm)
routine( routine(
@@ -815,6 +816,7 @@ pub(crate) fn rlwe_by_rgsw<
); );
scratch_matrix_d_ring scratch_matrix_d_ring
.iter_mut() .iter_mut()
.take(d_b)
.for_each(|r| ntt_op.forward(r.as_mut())); .for_each(|r| ntt_op.forward(r.as_mut()));
// a_out += decomp<b_in> \cdot RLWE_A'(m) // a_out += decomp<b_in> \cdot RLWE_A'(m)
routine( routine(