tfhe: add blind_rotation & lookup table computation

This commit is contained in:
2025-08-03 19:12:47 +00:00
parent 7bfcf6f7c1
commit 2c20a2ed0e
5 changed files with 101 additions and 23 deletions

View File

@@ -1,20 +1,17 @@
use anyhow::Result;
use itertools::zip_eq;
use rand::distributions::Standard;
use rand::Rng;
use rand_distr::{Normal, Uniform};
use std::array;
use std::iter::Sum;
use std::ops::{Add, AddAssign, Mul, Sub};
use arith::{Ring, Rq, Tn, T64, TR};
use arith::{Ring, Rq, Tn, Zq, T64, TR};
use gfhe::{glwe, GLWE};
use crate::tggsw::TGGSW;
use crate::tlev::TLev;
use crate::{tglwe, tglwe::TGLWE};
// #[derive(Clone, Debug)]
pub struct SecretKey<const K: usize>(pub glwe::SecretKey<T64, K>);
// pub type SecretKey<const K: usize> = glwe::SecretKey<T64, K>;
impl<const KN: usize> SecretKey<KN> {
/// from TFHE [2018-421] paper: A TLWE key k \in B^n, can be interpreted as a
@@ -32,8 +29,6 @@ impl<const KN: usize> SecretKey<KN> {
}
}
// #[derive(Clone, Debug)]
// pub struct PublicKey<const K: usize>(glwe::PublicKey<T64, K>);
pub type PublicKey<const K: usize> = glwe::PublicKey<T64, K>;
#[derive(Clone, Debug)]
@@ -49,14 +44,12 @@ impl<const K: usize> TLWE<K> {
pub fn new_key(rng: impl Rng) -> Result<(SecretKey<K>, PublicKey<K>)> {
let (sk, pk): (glwe::SecretKey<T64, K>, glwe::PublicKey<T64, K>) = GLWE::new_key(rng)?;
// Ok((SecretKey(sk), PublicKey(pk)))
Ok((SecretKey(sk), pk))
}
pub fn encode<const P: u64>(m: &Rq<P, 1>) -> T64 {
let delta = u64::MAX / P; // floored
let coeffs = m.coeffs();
// Tn(array::from_fn(|i| T64(coeffs[i].0 * delta)))
T64(coeffs[0].0 * delta)
}
pub fn decode<const P: u64>(p: &T64) -> Rq<P, 1> {
@@ -105,6 +98,76 @@ impl<const K: usize> TLWE<K> {
lhs - rhs
}
// modulus switch from Q (2^64) to Q2 (in blind_rotation Q2=K*N)
pub fn mod_switch<const Q2: u64>(&self) -> Self {
let a: TR<T64, K> = self.0 .0.mod_switch::<Q2>();
let b: T64 = self.0 .1.mod_switch::<Q2>();
Self(GLWE(a, b))
}
}
// NOTE: the ugly const generics are temporary
pub fn blind_rotation<const N: usize, const K: usize, const KN: usize, const KN2: u64>(
c: TLWE<KN>,
btk: BootstrappingKey<N, K, KN>,
table: TGLWE<N, K>,
) -> TGLWE<N, K> {
let c_kn: TLWE<KN> = c.mod_switch::<KN2>();
let (a, b): (TR<T64, KN>, T64) = (c_kn.0 .0, c_kn.0 .1);
// two main parts: rotate by a known power of X, rotate by a secret
// power of X (using the C gate)
// table * X^-b, ie. left rotate
let v_xb: TGLWE<N, K> = table.left_rotate(b.0 as usize);
// rotate by a secret power of X using the cmux gate
let mut c_j: TGLWE<N, K> = v_xb.clone();
let _ = (1..K).map(|j| {
c_j = TGGSW::<N, K>::cmux(
btk.0[j].clone(),
c_j.clone(),
c_j.clone().left_rotate(a.0[j].0 as usize),
);
dbg!(&c_j);
});
c_j
}
#[derive(Clone, Debug)]
pub struct BootstrappingKey<const N: usize, const K: usize, const KN: usize>(pub Vec<TGGSW<N, K>>);
impl<const N: usize, const K: usize, const KN: usize> BootstrappingKey<N, K, KN> {
pub fn from_sk(mut rng: impl Rng, sk: &tglwe::SecretKey<N, K>) -> Result<Self> {
let (beta, l) = (2u32, 64u32); // TMP
//
let s: TR<Tn<N>, K> = sk.0 .0.clone();
// each btk_j = TGGSW_sk(s_i)
let btk: Vec<TGGSW<N, K>> = s
.iter()
.map(|s_i| TGGSW::<N, K>::encrypt_s(&mut rng, beta, l, sk, s_i))
.collect::<Result<Vec<_>>>()?;
Ok(Self(btk))
}
}
pub fn compute_lookup_table<const T: u64, const K: usize, const N: usize>() -> TGLWE<N, K> {
// from 2021-1402:
// v(x) = \sum_j^{N-1} [(p_j / 2N mod p)/p] X^j
// matrix of coefficients with size K*N = delta x T
let delta: usize = N / T as usize;
let values: Vec<Zq<T>> = (0..T).map(|v| Zq::<T>::from_u64(v)).collect();
let coeffs: Vec<Zq<T>> = (0..T as usize)
.flat_map(|i| vec![values[i]; delta])
.collect();
let table = Rq::<T, N>::from_vec(coeffs);
// encode the table as plaintext
let v: Tn<N> = TGLWE::<N, K>::encode::<T>(&table);
// encode the table as TGLWE ciphertext
let v: TGLWE<N, K> = TGLWE::<N, K>::from_plaintext(v);
v
}
impl<const K: usize> Add<TLWE<K>> for TLWE<K> {
@@ -170,6 +233,7 @@ impl<const K: usize> Mul<T64> for TLWE<K> {
mod tests {
use anyhow::Result;
use rand::distributions::Uniform;
use std::time::Instant;
use super::*;
@@ -186,9 +250,7 @@ mod tests {
let (sk, pk) = S::new_key(&mut rng)?;
let m = Rq::<T, 1>::rand_u64(&mut rng, msg_dist)?;
dbg!(&m);
let p: T64 = S::encode::<T>(&m);
dbg!(&p);
let c = S::encrypt(&mut rng, &pk, &p)?;
let p_recovered = c.decrypt(&sk);