mirror of
https://github.com/arnaucube/phantom-zone.git
synced 2026-01-09 15:41:30 +01:00
fix rounding in decom
This commit is contained in:
@@ -326,8 +326,14 @@ pub(super) struct BoolPbsInfo<M: Matrix, Ntt, RlweModOp, LweModOp> {
|
||||
|
||||
impl<M: Matrix, NttOp, RlweModOp, LweModOp> PbsInfo for BoolPbsInfo<M, NttOp, RlweModOp, LweModOp>
|
||||
where
|
||||
M::MatElement:
|
||||
PrimInt + WrappingSub + NumInfo + FromPrimitive + From<bool> + Display + WrappingAdd,
|
||||
M::MatElement: PrimInt
|
||||
+ WrappingSub
|
||||
+ NumInfo
|
||||
+ FromPrimitive
|
||||
+ From<bool>
|
||||
+ Display
|
||||
+ WrappingAdd
|
||||
+ Debug,
|
||||
RlweModOp: ArithmeticOps<Element = M::MatElement> + ShoupMatrixFMA<M::R>,
|
||||
LweModOp: ArithmeticOps<Element = M::MatElement> + VectorOps<Element = M::MatElement>,
|
||||
NttOp: Ntt<Element = M::MatElement>,
|
||||
@@ -2003,7 +2009,8 @@ where
|
||||
+ WrappingSub
|
||||
+ NumInfo
|
||||
+ From<bool>
|
||||
+ WrappingAdd,
|
||||
+ WrappingAdd
|
||||
+ Debug,
|
||||
RlweModOp: VectorOps<Element = M::MatElement>
|
||||
+ ArithmeticOps<Element = M::MatElement>
|
||||
+ ShoupMatrixFMA<M::R>,
|
||||
@@ -2195,7 +2202,9 @@ mod tests {
|
||||
SP_TEST_BOOL_PARAMS,
|
||||
},
|
||||
},
|
||||
evaluator,
|
||||
ntt::NttBackendU64,
|
||||
parameters::OPTIMISED_SMALL_MP_BOOL_PARAMS,
|
||||
random::{RandomElementInModulus, DEFAULT_RNG},
|
||||
rgsw::{
|
||||
self, measure_noise, public_key_encrypt_rlwe, secret_key_encrypt_rlwe,
|
||||
@@ -2216,11 +2225,11 @@ mod tests {
|
||||
ModularOpsU64<CiphertextModulus<u64>>,
|
||||
ModularOpsU64<CiphertextModulus<u64>>,
|
||||
ShoupServerKeyEvaluationDomain<Vec<Vec<u64>>>,
|
||||
>::new(SMALL_MP_BOOL_PARAMS);
|
||||
>::new(OPTIMISED_SMALL_MP_BOOL_PARAMS);
|
||||
|
||||
// let (_, collective_pk, _, _, server_key_eval, ideal_client_key) =
|
||||
// _multi_party_all_keygen(&bool_evaluator, 20);
|
||||
let no_of_parties = 16;
|
||||
let no_of_parties = 2;
|
||||
let lwe_q = bool_evaluator.pbs_info.parameters.lwe_q();
|
||||
let rlwe_q = bool_evaluator.pbs_info.parameters.rlwe_q();
|
||||
let lwe_n = bool_evaluator.pbs_info.parameters.lwe_n().0;
|
||||
@@ -2269,7 +2278,7 @@ mod tests {
|
||||
});
|
||||
|
||||
// check noise in freshly encrypted RLWE ciphertext (ie var_fresh)
|
||||
if true {
|
||||
if false {
|
||||
let mut rng = DefaultSecureRng::new();
|
||||
let mut check = Stats { samples: vec![] };
|
||||
for _ in 0..10 {
|
||||
@@ -2343,7 +2352,7 @@ mod tests {
|
||||
bool_evaluator.aggregate_multi_party_server_key_shares(&server_key_shares);
|
||||
|
||||
// Check noise in RGSW ciphertexts of ideal LWE secret elements
|
||||
if false {
|
||||
if true {
|
||||
let mut check = Stats { samples: vec![] };
|
||||
izip!(ideal_lwe_sk.iter(), seeded_server_key.rgsw_cts().iter()).for_each(
|
||||
|(s_i, rgsw_ct_i)| {
|
||||
@@ -2361,6 +2370,10 @@ mod tests {
|
||||
Vec::<u64>::try_convert_from(ideal_rlwe_sk.as_slice(), rlwe_q);
|
||||
rlwe_modop.elwise_neg_mut(&mut neg_s_eval);
|
||||
rlwe_nttop.forward(&mut neg_s_eval);
|
||||
// let tmp_decomp = bool_evaluator
|
||||
// .parameters()
|
||||
// .rgsw_rgsw_decomposer::<DefaultDecomposer<u64>>();
|
||||
// let tmp_gadget = tmp_decomp.a().gadget_vector()
|
||||
for j in 0..rlwe_rgsw_decomposer.a().decomposition_count() {
|
||||
// RLWE(B^{j} * -s[X]*X^{s_lwe[i]})
|
||||
|
||||
@@ -2616,7 +2629,7 @@ mod tests {
|
||||
|
||||
// check noise in RLWE(X^k) after sending RLWE(X) -> RLWE(X^k)using collective
|
||||
// auto key
|
||||
if true {
|
||||
if false {
|
||||
let mut check = Stats { samples: vec![] };
|
||||
let br_q = bool_evaluator.pbs_info.br_q();
|
||||
let g = bool_evaluator.pbs_info.g();
|
||||
@@ -2692,7 +2705,7 @@ mod tests {
|
||||
|
||||
// Check noise growth in ksk
|
||||
// TODO check in LWE key switching keys
|
||||
if true {
|
||||
if false {
|
||||
// 1. encrypt LWE ciphertext
|
||||
// 2. Key switching
|
||||
// 3.
|
||||
|
||||
@@ -486,6 +486,28 @@ pub(crate) const SMALL_MP_BOOL_PARAMS: BoolParameters<u64> = BoolParameters::<u6
|
||||
variant: ParameterVariant::MultiParty,
|
||||
};
|
||||
|
||||
pub(crate) const OPTIMISED_SMALL_MP_BOOL_PARAMS: BoolParameters<u64> = BoolParameters::<u64> {
|
||||
rlwe_q: CiphertextModulus::new_non_native(18014398509404161),
|
||||
lwe_q: CiphertextModulus::new_non_native(1 << 15),
|
||||
br_q: 1 << 11,
|
||||
rlwe_n: PolynomialSize(1 << 11),
|
||||
lwe_n: LweDimension(500),
|
||||
lwe_decomposer_params: (DecompostionLogBase(1), DecompositionCount(11)),
|
||||
rlrg_decomposer_params: (
|
||||
DecompostionLogBase(24),
|
||||
(DecompositionCount(1), DecompositionCount(1)),
|
||||
),
|
||||
rgrg_decomposer_params: Some((
|
||||
DecompostionLogBase(12),
|
||||
(DecompositionCount(3), DecompositionCount(3)),
|
||||
)),
|
||||
auto_decomposer_params: (DecompostionLogBase(20), DecompositionCount(1)),
|
||||
non_interactive_ui_to_s_key_switch_decomposer: None,
|
||||
g: 5,
|
||||
w: 10,
|
||||
variant: ParameterVariant::MultiParty,
|
||||
};
|
||||
|
||||
pub(crate) const NON_INTERACTIVE_SMALL_MP_BOOL_PARAMS: BoolParameters<u64> = BoolParameters::<u64> {
|
||||
rlwe_q: CiphertextModulus::new_non_native(36028797018820609),
|
||||
lwe_q: CiphertextModulus::new_non_native(1 << 20),
|
||||
|
||||
@@ -11,6 +11,7 @@ use std::{
|
||||
use crate::backend::{ArithmeticOps, ModularOpsU64};
|
||||
|
||||
fn gadget_vector<T: PrimInt>(logq: usize, logb: usize, d: usize) -> Vec<T> {
|
||||
assert!(logq >= (logb * d));
|
||||
let ignored_bits = logq - (logb * d);
|
||||
|
||||
(0..d)
|
||||
@@ -114,7 +115,8 @@ impl<
|
||||
+ WrappingAdd
|
||||
+ NumInfo
|
||||
+ From<bool>
|
||||
+ Display,
|
||||
+ Display
|
||||
+ Debug,
|
||||
> Decomposer for DefaultDecomposer<T>
|
||||
{
|
||||
type Element = T;
|
||||
@@ -128,6 +130,11 @@ impl<
|
||||
(T::BITS - q.leading_zeros()) as usize
|
||||
};
|
||||
|
||||
assert!(
|
||||
logq >= (logb * d),
|
||||
"Decomposer wants logq >= logb*d but got logq={logq}, logb={logb}, d={d}"
|
||||
);
|
||||
|
||||
let ignore_bits = logq - (logb * d);
|
||||
|
||||
DefaultDecomposer {
|
||||
@@ -144,20 +151,19 @@ impl<
|
||||
|
||||
// TODO(Jay): Outline the caveat
|
||||
fn decompose_to_vec(&self, value: &T) -> Vec<T> {
|
||||
let mut value = round_value(*value, self.ignore_bits);
|
||||
|
||||
let q = self.q;
|
||||
let logb = self.logb;
|
||||
let b = T::one() << logb;
|
||||
let full_mask = b - T::one();
|
||||
let bby2 = b >> 1;
|
||||
|
||||
let mut value = *value;
|
||||
if value >= (q >> 1) {
|
||||
value = !(q - value) + T::one()
|
||||
}
|
||||
|
||||
value = round_value(value, self.ignore_bits);
|
||||
let mut out = Vec::with_capacity(self.d);
|
||||
for _ in 0..self.d {
|
||||
for _ in 0..(self.d) {
|
||||
let k_i = value & full_mask;
|
||||
|
||||
value = (value - k_i) >> logb;
|
||||
@@ -178,11 +184,11 @@ impl<
|
||||
}
|
||||
|
||||
fn decompose_iter(&self, value: &T) -> DecomposerIter<T> {
|
||||
let mut value = round_value(*value, self.ignore_bits);
|
||||
|
||||
let mut value = *value;
|
||||
if value >= (self.q >> 1) {
|
||||
value = !(self.q - value) + T::one()
|
||||
}
|
||||
value = round_value(value, self.ignore_bits);
|
||||
|
||||
DecomposerIter {
|
||||
value,
|
||||
@@ -283,12 +289,13 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn decomposition_works() {
|
||||
let logq = 55;
|
||||
let logb = 12;
|
||||
let d = 4;
|
||||
let ring_size = 1 << 11;
|
||||
|
||||
let mut rng = thread_rng();
|
||||
|
||||
for logq in [37, 55] {
|
||||
let logb = 11;
|
||||
let d = 3;
|
||||
let mut stats = vec![Stats::new(); d];
|
||||
|
||||
for i in [true] {
|
||||
@@ -300,20 +307,17 @@ mod tests {
|
||||
let decomposer = DefaultDecomposer::new(q, logb, d);
|
||||
dbg!(decomposer.ignore_bits);
|
||||
let modq_op = ModularOpsU64::new(q);
|
||||
for _ in 0..100000 {
|
||||
for _ in 0..1000000 {
|
||||
let value = rng.gen_range(0..q);
|
||||
let limbs = decomposer.decompose_to_vec(&value);
|
||||
// let limbs_from_iter = decomposer.decompose_iter(&value).collect_vec();
|
||||
// assert_eq!(limbs, limbs_from_iter);
|
||||
let limbs_from_iter = decomposer.decompose_iter(&value).collect_vec();
|
||||
assert_eq!(limbs, limbs_from_iter);
|
||||
let value_back = round_value(
|
||||
decomposer.recompose(&limbs, &modq_op),
|
||||
decomposer.ignore_bits,
|
||||
);
|
||||
let rounded_value = round_value(value, decomposer.ignore_bits);
|
||||
// assert_eq!(
|
||||
// rounded_value, value_back,
|
||||
// "Expected {rounded_value} got {value_back} for q={q}"
|
||||
// );
|
||||
assert!((rounded_value as i64 - value_back as i64).abs() <= 1,);
|
||||
|
||||
izip!(stats.iter_mut(), limbs.iter()).for_each(|(s, l)| {
|
||||
s.add_more(&vec![q.map_element_to_i64(l)]);
|
||||
@@ -330,3 +334,4 @@ mod tests {
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1156,8 +1156,8 @@ pub(crate) mod tests {
|
||||
let logq = 55;
|
||||
let ring_size = 1 << 11;
|
||||
let q = generate_prime(logq, ring_size as u64, 1u64 << logq).unwrap();
|
||||
let d = 12;
|
||||
let logb = 4;
|
||||
let d = 2;
|
||||
let logb = 12;
|
||||
let decomposer = DefaultDecomposer::new(q, logb, d);
|
||||
|
||||
let ntt_op = NttBackendU64::new(&q, ring_size as usize);
|
||||
@@ -1169,16 +1169,42 @@ pub(crate) mod tests {
|
||||
for _ in 0..10 {
|
||||
let mut a = vec![0u64; ring_size];
|
||||
RandomFillUniformInModulus::random_fill(&mut rng, &q, a.as_mut());
|
||||
let mut e = vec![1u64; ring_size];
|
||||
// RandomFillGaussianInModulus::random_fill(&mut rng, &q, e.as_mut());
|
||||
let mut m = vec![0u64; ring_size];
|
||||
RandomFillGaussianInModulus::random_fill(&mut rng, &q, m.as_mut());
|
||||
|
||||
let mut sk = vec![0u64; ring_size];
|
||||
RandomFillGaussianInModulus::random_fill(&mut rng, &q, sk.as_mut());
|
||||
let mut sk_eval = sk.clone();
|
||||
ntt_op.forward(sk_eval.as_mut_slice());
|
||||
|
||||
let gadget_vector = decomposer.gadget_vector();
|
||||
|
||||
// ksk (beta e)
|
||||
let mut ksk = vec![vec![0u64; ring_size]; decomposer.decomposition_count()];
|
||||
izip!(ksk.iter_rows_mut(), gadget_vector.iter()).for_each(|(row, beta)| {
|
||||
row.as_mut_slice().copy_from_slice(e.as_ref());
|
||||
mod_op.elwise_scalar_mul_mut(row.as_mut_slice(), beta);
|
||||
let mut ksk_part_b = vec![vec![0u64; ring_size]; decomposer.decomposition_count()];
|
||||
let mut ksk_part_a = vec![vec![0u64; ring_size]; decomposer.decomposition_count()];
|
||||
izip!(
|
||||
ksk_part_b.iter_rows_mut(),
|
||||
ksk_part_a.iter_rows_mut(),
|
||||
gadget_vector.iter()
|
||||
)
|
||||
.for_each(|(part_b, part_a, beta)| {
|
||||
RandomFillUniformInModulus::random_fill(&mut rng, &q, part_a.as_mut());
|
||||
|
||||
// a * s
|
||||
let mut tmp = part_a.to_vec();
|
||||
ntt_op.forward(tmp.as_mut());
|
||||
mod_op.elwise_mul_mut(tmp.as_mut(), sk_eval.as_ref());
|
||||
ntt_op.backward(tmp.as_mut());
|
||||
|
||||
// a*s + e + beta m
|
||||
RandomFillGaussianInModulus::random_fill(&mut rng, &q, part_b.as_mut());
|
||||
// println!("E: {:?}", &part_b);
|
||||
// a*s + e
|
||||
mod_op.elwise_add_mut(part_b.as_mut_slice(), tmp.as_ref());
|
||||
// a*s + e + beta m
|
||||
let mut tmp = m.to_vec();
|
||||
mod_op.elwise_scalar_mul_mut(tmp.as_mut_slice(), beta);
|
||||
mod_op.elwise_add_mut(part_b.as_mut_slice(), tmp.as_ref());
|
||||
});
|
||||
|
||||
// decompose a
|
||||
@@ -1195,35 +1221,60 @@ pub(crate) mod tests {
|
||||
|
||||
// println!("Last limb");
|
||||
|
||||
// decomp_a * ksk(beta e)
|
||||
ksk.iter_mut()
|
||||
// decomp_a * ksk(beta m)
|
||||
ksk_part_b
|
||||
.iter_mut()
|
||||
.for_each(|r| ntt_op.forward(r.as_mut_slice()));
|
||||
ksk_part_a
|
||||
.iter_mut()
|
||||
.for_each(|r| ntt_op.forward(r.as_mut_slice()));
|
||||
decomposed_a
|
||||
.iter_mut()
|
||||
.for_each(|r| ntt_op.forward(r.as_mut_slice()));
|
||||
let mut out = vec![0u64; ring_size];
|
||||
izip!(decomposed_a.iter(), ksk.iter()).for_each(|(a, b)| {
|
||||
// out += a * b
|
||||
let mut a_clone = a.clone();
|
||||
mod_op.elwise_mul_mut(a_clone.as_mut_slice(), b.as_ref());
|
||||
mod_op.elwise_add_mut(out.as_mut_slice(), a_clone.as_ref());
|
||||
});
|
||||
ntt_op.backward(out.as_mut_slice());
|
||||
let mut out = vec![vec![0u64; ring_size]; 2];
|
||||
izip!(decomposed_a.iter(), ksk_part_b.iter(), ksk_part_a.iter()).for_each(
|
||||
|(d_a, part_b, part_a)| {
|
||||
// out_a += d_a * part_a
|
||||
let mut d_a_clone = d_a.clone();
|
||||
mod_op.elwise_mul_mut(d_a_clone.as_mut_slice(), part_a.as_ref());
|
||||
mod_op.elwise_add_mut(out[0].as_mut_slice(), d_a_clone.as_ref());
|
||||
|
||||
// out_b += d_a * part_b
|
||||
let mut d_a_clone = d_a.clone();
|
||||
mod_op.elwise_mul_mut(d_a_clone.as_mut_slice(), part_b.as_ref());
|
||||
mod_op.elwise_add_mut(out[1].as_mut_slice(), d_a_clone.as_ref());
|
||||
},
|
||||
);
|
||||
out.iter_mut()
|
||||
.for_each(|r| ntt_op.backward(r.as_mut_slice()));
|
||||
|
||||
let out_back = {
|
||||
// decrypt
|
||||
// a*s
|
||||
ntt_op.forward(out[0].as_mut());
|
||||
mod_op.elwise_mul_mut(out[0].as_mut(), sk_eval.as_ref());
|
||||
ntt_op.backward(out[0].as_mut());
|
||||
|
||||
// b - a*s
|
||||
let tmp = (out[0]).clone();
|
||||
mod_op.elwise_sub_mut(out[1].as_mut(), tmp.as_ref());
|
||||
out.remove(1)
|
||||
};
|
||||
|
||||
let out_expected = {
|
||||
let mut a_clone = a.clone();
|
||||
let mut e_clone = e.clone();
|
||||
let mut m_clone = m.clone();
|
||||
|
||||
ntt_op.forward(a_clone.as_mut_slice());
|
||||
ntt_op.forward(e_clone.as_mut_slice());
|
||||
ntt_op.forward(m_clone.as_mut_slice());
|
||||
|
||||
mod_op.elwise_mul_mut(a_clone.as_mut_slice(), e_clone.as_mut_slice());
|
||||
mod_op.elwise_mul_mut(a_clone.as_mut_slice(), m_clone.as_mut_slice());
|
||||
ntt_op.backward(a_clone.as_mut_slice());
|
||||
a_clone
|
||||
};
|
||||
|
||||
let mut diff = out_expected;
|
||||
mod_op.elwise_sub_mut(diff.as_mut_slice(), out.as_ref());
|
||||
mod_op.elwise_sub_mut(diff.as_mut_slice(), out_back.as_ref());
|
||||
stats.add_more(&Vec::<i64>::try_convert_from(diff.as_ref(), &q));
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user