Browse Source

mid-way through big refactor

par-agg-key-shares
Janmajaya Mall 11 months ago
parent
commit
66464941cc
5 changed files with 342 additions and 156 deletions
  1. +106
    -73
      src/bool/evaluator.rs
  2. +190
    -40
      src/bool/parameters.rs
  3. +29
    -27
      src/decomposer.rs
  4. +2
    -2
      src/lwe.rs
  5. +15
    -14
      src/rgsw.rs

+ 106
- 73
src/bool/evaluator.rs

@ -145,15 +145,15 @@ where
assert!(value.len() > 0); assert!(value.len() > 0);
let parameters = &value[0].parameters; let parameters = &value[0].parameters;
let mut key = M::zeros(2, parameters.rlwe_n);
let mut key = M::zeros(2, parameters.rlwe_n().0);
// sample A // sample A
let seed = value[0].cr_seed; let seed = value[0].cr_seed;
let mut main_rng = Rng::new_with_seed(seed); let mut main_rng = Rng::new_with_seed(seed);
RandomUniformDist::random_fill(&mut main_rng, &parameters.rlwe_q, key.get_row_mut(0));
RandomUniformDist::random_fill(&mut main_rng, &parameters.rlwe_q().0, key.get_row_mut(0));
// Sum all Bs // Sum all Bs
let rlweq_modop = ModOp::new(parameters.rlwe_q);
let rlweq_modop = ModOp::new(parameters.rlwe_q().0);
value.iter().for_each(|share_i| { value.iter().for_each(|share_i| {
assert!(share_i.cr_seed == seed); assert!(share_i.cr_seed == seed);
assert!(&share_i.parameters == parameters); assert!(&share_i.parameters == parameters);
@ -203,12 +203,10 @@ where
let parameters = shares[0].parameters.clone(); let parameters = shares[0].parameters.clone();
let cr_seed = shares[0].cr_seed; let cr_seed = shares[0].cr_seed;
let rlwe_n = parameters.rlwe_n;
let g = parameters.g as isize;
let d_rgsw = parameters.d_rgsw;
let d_lwe = parameters.d_lwe;
let rlwe_q = parameters.rlwe_q;
let lwe_q = parameters.lwe_q;
let rlwe_n = parameters.rlwe_n().0;
let g = parameters.g() as isize;
let rlwe_q = parameters.rlwe_q().0;
let lwe_q = parameters.lwe_q().0;
// sanity checks // sanity checks
shares.iter().skip(1).for_each(|s| { shares.iter().skip(1).for_each(|s| {
@ -222,11 +220,13 @@ where
// auto keys // auto keys
let mut auto_keys = HashMap::new(); let mut auto_keys = HashMap::new();
for i in [g, -g] { for i in [g, -g] {
let mut key = M::zeros(d_rgsw, rlwe_n);
let mut key = M::zeros(parameters.auto_decomposition_count().0, rlwe_n);
shares.iter().for_each(|s| { shares.iter().for_each(|s| {
let auto_key_share_i = s.auto_keys.get(&i).expect("Auto key {i} missing"); let auto_key_share_i = s.auto_keys.get(&i).expect("Auto key {i} missing");
assert!(auto_key_share_i.dimension() == (d_rgsw, rlwe_n));
assert!(
auto_key_share_i.dimension() == (parameters.auto_decomposition_count().0, rlwe_n)
);
izip!(key.iter_rows_mut(), auto_key_share_i.iter_rows()).for_each( izip!(key.iter_rows_mut(), auto_key_share_i.iter_rows()).for_each(
|(partb_out, partb_share)| { |(partb_out, partb_share)| {
rlweq_modop.elwise_add_mut(partb_out.as_mut(), partb_share.as_ref()); rlweq_modop.elwise_add_mut(partb_out.as_mut(), partb_share.as_ref());
@ -238,8 +238,9 @@ where
} }
// rgsw ciphertext (most expensive part!) // rgsw ciphertext (most expensive part!)
let lwe_n = parameters.lwe_n;
let lwe_n = parameters.lwe_n().0;
let mut scratch_d_plus_rgsw_by_ring = M::zeros(d_rgsw + (d_rgsw * 4), rlwe_n); let mut scratch_d_plus_rgsw_by_ring = M::zeros(d_rgsw + (d_rgsw * 4), rlwe_n);
let mut tmp_rgsw = M::zeros(d_rgsw * 2 * 2, rlwe_n); let mut tmp_rgsw = M::zeros(d_rgsw * 2 * 2, rlwe_n);
let rgsw_cts = (0..lwe_n) let rgsw_cts = (0..lwe_n)
.into_iter() .into_iter()
@ -272,10 +273,10 @@ where
.collect_vec(); .collect_vec();
// LWE ksks // LWE ksks
let mut lwe_ksk = M::R::zeros(rlwe_n * d_lwe);
let mut lwe_ksk = M::R::zeros(rlwe_n * parameters.lwe_decomposition_count().0);
let lweq_modop = ModOp::new(lwe_q); let lweq_modop = ModOp::new(lwe_q);
shares.iter().for_each(|si| { shares.iter().for_each(|si| {
assert!(si.lwe_ksk.as_ref().len() == rlwe_n * d_lwe);
assert!(si.lwe_ksk.as_ref().len() == rlwe_n * parameters.lwe_decomposition_count().0);
lweq_modop.elwise_add_mut(lwe_ksk.as_mut(), si.lwe_ksk.as_ref()) lweq_modop.elwise_add_mut(lwe_ksk.as_mut(), si.lwe_ksk.as_ref())
}); });
@ -288,6 +289,7 @@ where
} }
} }
/// Seeded single party server key
struct SeededServerKey<M: Matrix, P, S> { struct SeededServerKey<M: Matrix, P, S> {
/// Rgsw cts of LWE secret elements /// Rgsw cts of LWE secret elements
pub(crate) rgsw_cts: Vec<M>, pub(crate) rgsw_cts: Vec<M>,
@ -310,13 +312,24 @@ impl SeededServerKey, S> {
seed: S, seed: S,
) -> Self { ) -> Self {
// sanity checks // sanity checks
auto_keys
.iter()
.for_each(|v| assert!(v.1.dimension() == (parameters.d_rgsw, parameters.rlwe_n)));
rgsw_cts
.iter()
.for_each(|v| assert!(v.dimension() == (parameters.d_rgsw * 3, parameters.rlwe_n)));
assert!(lwe_ksk.as_ref().len() == (parameters.d_lwe * parameters.rlwe_n));
auto_keys.iter().for_each(|v| {
assert!(
v.1.dimension()
== (
parameters.auto_decomposition_count().0,
parameters.rlwe_n().0
)
)
});
let (part_a_d, part_b_d) = parameters.rlwe_rgsw_decomposition_count();
rgsw_cts.iter().for_each(|v| {
assert!(v.dimension() == (part_a_d.0 * 2 + part_b_d.0, parameters.rlwe_n().0))
});
assert!(
lwe_ksk.as_ref().len()
== (parameters.lwe_decomposition_count().0 * parameters.rlwe_n().0)
);
SeededServerKey { SeededServerKey {
rgsw_cts, rgsw_cts,
@ -328,6 +341,7 @@ impl SeededServerKey, S> {
} }
} }
/// Server key in evaluation domain
struct ServerKeyEvaluationDomain<M, R, N> { struct ServerKeyEvaluationDomain<M, R, N> {
/// Rgsw cts of LWE secret elements /// Rgsw cts of LWE secret elements
rgsw_cts: Vec<M>, rgsw_cts: Vec<M>,
@ -351,33 +365,32 @@ where
{ {
fn from(value: &SeededServerKey<M, BoolParameters<M::MatElement>, R::Seed>) -> Self { fn from(value: &SeededServerKey<M, BoolParameters<M::MatElement>, R::Seed>) -> Self {
let mut main_prng = R::new_with_seed(value.seed.clone()); let mut main_prng = R::new_with_seed(value.seed.clone());
let g = value.parameters.g as isize;
let ring_size = value.parameters.rlwe_n;
let lwe_n = value.parameters.lwe_n;
let d_rgsw = value.parameters.d_rgsw;
let d_lwe = value.parameters.d_lwe;
let rlwe_q = value.parameters.rlwe_q;
let lwq_q = value.parameters.lwe_q;
let parameters = &value.parameters;
let g = parameters.g() as isize;
let ring_size = value.parameters.rlwe_n().0;
let lwe_n = value.parameters.lwe_n().0;
let rlwe_q = value.parameters.rlwe_q().0;
let lwq_q = value.parameters.lwe_q().0;
let nttop = N::new(rlwe_q, ring_size); let nttop = N::new(rlwe_q, ring_size);
// galois keys // galois keys
let mut auto_keys = HashMap::new(); let mut auto_keys = HashMap::new();
let auto_decomp_count = parameters.auto_decomposition_count().0;
for i in [g, -g] { for i in [g, -g] {
let seeded_auto_key = value.auto_keys.get(&i).unwrap(); let seeded_auto_key = value.auto_keys.get(&i).unwrap();
assert!(seeded_auto_key.dimension() == (d_rgsw, ring_size));
assert!(seeded_auto_key.dimension() == (auto_decomp_count, ring_size));
let mut data = M::zeros(d_rgsw * 2, ring_size);
let mut data = M::zeros(auto_decomp_count * 2, ring_size);
// sample RLWE'_A(-s(X^k)) // sample RLWE'_A(-s(X^k))
data.iter_rows_mut().take(d_rgsw).for_each(|ri| {
data.iter_rows_mut().take(auto_decomp_count).for_each(|ri| {
RandomUniformDist::random_fill(&mut main_prng, &rlwe_q, ri.as_mut()) RandomUniformDist::random_fill(&mut main_prng, &rlwe_q, ri.as_mut())
}); });
// copy over RLWE'B_(-s(X^k)) // copy over RLWE'B_(-s(X^k))
izip!( izip!(
data.iter_rows_mut().skip(d_rgsw),
data.iter_rows_mut().skip(auto_decomp_count),
seeded_auto_key.iter_rows() seeded_auto_key.iter_rows()
) )
.for_each(|(to_ri, from_ri)| to_ri.as_mut().copy_from_slice(from_ri.as_ref())); .for_each(|(to_ri, from_ri)| to_ri.as_mut().copy_from_slice(from_ri.as_ref()));
@ -390,33 +403,38 @@ where
} }
// RGSW ciphertexts // RGSW ciphertexts
let (rlrg_a_decomp, rlrg_b_decomp) = parameters.rlwe_rgsw_decomposition_count();
let rgsw_cts = value let rgsw_cts = value
.rgsw_cts .rgsw_cts
.iter() .iter()
.map(|seeded_rgsw_si| { .map(|seeded_rgsw_si| {
assert!(seeded_rgsw_si.dimension() == (3 * d_rgsw, ring_size));
assert!(
seeded_rgsw_si.dimension()
== (rlrg_a_decomp.0 * 2 + rlrg_b_decomp.0, ring_size)
);
let mut data = M::zeros(d_rgsw * 4, ring_size);
let mut data = M::zeros(rlrg_a_decomp.0 * 2 + rlrg_b_decomp.0 * 2, ring_size);
// copy over RLWE'(-sm) // copy over RLWE'(-sm)
izip!( izip!(
data.iter_rows_mut().take(d_rgsw * 2),
seeded_rgsw_si.iter_rows().take(d_rgsw * 2)
data.iter_rows_mut().take(rlrg_a_decomp.0 * 2),
seeded_rgsw_si.iter_rows().take(rlrg_a_decomp.0 * 2)
) )
.for_each(|(to_ri, from_ri)| to_ri.as_mut().copy_from_slice(from_ri.as_ref())); .for_each(|(to_ri, from_ri)| to_ri.as_mut().copy_from_slice(from_ri.as_ref()));
// sample RLWE'_A(m) // sample RLWE'_A(m)
data.iter_rows_mut() data.iter_rows_mut()
.skip(2 * d_rgsw)
.take(d_rgsw)
.skip(rlrg_a_decomp.0 * 2)
.take(rlrg_b_decomp.0)
.for_each(|ri| { .for_each(|ri| {
RandomUniformDist::random_fill(&mut main_prng, &rlwe_q, ri.as_mut()) RandomUniformDist::random_fill(&mut main_prng, &rlwe_q, ri.as_mut())
}); });
// copy over RLWE'_B(m) // copy over RLWE'_B(m)
izip!( izip!(
data.iter_rows_mut().skip(d_rgsw * 3),
seeded_rgsw_si.iter_rows().skip(d_rgsw * 2)
data.iter_rows_mut()
.skip(rlrg_a_decomp.0 * 2 + rlrg_b_decomp.0),
seeded_rgsw_si.iter_rows().skip(rlrg_a_decomp.0 * 2)
) )
.for_each(|(to_ri, from_ri)| to_ri.as_mut().copy_from_slice(from_ri.as_ref())); .for_each(|(to_ri, from_ri)| to_ri.as_mut().copy_from_slice(from_ri.as_ref()));
@ -430,9 +448,10 @@ where
// LWE ksk // LWE ksk
let lwe_ksk = { let lwe_ksk = {
assert!(value.lwe_ksk.as_ref().len() == d_lwe * ring_size);
let d = parameters.lwe_decomposition_count().0;
assert!(value.lwe_ksk.as_ref().len() == d * ring_size);
let mut data = M::zeros(d_lwe * ring_size, lwe_n + 1);
let mut data = M::zeros(d * ring_size, lwe_n + 1);
izip!(data.iter_rows_mut(), value.lwe_ksk.as_ref().iter()).for_each(|(lwe_i, bi)| { izip!(data.iter_rows_mut(), value.lwe_ksk.as_ref().iter()).for_each(|(lwe_i, bi)| {
RandomUniformDist::random_fill(&mut main_prng, &lwq_q, &mut lwe_i.as_mut()[1..]); RandomUniformDist::random_fill(&mut main_prng, &lwq_q, &mut lwe_i.as_mut()[1..]);
lwe_i.as_mut()[0] = *bi; lwe_i.as_mut()[0] = *bi;
@ -465,12 +484,11 @@ where
fn from( fn from(
value: &SeededMultiPartyServerKey<M, Rng::Seed, BoolParameters<M::MatElement>>, value: &SeededMultiPartyServerKey<M, Rng::Seed, BoolParameters<M::MatElement>>,
) -> Self { ) -> Self {
let g = value.parameters.g as isize;
let rlwe_n = value.parameters.rlwe_n;
let lwe_n = value.parameters.lwe_n;
let rlwe_q = value.parameters.rlwe_q;
let lwe_q = value.parameters.lwe_q;
let d_rgsw = value.parameters.d_rgsw;
let g = value.parameters.g() as isize;
let rlwe_n = value.parameters.rlwe_n().0;
let lwe_n = value.parameters.lwe_n().0;
let rlwe_q = value.parameters.rlwe_q().0;
let lwe_q = value.parameters.lwe_q().0;
let mut main_prng = Rng::new_with_seed(value.cr_seed); let mut main_prng = Rng::new_with_seed(value.cr_seed);
@ -478,21 +496,24 @@ where
// auto keys // auto keys
let mut auto_keys = HashMap::new(); let mut auto_keys = HashMap::new();
let auto_d_count = value.parameters.auto_decomposition_count().0;
for i in [g, -g] { for i in [g, -g] {
let mut key = M::zeros(value.parameters.d_rgsw * 2, rlwe_n);
let mut key = M::zeros(auto_d_count * 2, rlwe_n);
// sample a // sample a
key.iter_rows_mut().take(d_rgsw).for_each(|ri| {
key.iter_rows_mut().take(auto_d_count).for_each(|ri| {
RandomUniformDist::random_fill(&mut main_prng, &rlwe_q, ri.as_mut()) RandomUniformDist::random_fill(&mut main_prng, &rlwe_q, ri.as_mut())
}); });
let key_part_b = value.auto_keys.get(&i).unwrap(); let key_part_b = value.auto_keys.get(&i).unwrap();
assert!(key_part_b.dimension() == (d_rgsw, rlwe_n));
izip!(key.iter_rows_mut().skip(d_rgsw), key_part_b.iter_rows()).for_each(
|(to_ri, from_ri)| {
to_ri.as_mut().copy_from_slice(from_ri.as_ref());
},
);
assert!(key_part_b.dimension() == (auto_d_count, rlwe_n));
izip!(
key.iter_rows_mut().skip(auto_d_count),
key_part_b.iter_rows()
)
.for_each(|(to_ri, from_ri)| {
to_ri.as_mut().copy_from_slice(from_ri.as_ref());
});
// send to evaluation domain // send to evaluation domain
key.iter_rows_mut() key.iter_rows_mut()
@ -502,11 +523,14 @@ where
} }
// rgsw cts // rgsw cts
let (rlrg_d_a, rlrg_d_b) = value.parameters.rlwe_rgsw_decomposition_count();
let rgsw_ct_rows = rlrg_d_a.0 * 2 + rlrg_d_b.0 * 2;
let rgsw_cts = value let rgsw_cts = value
.rgsw_cts .rgsw_cts
.iter() .iter()
.map(|ct_i| { .map(|ct_i| {
let mut eval_ct_i = M::zeros(d_rgsw * 4, rlwe_n);
assert!(ct_i.dimension() == (rgsw_ct_rows, rlwe_n));
let mut eval_ct_i = M::zeros(rgsw_ct_rows, rlwe_n);
izip!(eval_ct_i.iter_rows_mut(), ct_i.iter_rows()).for_each(|(to_ri, from_ri)| { izip!(eval_ct_i.iter_rows_mut(), ct_i.iter_rows()).for_each(|(to_ri, from_ri)| {
to_ri.as_mut().copy_from_slice(from_ri.as_ref()); to_ri.as_mut().copy_from_slice(from_ri.as_ref());
@ -518,7 +542,7 @@ where
.collect_vec(); .collect_vec();
// lwe ksk // lwe ksk
let d_lwe = value.parameters.d_lwe;
let d_lwe = value.parameters.lwe_decomposition_count().0;
let mut lwe_ksk = M::zeros(rlwe_n * d_lwe, lwe_n + 1); let mut lwe_ksk = M::zeros(rlwe_n * d_lwe, lwe_n + 1);
izip!(lwe_ksk.iter_rows_mut(), value.lwe_ksk.as_ref().iter()).for_each(|(lwe_i, bi)| { izip!(lwe_ksk.iter_rows_mut(), value.lwe_ksk.as_ref().iter()).for_each(|(lwe_i, bi)| {
RandomUniformDist::random_fill(&mut main_prng, &lwe_q, &mut lwe_i.as_mut()[1..]); RandomUniformDist::random_fill(&mut main_prng, &lwe_q, &mut lwe_i.as_mut()[1..]);
@ -1439,8 +1463,11 @@ fn pbs<
gb_monomial_sign = false gb_monomial_sign = false
} }
// monomial mul // monomial mul
let mut trivial_rlwe_test_poly =
RlweCiphertext::<_, DefaultSecureRng>::from_raw(M::zeros(2, rlwe_n), true);
let mut trivial_rlwe_test_poly = RlweCiphertext::<_, DefaultSecureRng> {
data: M::zeros(2, rlwe_n),
is_trivial: true,
_phatom: PhantomData,
};
if parameters.embedding_factor() == 1 { if parameters.embedding_factor() == 1 {
monomial_mul( monomial_mul(
test_vec.as_ref(), test_vec.as_ref(),
@ -2242,7 +2269,7 @@ mod tests {
// RGSW(carrym) // RGSW(carrym)
let trivial_rlwect = vec![vec![0u64; rlwe_n], carry_m.clone()]; let trivial_rlwect = vec![vec![0u64; rlwe_n], carry_m.clone()];
let mut rlwe_ct = RlweCiphertext::<_, DefaultSecureRng>::from_raw(trivial_rlwect, true);
let mut rlwe_ct = RlweCiphertext::<_, DefaultSecureRng>::new_trivial(trivial_rlwect);
let mut scratch_matrix_dplus2_ring = vec![vec![0u64; rlwe_n]; d_rgsw + 2]; let mut scratch_matrix_dplus2_ring = vec![vec![0u64; rlwe_n]; d_rgsw + 2];
let mul_mod = let mul_mod =
@ -2298,7 +2325,7 @@ mod tests {
rlwe_nttop, rlwe_nttop,
&mut rng, &mut rng,
); );
RlweCiphertext::<_, DefaultSecureRng>::from_raw(data, false)
RlweCiphertext::<_, DefaultSecureRng>::new_trivial(data, false)
}; };
let auto_key = server_key_eval.galois_key_for_auto(i); let auto_key = server_key_eval.galois_key_for_auto(i);
@ -2666,7 +2693,7 @@ mod tests {
Vec::<u64>::try_convert_from(ideal_client_key.sk_rlwe.values(), &rlwe_q); Vec::<u64>::try_convert_from(ideal_client_key.sk_rlwe.values(), &rlwe_q);
rlwe_modop.elwise_neg_mut(&mut neg_s_eval); rlwe_modop.elwise_neg_mut(&mut neg_s_eval);
rlwe_nttop.forward(&mut neg_s_eval); rlwe_nttop.forward(&mut neg_s_eval);
for j in 0..rlwe_decomposer.d() {
for j in 0..rlwe_decomposer.decomposition_count() {
// -s[X]*X^{s_lwe[i]}*B_j // -s[X]*X^{s_lwe[i]}*B_j
let mut m_ideal = m_si.clone(); let mut m_ideal = m_si.clone();
rlwe_nttop.forward(m_ideal.as_mut_slice()); rlwe_nttop.forward(m_ideal.as_mut_slice());
@ -2678,7 +2705,8 @@ mod tests {
// RLWE(-s*X^{s_lwe[i]}*B_j) // RLWE(-s*X^{s_lwe[i]}*B_j)
let mut rlwe_ct = vec![vec![0u64; rlwe_n]; 2]; let mut rlwe_ct = vec![vec![0u64; rlwe_n]; 2];
rlwe_ct[0].copy_from_slice(&rgsw_ct_i[j]); rlwe_ct[0].copy_from_slice(&rgsw_ct_i[j]);
rlwe_ct[1].copy_from_slice(&rgsw_ct_i[j + rlwe_decomposer.d()]);
rlwe_ct[1]
.copy_from_slice(&rgsw_ct_i[j + rlwe_decomposer.decomposition_count()]);
let mut m_back = vec![0u64; rlwe_n]; let mut m_back = vec![0u64; rlwe_n];
decrypt_rlwe( decrypt_rlwe(
@ -2695,7 +2723,7 @@ mod tests {
} }
// RLWE'(m) // RLWE'(m)
for j in 0..rlwe_decomposer.d() {
for j in 0..rlwe_decomposer.decomposition_count() {
// X^{s_lwe[i]}*B_j // X^{s_lwe[i]}*B_j
let mut m_ideal = m_si.clone(); let mut m_ideal = m_si.clone();
rlwe_modop rlwe_modop
@ -2703,8 +2731,12 @@ mod tests {
// RLWE(X^{s_lwe[i]}*B_j) // RLWE(X^{s_lwe[i]}*B_j)
let mut rlwe_ct = vec![vec![0u64; rlwe_n]; 2]; let mut rlwe_ct = vec![vec![0u64; rlwe_n]; 2];
rlwe_ct[0].copy_from_slice(&rgsw_ct_i[j + (2 * rlwe_decomposer.d())]);
rlwe_ct[1].copy_from_slice(&rgsw_ct_i[j + (3 * rlwe_decomposer.d())]);
rlwe_ct[0].copy_from_slice(
&rgsw_ct_i[j + (2 * rlwe_decomposer.decomposition_count())],
);
rlwe_ct[1].copy_from_slice(
&rgsw_ct_i[j + (3 * rlwe_decomposer.decomposition_count())],
);
let mut m_back = vec![0u64; rlwe_n]; let mut m_back = vec![0u64; rlwe_n];
decrypt_rlwe( decrypt_rlwe(
@ -2759,13 +2791,14 @@ mod tests {
); );
// RLWE(m*X^{s[i]}) = RLWE(m) x RGSW(X^{s[i]}) // RLWE(m*X^{s[i]}) = RLWE(m) x RGSW(X^{s[i]})
let mut rlwe_after = RlweCiphertext::<_, DefaultSecureRng>::from_raw(
vec![vec![0u64; rlwe_n], m.clone()],
true,
);
let mut rlwe_after = RlweCiphertext::<_, DefaultSecureRng>::new_trivial(vec![
vec![0u64; rlwe_n],
m.clone(),
]);
// let mut rlwe_after = // let mut rlwe_after =
// RlweCiphertext::<_, DefaultSecureRng>::from_raw(rlwe_ct.clone(), false); // RlweCiphertext::<_, DefaultSecureRng>::from_raw(rlwe_ct.clone(), false);
let mut scratch = vec![vec![0u64; rlwe_n]; rlwe_decomposer.d() + 2];
let mut scratch =
vec![vec![0u64; rlwe_n]; rlwe_decomposer.decomposition_count() + 2];
rlwe_by_rgsw( rlwe_by_rgsw(
&mut rlwe_after, &mut rlwe_after,
&rgsw_ct_i, &rgsw_ct_i,

+ 190
- 40
src/bool/parameters.rs

@ -1,54 +1,204 @@
use crate::decomposer::Decomposer;
#[derive(Clone, PartialEq)] #[derive(Clone, PartialEq)]
pub(super) struct BoolParameters<El> { pub(super) struct BoolParameters<El> {
pub(super) rlwe_q: El,
pub(super) rlwe_logq: usize,
pub(super) lwe_q: El,
pub(super) lwe_logq: usize,
pub(super) br_q: usize,
pub(super) rlwe_n: usize,
pub(super) lwe_n: usize,
pub(super) d_rgsw: usize,
pub(super) logb_rgsw: usize,
pub(super) d_lwe: usize,
pub(super) logb_lwe: usize,
pub(super) g: usize,
pub(super) w: usize,
rlwe_q: Modulus<El>,
lwe_q: Modulus<El>,
br_q: Modulus<El>,
rlwe_n: PolynomialSize,
lwe_n: LweDimension,
lwe_decomposer_base: DecompostionLogBase,
lwe_decomposer_count: DecompositionCount,
rlrg_decomposer_base: DecompostionLogBase,
/// RLWE x RGSW decomposition count for (part A, part B)
rlrg_decomposer_count: (DecompositionCount, DecompositionCount),
rgrg_decomposer_base: DecompostionLogBase,
/// RGSW x RGSW decomposition count for (part A, part B)
rgrg_decomposer_count: (DecompositionCount, DecompositionCount),
auto_decomposer_base: DecompostionLogBase,
auto_decomposer_count: DecompositionCount,
g: usize,
w: usize,
}
impl<El> BoolParameters<El> {
pub(crate) fn rlwe_q(&self) -> &Modulus<El> {
&self.rlwe_q
}
pub(crate) fn lwe_q(&self) -> &Modulus<El> {
&self.lwe_q
}
pub(crate) fn br_q(&self) -> &Modulus<El> {
&self.br_q
}
pub(crate) fn rlwe_n(&self) -> &PolynomialSize {
&self.rlwe_n
}
pub(crate) fn lwe_n(&self) -> &LweDimension {
&self.lwe_n
}
pub(crate) fn g(&self) -> usize {
self.g
}
pub(crate) fn w(&self) -> usize {
self.w
}
pub(crate) fn rlwe_rgsw_decomposition_base(&self) -> DecompostionLogBase {
self.rlrg_decomposer_base
}
pub(crate) fn rlwe_rgsw_decomposition_count(&self) -> (DecompositionCount, DecompositionCount) {
self.rlrg_decomposer_count
}
pub(crate) fn rgsw_rgsw_decomposition_base(&self) -> DecompostionLogBase {
self.rgrg_decomposer_base
}
pub(crate) fn rgsw_rgsw_decomposition_count(&self) -> (DecompositionCount, DecompositionCount) {
self.rgrg_decomposer_count
}
pub(crate) fn auto_decomposition_base(&self) -> DecompostionLogBase {
self.auto_decomposer_base
}
pub(crate) fn auto_decomposition_count(&self) -> DecompositionCount {
self.auto_decomposer_count
}
pub(crate) fn lwe_decomposition_base(&self) -> DecompostionLogBase {
self.lwe_decomposer_base
}
pub(crate) fn lwe_decomposition_count(&self) -> DecompositionCount {
self.lwe_decomposer_count
}
pub(crate) fn rgsw_rgsw_decomposer<D: Decomposer<Element = El>>(&self) -> (D, D)
where
El: Copy,
{
(
// A
D::new(
self.rlwe_q.0,
self.rgrg_decomposer_base.0,
self.rgrg_decomposer_count.0 .0,
),
// B
D::new(
self.rlwe_q.0,
self.rgrg_decomposer_base.0,
self.rgrg_decomposer_count.1 .0,
),
)
}
pub(crate) fn auto_decomposer<D: Decomposer<Element = El>>(&self) -> D
where
El: Copy,
{
D::new(
self.rlwe_q.0,
self.auto_decomposer_base.0,
self.auto_decomposer_count.0,
)
}
pub(crate) fn lwe_decomposer<D: Decomposer<Element = El>>(&self) -> D
where
El: Copy,
{
D::new(
self.lwe_q.0,
self.lwe_decomposer_base.0,
self.lwe_decomposer_count.0,
)
}
pub(crate) fn rlwe_rgsw_decomposer<D: Decomposer<Element = El>>(&self) -> (D, D)
where
El: Copy,
{
(
// A
D::new(
self.rlwe_q.0,
self.rlrg_decomposer_base.0,
self.rlrg_decomposer_count.0 .0,
),
// B
D::new(
self.rlwe_q.0,
self.rlrg_decomposer_base.0,
self.rlrg_decomposer_count.1 .0,
),
)
}
}
#[derive(Clone, Copy, PartialEq)]
struct DecompostionLogBase(pub(crate) usize);
impl AsRef<usize> for DecompostionLogBase {
fn as_ref(&self) -> &usize {
&self.0
}
}
#[derive(Clone, Copy, PartialEq)]
struct DecompositionCount(pub(crate) usize);
impl AsRef<usize> for DecompositionCount {
fn as_ref(&self) -> &usize {
&self.0
}
} }
// impl<El> BoolParameters<El> {
// fn rlwe_q(&self) -> &El {
// &self.rlwe_q
// }
// }
#[derive(Clone, Copy, PartialEq)]
struct LweDimension(pub(crate) usize);
#[derive(Clone, Copy, PartialEq)]
struct PolynomialSize(pub(crate) usize);
#[derive(Clone, Copy, PartialEq)]
struct Modulus<T>(pub(crate) T);
pub(super) const SP_BOOL_PARAMS: BoolParameters<u64> = BoolParameters::<u64> { pub(super) const SP_BOOL_PARAMS: BoolParameters<u64> = BoolParameters::<u64> {
rlwe_q: 268369921u64,
rlwe_logq: 28,
lwe_q: 1 << 16,
lwe_logq: 16,
br_q: 1 << 10,
rlwe_n: 1 << 10,
lwe_n: 493,
d_rgsw: 4,
logb_rgsw: 7,
d_lwe: 4,
logb_lwe: 4,
rlwe_q: Modulus(268369921u64),
lwe_q: Modulus(1 << 16),
br_q: Modulus(1 << 10),
rlwe_n: PolynomialSize(1 << 10),
lwe_n: LweDimension(493),
lwe_decomposer_base: DecompostionLogBase(4),
lwe_decomposer_count: DecompositionCount(4),
rlrg_decomposer_base: DecompostionLogBase(7),
rlrg_decomposer_count: (DecompositionCount(4), DecompositionCount(4)),
rgrg_decomposer_base: DecompostionLogBase(7),
rgrg_decomposer_count: (DecompositionCount(4), DecompositionCount(4)),
auto_decomposer_base: DecompostionLogBase(7),
auto_decomposer_count: DecompositionCount(4),
g: 5, g: 5,
w: 1, w: 1,
}; };
pub(super) const MP_BOOL_PARAMS: BoolParameters<u64> = BoolParameters::<u64> { pub(super) const MP_BOOL_PARAMS: BoolParameters<u64> = BoolParameters::<u64> {
rlwe_q: 1152921504606830593,
rlwe_logq: 60,
lwe_q: 1 << 20,
lwe_logq: 20,
br_q: 1 << 11,
rlwe_n: 1 << 11,
lwe_n: 500,
d_rgsw: 5,
logb_rgsw: 12,
d_lwe: 5,
logb_lwe: 4,
rlwe_q: Modulus(1152921504606830593),
lwe_q: Modulus(1 << 20),
br_q: Modulus(1 << 11),
rlwe_n: PolynomialSize(1 << 11),
lwe_n: LweDimension(500),
lwe_decomposer_base: DecompostionLogBase(4),
lwe_decomposer_count: DecompositionCount(5),
rlrg_decomposer_base: DecompostionLogBase(12),
rlrg_decomposer_count: (DecompositionCount(5), DecompositionCount(5)),
rgrg_decomposer_base: DecompostionLogBase(12),
rgrg_decomposer_count: (DecompositionCount(5), DecompositionCount(5)),
auto_decomposer_base: DecompostionLogBase(12),
auto_decomposer_count: DecompositionCount(5),
g: 5, g: 5,
w: 1, w: 1,
}; };

+ 29
- 27
src/decomposer.rs

@ -1,5 +1,5 @@
use itertools::Itertools; use itertools::Itertools;
use num_traits::{AsPrimitive, One, PrimInt, ToPrimitive, WrappingSub, Zero};
use num_traits::{AsPrimitive, Num, One, PrimInt, ToPrimitive, WrappingSub, Zero};
use std::{fmt::Debug, marker::PhantomData, ops::Rem}; use std::{fmt::Debug, marker::PhantomData, ops::Rem};
use crate::backend::{ArithmeticOps, ModularOpsU64}; use crate::backend::{ArithmeticOps, ModularOpsU64};
@ -15,9 +15,10 @@ fn gadget_vector(logq: usize, logb: usize, d: usize) -> Vec {
pub trait Decomposer { pub trait Decomposer {
type Element; type Element;
fn new(q: Self::Element, logb: usize, d: usize) -> Self;
//FIXME(Jay): there's no reason why it returns a vec instead of an iterator //FIXME(Jay): there's no reason why it returns a vec instead of an iterator
fn decompose(&self, v: &Self::Element) -> Vec<Self::Element>; fn decompose(&self, v: &Self::Element) -> Vec<Self::Element>;
fn d(&self) -> usize;
fn decomposition_count(&self) -> usize;
} }
// TODO(Jay): Shouldn't Decompose also return corresponding gadget vector ? // TODO(Jay): Shouldn't Decompose also return corresponding gadget vector ?
@ -45,28 +46,6 @@ impl NumInfo for u128 {
} }
impl<T: PrimInt + NumInfo + Debug> DefaultDecomposer<T> { impl<T: PrimInt + NumInfo + Debug> DefaultDecomposer<T> {
pub fn new(q: T, logb: usize, d: usize) -> DefaultDecomposer<T> {
// if q is power of 2, then `BITS - leading_zeros` outputs logq + 1.
let logq = if q & (q - T::one()) == T::zero() {
(T::BITS - q.leading_zeros() - 1) as usize
} else {
(T::BITS - q.leading_zeros()) as usize
};
let d_ideal = (logq as f64 / logb as f64).ceil().to_usize().unwrap();
let ignore_limbs = (d_ideal - d);
let ignore_bits = (d_ideal - d) * logb;
DefaultDecomposer {
q,
logq,
logb,
d,
ignore_bits,
ignore_limbs,
}
}
fn recompose<Op>(&self, limbs: &[T], modq_op: &Op) -> T fn recompose<Op>(&self, limbs: &[T], modq_op: &Op) -> T
where where
Op: ArithmeticOps<Element = T>, Op: ArithmeticOps<Element = T>,
@ -89,10 +68,33 @@ impl DefaultDecomposer {
} }
} }
impl<T: PrimInt + WrappingSub + Debug> Decomposer for DefaultDecomposer<T> {
impl<T: PrimInt + WrappingSub + Debug + NumInfo> Decomposer for DefaultDecomposer<T> {
type Element = T; type Element = T;
fn new(q: T, logb: usize, d: usize) -> DefaultDecomposer<T> {
// if q is power of 2, then `BITS - leading_zeros` outputs logq + 1.
let logq = if q & (q - T::one()) == T::zero() {
(T::BITS - q.leading_zeros() - 1) as usize
} else {
(T::BITS - q.leading_zeros()) as usize
};
let d_ideal = (logq as f64 / logb as f64).ceil().to_usize().unwrap();
let ignore_limbs = (d_ideal - d);
let ignore_bits = (d_ideal - d) * logb;
DefaultDecomposer {
q,
logq,
logb,
d,
ignore_bits,
ignore_limbs,
}
}
fn decompose(&self, value: &T) -> Vec<T> { fn decompose(&self, value: &T) -> Vec<T> {
let mut value = round_value(*value, self.ignore_bits);
let value = round_value(*value, self.ignore_bits);
let q = self.q; let q = self.q;
// if value >= (q >> 1) { // if value >= (q >> 1) {
@ -135,7 +137,7 @@ impl Decomposer for DefaultDecomposer {
return out; return out;
} }
fn d(&self) -> usize {
fn decomposition_count(&self) -> usize {
self.d self.d
} }
} }

+ 2
- 2
src/lwe.rs

@ -111,7 +111,7 @@ pub(crate) fn lwe_key_switch<
operator: &Op, operator: &Op,
decomposer: &D, decomposer: &D,
) { ) {
assert!(lwe_ksk.dimension().0 == ((lwe_in.as_ref().len() - 1) * decomposer.d()));
assert!(lwe_ksk.dimension().0 == ((lwe_in.as_ref().len() - 1) * decomposer.decomposition_count()));
assert!(lwe_out.as_ref().len() == lwe_ksk.dimension().1); assert!(lwe_out.as_ref().len() == lwe_ksk.dimension().1);
let lwe_in_a_decomposed = lwe_in let lwe_in_a_decomposed = lwe_in
@ -274,7 +274,7 @@ mod tests {
use crate::{ use crate::{
backend::{ModInit, ModularOpsU64}, backend::{ModInit, ModularOpsU64},
decomposer::DefaultDecomposer,
decomposer::{Decomposer, DefaultDecomposer},
lwe::{lwe_key_switch, measure_noise_lwe}, lwe::{lwe_key_switch, measure_noise_lwe},
random::DefaultSecureRng, random::DefaultSecureRng,
rgsw::measure_noise, rgsw::measure_noise,

+ 15
- 14
src/rgsw.rs

@ -288,14 +288,14 @@ impl SeededRlweCiphertext {
pub struct RlweCiphertext<M, Rng> { pub struct RlweCiphertext<M, Rng> {
pub(crate) data: M, pub(crate) data: M,
pub(crate) is_trivial: bool, pub(crate) is_trivial: bool,
_phatom: PhantomData<Rng>,
pub(crate) _phatom: PhantomData<Rng>,
} }
impl<M, Rng> RlweCiphertext<M, Rng> { impl<M, Rng> RlweCiphertext<M, Rng> {
pub(crate) fn from_raw(data: M, is_trivial: bool) -> Self {
pub(crate) fn new_trivial(data: M) -> Self {
RlweCiphertext { RlweCiphertext {
data, data,
is_trivial,
is_trivial: true,
_phatom: PhantomData, _phatom: PhantomData,
} }
} }
@ -490,7 +490,7 @@ pub(crate) fn decompose_r>(
R::Element: Copy, R::Element: Copy,
{ {
let ring_size = r.len(); let ring_size = r.len();
let d = decomposer.d();
let d = decomposer.decomposition_count();
for ri in 0..ring_size { for ri in 0..ring_size {
let el_decomposed = decomposer.decompose(&r[ri]); let el_decomposed = decomposer.decompose(&r[ri]);
@ -521,7 +521,7 @@ pub(crate) fn galois_auto<
<MT as Matrix>::R: RowMut, <MT as Matrix>::R: RowMut,
MT::MatElement: Copy + Zero, MT::MatElement: Copy + Zero,
{ {
let d = decomposer.d();
let d = decomposer.decomposition_count();
let (scratch_matrix_d_ring, tmp_rlwe_out) = scratch_matrix_dplus2_ring.split_at_row_mut(d); let (scratch_matrix_d_ring, tmp_rlwe_out) = scratch_matrix_dplus2_ring.split_at_row_mut(d);
@ -625,7 +625,7 @@ pub(crate) fn less1_rlwe_by_rgsw<
<Mmut as Matrix>::R: RowMut, <Mmut as Matrix>::R: RowMut,
<MT as Matrix>::R: RowMut, <MT as Matrix>::R: RowMut,
{ {
let d_rgsw = decomposer.d();
let d_rgsw = decomposer.decomposition_count();
assert!(scratch_matrix_dplus2_ring.dimension() == (d_rgsw + 2, rlwe_in.dimension().1)); assert!(scratch_matrix_dplus2_ring.dimension() == (d_rgsw + 2, rlwe_in.dimension().1));
assert!(rgsw_in.dimension() == (d_rgsw * 4, rlwe_in.dimension().1)); assert!(rgsw_in.dimension() == (d_rgsw * 4, rlwe_in.dimension().1));
@ -717,7 +717,7 @@ pub(crate) fn rlwe_by_rgsw<
<Mmut as Matrix>::R: RowMut, <Mmut as Matrix>::R: RowMut,
<MT as Matrix>::R: RowMut, <MT as Matrix>::R: RowMut,
{ {
let d_rgsw = decomposer.d();
let d_rgsw = decomposer.decomposition_count();
assert!(scratch_matrix_dplus2_ring.dimension() == (d_rgsw + 2, rlwe_in.dimension().1)); assert!(scratch_matrix_dplus2_ring.dimension() == (d_rgsw + 2, rlwe_in.dimension().1));
assert!(rgsw_in.dimension() == (d_rgsw * 4, rlwe_in.dimension().1)); assert!(rgsw_in.dimension() == (d_rgsw * 4, rlwe_in.dimension().1));
@ -818,7 +818,7 @@ pub(crate) fn rgsw_by_rgsw_inplace<
<Mmut as Matrix>::R: RowMut, <Mmut as Matrix>::R: RowMut,
Mmut::MatElement: Copy + Zero, Mmut::MatElement: Copy + Zero,
{ {
let d_rgsw = decomposer.d();
let d_rgsw = decomposer.decomposition_count();
assert!(rgsw_0.dimension().0 == 4 * d_rgsw); assert!(rgsw_0.dimension().0 == 4 * d_rgsw);
let ring_size = rgsw_0.dimension().1; let ring_size = rgsw_0.dimension().1;
assert!(rgsw_1_eval.dimension() == (4 * d_rgsw, ring_size)); assert!(rgsw_1_eval.dimension() == (4 * d_rgsw, ring_size));
@ -1495,14 +1495,14 @@ where
#[cfg(test)] #[cfg(test)]
pub(crate) mod tests { pub(crate) mod tests {
use std::{ops::Mul, vec};
use std::{marker::PhantomData, ops::Mul, vec};
use itertools::{izip, Itertools}; use itertools::{izip, Itertools};
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use crate::{ use crate::{
backend::{ModInit, ModularOpsU64, VectorOps}, backend::{ModInit, ModularOpsU64, VectorOps},
decomposer::DefaultDecomposer,
decomposer::{Decomposer, DefaultDecomposer},
ntt::{self, Ntt, NttBackendU64, NttInit}, ntt::{self, Ntt, NttBackendU64, NttInit},
random::{DefaultSecureRng, NewWithSeed, RandomUniformDist}, random::{DefaultSecureRng, NewWithSeed, RandomUniformDist},
rgsw::{ rgsw::{
@ -2003,10 +2003,11 @@ pub(crate) mod tests {
// RLWE(m) x RGSW(carry_m) // RLWE(m) x RGSW(carry_m)
let mut m = vec![0u64; ring_size as usize]; let mut m = vec![0u64; ring_size as usize];
RandomUniformDist::random_fill(&mut rng, &q, m.as_mut_slice()); RandomUniformDist::random_fill(&mut rng, &q, m.as_mut_slice());
let mut rlwe_ct = RlweCiphertext::<_, DefaultSecureRng>::from_raw(
vec![vec![0u64; ring_size as usize]; 2],
false,
);
let mut rlwe_ct = RlweCiphertext::<_, DefaultSecureRng> {
data: vec![vec![0u64; ring_size as usize]; 2],
is_trivial: false,
_phatom: PhantomData,
};
let mut scratch_matrix_dplus2_ring = vec![vec![0u64; ring_size as usize]; d_rgsw + 2]; let mut scratch_matrix_dplus2_ring = vec![vec![0u64; ring_size as usize]; d_rgsw + 2];
public_key_encrypt_rlwe( public_key_encrypt_rlwe(
&mut rlwe_ct, &mut rlwe_ct,

Loading…
Cancel
Save