From d8d5e40f00c7255356e4020ab830dd79537b2de7 Mon Sep 17 00:00:00 2001 From: Janmajaya Mall Date: Tue, 2 Jul 2024 10:30:11 +0530 Subject: [PATCH] implement min, max, mux --- examples/bomberman.rs | 44 +++++++++++++++++----------- examples/meeting_friends.rs | 44 +++++++++++++++------------- examples/non_interactive_fheuint8.rs | 44 ++++++++++++++-------------- src/bool/mp_api.rs | 15 ++++++++-- src/lib.rs | 12 ++------ src/lwe.rs | 13 ++++---- src/rgsw/mod.rs | 13 ++++---- src/shortint/mod.rs | 35 +++++++++++++++++++++- src/shortint/ops.rs | 10 ++++--- 9 files changed, 136 insertions(+), 94 deletions(-) diff --git a/examples/bomberman.rs b/examples/bomberman.rs index e51a8a5..fe5bba2 100644 --- a/examples/bomberman.rs +++ b/examples/bomberman.rs @@ -34,6 +34,8 @@ fn coordinates_is_equal(a: &Coordinates, b: &Coordinates) -> &(a.x().eq(b.x())) & &(a.y().eq(b.y())) } +/// Traverse the map with `p0` moves and check whether any of the moves equal +/// bomb coordinates (in encrypted domain) fn traverse_map(p0: &[Coordinates], bomb_coords: &[Coordinates]) -> FheBool { // First move let mut out = coordinates_is_equal(&p0[0], &bomb_coords[0]); @@ -52,23 +54,23 @@ fn traverse_map(p0: &[Coordinates], bomb_coords: &[Coordinates(), thread_rng().gen())) .collect_vec(); + // Coordinates of bomb placed by Player 1 let player_1_bomb = Coordinates::new(thread_rng().gen::(), thread_rng().gen()); + // Coordinates of bomb placed by Player 2 let player_2_bomb = Coordinates::new(thread_rng().gen::(), thread_rng().gen()); + // Coordinates of bomb placed by Player 3 let player_3_bomb = Coordinates::new(thread_rng().gen::(), thread_rng().gen()); println!("P0 moves coordinates: {:?}", &player_0_moves); - println!("P1 bomb coordinate : {:?}", &player_1_bomb); - println!("P2 bomb coordinate : {:?}", &player_2_bomb); - println!("P3 bomb coordinate : {:?}", &player_3_bomb); + println!("P1 bomb coordinates : {:?}", &player_1_bomb); + println!("P2 bomb coordinates : {:?}", &player_2_bomb); + println!("P3 bomb coordinates : {:?}", &player_3_bomb); - // Al players encrypt their private inputs + // Players encrypt their private inputs let player_0_enc = cks[0].encrypt( player_0_moves .iter() @@ -115,14 +122,14 @@ fn main() { let player_2_enc = cks[2].encrypt(vec![*player_2_bomb.x(), *player_2_bomb.y()].as_slice()); let player_3_enc = cks[3].encrypt(vec![*player_3_bomb.x(), *player_3_bomb.y()].as_slice()); - // All player upload the encrypted inputs and server key shates to the server + // Players upload the encrypted inputs and server key shares to the server // Server side // let server_key = aggregate_server_key_shares(&server_key_shares); server_key.set_server_key(); - // server parses all player inputs + // server parses Player inputs let player_0_moves_enc = { let c = player_0_enc .unseed::>>() @@ -147,17 +154,20 @@ fn main() { Coordinates::new(c.extract_at(0), c.extract_at(1)) }; - // run the game + // Server runs the game let player_0_dead_ct = traverse_map( &player_0_moves_enc, &vec![player_1_bomb_enc, player_2_bomb_enc, player_3_bomb_enc], ); - // All players generate decryption shares + // Client side // + + // Players generate decryption shares and send them to each other let decryption_shares = cks .iter() .map(|k| k.gen_decryption_share(&player_0_dead_ct)) .collect_vec(); + // Players decrypt to find whether Player 0 survived let player_0_dead = cks[0].aggregate_decryption_shares(&player_0_dead_ct, &decryption_shares); if player_0_dead { diff --git a/examples/meeting_friends.rs b/examples/meeting_friends.rs index 8e9f31a..856b28f 100644 --- a/examples/meeting_friends.rs +++ b/examples/meeting_friends.rs @@ -1,6 +1,6 @@ use bin_rs::*; use itertools::Itertools; -use rand::{thread_rng, RngCore}; +use rand::{thread_rng, Rng, RngCore}; struct Location(T, T); @@ -47,17 +47,19 @@ fn should_meet_fhe( // Here we write a simple application with two users `a` and `b`. User `a` wants // to find (long distance) friends that live in their neighbourhood. User `b` is // open to meeting new friends within some distance of their location. Both user -// `a` and `b` encrypt their location and upload to the server. User `b` also -// encrypts the distance square threshold within which they are interested in -// meeting new friends. The server calculates the square of the distance between -// user a's location and user b's location and returns encrypted boolean output -// indicating whether square of distance is <= user b's supplied distance square -// threshold. User `a` then comes online, downloads output ciphertext, produces -// their decryption share for user `b`, and uploads the decryption share to the +// `a` and `b` encrypt their locations and upload their encrypted locations to +// the server. User `b` also encrypts the distance square threshold within which +// they are interested in meeting new friends. and send encrypted distance +// square threshold to the server. +// The server calculates the square of the distance between user a's location +// and user b's location and produces encrypted boolean output indicating +// whether square of distance is <= user b's supplied distance square threshold. +// User `a` then comes online, downloads output ciphertext, produces their +// decryption share for user `b`, and uploads the decryption share to the // server. User `b` comes online, downloads output ciphertext and user a's // decryption share, produces their own decryption share, and then decrypts the -// encrypted boolean output. If the output is `True`, it indicates -// user `a` is within the distance square threshold defined by user `b`. +// encrypted boolean output. If the output is `True`, it indicates user `a` is +// within the distance square threshold defined by user `b`. fn main() { set_parameter_set(ParameterSelector::NonInteractiveLTE2Party); @@ -73,7 +75,7 @@ fn main() { // Generate client keys let cks = (0..no_of_parties).map(|_| gen_client_key()).collect_vec(); - // We assign id 0 to client 0 and id 1 to client 1 + // We assign user_id 0 to client 0 and user_id 1 to client 1 let a_id = 0; let b_id = 1; let user_a_secret = &cks[0]; @@ -85,30 +87,30 @@ fn main() { // User a and b encrypt their locations let user_a_secret = &cks[0]; - let user_a_location = Location::new(50, 60); + let user_a_location = Location::new(thread_rng().gen::(), thread_rng().gen::()); let user_a_enc = user_a_secret.encrypt(vec![*user_a_location.x(), *user_a_location.y()].as_slice()); - let user_b_location = Location::new(50, 60); - // User b also encrypts the distance sq threshold - let user_b_threshold = 20; + let user_b_location = Location::new(thread_rng().gen::(), thread_rng().gen::()); + // User b also encrypts the distance square threshold + let user_b_threshold = 40; let user_b_enc = user_b_secret .encrypt(vec![*user_b_location.x(), *user_b_location.y(), user_b_threshold].as_slice()); // Server Side // // Both user a and b upload their private inputs and server key shares to - // the server in one shot message + // the server in single shot message let server_key = aggregate_server_key_shares(&vec![a_server_key_share, b_server_key_share]); server_key.set_server_key(); // Server parses private inputs from user a and b let user_a_location_enc = { - let c = user_a_enc.unseed::>>().key_switch(0); + let c = user_a_enc.unseed::>>().key_switch(a_id); Location::new(c.extract_at(0), c.extract_at(1)) }; let (user_b_location_enc, user_b_threshold_enc) = { - let c = user_b_enc.unseed::>>().key_switch(1); + let c = user_b_enc.unseed::>>().key_switch(b_id); ( Location::new(c.extract_at(0), c.extract_at(1)), c.extract_at(2), @@ -124,13 +126,13 @@ fn main() { // Client Side // - // user a comes online, downloads out_c, produces a decryption share, and + // user `a` comes online, downloads `out_c`, produces a decryption share, and // uploads the decryption share to the server. let a_dec_share = user_a_secret.gen_decryption_share(&out_c); - // user b comes online downloads user a's decryption share, generates their + // user `b` comes online downloads user `a`'s decryption share, generates their // own decryption share, decrypts the output ciphertext. If the output is - // True, they contact user a to meet. + // True, they contact user `a` to meet. let b_dec_share = user_b_secret.gen_decryption_share(&out_c); let out_bool = user_b_secret.aggregate_decryption_shares(&out_c, &vec![b_dec_share, a_dec_share]); diff --git a/examples/non_interactive_fheuint8.rs b/examples/non_interactive_fheuint8.rs index fee6453..75ed907 100644 --- a/examples/non_interactive_fheuint8.rs +++ b/examples/non_interactive_fheuint8.rs @@ -46,7 +46,7 @@ fn main() { let c2_a = thread_rng().gen::(); let c2_enc = cks[2].encrypt(vec![c2_a].as_slice()); - // client 1 encrypts its private inputs + // client 3 encrypts its private inputs let c3_a = thread_rng().gen::(); let c3_enc = cks[3].encrypt(vec![c3_a].as_slice()); @@ -66,26 +66,26 @@ fn main() { // Server side // // Server receives server key shares from each client and proceeds to aggregate - // them to produce server key. After this point, server can use server key share - // to evaluate any arbitrary function on encrypted private inputs from the fixed - // set of clients + // them to produce the server key. After this point, server can use the server + // key to evaluate any arbitrary function on encrypted private inputs from + // the fixed set of clients - // aggregate shares and generates server key + // aggregate server shares and generate the server key let server_key = aggregate_server_key_shares(&server_key_shares); server_key.set_server_key(); // Server proceeds to extract private inputs sent by clients // // To extract client 0's (with user_id=0) private inputs we first key switch - // client 0's private inputs from thei secret to ideal secret of the mpc + // client 0's private inputs from theit secret to ideal secret of the mpc // protocol. To indicate we're key switching client 0's private input we - // supply client 0's user_id i.e. we call `key_switch(0)`. Then we extract + // supply client 0's `user_id` i.e. we call `key_switch(0)`. Then we extract // the first ciphertext by calling `extract_at(0)`. // - // Since client 0 only encrypted 1 input in batched ciphertext calling + // Since client 0 only encrypts 1 input in batched ciphertext, calling // extract_at(index) for `index` > 0 will panic. If client 0 had more private - // inputs then we can either extract them all at once by `extract_all` or first - // `many` of them by `extract_many(many)` + // inputs then we can either extract them all at once with `extract_all` or + // first `many` of them with `extract_many(many)` let ct_c0_a = c0_enc.unseed::>>().key_switch(0).extract_at(0); let ct_c1_a = c1_enc.unseed::>>().key_switch(1).extract_at(0); @@ -93,7 +93,7 @@ fn main() { let ct_c3_a = c3_enc.unseed::>>().key_switch(3).extract_at(0); // After extracting each client's private inputs, server proceeds to evaluate - // the function1 + // function1 let now = std::time::Instant::now(); let ct_out_f1 = function1_fhe(&ct_c0_a, &ct_c1_a, &ct_c2_a, &ct_c3_a); println!("Function1 FHE evaluation time: {:?}", now.elapsed()); @@ -104,10 +104,10 @@ fn main() { // Client side // // In multi-party decryption, each client needs to come online, download output - // ciphertext from the server, produce decryption share, and send to other - // parties (either via p2p or via server). After receving decryption shares - // for output ciphertext from other parties, client can independently decrypt - // output ciphertext. + // ciphertext from the server, produce "output ciphertext" dependent decryption + // share, and send it to other parties (either via p2p or via server). After + // receving decryption shares from other parties, clients can independently + // decrypt output ciphertext. // each client produces decryption share let decryption_shares = cks @@ -115,19 +115,19 @@ fn main() { .map(|k| k.gen_decryption_share(&ct_out_f1)) .collect_vec(); - // With all decrytpion shares, client can aggregate the shares and decrypt the + // With all decrytpion shares, clients can aggregate the shares and decrypt the // ciphertext let out_f1 = cks[0].aggregate_decryption_shares(&ct_out_f1, &decryption_shares); - // we check that output is correct + // we check correctness of function1 let want_out_f1 = function1(c0_a, c1_a, c2_a, c3_a); assert_eq!(out_f1, want_out_f1); // ----------- - // Server key can be re-used for different function with different private - // client inputs for same set of clients. Here we run `function2_fhe` for - // the same set of client but with new inputs. Client only have to upload their + // Server key can be re-used for different functions with different private + // client inputs for the same set of clients. Here we run `function2_fhe` for + // the same set of client but with new inputs. Clients only have to upload their // private inputs to the server this time. // Each client encrypts their private input @@ -140,7 +140,7 @@ fn main() { let c3_a = thread_rng().gen::(); let c3_enc = cks[3].encrypt(vec![c3_a].as_slice()); - // Client upload their private inputs to the server + // Clients upload their private inputs to the server // Server side // @@ -163,7 +163,7 @@ fn main() { .map(|k| k.gen_decryption_share(&ct_out_f2)) .collect_vec(); - // Client independently aggregate the shares and decrypt + // Clients independently aggregate the shares and decrypt let out_f2 = cks[0].aggregate_decryption_shares(&ct_out_f2, &decryption_shares); // We check correctness of function2 diff --git a/src/bool/mp_api.rs b/src/bool/mp_api.rs index 37b650b..d3a5030 100644 --- a/src/bool/mp_api.rs +++ b/src/bool/mp_api.rs @@ -517,7 +517,7 @@ mod tests { let ct = seeded_ct.unseed::>>(); let m_back = (0..batch_size) - .map(|i| ck.decrypt(&ct.extract(i))) + .map(|i| ck.decrypt(&ct.extract_at(i))) .collect_vec(); assert_eq!(m, m_back); @@ -528,7 +528,7 @@ mod tests { fn all_uint8_apis() { use num_traits::Euclid; - use crate::div_zero_error_flag; + use crate::{div_zero_error_flag, FheBool}; set_single_party_parameter_sets(SP_TEST_BOOL_PARAMS); @@ -624,7 +624,7 @@ mod tests { } } - // Comparisons + // // Comparisons { { let c_eq = c0.eq(&c1); @@ -681,6 +681,15 @@ mod tests { ); } } + + // mux + { + let selector = thread_rng().gen_bool(0.5); + let selector_enc: FheBool = ck.encrypt(&selector); + let mux_out = ck.decrypt(&c0.mux(&c1, &selector_enc)); + let want_mux_out = if selector { m0 } else { m1 }; + assert_eq!(mux_out, want_mux_out); + } } } } diff --git a/src/lib.rs b/src/lib.rs index a60222a..544973a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,13 +16,10 @@ mod utils; pub use backend::{ ArithmeticLazyOps, ArithmeticOps, ModInit, ModularOpsU64, ShoupMatrixFMA, VectorOps, }; -// pub use bool::{ -// aggregate_public_key_shares, aggregate_server_key_shares, gen_client_key, -// gen_mp_keys_phase1, gen_mp_keys_phase2, set_mp_seed, set_parameter_set, -// ParameterSelector, }; + pub use bool::*; pub use ntt::{Ntt, NttBackendU64, NttInit}; -pub use shortint::{div_zero_error_flag, FheUint8}; +pub use shortint::{div_zero_error_flag, reset_error_flags, FheUint8}; pub use decomposer::{Decomposer, DecomposerIter, DefaultDecomposer}; @@ -96,11 +93,6 @@ pub trait RowEntity: Row { fn zeros(col: usize) -> Self; } -trait Secret { - type Element; - fn values(&self) -> &[Self::Element]; -} - impl Matrix for Vec> { type MatElement = T; type R = Vec; diff --git a/src/lwe.rs b/src/lwe.rs index 01d818b..aa2188c 100644 --- a/src/lwe.rs +++ b/src/lwe.rs @@ -173,7 +173,7 @@ mod tests { decomposer::DefaultDecomposer, random::{DefaultSecureRng, NewWithSeed}, utils::{fill_random_ternary_secret_with_hamming_weight, WithLocal}, - MatrixEntity, MatrixMut, Secret, + MatrixEntity, MatrixMut, }; use super::*; @@ -185,13 +185,6 @@ mod tests { pub(crate) values: Vec, } - impl Secret for LweSecret { - type Element = i32; - fn values(&self) -> &[Self::Element] { - &self.values - } - } - impl LweSecret { fn random(hw: usize, n: usize) -> LweSecret { DefaultSecureRng::with_local_mut(|rng| { @@ -201,6 +194,10 @@ mod tests { LweSecret { values: out } }) } + + fn values(&self) -> &[i32] { + &self.values + } } struct LweKeySwitchingKey { diff --git a/src/rgsw/mod.rs b/src/rgsw/mod.rs index a3bb5c2..2e0234c 100644 --- a/src/rgsw/mod.rs +++ b/src/rgsw/mod.rs @@ -24,7 +24,7 @@ pub(crate) mod tests { fill_random_ternary_secret_with_hamming_weight, generate_prime, negacyclic_mul, tests::Stats, ToShoup, TryConvertFrom1, WithLocal, }, - Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, Secret, + Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, }; use super::{ @@ -406,13 +406,6 @@ pub(crate) mod tests { pub(crate) values: Vec, } - impl Secret for RlweSecret { - type Element = i32; - fn values(&self) -> &[Self::Element] { - &self.values - } - } - impl RlweSecret { pub fn random(hw: usize, n: usize) -> RlweSecret { DefaultSecureRng::with_local_mut(|rng| { @@ -422,6 +415,10 @@ pub(crate) mod tests { RlweSecret { values: out } }) } + + fn values(&self) -> &[i32] { + &self.values + } } fn random_seed() -> [u8; 32] { diff --git a/src/shortint/mod.rs b/src/shortint/mod.rs index fc4d7ed..e8ecdb7 100644 --- a/src/shortint/mod.rs +++ b/src/shortint/mod.rs @@ -17,6 +17,16 @@ pub fn div_zero_error_flag() -> Option { DIV_ZERO_ERROR.with_borrow(|c| c.clone()) } +/// Reset all error flags +/// +/// Error flags are thread local. When running multiple circuits in sequence +/// within a single program you must prevent error flags set during the +/// execution of previous circuit to affect error flags set during execution of +/// the next circuit by resetting the flags before starting with next circuit. +pub fn reset_error_flags() { + DIV_ZERO_ERROR.with_borrow_mut(|c| *c = None); +} + mod frontend { use super::ops::{ arbitrary_bit_adder, arbitrary_bit_division_for_quotient_and_rem, arbitrary_bit_subtractor, @@ -176,7 +186,9 @@ mod frontend { } mod booleans { - use crate::shortint::ops::{arbitrary_bit_comparator, arbitrary_bit_equality}; + use crate::shortint::ops::{ + arbitrary_bit_comparator, arbitrary_bit_equality, arbitrary_bit_mux, + }; use super::*; @@ -238,6 +250,27 @@ mod frontend { FheBool { data: a_less_b } }) } + + /// Returns `Self` if `selector = True` else returns `other` + pub fn mux(&self, other: &FheUint8, selector: &FheBool) -> FheUint8 { + BoolEvaluator::with_local_mut(|e| { + let key = RuntimeServerKey::global(); + let out = arbitrary_bit_mux(e, selector.data(), self.data(), other.data(), key); + FheUint8 { data: out } + }) + } + + /// max(`Self`, `other`) + pub fn max(&self, other: &FheUint8) -> FheUint8 { + let self_gt = self.gt(other); + self.mux(other, &self_gt) + } + + /// min(`Self`, `other`) + pub fn min(&self, other: &FheUint8) -> FheUint8 { + let self_lt = self.lt(other); + self.mux(other, &self_lt) + } } } } diff --git a/src/shortint/ops.rs b/src/shortint/ops.rs index e4911ec..1beae6a 100644 --- a/src/shortint/ops.rs +++ b/src/shortint/ops.rs @@ -114,9 +114,10 @@ pub(super) fn bit_mux( // (s&a) | ((1-s)^b) let not_selector = evaluator.not(&selector); - let s_and_a = evaluator.and(&selector, if_true, key); + let mut s_and_a = evaluator.and(&selector, if_true, key); let s_and_b = evaluator.and(¬_selector, if_false, key); - evaluator.or(&s_and_a, &s_and_b, key) + evaluator.or(&mut s_and_a, &s_and_b, key); + s_and_a } pub(super) fn arbitrary_bit_mux( @@ -131,9 +132,10 @@ pub(super) fn arbitrary_bit_mux( izip!(if_true.iter(), if_false.iter()) .map(|(a, b)| { - let s_and_a = evaluator.and(&selector, a, key); + let mut s_and_a = evaluator.and(&selector, a, key); let s_and_b = evaluator.and(¬_selector, b, key); - evaluator.or(&s_and_a, &s_and_b, key) + evaluator.or_inplace(&mut s_and_a, &s_and_b, key); + s_and_a }) .collect() }