From 8e6cde2d89b82607214c29463916bb7627dedaf7 Mon Sep 17 00:00:00 2001 From: Janmajaya Mall Date: Fri, 28 Jun 2024 17:51:40 +0530 Subject: [PATCH] clean lwe --- examples/non_interactive_fheuint8.rs | 2 +- src/bool/evaluator.rs | 59 ++---- src/bool/keys.rs | 1 - src/bool/mod.rs | 3 +- src/bool/ni_mp_api.rs | 86 ++------ src/bool/parameters.rs | 25 +++ src/bool/print_noise.rs | 71 ++++++- src/decomposer.rs | 37 ++-- src/lwe.rs | 298 ++++++++++----------------- 9 files changed, 257 insertions(+), 325 deletions(-) diff --git a/examples/non_interactive_fheuint8.rs b/examples/non_interactive_fheuint8.rs index 1c8b10b..32bfed5 100644 --- a/examples/non_interactive_fheuint8.rs +++ b/examples/non_interactive_fheuint8.rs @@ -11,7 +11,7 @@ fn fhe_circuit(a: &FheUint8, b: &FheUint8, c: &FheUint8, d: &FheUint8) -> FheUin } fn main() { - set_parameter_set(ParameterSelector::NonInteractiveMultiPartyLessThanOrEqualTo16); + set_parameter_set(ParameterSelector::NonInteractiveLTE2Party); // set CRS let mut seed = [0u8; 32]; diff --git a/src/bool/evaluator.rs b/src/bool/evaluator.rs index 557e9e4..053d149 100644 --- a/src/bool/evaluator.rs +++ b/src/bool/evaluator.rs @@ -5,23 +5,18 @@ use std::{ usize, }; -use itertools::{izip, partition, Itertools}; -use num_traits::{ - zero, FromPrimitive, Num, One, Pow, PrimInt, ToPrimitive, WrappingAdd, WrappingSub, Zero, -}; -use rand::Rng; +use itertools::{izip, Itertools}; +use num_traits::{FromPrimitive, One, PrimInt, ToPrimitive, WrappingAdd, WrappingSub, Zero}; use rand_distr::uniform::SampleUniform; use crate::{ - backend::{ - ArithmeticOps, GetModulus, ModInit, ModularOpsU64, Modulus, ShoupMatrixFMA, VectorOps, - }, + backend::{ArithmeticOps, GetModulus, ModInit, Modulus, ShoupMatrixFMA, VectorOps}, bool::parameters::ParameterVariant, decomposer::{Decomposer, DefaultDecomposer, NumInfo, RlweDecomposer}, - lwe::{decrypt_lwe, encrypt_lwe, lwe_key_switch, lwe_ksk_keygen, measure_noise_lwe, LweSecret}, + lwe::{decrypt_lwe, encrypt_lwe, seeded_lwe_ksk_keygen}, multi_party::{ non_interactive_ksk_gen, non_interactive_ksk_zero_encryptions_for_other_party_i, - non_interactive_rgsw_ct, public_key_share, + public_key_share, }, ntt::{self, Ntt, NttBackendU64, NttInit}, pbs::{pbs, sample_extract, PbsInfo, PbsKey, WithShoupRepr}, @@ -48,8 +43,7 @@ use super::{ CommonReferenceSeededNonInteractiveMultiPartyServerKeyShare, InteractiveMultiPartyClientKey, NonInteractiveMultiPartyClientKey, SeededMultiPartyServerKey, SeededNonInteractiveMultiPartyServerKey, - SeededSinglePartyServerKey, ServerKeyEvaluationDomain, ShoupServerKeyEvaluationDomain, - SinglePartyClientKey, + SeededSinglePartyServerKey, SinglePartyClientKey, }, parameters::{ BoolParameters, CiphertextModulus, DecompositionCount, DecompostionLogBase, @@ -67,9 +61,8 @@ use super::{ /// Initial Seed: /// Puncture 1 -> Public key share seed /// Puncture 2 -> Main server key share seed -/// Puncture 1 -> RGSW cuphertexts seed -/// Puncture 2 -> Auto keys cipertexts seed -/// Puncture 3 -> LWE ksk seed +/// Puncture 1 -> Auto keys cipertexts seed +/// Puncture 2 -> LWE ksk seed #[derive(Clone, PartialEq)] pub struct MultiPartyCrs { pub(super) seed: S, @@ -97,19 +90,14 @@ impl MultiPartyCrs { puncture_p_rng(&mut prng, 2) } - pub(super) fn rgsw_cts_seed + RandomFill>(&self) -> S { - let mut key_prng = Rng::new_with_seed(self.key_seed::()); - puncture_p_rng(&mut key_prng, 1) - } - pub(super) fn auto_keys_cts_seed + RandomFill>(&self) -> S { let mut key_prng = Rng::new_with_seed(self.key_seed::()); - puncture_p_rng(&mut key_prng, 2) + puncture_p_rng(&mut key_prng, 1) } pub(super) fn lwe_ksk_cts_seed_seed + RandomFill>(&self) -> S { let mut key_prng = Rng::new_with_seed(self.key_seed::()); - puncture_p_rng(&mut key_prng, 3) + puncture_p_rng(&mut key_prng, 2) } } @@ -119,7 +107,8 @@ impl MultiPartyCrs { /// Puncture 1 -> Key Seed /// Puncture 1 -> Rgsw ciphertext seed /// Puncture l+1 -> Seed for zero encs and non-interactive -/// multi-party RGSW ciphertext corresponding to l^th LWE index. +/// multi-party RGSW ciphertexts of +/// l^th LWE index. /// Puncture 2 -> auto keys seed /// Puncture 3 -> Lwe key switching key seed /// Puncture 2 -> user specific seed for u_j to s ksk @@ -931,12 +920,9 @@ where // LWE KSK from RLWE secret s -> LWE secret z let d_lwe_gadget = self.pbs_info.lwe_decomposer.gadget_vector(); - let mut lwe_ksk = - M::R::zeros(self.pbs_info.lwe_decomposer.decomposition_count() * ring_size); - lwe_ksk_keygen( + let lwe_ksk = seeded_lwe_ksk_keygen( &sk_rlwe, &sk_lwe, - &mut lwe_ksk, &d_lwe_gadget, &self.pbs_info.lwe_modop, &mut main_prng, @@ -2049,21 +2035,16 @@ where ) -> M::R { DefaultSecureRng::with_local_mut(|rng| { let mut p_rng = DefaultSecureRng::new_seeded(lwe_ksk_seed); - let mut lwe_ksk = M::R::zeros( - self.pbs_info.lwe_decomposer.decomposition_count() * self.parameters().rlwe_n().0, - ); let lwe_modop = &self.pbs_info.lwe_modop; let d_lwe_gadget_vec = self.pbs_info.lwe_decomposer.gadget_vector(); - lwe_ksk_keygen( + seeded_lwe_ksk_keygen( sk_rlwe, sk_lwe, - &mut lwe_ksk, &d_lwe_gadget_vec, lwe_modop, &mut p_rng, rng, - ); - lwe_ksk + ) }) } @@ -2238,15 +2219,7 @@ where }; DefaultSecureRng::with_local_mut(|rng| { - let mut lwe_out = M::R::zeros(self.pbs_info.parameters.rlwe_n().0 + 1); - encrypt_lwe( - &mut lwe_out, - &m, - &client_key.sk_rlwe(), - &self.pbs_info.rlwe_modop, - rng, - ); - lwe_out + encrypt_lwe(&m, &client_key.sk_rlwe(), &self.pbs_info.rlwe_modop, rng) }) } diff --git a/src/bool/keys.rs b/src/bool/keys.rs index 2321dc5..18b6dc0 100644 --- a/src/bool/keys.rs +++ b/src/bool/keys.rs @@ -2,7 +2,6 @@ use std::{collections::HashMap, hash::Hash, marker::PhantomData}; use crate::{ backend::{ModInit, VectorOps}, - lwe::LweSecret, pbs::WithShoupRepr, random::{NewWithSeed, RandomFillUniformInModulus}, rgsw::RlweSecret, diff --git a/src/bool/mod.rs b/src/bool/mod.rs index 08261dc..c27ef13 100644 --- a/src/bool/mod.rs +++ b/src/bool/mod.rs @@ -22,7 +22,8 @@ pub type ClientKey = keys::ClientKey<[u8; 32], u64>; pub enum ParameterSelector { HighCommunicationButFast2Party, MultiPartyLessThanOrEqualTo16, - NonInteractiveMultiPartyLessThanOrEqualTo16, + NonInteractiveLTE2Party, + NonInteractiveLTE4Party, } mod common_mp_enc_dec { diff --git a/src/bool/ni_mp_api.rs b/src/bool/ni_mp_api.rs index 0110b24..ee5b7c5 100644 --- a/src/bool/ni_mp_api.rs +++ b/src/bool/ni_mp_api.rs @@ -3,6 +3,7 @@ use std::{cell::RefCell, sync::OnceLock}; use crate::{ backend::ModulusPowerOf2, bool::parameters::ParameterVariant, + parameters::NI_4P, random::DefaultSecureRng, utils::{Global, WithLocal}, ModularOpsU64, NttBackendU64, @@ -38,9 +39,13 @@ static MULTI_PARTY_CRS: OnceLock> = OnceLo pub fn set_parameter_set(select: ParameterSelector) { match select { - ParameterSelector::NonInteractiveMultiPartyLessThanOrEqualTo16 => { + ParameterSelector::NonInteractiveLTE2Party => { BOOL_EVALUATOR.with_borrow_mut(|v| *v = Some(BoolEvaluator::new(NI_2P))); } + ParameterSelector::NonInteractiveLTE4Party => { + BOOL_EVALUATOR.with_borrow_mut(|v| *v = Some(BoolEvaluator::new(NI_4P))); + } + _ => { panic!("Paramerters not supported") } @@ -160,6 +165,13 @@ impl Global for RuntimeServerKey { } } +pub(super) struct NonInteractiveBatchedFheBools { + data: Vec, +} +pub(super) struct BatchedFheBools { + pub(in super::super) data: Vec, +} + /// Non interactive multi-party specfic encryptor decryptor routines mod impl_enc_dec { use crate::{ @@ -177,10 +189,6 @@ mod impl_enc_dec { type Mat = Vec>; - pub(super) struct BatchedFheBools { - pub(super) data: Vec, - } - impl> BatchedFheBools where C::R: RowEntity + RowMut, @@ -202,10 +210,6 @@ mod impl_enc_dec { } } - pub(super) struct NonInteractiveBatchedFheBools { - data: Vec, - } - impl> From<&(Vec, [u8; 32])> for NonInteractiveBatchedFheBools where @@ -349,10 +353,9 @@ mod impl_enc_dec { #[cfg(test)] mod tests { - use impl_enc_dec::NonInteractiveBatchedFheBools; use itertools::{izip, Itertools}; use num_traits::{FromPrimitive, PrimInt, ToPrimitive, Zero}; - use rand::{thread_rng, RngCore}; + use rand::{thread_rng, Rng, RngCore}; use crate::{ backend::{GetModulus, Modulus}, @@ -374,7 +377,7 @@ mod tests { #[test] fn non_interactive_mp_bool_nand() { - set_parameter_set(ParameterSelector::NonInteractiveMultiPartyLessThanOrEqualTo16); + set_parameter_set(ParameterSelector::NonInteractiveLTE2Party); let mut seed = [0u8; 32]; thread_rng().fill_bytes(&mut seed); set_common_reference_seed(seed); @@ -444,63 +447,4 @@ mod tests { ct0 = ct_out; } } - - #[test] - fn trialtest() { - set_parameter_set(ParameterSelector::NonInteractiveMultiPartyLessThanOrEqualTo16); - set_common_reference_seed([2; 32]); - - let parties = 2; - - let cks = (0..parties).map(|_| gen_client_key()).collect_vec(); - - let key_shares = cks - .iter() - .enumerate() - .map(|(user_index, ck)| gen_server_key_share(user_index, parties, ck)) - .collect_vec(); - - let seeded_server_key = aggregate_server_key_shares(&key_shares); - seeded_server_key.set_server_key(); - - let m = vec![false, true]; - let ct: NonInteractiveBatchedFheBools<_> = cks[0].encrypt(m.as_slice()); - let ct = ct.key_switch(0); - - let parameters = BoolEvaluator::with_local(|e| e.parameters().clone()); - let nttop = NttBackendU64::new(parameters.rlwe_q(), parameters.rlwe_n().0); - let rlwe_q_modop = ModularOpsU64::new(*parameters.rlwe_q()); - - let mut ideal_rlwe_sk = vec![0i32; parameters.rlwe_n().0]; - cks.iter().for_each(|k| { - let sk_rlwe = k.sk_rlwe(); - izip!(ideal_rlwe_sk.iter_mut(), sk_rlwe.iter()).for_each(|(a, b)| { - *a = *a + b; - }); - }); - - let message = m - .iter() - .map(|b| parameters.rlwe_q().encode(*b)) - .collect_vec(); - - let mut m_out = vec![0u64; parameters.rlwe_n().0]; - decrypt_rlwe( - &ct.data[0], - &ideal_rlwe_sk, - &mut m_out, - &nttop, - &rlwe_q_modop, - ); - - let mut diff = m_out; - rlwe_q_modop.elwise_sub_mut(diff.as_mut_slice(), message.as_ref()); - - let mut stats = Stats::new(); - stats.add_more(&Vec::::try_convert_from( - diff.as_slice(), - parameters.rlwe_q(), - )); - println!("Noise: {}", stats.std_dev().abs().log2()); - } } diff --git a/src/bool/parameters.rs b/src/bool/parameters.rs index 770915d..53f2f2e 100644 --- a/src/bool/parameters.rs +++ b/src/bool/parameters.rs @@ -534,6 +534,31 @@ pub(crate) const NI_2P: BoolParameters = BoolParameters:: { variant: ParameterVariant::NonInteractiveMultiParty, }; +pub(crate) const NI_4P: BoolParameters = BoolParameters:: { + rlwe_q: CiphertextModulus::new_non_native(18014398509404161), + lwe_q: CiphertextModulus::new_non_native(1 << 16), + br_q: 1 << 11, + rlwe_n: PolynomialSize(1 << 11), + lwe_n: LweDimension(510), + lwe_decomposer_params: (DecompostionLogBase(1), DecompositionCount(12)), + rlrg_decomposer_params: ( + DecompostionLogBase(17), + (DecompositionCount(1), DecompositionCount(1)), + ), + rgrg_decomposer_params: Some(( + DecompostionLogBase(4), + (DecompositionCount(10), DecompositionCount(9)), + )), + auto_decomposer_params: (DecompostionLogBase(24), DecompositionCount(1)), + non_interactive_ui_to_s_key_switch_decomposer: Some(( + DecompostionLogBase(1), + DecompositionCount(50), + )), + g: 5, + w: 10, + variant: ParameterVariant::NonInteractiveMultiParty, +}; + #[cfg(test)] pub(crate) const SP_TEST_BOOL_PARAMS: BoolParameters = BoolParameters:: { rlwe_q: CiphertextModulus::new_non_native(268369921u64), diff --git a/src/bool/print_noise.rs b/src/bool/print_noise.rs index 59842b4..3522ea1 100644 --- a/src/bool/print_noise.rs +++ b/src/bool/print_noise.rs @@ -427,7 +427,7 @@ mod tests { NttBackendU64, }; - set_parameter_set(crate::ParameterSelector::NonInteractiveMultiPartyLessThanOrEqualTo16); + set_parameter_set(crate::ParameterSelector::NonInteractiveLTE2Party); set_common_reference_seed(NonInteractiveMultiPartyCrs::random().seed); let parties = 2; let cks = (0..parties).map(|i| gen_client_key()).collect_vec(); @@ -469,4 +469,73 @@ mod tests { server_key_stats.post_lwe_key_switch.std_dev().abs().log2() ); } + + #[test] + #[cfg(feature = "non_interactive_mp")] + fn enc_under_sk_and_key_switch() { + use rand::{thread_rng, Rng}; + + use crate::{ + aggregate_server_key_shares, + bool::{keys::tests::ideal_sk_rlwe, ni_mp_api::NonInteractiveBatchedFheBools}, + gen_client_key, gen_server_key_share, + rgsw::decrypt_rlwe, + set_common_reference_seed, set_parameter_set, + utils::{tests::Stats, TryConvertFrom1, WithLocal}, + BoolEvaluator, Encoder, Encryptor, KeySwitchWithId, ModInit, ModularOpsU64, + NttBackendU64, NttInit, ParameterSelector, VectorOps, + }; + + set_parameter_set(ParameterSelector::NonInteractiveLTE2Party); + set_common_reference_seed([2; 32]); + + let parties = 2; + + let cks = (0..parties).map(|_| gen_client_key()).collect_vec(); + + let key_shares = cks + .iter() + .enumerate() + .map(|(user_index, ck)| gen_server_key_share(user_index, parties, ck)) + .collect_vec(); + + let seeded_server_key = aggregate_server_key_shares(&key_shares); + seeded_server_key.set_server_key(); + + let parameters = BoolEvaluator::with_local(|e| e.parameters().clone()); + let nttop = NttBackendU64::new(parameters.rlwe_q(), parameters.rlwe_n().0); + let rlwe_q_modop = ModularOpsU64::new(*parameters.rlwe_q()); + + let m = (0..parameters.rlwe_n().0) + .map(|_| thread_rng().gen_bool(0.5)) + .collect_vec(); + let ct: NonInteractiveBatchedFheBools<_> = cks[0].encrypt(m.as_slice()); + let ct = ct.key_switch(0); + + let ideal_rlwe_sk = ideal_sk_rlwe(&cks); + + let message = m + .iter() + .map(|b| parameters.rlwe_q().encode(*b)) + .collect_vec(); + + let mut m_out = vec![0u64; parameters.rlwe_n().0]; + decrypt_rlwe( + &ct.data[0], + &ideal_rlwe_sk, + &mut m_out, + &nttop, + &rlwe_q_modop, + ); + + let mut diff = m_out; + rlwe_q_modop.elwise_sub_mut(diff.as_mut_slice(), message.as_ref()); + + let mut stats = Stats::new(); + stats.add_more(&Vec::::try_convert_from( + diff.as_slice(), + parameters.rlwe_q(), + )); + println!("Noise std log2: {}", stats.std_dev().abs().log2()); + } } diff --git a/src/decomposer.rs b/src/decomposer.rs index c4e785c..6bfbf35 100644 --- a/src/decomposer.rs +++ b/src/decomposer.rs @@ -1,14 +1,8 @@ use itertools::{izip, Itertools}; -use num_traits::{ - AsPrimitive, FromPrimitive, Num, One, PrimInt, ToPrimitive, WrappingAdd, WrappingSub, Zero, -}; -use std::{ - fmt::{Debug, Display}, - marker::PhantomData, - ops::Rem, -}; +use num_traits::{FromPrimitive, PrimInt, ToPrimitive, WrappingAdd, WrappingSub}; +use std::fmt::{Debug, Display}; -use crate::backend::{ArithmeticOps, ModularOpsU64}; +use crate::backend::ArithmeticOps; fn gadget_vector(logq: usize, logb: usize, d: usize) -> Vec { assert!(logq >= (logb * d)); @@ -146,7 +140,6 @@ impl< } } - // TODO(Jay): Outline the caveat fn decompose_to_vec(&self, value: &T) -> Vec { let q = self.q; let logb = self.logb; @@ -283,7 +276,7 @@ mod tests { use crate::{ backend::{ModInit, ModularOpsU64, Modulus}, decomposer::round_value, - utils::{generate_prime, tests::Stats, TryConvertFrom1}, + utils::{generate_prime, tests::Stats}, }; use super::{Decomposer, DefaultDecomposer}; @@ -297,7 +290,7 @@ mod tests { for logq in [37, 55] { let logb = 11; let d = 3; - let mut stats = vec![Stats::new(); d]; + // let mut stats = vec![Stats::new(); d]; for i in [true, false] { let q = if i { @@ -319,19 +312,19 @@ mod tests { let rounded_value = round_value(value, decomposer.ignore_bits); assert!((rounded_value as i64 - value_back as i64).abs() <= 1,); - izip!(stats.iter_mut(), limbs.iter()).for_each(|(s, l)| { - s.add_more(&vec![q.map_element_to_i64(l)]); - }); + // izip!(stats.iter_mut(), limbs.iter()).for_each(|(s, l)| { + // s.add_more(&vec![q.map_element_to_i64(l)]); + // }); } } - stats.iter().enumerate().for_each(|(index, s)| { - println!( - "Limb {index} - Mean: {}, Std: {}", - s.mean(), - s.std_dev().abs().log2() - ); - }); + // stats.iter().enumerate().for_each(|(index, s)| { + // println!( + // "Limb {index} - Mean: {}, Std: {}", + // s.mean(), + // s.std_dev().abs().log2() + // ); + // }); } } } diff --git a/src/lwe.rs b/src/lwe.rs index c036acb..bebee65 100644 --- a/src/lwe.rs +++ b/src/lwe.rs @@ -1,109 +1,16 @@ -use std::{ - cell::RefCell, - collections::btree_map::Values, - fmt::{Debug, Display}, - marker::PhantomData, -}; +use std::fmt::Debug; -use itertools::{izip, Itertools}; -use num_traits::{abs, PrimInt, ToPrimitive, Zero}; +use itertools::izip; +use num_traits::Zero; use crate::{ - backend::{ArithmeticOps, GetModulus, Modulus, VectorOps}, + backend::{ArithmeticOps, GetModulus, VectorOps}, decomposer::Decomposer, - random::{ - DefaultSecureRng, NewWithSeed, RandomFillGaussianInModulus, RandomFillUniformInModulus, - RandomGaussianElementInModulus, DEFAULT_RNG, - }, - utils::{fill_random_ternary_secret_with_hamming_weight, TryConvertFrom1, WithLocal}, - Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, Secret, + random::{RandomFillUniformInModulus, RandomGaussianElementInModulus}, + utils::TryConvertFrom1, + Matrix, Row, RowEntity, RowMut, }; -struct SeededLweKeySwitchingKey -where - Ro: Row, -{ - data: Ro, - seed: S, - to_lwe_n: usize, - modulus: Ro::Element, -} - -impl SeededLweKeySwitchingKey { - pub(crate) fn empty( - from_lwe_n: usize, - to_lwe_n: usize, - d: usize, - seed: S, - modulus: Ro::Element, - ) -> Self { - let data = Ro::zeros(from_lwe_n * d); - SeededLweKeySwitchingKey { - data, - to_lwe_n, - seed, - modulus, - } - } -} - -struct LweKeySwitchingKey { - data: M, - _phantom: PhantomData, -} - -impl< - M: MatrixMut + MatrixEntity, - R: NewWithSeed + RandomFillUniformInModulus<[M::MatElement], M::MatElement>, - > From<&SeededLweKeySwitchingKey> for LweKeySwitchingKey -where - M::R: RowMut, - R::Seed: Clone, - M::MatElement: Copy, -{ - fn from(value: &SeededLweKeySwitchingKey) -> Self { - let mut p_rng = R::new_with_seed(value.seed.clone()); - let mut data = M::zeros(value.data.as_ref().len(), value.to_lwe_n + 1); - izip!(value.data.as_ref().iter(), data.iter_rows_mut()).for_each(|(bi, lwe_i)| { - RandomFillUniformInModulus::random_fill( - &mut p_rng, - &value.modulus, - &mut lwe_i.as_mut()[1..], - ); - lwe_i.as_mut()[0] = *bi; - }); - LweKeySwitchingKey { - data, - _phantom: PhantomData, - } - } -} - -trait LweCiphertext {} - -#[derive(Clone)] -pub struct LweSecret { - pub(crate) values: Vec, -} - -impl Secret for LweSecret { - type Element = i32; - fn values(&self) -> &[Self::Element] { - &self.values - } -} - -impl LweSecret { - pub(crate) fn random(hw: usize, n: usize) -> LweSecret { - DefaultSecureRng::with_local_mut(|rng| { - let mut out = vec![0i32; n]; - fill_random_ternary_secret_with_hamming_weight(&mut out, hw, rng); - - LweSecret { values: out } - }) - } -} - pub(crate) fn lwe_key_switch< M: Matrix, Ro: AsMut<[M::MatElement]> + AsRef<[M::MatElement]>, @@ -127,15 +34,17 @@ pub(crate) fn lwe_key_switch< .skip(1) .flat_map(|ai| decomposer.decompose_iter(ai)); izip!(lwe_in_a_decomposed, lwe_ksk.iter_rows()).for_each(|(ai_j, beta_ij_lwe)| { + // let now = std::time::Instant::now(); operator.elwise_fma_scalar_mut(lwe_out.as_mut(), beta_ij_lwe.as_ref(), &ai_j); + // println!("Time elwise_fma_scalar_mut: {:?}", now.elapsed()); }); let out_b = operator.add(&lwe_out.as_ref()[0], &lwe_in.as_ref()[0]); lwe_out.as_mut()[0] = out_b; } -pub fn lwe_ksk_keygen< - Ro: Row + RowMut + RowEntity, +pub fn seeded_lwe_ksk_keygen< + Ro: RowMut + RowEntity, S, Op: VectorOps + ArithmeticOps @@ -145,16 +54,16 @@ pub fn lwe_ksk_keygen< >( from_lwe_sk: &[S], to_lwe_sk: &[S], - ksk_out: &mut Ro, gadget: &[Ro::Element], operator: &Op, p_rng: &mut PR, rng: &mut R, -) where +) -> Ro +where Ro: TryConvertFrom1<[S], Op::M>, Ro::Element: Zero + Debug, { - assert!(ksk_out.as_ref().len() == (from_lwe_sk.len() * gadget.len())); + let mut ksk_out = Ro::zeros(from_lwe_sk.len() * gadget.len()); let d = gadget.len(); @@ -167,7 +76,7 @@ pub fn lwe_ksk_keygen< izip!(neg_sk_in_m.as_ref(), ksk_out.as_mut().chunks_mut(d)).for_each( |(neg_sk_in_si, d_lwes_partb)| { - izip!(gadget.iter(), d_lwes_partb.into_iter()).for_each(|(f, lwe_b)| { + izip!(gadget.iter(), d_lwes_partb.into_iter()).for_each(|(beta, lwe_b)| { // sample `a` RandomFillUniformInModulus::random_fill(p_rng, &modulus, scratch.as_mut()); @@ -179,7 +88,7 @@ pub fn lwe_ksk_keygen< }); // a*z + (-s_i)*\beta^j + e - let mut b = operator.add(&az, &operator.mul(f, neg_sk_in_si)); + let mut b = operator.add(&az, &operator.mul(beta, neg_sk_in_si)); let e = RandomGaussianElementInModulus::random(rng, &modulus); b = operator.add(&b, &e); @@ -187,27 +96,29 @@ pub fn lwe_ksk_keygen< }) }, ); + + ksk_out } /// Encrypts encoded message m as LWE ciphertext pub fn encrypt_lwe< - Ro: Row + RowMut, + Ro: RowMut + RowEntity, Op: ArithmeticOps + GetModulus, R: RandomGaussianElementInModulus + RandomFillUniformInModulus<[Ro::Element], Op::M>, S, >( - lwe_out: &mut Ro, m: &Ro::Element, s: &[S], operator: &Op, rng: &mut R, -) where +) -> Ro +where Ro: TryConvertFrom1<[S], Op::M>, Ro::Element: Zero, { let s = Ro::try_convert_from(s, operator.modulus()); - assert!(s.as_ref().len() == (lwe_out.as_ref().len() - 1)); + let mut lwe_out = Ro::zeros(s.as_ref().len() + 1); // a*s RandomFillUniformInModulus::random_fill(rng, operator.modulus(), &mut lwe_out.as_mut()[1..]); @@ -221,9 +132,11 @@ pub fn encrypt_lwe< let e = RandomGaussianElementInModulus::random(rng, operator.modulus()); let b = operator.add(&operator.add(&sa, &e), m); lwe_out.as_mut()[0] = b; + + lwe_out } -pub fn decrypt_lwe< +pub(crate) fn decrypt_lwe< Ro: Row, Op: ArithmeticOps + GetModulus, S, @@ -248,58 +161,85 @@ where operator.sub(b, &sa) } -/// Measures noise in input LWE ciphertext with reference of `ideal_m` -/// -/// - ct: Input LWE ciphertext -/// - s: corresponding secret -/// - ideal_m: Ideal `encoded` message -pub(crate) fn measure_noise_lwe< - Ro: Row, - Op: ArithmeticOps + GetModulus, - S, ->( - ct: &Ro, - s: &[S], - operator: &Op, - ideal_m: &Ro::Element, -) -> f64 -where - Ro: TryConvertFrom1<[S], Op::M>, - Ro::Element: Zero + ToPrimitive + PrimInt + Display, -{ - assert!(s.len() == ct.as_ref().len() - 1,); - - let s = Ro::try_convert_from(s, &operator.modulus()); - let mut sa = Ro::Element::zero(); - izip!(s.as_ref().iter(), ct.as_ref().iter().skip(1)).for_each(|(si, ai)| { - sa = operator.add(&sa, &operator.mul(si, ai)); - }); - let m = operator.sub(&ct.as_ref()[0], &sa); - - let mut diff = operator.sub(&m, ideal_m); - let q = operator.modulus(); - return q.map_element_to_i64(&diff).to_f64().unwrap().abs().log2(); -} - #[cfg(test)] mod tests { + use std::marker::PhantomData; + + use itertools::izip; + use crate::{ - backend::{ModInit, ModularOpsU64, ModulusPowerOf2}, - decomposer::{Decomposer, DefaultDecomposer}, - lwe::{lwe_key_switch, measure_noise_lwe}, - random::DefaultSecureRng, - rgsw::measure_noise, - Secret, + backend::{ModInit, ModulusPowerOf2}, + decomposer::DefaultDecomposer, + random::{DefaultSecureRng, NewWithSeed}, + utils::{fill_random_ternary_secret_with_hamming_weight, WithLocal}, + MatrixEntity, MatrixMut, Secret, }; - use super::{ - decrypt_lwe, encrypt_lwe, lwe_ksk_keygen, LweKeySwitchingKey, LweSecret, - SeededLweKeySwitchingKey, - }; + use super::*; const K: usize = 50; + #[derive(Clone)] + struct LweSecret { + pub(crate) values: Vec, + } + + impl Secret for LweSecret { + type Element = i32; + fn values(&self) -> &[Self::Element] { + &self.values + } + } + + impl LweSecret { + fn random(hw: usize, n: usize) -> LweSecret { + DefaultSecureRng::with_local_mut(|rng| { + let mut out = vec![0i32; n]; + fill_random_ternary_secret_with_hamming_weight(&mut out, hw, rng); + + LweSecret { values: out } + }) + } + } + + struct LweKeySwitchingKey { + data: M, + _phantom: PhantomData, + } + + impl< + M: MatrixMut + MatrixEntity, + R: NewWithSeed + RandomFillUniformInModulus<[M::MatElement], M::MatElement>, + > From<&(M::R, R::Seed, usize, M::MatElement)> for LweKeySwitchingKey + where + M::R: RowMut, + R::Seed: Clone, + M::MatElement: Copy, + { + fn from(value: &(M::R, R::Seed, usize, M::MatElement)) -> Self { + let data_in = &value.0; + let seed = &value.1; + let to_lwe_n = value.2; + let modulus = value.3; + + let mut p_rng = R::new_with_seed(seed.clone()); + let mut data = M::zeros(data_in.as_ref().len(), to_lwe_n + 1); + izip!(data_in.as_ref().iter(), data.iter_rows_mut()).for_each(|(bi, lwe_i)| { + RandomFillUniformInModulus::random_fill( + &mut p_rng, + &modulus, + &mut lwe_i.as_mut()[1..], + ); + lwe_i.as_mut()[0] = *bi; + }); + LweKeySwitchingKey { + data, + _phantom: PhantomData, + } + } + } + #[test] fn encrypt_decrypt_works() { let logq = 16; @@ -315,14 +255,8 @@ mod tests { // encrypt for m in 0..1u64 << logp { let encoded_m = m << (logq - logp); - let mut lwe_ct = vec![0u64; lwe_n + 1]; - encrypt_lwe( - &mut lwe_ct, - &encoded_m, - &lwe_sk.values(), - &modq_op, - &mut rng, - ); + let lwe_ct = + encrypt_lwe::, _, _, _>(&encoded_m, &lwe_sk.values(), &modq_op, &mut rng); let encoded_m_back = decrypt_lwe(&lwe_ct, &lwe_sk.values(), &modq_op); let m_back = ((((encoded_m_back as f64) * ((1 << logp) as f64)) / q as f64).round() as u64) @@ -351,34 +285,26 @@ mod tests { for _ in 0..1 { let mut ksk_seed = [0u8; 32]; rng.fill_bytes(&mut ksk_seed); - let mut seeded_ksk = - SeededLweKeySwitchingKey::empty(lwe_in_n, lwe_out_n, d_ks, ksk_seed, q); let mut p_rng = DefaultSecureRng::new_seeded(ksk_seed); let decomposer = DefaultDecomposer::new(q, logb, d_ks); let gadget = decomposer.gadget_vector(); - lwe_ksk_keygen( + let seeded_ksk = seeded_lwe_ksk_keygen( &lwe_sk_in.values(), &lwe_sk_out.values(), - &mut seeded_ksk.data, &gadget, &modq_op, &mut p_rng, &mut rng, ); // println!("{:?}", ksk); - let ksk = LweKeySwitchingKey::>, DefaultSecureRng>::from(&seeded_ksk); + let ksk = LweKeySwitchingKey::>, DefaultSecureRng>::from(&( + seeded_ksk, ksk_seed, lwe_out_n, q, + )); for m in 0..(1 << logp) { // encrypt using lwe_sk_in let encoded_m = m << (logq - logp); - let mut lwe_in_ct = vec![0u64; lwe_in_n + 1]; - encrypt_lwe( - &mut lwe_in_ct, - &encoded_m, - lwe_sk_in.values(), - &modq_op, - &mut rng, - ); + let lwe_in_ct = encrypt_lwe(&encoded_m, lwe_sk_in.values(), &modq_op, &mut rng); // key switch from lwe_sk_in to lwe_sk_out let mut lwe_out_ct = vec![0u64; lwe_out_n + 1]; @@ -393,15 +319,17 @@ mod tests { println!("Time: {:?}", now.elapsed()); // decrypt lwe_out_ct using lwe_sk_out - let encoded_m_back = decrypt_lwe(&lwe_out_ct, &lwe_sk_out.values(), &modq_op); - let m_back = ((((encoded_m_back as f64) * ((1 << logp) as f64)) / q as f64).round() - as u64) - % (1u64 << logp); - let noise = - measure_noise_lwe(&lwe_out_ct, lwe_sk_out.values(), &modq_op, &encoded_m); - println!("Noise: {noise}"); - // assert_eq!(m, m_back, "Expected {m} but got {m_back}"); - // dbg!(m, m_back); + // TODO(Jay): Fix me + // let encoded_m_back = decrypt_lwe(&lwe_out_ct, + // &lwe_sk_out.values(), &modq_op); let m_back = + // ((((encoded_m_back as f64) * ((1 << logp) as f64)) / q as + // f64).round() as u64) + // % (1u64 << logp); + // let noise = + // measure_noise_lwe(&lwe_out_ct, lwe_sk_out.values(), + // &modq_op, &encoded_m); println!("Noise: + // {noise}"); assert_eq!(m, m_back, "Expected + // {m} but got {m_back}"); dbg!(m, m_back); // dbg!(encoded_m, encoded_m_back); } }