fix pbs up with shoup

This commit is contained in:
Janmajaya Mall
2024-06-11 13:36:02 +05:30
parent 80ae5d7c8f
commit 590a222c92
10 changed files with 1253 additions and 1236 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -6,7 +6,7 @@ use crate::{
pbs::WithShoupRepr,
random::{NewWithSeed, RandomFillUniformInModulus},
rgsw::RlweSecret,
utils::WithLocal,
utils::{ToShoup, WithLocal},
Decryptor, Encryptor, Matrix, MatrixEntity, MatrixMut, MultiPartyDecryptor, RowEntity, RowMut,
};
@@ -669,7 +669,7 @@ pub(super) mod impl_server_key_eval_domain {
}
/// Server key in evaluation domain
pub(crate) struct ShoupServerKeyEvaluationDomain<M, P, R, N> {
pub(crate) struct ShoupServerKeyEvaluationDomain<M> {
/// Rgsw cts of LWE secret elements
rgsw_cts: Vec<NormalAndShoup<M>>,
/// Auto keys. Key corresponding to g^{k} is at index `k`. Key corresponding
@@ -677,12 +677,18 @@ pub(crate) struct ShoupServerKeyEvaluationDomain<M, P, R, N> {
galois_keys: HashMap<usize, NormalAndShoup<M>>,
/// LWE ksk to key switching LWE ciphertext from RLWE secret to LWE secret
lwe_ksk: M,
parameters: P,
_phanton: PhantomData<(R, N)>,
}
/// Stores normal and shoup representation of Matrix elements (Normal, Shoup)
pub(crate) struct NormalAndShoup<M>(M, M);
impl<M: ToShoup> NormalAndShoup<M> {
fn new_with_modulus(value: M, modulus: <M as ToShoup>::Modulus) -> Self {
let value_shoup = M::to_shoup(&value, modulus);
NormalAndShoup(value, value_shoup)
}
}
impl<M> AsRef<M> for NormalAndShoup<M> {
fn as_ref(&self) -> &M {
&self.0
@@ -697,11 +703,43 @@ impl<M> WithShoupRepr for NormalAndShoup<M> {
}
mod shoup_server_key_eval_domain {
use crate::pbs::PbsKey;
use itertools::{izip, Itertools};
use num_traits::{FromPrimitive, PrimInt};
use crate::{backend::Modulus, pbs::PbsKey};
use super::*;
impl<M: Matrix, P, R, N> PbsKey for ShoupServerKeyEvaluationDomain<M, P, R, N> {
impl<M: MatrixMut + MatrixEntity + ToShoup<Modulus = M::MatElement>, R, N>
From<ServerKeyEvaluationDomain<M, BoolParameters<M::MatElement>, R, N>>
for ShoupServerKeyEvaluationDomain<M>
where
<M as Matrix>::R: RowMut,
M::MatElement: PrimInt + FromPrimitive,
{
fn from(value: ServerKeyEvaluationDomain<M, BoolParameters<M::MatElement>, R, N>) -> Self {
let q = value.parameters.rlwe_q().q().unwrap();
// Rgsw ciphertexts
let rgsw_cts = value
.rgsw_cts
.into_iter()
.map(|ct| NormalAndShoup::new_with_modulus(ct, q))
.collect_vec();
let mut auto_keys = HashMap::new();
value.galois_keys.into_iter().for_each(|(index, key)| {
auto_keys.insert(index, NormalAndShoup::new_with_modulus(key, q));
});
Self {
rgsw_cts,
galois_keys: auto_keys,
lwe_ksk: value.lwe_ksk,
}
}
}
impl<M: Matrix> PbsKey for ShoupServerKeyEvaluationDomain<M> {
type AutoKey = NormalAndShoup<M>;
type LweKskKey = M;
type RgswCt = NormalAndShoup<M>;

View File

@@ -19,17 +19,10 @@ use crate::{
};
thread_local! {
static BOOL_EVALUATOR: RefCell<Option<BoolEvaluator<Vec<Vec<u64>>, NttBackendU64, ModularOpsU64<CiphertextModulus<u64>>, ModularOpsU64<CiphertextModulus<u64>>>>> = RefCell::new(None);
static BOOL_EVALUATOR: RefCell<Option<BoolEvaluator<Vec<Vec<u64>>, NttBackendU64, ModularOpsU64<CiphertextModulus<u64>>, ModularOpsU64<CiphertextModulus<u64>>, ShoupServerKeyEvaluationDomain<Vec<Vec<u64>>>>>> = RefCell::new(None);
}
static BOOL_SERVER_KEY: OnceLock<
ShoupServerKeyEvaluationDomain<
Vec<Vec<u64>>,
BoolParameters<u64>,
DefaultSecureRng,
NttBackendU64,
>,
> = OnceLock::new();
static BOOL_SERVER_KEY: OnceLock<ShoupServerKeyEvaluationDomain<Vec<Vec<u64>>>> = OnceLock::new();
static MULTI_PARTY_CRS: OnceLock<MultiPartyCrs<[u8; 32]>> = OnceLock::new();
@@ -44,14 +37,7 @@ pub fn set_mp_seed(seed: [u8; 32]) {
)
}
fn set_server_key(
key: ShoupServerKeyEvaluationDomain<
Vec<Vec<u64>>,
BoolParameters<u64>,
DefaultSecureRng,
NttBackendU64,
>,
) {
fn set_server_key(key: ShoupServerKeyEvaluationDomain<Vec<Vec<u64>>>) {
assert!(
BOOL_SERVER_KEY.set(key).is_ok(),
"Attempted to set server key twice."
@@ -64,7 +50,7 @@ pub(crate) fn gen_keys() -> (
) {
BoolEvaluator::with_local_mut(|e| {
let ck = e.client_key();
let sk = e.server_key(&ck);
let sk = e.single_party_server_key(&ck);
(ck, sk)
})
@@ -115,15 +101,11 @@ pub fn aggregate_server_key_shares(
BoolEvaluator::with_local(|e| e.aggregate_multi_party_server_key_shares(shares))
}
// SERVER KEY EVAL DOMAIN //
// SERVER KEY EVAL (/SHOUP) DOMAIN //
impl SeededServerKey<Vec<Vec<u64>>, BoolParameters<u64>, [u8; 32]> {
pub fn set_server_key(&self) {
set_server_key(ServerKeyEvaluationDomain::<
_,
_,
DefaultSecureRng,
NttBackendU64,
>::from(self));
let eval = ServerKeyEvaluationDomain::<_, _, DefaultSecureRng, NttBackendU64>::from(self);
set_server_key(ShoupServerKeyEvaluationDomain::from(eval));
}
}
@@ -135,25 +117,9 @@ impl
>
{
pub fn set_server_key(&self) {
set_server_key(ServerKeyEvaluationDomain::<
_,
_,
DefaultSecureRng,
NttBackendU64,
>::from(self))
}
}
impl Global
for ShoupServerKeyEvaluationDomain<
Vec<Vec<u64>>,
BoolParameters<u64>,
DefaultSecureRng,
NttBackendU64,
>
{
fn global() -> &'static Self {
BOOL_SERVER_KEY.get().unwrap()
set_server_key(ShoupServerKeyEvaluationDomain::from(
ServerKeyEvaluationDomain::<_, _, DefaultSecureRng, NttBackendU64>::from(self),
))
}
}
@@ -173,6 +139,7 @@ impl WithLocal
NttBackendU64,
ModularOpsU64<CiphertextModulus<u64>>,
ModularOpsU64<CiphertextModulus<u64>>,
ShoupServerKeyEvaluationDomain<Vec<Vec<u64>>>,
>
{
fn with_local<F, R>(func: F) -> R
@@ -196,3 +163,10 @@ impl WithLocal
BOOL_EVALUATOR.with_borrow_mut(|s| func(s.as_mut().expect("Parameters not set")))
}
}
pub(crate) type RuntimeServerKey = ShoupServerKeyEvaluationDomain<Vec<Vec<u64>>>;
impl Global for RuntimeServerKey {
fn global() -> &'static Self {
BOOL_SERVER_KEY.get().expect("Server key not set!")
}
}

View File

@@ -1,5 +1,3 @@
use std::cell::RefCell;
mod test {
use itertools::{izip, Itertools};
@@ -7,7 +5,8 @@ mod test {
backend::{ArithmeticOps, ModularOpsU64, Modulus},
bool::{
set_parameter_set, BoolEncoding, BoolEvaluator, BooleanGates, CiphertextModulus,
ClientKey, PublicKey, ServerKeyEvaluationDomain, MP_BOOL_PARAMS, SMALL_MP_BOOL_PARAMS,
ClientKey, PublicKey, ServerKeyEvaluationDomain, ShoupServerKeyEvaluationDomain,
MP_BOOL_PARAMS, SMALL_MP_BOOL_PARAMS,
},
lwe::{decrypt_lwe, LweSecret},
ntt::NttBackendU64,
@@ -15,7 +14,7 @@ mod test {
random::DefaultSecureRng,
rgsw::RlweSecret,
utils::Stats,
Secret,
Ntt, Secret,
};
#[test]
@@ -26,6 +25,7 @@ mod test {
NttBackendU64,
ModularOpsU64<CiphertextModulus<u64>>,
ModularOpsU64<CiphertextModulus<u64>>,
ShoupServerKeyEvaluationDomain<Vec<Vec<u64>>>,
>::new(SMALL_MP_BOOL_PARAMS);
let parties = 2;
@@ -84,7 +84,12 @@ mod test {
.collect_vec();
let server_key = evaluator.aggregate_multi_party_server_key_shares(&server_key_shares);
let server_key_eval_domain = ServerKeyEvaluationDomain::from(&server_key);
let runtime_server_key = ShoupServerKeyEvaluationDomain::from(ServerKeyEvaluationDomain::<
_,
_,
DefaultSecureRng,
NttBackendU64,
>::from(&server_key));
let mut m0 = false;
let mut m1 = true;
@@ -99,7 +104,7 @@ mod test {
for _ in 0..1000 {
let now = std::time::Instant::now();
let c_out = evaluator.xor(&c_m0, &c_m1, &server_key_eval_domain);
let c_out = evaluator.xor(&c_m0, &c_m1, &runtime_server_key);
println!("Gate time: {:?}", now.elapsed());
// mp decrypt