mirror of
https://github.com/arnaucube/phantom-zone.git
synced 2026-01-09 23:51:30 +01:00
fix rounding in decom
This commit is contained in:
@@ -326,8 +326,14 @@ pub(super) struct BoolPbsInfo<M: Matrix, Ntt, RlweModOp, LweModOp> {
|
|||||||
|
|
||||||
impl<M: Matrix, NttOp, RlweModOp, LweModOp> PbsInfo for BoolPbsInfo<M, NttOp, RlweModOp, LweModOp>
|
impl<M: Matrix, NttOp, RlweModOp, LweModOp> PbsInfo for BoolPbsInfo<M, NttOp, RlweModOp, LweModOp>
|
||||||
where
|
where
|
||||||
M::MatElement:
|
M::MatElement: PrimInt
|
||||||
PrimInt + WrappingSub + NumInfo + FromPrimitive + From<bool> + Display + WrappingAdd,
|
+ WrappingSub
|
||||||
|
+ NumInfo
|
||||||
|
+ FromPrimitive
|
||||||
|
+ From<bool>
|
||||||
|
+ Display
|
||||||
|
+ WrappingAdd
|
||||||
|
+ Debug,
|
||||||
RlweModOp: ArithmeticOps<Element = M::MatElement> + ShoupMatrixFMA<M::R>,
|
RlweModOp: ArithmeticOps<Element = M::MatElement> + ShoupMatrixFMA<M::R>,
|
||||||
LweModOp: ArithmeticOps<Element = M::MatElement> + VectorOps<Element = M::MatElement>,
|
LweModOp: ArithmeticOps<Element = M::MatElement> + VectorOps<Element = M::MatElement>,
|
||||||
NttOp: Ntt<Element = M::MatElement>,
|
NttOp: Ntt<Element = M::MatElement>,
|
||||||
@@ -2003,7 +2009,8 @@ where
|
|||||||
+ WrappingSub
|
+ WrappingSub
|
||||||
+ NumInfo
|
+ NumInfo
|
||||||
+ From<bool>
|
+ From<bool>
|
||||||
+ WrappingAdd,
|
+ WrappingAdd
|
||||||
|
+ Debug,
|
||||||
RlweModOp: VectorOps<Element = M::MatElement>
|
RlweModOp: VectorOps<Element = M::MatElement>
|
||||||
+ ArithmeticOps<Element = M::MatElement>
|
+ ArithmeticOps<Element = M::MatElement>
|
||||||
+ ShoupMatrixFMA<M::R>,
|
+ ShoupMatrixFMA<M::R>,
|
||||||
@@ -2195,7 +2202,9 @@ mod tests {
|
|||||||
SP_TEST_BOOL_PARAMS,
|
SP_TEST_BOOL_PARAMS,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
evaluator,
|
||||||
ntt::NttBackendU64,
|
ntt::NttBackendU64,
|
||||||
|
parameters::OPTIMISED_SMALL_MP_BOOL_PARAMS,
|
||||||
random::{RandomElementInModulus, DEFAULT_RNG},
|
random::{RandomElementInModulus, DEFAULT_RNG},
|
||||||
rgsw::{
|
rgsw::{
|
||||||
self, measure_noise, public_key_encrypt_rlwe, secret_key_encrypt_rlwe,
|
self, measure_noise, public_key_encrypt_rlwe, secret_key_encrypt_rlwe,
|
||||||
@@ -2216,11 +2225,11 @@ mod tests {
|
|||||||
ModularOpsU64<CiphertextModulus<u64>>,
|
ModularOpsU64<CiphertextModulus<u64>>,
|
||||||
ModularOpsU64<CiphertextModulus<u64>>,
|
ModularOpsU64<CiphertextModulus<u64>>,
|
||||||
ShoupServerKeyEvaluationDomain<Vec<Vec<u64>>>,
|
ShoupServerKeyEvaluationDomain<Vec<Vec<u64>>>,
|
||||||
>::new(SMALL_MP_BOOL_PARAMS);
|
>::new(OPTIMISED_SMALL_MP_BOOL_PARAMS);
|
||||||
|
|
||||||
// let (_, collective_pk, _, _, server_key_eval, ideal_client_key) =
|
// let (_, collective_pk, _, _, server_key_eval, ideal_client_key) =
|
||||||
// _multi_party_all_keygen(&bool_evaluator, 20);
|
// _multi_party_all_keygen(&bool_evaluator, 20);
|
||||||
let no_of_parties = 16;
|
let no_of_parties = 2;
|
||||||
let lwe_q = bool_evaluator.pbs_info.parameters.lwe_q();
|
let lwe_q = bool_evaluator.pbs_info.parameters.lwe_q();
|
||||||
let rlwe_q = bool_evaluator.pbs_info.parameters.rlwe_q();
|
let rlwe_q = bool_evaluator.pbs_info.parameters.rlwe_q();
|
||||||
let lwe_n = bool_evaluator.pbs_info.parameters.lwe_n().0;
|
let lwe_n = bool_evaluator.pbs_info.parameters.lwe_n().0;
|
||||||
@@ -2269,7 +2278,7 @@ mod tests {
|
|||||||
});
|
});
|
||||||
|
|
||||||
// check noise in freshly encrypted RLWE ciphertext (ie var_fresh)
|
// check noise in freshly encrypted RLWE ciphertext (ie var_fresh)
|
||||||
if true {
|
if false {
|
||||||
let mut rng = DefaultSecureRng::new();
|
let mut rng = DefaultSecureRng::new();
|
||||||
let mut check = Stats { samples: vec![] };
|
let mut check = Stats { samples: vec![] };
|
||||||
for _ in 0..10 {
|
for _ in 0..10 {
|
||||||
@@ -2343,7 +2352,7 @@ mod tests {
|
|||||||
bool_evaluator.aggregate_multi_party_server_key_shares(&server_key_shares);
|
bool_evaluator.aggregate_multi_party_server_key_shares(&server_key_shares);
|
||||||
|
|
||||||
// Check noise in RGSW ciphertexts of ideal LWE secret elements
|
// Check noise in RGSW ciphertexts of ideal LWE secret elements
|
||||||
if false {
|
if true {
|
||||||
let mut check = Stats { samples: vec![] };
|
let mut check = Stats { samples: vec![] };
|
||||||
izip!(ideal_lwe_sk.iter(), seeded_server_key.rgsw_cts().iter()).for_each(
|
izip!(ideal_lwe_sk.iter(), seeded_server_key.rgsw_cts().iter()).for_each(
|
||||||
|(s_i, rgsw_ct_i)| {
|
|(s_i, rgsw_ct_i)| {
|
||||||
@@ -2361,6 +2370,10 @@ mod tests {
|
|||||||
Vec::<u64>::try_convert_from(ideal_rlwe_sk.as_slice(), rlwe_q);
|
Vec::<u64>::try_convert_from(ideal_rlwe_sk.as_slice(), rlwe_q);
|
||||||
rlwe_modop.elwise_neg_mut(&mut neg_s_eval);
|
rlwe_modop.elwise_neg_mut(&mut neg_s_eval);
|
||||||
rlwe_nttop.forward(&mut neg_s_eval);
|
rlwe_nttop.forward(&mut neg_s_eval);
|
||||||
|
// let tmp_decomp = bool_evaluator
|
||||||
|
// .parameters()
|
||||||
|
// .rgsw_rgsw_decomposer::<DefaultDecomposer<u64>>();
|
||||||
|
// let tmp_gadget = tmp_decomp.a().gadget_vector()
|
||||||
for j in 0..rlwe_rgsw_decomposer.a().decomposition_count() {
|
for j in 0..rlwe_rgsw_decomposer.a().decomposition_count() {
|
||||||
// RLWE(B^{j} * -s[X]*X^{s_lwe[i]})
|
// RLWE(B^{j} * -s[X]*X^{s_lwe[i]})
|
||||||
|
|
||||||
@@ -2616,7 +2629,7 @@ mod tests {
|
|||||||
|
|
||||||
// check noise in RLWE(X^k) after sending RLWE(X) -> RLWE(X^k)using collective
|
// check noise in RLWE(X^k) after sending RLWE(X) -> RLWE(X^k)using collective
|
||||||
// auto key
|
// auto key
|
||||||
if true {
|
if false {
|
||||||
let mut check = Stats { samples: vec![] };
|
let mut check = Stats { samples: vec![] };
|
||||||
let br_q = bool_evaluator.pbs_info.br_q();
|
let br_q = bool_evaluator.pbs_info.br_q();
|
||||||
let g = bool_evaluator.pbs_info.g();
|
let g = bool_evaluator.pbs_info.g();
|
||||||
@@ -2692,7 +2705,7 @@ mod tests {
|
|||||||
|
|
||||||
// Check noise growth in ksk
|
// Check noise growth in ksk
|
||||||
// TODO check in LWE key switching keys
|
// TODO check in LWE key switching keys
|
||||||
if true {
|
if false {
|
||||||
// 1. encrypt LWE ciphertext
|
// 1. encrypt LWE ciphertext
|
||||||
// 2. Key switching
|
// 2. Key switching
|
||||||
// 3.
|
// 3.
|
||||||
|
|||||||
@@ -486,6 +486,28 @@ pub(crate) const SMALL_MP_BOOL_PARAMS: BoolParameters<u64> = BoolParameters::<u6
|
|||||||
variant: ParameterVariant::MultiParty,
|
variant: ParameterVariant::MultiParty,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
pub(crate) const OPTIMISED_SMALL_MP_BOOL_PARAMS: BoolParameters<u64> = BoolParameters::<u64> {
|
||||||
|
rlwe_q: CiphertextModulus::new_non_native(18014398509404161),
|
||||||
|
lwe_q: CiphertextModulus::new_non_native(1 << 15),
|
||||||
|
br_q: 1 << 11,
|
||||||
|
rlwe_n: PolynomialSize(1 << 11),
|
||||||
|
lwe_n: LweDimension(500),
|
||||||
|
lwe_decomposer_params: (DecompostionLogBase(1), DecompositionCount(11)),
|
||||||
|
rlrg_decomposer_params: (
|
||||||
|
DecompostionLogBase(24),
|
||||||
|
(DecompositionCount(1), DecompositionCount(1)),
|
||||||
|
),
|
||||||
|
rgrg_decomposer_params: Some((
|
||||||
|
DecompostionLogBase(12),
|
||||||
|
(DecompositionCount(3), DecompositionCount(3)),
|
||||||
|
)),
|
||||||
|
auto_decomposer_params: (DecompostionLogBase(20), DecompositionCount(1)),
|
||||||
|
non_interactive_ui_to_s_key_switch_decomposer: None,
|
||||||
|
g: 5,
|
||||||
|
w: 10,
|
||||||
|
variant: ParameterVariant::MultiParty,
|
||||||
|
};
|
||||||
|
|
||||||
pub(crate) const NON_INTERACTIVE_SMALL_MP_BOOL_PARAMS: BoolParameters<u64> = BoolParameters::<u64> {
|
pub(crate) const NON_INTERACTIVE_SMALL_MP_BOOL_PARAMS: BoolParameters<u64> = BoolParameters::<u64> {
|
||||||
rlwe_q: CiphertextModulus::new_non_native(36028797018820609),
|
rlwe_q: CiphertextModulus::new_non_native(36028797018820609),
|
||||||
lwe_q: CiphertextModulus::new_non_native(1 << 20),
|
lwe_q: CiphertextModulus::new_non_native(1 << 20),
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ use std::{
|
|||||||
use crate::backend::{ArithmeticOps, ModularOpsU64};
|
use crate::backend::{ArithmeticOps, ModularOpsU64};
|
||||||
|
|
||||||
fn gadget_vector<T: PrimInt>(logq: usize, logb: usize, d: usize) -> Vec<T> {
|
fn gadget_vector<T: PrimInt>(logq: usize, logb: usize, d: usize) -> Vec<T> {
|
||||||
|
assert!(logq >= (logb * d));
|
||||||
let ignored_bits = logq - (logb * d);
|
let ignored_bits = logq - (logb * d);
|
||||||
|
|
||||||
(0..d)
|
(0..d)
|
||||||
@@ -114,7 +115,8 @@ impl<
|
|||||||
+ WrappingAdd
|
+ WrappingAdd
|
||||||
+ NumInfo
|
+ NumInfo
|
||||||
+ From<bool>
|
+ From<bool>
|
||||||
+ Display,
|
+ Display
|
||||||
|
+ Debug,
|
||||||
> Decomposer for DefaultDecomposer<T>
|
> Decomposer for DefaultDecomposer<T>
|
||||||
{
|
{
|
||||||
type Element = T;
|
type Element = T;
|
||||||
@@ -128,6 +130,11 @@ impl<
|
|||||||
(T::BITS - q.leading_zeros()) as usize
|
(T::BITS - q.leading_zeros()) as usize
|
||||||
};
|
};
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
logq >= (logb * d),
|
||||||
|
"Decomposer wants logq >= logb*d but got logq={logq}, logb={logb}, d={d}"
|
||||||
|
);
|
||||||
|
|
||||||
let ignore_bits = logq - (logb * d);
|
let ignore_bits = logq - (logb * d);
|
||||||
|
|
||||||
DefaultDecomposer {
|
DefaultDecomposer {
|
||||||
@@ -144,20 +151,19 @@ impl<
|
|||||||
|
|
||||||
// TODO(Jay): Outline the caveat
|
// TODO(Jay): Outline the caveat
|
||||||
fn decompose_to_vec(&self, value: &T) -> Vec<T> {
|
fn decompose_to_vec(&self, value: &T) -> Vec<T> {
|
||||||
let mut value = round_value(*value, self.ignore_bits);
|
|
||||||
|
|
||||||
let q = self.q;
|
let q = self.q;
|
||||||
let logb = self.logb;
|
let logb = self.logb;
|
||||||
let b = T::one() << logb;
|
let b = T::one() << logb;
|
||||||
let full_mask = b - T::one();
|
let full_mask = b - T::one();
|
||||||
let bby2 = b >> 1;
|
let bby2 = b >> 1;
|
||||||
|
|
||||||
|
let mut value = *value;
|
||||||
if value >= (q >> 1) {
|
if value >= (q >> 1) {
|
||||||
value = !(q - value) + T::one()
|
value = !(q - value) + T::one()
|
||||||
}
|
}
|
||||||
|
value = round_value(value, self.ignore_bits);
|
||||||
let mut out = Vec::with_capacity(self.d);
|
let mut out = Vec::with_capacity(self.d);
|
||||||
for _ in 0..self.d {
|
for _ in 0..(self.d) {
|
||||||
let k_i = value & full_mask;
|
let k_i = value & full_mask;
|
||||||
|
|
||||||
value = (value - k_i) >> logb;
|
value = (value - k_i) >> logb;
|
||||||
@@ -178,11 +184,11 @@ impl<
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn decompose_iter(&self, value: &T) -> DecomposerIter<T> {
|
fn decompose_iter(&self, value: &T) -> DecomposerIter<T> {
|
||||||
let mut value = round_value(*value, self.ignore_bits);
|
let mut value = *value;
|
||||||
|
|
||||||
if value >= (self.q >> 1) {
|
if value >= (self.q >> 1) {
|
||||||
value = !(self.q - value) + T::one()
|
value = !(self.q - value) + T::one()
|
||||||
}
|
}
|
||||||
|
value = round_value(value, self.ignore_bits);
|
||||||
|
|
||||||
DecomposerIter {
|
DecomposerIter {
|
||||||
value,
|
value,
|
||||||
@@ -283,12 +289,13 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn decomposition_works() {
|
fn decomposition_works() {
|
||||||
let logq = 55;
|
|
||||||
let logb = 12;
|
|
||||||
let d = 4;
|
|
||||||
let ring_size = 1 << 11;
|
let ring_size = 1 << 11;
|
||||||
|
|
||||||
let mut rng = thread_rng();
|
let mut rng = thread_rng();
|
||||||
|
|
||||||
|
for logq in [37, 55] {
|
||||||
|
let logb = 11;
|
||||||
|
let d = 3;
|
||||||
let mut stats = vec![Stats::new(); d];
|
let mut stats = vec![Stats::new(); d];
|
||||||
|
|
||||||
for i in [true] {
|
for i in [true] {
|
||||||
@@ -300,20 +307,17 @@ mod tests {
|
|||||||
let decomposer = DefaultDecomposer::new(q, logb, d);
|
let decomposer = DefaultDecomposer::new(q, logb, d);
|
||||||
dbg!(decomposer.ignore_bits);
|
dbg!(decomposer.ignore_bits);
|
||||||
let modq_op = ModularOpsU64::new(q);
|
let modq_op = ModularOpsU64::new(q);
|
||||||
for _ in 0..100000 {
|
for _ in 0..1000000 {
|
||||||
let value = rng.gen_range(0..q);
|
let value = rng.gen_range(0..q);
|
||||||
let limbs = decomposer.decompose_to_vec(&value);
|
let limbs = decomposer.decompose_to_vec(&value);
|
||||||
// let limbs_from_iter = decomposer.decompose_iter(&value).collect_vec();
|
let limbs_from_iter = decomposer.decompose_iter(&value).collect_vec();
|
||||||
// assert_eq!(limbs, limbs_from_iter);
|
assert_eq!(limbs, limbs_from_iter);
|
||||||
let value_back = round_value(
|
let value_back = round_value(
|
||||||
decomposer.recompose(&limbs, &modq_op),
|
decomposer.recompose(&limbs, &modq_op),
|
||||||
decomposer.ignore_bits,
|
decomposer.ignore_bits,
|
||||||
);
|
);
|
||||||
let rounded_value = round_value(value, decomposer.ignore_bits);
|
let rounded_value = round_value(value, decomposer.ignore_bits);
|
||||||
// assert_eq!(
|
assert!((rounded_value as i64 - value_back as i64).abs() <= 1,);
|
||||||
// rounded_value, value_back,
|
|
||||||
// "Expected {rounded_value} got {value_back} for q={q}"
|
|
||||||
// );
|
|
||||||
|
|
||||||
izip!(stats.iter_mut(), limbs.iter()).for_each(|(s, l)| {
|
izip!(stats.iter_mut(), limbs.iter()).for_each(|(s, l)| {
|
||||||
s.add_more(&vec![q.map_element_to_i64(l)]);
|
s.add_more(&vec![q.map_element_to_i64(l)]);
|
||||||
@@ -330,3 +334,4 @@ mod tests {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1156,8 +1156,8 @@ pub(crate) mod tests {
|
|||||||
let logq = 55;
|
let logq = 55;
|
||||||
let ring_size = 1 << 11;
|
let ring_size = 1 << 11;
|
||||||
let q = generate_prime(logq, ring_size as u64, 1u64 << logq).unwrap();
|
let q = generate_prime(logq, ring_size as u64, 1u64 << logq).unwrap();
|
||||||
let d = 12;
|
let d = 2;
|
||||||
let logb = 4;
|
let logb = 12;
|
||||||
let decomposer = DefaultDecomposer::new(q, logb, d);
|
let decomposer = DefaultDecomposer::new(q, logb, d);
|
||||||
|
|
||||||
let ntt_op = NttBackendU64::new(&q, ring_size as usize);
|
let ntt_op = NttBackendU64::new(&q, ring_size as usize);
|
||||||
@@ -1169,16 +1169,42 @@ pub(crate) mod tests {
|
|||||||
for _ in 0..10 {
|
for _ in 0..10 {
|
||||||
let mut a = vec![0u64; ring_size];
|
let mut a = vec![0u64; ring_size];
|
||||||
RandomFillUniformInModulus::random_fill(&mut rng, &q, a.as_mut());
|
RandomFillUniformInModulus::random_fill(&mut rng, &q, a.as_mut());
|
||||||
let mut e = vec![1u64; ring_size];
|
let mut m = vec![0u64; ring_size];
|
||||||
// RandomFillGaussianInModulus::random_fill(&mut rng, &q, e.as_mut());
|
RandomFillGaussianInModulus::random_fill(&mut rng, &q, m.as_mut());
|
||||||
|
|
||||||
|
let mut sk = vec![0u64; ring_size];
|
||||||
|
RandomFillGaussianInModulus::random_fill(&mut rng, &q, sk.as_mut());
|
||||||
|
let mut sk_eval = sk.clone();
|
||||||
|
ntt_op.forward(sk_eval.as_mut_slice());
|
||||||
|
|
||||||
let gadget_vector = decomposer.gadget_vector();
|
let gadget_vector = decomposer.gadget_vector();
|
||||||
|
|
||||||
// ksk (beta e)
|
// ksk (beta e)
|
||||||
let mut ksk = vec![vec![0u64; ring_size]; decomposer.decomposition_count()];
|
let mut ksk_part_b = vec![vec![0u64; ring_size]; decomposer.decomposition_count()];
|
||||||
izip!(ksk.iter_rows_mut(), gadget_vector.iter()).for_each(|(row, beta)| {
|
let mut ksk_part_a = vec![vec![0u64; ring_size]; decomposer.decomposition_count()];
|
||||||
row.as_mut_slice().copy_from_slice(e.as_ref());
|
izip!(
|
||||||
mod_op.elwise_scalar_mul_mut(row.as_mut_slice(), beta);
|
ksk_part_b.iter_rows_mut(),
|
||||||
|
ksk_part_a.iter_rows_mut(),
|
||||||
|
gadget_vector.iter()
|
||||||
|
)
|
||||||
|
.for_each(|(part_b, part_a, beta)| {
|
||||||
|
RandomFillUniformInModulus::random_fill(&mut rng, &q, part_a.as_mut());
|
||||||
|
|
||||||
|
// a * s
|
||||||
|
let mut tmp = part_a.to_vec();
|
||||||
|
ntt_op.forward(tmp.as_mut());
|
||||||
|
mod_op.elwise_mul_mut(tmp.as_mut(), sk_eval.as_ref());
|
||||||
|
ntt_op.backward(tmp.as_mut());
|
||||||
|
|
||||||
|
// a*s + e + beta m
|
||||||
|
RandomFillGaussianInModulus::random_fill(&mut rng, &q, part_b.as_mut());
|
||||||
|
// println!("E: {:?}", &part_b);
|
||||||
|
// a*s + e
|
||||||
|
mod_op.elwise_add_mut(part_b.as_mut_slice(), tmp.as_ref());
|
||||||
|
// a*s + e + beta m
|
||||||
|
let mut tmp = m.to_vec();
|
||||||
|
mod_op.elwise_scalar_mul_mut(tmp.as_mut_slice(), beta);
|
||||||
|
mod_op.elwise_add_mut(part_b.as_mut_slice(), tmp.as_ref());
|
||||||
});
|
});
|
||||||
|
|
||||||
// decompose a
|
// decompose a
|
||||||
@@ -1195,35 +1221,60 @@ pub(crate) mod tests {
|
|||||||
|
|
||||||
// println!("Last limb");
|
// println!("Last limb");
|
||||||
|
|
||||||
// decomp_a * ksk(beta e)
|
// decomp_a * ksk(beta m)
|
||||||
ksk.iter_mut()
|
ksk_part_b
|
||||||
|
.iter_mut()
|
||||||
|
.for_each(|r| ntt_op.forward(r.as_mut_slice()));
|
||||||
|
ksk_part_a
|
||||||
|
.iter_mut()
|
||||||
.for_each(|r| ntt_op.forward(r.as_mut_slice()));
|
.for_each(|r| ntt_op.forward(r.as_mut_slice()));
|
||||||
decomposed_a
|
decomposed_a
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.for_each(|r| ntt_op.forward(r.as_mut_slice()));
|
.for_each(|r| ntt_op.forward(r.as_mut_slice()));
|
||||||
let mut out = vec![0u64; ring_size];
|
let mut out = vec![vec![0u64; ring_size]; 2];
|
||||||
izip!(decomposed_a.iter(), ksk.iter()).for_each(|(a, b)| {
|
izip!(decomposed_a.iter(), ksk_part_b.iter(), ksk_part_a.iter()).for_each(
|
||||||
// out += a * b
|
|(d_a, part_b, part_a)| {
|
||||||
let mut a_clone = a.clone();
|
// out_a += d_a * part_a
|
||||||
mod_op.elwise_mul_mut(a_clone.as_mut_slice(), b.as_ref());
|
let mut d_a_clone = d_a.clone();
|
||||||
mod_op.elwise_add_mut(out.as_mut_slice(), a_clone.as_ref());
|
mod_op.elwise_mul_mut(d_a_clone.as_mut_slice(), part_a.as_ref());
|
||||||
});
|
mod_op.elwise_add_mut(out[0].as_mut_slice(), d_a_clone.as_ref());
|
||||||
ntt_op.backward(out.as_mut_slice());
|
|
||||||
|
// out_b += d_a * part_b
|
||||||
|
let mut d_a_clone = d_a.clone();
|
||||||
|
mod_op.elwise_mul_mut(d_a_clone.as_mut_slice(), part_b.as_ref());
|
||||||
|
mod_op.elwise_add_mut(out[1].as_mut_slice(), d_a_clone.as_ref());
|
||||||
|
},
|
||||||
|
);
|
||||||
|
out.iter_mut()
|
||||||
|
.for_each(|r| ntt_op.backward(r.as_mut_slice()));
|
||||||
|
|
||||||
|
let out_back = {
|
||||||
|
// decrypt
|
||||||
|
// a*s
|
||||||
|
ntt_op.forward(out[0].as_mut());
|
||||||
|
mod_op.elwise_mul_mut(out[0].as_mut(), sk_eval.as_ref());
|
||||||
|
ntt_op.backward(out[0].as_mut());
|
||||||
|
|
||||||
|
// b - a*s
|
||||||
|
let tmp = (out[0]).clone();
|
||||||
|
mod_op.elwise_sub_mut(out[1].as_mut(), tmp.as_ref());
|
||||||
|
out.remove(1)
|
||||||
|
};
|
||||||
|
|
||||||
let out_expected = {
|
let out_expected = {
|
||||||
let mut a_clone = a.clone();
|
let mut a_clone = a.clone();
|
||||||
let mut e_clone = e.clone();
|
let mut m_clone = m.clone();
|
||||||
|
|
||||||
ntt_op.forward(a_clone.as_mut_slice());
|
ntt_op.forward(a_clone.as_mut_slice());
|
||||||
ntt_op.forward(e_clone.as_mut_slice());
|
ntt_op.forward(m_clone.as_mut_slice());
|
||||||
|
|
||||||
mod_op.elwise_mul_mut(a_clone.as_mut_slice(), e_clone.as_mut_slice());
|
mod_op.elwise_mul_mut(a_clone.as_mut_slice(), m_clone.as_mut_slice());
|
||||||
ntt_op.backward(a_clone.as_mut_slice());
|
ntt_op.backward(a_clone.as_mut_slice());
|
||||||
a_clone
|
a_clone
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut diff = out_expected;
|
let mut diff = out_expected;
|
||||||
mod_op.elwise_sub_mut(diff.as_mut_slice(), out.as_ref());
|
mod_op.elwise_sub_mut(diff.as_mut_slice(), out_back.as_ref());
|
||||||
stats.add_more(&Vec::<i64>::try_convert_from(diff.as_ref(), &q));
|
stats.add_more(&Vec::<i64>::try_convert_from(diff.as_ref(), &q));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user