From 6776391395c8d875fcb36623dc6b98bfeb782247 Mon Sep 17 00:00:00 2001 From: Janmajaya Mall Date: Fri, 5 Jul 2024 10:46:16 +0530 Subject: [PATCH] add interactive_mp_bool_gates inside print noise --- src/bool/mp_api.rs | 121 +---------------------------------- src/bool/print_noise.rs | 136 ++++++++++++++++++++++++++++++++++++++-- 2 files changed, 133 insertions(+), 124 deletions(-) diff --git a/src/bool/mp_api.rs b/src/bool/mp_api.rs index 4791033..f0dca1e 100644 --- a/src/bool/mp_api.rs +++ b/src/bool/mp_api.rs @@ -310,129 +310,10 @@ mod tests { use itertools::Itertools; use rand::{thread_rng, Rng, RngCore}; - use crate::{ - backend::Modulus, - bool::{ - evaluator::BoolEncoding, - keys::tests::{ideal_sk_rlwe, measure_noise_lwe}, - BooleanGates, - }, - lwe::decrypt_lwe, - utils::tests::Stats, - Encoder, Encryptor, MultiPartyDecryptor, SampleExtractor, - }; + use crate::{bool::evaluator::BoolEncoding, Encryptor, MultiPartyDecryptor, SampleExtractor}; use super::*; - #[test] - fn multi_party_bool_gates() { - set_parameter_set(ParameterSelector::InteractiveLTE8Party); - let mut seed = [0u8; 32]; - thread_rng().fill_bytes(&mut seed); - set_common_reference_seed(seed); - - let parties = 8; - let cks = (0..parties).map(|_| gen_client_key()).collect_vec(); - - // round 1 - let pk_shares = cks - .iter() - .map(|k| interactive_multi_party_round1_share(k)) - .collect_vec(); - - // collective pk - let pk = aggregate_public_key_shares(&pk_shares); - - // round 2 - let server_key_shares = cks - .iter() - .enumerate() - .map(|(user_id, k)| gen_mp_keys_phase2(k, user_id, parties, &pk)) - .collect_vec(); - - // server key - let server_key = aggregate_server_key_shares(&server_key_shares); - server_key.set_server_key(); - - let mut m0 = false; - let mut m1 = true; - - let mut ct0 = pk.encrypt(&m0); - let mut ct1 = pk.encrypt(&m1); - - let ideal_sk_rlwe = ideal_sk_rlwe(&cks); - let parameters = BoolEvaluator::with_local(|e| e.parameters().clone()); - let rlwe_modop = parameters.default_rlwe_modop(); - - let mut stats = Stats::new(); - - for _ in 0..1000 { - let now = std::time::Instant::now(); - let ct_out = - BoolEvaluator::with_local_mut(|e| e.nand(&ct0, &ct1, RuntimeServerKey::global())); - println!("Time: {:?}", now.elapsed()); - - let m_expected = !(m0 && m1); - - let decryption_shares = cks - .iter() - .map(|k| k.gen_decryption_share(&ct_out)) - .collect_vec(); - let m_out = cks[0].aggregate_decryption_shares(&ct_out, &decryption_shares); - - assert!(m_out == m_expected, "Expected {m_expected}, got {m_out}"); - - { - let noise = measure_noise_lwe( - &ct_out, - parameters.rlwe_q().encode(m_expected), - &ideal_sk_rlwe, - &rlwe_modop, - ); - stats.add_sample(parameters.rlwe_q().map_element_to_i64(&noise)); - } - - m1 = m0; - m0 = m_expected; - - ct1 = ct0; - ct0 = ct_out; - } - - for _ in 0..1000 { - let ct_out = - BoolEvaluator::with_local_mut(|e| e.xnor(&ct0, &ct1, RuntimeServerKey::global())); - - let m_expected = !(m0 ^ m1); - - let decryption_shares = cks - .iter() - .map(|k| k.gen_decryption_share(&ct_out)) - .collect_vec(); - let m_out = cks[0].aggregate_decryption_shares(&ct_out, &decryption_shares); - - assert!(m_out == m_expected, "Expected {m_expected}, got {m_out}"); - - { - let noise = measure_noise_lwe( - &ct_out, - parameters.rlwe_q().encode(m_expected), - &ideal_sk_rlwe, - &rlwe_modop, - ); - stats.add_sample(parameters.rlwe_q().map_element_to_i64(&noise)); - } - - m1 = m0; - m0 = m_expected; - - ct1 = ct0; - ct0 = ct_out; - } - - println!("Noise std_dev log2: {}", stats.std_dev().abs().log2()); - } - #[test] fn batched_fhe_u8s_extract_works() { set_parameter_set(ParameterSelector::InteractiveLTE2Party); diff --git a/src/bool/print_noise.rs b/src/bool/print_noise.rs index 064835e..eda0b83 100644 --- a/src/bool/print_noise.rs +++ b/src/bool/print_noise.rs @@ -473,11 +473,139 @@ mod tests { ); } - const K: usize = 100000; + const K: usize = 10; - // #[test] - // #[cfg(feature = "interactive_mp")] - // fn interactive_mp_bool_gates() {} + #[test] + #[cfg(feature = "interactive_mp")] + fn interactive_mp_bool_gates() { + use rand::{thread_rng, RngCore}; + + use crate::{ + aggregate_public_key_shares, aggregate_server_key_shares, + backend::Modulus, + bool::{ + keys::{ + tests::{ideal_sk_rlwe, measure_noise_lwe}, + ServerKeyEvaluationDomain, + }, + print_noise::collect_server_key_stats, + }, + gen_client_key, gen_mp_keys_phase2, interactive_multi_party_round1_share, + parameters::CiphertextModulus, + random::DefaultSecureRng, + set_common_reference_seed, set_parameter_set, + utils::{tests::Stats, Global, WithLocal}, + BoolEvaluator, BooleanGates, DefaultDecomposer, Encoder, Encryptor, ModInit, + ModularOpsU64, MultiPartyDecryptor, NttBackendU64, ParameterSelector, RuntimeServerKey, + }; + + set_parameter_set(ParameterSelector::InteractiveLTE8Party); + + let mut seed = [0u8; 32]; + thread_rng().fill_bytes(&mut seed); + set_common_reference_seed(seed); + + let no_of_parties = 8; + + let cks = (0..no_of_parties).map(|_| gen_client_key()).collect_vec(); + + // round 1 + let pk_shares = cks + .iter() + .map(|k| interactive_multi_party_round1_share(k)) + .collect_vec(); + + let pk = aggregate_public_key_shares(&pk_shares); + + // round 2 + let server_key_shares = cks + .iter() + .enumerate() + .map(|(user_id, k)| gen_mp_keys_phase2(k, user_id, no_of_parties, &pk)) + .collect_vec(); + + let server_key = aggregate_server_key_shares(&server_key_shares); + server_key.set_server_key(); + + let mut m0 = false; + let mut m1 = true; + + let mut ct0 = pk.encrypt(&m0); + let mut ct1 = pk.encrypt(&m1); + + let ideal_sk_rlwe = ideal_sk_rlwe(&cks); + let parameters = BoolEvaluator::with_local(|e| e.parameters().clone()); + let rlwe_modop = ModularOpsU64::new(*parameters.rlwe_q()); + + let mut stats = Stats::new(); + + for _ in 0..K { + // let now = std::time::Instant::now(); + let ct_out = + BoolEvaluator::with_local_mut(|e| e.xor(&ct0, &ct1, RuntimeServerKey::global())); + // println!("Time: {:?}", now.elapsed()); + + let m_expected = m0 ^ m1; + + let decryption_shares = cks + .iter() + .map(|k| k.gen_decryption_share(&ct_out)) + .collect_vec(); + let m_out = cks[0].aggregate_decryption_shares(&ct_out, &decryption_shares); + + assert!(m_out == m_expected, "Expected {m_expected}, got {m_out}"); + + { + let noise = measure_noise_lwe( + &ct_out, + parameters.rlwe_q().encode(m_expected), + &ideal_sk_rlwe, + &rlwe_modop, + ); + stats.add_sample(parameters.rlwe_q().map_element_to_i64(&noise)); + } + + m1 = m0; + m0 = m_expected; + + ct1 = ct0; + ct0 = ct_out; + } + + let server_key_stats = collect_server_key_stats::< + _, + DefaultDecomposer, + NttBackendU64, + ModularOpsU64>, + _, + >( + parameters, + &cks, + &ServerKeyEvaluationDomain::<_, _, DefaultSecureRng, NttBackendU64>::from(&server_key), + ); + + println!("## Bootstrapping Statistics ##"); + println!("Bootstrapped ciphertext noise std_dev: {}", stats.std_dev()); + + println!("## Key Statistics ##"); + println!( + "Rgsw nsm std_dev {}", + server_key_stats.brk_rgsw_cts.0.std_dev() + ); + println!( + "Rgsw m std_dev {}", + server_key_stats.brk_rgsw_cts.1.std_dev() + ); + println!( + "rlwe post 1 auto std_dev {}", + server_key_stats.post_1_auto.std_dev() + ); + println!( + "key switching noise rlwe secret s to lwe secret z std_dev {}", + server_key_stats.post_lwe_key_switch.std_dev() + ); + println!(); + } #[test] #[cfg(feature = "non_interactive_mp")]