diff --git a/src/bool.rs b/src/bool.rs index c6c88e5..4177454 100644 --- a/src/bool.rs +++ b/src/bool.rs @@ -1,28 +1,44 @@ -use std::{collections::HashMap, fmt::Debug, marker::PhantomData}; +use std::{ + cell::RefCell, + collections::HashMap, + fmt::{Debug, Display}, + hash::Hash, + marker::PhantomData, + thread::panicking, +}; -use itertools::Itertools; -use num_traits::{FromPrimitive, Num, One, PrimInt, ToPrimitive, Zero}; +use itertools::{izip, partition, Itertools}; +use num_traits::{FromPrimitive, Num, One, PrimInt, ToPrimitive, WrappingSub, Zero}; use crate::{ - backend::{ArithmeticOps, ModInit, VectorOps}, + backend::{ArithmeticOps, ModInit, ModularOpsU64, VectorOps}, decomposer::{gadget_vector, Decomposer, DefaultDecomposer, NumInfo}, - lwe::{decrypt_lwe, encrypt_lwe, lwe_key_switch, lwe_ksk_keygen, LweSecret}, - ntt::{Ntt, NttInit}, + lwe::{decrypt_lwe, encrypt_lwe, lwe_key_switch, lwe_ksk_keygen, measure_noise_lwe, LweSecret}, + ntt::{Ntt, NttBackendU64, NttInit}, random::{DefaultSecureRng, RandomGaussianDist, RandomUniformDist}, - rgsw::{encrypt_rgsw, galois_auto, galois_key_gen, rlwe_by_rgsw, IsTrivial, RlweSecret}, + rgsw::{ + decrypt_rlwe, encrypt_rgsw, galois_auto, galois_key_gen, generate_auto_map, rlwe_by_rgsw, + IsTrivial, RlweCiphertext, RlweSecret, + }, utils::{generate_prime, mod_exponent, TryConvertFrom, WithLocal}, Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, Secret, }; +thread_local! { + pub(crate) static CLIENT_KEY: RefCell = RefCell::new(ClientKey::random()); +} + trait PbsKey { type M: Matrix; - fn rgsw_ct_secret_el(&self, si: usize) -> &Self::M; + /// RGSW ciphertext of LWE secret elements + fn rgsw_ct_lwe_si(&self, si: usize) -> &Self::M; + /// Key for automorphism fn galois_key_for_auto(&self, k: isize) -> &Self::M; - fn auto_map_index(&self, k: isize) -> &[usize]; - fn auto_map_sign(&self, k: isize) -> &[bool]; + /// LWE ksk to key switch from RLWE secret to LWE secret + fn lwe_ksk(&self) -> &Self::M; } -trait Parameters { +trait PbsParameters { type Element; type D: Decomposer; fn rlwe_q(&self) -> Self::Element; @@ -44,12 +60,43 @@ trait Parameters { /// For any a, if k is s.t. a = g^{k}, then k is expressed as k. If k is s.t /// a = -g^{k}, then k is expressed as k=k+q/2 fn g_k_dlog_map(&self) -> &[usize]; + fn rlwe_auto_map(&self, k: isize) -> &(Vec, Vec); } + +#[derive(Clone)] struct ClientKey { sk_rlwe: RlweSecret, sk_lwe: LweSecret, } +impl ClientKey { + fn random() -> Self { + let sk_rlwe = RlweSecret::random(0, 0); + let sk_lwe = LweSecret::random(0, 0); + Self { sk_rlwe, sk_lwe } + } +} + +impl WithLocal for ClientKey { + fn with_local(func: F) -> R + where + F: Fn(&Self) -> R, + { + CLIENT_KEY.with_borrow(|client_key| func(client_key)) + } + + fn with_local_mut(func: F) -> R + where + F: Fn(&mut Self) -> R, + { + CLIENT_KEY.with_borrow_mut(|client_key| func(client_key)) + } +} + +fn set_client_key(key: &ClientKey) { + ClientKey::with_local_mut(|k| *k = key.clone()) +} + struct ServerKey { /// Rgsw cts of LWE secret elements rgsw_cts: Vec, @@ -59,6 +106,23 @@ struct ServerKey { lwe_ksk: M, } +//FIXME(Jay): Figure out a way for BoolEvaluator to have access to ServerKey +// via a pointer and implement PbsKey for BoolEvaluator instead of ServerKey +// directly +impl PbsKey for ServerKey { + type M = M; + fn galois_key_for_auto(&self, k: isize) -> &Self::M { + self.galois_keys.get(&k).unwrap() + } + fn rgsw_ct_lwe_si(&self, si: usize) -> &Self::M { + &self.rgsw_cts[si] + } + + fn lwe_ksk(&self) -> &Self::M { + &self.lwe_ksk + } +} + struct BoolParameters { rlwe_q: El, rlwe_logq: usize, @@ -75,7 +139,10 @@ struct BoolParameters { w: usize, } -struct BoolEvaluator { +struct BoolEvaluator +where + M: Matrix, +{ parameters: BoolParameters, decomposer_rlwe: DefaultDecomposer, decomposer_lwe: DefaultDecomposer, @@ -84,7 +151,11 @@ struct BoolEvaluator { rlwe_modop: ModOp, lwe_modop: ModOp, embedding_factor: usize, - + nand_test_vec: M::R, + rlweq_by8: M::MatElement, + rlwe_auto_maps: Vec<(Vec, Vec)>, + scratch_lwen_plus1: M::R, + scratch_dplus2_ring: M, _phantom: PhantomData, } @@ -94,7 +165,7 @@ where ModOp: ModInit + ArithmeticOps + VectorOps, - M::MatElement: PrimInt + Debug + NumInfo + FromPrimitive, + M::MatElement: PrimInt + Debug + Display + NumInfo + FromPrimitive + WrappingSub, M: MatrixEntity + MatrixMut, M::R: TryConvertFrom<[i32], Parameters = M::MatElement> + RowEntity, M: TryConvertFrom<[i32], Parameters = M::MatElement>, @@ -105,6 +176,7 @@ where { fn new(parameters: BoolParameters) -> Self { //TODO(Jay): Run sanity checks for modulus values in parameters + assert!(parameters.br_q.is_power_of_two()); let decomposer_rlwe = DefaultDecomposer::new(parameters.rlwe_q, parameters.logb_rgsw, parameters.d_rgsw); @@ -129,6 +201,73 @@ where let rlwe_modop = ModInit::new(parameters.rlwe_q); let lwe_modop = ModInit::new(parameters.lwe_q); + // set test vectors + let el_one = M::MatElement::one(); + let nand_map = |index: usize, qby8: usize| { + if index < (3 * qby8) { + true + } else { + false + } + }; + + let q = parameters.br_q; + let qby2 = q >> 1; + let qby8 = q >> 3; + let qby16 = q >> 4; + let mut nand_test_vec = M::R::zeros(qby2); + // Q/8 (Q: rlwe_q) + let rlwe_qby8 = + M::MatElement::from_f64((parameters.rlwe_q.to_f64().unwrap() / 8.0).round()).unwrap(); + let true_m_el = rlwe_qby8; + // -Q/8 + let false_m_el = parameters.rlwe_q - rlwe_qby8; + for i in 0..qby2 { + let v = nand_map(i, qby8); + if v { + nand_test_vec.as_mut()[i] = true_m_el; + } else { + nand_test_vec.as_mut()[i] = false_m_el; + } + } + // Rotate and negate by q/16 + let mut tmp = M::R::zeros(qby2); + tmp.as_mut()[..qby2 - qby16].copy_from_slice(&nand_test_vec.as_ref()[qby16..]); + tmp.as_mut()[qby2 - qby16..].copy_from_slice(&nand_test_vec.as_ref()[..qby16]); + tmp.as_mut()[qby2 - qby16..].iter_mut().for_each(|v| { + *v = parameters.rlwe_q - *v; + }); + let nand_test_vec = tmp; + + // v(X) -> v(X^{-g}) + let (auto_map_index, auto_map_sign) = generate_auto_map(qby2, -(g as isize)); + let mut nand_test_vec_autog = M::R::zeros(qby2); + izip!( + nand_test_vec.as_ref().iter(), + auto_map_index.iter(), + auto_map_sign.iter() + ) + .for_each(|(v, to_index, to_sign)| { + if !to_sign { + // negate + nand_test_vec_autog.as_mut()[*to_index] = parameters.rlwe_q - *v; + } else { + nand_test_vec_autog.as_mut()[*to_index] = *v; + } + }); + + // auto map indices and sign + let mut rlwe_auto_maps = vec![]; + let ring_size = parameters.rlwe_n; + let g = parameters.g as isize; + for i in [g, -g] { + rlwe_auto_maps.push(generate_auto_map(ring_size, i)) + } + + // create srcatch spaces + let scratch_lwen_plus1 = M::R::zeros(parameters.lwe_n + 1); + let scratch_dplus2_ring = M::zeros(parameters.d_rgsw + 2, parameters.rlwe_n); + BoolEvaluator { parameters: parameters, decomposer_lwe, @@ -138,7 +277,11 @@ where lwe_modop, rlwe_modop, rlwe_nttop, - + nand_test_vec: nand_test_vec_autog, + rlweq_by8: rlwe_qby8, + rlwe_auto_maps, + scratch_lwen_plus1, + scratch_dplus2_ring, _phantom: PhantomData, } } @@ -190,6 +333,7 @@ where // X^{si}; assume |emebedding_factor * si| < N let mut m = M::zeros(1, ring_size); let si = (self.embedding_factor as i32) * si; + // dbg!(si); if si < 0 { // X^{-i} = X^{2N - i} = -X^{N-i} m.set( @@ -246,16 +390,14 @@ where } } + /// TODO(Jay): Fetch client key from thread local pub fn encrypt(&self, m: bool, client_key: &ClientKey) -> M::R { - let rlwe_q_by8 = - M::MatElement::from_f64((self.parameters.rlwe_q.to_f64().unwrap() / 8.0).round()) - .unwrap(); let m = if m { // Q/8 - rlwe_q_by8 + self.rlweq_by8 } else { // -Q/8 - self.parameters.rlwe_q - rlwe_q_by8 + self.parameters.rlwe_q - self.rlweq_by8 }; DefaultSecureRng::with_local_mut(|rng| { @@ -275,13 +417,11 @@ where let m = decrypt_lwe(lwe_ct, client_key.sk_rlwe.values(), &self.rlwe_modop); let m = { // m + q/8 => {0,q/4 1} - let rlwe_q_by8 = - M::MatElement::from_f64((self.parameters.rlwe_q.to_f64().unwrap() / 8.0).round()) - .unwrap(); - (((m + rlwe_q_by8).to_f64().unwrap() * 4.0) / self.parameters.rlwe_q.to_f64().unwrap()) - .round() - .to_usize() - .unwrap() + (((m + self.rlweq_by8).to_f64().unwrap() * 4.0) + / self.parameters.rlwe_q.to_f64().unwrap()) + .round() + .to_usize() + .unwrap() % 4 }; @@ -290,9 +430,123 @@ where } else if m == 1 { true } else { - panic!("Incorrect bool decryption. Got m={m} expected m to be 0 or 1") + panic!("Incorrect bool decryption. Got m={m} but expected m to be 0 or 1") + } + } + + pub fn nand( + &self, + c0: &M::R, + c1: &M::R, + server_key: &ServerKey, + scratch_lwen_plus1: &mut M::R, + scratch_matrix_dplus2_ring: &mut M, + ) -> M::R { + // ClientKey::with_local(|ck| { + // let c0_noise = measure_noise_lwe( + // c0, + // ck.sk_rlwe.values(), + // &self.rlwe_modop, + // &(self.rlwe_q() - self.rlweq_by8), + // ); + // let c1_noise = + // measure_noise_lwe(c1, ck.sk_rlwe.values(), &self.rlwe_modop, + // &(self.rlweq_by8)); println!("c0 noise: {c0_noise}; c1 noise: + // {c1_noise}"); }); + + let mut c_out = M::R::zeros(c0.as_ref().len()); + let modop = &self.rlwe_modop; + izip!( + c_out.as_mut().iter_mut(), + c0.as_ref().iter(), + c1.as_ref().iter() + ) + .for_each(|(o, i0, i1)| { + *o = modop.add(i0, i1); + }); + // +Q/8 + c_out.as_mut()[0] = modop.add(&c_out.as_ref()[0], &self.rlweq_by8); + + // ClientKey::with_local(|ck| { + // let noise = measure_noise_lwe( + // &c_out, + // ck.sk_rlwe.values(), + // &self.rlwe_modop, + // &(self.rlweq_by8), + // ); + // println!("cout_noise: {noise}"); + // }); + + // PBS + pbs( + self, + &self.nand_test_vec, + &mut c_out, + scratch_lwen_plus1, + scratch_matrix_dplus2_ring, + &self.lwe_modop, + &self.rlwe_modop, + &self.rlwe_nttop, + server_key, + ); + + c_out + } +} + +impl PbsParameters for BoolEvaluator +where + M::MatElement: PrimInt + WrappingSub + Debug, +{ + type Element = M::MatElement; + type D = DefaultDecomposer; + fn rlwe_auto_map(&self, k: isize) -> &(Vec, Vec) { + let g = self.parameters.g as isize; + if k == g { + &self.rlwe_auto_maps[0] + } else if k == -g { + &self.rlwe_auto_maps[1] + } else { + panic!("RLWE auto map only supports k in [-g, g], but got k={k}"); } } + + fn br_q(&self) -> usize { + self.parameters.br_q + } + fn d_lwe(&self) -> usize { + self.parameters.d_lwe + } + fn d_rgsw(&self) -> usize { + self.parameters.d_rgsw + } + fn decomoposer_lwe(&self) -> &Self::D { + &self.decomposer_lwe + } + fn decomoposer_rlwe(&self) -> &Self::D { + &self.decomposer_rlwe + } + fn embedding_factor(&self) -> usize { + self.embedding_factor + } + fn g(&self) -> isize { + self.parameters.g as isize + } + fn g_k_dlog_map(&self) -> &[usize] { + &self.g_k_dlog_map + } + fn lwe_n(&self) -> usize { + self.parameters.lwe_n + } + fn lwe_q(&self) -> Self::Element { + self.parameters.lwe_q + } + fn rlwe_n(&self) -> usize { + self.parameters.rlwe_n + } + fn rlwe_q(&self) -> Self::Element { + self.parameters.rlwe_q + } } /// LMKCY+ Blind rotation @@ -305,6 +559,7 @@ fn blind_rotation< NttOp: Ntt, ModOp: ArithmeticOps + VectorOps, K: PbsKey, + P: PbsParameters, >( trivial_rlwe_test_poly: &mut MT, scratch_matrix_dplus2_ring: &mut Mmut, @@ -315,6 +570,7 @@ fn blind_rotation< decomposer: &D, ntt_op: &NttOp, mod_op: &ModOp, + parameters: &P, pbs_key: &K, ) where ::R: RowMut, @@ -324,11 +580,11 @@ fn blind_rotation< let q_by_2 = q / 2; // -(g^k) - for i in 1..q_by_2 { + for i in (1..q_by_2).rev() { gk_to_si[q_by_2 + i].iter().for_each(|s_index| { rlwe_by_rgsw( trivial_rlwe_test_poly, - pbs_key.rgsw_ct_secret_el(*s_index), + pbs_key.rgsw_ct_lwe_si(*s_index), scratch_matrix_dplus2_ring, decomposer, ntt_op, @@ -336,12 +592,13 @@ fn blind_rotation< ); }); + let (auto_map_index, auto_map_sign) = parameters.rlwe_auto_map(g); galois_auto( trivial_rlwe_test_poly, pbs_key.galois_key_for_auto(g), scratch_matrix_dplus2_ring, - pbs_key.auto_map_index(g), - pbs_key.auto_map_sign(g), + &auto_map_index, + &auto_map_sign, mod_op, ntt_op, decomposer, @@ -352,30 +609,31 @@ fn blind_rotation< gk_to_si[q_by_2].iter().for_each(|s_index| { rlwe_by_rgsw( trivial_rlwe_test_poly, - pbs_key.rgsw_ct_secret_el(*s_index), + pbs_key.rgsw_ct_lwe_si(*s_index), scratch_matrix_dplus2_ring, decomposer, ntt_op, mod_op, ); }); + let (auto_map_index, auto_map_sign) = parameters.rlwe_auto_map(-g); galois_auto( trivial_rlwe_test_poly, pbs_key.galois_key_for_auto(-g), scratch_matrix_dplus2_ring, - pbs_key.auto_map_index(-g), - pbs_key.auto_map_sign(-g), + &auto_map_index, + &auto_map_sign, mod_op, ntt_op, decomposer, ); // +(g^k) - for i in 1..q_by_2 { + for i in (1..q_by_2).rev() { gk_to_si[i].iter().for_each(|s_index| { rlwe_by_rgsw( trivial_rlwe_test_poly, - pbs_key.rgsw_ct_secret_el(*s_index), + pbs_key.rgsw_ct_lwe_si(*s_index), scratch_matrix_dplus2_ring, decomposer, ntt_op, @@ -383,12 +641,13 @@ fn blind_rotation< ); }); + let (auto_map_index, auto_map_sign) = parameters.rlwe_auto_map(g); galois_auto( trivial_rlwe_test_poly, pbs_key.galois_key_for_auto(g), scratch_matrix_dplus2_ring, - pbs_key.auto_map_index(g), - pbs_key.auto_map_sign(g), + &auto_map_index, + &auto_map_sign, mod_op, ntt_op, decomposer, @@ -399,7 +658,7 @@ fn blind_rotation< gk_to_si[0].iter().for_each(|s_index| { rlwe_by_rgsw( trivial_rlwe_test_poly, - pbs_key.rgsw_ct_secret_el(gk_to_si[q_by_2][*s_index]), + pbs_key.rgsw_ct_lwe_si(gk_to_si[q_by_2][*s_index]), scratch_matrix_dplus2_ring, decomposer, ntt_op, @@ -414,8 +673,7 @@ fn blind_rotation< /// - blind rotate fn pbs< M: Matrix + MatrixMut + MatrixEntity, - MT: MatrixMut + IsTrivial + MatrixEntity, - P: Parameters, + P: PbsParameters, NttOp: Ntt, ModOp: ArithmeticOps + VectorOps, K: PbsKey, @@ -423,17 +681,17 @@ fn pbs< parameters: &P, test_vec: &M::R, lwe_in: &mut M::R, - lwe_ksk: &M, scratch_lwen_plus1: &mut M::R, scratch_matrix_dplus2_ring: &mut M, modop_lweq: &ModOp, modop_rlweq: &ModOp, nttop_rlweq: &NttOp, - pbs_key: K, + pbs_key: &K, ) where - ::R: RowMut, - ::R: RowMut, - M::MatElement: PrimInt + ToPrimitive + FromPrimitive + One + Copy + Zero, + // FIXME(Jay): TryConvertFrom<[i32], Parameters = M::MatElement> are only needed for + // debugging purposes + ::R: RowMut + TryConvertFrom<[i32], Parameters = M::MatElement>, + M::MatElement: PrimInt + ToPrimitive + FromPrimitive + One + Copy + Zero + Display, { let rlwe_q = parameters.rlwe_q(); let lwe_q = parameters.lwe_q(); @@ -449,17 +707,34 @@ fn pbs< M::MatElement::from_f64(((v.to_f64().unwrap() * lwe_qf64) / rlwe_qf64).round()).unwrap() }); - // key switch - // let mut lwe_out = M::zeros(1, parameters.lwe_n() + 1); + PBSTracer::with_local_mut(|t| { + let out = lwe_in + .as_ref() + .iter() + .map(|v| v.to_u64().unwrap()) + .collect_vec(); + t.ct_lwe_q_mod = out; + }); + + // key switch RLWE secret to LWE secret scratch_lwen_plus1.as_mut().fill(M::MatElement::zero()); lwe_key_switch( scratch_lwen_plus1, lwe_in, - lwe_ksk, + pbs_key.lwe_ksk(), modop_lweq, parameters.decomoposer_lwe(), ); + PBSTracer::with_local_mut(|t| { + let out = scratch_lwen_plus1 + .as_ref() + .iter() + .map(|v| v.to_u64().unwrap()) + .collect_vec(); + t.ct_lwe_q_mod_after_ksk = out; + }); + // odd mowdown Q_ks -> q let g_k_dlog_map = parameters.g_k_dlog_map(); let mut g_k_si = vec![vec![]; br_q]; @@ -474,6 +749,15 @@ fn pbs< g_k_si[k].push(index); }); + PBSTracer::with_local_mut(|t| { + let out = scratch_lwen_plus1 + .as_ref() + .iter() + .map(|v| mod_switch_odd(v.to_f64().unwrap(), lwe_qf64, br_qf64) as u64) + .collect_vec(); + t.ct_br_q_mod = out; + }); + // handle b and set trivial test RLWE let g = parameters.g() as usize; let g_times_b = (g * mod_switch_odd( @@ -485,41 +769,42 @@ fn pbs< let br_qby2 = br_q / 2; let mut gb_monomial_sign = true; let mut gb_monomial_exp = g_times_b; - // X^{g*b} mod X^{q}+1 + // X^{g*b} mod X^{q/2}+1 if gb_monomial_exp > br_qby2 { gb_monomial_exp -= br_qby2; gb_monomial_sign = false } // monomial mul - let mut trivial_rlwe_test_poly = MT::zeros(2, rlwe_n); + let mut trivial_rlwe_test_poly = RlweCiphertext(M::zeros(2, rlwe_n), true); if parameters.embedding_factor() == 1 { monomial_mul( test_vec.as_ref(), trivial_rlwe_test_poly.get_row_mut(1).as_mut(), gb_monomial_exp, gb_monomial_sign, - br_q, + br_qby2, modop_rlweq, ); } else { // use lwe_in to store the `t = v(X) * X^{g*2} mod X^{q/2}+1` temporarily. This - // works because q/2 < N (where N is lwe_in LWE dimension) always. + // works because q/2 <= N (where N is lwe_in LWE dimension) always. monomial_mul( test_vec.as_ref(), &mut lwe_in.as_mut()[..br_qby2], gb_monomial_exp, gb_monomial_sign, - br_q, + br_qby2, modop_rlweq, ); // emebed poly `t` in ring X^{q/2}+1 inside the bigger ring X^{N}+1 + let embed_factor = parameters.embedding_factor(); let partb_trivial_rlwe = trivial_rlwe_test_poly.get_row_mut(1); lwe_in.as_ref()[..br_qby2] .iter() .enumerate() .for_each(|(index, v)| { - partb_trivial_rlwe[2 * index] = *v; + partb_trivial_rlwe[embed_factor * index] = *v; }); } @@ -534,19 +819,48 @@ fn pbs< parameters.decomoposer_rlwe(), nttop_rlweq, modop_rlweq, - &pbs_key, + parameters, + pbs_key, ); + // ClientKey::with_local(|ck| { + // let ring_size = parameters.rlwe_n(); + // let mut rlwe_ct = vec![vec![0u64; ring_size]; 2]; + // izip!( + // rlwe_ct[0].iter_mut(), + // trivial_rlwe_test_poly.0.get_row_slice(0) + // ) + // .for_each(|(t, f)| { + // *t = f.to_u64().unwrap(); + // }); + // izip!( + // rlwe_ct[1].iter_mut(), + // trivial_rlwe_test_poly.0.get_row_slice(1) + // ) + // .for_each(|(t, f)| { + // *t = f.to_u64().unwrap(); + // }); + // let mut m_out = vec![vec![0u64; ring_size]]; + // let modop = ModularOpsU64::new(rlwe_q.to_u64().unwrap()); + // let nttop = NttBackendU64::new(rlwe_q.to_u64().unwrap(), ring_size); + // decrypt_rlwe(&rlwe_ct, ck.sk_rlwe.values(), &mut m_out, &nttop, &modop); + + // println!("RLWE post PBS message: {:?}", m_out[0]); + // }); + // sample extract sample_extract(lwe_in, &trivial_rlwe_test_poly, modop_rlweq, 0); } fn mod_switch_odd(v: f64, from_q: f64, to_q: f64) -> usize { - let odd_v = (((v.to_f64().unwrap() * to_q) / (from_q)).floor()) - .to_usize() - .unwrap(); + // println!("v: {v}, odd_v: {odd_v}, lwe_q:{lwe_q}, br_q:{br_q}"); + let odd_v = (((v * to_q) / (from_q)).floor()).to_usize().unwrap(); + // println!( + // "v: {v}, odd_v: {odd_v}, returned_oddv: {},lwe_q:{from_q}, br_q:{to_q}", + // odd_v + ((odd_v & 1) ^ 1) + // ); //TODO(Jay): check correctness of this - odd_v + (odd_v ^ (usize::one())) + odd_v + ((odd_v & 1) ^ 1) } fn sample_extract>( @@ -576,6 +890,7 @@ fn sample_extract>( p_in: &[El], p_out: &mut [El], @@ -606,28 +921,121 @@ fn monomial_mul>( }); } +thread_local! { + static PBS_TRACER: RefCell>>> = RefCell::new(PBSTracer::default()); +} + +#[derive(Default)] +struct PBSTracer +where + M: Matrix + Default, +{ + pub(crate) ct_lwe_q_mod: M::R, + pub(crate) ct_lwe_q_mod_after_ksk: M::R, + pub(crate) ct_br_q_mod: Vec, +} + +impl PBSTracer>> { + fn trace(&self, parameters: &BoolParameters, client_key: &ClientKey, expected_m: bool) { + let lwe_q = parameters.lwe_q; + let lwe_qby8 = ((lwe_q as f64) / 8.0).round() as u64; + let expected_m_lweq = if expected_m { + lwe_qby8 + } else { + lwe_q - lwe_qby8 + }; + let modop_lweq = ModularOpsU64::new(lwe_q); + // noise after mod down Q -> Q_ks + let noise0 = { + measure_noise_lwe( + &self.ct_lwe_q_mod, + client_key.sk_rlwe.values(), + &modop_lweq, + &expected_m_lweq, + ) + }; + + // noise after key switch from RLWE -> LWE + let noise1 = { + measure_noise_lwe( + &self.ct_lwe_q_mod_after_ksk, + client_key.sk_lwe.values(), + &modop_lweq, + &expected_m_lweq, + ) + }; + + // noise after mod down odd from Q_ks -> q + let br_q = parameters.br_q as u64; + let expected_m_brq = if expected_m { + br_q >> 3 + } else { + br_q - (br_q >> 3) + }; + let modop_br_q = ModularOpsU64::new(br_q); + let noise2 = { + measure_noise_lwe( + &self.ct_br_q_mod, + client_key.sk_lwe.values(), + &modop_br_q, + &expected_m_brq, + ) + }; + + println!( + " + m: {expected_m}, + Noise after mod down Q -> Q_ks: {noise0}, + Noise after key switch from RLWE -> LWE: {noise1}, + Noise after mod dwon Q_ks -> q: {noise2} + " + ); + } +} + +impl WithLocal for PBSTracer>> { + fn with_local(func: F) -> R + where + F: Fn(&Self) -> R, + { + PBS_TRACER.with_borrow(|t| func(t)) + } + + fn with_local_mut(func: F) -> R + where + F: Fn(&mut Self) -> R, + { + PBS_TRACER.with_borrow_mut(|t| func(t)) + } +} + #[cfg(test)] mod tests { - use crate::{backend::ModularOpsU64, ntt::NttBackendU64}; + use crate::{backend::ModularOpsU64, ntt::NttBackendU64, random::DEFAULT_RNG}; use super::*; const SP_BOOL_PARAMS: BoolParameters = BoolParameters:: { - rlwe_q: 4294957057u64, - rlwe_logq: 32, + rlwe_q: 268369921u64, + rlwe_logq: 28, lwe_q: 1 << 16, lwe_logq: 16, - br_q: 1 << 9, + br_q: 1 << 10, rlwe_n: 1 << 10, - lwe_n: 490, - d_rgsw: 4, - logb_rgsw: 7, - d_lwe: 4, + lwe_n: 493, + d_rgsw: 3, + logb_rgsw: 8, + d_lwe: 3, logb_lwe: 4, g: 5, w: 1, }; + // #[test] + // fn trial() { + // dbg!(generate_prime(28, 1 << 11, 1 << 28)); + // } + #[test] fn encrypt_decrypt_works() { // let prime = generate_prime(32, 2 * 1024, 1 << 32); @@ -645,4 +1053,71 @@ mod tests { m = !m; } } + + #[test] + fn trial12() { + // DefaultSecureRng::with_local_mut(|r| { + // let rng = DefaultSecureRng::new_seeded([19u8; 32]); + // *r = rng; + // }); + + let bool_evaluator = + BoolEvaluator::>, u64, NttBackendU64, ModularOpsU64>::new(SP_BOOL_PARAMS); + // println!("{:?}", bool_evaluator.nand_test_vec); + let client_key = bool_evaluator.client_key(); + set_client_key(&client_key); + + let server_key = bool_evaluator.server_key(&client_key); + + let mut scratch_lwen_plus1 = vec![0u64; bool_evaluator.parameters.lwe_n + 1]; + let mut scratch_matrix_dplus2_ring = vec![ + vec![0u64; bool_evaluator.parameters.rlwe_n]; + bool_evaluator.parameters.d_rgsw + 2 + ]; + + let mut m0 = false; + let mut m1 = true; + let mut ct0 = bool_evaluator.encrypt(m0, &client_key); + let mut ct1 = bool_evaluator.encrypt(m1, &client_key); + for _ in 0..4 { + let ct_back = bool_evaluator.nand( + &ct0, + &ct1, + &server_key, + &mut scratch_lwen_plus1, + &mut scratch_matrix_dplus2_ring, + ); + + let m_out = !(m0 && m1); + + // Trace and measure PBS noise + { + // Trace PBS + PBSTracer::with_local(|t| t.trace(&SP_BOOL_PARAMS, &client_key, m_out)); + + // Calculate nosie in ciphertext post PBS + let ideal = if m_out { + bool_evaluator.rlweq_by8 + } else { + bool_evaluator.rlwe_q() - bool_evaluator.rlweq_by8 + }; + let noise = measure_noise_lwe( + &ct_back, + client_key.sk_rlwe.values(), + &bool_evaluator.rlwe_modop, + &ideal, + ); + println!("PBS noise: {noise}"); + } + let m_back = bool_evaluator.decrypt(&ct_back, &client_key); + assert_eq!(m_out, m_back); + println!("----------"); + + m1 = m0; + m0 = m_out; + + ct1 = ct0; + ct0 = ct_back; + } + } } diff --git a/src/lwe.rs b/src/lwe.rs index 849450f..040de3c 100644 --- a/src/lwe.rs +++ b/src/lwe.rs @@ -1,7 +1,10 @@ -use std::fmt::Debug; +use std::{ + cell::RefCell, + fmt::{Debug, Display}, +}; use itertools::{izip, Itertools}; -use num_traits::{abs, Zero}; +use num_traits::{abs, PrimInt, ToPrimitive, Zero}; use crate::{ backend::{ArithmeticOps, VectorOps}, @@ -21,6 +24,7 @@ trait LweKeySwitchParameters { trait LweCiphertext {} +#[derive(Clone)] pub struct LweSecret { values: Vec, } @@ -183,6 +187,34 @@ where operator.sub(b, &sa) } +pub(crate) fn measure_noise_lwe, S>( + ct: &Ro, + s: &[S], + operator: &Op, + ideal_m: &Ro::Element, +) -> f64 +where + Ro: TryConvertFrom<[S], Parameters = Ro::Element>, + Ro::Element: Zero + ToPrimitive + PrimInt + Display, +{ + assert!(s.len() == ct.as_ref().len() - 1,); + + let s = Ro::try_convert_from(s, &operator.modulus()); + let mut sa = Ro::Element::zero(); + izip!(s.as_ref().iter(), ct.as_ref().iter().skip(1)).for_each(|(si, ai)| { + sa = operator.add(&sa, &operator.mul(si, ai)); + }); + let m = operator.sub(&ct.as_ref()[0], &sa); + + println!("measire: {m} {ideal_m}"); + let mut diff = operator.sub(&m, ideal_m); + let q = operator.modulus(); + if diff > (q >> 1) { + diff = q - diff; + } + return diff.to_f64().unwrap().log2(); +} + #[cfg(test)] mod tests { diff --git a/src/rgsw.rs b/src/rgsw.rs index b3b400a..6d305b3 100644 --- a/src/rgsw.rs +++ b/src/rgsw.rs @@ -16,7 +16,7 @@ use crate::{ Matrix, MatrixEntity, MatrixMut, RowMut, Secret, }; -pub struct RlweCiphertext(M, bool); +pub struct RlweCiphertext(pub(crate) M, pub(crate) bool); impl Matrix for RlweCiphertext { type MatElement = M::MatElement; @@ -58,6 +58,7 @@ pub trait IsTrivial { fn set_not_trivial(&mut self); } +#[derive(Clone)] pub struct RlweSecret { values: Vec, } @@ -80,12 +81,12 @@ impl RlweSecret { } } -fn generate_auto_map(ring_size: usize, k: isize) -> (Vec, Vec) { +pub(crate) fn generate_auto_map(ring_size: usize, k: isize) -> (Vec, Vec) { assert!(k & 1 == 1, "Auto {k} must be odd"); - // k = k % 2*N let k = if k < 0 { - (2 * ring_size) - (k.abs() as usize) + // k is -ve, return k%(2*N) + (2 * ring_size) - (k.abs() as usize % (2 * ring_size)) } else { k as usize }; @@ -712,19 +713,19 @@ pub(crate) fn decrypt_rlwe< M: Matrix, ModOp: VectorOps, NttOp: Ntt, - S: Secret, + S, >( rlwe_ct: &M, - s: &S, + s: &[S], m_out: &mut Mmut, ntt_op: &NttOp, mod_op: &ModOp, ) where ::R: RowMut, - Mmut: TryConvertFrom<[S::Element], Parameters = Mmut::MatElement>, + Mmut: TryConvertFrom<[S], Parameters = Mmut::MatElement>, Mmut::MatElement: Copy, { - let ring_size = s.values().len(); + let ring_size = s.len(); assert!(rlwe_ct.dimension() == (2, ring_size)); assert!(m_out.dimension() == (1, ring_size)); @@ -735,7 +736,7 @@ pub(crate) fn decrypt_rlwe< ntt_op.forward(m_out.get_row_mut(0)); // -s*a - let mut s = Mmut::try_convert_from(&s.values(), &mod_op.modulus()); + let mut s = Mmut::try_convert_from(&s, &mod_op.modulus()); ntt_op.forward(s.get_row_mut(0)); mod_op.elwise_mul_mut(m_out.get_row_mut(0), s.get_row_slice(0)); mod_op.elwise_neg_mut(m_out.get_row_mut(0)); @@ -819,7 +820,7 @@ mod tests { random::{DefaultSecureRng, RandomUniformDist}, rgsw::{measure_noise, RlweCiphertext}, utils::{generate_prime, negacyclic_mul}, - Matrix, + Matrix, Secret, }; use super::{ @@ -834,7 +835,7 @@ mod tests { let ring_size = 1 << 10; let q = generate_prime(logq, ring_size, 1u64 << logq).unwrap(); let p = 1u64 << logp; - let d_rgsw = 10; + let d_rgsw = 9; let logb = 5; let mut rng = DefaultSecureRng::new(); @@ -895,7 +896,13 @@ mod tests { // Decrypt RLWE(m0m1) let mut encoded_m0m1_back = vec![vec![0u64; ring_size as usize]]; - decrypt_rlwe(&rlwe_in_ct, &s, &mut encoded_m0m1_back, &ntt_op, &mod_op); + decrypt_rlwe( + &rlwe_in_ct, + s.values(), + &mut encoded_m0m1_back, + &ntt_op, + &mod_op, + ); let m0m1_back = encoded_m0m1_back[0] .iter() .map(|v| (((*v as f64 * p as f64) / (q as f64)).round() as u64) % p) @@ -941,7 +948,7 @@ mod tests { &mut rng, ); - let auto_k = -25; + let auto_k = -5; // Generate galois key to key switch from s^k to s let mut ksk_out = vec![vec![0u64; ring_size as usize]; d_rgsw * 2]; @@ -976,7 +983,13 @@ mod tests { // Decrypt RLWE_{s}(m^k) and check let mut encoded_m_k_back = vec![vec![0u64; ring_size as usize]]; - decrypt_rlwe(&rlwe_m_k, &s, &mut encoded_m_k_back, &ntt_op, &mod_op); + decrypt_rlwe( + &rlwe_m_k, + s.values(), + &mut encoded_m_k_back, + &ntt_op, + &mod_op, + ); let m_k_back = encoded_m_k_back[0] .iter() .map(|v| (((*v as f64 * p as f64) / q as f64).round() as u64) % p)