add differing base feature for RLWExRGSw and RGSWxRGSW for interactive mpc

This commit is contained in:
Janmajaya Mall
2024-06-24 15:26:53 +07:00
parent 5d5100e6d1
commit 1d7099600a
8 changed files with 382 additions and 211 deletions

View File

@@ -1114,6 +1114,8 @@ pub(crate) mod tests {
);
rgsw_by_rgsw_inplace(
&mut rgsw_carrym,
decomposer.a().decomposition_count(),
decomposer.b().decomposition_count(),
&rgsw_m.data,
&decomposer,
&mut scratch_matrix,

View File

@@ -546,14 +546,19 @@ pub(crate) fn rlwe_by_rgsw_shoup<
/// - rgsw_1_eval: RGSW(m1) in Evaluation domain
/// - scratch_matrix_d_plus_rgsw_by_ring: scratch space matrix with rows
/// (max(d_a, d_b) + d_a*2+d_b*2) and columns ring_size
///
/// ## Note:
/// - We treat RGSW x RGSW as multiple RLWE x RGSW multiplications. .
pub(crate) fn rgsw_by_rgsw_inplace<
Mmut: MatrixMut,
D: RlweDecomposer<Element = Mmut::MatElement>,
ModOp: VectorOps<Element = Mmut::MatElement>,
NttOp: Ntt<Element = Mmut::MatElement>,
>(
rgsw_0: &mut Mmut,
rgsw_1_eval: &Mmut,
rgsw0: &mut Mmut,
rgsw0_da: usize,
rgsw0_db: usize,
rgsw1_eval: &Mmut,
decomposer: &D,
scratch_matrix: &mut Mmut,
ntt_op: &NttOp,
@@ -567,11 +572,12 @@ pub(crate) fn rgsw_by_rgsw_inplace<
let d_a = decomposer_a.decomposition_count();
let d_b = decomposer_b.decomposition_count();
let max_d = std::cmp::max(d_a, d_b);
let rgsw_rows = d_a * 2 + d_b * 2;
assert!(rgsw_0.dimension().0 == rgsw_rows);
let ring_size = rgsw_0.dimension().1;
assert!(rgsw_1_eval.dimension() == (rgsw_rows, ring_size));
assert!(scratch_matrix.fits(max_d + rgsw_rows, ring_size));
let rgsw1_rows = d_a * 2 + d_b * 2;
let rgsw0_rows = rgsw0_da * 2 + rgsw0_db * 2;
let ring_size = rgsw0.dimension().1;
assert!(rgsw0.dimension().0 == rgsw0_rows);
assert!(rgsw1_eval.dimension() == (rgsw1_rows, ring_size));
assert!(scratch_matrix.fits(max_d + rgsw0_rows, ring_size));
let (decomp_r_space, rgsw_space) = scratch_matrix.split_at_row_mut(max_d);
@@ -579,18 +585,25 @@ pub(crate) fn rgsw_by_rgsw_inplace<
rgsw_space
.iter_mut()
.for_each(|ri| ri.as_mut().fill(Mmut::MatElement::zero()));
let (rlwe_dash_space_nsm, rlwe_dash_space_m) = rgsw_space.split_at_mut(d_a * 2);
let (rlwe_dash_space_nsm, rlwe_dash_space_m) = rgsw_space.split_at_mut(rgsw0_da * 2);
let (rlwe_dash_space_nsm_parta, rlwe_dash_space_nsm_partb) =
rlwe_dash_space_nsm.split_at_mut(d_a);
let (rlwe_dash_space_m_parta, rlwe_dash_space_m_partb) = rlwe_dash_space_m.split_at_mut(d_b);
rlwe_dash_space_nsm.split_at_mut(rgsw0_da);
let (rlwe_dash_space_m_parta, rlwe_dash_space_m_partb) =
rlwe_dash_space_m.split_at_mut(rgsw0_db);
let (rgsw0_nsm, rgsw0_m) = rgsw_0.split_at_row(d_a * 2);
let (rgsw1_nsm, rgsw1_m) = rgsw_1_eval.split_at_row(d_a * 2);
let (rgsw0_nsm, rgsw0_m) = rgsw0.split_at_row(rgsw0_da * 2);
let (rgsw1_nsm, rgsw1_m) = rgsw1_eval.split_at_row(d_a * 2);
// RGSW x RGSW
izip!(
rgsw0_nsm.iter().take(d_a).chain(rgsw0_m.iter().take(d_b)),
rgsw0_nsm.iter().skip(d_a).chain(rgsw0_m.iter().skip(d_b)),
rgsw0_nsm
.iter()
.take(rgsw0_da)
.chain(rgsw0_m.iter().take(rgsw0_db)),
rgsw0_nsm
.iter()
.skip(rgsw0_da)
.chain(rgsw0_m.iter().skip(rgsw0_db)),
rlwe_dash_space_nsm_parta
.iter_mut()
.chain(rlwe_dash_space_m_parta.iter_mut()),
@@ -599,7 +612,9 @@ pub(crate) fn rgsw_by_rgsw_inplace<
.chain(rlwe_dash_space_m_partb.iter_mut()),
)
.for_each(|(rlwe_a, rlwe_b, rlwe_out_a, rlwe_out_b)| {
// Part A
// RLWE(m0) x RGSW(m1)
// Part A: Decomp<RLWE(m0)[A]> \cdot RLWE'(-sm1)
decompose_r(rlwe_a.as_ref(), decomp_r_space.as_mut(), decomposer_a);
decomp_r_space
.iter_mut()
@@ -618,7 +633,7 @@ pub(crate) fn rgsw_by_rgsw_inplace<
mod_op,
);
// Part B
// Part B: Decompose<RLWE(m0)[B]> \cdot RLWE'(m1)
decompose_r(rlwe_b.as_ref(), decomp_r_space.as_mut(), decomposer_b);
decomp_r_space
.iter_mut()
@@ -639,11 +654,11 @@ pub(crate) fn rgsw_by_rgsw_inplace<
});
// copy over RGSW(m0m1) into RGSW(m0)
izip!(rgsw_0.iter_rows_mut(), rgsw_space.iter())
izip!(rgsw0.iter_rows_mut(), rgsw_space.iter())
.for_each(|(to_ri, from_ri)| to_ri.as_mut().copy_from_slice(from_ri.as_ref()));
// send back to coefficient domain
rgsw_0
rgsw0
.iter_rows_mut()
.for_each(|ri| ntt_op.backward(ri.as_mut()));
}