From d1554a84269fbd2d9bf2e999b09813b382be0661 Mon Sep 17 00:00:00 2001 From: Janmajaya Mall Date: Fri, 3 May 2024 20:15:17 +0530 Subject: [PATCH] add multi-party bool, but fails --- src/{bool.rs => bool/evaluator.rs} | 967 +++++++++++++++++++++++++++-- src/bool/mod.rs | 2 + src/bool/parameters.rs | 67 ++ src/lwe.rs | 2 +- src/multi_party.rs | 25 +- src/rgsw.rs | 671 +++++++++++++++++--- 6 files changed, 1586 insertions(+), 148 deletions(-) rename src/{bool.rs => bool/evaluator.rs} (53%) create mode 100644 src/bool/mod.rs create mode 100644 src/bool/parameters.rs diff --git a/src/bool.rs b/src/bool/evaluator.rs similarity index 53% rename from src/bool.rs rename to src/bool/evaluator.rs index f00979e..e3938d7 100644 --- a/src/bool.rs +++ b/src/bool/evaluator.rs @@ -14,16 +14,23 @@ use crate::{ backend::{ArithmeticOps, ModInit, ModularOpsU64, VectorOps}, decomposer::{gadget_vector, Decomposer, DefaultDecomposer, NumInfo}, lwe::{decrypt_lwe, encrypt_lwe, lwe_key_switch, lwe_ksk_keygen, measure_noise_lwe, LweSecret}, - ntt::{Ntt, NttBackendU64, NttInit}, + multi_party::public_key_share, + ntt::{self, Ntt, NttBackendU64, NttInit}, random::{DefaultSecureRng, NewWithSeed, RandomGaussianDist, RandomUniformDist}, rgsw::{ - decrypt_rlwe, galois_auto, galois_key_gen, generate_auto_map, rlwe_by_rgsw, - secret_key_encrypt_rgsw, IsTrivial, RlweCiphertext, RlweSecret, + decrypt_rlwe, galois_auto, galois_key_gen, generate_auto_map, public_key_encrypt_rgsw, + rgsw_by_rgsw_inplace, rlwe_by_rgsw, secret_key_encrypt_rgsw, IsTrivial, RlweCiphertext, + RlweSecret, + }, + utils::{ + fill_random_ternary_secret_with_hamming_weight, generate_prime, mod_exponent, + TryConvertFrom, WithLocal, }, - utils::{generate_prime, mod_exponent, TryConvertFrom, WithLocal}, Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, Secret, }; +use super::parameters::{self, BoolParameters}; + // thread_local! { // pub(crate) static CLIENT_KEY: RefCell = // RefCell::new(ClientKey::random()); } @@ -97,6 +104,189 @@ impl ClientKey { // ClientKey::with_local_mut(|k| *k = key.clone()) // } +struct MultiPartyDecryptionShare { + share: E, +} + +struct CommonReferenceSeededCollectivePublicKeyShare { + share: R, + cr_seed: S, + parameters: P, +} + +struct PublicKey { + key: M, + _phantom: PhantomData<(R, O)>, +} + +impl< + M: MatrixMut + MatrixEntity, + Rng: NewWithSeed + RandomUniformDist<[M::MatElement], Parameters = M::MatElement>, + ModOp: VectorOps + ModInit, + > + From< + &[CommonReferenceSeededCollectivePublicKeyShare< + M::R, + Rng::Seed, + BoolParameters, + >], + > for PublicKey +where + ::R: RowMut, + Rng::Seed: Copy + PartialEq, + M::MatElement: PartialEq + Copy, +{ + fn from( + value: &[CommonReferenceSeededCollectivePublicKeyShare< + M::R, + Rng::Seed, + BoolParameters, + >], + ) -> Self { + assert!(value.len() > 0); + + let parameters = &value[0].parameters; + let mut key = M::zeros(2, parameters.rlwe_n); + + // sample A + let seed = value[0].cr_seed; + let mut main_rng = Rng::new_with_seed(seed); + RandomUniformDist::random_fill(&mut main_rng, ¶meters.rlwe_q, key.get_row_mut(0)); + + // Sum all Bs + let rlweq_modop = ModOp::new(parameters.rlwe_q); + value.iter().for_each(|share_i| { + assert!(share_i.cr_seed == seed); + assert!(&share_i.parameters == parameters); + + rlweq_modop.elwise_add_mut(key.get_row_mut(1), share_i.share.as_ref()); + }); + + PublicKey { + key, + _phantom: PhantomData, + } + } +} + +struct CommonReferenceSeededMultiPartyServerKeyShare { + rgsw_cts: Vec, + auto_keys: HashMap, + lwe_ksk: M::R, + /// Common reference seed + cr_seed: S, + parameters: P, +} +struct SeededMultiPartyServerKey { + rgsw_cts: Vec, + auto_keys: HashMap, + lwe_ksk: M::R, + cr_seed: S, + parameters: P, +} + +fn aggregate_multi_party_server_key_shares< + M: MatrixMut + MatrixEntity, + S: Copy + PartialEq, + D: Decomposer, + ModOp: VectorOps + ModInit, + NttOp: Ntt + NttInit, +>( + shares: &[CommonReferenceSeededMultiPartyServerKeyShare, S>], + d_rgsw_decomposer: &D, +) -> SeededMultiPartyServerKey> +where + ::R: RowMut + RowEntity, + M::MatElement: Copy + PartialEq + Zero, + M: Clone, +{ + assert!(shares.len() > 0); + let parameters = shares[0].parameters.clone(); + let cr_seed = shares[0].cr_seed; + + let rlwe_n = parameters.rlwe_n; + let g = parameters.g as isize; + let d_rgsw = parameters.d_rgsw; + let d_lwe = parameters.d_lwe; + let rlwe_q = parameters.rlwe_q; + let lwe_q = parameters.lwe_q; + + // sanity checks + shares.iter().skip(1).for_each(|s| { + assert!(s.parameters == parameters); + assert!(s.cr_seed == cr_seed); + }); + + let rlweq_modop = ModOp::new(rlwe_q); + let rlweq_nttop = NttOp::new(rlwe_q, rlwe_n); + + // auto keys + let mut auto_keys = HashMap::new(); + for i in [g, -g] { + let mut key = M::zeros(d_rgsw, rlwe_n); + + shares.iter().for_each(|s| { + let auto_key_share_i = s.auto_keys.get(&i).expect("Auto key {i} missing"); + assert!(auto_key_share_i.dimension() == (d_rgsw, rlwe_n)); + izip!(key.iter_rows_mut(), auto_key_share_i.iter_rows()).for_each( + |(partb_out, partb_share)| { + rlweq_modop.elwise_add_mut(partb_out.as_mut(), partb_share.as_ref()); + }, + ); + }); + + auto_keys.insert(i, key); + } + + // rgsw ciphertext (most expensive part!) + let lwe_n = parameters.lwe_n; + let mut scratch_d_plus_rgsw_by_ring = M::zeros(d_rgsw + (d_rgsw * 4), rlwe_n); + let rgsw_cts = (0..lwe_n) + .into_iter() + .map(|index| { + // copy over rgsw ciphertext for index^th secret element from first share and + // send it to evaluation domain + let mut rgsw_i = shares[0].rgsw_cts[index].clone(); + rgsw_i + .iter_rows_mut() + .for_each(|ri| rlweq_nttop.forward(ri.as_mut())); + + shares.iter().skip(1).for_each(|si| { + rgsw_by_rgsw_inplace( + &mut rgsw_i, + &si.rgsw_cts[index], + d_rgsw_decomposer, + &mut scratch_d_plus_rgsw_by_ring, + &rlweq_nttop, + &rlweq_modop, + ); + }); + + // send final rgsw ciphertext of secret element at index to coefficient domain + rgsw_i + .iter_rows_mut() + .for_each(|ri| rlweq_nttop.backward(ri.as_mut())); + rgsw_i + }) + .collect_vec(); + + // LWE ksks + let mut lwe_ksk = M::R::zeros(rlwe_n * d_lwe); + let lweq_modop = ModOp::new(lwe_q); + shares.iter().for_each(|si| { + assert!(si.lwe_ksk.as_ref().len() == rlwe_n * d_lwe); + lweq_modop.elwise_add_mut(lwe_ksk.as_mut(), si.lwe_ksk.as_ref()) + }); + + SeededMultiPartyServerKey { + rgsw_cts, + auto_keys, + lwe_ksk, + cr_seed, + parameters: parameters, + } +} + struct SeededServerKey { /// Rgsw cts of LWE secret elements rgsw_cts: Vec, @@ -259,6 +449,90 @@ where } } +impl< + M: MatrixMut + MatrixEntity, + Rng: NewWithSeed, + N: NttInit + Ntt, + > From<&SeededMultiPartyServerKey>> + for ServerKeyEvaluationDomain +where + ::R: RowMut, + Rng::Seed: Copy, + Rng: RandomUniformDist<[M::MatElement], Parameters = M::MatElement>, + M::MatElement: Copy, +{ + fn from( + value: &SeededMultiPartyServerKey>, + ) -> Self { + let g = value.parameters.g as isize; + let rlwe_n = value.parameters.rlwe_n; + let lwe_n = value.parameters.lwe_n; + let rlwe_q = value.parameters.rlwe_q; + let lwe_q = value.parameters.lwe_q; + let d_rgsw = value.parameters.d_rgsw; + + let mut main_prng = Rng::new_with_seed(value.cr_seed); + + let rlwe_nttop = N::new(rlwe_q, rlwe_n); + + // auto keys + let mut auto_keys = HashMap::new(); + for i in [g, -g] { + let mut key = M::zeros(value.parameters.d_rgsw * 2, rlwe_n); + + // sample a + key.iter_rows_mut().take(d_rgsw).for_each(|ri| { + RandomUniformDist::random_fill(&mut main_prng, &rlwe_q, ri.as_mut()) + }); + + let key_part_b = value.auto_keys.get(&i).unwrap(); + assert!(key_part_b.dimension() == (d_rgsw, rlwe_n)); + izip!(key.iter_rows_mut().skip(d_rgsw), key_part_b.iter_rows()).for_each( + |(to_ri, from_ri)| { + to_ri.as_mut().copy_from_slice(from_ri.as_ref()); + }, + ); + + // send to evaluation domain + key.iter_rows_mut() + .for_each(|ri| rlwe_nttop.forward(ri.as_mut())); + + auto_keys.insert(i, key); + } + + // rgsw cts + let rgsw_cts = value + .rgsw_cts + .iter() + .map(|ct_i| { + let mut eval_ct_i = M::zeros(d_rgsw * 4, rlwe_n); + + izip!(eval_ct_i.iter_rows_mut(), ct_i.iter_rows()).for_each(|(to_ri, from_ri)| { + to_ri.as_mut().copy_from_slice(from_ri.as_ref()); + rlwe_nttop.forward(to_ri.as_mut()); + }); + + eval_ct_i + }) + .collect_vec(); + + // lwe ksk + let d_lwe = value.parameters.d_lwe; + let mut lwe_ksk = M::zeros(rlwe_n * d_lwe, lwe_n + 1); + izip!(lwe_ksk.iter_rows_mut(), value.lwe_ksk.as_ref().iter()).for_each(|(lwe_i, bi)| { + RandomUniformDist::random_fill(&mut main_prng, &lwe_q, &mut lwe_i.as_mut()[1..]); + lwe_i.as_mut()[0] = *bi; + }); + + ServerKeyEvaluationDomain { + rgsw_cts, + galois_keys: auto_keys, + lwe_ksk, + _phanton: PhantomData, + } + } +} + //FIXME(Jay): Figure out a way for BoolEvaluator to have access to ServerKey // via a pointer and implement PbsKey for BoolEvaluator instead of ServerKey // directly @@ -276,23 +550,6 @@ impl PbsKey for ServerKeyEvaluationDomain { } } -#[derive(Clone)] -struct BoolParameters { - rlwe_q: El, - rlwe_logq: usize, - lwe_q: El, - lwe_logq: usize, - br_q: usize, - rlwe_n: usize, - lwe_n: usize, - d_rgsw: usize, - logb_rgsw: usize, - d_lwe: usize, - logb_lwe: usize, - g: usize, - w: usize, -} - struct BoolEvaluator where M: Matrix, @@ -540,8 +797,256 @@ where }) } + fn multi_party_sever_key_share( + &self, + cr_seed: [u8; 32], + collective_pk: &M, + client_key: &ClientKey, + ) -> CommonReferenceSeededMultiPartyServerKeyShare, [u8; 32]> + { + DefaultSecureRng::with_local_mut(|rng| { + let mut main_prng = DefaultSecureRng::new_seeded(cr_seed); + + let sk_rlwe = &client_key.sk_rlwe; + let sk_lwe = &client_key.sk_lwe; + + let g = self.parameters.g as isize; + let ring_size = self.parameters.rlwe_n; + let d_rgsw = self.parameters.d_rgsw; + let d_lwe = self.parameters.d_lwe; + let rlwe_q = self.parameters.rlwe_q; + let lwe_q = self.parameters.lwe_q; + + let d_rgsw_gadget_vec = + gadget_vector(self.parameters.rlwe_logq, self.parameters.logb_rgsw, d_rgsw); + + let rlweq_modop = ModOp::new(rlwe_q); + let rlweq_nttop = NttOp::new(rlwe_q, ring_size); + + // sanity check + assert!(sk_rlwe.values().len() == ring_size); + assert!(sk_lwe.values().len() == self.parameters.lwe_n); + + // auto keys + let mut auto_keys = HashMap::new(); + for i in [g, -g] { + let mut ksk_out = M::zeros(d_rgsw, ring_size); + galois_key_gen( + &mut ksk_out, + sk_rlwe.values(), + i, + &d_rgsw_gadget_vec, + &rlweq_modop, + &rlweq_nttop, + &mut main_prng, + rng, + ); + auto_keys.insert(i, ksk_out); + } + + // rgsw ciphertexts of lwe secret elements + let rgsw_cts = sk_lwe + .values() + .iter() + .map(|si| { + let mut m = M::R::zeros(ring_size); + + if *si < 0 { + // X^{-si} = X^{2N-si} = -X^{N-si}, assuming abs(si) < N + // (which it is given si is secret element) + m.as_mut()[ring_size - (si.abs() as usize)] = rlwe_q - M::MatElement::one(); + } else { + m.as_mut()[*si as usize] = M::MatElement::one(); + } + + // public key RGSW encryption has no part that can be seeded, unlike secret key + // RGSW encryption where RLWE'_A(m) is seeded + let mut out_rgsw = M::zeros(d_rgsw * 4, ring_size); + public_key_encrypt_rgsw( + &mut out_rgsw, + &m.as_ref(), + collective_pk, + &d_rgsw_gadget_vec, + &rlweq_modop, + &rlweq_nttop, + rng, + ); + + out_rgsw + }) + .collect_vec(); + + // LWE ksk + let mut lwe_ksk = M::R::zeros(d_lwe * ring_size); + let lwe_modop = ModOp::new(lwe_q); + let d_lwe_gadget_vec = + gadget_vector(self.parameters.lwe_logq, self.parameters.logb_lwe, d_lwe); + lwe_ksk_keygen( + sk_rlwe.values(), + sk_lwe.values(), + &mut lwe_ksk, + &d_lwe_gadget_vec, + &lwe_modop, + &mut main_prng, + rng, + ); + + CommonReferenceSeededMultiPartyServerKeyShare { + auto_keys, + rgsw_cts, + lwe_ksk, + cr_seed, + parameters: self.parameters.clone(), + } + }) + } + + fn multi_party_public_key_share( + &self, + cr_seed: [u8; 32], + client_key: &ClientKey, + ) -> CommonReferenceSeededCollectivePublicKeyShare< + ::R, + [u8; 32], + BoolParameters<::MatElement>, + > { + DefaultSecureRng::with_local_mut(|rng| { + let mut share_out = M::R::zeros(self.parameters.rlwe_n); + let modop = ModOp::new(self.parameters.rlwe_q); + let nttop = NttOp::new(self.parameters.rlwe_q, self.parameters.rlwe_n); + let mut main_prng = DefaultSecureRng::new_seeded(cr_seed); + public_key_share( + &mut share_out, + client_key.sk_rlwe.values(), + &modop, + &nttop, + &mut main_prng, + rng, + ); + + CommonReferenceSeededCollectivePublicKeyShare { + share: share_out, + cr_seed: cr_seed, + parameters: self.parameters.clone(), + } + }) + } + + fn multi_party_decryption_share( + &self, + lwe_ct: &M::R, + client_key: &ClientKey, + ) -> MultiPartyDecryptionShare<::MatElement> { + assert!(lwe_ct.as_ref().len() == self.parameters.rlwe_n + 1); + let modop = &self.rlwe_modop; + let mut neg_s = + M::R::try_convert_from(client_key.sk_rlwe.values(), &self.parameters.rlwe_q); + modop.elwise_neg_mut(neg_s.as_mut()); + + let mut neg_sa = M::MatElement::zero(); + izip!(lwe_ct.as_ref().iter().skip(1), neg_s.as_ref().iter()).for_each(|(ai, nsi)| { + neg_sa = modop.add(&neg_sa, &modop.mul(ai, nsi)); + }); + + let e = DefaultSecureRng::with_local_mut(|rng| { + let mut e = M::MatElement::zero(); + RandomGaussianDist::random_fill(rng, &self.parameters.rlwe_q, &mut e); + e + }); + let share = modop.add(&neg_sa, &e); + + MultiPartyDecryptionShare { share } + } + + pub(crate) fn multi_party_decrypt( + &self, + shares: &[MultiPartyDecryptionShare], + lwe_ct: &M::R, + ) -> bool { + let modop = &self.rlwe_modop; + let mut sum_a = M::MatElement::zero(); + shares + .iter() + .for_each(|share_i| sum_a = modop.add(&sum_a, &share_i.share)); + + let encoded_m = modop.add(&lwe_ct.as_ref()[0], &sum_a); + + let m = (((encoded_m + self.rlweq_by8).to_f64().unwrap() * 4f64) + / self.parameters.rlwe_q.to_f64().unwrap()) + .round() as usize + % 4usize; + + if m == 0 { + return false; + } else if m == 1 { + return true; + } else { + panic!("Bool decryption failure. Expected m to be either 1 or 0, but m={m} "); + } + } + + /// First encrypt as RLWE(m) with m as constant polynomial and extract it as + /// LWE ciphertext + pub(crate) fn pk_encrypt(&self, pk: &M, m: bool) -> M::R { + DefaultSecureRng::with_local_mut(|rng| { + let modop = &self.rlwe_modop; + let nttop = &self.rlwe_nttop; + + // RLWE(0) + // sample ephemeral key u + let ring_size = self.parameters.rlwe_n; + let mut u = vec![0i32; ring_size]; + fill_random_ternary_secret_with_hamming_weight(u.as_mut(), ring_size >> 1, rng); + let mut u = M::R::try_convert_from(&u, &self.parameters.rlwe_q); + nttop.forward(u.as_mut()); + + let mut ua = M::R::zeros(ring_size); + ua.as_mut().copy_from_slice(pk.get_row_slice(0)); + let mut ub = M::R::zeros(ring_size); + ub.as_mut().copy_from_slice(pk.get_row_slice(1)); + + // a*u + nttop.forward(ua.as_mut()); + modop.elwise_mul_mut(ua.as_mut(), u.as_ref()); + nttop.backward(ua.as_mut()); + + // b*u + nttop.forward(ub.as_mut()); + modop.elwise_mul_mut(ub.as_mut(), u.as_ref()); + nttop.backward(ub.as_mut()); + + let mut rlwe = M::zeros(2, ring_size); + // sample error + rlwe.iter_rows_mut().for_each(|ri| { + RandomGaussianDist::random_fill(rng, &self.parameters.rlwe_q, ri.as_mut()); + }); + + // a*u + e0 + modop.elwise_add_mut(rlwe.get_row_mut(0), ua.as_ref()); + // b*u + e1 + modop.elwise_add_mut(rlwe.get_row_mut(1), ub.as_ref()); + + let m = if m { + // Q/8 + self.rlweq_by8 + } else { + // -Q/8 + self.parameters.rlwe_q - self.rlweq_by8 + }; + + // b*u + e1 + m, where m is constant polynomial + rlwe.set(1, 0, modop.add(rlwe.get(1, 0), &m)); + + // sample extract index 0 + let mut lwe_out = M::R::zeros(ring_size + 1); + sample_extract(&mut lwe_out, &rlwe, modop, 0); + + lwe_out + }) + } + /// TODO(Jay): Fetch client key from thread local - pub fn encrypt(&self, m: bool, client_key: &ClientKey) -> M::R { + pub fn sk_encrypt(&self, m: bool, client_key: &ClientKey) -> M::R { let m = if m { // Q/8 self.rlweq_by8 @@ -563,7 +1068,7 @@ where }) } - pub fn decrypt(&self, lwe_ct: &M::R, client_key: &ClientKey) -> bool { + pub fn sk_decrypt(&self, lwe_ct: &M::R, client_key: &ClientKey) -> bool { let m = decrypt_lwe(lwe_ct, client_key.sk_rlwe.values(), &self.rlwe_modop); let m = { // m + q/8 => {0,q/4 1} @@ -1082,20 +1587,19 @@ where } impl PBSTracer>> { - fn trace(&self, parameters: &BoolParameters, client_key: &ClientKey) { + fn trace(&self, parameters: &BoolParameters, sk_lwe: &[i32], sk_rlwe: &[i32]) { + assert!(parameters.rlwe_n == sk_rlwe.len()); + assert!(parameters.lwe_n == sk_lwe.len()); + let modop_lweq = ModularOpsU64::new(parameters.lwe_q as u64); // noise after mod down Q -> Q_ks - let m_back0 = decrypt_lwe(&self.ct_lwe_q_mod, client_key.sk_rlwe.values(), &modop_lweq); + let m_back0 = decrypt_lwe(&self.ct_lwe_q_mod, sk_rlwe, &modop_lweq); // noise after key switch from RLWE -> LWE - let m_back1 = decrypt_lwe( - &self.ct_lwe_q_mod_after_ksk, - client_key.sk_lwe.values(), - &modop_lweq, - ); + let m_back1 = decrypt_lwe(&self.ct_lwe_q_mod_after_ksk, sk_lwe, &modop_lweq); // noise after mod down odd from Q_ks -> q let modop_br_q = ModularOpsU64::new(parameters.br_q as u64); - let m_back2 = decrypt_lwe(&self.ct_br_q_mod, client_key.sk_lwe.values(), &modop_br_q); + let m_back2 = decrypt_lwe(&self.ct_br_q_mod, sk_lwe, &modop_br_q); println!( " @@ -1125,25 +1629,21 @@ impl WithLocal for PBSTracer>> { #[cfg(test)] mod tests { - use crate::{backend::ModularOpsU64, ntt::NttBackendU64, random::DEFAULT_RNG}; + use crate::{ + backend::ModularOpsU64, + bool, + ntt::NttBackendU64, + random::DEFAULT_RNG, + rgsw::{ + secret_key_encrypt_rlwe, RgswCiphertextEvaluationDomain, SeededRgswCiphertext, + SeededRlweCiphertext, + }, + utils::negacyclic_mul, + }; - use super::*; + use self::parameters::{MP_BOOL_PARAMS, SP_BOOL_PARAMS}; - const SP_BOOL_PARAMS: BoolParameters = BoolParameters:: { - rlwe_q: 268369921u64, - rlwe_logq: 28, - lwe_q: 1 << 16, - lwe_logq: 16, - br_q: 1 << 10, - rlwe_n: 1 << 10, - lwe_n: 493, - d_rgsw: 3, - logb_rgsw: 8, - d_lwe: 3, - logb_lwe: 4, - g: 5, - w: 1, - }; + use super::*; // #[test] // fn trial() { @@ -1161,8 +1661,8 @@ mod tests { let mut m = true; for _ in 0..1000 { - let lwe_ct = bool_evaluator.encrypt(m, &client_key); - let m_back = bool_evaluator.decrypt(&lwe_ct, &client_key); + let lwe_ct = bool_evaluator.sk_encrypt(m, &client_key); + let m_back = bool_evaluator.sk_decrypt(&lwe_ct, &client_key); assert_eq!(m, m_back); m = !m; } @@ -1194,8 +1694,8 @@ mod tests { let mut m0 = false; let mut m1 = true; - let mut ct0 = bool_evaluator.encrypt(m0, &client_key); - let mut ct1 = bool_evaluator.encrypt(m1, &client_key); + let mut ct0 = bool_evaluator.sk_encrypt(m0, &client_key); + let mut ct1 = bool_evaluator.sk_encrypt(m1, &client_key); for _ in 0..1000 { let ct_back = bool_evaluator.nand( &ct0, @@ -1210,7 +1710,13 @@ mod tests { // Trace and measure PBS noise { // Trace PBS - PBSTracer::with_local(|t| t.trace(&SP_BOOL_PARAMS, &client_key)); + PBSTracer::with_local(|t| { + t.trace( + &SP_BOOL_PARAMS, + &client_key.sk_lwe.values(), + client_key.sk_rlwe.values(), + ) + }); // Calculate noise in ciphertext post PBS let ideal = if m_out { @@ -1226,7 +1732,7 @@ mod tests { ); println!("PBS noise: {noise}"); } - let m_back = bool_evaluator.decrypt(&ct_back, &client_key); + let m_back = bool_evaluator.sk_decrypt(&ct_back, &client_key); assert_eq!(m_out, m_back); println!("----------"); @@ -1237,4 +1743,357 @@ mod tests { ct0 = ct_back; } } + + #[test] + fn multi_party_encryption_decryption() { + let bool_evaluator = + BoolEvaluator::>, u64, NttBackendU64, ModularOpsU64>::new(MP_BOOL_PARAMS); + + let no_of_parties = 5; + let parties = (0..no_of_parties) + .map(|_| bool_evaluator.client_key()) + .collect_vec(); + + let mut ideal_rlwe_sk = vec![0i32; bool_evaluator.rlwe_n()]; + parties.iter().for_each(|k| { + izip!(ideal_rlwe_sk.iter_mut(), k.sk_rlwe.values()).for_each(|(ideal_i, s_i)| { + *ideal_i = *ideal_i + s_i; + }); + }); + + let mut m = true; + for i in 0..100 { + let pk_cr_seed = [0u8; 32]; + + let public_key_share = parties + .iter() + .map(|k| bool_evaluator.multi_party_public_key_share(pk_cr_seed, k)) + .collect_vec(); + + let collective_pk = PublicKey::>, DefaultSecureRng, ModularOpsU64>::from( + public_key_share.as_slice(), + ); + let lwe_ct = bool_evaluator.pk_encrypt(&collective_pk.key, m); + + let decryption_shares = parties + .iter() + .map(|k| bool_evaluator.multi_party_decryption_share(&lwe_ct, k)) + .collect_vec(); + + let m_back = bool_evaluator.multi_party_decrypt(&decryption_shares, &lwe_ct); + + { + let ideal_m = if m { + bool_evaluator.rlweq_by8 + } else { + bool_evaluator.parameters.rlwe_q - bool_evaluator.rlweq_by8 + }; + let noise = measure_noise_lwe( + &lwe_ct, + &ideal_rlwe_sk, + &bool_evaluator.rlwe_modop, + &ideal_m, + ); + println!("Noise: {noise}"); + } + + assert_eq!(m_back, m); + m = !m; + } + } + + #[test] + fn trial_mp() { + let bool_evaluator = + BoolEvaluator::>, u64, NttBackendU64, ModularOpsU64>::new(MP_BOOL_PARAMS); + + let no_of_parties = 2; + let parties = (0..no_of_parties) + .map(|_| bool_evaluator.client_key()) + .collect_vec(); + + // Collective public key + let pk_cr_seed = [0u8; 32]; + let public_key_share = parties + .iter() + .map(|k| bool_evaluator.multi_party_public_key_share(pk_cr_seed, k)) + .collect_vec(); + let collective_pk = PublicKey::>, DefaultSecureRng, ModularOpsU64>::from( + public_key_share.as_slice(), + ); + + // Server key + let pbs_cr_seed = [1u8; 32]; + let server_key_shares = parties + .iter() + .map(|k| bool_evaluator.multi_party_sever_key_share(pbs_cr_seed, &collective_pk.key, k)) + .collect_vec(); + let seeded_server_key = + aggregate_multi_party_server_key_shares::<_, _, _, ModularOpsU64, NttBackendU64>( + &server_key_shares, + &bool_evaluator.decomposer_rlwe, + ); + let server_key_eval = ServerKeyEvaluationDomain::<_, DefaultSecureRng, NttBackendU64>::from( + &seeded_server_key, + ); + + // construct ideal rlwe sk for meauring noise + let ideal_client_key = { + let mut ideal_rlwe_sk = vec![0i32; bool_evaluator.rlwe_n()]; + parties.iter().for_each(|k| { + izip!(ideal_rlwe_sk.iter_mut(), k.sk_rlwe.values()).for_each(|(ideal_i, s_i)| { + *ideal_i = *ideal_i + s_i; + }); + }); + let mut ideal_lwe_sk = vec![0i32; bool_evaluator.lwe_n()]; + parties.iter().for_each(|k| { + izip!(ideal_lwe_sk.iter_mut(), k.sk_lwe.values()).for_each(|(ideal_i, s_i)| { + *ideal_i = *ideal_i + s_i; + }); + }); + + ClientKey { + sk_lwe: LweSecret { + values: ideal_lwe_sk, + }, + sk_rlwe: RlweSecret { + values: ideal_rlwe_sk, + }, + } + }; + + // test LWE ksk from RLWE -> LWE + // { + // let lwe_q = bool_evaluator.parameters.lwe_q; + // let lwe_logq = bool_evaluator.parameters.lwe_logq; + // let lwe_n = bool_evaluator.parameters.lwe_n; + // let rlwe_n = bool_evaluator.parameters.rlwe_n; + // let logp = 2; + // let lwe_modop = &bool_evaluator.lwe_modop; + // let mut rng = DefaultSecureRng::new(); + + // let m = 1; + // let encoded_m = m << (lwe_logq - logp); + + // // Encrypt + // let mut lwe_ct = vec![0u64; rlwe_n + 1]; + // encrypt_lwe( + // &mut lwe_ct, + // &encoded_m, + // ideal_client_key.sk_rlwe.values(), + // lwe_modop, + // &mut rng, + // ); + + // // key switch + // let lwe_decomposer = &bool_evaluator.decomposer_lwe; + // let mut lwe_out = vec![0u64; lwe_n + 1]; + // lwe_key_switch( + // &mut lwe_out, + // &lwe_ct, + // &server_key_eval.lwe_ksk, + // lwe_modop, + // lwe_decomposer, + // ); + + // let encoded_m_back = decrypt_lwe(&lwe_out, + // ideal_client_key.sk_lwe.values(), lwe_modop); let m_back = + // ((encoded_m_back as f64 * (1 << logp) as f64) / (lwe_q as + // f64)).round() as u64; dbg!(m_back, m); + + // let noise = measure_noise_lwe( + // &lwe_out, + // ideal_client_key.sk_lwe.values(), + // lwe_modop, + // &encoded_m, + // ); + + // println!("Noise: {noise}"); + // } + + { + let rlwe_q = bool_evaluator.parameters.rlwe_q; + let rlwe_n = bool_evaluator.parameters.rlwe_n; + let logp = 2; + let p = 1 << logp; + let rlwe_modop = &bool_evaluator.rlwe_modop; + let rlwe_nttop = &bool_evaluator.rlwe_nttop; + let d_rgsw = bool_evaluator.parameters.d_rgsw; + + let mut rng = DefaultSecureRng::new(); + let mut m = vec![0u64; rlwe_n]; + RandomUniformDist::random_fill(&mut rng, &p, m.as_mut_slice()); + + // Encode message m + let encoded_m = m + .iter() + .map(|el| ((*el as f64 * rlwe_q as f64) / (p as f64)).round() as u64) + .collect_vec(); + + // Encrypt encoded m -> RLWE(m) + let mut rlwe_seed = [0u8; 32]; + rng.fill_bytes(&mut rlwe_seed); + let mut seeded_rlwe_ct = SeededRlweCiphertext::empty(rlwe_n, rlwe_seed, rlwe_q); + let mut rlwe_prng = DefaultSecureRng::new_seeded(rlwe_seed); + secret_key_encrypt_rlwe( + &encoded_m, + &mut seeded_rlwe_ct.data, + ideal_client_key.sk_rlwe.values(), + rlwe_modop, + rlwe_nttop, + &mut rlwe_prng, + &mut rng, + ); + // public_key_encrypt_rgsw(out_rgsw, m, public_key, gadget_vector, mod_op, + // ntt_op, rng) + let mut rlwe_ct = + RlweCiphertext::>, DefaultSecureRng>::from(&seeded_rlwe_ct); + + let index = 0; + + let mut rgsw_ct = { + let rgsw_seed = [0u8; 32]; + let mut rgsw_prng = DefaultSecureRng::new_seeded(rgsw_seed); + let mut rgsw_ct = SeededRgswCiphertext::>, _>::empty( + rlwe_n, d_rgsw, rgsw_seed, rlwe_q, + ); + let mut si_poly = vec![0u64; rlwe_n]; + // dbg!(ideal_client_key.sk_lwe.values()); + let secret_el_i = ideal_client_key.sk_lwe.values[index]; + dbg!(secret_el_i); + if secret_el_i < 0 { + si_poly[rlwe_n - secret_el_i.abs() as usize] = rlwe_q - 1; + } else { + si_poly[secret_el_i.abs() as usize] = 1; + } + secret_key_encrypt_rgsw( + &mut rgsw_ct.data, + &si_poly, + &gadget_vector( + bool_evaluator.parameters.rlwe_logq, + bool_evaluator.parameters.logb_rgsw, + d_rgsw, + ), + ideal_client_key.sk_rlwe.values(), + rlwe_modop, + rlwe_nttop, + &mut rgsw_prng, + &mut rng, + ); + + RgswCiphertextEvaluationDomain::<_, DefaultSecureRng, NttBackendU64>::from(&rgsw_ct) + }; + + // RLWE(m*X^{s[i]}) = RLWE(m) x RGSW(X^{s[i]}) + let mut scratch_matrix_dplus2_ring = vec![vec![0u64; rlwe_n]; d_rgsw + 2]; + let rlwe_decomposer = &bool_evaluator.decomposer_rlwe; + rlwe_by_rgsw( + &mut rlwe_ct, + server_key_eval.rgsw_ct_lwe_si(index), + // &rgsw_ct.data, + &mut scratch_matrix_dplus2_ring, + rlwe_decomposer, + rlwe_nttop, + rlwe_modop, + ); + + // decrypt RLWE(m*X^{s[i]}) to get encoded m[X]*X^{s[i]} + let mut encoded_m_back = vec![0u64; rlwe_n]; + decrypt_rlwe( + &rlwe_ct, + ideal_client_key.sk_rlwe.values(), + &mut encoded_m_back, + rlwe_nttop, + rlwe_modop, + ); + let m_back = encoded_m_back + .iter() + .map(|el| (((*el as f64 * p as f64) / (rlwe_q as f64)).round() as u64) % p) + .collect_vec(); + + // calculate m[X]X^{s[i]} in plain + let mut si_poly = vec![0u64; rlwe_n]; + // dbg!(ideal_client_key.sk_lwe.values()); + let secret_el_i = ideal_client_key.sk_lwe.values[index]; + dbg!(secret_el_i); + if secret_el_i < 0 { + si_poly[rlwe_n - secret_el_i.abs() as usize] = p - 1; + } else { + si_poly[secret_el_i.abs() as usize] = 1; + } + let mul = |a: &u64, b: &u64| ((*a as u128 * *b as u128) % p as u128) as u64; + let expected_m = negacyclic_mul(&m, &si_poly, mul, p); + assert_eq!(expected_m, m_back); + // println!("M:{:?}", m); + // println!("M_back:{:?} \n Expected_m:{:?}", m_back, expected_m); + } + + // // PBS + // let mut scratch_lwen_plus1 = vec![0u64; + // bool_evaluator.parameters.lwe_n + 1]; + // let mut scratch_matrix_dplus2_ring = vec![ + // vec![0u64; bool_evaluator.parameters.rlwe_n]; + // bool_evaluator.parameters.d_rgsw + 2 + // ]; + + // let mut m0 = true; + // let mut m1 = false; + + // for _ in 0..100 { + // let lwe0 = bool_evaluator.pk_encrypt(&collective_pk.key, m0); + // let lwe1 = bool_evaluator.pk_encrypt(&collective_pk.key, m1); + + // let lwe_out = bool_evaluator.nand( + // &lwe0, + // &lwe1, + // &server_key_eval, + // &mut scratch_lwen_plus1, + // &mut scratch_matrix_dplus2_ring, + // ); + + // let m_expected = !(m0 & m1); + + // // measure noise + // { + // // Trace PBS + // PBSTracer::with_local(|t| { + // t.trace( + // &MP_BOOL_PARAMS, + // &ideal_client_key.sk_lwe.values(), + // &ideal_client_key.sk_rlwe.values(), + // ) + // }); + + // let ideal_m = if m_expected { + // bool_evaluator.rlweq_by8 + // } else { + // bool_evaluator.parameters.rlwe_q - + // bool_evaluator.rlweq_by8 }; + // let noise = measure_noise_lwe( + // &lwe_out, + // ideal_client_key.sk_rlwe.values(), + // &bool_evaluator.rlwe_modop, + // &ideal_m, + // ); + // println!("Noise: {noise}"); + // } + + // // multi-party decrypt + // // let decryption_shares = parties + // // .iter() + // // .map(|k| + // bool_evaluator.multi_party_decryption_share(&lwe_out, k)) + // // .collect_vec(); + // // let m_back = + // bool_evaluator.multi_party_decrypt(&decryption_shares, // + // &lwe_out); + + // let m_back = bool_evaluator.sk_decrypt(&lwe_out, + // &ideal_client_key); + + // dbg!(m_expected, m_back); + // m1 = m0; + // m0 = m_back; + // } + } } diff --git a/src/bool/mod.rs b/src/bool/mod.rs new file mode 100644 index 0000000..bfa8111 --- /dev/null +++ b/src/bool/mod.rs @@ -0,0 +1,2 @@ +mod evaluator; +mod parameters; diff --git a/src/bool/parameters.rs b/src/bool/parameters.rs new file mode 100644 index 0000000..6c6556c --- /dev/null +++ b/src/bool/parameters.rs @@ -0,0 +1,67 @@ +#[derive(Clone, PartialEq)] +pub(super) struct BoolParameters { + pub(super) rlwe_q: El, + pub(super) rlwe_logq: usize, + pub(super) lwe_q: El, + pub(super) lwe_logq: usize, + pub(super) br_q: usize, + pub(super) rlwe_n: usize, + pub(super) lwe_n: usize, + pub(super) d_rgsw: usize, + pub(super) logb_rgsw: usize, + pub(super) d_lwe: usize, + pub(super) logb_lwe: usize, + pub(super) g: usize, + pub(super) w: usize, +} + +// impl BoolParameters { +// fn rlwe_q(&self) -> &El { +// &self.rlwe_q +// } +// } + +pub(super) const SP_BOOL_PARAMS: BoolParameters = BoolParameters:: { + rlwe_q: 268369921u64, + rlwe_logq: 28, + lwe_q: 1 << 16, + lwe_logq: 16, + br_q: 1 << 10, + rlwe_n: 1 << 10, + lwe_n: 493, + d_rgsw: 3, + logb_rgsw: 8, + d_lwe: 3, + logb_lwe: 4, + g: 5, + w: 1, +}; + +pub(super) const MP_BOOL_PARAMS: BoolParameters = BoolParameters:: { + rlwe_q: 2305843009213616129u64, + rlwe_logq: 61, + lwe_q: 1 << 25, + lwe_logq: 25, + br_q: 1 << 11, + rlwe_n: 1 << 11, + lwe_n: 500, + d_rgsw: 7, + logb_rgsw: 8, + d_lwe: 5, + logb_lwe: 5, + g: 5, + w: 1, +}; + +#[cfg(test)] +mod tests { + use crate::utils::generate_prime; + + #[test] + fn find_prime() { + let bits = 61; + let ring_size = 1 << 11; + let prime = generate_prime(bits, ring_size * 2, 1 << bits).unwrap(); + dbg!(prime); + } +} diff --git a/src/lwe.rs b/src/lwe.rs index 2d3b288..76899b8 100644 --- a/src/lwe.rs +++ b/src/lwe.rs @@ -78,7 +78,7 @@ trait LweCiphertext {} #[derive(Clone)] pub struct LweSecret { - values: Vec, + pub(crate) values: Vec, } impl Secret for LweSecret { diff --git a/src/multi_party.rs b/src/multi_party.rs index 9ae4e09..cda5bf8 100644 --- a/src/multi_party.rs +++ b/src/multi_party.rs @@ -1,25 +1,24 @@ use crate::{ backend::VectorOps, - ntt::{self, Ntt}, + ntt::Ntt, random::{NewWithSeed, RandomGaussianDist, RandomUniformDist}, utils::TryConvertFrom, Matrix, Row, RowEntity, RowMut, }; -fn public_key_share< +pub(crate) fn public_key_share< R: Row + RowMut + RowEntity, S, ModOp: VectorOps, NttOp: Ntt, - Rng: RandomGaussianDist<[R::Element], Parameters = R::Element> - + NewWithSeed - + RandomUniformDist<[R::Element], Parameters = R::Element>, + Rng: RandomGaussianDist<[R::Element], Parameters = R::Element>, + PRng: RandomUniformDist<[R::Element], Parameters = R::Element>, >( share_out: &mut R, s_i: &[S], modop: &ModOp, nttop: &NttOp, - crp_seed: Rng::Seed, + p_rng: &mut PRng, rng: &mut Rng, ) where R: TryConvertFrom<[S], Parameters = R::Element>, @@ -29,10 +28,14 @@ fn public_key_share< let q = modop.modulus(); - // a*s - let mut a = R::zeros(ring_size); - let mut p_rng = Rng::new_with_seed(crp_seed); - RandomUniformDist::random_fill(&mut p_rng, &q, a.as_mut()); + // sample a + let mut a = { + let mut a = R::zeros(ring_size); + RandomUniformDist::random_fill(p_rng, &q, a.as_mut()); + a + }; + + // s*a nttop.forward(a.as_mut()); let mut s = R::try_convert_from(s_i, &q); nttop.forward(s.as_mut()); @@ -42,5 +45,3 @@ fn public_key_share< RandomGaussianDist::random_fill(rng, &q, share_out.as_mut()); modop.elwise_add_mut(share_out.as_mut(), s.as_ref()); // s*e + e } - -fn rlwe_galois_auto_key_share() {} diff --git a/src/rgsw.rs b/src/rgsw.rs index 53b9817..b6515ea 100644 --- a/src/rgsw.rs +++ b/src/rgsw.rs @@ -88,11 +88,16 @@ where } } +pub struct RgswCiphertext { + data: M, + modulus: M::MatElement, +} + pub struct SeededRgswCiphertext where M: Matrix, { - data: M, + pub(crate) data: M, seed: S, modulus: M::MatElement, } @@ -108,7 +113,7 @@ impl SeededRgswCiphertext { } } - fn empty( + pub(crate) fn empty( ring_size: usize, d_rgsw: usize, seed: S, @@ -136,7 +141,7 @@ where } pub struct RgswCiphertextEvaluationDomain { - data: M, + pub(crate) data: M, _phantom: PhantomData<(R, N)>, } @@ -192,6 +197,51 @@ where } } +impl< + M: MatrixMut + MatrixEntity, + R, + N: NttInit + Ntt, + > From<&RgswCiphertext> for RgswCiphertextEvaluationDomain +where + ::R: RowMut, + M::MatElement: Copy, + M: Debug, +{ + fn from(value: &RgswCiphertext) -> Self { + assert!(value.data.dimension().0 % 4 == 0); + let d = value.data.dimension().0.div(4); + + let mut data = M::zeros(4 * d, value.data.dimension().1); + + // copy RLWE'(-sm) + izip!(data.iter_rows_mut().take(2 * d), value.data.iter_rows()).for_each( + |(to_ri, from_ri)| { + to_ri.as_mut().copy_from_slice(from_ri.as_ref()); + }, + ); + + // copy RLWE'(m) + izip!( + data.iter_rows_mut().skip(2 * d), + value.data.iter_rows().skip(2 * d) + ) + .for_each(|(to_ri, from_ri)| { + to_ri.as_mut().copy_from_slice(from_ri.as_ref()); + }); + + // Send polynomials to evaluation domain + let ring_size = data.dimension().1; + let nttop = N::new(value.modulus, ring_size); + data.iter_rows_mut() + .for_each(|ri| nttop.forward(ri.as_mut())); + + Self { + data: data, + _phantom: PhantomData, + } + } +} + impl Debug for RgswCiphertextEvaluationDomain { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("RgswCiphertextEvaluationDomain") @@ -220,13 +270,13 @@ pub struct SeededRlweCiphertext where R: Row, { - data: R, - seed: S, - modulus: R::Element, + pub(crate) data: R, + pub(crate) seed: S, + pub(crate) modulus: R::Element, } impl SeededRlweCiphertext { - fn empty(ring_size: usize, seed: S, modulus: R::Element) -> Self { + pub(crate) fn empty(ring_size: usize, seed: S, modulus: R::Element) -> Self { SeededRlweCiphertext { data: R::zeros(ring_size), seed, @@ -236,8 +286,8 @@ impl SeededRlweCiphertext { } pub struct RlweCiphertext { - data: M, - is_trivial: bool, + pub(crate) data: M, + pub(crate) is_trivial: bool, _phatom: PhantomData, } @@ -316,9 +366,56 @@ pub trait IsTrivial { fn set_not_trivial(&mut self); } +pub struct SeededRlwePublicKey { + data: Ro, + seed: S, + modulus: Ro::Element, +} + +impl SeededRlwePublicKey { + pub(crate) fn empty(ring_size: usize, seed: S, modulus: Ro::Element) -> Self { + Self { + data: Ro::zeros(ring_size), + seed, + modulus, + } + } +} + +pub struct RlwePublicKey { + data: M, + _phantom: PhantomData, +} + +impl< + M: MatrixMut + MatrixEntity, + Rng: NewWithSeed + RandomUniformDist<[M::MatElement], Parameters = M::MatElement>, + > From<&SeededRlwePublicKey> for RlwePublicKey +where + ::R: RowMut, + M::MatElement: Copy, + Rng::Seed: Copy, +{ + fn from(value: &SeededRlwePublicKey) -> Self { + let mut data = M::zeros(2, value.data.as_ref().len()); + + // sample a + let mut p_rng = Rng::new_with_seed(value.seed); + RandomUniformDist::random_fill(&mut p_rng, &value.modulus, data.get_row_mut(0)); + + // copy over b + data.get_row_mut(1).copy_from_slice(value.data.as_ref()); + + Self { + data, + _phantom: PhantomData, + } + } +} + #[derive(Clone)] pub struct RlweSecret { - values: Vec, + pub(crate) values: Vec, } impl Secret for RlweSecret { @@ -490,10 +587,10 @@ pub(crate) fn galois_key_gen< ); } -pub(crate) fn routine>( - write_to_row: &mut [M::MatElement], - matrix_a: &[M::R], - matrix_b: &[M::R], +pub(crate) fn routine>( + write_to_row: &mut [R::Element], + matrix_a: &[R], + matrix_b: &[R], mod_op: &ModOp, ) { izip!(matrix_a.iter(), matrix_b.iter()).for_each(|(a, b)| { @@ -509,13 +606,12 @@ pub(crate) fn routine>( - r: &[M::MatElement], - decomp_r: &mut [M::R], +pub(crate) fn decompose_r>( + r: &[R::Element], + decomp_r: &mut [R], decomposer: &D, ) where - ::R: RowMut, - M::MatElement: Copy, + R::Element: Copy, { let ring_size = r.len(); let d = decomposer.d(); @@ -594,7 +690,7 @@ pub(crate) fn galois_auto< let (ksk_a, ksk_b) = ksk.split_at_row(d); tmp_rlwe_out[0].as_mut().fill(Mmut::MatElement::zero()); // a' = decomp * RLWE'_A(s(X^k)) - routine::( + routine( tmp_rlwe_out[0].as_mut(), scratch_matrix_d_ring, ksk_a, @@ -604,7 +700,7 @@ pub(crate) fn galois_auto< ntt_op.forward(tmp_rlwe_out[1].as_mut()); // b' = b(X^k) // b' += decomp * RLWE'_B(s(X^k)) - routine::( + routine( tmp_rlwe_out[1].as_mut(), scratch_matrix_d_ring, ksk_b, @@ -665,19 +761,19 @@ pub(crate) fn rlwe_by_rgsw< if !rlwe_in.is_trivial() { // a_in = 0 when RLWE_in is trivial RLWE ciphertext // decomp - decompose_r::(rlwe_in.get_row_slice(0), scratch_matrix_d_ring, decomposer); + decompose_r(rlwe_in.get_row_slice(0), scratch_matrix_d_ring, decomposer); scratch_matrix_d_ring .iter_mut() .for_each(|r| ntt_op.forward(r.as_mut())); // a_out += decomp \cdot RLWE_A'(-sm) - routine::( + routine( scratch_rlwe_out[0].as_mut(), scratch_matrix_d_ring.as_ref(), &rlwe_dash_nsm[..d_rgsw], mod_op, ); // b_out += decomp \cdot RLWE_B'(-sm) - routine::( + routine( scratch_rlwe_out[1].as_mut(), scratch_matrix_d_ring.as_ref(), &rlwe_dash_nsm[d_rgsw..], @@ -685,19 +781,19 @@ pub(crate) fn rlwe_by_rgsw< ); } // decomp - decompose_r::(rlwe_in.get_row_slice(1), scratch_matrix_d_ring, decomposer); + decompose_r(rlwe_in.get_row_slice(1), scratch_matrix_d_ring, decomposer); scratch_matrix_d_ring .iter_mut() .for_each(|r| ntt_op.forward(r.as_mut())); // a_out += decomp \cdot RLWE_A'(m) - routine::( + routine( scratch_rlwe_out[0].as_mut(), scratch_matrix_d_ring.as_ref(), &rlwe_dash_m[..d_rgsw], mod_op, ); // b_out += decomp \cdot RLWE_B'(m) - routine::( + routine( scratch_rlwe_out[1].as_mut(), scratch_matrix_d_ring.as_ref(), &rlwe_dash_m[d_rgsw..], @@ -718,12 +814,110 @@ pub(crate) fn rlwe_by_rgsw< rlwe_in.set_not_trivial(); } +/// Inplace mutates rlwe_0_eval_domain to equal RGSW(m0m1) = RGSW(m0)xRGSW(m1) +/// in evaluation domain +/// +/// - rgsw_0_eval_domain: RGSW(m0) in evaluation domain +/// - rgsw_1: RGSW(m1) +/// - scratch_matrix_d_plus_rgsw_by_ring: scratch space matrix of size +/// (d+(d*4))xring_size, where d equals d_rgsw +pub(crate) fn rgsw_by_rgsw_inplace< + Mmut: MatrixMut, + D: Decomposer, + ModOp: VectorOps, + NttOp: Ntt, +>( + rgsw_0_eval_domain: &mut Mmut, + rgsw_1: &Mmut, + decomposer: &D, + scratch_matrix_d_plus_rgsw_by_ring: &mut Mmut, + ntt_op: &NttOp, + mod_op: &ModOp, +) where + ::R: RowMut, + Mmut::MatElement: Copy + Zero, +{ + let d_rgsw = decomposer.d(); + assert!(rgsw_0_eval_domain.dimension().0 == 4 * d_rgsw); + let ring_size = rgsw_0_eval_domain.dimension().1; + assert!(rgsw_1.dimension() == (4 * d_rgsw, ring_size)); + assert!(scratch_matrix_d_plus_rgsw_by_ring.dimension() == (d_rgsw + (d_rgsw * 4), ring_size)); + + let (decomp_r_space, rgsw_space) = scratch_matrix_d_plus_rgsw_by_ring.split_at_row_mut(d_rgsw); + + // zero rgsw_space + rgsw_space + .iter_mut() + .for_each(|ri| ri.as_mut().fill(Mmut::MatElement::zero())); + let (rlwe_dash_space_nsm, rlwe_dash_space_m) = rgsw_space.split_at_mut(d_rgsw * 2); + let (rlwe_dash_space_nsm_parta, rlwe_dash_space_nsm_partb) = + rlwe_dash_space_nsm.split_at_mut(d_rgsw); + let (rlwe_dash_space_m_parta, rlwe_dash_space_m_partb) = rlwe_dash_space_m.split_at_mut(d_rgsw); + + let (rgsw0_nsm, rgsw0_m) = rgsw_0_eval_domain.split_at_row(d_rgsw * 2); + let (rgsw1_nsm, rgsw1_m) = rgsw_1.split_at_row(d_rgsw * 2); + + // RGSW x RGSW + izip!( + rgsw1_nsm.iter().take(d_rgsw).chain(rgsw1_m).take(d_rgsw), + rgsw1_nsm.iter().skip(d_rgsw).chain(rgsw1_m).skip(d_rgsw), + rlwe_dash_space_nsm_parta + .iter_mut() + .chain(rlwe_dash_space_m_parta), + rlwe_dash_space_nsm_partb + .iter_mut() + .chain(rlwe_dash_space_m_partb), + ) + .for_each(|(rlwe_a, rlwe_b, rlwe_out_a, rlwe_out_b)| { + // Part A + decompose_r(rlwe_a.as_ref(), decomp_r_space.as_mut(), decomposer); + decomp_r_space + .iter_mut() + .for_each(|ri| ntt_op.forward(ri.as_mut())); + routine( + rlwe_out_a.as_mut(), + decomp_r_space, + &rgsw0_nsm[..d_rgsw], + mod_op, + ); + routine( + rlwe_out_b.as_mut(), + decomp_r_space, + &rgsw0_nsm[d_rgsw..], + mod_op, + ); + + // Part B + decompose_r(rlwe_b.as_ref(), decomp_r_space.as_mut(), decomposer); + decomp_r_space + .iter_mut() + .for_each(|ri| ntt_op.forward(ri.as_mut())); + routine( + rlwe_out_a.as_mut(), + decomp_r_space, + &rgsw0_m[..d_rgsw], + mod_op, + ); + routine( + rlwe_out_b.as_mut(), + decomp_r_space, + &rgsw0_m[d_rgsw..], + mod_op, + ); + }); + + // copy over RGSW(m0m1) into RGSW(m0) + izip!(rgsw_0_eval_domain.iter_rows_mut(), rgsw_space.iter()) + .for_each(|(to_ri, from_ri)| to_ri.as_mut().copy_from_slice(from_ri.as_ref())) +} + /// Encrypts message m as a RGSW ciphertext. /// /// - m_eval: is `m` is evaluation domain /// - out_rgsw: RGSW(m) is stored as single matrix of dimension (d_rgsw * 3, /// ring_size). The matrix has the following structure [RLWE'_A(-sm) || -/// RLWE'_B(-sm) || RLWE'_B(m)]^T and RLWE'_A(m) is generated via seed +/// RLWE'_B(-sm) || RLWE'_B(m)]^T and RLWE'_A(m) is generated via seed (where +/// p_rng is assumed to be seeded with seed) pub(crate) fn secret_key_encrypt_rgsw< Mmut: MatrixMut + MatrixEntity, S, @@ -816,6 +1010,123 @@ pub(crate) fn secret_key_encrypt_rgsw< }); } +pub(crate) fn public_key_encrypt_rgsw< + Mmut: MatrixMut + MatrixEntity, + M: Matrix, + R: RandomGaussianDist<[Mmut::MatElement], Parameters = Mmut::MatElement> + + RandomUniformDist<[u8], Parameters = u8> + + RandomUniformDist, + ModOp: VectorOps, + NttOp: Ntt, +>( + out_rgsw: &mut Mmut, + m: &[M::MatElement], + public_key: &M, + gadget_vector: &[Mmut::MatElement], + mod_op: &ModOp, + ntt_op: &NttOp, + rng: &mut R, +) where + ::R: RowMut + RowEntity + TryConvertFrom<[i32], Parameters = Mmut::MatElement>, + Mmut::MatElement: Copy, +{ + let ring_size = public_key.dimension().1; + let d = gadget_vector.len(); + assert!(public_key.dimension().0 == 2); + assert!(out_rgsw.dimension() == (d * 4, ring_size)); + + let mut pk_eval = Mmut::zeros(2, ring_size); + izip!(pk_eval.iter_rows_mut(), public_key.iter_rows()).for_each(|(to_i, from_i)| { + to_i.as_mut().copy_from_slice(from_i.as_ref()); + ntt_op.forward(to_i.as_mut()); + }); + let p0 = pk_eval.get_row_slice(0); + let p1 = pk_eval.get_row_slice(1); + + let q = mod_op.modulus(); + + // RGSW(m) = RLWE'(-sm), RLWE(m) + let (rlwe_dash_nsm, rlwe_dash_m) = out_rgsw.split_at_row_mut(2 * d); + + // RLWE(-sm) + let (rlwe_dash_nsm_parta, rlwe_dash_nsm_partb) = rlwe_dash_nsm.split_at_mut(d); + izip!( + rlwe_dash_nsm_parta.iter_mut(), + rlwe_dash_nsm_partb.iter_mut(), + gadget_vector.iter() + ) + .for_each(|(ai, bi, beta_i)| { + // sample ephemeral secret u_i + let mut u = vec![0i32; ring_size]; + fill_random_ternary_secret_with_hamming_weight(u.as_mut(), ring_size >> 1, rng); + let mut u_eval = Mmut::R::try_convert_from(u.as_ref(), &q); + ntt_op.forward(u_eval.as_mut()); + + let mut u_eval_copy = Mmut::R::zeros(ring_size); + u_eval_copy.as_mut().copy_from_slice(u_eval.as_ref()); + + // p0 * u + mod_op.elwise_mul_mut(u_eval.as_mut(), p0.as_ref()); + // p1 * u + mod_op.elwise_mul_mut(u_eval_copy.as_mut(), p1.as_ref()); + ntt_op.backward(u_eval.as_mut()); + ntt_op.backward(u_eval_copy.as_mut()); + + // sample error + RandomGaussianDist::random_fill(rng, &q, ai.as_mut()); + RandomGaussianDist::random_fill(rng, &q, bi.as_mut()); + + // a = p0*u+e0 + mod_op.elwise_add_mut(ai.as_mut(), u_eval.as_ref()); + // b = p1*u+e1 + mod_op.elwise_add_mut(bi.as_mut(), u_eval_copy.as_ref()); + + // a = p0*u + e0 + \beta*m + // use u_eval as scratch + mod_op.elwise_scalar_mul(u_eval.as_mut(), m.as_ref(), beta_i); + mod_op.elwise_add_mut(ai.as_mut(), u_eval.as_ref()); + }); + + // RLWE(m) + let (rlwe_dash_m_parta, rlwe_dash_m_partb) = rlwe_dash_m.split_at_mut(d); + izip!( + rlwe_dash_m_parta.iter_mut(), + rlwe_dash_m_partb.iter_mut(), + gadget_vector.iter() + ) + .for_each(|(ai, bi, beta_i)| { + // sample ephemeral secret u_i + let mut u = vec![0i32; ring_size]; + fill_random_ternary_secret_with_hamming_weight(u.as_mut(), ring_size >> 1, rng); + let mut u_eval = Mmut::R::try_convert_from(u.as_ref(), &q); + ntt_op.forward(u_eval.as_mut()); + + let mut u_eval_copy = Mmut::R::zeros(ring_size); + u_eval_copy.as_mut().copy_from_slice(u_eval.as_ref()); + + // p0 * u + mod_op.elwise_mul_mut(u_eval.as_mut(), p0.as_ref()); + // p1 * u + mod_op.elwise_mul_mut(u_eval_copy.as_mut(), p1.as_ref()); + ntt_op.backward(u_eval.as_mut()); + ntt_op.backward(u_eval_copy.as_mut()); + + // sample error + RandomGaussianDist::random_fill(rng, &q, ai.as_mut()); + RandomGaussianDist::random_fill(rng, &q, bi.as_mut()); + + // a = p0*u+e0 + mod_op.elwise_add_mut(ai.as_mut(), u_eval.as_ref()); + // b = p1*u+e1 + mod_op.elwise_add_mut(bi.as_mut(), u_eval_copy.as_ref()); + + // b = p1*u + e0 + \beta*m + // use u_eval as scratch + mod_op.elwise_scalar_mul(u_eval.as_mut(), m.as_ref(), beta_i); + mod_op.elwise_add_mut(bi.as_mut(), u_eval.as_ref()); + }); +} + /// Encrypt polynomial m(X) as RLWE ciphertext. /// /// - rlwe_out: returned RLWE ciphertext RLWE(m) in coefficient domain. RLWE @@ -865,6 +1176,48 @@ pub(crate) fn secret_key_encrypt_rlwe< mod_op.elwise_add_mut(b_rlwe_out.as_mut(), sa.as_ref()); } +/// Generates RLWE public key +pub(crate) fn gen_rlwe_public_key< + Ro: RowMut + RowEntity, + S, + ModOp: VectorOps, + NttOp: Ntt, + PRng: RandomUniformDist<[Ro::Element], Parameters = Ro::Element>, + Rng: RandomGaussianDist<[Ro::Element], Parameters = Ro::Element>, +>( + part_b_out: &mut Ro, + s: &[S], + ntt_op: &NttOp, + mod_op: &ModOp, + p_rng: &mut PRng, + rng: &mut Rng, +) where + Ro: TryConvertFrom<[S], Parameters = Ro::Element>, +{ + let ring_size = s.len(); + assert!(part_b_out.as_ref().len() == ring_size); + + let q = mod_op.modulus(); + + // sample a + let mut a = { + let mut tmp = Ro::zeros(ring_size); + RandomUniformDist::random_fill(p_rng, &q, tmp.as_mut()); + tmp + }; + ntt_op.forward(a.as_mut()); + + // s*a + let mut sa = Ro::try_convert_from(s, &q); + ntt_op.forward(sa.as_mut()); + mod_op.elwise_mul_mut(sa.as_mut(), a.as_ref()); + ntt_op.backward(sa.as_mut()); + + // s*a + e + RandomGaussianDist::random_fill(rng, &q, part_b_out.as_mut()); + mod_op.elwise_add_mut(part_b_out.as_mut(), sa.as_ref()); +} + /// Decrypts degree 1 RLWE ciphertext RLWE(m) and returns m /// /// - rlwe_ct: input degree 1 ciphertext RLWE(m). @@ -965,27 +1318,28 @@ where #[cfg(test)] mod tests { - use std::vec; + use std::{ops::Mul, vec}; use itertools::{izip, Itertools}; use rand::{thread_rng, Rng}; use crate::{ - backend::{ModInit, ModularOpsU64}, + backend::{ArithmeticOps, ModInit, ModularOpsU64}, decomposer::{gadget_vector, DefaultDecomposer}, ntt::{self, Ntt, NttBackendU64, NttInit}, random::{DefaultSecureRng, NewWithSeed, RandomUniformDist}, rgsw::{ - measure_noise, AutoKeyEvaluationDomain, RgswCiphertextEvaluationDomain, RlweCiphertext, - SeededAutoKey, SeededRgswCiphertext, SeededRlweCiphertext, + gen_rlwe_public_key, measure_noise, public_key_encrypt_rgsw, AutoKeyEvaluationDomain, + RgswCiphertext, RgswCiphertextEvaluationDomain, RlweCiphertext, RlwePublicKey, + SeededAutoKey, SeededRgswCiphertext, SeededRlweCiphertext, SeededRlwePublicKey, }, utils::{generate_prime, negacyclic_mul}, Matrix, Secret, }; use super::{ - decrypt_rlwe, galois_auto, galois_key_gen, generate_auto_map, rlwe_by_rgsw, - secret_key_encrypt_rgsw, secret_key_encrypt_rlwe, RlweSecret, + decrypt_rlwe, galois_auto, galois_key_gen, generate_auto_map, rgsw_by_rgsw_inplace, + rlwe_by_rgsw, secret_key_encrypt_rgsw, secret_key_encrypt_rlwe, RlweSecret, }; #[test] @@ -1046,13 +1400,13 @@ mod tests { #[test] fn rlwe_by_rgsw_works() { - let logq = 24; + let logq = 50; let logp = 2; - let ring_size = 1 << 4; + let ring_size = 1 << 9; let q = generate_prime(logq, ring_size, 1u64 << logq).unwrap(); let p = 1u64 << logp; - let d_rgsw = 2; - let logb = 12; + let d_rgsw = 9; + let logb = 5; let mut rng = DefaultSecureRng::new_seeded([0u8; 32]); @@ -1067,51 +1421,99 @@ mod tests { let mod_op = ModularOpsU64::new(q); // Encrypt m1 as RGSW(m1) - let mut rgsw_seed = [0u8; 32]; - rng.fill_bytes(&mut rgsw_seed); - let mut seeded_rgsw_ct = SeededRgswCiphertext::>, [u8; 32]>::empty( - ring_size as usize, - d_rgsw, - rgsw_seed, - q, - ); - let mut p_rng = DefaultSecureRng::new_seeded(rgsw_seed); - let gadget_vector = gadget_vector(logq, logb, d_rgsw); - secret_key_encrypt_rgsw( - &mut seeded_rgsw_ct.data, - &m1, - &gadget_vector, - s.values(), - &mod_op, - &ntt_op, - &mut p_rng, - &mut rng, - ); - let rgsw_ct = RgswCiphertextEvaluationDomain::<_, DefaultSecureRng, NttBackendU64>::from( - &seeded_rgsw_ct, - ); + let rgsw_ct = { + //TODO(Jay): Figure out better way to test secret key and public key variant of + // RGSW ciphertext encryption within the same test + + if false { + // RGSW(m1) encryption using secret key + let mut rgsw_seed = [0u8; 32]; + rng.fill_bytes(&mut rgsw_seed); + let mut seeded_rgsw_ct = SeededRgswCiphertext::>, [u8; 32]>::empty( + ring_size as usize, + d_rgsw, + rgsw_seed, + q, + ); + let mut p_rng = DefaultSecureRng::new_seeded(rgsw_seed); + let gadget_vector = gadget_vector(logq, logb, d_rgsw); + secret_key_encrypt_rgsw( + &mut seeded_rgsw_ct.data, + &m1, + &gadget_vector, + s.values(), + &mod_op, + &ntt_op, + &mut p_rng, + &mut rng, + ); + RgswCiphertextEvaluationDomain::<_, DefaultSecureRng, NttBackendU64>::from( + &seeded_rgsw_ct, + ) + } else { + // RGSW(m1) encryption using public key + + // first create public key + let mut pk_seed = [0u8; 32]; + rng.fill_bytes(&mut pk_seed); + let mut pk_prng = DefaultSecureRng::new_seeded(pk_seed); + let mut seeded_pk = + SeededRlwePublicKey::, _>::empty(ring_size as usize, pk_seed, q); + gen_rlwe_public_key( + &mut seeded_pk.data, + s.values(), + &ntt_op, + &mod_op, + &mut pk_prng, + &mut rng, + ); + let pk = RlwePublicKey::>, DefaultSecureRng>::from(&seeded_pk); + + // public key encrypt RGSW(m1) + let mut rgsw_ct = vec![vec![0u64; ring_size as usize]; d_rgsw * 4]; + let gadget_vector = gadget_vector(logq, logb, d_rgsw); + public_key_encrypt_rgsw( + &mut rgsw_ct, + &m1, + &pk.data, + &gadget_vector, + &mod_op, + &ntt_op, + &mut rng, + ); + + RgswCiphertextEvaluationDomain::<_, DefaultSecureRng, NttBackendU64>::from( + &RgswCiphertext { + data: rgsw_ct, + modulus: q, + }, + ) + } + }; // Encrypt m0 as RLWE(m0) - let mut rlwe_seed = [0u8; 32]; - rng.fill_bytes(&mut rlwe_seed); - let mut seeded_rlwe_in_ct = - SeededRlweCiphertext::<_, [u8; 32]>::empty(ring_size as usize, rlwe_seed, q); - let mut p_rng = DefaultSecureRng::new_seeded(rlwe_seed); - let encoded_m = m0 - .iter() - .map(|v| (((*v as f64) * q as f64) / (p as f64)).round() as u64) - .collect_vec(); - secret_key_encrypt_rlwe( - &encoded_m, - &mut seeded_rlwe_in_ct.data, - s.values(), - &mod_op, - &ntt_op, - &mut p_rng, - &mut rng, - ); - let mut rlwe_in_ct = - RlweCiphertext::>, DefaultSecureRng>::from(&seeded_rlwe_in_ct); + let mut rlwe_in_ct = { + let mut rlwe_seed = [0u8; 32]; + rng.fill_bytes(&mut rlwe_seed); + let mut seeded_rlwe_in_ct = + SeededRlweCiphertext::<_, [u8; 32]>::empty(ring_size as usize, rlwe_seed, q); + let mut p_rng = DefaultSecureRng::new_seeded(rlwe_seed); + let encoded_m = m0 + .iter() + .map(|v| (((*v as f64) * q as f64) / (p as f64)).round() as u64) + .collect_vec(); + secret_key_encrypt_rlwe( + &encoded_m, + &mut seeded_rlwe_in_ct.data, + s.values(), + &mod_op, + &ntt_op, + &mut p_rng, + &mut rng, + ); + + RlweCiphertext::>, DefaultSecureRng>::from(&seeded_rlwe_in_ct) + }; // RLWE(m0m1) = RLWE(m0) x RGSW(m1) let mut scratch_space = vec![vec![0u64; ring_size as usize]; d_rgsw + 2]; @@ -1149,6 +1551,113 @@ mod tests { ); } + fn _pk_encrypt_rgsw( + m: &[u64], + public_key: &RlwePublicKey>, DefaultSecureRng>, + gadget_vector: &[u64], + mod_op: &ModularOpsU64, + ntt_op: &NttBackendU64, + ) -> RgswCiphertext>> { + let (_, ring_size) = Matrix::dimension(&public_key.data); + let d_rgsw = gadget_vector.len(); + + let mut rng = DefaultSecureRng::new(); + + assert!(m.len() == ring_size); + + // public key encrypt RGSW(m1) + let mut rgsw_ct = vec![vec![0u64; ring_size]; d_rgsw * 4]; + public_key_encrypt_rgsw( + &mut rgsw_ct, + m, + &public_key.data, + gadget_vector, + mod_op, + ntt_op, + &mut rng, + ); + + RgswCiphertext { + data: rgsw_ct, + modulus: mod_op.modulus(), + } + } + + #[test] + fn rgsw_by_rgsw() { + let logq = 50; + let logp = 2; + let ring_size = 1 << 4; + let q = generate_prime(logq, ring_size, 1u64 << logq).unwrap(); + let p = 1u64 << logp; + let d_rgsw = 10; + let logb = 5; + + let s = RlweSecret::random((ring_size >> 1) as usize, ring_size as usize); + + let mut rng = DefaultSecureRng::new(); + let ntt_op = NttBackendU64::new(q, ring_size as usize); + let mod_op = ModularOpsU64::new(q); + let gadget_vector = gadget_vector(logq, logb, d_rgsw); + let decomposer = DefaultDecomposer::new(q, logb, d_rgsw); + + // Public Key + let public_key = { + let mut pk_seed = [0u8; 32]; + rng.fill_bytes(&mut pk_seed); + let mut pk_prng = DefaultSecureRng::new_seeded(pk_seed); + let mut seeded_pk = + SeededRlwePublicKey::, _>::empty(ring_size as usize, pk_seed, q); + gen_rlwe_public_key( + &mut seeded_pk.data, + s.values(), + &ntt_op, + &mod_op, + &mut pk_prng, + &mut rng, + ); + RlwePublicKey::>, DefaultSecureRng>::from(&seeded_pk) + }; + + let mut m0 = vec![0u64; ring_size as usize]; + m0[thread_rng().gen_range(0..ring_size) as usize] = 1; + let mut m1 = vec![0u64; ring_size as usize]; + m1[thread_rng().gen_range(0..ring_size) as usize] = 1; + + // RGSW(m0) + let rgsw_m0 = _pk_encrypt_rgsw(&m0, &public_key, &gadget_vector, &mod_op, &ntt_op); + // RGSW(m1) + let rgsw_m1 = _pk_encrypt_rgsw(&m0, &public_key, &gadget_vector, &mod_op, &ntt_op); + + let mut rgsw_m0_eval = + RgswCiphertextEvaluationDomain::<_, DefaultSecureRng, NttBackendU64>::from(&rgsw_m0); + + let mut scratch_matrix_d_plus_rgsw_by_ring = + vec![vec![0u64; ring_size as usize]; d_rgsw + (d_rgsw * 4)]; + rgsw_by_rgsw_inplace( + &mut rgsw_m0_eval.data, + &rgsw_m1.data, + &decomposer, + &mut scratch_matrix_d_plus_rgsw_by_ring, + &ntt_op, + &mod_op, + ); + dbg!(&rgsw_m0_eval.data); + + // RLWE(m0m1) + let mut rlwe_m0m1 = vec![vec![0u64; ring_size as usize]; 2]; + rlwe_m0m1[0].copy_from_slice(rgsw_m0_eval.get_row_slice(2 * d_rgsw)); + rlwe_m0m1[1].copy_from_slice(rgsw_m0_eval.get_row_slice(3 * d_rgsw)); + rlwe_m0m1.iter_mut().for_each(|ri| ntt_op.backward(ri)); + + // m0m1 + let mul = |a: &u64, b: &u64| ((*a as u128 * *b as u128) % q as u128) as u64; + let m0m1 = negacyclic_mul(&m0, &m1, mul, q); + + let noise = measure_noise(&rlwe_m0m1, &m0m1, &ntt_op, &mod_op, s.values()); + dbg!(noise); + } + #[test] fn galois_auto_works() { let logq = 50;