mirror of
https://github.com/arnaucube/phantom-zone.git
synced 2026-01-09 15:41:30 +01:00
clean lwe
This commit is contained in:
@@ -11,7 +11,7 @@ fn fhe_circuit(a: &FheUint8, b: &FheUint8, c: &FheUint8, d: &FheUint8) -> FheUin
|
||||
}
|
||||
|
||||
fn main() {
|
||||
set_parameter_set(ParameterSelector::NonInteractiveMultiPartyLessThanOrEqualTo16);
|
||||
set_parameter_set(ParameterSelector::NonInteractiveLTE2Party);
|
||||
|
||||
// set CRS
|
||||
let mut seed = [0u8; 32];
|
||||
|
||||
@@ -5,23 +5,18 @@ use std::{
|
||||
usize,
|
||||
};
|
||||
|
||||
use itertools::{izip, partition, Itertools};
|
||||
use num_traits::{
|
||||
zero, FromPrimitive, Num, One, Pow, PrimInt, ToPrimitive, WrappingAdd, WrappingSub, Zero,
|
||||
};
|
||||
use rand::Rng;
|
||||
use itertools::{izip, Itertools};
|
||||
use num_traits::{FromPrimitive, One, PrimInt, ToPrimitive, WrappingAdd, WrappingSub, Zero};
|
||||
use rand_distr::uniform::SampleUniform;
|
||||
|
||||
use crate::{
|
||||
backend::{
|
||||
ArithmeticOps, GetModulus, ModInit, ModularOpsU64, Modulus, ShoupMatrixFMA, VectorOps,
|
||||
},
|
||||
backend::{ArithmeticOps, GetModulus, ModInit, Modulus, ShoupMatrixFMA, VectorOps},
|
||||
bool::parameters::ParameterVariant,
|
||||
decomposer::{Decomposer, DefaultDecomposer, NumInfo, RlweDecomposer},
|
||||
lwe::{decrypt_lwe, encrypt_lwe, lwe_key_switch, lwe_ksk_keygen, measure_noise_lwe, LweSecret},
|
||||
lwe::{decrypt_lwe, encrypt_lwe, seeded_lwe_ksk_keygen},
|
||||
multi_party::{
|
||||
non_interactive_ksk_gen, non_interactive_ksk_zero_encryptions_for_other_party_i,
|
||||
non_interactive_rgsw_ct, public_key_share,
|
||||
public_key_share,
|
||||
},
|
||||
ntt::{self, Ntt, NttBackendU64, NttInit},
|
||||
pbs::{pbs, sample_extract, PbsInfo, PbsKey, WithShoupRepr},
|
||||
@@ -48,8 +43,7 @@ use super::{
|
||||
CommonReferenceSeededNonInteractiveMultiPartyServerKeyShare,
|
||||
InteractiveMultiPartyClientKey, NonInteractiveMultiPartyClientKey,
|
||||
SeededMultiPartyServerKey, SeededNonInteractiveMultiPartyServerKey,
|
||||
SeededSinglePartyServerKey, ServerKeyEvaluationDomain, ShoupServerKeyEvaluationDomain,
|
||||
SinglePartyClientKey,
|
||||
SeededSinglePartyServerKey, SinglePartyClientKey,
|
||||
},
|
||||
parameters::{
|
||||
BoolParameters, CiphertextModulus, DecompositionCount, DecompostionLogBase,
|
||||
@@ -67,9 +61,8 @@ use super::{
|
||||
/// Initial Seed:
|
||||
/// Puncture 1 -> Public key share seed
|
||||
/// Puncture 2 -> Main server key share seed
|
||||
/// Puncture 1 -> RGSW cuphertexts seed
|
||||
/// Puncture 2 -> Auto keys cipertexts seed
|
||||
/// Puncture 3 -> LWE ksk seed
|
||||
/// Puncture 1 -> Auto keys cipertexts seed
|
||||
/// Puncture 2 -> LWE ksk seed
|
||||
#[derive(Clone, PartialEq)]
|
||||
pub struct MultiPartyCrs<S> {
|
||||
pub(super) seed: S,
|
||||
@@ -97,19 +90,14 @@ impl<S: Default + Copy> MultiPartyCrs<S> {
|
||||
puncture_p_rng(&mut prng, 2)
|
||||
}
|
||||
|
||||
pub(super) fn rgsw_cts_seed<Rng: NewWithSeed<Seed = S> + RandomFill<S>>(&self) -> S {
|
||||
pub(super) fn auto_keys_cts_seed<Rng: NewWithSeed<Seed = S> + RandomFill<S>>(&self) -> S {
|
||||
let mut key_prng = Rng::new_with_seed(self.key_seed::<Rng>());
|
||||
puncture_p_rng(&mut key_prng, 1)
|
||||
}
|
||||
|
||||
pub(super) fn auto_keys_cts_seed<Rng: NewWithSeed<Seed = S> + RandomFill<S>>(&self) -> S {
|
||||
let mut key_prng = Rng::new_with_seed(self.key_seed::<Rng>());
|
||||
puncture_p_rng(&mut key_prng, 2)
|
||||
}
|
||||
|
||||
pub(super) fn lwe_ksk_cts_seed_seed<Rng: NewWithSeed<Seed = S> + RandomFill<S>>(&self) -> S {
|
||||
let mut key_prng = Rng::new_with_seed(self.key_seed::<Rng>());
|
||||
puncture_p_rng(&mut key_prng, 3)
|
||||
puncture_p_rng(&mut key_prng, 2)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -119,7 +107,8 @@ impl<S: Default + Copy> MultiPartyCrs<S> {
|
||||
/// Puncture 1 -> Key Seed
|
||||
/// Puncture 1 -> Rgsw ciphertext seed
|
||||
/// Puncture l+1 -> Seed for zero encs and non-interactive
|
||||
/// multi-party RGSW ciphertext corresponding to l^th LWE index.
|
||||
/// multi-party RGSW ciphertexts of
|
||||
/// l^th LWE index.
|
||||
/// Puncture 2 -> auto keys seed
|
||||
/// Puncture 3 -> Lwe key switching key seed
|
||||
/// Puncture 2 -> user specific seed for u_j to s ksk
|
||||
@@ -931,12 +920,9 @@ where
|
||||
|
||||
// LWE KSK from RLWE secret s -> LWE secret z
|
||||
let d_lwe_gadget = self.pbs_info.lwe_decomposer.gadget_vector();
|
||||
let mut lwe_ksk =
|
||||
M::R::zeros(self.pbs_info.lwe_decomposer.decomposition_count() * ring_size);
|
||||
lwe_ksk_keygen(
|
||||
let lwe_ksk = seeded_lwe_ksk_keygen(
|
||||
&sk_rlwe,
|
||||
&sk_lwe,
|
||||
&mut lwe_ksk,
|
||||
&d_lwe_gadget,
|
||||
&self.pbs_info.lwe_modop,
|
||||
&mut main_prng,
|
||||
@@ -2049,21 +2035,16 @@ where
|
||||
) -> M::R {
|
||||
DefaultSecureRng::with_local_mut(|rng| {
|
||||
let mut p_rng = DefaultSecureRng::new_seeded(lwe_ksk_seed);
|
||||
let mut lwe_ksk = M::R::zeros(
|
||||
self.pbs_info.lwe_decomposer.decomposition_count() * self.parameters().rlwe_n().0,
|
||||
);
|
||||
let lwe_modop = &self.pbs_info.lwe_modop;
|
||||
let d_lwe_gadget_vec = self.pbs_info.lwe_decomposer.gadget_vector();
|
||||
lwe_ksk_keygen(
|
||||
seeded_lwe_ksk_keygen(
|
||||
sk_rlwe,
|
||||
sk_lwe,
|
||||
&mut lwe_ksk,
|
||||
&d_lwe_gadget_vec,
|
||||
lwe_modop,
|
||||
&mut p_rng,
|
||||
rng,
|
||||
);
|
||||
lwe_ksk
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2238,15 +2219,7 @@ where
|
||||
};
|
||||
|
||||
DefaultSecureRng::with_local_mut(|rng| {
|
||||
let mut lwe_out = M::R::zeros(self.pbs_info.parameters.rlwe_n().0 + 1);
|
||||
encrypt_lwe(
|
||||
&mut lwe_out,
|
||||
&m,
|
||||
&client_key.sk_rlwe(),
|
||||
&self.pbs_info.rlwe_modop,
|
||||
rng,
|
||||
);
|
||||
lwe_out
|
||||
encrypt_lwe(&m, &client_key.sk_rlwe(), &self.pbs_info.rlwe_modop, rng)
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ use std::{collections::HashMap, hash::Hash, marker::PhantomData};
|
||||
|
||||
use crate::{
|
||||
backend::{ModInit, VectorOps},
|
||||
lwe::LweSecret,
|
||||
pbs::WithShoupRepr,
|
||||
random::{NewWithSeed, RandomFillUniformInModulus},
|
||||
rgsw::RlweSecret,
|
||||
|
||||
@@ -22,7 +22,8 @@ pub type ClientKey = keys::ClientKey<[u8; 32], u64>;
|
||||
pub enum ParameterSelector {
|
||||
HighCommunicationButFast2Party,
|
||||
MultiPartyLessThanOrEqualTo16,
|
||||
NonInteractiveMultiPartyLessThanOrEqualTo16,
|
||||
NonInteractiveLTE2Party,
|
||||
NonInteractiveLTE4Party,
|
||||
}
|
||||
|
||||
mod common_mp_enc_dec {
|
||||
|
||||
@@ -3,6 +3,7 @@ use std::{cell::RefCell, sync::OnceLock};
|
||||
use crate::{
|
||||
backend::ModulusPowerOf2,
|
||||
bool::parameters::ParameterVariant,
|
||||
parameters::NI_4P,
|
||||
random::DefaultSecureRng,
|
||||
utils::{Global, WithLocal},
|
||||
ModularOpsU64, NttBackendU64,
|
||||
@@ -38,9 +39,13 @@ static MULTI_PARTY_CRS: OnceLock<NonInteractiveMultiPartyCrs<[u8; 32]>> = OnceLo
|
||||
|
||||
pub fn set_parameter_set(select: ParameterSelector) {
|
||||
match select {
|
||||
ParameterSelector::NonInteractiveMultiPartyLessThanOrEqualTo16 => {
|
||||
ParameterSelector::NonInteractiveLTE2Party => {
|
||||
BOOL_EVALUATOR.with_borrow_mut(|v| *v = Some(BoolEvaluator::new(NI_2P)));
|
||||
}
|
||||
ParameterSelector::NonInteractiveLTE4Party => {
|
||||
BOOL_EVALUATOR.with_borrow_mut(|v| *v = Some(BoolEvaluator::new(NI_4P)));
|
||||
}
|
||||
|
||||
_ => {
|
||||
panic!("Paramerters not supported")
|
||||
}
|
||||
@@ -160,6 +165,13 @@ impl Global for RuntimeServerKey {
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) struct NonInteractiveBatchedFheBools<C> {
|
||||
data: Vec<C>,
|
||||
}
|
||||
pub(super) struct BatchedFheBools<C> {
|
||||
pub(in super::super) data: Vec<C>,
|
||||
}
|
||||
|
||||
/// Non interactive multi-party specfic encryptor decryptor routines
|
||||
mod impl_enc_dec {
|
||||
use crate::{
|
||||
@@ -177,10 +189,6 @@ mod impl_enc_dec {
|
||||
|
||||
type Mat = Vec<Vec<u64>>;
|
||||
|
||||
pub(super) struct BatchedFheBools<C> {
|
||||
pub(super) data: Vec<C>,
|
||||
}
|
||||
|
||||
impl<C: MatrixMut<MatElement = u64>> BatchedFheBools<C>
|
||||
where
|
||||
C::R: RowEntity + RowMut,
|
||||
@@ -202,10 +210,6 @@ mod impl_enc_dec {
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) struct NonInteractiveBatchedFheBools<C> {
|
||||
data: Vec<C>,
|
||||
}
|
||||
|
||||
impl<M: MatrixEntity + MatrixMut<MatElement = u64>> From<&(Vec<M::R>, [u8; 32])>
|
||||
for NonInteractiveBatchedFheBools<M>
|
||||
where
|
||||
@@ -349,10 +353,9 @@ mod impl_enc_dec {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use impl_enc_dec::NonInteractiveBatchedFheBools;
|
||||
use itertools::{izip, Itertools};
|
||||
use num_traits::{FromPrimitive, PrimInt, ToPrimitive, Zero};
|
||||
use rand::{thread_rng, RngCore};
|
||||
use rand::{thread_rng, Rng, RngCore};
|
||||
|
||||
use crate::{
|
||||
backend::{GetModulus, Modulus},
|
||||
@@ -374,7 +377,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn non_interactive_mp_bool_nand() {
|
||||
set_parameter_set(ParameterSelector::NonInteractiveMultiPartyLessThanOrEqualTo16);
|
||||
set_parameter_set(ParameterSelector::NonInteractiveLTE2Party);
|
||||
let mut seed = [0u8; 32];
|
||||
thread_rng().fill_bytes(&mut seed);
|
||||
set_common_reference_seed(seed);
|
||||
@@ -444,63 +447,4 @@ mod tests {
|
||||
ct0 = ct_out;
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn trialtest() {
|
||||
set_parameter_set(ParameterSelector::NonInteractiveMultiPartyLessThanOrEqualTo16);
|
||||
set_common_reference_seed([2; 32]);
|
||||
|
||||
let parties = 2;
|
||||
|
||||
let cks = (0..parties).map(|_| gen_client_key()).collect_vec();
|
||||
|
||||
let key_shares = cks
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(user_index, ck)| gen_server_key_share(user_index, parties, ck))
|
||||
.collect_vec();
|
||||
|
||||
let seeded_server_key = aggregate_server_key_shares(&key_shares);
|
||||
seeded_server_key.set_server_key();
|
||||
|
||||
let m = vec![false, true];
|
||||
let ct: NonInteractiveBatchedFheBools<_> = cks[0].encrypt(m.as_slice());
|
||||
let ct = ct.key_switch(0);
|
||||
|
||||
let parameters = BoolEvaluator::with_local(|e| e.parameters().clone());
|
||||
let nttop = NttBackendU64::new(parameters.rlwe_q(), parameters.rlwe_n().0);
|
||||
let rlwe_q_modop = ModularOpsU64::new(*parameters.rlwe_q());
|
||||
|
||||
let mut ideal_rlwe_sk = vec![0i32; parameters.rlwe_n().0];
|
||||
cks.iter().for_each(|k| {
|
||||
let sk_rlwe = k.sk_rlwe();
|
||||
izip!(ideal_rlwe_sk.iter_mut(), sk_rlwe.iter()).for_each(|(a, b)| {
|
||||
*a = *a + b;
|
||||
});
|
||||
});
|
||||
|
||||
let message = m
|
||||
.iter()
|
||||
.map(|b| parameters.rlwe_q().encode(*b))
|
||||
.collect_vec();
|
||||
|
||||
let mut m_out = vec![0u64; parameters.rlwe_n().0];
|
||||
decrypt_rlwe(
|
||||
&ct.data[0],
|
||||
&ideal_rlwe_sk,
|
||||
&mut m_out,
|
||||
&nttop,
|
||||
&rlwe_q_modop,
|
||||
);
|
||||
|
||||
let mut diff = m_out;
|
||||
rlwe_q_modop.elwise_sub_mut(diff.as_mut_slice(), message.as_ref());
|
||||
|
||||
let mut stats = Stats::new();
|
||||
stats.add_more(&Vec::<i64>::try_convert_from(
|
||||
diff.as_slice(),
|
||||
parameters.rlwe_q(),
|
||||
));
|
||||
println!("Noise: {}", stats.std_dev().abs().log2());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -534,6 +534,31 @@ pub(crate) const NI_2P: BoolParameters<u64> = BoolParameters::<u64> {
|
||||
variant: ParameterVariant::NonInteractiveMultiParty,
|
||||
};
|
||||
|
||||
pub(crate) const NI_4P: BoolParameters<u64> = BoolParameters::<u64> {
|
||||
rlwe_q: CiphertextModulus::new_non_native(18014398509404161),
|
||||
lwe_q: CiphertextModulus::new_non_native(1 << 16),
|
||||
br_q: 1 << 11,
|
||||
rlwe_n: PolynomialSize(1 << 11),
|
||||
lwe_n: LweDimension(510),
|
||||
lwe_decomposer_params: (DecompostionLogBase(1), DecompositionCount(12)),
|
||||
rlrg_decomposer_params: (
|
||||
DecompostionLogBase(17),
|
||||
(DecompositionCount(1), DecompositionCount(1)),
|
||||
),
|
||||
rgrg_decomposer_params: Some((
|
||||
DecompostionLogBase(4),
|
||||
(DecompositionCount(10), DecompositionCount(9)),
|
||||
)),
|
||||
auto_decomposer_params: (DecompostionLogBase(24), DecompositionCount(1)),
|
||||
non_interactive_ui_to_s_key_switch_decomposer: Some((
|
||||
DecompostionLogBase(1),
|
||||
DecompositionCount(50),
|
||||
)),
|
||||
g: 5,
|
||||
w: 10,
|
||||
variant: ParameterVariant::NonInteractiveMultiParty,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) const SP_TEST_BOOL_PARAMS: BoolParameters<u64> = BoolParameters::<u64> {
|
||||
rlwe_q: CiphertextModulus::new_non_native(268369921u64),
|
||||
|
||||
@@ -427,7 +427,7 @@ mod tests {
|
||||
NttBackendU64,
|
||||
};
|
||||
|
||||
set_parameter_set(crate::ParameterSelector::NonInteractiveMultiPartyLessThanOrEqualTo16);
|
||||
set_parameter_set(crate::ParameterSelector::NonInteractiveLTE2Party);
|
||||
set_common_reference_seed(NonInteractiveMultiPartyCrs::random().seed);
|
||||
let parties = 2;
|
||||
let cks = (0..parties).map(|i| gen_client_key()).collect_vec();
|
||||
@@ -469,4 +469,73 @@ mod tests {
|
||||
server_key_stats.post_lwe_key_switch.std_dev().abs().log2()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "non_interactive_mp")]
|
||||
fn enc_under_sk_and_key_switch() {
|
||||
use rand::{thread_rng, Rng};
|
||||
|
||||
use crate::{
|
||||
aggregate_server_key_shares,
|
||||
bool::{keys::tests::ideal_sk_rlwe, ni_mp_api::NonInteractiveBatchedFheBools},
|
||||
gen_client_key, gen_server_key_share,
|
||||
rgsw::decrypt_rlwe,
|
||||
set_common_reference_seed, set_parameter_set,
|
||||
utils::{tests::Stats, TryConvertFrom1, WithLocal},
|
||||
BoolEvaluator, Encoder, Encryptor, KeySwitchWithId, ModInit, ModularOpsU64,
|
||||
NttBackendU64, NttInit, ParameterSelector, VectorOps,
|
||||
};
|
||||
|
||||
set_parameter_set(ParameterSelector::NonInteractiveLTE2Party);
|
||||
set_common_reference_seed([2; 32]);
|
||||
|
||||
let parties = 2;
|
||||
|
||||
let cks = (0..parties).map(|_| gen_client_key()).collect_vec();
|
||||
|
||||
let key_shares = cks
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(user_index, ck)| gen_server_key_share(user_index, parties, ck))
|
||||
.collect_vec();
|
||||
|
||||
let seeded_server_key = aggregate_server_key_shares(&key_shares);
|
||||
seeded_server_key.set_server_key();
|
||||
|
||||
let parameters = BoolEvaluator::with_local(|e| e.parameters().clone());
|
||||
let nttop = NttBackendU64::new(parameters.rlwe_q(), parameters.rlwe_n().0);
|
||||
let rlwe_q_modop = ModularOpsU64::new(*parameters.rlwe_q());
|
||||
|
||||
let m = (0..parameters.rlwe_n().0)
|
||||
.map(|_| thread_rng().gen_bool(0.5))
|
||||
.collect_vec();
|
||||
let ct: NonInteractiveBatchedFheBools<_> = cks[0].encrypt(m.as_slice());
|
||||
let ct = ct.key_switch(0);
|
||||
|
||||
let ideal_rlwe_sk = ideal_sk_rlwe(&cks);
|
||||
|
||||
let message = m
|
||||
.iter()
|
||||
.map(|b| parameters.rlwe_q().encode(*b))
|
||||
.collect_vec();
|
||||
|
||||
let mut m_out = vec![0u64; parameters.rlwe_n().0];
|
||||
decrypt_rlwe(
|
||||
&ct.data[0],
|
||||
&ideal_rlwe_sk,
|
||||
&mut m_out,
|
||||
&nttop,
|
||||
&rlwe_q_modop,
|
||||
);
|
||||
|
||||
let mut diff = m_out;
|
||||
rlwe_q_modop.elwise_sub_mut(diff.as_mut_slice(), message.as_ref());
|
||||
|
||||
let mut stats = Stats::new();
|
||||
stats.add_more(&Vec::<i64>::try_convert_from(
|
||||
diff.as_slice(),
|
||||
parameters.rlwe_q(),
|
||||
));
|
||||
println!("Noise std log2: {}", stats.std_dev().abs().log2());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,14 +1,8 @@
|
||||
use itertools::{izip, Itertools};
|
||||
use num_traits::{
|
||||
AsPrimitive, FromPrimitive, Num, One, PrimInt, ToPrimitive, WrappingAdd, WrappingSub, Zero,
|
||||
};
|
||||
use std::{
|
||||
fmt::{Debug, Display},
|
||||
marker::PhantomData,
|
||||
ops::Rem,
|
||||
};
|
||||
use num_traits::{FromPrimitive, PrimInt, ToPrimitive, WrappingAdd, WrappingSub};
|
||||
use std::fmt::{Debug, Display};
|
||||
|
||||
use crate::backend::{ArithmeticOps, ModularOpsU64};
|
||||
use crate::backend::ArithmeticOps;
|
||||
|
||||
fn gadget_vector<T: PrimInt>(logq: usize, logb: usize, d: usize) -> Vec<T> {
|
||||
assert!(logq >= (logb * d));
|
||||
@@ -146,7 +140,6 @@ impl<
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(Jay): Outline the caveat
|
||||
fn decompose_to_vec(&self, value: &T) -> Vec<T> {
|
||||
let q = self.q;
|
||||
let logb = self.logb;
|
||||
@@ -283,7 +276,7 @@ mod tests {
|
||||
use crate::{
|
||||
backend::{ModInit, ModularOpsU64, Modulus},
|
||||
decomposer::round_value,
|
||||
utils::{generate_prime, tests::Stats, TryConvertFrom1},
|
||||
utils::{generate_prime, tests::Stats},
|
||||
};
|
||||
|
||||
use super::{Decomposer, DefaultDecomposer};
|
||||
@@ -297,7 +290,7 @@ mod tests {
|
||||
for logq in [37, 55] {
|
||||
let logb = 11;
|
||||
let d = 3;
|
||||
let mut stats = vec![Stats::new(); d];
|
||||
// let mut stats = vec![Stats::new(); d];
|
||||
|
||||
for i in [true, false] {
|
||||
let q = if i {
|
||||
@@ -319,19 +312,19 @@ mod tests {
|
||||
let rounded_value = round_value(value, decomposer.ignore_bits);
|
||||
assert!((rounded_value as i64 - value_back as i64).abs() <= 1,);
|
||||
|
||||
izip!(stats.iter_mut(), limbs.iter()).for_each(|(s, l)| {
|
||||
s.add_more(&vec![q.map_element_to_i64(l)]);
|
||||
});
|
||||
// izip!(stats.iter_mut(), limbs.iter()).for_each(|(s, l)| {
|
||||
// s.add_more(&vec![q.map_element_to_i64(l)]);
|
||||
// });
|
||||
}
|
||||
}
|
||||
|
||||
stats.iter().enumerate().for_each(|(index, s)| {
|
||||
println!(
|
||||
"Limb {index} - Mean: {}, Std: {}",
|
||||
s.mean(),
|
||||
s.std_dev().abs().log2()
|
||||
);
|
||||
});
|
||||
// stats.iter().enumerate().for_each(|(index, s)| {
|
||||
// println!(
|
||||
// "Limb {index} - Mean: {}, Std: {}",
|
||||
// s.mean(),
|
||||
// s.std_dev().abs().log2()
|
||||
// );
|
||||
// });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
298
src/lwe.rs
298
src/lwe.rs
@@ -1,109 +1,16 @@
|
||||
use std::{
|
||||
cell::RefCell,
|
||||
collections::btree_map::Values,
|
||||
fmt::{Debug, Display},
|
||||
marker::PhantomData,
|
||||
};
|
||||
use std::fmt::Debug;
|
||||
|
||||
use itertools::{izip, Itertools};
|
||||
use num_traits::{abs, PrimInt, ToPrimitive, Zero};
|
||||
use itertools::izip;
|
||||
use num_traits::Zero;
|
||||
|
||||
use crate::{
|
||||
backend::{ArithmeticOps, GetModulus, Modulus, VectorOps},
|
||||
backend::{ArithmeticOps, GetModulus, VectorOps},
|
||||
decomposer::Decomposer,
|
||||
random::{
|
||||
DefaultSecureRng, NewWithSeed, RandomFillGaussianInModulus, RandomFillUniformInModulus,
|
||||
RandomGaussianElementInModulus, DEFAULT_RNG,
|
||||
},
|
||||
utils::{fill_random_ternary_secret_with_hamming_weight, TryConvertFrom1, WithLocal},
|
||||
Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, Secret,
|
||||
random::{RandomFillUniformInModulus, RandomGaussianElementInModulus},
|
||||
utils::TryConvertFrom1,
|
||||
Matrix, Row, RowEntity, RowMut,
|
||||
};
|
||||
|
||||
struct SeededLweKeySwitchingKey<Ro, S>
|
||||
where
|
||||
Ro: Row,
|
||||
{
|
||||
data: Ro,
|
||||
seed: S,
|
||||
to_lwe_n: usize,
|
||||
modulus: Ro::Element,
|
||||
}
|
||||
|
||||
impl<Ro: RowEntity, S> SeededLweKeySwitchingKey<Ro, S> {
|
||||
pub(crate) fn empty(
|
||||
from_lwe_n: usize,
|
||||
to_lwe_n: usize,
|
||||
d: usize,
|
||||
seed: S,
|
||||
modulus: Ro::Element,
|
||||
) -> Self {
|
||||
let data = Ro::zeros(from_lwe_n * d);
|
||||
SeededLweKeySwitchingKey {
|
||||
data,
|
||||
to_lwe_n,
|
||||
seed,
|
||||
modulus,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct LweKeySwitchingKey<M, R> {
|
||||
data: M,
|
||||
_phantom: PhantomData<R>,
|
||||
}
|
||||
|
||||
impl<
|
||||
M: MatrixMut + MatrixEntity,
|
||||
R: NewWithSeed + RandomFillUniformInModulus<[M::MatElement], M::MatElement>,
|
||||
> From<&SeededLweKeySwitchingKey<M::R, R::Seed>> for LweKeySwitchingKey<M, R>
|
||||
where
|
||||
M::R: RowMut,
|
||||
R::Seed: Clone,
|
||||
M::MatElement: Copy,
|
||||
{
|
||||
fn from(value: &SeededLweKeySwitchingKey<M::R, R::Seed>) -> Self {
|
||||
let mut p_rng = R::new_with_seed(value.seed.clone());
|
||||
let mut data = M::zeros(value.data.as_ref().len(), value.to_lwe_n + 1);
|
||||
izip!(value.data.as_ref().iter(), data.iter_rows_mut()).for_each(|(bi, lwe_i)| {
|
||||
RandomFillUniformInModulus::random_fill(
|
||||
&mut p_rng,
|
||||
&value.modulus,
|
||||
&mut lwe_i.as_mut()[1..],
|
||||
);
|
||||
lwe_i.as_mut()[0] = *bi;
|
||||
});
|
||||
LweKeySwitchingKey {
|
||||
data,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
trait LweCiphertext<M: Matrix> {}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct LweSecret {
|
||||
pub(crate) values: Vec<i32>,
|
||||
}
|
||||
|
||||
impl Secret for LweSecret {
|
||||
type Element = i32;
|
||||
fn values(&self) -> &[Self::Element] {
|
||||
&self.values
|
||||
}
|
||||
}
|
||||
|
||||
impl LweSecret {
|
||||
pub(crate) fn random(hw: usize, n: usize) -> LweSecret {
|
||||
DefaultSecureRng::with_local_mut(|rng| {
|
||||
let mut out = vec![0i32; n];
|
||||
fill_random_ternary_secret_with_hamming_weight(&mut out, hw, rng);
|
||||
|
||||
LweSecret { values: out }
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn lwe_key_switch<
|
||||
M: Matrix,
|
||||
Ro: AsMut<[M::MatElement]> + AsRef<[M::MatElement]>,
|
||||
@@ -127,15 +34,17 @@ pub(crate) fn lwe_key_switch<
|
||||
.skip(1)
|
||||
.flat_map(|ai| decomposer.decompose_iter(ai));
|
||||
izip!(lwe_in_a_decomposed, lwe_ksk.iter_rows()).for_each(|(ai_j, beta_ij_lwe)| {
|
||||
// let now = std::time::Instant::now();
|
||||
operator.elwise_fma_scalar_mut(lwe_out.as_mut(), beta_ij_lwe.as_ref(), &ai_j);
|
||||
// println!("Time elwise_fma_scalar_mut: {:?}", now.elapsed());
|
||||
});
|
||||
|
||||
let out_b = operator.add(&lwe_out.as_ref()[0], &lwe_in.as_ref()[0]);
|
||||
lwe_out.as_mut()[0] = out_b;
|
||||
}
|
||||
|
||||
pub fn lwe_ksk_keygen<
|
||||
Ro: Row + RowMut + RowEntity,
|
||||
pub fn seeded_lwe_ksk_keygen<
|
||||
Ro: RowMut + RowEntity,
|
||||
S,
|
||||
Op: VectorOps<Element = Ro::Element>
|
||||
+ ArithmeticOps<Element = Ro::Element>
|
||||
@@ -145,16 +54,16 @@ pub fn lwe_ksk_keygen<
|
||||
>(
|
||||
from_lwe_sk: &[S],
|
||||
to_lwe_sk: &[S],
|
||||
ksk_out: &mut Ro,
|
||||
gadget: &[Ro::Element],
|
||||
operator: &Op,
|
||||
p_rng: &mut PR,
|
||||
rng: &mut R,
|
||||
) where
|
||||
) -> Ro
|
||||
where
|
||||
Ro: TryConvertFrom1<[S], Op::M>,
|
||||
Ro::Element: Zero + Debug,
|
||||
{
|
||||
assert!(ksk_out.as_ref().len() == (from_lwe_sk.len() * gadget.len()));
|
||||
let mut ksk_out = Ro::zeros(from_lwe_sk.len() * gadget.len());
|
||||
|
||||
let d = gadget.len();
|
||||
|
||||
@@ -167,7 +76,7 @@ pub fn lwe_ksk_keygen<
|
||||
|
||||
izip!(neg_sk_in_m.as_ref(), ksk_out.as_mut().chunks_mut(d)).for_each(
|
||||
|(neg_sk_in_si, d_lwes_partb)| {
|
||||
izip!(gadget.iter(), d_lwes_partb.into_iter()).for_each(|(f, lwe_b)| {
|
||||
izip!(gadget.iter(), d_lwes_partb.into_iter()).for_each(|(beta, lwe_b)| {
|
||||
// sample `a`
|
||||
RandomFillUniformInModulus::random_fill(p_rng, &modulus, scratch.as_mut());
|
||||
|
||||
@@ -179,7 +88,7 @@ pub fn lwe_ksk_keygen<
|
||||
});
|
||||
|
||||
// a*z + (-s_i)*\beta^j + e
|
||||
let mut b = operator.add(&az, &operator.mul(f, neg_sk_in_si));
|
||||
let mut b = operator.add(&az, &operator.mul(beta, neg_sk_in_si));
|
||||
let e = RandomGaussianElementInModulus::random(rng, &modulus);
|
||||
b = operator.add(&b, &e);
|
||||
|
||||
@@ -187,27 +96,29 @@ pub fn lwe_ksk_keygen<
|
||||
})
|
||||
},
|
||||
);
|
||||
|
||||
ksk_out
|
||||
}
|
||||
|
||||
/// Encrypts encoded message m as LWE ciphertext
|
||||
pub fn encrypt_lwe<
|
||||
Ro: Row + RowMut,
|
||||
Ro: RowMut + RowEntity,
|
||||
Op: ArithmeticOps<Element = Ro::Element> + GetModulus<Element = Ro::Element>,
|
||||
R: RandomGaussianElementInModulus<Ro::Element, Op::M>
|
||||
+ RandomFillUniformInModulus<[Ro::Element], Op::M>,
|
||||
S,
|
||||
>(
|
||||
lwe_out: &mut Ro,
|
||||
m: &Ro::Element,
|
||||
s: &[S],
|
||||
operator: &Op,
|
||||
rng: &mut R,
|
||||
) where
|
||||
) -> Ro
|
||||
where
|
||||
Ro: TryConvertFrom1<[S], Op::M>,
|
||||
Ro::Element: Zero,
|
||||
{
|
||||
let s = Ro::try_convert_from(s, operator.modulus());
|
||||
assert!(s.as_ref().len() == (lwe_out.as_ref().len() - 1));
|
||||
let mut lwe_out = Ro::zeros(s.as_ref().len() + 1);
|
||||
|
||||
// a*s
|
||||
RandomFillUniformInModulus::random_fill(rng, operator.modulus(), &mut lwe_out.as_mut()[1..]);
|
||||
@@ -221,9 +132,11 @@ pub fn encrypt_lwe<
|
||||
let e = RandomGaussianElementInModulus::random(rng, operator.modulus());
|
||||
let b = operator.add(&operator.add(&sa, &e), m);
|
||||
lwe_out.as_mut()[0] = b;
|
||||
|
||||
lwe_out
|
||||
}
|
||||
|
||||
pub fn decrypt_lwe<
|
||||
pub(crate) fn decrypt_lwe<
|
||||
Ro: Row,
|
||||
Op: ArithmeticOps<Element = Ro::Element> + GetModulus<Element = Ro::Element>,
|
||||
S,
|
||||
@@ -248,58 +161,85 @@ where
|
||||
operator.sub(b, &sa)
|
||||
}
|
||||
|
||||
/// Measures noise in input LWE ciphertext with reference of `ideal_m`
|
||||
///
|
||||
/// - ct: Input LWE ciphertext
|
||||
/// - s: corresponding secret
|
||||
/// - ideal_m: Ideal `encoded` message
|
||||
pub(crate) fn measure_noise_lwe<
|
||||
Ro: Row,
|
||||
Op: ArithmeticOps<Element = Ro::Element> + GetModulus<Element = Ro::Element>,
|
||||
S,
|
||||
>(
|
||||
ct: &Ro,
|
||||
s: &[S],
|
||||
operator: &Op,
|
||||
ideal_m: &Ro::Element,
|
||||
) -> f64
|
||||
where
|
||||
Ro: TryConvertFrom1<[S], Op::M>,
|
||||
Ro::Element: Zero + ToPrimitive + PrimInt + Display,
|
||||
{
|
||||
assert!(s.len() == ct.as_ref().len() - 1,);
|
||||
|
||||
let s = Ro::try_convert_from(s, &operator.modulus());
|
||||
let mut sa = Ro::Element::zero();
|
||||
izip!(s.as_ref().iter(), ct.as_ref().iter().skip(1)).for_each(|(si, ai)| {
|
||||
sa = operator.add(&sa, &operator.mul(si, ai));
|
||||
});
|
||||
let m = operator.sub(&ct.as_ref()[0], &sa);
|
||||
|
||||
let mut diff = operator.sub(&m, ideal_m);
|
||||
let q = operator.modulus();
|
||||
return q.map_element_to_i64(&diff).to_f64().unwrap().abs().log2();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use itertools::izip;
|
||||
|
||||
use crate::{
|
||||
backend::{ModInit, ModularOpsU64, ModulusPowerOf2},
|
||||
decomposer::{Decomposer, DefaultDecomposer},
|
||||
lwe::{lwe_key_switch, measure_noise_lwe},
|
||||
random::DefaultSecureRng,
|
||||
rgsw::measure_noise,
|
||||
Secret,
|
||||
backend::{ModInit, ModulusPowerOf2},
|
||||
decomposer::DefaultDecomposer,
|
||||
random::{DefaultSecureRng, NewWithSeed},
|
||||
utils::{fill_random_ternary_secret_with_hamming_weight, WithLocal},
|
||||
MatrixEntity, MatrixMut, Secret,
|
||||
};
|
||||
|
||||
use super::{
|
||||
decrypt_lwe, encrypt_lwe, lwe_ksk_keygen, LweKeySwitchingKey, LweSecret,
|
||||
SeededLweKeySwitchingKey,
|
||||
};
|
||||
use super::*;
|
||||
|
||||
const K: usize = 50;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct LweSecret {
|
||||
pub(crate) values: Vec<i32>,
|
||||
}
|
||||
|
||||
impl Secret for LweSecret {
|
||||
type Element = i32;
|
||||
fn values(&self) -> &[Self::Element] {
|
||||
&self.values
|
||||
}
|
||||
}
|
||||
|
||||
impl LweSecret {
|
||||
fn random(hw: usize, n: usize) -> LweSecret {
|
||||
DefaultSecureRng::with_local_mut(|rng| {
|
||||
let mut out = vec![0i32; n];
|
||||
fill_random_ternary_secret_with_hamming_weight(&mut out, hw, rng);
|
||||
|
||||
LweSecret { values: out }
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct LweKeySwitchingKey<M, R> {
|
||||
data: M,
|
||||
_phantom: PhantomData<R>,
|
||||
}
|
||||
|
||||
impl<
|
||||
M: MatrixMut + MatrixEntity,
|
||||
R: NewWithSeed + RandomFillUniformInModulus<[M::MatElement], M::MatElement>,
|
||||
> From<&(M::R, R::Seed, usize, M::MatElement)> for LweKeySwitchingKey<M, R>
|
||||
where
|
||||
M::R: RowMut,
|
||||
R::Seed: Clone,
|
||||
M::MatElement: Copy,
|
||||
{
|
||||
fn from(value: &(M::R, R::Seed, usize, M::MatElement)) -> Self {
|
||||
let data_in = &value.0;
|
||||
let seed = &value.1;
|
||||
let to_lwe_n = value.2;
|
||||
let modulus = value.3;
|
||||
|
||||
let mut p_rng = R::new_with_seed(seed.clone());
|
||||
let mut data = M::zeros(data_in.as_ref().len(), to_lwe_n + 1);
|
||||
izip!(data_in.as_ref().iter(), data.iter_rows_mut()).for_each(|(bi, lwe_i)| {
|
||||
RandomFillUniformInModulus::random_fill(
|
||||
&mut p_rng,
|
||||
&modulus,
|
||||
&mut lwe_i.as_mut()[1..],
|
||||
);
|
||||
lwe_i.as_mut()[0] = *bi;
|
||||
});
|
||||
LweKeySwitchingKey {
|
||||
data,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encrypt_decrypt_works() {
|
||||
let logq = 16;
|
||||
@@ -315,14 +255,8 @@ mod tests {
|
||||
// encrypt
|
||||
for m in 0..1u64 << logp {
|
||||
let encoded_m = m << (logq - logp);
|
||||
let mut lwe_ct = vec![0u64; lwe_n + 1];
|
||||
encrypt_lwe(
|
||||
&mut lwe_ct,
|
||||
&encoded_m,
|
||||
&lwe_sk.values(),
|
||||
&modq_op,
|
||||
&mut rng,
|
||||
);
|
||||
let lwe_ct =
|
||||
encrypt_lwe::<Vec<u64>, _, _, _>(&encoded_m, &lwe_sk.values(), &modq_op, &mut rng);
|
||||
let encoded_m_back = decrypt_lwe(&lwe_ct, &lwe_sk.values(), &modq_op);
|
||||
let m_back = ((((encoded_m_back as f64) * ((1 << logp) as f64)) / q as f64).round()
|
||||
as u64)
|
||||
@@ -351,34 +285,26 @@ mod tests {
|
||||
for _ in 0..1 {
|
||||
let mut ksk_seed = [0u8; 32];
|
||||
rng.fill_bytes(&mut ksk_seed);
|
||||
let mut seeded_ksk =
|
||||
SeededLweKeySwitchingKey::empty(lwe_in_n, lwe_out_n, d_ks, ksk_seed, q);
|
||||
let mut p_rng = DefaultSecureRng::new_seeded(ksk_seed);
|
||||
let decomposer = DefaultDecomposer::new(q, logb, d_ks);
|
||||
let gadget = decomposer.gadget_vector();
|
||||
lwe_ksk_keygen(
|
||||
let seeded_ksk = seeded_lwe_ksk_keygen(
|
||||
&lwe_sk_in.values(),
|
||||
&lwe_sk_out.values(),
|
||||
&mut seeded_ksk.data,
|
||||
&gadget,
|
||||
&modq_op,
|
||||
&mut p_rng,
|
||||
&mut rng,
|
||||
);
|
||||
// println!("{:?}", ksk);
|
||||
let ksk = LweKeySwitchingKey::<Vec<Vec<u64>>, DefaultSecureRng>::from(&seeded_ksk);
|
||||
let ksk = LweKeySwitchingKey::<Vec<Vec<u64>>, DefaultSecureRng>::from(&(
|
||||
seeded_ksk, ksk_seed, lwe_out_n, q,
|
||||
));
|
||||
|
||||
for m in 0..(1 << logp) {
|
||||
// encrypt using lwe_sk_in
|
||||
let encoded_m = m << (logq - logp);
|
||||
let mut lwe_in_ct = vec![0u64; lwe_in_n + 1];
|
||||
encrypt_lwe(
|
||||
&mut lwe_in_ct,
|
||||
&encoded_m,
|
||||
lwe_sk_in.values(),
|
||||
&modq_op,
|
||||
&mut rng,
|
||||
);
|
||||
let lwe_in_ct = encrypt_lwe(&encoded_m, lwe_sk_in.values(), &modq_op, &mut rng);
|
||||
|
||||
// key switch from lwe_sk_in to lwe_sk_out
|
||||
let mut lwe_out_ct = vec![0u64; lwe_out_n + 1];
|
||||
@@ -393,15 +319,17 @@ mod tests {
|
||||
println!("Time: {:?}", now.elapsed());
|
||||
|
||||
// decrypt lwe_out_ct using lwe_sk_out
|
||||
let encoded_m_back = decrypt_lwe(&lwe_out_ct, &lwe_sk_out.values(), &modq_op);
|
||||
let m_back = ((((encoded_m_back as f64) * ((1 << logp) as f64)) / q as f64).round()
|
||||
as u64)
|
||||
% (1u64 << logp);
|
||||
let noise =
|
||||
measure_noise_lwe(&lwe_out_ct, lwe_sk_out.values(), &modq_op, &encoded_m);
|
||||
println!("Noise: {noise}");
|
||||
// assert_eq!(m, m_back, "Expected {m} but got {m_back}");
|
||||
// dbg!(m, m_back);
|
||||
// TODO(Jay): Fix me
|
||||
// let encoded_m_back = decrypt_lwe(&lwe_out_ct,
|
||||
// &lwe_sk_out.values(), &modq_op); let m_back =
|
||||
// ((((encoded_m_back as f64) * ((1 << logp) as f64)) / q as
|
||||
// f64).round() as u64)
|
||||
// % (1u64 << logp);
|
||||
// let noise =
|
||||
// measure_noise_lwe(&lwe_out_ct, lwe_sk_out.values(),
|
||||
// &modq_op, &encoded_m); println!("Noise:
|
||||
// {noise}"); assert_eq!(m, m_back, "Expected
|
||||
// {m} but got {m_back}"); dbg!(m, m_back);
|
||||
// dbg!(encoded_m, encoded_m_back);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user