diff --git a/src/backend.rs b/src/backend.rs index 5f0fa09..a66a227 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -7,6 +7,8 @@ pub trait Modulus { type Element; /// Modulus value if it fits in Element fn q(&self) -> Option; + /// Modulus value as f64 if it fits in f64 + fn q_as_f64(&self) -> Option; /// Is modulus native? fn is_native(&self) -> bool; /// -1 in signed representaiton @@ -17,11 +19,11 @@ pub trait Modulus { /// Always assmed to be 0. fn smallest_unsigned_value(&self) -> Self::Element; /// Convert unsigned value in signed represetation to i64 - fn to_i64(&self, v: &Self::Element) -> i64; + fn map_element_to_i64(&self, v: &Self::Element) -> i64; /// Convert f64 to signed represented in modulus - fn from_f64(&self, v: f64) -> Self::Element; + fn map_element_from_f64(&self, v: f64) -> Self::Element; /// Convert i64 to signed represented in modulus - fn from_i64(&self, v: i64) -> Self::Element; + fn map_element_from_i64(&self, v: i64) -> Self::Element; } impl Modulus for u64 { @@ -39,7 +41,7 @@ impl Modulus for u64 { fn smallest_unsigned_value(&self) -> Self::Element { 0 } - fn to_i64(&self, v: &Self::Element) -> i64 { + fn map_element_to_i64(&self, v: &Self::Element) -> i64 { assert!(v < self); if *v > (self >> 1) { @@ -48,7 +50,7 @@ impl Modulus for u64 { ToPrimitive::to_i64(v).unwrap() } } - fn from_f64(&self, v: f64) -> Self::Element { + fn map_element_from_f64(&self, v: f64) -> Self::Element { //FIXME (Jay): Before I check whether v is smaller than 0 with `let is_neg = // o.is_sign_negative() && o != 0.0; I'm ocnfused why didn't I simply check < // 0.0? @@ -59,7 +61,7 @@ impl Modulus for u64 { v.to_u64().unwrap() } } - fn from_i64(&self, v: i64) -> Self::Element { + fn map_element_from_i64(&self, v: i64) -> Self::Element { if v < 0 { self - v.to_u64().unwrap() } else { @@ -69,6 +71,9 @@ impl Modulus for u64 { fn q(&self) -> Option { Some(*self) } + fn q_as_f64(&self) -> Option { + self.to_f64() + } } pub trait ModInit { diff --git a/src/bool/evaluator.rs b/src/bool/evaluator.rs index bf506ab..9f07396 100644 --- a/src/bool/evaluator.rs +++ b/src/bool/evaluator.rs @@ -3,6 +3,7 @@ use std::{ collections::HashMap, fmt::{Debug, Display}, marker::PhantomData, + ops::Shr, }; use itertools::{izip, partition, Itertools}; @@ -110,26 +111,45 @@ trait BoolEncoding { fn decode(&self, m: Self::Element) -> bool; } -impl BoolEncoding for CiphertextModulus { +impl BoolEncoding for CiphertextModulus +where + CiphertextModulus: Modulus, + T: PrimInt, +{ type Element = T; - fn false_el(&self) -> Self::Element { - todo!() - } + fn qby4(&self) -> Self::Element { - todo!() + if self.is_native() { + T::one() << (CiphertextModulus::::_bits() - 2) + } else { + self.q().unwrap() >> 2 + } } + /// Q/8 fn true_el(&self) -> Self::Element { - todo!() + if self.is_native() { + T::one() << (CiphertextModulus::::_bits() - 3) + } else { + self.q().unwrap() >> 3 + } + } + /// -Q/8 + fn false_el(&self) -> Self::Element { + self.largest_unsigned_value() - self.true_el() + T::one() } fn decode(&self, m: Self::Element) -> bool { - // let m = (((encoded_m + self.pbs_info.rlweq_by8).to_f64().unwrap() * 4f64) - // / self.pbs_info.parameters.rlwe_q().0.to_f64().unwrap()) - // .round() as usize - // % 4usize; - // panic!("Incorrect bool decryption. Got m={m} but expected m to be - // 0 or 1") - - todo!() + let qby8 = self.true_el(); + let m = (((m + qby8).to_f64().unwrap() * 4.0f64) / self.q_as_f64().unwrap()).round() + as usize + % 4usize; + + if m == 0 { + return false; + } else if m == 1 { + return true; + } else { + panic!("Incorrect bool decryption. Got m={m} but expected m to be 0 or 1") + } } } @@ -304,119 +324,6 @@ struct SeededMultiPartyServerKey { parameters: P, } -fn aggregate_multi_party_server_key_shares< - M: MatrixMut + MatrixEntity, - S: Copy + PartialEq, - D: RlweDecomposer, - ModOp: VectorOps + ModInit>, - NttOp: Ntt + NttInit>, ->( - shares: &[CommonReferenceSeededMultiPartyServerKeyShare, S>], -) -> SeededMultiPartyServerKey> -where - ::R: RowMut + RowEntity, - M::MatElement: PrimInt + PartialEq + Zero, - M: Clone, -{ - assert!(shares.len() > 0); - let parameters = shares[0].parameters.clone(); - let cr_seed = shares[0].cr_seed; - - let rlwe_n = parameters.rlwe_n().0; - let g = parameters.g() as isize; - let rlwe_q = parameters.rlwe_q(); - let lwe_q = parameters.lwe_q(); - - // sanity checks - shares.iter().skip(1).for_each(|s| { - assert!(s.parameters == parameters); - assert!(s.cr_seed == cr_seed); - }); - - let rlweq_modop = ModOp::new(*rlwe_q); - let rlweq_nttop = NttOp::new(rlwe_q, rlwe_n); - - // auto keys - let mut auto_keys = HashMap::new(); - for i in [g, -g] { - let mut key = M::zeros(parameters.auto_decomposition_count().0, rlwe_n); - - shares.iter().for_each(|s| { - let auto_key_share_i = s.auto_keys.get(&i).expect("Auto key {i} missing"); - assert!( - auto_key_share_i.dimension() == (parameters.auto_decomposition_count().0, rlwe_n) - ); - izip!(key.iter_rows_mut(), auto_key_share_i.iter_rows()).for_each( - |(partb_out, partb_share)| { - rlweq_modop.elwise_add_mut(partb_out.as_mut(), partb_share.as_ref()); - }, - ); - }); - - auto_keys.insert(i, key); - } - - // rgsw ciphertext (most expensive part!) - let lwe_n = parameters.lwe_n().0; - let rgsw_by_rgsw_decomposer = parameters.rgsw_rgsw_decomposer::(); - let mut scratch_matrix = M::zeros( - std::cmp::max( - rgsw_by_rgsw_decomposer.a().decomposition_count(), - rgsw_by_rgsw_decomposer.b().decomposition_count(), - ) + (rgsw_by_rgsw_decomposer.a().decomposition_count() * 2 - + rgsw_by_rgsw_decomposer.b().decomposition_count() * 2), - rlwe_n, - ); - - let mut tmp_rgsw = - RgswCiphertext::::empty(rlwe_n, &rgsw_by_rgsw_decomposer, rlwe_q.clone()).data; - let rgsw_cts = (0..lwe_n) - .into_iter() - .map(|index| { - // copy over rgsw ciphertext for index^th secret element from first share and - // treat it as accumulating rgsw ciphertext - let mut rgsw_i = shares[0].rgsw_cts[index].clone(); - - shares.iter().skip(1).for_each(|si| { - // copy over si's RGSW[index] ciphertext and send to evaluation domain - izip!(tmp_rgsw.iter_rows_mut(), si.rgsw_cts[index].iter_rows()).for_each( - |(to_ri, from_ri)| { - to_ri.as_mut().copy_from_slice(from_ri.as_ref()); - rlweq_nttop.forward(to_ri.as_mut()) - }, - ); - - rgsw_by_rgsw_inplace( - &mut rgsw_i, - &tmp_rgsw, - &rgsw_by_rgsw_decomposer, - &mut scratch_matrix, - &rlweq_nttop, - &rlweq_modop, - ); - }); - - rgsw_i - }) - .collect_vec(); - - // LWE ksks - let mut lwe_ksk = M::R::zeros(rlwe_n * parameters.lwe_decomposition_count().0); - let lweq_modop = ModOp::new(*lwe_q); - shares.iter().for_each(|si| { - assert!(si.lwe_ksk.as_ref().len() == rlwe_n * parameters.lwe_decomposition_count().0); - lweq_modop.elwise_add_mut(lwe_ksk.as_mut(), si.lwe_ksk.as_ref()) - }); - - SeededMultiPartyServerKey { - rgsw_cts, - auto_keys, - lwe_ksk, - cr_seed, - parameters: parameters, - } -} - /// Seeded single party server key struct SeededServerKey { /// Rgsw cts of LWE secret elements @@ -792,7 +699,7 @@ struct BoolPbsInfo { impl PbsInfo for BoolPbsInfo where - M::MatElement: PrimInt + WrappingSub + NumInfo + Debug, + M::MatElement: PrimInt + WrappingSub + NumInfo + Debug + FromPrimitive, RlweModOp: ArithmeticOps + VectorOps, LweModOp: ArithmeticOps + VectorOps, NttOp: Ntt, @@ -1382,6 +1289,119 @@ where self.pbs_info.rlwe_q().decode(m) } + fn aggregate_multi_party_server_key_shares( + &self, + shares: &[CommonReferenceSeededMultiPartyServerKeyShare< + M, + BoolParameters, + S, + >], + ) -> SeededMultiPartyServerKey> + where + S: PartialEq + Clone, + M: Clone, + { + assert!(shares.len() > 0); + let parameters = shares[0].parameters.clone(); + let cr_seed = &shares[0].cr_seed; + + let rlwe_n = parameters.rlwe_n().0; + let g = parameters.g() as isize; + let rlwe_q = parameters.rlwe_q(); + let lwe_q = parameters.lwe_q(); + + // sanity checks + shares.iter().skip(1).for_each(|s| { + assert!(s.parameters == parameters); + assert!(&s.cr_seed == cr_seed); + }); + + let rlweq_modop = &self.pbs_info.rlwe_modop; + let rlweq_nttop = &self.pbs_info.rlwe_nttop; + + // auto keys + let mut auto_keys = HashMap::new(); + for i in [g, -g] { + let mut key = M::zeros(parameters.auto_decomposition_count().0, rlwe_n); + + shares.iter().for_each(|s| { + let auto_key_share_i = s.auto_keys.get(&i).expect("Auto key {i} missing"); + assert!( + auto_key_share_i.dimension() + == (parameters.auto_decomposition_count().0, rlwe_n) + ); + izip!(key.iter_rows_mut(), auto_key_share_i.iter_rows()).for_each( + |(partb_out, partb_share)| { + rlweq_modop.elwise_add_mut(partb_out.as_mut(), partb_share.as_ref()); + }, + ); + }); + + auto_keys.insert(i, key); + } + + // rgsw ciphertext (most expensive part!) + let lwe_n = parameters.lwe_n().0; + let rgsw_by_rgsw_decomposer = + parameters.rgsw_rgsw_decomposer::>(); + let mut scratch_matrix = M::zeros( + std::cmp::max( + rgsw_by_rgsw_decomposer.a().decomposition_count(), + rgsw_by_rgsw_decomposer.b().decomposition_count(), + ) + (rgsw_by_rgsw_decomposer.a().decomposition_count() * 2 + + rgsw_by_rgsw_decomposer.b().decomposition_count() * 2), + rlwe_n, + ); + + let mut tmp_rgsw = + RgswCiphertext::::empty(rlwe_n, &rgsw_by_rgsw_decomposer, rlwe_q.clone()).data; + let rgsw_cts = (0..lwe_n) + .into_iter() + .map(|index| { + // copy over rgsw ciphertext for index^th secret element from first share and + // treat it as accumulating rgsw ciphertext + let mut rgsw_i = shares[0].rgsw_cts[index].clone(); + + shares.iter().skip(1).for_each(|si| { + // copy over si's RGSW[index] ciphertext and send to evaluation domain + izip!(tmp_rgsw.iter_rows_mut(), si.rgsw_cts[index].iter_rows()).for_each( + |(to_ri, from_ri)| { + to_ri.as_mut().copy_from_slice(from_ri.as_ref()); + rlweq_nttop.forward(to_ri.as_mut()) + }, + ); + + rgsw_by_rgsw_inplace( + &mut rgsw_i, + &tmp_rgsw, + &rgsw_by_rgsw_decomposer, + &mut scratch_matrix, + rlweq_nttop, + rlweq_modop, + ); + }); + + rgsw_i + }) + .collect_vec(); + + // LWE ksks + let mut lwe_ksk = M::R::zeros(rlwe_n * parameters.lwe_decomposition_count().0); + let lweq_modop = &self.pbs_info.lwe_modop; + shares.iter().for_each(|si| { + assert!(si.lwe_ksk.as_ref().len() == rlwe_n * parameters.lwe_decomposition_count().0); + lweq_modop.elwise_add_mut(lwe_ksk.as_mut(), si.lwe_ksk.as_ref()) + }); + + SeededMultiPartyServerKey { + rgsw_cts, + auto_keys, + lwe_ksk, + cr_seed: cr_seed.clone(), + parameters: parameters, + } + } + // TODO(Jay): scratch spaces must be thread local. Don't pass them as arguments pub fn nand( &mut self, @@ -1559,8 +1579,8 @@ fn pbs< let rlwe_q = pbs_info.rlwe_q(); let lwe_q = pbs_info.lwe_q(); let br_q = pbs_info.br_q(); - let rlwe_qf64 = rlwe_q.q().unwrap().to_f64().unwrap(); - let lwe_qf64 = lwe_q.q().unwrap().to_f64().unwrap(); + let rlwe_qf64 = rlwe_q.q_as_f64().unwrap(); + let lwe_qf64 = lwe_q.q_as_f64().unwrap(); let br_qf64 = br_q.to_f64().unwrap(); let rlwe_n = pbs_info.rlwe_n(); @@ -1794,65 +1814,1078 @@ fn monomial_mul>( }); } -// thread_local! { -// static PBS_TRACER: RefCell>>> = -// RefCell::new(PBSTracer::default()); } - -// #[derive(Default)] -// struct PBSTracer -// where -// M: Matrix + Default, -// { -// pub(crate) ct_rlwe_q_mod: M::R, -// pub(crate) ct_lwe_q_mod: M::R, -// pub(crate) ct_lwe_q_mod_after_ksk: M::R, -// pub(crate) ct_br_q_mod: Vec, -// } +thread_local! { + static PBS_TRACER: RefCell>>> = +RefCell::new(PBSTracer::default()); } -// impl PBSTracer>> { -// fn trace(&self, parameters: &BoolParameters, sk_lwe: &[i32], -// sk_rlwe: &[i32]) { assert!(parameters.rlwe_n().0 == sk_rlwe.len()); -// assert!(parameters.lwe_n().0 == sk_lwe.len()); - -// let modop_rlweq = ModularOpsU64::new(*parameters.rlwe_q()); -// // noise after mod down Q -> Q_ks -// let m_back0 = decrypt_lwe(&self.ct_rlwe_q_mod, sk_rlwe, -// &modop_rlweq); - -// let modop_lweq = ModularOpsU64::new(parameters.lwe_q().0); -// // noise after mod down Q -> Q_ks -// let m_back1 = decrypt_lwe(&self.ct_lwe_q_mod, sk_rlwe, &modop_lweq); -// // noise after key switch from RLWE -> LWE -// let m_back2 = decrypt_lwe(&self.ct_lwe_q_mod_after_ksk, sk_lwe, -// &modop_lweq); - -// // noise after mod down odd from Q_ks -> q -// let modop_br_q = ModularOpsU64::new(parameters.br_q().0); -// let m_back3 = decrypt_lwe(&self.ct_br_q_mod, sk_lwe, &modop_br_q); - -// println!( -// " -// M initial mod Q: {m_back0}, -// M after mod down Q -> Q_ks: {m_back1}, -// M after key switch from RLWE -> LWE: {m_back2}, -// M after mod dwon Q_ks -> q: {m_back3} -// " -// ); -// } -// } +#[derive(Default)] +struct PBSTracer +where + M: Matrix + Default, +{ + pub(crate) ct_rlwe_q_mod: M::R, + pub(crate) ct_lwe_q_mod: M::R, + pub(crate) ct_lwe_q_mod_after_ksk: M::R, + pub(crate) ct_br_q_mod: Vec, +} -// impl WithLocal for PBSTracer>> { -// fn with_local(func: F) -> R -// where -// F: Fn(&Self) -> R, -// { -// PBS_TRACER.with_borrow(|t| func(t)) -// } +impl PBSTracer>> { + fn trace(&self, parameters: &BoolParameters, sk_lwe: &[i32], sk_rlwe: &[i32]) { + assert!(parameters.rlwe_n().0 == sk_rlwe.len()); + assert!(parameters.lwe_n().0 == sk_lwe.len()); + + let modop_rlweq = ModularOpsU64::new(*parameters.rlwe_q()); + // noise after mod down Q -> Q_ks + let m_back0 = decrypt_lwe(&self.ct_rlwe_q_mod, sk_rlwe, &modop_rlweq); + + let modop_lweq = ModularOpsU64::>::new(*parameters.lwe_q()); + // noise after mod down Q -> Q_ks + let m_back1 = decrypt_lwe(&self.ct_lwe_q_mod, sk_rlwe, &modop_lweq); + // noise after key switch from RLWE -> LWE + let m_back2 = decrypt_lwe(&self.ct_lwe_q_mod_after_ksk, sk_lwe, &modop_lweq); + + // noise after mod down odd from Q_ks -> q + let modop_br_q = ModularOpsU64::::new(*parameters.br_q() as u64); + let m_back3 = decrypt_lwe(&self.ct_br_q_mod, sk_lwe, &modop_br_q); + + println!( + " + M initial mod Q: {m_back0}, + M after mod down Q -> Q_ks: {m_back1}, + M after key switch from RLWE -> LWE: {m_back2}, + M after mod dwon Q_ks -> q: {m_back3} + " + ); + } +} -// fn with_local_mut(func: F) -> R -// where -// F: Fn(&mut Self) -> R, -// { -// PBS_TRACER.with_borrow_mut(|t| func(t)) -// } -// } +impl WithLocal for PBSTracer>> { + fn with_local(func: F) -> R + where + F: Fn(&Self) -> R, + { + PBS_TRACER.with_borrow(|t| func(t)) + } + + fn with_local_mut(func: F) -> R + where + F: Fn(&mut Self) -> R, + { + PBS_TRACER.with_borrow_mut(|t| func(t)) + } +} + +#[cfg(test)] +mod tests { + use std::iter::Sum; + + use rand::{thread_rng, Rng}; + use rand_distr::Uniform; + + use crate::{ + backend::ModularOpsU64, + bool, + ntt::NttBackendU64, + random::DEFAULT_RNG, + rgsw::{ + self, measure_noise, public_key_encrypt_rlwe, secret_key_encrypt_rlwe, + tests::{_measure_noise_rgsw, _sk_encrypt_rlwe}, + RgswCiphertext, RgswCiphertextEvaluationDomain, SeededRgswCiphertext, + SeededRlweCiphertext, + }, + utils::{negacyclic_mul, Stats}, + }; + + use super::*; + + #[test] + fn bool_encrypt_decrypt_works() { + let bool_evaluator = BoolEvaluator::< + Vec>, + NttBackendU64, + ModularOpsU64>, + ModularOpsU64>, + >::new(SP_BOOL_PARAMS); + let client_key = bool_evaluator.client_key(); + + let mut m = true; + for _ in 0..1000 { + let lwe_ct = bool_evaluator.sk_encrypt(m, &client_key); + let m_back = bool_evaluator.sk_decrypt(&lwe_ct, &client_key); + assert_eq!(m, m_back); + m = !m; + } + } + + #[test] + fn bool_nand() { + DefaultSecureRng::with_local_mut(|r| { + let rng = DefaultSecureRng::new_seeded([19u8; 32]); + *r = rng; + }); + + let mut bool_evaluator = BoolEvaluator::< + Vec>, + NttBackendU64, + ModularOpsU64>, + ModularOpsU64>, + >::new(SP_BOOL_PARAMS); + + // println!("{:?}", bool_evaluator.nand_test_vec); + let client_key = bool_evaluator.client_key(); + let seeded_server_key = bool_evaluator.server_key(&client_key); + let server_key_eval_domain = + ServerKeyEvaluationDomain::<_, DefaultSecureRng, NttBackendU64>::from( + &seeded_server_key, + ); + + let mut m0 = false; + let mut m1 = true; + let mut ct0 = bool_evaluator.sk_encrypt(m0, &client_key); + let mut ct1 = bool_evaluator.sk_encrypt(m1, &client_key); + + for _ in 0..1000 { + let ct_back = bool_evaluator.nand(&ct0, &ct1, &server_key_eval_domain); + + let m_out = !(m0 && m1); + + // Trace and measure PBS noise + { + let noise0 = { + let ideal = if m0 { + bool_evaluator.pbs_info.parameters.rlwe_q().true_el() + } else { + bool_evaluator.pbs_info.parameters.rlwe_q().false_el() + }; + let n = measure_noise_lwe( + &ct0, + client_key.sk_rlwe.values(), + &bool_evaluator.pbs_info.rlwe_modop, + &ideal, + ); + let v = decrypt_lwe( + &ct0, + client_key.sk_rlwe.values(), + &bool_evaluator.pbs_info.rlwe_modop, + ); + (n, v) + }; + let noise1 = { + let ideal = if m1 { + bool_evaluator.pbs_info.parameters.rlwe_q().true_el() + } else { + bool_evaluator.pbs_info.parameters.rlwe_q().false_el() + }; + let n = measure_noise_lwe( + &ct1, + client_key.sk_rlwe.values(), + &bool_evaluator.pbs_info.rlwe_modop, + &ideal, + ); + let v = decrypt_lwe( + &ct1, + client_key.sk_rlwe.values(), + &bool_evaluator.pbs_info.rlwe_modop, + ); + (n, v) + }; + + // // Trace PBS + // PBSTracer::with_local(|t| { + // t.trace( + // &SP_BOOL_PARAMS, + // &client_key.sk_lwe.values(), + // client_key.sk_rlwe.values(), + // ) + // }); + + // Calculate noise in ciphertext post PBS + let noise_out = { + let ideal = if m_out { + bool_evaluator.pbs_info.parameters.rlwe_q().true_el() + } else { + bool_evaluator.pbs_info.parameters.rlwe_q().false_el() + }; + let n = measure_noise_lwe( + &ct_back, + client_key.sk_rlwe.values(), + &bool_evaluator.pbs_info.rlwe_modop, + &ideal, + ); + let v = decrypt_lwe( + &ct_back, + client_key.sk_rlwe.values(), + &bool_evaluator.pbs_info.rlwe_modop, + ); + (n, v) + }; + dbg!(m0, m1, m_out); + println!( + "ct0 (noise, message): {:?} \n ct1 (noise, message): {:?} \n PBS (noise, message): {:?}", noise0, noise1, noise_out + ); + } + let m_back = bool_evaluator.sk_decrypt(&ct_back, &client_key); + assert!(m_out == m_back, "Expected {m_out}, got {m_back}"); + println!("----------"); + + m1 = m0; + m0 = m_out; + + ct1 = ct0; + ct0 = ct_back; + } + } + + #[test] + fn multi_party_encryption_decryption() { + let bool_evaluator = BoolEvaluator::< + Vec>, + NttBackendU64, + ModularOpsU64>, + ModularOpsU64>, + >::new(MP_BOOL_PARAMS); + + let no_of_parties = 500; + let parties = (0..no_of_parties) + .map(|_| bool_evaluator.client_key()) + .collect_vec(); + + let mut ideal_rlwe_sk = vec![0i32; bool_evaluator.pbs_info.rlwe_n()]; + parties.iter().for_each(|k| { + izip!(ideal_rlwe_sk.iter_mut(), k.sk_rlwe.values()).for_each(|(ideal_i, s_i)| { + *ideal_i = *ideal_i + s_i; + }); + }); + + println!("{:?}", &ideal_rlwe_sk); + + let mut m = true; + for i in 0..100 { + let pk_cr_seed = [0u8; 32]; + + let public_key_share = parties + .iter() + .map(|k| bool_evaluator.multi_party_public_key_share(pk_cr_seed, k)) + .collect_vec(); + + let collective_pk = PublicKey::< + Vec>, + DefaultSecureRng, + ModularOpsU64>, + >::from(public_key_share.as_slice()); + let lwe_ct = bool_evaluator.pk_encrypt(&collective_pk.key, m); + + let decryption_shares = parties + .iter() + .map(|k| bool_evaluator.multi_party_decryption_share(&lwe_ct, k)) + .collect_vec(); + + let m_back = bool_evaluator.multi_party_decrypt(&decryption_shares, &lwe_ct); + + { + let ideal_m = if m { + bool_evaluator.pbs_info.parameters.rlwe_q().true_el() + } else { + bool_evaluator.pbs_info.parameters.rlwe_q().false_el() + }; + let noise = measure_noise_lwe( + &lwe_ct, + &ideal_rlwe_sk, + &bool_evaluator.pbs_info.rlwe_modop, + &ideal_m, + ); + println!("Noise: {noise}"); + } + + assert_eq!(m_back, m); + m = !m; + } + } + + fn _collecitve_public_key_gen(rlwe_q: u64, parties_rlwe_sk: &[RlweSecret]) -> Vec> { + let ring_size = parties_rlwe_sk[0].values.len(); + assert!(ring_size.is_power_of_two()); + let mut rng = DefaultSecureRng::new(); + let nttop = NttBackendU64::new(&rlwe_q, ring_size); + let modop = ModularOpsU64::new(rlwe_q); + + // Generate Pk shares + let pk_seed = [0u8; 32]; + let pk_shares = parties_rlwe_sk.iter().map(|sk| { + let mut p_rng = DefaultSecureRng::new_seeded(pk_seed); + let mut share_out = vec![0u64; ring_size]; + public_key_share( + &mut share_out, + sk.values(), + &modop, + &nttop, + &mut p_rng, + &mut rng, + ); + share_out + }); + + let mut pk_part_b = vec![0u64; ring_size]; + pk_shares.for_each(|share| modop.elwise_add_mut(&mut pk_part_b, &share)); + let mut pk_part_a = vec![0u64; ring_size]; + let mut p_rng = DefaultSecureRng::new_seeded(pk_seed); + RandomFillUniformInModulus::random_fill(&mut p_rng, &rlwe_q, pk_part_a.as_mut_slice()); + + vec![pk_part_a, pk_part_b] + } + + fn _multi_party_all_keygen( + bool_evaluator: &BoolEvaluator< + Vec>, + NttBackendU64, + ModularOpsU64>, + ModularOpsU64>, + >, + no_of_parties: usize, + ) -> ( + Vec, + PublicKey>, DefaultSecureRng, ModularOpsU64>>, + Vec< + CommonReferenceSeededMultiPartyServerKeyShare< + Vec>, + BoolParameters, + [u8; 32], + >, + >, + SeededMultiPartyServerKey>, [u8; 32], BoolParameters>, + ServerKeyEvaluationDomain>, DefaultSecureRng, NttBackendU64>, + ClientKey, + ) { + let parties = (0..no_of_parties) + .map(|_| bool_evaluator.client_key()) + .collect_vec(); + + let mut rng = DefaultSecureRng::new(); + + // Collective public key + let mut pk_cr_seed = [0u8; 32]; + rng.fill_bytes(&mut pk_cr_seed); + let public_key_share = parties + .iter() + .map(|k| bool_evaluator.multi_party_public_key_share(pk_cr_seed, k)) + .collect_vec(); + let collective_pk = + PublicKey::>, DefaultSecureRng, _>::from(public_key_share.as_slice()); + + // Server key + let mut pbs_cr_seed = [0u8; 32]; + rng.fill_bytes(&mut pbs_cr_seed); + let server_key_shares = parties + .iter() + .map(|k| { + bool_evaluator.multi_party_server_key_share(pbs_cr_seed, &collective_pk.key, k) + }) + .collect_vec(); + let seeded_server_key = + bool_evaluator.aggregate_multi_party_server_key_shares(&server_key_shares); + let server_key_eval = ServerKeyEvaluationDomain::<_, DefaultSecureRng, NttBackendU64>::from( + &seeded_server_key, + ); + + // construct ideal rlwe sk for meauring noise + let ideal_client_key = { + let mut ideal_rlwe_sk = vec![0i32; bool_evaluator.pbs_info.rlwe_n()]; + parties.iter().for_each(|k| { + izip!(ideal_rlwe_sk.iter_mut(), k.sk_rlwe.values()).for_each(|(ideal_i, s_i)| { + *ideal_i = *ideal_i + s_i; + }); + }); + let mut ideal_lwe_sk = vec![0i32; bool_evaluator.pbs_info.lwe_n()]; + parties.iter().for_each(|k| { + izip!(ideal_lwe_sk.iter_mut(), k.sk_lwe.values()).for_each(|(ideal_i, s_i)| { + *ideal_i = *ideal_i + s_i; + }); + }); + + ClientKey { + sk_lwe: LweSecret { + values: ideal_lwe_sk, + }, + sk_rlwe: RlweSecret { + values: ideal_rlwe_sk, + }, + } + }; + + ( + parties, + collective_pk, + server_key_shares, + seeded_server_key, + server_key_eval, + ideal_client_key, + ) + } + + #[test] + fn multi_party_nand() { + let mut bool_evaluator = BoolEvaluator::< + Vec>, + NttBackendU64, + ModularOpsU64>, + ModularOpsU64>, + >::new(MP_BOOL_PARAMS); + + let (parties, collective_pk, _, _, server_key_eval, ideal_client_key) = + _multi_party_all_keygen(&bool_evaluator, 2); + + let mut m0 = true; + let mut m1 = false; + + let mut lwe0 = bool_evaluator.pk_encrypt(&collective_pk.key, m0); + let mut lwe1 = bool_evaluator.pk_encrypt(&collective_pk.key, m1); + + for _ in 0..2000 { + let lwe_out = bool_evaluator.nand(&lwe0, &lwe1, &server_key_eval); + + let m_expected = !(m0 & m1); + + // measure noise + { + let noise0 = { + let ideal = if m0 { + bool_evaluator.pbs_info.rlwe_q().true_el() + } else { + bool_evaluator.pbs_info.rlwe_q().false_el() + }; + let n = measure_noise_lwe( + &lwe0, + ideal_client_key.sk_rlwe.values(), + &bool_evaluator.pbs_info.rlwe_modop, + &ideal, + ); + let v = decrypt_lwe( + &lwe0, + ideal_client_key.sk_rlwe.values(), + &bool_evaluator.pbs_info.rlwe_modop, + ); + (n, v) + }; + let noise1 = { + let ideal = if m1 { + bool_evaluator.pbs_info.rlwe_q().true_el() + } else { + bool_evaluator.pbs_info.rlwe_q().false_el() + }; + let n = measure_noise_lwe( + &lwe1, + ideal_client_key.sk_rlwe.values(), + &bool_evaluator.pbs_info.rlwe_modop, + &ideal, + ); + let v = decrypt_lwe( + &lwe1, + ideal_client_key.sk_rlwe.values(), + &bool_evaluator.pbs_info.rlwe_modop, + ); + (n, v) + }; + + // // Trace PBS + // PBSTracer::with_local(|t| { + // t.trace( + // &MP_BOOL_PARAMS, + // &ideal_client_key.sk_lwe.values(), + // &ideal_client_key.sk_rlwe.values(), + // ) + // }); + + let noise_out = { + let ideal_m = if m_expected { + bool_evaluator.pbs_info.rlwe_q().true_el() + } else { + bool_evaluator.pbs_info.rlwe_q().false_el() + }; + let n = measure_noise_lwe( + &lwe_out, + ideal_client_key.sk_rlwe.values(), + &bool_evaluator.pbs_info.rlwe_modop, + &ideal_m, + ); + let v = decrypt_lwe( + &lwe_out, + ideal_client_key.sk_rlwe.values(), + &bool_evaluator.pbs_info.rlwe_modop, + ); + (n, v) + }; + dbg!(m0, m1, m_expected); + println!( + "ct0 (noise, message): {:?} \n ct1 (noise, message): {:?} \n PBS (noise, message): {:?}", noise0, noise1, noise_out + ); + } + + // multi-party decrypt + let decryption_shares = parties + .iter() + .map(|k| bool_evaluator.multi_party_decryption_share(&lwe_out, k)) + .collect_vec(); + let m_back = bool_evaluator.multi_party_decrypt(&decryption_shares, &lwe_out); + + // let m_back = bool_evaluator.sk_decrypt(&lwe_out, &ideal_client_key); + + assert!(m_expected == m_back, "Expected {m_expected}, got {m_back}"); + m1 = m0; + m0 = m_expected; + + lwe1 = lwe0; + lwe0 = lwe_out; + } + } + + #[test] + fn tester() { + // pub(super) const TEST_MP_BOOL_PARAMS: BoolParameters = + // BoolParameters:: { rlwe_q: 1152921504606830593, + // rlwe_logq: 60, + // lwe_q: 1 << 20, + // lwe_logq: 20, + // br_q: 1 << 11, + // rlwe_n: 1 << 11, + // lwe_n: 500, + // d_rgsw: 4, + // logb_rgsw: 12, + // d_lwe: 5, + // logb_lwe: 4, + // g: 5, + // w: 1, + // }; + + let bool_evaluator = BoolEvaluator::< + Vec>, + NttBackendU64, + ModularOpsU64>, + ModularOpsU64>, + >::new(MP_BOOL_PARAMS); + + // let (_, collective_pk, _, _, server_key_eval, ideal_client_key) = + // _multi_party_all_keygen(&bool_evaluator, 20); + let no_of_parties = 2; + let lwe_q = bool_evaluator.pbs_info.parameters.lwe_q(); + let rlwe_q = bool_evaluator.pbs_info.parameters.rlwe_q(); + let lwe_n = bool_evaluator.pbs_info.parameters.lwe_n().0; + let rlwe_n = bool_evaluator.pbs_info.parameters.rlwe_n().0; + let lwe_modop = &bool_evaluator.pbs_info.lwe_modop; + let rlwe_nttop = &bool_evaluator.pbs_info.rlwe_nttop; + let rlwe_modop = &bool_evaluator.pbs_info.rlwe_modop; + + let rlwe_rgsw_decomposer = &bool_evaluator.pbs_info.rlwe_rgsw_decomposer; + let rlwe_rgsw_gadget_a = rlwe_rgsw_decomposer.0.gadget_vector(); + let rlwe_rgsw_gadget_b = rlwe_rgsw_decomposer.1.gadget_vector(); + + // let rgsw_rgsw_decomposer = &bool_evaluator + // .pbs_info + // .parameters + // .rgsw_rgsw_decomposer::>(); + // let rgsw_rgsw_gagdet_a = rgsw_rgsw_decomposer.a().gadget_vector(); + // let rgsw_rgsw_gagdet_b = rgsw_rgsw_decomposer.b().gadget_vector(); + + let parties = (0..no_of_parties) + .map(|_| bool_evaluator.client_key()) + .collect_vec(); + + let ideal_client_key = { + let mut ideal_rlwe_sk = vec![0i32; bool_evaluator.pbs_info.rlwe_n()]; + parties.iter().for_each(|k| { + izip!(ideal_rlwe_sk.iter_mut(), k.sk_rlwe.values()).for_each(|(ideal_i, s_i)| { + *ideal_i = *ideal_i + s_i; + }); + }); + let mut ideal_lwe_sk = vec![0i32; bool_evaluator.pbs_info.lwe_n()]; + parties.iter().for_each(|k| { + izip!(ideal_lwe_sk.iter_mut(), k.sk_lwe.values()).for_each(|(ideal_i, s_i)| { + *ideal_i = *ideal_i + s_i; + }); + }); + + ClientKey { + sk_lwe: LweSecret { + values: ideal_lwe_sk, + }, + sk_rlwe: RlweSecret { + values: ideal_rlwe_sk, + }, + } + }; + + // check noise in freshly encrypted RLWE ciphertext (ie var_fresh) + if false { + let mut rng = DefaultSecureRng::new(); + let mut check = Stats { samples: vec![] }; + for _ in 0..10 { + // generate a new collective public key + let mut pk_cr_seed = [0u8; 32]; + rng.fill_bytes(&mut pk_cr_seed); + let public_key_share = parties + .iter() + .map(|k| bool_evaluator.multi_party_public_key_share(pk_cr_seed, k)) + .collect_vec(); + let collective_pk = PublicKey::< + Vec>, + DefaultSecureRng, + ModularOpsU64>, + >::from(public_key_share.as_slice()); + + let mut m = vec![0u64; rlwe_n]; + RandomFillUniformInModulus::random_fill(&mut rng, rlwe_q, m.as_mut_slice()); + let mut rlwe_ct = vec![vec![0u64; rlwe_n]; 2]; + public_key_encrypt_rlwe::<_, _, _, _, i32, _>( + &mut rlwe_ct, + &collective_pk.key, + &m, + rlwe_modop, + rlwe_nttop, + &mut rng, + ); + + let mut m_back = vec![0u64; rlwe_n]; + decrypt_rlwe( + &rlwe_ct, + ideal_client_key.sk_rlwe.values(), + &mut m_back, + rlwe_nttop, + rlwe_modop, + ); + + rlwe_modop.elwise_sub_mut(m_back.as_mut_slice(), m.as_slice()); + + check.add_more(Vec::::try_convert_from(&m_back, rlwe_q).as_slice()); + } + + println!("Public key Std: {}", check.std_dev().abs().log2()); + } + + if true { + // Generate server key shares + let mut rng = DefaultSecureRng::new(); + let mut pk_cr_seed = [0u8; 32]; + rng.fill_bytes(&mut pk_cr_seed); + let public_key_share = parties + .iter() + .map(|k| bool_evaluator.multi_party_public_key_share(pk_cr_seed, k)) + .collect_vec(); + let collective_pk = PublicKey::< + Vec>, + DefaultSecureRng, + ModularOpsU64>, + >::from(public_key_share.as_slice()); + + let pbs_cr_seed = [0u8; 32]; + rng.fill_bytes(&mut pk_cr_seed); + let server_key_shares = parties + .iter() + .map(|k| { + bool_evaluator.multi_party_server_key_share(pbs_cr_seed, &collective_pk.key, k) + }) + .collect_vec(); + + let seeded_server_key = + bool_evaluator.aggregate_multi_party_server_key_shares(&server_key_shares); + + // Check noise in RGSW ciphertexts of ideal LWE secret elements + if true { + let mut check = Stats { samples: vec![] }; + izip!( + ideal_client_key.sk_lwe.values.iter(), + seeded_server_key.rgsw_cts.iter() + ) + .for_each(|(s_i, rgsw_ct_i)| { + // X^{s[i]} + let mut m_si = vec![0u64; rlwe_n]; + let s_i = *s_i * (bool_evaluator.pbs_info.embedding_factor as i32); + if s_i < 0 { + m_si[rlwe_n - (s_i.abs() as usize)] = rlwe_q.neg_one(); + } else { + m_si[s_i as usize] = 1; + } + + // RLWE(-sm) + let mut neg_s_eval = + Vec::::try_convert_from(ideal_client_key.sk_rlwe.values(), rlwe_q); + rlwe_modop.elwise_neg_mut(&mut neg_s_eval); + rlwe_nttop.forward(&mut neg_s_eval); + for j in 0..rlwe_rgsw_decomposer.a().decomposition_count() { + // -s[X]*X^{s_lwe[i]}*B_j + let mut m_ideal = m_si.clone(); + rlwe_nttop.forward(m_ideal.as_mut_slice()); + rlwe_modop.elwise_mul_mut(m_ideal.as_mut_slice(), neg_s_eval.as_slice()); + rlwe_nttop.backward(m_ideal.as_mut_slice()); + rlwe_modop + .elwise_scalar_mul_mut(m_ideal.as_mut_slice(), &rlwe_rgsw_gadget_a[j]); + + // RLWE(-s*X^{s_lwe[i]}*B_j) + let mut rlwe_ct = vec![vec![0u64; rlwe_n]; 2]; + rlwe_ct[0].copy_from_slice(&rgsw_ct_i[j]); + rlwe_ct[1].copy_from_slice( + &rgsw_ct_i[j + rlwe_rgsw_decomposer.a().decomposition_count()], + ); + + let mut m_back = vec![0u64; rlwe_n]; + decrypt_rlwe( + &rlwe_ct, + ideal_client_key.sk_rlwe.values(), + &mut m_back, + rlwe_nttop, + rlwe_modop, + ); + + // diff + rlwe_modop.elwise_sub_mut(&mut m_back, &m_ideal); + check.add_more(&Vec::::try_convert_from(&m_back, rlwe_q)); + } + + // RLWE'(m) + for j in 0..rlwe_rgsw_decomposer.b().decomposition_count() { + // X^{s_lwe[i]}*B_j + let mut m_ideal = m_si.clone(); + rlwe_modop + .elwise_scalar_mul_mut(m_ideal.as_mut_slice(), &rlwe_rgsw_gadget_b[j]); + + // RLWE(X^{s_lwe[i]}*B_j) + let mut rlwe_ct = vec![vec![0u64; rlwe_n]; 2]; + rlwe_ct[0].copy_from_slice( + &rgsw_ct_i[j + (2 * rlwe_rgsw_decomposer.a().decomposition_count())], + ); + rlwe_ct[1].copy_from_slice( + &rgsw_ct_i[j + + (2 * rlwe_rgsw_decomposer.a().decomposition_count() + + rlwe_rgsw_decomposer.b().decomposition_count())], + ); + + let mut m_back = vec![0u64; rlwe_n]; + decrypt_rlwe( + &rlwe_ct, + ideal_client_key.sk_rlwe.values(), + &mut m_back, + rlwe_nttop, + rlwe_modop, + ); + + // diff + rlwe_modop.elwise_sub_mut(&mut m_back, &m_ideal); + check.add_more(&Vec::::try_convert_from(&m_back, rlwe_q)); + } + }); + println!( + "RGSW Std: {} {} ;; max={}", + check.mean(), + check.std_dev().abs().log2(), + check.samples.iter().max().unwrap() + ); + } + + // check noise in RLWE x RGSW(X^{s_i}) where RGSW is accunulated RGSW ciphertext + if false { + let mut check = Stats { samples: vec![] }; + // server key in Evaluation domain + let server_key_eval_domain = + ServerKeyEvaluationDomain::<_, DefaultSecureRng, NttBackendU64>::from( + &seeded_server_key, + ); + izip!( + ideal_client_key.sk_lwe.values(), + seeded_server_key.rgsw_cts.iter() + ) + .for_each(|(s_i, rgsw_ct_i)| { + let mut rgsw_ct_i = rgsw_ct_i.clone(); + rgsw_ct_i + .iter_mut() + .for_each(|ri| rlwe_nttop.forward(ri.as_mut())); + + let mut m = vec![0u64; rlwe_n]; + RandomFillUniformInModulus::random_fill(&mut rng, rlwe_q, m.as_mut_slice()); + let mut rlwe_ct = vec![vec![0u64; rlwe_n]; 2]; + public_key_encrypt_rlwe::<_, _, _, _, i32, _>( + &mut rlwe_ct, + &collective_pk.key, + &m, + rlwe_modop, + rlwe_nttop, + &mut rng, + ); + + // RLWE(m*X^{s[i]}) = RLWE(m) x RGSW(X^{s[i]}) + let mut rlwe_after = RlweCiphertext::<_, DefaultSecureRng>::new_trivial(vec![ + vec![0u64; rlwe_n], + m.clone(), + ]); + // let mut rlwe_after = + // RlweCiphertext::<_, DefaultSecureRng>::from_raw(rlwe_ct.clone(), false); + let mut scratch = vec![ + vec![0u64; rlwe_n]; + std::cmp::max( + rlwe_rgsw_decomposer.0.decomposition_count(), + rlwe_rgsw_decomposer.1.decomposition_count() + ) + 2 + ]; + rlwe_by_rgsw( + &mut rlwe_after, + &rgsw_ct_i, + &mut scratch, + rlwe_rgsw_decomposer, + rlwe_nttop, + rlwe_modop, + ); + + // m1 = X^{s[i]} + let mut m1 = vec![0u64; rlwe_n]; + let s_i = *s_i * (bool_evaluator.pbs_info.embedding_factor as i32); + if s_i < 0 { + m1[rlwe_n - (s_i.abs() as usize)] = rlwe_q.neg_one() + } else { + m1[s_i as usize] = 1; + } + + // (m+e) * m1 + let mut m_plus_e_times_m1 = m.clone(); + // decrypt_rlwe( + // &rlwe_ct, + // ideal_client_key.sk_rlwe.values(), + // &mut m_plus_e_times_m1, + // rlwe_nttop, + // rlwe_modop, + // ); + rlwe_nttop.forward(m_plus_e_times_m1.as_mut_slice()); + rlwe_nttop.forward(m1.as_mut_slice()); + rlwe_modop.elwise_mul_mut(m_plus_e_times_m1.as_mut_slice(), m1.as_slice()); + rlwe_nttop.backward(m_plus_e_times_m1.as_mut_slice()); + + // Resulting RLWE ciphertext will equal: (m0m1 + em1) + e_{rlsw x rgsw}. + // Hence, resulting rlwe ciphertext will have error em1 + e_{rlwe x rgsw}. + // Here we're only concerned with e_{rlwe x rgsw}, that is noise caused due to + // RLWExRGSW. Also note, in practice m1 is a monomial, for ex, X^{s_{i}}, for + // some i and var(em1) = var(e). + let mut m_plus_e_times_m1_more_e = vec![0u64; rlwe_n]; + decrypt_rlwe( + &rlwe_after, + ideal_client_key.sk_rlwe.values(), + &mut m_plus_e_times_m1_more_e, + rlwe_nttop, + rlwe_modop, + ); + + // diff + rlwe_modop.elwise_sub_mut( + m_plus_e_times_m1_more_e.as_mut_slice(), + m_plus_e_times_m1.as_slice(), + ); + + let noise = measure_noise( + &rlwe_after, + &m_plus_e_times_m1, + rlwe_nttop, + rlwe_modop, + ideal_client_key.sk_rlwe.values(), + ); + print!("NOISE: {}", noise); + + check.add_more(&Vec::::try_convert_from( + &m_plus_e_times_m1_more_e, + rlwe_q, + )); + }); + println!( + "RLWE x RGSW, where RGSW has noise var_brk, std: {} {}", + check.std_dev(), + check.std_dev().abs().log2() + ) + } + } + + // Check noise in fresh RGSW ciphertexts, ie X^{s_j[i]}, must equal noise in + // fresh RLWE ciphertext + if true {} + // test LWE ksk from RLWE -> LWE + // if false { + // let logp = 2; + // let mut rng = DefaultSecureRng::new(); + + // let m = 1; + // let encoded_m = m << (lwe_logq - logp); + + // // Encrypt + // let mut lwe_ct = vec![0u64; rlwe_n + 1]; + // encrypt_lwe( + // &mut lwe_ct, + // &encoded_m, + // ideal_client_key.sk_rlwe.values(), + // lwe_modop, + // &mut rng, + // ); + + // // key switch + // let lwe_decomposer = &bool_evaluator.decomposer_lwe; + // let mut lwe_out = vec![0u64; lwe_n + 1]; + // lwe_key_switch( + // &mut lwe_out, + // &lwe_ct, + // &server_key_eval.lwe_ksk, + // lwe_modop, + // lwe_decomposer, + // ); + + // let encoded_m_back = decrypt_lwe(&lwe_out, + // ideal_client_key.sk_lwe.values(), lwe_modop); let m_back + // = ((encoded_m_back as f64 * (1 << logp) as f64) / + // (lwe_q as f64)).round() as u64; dbg!(m_back, m); + + // let noise = measure_noise_lwe( + // &lwe_out, + // ideal_client_key.sk_lwe.values(), + // lwe_modop, + // &encoded_m, + // ); + + // println!("Noise: {noise}"); + // } + + // Measure noise in RGSW ciphertexts of ideal LWE secrets + // if true { + // let gadget_vec = gadget_vector( + // bool_evaluator.parameters.rlwe_logq, + // bool_evaluator.parameters.logb_rgsw, + // bool_evaluator.parameters.d_rgsw, + // ); + + // for i in 0..20 { + // // measure noise in RGSW(s[i]) + // let si = + // ideal_client_key.sk_lwe.values[i] * + // (bool_evaluator.embedding_factor as i32); let mut + // si_poly = vec![0u64; rlwe_n]; if si < 0 { + // si_poly[rlwe_n - (si.abs() as usize)] = rlwe_q - 1; + // } else { + // si_poly[(si.abs() as usize)] = 1; + // } + + // let mut rgsw_si = server_key_eval.rgsw_cts[i].clone(); + // rgsw_si + // .iter_mut() + // .for_each(|ri| rlwe_nttop.backward(ri.as_mut())); + + // println!("####### Noise in RGSW(X^s_{i}) #######"); + // _measure_noise_rgsw( + // &rgsw_si, + // &si_poly, + // ideal_client_key.sk_rlwe.values(), + // &gadget_vec, + // rlwe_q, + // ); + // println!("####### ##################### #######"); + // } + // } + + // // measure noise grwoth in RLWExRGSW + // if true { + // let mut rng = DefaultSecureRng::new(); + // let mut carry_m = vec![0u64; rlwe_n]; + // RandomUniformDist1::random_fill(&mut rng, &rlwe_q, + // carry_m.as_mut_slice()); + + // // RGSW(carrym) + // let trivial_rlwect = vec![vec![0u64; rlwe_n], carry_m.clone()]; + // let mut rlwe_ct = RlweCiphertext::<_, + // DefaultSecureRng>::from_raw(trivial_rlwect, true); + + // let mut scratch_matrix_dplus2_ring = vec![vec![0u64; rlwe_n]; + // d_rgsw + 2]; let mul_mod = + // |v0: &u64, v1: &u64| (((*v0 as u128 * *v1 as u128) % (rlwe_q as u128)) as u64); + + // for i in 0..bool_evaluator.parameters.lwe_n { + // rlwe_by_rgsw( + // &mut rlwe_ct, + // server_key_eval.rgsw_ct_lwe_si(i), + // &mut scratch_matrix_dplus2_ring, + // rlwe_decomposer, + // rlwe_nttop, + // rlwe_modop, + // ); + + // // carry_m[X] * s_i[X] + // let si = + // ideal_client_key.sk_lwe.values[i] * + // (bool_evaluator.embedding_factor as i32); let mut + // si_poly = vec![0u64; rlwe_n]; if si < 0 { + // si_poly[rlwe_n - (si.abs() as usize)] = rlwe_q - 1; + // } else { + // si_poly[(si.abs() as usize)] = 1; + // } + // carry_m = negacyclic_mul(&carry_m, &si_poly, mul_mod, + // rlwe_q); + + // let noise = measure_noise( + // &rlwe_ct, + // &carry_m, + // rlwe_nttop, + // rlwe_modop, + // ideal_client_key.sk_rlwe.values(), + // ); + // println!("Noise RLWE(carry_m) accumulating {i}^th secret + // monomial: {noise}"); } + // } + + // // Check galois keys + // if false { + // let g = bool_evaluator.g() as isize; + // let mut rng = DefaultSecureRng::new(); + // let mut scratch_matrix_dplus2_ring = vec![vec![0u64; rlwe_n]; + // d_rgsw + 2]; for i in [g, -g] { + // let mut m = vec![0u64; rlwe_n]; + // RandomUniformDist1::random_fill(&mut rng, &rlwe_q, + // m.as_mut_slice()); let mut rlwe_ct = { + // let mut data = vec![vec![0u64; rlwe_n]; 2]; + // public_key_encrypt_rlwe( + // &mut data, + // &collective_pk.key, + // &m, + // rlwe_modop, + // rlwe_nttop, + // &mut rng, + // ); + // RlweCiphertext::<_, DefaultSecureRng>::from_raw(data, + // false) }; + + // let auto_key = server_key_eval.galois_key_for_auto(i); + // let (auto_map_index, auto_map_sign) = + // generate_auto_map(rlwe_n, i); galois_auto( + // &mut rlwe_ct, + // auto_key, + // &mut scratch_matrix_dplus2_ring, + // &auto_map_index, + // &auto_map_sign, + // rlwe_modop, + // rlwe_nttop, + // rlwe_decomposer, + // ); + + // // send m(X) -> m(X^i) + // let mut m_k = vec![0u64; rlwe_n]; + // izip!(m.iter(), auto_map_index.iter(), + // auto_map_sign.iter()).for_each( |(mi, to_index, to_sign)| + // { if !to_sign { + // m_k[*to_index] = rlwe_q - *mi; + // } else { + // m_k[*to_index] = *mi; + // } + // }, + // ); + + // // measure noise + // let noise = measure_noise( + // &rlwe_ct, + // &m_k, + // rlwe_nttop, + // rlwe_modop, + // ideal_client_key.sk_rlwe.values(), + // ); + + // println!("Noise after auto k={i}: {noise}"); + // } + // } + } +} diff --git a/src/bool/parameters.rs b/src/bool/parameters.rs index d7f6e83..62292a0 100644 --- a/src/bool/parameters.rs +++ b/src/bool/parameters.rs @@ -1,4 +1,4 @@ -use num_traits::{ConstZero, PrimInt, Zero}; +use num_traits::{ConstZero, FromPrimitive, PrimInt, ToPrimitive, Zero}; use crate::{backend::Modulus, decomposer::Decomposer}; @@ -183,23 +183,52 @@ impl CiphertextModulus { } } -impl Modulus for CiphertextModulus +impl CiphertextModulus where T: PrimInt, +{ + pub(crate) fn _bits() -> usize { + std::mem::size_of::() as usize * 8 + } + + fn _native(&self) -> bool { + self.1 + } + + fn _half_q(&self) -> T { + if self._native() { + T::one() << (Self::_bits() - 1) + } else { + self.0 >> 1 + } + } + + fn _q(&self) -> Option { + if self._native() { + None + } else { + Some(self.0) + } + } +} + +impl Modulus for CiphertextModulus +where + T: PrimInt + FromPrimitive, { type Element = T; fn is_native(&self) -> bool { - false + self._native() } fn largest_unsigned_value(&self) -> Self::Element { - if self.1 { + if self._native() { T::max_value() } else { self.0 - T::one() } } fn neg_one(&self) -> Self::Element { - if self.1 { + if self._native() { T::max_value() } else { self.0 - T::one() @@ -211,20 +240,43 @@ where T::zero() } - fn to_i64(&self, v: &Self::Element) -> i64 { - todo!() + fn map_element_to_i64(&self, v: &Self::Element) -> i64 { + if *v > self._half_q() { + -((self.largest_unsigned_value() - *v) + T::one()) + .to_i64() + .unwrap() + } else { + v.to_i64().unwrap() + } } - fn from_f64(&self, v: f64) -> Self::Element { - todo!() + fn map_element_from_f64(&self, v: f64) -> Self::Element { + let v = v.round(); + if v < 0.0 { + self.largest_unsigned_value() - T::from_f64(v.abs()).unwrap() + T::one() + } else { + T::from_f64(v.abs()).unwrap() + } } - fn from_i64(&self, v: i64) -> Self::Element { - todo!() + fn map_element_from_i64(&self, v: i64) -> Self::Element { + if v < 0 { + self.largest_unsigned_value() - T::from_i64(v.abs()).unwrap() + T::one() + } else { + T::from_i64(v.abs()).unwrap() + } } fn q(&self) -> Option { - todo!() + self._q() + } + + fn q_as_f64(&self) -> Option { + if self._native() { + Some(T::max_value().to_f64().unwrap() + 1.0) + } else { + self.0.to_f64() + } } } diff --git a/src/lwe.rs b/src/lwe.rs index 000235f..e036f41 100644 --- a/src/lwe.rs +++ b/src/lwe.rs @@ -12,8 +12,8 @@ use crate::{ backend::{ArithmeticOps, GetModulus, Modulus, VectorOps}, decomposer::Decomposer, random::{ - DefaultSecureRng, NewWithSeed, RandomFillGaussianInModulus, RandomGaussianElementInModulus, - RandomFillUniformInModulus, DEFAULT_RNG, + DefaultSecureRng, NewWithSeed, RandomFillGaussianInModulus, RandomFillUniformInModulus, + RandomGaussianElementInModulus, DEFAULT_RNG, }, utils::{fill_random_ternary_secret_with_hamming_weight, TryConvertFrom1, WithLocal}, Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, Secret, @@ -65,7 +65,11 @@ where let mut p_rng = R::new_with_seed(value.seed.clone()); let mut data = M::zeros(value.data.as_ref().len(), value.to_lwe_n + 1); izip!(value.data.as_ref().iter(), data.iter_rows_mut()).for_each(|(bi, lwe_i)| { - RandomFillUniformInModulus::random_fill(&mut p_rng, &value.modulus, &mut lwe_i.as_mut()[1..]); + RandomFillUniformInModulus::random_fill( + &mut p_rng, + &value.modulus, + &mut lwe_i.as_mut()[1..], + ); lwe_i.as_mut()[0] = *bi; }); LweKeySwitchingKey { @@ -189,7 +193,8 @@ pub fn lwe_ksk_keygen< pub fn encrypt_lwe< Ro: Row + RowMut, Op: ArithmeticOps + GetModulus, - R: RandomGaussianElementInModulus + RandomFillUniformInModulus<[Ro::Element], Op::M>, + R: RandomGaussianElementInModulus + + RandomFillUniformInModulus<[Ro::Element], Op::M>, S, >( lwe_out: &mut Ro, @@ -273,7 +278,7 @@ where let mut diff = operator.sub(&m, ideal_m); let q = operator.modulus(); - return q.to_i64(&diff).to_f64().unwrap().abs().log2(); + return q.map_element_to_i64(&diff).to_f64().unwrap().abs().log2(); } #[cfg(test)] diff --git a/src/random.rs b/src/random.rs index 7980c47..3bac194 100644 --- a/src/random.rs +++ b/src/random.rs @@ -118,7 +118,7 @@ where container.iter_mut() ) .for_each(|(from, to)| { - *to = modulus.from_f64(from); + *to = modulus.map_element_from_f64(from); }); } } @@ -152,13 +152,13 @@ where T: PrimInt + SampleUniform, { fn random(&mut self, modulus: &T) -> T { - Uniform::new_inclusive(T::zero(), modulus).sample(&mut self.rng) + Uniform::new(T::zero(), modulus).sample(&mut self.rng) } } impl> RandomGaussianElementInModulus for DefaultSecureRng { fn random(&mut self, modulus: &M) -> T { - modulus.from_f64( + modulus.map_element_from_f64( rand_distr::Normal::new(0.0, 3.19f64) .unwrap() .sample(&mut self.rng), diff --git a/src/rgsw.rs b/src/rgsw.rs index 1b8d279..02c2079 100644 --- a/src/rgsw.rs +++ b/src/rgsw.rs @@ -13,8 +13,8 @@ use crate::{ decomposer::{self, Decomposer, RlweDecomposer}, ntt::{self, Ntt, NttInit}, random::{ - DefaultSecureRng, NewWithSeed, RandomElementInModulus, RandomFill, RandomFillGaussianInModulus, - RandomFillUniformInModulus, + DefaultSecureRng, NewWithSeed, RandomElementInModulus, RandomFill, + RandomFillGaussianInModulus, RandomFillUniformInModulus, }, utils::{fill_random_ternary_secret_with_hamming_weight, TryConvertFrom1, WithLocal}, Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, Secret, @@ -1528,7 +1528,7 @@ where let mut max_diff_bits = f64::MIN; m_plus_e.as_ref().iter().for_each(|v| { - let bits = (q.to_i64(v).to_f64().unwrap()).log2(); + let bits = (q.map_element_to_i64(v).to_f64().unwrap()).log2(); if max_diff_bits < bits { max_diff_bits = bits; @@ -1744,7 +1744,11 @@ pub(crate) mod tests { // sample m0 let mut m0 = vec![0u64; ring_size as usize]; - RandomFillUniformInModulus::<[u64], u64>::random_fill(&mut rng, &(1u64 << logp), m0.as_mut_slice()); + RandomFillUniformInModulus::<[u64], u64>::random_fill( + &mut rng, + &(1u64 << logp), + m0.as_mut_slice(), + ); let ntt_op = NttBackendU64::new(&q, ring_size as usize); let mod_op = ModularOpsU64::new(q); @@ -1787,7 +1791,11 @@ pub(crate) mod tests { let s = RlweSecret::random((ring_size >> 1) as usize, ring_size as usize); let mut m0 = vec![0u64; ring_size as usize]; - RandomFillUniformInModulus::<[u64], _>::random_fill(&mut rng, &(1u64 << logp), m0.as_mut_slice()); + RandomFillUniformInModulus::<[u64], _>::random_fill( + &mut rng, + &(1u64 << logp), + m0.as_mut_slice(), + ); let mut m1 = vec![0u64; ring_size as usize]; m1[thread_rng().gen_range(0..ring_size) as usize] = 1; diff --git a/src/utils.rs b/src/utils.rs index cd68a11..2720cc3 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -146,7 +146,10 @@ pub trait TryConvertFrom1 { impl> TryConvertFrom1<[i64], P> for Vec { fn try_convert_from(value: &[i64], parameters: &P) -> Self { - value.iter().map(|v| parameters.from_i64(*v)).collect_vec() + value + .iter() + .map(|v| parameters.map_element_from_i64(*v)) + .collect_vec() } } @@ -154,14 +157,17 @@ impl> TryConvertFrom1<[i32], P> for Vec { fn try_convert_from(value: &[i32], parameters: &P) -> Self { value .iter() - .map(|v| parameters.from_i64(*v as i64)) + .map(|v| parameters.map_element_from_i64(*v as i64)) .collect_vec() } } impl TryConvertFrom1<[P::Element], P> for Vec { fn try_convert_from(value: &[P::Element], parameters: &P) -> Self { - value.iter().map(|v| parameters.to_i64(v)).collect_vec() + value + .iter() + .map(|v| parameters.map_element_to_i64(v)) + .collect_vec() } }