tfhe: get rid of constant generics

This commit is contained in:
2025-08-13 19:31:43 +00:00
parent 2a9cbc71de
commit bb3288f211
16 changed files with 729 additions and 582 deletions

View File

@@ -19,7 +19,7 @@ pub use encoder::Encoder;
const ERR_SIGMA: f64 = 3.2;
#[derive(Clone, Copy, Debug)]
pub struct Params {
pub struct Param {
ring: RingParam,
t: u64,
}
@@ -30,33 +30,33 @@ pub struct PublicKey(Rq, Rq);
pub struct SecretKey(Rq);
pub struct CKKS {
params: Params,
param: Param,
encoder: Encoder,
}
impl CKKS {
pub fn new(params: &Params, delta: C<f64>) -> Self {
let encoder = Encoder::new(params.ring.n, delta);
pub fn new(param: &Param, delta: C<f64>) -> Self {
let encoder = Encoder::new(param.ring.n, delta);
Self {
params: params.clone(),
param: param.clone(),
encoder,
}
}
/// generate a new key pair (privK, pubK)
pub fn new_key(&self, mut rng: impl Rng) -> Result<(SecretKey, PublicKey)> {
let params = &self.params;
let param = &self.param;
let Xi_key = Uniform::new(-1_f64, 1_f64);
let Xi_err = Normal::new(0_f64, ERR_SIGMA)?;
let e = Rq::rand_f64(&mut rng, Xi_err, &params.ring)?;
let e = Rq::rand_f64(&mut rng, Xi_err, &param.ring)?;
let mut s = Rq::rand_f64(&mut rng, Xi_key, &params.ring)?;
let mut s = Rq::rand_f64(&mut rng, Xi_key, &param.ring)?;
// since s is going to be multiplied by other Rq elements, already
// compute its NTT
s.compute_evals();
let a = Rq::rand_f64(&mut rng, Xi_key, &params.ring)?;
let a = Rq::rand_f64(&mut rng, Xi_key, &param.ring)?;
let pk: PublicKey = PublicKey((&(-a.clone()) * &s) + e, a.clone()); // TODO rm clones
Ok((SecretKey(s), pk))
@@ -69,17 +69,17 @@ impl CKKS {
pk: &PublicKey,
m: &R,
) -> Result<(Rq, Rq)> {
let params = self.params;
let param = self.param;
let Xi_key = Uniform::new(-1_f64, 1_f64);
let Xi_err = Normal::new(0_f64, ERR_SIGMA)?;
let e_0 = Rq::rand_f64(&mut rng, Xi_err, &params.ring)?;
let e_1 = Rq::rand_f64(&mut rng, Xi_err, &params.ring)?;
let e_0 = Rq::rand_f64(&mut rng, Xi_err, &param.ring)?;
let e_1 = Rq::rand_f64(&mut rng, Xi_err, &param.ring)?;
let v = Rq::rand_f64(&mut rng, Xi_key, &params.ring)?;
let v = Rq::rand_f64(&mut rng, Xi_key, &param.ring)?;
// let m: Rq = Rq::from(*m);
let m: Rq = m.clone().to_rq(params.ring.q); // TODO rm clone
let m: Rq = m.clone().to_rq(param.ring.q); // TODO rm clone
Ok((m + e_0 + &v * &pk.0.clone(), &v * &pk.1 + e_1))
}
@@ -127,7 +127,7 @@ mod tests {
let q: u64 = 2u64.pow(16) + 1;
let n: usize = 32;
let t: u64 = 50;
let params = Params {
let param = Param {
ring: RingParam { q, n },
t,
};
@@ -137,12 +137,12 @@ mod tests {
let mut rng = rand::thread_rng();
for _ in 0..1000 {
let ckks = CKKS::new(&params, scale_factor);
let ckks = CKKS::new(&param, scale_factor);
let (sk, pk) = ckks.new_key(&mut rng)?;
let m_raw: R =
Rq::rand_f64(&mut rng, Uniform::new(0_f64, t as f64), &params.ring)?.to_r();
Rq::rand_f64(&mut rng, Uniform::new(0_f64, t as f64), &param.ring)?.to_r();
let m = &m_raw * &scale_factor_u64;
let ct = ckks.encrypt(&mut rng, &pk, &m)?;
@@ -153,7 +153,7 @@ mod tests {
.iter()
.map(|e| (*e as f64 / (scale_factor_u64 as f64)).round() as u64)
.collect();
let m_decrypted = Rq::from_vec_u64(&params.ring, m_decrypted);
let m_decrypted = Rq::from_vec_u64(&param.ring, m_decrypted);
// assert_eq!(m_decrypted, Rq::from(m_raw));
assert_eq!(m_decrypted, m_raw.to_rq(q));
}
@@ -166,7 +166,7 @@ mod tests {
let q: u64 = 2u64.pow(16) + 1;
let n: usize = 16;
let t: u64 = 8;
let params = Params {
let param = Param {
ring: RingParam { q, n },
t,
};
@@ -175,7 +175,7 @@ mod tests {
let mut rng = rand::thread_rng();
for _ in 0..1000 {
let ckks = CKKS::new(&params, scale_factor);
let ckks = CKKS::new(&param, scale_factor);
let (sk, pk) = ckks.new_key(&mut rng)?;
let z: Vec<C<f64>> = std::iter::repeat_with(|| C::<f64>::rand(&mut rng, t))
@@ -215,7 +215,7 @@ mod tests {
let q: u64 = 2u64.pow(16) + 1;
let n: usize = 16;
let t: u64 = 8;
let params = Params {
let param = Param {
ring: RingParam { q, n },
t,
};
@@ -224,7 +224,7 @@ mod tests {
let mut rng = rand::thread_rng();
for _ in 0..1000 {
let ckks = CKKS::new(&params, scale_factor);
let ckks = CKKS::new(&param, scale_factor);
let (sk, pk) = ckks.new_key(&mut rng)?;
@@ -261,8 +261,8 @@ mod tests {
fn test_sub() -> Result<()> {
let q: u64 = 2u64.pow(16) + 1;
let n: usize = 16;
let t: u64 = 4;
let params = Params {
let t: u64 = 2;
let param = Param {
ring: RingParam { q, n },
t,
};
@@ -271,7 +271,7 @@ mod tests {
let mut rng = rand::thread_rng();
for _ in 0..1000 {
let ckks = CKKS::new(&params, scale_factor);
let ckks = CKKS::new(&param, scale_factor);
let (sk, pk) = ckks.new_key(&mut rng)?;