Browse Source

add example

par-agg-key-shares
Janmajaya Mall 10 months ago
parent
commit
1e0fb86782
7 changed files with 98 additions and 16 deletions
  1. +71
    -0
      examples/fheuint8.rs
  2. +10
    -2
      src/bool/mod.rs
  3. +2
    -2
      src/bool/parameters.rs
  4. +8
    -3
      src/lib.rs
  5. +2
    -2
      src/random.rs
  6. +4
    -6
      src/shortint/mod.rs
  7. +1
    -1
      src/shortint/types.rs

+ 71
- 0
examples/fheuint8.rs

@ -0,0 +1,71 @@
use bin_rs::*;
use itertools::Itertools;
use rand::{thread_rng, RngCore};
fn plain_circuit(a: u8, b: u8, c: u8) -> u8 {
(a + b) * c
}
fn fhe_circuit(fhe_a: &FheUint8, fhe_b: &FheUint8, fhe_c: &FheUint8) -> FheUint8 {
&(fhe_a + fhe_b) * fhe_c
}
fn main() {
set_parameter_set(ParameterSelector::MultiPartyLessThan16);
let no_of_parties = 2;
let client_keys = (0..no_of_parties)
.into_iter()
.map(|_| gen_client_key())
.collect_vec();
// set Multi-Party seed
let mut seed = [0u8; 32];
thread_rng().fill_bytes(&mut seed);
set_mp_seed(seed);
// multi-party key gen round 1
let pk_shares = client_keys
.iter()
.map(|k| gen_mp_keys_phase1(k))
.collect_vec();
// create public key
let public_key = aggregate_public_key_shares(&pk_shares);
// multi-party key gen round 2
let server_key_shares = client_keys
.iter()
.map(|k| gen_mp_keys_phase2(k, &public_key))
.collect_vec();
// server aggregates server key shares and sets it
let server_key = aggregate_server_key_shares(&server_key_shares);
server_key.set_server_key();
// private inputs
let a = 4u8;
let b = 6u8;
let c = 128u8;
let fhe_a = public_key.encrypt(&a);
let fhe_b = public_key.encrypt(&b);
let fhe_c = public_key.encrypt(&c);
// fhe evaluation
let now = std::time::Instant::now();
let fhe_out = fhe_circuit(&fhe_a, &fhe_b, &fhe_c);
println!("Circuit time: {:?}", now.elapsed());
// plain evaluation
let out = plain_circuit(a, b, c);
// generate decryption shares to decrypt ciphertext fhe_out
let decryption_shares = client_keys
.iter()
.map(|k| k.gen_decryption_share(&fhe_out))
.collect_vec();
// decrypt fhe_out using decryption shares
let got_out = client_keys[0].aggregate_decryption_shares(&fhe_out, &decryption_shares);
assert_eq!(got_out, out);
}

+ 10
- 2
src/bool/mod.rs

@ -26,8 +26,16 @@ static BOOL_SERVER_KEY: OnceLock>>>
static MULTI_PARTY_CRS: OnceLock<MultiPartyCrs<[u8; 32]>> = OnceLock::new(); static MULTI_PARTY_CRS: OnceLock<MultiPartyCrs<[u8; 32]>> = OnceLock::new();
pub fn set_parameter_set(parameter: &BoolParameters<u64>) {
BOOL_EVALUATOR.with_borrow_mut(|v| *v = Some(BoolEvaluator::new(parameter.clone())));
pub enum ParameterSelector {
MultiPartyLessThan16,
}
pub fn set_parameter_set(select: ParameterSelector) {
match select {
ParameterSelector::MultiPartyLessThan16 => {
BOOL_EVALUATOR.with_borrow_mut(|v| *v = Some(BoolEvaluator::new(SMALL_MP_BOOL_PARAMS)));
}
}
} }
pub fn set_mp_seed(seed: [u8; 32]) { pub fn set_mp_seed(seed: [u8; 32]) {

+ 2
- 2
src/bool/parameters.rs

@ -3,7 +3,7 @@ use num_traits::{ConstZero, FromPrimitive, PrimInt};
use crate::{backend::Modulus, decomposer::Decomposer}; use crate::{backend::Modulus, decomposer::Decomposer};
#[derive(Clone, PartialEq)] #[derive(Clone, PartialEq)]
pub(crate) struct BoolParameters<El> {
pub struct BoolParameters<El> {
rlwe_q: CiphertextModulus<El>, rlwe_q: CiphertextModulus<El>,
lwe_q: CiphertextModulus<El>, lwe_q: CiphertextModulus<El>,
br_q: usize, br_q: usize,
@ -181,7 +181,7 @@ pub(crate) struct PolynomialSize(pub(crate) usize);
/// T equals modulus when modulus is non-native. Otherwise T equals 0. bool is /// T equals modulus when modulus is non-native. Otherwise T equals 0. bool is
/// true when modulus is native, false otherwise. /// true when modulus is native, false otherwise.
pub(crate) struct CiphertextModulus<T>(T, bool);
pub struct CiphertextModulus<T>(T, bool);
impl<T: ConstZero> CiphertextModulus<T> { impl<T: ConstZero> CiphertextModulus<T> {
const fn new_native() -> Self { const fn new_native() -> Self {

+ 8
- 3
src/lib.rs

@ -23,8 +23,13 @@ mod utils;
pub use backend::{ pub use backend::{
ArithmeticLazyOps, ArithmeticOps, ModInit, ModularOpsU64, ShoupMatrixFMA, VectorOps, ArithmeticLazyOps, ArithmeticOps, ModInit, ModularOpsU64, ShoupMatrixFMA, VectorOps,
}; };
pub use bool::{
aggregate_public_key_shares, aggregate_server_key_shares, gen_client_key, gen_mp_keys_phase1,
gen_mp_keys_phase2, set_mp_seed, set_parameter_set, ParameterSelector,
};
pub use decomposer::{Decomposer, DecomposerIter, DefaultDecomposer}; pub use decomposer::{Decomposer, DecomposerIter, DefaultDecomposer};
pub use ntt::{Ntt, NttBackendU64, NttInit}; pub use ntt::{Ntt, NttBackendU64, NttInit};
pub use shortint::FheUint8;
pub trait Matrix: AsRef<[Self::R]> { pub trait Matrix: AsRef<[Self::R]> {
type MatElement; type MatElement;
@ -165,15 +170,15 @@ impl RowEntity for Vec {
} }
} }
trait Encryptor<M: ?Sized, C> {
pub trait Encryptor<M: ?Sized, C> {
fn encrypt(&self, m: &M) -> C; fn encrypt(&self, m: &M) -> C;
} }
trait Decryptor<M, C> {
pub trait Decryptor<M, C> {
fn decrypt(&self, c: &C) -> M; fn decrypt(&self, c: &C) -> M;
} }
trait MultiPartyDecryptor<M, C> {
pub trait MultiPartyDecryptor<M, C> {
type DecryptionShare; type DecryptionShare;
fn gen_decryption_share(&self, c: &C) -> Self::DecryptionShare; fn gen_decryption_share(&self, c: &C) -> Self::DecryptionShare;

+ 2
- 2
src/random.rs

@ -12,7 +12,7 @@ thread_local! {
pub(crate) static DEFAULT_RNG: RefCell<DefaultSecureRng> = RefCell::new(DefaultSecureRng::new_seeded([0u8;32])); pub(crate) static DEFAULT_RNG: RefCell<DefaultSecureRng> = RefCell::new(DefaultSecureRng::new_seeded([0u8;32]));
} }
pub(crate) trait NewWithSeed {
pub trait NewWithSeed {
type Seed; type Seed;
fn new_with_seed(seed: Self::Seed) -> Self; fn new_with_seed(seed: Self::Seed) -> Self;
} }
@ -59,7 +59,7 @@ where
fn random_fill(&mut self, modulus: &P, container: &mut M); fn random_fill(&mut self, modulus: &P, container: &mut M);
} }
pub(crate) struct DefaultSecureRng {
pub struct DefaultSecureRng {
rng: ChaCha8Rng, rng: ChaCha8Rng,
} }

+ 4
- 6
src/shortint/mod.rs

@ -8,7 +8,7 @@ use crate::{
mod ops; mod ops;
mod types; mod types;
type FheUint8 = types::FheUint8<Vec<u64>>;
pub type FheUint8 = types::FheUint8<Vec<u64>>;
impl Encryptor<u8, FheUint8> for ClientKey { impl Encryptor<u8, FheUint8> for ClientKey {
fn encrypt(&self, m: &u8) -> FheUint8 { fn encrypt(&self, m: &u8) -> FheUint8 {
@ -308,9 +308,7 @@ mod tests {
use crate::{ use crate::{
bool::{ bool::{
aggregate_public_key_shares, aggregate_server_key_shares, gen_client_key, gen_keys, aggregate_public_key_shares, aggregate_server_key_shares, gen_client_key, gen_keys,
gen_mp_keys_phase1, gen_mp_keys_phase2,
parameters::{MP_BOOL_PARAMS, SMALL_MP_BOOL_PARAMS, SP_BOOL_PARAMS},
set_mp_seed, set_parameter_set,
gen_mp_keys_phase1, gen_mp_keys_phase2, set_mp_seed, set_parameter_set,
}, },
shortint::types::FheUint8, shortint::types::FheUint8,
Decryptor, Encryptor, MultiPartyDecryptor, Decryptor, Encryptor, MultiPartyDecryptor,
@ -318,7 +316,7 @@ mod tests {
#[test] #[test]
fn all_uint8_apis() { fn all_uint8_apis() {
set_parameter_set(&SP_BOOL_PARAMS);
set_parameter_set(crate::ParameterSelector::MultiPartyLessThan16);
let (ck, sk) = gen_keys(); let (ck, sk) = gen_keys();
sk.set_server_key(); sk.set_server_key();
@ -466,7 +464,7 @@ mod tests {
#[test] #[test]
fn fheuint8_test_multi_party() { fn fheuint8_test_multi_party() {
set_parameter_set(&SMALL_MP_BOOL_PARAMS);
set_parameter_set(crate::ParameterSelector::MultiPartyLessThan16);
set_mp_seed([0; 32]); set_mp_seed([0; 32]);
let parties = 8; let parties = 8;

+ 1
- 1
src/shortint/types.rs

@ -1,5 +1,5 @@
#[derive(Clone)] #[derive(Clone)]
pub(super) struct FheUint8<C> {
pub struct FheUint8<C> {
pub(super) data: Vec<C>, pub(super) data: Vec<C>,
} }

Loading…
Cancel
Save