use anyhow::Result; use arith::{Matrix, Rq, C, R}; #[derive(Clone, Debug)] pub struct SecretKey(Rq); #[derive(Clone, Debug)] pub struct PublicKey(Rq, Rq); pub struct Encoder { n: usize, scale_factor: C, // Δ (delta) primitive: C, basis: Matrix>, basis_t: Matrix>, // transposed basis } /// returns the mitive root of unity fn primitive_root_of_unity(m: usize) -> C { let pi = C::::from(std::f64::consts::PI); ((C::::from(2f64) * pi * C::::i()) / C::::new(m as f64, 0f64)).exp() } /// where 'w' is 'omega', the primitive root of unity fn vandermonde(n: usize, w: C) -> Matrix> { let mut v: Vec>> = vec![]; for i in 0..n { let root = w.pow(2 * i as u32 + 1); let mut row: Vec> = vec![]; for j in 0..n { row.push(root.pow(j as u32)); } v.push(row); } Matrix::>(v) } impl Encoder { pub fn new(n: usize, scale_factor: C) -> Self { let primitive: C = primitive_root_of_unity(2 * n); let basis = vandermonde(n, primitive); let basis_t = basis.transpose(); Self { n, scale_factor, primitive, basis, basis_t, } } /// encode as described in the CKKS paper. /// from $\mathbb{C}^{N/2} \longrightarrow \mathbb{Z_q}[X]/(X^N +1) = R$ // TODO use alg.1 from 2018-1043, // or as in 2018-1073: $f(x) = 1N (U^T.conj() m + U^T m.conj())$ pub fn encode(&self, z: &[C]) -> Result { // $pi^{-1}: \mathbb{C}^{N/2} \longrightarrow \mathbb{H}$ let expanded = self.pi_inv(z); // scale the values let scaled: Vec> = expanded.iter().map(|e| *e * self.scale_factor).collect(); // but $\mathbb{H} \neq \sigma(R)$, since $\sigma(R) \subseteq \mathbb{H}$, so we need to // discretize $\pi^{-1}(z)$ into an element of $\sigma(R)$. // discretize \pi^-1(z_projected) to \sigma(R) // project 'scaled' into \sigma(R): // get the orthogonal basis (note: that would be doing Gram-Schmidt, which is not this, but // we're fine since the basis=Vandermonde matrix which is orthogonal, so we project z to it): // $z = \sum z_i * b_i, with z_i = /||b_i||^2$ let z_projected = self .basis_t .0 .iter() .map(|b_i| { // TODO: the b_j.conj() can be precomputed at initialization (of the basis) let num: C = scaled .iter() .zip(b_i.iter()) .map(|(z_j, b_j)| *z_j * b_j.conj()) .sum::>(); let den: C = b_i.iter().map(|b_j| *b_j * b_j.conj()).sum::>(); let mut z_i = num / den; z_i.im = 0.0; // get only the real component z_i }) .collect::>>(); // V * z_projected (V: Vandermonde matrix) let discretized = self.basis.mul_vec(&z_projected)?; // sigma_inv let r = self.sigma_inv(&discretized)?; // TMP: naive round, maybe do gaussian let coeffs = r.iter().map(|e| e.re.round() as i64).collect::>(); Ok(R::from_vec(self.n, coeffs)) } pub fn decode(&self, p: &R) -> Result>> { let p: Vec> = p .coeffs() .iter() .map(|&e| C::::new(e as f64, 0_f64)) // TODO review u64 to f64 conversion overflow .collect(); let in_sigma = self.sigma(&p)?; let deescalated: Vec> = in_sigma.iter().map(|e| *e / self.scale_factor).collect(); Ok(self.pi(&deescalated)) } /// pi: \mathbb{H} \longrightarrow \mathbb{C}^{N/2} fn pi(&self, z: &[C]) -> Vec> { z[..self.n / 2].to_vec() } /// pi^{-1}: \mathbb{C}^{N/2} \longrightarrow \mathbb{H} fn pi_inv(&self, z: &[C]) -> Vec> { z.iter() .cloned() .chain(z.iter().rev().map(|z_i| z_i.conj())) .collect() } fn sigma(&self, p: &[C]) -> Result>> { // the roots of unity are already calculated in the 2nd row of the transpose of the // Vandermonde matrix used as the basis (ie. the 2nd column of the Vandermonde matrix). // let roots = &self.basis_t[1]; // // Approach 1: evaluate p at the roots of unity // let mut z = vec![]; // for root_i in roots.iter() { // z.push(eval(p, root_i)); // } // Approach 2: Vandermonde * p let z: Vec> = self.basis.mul_vec(&p.to_vec())?; // TODO check using NTT-ish (2018-1043) for the encode/decode Ok(z) } fn sigma_inv(&self, z: &Vec>) -> Result>> { // $\alpha = A^{-1} * z$ let a = self.basis.solve(z)?; Ok(a.to_vec()) } } #[cfg(test)] mod tests { use super::*; use rand::Rng; #[test] fn test_encode_decode() -> Result<()> { const Q: u64 = 1024; const N: usize = 32; let n: usize = 32; let T = 128; // WIP let mut rng = rand::rng(); for _ in 0..100 { let z: Vec> = std::iter::repeat_with(|| { C::::new(rng.random_range(0..T) as f64, rng.random_range(0..T) as f64) }) .take(N / 2) .collect(); let delta = C::::new(64.0, 0.0); // delta = scaling factor let encoder = Encoder::new(n, delta); let m: R = encoder.encode(&z)?; // polynomial (encoded vec) \in R let z_decoded = encoder.decode(&m)?; // round it to compare it to the initial value let rounded_z_decoded: Vec> = z_decoded .iter() .map(|c| C::::new(c.re.round(), c.im.round())) .collect(); assert_eq!(rounded_z_decoded, z); } Ok(()) } }