mirror of
https://github.com/arnaucube/phantom-zone.git
synced 2026-01-09 15:41:30 +01:00
non-interactive example
This commit is contained in:
@@ -26,3 +26,12 @@ harness = false
|
|||||||
[[bench]]
|
[[bench]]
|
||||||
name = "modulus"
|
name = "modulus"
|
||||||
harness = false
|
harness = false
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "interactive_fheuint8"
|
||||||
|
path = "./examples/interactive_fheuint8.rs"
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "non_interactive_fheuint8"
|
||||||
|
path = "./examples/non_interactive_fheuint8.rs"
|
||||||
|
required-features = ["non_interactive_mp"]
|
||||||
@@ -6,13 +6,13 @@ fn plain_circuit(a: u8, b: u8, c: u8) -> u8 {
|
|||||||
(a + b) * c
|
(a + b) * c
|
||||||
}
|
}
|
||||||
|
|
||||||
// fn fhe_circuit(fhe_a: &FheUint8, fhe_b: &FheUint8, fhe_c: &FheUint8) ->
|
fn fhe_circuit(fhe_a: &FheUint8, fhe_b: &FheUint8, fhe_c: &FheUint8) -> FheUint8 {
|
||||||
// FheUint8 { &(fhe_a + fhe_b) * fhe_c
|
&(fhe_a + fhe_b) * fhe_c
|
||||||
// }
|
}
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
set_parameter_set(ParameterSelector::MultiPartyLessThanOrEqualTo16);
|
set_parameter_set(ParameterSelector::MultiPartyLessThanOrEqualTo16);
|
||||||
let no_of_parties = 2;
|
let no_of_parties = 8;
|
||||||
let client_keys = (0..no_of_parties)
|
let client_keys = (0..no_of_parties)
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|_| gen_client_key())
|
.map(|_| gen_client_key())
|
||||||
@@ -50,25 +50,22 @@ fn main() {
|
|||||||
let fhe_b = public_key.encrypt(&b);
|
let fhe_b = public_key.encrypt(&b);
|
||||||
let fhe_c = public_key.encrypt(&c);
|
let fhe_c = public_key.encrypt(&c);
|
||||||
|
|
||||||
let fhe_batched = public_key.encrypt(vec![12, 3u8].as_slice());
|
|
||||||
|
|
||||||
// fhe evaluation
|
// fhe evaluation
|
||||||
// let now = std::time::Instant::now();
|
let now = std::time::Instant::now();
|
||||||
// let fhe_out = fhe_circuit(&fhe_a, &fhe_b, &fhe_c);
|
let fhe_out = fhe_circuit(&fhe_a, &fhe_b, &fhe_c);
|
||||||
// println!("Circuit time: {:?}", now.elapsed());
|
println!("Circuit time: {:?}", now.elapsed());
|
||||||
|
|
||||||
// // plain evaluation
|
// plain evaluation
|
||||||
// let out = plain_circuit(a, b, c);
|
let out = plain_circuit(a, b, c);
|
||||||
|
|
||||||
// // generate decryption shares to decrypt ciphertext fhe_out
|
// generate decryption shares to decrypt ciphertext fhe_out
|
||||||
// let decryption_shares = client_keys
|
let decryption_shares = client_keys
|
||||||
// .iter()
|
.iter()
|
||||||
// .map(|k| k.gen_decryption_share(&fhe_out))
|
.map(|k| k.gen_decryption_share(&fhe_out))
|
||||||
// .collect_vec();
|
.collect_vec();
|
||||||
|
|
||||||
// // decrypt fhe_out using decryption shares
|
// decrypt fhe_out using decryption shares
|
||||||
// let got_out = client_keys[0].aggregate_decryption_shares(&fhe_out,
|
let got_out = client_keys[0].aggregate_decryption_shares(&fhe_out, &decryption_shares);
|
||||||
// &decryption_shares);
|
|
||||||
|
|
||||||
// assert_eq!(got_out, out);
|
assert_eq!(got_out, out);
|
||||||
}
|
}
|
||||||
73
examples/non_interactive_fheuint8.rs
Normal file
73
examples/non_interactive_fheuint8.rs
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
use bin_rs::*;
|
||||||
|
use itertools::Itertools;
|
||||||
|
use rand::{thread_rng, Rng, RngCore};
|
||||||
|
|
||||||
|
fn circuit(a: u8, b: u8, c: u8, d: u8) -> u8 {
|
||||||
|
((a + b) * c) * d
|
||||||
|
}
|
||||||
|
|
||||||
|
fn fhe_circuit(a: &FheUint8, b: &FheUint8, c: &FheUint8, d: &FheUint8) -> FheUint8 {
|
||||||
|
&(&(a + b) * c) * d
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
set_parameter_set(ParameterSelector::NonInteractiveMultiPartyLessThanOrEqualTo16);
|
||||||
|
|
||||||
|
// set CRS
|
||||||
|
let mut seed = [0u8; 32];
|
||||||
|
thread_rng().fill_bytes(&mut seed);
|
||||||
|
set_common_reference_seed(seed);
|
||||||
|
|
||||||
|
let no_of_parties = 2;
|
||||||
|
|
||||||
|
// Generate client keys
|
||||||
|
let cks = (0..no_of_parties).map(|_| gen_client_key()).collect_vec();
|
||||||
|
|
||||||
|
// client 0 encrypts private input
|
||||||
|
let c0_a = thread_rng().gen::<u8>();
|
||||||
|
let c0_b = thread_rng().gen::<u8>();
|
||||||
|
let c0_batched_to_send = cks[0].encrypt(vec![c0_a, c0_b].as_slice());
|
||||||
|
|
||||||
|
// client 1 encrypts private input
|
||||||
|
let c1_a = thread_rng().gen::<u8>();
|
||||||
|
let c1_b = thread_rng().gen::<u8>();
|
||||||
|
let c1_batch_to_send = cks[1].encrypt(vec![c1_a, c1_b].as_slice());
|
||||||
|
|
||||||
|
// Both client indenpendently generate their server key shares
|
||||||
|
let server_key_shares = cks
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(id, k)| gen_server_key_share(id, no_of_parties, k))
|
||||||
|
.collect_vec();
|
||||||
|
|
||||||
|
// Server side
|
||||||
|
|
||||||
|
// aggregates shares and generates server key
|
||||||
|
let server_key = aggregate_server_key_shares(&server_key_shares);
|
||||||
|
server_key.set_server_key();
|
||||||
|
|
||||||
|
// extract a and b from client0 inputs
|
||||||
|
let (ct_c0_a, ct_c0_b) = {
|
||||||
|
let ct = c0_batched_to_send.unseed::<Vec<Vec<u64>>>().key_switch(0);
|
||||||
|
(ct.extract(0), ct.extract(1))
|
||||||
|
};
|
||||||
|
|
||||||
|
// extract a and b from client1 inputs
|
||||||
|
let (ct_c1_a, ct_c1_b) = {
|
||||||
|
let ct = c1_batch_to_send.unseed::<Vec<Vec<u64>>>().key_switch(1);
|
||||||
|
(ct.extract(0), ct.extract(1))
|
||||||
|
};
|
||||||
|
|
||||||
|
let now = std::time::Instant::now();
|
||||||
|
let c_out = fhe_circuit(&ct_c0_a, &ct_c1_a, &ct_c0_b, &ct_c1_b);
|
||||||
|
println!("Circuit Time: {:?}", now.elapsed());
|
||||||
|
|
||||||
|
// decrypt c_out
|
||||||
|
let decryption_shares = cks
|
||||||
|
.iter()
|
||||||
|
.map(|k| k.gen_decryption_share(&c_out))
|
||||||
|
.collect_vec();
|
||||||
|
let m_out = cks[0].aggregate_decryption_shares(&c_out, &decryption_shares);
|
||||||
|
let m_expected = circuit(c0_a, c1_a, c0_b, c1_b);
|
||||||
|
assert!(m_expected == m_out);
|
||||||
|
}
|
||||||
@@ -1,16 +1,21 @@
|
|||||||
pub(crate) mod evaluator;
|
pub(crate) mod evaluator;
|
||||||
mod keys;
|
mod keys;
|
||||||
mod mp_api;
|
|
||||||
mod ni_mp_api;
|
|
||||||
mod noise;
|
mod noise;
|
||||||
pub(crate) mod parameters;
|
pub(crate) mod parameters;
|
||||||
|
|
||||||
pub(crate) use keys::PublicKey;
|
pub(crate) use keys::PublicKey;
|
||||||
|
|
||||||
|
#[cfg(feature = "interactive_mp")]
|
||||||
|
#[cfg(not(feature = "non_interactive_mp"))]
|
||||||
|
mod mp_api;
|
||||||
|
#[cfg(feature = "non_interactive_mp")]
|
||||||
|
mod ni_mp_api;
|
||||||
|
|
||||||
#[cfg(feature = "non_interactive_mp")]
|
#[cfg(feature = "non_interactive_mp")]
|
||||||
pub use ni_mp_api::*;
|
pub use ni_mp_api::*;
|
||||||
|
|
||||||
#[cfg(feature = "interactive_mp")]
|
#[cfg(feature = "interactive_mp")]
|
||||||
|
#[cfg(not(feature = "non_interactive_mp"))]
|
||||||
pub use mp_api::*;
|
pub use mp_api::*;
|
||||||
|
|
||||||
pub type ClientKey = keys::ClientKey<[u8; 32], u64>;
|
pub type ClientKey = keys::ClientKey<[u8; 32], u64>;
|
||||||
@@ -22,7 +27,11 @@ pub enum ParameterSelector {
|
|||||||
|
|
||||||
mod common_mp_enc_dec {
|
mod common_mp_enc_dec {
|
||||||
use super::BoolEvaluator;
|
use super::BoolEvaluator;
|
||||||
use crate::{utils::WithLocal, Matrix, MultiPartyDecryptor};
|
use crate::{
|
||||||
|
pbs::{sample_extract, PbsInfo},
|
||||||
|
utils::WithLocal,
|
||||||
|
Matrix, MultiPartyDecryptor, RowEntity, SampleExtractor,
|
||||||
|
};
|
||||||
|
|
||||||
type Mat = Vec<Vec<u64>>;
|
type Mat = Vec<Vec<u64>>;
|
||||||
|
|
||||||
@@ -41,4 +50,20 @@ mod common_mp_enc_dec {
|
|||||||
BoolEvaluator::with_local(|e| e.multi_party_decrypt(shares, c))
|
BoolEvaluator::with_local(|e| e.multi_party_decrypt(shares, c))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl SampleExtractor<<Mat as Matrix>::R> for Mat {
|
||||||
|
fn extract(&self, index: usize) -> <Mat as Matrix>::R {
|
||||||
|
// input is RLWE ciphertext
|
||||||
|
assert!(self.dimension().0 == 2);
|
||||||
|
|
||||||
|
let ring_size = self.dimension().1;
|
||||||
|
assert!(index < ring_size);
|
||||||
|
|
||||||
|
BoolEvaluator::with_local(|e| {
|
||||||
|
let mut lwe_out = <Mat as Matrix>::R::zeros(ring_size + 1);
|
||||||
|
sample_extract(&mut lwe_out, self, e.pbs_info().modop_rlweq(), index);
|
||||||
|
lwe_out
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ use crate::{
|
|||||||
|
|
||||||
use super::{evaluator::MultiPartyCrs, keys::*, parameters::*, ClientKey, ParameterSelector};
|
use super::{evaluator::MultiPartyCrs, keys::*, parameters::*, ClientKey, ParameterSelector};
|
||||||
|
|
||||||
pub type BoolEvaluator = super::evaluator::BoolEvaluator<
|
pub(crate) type BoolEvaluator = super::evaluator::BoolEvaluator<
|
||||||
Vec<Vec<u64>>,
|
Vec<Vec<u64>>,
|
||||||
NttBackendU64,
|
NttBackendU64,
|
||||||
ModularOpsU64<CiphertextModulus<u64>>,
|
ModularOpsU64<CiphertextModulus<u64>>,
|
||||||
@@ -304,8 +304,13 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
mod sp_api {
|
mod sp_api {
|
||||||
|
use num_traits::ToPrimitive;
|
||||||
|
use rand::Rng;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
backend::ModulusPowerOf2, utils::WithLocal, Decryptor, ModularOpsU64, NttBackendU64,
|
backend::ModulusPowerOf2, evaluator::BoolEncoding, pbs::PbsInfo,
|
||||||
|
rgsw::secret_key_encrypt_rlwe, utils::WithLocal, Decryptor, ModularOpsU64,
|
||||||
|
NttBackendU64, SampleExtractor,
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
@@ -358,6 +363,86 @@ mod tests {
|
|||||||
BoolEvaluator::with_local(|e| e.sk_decrypt(c, self))
|
BoolEvaluator::with_local(|e| e.sk_decrypt(c, self))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
impl<K> Encryptor<[bool], (Vec<Vec<u64>>, [u8; 32])> for K
|
||||||
|
where
|
||||||
|
K: SinglePartyClientKey<Element = i32>,
|
||||||
|
{
|
||||||
|
fn encrypt(&self, m: &[bool]) -> (Vec<Vec<u64>>, [u8; 32]) {
|
||||||
|
BoolEvaluator::with_local(|e| {
|
||||||
|
DefaultSecureRng::with_local_mut(|rng| {
|
||||||
|
let parameters = e.parameters();
|
||||||
|
let ring_size = parameters.rlwe_n().0;
|
||||||
|
|
||||||
|
let rlwe_count = ((m.len() as f64 / ring_size as f64).ceil())
|
||||||
|
.to_usize()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let mut seed = <DefaultSecureRng as NewWithSeed>::Seed::default();
|
||||||
|
rng.fill_bytes(&mut seed);
|
||||||
|
let mut prng = DefaultSecureRng::new_seeded(seed);
|
||||||
|
|
||||||
|
let sk_u = self.sk_rlwe();
|
||||||
|
|
||||||
|
// encrypt `m` into ceil(len(m)/N) RLWE ciphertexts
|
||||||
|
let rlwes = (0..rlwe_count)
|
||||||
|
.map(|index| {
|
||||||
|
let mut message = vec![0; ring_size];
|
||||||
|
m[(index * ring_size)
|
||||||
|
..std::cmp::min(m.len(), (index + 1) * ring_size)]
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.for_each(|(i, v)| {
|
||||||
|
if *v {
|
||||||
|
message[i] = parameters.rlwe_q().true_el()
|
||||||
|
} else {
|
||||||
|
message[i] = parameters.rlwe_q().false_el()
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// encrypt message
|
||||||
|
let mut rlwe_out = vec![0u64; parameters.rlwe_n().0];
|
||||||
|
|
||||||
|
secret_key_encrypt_rlwe(
|
||||||
|
&message,
|
||||||
|
&mut rlwe_out,
|
||||||
|
&sk_u,
|
||||||
|
e.pbs_info().modop_rlweq(),
|
||||||
|
e.pbs_info().nttop_rlweq(),
|
||||||
|
&mut prng,
|
||||||
|
rng,
|
||||||
|
);
|
||||||
|
|
||||||
|
rlwe_out
|
||||||
|
})
|
||||||
|
.collect_vec();
|
||||||
|
|
||||||
|
(rlwes, seed)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn batch_extract_works() {
|
||||||
|
set_single_party_parameter_sets(SP_TEST_BOOL_PARAMS);
|
||||||
|
|
||||||
|
let (ck, sk) = gen_keys();
|
||||||
|
sk.set_server_key();
|
||||||
|
|
||||||
|
let batch_size = (SP_TEST_BOOL_PARAMS.rlwe_n().0 * 3 + 123);
|
||||||
|
let m = (0..batch_size)
|
||||||
|
.map(|_| thread_rng().gen::<u8>())
|
||||||
|
.collect_vec();
|
||||||
|
|
||||||
|
let seeded_ct = ck.encrypt(m.as_slice());
|
||||||
|
let ct = seeded_ct.unseed::<Vec<Vec<u64>>>();
|
||||||
|
|
||||||
|
let m_back = (0..batch_size)
|
||||||
|
.map(|i| ck.decrypt(&ct.extract(i)))
|
||||||
|
.collect_vec();
|
||||||
|
|
||||||
|
assert_eq!(m, m_back);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[cfg(feature = "interactive_mp")]
|
#[cfg(feature = "interactive_mp")]
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ use super::{
|
|||||||
ClientKey, ParameterSelector,
|
ClientKey, ParameterSelector,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub type BoolEvaluator = super::evaluator::BoolEvaluator<
|
pub(crate) type BoolEvaluator = super::evaluator::BoolEvaluator<
|
||||||
Vec<Vec<u64>>,
|
Vec<Vec<u64>>,
|
||||||
NttBackendU64,
|
NttBackendU64,
|
||||||
ModularOpsU64<CiphertextModulus<u64>>,
|
ModularOpsU64<CiphertextModulus<u64>>,
|
||||||
|
|||||||
@@ -23,8 +23,10 @@ pub use backend::{
|
|||||||
// aggregate_public_key_shares, aggregate_server_key_shares, gen_client_key,
|
// aggregate_public_key_shares, aggregate_server_key_shares, gen_client_key,
|
||||||
// gen_mp_keys_phase1, gen_mp_keys_phase2, set_mp_seed, set_parameter_set,
|
// gen_mp_keys_phase1, gen_mp_keys_phase2, set_mp_seed, set_parameter_set,
|
||||||
// ParameterSelector, };
|
// ParameterSelector, };
|
||||||
|
pub use bool::*;
|
||||||
pub use decomposer::{Decomposer, DecomposerIter, DefaultDecomposer};
|
pub use decomposer::{Decomposer, DecomposerIter, DefaultDecomposer};
|
||||||
pub use ntt::{Ntt, NttBackendU64, NttInit};
|
pub use ntt::{Ntt, NttBackendU64, NttInit};
|
||||||
|
pub use shortint::FheUint8;
|
||||||
|
|
||||||
pub trait Matrix: AsRef<[Self::R]> {
|
pub trait Matrix: AsRef<[Self::R]> {
|
||||||
type MatElement;
|
type MatElement;
|
||||||
@@ -184,6 +186,10 @@ pub trait KeySwitchWithId<C> {
|
|||||||
fn key_switch(&self, user_id: usize) -> C;
|
fn key_switch(&self, user_id: usize) -> C;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) trait Encoder<F, T> {
|
pub trait SampleExtractor<R> {
|
||||||
|
fn extract(&self, index: usize) -> R;
|
||||||
|
}
|
||||||
|
|
||||||
|
trait Encoder<F, T> {
|
||||||
fn encode(&self, v: F) -> T;
|
fn encode(&self, v: F) -> T;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,9 +5,13 @@ use crate::{
|
|||||||
random::{DefaultSecureRng, RandomFillUniformInModulus},
|
random::{DefaultSecureRng, RandomFillUniformInModulus},
|
||||||
utils::{TryConvertFrom1, WithLocal},
|
utils::{TryConvertFrom1, WithLocal},
|
||||||
Decryptor, Encryptor, KeySwitchWithId, Matrix, MatrixEntity, MatrixMut, MultiPartyDecryptor,
|
Decryptor, Encryptor, KeySwitchWithId, Matrix, MatrixEntity, MatrixMut, MultiPartyDecryptor,
|
||||||
RowMut,
|
RowMut, SampleExtractor,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Fhe UInt8 type
|
||||||
|
///
|
||||||
|
/// - Stores encryptions of bits in little endian (i.e least signficant bit
|
||||||
|
/// stored at 0th index and most signficant bit stores at 7th index)
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct FheUint8<C> {
|
pub struct FheUint8<C> {
|
||||||
pub(super) data: Vec<C>,
|
pub(super) data: Vec<C>,
|
||||||
@@ -27,6 +31,28 @@ pub struct BatchedFheUint8<C> {
|
|||||||
data: Vec<C>,
|
data: Vec<C>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<C, R> SampleExtractor<FheUint8<R>> for BatchedFheUint8<C>
|
||||||
|
where
|
||||||
|
C: SampleExtractor<R>,
|
||||||
|
{
|
||||||
|
fn extract(&self, index: usize) -> FheUint8<R> {
|
||||||
|
BoolEvaluator::with_local(|e| {
|
||||||
|
let ring_size = e.parameters().rlwe_n().0;
|
||||||
|
|
||||||
|
let start_index = index * 8;
|
||||||
|
let end_index = (index + 1) * 8;
|
||||||
|
let data = (start_index..end_index)
|
||||||
|
.map(|i| {
|
||||||
|
let rlwe_index = i / ring_size;
|
||||||
|
let coeff_index = i % ring_size;
|
||||||
|
self.data[rlwe_index].extract(coeff_index)
|
||||||
|
})
|
||||||
|
.collect_vec();
|
||||||
|
FheUint8 { data }
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<M: MatrixEntity + MatrixMut<MatElement = u64>> From<&SeededBatchedFheUint8<M::R, [u8; 32]>>
|
impl<M: MatrixEntity + MatrixMut<MatElement = u64>> From<&SeededBatchedFheUint8<M::R, [u8; 32]>>
|
||||||
for BatchedFheUint8<M>
|
for BatchedFheUint8<M>
|
||||||
where
|
where
|
||||||
@@ -85,7 +111,6 @@ where
|
|||||||
.flat_map(|v| (0..8).into_iter().map(|i| (((*v) >> i) & 1) == 1))
|
.flat_map(|v| (0..8).into_iter().map(|i| (((*v) >> i) & 1) == 1))
|
||||||
.collect_vec();
|
.collect_vec();
|
||||||
let (cts, seed) = K::encrypt(&self, &m);
|
let (cts, seed) = K::encrypt(&self, &m);
|
||||||
dbg!(cts.len());
|
|
||||||
SeededBatchedFheUint8 { data: cts, seed }
|
SeededBatchedFheUint8 { data: cts, seed }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user