Browse Source

Add short int and more gates

par-agg-key-shares
Janmajaya Mall 10 months ago
parent
commit
0bb653c816
3 changed files with 292 additions and 34 deletions
  1. +277
    -34
      src/bool/evaluator.rs
  2. +1
    -0
      src/lib.rs
  3. +14
    -0
      src/shortint.rs

+ 277
- 34
src/bool/evaluator.rs

@ -636,7 +636,6 @@ struct BoolPbsInfo {
rlwe_modop: RlweModOp, rlwe_modop: RlweModOp,
lwe_modop: LweModOp, lwe_modop: LweModOp,
embedding_factor: usize, embedding_factor: usize,
nand_test_vec: M::R,
rlwe_qby4: M::MatElement, rlwe_qby4: M::MatElement,
rlwe_auto_maps: Vec<(Vec<usize>, Vec<bool>)>, rlwe_auto_maps: Vec<(Vec<usize>, Vec<bool>)>,
parameters: BoolParameters<M::MatElement>, parameters: BoolParameters<M::MatElement>,
@ -715,6 +714,12 @@ where
{ {
pbs_info: BoolPbsInfo<M, Ntt, RlweModOp, LweModOp>, pbs_info: BoolPbsInfo<M, Ntt, RlweModOp, LweModOp>,
scratch_memory: ScratchMemory<M>, scratch_memory: ScratchMemory<M>,
nand_test_vec: M::R,
and_test_vec: M::R,
or_test_vec: M::R,
nor_test_vec: M::R,
xor_test_vec: M::R,
xnor_test_vec: M::R,
_phantom: PhantomData<M>, _phantom: PhantomData<M>,
} }
@ -764,39 +769,79 @@ where
let rlwe_modop = RlweModOp::new(*parameters.rlwe_q()); let rlwe_modop = RlweModOp::new(*parameters.rlwe_q());
let lwe_modop = LweModOp::new(*parameters.lwe_q()); let lwe_modop = LweModOp::new(*parameters.lwe_q());
// set test vectors
let q = *parameters.br_q(); let q = *parameters.br_q();
let qby2 = q >> 1; let qby2 = q >> 1;
let qby8 = q >> 3; let qby8 = q >> 3;
let mut nand_test_vec = M::R::zeros(qby2);
// Q/8 (Q: rlwe_q) // Q/8 (Q: rlwe_q)
let true_m_el = parameters.rlwe_q().true_el(); let true_m_el = parameters.rlwe_q().true_el();
// -Q/8 // -Q/8
let false_m_el = parameters.rlwe_q().false_el(); let false_m_el = parameters.rlwe_q().false_el();
for i in 0..qby2 {
if i < (3 * qby8) {
nand_test_vec.as_mut()[i] = true_m_el;
} else {
nand_test_vec.as_mut()[i] = false_m_el;
}
}
// v(X) -> v(X^{-g})
let (auto_map_index, auto_map_sign) = generate_auto_map(qby2, -(g as isize)); let (auto_map_index, auto_map_sign) = generate_auto_map(qby2, -(g as isize));
let mut nand_test_vec_autog = M::R::zeros(qby2);
izip!(
nand_test_vec.as_ref().iter(),
auto_map_index.iter(),
auto_map_sign.iter()
)
.for_each(|(v, to_index, to_sign)| {
if !to_sign {
// negate
nand_test_vec_autog.as_mut()[*to_index] = rlwe_modop.neg(v);
} else {
nand_test_vec_autog.as_mut()[*to_index] = *v;
let init_test_vec = |partition_el: usize,
before_partition_el: M::MatElement,
after_partition_el: M::MatElement| {
let mut test_vec = M::R::zeros(qby2);
for i in 0..qby2 {
if i < partition_el {
test_vec.as_mut()[i] = before_partition_el;
} else {
test_vec.as_mut()[i] = after_partition_el;
}
} }
});
// v(X) -> v(X^{-g})
let mut test_vec_autog = M::R::zeros(qby2);
izip!(
test_vec.as_ref().iter(),
auto_map_index.iter(),
auto_map_sign.iter()
)
.for_each(|(v, to_index, to_sign)| {
if !to_sign {
// negate
test_vec_autog.as_mut()[*to_index] = rlwe_modop.neg(v);
} else {
test_vec_autog.as_mut()[*to_index] = *v;
}
});
return test_vec_autog;
};
let nand_test_vec = init_test_vec(3 * qby8, true_m_el, false_m_el);
let and_test_vec = init_test_vec(3 * qby8, false_m_el, true_m_el);
let or_test_vec = init_test_vec(qby8, false_m_el, true_m_el);
let nor_test_vec = init_test_vec(qby8, true_m_el, false_m_el);
let xor_test_vec = init_test_vec(qby8, false_m_el, true_m_el);
let xnor_test_vec = init_test_vec(qby8, true_m_el, false_m_el);
// // set test vectors
// let mut nand_test_vec = M::R::zeros(qby2);
// for i in 0..qby2 {
// if i < (3 * qby8) {
// nand_test_vec.as_mut()[i] = true_m_el;
// } else {
// nand_test_vec.as_mut()[i] = false_m_el;
// }
// }
// // v(X) -> v(X^{-g})
// let (auto_map_index, auto_map_sign) = generate_auto_map(qby2, -(g as isize));
// let mut nand_test_vec_autog = M::R::zeros(qby2);
// izip!(
// nand_test_vec.as_ref().iter(),
// auto_map_index.iter(),
// auto_map_sign.iter()
// )
// .for_each(|(v, to_index, to_sign)| {
// if !to_sign {
// // negate
// nand_test_vec_autog.as_mut()[*to_index] = rlwe_modop.neg(v);
// } else {
// nand_test_vec_autog.as_mut()[*to_index] = *v;
// }
// });
// auto map indices and sign // auto map indices and sign
let mut rlwe_auto_maps = vec![]; let mut rlwe_auto_maps = vec![];
@ -819,7 +864,6 @@ where
lwe_modop, lwe_modop,
rlwe_modop, rlwe_modop,
rlwe_nttop, rlwe_nttop,
nand_test_vec: nand_test_vec_autog,
rlwe_qby4, rlwe_qby4,
rlwe_auto_maps, rlwe_auto_maps,
parameters: parameters, parameters: parameters,
@ -828,6 +872,12 @@ where
BoolEvaluator { BoolEvaluator {
pbs_info, pbs_info,
scratch_memory, scratch_memory,
nand_test_vec,
and_test_vec,
or_test_vec,
nor_test_vec,
xnor_test_vec,
xor_test_vec,
_phantom: PhantomData, _phantom: PhantomData,
} }
} }
@ -1419,13 +1469,11 @@ where
} }
} }
// TODO(Jay): scratch spaces must be thread local. Don't pass them as arguments
pub fn nand(
&mut self,
c0: &M::R,
c1: &M::R,
server_key: &ServerKeyEvaluationDomain<M, DefaultSecureRng, NttOp>,
) -> M::R {
/// Returns c0 + c1 + Q/4
fn _add_and_shift_lwe_cts(&self, c0: &M::R, c1: &M::R) -> M::R
where
M::R: Clone,
{
let mut c_out = M::R::zeros(c0.as_ref().len()); let mut c_out = M::R::zeros(c0.as_ref().len());
let modop = &self.pbs_info.rlwe_modop; let modop = &self.pbs_info.rlwe_modop;
izip!( izip!(
@ -1438,11 +1486,111 @@ where
}); });
// +Q/4 // +Q/4
c_out.as_mut()[0] = modop.add(&c_out.as_ref()[0], &self.pbs_info.rlwe_qby4); c_out.as_mut()[0] = modop.add(&c_out.as_ref()[0], &self.pbs_info.rlwe_qby4);
c_out
}
/// Returns 2(c0 - c1) + Q/4
fn _subtract_double_and_shift_lwe_cts(&self, c0: &M::R, c1: &M::R) -> M::R
where
M::R: Clone,
{
let mut c_out = c0.clone();
let modop = &self.pbs_info.rlwe_modop;
// c0 - c1
modop.elwise_sub_mut(c_out.as_mut(), c1.as_ref());
// double
c_out.as_mut().iter_mut().for_each(|v| *v = modop.add(v, v));
c_out
}
pub fn nand(
&mut self,
c0: &M::R,
c1: &M::R,
server_key: &ServerKeyEvaluationDomain<M, DefaultSecureRng, NttOp>,
) -> M::R
where
M::R: Clone,
{
let mut c_out = self._add_and_shift_lwe_cts(c0, c1);
// PBS
pbs(
&self.pbs_info,
&self.nand_test_vec,
&mut c_out,
server_key,
&mut self.scratch_memory.lwe_vector,
&mut self.scratch_memory.decomposition_matrix,
);
c_out
}
pub fn and(
&mut self,
c0: &M::R,
c1: &M::R,
server_key: &ServerKeyEvaluationDomain<M, DefaultSecureRng, NttOp>,
) -> M::R
where
M::R: Clone,
{
let mut c_out = self._add_and_shift_lwe_cts(c0, c1);
// PBS
pbs(
&self.pbs_info,
&self.and_test_vec,
&mut c_out,
server_key,
&mut self.scratch_memory.lwe_vector,
&mut self.scratch_memory.decomposition_matrix,
);
c_out
}
pub fn or(
&mut self,
c0: &M::R,
c1: &M::R,
server_key: &ServerKeyEvaluationDomain<M, DefaultSecureRng, NttOp>,
) -> M::R
where
M::R: Clone,
{
let mut c_out = self._add_and_shift_lwe_cts(c0, c1);
// PBS
pbs(
&self.pbs_info,
&self.or_test_vec,
&mut c_out,
server_key,
&mut self.scratch_memory.lwe_vector,
&mut self.scratch_memory.decomposition_matrix,
);
c_out
}
pub fn nor(
&mut self,
c0: &M::R,
c1: &M::R,
server_key: &ServerKeyEvaluationDomain<M, DefaultSecureRng, NttOp>,
) -> M::R
where
M::R: Clone,
{
let mut c_out = self._add_and_shift_lwe_cts(c0, c1);
// PBS // PBS
pbs( pbs(
&self.pbs_info, &self.pbs_info,
&self.pbs_info.nand_test_vec,
&self.nor_test_vec,
&mut c_out, &mut c_out,
server_key, server_key,
&mut self.scratch_memory.lwe_vector, &mut self.scratch_memory.lwe_vector,
@ -1451,6 +1599,62 @@ where
c_out c_out
} }
pub fn xor(
&mut self,
c0: &M::R,
c1: &M::R,
server_key: &ServerKeyEvaluationDomain<M, DefaultSecureRng, NttOp>,
) -> M::R
where
M::R: Clone,
{
let mut c_out = self._subtract_double_and_shift_lwe_cts(c0, c1);
// PBS
pbs(
&self.pbs_info,
&self.xor_test_vec,
&mut c_out,
server_key,
&mut self.scratch_memory.lwe_vector,
&mut self.scratch_memory.decomposition_matrix,
);
c_out
}
pub fn xnor(
&mut self,
c0: &M::R,
c1: &M::R,
server_key: &ServerKeyEvaluationDomain<M, DefaultSecureRng, NttOp>,
) -> M::R
where
M::R: Clone,
{
let mut c_out = self._subtract_double_and_shift_lwe_cts(c0, c1);
// PBS
pbs(
&self.pbs_info,
&self.xnor_test_vec,
&mut c_out,
server_key,
&mut self.scratch_memory.lwe_vector,
&mut self.scratch_memory.decomposition_matrix,
);
c_out
}
pub fn not(&mut self, c0: &M::R) -> M::R
where
<M as Matrix>::R: FromIterator<<M as Matrix>::MatElement>,
{
let modop = &self.pbs_info.rlwe_modop;
c0.as_ref().iter().map(|v| modop.neg(v)).collect()
}
} }
/// LMKCY+ Blind rotation /// LMKCY+ Blind rotation
@ -1956,6 +2160,7 @@ mod tests {
let mut m0 = false; let mut m0 = false;
let mut m1 = true; let mut m1 = true;
let mut ct0 = bool_evaluator.sk_encrypt(m0, &client_key); let mut ct0 = bool_evaluator.sk_encrypt(m0, &client_key);
let mut ct1 = bool_evaluator.sk_encrypt(m1, &client_key); let mut ct1 = bool_evaluator.sk_encrypt(m1, &client_key);
@ -2051,6 +2256,44 @@ mod tests {
} }
} }
#[test]
fn bool_xor() {
let mut bool_evaluator = BoolEvaluator::<
Vec<Vec<u64>>,
NttBackendU64,
ModularOpsU64<CiphertextModulus<u64>>,
ModularOpsU64<CiphertextModulus<u64>>,
>::new(SP_BOOL_PARAMS);
// println!("{:?}", bool_evaluator.nand_test_vec);
let client_key = bool_evaluator.client_key();
let seeded_server_key = bool_evaluator.server_key(&client_key);
let server_key_eval_domain =
ServerKeyEvaluationDomain::<_, DefaultSecureRng, NttBackendU64>::from(
&seeded_server_key,
);
let mut m0 = false;
let mut m1 = true;
let mut ct0 = bool_evaluator.sk_encrypt(m0, &client_key);
let mut ct1 = bool_evaluator.sk_encrypt(m1, &client_key);
for _ in 0..1000 {
let ct_back = bool_evaluator.xor(&ct0, &ct1, &server_key_eval_domain);
let m_out = (m0 ^ m1);
let m_back = bool_evaluator.sk_decrypt(&ct_back, &client_key);
assert!(m_out == m_back, "Expected {m_out}, got {m_back}");
m1 = m0;
m0 = m_out;
ct1 = ct0;
ct0 = ct_back;
}
}
#[test] #[test]
fn multi_party_encryption_decryption() { fn multi_party_encryption_decryption() {
let bool_evaluator = BoolEvaluator::< let bool_evaluator = BoolEvaluator::<

+ 1
- 0
src/lib.rs

@ -14,6 +14,7 @@ mod ntt;
mod num; mod num;
mod random; mod random;
mod rgsw; mod rgsw;
mod shortint;
mod utils; mod utils;
pub trait Matrix: AsRef<[Self::R]> { pub trait Matrix: AsRef<[Self::R]> {

+ 14
- 0
src/shortint.rs

@ -0,0 +1,14 @@
use itertools::izip;
use crate::Matrix;
struct FheUint8<M: Matrix> {
data: M,
}
fn add<M: Matrix>(a: FheUint8<M>, b: FheUint8<M>) {
// CALL THE EVALUATOR
izip!(a.data.iter_rows(), b.data.iter_rows()).for_each(|(a_bit, b_bit)| {
// A ^ B
});
}

Loading…
Cancel
Save