From bc02262f9d9ec022cf9b37e41a2e3493f46c54ba Mon Sep 17 00:00:00 2001 From: Janmajaya Mall Date: Sat, 29 Jun 2024 19:22:15 +0530 Subject: [PATCH] modify rgsw/runtime ot use traits --- src/bool/evaluator.rs | 79 +-- src/bool/keys.rs | 1 - src/bool/print_noise.rs | 25 +- src/pbs.rs | 84 +-- src/rgsw/keygen.rs | 13 +- src/rgsw/mod.rs | 1324 ++++++++++++++++----------------------- src/rgsw/runtime.rs | 876 +++++++++++++++++--------- 7 files changed, 1224 insertions(+), 1178 deletions(-) diff --git a/src/bool/evaluator.rs b/src/bool/evaluator.rs index d09434c..4d0a228 100644 --- a/src/bool/evaluator.rs +++ b/src/bool/evaluator.rs @@ -25,15 +25,15 @@ use crate::{ RandomFillUniformInModulus, RandomGaussianElementInModulus, }, rgsw::{ - decrypt_rlwe, generate_auto_map, public_key_encrypt_rgsw, rgsw_by_rgsw_inplace, rlwe_auto, - secret_key_encrypt_rgsw, seeded_auto_key_gen, + decrypt_rlwe, generate_auto_map, public_key_encrypt_rgsw, rgsw_by_rgsw_inplace, + rgsw_x_rgsw_scratch_rows, rlwe_auto, secret_key_encrypt_rgsw, seeded_auto_key_gen, + RgswCiphertext, RgswCiphertextMutRef, RgswCiphertextRef, RuntimeScratchMutRef, }, utils::{ encode_x_pow_si_with_emebedding_factor, fill_random_ternary_secret_with_hamming_weight, generate_prime, mod_exponent, puncture_p_rng, Global, TryConvertFrom1, WithLocal, }, - Decryptor, Encoder, Encryptor, Matrix, MatrixEntity, MatrixMut, MultiPartyDecryptor, Row, - RowEntity, RowMut, Secret, + Encoder, Matrix, MatrixEntity, MatrixMut, RowEntity, RowMut, }; use super::{ @@ -45,10 +45,7 @@ use super::{ SeededMultiPartyServerKey, SeededNonInteractiveMultiPartyServerKey, SeededSinglePartyServerKey, SinglePartyClientKey, }, - parameters::{ - BoolParameters, CiphertextModulus, DecompositionCount, DecompostionLogBase, - DoubleDecomposerParams, - }, + parameters::{BoolParameters, CiphertextModulus, DecompositionCount, DoubleDecomposerParams}, }; /// Common reference seed used for Interactive multi-party, @@ -1146,12 +1143,12 @@ where // rgsw ciphertext (most expensive part!) let rgsw_cts = { - let rgsw_by_rgsw_decomposer = + let rgsw_x_rgsw_decomposer = parameters.rgsw_rgsw_decomposer::>(); let rlwe_x_rgsw_decomposer = self.pbs_info().rlwe_rgsw_decomposer(); let rgsw_x_rgsw_dimension = ( - rgsw_by_rgsw_decomposer.a().decomposition_count() * 2 - + rgsw_by_rgsw_decomposer.b().decomposition_count() * 2, + rgsw_x_rgsw_decomposer.a().decomposition_count() * 2 + + rgsw_x_rgsw_decomposer.b().decomposition_count() * 2, rlwe_n, ); let rlwe_x_rgsw_dimension = ( @@ -1159,11 +1156,9 @@ where + rlwe_x_rgsw_decomposer.b().decomposition_count() * 2, rlwe_n, ); - let mut rgsw_x_rgsw_scratch_mat = M::zeros( - std::cmp::max( - rgsw_by_rgsw_decomposer.a().decomposition_count(), - rgsw_by_rgsw_decomposer.b().decomposition_count(), - ) + rlwe_x_rgsw_dimension.0, + + let mut rgsw_x_rgsw_scratch = M::zeros( + rgsw_x_rgsw_scratch_rows(rlwe_x_rgsw_decomposer, &rgsw_x_rgsw_decomposer), rlwe_n, ); @@ -1216,15 +1211,22 @@ where .for_each(|r| rlweq_nttop.forward(r.as_mut())); rgsw_by_rgsw_inplace( - &mut rgsw_i, - rlwe_x_rgsw_decomposer.a().decomposition_count(), - rlwe_x_rgsw_decomposer.b().decomposition_count(), - &other_rgsw_i, - &rgsw_by_rgsw_decomposer, - &mut rgsw_x_rgsw_scratch_mat, + &mut RgswCiphertextMutRef::new( + rgsw_i.as_mut(), + rlwe_x_rgsw_decomposer.a().decomposition_count(), + rlwe_x_rgsw_decomposer.b().decomposition_count(), + ), + &RgswCiphertextRef::new( + other_rgsw_i.as_ref(), + rgsw_x_rgsw_decomposer.a().decomposition_count(), + rgsw_x_rgsw_decomposer.b().decomposition_count(), + ), + rlwe_x_rgsw_decomposer, + &rgsw_x_rgsw_decomposer, + &mut RuntimeScratchMutRef::new(rgsw_x_rgsw_scratch.as_mut()), rlweq_nttop, rlweq_modop, - ) + ); }); rgsw_cts.push(rgsw_i); @@ -1370,11 +1372,7 @@ where }; let mut scratch_rgsw_x_rgsw = M::zeros( - std::cmp::max( - rgsw_x_rgsw_decomposer.a().decomposition_count(), - rgsw_x_rgsw_decomposer.b().decomposition_count(), - ) + rlwe_x_rgsw_decomposer.a().decomposition_count() * 2 - + rlwe_x_rgsw_decomposer.b().decomposition_count() * 2, + rgsw_x_rgsw_scratch_rows(&rlwe_x_rgsw_decomposer, &rgsw_x_rgsw_decomposer), self.parameters().rlwe_n().0, ); @@ -1534,7 +1532,7 @@ where (0..total_users) .filter(|i| *i != user_id) .for_each(|other_user_id| { - let other_rgsw_i = produce_rgsw_ciphertext_from_ni_rgsw( + let mut other_rgsw_i = produce_rgsw_ciphertext_from_ni_rgsw( key_shares[other_user_id] .ni_rgsw_cts_for_self_not_leader_lwe_index(lwe_index), &ni_rgsw_zero_encs, @@ -1551,12 +1549,21 @@ where ); rgsw_by_rgsw_inplace( - &mut rgsw_i, - rlwe_x_rgsw_decomposer.a().decomposition_count(), - rlwe_x_rgsw_decomposer.b().decomposition_count(), - &other_rgsw_i, + &mut RgswCiphertextMutRef::new( + rgsw_i.as_mut(), + rlwe_x_rgsw_decomposer.a().decomposition_count(), + rlwe_x_rgsw_decomposer.b().decomposition_count(), + ), + &RgswCiphertextRef::new( + other_rgsw_i.as_ref(), + rgsw_x_rgsw_decomposer.a().decomposition_count(), + rgsw_x_rgsw_decomposer.b().decomposition_count(), + ), + &rlwe_x_rgsw_decomposer, &rgsw_x_rgsw_decomposer, - &mut scratch_rgsw_x_rgsw, + &mut RuntimeScratchMutRef::new( + scratch_rgsw_x_rgsw.as_mut(), + ), nttop, rlwe_modop, ) @@ -2096,9 +2103,7 @@ where }); let e = DefaultSecureRng::with_local_mut(|rng| { - let mut e = - RandomGaussianElementInModulus::random(rng, self.pbs_info.parameters.rlwe_q()); - e + RandomGaussianElementInModulus::random(rng, self.pbs_info.parameters.rlwe_q()) }); let share = modop.add(&neg_sa, &e); diff --git a/src/bool/keys.rs b/src/bool/keys.rs index 18b6dc0..ab2afde 100644 --- a/src/bool/keys.rs +++ b/src/bool/keys.rs @@ -4,7 +4,6 @@ use crate::{ backend::{ModInit, VectorOps}, pbs::WithShoupRepr, random::{NewWithSeed, RandomFillUniformInModulus}, - rgsw::RlweSecret, utils::{ToShoup, WithLocal}, Decryptor, Encryptor, Matrix, MatrixEntity, MatrixMut, MultiPartyDecryptor, RowEntity, RowMut, }; diff --git a/src/bool/print_noise.rs b/src/bool/print_noise.rs index 85aa730..fc1e9a9 100644 --- a/src/bool/print_noise.rs +++ b/src/bool/print_noise.rs @@ -10,7 +10,10 @@ use crate::{ lwe::{decrypt_lwe, lwe_key_switch}, parameters::{BoolParameters, CiphertextModulus}, random::{DefaultSecureRng, RandomFillUniformInModulus}, - rgsw::{decrypt_rlwe, rlwe_auto, IsTrivial, RlweCiphertext}, + rgsw::{ + decrypt_rlwe, rlwe_auto, rlwe_auto_scratch_rows, RlweCiphertextMutRef, RlweKskRef, + RuntimeScratchMutRef, + }, utils::{encode_x_pow_si_with_emebedding_factor, tests::Stats, TryConvertFrom1}, ArithmeticOps, ClientKey, Decomposer, MatrixEntity, MatrixMut, ModInit, Ntt, NttInit, RowEntity, RowMut, VectorOps, @@ -223,7 +226,8 @@ where let br_q = parameters.br_q(); let g_dlogs = parameters.auto_element_dlogs(); let auto_decomposer = parameters.auto_decomposer::(); - let mut scratch_matrix = M::zeros(auto_decomposer.decomposition_count() + 2, rlwe_n); + let mut scratch_matrix = M::zeros(rlwe_auto_scratch_rows(&auto_decomposer), rlwe_n); + let mut scratch_matrix_ref = RuntimeScratchMutRef::new(scratch_matrix.as_mut()); g_dlogs.iter().for_each(|k| { let g_pow_k = if *k == 0 { @@ -279,19 +283,22 @@ where // RLWE auto sends part A, A(X), of RLWE to A(X^{g^k}) and then multiplies it // with -s(X^{g^k}) using auto key. Deliberately set RLWE = (0, m(X)) // (ie. m in part A) to get back RLWE(-m(X^{g^k})s(X^{g^k})) - let mut rlwe = RlweCiphertext::<_, DefaultSecureRng>::new_trivial(M::zeros(2, rlwe_n)); - rlwe.data.get_row_mut(0).copy_from_slice(m.as_ref()); - rlwe.set_not_trivial(); + let mut rlwe = M::zeros(2, rlwe_n); + rlwe.get_row_mut(0).copy_from_slice(m.as_ref()); rlwe_auto( - &mut rlwe, - server_key.galois_key_for_auto(*k), - &mut scratch_matrix, + &mut RlweCiphertextMutRef::new(rlwe.as_mut()), + &RlweKskRef::new( + server_key.galois_key_for_auto(*k).as_ref(), + auto_decomposer.decomposition_count(), + ), + &mut scratch_matrix_ref, &auto_index_map, &auto_sign_map, &rlwe_modop, &rlwe_nttop, &auto_decomposer, + false, ); // decrypt RLWE(-m(X)s(X^{g^k]})) @@ -430,7 +437,7 @@ mod tests { set_parameter_set(crate::ParameterSelector::NonInteractiveLTE2Party); set_common_reference_seed(NonInteractiveMultiPartyCrs::random().seed); let parties = 2; - let cks = (0..parties).map(|i| gen_client_key()).collect_vec(); + let cks = (0..parties).map(|_| gen_client_key()).collect_vec(); let server_key_shares = cks .iter() .enumerate() diff --git a/src/pbs.rs b/src/pbs.rs index 000dab6..7a727d5 100644 --- a/src/pbs.rs +++ b/src/pbs.rs @@ -1,14 +1,16 @@ -use std::{fmt::Display, marker::PhantomData}; +use std::fmt::Display; use num_traits::{FromPrimitive, One, PrimInt, ToPrimitive, Zero}; use crate::{ backend::{ArithmeticOps, Modulus, ShoupMatrixFMA, VectorOps}, - decomposer::Decomposer, + decomposer::{Decomposer, RlweDecomposer}, lwe::lwe_key_switch, ntt::Ntt, - random::DefaultSecureRng, - rgsw::{galois_auto_shoup, rlwe_by_rgsw_shoup, IsTrivial, RlweCiphertext}, + rgsw::{ + rlwe_auto_shoup, rlwe_by_rgsw_shoup, RgswCiphertextRef, RlweCiphertextMutRef, RlweKskRef, + RuntimeScratchMutRef, + }, Matrix, MatrixEntity, MatrixMut, RowMut, }; pub(crate) trait PbsKey { @@ -215,7 +217,8 @@ pub(crate) fn pbs< /// gk_to_si: [g^0, ..., g^{q/2-1}, -g^0, -g^1, .., -g^{q/2-1}] fn blind_rotation< Mmut: MatrixMut, - D: Decomposer, + RlweD: RlweDecomposer, + AutoD: Decomposer, NttOp: Ntt, ModOp: ArithmeticOps + ShoupMatrixFMA, MShoup: WithShoupRepr, @@ -228,8 +231,8 @@ fn blind_rotation< w: usize, q: usize, gk_to_si: &[Vec], - rlwe_rgsw_decomposer: &(D, D), - auto_decomposer: &D, + rlwe_rgsw_decomposer: &RlweD, + auto_decomposer: &AutoD, ntt_op: &NttOp, mod_op: &ModOp, parameters: &P, @@ -239,6 +242,11 @@ fn blind_rotation< Mmut::MatElement: Copy + Zero, { let mut is_trivial = true; + let mut scratch_matrix = RuntimeScratchMutRef::new(scratch_matrix.as_mut()); + let mut rlwe = RlweCiphertextMutRef::new(trivial_rlwe_test_poly.as_mut()); + let d_a = rlwe_rgsw_decomposer.a().decomposition_count(); + let d_b = rlwe_rgsw_decomposer.b().decomposition_count(); + let d_auto = auto_decomposer.decomposition_count(); let q_by_4 = q >> 2; let mut count = 0; @@ -252,10 +260,10 @@ fn blind_rotation< // let new = std::time::Instant::now(); let ct = pbs_key.rgsw_ct_lwe_si(*s_index); rlwe_by_rgsw_shoup( - trivial_rlwe_test_poly, - ct.as_ref(), - ct.shoup_repr(), - scratch_matrix, + &mut rlwe, + &RgswCiphertextRef::new(ct.as_ref().as_ref(), d_a, d_b), + &RgswCiphertextRef::new(ct.shoup_repr().as_ref(), d_a, d_b), + &mut scratch_matrix, rlwe_rgsw_decomposer, ntt_op, mod_op, @@ -271,11 +279,11 @@ fn blind_rotation< // let now = std::time::Instant::now(); let auto_key = pbs_key.galois_key_for_auto(v); - galois_auto_shoup( - trivial_rlwe_test_poly, - auto_key.as_ref(), - auto_key.shoup_repr(), - scratch_matrix, + rlwe_auto_shoup( + &mut rlwe, + &RlweKskRef::new(auto_key.as_ref().as_ref(), d_auto), + &RlweKskRef::new(auto_key.shoup_repr().as_ref(), d_auto), + &mut scratch_matrix, &auto_map_index, &auto_map_sign, mod_op, @@ -295,10 +303,10 @@ fn blind_rotation< gk_to_si[q_by_4].iter().for_each(|s_index| { let ct = pbs_key.rgsw_ct_lwe_si(*s_index); rlwe_by_rgsw_shoup( - trivial_rlwe_test_poly, - ct.as_ref(), - ct.shoup_repr(), - scratch_matrix, + &mut rlwe, + &RgswCiphertextRef::new(ct.as_ref().as_ref(), d_a, d_b), + &RgswCiphertextRef::new(ct.shoup_repr().as_ref(), d_a, d_b), + &mut scratch_matrix, rlwe_rgsw_decomposer, ntt_op, mod_op, @@ -309,11 +317,11 @@ fn blind_rotation< let (auto_map_index, auto_map_sign) = parameters.rlwe_auto_map(0); let auto_key = pbs_key.galois_key_for_auto(0); - galois_auto_shoup( - trivial_rlwe_test_poly, - auto_key.as_ref(), - auto_key.shoup_repr(), - scratch_matrix, + rlwe_auto_shoup( + &mut rlwe, + &RlweKskRef::new(auto_key.as_ref().as_ref(), d_auto), + &RlweKskRef::new(auto_key.shoup_repr().as_ref(), d_auto), + &mut scratch_matrix, &auto_map_index, &auto_map_sign, mod_op, @@ -331,10 +339,10 @@ fn blind_rotation< s_indices.iter().for_each(|s_index| { let ct = pbs_key.rgsw_ct_lwe_si(*s_index); rlwe_by_rgsw_shoup( - trivial_rlwe_test_poly, - ct.as_ref(), - ct.shoup_repr(), - scratch_matrix, + &mut rlwe, + &RgswCiphertextRef::new(ct.as_ref().as_ref(), d_a, d_b), + &RgswCiphertextRef::new(ct.shoup_repr().as_ref(), d_a, d_b), + &mut scratch_matrix, rlwe_rgsw_decomposer, ntt_op, mod_op, @@ -347,11 +355,11 @@ fn blind_rotation< if gk_to_si[i - 1].len() != 0 || v == w || i == 1 { let (auto_map_index, auto_map_sign) = parameters.rlwe_auto_map(v); let auto_key = pbs_key.galois_key_for_auto(v); - galois_auto_shoup( - trivial_rlwe_test_poly, - auto_key.as_ref(), - auto_key.shoup_repr(), - scratch_matrix, + rlwe_auto_shoup( + &mut rlwe, + &RlweKskRef::new(auto_key.as_ref().as_ref(), d_auto), + &RlweKskRef::new(auto_key.shoup_repr().as_ref(), d_auto), + &mut scratch_matrix, &auto_map_index, &auto_map_sign, mod_op, @@ -368,10 +376,10 @@ fn blind_rotation< gk_to_si[0].iter().for_each(|s_index| { let ct = pbs_key.rgsw_ct_lwe_si(*s_index); rlwe_by_rgsw_shoup( - trivial_rlwe_test_poly, - ct.as_ref(), - ct.shoup_repr(), - scratch_matrix, + &mut rlwe, + &RgswCiphertextRef::new(ct.as_ref().as_ref(), d_a, d_b), + &RgswCiphertextRef::new(ct.shoup_repr().as_ref(), d_a, d_b), + &mut scratch_matrix, rlwe_rgsw_decomposer, ntt_op, mod_op, diff --git a/src/rgsw/keygen.rs b/src/rgsw/keygen.rs index 094acb1..9d69161 100644 --- a/src/rgsw/keygen.rs +++ b/src/rgsw/keygen.rs @@ -624,16 +624,15 @@ pub(crate) fn decrypt_rlwe< mod_op.elwise_add_mut(m_out.as_mut(), rlwe_ct.get_row_slice(1)); } -// Measures noise in degree 1 RLWE ciphertext against encoded ideal message -// encoded_m -pub(crate) fn measure_noise< +// Measures maximum noise in degree 1 RLWE ciphertext against message `want_m` +pub(crate) fn measure_max_noise< Mmut: MatrixMut + Matrix, ModOp: VectorOps + GetModulus, NttOp: Ntt, S, >( rlwe_ct: &Mmut, - encoded_m_ideal: &Mmut::R, + want_m: &Mmut::R, ntt_op: &NttOp, mod_op: &ModOp, s: &[S], @@ -645,7 +644,7 @@ where { let ring_size = s.len(); assert!(rlwe_ct.dimension() == (2, ring_size)); - assert!(encoded_m_ideal.as_ref().len() == ring_size); + assert!(want_m.as_ref().len() == ring_size); // -(s * a) let q = mod_op.modulus(); @@ -663,11 +662,11 @@ where mod_op.elwise_add_mut(m_plus_e.as_mut(), rlwe_ct.get_row_slice(1)); // difference - mod_op.elwise_sub_mut(m_plus_e.as_mut(), encoded_m_ideal.as_ref()); + mod_op.elwise_sub_mut(m_plus_e.as_mut(), want_m.as_ref()); let mut max_diff_bits = f64::MIN; m_plus_e.as_ref().iter().for_each(|v| { - let bits = (q.map_element_to_i64(v).to_f64().unwrap()).log2(); + let bits = (q.map_element_to_i64(v).to_f64().unwrap().abs()).log2(); if max_diff_bits < bits { max_diff_bits = bits; diff --git a/src/rgsw/mod.rs b/src/rgsw/mod.rs index 2ae3597..6a55542 100644 --- a/src/rgsw/mod.rs +++ b/src/rgsw/mod.rs @@ -1,598 +1,438 @@ -use itertools::{izip, Itertools}; -use num_traits::{PrimInt, Signed, ToPrimitive, Zero}; -use std::{ - clone, - fmt::Debug, - iter, - marker::PhantomData, - ops::{Div, Neg, Sub}, -}; - -use crate::{ - backend::Modulus, - decomposer::{Decomposer, RlweDecomposer}, - ntt::{Ntt, NttInit}, - random::{DefaultSecureRng, NewWithSeed, RandomFillUniformInModulus}, - utils::{fill_random_ternary_secret_with_hamming_weight, ToShoup, WithLocal}, - Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, Secret, -}; - mod keygen; mod runtime; pub(crate) use keygen::*; pub(crate) use runtime::*; -pub struct SeededAutoKey -where - M: Matrix, -{ - data: M, - seed: S, - modulus: Mod, -} - -impl> SeededAutoKey { - fn empty(ring_size: usize, auto_decomposer: &D, seed: S, modulus: Mod) -> Self { - SeededAutoKey { - data: M::zeros(auto_decomposer.decomposition_count(), ring_size), - seed, - modulus, - } - } -} - -pub struct AutoKeyEvaluationDomain { - data: M, - _phantom: PhantomData<(R, N)>, - modulus: Mod, -} - -impl< - M: MatrixMut + MatrixEntity, - Mod: Modulus + Clone, - R: RandomFillUniformInModulus<[M::MatElement], Mod> + NewWithSeed, - N: NttInit + Ntt, - > From<&SeededAutoKey> for AutoKeyEvaluationDomain -where - ::R: RowMut, - M::MatElement: Copy, - - R::Seed: Clone, -{ - fn from(value: &SeededAutoKey) -> Self { - let (d, ring_size) = value.data.dimension(); - let mut data = M::zeros(2 * d, ring_size); - - // sample RLWE'_A(-s(X^k)) - let mut p_rng = R::new_with_seed(value.seed.clone()); - data.iter_rows_mut().take(d).for_each(|r| { - RandomFillUniformInModulus::random_fill(&mut p_rng, &value.modulus, r.as_mut()); - }); - - // copy over RLWE'_B(-s(X^k)) - izip!(data.iter_rows_mut().skip(d), value.data.iter_rows()).for_each(|(to_r, from_r)| { - to_r.as_mut().copy_from_slice(from_r.as_ref()); - }); - - // send RLWE'(-s(X^k)) polynomials to evaluation domain - let ntt_op = N::new(&value.modulus, ring_size); - data.iter_rows_mut() - .for_each(|r| ntt_op.forward(r.as_mut())); - - AutoKeyEvaluationDomain { - data, - _phantom: PhantomData, - modulus: value.modulus.clone(), - } - } -} +#[cfg(test)] +pub(crate) mod tests { + use std::{fmt::Debug, marker::PhantomData, vec}; -pub struct ShoupAutoKeyEvaluationDomain { - data: M, -} + use itertools::{izip, Itertools}; + use rand::{thread_rng, Rng}; -impl, Mod: Modulus, R, N> - From<&AutoKeyEvaluationDomain> for ShoupAutoKeyEvaluationDomain -{ - fn from(value: &AutoKeyEvaluationDomain) -> Self { - Self { - data: M::to_shoup(&value.data, value.modulus.q().unwrap()), - } - } -} + use crate::{ + backend::{GetModulus, ModInit, ModularOpsU64, Modulus, VectorOps}, + decomposer::{Decomposer, DefaultDecomposer, RlweDecomposer}, + ntt::{Ntt, NttBackendU64, NttInit}, + random::{DefaultSecureRng, NewWithSeed, RandomFillUniformInModulus}, + rgsw::{ + rlwe_auto_scratch_rows, rlwe_auto_shoup, rlwe_by_rgsw_shoup, rlwe_x_rgsw_scratch_rows, + RgswCiphertextRef, RlweCiphertextMutRef, RlweKskRef, RuntimeScratchMutRef, + }, + utils::{ + fill_random_ternary_secret_with_hamming_weight, generate_prime, negacyclic_mul, + tests::Stats, ToShoup, TryConvertFrom1, WithLocal, + }, + Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, Secret, + }; -pub struct RgswCiphertext { - /// Rgsw ciphertext polynomials - pub(crate) data: M, - modulus: Mod, - /// Decomposition for RLWE part A - d_a: usize, - /// Decomposition for RLWE part B - d_b: usize, -} + use super::{ + keygen::{ + decrypt_rlwe, generate_auto_map, rlwe_public_key, secret_key_encrypt_rgsw, + seeded_auto_key_gen, seeded_secret_key_encrypt_rlwe, + }, + rgsw_x_rgsw_scratch_rows, + runtime::{rgsw_by_rgsw_inplace, rlwe_auto, rlwe_by_rgsw}, + RgswCiphertextMutRef, + }; -impl> RgswCiphertext { - pub(crate) fn empty( - ring_size: usize, - decomposer: &D, + struct SeededAutoKey + where + M: Matrix, + { + data: M, + seed: S, modulus: Mod, - ) -> RgswCiphertext { - RgswCiphertext { - data: M::zeros( - decomposer.a().decomposition_count() * 2 + decomposer.b().decomposition_count() * 2, - ring_size, - ), - d_a: decomposer.a().decomposition_count(), - d_b: decomposer.b().decomposition_count(), - modulus, - } } -} -pub struct SeededRgswCiphertext -where - M: Matrix, -{ - pub(crate) data: M, - seed: S, - modulus: Mod, - /// Decomposition for RLWE part A - d_a: usize, - /// Decomposition for RLWE part B - d_b: usize, -} - -impl SeededRgswCiphertext { - pub(crate) fn empty( - ring_size: usize, - decomposer: &D, - seed: S, - modulus: Mod, - ) -> SeededRgswCiphertext { - SeededRgswCiphertext { - data: M::zeros( - decomposer.a().decomposition_count() * 2 + decomposer.b().decomposition_count(), - ring_size, - ), - seed, - modulus, - d_a: decomposer.a().decomposition_count(), - d_b: decomposer.b().decomposition_count(), + impl> SeededAutoKey { + fn empty( + ring_size: usize, + auto_decomposer: &D, + seed: S, + modulus: Mod, + ) -> Self { + SeededAutoKey { + data: M::zeros(auto_decomposer.decomposition_count(), ring_size), + seed, + modulus, + } } } -} -impl Debug for SeededRgswCiphertext -where - M::MatElement: Debug, -{ - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("SeededRgswCiphertext") - .field("data", &self.data) - .field("seed", &self.seed) - .field("modulus", &self.modulus) - .finish() + struct AutoKeyEvaluationDomain { + data: M, + _phantom: PhantomData<(R, N)>, } -} -pub struct RgswCiphertextEvaluationDomain { - pub(crate) data: M, - modulus: Mod, - _phantom: PhantomData<(R, N)>, -} + impl< + M: MatrixMut + MatrixEntity, + Mod: Modulus + Clone, + R: RandomFillUniformInModulus<[M::MatElement], Mod> + NewWithSeed, + N: NttInit + Ntt, + > From<&SeededAutoKey> for AutoKeyEvaluationDomain + where + ::R: RowMut, + M::MatElement: Copy, -impl< - M: MatrixMut + MatrixEntity, - Mod: Modulus + Clone, - R: NewWithSeed + RandomFillUniformInModulus<[M::MatElement], Mod>, - N: NttInit + Ntt + Debug, - > From<&SeededRgswCiphertext> for RgswCiphertextEvaluationDomain -where - ::R: RowMut, - M::MatElement: Copy, - R::Seed: Clone, - M: Debug, -{ - fn from(value: &SeededRgswCiphertext) -> Self { - let mut data = M::zeros(value.d_a * 2 + value.d_b * 2, value.data.dimension().1); - - // copy RLWE'(-sm) - izip!( - data.iter_rows_mut().take(value.d_a * 2), - value.data.iter_rows().take(value.d_a * 2) - ) - .for_each(|(to_ri, from_ri)| { - to_ri.as_mut().copy_from_slice(from_ri.as_ref()); - }); - - // sample A polynomials of RLWE'(m) - RLWE'A(m) - let mut p_rng = R::new_with_seed(value.seed.clone()); - izip!(data.iter_rows_mut().skip(value.d_a * 2).take(value.d_b * 1)) - .for_each(|ri| p_rng.random_fill(&value.modulus, ri.as_mut())); - - // RLWE'_B(m) - izip!( - data.iter_rows_mut().skip(value.d_a * 2 + value.d_b), - value.data.iter_rows().skip(value.d_a * 2) - ) - .for_each(|(to_ri, from_ri)| { - to_ri.as_mut().copy_from_slice(from_ri.as_ref()); - }); - - // Send polynomials to evaluation domain - let ring_size = data.dimension().1; - let nttop = N::new(&value.modulus, ring_size); - data.iter_rows_mut() - .for_each(|ri| nttop.forward(ri.as_mut())); - - Self { - data: data, - modulus: value.modulus.clone(), - _phantom: PhantomData, - } - } -} + R::Seed: Clone, + { + fn from(value: &SeededAutoKey) -> Self { + let (d, ring_size) = value.data.dimension(); + let mut data = M::zeros(2 * d, ring_size); -impl< - M: MatrixMut + MatrixEntity, - Mod: Modulus + Clone, - R, - N: NttInit + Ntt, - > From<&RgswCiphertext> for RgswCiphertextEvaluationDomain -where - ::R: RowMut, - M::MatElement: Copy, - M: Debug, -{ - fn from(value: &RgswCiphertext) -> Self { - let mut data = M::zeros(value.d_a * 2 + value.d_b * 2, value.data.dimension().1); - - // copy RLWE'(-sm) - izip!( - data.iter_rows_mut().take(value.d_a * 2), - value.data.iter_rows().take(value.d_a * 2) - ) - .for_each(|(to_ri, from_ri)| { - to_ri.as_mut().copy_from_slice(from_ri.as_ref()); - }); - - // copy RLWE'(m) - izip!( - data.iter_rows_mut().skip(value.d_a * 2), - value.data.iter_rows().skip(value.d_a * 2) - ) - .for_each(|(to_ri, from_ri)| { - to_ri.as_mut().copy_from_slice(from_ri.as_ref()); - }); - - // Send polynomials to evaluation domain - let ring_size = data.dimension().1; - let nttop = N::new(&value.modulus, ring_size); - data.iter_rows_mut() - .for_each(|ri| nttop.forward(ri.as_mut())); - - Self { - data: data, - modulus: value.modulus.clone(), - _phantom: PhantomData, - } - } -} + // sample RLWE'_A(-s(X^k)) + let mut p_rng = R::new_with_seed(value.seed.clone()); + data.iter_rows_mut().take(d).for_each(|r| { + RandomFillUniformInModulus::random_fill(&mut p_rng, &value.modulus, r.as_mut()); + }); -impl Debug for RgswCiphertextEvaluationDomain { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("RgswCiphertextEvaluationDomain") - .field("data", &self.data) - .field("modulus", &self.modulus) - .field("_phantom", &self._phantom) - .finish() - } -} + // copy over RLWE'_B(-s(X^k)) + izip!(data.iter_rows_mut().skip(d), value.data.iter_rows()).for_each( + |(to_r, from_r)| { + to_r.as_mut().copy_from_slice(from_r.as_ref()); + }, + ); -impl Matrix for RgswCiphertextEvaluationDomain { - type MatElement = M::MatElement; - type R = M::R; + // send RLWE'(-s(X^k)) polynomials to evaluation domain + let ntt_op = N::new(&value.modulus, ring_size); + data.iter_rows_mut() + .for_each(|r| ntt_op.forward(r.as_mut())); - fn dimension(&self) -> (usize, usize) { - self.data.dimension() + AutoKeyEvaluationDomain { + data, + _phantom: PhantomData, + } + } } - fn fits(&self, row: usize, col: usize) -> bool { - self.data.fits(row, col) + struct RgswCiphertext { + /// Rgsw ciphertext polynomials + data: M, + modulus: Mod, + /// Decomposition for RLWE part A + d_a: usize, + /// Decomposition for RLWE part B + d_b: usize, + } + + impl> RgswCiphertext { + pub(crate) fn empty( + ring_size: usize, + decomposer: &D, + modulus: Mod, + ) -> RgswCiphertext { + RgswCiphertext { + data: M::zeros( + decomposer.a().decomposition_count() * 2 + + decomposer.b().decomposition_count() * 2, + ring_size, + ), + d_a: decomposer.a().decomposition_count(), + d_b: decomposer.b().decomposition_count(), + modulus, + } + } } -} -impl AsRef<[M::R]> for RgswCiphertextEvaluationDomain { - fn as_ref(&self) -> &[M::R] { - self.data.as_ref() + pub struct SeededRgswCiphertext + where + M: Matrix, + { + pub(crate) data: M, + seed: S, + modulus: Mod, + /// Decomposition for RLWE part A + d_a: usize, + /// Decomposition for RLWE part B + d_b: usize, + } + + impl SeededRgswCiphertext { + pub(crate) fn empty( + ring_size: usize, + decomposer: &D, + seed: S, + modulus: Mod, + ) -> SeededRgswCiphertext { + SeededRgswCiphertext { + data: M::zeros( + decomposer.a().decomposition_count() * 2 + decomposer.b().decomposition_count(), + ring_size, + ), + seed, + modulus, + d_a: decomposer.a().decomposition_count(), + d_b: decomposer.b().decomposition_count(), + } + } } -} - -pub struct ShoupRgswCiphertextEvaluationDomain { - pub(crate) data: M, -} -impl, Mod: Modulus, R, N> - From<&RgswCiphertextEvaluationDomain> for ShoupRgswCiphertextEvaluationDomain -{ - fn from(value: &RgswCiphertextEvaluationDomain) -> Self { - Self { - data: M::to_shoup(&value.data, value.modulus.q().unwrap()), + impl Debug for SeededRgswCiphertext + where + M::MatElement: Debug, + { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SeededRgswCiphertext") + .field("data", &self.data) + .field("seed", &self.seed) + .field("modulus", &self.modulus) + .finish() } } -} -pub struct SeededRlweCiphertext { - pub(crate) data: R, - pub(crate) seed: S, - pub(crate) modulus: Mod, -} + pub struct RgswCiphertextEvaluationDomain { + pub(crate) data: M, + modulus: Mod, + _phantom: PhantomData<(R, N)>, + } + + impl< + M: MatrixMut + MatrixEntity, + Mod: Modulus + Clone, + R: NewWithSeed + RandomFillUniformInModulus<[M::MatElement], Mod>, + N: NttInit + Ntt + Debug, + > From<&SeededRgswCiphertext> + for RgswCiphertextEvaluationDomain + where + ::R: RowMut, + M::MatElement: Copy, + R::Seed: Clone, + M: Debug, + { + fn from(value: &SeededRgswCiphertext) -> Self { + let mut data = M::zeros(value.d_a * 2 + value.d_b * 2, value.data.dimension().1); + + // copy RLWE'(-sm) + izip!( + data.iter_rows_mut().take(value.d_a * 2), + value.data.iter_rows().take(value.d_a * 2) + ) + .for_each(|(to_ri, from_ri)| { + to_ri.as_mut().copy_from_slice(from_ri.as_ref()); + }); -impl SeededRlweCiphertext { - pub(crate) fn empty(ring_size: usize, seed: S, modulus: Mod) -> Self { - SeededRlweCiphertext { - data: R::zeros(ring_size), - seed, - modulus, - } - } -} + // sample A polynomials of RLWE'(m) - RLWE'A(m) + let mut p_rng = R::new_with_seed(value.seed.clone()); + izip!(data.iter_rows_mut().skip(value.d_a * 2).take(value.d_b * 1)) + .for_each(|ri| p_rng.random_fill(&value.modulus, ri.as_mut())); -pub struct RlweCiphertext { - pub(crate) data: M, - pub(crate) is_trivial: bool, - pub(crate) _phatom: PhantomData, -} + // RLWE'_B(m) + izip!( + data.iter_rows_mut().skip(value.d_a * 2 + value.d_b), + value.data.iter_rows().skip(value.d_a * 2) + ) + .for_each(|(to_ri, from_ri)| { + to_ri.as_mut().copy_from_slice(from_ri.as_ref()); + }); -impl RlweCiphertext { - pub(crate) fn new_trivial(data: M) -> Self { - RlweCiphertext { - data, - is_trivial: true, - _phatom: PhantomData, + // Send polynomials to evaluation domain + let ring_size = data.dimension().1; + let nttop = N::new(&value.modulus, ring_size); + data.iter_rows_mut() + .for_each(|ri| nttop.forward(ri.as_mut())); + + Self { + data: data, + modulus: value.modulus.clone(), + _phantom: PhantomData, + } } } -} -impl Matrix for RlweCiphertext { - type MatElement = M::MatElement; - type R = M::R; - - fn dimension(&self) -> (usize, usize) { - self.data.dimension() - } + impl< + M: MatrixMut + MatrixEntity, + Mod: Modulus + Clone, + R, + N: NttInit + Ntt, + > From<&RgswCiphertext> for RgswCiphertextEvaluationDomain + where + ::R: RowMut, + M::MatElement: Copy, + M: Debug, + { + fn from(value: &RgswCiphertext) -> Self { + let mut data = M::zeros(value.d_a * 2 + value.d_b * 2, value.data.dimension().1); + + // copy RLWE'(-sm) + izip!( + data.iter_rows_mut().take(value.d_a * 2), + value.data.iter_rows().take(value.d_a * 2) + ) + .for_each(|(to_ri, from_ri)| { + to_ri.as_mut().copy_from_slice(from_ri.as_ref()); + }); - fn fits(&self, row: usize, col: usize) -> bool { - self.data.fits(row, col) - } -} + // copy RLWE'(m) + izip!( + data.iter_rows_mut().skip(value.d_a * 2), + value.data.iter_rows().skip(value.d_a * 2) + ) + .for_each(|(to_ri, from_ri)| { + to_ri.as_mut().copy_from_slice(from_ri.as_ref()); + }); -impl MatrixMut for RlweCiphertext where ::R: RowMut {} + // Send polynomials to evaluation domain + let ring_size = data.dimension().1; + let nttop = N::new(&value.modulus, ring_size); + data.iter_rows_mut() + .for_each(|ri| nttop.forward(ri.as_mut())); -impl AsRef<[::R]> for RlweCiphertext { - fn as_ref(&self) -> &[::R] { - self.data.as_ref() + Self { + data: data, + modulus: value.modulus.clone(), + _phantom: PhantomData, + } + } } -} -impl AsMut<[::R]> for RlweCiphertext -where - ::R: RowMut, -{ - fn as_mut(&mut self) -> &mut [::R] { - self.data.as_mut() + impl Debug for RgswCiphertextEvaluationDomain { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RgswCiphertextEvaluationDomain") + .field("data", &self.data) + .field("modulus", &self.modulus) + .field("_phantom", &self._phantom) + .finish() + } } -} -impl IsTrivial for RlweCiphertext { - fn is_trivial(&self) -> bool { - self.is_trivial - } - fn set_not_trivial(&mut self) { - self.is_trivial = false; + struct SeededRlweCiphertext { + data: R, + seed: S, + modulus: Mod, } -} -impl< - R: Row, - M: MatrixEntity + MatrixMut, - Rng: NewWithSeed + RandomFillUniformInModulus<[M::MatElement], Mod>, - Mod: Modulus, - > From<&SeededRlweCiphertext> for RlweCiphertext -where - Rng::Seed: Clone, - ::R: RowMut, - R::Element: Copy, -{ - fn from(value: &SeededRlweCiphertext) -> Self { - let mut data = M::zeros(2, value.data.as_ref().len()); - - // sample a - let mut p_rng = Rng::new_with_seed(value.seed.clone()); - RandomFillUniformInModulus::random_fill(&mut p_rng, &value.modulus, data.get_row_mut(0)); - - data.get_row_mut(1).copy_from_slice(value.data.as_ref()); - - RlweCiphertext { - data, - is_trivial: false, - _phatom: PhantomData, + impl SeededRlweCiphertext { + fn empty(ring_size: usize, seed: S, modulus: Mod) -> Self { + SeededRlweCiphertext { + data: R::zeros(ring_size), + seed, + modulus, + } } } -} -pub trait IsTrivial { - fn is_trivial(&self) -> bool; - fn set_not_trivial(&mut self); -} + pub struct RlweCiphertext { + data: M, + _phatom: PhantomData, + } + + impl< + R: Row, + M: MatrixEntity + MatrixMut, + Rng: NewWithSeed + RandomFillUniformInModulus<[M::MatElement], Mod>, + Mod: Modulus, + > From<&SeededRlweCiphertext> for RlweCiphertext + where + Rng::Seed: Clone, + ::R: RowMut, + R::Element: Copy, + { + fn from(value: &SeededRlweCiphertext) -> Self { + let mut data = M::zeros(2, value.data.as_ref().len()); + + // sample a + let mut p_rng = Rng::new_with_seed(value.seed.clone()); + RandomFillUniformInModulus::random_fill( + &mut p_rng, + &value.modulus, + data.get_row_mut(0), + ); -pub struct SeededRlwePublicKey { - data: Ro, - seed: S, - modulus: Ro::Element, -} + data.get_row_mut(1).copy_from_slice(value.data.as_ref()); -impl SeededRlwePublicKey { - pub(crate) fn empty(ring_size: usize, seed: S, modulus: Ro::Element) -> Self { - Self { - data: Ro::zeros(ring_size), - seed, - modulus, + RlweCiphertext { + data, + _phatom: PhantomData, + } } } -} -pub struct RlwePublicKey { - data: M, - _phantom: PhantomData, -} + struct SeededRlwePublicKey { + data: Ro, + seed: S, + modulus: Ro::Element, + } -impl< - M: MatrixMut + MatrixEntity, - Rng: NewWithSeed + RandomFillUniformInModulus<[M::MatElement], M::MatElement>, - > From<&SeededRlwePublicKey> for RlwePublicKey -where - ::R: RowMut, - M::MatElement: Copy, - Rng::Seed: Copy, -{ - fn from(value: &SeededRlwePublicKey) -> Self { - let mut data = M::zeros(2, value.data.as_ref().len()); - - // sample a - let mut p_rng = Rng::new_with_seed(value.seed); - RandomFillUniformInModulus::random_fill(&mut p_rng, &value.modulus, data.get_row_mut(0)); - - // copy over b - data.get_row_mut(1).copy_from_slice(value.data.as_ref()); - - Self { - data, - _phantom: PhantomData, + impl SeededRlwePublicKey { + pub(crate) fn empty(ring_size: usize, seed: S, modulus: Ro::Element) -> Self { + Self { + data: Ro::zeros(ring_size), + seed, + modulus, + } } } -} - -#[derive(Clone)] -pub struct RlweSecret { - pub(crate) values: Vec, -} -impl Secret for RlweSecret { - type Element = i32; - fn values(&self) -> &[Self::Element] { - &self.values - } -} + struct RlwePublicKey { + data: M, + _phantom: PhantomData, + } + + impl< + M: MatrixMut + MatrixEntity, + Rng: NewWithSeed + RandomFillUniformInModulus<[M::MatElement], M::MatElement>, + > From<&SeededRlwePublicKey> for RlwePublicKey + where + ::R: RowMut, + M::MatElement: Copy, + Rng::Seed: Copy, + { + fn from(value: &SeededRlwePublicKey) -> Self { + let mut data = M::zeros(2, value.data.as_ref().len()); + + // sample a + let mut p_rng = Rng::new_with_seed(value.seed); + RandomFillUniformInModulus::random_fill( + &mut p_rng, + &value.modulus, + data.get_row_mut(0), + ); -impl RlweSecret { - pub fn random(hw: usize, n: usize) -> RlweSecret { - DefaultSecureRng::with_local_mut(|rng| { - let mut out = vec![0i32; n]; - fill_random_ternary_secret_with_hamming_weight(&mut out, hw, rng); + // copy over b + data.get_row_mut(1).copy_from_slice(value.data.as_ref()); - RlweSecret { values: out } - }) + Self { + data, + _phantom: PhantomData, + } + } } -} - -#[cfg(test)] -pub(crate) mod tests { - use std::{clone, marker::PhantomData, ops::Mul, vec}; - use itertools::{izip, Itertools}; - use rand::{thread_rng, Rng}; - - use crate::{ - backend::{GetModulus, ModInit, ModularOpsU64, Modulus, VectorOps}, - decomposer::{Decomposer, DefaultDecomposer, RlweDecomposer}, - ntt::{Ntt, NttBackendU64, NttInit}, - random::{DefaultSecureRng, RandomFillGaussianInModulus, RandomFillUniformInModulus}, - rgsw::{ - galois_auto_shoup, rlwe_by_rgsw_shoup, ShoupAutoKeyEvaluationDomain, - ShoupRgswCiphertextEvaluationDomain, - }, - utils::{generate_prime, negacyclic_mul, tests::Stats, TryConvertFrom1}, - Matrix, MatrixMut, Secret, - }; - - use super::{ - keygen::{ - decrypt_rlwe, generate_auto_map, measure_noise, public_key_encrypt_rgsw, - rlwe_public_key, secret_key_encrypt_rgsw, seeded_auto_key_gen, - seeded_secret_key_encrypt_rlwe, - }, - runtime::{rgsw_by_rgsw_inplace, rlwe_auto, rlwe_by_rgsw}, - AutoKeyEvaluationDomain, RgswCiphertext, RgswCiphertextEvaluationDomain, RlweCiphertext, - RlwePublicKey, RlweSecret, SeededAutoKey, SeededRgswCiphertext, SeededRlweCiphertext, - SeededRlwePublicKey, - }; + #[derive(Clone)] + struct RlweSecret { + pub(crate) values: Vec, + } - pub(crate) fn _sk_encrypt_rlwe + Clone>( - m: &[u64], - s: &[i32], - ntt_op: &NttBackendU64, - mod_op: &ModularOpsU64, - ) -> RlweCiphertext>, DefaultSecureRng> { - let ring_size = m.len(); - let q = mod_op.modulus(); - assert!(s.len() == ring_size); + impl Secret for RlweSecret { + type Element = i32; + fn values(&self) -> &[Self::Element] { + &self.values + } + } - let mut rng = DefaultSecureRng::new(); - let mut rlwe_seed = [0u8; 32]; - rng.fill_bytes(&mut rlwe_seed); - let mut seeded_rlwe_ct = - SeededRlweCiphertext::<_, [u8; 32], _>::empty(ring_size as usize, rlwe_seed, q.clone()); - let mut p_rng = DefaultSecureRng::new_seeded(rlwe_seed); - seeded_secret_key_encrypt_rlwe( - &m, - &mut seeded_rlwe_ct.data, - s, - mod_op, - ntt_op, - &mut p_rng, - &mut rng, - ); + impl RlweSecret { + pub fn random(hw: usize, n: usize) -> RlweSecret { + DefaultSecureRng::with_local_mut(|rng| { + let mut out = vec![0i32; n]; + fill_random_ternary_secret_with_hamming_weight(&mut out, hw, rng); - RlweCiphertext::>, DefaultSecureRng>::from(&seeded_rlwe_ct) + RlweSecret { values: out } + }) + } } - // Encrypt m as RGSW ciphertext RGSW(m) using supplied public key - pub(crate) fn _pk_encrypt_rgsw + Clone>( - m: &[u64], - public_key: &RlwePublicKey>, DefaultSecureRng>, - decomposer: &(DefaultDecomposer, DefaultDecomposer), - mod_op: &ModularOpsU64, - ntt_op: &NttBackendU64, - ) -> RgswCiphertext>, T> { - let (_, ring_size) = Matrix::dimension(&public_key.data); - let gadget_vector_a = decomposer.a().gadget_vector(); - let gadget_vector_b = decomposer.b().gadget_vector(); - + fn random_seed() -> [u8; 32] { let mut rng = DefaultSecureRng::new(); - - assert!(m.len() == ring_size); - - // public key encrypt RGSW(m1) - let mut rgsw_ct = RgswCiphertext::empty(ring_size, decomposer, mod_op.modulus().clone()); - public_key_encrypt_rgsw( - &mut rgsw_ct.data, - m, - &public_key.data, - &gadget_vector_a, - &gadget_vector_b, - mod_op, - ntt_op, - &mut rng, - ); - - rgsw_ct + let mut seed = [0u8; 32]; + rng.fill_bytes(&mut seed); + seed } /// Encrypts m as RGSW ciphertext RGSW(m) using supplied secret key. Returns - /// unseeded RGSW ciphertext in coefficient domain - pub(crate) fn _sk_encrypt_rgsw + Clone>( + /// seeded RGSW ciphertext in coefficient domain + fn sk_encrypt_rgsw + Clone>( m: &[u64], s: &[i32], decomposer: &(DefaultDecomposer, DefaultDecomposer), @@ -602,14 +442,10 @@ pub(crate) mod tests { let ring_size = s.len(); assert!(m.len() == s.len()); - let q = mod_op.modulus(); - - let gadget_vector_a = decomposer.a().gadget_vector(); - let gadget_vector_b = decomposer.b().gadget_vector(); - let mut rng = DefaultSecureRng::new(); - let mut rgsw_seed = [0u8; 32]; - rng.fill_bytes(&mut rgsw_seed); + + let q = mod_op.modulus(); + let rgsw_seed = random_seed(); let mut seeded_rgsw_ct = SeededRgswCiphertext::>, [u8; 32], T>::empty( ring_size as usize, decomposer, @@ -620,8 +456,8 @@ pub(crate) mod tests { secret_key_encrypt_rgsw( &mut seeded_rgsw_ct.data, m, - &gadget_vector_a, - &gadget_vector_b, + &decomposer.a().gadget_vector(), + &decomposer.b().gadget_vector(), s, mod_op, ntt_op, @@ -659,11 +495,24 @@ pub(crate) mod tests { .iter() .map(|v| (((*v as f64) * q as f64) / (p as f64)).round() as u64) .collect_vec(); - let rlwe_in_ct = _sk_encrypt_rlwe(&encoded_m, s.values(), &ntt_op, &mod_op); + let seed = random_seed(); + let mut rlwe_in_ct = + SeededRlweCiphertext::, _, _>::empty(ring_size as usize, seed, q); + let mut p_rng = DefaultSecureRng::new_seeded(seed); + seeded_secret_key_encrypt_rlwe( + &encoded_m, + &mut rlwe_in_ct.data, + s.values(), + &mod_op, + &ntt_op, + &mut p_rng, + &mut rng, + ); + let rlwe_in_ct = RlweCiphertext::>, DefaultSecureRng>::from(&rlwe_in_ct); let mut encoded_m_back = vec![0u64; ring_size as usize]; decrypt_rlwe( - &rlwe_in_ct, + &rlwe_in_ct.data, s.values(), &mut encoded_m_back, &ntt_op, @@ -680,7 +529,7 @@ pub(crate) mod tests { fn rlwe_by_rgsw_works() { let logq = 50; let logp = 2; - let ring_size = 1 << 9; + let ring_size = 1 << 4; let q = generate_prime(logq, ring_size, 1u64 << logq).unwrap(); let p: u64 = 1u64 << logp; @@ -707,8 +556,7 @@ pub(crate) mod tests { ); // create public key - let mut pk_seed = [0u8; 32]; - rng.fill_bytes(&mut pk_seed); + let pk_seed = random_seed(); let mut pk_prng = DefaultSecureRng::new_seeded(pk_seed); let mut seeded_pk = SeededRlwePublicKey::, _>::empty(ring_size as usize, pk_seed, q); @@ -720,25 +568,13 @@ pub(crate) mod tests { &mut pk_prng, &mut rng, ); - let pk = RlwePublicKey::>, DefaultSecureRng>::from(&seeded_pk); + // let pk = RlwePublicKey::>, DefaultSecureRng>::from(&seeded_pk); // Encrypt m1 as RGSW(m1) let rgsw_ct = { - //TODO(Jay): Figure out better way to test secret key and public key variant of - // RGSW ciphertext encryption within the same test - - if true { - // Encryption m1 as RGSW(m1) using secret key - let seeded_rgsw_ct = - _sk_encrypt_rgsw(&m1, s.values(), &decomposer, &mod_op, &ntt_op); - RgswCiphertextEvaluationDomain::>, _,DefaultSecureRng, NttBackendU64>::from(&seeded_rgsw_ct) - } else { - // Encrypt m1 as RGSW(m1) using public key - let rgsw_ct = _pk_encrypt_rgsw(&m1, &pk, &decomposer, &mod_op, &ntt_op); - RgswCiphertextEvaluationDomain::<_, _, DefaultSecureRng, NttBackendU64>::from( - &rgsw_ct, - ) - } + // Encryption m1 as RGSW(m1) using secret key + let seeded_rgsw_ct = sk_encrypt_rgsw(&m1, s.values(), &decomposer, &mod_op, &ntt_op); + RgswCiphertextEvaluationDomain::>, _,DefaultSecureRng, NttBackendU64>::from(&seeded_rgsw_ct) }; // Encrypt m0 as RLWE(m0) @@ -748,29 +584,44 @@ pub(crate) mod tests { .map(|v| (((*v as f64) * q as f64) / (p as f64)).round() as u64) .collect_vec(); - _sk_encrypt_rlwe(&encoded_m, s.values(), &ntt_op, &mod_op) + let seed = random_seed(); + let mut p_rng = DefaultSecureRng::new_seeded(seed); + let mut seeded_rlwe = SeededRlweCiphertext::empty(ring_size as usize, seed, q); + seeded_secret_key_encrypt_rlwe( + &encoded_m, + &mut seeded_rlwe.data, + s.values(), + &mod_op, + &ntt_op, + &mut p_rng, + &mut rng, + ); + RlweCiphertext::>, DefaultSecureRng>::from(&seeded_rlwe) }; // RLWE(m0m1) = RLWE(m0) x RGSW(m1) - let mut scratch_space = vec![ - vec![0u64; ring_size as usize]; - std::cmp::max( - decomposer.a().decomposition_count(), - decomposer.b().decomposition_count() - ) + 2 - ]; + let mut scratch_space = + vec![vec![0u64; ring_size as usize]; rlwe_x_rgsw_scratch_rows(&decomposer)]; - // rlwe x rgsw with additional RGSW ciphertexts in shoup repr + // rlwe x rgsw with with soup repr let rlwe_in_ct_shoup = { let mut rlwe_in_ct_shoup = rlwe_in_ct.data.clone(); - let rgsw_ct_shoup = ShoupRgswCiphertextEvaluationDomain::from(&rgsw_ct); + let rgsw_ct_shoup = ToShoup::to_shoup(&rgsw_ct.data, q); rlwe_by_rgsw_shoup( - &mut rlwe_in_ct_shoup, - &rgsw_ct.data, - &rgsw_ct_shoup.data, - &mut scratch_space, + &mut RlweCiphertextMutRef::new(rlwe_in_ct_shoup.as_mut()), + &RgswCiphertextRef::new( + rgsw_ct.data.as_ref(), + decomposer.a().decomposition_count(), + decomposer.b().decomposition_count(), + ), + &RgswCiphertextRef::new( + rgsw_ct_shoup.as_ref(), + decomposer.a().decomposition_count(), + decomposer.b().decomposition_count(), + ), + &mut RuntimeScratchMutRef::new(scratch_space.as_mut()), &decomposer, &ntt_op, &mod_op, @@ -783,24 +634,27 @@ pub(crate) mod tests { // rlwe x rgsw normal { rlwe_by_rgsw( - &mut rlwe_in_ct, - &rgsw_ct.data, - &mut scratch_space, + &mut RlweCiphertextMutRef::new(rlwe_in_ct.data.as_mut()), + &RgswCiphertextRef::new( + rgsw_ct.data.as_ref(), + decomposer.a().decomposition_count(), + decomposer.b().decomposition_count(), + ), + &mut RuntimeScratchMutRef::new(scratch_space.as_mut()), &decomposer, &ntt_op, &mod_op, + false, ); } // output from both functions must be equal - { - assert_eq!(rlwe_in_ct.data, rlwe_in_ct_shoup); - } + assert_eq!(rlwe_in_ct.data, rlwe_in_ct_shoup); // Decrypt RLWE(m0m1) let mut encoded_m0m1_back = vec![0u64; ring_size as usize]; decrypt_rlwe( - &rlwe_in_ct, + &rlwe_in_ct_shoup, s.values(), &mut encoded_m0m1_back, &ntt_op, @@ -834,7 +688,7 @@ pub(crate) mod tests { } #[test] - fn galois_auto_works() { + fn rlwe_auto_works() { let logq = 55; let ring_size = 1 << 11; let q = generate_prime(logq, 2 * ring_size, 1u64 << logq).unwrap(); @@ -857,8 +711,7 @@ pub(crate) mod tests { let mod_op = ModularOpsU64::new(q); // RLWE_{s}(m) - let mut seed_rlwe = [0u8; 32]; - rng.fill_bytes(&mut seed_rlwe); + let seed_rlwe = random_seed(); let mut seeded_rlwe_m = SeededRlweCiphertext::empty(ring_size as usize, seed_rlwe, q); let mut p_rng = DefaultSecureRng::new_seeded(seed_rlwe); seeded_secret_key_encrypt_rlwe( @@ -874,10 +727,9 @@ pub(crate) mod tests { let auto_k = -125; - // Generate galois key to key switch from s^k to s + // Generate auto key to key switch from s^k to s let decomposer = DefaultDecomposer::new(q, logb, d_rgsw); - let mut seed_auto = [0u8; 32]; - rng.fill_bytes(&mut seed_auto); + let seed_auto = random_seed(); let mut seeded_auto_key = SeededAutoKey::empty(ring_size as usize, &decomposer, seed_auto, q); let mut p_rng = DefaultSecureRng::new_seeded(seed_auto); @@ -893,23 +745,24 @@ pub(crate) mod tests { &mut rng, ); let auto_key = - AutoKeyEvaluationDomain::>, _, DefaultSecureRng, NttBackendU64>::from( + AutoKeyEvaluationDomain::>, DefaultSecureRng, NttBackendU64>::from( &seeded_auto_key, ); // Send RLWE_{s}(m) -> RLWE_{s}(m^k) - let mut scratch_space = vec![vec![0u64; ring_size as usize]; d_rgsw + 2]; + let mut scratch_space = + vec![vec![0; ring_size as usize]; rlwe_auto_scratch_rows(&decomposer)]; let (auto_map_index, auto_map_sign) = generate_auto_map(ring_size as usize, auto_k); - // galois auto with additional auto key in shoup repr + // galois auto with auto key in shoup repr let rlwe_m_shoup = { - let auto_key_shoup = ShoupAutoKeyEvaluationDomain::from(&auto_key); + let auto_key_shoup = ToShoup::to_shoup(&auto_key.data, q); let mut rlwe_m_shoup = rlwe_m.data.clone(); - galois_auto_shoup( - &mut rlwe_m_shoup, - &auto_key.data, - &auto_key_shoup.data, - &mut scratch_space, + rlwe_auto_shoup( + &mut RlweCiphertextMutRef::new(&mut rlwe_m_shoup), + &RlweKskRef::new(&auto_key.data, decomposer.decomposition_count()), + &RlweKskRef::new(&auto_key_shoup, decomposer.decomposition_count()), + &mut RuntimeScratchMutRef::new(&mut scratch_space), &auto_map_index, &auto_map_sign, &mod_op, @@ -923,14 +776,15 @@ pub(crate) mod tests { // normal galois auto { rlwe_auto( - &mut rlwe_m, - &auto_key.data, - &mut scratch_space, + &mut RlweCiphertextMutRef::new(rlwe_m.data.as_mut()), + &RlweKskRef::new(auto_key.data.as_ref(), decomposer.decomposition_count()), + &mut RuntimeScratchMutRef::new(scratch_space.as_mut()), &auto_map_index, &auto_map_sign, &mod_op, &ntt_op, &decomposer, + false, ); } @@ -942,7 +796,7 @@ pub(crate) mod tests { // Decrypt RLWE_{s}(m^k) and check let mut encoded_m_k_back = vec![0u64; ring_size as usize]; decrypt_rlwe( - &rlwe_m_k, + &rlwe_m_k.data, s.values(), &mut encoded_m_k_back, &ntt_op, @@ -965,32 +819,105 @@ pub(crate) mod tests { }, ); - { - let encoded_m_k = m_k - .iter() - .map(|v| ((*v as f64 * q as f64) / p as f64).round() as u64) - .collect_vec(); + // { + // let encoded_m_k = m_k + // .iter() + // .map(|v| ((*v as f64 * q as f64) / p as f64).round() as u64) + // .collect_vec(); - let noise = measure_noise(&rlwe_m_k, &encoded_m_k, &ntt_op, &mod_op, s.values()); - println!("Ksk noise: {noise}"); - } + // let noise = measure_noise(&rlwe_m_k, &encoded_m_k, &ntt_op, &mod_op, + // s.values()); println!("Ksk noise: {noise}"); + // } assert_eq!(m_k_back, m_k); } + /// Collect noise stats of RGSW ciphertext + /// + /// - rgsw_ct: RGSW ciphertext must be in coefficient domain + fn rgsw_noise_stats + Clone>( + rgsw_ct: &[Vec], + m: &[u64], + s: &[i32], + decomposer: &(DefaultDecomposer, DefaultDecomposer), + q: &T, + ) -> Stats { + let gadget_vector_a = decomposer.a().gadget_vector(); + let gadget_vector_b = decomposer.b().gadget_vector(); + let d_a = gadget_vector_a.len(); + let d_b = gadget_vector_b.len(); + let ring_size = s.len(); + assert!(Matrix::dimension(&rgsw_ct) == (d_a * 2 + d_b * 2, ring_size)); + assert!(m.len() == ring_size); + + let mod_op = ModularOpsU64::new(q.clone()); + let ntt_op = NttBackendU64::new(q, ring_size); + + let mul_mod = + |a: &u64, b: &u64| ((*a as u128 * *b as u128) % q.q().unwrap() as u128) as u64; + let s_poly = Vec::::try_convert_from(s, q); + let mut neg_s = s_poly.clone(); + mod_op.elwise_neg_mut(neg_s.as_mut()); + let neg_sm0m1 = negacyclic_mul(&neg_s, &m, mul_mod, q.q().unwrap()); + + let mut stats = Stats::new(); + + // RLWE(\beta^j -s * m) + for j in 0..d_a { + let want_m = { + // RLWE(\beta^j -s * m) + let mut beta_neg_sm0m1 = vec![0u64; ring_size as usize]; + mod_op.elwise_scalar_mul(beta_neg_sm0m1.as_mut(), &neg_sm0m1, &gadget_vector_a[j]); + beta_neg_sm0m1 + }; + + let mut rlwe = vec![vec![0u64; ring_size as usize]; 2]; + rlwe[0].copy_from_slice(rgsw_ct.get_row_slice(j)); + rlwe[1].copy_from_slice(rgsw_ct.get_row_slice(d_a + j)); + + let mut got_m = vec![0; ring_size]; + decrypt_rlwe(&rlwe, s, &mut got_m, &ntt_op, &mod_op); + + let mut diff = want_m; + mod_op.elwise_sub_mut(diff.as_mut(), got_m.as_ref()); + stats.add_more(&Vec::::try_convert_from(&diff, q)); + } + + // RLWE(\beta^j m) + for j in 0..d_b { + let want_m = { + // RLWE(\beta^j m) + let mut beta_m0m1 = vec![0u64; ring_size as usize]; + mod_op.elwise_scalar_mul(beta_m0m1.as_mut(), &m, &gadget_vector_b[j]); + beta_m0m1 + }; + + let mut rlwe = vec![vec![0u64; ring_size as usize]; 2]; + rlwe[0].copy_from_slice(rgsw_ct.get_row_slice(d_a * 2 + j)); + rlwe[1].copy_from_slice(rgsw_ct.get_row_slice(d_a * 2 + d_b + j)); + + let mut got_m = vec![0; ring_size]; + decrypt_rlwe(&rlwe, s, &mut got_m, &ntt_op, &mod_op); + + let mut diff = want_m; + mod_op.elwise_sub_mut(diff.as_mut(), got_m.as_ref()); + stats.add_more(&Vec::::try_convert_from(&diff, q)); + } + + stats + } + #[test] - fn sk_rgsw_by_rgsw() { + fn print_noise_stats_rgsw_x_rgsw() { let logq = 60; let logp = 2; let ring_size = 1 << 11; let q = generate_prime(logq, ring_size, 1u64 << logq).unwrap(); - let p = 1u64 << logp; let d_rgsw = 12; let logb = 5; let s = RlweSecret::random((ring_size >> 1) as usize, ring_size as usize); - let mut rng = DefaultSecureRng::new(); let ntt_op = NttBackendU64::new(&q, ring_size as usize); let mod_op = ModularOpsU64::new(q); let decomposer = ( @@ -998,14 +925,17 @@ pub(crate) mod tests { DefaultDecomposer::new(q, logb, d_rgsw), ); + let d_a = decomposer.a().decomposition_count(); + let d_b = decomposer.b().decomposition_count(); + let mul_mod = |a: &u64, b: &u64| ((*a as u128 * *b as u128) % q as u128) as u64; let mut carry_m = vec![0u64; ring_size as usize]; - carry_m[thread_rng().gen_range(0..ring_size) as usize] = 1; + carry_m[thread_rng().gen_range(0..ring_size) as usize] = 1 << logp; // RGSW(carry_m) let mut rgsw_carrym = { - let seeded_rgsw = _sk_encrypt_rgsw(&carry_m, s.values(), &decomposer, &mod_op, &ntt_op); + let seeded_rgsw = sk_encrypt_rgsw(&carry_m, s.values(), &decomposer, &mod_op, &ntt_op); let mut rgsw_eval = RgswCiphertextEvaluationDomain::<_, _, DefaultSecureRng, NttBackendU64>::from( &seeded_rgsw, @@ -1019,192 +949,36 @@ pub(crate) mod tests { let mut scratch_matrix = vec![ vec![0u64; ring_size as usize]; - decomposer.a().decomposition_count() * 2 - + decomposer.b().decomposition_count() * 2 - + std::cmp::max( - decomposer.a().decomposition_count(), - decomposer.b().decomposition_count() - ) + rgsw_x_rgsw_scratch_rows(&decomposer, &decomposer) ]; - // _measure_noise_rgsw(&rgsw_carrym, &carry_m, s.values(), &decomposer, &q); + rgsw_noise_stats(&rgsw_carrym, &carry_m, s.values(), &decomposer, &q); - for i in 0..2 { + for i in 0..8 { let mut m = vec![0u64; ring_size as usize]; m[thread_rng().gen_range(0..ring_size) as usize] = if (i & 1) == 1 { q - 1 } else { 1 }; let rgsw_m = RgswCiphertextEvaluationDomain::<_, _, DefaultSecureRng, NttBackendU64>::from( - &_sk_encrypt_rgsw(&m, s.values(), &decomposer, &mod_op, &ntt_op), + &sk_encrypt_rgsw(&m, s.values(), &decomposer, &mod_op, &ntt_op), ); + rgsw_by_rgsw_inplace( - &mut rgsw_carrym, - decomposer.a().decomposition_count(), - decomposer.b().decomposition_count(), - &rgsw_m.data, + &mut RgswCiphertextMutRef::new(rgsw_carrym.as_mut(), d_a, d_b), + &RgswCiphertextRef::new(rgsw_m.data.as_ref(), d_a, d_b), + &decomposer, &decomposer, - &mut scratch_matrix, + &mut RuntimeScratchMutRef::new(scratch_matrix.as_mut()), &ntt_op, &mod_op, ); // measure noise carry_m = negacyclic_mul(&carry_m, &m, mul_mod, q); - println!("########### Noise RGSW(carrym) in {i}^th loop ###########"); - // _measure_noise_rgsw(&rgsw_carrym, &carry_m, s.values(), - // &decomposer, &q); - } - { - // RLWE(m) x RGSW(carry_m) - let mut m = vec![0u64; ring_size as usize]; - RandomFillUniformInModulus::random_fill(&mut rng, &q, m.as_mut_slice()); - let mut rlwe_ct = _sk_encrypt_rlwe(&m, s.values(), &ntt_op, &mod_op); - - // send rgsw to evaluation domain - rgsw_carrym - .iter_mut() - .for_each(|ri| ntt_op.forward(ri.as_mut_slice())); - - rlwe_by_rgsw( - &mut rlwe_ct, - &rgsw_carrym, - &mut scratch_matrix, - &decomposer, - &ntt_op, - &mod_op, - ); - let m_expected = negacyclic_mul(&carry_m, &m, mul_mod, q); - let noise = measure_noise(&rlwe_ct, &m_expected, &ntt_op, &mod_op, s.values()); - println!("RLWE(m) x RGSW(carry_m): {noise}"); - } - } - - #[test] - fn some_work() { - let logq = 55; - let ring_size = 1 << 11; - let q = generate_prime(logq, ring_size as u64, 1u64 << logq).unwrap(); - let d = 2; - let logb = 12; - let decomposer = DefaultDecomposer::new(q, logb, d); - - let ntt_op = NttBackendU64::new(&q, ring_size as usize); - let mod_op = ModularOpsU64::new(q); - let mut rng = DefaultSecureRng::new(); - - let mut stats = Stats::new(); - - for _ in 0..10 { - let mut a = vec![0u64; ring_size]; - RandomFillUniformInModulus::random_fill(&mut rng, &q, a.as_mut()); - let mut m = vec![0u64; ring_size]; - RandomFillGaussianInModulus::random_fill(&mut rng, &q, m.as_mut()); - - let mut sk = vec![0u64; ring_size]; - RandomFillGaussianInModulus::random_fill(&mut rng, &q, sk.as_mut()); - let mut sk_eval = sk.clone(); - ntt_op.forward(sk_eval.as_mut_slice()); - - let gadget_vector = decomposer.gadget_vector(); - - // ksk (beta e) - let mut ksk_part_b = vec![vec![0u64; ring_size]; decomposer.decomposition_count()]; - let mut ksk_part_a = vec![vec![0u64; ring_size]; decomposer.decomposition_count()]; - izip!( - ksk_part_b.iter_rows_mut(), - ksk_part_a.iter_rows_mut(), - gadget_vector.iter() - ) - .for_each(|(part_b, part_a, beta)| { - RandomFillUniformInModulus::random_fill(&mut rng, &q, part_a.as_mut()); - - // a * s - let mut tmp = part_a.to_vec(); - ntt_op.forward(tmp.as_mut()); - mod_op.elwise_mul_mut(tmp.as_mut(), sk_eval.as_ref()); - ntt_op.backward(tmp.as_mut()); - - // a*s + e + beta m - RandomFillGaussianInModulus::random_fill(&mut rng, &q, part_b.as_mut()); - // println!("E: {:?}", &part_b); - // a*s + e - mod_op.elwise_add_mut(part_b.as_mut_slice(), tmp.as_ref()); - // a*s + e + beta m - let mut tmp = m.to_vec(); - mod_op.elwise_scalar_mul_mut(tmp.as_mut_slice(), beta); - mod_op.elwise_add_mut(part_b.as_mut_slice(), tmp.as_ref()); - }); - - // decompose a - let mut decomposed_a = vec![vec![0u64; ring_size]; decomposer.decomposition_count()]; - a.iter().enumerate().for_each(|(ri, el)| { - decomposer - .decompose_iter(el) - .into_iter() - .enumerate() - .for_each(|(j, d_el)| { - decomposed_a[j][ri] = d_el; - }); - }); - - // println!("Last limb"); - - // decomp_a * ksk(beta m) - ksk_part_b - .iter_mut() - .for_each(|r| ntt_op.forward(r.as_mut_slice())); - ksk_part_a - .iter_mut() - .for_each(|r| ntt_op.forward(r.as_mut_slice())); - decomposed_a - .iter_mut() - .for_each(|r| ntt_op.forward(r.as_mut_slice())); - let mut out = vec![vec![0u64; ring_size]; 2]; - izip!(decomposed_a.iter(), ksk_part_b.iter(), ksk_part_a.iter()).for_each( - |(d_a, part_b, part_a)| { - // out_a += d_a * part_a - let mut d_a_clone = d_a.clone(); - mod_op.elwise_mul_mut(d_a_clone.as_mut_slice(), part_a.as_ref()); - mod_op.elwise_add_mut(out[0].as_mut_slice(), d_a_clone.as_ref()); - - // out_b += d_a * part_b - let mut d_a_clone = d_a.clone(); - mod_op.elwise_mul_mut(d_a_clone.as_mut_slice(), part_b.as_ref()); - mod_op.elwise_add_mut(out[1].as_mut_slice(), d_a_clone.as_ref()); - }, + let stats = rgsw_noise_stats(&rgsw_carrym, &carry_m, s.values(), &decomposer, &q); + println!( + "Log2 of noise std after {i} RGSW x RGSW: {}", + stats.std_dev().abs().log2() ); - out.iter_mut() - .for_each(|r| ntt_op.backward(r.as_mut_slice())); - - let out_back = { - // decrypt - // a*s - ntt_op.forward(out[0].as_mut()); - mod_op.elwise_mul_mut(out[0].as_mut(), sk_eval.as_ref()); - ntt_op.backward(out[0].as_mut()); - - // b - a*s - let tmp = (out[0]).clone(); - mod_op.elwise_sub_mut(out[1].as_mut(), tmp.as_ref()); - out.remove(1) - }; - - let out_expected = { - let mut a_clone = a.clone(); - let mut m_clone = m.clone(); - - ntt_op.forward(a_clone.as_mut_slice()); - ntt_op.forward(m_clone.as_mut_slice()); - - mod_op.elwise_mul_mut(a_clone.as_mut_slice(), m_clone.as_mut_slice()); - ntt_op.backward(a_clone.as_mut_slice()); - a_clone - }; - - let mut diff = out_expected; - mod_op.elwise_sub_mut(diff.as_mut_slice(), out_back.as_ref()); - stats.add_more(&Vec::::try_convert_from(diff.as_ref(), &q)); } - - println!("Std: {}", stats.std_dev().abs().log2()); } } diff --git a/src/rgsw/runtime.rs b/src/rgsw/runtime.rs index db34679..d8f66be 100644 --- a/src/rgsw/runtime.rs +++ b/src/rgsw/runtime.rs @@ -5,10 +5,305 @@ use crate::{ backend::{ArithmeticOps, GetModulus, ShoupMatrixFMA, VectorOps}, decomposer::{Decomposer, RlweDecomposer}, ntt::Ntt, - Matrix, MatrixEntity, MatrixMut, RowEntity, RowMut, + Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, }; -use super::IsTrivial; +/// Degree 1 RLWE ciphertext. +/// +/// RLWE(m) = [a, b] s.t. m+e = b - as +pub(crate) trait RlweCiphertext { + type R: RowMut; + /// Returns polynomial `a` of RLWE ciphertext as slice of elements + fn part_a(&self) -> &[::Element]; + /// Returns polynomial `a` of RLWE ciphertext as mutable slice of elements + fn part_a_mut(&mut self) -> &mut [::Element]; + /// Returns polynomial `b` of RLWE ciphertext as slice of elements + fn part_b(&self) -> &[::Element]; + /// Returns polynomial `b` of RLWE ciphertext as mut slice of elements + fn part_b_mut(&mut self) -> &mut [::Element]; + /// Returns ring size of polynomials + fn ring_size(&self) -> usize; +} + +/// RGSW ciphertext +/// +/// RGSW is a collection of RLWE' ciphertext which are collection degree 1 of +/// RLWE ciphertexts +/// +/// Let +/// RGSW = [RLWE'(-sm) || RLW'(m)] = [RW] +pub(crate) trait RgswCiphertext { + type R: Row; + + fn split(&self) -> ((&[Self::R], &[Self::R]), (&[Self::R], &[Self::R])); +} + +pub(crate) trait RgswCiphertextMut: RgswCiphertext { + fn split_mut( + &mut self, + ) -> ( + (&mut [Self::R], &mut [Self::R]), + (&mut [Self::R], &mut [Self::R]), + ); +} + +pub(crate) struct RlweCiphertextMutRef<'a, R> { + data: &'a mut [R], +} + +impl<'a, R> RlweCiphertextMutRef<'a, R> { + pub(crate) fn new(data: &'a mut [R]) -> Self { + Self { data } + } +} + +impl<'a, R: RowMut> RlweCiphertext for RlweCiphertextMutRef<'a, R> { + type R = R; + fn part_a(&self) -> &[::Element] { + self.data[0].as_ref() + } + fn part_a_mut(&mut self) -> &mut [::Element] { + self.data[0].as_mut() + } + fn part_b(&self) -> &[::Element] { + self.data[1].as_ref() + } + fn part_b_mut(&mut self) -> &mut [::Element] { + self.data[1].as_mut() + } + fn ring_size(&self) -> usize { + self.data[0].as_ref().len() + } +} + +pub(crate) struct RgswCiphertextRef<'a, R> { + data: &'a [R], + d_a: usize, + d_b: usize, +} + +impl<'a, R> RgswCiphertextRef<'a, R> { + pub(crate) fn new(data: &'a [R], d_a: usize, d_b: usize) -> Self { + RgswCiphertextRef { data, d_a, d_b } + } +} + +impl<'a, R> RgswCiphertext for RgswCiphertextRef<'a, R> +where + R: Row, +{ + type R = R; + + fn split(&self) -> ((&[Self::R], &[Self::R]), (&[Self::R], &[Self::R])) { + let (rlwe_dash_nsm, rlwe_dash_m) = self.data.split_at(self.d_a * 2); + ( + rlwe_dash_nsm.split_at(self.d_a), + rlwe_dash_m.split_at(self.d_b), + ) + } +} + +pub(crate) struct RgswCiphertextMutRef<'a, R> { + data: &'a mut [R], + d_a: usize, + d_b: usize, +} + +impl<'a, R> RgswCiphertextMutRef<'a, R> { + pub(crate) fn new(data: &'a mut [R], d_a: usize, d_b: usize) -> Self { + RgswCiphertextMutRef { data, d_a, d_b } + } +} + +impl<'a, R: RowMut> AsMut<[R]> for RgswCiphertextMutRef<'a, R> { + fn as_mut(&mut self) -> &mut [R] { + &mut self.data + } +} + +impl<'a, R> RgswCiphertext for RgswCiphertextMutRef<'a, R> +where + R: Row, +{ + type R = R; + + fn split(&self) -> ((&[Self::R], &[Self::R]), (&[Self::R], &[Self::R])) { + let (rlwe_dash_nsm, rlwe_dash_m) = self.data.split_at(self.d_a * 2); + ( + rlwe_dash_nsm.split_at(self.d_a), + rlwe_dash_m.split_at(self.d_b), + ) + } +} + +impl<'a, R> RgswCiphertextMut for RgswCiphertextMutRef<'a, R> +where + R: RowMut, +{ + fn split_mut( + &mut self, + ) -> ( + (&mut [Self::R], &mut [Self::R]), + (&mut [Self::R], &mut [Self::R]), + ) { + let (rlwe_dash_nsm, rlwe_dash_m) = self.data.split_at_mut(self.d_a * 2); + ( + rlwe_dash_nsm.split_at_mut(self.d_a), + rlwe_dash_m.split_at_mut(self.d_b), + ) + } +} + +pub(crate) trait RlweKsk { + type R: Row; + fn ksk_part_a(&self) -> &[Self::R]; + fn ksk_part_b(&self) -> &[Self::R]; +} + +pub(crate) struct RlweKskRef<'a, R> { + data: &'a [R], + decomposition_count: usize, +} +impl<'a, R: Row> RlweKskRef<'a, R> { + pub(crate) fn new(ksk: &'a [R], decomposition_count: usize) -> Self { + Self { + data: ksk, + decomposition_count, + } + } +} + +impl<'a, R: Row> RlweKsk for RlweKskRef<'a, R> { + type R = R; + + fn ksk_part_a(&self) -> &[Self::R] { + &self.data[..self.decomposition_count] + } + + fn ksk_part_b(&self) -> &[Self::R] { + &self.data[self.decomposition_count..] + } +} + +pub(crate) trait RlweAutoScratch { + type R: RowMut; + type Rgsw: RgswCiphertext; + + fn split_for_rlwe_auto_and_zero_rlwe_space( + &mut self, + decompostion_count: usize, + ) -> (&mut [Self::R], &mut [Self::R]); + + fn split_for_rlwe_auto_trivial_case(&mut self) -> &mut Self::R; + + fn split_for_rlwe_x_rgsw_and_zero_rlwe_space( + &mut self, + decomposer: &D, + ) -> (&mut [Self::R], &mut [Self::R]); + + fn split_for_rgsw_x_rgsw_and_zero_rgsw0_space( + &mut self, + d0: &D, + d1: &D, + ) -> (&mut [Self::R], &mut [Self::R]); +} + +pub(crate) struct RuntimeScratchMutRef<'a, R> { + data: &'a mut [R], +} + +impl<'a, R> RuntimeScratchMutRef<'a, R> { + pub(crate) fn new(data: &'a mut [R]) -> Self { + Self { data } + } +} + +impl<'a, R: RowMut> RlweAutoScratch for RuntimeScratchMutRef<'a, R> +where + R::Element: Zero + Clone, +{ + type R = R; + type Rgsw = RgswCiphertextRef<'a, R>; + + fn split_for_rlwe_auto_and_zero_rlwe_space( + &mut self, + decompostion_count: usize, + ) -> (&mut [Self::R], &mut [Self::R]) { + let (decomp_poly, other) = self.data.split_at_mut(decompostion_count); + let (rlwe, _) = other.split_at_mut(2); + + // zero fill rlwe + rlwe.iter_mut() + .for_each(|r| r.as_mut().fill(R::Element::zero())); + + (decomp_poly, rlwe) + } + + fn split_for_rlwe_auto_trivial_case(&mut self) -> &mut Self::R { + &mut self.data[0] + } + + fn split_for_rgsw_x_rgsw_and_zero_rgsw0_space( + &mut self, + rgsw0_decoposer: &D, + rgsw1_decoposer: &D, + ) -> (&mut [Self::R], &mut [Self::R]) { + let (decomp_poly, other) = self.data.split_at_mut(std::cmp::max( + rgsw1_decoposer.a().decomposition_count(), + rgsw1_decoposer.b().decomposition_count(), + )); + let (rgsw, _) = other.split_at_mut( + rgsw0_decoposer.a().decomposition_count() * 2 + + rgsw0_decoposer.b().decomposition_count() * 2, + ); + + // zero fill rgsw0 + rgsw.iter_mut() + .for_each(|r| r.as_mut().fill(R::Element::zero())); + + (decomp_poly, rgsw) + } + + fn split_for_rlwe_x_rgsw_and_zero_rlwe_space( + &mut self, + decomposer: &D, + ) -> (&mut [Self::R], &mut [Self::R]) { + let (decomp_poly, other) = self.data.split_at_mut(std::cmp::max( + decomposer.a().decomposition_count(), + decomposer.b().decomposition_count(), + )); + + let (rlwe, _) = other.split_at_mut(2); + + // zero fill rlwe + rlwe.iter_mut() + .for_each(|r| r.as_mut().fill(R::Element::zero())); + + (decomp_poly, rlwe) + } +} + +pub(crate) fn rgsw_x_rgsw_scratch_rows( + rgsw0_decomposer: &D, + rgsw1_decomposer: &D, +) -> usize { + std::cmp::max( + rgsw1_decomposer.a().decomposition_count(), + rgsw1_decomposer.b().decomposition_count(), + ) + rgsw0_decomposer.a().decomposition_count() * 2 + + rgsw0_decomposer.b().decomposition_count() * 2 +} + +pub(crate) fn rlwe_x_rgsw_scratch_rows(rgsw_decomposer: &D) -> usize { + std::cmp::max( + rgsw_decomposer.a().decomposition_count(), + rgsw_decomposer.b().decomposition_count(), + ) + 2 +} + +pub(crate) fn rlwe_auto_scratch_rows(decomposer: &D) -> usize { + decomposer.decomposition_count() + 2 +} pub(crate) fn poly_fma_routine>( write_to_row: &mut [R::Element], @@ -60,49 +355,38 @@ pub(crate) fn decompose_r>( /// - scratch_matrix: must have dimension at-least d+2 x ring_size. `d` rows to /// store decomposed polynomials nad 2 rows to store out RLWE temporarily. pub(crate) fn rlwe_auto< - MT: Matrix + IsTrivial + MatrixMut, - Mmut: MatrixMut, - ModOp: ArithmeticOps + VectorOps, - NttOp: Ntt, - D: Decomposer, + Rlwe: RlweCiphertext, + Ksk: RlweKsk, + Sc: RlweAutoScratch, + ModOp: ArithmeticOps::Element> + + VectorOps::Element>, + NttOp: Ntt::Element>, + D: Decomposer::Element>, >( - rlwe_in: &mut MT, - ksk: &Mmut, - scratch_matrix: &mut Mmut, + rlwe_in: &mut Rlwe, + ksk: &Ksk, + scratch_matrix: &mut Sc, auto_map_index: &[usize], auto_map_sign: &[bool], mod_op: &ModOp, ntt_op: &NttOp, decomposer: &D, + is_trivial: bool, ) where - ::R: RowMut, - ::R: RowMut, - MT::MatElement: Copy + Zero, + ::Element: Copy + Zero, { - let d = decomposer.decomposition_count(); - let ring_size = rlwe_in.dimension().1; - assert!(rlwe_in.dimension().0 == 2); - assert!(scratch_matrix.fits(d + 2, ring_size)); - - // scratch matrix is guaranteed to have at-least d+2 rows but can have more than - // d+2 rows. We require to split them into sub-matrices of exact sizes one with - // d rows for storing decomposed polynomial and second with 2 rows to act - // tomperary space for RLWE ciphertext. Exact sizes is necessary to avoid any - // irrelevant extra FMA or NTT ops. - let (scratch_matrix_d_ring, other_half) = scratch_matrix.split_at_row_mut(d); - let (tmp_rlwe_out, _) = other_half.split_at_mut(2); - - debug_assert!(tmp_rlwe_out.len() == 2); - debug_assert!(scratch_matrix_d_ring.len() == d); - - if !rlwe_in.is_trivial() { - tmp_rlwe_out.iter_mut().for_each(|r| { - r.as_mut().fill(Mmut::MatElement::zero()); - }); + // let ring_size = rlwe_in.dimension().1; + // assert!(rlwe_in.dimension().0 == 2); + // assert!(scratch_matrix.fits(d + 2, ring_size)); + + if !is_trivial { + let (decomp_poly_scratch, tmp_rlwe) = scratch_matrix + .split_for_rlwe_auto_and_zero_rlwe_space(decomposer.decomposition_count()); + let mut tmp_rlwe = RlweCiphertextMutRef::new(tmp_rlwe); // send a(X) -> a(X^k) and decompose a(X^k) izip!( - rlwe_in.get_row(0), + rlwe_in.part_a(), auto_map_index.iter(), auto_map_sign.iter() ) @@ -113,47 +397,45 @@ pub(crate) fn rlwe_auto< .decompose_iter(&el_out) .enumerate() .for_each(|(index, el)| { - scratch_matrix_d_ring[index].as_mut()[*to_index] = el; + decomp_poly_scratch[index].as_mut()[*to_index] = el; }); }); // transform decomposed a(X^k) to evaluation domain - scratch_matrix_d_ring.iter_mut().for_each(|r| { + decomp_poly_scratch.iter_mut().for_each(|r| { ntt_op.forward(r.as_mut()); }); // RLWE(m^k) = a', b'; RLWE(m) = a, b // key switch: (a * RLWE'(s(X^k))) - let (ksk_a, ksk_b) = ksk.split_at_row(d); // a' = decomp * RLWE'_A(s(X^k)) poly_fma_routine( - tmp_rlwe_out[0].as_mut(), - scratch_matrix_d_ring, - ksk_a, + tmp_rlwe.part_a_mut(), + decomp_poly_scratch, + ksk.ksk_part_a(), mod_op, ); // b' += decomp * RLWE'_B(s(X^k)) poly_fma_routine( - tmp_rlwe_out[1].as_mut(), - scratch_matrix_d_ring, - ksk_b, + tmp_rlwe.part_b_mut(), + decomp_poly_scratch, + ksk.ksk_part_b(), mod_op, ); // transform RLWE(m^k) to coefficient domain - tmp_rlwe_out - .iter_mut() - .for_each(|r| ntt_op.backward(r.as_mut())); + ntt_op.backward(tmp_rlwe.part_a_mut()); + ntt_op.backward(tmp_rlwe.part_b_mut()); // send b(X) -> b(X^k) and then b'(X) += b(X^k) izip!( - rlwe_in.get_row(1), + rlwe_in.part_b(), auto_map_index.iter(), auto_map_sign.iter() ) .for_each(|(el_in, to_index, sign)| { - let row = tmp_rlwe_out[1].as_mut(); + let row = tmp_rlwe.part_b_mut(); if !*sign { row[*to_index] = mod_op.sub(&row[*to_index], el_in); } else { @@ -162,30 +444,26 @@ pub(crate) fn rlwe_auto< }); // copy over A; Leave B for later - rlwe_in - .get_row_mut(0) - .copy_from_slice(tmp_rlwe_out[0].as_ref()); + rlwe_in.part_a_mut().copy_from_slice(tmp_rlwe.part_a()); + rlwe_in.part_b_mut().copy_from_slice(tmp_rlwe.part_b()); } else { // RLWE is trivial, a(X) is 0. // send b(X) -> b(X^k) + let tmp_row = scratch_matrix.split_for_rlwe_auto_trivial_case(); izip!( - rlwe_in.get_row(1), + rlwe_in.part_b(), auto_map_index.iter(), auto_map_sign.iter() ) .for_each(|(el_in, to_index, sign)| { if !*sign { - tmp_rlwe_out[1].as_mut()[*to_index] = mod_op.neg(el_in); + tmp_row.as_mut()[*to_index] = mod_op.neg(el_in); } else { - tmp_rlwe_out[1].as_mut()[*to_index] = *el_in; + tmp_row.as_mut()[*to_index] = *el_in; } }); + rlwe_in.part_b_mut().copy_from_slice(tmp_row.as_ref()); } - - // Copy over B - rlwe_in - .get_row_mut(1) - .copy_from_slice(tmp_rlwe_out[1].as_ref()); } /// Sends RLWE_{s(X)}(m(X)) -> RLWE_{s(X)}(m{X^k}) where k is some galois @@ -194,18 +472,20 @@ pub(crate) fn rlwe_auto< /// This is same as `galois_auto` with the difference that alongside `ksk` with /// key switching polynomials in evaluation domain, shoup representation, /// `ksk_shoup`, of the polynomials in evaluation domain is also supplied. -pub(crate) fn galois_auto_shoup< - Mmut: MatrixMut, - ModOp: ArithmeticOps +pub(crate) fn rlwe_auto_shoup< + Rlwe: RlweCiphertext, + Ksk: RlweKsk, + Sc: RlweAutoScratch, + ModOp: ArithmeticOps::Element> // + VectorOps - + ShoupMatrixFMA, - NttOp: Ntt, - D: Decomposer, + + ShoupMatrixFMA, + NttOp: Ntt::Element>, + D: Decomposer::Element>, >( - rlwe_in: &mut Mmut, - ksk: &Mmut, - ksk_shoup: &Mmut, - scratch_matrix: &mut Mmut, + rlwe_in: &mut Rlwe, + ksk: &Ksk, + ksk_shoup: &Ksk, + scratch_matrix: &mut Sc, auto_map_index: &[usize], auto_map_sign: &[bool], mod_op: &ModOp, @@ -213,28 +493,21 @@ pub(crate) fn galois_auto_shoup< decomposer: &D, is_trivial: bool, ) where - ::R: RowMut, - Mmut::MatElement: Copy + Zero, + ::Element: Copy + Zero, { - let d = decomposer.decomposition_count(); - let ring_size = rlwe_in.dimension().1; - assert!(rlwe_in.dimension().0 == 2); - assert!(scratch_matrix.fits(d + 2, ring_size)); - - let (scratch_matrix_d_ring, other_half) = scratch_matrix.split_at_row_mut(d); - let (tmp_rlwe_out, _) = other_half.split_at_mut(2); - - debug_assert!(tmp_rlwe_out.len() == 2); - debug_assert!(scratch_matrix_d_ring.len() == d); + // let d = decomposer.decomposition_count(); + // let ring_size = rlwe_in.dimension().1; + // assert!(rlwe_in.dimension().0 == 2); + // assert!(scratch_matrix.fits(d + 2, ring_size)); if !is_trivial { - tmp_rlwe_out.iter_mut().for_each(|r| { - r.as_mut().fill(Mmut::MatElement::zero()); - }); + let (decomp_poly_scratch, tmp_rlwe) = scratch_matrix + .split_for_rlwe_auto_and_zero_rlwe_space(decomposer.decomposition_count()); + let mut tmp_rlwe = RlweCiphertextMutRef::new(tmp_rlwe); // send a(X) -> a(X^k) and decompose a(X^k) izip!( - rlwe_in.get_row(0), + rlwe_in.part_a(), auto_map_index.iter(), auto_map_sign.iter() ) @@ -245,48 +518,45 @@ pub(crate) fn galois_auto_shoup< .decompose_iter(&el_out) .enumerate() .for_each(|(index, el)| { - scratch_matrix_d_ring[index].as_mut()[*to_index] = el; + decomp_poly_scratch[index].as_mut()[*to_index] = el; }); }); // transform decomposed a(X^k) to evaluation domain - scratch_matrix_d_ring.iter_mut().for_each(|r| { + decomp_poly_scratch.iter_mut().for_each(|r| { ntt_op.forward_lazy(r.as_mut()); }); // RLWE(m^k) = a', b'; RLWE(m) = a, b // key switch: (a * RLWE'(s(X^k))) - let (ksk_a, ksk_b) = ksk.split_at_row(d); - let (ksk_a_shoup, ksk_b_shoup) = ksk_shoup.split_at_row(d); // a' = decomp * RLWE'_A(s(X^k)) mod_op.shoup_matrix_fma( - tmp_rlwe_out[0].as_mut(), - ksk_a, - ksk_a_shoup, - scratch_matrix_d_ring, + tmp_rlwe.part_a_mut(), + ksk.ksk_part_a(), + ksk_shoup.ksk_part_a(), + decomp_poly_scratch, ); // b'= decomp * RLWE'_B(s(X^k)) mod_op.shoup_matrix_fma( - tmp_rlwe_out[1].as_mut(), - ksk_b, - ksk_b_shoup, - scratch_matrix_d_ring, + tmp_rlwe.part_b_mut(), + ksk.ksk_part_b(), + ksk_shoup.ksk_part_b(), + decomp_poly_scratch, ); // transform RLWE(m^k) to coefficient domain - tmp_rlwe_out - .iter_mut() - .for_each(|r| ntt_op.backward(r.as_mut())); + ntt_op.backward(tmp_rlwe.part_a_mut()); + ntt_op.backward(tmp_rlwe.part_b_mut()); // send b(X) -> b(X^k) and then b'(X) += b(X^k) + let row = tmp_rlwe.part_b_mut(); izip!( - rlwe_in.get_row(1), + rlwe_in.part_b(), auto_map_index.iter(), auto_map_sign.iter() ) .for_each(|(el_in, to_index, sign)| { - let row = tmp_rlwe_out[1].as_mut(); if !*sign { row[*to_index] = mod_op.sub(&row[*to_index], el_in); } else { @@ -294,31 +564,27 @@ pub(crate) fn galois_auto_shoup< } }); - // copy over A; Leave B for later - rlwe_in - .get_row_mut(0) - .copy_from_slice(tmp_rlwe_out[0].as_ref()); + // copy over A, B + rlwe_in.part_a_mut().copy_from_slice(tmp_rlwe.part_a()); + rlwe_in.part_b_mut().copy_from_slice(tmp_rlwe.part_b()); } else { // RLWE is trivial, a(X) is 0. // send b(X) -> b(X^k) + let row = scratch_matrix.split_for_rlwe_auto_trivial_case(); izip!( - rlwe_in.get_row(1), + rlwe_in.part_b(), auto_map_index.iter(), auto_map_sign.iter() ) .for_each(|(el_in, to_index, sign)| { if !*sign { - tmp_rlwe_out[1].as_mut()[*to_index] = mod_op.neg(el_in); + row.as_mut()[*to_index] = mod_op.neg(el_in); } else { - tmp_rlwe_out[1].as_mut()[*to_index] = *el_in; + row.as_mut()[*to_index] = *el_in; } }); + rlwe_in.part_b_mut().copy_from_slice(row.as_ref()); } - - // Copy over B - rlwe_in - .get_row_mut(1) - .copy_from_slice(tmp_rlwe_out[1].as_ref()); } /// Inplace mutates RLWE(m0) to equal RLWE(m0m1) = RLWE(m0) x RGSW(m1). @@ -328,104 +594,101 @@ pub(crate) fn galois_auto_shoup< /// - scratch_matrix: with dimension (max(d_a, d_b) + 2) x ring_size columns. /// It's used to store decomposed polynomials and out RLWE temporarily pub(crate) fn rlwe_by_rgsw< - Mmut: MatrixMut, - MT: Matrix + MatrixMut + IsTrivial, - D: RlweDecomposer, - ModOp: VectorOps, - NttOp: Ntt, + Rlwe: RlweCiphertext, + Rgsw: RgswCiphertext, + Sc: RlweAutoScratch, + D: RlweDecomposer::Element>, + ModOp: VectorOps::Element>, + NttOp: Ntt::Element>, >( - rlwe_in: &mut MT, - rgsw_in: &Mmut, - scratch_matrix: &mut Mmut, + rlwe_in: &mut Rlwe, + rgsw_in: &Rgsw, + scratch_matrix: &mut Sc, decomposer: &D, ntt_op: &NttOp, mod_op: &ModOp, + is_trivial: bool, ) where - Mmut::MatElement: Copy + Zero, - ::R: RowMut, - ::R: RowMut, + ::Element: Copy + Zero, { let decomposer_a = decomposer.a(); let decomposer_b = decomposer.b(); let d_a = decomposer_a.decomposition_count(); let d_b = decomposer_b.decomposition_count(); - let max_d = std::cmp::max(d_a, d_b); - assert!(scratch_matrix.fits(max_d + 2, rlwe_in.dimension().1)); - assert!(rgsw_in.dimension() == (d_a * 2 + d_b * 2, rlwe_in.dimension().1)); - // decomposed RLWE x RGSW - let (rlwe_dash_nsm, rlwe_dash_m) = rgsw_in.split_at_row(d_a * 2); - let (scratch_matrix_d_ring, rest) = scratch_matrix.split_at_row_mut(max_d); - let (scratch_rlwe_out, _) = rest.split_at_mut(2); + let ((rlwe_dash_nsm_parta, rlwe_dash_nsm_partb), (rlwe_dash_m_parta, rlwe_dash_m_partb)) = + rgsw_in.split(); - scratch_rlwe_out[0].as_mut().fill(Mmut::MatElement::zero()); - scratch_rlwe_out[1].as_mut().fill(Mmut::MatElement::zero()); + let (decomposed_poly_scratch, tmp_rlwe) = + scratch_matrix.split_for_rlwe_x_rgsw_and_zero_rlwe_space(decomposer); // RLWE_in = a_in, b_in; RLWE_out = a_out, b_out - if !rlwe_in.is_trivial() { + if !is_trivial { // a_in = 0 when RLWE_in is trivial RLWE ciphertext // decomp + let mut decomposed_polys_of_rlwea = &mut decomposed_poly_scratch[..d_a]; decompose_r( - rlwe_in.get_row_slice(0), - &mut scratch_matrix_d_ring[..d_a], + rlwe_in.part_a(), + &mut decomposed_polys_of_rlwea, decomposer_a, ); - scratch_matrix_d_ring + + decomposed_polys_of_rlwea .iter_mut() - .take(d_a) .for_each(|r| ntt_op.forward(r.as_mut())); + // a_out += decomp \cdot RLWE_A'(-sm) poly_fma_routine( - scratch_rlwe_out[0].as_mut(), - &scratch_matrix_d_ring[..d_a], - &rlwe_dash_nsm[..d_a], + tmp_rlwe[0].as_mut(), + &decomposed_polys_of_rlwea, + rlwe_dash_nsm_parta, mod_op, ); // b_out += decomp \cdot RLWE_B'(-sm) poly_fma_routine( - scratch_rlwe_out[1].as_mut(), - &scratch_matrix_d_ring[..d_a], - &rlwe_dash_nsm[d_a..], + tmp_rlwe[1].as_mut(), + &decomposed_polys_of_rlwea, + &rlwe_dash_nsm_partb, + mod_op, + ); + } + + { + // decomp + let mut decomposed_polys_of_rlweb = &mut decomposed_poly_scratch[..d_b]; + decompose_r( + rlwe_in.part_b(), + &mut decomposed_polys_of_rlweb, + decomposer_b, + ); + + decomposed_polys_of_rlweb + .iter_mut() + .for_each(|r| ntt_op.forward(r.as_mut())); + + // a_out += decomp \cdot RLWE_A'(m) + poly_fma_routine( + tmp_rlwe[0].as_mut(), + &decomposed_polys_of_rlweb, + &rlwe_dash_m_parta, + mod_op, + ); + // b_out += decomp \cdot RLWE_B'(m) + poly_fma_routine( + tmp_rlwe[1].as_mut(), + &decomposed_polys_of_rlweb, + &rlwe_dash_m_partb, mod_op, ); } - // decomp - decompose_r( - rlwe_in.get_row_slice(1), - &mut scratch_matrix_d_ring[..d_b], - decomposer_b, - ); - scratch_matrix_d_ring - .iter_mut() - .take(d_b) - .for_each(|r| ntt_op.forward(r.as_mut())); - // a_out += decomp \cdot RLWE_A'(m) - poly_fma_routine( - scratch_rlwe_out[0].as_mut(), - &scratch_matrix_d_ring[..d_b], - &rlwe_dash_m[..d_b], - mod_op, - ); - // b_out += decomp \cdot RLWE_B'(m) - poly_fma_routine( - scratch_rlwe_out[1].as_mut(), - &scratch_matrix_d_ring[..d_b], - &rlwe_dash_m[d_b..], - mod_op, - ); // transform rlwe_out to coefficient domain - scratch_rlwe_out + tmp_rlwe .iter_mut() .for_each(|r| ntt_op.backward(r.as_mut())); - rlwe_in - .get_row_mut(0) - .copy_from_slice(scratch_rlwe_out[0].as_mut()); - rlwe_in - .get_row_mut(1) - .copy_from_slice(scratch_rlwe_out[1].as_mut()); - rlwe_in.set_not_trivial(); + rlwe_in.part_a_mut().copy_from_slice(tmp_rlwe[0].as_mut()); + rlwe_in.part_b_mut().copy_from_slice(tmp_rlwe[1].as_mut()); } /// Inplace mutates RLWE(m0) to equal RLWE(m0m1) = RLWE(m0) x RGSW(m1). @@ -434,111 +697,106 @@ pub(crate) fn rlwe_by_rgsw< /// polynomials in evaluation domain, shoup representation of polynomials in /// evaluation domain, `rgsw_in_shoup`, is also supplied. pub(crate) fn rlwe_by_rgsw_shoup< - Mmut: MatrixMut, - D: RlweDecomposer, - ModOp: ShoupMatrixFMA, - NttOp: Ntt, + Rlwe: RlweCiphertext, + Rgsw: RgswCiphertext, + Sc: RlweAutoScratch, + D: RlweDecomposer::Element>, + ModOp: ShoupMatrixFMA, + NttOp: Ntt::Element>, >( - rlwe_in: &mut Mmut, - rgsw_in: &Mmut, - rgsw_in_shoup: &Mmut, - scratch_matrix: &mut Mmut, + rlwe_in: &mut Rlwe, + rgsw_in: &Rgsw, + rgsw_in_shoup: &Rgsw, + scratch_matrix: &mut Sc, decomposer: &D, ntt_op: &NttOp, mod_op: &ModOp, is_trivial: bool, ) where - Mmut::MatElement: Copy + Zero, - ::R: RowMut, + ::Element: Copy + Zero, { let decomposer_a = decomposer.a(); let decomposer_b = decomposer.b(); let d_a = decomposer_a.decomposition_count(); let d_b = decomposer_b.decomposition_count(); - let max_d = std::cmp::max(d_a, d_b); - assert!(scratch_matrix.fits(max_d + 2, rlwe_in.dimension().1)); - assert!(rgsw_in.dimension() == (d_a * 2 + d_b * 2, rlwe_in.dimension().1)); - assert!(rgsw_in.dimension() == rgsw_in_shoup.dimension()); - // decomposed RLWE x RGSW - let (rlwe_dash_nsm, rlwe_dash_m) = rgsw_in.split_at_row(d_a * 2); - let (rlwe_dash_nsm_shoup, rlwe_dash_m_shoup) = rgsw_in_shoup.split_at_row(d_a * 2); - let (scratch_matrix_d_ring, rest) = scratch_matrix.split_at_row_mut(max_d); - let (scratch_rlwe_out, _) = rest.split_at_mut(2); + let ((rlwe_dash_nsm_parta, rlwe_dash_nsm_partb), (rlwe_dash_m_parta, rlwe_dash_m_partb)) = + rgsw_in.split(); + + let ( + (rlwe_dash_nsm_parta_shoup, rlwe_dash_nsm_partb_shoup), + (rlwe_dash_m_parta_shoup, rlwe_dash_m_partb_shoup), + ) = rgsw_in_shoup.split(); - scratch_rlwe_out[1].as_mut().fill(Mmut::MatElement::zero()); - scratch_rlwe_out[0].as_mut().fill(Mmut::MatElement::zero()); + let (decomposed_poly_scratch, tmp_rlwe) = + scratch_matrix.split_for_rlwe_x_rgsw_and_zero_rlwe_space(decomposer); // RLWE_in = a_in, b_in; RLWE_out = a_out, b_out if !is_trivial { // a_in = 0 when RLWE_in is trivial RLWE ciphertext // decomp + let mut decomposed_polys_of_rlwea = &mut decomposed_poly_scratch[..d_a]; decompose_r( - rlwe_in.get_row_slice(0), - &mut scratch_matrix_d_ring[..d_a], + rlwe_in.part_a(), + &mut decomposed_polys_of_rlwea, decomposer_a, ); - scratch_matrix_d_ring + decomposed_polys_of_rlwea .iter_mut() - .take(d_a) .for_each(|r| ntt_op.forward_lazy(r.as_mut())); // a_out += decomp \cdot RLWE_A'(-sm) mod_op.shoup_matrix_fma( - scratch_rlwe_out[0].as_mut(), - &rlwe_dash_nsm[..d_a], - &rlwe_dash_nsm_shoup[..d_a], - &scratch_matrix_d_ring[..d_a], + tmp_rlwe[0].as_mut(), + &rlwe_dash_nsm_parta, + &rlwe_dash_nsm_parta_shoup, + &decomposed_polys_of_rlwea, ); // b_out += decomp \cdot RLWE_B'(-sm) mod_op.shoup_matrix_fma( - scratch_rlwe_out[1].as_mut(), - &rlwe_dash_nsm[d_a..], - &rlwe_dash_nsm_shoup[d_a..], - &scratch_matrix_d_ring[..d_a], + tmp_rlwe[1].as_mut(), + &rlwe_dash_nsm_partb, + &rlwe_dash_nsm_partb_shoup, + &decomposed_polys_of_rlwea, ); } { // decomp + let mut decomposed_polys_of_rlweb = &mut decomposed_poly_scratch[..d_b]; decompose_r( - rlwe_in.get_row_slice(1), - &mut scratch_matrix_d_ring[..d_b], + rlwe_in.part_b(), + &mut decomposed_polys_of_rlweb, decomposer_b, ); - scratch_matrix_d_ring + decomposed_polys_of_rlweb .iter_mut() - .take(d_b) .for_each(|r| ntt_op.forward_lazy(r.as_mut())); // a_out += decomp \cdot RLWE_A'(m) mod_op.shoup_matrix_fma( - scratch_rlwe_out[0].as_mut(), - &rlwe_dash_m[..d_b], - &rlwe_dash_m_shoup[..d_b], - &scratch_matrix_d_ring[..d_b], + tmp_rlwe[0].as_mut(), + &rlwe_dash_m_parta, + &rlwe_dash_m_parta_shoup, + &decomposed_polys_of_rlweb, ); // b_out += decomp \cdot RLWE_B'(m) mod_op.shoup_matrix_fma( - scratch_rlwe_out[1].as_mut(), - &rlwe_dash_m[d_b..], - &rlwe_dash_m_shoup[d_b..], - &scratch_matrix_d_ring[..d_b], + tmp_rlwe[1].as_mut(), + &rlwe_dash_m_partb, + &rlwe_dash_m_partb_shoup, + &decomposed_polys_of_rlweb, ); } // transform rlwe_out to coefficient domain - scratch_rlwe_out + tmp_rlwe .iter_mut() .for_each(|r| ntt_op.backward(r.as_mut())); - rlwe_in - .get_row_mut(0) - .copy_from_slice(scratch_rlwe_out[0].as_mut()); - rlwe_in - .get_row_mut(1) - .copy_from_slice(scratch_rlwe_out[1].as_mut()); + rlwe_in.part_a_mut().copy_from_slice(tmp_rlwe[0].as_mut()); + rlwe_in.part_b_mut().copy_from_slice(tmp_rlwe[1].as_mut()); } /// Inplace mutates RGSW(m0) to equal RGSW(m0m1) = RGSW(m0)xRGSW(m1) @@ -568,60 +826,52 @@ pub(crate) fn rlwe_by_rgsw_shoup< /// - rgsw_0: RGSW(m0) in coefficient domain /// - rgsw_1_eval: RGSW(m1) in evaluation domain pub(crate) fn rgsw_by_rgsw_inplace< - Mmut: MatrixMut, - D: RlweDecomposer, - ModOp: VectorOps, - NttOp: Ntt, + Rgsw: RgswCiphertext, + RgswMut: RgswCiphertextMut, + Sc: RlweAutoScratch, + D: RlweDecomposer::Element>, + ModOp: VectorOps::Element>, + NttOp: Ntt::Element>, >( - rgsw0: &mut Mmut, - rgsw0_da: usize, - rgsw0_db: usize, - rgsw1_eval: &Mmut, - decomposer: &D, - scratch_matrix: &mut Mmut, + rgsw0: &mut RgswMut, + rgsw1_eval: &Rgsw, + rgsw0_decomposer: &D, + rgsw1_decomposer: &D, + scratch_matrix: &mut Sc, ntt_op: &NttOp, mod_op: &ModOp, ) where - ::R: RowMut, - Mmut::MatElement: Copy + Zero, + ::Element: Copy + Zero, + RgswMut: AsMut<[Rgsw::R]>, + RgswMut::R: RowMut, + // Rgsw: AsRef<[Rgsw::R]>, { - let decomposer_a = decomposer.a(); - let decomposer_b = decomposer.b(); - let d_a = decomposer_a.decomposition_count(); - let d_b = decomposer_b.decomposition_count(); - let max_d = std::cmp::max(d_a, d_b); - let rgsw1_rows = d_a * 2 + d_b * 2; - let rgsw0_rows = rgsw0_da * 2 + rgsw0_db * 2; - let ring_size = rgsw0.dimension().1; - assert!(rgsw0.dimension().0 == rgsw0_rows); - assert!(rgsw1_eval.dimension() == (rgsw1_rows, ring_size)); - assert!(scratch_matrix.fits(max_d + rgsw0_rows, ring_size)); - - let (decomp_r_space, rgsw_space) = scratch_matrix.split_at_row_mut(max_d); - - // zero rgsw_space - rgsw_space - .iter_mut() - .for_each(|ri| ri.as_mut().fill(Mmut::MatElement::zero())); - let (rlwe_dash_space_nsm, rlwe_dash_space_m) = rgsw_space.split_at_mut(rgsw0_da * 2); - let (rlwe_dash_space_nsm_parta, rlwe_dash_space_nsm_partb) = - rlwe_dash_space_nsm.split_at_mut(rgsw0_da); - let (rlwe_dash_space_m_parta, rlwe_dash_space_m_partb) = - rlwe_dash_space_m.split_at_mut(rgsw0_db); + // let rgsw0_rows = rgsw0_da * 2 + rgsw0_db * 2; + // let ring_size = rgsw0.dimension().1; + // assert!(rgsw0.dimension().0 == rgsw0_rows); + // assert!(rgsw1_eval.dimension() == (rgsw1_rows, ring_size)); + // assert!(scratch_matrix.fits(max_d + rgsw0_rows, ring_size)); + + let (decomp_r_space, rgsw_space) = scratch_matrix + .split_for_rgsw_x_rgsw_and_zero_rgsw0_space(rgsw0_decomposer, rgsw1_decomposer); + + let mut rgsw_space = RgswCiphertextMutRef::new( + rgsw_space, + rgsw0_decomposer.a().decomposition_count(), + rgsw0_decomposer.b().decomposition_count(), + ); + let ( + (rlwe_dash_space_nsm_parta, rlwe_dash_space_nsm_partb), + (rlwe_dash_space_m_parta, rlwe_dash_space_m_partb), + ) = rgsw_space.split_mut(); - let (rgsw0_nsm, rgsw0_m) = rgsw0.split_at_row(rgsw0_da * 2); - let (rgsw1_nsm, rgsw1_m) = rgsw1_eval.split_at_row(d_a * 2); + let ((rgsw0_nsm_parta, rgsw0_nsm_partb), (rgsw0_m_parta, rgsw0_m_partb)) = rgsw0.split(); + let ((rgsw1_nsm_parta, rgsw1_nsm_partb), (rgsw1_m_parta, rgsw1_m_partb)) = rgsw1_eval.split(); // RGSW x RGSW izip!( - rgsw0_nsm - .iter() - .take(rgsw0_da) - .chain(rgsw0_m.iter().take(rgsw0_db)), - rgsw0_nsm - .iter() - .skip(rgsw0_da) - .chain(rgsw0_m.iter().skip(rgsw0_db)), + rgsw0_nsm_parta.iter().chain(rgsw0_m_parta), + rgsw0_nsm_partb.iter().chain(rgsw0_m_partb), rlwe_dash_space_nsm_parta .iter_mut() .chain(rlwe_dash_space_m_parta.iter_mut()), @@ -633,51 +883,55 @@ pub(crate) fn rgsw_by_rgsw_inplace< // RLWE(m0) x RGSW(m1) // Part A: Decomp \cdot RLWE'(-sm1) - decompose_r(rlwe_a.as_ref(), decomp_r_space.as_mut(), decomposer_a); - decomp_r_space - .iter_mut() - .take(d_a) - .for_each(|ri| ntt_op.forward(ri.as_mut())); - poly_fma_routine( - rlwe_out_a.as_mut(), - &decomp_r_space[..d_a], - &rgsw1_nsm[..d_a], - mod_op, - ); - poly_fma_routine( - rlwe_out_b.as_mut(), - &decomp_r_space[..d_a], - &rgsw1_nsm[d_a..], - mod_op, - ); + { + let decomp_r_parta = &mut decomp_r_space[..rgsw1_decomposer.a().decomposition_count()]; + decompose_r( + rlwe_a.as_ref(), + decomp_r_parta.as_mut(), + rgsw1_decomposer.a(), + ); + decomp_r_parta + .iter_mut() + .for_each(|ri| ntt_op.forward(ri.as_mut())); + poly_fma_routine( + rlwe_out_a.as_mut(), + &decomp_r_parta, + &rgsw1_nsm_parta, + mod_op, + ); + poly_fma_routine( + rlwe_out_b.as_mut(), + &decomp_r_parta, + &rgsw1_nsm_partb, + mod_op, + ); + } // Part B: Decompose \cdot RLWE'(m1) - decompose_r(rlwe_b.as_ref(), decomp_r_space.as_mut(), decomposer_b); - decomp_r_space - .iter_mut() - .take(d_b) - .for_each(|ri| ntt_op.forward(ri.as_mut())); - poly_fma_routine( - rlwe_out_a.as_mut(), - &decomp_r_space[..d_b], - &rgsw1_m[..d_b], - mod_op, - ); - poly_fma_routine( - rlwe_out_b.as_mut(), - &decomp_r_space[..d_b], - &rgsw1_m[d_b..], - mod_op, - ); + { + let decomp_r_partb = &mut decomp_r_space[..rgsw1_decomposer.b().decomposition_count()]; + decompose_r( + rlwe_b.as_ref(), + decomp_r_partb.as_mut(), + rgsw1_decomposer.b(), + ); + decomp_r_partb + .iter_mut() + .for_each(|ri| ntt_op.forward(ri.as_mut())); + poly_fma_routine(rlwe_out_a.as_mut(), &decomp_r_partb, &rgsw1_m_parta, mod_op); + poly_fma_routine(rlwe_out_b.as_mut(), &decomp_r_partb, &rgsw1_m_partb, mod_op); + } }); - // copy over RGSW(m0m1) into RGSW(m0) - izip!(rgsw0.iter_rows_mut(), rgsw_space.iter()) + // copy over RGSW(m0m1) to RGSW(m0) + // let d = rgsw0.as_mut(); + izip!(rgsw0.as_mut().iter_mut(), rgsw_space.data.iter()) .for_each(|(to_ri, from_ri)| to_ri.as_mut().copy_from_slice(from_ri.as_ref())); // send back to coefficient domain rgsw0 - .iter_rows_mut() + .as_mut() + .iter_mut() .for_each(|ri| ntt_op.backward(ri.as_mut())); }