mirror of
https://github.com/arnaucube/phantom-zone.git
synced 2026-01-10 16:11:30 +01:00
implement min, max, mux
This commit is contained in:
@@ -517,7 +517,7 @@ mod tests {
|
||||
let ct = seeded_ct.unseed::<Vec<Vec<u64>>>();
|
||||
|
||||
let m_back = (0..batch_size)
|
||||
.map(|i| ck.decrypt(&ct.extract(i)))
|
||||
.map(|i| ck.decrypt(&ct.extract_at(i)))
|
||||
.collect_vec();
|
||||
|
||||
assert_eq!(m, m_back);
|
||||
@@ -528,7 +528,7 @@ mod tests {
|
||||
fn all_uint8_apis() {
|
||||
use num_traits::Euclid;
|
||||
|
||||
use crate::div_zero_error_flag;
|
||||
use crate::{div_zero_error_flag, FheBool};
|
||||
|
||||
set_single_party_parameter_sets(SP_TEST_BOOL_PARAMS);
|
||||
|
||||
@@ -624,7 +624,7 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
// Comparisons
|
||||
// // Comparisons
|
||||
{
|
||||
{
|
||||
let c_eq = c0.eq(&c1);
|
||||
@@ -681,6 +681,15 @@ mod tests {
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// mux
|
||||
{
|
||||
let selector = thread_rng().gen_bool(0.5);
|
||||
let selector_enc: FheBool = ck.encrypt(&selector);
|
||||
let mux_out = ck.decrypt(&c0.mux(&c1, &selector_enc));
|
||||
let want_mux_out = if selector { m0 } else { m1 };
|
||||
assert_eq!(mux_out, want_mux_out);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
12
src/lib.rs
12
src/lib.rs
@@ -16,13 +16,10 @@ mod utils;
|
||||
pub use backend::{
|
||||
ArithmeticLazyOps, ArithmeticOps, ModInit, ModularOpsU64, ShoupMatrixFMA, VectorOps,
|
||||
};
|
||||
// pub use bool::{
|
||||
// aggregate_public_key_shares, aggregate_server_key_shares, gen_client_key,
|
||||
// gen_mp_keys_phase1, gen_mp_keys_phase2, set_mp_seed, set_parameter_set,
|
||||
// ParameterSelector, };
|
||||
|
||||
pub use bool::*;
|
||||
pub use ntt::{Ntt, NttBackendU64, NttInit};
|
||||
pub use shortint::{div_zero_error_flag, FheUint8};
|
||||
pub use shortint::{div_zero_error_flag, reset_error_flags, FheUint8};
|
||||
|
||||
pub use decomposer::{Decomposer, DecomposerIter, DefaultDecomposer};
|
||||
|
||||
@@ -96,11 +93,6 @@ pub trait RowEntity: Row {
|
||||
fn zeros(col: usize) -> Self;
|
||||
}
|
||||
|
||||
trait Secret {
|
||||
type Element;
|
||||
fn values(&self) -> &[Self::Element];
|
||||
}
|
||||
|
||||
impl<T> Matrix for Vec<Vec<T>> {
|
||||
type MatElement = T;
|
||||
type R = Vec<T>;
|
||||
|
||||
13
src/lwe.rs
13
src/lwe.rs
@@ -173,7 +173,7 @@ mod tests {
|
||||
decomposer::DefaultDecomposer,
|
||||
random::{DefaultSecureRng, NewWithSeed},
|
||||
utils::{fill_random_ternary_secret_with_hamming_weight, WithLocal},
|
||||
MatrixEntity, MatrixMut, Secret,
|
||||
MatrixEntity, MatrixMut,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
@@ -185,13 +185,6 @@ mod tests {
|
||||
pub(crate) values: Vec<i32>,
|
||||
}
|
||||
|
||||
impl Secret for LweSecret {
|
||||
type Element = i32;
|
||||
fn values(&self) -> &[Self::Element] {
|
||||
&self.values
|
||||
}
|
||||
}
|
||||
|
||||
impl LweSecret {
|
||||
fn random(hw: usize, n: usize) -> LweSecret {
|
||||
DefaultSecureRng::with_local_mut(|rng| {
|
||||
@@ -201,6 +194,10 @@ mod tests {
|
||||
LweSecret { values: out }
|
||||
})
|
||||
}
|
||||
|
||||
fn values(&self) -> &[i32] {
|
||||
&self.values
|
||||
}
|
||||
}
|
||||
|
||||
struct LweKeySwitchingKey<M, R> {
|
||||
|
||||
@@ -24,7 +24,7 @@ pub(crate) mod tests {
|
||||
fill_random_ternary_secret_with_hamming_weight, generate_prime, negacyclic_mul,
|
||||
tests::Stats, ToShoup, TryConvertFrom1, WithLocal,
|
||||
},
|
||||
Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, Secret,
|
||||
Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut,
|
||||
};
|
||||
|
||||
use super::{
|
||||
@@ -406,13 +406,6 @@ pub(crate) mod tests {
|
||||
pub(crate) values: Vec<i32>,
|
||||
}
|
||||
|
||||
impl Secret for RlweSecret {
|
||||
type Element = i32;
|
||||
fn values(&self) -> &[Self::Element] {
|
||||
&self.values
|
||||
}
|
||||
}
|
||||
|
||||
impl RlweSecret {
|
||||
pub fn random(hw: usize, n: usize) -> RlweSecret {
|
||||
DefaultSecureRng::with_local_mut(|rng| {
|
||||
@@ -422,6 +415,10 @@ pub(crate) mod tests {
|
||||
RlweSecret { values: out }
|
||||
})
|
||||
}
|
||||
|
||||
fn values(&self) -> &[i32] {
|
||||
&self.values
|
||||
}
|
||||
}
|
||||
|
||||
fn random_seed() -> [u8; 32] {
|
||||
|
||||
@@ -17,6 +17,16 @@ pub fn div_zero_error_flag() -> Option<FheBool> {
|
||||
DIV_ZERO_ERROR.with_borrow(|c| c.clone())
|
||||
}
|
||||
|
||||
/// Reset all error flags
|
||||
///
|
||||
/// Error flags are thread local. When running multiple circuits in sequence
|
||||
/// within a single program you must prevent error flags set during the
|
||||
/// execution of previous circuit to affect error flags set during execution of
|
||||
/// the next circuit by resetting the flags before starting with next circuit.
|
||||
pub fn reset_error_flags() {
|
||||
DIV_ZERO_ERROR.with_borrow_mut(|c| *c = None);
|
||||
}
|
||||
|
||||
mod frontend {
|
||||
use super::ops::{
|
||||
arbitrary_bit_adder, arbitrary_bit_division_for_quotient_and_rem, arbitrary_bit_subtractor,
|
||||
@@ -176,7 +186,9 @@ mod frontend {
|
||||
}
|
||||
|
||||
mod booleans {
|
||||
use crate::shortint::ops::{arbitrary_bit_comparator, arbitrary_bit_equality};
|
||||
use crate::shortint::ops::{
|
||||
arbitrary_bit_comparator, arbitrary_bit_equality, arbitrary_bit_mux,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -238,6 +250,27 @@ mod frontend {
|
||||
FheBool { data: a_less_b }
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns `Self` if `selector = True` else returns `other`
|
||||
pub fn mux(&self, other: &FheUint8, selector: &FheBool) -> FheUint8 {
|
||||
BoolEvaluator::with_local_mut(|e| {
|
||||
let key = RuntimeServerKey::global();
|
||||
let out = arbitrary_bit_mux(e, selector.data(), self.data(), other.data(), key);
|
||||
FheUint8 { data: out }
|
||||
})
|
||||
}
|
||||
|
||||
/// max(`Self`, `other`)
|
||||
pub fn max(&self, other: &FheUint8) -> FheUint8 {
|
||||
let self_gt = self.gt(other);
|
||||
self.mux(other, &self_gt)
|
||||
}
|
||||
|
||||
/// min(`Self`, `other`)
|
||||
pub fn min(&self, other: &FheUint8) -> FheUint8 {
|
||||
let self_lt = self.lt(other);
|
||||
self.mux(other, &self_lt)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,9 +114,10 @@ pub(super) fn bit_mux<E: BooleanGates>(
|
||||
// (s&a) | ((1-s)^b)
|
||||
let not_selector = evaluator.not(&selector);
|
||||
|
||||
let s_and_a = evaluator.and(&selector, if_true, key);
|
||||
let mut s_and_a = evaluator.and(&selector, if_true, key);
|
||||
let s_and_b = evaluator.and(¬_selector, if_false, key);
|
||||
evaluator.or(&s_and_a, &s_and_b, key)
|
||||
evaluator.or(&mut s_and_a, &s_and_b, key);
|
||||
s_and_a
|
||||
}
|
||||
|
||||
pub(super) fn arbitrary_bit_mux<E: BooleanGates>(
|
||||
@@ -131,9 +132,10 @@ pub(super) fn arbitrary_bit_mux<E: BooleanGates>(
|
||||
|
||||
izip!(if_true.iter(), if_false.iter())
|
||||
.map(|(a, b)| {
|
||||
let s_and_a = evaluator.and(&selector, a, key);
|
||||
let mut s_and_a = evaluator.and(&selector, a, key);
|
||||
let s_and_b = evaluator.and(¬_selector, b, key);
|
||||
evaluator.or(&s_and_a, &s_and_b, key)
|
||||
evaluator.or_inplace(&mut s_and_a, &s_and_b, key);
|
||||
s_and_a
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user