diff --git a/bfv/src/lib.rs b/bfv/src/lib.rs index 7e2fe54..bde6e40 100644 --- a/bfv/src/lib.rs +++ b/bfv/src/lib.rs @@ -105,7 +105,7 @@ impl BFV { let Xi_err = Normal::new(0_f64, ERR_SIGMA)?; // secret key - // let s = Rq::::rand_f64(&mut rng, Xi_key)?; + // let mut s = Rq::::rand_f64(&mut rng, Xi_key)?; let mut s = Rq::::rand_u64(&mut rng, Xi_key)?; // since s is going to be multiplied by other Rq elements, already // compute its NTT diff --git a/ckks/src/lib.rs b/ckks/src/lib.rs index 524be92..2e36bce 100644 --- a/ckks/src/lib.rs +++ b/ckks/src/lib.rs @@ -1,10 +1,183 @@ -//! Implementation of BFV https://eprint.iacr.org/2012/144.pdf +//! Implementation of CKKS https://eprint.iacr.org/2016/421.pdf #![allow(non_snake_case)] #![allow(non_upper_case_globals)] #![allow(non_camel_case_types)] #![allow(clippy::upper_case_acronyms)] #![allow(dead_code)] // TMP -pub mod encoder; +use arith::{Rq, C, R}; + +use anyhow::Result; +use rand::Rng; +use rand_distr::{Normal, Uniform}; +pub mod encoder; pub use encoder::Encoder; + +// error deviation for the Gaussian(Normal) distribution +// sigma=3.2 from: https://eprint.iacr.org/2016/421.pdf page 17 +const ERR_SIGMA: f64 = 3.2; + +#[derive(Debug)] +pub struct PublicKey(Rq, Rq); + +pub struct SecretKey(Rq); + +pub struct CKKS { + encoder: Encoder, +} + +impl CKKS { + pub fn new(delta: C) -> Self { + let encoder = Encoder::::new(delta); + Self { encoder } + } + /// generate a new key pair (privK, pubK) + pub fn new_key(&self, mut rng: impl Rng) -> Result<(SecretKey, PublicKey)> { + let Xi_key = Uniform::new(-1_f64, 1_f64); + let Xi_err = Normal::new(0_f64, ERR_SIGMA)?; + + let e = Rq::::rand_f64(&mut rng, Xi_err)?; + + let mut s = Rq::::rand_f64(&mut rng, Xi_key)?; + // since s is going to be multiplied by other Rq elements, already + // compute its NTT + s.compute_evals(); + + let a = Rq::::rand_f64(&mut rng, Xi_key)?; + + let pk: PublicKey = PublicKey((&(-a) * &s) + e, a.clone()); + Ok((SecretKey(s), pk)) + } + + // encrypts a plaintext \in R=Z_Q[X]/(X^N+1) + fn encrypt( + &self, // TODO maybe rm? + mut rng: impl Rng, + pk: &PublicKey, + m: &R, + ) -> Result<(Rq, Rq)> { + let Xi_key = Uniform::new(-1_f64, 1_f64); + let Xi_err = Normal::new(0_f64, ERR_SIGMA)?; + + let e_0 = Rq::::rand_f64(&mut rng, Xi_err)?; + let e_1 = Rq::::rand_f64(&mut rng, Xi_err)?; + + let v = Rq::::rand_f64(&mut rng, Xi_key)?; + + let m: Rq = Rq::::from(*m); + + Ok((m + e_0 + v * pk.0.clone(), v * pk.1.clone() + e_1)) + } + + fn decrypt( + &self, // TODO maybe rm? + sk: SecretKey, + c: (Rq, Rq), + ) -> Result> { + let m = c.0.clone() + c.1 * sk.0; + Ok(m.mod_centered_q()) + } + + pub fn encode_and_encrypt( + &self, + mut rng: impl Rng, + pk: &PublicKey, + z: &[C], + ) -> Result<(Rq, Rq)> { + let m: R = self.encoder.encode(&z)?; // polynomial (encoded vec) \in R + + self.encrypt(&mut rng, pk, &m) + } + + pub fn decrypt_and_decode( + &self, + sk: SecretKey, + c: (Rq, Rq), + ) -> Result>> { + let d = self.decrypt(sk, c)?; + + self.encoder.decode(&d) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_encrypt_decrypt() -> Result<()> { + const Q: u64 = 2u64.pow(16) + 1; + const T: u64 = 16; + const N: usize = 8; + let scale_factor_u64 = 512_u64; // delta + let scale_factor = C::::new(512.0, 0.0); // delta + + let mut rng = rand::thread_rng(); + + for _ in 0..1000 { + let ckks = CKKS::::new(scale_factor); + + let (sk, pk) = ckks.new_key(&mut rng)?; + + let m_raw: R = Rq::::rand_f64(&mut rng, Uniform::new(0_f64, T as f64))?.to_r(); + let m = m_raw * scale_factor_u64; + + let ct = ckks.encrypt(&mut rng, &pk, &m)?; + let m_decrypted = ckks.decrypt(sk, ct)?; + + let m_decrypted: Vec = m_decrypted + .coeffs() + .iter() + .map(|e| (*e as f64 / (scale_factor_u64 as f64)).round() as u64) + .collect(); + let m_decrypted = Rq::::from_vec_u64(m_decrypted); + assert_eq!(m_decrypted, Rq::::from(m_raw)); + } + + Ok(()) + } + + #[test] + fn test_encode_encrypt_decrypt_decode() -> Result<()> { + const Q: u64 = 2u64.pow(16) + 1; + const T: u64 = 16; + const N: usize = 4; + let scale_factor = C::::new(512.0, 0.0); // delta + + let mut rng = rand::thread_rng(); + + for _ in 0..1000 { + let ckks = CKKS::::new(scale_factor); + let (sk, pk) = ckks.new_key(&mut rng)?; + + let z: Vec> = std::iter::repeat_with(|| C::::rand(&mut rng, T)) + .take(N / 2) + .collect(); + let m: R = ckks.encoder.encode(&z)?; + + // sanity check + { + let z_decoded = ckks.encoder.decode(&m)?; + let rounded_z_decoded: Vec> = z_decoded + .iter() + .map(|c| C::::new(c.re.round(), c.im.round())) + .collect(); + assert_eq!(rounded_z_decoded, z); + } + + let ct = ckks.encrypt(&mut rng, &pk, &m)?; + let m_decrypted = ckks.decrypt(sk, ct)?; + + let z_decrypted = ckks.encoder.decode(&m_decrypted)?; + + let rounded_z_decrypted: Vec> = z_decrypted + .iter() + .map(|&c| C::::new(c.re.round(), c.im.round())) + .collect(); + assert_eq!(rounded_z_decrypted, z); + } + + Ok(()) + } +}