Browse Source

add PBS tracer

par-agg-key-shares
Janmajaya Mall 11 months ago
parent
commit
4b835461dd
3 changed files with 605 additions and 85 deletions
  1. +544
    -69
      src/bool.rs
  2. +34
    -2
      src/lwe.rs
  3. +27
    -14
      src/rgsw.rs

+ 544
- 69
src/bool.rs

@ -1,28 +1,44 @@
use std::{collections::HashMap, fmt::Debug, marker::PhantomData};
use std::{
cell::RefCell,
collections::HashMap,
fmt::{Debug, Display},
hash::Hash,
marker::PhantomData,
thread::panicking,
};
use itertools::Itertools;
use num_traits::{FromPrimitive, Num, One, PrimInt, ToPrimitive, Zero};
use itertools::{izip, partition, Itertools};
use num_traits::{FromPrimitive, Num, One, PrimInt, ToPrimitive, WrappingSub, Zero};
use crate::{
backend::{ArithmeticOps, ModInit, VectorOps},
backend::{ArithmeticOps, ModInit, ModularOpsU64, VectorOps},
decomposer::{gadget_vector, Decomposer, DefaultDecomposer, NumInfo},
lwe::{decrypt_lwe, encrypt_lwe, lwe_key_switch, lwe_ksk_keygen, LweSecret},
ntt::{Ntt, NttInit},
lwe::{decrypt_lwe, encrypt_lwe, lwe_key_switch, lwe_ksk_keygen, measure_noise_lwe, LweSecret},
ntt::{Ntt, NttBackendU64, NttInit},
random::{DefaultSecureRng, RandomGaussianDist, RandomUniformDist},
rgsw::{encrypt_rgsw, galois_auto, galois_key_gen, rlwe_by_rgsw, IsTrivial, RlweSecret},
rgsw::{
decrypt_rlwe, encrypt_rgsw, galois_auto, galois_key_gen, generate_auto_map, rlwe_by_rgsw,
IsTrivial, RlweCiphertext, RlweSecret,
},
utils::{generate_prime, mod_exponent, TryConvertFrom, WithLocal},
Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, Secret,
};
thread_local! {
pub(crate) static CLIENT_KEY: RefCell<ClientKey> = RefCell::new(ClientKey::random());
}
trait PbsKey {
type M: Matrix;
fn rgsw_ct_secret_el(&self, si: usize) -> &Self::M;
/// RGSW ciphertext of LWE secret elements
fn rgsw_ct_lwe_si(&self, si: usize) -> &Self::M;
/// Key for automorphism
fn galois_key_for_auto(&self, k: isize) -> &Self::M;
fn auto_map_index(&self, k: isize) -> &[usize];
fn auto_map_sign(&self, k: isize) -> &[bool];
/// LWE ksk to key switch from RLWE secret to LWE secret
fn lwe_ksk(&self) -> &Self::M;
}
trait Parameters {
trait PbsParameters {
type Element;
type D: Decomposer<Element = Self::Element>;
fn rlwe_q(&self) -> Self::Element;
@ -44,12 +60,43 @@ trait Parameters {
/// For any a, if k is s.t. a = g^{k}, then k is expressed as k. If k is s.t
/// a = -g^{k}, then k is expressed as k=k+q/2
fn g_k_dlog_map(&self) -> &[usize];
fn rlwe_auto_map(&self, k: isize) -> &(Vec<usize>, Vec<bool>);
}
#[derive(Clone)]
struct ClientKey {
sk_rlwe: RlweSecret,
sk_lwe: LweSecret,
}
impl ClientKey {
fn random() -> Self {
let sk_rlwe = RlweSecret::random(0, 0);
let sk_lwe = LweSecret::random(0, 0);
Self { sk_rlwe, sk_lwe }
}
}
impl WithLocal for ClientKey {
fn with_local<F, R>(func: F) -> R
where
F: Fn(&Self) -> R,
{
CLIENT_KEY.with_borrow(|client_key| func(client_key))
}
fn with_local_mut<F, R>(func: F) -> R
where
F: Fn(&mut Self) -> R,
{
CLIENT_KEY.with_borrow_mut(|client_key| func(client_key))
}
}
fn set_client_key(key: &ClientKey) {
ClientKey::with_local_mut(|k| *k = key.clone())
}
struct ServerKey<M> {
/// Rgsw cts of LWE secret elements
rgsw_cts: Vec<M>,
@ -59,6 +106,23 @@ struct ServerKey {
lwe_ksk: M,
}
//FIXME(Jay): Figure out a way for BoolEvaluator to have access to ServerKey
// via a pointer and implement PbsKey for BoolEvaluator instead of ServerKey
// directly
impl<M: Matrix> PbsKey for ServerKey<M> {
type M = M;
fn galois_key_for_auto(&self, k: isize) -> &Self::M {
self.galois_keys.get(&k).unwrap()
}
fn rgsw_ct_lwe_si(&self, si: usize) -> &Self::M {
&self.rgsw_cts[si]
}
fn lwe_ksk(&self) -> &Self::M {
&self.lwe_ksk
}
}
struct BoolParameters<El> {
rlwe_q: El,
rlwe_logq: usize,
@ -75,7 +139,10 @@ struct BoolParameters {
w: usize,
}
struct BoolEvaluator<M, E, Ntt, ModOp> {
struct BoolEvaluator<M, E, Ntt, ModOp>
where
M: Matrix,
{
parameters: BoolParameters<E>,
decomposer_rlwe: DefaultDecomposer<E>,
decomposer_lwe: DefaultDecomposer<E>,
@ -84,7 +151,11 @@ struct BoolEvaluator {
rlwe_modop: ModOp,
lwe_modop: ModOp,
embedding_factor: usize,
nand_test_vec: M::R,
rlweq_by8: M::MatElement,
rlwe_auto_maps: Vec<(Vec<usize>, Vec<bool>)>,
scratch_lwen_plus1: M::R,
scratch_dplus2_ring: M,
_phantom: PhantomData<M>,
}
@ -94,7 +165,7 @@ where
ModOp: ModInit<Element = M::MatElement>
+ ArithmeticOps<Element = M::MatElement>
+ VectorOps<Element = M::MatElement>,
M::MatElement: PrimInt + Debug + NumInfo + FromPrimitive,
M::MatElement: PrimInt + Debug + Display + NumInfo + FromPrimitive + WrappingSub,
M: MatrixEntity + MatrixMut,
M::R: TryConvertFrom<[i32], Parameters = M::MatElement> + RowEntity,
M: TryConvertFrom<[i32], Parameters = M::MatElement>,
@ -105,6 +176,7 @@ where
{
fn new(parameters: BoolParameters<M::MatElement>) -> Self {
//TODO(Jay): Run sanity checks for modulus values in parameters
assert!(parameters.br_q.is_power_of_two());
let decomposer_rlwe =
DefaultDecomposer::new(parameters.rlwe_q, parameters.logb_rgsw, parameters.d_rgsw);
@ -129,6 +201,73 @@ where
let rlwe_modop = ModInit::new(parameters.rlwe_q);
let lwe_modop = ModInit::new(parameters.lwe_q);
// set test vectors
let el_one = M::MatElement::one();
let nand_map = |index: usize, qby8: usize| {
if index < (3 * qby8) {
true
} else {
false
}
};
let q = parameters.br_q;
let qby2 = q >> 1;
let qby8 = q >> 3;
let qby16 = q >> 4;
let mut nand_test_vec = M::R::zeros(qby2);
// Q/8 (Q: rlwe_q)
let rlwe_qby8 =
M::MatElement::from_f64((parameters.rlwe_q.to_f64().unwrap() / 8.0).round()).unwrap();
let true_m_el = rlwe_qby8;
// -Q/8
let false_m_el = parameters.rlwe_q - rlwe_qby8;
for i in 0..qby2 {
let v = nand_map(i, qby8);
if v {
nand_test_vec.as_mut()[i] = true_m_el;
} else {
nand_test_vec.as_mut()[i] = false_m_el;
}
}
// Rotate and negate by q/16
let mut tmp = M::R::zeros(qby2);
tmp.as_mut()[..qby2 - qby16].copy_from_slice(&nand_test_vec.as_ref()[qby16..]);
tmp.as_mut()[qby2 - qby16..].copy_from_slice(&nand_test_vec.as_ref()[..qby16]);
tmp.as_mut()[qby2 - qby16..].iter_mut().for_each(|v| {
*v = parameters.rlwe_q - *v;
});
let nand_test_vec = tmp;
// v(X) -> v(X^{-g})
let (auto_map_index, auto_map_sign) = generate_auto_map(qby2, -(g as isize));
let mut nand_test_vec_autog = M::R::zeros(qby2);
izip!(
nand_test_vec.as_ref().iter(),
auto_map_index.iter(),
auto_map_sign.iter()
)
.for_each(|(v, to_index, to_sign)| {
if !to_sign {
// negate
nand_test_vec_autog.as_mut()[*to_index] = parameters.rlwe_q - *v;
} else {
nand_test_vec_autog.as_mut()[*to_index] = *v;
}
});
// auto map indices and sign
let mut rlwe_auto_maps = vec![];
let ring_size = parameters.rlwe_n;
let g = parameters.g as isize;
for i in [g, -g] {
rlwe_auto_maps.push(generate_auto_map(ring_size, i))
}
// create srcatch spaces
let scratch_lwen_plus1 = M::R::zeros(parameters.lwe_n + 1);
let scratch_dplus2_ring = M::zeros(parameters.d_rgsw + 2, parameters.rlwe_n);
BoolEvaluator {
parameters: parameters,
decomposer_lwe,
@ -138,7 +277,11 @@ where
lwe_modop,
rlwe_modop,
rlwe_nttop,
nand_test_vec: nand_test_vec_autog,
rlweq_by8: rlwe_qby8,
rlwe_auto_maps,
scratch_lwen_plus1,
scratch_dplus2_ring,
_phantom: PhantomData,
}
}
@ -190,6 +333,7 @@ where
// X^{si}; assume |emebedding_factor * si| < N
let mut m = M::zeros(1, ring_size);
let si = (self.embedding_factor as i32) * si;
// dbg!(si);
if si < 0 {
// X^{-i} = X^{2N - i} = -X^{N-i}
m.set(
@ -246,16 +390,14 @@ where
}
}
/// TODO(Jay): Fetch client key from thread local
pub fn encrypt(&self, m: bool, client_key: &ClientKey) -> M::R {
let rlwe_q_by8 =
M::MatElement::from_f64((self.parameters.rlwe_q.to_f64().unwrap() / 8.0).round())
.unwrap();
let m = if m {
// Q/8
rlwe_q_by8
self.rlweq_by8
} else {
// -Q/8
self.parameters.rlwe_q - rlwe_q_by8
self.parameters.rlwe_q - self.rlweq_by8
};
DefaultSecureRng::with_local_mut(|rng| {
@ -275,13 +417,11 @@ where
let m = decrypt_lwe(lwe_ct, client_key.sk_rlwe.values(), &self.rlwe_modop);
let m = {
// m + q/8 => {0,q/4 1}
let rlwe_q_by8 =
M::MatElement::from_f64((self.parameters.rlwe_q.to_f64().unwrap() / 8.0).round())
.unwrap();
(((m + rlwe_q_by8).to_f64().unwrap() * 4.0) / self.parameters.rlwe_q.to_f64().unwrap())
.round()
.to_usize()
.unwrap()
(((m + self.rlweq_by8).to_f64().unwrap() * 4.0)
/ self.parameters.rlwe_q.to_f64().unwrap())
.round()
.to_usize()
.unwrap()
% 4
};
@ -290,9 +430,123 @@ where
} else if m == 1 {
true
} else {
panic!("Incorrect bool decryption. Got m={m} expected m to be 0 or 1")
panic!("Incorrect bool decryption. Got m={m} but expected m to be 0 or 1")
}
}
pub fn nand(
&self,
c0: &M::R,
c1: &M::R,
server_key: &ServerKey<M>,
scratch_lwen_plus1: &mut M::R,
scratch_matrix_dplus2_ring: &mut M,
) -> M::R {
// ClientKey::with_local(|ck| {
// let c0_noise = measure_noise_lwe(
// c0,
// ck.sk_rlwe.values(),
// &self.rlwe_modop,
// &(self.rlwe_q() - self.rlweq_by8),
// );
// let c1_noise =
// measure_noise_lwe(c1, ck.sk_rlwe.values(), &self.rlwe_modop,
// &(self.rlweq_by8)); println!("c0 noise: {c0_noise}; c1 noise:
// {c1_noise}"); });
let mut c_out = M::R::zeros(c0.as_ref().len());
let modop = &self.rlwe_modop;
izip!(
c_out.as_mut().iter_mut(),
c0.as_ref().iter(),
c1.as_ref().iter()
)
.for_each(|(o, i0, i1)| {
*o = modop.add(i0, i1);
});
// +Q/8
c_out.as_mut()[0] = modop.add(&c_out.as_ref()[0], &self.rlweq_by8);
// ClientKey::with_local(|ck| {
// let noise = measure_noise_lwe(
// &c_out,
// ck.sk_rlwe.values(),
// &self.rlwe_modop,
// &(self.rlweq_by8),
// );
// println!("cout_noise: {noise}");
// });
// PBS
pbs(
self,
&self.nand_test_vec,
&mut c_out,
scratch_lwen_plus1,
scratch_matrix_dplus2_ring,
&self.lwe_modop,
&self.rlwe_modop,
&self.rlwe_nttop,
server_key,
);
c_out
}
}
impl<M: Matrix, NttOp, ModOp> PbsParameters for BoolEvaluator<M, M::MatElement, NttOp, ModOp>
where
M::MatElement: PrimInt + WrappingSub + Debug,
{
type Element = M::MatElement;
type D = DefaultDecomposer<M::MatElement>;
fn rlwe_auto_map(&self, k: isize) -> &(Vec<usize>, Vec<bool>) {
let g = self.parameters.g as isize;
if k == g {
&self.rlwe_auto_maps[0]
} else if k == -g {
&self.rlwe_auto_maps[1]
} else {
panic!("RLWE auto map only supports k in [-g, g], but got k={k}");
}
}
fn br_q(&self) -> usize {
self.parameters.br_q
}
fn d_lwe(&self) -> usize {
self.parameters.d_lwe
}
fn d_rgsw(&self) -> usize {
self.parameters.d_rgsw
}
fn decomoposer_lwe(&self) -> &Self::D {
&self.decomposer_lwe
}
fn decomoposer_rlwe(&self) -> &Self::D {
&self.decomposer_rlwe
}
fn embedding_factor(&self) -> usize {
self.embedding_factor
}
fn g(&self) -> isize {
self.parameters.g as isize
}
fn g_k_dlog_map(&self) -> &[usize] {
&self.g_k_dlog_map
}
fn lwe_n(&self) -> usize {
self.parameters.lwe_n
}
fn lwe_q(&self) -> Self::Element {
self.parameters.lwe_q
}
fn rlwe_n(&self) -> usize {
self.parameters.rlwe_n
}
fn rlwe_q(&self) -> Self::Element {
self.parameters.rlwe_q
}
}
/// LMKCY+ Blind rotation
@ -305,6 +559,7 @@ fn blind_rotation<
NttOp: Ntt<Element = MT::MatElement>,
ModOp: ArithmeticOps<Element = MT::MatElement> + VectorOps<Element = MT::MatElement>,
K: PbsKey<M = Mmut>,
P: PbsParameters<Element = MT::MatElement>,
>(
trivial_rlwe_test_poly: &mut MT,
scratch_matrix_dplus2_ring: &mut Mmut,
@ -315,6 +570,7 @@ fn blind_rotation<
decomposer: &D,
ntt_op: &NttOp,
mod_op: &ModOp,
parameters: &P,
pbs_key: &K,
) where
<Mmut as Matrix>::R: RowMut,
@ -324,11 +580,11 @@ fn blind_rotation<
let q_by_2 = q / 2;
// -(g^k)
for i in 1..q_by_2 {
for i in (1..q_by_2).rev() {
gk_to_si[q_by_2 + i].iter().for_each(|s_index| {
rlwe_by_rgsw(
trivial_rlwe_test_poly,
pbs_key.rgsw_ct_secret_el(*s_index),
pbs_key.rgsw_ct_lwe_si(*s_index),
scratch_matrix_dplus2_ring,
decomposer,
ntt_op,
@ -336,12 +592,13 @@ fn blind_rotation<
);
});
let (auto_map_index, auto_map_sign) = parameters.rlwe_auto_map(g);
galois_auto(
trivial_rlwe_test_poly,
pbs_key.galois_key_for_auto(g),
scratch_matrix_dplus2_ring,
pbs_key.auto_map_index(g),
pbs_key.auto_map_sign(g),
&auto_map_index,
&auto_map_sign,
mod_op,
ntt_op,
decomposer,
@ -352,30 +609,31 @@ fn blind_rotation<
gk_to_si[q_by_2].iter().for_each(|s_index| {
rlwe_by_rgsw(
trivial_rlwe_test_poly,
pbs_key.rgsw_ct_secret_el(*s_index),
pbs_key.rgsw_ct_lwe_si(*s_index),
scratch_matrix_dplus2_ring,
decomposer,
ntt_op,
mod_op,
);
});
let (auto_map_index, auto_map_sign) = parameters.rlwe_auto_map(-g);
galois_auto(
trivial_rlwe_test_poly,
pbs_key.galois_key_for_auto(-g),
scratch_matrix_dplus2_ring,
pbs_key.auto_map_index(-g),
pbs_key.auto_map_sign(-g),
&auto_map_index,
&auto_map_sign,
mod_op,
ntt_op,
decomposer,
);
// +(g^k)
for i in 1..q_by_2 {
for i in (1..q_by_2).rev() {
gk_to_si[i].iter().for_each(|s_index| {
rlwe_by_rgsw(
trivial_rlwe_test_poly,
pbs_key.rgsw_ct_secret_el(*s_index),
pbs_key.rgsw_ct_lwe_si(*s_index),
scratch_matrix_dplus2_ring,
decomposer,
ntt_op,
@ -383,12 +641,13 @@ fn blind_rotation<
);
});
let (auto_map_index, auto_map_sign) = parameters.rlwe_auto_map(g);
galois_auto(
trivial_rlwe_test_poly,
pbs_key.galois_key_for_auto(g),
scratch_matrix_dplus2_ring,
pbs_key.auto_map_index(g),
pbs_key.auto_map_sign(g),
&auto_map_index,
&auto_map_sign,
mod_op,
ntt_op,
decomposer,
@ -399,7 +658,7 @@ fn blind_rotation<
gk_to_si[0].iter().for_each(|s_index| {
rlwe_by_rgsw(
trivial_rlwe_test_poly,
pbs_key.rgsw_ct_secret_el(gk_to_si[q_by_2][*s_index]),
pbs_key.rgsw_ct_lwe_si(gk_to_si[q_by_2][*s_index]),
scratch_matrix_dplus2_ring,
decomposer,
ntt_op,
@ -414,8 +673,7 @@ fn blind_rotation<
/// - blind rotate
fn pbs<
M: Matrix + MatrixMut + MatrixEntity,
MT: MatrixMut<MatElement = M::MatElement, R = M::R> + IsTrivial + MatrixEntity,
P: Parameters<Element = M::MatElement>,
P: PbsParameters<Element = M::MatElement>,
NttOp: Ntt<Element = M::MatElement>,
ModOp: ArithmeticOps<Element = M::MatElement> + VectorOps<Element = M::MatElement>,
K: PbsKey<M = M>,
@ -423,17 +681,17 @@ fn pbs<
parameters: &P,
test_vec: &M::R,
lwe_in: &mut M::R,
lwe_ksk: &M,
scratch_lwen_plus1: &mut M::R,
scratch_matrix_dplus2_ring: &mut M,
modop_lweq: &ModOp,
modop_rlweq: &ModOp,
nttop_rlweq: &NttOp,
pbs_key: K,
pbs_key: &K,
) where
<M as Matrix>::R: RowMut,
<MT as Matrix>::R: RowMut,
M::MatElement: PrimInt + ToPrimitive + FromPrimitive + One + Copy + Zero,
// FIXME(Jay): TryConvertFrom<[i32], Parameters = M::MatElement> are only needed for
// debugging purposes
<M as Matrix>::R: RowMut + TryConvertFrom<[i32], Parameters = M::MatElement>,
M::MatElement: PrimInt + ToPrimitive + FromPrimitive + One + Copy + Zero + Display,
{
let rlwe_q = parameters.rlwe_q();
let lwe_q = parameters.lwe_q();
@ -449,17 +707,34 @@ fn pbs<
M::MatElement::from_f64(((v.to_f64().unwrap() * lwe_qf64) / rlwe_qf64).round()).unwrap()
});
// key switch
// let mut lwe_out = M::zeros(1, parameters.lwe_n() + 1);
PBSTracer::with_local_mut(|t| {
let out = lwe_in
.as_ref()
.iter()
.map(|v| v.to_u64().unwrap())
.collect_vec();
t.ct_lwe_q_mod = out;
});
// key switch RLWE secret to LWE secret
scratch_lwen_plus1.as_mut().fill(M::MatElement::zero());
lwe_key_switch(
scratch_lwen_plus1,
lwe_in,
lwe_ksk,
pbs_key.lwe_ksk(),
modop_lweq,
parameters.decomoposer_lwe(),
);
PBSTracer::with_local_mut(|t| {
let out = scratch_lwen_plus1
.as_ref()
.iter()
.map(|v| v.to_u64().unwrap())
.collect_vec();
t.ct_lwe_q_mod_after_ksk = out;
});
// odd mowdown Q_ks -> q
let g_k_dlog_map = parameters.g_k_dlog_map();
let mut g_k_si = vec![vec![]; br_q];
@ -474,6 +749,15 @@ fn pbs<
g_k_si[k].push(index);
});
PBSTracer::with_local_mut(|t| {
let out = scratch_lwen_plus1
.as_ref()
.iter()
.map(|v| mod_switch_odd(v.to_f64().unwrap(), lwe_qf64, br_qf64) as u64)
.collect_vec();
t.ct_br_q_mod = out;
});
// handle b and set trivial test RLWE
let g = parameters.g() as usize;
let g_times_b = (g * mod_switch_odd(
@ -485,41 +769,42 @@ fn pbs<
let br_qby2 = br_q / 2;
let mut gb_monomial_sign = true;
let mut gb_monomial_exp = g_times_b;
// X^{g*b} mod X^{q}+1
// X^{g*b} mod X^{q/2}+1
if gb_monomial_exp > br_qby2 {
gb_monomial_exp -= br_qby2;
gb_monomial_sign = false
}
// monomial mul
let mut trivial_rlwe_test_poly = MT::zeros(2, rlwe_n);
let mut trivial_rlwe_test_poly = RlweCiphertext(M::zeros(2, rlwe_n), true);
if parameters.embedding_factor() == 1 {
monomial_mul(
test_vec.as_ref(),
trivial_rlwe_test_poly.get_row_mut(1).as_mut(),
gb_monomial_exp,
gb_monomial_sign,
br_q,
br_qby2,
modop_rlweq,
);
} else {
// use lwe_in to store the `t = v(X) * X^{g*2} mod X^{q/2}+1` temporarily. This
// works because q/2 < N (where N is lwe_in LWE dimension) always.
// works because q/2 <= N (where N is lwe_in LWE dimension) always.
monomial_mul(
test_vec.as_ref(),
&mut lwe_in.as_mut()[..br_qby2],
gb_monomial_exp,
gb_monomial_sign,
br_q,
br_qby2,
modop_rlweq,
);
// emebed poly `t` in ring X^{q/2}+1 inside the bigger ring X^{N}+1
let embed_factor = parameters.embedding_factor();
let partb_trivial_rlwe = trivial_rlwe_test_poly.get_row_mut(1);
lwe_in.as_ref()[..br_qby2]
.iter()
.enumerate()
.for_each(|(index, v)| {
partb_trivial_rlwe[2 * index] = *v;
partb_trivial_rlwe[embed_factor * index] = *v;
});
}
@ -534,19 +819,48 @@ fn pbs<
parameters.decomoposer_rlwe(),
nttop_rlweq,
modop_rlweq,
&pbs_key,
parameters,
pbs_key,
);
// ClientKey::with_local(|ck| {
// let ring_size = parameters.rlwe_n();
// let mut rlwe_ct = vec![vec![0u64; ring_size]; 2];
// izip!(
// rlwe_ct[0].iter_mut(),
// trivial_rlwe_test_poly.0.get_row_slice(0)
// )
// .for_each(|(t, f)| {
// *t = f.to_u64().unwrap();
// });
// izip!(
// rlwe_ct[1].iter_mut(),
// trivial_rlwe_test_poly.0.get_row_slice(1)
// )
// .for_each(|(t, f)| {
// *t = f.to_u64().unwrap();
// });
// let mut m_out = vec![vec![0u64; ring_size]];
// let modop = ModularOpsU64::new(rlwe_q.to_u64().unwrap());
// let nttop = NttBackendU64::new(rlwe_q.to_u64().unwrap(), ring_size);
// decrypt_rlwe(&rlwe_ct, ck.sk_rlwe.values(), &mut m_out, &nttop, &modop);
// println!("RLWE post PBS message: {:?}", m_out[0]);
// });
// sample extract
sample_extract(lwe_in, &trivial_rlwe_test_poly, modop_rlweq, 0);
}
fn mod_switch_odd(v: f64, from_q: f64, to_q: f64) -> usize {
let odd_v = (((v.to_f64().unwrap() * to_q) / (from_q)).floor())
.to_usize()
.unwrap();
// println!("v: {v}, odd_v: {odd_v}, lwe_q:{lwe_q}, br_q:{br_q}");
let odd_v = (((v * to_q) / (from_q)).floor()).to_usize().unwrap();
// println!(
// "v: {v}, odd_v: {odd_v}, returned_oddv: {},lwe_q:{from_q}, br_q:{to_q}",
// odd_v + ((odd_v & 1) ^ 1)
// );
//TODO(Jay): check correctness of this
odd_v + (odd_v ^ (usize::one()))
odd_v + ((odd_v & 1) ^ 1)
}
fn sample_extract<M: Matrix + MatrixMut, ModOp: ArithmeticOps<Element = M::MatElement>>(
@ -576,6 +890,7 @@ fn sample_extract
lwe_out.as_mut()[0] = *rlwe_in.get(1, index);
}
/// TODO(Jay): Write tests for monomial mul
fn monomial_mul<El, ModOp: ArithmeticOps<Element = El>>(
p_in: &[El],
p_out: &mut [El],
@ -606,28 +921,121 @@ fn monomial_mul>(
});
}
thread_local! {
static PBS_TRACER: RefCell<PBSTracer<Vec<Vec<u64>>>> = RefCell::new(PBSTracer::default());
}
#[derive(Default)]
struct PBSTracer<M>
where
M: Matrix + Default,
{
pub(crate) ct_lwe_q_mod: M::R,
pub(crate) ct_lwe_q_mod_after_ksk: M::R,
pub(crate) ct_br_q_mod: Vec<u64>,
}
impl PBSTracer<Vec<Vec<u64>>> {
fn trace(&self, parameters: &BoolParameters<u64>, client_key: &ClientKey, expected_m: bool) {
let lwe_q = parameters.lwe_q;
let lwe_qby8 = ((lwe_q as f64) / 8.0).round() as u64;
let expected_m_lweq = if expected_m {
lwe_qby8
} else {
lwe_q - lwe_qby8
};
let modop_lweq = ModularOpsU64::new(lwe_q);
// noise after mod down Q -> Q_ks
let noise0 = {
measure_noise_lwe(
&self.ct_lwe_q_mod,
client_key.sk_rlwe.values(),
&modop_lweq,
&expected_m_lweq,
)
};
// noise after key switch from RLWE -> LWE
let noise1 = {
measure_noise_lwe(
&self.ct_lwe_q_mod_after_ksk,
client_key.sk_lwe.values(),
&modop_lweq,
&expected_m_lweq,
)
};
// noise after mod down odd from Q_ks -> q
let br_q = parameters.br_q as u64;
let expected_m_brq = if expected_m {
br_q >> 3
} else {
br_q - (br_q >> 3)
};
let modop_br_q = ModularOpsU64::new(br_q);
let noise2 = {
measure_noise_lwe(
&self.ct_br_q_mod,
client_key.sk_lwe.values(),
&modop_br_q,
&expected_m_brq,
)
};
println!(
"
m: {expected_m},
Noise after mod down Q -> Q_ks: {noise0},
Noise after key switch from RLWE -> LWE: {noise1},
Noise after mod dwon Q_ks -> q: {noise2}
"
);
}
}
impl WithLocal for PBSTracer<Vec<Vec<u64>>> {
fn with_local<F, R>(func: F) -> R
where
F: Fn(&Self) -> R,
{
PBS_TRACER.with_borrow(|t| func(t))
}
fn with_local_mut<F, R>(func: F) -> R
where
F: Fn(&mut Self) -> R,
{
PBS_TRACER.with_borrow_mut(|t| func(t))
}
}
#[cfg(test)]
mod tests {
use crate::{backend::ModularOpsU64, ntt::NttBackendU64};
use crate::{backend::ModularOpsU64, ntt::NttBackendU64, random::DEFAULT_RNG};
use super::*;
const SP_BOOL_PARAMS: BoolParameters<u64> = BoolParameters::<u64> {
rlwe_q: 4294957057u64,
rlwe_logq: 32,
rlwe_q: 268369921u64,
rlwe_logq: 28,
lwe_q: 1 << 16,
lwe_logq: 16,
br_q: 1 << 9,
br_q: 1 << 10,
rlwe_n: 1 << 10,
lwe_n: 490,
d_rgsw: 4,
logb_rgsw: 7,
d_lwe: 4,
lwe_n: 493,
d_rgsw: 3,
logb_rgsw: 8,
d_lwe: 3,
logb_lwe: 4,
g: 5,
w: 1,
};
// #[test]
// fn trial() {
// dbg!(generate_prime(28, 1 << 11, 1 << 28));
// }
#[test]
fn encrypt_decrypt_works() {
// let prime = generate_prime(32, 2 * 1024, 1 << 32);
@ -645,4 +1053,71 @@ mod tests {
m = !m;
}
}
#[test]
fn trial12() {
// DefaultSecureRng::with_local_mut(|r| {
// let rng = DefaultSecureRng::new_seeded([19u8; 32]);
// *r = rng;
// });
let bool_evaluator =
BoolEvaluator::<Vec<Vec<u64>>, u64, NttBackendU64, ModularOpsU64>::new(SP_BOOL_PARAMS);
// println!("{:?}", bool_evaluator.nand_test_vec);
let client_key = bool_evaluator.client_key();
set_client_key(&client_key);
let server_key = bool_evaluator.server_key(&client_key);
let mut scratch_lwen_plus1 = vec![0u64; bool_evaluator.parameters.lwe_n + 1];
let mut scratch_matrix_dplus2_ring = vec![
vec![0u64; bool_evaluator.parameters.rlwe_n];
bool_evaluator.parameters.d_rgsw + 2
];
let mut m0 = false;
let mut m1 = true;
let mut ct0 = bool_evaluator.encrypt(m0, &client_key);
let mut ct1 = bool_evaluator.encrypt(m1, &client_key);
for _ in 0..4 {
let ct_back = bool_evaluator.nand(
&ct0,
&ct1,
&server_key,
&mut scratch_lwen_plus1,
&mut scratch_matrix_dplus2_ring,
);
let m_out = !(m0 && m1);
// Trace and measure PBS noise
{
// Trace PBS
PBSTracer::with_local(|t| t.trace(&SP_BOOL_PARAMS, &client_key, m_out));
// Calculate nosie in ciphertext post PBS
let ideal = if m_out {
bool_evaluator.rlweq_by8
} else {
bool_evaluator.rlwe_q() - bool_evaluator.rlweq_by8
};
let noise = measure_noise_lwe(
&ct_back,
client_key.sk_rlwe.values(),
&bool_evaluator.rlwe_modop,
&ideal,
);
println!("PBS noise: {noise}");
}
let m_back = bool_evaluator.decrypt(&ct_back, &client_key);
assert_eq!(m_out, m_back);
println!("----------");
m1 = m0;
m0 = m_out;
ct1 = ct0;
ct0 = ct_back;
}
}
}

+ 34
- 2
src/lwe.rs

@ -1,7 +1,10 @@
use std::fmt::Debug;
use std::{
cell::RefCell,
fmt::{Debug, Display},
};
use itertools::{izip, Itertools};
use num_traits::{abs, Zero};
use num_traits::{abs, PrimInt, ToPrimitive, Zero};
use crate::{
backend::{ArithmeticOps, VectorOps},
@ -21,6 +24,7 @@ trait LweKeySwitchParameters {
trait LweCiphertext<M: Matrix> {}
#[derive(Clone)]
pub struct LweSecret {
values: Vec<i32>,
}
@ -183,6 +187,34 @@ where
operator.sub(b, &sa)
}
pub(crate) fn measure_noise_lwe<Ro: Row, Op: ArithmeticOps<Element = Ro::Element>, S>(
ct: &Ro,
s: &[S],
operator: &Op,
ideal_m: &Ro::Element,
) -> f64
where
Ro: TryConvertFrom<[S], Parameters = Ro::Element>,
Ro::Element: Zero + ToPrimitive + PrimInt + Display,
{
assert!(s.len() == ct.as_ref().len() - 1,);
let s = Ro::try_convert_from(s, &operator.modulus());
let mut sa = Ro::Element::zero();
izip!(s.as_ref().iter(), ct.as_ref().iter().skip(1)).for_each(|(si, ai)| {
sa = operator.add(&sa, &operator.mul(si, ai));
});
let m = operator.sub(&ct.as_ref()[0], &sa);
println!("measire: {m} {ideal_m}");
let mut diff = operator.sub(&m, ideal_m);
let q = operator.modulus();
if diff > (q >> 1) {
diff = q - diff;
}
return diff.to_f64().unwrap().log2();
}
#[cfg(test)]
mod tests {

+ 27
- 14
src/rgsw.rs

@ -16,7 +16,7 @@ use crate::{
Matrix, MatrixEntity, MatrixMut, RowMut, Secret,
};
pub struct RlweCiphertext<M>(M, bool);
pub struct RlweCiphertext<M>(pub(crate) M, pub(crate) bool);
impl<M: Matrix> Matrix for RlweCiphertext<M> {
type MatElement = M::MatElement;
@ -58,6 +58,7 @@ pub trait IsTrivial {
fn set_not_trivial(&mut self);
}
#[derive(Clone)]
pub struct RlweSecret {
values: Vec<i32>,
}
@ -80,12 +81,12 @@ impl RlweSecret {
}
}
fn generate_auto_map(ring_size: usize, k: isize) -> (Vec<usize>, Vec<bool>) {
pub(crate) fn generate_auto_map(ring_size: usize, k: isize) -> (Vec<usize>, Vec<bool>) {
assert!(k & 1 == 1, "Auto {k} must be odd");
// k = k % 2*N
let k = if k < 0 {
(2 * ring_size) - (k.abs() as usize)
// k is -ve, return k%(2*N)
(2 * ring_size) - (k.abs() as usize % (2 * ring_size))
} else {
k as usize
};
@ -712,19 +713,19 @@ pub(crate) fn decrypt_rlwe<
M: Matrix<MatElement = Mmut::MatElement>,
ModOp: VectorOps<Element = Mmut::MatElement>,
NttOp: Ntt<Element = Mmut::MatElement>,
S: Secret,
S,
>(
rlwe_ct: &M,
s: &S,
s: &[S],
m_out: &mut Mmut,
ntt_op: &NttOp,
mod_op: &ModOp,
) where
<Mmut as Matrix>::R: RowMut,
Mmut: TryConvertFrom<[S::Element], Parameters = Mmut::MatElement>,
Mmut: TryConvertFrom<[S], Parameters = Mmut::MatElement>,
Mmut::MatElement: Copy,
{
let ring_size = s.values().len();
let ring_size = s.len();
assert!(rlwe_ct.dimension() == (2, ring_size));
assert!(m_out.dimension() == (1, ring_size));
@ -735,7 +736,7 @@ pub(crate) fn decrypt_rlwe<
ntt_op.forward(m_out.get_row_mut(0));
// -s*a
let mut s = Mmut::try_convert_from(&s.values(), &mod_op.modulus());
let mut s = Mmut::try_convert_from(&s, &mod_op.modulus());
ntt_op.forward(s.get_row_mut(0));
mod_op.elwise_mul_mut(m_out.get_row_mut(0), s.get_row_slice(0));
mod_op.elwise_neg_mut(m_out.get_row_mut(0));
@ -819,7 +820,7 @@ mod tests {
random::{DefaultSecureRng, RandomUniformDist},
rgsw::{measure_noise, RlweCiphertext},
utils::{generate_prime, negacyclic_mul},
Matrix,
Matrix, Secret,
};
use super::{
@ -834,7 +835,7 @@ mod tests {
let ring_size = 1 << 10;
let q = generate_prime(logq, ring_size, 1u64 << logq).unwrap();
let p = 1u64 << logp;
let d_rgsw = 10;
let d_rgsw = 9;
let logb = 5;
let mut rng = DefaultSecureRng::new();
@ -895,7 +896,13 @@ mod tests {
// Decrypt RLWE(m0m1)
let mut encoded_m0m1_back = vec![vec![0u64; ring_size as usize]];
decrypt_rlwe(&rlwe_in_ct, &s, &mut encoded_m0m1_back, &ntt_op, &mod_op);
decrypt_rlwe(
&rlwe_in_ct,
s.values(),
&mut encoded_m0m1_back,
&ntt_op,
&mod_op,
);
let m0m1_back = encoded_m0m1_back[0]
.iter()
.map(|v| (((*v as f64 * p as f64) / (q as f64)).round() as u64) % p)
@ -941,7 +948,7 @@ mod tests {
&mut rng,
);
let auto_k = -25;
let auto_k = -5;
// Generate galois key to key switch from s^k to s
let mut ksk_out = vec![vec![0u64; ring_size as usize]; d_rgsw * 2];
@ -976,7 +983,13 @@ mod tests {
// Decrypt RLWE_{s}(m^k) and check
let mut encoded_m_k_back = vec![vec![0u64; ring_size as usize]];
decrypt_rlwe(&rlwe_m_k, &s, &mut encoded_m_k_back, &ntt_op, &mod_op);
decrypt_rlwe(
&rlwe_m_k,
s.values(),
&mut encoded_m_k_back,
&ntt_op,
&mod_op,
);
let m_k_back = encoded_m_k_back[0]
.iter()
.map(|v| (((*v as f64 * p as f64) / q as f64).round() as u64) % p)

Loading…
Cancel
Save