Add multi-party Uint8

This commit is contained in:
Janmajaya Mall
2024-06-01 13:09:01 +05:30
parent 70cb18da57
commit 1c0ac104e2
5 changed files with 417 additions and 50 deletions

View File

@@ -1,9 +1,11 @@
use itertools::Itertools;
use crate::{
bool::evaluator::{BoolEvaluator, ClientKey, ServerKeyEvaluationDomain, BOOL_SERVER_KEY},
bool::evaluator::{
BoolEvaluator, ClientKey, PublicKey, ServerKeyEvaluationDomain, BOOL_SERVER_KEY,
},
utils::{Global, WithLocal},
Decryptor, Encryptor,
Decryptor, Encryptor, Matrix, MultiPartyDecryptor,
};
mod ops;
@@ -26,6 +28,7 @@ impl Encryptor<u8, FheUint8> for ClientKey {
impl Decryptor<u8, FheUint8> for ClientKey {
fn decrypt(&self, c: &FheUint8) -> u8 {
assert!(c.data.len() == 8);
let mut out = 0u8;
c.data().iter().enumerate().for_each(|(index, bit_c)| {
let bool = Decryptor::<bool, Vec<u64>>::decrypt(self, bit_c);
@@ -37,6 +40,60 @@ impl Decryptor<u8, FheUint8> for ClientKey {
}
}
impl<M, R, Mo> Encryptor<u8, FheUint8> for PublicKey<M, R, Mo>
where
PublicKey<M, R, Mo>: Encryptor<bool, Vec<u64>>,
{
fn encrypt(&self, m: &u8) -> FheUint8 {
let cts = (0..8)
.into_iter()
.map(|i| {
let bit = ((m >> i) & 1) == 1;
Encryptor::<bool, Vec<u64>>::encrypt(self, &bit)
})
.collect_vec();
FheUint8 { data: cts }
}
}
impl MultiPartyDecryptor<u8, FheUint8> for ClientKey
where
ClientKey: MultiPartyDecryptor<bool, Vec<u64>>,
{
type DecryptionShare = Vec<<Self as MultiPartyDecryptor<bool, Vec<u64>>>::DecryptionShare>;
fn gen_decryption_share(&self, c: &FheUint8) -> Self::DecryptionShare {
assert!(c.data().len() == 8);
c.data()
.iter()
.map(|bit_c| {
let decryption_share =
MultiPartyDecryptor::<bool, Vec<u64>>::gen_decryption_share(self, bit_c);
decryption_share
})
.collect_vec()
}
fn aggregate_decryption_shares(&self, c: &FheUint8, shares: &[Self::DecryptionShare]) -> u8 {
let mut out = 0u8;
(0..8).into_iter().for_each(|i| {
// Collect bit i^th decryption share of each party
let bit_i_decryption_shares = shares.iter().map(|s| s[i]).collect_vec();
let bit_i = MultiPartyDecryptor::<bool, Vec<u64>>::aggregate_decryption_shares(
self,
&c.data()[i],
&bit_i_decryption_shares,
);
if bit_i {
out += 1 << i;
}
});
out
}
}
mod frontend {
use super::ops::{
arbitrary_bit_adder, arbitrary_bit_division_for_quotient_and_rem, arbitrary_bit_subtractor,
@@ -245,15 +302,20 @@ mod frontend {
#[cfg(test)]
mod tests {
use itertools::Itertools;
use num_traits::Euclid;
use crate::{
bool::{
evaluator::{gen_keys, set_parameter_set, BoolEvaluator},
parameters::SP_BOOL_PARAMS,
evaluator::{
aggregate_public_key_shares, aggregate_server_key_shares, gen_client_key, gen_keys,
gen_mp_keys_phase1, gen_mp_keys_phase2, set_mp_seed, set_parameter_set,
BoolEvaluator, ClientKey,
},
parameters::{MP_BOOL_PARAMS, SP_BOOL_PARAMS},
},
shortint::types::FheUint8,
Decryptor, Encryptor,
Decryptor, Encryptor, MultiPartyDecryptor,
};
#[test]
@@ -403,4 +465,57 @@ mod tests {
}
}
}
#[test]
fn fheuint8_test_multi_party() {
set_parameter_set(&MP_BOOL_PARAMS);
set_mp_seed([0; 32]);
let parties = 8;
// client keys and public key share
let cks = (0..parties)
.into_iter()
.map(|i| gen_client_key())
.collect_vec();
// round 1: generate pulic key shares
let pk_shares = cks.iter().map(|key| gen_mp_keys_phase1(key)).collect_vec();
let public_key = aggregate_public_key_shares(&pk_shares);
// round 2: generate server key shares
let server_key_shares = cks
.iter()
.map(|key| gen_mp_keys_phase2(key, &public_key))
.collect_vec();
// server aggregates the server key
let server_key = aggregate_server_key_shares(&server_key_shares);
server_key.set_server_key();
// Clients use Pk to encrypt private inputs
let a = 8u8;
let b = 10u8;
let c = 155u8;
let ct_a = public_key.encrypt(&a);
let ct_b = public_key.encrypt(&b);
let ct_c = public_key.encrypt(&c);
// server computes
// a*b + c
let mut ct_ab = &ct_a * &ct_b;
ct_ab += &ct_c;
// decrypt ab and check
// generate decryption shares
let decryption_shares = cks
.iter()
.map(|k| k.gen_decryption_share(&ct_ab))
.collect_vec();
// aggregate and decryption ab
let ab_add_c = cks[0].aggregate_decryption_shares(&ct_ab, &decryption_shares);
assert!(ab_add_c == (a.wrapping_mul(b)).wrapping_add(c));
}
}