diff --git a/src/bool/evaluator.rs b/src/bool/evaluator.rs index fb66b6b..47bcd9b 100644 --- a/src/bool/evaluator.rs +++ b/src/bool/evaluator.rs @@ -1,5 +1,5 @@ use std::{ - cell::RefCell, + cell::{OnceCell, RefCell}, collections::HashMap, fmt::{Debug, Display}, marker::PhantomData, @@ -10,15 +10,16 @@ use num_traits::{FromPrimitive, Num, One, PrimInt, ToPrimitive, WrappingSub, Zer use crate::{ backend::{ArithmeticOps, ModInit, ModularOpsU64, VectorOps}, - decomposer::{Decomposer, DefaultDecomposer, NumInfo}, + bool::parameters::{MP_BOOL_PARAMS, SP_BOOL_PARAMS}, + decomposer::{Decomposer, DefaultDecomposer, NumInfo, RlweDecomposer}, lwe::{decrypt_lwe, encrypt_lwe, lwe_key_switch, lwe_ksk_keygen, measure_noise_lwe, LweSecret}, multi_party::public_key_share, ntt::{self, Ntt, NttBackendU64, NttInit}, random::{DefaultSecureRng, NewWithSeed, RandomGaussianDist, RandomUniformDist}, rgsw::{ decrypt_rlwe, galois_auto, galois_key_gen, generate_auto_map, public_key_encrypt_rgsw, - rgsw_by_rgsw_inplace, rlwe_by_rgsw, secret_key_encrypt_rgsw, IsTrivial, RlweCiphertext, - RlweSecret, + rgsw_by_rgsw_inplace, rlwe_by_rgsw, secret_key_encrypt_rgsw, IsTrivial, RgswCiphertext, + RlweCiphertext, RlweSecret, }, utils::{ fill_random_ternary_secret_with_hamming_weight, generate_prime, mod_exponent, @@ -27,7 +28,65 @@ use crate::{ Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, Secret, }; -use super::parameters::{self, BoolParameters}; +use super::parameters::BoolParameters; + +thread_local! { + static BOOL_EVALUATOR: RefCell>, NttBackendU64, ModularOpsU64>> = RefCell::new(BoolEvaluator::new(MP_BOOL_PARAMS)); +} + +pub fn set_parameter_set(parameter: &BoolParameters) { + BoolEvaluator::with_local_mut(|e| *e = BoolEvaluator::new(parameter.clone())) +} + +impl WithLocal for BoolEvaluator>, NttBackendU64, ModularOpsU64> { + fn with_local(func: F) -> R + where + F: Fn(&Self) -> R, + { + BOOL_EVALUATOR.with_borrow(|s| func(s)) + } + + fn with_local_mut(func: F) -> R + where + F: Fn(&mut Self) -> R, + { + BOOL_EVALUATOR.with_borrow_mut(|s| func(s)) + } +} + +struct ScratchMemory +where + M: Matrix, +{ + lwe_vector: M::R, + decomposition_matrix: M, +} + +impl ScratchMemory +where + M::R: RowEntity, +{ + fn new(parameters: &BoolParameters) -> Self { + // Vector to store LWE ciphertext with LWE dimesnion n + let lwe_vector = M::R::zeros(parameters.lwe_n().0 + 1); + + // Matrix to store decomposed polynomials + // Max decompistion count + space for temporary RLWE + let d = std::cmp::max( + parameters.auto_decomposition_count().0, + std::cmp::max( + parameters.rlwe_rgsw_decomposition_count().0 .0, + parameters.rlwe_rgsw_decomposition_count().1 .0, + ), + ) + 2; + let decomposition_matrix = M::zeros(d, parameters.rlwe_n().0); + + Self { + lwe_vector, + decomposition_matrix, + } + } +} // thread_local! { // pub(crate) static CLIENT_KEY: RefCell = @@ -44,22 +103,31 @@ trait PbsKey { fn lwe_ksk(&self) -> &Self::M; } -trait PbsParameters { +trait PbsInfo { type Element; + type ModOp: VectorOps + ArithmeticOps; + type NttOp: Ntt; type D: Decomposer; fn rlwe_q(&self) -> Self::Element; fn lwe_q(&self) -> Self::Element; fn br_q(&self) -> usize; - fn d_rgsw(&self) -> usize; - fn d_lwe(&self) -> usize; fn rlwe_n(&self) -> usize; fn lwe_n(&self) -> usize; /// Embedding fator for ring X^{q}+1 inside fn embedding_factor(&self) -> usize; /// generator g fn g(&self) -> isize; - fn decomoposer_lwe(&self) -> &Self::D; - fn decomoposer_rlwe(&self) -> &Self::D; + /// Decomposers + fn lwe_decomposer(&self) -> &Self::D; + fn rlwe_rgsw_decomposer(&self) -> &(Self::D, Self::D); + fn auto_decomposer(&self) -> &Self::D; + + /// Modulus operators + fn modop_lweq(&self) -> &Self::ModOp; + fn modop_rlweq(&self) -> &Self::ModOp; + + /// Ntt operators + fn nttop_rlweq(&self) -> &Self::NttOp; /// Maps a \in Z^*_{q} to discrete log k, with generator g (i.e. g^k = /// a). Returned vector is of size q that stores dlog of a at `vec[a]`. @@ -187,12 +255,11 @@ struct SeededMultiPartyServerKey { fn aggregate_multi_party_server_key_shares< M: MatrixMut + MatrixEntity, S: Copy + PartialEq, - D: Decomposer, + D: RlweDecomposer, ModOp: VectorOps + ModInit, NttOp: Ntt + NttInit, >( shares: &[CommonReferenceSeededMultiPartyServerKeyShare, S>], - d_rgsw_decomposer: &D, ) -> SeededMultiPartyServerKey> where ::R: RowMut + RowEntity, @@ -239,9 +306,17 @@ where // rgsw ciphertext (most expensive part!) let lwe_n = parameters.lwe_n().0; - let mut scratch_d_plus_rgsw_by_ring = M::zeros(d_rgsw + (d_rgsw * 4), rlwe_n); + let rgsw_by_rgsw_decomposer = parameters.rgsw_rgsw_decomposer::(); + let mut scratch_matrix = M::zeros( + std::cmp::max( + rgsw_by_rgsw_decomposer.a().decomposition_count(), + rgsw_by_rgsw_decomposer.b().decomposition_count(), + ) + (rgsw_by_rgsw_decomposer.a().decomposition_count() * 2 + + rgsw_by_rgsw_decomposer.b().decomposition_count() * 2), + rlwe_n, + ); - let mut tmp_rgsw = M::zeros(d_rgsw * 2 * 2, rlwe_n); + let mut tmp_rgsw = RgswCiphertext::::empty(rlwe_n, &rgsw_by_rgsw_decomposer, rlwe_q).data; let rgsw_cts = (0..lwe_n) .into_iter() .map(|index| { @@ -261,8 +336,8 @@ where rgsw_by_rgsw_inplace( &mut rgsw_i, &tmp_rgsw, - d_rgsw_decomposer, - &mut scratch_d_plus_rgsw_by_ring, + &rgsw_by_rgsw_decomposer, + &mut scratch_matrix, &rlweq_nttop, &rlweq_modop, ); @@ -523,21 +598,78 @@ where } // rgsw cts + let (rgswrgsw_d_a, rgswrgsw_d_b) = value.parameters.rgsw_rgsw_decomposition_count(); let (rlrg_d_a, rlrg_d_b) = value.parameters.rlwe_rgsw_decomposition_count(); - let rgsw_ct_rows = rlrg_d_a.0 * 2 + rlrg_d_b.0 * 2; + // d_a and d_b may differ for RGSWxRGSW multiplication and RLWExRGSW + // multiplication. After this point RGSW ciphertexts will only be used for + // RLWExRGSW multiplication (in blind rotation). Thus we drop any additional + // RLWE ciphertexts in incoming collective RGSW ciphertexts, which are + // result of RGSWxRGSW multiplications. + let rgsw_ct_rows_in = rgswrgsw_d_a.0 * 2 + rgswrgsw_d_b.0 * 2; + let rgsw_ct_rows_out = rlrg_d_a.0 * 2 + rlrg_d_b.0 * 2; let rgsw_cts = value .rgsw_cts .iter() - .map(|ct_i| { - assert!(ct_i.dimension() == (rgsw_ct_rows, rlwe_n)); - let mut eval_ct_i = M::zeros(rgsw_ct_rows, rlwe_n); + .map(|ct_i_in| { + assert!(ct_i_in.dimension() == (rgsw_ct_rows_in, rlwe_n)); + let mut eval_ct_i_out = M::zeros(rgsw_ct_rows_out, rlwe_n); - izip!(eval_ct_i.iter_rows_mut(), ct_i.iter_rows()).for_each(|(to_ri, from_ri)| { + // RLWE'(-sm) part A + izip!( + eval_ct_i_out.iter_rows_mut().take(rlrg_d_a.0), + ct_i_in.iter_rows().take(rlrg_d_a.0) + ) + .for_each(|(to_ri, from_ri)| { + to_ri.as_mut().copy_from_slice(from_ri.as_ref()); + rlwe_nttop.forward(to_ri.as_mut()); + }); + + // RLWE'(-sm) part B + izip!( + eval_ct_i_out + .iter_rows_mut() + .skip(rlrg_d_a.0) + .take(rlrg_d_a.0), + ct_i_in.iter_rows().skip(rgswrgsw_d_a.0).take(rlrg_d_a.0) + ) + .for_each(|(to_ri, from_ri)| { to_ri.as_mut().copy_from_slice(from_ri.as_ref()); rlwe_nttop.forward(to_ri.as_mut()); }); - eval_ct_i + // RLWE'(m) Part A + izip!( + eval_ct_i_out + .iter_rows_mut() + .skip(rlrg_d_a.0 * 2) + .take(rlrg_d_b.0), + ct_i_in + .iter_rows() + .skip(rgswrgsw_d_a.0 * 2) + .take(rlrg_d_b.0) + ) + .for_each(|(to_ri, from_ri)| { + to_ri.as_mut().copy_from_slice(from_ri.as_ref()); + rlwe_nttop.forward(to_ri.as_mut()); + }); + + // RLWE'(m) Part B + izip!( + eval_ct_i_out + .iter_rows_mut() + .skip(rlrg_d_a.0 * 2 + rlrg_d_b.0) + .take(rlrg_d_b.0), + ct_i_in + .iter_rows() + .skip(rgswrgsw_d_a.0 * 2 + rgswrgsw_d_b.0) + .take(rlrg_d_b.0) + ) + .for_each(|(to_ri, from_ri)| { + to_ri.as_mut().copy_from_slice(from_ri.as_ref()); + rlwe_nttop.forward(to_ri.as_mut()); + }); + + eval_ct_i_out }) .collect_vec(); @@ -575,13 +707,13 @@ impl PbsKey for ServerKeyEvaluationDomain { } } -struct BoolEvaluator -where - M: Matrix, -{ - parameters: BoolParameters, - decomposer_rlwe: DefaultDecomposer, - decomposer_lwe: DefaultDecomposer, +struct BoolPbsInfo { + auto_decomposer: DefaultDecomposer, + rlwe_rgsw_decomposer: ( + DefaultDecomposer, + DefaultDecomposer, + ), + lwe_decomposer: DefaultDecomposer, g_k_dlog_map: Vec, rlwe_nttop: Ntt, rlwe_modop: ModOp, @@ -591,10 +723,83 @@ where rlweq_by8: M::MatElement, rlwe_qby4: M::MatElement, rlwe_auto_maps: Vec<(Vec, Vec)>, + parameters: BoolParameters, +} + +impl PbsInfo for BoolPbsInfo +where + M::MatElement: PrimInt + WrappingSub + NumInfo + Debug, + ModOp: ArithmeticOps + VectorOps, + NttOp: Ntt, +{ + type Element = M::MatElement; + type D = DefaultDecomposer; + type ModOp = ModOp; + type NttOp = NttOp; + fn rlwe_auto_map(&self, k: isize) -> &(Vec, Vec) { + let g = self.parameters.g() as isize; + if k == g { + &self.rlwe_auto_maps[0] + } else if k == -g { + &self.rlwe_auto_maps[1] + } else { + panic!("RLWE auto map only supports k in [-g, g], but got k={k}"); + } + } + fn br_q(&self) -> usize { + self.parameters.br_q().0.to_usize().unwrap() + } + fn lwe_decomposer(&self) -> &Self::D { + &self.lwe_decomposer + } + fn rlwe_rgsw_decomposer(&self) -> &(Self::D, Self::D) { + &self.rlwe_rgsw_decomposer + } + fn auto_decomposer(&self) -> &Self::D { + &self.auto_decomposer + } + fn embedding_factor(&self) -> usize { + self.embedding_factor + } + fn g(&self) -> isize { + self.parameters.g() as isize + } + fn g_k_dlog_map(&self) -> &[usize] { + &self.g_k_dlog_map + } + fn lwe_n(&self) -> usize { + self.parameters.lwe_n().0 + } + fn lwe_q(&self) -> Self::Element { + self.parameters.lwe_q().0 + } + fn rlwe_n(&self) -> usize { + self.parameters.rlwe_n().0 + } + fn rlwe_q(&self) -> Self::Element { + self.parameters.rlwe_q().0 + } + fn modop_lweq(&self) -> &Self::ModOp { + &self.lwe_modop + } + fn modop_rlweq(&self) -> &Self::ModOp { + &self.rlwe_modop + } + fn nttop_rlweq(&self) -> &Self::NttOp { + &self.rlwe_nttop + } +} + +struct BoolEvaluator +where + M: Matrix, +{ + pbs_info: BoolPbsInfo, + scratch_memory: ScratchMemory, _phantom: PhantomData, } -impl BoolEvaluator +impl BoolEvaluator where NttOp: NttInit + Ntt, ModOp: ModInit @@ -612,16 +817,10 @@ where { fn new(parameters: BoolParameters) -> Self { //TODO(Jay): Run sanity checks for modulus values in parameters - assert!(parameters.br_q.is_power_of_two()); - - let decomposer_rlwe = - DefaultDecomposer::new(parameters.rlwe_q, parameters.logb_rgsw, parameters.d_rgsw); - let decomposer_lwe = - DefaultDecomposer::new(parameters.lwe_q, parameters.logb_lwe, parameters.d_lwe); // generatr dlog map s.t. g^{k} % q = a, for all a \in Z*_{q} - let g = parameters.g; - let q = parameters.br_q; + let g = parameters.g(); + let q = parameters.br_q().0.to_usize().unwrap(); let mut g_k_dlog_map = vec![0usize; q]; for i in 0..q / 2 { let v = mod_exponent(g as u64, i as u64, q as u64) as usize; @@ -631,48 +830,30 @@ where g_k_dlog_map[q - v] = i + (q / 2); } - let embedding_factor = (2 * parameters.rlwe_n) / q; + let embedding_factor = (2 * parameters.rlwe_n().0) / q; - let rlwe_nttop = NttOp::new(parameters.rlwe_q, parameters.rlwe_n); - let rlwe_modop = ModInit::new(parameters.rlwe_q); - let lwe_modop = ModInit::new(parameters.lwe_q); + let rlwe_nttop = NttOp::new(parameters.rlwe_q().0, parameters.rlwe_n().0); + let rlwe_modop = ModInit::new(parameters.rlwe_q().0); + let lwe_modop = ModInit::new(parameters.lwe_q().0); // set test vectors - let el_one = M::MatElement::one(); - let nand_map = |index: usize, qby8: usize| { - if index < (3 * qby8) { - true - } else { - false - } - }; - - let q = parameters.br_q; + let rlwe_q = parameters.rlwe_q().0; + let q = parameters.br_q().0.to_usize().unwrap(); let qby2 = q >> 1; let qby8 = q >> 3; let mut nand_test_vec = M::R::zeros(qby2); // Q/8 (Q: rlwe_q) - let rlwe_qby8 = - M::MatElement::from_f64((parameters.rlwe_q.to_f64().unwrap() / 8.0).round()).unwrap(); + let rlwe_qby8 = M::MatElement::from_f64((rlwe_q.to_f64().unwrap() / 8.0).round()).unwrap(); let true_m_el = rlwe_qby8; // -Q/8 - let false_m_el = parameters.rlwe_q - rlwe_qby8; + let false_m_el = rlwe_q - rlwe_qby8; for i in 0..qby2 { - let v = nand_map(i, qby8); - if v { + if i < (3 * qby8) { nand_test_vec.as_mut()[i] = true_m_el; } else { nand_test_vec.as_mut()[i] = false_m_el; } } - // // Rotate and negate by q/8 - // let mut tmp = M::R::zeros(qby2); - // tmp.as_mut()[..qby2 - qby8].copy_from_slice(&nand_test_vec.as_ref()[qby8..]); - // tmp.as_mut()[qby2 - qby8..].copy_from_slice(&nand_test_vec.as_ref()[..qby8]); - // tmp.as_mut()[qby2 - qby8..].iter_mut().for_each(|v| { - // *v = parameters.rlwe_q - *v; - // }); - // let nand_test_vec = tmp; // v(X) -> v(X^{-g}) let (auto_map_index, auto_map_sign) = generate_auto_map(qby2, -(g as isize)); @@ -685,7 +866,7 @@ where .for_each(|(v, to_index, to_sign)| { if !to_sign { // negate - nand_test_vec_autog.as_mut()[*to_index] = parameters.rlwe_q - *v; + nand_test_vec_autog.as_mut()[*to_index] = parameters.rlwe_q().0 - *v; } else { nand_test_vec_autog.as_mut()[*to_index] = *v; } @@ -693,19 +874,20 @@ where // auto map indices and sign let mut rlwe_auto_maps = vec![]; - let ring_size = parameters.rlwe_n; - let g = parameters.g as isize; + let ring_size = parameters.rlwe_n().0; + let g = parameters.g() as isize; for i in [g, -g] { rlwe_auto_maps.push(generate_auto_map(ring_size, i)) } - let rlwe_qby4 = - M::MatElement::from_f64((parameters.rlwe_q.to_f64().unwrap() / 4.0).round()).unwrap(); + let rlwe_qby4 = M::MatElement::from_f64((rlwe_q.to_f64().unwrap() / 4.0).round()).unwrap(); - BoolEvaluator { - parameters: parameters, - decomposer_lwe, - decomposer_rlwe, + let scratch_memory = ScratchMemory::new(¶meters); + + let pbs_info = BoolPbsInfo { + auto_decomposer: parameters.auto_decomposer(), + lwe_decomposer: parameters.lwe_decomposer(), + rlwe_rgsw_decomposer: parameters.rlwe_rgsw_decomposer(), g_k_dlog_map, embedding_factor, lwe_modop, @@ -715,14 +897,25 @@ where rlweq_by8: rlwe_qby8, rlwe_qby4: rlwe_qby4, rlwe_auto_maps, + parameters: parameters, + }; + BoolEvaluator { + pbs_info, + scratch_memory, _phantom: PhantomData, } } fn client_key(&self) -> ClientKey { - let sk_lwe = LweSecret::random(self.parameters.lwe_n >> 1, self.parameters.lwe_n); - let sk_rlwe = RlweSecret::random(self.parameters.rlwe_n >> 1, self.parameters.rlwe_n); + let sk_lwe = LweSecret::random( + self.pbs_info.parameters.lwe_n().0 >> 1, + self.pbs_info.parameters.lwe_n().0, + ); + let sk_rlwe = RlweSecret::random( + self.pbs_info.parameters.rlwe_n().0 >> 1, + self.pbs_info.parameters.rlwe_n().0, + ); ClientKey { sk_rlwe, sk_lwe } } @@ -736,23 +929,23 @@ where let mut main_prng = DefaultSecureRng::new_seeded(main_seed); + let rlwe_n = self.pbs_info.parameters.rlwe_n().0; let sk_rlwe = &client_key.sk_rlwe; let sk_lwe = &client_key.sk_lwe; - let d_rgsw_gadget_vec = self.decomposer_rlwe.gadget_vector(); - // generate auto keys -g, g let mut auto_keys = HashMap::new(); - let g = self.parameters.g as isize; + let auto_gadget = self.pbs_info.auto_decomposer.gadget_vector(); + let g = self.pbs_info.parameters.g() as isize; for i in [g, -g] { - let mut gk = M::zeros(self.parameters.d_rgsw, self.parameters.rlwe_n); + let mut gk = M::zeros(self.pbs_info.auto_decomposer.decomposition_count(), rlwe_n); galois_key_gen( &mut gk, sk_rlwe.values(), i, - &d_rgsw_gadget_vec, - &self.rlwe_modop, - &self.rlwe_nttop, + &auto_gadget, + &self.pbs_info.rlwe_modop, + &self.pbs_info.rlwe_nttop, &mut main_prng, rng, ); @@ -760,15 +953,21 @@ where } // generate rgsw ciphertexts RGSW(si) where si is i^th LWE secret element - let ring_size = self.parameters.rlwe_n; - let rlwe_q = self.parameters.rlwe_q; + let ring_size = self.pbs_info.parameters.rlwe_n().0; + let rlwe_q = self.pbs_info.parameters.rlwe_q().0; + let (rlrg_d_a, rlrg_d_b) = ( + self.pbs_info.rlwe_rgsw_decomposer.0.decomposition_count(), + self.pbs_info.rlwe_rgsw_decomposer.1.decomposition_count(), + ); + let rlrg_gadget_a = self.pbs_info.rlwe_rgsw_decomposer.0.gadget_vector(); + let rlrg_gadget_b = self.pbs_info.rlwe_rgsw_decomposer.1.gadget_vector(); let rgsw_cts = sk_lwe .values() .iter() .map(|si| { // X^{si}; assume |emebedding_factor * si| < N let mut m = M::R::zeros(ring_size); - let si = (self.embedding_factor as i32) * si; + let si = (self.pbs_info.embedding_factor as i32) * si; // dbg!(si); if si < 0 { // X^{-i} = X^{2N - i} = -X^{N-i} @@ -778,14 +977,15 @@ where m.as_mut()[si.abs() as usize] = M::MatElement::one(); } - let mut rgsw_si = M::zeros(self.parameters.d_rgsw * 3, ring_size); + let mut rgsw_si = M::zeros(rlrg_d_a * 2 + rlrg_d_b, ring_size); secret_key_encrypt_rgsw( &mut rgsw_si, m.as_ref(), - &d_rgsw_gadget_vec, + &rlrg_gadget_a, + &rlrg_gadget_b, sk_rlwe.values(), - &self.rlwe_modop, - &self.rlwe_nttop, + &self.pbs_info.rlwe_modop, + &self.pbs_info.rlwe_nttop, &mut main_prng, rng, ); @@ -795,15 +995,15 @@ where .collect_vec(); // LWE KSK from RLWE secret s -> LWE secret z - let d_lwe_gadget = self.decomposer_lwe.gadget_vector(); - - let mut lwe_ksk = M::R::zeros(self.parameters.d_lwe * ring_size); + let d_lwe_gadget = self.pbs_info.lwe_decomposer.gadget_vector(); + let mut lwe_ksk = + M::R::zeros(self.pbs_info.lwe_decomposer.decomposition_count() * ring_size); lwe_ksk_keygen( &sk_rlwe.values(), &sk_lwe.values(), &mut lwe_ksk, &d_lwe_gadget, - &self.lwe_modop, + &self.pbs_info.lwe_modop, &mut main_prng, rng, ); @@ -812,13 +1012,13 @@ where auto_keys, rgsw_cts, lwe_ksk, - self.parameters.clone(), + self.pbs_info.parameters.clone(), main_seed, ) }) } - fn multi_party_sever_key_share( + fn multi_party_server_key_share( &self, cr_seed: [u8; 32], collective_pk: &M, @@ -831,31 +1031,31 @@ where let sk_rlwe = &client_key.sk_rlwe; let sk_lwe = &client_key.sk_lwe; - let g = self.parameters.g as isize; - let ring_size = self.parameters.rlwe_n; - let d_rgsw = self.parameters.d_rgsw; - let d_lwe = self.parameters.d_lwe; - let rlwe_q = self.parameters.rlwe_q; - let lwe_q = self.parameters.lwe_q; - - let d_rgsw_gadget_vec = self.decomposer_rlwe.gadget_vector(); + let g = self.pbs_info.parameters.g() as isize; + let ring_size = self.pbs_info.parameters.rlwe_n().0; + let rlwe_q = self.pbs_info.parameters.rlwe_q().0; + let lwe_q = self.pbs_info.parameters.lwe_q().0; let rlweq_modop = ModOp::new(rlwe_q); let rlweq_nttop = NttOp::new(rlwe_q, ring_size); // sanity check assert!(sk_rlwe.values().len() == ring_size); - assert!(sk_lwe.values().len() == self.parameters.lwe_n); + assert!(sk_lwe.values().len() == self.pbs_info.parameters.lwe_n().0); // auto keys let mut auto_keys = HashMap::new(); + let auto_gadget = self.pbs_info.auto_decomposer.gadget_vector(); for i in [g, -g] { - let mut ksk_out = M::zeros(d_rgsw, ring_size); + let mut ksk_out = M::zeros( + self.pbs_info.auto_decomposer.decomposition_count(), + ring_size, + ); galois_key_gen( &mut ksk_out, sk_rlwe.values(), i, - &d_rgsw_gadget_vec, + &auto_gadget, &rlweq_modop, &rlweq_nttop, &mut main_prng, @@ -865,6 +1065,18 @@ where } // rgsw ciphertexts of lwe secret elements + let rgsw_rgsw_decomposer = self + .pbs_info + .parameters + .rgsw_rgsw_decomposer::>(); + let (rgrg_d_a, rgrg_d_b) = ( + rgsw_rgsw_decomposer.0.decomposition_count(), + rgsw_rgsw_decomposer.1.decomposition_count(), + ); + let (rgrg_gadget_a, rgrg_gadget_b) = ( + rgsw_rgsw_decomposer.0.gadget_vector(), + rgsw_rgsw_decomposer.1.gadget_vector(), + ); let rgsw_cts = sk_lwe .values() .iter() @@ -873,7 +1085,7 @@ where //TODO(Jay): It will be nice to have a function that returns polynomial // (monomial infact!) corresponding to secret element embedded in ring X^{2N+1}. // Save lots of mistakes where one forgest to emebed si in bigger ring. - let si = *si * (self.embedding_factor as i32); + let si = *si * (self.pbs_info.embedding_factor as i32); if si < 0 { // X^{-si} = X^{2N-si} = -X^{N-si}, assuming abs(si) < N // (which it is given si is secret element) @@ -884,12 +1096,13 @@ where // public key RGSW encryption has no part that can be seeded, unlike secret key // RGSW encryption where RLWE'_A(m) is seeded - let mut out_rgsw = M::zeros(d_rgsw * 4, ring_size); + let mut out_rgsw = M::zeros(rgrg_d_a * 2 + rgrg_d_b * 2, ring_size); public_key_encrypt_rgsw( &mut out_rgsw, &m.as_ref(), collective_pk, - &d_rgsw_gadget_vec, + &rgrg_gadget_a, + &rgrg_gadget_b, &rlweq_modop, &rlweq_nttop, rng, @@ -900,9 +1113,10 @@ where .collect_vec(); // LWE ksk - let mut lwe_ksk = M::R::zeros(d_lwe * ring_size); + let mut lwe_ksk = + M::R::zeros(self.pbs_info.lwe_decomposer.decomposition_count() * ring_size); let lwe_modop = ModOp::new(lwe_q); - let d_lwe_gadget_vec = self.decomposer_lwe.gadget_vector(); + let d_lwe_gadget_vec = self.pbs_info.lwe_decomposer.gadget_vector(); lwe_ksk_keygen( sk_rlwe.values(), sk_lwe.values(), @@ -918,7 +1132,7 @@ where rgsw_cts, lwe_ksk, cr_seed, - parameters: self.parameters.clone(), + parameters: self.pbs_info.parameters.clone(), } }) } @@ -933,9 +1147,12 @@ where BoolParameters<::MatElement>, > { DefaultSecureRng::with_local_mut(|rng| { - let mut share_out = M::R::zeros(self.parameters.rlwe_n); - let modop = ModOp::new(self.parameters.rlwe_q); - let nttop = NttOp::new(self.parameters.rlwe_q, self.parameters.rlwe_n); + let mut share_out = M::R::zeros(self.pbs_info.parameters.rlwe_n().0); + let modop = ModOp::new(self.pbs_info.parameters.rlwe_q().0); + let nttop = NttOp::new( + self.pbs_info.parameters.rlwe_q().0, + self.pbs_info.parameters.rlwe_n().0, + ); let mut main_prng = DefaultSecureRng::new_seeded(cr_seed); public_key_share( &mut share_out, @@ -949,7 +1166,7 @@ where CommonReferenceSeededCollectivePublicKeyShare { share: share_out, cr_seed: cr_seed, - parameters: self.parameters.clone(), + parameters: self.pbs_info.parameters.clone(), } }) } @@ -959,10 +1176,12 @@ where lwe_ct: &M::R, client_key: &ClientKey, ) -> MultiPartyDecryptionShare<::MatElement> { - assert!(lwe_ct.as_ref().len() == self.parameters.rlwe_n + 1); - let modop = &self.rlwe_modop; - let mut neg_s = - M::R::try_convert_from(client_key.sk_rlwe.values(), &self.parameters.rlwe_q); + assert!(lwe_ct.as_ref().len() == self.pbs_info.parameters.rlwe_n().0 + 1); + let modop = &self.pbs_info.rlwe_modop; + let mut neg_s = M::R::try_convert_from( + client_key.sk_rlwe.values(), + &self.pbs_info.parameters.rlwe_q().0, + ); modop.elwise_neg_mut(neg_s.as_mut()); let mut neg_sa = M::MatElement::zero(); @@ -972,7 +1191,7 @@ where let e = DefaultSecureRng::with_local_mut(|rng| { let mut e = M::MatElement::zero(); - RandomGaussianDist::random_fill(rng, &self.parameters.rlwe_q, &mut e); + RandomGaussianDist::random_fill(rng, &self.pbs_info.parameters.rlwe_q().0, &mut e); e }); let share = modop.add(&neg_sa, &e); @@ -985,7 +1204,7 @@ where shares: &[MultiPartyDecryptionShare], lwe_ct: &M::R, ) -> bool { - let modop = &self.rlwe_modop; + let modop = &self.pbs_info.rlwe_modop; let mut sum_a = M::MatElement::zero(); shares .iter() @@ -993,8 +1212,8 @@ where let encoded_m = modop.add(&lwe_ct.as_ref()[0], &sum_a); - let m = (((encoded_m + self.rlweq_by8).to_f64().unwrap() * 4f64) - / self.parameters.rlwe_q.to_f64().unwrap()) + let m = (((encoded_m + self.pbs_info.rlweq_by8).to_f64().unwrap() * 4f64) + / self.pbs_info.parameters.rlwe_q().0.to_f64().unwrap()) .round() as usize % 4usize; @@ -1011,15 +1230,15 @@ where /// LWE ciphertext pub(crate) fn pk_encrypt(&self, pk: &M, m: bool) -> M::R { DefaultSecureRng::with_local_mut(|rng| { - let modop = &self.rlwe_modop; - let nttop = &self.rlwe_nttop; + let modop = &self.pbs_info.rlwe_modop; + let nttop = &self.pbs_info.rlwe_nttop; // RLWE(0) // sample ephemeral key u - let ring_size = self.parameters.rlwe_n; + let ring_size = self.pbs_info.parameters.rlwe_n().0; let mut u = vec![0i32; ring_size]; fill_random_ternary_secret_with_hamming_weight(u.as_mut(), ring_size >> 1, rng); - let mut u = M::R::try_convert_from(&u, &self.parameters.rlwe_q); + let mut u = M::R::try_convert_from(&u, &self.pbs_info.parameters.rlwe_q().0); nttop.forward(u.as_mut()); let mut ua = M::R::zeros(ring_size); @@ -1040,7 +1259,11 @@ where let mut rlwe = M::zeros(2, ring_size); // sample error rlwe.iter_rows_mut().for_each(|ri| { - RandomGaussianDist::random_fill(rng, &self.parameters.rlwe_q, ri.as_mut()); + RandomGaussianDist::random_fill( + rng, + &self.pbs_info.parameters.rlwe_q().0, + ri.as_mut(), + ); }); // a*u + e0 @@ -1050,10 +1273,10 @@ where let m = if m { // Q/8 - self.rlweq_by8 + self.pbs_info.rlweq_by8 } else { // -Q/8 - self.parameters.rlwe_q - self.rlweq_by8 + self.pbs_info.parameters.rlwe_q().0 - self.pbs_info.rlweq_by8 }; // b*u + e1 + m, where m is constant polynomial @@ -1071,19 +1294,19 @@ where pub fn sk_encrypt(&self, m: bool, client_key: &ClientKey) -> M::R { let m = if m { // Q/8 - self.rlweq_by8 + self.pbs_info.rlweq_by8 } else { // -Q/8 - self.parameters.rlwe_q - self.rlweq_by8 + self.pbs_info.parameters.rlwe_q().0 - self.pbs_info.rlweq_by8 }; DefaultSecureRng::with_local_mut(|rng| { - let mut lwe_out = M::R::zeros(self.parameters.rlwe_n + 1); + let mut lwe_out = M::R::zeros(self.pbs_info.parameters.rlwe_n().0 + 1); encrypt_lwe( &mut lwe_out, &m, client_key.sk_rlwe.values(), - &self.rlwe_modop, + &self.pbs_info.rlwe_modop, rng, ); lwe_out @@ -1091,11 +1314,15 @@ where } pub fn sk_decrypt(&self, lwe_ct: &M::R, client_key: &ClientKey) -> bool { - let m = decrypt_lwe(lwe_ct, client_key.sk_rlwe.values(), &self.rlwe_modop); + let m = decrypt_lwe( + lwe_ct, + client_key.sk_rlwe.values(), + &self.pbs_info.rlwe_modop, + ); let m = { // m + q/8 => {0,q/4 1} - (((m + self.rlweq_by8).to_f64().unwrap() * 4.0) - / self.parameters.rlwe_q.to_f64().unwrap()) + (((m + self.pbs_info.rlweq_by8).to_f64().unwrap() * 4.0) + / self.pbs_info.parameters.rlwe_q().0.to_f64().unwrap()) .round() .to_usize() .unwrap() @@ -1113,30 +1340,13 @@ where // TODO(Jay): scratch spaces must be thread local. Don't pass them as arguments pub fn nand( - &self, + &mut self, c0: &M::R, c1: &M::R, server_key: &ServerKeyEvaluationDomain, - scratch_lwen_plus1: &mut M::R, - scratch_matrix_dplus2_ring: &mut M, ) -> M::R { - // ClientKey::with_local(|ck| { - // let c0_noise = measure_noise_lwe( - // c0, - // ck.sk_rlwe.values(), - // &self.rlwe_modop, - // &(self.rlwe_q() - self.rlweq_by8), - // ); - // let c1_noise = - // measure_noise_lwe(c1, ck.sk_rlwe.values(), &self.rlwe_modop, - // &(self.rlweq_by8)); println!( - // "c0 noise: {c0_noise}; c1 noise: - // {c1_noise}" - // ); - // }); - let mut c_out = M::R::zeros(c0.as_ref().len()); - let modop = &self.rlwe_modop; + let modop = &self.pbs_info.rlwe_modop; izip!( c_out.as_mut().iter_mut(), c0.as_ref().iter(), @@ -1145,94 +1355,26 @@ where .for_each(|(o, i0, i1)| { *o = modop.add(i0, i1); }); - // +Q/8 - c_out.as_mut()[0] = modop.add(&c_out.as_ref()[0], &self.rlwe_qby4); - - // ClientKey::with_local(|ck| { - // let noise = measure_noise_lwe( - // &c_out, - // ck.sk_rlwe.values(), - // &self.rlwe_modop, - // &(self.rlweq_by8), - // ); - // println!("cout_noise: {noise}"); - // }); + // +Q/4 + c_out.as_mut()[0] = modop.add(&c_out.as_ref()[0], &self.pbs_info.rlwe_qby4); // PBS pbs( - self, - &self.nand_test_vec, + &self.pbs_info, + &self.pbs_info.nand_test_vec, &mut c_out, - scratch_lwen_plus1, - scratch_matrix_dplus2_ring, - &self.lwe_modop, - &self.rlwe_modop, - &self.rlwe_nttop, server_key, + &mut self.scratch_memory.lwe_vector, + &mut self.scratch_memory.decomposition_matrix, ); c_out } } -impl PbsParameters for BoolEvaluator -where - M::MatElement: PrimInt + WrappingSub + Debug, -{ - type Element = M::MatElement; - type D = DefaultDecomposer; - fn rlwe_auto_map(&self, k: isize) -> &(Vec, Vec) { - let g = self.parameters.g as isize; - if k == g { - &self.rlwe_auto_maps[0] - } else if k == -g { - &self.rlwe_auto_maps[1] - } else { - panic!("RLWE auto map only supports k in [-g, g], but got k={k}"); - } - } - - fn br_q(&self) -> usize { - self.parameters.br_q - } - fn d_lwe(&self) -> usize { - self.parameters.d_lwe - } - fn d_rgsw(&self) -> usize { - self.parameters.d_rgsw - } - fn decomoposer_lwe(&self) -> &Self::D { - &self.decomposer_lwe - } - fn decomoposer_rlwe(&self) -> &Self::D { - &self.decomposer_rlwe - } - fn embedding_factor(&self) -> usize { - self.embedding_factor - } - fn g(&self) -> isize { - self.parameters.g as isize - } - fn g_k_dlog_map(&self) -> &[usize] { - &self.g_k_dlog_map - } - fn lwe_n(&self) -> usize { - self.parameters.lwe_n - } - fn lwe_q(&self) -> Self::Element { - self.parameters.lwe_q - } - fn rlwe_n(&self) -> usize { - self.parameters.rlwe_n - } - fn rlwe_q(&self) -> Self::Element { - self.parameters.rlwe_q - } -} - /// LMKCY+ Blind rotation /// -/// gk_to_si: [-g^0, -g^1, .., -g^{q/2-1}, g^0, ..., g^{q/2-1}] +/// gk_to_si: [g^0, ..., g^{q/2-1}, -g^0, -g^1, .., -g^{q/2-1}] fn blind_rotation< MT: IsTrivial + MatrixMut, Mmut: MatrixMut + Matrix, @@ -1240,15 +1382,16 @@ fn blind_rotation< NttOp: Ntt, ModOp: ArithmeticOps + VectorOps, K: PbsKey, - P: PbsParameters, + P: PbsInfo, >( trivial_rlwe_test_poly: &mut MT, - scratch_matrix_dplus2_ring: &mut Mmut, + scratch_matrix: &mut Mmut, g: isize, w: usize, q: usize, gk_to_si: &[Vec], - decomposer: &D, + rlwe_rgsw_decomposer: &(D, D), + auto_decomposer: &D, ntt_op: &NttOp, mod_op: &ModOp, parameters: &P, @@ -1266,8 +1409,8 @@ fn blind_rotation< rlwe_by_rgsw( trivial_rlwe_test_poly, pbs_key.rgsw_ct_lwe_si(*s_index), - scratch_matrix_dplus2_ring, - decomposer, + scratch_matrix, + rlwe_rgsw_decomposer, ntt_op, mod_op, ); @@ -1277,12 +1420,12 @@ fn blind_rotation< galois_auto( trivial_rlwe_test_poly, pbs_key.galois_key_for_auto(g), - scratch_matrix_dplus2_ring, + scratch_matrix, &auto_map_index, &auto_map_sign, mod_op, ntt_op, - decomposer, + auto_decomposer, ); } @@ -1291,8 +1434,8 @@ fn blind_rotation< rlwe_by_rgsw( trivial_rlwe_test_poly, pbs_key.rgsw_ct_lwe_si(*s_index), - scratch_matrix_dplus2_ring, - decomposer, + scratch_matrix, + rlwe_rgsw_decomposer, ntt_op, mod_op, ); @@ -1301,12 +1444,12 @@ fn blind_rotation< galois_auto( trivial_rlwe_test_poly, pbs_key.galois_key_for_auto(-g), - scratch_matrix_dplus2_ring, + scratch_matrix, &auto_map_index, &auto_map_sign, mod_op, ntt_op, - decomposer, + auto_decomposer, ); // +(g^k) @@ -1315,8 +1458,8 @@ fn blind_rotation< rlwe_by_rgsw( trivial_rlwe_test_poly, pbs_key.rgsw_ct_lwe_si(*s_index), - scratch_matrix_dplus2_ring, - decomposer, + scratch_matrix, + rlwe_rgsw_decomposer, ntt_op, mod_op, ); @@ -1326,12 +1469,12 @@ fn blind_rotation< galois_auto( trivial_rlwe_test_poly, pbs_key.galois_key_for_auto(g), - scratch_matrix_dplus2_ring, + scratch_matrix, &auto_map_index, &auto_map_sign, mod_op, ntt_op, - decomposer, + auto_decomposer, ); } @@ -1340,8 +1483,8 @@ fn blind_rotation< rlwe_by_rgsw( trivial_rlwe_test_poly, pbs_key.rgsw_ct_lwe_si(gk_to_si[q_by_2][*s_index]), - scratch_matrix_dplus2_ring, - decomposer, + scratch_matrix, + rlwe_rgsw_decomposer, ntt_op, mod_op, ); @@ -1354,31 +1497,28 @@ fn blind_rotation< /// - blind rotate fn pbs< M: Matrix + MatrixMut + MatrixEntity, - P: PbsParameters, - NttOp: Ntt, - ModOp: ArithmeticOps + VectorOps, + P: PbsInfo, + // NttOp: Ntt, + // ModOp: ArithmeticOps + VectorOps, K: PbsKey, >( - parameters: &P, + pbs_info: &P, test_vec: &M::R, lwe_in: &mut M::R, - scratch_lwen_plus1: &mut M::R, - scratch_matrix_dplus2_ring: &mut M, - modop_lweq: &ModOp, - modop_rlweq: &ModOp, - nttop_rlweq: &NttOp, pbs_key: &K, + scratch_lwe_vec: &mut M::R, + scratch_blind_rotate_matrix: &mut M, ) where ::R: RowMut, M::MatElement: PrimInt + ToPrimitive + FromPrimitive + One + Copy + Zero + Display, { - let rlwe_q = parameters.rlwe_q(); - let lwe_q = parameters.lwe_q(); - let br_q = parameters.br_q(); + let rlwe_q = pbs_info.rlwe_q(); + let lwe_q = pbs_info.lwe_q(); + let br_q = pbs_info.br_q(); let rlwe_qf64 = rlwe_q.to_f64().unwrap(); let lwe_qf64 = lwe_q.to_f64().unwrap(); let br_qf64 = br_q.to_f64().unwrap(); - let rlwe_n = parameters.rlwe_n(); + let rlwe_n = pbs_info.rlwe_n(); PBSTracer::with_local_mut(|t| { let out = lwe_in @@ -1405,17 +1545,17 @@ fn pbs< }); // key switch RLWE secret to LWE secret - scratch_lwen_plus1.as_mut().fill(M::MatElement::zero()); + scratch_lwe_vec.as_mut().fill(M::MatElement::zero()); lwe_key_switch( - scratch_lwen_plus1, + scratch_lwe_vec, lwe_in, pbs_key.lwe_ksk(), - modop_lweq, - parameters.decomoposer_lwe(), + pbs_info.modop_lweq(), + pbs_info.lwe_decomposer(), ); PBSTracer::with_local_mut(|t| { - let out = scratch_lwen_plus1 + let out = scratch_lwe_vec .as_ref() .iter() .map(|v| v.to_u64().unwrap()) @@ -1424,9 +1564,9 @@ fn pbs< }); // odd mowdown Q_ks -> q - let g_k_dlog_map = parameters.g_k_dlog_map(); + let g_k_dlog_map = pbs_info.g_k_dlog_map(); let mut g_k_si = vec![vec![]; br_q]; - scratch_lwen_plus1 + scratch_lwe_vec .as_ref() .iter() .skip(1) @@ -1438,7 +1578,7 @@ fn pbs< }); PBSTracer::with_local_mut(|t| { - let out = scratch_lwen_plus1 + let out = scratch_lwe_vec .as_ref() .iter() .map(|v| mod_switch_odd(v.to_f64().unwrap(), lwe_qf64, br_qf64) as u64) @@ -1447,9 +1587,9 @@ fn pbs< }); // handle b and set trivial test RLWE - let g = parameters.g() as usize; + let g = pbs_info.g() as usize; let g_times_b = (g * mod_switch_odd( - scratch_lwen_plus1.as_ref()[0].to_f64().unwrap(), + scratch_lwe_vec.as_ref()[0].to_f64().unwrap(), lwe_qf64, br_qf64, )) % (br_q); @@ -1468,14 +1608,14 @@ fn pbs< is_trivial: true, _phatom: PhantomData, }; - if parameters.embedding_factor() == 1 { + if pbs_info.embedding_factor() == 1 { monomial_mul( test_vec.as_ref(), trivial_rlwe_test_poly.get_row_mut(1).as_mut(), gb_monomial_exp, gb_monomial_sign, br_qby2, - modop_rlweq, + pbs_info.modop_rlweq(), ); } else { // use lwe_in to store the `t = v(X) * X^{g*2} mod X^{q/2}+1` temporarily. This @@ -1486,11 +1626,11 @@ fn pbs< gb_monomial_exp, gb_monomial_sign, br_qby2, - modop_rlweq, + pbs_info.modop_rlweq(), ); // emebed poly `t` in ring X^{q/2}+1 inside the bigger ring X^{N}+1 - let embed_factor = parameters.embedding_factor(); + let embed_factor = pbs_info.embedding_factor(); let partb_trivial_rlwe = trivial_rlwe_test_poly.get_row_mut(1); lwe_in.as_ref()[..br_qby2] .iter() @@ -1503,15 +1643,16 @@ fn pbs< // blind rotate blind_rotation( &mut trivial_rlwe_test_poly, - scratch_matrix_dplus2_ring, - parameters.g(), + scratch_blind_rotate_matrix, + pbs_info.g(), 1, br_q, &g_k_si, - parameters.decomoposer_rlwe(), - nttop_rlweq, - modop_rlweq, - parameters, + pbs_info.rlwe_rgsw_decomposer(), + pbs_info.auto_decomposer(), + pbs_info.nttop_rlweq(), + pbs_info.modop_rlweq(), + pbs_info, pbs_key, ); @@ -1541,7 +1682,7 @@ fn pbs< // }); // sample extract - sample_extract(lwe_in, &trivial_rlwe_test_poly, modop_rlweq, 0); + sample_extract(lwe_in, &trivial_rlwe_test_poly, pbs_info.modop_rlweq(), 0); } fn mod_switch_odd(v: f64, from_q: f64, to_q: f64) -> usize { @@ -1626,21 +1767,21 @@ where impl PBSTracer>> { fn trace(&self, parameters: &BoolParameters, sk_lwe: &[i32], sk_rlwe: &[i32]) { - assert!(parameters.rlwe_n == sk_rlwe.len()); - assert!(parameters.lwe_n == sk_lwe.len()); + assert!(parameters.rlwe_n().0 == sk_rlwe.len()); + assert!(parameters.lwe_n().0 == sk_lwe.len()); - let modop_rlweq = ModularOpsU64::new(parameters.rlwe_q as u64); + let modop_rlweq = ModularOpsU64::new(parameters.rlwe_q().0); // noise after mod down Q -> Q_ks let m_back0 = decrypt_lwe(&self.ct_rlwe_q_mod, sk_rlwe, &modop_rlweq); - let modop_lweq = ModularOpsU64::new(parameters.lwe_q as u64); + let modop_lweq = ModularOpsU64::new(parameters.lwe_q().0); // noise after mod down Q -> Q_ks let m_back1 = decrypt_lwe(&self.ct_lwe_q_mod, sk_rlwe, &modop_lweq); // noise after key switch from RLWE -> LWE let m_back2 = decrypt_lwe(&self.ct_lwe_q_mod_after_ksk, sk_lwe, &modop_lweq); // noise after mod down odd from Q_ks -> q - let modop_br_q = ModularOpsU64::new(parameters.br_q as u64); + let modop_br_q = ModularOpsU64::new(parameters.br_q().0); let m_back3 = decrypt_lwe(&self.ct_br_q_mod, sk_lwe, &modop_br_q); println!( @@ -1684,30 +1825,20 @@ mod tests { random::DEFAULT_RNG, rgsw::{ self, measure_noise, public_key_encrypt_rlwe, secret_key_encrypt_rlwe, - tests::{_measure_noise_rgsw, _secret_encrypt_rlwe}, + tests::{_measure_noise_rgsw, _sk_encrypt_rlwe}, RgswCiphertext, RgswCiphertextEvaluationDomain, SeededRgswCiphertext, SeededRlweCiphertext, }, utils::negacyclic_mul, }; - use self::parameters::{MP_BOOL_PARAMS, SP_BOOL_PARAMS}; - use super::*; - // #[test] - // fn trial() { - // dbg!(generate_prime(28, 1 << 11, 1 << 28)); - // } - #[test] fn bool_encrypt_decrypt_works() { - // let prime = generate_prime(32, 2 * 1024, 1 << 32); - // dbg!(prime); let bool_evaluator = - BoolEvaluator::>, u64, NttBackendU64, ModularOpsU64>::new(SP_BOOL_PARAMS); + BoolEvaluator::>, NttBackendU64, ModularOpsU64>::new(SP_BOOL_PARAMS); let client_key = bool_evaluator.client_key(); - // let sever_key = bool_evaluator.server_key(&client_key); let mut m = true; for _ in 0..1000 { @@ -1725,8 +1856,8 @@ mod tests { *r = rng; }); - let bool_evaluator = - BoolEvaluator::>, u64, NttBackendU64, ModularOpsU64>::new(SP_BOOL_PARAMS); + let mut bool_evaluator = + BoolEvaluator::>, NttBackendU64, ModularOpsU64>::new(SP_BOOL_PARAMS); // println!("{:?}", bool_evaluator.nand_test_vec); let client_key = bool_evaluator.client_key(); @@ -1736,25 +1867,13 @@ mod tests { &seeded_server_key, ); - let mut scratch_lwen_plus1 = vec![0u64; bool_evaluator.parameters.lwe_n + 1]; - let mut scratch_matrix_dplus2_ring = vec![ - vec![0u64; bool_evaluator.parameters.rlwe_n]; - bool_evaluator.parameters.d_rgsw + 2 - ]; - let mut m0 = false; let mut m1 = true; let mut ct0 = bool_evaluator.sk_encrypt(m0, &client_key); let mut ct1 = bool_evaluator.sk_encrypt(m1, &client_key); for _ in 0..1000 { - let ct_back = bool_evaluator.nand( - &ct0, - &ct1, - &server_key_eval_domain, - &mut scratch_lwen_plus1, - &mut scratch_matrix_dplus2_ring, - ); + let ct_back = bool_evaluator.nand(&ct0, &ct1, &server_key_eval_domain); let m_out = !(m0 && m1); @@ -1762,39 +1881,39 @@ mod tests { { let noise0 = { let ideal = if m0 { - bool_evaluator.rlweq_by8 + bool_evaluator.pbs_info.rlweq_by8 } else { - bool_evaluator.rlwe_q() - bool_evaluator.rlweq_by8 + bool_evaluator.pbs_info.rlwe_q() - bool_evaluator.pbs_info.rlweq_by8 }; let n = measure_noise_lwe( &ct0, client_key.sk_rlwe.values(), - &bool_evaluator.rlwe_modop, + &bool_evaluator.pbs_info.rlwe_modop, &ideal, ); let v = decrypt_lwe( &ct0, client_key.sk_rlwe.values(), - &bool_evaluator.rlwe_modop, + &bool_evaluator.pbs_info.rlwe_modop, ); (n, v) }; let noise1 = { let ideal = if m1 { - bool_evaluator.rlweq_by8 + bool_evaluator.pbs_info.rlweq_by8 } else { - bool_evaluator.rlwe_q() - bool_evaluator.rlweq_by8 + bool_evaluator.pbs_info.rlwe_q() - bool_evaluator.pbs_info.rlweq_by8 }; let n = measure_noise_lwe( &ct1, client_key.sk_rlwe.values(), - &bool_evaluator.rlwe_modop, + &bool_evaluator.pbs_info.rlwe_modop, &ideal, ); let v = decrypt_lwe( &ct1, client_key.sk_rlwe.values(), - &bool_evaluator.rlwe_modop, + &bool_evaluator.pbs_info.rlwe_modop, ); (n, v) }; @@ -1811,20 +1930,20 @@ mod tests { // Calculate noise in ciphertext post PBS let noise_out = { let ideal = if m_out { - bool_evaluator.rlweq_by8 + bool_evaluator.pbs_info.rlweq_by8 } else { - bool_evaluator.rlwe_q() - bool_evaluator.rlweq_by8 + bool_evaluator.pbs_info.rlwe_q() - bool_evaluator.pbs_info.rlweq_by8 }; let n = measure_noise_lwe( &ct_back, client_key.sk_rlwe.values(), - &bool_evaluator.rlwe_modop, + &bool_evaluator.pbs_info.rlwe_modop, &ideal, ); let v = decrypt_lwe( &ct_back, client_key.sk_rlwe.values(), - &bool_evaluator.rlwe_modop, + &bool_evaluator.pbs_info.rlwe_modop, ); (n, v) }; @@ -1848,14 +1967,14 @@ mod tests { #[test] fn multi_party_encryption_decryption() { let bool_evaluator = - BoolEvaluator::>, u64, NttBackendU64, ModularOpsU64>::new(MP_BOOL_PARAMS); + BoolEvaluator::>, NttBackendU64, ModularOpsU64>::new(MP_BOOL_PARAMS); let no_of_parties = 500; let parties = (0..no_of_parties) .map(|_| bool_evaluator.client_key()) .collect_vec(); - let mut ideal_rlwe_sk = vec![0i32; bool_evaluator.rlwe_n()]; + let mut ideal_rlwe_sk = vec![0i32; bool_evaluator.pbs_info.rlwe_n()]; parties.iter().for_each(|k| { izip!(ideal_rlwe_sk.iter_mut(), k.sk_rlwe.values()).for_each(|(ideal_i, s_i)| { *ideal_i = *ideal_i + s_i; @@ -1887,14 +2006,14 @@ mod tests { { let ideal_m = if m { - bool_evaluator.rlweq_by8 + bool_evaluator.pbs_info.rlweq_by8 } else { - bool_evaluator.parameters.rlwe_q - bool_evaluator.rlweq_by8 + bool_evaluator.pbs_info.rlwe_q() - bool_evaluator.pbs_info.rlweq_by8 }; let noise = measure_noise_lwe( &lwe_ct, &ideal_rlwe_sk, - &bool_evaluator.rlwe_modop, + &bool_evaluator.pbs_info.rlwe_modop, &ideal_m, ); println!("Noise: {noise}"); @@ -1905,155 +2024,6 @@ mod tests { } } - #[test] - fn ms() { - let logbig_q = 50; - let logsmall_q = 20; - let big_q = 1 << logbig_q; - let small_q = 1 << logsmall_q; - let lwe_n = 493; - - let no_of_parties = 10; - let parties_lwe_sk = (0..no_of_parties) - .map(|_| LweSecret::random(lwe_n >> 1, lwe_n)) - .collect_vec(); - - // Ideal secrets - let mut ideal_lwe_sk = vec![0i32; lwe_n]; - parties_lwe_sk.iter().for_each(|k| { - izip!(ideal_lwe_sk.iter_mut(), k.values()).for_each(|(ideal_i, s_i)| { - *ideal_i = *ideal_i + s_i; - }); - }); - - let mut rng = DefaultSecureRng::new(); - - let logp = 3; - let modop_bigq = ModularOpsU64::new(big_q); - let modop_smallq = ModularOpsU64::new(small_q); - - for i in 0..100 { - let m = thread_rng().sample(Uniform::new(0u64, (1u64 << logp))); - let bigq_m = m << (logbig_q - logp); - let smallq_m = m << (logsmall_q - logp); - - // encrypt - let mut lwe_ct = vec![0u64; lwe_n + 1]; - encrypt_lwe(&mut lwe_ct, &bigq_m, &ideal_lwe_sk, &modop_bigq, &mut rng); - - let noise = measure_noise_lwe(&lwe_ct, &ideal_lwe_sk, &modop_bigq, &bigq_m); - println!("Noise Before: {noise}"); - - // mod switch - let lwe_ct_ms = lwe_ct - .iter() - .map(|v| (((*v as f64) * small_q as f64) / (big_q as f64)).round() as u64) - .collect_vec(); - - let noise = measure_noise_lwe(&lwe_ct_ms, &ideal_lwe_sk, &modop_smallq, &smallq_m); - println!("Noise After: {noise}"); - } - } - - #[test] - fn multi_party_lwe_keyswitch() { - let lwe_logq = 18; - let lwe_q = 1 << lwe_logq; - let d_lwe = 1; - let logb_lwe = 6; - let lweq_modop = ModularOpsU64::new(lwe_q); - - let decomposer = DefaultDecomposer::new(lwe_q, logb_lwe, d_lwe); - let lwe_gadgect_vec = decomposer.gadget_vector(); - let logp = 2; - - let from_lwe_n = 2048; - let to_lwe_n = 500; - - let no_of_parties = 10; - let parties_from_lwe_sk = (0..no_of_parties) - .map(|_| LweSecret::random(from_lwe_n >> 1, from_lwe_n)) - .collect_vec(); - let parties_to_lwe_sk = (0..no_of_parties) - .map(|_| LweSecret::random(to_lwe_n >> 1, to_lwe_n)) - .collect_vec(); - - // Ideal secrets - let mut ideal_from_lwe_sk = vec![0i32; from_lwe_n]; - parties_from_lwe_sk.iter().for_each(|k| { - izip!(ideal_from_lwe_sk.iter_mut(), k.values()).for_each(|(ideal_i, s_i)| { - *ideal_i = *ideal_i + s_i; - }); - }); - let mut ideal_to_lwe_sk = vec![0i32; to_lwe_n]; - parties_to_lwe_sk.iter().for_each(|k| { - izip!(ideal_to_lwe_sk.iter_mut(), k.values()).for_each(|(ideal_i, s_i)| { - *ideal_i = *ideal_i + s_i; - }); - }); - - // Generate Lwe KSK share - let mut rng = DefaultSecureRng::new(); - let mut ksk_seed = [0u8; 32]; - rng.fill_bytes(&mut ksk_seed); - let lwe_ksk_shares = izip!(parties_from_lwe_sk.iter(), parties_to_lwe_sk.iter()) - .map(|(from_sk, to_sk)| { - let mut ksk_out = vec![0u64; from_lwe_n * d_lwe]; - let mut p_rng = DefaultSecureRng::new_seeded(ksk_seed); - lwe_ksk_keygen( - from_sk.values(), - to_sk.values(), - &mut ksk_out, - &lwe_gadgect_vec, - &lweq_modop, - &mut p_rng, - &mut rng, - ); - ksk_out - }) - .collect_vec(); - - // Create collective LWE ksk - let mut sum_partb = vec![0u64; d_lwe * from_lwe_n]; - lwe_ksk_shares.iter().for_each(|share| { - lweq_modop.elwise_add_mut(sum_partb.as_mut_slice(), share.as_slice()) - }); - let mut lwe_ksk = vec![vec![0u64; to_lwe_n + 1]; d_lwe * from_lwe_n]; - let mut p_rng = DefaultSecureRng::new_seeded(ksk_seed); - izip!(lwe_ksk.iter_mut(), sum_partb.iter()).for_each(|(lwe_i, part_bi)| { - RandomUniformDist::random_fill(&mut p_rng, &lwe_q, &mut lwe_i.as_mut_slice()[1..]); - lwe_i[0] = *part_bi; - }); - - for i in 0..128 { - println!("############## ITERATION {i} ##############"); - - // Encrypt m - let m = 1; - let mut lwe_ct = vec![0u64; from_lwe_n + 1]; - encrypt_lwe(&mut lwe_ct, &m, &ideal_from_lwe_sk, &lweq_modop, &mut rng); - - let noise = measure_noise_lwe(&lwe_ct, &ideal_from_lwe_sk, &lweq_modop, &m); - println!("Noise before key switch: {noise}"); - - // Key switch - let lwe_ct_key_switched = { - let mut lwe_ct_key_switched = vec![0u64; to_lwe_n + 1]; - lwe_key_switch( - &mut lwe_ct_key_switched, - &lwe_ct, - &lwe_ksk, - &lweq_modop, - &decomposer, - ); - lwe_ct_key_switched - }; - - let noise = measure_noise_lwe(&lwe_ct_key_switched, &ideal_to_lwe_sk, &lweq_modop, &m); - println!("Noise after key switch: {noise}"); - } - } - fn _collecitve_public_key_gen(rlwe_q: u64, parties_rlwe_sk: &[RlweSecret]) -> Vec> { let ring_size = parties_rlwe_sk[0].values.len(); assert!(ring_size.is_power_of_two()); @@ -2087,7 +2057,7 @@ mod tests { } fn _multi_party_all_keygen( - bool_evaluator: &BoolEvaluator>, u64, NttBackendU64, ModularOpsU64>, + bool_evaluator: &BoolEvaluator>, NttBackendU64, ModularOpsU64>, no_of_parties: usize, ) -> ( Vec, @@ -2107,8 +2077,11 @@ mod tests { .map(|_| bool_evaluator.client_key()) .collect_vec(); + let mut rng = DefaultSecureRng::new(); + // Collective public key - let pk_cr_seed = [0u8; 32]; + let mut pk_cr_seed = [0u8; 32]; + rng.fill_bytes(&mut pk_cr_seed); let public_key_share = parties .iter() .map(|k| bool_evaluator.multi_party_public_key_share(pk_cr_seed, k)) @@ -2118,29 +2091,34 @@ mod tests { ); // Server key - let pbs_cr_seed = [1u8; 32]; + let mut pbs_cr_seed = [1u8; 32]; + rng.fill_bytes(&mut pbs_cr_seed); let server_key_shares = parties .iter() - .map(|k| bool_evaluator.multi_party_sever_key_share(pbs_cr_seed, &collective_pk.key, k)) + .map(|k| { + bool_evaluator.multi_party_server_key_share(pbs_cr_seed, &collective_pk.key, k) + }) .collect_vec(); - let seeded_server_key = - aggregate_multi_party_server_key_shares::<_, _, _, ModularOpsU64, NttBackendU64>( - &server_key_shares, - &bool_evaluator.decomposer_rlwe, - ); + let seeded_server_key = aggregate_multi_party_server_key_shares::< + _, + _, + (DefaultDecomposer, DefaultDecomposer), + ModularOpsU64, + NttBackendU64, + >(&server_key_shares); let server_key_eval = ServerKeyEvaluationDomain::<_, DefaultSecureRng, NttBackendU64>::from( &seeded_server_key, ); // construct ideal rlwe sk for meauring noise let ideal_client_key = { - let mut ideal_rlwe_sk = vec![0i32; bool_evaluator.rlwe_n()]; + let mut ideal_rlwe_sk = vec![0i32; bool_evaluator.pbs_info.rlwe_n()]; parties.iter().for_each(|k| { izip!(ideal_rlwe_sk.iter_mut(), k.sk_rlwe.values()).for_each(|(ideal_i, s_i)| { *ideal_i = *ideal_i + s_i; }); }); - let mut ideal_lwe_sk = vec![0i32; bool_evaluator.lwe_n()]; + let mut ideal_lwe_sk = vec![0i32; bool_evaluator.pbs_info.lwe_n()]; parties.iter().for_each(|k| { izip!(ideal_lwe_sk.iter_mut(), k.sk_lwe.values()).for_each(|(ideal_i, s_i)| { *ideal_i = *ideal_i + s_i; @@ -2167,220 +2145,13 @@ mod tests { ) } - #[test] - fn mp_key_correcntess() { - let bool_evaluator = - BoolEvaluator::>, u64, NttBackendU64, ModularOpsU64>::new(MP_BOOL_PARAMS); - - let (_, collective_pk, _, _, server_key_eval, ideal_client_key) = - _multi_party_all_keygen(&bool_evaluator, 20); - - let lwe_q = bool_evaluator.parameters.lwe_q; - let rlwe_q = bool_evaluator.parameters.rlwe_q; - let d_rgsw = bool_evaluator.parameters.d_rgsw; - let lwe_logq = bool_evaluator.parameters.lwe_logq; - let lwe_n = bool_evaluator.parameters.lwe_n; - let rlwe_n = bool_evaluator.parameters.rlwe_n; - let lwe_modop = &bool_evaluator.lwe_modop; - let rlwe_nttop = &bool_evaluator.rlwe_nttop; - let rlwe_modop = &bool_evaluator.rlwe_modop; - let rlwe_decomposer = &bool_evaluator.decomposer_rlwe; - - // test LWE ksk from RLWE -> LWE - if false { - let logp = 2; - let mut rng = DefaultSecureRng::new(); - - let m = 1; - let encoded_m = m << (lwe_logq - logp); - - // Encrypt - let mut lwe_ct = vec![0u64; rlwe_n + 1]; - encrypt_lwe( - &mut lwe_ct, - &encoded_m, - ideal_client_key.sk_rlwe.values(), - lwe_modop, - &mut rng, - ); - - // key switch - let lwe_decomposer = &bool_evaluator.decomposer_lwe; - let mut lwe_out = vec![0u64; lwe_n + 1]; - lwe_key_switch( - &mut lwe_out, - &lwe_ct, - &server_key_eval.lwe_ksk, - lwe_modop, - lwe_decomposer, - ); - - let encoded_m_back = decrypt_lwe(&lwe_out, ideal_client_key.sk_lwe.values(), lwe_modop); - let m_back = - ((encoded_m_back as f64 * (1 << logp) as f64) / (lwe_q as f64)).round() as u64; - dbg!(m_back, m); - - let noise = measure_noise_lwe( - &lwe_out, - ideal_client_key.sk_lwe.values(), - lwe_modop, - &encoded_m, - ); - - println!("Noise: {noise}"); - } - - // Measure noise in RGSW ciphertexts of ideal LWE secrets - if true { - let gadget_vec = rlwe_decomposer.gadget_vector(); - for i in 0..20 { - // measure noise in RGSW(s[i]) - let si = - ideal_client_key.sk_lwe.values[i] * (bool_evaluator.embedding_factor as i32); - let mut si_poly = vec![0u64; rlwe_n]; - if si < 0 { - si_poly[rlwe_n - (si.abs() as usize)] = rlwe_q - 1; - } else { - si_poly[(si.abs() as usize)] = 1; - } - - let mut rgsw_si = server_key_eval.rgsw_cts[i].clone(); - rgsw_si - .iter_mut() - .for_each(|ri| rlwe_nttop.backward(ri.as_mut())); - - println!("####### Noise in RGSW(X^s_{i}) #######"); - _measure_noise_rgsw( - &rgsw_si, - &si_poly, - ideal_client_key.sk_rlwe.values(), - &gadget_vec, - rlwe_q, - ); - println!("####### ##################### #######"); - } - } - - // measure noise grwoth in RLWExRGSW - if true { - let mut rng = DefaultSecureRng::new(); - let mut carry_m = vec![0u64; rlwe_n]; - RandomUniformDist::random_fill(&mut rng, &rlwe_q, carry_m.as_mut_slice()); - - // RGSW(carrym) - let trivial_rlwect = vec![vec![0u64; rlwe_n], carry_m.clone()]; - let mut rlwe_ct = RlweCiphertext::<_, DefaultSecureRng>::new_trivial(trivial_rlwect); - - let mut scratch_matrix_dplus2_ring = vec![vec![0u64; rlwe_n]; d_rgsw + 2]; - let mul_mod = - |v0: &u64, v1: &u64| (((*v0 as u128 * *v1 as u128) % (rlwe_q as u128)) as u64); - - for i in 0..bool_evaluator.parameters.lwe_n { - rlwe_by_rgsw( - &mut rlwe_ct, - server_key_eval.rgsw_ct_lwe_si(i), - &mut scratch_matrix_dplus2_ring, - rlwe_decomposer, - rlwe_nttop, - rlwe_modop, - ); - - // carry_m[X] * s_i[X] - let si = - ideal_client_key.sk_lwe.values[i] * (bool_evaluator.embedding_factor as i32); - let mut si_poly = vec![0u64; rlwe_n]; - if si < 0 { - si_poly[rlwe_n - (si.abs() as usize)] = rlwe_q - 1; - } else { - si_poly[(si.abs() as usize)] = 1; - } - carry_m = negacyclic_mul(&carry_m, &si_poly, mul_mod, rlwe_q); - - let noise = measure_noise( - &rlwe_ct, - &carry_m, - rlwe_nttop, - rlwe_modop, - ideal_client_key.sk_rlwe.values(), - ); - println!("Noise RLWE(carry_m) accumulating {i}^th secret monomial: {noise}"); - } - } - - // Check galois keys - if false { - let g = bool_evaluator.g() as isize; - let mut rng = DefaultSecureRng::new(); - let mut scratch_matrix_dplus2_ring = vec![vec![0u64; rlwe_n]; d_rgsw + 2]; - for i in [g, -g] { - let mut m = vec![0u64; rlwe_n]; - RandomUniformDist::random_fill(&mut rng, &rlwe_q, m.as_mut_slice()); - let mut rlwe_ct = { - let mut data = vec![vec![0u64; rlwe_n]; 2]; - public_key_encrypt_rlwe( - &mut data, - &collective_pk.key, - &m, - rlwe_modop, - rlwe_nttop, - &mut rng, - ); - RlweCiphertext::<_, DefaultSecureRng>::new_trivial(data, false) - }; - - let auto_key = server_key_eval.galois_key_for_auto(i); - let (auto_map_index, auto_map_sign) = generate_auto_map(rlwe_n, i); - galois_auto( - &mut rlwe_ct, - auto_key, - &mut scratch_matrix_dplus2_ring, - &auto_map_index, - &auto_map_sign, - rlwe_modop, - rlwe_nttop, - rlwe_decomposer, - ); - - // send m(X) -> m(X^i) - let mut m_k = vec![0u64; rlwe_n]; - izip!(m.iter(), auto_map_index.iter(), auto_map_sign.iter()).for_each( - |(mi, to_index, to_sign)| { - if !to_sign { - m_k[*to_index] = rlwe_q - *mi; - } else { - m_k[*to_index] = *mi; - } - }, - ); - - // measure noise - let noise = measure_noise( - &rlwe_ct, - &m_k, - rlwe_nttop, - rlwe_modop, - ideal_client_key.sk_rlwe.values(), - ); - - println!("Noise after auto k={i}: {noise}"); - } - } - } - #[test] fn multi_party_nand() { - let bool_evaluator = - BoolEvaluator::>, u64, NttBackendU64, ModularOpsU64>::new(MP_BOOL_PARAMS); + let mut bool_evaluator = + BoolEvaluator::>, NttBackendU64, ModularOpsU64>::new(MP_BOOL_PARAMS); let (parties, collective_pk, _, _, server_key_eval, ideal_client_key) = - _multi_party_all_keygen(&bool_evaluator, 50); - - // PBS - let mut scratch_lwen_plus1 = vec![0u64; bool_evaluator.parameters.lwe_n + 1]; - let mut scratch_matrix_dplus2_ring = vec![ - vec![0u64; bool_evaluator.parameters.rlwe_n]; - bool_evaluator.parameters.d_rgsw + 2 - ]; + _multi_party_all_keygen(&bool_evaluator, 2); let mut m0 = true; let mut m1 = false; @@ -2389,13 +2160,7 @@ mod tests { let mut lwe1 = bool_evaluator.pk_encrypt(&collective_pk.key, m1); for _ in 0..2000 { - let lwe_out = bool_evaluator.nand( - &lwe0, - &lwe1, - &server_key_eval, - &mut scratch_lwen_plus1, - &mut scratch_matrix_dplus2_ring, - ); + let lwe_out = bool_evaluator.nand(&lwe0, &lwe1, &server_key_eval); let m_expected = !(m0 & m1); @@ -2403,39 +2168,39 @@ mod tests { { let noise0 = { let ideal = if m0 { - bool_evaluator.rlweq_by8 + bool_evaluator.pbs_info.rlweq_by8 } else { - bool_evaluator.rlwe_q() - bool_evaluator.rlweq_by8 + bool_evaluator.pbs_info.rlwe_q() - bool_evaluator.pbs_info.rlweq_by8 }; let n = measure_noise_lwe( &lwe0, ideal_client_key.sk_rlwe.values(), - &bool_evaluator.rlwe_modop, + &bool_evaluator.pbs_info.rlwe_modop, &ideal, ); let v = decrypt_lwe( &lwe0, ideal_client_key.sk_rlwe.values(), - &bool_evaluator.rlwe_modop, + &bool_evaluator.pbs_info.rlwe_modop, ); (n, v) }; let noise1 = { let ideal = if m1 { - bool_evaluator.rlweq_by8 + bool_evaluator.pbs_info.rlweq_by8 } else { - bool_evaluator.rlwe_q() - bool_evaluator.rlweq_by8 + bool_evaluator.pbs_info.rlwe_q() - bool_evaluator.pbs_info.rlweq_by8 }; let n = measure_noise_lwe( &lwe1, ideal_client_key.sk_rlwe.values(), - &bool_evaluator.rlwe_modop, + &bool_evaluator.pbs_info.rlwe_modop, &ideal, ); let v = decrypt_lwe( &lwe1, ideal_client_key.sk_rlwe.values(), - &bool_evaluator.rlwe_modop, + &bool_evaluator.pbs_info.rlwe_modop, ); (n, v) }; @@ -2451,20 +2216,21 @@ mod tests { let noise_out = { let ideal_m = if m_expected { - bool_evaluator.rlweq_by8 + bool_evaluator.pbs_info.rlweq_by8 } else { - bool_evaluator.parameters.rlwe_q - bool_evaluator.rlweq_by8 + bool_evaluator.pbs_info.parameters.rlwe_q().0 + - bool_evaluator.pbs_info.rlweq_by8 }; let n = measure_noise_lwe( &lwe_out, ideal_client_key.sk_rlwe.values(), - &bool_evaluator.rlwe_modop, + &bool_evaluator.pbs_info.rlwe_modop, &ideal_m, ); let v = decrypt_lwe( &lwe_out, ideal_client_key.sk_rlwe.values(), - &bool_evaluator.rlwe_modop, + &bool_evaluator.pbs_info.rlwe_modop, ); (n, v) }; @@ -2546,35 +2312,42 @@ mod tests { // }; let bool_evaluator = - BoolEvaluator::>, u64, NttBackendU64, ModularOpsU64>::new(MP_BOOL_PARAMS); + BoolEvaluator::>, NttBackendU64, ModularOpsU64>::new(MP_BOOL_PARAMS); // let (_, collective_pk, _, _, server_key_eval, ideal_client_key) = // _multi_party_all_keygen(&bool_evaluator, 20); let no_of_parties = 2; - let lwe_q = bool_evaluator.parameters.lwe_q; - let rlwe_q = bool_evaluator.parameters.rlwe_q; - let d_rgsw = bool_evaluator.parameters.d_rgsw; - let lwe_logq = bool_evaluator.parameters.lwe_logq; - let lwe_n = bool_evaluator.parameters.lwe_n; - let rlwe_n = bool_evaluator.parameters.rlwe_n; - let lwe_modop = &bool_evaluator.lwe_modop; - let rlwe_nttop = &bool_evaluator.rlwe_nttop; - let rlwe_modop = &bool_evaluator.rlwe_modop; - let rlwe_decomposer = &bool_evaluator.decomposer_rlwe; - let rlwe_gadget_vector = rlwe_decomposer.gadget_vector(); + let lwe_q = bool_evaluator.pbs_info.parameters.lwe_q().0; + let rlwe_q = bool_evaluator.pbs_info.parameters.rlwe_q().0; + let lwe_n = bool_evaluator.pbs_info.parameters.lwe_n().0; + let rlwe_n = bool_evaluator.pbs_info.parameters.rlwe_n().0; + let lwe_modop = &bool_evaluator.pbs_info.lwe_modop; + let rlwe_nttop = &bool_evaluator.pbs_info.rlwe_nttop; + let rlwe_modop = &bool_evaluator.pbs_info.rlwe_modop; + + let rlwe_rgsw_decomposer = &bool_evaluator.pbs_info.rlwe_rgsw_decomposer; + let rlwe_rgsw_gadget_a = rlwe_rgsw_decomposer.0.gadget_vector(); + let rlwe_rgsw_gadget_b = rlwe_rgsw_decomposer.1.gadget_vector(); + + // let rgsw_rgsw_decomposer = &bool_evaluator + // .pbs_info + // .parameters + // .rgsw_rgsw_decomposer::>(); + // let rgsw_rgsw_gagdet_a = rgsw_rgsw_decomposer.a().gadget_vector(); + // let rgsw_rgsw_gagdet_b = rgsw_rgsw_decomposer.b().gadget_vector(); let parties = (0..no_of_parties) .map(|_| bool_evaluator.client_key()) .collect_vec(); let ideal_client_key = { - let mut ideal_rlwe_sk = vec![0i32; bool_evaluator.rlwe_n()]; + let mut ideal_rlwe_sk = vec![0i32; bool_evaluator.pbs_info.rlwe_n()]; parties.iter().for_each(|k| { izip!(ideal_rlwe_sk.iter_mut(), k.sk_rlwe.values()).for_each(|(ideal_i, s_i)| { *ideal_i = *ideal_i + s_i; }); }); - let mut ideal_lwe_sk = vec![0i32; bool_evaluator.lwe_n()]; + let mut ideal_lwe_sk = vec![0i32; bool_evaluator.pbs_info.lwe_n()]; parties.iter().for_each(|k| { izip!(ideal_lwe_sk.iter_mut(), k.sk_lwe.values()).for_each(|(ideal_i, s_i)| { *ideal_i = *ideal_i + s_i; @@ -2608,7 +2381,8 @@ mod tests { public_key_share.as_slice(), ); - let m = vec![0u64; rlwe_n]; + let mut m = vec![0u64; rlwe_n]; + RandomUniformDist::random_fill(&mut rng, &rlwe_q, m.as_mut_slice()); let mut rlwe_ct = vec![vec![0u64; rlwe_n]; 2]; public_key_encrypt_rlwe( &mut rlwe_ct, @@ -2653,15 +2427,17 @@ mod tests { let server_key_shares = parties .iter() .map(|k| { - bool_evaluator.multi_party_sever_key_share(pbs_cr_seed, &collective_pk.key, k) + bool_evaluator.multi_party_server_key_share(pbs_cr_seed, &collective_pk.key, k) }) .collect_vec(); - let seeded_server_key = - aggregate_multi_party_server_key_shares::<_, _, _, ModularOpsU64, NttBackendU64>( - &server_key_shares, - rlwe_decomposer, - ); + let seeded_server_key = aggregate_multi_party_server_key_shares::< + _, + _, + (DefaultDecomposer, DefaultDecomposer), + ModularOpsU64, + NttBackendU64, + >(&server_key_shares); // Check noise in RGSW ciphertexts of ideal LWE secret elements if true { @@ -2673,40 +2449,33 @@ mod tests { .for_each(|(s_i, rgsw_ct_i)| { // X^{s[i]} let mut m_si = vec![0u64; rlwe_n]; - let s_i = *s_i * (bool_evaluator.embedding_factor as i32); + let s_i = *s_i * (bool_evaluator.pbs_info.embedding_factor as i32); if s_i < 0 { m_si[rlwe_n - (s_i.abs() as usize)] = rlwe_q - 1; } else { m_si[s_i as usize] = 1; } - _measure_noise_rgsw( - &rgsw_ct_i, - &m_si, - ideal_client_key.sk_rlwe.values(), - &rlwe_gadget_vector, - rlwe_q, - ); - // RLWE(-sm) let mut neg_s_eval = Vec::::try_convert_from(ideal_client_key.sk_rlwe.values(), &rlwe_q); rlwe_modop.elwise_neg_mut(&mut neg_s_eval); rlwe_nttop.forward(&mut neg_s_eval); - for j in 0..rlwe_decomposer.decomposition_count() { + for j in 0..rlwe_rgsw_decomposer.a().decomposition_count() { // -s[X]*X^{s_lwe[i]}*B_j let mut m_ideal = m_si.clone(); rlwe_nttop.forward(m_ideal.as_mut_slice()); rlwe_modop.elwise_mul_mut(m_ideal.as_mut_slice(), neg_s_eval.as_slice()); rlwe_nttop.backward(m_ideal.as_mut_slice()); rlwe_modop - .elwise_scalar_mul_mut(m_ideal.as_mut_slice(), &rlwe_gadget_vector[j]); + .elwise_scalar_mul_mut(m_ideal.as_mut_slice(), &rlwe_rgsw_gadget_a[j]); // RLWE(-s*X^{s_lwe[i]}*B_j) let mut rlwe_ct = vec![vec![0u64; rlwe_n]; 2]; rlwe_ct[0].copy_from_slice(&rgsw_ct_i[j]); - rlwe_ct[1] - .copy_from_slice(&rgsw_ct_i[j + rlwe_decomposer.decomposition_count()]); + rlwe_ct[1].copy_from_slice( + &rgsw_ct_i[j + rlwe_rgsw_decomposer.a().decomposition_count()], + ); let mut m_back = vec![0u64; rlwe_n]; decrypt_rlwe( @@ -2723,19 +2492,21 @@ mod tests { } // RLWE'(m) - for j in 0..rlwe_decomposer.decomposition_count() { + for j in 0..rlwe_rgsw_decomposer.b().decomposition_count() { // X^{s_lwe[i]}*B_j let mut m_ideal = m_si.clone(); rlwe_modop - .elwise_scalar_mul_mut(m_ideal.as_mut_slice(), &rlwe_gadget_vector[j]); + .elwise_scalar_mul_mut(m_ideal.as_mut_slice(), &rlwe_rgsw_gadget_b[j]); // RLWE(X^{s_lwe[i]}*B_j) let mut rlwe_ct = vec![vec![0u64; rlwe_n]; 2]; rlwe_ct[0].copy_from_slice( - &rgsw_ct_i[j + (2 * rlwe_decomposer.decomposition_count())], + &rgsw_ct_i[j + (2 * rlwe_rgsw_decomposer.a().decomposition_count())], ); rlwe_ct[1].copy_from_slice( - &rgsw_ct_i[j + (3 * rlwe_decomposer.decomposition_count())], + &rgsw_ct_i[j + + (2 * rlwe_rgsw_decomposer.a().decomposition_count() + + rlwe_rgsw_decomposer.b().decomposition_count())], ); let mut m_back = vec![0u64; rlwe_n]; @@ -2797,20 +2568,25 @@ mod tests { ]); // let mut rlwe_after = // RlweCiphertext::<_, DefaultSecureRng>::from_raw(rlwe_ct.clone(), false); - let mut scratch = - vec![vec![0u64; rlwe_n]; rlwe_decomposer.decomposition_count() + 2]; + let mut scratch = vec![ + vec![0u64; rlwe_n]; + std::cmp::max( + rlwe_rgsw_decomposer.0.decomposition_count(), + rlwe_rgsw_decomposer.1.decomposition_count() + ) + 2 + ]; rlwe_by_rgsw( &mut rlwe_after, &rgsw_ct_i, &mut scratch, - rlwe_decomposer, + rlwe_rgsw_decomposer, rlwe_nttop, rlwe_modop, ); // m1 = X^{s[i]} let mut m1 = vec![0u64; rlwe_n]; - let s_i = *s_i * (bool_evaluator.embedding_factor as i32); + let s_i = *s_i * (bool_evaluator.pbs_info.embedding_factor as i32); if s_i < 0 { m1[rlwe_n - (s_i.abs() as usize)] = rlwe_q - 1; } else { @@ -3064,9 +2840,4 @@ mod tests { // } // } } - - fn test_2() { - let bool_evaluator = - BoolEvaluator::>, u64, NttBackendU64, ModularOpsU64>::new(SP_BOOL_PARAMS); - } } diff --git a/src/bool/parameters.rs b/src/bool/parameters.rs index 71d86d3..f9dc15b 100644 --- a/src/bool/parameters.rs +++ b/src/bool/parameters.rs @@ -146,26 +146,25 @@ impl BoolParameters { } #[derive(Clone, Copy, PartialEq)] -struct DecompostionLogBase(pub(crate) usize); +pub(crate) struct DecompostionLogBase(pub(crate) usize); impl AsRef for DecompostionLogBase { fn as_ref(&self) -> &usize { &self.0 } } #[derive(Clone, Copy, PartialEq)] -struct DecompositionCount(pub(crate) usize); +pub(crate) struct DecompositionCount(pub(crate) usize); impl AsRef for DecompositionCount { fn as_ref(&self) -> &usize { &self.0 } } - #[derive(Clone, Copy, PartialEq)] -struct LweDimension(pub(crate) usize); +pub(crate) struct LweDimension(pub(crate) usize); #[derive(Clone, Copy, PartialEq)] -struct PolynomialSize(pub(crate) usize); +pub(crate) struct PolynomialSize(pub(crate) usize); #[derive(Clone, Copy, PartialEq)] -struct Modulus(pub(crate) T); +pub(crate) struct Modulus(pub(crate) T); pub(super) const SP_BOOL_PARAMS: BoolParameters = BoolParameters:: { rlwe_q: Modulus(268369921u64), diff --git a/src/decomposer.rs b/src/decomposer.rs index 544e30c..2db3f58 100644 --- a/src/decomposer.rs +++ b/src/decomposer.rs @@ -13,6 +13,30 @@ fn gadget_vector(logq: usize, logb: usize, d: usize) -> Vec { .collect_vec() } +pub trait RlweDecomposer { + type Element; + type D: Decomposer; + + /// Decomposer for RLWE Part A + fn a(&self) -> &Self::D; + /// Decomposer for RLWE Part B + fn b(&self) -> &Self::D; +} + +impl RlweDecomposer for (D, D) +where + D: Decomposer, +{ + type D = D; + type Element = D::Element; + fn a(&self) -> &Self::D { + &self.0 + } + fn b(&self) -> &Self::D { + &self.1 + } +} + pub trait Decomposer { type Element; fn new(q: Self::Element, logb: usize, d: usize) -> Self; @@ -142,6 +166,44 @@ impl Decomposer for DefaultDecompose } } +// impl Decomposer for dyn AsRef> +// where +// DefaultDecomposer: Decomposer, +// { +// type Element = T; + +// fn new(q: Self::Element, logb: usize, d: usize) -> Self { +// DefaultDecomposer::::new(q, logb, d) +// } + +// fn decompose(&self, v: &Self::Element) -> Vec { +// todo!() +// } + +// fn decomposition_count(&self) -> usize { +// todo!() +// } +// } + +// impl>> Decomposer for U +// where +// DefaultDecomposer: Decomposer, +// { +// type Element = T; + +// fn new(q: Self::Element, logb: usize, d: usize) -> Self { +// todo!() +// } + +// fn decompose(&self, v: &Self::Element) -> Vec { +// todo!() +// } + +// fn decomposition_count(&self) -> usize { +// todo!() +// } +// } + fn round_value(value: T, ignore_bits: usize) -> T { if ignore_bits == 0 { return value; diff --git a/src/lib.rs b/src/lib.rs index 2c77692..69f5dde 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -42,6 +42,9 @@ pub trait Matrix: AsRef<[Self::R]> { fn split_at_row(&self, idx: usize) -> (&[::R], &[::R]) { self.as_ref().split_at(idx) } + + /// Does the matrix fit sub-matrix of dimension row x col + fn fits(&self, row: usize, col: usize) -> bool; } pub trait MatrixMut: Matrix + AsMut<[::R]> @@ -96,6 +99,10 @@ impl Matrix for Vec> { fn dimension(&self) -> (usize, usize) { (self.len(), self[0].len()) } + + fn fits(&self, row: usize, col: usize) -> bool { + self.len() >= row && self[0].len() >= col + } } impl Matrix for &[Vec] { @@ -105,6 +112,10 @@ impl Matrix for &[Vec] { fn dimension(&self) -> (usize, usize) { (self.len(), self[0].len()) } + + fn fits(&self, row: usize, col: usize) -> bool { + self.len() >= row && self[0].len() >= col + } } impl Matrix for &mut [Vec] { @@ -114,6 +125,10 @@ impl Matrix for &mut [Vec] { fn dimension(&self) -> (usize, usize) { (self.len(), self[0].len()) } + + fn fits(&self, row: usize, col: usize) -> bool { + self.len() >= row && self[0].len() >= col + } } impl MatrixMut for Vec> {} diff --git a/src/lwe.rs b/src/lwe.rs index 4f99ed7..27df3a4 100644 --- a/src/lwe.rs +++ b/src/lwe.rs @@ -111,7 +111,9 @@ pub(crate) fn lwe_key_switch< operator: &Op, decomposer: &D, ) { - assert!(lwe_ksk.dimension().0 == ((lwe_in.as_ref().len() - 1) * decomposer.decomposition_count())); + assert!( + lwe_ksk.dimension().0 == ((lwe_in.as_ref().len() - 1) * decomposer.decomposition_count()) + ); assert!(lwe_out.as_ref().len() == lwe_ksk.dimension().1); let lwe_in_a_decomposed = lwe_in diff --git a/src/noise.rs b/src/noise.rs index fa7e28d..6a0b9ed 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -17,229 +17,244 @@ mod tests { Matrix, Row, Secret, }; - // Test B part with limbd -1 when variance of m is 1 - #[test] - fn trial() { - let logq = 28; - let ring_size = 1 << 10; - let q = generate_prime(logq, (ring_size as u64) << 1, 1 << logq).unwrap(); - let logb = 7; - let d0 = 3; - let d1 = d0 - 1; - - let sk = RlweSecret::random((ring_size >> 1) as usize, ring_size as usize); - - let mut rng = DefaultSecureRng::new(); - let decomposer = DefaultDecomposer::new(q, logb, d0); - let gadget_vector = decomposer.gadget_vector(); - - for i in 0..100 { - // m should have norm 1 - let mut m0 = vec![0u64; ring_size as usize]; - m0[thread_rng().gen_range(0..ring_size)] = 1; - - let modq_op = ModularOpsU64::new(q); - let nttq_op = NttBackendU64::new(q, ring_size); - - // Encrypt RGSW(m0) - let mut rgsw_seed = [0u8; 32]; - rng.fill_bytes(&mut rgsw_seed); - let mut seeded_rgsw = - SeededRgswCiphertext::>, _>::empty(ring_size, d0, rgsw_seed, q); - let mut p_rng = DefaultSecureRng::new_seeded(rgsw_seed); - secret_key_encrypt_rgsw( - &mut seeded_rgsw.data, - &m0, - &gadget_vector, - sk.values(), - &modq_op, - &nttq_op, - &mut p_rng, - &mut rng, - ); - - // Encrypt RLWE(m1) - let mut m1 = vec![0u64; ring_size]; - RandomUniformDist::random_fill(&mut rng, &q, m1.as_mut_slice()); - let mut rlwe_seed = [0u8; 32]; - rng.fill_bytes(&mut rlwe_seed); - let mut seeded_rlwe: SeededRlweCiphertext, [u8; 32]> = - SeededRlweCiphertext::, _>::empty(ring_size, rlwe_seed, q); - let mut p_rng = DefaultSecureRng::new_seeded(rlwe_seed); - secret_key_encrypt_rlwe( - &m1, - &mut seeded_rlwe.data, - sk.values(), - &modq_op, - &nttq_op, - &mut p_rng, - &mut rng, - ); - - let mut rlwe = RlweCiphertext::>, DefaultSecureRng>::from(&seeded_rlwe); - let rgsw = RgswCiphertextEvaluationDomain::<_, DefaultSecureRng, NttBackendU64>::from( - &seeded_rgsw, - ); - - // RLWE(m0m1) = RLWE(m1) x RGSW(m0) - let mut scratch = vec![vec![0u64; ring_size]; d0 + 2]; - less1_rlwe_by_rgsw( - &mut rlwe, - &rgsw.data, - &mut scratch, - &decomposer, - &nttq_op, - &modq_op, - 0, - 1, - ); - // rlwe_by_rgsw( - // &mut rlwe, - // &rgsw.data, - // &mut scratch, - // &decomposer, - // &nttq_op, - // &modq_op, - // ); - - // measure noise - let mul_mod = |v0: &u64, v1: &u64| ((*v0 as u128 * *v1 as u128) % q as u128) as u64; - let m0m1 = negacyclic_mul(&m0, &m1, mul_mod, q); - let noise = measure_noise(&rlwe, &m0m1, &nttq_op, &modq_op, sk.values()); - println!("Noise: {noise}"); - } - } - - // Test B part with limbd -1 when variance of m is 1 - #[test] - fn rgsw_saver() { - let logq = 60; - let ring_size = 1 << 11; - let q = generate_prime(logq, (ring_size as u64) << 1, 1 << logq).unwrap(); - let logb = 12; - let d0 = 4; - - let sk = RlweSecret::random((ring_size >> 1) as usize, ring_size as usize); - - let mut rng = DefaultSecureRng::new(); - - let decomposer = DefaultDecomposer::new(q, logb, d0); - let gadget_vector = decomposer.gadget_vector(); - - for i in 0..100 { - let modq_op = ModularOpsU64::new(q); - let nttq_op = NttBackendU64::new(q, ring_size); - - // Encrypt RGSW(m0) - let mut m0 = vec![0u64; ring_size as usize]; - m0[thread_rng().gen_range(0..ring_size)] = 1; - let mut rgsw_seed = [0u8; 32]; - rng.fill_bytes(&mut rgsw_seed); - let mut seeded_rgsw0 = - SeededRgswCiphertext::>, _>::empty(ring_size, d0, rgsw_seed, q); - let mut p_rng = DefaultSecureRng::new_seeded(rgsw_seed); - secret_key_encrypt_rgsw( - &mut seeded_rgsw0.data, - &m0, - &gadget_vector, - sk.values(), - &modq_op, - &nttq_op, - &mut p_rng, - &mut rng, - ); - - // Encrypt RGSW(m1) - let mut m1 = vec![0u64; ring_size as usize]; - m1[thread_rng().gen_range(0..ring_size)] = 1; - let mut rgsw_seed = [0u8; 32]; - rng.fill_bytes(&mut rgsw_seed); - let mut seeded_rgsw1 = - SeededRgswCiphertext::>, _>::empty(ring_size, d0, rgsw_seed, q); - let mut p_rng = DefaultSecureRng::new_seeded(rgsw_seed); - secret_key_encrypt_rgsw( - &mut seeded_rgsw1.data, - &m1, - &gadget_vector, - sk.values(), - &modq_op, - &nttq_op, - &mut p_rng, - &mut rng, - ); - - // TODO(Jay): Why cant you create RgswCIphertext from SeededRgswCiphertext? - let mut rgsw0 = { - let mut evl_tmp = - RgswCiphertextEvaluationDomain::<_, DefaultSecureRng, NttBackendU64>::from( - &seeded_rgsw0, - ); - evl_tmp - .data - .iter_mut() - .for_each(|ri| nttq_op.backward(ri.as_mut())); - evl_tmp.data - }; - let rgsw1 = RgswCiphertextEvaluationDomain::<_, DefaultSecureRng, NttBackendU64>::from( - &seeded_rgsw1, - ); - let mut scratch_matrix_d_plus_rgsw_by_ring = vec![vec![0u64; ring_size]; d0 + (d0 * 4)]; - - // RGSW(m0m1) = RGSW(m0)xRGSW(m1) - rgsw_by_rgsw_inplace( - &mut rgsw0, - &rgsw1.data, - &decomposer, - &mut scratch_matrix_d_plus_rgsw_by_ring, - &nttq_op, - &modq_op, - ); - - // send RGSW(m0m1) to Evaluation domain - let mut rgsw01 = rgsw0; - rgsw01 - .iter_mut() - .for_each(|v| nttq_op.forward(v.as_mut_slice())); - - // RLWE(m2) - let mut m2 = vec![0u64; ring_size as usize]; - RandomUniformDist::random_fill(&mut rng, &q, m2.as_mut_slice()); - let mut rlwe_seed = [0u8; 32]; - rng.fill_bytes(&mut rlwe_seed); - let mut seeded_rlwe = - SeededRlweCiphertext::, _>::empty(ring_size, rlwe_seed, q); - let mut p_rng = DefaultSecureRng::new_seeded(rlwe_seed); - secret_key_encrypt_rlwe( - &m2, - &mut seeded_rlwe.data, - sk.values(), - &modq_op, - &nttq_op, - &mut p_rng, - &mut rng, - ); - - let mut rlwe = RlweCiphertext::>, DefaultSecureRng>::from(&seeded_rlwe); - - // RLWE(m0m1m2) = RLWE(m2) x RGSW(m0m1) - let mut scratch_matrix_dplus2_ring = vec![vec![0u64; ring_size]; d0 + 2]; - less1_rlwe_by_rgsw( - &mut rlwe, - &rgsw01, - &mut scratch_matrix_dplus2_ring, - &decomposer, - &nttq_op, - &modq_op, - 1, - 2, - ); - - let mul_mod = |v0: &u64, v1: &u64| ((*v0 as u128 * *v1 as u128) % q as u128) as u64; - let m0m1 = negacyclic_mul(&m0, &m1, mul_mod, q); - let m0m1m2 = negacyclic_mul(&m2, &m0m1, mul_mod, q); - let noise = measure_noise(&rlwe.data, &m0m1m2, &nttq_op, &modq_op, sk.values()); - - println!("Noise: {noise}"); - } - } + // // Test B part with limbd -1 when variance of m is 1 + // #[test] + // fn trial() { + // let logq = 28; + // let ring_size = 1 << 10; + // let q = generate_prime(logq, (ring_size as u64) << 1, 1 << + // logq).unwrap(); let logb = 7; + // let d0 = 3; + // let d1 = d0 - 1; + + // let sk = RlweSecret::random((ring_size >> 1) as usize, ring_size as + // usize); + + // let mut rng = DefaultSecureRng::new(); + // let decomposer = DefaultDecomposer::new(q, logb, d0); + // let gadget_vector = decomposer.gadget_vector(); + + // for i in 0..100 { + // // m should have norm 1 + // let mut m0 = vec![0u64; ring_size as usize]; + // m0[thread_rng().gen_range(0..ring_size)] = 1; + + // let modq_op = ModularOpsU64::new(q); + // let nttq_op = NttBackendU64::new(q, ring_size); + + // // Encrypt RGSW(m0) + // let mut rgsw_seed = [0u8; 32]; + // rng.fill_bytes(&mut rgsw_seed); + // let mut seeded_rgsw = + // SeededRgswCiphertext::>, _>::empty(ring_size, + // d0, rgsw_seed, q); let mut p_rng = + // DefaultSecureRng::new_seeded(rgsw_seed); + // secret_key_encrypt_rgsw( + // &mut seeded_rgsw.data, + // &m0, + // &gadget_vector, + // &gadget_vector, + // sk.values(), + // &modq_op, + // &nttq_op, + // &mut p_rng, + // &mut rng, + // ); + + // // Encrypt RLWE(m1) + // let mut m1 = vec![0u64; ring_size]; + // RandomUniformDist::random_fill(&mut rng, &q, m1.as_mut_slice()); + // let mut rlwe_seed = [0u8; 32]; + // rng.fill_bytes(&mut rlwe_seed); + // let mut seeded_rlwe: SeededRlweCiphertext, [u8; 32]> = + // SeededRlweCiphertext::, _>::empty(ring_size, + // rlwe_seed, q); let mut p_rng = + // DefaultSecureRng::new_seeded(rlwe_seed); + // secret_key_encrypt_rlwe( + // &m1, + // &mut seeded_rlwe.data, + // sk.values(), + // &modq_op, + // &nttq_op, + // &mut p_rng, + // &mut rng, + // ); + + // let mut rlwe = RlweCiphertext::>, + // DefaultSecureRng>::from(&seeded_rlwe); let rgsw = + // RgswCiphertextEvaluationDomain::<_, DefaultSecureRng, + // NttBackendU64>::from( &seeded_rgsw, + // ); + + // // RLWE(m0m1) = RLWE(m1) x RGSW(m0) + // let mut scratch = vec![vec![0u64; ring_size]; d0 + 2]; + // less1_rlwe_by_rgsw( + // &mut rlwe, + // &rgsw.data, + // &mut scratch, + // &decomposer, + // &nttq_op, + // &modq_op, + // 0, + // 1, + // ); + // // rlwe_by_rgsw( + // // &mut rlwe, + // // &rgsw.data, + // // &mut scratch, + // // &decomposer, + // // &nttq_op, + // // &modq_op, + // // ); + + // // measure noise + // let mul_mod = |v0: &u64, v1: &u64| ((*v0 as u128 * *v1 as u128) % + // q as u128) as u64; let m0m1 = negacyclic_mul(&m0, &m1, + // mul_mod, q); let noise = measure_noise(&rlwe, &m0m1, + // &nttq_op, &modq_op, sk.values()); println!("Noise: {noise}"); + // } + // } + + // // Test B part with limbd -1 when variance of m is 1 + // #[test] + // fn rgsw_saver() { + // let logq = 60; + // let ring_size = 1 << 11; + // let q = generate_prime(logq, (ring_size as u64) << 1, 1 << + // logq).unwrap(); let logb = 12; + // let d0 = 4; + + // let sk = RlweSecret::random((ring_size >> 1) as usize, ring_size as + // usize); + + // let mut rng = DefaultSecureRng::new(); + + // let decomposer = DefaultDecomposer::new(q, logb, d0); + // let gadget_vector = decomposer.gadget_vector(); + + // for i in 0..100 { + // let modq_op = ModularOpsU64::new(q); + // let nttq_op = NttBackendU64::new(q, ring_size); + + // // Encrypt RGSW(m0) + // let mut m0 = vec![0u64; ring_size as usize]; + // m0[thread_rng().gen_range(0..ring_size)] = 1; + // let mut rgsw_seed = [0u8; 32]; + // rng.fill_bytes(&mut rgsw_seed); + // let mut seeded_rgsw0 = + // SeededRgswCiphertext::>, _>::empty(ring_size, + // d0, rgsw_seed, q); let mut p_rng = + // DefaultSecureRng::new_seeded(rgsw_seed); + // secret_key_encrypt_rgsw( + // &mut seeded_rgsw0.data, + // &m0, + // &gadget_vector, + // &gadget_vector, + // sk.values(), + // &modq_op, + // &nttq_op, + // &mut p_rng, + // &mut rng, + // ); + + // // Encrypt RGSW(m1) + // let mut m1 = vec![0u64; ring_size as usize]; + // m1[thread_rng().gen_range(0..ring_size)] = 1; + // let mut rgsw_seed = [0u8; 32]; + // rng.fill_bytes(&mut rgsw_seed); + // let mut seeded_rgsw1 = + // SeededRgswCiphertext::>, _>::empty(ring_size, + // d0, rgsw_seed, q); let mut p_rng = + // DefaultSecureRng::new_seeded(rgsw_seed); + // secret_key_encrypt_rgsw( + // &mut seeded_rgsw1.data, + // &m1, + // &gadget_vector, + // &gadget_vector, + // sk.values(), + // &modq_op, + // &nttq_op, + // &mut p_rng, + // &mut rng, + // ); + + // // TODO(Jay): Why cant you create RgswCIphertext from + // SeededRgswCiphertext? let mut rgsw0 = { + // let mut evl_tmp = + // RgswCiphertextEvaluationDomain::<_, DefaultSecureRng, + // NttBackendU64>::from( &seeded_rgsw0, + // ); + // evl_tmp + // .data + // .iter_mut() + // .for_each(|ri| nttq_op.backward(ri.as_mut())); + // evl_tmp.data + // }; + // let rgsw1 = RgswCiphertextEvaluationDomain::<_, DefaultSecureRng, + // NttBackendU64>::from( &seeded_rgsw1, + // ); + // let mut scratch_matrix_d_plus_rgsw_by_ring = vec![vec![0u64; + // ring_size]; d0 + (d0 * 4)]; + + // // RGSW(m0m1) = RGSW(m0)xRGSW(m1) + // rgsw_by_rgsw_inplace( + // &mut rgsw0, + // &rgsw1.data, + // &decomposer, + // &decomposer, + // &mut scratch_matrix_d_plus_rgsw_by_ring, + // &nttq_op, + // &modq_op, + // ); + + // // send RGSW(m0m1) to Evaluation domain + // let mut rgsw01 = rgsw0; + // rgsw01 + // .iter_mut() + // .for_each(|v| nttq_op.forward(v.as_mut_slice())); + + // // RLWE(m2) + // let mut m2 = vec![0u64; ring_size as usize]; + // RandomUniformDist::random_fill(&mut rng, &q, m2.as_mut_slice()); + // let mut rlwe_seed = [0u8; 32]; + // rng.fill_bytes(&mut rlwe_seed); + // let mut seeded_rlwe = + // SeededRlweCiphertext::, _>::empty(ring_size, + // rlwe_seed, q); let mut p_rng = + // DefaultSecureRng::new_seeded(rlwe_seed); + // secret_key_encrypt_rlwe( + // &m2, + // &mut seeded_rlwe.data, + // sk.values(), + // &modq_op, + // &nttq_op, + // &mut p_rng, + // &mut rng, + // ); + + // let mut rlwe = RlweCiphertext::>, + // DefaultSecureRng>::from(&seeded_rlwe); + + // // RLWE(m0m1m2) = RLWE(m2) x RGSW(m0m1) + // let mut scratch_matrix_dplus2_ring = vec![vec![0u64; ring_size]; + // d0 + 2]; less1_rlwe_by_rgsw( + // &mut rlwe, + // &rgsw01, + // &mut scratch_matrix_dplus2_ring, + // &decomposer, + // &nttq_op, + // &modq_op, + // 1, + // 2, + // ); + + // let mul_mod = |v0: &u64, v1: &u64| ((*v0 as u128 * *v1 as u128) % + // q as u128) as u64; let m0m1 = negacyclic_mul(&m0, &m1, + // mul_mod, q); let m0m1m2 = negacyclic_mul(&m2, &m0m1, mul_mod, + // q); let noise = measure_noise(&rlwe.data, &m0m1m2, &nttq_op, + // &modq_op, sk.values()); + + // println!("Noise: {noise}"); + // } + // } } diff --git a/src/rgsw.rs b/src/rgsw.rs index 590ee45..a46bbe9 100644 --- a/src/rgsw.rs +++ b/src/rgsw.rs @@ -10,7 +10,7 @@ use num_traits::{PrimInt, Signed, ToPrimitive, Zero}; use crate::{ backend::{ArithmeticOps, VectorOps}, - decomposer::{self, Decomposer}, + decomposer::{self, Decomposer, RlweDecomposer}, ntt::{self, Ntt, NttInit}, random::{DefaultSecureRng, NewWithSeed, RandomGaussianDist, RandomUniformDist}, utils::{fill_random_ternary_secret_with_hamming_weight, TryConvertFrom, WithLocal}, @@ -27,19 +27,14 @@ where } impl SeededAutoKey { - fn from_raw(data: M, seed: S, modulus: M::MatElement) -> Self { - assert!(data.dimension().0 % 3 == 0); - - SeededAutoKey { - data, - seed, - modulus, - } - } - - fn empty(ring_size: usize, d_rgsw: usize, seed: S, modulus: M::MatElement) -> Self { + fn empty( + ring_size: usize, + auto_decomposer: &D, + seed: S, + modulus: M::MatElement, + ) -> Self { SeededAutoKey { - data: M::zeros(d_rgsw, ring_size), + data: M::zeros(auto_decomposer.decomposition_count(), ring_size), seed, modulus: modulus, } @@ -89,8 +84,31 @@ where } pub struct RgswCiphertext { - data: M, + /// Rgsw ciphertext polynomials + pub(crate) data: M, modulus: M::MatElement, + /// Decomposition for RLWE part A + d_a: usize, + /// Decomposition for RLWE part B + d_b: usize, +} + +impl RgswCiphertext { + pub(crate) fn empty( + ring_size: usize, + decomposer: &D, + modulus: M::MatElement, + ) -> RgswCiphertext { + RgswCiphertext { + data: M::zeros( + decomposer.a().decomposition_count() * 2 + decomposer.b().decomposition_count() * 2, + ring_size, + ), + d_a: decomposer.a().decomposition_count(), + d_b: decomposer.b().decomposition_count(), + modulus, + } + } } pub struct SeededRgswCiphertext @@ -100,29 +118,28 @@ where pub(crate) data: M, seed: S, modulus: M::MatElement, + /// Decomposition for RLWE part A + d_a: usize, + /// Decomposition for RLWE part B + d_b: usize, } impl SeededRgswCiphertext { - pub(crate) fn from_raw(data: M, seed: S, modulus: M::MatElement) -> SeededRgswCiphertext { - assert!(data.dimension().0 % 3 == 0); - - SeededRgswCiphertext { - data, - seed, - modulus, - } - } - - pub(crate) fn empty( + pub(crate) fn empty( ring_size: usize, - d_rgsw: usize, + decomposer: &D, seed: S, modulus: M::MatElement, ) -> SeededRgswCiphertext { SeededRgswCiphertext { - data: M::zeros(d_rgsw * 3, ring_size), + data: M::zeros( + decomposer.a().decomposition_count() * 2 + decomposer.b().decomposition_count(), + ring_size, + ), seed, modulus: modulus, + d_a: decomposer.a().decomposition_count(), + d_b: decomposer.b().decomposition_count(), } } } @@ -157,28 +174,28 @@ where M: Debug, { fn from(value: &SeededRgswCiphertext) -> Self { - let d = value.data.dimension().0.div(3); - - let mut data = M::zeros(4 * d, value.data.dimension().1); + let mut data = M::zeros(value.d_a * 2 + value.d_b * 2, value.data.dimension().1); // copy RLWE'(-sm) - izip!(data.iter_rows_mut().take(2 * d), value.data.iter_rows()).for_each( - |(to_ri, from_ri)| { - to_ri.as_mut().copy_from_slice(from_ri.as_ref()); - }, - ); + izip!( + data.iter_rows_mut().take(value.d_a * 2), + value.data.iter_rows().take(value.d_a * 2) + ) + .for_each(|(to_ri, from_ri)| { + to_ri.as_mut().copy_from_slice(from_ri.as_ref()); + }); // sample A polynomials of RLWE'(m) - RLWE'A(m) // TODO(Jay): Do we want to be generic over RandomGenerator used here? I think // not. let mut p_rng = R::new_with_seed(value.seed.clone()); - izip!(data.iter_rows_mut().skip(2 * d).take(d)) + izip!(data.iter_rows_mut().skip(value.d_a * 2).take(value.d_b * 1)) .for_each(|ri| p_rng.random_fill(&value.modulus, ri.as_mut())); // RLWE'_B(m) izip!( - data.iter_rows_mut().skip(3 * d), - value.data.iter_rows().skip(2 * d) + data.iter_rows_mut().skip(value.d_a * 2 + value.d_b), + value.data.iter_rows().skip(value.d_a * 2) ) .for_each(|(to_ri, from_ri)| { to_ri.as_mut().copy_from_slice(from_ri.as_ref()); @@ -208,22 +225,21 @@ where M: Debug, { fn from(value: &RgswCiphertext) -> Self { - assert!(value.data.dimension().0 % 4 == 0); - let d = value.data.dimension().0.div(4); - - let mut data = M::zeros(4 * d, value.data.dimension().1); + let mut data = M::zeros(value.d_a * 2 + value.d_b * 2, value.data.dimension().1); // copy RLWE'(-sm) - izip!(data.iter_rows_mut().take(2 * d), value.data.iter_rows()).for_each( - |(to_ri, from_ri)| { - to_ri.as_mut().copy_from_slice(from_ri.as_ref()); - }, - ); + izip!( + data.iter_rows_mut().take(value.d_a * 2), + value.data.iter_rows().take(value.d_a * 2) + ) + .for_each(|(to_ri, from_ri)| { + to_ri.as_mut().copy_from_slice(from_ri.as_ref()); + }); // copy RLWE'(m) izip!( - data.iter_rows_mut().skip(2 * d), - value.data.iter_rows().skip(2 * d) + data.iter_rows_mut().skip(value.d_a * 2), + value.data.iter_rows().skip(value.d_a * 2) ) .for_each(|(to_ri, from_ri)| { to_ri.as_mut().copy_from_slice(from_ri.as_ref()); @@ -258,6 +274,10 @@ impl Matrix for RgswCiphertextEvaluationDomain { fn dimension(&self) -> (usize, usize) { self.data.dimension() } + + fn fits(&self, row: usize, col: usize) -> bool { + self.data.fits(row, col) + } } impl AsRef<[M::R]> for RgswCiphertextEvaluationDomain { @@ -308,6 +328,10 @@ impl Matrix for RlweCiphertext { fn dimension(&self) -> (usize, usize) { self.data.dimension() } + + fn fits(&self, row: usize, col: usize) -> bool { + self.data.fits(row, col) + } } impl MatrixMut for RlweCiphertext where ::R: RowMut {} @@ -501,6 +525,9 @@ pub(crate) fn decompose_r>( } /// Sends RLWE_{s}(X) -> RLWE_{s}(X^k) where k is some galois element +/// +/// - scratch_matrix: must have dimension at-least d+2 x ring_size. d rows to +/// store decomposed polynomials and 2 for rlwe pub(crate) fn galois_auto< MT: Matrix + IsTrivial + MatrixMut, Mmut: MatrixMut, @@ -510,7 +537,7 @@ pub(crate) fn galois_auto< >( rlwe_in: &mut MT, ksk: &Mmut, - scratch_matrix_dplus2_ring: &mut Mmut, + scratch_matrix: &mut Mmut, auto_map_index: &[usize], auto_map_sign: &[bool], mod_op: &ModOp, @@ -522,8 +549,11 @@ pub(crate) fn galois_auto< MT::MatElement: Copy + Zero, { let d = decomposer.decomposition_count(); + let ring_size = rlwe_in.dimension().1; + assert!(rlwe_in.dimension().0 == 2); + assert!(scratch_matrix.fits(d + 2, ring_size)); - let (scratch_matrix_d_ring, tmp_rlwe_out) = scratch_matrix_dplus2_ring.split_at_row_mut(d); + let (scratch_matrix_d_ring, tmp_rlwe_out) = scratch_matrix.split_at_row_mut(d); // send b(X) -> b(X^k) izip!( @@ -697,18 +727,19 @@ pub(crate) fn less1_rlwe_by_rgsw< /// /// - rlwe_in: is RLWE(m0) with polynomials in coefficient domain /// - rgsw_in: is RGSW(m1) with polynomials in evaluation domain -/// - scratch_matrix_d_ring: is a matrix of dimension (d_rgsw, ring_size) used -/// as scratch space to store decomposed Ring elements temporarily +/// - scratch_matrix_d_ring: is a matrix with atleast max(d_a, d_b) rows and +/// ring_size columns. It's used to store decomposed polynomials and out RLWE +/// temoporarily pub(crate) fn rlwe_by_rgsw< Mmut: MatrixMut, MT: Matrix + MatrixMut + IsTrivial, - D: Decomposer, + D: RlweDecomposer, ModOp: VectorOps, NttOp: Ntt, >( rlwe_in: &mut MT, rgsw_in: &Mmut, - scratch_matrix_dplus2_ring: &mut Mmut, + scratch_matrix: &mut Mmut, decomposer: &D, ntt_op: &NttOp, mod_op: &ModOp, @@ -717,21 +748,28 @@ pub(crate) fn rlwe_by_rgsw< ::R: RowMut, ::R: RowMut, { - let d_rgsw = decomposer.decomposition_count(); - assert!(scratch_matrix_dplus2_ring.dimension() == (d_rgsw + 2, rlwe_in.dimension().1)); - assert!(rgsw_in.dimension() == (d_rgsw * 4, rlwe_in.dimension().1)); + let decomposer_a = decomposer.a(); + let decomposer_b = decomposer.b(); + let d_a = decomposer_a.decomposition_count(); + let d_b = decomposer_b.decomposition_count(); + let max_d = std::cmp::max(d_a, d_b); + assert!(scratch_matrix.fits(max_d + 2, rlwe_in.dimension().1)); + assert!(rgsw_in.dimension() == (d_a * 2 + d_b * 2, rlwe_in.dimension().1)); // decomposed RLWE x RGSW - let (rlwe_dash_nsm, rlwe_dash_m) = rgsw_in.split_at_row(d_rgsw * 2); - let (scratch_matrix_d_ring, scratch_rlwe_out) = - scratch_matrix_dplus2_ring.split_at_row_mut(d_rgsw); + let (rlwe_dash_nsm, rlwe_dash_m) = rgsw_in.split_at_row(d_a * 2); + let (scratch_matrix_d_ring, scratch_rlwe_out) = scratch_matrix.split_at_row_mut(max_d); scratch_rlwe_out[0].as_mut().fill(Mmut::MatElement::zero()); scratch_rlwe_out[1].as_mut().fill(Mmut::MatElement::zero()); // RLWE_in = a_in, b_in; RLWE_out = a_out, b_out if !rlwe_in.is_trivial() { // a_in = 0 when RLWE_in is trivial RLWE ciphertext // decomp - decompose_r(rlwe_in.get_row_slice(0), scratch_matrix_d_ring, decomposer); + decompose_r( + rlwe_in.get_row_slice(0), + scratch_matrix_d_ring, + decomposer_a, + ); scratch_matrix_d_ring .iter_mut() .for_each(|r| ntt_op.forward(r.as_mut())); @@ -739,19 +777,23 @@ pub(crate) fn rlwe_by_rgsw< routine( scratch_rlwe_out[0].as_mut(), scratch_matrix_d_ring.as_ref(), - &rlwe_dash_nsm[..d_rgsw], + &rlwe_dash_nsm[..d_a], mod_op, ); // b_out += decomp \cdot RLWE_B'(-sm) routine( scratch_rlwe_out[1].as_mut(), scratch_matrix_d_ring.as_ref(), - &rlwe_dash_nsm[d_rgsw..], + &rlwe_dash_nsm[d_a..], mod_op, ); } // decomp - decompose_r(rlwe_in.get_row_slice(1), scratch_matrix_d_ring, decomposer); + decompose_r( + rlwe_in.get_row_slice(1), + scratch_matrix_d_ring, + decomposer_b, + ); scratch_matrix_d_ring .iter_mut() .for_each(|r| ntt_op.forward(r.as_mut())); @@ -759,14 +801,14 @@ pub(crate) fn rlwe_by_rgsw< routine( scratch_rlwe_out[0].as_mut(), scratch_matrix_d_ring.as_ref(), - &rlwe_dash_m[..d_rgsw], + &rlwe_dash_m[..d_b], mod_op, ); // b_out += decomp \cdot RLWE_B'(m) routine( scratch_rlwe_out[1].as_mut(), scratch_matrix_d_ring.as_ref(), - &rlwe_dash_m[d_rgsw..], + &rlwe_dash_m[d_b..], mod_op, ); @@ -800,54 +842,53 @@ pub(crate) fn rlwe_by_rgsw< /// /// - rgsw_0: RGSW(m0) /// - rgsw_1_eval: RGSW(m1) in Evaluation domain -/// - scratch_matrix_d_plus_rgsw_by_ring: scratch space matrix of size -/// (d+(d*4))xring_size, where d equals d_rgsw +/// - scratch_matrix_d_plus_rgsw_by_ring: scratch space matrix with rows +/// (max(d_a, d_b) + d_a*2+d_b*2) and columns ring_size pub(crate) fn rgsw_by_rgsw_inplace< Mmut: MatrixMut, - D: Decomposer, + D: RlweDecomposer, ModOp: VectorOps, NttOp: Ntt, >( rgsw_0: &mut Mmut, rgsw_1_eval: &Mmut, decomposer: &D, - scratch_matrix_d_plus_rgsw_by_ring: &mut Mmut, + scratch_matrix: &mut Mmut, ntt_op: &NttOp, mod_op: &ModOp, ) where ::R: RowMut, Mmut::MatElement: Copy + Zero, { - let d_rgsw = decomposer.decomposition_count(); - assert!(rgsw_0.dimension().0 == 4 * d_rgsw); + let decomposer_a = decomposer.a(); + let decomposer_b = decomposer.b(); + let d_a = decomposer_a.decomposition_count(); + let d_b = decomposer_b.decomposition_count(); + let max_d = std::cmp::max(d_a, d_b); + let rgsw_rows = d_a * 2 + d_b * 2; + assert!(rgsw_0.dimension().0 == rgsw_rows); let ring_size = rgsw_0.dimension().1; - assert!(rgsw_1_eval.dimension() == (4 * d_rgsw, ring_size)); - assert!(scratch_matrix_d_plus_rgsw_by_ring.dimension() == (d_rgsw + (d_rgsw * 4), ring_size)); + assert!(rgsw_1_eval.dimension() == (rgsw_rows, ring_size)); + assert!(scratch_matrix.fits(max_d + rgsw_rows, ring_size)); - let (decomp_r_space, rgsw_space) = scratch_matrix_d_plus_rgsw_by_ring.split_at_row_mut(d_rgsw); + let (decomp_r_space, rgsw_space) = scratch_matrix.split_at_row_mut(max_d); // zero rgsw_space rgsw_space .iter_mut() .for_each(|ri| ri.as_mut().fill(Mmut::MatElement::zero())); - let (rlwe_dash_space_nsm, rlwe_dash_space_m) = rgsw_space.split_at_mut(d_rgsw * 2); + let (rlwe_dash_space_nsm, rlwe_dash_space_m) = rgsw_space.split_at_mut(d_a * 2); let (rlwe_dash_space_nsm_parta, rlwe_dash_space_nsm_partb) = - rlwe_dash_space_nsm.split_at_mut(d_rgsw); - let (rlwe_dash_space_m_parta, rlwe_dash_space_m_partb) = rlwe_dash_space_m.split_at_mut(d_rgsw); + rlwe_dash_space_nsm.split_at_mut(d_a); + let (rlwe_dash_space_m_parta, rlwe_dash_space_m_partb) = rlwe_dash_space_m.split_at_mut(d_b); - let (rgsw0_nsm, rgsw0_m) = rgsw_0.split_at_row(d_rgsw * 2); - let (rgsw1_nsm, rgsw1_m) = rgsw_1_eval.split_at_row(d_rgsw * 2); + let (rgsw0_nsm, rgsw0_m) = rgsw_0.split_at_row(d_a * 2); + let (rgsw1_nsm, rgsw1_m) = rgsw_1_eval.split_at_row(d_a * 2); // RGSW x RGSW izip!( - rgsw0_nsm - .iter() - .take(d_rgsw) - .chain(rgsw0_m.iter().take(d_rgsw)), - rgsw0_nsm - .iter() - .skip(d_rgsw) - .chain(rgsw0_m.iter().skip(d_rgsw)), + rgsw0_nsm.iter().take(d_a).chain(rgsw0_m.iter().take(d_b)), + rgsw0_nsm.iter().skip(d_a).chain(rgsw0_m.iter().skip(d_b)), rlwe_dash_space_nsm_parta .iter_mut() .chain(rlwe_dash_space_m_parta.iter_mut()), @@ -857,38 +898,40 @@ pub(crate) fn rgsw_by_rgsw_inplace< ) .for_each(|(rlwe_a, rlwe_b, rlwe_out_a, rlwe_out_b)| { // Part A - decompose_r(rlwe_a.as_ref(), decomp_r_space.as_mut(), decomposer); + decompose_r(rlwe_a.as_ref(), decomp_r_space.as_mut(), decomposer_a); decomp_r_space .iter_mut() + .take(d_a) .for_each(|ri| ntt_op.forward(ri.as_mut())); routine( rlwe_out_a.as_mut(), - decomp_r_space, - &rgsw1_nsm[..d_rgsw], + &decomp_r_space[..d_a], + &rgsw1_nsm[..d_a], mod_op, ); routine( rlwe_out_b.as_mut(), - decomp_r_space, - &rgsw1_nsm[d_rgsw..], + &decomp_r_space[..d_a], + &rgsw1_nsm[d_a..], mod_op, ); // Part B - decompose_r(rlwe_b.as_ref(), decomp_r_space.as_mut(), decomposer); + decompose_r(rlwe_b.as_ref(), decomp_r_space.as_mut(), decomposer_b); decomp_r_space .iter_mut() + .take(d_b) .for_each(|ri| ntt_op.forward(ri.as_mut())); routine( rlwe_out_a.as_mut(), - decomp_r_space, - &rgsw1_m[..d_rgsw], + &decomp_r_space[..d_b], + &rgsw1_m[..d_b], mod_op, ); routine( rlwe_out_b.as_mut(), - decomp_r_space, - &rgsw1_m[d_rgsw..], + &decomp_r_space[..d_b], + &rgsw1_m[d_b..], mod_op, ); }); @@ -921,7 +964,8 @@ pub(crate) fn secret_key_encrypt_rgsw< >( out_rgsw: &mut Mmut, m: &[Mmut::MatElement], - gadget_vector: &[Mmut::MatElement], + gadget_a: &[Mmut::MatElement], + gadget_b: &[Mmut::MatElement], s: &[S], mod_op: &ModOp, ntt_op: &NttOp, @@ -932,14 +976,15 @@ pub(crate) fn secret_key_encrypt_rgsw< RowMut + RowEntity + TryConvertFrom<[S], Parameters = Mmut::MatElement> + Debug, Mmut::MatElement: Copy + Debug, { - let d = gadget_vector.len(); + let d_a = gadget_a.len(); + let d_b = gadget_b.len(); let q = mod_op.modulus(); let ring_size = s.len(); - assert!(out_rgsw.dimension() == (d * 3, ring_size)); + assert!(out_rgsw.dimension() == (d_a * 2 + d_b, ring_size)); assert!(m.as_ref().len() == ring_size); // RLWE(-sm), RLWE(m) - let (rlwe_dash_nsm, b_rlwe_dash_m) = out_rgsw.split_at_row_mut(d * 2); + let (rlwe_dash_nsm, b_rlwe_dash_m) = out_rgsw.split_at_row_mut(d_a * 2); let mut s_eval = Mmut::R::try_convert_from(s, &q); ntt_op.forward(s_eval.as_mut()); @@ -947,11 +992,11 @@ pub(crate) fn secret_key_encrypt_rgsw< let mut scratch_space = Mmut::R::zeros(ring_size); // RLWE'(-sm) - let (a_rlwe_dash_nsm, b_rlwe_dash_nsm) = rlwe_dash_nsm.split_at_mut(d); + let (a_rlwe_dash_nsm, b_rlwe_dash_nsm) = rlwe_dash_nsm.split_at_mut(d_a); izip!( a_rlwe_dash_nsm.iter_mut(), b_rlwe_dash_nsm.iter_mut(), - gadget_vector.iter() + gadget_a.iter() ) .for_each(|(ai, bi, beta_i)| { // Sample a_i @@ -975,7 +1020,7 @@ pub(crate) fn secret_key_encrypt_rgsw< // RLWE(m) let mut a_rlwe_dash_m = { // polynomials of part A of RLWE'(m) are sampled from seed - let mut a = Mmut::zeros(d, ring_size); + let mut a = Mmut::zeros(d_b, ring_size); a.iter_rows_mut() .for_each(|ai| RandomUniformDist::random_fill(p_rng, &q, ai.as_mut())); a @@ -984,7 +1029,7 @@ pub(crate) fn secret_key_encrypt_rgsw< izip!( a_rlwe_dash_m.iter_rows_mut(), b_rlwe_dash_m.iter_mut(), - gadget_vector.iter() + gadget_b.iter() ) .for_each(|(ai, bi, beta_i)| { // ai * s @@ -1015,7 +1060,8 @@ pub(crate) fn public_key_encrypt_rgsw< out_rgsw: &mut Mmut, m: &[M::MatElement], public_key: &M, - gadget_vector: &[Mmut::MatElement], + gadget_a: &[Mmut::MatElement], + gadget_b: &[Mmut::MatElement], mod_op: &ModOp, ntt_op: &NttOp, rng: &mut R, @@ -1024,9 +1070,10 @@ pub(crate) fn public_key_encrypt_rgsw< Mmut::MatElement: Copy, { let ring_size = public_key.dimension().1; - let d = gadget_vector.len(); + let d_a = gadget_a.len(); + let d_b = gadget_b.len(); assert!(public_key.dimension().0 == 2); - assert!(out_rgsw.dimension() == (d * 4, ring_size)); + assert!(out_rgsw.dimension() == (d_a * 2 + d_b * 2, ring_size)); let mut pk_eval = Mmut::zeros(2, ring_size); izip!(pk_eval.iter_rows_mut(), public_key.iter_rows()).for_each(|(to_i, from_i)| { @@ -1039,14 +1086,14 @@ pub(crate) fn public_key_encrypt_rgsw< let q = mod_op.modulus(); // RGSW(m) = RLWE'(-sm), RLWE(m) - let (rlwe_dash_nsm, rlwe_dash_m) = out_rgsw.split_at_row_mut(2 * d); + let (rlwe_dash_nsm, rlwe_dash_m) = out_rgsw.split_at_row_mut(d_a * 2); // RLWE(-sm) - let (rlwe_dash_nsm_parta, rlwe_dash_nsm_partb) = rlwe_dash_nsm.split_at_mut(d); + let (rlwe_dash_nsm_parta, rlwe_dash_nsm_partb) = rlwe_dash_nsm.split_at_mut(d_a); izip!( rlwe_dash_nsm_parta.iter_mut(), rlwe_dash_nsm_partb.iter_mut(), - gadget_vector.iter() + gadget_a.iter() ) .for_each(|(ai, bi, beta_i)| { // sample ephemeral secret u_i @@ -1081,11 +1128,11 @@ pub(crate) fn public_key_encrypt_rgsw< }); // RLWE(m) - let (rlwe_dash_m_parta, rlwe_dash_m_partb) = rlwe_dash_m.split_at_mut(d); + let (rlwe_dash_m_parta, rlwe_dash_m_partb) = rlwe_dash_m.split_at_mut(d_b); izip!( rlwe_dash_m_parta.iter_mut(), rlwe_dash_m_partb.iter_mut(), - gadget_vector.iter() + gadget_b.iter() ) .for_each(|(ai, bi, beta_i)| { // sample ephemeral secret u_i @@ -1502,7 +1549,7 @@ pub(crate) mod tests { use crate::{ backend::{ModInit, ModularOpsU64, VectorOps}, - decomposer::{Decomposer, DefaultDecomposer}, + decomposer::{Decomposer, DefaultDecomposer, RlweDecomposer}, ntt::{self, Ntt, NttBackendU64, NttInit}, random::{DefaultSecureRng, NewWithSeed, RandomUniformDist}, rgsw::{ @@ -1520,6 +1567,170 @@ pub(crate) mod tests { RlweSecret, }; + pub(crate) fn _sk_encrypt_rlwe( + m: &[u64], + s: &[i32], + ntt_op: &NttBackendU64, + mod_op: &ModularOpsU64, + ) -> RlweCiphertext>, DefaultSecureRng> { + let ring_size = m.len(); + let q = mod_op.modulus(); + assert!(s.len() == ring_size); + + let mut rng = DefaultSecureRng::new(); + let mut rlwe_seed = [0u8; 32]; + rng.fill_bytes(&mut rlwe_seed); + let mut seeded_rlwe_ct = + SeededRlweCiphertext::<_, [u8; 32]>::empty(ring_size as usize, rlwe_seed, q); + let mut p_rng = DefaultSecureRng::new_seeded(rlwe_seed); + secret_key_encrypt_rlwe( + &m, + &mut seeded_rlwe_ct.data, + s, + mod_op, + ntt_op, + &mut p_rng, + &mut rng, + ); + + RlweCiphertext::>, DefaultSecureRng>::from(&seeded_rlwe_ct) + } + + // Encrypt m as RGSW ciphertext RGSW(m) using supplied public key + pub(crate) fn _pk_encrypt_rgsw( + m: &[u64], + public_key: &RlwePublicKey>, DefaultSecureRng>, + decomposer: &(DefaultDecomposer, DefaultDecomposer), + mod_op: &ModularOpsU64, + ntt_op: &NttBackendU64, + ) -> RgswCiphertext>> { + let (_, ring_size) = Matrix::dimension(&public_key.data); + let gadget_vector_a = decomposer.a().gadget_vector(); + let gadget_vector_b = decomposer.b().gadget_vector(); + + let mut rng = DefaultSecureRng::new(); + + assert!(m.len() == ring_size); + + // public key encrypt RGSW(m1) + let mut rgsw_ct = RgswCiphertext::empty(ring_size, decomposer, mod_op.modulus()); + public_key_encrypt_rgsw( + &mut rgsw_ct.data, + m, + &public_key.data, + &gadget_vector_a, + &gadget_vector_b, + mod_op, + ntt_op, + &mut rng, + ); + + rgsw_ct + } + + /// Encrypts m as RGSW ciphertext RGSW(m) using supplied secret key. Returns + /// unseeded RGSW ciphertext in coefficient domain + pub(crate) fn _sk_encrypt_rgsw( + m: &[u64], + s: &[i32], + decomposer: &(DefaultDecomposer, DefaultDecomposer), + mod_op: &ModularOpsU64, + ntt_op: &NttBackendU64, + ) -> SeededRgswCiphertext>, [u8; 32]> { + let ring_size = s.len(); + assert!(m.len() == s.len()); + + let q = mod_op.modulus(); + + let gadget_vector_a = decomposer.a().gadget_vector(); + let gadget_vector_b = decomposer.b().gadget_vector(); + + let mut rng = DefaultSecureRng::new(); + let mut rgsw_seed = [0u8; 32]; + rng.fill_bytes(&mut rgsw_seed); + let mut seeded_rgsw_ct = SeededRgswCiphertext::>, [u8; 32]>::empty( + ring_size as usize, + decomposer, + rgsw_seed, + q, + ); + let mut p_rng = DefaultSecureRng::new_seeded(rgsw_seed); + secret_key_encrypt_rgsw( + &mut seeded_rgsw_ct.data, + m, + &gadget_vector_a, + &gadget_vector_b, + s, + mod_op, + ntt_op, + &mut p_rng, + &mut rng, + ); + seeded_rgsw_ct + } + + /// Prints noise in RGSW ciphertext RGSW(m). + /// + /// - rgsw_ct: RGSW ciphertext in coefficient domain + pub(crate) fn _measure_noise_rgsw( + rgsw_ct: &[Vec], + m: &[u64], + s: &[i32], + decomposer: &(DefaultDecomposer, DefaultDecomposer), + q: u64, + ) { + let gadget_vector_a = decomposer.a().gadget_vector(); + let gadget_vector_b = decomposer.b().gadget_vector(); + let d_a = gadget_vector_a.len(); + let d_b = gadget_vector_b.len(); + let ring_size = s.len(); + assert!(Matrix::dimension(&rgsw_ct) == (d_a * 2 + d_b * 2, ring_size)); + assert!(m.len() == ring_size); + + let mod_op = ModularOpsU64::new(q); + let ntt_op = NttBackendU64::new(q, ring_size); + + let mul_mod = |a: &u64, b: &u64| ((*a as u128 * *b as u128) % q as u128) as u64; + let s_poly = Vec::::try_convert_from(s, &q); + let mut neg_s = s_poly.clone(); + mod_op.elwise_neg_mut(neg_s.as_mut()); + let neg_sm0m1 = negacyclic_mul(&neg_s, &m, mul_mod, q); + + // RLWE(\beta^j -s * m) + for j in 0..d_a { + let ideal_m = { + // RLWE(\beta^j -s * m) + let mut beta_neg_sm0m1 = vec![0u64; ring_size as usize]; + mod_op.elwise_scalar_mul(beta_neg_sm0m1.as_mut(), &neg_sm0m1, &gadget_vector_a[j]); + beta_neg_sm0m1 + }; + + let mut rlwe = vec![vec![0u64; ring_size as usize]; 2]; + rlwe[0].copy_from_slice(rgsw_ct.get_row_slice(j)); + rlwe[1].copy_from_slice(rgsw_ct.get_row_slice(d_a + j)); + let noise = measure_noise(&rlwe, &ideal_m, &ntt_op, &mod_op, s); + + println!(r"Noise RLWE(\beta^{j} -sm0m1): {noise}"); + } + + // RLWE(\beta^j m) + for j in 0..d_b { + let ideal_m = { + // RLWE(\beta^j m) + let mut beta_m0m1 = vec![0u64; ring_size as usize]; + mod_op.elwise_scalar_mul(beta_m0m1.as_mut(), &m, &gadget_vector_b[j]); + beta_m0m1 + }; + + let mut rlwe = vec![vec![0u64; ring_size as usize]; 2]; + rlwe[0].copy_from_slice(rgsw_ct.get_row_slice(d_a * 2 + j)); + rlwe[1].copy_from_slice(rgsw_ct.get_row_slice(d_a * 2 + d_b + j)); + let noise = measure_noise(&rlwe, &ideal_m, &ntt_op, &mod_op, s); + + println!(r"Noise RLWE(\beta^{j} m0m1): {noise}"); + } + } + #[test] fn rlwe_encrypt_decryption() { let logq = 50; @@ -1540,26 +1751,11 @@ pub(crate) mod tests { let mod_op = ModularOpsU64::new(q); // encrypt m0 - let mut rlwe_seed = [0u8; 32]; - rng.fill_bytes(&mut rlwe_seed); - let mut seeded_rlwe_in_ct = - SeededRlweCiphertext::<_, [u8; 32]>::empty(ring_size as usize, rlwe_seed, q); - let mut p_rng = DefaultSecureRng::new_with_seed(rlwe_seed); let encoded_m = m0 .iter() .map(|v| (((*v as f64) * q as f64) / (p as f64)).round() as u64) .collect_vec(); - secret_key_encrypt_rlwe( - &encoded_m, - &mut seeded_rlwe_in_ct.data, - s.values(), - &mod_op, - &ntt_op, - &mut p_rng, - &mut rng, - ); - let rlwe_in_ct = - RlweCiphertext::>, DefaultSecureRng>::from(&seeded_rlwe_in_ct); + let rlwe_in_ct = _sk_encrypt_rlwe(&encoded_m, s.values(), &ntt_op, &mod_op); let mut encoded_m_back = vec![0u64; ring_size as usize]; decrypt_rlwe( @@ -1575,8 +1771,8 @@ pub(crate) mod tests { .collect_vec(); assert_eq!(m0, m_back); - let noise = measure_noise(&rlwe_in_ct, &encoded_m, &ntt_op, &mod_op, s.values()); - println!("Noise: {noise}"); + // let noise = measure_noise(&rlwe_in_ct, &encoded_m, &ntt_op, &mod_op, + // s.values()); println!("Noise: {noise}"); } #[test] @@ -1585,9 +1781,7 @@ pub(crate) mod tests { let logp = 2; let ring_size = 1 << 9; let q = generate_prime(logq, ring_size, 1u64 << logq).unwrap(); - let p = 1u64 << logp; - let d_rgsw = 10; - let logb = 5; + let p: u64 = 1u64 << logp; let mut rng = DefaultSecureRng::new_seeded([0u8; 32]); @@ -1600,8 +1794,28 @@ pub(crate) mod tests { let ntt_op = NttBackendU64::new(q, ring_size as usize); let mod_op = ModularOpsU64::new(q); - let decomposer = DefaultDecomposer::new(q, logb, d_rgsw); - let gadget_vector = decomposer.gadget_vector(); + let d_rgsw = 10; + let logb = 5; + let decomposer = ( + DefaultDecomposer::new(q, logb, d_rgsw), + DefaultDecomposer::new(q, logb, d_rgsw), + ); + + // create public key + let mut pk_seed = [0u8; 32]; + rng.fill_bytes(&mut pk_seed); + let mut pk_prng = DefaultSecureRng::new_seeded(pk_seed); + let mut seeded_pk = + SeededRlwePublicKey::, _>::empty(ring_size as usize, pk_seed, q); + gen_rlwe_public_key( + &mut seeded_pk.data, + s.values(), + &ntt_op, + &mod_op, + &mut pk_prng, + &mut rng, + ); + let pk = RlwePublicKey::>, DefaultSecureRng>::from(&seeded_pk); // Encrypt m1 as RGSW(m1) let rgsw_ct = { @@ -1610,62 +1824,34 @@ pub(crate) mod tests { if true { // Encryption m1 as RGSW(m1) using secret key - _sk_encrypt_rgsw(&m1, s.values(), &gadget_vector, &mod_op, &ntt_op) + let seeded_rgsw_ct = + _sk_encrypt_rgsw(&m1, s.values(), &decomposer, &mod_op, &ntt_op); + RgswCiphertextEvaluationDomain::>, DefaultSecureRng, NttBackendU64>::from(&seeded_rgsw_ct) } else { - // Encrypt m1 as RGSW(m1) as public key - - // first create public key - let mut pk_seed = [0u8; 32]; - rng.fill_bytes(&mut pk_seed); - let mut pk_prng = DefaultSecureRng::new_seeded(pk_seed); - let mut seeded_pk = - SeededRlwePublicKey::, _>::empty(ring_size as usize, pk_seed, q); - gen_rlwe_public_key( - &mut seeded_pk.data, - s.values(), - &ntt_op, - &mod_op, - &mut pk_prng, - &mut rng, - ); - let pk = RlwePublicKey::>, DefaultSecureRng>::from(&seeded_pk); - - let rgsw_ct = _pk_encrypt_rgsw(&m1, &pk, &gadget_vector, &mod_op, &ntt_op); - RgswCiphertextEvaluationDomain::<_, DefaultSecureRng, NttBackendU64>::from( - &RgswCiphertext { - data: rgsw_ct.data, - modulus: q, - }, - ) + // Encrypt m1 as RGSW(m1) using public key + let rgsw_ct = _pk_encrypt_rgsw(&m1, &pk, &decomposer, &mod_op, &ntt_op); + RgswCiphertextEvaluationDomain::<_, DefaultSecureRng, NttBackendU64>::from(&rgsw_ct) } }; // Encrypt m0 as RLWE(m0) let mut rlwe_in_ct = { - let mut rlwe_seed = [0u8; 32]; - rng.fill_bytes(&mut rlwe_seed); - let mut seeded_rlwe_in_ct = - SeededRlweCiphertext::<_, [u8; 32]>::empty(ring_size as usize, rlwe_seed, q); - let mut p_rng = DefaultSecureRng::new_seeded(rlwe_seed); let encoded_m = m0 .iter() .map(|v| (((*v as f64) * q as f64) / (p as f64)).round() as u64) .collect_vec(); - secret_key_encrypt_rlwe( - &encoded_m, - &mut seeded_rlwe_in_ct.data, - s.values(), - &mod_op, - &ntt_op, - &mut p_rng, - &mut rng, - ); - RlweCiphertext::>, DefaultSecureRng>::from(&seeded_rlwe_in_ct) + _sk_encrypt_rlwe(&encoded_m, s.values(), &ntt_op, &mod_op) }; // RLWE(m0m1) = RLWE(m0) x RGSW(m1) - let mut scratch_space = vec![vec![0u64; ring_size as usize]; d_rgsw + 2]; + let mut scratch_space = vec![ + vec![0u64; ring_size as usize]; + std::cmp::max( + decomposer.a().decomposition_count(), + decomposer.b().decomposition_count() + ) + 2 + ]; rlwe_by_rgsw( &mut rlwe_in_ct, &rgsw_ct.data, @@ -1692,16 +1878,16 @@ pub(crate) mod tests { let mul_mod = |v0: &u64, v1: &u64| (v0 * v1) % p; let m0m1 = negacyclic_mul(&m0, &m1, mul_mod, p); - { - // measure noise - let encoded_m_ideal = m0m1 - .iter() - .map(|v| (((*v as f64) * q as f64) / (p as f64)).round() as u64) - .collect_vec(); + // { + // // measure noise + // let encoded_m_ideal = m0m1 + // .iter() + // .map(|v| (((*v as f64) * q as f64) / (p as f64)).round() as u64) + // .collect_vec(); - let noise = measure_noise(&rlwe_in_ct, &encoded_m_ideal, &ntt_op, &mod_op, s.values()); - println!("Noise RLWE(m0m1)(= RLWE(m0)xRGSW(m1)) : {noise}"); - } + // let noise = measure_noise(&rlwe_in_ct, &encoded_m_ideal, &ntt_op, + // &mod_op, s.values()); println!("Noise RLWE(m0m1)(= + // RLWE(m0)xRGSW(m1)) : {noise}"); } assert!( m0m1 == m0m1_back, @@ -1711,426 +1897,15 @@ pub(crate) mod tests { ); } - pub(crate) fn _secret_encrypt_rlwe( - m: &[u64], - s: &[i32], - ntt_op: &NttBackendU64, - mod_op: &ModularOpsU64, - ) -> RlweCiphertext>, DefaultSecureRng> { - let ring_size = m.len(); - let q = mod_op.modulus(); - assert!(s.len() == ring_size); - - let mut rng = DefaultSecureRng::new(); - let mut rlwe_seed = [0u8; 32]; - rng.fill_bytes(&mut rlwe_seed); - let mut seeded_rlwe_ct = - SeededRlweCiphertext::<_, [u8; 32]>::empty(ring_size as usize, rlwe_seed, q); - let mut p_rng = DefaultSecureRng::new_seeded(rlwe_seed); - secret_key_encrypt_rlwe( - &m, - &mut seeded_rlwe_ct.data, - s, - mod_op, - ntt_op, - &mut p_rng, - &mut rng, - ); - - RlweCiphertext::>, DefaultSecureRng>::from(&seeded_rlwe_ct) - } - - #[test] - fn rlwe_by_rgsw_noise_growth() { - let logq = 28; - let ring_size = 1 << 10; - let q = generate_prime(logq, ring_size * 2, 1u64 << logq).unwrap(); - let d_rgsw = 2; - let logb = 7; - - let s = RlweSecret::random((ring_size >> 1) as usize, ring_size as usize); - - let ntt_op = NttBackendU64::new(q, ring_size as usize); - let mod_op = ModularOpsU64::new(q); - let decomposer = DefaultDecomposer::new(q, logb, d_rgsw); - let gadget_vector = decomposer.gadget_vector(); - - let mul_mod = |v0: &u64, v1: &u64| ((*v0 as u128 * *v1 as u128) % (q as u128)) as u64; - - let mut carry_m = vec![0u64; ring_size as usize]; - carry_m[thread_rng().gen_range(0..ring_size) as usize] = 1; - let mut rlwe = _secret_encrypt_rlwe(&carry_m, s.values(), &ntt_op, &mod_op); - - let mut scratch_matrix_dplus2_ring = vec![vec![0u64; ring_size as usize]; d_rgsw + 2]; - for i in 0..1000usize { - // Encrypt monomial as RGSW - let mut m = vec![0u64; ring_size as usize]; - m[thread_rng().gen_range(0..ring_size) as usize] = if i & 1 == 1 { 1 } else { q - 1 }; - let rgsw_ct = _sk_encrypt_rgsw(&m, s.values(), &gadget_vector, &mod_op, &ntt_op); - - // RLWE(carry_m * m) = RLWE(carry_m) x RGSW(m) - rlwe_by_rgsw( - &mut rlwe, - &rgsw_ct.data, - &mut scratch_matrix_dplus2_ring, - &decomposer, - &ntt_op, - &mod_op, - ); - - carry_m = negacyclic_mul(&carry_m, &m, mul_mod, q); - let noise = measure_noise(&rlwe, &carry_m, &ntt_op, &mod_op, s.values()); - - println!("Noise RLWE(carry_m) after {i}^th iteration: {noise}"); - } - } - - // Encrypt m as RGSW ciphertext RGSW(m) using supplied public key - pub(crate) fn _pk_encrypt_rgsw( - m: &[u64], - public_key: &RlwePublicKey>, DefaultSecureRng>, - gadget_vector: &[u64], - mod_op: &ModularOpsU64, - ntt_op: &NttBackendU64, - ) -> RgswCiphertext>> { - let (_, ring_size) = Matrix::dimension(&public_key.data); - let d_rgsw = gadget_vector.len(); - - let mut rng = DefaultSecureRng::new(); - - assert!(m.len() == ring_size); - - // public key encrypt RGSW(m1) - let mut rgsw_ct = vec![vec![0u64; ring_size]; d_rgsw * 4]; - public_key_encrypt_rgsw( - &mut rgsw_ct, - m, - &public_key.data, - gadget_vector, - mod_op, - ntt_op, - &mut rng, - ); - - RgswCiphertext { - data: rgsw_ct, - modulus: mod_op.modulus(), - } - } - - /// Encrypts m as RGSW ciphertext RGSW(m) using supplied secret key. Returns - /// unseeded RGSW ciphertext in coefficient domain - pub(crate) fn _sk_encrypt_rgsw( - m: &[u64], - s: &[i32], - gadget_vector: &[u64], - mod_op: &ModularOpsU64, - ntt_op: &NttBackendU64, - ) -> RgswCiphertextEvaluationDomain>, DefaultSecureRng, NttBackendU64> { - let ring_size = s.len(); - assert!(m.len() == s.len()); - - let d_rgsw = gadget_vector.len(); - let q = mod_op.modulus(); - - let mut rng = DefaultSecureRng::new(); - let mut rgsw_seed = [0u8; 32]; - rng.fill_bytes(&mut rgsw_seed); - let mut seeded_rgsw_ct = SeededRgswCiphertext::>, [u8; 32]>::empty( - ring_size as usize, - d_rgsw, - rgsw_seed, - q, - ); - let mut p_rng = DefaultSecureRng::new_seeded(rgsw_seed); - secret_key_encrypt_rgsw( - &mut seeded_rgsw_ct.data, - m, - &gadget_vector, - s, - mod_op, - ntt_op, - &mut p_rng, - &mut rng, - ); - - RgswCiphertextEvaluationDomain::<_, DefaultSecureRng, NttBackendU64>::from(&seeded_rgsw_ct) - } - - /// Prints noise in RGSW ciphertext RGSW(m). - /// - /// - rgsw_ct: RGSW ciphertext in coefficient domain - pub(crate) fn _measure_noise_rgsw( - rgsw_ct: &[Vec], - m: &[u64], - s: &[i32], - gadget_vector: &[u64], - q: u64, - ) { - let d_rgsw = gadget_vector.len(); - let ring_size = s.len(); - assert!(Matrix::dimension(&rgsw_ct) == (d_rgsw * 2 * 2, ring_size)); - assert!(m.len() == ring_size); - - let mod_op = ModularOpsU64::new(q); - let ntt_op = NttBackendU64::new(q, ring_size); - - let mul_mod = |a: &u64, b: &u64| ((*a as u128 * *b as u128) % q as u128) as u64; - let s_poly = Vec::::try_convert_from(s, &q); - let mut neg_s = s_poly.clone(); - mod_op.elwise_neg_mut(neg_s.as_mut()); - let neg_sm0m1 = negacyclic_mul(&neg_s, &m, mul_mod, q); - for i in 0..2 { - for j in 0..d_rgsw { - let ideal_m = { - if i == 0 { - // RLWE(\beta^j -s * m) - let mut beta_neg_sm0m1 = vec![0u64; ring_size as usize]; - mod_op.elwise_scalar_mul( - beta_neg_sm0m1.as_mut(), - &neg_sm0m1, - &gadget_vector[j], - ); - beta_neg_sm0m1 - } else { - // RLWE(\beta^j m) - let mut beta_m0m1 = vec![0u64; ring_size as usize]; - mod_op.elwise_scalar_mul(beta_m0m1.as_mut(), &m, &gadget_vector[j]); - beta_m0m1 - } - }; - - let mut rlwe = vec![vec![0u64; ring_size as usize]; 2]; - rlwe[0].copy_from_slice(rgsw_ct.get_row_slice((i * 2 * d_rgsw) + j)); - rlwe[1].copy_from_slice(rgsw_ct.get_row_slice((i * 2 * d_rgsw) + d_rgsw + j)); - let noise = measure_noise(&rlwe, &ideal_m, &ntt_op, &mod_op, s); - - if i == 0 { - println!(r"Noise RLWE(\beta^{j} -sm0m1): {noise}"); - } else { - println!(r"Noise RLWE(\beta^{j} m0m1): {noise}"); - } - } - // m0m1 - } - } - - #[test] - fn pk_rgsw_by_rgsw() { - let logq = 60; - let logp = 2; - let ring_size = 1 << 11; - let q = generate_prime(logq, ring_size, 1u64 << logq).unwrap(); - let p = 1u64 << logp; - let d_rgsw = 3; - let logb = 15; - - let s = RlweSecret::random((ring_size >> 1) as usize, ring_size as usize); - - let mut rng = DefaultSecureRng::new(); - let ntt_op = NttBackendU64::new(q, ring_size as usize); - let mod_op = ModularOpsU64::new(q); - let decomposer = DefaultDecomposer::new(q, logb, d_rgsw); - let gadget_vector = decomposer.gadget_vector(); - - let mul_mod = |a: &u64, b: &u64| ((*a as u128 * *b as u128) % q as u128) as u64; - - // Public Key - let public_key = { - let mut pk_seed = [0u8; 32]; - rng.fill_bytes(&mut pk_seed); - let mut pk_prng = DefaultSecureRng::new_seeded(pk_seed); - let mut seeded_pk = - SeededRlwePublicKey::, _>::empty(ring_size as usize, pk_seed, q); - gen_rlwe_public_key( - &mut seeded_pk.data, - s.values(), - &ntt_op, - &mod_op, - &mut pk_prng, - &mut rng, - ); - RlwePublicKey::>, DefaultSecureRng>::from(&seeded_pk) - }; - - let mut carry_m = vec![0u64; ring_size as usize]; - carry_m[thread_rng().gen_range(0..ring_size) as usize] = 1; - - // RGSW(carry_m) - let mut rgsw_carrym = - _pk_encrypt_rgsw(&carry_m, &public_key, &gadget_vector, &mod_op, &ntt_op); - // let mut rgsw_carrym = { - // let mut rgsw_eval = - // _sk_encrypt_rgsw(&carry_m, s.values(), &gadget_vector, &mod_op, - // &ntt_op); rgsw_eval - // .data - // .iter_mut() - // .for_each(|ri| ntt_op.backward(ri.as_mut())); - // rgsw_eval.data - // }; - - println!("########### Noise RGSW(carrym) at start ###########"); - _measure_noise_rgsw(&rgsw_carrym.data, &carry_m, s.values(), &gadget_vector, q); - - let mut scratch_matrix_d_plus_rgsw_by_ring = - vec![vec![0u64; ring_size as usize]; d_rgsw + (d_rgsw * 4)]; - - for i in 0..1 { - let mut m = vec![0u64; ring_size as usize]; - m[thread_rng().gen_range(0..ring_size) as usize] = q - 1; - let rgsw_m = { - RgswCiphertextEvaluationDomain::<_, DefaultSecureRng, NttBackendU64>::from( - &_pk_encrypt_rgsw(&m, &public_key, &gadget_vector, &mod_op, &ntt_op), - ) - }; - - rgsw_by_rgsw_inplace( - &mut rgsw_carrym.data, - &rgsw_m.data, - &decomposer, - &mut scratch_matrix_d_plus_rgsw_by_ring, - &ntt_op, - &mod_op, - ); - - // measure noise - carry_m = negacyclic_mul(&carry_m, &m, mul_mod, q); - println!("########### Noise RGSW(carrym) in {i}^th loop ###########"); - _measure_noise_rgsw(&rgsw_carrym.data, &carry_m, s.values(), &gadget_vector, q); - } - - { - // RLWE(m) x RGSW(carry_m) - let mut m = vec![0u64; ring_size as usize]; - RandomUniformDist::random_fill(&mut rng, &q, m.as_mut_slice()); - let mut rlwe_ct = RlweCiphertext::<_, DefaultSecureRng> { - data: vec![vec![0u64; ring_size as usize]; 2], - is_trivial: false, - _phatom: PhantomData, - }; - let mut scratch_matrix_dplus2_ring = vec![vec![0u64; ring_size as usize]; d_rgsw + 2]; - public_key_encrypt_rlwe( - &mut rlwe_ct, - &public_key.data, - &m, - &mod_op, - &ntt_op, - &mut rng, - ); - rlwe_by_rgsw( - &mut rlwe_ct, - &RgswCiphertextEvaluationDomain::<_, DefaultSecureRng, NttBackendU64>::from( - &rgsw_carrym, - ) - .data, - &mut scratch_matrix_dplus2_ring, - &decomposer, - &ntt_op, - &mod_op, - ); - let m_expected = negacyclic_mul(&carry_m, &m, mul_mod, q); - let noise = measure_noise(&rlwe_ct, &m_expected, &ntt_op, &mod_op, s.values()); - println!( - "RLWE(m) x RGSW(carry_m): - {noise}" - ); - } - } - - #[test] - fn sk_rgsw_by_rgsw() { - let logq = 60; - let logp = 2; - let ring_size = 1 << 11; - let q = generate_prime(logq, ring_size, 1u64 << logq).unwrap(); - let p = 1u64 << logp; - let d_rgsw = 5; - let logb = 12; - - let s = RlweSecret::random((ring_size >> 1) as usize, ring_size as usize); - - let mut rng = DefaultSecureRng::new(); - let ntt_op = NttBackendU64::new(q, ring_size as usize); - let mod_op = ModularOpsU64::new(q); - let decomposer = DefaultDecomposer::new(q, logb, d_rgsw); - let gadget_vector = decomposer.gadget_vector(); - let mul_mod = |a: &u64, b: &u64| ((*a as u128 * *b as u128) % q as u128) as u64; - - let mut carry_m = vec![0u64; ring_size as usize]; - carry_m[thread_rng().gen_range(0..ring_size) as usize] = 1; - - // RGSW(carry_m) - let mut rgsw_carrym = { - let mut rgsw_eval = - _sk_encrypt_rgsw(&carry_m, s.values(), &gadget_vector, &mod_op, &ntt_op); - rgsw_eval - .data - .iter_mut() - .for_each(|ri| ntt_op.backward(ri.as_mut())); - rgsw_eval.data - }; - println!("########### Noise RGSW(carrym) at start ###########"); - _measure_noise_rgsw(&rgsw_carrym, &carry_m, s.values(), &gadget_vector, q); - - let mut scratch_matrix_d_plus_rgsw_by_ring = - vec![vec![0u64; ring_size as usize]; d_rgsw + (d_rgsw * 4)]; - - for i in 0..1 { - let mut m = vec![0u64; ring_size as usize]; - m[thread_rng().gen_range(0..ring_size) as usize] = if (i & 1) == 1 { q - 1 } else { 1 }; - let rgsw_m = _sk_encrypt_rgsw(&m, s.values(), &gadget_vector, &mod_op, &ntt_op); - rgsw_by_rgsw_inplace( - &mut rgsw_carrym, - &rgsw_m.data, - &decomposer, - &mut scratch_matrix_d_plus_rgsw_by_ring, - &ntt_op, - &mod_op, - ); - - // measure noise - carry_m = negacyclic_mul(&carry_m, &m, mul_mod, q); - println!("########### Noise RGSW(carrym) in {i}^th loop ###########"); - _measure_noise_rgsw(&rgsw_carrym, &carry_m, s.values(), &gadget_vector, q); - } - { - // RLWE(m) x RGSW(carry_m) - let mut m = vec![0u64; ring_size as usize]; - RandomUniformDist::random_fill(&mut rng, &q, m.as_mut_slice()); - let mut rlwe_ct = _secret_encrypt_rlwe(&m, s.values(), &ntt_op, &mod_op); - let mut scratch_matrix_dplus2_ring = vec![vec![0u64; ring_size as usize]; d_rgsw + 2]; - - // send rgsw to evaluation domain - rgsw_carrym - .iter_mut() - .for_each(|ri| ntt_op.forward(ri.as_mut_slice())); - - rlwe_by_rgsw( - &mut rlwe_ct, - &rgsw_carrym, - &mut scratch_matrix_dplus2_ring, - &decomposer, - &ntt_op, - &mod_op, - ); - let m_expected = negacyclic_mul(&carry_m, &m, mul_mod, q); - let noise = measure_noise(&rlwe_ct, &m_expected, &ntt_op, &mod_op, s.values()); - println!( - "RLWE(m) x RGSW(carry_m): - {noise}" - ); - } - } - - #[test] - fn galois_auto_works() { - let logq = 50; - let ring_size = 1 << 4; - let q = generate_prime(logq, 2 * ring_size, 1u64 << logq).unwrap(); - let logp = 3; - let p = 1u64 << logp; - let d_rgsw = 10; - let logb = 5; + #[test] + fn galois_auto_works() { + let logq = 50; + let ring_size = 1 << 8; + let q = generate_prime(logq, 2 * ring_size, 1u64 << logq).unwrap(); + let logp = 3; + let p = 1u64 << logp; + let d_rgsw = 10; + let logb = 5; let mut rng = DefaultSecureRng::new(); let s = RlweSecret::random((ring_size >> 1) as usize, ring_size as usize); @@ -2164,11 +1939,12 @@ pub(crate) mod tests { let auto_k = -5; // Generate galois key to key switch from s^k to s + let decomposer = DefaultDecomposer::new(q, logb, d_rgsw); let mut seed_auto = [0u8; 32]; rng.fill_bytes(&mut seed_auto); - let mut seeded_auto_key = SeededAutoKey::empty(ring_size as usize, d_rgsw, seed_auto, q); + let mut seeded_auto_key = + SeededAutoKey::empty(ring_size as usize, &decomposer, seed_auto, q); let mut p_rng = DefaultSecureRng::new_seeded(seed_auto); - let decomposer = DefaultDecomposer::new(q, logb, d_rgsw); let gadget_vector = decomposer.gadget_vector(); galois_key_gen( &mut seeded_auto_key.data, @@ -2239,4 +2015,100 @@ pub(crate) mod tests { assert_eq!(m_k_back, m_k); } + + #[test] + fn sk_rgsw_by_rgsw() { + let logq = 60; + let logp = 2; + let ring_size = 1 << 11; + let q = generate_prime(logq, ring_size, 1u64 << logq).unwrap(); + let p = 1u64 << logp; + let d_rgsw = 5; + let logb = 12; + + let s = RlweSecret::random((ring_size >> 1) as usize, ring_size as usize); + + let mut rng = DefaultSecureRng::new(); + let ntt_op = NttBackendU64::new(q, ring_size as usize); + let mod_op = ModularOpsU64::new(q); + let decomposer = ( + DefaultDecomposer::new(q, logb, d_rgsw), + DefaultDecomposer::new(q, logb, d_rgsw), + ); + + let mul_mod = |a: &u64, b: &u64| ((*a as u128 * *b as u128) % q as u128) as u64; + + let mut carry_m = vec![0u64; ring_size as usize]; + carry_m[thread_rng().gen_range(0..ring_size) as usize] = 1; + + // RGSW(carry_m) + let mut rgsw_carrym = { + let seeded_rgsw = _sk_encrypt_rgsw(&carry_m, s.values(), &decomposer, &mod_op, &ntt_op); + let mut rgsw_eval = + RgswCiphertextEvaluationDomain::<_, DefaultSecureRng, NttBackendU64>::from( + &seeded_rgsw, + ); + rgsw_eval + .data + .iter_mut() + .for_each(|ri| ntt_op.backward(ri.as_mut())); + rgsw_eval.data + }; + + let mut scratch_matrix = vec![ + vec![0u64; ring_size as usize]; + decomposer.a().decomposition_count() * 2 + + decomposer.b().decomposition_count() * 2 + + std::cmp::max( + decomposer.a().decomposition_count(), + decomposer.b().decomposition_count() + ) + ]; + + // _measure_noise_rgsw(&rgsw_carrym, &carry_m, s.values(), &decomposer, q); + + for i in 0..1 { + let mut m = vec![0u64; ring_size as usize]; + m[thread_rng().gen_range(0..ring_size) as usize] = if (i & 1) == 1 { q - 1 } else { 1 }; + let rgsw_m = RgswCiphertextEvaluationDomain::<_, DefaultSecureRng, NttBackendU64>::from( + &_sk_encrypt_rgsw(&m, s.values(), &decomposer, &mod_op, &ntt_op), + ); + rgsw_by_rgsw_inplace( + &mut rgsw_carrym, + &rgsw_m.data, + &decomposer, + &mut scratch_matrix, + &ntt_op, + &mod_op, + ); + + // measure noise + carry_m = negacyclic_mul(&carry_m, &m, mul_mod, q); + println!("########### Noise RGSW(carrym) in {i}^th loop ###########"); + _measure_noise_rgsw(&rgsw_carrym, &carry_m, s.values(), &decomposer, q); + } + { + // RLWE(m) x RGSW(carry_m) + let mut m = vec![0u64; ring_size as usize]; + RandomUniformDist::random_fill(&mut rng, &q, m.as_mut_slice()); + let mut rlwe_ct = _sk_encrypt_rlwe(&m, s.values(), &ntt_op, &mod_op); + + // send rgsw to evaluation domain + rgsw_carrym + .iter_mut() + .for_each(|ri| ntt_op.forward(ri.as_mut_slice())); + + rlwe_by_rgsw( + &mut rlwe_ct, + &rgsw_carrym, + &mut scratch_matrix, + &decomposer, + &ntt_op, + &mod_op, + ); + let m_expected = negacyclic_mul(&carry_m, &m, mul_mod, q); + let noise = measure_noise(&rlwe_ct, &m_expected, &ntt_op, &mod_op, s.values()); + println!("RLWE(m) x RGSW(carry_m): {noise}"); + } + } }