mirror of
https://github.com/arnaucube/phantom-zone.git
synced 2026-01-18 20:01:34 +01:00
Add multi-party Uint8
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::{
|
||||
bool::evaluator::{BoolEvaluator, ClientKey, ServerKeyEvaluationDomain, BOOL_SERVER_KEY},
|
||||
bool::evaluator::{
|
||||
BoolEvaluator, ClientKey, PublicKey, ServerKeyEvaluationDomain, BOOL_SERVER_KEY,
|
||||
},
|
||||
utils::{Global, WithLocal},
|
||||
Decryptor, Encryptor,
|
||||
Decryptor, Encryptor, Matrix, MultiPartyDecryptor,
|
||||
};
|
||||
|
||||
mod ops;
|
||||
@@ -26,6 +28,7 @@ impl Encryptor<u8, FheUint8> for ClientKey {
|
||||
|
||||
impl Decryptor<u8, FheUint8> for ClientKey {
|
||||
fn decrypt(&self, c: &FheUint8) -> u8 {
|
||||
assert!(c.data.len() == 8);
|
||||
let mut out = 0u8;
|
||||
c.data().iter().enumerate().for_each(|(index, bit_c)| {
|
||||
let bool = Decryptor::<bool, Vec<u64>>::decrypt(self, bit_c);
|
||||
@@ -37,6 +40,60 @@ impl Decryptor<u8, FheUint8> for ClientKey {
|
||||
}
|
||||
}
|
||||
|
||||
impl<M, R, Mo> Encryptor<u8, FheUint8> for PublicKey<M, R, Mo>
|
||||
where
|
||||
PublicKey<M, R, Mo>: Encryptor<bool, Vec<u64>>,
|
||||
{
|
||||
fn encrypt(&self, m: &u8) -> FheUint8 {
|
||||
let cts = (0..8)
|
||||
.into_iter()
|
||||
.map(|i| {
|
||||
let bit = ((m >> i) & 1) == 1;
|
||||
Encryptor::<bool, Vec<u64>>::encrypt(self, &bit)
|
||||
})
|
||||
.collect_vec();
|
||||
FheUint8 { data: cts }
|
||||
}
|
||||
}
|
||||
|
||||
impl MultiPartyDecryptor<u8, FheUint8> for ClientKey
|
||||
where
|
||||
ClientKey: MultiPartyDecryptor<bool, Vec<u64>>,
|
||||
{
|
||||
type DecryptionShare = Vec<<Self as MultiPartyDecryptor<bool, Vec<u64>>>::DecryptionShare>;
|
||||
fn gen_decryption_share(&self, c: &FheUint8) -> Self::DecryptionShare {
|
||||
assert!(c.data().len() == 8);
|
||||
c.data()
|
||||
.iter()
|
||||
.map(|bit_c| {
|
||||
let decryption_share =
|
||||
MultiPartyDecryptor::<bool, Vec<u64>>::gen_decryption_share(self, bit_c);
|
||||
decryption_share
|
||||
})
|
||||
.collect_vec()
|
||||
}
|
||||
|
||||
fn aggregate_decryption_shares(&self, c: &FheUint8, shares: &[Self::DecryptionShare]) -> u8 {
|
||||
let mut out = 0u8;
|
||||
|
||||
(0..8).into_iter().for_each(|i| {
|
||||
// Collect bit i^th decryption share of each party
|
||||
let bit_i_decryption_shares = shares.iter().map(|s| s[i]).collect_vec();
|
||||
let bit_i = MultiPartyDecryptor::<bool, Vec<u64>>::aggregate_decryption_shares(
|
||||
self,
|
||||
&c.data()[i],
|
||||
&bit_i_decryption_shares,
|
||||
);
|
||||
|
||||
if bit_i {
|
||||
out += 1 << i;
|
||||
}
|
||||
});
|
||||
|
||||
out
|
||||
}
|
||||
}
|
||||
|
||||
mod frontend {
|
||||
use super::ops::{
|
||||
arbitrary_bit_adder, arbitrary_bit_division_for_quotient_and_rem, arbitrary_bit_subtractor,
|
||||
@@ -245,15 +302,20 @@ mod frontend {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use itertools::Itertools;
|
||||
use num_traits::Euclid;
|
||||
|
||||
use crate::{
|
||||
bool::{
|
||||
evaluator::{gen_keys, set_parameter_set, BoolEvaluator},
|
||||
parameters::SP_BOOL_PARAMS,
|
||||
evaluator::{
|
||||
aggregate_public_key_shares, aggregate_server_key_shares, gen_client_key, gen_keys,
|
||||
gen_mp_keys_phase1, gen_mp_keys_phase2, set_mp_seed, set_parameter_set,
|
||||
BoolEvaluator, ClientKey,
|
||||
},
|
||||
parameters::{MP_BOOL_PARAMS, SP_BOOL_PARAMS},
|
||||
},
|
||||
shortint::types::FheUint8,
|
||||
Decryptor, Encryptor,
|
||||
Decryptor, Encryptor, MultiPartyDecryptor,
|
||||
};
|
||||
|
||||
#[test]
|
||||
@@ -403,4 +465,57 @@ mod tests {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fheuint8_test_multi_party() {
|
||||
set_parameter_set(&MP_BOOL_PARAMS);
|
||||
set_mp_seed([0; 32]);
|
||||
|
||||
let parties = 8;
|
||||
|
||||
// client keys and public key share
|
||||
let cks = (0..parties)
|
||||
.into_iter()
|
||||
.map(|i| gen_client_key())
|
||||
.collect_vec();
|
||||
|
||||
// round 1: generate pulic key shares
|
||||
let pk_shares = cks.iter().map(|key| gen_mp_keys_phase1(key)).collect_vec();
|
||||
|
||||
let public_key = aggregate_public_key_shares(&pk_shares);
|
||||
|
||||
// round 2: generate server key shares
|
||||
let server_key_shares = cks
|
||||
.iter()
|
||||
.map(|key| gen_mp_keys_phase2(key, &public_key))
|
||||
.collect_vec();
|
||||
|
||||
// server aggregates the server key
|
||||
let server_key = aggregate_server_key_shares(&server_key_shares);
|
||||
server_key.set_server_key();
|
||||
|
||||
// Clients use Pk to encrypt private inputs
|
||||
let a = 8u8;
|
||||
let b = 10u8;
|
||||
let c = 155u8;
|
||||
let ct_a = public_key.encrypt(&a);
|
||||
let ct_b = public_key.encrypt(&b);
|
||||
let ct_c = public_key.encrypt(&c);
|
||||
|
||||
// server computes
|
||||
// a*b + c
|
||||
let mut ct_ab = &ct_a * &ct_b;
|
||||
ct_ab += &ct_c;
|
||||
|
||||
// decrypt ab and check
|
||||
// generate decryption shares
|
||||
let decryption_shares = cks
|
||||
.iter()
|
||||
.map(|k| k.gen_decryption_share(&ct_ab))
|
||||
.collect_vec();
|
||||
|
||||
// aggregate and decryption ab
|
||||
let ab_add_c = cks[0].aggregate_decryption_shares(&ct_ab, &decryption_shares);
|
||||
assert!(ab_add_c == (a.wrapping_mul(b)).wrapping_add(c));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user