tfhe: get rid of constant generics

This commit is contained in:
2025-08-13 19:31:43 +00:00
parent 2a9cbc71de
commit bb3288f211
16 changed files with 729 additions and 582 deletions

View File

@@ -19,7 +19,6 @@ pub mod tuple_ring;
pub mod ntt; pub mod ntt;
// expose objects // expose objects
pub use complex::C; pub use complex::C;
pub use matrix::Matrix; pub use matrix::Matrix;
pub use torus::T64; pub use torus::T64;

View File

@@ -6,11 +6,7 @@
//! generics; but once using real-world parameters, the stack could not handle //! generics; but once using real-world parameters, the stack could not handle
//! it, so moved to use Vec instead of fixed-sized arrays, and adapted the NTT //! it, so moved to use Vec instead of fixed-sized arrays, and adapted the NTT
//! implementation to that too. //! implementation to that too.
use crate::{ use crate::{ring::RingParam, ring_nq::Rq, zq::Zq};
ring::{Ring, RingParam},
ring_nq::Rq,
zq::Zq,
};
use std::collections::HashMap; use std::collections::HashMap;
@@ -197,6 +193,7 @@ const fn const_inv_mod(q: u64, x: u64) -> u64 {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::Ring;
use anyhow::Result; use anyhow::Result;

View File

@@ -1,16 +1,11 @@
//! Polynomial ring Z[X]/(X^N+1) //! Polynomial ring Z[X]/(X^N+1)
//! //!
use anyhow::Result;
use itertools::zip_eq; use itertools::zip_eq;
use rand::{distributions::Distribution, Rng}; use rand::{distributions::Distribution, Rng};
use std::array;
use std::borrow::Borrow;
use std::fmt; use std::fmt;
use std::iter::Sum; use std::iter::Sum;
use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; use std::ops::{Add, AddAssign, Mul, Neg, Sub, SubAssign};
use crate::Ring;
// TODO rename to not have name conflicts with the Ring trait (R: Ring) // TODO rename to not have name conflicts with the Ring trait (R: Ring)
// PolynomialRing element, where the PolynomialRing is R = Z[X]/(X^n +1) // PolynomialRing element, where the PolynomialRing is R = Z[X]/(X^n +1)

View File

@@ -4,8 +4,6 @@
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use itertools::zip_eq; use itertools::zip_eq;
use rand::{distributions::Distribution, Rng}; use rand::{distributions::Distribution, Rng};
use std::array;
use std::borrow::Borrow;
use std::fmt; use std::fmt;
use std::iter::Sum; use std::iter::Sum;
use std::ops::{Add, AddAssign, Mul, Neg, Sub, SubAssign}; use std::ops::{Add, AddAssign, Mul, Neg, Sub, SubAssign};

View File

@@ -9,7 +9,6 @@
use itertools::zip_eq; use itertools::zip_eq;
use rand::{distributions::Distribution, Rng}; use rand::{distributions::Distribution, Rng};
use std::array;
use std::iter::Sum; use std::iter::Sum;
use std::ops::{Add, AddAssign, Mul, Neg, Sub, SubAssign}; use std::ops::{Add, AddAssign, Mul, Neg, Sub, SubAssign};

View File

@@ -1,15 +1,9 @@
//! This file implements the struct for an Tuple of Ring Rq elements and its //! This file implements the struct for an Tuple of Ring Rq elements and its
//! operations, which are performed element-wise. //! operations, which are performed element-wise.
use anyhow::Result;
use itertools::zip_eq; use itertools::zip_eq;
use rand::{distributions::Distribution, Rng}; use rand::{distributions::Distribution, Rng};
use rand_distr::{Normal, Uniform}; use std::ops::{Add, Mul, Neg, Sub};
use std::iter::Sum;
use std::{
array,
ops::{Add, Mul, Neg, Sub},
};
use crate::{Ring, RingParam}; use crate::{Ring, RingParam};
@@ -28,23 +22,23 @@ impl<R: Ring> TR<R> {
assert_eq!(r.len(), k); assert_eq!(r.len(), k);
Self { k, r } Self { k, r }
} }
pub fn zero(k: usize, r_params: &RingParam) -> Self { pub fn zero(k: usize, r_param: &RingParam) -> Self {
Self { Self {
k, k,
r: (0..k).into_iter().map(|_| R::zero(r_params)).collect(), r: (0..k).into_iter().map(|_| R::zero(r_param)).collect(),
} }
} }
pub fn rand( pub fn rand(
mut rng: impl Rng, mut rng: impl Rng,
dist: impl Distribution<f64>, dist: impl Distribution<f64>,
k: usize, k: usize,
r_params: &RingParam, r_param: &RingParam,
) -> Self { ) -> Self {
Self { Self {
k, k,
r: (0..k) r: (0..k)
.into_iter() .into_iter()
.map(|_| R::rand(&mut rng, &dist, r_params)) .map(|_| R::rand(&mut rng, &dist, r_param))
.collect(), .collect(),
} }
} }

View File

@@ -1,5 +1,4 @@
use rand::{distributions::Distribution, Rng}; use rand::{distributions::Distribution, Rng};
use std::borrow::Borrow;
use std::fmt; use std::fmt;
use std::ops::{Add, AddAssign, Div, Mul, Neg, Sub, SubAssign}; use std::ops::{Add, AddAssign, Div, Mul, Neg, Sub, SubAssign};

View File

@@ -23,7 +23,7 @@ pub struct Param {
p: u64, p: u64,
} }
impl Param { impl Param {
// returns the plaintext params // returns the plaintext param
pub fn pt(&self) -> RingParam { pub fn pt(&self) -> RingParam {
RingParam { RingParam {
q: self.t, q: self.t,
@@ -117,7 +117,7 @@ impl BFV {
// const DELTA: u64 = Q / T; // floor // const DELTA: u64 = Q / T; // floor
/// generate a new key pair (privK, pubK) /// generate a new key pair (privK, pubK)
pub fn new_key(mut rng: impl Rng, params: &Param) -> Result<(SecretKey, PublicKey)> { pub fn new_key(mut rng: impl Rng, param: &Param) -> Result<(SecretKey, PublicKey)> {
// WIP: review probabilities // WIP: review probabilities
// let Xi_key = Uniform::new(-1_f64, 1_f64); // let Xi_key = Uniform::new(-1_f64, 1_f64);
@@ -126,37 +126,37 @@ impl BFV {
// secret key // secret key
// let mut s = Rq::rand_f64(&mut rng, Xi_key)?; // let mut s = Rq::rand_f64(&mut rng, Xi_key)?;
let mut s = Rq::rand_u64(&mut rng, Xi_key, &params.ring)?; let mut s = Rq::rand_u64(&mut rng, Xi_key, &param.ring)?;
// since s is going to be multiplied by other Rq elements, already // since s is going to be multiplied by other Rq elements, already
// compute its NTT // compute its NTT
s.compute_evals(); s.compute_evals();
// pk = (-a * s + e, a) // pk = (-a * s + e, a)
let a = Rq::rand_u64(&mut rng, Uniform::new(0_u64, params.ring.q), &params.ring)?; let a = Rq::rand_u64(&mut rng, Uniform::new(0_u64, param.ring.q), &param.ring)?;
let e = Rq::rand_f64(&mut rng, Xi_err, &params.ring)?; let e = Rq::rand_f64(&mut rng, Xi_err, &param.ring)?;
let pk: PublicKey = PublicKey(&(&(-a.clone()) * &s) + &e, a.clone()); // TODO rm clones let pk: PublicKey = PublicKey(&(&(-a.clone()) * &s) + &e, a.clone()); // TODO rm clones
Ok((SecretKey(s), pk)) Ok((SecretKey(s), pk))
} }
// note: m is modulus t // note: m is modulus t
pub fn encrypt(mut rng: impl Rng, params: &Param, pk: &PublicKey, m: &Rq) -> Result<RLWE> { pub fn encrypt(mut rng: impl Rng, param: &Param, pk: &PublicKey, m: &Rq) -> Result<RLWE> {
// assert params & inputs // assert param & inputs
debug_assert_eq!(params.ring, pk.0.param); debug_assert_eq!(param.ring, pk.0.param);
debug_assert_eq!(params.t, m.param.q); debug_assert_eq!(param.t, m.param.q);
debug_assert_eq!(params.ring.n, m.param.n); debug_assert_eq!(param.ring.n, m.param.n);
let Xi_key = Uniform::new(-1_f64, 1_f64); let Xi_key = Uniform::new(-1_f64, 1_f64);
// let Xi_key = Uniform::new(0_u64, 2_u64); // let Xi_key = Uniform::new(0_u64, 2_u64);
let Xi_err = Normal::new(0_f64, ERR_SIGMA)?; let Xi_err = Normal::new(0_f64, ERR_SIGMA)?;
let u = Rq::rand_f64(&mut rng, Xi_key, &params.ring)?; let u = Rq::rand_f64(&mut rng, Xi_key, &param.ring)?;
// let u = Rq::rand_u64(&mut rng, Xi_key)?; // let u = Rq::rand_u64(&mut rng, Xi_key)?;
let e_1 = Rq::rand_f64(&mut rng, Xi_err, &params.ring)?; let e_1 = Rq::rand_f64(&mut rng, Xi_err, &param.ring)?;
let e_2 = Rq::rand_f64(&mut rng, Xi_err, &params.ring)?; let e_2 = Rq::rand_f64(&mut rng, Xi_err, &param.ring)?;
// migrate m's coeffs to the bigger modulus Q (from T) // migrate m's coeffs to the bigger modulus Q (from T)
let m = m.remodule(params.ring.q); let m = m.remodule(param.ring.q);
let c0 = &pk.0 * &u + e_1 + m * (params.ring.q / params.t); // floor(q/t)=DELTA let c0 = &pk.0 * &u + e_1 + m * (param.ring.q / param.t); // floor(q/t)=DELTA
let c1 = &pk.1 * &u + e_2; let c1 = &pk.1 * &u + e_2;
Ok(RLWE(c0, c1)) Ok(RLWE(c0, c1))
} }
@@ -280,7 +280,7 @@ mod tests {
#[test] #[test]
fn test_encrypt_decrypt() -> Result<()> { fn test_encrypt_decrypt() -> Result<()> {
let params = Param { let param = Param {
ring: RingParam { ring: RingParam {
q: 2u64.pow(16) + 1, // q prime, and 2^q + 1 shape q: 2u64.pow(16) + 1, // q prime, and 2^q + 1 shape
n: 512, n: 512,
@@ -292,13 +292,13 @@ mod tests {
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
for _ in 0..100 { for _ in 0..100 {
let (sk, pk) = BFV::new_key(&mut rng, &params)?; let (sk, pk) = BFV::new_key(&mut rng, &param)?;
let msg_dist = Uniform::new(0_u64, params.t); let msg_dist = Uniform::new(0_u64, param.t);
let m = Rq::rand_u64(&mut rng, msg_dist, &params.pt())?; let m = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let c = BFV::encrypt(&mut rng, &params, &pk, &m)?; let c = BFV::encrypt(&mut rng, &param, &pk, &m)?;
let m_recovered = BFV::decrypt(&params, &sk, &c); let m_recovered = BFV::decrypt(&param, &sk, &c);
assert_eq!(m, m_recovered); assert_eq!(m, m_recovered);
} }
@@ -308,7 +308,7 @@ mod tests {
#[test] #[test]
fn test_addition() -> Result<()> { fn test_addition() -> Result<()> {
let params = Param { let param = Param {
ring: RingParam { ring: RingParam {
q: 2u64.pow(16) + 1, // q prime, and 2^q + 1 shape q: 2u64.pow(16) + 1, // q prime, and 2^q + 1 shape
n: 128, n: 128,
@@ -320,18 +320,18 @@ mod tests {
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
for _ in 0..100 { for _ in 0..100 {
let (sk, pk) = BFV::new_key(&mut rng, &params)?; let (sk, pk) = BFV::new_key(&mut rng, &param)?;
let msg_dist = Uniform::new(0_u64, params.t); let msg_dist = Uniform::new(0_u64, param.t);
let m1 = Rq::rand_u64(&mut rng, msg_dist, &params.pt())?; let m1 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2 = Rq::rand_u64(&mut rng, msg_dist, &params.pt())?; let m2 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let c1 = BFV::encrypt(&mut rng, &params, &pk, &m1)?; let c1 = BFV::encrypt(&mut rng, &param, &pk, &m1)?;
let c2 = BFV::encrypt(&mut rng, &params, &pk, &m2)?; let c2 = BFV::encrypt(&mut rng, &param, &pk, &m2)?;
let c3 = c1 + c2; let c3 = c1 + c2;
let m3_recovered = BFV::decrypt(&params, &sk, &c3); let m3_recovered = BFV::decrypt(&param, &sk, &c3);
assert_eq!(m1 + m2, m3_recovered); assert_eq!(m1 + m2, m3_recovered);
} }
@@ -342,7 +342,7 @@ mod tests {
#[test] #[test]
fn test_constant_add_mul() -> Result<()> { fn test_constant_add_mul() -> Result<()> {
let q: u64 = 2u64.pow(16) + 1; // q prime, and 2^q + 1 shape let q: u64 = 2u64.pow(16) + 1; // q prime, and 2^q + 1 shape
let params = Param { let param = Param {
ring: RingParam { q, n: 16 }, ring: RingParam { q, n: 16 },
t: 8, // plaintext modulus t: 8, // plaintext modulus
p: q * q, p: q * q,
@@ -350,26 +350,26 @@ mod tests {
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let (sk, pk) = BFV::new_key(&mut rng, &params)?; let (sk, pk) = BFV::new_key(&mut rng, &param)?;
let msg_dist = Uniform::new(0_u64, params.t); let msg_dist = Uniform::new(0_u64, param.t);
let m1 = Rq::rand_u64(&mut rng, msg_dist, &params.pt())?; let m1 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2_const = Rq::rand_u64(&mut rng, msg_dist, &params.pt())?; let m2_const = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let c1 = BFV::encrypt(&mut rng, &params, &pk, &m1)?; let c1 = BFV::encrypt(&mut rng, &param, &pk, &m1)?;
let c3_add = &c1 + &m2_const; let c3_add = &c1 + &m2_const;
let m3_add_recovered = BFV::decrypt(&params, &sk, &c3_add); let m3_add_recovered = BFV::decrypt(&param, &sk, &c3_add);
assert_eq!(&m1 + &m2_const, m3_add_recovered); assert_eq!(&m1 + &m2_const, m3_add_recovered);
// test multiplication of a ciphertext by a constant // test multiplication of a ciphertext by a constant
let rlk = BFV::rlk_key(&mut rng, &params, &sk)?; let rlk = BFV::rlk_key(&mut rng, &param, &sk)?;
let c3_mul = BFV::mul_const(&rlk, &c1, &m2_const); let c3_mul = BFV::mul_const(&rlk, &c1, &m2_const);
let m3_mul_recovered = BFV::decrypt(&params, &sk, &c3_mul); let m3_mul_recovered = BFV::decrypt(&param, &sk, &c3_mul);
assert_eq!( assert_eq!(
(m1.to_r() * m2_const.to_r()).to_rq(params.t).coeffs(), (m1.to_r() * m2_const.to_r()).to_rq(param.t).coeffs(),
m3_mul_recovered.coeffs() m3_mul_recovered.coeffs()
); );
@@ -380,7 +380,7 @@ mod tests {
// TMP WIP // TMP WIP
#[test] #[test]
#[ignore] #[ignore]
fn test_params() -> Result<()> { fn test_param() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1; // q prime, and 2^q + 1 shape const Q: u64 = 2u64.pow(16) + 1; // q prime, and 2^q + 1 shape
const N: usize = 32; const N: usize = 32;
const T: u64 = 8; // plaintext modulus const T: u64 = 8; // plaintext modulus
@@ -504,30 +504,30 @@ mod tests {
#[test] #[test]
fn test_tensor() -> Result<()> { fn test_tensor() -> Result<()> {
let q: u64 = 2u64.pow(16) + 1; // q prime, and 2^q + 1 shape let q: u64 = 2u64.pow(16) + 1; // q prime, and 2^q + 1 shape
let params = Param { let param = Param {
ring: RingParam { q, n: 16 }, ring: RingParam { q, n: 16 },
t: 2, // plaintext modulus t: 2, // plaintext modulus
p: q * q, p: q * q,
}; };
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, params.t); let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..1_000 { for _ in 0..1_000 {
let m1 = Rq::rand_u64(&mut rng, msg_dist, &params.pt())?; let m1 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2 = Rq::rand_u64(&mut rng, msg_dist, &params.pt())?; let m2 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
test_tensor_opt(&mut rng, &params, m1, m2)?; test_tensor_opt(&mut rng, &param, m1, m2)?;
} }
Ok(()) Ok(())
} }
fn test_tensor_opt(mut rng: impl Rng, params: &Param, m1: Rq, m2: Rq) -> Result<()> { fn test_tensor_opt(mut rng: impl Rng, param: &Param, m1: Rq, m2: Rq) -> Result<()> {
let (sk, pk) = BFV::new_key(&mut rng, &params)?; let (sk, pk) = BFV::new_key(&mut rng, &param)?;
let c1 = BFV::encrypt(&mut rng, &params, &pk, &m1)?; let c1 = BFV::encrypt(&mut rng, &param, &pk, &m1)?;
let c2 = BFV::encrypt(&mut rng, &params, &pk, &m2)?; let c2 = BFV::encrypt(&mut rng, &param, &pk, &m2)?;
let (c_a, c_b, c_c) = RLWE::tensor(params.t, &c1, &c2); let (c_a, c_b, c_c) = RLWE::tensor(param.t, &c1, &c2);
// let (c_a, c_b, c_c) = RLWE::tensor_new::<PQ, T>(&c1, &c2); // let (c_a, c_b, c_c) = RLWE::tensor_new::<PQ, T>(&c1, &c2);
// decrypt non-relinearized mul result // decrypt non-relinearized mul result
@@ -539,10 +539,10 @@ mod tests {
// &c_c.to_r(), // &c_c.to_r(),
// &R::<N>::from_vec(arith::ring_n::naive_mul(&sk.0.to_r(), &sk.0.to_r())), // &R::<N>::from_vec(arith::ring_n::naive_mul(&sk.0.to_r(), &sk.0.to_r())),
// )); // ));
let m3: Rq = m3.mul_div_round(params.t, params.ring.q); // descale let m3: Rq = m3.mul_div_round(param.t, param.ring.q); // descale
let m3 = m3.remodule(params.t); let m3 = m3.remodule(param.t);
let naive = (m1.clone().to_r() * m2.clone().to_r()).to_rq(params.t); // TODO rm clones let naive = (m1.clone().to_r() * m2.clone().to_r()).to_rq(param.t); // TODO rm clones
assert_eq!( assert_eq!(
m3.coeffs().to_vec(), m3.coeffs().to_vec(),
naive.coeffs().to_vec(), naive.coeffs().to_vec(),
@@ -557,38 +557,38 @@ mod tests {
#[test] #[test]
fn test_mul_relin() -> Result<()> { fn test_mul_relin() -> Result<()> {
let q: u64 = 2u64.pow(16) + 1; // q prime, and 2^q + 1 shape let q: u64 = 2u64.pow(16) + 1; // q prime, and 2^q + 1 shape
let params = Param { let param = Param {
ring: RingParam { q, n: 16 }, ring: RingParam { q, n: 16 },
t: 2, // plaintext modulus t: 2, // plaintext modulus
p: q * q, p: q * q,
}; };
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, params.t); let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..1_000 { for _ in 0..1_000 {
let m1 = Rq::rand_u64(&mut rng, msg_dist, &params.pt())?; let m1 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2 = Rq::rand_u64(&mut rng, msg_dist, &params.pt())?; let m2 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
test_mul_relin_opt(&mut rng, &params, m1, m2)?; test_mul_relin_opt(&mut rng, &param, m1, m2)?;
} }
Ok(()) Ok(())
} }
fn test_mul_relin_opt(mut rng: impl Rng, params: &Param, m1: Rq, m2: Rq) -> Result<()> { fn test_mul_relin_opt(mut rng: impl Rng, param: &Param, m1: Rq, m2: Rq) -> Result<()> {
let (sk, pk) = BFV::new_key(&mut rng, &params)?; let (sk, pk) = BFV::new_key(&mut rng, &param)?;
let rlk = BFV::rlk_key(&mut rng, &params, &sk)?; let rlk = BFV::rlk_key(&mut rng, &param, &sk)?;
let c1 = BFV::encrypt(&mut rng, &params, &pk, &m1)?; let c1 = BFV::encrypt(&mut rng, &param, &pk, &m1)?;
let c2 = BFV::encrypt(&mut rng, &params, &pk, &m2)?; let c2 = BFV::encrypt(&mut rng, &param, &pk, &m2)?;
let c3 = RLWE::mul(params.t, &rlk, &c1, &c2); // uses relinearize internally let c3 = RLWE::mul(param.t, &rlk, &c1, &c2); // uses relinearize internally
let m3 = BFV::decrypt(&params, &sk, &c3); let m3 = BFV::decrypt(&param, &sk, &c3);
let naive = (m1.clone().to_r() * m2.clone().to_r()).to_rq(params.t); // TODO rm clones let naive = (m1.clone().to_r() * m2.clone().to_r()).to_rq(param.t); // TODO rm clones
assert_eq!( assert_eq!(
m3.coeffs().to_vec(), m3.coeffs().to_vec(),
naive.coeffs().to_vec(), naive.coeffs().to_vec(),

View File

@@ -19,7 +19,7 @@ pub use encoder::Encoder;
const ERR_SIGMA: f64 = 3.2; const ERR_SIGMA: f64 = 3.2;
#[derive(Clone, Copy, Debug)] #[derive(Clone, Copy, Debug)]
pub struct Params { pub struct Param {
ring: RingParam, ring: RingParam,
t: u64, t: u64,
} }
@@ -30,33 +30,33 @@ pub struct PublicKey(Rq, Rq);
pub struct SecretKey(Rq); pub struct SecretKey(Rq);
pub struct CKKS { pub struct CKKS {
params: Params, param: Param,
encoder: Encoder, encoder: Encoder,
} }
impl CKKS { impl CKKS {
pub fn new(params: &Params, delta: C<f64>) -> Self { pub fn new(param: &Param, delta: C<f64>) -> Self {
let encoder = Encoder::new(params.ring.n, delta); let encoder = Encoder::new(param.ring.n, delta);
Self { Self {
params: params.clone(), param: param.clone(),
encoder, encoder,
} }
} }
/// generate a new key pair (privK, pubK) /// generate a new key pair (privK, pubK)
pub fn new_key(&self, mut rng: impl Rng) -> Result<(SecretKey, PublicKey)> { pub fn new_key(&self, mut rng: impl Rng) -> Result<(SecretKey, PublicKey)> {
let params = &self.params; let param = &self.param;
let Xi_key = Uniform::new(-1_f64, 1_f64); let Xi_key = Uniform::new(-1_f64, 1_f64);
let Xi_err = Normal::new(0_f64, ERR_SIGMA)?; let Xi_err = Normal::new(0_f64, ERR_SIGMA)?;
let e = Rq::rand_f64(&mut rng, Xi_err, &params.ring)?; let e = Rq::rand_f64(&mut rng, Xi_err, &param.ring)?;
let mut s = Rq::rand_f64(&mut rng, Xi_key, &params.ring)?; let mut s = Rq::rand_f64(&mut rng, Xi_key, &param.ring)?;
// since s is going to be multiplied by other Rq elements, already // since s is going to be multiplied by other Rq elements, already
// compute its NTT // compute its NTT
s.compute_evals(); s.compute_evals();
let a = Rq::rand_f64(&mut rng, Xi_key, &params.ring)?; let a = Rq::rand_f64(&mut rng, Xi_key, &param.ring)?;
let pk: PublicKey = PublicKey((&(-a.clone()) * &s) + e, a.clone()); // TODO rm clones let pk: PublicKey = PublicKey((&(-a.clone()) * &s) + e, a.clone()); // TODO rm clones
Ok((SecretKey(s), pk)) Ok((SecretKey(s), pk))
@@ -69,17 +69,17 @@ impl CKKS {
pk: &PublicKey, pk: &PublicKey,
m: &R, m: &R,
) -> Result<(Rq, Rq)> { ) -> Result<(Rq, Rq)> {
let params = self.params; let param = self.param;
let Xi_key = Uniform::new(-1_f64, 1_f64); let Xi_key = Uniform::new(-1_f64, 1_f64);
let Xi_err = Normal::new(0_f64, ERR_SIGMA)?; let Xi_err = Normal::new(0_f64, ERR_SIGMA)?;
let e_0 = Rq::rand_f64(&mut rng, Xi_err, &params.ring)?; let e_0 = Rq::rand_f64(&mut rng, Xi_err, &param.ring)?;
let e_1 = Rq::rand_f64(&mut rng, Xi_err, &params.ring)?; let e_1 = Rq::rand_f64(&mut rng, Xi_err, &param.ring)?;
let v = Rq::rand_f64(&mut rng, Xi_key, &params.ring)?; let v = Rq::rand_f64(&mut rng, Xi_key, &param.ring)?;
// let m: Rq = Rq::from(*m); // let m: Rq = Rq::from(*m);
let m: Rq = m.clone().to_rq(params.ring.q); // TODO rm clone let m: Rq = m.clone().to_rq(param.ring.q); // TODO rm clone
Ok((m + e_0 + &v * &pk.0.clone(), &v * &pk.1 + e_1)) Ok((m + e_0 + &v * &pk.0.clone(), &v * &pk.1 + e_1))
} }
@@ -127,7 +127,7 @@ mod tests {
let q: u64 = 2u64.pow(16) + 1; let q: u64 = 2u64.pow(16) + 1;
let n: usize = 32; let n: usize = 32;
let t: u64 = 50; let t: u64 = 50;
let params = Params { let param = Param {
ring: RingParam { q, n }, ring: RingParam { q, n },
t, t,
}; };
@@ -137,12 +137,12 @@ mod tests {
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
for _ in 0..1000 { for _ in 0..1000 {
let ckks = CKKS::new(&params, scale_factor); let ckks = CKKS::new(&param, scale_factor);
let (sk, pk) = ckks.new_key(&mut rng)?; let (sk, pk) = ckks.new_key(&mut rng)?;
let m_raw: R = let m_raw: R =
Rq::rand_f64(&mut rng, Uniform::new(0_f64, t as f64), &params.ring)?.to_r(); Rq::rand_f64(&mut rng, Uniform::new(0_f64, t as f64), &param.ring)?.to_r();
let m = &m_raw * &scale_factor_u64; let m = &m_raw * &scale_factor_u64;
let ct = ckks.encrypt(&mut rng, &pk, &m)?; let ct = ckks.encrypt(&mut rng, &pk, &m)?;
@@ -153,7 +153,7 @@ mod tests {
.iter() .iter()
.map(|e| (*e as f64 / (scale_factor_u64 as f64)).round() as u64) .map(|e| (*e as f64 / (scale_factor_u64 as f64)).round() as u64)
.collect(); .collect();
let m_decrypted = Rq::from_vec_u64(&params.ring, m_decrypted); let m_decrypted = Rq::from_vec_u64(&param.ring, m_decrypted);
// assert_eq!(m_decrypted, Rq::from(m_raw)); // assert_eq!(m_decrypted, Rq::from(m_raw));
assert_eq!(m_decrypted, m_raw.to_rq(q)); assert_eq!(m_decrypted, m_raw.to_rq(q));
} }
@@ -166,7 +166,7 @@ mod tests {
let q: u64 = 2u64.pow(16) + 1; let q: u64 = 2u64.pow(16) + 1;
let n: usize = 16; let n: usize = 16;
let t: u64 = 8; let t: u64 = 8;
let params = Params { let param = Param {
ring: RingParam { q, n }, ring: RingParam { q, n },
t, t,
}; };
@@ -175,7 +175,7 @@ mod tests {
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
for _ in 0..1000 { for _ in 0..1000 {
let ckks = CKKS::new(&params, scale_factor); let ckks = CKKS::new(&param, scale_factor);
let (sk, pk) = ckks.new_key(&mut rng)?; let (sk, pk) = ckks.new_key(&mut rng)?;
let z: Vec<C<f64>> = std::iter::repeat_with(|| C::<f64>::rand(&mut rng, t)) let z: Vec<C<f64>> = std::iter::repeat_with(|| C::<f64>::rand(&mut rng, t))
@@ -215,7 +215,7 @@ mod tests {
let q: u64 = 2u64.pow(16) + 1; let q: u64 = 2u64.pow(16) + 1;
let n: usize = 16; let n: usize = 16;
let t: u64 = 8; let t: u64 = 8;
let params = Params { let param = Param {
ring: RingParam { q, n }, ring: RingParam { q, n },
t, t,
}; };
@@ -224,7 +224,7 @@ mod tests {
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
for _ in 0..1000 { for _ in 0..1000 {
let ckks = CKKS::new(&params, scale_factor); let ckks = CKKS::new(&param, scale_factor);
let (sk, pk) = ckks.new_key(&mut rng)?; let (sk, pk) = ckks.new_key(&mut rng)?;
@@ -261,8 +261,8 @@ mod tests {
fn test_sub() -> Result<()> { fn test_sub() -> Result<()> {
let q: u64 = 2u64.pow(16) + 1; let q: u64 = 2u64.pow(16) + 1;
let n: usize = 16; let n: usize = 16;
let t: u64 = 4; let t: u64 = 2;
let params = Params { let param = Param {
ring: RingParam { q, n }, ring: RingParam { q, n },
t, t,
}; };
@@ -271,7 +271,7 @@ mod tests {
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
for _ in 0..1000 { for _ in 0..1000 {
let ckks = CKKS::new(&params, scale_factor); let ckks = CKKS::new(&param, scale_factor);
let (sk, pk) = ckks.new_key(&mut rng)?; let (sk, pk) = ckks.new_key(&mut rng)?;

View File

@@ -69,7 +69,7 @@ impl<R: Ring> GLev<R> {
impl<R: Ring> Mul<Vec<R>> for GLev<R> { impl<R: Ring> Mul<Vec<R>> for GLev<R> {
type Output = GLWE<R>; type Output = GLWE<R>;
fn mul(self, v: Vec<R>) -> GLWE<R> { fn mul(self, v: Vec<R>) -> GLWE<R> {
// TODO debug_assert_eq of params // TODO debug_assert_eq of param
// l times GLWES // l times GLWES
let glwes: Vec<GLWE<R>> = self.0; let glwes: Vec<GLWE<R>> = self.0;

View File

@@ -22,13 +22,30 @@ pub struct Param {
pub t: u64, pub t: u64,
} }
impl Param { impl Param {
// returns the plaintext params /// returns the plaintext param
pub fn pt(&self) -> RingParam { pub fn pt(&self) -> RingParam {
// TODO think if maybe return a new truct "PtParam" to differenciate
// between the ciphertexxt (RingParam) and the plaintext param. Maybe it
// can be just a wrapper on top of RingParam.
RingParam { RingParam {
q: self.t, q: self.t,
n: self.ring.n, n: self.ring.n,
} }
} }
/// returns the LWE param for the given GLWE (self), that is, it uses k=K*N
/// as the length for the secret key. This follows [2018-421] where
/// TLWE sk: s \in B^n , where n=K*N
/// TRLWE sk: s \in B_N[X]^K
pub fn lwe(&self) -> Self {
Self {
ring: RingParam {
q: self.ring.q,
n: 1,
},
k: self.k * self.ring.n,
t: self.t,
}
}
} }
/// GLWE implemented over the `Ring` trait, so that it can be also instantiated /// GLWE implemented over the `Ring` trait, so that it can be also instantiated
@@ -46,8 +63,8 @@ pub struct PublicKey<R: Ring>(pub R, pub TR<R>);
pub struct KSK<R: Ring>(Vec<GLev<R>>); pub struct KSK<R: Ring>(Vec<GLev<R>>);
impl<R: Ring> GLWE<R> { impl<R: Ring> GLWE<R> {
pub fn zero(k: usize, params: &RingParam) -> Self { pub fn zero(k: usize, param: &RingParam) -> Self {
Self(TR::zero(k, &params), R::zero(&params)) Self(TR::zero(k, &param), R::zero(&param))
} }
pub fn from_plaintext(k: usize, param: &RingParam, p: R) -> Self { pub fn from_plaintext(k: usize, param: &RingParam, p: R) -> Self {
Self(TR::zero(k, &param), p) Self(TR::zero(k, &param), p)
@@ -187,6 +204,9 @@ impl GLWE<Rq> {
impl<R: Ring> Add<GLWE<R>> for GLWE<R> { impl<R: Ring> Add<GLWE<R>> for GLWE<R> {
type Output = Self; type Output = Self;
fn add(self, other: Self) -> Self { fn add(self, other: Self) -> Self {
debug_assert_eq!(self.0.k, other.0.k);
debug_assert_eq!(self.1.param(), other.1.param());
let a: TR<R> = self.0 + other.0; let a: TR<R> = self.0 + other.0;
let b: R = self.1 + other.1; let b: R = self.1 + other.1;
Self(a, b) Self(a, b)
@@ -196,6 +216,8 @@ impl<R: Ring> Add<GLWE<R>> for GLWE<R> {
impl<R: Ring> Add<R> for GLWE<R> { impl<R: Ring> Add<R> for GLWE<R> {
type Output = Self; type Output = Self;
fn add(self, plaintext: R) -> Self { fn add(self, plaintext: R) -> Self {
debug_assert_eq!(self.1.param(), plaintext.param());
let a: TR<R> = self.0; let a: TR<R> = self.0;
let b: R = self.1 + plaintext; let b: R = self.1 + plaintext;
Self(a, b) Self(a, b)
@@ -231,6 +253,9 @@ impl<R: Ring> Sum<GLWE<R>> for GLWE<R> {
impl<R: Ring> Sub<GLWE<R>> for GLWE<R> { impl<R: Ring> Sub<GLWE<R>> for GLWE<R> {
type Output = Self; type Output = Self;
fn sub(self, other: Self) -> Self { fn sub(self, other: Self) -> Self {
debug_assert_eq!(self.0.k, other.0.k);
debug_assert_eq!(self.1.param(), other.1.param());
let a: TR<R> = self.0 - other.0; let a: TR<R> = self.0 - other.0;
let b: R = self.1 - other.1; let b: R = self.1 - other.1;
Self(a, b) Self(a, b)
@@ -240,6 +265,8 @@ impl<R: Ring> Sub<GLWE<R>> for GLWE<R> {
impl<R: Ring> Mul<R> for GLWE<R> { impl<R: Ring> Mul<R> for GLWE<R> {
type Output = Self; type Output = Self;
fn mul(self, plaintext: R) -> Self { fn mul(self, plaintext: R) -> Self {
debug_assert_eq!(self.1.param(), plaintext.param());
let a: TR<R> = TR { let a: TR<R> = TR {
k: self.0.k, k: self.0.k,
r: self r: self
@@ -351,8 +378,7 @@ mod tests {
} }
} }
pub fn t_decode(param: &Param, pt: &Tn) -> Rq { pub fn t_decode(param: &Param, pt: &Tn) -> Rq {
let p = param.t; let pt = pt.mul_div_round(param.t, u64::MAX);
let pt = pt.mul_div_round(p, u64::MAX);
Rq::from_vec_u64(&param.pt(), pt.coeffs().iter().map(|c| c.0).collect()) Rq::from_vec_u64(&param.pt(), pt.coeffs().iter().map(|c| c.0).collect())
} }
#[test] #[test]

View File

@@ -4,53 +4,57 @@ use rand::Rng;
use std::array; use std::array;
use std::ops::{Add, Mul}; use std::ops::{Add, Mul};
use arith::{Ring, Rq, Tn, T64, TR}; use arith::{Ring, RingParam, Rq, Tn, T64, TR};
use crate::tglwe::{PublicKey, SecretKey, TGLWE}; use crate::tglwe::{PublicKey, SecretKey, TGLWE};
use gfhe::glwe::GLWE; use gfhe::glwe::{Param, GLWE};
/// vector of length K+1 = ([K * TGLev], [1 * TGLev]) /// vector of length K+1 = ([K * TGLev], [1 * TGLev])
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct TGGSW<const N: usize, const K: usize>(pub(crate) Vec<TGLev<N, K>>, TGLev<N, K>); pub struct TGGSW(pub(crate) Vec<TGLev>, TGLev);
impl<const N: usize, const K: usize> TGGSW<N, K> { impl TGGSW {
pub fn encrypt_s( pub fn encrypt_s(
mut rng: impl Rng, mut rng: impl Rng,
param: &Param,
beta: u32, beta: u32,
l: u32, l: u32,
sk: &SecretKey<N, K>, sk: &SecretKey,
m: &Tn<N>, m: &Tn,
) -> Result<Self> { ) -> Result<Self> {
let a: Vec<TGLev<N, K>> = (0..K) debug_assert_eq!(sk.0 .0.k, param.k);
.map(|i| TGLev::encrypt_s(&mut rng, beta, l, sk, &(-sk.0 .0 .0[i] * *m)))
let a: Vec<TGLev> = (0..param.k)
.map(|i| TGLev::encrypt_s(&mut rng, param, beta, l, sk, &(&-sk.0 .0.r[i].clone() * m)))
// TODO rm clone
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
let b: TGLev<N, K> = TGLev::encrypt_s(&mut rng, beta, l, sk, m)?; let b: TGLev = TGLev::encrypt_s(&mut rng, &param, beta, l, sk, m)?;
Ok(Self(a, b)) Ok(Self(a, b))
} }
pub fn decrypt(&self, sk: &SecretKey<N, K>, beta: u32) -> Tn<N> { pub fn decrypt(&self, sk: &SecretKey, beta: u32) -> Tn {
self.1.decrypt(sk, beta) self.1.decrypt(sk, beta)
} }
pub fn cmux(bit: Self, ct1: TGLWE<N, K>, ct2: TGLWE<N, K>) -> TGLWE<N, K> { pub fn cmux(bit: Self, ct1: TGLWE, ct2: TGLWE) -> TGLWE {
ct1.clone() + (bit * (ct2 - ct1)) ct1.clone() + (bit * (ct2 - ct1))
} }
} }
/// External product TGGSW x TGLWE /// External product tggsw x tglwe
impl<const N: usize, const K: usize> Mul<TGLWE<N, K>> for TGGSW<N, K> { impl Mul<TGLWE> for TGGSW {
type Output = TGLWE<N, K>; type Output = TGLWE;
fn mul(self, tglwe: TGLWE<N, K>) -> TGLWE<N, K> { fn mul(self, tglwe: TGLWE) -> TGLWE {
let beta: u32 = 2; let beta: u32 = 2;
let l: u32 = 64; // TODO wip let l: u32 = 64; // TODO wip
let tglwe_ab: Vec<Tn<N>> = [tglwe.0 .0 .0.clone(), vec![tglwe.0 .1]].concat(); let tglwe_ab: Vec<Tn> = [tglwe.0 .0.r.clone(), vec![tglwe.0 .1]].concat();
let tgsw_ab: Vec<TGLev<N, K>> = [self.0.clone(), vec![self.1]].concat(); let tgsw_ab: Vec<TGLev> = [self.0.clone(), vec![self.1]].concat();
assert_eq!(tgsw_ab.len(), tglwe_ab.len()); assert_eq!(tgsw_ab.len(), tglwe_ab.len());
let r: TGLWE<N, K> = zip_eq(tgsw_ab, tglwe_ab) let r: TGLWE = zip_eq(tgsw_ab, tglwe_ab)
.map(|(tlev_i, tglwe_i)| tlev_i * tglwe_i.decompose(beta, l)) .map(|(tlev_i, tglwe_i)| tlev_i * tglwe_i.decompose(beta, l))
.sum(); .sum();
r r
@@ -58,26 +62,36 @@ impl<const N: usize, const K: usize> Mul<TGLWE<N, K>> for TGGSW<N, K> {
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct TGLev<const N: usize, const K: usize>(pub(crate) Vec<TGLWE<N, K>>); pub struct TGLev(pub(crate) Vec<TGLWE>);
impl<const N: usize, const K: usize> TGLev<N, K> { impl TGLev {
pub fn encode<const T: u64>(m: &Rq<T, N>) -> Tn<N> { pub fn encode(param: &Param, m: &Rq) -> Tn {
let coeffs = m.coeffs(); debug_assert_eq!(param.t, m.param.q); // plaintext modulus
Tn(array::from_fn(|i| T64(coeffs[i].0)))
Tn {
param: param.ring,
coeffs: m.coeffs().iter().map(|c_i| T64(c_i.v)).collect(),
}
} }
pub fn decode<const T: u64>(p: &Tn<N>) -> Rq<T, N> { pub fn decode(param: &Param, p: &Tn) -> Rq {
Rq::<T, N>::from_vec_u64(p.coeffs().iter().map(|c| c.0).collect()) Rq::from_vec_u64(&param.pt(), p.coeffs().iter().map(|c| c.0).collect())
} }
pub fn encrypt( pub fn encrypt(
mut rng: impl Rng, mut rng: impl Rng,
param: &Param,
beta: u32, beta: u32,
l: u32, l: u32,
pk: &PublicKey<N, K>, pk: &PublicKey,
m: &Tn<N>, m: &Tn,
) -> Result<Self> { ) -> Result<Self> {
let tlev: Vec<TGLWE<N, K>> = (1..l + 1) let tlev: Vec<TGLWE> = (1..l + 1)
.map(|i| { .map(|i| {
TGLWE::<N, K>::encrypt(&mut rng, pk, &(*m * (u64::MAX / beta.pow(i as u32) as u64))) TGLWE::encrypt(
&mut rng,
&param,
pk,
&(m * &(u64::MAX / beta.pow(i as u32) as u64)),
)
}) })
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
@@ -85,35 +99,36 @@ impl<const N: usize, const K: usize> TGLev<N, K> {
} }
pub fn encrypt_s( pub fn encrypt_s(
mut rng: impl Rng, mut rng: impl Rng,
param: &Param,
_beta: u32, // TODO rm, and make beta=2 always _beta: u32, // TODO rm, and make beta=2 always
l: u32, l: u32,
sk: &SecretKey<N, K>, sk: &SecretKey,
m: &Tn<N>, m: &Tn,
) -> Result<Self> { ) -> Result<Self> {
let tlev: Vec<TGLWE<N, K>> = (1..l as u64 + 1) let tlev: Vec<TGLWE> = (1..l as u64 + 1)
.map(|i| { .map(|i| {
let aux = if i < 64 { let aux = if i < 64 {
*m * (u64::MAX / (1u64 << i)) m * &(u64::MAX / (1u64 << i))
} else { } else {
// 1<<64 would overflow, and anyways we're dividing u64::MAX // 1<<64 would overflow, and anyways we're dividing u64::MAX
// by it, which would be equal to 1 // by it, which would be equal to 1
*m m.clone() // TODO rm clone
}; };
TGLWE::<N, K>::encrypt_s(&mut rng, sk, &aux) TGLWE::encrypt_s(&mut rng, &param, sk, &aux)
}) })
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
Ok(Self(tlev)) Ok(Self(tlev))
} }
pub fn decrypt(&self, sk: &SecretKey<N, K>, beta: u32) -> Tn<N> { pub fn decrypt(&self, sk: &SecretKey, beta: u32) -> Tn {
let pt = self.0[0].decrypt(sk); let pt = self.0[0].decrypt(sk);
pt.mul_div_round(beta as u64, u64::MAX) pt.mul_div_round(beta as u64, u64::MAX)
} }
} }
impl<const N: usize, const K: usize> TGLev<N, K> { impl TGLev {
pub fn iter(&self) -> std::slice::Iter<TGLWE<N, K>> { pub fn iter(&self) -> std::slice::Iter<TGLWE> {
self.0.iter() self.0.iter()
} }
} }
@@ -121,14 +136,14 @@ impl<const N: usize, const K: usize> TGLev<N, K> {
// dot product between a TGLev and Vec<Tn<N>>, usually Vec<Tn<N>> comes from a // dot product between a TGLev and Vec<Tn<N>>, usually Vec<Tn<N>> comes from a
// decomposition of Tn<N> // decomposition of Tn<N>
// TGLev * Vec<Tn<N>> --> TGLWE // TGLev * Vec<Tn<N>> --> TGLWE
impl<const N: usize, const K: usize> Mul<Vec<Tn<N>>> for TGLev<N, K> { impl Mul<Vec<Tn>> for TGLev {
type Output = TGLWE<N, K>; type Output = TGLWE;
fn mul(self, v: Vec<Tn<N>>) -> Self::Output { fn mul(self, v: Vec<Tn>) -> Self::Output {
assert_eq!(self.0.len(), v.len()); assert_eq!(self.0.len(), v.len());
// l TGLWES // l TGLWES
let tlwes: Vec<TGLWE<N, K>> = self.0; let tlwes: Vec<TGLWE> = self.0;
let r: TGLWE<N, K> = zip_eq(v, tlwes).map(|(a_d_i, glwe_i)| glwe_i * a_d_i).sum(); let r: TGLWE = zip_eq(v, tlwes).map(|(a_d_i, glwe_i)| glwe_i * a_d_i).sum();
r r
} }
} }
@@ -141,38 +156,39 @@ mod tests {
use super::*; use super::*;
#[test] #[test]
fn test_external_product() -> Result<()> { fn test_external_product() -> Result<()> {
const T: u64 = 16; // plaintext modulus let param = Param {
const K: usize = 4; ring: RingParam { q: u64::MAX, n: 64 },
const N: usize = 64; k: 4,
const KN: usize = K * N; t: 16, // plaintext modulus
};
let beta: u32 = 2; let beta: u32 = 2;
let l: u32 = 64; let l: u32 = 64;
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T); let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..50 { for _ in 0..50 {
let (sk, _) = TGLWE::<N, K>::new_key::<KN>(&mut rng)?; let (sk, _) = TGLWE::new_key(&mut rng, &param)?;
let m1: Rq<T, N> = Rq::rand_u64(&mut rng, msg_dist)?; let m1: Rq = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p1: Tn<N> = TGLev::<N, K>::encode::<T>(&m1); let p1: Tn = TGLev::encode(&param, &m1);
let m2: Rq<T, N> = Rq::rand_u64(&mut rng, msg_dist)?; let m2: Rq = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p2: Tn<N> = TGLWE::<N, K>::encode::<T>(&m2); // scaled by delta let p2: Tn = TGLWE::encode(&param, &m2); // scaled by delta
let tgsw = TGGSW::<N, K>::encrypt_s(&mut rng, beta, l, &sk, &p1)?; let tgsw = TGGSW::encrypt_s(&mut rng, &param, beta, l, &sk, &p1)?;
let tlwe = TGLWE::<N, K>::encrypt_s(&mut rng, &sk, &p2)?; let tlwe = TGLWE::encrypt_s(&mut rng, &param, &sk, &p2)?;
let res: TGLWE<N, K> = tgsw * tlwe; let res: TGLWE = tgsw * tlwe;
// let p_recovered = res.decrypt(&sk, beta); // let p_recovered = res.decrypt(&sk, beta);
let p_recovered = res.decrypt(&sk); let p_recovered = res.decrypt(&sk);
// downscaled by delta^-1 // downscaled by delta^-1
let res_recovered = TGLWE::<N, K>::decode::<T>(&p_recovered); let res_recovered = TGLWE::decode(&param, &p_recovered);
// assert_eq!(m1 * m2, m_recovered); // assert_eq!(m1 * m2, m_recovered);
assert_eq!((m1.to_r() * m2.to_r()).to_rq::<T>(), res_recovered); assert_eq!((m1.to_r() * m2.to_r()).to_rq(param.t), res_recovered);
} }
Ok(()) Ok(())

View File

@@ -7,155 +7,193 @@ use std::array;
use std::iter::Sum; use std::iter::Sum;
use std::ops::{Add, AddAssign, Mul, Sub}; use std::ops::{Add, AddAssign, Mul, Sub};
use arith::{Ring, Rq, Tn, T64, TR}; use arith::{Ring, RingParam, Rq, Tn, T64, TR};
use gfhe::{glwe, GLWE}; use gfhe::{glwe, glwe::Param, GLWE};
use crate::tlev::TLev; use crate::tlev::TLev;
use crate::{tlwe, tlwe::TLWE}; use crate::{tlwe, tlwe::TLWE};
// pub type SecretKey<const N: usize, const K: usize> = glwe::SecretKey<Tn<N>, K>; // pub type SecretKey<const N: usize, const K: usize> = glwe::SecretKey<Tn<N>, K>;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct SecretKey<const N: usize, const K: usize>(pub glwe::SecretKey<Tn<N>, K>); pub struct SecretKey(pub glwe::SecretKey<Tn>);
// pub struct SecretKey<const K: usize>(pub tlwe::SecretKey<K>); // pub struct SecretKey<const K: usize>(pub tlwe::SecretKey<K>);
impl<const N: usize, const K: usize> SecretKey<N, K> { impl SecretKey {
pub fn to_tlwe<const KN: usize>(&self) -> tlwe::SecretKey<KN> { pub fn to_tlwe(&self, param: &Param) -> tlwe::SecretKey {
let s: TR<Tn<N>, K> = self.0 .0.clone(); let s: TR<Tn> = self.0 .0.clone();
debug_assert_eq!(s.r.len(), param.k); // sanity check
let r: Vec<Vec<T64>> = s.0.iter().map(|s_i| s_i.coeffs()).collect(); let kn = param.k * param.ring.n;
let r: Vec<Vec<T64>> = s.r.iter().map(|s_i| s_i.coeffs()).collect();
let r: Vec<T64> = r.into_iter().flatten().collect(); let r: Vec<T64> = r.into_iter().flatten().collect();
tlwe::SecretKey::<KN>(glwe::SecretKey::<T64, KN>(TR::<T64, KN>::new(r))) debug_assert_eq!(r.len(), kn); // sanity check
tlwe::SecretKey(glwe::SecretKey::<T64>(TR::<T64>::new(kn, r)))
} }
} }
pub type PublicKey<const N: usize, const K: usize> = glwe::PublicKey<Tn<N>, K>; pub type PublicKey = glwe::PublicKey<Tn>;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct TGLWE<const N: usize, const K: usize>(pub GLWE<Tn<N>, K>); pub struct TGLWE(pub GLWE<Tn>);
impl<const N: usize, const K: usize> TGLWE<N, K> { impl TGLWE {
pub fn zero() -> Self { pub fn zero(k: usize, param: &RingParam) -> Self {
Self(GLWE::<Tn<N>, K>::zero()) Self(GLWE::<Tn>::zero(k, param))
} }
pub fn from_plaintext(p: Tn<N>) -> Self { pub fn from_plaintext(k: usize, param: &RingParam, p: Tn) -> Self {
Self(GLWE::<Tn<N>, K>::from_plaintext(p)) Self(GLWE::<Tn>::from_plaintext(k, param, p))
} }
pub fn new_key<const KN: usize>( pub fn new_key(mut rng: impl Rng, param: &Param) -> Result<(SecretKey, PublicKey)> {
mut rng: impl Rng,
) -> Result<(SecretKey<N, K>, PublicKey<N, K>)> {
// assert_eq!(KN, K * N); // this is wip, while not being able to compute K*N // assert_eq!(KN, K * N); // this is wip, while not being able to compute K*N
let (sk_tlwe, _) = TLWE::<KN>::new_key(&mut rng)?; let (sk_tlwe, _) = TLWE::new_key(&mut rng, &param.lwe())?; //param.lwe() so that it uses K*N
debug_assert_eq!(sk_tlwe.0 .0.r.len(), param.lwe().k); // =KN (sanity check)
// let sk = crate::tlwe::sk_to_tglwe::<N, K, KN>(sk_tlwe); // let sk = crate::tlwe::sk_to_tglwe::<N, K, KN>(sk_tlwe);
let sk = sk_tlwe.to_tglwe::<N, K>(); let sk = sk_tlwe.to_tglwe(param);
let pk: PublicKey<N, K> = GLWE::pk_from_sk(rng, sk.0.clone())?; let pk: PublicKey = GLWE::pk_from_sk(rng, param, sk.0.clone())?;
Ok((sk, pk)) Ok((sk, pk))
} }
pub fn encode<const P: u64>(m: &Rq<P, N>) -> Tn<N> { pub fn encode(param: &Param, m: &Rq) -> Tn {
let delta = u64::MAX / P; // floored debug_assert_eq!(param.t, m.param.q); // plaintext modulus
let p = param.t;
let delta = u64::MAX / p; // floored
let coeffs = m.coeffs(); let coeffs = m.coeffs();
Tn(array::from_fn(|i| T64(coeffs[i].0 * delta))) // Tn(array::from_fn(|i| T64(coeffs[i].0 * delta)))
Tn {
param: param.ring,
coeffs: coeffs.iter().map(|c_i| T64(c_i.v * delta)).collect(),
}
} }
pub fn decode<const P: u64>(p: &Tn<N>) -> Rq<P, N> { pub fn decode(param: &Param, pt: &Tn) -> Rq {
let p = p.mul_div_round(P, u64::MAX); let p = param.t;
Rq::<P, N>::from_vec_u64(p.coeffs().iter().map(|c| c.0).collect()) let pt = pt.mul_div_round(p, u64::MAX);
Rq::from_vec_u64(&param.pt(), pt.coeffs().iter().map(|c| c.0).collect())
} }
// encrypts with the given SecretKey (instead of PublicKey) // encrypts with the given SecretKey (instead of PublicKey)
pub fn encrypt_s(rng: impl Rng, sk: &SecretKey<N, K>, p: &Tn<N>) -> Result<Self> { pub fn encrypt_s(rng: impl Rng, param: &Param, sk: &SecretKey, p: &Tn) -> Result<Self> {
let glwe = GLWE::encrypt_s(rng, &sk.0, p)?; let glwe = GLWE::encrypt_s(rng, param, &sk.0, p)?;
Ok(Self(glwe)) Ok(Self(glwe))
} }
pub fn encrypt(rng: impl Rng, pk: &PublicKey<N, K>, p: &Tn<N>) -> Result<Self> { pub fn encrypt(rng: impl Rng, param: &Param, pk: &PublicKey, p: &Tn) -> Result<Self> {
let glwe = GLWE::encrypt(rng, &pk, p)?; let glwe = GLWE::encrypt(rng, param, &pk, p)?;
Ok(Self(glwe)) Ok(Self(glwe))
} }
pub fn decrypt(&self, sk: &SecretKey<N, K>) -> Tn<N> { pub fn decrypt(&self, sk: &SecretKey) -> Tn {
self.0.decrypt(&sk.0) self.0.decrypt(&sk.0)
} }
/// Sample extraction / Coefficient extraction /// Sample extraction / Coefficient extraction
pub fn sample_extraction<const KN: usize>(&self, h: usize) -> TLWE<KN> { pub fn sample_extraction(&self, param: &Param, h: usize) -> TLWE {
assert!(h < N); let n = param.ring.n;
assert!(h < n);
let a: TR<Tn<N>, K> = self.0 .0.clone(); let a: TR<Tn> = self.0 .0.clone();
// set a_{n*i+j} = a_{i, h-j} if j \in {0, h} // set a_{n*i+j} = a_{i, h-j} if j \in {0, h}
// -a_{i, n+h-j} if j \in {h+1, n-1} // -a_{i, n+h-j} if j \in {h+1, n-1}
let new_a: Vec<T64> = a let new_a: Vec<T64> = a
.iter() .iter()
.flat_map(|a_i| { .flat_map(|a_i| {
let a_i = a_i.coeffs(); let a_i = a_i.coeffs();
(0..N) (0..n)
.map(|j| if j <= h { a_i[h - j] } else { -a_i[N + h - j] }) .map(|j| if j <= h { a_i[h - j] } else { -a_i[n + h - j] })
.collect::<Vec<T64>>() .collect::<Vec<T64>>()
}) })
.collect::<Vec<T64>>(); .collect::<Vec<T64>>();
debug_assert_eq!(new_a.len(), param.k * param.ring.n); // sanity check
TLWE(GLWE(TR(new_a), self.0 .1.coeffs()[h])) TLWE(GLWE(
TR {
// TODO use constructor `new`, which will check len with k
k: param.k * param.ring.n,
r: new_a,
},
self.0 .1.coeffs()[h],
))
} }
pub fn left_rotate(&self, h: usize) -> Self { pub fn left_rotate(&self, h: usize) -> Self {
dbg!(&h); let (a, b): (TR<Tn>, Tn) = (self.0 .0.clone(), self.0 .1.clone());
let (a, b): (TR<Tn<N>, K>, Tn<N>) = (self.0 .0.clone(), self.0 .1);
Self(GLWE(a.left_rotate(h), b.left_rotate(h))) Self(GLWE(a.left_rotate(h), b.left_rotate(h)))
} }
} }
impl<const N: usize, const K: usize> Add<TGLWE<N, K>> for TGLWE<N, K> { impl Add<TGLWE> for TGLWE {
type Output = Self; type Output = Self;
fn add(self, other: Self) -> Self { fn add(self, other: Self) -> Self {
debug_assert_eq!(self.0 .0.k, other.0 .0.k);
debug_assert_eq!(self.0 .1.param(), other.0 .1.param());
Self(self.0 + other.0) Self(self.0 + other.0)
} }
} }
impl<const N: usize, const K: usize> AddAssign for TGLWE<N, K> { impl AddAssign for TGLWE {
fn add_assign(&mut self, rhs: Self) { fn add_assign(&mut self, other: Self) {
self.0 += rhs.0 debug_assert_eq!(self.0 .0.k, other.0 .0.k);
debug_assert_eq!(self.0 .1.param(), other.0 .1.param());
self.0 += other.0
} }
} }
impl<const N: usize, const K: usize> Sum<TGLWE<N, K>> for TGLWE<N, K> { impl Sum<TGLWE> for TGLWE {
fn sum<I>(iter: I) -> Self fn sum<I>(mut iter: I) -> Self
where where
I: Iterator<Item = Self>, I: Iterator<Item = Self>,
{ {
let mut acc = TGLWE::<N, K>::zero(); // let mut acc = TGLWE::<N, K>::zero();
for e in iter { // for e in iter {
acc += e; // acc += e;
} // }
acc // acc
let first = iter.next().unwrap();
iter.fold(first, |acc, e| acc + e)
} }
} }
impl<const N: usize, const K: usize> Sub<TGLWE<N, K>> for TGLWE<N, K> { impl Sub<TGLWE> for TGLWE {
type Output = Self; type Output = Self;
fn sub(self, other: Self) -> Self { fn sub(self, other: Self) -> Self {
debug_assert_eq!(self.0 .0.k, other.0 .0.k);
debug_assert_eq!(self.0 .1.param(), other.0 .1.param());
Self(self.0 - other.0) Self(self.0 - other.0)
} }
} }
// plaintext addition // plaintext addition
impl<const N: usize, const K: usize> Add<Tn<N>> for TGLWE<N, K> { impl Add<Tn> for TGLWE {
type Output = Self; type Output = Self;
fn add(self, plaintext: Tn<N>) -> Self { fn add(self, plaintext: Tn) -> Self {
let a: TR<Tn<N>, K> = self.0 .0; debug_assert_eq!(self.0 .1.param(), plaintext.param());
let b: Tn<N> = self.0 .1 + plaintext;
let a: TR<Tn> = self.0 .0;
let b: Tn = self.0 .1 + plaintext;
Self(GLWE(a, b)) Self(GLWE(a, b))
} }
} }
// plaintext substraction // plaintext substraction
impl<const N: usize, const K: usize> Sub<Tn<N>> for TGLWE<N, K> { impl Sub<Tn> for TGLWE {
type Output = Self; type Output = Self;
fn sub(self, plaintext: Tn<N>) -> Self { fn sub(self, plaintext: Tn) -> Self {
let a: TR<Tn<N>, K> = self.0 .0; debug_assert_eq!(self.0 .1.param(), plaintext.param());
let b: Tn<N> = self.0 .1 - plaintext;
let a: TR<Tn> = self.0 .0;
let b: Tn = self.0 .1 - plaintext;
Self(GLWE(a, b)) Self(GLWE(a, b))
} }
} }
// plaintext multiplication // plaintext multiplication
impl<const N: usize, const K: usize> Mul<Tn<N>> for TGLWE<N, K> { impl Mul<Tn> for TGLWE {
type Output = Self; type Output = Self;
fn mul(self, plaintext: Tn<N>) -> Self { fn mul(self, plaintext: Tn) -> Self {
let a: TR<Tn<N>, K> = TR(self.0 .0 .0.iter().map(|r_i| *r_i * plaintext).collect()); debug_assert_eq!(self.0 .1.param(), plaintext.param());
let b: Tn<N> = self.0 .1 * plaintext;
let a: TR<Tn> = TR {
k: self.0 .0.k,
r: self.0 .0.r.iter().map(|r_i| r_i * &plaintext).collect(),
};
let b: Tn = self.0 .1 * plaintext;
Self(GLWE(a, b)) Self(GLWE(a, b))
} }
} }
@@ -169,30 +207,31 @@ mod tests {
#[test] #[test]
fn test_encrypt_decrypt() -> Result<()> { fn test_encrypt_decrypt() -> Result<()> {
const T: u64 = 128; // msg space (msg modulus) let param = Param {
const N: usize = 64; ring: RingParam { q: u64::MAX, n: 64 },
const K: usize = 16; k: 16,
type S = TGLWE<N, K>; t: 128, // plaintext modulus
};
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T); let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..200 { for _ in 0..200 {
let (sk, pk) = TGLWE::<N, K>::new_key::<{ K * N }>(&mut rng)?; let (sk, pk) = TGLWE::new_key(&mut rng, &param)?;
let m = Rq::<T, N>::rand_u64(&mut rng, msg_dist)?; let m = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p: Tn<N> = S::encode::<T>(&m); let p: Tn = TGLWE::encode(&param, &m);
let c = S::encrypt(&mut rng, &pk, &p)?; let c = TGLWE::encrypt(&mut rng, &param, &pk, &p)?;
let p_recovered = c.decrypt(&sk); let p_recovered = c.decrypt(&sk);
let m_recovered = S::decode::<T>(&p_recovered); let m_recovered = TGLWE::decode(&param, &p_recovered);
assert_eq!(m, m_recovered); assert_eq!(m, m_recovered);
// same but using encrypt_s (with sk instead of pk)) // same but using encrypt_s (with sk instead of pk))
let c = S::encrypt_s(&mut rng, &sk, &p)?; let c = TGLWE::encrypt_s(&mut rng, &param, &sk, &p)?;
let p_recovered = c.decrypt(&sk); let p_recovered = c.decrypt(&sk);
let m_recovered = S::decode::<T>(&p_recovered); let m_recovered = TGLWE::decode(&param, &p_recovered);
assert_eq!(m, m_recovered); assert_eq!(m, m_recovered);
} }
@@ -202,31 +241,32 @@ mod tests {
#[test] #[test]
fn test_addition() -> Result<()> { fn test_addition() -> Result<()> {
const T: u64 = 128; let param = Param {
const N: usize = 64; ring: RingParam { q: u64::MAX, n: 64 },
const K: usize = 16; k: 16,
type S = TGLWE<N, K>; t: 128, // plaintext modulus
};
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T); let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..200 { for _ in 0..200 {
let (sk, pk) = S::new_key::<{ K * N }>(&mut rng)?; let (sk, pk) = TGLWE::new_key(&mut rng, &param)?;
let m1 = Rq::<T, N>::rand_u64(&mut rng, msg_dist)?; let m1 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2 = Rq::<T, N>::rand_u64(&mut rng, msg_dist)?; let m2 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p1: Tn<N> = S::encode::<T>(&m1); // plaintext let p1: Tn = TGLWE::encode(&param, &m1); // plaintext
let p2: Tn<N> = S::encode::<T>(&m2); // plaintext let p2: Tn = TGLWE::encode(&param, &m2); // plaintext
let c1 = S::encrypt(&mut rng, &pk, &p1)?; let c1 = TGLWE::encrypt(&mut rng, &param, &pk, &p1)?;
let c2 = S::encrypt(&mut rng, &pk, &p2)?; let c2 = TGLWE::encrypt(&mut rng, &param, &pk, &p2)?;
let c3 = c1 + c2; let c3 = c1 + c2;
let p3_recovered = c3.decrypt(&sk); let p3_recovered = c3.decrypt(&sk);
let m3_recovered = S::decode::<T>(&p3_recovered); let m3_recovered = TGLWE::decode(&param, &p3_recovered);
assert_eq!((m1 + m2).remodule::<T>(), m3_recovered.remodule::<T>()); assert_eq!((m1 + m2).remodule(param.t), m3_recovered.remodule(param.t));
} }
Ok(()) Ok(())
@@ -234,28 +274,29 @@ mod tests {
#[test] #[test]
fn test_add_plaintext() -> Result<()> { fn test_add_plaintext() -> Result<()> {
const T: u64 = 128; let param = Param {
const N: usize = 64; ring: RingParam { q: u64::MAX, n: 64 },
const K: usize = 16; k: 16,
type S = TGLWE<N, K>; t: 128, // plaintext modulus
};
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T); let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..200 { for _ in 0..200 {
let (sk, pk) = S::new_key::<{ K * N }>(&mut rng)?; let (sk, pk) = TGLWE::new_key(&mut rng, &param)?;
let m1 = Rq::<T, N>::rand_u64(&mut rng, msg_dist)?; let m1 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2 = Rq::<T, N>::rand_u64(&mut rng, msg_dist)?; let m2 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p1: Tn<N> = S::encode::<T>(&m1); // plaintext let p1: Tn = TGLWE::encode(&param, &m1); // plaintext
let p2: Tn<N> = S::encode::<T>(&m2); // plaintext let p2: Tn = TGLWE::encode(&param, &m2); // plaintext
let c1 = S::encrypt(&mut rng, &pk, &p1)?; let c1 = TGLWE::encrypt(&mut rng, &param, &pk, &p1)?;
let c3 = c1 + p2; let c3 = c1 + p2;
let p3_recovered = c3.decrypt(&sk); let p3_recovered = c3.decrypt(&sk);
let m3_recovered = S::decode::<T>(&p3_recovered); let m3_recovered = TGLWE::decode(&param, &p3_recovered);
assert_eq!(m1 + m2, m3_recovered); assert_eq!(m1 + m2, m3_recovered);
} }
@@ -265,30 +306,34 @@ mod tests {
#[test] #[test]
fn test_mul_plaintext() -> Result<()> { fn test_mul_plaintext() -> Result<()> {
const T: u64 = 128; let param = Param {
const N: usize = 64; ring: RingParam { q: u64::MAX, n: 64 },
const K: usize = 16; k: 16,
type S = TGLWE<N, K>; t: 128, // plaintext modulus
};
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T); let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..200 { for _ in 0..200 {
let (sk, pk) = S::new_key::<{ K * N }>(&mut rng)?; let (sk, pk) = TGLWE::new_key(&mut rng, &param)?;
let m1 = Rq::<T, N>::rand_u64(&mut rng, msg_dist)?; let m1 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2 = Rq::<T, N>::rand_u64(&mut rng, msg_dist)?; let m2 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p1: Tn<N> = S::encode::<T>(&m1); let p1: Tn = TGLWE::encode(&param, &m1);
// don't scale up p2, set it directly from m2 // don't scale up p2, set it directly from m2
let p2: Tn<N> = Tn(array::from_fn(|i| T64(m2.coeffs()[i].0))); let p2: Tn = Tn {
param: param.ring,
coeffs: m2.coeffs().iter().map(|c_i| T64(c_i.v)).collect(),
};
let c1 = S::encrypt(&mut rng, &pk, &p1)?; let c1 = TGLWE::encrypt(&mut rng, &param, &pk, &p1)?;
let c3 = c1 * p2; let c3 = c1 * p2;
let p3_recovered: Tn<N> = c3.decrypt(&sk); let p3_recovered: Tn = c3.decrypt(&sk);
let m3_recovered = S::decode::<T>(&p3_recovered); let m3_recovered = TGLWE::decode(&param, &p3_recovered);
assert_eq!((m1.to_r() * m2.to_r()).to_rq::<T>(), m3_recovered); assert_eq!((m1.to_r() * m2.to_r()).to_rq(param.t), m3_recovered);
} }
Ok(()) Ok(())
@@ -296,28 +341,29 @@ mod tests {
#[test] #[test]
fn test_sample_extraction() -> Result<()> { fn test_sample_extraction() -> Result<()> {
const T: u64 = 128; // msg space (msg modulus) let param = Param {
const N: usize = 64; ring: RingParam { q: u64::MAX, n: 64 },
const K: usize = 16; k: 16,
const KN: usize = K * N; t: 128, // plaintext modulus
};
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T); let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..20 { for _ in 0..20 {
let (sk, pk) = TGLWE::<N, K>::new_key::<KN>(&mut rng)?; let (sk, pk) = TGLWE::new_key(&mut rng, &param)?;
let sk_tlwe = sk.to_tlwe::<KN>(); let sk_tlwe = sk.to_tlwe(&param);
let m = Rq::<T, N>::rand_u64(&mut rng, msg_dist)?; let m = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p: Tn<N> = TGLWE::<N, K>::encode::<T>(&m); let p: Tn = TGLWE::encode(&param, &m);
let c = TGLWE::<N, K>::encrypt(&mut rng, &pk, &p)?; let c = TGLWE::encrypt(&mut rng, &param, &pk, &p)?;
for h in 0..N { for h in 0..param.ring.n {
let c_h: TLWE<KN> = c.sample_extraction(h); let c_h: TLWE = c.sample_extraction(&param, h);
let p_recovered = c_h.decrypt(&sk_tlwe); let p_recovered = c_h.decrypt(&sk_tlwe);
let m_recovered = TLWE::<KN>::decode::<T>(&p_recovered); let m_recovered = TLWE::decode(&param, &p_recovered);
assert_eq!(m.coeffs()[h], m_recovered.coeffs()[0]); assert_eq!(m.coeffs()[h], m_recovered.coeffs()[0]);
} }
} }

View File

@@ -4,62 +4,63 @@ use rand::Rng;
use std::array; use std::array;
use std::ops::{Add, Mul}; use std::ops::{Add, Mul};
use arith::{Ring, Rq, Tn, T64, TR}; use arith::{Ring, RingParam, Rq, Tn, T64, TR};
use crate::tlev::TLev; use crate::tlev::TLev;
use crate::{ use crate::{
tglwe::TGLWE, tglwe::TGLWE,
tlwe::{PublicKey, SecretKey, TLWE}, tlwe::{PublicKey, SecretKey, TLWE},
}; };
use gfhe::glwe::GLWE; use gfhe::glwe::{Param, GLWE};
/// vector of length K+1 = [K], [1] /// vector of length K+1 = [K], [1]
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct TGSW<const K: usize>(pub(crate) Vec<TLev<K>>, TLev<K>); pub struct TGSW(pub(crate) Vec<TLev>, TLev);
impl<const K: usize> TGSW<K> { impl TGSW {
pub fn encrypt_s( pub fn encrypt_s(
mut rng: impl Rng, mut rng: impl Rng,
param: &Param,
beta: u32, beta: u32,
l: u32, l: u32,
sk: &SecretKey<K>, sk: &SecretKey,
m: &T64, m: &T64,
) -> Result<Self> { ) -> Result<Self> {
let a: Vec<TLev<K>> = (0..K) let a: Vec<TLev> = (0..param.k)
.map(|i| TLev::encrypt_s(&mut rng, beta, l, sk, &(-sk.0 .0 .0[i] * *m))) .map(|i| TLev::encrypt_s(&mut rng, &param, beta, l, sk, &(-sk.0 .0.r[i] * *m)))
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
let b: TLev<K> = TLev::encrypt_s(&mut rng, beta, l, sk, m)?; let b: TLev = TLev::encrypt_s(&mut rng, &param, beta, l, sk, m)?;
Ok(Self(a, b)) Ok(Self(a, b))
} }
pub fn decrypt(&self, sk: &SecretKey<K>, beta: u32) -> T64 { pub fn decrypt(&self, sk: &SecretKey, beta: u32) -> T64 {
self.1.decrypt(sk, beta) self.1.decrypt(sk, beta)
} }
pub fn from_tlwe(_tlwe: TLWE<K>) -> Self { pub fn from_tlwe(_tlwe: TLWE) -> Self {
todo!() todo!()
} }
pub fn cmux(bit: Self, ct1: TLWE<K>, ct2: TLWE<K>) -> TLWE<K> { pub fn cmux(bit: Self, ct1: TLWE, ct2: TLWE) -> TLWE {
ct1.clone() + (bit * (ct2 - ct1)) ct1.clone() + (bit * (ct2 - ct1))
} }
} }
/// External product TGSW x TLWE /// External product TGSW x TLWE
impl<const K: usize> Mul<TLWE<K>> for TGSW<K> { impl Mul<TLWE> for TGSW {
type Output = TLWE<K>; type Output = TLWE;
fn mul(self, tlwe: TLWE<K>) -> TLWE<K> { fn mul(self, tlwe: TLWE) -> TLWE {
let beta: u32 = 2; let beta: u32 = 2;
let l: u32 = 64; // TODO wip let l: u32 = 64; // TODO wip
// since N=1, each tlwe element is a vector of length=1, decomposed into // since N=1, each tlwe element is a vector of length=1, decomposed into
// l elements, and we have K of them // l elements, and we have K of them
let tlwe_ab: Vec<T64> = [tlwe.0 .0 .0.clone(), vec![tlwe.0 .1]].concat(); let tlwe_ab: Vec<T64> = [tlwe.0 .0.r.clone(), vec![tlwe.0 .1]].concat();
let tgsw_ab: Vec<TLev<K>> = [self.0.clone(), vec![self.1]].concat(); let tgsw_ab: Vec<TLev> = [self.0.clone(), vec![self.1]].concat();
assert_eq!(tgsw_ab.len(), tlwe_ab.len()); assert_eq!(tgsw_ab.len(), tlwe_ab.len());
let r: TLWE<K> = zip_eq(tgsw_ab, tlwe_ab) let r: TLWE = zip_eq(tgsw_ab, tlwe_ab)
.map(|(tlev_i, tlwe_i)| tlev_i * tlwe_i.decompose(beta, l)) .map(|(tlev_i, tlwe_i)| tlev_i * tlwe_i.decompose(beta, l))
.sum(); .sum();
r r
@@ -75,25 +76,26 @@ mod tests {
#[test] #[test]
fn test_encrypt_decrypt() -> Result<()> { fn test_encrypt_decrypt() -> Result<()> {
const T: u64 = 2; // plaintext modulus let param = Param {
const K: usize = 16; ring: RingParam { q: u64::MAX, n: 1 },
type S = TGSW<K>; k: 16,
t: 2, // plaintext modulus
};
let beta: u32 = 2; let beta: u32 = 2;
let l: u32 = 16; let l: u32 = 16;
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T); let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..50 { for _ in 0..50 {
let (sk, _) = TLWE::<K>::new_key(&mut rng)?; let (sk, _) = TLWE::new_key(&mut rng, &param)?;
let m: Rq<T, 1> = Rq::rand_u64(&mut rng, msg_dist)?; let m: Rq = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p: T64 = TLev::<K>::encode::<T>(&m); // plaintext let p: T64 = TLev::encode(&param, &m); // plaintext
let c = S::encrypt_s(&mut rng, beta, l, &sk, &p)?; let c = TGSW::encrypt_s(&mut rng, &param, beta, l, &sk, &p)?;
let p_recovered = c.decrypt(&sk, beta); let p_recovered = c.decrypt(&sk, beta);
let m_recovered = TLev::<K>::decode::<T>(&p_recovered); let m_recovered = TLev::decode(&param, &p_recovered);
assert_eq!(m, m_recovered); assert_eq!(m, m_recovered);
} }
@@ -103,36 +105,38 @@ mod tests {
#[test] #[test]
fn test_external_product() -> Result<()> { fn test_external_product() -> Result<()> {
const T: u64 = 2; // plaintext modulus let param = Param {
const K: usize = 32; ring: RingParam { q: u64::MAX, n: 1 },
k: 32,
t: 2, // plaintext modulus
};
let beta: u32 = 2; let beta: u32 = 2;
let l: u32 = 64; let l: u32 = 64;
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T); let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..50 { for _ in 0..50 {
let (sk, _) = TLWE::<K>::new_key(&mut rng)?; let (sk, _) = TLWE::new_key(&mut rng, &param)?;
let m1: Rq<T, 1> = Rq::rand_u64(&mut rng, msg_dist)?; let m1: Rq = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p1: T64 = TLev::<K>::encode::<T>(&m1); let p1: T64 = TLev::encode(&param, &m1);
let m2: Rq<T, 1> = Rq::rand_u64(&mut rng, msg_dist)?; let m2: Rq = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p2: T64 = TLWE::<K>::encode::<T>(&m2); // scaled by delta let p2: T64 = TLWE::encode(&param, &m2); // scaled by delta
let tgsw = TGSW::<K>::encrypt_s(&mut rng, beta, l, &sk, &p1)?; let tgsw = TGSW::encrypt_s(&mut rng, &param, beta, l, &sk, &p1)?;
let tlwe = TLWE::<K>::encrypt_s(&mut rng, &sk, &p2)?; let tlwe = TLWE::encrypt_s(&mut rng, &param, &sk, &p2)?;
let res: TLWE<K> = tgsw * tlwe; let res: TLWE = tgsw * tlwe;
// let p_recovered = res.decrypt(&sk, beta); // let p_recovered = res.decrypt(&sk, beta);
let p_recovered = res.decrypt(&sk); let p_recovered = res.decrypt(&sk);
// downscaled by delta^-1 // downscaled by delta^-1
let res_recovered = TLWE::<K>::decode::<T>(&p_recovered); let res_recovered = TLWE::decode(&param, &p_recovered);
// assert_eq!(m1 * m2, m_recovered); // assert_eq!(m1 * m2, m_recovered);
assert_eq!((m1.to_r() * m2.to_r()).to_rq::<T>(), res_recovered); assert_eq!((m1.to_r() * m2.to_r()).to_rq(param.t), res_recovered);
} }
Ok(()) Ok(())
@@ -140,35 +144,38 @@ mod tests {
#[test] #[test]
fn test_cmux() -> Result<()> { fn test_cmux() -> Result<()> {
const T: u64 = 2; // plaintext modulus let param = Param {
const K: usize = 32; ring: RingParam { q: u64::MAX, n: 1 },
k: 32,
t: 2, // plaintext modulus
};
let beta: u32 = 2; let beta: u32 = 2;
let l: u32 = 64; let l: u32 = 64;
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T); let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..50 { for _ in 0..50 {
let (sk, _) = TLWE::<K>::new_key(&mut rng)?; let (sk, _) = TLWE::new_key(&mut rng, &param)?;
let m1: Rq<T, 1> = Rq::rand_u64(&mut rng, msg_dist)?; let m1: Rq = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p1: T64 = TLWE::<K>::encode::<T>(&m1); // scaled by delta let p1: T64 = TLWE::encode(&param, &m1); // scaled by delta
let m2: Rq<T, 1> = Rq::rand_u64(&mut rng, msg_dist)?; let m2: Rq = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p2: T64 = TLWE::<K>::encode::<T>(&m2); // scaled by delta let p2: T64 = TLWE::encode(&param, &m2); // scaled by delta
for bit_raw in 0..2 { for bit_raw in 0..2 {
let bit = TGSW::<K>::encrypt_s(&mut rng, beta, l, &sk, &T64(bit_raw))?; let bit = TGSW::encrypt_s(&mut rng, &param, beta, l, &sk, &T64(bit_raw))?;
let c1 = TLWE::<K>::encrypt_s(&mut rng, &sk, &p1)?; let c1 = TLWE::encrypt_s(&mut rng, &param, &sk, &p1)?;
let c2 = TLWE::<K>::encrypt_s(&mut rng, &sk, &p2)?; let c2 = TLWE::encrypt_s(&mut rng, &param, &sk, &p2)?;
let res: TLWE<K> = TGSW::cmux(bit, c1, c2); let res: TLWE = TGSW::cmux(bit, c1, c2);
let p_recovered = res.decrypt(&sk); let p_recovered = res.decrypt(&sk);
// downscaled by delta^-1 // downscaled by delta^-1
let res_recovered = TLWE::<K>::decode::<T>(&p_recovered); let res_recovered = TLWE::decode(&param, &p_recovered);
if bit_raw == 0 { if bit_raw == 0 {
assert_eq!(m1, res_recovered); assert_eq!(m1, res_recovered);

View File

@@ -4,32 +4,48 @@ use rand::Rng;
use std::array; use std::array;
use std::ops::{Add, Mul}; use std::ops::{Add, Mul};
use arith::{Ring, Rq, Tn, T64, TR}; use arith::{Ring, RingParam, Rq, Tn, T64, TR};
use crate::tglwe::TGLWE; use crate::tglwe::TGLWE;
use crate::tlwe::{PublicKey, SecretKey, TLWE}; use crate::tlwe::{PublicKey, SecretKey, TLWE};
use gfhe::glwe::Param;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct TLev<const K: usize>(pub(crate) Vec<TLWE<K>>); pub struct TLev(pub(crate) Vec<TLWE>);
impl TLev {
pub fn encode(param: &Param, m: &Rq) -> T64 {
assert_eq!(m.param.n, 1);
assert_eq!(param.t, m.param.q);
impl<const K: usize> TLev<K> {
pub fn encode<const T: u64>(m: &Rq<T, 1>) -> T64 {
let coeffs = m.coeffs(); let coeffs = m.coeffs();
T64(coeffs[0].0) // N=1, so take the only coeff T64(coeffs[0].v) // N=1, so take the only coeff
} }
pub fn decode<const T: u64>(p: &T64) -> Rq<T, 1> { pub fn decode(param: &Param, p: &T64) -> Rq {
Rq::<T, 1>::from_vec_u64(p.coeffs().iter().map(|c| c.0).collect()) Rq::from_vec_u64(
// &RingParam { q: u64::MAX, n: 1 },
&RingParam { q: param.t, n: 1 },
p.coeffs().iter().map(|c| c.0).collect(),
)
} }
pub fn encrypt( pub fn encrypt(
mut rng: impl Rng, mut rng: impl Rng,
param: &Param,
beta: u32, beta: u32,
l: u32, l: u32,
pk: &PublicKey<K>, pk: &PublicKey,
m: &T64, m: &T64,
) -> Result<Self> { ) -> Result<Self> {
let tlev: Vec<TLWE<K>> = (1..l + 1) debug_assert_eq!(pk.1.k, param.k);
let tlev: Vec<TLWE> = (1..l + 1)
.map(|i| { .map(|i| {
TLWE::<K>::encrypt(&mut rng, pk, &(*m * (u64::MAX / beta.pow(i as u32) as u64))) TLWE::encrypt(
&mut rng,
param,
pk,
&(*m * (u64::MAX / beta.pow(i as u32) as u64)),
)
}) })
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
@@ -37,12 +53,13 @@ impl<const K: usize> TLev<K> {
} }
pub fn encrypt_s( pub fn encrypt_s(
mut rng: impl Rng, mut rng: impl Rng,
param: &Param,
_beta: u32, // TODO rm, and make beta=2 always _beta: u32, // TODO rm, and make beta=2 always
l: u32, l: u32,
sk: &SecretKey<K>, sk: &SecretKey,
m: &T64, m: &T64,
) -> Result<Self> { ) -> Result<Self> {
let tlev: Vec<TLWE<K>> = (1..l as u64 + 1) let tlev: Vec<TLWE> = (1..l as u64 + 1)
.map(|i| { .map(|i| {
let aux = if i < 64 { let aux = if i < 64 {
*m * (u64::MAX / (1u64 << i)) *m * (u64::MAX / (1u64 << i))
@@ -51,22 +68,22 @@ impl<const K: usize> TLev<K> {
// by it, which would be equal to 1 // by it, which would be equal to 1
*m *m
}; };
TLWE::<K>::encrypt_s(&mut rng, sk, &aux) TLWE::encrypt_s(&mut rng, &param, sk, &aux)
}) })
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
Ok(Self(tlev)) Ok(Self(tlev))
} }
pub fn decrypt(&self, sk: &SecretKey<K>, beta: u32) -> T64 { pub fn decrypt(&self, sk: &SecretKey, beta: u32) -> T64 {
let pt = self.0[0].decrypt(sk); let pt = self.0[0].decrypt(sk);
pt.mul_div_round(beta as u64, u64::MAX) pt.mul_div_round(beta as u64, u64::MAX)
} }
} }
// TODO review u64::MAX, since is -1 of the value we actually want // TODO review u64::MAX, since is -1 of the value we actually want
impl<const K: usize> TLev<K> { impl TLev {
pub fn iter(&self) -> std::slice::Iter<TLWE<K>> { pub fn iter(&self) -> std::slice::Iter<TLWE> {
self.0.iter() self.0.iter()
} }
} }
@@ -74,14 +91,14 @@ impl<const K: usize> TLev<K> {
// dot product between a TLev and Vec<T64>, usually Vec<T64> comes from a // dot product between a TLev and Vec<T64>, usually Vec<T64> comes from a
// decomposition of T64 // decomposition of T64
// TLev * Vec<T64> --> TLWE // TLev * Vec<T64> --> TLWE
impl<const K: usize> Mul<Vec<T64>> for TLev<K> { impl Mul<Vec<T64>> for TLev {
type Output = TLWE<K>; type Output = TLWE;
fn mul(self, v: Vec<T64>) -> Self::Output { fn mul(self, v: Vec<T64>) -> Self::Output {
assert_eq!(self.0.len(), v.len()); assert_eq!(self.0.len(), v.len());
// l TLWES // l TLWES
let tlwes: Vec<TLWE<K>> = self.0; let tlwes: Vec<TLWE> = self.0;
let r: TLWE<K> = zip_eq(v, tlwes).map(|(a_d_i, glwe_i)| glwe_i * a_d_i).sum(); let r: TLWE = zip_eq(v, tlwes).map(|(a_d_i, glwe_i)| glwe_i * a_d_i).sum();
r r
} }
} }
@@ -95,27 +112,29 @@ mod tests {
#[test] #[test]
fn test_encrypt_decrypt() -> Result<()> { fn test_encrypt_decrypt() -> Result<()> {
const T: u64 = 2; // plaintext modulus let param = Param {
const K: usize = 16; ring: RingParam { q: u64::MAX, n: 1 },
type S = TLev<K>; k: 16,
t: 2, // plaintext modulus
};
let beta: u32 = 2; let beta: u32 = 2;
let l: u32 = 16; let l: u32 = 16;
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T); let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..200 { for _ in 0..200 {
let (sk, pk) = TLWE::<K>::new_key(&mut rng)?; let (sk, pk) = TLWE::new_key(&mut rng, &param)?;
let m: Rq<T, 1> = Rq::rand_u64(&mut rng, msg_dist)?; let m: Rq = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p: T64 = S::encode::<T>(&m); // plaintext let p: T64 = TLev::encode(&param, &m); // plaintext
let c = S::encrypt(&mut rng, beta, l, &pk, &p)?; let c = TLev::encrypt(&mut rng, &param, beta, l, &pk, &p)?;
let p_recovered = c.decrypt(&sk, beta); let p_recovered = c.decrypt(&sk, beta);
let m_recovered = S::decode::<T>(&p_recovered); let m_recovered = TLev::decode(&param, &p_recovered);
assert_eq!(m.remodule::<T>(), m_recovered.remodule::<T>()); assert_eq!(m.remodule(param.t), m_recovered.remodule(param.t));
} }
Ok(()) Ok(())
@@ -123,32 +142,35 @@ mod tests {
#[test] #[test]
fn test_tlev_vect64_product() -> Result<()> { fn test_tlev_vect64_product() -> Result<()> {
const T: u64 = 2; // plaintext modulus let param = Param {
const K: usize = 16; ring: RingParam { q: u64::MAX, n: 1 },
k: 16,
t: 2, // plaintext modulus
};
let beta: u32 = 2; let beta: u32 = 2;
let l: u32 = 16; let l: u32 = 16;
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T); let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..200 { for _ in 0..200 {
let (sk, pk) = TLWE::<K>::new_key(&mut rng)?; let (sk, pk) = TLWE::new_key(&mut rng, &param)?;
let m1: Rq<T, 1> = Rq::rand_u64(&mut rng, msg_dist)?; let m1: Rq = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2: Rq<T, 1> = Rq::rand_u64(&mut rng, msg_dist)?; let m2: Rq = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p1: T64 = TLev::<K>::encode::<T>(&m1); let p1: T64 = TLev::encode(&param, &m1);
let p2: T64 = TLev::<K>::encode::<T>(&m2); let p2: T64 = TLev::encode(&param, &m2);
let c1 = TLev::<K>::encrypt(&mut rng, beta, l, &pk, &p1)?; let c1 = TLev::encrypt(&mut rng, &param, beta, l, &pk, &p1)?;
let c2 = p2.decompose(beta, l); let c2 = p2.decompose(beta, l);
let c3 = c1 * c2; let c3 = c1 * c2;
let p_recovered = c3.decrypt(&sk); let p_recovered = c3.decrypt(&sk);
let m_recovered = TLev::<K>::decode::<T>(&p_recovered); let m_recovered = TLev::decode(&param, &p_recovered);
assert_eq!((m1.to_r() * m2.to_r()).to_rq::<T>(), m_recovered); assert_eq!((m1.to_r() * m2.to_r()).to_rq(param.t), m_recovered);
} }
Ok(()) Ok(())

View File

@@ -4,242 +4,276 @@ use rand::Rng;
use std::iter::Sum; use std::iter::Sum;
use std::ops::{Add, AddAssign, Mul, Sub}; use std::ops::{Add, AddAssign, Mul, Sub};
use arith::{Ring, Rq, Tn, Zq, T64, TR}; use arith::{Ring, RingParam, Rq, Tn, Zq, T64, TR};
use gfhe::{glwe, GLWE}; use gfhe::{glwe, glwe::Param, GLWE};
use crate::tggsw::TGGSW; use crate::tggsw::TGGSW;
use crate::tlev::TLev; use crate::tlev::TLev;
use crate::{tglwe, tglwe::TGLWE}; use crate::{tglwe, tglwe::TGLWE};
pub struct SecretKey<const K: usize>(pub glwe::SecretKey<T64, K>); pub struct SecretKey(pub glwe::SecretKey<T64>);
impl<const KN: usize> SecretKey<KN> { impl SecretKey {
/// from TFHE [2018-421] paper: A TLWE key k \in B^n, can be interpreted as a /// from TFHE [2018-421] paper: A TLWE key k \in B^n, can be interpreted as a
/// TRLWE key K \in B_N[X]^k having the same sequence of coefficients and /// TRLWE key K \in B_N[X]^k having the same sequence of coefficients and
/// vice-versa. /// vice-versa.
pub fn to_tglwe<const N: usize, const K: usize>(self) -> crate::tglwe::SecretKey<N, K> { pub fn to_tglwe(self, param: &Param) -> crate::tglwe::SecretKey {
let s: TR<T64, KN> = self.0 .0; let s: TR<T64> = self.0 .0; // of length K*N
assert_eq!(s.r.len(), param.k * param.ring.n); // sanity check
// split into K vectors, and interpret each of them as a T_N[X]/(X^N+1) // split into K vectors, and interpret each of them as a T_N[X]/(X^N+1)
// polynomial // polynomial
let r: Vec<Tn<N>> = let r: Vec<Tn> =
s.0.chunks(N) s.r.chunks(param.ring.n)
.map(|v| Tn::<N>::from_vec(v.to_vec())) .map(|v| Tn::from_vec(&param.ring, v.to_vec()))
.collect(); .collect();
crate::tglwe::SecretKey(glwe::SecretKey::<Tn<N>, K>(TR(r))) crate::tglwe::SecretKey(glwe::SecretKey::<Tn>(TR { k: param.k, r }))
} }
} }
pub type PublicKey<const K: usize> = glwe::PublicKey<T64, K>; pub type PublicKey = glwe::PublicKey<T64>;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct KSK<const K: usize>(Vec<TLev<K>>); pub struct KSK(Vec<TLev>);
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct TLWE<const K: usize>(pub GLWE<T64, K>); pub struct TLWE(pub GLWE<T64>);
impl<const K: usize> TLWE<K> { impl TLWE {
pub fn zero() -> Self { pub fn zero(k: usize, ring_param: &RingParam) -> Self {
Self(GLWE::<T64, K>::zero()) Self(GLWE::<T64>::zero(k, ring_param))
} }
pub fn new_key(rng: impl Rng) -> Result<(SecretKey<K>, PublicKey<K>)> { pub fn new_key(rng: impl Rng, param: &Param) -> Result<(SecretKey, PublicKey)> {
let (sk, pk): (glwe::SecretKey<T64, K>, glwe::PublicKey<T64, K>) = GLWE::new_key(rng)?; let (sk, pk): (glwe::SecretKey<T64>, glwe::PublicKey<T64>) = GLWE::new_key(rng, param)?;
Ok((SecretKey(sk), pk)) Ok((SecretKey(sk), pk))
} }
pub fn encode<const P: u64>(m: &Rq<P, 1>) -> T64 { // TODO use param instead of p:u64
let delta = u64::MAX / P; // floored pub fn encode(param: &Param, m: &Rq) -> T64 {
assert_eq!(param.ring.n, 1);
debug_assert_eq!(param.t, m.param.q); // plaintext modulus
//
let delta = u64::MAX / param.t; // floored
let coeffs = m.coeffs(); let coeffs = m.coeffs();
T64(coeffs[0].0 * delta) T64(coeffs[0].v * delta)
} }
pub fn decode<const P: u64>(p: &T64) -> Rq<P, 1> { pub fn decode(param: &Param, p: &T64) -> Rq {
let p = p.mul_div_round(P, u64::MAX); let p = p.mul_div_round(param.t, u64::MAX);
Rq::<P, 1>::from_vec_u64(p.coeffs().iter().map(|c| c.0).collect()) Rq::from_vec_u64(&param.pt(), p.coeffs().iter().map(|c| c.0).collect())
} }
// encrypts with the given SecretKey (instead of PublicKey) // encrypts with the given SecretKey (instead of PublicKey)
pub fn encrypt_s(rng: impl Rng, sk: &SecretKey<K>, p: &T64) -> Result<Self> { pub fn encrypt_s(rng: impl Rng, param: &Param, sk: &SecretKey, p: &T64) -> Result<Self> {
let glwe = GLWE::encrypt_s(rng, &sk.0, p)?; let glwe = GLWE::encrypt_s(rng, param, &sk.0, p)?;
Ok(Self(glwe)) Ok(Self(glwe))
} }
pub fn encrypt(rng: impl Rng, pk: &PublicKey<K>, p: &T64) -> Result<Self> { pub fn encrypt(rng: impl Rng, param: &Param, pk: &PublicKey, p: &T64) -> Result<Self> {
let glwe = GLWE::encrypt(rng, &pk, p)?; let glwe = GLWE::encrypt(rng, param, pk, p)?;
Ok(Self(glwe)) Ok(Self(glwe))
} }
pub fn decrypt(&self, sk: &SecretKey<K>) -> T64 { pub fn decrypt(&self, sk: &SecretKey) -> T64 {
self.0.decrypt(&sk.0) self.0.decrypt(&sk.0)
} }
pub fn new_ksk( pub fn new_ksk(
mut rng: impl Rng, mut rng: impl Rng,
param: &Param,
beta: u32, beta: u32,
l: u32, l: u32,
sk: &SecretKey<K>, sk: &SecretKey,
new_sk: &SecretKey<K>, new_sk: &SecretKey,
) -> Result<KSK<K>> { ) -> Result<KSK> {
let r: Vec<TLev<K>> = (0..K) let r: Vec<TLev> = (0..param.k)
.into_iter() .into_iter()
.map(|i| .map(|i|
// treat sk_i as the msg being encrypted // treat sk_i as the msg being encrypted
TLev::<K>::encrypt_s(&mut rng, beta, l, &new_sk, &sk.0.0 .0[i])) TLev::encrypt_s(&mut rng, param, beta, l, &new_sk, &sk.0.0 .r[i]))
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
Ok(KSK(r)) Ok(KSK(r))
} }
pub fn key_switch(&self, beta: u32, l: u32, ksk: &KSK<K>) -> Self { pub fn key_switch(&self, param: &Param, beta: u32, l: u32, ksk: &KSK) -> Self {
let (a, b): (TR<T64, K>, T64) = (self.0 .0.clone(), self.0 .1); let (a, b): (TR<T64>, T64) = (self.0 .0.clone(), self.0 .1);
let lhs: TLWE<K> = TLWE(GLWE(TR::zero(), b)); let lhs: TLWE = TLWE(GLWE(TR::zero(param.k * param.ring.n, &param.ring), b));
// K iterations, ksk.0 contains K times GLev // K iterations, ksk.0 contains K times GLev
let rhs: TLWE<K> = zip_eq(a.0, ksk.0.clone()) let rhs: TLWE = zip_eq(a.r, ksk.0.clone())
.map(|(a_i, ksk_i)| ksk_i * a_i.decompose(beta, l)) // dot_product .map(|(a_i, ksk_i)| ksk_i * a_i.decompose(beta, l)) // dot_product
.sum(); .sum();
lhs - rhs lhs - rhs
} }
// modulus switch from Q (2^64) to Q2 (in blind_rotation Q2=K*N) // modulus switch from Q (2^64) to Q2 (in blind_rotation Q2=K*N)
pub fn mod_switch<const Q2: u64>(&self) -> Self { pub fn mod_switch(&self, q2: u64) -> Self {
let a: TR<T64, K> = self.0 .0.mod_switch::<Q2>(); let a: TR<T64> = self.0 .0.mod_switch(q2);
let b: T64 = self.0 .1.mod_switch::<Q2>(); let b: T64 = self.0 .1.mod_switch(q2);
Self(GLWE(a, b)) Self(GLWE(a, b))
} }
} }
// NOTE: the ugly const generics are temporary // NOTE: the ugly const generics are temporary
pub fn blind_rotation<const N: usize, const K: usize, const KN: usize, const KN2: u64>( pub fn blind_rotation(
c: TLWE<KN>, param: &Param,
btk: BootstrappingKey<N, K, KN>, c: TLWE, // kn
table: TGLWE<N, K>, btk: BootstrappingKey,
) -> TGLWE<N, K> { table: TGLWE, // n,k
let c_kn: TLWE<KN> = c.mod_switch::<KN2>(); ) -> TGLWE {
let (a, b): (TR<T64, KN>, T64) = (c_kn.0 .0, c_kn.0 .1); debug_assert_eq!(c.0 .0.k, param.k); // TODO or k*n?
// TODO replace `param.k*param.ring.n` by `param.kn()`
let c_kn: TLWE = c.mod_switch((param.k * param.ring.n) as u64);
let (a, b): (TR<T64>, T64) = (c_kn.0 .0, c_kn.0 .1);
// two main parts: rotate by a known power of X, rotate by a secret // two main parts: rotate by a known power of X, rotate by a secret
// power of X (using the C gate) // power of X (using the C gate)
// table * X^-b, ie. left rotate // table * X^-b, ie. left rotate
let v_xb: TGLWE<N, K> = table.left_rotate(b.0 as usize); let v_xb: TGLWE = table.left_rotate(b.0 as usize);
// rotate by a secret power of X using the cmux gate // rotate by a secret power of X using the cmux gate
let mut c_j: TGLWE<N, K> = v_xb.clone(); let mut c_j: TGLWE = v_xb.clone();
let _ = (1..K).map(|j| { let _ = (1..param.k).map(|j| {
c_j = TGGSW::<N, K>::cmux( c_j = TGGSW::cmux(
btk.0[j].clone(), btk.0[j].clone(),
c_j.clone(), c_j.clone(),
c_j.clone().left_rotate(a.0[j].0 as usize), c_j.clone().left_rotate(a.r[j].0 as usize),
); );
dbg!(&c_j); // dbg!(&c_j);
}); });
c_j c_j
} }
pub fn bootstrapping<const N: usize, const K: usize, const KN: usize, const KN2: u64>( pub fn bootstrapping(
btk: BootstrappingKey<N, K, KN>, param: &Param,
table: TGLWE<N, K>, btk: BootstrappingKey,
c: TLWE<KN>, table: TGLWE,
) -> TLWE<KN> { c: TLWE, // kn
let rotated: TGLWE<N, K> = blind_rotation::<N, K, KN, KN2>(c, btk.clone(), table); ) -> TLWE {
let c_h: TLWE<KN> = rotated.sample_extraction(0); // kn
let r = c_h.key_switch(2, 64, &btk.1); let rotated: TGLWE = blind_rotation(param, c, btk.clone(), table);
let c_h: TLWE = rotated.sample_extraction(&param, 0);
let r = c_h.key_switch(param, 2, 64, &btk.1);
r r
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct BootstrappingKey<const N: usize, const K: usize, const KN: usize>( pub struct BootstrappingKey(
pub Vec<TGGSW<N, K>>, pub Vec<TGGSW>,
pub KSK<KN>, pub KSK, // kn
); );
impl<const N: usize, const K: usize, const KN: usize> BootstrappingKey<N, K, KN> { impl BootstrappingKey {
pub fn from_sk(mut rng: impl Rng, sk: &tglwe::SecretKey<N, K>) -> Result<Self> { pub fn from_sk(mut rng: impl Rng, param: &Param, sk: &tglwe::SecretKey) -> Result<Self> {
let (beta, l) = (2u32, 64u32); // TMP let (beta, l) = (2u32, 64u32); // TMP
// //
let s: TR<Tn<N>, K> = sk.0 .0.clone(); let s: TR<Tn> = sk.0 .0.clone();
let (sk2, _) = TLWE::<KN>::new_key(&mut rng)?; // TLWE<KN> compatible with TGLWE<N,K> let (sk2, _) = TLWE::new_key(&mut rng, &param.lwe())?; // TLWE<KN> compatible with TGLWE<N,K>
// each btk_j = TGGSW_sk(s_i) // each btk_j = TGGSW_sk(s_i)
let btk: Vec<TGGSW<N, K>> = s let btk: Vec<TGGSW> = s
.iter() .iter()
.map(|s_i| TGGSW::<N, K>::encrypt_s(&mut rng, beta, l, sk, s_i)) .map(|s_i| TGGSW::encrypt_s(&mut rng, param, beta, l, sk, s_i))
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
let ksk = TLWE::<KN>::new_ksk(&mut rng, beta, l, &sk.to_tlwe(), &sk2)?;
let ksk = TLWE::new_ksk(
&mut rng,
&param.lwe(),
beta,
l,
&sk.to_tlwe(&param.lwe()), // converted to length k*n
&sk2, // created with length k*n
)?;
debug_assert_eq!(ksk.0.len(), param.lwe().k);
debug_assert_eq!(ksk.0.len(), param.k * param.ring.n);
Ok(Self(btk, ksk)) Ok(Self(btk, ksk))
} }
} }
pub fn compute_lookup_table<const T: u64, const K: usize, const N: usize>() -> TGLWE<N, K> { pub fn compute_lookup_table(param: &Param) -> TGLWE {
// from 2021-1402: // from 2021-1402:
// v(x) = \sum_j^{N-1} [(p_j / 2N mod p)/p] X^j // v(x) = \sum_j^{N-1} [(p_j / 2N mod p)/p] X^j
// matrix of coefficients with size K*N = delta x T // matrix of coefficients with size K*N = delta x T
let delta: usize = N / T as usize; let delta: usize = param.ring.n / param.t as usize;
let values: Vec<Zq<T>> = (0..T).map(|v| Zq::<T>::from_u64(v)).collect(); let values: Vec<Zq> = (0..param.t).map(|v| Zq::from_u64(param.t, v)).collect();
let coeffs: Vec<Zq<T>> = (0..T as usize) let coeffs: Vec<Zq> = (0..param.t as usize)
.flat_map(|i| vec![values[i]; delta]) .flat_map(|i| vec![values[i]; delta])
.collect(); .collect();
let table = Rq::<T, N>::from_vec(coeffs); let table = Rq::from_vec(&param.pt(), coeffs);
// encode the table as plaintext // encode the table as plaintext
let v: Tn<N> = TGLWE::<N, K>::encode::<T>(&table); let v: Tn = TGLWE::encode(param, &table);
// encode the table as TGLWE ciphertext // encode the table as TGLWE ciphertext
let v: TGLWE<N, K> = TGLWE::<N, K>::from_plaintext(v); let v: TGLWE = TGLWE::from_plaintext(param.k, &param.ring, v);
v v
} }
impl<const K: usize> Add<TLWE<K>> for TLWE<K> { impl Add<TLWE> for TLWE {
type Output = Self; type Output = Self;
fn add(self, other: Self) -> Self { fn add(self, other: Self) -> Self {
debug_assert_eq!(self.0 .0.k, other.0 .0.k);
debug_assert_eq!(self.0 .1.param(), other.0 .1.param());
Self(self.0 + other.0) Self(self.0 + other.0)
} }
} }
impl<const K: usize> AddAssign for TLWE<K> { impl AddAssign for TLWE {
fn add_assign(&mut self, rhs: Self) { fn add_assign(&mut self, rhs: Self) {
debug_assert_eq!(self.0 .0.k, rhs.0 .0.k);
debug_assert_eq!(self.0 .1.param(), rhs.0 .1.param());
self.0 += rhs.0 self.0 += rhs.0
} }
} }
impl<const K: usize> Sum<TLWE<K>> for TLWE<K> { impl Sum<TLWE> for TLWE {
fn sum<I>(iter: I) -> Self fn sum<I>(mut iter: I) -> Self
where where
I: Iterator<Item = Self>, I: Iterator<Item = Self>,
{ {
let mut acc = TLWE::<K>::zero(); // let mut acc = TLWE::<K>::zero();
for e in iter { // for e in iter {
acc += e; // acc += e;
} // }
acc // acc
let first = iter.next().unwrap();
iter.fold(first, |acc, e| acc + e)
} }
} }
impl<const K: usize> Sub<TLWE<K>> for TLWE<K> { impl Sub<TLWE> for TLWE {
type Output = Self; type Output = Self;
fn sub(self, other: Self) -> Self { fn sub(self, other: Self) -> Self {
debug_assert_eq!(self.0 .0.k, other.0 .0.k);
debug_assert_eq!(self.0 .1.param(), other.0 .1.param());
Self(self.0 - other.0) Self(self.0 - other.0)
} }
} }
// plaintext addition // plaintext addition
impl<const K: usize> Add<T64> for TLWE<K> { impl Add<T64> for TLWE {
type Output = Self; type Output = Self;
fn add(self, plaintext: T64) -> Self { fn add(self, plaintext: T64) -> Self {
let a: TR<T64, K> = self.0 .0; let a: TR<T64> = self.0 .0;
let b: T64 = self.0 .1 + plaintext; let b: T64 = self.0 .1 + plaintext;
Self(GLWE(a, b)) Self(GLWE(a, b))
} }
} }
// plaintext substraction // plaintext substraction
impl<const K: usize> Sub<T64> for TLWE<K> { impl Sub<T64> for TLWE {
type Output = Self; type Output = Self;
fn sub(self, plaintext: T64) -> Self { fn sub(self, plaintext: T64) -> Self {
let a: TR<T64, K> = self.0 .0; let a: TR<T64> = self.0 .0;
let b: T64 = self.0 .1 - plaintext; let b: T64 = self.0 .1 - plaintext;
Self(GLWE(a, b)) Self(GLWE(a, b))
} }
} }
// plaintext multiplication // plaintext multiplication
impl<const K: usize> Mul<T64> for TLWE<K> { impl Mul<T64> for TLWE {
type Output = Self; type Output = Self;
fn mul(self, plaintext: T64) -> Self { fn mul(self, plaintext: T64) -> Self {
let a: TR<T64, K> = TR(self.0 .0 .0.iter().map(|r_i| *r_i * plaintext).collect()); let a: TR<T64> = TR {
k: self.0 .0.k,
r: self.0 .0.r.iter().map(|r_i| *r_i * plaintext).collect(),
};
let b: T64 = self.0 .1 * plaintext; let b: T64 = self.0 .1 * plaintext;
Self(GLWE(a, b)) Self(GLWE(a, b))
} }
@@ -255,29 +289,31 @@ mod tests {
#[test] #[test]
fn test_encrypt_decrypt() -> Result<()> { fn test_encrypt_decrypt() -> Result<()> {
const T: u64 = 128; // msg space (msg modulus) let param = Param {
const K: usize = 16; ring: RingParam { q: u64::MAX, n: 1 },
type S = TLWE<K>; k: 16,
t: 128, // plaintext modulus
};
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T); let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..200 { for _ in 0..200 {
let (sk, pk) = S::new_key(&mut rng)?; let (sk, pk) = TLWE::new_key(&mut rng, &param)?;
let m = Rq::<T, 1>::rand_u64(&mut rng, msg_dist)?; let m = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p: T64 = S::encode::<T>(&m); let p: T64 = TLWE::encode(&param, &m);
let c = S::encrypt(&mut rng, &pk, &p)?; let c = TLWE::encrypt(&mut rng, &param, &pk, &p)?;
let p_recovered = c.decrypt(&sk); let p_recovered = c.decrypt(&sk);
let m_recovered = S::decode::<T>(&p_recovered); let m_recovered = TLWE::decode(&param, &p_recovered);
assert_eq!(m, m_recovered); assert_eq!(m, m_recovered);
// same but using encrypt_s (with sk instead of pk)) // same but using encrypt_s (with sk instead of pk))
let c = S::encrypt_s(&mut rng, &sk, &p)?; let c = TLWE::encrypt_s(&mut rng, &param, &sk, &p)?;
let p_recovered = c.decrypt(&sk); let p_recovered = c.decrypt(&sk);
let m_recovered = S::decode::<T>(&p_recovered); let m_recovered = TLWE::decode(&param, &p_recovered);
assert_eq!(m, m_recovered); assert_eq!(m, m_recovered);
} }
@@ -287,30 +323,32 @@ mod tests {
#[test] #[test]
fn test_addition() -> Result<()> { fn test_addition() -> Result<()> {
const T: u64 = 128; let param = Param {
const K: usize = 16; ring: RingParam { q: u64::MAX, n: 1 },
type S = TLWE<K>; k: 16,
t: 128, // plaintext modulus
};
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T); let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..200 { for _ in 0..200 {
let (sk, pk) = S::new_key(&mut rng)?; let (sk, pk) = TLWE::new_key(&mut rng, &param)?;
let m1 = Rq::<T, 1>::rand_u64(&mut rng, msg_dist)?; let m1 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2 = Rq::<T, 1>::rand_u64(&mut rng, msg_dist)?; let m2 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p1: T64 = S::encode::<T>(&m1); // plaintext let p1: T64 = TLWE::encode(&param, &m1); // plaintext
let p2: T64 = S::encode::<T>(&m2); // plaintext let p2: T64 = TLWE::encode(&param, &m2); // plaintext
let c1 = S::encrypt(&mut rng, &pk, &p1)?; let c1 = TLWE::encrypt(&mut rng, &param, &pk, &p1)?;
let c2 = S::encrypt(&mut rng, &pk, &p2)?; let c2 = TLWE::encrypt(&mut rng, &param, &pk, &p2)?;
let c3 = c1 + c2; let c3 = c1 + c2;
let p3_recovered = c3.decrypt(&sk); let p3_recovered = c3.decrypt(&sk);
let m3_recovered = S::decode::<T>(&p3_recovered); let m3_recovered = TLWE::decode(&param, &p3_recovered);
assert_eq!((m1 + m2).remodule::<T>(), m3_recovered.remodule::<T>()); assert_eq!((m1 + m2).remodule(param.t), m3_recovered.remodule(param.t));
} }
Ok(()) Ok(())
@@ -318,27 +356,29 @@ mod tests {
#[test] #[test]
fn test_add_plaintext() -> Result<()> { fn test_add_plaintext() -> Result<()> {
const T: u64 = 128; let param = Param {
const K: usize = 16; ring: RingParam { q: u64::MAX, n: 1 },
type S = TLWE<K>; k: 16,
t: 128, // plaintext modulus
};
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T); let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..200 { for _ in 0..200 {
let (sk, pk) = S::new_key(&mut rng)?; let (sk, pk) = TLWE::new_key(&mut rng, &param)?;
let m1 = Rq::<T, 1>::rand_u64(&mut rng, msg_dist)?; let m1 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2 = Rq::<T, 1>::rand_u64(&mut rng, msg_dist)?; let m2 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p1: T64 = S::encode::<T>(&m1); // plaintext let p1: T64 = TLWE::encode(&param, &m1); // plaintext
let p2: T64 = S::encode::<T>(&m2); // plaintext let p2: T64 = TLWE::encode(&param, &m2); // plaintext
let c1 = S::encrypt(&mut rng, &pk, &p1)?; let c1 = TLWE::encrypt(&mut rng, &param, &pk, &p1)?;
let c3 = c1 + p2; let c3 = c1 + p2;
let p3_recovered = c3.decrypt(&sk); let p3_recovered = c3.decrypt(&sk);
let m3_recovered = S::decode::<T>(&p3_recovered); let m3_recovered = TLWE::decode(&param, &p3_recovered);
assert_eq!(m1 + m2, m3_recovered); assert_eq!(m1 + m2, m3_recovered);
} }
@@ -348,30 +388,32 @@ mod tests {
#[test] #[test]
fn test_mul_plaintext() -> Result<()> { fn test_mul_plaintext() -> Result<()> {
const T: u64 = 128; let param = Param {
const K: usize = 16; ring: RingParam { q: u64::MAX, n: 1 },
type S = TLWE<K>; k: 16,
t: 128, // plaintext modulus
};
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T); let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..200 { for _ in 0..200 {
let (sk, pk) = S::new_key(&mut rng)?; let (sk, pk) = TLWE::new_key(&mut rng, &param)?;
let m1 = Rq::<T, 1>::rand_u64(&mut rng, msg_dist)?; let m1 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2 = Rq::<T, 1>::rand_u64(&mut rng, msg_dist)?; let m2 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p1: T64 = S::encode::<T>(&m1); let p1: T64 = TLWE::encode(&param, &m1);
// don't scale up p2, set it directly from m2 // don't scale up p2, set it directly from m2
// let p2: T64 = Tn(array::from_fn(|i| T64(m2.coeffs()[i].0))); // let p2: T64 = Tn(array::from_fn(|i| T64(m2.coeffs()[i].0)));
let p2: T64 = T64(m2.coeffs()[0].0); let p2: T64 = T64(m2.coeffs()[0].v);
let c1 = S::encrypt(&mut rng, &pk, &p1)?; let c1 = TLWE::encrypt(&mut rng, &param, &pk, &p1)?;
let c3 = c1 * p2; let c3 = c1 * p2;
let p3_recovered: T64 = c3.decrypt(&sk); let p3_recovered: T64 = c3.decrypt(&sk);
let m3_recovered = S::decode::<T>(&p3_recovered); let m3_recovered = TLWE::decode(&param, &p3_recovered);
assert_eq!((m1.to_r() * m2.to_r()).to_rq::<T>(), m3_recovered); assert_eq!((m1.to_r() * m2.to_r()).to_rq(param.t), m3_recovered);
} }
Ok(()) Ok(())
@@ -379,38 +421,40 @@ mod tests {
#[test] #[test]
fn test_key_switch() -> Result<()> { fn test_key_switch() -> Result<()> {
const T: u64 = 128; // plaintext modulus let param = Param {
const K: usize = 16; ring: RingParam { q: u64::MAX, n: 1 },
type S = TLWE<K>; k: 16,
t: 128, // plaintext modulus
};
let beta: u32 = 2; let beta: u32 = 2;
let l: u32 = 64; let l: u32 = 64;
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let (sk, pk) = S::new_key(&mut rng)?; let (sk, pk) = TLWE::new_key(&mut rng, &param)?;
let (sk2, _) = S::new_key(&mut rng)?; let (sk2, _) = TLWE::new_key(&mut rng, &param)?;
// ksk to switch from sk to sk2 // ksk to switch from sk to sk2
let ksk = S::new_ksk(&mut rng, beta, l, &sk, &sk2)?; let ksk = TLWE::new_ksk(&mut rng, &param, beta, l, &sk, &sk2)?;
let msg_dist = Uniform::new(0_u64, T); let msg_dist = Uniform::new(0_u64, param.t);
let m = Rq::<T, 1>::rand_u64(&mut rng, msg_dist)?; let m = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p = S::encode::<T>(&m); // plaintext let p = TLWE::encode(&param, &m); // plaintext
// //
let c = S::encrypt_s(&mut rng, &sk, &p)?; let c = TLWE::encrypt_s(&mut rng, &param, &sk, &p)?;
let c2 = c.key_switch(beta, l, &ksk); let c2 = c.key_switch(&param, beta, l, &ksk);
// decrypt with the 2nd secret key // decrypt with the 2nd secret key
let p_recovered = c2.decrypt(&sk2); let p_recovered = c2.decrypt(&sk2);
let m_recovered = S::decode::<T>(&p_recovered); let m_recovered = TLWE::decode(&param, &p_recovered);
assert_eq!(m.remodule::<T>(), m_recovered.remodule::<T>()); assert_eq!(m.remodule(param.t), m_recovered.remodule(param.t));
// do the same but now encrypting with pk // do the same but now encrypting with pk
let c = S::encrypt(&mut rng, &pk, &p)?; let c = TLWE::encrypt(&mut rng, &param, &pk, &p)?;
let c2 = c.key_switch(beta, l, &ksk); let c2 = c.key_switch(&param, beta, l, &ksk);
let p_recovered = c2.decrypt(&sk2); let p_recovered = c2.decrypt(&sk2);
let m_recovered = S::decode::<T>(&p_recovered); let m_recovered = TLWE::decode(&param, &p_recovered);
assert_eq!(m, m_recovered); assert_eq!(m, m_recovered);
Ok(()) Ok(())
@@ -418,39 +462,44 @@ mod tests {
#[test] #[test]
fn test_bootstrapping() -> Result<()> { fn test_bootstrapping() -> Result<()> {
const T: u64 = 128; // plaintext modulus let param = Param {
const K: usize = 1; ring: RingParam {
const N: usize = 1024; q: u64::MAX,
const KN: usize = K * N; n: 1024,
},
k: 1,
t: 128, // plaintext modulus
};
// const T: u64 = 128; // plaintext modulus
// const K: usize = 1;
// const N: usize = 1024;
// const KN: usize = K * N;
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let start = Instant::now(); let start = Instant::now();
let table: TGLWE<N, K> = compute_lookup_table::<T, K, N>(); let table: TGLWE = compute_lookup_table(&param);
println!("table took: {:?}", start.elapsed()); println!("table took: {:?}", start.elapsed());
let (sk, _) = TGLWE::<N, K>::new_key::<KN>(&mut rng)?; let (sk, _) = TGLWE::new_key(&mut rng, &param)?;
let sk_tlwe: SecretKey<KN> = sk.to_tlwe::<KN>(); let sk_tlwe: SecretKey = sk.to_tlwe(&param);
let start = Instant::now(); let start = Instant::now();
let btk = BootstrappingKey::<N, K, KN>::from_sk(&mut rng, &sk)?; let btk = BootstrappingKey::from_sk(&mut rng, &param, &sk)?;
println!("btk took: {:?}", start.elapsed()); println!("btk took: {:?}", start.elapsed());
let msg_dist = Uniform::new(0_u64, T); let msg_dist = Uniform::new(0_u64, param.t);
let m = Rq::<T, 1>::rand_u64(&mut rng, msg_dist)?; let m = Rq::rand_u64(&mut rng, msg_dist, &param.lwe().pt())?; // q=t, n=1
dbg!(&m); let p = TLWE::encode(&param.lwe(), &m); // plaintext
let p = TLWE::<K>::encode::<T>(&m); // plaintext
let c = TLWE::<KN>::encrypt_s(&mut rng, &sk_tlwe, &p)?; let c = TLWE::encrypt_s(&mut rng, &param.lwe(), &sk_tlwe, &p)?;
let start = Instant::now(); let start = Instant::now();
// the ugly const generics are temporary // the ugly const generics are temporary
let bootstrapped: TLWE<KN> = let bootstrapped: TLWE = bootstrapping(&param, btk, table, c);
bootstrapping::<N, K, KN, { K as u64 * N as u64 }>(btk, table, c);
println!("bootstrapping took: {:?}", start.elapsed()); println!("bootstrapping took: {:?}", start.elapsed());
let p_recovered: T64 = bootstrapped.decrypt(&sk_tlwe); let p_recovered: T64 = bootstrapped.decrypt(&sk_tlwe);
let m_recovered = TLWE::<KN>::decode::<T>(&p_recovered); let m_recovered = TLWE::decode(&param.lwe(), &p_recovered);
dbg!(&m_recovered);
assert_eq!(m_recovered, m); assert_eq!(m_recovered, m);
Ok(()) Ok(())