implement min, max, mux

This commit is contained in:
Janmajaya Mall
2024-07-02 10:30:11 +05:30
parent d74c96d08a
commit d8d5e40f00
9 changed files with 136 additions and 94 deletions

View File

@@ -517,7 +517,7 @@ mod tests {
let ct = seeded_ct.unseed::<Vec<Vec<u64>>>();
let m_back = (0..batch_size)
.map(|i| ck.decrypt(&ct.extract(i)))
.map(|i| ck.decrypt(&ct.extract_at(i)))
.collect_vec();
assert_eq!(m, m_back);
@@ -528,7 +528,7 @@ mod tests {
fn all_uint8_apis() {
use num_traits::Euclid;
use crate::div_zero_error_flag;
use crate::{div_zero_error_flag, FheBool};
set_single_party_parameter_sets(SP_TEST_BOOL_PARAMS);
@@ -624,7 +624,7 @@ mod tests {
}
}
// Comparisons
// // Comparisons
{
{
let c_eq = c0.eq(&c1);
@@ -681,6 +681,15 @@ mod tests {
);
}
}
// mux
{
let selector = thread_rng().gen_bool(0.5);
let selector_enc: FheBool = ck.encrypt(&selector);
let mux_out = ck.decrypt(&c0.mux(&c1, &selector_enc));
let want_mux_out = if selector { m0 } else { m1 };
assert_eq!(mux_out, want_mux_out);
}
}
}
}

View File

@@ -16,13 +16,10 @@ mod utils;
pub use backend::{
ArithmeticLazyOps, ArithmeticOps, ModInit, ModularOpsU64, ShoupMatrixFMA, VectorOps,
};
// pub use bool::{
// aggregate_public_key_shares, aggregate_server_key_shares, gen_client_key,
// gen_mp_keys_phase1, gen_mp_keys_phase2, set_mp_seed, set_parameter_set,
// ParameterSelector, };
pub use bool::*;
pub use ntt::{Ntt, NttBackendU64, NttInit};
pub use shortint::{div_zero_error_flag, FheUint8};
pub use shortint::{div_zero_error_flag, reset_error_flags, FheUint8};
pub use decomposer::{Decomposer, DecomposerIter, DefaultDecomposer};
@@ -96,11 +93,6 @@ pub trait RowEntity: Row {
fn zeros(col: usize) -> Self;
}
trait Secret {
type Element;
fn values(&self) -> &[Self::Element];
}
impl<T> Matrix for Vec<Vec<T>> {
type MatElement = T;
type R = Vec<T>;

View File

@@ -173,7 +173,7 @@ mod tests {
decomposer::DefaultDecomposer,
random::{DefaultSecureRng, NewWithSeed},
utils::{fill_random_ternary_secret_with_hamming_weight, WithLocal},
MatrixEntity, MatrixMut, Secret,
MatrixEntity, MatrixMut,
};
use super::*;
@@ -185,13 +185,6 @@ mod tests {
pub(crate) values: Vec<i32>,
}
impl Secret for LweSecret {
type Element = i32;
fn values(&self) -> &[Self::Element] {
&self.values
}
}
impl LweSecret {
fn random(hw: usize, n: usize) -> LweSecret {
DefaultSecureRng::with_local_mut(|rng| {
@@ -201,6 +194,10 @@ mod tests {
LweSecret { values: out }
})
}
fn values(&self) -> &[i32] {
&self.values
}
}
struct LweKeySwitchingKey<M, R> {

View File

@@ -24,7 +24,7 @@ pub(crate) mod tests {
fill_random_ternary_secret_with_hamming_weight, generate_prime, negacyclic_mul,
tests::Stats, ToShoup, TryConvertFrom1, WithLocal,
},
Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, Secret,
Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut,
};
use super::{
@@ -406,13 +406,6 @@ pub(crate) mod tests {
pub(crate) values: Vec<i32>,
}
impl Secret for RlweSecret {
type Element = i32;
fn values(&self) -> &[Self::Element] {
&self.values
}
}
impl RlweSecret {
pub fn random(hw: usize, n: usize) -> RlweSecret {
DefaultSecureRng::with_local_mut(|rng| {
@@ -422,6 +415,10 @@ pub(crate) mod tests {
RlweSecret { values: out }
})
}
fn values(&self) -> &[i32] {
&self.values
}
}
fn random_seed() -> [u8; 32] {

View File

@@ -17,6 +17,16 @@ pub fn div_zero_error_flag() -> Option<FheBool> {
DIV_ZERO_ERROR.with_borrow(|c| c.clone())
}
/// Reset all error flags
///
/// Error flags are thread local. When running multiple circuits in sequence
/// within a single program you must prevent error flags set during the
/// execution of previous circuit to affect error flags set during execution of
/// the next circuit by resetting the flags before starting with next circuit.
pub fn reset_error_flags() {
DIV_ZERO_ERROR.with_borrow_mut(|c| *c = None);
}
mod frontend {
use super::ops::{
arbitrary_bit_adder, arbitrary_bit_division_for_quotient_and_rem, arbitrary_bit_subtractor,
@@ -176,7 +186,9 @@ mod frontend {
}
mod booleans {
use crate::shortint::ops::{arbitrary_bit_comparator, arbitrary_bit_equality};
use crate::shortint::ops::{
arbitrary_bit_comparator, arbitrary_bit_equality, arbitrary_bit_mux,
};
use super::*;
@@ -238,6 +250,27 @@ mod frontend {
FheBool { data: a_less_b }
})
}
/// Returns `Self` if `selector = True` else returns `other`
pub fn mux(&self, other: &FheUint8, selector: &FheBool) -> FheUint8 {
BoolEvaluator::with_local_mut(|e| {
let key = RuntimeServerKey::global();
let out = arbitrary_bit_mux(e, selector.data(), self.data(), other.data(), key);
FheUint8 { data: out }
})
}
/// max(`Self`, `other`)
pub fn max(&self, other: &FheUint8) -> FheUint8 {
let self_gt = self.gt(other);
self.mux(other, &self_gt)
}
/// min(`Self`, `other`)
pub fn min(&self, other: &FheUint8) -> FheUint8 {
let self_lt = self.lt(other);
self.mux(other, &self_lt)
}
}
}
}

View File

@@ -114,9 +114,10 @@ pub(super) fn bit_mux<E: BooleanGates>(
// (s&a) | ((1-s)^b)
let not_selector = evaluator.not(&selector);
let s_and_a = evaluator.and(&selector, if_true, key);
let mut s_and_a = evaluator.and(&selector, if_true, key);
let s_and_b = evaluator.and(&not_selector, if_false, key);
evaluator.or(&s_and_a, &s_and_b, key)
evaluator.or(&mut s_and_a, &s_and_b, key);
s_and_a
}
pub(super) fn arbitrary_bit_mux<E: BooleanGates>(
@@ -131,9 +132,10 @@ pub(super) fn arbitrary_bit_mux<E: BooleanGates>(
izip!(if_true.iter(), if_false.iter())
.map(|(a, b)| {
let s_and_a = evaluator.and(&selector, a, key);
let mut s_and_a = evaluator.and(&selector, a, key);
let s_and_b = evaluator.and(&not_selector, b, key);
evaluator.or(&s_and_a, &s_and_b, key)
evaluator.or_inplace(&mut s_and_a, &s_and_b, key);
s_and_a
})
.collect()
}