change decomp_iter last check

This commit is contained in:
Janmajaya Mall
2024-06-06 13:33:46 +05:30
parent 9b09549e18
commit 77039d7918
4 changed files with 27 additions and 21 deletions

View File

@@ -230,7 +230,7 @@ pub(super) struct BoolPbsInfo<M: Matrix, Ntt, RlweModOp, LweModOp> {
impl<M: Matrix, NttOp, RlweModOp, LweModOp> PbsInfo for BoolPbsInfo<M, NttOp, RlweModOp, LweModOp>
where
M::MatElement: PrimInt + WrappingSub + NumInfo + FromPrimitive + From<bool>,
M::MatElement: PrimInt + WrappingSub + NumInfo + FromPrimitive + From<bool> + Display,
RlweModOp: ArithmeticOps<Element = M::MatElement> + VectorOps<Element = M::MatElement>,
LweModOp: ArithmeticOps<Element = M::MatElement> + VectorOps<Element = M::MatElement>,
NttOp: Ntt<Element = M::MatElement>,

View File

@@ -103,13 +103,13 @@ mod test {
println!("Gate time: {:?}", now.elapsed());
// mp decrypt
// let decryption_shares = cks
// .iter()
// .map(|c| evaluator.multi_party_decryption_share(&c_out, c))
// .collect_vec();
// let m_out = evaluator.multi_party_decrypt(&decryption_shares, &c_out);
let decryption_shares = cks
.iter()
.map(|c| evaluator.multi_party_decryption_share(&c_out, c))
.collect_vec();
let m_out = evaluator.multi_party_decrypt(&decryption_shares, &c_out);
let m_expected = (m0 ^ m1);
// assert_eq!(m_expected, m_out, "Expected {m_expected} but got {m_out}");
assert_eq!(m_expected, m_out, "Expected {m_expected} but got {m_out}");
// // find noise update
// {

View File

@@ -1,6 +1,10 @@
use itertools::Itertools;
use num_traits::{AsPrimitive, FromPrimitive, Num, One, PrimInt, ToPrimitive, WrappingSub, Zero};
use std::{fmt::Debug, marker::PhantomData, ops::Rem};
use std::{
fmt::{Debug, Display},
marker::PhantomData,
ops::Rem,
};
use crate::backend::{ArithmeticOps, ModularOpsU64};
@@ -106,8 +110,8 @@ impl<T: PrimInt + NumInfo + Debug> DefaultDecomposer<T> {
}
}
impl<T: PrimInt + ToPrimitive + FromPrimitive + WrappingSub + NumInfo + From<bool>> Decomposer
for DefaultDecomposer<T>
impl<T: PrimInt + ToPrimitive + FromPrimitive + WrappingSub + NumInfo + From<bool> + Display>
Decomposer for DefaultDecomposer<T>
{
type Element = T;
type Iter = DecomposerIter<T>;
@@ -212,7 +216,7 @@ pub struct DecomposerIter<T> {
b: T,
}
impl<T: PrimInt + From<bool>> Iterator for DecomposerIter<T> {
impl<T: PrimInt + From<bool> + WrappingSub + Display> Iterator for DecomposerIter<T> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
@@ -235,14 +239,16 @@ impl<T: PrimInt + From<bool>> Iterator for DecomposerIter<T> {
// Suprisingly the improvement does not show up when I benchmark
// `decomposer_iter` in isolation. Putting this remark here as a
// future task to investiage (TODO).
let carry = <T as From<bool>>::from(
k_i > self.bby2 || (k_i == self.bby2 && ((self.value & T::one()) == T::one())),
);
let carry_bool =
k_i > self.bby2 || (k_i == self.bby2 && ((self.value & T::one()) == T::one()));
let carry = <T as From<bool>>::from(carry_bool);
let neg_carry = (T::zero().wrapping_sub(&carry)) >> 9;
self.value = self.value + carry;
Some((neg_carry & self.q) + k_i - (carry << self.logb))
Some(
(self.q & ((carry << self.logq) - (T::one() & carry))) + k_i - (carry << self.logb),
)
// Some(
// (self.q & ((carry << self.logq) - (T::one() & carry))) + k_i
// - (carry << self.logb), )
// Some(k_i)
} else {

View File

@@ -240,7 +240,7 @@ fn blind_rotation<
let s_indices = &gk_to_si[q_by_4 + i];
s_indices.iter().for_each(|s_index| {
let new = std::time::Instant::now();
// let new = std::time::Instant::now();
rlwe_by_rgsw(
trivial_rlwe_test_poly,
pbs_key.rgsw_ct_lwe_si(*s_index),
@@ -249,14 +249,14 @@ fn blind_rotation<
ntt_op,
mod_op,
);
println!("Rlwe x Rgsw time: {:?}", new.elapsed());
// println!("Rlwe x Rgsw time: {:?}", new.elapsed());
});
v += 1;
if gk_to_si[q_by_4 + i - 1].len() != 0 || v == w || i == 1 {
let (auto_map_index, auto_map_sign) = parameters.rlwe_auto_map(v);
let now = std::time::Instant::now();
// let now = std::time::Instant::now();
galois_auto(
trivial_rlwe_test_poly,
pbs_key.galois_key_for_auto(v),
@@ -267,7 +267,7 @@ fn blind_rotation<
ntt_op,
auto_decomposer,
);
println!("Auto time: {:?}", now.elapsed());
// println!("Auto time: {:?}", now.elapsed());
count += 1;
v = 0;