Browse Source

uint8 frontend works

par-agg-key-shares
Janmajaya Mall 10 months ago
parent
commit
3dc00766aa
4 changed files with 269 additions and 123 deletions
  1. +3
    -3
      src/bool/evaluator.rs
  2. +2
    -0
      src/bool/mod.rs
  3. +248
    -104
      src/shortint/mod.rs
  4. +16
    -16
      src/shortint/ops.rs

+ 3
- 3
src/bool/evaluator.rs

@ -1580,7 +1580,7 @@ where
} }
/// Returns 2(c0 - c1) + Q/4 /// Returns 2(c0 - c1) + Q/4
fn _subtract_double_and_shift_lwe_cts(&self, c0: &mut M::R, c1: &M::R) {
fn _subtract_double_lwe_cts(&self, c0: &mut M::R, c1: &M::R) {
let modop = &self.pbs_info.rlwe_modop; let modop = &self.pbs_info.rlwe_modop;
// c0 - c1 // c0 - c1
modop.elwise_sub_mut(c0.as_mut(), c1.as_ref()); modop.elwise_sub_mut(c0.as_mut(), c1.as_ref());
@ -1688,7 +1688,7 @@ where
c1: &M::R, c1: &M::R,
server_key: &ServerKeyEvaluationDomain<M, DefaultSecureRng, NttOp>, server_key: &ServerKeyEvaluationDomain<M, DefaultSecureRng, NttOp>,
) { ) {
self._subtract_double_and_shift_lwe_cts(c0, c1);
self._subtract_double_lwe_cts(c0, c1);
// PBS // PBS
pbs( pbs(
@ -1707,7 +1707,7 @@ where
c1: &M::R, c1: &M::R,
server_key: &ServerKeyEvaluationDomain<M, DefaultSecureRng, NttOp>, server_key: &ServerKeyEvaluationDomain<M, DefaultSecureRng, NttOp>,
) { ) {
self._subtract_double_and_shift_lwe_cts(c0, c1);
self._subtract_double_lwe_cts(c0, c1);
// PBS // PBS
pbs( pbs(

+ 2
- 0
src/bool/mod.rs

@ -1,2 +1,4 @@
pub(crate) mod evaluator; pub(crate) mod evaluator;
pub(crate) mod parameters; pub(crate) mod parameters;
pub type FheBool = Vec<u64>;

+ 248
- 104
src/shortint/mod.rs

@ -5,49 +5,12 @@ use crate::{
utils::{Global, WithLocal}, utils::{Global, WithLocal},
Decryptor, Encryptor, Decryptor, Encryptor,
}; };
use ops::{
arbitrary_bit_adder, arbitrary_bit_division_for_quotient_and_rem, arbitrary_bit_subtractor,
eight_bit_mul,
};
mod ops; mod ops;
mod types; mod types;
type FheUint8 = types::FheUint8<Vec<u64>>; type FheUint8 = types::FheUint8<Vec<u64>>;
fn add_mut(a: &mut FheUint8, b: &FheUint8) {
BoolEvaluator::with_local_mut_mut(&mut |e| {
let key = ServerKeyEvaluationDomain::global();
arbitrary_bit_adder(e, a.data_mut(), b.data(), false, key);
});
}
fn sub(a: &FheUint8, b: &FheUint8) -> FheUint8 {
BoolEvaluator::with_local_mut(|e| {
let key = ServerKeyEvaluationDomain::global();
let (out, _, _) = arbitrary_bit_subtractor(e, a.data(), b.data(), key);
FheUint8 { data: out }
})
}
fn mul(a: &FheUint8, b: &FheUint8) -> FheUint8 {
BoolEvaluator::with_local_mut(|e| {
let key = ServerKeyEvaluationDomain::global();
let out = eight_bit_mul(e, a.data(), b.data(), key);
FheUint8 { data: out }
})
}
fn div(a: &FheUint8, b: &FheUint8) -> (FheUint8, FheUint8) {
BoolEvaluator::with_local_mut(|e| {
let key = ServerKeyEvaluationDomain::global();
let (quotient, remainder) =
arbitrary_bit_division_for_quotient_and_rem(e, a.data(), b.data(), key);
(FheUint8 { data: quotient }, FheUint8 { data: remainder })
})
}
impl Encryptor<u8, FheUint8> for ClientKey { impl Encryptor<u8, FheUint8> for ClientKey {
fn encrypt(&self, m: &u8) -> FheUint8 { fn encrypt(&self, m: &u8) -> FheUint8 {
let cts = (0..8) let cts = (0..8)
@ -84,9 +47,11 @@ mod frontend {
utils::{Global, WithLocal}, utils::{Global, WithLocal},
}; };
use super::{add_mut, div, mul, FheUint8};
use super::FheUint8;
mod arithetic { mod arithetic {
use crate::bool::{evaluator::BooleanGates, FheBool};
use super::*; use super::*;
use std::ops::{Add, AddAssign, Div, Mul, Rem, Sub}; use std::ops::{Add, AddAssign, Div, Mul, Rem, Sub};
@ -113,7 +78,7 @@ mod frontend {
fn sub(self, rhs: &FheUint8) -> Self::Output { fn sub(self, rhs: &FheUint8) -> Self::Output {
BoolEvaluator::with_local_mut(|e| { BoolEvaluator::with_local_mut(|e| {
let key = ServerKeyEvaluationDomain::global(); let key = ServerKeyEvaluationDomain::global();
let (out, _, _) = arbitrary_bit_subtractor(e, self.data(), self.data(), key);
let (out, _, _) = arbitrary_bit_subtractor(e, self.data(), rhs.data(), key);
FheUint8 { data: out } FheUint8 { data: out }
}) })
} }
@ -133,6 +98,7 @@ mod frontend {
impl Div<&FheUint8> for &FheUint8 { impl Div<&FheUint8> for &FheUint8 {
type Output = FheUint8; type Output = FheUint8;
fn div(self, rhs: &FheUint8) -> Self::Output { fn div(self, rhs: &FheUint8) -> Self::Output {
// TODO(Jay:) Figure out how to set zero error flag
BoolEvaluator::with_local_mut(|e| { BoolEvaluator::with_local_mut(|e| {
let key = ServerKeyEvaluationDomain::global(); let key = ServerKeyEvaluationDomain::global();
let (quotient, _) = arbitrary_bit_division_for_quotient_and_rem( let (quotient, _) = arbitrary_bit_division_for_quotient_and_rem(
@ -161,9 +127,120 @@ mod frontend {
}) })
} }
} }
impl FheUint8 {
pub fn overflowing_add_assign(&mut self, rhs: &FheUint8) -> FheBool {
BoolEvaluator::with_local_mut_mut(&mut |e| {
let key = ServerKeyEvaluationDomain::global();
let (overflow, _) =
arbitrary_bit_adder(e, self.data_mut(), rhs.data(), false, key);
overflow
})
}
pub fn overflowing_add(self, rhs: &FheUint8) -> (FheUint8, FheBool) {
BoolEvaluator::with_local_mut(|e| {
let mut lhs = self.clone();
let key = ServerKeyEvaluationDomain::global();
let (overflow, _) =
arbitrary_bit_adder(e, lhs.data_mut(), rhs.data(), false, key);
(lhs, overflow)
})
}
pub fn overflowing_sub(&self, rhs: &FheUint8) -> (FheUint8, FheBool) {
BoolEvaluator::with_local_mut(|e| {
let key = ServerKeyEvaluationDomain::global();
let (out, mut overflow, _) =
arbitrary_bit_subtractor(e, self.data(), rhs.data(), key);
e.not_inplace(&mut overflow);
(FheUint8 { data: out }, overflow)
})
}
pub fn div_rem(&self, rhs: &FheUint8) -> (FheUint8, FheUint8) {
// TODO(Jay:) Figure out how to set zero error flag
BoolEvaluator::with_local_mut(|e| {
let key = ServerKeyEvaluationDomain::global();
let (quotient, remainder) = arbitrary_bit_division_for_quotient_and_rem(
e,
self.data(),
rhs.data(),
key,
);
(FheUint8 { data: quotient }, FheUint8 { data: remainder })
})
}
}
} }
mod booleans {}
mod booleans {
use crate::{
bool::{evaluator::BooleanGates, FheBool},
shortint::ops::{
arbitrary_bit_comparator, arbitrary_bit_equality, arbitrary_signed_bit_comparator,
},
};
use super::*;
impl FheUint8 {
/// a == b
pub fn eq(&self, other: &FheUint8) -> FheBool {
BoolEvaluator::with_local_mut(|e| {
let key = ServerKeyEvaluationDomain::global();
arbitrary_bit_equality(e, self.data(), other.data(), key)
})
}
/// a != b
pub fn neq(&self, other: &FheUint8) -> FheBool {
BoolEvaluator::with_local_mut(|e| {
let key = ServerKeyEvaluationDomain::global();
let mut is_equal = arbitrary_bit_equality(e, self.data(), other.data(), key);
e.not_inplace(&mut is_equal);
is_equal
})
}
/// a < b
pub fn lt(&self, other: &FheUint8) -> FheBool {
BoolEvaluator::with_local_mut(|e| {
let key = ServerKeyEvaluationDomain::global();
arbitrary_bit_comparator(e, other.data(), self.data(), key)
})
}
/// a > b
pub fn gt(&self, other: &FheUint8) -> FheBool {
BoolEvaluator::with_local_mut(|e| {
let key = ServerKeyEvaluationDomain::global();
arbitrary_bit_comparator(e, self.data(), other.data(), key)
})
}
/// a <= b
pub fn le(&self, other: &FheUint8) -> FheBool {
BoolEvaluator::with_local_mut(|e| {
let key = ServerKeyEvaluationDomain::global();
let mut a_greater_b =
arbitrary_bit_comparator(e, self.data(), other.data(), key);
e.not_inplace(&mut a_greater_b);
a_greater_b
})
}
/// a >= b
pub fn ge(&self, other: &FheUint8) -> FheBool {
BoolEvaluator::with_local_mut(|e| {
let key = ServerKeyEvaluationDomain::global();
let mut a_less_b = arbitrary_bit_comparator(e, other.data(), self.data(), key);
e.not_inplace(&mut a_less_b);
a_less_b
})
}
}
}
} }
#[cfg(test)] #[cfg(test)]
@ -175,19 +252,19 @@ mod tests {
evaluator::{gen_keys, set_parameter_set, BoolEvaluator}, evaluator::{gen_keys, set_parameter_set, BoolEvaluator},
parameters::SP_BOOL_PARAMS, parameters::SP_BOOL_PARAMS,
}, },
shortint::{add_mut, div, mul, sub, types::FheUint8},
shortint::types::FheUint8,
Decryptor, Encryptor, Decryptor, Encryptor,
}; };
#[test] #[test]
fn qwerty() {
fn all_uint8_apis() {
set_parameter_set(&SP_BOOL_PARAMS); set_parameter_set(&SP_BOOL_PARAMS);
let (ck, sk) = gen_keys(); let (ck, sk) = gen_keys();
sk.set_server_key(); sk.set_server_key();
for i in 1..=255 {
for j in 0..=255 {
for i in 144..=255 {
for j in 100..=255 {
let m0 = i; let m0 = i;
let m1 = j; let m1 = j;
let c0 = ck.encrypt(&m0); let c0 = ck.encrypt(&m0);
@ -196,66 +273,133 @@ mod tests {
assert!(ck.decrypt(&c0) == m0); assert!(ck.decrypt(&c0) == m0);
assert!(ck.decrypt(&c1) == m1); assert!(ck.decrypt(&c1) == m1);
// Add
// let mut c_m0_plus_m1 = FheUint8 {
// data: c0.data().to_vec(),
// };
// add_mut(&mut c_m0_plus_m1, &c1);
// let m0_plus_m1 = ck.decrypt(&c_m0_plus_m1);
// assert_eq!(
// m0_plus_m1,
// m0.wrapping_add(m1),
// "Expected {} but got {m0_plus_m1} for {i}+{j}",
// m0.wrapping_add(m1)
// );
// Sub
// let c_sub = sub(&c0, &c1);
// let m0_sub_m1 = ck.decrypt(&c_sub);
// dbg!(m0, m1, m0_sub_m1);
// assert_eq!(
// m0_sub_m1,
// m0.wrapping_sub(m1),
// "Expected {} but got {m0_sub_m1} for {i}-{j}",
// m0.wrapping_sub(m1)
// );
// Mul
// let c_m0m1 = mul(&c0, &c1);
// let m0m1 = ck.decrypt(&c_m0m1);
// assert_eq!(
// m0m1,
// m0.wrapping_mul(m1),
// "Expected {} but got {m0m1} for {i}x{j}",
// m0.wrapping_mul(m1)
// );
// Div
// let (c_quotient, c_rem) = div(&c0, &c1);
// let m_quotient = ck.decrypt(&c_quotient);
// let m_remainder = ck.decrypt(&c_rem);
// if j != 0 {
// let (q, r) = i.div_rem_euclid(&j);
// assert_eq!(
// m_quotient, q,
// "Expected {} but got {m_quotient} for {i}/{j}",
// q
// );
// assert_eq!(
// m_remainder, r,
// "Expected {} but got {m_quotient} for {i}%{j}",
// r
// );
// } else {
// assert_eq!(
// m_quotient, 255,
// "Expected 255 but got {m_quotient}. Case div by zero"
// );
// assert_eq!(
// m_remainder, i,
// "Expected {i} but got {m_quotient}. Case div by zero"
// )
// }
// Arithmetic
{
{
// Add
let mut c_m0_plus_m1 = FheUint8 {
data: c0.data().to_vec(),
};
c_m0_plus_m1 += &c1;
let m0_plus_m1 = ck.decrypt(&c_m0_plus_m1);
assert_eq!(
m0_plus_m1,
m0.wrapping_add(m1),
"Expected {} but got {m0_plus_m1} for {i}+{j}",
m0.wrapping_add(m1)
);
}
{
// Sub
let c_sub = &c0 - &c1;
let m0_sub_m1 = ck.decrypt(&c_sub);
assert_eq!(
m0_sub_m1,
m0.wrapping_sub(m1),
"Expected {} but got {m0_sub_m1} for {i}-{j}",
m0.wrapping_sub(m1)
);
}
{
// Mul
let c_m0m1 = &c0 * &c1;
let m0m1 = ck.decrypt(&c_m0m1);
assert_eq!(
m0m1,
m0.wrapping_mul(m1),
"Expected {} but got {m0m1} for {i}x{j}",
m0.wrapping_mul(m1)
);
}
// Div & Rem
{
let (c_quotient, c_rem) = c0.div_rem(&c1);
let m_quotient = ck.decrypt(&c_quotient);
let m_remainder = ck.decrypt(&c_rem);
if j != 0 {
let (q, r) = i.div_rem_euclid(&j);
assert_eq!(
m_quotient, q,
"Expected {} but got {m_quotient} for {i}/{j}",
q
);
assert_eq!(
m_remainder, r,
"Expected {} but got {m_quotient} for {i}%{j}",
r
);
} else {
assert_eq!(
m_quotient, 255,
"Expected 255 but got {m_quotient}. Case div by zero"
);
assert_eq!(
m_remainder, i,
"Expected {i} but got {m_quotient}. Case div by zero"
)
}
}
}
// Comparisons
{
{
let c_eq = c0.eq(&c1);
let is_eq = ck.decrypt(&c_eq);
assert_eq!(
is_eq,
i == j,
"Expected {} but got {is_eq} for {i}=={j}",
i == j
);
}
{
let c_gt = c0.gt(&c1);
let is_gt = ck.decrypt(&c_gt);
assert_eq!(
is_gt,
i > j,
"Expected {} but got {is_gt} for {i}>{j}",
i > j
);
}
{
let c_lt = c0.lt(&c1);
let is_lt = ck.decrypt(&c_lt);
assert_eq!(
is_lt,
i < j,
"Expected {} but got {is_lt} for {i}<{j}",
i < j
);
}
{
let c_ge = c0.ge(&c1);
let is_ge = ck.decrypt(&c_ge);
assert_eq!(
is_ge,
i >= j,
"Expected {} but got {is_ge} for {i}>={j}",
i >= j
);
}
{
let c_le = c0.le(&c1);
let is_le = ck.decrypt(&c_le);
assert_eq!(
is_le,
i <= j,
"Expected {} but got {is_le} for {i}<={j}",
i <= j
);
}
}
} }
} }
} }

+ 16
- 16
src/shortint/ops.rs

@ -277,22 +277,22 @@ fn is_zero(evaluator: &mut E, a: &[E::Ciphertext], key: &E::Key
return a.remove(0); return a.remove(0);
} }
fn arbitrary_bit_equality<E: BooleanGates>(
pub(super) fn arbitrary_bit_equality<E: BooleanGates>(
evaluator: &mut E, evaluator: &mut E,
a: &[E::Ciphertext], a: &[E::Ciphertext],
b: &[E::Ciphertext], b: &[E::Ciphertext],
key: &E::Key, key: &E::Key,
) -> E::Ciphertext { ) -> E::Ciphertext {
assert!(a.len() == b.len()); assert!(a.len() == b.len());
let mut out = evaluator.and(&a[0], &b[0], key);
let mut out = evaluator.xnor(&a[0], &b[0], key);
izip!(a.iter(), b.iter()).skip(1).for_each(|(abit, bbit)| { izip!(a.iter(), b.iter()).skip(1).for_each(|(abit, bbit)| {
let e = evaluator.xnor(abit, bbit, key); let e = evaluator.xnor(abit, bbit, key);
evaluator.and(&mut out, &e, key);
evaluator.and_inplace(&mut out, &e, key);
}); });
return out; return out;
} }
/// Comaprator handle computes comparator result 2ns MSB onwards. It is
/// Comparator handle computes comparator result 2ns MSB onwards. It is
/// separated because comparator subroutine for signed and unsgind integers /// separated because comparator subroutine for signed and unsgind integers
/// differs only for 1st MSB and is common second MSB onwards /// differs only for 1st MSB and is common second MSB onwards
fn _comparator_handler_from_second_msb<E: BooleanGates>( fn _comparator_handler_from_second_msb<E: BooleanGates>(
@ -307,27 +307,27 @@ fn _comparator_handler_from_second_msb(
// handle MSB - 1 // handle MSB - 1
let mut tmp = evaluator.not(&b[n - 2]); let mut tmp = evaluator.not(&b[n - 2]);
evaluator.and(&mut tmp, &a[n - 2], key);
evaluator.and(&mut tmp, &casc, key);
evaluator.or(&mut comp, &tmp, key);
evaluator.and_inplace(&mut tmp, &a[n - 2], key);
evaluator.and_inplace(&mut tmp, &casc, key);
evaluator.or_inplace(&mut comp, &tmp, key);
for i in 2..n { for i in 2..n {
// calculate cascading bit // calculate cascading bit
let tmp_casc = evaluator.xnor(&a[n - 2 - i], &b[n - 2 - i], key);
evaluator.and(&mut casc, &tmp_casc, key);
let tmp_casc = evaluator.xnor(&a[n - i], &b[n - i], key);
evaluator.and_inplace(&mut casc, &tmp_casc, key);
// calculate computate bit // calculate computate bit
let mut tmp = evaluator.not(&b[n - 1 - i]); let mut tmp = evaluator.not(&b[n - 1 - i]);
evaluator.and(&mut tmp, &a[n - 1 - i], key);
evaluator.and(&mut tmp, &casc, key);
evaluator.or(&mut comp, &tmp, key);
evaluator.and_inplace(&mut tmp, &a[n - 1 - i], key);
evaluator.and_inplace(&mut tmp, &casc, key);
evaluator.or_inplace(&mut comp, &tmp, key);
} }
return comp; return comp;
} }
/// Signed integer comparison is same as unsigned integer with MSB flipped. /// Signed integer comparison is same as unsigned integer with MSB flipped.
fn arbitrary_signed_bit_comparator<E: BooleanGates>(
pub(super) fn arbitrary_signed_bit_comparator<E: BooleanGates>(
evaluator: &mut E, evaluator: &mut E,
a: &[E::Ciphertext], a: &[E::Ciphertext],
b: &[E::Ciphertext], b: &[E::Ciphertext],
@ -338,13 +338,13 @@ fn arbitrary_signed_bit_comparator(
// handle MSB // handle MSB
let mut comp = evaluator.not(&a[n - 1]); let mut comp = evaluator.not(&a[n - 1]);
evaluator.and(&mut comp, &b[n - 1], key); // comp
evaluator.and_inplace(&mut comp, &b[n - 1], key); // comp
let casc = evaluator.xnor(&a[n - 1], &b[n - 1], key); // casc let casc = evaluator.xnor(&a[n - 1], &b[n - 1], key); // casc
return _comparator_handler_from_second_msb(evaluator, a, b, comp, casc, key); return _comparator_handler_from_second_msb(evaluator, a, b, comp, casc, key);
} }
fn arbitrary_bit_comparator<E: BooleanGates>(
pub(super) fn arbitrary_bit_comparator<E: BooleanGates>(
evaluator: &mut E, evaluator: &mut E,
a: &[E::Ciphertext], a: &[E::Ciphertext],
b: &[E::Ciphertext], b: &[E::Ciphertext],
@ -355,7 +355,7 @@ fn arbitrary_bit_comparator(
// handle MSB // handle MSB
let mut comp = evaluator.not(&b[n - 1]); let mut comp = evaluator.not(&b[n - 1]);
evaluator.and(&mut comp, &a[n - 1], key);
evaluator.and_inplace(&mut comp, &a[n - 1], key);
let casc = evaluator.xnor(&a[n - 1], &b[n - 1], key); let casc = evaluator.xnor(&a[n - 1], &b[n - 1], key);
return _comparator_handler_from_second_msb(evaluator, a, b, comp, casc, key); return _comparator_handler_from_second_msb(evaluator, a, b, comp, casc, key);

Loading…
Cancel
Save