From 691995f1c334641cc38581e2794c9ee4c007d7f4 Mon Sep 17 00:00:00 2001 From: Janmajaya Mall Date: Mon, 17 Jun 2024 14:55:41 +0530 Subject: [PATCH] fix test --- src/bool/evaluator.rs | 203 ++++++++++++++--------------------------- src/bool/keys.rs | 12 ++- src/bool/parameters.rs | 2 +- 3 files changed, 83 insertions(+), 134 deletions(-) diff --git a/src/bool/evaluator.rs b/src/bool/evaluator.rs index c2ae5b3..d712445 100644 --- a/src/bool/evaluator.rs +++ b/src/bool/evaluator.rs @@ -935,7 +935,7 @@ where // record the order s.t. key_order[i] stores the position of i^th // users key share in existing order let mut key_order = Vec::with_capacity(existing_key_order.len()); - (0..total_users).map(|i| { + (0..total_users).for_each(|i| { // find i let index = existing_key_order .iter() @@ -2166,8 +2166,10 @@ mod tests { use crate::{ backend::ModulusPowerOf2, bool::{ - self, CommonReferenceSeededMultiPartyServerKeyShare, PublicKey, - SeededMultiPartyServerKey, NON_INTERACTIVE_SMALL_MP_BOOL_PARAMS, SMALL_MP_BOOL_PARAMS, + self, CommonReferenceSeededMultiPartyServerKeyShare, + NonInteractiveServerKeyEvaluationDomain, PublicKey, SeededMultiPartyServerKey, + ShoupNonInteractiveServerKeyEvaluationDomain, NON_INTERACTIVE_SMALL_MP_BOOL_PARAMS, + SMALL_MP_BOOL_PARAMS, }, ntt::NttBackendU64, random::{RandomElementInModulus, DEFAULT_RNG}, @@ -2478,7 +2480,7 @@ mod tests { } #[test] - fn multi_party_nand() { + fn interactive_multi_party_nand() { let mut bool_evaluator = BoolEvaluator::< Vec>, NttBackendU64, @@ -2508,13 +2510,7 @@ mod tests { .collect_vec(); let m_back = bool_evaluator.multi_party_decrypt(&decryption_shares, &lwe_out); - let m_back = bool_evaluator.sk_decrypt(&lwe_out, &ideal_client_key); - - assert!( - m_expected == m_back, - "Expected {m_expected}, got -{m_back}" - ); + assert!(m_expected == m_back, "Expected {m_expected}, got {m_back}"); m1 = m0; m0 = m_expected; @@ -3267,7 +3263,7 @@ mod tests { NttBackendU64, ModularOpsU64>, ModulusPowerOf2>, - ShoupServerKeyEvaluationDomain>>, + ShoupNonInteractiveServerKeyEvaluationDomain>>, >::new(NON_INTERACTIVE_SMALL_MP_BOOL_PARAMS); let mp_seed = NonInteractiveMultiPartyCrs { seed: [1u8; 32] }; @@ -3287,11 +3283,15 @@ mod tests { .collect_vec(); // dbg!(key_shares[1].user_index); - let rgsw_cts = evaluator.aggregate_non_interactive_multi_party_key_share( + let seeded_server_key = evaluator.aggregate_non_interactive_multi_party_key_share( &mp_seed, parties, &key_shares, ); + let server_key_evaluation_domain = + NonInteractiveServerKeyEvaluationDomain::<_, _, DefaultSecureRng, NttBackendU64>::from( + seeded_server_key, + ); let mut ideal_rlwe = vec![0; ring_size]; cks.iter().for_each(|k| { @@ -3307,132 +3307,71 @@ mod tests { }); }); - // let mut stats = Stats::new(); + let mut stats = Stats::new(); - let (rlrg_decomp_a, rlrg_decomp_b) = evaluator - .parameters() - .rlwe_rgsw_decomposer::>(); - let gadget_vec_a = rlrg_decomp_a.gadget_vector(); - let gadget_vec_b = rlrg_decomp_b.gadget_vector(); - let d_a = rlrg_decomp_a.decomposition_count(); - let d_b = rlrg_decomp_b.decomposition_count(); - let s_poly = Vec::::try_convert_from(ideal_rlwe.as_slice(), rlwe_q); - let mut neg_s_poly_eval = s_poly.clone(); - rlwe_modop.elwise_neg_mut(&mut neg_s_poly_eval); - nttop.forward(neg_s_poly_eval.as_mut()); - - // rgsw_cts.iter().enumerate().for_each(|(s_index, ct)| { - // // X^{lwe_s[i]} - // let mut m = vec![0u64; ring_size]; - // if ideal_lwe[s_index] < 0 { - // m[ring_size - (ideal_lwe[s_index].abs() as usize)] = - // rlwe_q.neg_one(); } else { - // m[(ideal_lwe[s_index] as usize)] = 1; - // } - - // let mut neg_sm = m.clone(); - // nttop.forward(&mut neg_sm); - // rlwe_modop.elwise_mul_mut(&mut neg_sm, &neg_s_poly_eval); - // nttop.backward(&mut neg_sm); + { + let (rlrg_decomp_a, rlrg_decomp_b) = evaluator + .parameters() + .rlwe_rgsw_decomposer::>(); + let gadget_vec_a = rlrg_decomp_a.gadget_vector(); + let gadget_vec_b = rlrg_decomp_b.gadget_vector(); + let d_a = rlrg_decomp_a.decomposition_count(); + let d_b = rlrg_decomp_b.decomposition_count(); + + // s[X] + let s_poly = Vec::::try_convert_from(ideal_rlwe.as_slice(), rlwe_q); + + // -s[X] + let mut neg_s_poly_eval = s_poly.clone(); + rlwe_modop.elwise_neg_mut(&mut neg_s_poly_eval); + nttop.forward(neg_s_poly_eval.as_mut()); + + server_key_evaluation_domain + .rgsw_cts() + .iter() + .enumerate() + .for_each(|(s_index, ct)| { + // X^{lwe_s[i]} + let mut m = vec![0u64; ring_size]; + let s_i = ideal_lwe[s_index] * (evaluator.pbs_info().embedding_factor() as i32); + if s_i < 0 { + m[ring_size - (s_i.abs() as usize)] = rlwe_q.neg_one(); + } else { + m[(s_i as usize)] = 1; + } - // // RLWE'(-sm) - // gadget_vec_a.iter().enumerate().for_each(|(index, beta)| { - // // RLWE(\beta -sm) + let mut neg_sm = m.clone(); + nttop.forward(&mut neg_sm); + rlwe_modop.elwise_mul_mut(&mut neg_sm, &neg_s_poly_eval); + nttop.backward(&mut neg_sm); - // // \beta * -sX^[lwe_s[i]] - // let mut beta_neg_sm = neg_sm.clone(); - // rlwe_modop.elwise_scalar_mul_mut(&mut beta_neg_sm, beta); + // RLWE'(-sm) + gadget_vec_a.iter().enumerate().for_each(|(index, beta)| { + // RLWE(\beta -sm) - // // extract RLWE(-sm \beta) - // let mut rlwe = vec![vec![0u64; ring_size]; 2]; - // rlwe[0].copy_from_slice(&ct[index]); - // rlwe[1].copy_from_slice(&ct[index + d_a]); + // \beta * -sX^[lwe_s[i]] + let mut beta_neg_sm = neg_sm.clone(); + rlwe_modop.elwise_scalar_mul_mut(&mut beta_neg_sm, beta); - // // decrypt - // let mut m_out = vec![0u64; ring_size]; - // decrypt_rlwe(&rlwe, &ideal_rlwe, &mut m_out, nttop, - // rlwe_modop); // println!("{:?}", &beta_neg_sm); + // extract RLWE(-sm \beta) + let mut rlwe = vec![vec![0u64; ring_size]; 2]; + rlwe[0].copy_from_slice(&ct[index]); + rlwe[1].copy_from_slice(&ct[index + d_a]); + // send back to coefficient domain + rlwe.iter_rows_mut() + .for_each(|r| nttop.backward(r.as_mut_slice())); - // let mut diff = m_out; - // rlwe_modop.elwise_sub_mut(&mut diff, &beta_neg_sm); + // decrypt + let mut m_out = vec![0u64; ring_size]; + decrypt_rlwe(&rlwe, &ideal_rlwe, &mut m_out, nttop, rlwe_modop); // println!("{:?}", &beta_neg_sm); - // stats.add_more(&Vec::::try_convert_from(&diff, rlwe_q)); - // }); - // }); + let mut diff = m_out; + rlwe_modop.elwise_sub_mut(&mut diff, &beta_neg_sm); - // println!("Stats: {}", stats.std_dev().abs().log2()); + stats.add_more(&Vec::::try_convert_from(&diff, rlwe_q)); + }); + }); + } + println!("Stats: {}", stats.std_dev().abs().log2()); } } - -// let (rgswrgsw_d_a, rgswrgsw_d_b) = -// self.pbs_info.parameters.rgsw_rgsw_decomposition_count(); let (rlrg_d_a, -// rlrg_d_b) = self.pbs_info.parameters.rlwe_rgsw_decomposition_count(); let -// rgsw_ct_rows_in = rgswrgsw_d_a.0 * 2 + rgswrgsw_d_b.0 * 2; -// let rgsw_ct_rows_out = rlrg_d_a.0 * 2 + rlrg_d_b.0 * 2; -// assert!(rgswrgsw_d_a.0 >= rlrg_d_a.0, "RGSWxRGSW part A decomposition count {} must be >= RLWExRGSW part A decomposition count {}", rgswrgsw_d_a.0 , rlrg_d_a.0); -// assert!(rgswrgsw_d_b.0 >= rlrg_d_b.0, "RGSWxRGSW part B decomposition count {} must be >= RLWExRGSW part B decomposition count {}", rgswrgsw_d_b.0 , rlrg_d_b.0); -// let rgsw_cts = rgsw_cts -// .map(|ct_i_in| { -// assert!(ct_i_in.dimension() == (rgsw_ct_rows_in, rlwe_n)); -// let mut reduced_ct_i_out = M::zeros(rgsw_ct_rows_out, rlwe_n); - -// // RLWE'(-sm) part A -// izip!( -// reduced_ct_i_out.iter_rows_mut().take(rlrg_d_a.0), -// ct_i_in -// .iter_rows() -// .skip(rgswrgsw_d_a.0 - rlrg_d_a.0) -// .take(rlrg_d_a.0) -// ) -// .for_each(|(to_ri, from_ri)| { -// to_ri.as_mut().copy_from_slice(from_ri.as_ref()); -// }); - -// // RLWE'(-sm) part B -// izip!( -// reduced_ct_i_out -// .iter_rows_mut() -// .skip(rlrg_d_a.0) -// .take(rlrg_d_a.0), -// ct_i_in -// .iter_rows() -// .skip(rgswrgsw_d_a.0 + (rgswrgsw_d_a.0 - rlrg_d_a.0)) -// .take(rlrg_d_a.0) -// ) -// .for_each(|(to_ri, from_ri)| { -// to_ri.as_mut().copy_from_slice(from_ri.as_ref()); -// }); - -// // RLWE'(m) Part A -// izip!( -// reduced_ct_i_out -// .iter_rows_mut() -// .skip(rlrg_d_a.0 * 2) -// .take(rlrg_d_b.0), -// ct_i_in -// .iter_rows() -// .skip(rgswrgsw_d_a.0 * 2 + (rgswrgsw_d_b.0 - rlrg_d_b.0)) -// .take(rlrg_d_b.0) -// ) -// .for_each(|(to_ri, from_ri)| { -// to_ri.as_mut().copy_from_slice(from_ri.as_ref()); -// }); - -// // RLWE'(m) Part B -// izip!( -// reduced_ct_i_out -// .iter_rows_mut() -// .skip(rlrg_d_a.0 * 2 + rlrg_d_b.0) -// .take(rlrg_d_b.0), -// ct_i_in -// .iter_rows() -// .skip(rgswrgsw_d_a.0 * 2 + rgswrgsw_d_b.0 + -// (rgswrgsw_d_b.0 - rlrg_d_b.0)) .take(rlrg_d_b.0) -// ) -// .for_each(|(to_ri, from_ri)| { -// to_ri.as_mut().copy_from_slice(from_ri.as_ref()); -// }); - -// reduced_ct_i_out -// }) -// .collect_vec(); diff --git a/src/bool/keys.rs b/src/bool/keys.rs index de5ffe2..289e0a9 100644 --- a/src/bool/keys.rs +++ b/src/bool/keys.rs @@ -740,6 +740,12 @@ pub(super) mod impl_non_interactive_server_key_eval_domain { use super::*; + impl NonInteractiveServerKeyEvaluationDomain { + pub(in super::super) fn rgsw_cts(&self) -> &[M] { + &self.rgsw_cts + } + } + impl From< SeededNonInteractiveMultiPartyServerKey< @@ -836,6 +842,10 @@ pub(super) mod impl_non_interactive_server_key_eval_domain { ) }); // copy over part bs + assert!( + value.lwe_ksk.as_ref().len() + == value.parameters.lwe_decomposition_count().0 * ring_size + ); izip!(value.lwe_ksk.as_ref().iter(), lwe_ksk.iter_rows_mut()).for_each( |(b_el, lwe_ct)| { lwe_ct.as_mut()[0] = *b_el; @@ -865,7 +875,7 @@ pub(super) mod impl_non_interactive_server_key_eval_domain { let incoming_ksk_partb_ref = &value.ui_to_s_ksks[value.ui_to_s_ksks_key_order[user_index]]; - assert!(ksk_ct.dimension() == (d_uitos, ring_size)); + assert!(incoming_ksk_partb_ref.dimension() == (d_uitos, ring_size)); izip!( ksk_ct.iter_rows_mut().skip(d_uitos), incoming_ksk_partb_ref.iter_rows() diff --git a/src/bool/parameters.rs b/src/bool/parameters.rs index c251f47..4e603ab 100644 --- a/src/bool/parameters.rs +++ b/src/bool/parameters.rs @@ -504,7 +504,7 @@ pub(crate) const SMALL_MP_BOOL_PARAMS: BoolParameters = BoolParameters:: = BoolParameters:: { rlwe_q: CiphertextModulus::new_non_native(36028797018820609), lwe_q: CiphertextModulus::new_non_native(1 << 20), - br_q: 1 << 12, + br_q: 1 << 11, rlwe_n: PolynomialSize(1 << 11), lwe_n: LweDimension(10), lwe_decomposer_params: (DecompostionLogBase(4), DecompositionCount(5)),