Browse Source

add ckks encrypt & decrypt

gfhe-over-ring-trait
arnaucube 1 month ago
parent
commit
6090116a8b
2 changed files with 176 additions and 3 deletions
  1. +1
    -1
      bfv/src/lib.rs
  2. +175
    -2
      ckks/src/lib.rs

+ 1
- 1
bfv/src/lib.rs

@ -105,7 +105,7 @@ impl BFV {
let Xi_err = Normal::new(0_f64, ERR_SIGMA)?;
// secret key
// let s = Rq::<Q, N>::rand_f64(&mut rng, Xi_key)?;
// let mut s = Rq::<Q, N>::rand_f64(&mut rng, Xi_key)?;
let mut s = Rq::<Q, N>::rand_u64(&mut rng, Xi_key)?;
// since s is going to be multiplied by other Rq elements, already
// compute its NTT

+ 175
- 2
ckks/src/lib.rs

@ -1,10 +1,183 @@
//! Implementation of BFV https://eprint.iacr.org/2012/144.pdf
//! Implementation of CKKS https://eprint.iacr.org/2016/421.pdf
#![allow(non_snake_case)]
#![allow(non_upper_case_globals)]
#![allow(non_camel_case_types)]
#![allow(clippy::upper_case_acronyms)]
#![allow(dead_code)] // TMP
pub mod encoder;
use arith::{Rq, C, R};
use anyhow::Result;
use rand::Rng;
use rand_distr::{Normal, Uniform};
pub mod encoder;
pub use encoder::Encoder;
// error deviation for the Gaussian(Normal) distribution
// sigma=3.2 from: https://eprint.iacr.org/2016/421.pdf page 17
const ERR_SIGMA: f64 = 3.2;
#[derive(Debug)]
pub struct PublicKey<const Q: u64, const N: usize>(Rq<Q, N>, Rq<Q, N>);
pub struct SecretKey<const Q: u64, const N: usize>(Rq<Q, N>);
pub struct CKKS<const Q: u64, const N: usize> {
encoder: Encoder<Q, N>,
}
impl<const Q: u64, const N: usize> CKKS<Q, N> {
pub fn new(delta: C<f64>) -> Self {
let encoder = Encoder::<Q, N>::new(delta);
Self { encoder }
}
/// generate a new key pair (privK, pubK)
pub fn new_key(&self, mut rng: impl Rng) -> Result<(SecretKey<Q, N>, PublicKey<Q, N>)> {
let Xi_key = Uniform::new(-1_f64, 1_f64);
let Xi_err = Normal::new(0_f64, ERR_SIGMA)?;
let e = Rq::<Q, N>::rand_f64(&mut rng, Xi_err)?;
let mut s = Rq::<Q, N>::rand_f64(&mut rng, Xi_key)?;
// since s is going to be multiplied by other Rq elements, already
// compute its NTT
s.compute_evals();
let a = Rq::<Q, N>::rand_f64(&mut rng, Xi_key)?;
let pk: PublicKey<Q, N> = PublicKey((&(-a) * &s) + e, a.clone());
Ok((SecretKey(s), pk))
}
// encrypts a plaintext \in R=Z_Q[X]/(X^N+1)
fn encrypt(
&self, // TODO maybe rm?
mut rng: impl Rng,
pk: &PublicKey<Q, N>,
m: &R<N>,
) -> Result<(Rq<Q, N>, Rq<Q, N>)> {
let Xi_key = Uniform::new(-1_f64, 1_f64);
let Xi_err = Normal::new(0_f64, ERR_SIGMA)?;
let e_0 = Rq::<Q, N>::rand_f64(&mut rng, Xi_err)?;
let e_1 = Rq::<Q, N>::rand_f64(&mut rng, Xi_err)?;
let v = Rq::<Q, N>::rand_f64(&mut rng, Xi_key)?;
let m: Rq<Q, N> = Rq::<Q, N>::from(*m);
Ok((m + e_0 + v * pk.0.clone(), v * pk.1.clone() + e_1))
}
fn decrypt(
&self, // TODO maybe rm?
sk: SecretKey<Q, N>,
c: (Rq<Q, N>, Rq<Q, N>),
) -> Result<R<N>> {
let m = c.0.clone() + c.1 * sk.0;
Ok(m.mod_centered_q())
}
pub fn encode_and_encrypt(
&self,
mut rng: impl Rng,
pk: &PublicKey<Q, N>,
z: &[C<f64>],
) -> Result<(Rq<Q, N>, Rq<Q, N>)> {
let m: R<N> = self.encoder.encode(&z)?; // polynomial (encoded vec) \in R
self.encrypt(&mut rng, pk, &m)
}
pub fn decrypt_and_decode(
&self,
sk: SecretKey<Q, N>,
c: (Rq<Q, N>, Rq<Q, N>),
) -> Result<Vec<C<f64>>> {
let d = self.decrypt(sk, c)?;
self.encoder.decode(&d)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encrypt_decrypt() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1;
const T: u64 = 16;
const N: usize = 8;
let scale_factor_u64 = 512_u64; // delta
let scale_factor = C::<f64>::new(512.0, 0.0); // delta
let mut rng = rand::thread_rng();
for _ in 0..1000 {
let ckks = CKKS::<Q, N>::new(scale_factor);
let (sk, pk) = ckks.new_key(&mut rng)?;
let m_raw: R<N> = Rq::<Q, N>::rand_f64(&mut rng, Uniform::new(0_f64, T as f64))?.to_r();
let m = m_raw * scale_factor_u64;
let ct = ckks.encrypt(&mut rng, &pk, &m)?;
let m_decrypted = ckks.decrypt(sk, ct)?;
let m_decrypted: Vec<u64> = m_decrypted
.coeffs()
.iter()
.map(|e| (*e as f64 / (scale_factor_u64 as f64)).round() as u64)
.collect();
let m_decrypted = Rq::<Q, N>::from_vec_u64(m_decrypted);
assert_eq!(m_decrypted, Rq::<Q, N>::from(m_raw));
}
Ok(())
}
#[test]
fn test_encode_encrypt_decrypt_decode() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1;
const T: u64 = 16;
const N: usize = 4;
let scale_factor = C::<f64>::new(512.0, 0.0); // delta
let mut rng = rand::thread_rng();
for _ in 0..1000 {
let ckks = CKKS::<Q, N>::new(scale_factor);
let (sk, pk) = ckks.new_key(&mut rng)?;
let z: Vec<C<f64>> = std::iter::repeat_with(|| C::<f64>::rand(&mut rng, T))
.take(N / 2)
.collect();
let m: R<N> = ckks.encoder.encode(&z)?;
// sanity check
{
let z_decoded = ckks.encoder.decode(&m)?;
let rounded_z_decoded: Vec<C<f64>> = z_decoded
.iter()
.map(|c| C::<f64>::new(c.re.round(), c.im.round()))
.collect();
assert_eq!(rounded_z_decoded, z);
}
let ct = ckks.encrypt(&mut rng, &pk, &m)?;
let m_decrypted = ckks.decrypt(sk, ct)?;
let z_decrypted = ckks.encoder.decode(&m_decrypted)?;
let rounded_z_decrypted: Vec<C<f64>> = z_decrypted
.iter()
.map(|&c| C::<f64>::new(c.re.round(), c.im.round()))
.collect();
assert_eq!(rounded_z_decrypted, z);
}
Ok(())
}
}

Loading…
Cancel
Save