diff --git a/src/backend/mod.rs b/src/backend/mod.rs index d3cc655..097aab4 100644 --- a/src/backend/mod.rs +++ b/src/backend/mod.rs @@ -1,6 +1,6 @@ use num_traits::ToPrimitive; -use crate::{Matrix, RowMut}; +use crate::{Matrix, Row, RowMut}; mod modulus_u64; mod word_size; @@ -126,10 +126,7 @@ pub trait ArithmeticLazyOps { fn add_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element; } -pub trait ShoupMatrixFMA -where - M::R: RowMut, -{ - /// Returns summation of row-wise product of matrix a and b. - fn shoup_matrix_fma(&self, out: &mut M::R, a: &M, a_shoup: &M, b: &M); +pub trait ShoupMatrixFMA { + /// Returns summation of `row-wise product of matrix a and b` + out. + fn shoup_matrix_fma(&self, out: &mut [R::Element], a: &[R], a_shoup: &[R], b: &[R]); } diff --git a/src/backend/modulus_u64.rs b/src/backend/modulus_u64.rs index 16c0e85..f279c03 100644 --- a/src/backend/modulus_u64.rs +++ b/src/backend/modulus_u64.rs @@ -230,34 +230,15 @@ impl VectorOps for ModularOpsU64 { // } } -impl, T> ShoupMatrixFMA for ModularOpsU64 -where - M::R: RowMut, -{ - fn shoup_matrix_fma(&self, out: &mut ::R, a: &M, a_shoup: &M, b: &M) { - assert!(a.dimension() == a_shoup.dimension()); - assert!(a.dimension() == b.dimension()); +impl, T> ShoupMatrixFMA for ModularOpsU64 { + fn shoup_matrix_fma(&self, out: &mut [R::Element], a: &[R], a_shoup: &[R], b: &[R]) { + assert!(a.len() == a_shoup.len()); + assert!(a.len() == b.len()); let q = self.q; let q_twice = self.q << 1; - // first row (without summation) - izip!( - out.as_mut().iter_mut(), - a.get_row(0), - a_shoup.get_row(0), - b.get_row(0) - ) - .for_each(|(o, a, a_shoup, b)| { - *o = ShoupMul::mul(*b, *a, *a_shoup, q); - }); - - izip!( - a.iter_rows().skip(1), - a_shoup.iter_rows().skip(1), - b.iter_rows().skip(1) - ) - .for_each(|(a_row, a_shoup_row, b_row)| { + izip!(a.iter(), a_shoup.iter(), b.iter()).for_each(|(a_row, a_shoup_row, b_row)| { izip!( out.as_mut().iter_mut(), a_row.as_ref().iter(), diff --git a/src/ntt.rs b/src/ntt.rs index a2e0c40..ff76aa3 100644 --- a/src/ntt.rs +++ b/src/ntt.rs @@ -158,8 +158,9 @@ pub fn ntt(a: &mut [u64], psi: &[u64], psi_shoup: &[u64], q: u64, q_twice: u64) if t == 1 { for (a, w, w_shoup) in izip!(a.chunks_mut(2), w.iter(), w_shoup.iter()) { let (ox, oy) = forward_butterly_0_to_2q(a[0], a[1], *w, *w_shoup, q, q_twice); - a[0] = ox.min(ox.wrapping_sub(q_twice)); - a[1] = oy.min(oy.wrapping_sub(q_twice)); + // reduce from range [0, 2q) to [0, q) + a[0] = ox.min(ox.wrapping_sub(q)); + a[1] = oy.min(oy.wrapping_sub(q)); } } else { for i in 0..m { @@ -476,6 +477,11 @@ mod tests { .collect_vec() } + fn assert_output_range(a: &[u64], max_val: u64) { + a.iter() + .for_each(|v| assert!(v <= &max_val, "{v} > {max_val}")); + } + #[test] fn native_ntt_backend_works() { // TODO(Jay): Improve tests. Add tests for different primes and ring size. @@ -485,20 +491,26 @@ mod tests { let a_clone = a.clone(); ntt_backend.forward(&mut a); + assert_output_range(a.as_ref(), Q_60_BITS - 1); assert_ne!(a, a_clone); ntt_backend.backward(&mut a); + assert_output_range(a.as_ref(), Q_60_BITS - 1); assert_eq!(a, a_clone); ntt_backend.forward_lazy(&mut a); + assert_output_range(a.as_ref(), (2 * Q_60_BITS) - 1); assert_ne!(a, a_clone); ntt_backend.backward(&mut a); + assert_output_range(a.as_ref(), Q_60_BITS - 1); assert_eq!(a, a_clone); ntt_backend.forward(&mut a); + assert_output_range(a.as_ref(), Q_60_BITS - 1); ntt_backend.backward_lazy(&mut a); + assert_output_range(a.as_ref(), (2 * Q_60_BITS) - 1); // reduce a.iter_mut().for_each(|a0| { - if *a0 > Q_60_BITS { + if *a0 >= Q_60_BITS { *a0 -= *a0 - Q_60_BITS; } }); diff --git a/src/rgsw/mod.rs b/src/rgsw/mod.rs index d4335c6..21e21ce 100644 --- a/src/rgsw/mod.rs +++ b/src/rgsw/mod.rs @@ -16,7 +16,7 @@ use crate::{ DefaultSecureRng, NewWithSeed, RandomElementInModulus, RandomFill, RandomFillGaussianInModulus, RandomFillUniformInModulus, }, - utils::{fill_random_ternary_secret_with_hamming_weight, TryConvertFrom1, WithLocal}, + utils::{fill_random_ternary_secret_with_hamming_weight, ToShoup, TryConvertFrom1, WithLocal}, Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, Secret, }; @@ -91,21 +91,17 @@ where } } -pub(crate) trait ToShoup { - fn to_shoup(value: Self, modulus: Self) -> Self; -} - pub struct ShoupAutoKeyEvaluationDomain { data: M, } impl, R, N> - From> for ShoupAutoKeyEvaluationDomain + From<&AutoKeyEvaluationDomain> for ShoupAutoKeyEvaluationDomain where M::R: RowMut, M::MatElement: ToShoup + Copy, { - fn from(value: AutoKeyEvaluationDomain) -> Self { + fn from(value: &AutoKeyEvaluationDomain) -> Self { let (row, col) = value.data.dimension(); let mut shoup_data = M::zeros(row, col); @@ -538,6 +534,7 @@ pub(crate) mod tests { decomposer::{Decomposer, DefaultDecomposer, RlweDecomposer}, ntt::{self, Ntt, NttBackendU64, NttInit}, random::{DefaultSecureRng, NewWithSeed, RandomFillUniformInModulus}, + rgsw::{galois_auto_shoup, ShoupAutoKeyEvaluationDomain}, utils::{generate_prime, negacyclic_mul, Stats, TryConvertFrom1}, Matrix, Secret, }; @@ -961,16 +958,45 @@ pub(crate) mod tests { // Send RLWE_{s}(m) -> RLWE_{s}(m^k) let mut scratch_space = vec![vec![0u64; ring_size as usize]; d_rgsw + 2]; let (auto_map_index, auto_map_sign) = generate_auto_map(ring_size as usize, auto_k); - galois_auto( - &mut rlwe_m, - &auto_key.data, - &mut scratch_space, - &auto_map_index, - &auto_map_sign, - &mod_op, - &ntt_op, - &decomposer, - ); + + // galois auto with additional auto key in shoup repr + let rlwe_m_shoup = { + let auto_key_shoup = ShoupAutoKeyEvaluationDomain::from(&auto_key); + let mut rlwe_m_shoup = RlweCiphertext::<_, DefaultSecureRng> { + data: rlwe_m.data.clone(), + is_trivial: rlwe_m.is_trivial, + _phatom: PhantomData::default(), + }; + galois_auto_shoup( + &mut rlwe_m_shoup, + &auto_key.data, + &auto_key_shoup.data, + &mut scratch_space, + &auto_map_index, + &auto_map_sign, + &mod_op, + &ntt_op, + &decomposer, + ); + rlwe_m_shoup + }; + + // normal galois auto + { + galois_auto( + &mut rlwe_m, + &auto_key.data, + &mut scratch_space, + &auto_map_index, + &auto_map_sign, + &mod_op, + &ntt_op, + &decomposer, + ); + } + + // rlwe out from both functions must be same + assert_eq!(rlwe_m.data, rlwe_m_shoup.data); let rlwe_m_k = rlwe_m; diff --git a/src/rgsw/runtime.rs b/src/rgsw/runtime.rs index 88085b4..28cb475 100644 --- a/src/rgsw/runtime.rs +++ b/src/rgsw/runtime.rs @@ -2,7 +2,7 @@ use itertools::izip; use num_traits::Zero; use crate::{ - backend::{ArithmeticOps, GetModulus, Modulus, VectorOps}, + backend::{ArithmeticOps, GetModulus, Modulus, ShoupMatrixFMA, VectorOps}, decomposer::{Decomposer, RlweDecomposer}, ntt::Ntt, Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, Secret, @@ -181,6 +181,134 @@ pub(crate) fn galois_auto< .copy_from_slice(tmp_rlwe_out[1].as_ref()); } +pub(crate) fn galois_auto_shoup< + MT: Matrix + IsTrivial + MatrixMut, + Mmut: MatrixMut, + ModOp: ArithmeticOps + + VectorOps + + ShoupMatrixFMA, + NttOp: Ntt, + D: Decomposer, +>( + rlwe_in: &mut MT, + ksk: &Mmut, + ksk_shoup: &Mmut, + scratch_matrix: &mut Mmut, + auto_map_index: &[usize], + auto_map_sign: &[bool], + mod_op: &ModOp, + ntt_op: &NttOp, + decomposer: &D, +) where + ::R: RowMut, + ::R: RowMut, + MT::MatElement: Copy + Zero, +{ + let d = decomposer.decomposition_count(); + let ring_size = rlwe_in.dimension().1; + assert!(rlwe_in.dimension().0 == 2); + assert!(scratch_matrix.fits(d + 2, ring_size)); + + let (scratch_matrix_d_ring, other_half) = scratch_matrix.split_at_row_mut(d); + let (tmp_rlwe_out, _) = other_half.split_at_mut(2); + + debug_assert!(tmp_rlwe_out.len() == 2); + debug_assert!(scratch_matrix_d_ring.len() == d); + + if !rlwe_in.is_trivial() { + tmp_rlwe_out.iter_mut().for_each(|r| { + r.as_mut().fill(Mmut::MatElement::zero()); + }); + + // send a(X) -> a(X^k) and decompose a(X^k) + izip!( + rlwe_in.get_row(0), + auto_map_index.iter(), + auto_map_sign.iter() + ) + .for_each(|(el_in, to_index, sign)| { + let el_out = if !*sign { mod_op.neg(el_in) } else { *el_in }; + + decomposer + .decompose_iter(&el_out) + .enumerate() + .for_each(|(index, el)| { + scratch_matrix_d_ring[index].as_mut()[*to_index] = el; + }); + }); + + // transform decomposed a(X^k) to evaluation domain + scratch_matrix_d_ring.iter_mut().for_each(|r| { + ntt_op.forward_lazy(r.as_mut()); + }); + + // RLWE(m^k) = a', b'; RLWE(m) = a, b + // key switch: (a * RLWE'(s(X^k))) + let (ksk_a, ksk_b) = ksk.split_at_row(d); + let (ksk_a_shoup, ksk_b_shoup) = ksk_shoup.split_at_row(d); + // a' = decomp * RLWE'_A(s(X^k)) + mod_op.shoup_matrix_fma( + tmp_rlwe_out[0].as_mut(), + ksk_a, + ksk_a_shoup, + scratch_matrix_d_ring, + ); + + // b'= decomp * RLWE'_B(s(X^k)) + mod_op.shoup_matrix_fma( + tmp_rlwe_out[1].as_mut(), + ksk_b, + ksk_b_shoup, + scratch_matrix_d_ring, + ); + + // transform RLWE(m^k) to coefficient domain + tmp_rlwe_out + .iter_mut() + .for_each(|r| ntt_op.backward(r.as_mut())); + + // send b(X) -> b(X^k) and then b'(X) += b(X^k) + izip!( + rlwe_in.get_row(1), + auto_map_index.iter(), + auto_map_sign.iter() + ) + .for_each(|(el_in, to_index, sign)| { + let row = tmp_rlwe_out[1].as_mut(); + if !*sign { + row[*to_index] = mod_op.sub(&row[*to_index], el_in); + } else { + row[*to_index] = mod_op.add(&row[*to_index], el_in); + } + }); + + // copy over A; Leave B for later + rlwe_in + .get_row_mut(0) + .copy_from_slice(tmp_rlwe_out[0].as_ref()); + } else { + // RLWE is trivial, a(X) is 0. + // send b(X) -> b(X^k) + izip!( + rlwe_in.get_row(1), + auto_map_index.iter(), + auto_map_sign.iter() + ) + .for_each(|(el_in, to_index, sign)| { + if !*sign { + tmp_rlwe_out[1].as_mut()[*to_index] = mod_op.neg(el_in); + } else { + tmp_rlwe_out[1].as_mut()[*to_index] = *el_in; + } + }); + } + + // Copy over B + rlwe_in + .get_row_mut(1) + .copy_from_slice(tmp_rlwe_out[1].as_ref()); +} + /// Returns RLWE(m0m1) = RLWE(m0) x RGSW(m1). Mutates rlwe_in inplace to equal /// RLWE(m0m1) /// diff --git a/src/utils.rs b/src/utils.rs index 6fd9c08..c5e0d00 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -25,11 +25,15 @@ pub trait Global { fn global() -> &'static Self; } -pub trait ShoupMul { +pub(crate) trait ShoupMul { fn representation(value: Self, q: Self) -> Self; fn mul(a: Self, b: Self, b_shoup: Self, q: Self) -> Self; } +pub(crate) trait ToShoup { + fn to_shoup(value: Self, modulus: Self) -> Self; +} + impl ShoupMul for u64 { #[inline] fn representation(value: Self, q: Self) -> Self { @@ -44,6 +48,12 @@ impl ShoupMul for u64 { } } +impl ToShoup for u64 { + fn to_shoup(value: Self, modulus: Self) -> Self { + ((value as u128 * (1u128 << 64)) / modulus as u128) as u64 + } +} + pub fn fill_random_ternary_secret_with_hamming_weight< T: Signed, R: RandomFill<[u8]> + RandomElementInModulus,