Browse Source

fix rounding in decom

par-agg-key-shares
Janmajaya Mall 10 months ago
parent
commit
ab7b1ca40f
4 changed files with 168 additions and 77 deletions
  1. +22
    -9
      src/bool/evaluator.rs
  2. +22
    -0
      src/bool/parameters.rs
  3. +51
    -46
      src/decomposer.rs
  4. +73
    -22
      src/rgsw/mod.rs

+ 22
- 9
src/bool/evaluator.rs

@ -326,8 +326,14 @@ pub(super) struct BoolPbsInfo {
impl<M: Matrix, NttOp, RlweModOp, LweModOp> PbsInfo for BoolPbsInfo<M, NttOp, RlweModOp, LweModOp>
where
M::MatElement:
PrimInt + WrappingSub + NumInfo + FromPrimitive + From<bool> + Display + WrappingAdd,
M::MatElement: PrimInt
+ WrappingSub
+ NumInfo
+ FromPrimitive
+ From<bool>
+ Display
+ WrappingAdd
+ Debug,
RlweModOp: ArithmeticOps<Element = M::MatElement> + ShoupMatrixFMA<M::R>,
LweModOp: ArithmeticOps<Element = M::MatElement> + VectorOps<Element = M::MatElement>,
NttOp: Ntt<Element = M::MatElement>,
@ -2003,7 +2009,8 @@ where
+ WrappingSub
+ NumInfo
+ From<bool>
+ WrappingAdd,
+ WrappingAdd
+ Debug,
RlweModOp: VectorOps<Element = M::MatElement>
+ ArithmeticOps<Element = M::MatElement>
+ ShoupMatrixFMA<M::R>,
@ -2195,7 +2202,9 @@ mod tests {
SP_TEST_BOOL_PARAMS,
},
},
evaluator,
ntt::NttBackendU64,
parameters::OPTIMISED_SMALL_MP_BOOL_PARAMS,
random::{RandomElementInModulus, DEFAULT_RNG},
rgsw::{
self, measure_noise, public_key_encrypt_rlwe, secret_key_encrypt_rlwe,
@ -2216,11 +2225,11 @@ mod tests {
ModularOpsU64<CiphertextModulus<u64>>,
ModularOpsU64<CiphertextModulus<u64>>,
ShoupServerKeyEvaluationDomain<Vec<Vec<u64>>>,
>::new(SMALL_MP_BOOL_PARAMS);
>::new(OPTIMISED_SMALL_MP_BOOL_PARAMS);
// let (_, collective_pk, _, _, server_key_eval, ideal_client_key) =
// _multi_party_all_keygen(&bool_evaluator, 20);
let no_of_parties = 16;
let no_of_parties = 2;
let lwe_q = bool_evaluator.pbs_info.parameters.lwe_q();
let rlwe_q = bool_evaluator.pbs_info.parameters.rlwe_q();
let lwe_n = bool_evaluator.pbs_info.parameters.lwe_n().0;
@ -2269,7 +2278,7 @@ mod tests {
});
// check noise in freshly encrypted RLWE ciphertext (ie var_fresh)
if true {
if false {
let mut rng = DefaultSecureRng::new();
let mut check = Stats { samples: vec![] };
for _ in 0..10 {
@ -2343,7 +2352,7 @@ mod tests {
bool_evaluator.aggregate_multi_party_server_key_shares(&server_key_shares);
// Check noise in RGSW ciphertexts of ideal LWE secret elements
if false {
if true {
let mut check = Stats { samples: vec![] };
izip!(ideal_lwe_sk.iter(), seeded_server_key.rgsw_cts().iter()).for_each(
|(s_i, rgsw_ct_i)| {
@ -2361,6 +2370,10 @@ mod tests {
Vec::<u64>::try_convert_from(ideal_rlwe_sk.as_slice(), rlwe_q);
rlwe_modop.elwise_neg_mut(&mut neg_s_eval);
rlwe_nttop.forward(&mut neg_s_eval);
// let tmp_decomp = bool_evaluator
// .parameters()
// .rgsw_rgsw_decomposer::<DefaultDecomposer<u64>>();
// let tmp_gadget = tmp_decomp.a().gadget_vector()
for j in 0..rlwe_rgsw_decomposer.a().decomposition_count() {
// RLWE(B^{j} * -s[X]*X^{s_lwe[i]})
@ -2616,7 +2629,7 @@ mod tests {
// check noise in RLWE(X^k) after sending RLWE(X) -> RLWE(X^k)using collective
// auto key
if true {
if false {
let mut check = Stats { samples: vec![] };
let br_q = bool_evaluator.pbs_info.br_q();
let g = bool_evaluator.pbs_info.g();
@ -2692,7 +2705,7 @@ mod tests {
// Check noise growth in ksk
// TODO check in LWE key switching keys
if true {
if false {
// 1. encrypt LWE ciphertext
// 2. Key switching
// 3.

+ 22
- 0
src/bool/parameters.rs

@ -486,6 +486,28 @@ pub(crate) const SMALL_MP_BOOL_PARAMS: BoolParameters = BoolParameters::
variant: ParameterVariant::MultiParty,
};
pub(crate) const OPTIMISED_SMALL_MP_BOOL_PARAMS: BoolParameters<u64> = BoolParameters::<u64> {
rlwe_q: CiphertextModulus::new_non_native(18014398509404161),
lwe_q: CiphertextModulus::new_non_native(1 << 15),
br_q: 1 << 11,
rlwe_n: PolynomialSize(1 << 11),
lwe_n: LweDimension(500),
lwe_decomposer_params: (DecompostionLogBase(1), DecompositionCount(11)),
rlrg_decomposer_params: (
DecompostionLogBase(24),
(DecompositionCount(1), DecompositionCount(1)),
),
rgrg_decomposer_params: Some((
DecompostionLogBase(12),
(DecompositionCount(3), DecompositionCount(3)),
)),
auto_decomposer_params: (DecompostionLogBase(20), DecompositionCount(1)),
non_interactive_ui_to_s_key_switch_decomposer: None,
g: 5,
w: 10,
variant: ParameterVariant::MultiParty,
};
pub(crate) const NON_INTERACTIVE_SMALL_MP_BOOL_PARAMS: BoolParameters<u64> = BoolParameters::<u64> {
rlwe_q: CiphertextModulus::new_non_native(36028797018820609),
lwe_q: CiphertextModulus::new_non_native(1 << 20),

+ 51
- 46
src/decomposer.rs

@ -11,6 +11,7 @@ use std::{
use crate::backend::{ArithmeticOps, ModularOpsU64};
fn gadget_vector<T: PrimInt>(logq: usize, logb: usize, d: usize) -> Vec<T> {
assert!(logq >= (logb * d));
let ignored_bits = logq - (logb * d);
(0..d)
@ -114,7 +115,8 @@ impl<
+ WrappingAdd
+ NumInfo
+ From<bool>
+ Display,
+ Display
+ Debug,
> Decomposer for DefaultDecomposer<T>
{
type Element = T;
@ -128,6 +130,11 @@ impl<
(T::BITS - q.leading_zeros()) as usize
};
assert!(
logq >= (logb * d),
"Decomposer wants logq >= logb*d but got logq={logq}, logb={logb}, d={d}"
);
let ignore_bits = logq - (logb * d);
DefaultDecomposer {
@ -144,20 +151,19 @@ impl<
// TODO(Jay): Outline the caveat
fn decompose_to_vec(&self, value: &T) -> Vec<T> {
let mut value = round_value(*value, self.ignore_bits);
let q = self.q;
let logb = self.logb;
let b = T::one() << logb;
let full_mask = b - T::one();
let bby2 = b >> 1;
let mut value = *value;
if value >= (q >> 1) {
value = !(q - value) + T::one()
}
value = round_value(value, self.ignore_bits);
let mut out = Vec::with_capacity(self.d);
for _ in 0..self.d {
for _ in 0..(self.d) {
let k_i = value & full_mask;
value = (value - k_i) >> logb;
@ -178,11 +184,11 @@ impl<
}
fn decompose_iter(&self, value: &T) -> DecomposerIter<T> {
let mut value = round_value(*value, self.ignore_bits);
let mut value = *value;
if value >= (self.q >> 1) {
value = !(self.q - value) + T::one()
}
value = round_value(value, self.ignore_bits);
DecomposerIter {
value,
@ -283,50 +289,49 @@ mod tests {
#[test]
fn decomposition_works() {
let logq = 55;
let logb = 12;
let d = 4;
let ring_size = 1 << 11;
let mut rng = thread_rng();
let mut stats = vec![Stats::new(); d];
for i in [true] {
let q = if i {
generate_prime(logq, 2 * ring_size, 1u64 << logq).unwrap()
} else {
1u64 << logq
};
let decomposer = DefaultDecomposer::new(q, logb, d);
dbg!(decomposer.ignore_bits);
let modq_op = ModularOpsU64::new(q);
for _ in 0..100000 {
let value = rng.gen_range(0..q);
let limbs = decomposer.decompose_to_vec(&value);
// let limbs_from_iter = decomposer.decompose_iter(&value).collect_vec();
// assert_eq!(limbs, limbs_from_iter);
let value_back = round_value(
decomposer.recompose(&limbs, &modq_op),
decomposer.ignore_bits,
);
let rounded_value = round_value(value, decomposer.ignore_bits);
// assert_eq!(
// rounded_value, value_back,
// "Expected {rounded_value} got {value_back} for q={q}"
// );
izip!(stats.iter_mut(), limbs.iter()).for_each(|(s, l)| {
s.add_more(&vec![q.map_element_to_i64(l)]);
});
for logq in [37, 55] {
let logb = 11;
let d = 3;
let mut stats = vec![Stats::new(); d];
for i in [true] {
let q = if i {
generate_prime(logq, 2 * ring_size, 1u64 << logq).unwrap()
} else {
1u64 << logq
};
let decomposer = DefaultDecomposer::new(q, logb, d);
dbg!(decomposer.ignore_bits);
let modq_op = ModularOpsU64::new(q);
for _ in 0..1000000 {
let value = rng.gen_range(0..q);
let limbs = decomposer.decompose_to_vec(&value);
let limbs_from_iter = decomposer.decompose_iter(&value).collect_vec();
assert_eq!(limbs, limbs_from_iter);
let value_back = round_value(
decomposer.recompose(&limbs, &modq_op),
decomposer.ignore_bits,
);
let rounded_value = round_value(value, decomposer.ignore_bits);
assert!((rounded_value as i64 - value_back as i64).abs() <= 1,);
izip!(stats.iter_mut(), limbs.iter()).for_each(|(s, l)| {
s.add_more(&vec![q.map_element_to_i64(l)]);
});
}
}
}
stats.iter().enumerate().for_each(|(index, s)| {
println!(
"Limb {index} - Mean: {}, Std: {}",
s.mean(),
s.std_dev().abs().log2()
);
});
stats.iter().enumerate().for_each(|(index, s)| {
println!(
"Limb {index} - Mean: {}, Std: {}",
s.mean(),
s.std_dev().abs().log2()
);
});
}
}
}

+ 73
- 22
src/rgsw/mod.rs

@ -1156,8 +1156,8 @@ pub(crate) mod tests {
let logq = 55;
let ring_size = 1 << 11;
let q = generate_prime(logq, ring_size as u64, 1u64 << logq).unwrap();
let d = 12;
let logb = 4;
let d = 2;
let logb = 12;
let decomposer = DefaultDecomposer::new(q, logb, d);
let ntt_op = NttBackendU64::new(&q, ring_size as usize);
@ -1169,16 +1169,42 @@ pub(crate) mod tests {
for _ in 0..10 {
let mut a = vec![0u64; ring_size];
RandomFillUniformInModulus::random_fill(&mut rng, &q, a.as_mut());
let mut e = vec![1u64; ring_size];
// RandomFillGaussianInModulus::random_fill(&mut rng, &q, e.as_mut());
let mut m = vec![0u64; ring_size];
RandomFillGaussianInModulus::random_fill(&mut rng, &q, m.as_mut());
let mut sk = vec![0u64; ring_size];
RandomFillGaussianInModulus::random_fill(&mut rng, &q, sk.as_mut());
let mut sk_eval = sk.clone();
ntt_op.forward(sk_eval.as_mut_slice());
let gadget_vector = decomposer.gadget_vector();
// ksk (beta e)
let mut ksk = vec![vec![0u64; ring_size]; decomposer.decomposition_count()];
izip!(ksk.iter_rows_mut(), gadget_vector.iter()).for_each(|(row, beta)| {
row.as_mut_slice().copy_from_slice(e.as_ref());
mod_op.elwise_scalar_mul_mut(row.as_mut_slice(), beta);
let mut ksk_part_b = vec![vec![0u64; ring_size]; decomposer.decomposition_count()];
let mut ksk_part_a = vec![vec![0u64; ring_size]; decomposer.decomposition_count()];
izip!(
ksk_part_b.iter_rows_mut(),
ksk_part_a.iter_rows_mut(),
gadget_vector.iter()
)
.for_each(|(part_b, part_a, beta)| {
RandomFillUniformInModulus::random_fill(&mut rng, &q, part_a.as_mut());
// a * s
let mut tmp = part_a.to_vec();
ntt_op.forward(tmp.as_mut());
mod_op.elwise_mul_mut(tmp.as_mut(), sk_eval.as_ref());
ntt_op.backward(tmp.as_mut());
// a*s + e + beta m
RandomFillGaussianInModulus::random_fill(&mut rng, &q, part_b.as_mut());
// println!("E: {:?}", &part_b);
// a*s + e
mod_op.elwise_add_mut(part_b.as_mut_slice(), tmp.as_ref());
// a*s + e + beta m
let mut tmp = m.to_vec();
mod_op.elwise_scalar_mul_mut(tmp.as_mut_slice(), beta);
mod_op.elwise_add_mut(part_b.as_mut_slice(), tmp.as_ref());
});
// decompose a
@ -1195,35 +1221,60 @@ pub(crate) mod tests {
// println!("Last limb");
// decomp_a * ksk(beta e)
ksk.iter_mut()
// decomp_a * ksk(beta m)
ksk_part_b
.iter_mut()
.for_each(|r| ntt_op.forward(r.as_mut_slice()));
ksk_part_a
.iter_mut()
.for_each(|r| ntt_op.forward(r.as_mut_slice()));
decomposed_a
.iter_mut()
.for_each(|r| ntt_op.forward(r.as_mut_slice()));
let mut out = vec![0u64; ring_size];
izip!(decomposed_a.iter(), ksk.iter()).for_each(|(a, b)| {
// out += a * b
let mut a_clone = a.clone();
mod_op.elwise_mul_mut(a_clone.as_mut_slice(), b.as_ref());
mod_op.elwise_add_mut(out.as_mut_slice(), a_clone.as_ref());
});
ntt_op.backward(out.as_mut_slice());
let mut out = vec![vec![0u64; ring_size]; 2];
izip!(decomposed_a.iter(), ksk_part_b.iter(), ksk_part_a.iter()).for_each(
|(d_a, part_b, part_a)| {
// out_a += d_a * part_a
let mut d_a_clone = d_a.clone();
mod_op.elwise_mul_mut(d_a_clone.as_mut_slice(), part_a.as_ref());
mod_op.elwise_add_mut(out[0].as_mut_slice(), d_a_clone.as_ref());
// out_b += d_a * part_b
let mut d_a_clone = d_a.clone();
mod_op.elwise_mul_mut(d_a_clone.as_mut_slice(), part_b.as_ref());
mod_op.elwise_add_mut(out[1].as_mut_slice(), d_a_clone.as_ref());
},
);
out.iter_mut()
.for_each(|r| ntt_op.backward(r.as_mut_slice()));
let out_back = {
// decrypt
// a*s
ntt_op.forward(out[0].as_mut());
mod_op.elwise_mul_mut(out[0].as_mut(), sk_eval.as_ref());
ntt_op.backward(out[0].as_mut());
// b - a*s
let tmp = (out[0]).clone();
mod_op.elwise_sub_mut(out[1].as_mut(), tmp.as_ref());
out.remove(1)
};
let out_expected = {
let mut a_clone = a.clone();
let mut e_clone = e.clone();
let mut m_clone = m.clone();
ntt_op.forward(a_clone.as_mut_slice());
ntt_op.forward(e_clone.as_mut_slice());
ntt_op.forward(m_clone.as_mut_slice());
mod_op.elwise_mul_mut(a_clone.as_mut_slice(), e_clone.as_mut_slice());
mod_op.elwise_mul_mut(a_clone.as_mut_slice(), m_clone.as_mut_slice());
ntt_op.backward(a_clone.as_mut_slice());
a_clone
};
let mut diff = out_expected;
mod_op.elwise_sub_mut(diff.as_mut_slice(), out.as_ref());
mod_op.elwise_sub_mut(diff.as_mut_slice(), out_back.as_ref());
stats.add_more(&Vec::<i64>::try_convert_from(diff.as_ref(), &q));
}

Loading…
Cancel
Save