From 80856cc850ffeb045c18cece8899c9afb7c39b65 Mon Sep 17 00:00:00 2001 From: Janmajaya Mall Date: Sun, 16 Jun 2024 10:48:55 +0530 Subject: [PATCH] fix parameters --- src/bool/evaluator.rs | 22 ++++- src/bool/keys.rs | 46 +++++++++ src/bool/parameters.rs | 208 ++++++++++++++++++++++++++++++----------- src/multi_party.rs | 66 +++++++++---- 4 files changed, 267 insertions(+), 75 deletions(-) diff --git a/src/bool/evaluator.rs b/src/bool/evaluator.rs index 8205b9c..7ae85f8 100644 --- a/src/bool/evaluator.rs +++ b/src/bool/evaluator.rs @@ -12,6 +12,7 @@ use std::{ use itertools::{izip, partition, Itertools}; use num_traits::{FromPrimitive, Num, One, Pow, PrimInt, ToPrimitive, WrappingSub, Zero}; +use rand::Rng; use rand_distr::uniform::SampleUniform; use crate::{ @@ -20,7 +21,7 @@ use crate::{ }, decomposer::{Decomposer, DefaultDecomposer, NumInfo, RlweDecomposer}, lwe::{decrypt_lwe, encrypt_lwe, lwe_key_switch, lwe_ksk_keygen, measure_noise_lwe, LweSecret}, - multi_party::public_key_share, + multi_party::{non_interactive_ksk_gen, public_key_share}, ntt::{self, Ntt, NttBackendU64, NttInit}, pbs::{pbs, sample_extract, PbsInfo, PbsKey, WithShoupRepr}, random::{ @@ -44,7 +45,7 @@ use super::{ keys::ClientKey, parameters::{BoolParameters, CiphertextModulus}, CommonReferenceSeededCollectivePublicKeyShare, CommonReferenceSeededMultiPartyServerKeyShare, - SeededMultiPartyServerKey, SeededServerKey, ServerKeyEvaluationDomain, + NonInteractiveClientKey, SeededMultiPartyServerKey, SeededServerKey, ServerKeyEvaluationDomain, ShoupServerKeyEvaluationDomain, }; @@ -713,6 +714,23 @@ where }) } + pub(super) fn non_interactive_multi_party_key_share( + self_ui_to_ksk_seed: [u8; 32], + others_ui_to_ksk_seed: &[[u8; 32]], + client_key: &NonInteractiveClientKey, + ) { + // // ui_to_s_ksk + // non_interactive_ksk_gen( + // client_key.sk_rlwe().values(), + // client_key.sk_u_rlwe().values(), + // gadget_vec, + // p_rng, + // rng, + // nttop, + // modop, + // ) + } + pub(super) fn multi_party_public_key_share( &self, cr_seed: [u8; 32], diff --git a/src/bool/keys.rs b/src/bool/keys.rs index a0df170..da79d3b 100644 --- a/src/bool/keys.rs +++ b/src/bool/keys.rs @@ -19,6 +19,14 @@ pub struct ClientKey { sk_lwe: LweSecret, } +/// Client key with RLWE and LWE secrets +#[derive(Clone)] +pub struct NonInteractiveClientKey { + sk_rlwe: RlweSecret, + sk_u_rlwe: RlweSecret, + sk_lwe: LweSecret, +} + mod impl_ck { use super::*; @@ -43,6 +51,44 @@ mod impl_ck { } } + // Client key + impl NonInteractiveClientKey { + pub(in super::super) fn random() -> Self { + let sk_rlwe = RlweSecret::random(0, 0); + let sk_u_rlwe = RlweSecret::random(0, 0); + let sk_lwe = LweSecret::random(0, 0); + Self { + sk_rlwe, + sk_u_rlwe, + sk_lwe, + } + } + + pub(in super::super) fn new( + sk_rlwe: RlweSecret, + sk_u_rlwe: RlweSecret, + sk_lwe: LweSecret, + ) -> Self { + Self { + sk_rlwe, + sk_u_rlwe, + sk_lwe, + } + } + + pub(in super::super) fn sk_rlwe(&self) -> &RlweSecret { + &self.sk_rlwe + } + + pub(in super::super) fn sk_u_rlwe(&self) -> &RlweSecret { + &self.sk_u_rlwe + } + + pub(in super::super) fn sk_lwe(&self) -> &LweSecret { + &self.sk_lwe + } + } + impl Encryptor> for ClientKey { fn encrypt(&self, m: &bool) -> Vec { BoolEvaluator::with_local(|e| e.sk_encrypt(*m, self)) diff --git a/src/bool/parameters.rs b/src/bool/parameters.rs index f6c0f8d..8c5245e 100644 --- a/src/bool/parameters.rs +++ b/src/bool/parameters.rs @@ -2,6 +2,91 @@ use num_traits::{ConstZero, FromPrimitive, PrimInt}; use crate::{backend::Modulus, decomposer::Decomposer}; +trait DoubleDecomposerParams { + type Base; + type Count; + + fn new(base: Self::Base, count_a: Self::Count, count_b: Self::Count) -> Self; + fn decomposition_base(&self) -> Self::Base; + fn decomposition_count_a(&self) -> Self::Count; + fn decomposition_count_b(&self) -> Self::Count; +} + +trait SingleDecomposerParams { + type Base; + type Count; + + fn new(base: Self::Base, count: Self::Count) -> Self; + fn decomposition_base(&self) -> Self::Base; + fn decomposition_count(&self) -> Self::Count; +} + +impl DoubleDecomposerParams + for ( + DecompostionLogBase, + (DecompositionCount, DecompositionCount), + ) +{ + type Base = DecompostionLogBase; + type Count = DecompositionCount; + + fn new( + base: DecompostionLogBase, + count_a: DecompositionCount, + count_b: DecompositionCount, + ) -> Self { + (base, (count_a, count_b)) + } + + fn decomposition_base(&self) -> Self::Base { + self.0 + } + + fn decomposition_count_a(&self) -> Self::Count { + self.1 .0 + } + + fn decomposition_count_b(&self) -> Self::Count { + self.1 .1 + } +} + +impl SingleDecomposerParams for (DecompostionLogBase, DecompositionCount) { + type Base = DecompostionLogBase; + type Count = DecompositionCount; + + fn new(base: DecompostionLogBase, count: DecompositionCount) -> Self { + (base, count) + } + + fn decomposition_base(&self) -> Self::Base { + self.0 + } + + fn decomposition_count(&self) -> Self::Count { + self.1 + } +} + +// impl DecomposerParams for (DecompostionLogBase, (DecompositionCount)) { +// type Base = DecompostionLogBase; +// type Count = DecompositionCount; + +// fn decomposition_base(&self) -> Self::Base { +// self.0 +// } + +// fn decomposition_count(&self) -> Self::Count { +// self.1 +// } +// } + +#[derive(Clone, PartialEq, Debug)] +pub(crate) enum ParameterVariant { + SingleParty, + MultiParty, + NonInteractiveMultiParty, +} #[derive(Clone, PartialEq)] pub struct BoolParameters { rlwe_q: CiphertextModulus, @@ -9,18 +94,21 @@ pub struct BoolParameters { br_q: usize, rlwe_n: PolynomialSize, lwe_n: LweDimension, - lwe_decomposer_base: DecompostionLogBase, - lwe_decomposer_count: DecompositionCount, - rlrg_decomposer_base: DecompostionLogBase, + lwe_decomposer_params: (DecompostionLogBase, DecompositionCount), /// RLWE x RGSW decomposition count for (part A, part B) - rlrg_decomposer_count: (DecompositionCount, DecompositionCount), - rgrg_decomposer_base: DecompostionLogBase, + rlrg_decomposer_params: ( + DecompostionLogBase, + (DecompositionCount, DecompositionCount), + ), /// RGSW x RGSW decomposition count for (part A, part B) - rgrg_decomposer_count: (DecompositionCount, DecompositionCount), - auto_decomposer_base: DecompostionLogBase, - auto_decomposer_count: DecompositionCount, + rgrg_decomposer_params: Option<( + DecompostionLogBase, + (DecompositionCount, DecompositionCount), + )>, + auto_decomposer_params: (DecompostionLogBase, DecompositionCount), g: usize, w: usize, + variant: ParameterVariant, } impl BoolParameters { @@ -53,53 +141,57 @@ impl BoolParameters { } pub(crate) fn rlwe_rgsw_decomposition_base(&self) -> DecompostionLogBase { - self.rlrg_decomposer_base + self.rlrg_decomposer_params.0 } pub(crate) fn rlwe_rgsw_decomposition_count(&self) -> (DecompositionCount, DecompositionCount) { - self.rlrg_decomposer_count - } - - pub(crate) fn rgsw_rgsw_decomposition_base(&self) -> DecompostionLogBase { - self.rgrg_decomposer_base + self.rlrg_decomposer_params.1 } pub(crate) fn rgsw_rgsw_decomposition_count(&self) -> (DecompositionCount, DecompositionCount) { - self.rgrg_decomposer_count + let params = self.rgrg_decomposer_params.expect(&format!( + "Parameter variant {:?} does not support RGSW x RGSW", + self.variant + )); + params.1 } pub(crate) fn auto_decomposition_base(&self) -> DecompostionLogBase { - self.auto_decomposer_base + self.auto_decomposer_params.decomposition_base() } pub(crate) fn auto_decomposition_count(&self) -> DecompositionCount { - self.auto_decomposer_count + self.auto_decomposer_params.decomposition_count() } pub(crate) fn lwe_decomposition_base(&self) -> DecompostionLogBase { - self.lwe_decomposer_base + self.lwe_decomposer_params.decomposition_base() } pub(crate) fn lwe_decomposition_count(&self) -> DecompositionCount { - self.lwe_decomposer_count + self.lwe_decomposer_params.decomposition_count() } pub(crate) fn rgsw_rgsw_decomposer>(&self) -> (D, D) where El: Copy, { + let params = self.rgrg_decomposer_params.expect(&format!( + "Parameter variant {:?} does not support RGSW x RGSW", + self.variant + )); ( // A D::new( self.rlwe_q.0, - self.rgrg_decomposer_base.0, - self.rgrg_decomposer_count.0 .0, + params.decomposition_base().0, + params.decomposition_count_a().0, ), // B D::new( self.rlwe_q.0, - self.rgrg_decomposer_base.0, - self.rgrg_decomposer_count.1 .0, + params.decomposition_base().0, + params.decomposition_count_b().0, ), ) } @@ -110,8 +202,8 @@ impl BoolParameters { { D::new( self.rlwe_q.0, - self.auto_decomposer_base.0, - self.auto_decomposer_count.0, + self.auto_decomposer_params.decomposition_base().0, + self.auto_decomposer_params.decomposition_count().0, ) } @@ -121,8 +213,8 @@ impl BoolParameters { { D::new( self.lwe_q.0, - self.lwe_decomposer_base.0, - self.lwe_decomposer_count.0, + self.lwe_decomposer_params.decomposition_base().0, + self.lwe_decomposer_params.decomposition_count().0, ) } @@ -134,14 +226,14 @@ impl BoolParameters { // A D::new( self.rlwe_q.0, - self.rlrg_decomposer_base.0, - self.rlrg_decomposer_count.0 .0, + self.rlrg_decomposer_params.decomposition_base().0, + self.rlrg_decomposer_params.decomposition_count_a().0, ), // B D::new( self.rlwe_q.0, - self.rlrg_decomposer_base.0, - self.rlrg_decomposer_count.1 .0, + self.rlrg_decomposer_params.decomposition_base().0, + self.rlrg_decomposer_params.decomposition_count_b().0, ), ) } @@ -298,16 +390,16 @@ pub(crate) const SP_BOOL_PARAMS: BoolParameters = BoolParameters:: { br_q: 1 << 10, rlwe_n: PolynomialSize(1 << 10), lwe_n: LweDimension(500), - lwe_decomposer_base: DecompostionLogBase(4), - lwe_decomposer_count: DecompositionCount(4), - rlrg_decomposer_base: DecompostionLogBase(7), - rlrg_decomposer_count: (DecompositionCount(4), DecompositionCount(4)), - rgrg_decomposer_base: DecompostionLogBase(7), - rgrg_decomposer_count: (DecompositionCount(4), DecompositionCount(4)), - auto_decomposer_base: DecompostionLogBase(7), - auto_decomposer_count: DecompositionCount(4), + lwe_decomposer_params: (DecompostionLogBase(4), DecompositionCount(4)), + rlrg_decomposer_params: ( + DecompostionLogBase(7), + (DecompositionCount(4), DecompositionCount(4)), + ), + rgrg_decomposer_params: None, + auto_decomposer_params: (DecompostionLogBase(7), DecompositionCount(4)), g: 5, w: 5, + variant: ParameterVariant::SingleParty, }; pub(crate) const MP_BOOL_PARAMS: BoolParameters = BoolParameters:: { @@ -316,16 +408,19 @@ pub(crate) const MP_BOOL_PARAMS: BoolParameters = BoolParameters:: { br_q: 1 << 11, rlwe_n: PolynomialSize(1 << 11), lwe_n: LweDimension(500), - lwe_decomposer_base: DecompostionLogBase(4), - lwe_decomposer_count: DecompositionCount(5), - rlrg_decomposer_base: DecompostionLogBase(12), - rlrg_decomposer_count: (DecompositionCount(5), DecompositionCount(5)), - rgrg_decomposer_base: DecompostionLogBase(12), - rgrg_decomposer_count: (DecompositionCount(5), DecompositionCount(5)), - auto_decomposer_base: DecompostionLogBase(12), - auto_decomposer_count: DecompositionCount(5), + lwe_decomposer_params: (DecompostionLogBase(4), DecompositionCount(5)), + rlrg_decomposer_params: ( + DecompostionLogBase(12), + (DecompositionCount(5), DecompositionCount(5)), + ), + rgrg_decomposer_params: Some(( + DecompostionLogBase(12), + (DecompositionCount(5), DecompositionCount(5)), + )), + auto_decomposer_params: (DecompostionLogBase(12), DecompositionCount(5)), g: 5, w: 10, + variant: ParameterVariant::MultiParty, }; pub(crate) const SMALL_MP_BOOL_PARAMS: BoolParameters = BoolParameters:: { @@ -334,16 +429,19 @@ pub(crate) const SMALL_MP_BOOL_PARAMS: BoolParameters = BoolParameters:: + GetModulus, >( s: &[S], - ephemeral_u: &[S], + u: &[S], m: &[M::MatElement], gadget_vec: &[M::MatElement], p_rng: &mut PRng, @@ -69,14 +69,14 @@ where ::R: RowMut + TryConvertFrom1<[S], ModOp::M> + RowEntity, M::MatElement: Copy, { - assert_eq!(s.len(), ephemeral_u.len()); + assert_eq!(s.len(), u.len()); assert_eq!(s.len(), m.len()); let q = modop.modulus(); let d = gadget_vec.len(); let ring_size = s.len(); let mut s_poly_eval = M::R::try_convert_from(s, q); - let mut u_poly_eval = M::R::try_convert_from(ephemeral_u, q); + let mut u_poly_eval = M::R::try_convert_from(u, q); nttop.forward(s_poly_eval.as_mut()); nttop.forward(u_poly_eval.as_mut()); @@ -125,7 +125,7 @@ where (enc_beta_m, zero_encryptions) } -fn non_interactive_ksk_gen< +pub(crate) fn non_interactive_ksk_gen< M: MatrixMut + MatrixEntity, S, PRng: RandomFillUniformInModulus<[M::MatElement], ModOp::M>, @@ -134,7 +134,7 @@ fn non_interactive_ksk_gen< ModOp: VectorOps + GetModulus, >( s: &[S], - ephemeral_u: &[S], + u: &[S], gadget_vec: &[M::MatElement], p_rng: &mut PRng, rng: &mut Rng, @@ -144,7 +144,7 @@ fn non_interactive_ksk_gen< ::R: RowMut + TryConvertFrom1<[S], ModOp::M> + RowEntity, M::MatElement: Copy, { - assert_eq!(s.len(), ephemeral_u.len()); + assert_eq!(s.len(), u.len()); let q = modop.modulus(); let d = gadget_vec.len(); @@ -152,24 +152,16 @@ fn non_interactive_ksk_gen< let mut s_poly_eval = M::R::try_convert_from(s, q); nttop.forward(s_poly_eval.as_mut()); - let u_poly = M::R::try_convert_from(ephemeral_u, q); + let u_poly = M::R::try_convert_from(u, q); // a_i * s + \beta u + e - // a_i * s + e let mut ksk = M::zeros(d, ring_size); - let mut zero_encs = M::zeros(d, ring_size); let mut scratch_space = M::R::zeros(ring_size); - izip!( - ksk.iter_rows_mut(), - zero_encs.iter_rows_mut(), - gadget_vec.iter() - ) - .for_each(|(e_ksk, e_zero, beta)| { + izip!(ksk.iter_rows_mut(), gadget_vec.iter()).for_each(|(e_ksk, beta)| { // sample a_i RandomFillUniformInModulus::random_fill(p_rng, q, e_ksk.as_mut()); - e_zero.as_mut().copy_from_slice(e_ksk.as_ref()); // a_i * s + e + beta u nttop.forward(e_ksk.as_mut()); @@ -183,13 +175,51 @@ fn non_interactive_ksk_gen< modop.elwise_scalar_mul(scratch_space.as_mut(), u_poly.as_ref(), beta); // a_i * s + e + \beta * u modop.elwise_add_mut(e_ksk.as_mut(), scratch_space.as_ref()); + }); +} + +pub(crate) fn non_interactive_ksk_zero_encryptions_for_other_party_i< + M: MatrixMut + MatrixEntity, + S, + PRng: RandomFillUniformInModulus<[M::MatElement], ModOp::M>, + Rng: RandomFillGaussianInModulus<[M::MatElement], ModOp::M>, + NttOp: Ntt, + ModOp: VectorOps + GetModulus, +>( + s: &[S], + gadget_vec: &[M::MatElement], + p_rng: &mut PRng, + rng: &mut Rng, + nttop: &NttOp, + modop: &ModOp, +) -> M +where + ::R: RowMut + TryConvertFrom1<[S], ModOp::M> + RowEntity, + M::MatElement: Copy, +{ + let q = modop.modulus(); + let d = gadget_vec.len(); + let ring_size = s.len(); + + let mut s_poly_eval = M::R::try_convert_from(s, q); + nttop.forward(s_poly_eval.as_mut()); + + // a_i * s + e + let mut zero_encs = M::zeros(d, ring_size); + + let mut scratch_space = M::R::zeros(ring_size); + + izip!(zero_encs.iter_rows_mut()).for_each(|(e_zero)| { + // sample a_i + RandomFillUniformInModulus::random_fill(p_rng, q, e_zero.as_mut()); // a_i * s + e nttop.forward(e_zero.as_mut()); modop.elwise_mul_mut(e_zero.as_mut(), s_poly_eval.as_ref()); nttop.backward(e_zero.as_mut()); - // sample error e + // sample error e RandomFillGaussianInModulus::random_fill(rng, q, scratch_space.as_mut()); - + modop.elwise_add_mut(e_zero.as_mut(), scratch_space.as_ref()); }); + zero_encs }