Browse Source

add differing base feature for RLWExRGSw and RGSWxRGSW for interactive mpc

par-agg-key-shares
Janmajaya Mall 9 months ago
parent
commit
1d7099600a
8 changed files with 381 additions and 210 deletions
  1. +276
    -173
      src/bool/evaluator.rs
  2. +19
    -5
      src/bool/keys.rs
  3. +14
    -3
      src/bool/mp_api.rs
  4. +6
    -5
      src/bool/noise.rs
  5. +4
    -4
      src/bool/parameters.rs
  6. +2
    -0
      src/rgsw/mod.rs
  7. +33
    -18
      src/rgsw/runtime.rs
  8. +27
    -2
      src/utils.rs

+ 276
- 173
src/bool/evaluator.rs

@ -41,8 +41,8 @@ use crate::{
RlweCiphertext, RlweSecret, RlweCiphertext, RlweSecret,
}, },
utils::{ utils::{
fill_random_ternary_secret_with_hamming_weight, generate_prime, mod_exponent,
puncture_p_rng, Global, TryConvertFrom1, WithLocal,
encode_x_pow_si_with_emebedding_factor, fill_random_ternary_secret_with_hamming_weight,
generate_prime, mod_exponent, puncture_p_rng, Global, TryConvertFrom1, WithLocal,
}, },
Decryptor, Encoder, Encryptor, Matrix, MatrixEntity, MatrixMut, MultiPartyDecryptor, Row, Decryptor, Encoder, Encryptor, Matrix, MatrixEntity, MatrixMut, MultiPartyDecryptor, Row,
RowEntity, RowMut, Secret, RowEntity, RowMut, Secret,
@ -527,6 +527,24 @@ where
reduced_ct_i_out reduced_ct_i_out
} }
/// Assigns user with user_id segement of LWE secret indices for which they
/// generate RGSW(X^{s[i]}) as the leader (i.e. for RLWExRGSW). If returned
/// tuple is (start, end), user's segment is [start, end)
pub(super) fn interactive_mult_party_user_id_lwe_segment(
user_id: usize,
total_users: usize,
lwe_n: usize,
) -> (usize, usize) {
let per_user = (lwe_n as f64 / total_users as f64)
.ceil()
.to_usize()
.unwrap();
(
per_user * user_id,
std::cmp::min(per_user * (user_id + 1), lwe_n),
)
}
impl<M: Matrix, NttOp, RlweModOp, LweModOp, SKey> BoolEvaluator<M, NttOp, RlweModOp, LweModOp, SKey> impl<M: Matrix, NttOp, RlweModOp, LweModOp, SKey> BoolEvaluator<M, NttOp, RlweModOp, LweModOp, SKey>
where where
M: MatrixEntity + MatrixMut, M: MatrixEntity + MatrixMut,
@ -800,6 +818,8 @@ where
pub(super) fn multi_party_server_key_share<K: InteractiveMultiPartyClientKey<Element = i32>>( pub(super) fn multi_party_server_key_share<K: InteractiveMultiPartyClientKey<Element = i32>>(
&self, &self,
user_id: usize,
total_users: usize,
cr_seed: &MultiPartyCrs<[u8; 32]>, cr_seed: &MultiPartyCrs<[u8; 32]>,
collective_pk: &M, collective_pk: &M,
client_key: &K, client_key: &K,
@ -809,10 +829,7 @@ where
MultiPartyCrs<[u8; 32]>, MultiPartyCrs<[u8; 32]>,
> { > {
assert_eq!(self.parameters().variant(), &ParameterVariant::MultiParty); assert_eq!(self.parameters().variant(), &ParameterVariant::MultiParty);
// let user_id = 0;
// let user_segment_start = 0;
// let user_segment_end = 1;
assert!(user_id < total_users);
let sk_rlwe = client_key.sk_rlwe(); let sk_rlwe = client_key.sk_rlwe();
let sk_lwe = client_key.sk_lwe(); let sk_lwe = client_key.sk_lwe();
@ -836,41 +853,85 @@ where
); );
// rgsw ciphertexts of lwe secret elements // rgsw ciphertexts of lwe secret elements
let rgsw_cts = DefaultSecureRng::with_local_mut(|rng| {
let rgsw_rgsw_decomposer = self
.pbs_info
.parameters
.rgsw_rgsw_decomposer::<DefaultDecomposer<M::MatElement>>();
let (rgrg_d_a, rgrg_d_b) = (
rgsw_rgsw_decomposer.0.decomposition_count(),
rgsw_rgsw_decomposer.1.decomposition_count(),
);
let (rgrg_gadget_a, rgrg_gadget_b) = (
rgsw_rgsw_decomposer.0.gadget_vector(),
rgsw_rgsw_decomposer.1.gadget_vector(),
let (self_leader_rgsws, not_self_leader_rgsws) = DefaultSecureRng::with_local_mut(|rng| {
let mut self_leader_rgsw = vec![];
let mut not_self_leader_rgsws = vec![];
let (segment_start, segment_end) = interactive_mult_party_user_id_lwe_segment(
user_id,
total_users,
self.pbs_info().lwe_n(),
); );
let rgsw_cts = sk_lwe
.iter()
.map(|si| {
let mut m = M::R::zeros(ring_size);
//TODO(Jay): It will be nice to have a function that returns polynomial
// (monomial infact!) corresponding to secret element embedded in ring X^{2N+1}.
// Save lots of mistakes where one forgest to emebed si in bigger ring.
let si = *si * (self.pbs_info.embedding_factor as i32);
if si < 0 {
// X^{-si} = X^{2N-si} = -X^{N-si}, assuming abs(si) < N
// (which it is given si is secret element)
m.as_mut()[ring_size - (si.abs() as usize)] = rlwe_q.neg_one();
} else {
m.as_mut()[si as usize] = M::MatElement::one();
}
// public key RGSW encryption has no part that can be seeded, unlike secret key
// RGSW encryption where RLWE'_A(m) is seeded
// self LWE secret indices
{
// LWE secret indices for which user is the leader they need to send RGSW(m) for
// RLWE x RGSW multiplication
let rlrg_decomposer = self.pbs_info().rlwe_rgsw_decomposer();
let (rlrg_d_a, rlrg_d_b) = (
rlrg_decomposer.a().decomposition_count(),
rlrg_decomposer.b().decomposition_count(),
);
let (gadget_a, gadget_b) = (
rlrg_decomposer.a().gadget_vector(),
rlrg_decomposer.b().gadget_vector(),
);
for s_index in segment_start..segment_end {
let mut out_rgsw = M::zeros(rlrg_d_a * 2 + rlrg_d_b * 2, ring_size);
public_key_encrypt_rgsw(
&mut out_rgsw,
&encode_x_pow_si_with_emebedding_factor::<
M::R,
CiphertextModulus<M::MatElement>,
>(
sk_lwe[s_index],
self.pbs_info().embedding_factor(),
ring_size,
self.pbs_info().rlwe_q(),
)
.as_ref(),
collective_pk,
&gadget_a,
&gadget_b,
rlweq_modop,
rlweq_nttop,
rng,
);
self_leader_rgsw.push(out_rgsw);
}
}
// not self LWE secret indices
{
// LWE secret indices for which user isn't the leader, they need to send RGSW(m)
// for RGSW x RGSW multiplcation
let rgsw_rgsw_decomposer = self
.pbs_info
.parameters
.rgsw_rgsw_decomposer::<DefaultDecomposer<M::MatElement>>();
let (rgrg_d_a, rgrg_d_b) = (
rgsw_rgsw_decomposer.a().decomposition_count(),
rgsw_rgsw_decomposer.b().decomposition_count(),
);
let (rgrg_gadget_a, rgrg_gadget_b) = (
rgsw_rgsw_decomposer.a().gadget_vector(),
rgsw_rgsw_decomposer.b().gadget_vector(),
);
for s_index in (0..segment_start).chain(segment_end..self.parameters().lwe_n().0) {
let mut out_rgsw = M::zeros(rgrg_d_a * 2 + rgrg_d_b * 2, ring_size); let mut out_rgsw = M::zeros(rgrg_d_a * 2 + rgrg_d_b * 2, ring_size);
public_key_encrypt_rgsw( public_key_encrypt_rgsw(
&mut out_rgsw, &mut out_rgsw,
&m.as_ref(),
&encode_x_pow_si_with_emebedding_factor::<
M::R,
CiphertextModulus<M::MatElement>,
>(
sk_lwe[s_index],
self.pbs_info().embedding_factor(),
ring_size,
self.pbs_info().rlwe_q(),
)
.as_ref(),
collective_pk, collective_pk,
&rgrg_gadget_a, &rgrg_gadget_a,
&rgrg_gadget_b, &rgrg_gadget_b,
@ -879,10 +940,11 @@ where
rng, rng,
); );
out_rgsw
})
.collect_vec();
rgsw_cts
not_self_leader_rgsws.push(out_rgsw);
}
}
(self_leader_rgsw, not_self_leader_rgsws)
}); });
// LWE Ksk // LWE Ksk
@ -893,14 +955,173 @@ where
); );
CommonReferenceSeededMultiPartyServerKeyShare::new( CommonReferenceSeededMultiPartyServerKeyShare::new(
rgsw_cts,
self_leader_rgsws,
not_self_leader_rgsws,
auto_keys, auto_keys,
lwe_ksk, lwe_ksk,
cr_seed.clone(), cr_seed.clone(),
self.pbs_info.parameters.clone(), self.pbs_info.parameters.clone(),
user_id,
) )
} }
pub(super) fn aggregate_multi_party_server_key_shares<S>(
&self,
shares: &[CommonReferenceSeededMultiPartyServerKeyShare<
M,
BoolParameters<M::MatElement>,
MultiPartyCrs<S>,
>],
) -> SeededMultiPartyServerKey<M, MultiPartyCrs<S>, BoolParameters<M::MatElement>>
where
S: PartialEq + Clone,
M: Clone,
{
assert_eq!(self.parameters().variant(), &ParameterVariant::MultiParty);
assert!(shares.len() > 0);
let total_users = shares.len();
let parameters = shares[0].parameters().clone();
let cr_seed = shares[0].cr_seed();
let rlwe_n = parameters.rlwe_n().0;
let g = parameters.g() as isize;
let rlwe_q = parameters.rlwe_q();
let lwe_q = parameters.lwe_q();
// sanity checks
shares.iter().skip(1).for_each(|s| {
assert!(s.parameters() == &parameters);
assert!(s.cr_seed() == cr_seed);
});
let rlweq_modop = &self.pbs_info.rlwe_modop;
let rlweq_nttop = &self.pbs_info.rlwe_nttop;
// auto keys
let mut auto_keys = HashMap::new();
let auto_elements_dlog = parameters.auto_element_dlogs();
for i in auto_elements_dlog.into_iter() {
let mut key = M::zeros(parameters.auto_decomposition_count().0, rlwe_n);
shares.iter().for_each(|s| {
let auto_key_share_i = s.auto_keys().get(&i).expect("Auto key {i} missing");
assert!(
auto_key_share_i.dimension()
== (parameters.auto_decomposition_count().0, rlwe_n)
);
izip!(key.iter_rows_mut(), auto_key_share_i.iter_rows()).for_each(
|(partb_out, partb_share)| {
rlweq_modop.elwise_add_mut(partb_out.as_mut(), partb_share.as_ref());
},
);
});
auto_keys.insert(i, key);
}
// rgsw ciphertext (most expensive part!)
let rgsw_cts = {
let rgsw_by_rgsw_decomposer =
parameters.rgsw_rgsw_decomposer::<DefaultDecomposer<M::MatElement>>();
let rlwe_x_rgsw_decomposer = self.pbs_info().rlwe_rgsw_decomposer();
let rgsw_x_rgsw_dimension = (
rgsw_by_rgsw_decomposer.a().decomposition_count() * 2
+ rgsw_by_rgsw_decomposer.b().decomposition_count() * 2,
rlwe_n,
);
let rlwe_x_rgsw_dimension = (
rlwe_x_rgsw_decomposer.a().decomposition_count() * 2
+ rlwe_x_rgsw_decomposer.b().decomposition_count() * 2,
rlwe_n,
);
let mut rgsw_x_rgsw_scratch_mat = M::zeros(
std::cmp::max(
rgsw_by_rgsw_decomposer.a().decomposition_count(),
rgsw_by_rgsw_decomposer.b().decomposition_count(),
) + rlwe_x_rgsw_dimension.0,
rlwe_n,
);
let shares_in_correct_order = (0..total_users)
.map(|i| shares.iter().find(|s| s.user_id() == i).unwrap())
.collect_vec();
let lwe_n = self.parameters().lwe_n().0;
let (users_segments, users_segments_sizes): (Vec<(usize, usize)>, Vec<usize>) = (0
..total_users)
.map(|(user_id)| {
let (start_index, end_index) =
interactive_mult_party_user_id_lwe_segment(user_id, total_users, lwe_n);
((start_index, end_index), end_index - start_index)
})
.unzip();
let mut rgsw_cts = Vec::with_capacity(lwe_n);
users_segments
.iter()
.enumerate()
.for_each(|(user_id, user_segment)| {
let share = shares_in_correct_order[user_id];
for secret_index in user_segment.0..user_segment.1 {
let mut rgsw_i =
share.self_leader_rgsws()[secret_index - user_segment.0].clone();
// assert already exists in RGSW x RGSW rountine
assert!(rgsw_i.dimension() == rlwe_x_rgsw_dimension);
// multiply leader's RGSW ct at `secret_index` with RGSW cts of other users
// for lwe index `secret_index`
(0..total_users)
.filter(|i| i != &user_id)
.for_each(|other_user_id| {
let mut offset = 0;
if other_user_id < user_id {
offset = users_segments_sizes[other_user_id];
}
let mut other_rgsw_i = shares_in_correct_order[other_user_id]
.not_self_leader_rgsws()
[secret_index.checked_sub(offset).unwrap()]
.clone();
// assert already exists in RGSW x RGSW rountine
assert!(other_rgsw_i.dimension() == rgsw_x_rgsw_dimension);
// send to evaluation domain for RGSwxRGSW mul
other_rgsw_i
.iter_rows_mut()
.for_each(|r| rlweq_nttop.forward(r.as_mut()));
rgsw_by_rgsw_inplace(
&mut rgsw_i,
rlwe_x_rgsw_decomposer.a().decomposition_count(),
rlwe_x_rgsw_decomposer.b().decomposition_count(),
&other_rgsw_i,
&rgsw_by_rgsw_decomposer,
&mut rgsw_x_rgsw_scratch_mat,
rlweq_nttop,
rlweq_modop,
)
});
rgsw_cts.push(rgsw_i);
}
});
rgsw_cts
};
// LWE ksks
let mut lwe_ksk = M::R::zeros(rlwe_n * parameters.lwe_decomposition_count().0);
let lweq_modop = &self.pbs_info.lwe_modop;
shares.iter().for_each(|si| {
assert!(si.lwe_ksk().as_ref().len() == rlwe_n * parameters.lwe_decomposition_count().0);
lweq_modop.elwise_add_mut(lwe_ksk.as_mut(), si.lwe_ksk().as_ref())
});
SeededMultiPartyServerKey::new(rgsw_cts, auto_keys, lwe_ksk, cr_seed.clone(), parameters)
}
pub(super) fn aggregate_non_interactive_multi_party_key_share( pub(super) fn aggregate_non_interactive_multi_party_key_share(
&self, &self,
cr_seed: &NonInteractiveMultiPartyCrs<[u8; 32]>, cr_seed: &NonInteractiveMultiPartyCrs<[u8; 32]>,
@ -1351,6 +1572,8 @@ where
.for_each(|user_i_rgsws| { .for_each(|user_i_rgsws| {
rgsw_by_rgsw_inplace( rgsw_by_rgsw_inplace(
&mut rgsw_i, &mut rgsw_i,
rgsw_by_rgsw_decomposer.a().decomposition_count(),
rgsw_by_rgsw_decomposer.b().decomposition_count(),
&user_i_rgsws[s_index], &user_i_rgsws[s_index],
&rgsw_by_rgsw_decomposer, &rgsw_by_rgsw_decomposer,
&mut scratch_matrix, &mut scratch_matrix,
@ -1834,125 +2057,6 @@ where
let m = decrypt_lwe(lwe_ct, &client_key.sk_rlwe(), &self.pbs_info.rlwe_modop); let m = decrypt_lwe(lwe_ct, &client_key.sk_rlwe(), &self.pbs_info.rlwe_modop);
self.pbs_info.rlwe_q().decode(m) self.pbs_info.rlwe_q().decode(m)
} }
pub(super) fn aggregate_multi_party_server_key_shares<S>(
&self,
shares: &[CommonReferenceSeededMultiPartyServerKeyShare<
M,
BoolParameters<M::MatElement>,
MultiPartyCrs<S>,
>],
) -> SeededMultiPartyServerKey<M, MultiPartyCrs<S>, BoolParameters<M::MatElement>>
where
S: PartialEq + Clone,
M: Clone,
{
assert_eq!(self.parameters().variant(), &ParameterVariant::MultiParty);
assert!(shares.len() > 0);
let parameters = shares[0].parameters().clone();
let cr_seed = shares[0].cr_seed();
let rlwe_n = parameters.rlwe_n().0;
let g = parameters.g() as isize;
let rlwe_q = parameters.rlwe_q();
let lwe_q = parameters.lwe_q();
// sanity checks
shares.iter().skip(1).for_each(|s| {
assert!(s.parameters() == &parameters);
assert!(s.cr_seed() == cr_seed);
});
let rlweq_modop = &self.pbs_info.rlwe_modop;
let rlweq_nttop = &self.pbs_info.rlwe_nttop;
// auto keys
let mut auto_keys = HashMap::new();
let auto_elements_dlog = parameters.auto_element_dlogs();
for i in auto_elements_dlog.into_iter() {
let mut key = M::zeros(parameters.auto_decomposition_count().0, rlwe_n);
shares.iter().for_each(|s| {
let auto_key_share_i = s.auto_keys().get(&i).expect("Auto key {i} missing");
assert!(
auto_key_share_i.dimension()
== (parameters.auto_decomposition_count().0, rlwe_n)
);
izip!(key.iter_rows_mut(), auto_key_share_i.iter_rows()).for_each(
|(partb_out, partb_share)| {
rlweq_modop.elwise_add_mut(partb_out.as_mut(), partb_share.as_ref());
},
);
});
auto_keys.insert(i, key);
}
// rgsw ciphertext (most expensive part!)
let lwe_n = parameters.lwe_n().0;
let rgsw_by_rgsw_decomposer =
parameters.rgsw_rgsw_decomposer::<DefaultDecomposer<M::MatElement>>();
let mut scratch_matrix = M::zeros(
std::cmp::max(
rgsw_by_rgsw_decomposer.a().decomposition_count(),
rgsw_by_rgsw_decomposer.b().decomposition_count(),
) + (rgsw_by_rgsw_decomposer.a().decomposition_count() * 2
+ rgsw_by_rgsw_decomposer.b().decomposition_count() * 2),
rlwe_n,
);
let mut tmp_rgsw =
RgswCiphertext::<M, _>::empty(rlwe_n, &rgsw_by_rgsw_decomposer, rlwe_q.clone()).data;
let rgsw_cts = (0..lwe_n).into_iter().map(|index| {
// copy over rgsw ciphertext for index^th secret element from first share and
// treat it as accumulating rgsw ciphertext
let mut rgsw_i = shares[0].rgsw_cts()[index].clone();
shares.iter().skip(1).for_each(|si| {
// copy over si's RGSW[index] ciphertext and send to evaluation domain
izip!(tmp_rgsw.iter_rows_mut(), si.rgsw_cts()[index].iter_rows()).for_each(
|(to_ri, from_ri)| {
to_ri.as_mut().copy_from_slice(from_ri.as_ref());
rlweq_nttop.forward(to_ri.as_mut())
},
);
rgsw_by_rgsw_inplace(
&mut rgsw_i,
&tmp_rgsw,
&rgsw_by_rgsw_decomposer,
&mut scratch_matrix,
rlweq_nttop,
rlweq_modop,
);
});
rgsw_i
});
// d_a and d_b may differ for RGSWxRGSW multiplication and RLWExRGSW
// multiplication. After this point RGSW ciphertexts will only be used for
// RLWExRGSW multiplication (in blind rotation). Thus we drop any additional
// RLWE ciphertexts in RGSW ciphertexts after RGSw x RGSW multiplication
let rgsw_cts = rgsw_cts
.map(|ct_i_in| {
trim_rgsw_ct_matrix_from_rgrg_to_rlrg(
ct_i_in,
self.parameters().rgsw_by_rgsw_decomposition_params(),
self.parameters().rlwe_by_rgsw_decomposition_params(),
)
})
.collect_vec();
// LWE ksks
let mut lwe_ksk = M::R::zeros(rlwe_n * parameters.lwe_decomposition_count().0);
let lweq_modop = &self.pbs_info.lwe_modop;
shares.iter().for_each(|si| {
assert!(si.lwe_ksk().as_ref().len() == rlwe_n * parameters.lwe_decomposition_count().0);
lweq_modop.elwise_add_mut(lwe_ksk.as_mut(), si.lwe_ksk().as_ref())
});
SeededMultiPartyServerKey::new(rgsw_cts, auto_keys, lwe_ksk, cr_seed.clone(), parameters)
}
} }
impl<M, NttOp, RlweModOp, LweModOp, Skey> BoolEvaluator<M, NttOp, RlweModOp, LweModOp, Skey> impl<M, NttOp, RlweModOp, LweModOp, Skey> BoolEvaluator<M, NttOp, RlweModOp, LweModOp, Skey>
@ -2267,6 +2371,8 @@ mod tests {
}); });
}); });
let mut rng = DefaultSecureRng::new();
// check noise in freshly encrypted RLWE ciphertext (ie var_fresh) // check noise in freshly encrypted RLWE ciphertext (ie var_fresh)
if false { if false {
let mut rng = DefaultSecureRng::new(); let mut rng = DefaultSecureRng::new();
@ -2316,9 +2422,6 @@ mod tests {
if true { if true {
// Generate server key shares // Generate server key shares
let mut rng = DefaultSecureRng::new();
let mut pk_cr_seed = [0u8; 32];
rng.fill_bytes(&mut pk_cr_seed);
let public_key_share = parties let public_key_share = parties
.iter() .iter()
.map(|k| bool_evaluator.multi_party_public_key_share(&int_mp_seed, k)) .map(|k| bool_evaluator.multi_party_public_key_share(&int_mp_seed, k))
@ -2329,12 +2432,13 @@ mod tests {
ModularOpsU64<CiphertextModulus<u64>>, ModularOpsU64<CiphertextModulus<u64>>,
>::from(public_key_share.as_slice()); >::from(public_key_share.as_slice());
let pbs_cr_seed = [0u8; 32];
rng.fill_bytes(&mut pk_cr_seed);
let server_key_shares = parties let server_key_shares = parties
.iter() .iter()
.map(|k| {
.enumerate()
.map(|(user_id, k)| {
bool_evaluator.multi_party_server_key_share( bool_evaluator.multi_party_server_key_share(
user_id,
no_of_parties,
&int_mp_seed, &int_mp_seed,
collective_pk.key(), collective_pk.key(),
k, k,
@ -2351,13 +2455,12 @@ mod tests {
izip!(ideal_lwe_sk.iter(), seeded_server_key.rgsw_cts().iter()).for_each( izip!(ideal_lwe_sk.iter(), seeded_server_key.rgsw_cts().iter()).for_each(
|(s_i, rgsw_ct_i)| { |(s_i, rgsw_ct_i)| {
// X^{s[i]} // X^{s[i]}
let mut m_si = vec![0u64; rlwe_n];
let s_i = *s_i * (bool_evaluator.pbs_info.embedding_factor as i32);
if s_i < 0 {
m_si[rlwe_n - (s_i.abs() as usize)] = rlwe_q.neg_one();
} else {
m_si[s_i as usize] = 1;
}
let m_si = encode_x_pow_si_with_emebedding_factor::<Vec<u64>, _>(
*s_i,
bool_evaluator.pbs_info.embedding_factor,
rlwe_n,
rlwe_q,
);
// RLWE'(-sm) // RLWE'(-sm)
let mut neg_s_eval = let mut neg_s_eval =

+ 19
- 5
src/bool/keys.rs

@ -314,7 +314,8 @@ impl CommonReferenceSeededCollectivePublicKeyShare {
/// CRS seeded Multi-party server key share /// CRS seeded Multi-party server key share
pub struct CommonReferenceSeededMultiPartyServerKeyShare<M: Matrix, P, S> { pub struct CommonReferenceSeededMultiPartyServerKeyShare<M: Matrix, P, S> {
rgsw_cts: Vec<M>,
self_leader_rgsws: Vec<M>,
not_self_leader_rgsws: Vec<M>,
/// Auto keys. Key corresponding to g^{k} is at index `k`. Key corresponding /// Auto keys. Key corresponding to g^{k} is at index `k`. Key corresponding
/// to -g is at 0 /// to -g is at 0
auto_keys: HashMap<usize, M>, auto_keys: HashMap<usize, M>,
@ -322,22 +323,27 @@ pub struct CommonReferenceSeededMultiPartyServerKeyShare {
/// Common reference seed /// Common reference seed
cr_seed: S, cr_seed: S,
parameters: P, parameters: P,
user_id: usize,
} }
impl<M: Matrix, P, S> CommonReferenceSeededMultiPartyServerKeyShare<M, P, S> { impl<M: Matrix, P, S> CommonReferenceSeededMultiPartyServerKeyShare<M, P, S> {
pub(super) fn new( pub(super) fn new(
rgsw_cts: Vec<M>,
self_leader_rgsws: Vec<M>,
not_self_leader_rgsws: Vec<M>,
auto_keys: HashMap<usize, M>, auto_keys: HashMap<usize, M>,
lwe_ksk: M::R, lwe_ksk: M::R,
cr_seed: S, cr_seed: S,
parameters: P, parameters: P,
user_id: usize,
) -> Self { ) -> Self {
CommonReferenceSeededMultiPartyServerKeyShare { CommonReferenceSeededMultiPartyServerKeyShare {
rgsw_cts,
self_leader_rgsws,
not_self_leader_rgsws,
auto_keys, auto_keys,
lwe_ksk, lwe_ksk,
cr_seed, cr_seed,
parameters, parameters,
user_id,
} }
} }
@ -353,13 +359,21 @@ impl CommonReferenceSeededMultiPartyServerKeyShare {
&self.auto_keys &self.auto_keys
} }
pub(super) fn rgsw_cts(&self) -> &[M] {
&self.rgsw_cts
pub(crate) fn self_leader_rgsws(&self) -> &[M] {
&self.self_leader_rgsws
}
pub(super) fn not_self_leader_rgsws(&self) -> &[M] {
&self.not_self_leader_rgsws
} }
pub(super) fn lwe_ksk(&self) -> &M::R { pub(super) fn lwe_ksk(&self) -> &M::R {
&self.lwe_ksk &self.lwe_ksk
} }
pub(super) fn user_id(&self) -> usize {
self.user_id
}
} }
/// CRS seeded MultiParty server key /// CRS seeded MultiParty server key

+ 14
- 3
src/bool/mp_api.rs

@ -58,6 +58,8 @@ pub fn gen_mp_keys_phase1(
pub fn gen_mp_keys_phase2<R, ModOp>( pub fn gen_mp_keys_phase2<R, ModOp>(
ck: &ClientKey, ck: &ClientKey,
user_id: usize,
total_users: usize,
pk: &PublicKey<Vec<Vec<u64>>, R, ModOp>, pk: &PublicKey<Vec<Vec<u64>>, R, ModOp>,
) -> CommonReferenceSeededMultiPartyServerKeyShare< ) -> CommonReferenceSeededMultiPartyServerKeyShare<
Vec<Vec<u64>>, Vec<Vec<u64>>,
@ -65,8 +67,13 @@ pub fn gen_mp_keys_phase2(
MultiPartyCrs<[u8; 32]>, MultiPartyCrs<[u8; 32]>,
> { > {
BoolEvaluator::with_local_mut(|e| { BoolEvaluator::with_local_mut(|e| {
let server_key_share =
e.multi_party_server_key_share(MultiPartyCrs::global(), pk.key(), ck);
let server_key_share = e.multi_party_server_key_share(
user_id,
total_users,
MultiPartyCrs::global(),
pk.key(),
ck,
);
server_key_share server_key_share
}) })
} }
@ -251,7 +258,11 @@ mod tests {
let pk = aggregate_public_key_shares(&pk_shares); let pk = aggregate_public_key_shares(&pk_shares);
// round 2 // round 2
let server_key_shares = cks.iter().map(|k| gen_mp_keys_phase2(k, &pk)).collect_vec();
let server_key_shares = cks
.iter()
.enumerate()
.map(|(user_id, k)| gen_mp_keys_phase2(k, user_id, parties, &pk))
.collect_vec();
// server key // server key
let server_key = aggregate_server_key_shares(&server_key_shares); let server_key = aggregate_server_key_shares(&server_key_shares);

+ 6
- 5
src/bool/noise.rs

@ -13,6 +13,7 @@ mod test {
}, },
evaluator::MultiPartyCrs, evaluator::MultiPartyCrs,
ntt::NttBackendU64, ntt::NttBackendU64,
parameters::OPTIMISED_SMALL_MP_BOOL_PARAMS,
random::DefaultSecureRng, random::DefaultSecureRng,
}; };
@ -25,7 +26,7 @@ mod test {
ModularOpsU64<CiphertextModulus<u64>>, ModularOpsU64<CiphertextModulus<u64>>,
ModulusPowerOf2<CiphertextModulus<u64>>, ModulusPowerOf2<CiphertextModulus<u64>>,
ShoupServerKeyEvaluationDomain<Vec<Vec<u64>>>, ShoupServerKeyEvaluationDomain<Vec<Vec<u64>>>,
>::new(SMALL_MP_BOOL_PARAMS);
>::new(OPTIMISED_SMALL_MP_BOOL_PARAMS);
let parties = 2; let parties = 2;
@ -72,7 +73,10 @@ mod test {
// round 2 // round 2
let server_key_shares = cks let server_key_shares = cks
.iter() .iter()
.map(|c| evaluator.multi_party_server_key_share(&cr_seed, &pk.key(), c))
.enumerate()
.map(|(index, c)| {
evaluator.multi_party_server_key_share(index, parties, &cr_seed, &pk.key(), c)
})
.collect_vec(); .collect_vec();
let server_key = evaluator.aggregate_multi_party_server_key_shares(&server_key_shares); let server_key = evaluator.aggregate_multi_party_server_key_shares(&server_key_shares);
@ -89,9 +93,6 @@ mod test {
let mut c_m0 = evaluator.pk_encrypt(pk.key(), m0); let mut c_m0 = evaluator.pk_encrypt(pk.key(), m0);
let mut c_m1 = evaluator.pk_encrypt(pk.key(), m1); let mut c_m1 = evaluator.pk_encrypt(pk.key(), m1);
let true_el_encoded = evaluator.parameters().rlwe_q().true_el();
let false_el_encoded = evaluator.parameters().rlwe_q().false_el();
// let mut stats = Stats::new(); // let mut stats = Stats::new();
for _ in 0..1000 { for _ in 0..1000 {

+ 4
- 4
src/bool/parameters.rs

@ -494,14 +494,14 @@ pub(crate) const OPTIMISED_SMALL_MP_BOOL_PARAMS: BoolParameters = BoolParam
lwe_n: LweDimension(500), lwe_n: LweDimension(500),
lwe_decomposer_params: (DecompostionLogBase(1), DecompositionCount(11)), lwe_decomposer_params: (DecompostionLogBase(1), DecompositionCount(11)),
rlrg_decomposer_params: ( rlrg_decomposer_params: (
DecompostionLogBase(24),
DecompostionLogBase(16),
(DecompositionCount(1), DecompositionCount(1)), (DecompositionCount(1), DecompositionCount(1)),
), ),
rgrg_decomposer_params: Some(( rgrg_decomposer_params: Some((
DecompostionLogBase(12),
(DecompositionCount(3), DecompositionCount(3)),
DecompostionLogBase(8),
(DecompositionCount(6), DecompositionCount(6)),
)), )),
auto_decomposer_params: (DecompostionLogBase(20), DecompositionCount(1)),
auto_decomposer_params: (DecompostionLogBase(24), DecompositionCount(1)),
non_interactive_ui_to_s_key_switch_decomposer: None, non_interactive_ui_to_s_key_switch_decomposer: None,
g: 5, g: 5,
w: 10, w: 10,

+ 2
- 0
src/rgsw/mod.rs

@ -1114,6 +1114,8 @@ pub(crate) mod tests {
); );
rgsw_by_rgsw_inplace( rgsw_by_rgsw_inplace(
&mut rgsw_carrym, &mut rgsw_carrym,
decomposer.a().decomposition_count(),
decomposer.b().decomposition_count(),
&rgsw_m.data, &rgsw_m.data,
&decomposer, &decomposer,
&mut scratch_matrix, &mut scratch_matrix,

+ 33
- 18
src/rgsw/runtime.rs

@ -546,14 +546,19 @@ pub(crate) fn rlwe_by_rgsw_shoup<
/// - rgsw_1_eval: RGSW(m1) in Evaluation domain /// - rgsw_1_eval: RGSW(m1) in Evaluation domain
/// - scratch_matrix_d_plus_rgsw_by_ring: scratch space matrix with rows /// - scratch_matrix_d_plus_rgsw_by_ring: scratch space matrix with rows
/// (max(d_a, d_b) + d_a*2+d_b*2) and columns ring_size /// (max(d_a, d_b) + d_a*2+d_b*2) and columns ring_size
///
/// ## Note:
/// - We treat RGSW x RGSW as multiple RLWE x RGSW multiplications. .
pub(crate) fn rgsw_by_rgsw_inplace< pub(crate) fn rgsw_by_rgsw_inplace<
Mmut: MatrixMut, Mmut: MatrixMut,
D: RlweDecomposer<Element = Mmut::MatElement>, D: RlweDecomposer<Element = Mmut::MatElement>,
ModOp: VectorOps<Element = Mmut::MatElement>, ModOp: VectorOps<Element = Mmut::MatElement>,
NttOp: Ntt<Element = Mmut::MatElement>, NttOp: Ntt<Element = Mmut::MatElement>,
>( >(
rgsw_0: &mut Mmut,
rgsw_1_eval: &Mmut,
rgsw0: &mut Mmut,
rgsw0_da: usize,
rgsw0_db: usize,
rgsw1_eval: &Mmut,
decomposer: &D, decomposer: &D,
scratch_matrix: &mut Mmut, scratch_matrix: &mut Mmut,
ntt_op: &NttOp, ntt_op: &NttOp,
@ -567,11 +572,12 @@ pub(crate) fn rgsw_by_rgsw_inplace<
let d_a = decomposer_a.decomposition_count(); let d_a = decomposer_a.decomposition_count();
let d_b = decomposer_b.decomposition_count(); let d_b = decomposer_b.decomposition_count();
let max_d = std::cmp::max(d_a, d_b); let max_d = std::cmp::max(d_a, d_b);
let rgsw_rows = d_a * 2 + d_b * 2;
assert!(rgsw_0.dimension().0 == rgsw_rows);
let ring_size = rgsw_0.dimension().1;
assert!(rgsw_1_eval.dimension() == (rgsw_rows, ring_size));
assert!(scratch_matrix.fits(max_d + rgsw_rows, ring_size));
let rgsw1_rows = d_a * 2 + d_b * 2;
let rgsw0_rows = rgsw0_da * 2 + rgsw0_db * 2;
let ring_size = rgsw0.dimension().1;
assert!(rgsw0.dimension().0 == rgsw0_rows);
assert!(rgsw1_eval.dimension() == (rgsw1_rows, ring_size));
assert!(scratch_matrix.fits(max_d + rgsw0_rows, ring_size));
let (decomp_r_space, rgsw_space) = scratch_matrix.split_at_row_mut(max_d); let (decomp_r_space, rgsw_space) = scratch_matrix.split_at_row_mut(max_d);
@ -579,18 +585,25 @@ pub(crate) fn rgsw_by_rgsw_inplace<
rgsw_space rgsw_space
.iter_mut() .iter_mut()
.for_each(|ri| ri.as_mut().fill(Mmut::MatElement::zero())); .for_each(|ri| ri.as_mut().fill(Mmut::MatElement::zero()));
let (rlwe_dash_space_nsm, rlwe_dash_space_m) = rgsw_space.split_at_mut(d_a * 2);
let (rlwe_dash_space_nsm, rlwe_dash_space_m) = rgsw_space.split_at_mut(rgsw0_da * 2);
let (rlwe_dash_space_nsm_parta, rlwe_dash_space_nsm_partb) = let (rlwe_dash_space_nsm_parta, rlwe_dash_space_nsm_partb) =
rlwe_dash_space_nsm.split_at_mut(d_a);
let (rlwe_dash_space_m_parta, rlwe_dash_space_m_partb) = rlwe_dash_space_m.split_at_mut(d_b);
rlwe_dash_space_nsm.split_at_mut(rgsw0_da);
let (rlwe_dash_space_m_parta, rlwe_dash_space_m_partb) =
rlwe_dash_space_m.split_at_mut(rgsw0_db);
let (rgsw0_nsm, rgsw0_m) = rgsw_0.split_at_row(d_a * 2);
let (rgsw1_nsm, rgsw1_m) = rgsw_1_eval.split_at_row(d_a * 2);
let (rgsw0_nsm, rgsw0_m) = rgsw0.split_at_row(rgsw0_da * 2);
let (rgsw1_nsm, rgsw1_m) = rgsw1_eval.split_at_row(d_a * 2);
// RGSW x RGSW // RGSW x RGSW
izip!( izip!(
rgsw0_nsm.iter().take(d_a).chain(rgsw0_m.iter().take(d_b)),
rgsw0_nsm.iter().skip(d_a).chain(rgsw0_m.iter().skip(d_b)),
rgsw0_nsm
.iter()
.take(rgsw0_da)
.chain(rgsw0_m.iter().take(rgsw0_db)),
rgsw0_nsm
.iter()
.skip(rgsw0_da)
.chain(rgsw0_m.iter().skip(rgsw0_db)),
rlwe_dash_space_nsm_parta rlwe_dash_space_nsm_parta
.iter_mut() .iter_mut()
.chain(rlwe_dash_space_m_parta.iter_mut()), .chain(rlwe_dash_space_m_parta.iter_mut()),
@ -599,7 +612,9 @@ pub(crate) fn rgsw_by_rgsw_inplace<
.chain(rlwe_dash_space_m_partb.iter_mut()), .chain(rlwe_dash_space_m_partb.iter_mut()),
) )
.for_each(|(rlwe_a, rlwe_b, rlwe_out_a, rlwe_out_b)| { .for_each(|(rlwe_a, rlwe_b, rlwe_out_a, rlwe_out_b)| {
// Part A
// RLWE(m0) x RGSW(m1)
// Part A: Decomp<RLWE(m0)[A]> \cdot RLWE'(-sm1)
decompose_r(rlwe_a.as_ref(), decomp_r_space.as_mut(), decomposer_a); decompose_r(rlwe_a.as_ref(), decomp_r_space.as_mut(), decomposer_a);
decomp_r_space decomp_r_space
.iter_mut() .iter_mut()
@ -618,7 +633,7 @@ pub(crate) fn rgsw_by_rgsw_inplace<
mod_op, mod_op,
); );
// Part B
// Part B: Decompose<RLWE(m0)[B]> \cdot RLWE'(m1)
decompose_r(rlwe_b.as_ref(), decomp_r_space.as_mut(), decomposer_b); decompose_r(rlwe_b.as_ref(), decomp_r_space.as_mut(), decomposer_b);
decomp_r_space decomp_r_space
.iter_mut() .iter_mut()
@ -639,11 +654,11 @@ pub(crate) fn rgsw_by_rgsw_inplace<
}); });
// copy over RGSW(m0m1) into RGSW(m0) // copy over RGSW(m0m1) into RGSW(m0)
izip!(rgsw_0.iter_rows_mut(), rgsw_space.iter())
izip!(rgsw0.iter_rows_mut(), rgsw_space.iter())
.for_each(|(to_ri, from_ri)| to_ri.as_mut().copy_from_slice(from_ri.as_ref())); .for_each(|(to_ri, from_ri)| to_ri.as_mut().copy_from_slice(from_ri.as_ref()));
// send back to coefficient domain // send back to coefficient domain
rgsw_0
rgsw0
.iter_rows_mut() .iter_rows_mut()
.for_each(|ri| ntt_op.backward(ri.as_mut())); .for_each(|ri| ntt_op.backward(ri.as_mut()));
} }

+ 27
- 2
src/utils.rs

@ -1,12 +1,12 @@
use std::{fmt::Debug, usize, vec}; use std::{fmt::Debug, usize, vec};
use itertools::{izip, Itertools}; use itertools::{izip, Itertools};
use num_traits::{FromPrimitive, PrimInt, Signed};
use num_traits::{FromPrimitive, One, PrimInt, Signed};
use crate::{ use crate::{
backend::Modulus, backend::Modulus,
random::{RandomElementInModulus, RandomFill}, random::{RandomElementInModulus, RandomFill},
Matrix,
Matrix, Row, RowEntity, RowMut,
}; };
pub trait WithLocal { pub trait WithLocal {
fn with_local<F, R>(func: F) -> R fn with_local<F, R>(func: F) -> R
@ -190,6 +190,31 @@ pub fn negacyclic_mul T>(
return r; return r;
} }
/// Returns a polynomial X^{emebedding_factor * si} \mod {Z_Q / X^{N}+1}
pub(crate) fn encode_x_pow_si_with_emebedding_factor<
R: RowEntity + RowMut,
M: Modulus<Element = R::Element>,
>(
si: i32,
embedding_factor: usize,
ring_size: usize,
modulus: &M,
) -> R
where
R::Element: One,
{
assert!((si.abs() as usize) < ring_size);
let mut m = R::zeros(ring_size);
let si = si * (embedding_factor as i32);
if si < 0 {
// X^{-si} = X^{2N-si} = -X^{N-si}, assuming abs(si) < N
m.as_mut()[ring_size - (si.abs() as usize)] = modulus.neg_one();
} else {
m.as_mut()[si as usize] = R::Element::one();
}
m
}
pub(crate) fn puncture_p_rng<S: Default + Copy, R: RandomFill<S>>( pub(crate) fn puncture_p_rng<S: Default + Copy, R: RandomFill<S>>(
p_rng: &mut R, p_rng: &mut R,
times: usize, times: usize,

Loading…
Cancel
Save