From 3f624f04ded28978f27f5909aad8f8ee21dc12cf Mon Sep 17 00:00:00 2001 From: Janmajaya Mall Date: Wed, 5 Jun 2024 18:13:53 +0530 Subject: [PATCH] minor fixes --- src/rgsw.rs | 110 +++++++--------------------------------------------- 1 file changed, 13 insertions(+), 97 deletions(-) diff --git a/src/rgsw.rs b/src/rgsw.rs index 0e09eda..9f2812e 100644 --- a/src/rgsw.rs +++ b/src/rgsw.rs @@ -1,6 +1,7 @@ use std::{ clone, fmt::Debug, + iter, marker::PhantomData, ops::{Div, Neg, Sub}, }; @@ -555,7 +556,16 @@ pub(crate) fn galois_auto< assert!(rlwe_in.dimension().0 == 2); assert!(scratch_matrix.fits(d + 2, ring_size)); - let (scratch_matrix_d_ring, tmp_rlwe_out) = scratch_matrix.split_at_row_mut(d); + // scratch matrix is guaranteed to have at-least d+2 rows but can have more than + // d+2 rows. We require to split them into sub-matrices of exact sizes one with + // d rows for storing decomposed polynomial and second with 2 rows to act + // tomperary space for RLWE ciphertext. Exact sizes is necessary to avoid any + // irrelevant extra FMA or NTT ops. + let (scratch_matrix_d_ring, other_half) = scratch_matrix.split_at_row_mut(d); + let (tmp_rlwe_out, _) = other_half.split_at_mut(2); + + assert!(tmp_rlwe_out.len() == 2); + assert!(scratch_matrix_d_ring.len() == d); if !rlwe_in.is_trivial() { tmp_rlwe_out.iter_mut().for_each(|r| { @@ -650,100 +660,6 @@ pub(crate) fn galois_auto< .copy_from_slice(tmp_rlwe_out[1].as_ref()); } -/// Returns RLWE(m0m1) = RLWE(m0) x RGSW(m1). Mutates rlwe_in inplace to equal -/// RLWE(m0m1) -/// -/// - rlwe_in: is RLWE(m0) with polynomials in coefficient domain -/// - rgsw_in: is RGSW(m1) with polynomials in evaluation domain -/// - scratch_matrix_d_ring: is a matrix of dimension (d_rgsw, ring_size) used -/// as scratch space to store decomposed Ring elements temporarily -pub(crate) fn less1_rlwe_by_rgsw< - Mmut: MatrixMut, - MT: Matrix + MatrixMut + IsTrivial, - D: Decomposer, - ModOp: VectorOps, - NttOp: Ntt, ->( - rlwe_in: &mut MT, - rgsw_in: &Mmut, - scratch_matrix_dplus2_ring: &mut Mmut, - decomposer: &D, - ntt_op: &NttOp, - mod_op: &ModOp, - skip0: usize, - skip1: usize, -) where - Mmut::MatElement: Copy + Zero, - ::R: RowMut, - ::R: RowMut, -{ - let d_rgsw = decomposer.decomposition_count(); - assert!(scratch_matrix_dplus2_ring.dimension() == (d_rgsw + 2, rlwe_in.dimension().1)); - assert!(rgsw_in.dimension() == (d_rgsw * 4, rlwe_in.dimension().1)); - - // decomposed RLWE x RGSW - let (rlwe_dash_nsm, rlwe_dash_m) = rgsw_in.split_at_row(d_rgsw * 2); - let (scratch_matrix_d_ring, scratch_rlwe_out) = - scratch_matrix_dplus2_ring.split_at_row_mut(d_rgsw); - scratch_rlwe_out[0].as_mut().fill(Mmut::MatElement::zero()); - scratch_rlwe_out[1].as_mut().fill(Mmut::MatElement::zero()); - // RLWE_in = a_in, b_in; RLWE_out = a_out, b_out - if !rlwe_in.is_trivial() { - // a_in = 0 when RLWE_in is trivial RLWE ciphertext - // decomp - decompose_r(rlwe_in.get_row_slice(0), scratch_matrix_d_ring, decomposer); - scratch_matrix_d_ring - .iter_mut() - .for_each(|r| ntt_op.forward(r.as_mut())); - // a_out += decomp \cdot RLWE_A'(-sm) - routine( - scratch_rlwe_out[0].as_mut(), - scratch_matrix_d_ring[skip0..].as_ref(), - &rlwe_dash_nsm[skip0..d_rgsw], - mod_op, - ); - // b_out += decomp \cdot RLWE_B'(-sm) - routine( - scratch_rlwe_out[1].as_mut(), - scratch_matrix_d_ring[skip0..].as_ref(), - &rlwe_dash_nsm[d_rgsw + skip0..], - mod_op, - ); - } - // decomp - decompose_r(rlwe_in.get_row_slice(1), scratch_matrix_d_ring, decomposer); - scratch_matrix_d_ring - .iter_mut() - .for_each(|r| ntt_op.forward(r.as_mut())); - // a_out += decomp \cdot RLWE_A'(m) - routine( - scratch_rlwe_out[0].as_mut(), - scratch_matrix_d_ring[skip1..].as_ref(), - &rlwe_dash_m[skip1..d_rgsw], - mod_op, - ); - // b_out += decomp \cdot RLWE_B'(m) - routine( - scratch_rlwe_out[1].as_mut(), - scratch_matrix_d_ring[skip1..].as_ref(), - &rlwe_dash_m[d_rgsw + skip1..], - mod_op, - ); - - // transform rlwe_out to coefficient domain - scratch_rlwe_out - .iter_mut() - .for_each(|r| ntt_op.backward(r.as_mut())); - - rlwe_in - .get_row_mut(0) - .copy_from_slice(scratch_rlwe_out[0].as_mut()); - rlwe_in - .get_row_mut(1) - .copy_from_slice(scratch_rlwe_out[1].as_mut()); - rlwe_in.set_not_trivial(); -} - /// Returns RLWE(m0m1) = RLWE(m0) x RGSW(m1). Mutates rlwe_in inplace to equal /// RLWE(m0m1) /// @@ -789,7 +705,7 @@ pub(crate) fn rlwe_by_rgsw< // decomp decompose_r( rlwe_in.get_row_slice(0), - scratch_matrix_d_ring, + &mut scratch_matrix_d_ring[..d_a], decomposer_a, ); scratch_matrix_d_ring @@ -814,7 +730,7 @@ pub(crate) fn rlwe_by_rgsw< // decomp decompose_r( rlwe_in.get_row_slice(1), - scratch_matrix_d_ring, + &mut scratch_matrix_d_ring[..d_b], decomposer_b, ); scratch_matrix_d_ring