use std::{collections::HashMap, fmt::Debug, marker::PhantomData}; use itertools::Itertools; use num_traits::{FromPrimitive, Num, One, PrimInt, ToPrimitive, Zero}; use crate::{ backend::{ArithmeticOps, ModInit, VectorOps}, decomposer::{gadget_vector, Decomposer, DefaultDecomposer, NumInfo}, lwe::{decrypt_lwe, encrypt_lwe, lwe_key_switch, lwe_ksk_keygen, LweSecret}, ntt::{Ntt, NttInit}, random::{DefaultSecureRng, RandomGaussianDist, RandomUniformDist}, rgsw::{encrypt_rgsw, galois_auto, galois_key_gen, rlwe_by_rgsw, IsTrivial, RlweSecret}, utils::{generate_prime, mod_exponent, TryConvertFrom, WithLocal}, Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, Secret, }; trait PbsKey { type M: Matrix; fn rgsw_ct_secret_el(&self, si: usize) -> &Self::M; fn galois_key_for_auto(&self, k: isize) -> &Self::M; fn auto_map_index(&self, k: isize) -> &[usize]; fn auto_map_sign(&self, k: isize) -> &[bool]; } trait Parameters { type Element; type D: Decomposer; fn rlwe_q(&self) -> Self::Element; fn lwe_q(&self) -> Self::Element; fn br_q(&self) -> usize; fn d_rgsw(&self) -> usize; fn d_lwe(&self) -> usize; fn rlwe_n(&self) -> usize; fn lwe_n(&self) -> usize; /// Embedding fator for ring X^{q}+1 inside fn embedding_factor(&self) -> usize; /// generator g fn g(&self) -> isize; fn decomoposer_lwe(&self) -> &Self::D; fn decomoposer_rlwe(&self) -> &Self::D; /// Maps a \in Z^*_{q} to discrete log k, with generator g (i.e. g^k = /// a). Returned vector is of size q that stores dlog of a at `vec[a]`. /// For any a, if k is s.t. a = g^{k}, then k is expressed as k. If k is s.t /// a = -g^{k}, then k is expressed as k=k+q/2 fn g_k_dlog_map(&self) -> &[usize]; } struct ClientKey { sk_rlwe: RlweSecret, sk_lwe: LweSecret, } struct ServerKey { /// Rgsw cts of LWE secret elements rgsw_cts: Vec, /// Galois keys galois_keys: HashMap, /// LWE ksk to key switching LWE ciphertext from RLWE secret to LWE secret lwe_ksk: M, } struct BoolParameters { rlwe_q: El, rlwe_logq: usize, lwe_q: El, lwe_logq: usize, br_q: usize, rlwe_n: usize, lwe_n: usize, d_rgsw: usize, logb_rgsw: usize, d_lwe: usize, logb_lwe: usize, g: usize, w: usize, } struct BoolEvaluator { parameters: BoolParameters, decomposer_rlwe: DefaultDecomposer, decomposer_lwe: DefaultDecomposer, g_k_dlog_map: Vec, rlwe_nttop: Ntt, rlwe_modop: ModOp, lwe_modop: ModOp, embedding_factor: usize, _phantom: PhantomData, } impl BoolEvaluator where NttOp: NttInit + Ntt, ModOp: ModInit + ArithmeticOps + VectorOps, M::MatElement: PrimInt + Debug + NumInfo + FromPrimitive, M: MatrixEntity + MatrixMut, M::R: TryConvertFrom<[i32], Parameters = M::MatElement> + RowEntity, M: TryConvertFrom<[i32], Parameters = M::MatElement>, ::R: RowMut, DefaultSecureRng: RandomGaussianDist<[M::MatElement], Parameters = M::MatElement> + RandomGaussianDist + RandomUniformDist<[M::MatElement], Parameters = M::MatElement>, { fn new(parameters: BoolParameters) -> Self { //TODO(Jay): Run sanity checks for modulus values in parameters let decomposer_rlwe = DefaultDecomposer::new(parameters.rlwe_q, parameters.logb_rgsw, parameters.d_rgsw); let decomposer_lwe = DefaultDecomposer::new(parameters.lwe_q, parameters.logb_lwe, parameters.d_lwe); // generatr dlog map s.t. g^{k} % q = a, for all a \in Z*_{q} let g = parameters.g; let q = parameters.br_q; let mut g_k_dlog_map = vec![0usize; q]; for i in 0..q / 2 { let v = mod_exponent(g as u64, i as u64, q as u64) as usize; // g^i g_k_dlog_map[v] = i; // -(g^i) g_k_dlog_map[q - v] = i + (q / 2); } let embedding_factor = (2 * parameters.rlwe_n) / q; let rlwe_nttop = NttOp::new(parameters.rlwe_q, parameters.rlwe_n); let rlwe_modop = ModInit::new(parameters.rlwe_q); let lwe_modop = ModInit::new(parameters.lwe_q); BoolEvaluator { parameters: parameters, decomposer_lwe, decomposer_rlwe, g_k_dlog_map, embedding_factor, lwe_modop, rlwe_modop, rlwe_nttop, _phantom: PhantomData, } } fn client_key(&self) -> ClientKey { let sk_lwe = LweSecret::random(self.parameters.lwe_n >> 1, self.parameters.lwe_n); let sk_rlwe = RlweSecret::random(self.parameters.rlwe_n >> 1, self.parameters.rlwe_n); ClientKey { sk_rlwe, sk_lwe } } fn server_key(&self, client_key: &ClientKey) -> ServerKey { let sk_rlwe = &client_key.sk_rlwe; let sk_lwe = &client_key.sk_lwe; let d_rgsw_gadget_vec = gadget_vector( self.parameters.rlwe_logq, self.parameters.logb_rgsw, self.parameters.d_rgsw, ); // generate galois key -g, g let mut galois_keys = HashMap::new(); let g = self.parameters.g as isize; for i in [g, -g] { let gk = DefaultSecureRng::with_local_mut(|rng| { let mut ksk_out = M::zeros(self.parameters.d_rgsw * 2, self.parameters.rlwe_n); galois_key_gen( &mut ksk_out, sk_rlwe, i, &d_rgsw_gadget_vec, &self.rlwe_modop, &self.rlwe_nttop, rng, ); ksk_out }); galois_keys.insert(i, gk); } // generate rgsw ciphertexts RGSW(si) where si is i^th LWE secret element let ring_size = self.parameters.rlwe_n; let rlwe_q = self.parameters.rlwe_q; let rgsw_cts = sk_lwe .values() .iter() .map(|si| { // X^{si}; assume |emebedding_factor * si| < N let mut m = M::zeros(1, ring_size); let si = (self.embedding_factor as i32) * si; if si < 0 { // X^{-i} = X^{2N - i} = -X^{N-i} m.set( 0, ring_size - (si.abs() as usize), rlwe_q - M::MatElement::one(), ); } else { // X^{i} m.set(0, (si.abs() as usize), M::MatElement::one()); } self.rlwe_nttop.forward(m.get_row_mut(0)); let rgsw_si = DefaultSecureRng::with_local_mut(|rng| { let mut rgsw_si = M::zeros(self.parameters.d_rgsw * 4, ring_size); encrypt_rgsw( &mut rgsw_si, &m, &d_rgsw_gadget_vec, sk_rlwe, &self.rlwe_modop, &self.rlwe_nttop, rng, ); rgsw_si }); rgsw_si }) .collect_vec(); // LWE KSK from RLWE secret s -> LWE secret z let d_lwe_gadget = gadget_vector( self.parameters.lwe_logq, self.parameters.logb_lwe, self.parameters.d_lwe, ); let mut lwe_ksk = DefaultSecureRng::with_local_mut(|rng| { let mut out = M::zeros(self.parameters.d_lwe * ring_size, self.parameters.lwe_n + 1); lwe_ksk_keygen( &sk_rlwe.values(), &sk_lwe.values(), &mut out, &d_lwe_gadget, &self.lwe_modop, rng, ); out }); ServerKey { rgsw_cts, galois_keys, lwe_ksk, } } pub fn encrypt(&self, m: bool, client_key: &ClientKey) -> M::R { let rlwe_q_by8 = M::MatElement::from_f64((self.parameters.rlwe_q.to_f64().unwrap() / 8.0).round()) .unwrap(); let m = if m { // Q/8 rlwe_q_by8 } else { // -Q/8 self.parameters.rlwe_q - rlwe_q_by8 }; DefaultSecureRng::with_local_mut(|rng| { let mut lwe_out = M::R::zeros(self.parameters.rlwe_n + 1); encrypt_lwe( &mut lwe_out, &m, client_key.sk_rlwe.values(), &self.rlwe_modop, rng, ); lwe_out }) } pub fn decrypt(&self, lwe_ct: &M::R, client_key: &ClientKey) -> bool { let m = decrypt_lwe(lwe_ct, client_key.sk_rlwe.values(), &self.rlwe_modop); let m = { // m + q/8 => {0,q/4 1} let rlwe_q_by8 = M::MatElement::from_f64((self.parameters.rlwe_q.to_f64().unwrap() / 8.0).round()) .unwrap(); (((m + rlwe_q_by8).to_f64().unwrap() * 4.0) / self.parameters.rlwe_q.to_f64().unwrap()) .round() .to_usize() .unwrap() % 4 }; if m == 0 { false } else if m == 1 { true } else { panic!("Incorrect bool decryption. Got m={m} expected m to be 0 or 1") } } } /// LMKCY+ Blind rotation /// /// gk_to_si: [-g^0, -g^1, .., -g^{q/2-1}, g^0, ..., g^{q/2-1}] fn blind_rotation< MT: IsTrivial + MatrixMut, Mmut: MatrixMut + Matrix, D: Decomposer, NttOp: Ntt, ModOp: ArithmeticOps + VectorOps, K: PbsKey, >( trivial_rlwe_test_poly: &mut MT, scratch_matrix_dplus2_ring: &mut Mmut, g: isize, w: usize, q: usize, gk_to_si: &[Vec], decomposer: &D, ntt_op: &NttOp, mod_op: &ModOp, pbs_key: &K, ) where ::R: RowMut, Mmut::MatElement: Copy + Zero, ::R: RowMut, { let q_by_2 = q / 2; // -(g^k) for i in 1..q_by_2 { gk_to_si[q_by_2 + i].iter().for_each(|s_index| { rlwe_by_rgsw( trivial_rlwe_test_poly, pbs_key.rgsw_ct_secret_el(*s_index), scratch_matrix_dplus2_ring, decomposer, ntt_op, mod_op, ); }); galois_auto( trivial_rlwe_test_poly, pbs_key.galois_key_for_auto(g), scratch_matrix_dplus2_ring, pbs_key.auto_map_index(g), pbs_key.auto_map_sign(g), mod_op, ntt_op, decomposer, ); } // -(g^0) gk_to_si[q_by_2].iter().for_each(|s_index| { rlwe_by_rgsw( trivial_rlwe_test_poly, pbs_key.rgsw_ct_secret_el(*s_index), scratch_matrix_dplus2_ring, decomposer, ntt_op, mod_op, ); }); galois_auto( trivial_rlwe_test_poly, pbs_key.galois_key_for_auto(-g), scratch_matrix_dplus2_ring, pbs_key.auto_map_index(-g), pbs_key.auto_map_sign(-g), mod_op, ntt_op, decomposer, ); // +(g^k) for i in 1..q_by_2 { gk_to_si[i].iter().for_each(|s_index| { rlwe_by_rgsw( trivial_rlwe_test_poly, pbs_key.rgsw_ct_secret_el(*s_index), scratch_matrix_dplus2_ring, decomposer, ntt_op, mod_op, ); }); galois_auto( trivial_rlwe_test_poly, pbs_key.galois_key_for_auto(g), scratch_matrix_dplus2_ring, pbs_key.auto_map_index(g), pbs_key.auto_map_sign(g), mod_op, ntt_op, decomposer, ); } // +(g^0) gk_to_si[0].iter().for_each(|s_index| { rlwe_by_rgsw( trivial_rlwe_test_poly, pbs_key.rgsw_ct_secret_el(gk_to_si[q_by_2][*s_index]), scratch_matrix_dplus2_ring, decomposer, ntt_op, mod_op, ); }); } /// - Mod down /// - key switching /// - mod down /// - blind rotate fn pbs< M: Matrix + MatrixMut + MatrixEntity, MT: MatrixMut + IsTrivial + MatrixEntity, P: Parameters, NttOp: Ntt, ModOp: ArithmeticOps + VectorOps, K: PbsKey, >( parameters: &P, test_vec: &M::R, lwe_in: &mut M::R, lwe_ksk: &M, scratch_lwen_plus1: &mut M::R, scratch_matrix_dplus2_ring: &mut M, modop_lweq: &ModOp, modop_rlweq: &ModOp, nttop_rlweq: &NttOp, pbs_key: K, ) where ::R: RowMut, ::R: RowMut, M::MatElement: PrimInt + ToPrimitive + FromPrimitive + One + Copy + Zero, { let rlwe_q = parameters.rlwe_q(); let lwe_q = parameters.lwe_q(); let br_q = parameters.br_q(); let rlwe_qf64 = rlwe_q.to_f64().unwrap(); let lwe_qf64 = lwe_q.to_f64().unwrap(); let br_qf64 = br_q.to_f64().unwrap(); let rlwe_n = parameters.rlwe_n(); // moddown Q -> Q_ks lwe_in.as_mut().iter_mut().for_each(|v| { *v = M::MatElement::from_f64(((v.to_f64().unwrap() * lwe_qf64) / rlwe_qf64).round()).unwrap() }); // key switch // let mut lwe_out = M::zeros(1, parameters.lwe_n() + 1); scratch_lwen_plus1.as_mut().fill(M::MatElement::zero()); lwe_key_switch( scratch_lwen_plus1, lwe_in, lwe_ksk, modop_lweq, parameters.decomoposer_lwe(), ); // odd mowdown Q_ks -> q let g_k_dlog_map = parameters.g_k_dlog_map(); let mut g_k_si = vec![vec![]; br_q]; scratch_lwen_plus1 .as_ref() .iter() .skip(1) .enumerate() .for_each(|(index, v)| { let odd_v = mod_switch_odd(v.to_f64().unwrap(), lwe_qf64, br_qf64); let k = g_k_dlog_map[odd_v]; g_k_si[k].push(index); }); // handle b and set trivial test RLWE let g = parameters.g() as usize; let g_times_b = (g * mod_switch_odd( scratch_lwen_plus1.as_ref()[0].to_f64().unwrap(), lwe_qf64, br_qf64, )) % (br_q); // v = (v(X) * X^{g*b}) mod X^{q/2}+1 let br_qby2 = br_q / 2; let mut gb_monomial_sign = true; let mut gb_monomial_exp = g_times_b; // X^{g*b} mod X^{q}+1 if gb_monomial_exp > br_qby2 { gb_monomial_exp -= br_qby2; gb_monomial_sign = false } // monomial mul let mut trivial_rlwe_test_poly = MT::zeros(2, rlwe_n); if parameters.embedding_factor() == 1 { monomial_mul( test_vec.as_ref(), trivial_rlwe_test_poly.get_row_mut(1).as_mut(), gb_monomial_exp, gb_monomial_sign, br_q, modop_rlweq, ); } else { // use lwe_in to store the `t = v(X) * X^{g*2} mod X^{q/2}+1` temporarily. This // works because q/2 < N (where N is lwe_in LWE dimension) always. monomial_mul( test_vec.as_ref(), &mut lwe_in.as_mut()[..br_qby2], gb_monomial_exp, gb_monomial_sign, br_q, modop_rlweq, ); // emebed poly `t` in ring X^{q/2}+1 inside the bigger ring X^{N}+1 let partb_trivial_rlwe = trivial_rlwe_test_poly.get_row_mut(1); lwe_in.as_ref()[..br_qby2] .iter() .enumerate() .for_each(|(index, v)| { partb_trivial_rlwe[2 * index] = *v; }); } // blind rotate blind_rotation( &mut trivial_rlwe_test_poly, scratch_matrix_dplus2_ring, parameters.g(), 1, br_q, &g_k_si, parameters.decomoposer_rlwe(), nttop_rlweq, modop_rlweq, &pbs_key, ); // sample extract sample_extract(lwe_in, &trivial_rlwe_test_poly, modop_rlweq, 0); } fn mod_switch_odd(v: f64, from_q: f64, to_q: f64) -> usize { let odd_v = (((v.to_f64().unwrap() * to_q) / (from_q)).floor()) .to_usize() .unwrap(); //TODO(Jay): check correctness of this odd_v + (odd_v ^ (usize::one())) } fn sample_extract>( lwe_out: &mut M::R, rlwe_in: &M, mod_op: &ModOp, index: usize, ) where ::R: RowMut, M::MatElement: Copy, { let ring_size = rlwe_in.dimension().1; // index..=0 let to = &mut lwe_out.as_mut()[1..]; let from = rlwe_in.get_row_slice(0); for i in 0..index + 1 { to[i] = from[index - i]; } // -(N..index) for i in index + 1..ring_size { to[i] = mod_op.neg(&from[ring_size + index - i]); } // set b lwe_out.as_mut()[0] = *rlwe_in.get(1, index); } fn monomial_mul>( p_in: &[El], p_out: &mut [El], mon_exp: usize, mon_sign: bool, ring_size: usize, mod_op: &ModOp, ) where El: Copy, { debug_assert!(p_in.as_ref().len() == ring_size); debug_assert!(p_in.as_ref().len() == p_out.as_ref().len()); debug_assert!(mon_exp < ring_size); p_in.as_ref().iter().enumerate().for_each(|(index, v)| { let mut to_index = index + mon_exp; let mut to_sign = mon_sign; if to_index >= ring_size { to_index = to_index - ring_size; to_sign = !to_sign; } if !to_sign { p_out.as_mut()[to_index] = mod_op.neg(v); } else { p_out.as_mut()[to_index] = *v; } }); } #[cfg(test)] mod tests { use crate::{backend::ModularOpsU64, ntt::NttBackendU64}; use super::*; const SP_BOOL_PARAMS: BoolParameters = BoolParameters:: { rlwe_q: 4294957057u64, rlwe_logq: 32, lwe_q: 1 << 16, lwe_logq: 16, br_q: 1 << 9, rlwe_n: 1 << 10, lwe_n: 490, d_rgsw: 4, logb_rgsw: 7, d_lwe: 4, logb_lwe: 4, g: 5, w: 1, }; #[test] fn encrypt_decrypt_works() { // let prime = generate_prime(32, 2 * 1024, 1 << 32); // dbg!(prime); let bool_evaluator = BoolEvaluator::>, u64, NttBackendU64, ModularOpsU64>::new(SP_BOOL_PARAMS); let client_key = bool_evaluator.client_key(); // let sever_key = bool_evaluator.server_key(&client_key); let mut m = true; for _ in 0..1000 { let lwe_ct = bool_evaluator.encrypt(m, &client_key); let m_back = bool_evaluator.decrypt(&lwe_ct, &client_key); assert_eq!(m, m_back); m = !m; } } }