Browse Source

fix pbs up with shoup

par-agg-key-shares
Janmajaya Mall 10 months ago
parent
commit
590a222c92
10 changed files with 1253 additions and 1236 deletions
  1. +7
    -7
      benches/modulus.rs
  2. +2
    -1
      src/backend/mod.rs
  3. +1117
    -1114
      src/bool/evaluator.rs
  4. +44
    -6
      src/bool/keys.rs
  5. +18
    -44
      src/bool/mod.rs
  6. +11
    -6
      src/bool/noise.rs
  7. +3
    -1
      src/lib.rs
  8. +8
    -28
      src/rgsw/mod.rs
  9. +18
    -21
      src/shortint/mod.rs
  10. +25
    -8
      src/utils.rs

+ 7
- 7
benches/modulus.rs

@ -1,4 +1,7 @@
use bin_rs::{ArithmeticOps, Decomposer, DefaultDecomposer, ModInit, ModularOpsU64, VectorOps};
use bin_rs::{
ArithmeticLazyOps, ArithmeticOps, Decomposer, DefaultDecomposer, ModInit, ModularOpsU64,
ShoupMatrixFMA, VectorOps,
};
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
use itertools::{izip, Itertools}; use itertools::{izip, Itertools};
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
@ -21,12 +24,9 @@ fn decompose_r(r: &[u64], decomp_r: &mut [Vec], decomposer: &DefaultDecompo
} }
fn matrix_fma(out: &mut [u64], a: &Vec<Vec<u64>>, b: &Vec<Vec<u64>>, modop: &ModularOpsU64<u64>) { fn matrix_fma(out: &mut [u64], a: &Vec<Vec<u64>>, b: &Vec<Vec<u64>>, modop: &ModularOpsU64<u64>) {
izip!(out.iter_mut(), a[0].iter(), b[0].iter())
.for_each(|(o, ai, bi)| *o = modop.add(o, &modop.mul_lazy(ai, bi)));
izip!(a.iter().skip(1), b.iter().skip(1)).for_each(|(a_r, b_r)| {
izip!(a.iter(), b.iter()).for_each(|(a_r, b_r)| {
izip!(out.iter_mut(), a_r.iter(), b_r.iter()) izip!(out.iter_mut(), a_r.iter(), b_r.iter())
.for_each(|(o, ai, bi)| *o = modop.add_lazy(o, &modop.mul(ai, bi)));
.for_each(|(o, ai, bi)| *o = modop.add_lazy(o, &modop.mul_lazy(ai, bi)));
}); });
} }
@ -127,7 +127,7 @@ fn benchmark(c: &mut Criterion) {
b.iter_batched_ref( b.iter_batched_ref(
|| (vec![0u64; ring_size]), || (vec![0u64; ring_size]),
|(out)| { |(out)| {
black_box(modop.shoup_fma(
black_box(modop.shoup_matrix_fma(
out, out,
&a0_matrix, &a0_matrix,
&a0_shoup_matrix, &a0_shoup_matrix,

+ 2
- 1
src/backend/mod.rs

@ -127,6 +127,7 @@ pub trait ArithmeticLazyOps {
} }
pub trait ShoupMatrixFMA<R: Row> { pub trait ShoupMatrixFMA<R: Row> {
/// Returns summation of `row-wise product of matrix a and b` + out.
/// Returns summation of `row-wise product of matrix a and b` + out where
/// each element is in range [0, 2q)
fn shoup_matrix_fma(&self, out: &mut [R::Element], a: &[R], a_shoup: &[R], b: &[R]); fn shoup_matrix_fma(&self, out: &mut [R::Element], a: &[R], a_shoup: &[R], b: &[R]);
} }

+ 1117
- 1114
src/bool/evaluator.rs
File diff suppressed because it is too large
View File


+ 44
- 6
src/bool/keys.rs

@ -6,7 +6,7 @@ use crate::{
pbs::WithShoupRepr, pbs::WithShoupRepr,
random::{NewWithSeed, RandomFillUniformInModulus}, random::{NewWithSeed, RandomFillUniformInModulus},
rgsw::RlweSecret, rgsw::RlweSecret,
utils::WithLocal,
utils::{ToShoup, WithLocal},
Decryptor, Encryptor, Matrix, MatrixEntity, MatrixMut, MultiPartyDecryptor, RowEntity, RowMut, Decryptor, Encryptor, Matrix, MatrixEntity, MatrixMut, MultiPartyDecryptor, RowEntity, RowMut,
}; };
@ -669,7 +669,7 @@ pub(super) mod impl_server_key_eval_domain {
} }
/// Server key in evaluation domain /// Server key in evaluation domain
pub(crate) struct ShoupServerKeyEvaluationDomain<M, P, R, N> {
pub(crate) struct ShoupServerKeyEvaluationDomain<M> {
/// Rgsw cts of LWE secret elements /// Rgsw cts of LWE secret elements
rgsw_cts: Vec<NormalAndShoup<M>>, rgsw_cts: Vec<NormalAndShoup<M>>,
/// Auto keys. Key corresponding to g^{k} is at index `k`. Key corresponding /// Auto keys. Key corresponding to g^{k} is at index `k`. Key corresponding
@ -677,12 +677,18 @@ pub(crate) struct ShoupServerKeyEvaluationDomain {
galois_keys: HashMap<usize, NormalAndShoup<M>>, galois_keys: HashMap<usize, NormalAndShoup<M>>,
/// LWE ksk to key switching LWE ciphertext from RLWE secret to LWE secret /// LWE ksk to key switching LWE ciphertext from RLWE secret to LWE secret
lwe_ksk: M, lwe_ksk: M,
parameters: P,
_phanton: PhantomData<(R, N)>,
} }
/// Stores normal and shoup representation of Matrix elements (Normal, Shoup)
pub(crate) struct NormalAndShoup<M>(M, M); pub(crate) struct NormalAndShoup<M>(M, M);
impl<M: ToShoup> NormalAndShoup<M> {
fn new_with_modulus(value: M, modulus: <M as ToShoup>::Modulus) -> Self {
let value_shoup = M::to_shoup(&value, modulus);
NormalAndShoup(value, value_shoup)
}
}
impl<M> AsRef<M> for NormalAndShoup<M> { impl<M> AsRef<M> for NormalAndShoup<M> {
fn as_ref(&self) -> &M { fn as_ref(&self) -> &M {
&self.0 &self.0
@ -697,11 +703,43 @@ impl WithShoupRepr for NormalAndShoup {
} }
mod shoup_server_key_eval_domain { mod shoup_server_key_eval_domain {
use crate::pbs::PbsKey;
use itertools::{izip, Itertools};
use num_traits::{FromPrimitive, PrimInt};
use crate::{backend::Modulus, pbs::PbsKey};
use super::*; use super::*;
impl<M: Matrix, P, R, N> PbsKey for ShoupServerKeyEvaluationDomain<M, P, R, N> {
impl<M: MatrixMut + MatrixEntity + ToShoup<Modulus = M::MatElement>, R, N>
From<ServerKeyEvaluationDomain<M, BoolParameters<M::MatElement>, R, N>>
for ShoupServerKeyEvaluationDomain<M>
where
<M as Matrix>::R: RowMut,
M::MatElement: PrimInt + FromPrimitive,
{
fn from(value: ServerKeyEvaluationDomain<M, BoolParameters<M::MatElement>, R, N>) -> Self {
let q = value.parameters.rlwe_q().q().unwrap();
// Rgsw ciphertexts
let rgsw_cts = value
.rgsw_cts
.into_iter()
.map(|ct| NormalAndShoup::new_with_modulus(ct, q))
.collect_vec();
let mut auto_keys = HashMap::new();
value.galois_keys.into_iter().for_each(|(index, key)| {
auto_keys.insert(index, NormalAndShoup::new_with_modulus(key, q));
});
Self {
rgsw_cts,
galois_keys: auto_keys,
lwe_ksk: value.lwe_ksk,
}
}
}
impl<M: Matrix> PbsKey for ShoupServerKeyEvaluationDomain<M> {
type AutoKey = NormalAndShoup<M>; type AutoKey = NormalAndShoup<M>;
type LweKskKey = M; type LweKskKey = M;
type RgswCt = NormalAndShoup<M>; type RgswCt = NormalAndShoup<M>;

+ 18
- 44
src/bool/mod.rs

@ -19,17 +19,10 @@ use crate::{
}; };
thread_local! { thread_local! {
static BOOL_EVALUATOR: RefCell<Option<BoolEvaluator<Vec<Vec<u64>>, NttBackendU64, ModularOpsU64<CiphertextModulus<u64>>, ModularOpsU64<CiphertextModulus<u64>>>>> = RefCell::new(None);
static BOOL_EVALUATOR: RefCell<Option<BoolEvaluator<Vec<Vec<u64>>, NttBackendU64, ModularOpsU64<CiphertextModulus<u64>>, ModularOpsU64<CiphertextModulus<u64>>, ShoupServerKeyEvaluationDomain<Vec<Vec<u64>>>>>> = RefCell::new(None);
} }
static BOOL_SERVER_KEY: OnceLock<
ShoupServerKeyEvaluationDomain<
Vec<Vec<u64>>,
BoolParameters<u64>,
DefaultSecureRng,
NttBackendU64,
>,
> = OnceLock::new();
static BOOL_SERVER_KEY: OnceLock<ShoupServerKeyEvaluationDomain<Vec<Vec<u64>>>> = OnceLock::new();
static MULTI_PARTY_CRS: OnceLock<MultiPartyCrs<[u8; 32]>> = OnceLock::new(); static MULTI_PARTY_CRS: OnceLock<MultiPartyCrs<[u8; 32]>> = OnceLock::new();
@ -44,14 +37,7 @@ pub fn set_mp_seed(seed: [u8; 32]) {
) )
} }
fn set_server_key(
key: ShoupServerKeyEvaluationDomain<
Vec<Vec<u64>>,
BoolParameters<u64>,
DefaultSecureRng,
NttBackendU64,
>,
) {
fn set_server_key(key: ShoupServerKeyEvaluationDomain<Vec<Vec<u64>>>) {
assert!( assert!(
BOOL_SERVER_KEY.set(key).is_ok(), BOOL_SERVER_KEY.set(key).is_ok(),
"Attempted to set server key twice." "Attempted to set server key twice."
@ -64,7 +50,7 @@ pub(crate) fn gen_keys() -> (
) { ) {
BoolEvaluator::with_local_mut(|e| { BoolEvaluator::with_local_mut(|e| {
let ck = e.client_key(); let ck = e.client_key();
let sk = e.server_key(&ck);
let sk = e.single_party_server_key(&ck);
(ck, sk) (ck, sk)
}) })
@ -115,15 +101,11 @@ pub fn aggregate_server_key_shares(
BoolEvaluator::with_local(|e| e.aggregate_multi_party_server_key_shares(shares)) BoolEvaluator::with_local(|e| e.aggregate_multi_party_server_key_shares(shares))
} }
// SERVER KEY EVAL DOMAIN //
// SERVER KEY EVAL (/SHOUP) DOMAIN //
impl SeededServerKey<Vec<Vec<u64>>, BoolParameters<u64>, [u8; 32]> { impl SeededServerKey<Vec<Vec<u64>>, BoolParameters<u64>, [u8; 32]> {
pub fn set_server_key(&self) { pub fn set_server_key(&self) {
set_server_key(ServerKeyEvaluationDomain::<
_,
_,
DefaultSecureRng,
NttBackendU64,
>::from(self));
let eval = ServerKeyEvaluationDomain::<_, _, DefaultSecureRng, NttBackendU64>::from(self);
set_server_key(ShoupServerKeyEvaluationDomain::from(eval));
} }
} }
@ -135,25 +117,9 @@ impl
> >
{ {
pub fn set_server_key(&self) { pub fn set_server_key(&self) {
set_server_key(ServerKeyEvaluationDomain::<
_,
_,
DefaultSecureRng,
NttBackendU64,
>::from(self))
}
}
impl Global
for ShoupServerKeyEvaluationDomain<
Vec<Vec<u64>>,
BoolParameters<u64>,
DefaultSecureRng,
NttBackendU64,
>
{
fn global() -> &'static Self {
BOOL_SERVER_KEY.get().unwrap()
set_server_key(ShoupServerKeyEvaluationDomain::from(
ServerKeyEvaluationDomain::<_, _, DefaultSecureRng, NttBackendU64>::from(self),
))
} }
} }
@ -173,6 +139,7 @@ impl WithLocal
NttBackendU64, NttBackendU64,
ModularOpsU64<CiphertextModulus<u64>>, ModularOpsU64<CiphertextModulus<u64>>,
ModularOpsU64<CiphertextModulus<u64>>, ModularOpsU64<CiphertextModulus<u64>>,
ShoupServerKeyEvaluationDomain<Vec<Vec<u64>>>,
> >
{ {
fn with_local<F, R>(func: F) -> R fn with_local<F, R>(func: F) -> R
@ -196,3 +163,10 @@ impl WithLocal
BOOL_EVALUATOR.with_borrow_mut(|s| func(s.as_mut().expect("Parameters not set"))) BOOL_EVALUATOR.with_borrow_mut(|s| func(s.as_mut().expect("Parameters not set")))
} }
} }
pub(crate) type RuntimeServerKey = ShoupServerKeyEvaluationDomain<Vec<Vec<u64>>>;
impl Global for RuntimeServerKey {
fn global() -> &'static Self {
BOOL_SERVER_KEY.get().expect("Server key not set!")
}
}

+ 11
- 6
src/bool/noise.rs

@ -1,5 +1,3 @@
use std::cell::RefCell;
mod test { mod test {
use itertools::{izip, Itertools}; use itertools::{izip, Itertools};
@ -7,7 +5,8 @@ mod test {
backend::{ArithmeticOps, ModularOpsU64, Modulus}, backend::{ArithmeticOps, ModularOpsU64, Modulus},
bool::{ bool::{
set_parameter_set, BoolEncoding, BoolEvaluator, BooleanGates, CiphertextModulus, set_parameter_set, BoolEncoding, BoolEvaluator, BooleanGates, CiphertextModulus,
ClientKey, PublicKey, ServerKeyEvaluationDomain, MP_BOOL_PARAMS, SMALL_MP_BOOL_PARAMS,
ClientKey, PublicKey, ServerKeyEvaluationDomain, ShoupServerKeyEvaluationDomain,
MP_BOOL_PARAMS, SMALL_MP_BOOL_PARAMS,
}, },
lwe::{decrypt_lwe, LweSecret}, lwe::{decrypt_lwe, LweSecret},
ntt::NttBackendU64, ntt::NttBackendU64,
@ -15,7 +14,7 @@ mod test {
random::DefaultSecureRng, random::DefaultSecureRng,
rgsw::RlweSecret, rgsw::RlweSecret,
utils::Stats, utils::Stats,
Secret,
Ntt, Secret,
}; };
#[test] #[test]
@ -26,6 +25,7 @@ mod test {
NttBackendU64, NttBackendU64,
ModularOpsU64<CiphertextModulus<u64>>, ModularOpsU64<CiphertextModulus<u64>>,
ModularOpsU64<CiphertextModulus<u64>>, ModularOpsU64<CiphertextModulus<u64>>,
ShoupServerKeyEvaluationDomain<Vec<Vec<u64>>>,
>::new(SMALL_MP_BOOL_PARAMS); >::new(SMALL_MP_BOOL_PARAMS);
let parties = 2; let parties = 2;
@ -84,7 +84,12 @@ mod test {
.collect_vec(); .collect_vec();
let server_key = evaluator.aggregate_multi_party_server_key_shares(&server_key_shares); let server_key = evaluator.aggregate_multi_party_server_key_shares(&server_key_shares);
let server_key_eval_domain = ServerKeyEvaluationDomain::from(&server_key);
let runtime_server_key = ShoupServerKeyEvaluationDomain::from(ServerKeyEvaluationDomain::<
_,
_,
DefaultSecureRng,
NttBackendU64,
>::from(&server_key));
let mut m0 = false; let mut m0 = false;
let mut m1 = true; let mut m1 = true;
@ -99,7 +104,7 @@ mod test {
for _ in 0..1000 { for _ in 0..1000 {
let now = std::time::Instant::now(); let now = std::time::Instant::now();
let c_out = evaluator.xor(&c_m0, &c_m1, &server_key_eval_domain);
let c_out = evaluator.xor(&c_m0, &c_m1, &runtime_server_key);
println!("Gate time: {:?}", now.elapsed()); println!("Gate time: {:?}", now.elapsed());
// mp decrypt // mp decrypt

+ 3
- 1
src/lib.rs

@ -20,7 +20,9 @@ mod rgsw;
mod shortint; mod shortint;
mod utils; mod utils;
pub use backend::{ArithmeticOps, ModInit, ModularOpsU64, VectorOps};
pub use backend::{
ArithmeticLazyOps, ArithmeticOps, ModInit, ModularOpsU64, ShoupMatrixFMA, VectorOps,
};
pub use decomposer::{Decomposer, DecomposerIter, DefaultDecomposer}; pub use decomposer::{Decomposer, DecomposerIter, DefaultDecomposer};
pub use ntt::{Ntt, NttBackendU64, NttInit}; pub use ntt::{Ntt, NttBackendU64, NttInit};

+ 8
- 28
src/rgsw/mod.rs

@ -95,23 +95,13 @@ pub struct ShoupAutoKeyEvaluationDomain {
data: M, data: M,
} }
impl<M: MatrixMut + MatrixEntity, Mod: Modulus<Element = M::MatElement>, R, N>
impl<M: Matrix + ToShoup<Modulus = M::MatElement>, Mod: Modulus<Element = M::MatElement>, R, N>
From<&AutoKeyEvaluationDomain<M, Mod, R, N>> for ShoupAutoKeyEvaluationDomain<M> From<&AutoKeyEvaluationDomain<M, Mod, R, N>> for ShoupAutoKeyEvaluationDomain<M>
where
M::R: RowMut,
M::MatElement: ToShoup + Copy,
{ {
fn from(value: &AutoKeyEvaluationDomain<M, Mod, R, N>) -> Self { fn from(value: &AutoKeyEvaluationDomain<M, Mod, R, N>) -> Self {
let (row, col) = value.data.dimension();
let mut shoup_data = M::zeros(row, col);
izip!(shoup_data.iter_rows_mut(), value.data.iter_rows()).for_each(|(shoup_r, r)| {
izip!(shoup_r.as_mut().iter_mut(), r.as_ref().iter()).for_each(|(s, e)| {
*s = M::MatElement::to_shoup(*e, value.modulus.q().unwrap());
});
});
Self { data: shoup_data }
Self {
data: M::to_shoup(&value.data, value.modulus.q().unwrap()),
}
} }
} }
@ -328,23 +318,13 @@ pub struct ShoupRgswCiphertextEvaluationDomain {
pub(crate) data: M, pub(crate) data: M,
} }
impl<M: MatrixMut + MatrixEntity, Mod: Modulus<Element = M::MatElement>, R, N>
impl<M: Matrix + ToShoup<Modulus = M::MatElement>, Mod: Modulus<Element = M::MatElement>, R, N>
From<&RgswCiphertextEvaluationDomain<M, Mod, R, N>> for ShoupRgswCiphertextEvaluationDomain<M> From<&RgswCiphertextEvaluationDomain<M, Mod, R, N>> for ShoupRgswCiphertextEvaluationDomain<M>
where
M::R: RowMut,
M::MatElement: ToShoup + Copy,
{ {
fn from(value: &RgswCiphertextEvaluationDomain<M, Mod, R, N>) -> Self { fn from(value: &RgswCiphertextEvaluationDomain<M, Mod, R, N>) -> Self {
let (row, col) = value.data.dimension();
let mut shoup_data = M::zeros(row, col);
izip!(shoup_data.iter_rows_mut(), value.data.iter_rows()).for_each(|(shoup_r, r)| {
izip!(shoup_r.as_mut().iter_mut(), r.as_ref().iter()).for_each(|(s, e)| {
*s = M::MatElement::to_shoup(*e, value.modulus.q().unwrap());
});
});
Self { data: shoup_data }
Self {
data: M::to_shoup(&value.data, value.modulus.q().unwrap()),
}
} }
} }

+ 18
- 21
src/shortint/mod.rs

@ -106,19 +106,16 @@ mod frontend {
use super::FheUint8; use super::FheUint8;
type ShortIntBoolEvaluator<M, Ntt, RlweModOp, LweModOp> =
BoolEvaluator<M, Ntt, RlweModOp, LweModOp>;
mod arithetic { mod arithetic {
use crate::bool::{evaluator::BooleanGates, FheBool};
use crate::bool::{FheBool, RuntimeServerKey};
use super::*; use super::*;
use std::ops::{Add, AddAssign, Div, Mul, Rem, Sub}; use std::ops::{Add, AddAssign, Div, Mul, Rem, Sub};
impl AddAssign<&FheUint8> for FheUint8 { impl AddAssign<&FheUint8> for FheUint8 {
fn add_assign(&mut self, rhs: &FheUint8) { fn add_assign(&mut self, rhs: &FheUint8) {
ShortIntBoolEvaluator::with_local_mut_mut(&mut |e| {
let key = <ShortIntBoolEvaluator<_, _, _, _> as BooleanGates>::Key::global();
BoolEvaluator::with_local_mut_mut(&mut |e| {
let key = RuntimeServerKey::global();
arbitrary_bit_adder(e, self.data_mut(), rhs.data(), false, key); arbitrary_bit_adder(e, self.data_mut(), rhs.data(), false, key);
}); });
} }
@ -137,7 +134,7 @@ mod frontend {
type Output = FheUint8; type Output = FheUint8;
fn sub(self, rhs: &FheUint8) -> Self::Output { fn sub(self, rhs: &FheUint8) -> Self::Output {
BoolEvaluator::with_local_mut(|e| { BoolEvaluator::with_local_mut(|e| {
let key = ServerKeyEvaluationDomain::global();
let key = RuntimeServerKey::global();
let (out, _, _) = arbitrary_bit_subtractor(e, self.data(), rhs.data(), key); let (out, _, _) = arbitrary_bit_subtractor(e, self.data(), rhs.data(), key);
FheUint8 { data: out } FheUint8 { data: out }
}) })
@ -148,7 +145,7 @@ mod frontend {
type Output = FheUint8; type Output = FheUint8;
fn mul(self, rhs: &FheUint8) -> Self::Output { fn mul(self, rhs: &FheUint8) -> Self::Output {
BoolEvaluator::with_local_mut(|e| { BoolEvaluator::with_local_mut(|e| {
let key = ServerKeyEvaluationDomain::global();
let key = RuntimeServerKey::global();
let out = eight_bit_mul(e, self.data(), rhs.data(), key); let out = eight_bit_mul(e, self.data(), rhs.data(), key);
FheUint8 { data: out } FheUint8 { data: out }
}) })
@ -160,7 +157,7 @@ mod frontend {
fn div(self, rhs: &FheUint8) -> Self::Output { fn div(self, rhs: &FheUint8) -> Self::Output {
// TODO(Jay:) Figure out how to set zero error flag // TODO(Jay:) Figure out how to set zero error flag
BoolEvaluator::with_local_mut(|e| { BoolEvaluator::with_local_mut(|e| {
let key = ServerKeyEvaluationDomain::global();
let key = RuntimeServerKey::global();
let (quotient, _) = arbitrary_bit_division_for_quotient_and_rem( let (quotient, _) = arbitrary_bit_division_for_quotient_and_rem(
e, e,
self.data(), self.data(),
@ -176,7 +173,7 @@ mod frontend {
type Output = FheUint8; type Output = FheUint8;
fn rem(self, rhs: &FheUint8) -> Self::Output { fn rem(self, rhs: &FheUint8) -> Self::Output {
BoolEvaluator::with_local_mut(|e| { BoolEvaluator::with_local_mut(|e| {
let key = ServerKeyEvaluationDomain::global();
let key = RuntimeServerKey::global();
let (_, remainder) = arbitrary_bit_division_for_quotient_and_rem( let (_, remainder) = arbitrary_bit_division_for_quotient_and_rem(
e, e,
self.data(), self.data(),
@ -191,7 +188,7 @@ mod frontend {
impl FheUint8 { impl FheUint8 {
pub fn overflowing_add_assign(&mut self, rhs: &FheUint8) -> FheBool { pub fn overflowing_add_assign(&mut self, rhs: &FheUint8) -> FheBool {
BoolEvaluator::with_local_mut_mut(&mut |e| { BoolEvaluator::with_local_mut_mut(&mut |e| {
let key = ServerKeyEvaluationDomain::global();
let key = RuntimeServerKey::global();
let (overflow, _) = let (overflow, _) =
arbitrary_bit_adder(e, self.data_mut(), rhs.data(), false, key); arbitrary_bit_adder(e, self.data_mut(), rhs.data(), false, key);
overflow overflow
@ -201,7 +198,7 @@ mod frontend {
pub fn overflowing_add(self, rhs: &FheUint8) -> (FheUint8, FheBool) { pub fn overflowing_add(self, rhs: &FheUint8) -> (FheUint8, FheBool) {
BoolEvaluator::with_local_mut(|e| { BoolEvaluator::with_local_mut(|e| {
let mut lhs = self.clone(); let mut lhs = self.clone();
let key = ServerKeyEvaluationDomain::global();
let key = RuntimeServerKey::global();
let (overflow, _) = let (overflow, _) =
arbitrary_bit_adder(e, lhs.data_mut(), rhs.data(), false, key); arbitrary_bit_adder(e, lhs.data_mut(), rhs.data(), false, key);
(lhs, overflow) (lhs, overflow)
@ -210,7 +207,7 @@ mod frontend {
pub fn overflowing_sub(&self, rhs: &FheUint8) -> (FheUint8, FheBool) { pub fn overflowing_sub(&self, rhs: &FheUint8) -> (FheUint8, FheBool) {
BoolEvaluator::with_local_mut(|e| { BoolEvaluator::with_local_mut(|e| {
let key = ServerKeyEvaluationDomain::global();
let key = RuntimeServerKey::global();
let (out, mut overflow, _) = let (out, mut overflow, _) =
arbitrary_bit_subtractor(e, self.data(), rhs.data(), key); arbitrary_bit_subtractor(e, self.data(), rhs.data(), key);
e.not_inplace(&mut overflow); e.not_inplace(&mut overflow);
@ -221,7 +218,7 @@ mod frontend {
pub fn div_rem(&self, rhs: &FheUint8) -> (FheUint8, FheUint8) { pub fn div_rem(&self, rhs: &FheUint8) -> (FheUint8, FheUint8) {
// TODO(Jay:) Figure out how to set zero error flag // TODO(Jay:) Figure out how to set zero error flag
BoolEvaluator::with_local_mut(|e| { BoolEvaluator::with_local_mut(|e| {
let key = ServerKeyEvaluationDomain::global();
let key = RuntimeServerKey::global();
let (quotient, remainder) = arbitrary_bit_division_for_quotient_and_rem( let (quotient, remainder) = arbitrary_bit_division_for_quotient_and_rem(
e, e,
self.data(), self.data(),
@ -236,7 +233,7 @@ mod frontend {
mod booleans { mod booleans {
use crate::{ use crate::{
bool::{evaluator::BooleanGates, FheBool},
bool::{evaluator::BooleanGates, FheBool, RuntimeServerKey},
shortint::ops::{ shortint::ops::{
arbitrary_bit_comparator, arbitrary_bit_equality, arbitrary_signed_bit_comparator, arbitrary_bit_comparator, arbitrary_bit_equality, arbitrary_signed_bit_comparator,
}, },
@ -248,7 +245,7 @@ mod frontend {
/// a == b /// a == b
pub fn eq(&self, other: &FheUint8) -> FheBool { pub fn eq(&self, other: &FheUint8) -> FheBool {
BoolEvaluator::with_local_mut(|e| { BoolEvaluator::with_local_mut(|e| {
let key = ServerKeyEvaluationDomain::global();
let key = RuntimeServerKey::global();
arbitrary_bit_equality(e, self.data(), other.data(), key) arbitrary_bit_equality(e, self.data(), other.data(), key)
}) })
} }
@ -256,7 +253,7 @@ mod frontend {
/// a != b /// a != b
pub fn neq(&self, other: &FheUint8) -> FheBool { pub fn neq(&self, other: &FheUint8) -> FheBool {
BoolEvaluator::with_local_mut(|e| { BoolEvaluator::with_local_mut(|e| {
let key = ServerKeyEvaluationDomain::global();
let key = RuntimeServerKey::global();
let mut is_equal = arbitrary_bit_equality(e, self.data(), other.data(), key); let mut is_equal = arbitrary_bit_equality(e, self.data(), other.data(), key);
e.not_inplace(&mut is_equal); e.not_inplace(&mut is_equal);
is_equal is_equal
@ -266,7 +263,7 @@ mod frontend {
/// a < b /// a < b
pub fn lt(&self, other: &FheUint8) -> FheBool { pub fn lt(&self, other: &FheUint8) -> FheBool {
BoolEvaluator::with_local_mut(|e| { BoolEvaluator::with_local_mut(|e| {
let key = ServerKeyEvaluationDomain::global();
let key = RuntimeServerKey::global();
arbitrary_bit_comparator(e, other.data(), self.data(), key) arbitrary_bit_comparator(e, other.data(), self.data(), key)
}) })
} }
@ -274,7 +271,7 @@ mod frontend {
/// a > b /// a > b
pub fn gt(&self, other: &FheUint8) -> FheBool { pub fn gt(&self, other: &FheUint8) -> FheBool {
BoolEvaluator::with_local_mut(|e| { BoolEvaluator::with_local_mut(|e| {
let key = ServerKeyEvaluationDomain::global();
let key = RuntimeServerKey::global();
arbitrary_bit_comparator(e, self.data(), other.data(), key) arbitrary_bit_comparator(e, self.data(), other.data(), key)
}) })
} }
@ -282,7 +279,7 @@ mod frontend {
/// a <= b /// a <= b
pub fn le(&self, other: &FheUint8) -> FheBool { pub fn le(&self, other: &FheUint8) -> FheBool {
BoolEvaluator::with_local_mut(|e| { BoolEvaluator::with_local_mut(|e| {
let key = ServerKeyEvaluationDomain::global();
let key = RuntimeServerKey::global();
let mut a_greater_b = let mut a_greater_b =
arbitrary_bit_comparator(e, self.data(), other.data(), key); arbitrary_bit_comparator(e, self.data(), other.data(), key);
e.not_inplace(&mut a_greater_b); e.not_inplace(&mut a_greater_b);
@ -293,7 +290,7 @@ mod frontend {
/// a >= b /// a >= b
pub fn ge(&self, other: &FheUint8) -> FheBool { pub fn ge(&self, other: &FheUint8) -> FheBool {
BoolEvaluator::with_local_mut(|e| { BoolEvaluator::with_local_mut(|e| {
let key = ServerKeyEvaluationDomain::global();
let key = RuntimeServerKey::global();
let mut a_less_b = arbitrary_bit_comparator(e, other.data(), self.data(), key); let mut a_less_b = arbitrary_bit_comparator(e, other.data(), self.data(), key);
e.not_inplace(&mut a_less_b); e.not_inplace(&mut a_less_b);
a_less_b a_less_b

+ 25
- 8
src/utils.rs

@ -1,11 +1,12 @@
use std::{fmt::Debug, usize};
use std::{fmt::Debug, usize, vec};
use itertools::Itertools;
use itertools::{izip, Itertools};
use num_traits::{FromPrimitive, PrimInt, Signed, Unsigned}; use num_traits::{FromPrimitive, PrimInt, Signed, Unsigned};
use crate::{ use crate::{
backend::Modulus, backend::Modulus,
random::{RandomElement, RandomElementInModulus, RandomFill}, random::{RandomElement, RandomElementInModulus, RandomFill},
Matrix,
}; };
pub trait WithLocal { pub trait WithLocal {
fn with_local<F, R>(func: F) -> R fn with_local<F, R>(func: F) -> R
@ -30,10 +31,6 @@ pub(crate) trait ShoupMul {
fn mul(a: Self, b: Self, b_shoup: Self, q: Self) -> Self; fn mul(a: Self, b: Self, b_shoup: Self, q: Self) -> Self;
} }
pub(crate) trait ToShoup {
fn to_shoup(value: Self, modulus: Self) -> Self;
}
impl ShoupMul for u64 { impl ShoupMul for u64 {
#[inline] #[inline]
fn representation(value: Self, q: Self) -> Self { fn representation(value: Self, q: Self) -> Self {
@ -48,9 +45,29 @@ impl ShoupMul for u64 {
} }
} }
pub(crate) trait ToShoup {
type Modulus;
fn to_shoup(value: &Self, modulus: Self::Modulus) -> Self;
}
impl ToShoup for u64 { impl ToShoup for u64 {
fn to_shoup(value: Self, modulus: Self) -> Self {
((value as u128 * (1u128 << 64)) / modulus as u128) as u64
type Modulus = u64;
fn to_shoup(value: &Self, modulus: Self) -> Self {
((*value as u128 * (1u128 << 64)) / modulus as u128) as u64
}
}
impl ToShoup for Vec<Vec<u64>> {
type Modulus = u64;
fn to_shoup(value: &Self, modulus: Self::Modulus) -> Self {
let (row, col) = value.dimension();
let mut shoup_value = vec![vec![0u64; col]; row];
izip!(shoup_value.iter_mut(), value.iter()).for_each(|(shoup_r, r)| {
izip!(shoup_r.iter_mut(), r.iter()).for_each(|(s, e)| {
*s = u64::to_shoup(e, modulus);
})
});
shoup_value
} }
} }

Loading…
Cancel
Save