tfhe: get rid of constant generics

This commit is contained in:
2025-08-13 19:31:43 +00:00
parent 2a9cbc71de
commit bb3288f211
16 changed files with 729 additions and 582 deletions

View File

@@ -23,7 +23,7 @@ pub struct Param {
p: u64,
}
impl Param {
// returns the plaintext params
// returns the plaintext param
pub fn pt(&self) -> RingParam {
RingParam {
q: self.t,
@@ -117,7 +117,7 @@ impl BFV {
// const DELTA: u64 = Q / T; // floor
/// generate a new key pair (privK, pubK)
pub fn new_key(mut rng: impl Rng, params: &Param) -> Result<(SecretKey, PublicKey)> {
pub fn new_key(mut rng: impl Rng, param: &Param) -> Result<(SecretKey, PublicKey)> {
// WIP: review probabilities
// let Xi_key = Uniform::new(-1_f64, 1_f64);
@@ -126,37 +126,37 @@ impl BFV {
// secret key
// let mut s = Rq::rand_f64(&mut rng, Xi_key)?;
let mut s = Rq::rand_u64(&mut rng, Xi_key, &params.ring)?;
let mut s = Rq::rand_u64(&mut rng, Xi_key, &param.ring)?;
// since s is going to be multiplied by other Rq elements, already
// compute its NTT
s.compute_evals();
// pk = (-a * s + e, a)
let a = Rq::rand_u64(&mut rng, Uniform::new(0_u64, params.ring.q), &params.ring)?;
let e = Rq::rand_f64(&mut rng, Xi_err, &params.ring)?;
let a = Rq::rand_u64(&mut rng, Uniform::new(0_u64, param.ring.q), &param.ring)?;
let e = Rq::rand_f64(&mut rng, Xi_err, &param.ring)?;
let pk: PublicKey = PublicKey(&(&(-a.clone()) * &s) + &e, a.clone()); // TODO rm clones
Ok((SecretKey(s), pk))
}
// note: m is modulus t
pub fn encrypt(mut rng: impl Rng, params: &Param, pk: &PublicKey, m: &Rq) -> Result<RLWE> {
// assert params & inputs
debug_assert_eq!(params.ring, pk.0.param);
debug_assert_eq!(params.t, m.param.q);
debug_assert_eq!(params.ring.n, m.param.n);
pub fn encrypt(mut rng: impl Rng, param: &Param, pk: &PublicKey, m: &Rq) -> Result<RLWE> {
// assert param & inputs
debug_assert_eq!(param.ring, pk.0.param);
debug_assert_eq!(param.t, m.param.q);
debug_assert_eq!(param.ring.n, m.param.n);
let Xi_key = Uniform::new(-1_f64, 1_f64);
// let Xi_key = Uniform::new(0_u64, 2_u64);
let Xi_err = Normal::new(0_f64, ERR_SIGMA)?;
let u = Rq::rand_f64(&mut rng, Xi_key, &params.ring)?;
let u = Rq::rand_f64(&mut rng, Xi_key, &param.ring)?;
// let u = Rq::rand_u64(&mut rng, Xi_key)?;
let e_1 = Rq::rand_f64(&mut rng, Xi_err, &params.ring)?;
let e_2 = Rq::rand_f64(&mut rng, Xi_err, &params.ring)?;
let e_1 = Rq::rand_f64(&mut rng, Xi_err, &param.ring)?;
let e_2 = Rq::rand_f64(&mut rng, Xi_err, &param.ring)?;
// migrate m's coeffs to the bigger modulus Q (from T)
let m = m.remodule(params.ring.q);
let c0 = &pk.0 * &u + e_1 + m * (params.ring.q / params.t); // floor(q/t)=DELTA
let m = m.remodule(param.ring.q);
let c0 = &pk.0 * &u + e_1 + m * (param.ring.q / param.t); // floor(q/t)=DELTA
let c1 = &pk.1 * &u + e_2;
Ok(RLWE(c0, c1))
}
@@ -280,7 +280,7 @@ mod tests {
#[test]
fn test_encrypt_decrypt() -> Result<()> {
let params = Param {
let param = Param {
ring: RingParam {
q: 2u64.pow(16) + 1, // q prime, and 2^q + 1 shape
n: 512,
@@ -292,13 +292,13 @@ mod tests {
let mut rng = rand::thread_rng();
for _ in 0..100 {
let (sk, pk) = BFV::new_key(&mut rng, &params)?;
let (sk, pk) = BFV::new_key(&mut rng, &param)?;
let msg_dist = Uniform::new(0_u64, params.t);
let m = Rq::rand_u64(&mut rng, msg_dist, &params.pt())?;
let msg_dist = Uniform::new(0_u64, param.t);
let m = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let c = BFV::encrypt(&mut rng, &params, &pk, &m)?;
let m_recovered = BFV::decrypt(&params, &sk, &c);
let c = BFV::encrypt(&mut rng, &param, &pk, &m)?;
let m_recovered = BFV::decrypt(&param, &sk, &c);
assert_eq!(m, m_recovered);
}
@@ -308,7 +308,7 @@ mod tests {
#[test]
fn test_addition() -> Result<()> {
let params = Param {
let param = Param {
ring: RingParam {
q: 2u64.pow(16) + 1, // q prime, and 2^q + 1 shape
n: 128,
@@ -320,18 +320,18 @@ mod tests {
let mut rng = rand::thread_rng();
for _ in 0..100 {
let (sk, pk) = BFV::new_key(&mut rng, &params)?;
let (sk, pk) = BFV::new_key(&mut rng, &param)?;
let msg_dist = Uniform::new(0_u64, params.t);
let m1 = Rq::rand_u64(&mut rng, msg_dist, &params.pt())?;
let m2 = Rq::rand_u64(&mut rng, msg_dist, &params.pt())?;
let msg_dist = Uniform::new(0_u64, param.t);
let m1 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let c1 = BFV::encrypt(&mut rng, &params, &pk, &m1)?;
let c2 = BFV::encrypt(&mut rng, &params, &pk, &m2)?;
let c1 = BFV::encrypt(&mut rng, &param, &pk, &m1)?;
let c2 = BFV::encrypt(&mut rng, &param, &pk, &m2)?;
let c3 = c1 + c2;
let m3_recovered = BFV::decrypt(&params, &sk, &c3);
let m3_recovered = BFV::decrypt(&param, &sk, &c3);
assert_eq!(m1 + m2, m3_recovered);
}
@@ -342,7 +342,7 @@ mod tests {
#[test]
fn test_constant_add_mul() -> Result<()> {
let q: u64 = 2u64.pow(16) + 1; // q prime, and 2^q + 1 shape
let params = Param {
let param = Param {
ring: RingParam { q, n: 16 },
t: 8, // plaintext modulus
p: q * q,
@@ -350,26 +350,26 @@ mod tests {
let mut rng = rand::thread_rng();
let (sk, pk) = BFV::new_key(&mut rng, &params)?;
let (sk, pk) = BFV::new_key(&mut rng, &param)?;
let msg_dist = Uniform::new(0_u64, params.t);
let m1 = Rq::rand_u64(&mut rng, msg_dist, &params.pt())?;
let m2_const = Rq::rand_u64(&mut rng, msg_dist, &params.pt())?;
let c1 = BFV::encrypt(&mut rng, &params, &pk, &m1)?;
let msg_dist = Uniform::new(0_u64, param.t);
let m1 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2_const = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let c1 = BFV::encrypt(&mut rng, &param, &pk, &m1)?;
let c3_add = &c1 + &m2_const;
let m3_add_recovered = BFV::decrypt(&params, &sk, &c3_add);
let m3_add_recovered = BFV::decrypt(&param, &sk, &c3_add);
assert_eq!(&m1 + &m2_const, m3_add_recovered);
// test multiplication of a ciphertext by a constant
let rlk = BFV::rlk_key(&mut rng, &params, &sk)?;
let rlk = BFV::rlk_key(&mut rng, &param, &sk)?;
let c3_mul = BFV::mul_const(&rlk, &c1, &m2_const);
let m3_mul_recovered = BFV::decrypt(&params, &sk, &c3_mul);
let m3_mul_recovered = BFV::decrypt(&param, &sk, &c3_mul);
assert_eq!(
(m1.to_r() * m2_const.to_r()).to_rq(params.t).coeffs(),
(m1.to_r() * m2_const.to_r()).to_rq(param.t).coeffs(),
m3_mul_recovered.coeffs()
);
@@ -380,7 +380,7 @@ mod tests {
// TMP WIP
#[test]
#[ignore]
fn test_params() -> Result<()> {
fn test_param() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1; // q prime, and 2^q + 1 shape
const N: usize = 32;
const T: u64 = 8; // plaintext modulus
@@ -504,30 +504,30 @@ mod tests {
#[test]
fn test_tensor() -> Result<()> {
let q: u64 = 2u64.pow(16) + 1; // q prime, and 2^q + 1 shape
let params = Param {
let param = Param {
ring: RingParam { q, n: 16 },
t: 2, // plaintext modulus
p: q * q,
};
let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, params.t);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..1_000 {
let m1 = Rq::rand_u64(&mut rng, msg_dist, &params.pt())?;
let m2 = Rq::rand_u64(&mut rng, msg_dist, &params.pt())?;
let m1 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
test_tensor_opt(&mut rng, &params, m1, m2)?;
test_tensor_opt(&mut rng, &param, m1, m2)?;
}
Ok(())
}
fn test_tensor_opt(mut rng: impl Rng, params: &Param, m1: Rq, m2: Rq) -> Result<()> {
let (sk, pk) = BFV::new_key(&mut rng, &params)?;
fn test_tensor_opt(mut rng: impl Rng, param: &Param, m1: Rq, m2: Rq) -> Result<()> {
let (sk, pk) = BFV::new_key(&mut rng, &param)?;
let c1 = BFV::encrypt(&mut rng, &params, &pk, &m1)?;
let c2 = BFV::encrypt(&mut rng, &params, &pk, &m2)?;
let c1 = BFV::encrypt(&mut rng, &param, &pk, &m1)?;
let c2 = BFV::encrypt(&mut rng, &param, &pk, &m2)?;
let (c_a, c_b, c_c) = RLWE::tensor(params.t, &c1, &c2);
let (c_a, c_b, c_c) = RLWE::tensor(param.t, &c1, &c2);
// let (c_a, c_b, c_c) = RLWE::tensor_new::<PQ, T>(&c1, &c2);
// decrypt non-relinearized mul result
@@ -539,10 +539,10 @@ mod tests {
// &c_c.to_r(),
// &R::<N>::from_vec(arith::ring_n::naive_mul(&sk.0.to_r(), &sk.0.to_r())),
// ));
let m3: Rq = m3.mul_div_round(params.t, params.ring.q); // descale
let m3 = m3.remodule(params.t);
let m3: Rq = m3.mul_div_round(param.t, param.ring.q); // descale
let m3 = m3.remodule(param.t);
let naive = (m1.clone().to_r() * m2.clone().to_r()).to_rq(params.t); // TODO rm clones
let naive = (m1.clone().to_r() * m2.clone().to_r()).to_rq(param.t); // TODO rm clones
assert_eq!(
m3.coeffs().to_vec(),
naive.coeffs().to_vec(),
@@ -557,38 +557,38 @@ mod tests {
#[test]
fn test_mul_relin() -> Result<()> {
let q: u64 = 2u64.pow(16) + 1; // q prime, and 2^q + 1 shape
let params = Param {
let param = Param {
ring: RingParam { q, n: 16 },
t: 2, // plaintext modulus
p: q * q,
};
let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, params.t);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..1_000 {
let m1 = Rq::rand_u64(&mut rng, msg_dist, &params.pt())?;
let m2 = Rq::rand_u64(&mut rng, msg_dist, &params.pt())?;
let m1 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
test_mul_relin_opt(&mut rng, &params, m1, m2)?;
test_mul_relin_opt(&mut rng, &param, m1, m2)?;
}
Ok(())
}
fn test_mul_relin_opt(mut rng: impl Rng, params: &Param, m1: Rq, m2: Rq) -> Result<()> {
let (sk, pk) = BFV::new_key(&mut rng, &params)?;
fn test_mul_relin_opt(mut rng: impl Rng, param: &Param, m1: Rq, m2: Rq) -> Result<()> {
let (sk, pk) = BFV::new_key(&mut rng, &param)?;
let rlk = BFV::rlk_key(&mut rng, &params, &sk)?;
let rlk = BFV::rlk_key(&mut rng, &param, &sk)?;
let c1 = BFV::encrypt(&mut rng, &params, &pk, &m1)?;
let c2 = BFV::encrypt(&mut rng, &params, &pk, &m2)?;
let c1 = BFV::encrypt(&mut rng, &param, &pk, &m1)?;
let c2 = BFV::encrypt(&mut rng, &param, &pk, &m2)?;
let c3 = RLWE::mul(params.t, &rlk, &c1, &c2); // uses relinearize internally
let c3 = RLWE::mul(param.t, &rlk, &c1, &c2); // uses relinearize internally
let m3 = BFV::decrypt(&params, &sk, &c3);
let m3 = BFV::decrypt(&param, &sk, &c3);
let naive = (m1.clone().to_r() * m2.clone().to_r()).to_rq(params.t); // TODO rm clones
let naive = (m1.clone().to_r() * m2.clone().to_r()).to_rq(param.t); // TODO rm clones
assert_eq!(
m3.coeffs().to_vec(),
naive.coeffs().to_vec(),