pbs works again

This commit is contained in:
Janmajaya Mall
2024-05-19 20:13:12 +05:30
parent 8ec7143d80
commit 892b29e775
7 changed files with 1326 additions and 217 deletions

View File

@@ -7,6 +7,8 @@ pub trait Modulus {
type Element; type Element;
/// Modulus value if it fits in Element /// Modulus value if it fits in Element
fn q(&self) -> Option<Self::Element>; fn q(&self) -> Option<Self::Element>;
/// Modulus value as f64 if it fits in f64
fn q_as_f64(&self) -> Option<f64>;
/// Is modulus native? /// Is modulus native?
fn is_native(&self) -> bool; fn is_native(&self) -> bool;
/// -1 in signed representaiton /// -1 in signed representaiton
@@ -17,11 +19,11 @@ pub trait Modulus {
/// Always assmed to be 0. /// Always assmed to be 0.
fn smallest_unsigned_value(&self) -> Self::Element; fn smallest_unsigned_value(&self) -> Self::Element;
/// Convert unsigned value in signed represetation to i64 /// Convert unsigned value in signed represetation to i64
fn to_i64(&self, v: &Self::Element) -> i64; fn map_element_to_i64(&self, v: &Self::Element) -> i64;
/// Convert f64 to signed represented in modulus /// Convert f64 to signed represented in modulus
fn from_f64(&self, v: f64) -> Self::Element; fn map_element_from_f64(&self, v: f64) -> Self::Element;
/// Convert i64 to signed represented in modulus /// Convert i64 to signed represented in modulus
fn from_i64(&self, v: i64) -> Self::Element; fn map_element_from_i64(&self, v: i64) -> Self::Element;
} }
impl Modulus for u64 { impl Modulus for u64 {
@@ -39,7 +41,7 @@ impl Modulus for u64 {
fn smallest_unsigned_value(&self) -> Self::Element { fn smallest_unsigned_value(&self) -> Self::Element {
0 0
} }
fn to_i64(&self, v: &Self::Element) -> i64 { fn map_element_to_i64(&self, v: &Self::Element) -> i64 {
assert!(v < self); assert!(v < self);
if *v > (self >> 1) { if *v > (self >> 1) {
@@ -48,7 +50,7 @@ impl Modulus for u64 {
ToPrimitive::to_i64(v).unwrap() ToPrimitive::to_i64(v).unwrap()
} }
} }
fn from_f64(&self, v: f64) -> Self::Element { fn map_element_from_f64(&self, v: f64) -> Self::Element {
//FIXME (Jay): Before I check whether v is smaller than 0 with `let is_neg = //FIXME (Jay): Before I check whether v is smaller than 0 with `let is_neg =
// o.is_sign_negative() && o != 0.0; I'm ocnfused why didn't I simply check < // o.is_sign_negative() && o != 0.0; I'm ocnfused why didn't I simply check <
// 0.0? // 0.0?
@@ -59,7 +61,7 @@ impl Modulus for u64 {
v.to_u64().unwrap() v.to_u64().unwrap()
} }
} }
fn from_i64(&self, v: i64) -> Self::Element { fn map_element_from_i64(&self, v: i64) -> Self::Element {
if v < 0 { if v < 0 {
self - v.to_u64().unwrap() self - v.to_u64().unwrap()
} else { } else {
@@ -69,6 +71,9 @@ impl Modulus for u64 {
fn q(&self) -> Option<Self::Element> { fn q(&self) -> Option<Self::Element> {
Some(*self) Some(*self)
} }
fn q_as_f64(&self) -> Option<f64> {
self.to_f64()
}
} }
pub trait ModInit { pub trait ModInit {

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,4 @@
use num_traits::{ConstZero, PrimInt, Zero}; use num_traits::{ConstZero, FromPrimitive, PrimInt, ToPrimitive, Zero};
use crate::{backend::Modulus, decomposer::Decomposer}; use crate::{backend::Modulus, decomposer::Decomposer};
@@ -183,23 +183,52 @@ impl<T: ConstZero> CiphertextModulus<T> {
} }
} }
impl<T> Modulus for CiphertextModulus<T> impl<T> CiphertextModulus<T>
where where
T: PrimInt, T: PrimInt,
{
pub(crate) fn _bits() -> usize {
std::mem::size_of::<T>() as usize * 8
}
fn _native(&self) -> bool {
self.1
}
fn _half_q(&self) -> T {
if self._native() {
T::one() << (Self::_bits() - 1)
} else {
self.0 >> 1
}
}
fn _q(&self) -> Option<T> {
if self._native() {
None
} else {
Some(self.0)
}
}
}
impl<T> Modulus for CiphertextModulus<T>
where
T: PrimInt + FromPrimitive,
{ {
type Element = T; type Element = T;
fn is_native(&self) -> bool { fn is_native(&self) -> bool {
false self._native()
} }
fn largest_unsigned_value(&self) -> Self::Element { fn largest_unsigned_value(&self) -> Self::Element {
if self.1 { if self._native() {
T::max_value() T::max_value()
} else { } else {
self.0 - T::one() self.0 - T::one()
} }
} }
fn neg_one(&self) -> Self::Element { fn neg_one(&self) -> Self::Element {
if self.1 { if self._native() {
T::max_value() T::max_value()
} else { } else {
self.0 - T::one() self.0 - T::one()
@@ -211,20 +240,43 @@ where
T::zero() T::zero()
} }
fn to_i64(&self, v: &Self::Element) -> i64 { fn map_element_to_i64(&self, v: &Self::Element) -> i64 {
todo!() if *v > self._half_q() {
-((self.largest_unsigned_value() - *v) + T::one())
.to_i64()
.unwrap()
} else {
v.to_i64().unwrap()
}
} }
fn from_f64(&self, v: f64) -> Self::Element { fn map_element_from_f64(&self, v: f64) -> Self::Element {
todo!() let v = v.round();
if v < 0.0 {
self.largest_unsigned_value() - T::from_f64(v.abs()).unwrap() + T::one()
} else {
T::from_f64(v.abs()).unwrap()
}
} }
fn from_i64(&self, v: i64) -> Self::Element { fn map_element_from_i64(&self, v: i64) -> Self::Element {
todo!() if v < 0 {
self.largest_unsigned_value() - T::from_i64(v.abs()).unwrap() + T::one()
} else {
T::from_i64(v.abs()).unwrap()
}
} }
fn q(&self) -> Option<Self::Element> { fn q(&self) -> Option<Self::Element> {
todo!() self._q()
}
fn q_as_f64(&self) -> Option<f64> {
if self._native() {
Some(T::max_value().to_f64().unwrap() + 1.0)
} else {
self.0.to_f64()
}
} }
} }

View File

@@ -12,8 +12,8 @@ use crate::{
backend::{ArithmeticOps, GetModulus, Modulus, VectorOps}, backend::{ArithmeticOps, GetModulus, Modulus, VectorOps},
decomposer::Decomposer, decomposer::Decomposer,
random::{ random::{
DefaultSecureRng, NewWithSeed, RandomFillGaussianInModulus, RandomGaussianElementInModulus, DefaultSecureRng, NewWithSeed, RandomFillGaussianInModulus, RandomFillUniformInModulus,
RandomFillUniformInModulus, DEFAULT_RNG, RandomGaussianElementInModulus, DEFAULT_RNG,
}, },
utils::{fill_random_ternary_secret_with_hamming_weight, TryConvertFrom1, WithLocal}, utils::{fill_random_ternary_secret_with_hamming_weight, TryConvertFrom1, WithLocal},
Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, Secret, Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, Secret,
@@ -65,7 +65,11 @@ where
let mut p_rng = R::new_with_seed(value.seed.clone()); let mut p_rng = R::new_with_seed(value.seed.clone());
let mut data = M::zeros(value.data.as_ref().len(), value.to_lwe_n + 1); let mut data = M::zeros(value.data.as_ref().len(), value.to_lwe_n + 1);
izip!(value.data.as_ref().iter(), data.iter_rows_mut()).for_each(|(bi, lwe_i)| { izip!(value.data.as_ref().iter(), data.iter_rows_mut()).for_each(|(bi, lwe_i)| {
RandomFillUniformInModulus::random_fill(&mut p_rng, &value.modulus, &mut lwe_i.as_mut()[1..]); RandomFillUniformInModulus::random_fill(
&mut p_rng,
&value.modulus,
&mut lwe_i.as_mut()[1..],
);
lwe_i.as_mut()[0] = *bi; lwe_i.as_mut()[0] = *bi;
}); });
LweKeySwitchingKey { LweKeySwitchingKey {
@@ -189,7 +193,8 @@ pub fn lwe_ksk_keygen<
pub fn encrypt_lwe< pub fn encrypt_lwe<
Ro: Row + RowMut, Ro: Row + RowMut,
Op: ArithmeticOps<Element = Ro::Element> + GetModulus<Element = Ro::Element>, Op: ArithmeticOps<Element = Ro::Element> + GetModulus<Element = Ro::Element>,
R: RandomGaussianElementInModulus<Ro::Element, Op::M> + RandomFillUniformInModulus<[Ro::Element], Op::M>, R: RandomGaussianElementInModulus<Ro::Element, Op::M>
+ RandomFillUniformInModulus<[Ro::Element], Op::M>,
S, S,
>( >(
lwe_out: &mut Ro, lwe_out: &mut Ro,
@@ -273,7 +278,7 @@ where
let mut diff = operator.sub(&m, ideal_m); let mut diff = operator.sub(&m, ideal_m);
let q = operator.modulus(); let q = operator.modulus();
return q.to_i64(&diff).to_f64().unwrap().abs().log2(); return q.map_element_to_i64(&diff).to_f64().unwrap().abs().log2();
} }
#[cfg(test)] #[cfg(test)]

View File

@@ -118,7 +118,7 @@ where
container.iter_mut() container.iter_mut()
) )
.for_each(|(from, to)| { .for_each(|(from, to)| {
*to = modulus.from_f64(from); *to = modulus.map_element_from_f64(from);
}); });
} }
} }
@@ -152,13 +152,13 @@ where
T: PrimInt + SampleUniform, T: PrimInt + SampleUniform,
{ {
fn random(&mut self, modulus: &T) -> T { fn random(&mut self, modulus: &T) -> T {
Uniform::new_inclusive(T::zero(), modulus).sample(&mut self.rng) Uniform::new(T::zero(), modulus).sample(&mut self.rng)
} }
} }
impl<T, M: Modulus<Element = T>> RandomGaussianElementInModulus<T, M> for DefaultSecureRng { impl<T, M: Modulus<Element = T>> RandomGaussianElementInModulus<T, M> for DefaultSecureRng {
fn random(&mut self, modulus: &M) -> T { fn random(&mut self, modulus: &M) -> T {
modulus.from_f64( modulus.map_element_from_f64(
rand_distr::Normal::new(0.0, 3.19f64) rand_distr::Normal::new(0.0, 3.19f64)
.unwrap() .unwrap()
.sample(&mut self.rng), .sample(&mut self.rng),

View File

@@ -13,8 +13,8 @@ use crate::{
decomposer::{self, Decomposer, RlweDecomposer}, decomposer::{self, Decomposer, RlweDecomposer},
ntt::{self, Ntt, NttInit}, ntt::{self, Ntt, NttInit},
random::{ random::{
DefaultSecureRng, NewWithSeed, RandomElementInModulus, RandomFill, RandomFillGaussianInModulus, DefaultSecureRng, NewWithSeed, RandomElementInModulus, RandomFill,
RandomFillUniformInModulus, RandomFillGaussianInModulus, RandomFillUniformInModulus,
}, },
utils::{fill_random_ternary_secret_with_hamming_weight, TryConvertFrom1, WithLocal}, utils::{fill_random_ternary_secret_with_hamming_weight, TryConvertFrom1, WithLocal},
Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, Secret, Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, Secret,
@@ -1528,7 +1528,7 @@ where
let mut max_diff_bits = f64::MIN; let mut max_diff_bits = f64::MIN;
m_plus_e.as_ref().iter().for_each(|v| { m_plus_e.as_ref().iter().for_each(|v| {
let bits = (q.to_i64(v).to_f64().unwrap()).log2(); let bits = (q.map_element_to_i64(v).to_f64().unwrap()).log2();
if max_diff_bits < bits { if max_diff_bits < bits {
max_diff_bits = bits; max_diff_bits = bits;
@@ -1744,7 +1744,11 @@ pub(crate) mod tests {
// sample m0 // sample m0
let mut m0 = vec![0u64; ring_size as usize]; let mut m0 = vec![0u64; ring_size as usize];
RandomFillUniformInModulus::<[u64], u64>::random_fill(&mut rng, &(1u64 << logp), m0.as_mut_slice()); RandomFillUniformInModulus::<[u64], u64>::random_fill(
&mut rng,
&(1u64 << logp),
m0.as_mut_slice(),
);
let ntt_op = NttBackendU64::new(&q, ring_size as usize); let ntt_op = NttBackendU64::new(&q, ring_size as usize);
let mod_op = ModularOpsU64::new(q); let mod_op = ModularOpsU64::new(q);
@@ -1787,7 +1791,11 @@ pub(crate) mod tests {
let s = RlweSecret::random((ring_size >> 1) as usize, ring_size as usize); let s = RlweSecret::random((ring_size >> 1) as usize, ring_size as usize);
let mut m0 = vec![0u64; ring_size as usize]; let mut m0 = vec![0u64; ring_size as usize];
RandomFillUniformInModulus::<[u64], _>::random_fill(&mut rng, &(1u64 << logp), m0.as_mut_slice()); RandomFillUniformInModulus::<[u64], _>::random_fill(
&mut rng,
&(1u64 << logp),
m0.as_mut_slice(),
);
let mut m1 = vec![0u64; ring_size as usize]; let mut m1 = vec![0u64; ring_size as usize];
m1[thread_rng().gen_range(0..ring_size) as usize] = 1; m1[thread_rng().gen_range(0..ring_size) as usize] = 1;

View File

@@ -146,7 +146,10 @@ pub trait TryConvertFrom1<T: ?Sized, P> {
impl<P: Modulus<Element = u64>> TryConvertFrom1<[i64], P> for Vec<u64> { impl<P: Modulus<Element = u64>> TryConvertFrom1<[i64], P> for Vec<u64> {
fn try_convert_from(value: &[i64], parameters: &P) -> Self { fn try_convert_from(value: &[i64], parameters: &P) -> Self {
value.iter().map(|v| parameters.from_i64(*v)).collect_vec() value
.iter()
.map(|v| parameters.map_element_from_i64(*v))
.collect_vec()
} }
} }
@@ -154,14 +157,17 @@ impl<P: Modulus<Element = u64>> TryConvertFrom1<[i32], P> for Vec<u64> {
fn try_convert_from(value: &[i32], parameters: &P) -> Self { fn try_convert_from(value: &[i32], parameters: &P) -> Self {
value value
.iter() .iter()
.map(|v| parameters.from_i64(*v as i64)) .map(|v| parameters.map_element_from_i64(*v as i64))
.collect_vec() .collect_vec()
} }
} }
impl<P: Modulus> TryConvertFrom1<[P::Element], P> for Vec<i64> { impl<P: Modulus> TryConvertFrom1<[P::Element], P> for Vec<i64> {
fn try_convert_from(value: &[P::Element], parameters: &P) -> Self { fn try_convert_from(value: &[P::Element], parameters: &P) -> Self {
value.iter().map(|v| parameters.to_i64(v)).collect_vec() value
.iter()
.map(|v| parameters.map_element_to_i64(v))
.collect_vec()
} }
} }