Browse Source

implement DoubleDecomposer for Rlwe Decomposer

par-agg-key-shares
Janmajaya Mall 9 months ago
parent
commit
1ff98541c8
8 changed files with 251 additions and 167 deletions
  1. +102
    -83
      src/bool/evaluator.rs
  2. +28
    -16
      src/bool/parameters.rs
  3. +9
    -8
      src/bool/print_noise.rs
  4. +50
    -5
      src/decomposer.rs
  5. +1
    -1
      src/lwe.rs
  6. +3
    -3
      src/pbs.rs
  7. +20
    -19
      src/rgsw/mod.rs
  8. +38
    -32
      src/rgsw/runtime.rs

+ 102
- 83
src/bool/evaluator.rs

@ -26,8 +26,9 @@ use crate::{
}, },
rgsw::{ rgsw::{
decrypt_rlwe, generate_auto_map, public_key_encrypt_rgsw, rgsw_by_rgsw_inplace, decrypt_rlwe, generate_auto_map, public_key_encrypt_rgsw, rgsw_by_rgsw_inplace,
rgsw_x_rgsw_scratch_rows, rlwe_auto, secret_key_encrypt_rgsw, seeded_auto_key_gen,
RgswCiphertext, RgswCiphertextMutRef, RgswCiphertextRef, RuntimeScratchMutRef,
rgsw_x_rgsw_scratch_rows, rlwe_auto, rlwe_auto_scratch_rows, rlwe_x_rgsw_scratch_rows,
secret_key_encrypt_rgsw, seeded_auto_key_gen, RgswCiphertext, RgswCiphertextMutRef,
RgswCiphertextRef, RuntimeScratchMutRef,
}, },
utils::{ utils::{
encode_x_pow_si_with_emebedding_factor, fill_random_ternary_secret_with_hamming_weight, encode_x_pow_si_with_emebedding_factor, fill_random_ternary_secret_with_hamming_weight,
@ -237,16 +238,17 @@ where
// Vector to store LWE ciphertext with LWE dimesnion n // Vector to store LWE ciphertext with LWE dimesnion n
let lwe_vector = M::R::zeros(parameters.lwe_n().0 + 1); let lwe_vector = M::R::zeros(parameters.lwe_n().0 + 1);
// Matrix to store decomposed polynomials
// Max decompistion count + space for temporary RLWE
let d = std::cmp::max(
parameters.auto_decomposition_count().0,
// PBS perform two operations at runtime: RLWE x RGW and RLWE auto. Since the
// operations are performed serially same scratch space can be used for both.
// Hence we create scratch space that contains maximum amount of rows that
// suffices for RLWE x RGSW and RLWE auto
let decomposition_matrix = M::zeros(
std::cmp::max( std::cmp::max(
parameters.rlwe_rgsw_decomposition_count().0 .0,
parameters.rlwe_rgsw_decomposition_count().1 .0,
rlwe_x_rgsw_scratch_rows(parameters.rlwe_by_rgsw_decomposition_params()),
rlwe_auto_scratch_rows(parameters.auto_decomposition_param()),
), ),
) + 2;
let decomposition_matrix = M::zeros(d, parameters.rlwe_n().0);
parameters.rlwe_n().0,
);
Self { Self {
lwe_vector, lwe_vector,
@ -546,24 +548,30 @@ where
<M as Matrix>::R: RowMut + Clone, <M as Matrix>::R: RowMut + Clone,
{ {
let max_decomposer = let max_decomposer =
if decomposer.a().decomposition_count() > decomposer.b().decomposition_count() {
if decomposer.a().decomposition_count().0 > decomposer.b().decomposition_count().0 {
decomposer.a() decomposer.a()
} else { } else {
decomposer.b() decomposer.b()
}; };
assert!( assert!(
ni_rgsw_ct.dimension() == (max_decomposer.decomposition_count(), parameters.rlwe_n().0)
ni_rgsw_ct.dimension()
== (
max_decomposer.decomposition_count().0,
parameters.rlwe_n().0
)
);
assert!(
aggregated_decomposed_ni_rgsw_zero_encs.len() == decomposer.a().decomposition_count().0,
); );
assert!(aggregated_decomposed_ni_rgsw_zero_encs.len() == decomposer.a().decomposition_count(),);
assert!(decomposed_neg_ais.len() == decomposer.b().decomposition_count());
assert!(decomposed_neg_ais.len() == decomposer.b().decomposition_count().0);
let mut rgsw_i = M::zeros( let mut rgsw_i = M::zeros(
decomposer.a().decomposition_count() * 2 + decomposer.b().decomposition_count() * 2,
decomposer.a().decomposition_count().0 * 2 + decomposer.b().decomposition_count().0 * 2,
parameters.rlwe_n().0, parameters.rlwe_n().0,
); );
let (rlwe_dash_nsm, rlwe_dash_m) = let (rlwe_dash_nsm, rlwe_dash_m) =
rgsw_i.split_at_row_mut(decomposer.a().decomposition_count() * 2);
rgsw_i.split_at_row_mut(decomposer.a().decomposition_count().0 * 2);
// RLWE'_{s}(-sm) // RLWE'_{s}(-sm)
// Key switch `s * a_{i, l} + e` using ksk(u_j -> s) to produce RLWE(s * // Key switch `s * a_{i, l} + e` using ksk(u_j -> s) to produce RLWE(s *
@ -573,13 +581,13 @@ where
// RLWE(s * u_{j=user_id} * a_{i, l}) // RLWE(s * u_{j=user_id} * a_{i, l})
{ {
let (rlwe_dash_nsm_parta, rlwe_dash_nsm_partb) = let (rlwe_dash_nsm_parta, rlwe_dash_nsm_partb) =
rlwe_dash_nsm.split_at_mut(decomposer.a().decomposition_count());
rlwe_dash_nsm.split_at_mut(decomposer.a().decomposition_count().0);
izip!( izip!(
rlwe_dash_nsm_parta.iter_mut(), rlwe_dash_nsm_parta.iter_mut(),
rlwe_dash_nsm_partb.iter_mut(), rlwe_dash_nsm_partb.iter_mut(),
ni_rgsw_ct
.iter_rows()
.skip(max_decomposer.decomposition_count() - decomposer.a().decomposition_count()),
ni_rgsw_ct.iter_rows().skip(
max_decomposer.decomposition_count().0 - decomposer.a().decomposition_count().0
),
aggregated_decomposed_ni_rgsw_zero_encs.iter() aggregated_decomposed_ni_rgsw_zero_encs.iter()
) )
.for_each(|(rlwe_a, rlwe_b, ni_rlwe_ct, decomp_zero_enc)| { .for_each(|(rlwe_a, rlwe_b, ni_rlwe_ct, decomp_zero_enc)| {
@ -612,13 +620,13 @@ where
// RLWE'_{s}(m) // RLWE'_{s}(m)
{ {
let (rlwe_dash_m_parta, rlwe_dash_partb) = let (rlwe_dash_m_parta, rlwe_dash_partb) =
rlwe_dash_m.split_at_mut(decomposer.b().decomposition_count());
rlwe_dash_m.split_at_mut(decomposer.b().decomposition_count().0);
izip!( izip!(
rlwe_dash_m_parta.iter_mut(), rlwe_dash_m_parta.iter_mut(),
rlwe_dash_partb.iter_mut(), rlwe_dash_partb.iter_mut(),
ni_rgsw_ct
.iter_rows()
.skip(max_decomposer.decomposition_count() - decomposer.b().decomposition_count()),
ni_rgsw_ct.iter_rows().skip(
max_decomposer.decomposition_count().0 - decomposer.b().decomposition_count().0
),
decomposed_neg_ais.iter() decomposed_neg_ais.iter()
) )
.for_each(|(rlwe_a, rlwe_b, ni_rlwe_ct, decomp_neg_ai)| { .for_each(|(rlwe_a, rlwe_b, ni_rlwe_ct, decomp_neg_ai)| {
@ -860,7 +868,10 @@ where
} else { } else {
(g.pow(i as u32) % br_q) as isize (g.pow(i as u32) % br_q) as isize
}; };
let mut gk = M::zeros(self.pbs_info.auto_decomposer.decomposition_count(), rlwe_n);
let mut gk = M::zeros(
self.pbs_info.auto_decomposer.decomposition_count().0,
rlwe_n,
);
seeded_auto_key_gen( seeded_auto_key_gen(
&mut gk, &mut gk,
&sk_rlwe, &sk_rlwe,
@ -878,8 +889,8 @@ where
let ring_size = self.pbs_info.parameters.rlwe_n().0; let ring_size = self.pbs_info.parameters.rlwe_n().0;
let rlwe_q = self.pbs_info.parameters.rlwe_q(); let rlwe_q = self.pbs_info.parameters.rlwe_q();
let (rlrg_d_a, rlrg_d_b) = ( let (rlrg_d_a, rlrg_d_b) = (
self.pbs_info.rlwe_rgsw_decomposer.0.decomposition_count(),
self.pbs_info.rlwe_rgsw_decomposer.1.decomposition_count(),
self.pbs_info.rlwe_rgsw_decomposer.0.decomposition_count().0,
self.pbs_info.rlwe_rgsw_decomposer.1.decomposition_count().0,
); );
let rlrg_gadget_a = self.pbs_info.rlwe_rgsw_decomposer.0.gadget_vector(); let rlrg_gadget_a = self.pbs_info.rlwe_rgsw_decomposer.0.gadget_vector();
let rlrg_gadget_b = self.pbs_info.rlwe_rgsw_decomposer.1.gadget_vector(); let rlrg_gadget_b = self.pbs_info.rlwe_rgsw_decomposer.1.gadget_vector();
@ -997,7 +1008,7 @@ where
rlrg_decomposer.b().gadget_vector(), rlrg_decomposer.b().gadget_vector(),
); );
for s_index in segment_start..segment_end { for s_index in segment_start..segment_end {
let mut out_rgsw = M::zeros(rlrg_d_a * 2 + rlrg_d_b * 2, ring_size);
let mut out_rgsw = M::zeros(rlrg_d_a.0 * 2 + rlrg_d_b.0 * 2, ring_size);
public_key_encrypt_rgsw( public_key_encrypt_rgsw(
&mut out_rgsw, &mut out_rgsw,
&encode_x_pow_si_with_emebedding_factor::< &encode_x_pow_si_with_emebedding_factor::<
@ -1039,7 +1050,7 @@ where
); );
for s_index in (0..segment_start).chain(segment_end..self.parameters().lwe_n().0) { for s_index in (0..segment_start).chain(segment_end..self.parameters().lwe_n().0) {
let mut out_rgsw = M::zeros(rgrg_d_a * 2 + rgrg_d_b * 2, ring_size);
let mut out_rgsw = M::zeros(rgrg_d_a.0 * 2 + rgrg_d_b.0 * 2, ring_size);
public_key_encrypt_rgsw( public_key_encrypt_rgsw(
&mut out_rgsw, &mut out_rgsw,
&encode_x_pow_si_with_emebedding_factor::< &encode_x_pow_si_with_emebedding_factor::<
@ -1147,13 +1158,13 @@ where
parameters.rgsw_rgsw_decomposer::<DefaultDecomposer<M::MatElement>>(); parameters.rgsw_rgsw_decomposer::<DefaultDecomposer<M::MatElement>>();
let rlwe_x_rgsw_decomposer = self.pbs_info().rlwe_rgsw_decomposer(); let rlwe_x_rgsw_decomposer = self.pbs_info().rlwe_rgsw_decomposer();
let rgsw_x_rgsw_dimension = ( let rgsw_x_rgsw_dimension = (
rgsw_x_rgsw_decomposer.a().decomposition_count() * 2
+ rgsw_x_rgsw_decomposer.b().decomposition_count() * 2,
rgsw_x_rgsw_decomposer.a().decomposition_count().0 * 2
+ rgsw_x_rgsw_decomposer.b().decomposition_count().0 * 2,
rlwe_n, rlwe_n,
); );
let rlwe_x_rgsw_dimension = ( let rlwe_x_rgsw_dimension = (
rlwe_x_rgsw_decomposer.a().decomposition_count() * 2
+ rlwe_x_rgsw_decomposer.b().decomposition_count() * 2,
rlwe_x_rgsw_decomposer.a().decomposition_count().0 * 2
+ rlwe_x_rgsw_decomposer.b().decomposition_count().0 * 2,
rlwe_n, rlwe_n,
); );
@ -1213,13 +1224,13 @@ where
rgsw_by_rgsw_inplace( rgsw_by_rgsw_inplace(
&mut RgswCiphertextMutRef::new( &mut RgswCiphertextMutRef::new(
rgsw_i.as_mut(), rgsw_i.as_mut(),
rlwe_x_rgsw_decomposer.a().decomposition_count(),
rlwe_x_rgsw_decomposer.b().decomposition_count(),
rlwe_x_rgsw_decomposer.a().decomposition_count().0,
rlwe_x_rgsw_decomposer.b().decomposition_count().0,
), ),
&RgswCiphertextRef::new( &RgswCiphertextRef::new(
other_rgsw_i.as_ref(), other_rgsw_i.as_ref(),
rgsw_x_rgsw_decomposer.a().decomposition_count(),
rgsw_x_rgsw_decomposer.b().decomposition_count(),
rgsw_x_rgsw_decomposer.a().decomposition_count().0,
rgsw_x_rgsw_decomposer.b().decomposition_count().0,
), ),
rlwe_x_rgsw_decomposer, rlwe_x_rgsw_decomposer,
&rgsw_x_rgsw_decomposer, &rgsw_x_rgsw_decomposer,
@ -1306,7 +1317,7 @@ where
let mut useri_ui_to_s_ksk = share.ui_to_s_ksk().clone(); let mut useri_ui_to_s_ksk = share.ui_to_s_ksk().clone();
assert!( assert!(
useri_ui_to_s_ksk.dimension() useri_ui_to_s_ksk.dimension()
== (ni_uj_to_s_decomposer.decomposition_count(), ring_size)
== (ni_uj_to_s_decomposer.decomposition_count().0, ring_size)
); );
key_shares key_shares
.iter() .iter()
@ -1315,7 +1326,7 @@ where
let op2 = other_share.ui_to_s_ksk_zero_encs_for_user_i(share.user_index()); let op2 = other_share.ui_to_s_ksk_zero_encs_for_user_i(share.user_index());
assert!( assert!(
op2.dimension() op2.dimension()
== (ni_uj_to_s_decomposer.decomposition_count(), ring_size)
== (ni_uj_to_s_decomposer.decomposition_count().0, ring_size)
); );
izip!(useri_ui_to_s_ksk.iter_rows_mut(), op2.iter_rows()).for_each( izip!(useri_ui_to_s_ksk.iter_rows_mut(), op2.iter_rows()).for_each(
|(add_to, add_from)| { |(add_to, add_from)| {
@ -1341,7 +1352,8 @@ where
let mut ksk_prng = DefaultSecureRng::new_seeded( let mut ksk_prng = DefaultSecureRng::new_seeded(
cr_seed.ui_to_s_ks_seed_for_user_i::<DefaultSecureRng>(share.user_index()), cr_seed.ui_to_s_ks_seed_for_user_i::<DefaultSecureRng>(share.user_index()),
); );
let mut ais = M::zeros(ni_uj_to_s_decomposer.decomposition_count(), ring_size);
let mut ais =
M::zeros(ni_uj_to_s_decomposer.decomposition_count().0, ring_size);
ais.iter_rows_mut().for_each(|r_ai| { ais.iter_rows_mut().for_each(|r_ai| {
RandomFillUniformInModulus::random_fill( RandomFillUniformInModulus::random_fill(
@ -1363,12 +1375,12 @@ where
.parameters() .parameters()
.rlwe_rgsw_decomposer::<DefaultDecomposer<M::MatElement>>(); .rlwe_rgsw_decomposer::<DefaultDecomposer<M::MatElement>>();
let d_max = if rgsw_x_rgsw_decomposer.a().decomposition_count()
> rgsw_x_rgsw_decomposer.b().decomposition_count()
let d_max = if rgsw_x_rgsw_decomposer.a().decomposition_count().0
> rgsw_x_rgsw_decomposer.b().decomposition_count().0
{ {
rgsw_x_rgsw_decomposer.a().decomposition_count()
rgsw_x_rgsw_decomposer.a().decomposition_count().0
} else { } else {
rgsw_x_rgsw_decomposer.b().decomposition_count()
rgsw_x_rgsw_decomposer.b().decomposition_count().0
}; };
let mut scratch_rgsw_x_rgsw = M::zeros( let mut scratch_rgsw_x_rgsw = M::zeros(
@ -1416,19 +1428,19 @@ where
); );
let mut scratch = M::R::zeros(self.parameters().rlwe_n().0); let mut scratch = M::R::zeros(self.parameters().rlwe_n().0);
(0..d_max - rgsw_x_rgsw_decomposer.b().decomposition_count()).for_each(
|_| {
(0..d_max - rgsw_x_rgsw_decomposer.b().decomposition_count().0)
.for_each(|_| {
RandomFillUniformInModulus::random_fill( RandomFillUniformInModulus::random_fill(
&mut a_prng, &mut a_prng,
rlwe_q, rlwe_q,
scratch.as_mut(), scratch.as_mut(),
); );
},
);
});
let decomp_neg_ais = (0..rgsw_x_rgsw_decomposer let decomp_neg_ais = (0..rgsw_x_rgsw_decomposer
.b() .b()
.decomposition_count())
.decomposition_count()
.0)
.map(|_| { .map(|_| {
RandomFillUniformInModulus::random_fill( RandomFillUniformInModulus::random_fill(
&mut a_prng, &mut a_prng,
@ -1438,7 +1450,7 @@ where
rlwe_modop.elwise_neg_mut(scratch.as_mut()); rlwe_modop.elwise_neg_mut(scratch.as_mut());
let mut decomp_neg_ai = M::zeros( let mut decomp_neg_ai = M::zeros(
ni_uj_to_s_decomposer.decomposition_count(),
ni_uj_to_s_decomposer.decomposition_count().0,
self.parameters().rlwe_n().0, self.parameters().rlwe_n().0,
); );
scratch.as_ref().iter().enumerate().for_each(|(index, el)| { scratch.as_ref().iter().enumerate().for_each(|(index, el)| {
@ -1468,7 +1480,8 @@ where
// prepare for key switching // prepare for key switching
let ni_rgsw_zero_encs = (0..rgsw_x_rgsw_decomposer let ni_rgsw_zero_encs = (0..rgsw_x_rgsw_decomposer
.a() .a()
.decomposition_count())
.decomposition_count()
.0)
.map(|i| { .map(|i| {
let mut sum = M::R::zeros(self.parameters().rlwe_n().0); let mut sum = M::R::zeros(self.parameters().rlwe_n().0);
key_shares.iter().for_each(|k| { key_shares.iter().for_each(|k| {
@ -1481,7 +1494,7 @@ where
// decompose // decompose
let mut decomp_sum = M::zeros( let mut decomp_sum = M::zeros(
ni_uj_to_s_decomposer.decomposition_count(),
ni_uj_to_s_decomposer.decomposition_count().0,
self.parameters().rlwe_n().0, self.parameters().rlwe_n().0,
); );
sum.as_ref().iter().enumerate().for_each(|(index, el)| { sum.as_ref().iter().enumerate().for_each(|(index, el)| {
@ -1513,9 +1526,13 @@ where
&ni_rgsw_zero_encs[rgsw_x_rgsw_decomposer &ni_rgsw_zero_encs[rgsw_x_rgsw_decomposer
.a() .a()
.decomposition_count() .decomposition_count()
- rlwe_x_rgsw_decomposer.a().decomposition_count()..],
&decomp_neg_ais[rgsw_x_rgsw_decomposer.b().decomposition_count()
- rlwe_x_rgsw_decomposer.b().decomposition_count()..],
.0
- rlwe_x_rgsw_decomposer.a().decomposition_count().0..],
&decomp_neg_ais[rgsw_x_rgsw_decomposer
.b()
.decomposition_count()
.0
- rlwe_x_rgsw_decomposer.b().decomposition_count().0..],
&rlwe_x_rgsw_decomposer, &rlwe_x_rgsw_decomposer,
self.parameters(), self.parameters(),
(&uj_to_s_ksks[user_id], &uj_to_s_ksks_part_a_eval[user_id]), (&uj_to_s_ksks[user_id], &uj_to_s_ksks_part_a_eval[user_id]),
@ -1551,13 +1568,13 @@ where
rgsw_by_rgsw_inplace( rgsw_by_rgsw_inplace(
&mut RgswCiphertextMutRef::new( &mut RgswCiphertextMutRef::new(
rgsw_i.as_mut(), rgsw_i.as_mut(),
rlwe_x_rgsw_decomposer.a().decomposition_count(),
rlwe_x_rgsw_decomposer.b().decomposition_count(),
rlwe_x_rgsw_decomposer.a().decomposition_count().0,
rlwe_x_rgsw_decomposer.b().decomposition_count().0,
), ),
&RgswCiphertextRef::new( &RgswCiphertextRef::new(
other_rgsw_i.as_ref(), other_rgsw_i.as_ref(),
rgsw_x_rgsw_decomposer.a().decomposition_count(),
rgsw_x_rgsw_decomposer.b().decomposition_count(),
rgsw_x_rgsw_decomposer.a().decomposition_count().0,
rgsw_x_rgsw_decomposer.b().decomposition_count().0,
), ),
&rlwe_x_rgsw_decomposer, &rlwe_x_rgsw_decomposer,
&rgsw_x_rgsw_decomposer, &rgsw_x_rgsw_decomposer,
@ -1719,12 +1736,12 @@ where
// We assume that d_{a/b} for RGSW x RGSW are always < d'_{a/b} for RLWE x RGSW // We assume that d_{a/b} for RGSW x RGSW are always < d'_{a/b} for RLWE x RGSW
assert!( assert!(
rlwe_x_rgsw_decomposer.a().decomposition_count()
< rgsw_x_rgsw_decomposer.a().decomposition_count()
rlwe_x_rgsw_decomposer.a().decomposition_count().0
< rgsw_x_rgsw_decomposer.a().decomposition_count().0
); );
assert!( assert!(
rlwe_x_rgsw_decomposer.b().decomposition_count()
< rgsw_x_rgsw_decomposer.b().decomposition_count()
rlwe_x_rgsw_decomposer.b().decomposition_count().0
< rgsw_x_rgsw_decomposer.b().decomposition_count().0
); );
let sj_poly_eval = { let sj_poly_eval = {
@ -1733,8 +1750,8 @@ where
s s
}; };
let d_rgsw_a = rgsw_x_rgsw_decomposer.a().decomposition_count();
let d_rgsw_b = rgsw_x_rgsw_decomposer.b().decomposition_count();
let d_rgsw_a = rgsw_x_rgsw_decomposer.a().decomposition_count().0;
let d_rgsw_b = rgsw_x_rgsw_decomposer.b().decomposition_count().0;
let d_max = std::cmp::max(d_rgsw_a, d_rgsw_b); let d_max = std::cmp::max(d_rgsw_a, d_rgsw_b);
// Zero encyptions for each LWE index. We generate d_a zero encryptions for each // Zero encyptions for each LWE index. We generate d_a zero encryptions for each
@ -1809,13 +1826,14 @@ where
// RGSW multiplication. We refer to such indices as where user is // RGSW multiplication. We refer to such indices as where user is
// not the leader. // not the leader.
let self_leader_ni_rgsw_cts = { let self_leader_ni_rgsw_cts = {
let max_rlwe_x_rgsw_decomposer = if rlwe_x_rgsw_decomposer.a().decomposition_count()
> rlwe_x_rgsw_decomposer.b().decomposition_count()
{
rlwe_x_rgsw_decomposer.a()
} else {
rlwe_x_rgsw_decomposer.b()
};
let max_rlwe_x_rgsw_decomposer =
if rlwe_x_rgsw_decomposer.a().decomposition_count().0
> rlwe_x_rgsw_decomposer.b().decomposition_count().0
{
rlwe_x_rgsw_decomposer.a()
} else {
rlwe_x_rgsw_decomposer.b()
};
let gadget_vec = max_rlwe_x_rgsw_decomposer.gadget_vector(); let gadget_vec = max_rlwe_x_rgsw_decomposer.gadget_vector();
@ -1828,7 +1846,7 @@ where
// puncture p_rng d_max - d'_max time to align with `a_{i, l}`s used to // puncture p_rng d_max - d'_max time to align with `a_{i, l}`s used to
// produce RGSW cts for RGSW x RGSW // produce RGSW cts for RGSW x RGSW
let mut scratch = M::R::zeros(self.parameters().rlwe_n().0); let mut scratch = M::R::zeros(self.parameters().rlwe_n().0);
(0..(d_max - max_rlwe_x_rgsw_decomposer.decomposition_count()))
(0..(d_max - max_rlwe_x_rgsw_decomposer.decomposition_count().0))
.into_iter() .into_iter()
.for_each(|_| { .for_each(|_| {
RandomFillUniformInModulus::random_fill( RandomFillUniformInModulus::random_fill(
@ -1839,7 +1857,7 @@ where
}); });
let mut ni_rgsw_cts = M::zeros( let mut ni_rgsw_cts = M::zeros(
max_rlwe_x_rgsw_decomposer.decomposition_count(),
max_rlwe_x_rgsw_decomposer.decomposition_count().0,
self.parameters().rlwe_n().0, self.parameters().rlwe_n().0,
); );
@ -1891,13 +1909,14 @@ where
}; };
let not_self_leader_rgsw_cts = { let not_self_leader_rgsw_cts = {
let max_rgsw_x_rgsw_decomposer = if rgsw_x_rgsw_decomposer.a().decomposition_count()
> rgsw_x_rgsw_decomposer.b().decomposition_count()
{
rgsw_x_rgsw_decomposer.a()
} else {
rgsw_x_rgsw_decomposer.b()
};
let max_rgsw_x_rgsw_decomposer =
if rgsw_x_rgsw_decomposer.a().decomposition_count().0
> rgsw_x_rgsw_decomposer.b().decomposition_count().0
{
rgsw_x_rgsw_decomposer.a()
} else {
rgsw_x_rgsw_decomposer.b()
};
let gadget_vec = max_rgsw_x_rgsw_decomposer.gadget_vector(); let gadget_vec = max_rgsw_x_rgsw_decomposer.gadget_vector();
((0..self_start_index).chain(self_end_index..self.parameters().lwe_n().0)) ((0..self_start_index).chain(self_end_index..self.parameters().lwe_n().0))
@ -1906,7 +1925,7 @@ where
cr_seed.ni_rgsw_ct_seed_for_index::<DefaultSecureRng>(lwe_index), cr_seed.ni_rgsw_ct_seed_for_index::<DefaultSecureRng>(lwe_index),
); );
let mut ni_rgsw_cts = M::zeros( let mut ni_rgsw_cts = M::zeros(
max_rgsw_x_rgsw_decomposer.decomposition_count(),
max_rgsw_x_rgsw_decomposer.decomposition_count().0,
self.parameters().rlwe_n().0, self.parameters().rlwe_n().0,
); );
let mut scratch = M::R::zeros(self.parameters().rlwe_n().0); let mut scratch = M::R::zeros(self.parameters().rlwe_n().0);
@ -2014,7 +2033,7 @@ where
}; };
let mut ksk_out = M::zeros( let mut ksk_out = M::zeros(
self.pbs_info.auto_decomposer.decomposition_count(),
self.pbs_info.auto_decomposer.decomposition_count().0,
ring_size, ring_size,
); );
seeded_auto_key_gen( seeded_auto_key_gen(

+ 28
- 16
src/bool/parameters.rs

@ -1,22 +1,29 @@
use std::ops::Deref;
use num_traits::{ConstZero, FromPrimitive, PrimInt}; use num_traits::{ConstZero, FromPrimitive, PrimInt};
use crate::{backend::Modulus, decomposer::Decomposer}; use crate::{backend::Modulus, decomposer::Decomposer};
pub(super) trait DoubleDecomposerParams {
pub(crate) trait DoubleDecomposerCount {
type Count;
fn a(&self) -> Self::Count;
fn b(&self) -> Self::Count;
}
pub(crate) trait DoubleDecomposerParams {
type Base; type Base;
type Count; type Count;
fn new(base: Self::Base, count_a: Self::Count, count_b: Self::Count) -> Self;
fn decomposition_base(&self) -> Self::Base; fn decomposition_base(&self) -> Self::Base;
fn decomposition_count_a(&self) -> Self::Count; fn decomposition_count_a(&self) -> Self::Count;
fn decomposition_count_b(&self) -> Self::Count; fn decomposition_count_b(&self) -> Self::Count;
} }
trait SingleDecomposerParams {
pub(crate) trait SingleDecomposerParams {
type Base; type Base;
type Count; type Count;
fn new(base: Self::Base, count: Self::Count) -> Self;
// fn new(base: Self::Base, count: Self::Count) -> Self;
fn decomposition_base(&self) -> Self::Base; fn decomposition_base(&self) -> Self::Base;
fn decomposition_count(&self) -> Self::Count; fn decomposition_count(&self) -> Self::Count;
} }
@ -31,13 +38,13 @@ impl DoubleDecomposerParams
type Base = DecompostionLogBase; type Base = DecompostionLogBase;
type Count = DecompositionCount; type Count = DecompositionCount;
fn new(
base: DecompostionLogBase,
count_a: DecompositionCount,
count_b: DecompositionCount,
) -> Self {
(base, (count_a, count_b))
}
// fn new(
// base: DecompostionLogBase,
// count_a: DecompositionCount,
// count_b: DecompositionCount,
// ) -> Self {
// (base, (count_a, count_b))
// }
fn decomposition_base(&self) -> Self::Base { fn decomposition_base(&self) -> Self::Base {
self.0 self.0
@ -56,9 +63,9 @@ impl SingleDecomposerParams for (DecompostionLogBase, DecompositionCount) {
type Base = DecompostionLogBase; type Base = DecompostionLogBase;
type Count = DecompositionCount; type Count = DecompositionCount;
fn new(base: DecompostionLogBase, count: DecompositionCount) -> Self {
(base, count)
}
// fn new(base: DecompostionLogBase, count: DecompositionCount) -> Self {
// (base, count)
// }
fn decomposition_base(&self) -> Self::Base { fn decomposition_base(&self) -> Self::Base {
self.0 self.0
@ -132,11 +139,11 @@ impl BoolParameters {
pub(crate) fn rlwe_by_rgsw_decomposition_params( pub(crate) fn rlwe_by_rgsw_decomposition_params(
&self, &self,
) -> (
) -> &(
DecompostionLogBase, DecompostionLogBase,
(DecompositionCount, DecompositionCount), (DecompositionCount, DecompositionCount),
) { ) {
self.rlrg_decomposer_params
&self.rlrg_decomposer_params
} }
pub(crate) fn rgsw_by_rgsw_decomposition_params( pub(crate) fn rgsw_by_rgsw_decomposition_params(
@ -167,6 +174,10 @@ impl BoolParameters {
params.1 params.1
} }
pub(crate) fn auto_decomposition_param(&self) -> &(DecompostionLogBase, DecompositionCount) {
&self.auto_decomposer_params
}
pub(crate) fn auto_decomposition_base(&self) -> DecompostionLogBase { pub(crate) fn auto_decomposition_base(&self) -> DecompostionLogBase {
self.auto_decomposer_params.decomposition_base() self.auto_decomposer_params.decomposition_base()
} }
@ -311,6 +322,7 @@ impl AsRef for DecompositionCount {
&self.0 &self.0
} }
} }
#[derive(Clone, Copy, PartialEq)] #[derive(Clone, Copy, PartialEq)]
pub(crate) struct LweDimension(pub(crate) usize); pub(crate) struct LweDimension(pub(crate) usize);
#[derive(Clone, Copy, PartialEq)] #[derive(Clone, Copy, PartialEq)]

+ 9
- 8
src/bool/print_noise.rs

@ -128,7 +128,7 @@ where
rlwe_modop.elwise_neg_mut(neg_s_eval.as_mut()); rlwe_modop.elwise_neg_mut(neg_s_eval.as_mut());
rlwe_nttop.forward(neg_s_eval.as_mut()); rlwe_nttop.forward(neg_s_eval.as_mut());
for j in 0..rlwe_x_rgsw_decomposer.a().decomposition_count() {
for j in 0..rlwe_x_rgsw_decomposer.a().decomposition_count().0 {
// RLWE(B^{j} * -s[X]*X^{s_lwe[i]}) // RLWE(B^{j} * -s[X]*X^{s_lwe[i]})
// -s[X]*X^{s_lwe[i]}*B_j // -s[X]*X^{s_lwe[i]}*B_j
@ -144,7 +144,7 @@ where
.get_row_mut(0) .get_row_mut(0)
.copy_from_slice(rgsw_ct_i.get_row_slice(j)); .copy_from_slice(rgsw_ct_i.get_row_slice(j));
rlwe_ct.get_row_mut(1).copy_from_slice( rlwe_ct.get_row_mut(1).copy_from_slice(
rgsw_ct_i.get_row_slice(j + rlwe_x_rgsw_decomposer.a().decomposition_count()),
rgsw_ct_i.get_row_slice(j + rlwe_x_rgsw_decomposer.a().decomposition_count().0),
); );
// RGSW ciphertexts are in eval domain. We put RLWE ciphertexts back in // RGSW ciphertexts are in eval domain. We put RLWE ciphertexts back in
// coefficient domain // coefficient domain
@ -170,7 +170,7 @@ where
} }
// RLWE'(m) // RLWE'(m)
for j in 0..rlwe_x_rgsw_decomposer.b().decomposition_count() {
for j in 0..rlwe_x_rgsw_decomposer.b().decomposition_count().0 {
// RLWE(B^{j} * X^{s_lwe[i]}) // RLWE(B^{j} * X^{s_lwe[i]})
// X^{s_lwe[i]}*B_j // X^{s_lwe[i]}*B_j
@ -180,14 +180,15 @@ where
// RLWE(X^{s_lwe[i]}*B_j) // RLWE(X^{s_lwe[i]}*B_j)
let mut rlwe_ct = M::zeros(2, rlwe_n); let mut rlwe_ct = M::zeros(2, rlwe_n);
rlwe_ct.get_row_mut(0).copy_from_slice( rlwe_ct.get_row_mut(0).copy_from_slice(
rgsw_ct_i
.get_row_slice(j + (2 * rlwe_x_rgsw_decomposer.a().decomposition_count())),
rgsw_ct_i.get_row_slice(
j + (2 * rlwe_x_rgsw_decomposer.a().decomposition_count().0),
),
); );
rlwe_ct rlwe_ct
.get_row_mut(1) .get_row_mut(1)
.copy_from_slice(rgsw_ct_i.get_row_slice( .copy_from_slice(rgsw_ct_i.get_row_slice(
j + (2 * rlwe_x_rgsw_decomposer.a().decomposition_count())
+ rlwe_x_rgsw_decomposer.b().decomposition_count(),
j + (2 * rlwe_x_rgsw_decomposer.a().decomposition_count().0)
+ rlwe_x_rgsw_decomposer.b().decomposition_count().0,
)); ));
rlwe_ct rlwe_ct
.iter_rows_mut() .iter_rows_mut()
@ -290,7 +291,7 @@ where
&mut RlweCiphertextMutRef::new(rlwe.as_mut()), &mut RlweCiphertextMutRef::new(rlwe.as_mut()),
&RlweKskRef::new( &RlweKskRef::new(
server_key.galois_key_for_auto(*k).as_ref(), server_key.galois_key_for_auto(*k).as_ref(),
auto_decomposer.decomposition_count(),
auto_decomposer.decomposition_count().0,
), ),
&mut scratch_matrix_ref, &mut scratch_matrix_ref,
&auto_index_map, &auto_index_map,

+ 50
- 5
src/decomposer.rs

@ -1,8 +1,13 @@
use itertools::{izip, Itertools};
use itertools::{assert_equal, izip, Itertools};
use num_traits::{FromPrimitive, PrimInt, ToPrimitive, WrappingAdd, WrappingSub}; use num_traits::{FromPrimitive, PrimInt, ToPrimitive, WrappingAdd, WrappingSub};
use std::fmt::{Debug, Display}; use std::fmt::{Debug, Display};
use crate::backend::ArithmeticOps;
use crate::{
backend::ArithmeticOps,
parameters::{
DecompositionCount, DecompostionLogBase, DoubleDecomposerParams, SingleDecomposerParams,
},
};
fn gadget_vector<T: PrimInt>(logq: usize, logb: usize, d: usize) -> Vec<T> { fn gadget_vector<T: PrimInt>(logq: usize, logb: usize, d: usize) -> Vec<T> {
assert!(logq >= (logb * d)); assert!(logq >= (logb * d));
@ -38,6 +43,41 @@ where
} }
} }
impl<D> DoubleDecomposerParams for D
where
D: RlweDecomposer,
{
type Base = DecompostionLogBase;
type Count = DecompositionCount;
fn decomposition_base(&self) -> Self::Base {
assert!(
Decomposer::decomposition_base(self.a()) == Decomposer::decomposition_base(self.b())
);
Decomposer::decomposition_base(self.a())
}
fn decomposition_count_a(&self) -> Self::Count {
Decomposer::decomposition_count(self.a())
}
fn decomposition_count_b(&self) -> Self::Count {
Decomposer::decomposition_count(self.b())
}
}
impl<D> SingleDecomposerParams for D
where
D: Decomposer,
{
type Base = DecompostionLogBase;
type Count = DecompositionCount;
fn decomposition_base(&self) -> Self::Base {
Decomposer::decomposition_base(self)
}
fn decomposition_count(&self) -> Self::Count {
Decomposer::decomposition_count(self)
}
}
pub trait Decomposer { pub trait Decomposer {
type Element; type Element;
type Iter: Iterator<Item = Self::Element>; type Iter: Iterator<Item = Self::Element>;
@ -45,7 +85,8 @@ pub trait Decomposer {
fn decompose_to_vec(&self, v: &Self::Element) -> Vec<Self::Element>; fn decompose_to_vec(&self, v: &Self::Element) -> Vec<Self::Element>;
fn decompose_iter(&self, v: &Self::Element) -> Self::Iter; fn decompose_iter(&self, v: &Self::Element) -> Self::Iter;
fn decomposition_count(&self) -> usize;
fn decomposition_count(&self) -> DecompositionCount;
fn decomposition_base(&self) -> DecompostionLogBase;
fn gadget_vector(&self) -> Vec<Self::Element>; fn gadget_vector(&self) -> Vec<Self::Element>;
} }
@ -169,8 +210,12 @@ impl<
return out; return out;
} }
fn decomposition_count(&self) -> usize {
self.d
fn decomposition_count(&self) -> DecompositionCount {
DecompositionCount(self.d)
}
fn decomposition_base(&self) -> DecompostionLogBase {
DecompostionLogBase(self.logb)
} }
fn decompose_iter(&self, value: &T) -> DecomposerIter<T> { fn decompose_iter(&self, value: &T) -> DecomposerIter<T> {

+ 1
- 1
src/lwe.rs

@ -24,7 +24,7 @@ pub(crate) fn lwe_key_switch<
decomposer: &D, decomposer: &D,
) { ) {
assert!( assert!(
lwe_ksk.dimension().0 == ((lwe_in.as_ref().len() - 1) * decomposer.decomposition_count())
lwe_ksk.dimension().0 == ((lwe_in.as_ref().len() - 1) * decomposer.decomposition_count().0)
); );
assert!(lwe_out.as_ref().len() == lwe_ksk.dimension().1); assert!(lwe_out.as_ref().len() == lwe_ksk.dimension().1);

+ 3
- 3
src/pbs.rs

@ -244,9 +244,9 @@ fn blind_rotation<
let mut is_trivial = true; let mut is_trivial = true;
let mut scratch_matrix = RuntimeScratchMutRef::new(scratch_matrix.as_mut()); let mut scratch_matrix = RuntimeScratchMutRef::new(scratch_matrix.as_mut());
let mut rlwe = RlweCiphertextMutRef::new(trivial_rlwe_test_poly.as_mut()); let mut rlwe = RlweCiphertextMutRef::new(trivial_rlwe_test_poly.as_mut());
let d_a = rlwe_rgsw_decomposer.a().decomposition_count();
let d_b = rlwe_rgsw_decomposer.b().decomposition_count();
let d_auto = auto_decomposer.decomposition_count();
let d_a = rlwe_rgsw_decomposer.a().decomposition_count().0;
let d_b = rlwe_rgsw_decomposer.b().decomposition_count().0;
let d_auto = auto_decomposer.decomposition_count().0;
let q_by_4 = q >> 2; let q_by_4 = q >> 2;
let mut count = 0; let mut count = 0;

+ 20
- 19
src/rgsw/mod.rs

@ -54,7 +54,7 @@ pub(crate) mod tests {
modulus: Mod, modulus: Mod,
) -> Self { ) -> Self {
SeededAutoKey { SeededAutoKey {
data: M::zeros(auto_decomposer.decomposition_count(), ring_size),
data: M::zeros(auto_decomposer.decomposition_count().0, ring_size),
seed, seed,
modulus, modulus,
} }
@ -125,12 +125,12 @@ pub(crate) mod tests {
) -> RgswCiphertext<M, Mod> { ) -> RgswCiphertext<M, Mod> {
RgswCiphertext { RgswCiphertext {
data: M::zeros( data: M::zeros(
decomposer.a().decomposition_count() * 2
+ decomposer.b().decomposition_count() * 2,
decomposer.a().decomposition_count().0 * 2
+ decomposer.b().decomposition_count().0 * 2,
ring_size, ring_size,
), ),
d_a: decomposer.a().decomposition_count(),
d_b: decomposer.b().decomposition_count(),
d_a: decomposer.a().decomposition_count().0,
d_b: decomposer.b().decomposition_count().0,
modulus, modulus,
} }
} }
@ -158,13 +158,14 @@ pub(crate) mod tests {
) -> SeededRgswCiphertext<M, S, Mod> { ) -> SeededRgswCiphertext<M, S, Mod> {
SeededRgswCiphertext { SeededRgswCiphertext {
data: M::zeros( data: M::zeros(
decomposer.a().decomposition_count() * 2 + decomposer.b().decomposition_count(),
decomposer.a().decomposition_count().0 * 2
+ decomposer.b().decomposition_count().0,
ring_size, ring_size,
), ),
seed, seed,
modulus, modulus,
d_a: decomposer.a().decomposition_count(),
d_b: decomposer.b().decomposition_count(),
d_a: decomposer.a().decomposition_count().0,
d_b: decomposer.b().decomposition_count().0,
} }
} }
} }
@ -613,13 +614,13 @@ pub(crate) mod tests {
&mut RlweCiphertextMutRef::new(rlwe_in_ct_shoup.as_mut()), &mut RlweCiphertextMutRef::new(rlwe_in_ct_shoup.as_mut()),
&RgswCiphertextRef::new( &RgswCiphertextRef::new(
rgsw_ct.data.as_ref(), rgsw_ct.data.as_ref(),
decomposer.a().decomposition_count(),
decomposer.b().decomposition_count(),
decomposer.a().decomposition_count().0,
decomposer.b().decomposition_count().0,
), ),
&RgswCiphertextRef::new( &RgswCiphertextRef::new(
rgsw_ct_shoup.as_ref(), rgsw_ct_shoup.as_ref(),
decomposer.a().decomposition_count(),
decomposer.b().decomposition_count(),
decomposer.a().decomposition_count().0,
decomposer.b().decomposition_count().0,
), ),
&mut RuntimeScratchMutRef::new(scratch_space.as_mut()), &mut RuntimeScratchMutRef::new(scratch_space.as_mut()),
&decomposer, &decomposer,
@ -637,8 +638,8 @@ pub(crate) mod tests {
&mut RlweCiphertextMutRef::new(rlwe_in_ct.data.as_mut()), &mut RlweCiphertextMutRef::new(rlwe_in_ct.data.as_mut()),
&RgswCiphertextRef::new( &RgswCiphertextRef::new(
rgsw_ct.data.as_ref(), rgsw_ct.data.as_ref(),
decomposer.a().decomposition_count(),
decomposer.b().decomposition_count(),
decomposer.a().decomposition_count().0,
decomposer.b().decomposition_count().0,
), ),
&mut RuntimeScratchMutRef::new(scratch_space.as_mut()), &mut RuntimeScratchMutRef::new(scratch_space.as_mut()),
&decomposer, &decomposer,
@ -760,8 +761,8 @@ pub(crate) mod tests {
let mut rlwe_m_shoup = rlwe_m.data.clone(); let mut rlwe_m_shoup = rlwe_m.data.clone();
rlwe_auto_shoup( rlwe_auto_shoup(
&mut RlweCiphertextMutRef::new(&mut rlwe_m_shoup), &mut RlweCiphertextMutRef::new(&mut rlwe_m_shoup),
&RlweKskRef::new(&auto_key.data, decomposer.decomposition_count()),
&RlweKskRef::new(&auto_key_shoup, decomposer.decomposition_count()),
&RlweKskRef::new(&auto_key.data, decomposer.decomposition_count().0),
&RlweKskRef::new(&auto_key_shoup, decomposer.decomposition_count().0),
&mut RuntimeScratchMutRef::new(&mut scratch_space), &mut RuntimeScratchMutRef::new(&mut scratch_space),
&auto_map_index, &auto_map_index,
&auto_map_sign, &auto_map_sign,
@ -777,7 +778,7 @@ pub(crate) mod tests {
{ {
rlwe_auto( rlwe_auto(
&mut RlweCiphertextMutRef::new(rlwe_m.data.as_mut()), &mut RlweCiphertextMutRef::new(rlwe_m.data.as_mut()),
&RlweKskRef::new(auto_key.data.as_ref(), decomposer.decomposition_count()),
&RlweKskRef::new(auto_key.data.as_ref(), decomposer.decomposition_count().0),
&mut RuntimeScratchMutRef::new(scratch_space.as_mut()), &mut RuntimeScratchMutRef::new(scratch_space.as_mut()),
&auto_map_index, &auto_map_index,
&auto_map_sign, &auto_map_sign,
@ -925,8 +926,8 @@ pub(crate) mod tests {
DefaultDecomposer::new(q, logb, d_rgsw), DefaultDecomposer::new(q, logb, d_rgsw),
); );
let d_a = decomposer.a().decomposition_count();
let d_b = decomposer.b().decomposition_count();
let d_a = decomposer.a().decomposition_count().0;
let d_b = decomposer.b().decomposition_count().0;
let mul_mod = |a: &u64, b: &u64| ((*a as u128 * *b as u128) % q as u128) as u64; let mul_mod = |a: &u64, b: &u64| ((*a as u128 * *b as u128) % q as u128) as u64;

+ 38
- 32
src/rgsw/runtime.rs

@ -5,6 +5,7 @@ use crate::{
backend::{ArithmeticOps, GetModulus, ShoupMatrixFMA, VectorOps}, backend::{ArithmeticOps, GetModulus, ShoupMatrixFMA, VectorOps},
decomposer::{Decomposer, RlweDecomposer}, decomposer::{Decomposer, RlweDecomposer},
ntt::Ntt, ntt::Ntt,
parameters::{DecompositionCount, DoubleDecomposerParams, SingleDecomposerParams},
Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut,
}; };
@ -296,12 +297,12 @@ where
rgsw1_decoposer: &D, rgsw1_decoposer: &D,
) -> (&mut [Self::R], &mut [Self::R]) { ) -> (&mut [Self::R], &mut [Self::R]) {
let (decomp_poly, other) = self.data.split_at_mut(std::cmp::max( let (decomp_poly, other) = self.data.split_at_mut(std::cmp::max(
rgsw1_decoposer.a().decomposition_count(),
rgsw1_decoposer.b().decomposition_count(),
rgsw1_decoposer.decomposition_count_a().0,
rgsw1_decoposer.decomposition_count_b().0,
)); ));
let (rgsw, _) = other.split_at_mut( let (rgsw, _) = other.split_at_mut(
rgsw0_decoposer.a().decomposition_count() * 2
+ rgsw0_decoposer.b().decomposition_count() * 2,
rgsw0_decoposer.decomposition_count_a().0 * 2
+ rgsw0_decoposer.decomposition_count_b().0 * 2,
); );
// zero fill rgsw0 // zero fill rgsw0
@ -316,8 +317,8 @@ where
decomposer: &D, decomposer: &D,
) -> (&mut [Self::R], &mut [Self::R]) { ) -> (&mut [Self::R], &mut [Self::R]) {
let (decomp_poly, other) = self.data.split_at_mut(std::cmp::max( let (decomp_poly, other) = self.data.split_at_mut(std::cmp::max(
decomposer.a().decomposition_count(),
decomposer.b().decomposition_count(),
decomposer.decomposition_count_a().0,
decomposer.decomposition_count_b().0,
)); ));
let (rlwe, _) = other.split_at_mut(2); let (rlwe, _) = other.split_at_mut(2);
@ -331,27 +332,32 @@ where
} }
/// Returns no. of rows in scratch space for RGSW0 x RGSW1 product /// Returns no. of rows in scratch space for RGSW0 x RGSW1 product
pub(crate) fn rgsw_x_rgsw_scratch_rows<D: RlweDecomposer>(
rgsw0_decomposer: &D,
rgsw1_decomposer: &D,
pub(crate) fn rgsw_x_rgsw_scratch_rows<D: DoubleDecomposerParams<Count = DecompositionCount>>(
rgsw0_decomposer_param: &D,
rgsw1_decomposer_param: &D,
) -> usize { ) -> usize {
std::cmp::max( std::cmp::max(
rgsw1_decomposer.a().decomposition_count(),
rgsw1_decomposer.b().decomposition_count(),
) + rgsw0_decomposer.a().decomposition_count() * 2
+ rgsw0_decomposer.b().decomposition_count() * 2
rgsw1_decomposer_param.decomposition_count_a().0,
rgsw1_decomposer_param.decomposition_count_b().0,
) + rgsw0_decomposer_param.decomposition_count_a().0 * 2
+ rgsw0_decomposer_param.decomposition_count_b().0 * 2
} }
/// Returns no. of rows in scratch space for RLWE x RGSW product /// Returns no. of rows in scratch space for RLWE x RGSW product
pub(crate) fn rlwe_x_rgsw_scratch_rows<D: RlweDecomposer>(rgsw_decomposer: &D) -> usize {
pub(crate) fn rlwe_x_rgsw_scratch_rows<D: DoubleDecomposerParams<Count = DecompositionCount>>(
rgsw_decomposer_param: &D,
) -> usize {
std::cmp::max( std::cmp::max(
rgsw_decomposer.a().decomposition_count(),
rgsw_decomposer.b().decomposition_count(),
rgsw_decomposer_param.decomposition_count_a().0,
rgsw_decomposer_param.decomposition_count_b().0,
) + 2 ) + 2
} }
/// Returns no. of rows in scratch space for RLWE auto /// Returns no. of rows in scratch space for RLWE auto
pub(crate) fn rlwe_auto_scratch_rows<D: Decomposer>(decomposer: &D) -> usize {
decomposer.decomposition_count() + 2
pub(crate) fn rlwe_auto_scratch_rows<D: SingleDecomposerParams<Count = DecompositionCount>>(
param: &D,
) -> usize {
param.decomposition_count().0 + 2
} }
pub(crate) fn poly_fma_routine<R: RowMut, ModOp: VectorOps<Element = R::Element>>( pub(crate) fn poly_fma_routine<R: RowMut, ModOp: VectorOps<Element = R::Element>>(
@ -430,7 +436,7 @@ pub(crate) fn rlwe_auto<
if !is_trivial { if !is_trivial {
let (decomp_poly_scratch, tmp_rlwe) = scratch_matrix let (decomp_poly_scratch, tmp_rlwe) = scratch_matrix
.scratch_for_rlwe_auto_and_zero_rlwe_space(decomposer.decomposition_count());
.scratch_for_rlwe_auto_and_zero_rlwe_space(decomposer.decomposition_count().0);
let mut tmp_rlwe = RlweCiphertextMutRef::new(tmp_rlwe); let mut tmp_rlwe = RlweCiphertextMutRef::new(tmp_rlwe);
// send a(X) -> a(X^k) and decompose a(X^k) // send a(X) -> a(X^k) and decompose a(X^k)
@ -551,7 +557,7 @@ pub(crate) fn rlwe_auto_shoup<
if !is_trivial { if !is_trivial {
let (decomp_poly_scratch, tmp_rlwe) = scratch_matrix let (decomp_poly_scratch, tmp_rlwe) = scratch_matrix
.scratch_for_rlwe_auto_and_zero_rlwe_space(decomposer.decomposition_count());
.scratch_for_rlwe_auto_and_zero_rlwe_space(decomposer.decomposition_count().0);
let mut tmp_rlwe = RlweCiphertextMutRef::new(tmp_rlwe); let mut tmp_rlwe = RlweCiphertextMutRef::new(tmp_rlwe);
// send a(X) -> a(X^k) and decompose a(X^k) // send a(X) -> a(X^k) and decompose a(X^k)
@ -662,8 +668,8 @@ pub(crate) fn rlwe_by_rgsw<
{ {
let decomposer_a = decomposer.a(); let decomposer_a = decomposer.a();
let decomposer_b = decomposer.b(); let decomposer_b = decomposer.b();
let d_a = decomposer_a.decomposition_count();
let d_b = decomposer_b.decomposition_count();
let d_a = decomposer.decomposition_count_a().0;
let d_b = decomposer.decomposition_count_b().0;
let ((rlwe_dash_nsm_parta, rlwe_dash_nsm_partb), (rlwe_dash_m_parta, rlwe_dash_m_partb)) = let ((rlwe_dash_nsm_parta, rlwe_dash_nsm_partb), (rlwe_dash_m_parta, rlwe_dash_m_partb)) =
rgsw_in.split(); rgsw_in.split();
@ -766,8 +772,8 @@ pub(crate) fn rlwe_by_rgsw_shoup<
{ {
let decomposer_a = decomposer.a(); let decomposer_a = decomposer.a();
let decomposer_b = decomposer.b(); let decomposer_b = decomposer.b();
let d_a = decomposer_a.decomposition_count();
let d_b = decomposer_b.decomposition_count();
let d_a = decomposer.decomposition_count_a().0;
let d_b = decomposer.decomposition_count_b().0;
let ((rlwe_dash_nsm_parta, rlwe_dash_nsm_partb), (rlwe_dash_m_parta, rlwe_dash_m_partb)) = let ((rlwe_dash_nsm_parta, rlwe_dash_nsm_partb), (rlwe_dash_m_parta, rlwe_dash_m_partb)) =
rgsw_in.split(); rgsw_in.split();
@ -900,8 +906,8 @@ pub(crate) fn rgsw_by_rgsw_inplace<
let mut rgsw_space = RgswCiphertextMutRef::new( let mut rgsw_space = RgswCiphertextMutRef::new(
rgsw_space, rgsw_space,
rgsw0_decomposer.a().decomposition_count(),
rgsw0_decomposer.b().decomposition_count(),
rgsw0_decomposer.decomposition_count_a().0,
rgsw0_decomposer.decomposition_count_b().0,
); );
let ( let (
(rlwe_dash_space_nsm_parta, rlwe_dash_space_nsm_partb), (rlwe_dash_space_nsm_parta, rlwe_dash_space_nsm_partb),
@ -927,7 +933,7 @@ pub(crate) fn rgsw_by_rgsw_inplace<
// Part A: Decomp<RLWE(m0)[A]> \cdot RLWE'(-sm1) // Part A: Decomp<RLWE(m0)[A]> \cdot RLWE'(-sm1)
{ {
let decomp_r_parta = &mut decomp_r_space[..rgsw1_decomposer.a().decomposition_count()];
let decomp_r_parta = &mut decomp_r_space[..rgsw1_decomposer.decomposition_count_a().0];
decompose_r( decompose_r(
rlwe_a.as_ref(), rlwe_a.as_ref(),
decomp_r_parta.as_mut(), decomp_r_parta.as_mut(),
@ -952,7 +958,7 @@ pub(crate) fn rgsw_by_rgsw_inplace<
// Part B: Decompose<RLWE(m0)[B]> \cdot RLWE'(m1) // Part B: Decompose<RLWE(m0)[B]> \cdot RLWE'(m1)
{ {
let decomp_r_partb = &mut decomp_r_space[..rgsw1_decomposer.b().decomposition_count()];
let decomp_r_partb = &mut decomp_r_space[..rgsw1_decomposer.decomposition_count_b().0];
decompose_r( decompose_r(
rlwe_b.as_ref(), rlwe_b.as_ref(),
decomp_r_partb.as_mut(), decomp_r_partb.as_mut(),
@ -1011,11 +1017,11 @@ where
{ {
let ring_size = rlwe_in.dimension().1; let ring_size = rlwe_in.dimension().1;
assert!(rlwe_in.dimension().0 == 2); assert!(rlwe_in.dimension().0 == 2);
assert!(ksk.dimension() == (decomposer.decomposition_count() * 2, ring_size));
assert!(ksk.dimension() == (decomposer.decomposition_count().0 * 2, ring_size));
let mut rlwe_out = M::zeros(2, ring_size); let mut rlwe_out = M::zeros(2, ring_size);
let mut tmp = M::zeros(decomposer.decomposition_count(), ring_size);
let mut tmp = M::zeros(decomposer.decomposition_count().0, ring_size);
let mut tmp_row = M::R::zeros(ring_size); let mut tmp_row = M::R::zeros(ring_size);
// key switch RLWE part -A // key switch RLWE part -A
@ -1028,9 +1034,9 @@ where
.for_each(|r| ntt_op.forward_lazy(r.as_mut())); .for_each(|r| ntt_op.forward_lazy(r.as_mut()));
// RLWE_s(-A u) = B' + B, A' = (decomp(-A) * Ksk(u -> s)) + (B, 0) // RLWE_s(-A u) = B' + B, A' = (decomp(-A) * Ksk(u -> s)) + (B, 0)
let (ksk_part_a, ksk_part_b) = ksk.split_at_row(decomposer.decomposition_count());
let (ksk_part_a, ksk_part_b) = ksk.split_at_row(decomposer.decomposition_count().0);
let (ksk_part_a_shoup, ksk_part_b_shoup) = let (ksk_part_a_shoup, ksk_part_b_shoup) =
ksk_shoup.split_at_row(decomposer.decomposition_count());
ksk_shoup.split_at_row(decomposer.decomposition_count().0);
// Part A' // Part A'
mod_op.shoup_matrix_fma( mod_op.shoup_matrix_fma(
rlwe_out.get_row_mut(0), rlwe_out.get_row_mut(0),

Loading…
Cancel
Save