use std::{ cell::RefCell, collections::btree_map::Values, fmt::{Debug, Display}, marker::PhantomData, }; use itertools::{izip, Itertools}; use num_traits::{abs, PrimInt, ToPrimitive, Zero}; use crate::{ backend::{ArithmeticOps, VectorOps}, decomposer::Decomposer, lwe, num::UnsignedInteger, random::{DefaultSecureRng, NewWithSeed, RandomGaussianDist, RandomUniformDist, DEFAULT_RNG}, utils::{fill_random_ternary_secret_with_hamming_weight, TryConvertFrom, WithLocal}, Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, Secret, }; struct SeededLweKeySwitchingKey where Ro: Row, { data: Ro, seed: S, to_lwe_n: usize, modulus: Ro::Element, } impl SeededLweKeySwitchingKey { pub(crate) fn empty( from_lwe_n: usize, to_lwe_n: usize, d: usize, seed: S, modulus: Ro::Element, ) -> Self { let data = Ro::zeros(from_lwe_n * d); SeededLweKeySwitchingKey { data, to_lwe_n, seed, modulus, } } } struct LweKeySwitchingKey { data: M, _phantom: PhantomData, } impl< M: MatrixMut + MatrixEntity, R: NewWithSeed + RandomUniformDist<[M::MatElement], Parameters = M::MatElement>, > From<&SeededLweKeySwitchingKey> for LweKeySwitchingKey where M::R: RowMut, R::Seed: Clone, M::MatElement: Copy, { fn from(value: &SeededLweKeySwitchingKey) -> Self { let mut p_rng = R::new_with_seed(value.seed.clone()); let mut data = M::zeros(value.data.as_ref().len(), value.to_lwe_n + 1); izip!(value.data.as_ref().iter(), data.iter_rows_mut()).for_each(|(bi, lwe_i)| { RandomUniformDist::random_fill(&mut p_rng, &value.modulus, &mut lwe_i.as_mut()[1..]); lwe_i.as_mut()[0] = *bi; }); LweKeySwitchingKey { data, _phantom: PhantomData, } } } trait LweCiphertext {} #[derive(Clone)] pub struct LweSecret { values: Vec, } impl Secret for LweSecret { type Element = i32; fn values(&self) -> &[Self::Element] { &self.values } } impl LweSecret { pub(crate) fn random(hw: usize, n: usize) -> LweSecret { DefaultSecureRng::with_local_mut(|rng| { let mut out = vec![0i32; n]; fill_random_ternary_secret_with_hamming_weight(&mut out, hw, rng); LweSecret { values: out } }) } } pub(crate) fn lwe_key_switch< M: Matrix, Ro: AsMut<[M::MatElement]> + AsRef<[M::MatElement]>, Op: VectorOps + ArithmeticOps, D: Decomposer, >( lwe_out: &mut Ro, lwe_in: &Ro, lwe_ksk: &M, operator: &Op, decomposer: &D, ) { assert!(lwe_ksk.dimension().0 == ((lwe_in.as_ref().len() - 1) * decomposer.d())); assert!(lwe_out.as_ref().len() == lwe_ksk.dimension().1); let lwe_in_a_decomposed = lwe_in .as_ref() .iter() .skip(1) .flat_map(|ai| decomposer.decompose(ai)); izip!(lwe_in_a_decomposed, lwe_ksk.iter_rows()).for_each(|(ai_j, beta_ij_lwe)| { operator.elwise_fma_scalar_mut(lwe_out.as_mut(), beta_ij_lwe.as_ref(), &ai_j); }); let out_b = operator.add(&lwe_out.as_ref()[0], &lwe_in.as_ref()[0]); lwe_out.as_mut()[0] = out_b; } pub fn lwe_ksk_keygen< Ro: Row + RowMut + RowEntity, S, Op: VectorOps + ArithmeticOps, R: RandomGaussianDist + RandomUniformDist<[Ro::Element], Parameters = Ro::Element> + NewWithSeed, >( from_lwe_sk: &[S], to_lwe_sk: &[S], ksk_out: &mut Ro, gadget: &[Ro::Element], seed: R::Seed, operator: &Op, rng: &mut R, ) where Ro: TryConvertFrom<[S], Parameters = Ro::Element>, Ro::Element: Zero + Debug, { assert!(ksk_out.as_ref().len() == (from_lwe_sk.len() * gadget.len())); let d = gadget.len(); let modulus = VectorOps::modulus(operator); let mut neg_sk_in_m = Ro::try_convert_from(from_lwe_sk, &modulus); operator.elwise_neg_mut(neg_sk_in_m.as_mut()); let sk_out_m = Ro::try_convert_from(to_lwe_sk, &modulus); let mut scratch = Ro::zeros(to_lwe_sk.len()); let mut p_rng = R::new_with_seed(seed); izip!(neg_sk_in_m.as_ref(), ksk_out.as_mut().chunks_mut(d)).for_each( |(neg_sk_in_si, d_lwes_partb)| { izip!(gadget.iter(), d_lwes_partb.into_iter()).for_each(|(f, lwe_b)| { // sample `a` RandomUniformDist::random_fill(&mut p_rng, &modulus, scratch.as_mut()); // a * z let mut az = Ro::Element::zero(); izip!(scratch.as_ref().iter(), sk_out_m.as_ref()).for_each(|(ai, si)| { let ai_si = operator.mul(ai, si); az = operator.add(&az, &ai_si); }); // a*z + (-s_i)*\beta^j + e let mut b = operator.add(&az, &operator.mul(f, neg_sk_in_si)); let mut e = Ro::Element::zero(); RandomGaussianDist::random_fill(rng, &modulus, &mut e); b = operator.add(&b, &e); *lwe_b = b; // dbg!(&lwe.as_mut(), &f); }) }, ); } /// Encrypts encoded message m as LWE ciphertext pub fn encrypt_lwe< Ro: Row + RowMut, R: RandomGaussianDist + RandomUniformDist<[Ro::Element], Parameters = Ro::Element>, S, Op: ArithmeticOps, >( lwe_out: &mut Ro, m: &Ro::Element, s: &[S], operator: &Op, rng: &mut R, ) where Ro: TryConvertFrom<[S], Parameters = Ro::Element>, Ro::Element: Zero, { let s = Ro::try_convert_from(s, &operator.modulus()); assert!(s.as_ref().len() == (lwe_out.as_ref().len() - 1)); // a*s RandomUniformDist::random_fill(rng, &operator.modulus(), &mut lwe_out.as_mut()[1..]); let mut sa = Ro::Element::zero(); izip!(lwe_out.as_mut().iter().skip(1), s.as_ref()).for_each(|(ai, si)| { let tmp = operator.mul(ai, si); sa = operator.add(&tmp, &sa); }); // b = a*s + e + m let mut e = Ro::Element::zero(); RandomGaussianDist::random_fill(rng, &operator.modulus(), &mut e); let b = operator.add(&operator.add(&sa, &e), m); lwe_out.as_mut()[0] = b; } pub fn decrypt_lwe, S>( lwe_ct: &Ro, s: &[S], operator: &Op, ) -> Ro::Element where Ro: TryConvertFrom<[S], Parameters = Ro::Element>, Ro::Element: Zero, { let s = Ro::try_convert_from(s, &operator.modulus()); let mut sa = Ro::Element::zero(); izip!(lwe_ct.as_ref().iter().skip(1), s.as_ref()).for_each(|(ai, si)| { let tmp = operator.mul(ai, si); sa = operator.add(&tmp, &sa); }); let b = &lwe_ct.as_ref()[0]; operator.sub(b, &sa) } /// Measures noise in input LWE ciphertext with reference of `ideal_m` /// /// - ct: Input LWE ciphertext /// - s: corresponding secret /// - ideal_m: Ideal `encoded` message pub(crate) fn measure_noise_lwe, S>( ct: &Ro, s: &[S], operator: &Op, ideal_m: &Ro::Element, ) -> f64 where Ro: TryConvertFrom<[S], Parameters = Ro::Element>, Ro::Element: Zero + ToPrimitive + PrimInt + Display, { assert!(s.len() == ct.as_ref().len() - 1,); let s = Ro::try_convert_from(s, &operator.modulus()); let mut sa = Ro::Element::zero(); izip!(s.as_ref().iter(), ct.as_ref().iter().skip(1)).for_each(|(si, ai)| { sa = operator.add(&sa, &operator.mul(si, ai)); }); let m = operator.sub(&ct.as_ref()[0], &sa); let mut diff = operator.sub(&m, ideal_m); let q = operator.modulus(); if diff > (q >> 1) { diff = q - diff; } return diff.to_f64().unwrap().log2(); } #[cfg(test)] mod tests { use crate::{ backend::{ModInit, ModularOpsU64}, decomposer::{gadget_vector, DefaultDecomposer}, lwe::{lwe_key_switch, measure_noise_lwe}, random::DefaultSecureRng, rgsw::measure_noise, Secret, }; use super::{ decrypt_lwe, encrypt_lwe, lwe_ksk_keygen, LweKeySwitchingKey, LweSecret, SeededLweKeySwitchingKey, }; const K: usize = 50; #[test] fn encrypt_decrypt_works() { let logq = 16; let q = 1u64 << logq; let lwe_n = 1024; let logp = 3; let modq_op = ModularOpsU64::new(q); let lwe_sk = LweSecret::random(lwe_n >> 1, lwe_n); let mut rng = DefaultSecureRng::new(); // encrypt for m in 0..1u64 << logp { let encoded_m = m << (logq - logp); let mut lwe_ct = vec![0u64; lwe_n + 1]; encrypt_lwe( &mut lwe_ct, &encoded_m, &lwe_sk.values(), &modq_op, &mut rng, ); let encoded_m_back = decrypt_lwe(&lwe_ct, &lwe_sk.values(), &modq_op); let m_back = ((((encoded_m_back as f64) * ((1 << logp) as f64)) / q as f64).round() as u64) % (1u64 << logp); assert_eq!(m, m_back, "Expected {m} but got {m_back}"); } } #[test] fn key_switch_works() { let logq = 16; let logp = 2; let q = 1u64 << logq; let lwe_in_n = 2048; let lwe_out_n = 493; let d_ks = 3; let logb = 4; let lwe_sk_in = LweSecret::random(lwe_in_n >> 1, lwe_in_n); let lwe_sk_out = LweSecret::random(lwe_out_n >> 1, lwe_out_n); let mut rng = DefaultSecureRng::new(); let modq_op = ModularOpsU64::new(q); // genrate ksk for _ in 0..K { let mut ksk_seed = [0u8; 32]; rng.fill_bytes(&mut ksk_seed); let mut seeded_ksk = SeededLweKeySwitchingKey::empty(lwe_in_n, lwe_out_n, d_ks, ksk_seed, q); let gadget = gadget_vector(logq, logb, d_ks); lwe_ksk_keygen( &lwe_sk_in.values(), &lwe_sk_out.values(), &mut seeded_ksk.data, &gadget, seeded_ksk.seed, &modq_op, &mut rng, ); // println!("{:?}", ksk); let ksk = LweKeySwitchingKey::>, DefaultSecureRng>::from(&seeded_ksk); for m in 0..(1 << logp) { // encrypt using lwe_sk_in let encoded_m = m << (logq - logp); let mut lwe_in_ct = vec![0u64; lwe_in_n + 1]; encrypt_lwe( &mut lwe_in_ct, &encoded_m, lwe_sk_in.values(), &modq_op, &mut rng, ); // key switch from lwe_sk_in to lwe_sk_out let decomposer = DefaultDecomposer::new(1u64 << logq, logb, d_ks); let mut lwe_out_ct = vec![0u64; lwe_out_n + 1]; lwe_key_switch( &mut lwe_out_ct, &lwe_in_ct, &ksk.data, &modq_op, &decomposer, ); // decrypt lwe_out_ct using lwe_sk_out let encoded_m_back = decrypt_lwe(&lwe_out_ct, &lwe_sk_out.values(), &modq_op); let m_back = ((((encoded_m_back as f64) * ((1 << logp) as f64)) / q as f64).round() as u64) % (1u64 << logp); let noise = measure_noise_lwe(&lwe_out_ct, lwe_sk_out.values(), &modq_op, &encoded_m); println!("Noise: {noise}"); assert_eq!(m, m_back, "Expected {m} but got {m_back}"); // dbg!(m, m_back); // dbg!(encoded_m, encoded_m_back); } } } }