Rm const generics (#2)

* arith: get rid of constant generics. Reason:

using constant generics was great for allocating the arrays in the
stack, which is faster, but when started to use bigger parameter values,
in some cases it was overflowing the stack. This commit removes all the
constant generics in all of the `arith` crate, which in some cases slows
a bit the performance, but allows for bigger parameter values (on the
ones that affect lengths, like N and K).

* bfv: get rid of constant generics (reason in previous commit)

* ckks: get rid of constant generics (reason in two commits ago)

* group ring params under a single struct

* gfhe: get rid of constant generics

* tfhe: get rid of constant generics

* polish & clean a bit

* add methods for encoding constants for ct-pt-multiplication
This commit is contained in:
2025-08-14 18:32:43 +02:00
committed by GitHub
parent 13abadf6e1
commit fb1fb6b4e9
23 changed files with 2737 additions and 1887 deletions

View File

@@ -4,242 +4,275 @@ use rand::Rng;
use std::iter::Sum;
use std::ops::{Add, AddAssign, Mul, Sub};
use arith::{Ring, Rq, Tn, Zq, T64, TR};
use gfhe::{glwe, GLWE};
use arith::{Ring, RingParam, Rq, Tn, Zq, T64, TR};
use gfhe::{glwe, glwe::Param, GLWE};
use crate::tggsw::TGGSW;
use crate::tlev::TLev;
use crate::{tglwe, tglwe::TGLWE};
pub struct SecretKey<const K: usize>(pub glwe::SecretKey<T64, K>);
pub struct SecretKey(pub glwe::SecretKey<T64>);
impl<const KN: usize> SecretKey<KN> {
impl SecretKey {
/// from TFHE [2018-421] paper: A TLWE key k \in B^n, can be interpreted as a
/// TRLWE key K \in B_N[X]^k having the same sequence of coefficients and
/// vice-versa.
pub fn to_tglwe<const N: usize, const K: usize>(self) -> crate::tglwe::SecretKey<N, K> {
let s: TR<T64, KN> = self.0 .0;
pub fn to_tglwe(self, param: &Param) -> crate::tglwe::SecretKey {
let s: TR<T64> = self.0 .0; // of length K*N
assert_eq!(s.r.len(), param.k * param.ring.n); // sanity check
// split into K vectors, and interpret each of them as a T_N[X]/(X^N+1)
// polynomial
let r: Vec<Tn<N>> =
s.0.chunks(N)
.map(|v| Tn::<N>::from_vec(v.to_vec()))
let r: Vec<Tn> =
s.r.chunks(param.ring.n)
.map(|v| Tn::from_vec(&param.ring, v.to_vec()))
.collect();
crate::tglwe::SecretKey(glwe::SecretKey::<Tn<N>, K>(TR(r)))
crate::tglwe::SecretKey(glwe::SecretKey::<Tn>(TR { k: param.k, r }))
}
}
pub type PublicKey<const K: usize> = glwe::PublicKey<T64, K>;
pub type PublicKey = glwe::PublicKey<T64>;
#[derive(Clone, Debug)]
pub struct KSK<const K: usize>(Vec<TLev<K>>);
pub struct KSK(Vec<TLev>);
#[derive(Clone, Debug)]
pub struct TLWE<const K: usize>(pub GLWE<T64, K>);
pub struct TLWE(pub GLWE<T64>);
impl<const K: usize> TLWE<K> {
pub fn zero() -> Self {
Self(GLWE::<T64, K>::zero())
impl TLWE {
pub fn zero(k: usize, ring_param: &RingParam) -> Self {
Self(GLWE::<T64>::zero(k, ring_param))
}
pub fn new_key(rng: impl Rng) -> Result<(SecretKey<K>, PublicKey<K>)> {
let (sk, pk): (glwe::SecretKey<T64, K>, glwe::PublicKey<T64, K>) = GLWE::new_key(rng)?;
pub fn new_key(rng: impl Rng, param: &Param) -> Result<(SecretKey, PublicKey)> {
let (sk, pk): (glwe::SecretKey<T64>, glwe::PublicKey<T64>) = GLWE::new_key(rng, param)?;
Ok((SecretKey(sk), pk))
}
pub fn encode<const P: u64>(m: &Rq<P, 1>) -> T64 {
let delta = u64::MAX / P; // floored
pub fn encode(param: &Param, m: &Rq) -> T64 {
assert_eq!(param.ring.n, 1);
debug_assert_eq!(param.t, m.param.q); // plaintext modulus
let delta = u64::MAX / param.t; // floored
let coeffs = m.coeffs();
T64(coeffs[0].0 * delta)
T64(coeffs[0].v * delta)
}
pub fn decode<const P: u64>(p: &T64) -> Rq<P, 1> {
let p = p.mul_div_round(P, u64::MAX);
Rq::<P, 1>::from_vec_u64(p.coeffs().iter().map(|c| c.0).collect())
pub fn decode(param: &Param, p: &T64) -> Rq {
let p = p.mul_div_round(param.t, u64::MAX);
Rq::from_vec_u64(&param.pt(), p.coeffs().iter().map(|c| c.0).collect())
}
/// encodes the given message as a TLWE constant/public value, for using it
/// in ct-pt-multiplication.
pub fn new_const(param: &Param, m: &Rq) -> T64 {
debug_assert_eq!(param.t, m.param.q);
T64(m.coeffs()[0].v)
}
// encrypts with the given SecretKey (instead of PublicKey)
pub fn encrypt_s(rng: impl Rng, sk: &SecretKey<K>, p: &T64) -> Result<Self> {
let glwe = GLWE::encrypt_s(rng, &sk.0, p)?;
pub fn encrypt_s(rng: impl Rng, param: &Param, sk: &SecretKey, p: &T64) -> Result<Self> {
let glwe = GLWE::encrypt_s(rng, param, &sk.0, p)?;
Ok(Self(glwe))
}
pub fn encrypt(rng: impl Rng, pk: &PublicKey<K>, p: &T64) -> Result<Self> {
let glwe = GLWE::encrypt(rng, &pk, p)?;
pub fn encrypt(rng: impl Rng, param: &Param, pk: &PublicKey, p: &T64) -> Result<Self> {
let glwe = GLWE::encrypt(rng, param, pk, p)?;
Ok(Self(glwe))
}
pub fn decrypt(&self, sk: &SecretKey<K>) -> T64 {
pub fn decrypt(&self, sk: &SecretKey) -> T64 {
self.0.decrypt(&sk.0)
}
pub fn new_ksk(
mut rng: impl Rng,
param: &Param,
beta: u32,
l: u32,
sk: &SecretKey<K>,
new_sk: &SecretKey<K>,
) -> Result<KSK<K>> {
let r: Vec<TLev<K>> = (0..K)
sk: &SecretKey,
new_sk: &SecretKey,
) -> Result<KSK> {
let r: Vec<TLev> = (0..param.k)
.into_iter()
.map(|i|
// treat sk_i as the msg being encrypted
TLev::<K>::encrypt_s(&mut rng, beta, l, &new_sk, &sk.0.0 .0[i]))
TLev::encrypt_s(&mut rng, param, beta, l, &new_sk, &sk.0.0 .r[i]))
.collect::<Result<Vec<_>>>()?;
Ok(KSK(r))
}
pub fn key_switch(&self, beta: u32, l: u32, ksk: &KSK<K>) -> Self {
let (a, b): (TR<T64, K>, T64) = (self.0 .0.clone(), self.0 .1);
pub fn key_switch(&self, param: &Param, beta: u32, l: u32, ksk: &KSK) -> Self {
let (a, b): (TR<T64>, T64) = (self.0 .0.clone(), self.0 .1);
let lhs: TLWE<K> = TLWE(GLWE(TR::zero(), b));
let lhs: TLWE = TLWE(GLWE(TR::zero(param.k * param.ring.n, &param.ring), b));
// K iterations, ksk.0 contains K times GLev
let rhs: TLWE<K> = zip_eq(a.0, ksk.0.clone())
let rhs: TLWE = zip_eq(a.r, ksk.0.clone())
.map(|(a_i, ksk_i)| ksk_i * a_i.decompose(beta, l)) // dot_product
.sum();
lhs - rhs
}
// modulus switch from Q (2^64) to Q2 (in blind_rotation Q2=K*N)
pub fn mod_switch<const Q2: u64>(&self) -> Self {
let a: TR<T64, K> = self.0 .0.mod_switch::<Q2>();
let b: T64 = self.0 .1.mod_switch::<Q2>();
pub fn mod_switch(&self, q2: u64) -> Self {
let a: TR<T64> = self.0 .0.mod_switch(q2);
let b: T64 = self.0 .1.mod_switch(q2);
Self(GLWE(a, b))
}
}
// NOTE: the ugly const generics are temporary
pub fn blind_rotation<const N: usize, const K: usize, const KN: usize, const KN2: u64>(
c: TLWE<KN>,
btk: BootstrappingKey<N, K, KN>,
table: TGLWE<N, K>,
) -> TGLWE<N, K> {
let c_kn: TLWE<KN> = c.mod_switch::<KN2>();
let (a, b): (TR<T64, KN>, T64) = (c_kn.0 .0, c_kn.0 .1);
pub fn blind_rotation(
param: &Param,
c: TLWE, // kn
btk: BootstrappingKey,
table: TGLWE, // n,k
) -> TGLWE {
debug_assert_eq!(c.0 .0.k, param.k);
// TODO replace `param.k*param.ring.n` by `param.kn()`
let c_kn: TLWE = c.mod_switch((param.k * param.ring.n) as u64);
let (a, b): (TR<T64>, T64) = (c_kn.0 .0, c_kn.0 .1);
// two main parts: rotate by a known power of X, rotate by a secret
// power of X (using the C gate)
// table * X^-b, ie. left rotate
let v_xb: TGLWE<N, K> = table.left_rotate(b.0 as usize);
let v_xb: TGLWE = table.left_rotate(b.0 as usize);
// rotate by a secret power of X using the cmux gate
let mut c_j: TGLWE<N, K> = v_xb.clone();
let _ = (1..K).map(|j| {
c_j = TGGSW::<N, K>::cmux(
let mut c_j: TGLWE = v_xb.clone();
let _ = (1..param.k).map(|j| {
c_j = TGGSW::cmux(
btk.0[j].clone(),
c_j.clone(),
c_j.clone().left_rotate(a.0[j].0 as usize),
c_j.clone().left_rotate(a.r[j].0 as usize),
);
dbg!(&c_j);
});
c_j
}
pub fn bootstrapping<const N: usize, const K: usize, const KN: usize, const KN2: u64>(
btk: BootstrappingKey<N, K, KN>,
table: TGLWE<N, K>,
c: TLWE<KN>,
) -> TLWE<KN> {
let rotated: TGLWE<N, K> = blind_rotation::<N, K, KN, KN2>(c, btk.clone(), table);
let c_h: TLWE<KN> = rotated.sample_extraction(0);
let r = c_h.key_switch(2, 64, &btk.1);
pub fn bootstrapping(
param: &Param,
btk: BootstrappingKey,
table: TGLWE,
c: TLWE, // kn
) -> TLWE {
// kn
let rotated: TGLWE = blind_rotation(param, c, btk.clone(), table);
let c_h: TLWE = rotated.sample_extraction(&param, 0);
let r = c_h.key_switch(param, 2, 64, &btk.1);
r
}
#[derive(Clone, Debug)]
pub struct BootstrappingKey<const N: usize, const K: usize, const KN: usize>(
pub Vec<TGGSW<N, K>>,
pub KSK<KN>,
pub struct BootstrappingKey(
pub Vec<TGGSW>,
pub KSK, // kn
);
impl<const N: usize, const K: usize, const KN: usize> BootstrappingKey<N, K, KN> {
pub fn from_sk(mut rng: impl Rng, sk: &tglwe::SecretKey<N, K>) -> Result<Self> {
impl BootstrappingKey {
pub fn from_sk(mut rng: impl Rng, param: &Param, sk: &tglwe::SecretKey) -> Result<Self> {
let (beta, l) = (2u32, 64u32); // TMP
//
let s: TR<Tn<N>, K> = sk.0 .0.clone();
let (sk2, _) = TLWE::<KN>::new_key(&mut rng)?; // TLWE<KN> compatible with TGLWE<N,K>
let s: TR<Tn> = sk.0 .0.clone();
let (sk2, _) = TLWE::new_key(&mut rng, &param.lwe())?; // TLWE<KN> compatible with TGLWE<N,K>
// each btk_j = TGGSW_sk(s_i)
let btk: Vec<TGGSW<N, K>> = s
let btk: Vec<TGGSW> = s
.iter()
.map(|s_i| TGGSW::<N, K>::encrypt_s(&mut rng, beta, l, sk, s_i))
.map(|s_i| TGGSW::encrypt_s(&mut rng, param, beta, l, sk, s_i))
.collect::<Result<Vec<_>>>()?;
let ksk = TLWE::<KN>::new_ksk(&mut rng, beta, l, &sk.to_tlwe(), &sk2)?;
let ksk = TLWE::new_ksk(
&mut rng,
&param.lwe(),
beta,
l,
&sk.to_tlwe(&param.lwe()), // converted to length k*n
&sk2, // created with length k*n
)?;
debug_assert_eq!(ksk.0.len(), param.lwe().k);
debug_assert_eq!(ksk.0.len(), param.k * param.ring.n);
Ok(Self(btk, ksk))
}
}
pub fn compute_lookup_table<const T: u64, const K: usize, const N: usize>() -> TGLWE<N, K> {
pub fn compute_lookup_table(param: &Param) -> TGLWE {
// from 2021-1402:
// v(x) = \sum_j^{N-1} [(p_j / 2N mod p)/p] X^j
// matrix of coefficients with size K*N = delta x T
let delta: usize = N / T as usize;
let values: Vec<Zq<T>> = (0..T).map(|v| Zq::<T>::from_u64(v)).collect();
let coeffs: Vec<Zq<T>> = (0..T as usize)
let delta: usize = param.ring.n / param.t as usize;
let values: Vec<Zq> = (0..param.t).map(|v| Zq::from_u64(param.t, v)).collect();
let coeffs: Vec<Zq> = (0..param.t as usize)
.flat_map(|i| vec![values[i]; delta])
.collect();
let table = Rq::<T, N>::from_vec(coeffs);
let table = Rq::from_vec(&param.pt(), coeffs);
// encode the table as plaintext
let v: Tn<N> = TGLWE::<N, K>::encode::<T>(&table);
let v: Tn = TGLWE::encode(param, &table);
// encode the table as TGLWE ciphertext
let v: TGLWE<N, K> = TGLWE::<N, K>::from_plaintext(v);
let v: TGLWE = TGLWE::from_plaintext(param.k, &param.ring, v);
v
}
impl<const K: usize> Add<TLWE<K>> for TLWE<K> {
impl Add<TLWE> for TLWE {
type Output = Self;
fn add(self, other: Self) -> Self {
debug_assert_eq!(self.0 .0.k, other.0 .0.k);
debug_assert_eq!(self.0 .1.param(), other.0 .1.param());
Self(self.0 + other.0)
}
}
impl<const K: usize> AddAssign for TLWE<K> {
impl AddAssign for TLWE {
fn add_assign(&mut self, rhs: Self) {
debug_assert_eq!(self.0 .0.k, rhs.0 .0.k);
debug_assert_eq!(self.0 .1.param(), rhs.0 .1.param());
self.0 += rhs.0
}
}
impl<const K: usize> Sum<TLWE<K>> for TLWE<K> {
fn sum<I>(iter: I) -> Self
impl Sum<TLWE> for TLWE {
fn sum<I>(mut iter: I) -> Self
where
I: Iterator<Item = Self>,
{
let mut acc = TLWE::<K>::zero();
for e in iter {
acc += e;
}
acc
let first = iter.next().unwrap();
iter.fold(first, |acc, e| acc + e)
}
}
impl<const K: usize> Sub<TLWE<K>> for TLWE<K> {
impl Sub<TLWE> for TLWE {
type Output = Self;
fn sub(self, other: Self) -> Self {
debug_assert_eq!(self.0 .0.k, other.0 .0.k);
debug_assert_eq!(self.0 .1.param(), other.0 .1.param());
Self(self.0 - other.0)
}
}
// plaintext addition
impl<const K: usize> Add<T64> for TLWE<K> {
impl Add<T64> for TLWE {
type Output = Self;
fn add(self, plaintext: T64) -> Self {
let a: TR<T64, K> = self.0 .0;
let a: TR<T64> = self.0 .0;
let b: T64 = self.0 .1 + plaintext;
Self(GLWE(a, b))
}
}
// plaintext substraction
impl<const K: usize> Sub<T64> for TLWE<K> {
impl Sub<T64> for TLWE {
type Output = Self;
fn sub(self, plaintext: T64) -> Self {
let a: TR<T64, K> = self.0 .0;
let a: TR<T64> = self.0 .0;
let b: T64 = self.0 .1 - plaintext;
Self(GLWE(a, b))
}
}
// plaintext multiplication
impl<const K: usize> Mul<T64> for TLWE<K> {
impl Mul<T64> for TLWE {
type Output = Self;
fn mul(self, plaintext: T64) -> Self {
let a: TR<T64, K> = TR(self.0 .0 .0.iter().map(|r_i| *r_i * plaintext).collect());
let a: TR<T64> = TR {
k: self.0 .0.k,
r: self.0 .0.r.iter().map(|r_i| *r_i * plaintext).collect(),
};
let b: T64 = self.0 .1 * plaintext;
Self(GLWE(a, b))
}
@@ -255,29 +288,32 @@ mod tests {
#[test]
fn test_encrypt_decrypt() -> Result<()> {
const T: u64 = 128; // msg space (msg modulus)
const K: usize = 16;
type S = TLWE<K>;
let param = Param {
err_sigma: crate::ERR_SIGMA,
ring: RingParam { q: u64::MAX, n: 1 },
k: 16,
t: 128, // plaintext modulus
};
let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..200 {
let (sk, pk) = S::new_key(&mut rng)?;
let (sk, pk) = TLWE::new_key(&mut rng, &param)?;
let m = Rq::<T, 1>::rand_u64(&mut rng, msg_dist)?;
let p: T64 = S::encode::<T>(&m);
let m = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p: T64 = TLWE::encode(&param, &m);
let c = S::encrypt(&mut rng, &pk, &p)?;
let c = TLWE::encrypt(&mut rng, &param, &pk, &p)?;
let p_recovered = c.decrypt(&sk);
let m_recovered = S::decode::<T>(&p_recovered);
let m_recovered = TLWE::decode(&param, &p_recovered);
assert_eq!(m, m_recovered);
// same but using encrypt_s (with sk instead of pk))
let c = S::encrypt_s(&mut rng, &sk, &p)?;
let c = TLWE::encrypt_s(&mut rng, &param, &sk, &p)?;
let p_recovered = c.decrypt(&sk);
let m_recovered = S::decode::<T>(&p_recovered);
let m_recovered = TLWE::decode(&param, &p_recovered);
assert_eq!(m, m_recovered);
}
@@ -287,30 +323,33 @@ mod tests {
#[test]
fn test_addition() -> Result<()> {
const T: u64 = 128;
const K: usize = 16;
type S = TLWE<K>;
let param = Param {
err_sigma: crate::ERR_SIGMA,
ring: RingParam { q: u64::MAX, n: 1 },
k: 16,
t: 128, // plaintext modulus
};
let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..200 {
let (sk, pk) = S::new_key(&mut rng)?;
let (sk, pk) = TLWE::new_key(&mut rng, &param)?;
let m1 = Rq::<T, 1>::rand_u64(&mut rng, msg_dist)?;
let m2 = Rq::<T, 1>::rand_u64(&mut rng, msg_dist)?;
let p1: T64 = S::encode::<T>(&m1); // plaintext
let p2: T64 = S::encode::<T>(&m2); // plaintext
let m1 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p1: T64 = TLWE::encode(&param, &m1); // plaintext
let p2: T64 = TLWE::encode(&param, &m2); // plaintext
let c1 = S::encrypt(&mut rng, &pk, &p1)?;
let c2 = S::encrypt(&mut rng, &pk, &p2)?;
let c1 = TLWE::encrypt(&mut rng, &param, &pk, &p1)?;
let c2 = TLWE::encrypt(&mut rng, &param, &pk, &p2)?;
let c3 = c1 + c2;
let p3_recovered = c3.decrypt(&sk);
let m3_recovered = S::decode::<T>(&p3_recovered);
let m3_recovered = TLWE::decode(&param, &p3_recovered);
assert_eq!((m1 + m2).remodule::<T>(), m3_recovered.remodule::<T>());
assert_eq!((m1 + m2).remodule(param.t), m3_recovered.remodule(param.t));
}
Ok(())
@@ -318,27 +357,30 @@ mod tests {
#[test]
fn test_add_plaintext() -> Result<()> {
const T: u64 = 128;
const K: usize = 16;
type S = TLWE<K>;
let param = Param {
err_sigma: crate::ERR_SIGMA,
ring: RingParam { q: u64::MAX, n: 1 },
k: 16,
t: 128, // plaintext modulus
};
let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..200 {
let (sk, pk) = S::new_key(&mut rng)?;
let (sk, pk) = TLWE::new_key(&mut rng, &param)?;
let m1 = Rq::<T, 1>::rand_u64(&mut rng, msg_dist)?;
let m2 = Rq::<T, 1>::rand_u64(&mut rng, msg_dist)?;
let p1: T64 = S::encode::<T>(&m1); // plaintext
let p2: T64 = S::encode::<T>(&m2); // plaintext
let m1 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p1: T64 = TLWE::encode(&param, &m1); // plaintext
let p2: T64 = TLWE::encode(&param, &m2); // plaintext
let c1 = S::encrypt(&mut rng, &pk, &p1)?;
let c1 = TLWE::encrypt(&mut rng, &param, &pk, &p1)?;
let c3 = c1 + p2;
let p3_recovered = c3.decrypt(&sk);
let m3_recovered = S::decode::<T>(&p3_recovered);
let m3_recovered = TLWE::decode(&param, &p3_recovered);
assert_eq!(m1 + m2, m3_recovered);
}
@@ -348,30 +390,31 @@ mod tests {
#[test]
fn test_mul_plaintext() -> Result<()> {
const T: u64 = 128;
const K: usize = 16;
type S = TLWE<K>;
let param = Param {
err_sigma: crate::ERR_SIGMA,
ring: RingParam { q: u64::MAX, n: 1 },
k: 16,
t: 128, // plaintext modulus
};
let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..200 {
let (sk, pk) = S::new_key(&mut rng)?;
let (sk, pk) = TLWE::new_key(&mut rng, &param)?;
let m1 = Rq::<T, 1>::rand_u64(&mut rng, msg_dist)?;
let m2 = Rq::<T, 1>::rand_u64(&mut rng, msg_dist)?;
let p1: T64 = S::encode::<T>(&m1);
// don't scale up p2, set it directly from m2
// let p2: T64 = Tn(array::from_fn(|i| T64(m2.coeffs()[i].0)));
let p2: T64 = T64(m2.coeffs()[0].0);
let m1 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p1: T64 = TLWE::encode(&param, &m1);
let p2: T64 = TLWE::new_const(&param, &m2); // as constant/public value
let c1 = S::encrypt(&mut rng, &pk, &p1)?;
let c1 = TLWE::encrypt(&mut rng, &param, &pk, &p1)?;
let c3 = c1 * p2;
let p3_recovered: T64 = c3.decrypt(&sk);
let m3_recovered = S::decode::<T>(&p3_recovered);
assert_eq!((m1.to_r() * m2.to_r()).to_rq::<T>(), m3_recovered);
let m3_recovered = TLWE::decode(&param, &p3_recovered);
assert_eq!((m1.to_r() * m2.to_r()).to_rq(param.t), m3_recovered);
}
Ok(())
@@ -379,38 +422,41 @@ mod tests {
#[test]
fn test_key_switch() -> Result<()> {
const T: u64 = 128; // plaintext modulus
const K: usize = 16;
type S = TLWE<K>;
let param = Param {
err_sigma: crate::ERR_SIGMA,
ring: RingParam { q: u64::MAX, n: 1 },
k: 16,
t: 128, // plaintext modulus
};
let beta: u32 = 2;
let l: u32 = 64;
let mut rng = rand::thread_rng();
let (sk, pk) = S::new_key(&mut rng)?;
let (sk2, _) = S::new_key(&mut rng)?;
let (sk, pk) = TLWE::new_key(&mut rng, &param)?;
let (sk2, _) = TLWE::new_key(&mut rng, &param)?;
// ksk to switch from sk to sk2
let ksk = S::new_ksk(&mut rng, beta, l, &sk, &sk2)?;
let ksk = TLWE::new_ksk(&mut rng, &param, beta, l, &sk, &sk2)?;
let msg_dist = Uniform::new(0_u64, T);
let m = Rq::<T, 1>::rand_u64(&mut rng, msg_dist)?;
let p = S::encode::<T>(&m); // plaintext
//
let c = S::encrypt_s(&mut rng, &sk, &p)?;
let msg_dist = Uniform::new(0_u64, param.t);
let m = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p = TLWE::encode(&param, &m); // plaintext
let c2 = c.key_switch(beta, l, &ksk);
let c = TLWE::encrypt_s(&mut rng, &param, &sk, &p)?;
let c2 = c.key_switch(&param, beta, l, &ksk);
// decrypt with the 2nd secret key
let p_recovered = c2.decrypt(&sk2);
let m_recovered = S::decode::<T>(&p_recovered);
assert_eq!(m.remodule::<T>(), m_recovered.remodule::<T>());
let m_recovered = TLWE::decode(&param, &p_recovered);
assert_eq!(m.remodule(param.t), m_recovered.remodule(param.t));
// do the same but now encrypting with pk
let c = S::encrypt(&mut rng, &pk, &p)?;
let c2 = c.key_switch(beta, l, &ksk);
let c = TLWE::encrypt(&mut rng, &param, &pk, &p)?;
let c2 = c.key_switch(&param, beta, l, &ksk);
let p_recovered = c2.decrypt(&sk2);
let m_recovered = S::decode::<T>(&p_recovered);
let m_recovered = TLWE::decode(&param, &p_recovered);
assert_eq!(m, m_recovered);
Ok(())
@@ -418,39 +464,40 @@ mod tests {
#[test]
fn test_bootstrapping() -> Result<()> {
const T: u64 = 128; // plaintext modulus
const K: usize = 1;
const N: usize = 1024;
const KN: usize = K * N;
let param = Param {
err_sigma: crate::ERR_SIGMA,
ring: RingParam {
q: u64::MAX,
n: 1024,
},
k: 1,
t: 128, // plaintext modulus
};
let mut rng = rand::thread_rng();
let start = Instant::now();
let table: TGLWE<N, K> = compute_lookup_table::<T, K, N>();
let table: TGLWE = compute_lookup_table(&param);
println!("table took: {:?}", start.elapsed());
let (sk, _) = TGLWE::<N, K>::new_key::<KN>(&mut rng)?;
let sk_tlwe: SecretKey<KN> = sk.to_tlwe::<KN>();
let (sk, _) = TGLWE::new_key(&mut rng, &param)?;
let sk_tlwe: SecretKey = sk.to_tlwe(&param);
let start = Instant::now();
let btk = BootstrappingKey::<N, K, KN>::from_sk(&mut rng, &sk)?;
let btk = BootstrappingKey::from_sk(&mut rng, &param, &sk)?;
println!("btk took: {:?}", start.elapsed());
let msg_dist = Uniform::new(0_u64, T);
let m = Rq::<T, 1>::rand_u64(&mut rng, msg_dist)?;
dbg!(&m);
let p = TLWE::<K>::encode::<T>(&m); // plaintext
let msg_dist = Uniform::new(0_u64, param.t);
let m = Rq::rand_u64(&mut rng, msg_dist, &param.lwe().pt())?; // q=t, n=1
let p = TLWE::encode(&param.lwe(), &m); // plaintext
let c = TLWE::<KN>::encrypt_s(&mut rng, &sk_tlwe, &p)?;
let c = TLWE::encrypt_s(&mut rng, &param.lwe(), &sk_tlwe, &p)?;
let start = Instant::now();
// the ugly const generics are temporary
let bootstrapped: TLWE<KN> =
bootstrapping::<N, K, KN, { K as u64 * N as u64 }>(btk, table, c);
let bootstrapped: TLWE = bootstrapping(&param, btk, table, c);
println!("bootstrapping took: {:?}", start.elapsed());
let p_recovered: T64 = bootstrapped.decrypt(&sk_tlwe);
let m_recovered = TLWE::<KN>::decode::<T>(&p_recovered);
dbg!(&m_recovered);
let m_recovered = TLWE::decode(&param.lwe(), &p_recovered);
assert_eq!(m_recovered, m);
Ok(())