diff --git a/benches/modulus.rs b/benches/modulus.rs index 56429e2..eade8c9 100644 --- a/benches/modulus.rs +++ b/benches/modulus.rs @@ -1,4 +1,7 @@ -use bin_rs::{ArithmeticOps, Decomposer, DefaultDecomposer, ModInit, ModularOpsU64, VectorOps}; +use bin_rs::{ + ArithmeticLazyOps, ArithmeticOps, Decomposer, DefaultDecomposer, ModInit, ModularOpsU64, + ShoupMatrixFMA, VectorOps, +}; use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; use itertools::{izip, Itertools}; use rand::{thread_rng, Rng}; @@ -21,12 +24,9 @@ fn decompose_r(r: &[u64], decomp_r: &mut [Vec], decomposer: &DefaultDecompo } fn matrix_fma(out: &mut [u64], a: &Vec>, b: &Vec>, modop: &ModularOpsU64) { - izip!(out.iter_mut(), a[0].iter(), b[0].iter()) - .for_each(|(o, ai, bi)| *o = modop.add(o, &modop.mul_lazy(ai, bi))); - - izip!(a.iter().skip(1), b.iter().skip(1)).for_each(|(a_r, b_r)| { + izip!(a.iter(), b.iter()).for_each(|(a_r, b_r)| { izip!(out.iter_mut(), a_r.iter(), b_r.iter()) - .for_each(|(o, ai, bi)| *o = modop.add_lazy(o, &modop.mul(ai, bi))); + .for_each(|(o, ai, bi)| *o = modop.add_lazy(o, &modop.mul_lazy(ai, bi))); }); } @@ -127,7 +127,7 @@ fn benchmark(c: &mut Criterion) { b.iter_batched_ref( || (vec![0u64; ring_size]), |(out)| { - black_box(modop.shoup_fma( + black_box(modop.shoup_matrix_fma( out, &a0_matrix, &a0_shoup_matrix, diff --git a/src/backend/mod.rs b/src/backend/mod.rs index 097aab4..c2d78b5 100644 --- a/src/backend/mod.rs +++ b/src/backend/mod.rs @@ -127,6 +127,7 @@ pub trait ArithmeticLazyOps { } pub trait ShoupMatrixFMA { - /// Returns summation of `row-wise product of matrix a and b` + out. + /// Returns summation of `row-wise product of matrix a and b` + out where + /// each element is in range [0, 2q) fn shoup_matrix_fma(&self, out: &mut [R::Element], a: &[R], a_shoup: &[R], b: &[R]); } diff --git a/src/bool/evaluator.rs b/src/bool/evaluator.rs index 77be7eb..f5211d0 100644 --- a/src/bool/evaluator.rs +++ b/src/bool/evaluator.rs @@ -22,7 +22,7 @@ use crate::{ lwe::{decrypt_lwe, encrypt_lwe, lwe_key_switch, lwe_ksk_keygen, measure_noise_lwe, LweSecret}, multi_party::public_key_share, ntt::{self, Ntt, NttBackendU64, NttInit}, - pbs::{pbs, sample_extract, PbsInfo, PbsKey}, + pbs::{pbs, sample_extract, PbsInfo, PbsKey, WithShoupRepr}, random::{ DefaultSecureRng, NewWithSeed, RandomFill, RandomFillGaussianInModulus, RandomFillUniformInModulus, RandomGaussianElementInModulus, @@ -81,7 +81,7 @@ impl MultiPartyCrs { pub(crate) trait BooleanGates { type Ciphertext: RowEntity; - type Key: Global; + type Key; fn and_inplace(&mut self, c0: &mut Self::Ciphertext, c1: &Self::Ciphertext, key: &Self::Key); fn nand_inplace(&mut self, c0: &mut Self::Ciphertext, c1: &Self::Ciphertext, key: &Self::Key); @@ -294,7 +294,7 @@ where } } -pub(crate) struct BoolEvaluator +pub(crate) struct BoolEvaluator where M: Matrix, { @@ -306,10 +306,12 @@ where nor_test_vec: M::R, xor_test_vec: M::R, xnor_test_vec: M::R, - _phantom: PhantomData, + _phantom: PhantomData, } -impl BoolEvaluator { +impl + BoolEvaluator +{ pub(super) fn parameters(&self) -> &BoolParameters { &self.pbs_info.parameters } @@ -478,7 +480,7 @@ where ClientKey::new(sk_rlwe, sk_lwe) } - pub(super) fn server_key( + pub(super) fn single_party_server_key( &self, client_key: &ClientKey, ) -> SeededServerKey, [u8; 32]> { @@ -1082,7 +1084,7 @@ where } } -impl BoolEvaluator +impl BoolEvaluator where M: MatrixMut + MatrixEntity, M::R: RowMut + RowEntity, @@ -1110,7 +1112,8 @@ where } } -impl BooleanGates for BoolEvaluator +impl BooleanGates + for BoolEvaluator where M: MatrixMut + MatrixEntity, M::R: RowMut + RowEntity + Clone, @@ -1121,9 +1124,11 @@ where + ShoupMatrixFMA, LweModOp: VectorOps + ArithmeticOps, NttOp: Ntt, + Skey: PbsKey::RgswCt, LweKskKey = M>, + ::RgswCt: WithShoupRepr, { type Ciphertext = M::R; - type Key = Key; + type Key = Skey; fn nand_inplace(&mut self, c0: &mut M::R, c1: &M::R, server_key: &Self::Key) { self._add_and_shift_lwe_cts(c0, c1); @@ -1287,1108 +1292,1106 @@ where } } -// #[cfg(test)] -// mod tests { -// use bool::parameters::{MP_BOOL_PARAMS, SP_BOOL_PARAMS}; -// use rand::{thread_rng, Rng}; -// use rand_distr::Uniform; - -// use crate::{ -// backend::{GetModulus, ModInit, ModularOpsU64, WordSizeModulus}, -// bool::{ -// self, CommonReferenceSeededMultiPartyServerKeyShare, PublicKey, -// SeededMultiPartyServerKey, -// }, -// ntt::NttBackendU64, -// random::{RandomElementInModulus, DEFAULT_RNG}, -// rgsw::{ -// self, measure_noise, public_key_encrypt_rlwe, -// secret_key_encrypt_rlwe, tests::{_measure_noise_rgsw, -// _sk_encrypt_rlwe}, RgswCiphertext, -// RgswCiphertextEvaluationDomain, SeededRgswCiphertext, -// SeededRlweCiphertext, }, -// utils::{negacyclic_mul, Stats}, -// }; - -// use super::*; - -// #[test] -// fn bool_encrypt_decrypt_works() { -// let bool_evaluator = BoolEvaluator::< -// Vec>, -// NttBackendU64, -// ModularOpsU64>, -// ModularOpsU64>, -// >::new(SP_BOOL_PARAMS); -// let client_key = bool_evaluator.client_key(); - -// let mut m = true; -// for _ in 0..1000 { -// let lwe_ct = bool_evaluator.sk_encrypt(m, &client_key); -// let m_back = bool_evaluator.sk_decrypt(&lwe_ct, &client_key); -// assert_eq!(m, m_back); -// m = !m; -// } -// } - -// #[test] -// fn bool_nand() { -// let mut bool_evaluator = BoolEvaluator::< -// Vec>, -// NttBackendU64, -// ModularOpsU64>, -// ModularOpsU64>, -// >::new(SP_BOOL_PARAMS); - -// // println!("{:?}", bool_evaluator.nand_test_vec); -// let client_key = bool_evaluator.client_key(); -// let seeded_server_key = bool_evaluator.server_key(&client_key); -// let server_key_eval_domain = -// ServerKeyEvaluationDomain::<_, DefaultSecureRng, -// NttBackendU64>::from( &seeded_server_key, -// ); - -// let mut m0 = false; -// let mut m1 = true; - -// let mut ct0 = bool_evaluator.sk_encrypt(m0, &client_key); -// let mut ct1 = bool_evaluator.sk_encrypt(m1, &client_key); - -// for _ in 0..500 { -// let ct_back = bool_evaluator.nand(&ct0, &ct1, -// &server_key_eval_domain); - -// let m_out = !(m0 && m1); - -// let m_back = bool_evaluator.sk_decrypt(&ct_back, &client_key); -// assert!(m_out == m_back, "Expected {m_out}, got {m_back}"); - -// m1 = m0; -// m0 = m_out; - -// ct1 = ct0; -// ct0 = ct_back; -// } -// } - -// #[test] -// fn bool_xor() { -// let mut bool_evaluator = BoolEvaluator::< -// Vec>, -// NttBackendU64, -// ModularOpsU64>, -// ModularOpsU64>, -// >::new(SP_BOOL_PARAMS); - -// // println!("{:?}", bool_evaluator.nand_test_vec); -// let client_key = bool_evaluator.client_key(); -// let seeded_server_key = bool_evaluator.server_key(&client_key); -// let server_key_eval_domain = -// ServerKeyEvaluationDomain::<_, DefaultSecureRng, -// NttBackendU64>::from( &seeded_server_key, -// ); - -// let mut m0 = false; -// let mut m1 = true; - -// let mut ct0 = bool_evaluator.sk_encrypt(m0, &client_key); -// let mut ct1 = bool_evaluator.sk_encrypt(m1, &client_key); - -// for _ in 0..1000 { -// let ct_back = bool_evaluator.xor(&ct0, &ct1, -// &server_key_eval_domain); let m_out = (m0 ^ m1); - -// let m_back = bool_evaluator.sk_decrypt(&ct_back, &client_key); -// assert!(m_out == m_back, "Expected {m_out}, got {m_back}"); - -// m1 = m0; -// m0 = m_out; - -// ct1 = ct0; -// ct0 = ct_back; -// } -// } - -// #[test] -// fn multi_party_encryption_decryption() { -// let bool_evaluator = BoolEvaluator::< -// Vec>, -// NttBackendU64, -// ModularOpsU64>, -// ModularOpsU64>, -// >::new(MP_BOOL_PARAMS); - -// let no_of_parties = 500; -// let parties = (0..no_of_parties) -// .map(|_| bool_evaluator.client_key()) -// .collect_vec(); - -// let mut ideal_rlwe_sk = vec![0i32; bool_evaluator.pbs_info.rlwe_n()]; -// parties.iter().for_each(|k| { -// izip!(ideal_rlwe_sk.iter_mut(), -// k.sk_rlwe().values()).for_each(|(ideal_i, s_i)| { *ideal_i = -// *ideal_i + s_i; }); -// }); - -// let mut m = true; -// for i in 0..100 { -// let pk_cr_seed = [0u8; 32]; - -// let public_key_share = parties -// .iter() -// .map(|k| -// bool_evaluator.multi_party_public_key_share(pk_cr_seed, k)) -// .collect_vec(); - -// let collective_pk = PublicKey::< -// Vec>, -// DefaultSecureRng, -// ModularOpsU64>, -// >::from(public_key_share.as_slice()); -// let lwe_ct = bool_evaluator.pk_encrypt(collective_pk.key(), m); - -// let decryption_shares = parties -// .iter() -// .map(|k| bool_evaluator.multi_party_decryption_share(&lwe_ct, -// k)) .collect_vec(); - -// let m_back = -// bool_evaluator.multi_party_decrypt(&decryption_shares, &lwe_ct); - -// { -// let ideal_m = if m { -// bool_evaluator.pbs_info.parameters.rlwe_q().true_el() -// } else { -// bool_evaluator.pbs_info.parameters.rlwe_q().false_el() -// }; -// let noise = measure_noise_lwe( -// &lwe_ct, -// &ideal_rlwe_sk, -// &bool_evaluator.pbs_info.rlwe_modop, -// &ideal_m, -// ); -// println!("Noise: {noise}"); -// } - -// assert_eq!(m_back, m); -// m = !m; -// } -// } - -// fn _collecitve_public_key_gen(rlwe_q: u64, parties_rlwe_sk: -// &[RlweSecret]) -> Vec> { let ring_size = -// parties_rlwe_sk[0].values.len(); assert!(ring_size. -// is_power_of_two()); let mut rng = DefaultSecureRng::new(); -// let nttop = NttBackendU64::new(&rlwe_q, ring_size); -// let modop = ModularOpsU64::new(rlwe_q); - -// // Generate Pk shares -// let pk_seed = [0u8; 32]; -// let pk_shares = parties_rlwe_sk.iter().map(|sk| { -// let mut p_rng = DefaultSecureRng::new_seeded(pk_seed); -// let mut share_out = vec![0u64; ring_size]; -// public_key_share( -// &mut share_out, -// sk.values(), -// &modop, -// &nttop, -// &mut p_rng, -// &mut rng, -// ); -// share_out -// }); - -// let mut pk_part_b = vec![0u64; ring_size]; -// pk_shares.for_each(|share| modop.elwise_add_mut(&mut pk_part_b, -// &share)); let mut pk_part_a = vec![0u64; ring_size]; -// let mut p_rng = DefaultSecureRng::new_seeded(pk_seed); -// RandomFillUniformInModulus::random_fill(&mut p_rng, &rlwe_q, -// pk_part_a.as_mut_slice()); - -// vec![pk_part_a, pk_part_b] -// } - -// fn _multi_party_all_keygen( -// bool_evaluator: &BoolEvaluator< -// Vec>, -// NttBackendU64, -// ModularOpsU64>, -// ModularOpsU64>, -// >, -// no_of_parties: usize, -// ) -> ( -// Vec, -// PublicKey>, DefaultSecureRng, -// ModularOpsU64>>, Vec< -// CommonReferenceSeededMultiPartyServerKeyShare< -// Vec>, -// BoolParameters, -// [u8; 32], -// >, -// >, -// SeededMultiPartyServerKey>, [u8; 32], -// BoolParameters>, ServerKeyEvaluationDomain>, -// DefaultSecureRng, NttBackendU64>, ClientKey, -// ) { -// let parties = (0..no_of_parties) -// .map(|_| bool_evaluator.client_key()) -// .collect_vec(); - -// let mut rng = DefaultSecureRng::new(); - -// // Collective public key -// let mut pk_cr_seed = [0u8; 32]; -// rng.fill_bytes(&mut pk_cr_seed); -// let public_key_share = parties -// .iter() -// .map(|k| bool_evaluator.multi_party_public_key_share(pk_cr_seed, -// k)) .collect_vec(); -// let collective_pk = -// PublicKey::>, DefaultSecureRng, -// _>::from(public_key_share.as_slice()); - -// // Server key -// let mut pbs_cr_seed = [0u8; 32]; -// rng.fill_bytes(&mut pbs_cr_seed); -// let server_key_shares = parties -// .iter() -// .map(|k| { -// bool_evaluator.multi_party_server_key_share(pbs_cr_seed, -// &collective_pk.key(), k) }) -// .collect_vec(); -// let seeded_server_key = -// -// bool_evaluator.aggregate_multi_party_server_key_shares(&server_key_shares); -// let server_key_eval = ServerKeyEvaluationDomain::<_, -// DefaultSecureRng, NttBackendU64>::from( &seeded_server_key, -// ); - -// // construct ideal rlwe sk for meauring noise -// let ideal_client_key = { -// let mut ideal_rlwe_sk = vec![0i32; -// bool_evaluator.pbs_info.rlwe_n()]; parties.iter().for_each(|k| { -// izip!(ideal_rlwe_sk.iter_mut(), -// k.sk_rlwe().values()).for_each(|(ideal_i, s_i)| { -// *ideal_i = *ideal_i + s_i; }); -// }); -// let mut ideal_lwe_sk = vec![0i32; -// bool_evaluator.pbs_info.lwe_n()]; parties.iter().for_each(|k| { -// izip!(ideal_lwe_sk.iter_mut(), -// k.sk_lwe().values()).for_each(|(ideal_i, s_i)| { *ideal_i -// = *ideal_i + s_i; }); -// }); - -// ClientKey::new( -// RlweSecret { -// values: ideal_rlwe_sk, -// }, -// LweSecret { -// values: ideal_lwe_sk, -// }, -// ) -// }; - -// ( -// parties, -// collective_pk, -// server_key_shares, -// seeded_server_key, -// server_key_eval, -// ideal_client_key, -// ) -// } - -// #[test] -// fn multi_party_nand() { -// let mut bool_evaluator = BoolEvaluator::< -// Vec>, -// NttBackendU64, -// ModularOpsU64>, -// ModularOpsU64>, -// >::new(MP_BOOL_PARAMS); - -// let (parties, collective_pk, _, _, server_key_eval, ideal_client_key) -// = _multi_party_all_keygen(&bool_evaluator, 2); - -// let mut m0 = true; -// let mut m1 = false; - -// let mut lwe0 = bool_evaluator.pk_encrypt(collective_pk.key(), m0); -// let mut lwe1 = bool_evaluator.pk_encrypt(collective_pk.key(), m1); - -// for _ in 0..2000 { -// let lwe_out = bool_evaluator.nand(&lwe0, &lwe1, -// &server_key_eval); - -// let m_expected = !(m0 & m1); - -// // multi-party decrypt -// let decryption_shares = parties -// .iter() -// .map(|k| -// bool_evaluator.multi_party_decryption_share(&lwe_out, k)) -// .collect_vec(); let m_back = -// bool_evaluator.multi_party_decrypt(&decryption_shares, &lwe_out); - -// // let m_back = bool_evaluator.sk_decrypt(&lwe_out, -// &ideal_client_key); - -// assert!(m_expected == m_back, "Expected {m_expected}, got -// {m_back}"); m1 = m0; -// m0 = m_expected; - -// lwe1 = lwe0; -// lwe0 = lwe_out; -// } -// } - -// #[test] -// fn noise_tester() { -// let bool_evaluator = BoolEvaluator::< -// Vec>, -// NttBackendU64, -// ModularOpsU64>, -// ModularOpsU64>, -// >::new(SP_BOOL_PARAMS); - -// // let (_, collective_pk, _, _, server_key_eval, ideal_client_key) = -// // _multi_party_all_keygen(&bool_evaluator, 20); -// let no_of_parties = 2; -// let lwe_q = bool_evaluator.pbs_info.parameters.lwe_q(); -// let rlwe_q = bool_evaluator.pbs_info.parameters.rlwe_q(); -// let lwe_n = bool_evaluator.pbs_info.parameters.lwe_n().0; -// let rlwe_n = bool_evaluator.pbs_info.parameters.rlwe_n().0; -// let lwe_modop = &bool_evaluator.pbs_info.lwe_modop; -// let rlwe_nttop = &bool_evaluator.pbs_info.rlwe_nttop; -// let rlwe_modop = &bool_evaluator.pbs_info.rlwe_modop; - -// // let rgsw_rgsw_decomposer = &bool_evaluator -// // .pbs_info -// // .parameters -// // .rgsw_rgsw_decomposer::>(); -// // let rgsw_rgsw_gadget_a = rgsw_rgsw_decomposer.0.gadget_vector(); -// // let rgsw_rgsw_gadget_b = rgsw_rgsw_decomposer.1.gadget_vector(); - -// let rlwe_rgsw_decomposer = -// &bool_evaluator.pbs_info.rlwe_rgsw_decomposer; let rlwe_rgsw_gadget_a -// = rlwe_rgsw_decomposer.0.gadget_vector(); let rlwe_rgsw_gadget_b = -// rlwe_rgsw_decomposer.1.gadget_vector(); - -// let auto_decomposer = &bool_evaluator.pbs_info.auto_decomposer; -// let auto_gadget = auto_decomposer.gadget_vector(); - -// let parties = (0..no_of_parties) -// .map(|_| bool_evaluator.client_key()) -// .collect_vec(); - -// let ideal_client_key = { -// let mut ideal_rlwe_sk = vec![0i32; -// bool_evaluator.pbs_info.rlwe_n()]; parties.iter().for_each(|k| { -// izip!(ideal_rlwe_sk.iter_mut(), -// k.sk_rlwe().values()).for_each(|(ideal_i, s_i)| { -// *ideal_i = *ideal_i + s_i; }); -// }); -// let mut ideal_lwe_sk = vec![0i32; -// bool_evaluator.pbs_info.lwe_n()]; parties.iter().for_each(|k| { -// izip!(ideal_lwe_sk.iter_mut(), -// k.sk_lwe().values()).for_each(|(ideal_i, s_i)| { *ideal_i -// = *ideal_i + s_i; }); -// }); - -// ClientKey::new( -// RlweSecret { -// values: ideal_rlwe_sk, -// }, -// LweSecret { -// values: ideal_lwe_sk, -// }, -// ) -// }; - -// // check noise in freshly encrypted RLWE ciphertext (ie var_fresh) -// if true { -// let mut rng = DefaultSecureRng::new(); -// let mut check = Stats { samples: vec![] }; -// for _ in 0..10 { -// // generate a new collective public key -// let mut pk_cr_seed = [0u8; 32]; -// rng.fill_bytes(&mut pk_cr_seed); -// let public_key_share = parties -// .iter() -// .map(|k| -// bool_evaluator.multi_party_public_key_share(pk_cr_seed, k)) -// .collect_vec(); let collective_pk = PublicKey::< -// Vec>, -// DefaultSecureRng, -// ModularOpsU64>, -// >::from(public_key_share.as_slice()); - -// let mut m = vec![0u64; rlwe_n]; -// RandomFillUniformInModulus::random_fill(&mut rng, rlwe_q, -// m.as_mut_slice()); let mut rlwe_ct = vec![vec![0u64; rlwe_n]; -// 2]; public_key_encrypt_rlwe::<_, _, _, _, i32, _>( -// &mut rlwe_ct, -// collective_pk.key(), -// &m, -// rlwe_modop, -// rlwe_nttop, -// &mut rng, -// ); - -// let mut m_back = vec![0u64; rlwe_n]; -// decrypt_rlwe( -// &rlwe_ct, -// ideal_client_key.sk_rlwe().values(), -// &mut m_back, -// rlwe_nttop, -// rlwe_modop, -// ); - -// rlwe_modop.elwise_sub_mut(m_back.as_mut_slice(), -// m.as_slice()); - -// check.add_more(Vec::::try_convert_from(&m_back, -// rlwe_q).as_slice()); } - -// println!("Public key Std: {}", check.std_dev().abs().log2()); -// } - -// if true { -// // Generate server key shares -// let mut rng = DefaultSecureRng::new(); -// let mut pk_cr_seed = [0u8; 32]; -// rng.fill_bytes(&mut pk_cr_seed); -// let public_key_share = parties -// .iter() -// .map(|k| -// bool_evaluator.multi_party_public_key_share(pk_cr_seed, k)) -// .collect_vec(); let collective_pk = PublicKey::< -// Vec>, -// DefaultSecureRng, -// ModularOpsU64>, -// >::from(public_key_share.as_slice()); - -// let pbs_cr_seed = [0u8; 32]; -// rng.fill_bytes(&mut pk_cr_seed); -// let server_key_shares = parties -// .iter() -// .map(|k| { -// bool_evaluator.multi_party_server_key_share(pbs_cr_seed, -// collective_pk.key(), k) }) -// .collect_vec(); - -// let seeded_server_key = -// -// bool_evaluator.aggregate_multi_party_server_key_shares(&server_key_shares); - -// // Check noise in RGSW ciphertexts of ideal LWE secret elements -// if false { -// let mut check = Stats { samples: vec![] }; -// izip!( -// ideal_client_key.sk_lwe().values.iter(), -// seeded_server_key.rgsw_cts().iter() -// ) -// .for_each(|(s_i, rgsw_ct_i)| { -// // X^{s[i]} -// let mut m_si = vec![0u64; rlwe_n]; -// let s_i = *s_i * -// (bool_evaluator.pbs_info.embedding_factor as i32); if s_i -// < 0 { m_si[rlwe_n - (s_i.abs() as usize)] = -// rlwe_q.neg_one(); } else { -// m_si[s_i as usize] = 1; -// } - -// // RLWE'(-sm) -// let mut neg_s_eval = -// -// Vec::::try_convert_from(ideal_client_key.sk_rlwe().values(), rlwe_q); -// rlwe_modop.elwise_neg_mut(&mut neg_s_eval); -// rlwe_nttop.forward(&mut neg_s_eval); -// for j in -// 0..rlwe_rgsw_decomposer.a().decomposition_count() { -// // RLWE(B^{j} * -s[X]*X^{s_lwe[i]}) - -// // -s[X]*X^{s_lwe[i]}*B_j -// let mut m_ideal = m_si.clone(); -// rlwe_nttop.forward(m_ideal.as_mut_slice()); -// rlwe_modop.elwise_mul_mut(m_ideal.as_mut_slice(), -// neg_s_eval.as_slice()); -// rlwe_nttop.backward(m_ideal.as_mut_slice()); -// rlwe_modop -// .elwise_scalar_mul_mut(m_ideal.as_mut_slice(), &rlwe_rgsw_gadget_a[j]); - -// // RLWE(-s*X^{s_lwe[i]}*B_j) -// let mut rlwe_ct = vec![vec![0u64; rlwe_n]; 2]; -// rlwe_ct[0].copy_from_slice(&rgsw_ct_i[j]); -// rlwe_ct[1].copy_from_slice( -// &rgsw_ct_i[j + -// rlwe_rgsw_decomposer.a().decomposition_count()], ); - -// let mut m_back = vec![0u64; rlwe_n]; -// decrypt_rlwe( -// &rlwe_ct, -// ideal_client_key.sk_rlwe().values(), -// &mut m_back, -// rlwe_nttop, -// rlwe_modop, -// ); - -// // diff -// rlwe_modop.elwise_sub_mut(&mut m_back, &m_ideal); -// check.add_more(&Vec::::try_convert_from(&m_back, -// rlwe_q)); } - -// // RLWE'(m) -// for j in -// 0..rlwe_rgsw_decomposer.b().decomposition_count() { -// // RLWE(B^{j} * X^{s_lwe[i]}) - -// // X^{s_lwe[i]}*B_j -// let mut m_ideal = m_si.clone(); -// rlwe_modop -// .elwise_scalar_mul_mut(m_ideal.as_mut_slice(), -// &rlwe_rgsw_gadget_b[j]); - -// // RLWE(X^{s_lwe[i]}*B_j) -// let mut rlwe_ct = vec![vec![0u64; rlwe_n]; 2]; -// rlwe_ct[0].copy_from_slice( -// &rgsw_ct_i[j + (2 * -// rlwe_rgsw_decomposer.a().decomposition_count())], ); -// rlwe_ct[1].copy_from_slice( -// &rgsw_ct_i[j -// + (2 * rlwe_rgsw_decomposer.a(). -// decomposition_count() -// + rlwe_rgsw_decomposer.b(). -// decomposition_count())], -// ); - -// let mut m_back = vec![0u64; rlwe_n]; -// decrypt_rlwe( -// &rlwe_ct, -// ideal_client_key.sk_rlwe().values(), -// &mut m_back, -// rlwe_nttop, -// rlwe_modop, -// ); - -// // diff -// rlwe_modop.elwise_sub_mut(&mut m_back, &m_ideal); -// check.add_more(&Vec::::try_convert_from(&m_back, -// rlwe_q)); } -// }); -// println!( -// "RGSW Std: {} {} ;; max={}", -// check.mean(), -// check.std_dev().abs().log2(), -// check.samples.iter().max().unwrap() -// ); -// } - -// // server key in Evaluation domain -// let server_key_eval_domain = -// ServerKeyEvaluationDomain::<_, DefaultSecureRng, -// NttBackendU64>::from( &seeded_server_key, -// ); - -// // check noise in RLWE x RGSW(X^{s_i}) where RGSW is accunulated -// RGSW ciphertext if false { -// let mut check = Stats { samples: vec![] }; - -// izip!( -// ideal_client_key.sk_lwe().values(), -// server_key_eval_domain.rgsw_cts().iter() -// ) -// .for_each(|(s_i, rgsw_ct_i)| { -// let mut m = vec![0u64; rlwe_n]; -// RandomFillUniformInModulus::random_fill(&mut rng, rlwe_q, -// m.as_mut_slice()); let mut rlwe_ct = vec![vec![0u64; -// rlwe_n]; 2]; public_key_encrypt_rlwe::<_, _, _, _, i32, -// _>( &mut rlwe_ct, -// collective_pk.key(), -// &m, -// rlwe_modop, -// rlwe_nttop, -// &mut rng, -// ); - -// // RLWE(m*X^{s[i]}) = RLWE(m) x RGSW(X^{s[i]}) -// // let mut rlwe_after = RlweCiphertext::<_, -// DefaultSecureRng>::new_trivial(vec![ // vec![0u64; -// rlwe_n], // m.clone(), -// // ]); -// let mut rlwe_after = RlweCiphertext::<_, -// DefaultSecureRng> { data: rlwe_ct.clone(), -// is_trivial: false, -// _phatom: PhantomData, -// }; -// let mut scratch = vec![ -// vec![0u64; rlwe_n]; -// std::cmp::max( -// rlwe_rgsw_decomposer.0.decomposition_count(), -// rlwe_rgsw_decomposer.1.decomposition_count() -// ) + 2 -// ]; -// rlwe_by_rgsw( -// &mut rlwe_after, -// rgsw_ct_i, -// &mut scratch, -// rlwe_rgsw_decomposer, -// rlwe_nttop, -// rlwe_modop, -// ); - -// // m1 = X^{s[i]} -// let mut m1 = vec![0u64; rlwe_n]; -// let s_i = *s_i * -// (bool_evaluator.pbs_info.embedding_factor as i32); if s_i -// < 0 { m1[rlwe_n - (s_i.abs() as usize)] = -// rlwe_q.neg_one() } else { -// m1[s_i as usize] = 1; -// } - -// // (m+e) * m1 -// let mut m_plus_e_times_m1 = vec![0u64; rlwe_n]; -// decrypt_rlwe( -// &rlwe_ct, -// ideal_client_key.sk_rlwe().values(), -// &mut m_plus_e_times_m1, -// rlwe_nttop, -// rlwe_modop, -// ); -// rlwe_nttop.forward(m_plus_e_times_m1.as_mut_slice()); -// rlwe_nttop.forward(m1.as_mut_slice()); -// -// rlwe_modop.elwise_mul_mut(m_plus_e_times_m1.as_mut_slice(), m1.as_slice()); -// rlwe_nttop.backward(m_plus_e_times_m1.as_mut_slice()); - -// // Resulting RLWE ciphertext will equal: (m0m1 + em1) + -// e_{rlsw x rgsw}. // Hence, resulting rlwe ciphertext will -// have error em1 + e_{rlwe x rgsw}. // Here we're only -// concerned with e_{rlwe x rgsw}, that is noise added by // -// RLWExRGSW. Also note in practice m1 is a monomial, for ex, X^{s_{i}}, for -// // some i and var(em1) = var(e). -// let mut m_plus_e_times_m1_more_e = vec![0u64; rlwe_n]; -// decrypt_rlwe( -// &rlwe_after, -// ideal_client_key.sk_rlwe().values(), -// &mut m_plus_e_times_m1_more_e, -// rlwe_nttop, -// rlwe_modop, -// ); - -// // diff -// rlwe_modop.elwise_sub_mut( -// m_plus_e_times_m1_more_e.as_mut_slice(), -// m_plus_e_times_m1.as_slice(), -// ); - -// // let noise = measure_noise( -// // &rlwe_after, -// // &m_plus_e_times_m1, -// // rlwe_nttop, -// // rlwe_modop, -// // ideal_client_key.sk_rlwe.values(), -// // ); -// // print!("NOISE: {}", noise); - -// check.add_more(&Vec::::try_convert_from( -// &m_plus_e_times_m1_more_e, -// rlwe_q, -// )); -// }); -// println!( -// "RLWE x RGSW, where RGSW has noise var_brk, std: {} {}", -// check.std_dev(), -// check.std_dev().abs().log2() -// ) -// } - -// // check noise in Auto key -// if false { -// let mut check = Stats { samples: vec![] }; - -// let mut neg_s_poly = -// -// Vec::::try_convert_from(ideal_client_key.sk_rlwe().values(), rlwe_q); -// rlwe_modop.elwise_neg_mut(neg_s_poly.as_mut_slice()); - -// let g = bool_evaluator.pbs_info.g(); -// let br_q = bool_evaluator.pbs_info.br_q(); -// let auto_element_dlogs = -// bool_evaluator.pbs_info.parameters.auto_element_dlogs(); for -// i in auto_element_dlogs.into_iter() { let g_pow = if i == -// 0 { -g -// } else { -// (((g as usize).pow(i as u32)) % br_q) as isize -// }; - -// // -s[X^k] -// let (auto_indices, auto_sign) = generate_auto_map(rlwe_n, -// g_pow); let mut neg_s_poly_auto_i = vec![0u64; rlwe_n]; -// izip!(neg_s_poly.iter(), auto_indices.iter(), -// auto_sign.iter()).for_each( |(v, to_i, to_sign)| { -// if !to_sign { -// neg_s_poly_auto_i[*to_i] = rlwe_modop.neg(v); -// } else { -// neg_s_poly_auto_i[*to_i] = *v; -// } -// }, -// ); - -// let mut auto_key_i = -// server_key_eval_domain.galois_key_for_auto(i).clone(); // -// send i^th auto key to coefficient domain auto_key_i -// .iter_mut() -// .for_each(|r| rlwe_nttop.backward(r.as_mut_slice())); -// auto_gadget.iter().enumerate().for_each(|(i, b_i)| { -// // B^i * -s[X^k] -// let mut m_ideal = neg_s_poly_auto_i.clone(); -// -// rlwe_modop.elwise_scalar_mul_mut(m_ideal.as_mut_slice(), b_i); - -// let mut m_out = vec![0u64; rlwe_n]; -// let mut rlwe_ct = vec![vec![0u64; rlwe_n]; 2]; -// rlwe_ct[0].copy_from_slice(&auto_key_i[i]); -// rlwe_ct[1].copy_from_slice( -// &auto_key_i[auto_decomposer.decomposition_count() -// + i], ); decrypt_rlwe( &rlwe_ct, -// ideal_client_key.sk_rlwe().values(), &mut m_out, rlwe_nttop, rlwe_modop, ); - -// // diff -// rlwe_modop.elwise_sub_mut(m_out.as_mut_slice(), -// m_ideal.as_slice()); - -// check.add_more(&Vec::::try_convert_from(&m_out, -// rlwe_q)); }); -// } - -// println!("Auto key noise std dev: {}", -// check.std_dev().abs().log2()); } - -// // check noise in RLWE(X^k) after sending RLWE(X) -> RLWE(X^k) -// using collective // auto key -// if true { -// let mut check = Stats { samples: vec![] }; -// let br_q = bool_evaluator.pbs_info.br_q(); -// let g = bool_evaluator.pbs_info.g(); -// let auto_element_dlogs = -// bool_evaluator.pbs_info.parameters.auto_element_dlogs(); for -// i in auto_element_dlogs.into_iter() { for _ in 0..10 { -// let mut m = vec![0u64; rlwe_n]; -// RandomFillUniformInModulus::random_fill(&mut rng, -// rlwe_q, m.as_mut_slice()); let mut rlwe_ct = -// RlweCiphertext::<_, DefaultSecureRng> { data: -// vec![vec![0u64; rlwe_n]; 2], is_trivial: false, -// _phatom: PhantomData, -// }; -// public_key_encrypt_rlwe::<_, _, _, _, i32, _>( -// &mut rlwe_ct, -// collective_pk.key(), -// &m, -// rlwe_modop, -// rlwe_nttop, -// &mut rng, -// ); - -// // We're only interested in noise increased as a -// result of automorphism. // Hence, we take m+e as the -// bench. let mut m_plus_e = vec![0u64; rlwe_n]; -// decrypt_rlwe( -// &rlwe_ct, -// ideal_client_key.sk_rlwe().values(), -// &mut m_plus_e, -// rlwe_nttop, -// rlwe_modop, -// ); - -// let auto_key = -// server_key_eval_domain.galois_key_for_auto(i); let -// (auto_map_index, auto_map_sign) = -// bool_evaluator.pbs_info.rlwe_auto_map(i); let mut -// scratch = vec![vec![0u64; rlwe_n]; -// auto_decomposer.decomposition_count() + 2]; -// galois_auto( &mut rlwe_ct, -// auto_key, -// &mut scratch, -// &auto_map_index, -// &auto_map_sign, -// rlwe_modop, -// rlwe_nttop, -// auto_decomposer, -// ); - -// // send m+e from X to X^k -// let mut m_plus_e_auto = vec![0u64; rlwe_n]; -// izip!(m_plus_e.iter(), auto_map_index.iter(), -// auto_map_sign.iter()) .for_each(|(v, to_index, -// to_sign)| { if !to_sign { -// m_plus_e_auto[*to_index] = -// rlwe_modop.neg(v); } else { -// m_plus_e_auto[*to_index] = *v -// } -// }); - -// let mut m_out = vec![0u64; rlwe_n]; -// decrypt_rlwe( -// &rlwe_ct, -// ideal_client_key.sk_rlwe().values(), -// &mut m_out, -// rlwe_nttop, -// rlwe_modop, -// ); - -// // diff -// rlwe_modop.elwise_sub_mut(m_out.as_mut_slice(), -// m_plus_e_auto.as_slice()); - -// -// check.add_more(&Vec::::try_convert_from(m_out.as_slice(), rlwe_q)); -// } -// } - -// println!("Rlwe Auto Noise Std: {}", -// check.std_dev().abs().log2()); } - -// // Check noise growth in ksk -// // TODO check in LWE key switching keys -// if true { -// // 1. encrypt LWE ciphertext -// // 2. Key switching -// // 3. -// let mut check = Stats { samples: vec![] }; - -// for _ in 0..1024 { -// // Encrypt m \in Q_{ks} using RLWE sk -// let mut lwe_in_ct = vec![0u64; rlwe_n + 1]; -// let m = RandomElementInModulus::random(&mut rng, -// &lwe_q.q().unwrap()); encrypt_lwe( -// &mut lwe_in_ct, -// &m, -// ideal_client_key.sk_rlwe().values(), -// lwe_modop, -// &mut rng, -// ); - -// // Key switch -// let mut lwe_out = vec![0u64; lwe_n + 1]; -// lwe_key_switch( -// &mut lwe_out, -// &lwe_in_ct, -// server_key_eval_domain.lwe_ksk(), -// lwe_modop, -// bool_evaluator.pbs_info.lwe_decomposer(), -// ); - -// // We only care about noise added by LWE key switch -// // m+e -// let m_plus_e = -// decrypt_lwe(&lwe_in_ct, -// ideal_client_key.sk_rlwe().values(), lwe_modop); - -// let m_plus_e_plus_lwe_ksk_noise = -// decrypt_lwe(&lwe_out, -// ideal_client_key.sk_lwe().values(), lwe_modop); - -// let diff = lwe_modop.sub(&m_plus_e_plus_lwe_ksk_noise, -// &m_plus_e); - -// check.add_more(&vec![lwe_q.map_element_to_i64(&diff)]); -// } - -// println!("Lwe ksk std dev: {}", -// check.std_dev().abs().log2()); } -// } - -// // Check noise in fresh RGSW ciphertexts, ie X^{s_j[i]}, must equal -// noise in // fresh RLWE ciphertext -// if true {} -// // test LWE ksk from RLWE -> LWE -// // if false { -// // let logp = 2; -// // let mut rng = DefaultSecureRng::new(); - -// // let m = 1; -// // let encoded_m = m << (lwe_logq - logp); - -// // // Encrypt -// // let mut lwe_ct = vec![0u64; rlwe_n + 1]; -// // encrypt_lwe( -// // &mut lwe_ct, -// // &encoded_m, -// // ideal_client_key.sk_rlwe.values(), -// // lwe_modop, -// // &mut rng, -// // ); - -// // // key switch -// // let lwe_decomposer = &bool_evaluator.decomposer_lwe; -// // let mut lwe_out = vec![0u64; lwe_n + 1]; -// // lwe_key_switch( -// // &mut lwe_out, -// // &lwe_ct, -// // &server_key_eval.lwe_ksk, -// // lwe_modop, -// // lwe_decomposer, -// // ); - -// // let encoded_m_back = decrypt_lwe(&lwe_out, -// // ideal_client_key.sk_lwe.values(), lwe_modop); let m_back -// // = ((encoded_m_back as f64 * (1 << logp) as f64) / -// // (lwe_q as f64)).round() as u64; dbg!(m_back, m); - -// // let noise = measure_noise_lwe( -// // &lwe_out, -// // ideal_client_key.sk_lwe.values(), -// // lwe_modop, -// // &encoded_m, -// // ); - -// // println!("Noise: {noise}"); -// // } - -// // Measure noise in RGSW ciphertexts of ideal LWE secrets -// // if true { -// // let gadget_vec = gadget_vector( -// // bool_evaluator.parameters.rlwe_logq, -// // bool_evaluator.parameters.logb_rgsw, -// // bool_evaluator.parameters.d_rgsw, -// // ); - -// // for i in 0..20 { -// // // measure noise in RGSW(s[i]) -// // let si = -// // ideal_client_key.sk_lwe.values[i] * -// // (bool_evaluator.embedding_factor as i32); let mut -// // si_poly = vec![0u64; rlwe_n]; if si < 0 { -// // si_poly[rlwe_n - (si.abs() as usize)] = rlwe_q - 1; -// // } else { -// // si_poly[(si.abs() as usize)] = 1; -// // } - -// // let mut rgsw_si = server_key_eval.rgsw_cts[i].clone(); -// // rgsw_si -// // .iter_mut() -// // .for_each(|ri| rlwe_nttop.backward(ri.as_mut())); - -// // println!("####### Noise in RGSW(X^s_{i}) #######"); -// // _measure_noise_rgsw( -// // &rgsw_si, -// // &si_poly, -// // ideal_client_key.sk_rlwe.values(), -// // &gadget_vec, -// // rlwe_q, -// // ); -// // println!("####### ##################### #######"); -// // } -// // } - -// // // measure noise grwoth in RLWExRGSW -// // if true { -// // let mut rng = DefaultSecureRng::new(); -// // let mut carry_m = vec![0u64; rlwe_n]; -// // RandomUniformDist1::random_fill(&mut rng, &rlwe_q, -// // carry_m.as_mut_slice()); - -// // // RGSW(carrym) -// // let trivial_rlwect = vec![vec![0u64; rlwe_n], -// carry_m.clone()]; // let mut rlwe_ct = RlweCiphertext::<_, -// // DefaultSecureRng>::from_raw(trivial_rlwect, true); - -// // let mut scratch_matrix_dplus2_ring = vec![vec![0u64; rlwe_n]; -// // d_rgsw + 2]; let mul_mod = -// // |v0: &u64, v1: &u64| (((*v0 as u128 * *v1 as u128) % -// (rlwe_q as u128)) as u64); - -// // for i in 0..bool_evaluator.parameters.lwe_n { -// // rlwe_by_rgsw( -// // &mut rlwe_ct, -// // server_key_eval.rgsw_ct_lwe_si(i), -// // &mut scratch_matrix_dplus2_ring, -// // rlwe_decomposer, -// // rlwe_nttop, -// // rlwe_modop, -// // ); - -// // // carry_m[X] * s_i[X] -// // let si = -// // ideal_client_key.sk_lwe.values[i] * -// // (bool_evaluator.embedding_factor as i32); let mut -// // si_poly = vec![0u64; rlwe_n]; if si < 0 { -// // si_poly[rlwe_n - (si.abs() as usize)] = rlwe_q - 1; -// // } else { -// // si_poly[(si.abs() as usize)] = 1; -// // } -// // carry_m = negacyclic_mul(&carry_m, &si_poly, mul_mod, -// // rlwe_q); - -// // let noise = measure_noise( -// // &rlwe_ct, -// // &carry_m, -// // rlwe_nttop, -// // rlwe_modop, -// // ideal_client_key.sk_rlwe.values(), -// // ); -// // println!("Noise RLWE(carry_m) accumulating {i}^th secret -// // monomial: {noise}"); } -// // } - -// // // Check galois keys -// // if false { -// // let g = bool_evaluator.g() as isize; -// // let mut rng = DefaultSecureRng::new(); -// // let mut scratch_matrix_dplus2_ring = vec![vec![0u64; rlwe_n]; -// // d_rgsw + 2]; for i in [g, -g] { -// // let mut m = vec![0u64; rlwe_n]; -// // RandomUniformDist1::random_fill(&mut rng, &rlwe_q, -// // m.as_mut_slice()); let mut rlwe_ct = { -// // let mut data = vec![vec![0u64; rlwe_n]; 2]; -// // public_key_encrypt_rlwe( -// // &mut data, -// // &collective_pk.key, -// // &m, -// // rlwe_modop, -// // rlwe_nttop, -// // &mut rng, -// // ); -// // RlweCiphertext::<_, DefaultSecureRng>::from_raw(data, -// // false) }; - -// // let auto_key = server_key_eval.galois_key_for_auto(i); -// // let (auto_map_index, auto_map_sign) = -// // generate_auto_map(rlwe_n, i); galois_auto( -// // &mut rlwe_ct, -// // auto_key, -// // &mut scratch_matrix_dplus2_ring, -// // &auto_map_index, -// // &auto_map_sign, -// // rlwe_modop, -// // rlwe_nttop, -// // rlwe_decomposer, -// // ); - -// // // send m(X) -> m(X^i) -// // let mut m_k = vec![0u64; rlwe_n]; -// // izip!(m.iter(), auto_map_index.iter(), -// // auto_map_sign.iter()).for_each( |(mi, to_index, -// to_sign)| // { if !to_sign { -// // m_k[*to_index] = rlwe_q - *mi; -// // } else { -// // m_k[*to_index] = *mi; -// // } -// // }, -// // ); - -// // // measure noise -// // let noise = measure_noise( -// // &rlwe_ct, -// // &m_k, -// // rlwe_nttop, -// // rlwe_modop, -// // ideal_client_key.sk_rlwe.values(), -// // ); - -// // println!("Noise after auto k={i}: {noise}"); -// // } -// // } -// } -// } +#[cfg(test)] +mod tests { + use bool::parameters::{MP_BOOL_PARAMS, SP_BOOL_PARAMS}; + use rand::{thread_rng, Rng}; + use rand_distr::Uniform; + + use crate::{ + backend::{GetModulus, ModInit, ModularOpsU64, WordSizeModulus}, + bool::{ + self, CommonReferenceSeededMultiPartyServerKeyShare, PublicKey, + SeededMultiPartyServerKey, + }, + ntt::NttBackendU64, + random::{RandomElementInModulus, DEFAULT_RNG}, + rgsw::{ + self, measure_noise, public_key_encrypt_rlwe, secret_key_encrypt_rlwe, + tests::{_measure_noise_rgsw, _sk_encrypt_rlwe}, + RgswCiphertext, RgswCiphertextEvaluationDomain, SeededRgswCiphertext, + SeededRlweCiphertext, + }, + utils::{negacyclic_mul, Stats}, + }; + + use super::*; + + #[test] + fn bool_encrypt_decrypt_works() { + let bool_evaluator = BoolEvaluator::< + Vec>, + NttBackendU64, + ModularOpsU64>, + ModularOpsU64>, + ShoupServerKeyEvaluationDomain>>, + >::new(SP_BOOL_PARAMS); + let client_key = bool_evaluator.client_key(); + + let mut m = true; + for _ in 0..1000 { + let lwe_ct = bool_evaluator.sk_encrypt(m, &client_key); + let m_back = bool_evaluator.sk_decrypt(&lwe_ct, &client_key); + assert_eq!(m, m_back); + m = !m; + } + } + + #[test] + fn bool_nand() { + let mut bool_evaluator = BoolEvaluator::< + Vec>, + NttBackendU64, + ModularOpsU64>, + ModularOpsU64>, + ShoupServerKeyEvaluationDomain<_>, + >::new(SP_BOOL_PARAMS); + + // println!("{:?}", bool_evaluator.nand_test_vec); + let client_key = bool_evaluator.client_key(); + let seeded_server_key = bool_evaluator.single_party_server_key(&client_key); + let runtime_server_key = + ShoupServerKeyEvaluationDomain::from(ServerKeyEvaluationDomain::< + _, + _, + DefaultSecureRng, + NttBackendU64, + >::from(&seeded_server_key)); + + let mut m0 = false; + let mut m1 = true; + + let mut ct0 = bool_evaluator.sk_encrypt(m0, &client_key); + let mut ct1 = bool_evaluator.sk_encrypt(m1, &client_key); + + for _ in 0..500 { + let ct_back = bool_evaluator.nand(&ct0, &ct1, &runtime_server_key); + + let m_out = !(m0 && m1); + + let m_back = bool_evaluator.sk_decrypt(&ct_back, &client_key); + assert!(m_out == m_back, "Expected {m_out}, got {m_back}"); + + m1 = m0; + m0 = m_out; + + ct1 = ct0; + ct0 = ct_back; + } + } + + #[test] + fn bool_xor() { + let mut bool_evaluator = BoolEvaluator::< + Vec>, + NttBackendU64, + ModularOpsU64>, + ModularOpsU64>, + ShoupServerKeyEvaluationDomain<_>, + >::new(SP_BOOL_PARAMS); + + // println!("{:?}", bool_evaluator.nand_test_vec); + let client_key = bool_evaluator.client_key(); + let seeded_server_key = bool_evaluator.single_party_server_key(&client_key); + let runtime_server_key = + ShoupServerKeyEvaluationDomain::from(ServerKeyEvaluationDomain::< + _, + _, + DefaultSecureRng, + NttBackendU64, + >::from(&seeded_server_key)); + + let mut m0 = false; + let mut m1 = true; + + let mut ct0 = bool_evaluator.sk_encrypt(m0, &client_key); + let mut ct1 = bool_evaluator.sk_encrypt(m1, &client_key); + + for _ in 0..1000 { + let ct_back = bool_evaluator.xor(&ct0, &ct1, &runtime_server_key); + let m_out = (m0 ^ m1); + + let m_back = bool_evaluator.sk_decrypt(&ct_back, &client_key); + assert!(m_out == m_back, "Expected {m_out}, got {m_back}"); + + m1 = m0; + m0 = m_out; + + ct1 = ct0; + ct0 = ct_back; + } + } + + #[test] + fn multi_party_encryption_decryption() { + let bool_evaluator = BoolEvaluator::< + Vec>, + NttBackendU64, + ModularOpsU64>, + ModularOpsU64>, + ShoupServerKeyEvaluationDomain>>, + >::new(MP_BOOL_PARAMS); + + let no_of_parties = 500; + let parties = (0..no_of_parties) + .map(|_| bool_evaluator.client_key()) + .collect_vec(); + + let mut ideal_rlwe_sk = vec![0i32; bool_evaluator.pbs_info.rlwe_n()]; + parties.iter().for_each(|k| { + izip!(ideal_rlwe_sk.iter_mut(), k.sk_rlwe().values()).for_each(|(ideal_i, s_i)| { + *ideal_i = *ideal_i + s_i; + }); + }); + + let mut m = true; + for i in 0..100 { + let pk_cr_seed = [0u8; 32]; + + let public_key_share = parties + .iter() + .map(|k| bool_evaluator.multi_party_public_key_share(pk_cr_seed, k)) + .collect_vec(); + + let collective_pk = PublicKey::< + Vec>, + DefaultSecureRng, + ModularOpsU64>, + >::from(public_key_share.as_slice()); + let lwe_ct = bool_evaluator.pk_encrypt(collective_pk.key(), m); + + let decryption_shares = parties + .iter() + .map(|k| bool_evaluator.multi_party_decryption_share(&lwe_ct, k)) + .collect_vec(); + + let m_back = bool_evaluator.multi_party_decrypt(&decryption_shares, &lwe_ct); + + { + let ideal_m = if m { + bool_evaluator.pbs_info.parameters.rlwe_q().true_el() + } else { + bool_evaluator.pbs_info.parameters.rlwe_q().false_el() + }; + let noise = measure_noise_lwe( + &lwe_ct, + &ideal_rlwe_sk, + &bool_evaluator.pbs_info.rlwe_modop, + &ideal_m, + ); + println!("Noise: {noise}"); + } + + assert_eq!(m_back, m); + m = !m; + } + } + + fn _collecitve_public_key_gen(rlwe_q: u64, parties_rlwe_sk: &[RlweSecret]) -> Vec> { + let ring_size = parties_rlwe_sk[0].values.len(); + assert!(ring_size.is_power_of_two()); + let mut rng = DefaultSecureRng::new(); + let nttop = NttBackendU64::new(&rlwe_q, ring_size); + let modop = ModularOpsU64::new(rlwe_q); + + // Generate Pk shares + let pk_seed = [0u8; 32]; + let pk_shares = parties_rlwe_sk.iter().map(|sk| { + let mut p_rng = DefaultSecureRng::new_seeded(pk_seed); + let mut share_out = vec![0u64; ring_size]; + public_key_share( + &mut share_out, + sk.values(), + &modop, + &nttop, + &mut p_rng, + &mut rng, + ); + share_out + }); + + let mut pk_part_b = vec![0u64; ring_size]; + pk_shares.for_each(|share| modop.elwise_add_mut(&mut pk_part_b, &share)); + let mut pk_part_a = vec![0u64; ring_size]; + let mut p_rng = DefaultSecureRng::new_seeded(pk_seed); + RandomFillUniformInModulus::random_fill(&mut p_rng, &rlwe_q, pk_part_a.as_mut_slice()); + + vec![pk_part_a, pk_part_b] + } + + fn _multi_party_all_keygen( + bool_evaluator: &BoolEvaluator< + Vec>, + NttBackendU64, + ModularOpsU64>, + ModularOpsU64>, + ShoupServerKeyEvaluationDomain>>, + >, + no_of_parties: usize, + ) -> ( + Vec, + PublicKey>, DefaultSecureRng, ModularOpsU64>>, + Vec< + CommonReferenceSeededMultiPartyServerKeyShare< + Vec>, + BoolParameters, + [u8; 32], + >, + >, + SeededMultiPartyServerKey>, [u8; 32], BoolParameters>, + ShoupServerKeyEvaluationDomain>>, + ClientKey, + ) { + let parties = (0..no_of_parties) + .map(|_| bool_evaluator.client_key()) + .collect_vec(); + + let mut rng = DefaultSecureRng::new(); + + // Collective public key + let mut pk_cr_seed = [0u8; 32]; + rng.fill_bytes(&mut pk_cr_seed); + let public_key_share = parties + .iter() + .map(|k| bool_evaluator.multi_party_public_key_share(pk_cr_seed, k)) + .collect_vec(); + let collective_pk = + PublicKey::>, DefaultSecureRng, _>::from(public_key_share.as_slice()); + + // Server key + let mut pbs_cr_seed = [0u8; 32]; + rng.fill_bytes(&mut pbs_cr_seed); + let server_key_shares = parties + .iter() + .map(|k| { + bool_evaluator.multi_party_server_key_share(pbs_cr_seed, &collective_pk.key(), k) + }) + .collect_vec(); + let seeded_server_key = + bool_evaluator.aggregate_multi_party_server_key_shares(&server_key_shares); + let runtime_server_key = + ShoupServerKeyEvaluationDomain::from(ServerKeyEvaluationDomain::< + _, + _, + DefaultSecureRng, + NttBackendU64, + >::from(&seeded_server_key)); + + // construct ideal rlwe sk for meauring noise + let ideal_client_key = { + let mut ideal_rlwe_sk = vec![0i32; bool_evaluator.pbs_info.rlwe_n()]; + parties.iter().for_each(|k| { + izip!(ideal_rlwe_sk.iter_mut(), k.sk_rlwe().values()).for_each(|(ideal_i, s_i)| { + *ideal_i = *ideal_i + s_i; + }); + }); + let mut ideal_lwe_sk = vec![0i32; bool_evaluator.pbs_info.lwe_n()]; + parties.iter().for_each(|k| { + izip!(ideal_lwe_sk.iter_mut(), k.sk_lwe().values()).for_each(|(ideal_i, s_i)| { + *ideal_i = *ideal_i + s_i; + }); + }); + + ClientKey::new( + RlweSecret { + values: ideal_rlwe_sk, + }, + LweSecret { + values: ideal_lwe_sk, + }, + ) + }; + + ( + parties, + collective_pk, + server_key_shares, + seeded_server_key, + runtime_server_key, + ideal_client_key, + ) + } + + #[test] + fn multi_party_nand() { + let mut bool_evaluator = BoolEvaluator::< + Vec>, + NttBackendU64, + ModularOpsU64>, + ModularOpsU64>, + ShoupServerKeyEvaluationDomain>>, + >::new(MP_BOOL_PARAMS); + + let (parties, collective_pk, _, _, server_key_eval, ideal_client_key) = + _multi_party_all_keygen(&bool_evaluator, 2); + + let mut m0 = true; + let mut m1 = false; + + let mut lwe0 = bool_evaluator.pk_encrypt(collective_pk.key(), m0); + let mut lwe1 = bool_evaluator.pk_encrypt(collective_pk.key(), m1); + + for _ in 0..2000 { + let lwe_out = bool_evaluator.nand(&lwe0, &lwe1, &server_key_eval); + + let m_expected = !(m0 & m1); + + // multi-party decrypt + let decryption_shares = parties + .iter() + .map(|k| bool_evaluator.multi_party_decryption_share(&lwe_out, k)) + .collect_vec(); + let m_back = bool_evaluator.multi_party_decrypt(&decryption_shares, &lwe_out); + + let m_back = bool_evaluator.sk_decrypt(&lwe_out, &ideal_client_key); + + assert!( + m_expected == m_back, + "Expected {m_expected}, got +{m_back}" + ); + m1 = m0; + m0 = m_expected; + + lwe1 = lwe0; + lwe0 = lwe_out; + } + } + + #[test] + fn noise_tester() { + let bool_evaluator = BoolEvaluator::< + Vec>, + NttBackendU64, + ModularOpsU64>, + ModularOpsU64>, + ShoupServerKeyEvaluationDomain>>, + >::new(SP_BOOL_PARAMS); + + // let (_, collective_pk, _, _, server_key_eval, ideal_client_key) = + // _multi_party_all_keygen(&bool_evaluator, 20); + let no_of_parties = 2; + let lwe_q = bool_evaluator.pbs_info.parameters.lwe_q(); + let rlwe_q = bool_evaluator.pbs_info.parameters.rlwe_q(); + let lwe_n = bool_evaluator.pbs_info.parameters.lwe_n().0; + let rlwe_n = bool_evaluator.pbs_info.parameters.rlwe_n().0; + let lwe_modop = &bool_evaluator.pbs_info.lwe_modop; + let rlwe_nttop = &bool_evaluator.pbs_info.rlwe_nttop; + let rlwe_modop = &bool_evaluator.pbs_info.rlwe_modop; + + // let rgsw_rgsw_decomposer = &bool_evaluator + // .pbs_info + // .parameters + // .rgsw_rgsw_decomposer::>(); + // let rgsw_rgsw_gadget_a = rgsw_rgsw_decomposer.0.gadget_vector(); + // let rgsw_rgsw_gadget_b = rgsw_rgsw_decomposer.1.gadget_vector(); + + let rlwe_rgsw_decomposer = &bool_evaluator.pbs_info.rlwe_rgsw_decomposer; + let rlwe_rgsw_gadget_a = rlwe_rgsw_decomposer.0.gadget_vector(); + let rlwe_rgsw_gadget_b = rlwe_rgsw_decomposer.1.gadget_vector(); + + let auto_decomposer = &bool_evaluator.pbs_info.auto_decomposer; + let auto_gadget = auto_decomposer.gadget_vector(); + + let parties = (0..no_of_parties) + .map(|_| bool_evaluator.client_key()) + .collect_vec(); + + let ideal_client_key = { + let mut ideal_rlwe_sk = vec![0i32; bool_evaluator.pbs_info.rlwe_n()]; + parties.iter().for_each(|k| { + izip!(ideal_rlwe_sk.iter_mut(), k.sk_rlwe().values()).for_each(|(ideal_i, s_i)| { + *ideal_i = *ideal_i + s_i; + }); + }); + let mut ideal_lwe_sk = vec![0i32; bool_evaluator.pbs_info.lwe_n()]; + parties.iter().for_each(|k| { + izip!(ideal_lwe_sk.iter_mut(), k.sk_lwe().values()).for_each(|(ideal_i, s_i)| { + *ideal_i = *ideal_i + s_i; + }); + }); + + ClientKey::new( + RlweSecret { + values: ideal_rlwe_sk, + }, + LweSecret { + values: ideal_lwe_sk, + }, + ) + }; + + // check noise in freshly encrypted RLWE ciphertext (ie var_fresh) + if true { + let mut rng = DefaultSecureRng::new(); + let mut check = Stats { samples: vec![] }; + for _ in 0..10 { + // generate a new collective public key + let mut pk_cr_seed = [0u8; 32]; + rng.fill_bytes(&mut pk_cr_seed); + let public_key_share = parties + .iter() + .map(|k| bool_evaluator.multi_party_public_key_share(pk_cr_seed, k)) + .collect_vec(); + let collective_pk = PublicKey::< + Vec>, + DefaultSecureRng, + ModularOpsU64>, + >::from(public_key_share.as_slice()); + + let mut m = vec![0u64; rlwe_n]; + RandomFillUniformInModulus::random_fill(&mut rng, rlwe_q, m.as_mut_slice()); + let mut rlwe_ct = vec![vec![0u64; rlwe_n]; 2]; + public_key_encrypt_rlwe::<_, _, _, _, i32, _>( + &mut rlwe_ct, + collective_pk.key(), + &m, + rlwe_modop, + rlwe_nttop, + &mut rng, + ); + + let mut m_back = vec![0u64; rlwe_n]; + decrypt_rlwe( + &rlwe_ct, + ideal_client_key.sk_rlwe().values(), + &mut m_back, + rlwe_nttop, + rlwe_modop, + ); + + rlwe_modop.elwise_sub_mut(m_back.as_mut_slice(), m.as_slice()); + + check.add_more(Vec::::try_convert_from(&m_back, rlwe_q).as_slice()); + } + + println!("Public key Std: {}", check.std_dev().abs().log2()); + } + + if true { + // Generate server key shares + let mut rng = DefaultSecureRng::new(); + let mut pk_cr_seed = [0u8; 32]; + rng.fill_bytes(&mut pk_cr_seed); + let public_key_share = parties + .iter() + .map(|k| bool_evaluator.multi_party_public_key_share(pk_cr_seed, k)) + .collect_vec(); + let collective_pk = PublicKey::< + Vec>, + DefaultSecureRng, + ModularOpsU64>, + >::from(public_key_share.as_slice()); + + let pbs_cr_seed = [0u8; 32]; + rng.fill_bytes(&mut pk_cr_seed); + let server_key_shares = parties + .iter() + .map(|k| { + bool_evaluator.multi_party_server_key_share(pbs_cr_seed, collective_pk.key(), k) + }) + .collect_vec(); + + let seeded_server_key = + bool_evaluator.aggregate_multi_party_server_key_shares(&server_key_shares); + + // Check noise in RGSW ciphertexts of ideal LWE secret elements + if false { + let mut check = Stats { samples: vec![] }; + izip!( + ideal_client_key.sk_lwe().values.iter(), + seeded_server_key.rgsw_cts().iter() + ) + .for_each(|(s_i, rgsw_ct_i)| { + // X^{s[i]} + let mut m_si = vec![0u64; rlwe_n]; + let s_i = *s_i * (bool_evaluator.pbs_info.embedding_factor as i32); + if s_i < 0 { + m_si[rlwe_n - (s_i.abs() as usize)] = rlwe_q.neg_one(); + } else { + m_si[s_i as usize] = 1; + } + + // RLWE'(-sm) + let mut neg_s_eval = + Vec::::try_convert_from(ideal_client_key.sk_rlwe().values(), rlwe_q); + rlwe_modop.elwise_neg_mut(&mut neg_s_eval); + rlwe_nttop.forward(&mut neg_s_eval); + for j in 0..rlwe_rgsw_decomposer.a().decomposition_count() { + // RLWE(B^{j} * -s[X]*X^{s_lwe[i]}) + + // -s[X]*X^{s_lwe[i]}*B_j + let mut m_ideal = m_si.clone(); + rlwe_nttop.forward(m_ideal.as_mut_slice()); + rlwe_modop.elwise_mul_mut(m_ideal.as_mut_slice(), neg_s_eval.as_slice()); + rlwe_nttop.backward(m_ideal.as_mut_slice()); + rlwe_modop + .elwise_scalar_mul_mut(m_ideal.as_mut_slice(), &rlwe_rgsw_gadget_a[j]); + + // RLWE(-s*X^{s_lwe[i]}*B_j) + let mut rlwe_ct = vec![vec![0u64; rlwe_n]; 2]; + rlwe_ct[0].copy_from_slice(&rgsw_ct_i[j]); + rlwe_ct[1].copy_from_slice( + &rgsw_ct_i[j + rlwe_rgsw_decomposer.a().decomposition_count()], + ); + + let mut m_back = vec![0u64; rlwe_n]; + decrypt_rlwe( + &rlwe_ct, + ideal_client_key.sk_rlwe().values(), + &mut m_back, + rlwe_nttop, + rlwe_modop, + ); + + // diff + rlwe_modop.elwise_sub_mut(&mut m_back, &m_ideal); + check.add_more(&Vec::::try_convert_from(&m_back, rlwe_q)); + } + + // RLWE'(m) + for j in 0..rlwe_rgsw_decomposer.b().decomposition_count() { + // RLWE(B^{j} * X^{s_lwe[i]}) + + // X^{s_lwe[i]}*B_j + let mut m_ideal = m_si.clone(); + rlwe_modop + .elwise_scalar_mul_mut(m_ideal.as_mut_slice(), &rlwe_rgsw_gadget_b[j]); + + // RLWE(X^{s_lwe[i]}*B_j) + let mut rlwe_ct = vec![vec![0u64; rlwe_n]; 2]; + rlwe_ct[0].copy_from_slice( + &rgsw_ct_i[j + (2 * rlwe_rgsw_decomposer.a().decomposition_count())], + ); + rlwe_ct[1].copy_from_slice( + &rgsw_ct_i[j + + (2 * rlwe_rgsw_decomposer.a().decomposition_count() + + rlwe_rgsw_decomposer.b().decomposition_count())], + ); + + let mut m_back = vec![0u64; rlwe_n]; + decrypt_rlwe( + &rlwe_ct, + ideal_client_key.sk_rlwe().values(), + &mut m_back, + rlwe_nttop, + rlwe_modop, + ); + + // diff + rlwe_modop.elwise_sub_mut(&mut m_back, &m_ideal); + check.add_more(&Vec::::try_convert_from(&m_back, rlwe_q)); + } + }); + println!( + "RGSW Std: {} {} ;; max={}", + check.mean(), + check.std_dev().abs().log2(), + check.samples.iter().max().unwrap() + ); + } + + // server key in Evaluation domain + let runtime_server_key = + ShoupServerKeyEvaluationDomain::from(ServerKeyEvaluationDomain::< + _, + _, + DefaultSecureRng, + NttBackendU64, + >::from(&seeded_server_key)); + + // check noise in RLWE x RGSW(X^{s_i}) where RGSW is accunulated RGSW ciphertext + if false { + let mut check = Stats { samples: vec![] }; + + ideal_client_key + .sk_lwe() + .values() + .iter() + .enumerate() + .for_each(|(index, s_i)| { + let rgsw_ct_i = runtime_server_key.rgsw_ct_lwe_si(index).as_ref(); + + let mut m = vec![0u64; rlwe_n]; + RandomFillUniformInModulus::random_fill(&mut rng, rlwe_q, m.as_mut_slice()); + let mut rlwe_ct = vec![vec![0u64; rlwe_n]; 2]; + public_key_encrypt_rlwe::<_, _, _, _, i32, _>( + &mut rlwe_ct, + collective_pk.key(), + &m, + rlwe_modop, + rlwe_nttop, + &mut rng, + ); + + let mut rlwe_after = RlweCiphertext::<_, DefaultSecureRng> { + data: rlwe_ct.clone(), + is_trivial: false, + _phatom: PhantomData, + }; + let mut scratch = vec![ + vec![0u64; rlwe_n]; + std::cmp::max( + rlwe_rgsw_decomposer.0.decomposition_count(), + rlwe_rgsw_decomposer.1.decomposition_count() + ) + 2 + ]; + rlwe_by_rgsw( + &mut rlwe_after, + rgsw_ct_i, + &mut scratch, + rlwe_rgsw_decomposer, + rlwe_nttop, + rlwe_modop, + ); + + // m1 = X^{s[i]} + let mut m1 = vec![0u64; rlwe_n]; + let s_i = *s_i * (bool_evaluator.pbs_info.embedding_factor as i32); + if s_i < 0 { + m1[rlwe_n - (s_i.abs() as usize)] = rlwe_q.neg_one() + } else { + m1[s_i as usize] = 1; + } + + // (m+e) * m1 + let mut m_plus_e_times_m1 = vec![0u64; rlwe_n]; + decrypt_rlwe( + &rlwe_ct, + ideal_client_key.sk_rlwe().values(), + &mut m_plus_e_times_m1, + rlwe_nttop, + rlwe_modop, + ); + rlwe_nttop.forward(m_plus_e_times_m1.as_mut_slice()); + rlwe_nttop.forward(m1.as_mut_slice()); + + rlwe_modop.elwise_mul_mut(m_plus_e_times_m1.as_mut_slice(), m1.as_slice()); + rlwe_nttop.backward(m_plus_e_times_m1.as_mut_slice()); + + // Resulting RLWE ciphertext will equal: (m0m1 + em1) + e_{rlsw x rgsw}. + // Hence, resulting rlwe ciphertext will have error em1 + e_{rlwe x rgsw}. + // Here we're only concerned with e_{rlwe x rgsw}, that is noise added by + // RLWExRGSW. Also note in practice m1 is a monomial, for ex, X^{s_{i}}, for + // some i and var(em1) = var(e). + let mut m_plus_e_times_m1_more_e = vec![0u64; rlwe_n]; + decrypt_rlwe( + &rlwe_after, + ideal_client_key.sk_rlwe().values(), + &mut m_plus_e_times_m1_more_e, + rlwe_nttop, + rlwe_modop, + ); + + // diff + rlwe_modop.elwise_sub_mut( + m_plus_e_times_m1_more_e.as_mut_slice(), + m_plus_e_times_m1.as_slice(), + ); + + // let noise = measure_noise( + // &rlwe_after, + // &m_plus_e_times_m1, + // rlwe_nttop, + // rlwe_modop, + // ideal_client_key.sk_rlwe.values(), + // ); + // print!("NOISE: {}", noise); + + check.add_more(&Vec::::try_convert_from( + &m_plus_e_times_m1_more_e, + rlwe_q, + )); + }); + println!( + "RLWE x RGSW, where RGSW has noise var_brk, std: {} {}", + check.std_dev(), + check.std_dev().abs().log2() + ) + } + + // check noise in Auto key + if false { + let mut check = Stats { samples: vec![] }; + + let mut neg_s_poly = + Vec::::try_convert_from(ideal_client_key.sk_rlwe().values(), rlwe_q); + rlwe_modop.elwise_neg_mut(neg_s_poly.as_mut_slice()); + + let g = bool_evaluator.pbs_info.g(); + let br_q = bool_evaluator.pbs_info.br_q(); + let auto_element_dlogs = bool_evaluator.pbs_info.parameters.auto_element_dlogs(); + for i in auto_element_dlogs.into_iter() { + let g_pow = if i == 0 { + -g + } else { + (((g as usize).pow(i as u32)) % br_q) as isize + }; + + // -s[X^k] + let (auto_indices, auto_sign) = generate_auto_map(rlwe_n, g_pow); + let mut neg_s_poly_auto_i = vec![0u64; rlwe_n]; + izip!(neg_s_poly.iter(), auto_indices.iter(), auto_sign.iter()).for_each( + |(v, to_i, to_sign)| { + if !to_sign { + neg_s_poly_auto_i[*to_i] = rlwe_modop.neg(v); + } else { + neg_s_poly_auto_i[*to_i] = *v; + } + }, + ); + + let mut auto_key_i = runtime_server_key.galois_key_for_auto(i).as_ref().clone(); //send i^th auto key to coefficient domain + auto_key_i + .iter_mut() + .for_each(|r| rlwe_nttop.backward(r.as_mut_slice())); + auto_gadget.iter().enumerate().for_each(|(i, b_i)| { + // B^i * -s[X^k] + let mut m_ideal = neg_s_poly_auto_i.clone(); + + rlwe_modop.elwise_scalar_mul_mut(m_ideal.as_mut_slice(), b_i); + + let mut m_out = vec![0u64; rlwe_n]; + let mut rlwe_ct = vec![vec![0u64; rlwe_n]; 2]; + rlwe_ct[0].copy_from_slice(&auto_key_i[i]); + rlwe_ct[1].copy_from_slice( + &auto_key_i[auto_decomposer.decomposition_count() + i], + ); + decrypt_rlwe( + &rlwe_ct, + ideal_client_key.sk_rlwe().values(), + &mut m_out, + rlwe_nttop, + rlwe_modop, + ); + + // diff + rlwe_modop.elwise_sub_mut(m_out.as_mut_slice(), m_ideal.as_slice()); + + check.add_more(&Vec::::try_convert_from(&m_out, rlwe_q)); + }); + } + + println!("Auto key noise std dev: {}", check.std_dev().abs().log2()); + } + + // check noise in RLWE(X^k) after sending RLWE(X) -> RLWE(X^k)using collective + // auto key + if true { + let mut check = Stats { samples: vec![] }; + let br_q = bool_evaluator.pbs_info.br_q(); + let g = bool_evaluator.pbs_info.g(); + let auto_element_dlogs = bool_evaluator.pbs_info.parameters.auto_element_dlogs(); + for i in auto_element_dlogs.into_iter() { + for _ in 0..10 { + let mut m = vec![0u64; rlwe_n]; + RandomFillUniformInModulus::random_fill(&mut rng, rlwe_q, m.as_mut_slice()); + let mut rlwe_ct = RlweCiphertext::<_, DefaultSecureRng> { + data: vec![vec![0u64; rlwe_n]; 2], + is_trivial: false, + _phatom: PhantomData, + }; + public_key_encrypt_rlwe::<_, _, _, _, i32, _>( + &mut rlwe_ct, + collective_pk.key(), + &m, + rlwe_modop, + rlwe_nttop, + &mut rng, + ); + + // We're only interested in noise increased as a result of automorphism. + // Hence, we take m+e as the bench. + let mut m_plus_e = vec![0u64; rlwe_n]; + decrypt_rlwe( + &rlwe_ct, + ideal_client_key.sk_rlwe().values(), + &mut m_plus_e, + rlwe_nttop, + rlwe_modop, + ); + + let auto_key = runtime_server_key.galois_key_for_auto(i).as_ref(); + let (auto_map_index, auto_map_sign) = + bool_evaluator.pbs_info.rlwe_auto_map(i); + let mut scratch = + vec![vec![0u64; rlwe_n]; auto_decomposer.decomposition_count() + 2]; + galois_auto( + &mut rlwe_ct, + auto_key, + &mut scratch, + &auto_map_index, + &auto_map_sign, + rlwe_modop, + rlwe_nttop, + auto_decomposer, + ); + + // send m+e from X to X^k + let mut m_plus_e_auto = vec![0u64; rlwe_n]; + izip!(m_plus_e.iter(), auto_map_index.iter(), auto_map_sign.iter()) + .for_each(|(v, to_index, to_sign)| { + if !to_sign { + m_plus_e_auto[*to_index] = rlwe_modop.neg(v); + } else { + m_plus_e_auto[*to_index] = *v + } + }); + + let mut m_out = vec![0u64; rlwe_n]; + decrypt_rlwe( + &rlwe_ct, + ideal_client_key.sk_rlwe().values(), + &mut m_out, + rlwe_nttop, + rlwe_modop, + ); + + // diff + rlwe_modop.elwise_sub_mut(m_out.as_mut_slice(), m_plus_e_auto.as_slice()); + + check.add_more(&Vec::::try_convert_from(m_out.as_slice(), rlwe_q)); + } + } + + println!("Rlwe Auto Noise Std: {}", check.std_dev().abs().log2()); + } + + // Check noise growth in ksk + // TODO check in LWE key switching keys + if true { + // 1. encrypt LWE ciphertext + // 2. Key switching + // 3. + let mut check = Stats { samples: vec![] }; + + for _ in 0..1024 { + // Encrypt m \in Q_{ks} using RLWE sk + let mut lwe_in_ct = vec![0u64; rlwe_n + 1]; + let m = RandomElementInModulus::random(&mut rng, &lwe_q.q().unwrap()); + encrypt_lwe( + &mut lwe_in_ct, + &m, + ideal_client_key.sk_rlwe().values(), + lwe_modop, + &mut rng, + ); + + // Key switch + let mut lwe_out = vec![0u64; lwe_n + 1]; + lwe_key_switch( + &mut lwe_out, + &lwe_in_ct, + runtime_server_key.lwe_ksk(), + lwe_modop, + bool_evaluator.pbs_info.lwe_decomposer(), + ); + + // We only care about noise added by LWE key switch + // m+e + let m_plus_e = + decrypt_lwe(&lwe_in_ct, ideal_client_key.sk_rlwe().values(), lwe_modop); + + let m_plus_e_plus_lwe_ksk_noise = + decrypt_lwe(&lwe_out, ideal_client_key.sk_lwe().values(), lwe_modop); + + let diff = lwe_modop.sub(&m_plus_e_plus_lwe_ksk_noise, &m_plus_e); + + check.add_more(&vec![lwe_q.map_element_to_i64(&diff)]); + } + + println!("Lwe ksk std dev: {}", check.std_dev().abs().log2()); + } + } + + // Check noise in fresh RGSW ciphertexts, ie X^{s_j[i]}, must equalnoise in + // // fresh RLWE ciphertext + if true {} + // test LWE ksk from RLWE -> LWE + // if false { + // let logp = 2; + // let mut rng = DefaultSecureRng::new(); + + // let m = 1; + // let encoded_m = m << (lwe_logq - logp); + + // // Encrypt + // let mut lwe_ct = vec![0u64; rlwe_n + 1]; + // encrypt_lwe( + // &mut lwe_ct, + // &encoded_m, + // ideal_client_key.sk_rlwe.values(), + // lwe_modop, + // &mut rng, + // ); + + // // key switch + // let lwe_decomposer = &bool_evaluator.decomposer_lwe; + // let mut lwe_out = vec![0u64; lwe_n + 1]; + // lwe_key_switch( + // &mut lwe_out, + // &lwe_ct, + // &server_key_eval.lwe_ksk, + // lwe_modop, + // lwe_decomposer, + // ); + + // let encoded_m_back = decrypt_lwe(&lwe_out, + // ideal_client_key.sk_lwe.values(), lwe_modop); let m_back + // = ((encoded_m_back as f64 * (1 << logp) as f64) / + // (lwe_q as f64)).round() as u64; dbg!(m_back, m); + + // let noise = measure_noise_lwe( + // &lwe_out, + // ideal_client_key.sk_lwe.values(), + // lwe_modop, + // &encoded_m, + // ); + + // println!("Noise: {noise}"); + // } + + // Measure noise in RGSW ciphertexts of ideal LWE secrets + // if true { + // let gadget_vec = gadget_vector( + // bool_evaluator.parameters.rlwe_logq, + // bool_evaluator.parameters.logb_rgsw, + // bool_evaluator.parameters.d_rgsw, + // ); + + // for i in 0..20 { + // // measure noise in RGSW(s[i]) + // let si = + // ideal_client_key.sk_lwe.values[i] * + // (bool_evaluator.embedding_factor as i32); let mut + // si_poly = vec![0u64; rlwe_n]; if si < 0 { + // si_poly[rlwe_n - (si.abs() as usize)] = rlwe_q - 1; + // } else { + // si_poly[(si.abs() as usize)] = 1; + // } + + // let mut rgsw_si = server_key_eval.rgsw_cts[i].clone(); + // rgsw_si + // .iter_mut() + // .for_each(|ri| rlwe_nttop.backward(ri.as_mut())); + + // println!("####### Noise in RGSW(X^s_{i}) #######"); + // _measure_noise_rgsw( + // &rgsw_si, + // &si_poly, + // ideal_client_key.sk_rlwe.values(), + // &gadget_vec, + // rlwe_q, + // ); + // println!("####### ##################### #######"); + // } + // } + + // // measure noise grwoth in RLWExRGSW + // if true { + // let mut rng = DefaultSecureRng::new(); + // let mut carry_m = vec![0u64; rlwe_n]; + // RandomUniformDist1::random_fill(&mut rng, &rlwe_q, + // carry_m.as_mut_slice()); + + // // RGSW(carrym) + // let trivial_rlwect = vec![vec![0u64; rlwe_n],carry_m.clone()]; + // let mut rlwe_ct = RlweCiphertext::<_, + // DefaultSecureRng>::from_raw(trivial_rlwect, true); + + // let mut scratch_matrix_dplus2_ring = vec![vec![0u64; rlwe_n]; + // d_rgsw + 2]; let mul_mod = + // |v0: &u64, v1: &u64| (((*v0 as u128 * *v1 as u128) % (rlwe_q as u128)) as u64); + + // for i in 0..bool_evaluator.parameters.lwe_n { + // rlwe_by_rgsw( + // &mut rlwe_ct, + // server_key_eval.rgsw_ct_lwe_si(i), + // &mut scratch_matrix_dplus2_ring, + // rlwe_decomposer, + // rlwe_nttop, + // rlwe_modop, + // ); + + // // carry_m[X] * s_i[X] + // let si = + // ideal_client_key.sk_lwe.values[i] * + // (bool_evaluator.embedding_factor as i32); let mut + // si_poly = vec![0u64; rlwe_n]; if si < 0 { + // si_poly[rlwe_n - (si.abs() as usize)] = rlwe_q - 1; + // } else { + // si_poly[(si.abs() as usize)] = 1; + // } + // carry_m = negacyclic_mul(&carry_m, &si_poly, mul_mod, + // rlwe_q); + + // let noise = measure_noise( + // &rlwe_ct, + // &carry_m, + // rlwe_nttop, + // rlwe_modop, + // ideal_client_key.sk_rlwe.values(), + // ); + // println!("Noise RLWE(carry_m) accumulating {i}^th secret + // monomial: {noise}"); } + // } + + // // Check galois keys + // if false { + // let g = bool_evaluator.g() as isize; + // let mut rng = DefaultSecureRng::new(); + // let mut scratch_matrix_dplus2_ring = vec![vec![0u64; rlwe_n]; + // d_rgsw + 2]; for i in [g, -g] { + // let mut m = vec![0u64; rlwe_n]; + // RandomUniformDist1::random_fill(&mut rng, &rlwe_q, + // m.as_mut_slice()); let mut rlwe_ct = { + // let mut data = vec![vec![0u64; rlwe_n]; 2]; + // public_key_encrypt_rlwe( + // &mut data, + // &collective_pk.key, + // &m, + // rlwe_modop, + // rlwe_nttop, + // &mut rng, + // ); + // RlweCiphertext::<_, DefaultSecureRng>::from_raw(data, + // false) }; + + // let auto_key = server_key_eval.galois_key_for_auto(i); + // let (auto_map_index, auto_map_sign) = + // generate_auto_map(rlwe_n, i); galois_auto( + // &mut rlwe_ct, + // auto_key, + // &mut scratch_matrix_dplus2_ring, + // &auto_map_index, + // &auto_map_sign, + // rlwe_modop, + // rlwe_nttop, + // rlwe_decomposer, + // ); + + // // send m(X) -> m(X^i) + // let mut m_k = vec![0u64; rlwe_n]; + // izip!(m.iter(), auto_map_index.iter(), + // auto_map_sign.iter()).for_each( |(mi, to_index,to_sign)| + // // { if !to_sign { + // m_k[*to_index] = rlwe_q - *mi; } else { + // m_k[*to_index] = *mi; + // } + // }, + // ); + + // // measure noise + // let noise = measure_noise( + // &rlwe_ct, + // &m_k, + // rlwe_nttop, + // rlwe_modop, + // ideal_client_key.sk_rlwe.values(), + // ); + + // println!("Noise after auto k={i}: {noise}"); + // } + // } + } +} diff --git a/src/bool/keys.rs b/src/bool/keys.rs index bd90b2c..a0df170 100644 --- a/src/bool/keys.rs +++ b/src/bool/keys.rs @@ -6,7 +6,7 @@ use crate::{ pbs::WithShoupRepr, random::{NewWithSeed, RandomFillUniformInModulus}, rgsw::RlweSecret, - utils::WithLocal, + utils::{ToShoup, WithLocal}, Decryptor, Encryptor, Matrix, MatrixEntity, MatrixMut, MultiPartyDecryptor, RowEntity, RowMut, }; @@ -669,7 +669,7 @@ pub(super) mod impl_server_key_eval_domain { } /// Server key in evaluation domain -pub(crate) struct ShoupServerKeyEvaluationDomain { +pub(crate) struct ShoupServerKeyEvaluationDomain { /// Rgsw cts of LWE secret elements rgsw_cts: Vec>, /// Auto keys. Key corresponding to g^{k} is at index `k`. Key corresponding @@ -677,12 +677,18 @@ pub(crate) struct ShoupServerKeyEvaluationDomain { galois_keys: HashMap>, /// LWE ksk to key switching LWE ciphertext from RLWE secret to LWE secret lwe_ksk: M, - parameters: P, - _phanton: PhantomData<(R, N)>, } +/// Stores normal and shoup representation of Matrix elements (Normal, Shoup) pub(crate) struct NormalAndShoup(M, M); +impl NormalAndShoup { + fn new_with_modulus(value: M, modulus: ::Modulus) -> Self { + let value_shoup = M::to_shoup(&value, modulus); + NormalAndShoup(value, value_shoup) + } +} + impl AsRef for NormalAndShoup { fn as_ref(&self) -> &M { &self.0 @@ -697,11 +703,43 @@ impl WithShoupRepr for NormalAndShoup { } mod shoup_server_key_eval_domain { - use crate::pbs::PbsKey; + use itertools::{izip, Itertools}; + use num_traits::{FromPrimitive, PrimInt}; + + use crate::{backend::Modulus, pbs::PbsKey}; use super::*; - impl PbsKey for ShoupServerKeyEvaluationDomain { + impl, R, N> + From, R, N>> + for ShoupServerKeyEvaluationDomain + where + ::R: RowMut, + M::MatElement: PrimInt + FromPrimitive, + { + fn from(value: ServerKeyEvaluationDomain, R, N>) -> Self { + let q = value.parameters.rlwe_q().q().unwrap(); + // Rgsw ciphertexts + let rgsw_cts = value + .rgsw_cts + .into_iter() + .map(|ct| NormalAndShoup::new_with_modulus(ct, q)) + .collect_vec(); + + let mut auto_keys = HashMap::new(); + value.galois_keys.into_iter().for_each(|(index, key)| { + auto_keys.insert(index, NormalAndShoup::new_with_modulus(key, q)); + }); + + Self { + rgsw_cts, + galois_keys: auto_keys, + lwe_ksk: value.lwe_ksk, + } + } + } + + impl PbsKey for ShoupServerKeyEvaluationDomain { type AutoKey = NormalAndShoup; type LweKskKey = M; type RgswCt = NormalAndShoup; diff --git a/src/bool/mod.rs b/src/bool/mod.rs index fb5ab6f..9656e83 100644 --- a/src/bool/mod.rs +++ b/src/bool/mod.rs @@ -19,17 +19,10 @@ use crate::{ }; thread_local! { - static BOOL_EVALUATOR: RefCell>, NttBackendU64, ModularOpsU64>, ModularOpsU64>>>> = RefCell::new(None); + static BOOL_EVALUATOR: RefCell>, NttBackendU64, ModularOpsU64>, ModularOpsU64>, ShoupServerKeyEvaluationDomain>>>>> = RefCell::new(None); } -static BOOL_SERVER_KEY: OnceLock< - ShoupServerKeyEvaluationDomain< - Vec>, - BoolParameters, - DefaultSecureRng, - NttBackendU64, - >, -> = OnceLock::new(); +static BOOL_SERVER_KEY: OnceLock>>> = OnceLock::new(); static MULTI_PARTY_CRS: OnceLock> = OnceLock::new(); @@ -44,14 +37,7 @@ pub fn set_mp_seed(seed: [u8; 32]) { ) } -fn set_server_key( - key: ShoupServerKeyEvaluationDomain< - Vec>, - BoolParameters, - DefaultSecureRng, - NttBackendU64, - >, -) { +fn set_server_key(key: ShoupServerKeyEvaluationDomain>>) { assert!( BOOL_SERVER_KEY.set(key).is_ok(), "Attempted to set server key twice." @@ -64,7 +50,7 @@ pub(crate) fn gen_keys() -> ( ) { BoolEvaluator::with_local_mut(|e| { let ck = e.client_key(); - let sk = e.server_key(&ck); + let sk = e.single_party_server_key(&ck); (ck, sk) }) @@ -115,15 +101,11 @@ pub fn aggregate_server_key_shares( BoolEvaluator::with_local(|e| e.aggregate_multi_party_server_key_shares(shares)) } -// SERVER KEY EVAL DOMAIN // +// SERVER KEY EVAL (/SHOUP) DOMAIN // impl SeededServerKey>, BoolParameters, [u8; 32]> { pub fn set_server_key(&self) { - set_server_key(ServerKeyEvaluationDomain::< - _, - _, - DefaultSecureRng, - NttBackendU64, - >::from(self)); + let eval = ServerKeyEvaluationDomain::<_, _, DefaultSecureRng, NttBackendU64>::from(self); + set_server_key(ShoupServerKeyEvaluationDomain::from(eval)); } } @@ -135,25 +117,9 @@ impl > { pub fn set_server_key(&self) { - set_server_key(ServerKeyEvaluationDomain::< - _, - _, - DefaultSecureRng, - NttBackendU64, - >::from(self)) - } -} - -impl Global - for ShoupServerKeyEvaluationDomain< - Vec>, - BoolParameters, - DefaultSecureRng, - NttBackendU64, - > -{ - fn global() -> &'static Self { - BOOL_SERVER_KEY.get().unwrap() + set_server_key(ShoupServerKeyEvaluationDomain::from( + ServerKeyEvaluationDomain::<_, _, DefaultSecureRng, NttBackendU64>::from(self), + )) } } @@ -173,6 +139,7 @@ impl WithLocal NttBackendU64, ModularOpsU64>, ModularOpsU64>, + ShoupServerKeyEvaluationDomain>>, > { fn with_local(func: F) -> R @@ -196,3 +163,10 @@ impl WithLocal BOOL_EVALUATOR.with_borrow_mut(|s| func(s.as_mut().expect("Parameters not set"))) } } + +pub(crate) type RuntimeServerKey = ShoupServerKeyEvaluationDomain>>; +impl Global for RuntimeServerKey { + fn global() -> &'static Self { + BOOL_SERVER_KEY.get().expect("Server key not set!") + } +} diff --git a/src/bool/noise.rs b/src/bool/noise.rs index d2b7491..80f3fd9 100644 --- a/src/bool/noise.rs +++ b/src/bool/noise.rs @@ -1,5 +1,3 @@ -use std::cell::RefCell; - mod test { use itertools::{izip, Itertools}; @@ -7,7 +5,8 @@ mod test { backend::{ArithmeticOps, ModularOpsU64, Modulus}, bool::{ set_parameter_set, BoolEncoding, BoolEvaluator, BooleanGates, CiphertextModulus, - ClientKey, PublicKey, ServerKeyEvaluationDomain, MP_BOOL_PARAMS, SMALL_MP_BOOL_PARAMS, + ClientKey, PublicKey, ServerKeyEvaluationDomain, ShoupServerKeyEvaluationDomain, + MP_BOOL_PARAMS, SMALL_MP_BOOL_PARAMS, }, lwe::{decrypt_lwe, LweSecret}, ntt::NttBackendU64, @@ -15,7 +14,7 @@ mod test { random::DefaultSecureRng, rgsw::RlweSecret, utils::Stats, - Secret, + Ntt, Secret, }; #[test] @@ -26,6 +25,7 @@ mod test { NttBackendU64, ModularOpsU64>, ModularOpsU64>, + ShoupServerKeyEvaluationDomain>>, >::new(SMALL_MP_BOOL_PARAMS); let parties = 2; @@ -84,7 +84,12 @@ mod test { .collect_vec(); let server_key = evaluator.aggregate_multi_party_server_key_shares(&server_key_shares); - let server_key_eval_domain = ServerKeyEvaluationDomain::from(&server_key); + let runtime_server_key = ShoupServerKeyEvaluationDomain::from(ServerKeyEvaluationDomain::< + _, + _, + DefaultSecureRng, + NttBackendU64, + >::from(&server_key)); let mut m0 = false; let mut m1 = true; @@ -99,7 +104,7 @@ mod test { for _ in 0..1000 { let now = std::time::Instant::now(); - let c_out = evaluator.xor(&c_m0, &c_m1, &server_key_eval_domain); + let c_out = evaluator.xor(&c_m0, &c_m1, &runtime_server_key); println!("Gate time: {:?}", now.elapsed()); // mp decrypt diff --git a/src/lib.rs b/src/lib.rs index 8b24fb3..99f8674 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,7 +20,9 @@ mod rgsw; mod shortint; mod utils; -pub use backend::{ArithmeticOps, ModInit, ModularOpsU64, VectorOps}; +pub use backend::{ + ArithmeticLazyOps, ArithmeticOps, ModInit, ModularOpsU64, ShoupMatrixFMA, VectorOps, +}; pub use decomposer::{Decomposer, DecomposerIter, DefaultDecomposer}; pub use ntt::{Ntt, NttBackendU64, NttInit}; diff --git a/src/rgsw/mod.rs b/src/rgsw/mod.rs index 14bfee7..b4b0139 100644 --- a/src/rgsw/mod.rs +++ b/src/rgsw/mod.rs @@ -95,23 +95,13 @@ pub struct ShoupAutoKeyEvaluationDomain { data: M, } -impl, R, N> +impl, Mod: Modulus, R, N> From<&AutoKeyEvaluationDomain> for ShoupAutoKeyEvaluationDomain -where - M::R: RowMut, - M::MatElement: ToShoup + Copy, { fn from(value: &AutoKeyEvaluationDomain) -> Self { - let (row, col) = value.data.dimension(); - let mut shoup_data = M::zeros(row, col); - - izip!(shoup_data.iter_rows_mut(), value.data.iter_rows()).for_each(|(shoup_r, r)| { - izip!(shoup_r.as_mut().iter_mut(), r.as_ref().iter()).for_each(|(s, e)| { - *s = M::MatElement::to_shoup(*e, value.modulus.q().unwrap()); - }); - }); - - Self { data: shoup_data } + Self { + data: M::to_shoup(&value.data, value.modulus.q().unwrap()), + } } } @@ -328,23 +318,13 @@ pub struct ShoupRgswCiphertextEvaluationDomain { pub(crate) data: M, } -impl, R, N> +impl, Mod: Modulus, R, N> From<&RgswCiphertextEvaluationDomain> for ShoupRgswCiphertextEvaluationDomain -where - M::R: RowMut, - M::MatElement: ToShoup + Copy, { fn from(value: &RgswCiphertextEvaluationDomain) -> Self { - let (row, col) = value.data.dimension(); - let mut shoup_data = M::zeros(row, col); - - izip!(shoup_data.iter_rows_mut(), value.data.iter_rows()).for_each(|(shoup_r, r)| { - izip!(shoup_r.as_mut().iter_mut(), r.as_ref().iter()).for_each(|(s, e)| { - *s = M::MatElement::to_shoup(*e, value.modulus.q().unwrap()); - }); - }); - - Self { data: shoup_data } + Self { + data: M::to_shoup(&value.data, value.modulus.q().unwrap()), + } } } diff --git a/src/shortint/mod.rs b/src/shortint/mod.rs index 1d0f0fa..fe3876a 100644 --- a/src/shortint/mod.rs +++ b/src/shortint/mod.rs @@ -106,19 +106,16 @@ mod frontend { use super::FheUint8; - type ShortIntBoolEvaluator = - BoolEvaluator; - mod arithetic { - use crate::bool::{evaluator::BooleanGates, FheBool}; + use crate::bool::{FheBool, RuntimeServerKey}; use super::*; use std::ops::{Add, AddAssign, Div, Mul, Rem, Sub}; impl AddAssign<&FheUint8> for FheUint8 { fn add_assign(&mut self, rhs: &FheUint8) { - ShortIntBoolEvaluator::with_local_mut_mut(&mut |e| { - let key = as BooleanGates>::Key::global(); + BoolEvaluator::with_local_mut_mut(&mut |e| { + let key = RuntimeServerKey::global(); arbitrary_bit_adder(e, self.data_mut(), rhs.data(), false, key); }); } @@ -137,7 +134,7 @@ mod frontend { type Output = FheUint8; fn sub(self, rhs: &FheUint8) -> Self::Output { BoolEvaluator::with_local_mut(|e| { - let key = ServerKeyEvaluationDomain::global(); + let key = RuntimeServerKey::global(); let (out, _, _) = arbitrary_bit_subtractor(e, self.data(), rhs.data(), key); FheUint8 { data: out } }) @@ -148,7 +145,7 @@ mod frontend { type Output = FheUint8; fn mul(self, rhs: &FheUint8) -> Self::Output { BoolEvaluator::with_local_mut(|e| { - let key = ServerKeyEvaluationDomain::global(); + let key = RuntimeServerKey::global(); let out = eight_bit_mul(e, self.data(), rhs.data(), key); FheUint8 { data: out } }) @@ -160,7 +157,7 @@ mod frontend { fn div(self, rhs: &FheUint8) -> Self::Output { // TODO(Jay:) Figure out how to set zero error flag BoolEvaluator::with_local_mut(|e| { - let key = ServerKeyEvaluationDomain::global(); + let key = RuntimeServerKey::global(); let (quotient, _) = arbitrary_bit_division_for_quotient_and_rem( e, self.data(), @@ -176,7 +173,7 @@ mod frontend { type Output = FheUint8; fn rem(self, rhs: &FheUint8) -> Self::Output { BoolEvaluator::with_local_mut(|e| { - let key = ServerKeyEvaluationDomain::global(); + let key = RuntimeServerKey::global(); let (_, remainder) = arbitrary_bit_division_for_quotient_and_rem( e, self.data(), @@ -191,7 +188,7 @@ mod frontend { impl FheUint8 { pub fn overflowing_add_assign(&mut self, rhs: &FheUint8) -> FheBool { BoolEvaluator::with_local_mut_mut(&mut |e| { - let key = ServerKeyEvaluationDomain::global(); + let key = RuntimeServerKey::global(); let (overflow, _) = arbitrary_bit_adder(e, self.data_mut(), rhs.data(), false, key); overflow @@ -201,7 +198,7 @@ mod frontend { pub fn overflowing_add(self, rhs: &FheUint8) -> (FheUint8, FheBool) { BoolEvaluator::with_local_mut(|e| { let mut lhs = self.clone(); - let key = ServerKeyEvaluationDomain::global(); + let key = RuntimeServerKey::global(); let (overflow, _) = arbitrary_bit_adder(e, lhs.data_mut(), rhs.data(), false, key); (lhs, overflow) @@ -210,7 +207,7 @@ mod frontend { pub fn overflowing_sub(&self, rhs: &FheUint8) -> (FheUint8, FheBool) { BoolEvaluator::with_local_mut(|e| { - let key = ServerKeyEvaluationDomain::global(); + let key = RuntimeServerKey::global(); let (out, mut overflow, _) = arbitrary_bit_subtractor(e, self.data(), rhs.data(), key); e.not_inplace(&mut overflow); @@ -221,7 +218,7 @@ mod frontend { pub fn div_rem(&self, rhs: &FheUint8) -> (FheUint8, FheUint8) { // TODO(Jay:) Figure out how to set zero error flag BoolEvaluator::with_local_mut(|e| { - let key = ServerKeyEvaluationDomain::global(); + let key = RuntimeServerKey::global(); let (quotient, remainder) = arbitrary_bit_division_for_quotient_and_rem( e, self.data(), @@ -236,7 +233,7 @@ mod frontend { mod booleans { use crate::{ - bool::{evaluator::BooleanGates, FheBool}, + bool::{evaluator::BooleanGates, FheBool, RuntimeServerKey}, shortint::ops::{ arbitrary_bit_comparator, arbitrary_bit_equality, arbitrary_signed_bit_comparator, }, @@ -248,7 +245,7 @@ mod frontend { /// a == b pub fn eq(&self, other: &FheUint8) -> FheBool { BoolEvaluator::with_local_mut(|e| { - let key = ServerKeyEvaluationDomain::global(); + let key = RuntimeServerKey::global(); arbitrary_bit_equality(e, self.data(), other.data(), key) }) } @@ -256,7 +253,7 @@ mod frontend { /// a != b pub fn neq(&self, other: &FheUint8) -> FheBool { BoolEvaluator::with_local_mut(|e| { - let key = ServerKeyEvaluationDomain::global(); + let key = RuntimeServerKey::global(); let mut is_equal = arbitrary_bit_equality(e, self.data(), other.data(), key); e.not_inplace(&mut is_equal); is_equal @@ -266,7 +263,7 @@ mod frontend { /// a < b pub fn lt(&self, other: &FheUint8) -> FheBool { BoolEvaluator::with_local_mut(|e| { - let key = ServerKeyEvaluationDomain::global(); + let key = RuntimeServerKey::global(); arbitrary_bit_comparator(e, other.data(), self.data(), key) }) } @@ -274,7 +271,7 @@ mod frontend { /// a > b pub fn gt(&self, other: &FheUint8) -> FheBool { BoolEvaluator::with_local_mut(|e| { - let key = ServerKeyEvaluationDomain::global(); + let key = RuntimeServerKey::global(); arbitrary_bit_comparator(e, self.data(), other.data(), key) }) } @@ -282,7 +279,7 @@ mod frontend { /// a <= b pub fn le(&self, other: &FheUint8) -> FheBool { BoolEvaluator::with_local_mut(|e| { - let key = ServerKeyEvaluationDomain::global(); + let key = RuntimeServerKey::global(); let mut a_greater_b = arbitrary_bit_comparator(e, self.data(), other.data(), key); e.not_inplace(&mut a_greater_b); @@ -293,7 +290,7 @@ mod frontend { /// a >= b pub fn ge(&self, other: &FheUint8) -> FheBool { BoolEvaluator::with_local_mut(|e| { - let key = ServerKeyEvaluationDomain::global(); + let key = RuntimeServerKey::global(); let mut a_less_b = arbitrary_bit_comparator(e, other.data(), self.data(), key); e.not_inplace(&mut a_less_b); a_less_b diff --git a/src/utils.rs b/src/utils.rs index 532bc09..f1c6832 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,11 +1,12 @@ -use std::{fmt::Debug, usize}; +use std::{fmt::Debug, usize, vec}; -use itertools::Itertools; +use itertools::{izip, Itertools}; use num_traits::{FromPrimitive, PrimInt, Signed, Unsigned}; use crate::{ backend::Modulus, random::{RandomElement, RandomElementInModulus, RandomFill}, + Matrix, }; pub trait WithLocal { fn with_local(func: F) -> R @@ -30,10 +31,6 @@ pub(crate) trait ShoupMul { fn mul(a: Self, b: Self, b_shoup: Self, q: Self) -> Self; } -pub(crate) trait ToShoup { - fn to_shoup(value: Self, modulus: Self) -> Self; -} - impl ShoupMul for u64 { #[inline] fn representation(value: Self, q: Self) -> Self { @@ -48,9 +45,29 @@ impl ShoupMul for u64 { } } +pub(crate) trait ToShoup { + type Modulus; + fn to_shoup(value: &Self, modulus: Self::Modulus) -> Self; +} + impl ToShoup for u64 { - fn to_shoup(value: Self, modulus: Self) -> Self { - ((value as u128 * (1u128 << 64)) / modulus as u128) as u64 + type Modulus = u64; + fn to_shoup(value: &Self, modulus: Self) -> Self { + ((*value as u128 * (1u128 << 64)) / modulus as u128) as u64 + } +} + +impl ToShoup for Vec> { + type Modulus = u64; + fn to_shoup(value: &Self, modulus: Self::Modulus) -> Self { + let (row, col) = value.dimension(); + let mut shoup_value = vec![vec![0u64; col]; row]; + izip!(shoup_value.iter_mut(), value.iter()).for_each(|(shoup_r, r)| { + izip!(shoup_r.iter_mut(), r.iter()).for_each(|(s, e)| { + *s = u64::to_shoup(e, modulus); + }) + }); + shoup_value } }