Browse Source

move multi-party crs to puncturing

par-agg-key-shares
Janmajaya Mall 10 months ago
parent
commit
5d5100e6d1
5 changed files with 164 additions and 155 deletions
  1. +102
    -108
      src/bool/evaluator.rs
  2. +45
    -29
      src/bool/keys.rs
  3. +12
    -9
      src/bool/mp_api.rs
  4. +4
    -7
      src/bool/noise.rs
  5. +1
    -2
      src/decomposer.rs

+ 102
- 108
src/bool/evaluator.rs

@ -64,6 +64,20 @@ use super::{
},
};
/// Common reference seed used for Interactive multi-party,
///
/// Seeds for public key shares and differents parts of server key shares are
/// derived from common reference seed with different puncture rountines.
///
/// ## Punctures
///
/// Initial Seed:
/// Puncture 1 -> Public key share seed
/// Puncture 2 -> Main server key share seed
/// Puncture 1 -> RGSW cuphertexts seed
/// Puncture 2 -> Auto keys cipertexts seed
/// Puncture 3 -> LWE ksk seed
#[derive(Clone, PartialEq)]
pub struct MultiPartyCrs<S> {
pub(super) seed: S,
}
@ -77,6 +91,34 @@ impl MultiPartyCrs<[u8; 32]> {
})
}
}
impl<S: Default + Copy> MultiPartyCrs<S> {
/// Seed to generate public key share
fn public_key_share_seed<Rng: NewWithSeed<Seed = S> + RandomFill<S>>(&self) -> S {
let mut prng = Rng::new_with_seed(self.seed);
puncture_p_rng(&mut prng, 1)
}
/// Main server key share seed
fn key_seed<Rng: NewWithSeed<Seed = S> + RandomFill<S>>(&self) -> S {
let mut prng = Rng::new_with_seed(self.seed);
puncture_p_rng(&mut prng, 2)
}
pub(super) fn rgsw_cts_seed<Rng: NewWithSeed<Seed = S> + RandomFill<S>>(&self) -> S {
let mut key_prng = Rng::new_with_seed(self.key_seed::<Rng>());
puncture_p_rng(&mut key_prng, 1)
}
pub(super) fn auto_keys_cts_seed<Rng: NewWithSeed<Seed = S> + RandomFill<S>>(&self) -> S {
let mut key_prng = Rng::new_with_seed(self.key_seed::<Rng>());
puncture_p_rng(&mut key_prng, 2)
}
pub(super) fn lwe_ksk_cts_seed_seed<Rng: NewWithSeed<Seed = S> + RandomFill<S>>(&self) -> S {
let mut key_prng = Rng::new_with_seed(self.key_seed::<Rng>());
puncture_p_rng(&mut key_prng, 3)
}
}
/// Common reference seed used for non-interactive multi-party.
///
@ -99,20 +141,17 @@ impl NonInteractiveMultiPartyCrs {
}
pub(crate) fn rgsw_cts_seed<R: NewWithSeed<Seed = S> + RandomFill<S>>(&self) -> S {
let key_seed = self.key_seed::<R>();
let mut p_rng = R::new_with_seed(key_seed);
let mut p_rng = R::new_with_seed(self.key_seed::<R>());
puncture_p_rng(&mut p_rng, 1)
}
pub(crate) fn auto_keys_cts_seed<R: NewWithSeed<Seed = S> + RandomFill<S>>(&self) -> S {
let key_seed = self.key_seed::<R>();
let mut p_rng = R::new_with_seed(key_seed);
let mut p_rng = R::new_with_seed(self.key_seed::<R>());
puncture_p_rng(&mut p_rng, 2)
}
pub(crate) fn lwe_ksk_cts_seed<R: NewWithSeed<Seed = S> + RandomFill<S>>(&self) -> S {
let key_seed = self.key_seed::<R>();
let mut p_rng = R::new_with_seed(key_seed);
let mut p_rng = R::new_with_seed(self.key_seed::<R>());
puncture_p_rng(&mut p_rng, 3)
}
@ -132,33 +171,6 @@ impl NonInteractiveMultiPartyCrs {
}
}
impl<S: Default + Copy> MultiPartyCrs<S> {
/// Seed to generate public key share using MultiPartyCrs as the main seed.
///
/// Public key seed equals the 1st seed extracted from PRNG Seeded with
/// MiltiPartyCrs's seed.
pub(super) fn public_key_share_seed<Rng: NewWithSeed<Seed = S> + RandomFill<S>>(&self) -> S {
let mut prng = Rng::new_with_seed(self.seed);
let mut seed = S::default();
RandomFill::<S>::random_fill(&mut prng, &mut seed);
seed
}
/// Seed to generate server key share using MultiPartyCrs as the main seed.
///
/// Server key seed equals the 2nd seed extracted from PRNG Seeded with
/// MiltiPartyCrs's seed.
pub(super) fn server_key_share_seed<Rng: NewWithSeed<Seed = S> + RandomFill<S>>(&self) -> S {
let mut prng = Rng::new_with_seed(self.seed);
let mut seed = S::default();
RandomFill::<S>::random_fill(&mut prng, &mut seed);
RandomFill::<S>::random_fill(&mut prng, &mut seed);
seed
}
}
pub(crate) trait BooleanGates {
type Ciphertext: RowEntity;
type Key;
@ -788,61 +800,43 @@ where
pub(super) fn multi_party_server_key_share<K: InteractiveMultiPartyClientKey<Element = i32>>(
&self,
cr_seed: [u8; 32],
cr_seed: &MultiPartyCrs<[u8; 32]>,
collective_pk: &M,
client_key: &K,
) -> CommonReferenceSeededMultiPartyServerKeyShare<M, BoolParameters<M::MatElement>, [u8; 32]>
{
) -> CommonReferenceSeededMultiPartyServerKeyShare<
M,
BoolParameters<M::MatElement>,
MultiPartyCrs<[u8; 32]>,
> {
assert_eq!(self.parameters().variant(), &ParameterVariant::MultiParty);
// let user_id = 0;
DefaultSecureRng::with_local_mut(|rng| {
let mut main_prng = DefaultSecureRng::new_seeded(cr_seed);
// let user_segment_start = 0;
// let user_segment_end = 1;
let sk_rlwe = client_key.sk_rlwe();
let sk_lwe = client_key.sk_lwe();
let sk_rlwe = client_key.sk_rlwe();
let sk_lwe = client_key.sk_lwe();
let g = self.pbs_info.parameters.g();
let ring_size = self.pbs_info.parameters.rlwe_n().0;
let rlwe_q = self.pbs_info.parameters.rlwe_q();
let lwe_q = self.pbs_info.parameters.lwe_q();
let g = self.pbs_info.parameters.g();
let ring_size = self.pbs_info.parameters.rlwe_n().0;
let rlwe_q = self.pbs_info.parameters.rlwe_q();
let lwe_q = self.pbs_info.parameters.lwe_q();
let rlweq_modop = &self.pbs_info.rlwe_modop;
let rlweq_nttop = &self.pbs_info.rlwe_nttop;
let rlweq_modop = &self.pbs_info.rlwe_modop;
let rlweq_nttop = &self.pbs_info.rlwe_nttop;
// sanity check
assert!(sk_rlwe.len() == ring_size);
assert!(sk_lwe.len() == self.pbs_info.parameters.lwe_n().0);
// sanity check
assert!(sk_rlwe.len() == ring_size);
assert!(sk_lwe.len() == self.pbs_info.parameters.lwe_n().0);
// auto keys
let mut auto_keys = HashMap::new();
let auto_gadget = self.pbs_info.auto_decomposer.gadget_vector();
let auto_element_dlogs = self.pbs_info.parameters.auto_element_dlogs();
let br_q = self.pbs_info.parameters.br_q();
for i in auto_element_dlogs.into_iter() {
let g_pow = if i == 0 {
-(g as isize)
} else {
(g.pow(i as u32) % br_q) as isize
};
let mut ksk_out = M::zeros(
self.pbs_info.auto_decomposer.decomposition_count(),
ring_size,
);
galois_key_gen(
&mut ksk_out,
&sk_rlwe,
g_pow,
&auto_gadget,
rlweq_modop,
rlweq_nttop,
&mut main_prng,
rng,
);
auto_keys.insert(i, ksk_out);
}
// auto keys
let auto_keys = self._common_rountine_multi_party_auto_keys_share_gen(
cr_seed.auto_keys_cts_seed::<DefaultSecureRng>(),
&sk_rlwe,
);
// rgsw ciphertexts of lwe secret elements
// rgsw ciphertexts of lwe secret elements
let rgsw_cts = DefaultSecureRng::with_local_mut(|rng| {
let rgsw_rgsw_decomposer = self
.pbs_info
.parameters
@ -888,30 +882,23 @@ where
out_rgsw
})
.collect_vec();
rgsw_cts
});
// LWE ksk
let mut lwe_ksk =
M::R::zeros(self.pbs_info.lwe_decomposer.decomposition_count() * ring_size);
let lwe_modop = &self.pbs_info.lwe_modop;
let d_lwe_gadget_vec = self.pbs_info.lwe_decomposer.gadget_vector();
lwe_ksk_keygen(
&sk_rlwe,
&sk_lwe,
&mut lwe_ksk,
&d_lwe_gadget_vec,
lwe_modop,
&mut main_prng,
rng,
);
// LWE Ksk
let lwe_ksk = self._common_rountine_multi_party_lwe_ksk_share_gen(
cr_seed.lwe_ksk_cts_seed_seed::<DefaultSecureRng>(),
&sk_rlwe,
&sk_lwe,
);
CommonReferenceSeededMultiPartyServerKeyShare::new(
rgsw_cts,
auto_keys,
lwe_ksk,
cr_seed,
self.pbs_info.parameters.clone(),
)
})
CommonReferenceSeededMultiPartyServerKeyShare::new(
rgsw_cts,
auto_keys,
lwe_ksk,
cr_seed.clone(),
self.pbs_info.parameters.clone(),
)
}
pub(super) fn aggregate_non_interactive_multi_party_key_share(
@ -1657,7 +1644,7 @@ where
pub(super) fn multi_party_public_key_share<K: InteractiveMultiPartyClientKey<Element = i32>>(
&self,
cr_seed: [u8; 32],
cr_seed: &MultiPartyCrs<[u8; 32]>,
client_key: &K,
) -> CommonReferenceSeededCollectivePublicKeyShare<
<M as Matrix>::R,
@ -1668,7 +1655,8 @@ where
let mut share_out = M::R::zeros(self.pbs_info.parameters.rlwe_n().0);
let modop = &self.pbs_info.rlwe_modop;
let nttop = &self.pbs_info.rlwe_nttop;
let mut main_prng = DefaultSecureRng::new_seeded(cr_seed);
let pk_seed = cr_seed.public_key_share_seed::<DefaultSecureRng>();
let mut main_prng = DefaultSecureRng::new_seeded(pk_seed);
public_key_share(
&mut share_out,
&client_key.sk_rlwe(),
@ -1679,7 +1667,7 @@ where
);
CommonReferenceSeededCollectivePublicKeyShare::new(
share_out,
cr_seed,
pk_seed,
self.pbs_info.parameters.clone(),
)
})
@ -1852,9 +1840,9 @@ where
shares: &[CommonReferenceSeededMultiPartyServerKeyShare<
M,
BoolParameters<M::MatElement>,
S,
MultiPartyCrs<S>,
>],
) -> SeededMultiPartyServerKey<M, S, BoolParameters<M::MatElement>>
) -> SeededMultiPartyServerKey<M, MultiPartyCrs<S>, BoolParameters<M::MatElement>>
where
S: PartialEq + Clone,
M: Clone,
@ -2256,6 +2244,8 @@ mod tests {
.map(|_| bool_evaluator.client_key())
.collect_vec();
let int_mp_seed = MultiPartyCrs::random();
let mut ideal_rlwe_sk = vec![0i32; bool_evaluator.pbs_info.rlwe_n()];
parties.iter().for_each(|k| {
izip!(
@ -2287,7 +2277,7 @@ mod tests {
rng.fill_bytes(&mut pk_cr_seed);
let public_key_share = parties
.iter()
.map(|k| bool_evaluator.multi_party_public_key_share(pk_cr_seed, k))
.map(|k| bool_evaluator.multi_party_public_key_share(&int_mp_seed, k))
.collect_vec();
let collective_pk = PublicKey::<
Vec<Vec<u64>>,
@ -2331,7 +2321,7 @@ mod tests {
rng.fill_bytes(&mut pk_cr_seed);
let public_key_share = parties
.iter()
.map(|k| bool_evaluator.multi_party_public_key_share(pk_cr_seed, k))
.map(|k| bool_evaluator.multi_party_public_key_share(&int_mp_seed, k))
.collect_vec();
let collective_pk = PublicKey::<
Vec<Vec<u64>>,
@ -2344,7 +2334,11 @@ mod tests {
let server_key_shares = parties
.iter()
.map(|k| {
bool_evaluator.multi_party_server_key_share(pbs_cr_seed, collective_pk.key(), k)
bool_evaluator.multi_party_server_key_share(
&int_mp_seed,
collective_pk.key(),
k,
)
})
.collect_vec();

+ 45
- 29
src/bool/keys.rs

@ -145,6 +145,8 @@ pub struct PublicKey {
}
pub(super) mod impl_pk {
use crate::evaluator::MultiPartyCrs;
use super::*;
impl<M, R, Mo> PublicKey<M, R, Mo> {
@ -462,8 +464,10 @@ pub(super) mod impl_server_key_eval_domain {
use itertools::{izip, Itertools};
use crate::{
evaluator::MultiPartyCrs,
ntt::{Ntt, NttInit},
pbs::PbsKey,
random::RandomFill,
};
use super::*;
@ -610,16 +614,22 @@ pub(super) mod impl_server_key_eval_domain {
M: MatrixMut + MatrixEntity,
Rng: NewWithSeed,
N: NttInit<CiphertextModulus<M::MatElement>> + Ntt<Element = M::MatElement>,
> From<&SeededMultiPartyServerKey<M, Rng::Seed, BoolParameters<M::MatElement>>>
>
From<&SeededMultiPartyServerKey<M, MultiPartyCrs<Rng::Seed>, BoolParameters<M::MatElement>>>
for ServerKeyEvaluationDomain<M, BoolParameters<M::MatElement>, Rng, N>
where
<M as Matrix>::R: RowMut,
Rng::Seed: Copy,
Rng: RandomFillUniformInModulus<[M::MatElement], CiphertextModulus<M::MatElement>>,
Rng::Seed: Copy + Default,
Rng: RandomFillUniformInModulus<[M::MatElement], CiphertextModulus<M::MatElement>>
+ RandomFill<Rng::Seed>,
M::MatElement: Copy,
{
fn from(
value: &SeededMultiPartyServerKey<M, Rng::Seed, BoolParameters<M::MatElement>>,
value: &SeededMultiPartyServerKey<
M,
MultiPartyCrs<Rng::Seed>,
BoolParameters<M::MatElement>,
>,
) -> Self {
let g = value.parameters.g() as isize;
let rlwe_n = value.parameters.rlwe_n().0;
@ -627,37 +637,42 @@ pub(super) mod impl_server_key_eval_domain {
let rlwe_q = value.parameters.rlwe_q();
let lwe_q = value.parameters.lwe_q();
let mut main_prng = Rng::new_with_seed(value.cr_seed);
let rlwe_nttop = N::new(rlwe_q, rlwe_n);
// auto keys
let mut auto_keys = HashMap::new();
let auto_d_count = value.parameters.auto_decomposition_count().0;
let auto_element_dlogs = value.parameters.auto_element_dlogs();
for i in auto_element_dlogs.into_iter() {
let mut key = M::zeros(auto_d_count * 2, rlwe_n);
// sample a
key.iter_rows_mut().take(auto_d_count).for_each(|ri| {
RandomFillUniformInModulus::random_fill(&mut main_prng, &rlwe_q, ri.as_mut())
});
{
let mut auto_prng = Rng::new_with_seed(value.cr_seed.auto_keys_cts_seed::<Rng>());
let auto_d_count = value.parameters.auto_decomposition_count().0;
let auto_element_dlogs = value.parameters.auto_element_dlogs();
for i in auto_element_dlogs.into_iter() {
let mut key = M::zeros(auto_d_count * 2, rlwe_n);
// sample a
key.iter_rows_mut().take(auto_d_count).for_each(|ri| {
RandomFillUniformInModulus::random_fill(
&mut auto_prng,
&rlwe_q,
ri.as_mut(),
)
});
let key_part_b = value.auto_keys.get(&i).unwrap();
assert!(key_part_b.dimension() == (auto_d_count, rlwe_n));
izip!(
key.iter_rows_mut().skip(auto_d_count),
key_part_b.iter_rows()
)
.for_each(|(to_ri, from_ri)| {
to_ri.as_mut().copy_from_slice(from_ri.as_ref());
});
let key_part_b = value.auto_keys.get(&i).unwrap();
assert!(key_part_b.dimension() == (auto_d_count, rlwe_n));
izip!(
key.iter_rows_mut().skip(auto_d_count),
key_part_b.iter_rows()
)
.for_each(|(to_ri, from_ri)| {
to_ri.as_mut().copy_from_slice(from_ri.as_ref());
});
// send to evaluation domain
key.iter_rows_mut()
.for_each(|ri| rlwe_nttop.forward(ri.as_mut()));
// send to evaluation domain
key.iter_rows_mut()
.for_each(|ri| rlwe_nttop.forward(ri.as_mut()));
auto_keys.insert(i, key);
auto_keys.insert(i, key);
}
}
// rgsw cts
@ -682,12 +697,13 @@ pub(super) mod impl_server_key_eval_domain {
.collect_vec();
// lwe ksk
let mut lwe_ksk_prng = Rng::new_with_seed(value.cr_seed.lwe_ksk_cts_seed_seed::<Rng>());
let d_lwe = value.parameters.lwe_decomposition_count().0;
let mut lwe_ksk = M::zeros(rlwe_n * d_lwe, lwe_n + 1);
izip!(lwe_ksk.iter_rows_mut(), value.lwe_ksk.as_ref().iter()).for_each(
|(lwe_i, bi)| {
RandomFillUniformInModulus::random_fill(
&mut main_prng,
&mut lwe_ksk_prng,
&lwe_q,
&mut lwe_i.as_mut()[1..],
);

+ 12
- 9
src/bool/mp_api.rs

@ -1,4 +1,4 @@
use std::{cell::RefCell, sync::OnceLock};
use std::{cell::RefCell, ops::Mul, sync::OnceLock};
use crate::{
backend::{ModularOpsU64, ModulusPowerOf2},
@ -50,9 +50,8 @@ pub fn gen_client_key() -> ClientKey {
pub fn gen_mp_keys_phase1(
ck: &ClientKey,
) -> CommonReferenceSeededCollectivePublicKeyShare<Vec<u64>, [u8; 32], BoolParameters<u64>> {
let seed = MultiPartyCrs::global().public_key_share_seed::<DefaultSecureRng>();
BoolEvaluator::with_local(|e| {
let pk_share = e.multi_party_public_key_share(seed, ck);
let pk_share = e.multi_party_public_key_share(MultiPartyCrs::global(), ck);
pk_share
})
}
@ -60,10 +59,14 @@ pub fn gen_mp_keys_phase1(
pub fn gen_mp_keys_phase2<R, ModOp>(
ck: &ClientKey,
pk: &PublicKey<Vec<Vec<u64>>, R, ModOp>,
) -> CommonReferenceSeededMultiPartyServerKeyShare<Vec<Vec<u64>>, BoolParameters<u64>, [u8; 32]> {
let seed = MultiPartyCrs::global().server_key_share_seed::<DefaultSecureRng>();
) -> CommonReferenceSeededMultiPartyServerKeyShare<
Vec<Vec<u64>>,
BoolParameters<u64>,
MultiPartyCrs<[u8; 32]>,
> {
BoolEvaluator::with_local_mut(|e| {
let server_key_share = e.multi_party_server_key_share(seed, pk.key(), ck);
let server_key_share =
e.multi_party_server_key_share(MultiPartyCrs::global(), pk.key(), ck);
server_key_share
})
}
@ -82,16 +85,16 @@ pub fn aggregate_server_key_shares(
shares: &[CommonReferenceSeededMultiPartyServerKeyShare<
Vec<Vec<u64>>,
BoolParameters<u64>,
[u8; 32],
MultiPartyCrs<[u8; 32]>,
>],
) -> SeededMultiPartyServerKey<Vec<Vec<u64>>, [u8; 32], BoolParameters<u64>> {
) -> SeededMultiPartyServerKey<Vec<Vec<u64>>, MultiPartyCrs<[u8; 32]>, BoolParameters<u64>> {
BoolEvaluator::with_local(|e| e.aggregate_multi_party_server_key_shares(shares))
}
impl
SeededMultiPartyServerKey<
Vec<Vec<u64>>,
<DefaultSecureRng as NewWithSeed>::Seed,
MultiPartyCrs<<DefaultSecureRng as NewWithSeed>::Seed>,
BoolParameters<u64>,
>
{

+ 4
- 7
src/bool/noise.rs

@ -11,6 +11,7 @@ mod test {
},
parameters::{CiphertextModulus, SMALL_MP_BOOL_PARAMS},
},
evaluator::MultiPartyCrs,
ntt::NttBackendU64,
random::DefaultSecureRng,
};
@ -28,11 +29,7 @@ mod test {
let parties = 2;
let mut rng = DefaultSecureRng::new();
let mut pk_cr_seed = [0u8; 32];
let mut bk_cr_seed = [0u8; 32];
rng.fill_bytes(&mut pk_cr_seed);
rng.fill_bytes(&mut bk_cr_seed);
let cr_seed = MultiPartyCrs::random();
let cks = (0..parties)
.into_iter()
@ -64,7 +61,7 @@ mod test {
// round 1
let pk_shares = cks
.iter()
.map(|c| evaluator.multi_party_public_key_share(pk_cr_seed, c))
.map(|c| evaluator.multi_party_public_key_share(&cr_seed, c))
.collect_vec();
// public key
@ -75,7 +72,7 @@ mod test {
// round 2
let server_key_shares = cks
.iter()
.map(|c| evaluator.multi_party_server_key_share(bk_cr_seed, &pk.key(), c))
.map(|c| evaluator.multi_party_server_key_share(&cr_seed, &pk.key(), c))
.collect_vec();
let server_key = evaluator.aggregate_multi_party_server_key_shares(&server_key_shares);

+ 1
- 2
src/decomposer.rs

@ -298,14 +298,13 @@ mod tests {
let d = 3;
let mut stats = vec![Stats::new(); d];
for i in [true] {
for i in [true, false] {
let q = if i {
generate_prime(logq, 2 * ring_size, 1u64 << logq).unwrap()
} else {
1u64 << logq
};
let decomposer = DefaultDecomposer::new(q, logb, d);
dbg!(decomposer.ignore_bits);
let modq_op = ModularOpsU64::new(q);
for _ in 0..1000000 {
let value = rng.gen_range(0..q);

Loading…
Cancel
Save