mirror of
https://github.com/arnaucube/phantom-zone.git
synced 2026-01-10 16:11:30 +01:00
add decomp_iter
This commit is contained in:
@@ -1,9 +1,70 @@
|
||||
use bin_rs::{ModInit, ModularOpsU64, VectorOps};
|
||||
use bin_rs::{Decomposer, DefaultDecomposer, ModInit, ModularOpsU64, VectorOps};
|
||||
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
|
||||
use itertools::Itertools;
|
||||
use itertools::{izip, Itertools};
|
||||
use rand::{thread_rng, Rng};
|
||||
use rand_distr::Uniform;
|
||||
|
||||
pub(crate) fn decompose_r(
|
||||
r: &[u64],
|
||||
decomp_r: &mut [Vec<u64>],
|
||||
decomposer: &DefaultDecomposer<u64>,
|
||||
) {
|
||||
let ring_size = r.len();
|
||||
// let d = decomposer.decomposition_count();
|
||||
// let mut count = 0;
|
||||
for ri in 0..ring_size {
|
||||
// let el_decomposed = decomposer.decompose(&r[ri]);
|
||||
decomposer
|
||||
.decompose_iter(&r[ri])
|
||||
.enumerate()
|
||||
.into_iter()
|
||||
.for_each(|(j, el)| {
|
||||
decomp_r[j][ri] = el;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn benchmark_decomposer(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("decomposer");
|
||||
|
||||
// let decomposers = vec![];
|
||||
// 55
|
||||
for prime in [36028797017456641] {
|
||||
for ring_size in [1 << 11] {
|
||||
let logb = 11;
|
||||
let decomposer = DefaultDecomposer::new(prime, logb, 2);
|
||||
|
||||
let mut rng = thread_rng();
|
||||
let dist = Uniform::new(0, prime);
|
||||
let a = (&mut rng).sample_iter(dist).take(ring_size).collect_vec();
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::new(
|
||||
"decompose",
|
||||
format!(
|
||||
"q={prime}/N={ring_size}/logB={logb}/d={}",
|
||||
decomposer.decomposition_count()
|
||||
),
|
||||
),
|
||||
|b| {
|
||||
b.iter_batched_ref(
|
||||
|| {
|
||||
(
|
||||
a.clone(),
|
||||
vec![vec![0u64; ring_size]; decomposer.decomposition_count()],
|
||||
)
|
||||
},
|
||||
|(r, decomp_r)| (decompose_r(r, decomp_r, &decomposer)),
|
||||
criterion::BatchSize::PerIteration,
|
||||
)
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn benchmark(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("modulus");
|
||||
// 55
|
||||
@@ -34,5 +95,6 @@ fn benchmark(c: &mut Criterion) {
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group!(decomposer, benchmark_decomposer);
|
||||
criterion_group!(modulus, benchmark);
|
||||
criterion_main!(modulus);
|
||||
criterion_main!(modulus, decomposer);
|
||||
|
||||
@@ -103,13 +103,13 @@ mod test {
|
||||
println!("Gate time: {:?}", now.elapsed());
|
||||
|
||||
// mp decrypt
|
||||
let decryption_shares = cks
|
||||
.iter()
|
||||
.map(|c| evaluator.multi_party_decryption_share(&c_out, c))
|
||||
.collect_vec();
|
||||
let m_out = evaluator.multi_party_decrypt(&decryption_shares, &c_out);
|
||||
// let decryption_shares = cks
|
||||
// .iter()
|
||||
// .map(|c| evaluator.multi_party_decryption_share(&c_out, c))
|
||||
// .collect_vec();
|
||||
// let m_out = evaluator.multi_party_decrypt(&decryption_shares, &c_out);
|
||||
let m_expected = (m0 ^ m1);
|
||||
assert_eq!(m_expected, m_out, "Expected {m_expected} but got {m_out}");
|
||||
// assert_eq!(m_expected, m_out, "Expected {m_expected} but got {m_out}");
|
||||
|
||||
// // find noise update
|
||||
// {
|
||||
|
||||
@@ -39,19 +39,33 @@ where
|
||||
|
||||
pub trait Decomposer {
|
||||
type Element;
|
||||
type Iter: Iterator<Item = Self::Element>;
|
||||
fn new(q: Self::Element, logb: usize, d: usize) -> Self;
|
||||
//FIXME(Jay): there's no reason why it returns a vec instead of an iterator
|
||||
fn decompose(&self, v: &Self::Element) -> Vec<Self::Element>;
|
||||
|
||||
fn decompose_to_vec(&self, v: &Self::Element) -> Vec<Self::Element>;
|
||||
fn decompose_iter(&self, v: &Self::Element) -> Self::Iter;
|
||||
fn decomposition_count(&self) -> usize;
|
||||
}
|
||||
|
||||
// TODO(Jay): Shouldn't Decompose also return corresponding gadget vector ?
|
||||
pub struct DefaultDecomposer<T> {
|
||||
/// Ciphertext modulus
|
||||
q: T,
|
||||
/// Log of ciphertext modulus
|
||||
logq: usize,
|
||||
/// Log of base B
|
||||
logb: usize,
|
||||
/// base B
|
||||
b: T,
|
||||
/// (B - 1). To simulate (% B) as &(B-1), that is extract least significant
|
||||
/// logb bits
|
||||
b_mask: T,
|
||||
/// B/2
|
||||
bby2: T,
|
||||
/// Decomposition count
|
||||
d: usize,
|
||||
/// No. of bits to ignore in rounding
|
||||
ignore_bits: usize,
|
||||
/// No. of limbs to ignore in rounding. Set to ceil(logq / logb) - d
|
||||
ignore_limbs: usize,
|
||||
}
|
||||
|
||||
@@ -96,6 +110,7 @@ impl<T: PrimInt + ToPrimitive + FromPrimitive + WrappingSub + NumInfo> Decompose
|
||||
for DefaultDecomposer<T>
|
||||
{
|
||||
type Element = T;
|
||||
type Iter = DecomposerIter<T>;
|
||||
|
||||
fn new(q: T, logb: usize, d: usize) -> DefaultDecomposer<T> {
|
||||
// if q is power of 2, then `BITS - leading_zeros` outputs logq + 1.
|
||||
@@ -113,6 +128,9 @@ impl<T: PrimInt + ToPrimitive + FromPrimitive + WrappingSub + NumInfo> Decompose
|
||||
q,
|
||||
logq,
|
||||
logb,
|
||||
b: T::one() << logb,
|
||||
b_mask: (T::one() << logb) - T::one(),
|
||||
bby2: T::one() << (logb - 1),
|
||||
d,
|
||||
ignore_bits,
|
||||
ignore_limbs,
|
||||
@@ -120,7 +138,7 @@ impl<T: PrimInt + ToPrimitive + FromPrimitive + WrappingSub + NumInfo> Decompose
|
||||
}
|
||||
|
||||
// TODO(Jay): Outline the caveat
|
||||
fn decompose(&self, value: &T) -> Vec<T> {
|
||||
fn decompose_to_vec(&self, value: &T) -> Vec<T> {
|
||||
let mut value = round_value(*value, self.ignore_bits);
|
||||
|
||||
let q = self.q;
|
||||
@@ -153,6 +171,75 @@ impl<T: PrimInt + ToPrimitive + FromPrimitive + WrappingSub + NumInfo> Decompose
|
||||
fn decomposition_count(&self) -> usize {
|
||||
self.d
|
||||
}
|
||||
|
||||
fn decompose_iter(&self, value: &T) -> DecomposerIter<T> {
|
||||
let mut value = round_value(*value, self.ignore_bits);
|
||||
|
||||
if value >= (self.q >> 1) {
|
||||
value = !(self.q - value) + T::one()
|
||||
}
|
||||
|
||||
DecomposerIter {
|
||||
value,
|
||||
q: self.q,
|
||||
logb: self.logb,
|
||||
b: self.b,
|
||||
bby2: self.bby2,
|
||||
b_mask: self.b_mask,
|
||||
steps_left: self.d,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: PrimInt> DefaultDecomposer<T> {}
|
||||
|
||||
pub struct DecomposerIter<T> {
|
||||
/// Value to decompose
|
||||
value: T,
|
||||
steps_left: usize,
|
||||
/// (1 << logb) - 1 (for % (1<<logb); i.e. to extract least signiciant logb
|
||||
/// bits)
|
||||
b_mask: T,
|
||||
logb: usize,
|
||||
// b/2 = 1 << (logb-1)
|
||||
bby2: T,
|
||||
/// Ciphertext modulus
|
||||
q: T,
|
||||
/// b = 1 << logb
|
||||
b: T,
|
||||
}
|
||||
|
||||
impl<T: PrimInt> Iterator for DecomposerIter<T> {
|
||||
type Item = T;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.steps_left != 0 {
|
||||
self.steps_left -= 1;
|
||||
let k_i = self.value & self.b_mask;
|
||||
|
||||
self.value = (self.value - k_i) >> self.logb;
|
||||
|
||||
if k_i > self.bby2 || (k_i == self.bby2 && ((self.value & T::one()) == T::one())) {
|
||||
self.value = self.value + T::one();
|
||||
Some(self.q + k_i - self.b)
|
||||
} else {
|
||||
Some(k_i)
|
||||
}
|
||||
|
||||
// let carry = <T as From<bool>>::from(
|
||||
// k_i > self.bby2 || (k_i == self.bby2 && ((self.value &
|
||||
// T::one()) == T::one())), );
|
||||
// self.value = self.value + carry;
|
||||
|
||||
// Some(
|
||||
// (self.q & ((carry << self.logq) - (T::one() & carry))) + k_i
|
||||
// - (carry << self.logb), )
|
||||
|
||||
// Some(k_i)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn round_value<T: PrimInt>(value: T, ignore_bits: usize) -> T {
|
||||
@@ -197,15 +284,18 @@ mod tests {
|
||||
let modq_op = ModularOpsU64::new(q);
|
||||
for _ in 0..100000 {
|
||||
let value = rng.gen_range(0..q);
|
||||
let limbs = decomposer.decompose(&value);
|
||||
let value_back = decomposer.recompose(&limbs, &modq_op);
|
||||
let rounded_value =
|
||||
round_value(value, decomposer.ignore_bits) << decomposer.ignore_bits;
|
||||
stats.add_more(&Vec::<i64>::try_convert_from(&limbs, &q));
|
||||
let limbs = decomposer.decompose_to_vec(&value);
|
||||
let value_back = round_value(
|
||||
decomposer.recompose(&limbs, &modq_op),
|
||||
decomposer.ignore_bits,
|
||||
);
|
||||
let rounded_value = round_value(value, decomposer.ignore_bits);
|
||||
assert_eq!(
|
||||
rounded_value, value_back,
|
||||
"Expected {rounded_value} got {value_back} for q={q}"
|
||||
);
|
||||
|
||||
stats.add_more(&Vec::<i64>::try_convert_from(&limbs, &q));
|
||||
}
|
||||
}
|
||||
println!("Mean: {}", stats.mean());
|
||||
|
||||
@@ -21,6 +21,7 @@ mod shortint;
|
||||
mod utils;
|
||||
|
||||
pub use backend::{ModInit, ModularOpsU64, VectorOps};
|
||||
pub use decomposer::{Decomposer, DecomposerIter, DefaultDecomposer};
|
||||
pub use ntt::{Ntt, NttBackendU64, NttInit};
|
||||
|
||||
pub trait Matrix: AsRef<[Self::R]> {
|
||||
|
||||
@@ -125,7 +125,7 @@ pub(crate) fn lwe_key_switch<
|
||||
.as_ref()
|
||||
.iter()
|
||||
.skip(1)
|
||||
.flat_map(|ai| decomposer.decompose(ai));
|
||||
.flat_map(|ai| decomposer.decompose_to_vec(ai));
|
||||
izip!(lwe_in_a_decomposed, lwe_ksk.iter_rows()).for_each(|(ai_j, beta_ij_lwe)| {
|
||||
operator.elwise_fma_scalar_mut(lwe_out.as_mut(), beta_ij_lwe.as_ref(), &ai_j);
|
||||
});
|
||||
|
||||
@@ -16,8 +16,7 @@ fn decomposer(mut value: u64, q: u64, d: usize, logb: u64) -> Vec<u64> {
|
||||
// for _ in 0..d {
|
||||
// let k_i = carry + (value & full_mask);
|
||||
// value = (value) >> logb;
|
||||
// let go = thread_rng().gen_bool(1.0 / 2.0);
|
||||
// if k_i > bby2 || (k_i == bby2 && ((value & 1) == 1)) {
|
||||
// if k_i > bby2 {
|
||||
// // if (k_i == bby2 && ((value & 1) == 1)) {
|
||||
// // println!("AA");
|
||||
// // }
|
||||
@@ -31,7 +30,6 @@ fn decomposer(mut value: u64, q: u64, d: usize, logb: u64) -> Vec<u64> {
|
||||
// carry = 0;
|
||||
// }
|
||||
// }
|
||||
// println!("Last carry {carry}");
|
||||
// return out;
|
||||
|
||||
let mut out = Vec::with_capacity(d);
|
||||
|
||||
@@ -518,7 +518,8 @@ pub(crate) fn decompose_r<R: RowMut, D: Decomposer<Element = R::Element>>(
|
||||
let d = decomposer.decomposition_count();
|
||||
|
||||
for ri in 0..ring_size {
|
||||
let el_decomposed = decomposer.decompose(&r[ri]);
|
||||
let el_decomposed = decomposer.decompose_to_vec(&r[ri]);
|
||||
|
||||
for j in 0..d {
|
||||
decomp_r[j].as_mut()[ri] = el_decomposed[j];
|
||||
}
|
||||
@@ -570,7 +571,7 @@ pub(crate) fn galois_auto<
|
||||
.for_each(|(el_in, to_index, sign)| {
|
||||
let el_out = if !*sign { mod_op.neg(el_in) } else { *el_in };
|
||||
|
||||
let el_out_decomposed = decomposer.decompose(&el_out);
|
||||
let el_out_decomposed = decomposer.decompose_to_vec(&el_out);
|
||||
for j in 0..d {
|
||||
scratch_matrix_d_ring[j].as_mut()[*to_index] = el_out_decomposed[j];
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user