diff --git a/src/rgsw/mod.rs b/src/rgsw/mod.rs index 21e21ce..14bfee7 100644 --- a/src/rgsw/mod.rs +++ b/src/rgsw/mod.rs @@ -329,12 +329,12 @@ pub struct ShoupRgswCiphertextEvaluationDomain { } impl, R, N> - From> for ShoupRgswCiphertextEvaluationDomain + From<&RgswCiphertextEvaluationDomain> for ShoupRgswCiphertextEvaluationDomain where M::R: RowMut, M::MatElement: ToShoup + Copy, { - fn from(value: RgswCiphertextEvaluationDomain) -> Self { + fn from(value: &RgswCiphertextEvaluationDomain) -> Self { let (row, col) = value.data.dimension(); let mut shoup_data = M::zeros(row, col); @@ -534,7 +534,10 @@ pub(crate) mod tests { decomposer::{Decomposer, DefaultDecomposer, RlweDecomposer}, ntt::{self, Ntt, NttBackendU64, NttInit}, random::{DefaultSecureRng, NewWithSeed, RandomFillUniformInModulus}, - rgsw::{galois_auto_shoup, ShoupAutoKeyEvaluationDomain}, + rgsw::{ + galois_auto_shoup, rlwe_by_rgsw_shoup, ShoupAutoKeyEvaluationDomain, + ShoupRgswCiphertextEvaluationDomain, + }, utils::{generate_prime, negacyclic_mul, Stats, TryConvertFrom1}, Matrix, Secret, }; @@ -846,14 +849,46 @@ pub(crate) mod tests { decomposer.b().decomposition_count() ) + 2 ]; - rlwe_by_rgsw( - &mut rlwe_in_ct, - &rgsw_ct.data, - &mut scratch_space, - &decomposer, - &ntt_op, - &mod_op, - ); + + // rlwe x rgsw with additional RGSW ciphertexts in shoup repr + let rlwe_in_ct_shoup = { + let mut rlwe_in_ct_shoup = RlweCiphertext::<_, DefaultSecureRng> { + data: rlwe_in_ct.data.clone(), + is_trivial: rlwe_in_ct.is_trivial, + _phatom: PhantomData::default(), + }; + + let rgsw_ct_shoup = ShoupRgswCiphertextEvaluationDomain::from(&rgsw_ct); + + rlwe_by_rgsw_shoup( + &mut rlwe_in_ct_shoup, + &rgsw_ct.data, + &rgsw_ct_shoup.data, + &mut scratch_space, + &decomposer, + &ntt_op, + &mod_op, + ); + + rlwe_in_ct_shoup + }; + + // rlwe x rgsw normal + { + rlwe_by_rgsw( + &mut rlwe_in_ct, + &rgsw_ct.data, + &mut scratch_space, + &decomposer, + &ntt_op, + &mod_op, + ); + } + + // output from both functions must be equal + { + assert_eq!(rlwe_in_ct.data, rlwe_in_ct_shoup.data); + } // Decrypt RLWE(m0m1) let mut encoded_m0m1_back = vec![0u64; ring_size as usize]; diff --git a/src/rgsw/runtime.rs b/src/rgsw/runtime.rs index 28cb475..77bbe16 100644 --- a/src/rgsw/runtime.rs +++ b/src/rgsw/runtime.rs @@ -345,9 +345,12 @@ pub(crate) fn rlwe_by_rgsw< // decomposed RLWE x RGSW let (rlwe_dash_nsm, rlwe_dash_m) = rgsw_in.split_at_row(d_a * 2); - let (scratch_matrix_d_ring, scratch_rlwe_out) = scratch_matrix.split_at_row_mut(max_d); + let (scratch_matrix_d_ring, rest) = scratch_matrix.split_at_row_mut(max_d); + let (scratch_rlwe_out, _) = rest.split_at_mut(2); + scratch_rlwe_out[0].as_mut().fill(Mmut::MatElement::zero()); scratch_rlwe_out[1].as_mut().fill(Mmut::MatElement::zero()); + // RLWE_in = a_in, b_in; RLWE_out = a_out, b_out if !rlwe_in.is_trivial() { // a_in = 0 when RLWE_in is trivial RLWE ciphertext @@ -364,14 +367,14 @@ pub(crate) fn rlwe_by_rgsw< // a_out += decomp \cdot RLWE_A'(-sm) routine( scratch_rlwe_out[0].as_mut(), - scratch_matrix_d_ring.as_ref(), + &scratch_matrix_d_ring[..d_a], &rlwe_dash_nsm[..d_a], mod_op, ); // b_out += decomp \cdot RLWE_B'(-sm) routine( scratch_rlwe_out[1].as_mut(), - scratch_matrix_d_ring.as_ref(), + &scratch_matrix_d_ring[..d_a], &rlwe_dash_nsm[d_a..], mod_op, ); @@ -389,14 +392,14 @@ pub(crate) fn rlwe_by_rgsw< // a_out += decomp \cdot RLWE_A'(m) routine( scratch_rlwe_out[0].as_mut(), - scratch_matrix_d_ring.as_ref(), + &scratch_matrix_d_ring[..d_b], &rlwe_dash_m[..d_b], mod_op, ); // b_out += decomp \cdot RLWE_B'(m) routine( scratch_rlwe_out[1].as_mut(), - scratch_matrix_d_ring.as_ref(), + &scratch_matrix_d_ring[..d_b], &rlwe_dash_m[d_b..], mod_op, ); @@ -415,6 +418,116 @@ pub(crate) fn rlwe_by_rgsw< rlwe_in.set_not_trivial(); } +pub(crate) fn rlwe_by_rgsw_shoup< + Mmut: MatrixMut, + MT: Matrix + MatrixMut + IsTrivial, + D: RlweDecomposer, + ModOp: VectorOps + ShoupMatrixFMA, + NttOp: Ntt, +>( + rlwe_in: &mut MT, + rgsw_in: &Mmut, + rgsw_in_shoup: &Mmut, + scratch_matrix: &mut Mmut, + decomposer: &D, + ntt_op: &NttOp, + mod_op: &ModOp, +) where + Mmut::MatElement: Copy + Zero, + ::R: RowMut, + ::R: RowMut, +{ + let decomposer_a = decomposer.a(); + let decomposer_b = decomposer.b(); + let d_a = decomposer_a.decomposition_count(); + let d_b = decomposer_b.decomposition_count(); + let max_d = std::cmp::max(d_a, d_b); + assert!(scratch_matrix.fits(max_d + 2, rlwe_in.dimension().1)); + assert!(rgsw_in.dimension() == (d_a * 2 + d_b * 2, rlwe_in.dimension().1)); + assert!(rgsw_in.dimension() == rgsw_in_shoup.dimension()); + + // decomposed RLWE x RGSW + let (rlwe_dash_nsm, rlwe_dash_m) = rgsw_in.split_at_row(d_a * 2); + let (rlwe_dash_nsm_shoup, rlwe_dash_m_shoup) = rgsw_in_shoup.split_at_row(d_a * 2); + let (scratch_matrix_d_ring, rest) = scratch_matrix.split_at_row_mut(max_d); + let (scratch_rlwe_out, _) = rest.split_at_mut(2); + + scratch_rlwe_out[1].as_mut().fill(Mmut::MatElement::zero()); + scratch_rlwe_out[0].as_mut().fill(Mmut::MatElement::zero()); + + // RLWE_in = a_in, b_in; RLWE_out = a_out, b_out + if !rlwe_in.is_trivial() { + // a_in = 0 when RLWE_in is trivial RLWE ciphertext + // decomp + decompose_r( + rlwe_in.get_row_slice(0), + &mut scratch_matrix_d_ring[..d_a], + decomposer_a, + ); + scratch_matrix_d_ring + .iter_mut() + .take(d_a) + .for_each(|r| ntt_op.forward_lazy(r.as_mut())); + + // a_out += decomp \cdot RLWE_A'(-sm) + mod_op.shoup_matrix_fma( + scratch_rlwe_out[0].as_mut(), + &rlwe_dash_nsm[..d_a], + &rlwe_dash_nsm_shoup[..d_a], + &scratch_matrix_d_ring[..d_a], + ); + + // b_out += decomp \cdot RLWE_B'(-sm) + mod_op.shoup_matrix_fma( + scratch_rlwe_out[1].as_mut(), + &rlwe_dash_nsm[d_a..], + &rlwe_dash_nsm_shoup[d_a..], + &scratch_matrix_d_ring[..d_a], + ); + } + { + // decomp + decompose_r( + rlwe_in.get_row_slice(1), + &mut scratch_matrix_d_ring[..d_b], + decomposer_b, + ); + scratch_matrix_d_ring + .iter_mut() + .take(d_b) + .for_each(|r| ntt_op.forward_lazy(r.as_mut())); + + // a_out += decomp \cdot RLWE_A'(m) + mod_op.shoup_matrix_fma( + scratch_rlwe_out[0].as_mut(), + &rlwe_dash_m[..d_b], + &rlwe_dash_m_shoup[..d_b], + &scratch_matrix_d_ring[..d_b], + ); + + // b_out += decomp \cdot RLWE_B'(m) + mod_op.shoup_matrix_fma( + scratch_rlwe_out[1].as_mut(), + &rlwe_dash_m[d_b..], + &rlwe_dash_m_shoup[d_b..], + &scratch_matrix_d_ring[..d_b], + ); + } + + // transform rlwe_out to coefficient domain + scratch_rlwe_out + .iter_mut() + .for_each(|r| ntt_op.backward(r.as_mut())); + + rlwe_in + .get_row_mut(0) + .copy_from_slice(scratch_rlwe_out[0].as_mut()); + rlwe_in + .get_row_mut(1) + .copy_from_slice(scratch_rlwe_out[1].as_mut()); + rlwe_in.set_not_trivial(); +} + /// Inplace mutates rlwe_0 to equal RGSW(m0m1) = RGSW(m0)xRGSW(m1) /// in evaluation domain ///