mirror of
https://github.com/arnaucube/phantom-zone.git
synced 2026-01-08 23:21:29 +01:00
decomp checks
This commit is contained in:
@@ -12,7 +12,9 @@ use std::{
|
||||
};
|
||||
|
||||
use itertools::{izip, partition, Itertools};
|
||||
use num_traits::{FromPrimitive, Num, One, Pow, PrimInt, ToPrimitive, WrappingSub, Zero};
|
||||
use num_traits::{
|
||||
FromPrimitive, Num, One, Pow, PrimInt, ToPrimitive, WrappingAdd, WrappingSub, Zero,
|
||||
};
|
||||
use rand::Rng;
|
||||
use rand_distr::uniform::SampleUniform;
|
||||
|
||||
@@ -324,7 +326,8 @@ pub(super) struct BoolPbsInfo<M: Matrix, Ntt, RlweModOp, LweModOp> {
|
||||
|
||||
impl<M: Matrix, NttOp, RlweModOp, LweModOp> PbsInfo for BoolPbsInfo<M, NttOp, RlweModOp, LweModOp>
|
||||
where
|
||||
M::MatElement: PrimInt + WrappingSub + NumInfo + FromPrimitive + From<bool> + Display,
|
||||
M::MatElement:
|
||||
PrimInt + WrappingSub + NumInfo + FromPrimitive + From<bool> + Display + WrappingAdd,
|
||||
RlweModOp: ArithmeticOps<Element = M::MatElement> + ShoupMatrixFMA<M::R>,
|
||||
LweModOp: ArithmeticOps<Element = M::MatElement> + VectorOps<Element = M::MatElement>,
|
||||
NttOp: Ntt<Element = M::MatElement>,
|
||||
@@ -515,6 +518,7 @@ where
|
||||
+ NumInfo
|
||||
+ FromPrimitive
|
||||
+ WrappingSub
|
||||
+ WrappingAdd
|
||||
+ SampleUniform
|
||||
+ From<bool>,
|
||||
NttOp: Ntt<Element = M::MatElement>,
|
||||
@@ -1990,8 +1994,16 @@ impl<M, NttOp, RlweModOp, LweModOp, Skey> BooleanGates
|
||||
where
|
||||
M: MatrixMut + MatrixEntity,
|
||||
M::R: RowMut + RowEntity + Clone,
|
||||
M::MatElement:
|
||||
PrimInt + FromPrimitive + One + Copy + Zero + Display + WrappingSub + NumInfo + From<bool>,
|
||||
M::MatElement: PrimInt
|
||||
+ FromPrimitive
|
||||
+ One
|
||||
+ Copy
|
||||
+ Zero
|
||||
+ Display
|
||||
+ WrappingSub
|
||||
+ NumInfo
|
||||
+ From<bool>
|
||||
+ WrappingAdd,
|
||||
RlweModOp: VectorOps<Element = M::MatElement>
|
||||
+ ArithmeticOps<Element = M::MatElement>
|
||||
+ ShoupMatrixFMA<M::R>,
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
use itertools::Itertools;
|
||||
use num_traits::{AsPrimitive, FromPrimitive, Num, One, PrimInt, ToPrimitive, WrappingSub, Zero};
|
||||
use itertools::{izip, Itertools};
|
||||
use num_traits::{
|
||||
AsPrimitive, FromPrimitive, Num, One, PrimInt, ToPrimitive, WrappingAdd, WrappingSub, Zero,
|
||||
};
|
||||
use std::{
|
||||
fmt::{Debug, Display},
|
||||
marker::PhantomData,
|
||||
@@ -9,11 +11,11 @@ use std::{
|
||||
use crate::backend::{ArithmeticOps, ModularOpsU64};
|
||||
|
||||
fn gadget_vector<T: PrimInt>(logq: usize, logb: usize, d: usize) -> Vec<T> {
|
||||
let d_ideal = (logq as f64 / logb as f64).ceil().to_usize().unwrap();
|
||||
let ignored_limbs = d_ideal - d;
|
||||
(ignored_limbs..ignored_limbs + d)
|
||||
let ignored_bits = logq - (logb * d);
|
||||
|
||||
(0..d)
|
||||
.into_iter()
|
||||
.map(|i| T::one() << (logb * i))
|
||||
.map(|i| T::one() << (logb * i + ignored_bits))
|
||||
.collect_vec()
|
||||
}
|
||||
|
||||
@@ -69,8 +71,6 @@ pub struct DefaultDecomposer<T> {
|
||||
d: usize,
|
||||
/// No. of bits to ignore in rounding
|
||||
ignore_bits: usize,
|
||||
/// No. of limbs to ignore in rounding. Set to ceil(logq / logb) - d
|
||||
ignore_limbs: usize,
|
||||
}
|
||||
|
||||
pub trait NumInfo {
|
||||
@@ -93,15 +93,11 @@ impl<T: PrimInt + NumInfo + Debug> DefaultDecomposer<T> {
|
||||
Op: ArithmeticOps<Element = T>,
|
||||
{
|
||||
let mut value = T::zero();
|
||||
for i in 0..self.d {
|
||||
value = modq_op.add(
|
||||
&value,
|
||||
&(modq_op.mul(
|
||||
&limbs[i],
|
||||
&(T::one() << (self.logb * (i + self.ignore_limbs))),
|
||||
)),
|
||||
)
|
||||
}
|
||||
let gadget_vector = self.gadget_vector();
|
||||
assert!(limbs.len() == gadget_vector.len());
|
||||
izip!(limbs.iter(), gadget_vector.iter())
|
||||
.for_each(|(d_el, beta)| value = modq_op.add(&value, &modq_op.mul(d_el, beta)));
|
||||
|
||||
value
|
||||
}
|
||||
|
||||
@@ -110,8 +106,16 @@ impl<T: PrimInt + NumInfo + Debug> DefaultDecomposer<T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: PrimInt + ToPrimitive + FromPrimitive + WrappingSub + NumInfo + From<bool> + Display>
|
||||
Decomposer for DefaultDecomposer<T>
|
||||
impl<
|
||||
T: PrimInt
|
||||
+ ToPrimitive
|
||||
+ FromPrimitive
|
||||
+ WrappingSub
|
||||
+ WrappingAdd
|
||||
+ NumInfo
|
||||
+ From<bool>
|
||||
+ Display,
|
||||
> Decomposer for DefaultDecomposer<T>
|
||||
{
|
||||
type Element = T;
|
||||
type Iter = DecomposerIter<T>;
|
||||
@@ -124,9 +128,7 @@ impl<T: PrimInt + ToPrimitive + FromPrimitive + WrappingSub + NumInfo + From<boo
|
||||
(T::BITS - q.leading_zeros()) as usize
|
||||
};
|
||||
|
||||
let d_ideal = (logq as f64 / logb as f64).ceil().to_usize().unwrap();
|
||||
let ignore_limbs = (d_ideal - d);
|
||||
let ignore_bits = (d_ideal - d) * logb;
|
||||
let ignore_bits = logq - (logb * d);
|
||||
|
||||
DefaultDecomposer {
|
||||
q,
|
||||
@@ -137,7 +139,6 @@ impl<T: PrimInt + ToPrimitive + FromPrimitive + WrappingSub + NumInfo + From<boo
|
||||
bby2: T::one() << (logb - 1),
|
||||
d,
|
||||
ignore_bits,
|
||||
ignore_limbs,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -257,23 +258,23 @@ impl<T: PrimInt + From<bool> + WrappingSub + Display> Iterator for DecomposerIte
|
||||
}
|
||||
}
|
||||
|
||||
fn round_value<T: PrimInt>(value: T, ignore_bits: usize) -> T {
|
||||
fn round_value<T: PrimInt + WrappingAdd>(value: T, ignore_bits: usize) -> T {
|
||||
if ignore_bits == 0 {
|
||||
return value;
|
||||
}
|
||||
|
||||
let ignored_msb = (value & ((T::one() << ignore_bits) - T::one())) >> (ignore_bits - 1);
|
||||
(value >> ignore_bits) + ignored_msb
|
||||
(value >> ignore_bits).wrapping_add(&ignored_msb)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use itertools::Itertools;
|
||||
use itertools::{izip, Itertools};
|
||||
use rand::{thread_rng, Rng};
|
||||
|
||||
use crate::{
|
||||
backend::{ModInit, ModularOpsU64},
|
||||
backend::{ModInit, ModularOpsU64, Modulus},
|
||||
decomposer::round_value,
|
||||
utils::{generate_prime, tests::Stats, TryConvertFrom1},
|
||||
};
|
||||
@@ -283,12 +284,12 @@ mod tests {
|
||||
#[test]
|
||||
fn decomposition_works() {
|
||||
let logq = 55;
|
||||
let logb = 11;
|
||||
let d = 5;
|
||||
let logb = 12;
|
||||
let d = 4;
|
||||
let ring_size = 1 << 11;
|
||||
|
||||
let mut rng = thread_rng();
|
||||
let mut stats = Stats::new();
|
||||
let mut stats = vec![Stats::new(); d];
|
||||
|
||||
for i in [true] {
|
||||
let q = if i {
|
||||
@@ -297,26 +298,35 @@ mod tests {
|
||||
1u64 << logq
|
||||
};
|
||||
let decomposer = DefaultDecomposer::new(q, logb, d);
|
||||
dbg!(decomposer.ignore_bits);
|
||||
let modq_op = ModularOpsU64::new(q);
|
||||
for _ in 0..100000 {
|
||||
let value = rng.gen_range(0..q);
|
||||
let limbs = decomposer.decompose_to_vec(&value);
|
||||
let limbs_from_iter = decomposer.decompose_iter(&value).collect_vec();
|
||||
assert_eq!(limbs, limbs_from_iter);
|
||||
// let limbs_from_iter = decomposer.decompose_iter(&value).collect_vec();
|
||||
// assert_eq!(limbs, limbs_from_iter);
|
||||
let value_back = round_value(
|
||||
decomposer.recompose(&limbs, &modq_op),
|
||||
decomposer.ignore_bits,
|
||||
);
|
||||
let rounded_value = round_value(value, decomposer.ignore_bits);
|
||||
assert_eq!(
|
||||
rounded_value, value_back,
|
||||
"Expected {rounded_value} got {value_back} for q={q}"
|
||||
);
|
||||
// assert_eq!(
|
||||
// rounded_value, value_back,
|
||||
// "Expected {rounded_value} got {value_back} for q={q}"
|
||||
// );
|
||||
|
||||
stats.add_more(&Vec::<i64>::try_convert_from(&limbs, &q));
|
||||
izip!(stats.iter_mut(), limbs.iter()).for_each(|(s, l)| {
|
||||
s.add_more(&vec![q.map_element_to_i64(l)]);
|
||||
});
|
||||
}
|
||||
}
|
||||
println!("Mean: {}", stats.mean());
|
||||
println!("Std: {}", stats.std_dev().abs().log2());
|
||||
|
||||
stats.iter().enumerate().for_each(|(index, s)| {
|
||||
println!(
|
||||
"Limb {index} - Mean: {}, Std: {}",
|
||||
s.mean(),
|
||||
s.std_dev().abs().log2()
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
147
src/rgsw/mod.rs
147
src/rgsw/mod.rs
@@ -504,7 +504,7 @@ impl RlweSecret {
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod tests {
|
||||
use std::{marker::PhantomData, ops::Mul, vec};
|
||||
use std::{clone, marker::PhantomData, ops::Mul, vec};
|
||||
|
||||
use itertools::{izip, Itertools};
|
||||
use rand::{thread_rng, Rng};
|
||||
@@ -513,13 +513,13 @@ pub(crate) mod tests {
|
||||
backend::{GetModulus, ModInit, ModularOpsU64, Modulus, VectorOps},
|
||||
decomposer::{Decomposer, DefaultDecomposer, RlweDecomposer},
|
||||
ntt::{Ntt, NttBackendU64, NttInit},
|
||||
random::{DefaultSecureRng, RandomFillUniformInModulus},
|
||||
random::{DefaultSecureRng, RandomFillGaussianInModulus, RandomFillUniformInModulus},
|
||||
rgsw::{
|
||||
galois_auto_shoup, rlwe_by_rgsw_shoup, ShoupAutoKeyEvaluationDomain,
|
||||
ShoupRgswCiphertextEvaluationDomain,
|
||||
},
|
||||
utils::{generate_prime, negacyclic_mul, tests::Stats, TryConvertFrom1},
|
||||
Matrix, Secret,
|
||||
Matrix, MatrixMut, Secret,
|
||||
};
|
||||
|
||||
use super::{
|
||||
@@ -1153,105 +1153,80 @@ pub(crate) mod tests {
|
||||
|
||||
#[test]
|
||||
fn some_work() {
|
||||
let logq = 50;
|
||||
let logq = 55;
|
||||
let ring_size = 1 << 11;
|
||||
let q = generate_prime(logq, ring_size, 1u64 << logq).unwrap();
|
||||
let d_rgsw = 10;
|
||||
let logb = 5;
|
||||
let decomposer = (
|
||||
DefaultDecomposer::new(q, logb, d_rgsw),
|
||||
DefaultDecomposer::new(q, logb, d_rgsw),
|
||||
);
|
||||
let q = generate_prime(logq, ring_size as u64, 1u64 << logq).unwrap();
|
||||
let d = 12;
|
||||
let logb = 4;
|
||||
let decomposer = DefaultDecomposer::new(q, logb, d);
|
||||
|
||||
let ntt_op = NttBackendU64::new(&q, ring_size as usize);
|
||||
let mod_op = ModularOpsU64::new(q);
|
||||
let mut rng = DefaultSecureRng::new_seeded([0u8; 32]);
|
||||
let mut rng = DefaultSecureRng::new();
|
||||
|
||||
let s = RlweSecret::random((ring_size >> 1) as usize, ring_size as usize);
|
||||
let mut stats = Stats::new();
|
||||
|
||||
let mut check = Stats { samples: vec![] };
|
||||
for _ in 0..10 {
|
||||
let mut a = vec![0u64; ring_size];
|
||||
RandomFillUniformInModulus::random_fill(&mut rng, &q, a.as_mut());
|
||||
let mut e = vec![1u64; ring_size];
|
||||
// RandomFillGaussianInModulus::random_fill(&mut rng, &q, e.as_mut());
|
||||
|
||||
for _ in 0..100 {
|
||||
let mut m0 = vec![0u64; ring_size as usize];
|
||||
m0[thread_rng().gen_range(0..ring_size) as usize] = 1;
|
||||
let mut m1 = vec![0u64; ring_size as usize];
|
||||
m1[thread_rng().gen_range(0..ring_size) as usize] = 1;
|
||||
let gadget_vector = decomposer.gadget_vector();
|
||||
|
||||
let mut rgsw_ct0 = {
|
||||
let seeded_rgsw_ct =
|
||||
_sk_encrypt_rgsw(&m0, s.values(), &decomposer, &mod_op, &ntt_op);
|
||||
RgswCiphertextEvaluationDomain::<Vec<Vec<u64>>, _,DefaultSecureRng, NttBackendU64>::from(&seeded_rgsw_ct)
|
||||
};
|
||||
let rgsw_ct1 = {
|
||||
let seeded_rgsw_ct =
|
||||
_sk_encrypt_rgsw(&m1, s.values(), &decomposer, &mod_op, &ntt_op);
|
||||
RgswCiphertextEvaluationDomain::<Vec<Vec<u64>>,_, DefaultSecureRng, NttBackendU64>::from(&seeded_rgsw_ct)
|
||||
};
|
||||
// ksk (beta e)
|
||||
let mut ksk = vec![vec![0u64; ring_size]; decomposer.decomposition_count()];
|
||||
izip!(ksk.iter_rows_mut(), gadget_vector.iter()).for_each(|(row, beta)| {
|
||||
row.as_mut_slice().copy_from_slice(e.as_ref());
|
||||
mod_op.elwise_scalar_mul_mut(row.as_mut_slice(), beta);
|
||||
});
|
||||
|
||||
// RGSW x RGSW
|
||||
// send RGSW(m0) to coefficient domain
|
||||
rgsw_ct0
|
||||
.data
|
||||
.iter_mut()
|
||||
.for_each(|r| ntt_op.backward(r.as_mut_slice()));
|
||||
let mut scratch_matrix = vec![
|
||||
vec![0u64; ring_size as usize];
|
||||
std::cmp::max(
|
||||
decomposer.a().decomposition_count(),
|
||||
decomposer.b().decomposition_count()
|
||||
) + decomposer.a().decomposition_count() * 2
|
||||
+ decomposer.b().decomposition_count() * 2
|
||||
];
|
||||
rgsw_by_rgsw_inplace(
|
||||
&mut rgsw_ct0.data,
|
||||
&rgsw_ct1.data,
|
||||
&decomposer,
|
||||
&mut scratch_matrix,
|
||||
&ntt_op,
|
||||
&mod_op,
|
||||
);
|
||||
let mut rgsw_m0m1 = rgsw_ct0;
|
||||
// Back to Evaluation for RLWExRGSW
|
||||
rgsw_m0m1
|
||||
.data
|
||||
// decompose a
|
||||
let mut decomposed_a = vec![vec![0u64; ring_size]; decomposer.decomposition_count()];
|
||||
a.iter().enumerate().for_each(|(ri, el)| {
|
||||
decomposer
|
||||
.decompose_iter(el)
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.for_each(|(j, d_el)| {
|
||||
decomposed_a[j][ri] = d_el;
|
||||
});
|
||||
});
|
||||
|
||||
// println!("Last limb");
|
||||
|
||||
// decomp_a * ksk(beta e)
|
||||
ksk.iter_mut()
|
||||
.for_each(|r| ntt_op.forward(r.as_mut_slice()));
|
||||
decomposed_a
|
||||
.iter_mut()
|
||||
.for_each(|r| ntt_op.forward(r.as_mut_slice()));
|
||||
let mut out = vec![0u64; ring_size];
|
||||
izip!(decomposed_a.iter(), ksk.iter()).for_each(|(a, b)| {
|
||||
// out += a * b
|
||||
let mut a_clone = a.clone();
|
||||
mod_op.elwise_mul_mut(a_clone.as_mut_slice(), b.as_ref());
|
||||
mod_op.elwise_add_mut(out.as_mut_slice(), a_clone.as_ref());
|
||||
});
|
||||
ntt_op.backward(out.as_mut_slice());
|
||||
|
||||
// Sample m2, encrypt it as RLWE(m2) and multiply RLWE(m2)xRGSW(m0m1)
|
||||
let mut m2 = vec![0u64; ring_size as usize];
|
||||
RandomFillUniformInModulus::random_fill(&mut rng, &q, m2.as_mut_slice());
|
||||
let mut rlwe_in_ct = { _sk_encrypt_rlwe(&m2, s.values(), &ntt_op, &mod_op) };
|
||||
let mut scratch_space = vec![
|
||||
vec![0u64; ring_size as usize];
|
||||
std::cmp::max(
|
||||
decomposer.a().decomposition_count(),
|
||||
decomposer.b().decomposition_count()
|
||||
) + 2
|
||||
];
|
||||
rlwe_by_rgsw(
|
||||
&mut rlwe_in_ct,
|
||||
&rgsw_m0m1.data,
|
||||
&mut scratch_space,
|
||||
&decomposer,
|
||||
&ntt_op,
|
||||
&mod_op,
|
||||
);
|
||||
let out_expected = {
|
||||
let mut a_clone = a.clone();
|
||||
let mut e_clone = e.clone();
|
||||
|
||||
// Decrypt RLWE(m0m1m2)
|
||||
let mut m0m1m2_back = vec![0u64; ring_size as usize];
|
||||
decrypt_rlwe(&rlwe_in_ct, s.values(), &mut m0m1m2_back, &ntt_op, &mod_op);
|
||||
ntt_op.forward(a_clone.as_mut_slice());
|
||||
ntt_op.forward(e_clone.as_mut_slice());
|
||||
|
||||
// Calculate m0m1m2
|
||||
let mul_mod = |v0: &u64, v1: &u64| ((*v0 as u128 * *v1 as u128) % q as u128) as u64;
|
||||
let m0m1 = negacyclic_mul(&m0, &m1, mul_mod, q);
|
||||
let m0m1m2 = negacyclic_mul(&m2, &m0m1, mul_mod, q);
|
||||
mod_op.elwise_mul_mut(a_clone.as_mut_slice(), e_clone.as_mut_slice());
|
||||
ntt_op.backward(a_clone.as_mut_slice());
|
||||
a_clone
|
||||
};
|
||||
|
||||
// diff
|
||||
mod_op.elwise_sub_mut(m0m1m2_back.as_mut_slice(), m0m1m2.as_ref());
|
||||
|
||||
check.add_more(&Vec::<i64>::try_convert_from(&m0m1m2_back, &q));
|
||||
let mut diff = out_expected;
|
||||
mod_op.elwise_sub_mut(diff.as_mut_slice(), out.as_ref());
|
||||
stats.add_more(&Vec::<i64>::try_convert_from(diff.as_ref(), &q));
|
||||
}
|
||||
|
||||
println!("Std: {}", check.std_dev().abs().log2());
|
||||
println!("Std: {}", stats.std_dev().abs().log2());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -243,6 +243,7 @@ pub(crate) mod tests {
|
||||
|
||||
use super::fill_random_ternary_secret_with_hamming_weight;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct Stats<T> {
|
||||
pub(crate) samples: Vec<T>,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user