From 1ff98541c87fcbfbc225b5e2ead96090394ad6d5 Mon Sep 17 00:00:00 2001 From: Janmajaya Mall Date: Sun, 30 Jun 2024 11:17:18 +0530 Subject: [PATCH] implement DoubleDecomposer for Rlwe Decomposer --- src/bool/evaluator.rs | 185 ++++++++++++++++++++++------------------ src/bool/parameters.rs | 44 ++++++---- src/bool/print_noise.rs | 17 ++-- src/decomposer.rs | 55 ++++++++++-- src/lwe.rs | 2 +- src/pbs.rs | 6 +- src/rgsw/mod.rs | 39 ++++----- src/rgsw/runtime.rs | 70 ++++++++------- 8 files changed, 251 insertions(+), 167 deletions(-) diff --git a/src/bool/evaluator.rs b/src/bool/evaluator.rs index 4d0a228..c1197ef 100644 --- a/src/bool/evaluator.rs +++ b/src/bool/evaluator.rs @@ -26,8 +26,9 @@ use crate::{ }, rgsw::{ decrypt_rlwe, generate_auto_map, public_key_encrypt_rgsw, rgsw_by_rgsw_inplace, - rgsw_x_rgsw_scratch_rows, rlwe_auto, secret_key_encrypt_rgsw, seeded_auto_key_gen, - RgswCiphertext, RgswCiphertextMutRef, RgswCiphertextRef, RuntimeScratchMutRef, + rgsw_x_rgsw_scratch_rows, rlwe_auto, rlwe_auto_scratch_rows, rlwe_x_rgsw_scratch_rows, + secret_key_encrypt_rgsw, seeded_auto_key_gen, RgswCiphertext, RgswCiphertextMutRef, + RgswCiphertextRef, RuntimeScratchMutRef, }, utils::{ encode_x_pow_si_with_emebedding_factor, fill_random_ternary_secret_with_hamming_weight, @@ -237,16 +238,17 @@ where // Vector to store LWE ciphertext with LWE dimesnion n let lwe_vector = M::R::zeros(parameters.lwe_n().0 + 1); - // Matrix to store decomposed polynomials - // Max decompistion count + space for temporary RLWE - let d = std::cmp::max( - parameters.auto_decomposition_count().0, + // PBS perform two operations at runtime: RLWE x RGW and RLWE auto. Since the + // operations are performed serially same scratch space can be used for both. + // Hence we create scratch space that contains maximum amount of rows that + // suffices for RLWE x RGSW and RLWE auto + let decomposition_matrix = M::zeros( std::cmp::max( - parameters.rlwe_rgsw_decomposition_count().0 .0, - parameters.rlwe_rgsw_decomposition_count().1 .0, + rlwe_x_rgsw_scratch_rows(parameters.rlwe_by_rgsw_decomposition_params()), + rlwe_auto_scratch_rows(parameters.auto_decomposition_param()), ), - ) + 2; - let decomposition_matrix = M::zeros(d, parameters.rlwe_n().0); + parameters.rlwe_n().0, + ); Self { lwe_vector, @@ -546,24 +548,30 @@ where ::R: RowMut + Clone, { let max_decomposer = - if decomposer.a().decomposition_count() > decomposer.b().decomposition_count() { + if decomposer.a().decomposition_count().0 > decomposer.b().decomposition_count().0 { decomposer.a() } else { decomposer.b() }; assert!( - ni_rgsw_ct.dimension() == (max_decomposer.decomposition_count(), parameters.rlwe_n().0) + ni_rgsw_ct.dimension() + == ( + max_decomposer.decomposition_count().0, + parameters.rlwe_n().0 + ) + ); + assert!( + aggregated_decomposed_ni_rgsw_zero_encs.len() == decomposer.a().decomposition_count().0, ); - assert!(aggregated_decomposed_ni_rgsw_zero_encs.len() == decomposer.a().decomposition_count(),); - assert!(decomposed_neg_ais.len() == decomposer.b().decomposition_count()); + assert!(decomposed_neg_ais.len() == decomposer.b().decomposition_count().0); let mut rgsw_i = M::zeros( - decomposer.a().decomposition_count() * 2 + decomposer.b().decomposition_count() * 2, + decomposer.a().decomposition_count().0 * 2 + decomposer.b().decomposition_count().0 * 2, parameters.rlwe_n().0, ); let (rlwe_dash_nsm, rlwe_dash_m) = - rgsw_i.split_at_row_mut(decomposer.a().decomposition_count() * 2); + rgsw_i.split_at_row_mut(decomposer.a().decomposition_count().0 * 2); // RLWE'_{s}(-sm) // Key switch `s * a_{i, l} + e` using ksk(u_j -> s) to produce RLWE(s * @@ -573,13 +581,13 @@ where // RLWE(s * u_{j=user_id} * a_{i, l}) { let (rlwe_dash_nsm_parta, rlwe_dash_nsm_partb) = - rlwe_dash_nsm.split_at_mut(decomposer.a().decomposition_count()); + rlwe_dash_nsm.split_at_mut(decomposer.a().decomposition_count().0); izip!( rlwe_dash_nsm_parta.iter_mut(), rlwe_dash_nsm_partb.iter_mut(), - ni_rgsw_ct - .iter_rows() - .skip(max_decomposer.decomposition_count() - decomposer.a().decomposition_count()), + ni_rgsw_ct.iter_rows().skip( + max_decomposer.decomposition_count().0 - decomposer.a().decomposition_count().0 + ), aggregated_decomposed_ni_rgsw_zero_encs.iter() ) .for_each(|(rlwe_a, rlwe_b, ni_rlwe_ct, decomp_zero_enc)| { @@ -612,13 +620,13 @@ where // RLWE'_{s}(m) { let (rlwe_dash_m_parta, rlwe_dash_partb) = - rlwe_dash_m.split_at_mut(decomposer.b().decomposition_count()); + rlwe_dash_m.split_at_mut(decomposer.b().decomposition_count().0); izip!( rlwe_dash_m_parta.iter_mut(), rlwe_dash_partb.iter_mut(), - ni_rgsw_ct - .iter_rows() - .skip(max_decomposer.decomposition_count() - decomposer.b().decomposition_count()), + ni_rgsw_ct.iter_rows().skip( + max_decomposer.decomposition_count().0 - decomposer.b().decomposition_count().0 + ), decomposed_neg_ais.iter() ) .for_each(|(rlwe_a, rlwe_b, ni_rlwe_ct, decomp_neg_ai)| { @@ -860,7 +868,10 @@ where } else { (g.pow(i as u32) % br_q) as isize }; - let mut gk = M::zeros(self.pbs_info.auto_decomposer.decomposition_count(), rlwe_n); + let mut gk = M::zeros( + self.pbs_info.auto_decomposer.decomposition_count().0, + rlwe_n, + ); seeded_auto_key_gen( &mut gk, &sk_rlwe, @@ -878,8 +889,8 @@ where let ring_size = self.pbs_info.parameters.rlwe_n().0; let rlwe_q = self.pbs_info.parameters.rlwe_q(); let (rlrg_d_a, rlrg_d_b) = ( - self.pbs_info.rlwe_rgsw_decomposer.0.decomposition_count(), - self.pbs_info.rlwe_rgsw_decomposer.1.decomposition_count(), + self.pbs_info.rlwe_rgsw_decomposer.0.decomposition_count().0, + self.pbs_info.rlwe_rgsw_decomposer.1.decomposition_count().0, ); let rlrg_gadget_a = self.pbs_info.rlwe_rgsw_decomposer.0.gadget_vector(); let rlrg_gadget_b = self.pbs_info.rlwe_rgsw_decomposer.1.gadget_vector(); @@ -997,7 +1008,7 @@ where rlrg_decomposer.b().gadget_vector(), ); for s_index in segment_start..segment_end { - let mut out_rgsw = M::zeros(rlrg_d_a * 2 + rlrg_d_b * 2, ring_size); + let mut out_rgsw = M::zeros(rlrg_d_a.0 * 2 + rlrg_d_b.0 * 2, ring_size); public_key_encrypt_rgsw( &mut out_rgsw, &encode_x_pow_si_with_emebedding_factor::< @@ -1039,7 +1050,7 @@ where ); for s_index in (0..segment_start).chain(segment_end..self.parameters().lwe_n().0) { - let mut out_rgsw = M::zeros(rgrg_d_a * 2 + rgrg_d_b * 2, ring_size); + let mut out_rgsw = M::zeros(rgrg_d_a.0 * 2 + rgrg_d_b.0 * 2, ring_size); public_key_encrypt_rgsw( &mut out_rgsw, &encode_x_pow_si_with_emebedding_factor::< @@ -1147,13 +1158,13 @@ where parameters.rgsw_rgsw_decomposer::>(); let rlwe_x_rgsw_decomposer = self.pbs_info().rlwe_rgsw_decomposer(); let rgsw_x_rgsw_dimension = ( - rgsw_x_rgsw_decomposer.a().decomposition_count() * 2 - + rgsw_x_rgsw_decomposer.b().decomposition_count() * 2, + rgsw_x_rgsw_decomposer.a().decomposition_count().0 * 2 + + rgsw_x_rgsw_decomposer.b().decomposition_count().0 * 2, rlwe_n, ); let rlwe_x_rgsw_dimension = ( - rlwe_x_rgsw_decomposer.a().decomposition_count() * 2 - + rlwe_x_rgsw_decomposer.b().decomposition_count() * 2, + rlwe_x_rgsw_decomposer.a().decomposition_count().0 * 2 + + rlwe_x_rgsw_decomposer.b().decomposition_count().0 * 2, rlwe_n, ); @@ -1213,13 +1224,13 @@ where rgsw_by_rgsw_inplace( &mut RgswCiphertextMutRef::new( rgsw_i.as_mut(), - rlwe_x_rgsw_decomposer.a().decomposition_count(), - rlwe_x_rgsw_decomposer.b().decomposition_count(), + rlwe_x_rgsw_decomposer.a().decomposition_count().0, + rlwe_x_rgsw_decomposer.b().decomposition_count().0, ), &RgswCiphertextRef::new( other_rgsw_i.as_ref(), - rgsw_x_rgsw_decomposer.a().decomposition_count(), - rgsw_x_rgsw_decomposer.b().decomposition_count(), + rgsw_x_rgsw_decomposer.a().decomposition_count().0, + rgsw_x_rgsw_decomposer.b().decomposition_count().0, ), rlwe_x_rgsw_decomposer, &rgsw_x_rgsw_decomposer, @@ -1306,7 +1317,7 @@ where let mut useri_ui_to_s_ksk = share.ui_to_s_ksk().clone(); assert!( useri_ui_to_s_ksk.dimension() - == (ni_uj_to_s_decomposer.decomposition_count(), ring_size) + == (ni_uj_to_s_decomposer.decomposition_count().0, ring_size) ); key_shares .iter() @@ -1315,7 +1326,7 @@ where let op2 = other_share.ui_to_s_ksk_zero_encs_for_user_i(share.user_index()); assert!( op2.dimension() - == (ni_uj_to_s_decomposer.decomposition_count(), ring_size) + == (ni_uj_to_s_decomposer.decomposition_count().0, ring_size) ); izip!(useri_ui_to_s_ksk.iter_rows_mut(), op2.iter_rows()).for_each( |(add_to, add_from)| { @@ -1341,7 +1352,8 @@ where let mut ksk_prng = DefaultSecureRng::new_seeded( cr_seed.ui_to_s_ks_seed_for_user_i::(share.user_index()), ); - let mut ais = M::zeros(ni_uj_to_s_decomposer.decomposition_count(), ring_size); + let mut ais = + M::zeros(ni_uj_to_s_decomposer.decomposition_count().0, ring_size); ais.iter_rows_mut().for_each(|r_ai| { RandomFillUniformInModulus::random_fill( @@ -1363,12 +1375,12 @@ where .parameters() .rlwe_rgsw_decomposer::>(); - let d_max = if rgsw_x_rgsw_decomposer.a().decomposition_count() - > rgsw_x_rgsw_decomposer.b().decomposition_count() + let d_max = if rgsw_x_rgsw_decomposer.a().decomposition_count().0 + > rgsw_x_rgsw_decomposer.b().decomposition_count().0 { - rgsw_x_rgsw_decomposer.a().decomposition_count() + rgsw_x_rgsw_decomposer.a().decomposition_count().0 } else { - rgsw_x_rgsw_decomposer.b().decomposition_count() + rgsw_x_rgsw_decomposer.b().decomposition_count().0 }; let mut scratch_rgsw_x_rgsw = M::zeros( @@ -1416,19 +1428,19 @@ where ); let mut scratch = M::R::zeros(self.parameters().rlwe_n().0); - (0..d_max - rgsw_x_rgsw_decomposer.b().decomposition_count()).for_each( - |_| { + (0..d_max - rgsw_x_rgsw_decomposer.b().decomposition_count().0) + .for_each(|_| { RandomFillUniformInModulus::random_fill( &mut a_prng, rlwe_q, scratch.as_mut(), ); - }, - ); + }); let decomp_neg_ais = (0..rgsw_x_rgsw_decomposer .b() - .decomposition_count()) + .decomposition_count() + .0) .map(|_| { RandomFillUniformInModulus::random_fill( &mut a_prng, @@ -1438,7 +1450,7 @@ where rlwe_modop.elwise_neg_mut(scratch.as_mut()); let mut decomp_neg_ai = M::zeros( - ni_uj_to_s_decomposer.decomposition_count(), + ni_uj_to_s_decomposer.decomposition_count().0, self.parameters().rlwe_n().0, ); scratch.as_ref().iter().enumerate().for_each(|(index, el)| { @@ -1468,7 +1480,8 @@ where // prepare for key switching let ni_rgsw_zero_encs = (0..rgsw_x_rgsw_decomposer .a() - .decomposition_count()) + .decomposition_count() + .0) .map(|i| { let mut sum = M::R::zeros(self.parameters().rlwe_n().0); key_shares.iter().for_each(|k| { @@ -1481,7 +1494,7 @@ where // decompose let mut decomp_sum = M::zeros( - ni_uj_to_s_decomposer.decomposition_count(), + ni_uj_to_s_decomposer.decomposition_count().0, self.parameters().rlwe_n().0, ); sum.as_ref().iter().enumerate().for_each(|(index, el)| { @@ -1513,9 +1526,13 @@ where &ni_rgsw_zero_encs[rgsw_x_rgsw_decomposer .a() .decomposition_count() - - rlwe_x_rgsw_decomposer.a().decomposition_count()..], - &decomp_neg_ais[rgsw_x_rgsw_decomposer.b().decomposition_count() - - rlwe_x_rgsw_decomposer.b().decomposition_count()..], + .0 + - rlwe_x_rgsw_decomposer.a().decomposition_count().0..], + &decomp_neg_ais[rgsw_x_rgsw_decomposer + .b() + .decomposition_count() + .0 + - rlwe_x_rgsw_decomposer.b().decomposition_count().0..], &rlwe_x_rgsw_decomposer, self.parameters(), (&uj_to_s_ksks[user_id], &uj_to_s_ksks_part_a_eval[user_id]), @@ -1551,13 +1568,13 @@ where rgsw_by_rgsw_inplace( &mut RgswCiphertextMutRef::new( rgsw_i.as_mut(), - rlwe_x_rgsw_decomposer.a().decomposition_count(), - rlwe_x_rgsw_decomposer.b().decomposition_count(), + rlwe_x_rgsw_decomposer.a().decomposition_count().0, + rlwe_x_rgsw_decomposer.b().decomposition_count().0, ), &RgswCiphertextRef::new( other_rgsw_i.as_ref(), - rgsw_x_rgsw_decomposer.a().decomposition_count(), - rgsw_x_rgsw_decomposer.b().decomposition_count(), + rgsw_x_rgsw_decomposer.a().decomposition_count().0, + rgsw_x_rgsw_decomposer.b().decomposition_count().0, ), &rlwe_x_rgsw_decomposer, &rgsw_x_rgsw_decomposer, @@ -1719,12 +1736,12 @@ where // We assume that d_{a/b} for RGSW x RGSW are always < d'_{a/b} for RLWE x RGSW assert!( - rlwe_x_rgsw_decomposer.a().decomposition_count() - < rgsw_x_rgsw_decomposer.a().decomposition_count() + rlwe_x_rgsw_decomposer.a().decomposition_count().0 + < rgsw_x_rgsw_decomposer.a().decomposition_count().0 ); assert!( - rlwe_x_rgsw_decomposer.b().decomposition_count() - < rgsw_x_rgsw_decomposer.b().decomposition_count() + rlwe_x_rgsw_decomposer.b().decomposition_count().0 + < rgsw_x_rgsw_decomposer.b().decomposition_count().0 ); let sj_poly_eval = { @@ -1733,8 +1750,8 @@ where s }; - let d_rgsw_a = rgsw_x_rgsw_decomposer.a().decomposition_count(); - let d_rgsw_b = rgsw_x_rgsw_decomposer.b().decomposition_count(); + let d_rgsw_a = rgsw_x_rgsw_decomposer.a().decomposition_count().0; + let d_rgsw_b = rgsw_x_rgsw_decomposer.b().decomposition_count().0; let d_max = std::cmp::max(d_rgsw_a, d_rgsw_b); // Zero encyptions for each LWE index. We generate d_a zero encryptions for each @@ -1809,13 +1826,14 @@ where // RGSW multiplication. We refer to such indices as where user is // not the leader. let self_leader_ni_rgsw_cts = { - let max_rlwe_x_rgsw_decomposer = if rlwe_x_rgsw_decomposer.a().decomposition_count() - > rlwe_x_rgsw_decomposer.b().decomposition_count() - { - rlwe_x_rgsw_decomposer.a() - } else { - rlwe_x_rgsw_decomposer.b() - }; + let max_rlwe_x_rgsw_decomposer = + if rlwe_x_rgsw_decomposer.a().decomposition_count().0 + > rlwe_x_rgsw_decomposer.b().decomposition_count().0 + { + rlwe_x_rgsw_decomposer.a() + } else { + rlwe_x_rgsw_decomposer.b() + }; let gadget_vec = max_rlwe_x_rgsw_decomposer.gadget_vector(); @@ -1828,7 +1846,7 @@ where // puncture p_rng d_max - d'_max time to align with `a_{i, l}`s used to // produce RGSW cts for RGSW x RGSW let mut scratch = M::R::zeros(self.parameters().rlwe_n().0); - (0..(d_max - max_rlwe_x_rgsw_decomposer.decomposition_count())) + (0..(d_max - max_rlwe_x_rgsw_decomposer.decomposition_count().0)) .into_iter() .for_each(|_| { RandomFillUniformInModulus::random_fill( @@ -1839,7 +1857,7 @@ where }); let mut ni_rgsw_cts = M::zeros( - max_rlwe_x_rgsw_decomposer.decomposition_count(), + max_rlwe_x_rgsw_decomposer.decomposition_count().0, self.parameters().rlwe_n().0, ); @@ -1891,13 +1909,14 @@ where }; let not_self_leader_rgsw_cts = { - let max_rgsw_x_rgsw_decomposer = if rgsw_x_rgsw_decomposer.a().decomposition_count() - > rgsw_x_rgsw_decomposer.b().decomposition_count() - { - rgsw_x_rgsw_decomposer.a() - } else { - rgsw_x_rgsw_decomposer.b() - }; + let max_rgsw_x_rgsw_decomposer = + if rgsw_x_rgsw_decomposer.a().decomposition_count().0 + > rgsw_x_rgsw_decomposer.b().decomposition_count().0 + { + rgsw_x_rgsw_decomposer.a() + } else { + rgsw_x_rgsw_decomposer.b() + }; let gadget_vec = max_rgsw_x_rgsw_decomposer.gadget_vector(); ((0..self_start_index).chain(self_end_index..self.parameters().lwe_n().0)) @@ -1906,7 +1925,7 @@ where cr_seed.ni_rgsw_ct_seed_for_index::(lwe_index), ); let mut ni_rgsw_cts = M::zeros( - max_rgsw_x_rgsw_decomposer.decomposition_count(), + max_rgsw_x_rgsw_decomposer.decomposition_count().0, self.parameters().rlwe_n().0, ); let mut scratch = M::R::zeros(self.parameters().rlwe_n().0); @@ -2014,7 +2033,7 @@ where }; let mut ksk_out = M::zeros( - self.pbs_info.auto_decomposer.decomposition_count(), + self.pbs_info.auto_decomposer.decomposition_count().0, ring_size, ); seeded_auto_key_gen( diff --git a/src/bool/parameters.rs b/src/bool/parameters.rs index 1f43788..d174709 100644 --- a/src/bool/parameters.rs +++ b/src/bool/parameters.rs @@ -1,22 +1,29 @@ +use std::ops::Deref; + use num_traits::{ConstZero, FromPrimitive, PrimInt}; use crate::{backend::Modulus, decomposer::Decomposer}; -pub(super) trait DoubleDecomposerParams { +pub(crate) trait DoubleDecomposerCount { + type Count; + fn a(&self) -> Self::Count; + fn b(&self) -> Self::Count; +} + +pub(crate) trait DoubleDecomposerParams { type Base; type Count; - fn new(base: Self::Base, count_a: Self::Count, count_b: Self::Count) -> Self; fn decomposition_base(&self) -> Self::Base; fn decomposition_count_a(&self) -> Self::Count; fn decomposition_count_b(&self) -> Self::Count; } -trait SingleDecomposerParams { +pub(crate) trait SingleDecomposerParams { type Base; type Count; - fn new(base: Self::Base, count: Self::Count) -> Self; + // fn new(base: Self::Base, count: Self::Count) -> Self; fn decomposition_base(&self) -> Self::Base; fn decomposition_count(&self) -> Self::Count; } @@ -31,13 +38,13 @@ impl DoubleDecomposerParams type Base = DecompostionLogBase; type Count = DecompositionCount; - fn new( - base: DecompostionLogBase, - count_a: DecompositionCount, - count_b: DecompositionCount, - ) -> Self { - (base, (count_a, count_b)) - } + // fn new( + // base: DecompostionLogBase, + // count_a: DecompositionCount, + // count_b: DecompositionCount, + // ) -> Self { + // (base, (count_a, count_b)) + // } fn decomposition_base(&self) -> Self::Base { self.0 @@ -56,9 +63,9 @@ impl SingleDecomposerParams for (DecompostionLogBase, DecompositionCount) { type Base = DecompostionLogBase; type Count = DecompositionCount; - fn new(base: DecompostionLogBase, count: DecompositionCount) -> Self { - (base, count) - } + // fn new(base: DecompostionLogBase, count: DecompositionCount) -> Self { + // (base, count) + // } fn decomposition_base(&self) -> Self::Base { self.0 @@ -132,11 +139,11 @@ impl BoolParameters { pub(crate) fn rlwe_by_rgsw_decomposition_params( &self, - ) -> ( + ) -> &( DecompostionLogBase, (DecompositionCount, DecompositionCount), ) { - self.rlrg_decomposer_params + &self.rlrg_decomposer_params } pub(crate) fn rgsw_by_rgsw_decomposition_params( @@ -167,6 +174,10 @@ impl BoolParameters { params.1 } + pub(crate) fn auto_decomposition_param(&self) -> &(DecompostionLogBase, DecompositionCount) { + &self.auto_decomposer_params + } + pub(crate) fn auto_decomposition_base(&self) -> DecompostionLogBase { self.auto_decomposer_params.decomposition_base() } @@ -311,6 +322,7 @@ impl AsRef for DecompositionCount { &self.0 } } + #[derive(Clone, Copy, PartialEq)] pub(crate) struct LweDimension(pub(crate) usize); #[derive(Clone, Copy, PartialEq)] diff --git a/src/bool/print_noise.rs b/src/bool/print_noise.rs index fc1e9a9..d5b0d29 100644 --- a/src/bool/print_noise.rs +++ b/src/bool/print_noise.rs @@ -128,7 +128,7 @@ where rlwe_modop.elwise_neg_mut(neg_s_eval.as_mut()); rlwe_nttop.forward(neg_s_eval.as_mut()); - for j in 0..rlwe_x_rgsw_decomposer.a().decomposition_count() { + for j in 0..rlwe_x_rgsw_decomposer.a().decomposition_count().0 { // RLWE(B^{j} * -s[X]*X^{s_lwe[i]}) // -s[X]*X^{s_lwe[i]}*B_j @@ -144,7 +144,7 @@ where .get_row_mut(0) .copy_from_slice(rgsw_ct_i.get_row_slice(j)); rlwe_ct.get_row_mut(1).copy_from_slice( - rgsw_ct_i.get_row_slice(j + rlwe_x_rgsw_decomposer.a().decomposition_count()), + rgsw_ct_i.get_row_slice(j + rlwe_x_rgsw_decomposer.a().decomposition_count().0), ); // RGSW ciphertexts are in eval domain. We put RLWE ciphertexts back in // coefficient domain @@ -170,7 +170,7 @@ where } // RLWE'(m) - for j in 0..rlwe_x_rgsw_decomposer.b().decomposition_count() { + for j in 0..rlwe_x_rgsw_decomposer.b().decomposition_count().0 { // RLWE(B^{j} * X^{s_lwe[i]}) // X^{s_lwe[i]}*B_j @@ -180,14 +180,15 @@ where // RLWE(X^{s_lwe[i]}*B_j) let mut rlwe_ct = M::zeros(2, rlwe_n); rlwe_ct.get_row_mut(0).copy_from_slice( - rgsw_ct_i - .get_row_slice(j + (2 * rlwe_x_rgsw_decomposer.a().decomposition_count())), + rgsw_ct_i.get_row_slice( + j + (2 * rlwe_x_rgsw_decomposer.a().decomposition_count().0), + ), ); rlwe_ct .get_row_mut(1) .copy_from_slice(rgsw_ct_i.get_row_slice( - j + (2 * rlwe_x_rgsw_decomposer.a().decomposition_count()) - + rlwe_x_rgsw_decomposer.b().decomposition_count(), + j + (2 * rlwe_x_rgsw_decomposer.a().decomposition_count().0) + + rlwe_x_rgsw_decomposer.b().decomposition_count().0, )); rlwe_ct .iter_rows_mut() @@ -290,7 +291,7 @@ where &mut RlweCiphertextMutRef::new(rlwe.as_mut()), &RlweKskRef::new( server_key.galois_key_for_auto(*k).as_ref(), - auto_decomposer.decomposition_count(), + auto_decomposer.decomposition_count().0, ), &mut scratch_matrix_ref, &auto_index_map, diff --git a/src/decomposer.rs b/src/decomposer.rs index 6bfbf35..f315c93 100644 --- a/src/decomposer.rs +++ b/src/decomposer.rs @@ -1,8 +1,13 @@ -use itertools::{izip, Itertools}; +use itertools::{assert_equal, izip, Itertools}; use num_traits::{FromPrimitive, PrimInt, ToPrimitive, WrappingAdd, WrappingSub}; use std::fmt::{Debug, Display}; -use crate::backend::ArithmeticOps; +use crate::{ + backend::ArithmeticOps, + parameters::{ + DecompositionCount, DecompostionLogBase, DoubleDecomposerParams, SingleDecomposerParams, + }, +}; fn gadget_vector(logq: usize, logb: usize, d: usize) -> Vec { assert!(logq >= (logb * d)); @@ -38,6 +43,41 @@ where } } +impl DoubleDecomposerParams for D +where + D: RlweDecomposer, +{ + type Base = DecompostionLogBase; + type Count = DecompositionCount; + + fn decomposition_base(&self) -> Self::Base { + assert!( + Decomposer::decomposition_base(self.a()) == Decomposer::decomposition_base(self.b()) + ); + Decomposer::decomposition_base(self.a()) + } + fn decomposition_count_a(&self) -> Self::Count { + Decomposer::decomposition_count(self.a()) + } + fn decomposition_count_b(&self) -> Self::Count { + Decomposer::decomposition_count(self.b()) + } +} + +impl SingleDecomposerParams for D +where + D: Decomposer, +{ + type Base = DecompostionLogBase; + type Count = DecompositionCount; + fn decomposition_base(&self) -> Self::Base { + Decomposer::decomposition_base(self) + } + fn decomposition_count(&self) -> Self::Count { + Decomposer::decomposition_count(self) + } +} + pub trait Decomposer { type Element; type Iter: Iterator; @@ -45,7 +85,8 @@ pub trait Decomposer { fn decompose_to_vec(&self, v: &Self::Element) -> Vec; fn decompose_iter(&self, v: &Self::Element) -> Self::Iter; - fn decomposition_count(&self) -> usize; + fn decomposition_count(&self) -> DecompositionCount; + fn decomposition_base(&self) -> DecompostionLogBase; fn gadget_vector(&self) -> Vec; } @@ -169,8 +210,12 @@ impl< return out; } - fn decomposition_count(&self) -> usize { - self.d + fn decomposition_count(&self) -> DecompositionCount { + DecompositionCount(self.d) + } + + fn decomposition_base(&self) -> DecompostionLogBase { + DecompostionLogBase(self.logb) } fn decompose_iter(&self, value: &T) -> DecomposerIter { diff --git a/src/lwe.rs b/src/lwe.rs index bebee65..b6e56a6 100644 --- a/src/lwe.rs +++ b/src/lwe.rs @@ -24,7 +24,7 @@ pub(crate) fn lwe_key_switch< decomposer: &D, ) { assert!( - lwe_ksk.dimension().0 == ((lwe_in.as_ref().len() - 1) * decomposer.decomposition_count()) + lwe_ksk.dimension().0 == ((lwe_in.as_ref().len() - 1) * decomposer.decomposition_count().0) ); assert!(lwe_out.as_ref().len() == lwe_ksk.dimension().1); diff --git a/src/pbs.rs b/src/pbs.rs index 7a727d5..cad6892 100644 --- a/src/pbs.rs +++ b/src/pbs.rs @@ -244,9 +244,9 @@ fn blind_rotation< let mut is_trivial = true; let mut scratch_matrix = RuntimeScratchMutRef::new(scratch_matrix.as_mut()); let mut rlwe = RlweCiphertextMutRef::new(trivial_rlwe_test_poly.as_mut()); - let d_a = rlwe_rgsw_decomposer.a().decomposition_count(); - let d_b = rlwe_rgsw_decomposer.b().decomposition_count(); - let d_auto = auto_decomposer.decomposition_count(); + let d_a = rlwe_rgsw_decomposer.a().decomposition_count().0; + let d_b = rlwe_rgsw_decomposer.b().decomposition_count().0; + let d_auto = auto_decomposer.decomposition_count().0; let q_by_4 = q >> 2; let mut count = 0; diff --git a/src/rgsw/mod.rs b/src/rgsw/mod.rs index 6a55542..a3bb5c2 100644 --- a/src/rgsw/mod.rs +++ b/src/rgsw/mod.rs @@ -54,7 +54,7 @@ pub(crate) mod tests { modulus: Mod, ) -> Self { SeededAutoKey { - data: M::zeros(auto_decomposer.decomposition_count(), ring_size), + data: M::zeros(auto_decomposer.decomposition_count().0, ring_size), seed, modulus, } @@ -125,12 +125,12 @@ pub(crate) mod tests { ) -> RgswCiphertext { RgswCiphertext { data: M::zeros( - decomposer.a().decomposition_count() * 2 - + decomposer.b().decomposition_count() * 2, + decomposer.a().decomposition_count().0 * 2 + + decomposer.b().decomposition_count().0 * 2, ring_size, ), - d_a: decomposer.a().decomposition_count(), - d_b: decomposer.b().decomposition_count(), + d_a: decomposer.a().decomposition_count().0, + d_b: decomposer.b().decomposition_count().0, modulus, } } @@ -158,13 +158,14 @@ pub(crate) mod tests { ) -> SeededRgswCiphertext { SeededRgswCiphertext { data: M::zeros( - decomposer.a().decomposition_count() * 2 + decomposer.b().decomposition_count(), + decomposer.a().decomposition_count().0 * 2 + + decomposer.b().decomposition_count().0, ring_size, ), seed, modulus, - d_a: decomposer.a().decomposition_count(), - d_b: decomposer.b().decomposition_count(), + d_a: decomposer.a().decomposition_count().0, + d_b: decomposer.b().decomposition_count().0, } } } @@ -613,13 +614,13 @@ pub(crate) mod tests { &mut RlweCiphertextMutRef::new(rlwe_in_ct_shoup.as_mut()), &RgswCiphertextRef::new( rgsw_ct.data.as_ref(), - decomposer.a().decomposition_count(), - decomposer.b().decomposition_count(), + decomposer.a().decomposition_count().0, + decomposer.b().decomposition_count().0, ), &RgswCiphertextRef::new( rgsw_ct_shoup.as_ref(), - decomposer.a().decomposition_count(), - decomposer.b().decomposition_count(), + decomposer.a().decomposition_count().0, + decomposer.b().decomposition_count().0, ), &mut RuntimeScratchMutRef::new(scratch_space.as_mut()), &decomposer, @@ -637,8 +638,8 @@ pub(crate) mod tests { &mut RlweCiphertextMutRef::new(rlwe_in_ct.data.as_mut()), &RgswCiphertextRef::new( rgsw_ct.data.as_ref(), - decomposer.a().decomposition_count(), - decomposer.b().decomposition_count(), + decomposer.a().decomposition_count().0, + decomposer.b().decomposition_count().0, ), &mut RuntimeScratchMutRef::new(scratch_space.as_mut()), &decomposer, @@ -760,8 +761,8 @@ pub(crate) mod tests { let mut rlwe_m_shoup = rlwe_m.data.clone(); rlwe_auto_shoup( &mut RlweCiphertextMutRef::new(&mut rlwe_m_shoup), - &RlweKskRef::new(&auto_key.data, decomposer.decomposition_count()), - &RlweKskRef::new(&auto_key_shoup, decomposer.decomposition_count()), + &RlweKskRef::new(&auto_key.data, decomposer.decomposition_count().0), + &RlweKskRef::new(&auto_key_shoup, decomposer.decomposition_count().0), &mut RuntimeScratchMutRef::new(&mut scratch_space), &auto_map_index, &auto_map_sign, @@ -777,7 +778,7 @@ pub(crate) mod tests { { rlwe_auto( &mut RlweCiphertextMutRef::new(rlwe_m.data.as_mut()), - &RlweKskRef::new(auto_key.data.as_ref(), decomposer.decomposition_count()), + &RlweKskRef::new(auto_key.data.as_ref(), decomposer.decomposition_count().0), &mut RuntimeScratchMutRef::new(scratch_space.as_mut()), &auto_map_index, &auto_map_sign, @@ -925,8 +926,8 @@ pub(crate) mod tests { DefaultDecomposer::new(q, logb, d_rgsw), ); - let d_a = decomposer.a().decomposition_count(); - let d_b = decomposer.b().decomposition_count(); + let d_a = decomposer.a().decomposition_count().0; + let d_b = decomposer.b().decomposition_count().0; let mul_mod = |a: &u64, b: &u64| ((*a as u128 * *b as u128) % q as u128) as u64; diff --git a/src/rgsw/runtime.rs b/src/rgsw/runtime.rs index e9c645f..bd56c5d 100644 --- a/src/rgsw/runtime.rs +++ b/src/rgsw/runtime.rs @@ -5,6 +5,7 @@ use crate::{ backend::{ArithmeticOps, GetModulus, ShoupMatrixFMA, VectorOps}, decomposer::{Decomposer, RlweDecomposer}, ntt::Ntt, + parameters::{DecompositionCount, DoubleDecomposerParams, SingleDecomposerParams}, Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, }; @@ -296,12 +297,12 @@ where rgsw1_decoposer: &D, ) -> (&mut [Self::R], &mut [Self::R]) { let (decomp_poly, other) = self.data.split_at_mut(std::cmp::max( - rgsw1_decoposer.a().decomposition_count(), - rgsw1_decoposer.b().decomposition_count(), + rgsw1_decoposer.decomposition_count_a().0, + rgsw1_decoposer.decomposition_count_b().0, )); let (rgsw, _) = other.split_at_mut( - rgsw0_decoposer.a().decomposition_count() * 2 - + rgsw0_decoposer.b().decomposition_count() * 2, + rgsw0_decoposer.decomposition_count_a().0 * 2 + + rgsw0_decoposer.decomposition_count_b().0 * 2, ); // zero fill rgsw0 @@ -316,8 +317,8 @@ where decomposer: &D, ) -> (&mut [Self::R], &mut [Self::R]) { let (decomp_poly, other) = self.data.split_at_mut(std::cmp::max( - decomposer.a().decomposition_count(), - decomposer.b().decomposition_count(), + decomposer.decomposition_count_a().0, + decomposer.decomposition_count_b().0, )); let (rlwe, _) = other.split_at_mut(2); @@ -331,27 +332,32 @@ where } /// Returns no. of rows in scratch space for RGSW0 x RGSW1 product -pub(crate) fn rgsw_x_rgsw_scratch_rows( - rgsw0_decomposer: &D, - rgsw1_decomposer: &D, +pub(crate) fn rgsw_x_rgsw_scratch_rows>( + rgsw0_decomposer_param: &D, + rgsw1_decomposer_param: &D, ) -> usize { std::cmp::max( - rgsw1_decomposer.a().decomposition_count(), - rgsw1_decomposer.b().decomposition_count(), - ) + rgsw0_decomposer.a().decomposition_count() * 2 - + rgsw0_decomposer.b().decomposition_count() * 2 + rgsw1_decomposer_param.decomposition_count_a().0, + rgsw1_decomposer_param.decomposition_count_b().0, + ) + rgsw0_decomposer_param.decomposition_count_a().0 * 2 + + rgsw0_decomposer_param.decomposition_count_b().0 * 2 } /// Returns no. of rows in scratch space for RLWE x RGSW product -pub(crate) fn rlwe_x_rgsw_scratch_rows(rgsw_decomposer: &D) -> usize { +pub(crate) fn rlwe_x_rgsw_scratch_rows>( + rgsw_decomposer_param: &D, +) -> usize { std::cmp::max( - rgsw_decomposer.a().decomposition_count(), - rgsw_decomposer.b().decomposition_count(), + rgsw_decomposer_param.decomposition_count_a().0, + rgsw_decomposer_param.decomposition_count_b().0, ) + 2 } + /// Returns no. of rows in scratch space for RLWE auto -pub(crate) fn rlwe_auto_scratch_rows(decomposer: &D) -> usize { - decomposer.decomposition_count() + 2 +pub(crate) fn rlwe_auto_scratch_rows>( + param: &D, +) -> usize { + param.decomposition_count().0 + 2 } pub(crate) fn poly_fma_routine>( @@ -430,7 +436,7 @@ pub(crate) fn rlwe_auto< if !is_trivial { let (decomp_poly_scratch, tmp_rlwe) = scratch_matrix - .scratch_for_rlwe_auto_and_zero_rlwe_space(decomposer.decomposition_count()); + .scratch_for_rlwe_auto_and_zero_rlwe_space(decomposer.decomposition_count().0); let mut tmp_rlwe = RlweCiphertextMutRef::new(tmp_rlwe); // send a(X) -> a(X^k) and decompose a(X^k) @@ -551,7 +557,7 @@ pub(crate) fn rlwe_auto_shoup< if !is_trivial { let (decomp_poly_scratch, tmp_rlwe) = scratch_matrix - .scratch_for_rlwe_auto_and_zero_rlwe_space(decomposer.decomposition_count()); + .scratch_for_rlwe_auto_and_zero_rlwe_space(decomposer.decomposition_count().0); let mut tmp_rlwe = RlweCiphertextMutRef::new(tmp_rlwe); // send a(X) -> a(X^k) and decompose a(X^k) @@ -662,8 +668,8 @@ pub(crate) fn rlwe_by_rgsw< { let decomposer_a = decomposer.a(); let decomposer_b = decomposer.b(); - let d_a = decomposer_a.decomposition_count(); - let d_b = decomposer_b.decomposition_count(); + let d_a = decomposer.decomposition_count_a().0; + let d_b = decomposer.decomposition_count_b().0; let ((rlwe_dash_nsm_parta, rlwe_dash_nsm_partb), (rlwe_dash_m_parta, rlwe_dash_m_partb)) = rgsw_in.split(); @@ -766,8 +772,8 @@ pub(crate) fn rlwe_by_rgsw_shoup< { let decomposer_a = decomposer.a(); let decomposer_b = decomposer.b(); - let d_a = decomposer_a.decomposition_count(); - let d_b = decomposer_b.decomposition_count(); + let d_a = decomposer.decomposition_count_a().0; + let d_b = decomposer.decomposition_count_b().0; let ((rlwe_dash_nsm_parta, rlwe_dash_nsm_partb), (rlwe_dash_m_parta, rlwe_dash_m_partb)) = rgsw_in.split(); @@ -900,8 +906,8 @@ pub(crate) fn rgsw_by_rgsw_inplace< let mut rgsw_space = RgswCiphertextMutRef::new( rgsw_space, - rgsw0_decomposer.a().decomposition_count(), - rgsw0_decomposer.b().decomposition_count(), + rgsw0_decomposer.decomposition_count_a().0, + rgsw0_decomposer.decomposition_count_b().0, ); let ( (rlwe_dash_space_nsm_parta, rlwe_dash_space_nsm_partb), @@ -927,7 +933,7 @@ pub(crate) fn rgsw_by_rgsw_inplace< // Part A: Decomp \cdot RLWE'(-sm1) { - let decomp_r_parta = &mut decomp_r_space[..rgsw1_decomposer.a().decomposition_count()]; + let decomp_r_parta = &mut decomp_r_space[..rgsw1_decomposer.decomposition_count_a().0]; decompose_r( rlwe_a.as_ref(), decomp_r_parta.as_mut(), @@ -952,7 +958,7 @@ pub(crate) fn rgsw_by_rgsw_inplace< // Part B: Decompose \cdot RLWE'(m1) { - let decomp_r_partb = &mut decomp_r_space[..rgsw1_decomposer.b().decomposition_count()]; + let decomp_r_partb = &mut decomp_r_space[..rgsw1_decomposer.decomposition_count_b().0]; decompose_r( rlwe_b.as_ref(), decomp_r_partb.as_mut(), @@ -1011,11 +1017,11 @@ where { let ring_size = rlwe_in.dimension().1; assert!(rlwe_in.dimension().0 == 2); - assert!(ksk.dimension() == (decomposer.decomposition_count() * 2, ring_size)); + assert!(ksk.dimension() == (decomposer.decomposition_count().0 * 2, ring_size)); let mut rlwe_out = M::zeros(2, ring_size); - let mut tmp = M::zeros(decomposer.decomposition_count(), ring_size); + let mut tmp = M::zeros(decomposer.decomposition_count().0, ring_size); let mut tmp_row = M::R::zeros(ring_size); // key switch RLWE part -A @@ -1028,9 +1034,9 @@ where .for_each(|r| ntt_op.forward_lazy(r.as_mut())); // RLWE_s(-A u) = B' + B, A' = (decomp(-A) * Ksk(u -> s)) + (B, 0) - let (ksk_part_a, ksk_part_b) = ksk.split_at_row(decomposer.decomposition_count()); + let (ksk_part_a, ksk_part_b) = ksk.split_at_row(decomposer.decomposition_count().0); let (ksk_part_a_shoup, ksk_part_b_shoup) = - ksk_shoup.split_at_row(decomposer.decomposition_count()); + ksk_shoup.split_at_row(decomposer.decomposition_count().0); // Part A' mod_op.shoup_matrix_fma( rlwe_out.get_row_mut(0),