diff --git a/src/backend.rs b/src/backend.rs index 2b3f92c..2b21f1b 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -7,6 +7,7 @@ pub trait VectorOps { fn elwise_mul(&self, out: &mut [Self::Element], a: &[Self::Element], b: &[Self::Element]); fn elwise_add_mut(&self, a: &mut [Self::Element], b: &[Self::Element]); + fn elwise_sub_mut(&self, a: &mut [Self::Element], b: &[Self::Element]); fn elwise_mul_mut(&self, a: &mut [Self::Element], b: &[Self::Element]); fn elwise_neg_mut(&self, a: &mut [Self::Element]); /// inplace mutates `a`: a = a + b*c @@ -21,6 +22,7 @@ pub trait ArithmeticOps { fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element; fn add(&self, a: &Self::Element, b: &Self::Element) -> Self::Element; fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element; + fn neg(&self, a: &Self::Element) -> Self::Element; fn modulus(&self) -> Self::Element; } @@ -115,6 +117,10 @@ impl ArithmeticOps for ModularOpsU64 { self.sub_mod_fast(*a, *b) } + fn neg(&self, a: &Self::Element) -> Self::Element { + self.q - *a + } + fn modulus(&self) -> Self::Element { self.q } @@ -129,6 +135,12 @@ impl VectorOps for ModularOpsU64 { }); } + fn elwise_sub_mut(&self, a: &mut [Self::Element], b: &[Self::Element]) { + izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| { + *ai = self.sub_mod_fast(*ai, *bi); + }); + } + fn elwise_mul_mut(&self, a: &mut [Self::Element], b: &[Self::Element]) { izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| { *ai = self.mul_mod_fast(*ai, *bi); diff --git a/src/decomposer.rs b/src/decomposer.rs index fc291ba..f148705 100644 --- a/src/decomposer.rs +++ b/src/decomposer.rs @@ -45,7 +45,7 @@ impl NumInfo for u128 { impl DefaultDecomposer { pub fn new(q: T, logb: usize, d: usize) -> DefaultDecomposer { - // if q is power of 2, then BITS - leading zeros outputs logq + 1. + // if q is power of 2, then `BITS - leading_zeros` outputs logq + 1. let logq = if q & (q - T::one()) == T::zero() { (T::BITS - q.leading_zeros() - 1) as usize } else { @@ -71,7 +71,6 @@ impl DefaultDecomposer { Op: ArithmeticOps, { let mut value = T::zero(); - dbg!(self.ignore_limbs); for i in 0..self.d { value = modq_op.add( &value, @@ -88,10 +87,15 @@ impl DefaultDecomposer { impl Decomposer for DefaultDecomposer { type Element = T; fn decompose(&self, value: &T) -> Vec { - let value = round_value(*value, self.ignore_bits); + let mut value = round_value(*value, self.ignore_bits); + let q = self.q; + if value >= (q >> 1) { + value = value.wrapping_sub(&q); + } + let logb = self.logb; - // let b = T::one() << logb; // base + let b = T::one() << logb; // base let b_by2 = T::one() << (logb - 1); // let neg_b_by2_modq = q - b_by2; let full_mask = (T::one() << logb) - T::one(); @@ -100,15 +104,22 @@ impl Decomposer for DefaultDecomposer { let mut out = Vec::::with_capacity(self.d); for i in 0..self.d { let mut limb = ((value >> (logb * i)) & full_mask) + carry; - - carry = limb & b_by2; - limb = (q + limb) - (carry << 1); - if limb > q { - limb = limb - q; + carry = T::zero(); + if limb > b_by2 { + limb = (q + limb) - b; + carry = T::one(); } + + // carry = ((q + g - limb) % q) >> logb; + + // carry = limb & b_by2; + // limb = (q + limb) - (carry << 1); + // if limb > q { + // limb = limb - q; + // } out.push(limb); - carry = carry >> (logb - 1); + // carry = carry >> (logb - 1); } return out; @@ -154,13 +165,13 @@ mod tests { }; let decomposer = DefaultDecomposer::new(q, logb, d); let modq_op = ModularOpsU64::new(q); - for _ in 0..1 { + for _ in 0..100 { let value = rng.gen_range(0..q); let limbs = decomposer.decompose(&value); let value_back = decomposer.recompose(&limbs, &modq_op); let rounded_value = round_value(value, decomposer.ignore_bits) << decomposer.ignore_bits; - dbg!(value, rounded_value, value_back, &limbs); + dbg!(rounded_value, value, value_back); assert_eq!( rounded_value, value_back, "Expected {rounded_value} got {value_back} for q={q}" diff --git a/src/rgsw.rs b/src/rgsw.rs index 72a30cf..84f473d 100644 --- a/src/rgsw.rs +++ b/src/rgsw.rs @@ -1,9 +1,15 @@ -use itertools::izip; +use std::{ + fmt::Debug, + ops::{Neg, Sub}, +}; + +use itertools::{izip, Itertools}; +use num_traits::{PrimInt, ToPrimitive}; use crate::{ - backend::VectorOps, + backend::{ArithmeticOps, VectorOps}, decomposer::{self, Decomposer}, - ntt::Ntt, + ntt::{self, Ntt}, random::{DefaultSecureRng, RandomGaussianDist, RandomUniformDist}, utils::{fill_random_ternary_secret_with_hamming_weight, TryConvertFrom, WithLocal}, Matrix, MatrixEntity, MatrixMut, RowMut, Secret, @@ -31,6 +37,219 @@ impl RlweSecret { } } +fn generate_auto_map(ring_size: usize, k: usize) -> (Vec, Vec) { + assert!(k & 1 == 1, "Auto {k} must be odd"); + let (auto_map_index, auto_sign_index): (Vec, Vec) = (0..ring_size) + .into_iter() + .map(|i| { + let mut to_index = (i * k) % (2 * ring_size); + let mut sign = true; + + // wrap around. false implies negative + if to_index >= ring_size { + to_index = to_index - ring_size; + sign = false; + } + + (to_index, sign) + }) + .unzip(); + (auto_map_index, auto_sign_index) +} + +/// Generates RLWE Key switching key to key switch ciphertext RLWE_{from_s}(m) +/// to RLWE_{to_s}(m). +/// +/// Key switching equals +/// \sum decompose(c_1)_i * RLWE_{to_s}(\beta^i -from_s) +/// Hence, key switchin key equals RLWE'(-from_s) = RLWE(-from_s), RLWE(beta^1 +/// -from_s), ..., RLWE(beta^{d-1} -from_s). +/// +/// - ksk_out: Output Key switching key. Key switching key stores RLWE +/// ciphertexts as [RLWE'_A(-from_s) || RLWE'_B(-from_s)] +/// - neg_from_s_eval: Negative of secret polynomial to key switch from in +/// evaluation domain +/// - to_s_eval: secret polynomial to key switch to in evalution domain. +fn rlwe_ksk_gen< + Mmut: MatrixMut + MatrixEntity, + ModOp: ArithmeticOps + VectorOps, + NttOp: Ntt, + R: RandomGaussianDist<[Mmut::MatElement], Parameters = Mmut::MatElement> + + RandomUniformDist<[Mmut::MatElement], Parameters = Mmut::MatElement>, +>( + ksk_out: &mut Mmut, + neg_from_s_eval: &Mmut, + to_s_eval: &Mmut, + gadget_vector: &[Mmut::MatElement], + mod_op: &ModOp, + ntt_op: &NttOp, + rng: &mut R, +) where + ::R: RowMut, +{ + let ring_size = neg_from_s_eval.dimension().1; + let d = gadget_vector.len(); + assert!(neg_from_s_eval.dimension().0 == 1); + assert!(ksk_out.dimension() == (d * 2, ring_size)); + assert!(to_s_eval.dimension() == (1, ring_size)); + + let q = ArithmeticOps::modulus(mod_op); + + let mut scratch_space = Mmut::zeros(1, ring_size); + + // RLWE'_{to_s}(-from_s) + let (part_a, part_b) = ksk_out.split_at_row(d); + izip!(part_a.iter_mut(), part_b.iter_mut(), gadget_vector.iter()).for_each( + |(ai, bi, beta_i)| { + // sample ai and transform to evaluation + RandomUniformDist::random_fill(rng, &q, ai.as_mut()); + ntt_op.forward(ai.as_mut()); + + // to_s * ai + mod_op.elwise_mul( + scratch_space.get_row_mut(0), + ai.as_ref(), + to_s_eval.get_row_slice(0), + ); + + // ei + to_s*ai + RandomGaussianDist::random_fill(rng, &q, bi.as_mut()); + ntt_op.forward(bi.as_mut()); + mod_op.elwise_add_mut(bi.as_mut(), scratch_space.get_row_slice(0)); + + // beta_i * -from_s + mod_op.elwise_scalar_mul( + scratch_space.get_row_mut(0), + neg_from_s_eval.get_row_slice(0), + beta_i, + ); + + // bi = ei + to_s*ai + beta_i*-from_s + mod_op.elwise_add_mut(bi.as_mut(), scratch_space.get_row_slice(0)); + }, + ); +} + +fn galois_key_gen< + Mmut: MatrixMut + MatrixEntity, + ModOp: ArithmeticOps + VectorOps, + NttOp: Ntt, + S: Secret, + R: RandomGaussianDist<[Mmut::MatElement], Parameters = Mmut::MatElement> + + RandomUniformDist<[Mmut::MatElement], Parameters = Mmut::MatElement>, +>( + ksk_out: &mut Mmut, + s: &S, + auto_k: usize, + gadget_vector: &[Mmut::MatElement], + mod_op: &ModOp, + ntt_op: &NttOp, + rng: &mut R, +) where + ::R: RowMut, + Mmut: TryConvertFrom<[S::Element], Parameters = Mmut::MatElement>, + Mmut::MatElement: Copy + Sub, +{ + let ring_size = s.values().len(); + let (auto_map_index, auto_map_sign) = generate_auto_map(ring_size, auto_k); + + let q = ArithmeticOps::modulus(mod_op); + + // s(X) -> -s(X^k) + let mut s = Mmut::try_convert_from(s.values(), &q); + let mut neg_s_auto = Mmut::zeros(1, s.dimension().1); + izip!(s.get_row(0), auto_map_index.iter(), auto_map_sign.iter()).for_each( + |(el, to_index, sign)| { + // if sign is +ve (true), then negate because we need -s(X) (i.e. do the + // opposite than the usual case) + if *sign { + neg_s_auto.set(0, *to_index, q - *el) + } else { + neg_s_auto.set(0, *to_index, *el) + } + }, + ); + + // send both s(X) and -s(X^k) to evaluation domain + ntt_op.forward(s.get_row_mut(0)); + ntt_op.forward(neg_s_auto.get_row_mut(0)); + + // Ksk from -s(X^k) to s(X) + rlwe_ksk_gen(ksk_out, &neg_s_auto, &s, gadget_vector, mod_op, ntt_op, rng); +} + +/// Sends RLWE_{s}(X) -> RLWE_{s}(X^k) where k is some galois element +fn galois_auto< + M: Matrix, + Mmut: MatrixMut, + ModOp: ArithmeticOps + VectorOps, + NttOp: Ntt, + D: Decomposer, +>( + rlwe_in: &M, + ksk: &M, + rlwe_out: &mut Mmut, + a_rlwe_decomposed: &mut Mmut, + auto_map_index: &[usize], + auto_map_sign: &[bool], + mod_op: &ModOp, + ntt_op: &NttOp, + decomposer: &D, +) where + ::R: RowMut, + M::MatElement: Copy, +{ + let d = decomposer.d(); + + // send b(X) -> b(X^k) + izip!( + rlwe_in.get_row(1), + auto_map_index.iter(), + auto_map_sign.iter() + ) + .for_each(|(el_in, to_index, sign)| { + if !*sign { + rlwe_out.set(1, *to_index, mod_op.neg(el_in)); + } else { + rlwe_out.set(1, *to_index, *el_in); + } + }); + + // send a(X) -> a(X^k) and decompose a(X^k) + izip!( + rlwe_in.get_row(0), + auto_map_index.iter(), + auto_map_sign.iter() + ) + .for_each(|(el_in, to_index, sign)| { + let el_out = if !*sign { mod_op.neg(el_in) } else { *el_in }; + + let el_out_decomposed = decomposer.decompose(&el_out); + for j in 0..d { + a_rlwe_decomposed.set(j, *to_index, el_out_decomposed[j]); + } + }); + + // transform decomposed a(X^k) to evaluation domain + a_rlwe_decomposed.iter_rows_mut().for_each(|r| { + ntt_op.forward(r.as_mut()); + }); + + // key switch (a(X^k) * RLWE'(s(X^k))) + izip!(a_rlwe_decomposed.iter_rows(), ksk.iter_rows().take(d)).for_each(|(a, b)| { + mod_op.elwise_fma_mut(rlwe_out.get_row_mut(0), a.as_ref(), b.as_ref()); + }); + ntt_op.forward(rlwe_out.get_row_mut(1)); + izip!(a_rlwe_decomposed.iter_rows(), ksk.iter_rows().skip(d)).for_each(|(a, b)| { + mod_op.elwise_fma_mut(rlwe_out.get_row_mut(1), a.as_ref(), b.as_ref()); + }); + + // transform RLWE(-s(X^k) * a(X^k)) to coefficient domain + rlwe_out + .iter_rows_mut() + .for_each(|r| ntt_op.backward(r.as_mut())); +} + /// Encrypts message m as a RGSW ciphertext. /// /// - m_eval: is `m` is evaluation domain @@ -366,11 +585,71 @@ fn decrypt_rlwe< mod_op.elwise_add_mut(m_out.get_row_mut(0), rlwe_ct.get_row_slice(1)); } +// Measures noise in degree 1 RLWE ciphertext against encoded ideal message +// encoded_m +fn measure_noise< + Mmut: MatrixMut + Matrix + MatrixEntity, + ModOp: VectorOps, + NttOp: Ntt, + S: Secret, +>( + rlwe_ct: &Mmut, + encoded_m_ideal: &Mmut, + ntt_op: &NttOp, + mod_op: &ModOp, + s: &S, +) -> f64 +where + ::R: RowMut, + Mmut: TryConvertFrom<[S::Element], Parameters = Mmut::MatElement>, + Mmut::MatElement: PrimInt + ToPrimitive + Debug, +{ + let ring_size = s.values().len(); + assert!(rlwe_ct.dimension() == (2, ring_size)); + assert!(encoded_m_ideal.dimension() == (1, ring_size)); + + // -(s * a) + let q = VectorOps::modulus(mod_op); + let mut s = Mmut::try_convert_from(s.values(), &q); + ntt_op.forward(s.get_row_mut(0)); + let mut a = Mmut::zeros(1, ring_size); + a.get_row_mut(0).copy_from_slice(rlwe_ct.get_row_slice(0)); + ntt_op.forward(a.get_row_mut(0)); + mod_op.elwise_mul_mut(s.get_row_mut(0), a.get_row_slice(0)); + mod_op.elwise_neg_mut(s.get_row_mut(0)); + ntt_op.backward(s.get_row_mut(0)); + + // m+e = b - s*a + let mut m_plus_e = s; + mod_op.elwise_add_mut(m_plus_e.get_row_mut(0), rlwe_ct.get_row_slice(1)); + + // difference + mod_op.elwise_sub_mut(m_plus_e.get_row_mut(0), encoded_m_ideal.get_row_slice(0)); + + let mut max_diff_bits = f64::MIN; + m_plus_e.get_row_slice(0).iter().for_each(|v| { + let mut v = *v; + + if v >= (q >> 1) { + // v is -ve + v = q - v; + } + + let bits = (v.to_f64().unwrap()).log2(); + + if max_diff_bits < bits { + max_diff_bits = bits; + } + }); + + return max_diff_bits; +} + #[cfg(test)] mod tests { use std::vec; - use itertools::Itertools; + use itertools::{izip, Itertools}; use rand::{thread_rng, Rng}; use crate::{ @@ -378,10 +657,14 @@ mod tests { decomposer::{gadget_vector, DefaultDecomposer}, ntt::{self, Ntt, NttBackendU64}, random::{DefaultSecureRng, RandomUniformDist}, + rgsw::measure_noise, utils::{generate_prime, negacyclic_mul}, }; - use super::{decrypt_rlwe, encrypt_rgsw, encrypt_rlwe, rlwe_by_rgsw, RlweSecret}; + use super::{ + decrypt_rlwe, encrypt_rgsw, encrypt_rlwe, galois_auto, galois_key_gen, generate_auto_map, + rlwe_by_rgsw, RlweSecret, + }; #[test] fn rlwe_by_rgsw_works() { @@ -463,4 +746,106 @@ mod tests { assert_eq!(m0m1, m0m1_back, "Expected {:?} got {:?}", m0m1, m0m1_back); // dbg!(&m0m1_back, m0m1, q); } + + #[test] + fn galois_auto_works() { + let logq = 50; + let ring_size = 1 << 5; + let q = generate_prime(logq, 2 * ring_size, 1u64 << logq).unwrap(); + let logp = 3; + let p = 1u64 << logp; + let d_rgsw = 10; + let logb = 5; + + let mut rng = DefaultSecureRng::new(); + let s = RlweSecret::random((ring_size >> 1) as usize, ring_size as usize); + + let mut m = vec![0u64; ring_size as usize]; + RandomUniformDist::random_fill(&mut rng, &p, m.as_mut_slice()); + let encoded_m = m + .iter() + .map(|v| (((*v as f64 * q as f64) / (p as f64)).round() as u64)) + .collect_vec(); + + let ntt_op = NttBackendU64::new(q, ring_size as usize); + let mod_op = ModularOpsU64::new(q); + + // RLWE_{s}(m) + let mut rlwe_m = vec![vec![0u64; ring_size as usize]; 2]; + encrypt_rlwe( + &vec![encoded_m.clone()], + &mut rlwe_m, + &s, + &mod_op, + &ntt_op, + &mut rng, + ); + + let auto_k = 25; + + // Generate galois key to key switch from s^k to s + let mut ksk_out = vec![vec![0u64; ring_size as usize]; d_rgsw * 2]; + let gadget_vector = gadget_vector(logq, logb, d_rgsw); + galois_key_gen( + &mut ksk_out, + &s, + auto_k, + &gadget_vector, + &mod_op, + &ntt_op, + &mut rng, + ); + + // Send RLWE_{s}(m) -> RLWE_{s}(m^k) + let mut rlwe_m_k = vec![vec![0u64; ring_size as usize]; 2]; + let mut scratch_space = vec![vec![0u64; ring_size as usize]; d_rgsw]; + let (auto_map_index, auto_map_sign) = generate_auto_map(ring_size as usize, auto_k); + let decomposer = DefaultDecomposer::new(q, logb, d_rgsw); + galois_auto( + &rlwe_m, + &ksk_out, + &mut rlwe_m_k, + &mut scratch_space, + &auto_map_index, + &auto_map_sign, + &mod_op, + &ntt_op, + &decomposer, + ); + + // Decrypt RLWE_{s}(m^k) and check + let mut encoded_m_k_back = vec![vec![0u64; ring_size as usize]]; + decrypt_rlwe(&rlwe_m_k, &s, &mut encoded_m_k_back, &ntt_op, &mod_op); + let m_k_back = encoded_m_k_back[0] + .iter() + .map(|v| (((*v as f64 * p as f64) / q as f64).round() as u64) % p) + .collect_vec(); + + let mut m_k = vec![0u64; ring_size as usize]; + // Send \delta m -> \delta m^k + izip!(m.iter(), auto_map_index.iter(), auto_map_sign.iter()).for_each( + |(v, to_index, sign)| { + if !*sign { + m_k[*to_index] = (p - *v) % p; + } else { + m_k[*to_index] = *v; + } + }, + ); + + { + let encoded_m_k = m_k + .iter() + .map(|v| ((*v as f64 * q as f64) / p as f64).round() as u64) + .collect_vec(); + + let noise = measure_noise(&rlwe_m_k, &vec![encoded_m_k], &ntt_op, &mod_op, &s); + println!("Ksk noise: {noise}"); + } + + // FIXME(Jay): Galios autormophism will incur high error unless we fix in + // accurate decomoposition of Decomposer when q is prime + assert_eq!(m_k_back, m_k); + // dbg!(m_k_back, m_k, q); + } }