Browse Source

tfhe: get rid of constant generics

rm-const-generics
arnaucube 2 weeks ago
parent
commit
bb3288f211
16 changed files with 730 additions and 583 deletions
  1. +0
    -1
      arith/src/lib.rs
  2. +2
    -5
      arith/src/ntt.rs
  3. +1
    -6
      arith/src/ring_n.rs
  4. +0
    -2
      arith/src/ring_nq.rs
  5. +0
    -1
      arith/src/ring_torus.rs
  6. +5
    -11
      arith/src/tuple_ring.rs
  7. +0
    -1
      arith/src/zq.rs
  8. +66
    -66
      bfv/src/lib.rs
  9. +25
    -25
      ckks/src/lib.rs
  10. +1
    -1
      gfhe/src/glev.rs
  11. +31
    -5
      gfhe/src/glwe.rs
  12. +74
    -58
      tfhe/src/tggsw.rs
  13. +176
    -130
      tfhe/src/tglwe.rs
  14. +61
    -54
      tfhe/src/tgsw.rs
  15. +63
    -41
      tfhe/src/tlev.rs
  16. +225
    -176
      tfhe/src/tlwe.rs

+ 0
- 1
arith/src/lib.rs

@ -19,7 +19,6 @@ pub mod tuple_ring;
pub mod ntt;
// expose objects
pub use complex::C;
pub use matrix::Matrix;
pub use torus::T64;

+ 2
- 5
arith/src/ntt.rs

@ -6,11 +6,7 @@
//! generics; but once using real-world parameters, the stack could not handle
//! it, so moved to use Vec instead of fixed-sized arrays, and adapted the NTT
//! implementation to that too.
use crate::{
ring::{Ring, RingParam},
ring_nq::Rq,
zq::Zq,
};
use crate::{ring::RingParam, ring_nq::Rq, zq::Zq};
use std::collections::HashMap;
@ -197,6 +193,7 @@ const fn const_inv_mod(q: u64, x: u64) -> u64 {
#[cfg(test)]
mod tests {
use super::*;
use crate::Ring;
use anyhow::Result;

+ 1
- 6
arith/src/ring_n.rs

@ -1,16 +1,11 @@
//! Polynomial ring Z[X]/(X^N+1)
//!
use anyhow::Result;
use itertools::zip_eq;
use rand::{distributions::Distribution, Rng};
use std::array;
use std::borrow::Borrow;
use std::fmt;
use std::iter::Sum;
use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
use crate::Ring;
use std::ops::{Add, AddAssign, Mul, Neg, Sub, SubAssign};
// TODO rename to not have name conflicts with the Ring trait (R: Ring)
// PolynomialRing element, where the PolynomialRing is R = Z[X]/(X^n +1)

+ 0
- 2
arith/src/ring_nq.rs

@ -4,8 +4,6 @@
use anyhow::{anyhow, Result};
use itertools::zip_eq;
use rand::{distributions::Distribution, Rng};
use std::array;
use std::borrow::Borrow;
use std::fmt;
use std::iter::Sum;
use std::ops::{Add, AddAssign, Mul, Neg, Sub, SubAssign};

+ 0
- 1
arith/src/ring_torus.rs

@ -9,7 +9,6 @@
use itertools::zip_eq;
use rand::{distributions::Distribution, Rng};
use std::array;
use std::iter::Sum;
use std::ops::{Add, AddAssign, Mul, Neg, Sub, SubAssign};

+ 5
- 11
arith/src/tuple_ring.rs

@ -1,15 +1,9 @@
//! This file implements the struct for an Tuple of Ring Rq elements and its
//! operations, which are performed element-wise.
use anyhow::Result;
use itertools::zip_eq;
use rand::{distributions::Distribution, Rng};
use rand_distr::{Normal, Uniform};
use std::iter::Sum;
use std::{
array,
ops::{Add, Mul, Neg, Sub},
};
use std::ops::{Add, Mul, Neg, Sub};
use crate::{Ring, RingParam};
@ -28,23 +22,23 @@ impl TR {
assert_eq!(r.len(), k);
Self { k, r }
}
pub fn zero(k: usize, r_params: &RingParam) -> Self {
pub fn zero(k: usize, r_param: &RingParam) -> Self {
Self {
k,
r: (0..k).into_iter().map(|_| R::zero(r_params)).collect(),
r: (0..k).into_iter().map(|_| R::zero(r_param)).collect(),
}
}
pub fn rand(
mut rng: impl Rng,
dist: impl Distribution<f64>,
k: usize,
r_params: &RingParam,
r_param: &RingParam,
) -> Self {
Self {
k,
r: (0..k)
.into_iter()
.map(|_| R::rand(&mut rng, &dist, r_params))
.map(|_| R::rand(&mut rng, &dist, r_param))
.collect(),
}
}

+ 0
- 1
arith/src/zq.rs

@ -1,5 +1,4 @@
use rand::{distributions::Distribution, Rng};
use std::borrow::Borrow;
use std::fmt;
use std::ops::{Add, AddAssign, Div, Mul, Neg, Sub, SubAssign};

+ 66
- 66
bfv/src/lib.rs

@ -23,7 +23,7 @@ pub struct Param {
p: u64,
}
impl Param {
// returns the plaintext params
// returns the plaintext param
pub fn pt(&self) -> RingParam {
RingParam {
q: self.t,
@ -117,7 +117,7 @@ impl BFV {
// const DELTA: u64 = Q / T; // floor
/// generate a new key pair (privK, pubK)
pub fn new_key(mut rng: impl Rng, params: &Param) -> Result<(SecretKey, PublicKey)> {
pub fn new_key(mut rng: impl Rng, param: &Param) -> Result<(SecretKey, PublicKey)> {
// WIP: review probabilities
// let Xi_key = Uniform::new(-1_f64, 1_f64);
@ -126,37 +126,37 @@ impl BFV {
// secret key
// let mut s = Rq::rand_f64(&mut rng, Xi_key)?;
let mut s = Rq::rand_u64(&mut rng, Xi_key, &params.ring)?;
let mut s = Rq::rand_u64(&mut rng, Xi_key, &param.ring)?;
// since s is going to be multiplied by other Rq elements, already
// compute its NTT
s.compute_evals();
// pk = (-a * s + e, a)
let a = Rq::rand_u64(&mut rng, Uniform::new(0_u64, params.ring.q), &params.ring)?;
let e = Rq::rand_f64(&mut rng, Xi_err, &params.ring)?;
let a = Rq::rand_u64(&mut rng, Uniform::new(0_u64, param.ring.q), &param.ring)?;
let e = Rq::rand_f64(&mut rng, Xi_err, &param.ring)?;
let pk: PublicKey = PublicKey(&(&(-a.clone()) * &s) + &e, a.clone()); // TODO rm clones
Ok((SecretKey(s), pk))
}
// note: m is modulus t
pub fn encrypt(mut rng: impl Rng, params: &Param, pk: &PublicKey, m: &Rq) -> Result<RLWE> {
// assert params & inputs
debug_assert_eq!(params.ring, pk.0.param);
debug_assert_eq!(params.t, m.param.q);
debug_assert_eq!(params.ring.n, m.param.n);
pub fn encrypt(mut rng: impl Rng, param: &Param, pk: &PublicKey, m: &Rq) -> Result<RLWE> {
// assert param & inputs
debug_assert_eq!(param.ring, pk.0.param);
debug_assert_eq!(param.t, m.param.q);
debug_assert_eq!(param.ring.n, m.param.n);
let Xi_key = Uniform::new(-1_f64, 1_f64);
// let Xi_key = Uniform::new(0_u64, 2_u64);
let Xi_err = Normal::new(0_f64, ERR_SIGMA)?;
let u = Rq::rand_f64(&mut rng, Xi_key, &params.ring)?;
let u = Rq::rand_f64(&mut rng, Xi_key, &param.ring)?;
// let u = Rq::rand_u64(&mut rng, Xi_key)?;
let e_1 = Rq::rand_f64(&mut rng, Xi_err, &params.ring)?;
let e_2 = Rq::rand_f64(&mut rng, Xi_err, &params.ring)?;
let e_1 = Rq::rand_f64(&mut rng, Xi_err, &param.ring)?;
let e_2 = Rq::rand_f64(&mut rng, Xi_err, &param.ring)?;
// migrate m's coeffs to the bigger modulus Q (from T)
let m = m.remodule(params.ring.q);
let c0 = &pk.0 * &u + e_1 + m * (params.ring.q / params.t); // floor(q/t)=DELTA
let m = m.remodule(param.ring.q);
let c0 = &pk.0 * &u + e_1 + m * (param.ring.q / param.t); // floor(q/t)=DELTA
let c1 = &pk.1 * &u + e_2;
Ok(RLWE(c0, c1))
}
@ -280,7 +280,7 @@ mod tests {
#[test]
fn test_encrypt_decrypt() -> Result<()> {
let params = Param {
let param = Param {
ring: RingParam {
q: 2u64.pow(16) + 1, // q prime, and 2^q + 1 shape
n: 512,
@ -292,13 +292,13 @@ mod tests {
let mut rng = rand::thread_rng();
for _ in 0..100 {
let (sk, pk) = BFV::new_key(&mut rng, &params)?;
let (sk, pk) = BFV::new_key(&mut rng, &param)?;
let msg_dist = Uniform::new(0_u64, params.t);
let m = Rq::rand_u64(&mut rng, msg_dist, &params.pt())?;
let msg_dist = Uniform::new(0_u64, param.t);
let m = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let c = BFV::encrypt(&mut rng, &params, &pk, &m)?;
let m_recovered = BFV::decrypt(&params, &sk, &c);
let c = BFV::encrypt(&mut rng, &param, &pk, &m)?;
let m_recovered = BFV::decrypt(&param, &sk, &c);
assert_eq!(m, m_recovered);
}
@ -308,7 +308,7 @@ mod tests {
#[test]
fn test_addition() -> Result<()> {
let params = Param {
let param = Param {
ring: RingParam {
q: 2u64.pow(16) + 1, // q prime, and 2^q + 1 shape
n: 128,
@ -320,18 +320,18 @@ mod tests {
let mut rng = rand::thread_rng();
for _ in 0..100 {
let (sk, pk) = BFV::new_key(&mut rng, &params)?;
let (sk, pk) = BFV::new_key(&mut rng, &param)?;
let msg_dist = Uniform::new(0_u64, params.t);
let m1 = Rq::rand_u64(&mut rng, msg_dist, &params.pt())?;
let m2 = Rq::rand_u64(&mut rng, msg_dist, &params.pt())?;
let msg_dist = Uniform::new(0_u64, param.t);
let m1 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let c1 = BFV::encrypt(&mut rng, &params, &pk, &m1)?;
let c2 = BFV::encrypt(&mut rng, &params, &pk, &m2)?;
let c1 = BFV::encrypt(&mut rng, &param, &pk, &m1)?;
let c2 = BFV::encrypt(&mut rng, &param, &pk, &m2)?;
let c3 = c1 + c2;
let m3_recovered = BFV::decrypt(&params, &sk, &c3);
let m3_recovered = BFV::decrypt(&param, &sk, &c3);
assert_eq!(m1 + m2, m3_recovered);
}
@ -342,7 +342,7 @@ mod tests {
#[test]
fn test_constant_add_mul() -> Result<()> {
let q: u64 = 2u64.pow(16) + 1; // q prime, and 2^q + 1 shape
let params = Param {
let param = Param {
ring: RingParam { q, n: 16 },
t: 8, // plaintext modulus
p: q * q,
@ -350,26 +350,26 @@ mod tests {
let mut rng = rand::thread_rng();
let (sk, pk) = BFV::new_key(&mut rng, &params)?;
let (sk, pk) = BFV::new_key(&mut rng, &param)?;
let msg_dist = Uniform::new(0_u64, params.t);
let m1 = Rq::rand_u64(&mut rng, msg_dist, &params.pt())?;
let m2_const = Rq::rand_u64(&mut rng, msg_dist, &params.pt())?;
let c1 = BFV::encrypt(&mut rng, &params, &pk, &m1)?;
let msg_dist = Uniform::new(0_u64, param.t);
let m1 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2_const = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let c1 = BFV::encrypt(&mut rng, &param, &pk, &m1)?;
let c3_add = &c1 + &m2_const;
let m3_add_recovered = BFV::decrypt(&params, &sk, &c3_add);
let m3_add_recovered = BFV::decrypt(&param, &sk, &c3_add);
assert_eq!(&m1 + &m2_const, m3_add_recovered);
// test multiplication of a ciphertext by a constant
let rlk = BFV::rlk_key(&mut rng, &params, &sk)?;
let rlk = BFV::rlk_key(&mut rng, &param, &sk)?;
let c3_mul = BFV::mul_const(&rlk, &c1, &m2_const);
let m3_mul_recovered = BFV::decrypt(&params, &sk, &c3_mul);
let m3_mul_recovered = BFV::decrypt(&param, &sk, &c3_mul);
assert_eq!(
(m1.to_r() * m2_const.to_r()).to_rq(params.t).coeffs(),
(m1.to_r() * m2_const.to_r()).to_rq(param.t).coeffs(),
m3_mul_recovered.coeffs()
);
@ -380,7 +380,7 @@ mod tests {
// TMP WIP
#[test]
#[ignore]
fn test_params() -> Result<()> {
fn test_param() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1; // q prime, and 2^q + 1 shape
const N: usize = 32;
const T: u64 = 8; // plaintext modulus
@ -504,30 +504,30 @@ mod tests {
#[test]
fn test_tensor() -> Result<()> {
let q: u64 = 2u64.pow(16) + 1; // q prime, and 2^q + 1 shape
let params = Param {
let param = Param {
ring: RingParam { q, n: 16 },
t: 2, // plaintext modulus
p: q * q,
};
let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, params.t);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..1_000 {
let m1 = Rq::rand_u64(&mut rng, msg_dist, &params.pt())?;
let m2 = Rq::rand_u64(&mut rng, msg_dist, &params.pt())?;
let m1 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
test_tensor_opt(&mut rng, &params, m1, m2)?;
test_tensor_opt(&mut rng, &param, m1, m2)?;
}
Ok(())
}
fn test_tensor_opt(mut rng: impl Rng, params: &Param, m1: Rq, m2: Rq) -> Result<()> {
let (sk, pk) = BFV::new_key(&mut rng, &params)?;
fn test_tensor_opt(mut rng: impl Rng, param: &Param, m1: Rq, m2: Rq) -> Result<()> {
let (sk, pk) = BFV::new_key(&mut rng, &param)?;
let c1 = BFV::encrypt(&mut rng, &params, &pk, &m1)?;
let c2 = BFV::encrypt(&mut rng, &params, &pk, &m2)?;
let c1 = BFV::encrypt(&mut rng, &param, &pk, &m1)?;
let c2 = BFV::encrypt(&mut rng, &param, &pk, &m2)?;
let (c_a, c_b, c_c) = RLWE::tensor(params.t, &c1, &c2);
let (c_a, c_b, c_c) = RLWE::tensor(param.t, &c1, &c2);
// let (c_a, c_b, c_c) = RLWE::tensor_new::<PQ, T>(&c1, &c2);
// decrypt non-relinearized mul result
@ -539,10 +539,10 @@ mod tests {
// &c_c.to_r(),
// &R::<N>::from_vec(arith::ring_n::naive_mul(&sk.0.to_r(), &sk.0.to_r())),
// ));
let m3: Rq = m3.mul_div_round(params.t, params.ring.q); // descale
let m3 = m3.remodule(params.t);
let m3: Rq = m3.mul_div_round(param.t, param.ring.q); // descale
let m3 = m3.remodule(param.t);
let naive = (m1.clone().to_r() * m2.clone().to_r()).to_rq(params.t); // TODO rm clones
let naive = (m1.clone().to_r() * m2.clone().to_r()).to_rq(param.t); // TODO rm clones
assert_eq!(
m3.coeffs().to_vec(),
naive.coeffs().to_vec(),
@ -557,38 +557,38 @@ mod tests {
#[test]
fn test_mul_relin() -> Result<()> {
let q: u64 = 2u64.pow(16) + 1; // q prime, and 2^q + 1 shape
let params = Param {
let param = Param {
ring: RingParam { q, n: 16 },
t: 2, // plaintext modulus
p: q * q,
};
let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, params.t);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..1_000 {
let m1 = Rq::rand_u64(&mut rng, msg_dist, &params.pt())?;
let m2 = Rq::rand_u64(&mut rng, msg_dist, &params.pt())?;
let m1 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
test_mul_relin_opt(&mut rng, &params, m1, m2)?;
test_mul_relin_opt(&mut rng, &param, m1, m2)?;
}
Ok(())
}
fn test_mul_relin_opt(mut rng: impl Rng, params: &Param, m1: Rq, m2: Rq) -> Result<()> {
let (sk, pk) = BFV::new_key(&mut rng, &params)?;
fn test_mul_relin_opt(mut rng: impl Rng, param: &Param, m1: Rq, m2: Rq) -> Result<()> {
let (sk, pk) = BFV::new_key(&mut rng, &param)?;
let rlk = BFV::rlk_key(&mut rng, &params, &sk)?;
let rlk = BFV::rlk_key(&mut rng, &param, &sk)?;
let c1 = BFV::encrypt(&mut rng, &params, &pk, &m1)?;
let c2 = BFV::encrypt(&mut rng, &params, &pk, &m2)?;
let c1 = BFV::encrypt(&mut rng, &param, &pk, &m1)?;
let c2 = BFV::encrypt(&mut rng, &param, &pk, &m2)?;
let c3 = RLWE::mul(params.t, &rlk, &c1, &c2); // uses relinearize internally
let c3 = RLWE::mul(param.t, &rlk, &c1, &c2); // uses relinearize internally
let m3 = BFV::decrypt(&params, &sk, &c3);
let m3 = BFV::decrypt(&param, &sk, &c3);
let naive = (m1.clone().to_r() * m2.clone().to_r()).to_rq(params.t); // TODO rm clones
let naive = (m1.clone().to_r() * m2.clone().to_r()).to_rq(param.t); // TODO rm clones
assert_eq!(
m3.coeffs().to_vec(),
naive.coeffs().to_vec(),

+ 25
- 25
ckks/src/lib.rs

@ -19,7 +19,7 @@ pub use encoder::Encoder;
const ERR_SIGMA: f64 = 3.2;
#[derive(Clone, Copy, Debug)]
pub struct Params {
pub struct Param {
ring: RingParam,
t: u64,
}
@ -30,33 +30,33 @@ pub struct PublicKey(Rq, Rq);
pub struct SecretKey(Rq);
pub struct CKKS {
params: Params,
param: Param,
encoder: Encoder,
}
impl CKKS {
pub fn new(params: &Params, delta: C<f64>) -> Self {
let encoder = Encoder::new(params.ring.n, delta);
pub fn new(param: &Param, delta: C<f64>) -> Self {
let encoder = Encoder::new(param.ring.n, delta);
Self {
params: params.clone(),
param: param.clone(),
encoder,
}
}
/// generate a new key pair (privK, pubK)
pub fn new_key(&self, mut rng: impl Rng) -> Result<(SecretKey, PublicKey)> {
let params = &self.params;
let param = &self.param;
let Xi_key = Uniform::new(-1_f64, 1_f64);
let Xi_err = Normal::new(0_f64, ERR_SIGMA)?;
let e = Rq::rand_f64(&mut rng, Xi_err, &params.ring)?;
let e = Rq::rand_f64(&mut rng, Xi_err, &param.ring)?;
let mut s = Rq::rand_f64(&mut rng, Xi_key, &params.ring)?;
let mut s = Rq::rand_f64(&mut rng, Xi_key, &param.ring)?;
// since s is going to be multiplied by other Rq elements, already
// compute its NTT
s.compute_evals();
let a = Rq::rand_f64(&mut rng, Xi_key, &params.ring)?;
let a = Rq::rand_f64(&mut rng, Xi_key, &param.ring)?;
let pk: PublicKey = PublicKey((&(-a.clone()) * &s) + e, a.clone()); // TODO rm clones
Ok((SecretKey(s), pk))
@ -69,17 +69,17 @@ impl CKKS {
pk: &PublicKey,
m: &R,
) -> Result<(Rq, Rq)> {
let params = self.params;
let param = self.param;
let Xi_key = Uniform::new(-1_f64, 1_f64);
let Xi_err = Normal::new(0_f64, ERR_SIGMA)?;
let e_0 = Rq::rand_f64(&mut rng, Xi_err, &params.ring)?;
let e_1 = Rq::rand_f64(&mut rng, Xi_err, &params.ring)?;
let e_0 = Rq::rand_f64(&mut rng, Xi_err, &param.ring)?;
let e_1 = Rq::rand_f64(&mut rng, Xi_err, &param.ring)?;
let v = Rq::rand_f64(&mut rng, Xi_key, &params.ring)?;
let v = Rq::rand_f64(&mut rng, Xi_key, &param.ring)?;
// let m: Rq = Rq::from(*m);
let m: Rq = m.clone().to_rq(params.ring.q); // TODO rm clone
let m: Rq = m.clone().to_rq(param.ring.q); // TODO rm clone
Ok((m + e_0 + &v * &pk.0.clone(), &v * &pk.1 + e_1))
}
@ -127,7 +127,7 @@ mod tests {
let q: u64 = 2u64.pow(16) + 1;
let n: usize = 32;
let t: u64 = 50;
let params = Params {
let param = Param {
ring: RingParam { q, n },
t,
};
@ -137,12 +137,12 @@ mod tests {
let mut rng = rand::thread_rng();
for _ in 0..1000 {
let ckks = CKKS::new(&params, scale_factor);
let ckks = CKKS::new(&param, scale_factor);
let (sk, pk) = ckks.new_key(&mut rng)?;
let m_raw: R =
Rq::rand_f64(&mut rng, Uniform::new(0_f64, t as f64), &params.ring)?.to_r();
Rq::rand_f64(&mut rng, Uniform::new(0_f64, t as f64), &param.ring)?.to_r();
let m = &m_raw * &scale_factor_u64;
let ct = ckks.encrypt(&mut rng, &pk, &m)?;
@ -153,7 +153,7 @@ mod tests {
.iter()
.map(|e| (*e as f64 / (scale_factor_u64 as f64)).round() as u64)
.collect();
let m_decrypted = Rq::from_vec_u64(&params.ring, m_decrypted);
let m_decrypted = Rq::from_vec_u64(&param.ring, m_decrypted);
// assert_eq!(m_decrypted, Rq::from(m_raw));
assert_eq!(m_decrypted, m_raw.to_rq(q));
}
@ -166,7 +166,7 @@ mod tests {
let q: u64 = 2u64.pow(16) + 1;
let n: usize = 16;
let t: u64 = 8;
let params = Params {
let param = Param {
ring: RingParam { q, n },
t,
};
@ -175,7 +175,7 @@ mod tests {
let mut rng = rand::thread_rng();
for _ in 0..1000 {
let ckks = CKKS::new(&params, scale_factor);
let ckks = CKKS::new(&param, scale_factor);
let (sk, pk) = ckks.new_key(&mut rng)?;
let z: Vec<C<f64>> = std::iter::repeat_with(|| C::<f64>::rand(&mut rng, t))
@ -215,7 +215,7 @@ mod tests {
let q: u64 = 2u64.pow(16) + 1;
let n: usize = 16;
let t: u64 = 8;
let params = Params {
let param = Param {
ring: RingParam { q, n },
t,
};
@ -224,7 +224,7 @@ mod tests {
let mut rng = rand::thread_rng();
for _ in 0..1000 {
let ckks = CKKS::new(&params, scale_factor);
let ckks = CKKS::new(&param, scale_factor);
let (sk, pk) = ckks.new_key(&mut rng)?;
@ -261,8 +261,8 @@ mod tests {
fn test_sub() -> Result<()> {
let q: u64 = 2u64.pow(16) + 1;
let n: usize = 16;
let t: u64 = 4;
let params = Params {
let t: u64 = 2;
let param = Param {
ring: RingParam { q, n },
t,
};
@ -271,7 +271,7 @@ mod tests {
let mut rng = rand::thread_rng();
for _ in 0..1000 {
let ckks = CKKS::new(&params, scale_factor);
let ckks = CKKS::new(&param, scale_factor);
let (sk, pk) = ckks.new_key(&mut rng)?;

+ 1
- 1
gfhe/src/glev.rs

@ -69,7 +69,7 @@ impl GLev {
impl<R: Ring> Mul<Vec<R>> for GLev<R> {
type Output = GLWE<R>;
fn mul(self, v: Vec<R>) -> GLWE<R> {
// TODO debug_assert_eq of params
// TODO debug_assert_eq of param
// l times GLWES
let glwes: Vec<GLWE<R>> = self.0;

+ 31
- 5
gfhe/src/glwe.rs

@ -22,13 +22,30 @@ pub struct Param {
pub t: u64,
}
impl Param {
// returns the plaintext params
/// returns the plaintext param
pub fn pt(&self) -> RingParam {
// TODO think if maybe return a new truct "PtParam" to differenciate
// between the ciphertexxt (RingParam) and the plaintext param. Maybe it
// can be just a wrapper on top of RingParam.
RingParam {
q: self.t,
n: self.ring.n,
}
}
/// returns the LWE param for the given GLWE (self), that is, it uses k=K*N
/// as the length for the secret key. This follows [2018-421] where
/// TLWE sk: s \in B^n , where n=K*N
/// TRLWE sk: s \in B_N[X]^K
pub fn lwe(&self) -> Self {
Self {
ring: RingParam {
q: self.ring.q,
n: 1,
},
k: self.k * self.ring.n,
t: self.t,
}
}
}
/// GLWE implemented over the `Ring` trait, so that it can be also instantiated
@ -46,8 +63,8 @@ pub struct PublicKey(pub R, pub TR);
pub struct KSK<R: Ring>(Vec<GLev<R>>);
impl<R: Ring> GLWE<R> {
pub fn zero(k: usize, params: &RingParam) -> Self {
Self(TR::zero(k, &params), R::zero(&params))
pub fn zero(k: usize, param: &RingParam) -> Self {
Self(TR::zero(k, &param), R::zero(&param))
}
pub fn from_plaintext(k: usize, param: &RingParam, p: R) -> Self {
Self(TR::zero(k, &param), p)
@ -187,6 +204,9 @@ impl GLWE {
impl<R: Ring> Add<GLWE<R>> for GLWE<R> {
type Output = Self;
fn add(self, other: Self) -> Self {
debug_assert_eq!(self.0.k, other.0.k);
debug_assert_eq!(self.1.param(), other.1.param());
let a: TR<R> = self.0 + other.0;
let b: R = self.1 + other.1;
Self(a, b)
@ -196,6 +216,8 @@ impl Add> for GLWE {
impl<R: Ring> Add<R> for GLWE<R> {
type Output = Self;
fn add(self, plaintext: R) -> Self {
debug_assert_eq!(self.1.param(), plaintext.param());
let a: TR<R> = self.0;
let b: R = self.1 + plaintext;
Self(a, b)
@ -231,6 +253,9 @@ impl Sum> for GLWE {
impl<R: Ring> Sub<GLWE<R>> for GLWE<R> {
type Output = Self;
fn sub(self, other: Self) -> Self {
debug_assert_eq!(self.0.k, other.0.k);
debug_assert_eq!(self.1.param(), other.1.param());
let a: TR<R> = self.0 - other.0;
let b: R = self.1 - other.1;
Self(a, b)
@ -240,6 +265,8 @@ impl Sub> for GLWE {
impl<R: Ring> Mul<R> for GLWE<R> {
type Output = Self;
fn mul(self, plaintext: R) -> Self {
debug_assert_eq!(self.1.param(), plaintext.param());
let a: TR<R> = TR {
k: self.0.k,
r: self
@ -351,8 +378,7 @@ mod tests {
}
}
pub fn t_decode(param: &Param, pt: &Tn) -> Rq {
let p = param.t;
let pt = pt.mul_div_round(p, u64::MAX);
let pt = pt.mul_div_round(param.t, u64::MAX);
Rq::from_vec_u64(&param.pt(), pt.coeffs().iter().map(|c| c.0).collect())
}
#[test]

+ 74
- 58
tfhe/src/tggsw.rs

@ -4,53 +4,57 @@ use rand::Rng;
use std::array;
use std::ops::{Add, Mul};
use arith::{Ring, Rq, Tn, T64, TR};
use arith::{Ring, RingParam, Rq, Tn, T64, TR};
use crate::tglwe::{PublicKey, SecretKey, TGLWE};
use gfhe::glwe::GLWE;
use gfhe::glwe::{Param, GLWE};
/// vector of length K+1 = ([K * TGLev], [1 * TGLev])
#[derive(Clone, Debug)]
pub struct TGGSW<const N: usize, const K: usize>(pub(crate) Vec<TGLev<N, K>>, TGLev<N, K>);
pub struct TGGSW(pub(crate) Vec<TGLev>, TGLev);
impl<const N: usize, const K: usize> TGGSW<N, K> {
impl TGGSW {
pub fn encrypt_s(
mut rng: impl Rng,
param: &Param,
beta: u32,
l: u32,
sk: &SecretKey<N, K>,
m: &Tn<N>,
sk: &SecretKey,
m: &Tn,
) -> Result<Self> {
let a: Vec<TGLev<N, K>> = (0..K)
.map(|i| TGLev::encrypt_s(&mut rng, beta, l, sk, &(-sk.0 .0 .0[i] * *m)))
debug_assert_eq!(sk.0 .0.k, param.k);
let a: Vec<TGLev> = (0..param.k)
.map(|i| TGLev::encrypt_s(&mut rng, param, beta, l, sk, &(&-sk.0 .0.r[i].clone() * m)))
// TODO rm clone
.collect::<Result<Vec<_>>>()?;
let b: TGLev<N, K> = TGLev::encrypt_s(&mut rng, beta, l, sk, m)?;
let b: TGLev = TGLev::encrypt_s(&mut rng, &param, beta, l, sk, m)?;
Ok(Self(a, b))
}
pub fn decrypt(&self, sk: &SecretKey<N, K>, beta: u32) -> Tn<N> {
pub fn decrypt(&self, sk: &SecretKey, beta: u32) -> Tn {
self.1.decrypt(sk, beta)
}
pub fn cmux(bit: Self, ct1: TGLWE<N, K>, ct2: TGLWE<N, K>) -> TGLWE<N, K> {
pub fn cmux(bit: Self, ct1: TGLWE, ct2: TGLWE) -> TGLWE {
ct1.clone() + (bit * (ct2 - ct1))
}
}
/// External product TGGSW x TGLWE
impl<const N: usize, const K: usize> Mul<TGLWE<N, K>> for TGGSW<N, K> {
type Output = TGLWE<N, K>;
/// External product tggsw x tglwe
impl Mul<TGLWE> for TGGSW {
type Output = TGLWE;
fn mul(self, tglwe: TGLWE<N, K>) -> TGLWE<N, K> {
fn mul(self, tglwe: TGLWE) -> TGLWE {
let beta: u32 = 2;
let l: u32 = 64; // TODO wip
let tglwe_ab: Vec<Tn<N>> = [tglwe.0 .0 .0.clone(), vec![tglwe.0 .1]].concat();
let tglwe_ab: Vec<Tn> = [tglwe.0 .0.r.clone(), vec![tglwe.0 .1]].concat();
let tgsw_ab: Vec<TGLev<N, K>> = [self.0.clone(), vec![self.1]].concat();
let tgsw_ab: Vec<TGLev> = [self.0.clone(), vec![self.1]].concat();
assert_eq!(tgsw_ab.len(), tglwe_ab.len());
let r: TGLWE<N, K> = zip_eq(tgsw_ab, tglwe_ab)
let r: TGLWE = zip_eq(tgsw_ab, tglwe_ab)
.map(|(tlev_i, tglwe_i)| tlev_i * tglwe_i.decompose(beta, l))
.sum();
r
@ -58,26 +62,36 @@ impl Mul> for TGGSW {
}
#[derive(Clone, Debug)]
pub struct TGLev<const N: usize, const K: usize>(pub(crate) Vec<TGLWE<N, K>>);
pub struct TGLev(pub(crate) Vec<TGLWE>);
impl TGLev {
pub fn encode(param: &Param, m: &Rq) -> Tn {
debug_assert_eq!(param.t, m.param.q); // plaintext modulus
impl<const N: usize, const K: usize> TGLev<N, K> {
pub fn encode<const T: u64>(m: &Rq<T, N>) -> Tn<N> {
let coeffs = m.coeffs();
Tn(array::from_fn(|i| T64(coeffs[i].0)))
Tn {
param: param.ring,
coeffs: m.coeffs().iter().map(|c_i| T64(c_i.v)).collect(),
}
}
pub fn decode<const T: u64>(p: &Tn<N>) -> Rq<T, N> {
Rq::<T, N>::from_vec_u64(p.coeffs().iter().map(|c| c.0).collect())
pub fn decode(param: &Param, p: &Tn) -> Rq {
Rq::from_vec_u64(&param.pt(), p.coeffs().iter().map(|c| c.0).collect())
}
pub fn encrypt(
mut rng: impl Rng,
param: &Param,
beta: u32,
l: u32,
pk: &PublicKey<N, K>,
m: &Tn<N>,
pk: &PublicKey,
m: &Tn,
) -> Result<Self> {
let tlev: Vec<TGLWE<N, K>> = (1..l + 1)
let tlev: Vec<TGLWE> = (1..l + 1)
.map(|i| {
TGLWE::<N, K>::encrypt(&mut rng, pk, &(*m * (u64::MAX / beta.pow(i as u32) as u64)))
TGLWE::encrypt(
&mut rng,
&param,
pk,
&(m * &(u64::MAX / beta.pow(i as u32) as u64)),
)
})
.collect::<Result<Vec<_>>>()?;
@ -85,35 +99,36 @@ impl TGLev {
}
pub fn encrypt_s(
mut rng: impl Rng,
param: &Param,
_beta: u32, // TODO rm, and make beta=2 always
l: u32,
sk: &SecretKey<N, K>,
m: &Tn<N>,
sk: &SecretKey,
m: &Tn,
) -> Result<Self> {
let tlev: Vec<TGLWE<N, K>> = (1..l as u64 + 1)
let tlev: Vec<TGLWE> = (1..l as u64 + 1)
.map(|i| {
let aux = if i < 64 {
*m * (u64::MAX / (1u64 << i))
m * &(u64::MAX / (1u64 << i))
} else {
// 1<<64 would overflow, and anyways we're dividing u64::MAX
// by it, which would be equal to 1
*m
m.clone() // TODO rm clone
};
TGLWE::<N, K>::encrypt_s(&mut rng, sk, &aux)
TGLWE::encrypt_s(&mut rng, &param, sk, &aux)
})
.collect::<Result<Vec<_>>>()?;
Ok(Self(tlev))
}
pub fn decrypt(&self, sk: &SecretKey<N, K>, beta: u32) -> Tn<N> {
pub fn decrypt(&self, sk: &SecretKey, beta: u32) -> Tn {
let pt = self.0[0].decrypt(sk);
pt.mul_div_round(beta as u64, u64::MAX)
}
}
impl<const N: usize, const K: usize> TGLev<N, K> {
pub fn iter(&self) -> std::slice::Iter<TGLWE<N, K>> {
impl TGLev {
pub fn iter(&self) -> std::slice::Iter<TGLWE> {
self.0.iter()
}
}
@ -121,14 +136,14 @@ impl TGLev {
// dot product between a TGLev and Vec<Tn<N>>, usually Vec<Tn<N>> comes from a
// decomposition of Tn<N>
// TGLev * Vec<Tn<N>> --> TGLWE
impl<const N: usize, const K: usize> Mul<Vec<Tn<N>>> for TGLev<N, K> {
type Output = TGLWE<N, K>;
fn mul(self, v: Vec<Tn<N>>) -> Self::Output {
impl Mul<Vec<Tn>> for TGLev {
type Output = TGLWE;
fn mul(self, v: Vec<Tn>) -> Self::Output {
assert_eq!(self.0.len(), v.len());
// l TGLWES
let tlwes: Vec<TGLWE<N, K>> = self.0;
let r: TGLWE<N, K> = zip_eq(v, tlwes).map(|(a_d_i, glwe_i)| glwe_i * a_d_i).sum();
let tlwes: Vec<TGLWE> = self.0;
let r: TGLWE = zip_eq(v, tlwes).map(|(a_d_i, glwe_i)| glwe_i * a_d_i).sum();
r
}
}
@ -141,38 +156,39 @@ mod tests {
use super::*;
#[test]
fn test_external_product() -> Result<()> {
const T: u64 = 16; // plaintext modulus
const K: usize = 4;
const N: usize = 64;
const KN: usize = K * N;
let param = Param {
ring: RingParam { q: u64::MAX, n: 64 },
k: 4,
t: 16, // plaintext modulus
};
let beta: u32 = 2;
let l: u32 = 64;
let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..50 {
let (sk, _) = TGLWE::<N, K>::new_key::<KN>(&mut rng)?;
let (sk, _) = TGLWE::new_key(&mut rng, &param)?;
let m1: Rq<T, N> = Rq::rand_u64(&mut rng, msg_dist)?;
let p1: Tn<N> = TGLev::<N, K>::encode::<T>(&m1);
let m1: Rq = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p1: Tn = TGLev::encode(&param, &m1);
let m2: Rq<T, N> = Rq::rand_u64(&mut rng, msg_dist)?;
let p2: Tn<N> = TGLWE::<N, K>::encode::<T>(&m2); // scaled by delta
let m2: Rq = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p2: Tn = TGLWE::encode(&param, &m2); // scaled by delta
let tgsw = TGGSW::<N, K>::encrypt_s(&mut rng, beta, l, &sk, &p1)?;
let tlwe = TGLWE::<N, K>::encrypt_s(&mut rng, &sk, &p2)?;
let tgsw = TGGSW::encrypt_s(&mut rng, &param, beta, l, &sk, &p1)?;
let tlwe = TGLWE::encrypt_s(&mut rng, &param, &sk, &p2)?;
let res: TGLWE<N, K> = tgsw * tlwe;
let res: TGLWE = tgsw * tlwe;
// let p_recovered = res.decrypt(&sk, beta);
let p_recovered = res.decrypt(&sk);
// downscaled by delta^-1
let res_recovered = TGLWE::<N, K>::decode::<T>(&p_recovered);
let res_recovered = TGLWE::decode(&param, &p_recovered);
// assert_eq!(m1 * m2, m_recovered);
assert_eq!((m1.to_r() * m2.to_r()).to_rq::<T>(), res_recovered);
assert_eq!((m1.to_r() * m2.to_r()).to_rq(param.t), res_recovered);
}
Ok(())

+ 176
- 130
tfhe/src/tglwe.rs

@ -7,155 +7,193 @@ use std::array;
use std::iter::Sum;
use std::ops::{Add, AddAssign, Mul, Sub};
use arith::{Ring, Rq, Tn, T64, TR};
use gfhe::{glwe, GLWE};
use arith::{Ring, RingParam, Rq, Tn, T64, TR};
use gfhe::{glwe, glwe::Param, GLWE};
use crate::tlev::TLev;
use crate::{tlwe, tlwe::TLWE};
// pub type SecretKey<const N: usize, const K: usize> = glwe::SecretKey<Tn<N>, K>;
#[derive(Clone, Debug)]
pub struct SecretKey<const N: usize, const K: usize>(pub glwe::SecretKey<Tn<N>, K>);
pub struct SecretKey(pub glwe::SecretKey<Tn>);
// pub struct SecretKey<const K: usize>(pub tlwe::SecretKey<K>);
impl<const N: usize, const K: usize> SecretKey<N, K> {
pub fn to_tlwe<const KN: usize>(&self) -> tlwe::SecretKey<KN> {
let s: TR<Tn<N>, K> = self.0 .0.clone();
impl SecretKey {
pub fn to_tlwe(&self, param: &Param) -> tlwe::SecretKey {
let s: TR<Tn> = self.0 .0.clone();
debug_assert_eq!(s.r.len(), param.k); // sanity check
let r: Vec<Vec<T64>> = s.0.iter().map(|s_i| s_i.coeffs()).collect();
let kn = param.k * param.ring.n;
let r: Vec<Vec<T64>> = s.r.iter().map(|s_i| s_i.coeffs()).collect();
let r: Vec<T64> = r.into_iter().flatten().collect();
tlwe::SecretKey::<KN>(glwe::SecretKey::<T64, KN>(TR::<T64, KN>::new(r)))
debug_assert_eq!(r.len(), kn); // sanity check
tlwe::SecretKey(glwe::SecretKey::<T64>(TR::<T64>::new(kn, r)))
}
}
pub type PublicKey<const N: usize, const K: usize> = glwe::PublicKey<Tn<N>, K>;
pub type PublicKey = glwe::PublicKey<Tn>;
#[derive(Clone, Debug)]
pub struct TGLWE<const N: usize, const K: usize>(pub GLWE<Tn<N>, K>);
pub struct TGLWE(pub GLWE<Tn>);
impl<const N: usize, const K: usize> TGLWE<N, K> {
pub fn zero() -> Self {
Self(GLWE::<Tn<N>, K>::zero())
impl TGLWE {
pub fn zero(k: usize, param: &RingParam) -> Self {
Self(GLWE::<Tn>::zero(k, param))
}
pub fn from_plaintext(p: Tn<N>) -> Self {
Self(GLWE::<Tn<N>, K>::from_plaintext(p))
pub fn from_plaintext(k: usize, param: &RingParam, p: Tn) -> Self {
Self(GLWE::<Tn>::from_plaintext(k, param, p))
}
pub fn new_key<const KN: usize>(
mut rng: impl Rng,
) -> Result<(SecretKey<N, K>, PublicKey<N, K>)> {
pub fn new_key(mut rng: impl Rng, param: &Param) -> Result<(SecretKey, PublicKey)> {
// assert_eq!(KN, K * N); // this is wip, while not being able to compute K*N
let (sk_tlwe, _) = TLWE::<KN>::new_key(&mut rng)?;
let (sk_tlwe, _) = TLWE::new_key(&mut rng, &param.lwe())?; //param.lwe() so that it uses K*N
debug_assert_eq!(sk_tlwe.0 .0.r.len(), param.lwe().k); // =KN (sanity check)
// let sk = crate::tlwe::sk_to_tglwe::<N, K, KN>(sk_tlwe);
let sk = sk_tlwe.to_tglwe::<N, K>();
let pk: PublicKey<N, K> = GLWE::pk_from_sk(rng, sk.0.clone())?;
let sk = sk_tlwe.to_tglwe(param);
let pk: PublicKey = GLWE::pk_from_sk(rng, param, sk.0.clone())?;
Ok((sk, pk))
}
pub fn encode<const P: u64>(m: &Rq<P, N>) -> Tn<N> {
let delta = u64::MAX / P; // floored
pub fn encode(param: &Param, m: &Rq) -> Tn {
debug_assert_eq!(param.t, m.param.q); // plaintext modulus
let p = param.t;
let delta = u64::MAX / p; // floored
let coeffs = m.coeffs();
Tn(array::from_fn(|i| T64(coeffs[i].0 * delta)))
// Tn(array::from_fn(|i| T64(coeffs[i].0 * delta)))
Tn {
param: param.ring,
coeffs: coeffs.iter().map(|c_i| T64(c_i.v * delta)).collect(),
}
}
pub fn decode<const P: u64>(p: &Tn<N>) -> Rq<P, N> {
let p = p.mul_div_round(P, u64::MAX);
Rq::<P, N>::from_vec_u64(p.coeffs().iter().map(|c| c.0).collect())
pub fn decode(param: &Param, pt: &Tn) -> Rq {
let p = param.t;
let pt = pt.mul_div_round(p, u64::MAX);
Rq::from_vec_u64(&param.pt(), pt.coeffs().iter().map(|c| c.0).collect())
}
// encrypts with the given SecretKey (instead of PublicKey)
pub fn encrypt_s(rng: impl Rng, sk: &SecretKey<N, K>, p: &Tn<N>) -> Result<Self> {
let glwe = GLWE::encrypt_s(rng, &sk.0, p)?;
pub fn encrypt_s(rng: impl Rng, param: &Param, sk: &SecretKey, p: &Tn) -> Result<Self> {
let glwe = GLWE::encrypt_s(rng, param, &sk.0, p)?;
Ok(Self(glwe))
}
pub fn encrypt(rng: impl Rng, pk: &PublicKey<N, K>, p: &Tn<N>) -> Result<Self> {
let glwe = GLWE::encrypt(rng, &pk, p)?;
pub fn encrypt(rng: impl Rng, param: &Param, pk: &PublicKey, p: &Tn) -> Result<Self> {
let glwe = GLWE::encrypt(rng, param, &pk, p)?;
Ok(Self(glwe))
}
pub fn decrypt(&self, sk: &SecretKey<N, K>) -> Tn<N> {
pub fn decrypt(&self, sk: &SecretKey) -> Tn {
self.0.decrypt(&sk.0)
}
/// Sample extraction / Coefficient extraction
pub fn sample_extraction<const KN: usize>(&self, h: usize) -> TLWE<KN> {
assert!(h < N);
pub fn sample_extraction(&self, param: &Param, h: usize) -> TLWE {
let n = param.ring.n;
assert!(h < n);
let a: TR<Tn<N>, K> = self.0 .0.clone();
let a: TR<Tn> = self.0 .0.clone();
// set a_{n*i+j} = a_{i, h-j} if j \in {0, h}
// -a_{i, n+h-j} if j \in {h+1, n-1}
let new_a: Vec<T64> = a
.iter()
.flat_map(|a_i| {
let a_i = a_i.coeffs();
(0..N)
.map(|j| if j <= h { a_i[h - j] } else { -a_i[N + h - j] })
(0..n)
.map(|j| if j <= h { a_i[h - j] } else { -a_i[n + h - j] })
.collect::<Vec<T64>>()
})
.collect::<Vec<T64>>();
TLWE(GLWE(TR(new_a), self.0 .1.coeffs()[h]))
debug_assert_eq!(new_a.len(), param.k * param.ring.n); // sanity check
TLWE(GLWE(
TR {
// TODO use constructor `new`, which will check len with k
k: param.k * param.ring.n,
r: new_a,
},
self.0 .1.coeffs()[h],
))
}
pub fn left_rotate(&self, h: usize) -> Self {
dbg!(&h);
let (a, b): (TR<Tn<N>, K>, Tn<N>) = (self.0 .0.clone(), self.0 .1);
let (a, b): (TR<Tn>, Tn) = (self.0 .0.clone(), self.0 .1.clone());
Self(GLWE(a.left_rotate(h), b.left_rotate(h)))
}
}
impl<const N: usize, const K: usize> Add<TGLWE<N, K>> for TGLWE<N, K> {
impl Add<TGLWE> for TGLWE {
type Output = Self;
fn add(self, other: Self) -> Self {
debug_assert_eq!(self.0 .0.k, other.0 .0.k);
debug_assert_eq!(self.0 .1.param(), other.0 .1.param());
Self(self.0 + other.0)
}
}
impl<const N: usize, const K: usize> AddAssign for TGLWE<N, K> {
fn add_assign(&mut self, rhs: Self) {
self.0 += rhs.0
impl AddAssign for TGLWE {
fn add_assign(&mut self, other: Self) {
debug_assert_eq!(self.0 .0.k, other.0 .0.k);
debug_assert_eq!(self.0 .1.param(), other.0 .1.param());
self.0 += other.0
}
}
impl<const N: usize, const K: usize> Sum<TGLWE<N, K>> for TGLWE<N, K> {
fn sum<I>(iter: I) -> Self
impl Sum<TGLWE> for TGLWE {
fn sum<I>(mut iter: I) -> Self
where
I: Iterator<Item = Self>,
{
let mut acc = TGLWE::<N, K>::zero();
for e in iter {
acc += e;
}
acc
// let mut acc = TGLWE::<N, K>::zero();
// for e in iter {
// acc += e;
// }
// acc
let first = iter.next().unwrap();
iter.fold(first, |acc, e| acc + e)
}
}
impl<const N: usize, const K: usize> Sub<TGLWE<N, K>> for TGLWE<N, K> {
impl Sub<TGLWE> for TGLWE {
type Output = Self;
fn sub(self, other: Self) -> Self {
debug_assert_eq!(self.0 .0.k, other.0 .0.k);
debug_assert_eq!(self.0 .1.param(), other.0 .1.param());
Self(self.0 - other.0)
}
}
// plaintext addition
impl<const N: usize, const K: usize> Add<Tn<N>> for TGLWE<N, K> {
impl Add<Tn> for TGLWE {
type Output = Self;
fn add(self, plaintext: Tn<N>) -> Self {
let a: TR<Tn<N>, K> = self.0 .0;
let b: Tn<N> = self.0 .1 + plaintext;
fn add(self, plaintext: Tn) -> Self {
debug_assert_eq!(self.0 .1.param(), plaintext.param());
let a: TR<Tn> = self.0 .0;
let b: Tn = self.0 .1 + plaintext;
Self(GLWE(a, b))
}
}
// plaintext substraction
impl<const N: usize, const K: usize> Sub<Tn<N>> for TGLWE<N, K> {
impl Sub<Tn> for TGLWE {
type Output = Self;
fn sub(self, plaintext: Tn<N>) -> Self {
let a: TR<Tn<N>, K> = self.0 .0;
let b: Tn<N> = self.0 .1 - plaintext;
fn sub(self, plaintext: Tn) -> Self {
debug_assert_eq!(self.0 .1.param(), plaintext.param());
let a: TR<Tn> = self.0 .0;
let b: Tn = self.0 .1 - plaintext;
Self(GLWE(a, b))
}
}
// plaintext multiplication
impl<const N: usize, const K: usize> Mul<Tn<N>> for TGLWE<N, K> {
impl Mul<Tn> for TGLWE {
type Output = Self;
fn mul(self, plaintext: Tn<N>) -> Self {
let a: TR<Tn<N>, K> = TR(self.0 .0 .0.iter().map(|r_i| *r_i * plaintext).collect());
let b: Tn<N> = self.0 .1 * plaintext;
fn mul(self, plaintext: Tn) -> Self {
debug_assert_eq!(self.0 .1.param(), plaintext.param());
let a: TR<Tn> = TR {
k: self.0 .0.k,
r: self.0 .0.r.iter().map(|r_i| r_i * &plaintext).collect(),
};
let b: Tn = self.0 .1 * plaintext;
Self(GLWE(a, b))
}
}
@ -169,30 +207,31 @@ mod tests {
#[test]
fn test_encrypt_decrypt() -> Result<()> {
const T: u64 = 128; // msg space (msg modulus)
const N: usize = 64;
const K: usize = 16;
type S = TGLWE<N, K>;
let param = Param {
ring: RingParam { q: u64::MAX, n: 64 },
k: 16,
t: 128, // plaintext modulus
};
let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..200 {
let (sk, pk) = TGLWE::<N, K>::new_key::<{ K * N }>(&mut rng)?;
let (sk, pk) = TGLWE::new_key(&mut rng, &param)?;
let m = Rq::<T, N>::rand_u64(&mut rng, msg_dist)?;
let p: Tn<N> = S::encode::<T>(&m);
let m = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p: Tn = TGLWE::encode(&param, &m);
let c = S::encrypt(&mut rng, &pk, &p)?;
let c = TGLWE::encrypt(&mut rng, &param, &pk, &p)?;
let p_recovered = c.decrypt(&sk);
let m_recovered = S::decode::<T>(&p_recovered);
let m_recovered = TGLWE::decode(&param, &p_recovered);
assert_eq!(m, m_recovered);
// same but using encrypt_s (with sk instead of pk))
let c = S::encrypt_s(&mut rng, &sk, &p)?;
let c = TGLWE::encrypt_s(&mut rng, &param, &sk, &p)?;
let p_recovered = c.decrypt(&sk);
let m_recovered = S::decode::<T>(&p_recovered);
let m_recovered = TGLWE::decode(&param, &p_recovered);
assert_eq!(m, m_recovered);
}
@ -202,31 +241,32 @@ mod tests {
#[test]
fn test_addition() -> Result<()> {
const T: u64 = 128;
const N: usize = 64;
const K: usize = 16;
type S = TGLWE<N, K>;
let param = Param {
ring: RingParam { q: u64::MAX, n: 64 },
k: 16,
t: 128, // plaintext modulus
};
let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..200 {
let (sk, pk) = S::new_key::<{ K * N }>(&mut rng)?;
let (sk, pk) = TGLWE::new_key(&mut rng, &param)?;
let m1 = Rq::<T, N>::rand_u64(&mut rng, msg_dist)?;
let m2 = Rq::<T, N>::rand_u64(&mut rng, msg_dist)?;
let p1: Tn<N> = S::encode::<T>(&m1); // plaintext
let p2: Tn<N> = S::encode::<T>(&m2); // plaintext
let m1 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p1: Tn = TGLWE::encode(&param, &m1); // plaintext
let p2: Tn = TGLWE::encode(&param, &m2); // plaintext
let c1 = S::encrypt(&mut rng, &pk, &p1)?;
let c2 = S::encrypt(&mut rng, &pk, &p2)?;
let c1 = TGLWE::encrypt(&mut rng, &param, &pk, &p1)?;
let c2 = TGLWE::encrypt(&mut rng, &param, &pk, &p2)?;
let c3 = c1 + c2;
let p3_recovered = c3.decrypt(&sk);
let m3_recovered = S::decode::<T>(&p3_recovered);
let m3_recovered = TGLWE::decode(&param, &p3_recovered);
assert_eq!((m1 + m2).remodule::<T>(), m3_recovered.remodule::<T>());
assert_eq!((m1 + m2).remodule(param.t), m3_recovered.remodule(param.t));
}
Ok(())
@ -234,28 +274,29 @@ mod tests {
#[test]
fn test_add_plaintext() -> Result<()> {
const T: u64 = 128;
const N: usize = 64;
const K: usize = 16;
type S = TGLWE<N, K>;
let param = Param {
ring: RingParam { q: u64::MAX, n: 64 },
k: 16,
t: 128, // plaintext modulus
};
let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..200 {
let (sk, pk) = S::new_key::<{ K * N }>(&mut rng)?;
let (sk, pk) = TGLWE::new_key(&mut rng, &param)?;
let m1 = Rq::<T, N>::rand_u64(&mut rng, msg_dist)?;
let m2 = Rq::<T, N>::rand_u64(&mut rng, msg_dist)?;
let p1: Tn<N> = S::encode::<T>(&m1); // plaintext
let p2: Tn<N> = S::encode::<T>(&m2); // plaintext
let m1 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p1: Tn = TGLWE::encode(&param, &m1); // plaintext
let p2: Tn = TGLWE::encode(&param, &m2); // plaintext
let c1 = S::encrypt(&mut rng, &pk, &p1)?;
let c1 = TGLWE::encrypt(&mut rng, &param, &pk, &p1)?;
let c3 = c1 + p2;
let p3_recovered = c3.decrypt(&sk);
let m3_recovered = S::decode::<T>(&p3_recovered);
let m3_recovered = TGLWE::decode(&param, &p3_recovered);
assert_eq!(m1 + m2, m3_recovered);
}
@ -265,30 +306,34 @@ mod tests {
#[test]
fn test_mul_plaintext() -> Result<()> {
const T: u64 = 128;
const N: usize = 64;
const K: usize = 16;
type S = TGLWE<N, K>;
let param = Param {
ring: RingParam { q: u64::MAX, n: 64 },
k: 16,
t: 128, // plaintext modulus
};
let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..200 {
let (sk, pk) = S::new_key::<{ K * N }>(&mut rng)?;
let (sk, pk) = TGLWE::new_key(&mut rng, &param)?;
let m1 = Rq::<T, N>::rand_u64(&mut rng, msg_dist)?;
let m2 = Rq::<T, N>::rand_u64(&mut rng, msg_dist)?;
let p1: Tn<N> = S::encode::<T>(&m1);
let m1 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p1: Tn = TGLWE::encode(&param, &m1);
// don't scale up p2, set it directly from m2
let p2: Tn<N> = Tn(array::from_fn(|i| T64(m2.coeffs()[i].0)));
let p2: Tn = Tn {
param: param.ring,
coeffs: m2.coeffs().iter().map(|c_i| T64(c_i.v)).collect(),
};
let c1 = S::encrypt(&mut rng, &pk, &p1)?;
let c1 = TGLWE::encrypt(&mut rng, &param, &pk, &p1)?;
let c3 = c1 * p2;
let p3_recovered: Tn<N> = c3.decrypt(&sk);
let m3_recovered = S::decode::<T>(&p3_recovered);
assert_eq!((m1.to_r() * m2.to_r()).to_rq::<T>(), m3_recovered);
let p3_recovered: Tn = c3.decrypt(&sk);
let m3_recovered = TGLWE::decode(&param, &p3_recovered);
assert_eq!((m1.to_r() * m2.to_r()).to_rq(param.t), m3_recovered);
}
Ok(())
@ -296,28 +341,29 @@ mod tests {
#[test]
fn test_sample_extraction() -> Result<()> {
const T: u64 = 128; // msg space (msg modulus)
const N: usize = 64;
const K: usize = 16;
const KN: usize = K * N;
let param = Param {
ring: RingParam { q: u64::MAX, n: 64 },
k: 16,
t: 128, // plaintext modulus
};
let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..20 {
let (sk, pk) = TGLWE::<N, K>::new_key::<KN>(&mut rng)?;
let sk_tlwe = sk.to_tlwe::<KN>();
let (sk, pk) = TGLWE::new_key(&mut rng, &param)?;
let sk_tlwe = sk.to_tlwe(&param);
let m = Rq::<T, N>::rand_u64(&mut rng, msg_dist)?;
let p: Tn<N> = TGLWE::<N, K>::encode::<T>(&m);
let m = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p: Tn = TGLWE::encode(&param, &m);
let c = TGLWE::<N, K>::encrypt(&mut rng, &pk, &p)?;
let c = TGLWE::encrypt(&mut rng, &param, &pk, &p)?;
for h in 0..N {
let c_h: TLWE<KN> = c.sample_extraction(h);
for h in 0..param.ring.n {
let c_h: TLWE = c.sample_extraction(&param, h);
let p_recovered = c_h.decrypt(&sk_tlwe);
let m_recovered = TLWE::<KN>::decode::<T>(&p_recovered);
let m_recovered = TLWE::decode(&param, &p_recovered);
assert_eq!(m.coeffs()[h], m_recovered.coeffs()[0]);
}
}

+ 61
- 54
tfhe/src/tgsw.rs

@ -4,62 +4,63 @@ use rand::Rng;
use std::array;
use std::ops::{Add, Mul};
use arith::{Ring, Rq, Tn, T64, TR};
use arith::{Ring, RingParam, Rq, Tn, T64, TR};
use crate::tlev::TLev;
use crate::{
tglwe::TGLWE,
tlwe::{PublicKey, SecretKey, TLWE},
};
use gfhe::glwe::GLWE;
use gfhe::glwe::{Param, GLWE};
/// vector of length K+1 = [K], [1]
#[derive(Clone, Debug)]
pub struct TGSW<const K: usize>(pub(crate) Vec<TLev<K>>, TLev<K>);
pub struct TGSW(pub(crate) Vec<TLev>, TLev);
impl<const K: usize> TGSW<K> {
impl TGSW {
pub fn encrypt_s(
mut rng: impl Rng,
param: &Param,
beta: u32,
l: u32,
sk: &SecretKey<K>,
sk: &SecretKey,
m: &T64,
) -> Result<Self> {
let a: Vec<TLev<K>> = (0..K)
.map(|i| TLev::encrypt_s(&mut rng, beta, l, sk, &(-sk.0 .0 .0[i] * *m)))
let a: Vec<TLev> = (0..param.k)
.map(|i| TLev::encrypt_s(&mut rng, &param, beta, l, sk, &(-sk.0 .0.r[i] * *m)))
.collect::<Result<Vec<_>>>()?;
let b: TLev<K> = TLev::encrypt_s(&mut rng, beta, l, sk, m)?;
let b: TLev = TLev::encrypt_s(&mut rng, &param, beta, l, sk, m)?;
Ok(Self(a, b))
}
pub fn decrypt(&self, sk: &SecretKey<K>, beta: u32) -> T64 {
pub fn decrypt(&self, sk: &SecretKey, beta: u32) -> T64 {
self.1.decrypt(sk, beta)
}
pub fn from_tlwe(_tlwe: TLWE<K>) -> Self {
pub fn from_tlwe(_tlwe: TLWE) -> Self {
todo!()
}
pub fn cmux(bit: Self, ct1: TLWE<K>, ct2: TLWE<K>) -> TLWE<K> {
pub fn cmux(bit: Self, ct1: TLWE, ct2: TLWE) -> TLWE {
ct1.clone() + (bit * (ct2 - ct1))
}
}
/// External product TGSW x TLWE
impl<const K: usize> Mul<TLWE<K>> for TGSW<K> {
type Output = TLWE<K>;
impl Mul<TLWE> for TGSW {
type Output = TLWE;
fn mul(self, tlwe: TLWE<K>) -> TLWE<K> {
fn mul(self, tlwe: TLWE) -> TLWE {
let beta: u32 = 2;
let l: u32 = 64; // TODO wip
// since N=1, each tlwe element is a vector of length=1, decomposed into
// l elements, and we have K of them
let tlwe_ab: Vec<T64> = [tlwe.0 .0 .0.clone(), vec![tlwe.0 .1]].concat();
let tlwe_ab: Vec<T64> = [tlwe.0 .0.r.clone(), vec![tlwe.0 .1]].concat();
let tgsw_ab: Vec<TLev<K>> = [self.0.clone(), vec![self.1]].concat();
let tgsw_ab: Vec<TLev> = [self.0.clone(), vec![self.1]].concat();
assert_eq!(tgsw_ab.len(), tlwe_ab.len());
let r: TLWE<K> = zip_eq(tgsw_ab, tlwe_ab)
let r: TLWE = zip_eq(tgsw_ab, tlwe_ab)
.map(|(tlev_i, tlwe_i)| tlev_i * tlwe_i.decompose(beta, l))
.sum();
r
@ -75,25 +76,26 @@ mod tests {
#[test]
fn test_encrypt_decrypt() -> Result<()> {
const T: u64 = 2; // plaintext modulus
const K: usize = 16;
type S = TGSW<K>;
let param = Param {
ring: RingParam { q: u64::MAX, n: 1 },
k: 16,
t: 2, // plaintext modulus
};
let beta: u32 = 2;
let l: u32 = 16;
let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..50 {
let (sk, _) = TLWE::<K>::new_key(&mut rng)?;
let (sk, _) = TLWE::new_key(&mut rng, &param)?;
let m: Rq<T, 1> = Rq::rand_u64(&mut rng, msg_dist)?;
let p: T64 = TLev::<K>::encode::<T>(&m); // plaintext
let m: Rq = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p: T64 = TLev::encode(&param, &m); // plaintext
let c = S::encrypt_s(&mut rng, beta, l, &sk, &p)?;
let c = TGSW::encrypt_s(&mut rng, &param, beta, l, &sk, &p)?;
let p_recovered = c.decrypt(&sk, beta);
let m_recovered = TLev::<K>::decode::<T>(&p_recovered);
let m_recovered = TLev::decode(&param, &p_recovered);
assert_eq!(m, m_recovered);
}
@ -103,36 +105,38 @@ mod tests {
#[test]
fn test_external_product() -> Result<()> {
const T: u64 = 2; // plaintext modulus
const K: usize = 32;
let param = Param {
ring: RingParam { q: u64::MAX, n: 1 },
k: 32,
t: 2, // plaintext modulus
};
let beta: u32 = 2;
let l: u32 = 64;
let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..50 {
let (sk, _) = TLWE::<K>::new_key(&mut rng)?;
let (sk, _) = TLWE::new_key(&mut rng, &param)?;
let m1: Rq<T, 1> = Rq::rand_u64(&mut rng, msg_dist)?;
let p1: T64 = TLev::<K>::encode::<T>(&m1);
let m1: Rq = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p1: T64 = TLev::encode(&param, &m1);
let m2: Rq<T, 1> = Rq::rand_u64(&mut rng, msg_dist)?;
let p2: T64 = TLWE::<K>::encode::<T>(&m2); // scaled by delta
let m2: Rq = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p2: T64 = TLWE::encode(&param, &m2); // scaled by delta
let tgsw = TGSW::<K>::encrypt_s(&mut rng, beta, l, &sk, &p1)?;
let tlwe = TLWE::<K>::encrypt_s(&mut rng, &sk, &p2)?;
let tgsw = TGSW::encrypt_s(&mut rng, &param, beta, l, &sk, &p1)?;
let tlwe = TLWE::encrypt_s(&mut rng, &param, &sk, &p2)?;
let res: TLWE<K> = tgsw * tlwe;
let res: TLWE = tgsw * tlwe;
// let p_recovered = res.decrypt(&sk, beta);
let p_recovered = res.decrypt(&sk);
// downscaled by delta^-1
let res_recovered = TLWE::<K>::decode::<T>(&p_recovered);
let res_recovered = TLWE::decode(&param, &p_recovered);
// assert_eq!(m1 * m2, m_recovered);
assert_eq!((m1.to_r() * m2.to_r()).to_rq::<T>(), res_recovered);
assert_eq!((m1.to_r() * m2.to_r()).to_rq(param.t), res_recovered);
}
Ok(())
@ -140,35 +144,38 @@ mod tests {
#[test]
fn test_cmux() -> Result<()> {
const T: u64 = 2; // plaintext modulus
const K: usize = 32;
let param = Param {
ring: RingParam { q: u64::MAX, n: 1 },
k: 32,
t: 2, // plaintext modulus
};
let beta: u32 = 2;
let l: u32 = 64;
let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..50 {
let (sk, _) = TLWE::<K>::new_key(&mut rng)?;
let (sk, _) = TLWE::new_key(&mut rng, &param)?;
let m1: Rq<T, 1> = Rq::rand_u64(&mut rng, msg_dist)?;
let p1: T64 = TLWE::<K>::encode::<T>(&m1); // scaled by delta
let m1: Rq = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p1: T64 = TLWE::encode(&param, &m1); // scaled by delta
let m2: Rq<T, 1> = Rq::rand_u64(&mut rng, msg_dist)?;
let p2: T64 = TLWE::<K>::encode::<T>(&m2); // scaled by delta
let m2: Rq = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p2: T64 = TLWE::encode(&param, &m2); // scaled by delta
for bit_raw in 0..2 {
let bit = TGSW::<K>::encrypt_s(&mut rng, beta, l, &sk, &T64(bit_raw))?;
let bit = TGSW::encrypt_s(&mut rng, &param, beta, l, &sk, &T64(bit_raw))?;
let c1 = TLWE::<K>::encrypt_s(&mut rng, &sk, &p1)?;
let c2 = TLWE::<K>::encrypt_s(&mut rng, &sk, &p2)?;
let c1 = TLWE::encrypt_s(&mut rng, &param, &sk, &p1)?;
let c2 = TLWE::encrypt_s(&mut rng, &param, &sk, &p2)?;
let res: TLWE<K> = TGSW::cmux(bit, c1, c2);
let res: TLWE = TGSW::cmux(bit, c1, c2);
let p_recovered = res.decrypt(&sk);
// downscaled by delta^-1
let res_recovered = TLWE::<K>::decode::<T>(&p_recovered);
let res_recovered = TLWE::decode(&param, &p_recovered);
if bit_raw == 0 {
assert_eq!(m1, res_recovered);

+ 63
- 41
tfhe/src/tlev.rs

@ -4,32 +4,48 @@ use rand::Rng;
use std::array;
use std::ops::{Add, Mul};
use arith::{Ring, Rq, Tn, T64, TR};
use arith::{Ring, RingParam, Rq, Tn, T64, TR};
use crate::tglwe::TGLWE;
use crate::tlwe::{PublicKey, SecretKey, TLWE};
use gfhe::glwe::Param;
#[derive(Clone, Debug)]
pub struct TLev<const K: usize>(pub(crate) Vec<TLWE<K>>);
pub struct TLev(pub(crate) Vec<TLWE>);
impl TLev {
pub fn encode(param: &Param, m: &Rq) -> T64 {
assert_eq!(m.param.n, 1);
assert_eq!(param.t, m.param.q);
impl<const K: usize> TLev<K> {
pub fn encode<const T: u64>(m: &Rq<T, 1>) -> T64 {
let coeffs = m.coeffs();
T64(coeffs[0].0) // N=1, so take the only coeff
T64(coeffs[0].v) // N=1, so take the only coeff
}
pub fn decode<const T: u64>(p: &T64) -> Rq<T, 1> {
Rq::<T, 1>::from_vec_u64(p.coeffs().iter().map(|c| c.0).collect())
pub fn decode(param: &Param, p: &T64) -> Rq {
Rq::from_vec_u64(
// &RingParam { q: u64::MAX, n: 1 },
&RingParam { q: param.t, n: 1 },
p.coeffs().iter().map(|c| c.0).collect(),
)
}
pub fn encrypt(
mut rng: impl Rng,
param: &Param,
beta: u32,
l: u32,
pk: &PublicKey<K>,
pk: &PublicKey,
m: &T64,
) -> Result<Self> {
let tlev: Vec<TLWE<K>> = (1..l + 1)
debug_assert_eq!(pk.1.k, param.k);
let tlev: Vec<TLWE> = (1..l + 1)
.map(|i| {
TLWE::<K>::encrypt(&mut rng, pk, &(*m * (u64::MAX / beta.pow(i as u32) as u64)))
TLWE::encrypt(
&mut rng,
param,
pk,
&(*m * (u64::MAX / beta.pow(i as u32) as u64)),
)
})
.collect::<Result<Vec<_>>>()?;
@ -37,12 +53,13 @@ impl TLev {
}
pub fn encrypt_s(
mut rng: impl Rng,
param: &Param,
_beta: u32, // TODO rm, and make beta=2 always
l: u32,
sk: &SecretKey<K>,
sk: &SecretKey,
m: &T64,
) -> Result<Self> {
let tlev: Vec<TLWE<K>> = (1..l as u64 + 1)
let tlev: Vec<TLWE> = (1..l as u64 + 1)
.map(|i| {
let aux = if i < 64 {
*m * (u64::MAX / (1u64 << i))
@ -51,22 +68,22 @@ impl TLev {
// by it, which would be equal to 1
*m
};
TLWE::<K>::encrypt_s(&mut rng, sk, &aux)
TLWE::encrypt_s(&mut rng, &param, sk, &aux)
})
.collect::<Result<Vec<_>>>()?;
Ok(Self(tlev))
}
pub fn decrypt(&self, sk: &SecretKey<K>, beta: u32) -> T64 {
pub fn decrypt(&self, sk: &SecretKey, beta: u32) -> T64 {
let pt = self.0[0].decrypt(sk);
pt.mul_div_round(beta as u64, u64::MAX)
}
}
// TODO review u64::MAX, since is -1 of the value we actually want
impl<const K: usize> TLev<K> {
pub fn iter(&self) -> std::slice::Iter<TLWE<K>> {
impl TLev {
pub fn iter(&self) -> std::slice::Iter<TLWE> {
self.0.iter()
}
}
@ -74,14 +91,14 @@ impl TLev {
// dot product between a TLev and Vec<T64>, usually Vec<T64> comes from a
// decomposition of T64
// TLev * Vec<T64> --> TLWE
impl<const K: usize> Mul<Vec<T64>> for TLev<K> {
type Output = TLWE<K>;
impl Mul<Vec<T64>> for TLev {
type Output = TLWE;
fn mul(self, v: Vec<T64>) -> Self::Output {
assert_eq!(self.0.len(), v.len());
// l TLWES
let tlwes: Vec<TLWE<K>> = self.0;
let r: TLWE<K> = zip_eq(v, tlwes).map(|(a_d_i, glwe_i)| glwe_i * a_d_i).sum();
let tlwes: Vec<TLWE> = self.0;
let r: TLWE = zip_eq(v, tlwes).map(|(a_d_i, glwe_i)| glwe_i * a_d_i).sum();
r
}
}
@ -95,27 +112,29 @@ mod tests {
#[test]
fn test_encrypt_decrypt() -> Result<()> {
const T: u64 = 2; // plaintext modulus
const K: usize = 16;
type S = TLev<K>;
let param = Param {
ring: RingParam { q: u64::MAX, n: 1 },
k: 16,
t: 2, // plaintext modulus
};
let beta: u32 = 2;
let l: u32 = 16;
let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..200 {
let (sk, pk) = TLWE::<K>::new_key(&mut rng)?;
let (sk, pk) = TLWE::new_key(&mut rng, &param)?;
let m: Rq<T, 1> = Rq::rand_u64(&mut rng, msg_dist)?;
let p: T64 = S::encode::<T>(&m); // plaintext
let m: Rq = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p: T64 = TLev::encode(&param, &m); // plaintext
let c = S::encrypt(&mut rng, beta, l, &pk, &p)?;
let c = TLev::encrypt(&mut rng, &param, beta, l, &pk, &p)?;
let p_recovered = c.decrypt(&sk, beta);
let m_recovered = S::decode::<T>(&p_recovered);
let m_recovered = TLev::decode(&param, &p_recovered);
assert_eq!(m.remodule::<T>(), m_recovered.remodule::<T>());
assert_eq!(m.remodule(param.t), m_recovered.remodule(param.t));
}
Ok(())
@ -123,32 +142,35 @@ mod tests {
#[test]
fn test_tlev_vect64_product() -> Result<()> {
const T: u64 = 2; // plaintext modulus
const K: usize = 16;
let param = Param {
ring: RingParam { q: u64::MAX, n: 1 },
k: 16,
t: 2, // plaintext modulus
};
let beta: u32 = 2;
let l: u32 = 16;
let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..200 {
let (sk, pk) = TLWE::<K>::new_key(&mut rng)?;
let (sk, pk) = TLWE::new_key(&mut rng, &param)?;
let m1: Rq<T, 1> = Rq::rand_u64(&mut rng, msg_dist)?;
let m2: Rq<T, 1> = Rq::rand_u64(&mut rng, msg_dist)?;
let p1: T64 = TLev::<K>::encode::<T>(&m1);
let p2: T64 = TLev::<K>::encode::<T>(&m2);
let m1: Rq = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2: Rq = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p1: T64 = TLev::encode(&param, &m1);
let p2: T64 = TLev::encode(&param, &m2);
let c1 = TLev::<K>::encrypt(&mut rng, beta, l, &pk, &p1)?;
let c1 = TLev::encrypt(&mut rng, &param, beta, l, &pk, &p1)?;
let c2 = p2.decompose(beta, l);
let c3 = c1 * c2;
let p_recovered = c3.decrypt(&sk);
let m_recovered = TLev::<K>::decode::<T>(&p_recovered);
let m_recovered = TLev::decode(&param, &p_recovered);
assert_eq!((m1.to_r() * m2.to_r()).to_rq::<T>(), m_recovered);
assert_eq!((m1.to_r() * m2.to_r()).to_rq(param.t), m_recovered);
}
Ok(())

+ 225
- 176
tfhe/src/tlwe.rs

@ -4,242 +4,276 @@ use rand::Rng;
use std::iter::Sum;
use std::ops::{Add, AddAssign, Mul, Sub};
use arith::{Ring, Rq, Tn, Zq, T64, TR};
use gfhe::{glwe, GLWE};
use arith::{Ring, RingParam, Rq, Tn, Zq, T64, TR};
use gfhe::{glwe, glwe::Param, GLWE};
use crate::tggsw::TGGSW;
use crate::tlev::TLev;
use crate::{tglwe, tglwe::TGLWE};
pub struct SecretKey<const K: usize>(pub glwe::SecretKey<T64, K>);
pub struct SecretKey(pub glwe::SecretKey<T64>);
impl<const KN: usize> SecretKey<KN> {
impl SecretKey {
/// from TFHE [2018-421] paper: A TLWE key k \in B^n, can be interpreted as a
/// TRLWE key K \in B_N[X]^k having the same sequence of coefficients and
/// vice-versa.
pub fn to_tglwe<const N: usize, const K: usize>(self) -> crate::tglwe::SecretKey<N, K> {
let s: TR<T64, KN> = self.0 .0;
pub fn to_tglwe(self, param: &Param) -> crate::tglwe::SecretKey {
let s: TR<T64> = self.0 .0; // of length K*N
assert_eq!(s.r.len(), param.k * param.ring.n); // sanity check
// split into K vectors, and interpret each of them as a T_N[X]/(X^N+1)
// polynomial
let r: Vec<Tn<N>> =
s.0.chunks(N)
.map(|v| Tn::<N>::from_vec(v.to_vec()))
let r: Vec<Tn> =
s.r.chunks(param.ring.n)
.map(|v| Tn::from_vec(&param.ring, v.to_vec()))
.collect();
crate::tglwe::SecretKey(glwe::SecretKey::<Tn<N>, K>(TR(r)))
crate::tglwe::SecretKey(glwe::SecretKey::<Tn>(TR { k: param.k, r }))
}
}
pub type PublicKey<const K: usize> = glwe::PublicKey<T64, K>;
pub type PublicKey = glwe::PublicKey<T64>;
#[derive(Clone, Debug)]
pub struct KSK<const K: usize>(Vec<TLev<K>>);
pub struct KSK(Vec<TLev>);
#[derive(Clone, Debug)]
pub struct TLWE<const K: usize>(pub GLWE<T64, K>);
pub struct TLWE(pub GLWE<T64>);
impl<const K: usize> TLWE<K> {
pub fn zero() -> Self {
Self(GLWE::<T64, K>::zero())
impl TLWE {
pub fn zero(k: usize, ring_param: &RingParam) -> Self {
Self(GLWE::<T64>::zero(k, ring_param))
}
pub fn new_key(rng: impl Rng) -> Result<(SecretKey<K>, PublicKey<K>)> {
let (sk, pk): (glwe::SecretKey<T64, K>, glwe::PublicKey<T64, K>) = GLWE::new_key(rng)?;
pub fn new_key(rng: impl Rng, param: &Param) -> Result<(SecretKey, PublicKey)> {
let (sk, pk): (glwe::SecretKey<T64>, glwe::PublicKey<T64>) = GLWE::new_key(rng, param)?;
Ok((SecretKey(sk), pk))
}
pub fn encode<const P: u64>(m: &Rq<P, 1>) -> T64 {
let delta = u64::MAX / P; // floored
// TODO use param instead of p:u64
pub fn encode(param: &Param, m: &Rq) -> T64 {
assert_eq!(param.ring.n, 1);
debug_assert_eq!(param.t, m.param.q); // plaintext modulus
//
let delta = u64::MAX / param.t; // floored
let coeffs = m.coeffs();
T64(coeffs[0].0 * delta)
T64(coeffs[0].v * delta)
}
pub fn decode<const P: u64>(p: &T64) -> Rq<P, 1> {
let p = p.mul_div_round(P, u64::MAX);
Rq::<P, 1>::from_vec_u64(p.coeffs().iter().map(|c| c.0).collect())
pub fn decode(param: &Param, p: &T64) -> Rq {
let p = p.mul_div_round(param.t, u64::MAX);
Rq::from_vec_u64(&param.pt(), p.coeffs().iter().map(|c| c.0).collect())
}
// encrypts with the given SecretKey (instead of PublicKey)
pub fn encrypt_s(rng: impl Rng, sk: &SecretKey<K>, p: &T64) -> Result<Self> {
let glwe = GLWE::encrypt_s(rng, &sk.0, p)?;
pub fn encrypt_s(rng: impl Rng, param: &Param, sk: &SecretKey, p: &T64) -> Result<Self> {
let glwe = GLWE::encrypt_s(rng, param, &sk.0, p)?;
Ok(Self(glwe))
}
pub fn encrypt(rng: impl Rng, pk: &PublicKey<K>, p: &T64) -> Result<Self> {
let glwe = GLWE::encrypt(rng, &pk, p)?;
pub fn encrypt(rng: impl Rng, param: &Param, pk: &PublicKey, p: &T64) -> Result<Self> {
let glwe = GLWE::encrypt(rng, param, pk, p)?;
Ok(Self(glwe))
}
pub fn decrypt(&self, sk: &SecretKey<K>) -> T64 {
pub fn decrypt(&self, sk: &SecretKey) -> T64 {
self.0.decrypt(&sk.0)
}
pub fn new_ksk(
mut rng: impl Rng,
param: &Param,
beta: u32,
l: u32,
sk: &SecretKey<K>,
new_sk: &SecretKey<K>,
) -> Result<KSK<K>> {
let r: Vec<TLev<K>> = (0..K)
sk: &SecretKey,
new_sk: &SecretKey,
) -> Result<KSK> {
let r: Vec<TLev> = (0..param.k)
.into_iter()
.map(|i|
// treat sk_i as the msg being encrypted
TLev::<K>::encrypt_s(&mut rng, beta, l, &new_sk, &sk.0.0 .0[i]))
TLev::encrypt_s(&mut rng, param, beta, l, &new_sk, &sk.0.0 .r[i]))
.collect::<Result<Vec<_>>>()?;
Ok(KSK(r))
}
pub fn key_switch(&self, beta: u32, l: u32, ksk: &KSK<K>) -> Self {
let (a, b): (TR<T64, K>, T64) = (self.0 .0.clone(), self.0 .1);
pub fn key_switch(&self, param: &Param, beta: u32, l: u32, ksk: &KSK) -> Self {
let (a, b): (TR<T64>, T64) = (self.0 .0.clone(), self.0 .1);
let lhs: TLWE<K> = TLWE(GLWE(TR::zero(), b));
let lhs: TLWE = TLWE(GLWE(TR::zero(param.k * param.ring.n, &param.ring), b));
// K iterations, ksk.0 contains K times GLev
let rhs: TLWE<K> = zip_eq(a.0, ksk.0.clone())
let rhs: TLWE = zip_eq(a.r, ksk.0.clone())
.map(|(a_i, ksk_i)| ksk_i * a_i.decompose(beta, l)) // dot_product
.sum();
lhs - rhs
}
// modulus switch from Q (2^64) to Q2 (in blind_rotation Q2=K*N)
pub fn mod_switch<const Q2: u64>(&self) -> Self {
let a: TR<T64, K> = self.0 .0.mod_switch::<Q2>();
let b: T64 = self.0 .1.mod_switch::<Q2>();
pub fn mod_switch(&self, q2: u64) -> Self {
let a: TR<T64> = self.0 .0.mod_switch(q2);
let b: T64 = self.0 .1.mod_switch(q2);
Self(GLWE(a, b))
}
}
// NOTE: the ugly const generics are temporary
pub fn blind_rotation<const N: usize, const K: usize, const KN: usize, const KN2: u64>(
c: TLWE<KN>,
btk: BootstrappingKey<N, K, KN>,
table: TGLWE<N, K>,
) -> TGLWE<N, K> {
let c_kn: TLWE<KN> = c.mod_switch::<KN2>();
let (a, b): (TR<T64, KN>, T64) = (c_kn.0 .0, c_kn.0 .1);
pub fn blind_rotation(
param: &Param,
c: TLWE, // kn
btk: BootstrappingKey,
table: TGLWE, // n,k
) -> TGLWE {
debug_assert_eq!(c.0 .0.k, param.k); // TODO or k*n?
// TODO replace `param.k*param.ring.n` by `param.kn()`
let c_kn: TLWE = c.mod_switch((param.k * param.ring.n) as u64);
let (a, b): (TR<T64>, T64) = (c_kn.0 .0, c_kn.0 .1);
// two main parts: rotate by a known power of X, rotate by a secret
// power of X (using the C gate)
// table * X^-b, ie. left rotate
let v_xb: TGLWE<N, K> = table.left_rotate(b.0 as usize);
let v_xb: TGLWE = table.left_rotate(b.0 as usize);
// rotate by a secret power of X using the cmux gate
let mut c_j: TGLWE<N, K> = v_xb.clone();
let _ = (1..K).map(|j| {
c_j = TGGSW::<N, K>::cmux(
let mut c_j: TGLWE = v_xb.clone();
let _ = (1..param.k).map(|j| {
c_j = TGGSW::cmux(
btk.0[j].clone(),
c_j.clone(),
c_j.clone().left_rotate(a.0[j].0 as usize),
c_j.clone().left_rotate(a.r[j].0 as usize),
);
dbg!(&c_j);
// dbg!(&c_j);
});
c_j
}
pub fn bootstrapping<const N: usize, const K: usize, const KN: usize, const KN2: u64>(
btk: BootstrappingKey<N, K, KN>,
table: TGLWE<N, K>,
c: TLWE<KN>,
) -> TLWE<KN> {
let rotated: TGLWE<N, K> = blind_rotation::<N, K, KN, KN2>(c, btk.clone(), table);
let c_h: TLWE<KN> = rotated.sample_extraction(0);
let r = c_h.key_switch(2, 64, &btk.1);
pub fn bootstrapping(
param: &Param,
btk: BootstrappingKey,
table: TGLWE,
c: TLWE, // kn
) -> TLWE {
// kn
let rotated: TGLWE = blind_rotation(param, c, btk.clone(), table);
let c_h: TLWE = rotated.sample_extraction(&param, 0);
let r = c_h.key_switch(param, 2, 64, &btk.1);
r
}
#[derive(Clone, Debug)]
pub struct BootstrappingKey<const N: usize, const K: usize, const KN: usize>(
pub Vec<TGGSW<N, K>>,
pub KSK<KN>,
pub struct BootstrappingKey(
pub Vec<TGGSW>,
pub KSK, // kn
);
impl<const N: usize, const K: usize, const KN: usize> BootstrappingKey<N, K, KN> {
pub fn from_sk(mut rng: impl Rng, sk: &tglwe::SecretKey<N, K>) -> Result<Self> {
impl BootstrappingKey {
pub fn from_sk(mut rng: impl Rng, param: &Param, sk: &tglwe::SecretKey) -> Result<Self> {
let (beta, l) = (2u32, 64u32); // TMP
//
let s: TR<Tn<N>, K> = sk.0 .0.clone();
let (sk2, _) = TLWE::<KN>::new_key(&mut rng)?; // TLWE<KN> compatible with TGLWE<N,K>
let s: TR<Tn> = sk.0 .0.clone();
let (sk2, _) = TLWE::new_key(&mut rng, &param.lwe())?; // TLWE<KN> compatible with TGLWE<N,K>
// each btk_j = TGGSW_sk(s_i)
let btk: Vec<TGGSW<N, K>> = s
let btk: Vec<TGGSW> = s
.iter()
.map(|s_i| TGGSW::<N, K>::encrypt_s(&mut rng, beta, l, sk, s_i))
.map(|s_i| TGGSW::encrypt_s(&mut rng, param, beta, l, sk, s_i))
.collect::<Result<Vec<_>>>()?;
let ksk = TLWE::<KN>::new_ksk(&mut rng, beta, l, &sk.to_tlwe(), &sk2)?;
let ksk = TLWE::new_ksk(
&mut rng,
&param.lwe(),
beta,
l,
&sk.to_tlwe(&param.lwe()), // converted to length k*n
&sk2, // created with length k*n
)?;
debug_assert_eq!(ksk.0.len(), param.lwe().k);
debug_assert_eq!(ksk.0.len(), param.k * param.ring.n);
Ok(Self(btk, ksk))
}
}
pub fn compute_lookup_table<const T: u64, const K: usize, const N: usize>() -> TGLWE<N, K> {
pub fn compute_lookup_table(param: &Param) -> TGLWE {
// from 2021-1402:
// v(x) = \sum_j^{N-1} [(p_j / 2N mod p)/p] X^j
// matrix of coefficients with size K*N = delta x T
let delta: usize = N / T as usize;
let values: Vec<Zq<T>> = (0..T).map(|v| Zq::<T>::from_u64(v)).collect();
let coeffs: Vec<Zq<T>> = (0..T as usize)
let delta: usize = param.ring.n / param.t as usize;
let values: Vec<Zq> = (0..param.t).map(|v| Zq::from_u64(param.t, v)).collect();
let coeffs: Vec<Zq> = (0..param.t as usize)
.flat_map(|i| vec![values[i]; delta])
.collect();
let table = Rq::<T, N>::from_vec(coeffs);
let table = Rq::from_vec(&param.pt(), coeffs);
// encode the table as plaintext
let v: Tn<N> = TGLWE::<N, K>::encode::<T>(&table);
let v: Tn = TGLWE::encode(param, &table);
// encode the table as TGLWE ciphertext
let v: TGLWE<N, K> = TGLWE::<N, K>::from_plaintext(v);
let v: TGLWE = TGLWE::from_plaintext(param.k, &param.ring, v);
v
}
impl<const K: usize> Add<TLWE<K>> for TLWE<K> {
impl Add<TLWE> for TLWE {
type Output = Self;
fn add(self, other: Self) -> Self {
debug_assert_eq!(self.0 .0.k, other.0 .0.k);
debug_assert_eq!(self.0 .1.param(), other.0 .1.param());
Self(self.0 + other.0)
}
}
impl<const K: usize> AddAssign for TLWE<K> {
impl AddAssign for TLWE {
fn add_assign(&mut self, rhs: Self) {
debug_assert_eq!(self.0 .0.k, rhs.0 .0.k);
debug_assert_eq!(self.0 .1.param(), rhs.0 .1.param());
self.0 += rhs.0
}
}
impl<const K: usize> Sum<TLWE<K>> for TLWE<K> {
fn sum<I>(iter: I) -> Self
impl Sum<TLWE> for TLWE {
fn sum<I>(mut iter: I) -> Self
where
I: Iterator<Item = Self>,
{
let mut acc = TLWE::<K>::zero();
for e in iter {
acc += e;
}
acc
// let mut acc = TLWE::<K>::zero();
// for e in iter {
// acc += e;
// }
// acc
let first = iter.next().unwrap();
iter.fold(first, |acc, e| acc + e)
}
}
impl<const K: usize> Sub<TLWE<K>> for TLWE<K> {
impl Sub<TLWE> for TLWE {
type Output = Self;
fn sub(self, other: Self) -> Self {
debug_assert_eq!(self.0 .0.k, other.0 .0.k);
debug_assert_eq!(self.0 .1.param(), other.0 .1.param());
Self(self.0 - other.0)
}
}
// plaintext addition
impl<const K: usize> Add<T64> for TLWE<K> {
impl Add<T64> for TLWE {
type Output = Self;
fn add(self, plaintext: T64) -> Self {
let a: TR<T64, K> = self.0 .0;
let a: TR<T64> = self.0 .0;
let b: T64 = self.0 .1 + plaintext;
Self(GLWE(a, b))
}
}
// plaintext substraction
impl<const K: usize> Sub<T64> for TLWE<K> {
impl Sub<T64> for TLWE {
type Output = Self;
fn sub(self, plaintext: T64) -> Self {
let a: TR<T64, K> = self.0 .0;
let a: TR<T64> = self.0 .0;
let b: T64 = self.0 .1 - plaintext;
Self(GLWE(a, b))
}
}
// plaintext multiplication
impl<const K: usize> Mul<T64> for TLWE<K> {
impl Mul<T64> for TLWE {
type Output = Self;
fn mul(self, plaintext: T64) -> Self {
let a: TR<T64, K> = TR(self.0 .0 .0.iter().map(|r_i| *r_i * plaintext).collect());
let a: TR<T64> = TR {
k: self.0 .0.k,
r: self.0 .0.r.iter().map(|r_i| *r_i * plaintext).collect(),
};
let b: T64 = self.0 .1 * plaintext;
Self(GLWE(a, b))
}
@ -255,29 +289,31 @@ mod tests {
#[test]
fn test_encrypt_decrypt() -> Result<()> {
const T: u64 = 128; // msg space (msg modulus)
const K: usize = 16;
type S = TLWE<K>;
let param = Param {
ring: RingParam { q: u64::MAX, n: 1 },
k: 16,
t: 128, // plaintext modulus
};
let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..200 {
let (sk, pk) = S::new_key(&mut rng)?;
let (sk, pk) = TLWE::new_key(&mut rng, &param)?;
let m = Rq::<T, 1>::rand_u64(&mut rng, msg_dist)?;
let p: T64 = S::encode::<T>(&m);
let m = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p: T64 = TLWE::encode(&param, &m);
let c = S::encrypt(&mut rng, &pk, &p)?;
let c = TLWE::encrypt(&mut rng, &param, &pk, &p)?;
let p_recovered = c.decrypt(&sk);
let m_recovered = S::decode::<T>(&p_recovered);
let m_recovered = TLWE::decode(&param, &p_recovered);
assert_eq!(m, m_recovered);
// same but using encrypt_s (with sk instead of pk))
let c = S::encrypt_s(&mut rng, &sk, &p)?;
let c = TLWE::encrypt_s(&mut rng, &param, &sk, &p)?;
let p_recovered = c.decrypt(&sk);
let m_recovered = S::decode::<T>(&p_recovered);
let m_recovered = TLWE::decode(&param, &p_recovered);
assert_eq!(m, m_recovered);
}
@ -287,30 +323,32 @@ mod tests {
#[test]
fn test_addition() -> Result<()> {
const T: u64 = 128;
const K: usize = 16;
type S = TLWE<K>;
let param = Param {
ring: RingParam { q: u64::MAX, n: 1 },
k: 16,
t: 128, // plaintext modulus
};
let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..200 {
let (sk, pk) = S::new_key(&mut rng)?;
let (sk, pk) = TLWE::new_key(&mut rng, &param)?;
let m1 = Rq::<T, 1>::rand_u64(&mut rng, msg_dist)?;
let m2 = Rq::<T, 1>::rand_u64(&mut rng, msg_dist)?;
let p1: T64 = S::encode::<T>(&m1); // plaintext
let p2: T64 = S::encode::<T>(&m2); // plaintext
let m1 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p1: T64 = TLWE::encode(&param, &m1); // plaintext
let p2: T64 = TLWE::encode(&param, &m2); // plaintext
let c1 = S::encrypt(&mut rng, &pk, &p1)?;
let c2 = S::encrypt(&mut rng, &pk, &p2)?;
let c1 = TLWE::encrypt(&mut rng, &param, &pk, &p1)?;
let c2 = TLWE::encrypt(&mut rng, &param, &pk, &p2)?;
let c3 = c1 + c2;
let p3_recovered = c3.decrypt(&sk);
let m3_recovered = S::decode::<T>(&p3_recovered);
let m3_recovered = TLWE::decode(&param, &p3_recovered);
assert_eq!((m1 + m2).remodule::<T>(), m3_recovered.remodule::<T>());
assert_eq!((m1 + m2).remodule(param.t), m3_recovered.remodule(param.t));
}
Ok(())
@ -318,27 +356,29 @@ mod tests {
#[test]
fn test_add_plaintext() -> Result<()> {
const T: u64 = 128;
const K: usize = 16;
type S = TLWE<K>;
let param = Param {
ring: RingParam { q: u64::MAX, n: 1 },
k: 16,
t: 128, // plaintext modulus
};
let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..200 {
let (sk, pk) = S::new_key(&mut rng)?;
let (sk, pk) = TLWE::new_key(&mut rng, &param)?;
let m1 = Rq::<T, 1>::rand_u64(&mut rng, msg_dist)?;
let m2 = Rq::<T, 1>::rand_u64(&mut rng, msg_dist)?;
let p1: T64 = S::encode::<T>(&m1); // plaintext
let p2: T64 = S::encode::<T>(&m2); // plaintext
let m1 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p1: T64 = TLWE::encode(&param, &m1); // plaintext
let p2: T64 = TLWE::encode(&param, &m2); // plaintext
let c1 = S::encrypt(&mut rng, &pk, &p1)?;
let c1 = TLWE::encrypt(&mut rng, &param, &pk, &p1)?;
let c3 = c1 + p2;
let p3_recovered = c3.decrypt(&sk);
let m3_recovered = S::decode::<T>(&p3_recovered);
let m3_recovered = TLWE::decode(&param, &p3_recovered);
assert_eq!(m1 + m2, m3_recovered);
}
@ -348,30 +388,32 @@ mod tests {
#[test]
fn test_mul_plaintext() -> Result<()> {
const T: u64 = 128;
const K: usize = 16;
type S = TLWE<K>;
let param = Param {
ring: RingParam { q: u64::MAX, n: 1 },
k: 16,
t: 128, // plaintext modulus
};
let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..200 {
let (sk, pk) = S::new_key(&mut rng)?;
let (sk, pk) = TLWE::new_key(&mut rng, &param)?;
let m1 = Rq::<T, 1>::rand_u64(&mut rng, msg_dist)?;
let m2 = Rq::<T, 1>::rand_u64(&mut rng, msg_dist)?;
let p1: T64 = S::encode::<T>(&m1);
let m1 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p1: T64 = TLWE::encode(&param, &m1);
// don't scale up p2, set it directly from m2
// let p2: T64 = Tn(array::from_fn(|i| T64(m2.coeffs()[i].0)));
let p2: T64 = T64(m2.coeffs()[0].0);
let p2: T64 = T64(m2.coeffs()[0].v);
let c1 = S::encrypt(&mut rng, &pk, &p1)?;
let c1 = TLWE::encrypt(&mut rng, &param, &pk, &p1)?;
let c3 = c1 * p2;
let p3_recovered: T64 = c3.decrypt(&sk);
let m3_recovered = S::decode::<T>(&p3_recovered);
assert_eq!((m1.to_r() * m2.to_r()).to_rq::<T>(), m3_recovered);
let m3_recovered = TLWE::decode(&param, &p3_recovered);
assert_eq!((m1.to_r() * m2.to_r()).to_rq(param.t), m3_recovered);
}
Ok(())
@ -379,38 +421,40 @@ mod tests {
#[test]
fn test_key_switch() -> Result<()> {
const T: u64 = 128; // plaintext modulus
const K: usize = 16;
type S = TLWE<K>;
let param = Param {
ring: RingParam { q: u64::MAX, n: 1 },
k: 16,
t: 128, // plaintext modulus
};
let beta: u32 = 2;
let l: u32 = 64;
let mut rng = rand::thread_rng();
let (sk, pk) = S::new_key(&mut rng)?;
let (sk2, _) = S::new_key(&mut rng)?;
let (sk, pk) = TLWE::new_key(&mut rng, &param)?;
let (sk2, _) = TLWE::new_key(&mut rng, &param)?;
// ksk to switch from sk to sk2
let ksk = S::new_ksk(&mut rng, beta, l, &sk, &sk2)?;
let ksk = TLWE::new_ksk(&mut rng, &param, beta, l, &sk, &sk2)?;
let msg_dist = Uniform::new(0_u64, T);
let m = Rq::<T, 1>::rand_u64(&mut rng, msg_dist)?;
let p = S::encode::<T>(&m); // plaintext
//
let c = S::encrypt_s(&mut rng, &sk, &p)?;
let msg_dist = Uniform::new(0_u64, param.t);
let m = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p = TLWE::encode(&param, &m); // plaintext
//
let c = TLWE::encrypt_s(&mut rng, &param, &sk, &p)?;
let c2 = c.key_switch(beta, l, &ksk);
let c2 = c.key_switch(&param, beta, l, &ksk);
// decrypt with the 2nd secret key
let p_recovered = c2.decrypt(&sk2);
let m_recovered = S::decode::<T>(&p_recovered);
assert_eq!(m.remodule::<T>(), m_recovered.remodule::<T>());
let m_recovered = TLWE::decode(&param, &p_recovered);
assert_eq!(m.remodule(param.t), m_recovered.remodule(param.t));
// do the same but now encrypting with pk
let c = S::encrypt(&mut rng, &pk, &p)?;
let c2 = c.key_switch(beta, l, &ksk);
let c = TLWE::encrypt(&mut rng, &param, &pk, &p)?;
let c2 = c.key_switch(&param, beta, l, &ksk);
let p_recovered = c2.decrypt(&sk2);
let m_recovered = S::decode::<T>(&p_recovered);
let m_recovered = TLWE::decode(&param, &p_recovered);
assert_eq!(m, m_recovered);
Ok(())
@ -418,39 +462,44 @@ mod tests {
#[test]
fn test_bootstrapping() -> Result<()> {
const T: u64 = 128; // plaintext modulus
const K: usize = 1;
const N: usize = 1024;
const KN: usize = K * N;
let param = Param {
ring: RingParam {
q: u64::MAX,
n: 1024,
},
k: 1,
t: 128, // plaintext modulus
};
// const T: u64 = 128; // plaintext modulus
// const K: usize = 1;
// const N: usize = 1024;
// const KN: usize = K * N;
let mut rng = rand::thread_rng();
let start = Instant::now();
let table: TGLWE<N, K> = compute_lookup_table::<T, K, N>();
let table: TGLWE = compute_lookup_table(&param);
println!("table took: {:?}", start.elapsed());
let (sk, _) = TGLWE::<N, K>::new_key::<KN>(&mut rng)?;
let sk_tlwe: SecretKey<KN> = sk.to_tlwe::<KN>();
let (sk, _) = TGLWE::new_key(&mut rng, &param)?;
let sk_tlwe: SecretKey = sk.to_tlwe(&param);
let start = Instant::now();
let btk = BootstrappingKey::<N, K, KN>::from_sk(&mut rng, &sk)?;
let btk = BootstrappingKey::from_sk(&mut rng, &param, &sk)?;
println!("btk took: {:?}", start.elapsed());
let msg_dist = Uniform::new(0_u64, T);
let m = Rq::<T, 1>::rand_u64(&mut rng, msg_dist)?;
dbg!(&m);
let p = TLWE::<K>::encode::<T>(&m); // plaintext
let msg_dist = Uniform::new(0_u64, param.t);
let m = Rq::rand_u64(&mut rng, msg_dist, &param.lwe().pt())?; // q=t, n=1
let p = TLWE::encode(&param.lwe(), &m); // plaintext
let c = TLWE::<KN>::encrypt_s(&mut rng, &sk_tlwe, &p)?;
let c = TLWE::encrypt_s(&mut rng, &param.lwe(), &sk_tlwe, &p)?;
let start = Instant::now();
// the ugly const generics are temporary
let bootstrapped: TLWE<KN> =
bootstrapping::<N, K, KN, { K as u64 * N as u64 }>(btk, table, c);
let bootstrapped: TLWE = bootstrapping(&param, btk, table, c);
println!("bootstrapping took: {:?}", start.elapsed());
let p_recovered: T64 = bootstrapped.decrypt(&sk_tlwe);
let m_recovered = TLWE::<KN>::decode::<T>(&p_recovered);
dbg!(&m_recovered);
let m_recovered = TLWE::decode(&param.lwe(), &p_recovered);
assert_eq!(m_recovered, m);
Ok(())

Loading…
Cancel
Save