Browse Source

fix secret HW and clean a bit

par-agg-key-shares
Janmajaya Mall 10 months ago
parent
commit
1bfb6dc7a5
10 changed files with 82 additions and 207 deletions
  1. +1
    -21
      src/bool/evaluator.rs
  2. +9
    -11
      src/bool/ni_mp_api.rs
  3. +2
    -2
      src/decomposer.rs
  4. +1
    -85
      src/main.rs
  5. +4
    -4
      src/multi_party.rs
  6. +3
    -3
      src/rgsw/mod.rs
  7. +2
    -2
      src/rgsw/runtime.rs
  8. +1
    -1
      src/shortint/enc_dec.rs
  9. +0
    -13
      src/shortint/mod.rs
  10. +59
    -65
      src/utils.rs

+ 1
- 21
src/bool/evaluator.rs

@ -2191,31 +2191,11 @@ mod tests {
RgswCiphertext, RgswCiphertextEvaluationDomain, SeededRgswCiphertext,
SeededRlweCiphertext,
},
utils::{negacyclic_mul, Stats},
utils::{negacyclic_mul, tests::Stats},
};
use super::*;
#[test]
fn bool_encrypt_decrypt_works() {
let bool_evaluator = BoolEvaluator::<
Vec<Vec<u64>>,
NttBackendU64,
ModularOpsU64<CiphertextModulus<u64>>,
ModularOpsU64<CiphertextModulus<u64>>,
ShoupServerKeyEvaluationDomain<Vec<Vec<u64>>>,
>::new(SP_TEST_BOOL_PARAMS);
let client_key = bool_evaluator.client_key();
let mut m = true;
for _ in 0..1000 {
let lwe_ct = bool_evaluator.sk_encrypt(m, &client_key);
let m_back = bool_evaluator.sk_decrypt(&lwe_ct, &client_key);
assert_eq!(m, m_back);
m = !m;
}
}
#[test]
fn noise_tester() {
let bool_evaluator = BoolEvaluator::<

+ 9
- 11
src/bool/ni_mp_api.rs

@ -167,11 +167,10 @@ mod impl_enc_dec {
use crate::{
bool::{evaluator::BoolEncoding, keys::NonInteractiveMultiPartyClientKey},
pbs::{sample_extract, PbsInfo, WithShoupRepr},
random::{DefaultSecureRng, NewWithSeed, RandomFillUniformInModulus},
random::{NewWithSeed, RandomFillUniformInModulus},
rgsw::{key_switch, secret_key_encrypt_rlwe},
utils::{TryConvertFrom1, WithLocal},
Encryptor, KeySwitchWithId, Matrix, MatrixEntity, MatrixMut, MultiPartyDecryptor,
RowEntity, RowMut,
utils::TryConvertFrom1,
Encryptor, KeySwitchWithId, Matrix, MatrixEntity, MatrixMut, RowEntity, RowMut,
};
use itertools::Itertools;
use num_traits::{ToPrimitive, Zero};
@ -359,13 +358,10 @@ mod tests {
use crate::{
backend::{GetModulus, Modulus},
bool::{
evaluator::{BoolEncoding, BooleanGates},
keys::SinglePartyClientKey,
},
bool::{evaluator::BooleanGates, keys::SinglePartyClientKey},
lwe::decrypt_lwe,
rgsw::decrypt_rlwe,
utils::{Stats, TryConvertFrom1},
utils::{tests::Stats, TryConvertFrom1},
ArithmeticOps, Encoder, Encryptor, KeySwitchWithId, ModInit, MultiPartyDecryptor, NttInit,
Row, VectorOps,
};
@ -448,9 +444,11 @@ mod tests {
ct.extract(0)
};
for _ in 0..100 {
for _ in 0..1000 {
// let now = std::time::Instant::now();
let ct_out =
BoolEvaluator::with_local_mut(|e| e.xor(&ct0, &ct1, RuntimeServerKey::global()));
// println!("Time: {:?}", now.elapsed());
let decryption_shares = cks
.iter()
@ -458,7 +456,7 @@ mod tests {
.collect_vec();
let m_out = cks[0].aggregate_decryption_shares(&ct_out, &decryption_shares);
let m_expected = (m0 ^ m1);
let m_expected = m0 ^ m1;
{
let noise = measure_noise_lwe(

+ 2
- 2
src/decomposer.rs

@ -275,7 +275,7 @@ mod tests {
use crate::{
backend::{ModInit, ModularOpsU64},
decomposer::round_value,
utils::{generate_prime, Stats, TryConvertFrom1},
utils::{generate_prime, tests::Stats, TryConvertFrom1},
};
use super::{Decomposer, DefaultDecomposer};
@ -288,7 +288,7 @@ mod tests {
let ring_size = 1 << 11;
let mut rng = thread_rng();
let mut stats = Stats { samples: vec![] };
let mut stats = Stats::new();
for i in [true] {
let q = if i {

+ 1
- 85
src/main.rs

@ -1,85 +1 @@
use std::os::unix::thread;
use rand::{thread_rng, Rng};
fn decomposer(mut value: u64, q: u64, d: usize, logb: u64) -> Vec<u64> {
let b = 1u64 << logb;
let full_mask = b - 1u64;
let bby2 = b >> 1;
if value >= (q >> 1) {
value = !(q - value) + 1;
}
// let mut carry = 0;
// let mut out = Vec::with_capacity(d);
// for _ in 0..d {
// let k_i = carry + (value & full_mask);
// value = (value) >> logb;
// if k_i > bby2 {
// // if (k_i == bby2 && ((value & 1) == 1)) {
// // println!("AA");
// // }
// out.push(q - (b - k_i));
// carry = 1;
// } else {
// // if (k_i == bby2) {
// // println!("BB");
// // }
// out.push(k_i);
// carry = 0;
// }
// }
// return out;
let mut out = Vec::with_capacity(d);
for _ in 0..d {
let k_i = value & full_mask;
value = (value - k_i) >> logb;
if k_i > bby2 || (k_i == bby2 && ((value & 1) == 1)) {
// if (k_i == bby2 && ((value & 1) == 1)) {
// println!("AA");
// }
out.push(q - (b - k_i));
value += 1;
} else {
// if (k_i == bby2) {
// println!("BB");
// }
out.push(k_i);
}
}
return out;
}
fn recompose(limbs: &[u64], q: u64, logb: u64) -> u64 {
let mut out = 0;
limbs.iter().enumerate().for_each(|(i, l)| {
let a = 1u128 << (logb * (i as u64));
let a = ((a * (*l as u128)) % (q as u128)) as u64;
out = (out + a) % q;
});
out % q
}
fn main() {
// let mut v = Vec::with_capacity(10);
// v[0] = 1;
// println!("Hello, world!");
let mut rng = thread_rng();
let q = 36028797018820609u64;
let logb = 11;
let d = 5;
for _ in 0..100000 {
let value = rng.gen_range(0..q);
let limbs = decomposer(value, q, d, logb);
// println!("{:?}", &limbs);
let value_back = recompose(&limbs, q, logb);
assert_eq!(value, value_back)
}
}
fn main() {}

+ 4
- 4
src/multi_party.rs

@ -4,10 +4,10 @@ use itertools::izip;
use crate::{
backend::{GetModulus, VectorOps},
ntt::{self, Ntt},
random::{NewWithSeed, RandomFillGaussianInModulus, RandomFillUniformInModulus},
ntt::Ntt,
random::{RandomFillGaussianInModulus, RandomFillUniformInModulus},
utils::TryConvertFrom1,
Decomposer, Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut,
Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut,
};
pub(crate) fn public_key_share<
@ -213,7 +213,7 @@ where
let mut scratch_space = M::R::zeros(ring_size);
izip!(zero_encs.iter_rows_mut()).for_each(|(e_zero)| {
izip!(zero_encs.iter_rows_mut()).for_each(|e_zero| {
// sample a_i
RandomFillUniformInModulus::random_fill(p_rng, q, e_zero.as_mut());

+ 3
- 3
src/rgsw/mod.rs

@ -512,13 +512,13 @@ pub(crate) mod tests {
use crate::{
backend::{GetModulus, ModInit, ModularOpsU64, Modulus, VectorOps},
decomposer::{Decomposer, DefaultDecomposer, RlweDecomposer},
ntt::{self, Ntt, NttBackendU64, NttInit},
random::{DefaultSecureRng, NewWithSeed, RandomFillUniformInModulus},
ntt::{Ntt, NttBackendU64, NttInit},
random::{DefaultSecureRng, RandomFillUniformInModulus},
rgsw::{
galois_auto_shoup, rlwe_by_rgsw_shoup, ShoupAutoKeyEvaluationDomain,
ShoupRgswCiphertextEvaluationDomain,
},
utils::{generate_prime, negacyclic_mul, Stats, TryConvertFrom1},
utils::{generate_prime, negacyclic_mul, tests::Stats, TryConvertFrom1},
Matrix, Secret,
};

+ 2
- 2
src/rgsw/runtime.rs

@ -2,10 +2,10 @@ use itertools::izip;
use num_traits::Zero;
use crate::{
backend::{ArithmeticOps, GetModulus, Modulus, ShoupMatrixFMA, VectorOps},
backend::{ArithmeticOps, GetModulus, ShoupMatrixFMA, VectorOps},
decomposer::{Decomposer, RlweDecomposer},
ntt::Ntt,
Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, Secret,
Matrix, MatrixEntity, MatrixMut, RowEntity, RowMut,
};
use super::IsTrivial;

+ 1
- 1
src/shortint/enc_dec.rs

@ -3,7 +3,7 @@ use itertools::Itertools;
use crate::{
bool::BoolEvaluator,
random::{DefaultSecureRng, RandomFillUniformInModulus},
utils::{TryConvertFrom1, WithLocal},
utils::WithLocal,
Decryptor, Encryptor, KeySwitchWithId, Matrix, MatrixEntity, MatrixMut, MultiPartyDecryptor,
RowMut, SampleExtractor,
};

+ 0
- 13
src/shortint/mod.rs

@ -205,16 +205,3 @@ mod frontend {
}
}
}
#[cfg(test)]
mod tests {
use itertools::Itertools;
use num_traits::Euclid;
use crate::{
bool::set_parameter_set, shortint::enc_dec::FheUint8, utils::WithLocal, Decryptor,
Encryptor, MultiPartyDecryptor,
};
use super::*;
}

+ 59
- 65
src/utils.rs

@ -86,7 +86,7 @@ pub fn fill_random_ternary_secret_with_hamming_weight<
let mut secret_indices = (0..size).into_iter().map(|i| i).collect_vec();
let mut bit_index = 0;
let mut byte_index = 0;
for _ in 0..hamming_weight {
for i in 0..hamming_weight {
let s_index = RandomElementInModulus::<usize, usize>::random(rng, &secret_indices.len());
let curr_bit = (bytes[byte_index] >> bit_index) & 1;
@ -97,7 +97,7 @@ pub fn fill_random_ternary_secret_with_hamming_weight<
}
secret_indices[s_index] = *secret_indices.last().unwrap();
secret_indices.truncate(secret_indices.len());
secret_indices.truncate(secret_indices.len() - 1);
if bit_index == 7 {
bit_index = 0;
@ -232,79 +232,73 @@ impl TryConvertFrom1<[P::Element], P> for Vec {
}
}
pub(crate) struct Stats<T> {
pub(crate) samples: Vec<T>,
}
#[cfg(test)]
pub(crate) mod tests {
use std::fmt::Debug;
impl<T: PrimInt + FromPrimitive + Debug> Stats<T>
where
// T: for<'a> Sum<&'a T>,
T: for<'a> std::iter::Sum<&'a T> + std::iter::Sum<T>,
{
pub(crate) fn new() -> Self {
Self { samples: vec![] }
}
use num_traits::{FromPrimitive, PrimInt};
use rand::thread_rng;
use crate::random::DefaultSecureRng;
use super::fill_random_ternary_secret_with_hamming_weight;
pub(crate) fn mean(&self) -> f64 {
self.samples.iter().sum::<T>().to_f64().unwrap() / (self.samples.len() as f64)
pub(crate) struct Stats<T> {
pub(crate) samples: Vec<T>,
}
pub(crate) fn std_dev(&self) -> f64 {
let mean = self.mean();
impl<T: PrimInt + FromPrimitive + Debug> Stats<T>
where
// T: for<'a> Sum<&'a T>,
T: for<'a> std::iter::Sum<&'a T> + std::iter::Sum<T>,
{
pub(crate) fn new() -> Self {
Self { samples: vec![] }
}
// diff
let diff_sq = self
.samples
.iter()
.map(|v| {
let t = v.to_f64().unwrap() - mean;
t * t
})
.into_iter()
.sum::<f64>();
pub(crate) fn mean(&self) -> f64 {
self.samples.iter().sum::<T>().to_f64().unwrap() / (self.samples.len() as f64)
}
(diff_sq / (self.samples.len() as f64)).sqrt()
}
pub(crate) fn std_dev(&self) -> f64 {
let mean = self.mean();
// diff
let diff_sq = self
.samples
.iter()
.map(|v| {
let t = v.to_f64().unwrap() - mean;
t * t
})
.into_iter()
.sum::<f64>();
(diff_sq / (self.samples.len() as f64)).sqrt()
}
pub(crate) fn add_more(&mut self, values: &[T]) {
self.samples.extend(values.iter());
pub(crate) fn add_more(&mut self, values: &[T]) {
self.samples.extend(values.iter());
}
}
}
#[cfg(test)]
mod tests {
use super::is_probably_prime;
// let n = 1 << (11 + 1);
// let mut start = 1 << 55;
// while start < (1 << 56) {
// if start % n == 1 {
// break;
// }
// start += 1;
// }
// let mut prime = None;
// while start < (1 << 56) {
// if is_probably_prime(start) {
// dbg!(start);
// prime = Some(start);
// break;
// }
// dbg!(start);
// start += (n);
// }
#[test]
fn gg() {
let q = 30;
for i in 0..1000 {
let x = (1u64 << (q * 2)) + (i * (1 << q)) + 1;
let is_prime = is_probably_prime(x);
if is_prime {
println!("{x} = 2^{} + {i} * 2^{q} + 1", 2 * q);
}
fn ternary_secret_has_correct_hw() {
let mut rng = DefaultSecureRng::new();
for n in 4..15 {
let ring_size = 1 << n;
let mut out = vec![0i32; ring_size];
fill_random_ternary_secret_with_hamming_weight(&mut out, ring_size >> 1, &mut rng);
// check hamming weight of out equals ring_size/2
let mut non_zeros = 0;
out.iter().for_each(|i| {
if *i != 0 {
non_zeros += 1;
}
});
assert_eq!(ring_size >> 1, non_zeros);
}
// println!("{:?}", prime);
}
}

Loading…
Cancel
Save