From 5d5100e6d1ac1a08b93f0ab84031a08da4f149c2 Mon Sep 17 00:00:00 2001 From: Janmajaya Mall Date: Sun, 23 Jun 2024 15:26:15 +0700 Subject: [PATCH] move multi-party crs to puncturing --- src/bool/evaluator.rs | 210 ++++++++++++++++++++---------------------- src/bool/keys.rs | 74 +++++++++------ src/bool/mp_api.rs | 21 +++-- src/bool/noise.rs | 11 +-- src/decomposer.rs | 3 +- 5 files changed, 164 insertions(+), 155 deletions(-) diff --git a/src/bool/evaluator.rs b/src/bool/evaluator.rs index f3452d5..5166fc3 100644 --- a/src/bool/evaluator.rs +++ b/src/bool/evaluator.rs @@ -64,6 +64,20 @@ use super::{ }, }; +/// Common reference seed used for Interactive multi-party, +/// +/// Seeds for public key shares and differents parts of server key shares are +/// derived from common reference seed with different puncture rountines. +/// +/// ## Punctures +/// +/// Initial Seed: +/// Puncture 1 -> Public key share seed +/// Puncture 2 -> Main server key share seed +/// Puncture 1 -> RGSW cuphertexts seed +/// Puncture 2 -> Auto keys cipertexts seed +/// Puncture 3 -> LWE ksk seed +#[derive(Clone, PartialEq)] pub struct MultiPartyCrs { pub(super) seed: S, } @@ -77,6 +91,34 @@ impl MultiPartyCrs<[u8; 32]> { }) } } +impl MultiPartyCrs { + /// Seed to generate public key share + fn public_key_share_seed + RandomFill>(&self) -> S { + let mut prng = Rng::new_with_seed(self.seed); + puncture_p_rng(&mut prng, 1) + } + + /// Main server key share seed + fn key_seed + RandomFill>(&self) -> S { + let mut prng = Rng::new_with_seed(self.seed); + puncture_p_rng(&mut prng, 2) + } + + pub(super) fn rgsw_cts_seed + RandomFill>(&self) -> S { + let mut key_prng = Rng::new_with_seed(self.key_seed::()); + puncture_p_rng(&mut key_prng, 1) + } + + pub(super) fn auto_keys_cts_seed + RandomFill>(&self) -> S { + let mut key_prng = Rng::new_with_seed(self.key_seed::()); + puncture_p_rng(&mut key_prng, 2) + } + + pub(super) fn lwe_ksk_cts_seed_seed + RandomFill>(&self) -> S { + let mut key_prng = Rng::new_with_seed(self.key_seed::()); + puncture_p_rng(&mut key_prng, 3) + } +} /// Common reference seed used for non-interactive multi-party. /// @@ -99,20 +141,17 @@ impl NonInteractiveMultiPartyCrs { } pub(crate) fn rgsw_cts_seed + RandomFill>(&self) -> S { - let key_seed = self.key_seed::(); - let mut p_rng = R::new_with_seed(key_seed); + let mut p_rng = R::new_with_seed(self.key_seed::()); puncture_p_rng(&mut p_rng, 1) } pub(crate) fn auto_keys_cts_seed + RandomFill>(&self) -> S { - let key_seed = self.key_seed::(); - let mut p_rng = R::new_with_seed(key_seed); + let mut p_rng = R::new_with_seed(self.key_seed::()); puncture_p_rng(&mut p_rng, 2) } pub(crate) fn lwe_ksk_cts_seed + RandomFill>(&self) -> S { - let key_seed = self.key_seed::(); - let mut p_rng = R::new_with_seed(key_seed); + let mut p_rng = R::new_with_seed(self.key_seed::()); puncture_p_rng(&mut p_rng, 3) } @@ -132,33 +171,6 @@ impl NonInteractiveMultiPartyCrs { } } -impl MultiPartyCrs { - /// Seed to generate public key share using MultiPartyCrs as the main seed. - /// - /// Public key seed equals the 1st seed extracted from PRNG Seeded with - /// MiltiPartyCrs's seed. - pub(super) fn public_key_share_seed + RandomFill>(&self) -> S { - let mut prng = Rng::new_with_seed(self.seed); - - let mut seed = S::default(); - RandomFill::::random_fill(&mut prng, &mut seed); - seed - } - - /// Seed to generate server key share using MultiPartyCrs as the main seed. - /// - /// Server key seed equals the 2nd seed extracted from PRNG Seeded with - /// MiltiPartyCrs's seed. - pub(super) fn server_key_share_seed + RandomFill>(&self) -> S { - let mut prng = Rng::new_with_seed(self.seed); - - let mut seed = S::default(); - RandomFill::::random_fill(&mut prng, &mut seed); - RandomFill::::random_fill(&mut prng, &mut seed); - seed - } -} - pub(crate) trait BooleanGates { type Ciphertext: RowEntity; type Key; @@ -788,61 +800,43 @@ where pub(super) fn multi_party_server_key_share>( &self, - cr_seed: [u8; 32], + cr_seed: &MultiPartyCrs<[u8; 32]>, collective_pk: &M, client_key: &K, - ) -> CommonReferenceSeededMultiPartyServerKeyShare, [u8; 32]> - { + ) -> CommonReferenceSeededMultiPartyServerKeyShare< + M, + BoolParameters, + MultiPartyCrs<[u8; 32]>, + > { assert_eq!(self.parameters().variant(), &ParameterVariant::MultiParty); + // let user_id = 0; - DefaultSecureRng::with_local_mut(|rng| { - let mut main_prng = DefaultSecureRng::new_seeded(cr_seed); + // let user_segment_start = 0; + // let user_segment_end = 1; - let sk_rlwe = client_key.sk_rlwe(); - let sk_lwe = client_key.sk_lwe(); + let sk_rlwe = client_key.sk_rlwe(); + let sk_lwe = client_key.sk_lwe(); - let g = self.pbs_info.parameters.g(); - let ring_size = self.pbs_info.parameters.rlwe_n().0; - let rlwe_q = self.pbs_info.parameters.rlwe_q(); - let lwe_q = self.pbs_info.parameters.lwe_q(); + let g = self.pbs_info.parameters.g(); + let ring_size = self.pbs_info.parameters.rlwe_n().0; + let rlwe_q = self.pbs_info.parameters.rlwe_q(); + let lwe_q = self.pbs_info.parameters.lwe_q(); - let rlweq_modop = &self.pbs_info.rlwe_modop; - let rlweq_nttop = &self.pbs_info.rlwe_nttop; + let rlweq_modop = &self.pbs_info.rlwe_modop; + let rlweq_nttop = &self.pbs_info.rlwe_nttop; - // sanity check - assert!(sk_rlwe.len() == ring_size); - assert!(sk_lwe.len() == self.pbs_info.parameters.lwe_n().0); + // sanity check + assert!(sk_rlwe.len() == ring_size); + assert!(sk_lwe.len() == self.pbs_info.parameters.lwe_n().0); - // auto keys - let mut auto_keys = HashMap::new(); - let auto_gadget = self.pbs_info.auto_decomposer.gadget_vector(); - let auto_element_dlogs = self.pbs_info.parameters.auto_element_dlogs(); - let br_q = self.pbs_info.parameters.br_q(); - for i in auto_element_dlogs.into_iter() { - let g_pow = if i == 0 { - -(g as isize) - } else { - (g.pow(i as u32) % br_q) as isize - }; - - let mut ksk_out = M::zeros( - self.pbs_info.auto_decomposer.decomposition_count(), - ring_size, - ); - galois_key_gen( - &mut ksk_out, - &sk_rlwe, - g_pow, - &auto_gadget, - rlweq_modop, - rlweq_nttop, - &mut main_prng, - rng, - ); - auto_keys.insert(i, ksk_out); - } + // auto keys + let auto_keys = self._common_rountine_multi_party_auto_keys_share_gen( + cr_seed.auto_keys_cts_seed::(), + &sk_rlwe, + ); - // rgsw ciphertexts of lwe secret elements + // rgsw ciphertexts of lwe secret elements + let rgsw_cts = DefaultSecureRng::with_local_mut(|rng| { let rgsw_rgsw_decomposer = self .pbs_info .parameters @@ -888,30 +882,23 @@ where out_rgsw }) .collect_vec(); + rgsw_cts + }); - // LWE ksk - let mut lwe_ksk = - M::R::zeros(self.pbs_info.lwe_decomposer.decomposition_count() * ring_size); - let lwe_modop = &self.pbs_info.lwe_modop; - let d_lwe_gadget_vec = self.pbs_info.lwe_decomposer.gadget_vector(); - lwe_ksk_keygen( - &sk_rlwe, - &sk_lwe, - &mut lwe_ksk, - &d_lwe_gadget_vec, - lwe_modop, - &mut main_prng, - rng, - ); + // LWE Ksk + let lwe_ksk = self._common_rountine_multi_party_lwe_ksk_share_gen( + cr_seed.lwe_ksk_cts_seed_seed::(), + &sk_rlwe, + &sk_lwe, + ); - CommonReferenceSeededMultiPartyServerKeyShare::new( - rgsw_cts, - auto_keys, - lwe_ksk, - cr_seed, - self.pbs_info.parameters.clone(), - ) - }) + CommonReferenceSeededMultiPartyServerKeyShare::new( + rgsw_cts, + auto_keys, + lwe_ksk, + cr_seed.clone(), + self.pbs_info.parameters.clone(), + ) } pub(super) fn aggregate_non_interactive_multi_party_key_share( @@ -1657,7 +1644,7 @@ where pub(super) fn multi_party_public_key_share>( &self, - cr_seed: [u8; 32], + cr_seed: &MultiPartyCrs<[u8; 32]>, client_key: &K, ) -> CommonReferenceSeededCollectivePublicKeyShare< ::R, @@ -1668,7 +1655,8 @@ where let mut share_out = M::R::zeros(self.pbs_info.parameters.rlwe_n().0); let modop = &self.pbs_info.rlwe_modop; let nttop = &self.pbs_info.rlwe_nttop; - let mut main_prng = DefaultSecureRng::new_seeded(cr_seed); + let pk_seed = cr_seed.public_key_share_seed::(); + let mut main_prng = DefaultSecureRng::new_seeded(pk_seed); public_key_share( &mut share_out, &client_key.sk_rlwe(), @@ -1679,7 +1667,7 @@ where ); CommonReferenceSeededCollectivePublicKeyShare::new( share_out, - cr_seed, + pk_seed, self.pbs_info.parameters.clone(), ) }) @@ -1852,9 +1840,9 @@ where shares: &[CommonReferenceSeededMultiPartyServerKeyShare< M, BoolParameters, - S, + MultiPartyCrs, >], - ) -> SeededMultiPartyServerKey> + ) -> SeededMultiPartyServerKey, BoolParameters> where S: PartialEq + Clone, M: Clone, @@ -2256,6 +2244,8 @@ mod tests { .map(|_| bool_evaluator.client_key()) .collect_vec(); + let int_mp_seed = MultiPartyCrs::random(); + let mut ideal_rlwe_sk = vec![0i32; bool_evaluator.pbs_info.rlwe_n()]; parties.iter().for_each(|k| { izip!( @@ -2287,7 +2277,7 @@ mod tests { rng.fill_bytes(&mut pk_cr_seed); let public_key_share = parties .iter() - .map(|k| bool_evaluator.multi_party_public_key_share(pk_cr_seed, k)) + .map(|k| bool_evaluator.multi_party_public_key_share(&int_mp_seed, k)) .collect_vec(); let collective_pk = PublicKey::< Vec>, @@ -2331,7 +2321,7 @@ mod tests { rng.fill_bytes(&mut pk_cr_seed); let public_key_share = parties .iter() - .map(|k| bool_evaluator.multi_party_public_key_share(pk_cr_seed, k)) + .map(|k| bool_evaluator.multi_party_public_key_share(&int_mp_seed, k)) .collect_vec(); let collective_pk = PublicKey::< Vec>, @@ -2344,7 +2334,11 @@ mod tests { let server_key_shares = parties .iter() .map(|k| { - bool_evaluator.multi_party_server_key_share(pbs_cr_seed, collective_pk.key(), k) + bool_evaluator.multi_party_server_key_share( + &int_mp_seed, + collective_pk.key(), + k, + ) }) .collect_vec(); diff --git a/src/bool/keys.rs b/src/bool/keys.rs index a4223d4..99bc03f 100644 --- a/src/bool/keys.rs +++ b/src/bool/keys.rs @@ -145,6 +145,8 @@ pub struct PublicKey { } pub(super) mod impl_pk { + use crate::evaluator::MultiPartyCrs; + use super::*; impl PublicKey { @@ -462,8 +464,10 @@ pub(super) mod impl_server_key_eval_domain { use itertools::{izip, Itertools}; use crate::{ + evaluator::MultiPartyCrs, ntt::{Ntt, NttInit}, pbs::PbsKey, + random::RandomFill, }; use super::*; @@ -610,16 +614,22 @@ pub(super) mod impl_server_key_eval_domain { M: MatrixMut + MatrixEntity, Rng: NewWithSeed, N: NttInit> + Ntt, - > From<&SeededMultiPartyServerKey>> + > + From<&SeededMultiPartyServerKey, BoolParameters>> for ServerKeyEvaluationDomain, Rng, N> where ::R: RowMut, - Rng::Seed: Copy, - Rng: RandomFillUniformInModulus<[M::MatElement], CiphertextModulus>, + Rng::Seed: Copy + Default, + Rng: RandomFillUniformInModulus<[M::MatElement], CiphertextModulus> + + RandomFill, M::MatElement: Copy, { fn from( - value: &SeededMultiPartyServerKey>, + value: &SeededMultiPartyServerKey< + M, + MultiPartyCrs, + BoolParameters, + >, ) -> Self { let g = value.parameters.g() as isize; let rlwe_n = value.parameters.rlwe_n().0; @@ -627,37 +637,42 @@ pub(super) mod impl_server_key_eval_domain { let rlwe_q = value.parameters.rlwe_q(); let lwe_q = value.parameters.lwe_q(); - let mut main_prng = Rng::new_with_seed(value.cr_seed); - let rlwe_nttop = N::new(rlwe_q, rlwe_n); // auto keys let mut auto_keys = HashMap::new(); - let auto_d_count = value.parameters.auto_decomposition_count().0; - let auto_element_dlogs = value.parameters.auto_element_dlogs(); - for i in auto_element_dlogs.into_iter() { - let mut key = M::zeros(auto_d_count * 2, rlwe_n); - - // sample a - key.iter_rows_mut().take(auto_d_count).for_each(|ri| { - RandomFillUniformInModulus::random_fill(&mut main_prng, &rlwe_q, ri.as_mut()) - }); + { + let mut auto_prng = Rng::new_with_seed(value.cr_seed.auto_keys_cts_seed::()); + let auto_d_count = value.parameters.auto_decomposition_count().0; + let auto_element_dlogs = value.parameters.auto_element_dlogs(); + for i in auto_element_dlogs.into_iter() { + let mut key = M::zeros(auto_d_count * 2, rlwe_n); + + // sample a + key.iter_rows_mut().take(auto_d_count).for_each(|ri| { + RandomFillUniformInModulus::random_fill( + &mut auto_prng, + &rlwe_q, + ri.as_mut(), + ) + }); - let key_part_b = value.auto_keys.get(&i).unwrap(); - assert!(key_part_b.dimension() == (auto_d_count, rlwe_n)); - izip!( - key.iter_rows_mut().skip(auto_d_count), - key_part_b.iter_rows() - ) - .for_each(|(to_ri, from_ri)| { - to_ri.as_mut().copy_from_slice(from_ri.as_ref()); - }); + let key_part_b = value.auto_keys.get(&i).unwrap(); + assert!(key_part_b.dimension() == (auto_d_count, rlwe_n)); + izip!( + key.iter_rows_mut().skip(auto_d_count), + key_part_b.iter_rows() + ) + .for_each(|(to_ri, from_ri)| { + to_ri.as_mut().copy_from_slice(from_ri.as_ref()); + }); - // send to evaluation domain - key.iter_rows_mut() - .for_each(|ri| rlwe_nttop.forward(ri.as_mut())); + // send to evaluation domain + key.iter_rows_mut() + .for_each(|ri| rlwe_nttop.forward(ri.as_mut())); - auto_keys.insert(i, key); + auto_keys.insert(i, key); + } } // rgsw cts @@ -682,12 +697,13 @@ pub(super) mod impl_server_key_eval_domain { .collect_vec(); // lwe ksk + let mut lwe_ksk_prng = Rng::new_with_seed(value.cr_seed.lwe_ksk_cts_seed_seed::()); let d_lwe = value.parameters.lwe_decomposition_count().0; let mut lwe_ksk = M::zeros(rlwe_n * d_lwe, lwe_n + 1); izip!(lwe_ksk.iter_rows_mut(), value.lwe_ksk.as_ref().iter()).for_each( |(lwe_i, bi)| { RandomFillUniformInModulus::random_fill( - &mut main_prng, + &mut lwe_ksk_prng, &lwe_q, &mut lwe_i.as_mut()[1..], ); diff --git a/src/bool/mp_api.rs b/src/bool/mp_api.rs index 7d745b7..f77ec09 100644 --- a/src/bool/mp_api.rs +++ b/src/bool/mp_api.rs @@ -1,4 +1,4 @@ -use std::{cell::RefCell, sync::OnceLock}; +use std::{cell::RefCell, ops::Mul, sync::OnceLock}; use crate::{ backend::{ModularOpsU64, ModulusPowerOf2}, @@ -50,9 +50,8 @@ pub fn gen_client_key() -> ClientKey { pub fn gen_mp_keys_phase1( ck: &ClientKey, ) -> CommonReferenceSeededCollectivePublicKeyShare, [u8; 32], BoolParameters> { - let seed = MultiPartyCrs::global().public_key_share_seed::(); BoolEvaluator::with_local(|e| { - let pk_share = e.multi_party_public_key_share(seed, ck); + let pk_share = e.multi_party_public_key_share(MultiPartyCrs::global(), ck); pk_share }) } @@ -60,10 +59,14 @@ pub fn gen_mp_keys_phase1( pub fn gen_mp_keys_phase2( ck: &ClientKey, pk: &PublicKey>, R, ModOp>, -) -> CommonReferenceSeededMultiPartyServerKeyShare>, BoolParameters, [u8; 32]> { - let seed = MultiPartyCrs::global().server_key_share_seed::(); +) -> CommonReferenceSeededMultiPartyServerKeyShare< + Vec>, + BoolParameters, + MultiPartyCrs<[u8; 32]>, +> { BoolEvaluator::with_local_mut(|e| { - let server_key_share = e.multi_party_server_key_share(seed, pk.key(), ck); + let server_key_share = + e.multi_party_server_key_share(MultiPartyCrs::global(), pk.key(), ck); server_key_share }) } @@ -82,16 +85,16 @@ pub fn aggregate_server_key_shares( shares: &[CommonReferenceSeededMultiPartyServerKeyShare< Vec>, BoolParameters, - [u8; 32], + MultiPartyCrs<[u8; 32]>, >], -) -> SeededMultiPartyServerKey>, [u8; 32], BoolParameters> { +) -> SeededMultiPartyServerKey>, MultiPartyCrs<[u8; 32]>, BoolParameters> { BoolEvaluator::with_local(|e| e.aggregate_multi_party_server_key_shares(shares)) } impl SeededMultiPartyServerKey< Vec>, - ::Seed, + MultiPartyCrs<::Seed>, BoolParameters, > { diff --git a/src/bool/noise.rs b/src/bool/noise.rs index 1354a42..dcf87e3 100644 --- a/src/bool/noise.rs +++ b/src/bool/noise.rs @@ -11,6 +11,7 @@ mod test { }, parameters::{CiphertextModulus, SMALL_MP_BOOL_PARAMS}, }, + evaluator::MultiPartyCrs, ntt::NttBackendU64, random::DefaultSecureRng, }; @@ -28,11 +29,7 @@ mod test { let parties = 2; - let mut rng = DefaultSecureRng::new(); - let mut pk_cr_seed = [0u8; 32]; - let mut bk_cr_seed = [0u8; 32]; - rng.fill_bytes(&mut pk_cr_seed); - rng.fill_bytes(&mut bk_cr_seed); + let cr_seed = MultiPartyCrs::random(); let cks = (0..parties) .into_iter() @@ -64,7 +61,7 @@ mod test { // round 1 let pk_shares = cks .iter() - .map(|c| evaluator.multi_party_public_key_share(pk_cr_seed, c)) + .map(|c| evaluator.multi_party_public_key_share(&cr_seed, c)) .collect_vec(); // public key @@ -75,7 +72,7 @@ mod test { // round 2 let server_key_shares = cks .iter() - .map(|c| evaluator.multi_party_server_key_share(bk_cr_seed, &pk.key(), c)) + .map(|c| evaluator.multi_party_server_key_share(&cr_seed, &pk.key(), c)) .collect_vec(); let server_key = evaluator.aggregate_multi_party_server_key_shares(&server_key_shares); diff --git a/src/decomposer.rs b/src/decomposer.rs index 5592f09..599678a 100644 --- a/src/decomposer.rs +++ b/src/decomposer.rs @@ -298,14 +298,13 @@ mod tests { let d = 3; let mut stats = vec![Stats::new(); d]; - for i in [true] { + for i in [true, false] { let q = if i { generate_prime(logq, 2 * ring_size, 1u64 << logq).unwrap() } else { 1u64 << logq }; let decomposer = DefaultDecomposer::new(q, logb, d); - dbg!(decomposer.ignore_bits); let modq_op = ModularOpsU64::new(q); for _ in 0..1000000 { let value = rng.gen_range(0..q);