Browse Source

rlwe x rgsw not working

par-agg-key-shares
Janmajaya Mall 11 months ago
parent
commit
56752a7559
6 changed files with 886 additions and 707 deletions
  1. +143
    -185
      src/bool.rs
  2. +2
    -0
      src/lib.rs
  3. +18
    -8
      src/lwe.rs
  4. +7
    -4
      src/ntt.rs
  5. +44
    -28
      src/random.rs
  6. +672
    -482
      src/rgsw.rs

+ 143
- 185
src/bool.rs

@ -17,16 +17,16 @@ use crate::{
ntt::{Ntt, NttBackendU64, NttInit},
random::{DefaultSecureRng, RandomGaussianDist, RandomUniformDist},
rgsw::{
decrypt_rlwe, encrypt_rgsw, galois_auto, galois_key_gen, generate_auto_map, rlwe_by_rgsw,
IsTrivial, RlweCiphertext, RlweSecret,
decrypt_rlwe, galois_auto, galois_key_gen, generate_auto_map, rlwe_by_rgsw,
secret_key_encrypt_rgsw, IsTrivial, RlweCiphertext, RlweSecret,
},
utils::{generate_prime, mod_exponent, TryConvertFrom, WithLocal},
Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, Secret,
};
thread_local! {
pub(crate) static CLIENT_KEY: RefCell<ClientKey> = RefCell::new(ClientKey::random());
}
// thread_local! {
// pub(crate) static CLIENT_KEY: RefCell<ClientKey> =
// RefCell::new(ClientKey::random()); }
trait PbsKey {
type M: Matrix;
@ -77,25 +77,25 @@ impl ClientKey {
}
}
impl WithLocal for ClientKey {
fn with_local<F, R>(func: F) -> R
where
F: Fn(&Self) -> R,
{
CLIENT_KEY.with_borrow(|client_key| func(client_key))
}
fn with_local_mut<F, R>(func: F) -> R
where
F: Fn(&mut Self) -> R,
{
CLIENT_KEY.with_borrow_mut(|client_key| func(client_key))
}
}
fn set_client_key(key: &ClientKey) {
ClientKey::with_local_mut(|k| *k = key.clone())
}
// impl WithLocal for ClientKey {
// fn with_local<F, R>(func: F) -> R
// where
// F: Fn(&Self) -> R,
// {
// CLIENT_KEY.with_borrow(|client_key| func(client_key))
// }
// fn with_local_mut<F, R>(func: F) -> R
// where
// F: Fn(&mut Self) -> R,
// {
// CLIENT_KEY.with_borrow_mut(|client_key| func(client_key))
// }
// }
// fn set_client_key(key: &ClientKey) {
// ClientKey::with_local_mut(|k| *k = key.clone())
// }
struct ServerKey<M> {
/// Rgsw cts of LWE secret elements
@ -154,8 +154,6 @@ where
nand_test_vec: M::R,
rlweq_by8: M::MatElement,
rlwe_auto_maps: Vec<(Vec<usize>, Vec<bool>)>,
scratch_lwen_plus1: M::R,
scratch_dplus2_ring: M,
_phantom: PhantomData<M>,
}
@ -264,10 +262,6 @@ where
rlwe_auto_maps.push(generate_auto_map(ring_size, i))
}
// create srcatch spaces
let scratch_lwen_plus1 = M::R::zeros(parameters.lwe_n + 1);
let scratch_dplus2_ring = M::zeros(parameters.d_rgsw + 2, parameters.rlwe_n);
BoolEvaluator {
parameters: parameters,
decomposer_lwe,
@ -280,8 +274,7 @@ where
nand_test_vec: nand_test_vec_autog,
rlweq_by8: rlwe_qby8,
rlwe_auto_maps,
scratch_lwen_plus1,
scratch_dplus2_ring,
_phantom: PhantomData,
}
}
@ -293,101 +286,103 @@ where
}
fn server_key(&self, client_key: &ClientKey) -> ServerKey<M> {
let sk_rlwe = &client_key.sk_rlwe;
let sk_lwe = &client_key.sk_lwe;
let d_rgsw_gadget_vec = gadget_vector(
self.parameters.rlwe_logq,
self.parameters.logb_rgsw,
self.parameters.d_rgsw,
);
// generate galois key -g, g
let mut galois_keys = HashMap::new();
let g = self.parameters.g as isize;
for i in [g, -g] {
let gk = DefaultSecureRng::with_local_mut(|rng| {
let mut ksk_out = M::zeros(self.parameters.d_rgsw * 2, self.parameters.rlwe_n);
galois_key_gen(
&mut ksk_out,
sk_rlwe,
i,
&d_rgsw_gadget_vec,
&self.rlwe_modop,
&self.rlwe_nttop,
rng,
);
ksk_out
});
galois_keys.insert(i, gk);
}
// generate rgsw ciphertexts RGSW(si) where si is i^th LWE secret element
let ring_size = self.parameters.rlwe_n;
let rlwe_q = self.parameters.rlwe_q;
let rgsw_cts = sk_lwe
.values()
.iter()
.map(|si| {
// X^{si}; assume |emebedding_factor * si| < N
let mut m = M::zeros(1, ring_size);
let si = (self.embedding_factor as i32) * si;
// dbg!(si);
if si < 0 {
// X^{-i} = X^{2N - i} = -X^{N-i}
m.set(
0,
ring_size - (si.abs() as usize),
rlwe_q - M::MatElement::one(),
);
} else {
// X^{i}
m.set(0, (si.abs() as usize), M::MatElement::one());
}
self.rlwe_nttop.forward(m.get_row_mut(0));
let rgsw_si = DefaultSecureRng::with_local_mut(|rng| {
let mut rgsw_si = M::zeros(self.parameters.d_rgsw * 4, ring_size);
encrypt_rgsw(
&mut rgsw_si,
&m,
&d_rgsw_gadget_vec,
sk_rlwe,
&self.rlwe_modop,
&self.rlwe_nttop,
rng,
);
rgsw_si
});
rgsw_si
})
.collect_vec();
// let sk_rlwe = &client_key.sk_rlwe;
// let sk_lwe = &client_key.sk_lwe;
// let d_rgsw_gadget_vec = gadget_vector(
// self.parameters.rlwe_logq,
// self.parameters.logb_rgsw,
// self.parameters.d_rgsw,
// );
// // generate galois key -g, g
// let mut galois_keys = HashMap::new();
// let g = self.parameters.g as isize;
// for i in [g, -g] {
// let gk = DefaultSecureRng::with_local_mut(|rng| {
// let mut ksk_out = M::zeros(self.parameters.d_rgsw * 2,
// self.parameters.rlwe_n); galois_key_gen(
// &mut ksk_out,
// sk_rlwe,
// i,
// &d_rgsw_gadget_vec,
// &self.rlwe_modop,
// &self.rlwe_nttop,
// rng,
// );
// ksk_out
// });
// galois_keys.insert(i, gk);
// }
// // generate rgsw ciphertexts RGSW(si) where si is i^th LWE secret element
// let ring_size = self.parameters.rlwe_n;
// let rlwe_q = self.parameters.rlwe_q;
// let rgsw_cts = sk_lwe
// .values()
// .iter()
// .map(|si| {
// // X^{si}; assume |emebedding_factor * si| < N
// let mut m = M::zeros(1, ring_size);
// let si = (self.embedding_factor as i32) * si;
// // dbg!(si);
// if si < 0 {
// // X^{-i} = X^{2N - i} = -X^{N-i}
// m.set(
// 0,
// ring_size - (si.abs() as usize),
// rlwe_q - M::MatElement::one(),
// );
// } else {
// // X^{i}
// m.set(0, (si.abs() as usize), M::MatElement::one());
// }
// self.rlwe_nttop.forward(m.get_row_mut(0));
// let rgsw_si = DefaultSecureRng::with_local_mut(|rng| {
// let mut rgsw_si = M::zeros(self.parameters.d_rgsw * 4,
// ring_size); secret_key_encrypt_rgsw(
// &mut rgsw_si,
// &m,
// &d_rgsw_gadget_vec,
// sk_rlwe,
// &self.rlwe_modop,
// &self.rlwe_nttop,
// rng,
// );
// rgsw_si
// });
// rgsw_si
// })
// .collect_vec();
// // LWE KSK from RLWE secret s -> LWE secret z
// let d_lwe_gadget = gadget_vector(
// self.parameters.lwe_logq,
// self.parameters.logb_lwe,
// self.parameters.d_lwe,
// );
// let mut lwe_ksk = DefaultSecureRng::with_local_mut(|rng| {
// let mut out = M::zeros(self.parameters.d_lwe * ring_size,
// self.parameters.lwe_n + 1); lwe_ksk_keygen(
// &sk_rlwe.values(),
// &sk_lwe.values(),
// &mut out,
// &d_lwe_gadget,
// &self.lwe_modop,
// rng,
// );
// out
// });
// LWE KSK from RLWE secret s -> LWE secret z
let d_lwe_gadget = gadget_vector(
self.parameters.lwe_logq,
self.parameters.logb_lwe,
self.parameters.d_lwe,
);
let mut lwe_ksk = DefaultSecureRng::with_local_mut(|rng| {
let mut out = M::zeros(self.parameters.d_lwe * ring_size, self.parameters.lwe_n + 1);
lwe_ksk_keygen(
&sk_rlwe.values(),
&sk_lwe.values(),
&mut out,
&d_lwe_gadget,
&self.lwe_modop,
rng,
);
out
});
// ServerKey {
// rgsw_cts,
// galois_keys,
// lwe_ksk,
// }
ServerKey {
rgsw_cts,
galois_keys,
lwe_ksk,
}
todo!()
}
/// TODO(Jay): Fetch client key from thread local
@ -434,6 +429,7 @@ where
}
}
// TODO(Jay): scratch spaces must be thread local. Don't pass them as arguments
pub fn nand(
&self,
c0: &M::R,
@ -688,9 +684,7 @@ fn pbs<
nttop_rlweq: &NttOp,
pbs_key: &K,
) where
// FIXME(Jay): TryConvertFrom<[i32], Parameters = M::MatElement> are only needed for
// debugging purposes
<M as Matrix>::R: RowMut + TryConvertFrom<[i32], Parameters = M::MatElement>,
<M as Matrix>::R: RowMut,
M::MatElement: PrimInt + ToPrimitive + FromPrimitive + One + Copy + Zero + Display,
{
let rlwe_q = parameters.rlwe_q();
@ -775,7 +769,8 @@ fn pbs<
gb_monomial_sign = false
}
// monomial mul
let mut trivial_rlwe_test_poly = RlweCiphertext(M::zeros(2, rlwe_n), true);
let mut trivial_rlwe_test_poly =
RlweCiphertext::<_, DefaultSecureRng>::from_raw(M::zeros(2, rlwe_n), true);
if parameters.embedding_factor() == 1 {
monomial_mul(
test_vec.as_ref(),
@ -853,16 +848,12 @@ fn pbs<
}
fn mod_switch_odd(v: f64, from_q: f64, to_q: f64) -> usize {
// println!("v: {v}, odd_v: {odd_v}, lwe_q:{lwe_q}, br_q:{br_q}");
let odd_v = (((v * to_q) / (from_q)).floor()).to_usize().unwrap();
// println!(
// "v: {v}, odd_v: {odd_v}, returned_oddv: {},lwe_q:{from_q}, br_q:{to_q}",
// odd_v + ((odd_v & 1) ^ 1)
// );
//TODO(Jay): check correctness of this
odd_v + ((odd_v & 1) ^ 1)
}
// TODO(Jay): Add tests for sample extract
fn sample_extract<M: Matrix + MatrixMut, ModOp: ArithmeticOps<Element = M::MatElement>>(
lwe_out: &mut M::R,
rlwe_in: &M,
@ -936,58 +927,26 @@ where
}
impl PBSTracer<Vec<Vec<u64>>> {
fn trace(&self, parameters: &BoolParameters<u64>, client_key: &ClientKey, expected_m: bool) {
let lwe_q = parameters.lwe_q;
let lwe_qby8 = ((lwe_q as f64) / 8.0).round() as u64;
let expected_m_lweq = if expected_m {
lwe_qby8
} else {
lwe_q - lwe_qby8
};
let modop_lweq = ModularOpsU64::new(lwe_q);
fn trace(&self, parameters: &BoolParameters<u64>, client_key: &ClientKey) {
let modop_lweq = ModularOpsU64::new(parameters.lwe_q as u64);
// noise after mod down Q -> Q_ks
let noise0 = {
measure_noise_lwe(
&self.ct_lwe_q_mod,
client_key.sk_rlwe.values(),
&modop_lweq,
&expected_m_lweq,
)
};
let m_back0 = decrypt_lwe(&self.ct_lwe_q_mod, client_key.sk_rlwe.values(), &modop_lweq);
// noise after key switch from RLWE -> LWE
let noise1 = {
measure_noise_lwe(
&self.ct_lwe_q_mod_after_ksk,
client_key.sk_lwe.values(),
&modop_lweq,
&expected_m_lweq,
)
};
let m_back1 = decrypt_lwe(
&self.ct_lwe_q_mod_after_ksk,
client_key.sk_lwe.values(),
&modop_lweq,
);
// noise after mod down odd from Q_ks -> q
let br_q = parameters.br_q as u64;
let expected_m_brq = if expected_m {
br_q >> 3
} else {
br_q - (br_q >> 3)
};
let modop_br_q = ModularOpsU64::new(br_q);
let noise2 = {
measure_noise_lwe(
&self.ct_br_q_mod,
client_key.sk_lwe.values(),
&modop_br_q,
&expected_m_brq,
)
};
let modop_br_q = ModularOpsU64::new(parameters.br_q as u64);
let m_back2 = decrypt_lwe(&self.ct_br_q_mod, client_key.sk_lwe.values(), &modop_br_q);
println!(
"
m: {expected_m},
Noise after mod down Q -> Q_ks: {noise0},
Noise after key switch from RLWE -> LWE: {noise1},
Noise after mod dwon Q_ks -> q: {noise2}
M after mod down Q -> Q_ks: {m_back0},
M after key switch from RLWE -> LWE: {m_back1},
M after mod dwon Q_ks -> q: {m_back2}
"
);
}
@ -1025,7 +984,7 @@ mod tests {
lwe_n: 493,
d_rgsw: 3,
logb_rgsw: 8,
d_lwe: 3,
d_lwe: 2,
logb_lwe: 4,
g: 5,
w: 1,
@ -1063,10 +1022,9 @@ mod tests {
let bool_evaluator =
BoolEvaluator::<Vec<Vec<u64>>, u64, NttBackendU64, ModularOpsU64>::new(SP_BOOL_PARAMS);
// println!("{:?}", bool_evaluator.nand_test_vec);
let client_key = bool_evaluator.client_key();
set_client_key(&client_key);
let server_key = bool_evaluator.server_key(&client_key);
let mut scratch_lwen_plus1 = vec![0u64; bool_evaluator.parameters.lwe_n + 1];
@ -1079,7 +1037,7 @@ mod tests {
let mut m1 = true;
let mut ct0 = bool_evaluator.encrypt(m0, &client_key);
let mut ct1 = bool_evaluator.encrypt(m1, &client_key);
for _ in 0..4 {
for _ in 0..100 {
let ct_back = bool_evaluator.nand(
&ct0,
&ct1,
@ -1093,9 +1051,9 @@ mod tests {
// Trace and measure PBS noise
{
// Trace PBS
PBSTracer::with_local(|t| t.trace(&SP_BOOL_PARAMS, &client_key, m_out));
PBSTracer::with_local(|t| t.trace(&SP_BOOL_PARAMS, &client_key));
// Calculate nosie in ciphertext post PBS
// Calculate noise in ciphertext post PBS
let ideal = if m_out {
bool_evaluator.rlweq_by8
} else {

+ 2
- 0
src/lib.rs

@ -9,11 +9,13 @@ mod backend;
mod bool;
mod decomposer;
mod lwe;
mod multi_party;
mod ntt;
mod num;
mod random;
mod rgsw;
mod utils;
pub trait Matrix: AsRef<[Self::R]> {
type MatElement;
type R: Row<Element = Self::MatElement>;

+ 18
- 8
src/lwe.rs

@ -187,6 +187,11 @@ where
operator.sub(b, &sa)
}
/// Measures noise in input LWE ciphertext with reference of `ideal_m`
///
/// - ct: Input LWE ciphertext
/// - s: corresponding secret
/// - ideal_m: Ideal `encoded` message
pub(crate) fn measure_noise_lwe<Ro: Row, Op: ArithmeticOps<Element = Ro::Element>, S>(
ct: &Ro,
s: &[S],
@ -206,7 +211,6 @@ where
});
let m = operator.sub(&ct.as_ref()[0], &sa);
println!("measire: {m} {ideal_m}");
let mut diff = operator.sub(&m, ideal_m);
let q = operator.modulus();
if diff > (q >> 1) {
@ -221,16 +225,19 @@ mod tests {
use crate::{
backend::{ModInit, ModularOpsU64},
decomposer::{gadget_vector, DefaultDecomposer},
lwe::lwe_key_switch,
lwe::{lwe_key_switch, measure_noise_lwe},
random::DefaultSecureRng,
rgsw::measure_noise,
Secret,
};
use super::{decrypt_lwe, encrypt_lwe, lwe_ksk_keygen, LweSecret};
const K: usize = 500;
#[test]
fn encrypt_decrypt_works() {
let logq = 20;
let logq = 16;
let q = 1u64 << logq;
let lwe_n = 1024;
let logp = 3;
@ -262,12 +269,12 @@ mod tests {
#[test]
fn key_switch_works() {
let logq = 16;
let logp = 3;
let logp = 2;
let q = 1u64 << logq;
let lwe_in_n = 1024;
let lwe_out_n = 470;
let lwe_in_n = 2048;
let lwe_out_n = 493;
let d_ks = 3;
let logb = 4;
let logb = 5;
let lwe_sk_in = LweSecret::random(lwe_in_n >> 1, lwe_in_n);
let lwe_sk_out = LweSecret::random(lwe_out_n >> 1, lwe_out_n);
@ -276,7 +283,7 @@ mod tests {
let modq_op = ModularOpsU64::new(q);
// genrate ksk
for _ in 0..10 {
for _ in 0..K {
let mut ksk = vec![vec![0u64; lwe_out_n + 1]; d_ks * lwe_in_n];
let gadget = gadget_vector(logq, logb, d_ks);
lwe_ksk_keygen(
@ -311,6 +318,9 @@ mod tests {
let m_back = ((((encoded_m_back as f64) * ((1 << logp) as f64)) / q as f64).round()
as u64)
% (1u64 << logp);
let noise =
measure_noise_lwe(&lwe_out_ct, lwe_sk_out.values(), &modq_op, &encoded_m);
println!("Noise: {noise}");
assert_eq!(m, m_back, "Expected {m} but got {m_back}");
// dbg!(m, m_back);
// dbg!(encoded_m, encoded_m_back);

+ 7
- 4
src/ntt.rs

@ -1,5 +1,6 @@
use itertools::Itertools;
use rand::{thread_rng, Rng, RngCore};
use rand::{thread_rng, Rng, RngCore, SeedableRng};
use rand_chacha::ChaCha8Rng;
use crate::{
backend::{ArithmeticOps, ModInit, ModularOpsU64},
@ -8,6 +9,8 @@ use crate::{
pub trait NttInit {
type Element;
/// Ntt istance must be compatible across different instances with same `q`
/// and `n`
fn new(q: Self::Element, n: usize) -> Self;
}
@ -189,6 +192,7 @@ pub(crate) fn find_primitive_root(q: u64, n: u64, rng: &mut R) -> Op
None
}
#[derive(Debug)]
pub struct NttBackendU64 {
q: u64,
q_twice: u64,
@ -204,10 +208,9 @@ impl NttInit for NttBackendU64 {
type Element = u64;
fn new(q: u64, n: usize) -> Self {
// \psi = 2n^{th} primitive root of unity in F_q
let mut rng = thread_rng();
let mut rng = ChaCha8Rng::from_seed([0u8; 32]);
let psi = find_primitive_root(q, (n * 2) as u64, &mut rng)
.expect("Unable to find 2n^th root of unity");
let psi_inv = mod_inverse(psi, q);
// assert!(
@ -382,7 +385,7 @@ mod tests {
#[test]
fn native_ntt_negacylic_mul() {
let primes = [40, 50, 60]
let primes = [25, 40, 50, 60]
.iter()
.map(|bits| generate_prime(*bits, (2 * N) as u64, 1u64 << bits).unwrap())
.collect_vec();

+ 44
- 28
src/random.rs

@ -11,6 +11,11 @@ thread_local! {
pub(crate) static DEFAULT_RNG: RefCell<DefaultSecureRng> = RefCell::new(DefaultSecureRng::new());
}
pub(crate) trait NewWithSeed {
type Seed;
fn new_with_seed(seed: Self::Seed) -> Self;
}
pub trait RandomGaussianDist<M>
where
M: ?Sized,
@ -41,6 +46,17 @@ impl DefaultSecureRng {
let rng = ChaCha8Rng::from_entropy();
DefaultSecureRng { rng }
}
pub fn fill_bytes(&mut self, a: &mut [u8; 32]) {
self.rng.fill_bytes(a);
}
}
impl NewWithSeed for DefaultSecureRng {
type Seed = <ChaCha8Rng as SeedableRng>::Seed;
fn new_with_seed(seed: Self::Seed) -> Self {
DefaultSecureRng::new_seeded(seed)
}
}
impl RandomUniformDist<usize> for DefaultSecureRng {
@ -86,19 +102,19 @@ impl RandomUniformDist<[u64]> for DefaultSecureRng {
impl RandomGaussianDist<u64> for DefaultSecureRng {
type Parameters = u64;
fn random_fill(&mut self, parameters: &Self::Parameters, container: &mut u64) {
let o = rand_distr::Normal::new(0.0, 3.2f64)
.unwrap()
.sample(&mut self.rng)
.round();
// let o = 0.0f64;
let is_neg = o.is_sign_negative() && o != 0.0;
if is_neg {
*container = parameters - (o.abs() as u64);
} else {
*container = o as u64;
}
// let o = rand_distr::Normal::new(0.0, 3.2f64)
// .unwrap()
// .sample(&mut self.rng)
// .round();
// // let o = 0.0f64;
// let is_neg = o.is_sign_negative() && o != 0.0;
// if is_neg {
// *container = parameters - (o.abs() as u64);
// } else {
// *container = o as u64;
// }
}
}
@ -124,21 +140,21 @@ impl RandomGaussianDist for DefaultSecureRng {
impl RandomGaussianDist<[u64]> for DefaultSecureRng {
type Parameters = u64;
fn random_fill(&mut self, parameters: &Self::Parameters, container: &mut [u64]) {
izip!(
rand_distr::Normal::new(0.0, 3.2f64)
.unwrap()
.sample_iter(&mut self.rng),
container.iter_mut()
)
.for_each(|(oi, v)| {
let oi = oi.round();
let is_neg = oi.is_sign_negative() && oi != 0.0;
if is_neg {
*v = parameters - (oi.abs() as u64);
} else {
*v = oi as u64;
}
});
// izip!(
// rand_distr::Normal::new(0.0, 3.2f64)
// .unwrap()
// .sample_iter(&mut self.rng),
// container.iter_mut()
// )
// .for_each(|(oi, v)| {
// let oi = oi.round();
// let is_neg = oi.is_sign_negative() && oi != 0.0;
// if is_neg {
// *v = parameters - (oi.abs() as u64);
// } else {
// *v = oi as u64;
// }
// });
}
}

+ 672
- 482
src/rgsw.rs
File diff suppressed because it is too large
View File


Loading…
Cancel
Save