diff --git a/src/bool/evaluator.rs b/src/bool/evaluator.rs index 7f748a8..c3fb8f1 100644 --- a/src/bool/evaluator.rs +++ b/src/bool/evaluator.rs @@ -856,13 +856,16 @@ where .iter() .map(|si| { let mut m = M::R::zeros(ring_size); - - if *si < 0 { + //TODO(Jay): It will be nice to have a function that returns polynomial + // (monomial infact!) corresponding to secret element embedded in ring X^{2N+1}. + // Save lots of mistakes where one forgest to emebed si in bigger ring. + let si = *si * (self.embedding_factor as i32); + if si < 0 { // X^{-si} = X^{2N-si} = -X^{N-si}, assuming abs(si) < N // (which it is given si is secret element) m.as_mut()[ring_size - (si.abs() as usize)] = rlwe_q - M::MatElement::one(); } else { - m.as_mut()[*si as usize] = M::MatElement::one(); + m.as_mut()[si as usize] = M::MatElement::one(); } // public key RGSW encryption has no part that can be seeded, unlike secret key @@ -1812,11 +1815,174 @@ mod tests { } #[test] - fn mp_key_correcntess() { - let bool_evaluator = - BoolEvaluator::>, u64, NttBackendU64, ModularOpsU64>::new(MP_BOOL_PARAMS); + fn multi_party_lwe_keyswitch() { + let lwe_logq = 18; + let lwe_q = 1 << lwe_logq; + let d_lwe = 4; + let logb_lwe = 4; + let lwe_gadgect_vec = gadget_vector(lwe_logq, logb_lwe, d_lwe); + let lweq_modop = ModularOpsU64::new(lwe_q); + let logp = 2; + + let from_lwe_n = 2048; + let to_lwe_n = 583; let no_of_parties = 10; + let parties_from_lwe_sk = (0..no_of_parties) + .map(|_| LweSecret::random(from_lwe_n >> 1, from_lwe_n)) + .collect_vec(); + let parties_to_lwe_sk = (0..no_of_parties) + .map(|_| LweSecret::random(to_lwe_n >> 1, to_lwe_n)) + .collect_vec(); + + // Generate Lwe KSK share + let mut rng = DefaultSecureRng::new(); + let mut ksk_seed = [0u8; 32]; + rng.fill_bytes(&mut ksk_seed); + let lwe_ksk_shares = + izip!(parties_from_lwe_sk.iter(), parties_to_lwe_sk.iter()).map(|(from_sk, to_sk)| { + let mut ksk_out = vec![0u64; from_lwe_n * d_lwe]; + let mut p_rng = DefaultSecureRng::new_seeded(ksk_seed); + lwe_ksk_keygen( + from_sk.values(), + to_sk.values(), + &mut ksk_out, + &lwe_gadgect_vec, + &lweq_modop, + &mut p_rng, + &mut rng, + ); + ksk_out + }); + + // Create collective LWE ksk + let mut sum_partb = vec![0u64; d_lwe * from_lwe_n]; + lwe_ksk_shares.for_each(|share| { + lweq_modop.elwise_add_mut(sum_partb.as_mut_slice(), share.as_slice()) + }); + let mut lwe_ksk = vec![vec![0u64; to_lwe_n + 1]; d_lwe * from_lwe_n]; + let mut p_rng = DefaultSecureRng::new_seeded(ksk_seed); + izip!(lwe_ksk.iter_mut(), sum_partb.iter()).for_each(|(lwe_i, part_bi)| { + RandomUniformDist::random_fill(&mut p_rng, &lwe_q, &mut lwe_i.as_mut_slice()[1..]); + lwe_i[0] = *part_bi; + }); + + // Collective pk + // let collective_pk = _collecitve_public_key_gen( + // lwe_q, + // &parties_from_lwe_sk + // .iter() + // .map(|s| RlweSecret { + // values: s.values.clone(), + // }) + // .collect_vec(), + // ); + + // // Encrypt m as LWE ciphertext + // let m = 1; + // let lwe_ct = { + // let nttop = NttBackendU64::new(lwe_q, from_lwe_n); + // let modop = ModularOpsU64::new(lwe_q); + // let mut rlwe_out = vec![vec![0u64]; from_lwe_n]; + // let mut m_vec = vec![0u64; from_lwe_n]; + // m_vec[0] = m; + // public_key_encrypt_rlwe( + // &mut rlwe_out, + // &collective_pk, + // &m_vec, + // &modop, + // &nttop, + // &mut rng, + // ); + // let mut lwe_ct = vec![0u64; from_lwe_n + 1]; + // sample_extract(&mut lwe_ct, &rlwe_out, &modop, 0); + // lwe_ct + // }; + // Encrypt m + let m = 1; + let mut ideal_from_lwe_sk = vec![0i32; from_lwe_n]; + parties_from_lwe_sk.iter().for_each(|k| { + izip!(ideal_from_lwe_sk.iter_mut(), k.values()).for_each(|(ideal_i, s_i)| { + *ideal_i = *ideal_i + s_i; + }); + }); + let mut lwe_ct = vec![0u64; from_lwe_n + 1]; + encrypt_lwe(&mut lwe_ct, &m, &ideal_from_lwe_sk, &lweq_modop, &mut rng); + + // Key switch + let lwe_ct_key_switched = { + let mut lwe_ct_key_switched = vec![0u64; to_lwe_n + 1]; + let decomposer = DefaultDecomposer::new(lwe_q, logb_lwe, d_lwe); + lwe_key_switch( + &mut lwe_ct_key_switched, + &lwe_ct, + &lwe_ksk, + &lweq_modop, + &decomposer, + ); + lwe_ct_key_switched + }; + + // Measure noise + let mut ideal_to_lwe_sk = vec![0i32; to_lwe_n]; + parties_to_lwe_sk.iter().for_each(|k| { + izip!(ideal_to_lwe_sk.iter_mut(), k.values()).for_each(|(ideal_i, s_i)| { + *ideal_i = *ideal_i + s_i; + }); + }); + let noise = measure_noise_lwe(&lwe_ct_key_switched, &ideal_to_lwe_sk, &lweq_modop, &m); + println!("Noise: {noise}"); + } + + fn _collecitve_public_key_gen(rlwe_q: u64, parties_rlwe_sk: &[RlweSecret]) -> Vec> { + let ring_size = parties_rlwe_sk[0].values.len(); + assert!(ring_size.is_power_of_two()); + let mut rng = DefaultSecureRng::new(); + let nttop = NttBackendU64::new(rlwe_q, ring_size); + let modop = ModularOpsU64::new(rlwe_q); + + // Generate Pk shares + let pk_seed = [0u8; 32]; + let pk_shares = parties_rlwe_sk.iter().map(|sk| { + let mut p_rng = DefaultSecureRng::new_seeded(pk_seed); + let mut share_out = vec![0u64; ring_size]; + public_key_share( + &mut share_out, + sk.values(), + &modop, + &nttop, + &mut p_rng, + &mut rng, + ); + share_out + }); + + let mut pk_part_b = vec![0u64; ring_size]; + pk_shares.for_each(|share| modop.elwise_add_mut(&mut pk_part_b, &share)); + let mut pk_part_a = vec![0u64; ring_size]; + let mut p_rng = DefaultSecureRng::new_seeded(pk_seed); + RandomUniformDist::random_fill(&mut p_rng, &rlwe_q, pk_part_a.as_mut_slice()); + + vec![pk_part_a, pk_part_b] + } + + fn _multi_party_keygen( + bool_evaluator: &BoolEvaluator>, u64, NttBackendU64, ModularOpsU64>, + no_of_parties: usize, + ) -> ( + Vec, + PublicKey>, DefaultSecureRng, ModularOpsU64>, + Vec< + CommonReferenceSeededMultiPartyServerKeyShare< + Vec>, + BoolParameters, + [u8; 32], + >, + >, + SeededMultiPartyServerKey>, [u8; 32], BoolParameters>, + ServerKeyEvaluationDomain>, DefaultSecureRng, NttBackendU64>, + ClientKey, + ) { let parties = (0..no_of_parties) .map(|_| bool_evaluator.client_key()) .collect_vec(); @@ -1871,6 +2037,24 @@ mod tests { } }; + ( + parties, + collective_pk, + server_key_shares, + seeded_server_key, + server_key_eval, + ideal_client_key, + ) + } + + #[test] + fn mp_key_correcntess() { + let bool_evaluator = + BoolEvaluator::>, u64, NttBackendU64, ModularOpsU64>::new(MP_BOOL_PARAMS); + + let (_, collective_pk, _, _, server_key_eval, ideal_client_key) = + _multi_party_keygen(&bool_evaluator, 2); + let lwe_q = bool_evaluator.parameters.lwe_q; let rlwe_q = bool_evaluator.parameters.rlwe_q; let d_rgsw = bool_evaluator.parameters.d_rgsw; @@ -1936,7 +2120,8 @@ mod tests { for i in 0..20 { // measure noise in RGSW(s[i]) - let si = ideal_client_key.sk_lwe.values[i]; + let si = + ideal_client_key.sk_lwe.values[i] * (bool_evaluator.embedding_factor as i32); let mut si_poly = vec![0u64; rlwe_n]; if si < 0 { si_poly[rlwe_n - (si.abs() as usize)] = rlwe_q - 1; @@ -1986,7 +2171,8 @@ mod tests { ); // carry_m[X] * s_i[X] - let si = ideal_client_key.sk_lwe.values[i]; + let si = + ideal_client_key.sk_lwe.values[i] * (bool_evaluator.embedding_factor as i32); let mut si_poly = vec![0u64; rlwe_n]; if si < 0 { si_poly[rlwe_n - (si.abs() as usize)] = rlwe_q - 1; @@ -2067,11 +2253,11 @@ mod tests { } #[test] - fn trial12() { + fn multi_party_nand() { let bool_evaluator = BoolEvaluator::>, u64, NttBackendU64, ModularOpsU64>::new(MP_BOOL_PARAMS); - let no_of_parties = 2; + let no_of_parties = 20; let parties = (0..no_of_parties) .map(|_| bool_evaluator.client_key()) .collect_vec(); @@ -2184,7 +2370,7 @@ mod tests { // let m_back = bool_evaluator.sk_decrypt(&lwe_out, &ideal_client_key); - assert_eq!(m_expected, m_back); + assert!(m_expected == m_back, "Expected {m_expected}, got {m_back}"); m1 = m0; m0 = m_expected; } diff --git a/src/bool/parameters.rs b/src/bool/parameters.rs index e2e68b7..f211d51 100644 --- a/src/bool/parameters.rs +++ b/src/bool/parameters.rs @@ -38,16 +38,17 @@ pub(super) const SP_BOOL_PARAMS: BoolParameters = BoolParameters:: { }; pub(super) const MP_BOOL_PARAMS: BoolParameters = BoolParameters:: { - rlwe_q: 1152921504606830593, - rlwe_logq: 60, - lwe_q: 1 << 20, - lwe_logq: 20, - br_q: 1 << 12, + rlwe_q: 18014398509404161, + rlwe_logq: 54, + lwe_q: 1 << 18, + lwe_logq: 18, + // TODO(Jay:) why does this fail when q=1<<11? + br_q: 1 << 11, rlwe_n: 1 << 11, - lwe_n: 500, - d_rgsw: 10, - logb_rgsw: 6, - d_lwe: 5, + lwe_n: 200, + d_rgsw: 5, + logb_rgsw: 10, + d_lwe: 4, logb_lwe: 4, g: 5, w: 1, @@ -59,7 +60,7 @@ mod tests { #[test] fn find_prime() { - let bits = 60; + let bits = 54; let ring_size = 1 << 11; let prime = generate_prime(bits, ring_size * 2, 1 << bits).unwrap(); dbg!(prime);