Browse Source

Add multi-party Uint8

par-agg-key-shares
Janmajaya Mall 10 months ago
parent
commit
1c0ac104e2
5 changed files with 418 additions and 51 deletions
  1. +269
    -41
      src/bool/evaluator.rs
  2. +3
    -3
      src/bool/parameters.rs
  3. +11
    -2
      src/lib.rs
  4. +15
    -0
      src/random.rs
  5. +120
    -5
      src/shortint/mod.rs

+ 269
- 41
src/bool/evaluator.rs

@ -1,5 +1,6 @@
use std::{ use std::{
cell::{OnceCell, RefCell}, cell::{OnceCell, RefCell},
clone,
collections::HashMap, collections::HashMap,
fmt::{Debug, Display}, fmt::{Debug, Display},
iter::Once, iter::Once,
@ -20,8 +21,8 @@ use crate::{
multi_party::public_key_share, multi_party::public_key_share,
ntt::{self, Ntt, NttBackendU64, NttInit}, ntt::{self, Ntt, NttBackendU64, NttInit},
random::{ random::{
DefaultSecureRng, NewWithSeed, RandomFillGaussianInModulus, RandomFillUniformInModulus,
RandomGaussianElementInModulus,
DefaultSecureRng, NewWithSeed, RandomFill, RandomFillGaussianInModulus,
RandomFillUniformInModulus, RandomGaussianElementInModulus,
}, },
rgsw::{ rgsw::{
decrypt_rlwe, galois_auto, galois_key_gen, generate_auto_map, public_key_encrypt_rgsw, decrypt_rlwe, galois_auto, galois_key_gen, generate_auto_map, public_key_encrypt_rgsw,
@ -32,10 +33,11 @@ use crate::{
fill_random_ternary_secret_with_hamming_weight, generate_prime, mod_exponent, Global, fill_random_ternary_secret_with_hamming_weight, generate_prime, mod_exponent, Global,
TryConvertFrom1, WithLocal, TryConvertFrom1, WithLocal,
}, },
Decryptor, Encryptor, Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, Secret,
Decryptor, Encryptor, Matrix, MatrixEntity, MatrixMut, MultiPartyDecryptor, Row, RowEntity,
RowMut, Secret,
}; };
use super::parameters::{BoolParameters, CiphertextModulus};
use super::parameters::{self, BoolParameters, CiphertextModulus};
thread_local! { thread_local! {
pub(crate) static BOOL_EVALUATOR: RefCell<BoolEvaluator<Vec<Vec<u64>>, NttBackendU64, ModularOpsU64<CiphertextModulus<u64>>, ModularOpsU64<CiphertextModulus<u64>>>> = RefCell::new(BoolEvaluator::new(MP_BOOL_PARAMS)); pub(crate) static BOOL_EVALUATOR: RefCell<BoolEvaluator<Vec<Vec<u64>>, NttBackendU64, ModularOpsU64<CiphertextModulus<u64>>, ModularOpsU64<CiphertextModulus<u64>>>> = RefCell::new(BoolEvaluator::new(MP_BOOL_PARAMS));
@ -45,10 +47,19 @@ pub(crate) static BOOL_SERVER_KEY: OnceLock<
ServerKeyEvaluationDomain<Vec<Vec<u64>>, DefaultSecureRng, NttBackendU64>, ServerKeyEvaluationDomain<Vec<Vec<u64>>, DefaultSecureRng, NttBackendU64>,
> = OnceLock::new(); > = OnceLock::new();
pub(crate) static MULTI_PARTY_CRS: OnceLock<MultiPartyCrs<[u8; 32]>> = OnceLock::new();
pub fn set_parameter_set(parameter: &BoolParameters<u64>) { pub fn set_parameter_set(parameter: &BoolParameters<u64>) {
BoolEvaluator::with_local_mut(|e| *e = BoolEvaluator::new(parameter.clone())) BoolEvaluator::with_local_mut(|e| *e = BoolEvaluator::new(parameter.clone()))
} }
pub fn set_mp_seed(seed: [u8; 32]) {
assert!(
MULTI_PARTY_CRS.set(MultiPartyCrs { seed: seed }).is_ok(),
"Attempted to set MP SEED twice."
)
}
fn set_server_key(key: ServerKeyEvaluationDomain<Vec<Vec<u64>>, DefaultSecureRng, NttBackendU64>) { fn set_server_key(key: ServerKeyEvaluationDomain<Vec<Vec<u64>>, DefaultSecureRng, NttBackendU64>) {
assert!( assert!(
BOOL_SERVER_KEY.set(key).is_ok(), BOOL_SERVER_KEY.set(key).is_ok(),
@ -56,7 +67,7 @@ fn set_server_key(key: ServerKeyEvaluationDomain>, DefaultSecureRng
); );
} }
pub fn gen_keys() -> (
pub(crate) fn gen_keys() -> (
ClientKey, ClientKey,
SeededServerKey<Vec<Vec<u64>>, BoolParameters<u64>, [u8; 32]>, SeededServerKey<Vec<Vec<u64>>, BoolParameters<u64>, [u8; 32]>,
) { ) {
@ -67,6 +78,93 @@ pub fn gen_keys() -> (
(ck, sk) (ck, sk)
}) })
} }
pub fn gen_client_key() -> ClientKey {
BoolEvaluator::with_local(|e| e.client_key())
}
pub fn gen_mp_keys_phase1(
ck: &ClientKey,
) -> CommonReferenceSeededCollectivePublicKeyShare<Vec<u64>, [u8; 32], BoolParameters<u64>> {
let seed = MultiPartyCrs::global().public_key_share_seed::<DefaultSecureRng>();
BoolEvaluator::with_local(|e| {
let pk_share = e.multi_party_public_key_share(seed, &ck);
pk_share
})
}
pub fn gen_mp_keys_phase2<R, ModOp>(
ck: &ClientKey,
pk: &PublicKey<Vec<Vec<u64>>, R, ModOp>,
) -> CommonReferenceSeededMultiPartyServerKeyShare<Vec<Vec<u64>>, BoolParameters<u64>, [u8; 32]> {
let seed = MultiPartyCrs::global().server_key_share_seed::<DefaultSecureRng>();
BoolEvaluator::with_local_mut(|e| {
let server_key_share = e.multi_party_server_key_share(seed, &pk.key, ck);
server_key_share
})
}
pub fn aggregate_public_key_shares(
shares: &[CommonReferenceSeededCollectivePublicKeyShare<
Vec<u64>,
[u8; 32],
BoolParameters<u64>,
>],
) -> PublicKey<Vec<Vec<u64>>, DefaultSecureRng, ModularOpsU64<CiphertextModulus<u64>>> {
PublicKey::from(shares)
}
pub fn aggregate_server_key_shares(
shares: &[CommonReferenceSeededMultiPartyServerKeyShare<
Vec<Vec<u64>>,
BoolParameters<u64>,
[u8; 32],
>],
) -> SeededMultiPartyServerKey<Vec<Vec<u64>>, [u8; 32], BoolParameters<u64>> {
BoolEvaluator::with_local(|e| e.aggregate_multi_party_server_key_shares(shares))
}
// GENERIC BELOW
pub struct MultiPartyCrs<S> {
seed: S,
}
impl<S: Default + Copy> MultiPartyCrs<S> {
/// Seed to generate public key share using MultiPartyCrs as the main seed.
///
/// Public key seed equals the 1st seed extracted from PRNG Seeded with
/// MiltiPartyCrs's seed.
fn public_key_share_seed<Rng: NewWithSeed<Seed = S> + RandomFill<S>>(&self) -> S {
let mut prng = Rng::new_with_seed(self.seed);
let mut seed = S::default();
RandomFill::<S>::random_fill(&mut prng, &mut seed);
seed
}
/// Seed to generate server key share using MultiPartyCrs as the main seed.
///
/// Server key seed equals the 2nd seed extracted from PRNG Seeded with
/// MiltiPartyCrs's seed.
fn server_key_share_seed<Rng: NewWithSeed<Seed = S> + RandomFill<S>>(&self) -> S {
let mut prng = Rng::new_with_seed(self.seed);
let mut seed = S::default();
RandomFill::<S>::random_fill(&mut prng, &mut seed);
RandomFill::<S>::random_fill(&mut prng, &mut seed);
seed
}
}
impl Global for MultiPartyCrs<[u8; 32]> {
fn global() -> &'static Self {
MULTI_PARTY_CRS
.get()
.expect("Multi Party Common Reference String not set")
}
}
pub(crate) trait BooleanGates { pub(crate) trait BooleanGates {
type Ciphertext: RowEntity; type Ciphertext: RowEntity;
type Key; type Key;
@ -323,19 +421,116 @@ impl Decryptor> for ClientKey {
} }
} }
struct MultiPartyDecryptionShare<E> {
share: E,
impl MultiPartyDecryptor<bool, Vec<u64>> for ClientKey {
type DecryptionShare = u64;
fn gen_decryption_share(&self, c: &Vec<u64>) -> Self::DecryptionShare {
BoolEvaluator::with_local(|e| e.multi_party_decryption_share(c, &self))
}
fn aggregate_decryption_shares(&self, c: &Vec<u64>, shares: &[Self::DecryptionShare]) -> bool {
BoolEvaluator::with_local(|e| e.multi_party_decrypt(shares, c))
}
} }
struct CommonReferenceSeededCollectivePublicKeyShare<R, S, P> {
// struct MultiPartyDecryptionShare<E> {
// share: E,
// }
pub struct CommonReferenceSeededCollectivePublicKeyShare<R, S, P> {
share: R, share: R,
cr_seed: S, cr_seed: S,
parameters: P, parameters: P,
} }
struct PublicKey<M, R, O> {
struct SeededPublicKey<R, S, P, ModOp> {
part_b: R,
seed: S,
parameters: P,
_phantom: PhantomData<ModOp>,
}
impl<R, S, ModOp>
From<&[CommonReferenceSeededCollectivePublicKeyShare<R, S, BoolParameters<R::Element>>]>
for SeededPublicKey<R, S, BoolParameters<R::Element>, ModOp>
where
ModOp: VectorOps<Element = R::Element> + ModInit<M = CiphertextModulus<R::Element>>,
S: PartialEq + Clone,
R: RowMut + RowEntity + Clone,
R::Element: Clone + PartialEq,
{
fn from(
value: &[CommonReferenceSeededCollectivePublicKeyShare<R, S, BoolParameters<R::Element>>],
) -> Self {
assert!(value.len() > 0);
let parameters = &value[0].parameters;
let cr_seed = value[0].cr_seed.clone();
// Sum all Bs
let rlweq_modop = ModOp::new(parameters.rlwe_q().clone());
let mut part_b = value[0].share.clone();
value.iter().skip(1).for_each(|share_i| {
assert!(&share_i.cr_seed == &cr_seed);
assert!(&share_i.parameters == parameters);
rlweq_modop.elwise_add_mut(part_b.as_mut(), share_i.share.as_ref());
});
Self {
part_b,
seed: cr_seed,
parameters: parameters.clone(),
_phantom: PhantomData,
}
}
}
pub struct PublicKey<M, Rng, ModOp> {
key: M, key: M,
_phantom: PhantomData<(R, O)>,
_phantom: PhantomData<(Rng, ModOp)>,
}
impl<Rng, ModOp> Encryptor<bool, Vec<u64>> for PublicKey<Vec<Vec<u64>>, Rng, ModOp> {
fn encrypt(&self, m: &bool) -> Vec<u64> {
BoolEvaluator::with_local(|e| e.pk_encrypt(&self.key, *m))
}
}
impl<Rng, ModOp> Encryptor<[bool], Vec<Vec<u64>>> for PublicKey<Vec<Vec<u64>>, Rng, ModOp> {
fn encrypt(&self, m: &[bool]) -> Vec<Vec<u64>> {
BoolEvaluator::with_local(|e| e.pk_encrypt_batched(&self.key, m))
}
}
impl<
M: MatrixMut + MatrixEntity,
Rng: NewWithSeed + RandomFillUniformInModulus<[M::MatElement], CiphertextModulus<M::MatElement>>,
ModOp,
> From<SeededPublicKey<M::R, Rng::Seed, BoolParameters<M::MatElement>, ModOp>>
for PublicKey<M, Rng, ModOp>
where
<M as Matrix>::R: RowMut,
M::MatElement: Copy,
{
fn from(value: SeededPublicKey<M::R, Rng::Seed, BoolParameters<M::MatElement>, ModOp>) -> Self {
let mut prng = Rng::new_with_seed(value.seed);
let mut key = M::zeros(2, value.parameters.rlwe_n().0);
// sample A
RandomFillUniformInModulus::random_fill(
&mut prng,
value.parameters.rlwe_q(),
key.get_row_mut(0),
);
// Copy over B
key.get_row_mut(1).copy_from_slice(value.part_b.as_ref());
PublicKey {
key,
_phantom: PhantomData,
}
}
} }
impl< impl<
@ -392,7 +587,7 @@ where
} }
} }
struct CommonReferenceSeededMultiPartyServerKeyShare<M: Matrix, P, S> {
pub struct CommonReferenceSeededMultiPartyServerKeyShare<M: Matrix, P, S> {
rgsw_cts: Vec<M>, rgsw_cts: Vec<M>,
/// Auto keys. Key corresponding to g^{k} is at index `k`. Key corresponding /// Auto keys. Key corresponding to g^{k} is at index `k`. Key corresponding
/// to -g is at 0 /// to -g is at 0
@ -402,7 +597,7 @@ struct CommonReferenceSeededMultiPartyServerKeyShare {
cr_seed: S, cr_seed: S,
parameters: P, parameters: P,
} }
struct SeededMultiPartyServerKey<M: Matrix, S, P> {
pub struct SeededMultiPartyServerKey<M: Matrix, S, P> {
rgsw_cts: Vec<M>, rgsw_cts: Vec<M>,
/// Auto keys. Key corresponding to g^{k} is at index `k`. Key corresponding /// Auto keys. Key corresponding to g^{k} is at index `k`. Key corresponding
/// to -g is at 0 /// to -g is at 0
@ -412,6 +607,22 @@ struct SeededMultiPartyServerKey {
parameters: P, parameters: P,
} }
impl
SeededMultiPartyServerKey<
Vec<Vec<u64>>,
<DefaultSecureRng as NewWithSeed>::Seed,
BoolParameters<u64>,
>
{
pub fn set_server_key(&self) {
set_server_key(ServerKeyEvaluationDomain::<
Vec<Vec<u64>>,
DefaultSecureRng,
NttBackendU64,
>::from(self))
}
}
/// Seeded single party server key /// Seeded single party server key
pub struct SeededServerKey<M: Matrix, P, S> { pub struct SeededServerKey<M: Matrix, P, S> {
/// Rgsw cts of LWE secret elements /// Rgsw cts of LWE secret elements
@ -709,7 +920,6 @@ where
} }
} }
impl<M: Matrix, R, N> PbsKey for ServerKeyEvaluationDomain<M, R, N> { impl<M: Matrix, R, N> PbsKey for ServerKeyEvaluationDomain<M, R, N> {
type M = M; type M = M;
fn galois_key_for_auto(&self, k: usize) -> &Self::M { fn galois_key_for_auto(&self, k: usize) -> &Self::M {
@ -1241,7 +1451,7 @@ where
&self, &self,
lwe_ct: &M::R, lwe_ct: &M::R,
client_key: &ClientKey, client_key: &ClientKey,
) -> MultiPartyDecryptionShare<<M as Matrix>::MatElement> {
) -> <M as Matrix>::MatElement {
assert!(lwe_ct.as_ref().len() == self.pbs_info.parameters.rlwe_n().0 + 1); assert!(lwe_ct.as_ref().len() == self.pbs_info.parameters.rlwe_n().0 + 1);
let modop = &self.pbs_info.rlwe_modop; let modop = &self.pbs_info.rlwe_modop;
let mut neg_s = M::R::try_convert_from( let mut neg_s = M::R::try_convert_from(
@ -1262,34 +1472,44 @@ where
}); });
let share = modop.add(&neg_sa, &e); let share = modop.add(&neg_sa, &e);
MultiPartyDecryptionShare { share }
share
} }
pub(crate) fn multi_party_decrypt(
&self,
shares: &[MultiPartyDecryptionShare<M::MatElement>],
lwe_ct: &M::R,
) -> bool {
pub(crate) fn multi_party_decrypt(&self, shares: &[M::MatElement], lwe_ct: &M::R) -> bool {
let modop = &self.pbs_info.rlwe_modop; let modop = &self.pbs_info.rlwe_modop;
let mut sum_a = M::MatElement::zero(); let mut sum_a = M::MatElement::zero();
shares shares
.iter() .iter()
.for_each(|share_i| sum_a = modop.add(&sum_a, &share_i.share));
.for_each(|share_i| sum_a = modop.add(&sum_a, &share_i));
let encoded_m = modop.add(&lwe_ct.as_ref()[0], &sum_a); let encoded_m = modop.add(&lwe_ct.as_ref()[0], &sum_a);
self.pbs_info.parameters.rlwe_q().decode(encoded_m) self.pbs_info.parameters.rlwe_q().decode(encoded_m)
} }
/// First encrypt as RLWE(m) with m as constant polynomial and extract it as
/// LWE ciphertext
pub(crate) fn pk_encrypt(&self, pk: &M, m: bool) -> M::R { pub(crate) fn pk_encrypt(&self, pk: &M, m: bool) -> M::R {
self.pk_encrypt_batched(pk, &vec![m]).remove(0)
}
/// Encrypts a batch booleans as multiple LWE ciphertexts.
///
/// For public key encryption we first encrypt `m` as a RLWE ciphertext and
/// then sample extract LWE samples at required indices.
///
/// - TODO(Jay:) Communication can be improved by not sample exctracting and
/// instead just truncate degree 0 values (part Bs)
pub(crate) fn pk_encrypt_batched(&self, pk: &M, m: &[bool]) -> Vec<M::R> {
DefaultSecureRng::with_local_mut(|rng| { DefaultSecureRng::with_local_mut(|rng| {
let ring_size = self.pbs_info.parameters.rlwe_n().0;
assert!(
m.len() <= ring_size,
"Cannot batch encrypt > ring_size{ring_size} elements at once"
);
let modop = &self.pbs_info.rlwe_modop; let modop = &self.pbs_info.rlwe_modop;
let nttop = &self.pbs_info.rlwe_nttop; let nttop = &self.pbs_info.rlwe_nttop;
// RLWE(0) // RLWE(0)
// sample ephemeral key u // sample ephemeral key u
let ring_size = self.pbs_info.parameters.rlwe_n().0;
let mut u = vec![0i32; ring_size]; let mut u = vec![0i32; ring_size];
fill_random_ternary_secret_with_hamming_weight(u.as_mut(), ring_size >> 1, rng); fill_random_ternary_secret_with_hamming_weight(u.as_mut(), ring_size >> 1, rng);
let mut u = M::R::try_convert_from(&u, &self.pbs_info.parameters.rlwe_q()); let mut u = M::R::try_convert_from(&u, &self.pbs_info.parameters.rlwe_q());
@ -1326,22 +1546,31 @@ where
modop.elwise_add_mut(rlwe.get_row_mut(1), ub.as_ref()); modop.elwise_add_mut(rlwe.get_row_mut(1), ub.as_ref());
//FIXME(Jay): Figure out a way to get Q/8 form modulus //FIXME(Jay): Figure out a way to get Q/8 form modulus
let m = if m {
// Q/8
self.pbs_info.rlwe_q().true_el()
} else {
// -Q/8
self.pbs_info.rlwe_q().false_el()
};
// b*u + e1 + m, where m is constant polynomial
rlwe.set(1, 0, modop.add(rlwe.get(1, 0), &m));
// sample extract index 0
let mut lwe_out = M::R::zeros(ring_size + 1);
sample_extract(&mut lwe_out, &rlwe, modop, 0);
let mut m_vec = M::R::zeros(ring_size);
izip!(m_vec.as_mut().iter_mut(), m.iter()).for_each(|(m_el, m_bool)| {
if *m_bool {
// Q/8
*m_el = self.pbs_info.rlwe_q().true_el()
} else {
// -Q/8
*m_el = self.pbs_info.rlwe_q().false_el()
}
});
lwe_out
// b*u + e1 + m
modop.elwise_add_mut(rlwe.get_row_mut(1), m_vec.as_ref());
// rlwe.set(1, 0, modop.add(rlwe.get(1, 0), &m));
// sample extract index required indices
let samples = m.len();
(0..samples)
.into_iter()
.map(|i| {
let mut lwe_out = M::R::zeros(ring_size + 1);
sample_extract(&mut lwe_out, &rlwe, modop, i);
lwe_out
})
.collect_vec()
}) })
} }
@ -2103,7 +2332,6 @@ fn pbs, K: PbsK
pbs_key, pbs_key,
); );
// sample extract // sample extract
sample_extract(lwe_in, &trivial_rlwe_test_poly, pbs_info.modop_rlweq(), 0); sample_extract(lwe_in, &trivial_rlwe_test_poly, pbs_info.modop_rlweq(), 0);
} }
@ -2731,7 +2959,7 @@ mod tests {
>::new(MP_BOOL_PARAMS); >::new(MP_BOOL_PARAMS);
let (parties, collective_pk, _, _, server_key_eval, ideal_client_key) = let (parties, collective_pk, _, _, server_key_eval, ideal_client_key) =
_multi_party_all_keygen(&bool_evaluator, 8);
_multi_party_all_keygen(&bool_evaluator, 64);
let mut m0 = true; let mut m0 = true;
let mut m1 = false; let mut m1 = false;

+ 3
- 3
src/bool/parameters.rs

@ -307,10 +307,10 @@ pub(crate) const SP_BOOL_PARAMS: BoolParameters = BoolParameters:: {
auto_decomposer_base: DecompostionLogBase(7), auto_decomposer_base: DecompostionLogBase(7),
auto_decomposer_count: DecompositionCount(4), auto_decomposer_count: DecompositionCount(4),
g: 5, g: 5,
w: 10,
w: 5,
}; };
pub(super) const MP_BOOL_PARAMS: BoolParameters<u64> = BoolParameters::<u64> {
pub(crate) const MP_BOOL_PARAMS: BoolParameters<u64> = BoolParameters::<u64> {
rlwe_q: CiphertextModulus::new_non_native(1152921504606830593), rlwe_q: CiphertextModulus::new_non_native(1152921504606830593),
lwe_q: CiphertextModulus::new_non_native(1 << 20), lwe_q: CiphertextModulus::new_non_native(1 << 20),
br_q: 1 << 11, br_q: 1 << 11,
@ -325,7 +325,7 @@ pub(super) const MP_BOOL_PARAMS: BoolParameters = BoolParameters:: {
auto_decomposer_base: DecompostionLogBase(12), auto_decomposer_base: DecompostionLogBase(12),
auto_decomposer_count: DecompositionCount(5), auto_decomposer_count: DecompositionCount(5),
g: 5, g: 5,
w: 5,
w: 10,
}; };
#[cfg(test)] #[cfg(test)]

+ 11
- 2
src/lib.rs

@ -1,3 +1,5 @@
use std::{iter::Once, sync::OnceLock};
use itertools::{izip, Itertools}; use itertools::{izip, Itertools};
use num::UnsignedInteger; use num::UnsignedInteger;
use num_traits::{abs, Zero}; use num_traits::{abs, Zero};
@ -156,10 +158,17 @@ impl RowEntity for Vec {
} }
} }
trait Encryptor<M, C> {
trait Encryptor<M: ?Sized, C> {
fn encrypt(&self, m: &M) -> C; fn encrypt(&self, m: &M) -> C;
} }
trait Decryptor<M, C: ?Sized> {
trait Decryptor<M, C> {
fn decrypt(&self, c: &C) -> M; fn decrypt(&self, c: &C) -> M;
} }
trait MultiPartyDecryptor<M, C> {
type DecryptionShare;
fn gen_decryption_share(&self, c: &C) -> Self::DecryptionShare;
fn aggregate_decryption_shares(&self, c: &C, shares: &[Self::DecryptionShare]) -> M;
}

+ 15
- 0
src/random.rs

@ -138,6 +138,21 @@ where
} }
} }
impl<T> RandomFill<[T; 32]> for DefaultSecureRng
where
T: PrimInt + SampleUniform,
{
fn random_fill(&mut self, container: &mut [T; 32]) {
izip!(
(&mut self.rng).sample_iter(Uniform::new_inclusive(T::zero(), T::max_value())),
container.iter_mut()
)
.for_each(|(from, to)| {
*to = from;
});
}
}
impl<T> RandomElement<T> for DefaultSecureRng impl<T> RandomElement<T> for DefaultSecureRng
where where
T: PrimInt + SampleUniform, T: PrimInt + SampleUniform,

+ 120
- 5
src/shortint/mod.rs

@ -1,9 +1,11 @@
use itertools::Itertools; use itertools::Itertools;
use crate::{ use crate::{
bool::evaluator::{BoolEvaluator, ClientKey, ServerKeyEvaluationDomain, BOOL_SERVER_KEY},
bool::evaluator::{
BoolEvaluator, ClientKey, PublicKey, ServerKeyEvaluationDomain, BOOL_SERVER_KEY,
},
utils::{Global, WithLocal}, utils::{Global, WithLocal},
Decryptor, Encryptor,
Decryptor, Encryptor, Matrix, MultiPartyDecryptor,
}; };
mod ops; mod ops;
@ -26,6 +28,7 @@ impl Encryptor for ClientKey {
impl Decryptor<u8, FheUint8> for ClientKey { impl Decryptor<u8, FheUint8> for ClientKey {
fn decrypt(&self, c: &FheUint8) -> u8 { fn decrypt(&self, c: &FheUint8) -> u8 {
assert!(c.data.len() == 8);
let mut out = 0u8; let mut out = 0u8;
c.data().iter().enumerate().for_each(|(index, bit_c)| { c.data().iter().enumerate().for_each(|(index, bit_c)| {
let bool = Decryptor::<bool, Vec<u64>>::decrypt(self, bit_c); let bool = Decryptor::<bool, Vec<u64>>::decrypt(self, bit_c);
@ -37,6 +40,60 @@ impl Decryptor for ClientKey {
} }
} }
impl<M, R, Mo> Encryptor<u8, FheUint8> for PublicKey<M, R, Mo>
where
PublicKey<M, R, Mo>: Encryptor<bool, Vec<u64>>,
{
fn encrypt(&self, m: &u8) -> FheUint8 {
let cts = (0..8)
.into_iter()
.map(|i| {
let bit = ((m >> i) & 1) == 1;
Encryptor::<bool, Vec<u64>>::encrypt(self, &bit)
})
.collect_vec();
FheUint8 { data: cts }
}
}
impl MultiPartyDecryptor<u8, FheUint8> for ClientKey
where
ClientKey: MultiPartyDecryptor<bool, Vec<u64>>,
{
type DecryptionShare = Vec<<Self as MultiPartyDecryptor<bool, Vec<u64>>>::DecryptionShare>;
fn gen_decryption_share(&self, c: &FheUint8) -> Self::DecryptionShare {
assert!(c.data().len() == 8);
c.data()
.iter()
.map(|bit_c| {
let decryption_share =
MultiPartyDecryptor::<bool, Vec<u64>>::gen_decryption_share(self, bit_c);
decryption_share
})
.collect_vec()
}
fn aggregate_decryption_shares(&self, c: &FheUint8, shares: &[Self::DecryptionShare]) -> u8 {
let mut out = 0u8;
(0..8).into_iter().for_each(|i| {
// Collect bit i^th decryption share of each party
let bit_i_decryption_shares = shares.iter().map(|s| s[i]).collect_vec();
let bit_i = MultiPartyDecryptor::<bool, Vec<u64>>::aggregate_decryption_shares(
self,
&c.data()[i],
&bit_i_decryption_shares,
);
if bit_i {
out += 1 << i;
}
});
out
}
}
mod frontend { mod frontend {
use super::ops::{ use super::ops::{
arbitrary_bit_adder, arbitrary_bit_division_for_quotient_and_rem, arbitrary_bit_subtractor, arbitrary_bit_adder, arbitrary_bit_division_for_quotient_and_rem, arbitrary_bit_subtractor,
@ -245,15 +302,20 @@ mod frontend {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use itertools::Itertools;
use num_traits::Euclid; use num_traits::Euclid;
use crate::{ use crate::{
bool::{ bool::{
evaluator::{gen_keys, set_parameter_set, BoolEvaluator},
parameters::SP_BOOL_PARAMS,
evaluator::{
aggregate_public_key_shares, aggregate_server_key_shares, gen_client_key, gen_keys,
gen_mp_keys_phase1, gen_mp_keys_phase2, set_mp_seed, set_parameter_set,
BoolEvaluator, ClientKey,
},
parameters::{MP_BOOL_PARAMS, SP_BOOL_PARAMS},
}, },
shortint::types::FheUint8, shortint::types::FheUint8,
Decryptor, Encryptor,
Decryptor, Encryptor, MultiPartyDecryptor,
}; };
#[test] #[test]
@ -403,4 +465,57 @@ mod tests {
} }
} }
} }
#[test]
fn fheuint8_test_multi_party() {
set_parameter_set(&MP_BOOL_PARAMS);
set_mp_seed([0; 32]);
let parties = 8;
// client keys and public key share
let cks = (0..parties)
.into_iter()
.map(|i| gen_client_key())
.collect_vec();
// round 1: generate pulic key shares
let pk_shares = cks.iter().map(|key| gen_mp_keys_phase1(key)).collect_vec();
let public_key = aggregate_public_key_shares(&pk_shares);
// round 2: generate server key shares
let server_key_shares = cks
.iter()
.map(|key| gen_mp_keys_phase2(key, &public_key))
.collect_vec();
// server aggregates the server key
let server_key = aggregate_server_key_shares(&server_key_shares);
server_key.set_server_key();
// Clients use Pk to encrypt private inputs
let a = 8u8;
let b = 10u8;
let c = 155u8;
let ct_a = public_key.encrypt(&a);
let ct_b = public_key.encrypt(&b);
let ct_c = public_key.encrypt(&c);
// server computes
// a*b + c
let mut ct_ab = &ct_a * &ct_b;
ct_ab += &ct_c;
// decrypt ab and check
// generate decryption shares
let decryption_shares = cks
.iter()
.map(|k| k.gen_decryption_share(&ct_ab))
.collect_vec();
// aggregate and decryption ab
let ab_add_c = cks[0].aggregate_decryption_shares(&ct_ab, &decryption_shares);
assert!(ab_add_c == (a.wrapping_mul(b)).wrapping_add(c));
}
} }

Loading…
Cancel
Save