From f2be81f7eb90a06c26be28f888e7ba4635800c76 Mon Sep 17 00:00:00 2001 From: Janmajaya Mall Date: Mon, 1 Jul 2024 11:48:10 +0530 Subject: [PATCH] add div by zero --- src/bool/mp_api.rs | 18 +++++++++++++----- src/lib.rs | 2 +- src/shortint/mod.rs | 30 +++++++++++++++++++++++++----- src/shortint/ops.rs | 6 +++++- 4 files changed, 44 insertions(+), 12 deletions(-) diff --git a/src/bool/mp_api.rs b/src/bool/mp_api.rs index 19fb93c..e0529b3 100644 --- a/src/bool/mp_api.rs +++ b/src/bool/mp_api.rs @@ -508,13 +508,15 @@ mod tests { fn all_uint8_apis() { use num_traits::Euclid; + use crate::div_zero_error_flag; + set_single_party_parameter_sets(SP_TEST_BOOL_PARAMS); let (ck, sk) = gen_keys(); sk.set_server_key(); - for i in 144..=255 { - for j in 100..=255 { + for i in 0..=255 { + for j in 0..=255 { let m0 = i; let m1 = j; let c0 = ck.encrypt(&m0); @@ -574,7 +576,7 @@ mod tests { ); assert_eq!( m_remainder, r, - "Expected {} but got {m_quotient} for {i}%{j}", + "Expected {} but got {m_remainder} for {i}%{j}", r ); } else { @@ -584,8 +586,14 @@ mod tests { ); assert_eq!( m_remainder, i, - "Expected {i} but got {m_quotient}. Case div by zero" - ) + "Expected {i} but got {m_remainder}. Case div by zero" + ); + + let div_by_zero = ck.decrypt(&div_zero_error_flag().unwrap()); + assert_eq!( + div_by_zero, true, + "Expected true but got {div_by_zero}" + ); } } } diff --git a/src/lib.rs b/src/lib.rs index c62348c..0e91204 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,7 +22,7 @@ pub use backend::{ // ParameterSelector, }; pub use bool::*; pub use ntt::{Ntt, NttBackendU64, NttInit}; -pub use shortint::FheUint8; +pub use shortint::{div_zero_error_flag, FheUint8}; pub use decomposer::{Decomposer, DecomposerIter, DefaultDecomposer}; diff --git a/src/shortint/mod.rs b/src/shortint/mod.rs index a073c13..61ce65a 100644 --- a/src/shortint/mod.rs +++ b/src/shortint/mod.rs @@ -5,8 +5,20 @@ mod types; pub type FheUint8 = enc_dec::FheUint8>; pub type FheBool = Vec; +use std::cell::RefCell; + use crate::bool::{evaluator::BooleanGates, BoolEvaluator, RuntimeServerKey}; +thread_local! { + static DIV_ZERO_ERROR: RefCell> = RefCell::new(None); +} + +/// Returns Boolean ciphertext indicating whether last division was attempeted +/// with decnomiantor set to 0. +pub fn div_zero_error_flag() -> Option> { + DIV_ZERO_ERROR.with_borrow(|c| c.clone()) +} + mod frontend { use super::ops::{ arbitrary_bit_adder, arbitrary_bit_division_for_quotient_and_rem, arbitrary_bit_subtractor, @@ -18,6 +30,8 @@ mod frontend { mod arithetic { + use ops::is_zero; + use super::*; use std::ops::{Add, AddAssign, Div, Mul, Rem, Sub}; @@ -64,9 +78,13 @@ mod frontend { impl Div<&FheUint8> for &FheUint8 { type Output = FheUint8; fn div(self, rhs: &FheUint8) -> Self::Output { - // TODO(Jay:) Figure out how to set zero error flag BoolEvaluator::with_local_mut(|e| { let key = RuntimeServerKey::global(); + + // set div by 0 error flag + let is_zero = is_zero(e, rhs.data(), key); + DIV_ZERO_ERROR.set(Some(is_zero)); + let (quotient, _) = arbitrary_bit_division_for_quotient_and_rem( e, self.data(), @@ -125,9 +143,13 @@ mod frontend { } pub fn div_rem(&self, rhs: &FheUint8) -> (FheUint8, FheUint8) { - // TODO(Jay:) Figure out how to set zero error flag BoolEvaluator::with_local_mut(|e| { let key = RuntimeServerKey::global(); + + // set div by 0 error flag + let is_zero = is_zero(e, rhs.data(), key); + DIV_ZERO_ERROR.set(Some(is_zero)); + let (quotient, remainder) = arbitrary_bit_division_for_quotient_and_rem( e, self.data(), @@ -141,9 +163,7 @@ mod frontend { } mod booleans { - use crate::shortint::ops::{ - arbitrary_bit_comparator, arbitrary_bit_equality, arbitrary_signed_bit_comparator, - }; + use crate::shortint::ops::{arbitrary_bit_comparator, arbitrary_bit_equality}; use super::*; diff --git a/src/shortint/ops.rs b/src/shortint/ops.rs index 9f472b7..a6f6833 100644 --- a/src/shortint/ops.rs +++ b/src/shortint/ops.rs @@ -256,7 +256,11 @@ where (quotient, remainder) } -fn is_zero(evaluator: &mut E, a: &[E::Ciphertext], key: &E::Key) -> E::Ciphertext { +pub(super) fn is_zero( + evaluator: &mut E, + a: &[E::Ciphertext], + key: &E::Key, +) -> E::Ciphertext { let mut a = a.iter().map(|v| evaluator.not(v)).collect_vec(); let (out, rest_a) = a.split_at_mut(1); rest_a.iter().for_each(|c| {