diff --git a/src/bool/evaluator.rs b/src/bool/evaluator.rs index 815a11c..facee81 100644 --- a/src/bool/evaluator.rs +++ b/src/bool/evaluator.rs @@ -1580,7 +1580,7 @@ where } /// Returns 2(c0 - c1) + Q/4 - fn _subtract_double_and_shift_lwe_cts(&self, c0: &mut M::R, c1: &M::R) { + fn _subtract_double_lwe_cts(&self, c0: &mut M::R, c1: &M::R) { let modop = &self.pbs_info.rlwe_modop; // c0 - c1 modop.elwise_sub_mut(c0.as_mut(), c1.as_ref()); @@ -1688,7 +1688,7 @@ where c1: &M::R, server_key: &ServerKeyEvaluationDomain, ) { - self._subtract_double_and_shift_lwe_cts(c0, c1); + self._subtract_double_lwe_cts(c0, c1); // PBS pbs( @@ -1707,7 +1707,7 @@ where c1: &M::R, server_key: &ServerKeyEvaluationDomain, ) { - self._subtract_double_and_shift_lwe_cts(c0, c1); + self._subtract_double_lwe_cts(c0, c1); // PBS pbs( diff --git a/src/bool/mod.rs b/src/bool/mod.rs index 468272e..992bafe 100644 --- a/src/bool/mod.rs +++ b/src/bool/mod.rs @@ -1,2 +1,4 @@ pub(crate) mod evaluator; pub(crate) mod parameters; + +pub type FheBool = Vec; diff --git a/src/shortint/mod.rs b/src/shortint/mod.rs index ddd9a41..67f0e36 100644 --- a/src/shortint/mod.rs +++ b/src/shortint/mod.rs @@ -5,49 +5,12 @@ use crate::{ utils::{Global, WithLocal}, Decryptor, Encryptor, }; -use ops::{ - arbitrary_bit_adder, arbitrary_bit_division_for_quotient_and_rem, arbitrary_bit_subtractor, - eight_bit_mul, -}; mod ops; mod types; type FheUint8 = types::FheUint8>; -fn add_mut(a: &mut FheUint8, b: &FheUint8) { - BoolEvaluator::with_local_mut_mut(&mut |e| { - let key = ServerKeyEvaluationDomain::global(); - arbitrary_bit_adder(e, a.data_mut(), b.data(), false, key); - }); -} - -fn sub(a: &FheUint8, b: &FheUint8) -> FheUint8 { - BoolEvaluator::with_local_mut(|e| { - let key = ServerKeyEvaluationDomain::global(); - let (out, _, _) = arbitrary_bit_subtractor(e, a.data(), b.data(), key); - FheUint8 { data: out } - }) -} - -fn mul(a: &FheUint8, b: &FheUint8) -> FheUint8 { - BoolEvaluator::with_local_mut(|e| { - let key = ServerKeyEvaluationDomain::global(); - let out = eight_bit_mul(e, a.data(), b.data(), key); - FheUint8 { data: out } - }) -} - -fn div(a: &FheUint8, b: &FheUint8) -> (FheUint8, FheUint8) { - BoolEvaluator::with_local_mut(|e| { - let key = ServerKeyEvaluationDomain::global(); - let (quotient, remainder) = - arbitrary_bit_division_for_quotient_and_rem(e, a.data(), b.data(), key); - - (FheUint8 { data: quotient }, FheUint8 { data: remainder }) - }) -} - impl Encryptor for ClientKey { fn encrypt(&self, m: &u8) -> FheUint8 { let cts = (0..8) @@ -84,9 +47,11 @@ mod frontend { utils::{Global, WithLocal}, }; - use super::{add_mut, div, mul, FheUint8}; + use super::FheUint8; mod arithetic { + use crate::bool::{evaluator::BooleanGates, FheBool}; + use super::*; use std::ops::{Add, AddAssign, Div, Mul, Rem, Sub}; @@ -113,7 +78,7 @@ mod frontend { fn sub(self, rhs: &FheUint8) -> Self::Output { BoolEvaluator::with_local_mut(|e| { let key = ServerKeyEvaluationDomain::global(); - let (out, _, _) = arbitrary_bit_subtractor(e, self.data(), self.data(), key); + let (out, _, _) = arbitrary_bit_subtractor(e, self.data(), rhs.data(), key); FheUint8 { data: out } }) } @@ -133,6 +98,7 @@ mod frontend { impl Div<&FheUint8> for &FheUint8 { type Output = FheUint8; fn div(self, rhs: &FheUint8) -> Self::Output { + // TODO(Jay:) Figure out how to set zero error flag BoolEvaluator::with_local_mut(|e| { let key = ServerKeyEvaluationDomain::global(); let (quotient, _) = arbitrary_bit_division_for_quotient_and_rem( @@ -161,9 +127,120 @@ mod frontend { }) } } + + impl FheUint8 { + pub fn overflowing_add_assign(&mut self, rhs: &FheUint8) -> FheBool { + BoolEvaluator::with_local_mut_mut(&mut |e| { + let key = ServerKeyEvaluationDomain::global(); + let (overflow, _) = + arbitrary_bit_adder(e, self.data_mut(), rhs.data(), false, key); + overflow + }) + } + + pub fn overflowing_add(self, rhs: &FheUint8) -> (FheUint8, FheBool) { + BoolEvaluator::with_local_mut(|e| { + let mut lhs = self.clone(); + let key = ServerKeyEvaluationDomain::global(); + let (overflow, _) = + arbitrary_bit_adder(e, lhs.data_mut(), rhs.data(), false, key); + (lhs, overflow) + }) + } + + pub fn overflowing_sub(&self, rhs: &FheUint8) -> (FheUint8, FheBool) { + BoolEvaluator::with_local_mut(|e| { + let key = ServerKeyEvaluationDomain::global(); + let (out, mut overflow, _) = + arbitrary_bit_subtractor(e, self.data(), rhs.data(), key); + e.not_inplace(&mut overflow); + (FheUint8 { data: out }, overflow) + }) + } + + pub fn div_rem(&self, rhs: &FheUint8) -> (FheUint8, FheUint8) { + // TODO(Jay:) Figure out how to set zero error flag + BoolEvaluator::with_local_mut(|e| { + let key = ServerKeyEvaluationDomain::global(); + let (quotient, remainder) = arbitrary_bit_division_for_quotient_and_rem( + e, + self.data(), + rhs.data(), + key, + ); + (FheUint8 { data: quotient }, FheUint8 { data: remainder }) + }) + } + } } - mod booleans {} + mod booleans { + use crate::{ + bool::{evaluator::BooleanGates, FheBool}, + shortint::ops::{ + arbitrary_bit_comparator, arbitrary_bit_equality, arbitrary_signed_bit_comparator, + }, + }; + + use super::*; + + impl FheUint8 { + /// a == b + pub fn eq(&self, other: &FheUint8) -> FheBool { + BoolEvaluator::with_local_mut(|e| { + let key = ServerKeyEvaluationDomain::global(); + arbitrary_bit_equality(e, self.data(), other.data(), key) + }) + } + + /// a != b + pub fn neq(&self, other: &FheUint8) -> FheBool { + BoolEvaluator::with_local_mut(|e| { + let key = ServerKeyEvaluationDomain::global(); + let mut is_equal = arbitrary_bit_equality(e, self.data(), other.data(), key); + e.not_inplace(&mut is_equal); + is_equal + }) + } + + /// a < b + pub fn lt(&self, other: &FheUint8) -> FheBool { + BoolEvaluator::with_local_mut(|e| { + let key = ServerKeyEvaluationDomain::global(); + arbitrary_bit_comparator(e, other.data(), self.data(), key) + }) + } + + /// a > b + pub fn gt(&self, other: &FheUint8) -> FheBool { + BoolEvaluator::with_local_mut(|e| { + let key = ServerKeyEvaluationDomain::global(); + arbitrary_bit_comparator(e, self.data(), other.data(), key) + }) + } + + /// a <= b + pub fn le(&self, other: &FheUint8) -> FheBool { + BoolEvaluator::with_local_mut(|e| { + let key = ServerKeyEvaluationDomain::global(); + let mut a_greater_b = + arbitrary_bit_comparator(e, self.data(), other.data(), key); + e.not_inplace(&mut a_greater_b); + a_greater_b + }) + } + + /// a >= b + pub fn ge(&self, other: &FheUint8) -> FheBool { + BoolEvaluator::with_local_mut(|e| { + let key = ServerKeyEvaluationDomain::global(); + let mut a_less_b = arbitrary_bit_comparator(e, other.data(), self.data(), key); + e.not_inplace(&mut a_less_b); + a_less_b + }) + } + } + } } #[cfg(test)] @@ -175,19 +252,19 @@ mod tests { evaluator::{gen_keys, set_parameter_set, BoolEvaluator}, parameters::SP_BOOL_PARAMS, }, - shortint::{add_mut, div, mul, sub, types::FheUint8}, + shortint::types::FheUint8, Decryptor, Encryptor, }; #[test] - fn qwerty() { + fn all_uint8_apis() { set_parameter_set(&SP_BOOL_PARAMS); let (ck, sk) = gen_keys(); sk.set_server_key(); - for i in 1..=255 { - for j in 0..=255 { + for i in 144..=255 { + for j in 100..=255 { let m0 = i; let m1 = j; let c0 = ck.encrypt(&m0); @@ -196,66 +273,133 @@ mod tests { assert!(ck.decrypt(&c0) == m0); assert!(ck.decrypt(&c1) == m1); - // Add - // let mut c_m0_plus_m1 = FheUint8 { - // data: c0.data().to_vec(), - // }; - // add_mut(&mut c_m0_plus_m1, &c1); - // let m0_plus_m1 = ck.decrypt(&c_m0_plus_m1); - // assert_eq!( - // m0_plus_m1, - // m0.wrapping_add(m1), - // "Expected {} but got {m0_plus_m1} for {i}+{j}", - // m0.wrapping_add(m1) - // ); - - // Sub - // let c_sub = sub(&c0, &c1); - // let m0_sub_m1 = ck.decrypt(&c_sub); - // dbg!(m0, m1, m0_sub_m1); - // assert_eq!( - // m0_sub_m1, - // m0.wrapping_sub(m1), - // "Expected {} but got {m0_sub_m1} for {i}-{j}", - // m0.wrapping_sub(m1) - // ); - - // Mul - // let c_m0m1 = mul(&c0, &c1); - // let m0m1 = ck.decrypt(&c_m0m1); - // assert_eq!( - // m0m1, - // m0.wrapping_mul(m1), - // "Expected {} but got {m0m1} for {i}x{j}", - // m0.wrapping_mul(m1) - // ); - - // Div - // let (c_quotient, c_rem) = div(&c0, &c1); - // let m_quotient = ck.decrypt(&c_quotient); - // let m_remainder = ck.decrypt(&c_rem); - // if j != 0 { - // let (q, r) = i.div_rem_euclid(&j); - // assert_eq!( - // m_quotient, q, - // "Expected {} but got {m_quotient} for {i}/{j}", - // q - // ); - // assert_eq!( - // m_remainder, r, - // "Expected {} but got {m_quotient} for {i}%{j}", - // r - // ); - // } else { - // assert_eq!( - // m_quotient, 255, - // "Expected 255 but got {m_quotient}. Case div by zero" - // ); - // assert_eq!( - // m_remainder, i, - // "Expected {i} but got {m_quotient}. Case div by zero" - // ) - // } + // Arithmetic + { + { + // Add + let mut c_m0_plus_m1 = FheUint8 { + data: c0.data().to_vec(), + }; + c_m0_plus_m1 += &c1; + let m0_plus_m1 = ck.decrypt(&c_m0_plus_m1); + assert_eq!( + m0_plus_m1, + m0.wrapping_add(m1), + "Expected {} but got {m0_plus_m1} for {i}+{j}", + m0.wrapping_add(m1) + ); + } + { + // Sub + let c_sub = &c0 - &c1; + let m0_sub_m1 = ck.decrypt(&c_sub); + assert_eq!( + m0_sub_m1, + m0.wrapping_sub(m1), + "Expected {} but got {m0_sub_m1} for {i}-{j}", + m0.wrapping_sub(m1) + ); + } + + { + // Mul + let c_m0m1 = &c0 * &c1; + let m0m1 = ck.decrypt(&c_m0m1); + assert_eq!( + m0m1, + m0.wrapping_mul(m1), + "Expected {} but got {m0m1} for {i}x{j}", + m0.wrapping_mul(m1) + ); + } + + // Div & Rem + { + let (c_quotient, c_rem) = c0.div_rem(&c1); + let m_quotient = ck.decrypt(&c_quotient); + let m_remainder = ck.decrypt(&c_rem); + if j != 0 { + let (q, r) = i.div_rem_euclid(&j); + assert_eq!( + m_quotient, q, + "Expected {} but got {m_quotient} for {i}/{j}", + q + ); + assert_eq!( + m_remainder, r, + "Expected {} but got {m_quotient} for {i}%{j}", + r + ); + } else { + assert_eq!( + m_quotient, 255, + "Expected 255 but got {m_quotient}. Case div by zero" + ); + assert_eq!( + m_remainder, i, + "Expected {i} but got {m_quotient}. Case div by zero" + ) + } + } + } + + // Comparisons + { + { + let c_eq = c0.eq(&c1); + let is_eq = ck.decrypt(&c_eq); + assert_eq!( + is_eq, + i == j, + "Expected {} but got {is_eq} for {i}=={j}", + i == j + ); + } + + { + let c_gt = c0.gt(&c1); + let is_gt = ck.decrypt(&c_gt); + assert_eq!( + is_gt, + i > j, + "Expected {} but got {is_gt} for {i}>{j}", + i > j + ); + } + + { + let c_lt = c0.lt(&c1); + let is_lt = ck.decrypt(&c_lt); + assert_eq!( + is_lt, + i < j, + "Expected {} but got {is_lt} for {i}<{j}", + i < j + ); + } + + { + let c_ge = c0.ge(&c1); + let is_ge = ck.decrypt(&c_ge); + assert_eq!( + is_ge, + i >= j, + "Expected {} but got {is_ge} for {i}>={j}", + i >= j + ); + } + + { + let c_le = c0.le(&c1); + let is_le = ck.decrypt(&c_le); + assert_eq!( + is_le, + i <= j, + "Expected {} but got {is_le} for {i}<={j}", + i <= j + ); + } + } } } } diff --git a/src/shortint/ops.rs b/src/shortint/ops.rs index 31ca2af..73fb363 100644 --- a/src/shortint/ops.rs +++ b/src/shortint/ops.rs @@ -277,22 +277,22 @@ fn is_zero(evaluator: &mut E, a: &[E::Ciphertext], key: &E::Key return a.remove(0); } -fn arbitrary_bit_equality( +pub(super) fn arbitrary_bit_equality( evaluator: &mut E, a: &[E::Ciphertext], b: &[E::Ciphertext], key: &E::Key, ) -> E::Ciphertext { assert!(a.len() == b.len()); - let mut out = evaluator.and(&a[0], &b[0], key); + let mut out = evaluator.xnor(&a[0], &b[0], key); izip!(a.iter(), b.iter()).skip(1).for_each(|(abit, bbit)| { let e = evaluator.xnor(abit, bbit, key); - evaluator.and(&mut out, &e, key); + evaluator.and_inplace(&mut out, &e, key); }); return out; } -/// Comaprator handle computes comparator result 2ns MSB onwards. It is +/// Comparator handle computes comparator result 2ns MSB onwards. It is /// separated because comparator subroutine for signed and unsgind integers /// differs only for 1st MSB and is common second MSB onwards fn _comparator_handler_from_second_msb( @@ -307,27 +307,27 @@ fn _comparator_handler_from_second_msb( // handle MSB - 1 let mut tmp = evaluator.not(&b[n - 2]); - evaluator.and(&mut tmp, &a[n - 2], key); - evaluator.and(&mut tmp, &casc, key); - evaluator.or(&mut comp, &tmp, key); + evaluator.and_inplace(&mut tmp, &a[n - 2], key); + evaluator.and_inplace(&mut tmp, &casc, key); + evaluator.or_inplace(&mut comp, &tmp, key); for i in 2..n { // calculate cascading bit - let tmp_casc = evaluator.xnor(&a[n - 2 - i], &b[n - 2 - i], key); - evaluator.and(&mut casc, &tmp_casc, key); + let tmp_casc = evaluator.xnor(&a[n - i], &b[n - i], key); + evaluator.and_inplace(&mut casc, &tmp_casc, key); // calculate computate bit let mut tmp = evaluator.not(&b[n - 1 - i]); - evaluator.and(&mut tmp, &a[n - 1 - i], key); - evaluator.and(&mut tmp, &casc, key); - evaluator.or(&mut comp, &tmp, key); + evaluator.and_inplace(&mut tmp, &a[n - 1 - i], key); + evaluator.and_inplace(&mut tmp, &casc, key); + evaluator.or_inplace(&mut comp, &tmp, key); } return comp; } /// Signed integer comparison is same as unsigned integer with MSB flipped. -fn arbitrary_signed_bit_comparator( +pub(super) fn arbitrary_signed_bit_comparator( evaluator: &mut E, a: &[E::Ciphertext], b: &[E::Ciphertext], @@ -338,13 +338,13 @@ fn arbitrary_signed_bit_comparator( // handle MSB let mut comp = evaluator.not(&a[n - 1]); - evaluator.and(&mut comp, &b[n - 1], key); // comp + evaluator.and_inplace(&mut comp, &b[n - 1], key); // comp let casc = evaluator.xnor(&a[n - 1], &b[n - 1], key); // casc return _comparator_handler_from_second_msb(evaluator, a, b, comp, casc, key); } -fn arbitrary_bit_comparator( +pub(super) fn arbitrary_bit_comparator( evaluator: &mut E, a: &[E::Ciphertext], b: &[E::Ciphertext], @@ -355,7 +355,7 @@ fn arbitrary_bit_comparator( // handle MSB let mut comp = evaluator.not(&b[n - 1]); - evaluator.and(&mut comp, &a[n - 1], key); + evaluator.and_inplace(&mut comp, &a[n - 1], key); let casc = evaluator.xnor(&a[n - 1], &b[n - 1], key); return _comparator_handler_from_second_msb(evaluator, a, b, comp, casc, key);