Browse Source

fix test

par-agg-key-shares
Janmajaya Mall 10 months ago
parent
commit
691995f1c3
3 changed files with 83 additions and 134 deletions
  1. +71
    -132
      src/bool/evaluator.rs
  2. +11
    -1
      src/bool/keys.rs
  3. +1
    -1
      src/bool/parameters.rs

+ 71
- 132
src/bool/evaluator.rs

@ -935,7 +935,7 @@ where
// record the order s.t. key_order[i] stores the position of i^th
// users key share in existing order
let mut key_order = Vec::with_capacity(existing_key_order.len());
(0..total_users).map(|i| {
(0..total_users).for_each(|i| {
// find i
let index = existing_key_order
.iter()
@ -2166,8 +2166,10 @@ mod tests {
use crate::{
backend::ModulusPowerOf2,
bool::{
self, CommonReferenceSeededMultiPartyServerKeyShare, PublicKey,
SeededMultiPartyServerKey, NON_INTERACTIVE_SMALL_MP_BOOL_PARAMS, SMALL_MP_BOOL_PARAMS,
self, CommonReferenceSeededMultiPartyServerKeyShare,
NonInteractiveServerKeyEvaluationDomain, PublicKey, SeededMultiPartyServerKey,
ShoupNonInteractiveServerKeyEvaluationDomain, NON_INTERACTIVE_SMALL_MP_BOOL_PARAMS,
SMALL_MP_BOOL_PARAMS,
},
ntt::NttBackendU64,
random::{RandomElementInModulus, DEFAULT_RNG},
@ -2478,7 +2480,7 @@ mod tests {
}
#[test]
fn multi_party_nand() {
fn interactive_multi_party_nand() {
let mut bool_evaluator = BoolEvaluator::<
Vec<Vec<u64>>,
NttBackendU64,
@ -2508,13 +2510,7 @@ mod tests {
.collect_vec();
let m_back = bool_evaluator.multi_party_decrypt(&decryption_shares, &lwe_out);
let m_back = bool_evaluator.sk_decrypt(&lwe_out, &ideal_client_key);
assert!(
m_expected == m_back,
"Expected {m_expected}, got
{m_back}"
);
assert!(m_expected == m_back, "Expected {m_expected}, got {m_back}");
m1 = m0;
m0 = m_expected;
@ -3267,7 +3263,7 @@ mod tests {
NttBackendU64,
ModularOpsU64<CiphertextModulus<u64>>,
ModulusPowerOf2<CiphertextModulus<u64>>,
ShoupServerKeyEvaluationDomain<Vec<Vec<u64>>>,
ShoupNonInteractiveServerKeyEvaluationDomain<Vec<Vec<u64>>>,
>::new(NON_INTERACTIVE_SMALL_MP_BOOL_PARAMS);
let mp_seed = NonInteractiveMultiPartyCrs { seed: [1u8; 32] };
@ -3287,11 +3283,15 @@ mod tests {
.collect_vec();
// dbg!(key_shares[1].user_index);
let rgsw_cts = evaluator.aggregate_non_interactive_multi_party_key_share(
let seeded_server_key = evaluator.aggregate_non_interactive_multi_party_key_share(
&mp_seed,
parties,
&key_shares,
);
let server_key_evaluation_domain =
NonInteractiveServerKeyEvaluationDomain::<_, _, DefaultSecureRng, NttBackendU64>::from(
seeded_server_key,
);
let mut ideal_rlwe = vec![0; ring_size];
cks.iter().for_each(|k| {
@ -3307,132 +3307,71 @@ mod tests {
});
});
// let mut stats = Stats::new();
let mut stats = Stats::new();
let (rlrg_decomp_a, rlrg_decomp_b) = evaluator
.parameters()
.rlwe_rgsw_decomposer::<DefaultDecomposer<u64>>();
let gadget_vec_a = rlrg_decomp_a.gadget_vector();
let gadget_vec_b = rlrg_decomp_b.gadget_vector();
let d_a = rlrg_decomp_a.decomposition_count();
let d_b = rlrg_decomp_b.decomposition_count();
let s_poly = Vec::<u64>::try_convert_from(ideal_rlwe.as_slice(), rlwe_q);
let mut neg_s_poly_eval = s_poly.clone();
rlwe_modop.elwise_neg_mut(&mut neg_s_poly_eval);
nttop.forward(neg_s_poly_eval.as_mut());
// rgsw_cts.iter().enumerate().for_each(|(s_index, ct)| {
// // X^{lwe_s[i]}
// let mut m = vec![0u64; ring_size];
// if ideal_lwe[s_index] < 0 {
// m[ring_size - (ideal_lwe[s_index].abs() as usize)] =
// rlwe_q.neg_one(); } else {
// m[(ideal_lwe[s_index] as usize)] = 1;
// }
// let mut neg_sm = m.clone();
// nttop.forward(&mut neg_sm);
// rlwe_modop.elwise_mul_mut(&mut neg_sm, &neg_s_poly_eval);
// nttop.backward(&mut neg_sm);
{
let (rlrg_decomp_a, rlrg_decomp_b) = evaluator
.parameters()
.rlwe_rgsw_decomposer::<DefaultDecomposer<u64>>();
let gadget_vec_a = rlrg_decomp_a.gadget_vector();
let gadget_vec_b = rlrg_decomp_b.gadget_vector();
let d_a = rlrg_decomp_a.decomposition_count();
let d_b = rlrg_decomp_b.decomposition_count();
// s[X]
let s_poly = Vec::<u64>::try_convert_from(ideal_rlwe.as_slice(), rlwe_q);
// -s[X]
let mut neg_s_poly_eval = s_poly.clone();
rlwe_modop.elwise_neg_mut(&mut neg_s_poly_eval);
nttop.forward(neg_s_poly_eval.as_mut());
server_key_evaluation_domain
.rgsw_cts()
.iter()
.enumerate()
.for_each(|(s_index, ct)| {
// X^{lwe_s[i]}
let mut m = vec![0u64; ring_size];
let s_i = ideal_lwe[s_index] * (evaluator.pbs_info().embedding_factor() as i32);
if s_i < 0 {
m[ring_size - (s_i.abs() as usize)] = rlwe_q.neg_one();
} else {
m[(s_i as usize)] = 1;
}
// // RLWE'(-sm)
// gadget_vec_a.iter().enumerate().for_each(|(index, beta)| {
// // RLWE(\beta -sm)
let mut neg_sm = m.clone();
nttop.forward(&mut neg_sm);
rlwe_modop.elwise_mul_mut(&mut neg_sm, &neg_s_poly_eval);
nttop.backward(&mut neg_sm);
// // \beta * -sX^[lwe_s[i]]
// let mut beta_neg_sm = neg_sm.clone();
// rlwe_modop.elwise_scalar_mul_mut(&mut beta_neg_sm, beta);
// RLWE'(-sm)
gadget_vec_a.iter().enumerate().for_each(|(index, beta)| {
// RLWE(\beta -sm)
// // extract RLWE(-sm \beta)
// let mut rlwe = vec![vec![0u64; ring_size]; 2];
// rlwe[0].copy_from_slice(&ct[index]);
// rlwe[1].copy_from_slice(&ct[index + d_a]);
// \beta * -sX^[lwe_s[i]]
let mut beta_neg_sm = neg_sm.clone();
rlwe_modop.elwise_scalar_mul_mut(&mut beta_neg_sm, beta);
// // decrypt
// let mut m_out = vec![0u64; ring_size];
// decrypt_rlwe(&rlwe, &ideal_rlwe, &mut m_out, nttop,
// rlwe_modop); // println!("{:?}", &beta_neg_sm);
// extract RLWE(-sm \beta)
let mut rlwe = vec![vec![0u64; ring_size]; 2];
rlwe[0].copy_from_slice(&ct[index]);
rlwe[1].copy_from_slice(&ct[index + d_a]);
// send back to coefficient domain
rlwe.iter_rows_mut()
.for_each(|r| nttop.backward(r.as_mut_slice()));
// let mut diff = m_out;
// rlwe_modop.elwise_sub_mut(&mut diff, &beta_neg_sm);
// decrypt
let mut m_out = vec![0u64; ring_size];
decrypt_rlwe(&rlwe, &ideal_rlwe, &mut m_out, nttop, rlwe_modop); // println!("{:?}", &beta_neg_sm);
// stats.add_more(&Vec::<i64>::try_convert_from(&diff, rlwe_q));
// });
// });
let mut diff = m_out;
rlwe_modop.elwise_sub_mut(&mut diff, &beta_neg_sm);
// println!("Stats: {}", stats.std_dev().abs().log2());
stats.add_more(&Vec::<i64>::try_convert_from(&diff, rlwe_q));
});
});
}
println!("Stats: {}", stats.std_dev().abs().log2());
}
}
// let (rgswrgsw_d_a, rgswrgsw_d_b) =
// self.pbs_info.parameters.rgsw_rgsw_decomposition_count(); let (rlrg_d_a,
// rlrg_d_b) = self.pbs_info.parameters.rlwe_rgsw_decomposition_count(); let
// rgsw_ct_rows_in = rgswrgsw_d_a.0 * 2 + rgswrgsw_d_b.0 * 2;
// let rgsw_ct_rows_out = rlrg_d_a.0 * 2 + rlrg_d_b.0 * 2;
// assert!(rgswrgsw_d_a.0 >= rlrg_d_a.0, "RGSWxRGSW part A decomposition count {} must be >= RLWExRGSW part A decomposition count {}", rgswrgsw_d_a.0 , rlrg_d_a.0);
// assert!(rgswrgsw_d_b.0 >= rlrg_d_b.0, "RGSWxRGSW part B decomposition count {} must be >= RLWExRGSW part B decomposition count {}", rgswrgsw_d_b.0 , rlrg_d_b.0);
// let rgsw_cts = rgsw_cts
// .map(|ct_i_in| {
// assert!(ct_i_in.dimension() == (rgsw_ct_rows_in, rlwe_n));
// let mut reduced_ct_i_out = M::zeros(rgsw_ct_rows_out, rlwe_n);
// // RLWE'(-sm) part A
// izip!(
// reduced_ct_i_out.iter_rows_mut().take(rlrg_d_a.0),
// ct_i_in
// .iter_rows()
// .skip(rgswrgsw_d_a.0 - rlrg_d_a.0)
// .take(rlrg_d_a.0)
// )
// .for_each(|(to_ri, from_ri)| {
// to_ri.as_mut().copy_from_slice(from_ri.as_ref());
// });
// // RLWE'(-sm) part B
// izip!(
// reduced_ct_i_out
// .iter_rows_mut()
// .skip(rlrg_d_a.0)
// .take(rlrg_d_a.0),
// ct_i_in
// .iter_rows()
// .skip(rgswrgsw_d_a.0 + (rgswrgsw_d_a.0 - rlrg_d_a.0))
// .take(rlrg_d_a.0)
// )
// .for_each(|(to_ri, from_ri)| {
// to_ri.as_mut().copy_from_slice(from_ri.as_ref());
// });
// // RLWE'(m) Part A
// izip!(
// reduced_ct_i_out
// .iter_rows_mut()
// .skip(rlrg_d_a.0 * 2)
// .take(rlrg_d_b.0),
// ct_i_in
// .iter_rows()
// .skip(rgswrgsw_d_a.0 * 2 + (rgswrgsw_d_b.0 - rlrg_d_b.0))
// .take(rlrg_d_b.0)
// )
// .for_each(|(to_ri, from_ri)| {
// to_ri.as_mut().copy_from_slice(from_ri.as_ref());
// });
// // RLWE'(m) Part B
// izip!(
// reduced_ct_i_out
// .iter_rows_mut()
// .skip(rlrg_d_a.0 * 2 + rlrg_d_b.0)
// .take(rlrg_d_b.0),
// ct_i_in
// .iter_rows()
// .skip(rgswrgsw_d_a.0 * 2 + rgswrgsw_d_b.0 +
// (rgswrgsw_d_b.0 - rlrg_d_b.0)) .take(rlrg_d_b.0)
// )
// .for_each(|(to_ri, from_ri)| {
// to_ri.as_mut().copy_from_slice(from_ri.as_ref());
// });
// reduced_ct_i_out
// })
// .collect_vec();

+ 11
- 1
src/bool/keys.rs

@ -740,6 +740,12 @@ pub(super) mod impl_non_interactive_server_key_eval_domain {
use super::*;
impl<M, P, R, N> NonInteractiveServerKeyEvaluationDomain<M, P, R, N> {
pub(in super::super) fn rgsw_cts(&self) -> &[M] {
&self.rgsw_cts
}
}
impl<M, Rng, N>
From<
SeededNonInteractiveMultiPartyServerKey<
@ -836,6 +842,10 @@ pub(super) mod impl_non_interactive_server_key_eval_domain {
)
});
// copy over part bs
assert!(
value.lwe_ksk.as_ref().len()
== value.parameters.lwe_decomposition_count().0 * ring_size
);
izip!(value.lwe_ksk.as_ref().iter(), lwe_ksk.iter_rows_mut()).for_each(
|(b_el, lwe_ct)| {
lwe_ct.as_mut()[0] = *b_el;
@ -865,7 +875,7 @@ pub(super) mod impl_non_interactive_server_key_eval_domain {
let incoming_ksk_partb_ref =
&value.ui_to_s_ksks[value.ui_to_s_ksks_key_order[user_index]];
assert!(ksk_ct.dimension() == (d_uitos, ring_size));
assert!(incoming_ksk_partb_ref.dimension() == (d_uitos, ring_size));
izip!(
ksk_ct.iter_rows_mut().skip(d_uitos),
incoming_ksk_partb_ref.iter_rows()

+ 1
- 1
src/bool/parameters.rs

@ -504,7 +504,7 @@ pub(crate) const SMALL_MP_BOOL_PARAMS: BoolParameters = BoolParameters::
pub(crate) const NON_INTERACTIVE_SMALL_MP_BOOL_PARAMS: BoolParameters<u64> = BoolParameters::<u64> {
rlwe_q: CiphertextModulus::new_non_native(36028797018820609),
lwe_q: CiphertextModulus::new_non_native(1 << 20),
br_q: 1 << 12,
br_q: 1 << 11,
rlwe_n: PolynomialSize(1 << 11),
lwe_n: LweDimension(10),
lwe_decomposer_params: (DecompostionLogBase(4), DecompositionCount(5)),

Loading…
Cancel
Save