From fe0f8877064450601656e35596b06dc0e07086ec Mon Sep 17 00:00:00 2001 From: Janmajaya Mall Date: Mon, 10 Jun 2024 17:47:58 +0530 Subject: [PATCH] divide rgsw into keygen ops and runtime ops --- src/rgsw/keygen.rs | 640 +++++++++++++++++++ src/{rgsw.rs => rgsw/mod.rs} | 1142 +++------------------------------- src/rgsw/runtime.rs | 408 ++++++++++++ 3 files changed, 1145 insertions(+), 1045 deletions(-) create mode 100644 src/rgsw/keygen.rs rename src/{rgsw.rs => rgsw/mod.rs} (50%) create mode 100644 src/rgsw/runtime.rs diff --git a/src/rgsw/keygen.rs b/src/rgsw/keygen.rs new file mode 100644 index 0000000..e665cd0 --- /dev/null +++ b/src/rgsw/keygen.rs @@ -0,0 +1,640 @@ +use std::{ + clone, + fmt::Debug, + iter, + marker::PhantomData, + ops::{Div, Neg, Sub}, +}; + +use itertools::{izip, Itertools}; +use num_traits::{PrimInt, Signed, ToPrimitive, Zero}; + +use crate::{ + backend::{ArithmeticOps, GetModulus, Modulus, VectorOps}, + decomposer::{self, Decomposer, RlweDecomposer}, + ntt::{self, Ntt, NttInit}, + random::{ + DefaultSecureRng, NewWithSeed, RandomElementInModulus, RandomFill, + RandomFillGaussianInModulus, RandomFillUniformInModulus, + }, + utils::{fill_random_ternary_secret_with_hamming_weight, TryConvertFrom1, WithLocal}, + Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, Secret, +}; + +pub(crate) fn generate_auto_map(ring_size: usize, k: isize) -> (Vec, Vec) { + assert!(k & 1 == 1, "Auto {k} must be odd"); + + let k = if k < 0 { + // k is -ve, return k%(2*N) + (2 * ring_size) - (k.abs() as usize % (2 * ring_size)) + } else { + k as usize + }; + let (auto_map_index, auto_sign_index): (Vec, Vec) = (0..ring_size) + .into_iter() + .map(|i| { + let mut to_index = (i * k) % (2 * ring_size); + let mut sign = true; + + // wrap around. false implies negative + if to_index >= ring_size { + to_index = to_index - ring_size; + sign = false; + } + + (to_index, sign) + }) + .unzip(); + (auto_map_index, auto_sign_index) +} + +/// Encrypts message m as a RGSW ciphertext. +/// +/// - m_eval: is `m` is evaluation domain +/// - out_rgsw: RGSW(m) is stored as single matrix of dimension (d_rgsw * 3, +/// ring_size). The matrix has the following structure [RLWE'_A(-sm) || +/// RLWE'_B(-sm) || RLWE'_B(m)]^T and RLWE'_A(m) is generated via seed (where +/// p_rng is assumed to be seeded with seed) +pub(crate) fn secret_key_encrypt_rgsw< + Mmut: MatrixMut + MatrixEntity, + S, + R: RandomFillGaussianInModulus<[Mmut::MatElement], ModOp::M> + + RandomFillUniformInModulus<[Mmut::MatElement], ModOp::M>, + PR: RandomFillUniformInModulus<[Mmut::MatElement], ModOp::M>, + ModOp: VectorOps + GetModulus, + NttOp: Ntt, +>( + out_rgsw: &mut Mmut, + m: &[Mmut::MatElement], + gadget_a: &[Mmut::MatElement], + gadget_b: &[Mmut::MatElement], + s: &[S], + mod_op: &ModOp, + ntt_op: &NttOp, + p_rng: &mut PR, + rng: &mut R, +) where + ::R: RowMut + RowEntity + TryConvertFrom1<[S], ModOp::M> + Debug, + Mmut::MatElement: Copy + Debug, +{ + let d_a = gadget_a.len(); + let d_b = gadget_b.len(); + let q = mod_op.modulus(); + let ring_size = s.len(); + assert!(out_rgsw.dimension() == (d_a * 2 + d_b, ring_size)); + assert!(m.as_ref().len() == ring_size); + + // RLWE(-sm), RLWE(m) + let (rlwe_dash_nsm, b_rlwe_dash_m) = out_rgsw.split_at_row_mut(d_a * 2); + + let mut s_eval = Mmut::R::try_convert_from(s, &q); + ntt_op.forward(s_eval.as_mut()); + + let mut scratch_space = Mmut::R::zeros(ring_size); + + // RLWE'(-sm) + let (a_rlwe_dash_nsm, b_rlwe_dash_nsm) = rlwe_dash_nsm.split_at_mut(d_a); + izip!( + a_rlwe_dash_nsm.iter_mut(), + b_rlwe_dash_nsm.iter_mut(), + gadget_a.iter() + ) + .for_each(|(ai, bi, beta_i)| { + // Sample a_i + RandomFillUniformInModulus::random_fill(rng, &q, ai.as_mut()); + + // a_i * s + scratch_space.as_mut().copy_from_slice(ai.as_ref()); + ntt_op.forward(scratch_space.as_mut()); + mod_op.elwise_mul_mut(scratch_space.as_mut(), s_eval.as_ref()); + ntt_op.backward(scratch_space.as_mut()); + + // b_i = e_i + a_i * s + RandomFillGaussianInModulus::random_fill(rng, &q, bi.as_mut()); + mod_op.elwise_add_mut(bi.as_mut(), scratch_space.as_ref()); + + // a_i + \beta_i * m + mod_op.elwise_scalar_mul(scratch_space.as_mut(), m.as_ref(), beta_i); + mod_op.elwise_add_mut(ai.as_mut(), scratch_space.as_ref()); + }); + + // RLWE(m) + let mut a_rlwe_dash_m = { + // polynomials of part A of RLWE'(m) are sampled from seed + let mut a = Mmut::zeros(d_b, ring_size); + a.iter_rows_mut() + .for_each(|ai| RandomFillUniformInModulus::random_fill(p_rng, &q, ai.as_mut())); + a + }; + + izip!( + a_rlwe_dash_m.iter_rows_mut(), + b_rlwe_dash_m.iter_mut(), + gadget_b.iter() + ) + .for_each(|(ai, bi, beta_i)| { + // ai * s + ntt_op.forward(ai.as_mut()); + mod_op.elwise_mul_mut(ai.as_mut(), s_eval.as_ref()); + ntt_op.backward(ai.as_mut()); + + // beta_i * m + mod_op.elwise_scalar_mul(scratch_space.as_mut(), m.as_ref(), beta_i); + + // Sample e_i + RandomFillGaussianInModulus::random_fill(rng, &q, bi.as_mut()); + // e_i + beta_i * m + ai*s + mod_op.elwise_add_mut(bi.as_mut(), scratch_space.as_ref()); + mod_op.elwise_add_mut(bi.as_mut(), ai.as_ref()); + }); +} + +pub(crate) fn public_key_encrypt_rgsw< + Mmut: MatrixMut + MatrixEntity, + M: Matrix, + R: RandomFillGaussianInModulus<[Mmut::MatElement], ModOp::M> + + RandomFill<[u8]> + + RandomElementInModulus, + ModOp: VectorOps + GetModulus, + NttOp: Ntt, +>( + out_rgsw: &mut Mmut, + m: &[M::MatElement], + public_key: &M, + gadget_a: &[Mmut::MatElement], + gadget_b: &[Mmut::MatElement], + mod_op: &ModOp, + ntt_op: &NttOp, + rng: &mut R, +) where + ::R: RowMut + RowEntity + TryConvertFrom1<[i32], ModOp::M>, + Mmut::MatElement: Copy, +{ + let ring_size = public_key.dimension().1; + let d_a = gadget_a.len(); + let d_b = gadget_b.len(); + assert!(public_key.dimension().0 == 2); + assert!(out_rgsw.dimension() == (d_a * 2 + d_b * 2, ring_size)); + + let mut pk_eval = Mmut::zeros(2, ring_size); + izip!(pk_eval.iter_rows_mut(), public_key.iter_rows()).for_each(|(to_i, from_i)| { + to_i.as_mut().copy_from_slice(from_i.as_ref()); + ntt_op.forward(to_i.as_mut()); + }); + let p0 = pk_eval.get_row_slice(0); + let p1 = pk_eval.get_row_slice(1); + + let q = mod_op.modulus(); + + // RGSW(m) = RLWE'(-sm), RLWE(m) + let (rlwe_dash_nsm, rlwe_dash_m) = out_rgsw.split_at_row_mut(d_a * 2); + + // RLWE(-sm) + let (rlwe_dash_nsm_parta, rlwe_dash_nsm_partb) = rlwe_dash_nsm.split_at_mut(d_a); + izip!( + rlwe_dash_nsm_parta.iter_mut(), + rlwe_dash_nsm_partb.iter_mut(), + gadget_a.iter() + ) + .for_each(|(ai, bi, beta_i)| { + // sample ephemeral secret u_i + let mut u = vec![0i32; ring_size]; + fill_random_ternary_secret_with_hamming_weight(u.as_mut(), ring_size >> 1, rng); + let mut u_eval = Mmut::R::try_convert_from(u.as_ref(), &q); + ntt_op.forward(u_eval.as_mut()); + + let mut u_eval_copy = Mmut::R::zeros(ring_size); + u_eval_copy.as_mut().copy_from_slice(u_eval.as_ref()); + + // p0 * u + mod_op.elwise_mul_mut(u_eval.as_mut(), p0.as_ref()); + // p1 * u + mod_op.elwise_mul_mut(u_eval_copy.as_mut(), p1.as_ref()); + ntt_op.backward(u_eval.as_mut()); + ntt_op.backward(u_eval_copy.as_mut()); + + // sample error + RandomFillGaussianInModulus::random_fill(rng, &q, ai.as_mut()); + RandomFillGaussianInModulus::random_fill(rng, &q, bi.as_mut()); + + // a = p0*u+e0 + mod_op.elwise_add_mut(ai.as_mut(), u_eval.as_ref()); + // b = p1*u+e1 + mod_op.elwise_add_mut(bi.as_mut(), u_eval_copy.as_ref()); + + // a = p0*u + e0 + \beta*m + // use u_eval as scratch + mod_op.elwise_scalar_mul(u_eval.as_mut(), m.as_ref(), beta_i); + mod_op.elwise_add_mut(ai.as_mut(), u_eval.as_ref()); + }); + + // RLWE(m) + let (rlwe_dash_m_parta, rlwe_dash_m_partb) = rlwe_dash_m.split_at_mut(d_b); + izip!( + rlwe_dash_m_parta.iter_mut(), + rlwe_dash_m_partb.iter_mut(), + gadget_b.iter() + ) + .for_each(|(ai, bi, beta_i)| { + // sample ephemeral secret u_i + let mut u = vec![0i32; ring_size]; + fill_random_ternary_secret_with_hamming_weight(u.as_mut(), ring_size >> 1, rng); + let mut u_eval = Mmut::R::try_convert_from(u.as_ref(), &q); + ntt_op.forward(u_eval.as_mut()); + + let mut u_eval_copy = Mmut::R::zeros(ring_size); + u_eval_copy.as_mut().copy_from_slice(u_eval.as_ref()); + + // p0 * u + mod_op.elwise_mul_mut(u_eval.as_mut(), p0.as_ref()); + // p1 * u + mod_op.elwise_mul_mut(u_eval_copy.as_mut(), p1.as_ref()); + ntt_op.backward(u_eval.as_mut()); + ntt_op.backward(u_eval_copy.as_mut()); + + // sample error + RandomFillGaussianInModulus::random_fill(rng, &q, ai.as_mut()); + RandomFillGaussianInModulus::random_fill(rng, &q, bi.as_mut()); + + // a = p0*u+e0 + mod_op.elwise_add_mut(ai.as_mut(), u_eval.as_ref()); + // b = p1*u+e1 + mod_op.elwise_add_mut(bi.as_mut(), u_eval_copy.as_ref()); + + // b = p1*u + e0 + \beta*m + // use u_eval as scratch + mod_op.elwise_scalar_mul(u_eval.as_mut(), m.as_ref(), beta_i); + mod_op.elwise_add_mut(bi.as_mut(), u_eval.as_ref()); + }); +} + +/// Generates RLWE Key switching key to key switch ciphertext RLWE_{from_s}(m) +/// to RLWE_{to_s}(m). +/// +/// Key switching equals +/// \sum decompose(c_1)_i * RLWE_{to_s}(\beta^i -from_s) +/// Hence, key switchin key equals RLWE'(-from_s) = RLWE(-from_s), RLWE(beta^1 +/// -from_s), ..., RLWE(beta^{d-1} -from_s). +/// +/// - ksk_out: Output Key switching key. Key switching key stores only part B +/// polynomials of ksk RLWE ciphertexts (i.e. RLWE'_B(-from_s)) in coefficient +/// domain +/// - neg_from_s: Negative of secret polynomial to key switch from +/// - to_s: secret polynomial to key switch to. +pub(crate) fn rlwe_ksk_gen< + Mmut: MatrixMut + MatrixEntity, + ModOp: ArithmeticOps + + VectorOps + + GetModulus, + NttOp: Ntt, + R: RandomFillGaussianInModulus<[Mmut::MatElement], ModOp::M>, + PR: RandomFillUniformInModulus<[Mmut::MatElement], ModOp::M>, +>( + ksk_out: &mut Mmut, + neg_from_s: Mmut::R, + mut to_s: Mmut::R, + gadget_vector: &[Mmut::MatElement], + mod_op: &ModOp, + ntt_op: &NttOp, + p_rng: &mut PR, + rng: &mut R, +) where + ::R: RowMut, +{ + let ring_size = neg_from_s.as_ref().len(); + let d = gadget_vector.len(); + assert!(ksk_out.dimension() == (d, ring_size)); + + let q = mod_op.modulus(); + + ntt_op.forward(to_s.as_mut()); + + // RLWE'_{to_s}(-from_s) + let mut part_a = { + let mut a = Mmut::zeros(d, ring_size); + a.iter_rows_mut() + .for_each(|ai| RandomFillUniformInModulus::random_fill(p_rng, q, ai.as_mut())); + a + }; + izip!( + part_a.iter_rows_mut(), + ksk_out.iter_rows_mut(), + gadget_vector.iter(), + ) + .for_each(|(ai, bi, beta_i)| { + // si * ai + ntt_op.forward(ai.as_mut()); + mod_op.elwise_mul_mut(ai.as_mut(), to_s.as_ref()); + ntt_op.backward(ai.as_mut()); + + // ei + to_s*ai + RandomFillGaussianInModulus::random_fill(rng, &q, bi.as_mut()); + mod_op.elwise_add_mut(bi.as_mut(), ai.as_ref()); + + // beta_i * -from_s + // use ai as scratch space + mod_op.elwise_scalar_mul(ai.as_mut(), neg_from_s.as_ref(), beta_i); + + // bi = ei + to_s*ai + beta_i*-from_s + mod_op.elwise_add_mut(bi.as_mut(), ai.as_ref()); + }); +} + +pub(crate) fn galois_key_gen< + Mmut: MatrixMut + MatrixEntity, + ModOp: ArithmeticOps + + VectorOps + + GetModulus, + NttOp: Ntt, + S, + R: RandomFillGaussianInModulus<[Mmut::MatElement], ModOp::M>, + PR: RandomFillUniformInModulus<[Mmut::MatElement], ModOp::M>, +>( + ksk_out: &mut Mmut, + s: &[S], + auto_k: isize, + gadget_vector: &[Mmut::MatElement], + mod_op: &ModOp, + ntt_op: &NttOp, + p_rng: &mut PR, + rng: &mut R, +) where + ::R: RowMut, + Mmut::R: TryConvertFrom1<[S], ModOp::M> + RowEntity, + Mmut::MatElement: Copy + Sub, +{ + let ring_size = s.len(); + let (auto_map_index, auto_map_sign) = generate_auto_map(ring_size, auto_k); + + let q = mod_op.modulus(); + + // s(X) -> -s(X^k) + let s = Mmut::R::try_convert_from(s, q); + let mut neg_s_auto = Mmut::R::zeros(s.as_ref().len()); + izip!(s.as_ref(), auto_map_index.iter(), auto_map_sign.iter()).for_each( + |(el, to_index, sign)| { + // if sign is +ve (true), then negate because we need -s(X) (i.e. do the + // opposite than the usual case) + if *sign { + neg_s_auto.as_mut()[*to_index] = mod_op.neg(el); + } else { + neg_s_auto.as_mut()[*to_index] = *el; + } + }, + ); + + // Ksk from -s(X^k) to s(X) + rlwe_ksk_gen( + ksk_out, + neg_s_auto, + s, + gadget_vector, + mod_op, + ntt_op, + p_rng, + rng, + ); +} + +/// Encrypt polynomial m(X) as RLWE ciphertext. +/// +/// - rlwe_out: returned RLWE ciphertext RLWE(m) in coefficient domain. RLWE +/// ciphertext is a matirx with first row consiting polynomial `a` and the +/// second rows consting polynomial `b` +pub(crate) fn secret_key_encrypt_rlwe< + Ro: Row + RowMut + RowEntity, + ModOp: VectorOps + GetModulus, + NttOp: Ntt, + S, + R: RandomFillGaussianInModulus<[Ro::Element], ModOp::M>, + PR: RandomFillUniformInModulus<[Ro::Element], ModOp::M>, +>( + m: &[Ro::Element], + b_rlwe_out: &mut Ro, + s: &[S], + mod_op: &ModOp, + ntt_op: &NttOp, + p_rng: &mut PR, + rng: &mut R, +) where + Ro: TryConvertFrom1<[S], ModOp::M> + Debug, +{ + let ring_size = s.len(); + assert!(m.as_ref().len() == ring_size); + assert!(b_rlwe_out.as_ref().len() == ring_size); + + let q = mod_op.modulus(); + + // sample a + let mut a = { + let mut a = Ro::zeros(ring_size); + RandomFillUniformInModulus::random_fill(p_rng, q, a.as_mut()); + a + }; + + // s * a + let mut sa = Ro::try_convert_from(s, q); + ntt_op.forward(sa.as_mut()); + ntt_op.forward(a.as_mut()); + mod_op.elwise_mul_mut(sa.as_mut(), a.as_ref()); + ntt_op.backward(sa.as_mut()); + + // sample e + RandomFillGaussianInModulus::random_fill(rng, q, b_rlwe_out.as_mut()); + mod_op.elwise_add_mut(b_rlwe_out.as_mut(), m.as_ref()); + mod_op.elwise_add_mut(b_rlwe_out.as_mut(), sa.as_ref()); +} + +pub(crate) fn public_key_encrypt_rlwe< + M: Matrix, + Mmut: MatrixMut, + ModOp: VectorOps + GetModulus, + NttOp: Ntt, + S, + R: RandomFillGaussianInModulus<[M::MatElement], ModOp::M> + + RandomFillUniformInModulus<[M::MatElement], ModOp::M> + + RandomFill<[u8]> + + RandomElementInModulus, +>( + rlwe_out: &mut Mmut, + pk: &M, + m: &[M::MatElement], + mod_op: &ModOp, + ntt_op: &NttOp, + rng: &mut R, +) where + ::R: RowMut + TryConvertFrom1<[S], ModOp::M> + RowEntity, + M::MatElement: Copy, + S: Zero + Signed + Copy, +{ + let ring_size = m.len(); + assert!(rlwe_out.dimension() == (2, ring_size)); + + let q = mod_op.modulus(); + + let mut u = vec![S::zero(); ring_size]; + fill_random_ternary_secret_with_hamming_weight(u.as_mut(), ring_size >> 1, rng); + let mut u = Mmut::R::try_convert_from(&u, q); + ntt_op.forward(u.as_mut()); + + let mut ua = Mmut::R::zeros(ring_size); + ua.as_mut().copy_from_slice(pk.get_row_slice(0)); + let mut ub = Mmut::R::zeros(ring_size); + ub.as_mut().copy_from_slice(pk.get_row_slice(1)); + + // a*u + ntt_op.forward(ua.as_mut()); + mod_op.elwise_mul_mut(ua.as_mut(), u.as_ref()); + ntt_op.backward(ua.as_mut()); + + // b*u + ntt_op.forward(ub.as_mut()); + mod_op.elwise_mul_mut(ub.as_mut(), u.as_ref()); + ntt_op.backward(ub.as_mut()); + + // sample error + rlwe_out.iter_rows_mut().for_each(|ri| { + RandomFillGaussianInModulus::random_fill(rng, &q, ri.as_mut()); + }); + + // a*u + e0 + mod_op.elwise_add_mut(rlwe_out.get_row_mut(0), ua.as_ref()); + // b*u + e1 + mod_op.elwise_add_mut(rlwe_out.get_row_mut(1), ub.as_ref()); + + // b*u + e1 + m + mod_op.elwise_add_mut(rlwe_out.get_row_mut(1), m); +} + +/// Generates RLWE public key +pub(crate) fn gen_rlwe_public_key< + Ro: RowMut + RowEntity, + S, + ModOp: VectorOps + GetModulus, + NttOp: Ntt, + PRng: RandomFillUniformInModulus<[Ro::Element], ModOp::M>, + Rng: RandomFillGaussianInModulus<[Ro::Element], ModOp::M>, +>( + part_b_out: &mut Ro, + s: &[S], + ntt_op: &NttOp, + mod_op: &ModOp, + p_rng: &mut PRng, + rng: &mut Rng, +) where + Ro: TryConvertFrom1<[S], ModOp::M>, +{ + let ring_size = s.len(); + assert!(part_b_out.as_ref().len() == ring_size); + + let q = mod_op.modulus(); + + // sample a + let mut a = { + let mut tmp = Ro::zeros(ring_size); + RandomFillUniformInModulus::random_fill(p_rng, &q, tmp.as_mut()); + tmp + }; + ntt_op.forward(a.as_mut()); + + // s*a + let mut sa = Ro::try_convert_from(s, &q); + ntt_op.forward(sa.as_mut()); + mod_op.elwise_mul_mut(sa.as_mut(), a.as_ref()); + ntt_op.backward(sa.as_mut()); + + // s*a + e + RandomFillGaussianInModulus::random_fill(rng, &q, part_b_out.as_mut()); + mod_op.elwise_add_mut(part_b_out.as_mut(), sa.as_ref()); +} + +/// Decrypts degree 1 RLWE ciphertext RLWE(m) and returns m +/// +/// - rlwe_ct: input degree 1 ciphertext RLWE(m). +pub(crate) fn decrypt_rlwe< + R: RowMut, + M: Matrix, + ModOp: VectorOps + GetModulus, + NttOp: Ntt, + S, +>( + rlwe_ct: &M, + s: &[S], + m_out: &mut R, + ntt_op: &NttOp, + mod_op: &ModOp, +) where + R: TryConvertFrom1<[S], ModOp::M>, + R::Element: Copy, +{ + let ring_size = s.len(); + assert!(rlwe_ct.dimension() == (2, ring_size)); + assert!(m_out.as_ref().len() == ring_size); + + // transform a to evluation form + m_out.as_mut().copy_from_slice(rlwe_ct.get_row_slice(0)); + ntt_op.forward(m_out.as_mut()); + + // -s*a + let mut s = R::try_convert_from(&s, mod_op.modulus()); + ntt_op.forward(s.as_mut()); + mod_op.elwise_mul_mut(m_out.as_mut(), s.as_ref()); + mod_op.elwise_neg_mut(m_out.as_mut()); + ntt_op.backward(m_out.as_mut()); + + // m+e = b - s*a + mod_op.elwise_add_mut(m_out.as_mut(), rlwe_ct.get_row_slice(1)); +} + +// Measures noise in degree 1 RLWE ciphertext against encoded ideal message +// encoded_m +pub(crate) fn measure_noise< + Mmut: MatrixMut + Matrix, + ModOp: VectorOps + GetModulus, + NttOp: Ntt, + S, +>( + rlwe_ct: &Mmut, + encoded_m_ideal: &Mmut::R, + ntt_op: &NttOp, + mod_op: &ModOp, + s: &[S], +) -> f64 +where + ::R: RowMut, + Mmut::R: RowEntity + TryConvertFrom1<[S], ModOp::M>, + Mmut::MatElement: PrimInt + ToPrimitive + Debug, +{ + let ring_size = s.len(); + assert!(rlwe_ct.dimension() == (2, ring_size)); + assert!(encoded_m_ideal.as_ref().len() == ring_size); + + // -(s * a) + let q = mod_op.modulus(); + let mut s = Mmut::R::try_convert_from(s, &q); + ntt_op.forward(s.as_mut()); + let mut a = Mmut::R::zeros(ring_size); + a.as_mut().copy_from_slice(rlwe_ct.get_row_slice(0)); + ntt_op.forward(a.as_mut()); + mod_op.elwise_mul_mut(s.as_mut(), a.as_ref()); + mod_op.elwise_neg_mut(s.as_mut()); + ntt_op.backward(s.as_mut()); + + // m+e = b - s*a + let mut m_plus_e = s; + mod_op.elwise_add_mut(m_plus_e.as_mut(), rlwe_ct.get_row_slice(1)); + + // difference + mod_op.elwise_sub_mut(m_plus_e.as_mut(), encoded_m_ideal.as_ref()); + + let mut max_diff_bits = f64::MIN; + m_plus_e.as_ref().iter().for_each(|v| { + let bits = (q.map_element_to_i64(v).to_f64().unwrap()).log2(); + + if max_diff_bits < bits { + max_diff_bits = bits; + } + }); + + return max_diff_bits; +} diff --git a/src/rgsw.rs b/src/rgsw/mod.rs similarity index 50% rename from src/rgsw.rs rename to src/rgsw/mod.rs index 1218d9b..d4335c6 100644 --- a/src/rgsw.rs +++ b/src/rgsw/mod.rs @@ -1,3 +1,5 @@ +use itertools::{izip, Itertools}; +use num_traits::{PrimInt, Signed, ToPrimitive, Zero}; use std::{ clone, fmt::Debug, @@ -6,9 +8,6 @@ use std::{ ops::{Div, Neg, Sub}, }; -use itertools::{izip, Itertools}; -use num_traits::{PrimInt, Signed, ToPrimitive, Zero}; - use crate::{ backend::{ArithmeticOps, GetModulus, Modulus, VectorOps}, decomposer::{self, Decomposer, RlweDecomposer}, @@ -21,6 +20,12 @@ use crate::{ Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, Secret, }; +mod keygen; +mod runtime; + +pub(crate) use keygen::*; +pub(crate) use runtime::*; + pub struct SeededAutoKey where M: Matrix, @@ -40,9 +45,10 @@ impl> SeededA } } -pub struct AutoKeyEvaluationDomain { +pub struct AutoKeyEvaluationDomain { data: M, _phantom: PhantomData<(R, N)>, + modulus: Mod, } impl< @@ -50,10 +56,11 @@ impl< Mod: Modulus + Clone, R: RandomFillUniformInModulus<[M::MatElement], Mod> + NewWithSeed, N: NttInit + Ntt, - > From<&SeededAutoKey> for AutoKeyEvaluationDomain + > From<&SeededAutoKey> for AutoKeyEvaluationDomain where ::R: RowMut, M::MatElement: Copy, + R::Seed: Clone, { fn from(value: &SeededAutoKey) -> Self { @@ -79,10 +86,39 @@ where AutoKeyEvaluationDomain { data, _phantom: PhantomData, + modulus: value.modulus.clone(), } } } +pub(crate) trait ToShoup { + fn to_shoup(value: Self, modulus: Self) -> Self; +} + +pub struct ShoupAutoKeyEvaluationDomain { + data: M, +} + +impl, R, N> + From> for ShoupAutoKeyEvaluationDomain +where + M::R: RowMut, + M::MatElement: ToShoup + Copy, +{ + fn from(value: AutoKeyEvaluationDomain) -> Self { + let (row, col) = value.data.dimension(); + let mut shoup_data = M::zeros(row, col); + + izip!(shoup_data.iter_rows_mut(), value.data.iter_rows()).for_each(|(shoup_r, r)| { + izip!(shoup_r.as_mut().iter_mut(), r.as_ref().iter()).for_each(|(s, e)| { + *s = M::MatElement::to_shoup(*e, value.modulus.q().unwrap()); + }); + }); + + Self { data: shoup_data } + } +} + pub struct RgswCiphertext { /// Rgsw ciphertext polynomials pub(crate) data: M, @@ -157,17 +193,18 @@ where } } -pub struct RgswCiphertextEvaluationDomain { +pub struct RgswCiphertextEvaluationDomain { pub(crate) data: M, + modulus: Mod, _phantom: PhantomData<(R, N)>, } impl< M: MatrixMut + MatrixEntity, - Mod: Modulus, + Mod: Modulus + Clone, R: NewWithSeed + RandomFillUniformInModulus<[M::MatElement], Mod>, N: NttInit + Ntt + Debug, - > From<&SeededRgswCiphertext> for RgswCiphertextEvaluationDomain + > From<&SeededRgswCiphertext> for RgswCiphertextEvaluationDomain where ::R: RowMut, M::MatElement: Copy, @@ -210,6 +247,7 @@ where Self { data: data, + modulus: value.modulus.clone(), _phantom: PhantomData, } } @@ -217,10 +255,10 @@ where impl< M: MatrixMut + MatrixEntity, - Mod: Modulus, + Mod: Modulus + Clone, R, N: NttInit + Ntt, - > From<&RgswCiphertext> for RgswCiphertextEvaluationDomain + > From<&RgswCiphertext> for RgswCiphertextEvaluationDomain where ::R: RowMut, M::MatElement: Copy, @@ -255,21 +293,23 @@ where Self { data: data, + modulus: value.modulus.clone(), _phantom: PhantomData, } } } -impl Debug for RgswCiphertextEvaluationDomain { +impl Debug for RgswCiphertextEvaluationDomain { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("RgswCiphertextEvaluationDomain") .field("data", &self.data) + .field("modulus", &self.modulus) .field("_phantom", &self._phantom) .finish() } } -impl Matrix for RgswCiphertextEvaluationDomain { +impl Matrix for RgswCiphertextEvaluationDomain { type MatElement = M::MatElement; type R = M::R; @@ -282,12 +322,36 @@ impl Matrix for RgswCiphertextEvaluationDomain { } } -impl AsRef<[M::R]> for RgswCiphertextEvaluationDomain { +impl AsRef<[M::R]> for RgswCiphertextEvaluationDomain { fn as_ref(&self) -> &[M::R] { self.data.as_ref() } } +pub struct ShoupRgswCiphertextEvaluationDomain { + pub(crate) data: M, +} + +impl, R, N> + From> for ShoupRgswCiphertextEvaluationDomain +where + M::R: RowMut, + M::MatElement: ToShoup + Copy, +{ + fn from(value: RgswCiphertextEvaluationDomain) -> Self { + let (row, col) = value.data.dimension(); + let mut shoup_data = M::zeros(row, col); + + izip!(shoup_data.iter_rows_mut(), value.data.iter_rows()).for_each(|(shoup_r, r)| { + izip!(shoup_r.as_mut().iter_mut(), r.as_ref().iter()).for_each(|(s, e)| { + *s = M::MatElement::to_shoup(*e, value.modulus.q().unwrap()); + }); + }); + + Self { data: shoup_data } + } +} + pub struct SeededRlweCiphertext { pub(crate) data: R, pub(crate) seed: S, @@ -462,1021 +526,6 @@ impl RlweSecret { } } -pub(crate) fn generate_auto_map(ring_size: usize, k: isize) -> (Vec, Vec) { - assert!(k & 1 == 1, "Auto {k} must be odd"); - - let k = if k < 0 { - // k is -ve, return k%(2*N) - (2 * ring_size) - (k.abs() as usize % (2 * ring_size)) - } else { - k as usize - }; - let (auto_map_index, auto_sign_index): (Vec, Vec) = (0..ring_size) - .into_iter() - .map(|i| { - let mut to_index = (i * k) % (2 * ring_size); - let mut sign = true; - - // wrap around. false implies negative - if to_index >= ring_size { - to_index = to_index - ring_size; - sign = false; - } - - (to_index, sign) - }) - .unzip(); - (auto_map_index, auto_sign_index) -} - -pub(crate) fn routine>( - write_to_row: &mut [R::Element], - matrix_a: &[R], - matrix_b: &[R], - mod_op: &ModOp, -) { - izip!(matrix_a.iter(), matrix_b.iter()).for_each(|(a, b)| { - mod_op.elwise_fma_mut(write_to_row, a.as_ref(), b.as_ref()); - }); -} - -/// Decomposes ring polynomial r(X) into d polynomials using decomposer into -/// output matrix decomp_r -/// -/// Note that decomposition of r(X) requires decomposition of each of -/// coefficients. -/// -/// - decomp_r: must have dimensions d x ring_size. i^th decomposed polynomial -/// will be stored at i^th row. -pub(crate) fn decompose_r>( - r: &[R::Element], - decomp_r: &mut [R], - decomposer: &D, -) where - R::Element: Copy, -{ - let ring_size = r.len(); - - for ri in 0..ring_size { - decomposer - .decompose_iter(&r[ri]) - .enumerate() - .for_each(|(index, el)| { - decomp_r[index].as_mut()[ri] = el; - }); - } -} - -/// Sends RLWE_{s}(X) -> RLWE_{s}(X^k) where k is some galois element -/// -/// - scratch_matrix: must have dimension at-least d+2 x ring_size. d rows to -/// store decomposed polynomials and 2 for rlwe -pub(crate) fn galois_auto< - MT: Matrix + IsTrivial + MatrixMut, - Mmut: MatrixMut, - ModOp: ArithmeticOps + VectorOps, - NttOp: Ntt, - D: Decomposer, ->( - rlwe_in: &mut MT, - ksk: &Mmut, - scratch_matrix: &mut Mmut, - auto_map_index: &[usize], - auto_map_sign: &[bool], - mod_op: &ModOp, - ntt_op: &NttOp, - decomposer: &D, -) where - ::R: RowMut, - ::R: RowMut, - MT::MatElement: Copy + Zero, -{ - let d = decomposer.decomposition_count(); - let ring_size = rlwe_in.dimension().1; - assert!(rlwe_in.dimension().0 == 2); - assert!(scratch_matrix.fits(d + 2, ring_size)); - - // scratch matrix is guaranteed to have at-least d+2 rows but can have more than - // d+2 rows. We require to split them into sub-matrices of exact sizes one with - // d rows for storing decomposed polynomial and second with 2 rows to act - // tomperary space for RLWE ciphertext. Exact sizes is necessary to avoid any - // irrelevant extra FMA or NTT ops. - let (scratch_matrix_d_ring, other_half) = scratch_matrix.split_at_row_mut(d); - let (tmp_rlwe_out, _) = other_half.split_at_mut(2); - - debug_assert!(tmp_rlwe_out.len() == 2); - debug_assert!(scratch_matrix_d_ring.len() == d); - - if !rlwe_in.is_trivial() { - tmp_rlwe_out.iter_mut().for_each(|r| { - r.as_mut().fill(Mmut::MatElement::zero()); - }); - - // send a(X) -> a(X^k) and decompose a(X^k) - izip!( - rlwe_in.get_row(0), - auto_map_index.iter(), - auto_map_sign.iter() - ) - .for_each(|(el_in, to_index, sign)| { - let el_out = if !*sign { mod_op.neg(el_in) } else { *el_in }; - - decomposer - .decompose_iter(&el_out) - .enumerate() - .for_each(|(index, el)| { - scratch_matrix_d_ring[index].as_mut()[*to_index] = el; - }); - }); - - // transform decomposed a(X^k) to evaluation domain - scratch_matrix_d_ring.iter_mut().for_each(|r| { - ntt_op.forward(r.as_mut()); - }); - - // RLWE(m^k) = a', b'; RLWE(m) = a, b - // key switch: (a * RLWE'(s(X^k))) - let (ksk_a, ksk_b) = ksk.split_at_row(d); - // a' = decomp * RLWE'_A(s(X^k)) - routine( - tmp_rlwe_out[0].as_mut(), - scratch_matrix_d_ring, - ksk_a, - mod_op, - ); - - // b' += decomp * RLWE'_B(s(X^k)) - routine( - tmp_rlwe_out[1].as_mut(), - scratch_matrix_d_ring, - ksk_b, - mod_op, - ); - - // transform RLWE(m^k) to coefficient domain - tmp_rlwe_out - .iter_mut() - .for_each(|r| ntt_op.backward(r.as_mut())); - - // send b(X) -> b(X^k) and then b'(X) += b(X^k) - izip!( - rlwe_in.get_row(1), - auto_map_index.iter(), - auto_map_sign.iter() - ) - .for_each(|(el_in, to_index, sign)| { - let row = tmp_rlwe_out[1].as_mut(); - if !*sign { - row[*to_index] = mod_op.sub(&row[*to_index], el_in); - } else { - row[*to_index] = mod_op.add(&row[*to_index], el_in); - } - }); - - // copy over A; Leave B for later - rlwe_in - .get_row_mut(0) - .copy_from_slice(tmp_rlwe_out[0].as_ref()); - } else { - // RLWE is trivial, a(X) is 0. - // send b(X) -> b(X^k) - izip!( - rlwe_in.get_row(1), - auto_map_index.iter(), - auto_map_sign.iter() - ) - .for_each(|(el_in, to_index, sign)| { - if !*sign { - tmp_rlwe_out[1].as_mut()[*to_index] = mod_op.neg(el_in); - } else { - tmp_rlwe_out[1].as_mut()[*to_index] = *el_in; - } - }); - } - - // Copy over B - rlwe_in - .get_row_mut(1) - .copy_from_slice(tmp_rlwe_out[1].as_ref()); -} - -/// Returns RLWE(m0m1) = RLWE(m0) x RGSW(m1). Mutates rlwe_in inplace to equal -/// RLWE(m0m1) -/// -/// - rlwe_in: is RLWE(m0) with polynomials in coefficient domain -/// - rgsw_in: is RGSW(m1) with polynomials in evaluation domain -/// - scratch_matrix_d_ring: is a matrix with atleast max(d_a, d_b) rows and -/// ring_size columns. It's used to store decomposed polynomials and out RLWE -/// temoporarily -pub(crate) fn rlwe_by_rgsw< - Mmut: MatrixMut, - MT: Matrix + MatrixMut + IsTrivial, - D: RlweDecomposer, - ModOp: VectorOps, - NttOp: Ntt, ->( - rlwe_in: &mut MT, - rgsw_in: &Mmut, - scratch_matrix: &mut Mmut, - decomposer: &D, - ntt_op: &NttOp, - mod_op: &ModOp, -) where - Mmut::MatElement: Copy + Zero, - ::R: RowMut, - ::R: RowMut, -{ - let decomposer_a = decomposer.a(); - let decomposer_b = decomposer.b(); - let d_a = decomposer_a.decomposition_count(); - let d_b = decomposer_b.decomposition_count(); - let max_d = std::cmp::max(d_a, d_b); - assert!(scratch_matrix.fits(max_d + 2, rlwe_in.dimension().1)); - assert!(rgsw_in.dimension() == (d_a * 2 + d_b * 2, rlwe_in.dimension().1)); - - // decomposed RLWE x RGSW - let (rlwe_dash_nsm, rlwe_dash_m) = rgsw_in.split_at_row(d_a * 2); - let (scratch_matrix_d_ring, scratch_rlwe_out) = scratch_matrix.split_at_row_mut(max_d); - scratch_rlwe_out[0].as_mut().fill(Mmut::MatElement::zero()); - scratch_rlwe_out[1].as_mut().fill(Mmut::MatElement::zero()); - // RLWE_in = a_in, b_in; RLWE_out = a_out, b_out - if !rlwe_in.is_trivial() { - // a_in = 0 when RLWE_in is trivial RLWE ciphertext - // decomp - decompose_r( - rlwe_in.get_row_slice(0), - &mut scratch_matrix_d_ring[..d_a], - decomposer_a, - ); - scratch_matrix_d_ring - .iter_mut() - .take(d_a) - .for_each(|r| ntt_op.forward(r.as_mut())); - // a_out += decomp \cdot RLWE_A'(-sm) - routine( - scratch_rlwe_out[0].as_mut(), - scratch_matrix_d_ring.as_ref(), - &rlwe_dash_nsm[..d_a], - mod_op, - ); - // b_out += decomp \cdot RLWE_B'(-sm) - routine( - scratch_rlwe_out[1].as_mut(), - scratch_matrix_d_ring.as_ref(), - &rlwe_dash_nsm[d_a..], - mod_op, - ); - } - // decomp - decompose_r( - rlwe_in.get_row_slice(1), - &mut scratch_matrix_d_ring[..d_b], - decomposer_b, - ); - scratch_matrix_d_ring - .iter_mut() - .take(d_b) - .for_each(|r| ntt_op.forward(r.as_mut())); - // a_out += decomp \cdot RLWE_A'(m) - routine( - scratch_rlwe_out[0].as_mut(), - scratch_matrix_d_ring.as_ref(), - &rlwe_dash_m[..d_b], - mod_op, - ); - // b_out += decomp \cdot RLWE_B'(m) - routine( - scratch_rlwe_out[1].as_mut(), - scratch_matrix_d_ring.as_ref(), - &rlwe_dash_m[d_b..], - mod_op, - ); - - // transform rlwe_out to coefficient domain - scratch_rlwe_out - .iter_mut() - .for_each(|r| ntt_op.backward(r.as_mut())); - - rlwe_in - .get_row_mut(0) - .copy_from_slice(scratch_rlwe_out[0].as_mut()); - rlwe_in - .get_row_mut(1) - .copy_from_slice(scratch_rlwe_out[1].as_mut()); - rlwe_in.set_not_trivial(); -} - -/// Inplace mutates rlwe_0 to equal RGSW(m0m1) = RGSW(m0)xRGSW(m1) -/// in evaluation domain -/// -/// Warning - -/// Pass a fresh RGSW ciphertext as the second operand, i.e. as `rgsw_1`. -/// This is to assure minimal error growth in the resulting RGSW ciphertext. -/// RGSWxRGSW boils down to d_rgsw*2 RLWExRGSW multiplications. Hence, the noise -/// growth in resulting ciphertext depends on the norm of second RGSW -/// ciphertext, not the first. This is useful in cases where one is accumulating -/// multiple RGSW ciphertexts into 1. In which case, pass the accumulating RGSW -/// ciphertext as rlwe_0 (the one with higher noise) and subsequent RGSW -/// ciphertexts, with lower noise, to be accumulated as second -/// operand. -/// -/// - rgsw_0: RGSW(m0) -/// - rgsw_1_eval: RGSW(m1) in Evaluation domain -/// - scratch_matrix_d_plus_rgsw_by_ring: scratch space matrix with rows -/// (max(d_a, d_b) + d_a*2+d_b*2) and columns ring_size -pub(crate) fn rgsw_by_rgsw_inplace< - Mmut: MatrixMut, - D: RlweDecomposer, - ModOp: VectorOps, - NttOp: Ntt, ->( - rgsw_0: &mut Mmut, - rgsw_1_eval: &Mmut, - decomposer: &D, - scratch_matrix: &mut Mmut, - ntt_op: &NttOp, - mod_op: &ModOp, -) where - ::R: RowMut, - Mmut::MatElement: Copy + Zero, -{ - let decomposer_a = decomposer.a(); - let decomposer_b = decomposer.b(); - let d_a = decomposer_a.decomposition_count(); - let d_b = decomposer_b.decomposition_count(); - let max_d = std::cmp::max(d_a, d_b); - let rgsw_rows = d_a * 2 + d_b * 2; - assert!(rgsw_0.dimension().0 == rgsw_rows); - let ring_size = rgsw_0.dimension().1; - assert!(rgsw_1_eval.dimension() == (rgsw_rows, ring_size)); - assert!(scratch_matrix.fits(max_d + rgsw_rows, ring_size)); - - let (decomp_r_space, rgsw_space) = scratch_matrix.split_at_row_mut(max_d); - - // zero rgsw_space - rgsw_space - .iter_mut() - .for_each(|ri| ri.as_mut().fill(Mmut::MatElement::zero())); - let (rlwe_dash_space_nsm, rlwe_dash_space_m) = rgsw_space.split_at_mut(d_a * 2); - let (rlwe_dash_space_nsm_parta, rlwe_dash_space_nsm_partb) = - rlwe_dash_space_nsm.split_at_mut(d_a); - let (rlwe_dash_space_m_parta, rlwe_dash_space_m_partb) = rlwe_dash_space_m.split_at_mut(d_b); - - let (rgsw0_nsm, rgsw0_m) = rgsw_0.split_at_row(d_a * 2); - let (rgsw1_nsm, rgsw1_m) = rgsw_1_eval.split_at_row(d_a * 2); - - // RGSW x RGSW - izip!( - rgsw0_nsm.iter().take(d_a).chain(rgsw0_m.iter().take(d_b)), - rgsw0_nsm.iter().skip(d_a).chain(rgsw0_m.iter().skip(d_b)), - rlwe_dash_space_nsm_parta - .iter_mut() - .chain(rlwe_dash_space_m_parta.iter_mut()), - rlwe_dash_space_nsm_partb - .iter_mut() - .chain(rlwe_dash_space_m_partb.iter_mut()), - ) - .for_each(|(rlwe_a, rlwe_b, rlwe_out_a, rlwe_out_b)| { - // Part A - decompose_r(rlwe_a.as_ref(), decomp_r_space.as_mut(), decomposer_a); - decomp_r_space - .iter_mut() - .take(d_a) - .for_each(|ri| ntt_op.forward(ri.as_mut())); - routine( - rlwe_out_a.as_mut(), - &decomp_r_space[..d_a], - &rgsw1_nsm[..d_a], - mod_op, - ); - routine( - rlwe_out_b.as_mut(), - &decomp_r_space[..d_a], - &rgsw1_nsm[d_a..], - mod_op, - ); - - // Part B - decompose_r(rlwe_b.as_ref(), decomp_r_space.as_mut(), decomposer_b); - decomp_r_space - .iter_mut() - .take(d_b) - .for_each(|ri| ntt_op.forward(ri.as_mut())); - routine( - rlwe_out_a.as_mut(), - &decomp_r_space[..d_b], - &rgsw1_m[..d_b], - mod_op, - ); - routine( - rlwe_out_b.as_mut(), - &decomp_r_space[..d_b], - &rgsw1_m[d_b..], - mod_op, - ); - }); - - // copy over RGSW(m0m1) into RGSW(m0) - izip!(rgsw_0.iter_rows_mut(), rgsw_space.iter()) - .for_each(|(to_ri, from_ri)| to_ri.as_mut().copy_from_slice(from_ri.as_ref())); - - // send back to coefficient domain - rgsw_0 - .iter_rows_mut() - .for_each(|ri| ntt_op.backward(ri.as_mut())); -} - -/// Encrypts message m as a RGSW ciphertext. -/// -/// - m_eval: is `m` is evaluation domain -/// - out_rgsw: RGSW(m) is stored as single matrix of dimension (d_rgsw * 3, -/// ring_size). The matrix has the following structure [RLWE'_A(-sm) || -/// RLWE'_B(-sm) || RLWE'_B(m)]^T and RLWE'_A(m) is generated via seed (where -/// p_rng is assumed to be seeded with seed) -pub(crate) fn secret_key_encrypt_rgsw< - Mmut: MatrixMut + MatrixEntity, - S, - R: RandomFillGaussianInModulus<[Mmut::MatElement], ModOp::M> - + RandomFillUniformInModulus<[Mmut::MatElement], ModOp::M>, - PR: RandomFillUniformInModulus<[Mmut::MatElement], ModOp::M>, - ModOp: VectorOps + GetModulus, - NttOp: Ntt, ->( - out_rgsw: &mut Mmut, - m: &[Mmut::MatElement], - gadget_a: &[Mmut::MatElement], - gadget_b: &[Mmut::MatElement], - s: &[S], - mod_op: &ModOp, - ntt_op: &NttOp, - p_rng: &mut PR, - rng: &mut R, -) where - ::R: RowMut + RowEntity + TryConvertFrom1<[S], ModOp::M> + Debug, - Mmut::MatElement: Copy + Debug, -{ - let d_a = gadget_a.len(); - let d_b = gadget_b.len(); - let q = mod_op.modulus(); - let ring_size = s.len(); - assert!(out_rgsw.dimension() == (d_a * 2 + d_b, ring_size)); - assert!(m.as_ref().len() == ring_size); - - // RLWE(-sm), RLWE(m) - let (rlwe_dash_nsm, b_rlwe_dash_m) = out_rgsw.split_at_row_mut(d_a * 2); - - let mut s_eval = Mmut::R::try_convert_from(s, &q); - ntt_op.forward(s_eval.as_mut()); - - let mut scratch_space = Mmut::R::zeros(ring_size); - - // RLWE'(-sm) - let (a_rlwe_dash_nsm, b_rlwe_dash_nsm) = rlwe_dash_nsm.split_at_mut(d_a); - izip!( - a_rlwe_dash_nsm.iter_mut(), - b_rlwe_dash_nsm.iter_mut(), - gadget_a.iter() - ) - .for_each(|(ai, bi, beta_i)| { - // Sample a_i - RandomFillUniformInModulus::random_fill(rng, &q, ai.as_mut()); - - // a_i * s - scratch_space.as_mut().copy_from_slice(ai.as_ref()); - ntt_op.forward(scratch_space.as_mut()); - mod_op.elwise_mul_mut(scratch_space.as_mut(), s_eval.as_ref()); - ntt_op.backward(scratch_space.as_mut()); - - // b_i = e_i + a_i * s - RandomFillGaussianInModulus::random_fill(rng, &q, bi.as_mut()); - mod_op.elwise_add_mut(bi.as_mut(), scratch_space.as_ref()); - - // a_i + \beta_i * m - mod_op.elwise_scalar_mul(scratch_space.as_mut(), m.as_ref(), beta_i); - mod_op.elwise_add_mut(ai.as_mut(), scratch_space.as_ref()); - }); - - // RLWE(m) - let mut a_rlwe_dash_m = { - // polynomials of part A of RLWE'(m) are sampled from seed - let mut a = Mmut::zeros(d_b, ring_size); - a.iter_rows_mut() - .for_each(|ai| RandomFillUniformInModulus::random_fill(p_rng, &q, ai.as_mut())); - a - }; - - izip!( - a_rlwe_dash_m.iter_rows_mut(), - b_rlwe_dash_m.iter_mut(), - gadget_b.iter() - ) - .for_each(|(ai, bi, beta_i)| { - // ai * s - ntt_op.forward(ai.as_mut()); - mod_op.elwise_mul_mut(ai.as_mut(), s_eval.as_ref()); - ntt_op.backward(ai.as_mut()); - - // beta_i * m - mod_op.elwise_scalar_mul(scratch_space.as_mut(), m.as_ref(), beta_i); - - // Sample e_i - RandomFillGaussianInModulus::random_fill(rng, &q, bi.as_mut()); - // e_i + beta_i * m + ai*s - mod_op.elwise_add_mut(bi.as_mut(), scratch_space.as_ref()); - mod_op.elwise_add_mut(bi.as_mut(), ai.as_ref()); - }); -} - -pub(crate) fn public_key_encrypt_rgsw< - Mmut: MatrixMut + MatrixEntity, - M: Matrix, - R: RandomFillGaussianInModulus<[Mmut::MatElement], ModOp::M> - + RandomFill<[u8]> - + RandomElementInModulus, - ModOp: VectorOps + GetModulus, - NttOp: Ntt, ->( - out_rgsw: &mut Mmut, - m: &[M::MatElement], - public_key: &M, - gadget_a: &[Mmut::MatElement], - gadget_b: &[Mmut::MatElement], - mod_op: &ModOp, - ntt_op: &NttOp, - rng: &mut R, -) where - ::R: RowMut + RowEntity + TryConvertFrom1<[i32], ModOp::M>, - Mmut::MatElement: Copy, -{ - let ring_size = public_key.dimension().1; - let d_a = gadget_a.len(); - let d_b = gadget_b.len(); - assert!(public_key.dimension().0 == 2); - assert!(out_rgsw.dimension() == (d_a * 2 + d_b * 2, ring_size)); - - let mut pk_eval = Mmut::zeros(2, ring_size); - izip!(pk_eval.iter_rows_mut(), public_key.iter_rows()).for_each(|(to_i, from_i)| { - to_i.as_mut().copy_from_slice(from_i.as_ref()); - ntt_op.forward(to_i.as_mut()); - }); - let p0 = pk_eval.get_row_slice(0); - let p1 = pk_eval.get_row_slice(1); - - let q = mod_op.modulus(); - - // RGSW(m) = RLWE'(-sm), RLWE(m) - let (rlwe_dash_nsm, rlwe_dash_m) = out_rgsw.split_at_row_mut(d_a * 2); - - // RLWE(-sm) - let (rlwe_dash_nsm_parta, rlwe_dash_nsm_partb) = rlwe_dash_nsm.split_at_mut(d_a); - izip!( - rlwe_dash_nsm_parta.iter_mut(), - rlwe_dash_nsm_partb.iter_mut(), - gadget_a.iter() - ) - .for_each(|(ai, bi, beta_i)| { - // sample ephemeral secret u_i - let mut u = vec![0i32; ring_size]; - fill_random_ternary_secret_with_hamming_weight(u.as_mut(), ring_size >> 1, rng); - let mut u_eval = Mmut::R::try_convert_from(u.as_ref(), &q); - ntt_op.forward(u_eval.as_mut()); - - let mut u_eval_copy = Mmut::R::zeros(ring_size); - u_eval_copy.as_mut().copy_from_slice(u_eval.as_ref()); - - // p0 * u - mod_op.elwise_mul_mut(u_eval.as_mut(), p0.as_ref()); - // p1 * u - mod_op.elwise_mul_mut(u_eval_copy.as_mut(), p1.as_ref()); - ntt_op.backward(u_eval.as_mut()); - ntt_op.backward(u_eval_copy.as_mut()); - - // sample error - RandomFillGaussianInModulus::random_fill(rng, &q, ai.as_mut()); - RandomFillGaussianInModulus::random_fill(rng, &q, bi.as_mut()); - - // a = p0*u+e0 - mod_op.elwise_add_mut(ai.as_mut(), u_eval.as_ref()); - // b = p1*u+e1 - mod_op.elwise_add_mut(bi.as_mut(), u_eval_copy.as_ref()); - - // a = p0*u + e0 + \beta*m - // use u_eval as scratch - mod_op.elwise_scalar_mul(u_eval.as_mut(), m.as_ref(), beta_i); - mod_op.elwise_add_mut(ai.as_mut(), u_eval.as_ref()); - }); - - // RLWE(m) - let (rlwe_dash_m_parta, rlwe_dash_m_partb) = rlwe_dash_m.split_at_mut(d_b); - izip!( - rlwe_dash_m_parta.iter_mut(), - rlwe_dash_m_partb.iter_mut(), - gadget_b.iter() - ) - .for_each(|(ai, bi, beta_i)| { - // sample ephemeral secret u_i - let mut u = vec![0i32; ring_size]; - fill_random_ternary_secret_with_hamming_weight(u.as_mut(), ring_size >> 1, rng); - let mut u_eval = Mmut::R::try_convert_from(u.as_ref(), &q); - ntt_op.forward(u_eval.as_mut()); - - let mut u_eval_copy = Mmut::R::zeros(ring_size); - u_eval_copy.as_mut().copy_from_slice(u_eval.as_ref()); - - // p0 * u - mod_op.elwise_mul_mut(u_eval.as_mut(), p0.as_ref()); - // p1 * u - mod_op.elwise_mul_mut(u_eval_copy.as_mut(), p1.as_ref()); - ntt_op.backward(u_eval.as_mut()); - ntt_op.backward(u_eval_copy.as_mut()); - - // sample error - RandomFillGaussianInModulus::random_fill(rng, &q, ai.as_mut()); - RandomFillGaussianInModulus::random_fill(rng, &q, bi.as_mut()); - - // a = p0*u+e0 - mod_op.elwise_add_mut(ai.as_mut(), u_eval.as_ref()); - // b = p1*u+e1 - mod_op.elwise_add_mut(bi.as_mut(), u_eval_copy.as_ref()); - - // b = p1*u + e0 + \beta*m - // use u_eval as scratch - mod_op.elwise_scalar_mul(u_eval.as_mut(), m.as_ref(), beta_i); - mod_op.elwise_add_mut(bi.as_mut(), u_eval.as_ref()); - }); -} - -/// Generates RLWE Key switching key to key switch ciphertext RLWE_{from_s}(m) -/// to RLWE_{to_s}(m). -/// -/// Key switching equals -/// \sum decompose(c_1)_i * RLWE_{to_s}(\beta^i -from_s) -/// Hence, key switchin key equals RLWE'(-from_s) = RLWE(-from_s), RLWE(beta^1 -/// -from_s), ..., RLWE(beta^{d-1} -from_s). -/// -/// - ksk_out: Output Key switching key. Key switching key stores only part B -/// polynomials of ksk RLWE ciphertexts (i.e. RLWE'_B(-from_s)) in coefficient -/// domain -/// - neg_from_s: Negative of secret polynomial to key switch from -/// - to_s: secret polynomial to key switch to. -pub(crate) fn rlwe_ksk_gen< - Mmut: MatrixMut + MatrixEntity, - ModOp: ArithmeticOps - + VectorOps - + GetModulus, - NttOp: Ntt, - R: RandomFillGaussianInModulus<[Mmut::MatElement], ModOp::M>, - PR: RandomFillUniformInModulus<[Mmut::MatElement], ModOp::M>, ->( - ksk_out: &mut Mmut, - neg_from_s: Mmut::R, - mut to_s: Mmut::R, - gadget_vector: &[Mmut::MatElement], - mod_op: &ModOp, - ntt_op: &NttOp, - p_rng: &mut PR, - rng: &mut R, -) where - ::R: RowMut, -{ - let ring_size = neg_from_s.as_ref().len(); - let d = gadget_vector.len(); - assert!(ksk_out.dimension() == (d, ring_size)); - - let q = mod_op.modulus(); - - ntt_op.forward(to_s.as_mut()); - - // RLWE'_{to_s}(-from_s) - let mut part_a = { - let mut a = Mmut::zeros(d, ring_size); - a.iter_rows_mut() - .for_each(|ai| RandomFillUniformInModulus::random_fill(p_rng, q, ai.as_mut())); - a - }; - izip!( - part_a.iter_rows_mut(), - ksk_out.iter_rows_mut(), - gadget_vector.iter(), - ) - .for_each(|(ai, bi, beta_i)| { - // si * ai - ntt_op.forward(ai.as_mut()); - mod_op.elwise_mul_mut(ai.as_mut(), to_s.as_ref()); - ntt_op.backward(ai.as_mut()); - - // ei + to_s*ai - RandomFillGaussianInModulus::random_fill(rng, &q, bi.as_mut()); - mod_op.elwise_add_mut(bi.as_mut(), ai.as_ref()); - - // beta_i * -from_s - // use ai as scratch space - mod_op.elwise_scalar_mul(ai.as_mut(), neg_from_s.as_ref(), beta_i); - - // bi = ei + to_s*ai + beta_i*-from_s - mod_op.elwise_add_mut(bi.as_mut(), ai.as_ref()); - }); -} - -pub(crate) fn galois_key_gen< - Mmut: MatrixMut + MatrixEntity, - ModOp: ArithmeticOps - + VectorOps - + GetModulus, - NttOp: Ntt, - S, - R: RandomFillGaussianInModulus<[Mmut::MatElement], ModOp::M>, - PR: RandomFillUniformInModulus<[Mmut::MatElement], ModOp::M>, ->( - ksk_out: &mut Mmut, - s: &[S], - auto_k: isize, - gadget_vector: &[Mmut::MatElement], - mod_op: &ModOp, - ntt_op: &NttOp, - p_rng: &mut PR, - rng: &mut R, -) where - ::R: RowMut, - Mmut::R: TryConvertFrom1<[S], ModOp::M> + RowEntity, - Mmut::MatElement: Copy + Sub, -{ - let ring_size = s.len(); - let (auto_map_index, auto_map_sign) = generate_auto_map(ring_size, auto_k); - - let q = mod_op.modulus(); - - // s(X) -> -s(X^k) - let s = Mmut::R::try_convert_from(s, q); - let mut neg_s_auto = Mmut::R::zeros(s.as_ref().len()); - izip!(s.as_ref(), auto_map_index.iter(), auto_map_sign.iter()).for_each( - |(el, to_index, sign)| { - // if sign is +ve (true), then negate because we need -s(X) (i.e. do the - // opposite than the usual case) - if *sign { - neg_s_auto.as_mut()[*to_index] = mod_op.neg(el); - } else { - neg_s_auto.as_mut()[*to_index] = *el; - } - }, - ); - - // Ksk from -s(X^k) to s(X) - rlwe_ksk_gen( - ksk_out, - neg_s_auto, - s, - gadget_vector, - mod_op, - ntt_op, - p_rng, - rng, - ); -} - -/// Encrypt polynomial m(X) as RLWE ciphertext. -/// -/// - rlwe_out: returned RLWE ciphertext RLWE(m) in coefficient domain. RLWE -/// ciphertext is a matirx with first row consiting polynomial `a` and the -/// second rows consting polynomial `b` -pub(crate) fn secret_key_encrypt_rlwe< - Ro: Row + RowMut + RowEntity, - ModOp: VectorOps + GetModulus, - NttOp: Ntt, - S, - R: RandomFillGaussianInModulus<[Ro::Element], ModOp::M>, - PR: RandomFillUniformInModulus<[Ro::Element], ModOp::M>, ->( - m: &[Ro::Element], - b_rlwe_out: &mut Ro, - s: &[S], - mod_op: &ModOp, - ntt_op: &NttOp, - p_rng: &mut PR, - rng: &mut R, -) where - Ro: TryConvertFrom1<[S], ModOp::M> + Debug, -{ - let ring_size = s.len(); - assert!(m.as_ref().len() == ring_size); - assert!(b_rlwe_out.as_ref().len() == ring_size); - - let q = mod_op.modulus(); - - // sample a - let mut a = { - let mut a = Ro::zeros(ring_size); - RandomFillUniformInModulus::random_fill(p_rng, q, a.as_mut()); - a - }; - - // s * a - let mut sa = Ro::try_convert_from(s, q); - ntt_op.forward(sa.as_mut()); - ntt_op.forward(a.as_mut()); - mod_op.elwise_mul_mut(sa.as_mut(), a.as_ref()); - ntt_op.backward(sa.as_mut()); - - // sample e - RandomFillGaussianInModulus::random_fill(rng, q, b_rlwe_out.as_mut()); - mod_op.elwise_add_mut(b_rlwe_out.as_mut(), m.as_ref()); - mod_op.elwise_add_mut(b_rlwe_out.as_mut(), sa.as_ref()); -} - -pub(crate) fn public_key_encrypt_rlwe< - M: Matrix, - Mmut: MatrixMut, - ModOp: VectorOps + GetModulus, - NttOp: Ntt, - S, - R: RandomFillGaussianInModulus<[M::MatElement], ModOp::M> - + RandomFillUniformInModulus<[M::MatElement], ModOp::M> - + RandomFill<[u8]> - + RandomElementInModulus, ->( - rlwe_out: &mut Mmut, - pk: &M, - m: &[M::MatElement], - mod_op: &ModOp, - ntt_op: &NttOp, - rng: &mut R, -) where - ::R: RowMut + TryConvertFrom1<[S], ModOp::M> + RowEntity, - M::MatElement: Copy, - S: Zero + Signed + Copy, -{ - let ring_size = m.len(); - assert!(rlwe_out.dimension() == (2, ring_size)); - - let q = mod_op.modulus(); - - let mut u = vec![S::zero(); ring_size]; - fill_random_ternary_secret_with_hamming_weight(u.as_mut(), ring_size >> 1, rng); - let mut u = Mmut::R::try_convert_from(&u, q); - ntt_op.forward(u.as_mut()); - - let mut ua = Mmut::R::zeros(ring_size); - ua.as_mut().copy_from_slice(pk.get_row_slice(0)); - let mut ub = Mmut::R::zeros(ring_size); - ub.as_mut().copy_from_slice(pk.get_row_slice(1)); - - // a*u - ntt_op.forward(ua.as_mut()); - mod_op.elwise_mul_mut(ua.as_mut(), u.as_ref()); - ntt_op.backward(ua.as_mut()); - - // b*u - ntt_op.forward(ub.as_mut()); - mod_op.elwise_mul_mut(ub.as_mut(), u.as_ref()); - ntt_op.backward(ub.as_mut()); - - // sample error - rlwe_out.iter_rows_mut().for_each(|ri| { - RandomFillGaussianInModulus::random_fill(rng, &q, ri.as_mut()); - }); - - // a*u + e0 - mod_op.elwise_add_mut(rlwe_out.get_row_mut(0), ua.as_ref()); - // b*u + e1 - mod_op.elwise_add_mut(rlwe_out.get_row_mut(1), ub.as_ref()); - - // b*u + e1 + m - mod_op.elwise_add_mut(rlwe_out.get_row_mut(1), m); -} - -/// Generates RLWE public key -pub(crate) fn gen_rlwe_public_key< - Ro: RowMut + RowEntity, - S, - ModOp: VectorOps + GetModulus, - NttOp: Ntt, - PRng: RandomFillUniformInModulus<[Ro::Element], ModOp::M>, - Rng: RandomFillGaussianInModulus<[Ro::Element], ModOp::M>, ->( - part_b_out: &mut Ro, - s: &[S], - ntt_op: &NttOp, - mod_op: &ModOp, - p_rng: &mut PRng, - rng: &mut Rng, -) where - Ro: TryConvertFrom1<[S], ModOp::M>, -{ - let ring_size = s.len(); - assert!(part_b_out.as_ref().len() == ring_size); - - let q = mod_op.modulus(); - - // sample a - let mut a = { - let mut tmp = Ro::zeros(ring_size); - RandomFillUniformInModulus::random_fill(p_rng, &q, tmp.as_mut()); - tmp - }; - ntt_op.forward(a.as_mut()); - - // s*a - let mut sa = Ro::try_convert_from(s, &q); - ntt_op.forward(sa.as_mut()); - mod_op.elwise_mul_mut(sa.as_mut(), a.as_ref()); - ntt_op.backward(sa.as_mut()); - - // s*a + e - RandomFillGaussianInModulus::random_fill(rng, &q, part_b_out.as_mut()); - mod_op.elwise_add_mut(part_b_out.as_mut(), sa.as_ref()); -} - -/// Decrypts degree 1 RLWE ciphertext RLWE(m) and returns m -/// -/// - rlwe_ct: input degree 1 ciphertext RLWE(m). -pub(crate) fn decrypt_rlwe< - R: RowMut, - M: Matrix, - ModOp: VectorOps + GetModulus, - NttOp: Ntt, - S, ->( - rlwe_ct: &M, - s: &[S], - m_out: &mut R, - ntt_op: &NttOp, - mod_op: &ModOp, -) where - R: TryConvertFrom1<[S], ModOp::M>, - R::Element: Copy, -{ - let ring_size = s.len(); - assert!(rlwe_ct.dimension() == (2, ring_size)); - assert!(m_out.as_ref().len() == ring_size); - - // transform a to evluation form - m_out.as_mut().copy_from_slice(rlwe_ct.get_row_slice(0)); - ntt_op.forward(m_out.as_mut()); - - // -s*a - let mut s = R::try_convert_from(&s, mod_op.modulus()); - ntt_op.forward(s.as_mut()); - mod_op.elwise_mul_mut(m_out.as_mut(), s.as_ref()); - mod_op.elwise_neg_mut(m_out.as_mut()); - ntt_op.backward(m_out.as_mut()); - - // m+e = b - s*a - mod_op.elwise_add_mut(m_out.as_mut(), rlwe_ct.get_row_slice(1)); -} - -// Measures noise in degree 1 RLWE ciphertext against encoded ideal message -// encoded_m -pub(crate) fn measure_noise< - Mmut: MatrixMut + Matrix, - ModOp: VectorOps + GetModulus, - NttOp: Ntt, - S, ->( - rlwe_ct: &Mmut, - encoded_m_ideal: &Mmut::R, - ntt_op: &NttOp, - mod_op: &ModOp, - s: &[S], -) -> f64 -where - ::R: RowMut, - Mmut::R: RowEntity + TryConvertFrom1<[S], ModOp::M>, - Mmut::MatElement: PrimInt + ToPrimitive + Debug, -{ - let ring_size = s.len(); - assert!(rlwe_ct.dimension() == (2, ring_size)); - assert!(encoded_m_ideal.as_ref().len() == ring_size); - - // -(s * a) - let q = mod_op.modulus(); - let mut s = Mmut::R::try_convert_from(s, &q); - ntt_op.forward(s.as_mut()); - let mut a = Mmut::R::zeros(ring_size); - a.as_mut().copy_from_slice(rlwe_ct.get_row_slice(0)); - ntt_op.forward(a.as_mut()); - mod_op.elwise_mul_mut(s.as_mut(), a.as_ref()); - mod_op.elwise_neg_mut(s.as_mut()); - ntt_op.backward(s.as_mut()); - - // m+e = b - s*a - let mut m_plus_e = s; - mod_op.elwise_add_mut(m_plus_e.as_mut(), rlwe_ct.get_row_slice(1)); - - // difference - mod_op.elwise_sub_mut(m_plus_e.as_mut(), encoded_m_ideal.as_ref()); - - let mut max_diff_bits = f64::MIN; - m_plus_e.as_ref().iter().for_each(|v| { - let bits = (q.map_element_to_i64(v).to_f64().unwrap()).log2(); - - if max_diff_bits < bits { - max_diff_bits = bits; - } - }); - - return max_diff_bits; -} - #[cfg(test)] pub(crate) mod tests { use std::{marker::PhantomData, ops::Mul, vec}; @@ -1489,19 +538,19 @@ pub(crate) mod tests { decomposer::{Decomposer, DefaultDecomposer, RlweDecomposer}, ntt::{self, Ntt, NttBackendU64, NttInit}, random::{DefaultSecureRng, NewWithSeed, RandomFillUniformInModulus}, - rgsw::{ - gen_rlwe_public_key, measure_noise, public_key_encrypt_rgsw, AutoKeyEvaluationDomain, - RgswCiphertext, RgswCiphertextEvaluationDomain, RlweCiphertext, RlwePublicKey, - SeededAutoKey, SeededRgswCiphertext, SeededRlweCiphertext, SeededRlwePublicKey, - }, utils::{generate_prime, negacyclic_mul, Stats, TryConvertFrom1}, Matrix, Secret, }; use super::{ - decrypt_rlwe, galois_auto, galois_key_gen, generate_auto_map, public_key_encrypt_rlwe, - rgsw_by_rgsw_inplace, rlwe_by_rgsw, secret_key_encrypt_rgsw, secret_key_encrypt_rlwe, - RlweSecret, + keygen::{ + decrypt_rlwe, galois_key_gen, gen_rlwe_public_key, generate_auto_map, measure_noise, + public_key_encrypt_rgsw, secret_key_encrypt_rgsw, secret_key_encrypt_rlwe, + }, + runtime::{galois_auto, rgsw_by_rgsw_inplace, rlwe_by_rgsw}, + AutoKeyEvaluationDomain, RgswCiphertext, RgswCiphertextEvaluationDomain, RlweCiphertext, + RlwePublicKey, RlweSecret, SeededAutoKey, SeededRgswCiphertext, SeededRlweCiphertext, + SeededRlwePublicKey, }; pub(crate) fn _sk_encrypt_rlwe + Clone>( @@ -1772,11 +821,13 @@ pub(crate) mod tests { // Encryption m1 as RGSW(m1) using secret key let seeded_rgsw_ct = _sk_encrypt_rgsw(&m1, s.values(), &decomposer, &mod_op, &ntt_op); - RgswCiphertextEvaluationDomain::>, DefaultSecureRng, NttBackendU64>::from(&seeded_rgsw_ct) + RgswCiphertextEvaluationDomain::>, _,DefaultSecureRng, NttBackendU64>::from(&seeded_rgsw_ct) } else { // Encrypt m1 as RGSW(m1) using public key let rgsw_ct = _pk_encrypt_rgsw(&m1, &pk, &decomposer, &mod_op, &ntt_op); - RgswCiphertextEvaluationDomain::<_, DefaultSecureRng, NttBackendU64>::from(&rgsw_ct) + RgswCiphertextEvaluationDomain::<_, _, DefaultSecureRng, NttBackendU64>::from( + &rgsw_ct, + ) } }; @@ -1903,7 +954,7 @@ pub(crate) mod tests { &mut rng, ); let auto_key = - AutoKeyEvaluationDomain::>, DefaultSecureRng, NttBackendU64>::from( + AutoKeyEvaluationDomain::>, _, DefaultSecureRng, NttBackendU64>::from( &seeded_auto_key, ); @@ -1991,7 +1042,7 @@ pub(crate) mod tests { let mut rgsw_carrym = { let seeded_rgsw = _sk_encrypt_rgsw(&carry_m, s.values(), &decomposer, &mod_op, &ntt_op); let mut rgsw_eval = - RgswCiphertextEvaluationDomain::<_, DefaultSecureRng, NttBackendU64>::from( + RgswCiphertextEvaluationDomain::<_, _, DefaultSecureRng, NttBackendU64>::from( &seeded_rgsw, ); rgsw_eval @@ -2016,9 +1067,10 @@ pub(crate) mod tests { for i in 0..2 { let mut m = vec![0u64; ring_size as usize]; m[thread_rng().gen_range(0..ring_size) as usize] = if (i & 1) == 1 { q - 1 } else { 1 }; - let rgsw_m = RgswCiphertextEvaluationDomain::<_, DefaultSecureRng, NttBackendU64>::from( - &_sk_encrypt_rgsw(&m, s.values(), &decomposer, &mod_op, &ntt_op), - ); + let rgsw_m = + RgswCiphertextEvaluationDomain::<_, _, DefaultSecureRng, NttBackendU64>::from( + &_sk_encrypt_rgsw(&m, s.values(), &decomposer, &mod_op, &ntt_op), + ); rgsw_by_rgsw_inplace( &mut rgsw_carrym, &rgsw_m.data, @@ -2087,12 +1139,12 @@ pub(crate) mod tests { let mut rgsw_ct0 = { let seeded_rgsw_ct = _sk_encrypt_rgsw(&m0, s.values(), &decomposer, &mod_op, &ntt_op); - RgswCiphertextEvaluationDomain::>, DefaultSecureRng, NttBackendU64>::from(&seeded_rgsw_ct) + RgswCiphertextEvaluationDomain::>, _,DefaultSecureRng, NttBackendU64>::from(&seeded_rgsw_ct) }; let rgsw_ct1 = { let seeded_rgsw_ct = _sk_encrypt_rgsw(&m1, s.values(), &decomposer, &mod_op, &ntt_op); - RgswCiphertextEvaluationDomain::>, DefaultSecureRng, NttBackendU64>::from(&seeded_rgsw_ct) + RgswCiphertextEvaluationDomain::>,_, DefaultSecureRng, NttBackendU64>::from(&seeded_rgsw_ct) }; // RGSW x RGSW diff --git a/src/rgsw/runtime.rs b/src/rgsw/runtime.rs new file mode 100644 index 0000000..88085b4 --- /dev/null +++ b/src/rgsw/runtime.rs @@ -0,0 +1,408 @@ +use itertools::izip; +use num_traits::Zero; + +use crate::{ + backend::{ArithmeticOps, GetModulus, Modulus, VectorOps}, + decomposer::{Decomposer, RlweDecomposer}, + ntt::Ntt, + Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, Secret, +}; + +use super::IsTrivial; + +pub(crate) fn routine>( + write_to_row: &mut [R::Element], + matrix_a: &[R], + matrix_b: &[R], + mod_op: &ModOp, +) { + izip!(matrix_a.iter(), matrix_b.iter()).for_each(|(a, b)| { + mod_op.elwise_fma_mut(write_to_row, a.as_ref(), b.as_ref()); + }); +} + +/// Decomposes ring polynomial r(X) into d polynomials using decomposer into +/// output matrix decomp_r +/// +/// Note that decomposition of r(X) requires decomposition of each of +/// coefficients. +/// +/// - decomp_r: must have dimensions d x ring_size. i^th decomposed polynomial +/// will be stored at i^th row. +pub(crate) fn decompose_r>( + r: &[R::Element], + decomp_r: &mut [R], + decomposer: &D, +) where + R::Element: Copy, +{ + let ring_size = r.len(); + + for ri in 0..ring_size { + decomposer + .decompose_iter(&r[ri]) + .enumerate() + .for_each(|(index, el)| { + decomp_r[index].as_mut()[ri] = el; + }); + } +} + +/// Sends RLWE_{s}(X) -> RLWE_{s}(X^k) where k is some galois element +/// +/// - scratch_matrix: must have dimension at-least d+2 x ring_size. d rows to +/// store decomposed polynomials and 2 for rlwe +pub(crate) fn galois_auto< + MT: Matrix + IsTrivial + MatrixMut, + Mmut: MatrixMut, + ModOp: ArithmeticOps + VectorOps, + NttOp: Ntt, + D: Decomposer, +>( + rlwe_in: &mut MT, + ksk: &Mmut, + scratch_matrix: &mut Mmut, + auto_map_index: &[usize], + auto_map_sign: &[bool], + mod_op: &ModOp, + ntt_op: &NttOp, + decomposer: &D, +) where + ::R: RowMut, + ::R: RowMut, + MT::MatElement: Copy + Zero, +{ + let d = decomposer.decomposition_count(); + let ring_size = rlwe_in.dimension().1; + assert!(rlwe_in.dimension().0 == 2); + assert!(scratch_matrix.fits(d + 2, ring_size)); + + // scratch matrix is guaranteed to have at-least d+2 rows but can have more than + // d+2 rows. We require to split them into sub-matrices of exact sizes one with + // d rows for storing decomposed polynomial and second with 2 rows to act + // tomperary space for RLWE ciphertext. Exact sizes is necessary to avoid any + // irrelevant extra FMA or NTT ops. + let (scratch_matrix_d_ring, other_half) = scratch_matrix.split_at_row_mut(d); + let (tmp_rlwe_out, _) = other_half.split_at_mut(2); + + debug_assert!(tmp_rlwe_out.len() == 2); + debug_assert!(scratch_matrix_d_ring.len() == d); + + if !rlwe_in.is_trivial() { + tmp_rlwe_out.iter_mut().for_each(|r| { + r.as_mut().fill(Mmut::MatElement::zero()); + }); + + // send a(X) -> a(X^k) and decompose a(X^k) + izip!( + rlwe_in.get_row(0), + auto_map_index.iter(), + auto_map_sign.iter() + ) + .for_each(|(el_in, to_index, sign)| { + let el_out = if !*sign { mod_op.neg(el_in) } else { *el_in }; + + decomposer + .decompose_iter(&el_out) + .enumerate() + .for_each(|(index, el)| { + scratch_matrix_d_ring[index].as_mut()[*to_index] = el; + }); + }); + + // transform decomposed a(X^k) to evaluation domain + scratch_matrix_d_ring.iter_mut().for_each(|r| { + ntt_op.forward(r.as_mut()); + }); + + // RLWE(m^k) = a', b'; RLWE(m) = a, b + // key switch: (a * RLWE'(s(X^k))) + let (ksk_a, ksk_b) = ksk.split_at_row(d); + // a' = decomp * RLWE'_A(s(X^k)) + routine( + tmp_rlwe_out[0].as_mut(), + scratch_matrix_d_ring, + ksk_a, + mod_op, + ); + + // b' += decomp * RLWE'_B(s(X^k)) + routine( + tmp_rlwe_out[1].as_mut(), + scratch_matrix_d_ring, + ksk_b, + mod_op, + ); + + // transform RLWE(m^k) to coefficient domain + tmp_rlwe_out + .iter_mut() + .for_each(|r| ntt_op.backward(r.as_mut())); + + // send b(X) -> b(X^k) and then b'(X) += b(X^k) + izip!( + rlwe_in.get_row(1), + auto_map_index.iter(), + auto_map_sign.iter() + ) + .for_each(|(el_in, to_index, sign)| { + let row = tmp_rlwe_out[1].as_mut(); + if !*sign { + row[*to_index] = mod_op.sub(&row[*to_index], el_in); + } else { + row[*to_index] = mod_op.add(&row[*to_index], el_in); + } + }); + + // copy over A; Leave B for later + rlwe_in + .get_row_mut(0) + .copy_from_slice(tmp_rlwe_out[0].as_ref()); + } else { + // RLWE is trivial, a(X) is 0. + // send b(X) -> b(X^k) + izip!( + rlwe_in.get_row(1), + auto_map_index.iter(), + auto_map_sign.iter() + ) + .for_each(|(el_in, to_index, sign)| { + if !*sign { + tmp_rlwe_out[1].as_mut()[*to_index] = mod_op.neg(el_in); + } else { + tmp_rlwe_out[1].as_mut()[*to_index] = *el_in; + } + }); + } + + // Copy over B + rlwe_in + .get_row_mut(1) + .copy_from_slice(tmp_rlwe_out[1].as_ref()); +} + +/// Returns RLWE(m0m1) = RLWE(m0) x RGSW(m1). Mutates rlwe_in inplace to equal +/// RLWE(m0m1) +/// +/// - rlwe_in: is RLWE(m0) with polynomials in coefficient domain +/// - rgsw_in: is RGSW(m1) with polynomials in evaluation domain +/// - scratch_matrix_d_ring: is a matrix with atleast max(d_a, d_b) rows and +/// ring_size columns. It's used to store decomposed polynomials and out RLWE +/// temoporarily +pub(crate) fn rlwe_by_rgsw< + Mmut: MatrixMut, + MT: Matrix + MatrixMut + IsTrivial, + D: RlweDecomposer, + ModOp: VectorOps, + NttOp: Ntt, +>( + rlwe_in: &mut MT, + rgsw_in: &Mmut, + scratch_matrix: &mut Mmut, + decomposer: &D, + ntt_op: &NttOp, + mod_op: &ModOp, +) where + Mmut::MatElement: Copy + Zero, + ::R: RowMut, + ::R: RowMut, +{ + let decomposer_a = decomposer.a(); + let decomposer_b = decomposer.b(); + let d_a = decomposer_a.decomposition_count(); + let d_b = decomposer_b.decomposition_count(); + let max_d = std::cmp::max(d_a, d_b); + assert!(scratch_matrix.fits(max_d + 2, rlwe_in.dimension().1)); + assert!(rgsw_in.dimension() == (d_a * 2 + d_b * 2, rlwe_in.dimension().1)); + + // decomposed RLWE x RGSW + let (rlwe_dash_nsm, rlwe_dash_m) = rgsw_in.split_at_row(d_a * 2); + let (scratch_matrix_d_ring, scratch_rlwe_out) = scratch_matrix.split_at_row_mut(max_d); + scratch_rlwe_out[0].as_mut().fill(Mmut::MatElement::zero()); + scratch_rlwe_out[1].as_mut().fill(Mmut::MatElement::zero()); + // RLWE_in = a_in, b_in; RLWE_out = a_out, b_out + if !rlwe_in.is_trivial() { + // a_in = 0 when RLWE_in is trivial RLWE ciphertext + // decomp + decompose_r( + rlwe_in.get_row_slice(0), + &mut scratch_matrix_d_ring[..d_a], + decomposer_a, + ); + scratch_matrix_d_ring + .iter_mut() + .take(d_a) + .for_each(|r| ntt_op.forward(r.as_mut())); + // a_out += decomp \cdot RLWE_A'(-sm) + routine( + scratch_rlwe_out[0].as_mut(), + scratch_matrix_d_ring.as_ref(), + &rlwe_dash_nsm[..d_a], + mod_op, + ); + // b_out += decomp \cdot RLWE_B'(-sm) + routine( + scratch_rlwe_out[1].as_mut(), + scratch_matrix_d_ring.as_ref(), + &rlwe_dash_nsm[d_a..], + mod_op, + ); + } + // decomp + decompose_r( + rlwe_in.get_row_slice(1), + &mut scratch_matrix_d_ring[..d_b], + decomposer_b, + ); + scratch_matrix_d_ring + .iter_mut() + .take(d_b) + .for_each(|r| ntt_op.forward(r.as_mut())); + // a_out += decomp \cdot RLWE_A'(m) + routine( + scratch_rlwe_out[0].as_mut(), + scratch_matrix_d_ring.as_ref(), + &rlwe_dash_m[..d_b], + mod_op, + ); + // b_out += decomp \cdot RLWE_B'(m) + routine( + scratch_rlwe_out[1].as_mut(), + scratch_matrix_d_ring.as_ref(), + &rlwe_dash_m[d_b..], + mod_op, + ); + + // transform rlwe_out to coefficient domain + scratch_rlwe_out + .iter_mut() + .for_each(|r| ntt_op.backward(r.as_mut())); + + rlwe_in + .get_row_mut(0) + .copy_from_slice(scratch_rlwe_out[0].as_mut()); + rlwe_in + .get_row_mut(1) + .copy_from_slice(scratch_rlwe_out[1].as_mut()); + rlwe_in.set_not_trivial(); +} + +/// Inplace mutates rlwe_0 to equal RGSW(m0m1) = RGSW(m0)xRGSW(m1) +/// in evaluation domain +/// +/// Warning - +/// Pass a fresh RGSW ciphertext as the second operand, i.e. as `rgsw_1`. +/// This is to assure minimal error growth in the resulting RGSW ciphertext. +/// RGSWxRGSW boils down to d_rgsw*2 RLWExRGSW multiplications. Hence, the noise +/// growth in resulting ciphertext depends on the norm of second RGSW +/// ciphertext, not the first. This is useful in cases where one is accumulating +/// multiple RGSW ciphertexts into 1. In which case, pass the accumulating RGSW +/// ciphertext as rlwe_0 (the one with higher noise) and subsequent RGSW +/// ciphertexts, with lower noise, to be accumulated as second +/// operand. +/// +/// - rgsw_0: RGSW(m0) +/// - rgsw_1_eval: RGSW(m1) in Evaluation domain +/// - scratch_matrix_d_plus_rgsw_by_ring: scratch space matrix with rows +/// (max(d_a, d_b) + d_a*2+d_b*2) and columns ring_size +pub(crate) fn rgsw_by_rgsw_inplace< + Mmut: MatrixMut, + D: RlweDecomposer, + ModOp: VectorOps, + NttOp: Ntt, +>( + rgsw_0: &mut Mmut, + rgsw_1_eval: &Mmut, + decomposer: &D, + scratch_matrix: &mut Mmut, + ntt_op: &NttOp, + mod_op: &ModOp, +) where + ::R: RowMut, + Mmut::MatElement: Copy + Zero, +{ + let decomposer_a = decomposer.a(); + let decomposer_b = decomposer.b(); + let d_a = decomposer_a.decomposition_count(); + let d_b = decomposer_b.decomposition_count(); + let max_d = std::cmp::max(d_a, d_b); + let rgsw_rows = d_a * 2 + d_b * 2; + assert!(rgsw_0.dimension().0 == rgsw_rows); + let ring_size = rgsw_0.dimension().1; + assert!(rgsw_1_eval.dimension() == (rgsw_rows, ring_size)); + assert!(scratch_matrix.fits(max_d + rgsw_rows, ring_size)); + + let (decomp_r_space, rgsw_space) = scratch_matrix.split_at_row_mut(max_d); + + // zero rgsw_space + rgsw_space + .iter_mut() + .for_each(|ri| ri.as_mut().fill(Mmut::MatElement::zero())); + let (rlwe_dash_space_nsm, rlwe_dash_space_m) = rgsw_space.split_at_mut(d_a * 2); + let (rlwe_dash_space_nsm_parta, rlwe_dash_space_nsm_partb) = + rlwe_dash_space_nsm.split_at_mut(d_a); + let (rlwe_dash_space_m_parta, rlwe_dash_space_m_partb) = rlwe_dash_space_m.split_at_mut(d_b); + + let (rgsw0_nsm, rgsw0_m) = rgsw_0.split_at_row(d_a * 2); + let (rgsw1_nsm, rgsw1_m) = rgsw_1_eval.split_at_row(d_a * 2); + + // RGSW x RGSW + izip!( + rgsw0_nsm.iter().take(d_a).chain(rgsw0_m.iter().take(d_b)), + rgsw0_nsm.iter().skip(d_a).chain(rgsw0_m.iter().skip(d_b)), + rlwe_dash_space_nsm_parta + .iter_mut() + .chain(rlwe_dash_space_m_parta.iter_mut()), + rlwe_dash_space_nsm_partb + .iter_mut() + .chain(rlwe_dash_space_m_partb.iter_mut()), + ) + .for_each(|(rlwe_a, rlwe_b, rlwe_out_a, rlwe_out_b)| { + // Part A + decompose_r(rlwe_a.as_ref(), decomp_r_space.as_mut(), decomposer_a); + decomp_r_space + .iter_mut() + .take(d_a) + .for_each(|ri| ntt_op.forward(ri.as_mut())); + routine( + rlwe_out_a.as_mut(), + &decomp_r_space[..d_a], + &rgsw1_nsm[..d_a], + mod_op, + ); + routine( + rlwe_out_b.as_mut(), + &decomp_r_space[..d_a], + &rgsw1_nsm[d_a..], + mod_op, + ); + + // Part B + decompose_r(rlwe_b.as_ref(), decomp_r_space.as_mut(), decomposer_b); + decomp_r_space + .iter_mut() + .take(d_b) + .for_each(|ri| ntt_op.forward(ri.as_mut())); + routine( + rlwe_out_a.as_mut(), + &decomp_r_space[..d_b], + &rgsw1_m[..d_b], + mod_op, + ); + routine( + rlwe_out_b.as_mut(), + &decomp_r_space[..d_b], + &rgsw1_m[d_b..], + mod_op, + ); + }); + + // copy over RGSW(m0m1) into RGSW(m0) + izip!(rgsw_0.iter_rows_mut(), rgsw_space.iter()) + .for_each(|(to_ri, from_ri)| to_ri.as_mut().copy_from_slice(from_ri.as_ref())); + + // send back to coefficient domain + rgsw_0 + .iter_rows_mut() + .for_each(|ri| ntt_op.backward(ri.as_mut())); +}