mirror of
https://github.com/arnaucube/phantom-zone.git
synced 2026-01-12 00:51:29 +01:00
add decomp_iter
This commit is contained in:
@@ -1,9 +1,70 @@
|
|||||||
use bin_rs::{ModInit, ModularOpsU64, VectorOps};
|
use bin_rs::{Decomposer, DefaultDecomposer, ModInit, ModularOpsU64, VectorOps};
|
||||||
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
|
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
|
||||||
use itertools::Itertools;
|
use itertools::{izip, Itertools};
|
||||||
use rand::{thread_rng, Rng};
|
use rand::{thread_rng, Rng};
|
||||||
use rand_distr::Uniform;
|
use rand_distr::Uniform;
|
||||||
|
|
||||||
|
pub(crate) fn decompose_r(
|
||||||
|
r: &[u64],
|
||||||
|
decomp_r: &mut [Vec<u64>],
|
||||||
|
decomposer: &DefaultDecomposer<u64>,
|
||||||
|
) {
|
||||||
|
let ring_size = r.len();
|
||||||
|
// let d = decomposer.decomposition_count();
|
||||||
|
// let mut count = 0;
|
||||||
|
for ri in 0..ring_size {
|
||||||
|
// let el_decomposed = decomposer.decompose(&r[ri]);
|
||||||
|
decomposer
|
||||||
|
.decompose_iter(&r[ri])
|
||||||
|
.enumerate()
|
||||||
|
.into_iter()
|
||||||
|
.for_each(|(j, el)| {
|
||||||
|
decomp_r[j][ri] = el;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn benchmark_decomposer(c: &mut Criterion) {
|
||||||
|
let mut group = c.benchmark_group("decomposer");
|
||||||
|
|
||||||
|
// let decomposers = vec![];
|
||||||
|
// 55
|
||||||
|
for prime in [36028797017456641] {
|
||||||
|
for ring_size in [1 << 11] {
|
||||||
|
let logb = 11;
|
||||||
|
let decomposer = DefaultDecomposer::new(prime, logb, 2);
|
||||||
|
|
||||||
|
let mut rng = thread_rng();
|
||||||
|
let dist = Uniform::new(0, prime);
|
||||||
|
let a = (&mut rng).sample_iter(dist).take(ring_size).collect_vec();
|
||||||
|
|
||||||
|
group.bench_function(
|
||||||
|
BenchmarkId::new(
|
||||||
|
"decompose",
|
||||||
|
format!(
|
||||||
|
"q={prime}/N={ring_size}/logB={logb}/d={}",
|
||||||
|
decomposer.decomposition_count()
|
||||||
|
),
|
||||||
|
),
|
||||||
|
|b| {
|
||||||
|
b.iter_batched_ref(
|
||||||
|
|| {
|
||||||
|
(
|
||||||
|
a.clone(),
|
||||||
|
vec![vec![0u64; ring_size]; decomposer.decomposition_count()],
|
||||||
|
)
|
||||||
|
},
|
||||||
|
|(r, decomp_r)| (decompose_r(r, decomp_r, &decomposer)),
|
||||||
|
criterion::BatchSize::PerIteration,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
group.finish();
|
||||||
|
}
|
||||||
|
|
||||||
fn benchmark(c: &mut Criterion) {
|
fn benchmark(c: &mut Criterion) {
|
||||||
let mut group = c.benchmark_group("modulus");
|
let mut group = c.benchmark_group("modulus");
|
||||||
// 55
|
// 55
|
||||||
@@ -34,5 +95,6 @@ fn benchmark(c: &mut Criterion) {
|
|||||||
group.finish();
|
group.finish();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
criterion_group!(decomposer, benchmark_decomposer);
|
||||||
criterion_group!(modulus, benchmark);
|
criterion_group!(modulus, benchmark);
|
||||||
criterion_main!(modulus);
|
criterion_main!(modulus, decomposer);
|
||||||
|
|||||||
@@ -103,13 +103,13 @@ mod test {
|
|||||||
println!("Gate time: {:?}", now.elapsed());
|
println!("Gate time: {:?}", now.elapsed());
|
||||||
|
|
||||||
// mp decrypt
|
// mp decrypt
|
||||||
let decryption_shares = cks
|
// let decryption_shares = cks
|
||||||
.iter()
|
// .iter()
|
||||||
.map(|c| evaluator.multi_party_decryption_share(&c_out, c))
|
// .map(|c| evaluator.multi_party_decryption_share(&c_out, c))
|
||||||
.collect_vec();
|
// .collect_vec();
|
||||||
let m_out = evaluator.multi_party_decrypt(&decryption_shares, &c_out);
|
// let m_out = evaluator.multi_party_decrypt(&decryption_shares, &c_out);
|
||||||
let m_expected = (m0 ^ m1);
|
let m_expected = (m0 ^ m1);
|
||||||
assert_eq!(m_expected, m_out, "Expected {m_expected} but got {m_out}");
|
// assert_eq!(m_expected, m_out, "Expected {m_expected} but got {m_out}");
|
||||||
|
|
||||||
// // find noise update
|
// // find noise update
|
||||||
// {
|
// {
|
||||||
|
|||||||
@@ -39,19 +39,33 @@ where
|
|||||||
|
|
||||||
pub trait Decomposer {
|
pub trait Decomposer {
|
||||||
type Element;
|
type Element;
|
||||||
|
type Iter: Iterator<Item = Self::Element>;
|
||||||
fn new(q: Self::Element, logb: usize, d: usize) -> Self;
|
fn new(q: Self::Element, logb: usize, d: usize) -> Self;
|
||||||
//FIXME(Jay): there's no reason why it returns a vec instead of an iterator
|
|
||||||
fn decompose(&self, v: &Self::Element) -> Vec<Self::Element>;
|
fn decompose_to_vec(&self, v: &Self::Element) -> Vec<Self::Element>;
|
||||||
|
fn decompose_iter(&self, v: &Self::Element) -> Self::Iter;
|
||||||
fn decomposition_count(&self) -> usize;
|
fn decomposition_count(&self) -> usize;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(Jay): Shouldn't Decompose also return corresponding gadget vector ?
|
|
||||||
pub struct DefaultDecomposer<T> {
|
pub struct DefaultDecomposer<T> {
|
||||||
|
/// Ciphertext modulus
|
||||||
q: T,
|
q: T,
|
||||||
|
/// Log of ciphertext modulus
|
||||||
logq: usize,
|
logq: usize,
|
||||||
|
/// Log of base B
|
||||||
logb: usize,
|
logb: usize,
|
||||||
|
/// base B
|
||||||
|
b: T,
|
||||||
|
/// (B - 1). To simulate (% B) as &(B-1), that is extract least significant
|
||||||
|
/// logb bits
|
||||||
|
b_mask: T,
|
||||||
|
/// B/2
|
||||||
|
bby2: T,
|
||||||
|
/// Decomposition count
|
||||||
d: usize,
|
d: usize,
|
||||||
|
/// No. of bits to ignore in rounding
|
||||||
ignore_bits: usize,
|
ignore_bits: usize,
|
||||||
|
/// No. of limbs to ignore in rounding. Set to ceil(logq / logb) - d
|
||||||
ignore_limbs: usize,
|
ignore_limbs: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -96,6 +110,7 @@ impl<T: PrimInt + ToPrimitive + FromPrimitive + WrappingSub + NumInfo> Decompose
|
|||||||
for DefaultDecomposer<T>
|
for DefaultDecomposer<T>
|
||||||
{
|
{
|
||||||
type Element = T;
|
type Element = T;
|
||||||
|
type Iter = DecomposerIter<T>;
|
||||||
|
|
||||||
fn new(q: T, logb: usize, d: usize) -> DefaultDecomposer<T> {
|
fn new(q: T, logb: usize, d: usize) -> DefaultDecomposer<T> {
|
||||||
// if q is power of 2, then `BITS - leading_zeros` outputs logq + 1.
|
// if q is power of 2, then `BITS - leading_zeros` outputs logq + 1.
|
||||||
@@ -113,6 +128,9 @@ impl<T: PrimInt + ToPrimitive + FromPrimitive + WrappingSub + NumInfo> Decompose
|
|||||||
q,
|
q,
|
||||||
logq,
|
logq,
|
||||||
logb,
|
logb,
|
||||||
|
b: T::one() << logb,
|
||||||
|
b_mask: (T::one() << logb) - T::one(),
|
||||||
|
bby2: T::one() << (logb - 1),
|
||||||
d,
|
d,
|
||||||
ignore_bits,
|
ignore_bits,
|
||||||
ignore_limbs,
|
ignore_limbs,
|
||||||
@@ -120,7 +138,7 @@ impl<T: PrimInt + ToPrimitive + FromPrimitive + WrappingSub + NumInfo> Decompose
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO(Jay): Outline the caveat
|
// TODO(Jay): Outline the caveat
|
||||||
fn decompose(&self, value: &T) -> Vec<T> {
|
fn decompose_to_vec(&self, value: &T) -> Vec<T> {
|
||||||
let mut value = round_value(*value, self.ignore_bits);
|
let mut value = round_value(*value, self.ignore_bits);
|
||||||
|
|
||||||
let q = self.q;
|
let q = self.q;
|
||||||
@@ -153,6 +171,75 @@ impl<T: PrimInt + ToPrimitive + FromPrimitive + WrappingSub + NumInfo> Decompose
|
|||||||
fn decomposition_count(&self) -> usize {
|
fn decomposition_count(&self) -> usize {
|
||||||
self.d
|
self.d
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn decompose_iter(&self, value: &T) -> DecomposerIter<T> {
|
||||||
|
let mut value = round_value(*value, self.ignore_bits);
|
||||||
|
|
||||||
|
if value >= (self.q >> 1) {
|
||||||
|
value = !(self.q - value) + T::one()
|
||||||
|
}
|
||||||
|
|
||||||
|
DecomposerIter {
|
||||||
|
value,
|
||||||
|
q: self.q,
|
||||||
|
logb: self.logb,
|
||||||
|
b: self.b,
|
||||||
|
bby2: self.bby2,
|
||||||
|
b_mask: self.b_mask,
|
||||||
|
steps_left: self.d,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: PrimInt> DefaultDecomposer<T> {}
|
||||||
|
|
||||||
|
pub struct DecomposerIter<T> {
|
||||||
|
/// Value to decompose
|
||||||
|
value: T,
|
||||||
|
steps_left: usize,
|
||||||
|
/// (1 << logb) - 1 (for % (1<<logb); i.e. to extract least signiciant logb
|
||||||
|
/// bits)
|
||||||
|
b_mask: T,
|
||||||
|
logb: usize,
|
||||||
|
// b/2 = 1 << (logb-1)
|
||||||
|
bby2: T,
|
||||||
|
/// Ciphertext modulus
|
||||||
|
q: T,
|
||||||
|
/// b = 1 << logb
|
||||||
|
b: T,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: PrimInt> Iterator for DecomposerIter<T> {
|
||||||
|
type Item = T;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
if self.steps_left != 0 {
|
||||||
|
self.steps_left -= 1;
|
||||||
|
let k_i = self.value & self.b_mask;
|
||||||
|
|
||||||
|
self.value = (self.value - k_i) >> self.logb;
|
||||||
|
|
||||||
|
if k_i > self.bby2 || (k_i == self.bby2 && ((self.value & T::one()) == T::one())) {
|
||||||
|
self.value = self.value + T::one();
|
||||||
|
Some(self.q + k_i - self.b)
|
||||||
|
} else {
|
||||||
|
Some(k_i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// let carry = <T as From<bool>>::from(
|
||||||
|
// k_i > self.bby2 || (k_i == self.bby2 && ((self.value &
|
||||||
|
// T::one()) == T::one())), );
|
||||||
|
// self.value = self.value + carry;
|
||||||
|
|
||||||
|
// Some(
|
||||||
|
// (self.q & ((carry << self.logq) - (T::one() & carry))) + k_i
|
||||||
|
// - (carry << self.logb), )
|
||||||
|
|
||||||
|
// Some(k_i)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn round_value<T: PrimInt>(value: T, ignore_bits: usize) -> T {
|
fn round_value<T: PrimInt>(value: T, ignore_bits: usize) -> T {
|
||||||
@@ -197,15 +284,18 @@ mod tests {
|
|||||||
let modq_op = ModularOpsU64::new(q);
|
let modq_op = ModularOpsU64::new(q);
|
||||||
for _ in 0..100000 {
|
for _ in 0..100000 {
|
||||||
let value = rng.gen_range(0..q);
|
let value = rng.gen_range(0..q);
|
||||||
let limbs = decomposer.decompose(&value);
|
let limbs = decomposer.decompose_to_vec(&value);
|
||||||
let value_back = decomposer.recompose(&limbs, &modq_op);
|
let value_back = round_value(
|
||||||
let rounded_value =
|
decomposer.recompose(&limbs, &modq_op),
|
||||||
round_value(value, decomposer.ignore_bits) << decomposer.ignore_bits;
|
decomposer.ignore_bits,
|
||||||
stats.add_more(&Vec::<i64>::try_convert_from(&limbs, &q));
|
);
|
||||||
|
let rounded_value = round_value(value, decomposer.ignore_bits);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
rounded_value, value_back,
|
rounded_value, value_back,
|
||||||
"Expected {rounded_value} got {value_back} for q={q}"
|
"Expected {rounded_value} got {value_back} for q={q}"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
stats.add_more(&Vec::<i64>::try_convert_from(&limbs, &q));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
println!("Mean: {}", stats.mean());
|
println!("Mean: {}", stats.mean());
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ mod shortint;
|
|||||||
mod utils;
|
mod utils;
|
||||||
|
|
||||||
pub use backend::{ModInit, ModularOpsU64, VectorOps};
|
pub use backend::{ModInit, ModularOpsU64, VectorOps};
|
||||||
|
pub use decomposer::{Decomposer, DecomposerIter, DefaultDecomposer};
|
||||||
pub use ntt::{Ntt, NttBackendU64, NttInit};
|
pub use ntt::{Ntt, NttBackendU64, NttInit};
|
||||||
|
|
||||||
pub trait Matrix: AsRef<[Self::R]> {
|
pub trait Matrix: AsRef<[Self::R]> {
|
||||||
|
|||||||
@@ -125,7 +125,7 @@ pub(crate) fn lwe_key_switch<
|
|||||||
.as_ref()
|
.as_ref()
|
||||||
.iter()
|
.iter()
|
||||||
.skip(1)
|
.skip(1)
|
||||||
.flat_map(|ai| decomposer.decompose(ai));
|
.flat_map(|ai| decomposer.decompose_to_vec(ai));
|
||||||
izip!(lwe_in_a_decomposed, lwe_ksk.iter_rows()).for_each(|(ai_j, beta_ij_lwe)| {
|
izip!(lwe_in_a_decomposed, lwe_ksk.iter_rows()).for_each(|(ai_j, beta_ij_lwe)| {
|
||||||
operator.elwise_fma_scalar_mut(lwe_out.as_mut(), beta_ij_lwe.as_ref(), &ai_j);
|
operator.elwise_fma_scalar_mut(lwe_out.as_mut(), beta_ij_lwe.as_ref(), &ai_j);
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -16,8 +16,7 @@ fn decomposer(mut value: u64, q: u64, d: usize, logb: u64) -> Vec<u64> {
|
|||||||
// for _ in 0..d {
|
// for _ in 0..d {
|
||||||
// let k_i = carry + (value & full_mask);
|
// let k_i = carry + (value & full_mask);
|
||||||
// value = (value) >> logb;
|
// value = (value) >> logb;
|
||||||
// let go = thread_rng().gen_bool(1.0 / 2.0);
|
// if k_i > bby2 {
|
||||||
// if k_i > bby2 || (k_i == bby2 && ((value & 1) == 1)) {
|
|
||||||
// // if (k_i == bby2 && ((value & 1) == 1)) {
|
// // if (k_i == bby2 && ((value & 1) == 1)) {
|
||||||
// // println!("AA");
|
// // println!("AA");
|
||||||
// // }
|
// // }
|
||||||
@@ -31,7 +30,6 @@ fn decomposer(mut value: u64, q: u64, d: usize, logb: u64) -> Vec<u64> {
|
|||||||
// carry = 0;
|
// carry = 0;
|
||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
// println!("Last carry {carry}");
|
|
||||||
// return out;
|
// return out;
|
||||||
|
|
||||||
let mut out = Vec::with_capacity(d);
|
let mut out = Vec::with_capacity(d);
|
||||||
|
|||||||
@@ -518,7 +518,8 @@ pub(crate) fn decompose_r<R: RowMut, D: Decomposer<Element = R::Element>>(
|
|||||||
let d = decomposer.decomposition_count();
|
let d = decomposer.decomposition_count();
|
||||||
|
|
||||||
for ri in 0..ring_size {
|
for ri in 0..ring_size {
|
||||||
let el_decomposed = decomposer.decompose(&r[ri]);
|
let el_decomposed = decomposer.decompose_to_vec(&r[ri]);
|
||||||
|
|
||||||
for j in 0..d {
|
for j in 0..d {
|
||||||
decomp_r[j].as_mut()[ri] = el_decomposed[j];
|
decomp_r[j].as_mut()[ri] = el_decomposed[j];
|
||||||
}
|
}
|
||||||
@@ -570,7 +571,7 @@ pub(crate) fn galois_auto<
|
|||||||
.for_each(|(el_in, to_index, sign)| {
|
.for_each(|(el_in, to_index, sign)| {
|
||||||
let el_out = if !*sign { mod_op.neg(el_in) } else { *el_in };
|
let el_out = if !*sign { mod_op.neg(el_in) } else { *el_in };
|
||||||
|
|
||||||
let el_out_decomposed = decomposer.decompose(&el_out);
|
let el_out_decomposed = decomposer.decompose_to_vec(&el_out);
|
||||||
for j in 0..d {
|
for j in 0..d {
|
||||||
scratch_matrix_d_ring[j].as_mut()[*to_index] = el_out_decomposed[j];
|
scratch_matrix_d_ring[j].as_mut()[*to_index] = el_out_decomposed[j];
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user