From 4e6a9aa3a755e92832f74ab6094f23b78df741b8 Mon Sep 17 00:00:00 2001 From: Janmajaya Mall Date: Tue, 2 Jul 2024 15:28:46 +0530 Subject: [PATCH] amend interactive fhe uint8 example --- examples/interactive_fheuint8.rs | 188 +++++++++++++++++++++------ examples/non_interactive_fheuint8.rs | 15 ++- src/bool/evaluator.rs | 2 +- src/bool/mp_api.rs | 22 ++-- src/bool/print_noise.rs | 11 +- 5 files changed, 182 insertions(+), 56 deletions(-) diff --git a/examples/interactive_fheuint8.rs b/examples/interactive_fheuint8.rs index 82c3f71..081c9c9 100644 --- a/examples/interactive_fheuint8.rs +++ b/examples/interactive_fheuint8.rs @@ -1,72 +1,184 @@ use bin_rs::*; use itertools::Itertools; -use rand::{thread_rng, RngCore}; +use rand::{thread_rng, Rng, RngCore}; -fn plain_circuit(a: u8, b: u8, c: u8) -> u8 { - (a + b) * c +fn function1(a: u8, b: u8, c: u8, d: u8) -> u8 { + ((a + b) * c) * d } -fn fhe_circuit(fhe_a: &FheUint8, fhe_b: &FheUint8, fhe_c: &FheUint8) -> FheUint8 { - &(fhe_a + fhe_b) * fhe_c +fn function1_fhe(a: &FheUint8, b: &FheUint8, c: &FheUint8, d: &FheUint8) -> FheUint8 { + &(&(a + b) * c) * d +} + +fn function2(a: u8, b: u8, c: u8, d: u8) -> u8 { + (a * b) + (c * d) +} + +fn function2_fhe(a: &FheUint8, b: &FheUint8, c: &FheUint8, d: &FheUint8) -> FheUint8 { + &(a * b) + &(c * d) } fn main() { + // Select parameter set set_parameter_set(ParameterSelector::InteractiveLTE2Party); + + // set application's common reference seed + let mut seed = [0u8; 32]; + thread_rng().fill_bytes(&mut seed); + set_common_reference_seed(seed); + let no_of_parties = 2; - let client_keys = (0..no_of_parties) + + // Client side // + + // Clients generate their private keys + let cks = (0..no_of_parties) .into_iter() .map(|_| gen_client_key()) .collect_vec(); - // set Multi-Party seed - let mut seed = [0u8; 32]; - thread_rng().fill_bytes(&mut seed); - set_mp_seed(seed); + // -- Round 1 -- // + // In round 1 each client generates their share for the collective public key. + // They send public key shares to each other with out without server. After + // receiving others public key shares client independently aggregates the share + // and produces the collective public key `pk` - // multi-party key gen round 1 - let pk_shares = client_keys + let pk_shares = cks .iter() - .map(|k| gen_mp_keys_phase1(k)) + .map(|k| interactive_multi_party_round1_share(k)) .collect_vec(); - // create public key - let public_key = aggregate_public_key_shares(&pk_shares); + // Clients aggregate public key shares to produce collective public key `pk` + let pk = aggregate_public_key_shares(&pk_shares); - // multi-party key gen round 2 - let server_key_shares = client_keys + // -- Round 2 -- // + // In round 2 each client generates server key shares using the public key `pk`. + // Clients may also encrypt their private inputs using collective public key + // `pk`. Each client then uploads their server key share and private input + // ciphertexts to the server. + + // Clients generate server key shares + // + // We assign user_id 0 to client 0, user_id 1 to client 1, user_id 2 to client + // 2, and user_id 4 to client 4. + // + // Note that `user_id`'s must be unique among the clients and must be less than + // total number of clients. + let server_key_shares = cks .iter() .enumerate() - .map(|(user_id, k)| gen_mp_keys_phase2(k, user_id, no_of_parties, &public_key)) + .map(|(user_id, k)| gen_mp_keys_phase2(k, user_id, no_of_parties, &pk)) .collect_vec(); - // server aggregates server key shares and sets it + // Each client encrypts their private inputs using the collective public key + // `pk`. Unlike non-inteactive MPC protocol, given that private inputs are + // encrypted using collective public key, the private inputs are directly + // encrypted under the ideal RLWE secret `s`. + let c0_a = thread_rng().gen::(); + let c0_enc = pk.encrypt(vec![c0_a].as_slice()); + let c1_a = thread_rng().gen::(); + let c1_enc = pk.encrypt(vec![c1_a].as_slice()); + let c2_a = thread_rng().gen::(); + let c2_enc = pk.encrypt(vec![c2_a].as_slice()); + let c3_a = thread_rng().gen::(); + let c3_enc = pk.encrypt(vec![c3_a].as_slice()); + + // Clients upload their server key along with private encrypted inputs to + // the server + + // Server side // + + // Server receives server key shares from each client and proceeds to + // aggregated the shares and produce the server key let server_key = aggregate_server_key_shares(&server_key_shares); server_key.set_server_key(); - // private inputs - let a = 4u8; - let b = 6u8; - let c = 128u8; - let fhe_a = public_key.encrypt(&a); - let fhe_b = public_key.encrypt(&b); - let fhe_c = public_key.encrypt(&c); + // Server proceeds to extract clients private inputs + // + // Clients encrypt their FheUint8s inputs packed in a batched ciphertext. + // The server must extract clients private inputs from the batch ciphertext + // either (1) using `extract_at(index)` to extract `index`^{th} FheUint8 + // ciphertext (2) `extract_all()` to extract all available FheUint8s (3) + // `extract_many(many)` to extract first `many` available FheUint8s + let c0_a_enc = c0_enc.extract_at(0); + let c1_a_enc = c1_enc.extract_at(0); + let c2_a_enc = c2_enc.extract_at(0); + let c3_a_enc = c3_enc.extract_at(0); + + // Server proceeds to evaluate function1 on clients private inputs + let ct_out_f1 = function1_fhe(&c0_a_enc, &c1_a_enc, &c2_a_enc, &c3_a_enc); + + // After server has finished evaluating the circuit on client private + // inputs. Clients can proceed to multi-party decryption protocol to + // decryption output ciphertext + + // Client Side // + + // In multi-party decryption protocol, client must come online, download the + // output ciphertext from the server, product "output ciphertext" dependent + // decryption share, and send it to other parties. After receiving + // decryption shares of other parties, client independently aggregates the + // decrytion shares and decrypts the output ciphertext. + + // Client generate decryption shares + let decryption_shares = cks + .iter() + .map(|k| k.gen_decryption_share(&ct_out_f1)) + .collect_vec(); + + // After receiving decryption shares from other parties, client aggregates the + // shares and decryption output ciphertext + let out_f1 = cks[0].aggregate_decryption_shares(&ct_out_f1, &decryption_shares); + + // Check correctness of function1 output + let want_f1 = function1(c0_a, c1_a, c2_a, c3_a); + assert!(out_f1 == want_f1); + + // -------- + + // Once server key is produced it can be re-used across different functions + // with different private client inputs for the same set of clients. + // + // Here we run `function2_fhe` for the same of clients but with different + // private inputs. Clients do not need to participate in the 2 round + // protocol again, instead they only upload their new private inputs to the + // server. + + // Clients encrypt their private inputs + let c0_a = thread_rng().gen::(); + let c0_enc = pk.encrypt(vec![c0_a].as_slice()); + let c1_a = thread_rng().gen::(); + let c1_enc = pk.encrypt(vec![c1_a].as_slice()); + let c2_a = thread_rng().gen::(); + let c2_enc = pk.encrypt(vec![c2_a].as_slice()); + let c3_a = thread_rng().gen::(); + let c3_enc = pk.encrypt(vec![c3_a].as_slice()); + + // Clients uploads only their new private inputs to the server + + // Server side // + + // Server receives private inputs from the clients, extract them, and + // proceeds to evaluate `function2_fhe` + let c0_a_enc = c0_enc.extract_at(0); + let c1_a_enc = c1_enc.extract_at(0); + let c2_a_enc = c2_enc.extract_at(0); + let c3_a_enc = c3_enc.extract_at(0); - // fhe evaluation - let now = std::time::Instant::now(); - let fhe_out = fhe_circuit(&fhe_a, &fhe_b, &fhe_c); - println!("Circuit time: {:?}", now.elapsed()); + let ct_out_f2 = function2_fhe(&c0_a_enc, &c1_a_enc, &c2_a_enc, &c3_a_enc); - // plain evaluation - let out = plain_circuit(a, b, c); + // Client side // - // generate decryption shares to decrypt ciphertext fhe_out - let decryption_shares = client_keys + // Clients generate decryption shares for `ct_out_f2` + let decryption_shares = cks .iter() - .map(|k| k.gen_decryption_share(&fhe_out)) + .map(|k| k.gen_decryption_share(&ct_out_f2)) .collect_vec(); - // decrypt fhe_out using decryption shares - let got_out = client_keys[0].aggregate_decryption_shares(&fhe_out, &decryption_shares); + // Clients aggregate decryption shares and decrypt `ct_out_f2` + let out_f2 = cks[0].aggregate_decryption_shares(&ct_out_f2, &decryption_shares); - assert_eq!(got_out, out); + // We check correctness of function2 + let want_f2 = function2(c0_a, c1_a, c2_a, c3_a); + assert!(want_f2 == out_f2); } diff --git a/examples/non_interactive_fheuint8.rs b/examples/non_interactive_fheuint8.rs index 75ed907..3680477 100644 --- a/examples/non_interactive_fheuint8.rs +++ b/examples/non_interactive_fheuint8.rs @@ -35,7 +35,8 @@ fn main() { // client 0 encrypts its private inputs let c0_a = thread_rng().gen::(); - // Clients encrypt their private inputs in a seeded batched ciphertext + // Clients encrypt their private inputs in a seeded batched ciphertext using + // their private RLWE secret `u_j`. let c0_enc = cks[0].encrypt(vec![c0_a].as_slice()); // client 1 encrypts its private inputs @@ -54,6 +55,9 @@ fn main() { // // We assign user_id 0 to client 0, user_id 1 to client 1, user_id 2 to client // 2, user_id 3 to client 3. + // + // Note that `user_id`s must be unique among the clients and must be less than + // total number of clients. let server_key_shares = cks .iter() .enumerate() @@ -126,9 +130,10 @@ fn main() { // ----------- // Server key can be re-used for different functions with different private - // client inputs for the same set of clients. Here we run `function2_fhe` for - // the same set of client but with new inputs. Clients only have to upload their - // private inputs to the server this time. + // client inputs for the same set of clients. + // + // Here we run `function2_fhe` for the same set of client but with new inputs. + // Clients only have to upload their private inputs to the server this time. // Each client encrypts their private input let c0_a = thread_rng().gen::(); @@ -140,7 +145,7 @@ fn main() { let c3_a = thread_rng().gen::(); let c3_enc = cks[3].encrypt(vec![c3_a].as_slice()); - // Clients upload their private inputs to the server + // Clients upload only their new private inputs to the server // Server side // diff --git a/src/bool/evaluator.rs b/src/bool/evaluator.rs index 00ed7db..1e55c5b 100644 --- a/src/bool/evaluator.rs +++ b/src/bool/evaluator.rs @@ -22,7 +22,7 @@ use crate::{ pbs::{pbs, PbsInfo, PbsKey, WithShoupRepr}, random::{ DefaultSecureRng, NewWithSeed, RandomFill, RandomFillGaussianInModulus, - RandomFillUniformInModulus, RandomGaussianElementInModulus, + RandomFillUniformInModulus, }, rgsw::{ generate_auto_map, public_key_encrypt_rgsw, rgsw_by_rgsw_inplace, rgsw_x_rgsw_scratch_rows, diff --git a/src/bool/mp_api.rs b/src/bool/mp_api.rs index 623bca3..e28ebe3 100644 --- a/src/bool/mp_api.rs +++ b/src/bool/mp_api.rs @@ -43,7 +43,7 @@ pub fn set_parameter_set(select: ParameterSelector) { } /// Set application specific interactive multi-party common reference string -pub fn set_mp_seed(seed: [u8; 32]) { +pub fn set_common_reference_seed(seed: [u8; 32]) { assert!( MULTI_PARTY_CRS .set(InteractiveMultiPartyCrs { seed: seed }) @@ -57,9 +57,9 @@ pub fn gen_client_key() -> ClientKey { BoolEvaluator::with_local(|e| e.client_key()) } -/// Generate client's share for collective public key, i.e round 1, of the -/// protocol -pub fn gen_mp_keys_phase1( +/// Generate client's share for collective public key, i.e round 1 share, in +/// round 1 of the 2 round protocol +pub fn interactive_multi_party_round1_share( ck: &ClientKey, ) -> CommonReferenceSeededCollectivePublicKeyShare, [u8; 32], BoolParameters> { BoolEvaluator::with_local(|e| { @@ -319,13 +319,16 @@ mod tests { set_parameter_set(ParameterSelector::InteractiveLTE2Party); let mut seed = [0u8; 32]; thread_rng().fill_bytes(&mut seed); - set_mp_seed(seed); + set_common_reference_seed(seed); let parties = 2; let cks = (0..parties).map(|_| gen_client_key()).collect_vec(); // round 1 - let pk_shares = cks.iter().map(|k| gen_mp_keys_phase1(k)).collect_vec(); + let pk_shares = cks + .iter() + .map(|k| interactive_multi_party_round1_share(k)) + .collect_vec(); // collective pk let pk = aggregate_public_key_shares(&pk_shares); @@ -408,13 +411,16 @@ mod tests { set_parameter_set(ParameterSelector::InteractiveLTE2Party); let mut seed = [0u8; 32]; thread_rng().fill_bytes(&mut seed); - set_mp_seed(seed); + set_common_reference_seed(seed); let parties = 2; let cks = (0..parties).map(|_| gen_client_key()).collect_vec(); // round 1 - let pk_shares = cks.iter().map(|k| gen_mp_keys_phase1(k)).collect_vec(); + let pk_shares = cks + .iter() + .map(|k| interactive_multi_party_round1_share(k)) + .collect_vec(); // collective pk let pk = aggregate_public_key_shares(&pk_shares); diff --git a/src/bool/print_noise.rs b/src/bool/print_noise.rs index cb1764a..5ca9dc4 100644 --- a/src/bool/print_noise.rs +++ b/src/bool/print_noise.rs @@ -374,19 +374,22 @@ mod tests { evaluator::InteractiveMultiPartyCrs, keys::{key_size::KeySize, ServerKeyEvaluationDomain}, }, - gen_client_key, gen_mp_keys_phase1, gen_mp_keys_phase2, + gen_client_key, gen_mp_keys_phase2, interactive_multi_party_round1_share, parameters::CiphertextModulus, random::DefaultSecureRng, - set_mp_seed, set_parameter_set, + set_common_reference_seed, set_parameter_set, utils::WithLocal, BoolEvaluator, DefaultDecomposer, ModularOpsU64, Ntt, NttBackendU64, }; set_parameter_set(crate::ParameterSelector::InteractiveLTE2Party); - set_mp_seed(InteractiveMultiPartyCrs::random().seed); + set_common_reference_seed(InteractiveMultiPartyCrs::random().seed); let parties = 2; let cks = (0..parties).map(|_| gen_client_key()).collect_vec(); - let pk_shares = cks.iter().map(|k| gen_mp_keys_phase1(k)).collect_vec(); + let pk_shares = cks + .iter() + .map(|k| interactive_multi_party_round1_share(k)) + .collect_vec(); let pk = aggregate_public_key_shares(&pk_shares); let server_key_shares = cks