implement DoubleDecomposer for Rlwe Decomposer

This commit is contained in:
Janmajaya Mall
2024-06-30 11:17:18 +05:30
parent f5f3700ea4
commit 1ff98541c8
8 changed files with 251 additions and 167 deletions

View File

@@ -54,7 +54,7 @@ pub(crate) mod tests {
modulus: Mod,
) -> Self {
SeededAutoKey {
data: M::zeros(auto_decomposer.decomposition_count(), ring_size),
data: M::zeros(auto_decomposer.decomposition_count().0, ring_size),
seed,
modulus,
}
@@ -125,12 +125,12 @@ pub(crate) mod tests {
) -> RgswCiphertext<M, Mod> {
RgswCiphertext {
data: M::zeros(
decomposer.a().decomposition_count() * 2
+ decomposer.b().decomposition_count() * 2,
decomposer.a().decomposition_count().0 * 2
+ decomposer.b().decomposition_count().0 * 2,
ring_size,
),
d_a: decomposer.a().decomposition_count(),
d_b: decomposer.b().decomposition_count(),
d_a: decomposer.a().decomposition_count().0,
d_b: decomposer.b().decomposition_count().0,
modulus,
}
}
@@ -158,13 +158,14 @@ pub(crate) mod tests {
) -> SeededRgswCiphertext<M, S, Mod> {
SeededRgswCiphertext {
data: M::zeros(
decomposer.a().decomposition_count() * 2 + decomposer.b().decomposition_count(),
decomposer.a().decomposition_count().0 * 2
+ decomposer.b().decomposition_count().0,
ring_size,
),
seed,
modulus,
d_a: decomposer.a().decomposition_count(),
d_b: decomposer.b().decomposition_count(),
d_a: decomposer.a().decomposition_count().0,
d_b: decomposer.b().decomposition_count().0,
}
}
}
@@ -613,13 +614,13 @@ pub(crate) mod tests {
&mut RlweCiphertextMutRef::new(rlwe_in_ct_shoup.as_mut()),
&RgswCiphertextRef::new(
rgsw_ct.data.as_ref(),
decomposer.a().decomposition_count(),
decomposer.b().decomposition_count(),
decomposer.a().decomposition_count().0,
decomposer.b().decomposition_count().0,
),
&RgswCiphertextRef::new(
rgsw_ct_shoup.as_ref(),
decomposer.a().decomposition_count(),
decomposer.b().decomposition_count(),
decomposer.a().decomposition_count().0,
decomposer.b().decomposition_count().0,
),
&mut RuntimeScratchMutRef::new(scratch_space.as_mut()),
&decomposer,
@@ -637,8 +638,8 @@ pub(crate) mod tests {
&mut RlweCiphertextMutRef::new(rlwe_in_ct.data.as_mut()),
&RgswCiphertextRef::new(
rgsw_ct.data.as_ref(),
decomposer.a().decomposition_count(),
decomposer.b().decomposition_count(),
decomposer.a().decomposition_count().0,
decomposer.b().decomposition_count().0,
),
&mut RuntimeScratchMutRef::new(scratch_space.as_mut()),
&decomposer,
@@ -760,8 +761,8 @@ pub(crate) mod tests {
let mut rlwe_m_shoup = rlwe_m.data.clone();
rlwe_auto_shoup(
&mut RlweCiphertextMutRef::new(&mut rlwe_m_shoup),
&RlweKskRef::new(&auto_key.data, decomposer.decomposition_count()),
&RlweKskRef::new(&auto_key_shoup, decomposer.decomposition_count()),
&RlweKskRef::new(&auto_key.data, decomposer.decomposition_count().0),
&RlweKskRef::new(&auto_key_shoup, decomposer.decomposition_count().0),
&mut RuntimeScratchMutRef::new(&mut scratch_space),
&auto_map_index,
&auto_map_sign,
@@ -777,7 +778,7 @@ pub(crate) mod tests {
{
rlwe_auto(
&mut RlweCiphertextMutRef::new(rlwe_m.data.as_mut()),
&RlweKskRef::new(auto_key.data.as_ref(), decomposer.decomposition_count()),
&RlweKskRef::new(auto_key.data.as_ref(), decomposer.decomposition_count().0),
&mut RuntimeScratchMutRef::new(scratch_space.as_mut()),
&auto_map_index,
&auto_map_sign,
@@ -925,8 +926,8 @@ pub(crate) mod tests {
DefaultDecomposer::new(q, logb, d_rgsw),
);
let d_a = decomposer.a().decomposition_count();
let d_b = decomposer.b().decomposition_count();
let d_a = decomposer.a().decomposition_count().0;
let d_b = decomposer.b().decomposition_count().0;
let mul_mod = |a: &u64, b: &u64| ((*a as u128 * *b as u128) % q as u128) as u64;

View File

@@ -5,6 +5,7 @@ use crate::{
backend::{ArithmeticOps, GetModulus, ShoupMatrixFMA, VectorOps},
decomposer::{Decomposer, RlweDecomposer},
ntt::Ntt,
parameters::{DecompositionCount, DoubleDecomposerParams, SingleDecomposerParams},
Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut,
};
@@ -296,12 +297,12 @@ where
rgsw1_decoposer: &D,
) -> (&mut [Self::R], &mut [Self::R]) {
let (decomp_poly, other) = self.data.split_at_mut(std::cmp::max(
rgsw1_decoposer.a().decomposition_count(),
rgsw1_decoposer.b().decomposition_count(),
rgsw1_decoposer.decomposition_count_a().0,
rgsw1_decoposer.decomposition_count_b().0,
));
let (rgsw, _) = other.split_at_mut(
rgsw0_decoposer.a().decomposition_count() * 2
+ rgsw0_decoposer.b().decomposition_count() * 2,
rgsw0_decoposer.decomposition_count_a().0 * 2
+ rgsw0_decoposer.decomposition_count_b().0 * 2,
);
// zero fill rgsw0
@@ -316,8 +317,8 @@ where
decomposer: &D,
) -> (&mut [Self::R], &mut [Self::R]) {
let (decomp_poly, other) = self.data.split_at_mut(std::cmp::max(
decomposer.a().decomposition_count(),
decomposer.b().decomposition_count(),
decomposer.decomposition_count_a().0,
decomposer.decomposition_count_b().0,
));
let (rlwe, _) = other.split_at_mut(2);
@@ -331,27 +332,32 @@ where
}
/// Returns no. of rows in scratch space for RGSW0 x RGSW1 product
pub(crate) fn rgsw_x_rgsw_scratch_rows<D: RlweDecomposer>(
rgsw0_decomposer: &D,
rgsw1_decomposer: &D,
pub(crate) fn rgsw_x_rgsw_scratch_rows<D: DoubleDecomposerParams<Count = DecompositionCount>>(
rgsw0_decomposer_param: &D,
rgsw1_decomposer_param: &D,
) -> usize {
std::cmp::max(
rgsw1_decomposer.a().decomposition_count(),
rgsw1_decomposer.b().decomposition_count(),
) + rgsw0_decomposer.a().decomposition_count() * 2
+ rgsw0_decomposer.b().decomposition_count() * 2
rgsw1_decomposer_param.decomposition_count_a().0,
rgsw1_decomposer_param.decomposition_count_b().0,
) + rgsw0_decomposer_param.decomposition_count_a().0 * 2
+ rgsw0_decomposer_param.decomposition_count_b().0 * 2
}
/// Returns no. of rows in scratch space for RLWE x RGSW product
pub(crate) fn rlwe_x_rgsw_scratch_rows<D: RlweDecomposer>(rgsw_decomposer: &D) -> usize {
pub(crate) fn rlwe_x_rgsw_scratch_rows<D: DoubleDecomposerParams<Count = DecompositionCount>>(
rgsw_decomposer_param: &D,
) -> usize {
std::cmp::max(
rgsw_decomposer.a().decomposition_count(),
rgsw_decomposer.b().decomposition_count(),
rgsw_decomposer_param.decomposition_count_a().0,
rgsw_decomposer_param.decomposition_count_b().0,
) + 2
}
/// Returns no. of rows in scratch space for RLWE auto
pub(crate) fn rlwe_auto_scratch_rows<D: Decomposer>(decomposer: &D) -> usize {
decomposer.decomposition_count() + 2
pub(crate) fn rlwe_auto_scratch_rows<D: SingleDecomposerParams<Count = DecompositionCount>>(
param: &D,
) -> usize {
param.decomposition_count().0 + 2
}
pub(crate) fn poly_fma_routine<R: RowMut, ModOp: VectorOps<Element = R::Element>>(
@@ -430,7 +436,7 @@ pub(crate) fn rlwe_auto<
if !is_trivial {
let (decomp_poly_scratch, tmp_rlwe) = scratch_matrix
.scratch_for_rlwe_auto_and_zero_rlwe_space(decomposer.decomposition_count());
.scratch_for_rlwe_auto_and_zero_rlwe_space(decomposer.decomposition_count().0);
let mut tmp_rlwe = RlweCiphertextMutRef::new(tmp_rlwe);
// send a(X) -> a(X^k) and decompose a(X^k)
@@ -551,7 +557,7 @@ pub(crate) fn rlwe_auto_shoup<
if !is_trivial {
let (decomp_poly_scratch, tmp_rlwe) = scratch_matrix
.scratch_for_rlwe_auto_and_zero_rlwe_space(decomposer.decomposition_count());
.scratch_for_rlwe_auto_and_zero_rlwe_space(decomposer.decomposition_count().0);
let mut tmp_rlwe = RlweCiphertextMutRef::new(tmp_rlwe);
// send a(X) -> a(X^k) and decompose a(X^k)
@@ -662,8 +668,8 @@ pub(crate) fn rlwe_by_rgsw<
{
let decomposer_a = decomposer.a();
let decomposer_b = decomposer.b();
let d_a = decomposer_a.decomposition_count();
let d_b = decomposer_b.decomposition_count();
let d_a = decomposer.decomposition_count_a().0;
let d_b = decomposer.decomposition_count_b().0;
let ((rlwe_dash_nsm_parta, rlwe_dash_nsm_partb), (rlwe_dash_m_parta, rlwe_dash_m_partb)) =
rgsw_in.split();
@@ -766,8 +772,8 @@ pub(crate) fn rlwe_by_rgsw_shoup<
{
let decomposer_a = decomposer.a();
let decomposer_b = decomposer.b();
let d_a = decomposer_a.decomposition_count();
let d_b = decomposer_b.decomposition_count();
let d_a = decomposer.decomposition_count_a().0;
let d_b = decomposer.decomposition_count_b().0;
let ((rlwe_dash_nsm_parta, rlwe_dash_nsm_partb), (rlwe_dash_m_parta, rlwe_dash_m_partb)) =
rgsw_in.split();
@@ -900,8 +906,8 @@ pub(crate) fn rgsw_by_rgsw_inplace<
let mut rgsw_space = RgswCiphertextMutRef::new(
rgsw_space,
rgsw0_decomposer.a().decomposition_count(),
rgsw0_decomposer.b().decomposition_count(),
rgsw0_decomposer.decomposition_count_a().0,
rgsw0_decomposer.decomposition_count_b().0,
);
let (
(rlwe_dash_space_nsm_parta, rlwe_dash_space_nsm_partb),
@@ -927,7 +933,7 @@ pub(crate) fn rgsw_by_rgsw_inplace<
// Part A: Decomp<RLWE(m0)[A]> \cdot RLWE'(-sm1)
{
let decomp_r_parta = &mut decomp_r_space[..rgsw1_decomposer.a().decomposition_count()];
let decomp_r_parta = &mut decomp_r_space[..rgsw1_decomposer.decomposition_count_a().0];
decompose_r(
rlwe_a.as_ref(),
decomp_r_parta.as_mut(),
@@ -952,7 +958,7 @@ pub(crate) fn rgsw_by_rgsw_inplace<
// Part B: Decompose<RLWE(m0)[B]> \cdot RLWE'(m1)
{
let decomp_r_partb = &mut decomp_r_space[..rgsw1_decomposer.b().decomposition_count()];
let decomp_r_partb = &mut decomp_r_space[..rgsw1_decomposer.decomposition_count_b().0];
decompose_r(
rlwe_b.as_ref(),
decomp_r_partb.as_mut(),
@@ -1011,11 +1017,11 @@ where
{
let ring_size = rlwe_in.dimension().1;
assert!(rlwe_in.dimension().0 == 2);
assert!(ksk.dimension() == (decomposer.decomposition_count() * 2, ring_size));
assert!(ksk.dimension() == (decomposer.decomposition_count().0 * 2, ring_size));
let mut rlwe_out = M::zeros(2, ring_size);
let mut tmp = M::zeros(decomposer.decomposition_count(), ring_size);
let mut tmp = M::zeros(decomposer.decomposition_count().0, ring_size);
let mut tmp_row = M::R::zeros(ring_size);
// key switch RLWE part -A
@@ -1028,9 +1034,9 @@ where
.for_each(|r| ntt_op.forward_lazy(r.as_mut()));
// RLWE_s(-A u) = B' + B, A' = (decomp(-A) * Ksk(u -> s)) + (B, 0)
let (ksk_part_a, ksk_part_b) = ksk.split_at_row(decomposer.decomposition_count());
let (ksk_part_a, ksk_part_b) = ksk.split_at_row(decomposer.decomposition_count().0);
let (ksk_part_a_shoup, ksk_part_b_shoup) =
ksk_shoup.split_at_row(decomposer.decomposition_count());
ksk_shoup.split_at_row(decomposer.decomposition_count().0);
// Part A'
mod_op.shoup_matrix_fma(
rlwe_out.get_row_mut(0),