From 80ae5d7c8fe5e5ca2d5f587b2d6120a525c30b93 Mon Sep 17 00:00:00 2001 From: Janmajaya Mall Date: Tue, 11 Jun 2024 11:24:03 +0530 Subject: [PATCH] prolly a mistake --- src/bool/evaluator.rs | 86 +++++++++++++++------------------ src/bool/keys.rs | 14 ++++++ src/bool/mod.rs | 11 +++-- src/pbs.rs | 110 +++++++++++++++++++++++------------------- src/rgsw/runtime.rs | 4 +- src/shortint/mod.rs | 12 +++-- 6 files changed, 133 insertions(+), 104 deletions(-) diff --git a/src/bool/evaluator.rs b/src/bool/evaluator.rs index eeb100d..77be7eb 100644 --- a/src/bool/evaluator.rs +++ b/src/bool/evaluator.rs @@ -15,7 +15,9 @@ use num_traits::{FromPrimitive, Num, One, Pow, PrimInt, ToPrimitive, WrappingSub use rand_distr::uniform::SampleUniform; use crate::{ - backend::{ArithmeticOps, GetModulus, ModInit, ModularOpsU64, Modulus, VectorOps}, + backend::{ + ArithmeticOps, GetModulus, ModInit, ModularOpsU64, Modulus, ShoupMatrixFMA, VectorOps, + }, decomposer::{Decomposer, DefaultDecomposer, NumInfo, RlweDecomposer}, lwe::{decrypt_lwe, encrypt_lwe, lwe_key_switch, lwe_ksk_keygen, measure_noise_lwe, LweSecret}, multi_party::public_key_share, @@ -43,6 +45,7 @@ use super::{ parameters::{BoolParameters, CiphertextModulus}, CommonReferenceSeededCollectivePublicKeyShare, CommonReferenceSeededMultiPartyServerKeyShare, SeededMultiPartyServerKey, SeededServerKey, ServerKeyEvaluationDomain, + ShoupServerKeyEvaluationDomain, }; pub struct MultiPartyCrs { @@ -78,7 +81,7 @@ impl MultiPartyCrs { pub(crate) trait BooleanGates { type Ciphertext: RowEntity; - type Key; + type Key: Global; fn and_inplace(&mut self, c0: &mut Self::Ciphertext, c1: &Self::Ciphertext, key: &Self::Key); fn nand_inplace(&mut self, c0: &mut Self::Ciphertext, c1: &Self::Ciphertext, key: &Self::Key); @@ -231,12 +234,12 @@ pub(super) struct BoolPbsInfo { impl PbsInfo for BoolPbsInfo where M::MatElement: PrimInt + WrappingSub + NumInfo + FromPrimitive + From + Display, - RlweModOp: ArithmeticOps + VectorOps, + RlweModOp: ArithmeticOps + ShoupMatrixFMA, LweModOp: ArithmeticOps + VectorOps, NttOp: Ntt, { + type M = M; type Modulus = CiphertextModulus; - type Element = M::MatElement; type D = DefaultDecomposer; type RlweModOp = RlweModOp; type LweModOp = LweModOp; @@ -291,7 +294,7 @@ where } } -pub(crate) struct BoolEvaluator +pub(crate) struct BoolEvaluator where M: Matrix, { @@ -316,7 +319,7 @@ impl BoolEvaluator BoolEvaluator +impl BoolEvaluator where M: MatrixEntity + MatrixMut, M::MatElement: PrimInt @@ -330,7 +333,8 @@ where NttOp: Ntt, RlweModOp: ArithmeticOps + VectorOps - + GetModulus>, + + GetModulus> + + ShoupMatrixFMA, LweModOp: ArithmeticOps + VectorOps + GetModulus>, @@ -1083,12 +1087,8 @@ where M: MatrixMut + MatrixEntity, M::R: RowMut + RowEntity, M::MatElement: PrimInt + FromPrimitive + One + Copy + Zero + Display + WrappingSub + NumInfo, - RlweModOp: VectorOps - + ArithmeticOps - + GetModulus>, - LweModOp: VectorOps - + ArithmeticOps - + GetModulus>, + RlweModOp: VectorOps + ArithmeticOps, + LweModOp: VectorOps + ArithmeticOps, NttOp: Ntt, { /// Returns c0 + c1 + Q/4 @@ -1118,14 +1118,12 @@ where PrimInt + FromPrimitive + One + Copy + Zero + Display + WrappingSub + NumInfo + From, RlweModOp: VectorOps + ArithmeticOps - + GetModulus>, - LweModOp: VectorOps - + ArithmeticOps - + GetModulus>, + + ShoupMatrixFMA, + LweModOp: VectorOps + ArithmeticOps, NttOp: Ntt, { type Ciphertext = M::R; - type Key = ServerKeyEvaluationDomain, DefaultSecureRng, NttOp>; + type Key = Key; fn nand_inplace(&mut self, c0: &mut M::R, c1: &M::R, server_key: &Self::Key) { self._add_and_shift_lwe_cts(c0, c1); @@ -1307,7 +1305,7 @@ where // self, measure_noise, public_key_encrypt_rlwe, // secret_key_encrypt_rlwe, tests::{_measure_noise_rgsw, // _sk_encrypt_rlwe}, RgswCiphertext, -// RgswCiphertextEvaluationDomain, SeededRgswCiphertext, +// RgswCiphertextEvaluationDomain, SeededRgswCiphertext, // SeededRlweCiphertext, }, // utils::{negacyclic_mul, Stats}, // }; @@ -1439,7 +1437,7 @@ where // let public_key_share = parties // .iter() // .map(|k| -// bool_evaluator.multi_party_public_key_share(pk_cr_seed, k)) +// bool_evaluator.multi_party_public_key_share(pk_cr_seed, k)) // .collect_vec(); // let collective_pk = PublicKey::< @@ -1559,7 +1557,7 @@ where // &collective_pk.key(), k) }) // .collect_vec(); // let seeded_server_key = -// +// // bool_evaluator.aggregate_multi_party_server_key_shares(&server_key_shares); // let server_key_eval = ServerKeyEvaluationDomain::<_, // DefaultSecureRng, NttBackendU64>::from( &seeded_server_key, @@ -1570,7 +1568,7 @@ where // let mut ideal_rlwe_sk = vec![0i32; // bool_evaluator.pbs_info.rlwe_n()]; parties.iter().for_each(|k| { // izip!(ideal_rlwe_sk.iter_mut(), -// k.sk_rlwe().values()).for_each(|(ideal_i, s_i)| { +// k.sk_rlwe().values()).for_each(|(ideal_i, s_i)| { // *ideal_i = *ideal_i + s_i; }); // }); // let mut ideal_lwe_sk = vec![0i32; @@ -1628,7 +1626,7 @@ where // let decryption_shares = parties // .iter() // .map(|k| -// bool_evaluator.multi_party_decryption_share(&lwe_out, k)) +// bool_evaluator.multi_party_decryption_share(&lwe_out, k)) // .collect_vec(); let m_back = // bool_evaluator.multi_party_decrypt(&decryption_shares, &lwe_out); @@ -1687,7 +1685,7 @@ where // let mut ideal_rlwe_sk = vec![0i32; // bool_evaluator.pbs_info.rlwe_n()]; parties.iter().for_each(|k| { // izip!(ideal_rlwe_sk.iter_mut(), -// k.sk_rlwe().values()).for_each(|(ideal_i, s_i)| { +// k.sk_rlwe().values()).for_each(|(ideal_i, s_i)| { // *ideal_i = *ideal_i + s_i; }); // }); // let mut ideal_lwe_sk = vec![0i32; @@ -1718,7 +1716,7 @@ where // let public_key_share = parties // .iter() // .map(|k| -// bool_evaluator.multi_party_public_key_share(pk_cr_seed, k)) +// bool_evaluator.multi_party_public_key_share(pk_cr_seed, k)) // .collect_vec(); let collective_pk = PublicKey::< // Vec>, // DefaultSecureRng, @@ -1763,7 +1761,7 @@ where // let public_key_share = parties // .iter() // .map(|k| -// bool_evaluator.multi_party_public_key_share(pk_cr_seed, k)) +// bool_evaluator.multi_party_public_key_share(pk_cr_seed, k)) // .collect_vec(); let collective_pk = PublicKey::< // Vec>, // DefaultSecureRng, @@ -1780,7 +1778,7 @@ where // .collect_vec(); // let seeded_server_key = -// +// // bool_evaluator.aggregate_multi_party_server_key_shares(&server_key_shares); // // Check noise in RGSW ciphertexts of ideal LWE secret elements @@ -1802,21 +1800,21 @@ where // // RLWE'(-sm) // let mut neg_s_eval = -// +// // Vec::::try_convert_from(ideal_client_key.sk_rlwe().values(), rlwe_q); // rlwe_modop.elwise_neg_mut(&mut neg_s_eval); // rlwe_nttop.forward(&mut neg_s_eval); // for j in -// 0..rlwe_rgsw_decomposer.a().decomposition_count() { +// 0..rlwe_rgsw_decomposer.a().decomposition_count() { // // RLWE(B^{j} * -s[X]*X^{s_lwe[i]}) // // -s[X]*X^{s_lwe[i]}*B_j // let mut m_ideal = m_si.clone(); // rlwe_nttop.forward(m_ideal.as_mut_slice()); // rlwe_modop.elwise_mul_mut(m_ideal.as_mut_slice(), -// neg_s_eval.as_slice()); -// rlwe_nttop.backward(m_ideal.as_mut_slice()); -// rlwe_modop +// neg_s_eval.as_slice()); +// rlwe_nttop.backward(m_ideal.as_mut_slice()); +// rlwe_modop // .elwise_scalar_mul_mut(m_ideal.as_mut_slice(), &rlwe_rgsw_gadget_a[j]); // // RLWE(-s*X^{s_lwe[i]}*B_j) @@ -1842,7 +1840,7 @@ where // // RLWE'(m) // for j in -// 0..rlwe_rgsw_decomposer.b().decomposition_count() { +// 0..rlwe_rgsw_decomposer.b().decomposition_count() { // // RLWE(B^{j} * X^{s_lwe[i]}) // // X^{s_lwe[i]}*B_j @@ -1959,7 +1957,7 @@ where // ); // rlwe_nttop.forward(m_plus_e_times_m1.as_mut_slice()); // rlwe_nttop.forward(m1.as_mut_slice()); -// +// // rlwe_modop.elwise_mul_mut(m_plus_e_times_m1.as_mut_slice(), m1.as_slice()); // rlwe_nttop.backward(m_plus_e_times_m1.as_mut_slice()); @@ -2010,7 +2008,7 @@ where // let mut check = Stats { samples: vec![] }; // let mut neg_s_poly = -// +// // Vec::::try_convert_from(ideal_client_key.sk_rlwe().values(), rlwe_q); // rlwe_modop.elwise_neg_mut(neg_s_poly.as_mut_slice()); @@ -2045,7 +2043,7 @@ where // auto_gadget.iter().enumerate().for_each(|(i, b_i)| { // // B^i * -s[X^k] // let mut m_ideal = neg_s_poly_auto_i.clone(); -// +// // rlwe_modop.elwise_scalar_mul_mut(m_ideal.as_mut_slice(), b_i); // let mut m_out = vec![0u64; rlwe_n]; @@ -2053,14 +2051,8 @@ where // rlwe_ct[0].copy_from_slice(&auto_key_i[i]); // rlwe_ct[1].copy_from_slice( // &auto_key_i[auto_decomposer.decomposition_count() -// + i], ); -// decrypt_rlwe( -// &rlwe_ct, -// ideal_client_key.sk_rlwe().values(), -// &mut m_out, -// rlwe_nttop, -// rlwe_modop, -// ); +// + i], ); decrypt_rlwe( &rlwe_ct, +// ideal_client_key.sk_rlwe().values(), &mut m_out, rlwe_nttop, rlwe_modop, ); // // diff // rlwe_modop.elwise_sub_mut(m_out.as_mut_slice(), @@ -2111,10 +2103,10 @@ where // let auto_key = // server_key_eval_domain.galois_key_for_auto(i); let -// (auto_map_index, auto_map_sign) = +// (auto_map_index, auto_map_sign) = // bool_evaluator.pbs_info.rlwe_auto_map(i); let mut // scratch = vec![vec![0u64; rlwe_n]; -// auto_decomposer.decomposition_count() + 2]; +// auto_decomposer.decomposition_count() + 2]; // galois_auto( &mut rlwe_ct, // auto_key, // &mut scratch, @@ -2149,7 +2141,7 @@ where // rlwe_modop.elwise_sub_mut(m_out.as_mut_slice(), // m_plus_e_auto.as_slice()); -// +// // check.add_more(&Vec::::try_convert_from(m_out.as_slice(), rlwe_q)); // } // } diff --git a/src/bool/keys.rs b/src/bool/keys.rs index e6994cc..bd90b2c 100644 --- a/src/bool/keys.rs +++ b/src/bool/keys.rs @@ -3,6 +3,7 @@ use std::{collections::HashMap, hash::Hash, marker::PhantomData}; use crate::{ backend::{ModInit, VectorOps}, lwe::LweSecret, + pbs::WithShoupRepr, random::{NewWithSeed, RandomFillUniformInModulus}, rgsw::RlweSecret, utils::WithLocal, @@ -682,6 +683,19 @@ pub(crate) struct ShoupServerKeyEvaluationDomain { pub(crate) struct NormalAndShoup(M, M); +impl AsRef for NormalAndShoup { + fn as_ref(&self) -> &M { + &self.0 + } +} + +impl WithShoupRepr for NormalAndShoup { + type M = M; + fn shoup_repr(&self) -> &Self::M { + &self.1 + } +} + mod shoup_server_key_eval_domain { use crate::pbs::PbsKey; diff --git a/src/bool/mod.rs b/src/bool/mod.rs index 2f936cd..fb5ab6f 100644 --- a/src/bool/mod.rs +++ b/src/bool/mod.rs @@ -23,7 +23,12 @@ thread_local! { } static BOOL_SERVER_KEY: OnceLock< - ServerKeyEvaluationDomain>, BoolParameters, DefaultSecureRng, NttBackendU64>, + ShoupServerKeyEvaluationDomain< + Vec>, + BoolParameters, + DefaultSecureRng, + NttBackendU64, + >, > = OnceLock::new(); static MULTI_PARTY_CRS: OnceLock> = OnceLock::new(); @@ -40,7 +45,7 @@ pub fn set_mp_seed(seed: [u8; 32]) { } fn set_server_key( - key: ServerKeyEvaluationDomain< + key: ShoupServerKeyEvaluationDomain< Vec>, BoolParameters, DefaultSecureRng, @@ -140,7 +145,7 @@ impl } impl Global - for ServerKeyEvaluationDomain< + for ShoupServerKeyEvaluationDomain< Vec>, BoolParameters, DefaultSecureRng, diff --git a/src/pbs.rs b/src/pbs.rs index eb49f97..3a0b256 100644 --- a/src/pbs.rs +++ b/src/pbs.rs @@ -3,12 +3,14 @@ use std::{fmt::Display, marker::PhantomData}; use num_traits::{FromPrimitive, One, PrimInt, ToPrimitive, Zero}; use crate::{ - backend::{ArithmeticOps, Modulus, VectorOps}, + backend::{ArithmeticOps, Modulus, ShoupMatrixFMA, VectorOps}, decomposer::Decomposer, lwe::lwe_key_switch, ntt::Ntt, random::DefaultSecureRng, - rgsw::{galois_auto, rlwe_by_rgsw, rlwe_by_rgsw_shoup, IsTrivial, RlweCiphertext}, + rgsw::{ + galois_auto, galois_auto_shoup, rlwe_by_rgsw, rlwe_by_rgsw_shoup, IsTrivial, RlweCiphertext, + }, Matrix, MatrixEntity, MatrixMut, RowMut, }; pub(crate) trait PbsKey { @@ -24,22 +26,24 @@ pub(crate) trait PbsKey { fn lwe_ksk(&self) -> &Self::LweKskKey; } -trait WithShoupRepr: AsRef { +pub(crate) trait WithShoupRepr: AsRef { type M; - fn shoup_repr(&self) -> Self::M; + fn shoup_repr(&self) -> &Self::M; } pub(crate) trait PbsInfo { - type Element; - type Modulus: Modulus; - type NttOp: Ntt; - type D: Decomposer; + type M: Matrix; + type Modulus: Modulus::MatElement>; + type NttOp: Ntt::MatElement>; + type D: Decomposer::MatElement>; // Although both types have same bounds, they can be different types. For ex, // type RlweModOp may only support native modulus, where LweModOp may only // support prime modulus, etc. - type RlweModOp: VectorOps + ArithmeticOps; - type LweModOp: VectorOps + ArithmeticOps; + type RlweModOp: ArithmeticOps::MatElement> + + ShoupMatrixFMA<::R>; + type LweModOp: VectorOps::MatElement> + + ArithmeticOps::MatElement>; fn rlwe_q(&self) -> &Self::Modulus; fn lwe_q(&self) -> &Self::Modulus; @@ -79,8 +83,9 @@ pub(crate) trait PbsInfo { /// - blind rotate pub(crate) fn pbs< M: MatrixMut + MatrixEntity, - P: PbsInfo, - K: PbsKey, + MShoup: WithShoupRepr, + P: PbsInfo, + K: PbsKey, >( pbs_info: &P, test_vec: &M::R, @@ -217,10 +222,10 @@ fn blind_rotation< Mmut: MatrixMut, D: Decomposer, NttOp: Ntt, - ModOp: ArithmeticOps + VectorOps, + ModOp: ArithmeticOps + ShoupMatrixFMA, MShoup: WithShoupRepr, K: PbsKey, - P: PbsInfo, + P: PbsInfo, >( trivial_rlwe_test_poly: &mut MT, scratch_matrix: &mut Mmut, @@ -249,19 +254,11 @@ fn blind_rotation< s_indices.iter().for_each(|s_index| { // let new = std::time::Instant::now(); - // rlwe_by_rgsw( - // trivial_rlwe_test_poly, - // pbs_key.rgsw_ct_lwe_si(*s_index), - // scratch_matrix, - // rlwe_rgsw_decomposer, - // ntt_op, - // mod_op, - // ); let ct = pbs_key.rgsw_ct_lwe_si(*s_index); rlwe_by_rgsw_shoup( trivial_rlwe_test_poly, ct.as_ref(), - &ct.shoup_repr(), + ct.shoup_repr(), scratch_matrix, rlwe_rgsw_decomposer, ntt_op, @@ -275,9 +272,11 @@ fn blind_rotation< let (auto_map_index, auto_map_sign) = parameters.rlwe_auto_map(v); // let now = std::time::Instant::now(); - galois_auto( + let auto_key = pbs_key.galois_key_for_auto(v); + galois_auto_shoup( trivial_rlwe_test_poly, - pbs_key.galois_key_for_auto(v), + auto_key.as_ref(), + auto_key.shoup_repr(), scratch_matrix, &auto_map_index, &auto_map_sign, @@ -293,37 +292,46 @@ fn blind_rotation< } // -(g^0) - gk_to_si[q_by_4].iter().for_each(|s_index| { - rlwe_by_rgsw( + { + gk_to_si[q_by_4].iter().for_each(|s_index| { + let ct = pbs_key.rgsw_ct_lwe_si(*s_index); + rlwe_by_rgsw_shoup( + trivial_rlwe_test_poly, + ct.as_ref(), + ct.shoup_repr(), + scratch_matrix, + rlwe_rgsw_decomposer, + ntt_op, + mod_op, + ); + }); + + let (auto_map_index, auto_map_sign) = parameters.rlwe_auto_map(0); + let auto_key = pbs_key.galois_key_for_auto(0); + galois_auto_shoup( trivial_rlwe_test_poly, - pbs_key.rgsw_ct_lwe_si(*s_index), + auto_key.as_ref(), + auto_key.shoup_repr(), scratch_matrix, - rlwe_rgsw_decomposer, - ntt_op, + &auto_map_index, + &auto_map_sign, mod_op, + ntt_op, + auto_decomposer, ); - }); - let (auto_map_index, auto_map_sign) = parameters.rlwe_auto_map(0); - galois_auto( - trivial_rlwe_test_poly, - pbs_key.galois_key_for_auto(0), - scratch_matrix, - &auto_map_index, - &auto_map_sign, - mod_op, - ntt_op, - auto_decomposer, - ); - count += 1; + count += 1; + } // +(g^k) let mut v = 0; for i in (1..q_by_4).rev() { let s_indices = &gk_to_si[i]; s_indices.iter().for_each(|s_index| { - rlwe_by_rgsw( + let ct = pbs_key.rgsw_ct_lwe_si(*s_index); + rlwe_by_rgsw_shoup( trivial_rlwe_test_poly, - pbs_key.rgsw_ct_lwe_si(*s_index), + ct.as_ref(), + ct.shoup_repr(), scratch_matrix, rlwe_rgsw_decomposer, ntt_op, @@ -334,9 +342,11 @@ fn blind_rotation< if gk_to_si[i - 1].len() != 0 || v == w || i == 1 { let (auto_map_index, auto_map_sign) = parameters.rlwe_auto_map(v); - galois_auto( + let auto_key = pbs_key.galois_key_for_auto(v); + galois_auto_shoup( trivial_rlwe_test_poly, - pbs_key.galois_key_for_auto(v), + auto_key.as_ref(), + auto_key.shoup_repr(), scratch_matrix, &auto_map_index, &auto_map_sign, @@ -351,9 +361,11 @@ fn blind_rotation< // +(g^0) gk_to_si[0].iter().for_each(|s_index| { - rlwe_by_rgsw( + let ct = pbs_key.rgsw_ct_lwe_si(*s_index); + rlwe_by_rgsw_shoup( trivial_rlwe_test_poly, - pbs_key.rgsw_ct_lwe_si(*s_index), + ct.as_ref(), + ct.shoup_repr(), scratch_matrix, rlwe_rgsw_decomposer, ntt_op, diff --git a/src/rgsw/runtime.rs b/src/rgsw/runtime.rs index 77bbe16..5a48fbf 100644 --- a/src/rgsw/runtime.rs +++ b/src/rgsw/runtime.rs @@ -185,7 +185,7 @@ pub(crate) fn galois_auto_shoup< MT: Matrix + IsTrivial + MatrixMut, Mmut: MatrixMut, ModOp: ArithmeticOps - + VectorOps + // + VectorOps + ShoupMatrixFMA, NttOp: Ntt, D: Decomposer, @@ -422,7 +422,7 @@ pub(crate) fn rlwe_by_rgsw_shoup< Mmut: MatrixMut, MT: Matrix + MatrixMut + IsTrivial, D: RlweDecomposer, - ModOp: VectorOps + ShoupMatrixFMA, + ModOp: ShoupMatrixFMA, NttOp: Ntt, >( rlwe_in: &mut MT, diff --git a/src/shortint/mod.rs b/src/shortint/mod.rs index 5fd8b9c..1d0f0fa 100644 --- a/src/shortint/mod.rs +++ b/src/shortint/mod.rs @@ -97,12 +97,18 @@ mod frontend { eight_bit_mul, }; use crate::{ - bool::{evaluator::BoolEvaluator, keys::ServerKeyEvaluationDomain}, + bool::{ + evaluator::{self, BoolEvaluator, BooleanGates}, + keys::{ServerKeyEvaluationDomain, ShoupServerKeyEvaluationDomain}, + }, utils::{Global, WithLocal}, }; use super::FheUint8; + type ShortIntBoolEvaluator = + BoolEvaluator; + mod arithetic { use crate::bool::{evaluator::BooleanGates, FheBool}; @@ -111,8 +117,8 @@ mod frontend { impl AddAssign<&FheUint8> for FheUint8 { fn add_assign(&mut self, rhs: &FheUint8) { - BoolEvaluator::with_local_mut_mut(&mut |e| { - let key = ServerKeyEvaluationDomain::global(); + ShortIntBoolEvaluator::with_local_mut_mut(&mut |e| { + let key = as BooleanGates>::Key::global(); arbitrary_bit_adder(e, self.data_mut(), rhs.data(), false, key); }); }