From a517dfb83d4c80a823f946c6a350c74307f4ced9 Mon Sep 17 00:00:00 2001 From: Janmajaya Mall Date: Tue, 2 Jul 2024 11:53:10 +0530 Subject: [PATCH] add div_by_zero example and if_and_else example --- examples/div_by_zero.rs | 123 ++++++++++++++++++++++++++++++++++++++++ examples/if_and_else.rs | 107 ++++++++++++++++++++++++++++++++++ 2 files changed, 230 insertions(+) create mode 100644 examples/div_by_zero.rs create mode 100644 examples/if_and_else.rs diff --git a/examples/div_by_zero.rs b/examples/div_by_zero.rs new file mode 100644 index 0000000..2a525ce --- /dev/null +++ b/examples/div_by_zero.rs @@ -0,0 +1,123 @@ +use bin_rs::*; +use itertools::Itertools; +use rand::{thread_rng, Rng, RngCore}; + +fn main() { + set_parameter_set(ParameterSelector::NonInteractiveLTE2Party); + + // set application's common reference seed + let mut seed = [0u8; 32]; + thread_rng().fill_bytes(&mut seed); + set_common_reference_seed(seed); + + let no_of_parties = 2; + + // Generate client keys + let cks = (0..no_of_parties).map(|_| gen_client_key()).collect_vec(); + + // Generate server key shares + let server_key_shares = cks + .iter() + .enumerate() + .map(|(id, k)| gen_server_key_share(id, no_of_parties, k)) + .collect_vec(); + + // Aggregate server key shares and set the server key + let server_key = aggregate_server_key_shares(&server_key_shares); + server_key.set_server_key(); + + // -------- + + // We attempt to divide by 0 in encrypted domain and then check whether div by 0 + // error flag is set to True. + let numerator = thread_rng().gen::(); + let numerator_enc = cks[0] + .encrypt(vec![numerator].as_slice()) + .unseed::>>() + .key_switch(0) + .extract_at(0); + let zero_enc = cks[1] + .encrypt(vec![0].as_slice()) + .unseed::>>() + .key_switch(1) + .extract_at(0); + + let (quotient_enc, remainder_enc) = numerator_enc.div_rem(&zero_enc); + + // When attempting to divide by zero, for uint8, quotient is always 255 and + // remainder = numerator + let quotient = cks[0].aggregate_decryption_shares( + "ient_enc, + &cks.iter() + .map(|k| k.gen_decryption_share("ient_enc)) + .collect_vec(), + ); + let remainder = cks[0].aggregate_decryption_shares( + &remainder_enc, + &cks.iter() + .map(|k| k.gen_decryption_share(&remainder_enc)) + .collect_vec(), + ); + assert!(quotient == 255); + assert!(remainder == numerator); + + // Div by zero error flag must be True + let div_by_zero_enc = div_zero_error_flag().expect("We performed division. Flag must be set"); + let div_by_zero = cks[0].aggregate_decryption_shares( + &div_by_zero_enc, + &cks.iter() + .map(|k| k.gen_decryption_share(&div_by_zero_enc)) + .collect_vec(), + ); + assert!(div_by_zero == true); + + // ------- + + // div by zero error flag is thread local. If we were to run another circuit + // without stopping the thread (i.e. within the same program as previous + // one), we must reset errors flags set by previous circuit with + // `reset_error_flags()` to prevent error flags of previous circuit affecting + // the flags of the next circuit. + reset_error_flags(); + + // We divide again but with non-zero denominator this time and check that div + // by zero flag is set to False + let numerator = thread_rng().gen::(); + let denominator = thread_rng().gen::(); + let numerator_enc = cks[0] + .encrypt(vec![numerator].as_slice()) + .unseed::>>() + .key_switch(0) + .extract_at(0); + let denominator_enc = cks[1] + .encrypt(vec![denominator].as_slice()) + .unseed::>>() + .key_switch(1) + .extract_at(0); + + let (quotient_enc, remainder_enc) = numerator_enc.div_rem(&denominator_enc); + let quotient = cks[0].aggregate_decryption_shares( + "ient_enc, + &cks.iter() + .map(|k| k.gen_decryption_share("ient_enc)) + .collect_vec(), + ); + let remainder = cks[0].aggregate_decryption_shares( + &remainder_enc, + &cks.iter() + .map(|k| k.gen_decryption_share(&remainder_enc)) + .collect_vec(), + ); + assert!(quotient == numerator.div_euclid(denominator)); + assert!(remainder == numerator.rem_euclid(denominator)); + + // Div by zero error flag must be set to False + let div_by_zero_enc = div_zero_error_flag().expect("We performed division. Flag must be set"); + let div_by_zero = cks[0].aggregate_decryption_shares( + &div_by_zero_enc, + &cks.iter() + .map(|k| k.gen_decryption_share(&div_by_zero_enc)) + .collect_vec(), + ); + assert!(div_by_zero == false); +} diff --git a/examples/if_and_else.rs b/examples/if_and_else.rs new file mode 100644 index 0000000..dff8047 --- /dev/null +++ b/examples/if_and_else.rs @@ -0,0 +1,107 @@ +use bin_rs::*; +use itertools::Itertools; +use rand::{thread_rng, Rng, RngCore}; + +/// Code that runs if condition of conditional branch is `True` +fn circuit_branch_true(a: &FheUint8, b: &FheUint8) -> FheUint8 { + a + b +} + +/// Code that runs if condition of conditional branch is `False` +fn circuit_branch_false(a: &FheUint8, b: &FheUint8) -> FheUint8 { + a * b +} + +// Conditional branching (ie. If and else) are generally expensive in encrypted +// domain. The code must execute all the branches, and, as apparent, the +// runtime cost grows exponentially with no. of conditional branches. +// +// In general we recommend to write branchless code. In case the code cannot be +// modified to be branchless, the code must execute all branches and use a +// muxer to select correct output at the end. +// +// Below we showcase example of a single conditional branch in encrypted domain. +// The code executes both the branches (i.e. program runs both If and Else) and +// selects output of one of the branches with a mux. +fn main() { + set_parameter_set(ParameterSelector::NonInteractiveLTE2Party); + + // set application's common reference seed + let mut seed = [0u8; 32]; + thread_rng().fill_bytes(&mut seed); + set_common_reference_seed(seed); + + let no_of_parties = 2; + + // Generate client keys + let cks = (0..no_of_parties).map(|_| gen_client_key()).collect_vec(); + + // Generate server key shares + let server_key_shares = cks + .iter() + .enumerate() + .map(|(id, k)| gen_server_key_share(id, no_of_parties, k)) + .collect_vec(); + + // Aggregate server key shares and set the server key + let server_key = aggregate_server_key_shares(&server_key_shares); + server_key.set_server_key(); + + // ------- + + // User 0 encrypts their private input `v_a` and User 1 encrypts their + // private input `v_b`. We want to execute: + // + // if v_a < v_b: + // return v_a + v_b + // else: + // return v_a * v_b + // + // We define two functions + // (1) `circuit_branch_true`: which executes v_a + v_b in encrypted domain. + // (2) `circuit_branch_false`: which executes v_a * v_b in encrypted + // domain. + // + // The circuit runs both `circuit_branch_true` and `circuit_branch_false` and + // then selects the output of `circuit_branch_true` if `v_a < v_b == TRUE` + // otherwise selects the output of `circuit_branch_false` if `v_a < v_b == + // FALSE` using mux. + + // Clients private inputs + let v_a = thread_rng().gen::(); + let v_b = thread_rng().gen::(); + let v_a_enc = cks[0] + .encrypt(vec![v_a].as_slice()) + .unseed::>>() + .key_switch(0) + .extract_at(0); + let v_b_enc = cks[1] + .encrypt(vec![v_b].as_slice()) + .unseed::>>() + .key_switch(1) + .extract_at(0); + + // Run both branches + let out_true_enc = circuit_branch_true(&v_a_enc, &v_b_enc); + let out_false_enc = circuit_branch_false(&v_a_enc, &v_b_enc); + + // define condition select v_a < v_b + let selector_bit = v_a_enc.lt(&v_b_enc); + + // select output of `circuit_branch_true` if selector_bit == TRUE otherwise + // select output of `circuit_branch_false` + let out_enc = out_true_enc.mux(&out_false_enc, &selector_bit); + + let out = cks[0].aggregate_decryption_shares( + &out_enc, + &cks.iter() + .map(|k| k.gen_decryption_share(&out_enc)) + .collect_vec(), + ); + let want_out = if v_a < v_b { + v_a.wrapping_add(v_b) + } else { + v_a.wrapping_mul(v_b) + }; + assert_eq!(out, want_out); +}