Browse Source

galois auto works

par-agg-key-shares
Janmajaya Mall 1 year ago
parent
commit
60c09a2e18
3 changed files with 425 additions and 17 deletions
  1. +12
    -0
      src/backend.rs
  2. +23
    -12
      src/decomposer.rs
  3. +390
    -5
      src/rgsw.rs

+ 12
- 0
src/backend.rs

@ -7,6 +7,7 @@ pub trait VectorOps {
fn elwise_mul(&self, out: &mut [Self::Element], a: &[Self::Element], b: &[Self::Element]); fn elwise_mul(&self, out: &mut [Self::Element], a: &[Self::Element], b: &[Self::Element]);
fn elwise_add_mut(&self, a: &mut [Self::Element], b: &[Self::Element]); fn elwise_add_mut(&self, a: &mut [Self::Element], b: &[Self::Element]);
fn elwise_sub_mut(&self, a: &mut [Self::Element], b: &[Self::Element]);
fn elwise_mul_mut(&self, a: &mut [Self::Element], b: &[Self::Element]); fn elwise_mul_mut(&self, a: &mut [Self::Element], b: &[Self::Element]);
fn elwise_neg_mut(&self, a: &mut [Self::Element]); fn elwise_neg_mut(&self, a: &mut [Self::Element]);
/// inplace mutates `a`: a = a + b*c /// inplace mutates `a`: a = a + b*c
@ -21,6 +22,7 @@ pub trait ArithmeticOps {
fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element; fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
fn add(&self, a: &Self::Element, b: &Self::Element) -> Self::Element; fn add(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element; fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
fn neg(&self, a: &Self::Element) -> Self::Element;
fn modulus(&self) -> Self::Element; fn modulus(&self) -> Self::Element;
} }
@ -115,6 +117,10 @@ impl ArithmeticOps for ModularOpsU64 {
self.sub_mod_fast(*a, *b) self.sub_mod_fast(*a, *b)
} }
fn neg(&self, a: &Self::Element) -> Self::Element {
self.q - *a
}
fn modulus(&self) -> Self::Element { fn modulus(&self) -> Self::Element {
self.q self.q
} }
@ -129,6 +135,12 @@ impl VectorOps for ModularOpsU64 {
}); });
} }
fn elwise_sub_mut(&self, a: &mut [Self::Element], b: &[Self::Element]) {
izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| {
*ai = self.sub_mod_fast(*ai, *bi);
});
}
fn elwise_mul_mut(&self, a: &mut [Self::Element], b: &[Self::Element]) { fn elwise_mul_mut(&self, a: &mut [Self::Element], b: &[Self::Element]) {
izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| { izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| {
*ai = self.mul_mod_fast(*ai, *bi); *ai = self.mul_mod_fast(*ai, *bi);

+ 23
- 12
src/decomposer.rs

@ -45,7 +45,7 @@ impl NumInfo for u128 {
impl<T: PrimInt + NumInfo + Debug> DefaultDecomposer<T> { impl<T: PrimInt + NumInfo + Debug> DefaultDecomposer<T> {
pub fn new(q: T, logb: usize, d: usize) -> DefaultDecomposer<T> { pub fn new(q: T, logb: usize, d: usize) -> DefaultDecomposer<T> {
// if q is power of 2, then BITS - leading zeros outputs logq + 1.
// if q is power of 2, then `BITS - leading_zeros` outputs logq + 1.
let logq = if q & (q - T::one()) == T::zero() { let logq = if q & (q - T::one()) == T::zero() {
(T::BITS - q.leading_zeros() - 1) as usize (T::BITS - q.leading_zeros() - 1) as usize
} else { } else {
@ -71,7 +71,6 @@ impl DefaultDecomposer {
Op: ArithmeticOps<Element = T>, Op: ArithmeticOps<Element = T>,
{ {
let mut value = T::zero(); let mut value = T::zero();
dbg!(self.ignore_limbs);
for i in 0..self.d { for i in 0..self.d {
value = modq_op.add( value = modq_op.add(
&value, &value,
@ -88,10 +87,15 @@ impl DefaultDecomposer {
impl<T: PrimInt + WrappingSub + Debug> Decomposer for DefaultDecomposer<T> { impl<T: PrimInt + WrappingSub + Debug> Decomposer for DefaultDecomposer<T> {
type Element = T; type Element = T;
fn decompose(&self, value: &T) -> Vec<T> { fn decompose(&self, value: &T) -> Vec<T> {
let value = round_value(*value, self.ignore_bits);
let mut value = round_value(*value, self.ignore_bits);
let q = self.q; let q = self.q;
if value >= (q >> 1) {
value = value.wrapping_sub(&q);
}
let logb = self.logb; let logb = self.logb;
// let b = T::one() << logb; // base
let b = T::one() << logb; // base
let b_by2 = T::one() << (logb - 1); let b_by2 = T::one() << (logb - 1);
// let neg_b_by2_modq = q - b_by2; // let neg_b_by2_modq = q - b_by2;
let full_mask = (T::one() << logb) - T::one(); let full_mask = (T::one() << logb) - T::one();
@ -100,15 +104,22 @@ impl Decomposer for DefaultDecomposer {
let mut out = Vec::<T>::with_capacity(self.d); let mut out = Vec::<T>::with_capacity(self.d);
for i in 0..self.d { for i in 0..self.d {
let mut limb = ((value >> (logb * i)) & full_mask) + carry; let mut limb = ((value >> (logb * i)) & full_mask) + carry;
carry = limb & b_by2;
limb = (q + limb) - (carry << 1);
if limb > q {
limb = limb - q;
carry = T::zero();
if limb > b_by2 {
limb = (q + limb) - b;
carry = T::one();
} }
// carry = ((q + g - limb) % q) >> logb;
// carry = limb & b_by2;
// limb = (q + limb) - (carry << 1);
// if limb > q {
// limb = limb - q;
// }
out.push(limb); out.push(limb);
carry = carry >> (logb - 1);
// carry = carry >> (logb - 1);
} }
return out; return out;
@ -154,13 +165,13 @@ mod tests {
}; };
let decomposer = DefaultDecomposer::new(q, logb, d); let decomposer = DefaultDecomposer::new(q, logb, d);
let modq_op = ModularOpsU64::new(q); let modq_op = ModularOpsU64::new(q);
for _ in 0..1 {
for _ in 0..100 {
let value = rng.gen_range(0..q); let value = rng.gen_range(0..q);
let limbs = decomposer.decompose(&value); let limbs = decomposer.decompose(&value);
let value_back = decomposer.recompose(&limbs, &modq_op); let value_back = decomposer.recompose(&limbs, &modq_op);
let rounded_value = let rounded_value =
round_value(value, decomposer.ignore_bits) << decomposer.ignore_bits; round_value(value, decomposer.ignore_bits) << decomposer.ignore_bits;
dbg!(value, rounded_value, value_back, &limbs);
dbg!(rounded_value, value, value_back);
assert_eq!( assert_eq!(
rounded_value, value_back, rounded_value, value_back,
"Expected {rounded_value} got {value_back} for q={q}" "Expected {rounded_value} got {value_back} for q={q}"

+ 390
- 5
src/rgsw.rs

@ -1,9 +1,15 @@
use itertools::izip;
use std::{
fmt::Debug,
ops::{Neg, Sub},
};
use itertools::{izip, Itertools};
use num_traits::{PrimInt, ToPrimitive};
use crate::{ use crate::{
backend::VectorOps,
backend::{ArithmeticOps, VectorOps},
decomposer::{self, Decomposer}, decomposer::{self, Decomposer},
ntt::Ntt,
ntt::{self, Ntt},
random::{DefaultSecureRng, RandomGaussianDist, RandomUniformDist}, random::{DefaultSecureRng, RandomGaussianDist, RandomUniformDist},
utils::{fill_random_ternary_secret_with_hamming_weight, TryConvertFrom, WithLocal}, utils::{fill_random_ternary_secret_with_hamming_weight, TryConvertFrom, WithLocal},
Matrix, MatrixEntity, MatrixMut, RowMut, Secret, Matrix, MatrixEntity, MatrixMut, RowMut, Secret,
@ -31,6 +37,219 @@ impl RlweSecret {
} }
} }
fn generate_auto_map(ring_size: usize, k: usize) -> (Vec<usize>, Vec<bool>) {
assert!(k & 1 == 1, "Auto {k} must be odd");
let (auto_map_index, auto_sign_index): (Vec<usize>, Vec<bool>) = (0..ring_size)
.into_iter()
.map(|i| {
let mut to_index = (i * k) % (2 * ring_size);
let mut sign = true;
// wrap around. false implies negative
if to_index >= ring_size {
to_index = to_index - ring_size;
sign = false;
}
(to_index, sign)
})
.unzip();
(auto_map_index, auto_sign_index)
}
/// Generates RLWE Key switching key to key switch ciphertext RLWE_{from_s}(m)
/// to RLWE_{to_s}(m).
///
/// Key switching equals
/// \sum decompose(c_1)_i * RLWE_{to_s}(\beta^i -from_s)
/// Hence, key switchin key equals RLWE'(-from_s) = RLWE(-from_s), RLWE(beta^1
/// -from_s), ..., RLWE(beta^{d-1} -from_s).
///
/// - ksk_out: Output Key switching key. Key switching key stores RLWE
/// ciphertexts as [RLWE'_A(-from_s) || RLWE'_B(-from_s)]
/// - neg_from_s_eval: Negative of secret polynomial to key switch from in
/// evaluation domain
/// - to_s_eval: secret polynomial to key switch to in evalution domain.
fn rlwe_ksk_gen<
Mmut: MatrixMut + MatrixEntity,
ModOp: ArithmeticOps<Element = Mmut::MatElement> + VectorOps<Element = Mmut::MatElement>,
NttOp: Ntt<Element = Mmut::MatElement>,
R: RandomGaussianDist<[Mmut::MatElement], Parameters = Mmut::MatElement>
+ RandomUniformDist<[Mmut::MatElement], Parameters = Mmut::MatElement>,
>(
ksk_out: &mut Mmut,
neg_from_s_eval: &Mmut,
to_s_eval: &Mmut,
gadget_vector: &[Mmut::MatElement],
mod_op: &ModOp,
ntt_op: &NttOp,
rng: &mut R,
) where
<Mmut as Matrix>::R: RowMut,
{
let ring_size = neg_from_s_eval.dimension().1;
let d = gadget_vector.len();
assert!(neg_from_s_eval.dimension().0 == 1);
assert!(ksk_out.dimension() == (d * 2, ring_size));
assert!(to_s_eval.dimension() == (1, ring_size));
let q = ArithmeticOps::modulus(mod_op);
let mut scratch_space = Mmut::zeros(1, ring_size);
// RLWE'_{to_s}(-from_s)
let (part_a, part_b) = ksk_out.split_at_row(d);
izip!(part_a.iter_mut(), part_b.iter_mut(), gadget_vector.iter()).for_each(
|(ai, bi, beta_i)| {
// sample ai and transform to evaluation
RandomUniformDist::random_fill(rng, &q, ai.as_mut());
ntt_op.forward(ai.as_mut());
// to_s * ai
mod_op.elwise_mul(
scratch_space.get_row_mut(0),
ai.as_ref(),
to_s_eval.get_row_slice(0),
);
// ei + to_s*ai
RandomGaussianDist::random_fill(rng, &q, bi.as_mut());
ntt_op.forward(bi.as_mut());
mod_op.elwise_add_mut(bi.as_mut(), scratch_space.get_row_slice(0));
// beta_i * -from_s
mod_op.elwise_scalar_mul(
scratch_space.get_row_mut(0),
neg_from_s_eval.get_row_slice(0),
beta_i,
);
// bi = ei + to_s*ai + beta_i*-from_s
mod_op.elwise_add_mut(bi.as_mut(), scratch_space.get_row_slice(0));
},
);
}
fn galois_key_gen<
Mmut: MatrixMut + MatrixEntity,
ModOp: ArithmeticOps<Element = Mmut::MatElement> + VectorOps<Element = Mmut::MatElement>,
NttOp: Ntt<Element = Mmut::MatElement>,
S: Secret,
R: RandomGaussianDist<[Mmut::MatElement], Parameters = Mmut::MatElement>
+ RandomUniformDist<[Mmut::MatElement], Parameters = Mmut::MatElement>,
>(
ksk_out: &mut Mmut,
s: &S,
auto_k: usize,
gadget_vector: &[Mmut::MatElement],
mod_op: &ModOp,
ntt_op: &NttOp,
rng: &mut R,
) where
<Mmut as Matrix>::R: RowMut,
Mmut: TryConvertFrom<[S::Element], Parameters = Mmut::MatElement>,
Mmut::MatElement: Copy + Sub<Output = Mmut::MatElement>,
{
let ring_size = s.values().len();
let (auto_map_index, auto_map_sign) = generate_auto_map(ring_size, auto_k);
let q = ArithmeticOps::modulus(mod_op);
// s(X) -> -s(X^k)
let mut s = Mmut::try_convert_from(s.values(), &q);
let mut neg_s_auto = Mmut::zeros(1, s.dimension().1);
izip!(s.get_row(0), auto_map_index.iter(), auto_map_sign.iter()).for_each(
|(el, to_index, sign)| {
// if sign is +ve (true), then negate because we need -s(X) (i.e. do the
// opposite than the usual case)
if *sign {
neg_s_auto.set(0, *to_index, q - *el)
} else {
neg_s_auto.set(0, *to_index, *el)
}
},
);
// send both s(X) and -s(X^k) to evaluation domain
ntt_op.forward(s.get_row_mut(0));
ntt_op.forward(neg_s_auto.get_row_mut(0));
// Ksk from -s(X^k) to s(X)
rlwe_ksk_gen(ksk_out, &neg_s_auto, &s, gadget_vector, mod_op, ntt_op, rng);
}
/// Sends RLWE_{s}(X) -> RLWE_{s}(X^k) where k is some galois element
fn galois_auto<
M: Matrix,
Mmut: MatrixMut<MatElement = M::MatElement>,
ModOp: ArithmeticOps<Element = M::MatElement> + VectorOps<Element = M::MatElement>,
NttOp: Ntt<Element = M::MatElement>,
D: Decomposer<Element = M::MatElement>,
>(
rlwe_in: &M,
ksk: &M,
rlwe_out: &mut Mmut,
a_rlwe_decomposed: &mut Mmut,
auto_map_index: &[usize],
auto_map_sign: &[bool],
mod_op: &ModOp,
ntt_op: &NttOp,
decomposer: &D,
) where
<Mmut as Matrix>::R: RowMut,
M::MatElement: Copy,
{
let d = decomposer.d();
// send b(X) -> b(X^k)
izip!(
rlwe_in.get_row(1),
auto_map_index.iter(),
auto_map_sign.iter()
)
.for_each(|(el_in, to_index, sign)| {
if !*sign {
rlwe_out.set(1, *to_index, mod_op.neg(el_in));
} else {
rlwe_out.set(1, *to_index, *el_in);
}
});
// send a(X) -> a(X^k) and decompose a(X^k)
izip!(
rlwe_in.get_row(0),
auto_map_index.iter(),
auto_map_sign.iter()
)
.for_each(|(el_in, to_index, sign)| {
let el_out = if !*sign { mod_op.neg(el_in) } else { *el_in };
let el_out_decomposed = decomposer.decompose(&el_out);
for j in 0..d {
a_rlwe_decomposed.set(j, *to_index, el_out_decomposed[j]);
}
});
// transform decomposed a(X^k) to evaluation domain
a_rlwe_decomposed.iter_rows_mut().for_each(|r| {
ntt_op.forward(r.as_mut());
});
// key switch (a(X^k) * RLWE'(s(X^k)))
izip!(a_rlwe_decomposed.iter_rows(), ksk.iter_rows().take(d)).for_each(|(a, b)| {
mod_op.elwise_fma_mut(rlwe_out.get_row_mut(0), a.as_ref(), b.as_ref());
});
ntt_op.forward(rlwe_out.get_row_mut(1));
izip!(a_rlwe_decomposed.iter_rows(), ksk.iter_rows().skip(d)).for_each(|(a, b)| {
mod_op.elwise_fma_mut(rlwe_out.get_row_mut(1), a.as_ref(), b.as_ref());
});
// transform RLWE(-s(X^k) * a(X^k)) to coefficient domain
rlwe_out
.iter_rows_mut()
.for_each(|r| ntt_op.backward(r.as_mut()));
}
/// Encrypts message m as a RGSW ciphertext. /// Encrypts message m as a RGSW ciphertext.
/// ///
/// - m_eval: is `m` is evaluation domain /// - m_eval: is `m` is evaluation domain
@ -366,11 +585,71 @@ fn decrypt_rlwe<
mod_op.elwise_add_mut(m_out.get_row_mut(0), rlwe_ct.get_row_slice(1)); mod_op.elwise_add_mut(m_out.get_row_mut(0), rlwe_ct.get_row_slice(1));
} }
// Measures noise in degree 1 RLWE ciphertext against encoded ideal message
// encoded_m
fn measure_noise<
Mmut: MatrixMut + Matrix + MatrixEntity,
ModOp: VectorOps<Element = Mmut::MatElement>,
NttOp: Ntt<Element = Mmut::MatElement>,
S: Secret,
>(
rlwe_ct: &Mmut,
encoded_m_ideal: &Mmut,
ntt_op: &NttOp,
mod_op: &ModOp,
s: &S,
) -> f64
where
<Mmut as Matrix>::R: RowMut,
Mmut: TryConvertFrom<[S::Element], Parameters = Mmut::MatElement>,
Mmut::MatElement: PrimInt + ToPrimitive + Debug,
{
let ring_size = s.values().len();
assert!(rlwe_ct.dimension() == (2, ring_size));
assert!(encoded_m_ideal.dimension() == (1, ring_size));
// -(s * a)
let q = VectorOps::modulus(mod_op);
let mut s = Mmut::try_convert_from(s.values(), &q);
ntt_op.forward(s.get_row_mut(0));
let mut a = Mmut::zeros(1, ring_size);
a.get_row_mut(0).copy_from_slice(rlwe_ct.get_row_slice(0));
ntt_op.forward(a.get_row_mut(0));
mod_op.elwise_mul_mut(s.get_row_mut(0), a.get_row_slice(0));
mod_op.elwise_neg_mut(s.get_row_mut(0));
ntt_op.backward(s.get_row_mut(0));
// m+e = b - s*a
let mut m_plus_e = s;
mod_op.elwise_add_mut(m_plus_e.get_row_mut(0), rlwe_ct.get_row_slice(1));
// difference
mod_op.elwise_sub_mut(m_plus_e.get_row_mut(0), encoded_m_ideal.get_row_slice(0));
let mut max_diff_bits = f64::MIN;
m_plus_e.get_row_slice(0).iter().for_each(|v| {
let mut v = *v;
if v >= (q >> 1) {
// v is -ve
v = q - v;
}
let bits = (v.to_f64().unwrap()).log2();
if max_diff_bits < bits {
max_diff_bits = bits;
}
});
return max_diff_bits;
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::vec; use std::vec;
use itertools::Itertools;
use itertools::{izip, Itertools};
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use crate::{ use crate::{
@ -378,10 +657,14 @@ mod tests {
decomposer::{gadget_vector, DefaultDecomposer}, decomposer::{gadget_vector, DefaultDecomposer},
ntt::{self, Ntt, NttBackendU64}, ntt::{self, Ntt, NttBackendU64},
random::{DefaultSecureRng, RandomUniformDist}, random::{DefaultSecureRng, RandomUniformDist},
rgsw::measure_noise,
utils::{generate_prime, negacyclic_mul}, utils::{generate_prime, negacyclic_mul},
}; };
use super::{decrypt_rlwe, encrypt_rgsw, encrypt_rlwe, rlwe_by_rgsw, RlweSecret};
use super::{
decrypt_rlwe, encrypt_rgsw, encrypt_rlwe, galois_auto, galois_key_gen, generate_auto_map,
rlwe_by_rgsw, RlweSecret,
};
#[test] #[test]
fn rlwe_by_rgsw_works() { fn rlwe_by_rgsw_works() {
@ -463,4 +746,106 @@ mod tests {
assert_eq!(m0m1, m0m1_back, "Expected {:?} got {:?}", m0m1, m0m1_back); assert_eq!(m0m1, m0m1_back, "Expected {:?} got {:?}", m0m1, m0m1_back);
// dbg!(&m0m1_back, m0m1, q); // dbg!(&m0m1_back, m0m1, q);
} }
#[test]
fn galois_auto_works() {
let logq = 50;
let ring_size = 1 << 5;
let q = generate_prime(logq, 2 * ring_size, 1u64 << logq).unwrap();
let logp = 3;
let p = 1u64 << logp;
let d_rgsw = 10;
let logb = 5;
let mut rng = DefaultSecureRng::new();
let s = RlweSecret::random((ring_size >> 1) as usize, ring_size as usize);
let mut m = vec![0u64; ring_size as usize];
RandomUniformDist::random_fill(&mut rng, &p, m.as_mut_slice());
let encoded_m = m
.iter()
.map(|v| (((*v as f64 * q as f64) / (p as f64)).round() as u64))
.collect_vec();
let ntt_op = NttBackendU64::new(q, ring_size as usize);
let mod_op = ModularOpsU64::new(q);
// RLWE_{s}(m)
let mut rlwe_m = vec![vec![0u64; ring_size as usize]; 2];
encrypt_rlwe(
&vec![encoded_m.clone()],
&mut rlwe_m,
&s,
&mod_op,
&ntt_op,
&mut rng,
);
let auto_k = 25;
// Generate galois key to key switch from s^k to s
let mut ksk_out = vec![vec![0u64; ring_size as usize]; d_rgsw * 2];
let gadget_vector = gadget_vector(logq, logb, d_rgsw);
galois_key_gen(
&mut ksk_out,
&s,
auto_k,
&gadget_vector,
&mod_op,
&ntt_op,
&mut rng,
);
// Send RLWE_{s}(m) -> RLWE_{s}(m^k)
let mut rlwe_m_k = vec![vec![0u64; ring_size as usize]; 2];
let mut scratch_space = vec![vec![0u64; ring_size as usize]; d_rgsw];
let (auto_map_index, auto_map_sign) = generate_auto_map(ring_size as usize, auto_k);
let decomposer = DefaultDecomposer::new(q, logb, d_rgsw);
galois_auto(
&rlwe_m,
&ksk_out,
&mut rlwe_m_k,
&mut scratch_space,
&auto_map_index,
&auto_map_sign,
&mod_op,
&ntt_op,
&decomposer,
);
// Decrypt RLWE_{s}(m^k) and check
let mut encoded_m_k_back = vec![vec![0u64; ring_size as usize]];
decrypt_rlwe(&rlwe_m_k, &s, &mut encoded_m_k_back, &ntt_op, &mod_op);
let m_k_back = encoded_m_k_back[0]
.iter()
.map(|v| (((*v as f64 * p as f64) / q as f64).round() as u64) % p)
.collect_vec();
let mut m_k = vec![0u64; ring_size as usize];
// Send \delta m -> \delta m^k
izip!(m.iter(), auto_map_index.iter(), auto_map_sign.iter()).for_each(
|(v, to_index, sign)| {
if !*sign {
m_k[*to_index] = (p - *v) % p;
} else {
m_k[*to_index] = *v;
}
},
);
{
let encoded_m_k = m_k
.iter()
.map(|v| ((*v as f64 * q as f64) / p as f64).round() as u64)
.collect_vec();
let noise = measure_noise(&rlwe_m_k, &vec![encoded_m_k], &ntt_op, &mod_op, &s);
println!("Ksk noise: {noise}");
}
// FIXME(Jay): Galios autormophism will incur high error unless we fix in
// accurate decomoposition of Decomposer when q is prime
assert_eq!(m_k_back, m_k);
// dbg!(m_k_back, m_k, q);
}
} }

Loading…
Cancel
Save