Browse Source

clean lwe

par-agg-key-shares
Janmajaya Mall 9 months ago
parent
commit
8e6cde2d89
9 changed files with 257 additions and 325 deletions
  1. +1
    -1
      examples/non_interactive_fheuint8.rs
  2. +16
    -43
      src/bool/evaluator.rs
  3. +0
    -1
      src/bool/keys.rs
  4. +2
    -1
      src/bool/mod.rs
  5. +15
    -71
      src/bool/ni_mp_api.rs
  6. +25
    -0
      src/bool/parameters.rs
  7. +70
    -1
      src/bool/print_noise.rs
  8. +15
    -22
      src/decomposer.rs
  9. +113
    -185
      src/lwe.rs

+ 1
- 1
examples/non_interactive_fheuint8.rs

@ -11,7 +11,7 @@ fn fhe_circuit(a: &FheUint8, b: &FheUint8, c: &FheUint8, d: &FheUint8) -> FheUin
}
fn main() {
set_parameter_set(ParameterSelector::NonInteractiveMultiPartyLessThanOrEqualTo16);
set_parameter_set(ParameterSelector::NonInteractiveLTE2Party);
// set CRS
let mut seed = [0u8; 32];

+ 16
- 43
src/bool/evaluator.rs

@ -5,23 +5,18 @@ use std::{
usize,
};
use itertools::{izip, partition, Itertools};
use num_traits::{
zero, FromPrimitive, Num, One, Pow, PrimInt, ToPrimitive, WrappingAdd, WrappingSub, Zero,
};
use rand::Rng;
use itertools::{izip, Itertools};
use num_traits::{FromPrimitive, One, PrimInt, ToPrimitive, WrappingAdd, WrappingSub, Zero};
use rand_distr::uniform::SampleUniform;
use crate::{
backend::{
ArithmeticOps, GetModulus, ModInit, ModularOpsU64, Modulus, ShoupMatrixFMA, VectorOps,
},
backend::{ArithmeticOps, GetModulus, ModInit, Modulus, ShoupMatrixFMA, VectorOps},
bool::parameters::ParameterVariant,
decomposer::{Decomposer, DefaultDecomposer, NumInfo, RlweDecomposer},
lwe::{decrypt_lwe, encrypt_lwe, lwe_key_switch, lwe_ksk_keygen, measure_noise_lwe, LweSecret},
lwe::{decrypt_lwe, encrypt_lwe, seeded_lwe_ksk_keygen},
multi_party::{
non_interactive_ksk_gen, non_interactive_ksk_zero_encryptions_for_other_party_i,
non_interactive_rgsw_ct, public_key_share,
public_key_share,
},
ntt::{self, Ntt, NttBackendU64, NttInit},
pbs::{pbs, sample_extract, PbsInfo, PbsKey, WithShoupRepr},
@ -48,8 +43,7 @@ use super::{
CommonReferenceSeededNonInteractiveMultiPartyServerKeyShare,
InteractiveMultiPartyClientKey, NonInteractiveMultiPartyClientKey,
SeededMultiPartyServerKey, SeededNonInteractiveMultiPartyServerKey,
SeededSinglePartyServerKey, ServerKeyEvaluationDomain, ShoupServerKeyEvaluationDomain,
SinglePartyClientKey,
SeededSinglePartyServerKey, SinglePartyClientKey,
},
parameters::{
BoolParameters, CiphertextModulus, DecompositionCount, DecompostionLogBase,
@ -67,9 +61,8 @@ use super::{
/// Initial Seed:
/// Puncture 1 -> Public key share seed
/// Puncture 2 -> Main server key share seed
/// Puncture 1 -> RGSW cuphertexts seed
/// Puncture 2 -> Auto keys cipertexts seed
/// Puncture 3 -> LWE ksk seed
/// Puncture 1 -> Auto keys cipertexts seed
/// Puncture 2 -> LWE ksk seed
#[derive(Clone, PartialEq)]
pub struct MultiPartyCrs<S> {
pub(super) seed: S,
@ -97,19 +90,14 @@ impl MultiPartyCrs {
puncture_p_rng(&mut prng, 2)
}
pub(super) fn rgsw_cts_seed<Rng: NewWithSeed<Seed = S> + RandomFill<S>>(&self) -> S {
let mut key_prng = Rng::new_with_seed(self.key_seed::<Rng>());
puncture_p_rng(&mut key_prng, 1)
}
pub(super) fn auto_keys_cts_seed<Rng: NewWithSeed<Seed = S> + RandomFill<S>>(&self) -> S {
let mut key_prng = Rng::new_with_seed(self.key_seed::<Rng>());
puncture_p_rng(&mut key_prng, 2)
puncture_p_rng(&mut key_prng, 1)
}
pub(super) fn lwe_ksk_cts_seed_seed<Rng: NewWithSeed<Seed = S> + RandomFill<S>>(&self) -> S {
let mut key_prng = Rng::new_with_seed(self.key_seed::<Rng>());
puncture_p_rng(&mut key_prng, 3)
puncture_p_rng(&mut key_prng, 2)
}
}
@ -119,7 +107,8 @@ impl MultiPartyCrs {
/// Puncture 1 -> Key Seed
/// Puncture 1 -> Rgsw ciphertext seed
/// Puncture l+1 -> Seed for zero encs and non-interactive
/// multi-party RGSW ciphertext corresponding to l^th LWE index.
/// multi-party RGSW ciphertexts of
/// l^th LWE index.
/// Puncture 2 -> auto keys seed
/// Puncture 3 -> Lwe key switching key seed
/// Puncture 2 -> user specific seed for u_j to s ksk
@ -931,12 +920,9 @@ where
// LWE KSK from RLWE secret s -> LWE secret z
let d_lwe_gadget = self.pbs_info.lwe_decomposer.gadget_vector();
let mut lwe_ksk =
M::R::zeros(self.pbs_info.lwe_decomposer.decomposition_count() * ring_size);
lwe_ksk_keygen(
let lwe_ksk = seeded_lwe_ksk_keygen(
&sk_rlwe,
&sk_lwe,
&mut lwe_ksk,
&d_lwe_gadget,
&self.pbs_info.lwe_modop,
&mut main_prng,
@ -2049,21 +2035,16 @@ where
) -> M::R {
DefaultSecureRng::with_local_mut(|rng| {
let mut p_rng = DefaultSecureRng::new_seeded(lwe_ksk_seed);
let mut lwe_ksk = M::R::zeros(
self.pbs_info.lwe_decomposer.decomposition_count() * self.parameters().rlwe_n().0,
);
let lwe_modop = &self.pbs_info.lwe_modop;
let d_lwe_gadget_vec = self.pbs_info.lwe_decomposer.gadget_vector();
lwe_ksk_keygen(
seeded_lwe_ksk_keygen(
sk_rlwe,
sk_lwe,
&mut lwe_ksk,
&d_lwe_gadget_vec,
lwe_modop,
&mut p_rng,
rng,
);
lwe_ksk
)
})
}
@ -2238,15 +2219,7 @@ where
};
DefaultSecureRng::with_local_mut(|rng| {
let mut lwe_out = M::R::zeros(self.pbs_info.parameters.rlwe_n().0 + 1);
encrypt_lwe(
&mut lwe_out,
&m,
&client_key.sk_rlwe(),
&self.pbs_info.rlwe_modop,
rng,
);
lwe_out
encrypt_lwe(&m, &client_key.sk_rlwe(), &self.pbs_info.rlwe_modop, rng)
})
}

+ 0
- 1
src/bool/keys.rs

@ -2,7 +2,6 @@ use std::{collections::HashMap, hash::Hash, marker::PhantomData};
use crate::{
backend::{ModInit, VectorOps},
lwe::LweSecret,
pbs::WithShoupRepr,
random::{NewWithSeed, RandomFillUniformInModulus},
rgsw::RlweSecret,

+ 2
- 1
src/bool/mod.rs

@ -22,7 +22,8 @@ pub type ClientKey = keys::ClientKey<[u8; 32], u64>;
pub enum ParameterSelector {
HighCommunicationButFast2Party,
MultiPartyLessThanOrEqualTo16,
NonInteractiveMultiPartyLessThanOrEqualTo16,
NonInteractiveLTE2Party,
NonInteractiveLTE4Party,
}
mod common_mp_enc_dec {

+ 15
- 71
src/bool/ni_mp_api.rs

@ -3,6 +3,7 @@ use std::{cell::RefCell, sync::OnceLock};
use crate::{
backend::ModulusPowerOf2,
bool::parameters::ParameterVariant,
parameters::NI_4P,
random::DefaultSecureRng,
utils::{Global, WithLocal},
ModularOpsU64, NttBackendU64,
@ -38,9 +39,13 @@ static MULTI_PARTY_CRS: OnceLock> = OnceLo
pub fn set_parameter_set(select: ParameterSelector) {
match select {
ParameterSelector::NonInteractiveMultiPartyLessThanOrEqualTo16 => {
ParameterSelector::NonInteractiveLTE2Party => {
BOOL_EVALUATOR.with_borrow_mut(|v| *v = Some(BoolEvaluator::new(NI_2P)));
}
ParameterSelector::NonInteractiveLTE4Party => {
BOOL_EVALUATOR.with_borrow_mut(|v| *v = Some(BoolEvaluator::new(NI_4P)));
}
_ => {
panic!("Paramerters not supported")
}
@ -160,6 +165,13 @@ impl Global for RuntimeServerKey {
}
}
pub(super) struct NonInteractiveBatchedFheBools<C> {
data: Vec<C>,
}
pub(super) struct BatchedFheBools<C> {
pub(in super::super) data: Vec<C>,
}
/// Non interactive multi-party specfic encryptor decryptor routines
mod impl_enc_dec {
use crate::{
@ -177,10 +189,6 @@ mod impl_enc_dec {
type Mat = Vec<Vec<u64>>;
pub(super) struct BatchedFheBools<C> {
pub(super) data: Vec<C>,
}
impl<C: MatrixMut<MatElement = u64>> BatchedFheBools<C>
where
C::R: RowEntity + RowMut,
@ -202,10 +210,6 @@ mod impl_enc_dec {
}
}
pub(super) struct NonInteractiveBatchedFheBools<C> {
data: Vec<C>,
}
impl<M: MatrixEntity + MatrixMut<MatElement = u64>> From<&(Vec<M::R>, [u8; 32])>
for NonInteractiveBatchedFheBools<M>
where
@ -349,10 +353,9 @@ mod impl_enc_dec {
#[cfg(test)]
mod tests {
use impl_enc_dec::NonInteractiveBatchedFheBools;
use itertools::{izip, Itertools};
use num_traits::{FromPrimitive, PrimInt, ToPrimitive, Zero};
use rand::{thread_rng, RngCore};
use rand::{thread_rng, Rng, RngCore};
use crate::{
backend::{GetModulus, Modulus},
@ -374,7 +377,7 @@ mod tests {
#[test]
fn non_interactive_mp_bool_nand() {
set_parameter_set(ParameterSelector::NonInteractiveMultiPartyLessThanOrEqualTo16);
set_parameter_set(ParameterSelector::NonInteractiveLTE2Party);
let mut seed = [0u8; 32];
thread_rng().fill_bytes(&mut seed);
set_common_reference_seed(seed);
@ -444,63 +447,4 @@ mod tests {
ct0 = ct_out;
}
}
#[test]
fn trialtest() {
set_parameter_set(ParameterSelector::NonInteractiveMultiPartyLessThanOrEqualTo16);
set_common_reference_seed([2; 32]);
let parties = 2;
let cks = (0..parties).map(|_| gen_client_key()).collect_vec();
let key_shares = cks
.iter()
.enumerate()
.map(|(user_index, ck)| gen_server_key_share(user_index, parties, ck))
.collect_vec();
let seeded_server_key = aggregate_server_key_shares(&key_shares);
seeded_server_key.set_server_key();
let m = vec![false, true];
let ct: NonInteractiveBatchedFheBools<_> = cks[0].encrypt(m.as_slice());
let ct = ct.key_switch(0);
let parameters = BoolEvaluator::with_local(|e| e.parameters().clone());
let nttop = NttBackendU64::new(parameters.rlwe_q(), parameters.rlwe_n().0);
let rlwe_q_modop = ModularOpsU64::new(*parameters.rlwe_q());
let mut ideal_rlwe_sk = vec![0i32; parameters.rlwe_n().0];
cks.iter().for_each(|k| {
let sk_rlwe = k.sk_rlwe();
izip!(ideal_rlwe_sk.iter_mut(), sk_rlwe.iter()).for_each(|(a, b)| {
*a = *a + b;
});
});
let message = m
.iter()
.map(|b| parameters.rlwe_q().encode(*b))
.collect_vec();
let mut m_out = vec![0u64; parameters.rlwe_n().0];
decrypt_rlwe(
&ct.data[0],
&ideal_rlwe_sk,
&mut m_out,
&nttop,
&rlwe_q_modop,
);
let mut diff = m_out;
rlwe_q_modop.elwise_sub_mut(diff.as_mut_slice(), message.as_ref());
let mut stats = Stats::new();
stats.add_more(&Vec::<i64>::try_convert_from(
diff.as_slice(),
parameters.rlwe_q(),
));
println!("Noise: {}", stats.std_dev().abs().log2());
}
}

+ 25
- 0
src/bool/parameters.rs

@ -534,6 +534,31 @@ pub(crate) const NI_2P: BoolParameters = BoolParameters:: {
variant: ParameterVariant::NonInteractiveMultiParty,
};
pub(crate) const NI_4P: BoolParameters<u64> = BoolParameters::<u64> {
rlwe_q: CiphertextModulus::new_non_native(18014398509404161),
lwe_q: CiphertextModulus::new_non_native(1 << 16),
br_q: 1 << 11,
rlwe_n: PolynomialSize(1 << 11),
lwe_n: LweDimension(510),
lwe_decomposer_params: (DecompostionLogBase(1), DecompositionCount(12)),
rlrg_decomposer_params: (
DecompostionLogBase(17),
(DecompositionCount(1), DecompositionCount(1)),
),
rgrg_decomposer_params: Some((
DecompostionLogBase(4),
(DecompositionCount(10), DecompositionCount(9)),
)),
auto_decomposer_params: (DecompostionLogBase(24), DecompositionCount(1)),
non_interactive_ui_to_s_key_switch_decomposer: Some((
DecompostionLogBase(1),
DecompositionCount(50),
)),
g: 5,
w: 10,
variant: ParameterVariant::NonInteractiveMultiParty,
};
#[cfg(test)]
pub(crate) const SP_TEST_BOOL_PARAMS: BoolParameters<u64> = BoolParameters::<u64> {
rlwe_q: CiphertextModulus::new_non_native(268369921u64),

+ 70
- 1
src/bool/print_noise.rs

@ -427,7 +427,7 @@ mod tests {
NttBackendU64,
};
set_parameter_set(crate::ParameterSelector::NonInteractiveMultiPartyLessThanOrEqualTo16);
set_parameter_set(crate::ParameterSelector::NonInteractiveLTE2Party);
set_common_reference_seed(NonInteractiveMultiPartyCrs::random().seed);
let parties = 2;
let cks = (0..parties).map(|i| gen_client_key()).collect_vec();
@ -469,4 +469,73 @@ mod tests {
server_key_stats.post_lwe_key_switch.std_dev().abs().log2()
);
}
#[test]
#[cfg(feature = "non_interactive_mp")]
fn enc_under_sk_and_key_switch() {
use rand::{thread_rng, Rng};
use crate::{
aggregate_server_key_shares,
bool::{keys::tests::ideal_sk_rlwe, ni_mp_api::NonInteractiveBatchedFheBools},
gen_client_key, gen_server_key_share,
rgsw::decrypt_rlwe,
set_common_reference_seed, set_parameter_set,
utils::{tests::Stats, TryConvertFrom1, WithLocal},
BoolEvaluator, Encoder, Encryptor, KeySwitchWithId, ModInit, ModularOpsU64,
NttBackendU64, NttInit, ParameterSelector, VectorOps,
};
set_parameter_set(ParameterSelector::NonInteractiveLTE2Party);
set_common_reference_seed([2; 32]);
let parties = 2;
let cks = (0..parties).map(|_| gen_client_key()).collect_vec();
let key_shares = cks
.iter()
.enumerate()
.map(|(user_index, ck)| gen_server_key_share(user_index, parties, ck))
.collect_vec();
let seeded_server_key = aggregate_server_key_shares(&key_shares);
seeded_server_key.set_server_key();
let parameters = BoolEvaluator::with_local(|e| e.parameters().clone());
let nttop = NttBackendU64::new(parameters.rlwe_q(), parameters.rlwe_n().0);
let rlwe_q_modop = ModularOpsU64::new(*parameters.rlwe_q());
let m = (0..parameters.rlwe_n().0)
.map(|_| thread_rng().gen_bool(0.5))
.collect_vec();
let ct: NonInteractiveBatchedFheBools<_> = cks[0].encrypt(m.as_slice());
let ct = ct.key_switch(0);
let ideal_rlwe_sk = ideal_sk_rlwe(&cks);
let message = m
.iter()
.map(|b| parameters.rlwe_q().encode(*b))
.collect_vec();
let mut m_out = vec![0u64; parameters.rlwe_n().0];
decrypt_rlwe(
&ct.data[0],
&ideal_rlwe_sk,
&mut m_out,
&nttop,
&rlwe_q_modop,
);
let mut diff = m_out;
rlwe_q_modop.elwise_sub_mut(diff.as_mut_slice(), message.as_ref());
let mut stats = Stats::new();
stats.add_more(&Vec::<i64>::try_convert_from(
diff.as_slice(),
parameters.rlwe_q(),
));
println!("Noise std log2: {}", stats.std_dev().abs().log2());
}
}

+ 15
- 22
src/decomposer.rs

@ -1,14 +1,8 @@
use itertools::{izip, Itertools};
use num_traits::{
AsPrimitive, FromPrimitive, Num, One, PrimInt, ToPrimitive, WrappingAdd, WrappingSub, Zero,
};
use std::{
fmt::{Debug, Display},
marker::PhantomData,
ops::Rem,
};
use num_traits::{FromPrimitive, PrimInt, ToPrimitive, WrappingAdd, WrappingSub};
use std::fmt::{Debug, Display};
use crate::backend::{ArithmeticOps, ModularOpsU64};
use crate::backend::ArithmeticOps;
fn gadget_vector<T: PrimInt>(logq: usize, logb: usize, d: usize) -> Vec<T> {
assert!(logq >= (logb * d));
@ -146,7 +140,6 @@ impl<
}
}
// TODO(Jay): Outline the caveat
fn decompose_to_vec(&self, value: &T) -> Vec<T> {
let q = self.q;
let logb = self.logb;
@ -283,7 +276,7 @@ mod tests {
use crate::{
backend::{ModInit, ModularOpsU64, Modulus},
decomposer::round_value,
utils::{generate_prime, tests::Stats, TryConvertFrom1},
utils::{generate_prime, tests::Stats},
};
use super::{Decomposer, DefaultDecomposer};
@ -297,7 +290,7 @@ mod tests {
for logq in [37, 55] {
let logb = 11;
let d = 3;
let mut stats = vec![Stats::new(); d];
// let mut stats = vec![Stats::new(); d];
for i in [true, false] {
let q = if i {
@ -319,19 +312,19 @@ mod tests {
let rounded_value = round_value(value, decomposer.ignore_bits);
assert!((rounded_value as i64 - value_back as i64).abs() <= 1,);
izip!(stats.iter_mut(), limbs.iter()).for_each(|(s, l)| {
s.add_more(&vec![q.map_element_to_i64(l)]);
});
// izip!(stats.iter_mut(), limbs.iter()).for_each(|(s, l)| {
// s.add_more(&vec![q.map_element_to_i64(l)]);
// });
}
}
stats.iter().enumerate().for_each(|(index, s)| {
println!(
"Limb {index} - Mean: {}, Std: {}",
s.mean(),
s.std_dev().abs().log2()
);
});
// stats.iter().enumerate().for_each(|(index, s)| {
// println!(
// "Limb {index} - Mean: {}, Std: {}",
// s.mean(),
// s.std_dev().abs().log2()
// );
// });
}
}
}

+ 113
- 185
src/lwe.rs

@ -1,109 +1,16 @@
use std::{
cell::RefCell,
collections::btree_map::Values,
fmt::{Debug, Display},
marker::PhantomData,
};
use std::fmt::Debug;
use itertools::{izip, Itertools};
use num_traits::{abs, PrimInt, ToPrimitive, Zero};
use itertools::izip;
use num_traits::Zero;
use crate::{
backend::{ArithmeticOps, GetModulus, Modulus, VectorOps},
backend::{ArithmeticOps, GetModulus, VectorOps},
decomposer::Decomposer,
random::{
DefaultSecureRng, NewWithSeed, RandomFillGaussianInModulus, RandomFillUniformInModulus,
RandomGaussianElementInModulus, DEFAULT_RNG,
},
utils::{fill_random_ternary_secret_with_hamming_weight, TryConvertFrom1, WithLocal},
Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, Secret,
random::{RandomFillUniformInModulus, RandomGaussianElementInModulus},
utils::TryConvertFrom1,
Matrix, Row, RowEntity, RowMut,
};
struct SeededLweKeySwitchingKey<Ro, S>
where
Ro: Row,
{
data: Ro,
seed: S,
to_lwe_n: usize,
modulus: Ro::Element,
}
impl<Ro: RowEntity, S> SeededLweKeySwitchingKey<Ro, S> {
pub(crate) fn empty(
from_lwe_n: usize,
to_lwe_n: usize,
d: usize,
seed: S,
modulus: Ro::Element,
) -> Self {
let data = Ro::zeros(from_lwe_n * d);
SeededLweKeySwitchingKey {
data,
to_lwe_n,
seed,
modulus,
}
}
}
struct LweKeySwitchingKey<M, R> {
data: M,
_phantom: PhantomData<R>,
}
impl<
M: MatrixMut + MatrixEntity,
R: NewWithSeed + RandomFillUniformInModulus<[M::MatElement], M::MatElement>,
> From<&SeededLweKeySwitchingKey<M::R, R::Seed>> for LweKeySwitchingKey<M, R>
where
M::R: RowMut,
R::Seed: Clone,
M::MatElement: Copy,
{
fn from(value: &SeededLweKeySwitchingKey<M::R, R::Seed>) -> Self {
let mut p_rng = R::new_with_seed(value.seed.clone());
let mut data = M::zeros(value.data.as_ref().len(), value.to_lwe_n + 1);
izip!(value.data.as_ref().iter(), data.iter_rows_mut()).for_each(|(bi, lwe_i)| {
RandomFillUniformInModulus::random_fill(
&mut p_rng,
&value.modulus,
&mut lwe_i.as_mut()[1..],
);
lwe_i.as_mut()[0] = *bi;
});
LweKeySwitchingKey {
data,
_phantom: PhantomData,
}
}
}
trait LweCiphertext<M: Matrix> {}
#[derive(Clone)]
pub struct LweSecret {
pub(crate) values: Vec<i32>,
}
impl Secret for LweSecret {
type Element = i32;
fn values(&self) -> &[Self::Element] {
&self.values
}
}
impl LweSecret {
pub(crate) fn random(hw: usize, n: usize) -> LweSecret {
DefaultSecureRng::with_local_mut(|rng| {
let mut out = vec![0i32; n];
fill_random_ternary_secret_with_hamming_weight(&mut out, hw, rng);
LweSecret { values: out }
})
}
}
pub(crate) fn lwe_key_switch<
M: Matrix,
Ro: AsMut<[M::MatElement]> + AsRef<[M::MatElement]>,
@ -127,15 +34,17 @@ pub(crate) fn lwe_key_switch<
.skip(1)
.flat_map(|ai| decomposer.decompose_iter(ai));
izip!(lwe_in_a_decomposed, lwe_ksk.iter_rows()).for_each(|(ai_j, beta_ij_lwe)| {
// let now = std::time::Instant::now();
operator.elwise_fma_scalar_mut(lwe_out.as_mut(), beta_ij_lwe.as_ref(), &ai_j);
// println!("Time elwise_fma_scalar_mut: {:?}", now.elapsed());
});
let out_b = operator.add(&lwe_out.as_ref()[0], &lwe_in.as_ref()[0]);
lwe_out.as_mut()[0] = out_b;
}
pub fn lwe_ksk_keygen<
Ro: Row + RowMut + RowEntity,
pub fn seeded_lwe_ksk_keygen<
Ro: RowMut + RowEntity,
S,
Op: VectorOps<Element = Ro::Element>
+ ArithmeticOps<Element = Ro::Element>
@ -145,16 +54,16 @@ pub fn lwe_ksk_keygen<
>(
from_lwe_sk: &[S],
to_lwe_sk: &[S],
ksk_out: &mut Ro,
gadget: &[Ro::Element],
operator: &Op,
p_rng: &mut PR,
rng: &mut R,
) where
) -> Ro
where
Ro: TryConvertFrom1<[S], Op::M>,
Ro::Element: Zero + Debug,
{
assert!(ksk_out.as_ref().len() == (from_lwe_sk.len() * gadget.len()));
let mut ksk_out = Ro::zeros(from_lwe_sk.len() * gadget.len());
let d = gadget.len();
@ -167,7 +76,7 @@ pub fn lwe_ksk_keygen<
izip!(neg_sk_in_m.as_ref(), ksk_out.as_mut().chunks_mut(d)).for_each(
|(neg_sk_in_si, d_lwes_partb)| {
izip!(gadget.iter(), d_lwes_partb.into_iter()).for_each(|(f, lwe_b)| {
izip!(gadget.iter(), d_lwes_partb.into_iter()).for_each(|(beta, lwe_b)| {
// sample `a`
RandomFillUniformInModulus::random_fill(p_rng, &modulus, scratch.as_mut());
@ -179,7 +88,7 @@ pub fn lwe_ksk_keygen<
});
// a*z + (-s_i)*\beta^j + e
let mut b = operator.add(&az, &operator.mul(f, neg_sk_in_si));
let mut b = operator.add(&az, &operator.mul(beta, neg_sk_in_si));
let e = RandomGaussianElementInModulus::random(rng, &modulus);
b = operator.add(&b, &e);
@ -187,27 +96,29 @@ pub fn lwe_ksk_keygen<
})
},
);
ksk_out
}
/// Encrypts encoded message m as LWE ciphertext
pub fn encrypt_lwe<
Ro: Row + RowMut,
Ro: RowMut + RowEntity,
Op: ArithmeticOps<Element = Ro::Element> + GetModulus<Element = Ro::Element>,
R: RandomGaussianElementInModulus<Ro::Element, Op::M>
+ RandomFillUniformInModulus<[Ro::Element], Op::M>,
S,
>(
lwe_out: &mut Ro,
m: &Ro::Element,
s: &[S],
operator: &Op,
rng: &mut R,
) where
) -> Ro
where
Ro: TryConvertFrom1<[S], Op::M>,
Ro::Element: Zero,
{
let s = Ro::try_convert_from(s, operator.modulus());
assert!(s.as_ref().len() == (lwe_out.as_ref().len() - 1));
let mut lwe_out = Ro::zeros(s.as_ref().len() + 1);
// a*s
RandomFillUniformInModulus::random_fill(rng, operator.modulus(), &mut lwe_out.as_mut()[1..]);
@ -221,9 +132,11 @@ pub fn encrypt_lwe<
let e = RandomGaussianElementInModulus::random(rng, operator.modulus());
let b = operator.add(&operator.add(&sa, &e), m);
lwe_out.as_mut()[0] = b;
lwe_out
}
pub fn decrypt_lwe<
pub(crate) fn decrypt_lwe<
Ro: Row,
Op: ArithmeticOps<Element = Ro::Element> + GetModulus<Element = Ro::Element>,
S,
@ -248,58 +161,85 @@ where
operator.sub(b, &sa)
}
/// Measures noise in input LWE ciphertext with reference of `ideal_m`
///
/// - ct: Input LWE ciphertext
/// - s: corresponding secret
/// - ideal_m: Ideal `encoded` message
pub(crate) fn measure_noise_lwe<
Ro: Row,
Op: ArithmeticOps<Element = Ro::Element> + GetModulus<Element = Ro::Element>,
S,
>(
ct: &Ro,
s: &[S],
operator: &Op,
ideal_m: &Ro::Element,
) -> f64
where
Ro: TryConvertFrom1<[S], Op::M>,
Ro::Element: Zero + ToPrimitive + PrimInt + Display,
{
assert!(s.len() == ct.as_ref().len() - 1,);
let s = Ro::try_convert_from(s, &operator.modulus());
let mut sa = Ro::Element::zero();
izip!(s.as_ref().iter(), ct.as_ref().iter().skip(1)).for_each(|(si, ai)| {
sa = operator.add(&sa, &operator.mul(si, ai));
});
let m = operator.sub(&ct.as_ref()[0], &sa);
let mut diff = operator.sub(&m, ideal_m);
let q = operator.modulus();
return q.map_element_to_i64(&diff).to_f64().unwrap().abs().log2();
}
#[cfg(test)]
mod tests {
use std::marker::PhantomData;
use itertools::izip;
use crate::{
backend::{ModInit, ModularOpsU64, ModulusPowerOf2},
decomposer::{Decomposer, DefaultDecomposer},
lwe::{lwe_key_switch, measure_noise_lwe},
random::DefaultSecureRng,
rgsw::measure_noise,
Secret,
backend::{ModInit, ModulusPowerOf2},
decomposer::DefaultDecomposer,
random::{DefaultSecureRng, NewWithSeed},
utils::{fill_random_ternary_secret_with_hamming_weight, WithLocal},
MatrixEntity, MatrixMut, Secret,
};
use super::{
decrypt_lwe, encrypt_lwe, lwe_ksk_keygen, LweKeySwitchingKey, LweSecret,
SeededLweKeySwitchingKey,
};
use super::*;
const K: usize = 50;
#[derive(Clone)]
struct LweSecret {
pub(crate) values: Vec<i32>,
}
impl Secret for LweSecret {
type Element = i32;
fn values(&self) -> &[Self::Element] {
&self.values
}
}
impl LweSecret {
fn random(hw: usize, n: usize) -> LweSecret {
DefaultSecureRng::with_local_mut(|rng| {
let mut out = vec![0i32; n];
fill_random_ternary_secret_with_hamming_weight(&mut out, hw, rng);
LweSecret { values: out }
})
}
}
struct LweKeySwitchingKey<M, R> {
data: M,
_phantom: PhantomData<R>,
}
impl<
M: MatrixMut + MatrixEntity,
R: NewWithSeed + RandomFillUniformInModulus<[M::MatElement], M::MatElement>,
> From<&(M::R, R::Seed, usize, M::MatElement)> for LweKeySwitchingKey<M, R>
where
M::R: RowMut,
R::Seed: Clone,
M::MatElement: Copy,
{
fn from(value: &(M::R, R::Seed, usize, M::MatElement)) -> Self {
let data_in = &value.0;
let seed = &value.1;
let to_lwe_n = value.2;
let modulus = value.3;
let mut p_rng = R::new_with_seed(seed.clone());
let mut data = M::zeros(data_in.as_ref().len(), to_lwe_n + 1);
izip!(data_in.as_ref().iter(), data.iter_rows_mut()).for_each(|(bi, lwe_i)| {
RandomFillUniformInModulus::random_fill(
&mut p_rng,
&modulus,
&mut lwe_i.as_mut()[1..],
);
lwe_i.as_mut()[0] = *bi;
});
LweKeySwitchingKey {
data,
_phantom: PhantomData,
}
}
}
#[test]
fn encrypt_decrypt_works() {
let logq = 16;
@ -315,14 +255,8 @@ mod tests {
// encrypt
for m in 0..1u64 << logp {
let encoded_m = m << (logq - logp);
let mut lwe_ct = vec![0u64; lwe_n + 1];
encrypt_lwe(
&mut lwe_ct,
&encoded_m,
&lwe_sk.values(),
&modq_op,
&mut rng,
);
let lwe_ct =
encrypt_lwe::<Vec<u64>, _, _, _>(&encoded_m, &lwe_sk.values(), &modq_op, &mut rng);
let encoded_m_back = decrypt_lwe(&lwe_ct, &lwe_sk.values(), &modq_op);
let m_back = ((((encoded_m_back as f64) * ((1 << logp) as f64)) / q as f64).round()
as u64)
@ -351,34 +285,26 @@ mod tests {
for _ in 0..1 {
let mut ksk_seed = [0u8; 32];
rng.fill_bytes(&mut ksk_seed);
let mut seeded_ksk =
SeededLweKeySwitchingKey::empty(lwe_in_n, lwe_out_n, d_ks, ksk_seed, q);
let mut p_rng = DefaultSecureRng::new_seeded(ksk_seed);
let decomposer = DefaultDecomposer::new(q, logb, d_ks);
let gadget = decomposer.gadget_vector();
lwe_ksk_keygen(
let seeded_ksk = seeded_lwe_ksk_keygen(
&lwe_sk_in.values(),
&lwe_sk_out.values(),
&mut seeded_ksk.data,
&gadget,
&modq_op,
&mut p_rng,
&mut rng,
);
// println!("{:?}", ksk);
let ksk = LweKeySwitchingKey::<Vec<Vec<u64>>, DefaultSecureRng>::from(&seeded_ksk);
let ksk = LweKeySwitchingKey::<Vec<Vec<u64>>, DefaultSecureRng>::from(&(
seeded_ksk, ksk_seed, lwe_out_n, q,
));
for m in 0..(1 << logp) {
// encrypt using lwe_sk_in
let encoded_m = m << (logq - logp);
let mut lwe_in_ct = vec![0u64; lwe_in_n + 1];
encrypt_lwe(
&mut lwe_in_ct,
&encoded_m,
lwe_sk_in.values(),
&modq_op,
&mut rng,
);
let lwe_in_ct = encrypt_lwe(&encoded_m, lwe_sk_in.values(), &modq_op, &mut rng);
// key switch from lwe_sk_in to lwe_sk_out
let mut lwe_out_ct = vec![0u64; lwe_out_n + 1];
@ -393,15 +319,17 @@ mod tests {
println!("Time: {:?}", now.elapsed());
// decrypt lwe_out_ct using lwe_sk_out
let encoded_m_back = decrypt_lwe(&lwe_out_ct, &lwe_sk_out.values(), &modq_op);
let m_back = ((((encoded_m_back as f64) * ((1 << logp) as f64)) / q as f64).round()
as u64)
% (1u64 << logp);
let noise =
measure_noise_lwe(&lwe_out_ct, lwe_sk_out.values(), &modq_op, &encoded_m);
println!("Noise: {noise}");
// assert_eq!(m, m_back, "Expected {m} but got {m_back}");
// dbg!(m, m_back);
// TODO(Jay): Fix me
// let encoded_m_back = decrypt_lwe(&lwe_out_ct,
// &lwe_sk_out.values(), &modq_op); let m_back =
// ((((encoded_m_back as f64) * ((1 << logp) as f64)) / q as
// f64).round() as u64)
// % (1u64 << logp);
// let noise =
// measure_noise_lwe(&lwe_out_ct, lwe_sk_out.values(),
// &modq_op, &encoded_m); println!("Noise:
// {noise}"); assert_eq!(m, m_back, "Expected
// {m} but got {m_back}"); dbg!(m, m_back);
// dbg!(encoded_m, encoded_m_back);
}
}

Loading…
Cancel
Save