Browse Source

add decomp_iter

par-agg-key-shares
Janmajaya Mall 10 months ago
parent
commit
a05e959e75
7 changed files with 176 additions and 24 deletions
  1. +65
    -3
      benches/modulus.rs
  2. +6
    -6
      src/bool/noise.rs
  3. +99
    -9
      src/decomposer.rs
  4. +1
    -0
      src/lib.rs
  5. +1
    -1
      src/lwe.rs
  6. +1
    -3
      src/main.rs
  7. +3
    -2
      src/rgsw.rs

+ 65
- 3
benches/modulus.rs

@ -1,9 +1,70 @@
use bin_rs::{ModInit, ModularOpsU64, VectorOps};
use bin_rs::{Decomposer, DefaultDecomposer, ModInit, ModularOpsU64, VectorOps};
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
use itertools::Itertools;
use itertools::{izip, Itertools};
use rand::{thread_rng, Rng};
use rand_distr::Uniform;
pub(crate) fn decompose_r(
r: &[u64],
decomp_r: &mut [Vec<u64>],
decomposer: &DefaultDecomposer<u64>,
) {
let ring_size = r.len();
// let d = decomposer.decomposition_count();
// let mut count = 0;
for ri in 0..ring_size {
// let el_decomposed = decomposer.decompose(&r[ri]);
decomposer
.decompose_iter(&r[ri])
.enumerate()
.into_iter()
.for_each(|(j, el)| {
decomp_r[j][ri] = el;
});
}
}
fn benchmark_decomposer(c: &mut Criterion) {
let mut group = c.benchmark_group("decomposer");
// let decomposers = vec![];
// 55
for prime in [36028797017456641] {
for ring_size in [1 << 11] {
let logb = 11;
let decomposer = DefaultDecomposer::new(prime, logb, 2);
let mut rng = thread_rng();
let dist = Uniform::new(0, prime);
let a = (&mut rng).sample_iter(dist).take(ring_size).collect_vec();
group.bench_function(
BenchmarkId::new(
"decompose",
format!(
"q={prime}/N={ring_size}/logB={logb}/d={}",
decomposer.decomposition_count()
),
),
|b| {
b.iter_batched_ref(
|| {
(
a.clone(),
vec![vec![0u64; ring_size]; decomposer.decomposition_count()],
)
},
|(r, decomp_r)| (decompose_r(r, decomp_r, &decomposer)),
criterion::BatchSize::PerIteration,
)
},
);
}
}
group.finish();
}
fn benchmark(c: &mut Criterion) {
let mut group = c.benchmark_group("modulus");
// 55
@ -34,5 +95,6 @@ fn benchmark(c: &mut Criterion) {
group.finish();
}
criterion_group!(decomposer, benchmark_decomposer);
criterion_group!(modulus, benchmark);
criterion_main!(modulus);
criterion_main!(modulus, decomposer);

+ 6
- 6
src/bool/noise.rs

@ -103,13 +103,13 @@ mod test {
println!("Gate time: {:?}", now.elapsed());
// mp decrypt
let decryption_shares = cks
.iter()
.map(|c| evaluator.multi_party_decryption_share(&c_out, c))
.collect_vec();
let m_out = evaluator.multi_party_decrypt(&decryption_shares, &c_out);
// let decryption_shares = cks
// .iter()
// .map(|c| evaluator.multi_party_decryption_share(&c_out, c))
// .collect_vec();
// let m_out = evaluator.multi_party_decrypt(&decryption_shares, &c_out);
let m_expected = (m0 ^ m1);
assert_eq!(m_expected, m_out, "Expected {m_expected} but got {m_out}");
// assert_eq!(m_expected, m_out, "Expected {m_expected} but got {m_out}");
// // find noise update
// {

+ 99
- 9
src/decomposer.rs

@ -39,19 +39,33 @@ where
pub trait Decomposer {
type Element;
type Iter: Iterator<Item = Self::Element>;
fn new(q: Self::Element, logb: usize, d: usize) -> Self;
//FIXME(Jay): there's no reason why it returns a vec instead of an iterator
fn decompose(&self, v: &Self::Element) -> Vec<Self::Element>;
fn decompose_to_vec(&self, v: &Self::Element) -> Vec<Self::Element>;
fn decompose_iter(&self, v: &Self::Element) -> Self::Iter;
fn decomposition_count(&self) -> usize;
}
// TODO(Jay): Shouldn't Decompose also return corresponding gadget vector ?
pub struct DefaultDecomposer<T> {
/// Ciphertext modulus
q: T,
/// Log of ciphertext modulus
logq: usize,
/// Log of base B
logb: usize,
/// base B
b: T,
/// (B - 1). To simulate (% B) as &(B-1), that is extract least significant
/// logb bits
b_mask: T,
/// B/2
bby2: T,
/// Decomposition count
d: usize,
/// No. of bits to ignore in rounding
ignore_bits: usize,
/// No. of limbs to ignore in rounding. Set to ceil(logq / logb) - d
ignore_limbs: usize,
}
@ -96,6 +110,7 @@ impl Decompose
for DefaultDecomposer<T>
{
type Element = T;
type Iter = DecomposerIter<T>;
fn new(q: T, logb: usize, d: usize) -> DefaultDecomposer<T> {
// if q is power of 2, then `BITS - leading_zeros` outputs logq + 1.
@ -113,6 +128,9 @@ impl Decompose
q,
logq,
logb,
b: T::one() << logb,
b_mask: (T::one() << logb) - T::one(),
bby2: T::one() << (logb - 1),
d,
ignore_bits,
ignore_limbs,
@ -120,7 +138,7 @@ impl Decompose
}
// TODO(Jay): Outline the caveat
fn decompose(&self, value: &T) -> Vec<T> {
fn decompose_to_vec(&self, value: &T) -> Vec<T> {
let mut value = round_value(*value, self.ignore_bits);
let q = self.q;
@ -153,6 +171,75 @@ impl Decompose
fn decomposition_count(&self) -> usize {
self.d
}
fn decompose_iter(&self, value: &T) -> DecomposerIter<T> {
let mut value = round_value(*value, self.ignore_bits);
if value >= (self.q >> 1) {
value = !(self.q - value) + T::one()
}
DecomposerIter {
value,
q: self.q,
logb: self.logb,
b: self.b,
bby2: self.bby2,
b_mask: self.b_mask,
steps_left: self.d,
}
}
}
impl<T: PrimInt> DefaultDecomposer<T> {}
pub struct DecomposerIter<T> {
/// Value to decompose
value: T,
steps_left: usize,
/// (1 << logb) - 1 (for % (1<<logb); i.e. to extract least signiciant logb
/// bits)
b_mask: T,
logb: usize,
// b/2 = 1 << (logb-1)
bby2: T,
/// Ciphertext modulus
q: T,
/// b = 1 << logb
b: T,
}
impl<T: PrimInt> Iterator for DecomposerIter<T> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
if self.steps_left != 0 {
self.steps_left -= 1;
let k_i = self.value & self.b_mask;
self.value = (self.value - k_i) >> self.logb;
if k_i > self.bby2 || (k_i == self.bby2 && ((self.value & T::one()) == T::one())) {
self.value = self.value + T::one();
Some(self.q + k_i - self.b)
} else {
Some(k_i)
}
// let carry = <T as From<bool>>::from(
// k_i > self.bby2 || (k_i == self.bby2 && ((self.value &
// T::one()) == T::one())), );
// self.value = self.value + carry;
// Some(
// (self.q & ((carry << self.logq) - (T::one() & carry))) + k_i
// - (carry << self.logb), )
// Some(k_i)
} else {
None
}
}
}
fn round_value<T: PrimInt>(value: T, ignore_bits: usize) -> T {
@ -197,15 +284,18 @@ mod tests {
let modq_op = ModularOpsU64::new(q);
for _ in 0..100000 {
let value = rng.gen_range(0..q);
let limbs = decomposer.decompose(&value);
let value_back = decomposer.recompose(&limbs, &modq_op);
let rounded_value =
round_value(value, decomposer.ignore_bits) << decomposer.ignore_bits;
stats.add_more(&Vec::<i64>::try_convert_from(&limbs, &q));
let limbs = decomposer.decompose_to_vec(&value);
let value_back = round_value(
decomposer.recompose(&limbs, &modq_op),
decomposer.ignore_bits,
);
let rounded_value = round_value(value, decomposer.ignore_bits);
assert_eq!(
rounded_value, value_back,
"Expected {rounded_value} got {value_back} for q={q}"
);
stats.add_more(&Vec::<i64>::try_convert_from(&limbs, &q));
}
}
println!("Mean: {}", stats.mean());

+ 1
- 0
src/lib.rs

@ -21,6 +21,7 @@ mod shortint;
mod utils;
pub use backend::{ModInit, ModularOpsU64, VectorOps};
pub use decomposer::{Decomposer, DecomposerIter, DefaultDecomposer};
pub use ntt::{Ntt, NttBackendU64, NttInit};
pub trait Matrix: AsRef<[Self::R]> {

+ 1
- 1
src/lwe.rs

@ -125,7 +125,7 @@ pub(crate) fn lwe_key_switch<
.as_ref()
.iter()
.skip(1)
.flat_map(|ai| decomposer.decompose(ai));
.flat_map(|ai| decomposer.decompose_to_vec(ai));
izip!(lwe_in_a_decomposed, lwe_ksk.iter_rows()).for_each(|(ai_j, beta_ij_lwe)| {
operator.elwise_fma_scalar_mut(lwe_out.as_mut(), beta_ij_lwe.as_ref(), &ai_j);
});

+ 1
- 3
src/main.rs

@ -16,8 +16,7 @@ fn decomposer(mut value: u64, q: u64, d: usize, logb: u64) -> Vec {
// for _ in 0..d {
// let k_i = carry + (value & full_mask);
// value = (value) >> logb;
// let go = thread_rng().gen_bool(1.0 / 2.0);
// if k_i > bby2 || (k_i == bby2 && ((value & 1) == 1)) {
// if k_i > bby2 {
// // if (k_i == bby2 && ((value & 1) == 1)) {
// // println!("AA");
// // }
@ -31,7 +30,6 @@ fn decomposer(mut value: u64, q: u64, d: usize, logb: u64) -> Vec {
// carry = 0;
// }
// }
// println!("Last carry {carry}");
// return out;
let mut out = Vec::with_capacity(d);

+ 3
- 2
src/rgsw.rs

@ -518,7 +518,8 @@ pub(crate) fn decompose_r>(
let d = decomposer.decomposition_count();
for ri in 0..ring_size {
let el_decomposed = decomposer.decompose(&r[ri]);
let el_decomposed = decomposer.decompose_to_vec(&r[ri]);
for j in 0..d {
decomp_r[j].as_mut()[ri] = el_decomposed[j];
}
@ -570,7 +571,7 @@ pub(crate) fn galois_auto<
.for_each(|(el_in, to_index, sign)| {
let el_out = if !*sign { mod_op.neg(el_in) } else { *el_in };
let el_out_decomposed = decomposer.decompose(&el_out);
let el_out_decomposed = decomposer.decompose_to_vec(&el_out);
for j in 0..d {
scratch_matrix_d_ring[j].as_mut()[*to_index] = el_out_decomposed[j];
}

Loading…
Cancel
Save