From b0d53a6fbf1218d7aced7ef51f7165855f0f0676 Mon Sep 17 00:00:00 2001 From: Janmajaya Mall Date: Fri, 26 Apr 2024 13:56:52 +0530 Subject: [PATCH] add bool pbs --- src/backend.rs | 17 +++ src/bool.rs | 360 +++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 26 +++- src/lwe.rs | 110 +++++++------- src/rgsw.rs | 391 ++++++++++++++++++++++++++++++++++--------------- src/utils.rs | 21 +++ 6 files changed, 747 insertions(+), 178 deletions(-) create mode 100644 src/bool.rs diff --git a/src/backend.rs b/src/backend.rs index 2b21f1b..0dbf50d 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -12,6 +12,12 @@ pub trait VectorOps { fn elwise_neg_mut(&self, a: &mut [Self::Element]); /// inplace mutates `a`: a = a + b*c fn elwise_fma_mut(&self, a: &mut [Self::Element], b: &[Self::Element], c: &[Self::Element]); + fn elwise_fma_scalar_mut( + &self, + a: &mut [Self::Element], + b: &[Self::Element], + c: &Self::Element, + ); fn modulus(&self) -> Self::Element; } @@ -169,6 +175,17 @@ impl VectorOps for ModularOpsU64 { }); } + fn elwise_fma_scalar_mut( + &self, + a: &mut [Self::Element], + b: &[Self::Element], + c: &Self::Element, + ) { + izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| { + *ai = self.add_mod_fast(*ai, self.mul_mod_fast(*bi, *c)); + }); + } + fn modulus(&self) -> Self::Element { self.q } diff --git a/src/bool.rs b/src/bool.rs new file mode 100644 index 0000000..310bbc3 --- /dev/null +++ b/src/bool.rs @@ -0,0 +1,360 @@ +use std::collections::HashMap; + +use num_traits::{FromPrimitive, One, PrimInt, ToPrimitive, Zero}; + +use crate::{ + backend::{ArithmeticOps, VectorOps}, + decomposer::Decomposer, + lwe::lwe_key_switch, + ntt::Ntt, + rgsw::{galois_auto, rlwe_by_rgsw, IsTrivial}, + Matrix, MatrixEntity, MatrixMut, Row, RowMut, +}; + +struct BoolEvaluator {} + +impl BoolEvaluator {} + +trait PbsKey { + type M: Matrix; + + fn rgsw_ct_secret_el(&self, si: usize) -> &Self::M; + fn galois_key_for_auto(&self, k: isize) -> &Self::M; + fn auto_map_index(&self, k: isize) -> &[usize]; + fn auto_map_sign(&self, k: isize) -> &[bool]; +} + +/// LMKCY+ Blind rotation +/// +/// gk_to_si: [-g^0, -g^1, .., -g^{q/2-1}, g^0, ..., g^{q/2-1}] +fn blind_rotation< + MT: IsTrivial + MatrixMut, + Mmut: MatrixMut + Matrix, + D: Decomposer, + NttOp: Ntt, + ModOp: ArithmeticOps + VectorOps, + K: PbsKey, +>( + trivial_rlwe_test_poly: &mut MT, + scratch_matrix_dplus2_ring: &mut Mmut, + g: isize, + w: usize, + q: usize, + gk_to_si: &[Vec], + decomposer: &D, + ntt_op: &NttOp, + mod_op: &ModOp, + pbs_key: &K, +) where + ::R: RowMut, + Mmut::MatElement: Copy + Zero, + ::R: RowMut, +{ + let q_by_2 = q / 2; + + // -(g^k) + for i in 1..q_by_2 { + gk_to_si[q_by_2 + i].iter().for_each(|s_index| { + rlwe_by_rgsw( + trivial_rlwe_test_poly, + pbs_key.rgsw_ct_secret_el(*s_index), + scratch_matrix_dplus2_ring, + decomposer, + ntt_op, + mod_op, + ); + }); + + galois_auto( + trivial_rlwe_test_poly, + pbs_key.galois_key_for_auto(g), + scratch_matrix_dplus2_ring, + pbs_key.auto_map_index(g), + pbs_key.auto_map_sign(g), + mod_op, + ntt_op, + decomposer, + ); + } + + // -(g^0) + gk_to_si[q_by_2].iter().for_each(|s_index| { + rlwe_by_rgsw( + trivial_rlwe_test_poly, + pbs_key.rgsw_ct_secret_el(*s_index), + scratch_matrix_dplus2_ring, + decomposer, + ntt_op, + mod_op, + ); + }); + galois_auto( + trivial_rlwe_test_poly, + pbs_key.galois_key_for_auto(-g), + scratch_matrix_dplus2_ring, + pbs_key.auto_map_index(-g), + pbs_key.auto_map_sign(-g), + mod_op, + ntt_op, + decomposer, + ); + + // +(g^k) + for i in 1..q_by_2 { + gk_to_si[i].iter().for_each(|s_index| { + rlwe_by_rgsw( + trivial_rlwe_test_poly, + pbs_key.rgsw_ct_secret_el(*s_index), + scratch_matrix_dplus2_ring, + decomposer, + ntt_op, + mod_op, + ); + }); + + galois_auto( + trivial_rlwe_test_poly, + pbs_key.galois_key_for_auto(g), + scratch_matrix_dplus2_ring, + pbs_key.auto_map_index(g), + pbs_key.auto_map_sign(g), + mod_op, + ntt_op, + decomposer, + ); + } + + // +(g^0) + gk_to_si[0].iter().for_each(|s_index| { + rlwe_by_rgsw( + trivial_rlwe_test_poly, + pbs_key.rgsw_ct_secret_el(gk_to_si[q_by_2][*s_index]), + scratch_matrix_dplus2_ring, + decomposer, + ntt_op, + mod_op, + ); + }); +} + +trait Parameters { + type Element; + type D: Decomposer; + fn rlwe_q(&self) -> Self::Element; + fn lwe_q(&self) -> Self::Element; + fn br_q(&self) -> usize; + fn d_rgsw(&self) -> usize; + fn d_lwe(&self) -> usize; + fn rlwe_n(&self) -> usize; + fn lwe_n(&self) -> usize; + // Embedding fator for ring X^{q}+1 inside + fn embedding_factor(&self) -> usize; + // generator g + fn g(&self) -> isize; + fn decomoposer_lwe(&self) -> &Self::D; + fn decomoposer_rlwe(&self) -> &Self::D; + /// Maps a \in Z^*_{2q} to discrete log k, with generator g (i.e. g^k = + /// a). Returned vector is of size q that stores dlog of a at `vec[a]`. + /// For any a, k is s.t. a = g^{k}, then k is expressed as k. If k is s.t a + /// = -g^{k/2}, then k is expressed as k=k+q/2 + fn g_k_dlog_map(&self) -> &[usize]; +} + +/// - Mod down +/// - key switching +/// - mod down +/// - blind rotate +fn pbs< + M: Matrix + MatrixMut + MatrixEntity, + MT: MatrixMut + IsTrivial + MatrixEntity, + P: Parameters, + NttOp: Ntt, + ModOp: ArithmeticOps + VectorOps, + K: PbsKey, +>( + parameters: &P, + test_vec: &M::R, + lwe_in: &mut M::R, + lwe_ksk: &M, + scratch_lwen_plus1: &mut M::R, + scratch_matrix_dplus2_ring: &mut M, + modop_lweq: &ModOp, + modop_rlweq: &ModOp, + nttop_rlweq: &NttOp, + pbs_key: K, +) where + ::R: RowMut, + ::R: RowMut, + M::MatElement: PrimInt + ToPrimitive + FromPrimitive + One + Copy + Zero, +{ + let rlwe_q = parameters.rlwe_q(); + let lwe_q = parameters.lwe_q(); + let br_q = parameters.br_q(); + let rlwe_qf64 = rlwe_q.to_f64().unwrap(); + let lwe_qf64 = lwe_q.to_f64().unwrap(); + let br_qf64 = br_q.to_f64().unwrap(); + let rlwe_n = parameters.rlwe_n(); + + // moddown Q -> Q_ks + lwe_in.as_mut().iter_mut().for_each(|v| { + *v = + M::MatElement::from_f64(((v.to_f64().unwrap() * lwe_qf64) / rlwe_qf64).round()).unwrap() + }); + + // key switch + // let mut lwe_out = M::zeros(1, parameters.lwe_n() + 1); + scratch_lwen_plus1.as_mut().fill(M::MatElement::zero()); + lwe_key_switch( + scratch_lwen_plus1, + lwe_in, + lwe_ksk, + modop_lweq, + parameters.decomoposer_lwe(), + ); + + // odd mowdown Q_ks -> q + let g_k_dlog_map = parameters.g_k_dlog_map(); + let mut g_k_si = vec![vec![]; br_q]; + scratch_lwen_plus1 + .as_ref() + .iter() + .skip(1) + .enumerate() + .for_each(|(index, v)| { + let odd_v = mod_switch_odd(v.to_f64().unwrap(), lwe_qf64, br_qf64); + let k = g_k_dlog_map[odd_v]; + g_k_si[k].push(index); + }); + + // handle b and set trivial test RLWE + let g = parameters.g() as usize; + let g_times_b = (g * mod_switch_odd( + scratch_lwen_plus1.as_ref()[0].to_f64().unwrap(), + lwe_qf64, + br_qf64, + )) % (br_q); + // v = (v(X) * X^{g*b}) mod X^{q/2}+1 + let br_qby2 = br_q / 2; + let mut gb_monomial_sign = true; + let mut gb_monomial_exp = g_times_b; + // X^{g*b} mod X^{q}+1 + if gb_monomial_exp > br_qby2 { + gb_monomial_exp -= br_qby2; + gb_monomial_sign = false + } + // monomial mul + let mut trivial_rlwe_test_poly = MT::zeros(2, rlwe_n); + if parameters.embedding_factor() == 1 { + monomial_mul( + test_vec.as_ref(), + trivial_rlwe_test_poly.get_row_mut(1).as_mut(), + gb_monomial_exp, + gb_monomial_sign, + br_q, + modop_rlweq, + ); + } else { + // use lwe_in to store the `t = v(X) * X^{g*2} mod X^{q/2}+1` temporarily. This + // works because q/2 < N (where N is lwe_in LWE dimension) always. + monomial_mul( + test_vec.as_ref(), + &mut lwe_in.as_mut()[..br_qby2], + gb_monomial_exp, + gb_monomial_sign, + br_q, + modop_rlweq, + ); + + // emebed poly `t` in ring X^{q/2}+1 inside the bigger ring X^{N}+1 + let partb_trivial_rlwe = trivial_rlwe_test_poly.get_row_mut(1); + lwe_in.as_ref()[..br_qby2] + .iter() + .enumerate() + .for_each(|(index, v)| { + partb_trivial_rlwe[2 * index] = *v; + }); + } + // TODO Rotate test input + + // blind rotate + blind_rotation( + &mut trivial_rlwe_test_poly, + scratch_matrix_dplus2_ring, + parameters.g(), + 1, + br_q, + &g_k_si, + parameters.decomoposer_rlwe(), + nttop_rlweq, + modop_rlweq, + &pbs_key, + ); + + // sample extract + sample_extract(lwe_in, &trivial_rlwe_test_poly, modop_rlweq, 0); +} + +fn mod_switch_odd(v: f64, from_q: f64, to_q: f64) -> usize { + let odd_v = (((v.to_f64().unwrap() * to_q) / (from_q)).floor()) + .to_usize() + .unwrap(); + //TODO(Jay): check correctness of this + odd_v + (odd_v ^ (usize::one())) +} + +fn sample_extract>( + lwe_out: &mut M::R, + rlwe_in: &M, + mod_op: &ModOp, + index: usize, +) where + ::R: RowMut, + M::MatElement: Copy, +{ + let ring_size = rlwe_in.dimension().1; + + // index..=0 + let to = &mut lwe_out.as_mut()[1..]; + let from = rlwe_in.get_row_slice(0); + for i in 0..index + 1 { + to[i] = from[index - i]; + } + + // -(N..index) + for i in index + 1..ring_size { + to[i] = mod_op.neg(&from[ring_size + index - i]); + } + + // set b + lwe_out.as_mut()[0] = *rlwe_in.get(1, index); +} + +fn monomial_mul>( + p_in: &[El], + p_out: &mut [El], + mon_exp: usize, + mon_sign: bool, + ring_size: usize, + mod_op: &ModOp, +) where + El: Copy, +{ + debug_assert!(p_in.as_ref().len() == ring_size); + debug_assert!(p_in.as_ref().len() == p_out.as_ref().len()); + debug_assert!(mon_exp < ring_size); + + p_in.as_ref().iter().enumerate().for_each(|(index, v)| { + let mut to_index = index + mon_exp; + let mut to_sign = mon_sign; + if to_index >= ring_size { + to_index = to_index - ring_size; + to_sign = !to_sign; + } + + if !to_sign { + p_out.as_mut()[to_index] = mod_op.neg(v); + } else { + p_out.as_mut()[to_index] = *v; + } + }); +} diff --git a/src/lib.rs b/src/lib.rs index b185942..23a4a11 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,6 +6,7 @@ use random::{RandomGaussianDist, RandomUniformDist}; use utils::TryConvertFrom; mod backend; +mod bool; mod decomposer; mod lwe; mod ntt; @@ -34,6 +35,10 @@ pub trait Matrix: AsRef<[Self::R]> { fn get(&self, row_idx: usize, column_idx: usize) -> &Self::MatElement { &self.as_ref()[row_idx].as_ref()[column_idx] } + + fn split_at_row(&self, idx: usize) -> (&[::R], &[::R]) { + self.as_ref().split_at(idx) + } } pub trait MatrixMut: Matrix + AsMut<[::R]> @@ -52,7 +57,7 @@ where self.as_mut()[row_idx].as_mut()[column_idx] = val; } - fn split_at_row( + fn split_at_row_mut( &mut self, idx: usize, ) -> (&mut [::R], &mut [::R]) { @@ -86,7 +91,26 @@ impl Matrix for Vec> { } } +impl Matrix for &[Vec] { + type MatElement = T; + type R = Vec; + + fn dimension(&self) -> (usize, usize) { + (self.len(), self[0].len()) + } +} + +impl Matrix for &mut [Vec] { + type MatElement = T; + type R = Vec; + + fn dimension(&self) -> (usize, usize) { + (self.len(), self[0].len()) + } +} + impl MatrixMut for Vec> {} +impl MatrixMut for &mut [Vec] {} impl MatrixEntity for Vec> { fn zeros(row: usize, col: usize) -> Self { diff --git a/src/lwe.rs b/src/lwe.rs index bb3cf35..b66e7cd 100644 --- a/src/lwe.rs +++ b/src/lwe.rs @@ -43,36 +43,32 @@ impl LweSecret { } } -fn lwe_key_switch< +pub(crate) fn lwe_key_switch< M: Matrix, - Mmut: MatrixMut + MatrixEntity, + Ro: AsMut<[M::MatElement]> + AsRef<[M::MatElement]>, Op: VectorOps + ArithmeticOps, D: Decomposer, >( - lwe_out: &mut Mmut, - lwe_in: &M, + lwe_out: &mut Ro, + lwe_in: &Ro, lwe_ksk: &M, operator: &Op, decomposer: &D, -) where - ::R: RowMut, -{ - assert!(lwe_ksk.dimension().0 == ((lwe_in.dimension().1 - 1) * decomposer.d())); - assert!(lwe_out.dimension() == (1, lwe_ksk.dimension().1)); - - let mut scratch_space = Mmut::zeros(1, lwe_out.dimension().1); +) { + assert!(lwe_ksk.dimension().0 == ((lwe_in.as_ref().len() - 1) * decomposer.d())); + assert!(lwe_out.as_ref().len() == lwe_ksk.dimension().1); let lwe_in_a_decomposed = lwe_in - .get_row(0) + .as_ref() + .iter() .skip(1) .flat_map(|ai| decomposer.decompose(ai)); izip!(lwe_in_a_decomposed, lwe_ksk.iter_rows()).for_each(|(ai_j, beta_ij_lwe)| { - operator.elwise_scalar_mul(scratch_space.get_row_mut(0), beta_ij_lwe.as_ref(), &ai_j); - operator.elwise_add_mut(lwe_out.get_row_mut(0), scratch_space.get_row_slice(0)) + operator.elwise_fma_scalar_mut(lwe_out.as_mut(), beta_ij_lwe.as_ref(), &ai_j); }); - let out_b = operator.add(lwe_out.get(0, 0), lwe_in.get(0, 0)); - lwe_out.set(0, 0, out_b); + let out_b = operator.add(&lwe_out.as_ref()[0], &lwe_in.as_ref()[0]); + lwe_out.as_mut()[0] = out_b; } fn lwe_ksk_keygen< @@ -82,34 +78,34 @@ fn lwe_ksk_keygen< R: RandomGaussianDist + RandomUniformDist<[Mmut::MatElement], Parameters = Mmut::MatElement>, >( - lwe_sk_in: &S, - lwe_sk_out: &S, + from_lwe_sk: &S, + to_lwe_sk: &S, ksk_out: &mut Mmut, gadget: &[Mmut::MatElement], operator: &Op, rng: &mut R, ) where ::R: RowMut, - Mmut: TryConvertFrom<[S::Element], Parameters = Mmut::MatElement>, + Mmut::R: TryConvertFrom<[S::Element], Parameters = Mmut::MatElement>, Mmut::MatElement: Zero + Debug, { assert!( ksk_out.dimension() == ( - lwe_sk_in.values().len() * gadget.len(), - lwe_sk_out.values().len() + 1, + from_lwe_sk.values().len() * gadget.len(), + to_lwe_sk.values().len() + 1, ) ); let d = gadget.len(); let modulus = VectorOps::modulus(operator); - let mut neg_sk_in_m = Mmut::try_convert_from(lwe_sk_in.values(), &modulus); - operator.elwise_neg_mut(neg_sk_in_m.get_row_mut(0)); - let sk_out_m = Mmut::try_convert_from(lwe_sk_out.values(), &modulus); + let mut neg_sk_in_m = Mmut::R::try_convert_from(from_lwe_sk.values(), &modulus); + operator.elwise_neg_mut(neg_sk_in_m.as_mut()); + let sk_out_m = Mmut::R::try_convert_from(to_lwe_sk.values(), &modulus); izip!( - neg_sk_in_m.get_row(0), + neg_sk_in_m.as_ref(), ksk_out.iter_rows_mut().chunks(d).into_iter() ) .for_each(|(neg_sk_in_si, d_ks_lwes)| { @@ -119,7 +115,7 @@ fn lwe_ksk_keygen< // a * z let mut az = Mmut::MatElement::zero(); - izip!(lwe.as_ref()[1..].iter(), sk_out_m.get_row(0)).for_each(|(ai, si)| { + izip!(lwe.as_ref()[1..].iter(), sk_out_m.as_ref()).for_each(|(ai, si)| { let ai_si = operator.mul(ai, si); az = operator.add(&az, &ai_si); }); @@ -139,59 +135,57 @@ fn lwe_ksk_keygen< /// Encrypts encoded message m as LWE ciphertext fn encrypt_lwe< - Mmut: MatrixMut + MatrixEntity, - R: RandomGaussianDist - + RandomUniformDist<[Mmut::MatElement], Parameters = Mmut::MatElement>, + Ro: Row + RowMut, + R: RandomGaussianDist + + RandomUniformDist<[Ro::Element], Parameters = Ro::Element>, S: Secret, - Op: ArithmeticOps, + Op: ArithmeticOps, >( - lwe_out: &mut Mmut, - m: Mmut::MatElement, + lwe_out: &mut Ro, + m: &Ro::Element, s: &S, operator: &Op, rng: &mut R, ) where - Mmut: TryConvertFrom<[S::Element], Parameters = Mmut::MatElement>, - Mmut::MatElement: Zero, - ::R: RowMut, + Ro: TryConvertFrom<[S::Element], Parameters = Ro::Element>, + Ro::Element: Zero, { - let s = Mmut::try_convert_from(s.values(), &operator.modulus()); - assert!(s.dimension().0 == (lwe_out.dimension().0)); - assert!(s.dimension().1 == (lwe_out.dimension().1 - 1)); + let s = Ro::try_convert_from(s.values(), &operator.modulus()); + assert!(s.as_ref().len() == (lwe_out.as_ref().len() - 1)); // a*s - RandomUniformDist::random_fill(rng, &operator.modulus(), &mut lwe_out.get_row_mut(0)[1..]); - let mut sa = Mmut::MatElement::zero(); - izip!(lwe_out.get_row(0).skip(1), s.get_row(0)).for_each(|(ai, si)| { + RandomUniformDist::random_fill(rng, &operator.modulus(), &mut lwe_out.as_mut()[1..]); + let mut sa = Ro::Element::zero(); + izip!(lwe_out.as_mut().iter().skip(1), s.as_ref()).for_each(|(ai, si)| { let tmp = operator.mul(ai, si); sa = operator.add(&tmp, &sa); }); // b = a*s + e + m - let mut e = Mmut::MatElement::zero(); + let mut e = Ro::Element::zero(); RandomGaussianDist::random_fill(rng, &operator.modulus(), &mut e); - let b = operator.add(&operator.add(&sa, &e), &m); - lwe_out.set(0, 0, b); + let b = operator.add(&operator.add(&sa, &e), m); + lwe_out.as_mut()[0] = b; } -fn decrypt_lwe, S: Secret>( - lwe_ct: &M, +fn decrypt_lwe, S: Secret>( + lwe_ct: &Ro, s: &S, operator: &Op, -) -> M::MatElement +) -> Ro::Element where - M: TryConvertFrom<[S::Element], Parameters = M::MatElement>, - M::MatElement: Zero, + Ro: TryConvertFrom<[S::Element], Parameters = Ro::Element>, + Ro::Element: Zero, { - let s = M::try_convert_from(s.values(), &operator.modulus()); + let s = Ro::try_convert_from(s.values(), &operator.modulus()); - let mut sa = M::MatElement::zero(); - izip!(lwe_ct.get_row(0).skip(1), s.get_row(0)).for_each(|(ai, si)| { + let mut sa = Ro::Element::zero(); + izip!(lwe_ct.as_ref().iter().skip(1), s.as_ref()).for_each(|(ai, si)| { let tmp = operator.mul(ai, si); sa = operator.add(&tmp, &sa); }); - let b = &lwe_ct.get_row_slice(0)[0]; + let b = &lwe_ct.as_ref()[0]; operator.sub(b, &sa) } @@ -222,8 +216,8 @@ mod tests { // encrypt for m in 0..1u64 << logp { let encoded_m = m << (logq - logp); - let mut lwe_ct = vec![vec![0u64; lwe_n + 1]]; - encrypt_lwe(&mut lwe_ct, encoded_m, &lwe_sk, &modq_op, &mut rng); + let mut lwe_ct = vec![0u64; lwe_n + 1]; + encrypt_lwe(&mut lwe_ct, &encoded_m, &lwe_sk, &modq_op, &mut rng); let encoded_m_back = decrypt_lwe(&lwe_ct, &lwe_sk, &modq_op); let m_back = ((((encoded_m_back as f64) * ((1 << logp) as f64)) / q as f64).round() as u64) @@ -265,12 +259,12 @@ mod tests { for m in 0..(1 << logp) { // encrypt using lwe_sk_in let encoded_m = m << (logq - logp); - let mut lwe_in_ct = vec![vec![0u64; lwe_in_n + 1]]; - encrypt_lwe(&mut lwe_in_ct, encoded_m, &lwe_sk_in, &modq_op, &mut rng); + let mut lwe_in_ct = vec![0u64; lwe_in_n + 1]; + encrypt_lwe(&mut lwe_in_ct, &encoded_m, &lwe_sk_in, &modq_op, &mut rng); // key switch from lwe_sk_in to lwe_sk_out let decomposer = DefaultDecomposer::new(1u64 << logq, logb, d_ks); - let mut lwe_out_ct = vec![vec![0u64; lwe_out_n + 1]]; + let mut lwe_out_ct = vec![0u64; lwe_out_n + 1]; lwe_key_switch(&mut lwe_out_ct, &lwe_in_ct, &ksk, &modq_op, &decomposer); // decrypt lwe_out_ct using lwe_sk_out diff --git a/src/rgsw.rs b/src/rgsw.rs index 84f473d..dac3361 100644 --- a/src/rgsw.rs +++ b/src/rgsw.rs @@ -1,10 +1,11 @@ use std::{ + clone, fmt::Debug, ops::{Neg, Sub}, }; use itertools::{izip, Itertools}; -use num_traits::{PrimInt, ToPrimitive}; +use num_traits::{PrimInt, ToPrimitive, Zero}; use crate::{ backend::{ArithmeticOps, VectorOps}, @@ -15,6 +16,48 @@ use crate::{ Matrix, MatrixEntity, MatrixMut, RowMut, Secret, }; +struct RlweCiphertext(M, bool); + +impl Matrix for RlweCiphertext { + type MatElement = M::MatElement; + type R = M::R; + + fn dimension(&self) -> (usize, usize) { + self.0.dimension() + } +} + +impl MatrixMut for RlweCiphertext where ::R: RowMut {} + +impl AsRef<[::R]> for RlweCiphertext { + fn as_ref(&self) -> &[::R] { + self.0.as_ref() + } +} + +impl AsMut<[::R]> for RlweCiphertext +where + ::R: RowMut, +{ + fn as_mut(&mut self) -> &mut [::R] { + self.0.as_mut() + } +} + +impl IsTrivial for RlweCiphertext { + fn is_trivial(&self) -> bool { + self.1 + } + fn set_not_trivial(&mut self) { + self.1 = false; + } +} + +pub trait IsTrivial { + fn is_trivial(&self) -> bool; + fn set_not_trivial(&mut self); +} + struct RlweSecret { values: Vec, } @@ -70,7 +113,7 @@ fn generate_auto_map(ring_size: usize, k: usize) -> (Vec, Vec) { /// - neg_from_s_eval: Negative of secret polynomial to key switch from in /// evaluation domain /// - to_s_eval: secret polynomial to key switch to in evalution domain. -fn rlwe_ksk_gen< +pub(crate) fn rlwe_ksk_gen< Mmut: MatrixMut + MatrixEntity, ModOp: ArithmeticOps + VectorOps, NttOp: Ntt, @@ -98,7 +141,7 @@ fn rlwe_ksk_gen< let mut scratch_space = Mmut::zeros(1, ring_size); // RLWE'_{to_s}(-from_s) - let (part_a, part_b) = ksk_out.split_at_row(d); + let (part_a, part_b) = ksk_out.split_at_row_mut(d); izip!(part_a.iter_mut(), part_b.iter_mut(), gadget_vector.iter()).for_each( |(ai, bi, beta_i)| { // sample ai and transform to evaluation @@ -130,7 +173,7 @@ fn rlwe_ksk_gen< ); } -fn galois_key_gen< +pub(crate) fn galois_key_gen< Mmut: MatrixMut + MatrixEntity, ModOp: ArithmeticOps + VectorOps, NttOp: Ntt, @@ -179,17 +222,16 @@ fn galois_key_gen< } /// Sends RLWE_{s}(X) -> RLWE_{s}(X^k) where k is some galois element -fn galois_auto< - M: Matrix, - Mmut: MatrixMut, - ModOp: ArithmeticOps + VectorOps, - NttOp: Ntt, - D: Decomposer, +pub(crate) fn galois_auto< + MT: Matrix + IsTrivial + MatrixMut, + Mmut: MatrixMut, + ModOp: ArithmeticOps + VectorOps, + NttOp: Ntt, + D: Decomposer, >( - rlwe_in: &M, - ksk: &M, - rlwe_out: &mut Mmut, - a_rlwe_decomposed: &mut Mmut, + rlwe_in: &mut MT, + ksk: &Mmut, + scratch_matrix_dplus2_ring: &mut Mmut, auto_map_index: &[usize], auto_map_sign: &[bool], mod_op: &ModOp, @@ -197,10 +239,13 @@ fn galois_auto< decomposer: &D, ) where ::R: RowMut, - M::MatElement: Copy, + ::R: RowMut, + MT::MatElement: Copy + Zero, { let d = decomposer.d(); + let (scratch_matrix_d_ring, tmp_rlwe_out) = scratch_matrix_dplus2_ring.split_at_row_mut(d); + // send b(X) -> b(X^k) izip!( rlwe_in.get_row(1), @@ -209,45 +254,69 @@ fn galois_auto< ) .for_each(|(el_in, to_index, sign)| { if !*sign { - rlwe_out.set(1, *to_index, mod_op.neg(el_in)); + tmp_rlwe_out[1].as_mut()[*to_index] = mod_op.neg(el_in); } else { - rlwe_out.set(1, *to_index, *el_in); + tmp_rlwe_out[1].as_mut()[*to_index] = *el_in; + // scratch_matrix_dplus2_ring.set(d + 1, *to_index, *el_in); } }); - // send a(X) -> a(X^k) and decompose a(X^k) - izip!( - rlwe_in.get_row(0), - auto_map_index.iter(), - auto_map_sign.iter() - ) - .for_each(|(el_in, to_index, sign)| { - let el_out = if !*sign { mod_op.neg(el_in) } else { *el_in }; - - let el_out_decomposed = decomposer.decompose(&el_out); - for j in 0..d { - a_rlwe_decomposed.set(j, *to_index, el_out_decomposed[j]); - } - }); + if !rlwe_in.is_trivial() { + // send a(X) -> a(X^k) and decompose a(X^k) + izip!( + rlwe_in.get_row(0), + auto_map_index.iter(), + auto_map_sign.iter() + ) + .for_each(|(el_in, to_index, sign)| { + let el_out = if !*sign { mod_op.neg(el_in) } else { *el_in }; + + let el_out_decomposed = decomposer.decompose(&el_out); + for j in 0..d { + scratch_matrix_d_ring[j].as_mut()[*to_index] = el_out_decomposed[j]; + } + }); + + // transform decomposed a(X^k) to evaluation domain + scratch_matrix_d_ring.iter_mut().for_each(|r| { + ntt_op.forward(r.as_mut()); + }); + + // RLWE(m^k) = a', b'; RLWE(m) = a, b + // key switch: (a * RLWE'(s(X^k))) + let (ksk_a, ksk_b) = ksk.split_at_row(d); + tmp_rlwe_out[0].as_mut().fill(Mmut::MatElement::zero()); + // a' = decomp * RLWE'_A(s(X^k)) + routine::( + tmp_rlwe_out[0].as_mut(), + scratch_matrix_d_ring, + ksk_a, + mod_op, + ); + // send b(X^k) to evaluation domain + ntt_op.forward(tmp_rlwe_out[1].as_mut()); + // b' = b(X^k) + // b' += decomp * RLWE'_B(s(X^k)) + routine::( + tmp_rlwe_out[1].as_mut(), + scratch_matrix_d_ring, + ksk_b, + mod_op, + ); - // transform decomposed a(X^k) to evaluation domain - a_rlwe_decomposed.iter_rows_mut().for_each(|r| { - ntt_op.forward(r.as_mut()); - }); + // transform RLWE(m^k) to coefficient domain + tmp_rlwe_out + .iter_mut() + .for_each(|r| ntt_op.backward(r.as_mut())); - // key switch (a(X^k) * RLWE'(s(X^k))) - izip!(a_rlwe_decomposed.iter_rows(), ksk.iter_rows().take(d)).for_each(|(a, b)| { - mod_op.elwise_fma_mut(rlwe_out.get_row_mut(0), a.as_ref(), b.as_ref()); - }); - ntt_op.forward(rlwe_out.get_row_mut(1)); - izip!(a_rlwe_decomposed.iter_rows(), ksk.iter_rows().skip(d)).for_each(|(a, b)| { - mod_op.elwise_fma_mut(rlwe_out.get_row_mut(1), a.as_ref(), b.as_ref()); - }); + rlwe_in + .get_row_mut(0) + .copy_from_slice(tmp_rlwe_out[0].as_ref()); + } - // transform RLWE(-s(X^k) * a(X^k)) to coefficient domain - rlwe_out - .iter_rows_mut() - .for_each(|r| ntt_op.backward(r.as_mut())); + rlwe_in + .get_row_mut(1) + .copy_from_slice(tmp_rlwe_out[1].as_ref()); } /// Encrypts message m as a RGSW ciphertext. @@ -256,7 +325,7 @@ fn galois_auto< /// - out_rgsw: RGSW(m) is stored as single matrix of dimension (d_rgsw * 4, /// ring_size). The matrix has the following structure [RLWE'_A(-sm) || /// RLWE'_B(-sm) || RLWE'_A(m) || RLWE'_B(m)]^T -fn encrypt_rgsw< +pub(crate) fn encrypt_rgsw< Mmut: MatrixMut + MatrixEntity, M: Matrix + Clone, S: Secret, @@ -283,7 +352,7 @@ fn encrypt_rgsw< assert!(m_eval.dimension() == (1, ring_size)); // RLWE(-sm), RLWE(-sm) - let (rlwe_dash_nsm, rlwe_dash_m) = out_rgsw.split_at_row(d * 2); + let (rlwe_dash_nsm, rlwe_dash_m) = out_rgsw.split_at_row_mut(d * 2); let mut s_eval = Mmut::try_convert_from(s.values(), &q); ntt_op.forward(s_eval.get_row_mut(0).as_mut()); @@ -364,7 +433,7 @@ fn encrypt_rgsw< /// - rgsw_in: RGSW(m') in evaluation domain /// - rlwe_in_decomposed: decomposed RLWE(m) in evaluation domain /// - rlwe_out: returned RLWE(mm') in evaluation domain -fn rlwe_in_decomposed_evaluation_domain_mul_rgsw_rlwe_out_evaluation_domain< +pub(crate) fn rlwe_in_decomposed_evaluation_domain_mul_rgsw_rlwe_out_evaluation_domain< Mmut: MatrixMut + MatrixEntity, M: Matrix + Clone, ModOp: VectorOps, @@ -381,7 +450,7 @@ fn rlwe_in_decomposed_evaluation_domain_mul_rgsw_rlwe_out_evaluation_domain< assert!(rlwe_in_decomposed_eval.dimension() == (2 * d_rgsw, ring_size)); assert!(rlwe_out_eval.dimension() == (2, ring_size)); - let (a_rlwe_out, b_rlwe_out) = rlwe_out_eval.split_at_row(1); + let (a_rlwe_out, b_rlwe_out) = rlwe_out_eval.split_at_row_mut(1); // a * RLWE'(-sm) let a_rlwe_dash_nsm = rgsw_in.iter_rows().take(d_rgsw); @@ -420,82 +489,165 @@ fn rlwe_in_decomposed_evaluation_domain_mul_rgsw_rlwe_out_evaluation_domain< }); } -fn decompose_rlwe< - M: Matrix + Clone, - Mmut: MatrixMut + MatrixEntity, - D: Decomposer, ->( - rlwe_in: &M, +pub(crate) fn routine>( + write_to_row: &mut [M::MatElement], + matrix_a: &[M::R], + matrix_b: &[M::R], + mod_op: &ModOp, +) { + izip!(matrix_a.iter(), matrix_b.iter()).for_each(|(a, b)| { + mod_op.elwise_fma_mut(write_to_row, a.as_ref(), b.as_ref()); + }); +} + +// pub(crate) fn decompose_rlwe< +// M: Matrix + Clone, +// Mmut: MatrixMut + MatrixEntity, +// D: Decomposer, +// >( +// rlwe_in: &M, +// decomposer: &D, +// rlwe_in_decomposed: &mut Mmut, +// ) where +// M::MatElement: Copy, +// ::R: RowMut, +// { +// let d_rgsw = decomposer.d(); +// let ring_size = rlwe_in.dimension().1; +// assert!(rlwe_in_decomposed.dimension() == (2 * d_rgsw, ring_size)); + +// // Decompose rlwe_in +// for ri in 0..ring_size { +// // ai +// let ai_decomposed = decomposer.decompose(rlwe_in.get(0, ri)); +// for j in 0..d_rgsw { +// rlwe_in_decomposed.set(j, ri, ai_decomposed[j]); +// } + +// // bi +// let bi_decomposed = decomposer.decompose(rlwe_in.get(1, ri)); +// for j in 0..d_rgsw { +// rlwe_in_decomposed.set(j + d_rgsw, ri, bi_decomposed[j]); +// } +// } +// } + +/// Decomposes ring polynomial r(X) into d polynomials using decomposer into +/// output matrix decomp_r +/// +/// Note that decomposition of r(X) requires decomposition of each of +/// coefficients. +/// +/// - decomp_r: must have dimensions d x ring_size. i^th decomposed polynomial +/// will be stored at i^th row. +pub(crate) fn decompose_r>( + r: &[M::MatElement], + decomp_r: &mut [M::R], decomposer: &D, - rlwe_in_decomposed: &mut Mmut, ) where + ::R: RowMut, M::MatElement: Copy, - ::R: RowMut, { - let d_rgsw = decomposer.d(); - let ring_size = rlwe_in.dimension().1; - assert!(rlwe_in_decomposed.dimension() == (2 * d_rgsw, ring_size)); + let ring_size = r.len(); + let d = decomposer.d(); - // Decompose rlwe_in for ri in 0..ring_size { - // ai - let ai_decomposed = decomposer.decompose(rlwe_in.get(0, ri)); - for j in 0..d_rgsw { - rlwe_in_decomposed.set(j, ri, ai_decomposed[j]); - } - - // bi - let bi_decomposed = decomposer.decompose(rlwe_in.get(1, ri)); - for j in 0..d_rgsw { - rlwe_in_decomposed.set(j + d_rgsw, ri, bi_decomposed[j]); + let el_decomposed = decomposer.decompose(&r[ri]); + for j in 0..d { + decomp_r[j].as_mut()[ri] = el_decomposed[j]; } } } -/// Returns RLWE(m0m1) = RLWE(m0) x RGSW(m1) +/// Returns RLWE(m0m1) = RLWE(m0) x RGSW(m1). Mutates rlwe_in inplace to equal +/// RLWE(m0m1) /// /// - rlwe_in: is RLWE(m0) with polynomials in coefficient domain /// - rgsw_in: is RGSW(m1) with polynomials in evaluation domain -/// - rlwe_out: is output RLWE(m0m1) with polynomials in coefficient domain -/// - rlwe_in_decomposed: is a matrix of dimension (d_rgsw * 2, ring_size) used -/// as scratch space to store decomposed RLWE(m0) -fn rlwe_by_rgsw< - M: Matrix + Clone, - Mmut: MatrixMut + MatrixEntity, - D: Decomposer, - ModOp: VectorOps, - NttOp: Ntt, +/// - scratch_matrix_d_ring: is a matrix of dimension (d_rgsw, ring_size) used +/// as scratch space to store decomposed Ring elements temporarily +pub(crate) fn rlwe_by_rgsw< + Mmut: MatrixMut, + MT: Matrix + MatrixMut + IsTrivial, + D: Decomposer, + ModOp: VectorOps, + NttOp: Ntt, >( - rlwe_in: &M, - rgsw_in: &M, - rlwe_out: &mut Mmut, - rlwe_in_decomposed: &mut Mmut, + rlwe_in: &mut MT, + rgsw_in: &Mmut, + scratch_matrix_dplus2_ring: &mut Mmut, decomposer: &D, ntt_op: &NttOp, mod_op: &ModOp, ) where - M::MatElement: Copy, + Mmut::MatElement: Copy + Zero, ::R: RowMut, + ::R: RowMut, { - decompose_rlwe(rlwe_in, decomposer, rlwe_in_decomposed); - - // transform rlwe_in decomposed to evaluation domain - rlwe_in_decomposed - .iter_rows_mut() - .for_each(|r| ntt_op.forward(r.as_mut())); + let d_rgsw = decomposer.d(); + assert!(scratch_matrix_dplus2_ring.dimension() == (d_rgsw + 2, rlwe_in.dimension().1)); // decomposed RLWE x RGSW - rlwe_in_decomposed_evaluation_domain_mul_rgsw_rlwe_out_evaluation_domain( - rgsw_in, - rlwe_in_decomposed, - rlwe_out, + let (rlwe_dash_nsm, rlwe_dash_m) = rgsw_in.split_at_row(d_rgsw * 2); + let (scratch_matrix_d_ring, scratch_rlwe_out) = + scratch_matrix_dplus2_ring.split_at_row_mut(d_rgsw); + scratch_rlwe_out[0].as_mut().fill(Mmut::MatElement::zero()); + scratch_rlwe_out[1].as_mut().fill(Mmut::MatElement::zero()); + // RLWE_in = a_in, b_in; RLWE_out = a_out, b_out + if !rlwe_in.is_trivial() { + // a_in = 0 when RLWE_in is trivial RLWE ciphertext + // decomp + decompose_r::(rlwe_in.get_row_slice(0), scratch_matrix_d_ring, decomposer); + scratch_matrix_d_ring + .iter_mut() + .for_each(|r| ntt_op.forward(r.as_mut())); + // a_out += decomp \cdot RLWE_A'(-sm) + routine::( + scratch_rlwe_out[0].as_mut(), + scratch_matrix_d_ring.as_ref(), + &rlwe_dash_nsm[..d_rgsw], + mod_op, + ); + // b_out += decomp \cdot RLWE_B'(-sm) + routine::( + scratch_rlwe_out[1].as_mut(), + scratch_matrix_d_ring.as_ref(), + &rlwe_dash_nsm[d_rgsw..], + mod_op, + ); + } + // decomp + decompose_r::(rlwe_in.get_row_slice(1), scratch_matrix_d_ring, decomposer); + scratch_matrix_d_ring + .iter_mut() + .for_each(|r| ntt_op.forward(r.as_mut())); + // a_out += decomp \cdot RLWE_A'(m) + routine::( + scratch_rlwe_out[0].as_mut(), + scratch_matrix_d_ring.as_ref(), + &rlwe_dash_m[..d_rgsw], + mod_op, + ); + // b_out += decomp \cdot RLWE_B'(m) + routine::( + scratch_rlwe_out[1].as_mut(), + scratch_matrix_d_ring.as_ref(), + &rlwe_dash_m[d_rgsw..], mod_op, ); // transform rlwe_out to coefficient domain - rlwe_out - .iter_rows_mut() + scratch_rlwe_out + .iter_mut() .for_each(|r| ntt_op.backward(r.as_mut())); + + rlwe_in + .get_row_mut(0) + .copy_from_slice(scratch_rlwe_out[0].as_mut()); + rlwe_in + .get_row_mut(1) + .copy_from_slice(scratch_rlwe_out[1].as_mut()); + rlwe_in.set_not_trivial(); } /// Encrypt polynomial m(X) as RLWE ciphertext. @@ -503,7 +655,7 @@ fn rlwe_by_rgsw< /// - rlwe_out: returned RLWE ciphertext RLWE(m) in coefficient domain. RLWE /// ciphertext is a matirx with first row consiting polynomial `a` and the /// second rows consting polynomial `b` -fn encrypt_rlwe< +pub(crate) fn encrypt_rlwe< Mmut: Matrix + MatrixMut + Clone, ModOp: VectorOps, NttOp: Ntt, @@ -547,7 +699,7 @@ fn encrypt_rlwe< /// Decrypts degree 1 RLWE ciphertext RLWE(m) and returns m /// /// - rlwe_ct: input degree 1 ciphertext RLWE(m). -fn decrypt_rlwe< +pub(crate) fn decrypt_rlwe< Mmut: MatrixMut + Clone, M: Matrix, ModOp: VectorOps, @@ -587,7 +739,7 @@ fn decrypt_rlwe< // Measures noise in degree 1 RLWE ciphertext against encoded ideal message // encoded_m -fn measure_noise< +pub(crate) fn measure_noise< Mmut: MatrixMut + Matrix + MatrixEntity, ModOp: VectorOps, NttOp: Ntt, @@ -657,8 +809,9 @@ mod tests { decomposer::{gadget_vector, DefaultDecomposer}, ntt::{self, Ntt, NttBackendU64}, random::{DefaultSecureRng, RandomUniformDist}, - rgsw::measure_noise, + rgsw::{measure_noise, RlweCiphertext}, utils::{generate_prime, negacyclic_mul}, + Matrix, }; use super::{ @@ -718,15 +871,14 @@ mod tests { &ntt_op, &mut rng, ); + let mut rlwe_in_ct = RlweCiphertext(rlwe_in_ct, false); // RLWE(m0m1) = RLWE(m0) x RGSW(m1) - let mut rlwe_out_ct = vec![vec![0u64; ring_size as usize]; 2]; - let mut scratch_space = vec![vec![0u64; ring_size as usize]; d_rgsw * 2]; + let mut scratch_space = vec![vec![0u64; ring_size as usize]; d_rgsw + 2]; let decomposer = DefaultDecomposer::new(q, logb, d_rgsw); rlwe_by_rgsw( - &rlwe_in_ct, + &mut rlwe_in_ct, &rgsw_ct, - &mut rlwe_out_ct, &mut scratch_space, &decomposer, &ntt_op, @@ -735,7 +887,7 @@ mod tests { // Decrypt RLWE(m0m1) let mut encoded_m0m1_back = vec![vec![0u64; ring_size as usize]]; - decrypt_rlwe(&rlwe_out_ct, &s, &mut encoded_m0m1_back, &ntt_op, &mod_op); + decrypt_rlwe(&rlwe_in_ct, &s, &mut encoded_m0m1_back, &ntt_op, &mod_op); let m0m1_back = encoded_m0m1_back[0] .iter() .map(|v| (((*v as f64 * p as f64) / (q as f64)).round() as u64) % p) @@ -797,14 +949,13 @@ mod tests { ); // Send RLWE_{s}(m) -> RLWE_{s}(m^k) - let mut rlwe_m_k = vec![vec![0u64; ring_size as usize]; 2]; - let mut scratch_space = vec![vec![0u64; ring_size as usize]; d_rgsw]; + let mut rlwe_m = RlweCiphertext(rlwe_m, false); + let mut scratch_space = vec![vec![0u64; ring_size as usize]; d_rgsw + 2]; let (auto_map_index, auto_map_sign) = generate_auto_map(ring_size as usize, auto_k); let decomposer = DefaultDecomposer::new(q, logb, d_rgsw); galois_auto( - &rlwe_m, + &mut rlwe_m, &ksk_out, - &mut rlwe_m_k, &mut scratch_space, &auto_map_index, &auto_map_sign, @@ -813,6 +964,8 @@ mod tests { &decomposer, ); + let rlwe_m_k = rlwe_m; + // Decrypt RLWE_{s}(m^k) and check let mut encoded_m_k_back = vec![vec![0u64; ring_size as usize]]; decrypt_rlwe(&rlwe_m_k, &s, &mut encoded_m_k_back, &ntt_op, &mod_op); @@ -834,13 +987,13 @@ mod tests { ); { - let encoded_m_k = m_k - .iter() - .map(|v| ((*v as f64 * q as f64) / p as f64).round() as u64) - .collect_vec(); + // let encoded_m_k = m_k + // .iter() + // .map(|v| ((*v as f64 * q as f64) / p as f64).round() as u64) + // .collect_vec(); - let noise = measure_noise(&rlwe_m_k, &vec![encoded_m_k], &ntt_op, &mod_op, &s); - println!("Ksk noise: {noise}"); + // let noise = measure_noise(&rlwe_m_k, &vec![encoded_m_k], &ntt_op, + // &mod_op, &s); println!("Ksk noise: {noise}"); } // FIXME(Jay): Galios autormophism will incur high error unless we fix in diff --git a/src/utils.rs b/src/utils.rs index 40d007a..7fcce6e 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -189,3 +189,24 @@ impl TryConvertFrom<[i32]> for Vec> { vec![row0] } } + +impl TryConvertFrom<[i32]> for Vec { + type Parameters = u64; + fn try_convert_from(value: &[i32], parameters: &Self::Parameters) -> Self { + value + .iter() + .map(|v| { + let is_neg = v.is_negative(); + let v_u64 = v.abs() as u64; + + assert!(v_u64 < *parameters); + + if is_neg { + parameters - v_u64 + } else { + v_u64 + } + }) + .collect_vec() + } +}