diff --git a/src/lwe.rs b/src/lwe.rs index dcf08c2..5406708 100644 --- a/src/lwe.rs +++ b/src/lwe.rs @@ -1,6 +1,8 @@ use std::{ cell::RefCell, + collections::btree_map::Values, fmt::{Debug, Display}, + marker::PhantomData, }; use itertools::{izip, Itertools}; @@ -11,15 +13,65 @@ use crate::{ decomposer::Decomposer, lwe, num::UnsignedInteger, - random::{DefaultSecureRng, RandomGaussianDist, RandomUniformDist, DEFAULT_RNG}, + random::{DefaultSecureRng, NewWithSeed, RandomGaussianDist, RandomUniformDist, DEFAULT_RNG}, utils::{fill_random_ternary_secret_with_hamming_weight, TryConvertFrom, WithLocal}, - Matrix, MatrixEntity, MatrixMut, Row, RowMut, Secret, + Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, Secret, }; -trait LweKeySwitchParameters { - fn n_in(&self) -> usize; - fn n_out(&self) -> usize; - fn d_ks(&self) -> usize; +struct SeededLweKeySwitchingKey +where + Ro: Row, +{ + data: Ro, + seed: S, + to_lwe_n: usize, + modulus: Ro::Element, +} + +impl SeededLweKeySwitchingKey { + pub(crate) fn empty( + from_lwe_n: usize, + to_lwe_n: usize, + d: usize, + seed: S, + modulus: Ro::Element, + ) -> Self { + let data = Ro::zeros(from_lwe_n * d); + SeededLweKeySwitchingKey { + data, + to_lwe_n, + seed, + modulus, + } + } +} + +struct LweKeySwitchingKey { + data: M, + _phantom: PhantomData, +} + +impl< + M: MatrixMut + MatrixEntity, + R: NewWithSeed + RandomUniformDist<[M::MatElement], Parameters = M::MatElement>, + > From<&SeededLweKeySwitchingKey> for LweKeySwitchingKey +where + M::R: RowMut, + R::Seed: Clone, + M::MatElement: Copy, +{ + fn from(value: &SeededLweKeySwitchingKey) -> Self { + let mut p_rng = R::new_with_seed(value.seed.clone()); + let mut data = M::zeros(value.data.as_ref().len(), value.to_lwe_n + 1); + izip!(value.data.as_ref().iter(), data.iter_rows_mut()).for_each(|(bi, lwe_i)| { + RandomUniformDist::random_fill(&mut p_rng, &value.modulus, &mut lwe_i.as_mut()[1..]); + lwe_i.as_mut()[0] = *bi; + }); + LweKeySwitchingKey { + data, + _phantom: PhantomData, + } + } } trait LweCiphertext {} @@ -76,59 +128,61 @@ pub(crate) fn lwe_key_switch< } pub fn lwe_ksk_keygen< - Mmut: MatrixMut, + Ro: Row + RowMut + RowEntity, S, - Op: VectorOps + ArithmeticOps, - R: RandomGaussianDist - + RandomUniformDist<[Mmut::MatElement], Parameters = Mmut::MatElement>, + Op: VectorOps + ArithmeticOps, + R: RandomGaussianDist + + RandomUniformDist<[Ro::Element], Parameters = Ro::Element> + + NewWithSeed, >( from_lwe_sk: &[S], to_lwe_sk: &[S], - ksk_out: &mut Mmut, - gadget: &[Mmut::MatElement], + ksk_out: &mut Ro, + gadget: &[Ro::Element], + seed: R::Seed, operator: &Op, rng: &mut R, ) where - ::R: RowMut, - Mmut::R: TryConvertFrom<[S], Parameters = Mmut::MatElement>, - Mmut::MatElement: Zero + Debug, + Ro: TryConvertFrom<[S], Parameters = Ro::Element>, + Ro::Element: Zero + Debug, { - assert!(ksk_out.dimension() == (from_lwe_sk.len() * gadget.len(), to_lwe_sk.len() + 1,)); + assert!(ksk_out.as_ref().len() == (from_lwe_sk.len() * gadget.len())); let d = gadget.len(); let modulus = VectorOps::modulus(operator); - let mut neg_sk_in_m = Mmut::R::try_convert_from(from_lwe_sk, &modulus); + let mut neg_sk_in_m = Ro::try_convert_from(from_lwe_sk, &modulus); operator.elwise_neg_mut(neg_sk_in_m.as_mut()); - let sk_out_m = Mmut::R::try_convert_from(to_lwe_sk, &modulus); - - izip!( - neg_sk_in_m.as_ref(), - ksk_out.iter_rows_mut().chunks(d).into_iter() - ) - .for_each(|(neg_sk_in_si, d_ks_lwes)| { - izip!(gadget.iter(), d_ks_lwes.into_iter()).for_each(|(f, lwe)| { - // sample `a` - RandomUniformDist::random_fill(rng, &modulus, &mut lwe.as_mut()[1..]); - - // a * z - let mut az = Mmut::MatElement::zero(); - izip!(lwe.as_ref()[1..].iter(), sk_out_m.as_ref()).for_each(|(ai, si)| { - let ai_si = operator.mul(ai, si); - az = operator.add(&az, &ai_si); - }); - - // a*z + (-s_i)*\beta^j + e - let mut b = operator.add(&az, &operator.mul(f, neg_sk_in_si)); - let mut e = Mmut::MatElement::zero(); - RandomGaussianDist::random_fill(rng, &modulus, &mut e); - b = operator.add(&b, &e); - - lwe.as_mut()[0] = b; - - // dbg!(&lwe.as_mut(), &f); - }) - }); + let sk_out_m = Ro::try_convert_from(to_lwe_sk, &modulus); + + let mut scratch = Ro::zeros(to_lwe_sk.len()); + let mut p_rng = R::new_with_seed(seed); + + izip!(neg_sk_in_m.as_ref(), ksk_out.as_mut().chunks_mut(d)).for_each( + |(neg_sk_in_si, d_lwes_partb)| { + izip!(gadget.iter(), d_lwes_partb.into_iter()).for_each(|(f, lwe_b)| { + // sample `a` + RandomUniformDist::random_fill(&mut p_rng, &modulus, scratch.as_mut()); + + // a * z + let mut az = Ro::Element::zero(); + izip!(scratch.as_ref().iter(), sk_out_m.as_ref()).for_each(|(ai, si)| { + let ai_si = operator.mul(ai, si); + az = operator.add(&az, &ai_si); + }); + + // a*z + (-s_i)*\beta^j + e + let mut b = operator.add(&az, &operator.mul(f, neg_sk_in_si)); + let mut e = Ro::Element::zero(); + RandomGaussianDist::random_fill(rng, &modulus, &mut e); + b = operator.add(&b, &e); + + *lwe_b = b; + + // dbg!(&lwe.as_mut(), &f); + }) + }, + ); } /// Encrypts encoded message m as LWE ciphertext @@ -231,9 +285,12 @@ mod tests { Secret, }; - use super::{decrypt_lwe, encrypt_lwe, lwe_ksk_keygen, LweSecret}; + use super::{ + decrypt_lwe, encrypt_lwe, lwe_ksk_keygen, LweKeySwitchingKey, LweSecret, + SeededLweKeySwitchingKey, + }; - const K: usize = 500; + const K: usize = 50; #[test] fn encrypt_decrypt_works() { @@ -274,7 +331,7 @@ mod tests { let lwe_in_n = 2048; let lwe_out_n = 493; let d_ks = 3; - let logb = 5; + let logb = 4; let lwe_sk_in = LweSecret::random(lwe_in_n >> 1, lwe_in_n); let lwe_sk_out = LweSecret::random(lwe_out_n >> 1, lwe_out_n); @@ -284,17 +341,22 @@ mod tests { // genrate ksk for _ in 0..K { - let mut ksk = vec![vec![0u64; lwe_out_n + 1]; d_ks * lwe_in_n]; + let mut ksk_seed = [0u8; 32]; + rng.fill_bytes(&mut ksk_seed); + let mut seeded_ksk = + SeededLweKeySwitchingKey::empty(lwe_in_n, lwe_out_n, d_ks, ksk_seed, q); let gadget = gadget_vector(logq, logb, d_ks); lwe_ksk_keygen( &lwe_sk_in.values(), &lwe_sk_out.values(), - &mut ksk, + &mut seeded_ksk.data, &gadget, + seeded_ksk.seed, &modq_op, &mut rng, ); // println!("{:?}", ksk); + let ksk = LweKeySwitchingKey::>, DefaultSecureRng>::from(&seeded_ksk); for m in 0..(1 << logp) { // encrypt using lwe_sk_in @@ -311,7 +373,13 @@ mod tests { // key switch from lwe_sk_in to lwe_sk_out let decomposer = DefaultDecomposer::new(1u64 << logq, logb, d_ks); let mut lwe_out_ct = vec![0u64; lwe_out_n + 1]; - lwe_key_switch(&mut lwe_out_ct, &lwe_in_ct, &ksk, &modq_op, &decomposer); + lwe_key_switch( + &mut lwe_out_ct, + &lwe_in_ct, + &ksk.data, + &modq_op, + &decomposer, + ); // decrypt lwe_out_ct using lwe_sk_out let encoded_m_back = decrypt_lwe(&lwe_out_ct, &lwe_sk_out.values(), &modq_op);