diff --git a/Cargo.toml b/Cargo.toml index 6194cd1..a78aba5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,4 +33,10 @@ path = "./examples/interactive_fheuint8.rs" [[example]] name = "non_interactive_fheuint8" path = "./examples/non_interactive_fheuint8.rs" +required-features = ["non_interactive_mp"] + + +[[example]] +name = "meeting_friends" +path = "./examples/meeting_friends.rs" required-features = ["non_interactive_mp"] \ No newline at end of file diff --git a/examples/meeting_friends.rs b/examples/meeting_friends.rs new file mode 100644 index 0000000..5139d2c --- /dev/null +++ b/examples/meeting_friends.rs @@ -0,0 +1,145 @@ +use bin_rs::*; +use itertools::Itertools; +use rand::{thread_rng, RngCore}; + +struct Location(T, T); + +impl Location { + fn new(x: T, y: T) -> Self { + Location(x, y) + } + + fn x(&self) -> &T { + &self.0 + } + fn y(&self) -> &T { + &self.1 + } +} + +fn should_meet(a: &Location, b: &Location, b_threshold: &u8) -> bool { + let diff_x = a.x() - b.x(); + let diff_y = a.y() - b.y(); + let d_sq = &(&diff_x * &diff_x) + &(&diff_y * &diff_y); + + d_sq.le(b_threshold) +} + +/// Calculates distance square between a's and b's location. Returns a boolean +/// indicating whether diatance sqaure is <= `b_threshold`. +fn should_meet_fhe( + a: &Location, + b: &Location, + b_threshold: &FheUint8, +) -> FheBool { + let diff_x = a.x() - b.x(); + let diff_y = a.y() - b.y(); + let d_sq = &(&diff_x * &diff_x) + &(&diff_y * &diff_y); + + d_sq.le(b_threshold) +} + +// Even wondered who are the long distance friends (friends of friends or +// friends of friends of friends...) that live nearby ? But how do you find +// them? Surely no-one will simply reveal their exact location just because +// there's a slight chance that a long distance friend lives nearby. +// +// Here we write a simple application with two users `a` and `b`. User `a` wants +// to find (long distance) friends that live in their neighbourhood. User `b` is +// open to meeting new friends within some distance of their location. Both user +// `a` and `b` encrypt their location and upload to the server. User `b` also +// encrypts the distance square threshold within which they are interested in +// meeting new friends. The server calculates the square of the distance between +// user a's location and user b's location and returns encrypted boolean output +// indicating whether square of distance is <= user b's supplied distance square +// threshold. User `a` then comes online, downloads output ciphertext, produces +// their decryption share for user `b`, and uploads the decryption share to the +// server. User `b` comes online, downloads output ciphertext and user a's +// decryption share, produces their own decryption share, and then decrypts the +// encrypted boolean output. If the output is `True`, it indicates +// user `a` is within the distance square threshold defined by user `b`. +fn main() { + set_parameter_set(ParameterSelector::NonInteractiveLTE2Party); + + // set application's common reference seed + let mut seed = [0u8; 32]; + thread_rng().fill_bytes(&mut seed); + set_common_reference_seed(seed); + + let no_of_parties = 2; + + // Client Side // + + // Generate client keys + let cks = (0..no_of_parties).map(|_| gen_client_key()).collect_vec(); + + // We assign id 0 to client 0 and id 1 to client 1 + let a_id = 0; + let b_id = 1; + let user_a_secret = &cks[0]; + let user_b_secret = &cks[1]; + + // User a and b generate server key shares + let a_server_key_share = gen_server_key_share(a_id, no_of_parties, user_a_secret); + let b_server_key_share = gen_server_key_share(b_id, no_of_parties, user_b_secret); + + // User a and b encrypt their locations + let user_a_secret = &cks[0]; + let user_a_location = Location::new(50, 60); + let user_a_enc = + user_a_secret.encrypt(vec![*user_a_location.x(), *user_a_location.y()].as_slice()); + + let user_b_location = Location::new(50, 60); + // User b also encrypts the distance sq threshold + let user_b_threshold = 20; + let user_b_enc = user_b_secret + .encrypt(vec![*user_b_location.x(), *user_b_location.y(), user_b_threshold].as_slice()); + + // Server Side // + + // Both user a and b upload their private inputs and server key shares to + // the server in one shot message + let server_key = aggregate_server_key_shares(&vec![a_server_key_share, b_server_key_share]); + server_key.set_server_key(); + + // Server parses private inputs from user a and b + let user_a_location_enc = { + let c = user_a_enc.unseed::>>().key_switch(0); + Location::new(c.extract(0), c.extract(1)) + }; + let (user_b_location_enc, user_b_threshold_enc) = { + let c = user_b_enc.unseed::>>().key_switch(1); + (Location::new(c.extract(0), c.extract(1)), c.extract(2)) + }; + + // run the circuit + let out_c = should_meet_fhe( + &user_a_location_enc, + &user_b_location_enc, + &user_b_threshold_enc, + ); + + // Client Side // + + // user a comes online, downloads out_c, produces a decryption share, and + // uploads the decryption share to the server. + let a_dec_share = user_a_secret.gen_decryption_share(&out_c); + + // user b comes online downloads user a's decryption share, generates their + // own decryption share, decrypts the output ciphertext. If the output is + // True, they contact user a to meet. + let b_dec_share = user_b_secret.gen_decryption_share(&out_c); + let out_bool = + user_b_secret.aggregate_decryption_shares(&out_c, &vec![b_dec_share, a_dec_share]); + + assert_eq!( + out_bool, + should_meet(&user_a_location, &user_b_location, &user_b_threshold) + ); + + if out_bool { + println!("A lives nearby. B should meet A."); + } else { + println!("A lives too far away!") + } +} diff --git a/src/lib.rs b/src/lib.rs index 0e91204..e504983 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,7 +22,7 @@ pub use backend::{ // ParameterSelector, }; pub use bool::*; pub use ntt::{Ntt, NttBackendU64, NttInit}; -pub use shortint::{div_zero_error_flag, FheUint8}; +pub use shortint::{div_zero_error_flag, FheBool, FheUint8}; pub use decomposer::{Decomposer, DecomposerIter, DefaultDecomposer}; diff --git a/src/shortint/enc_dec.rs b/src/shortint/enc_dec.rs index 8c15df9..4c2c593 100644 --- a/src/shortint/enc_dec.rs +++ b/src/shortint/enc_dec.rs @@ -8,6 +8,12 @@ use crate::{ RowMut, SampleExtractor, }; +/// Fhe Bool ciphertext +#[derive(Clone)] +pub struct FheBool { + pub(super) data: C, +} + /// Fhe UInt8 type /// /// - Stores encryptions of bits in little endian (i.e least signficant bit @@ -204,6 +210,25 @@ where } } +impl MultiPartyDecryptor> for K +where + K: MultiPartyDecryptor, +{ + type DecryptionShare = >::DecryptionShare; + + fn aggregate_decryption_shares( + &self, + c: &FheBool, + shares: &[Self::DecryptionShare], + ) -> bool { + self.aggregate_decryption_shares(&c.data, shares) + } + + fn gen_decryption_share(&self, c: &FheBool) -> Self::DecryptionShare { + self.gen_decryption_share(&c.data) + } +} + impl Encryptor> for K where K: Encryptor, diff --git a/src/shortint/mod.rs b/src/shortint/mod.rs index 61ce65a..09f65c8 100644 --- a/src/shortint/mod.rs +++ b/src/shortint/mod.rs @@ -1,9 +1,8 @@ mod enc_dec; mod ops; -mod types; pub type FheUint8 = enc_dec::FheUint8>; -pub type FheBool = Vec; +pub type FheBool = enc_dec::FheBool>; use std::cell::RefCell; @@ -15,7 +14,7 @@ thread_local! { /// Returns Boolean ciphertext indicating whether last division was attempeted /// with decnomiantor set to 0. -pub fn div_zero_error_flag() -> Option> { +pub fn div_zero_error_flag() -> Option { DIV_ZERO_ERROR.with_borrow(|c| c.clone()) } @@ -83,7 +82,7 @@ mod frontend { // set div by 0 error flag let is_zero = is_zero(e, rhs.data(), key); - DIV_ZERO_ERROR.set(Some(is_zero)); + DIV_ZERO_ERROR.set(Some(FheBool { data: is_zero })); let (quotient, _) = arbitrary_bit_division_for_quotient_and_rem( e, @@ -118,7 +117,7 @@ mod frontend { let key = RuntimeServerKey::global(); let (overflow, _) = arbitrary_bit_adder(e, self.data_mut(), rhs.data(), false, key); - overflow + FheBool { data: overflow } }) } @@ -128,7 +127,7 @@ mod frontend { let key = RuntimeServerKey::global(); let (overflow, _) = arbitrary_bit_adder(e, lhs.data_mut(), rhs.data(), false, key); - (lhs, overflow) + (lhs, FheBool { data: overflow }) }) } @@ -138,7 +137,7 @@ mod frontend { let (out, mut overflow, _) = arbitrary_bit_subtractor(e, self.data(), rhs.data(), key); e.not_inplace(&mut overflow); - (FheUint8 { data: out }, overflow) + (FheUint8 { data: out }, FheBool { data: overflow }) }) } @@ -148,7 +147,7 @@ mod frontend { // set div by 0 error flag let is_zero = is_zero(e, rhs.data(), key); - DIV_ZERO_ERROR.set(Some(is_zero)); + DIV_ZERO_ERROR.set(Some(FheBool { data: is_zero })); let (quotient, remainder) = arbitrary_bit_division_for_quotient_and_rem( e, @@ -172,7 +171,8 @@ mod frontend { pub fn eq(&self, other: &FheUint8) -> FheBool { BoolEvaluator::with_local_mut(|e| { let key = RuntimeServerKey::global(); - arbitrary_bit_equality(e, self.data(), other.data(), key) + let out = arbitrary_bit_equality(e, self.data(), other.data(), key); + FheBool { data: out } }) } @@ -182,7 +182,7 @@ mod frontend { let key = RuntimeServerKey::global(); let mut is_equal = arbitrary_bit_equality(e, self.data(), other.data(), key); e.not_inplace(&mut is_equal); - is_equal + FheBool { data: is_equal } }) } @@ -190,7 +190,8 @@ mod frontend { pub fn lt(&self, other: &FheUint8) -> FheBool { BoolEvaluator::with_local_mut(|e| { let key = RuntimeServerKey::global(); - arbitrary_bit_comparator(e, other.data(), self.data(), key) + let out = arbitrary_bit_comparator(e, other.data(), self.data(), key); + FheBool { data: out } }) } @@ -198,7 +199,8 @@ mod frontend { pub fn gt(&self, other: &FheUint8) -> FheBool { BoolEvaluator::with_local_mut(|e| { let key = RuntimeServerKey::global(); - arbitrary_bit_comparator(e, self.data(), other.data(), key) + let out = arbitrary_bit_comparator(e, self.data(), other.data(), key); + FheBool { data: out } }) } @@ -209,7 +211,7 @@ mod frontend { let mut a_greater_b = arbitrary_bit_comparator(e, self.data(), other.data(), key); e.not_inplace(&mut a_greater_b); - a_greater_b + FheBool { data: a_greater_b } }) } @@ -219,7 +221,7 @@ mod frontend { let key = RuntimeServerKey::global(); let mut a_less_b = arbitrary_bit_comparator(e, other.data(), self.data(), key); e.not_inplace(&mut a_less_b); - a_less_b + FheBool { data: a_less_b } }) } } diff --git a/src/shortint/types.rs b/src/shortint/types.rs deleted file mode 100644 index e69de29..0000000