From 77039d7918b143b9707d948f1314d603ada44512 Mon Sep 17 00:00:00 2001 From: Janmajaya Mall Date: Thu, 6 Jun 2024 13:33:46 +0530 Subject: [PATCH] change decomp_iter last check --- src/bool/evaluator.rs | 2 +- src/bool/noise.rs | 12 ++++++------ src/decomposer.rs | 26 ++++++++++++++++---------- src/pbs.rs | 8 ++++---- 4 files changed, 27 insertions(+), 21 deletions(-) diff --git a/src/bool/evaluator.rs b/src/bool/evaluator.rs index 5f28000..f03ec12 100644 --- a/src/bool/evaluator.rs +++ b/src/bool/evaluator.rs @@ -230,7 +230,7 @@ pub(super) struct BoolPbsInfo { impl PbsInfo for BoolPbsInfo where - M::MatElement: PrimInt + WrappingSub + NumInfo + FromPrimitive + From, + M::MatElement: PrimInt + WrappingSub + NumInfo + FromPrimitive + From + Display, RlweModOp: ArithmeticOps + VectorOps, LweModOp: ArithmeticOps + VectorOps, NttOp: Ntt, diff --git a/src/bool/noise.rs b/src/bool/noise.rs index a6ff6b0..d2b7491 100644 --- a/src/bool/noise.rs +++ b/src/bool/noise.rs @@ -103,13 +103,13 @@ mod test { println!("Gate time: {:?}", now.elapsed()); // mp decrypt - // let decryption_shares = cks - // .iter() - // .map(|c| evaluator.multi_party_decryption_share(&c_out, c)) - // .collect_vec(); - // let m_out = evaluator.multi_party_decrypt(&decryption_shares, &c_out); + let decryption_shares = cks + .iter() + .map(|c| evaluator.multi_party_decryption_share(&c_out, c)) + .collect_vec(); + let m_out = evaluator.multi_party_decrypt(&decryption_shares, &c_out); let m_expected = (m0 ^ m1); - // assert_eq!(m_expected, m_out, "Expected {m_expected} but got {m_out}"); + assert_eq!(m_expected, m_out, "Expected {m_expected} but got {m_out}"); // // find noise update // { diff --git a/src/decomposer.rs b/src/decomposer.rs index 4337bef..2275458 100644 --- a/src/decomposer.rs +++ b/src/decomposer.rs @@ -1,6 +1,10 @@ use itertools::Itertools; use num_traits::{AsPrimitive, FromPrimitive, Num, One, PrimInt, ToPrimitive, WrappingSub, Zero}; -use std::{fmt::Debug, marker::PhantomData, ops::Rem}; +use std::{ + fmt::{Debug, Display}, + marker::PhantomData, + ops::Rem, +}; use crate::backend::{ArithmeticOps, ModularOpsU64}; @@ -106,8 +110,8 @@ impl DefaultDecomposer { } } -impl> Decomposer - for DefaultDecomposer +impl + Display> + Decomposer for DefaultDecomposer { type Element = T; type Iter = DecomposerIter; @@ -212,7 +216,7 @@ pub struct DecomposerIter { b: T, } -impl> Iterator for DecomposerIter { +impl + WrappingSub + Display> Iterator for DecomposerIter { type Item = T; fn next(&mut self) -> Option { @@ -235,14 +239,16 @@ impl> Iterator for DecomposerIter { // Suprisingly the improvement does not show up when I benchmark // `decomposer_iter` in isolation. Putting this remark here as a // future task to investiage (TODO). - let carry = >::from( - k_i > self.bby2 || (k_i == self.bby2 && ((self.value & T::one()) == T::one())), - ); + let carry_bool = + k_i > self.bby2 || (k_i == self.bby2 && ((self.value & T::one()) == T::one())); + let carry = >::from(carry_bool); + let neg_carry = (T::zero().wrapping_sub(&carry)) >> 9; self.value = self.value + carry; + Some((neg_carry & self.q) + k_i - (carry << self.logb)) - Some( - (self.q & ((carry << self.logq) - (T::one() & carry))) + k_i - (carry << self.logb), - ) + // Some( + // (self.q & ((carry << self.logq) - (T::one() & carry))) + k_i + // - (carry << self.logb), ) // Some(k_i) } else { diff --git a/src/pbs.rs b/src/pbs.rs index 698be32..ddfda22 100644 --- a/src/pbs.rs +++ b/src/pbs.rs @@ -240,7 +240,7 @@ fn blind_rotation< let s_indices = &gk_to_si[q_by_4 + i]; s_indices.iter().for_each(|s_index| { - let new = std::time::Instant::now(); + // let new = std::time::Instant::now(); rlwe_by_rgsw( trivial_rlwe_test_poly, pbs_key.rgsw_ct_lwe_si(*s_index), @@ -249,14 +249,14 @@ fn blind_rotation< ntt_op, mod_op, ); - println!("Rlwe x Rgsw time: {:?}", new.elapsed()); + // println!("Rlwe x Rgsw time: {:?}", new.elapsed()); }); v += 1; if gk_to_si[q_by_4 + i - 1].len() != 0 || v == w || i == 1 { let (auto_map_index, auto_map_sign) = parameters.rlwe_auto_map(v); - let now = std::time::Instant::now(); + // let now = std::time::Instant::now(); galois_auto( trivial_rlwe_test_poly, pbs_key.galois_key_for_auto(v), @@ -267,7 +267,7 @@ fn blind_rotation< ntt_op, auto_decomposer, ); - println!("Auto time: {:?}", now.elapsed()); + // println!("Auto time: {:?}", now.elapsed()); count += 1; v = 0;