| @ -0,0 +1,123 @@ | |||
| use bin_rs::*;
 | |||
| use itertools::Itertools;
 | |||
| use rand::{thread_rng, Rng, RngCore};
 | |||
| 
 | |||
| fn main() {
 | |||
|     set_parameter_set(ParameterSelector::NonInteractiveLTE2Party);
 | |||
| 
 | |||
|     // set application's common reference seed
 | |||
|     let mut seed = [0u8; 32];
 | |||
|     thread_rng().fill_bytes(&mut seed);
 | |||
|     set_common_reference_seed(seed);
 | |||
| 
 | |||
|     let no_of_parties = 2;
 | |||
| 
 | |||
|     // Generate client keys
 | |||
|     let cks = (0..no_of_parties).map(|_| gen_client_key()).collect_vec();
 | |||
| 
 | |||
|     // Generate server key shares
 | |||
|     let server_key_shares = cks
 | |||
|         .iter()
 | |||
|         .enumerate()
 | |||
|         .map(|(id, k)| gen_server_key_share(id, no_of_parties, k))
 | |||
|         .collect_vec();
 | |||
| 
 | |||
|     // Aggregate server key shares and set the server key
 | |||
|     let server_key = aggregate_server_key_shares(&server_key_shares);
 | |||
|     server_key.set_server_key();
 | |||
| 
 | |||
|     // --------
 | |||
| 
 | |||
|     // We attempt to divide by 0 in encrypted domain and then check whether div by 0
 | |||
|     // error flag is set to True.
 | |||
|     let numerator = thread_rng().gen::<u8>();
 | |||
|     let numerator_enc = cks[0]
 | |||
|         .encrypt(vec![numerator].as_slice())
 | |||
|         .unseed::<Vec<Vec<u64>>>()
 | |||
|         .key_switch(0)
 | |||
|         .extract_at(0);
 | |||
|     let zero_enc = cks[1]
 | |||
|         .encrypt(vec![0].as_slice())
 | |||
|         .unseed::<Vec<Vec<u64>>>()
 | |||
|         .key_switch(1)
 | |||
|         .extract_at(0);
 | |||
| 
 | |||
|     let (quotient_enc, remainder_enc) = numerator_enc.div_rem(&zero_enc);
 | |||
| 
 | |||
|     // When attempting to divide by zero, for uint8, quotient is always 255 and
 | |||
|     // remainder = numerator
 | |||
|     let quotient = cks[0].aggregate_decryption_shares(
 | |||
|         "ient_enc,
 | |||
|         &cks.iter()
 | |||
|             .map(|k| k.gen_decryption_share("ient_enc))
 | |||
|             .collect_vec(),
 | |||
|     );
 | |||
|     let remainder = cks[0].aggregate_decryption_shares(
 | |||
|         &remainder_enc,
 | |||
|         &cks.iter()
 | |||
|             .map(|k| k.gen_decryption_share(&remainder_enc))
 | |||
|             .collect_vec(),
 | |||
|     );
 | |||
|     assert!(quotient == 255);
 | |||
|     assert!(remainder == numerator);
 | |||
| 
 | |||
|     // Div by zero error flag must be True
 | |||
|     let div_by_zero_enc = div_zero_error_flag().expect("We performed division. Flag must be set");
 | |||
|     let div_by_zero = cks[0].aggregate_decryption_shares(
 | |||
|         &div_by_zero_enc,
 | |||
|         &cks.iter()
 | |||
|             .map(|k| k.gen_decryption_share(&div_by_zero_enc))
 | |||
|             .collect_vec(),
 | |||
|     );
 | |||
|     assert!(div_by_zero == true);
 | |||
| 
 | |||
|     // -------
 | |||
| 
 | |||
|     // div by zero error flag is thread local. If we were to run another circuit
 | |||
|     // without stopping the thread (i.e. within the same program as previous
 | |||
|     // one), we must reset errors flags set by previous circuit with
 | |||
|     // `reset_error_flags()` to prevent error flags of previous circuit affecting
 | |||
|     // the flags of the next circuit.
 | |||
|     reset_error_flags();
 | |||
| 
 | |||
|     // We divide again but with non-zero denominator this time and check that div
 | |||
|     // by zero flag is set to False
 | |||
|     let numerator = thread_rng().gen::<u8>();
 | |||
|     let denominator = thread_rng().gen::<u8>();
 | |||
|     let numerator_enc = cks[0]
 | |||
|         .encrypt(vec![numerator].as_slice())
 | |||
|         .unseed::<Vec<Vec<u64>>>()
 | |||
|         .key_switch(0)
 | |||
|         .extract_at(0);
 | |||
|     let denominator_enc = cks[1]
 | |||
|         .encrypt(vec![denominator].as_slice())
 | |||
|         .unseed::<Vec<Vec<u64>>>()
 | |||
|         .key_switch(1)
 | |||
|         .extract_at(0);
 | |||
| 
 | |||
|     let (quotient_enc, remainder_enc) = numerator_enc.div_rem(&denominator_enc);
 | |||
|     let quotient = cks[0].aggregate_decryption_shares(
 | |||
|         "ient_enc,
 | |||
|         &cks.iter()
 | |||
|             .map(|k| k.gen_decryption_share("ient_enc))
 | |||
|             .collect_vec(),
 | |||
|     );
 | |||
|     let remainder = cks[0].aggregate_decryption_shares(
 | |||
|         &remainder_enc,
 | |||
|         &cks.iter()
 | |||
|             .map(|k| k.gen_decryption_share(&remainder_enc))
 | |||
|             .collect_vec(),
 | |||
|     );
 | |||
|     assert!(quotient == numerator.div_euclid(denominator));
 | |||
|     assert!(remainder == numerator.rem_euclid(denominator));
 | |||
| 
 | |||
|     // Div by zero error flag must be set to False
 | |||
|     let div_by_zero_enc = div_zero_error_flag().expect("We performed division. Flag must be set");
 | |||
|     let div_by_zero = cks[0].aggregate_decryption_shares(
 | |||
|         &div_by_zero_enc,
 | |||
|         &cks.iter()
 | |||
|             .map(|k| k.gen_decryption_share(&div_by_zero_enc))
 | |||
|             .collect_vec(),
 | |||
|     );
 | |||
|     assert!(div_by_zero == false);
 | |||
| }
 | |||
| @ -0,0 +1,107 @@ | |||
| use bin_rs::*;
 | |||
| use itertools::Itertools;
 | |||
| use rand::{thread_rng, Rng, RngCore};
 | |||
| 
 | |||
| /// Code that runs if condition of conditional branch is `True`
 | |||
| fn circuit_branch_true(a: &FheUint8, b: &FheUint8) -> FheUint8 {
 | |||
|     a + b
 | |||
| }
 | |||
| 
 | |||
| /// Code that runs if condition of conditional branch is `False`
 | |||
| fn circuit_branch_false(a: &FheUint8, b: &FheUint8) -> FheUint8 {
 | |||
|     a * b
 | |||
| }
 | |||
| 
 | |||
| // Conditional branching (ie. If and else) are generally expensive in encrypted
 | |||
| // domain. The code must execute all the branches, and, as apparent, the
 | |||
| // runtime cost grows exponentially with no. of conditional branches.
 | |||
| //
 | |||
| // In general we recommend to write branchless code. In case the code cannot be
 | |||
| // modified to be branchless, the code must execute all branches and use a
 | |||
| // muxer to select correct output at the end.
 | |||
| //
 | |||
| // Below we showcase example of a single conditional branch in encrypted domain.
 | |||
| // The code executes both the branches (i.e. program runs both If and Else) and
 | |||
| // selects output of one of the branches with a mux.
 | |||
| fn main() {
 | |||
|     set_parameter_set(ParameterSelector::NonInteractiveLTE2Party);
 | |||
| 
 | |||
|     // set application's common reference seed
 | |||
|     let mut seed = [0u8; 32];
 | |||
|     thread_rng().fill_bytes(&mut seed);
 | |||
|     set_common_reference_seed(seed);
 | |||
| 
 | |||
|     let no_of_parties = 2;
 | |||
| 
 | |||
|     // Generate client keys
 | |||
|     let cks = (0..no_of_parties).map(|_| gen_client_key()).collect_vec();
 | |||
| 
 | |||
|     // Generate server key shares
 | |||
|     let server_key_shares = cks
 | |||
|         .iter()
 | |||
|         .enumerate()
 | |||
|         .map(|(id, k)| gen_server_key_share(id, no_of_parties, k))
 | |||
|         .collect_vec();
 | |||
| 
 | |||
|     // Aggregate server key shares and set the server key
 | |||
|     let server_key = aggregate_server_key_shares(&server_key_shares);
 | |||
|     server_key.set_server_key();
 | |||
| 
 | |||
|     // -------
 | |||
| 
 | |||
|     // User 0 encrypts their private input `v_a` and User 1 encrypts their
 | |||
|     // private input `v_b`. We want to execute:
 | |||
|     //
 | |||
|     // if v_a < v_b:
 | |||
|     //      return v_a + v_b
 | |||
|     // else:
 | |||
|     //      return v_a * v_b
 | |||
|     //
 | |||
|     // We define two functions
 | |||
|     //      (1) `circuit_branch_true`: which executes v_a + v_b in encrypted domain.
 | |||
|     //      (2) `circuit_branch_false`: which executes v_a * v_b in encrypted
 | |||
|     //                                  domain.
 | |||
|     //
 | |||
|     // The circuit runs both `circuit_branch_true` and `circuit_branch_false` and
 | |||
|     // then selects the output of `circuit_branch_true` if `v_a < v_b == TRUE`
 | |||
|     // otherwise selects the output of `circuit_branch_false` if `v_a < v_b ==
 | |||
|     // FALSE` using mux.
 | |||
| 
 | |||
|     // Clients private inputs
 | |||
|     let v_a = thread_rng().gen::<u8>();
 | |||
|     let v_b = thread_rng().gen::<u8>();
 | |||
|     let v_a_enc = cks[0]
 | |||
|         .encrypt(vec![v_a].as_slice())
 | |||
|         .unseed::<Vec<Vec<u64>>>()
 | |||
|         .key_switch(0)
 | |||
|         .extract_at(0);
 | |||
|     let v_b_enc = cks[1]
 | |||
|         .encrypt(vec![v_b].as_slice())
 | |||
|         .unseed::<Vec<Vec<u64>>>()
 | |||
|         .key_switch(1)
 | |||
|         .extract_at(0);
 | |||
| 
 | |||
|     // Run both branches
 | |||
|     let out_true_enc = circuit_branch_true(&v_a_enc, &v_b_enc);
 | |||
|     let out_false_enc = circuit_branch_false(&v_a_enc, &v_b_enc);
 | |||
| 
 | |||
|     // define condition select v_a < v_b
 | |||
|     let selector_bit = v_a_enc.lt(&v_b_enc);
 | |||
| 
 | |||
|     // select output of `circuit_branch_true` if selector_bit == TRUE otherwise
 | |||
|     // select output of `circuit_branch_false`
 | |||
|     let out_enc = out_true_enc.mux(&out_false_enc, &selector_bit);
 | |||
| 
 | |||
|     let out = cks[0].aggregate_decryption_shares(
 | |||
|         &out_enc,
 | |||
|         &cks.iter()
 | |||
|             .map(|k| k.gen_decryption_share(&out_enc))
 | |||
|             .collect_vec(),
 | |||
|     );
 | |||
|     let want_out = if v_a < v_b {
 | |||
|         v_a.wrapping_add(v_b)
 | |||
|     } else {
 | |||
|         v_a.wrapping_mul(v_b)
 | |||
|     };
 | |||
|     assert_eq!(out, want_out);
 | |||
| }
 | |||