diff --git a/examples/bomberman.rs b/examples/bomberman.rs index ad24526..f802af6 100644 --- a/examples/bomberman.rs +++ b/examples/bomberman.rs @@ -126,20 +126,20 @@ fn main() { let player_0_moves_enc = { let c = player_0_enc.unseed::>>().key_switch(0); (0..no_of_moves) - .map(|i| Coordinates::new(c.extract(2 * i), c.extract(i * 2 + 1))) + .map(|i| Coordinates::new(c.extract_at(2 * i), c.extract_at(i * 2 + 1))) .collect_vec() }; let player_1_bomb_enc = { let c = player_1_enc.unseed::>>().key_switch(1); - Coordinates::new(c.extract(0), c.extract(1)) + Coordinates::new(c.extract_at(0), c.extract_at(1)) }; let player_2_bomb_enc = { let c = player_2_enc.unseed::>>().key_switch(2); - Coordinates::new(c.extract(0), c.extract(1)) + Coordinates::new(c.extract_at(0), c.extract_at(1)) }; let player_3_bomb_enc = { let c = player_3_enc.unseed::>>().key_switch(3); - Coordinates::new(c.extract(0), c.extract(1)) + Coordinates::new(c.extract_at(0), c.extract_at(1)) }; // run the game diff --git a/examples/meeting_friends.rs b/examples/meeting_friends.rs index 5139d2c..8e9f31a 100644 --- a/examples/meeting_friends.rs +++ b/examples/meeting_friends.rs @@ -105,11 +105,14 @@ fn main() { // Server parses private inputs from user a and b let user_a_location_enc = { let c = user_a_enc.unseed::>>().key_switch(0); - Location::new(c.extract(0), c.extract(1)) + Location::new(c.extract_at(0), c.extract_at(1)) }; let (user_b_location_enc, user_b_threshold_enc) = { let c = user_b_enc.unseed::>>().key_switch(1); - (Location::new(c.extract(0), c.extract(1)), c.extract(2)) + ( + Location::new(c.extract_at(0), c.extract_at(1)), + c.extract_at(2), + ) }; // run the circuit diff --git a/examples/non_interactive_fheuint8.rs b/examples/non_interactive_fheuint8.rs index d34cee0..9edf512 100644 --- a/examples/non_interactive_fheuint8.rs +++ b/examples/non_interactive_fheuint8.rs @@ -50,7 +50,7 @@ fn main() { // let now = std::time::Instant::now(); let (ct_c0_a, ct_c0_b) = { let ct = c0_batched_to_send.unseed::>>().key_switch(0); - (ct.extract(0), ct.extract(1)) + (ct.extract_at(0), ct.extract_at(1)) }; // println!( // "Time to unseed, key switch, and extract 2 ciphertexts: {:?}", @@ -60,7 +60,7 @@ fn main() { // extract a and b from client1 inputs let (ct_c1_a, ct_c1_b) = { let ct = c1_batch_to_send.unseed::>>().key_switch(1); - (ct.extract(0), ct.extract(1)) + (ct.extract_at(0), ct.extract_at(1)) }; let now = std::time::Instant::now(); diff --git a/src/bool/mod.rs b/src/bool/mod.rs index bbc9cde..e465f7d 100644 --- a/src/bool/mod.rs +++ b/src/bool/mod.rs @@ -194,6 +194,8 @@ mod impl_bool_frontend { } mod common_mp_enc_dec { + use itertools::Itertools; + use super::BoolEvaluator; use crate::{ pbs::{sample_extract, PbsInfo}, @@ -225,7 +227,7 @@ mod common_mp_enc_dec { impl SampleExtractor<::R> for Mat { /// Sample extract coefficient at `index` as a LWE ciphertext from RLWE /// ciphertext `Self` - fn extract(&self, index: usize) -> ::R { + fn extract_at(&self, index: usize) -> ::R { // input is RLWE ciphertext assert!(self.dimension().0 == 2); @@ -238,6 +240,41 @@ mod common_mp_enc_dec { lwe_out }) } + + /// Extract first `how_many` coefficients of `Self` as LWE ciphertexts + fn extract_many(&self, how_many: usize) -> Vec<::R> { + assert!(self.dimension().0 == 2); + + let ring_size = self.dimension().1; + assert!(how_many <= ring_size); + + (0..how_many) + .map(|index| { + BoolEvaluator::with_local(|e| { + let mut lwe_out = ::R::zeros(ring_size + 1); + sample_extract(&mut lwe_out, self, e.pbs_info().modop_rlweq(), index); + lwe_out + }) + }) + .collect_vec() + } + + /// Extracts all coefficients of `Self` as LWE ciphertexts + fn extract_all(&self) -> Vec<::R> { + assert!(self.dimension().0 == 2); + + let ring_size = self.dimension().1; + + (0..ring_size) + .map(|index| { + BoolEvaluator::with_local(|e| { + let mut lwe_out = ::R::zeros(ring_size + 1); + sample_extract(&mut lwe_out, self, e.pbs_info().modop_rlweq(), index); + lwe_out + }) + }) + .collect_vec() + } } } diff --git a/src/lib.rs b/src/lib.rs index 0e91204..a60222a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -185,7 +185,12 @@ pub trait KeySwitchWithId { } pub trait SampleExtractor { - fn extract(&self, index: usize) -> R; + /// Extract ciphertext at `index` + fn extract_at(&self, index: usize) -> R; + /// Extract all ciphertexts + fn extract_all(&self) -> Vec; + /// Extract first `how_many` ciphertexts + fn extract_many(&self, how_many: usize) -> Vec; } trait Encoder { diff --git a/src/shortint/enc_dec.rs b/src/shortint/enc_dec.rs index 30f8b3e..c909b24 100644 --- a/src/shortint/enc_dec.rs +++ b/src/shortint/enc_dec.rs @@ -1,7 +1,7 @@ use itertools::Itertools; use crate::{ - bool::{BoolEvaluator, FheBool}, + bool::BoolEvaluator, random::{DefaultSecureRng, RandomFillUniformInModulus}, utils::WithLocal, Decryptor, Encryptor, KeySwitchWithId, Matrix, MatrixEntity, MatrixMut, MultiPartyDecryptor, @@ -31,7 +31,10 @@ impl FheUint8 { /// /// To extract Fhe Uint8 ciphertext at `index` call `self.extract(index)` pub struct BatchedFheUint8 { + /// Vector of RLWE ciphertexts `C` data: Vec, + /// Count of FheUint8s packed in vector of RLWE ciphertexts + count: usize, } impl SampleExtractor> for BatchedFheUint8 @@ -45,7 +48,8 @@ where /// ciphertexts, Fhe uint8 ciphertext at index `i` is stored in coefficients /// `i*8...(i+1)*8`. To extract Fhe uint8 at index `i`, sample extract bool /// ciphertext at indices `[i*8, ..., (i+1)*8)` - fn extract(&self, index: usize) -> FheUint8 { + fn extract_at(&self, index: usize) -> FheUint8 { + assert!(index < self.count); BoolEvaluator::with_local(|e| { let ring_size = e.parameters().rlwe_n().0; @@ -55,12 +59,27 @@ where .map(|i| { let rlwe_index = i / ring_size; let coeff_index = i % ring_size; - self.data[rlwe_index].extract(coeff_index) + self.data[rlwe_index].extract_at(coeff_index) }) .collect_vec(); FheUint8 { data } }) } + + /// Extracts all FheUint8s packed in vector of RLWE ciphertexts of `Self` + fn extract_all(&self) -> Vec> { + (0..self.count) + .map(|index| self.extract_at(index)) + .collect_vec() + } + + /// Extracts first `how_many` FheUint8s packed in vector of RLWE + /// ciphertexts of `Self` + fn extract_many(&self, how_many: usize) -> Vec> { + (0..how_many) + .map(|index| self.extract_at(index)) + .collect_vec() + } } impl> From<&SeededBatchedFheUint8> @@ -92,14 +111,24 @@ where rlwe }) .collect_vec(); - Self { data: rlwes } + Self { + data: rlwes, + count: value.count, + } }) } } pub struct SeededBatchedFheUint8 { + /// Vector of Seeded RLWE ciphertexts `C`. + /// + /// If RLWE(m) = [a, b] s.t. m + e = b - as, `a` can be seeded and seeded + /// RLWE ciphertext only contains `b` polynomial data: Vec, + /// Seed for the ciphertexts seed: S, + /// Count of FheUint8s packed in vector of RLWE ciphertexts + count: usize, } impl SeededBatchedFheUint8 { @@ -119,12 +148,16 @@ where /// Encrypt a slice of u8s of arbitray length as `SeededBatchedFheUint8` fn encrypt(&self, m: &[u8]) -> SeededBatchedFheUint8 { // convert vector of u8s to vector bools - let m = m + let bool_m = m .iter() .flat_map(|v| (0..8).into_iter().map(|i| (((*v) >> i) & 1) == 1)) .collect_vec(); - let (cts, seed) = K::encrypt(&self, &m); - SeededBatchedFheUint8 { data: cts, seed } + let (cts, seed) = K::encrypt(&self, &bool_m); + SeededBatchedFheUint8 { + data: cts, + seed, + count: m.len(), + } } } @@ -133,7 +166,7 @@ where K: Encryptor<[bool], Vec>, { fn encrypt(&self, m: &[u8]) -> BatchedFheUint8 { - let m = m + let bool_m = m .iter() .flat_map(|v| { (0..8) @@ -142,8 +175,11 @@ where .collect_vec() }) .collect_vec(); - let cts = K::encrypt(&self, &m); - BatchedFheUint8 { data: cts } + let cts = K::encrypt(&self, &bool_m); + BatchedFheUint8 { + data: cts, + count: m.len(), + } } } @@ -161,7 +197,10 @@ where .iter() .map(|c| c.key_switch(user_id)) .collect_vec(); - BatchedFheUint8 { data } + BatchedFheUint8 { + data, + count: self.count, + } } }